├── .gitignore ├── LICENSE ├── README.md └── online_softmax_torch.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 dhcode95 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # online-softmax 2 | simplest online-softmax notebook for explain Flash Attention 3 | 4 | Blog link: **[手撕Online-Softmax](https://zhuanlan.zhihu.com/p/5078640012)** 5 | 6 | ## Implemention 7 | 8 | run `online_softmax_torch.ipynb` 9 | 10 | we show the block online softmax result 11 | 12 | ```python 13 | X = torch.tensor([-0.3, 0.2, 0.5, 0.7, 0.1, 0.8]) 14 | X_softmax = F.softmax(X, dim = 0) 15 | print(X_softmax) 16 | 17 | X_block = torch.split(X, split_size_or_sections = 3 , dim = 0) 18 | 19 | # we parallel compute different block max & sum 20 | X_block_0_max = X_block[0].max() 21 | X_block_0_sum = torch.exp(X_block[0] - X_block_0_max).sum() 22 | 23 | X_block_1_max = X_block[1].max() 24 | X_block_1_sum = torch.exp(X_block[1] - X_block_1_max).sum() 25 | 26 | # online block update max & sum 27 | X_block_1_max_update = torch.max(X_block_0_max, X_block_1_max) # X[-1] is new data 28 | X_block_1_sum_update = X_block_0_sum * torch.exp(X_block_0_max - X_block_1_max_update) \ 29 | + torch.exp(X_block[1] - X_block_1_max_update).sum() # block sum 30 | 31 | X_block_online_softmax = torch.exp(X - X_block_1_max_update) / X_block_1_sum_update 32 | print(X_block_online_softmax) 33 | ``` 34 | 35 | output is 36 | 37 | ``` 38 | tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485]) 39 | tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485]) 40 | ``` 41 | 42 | ## Softmax Series 43 | 44 | ### softmax 45 | 46 | $$ 47 | \tilde{x}_i=\frac{e^{x_i}}{\sum_j^Ne^{x_j}} 48 | $$ 49 | 50 | ### safe softmax 51 | 52 | $$ 53 | \tilde{x}_i=\frac{e^{x_i-\max(x_{:N})}}{\sum_j^Ne^{x_j-\max(x_{:N})}} 54 | $$ 55 | 56 | Note $M=max(x_:N)$, so 57 | $$ 58 | \begin{align} 59 | \tilde{x}_i &=\frac{e^{x_i-\max(x_{:N})}}{\sum_j^Ne^{x_j-\max(x_{:N})}}\\ 60 | &=\frac{e^{x_i-M}}{\sum_j^Ne^{x_j-M}}\\ 61 | &=\frac{e^{x_i}/e^{M}}{\sum_j^Ne^{x_j}/e^{M}} \\ 62 | &=\frac{e^{x_i}}{\sum_j^Ne^{x_j}} \\ 63 | \end{align} 64 | $$ 65 | 66 | 67 | ### online softmax 68 | 69 | 1. We first compute `1:N` element maximum value $\max(x_{:N})$ and softmax denominator $l_N$ 70 | 71 | 2. We add a new element $x_{N+1}$, we update $\max(x_{:N+1})$ and update $l_{N+1}$ as follow. 72 | 73 | $$ 74 | \begin{align} 75 | l_{N} &= \sum_j^N e^{x_j-\max(x_{:N})}\\ 76 | \max(x_{:N+1})&=\max( \max(x_{:N}), x_{N+1} )\\ 77 | l_{N+1} &= \sum_j^{N+1} e^{x_j-\max(x_{:N+1})} \\ 78 | &= (\sum_j^N e^{x_j-\max(x_{:N})}) +e^{x_{N+1}-\max(x_{:N+1})} \\ 79 | &=(\sum_j^N e^{x_j-\max(x_{:N})}e^{\max(x_{:N})-\max(x_{:N+1})})+e^{x_{N+1}-\max(x_{:N+1})} \\ 80 | &=(\sum_j^N e^{x_j-\max(x_{:N})})(e^{\max(x_{:N})-\max(x_{:N+1})}) +e^{x_{N+1}-\max(x_{:N+1})} \\ 81 | &=l_N (e^{\max(x_{:N})-\max(x_{:N+1})})+e^{x_{N+1}-\max(x_{:N+1})} \\ 82 | \end{align} 83 | $$ 84 | 85 | ​ we cannot use $l_{N+1}=l_{N}+x_{N+1}$, because safe softmax need all element subtract the same maximum value. 86 | 87 | 3. We can apply the softmax function using the adjusted numerator and denominator values. 88 | 89 | $$ 90 | \tilde{x}_{i}=\frac{e^{x_i-\max(x_{:N+1})}}{l_{N+1}} 91 | $$ 92 | 93 | ### block online softmax 94 | 95 | online softmax make cumulative sum $l$ dynamic update while a new element added. It's more effiecent method is to update sum $l$ with block-wise element added. This advantage is we could parallelism to compute online softmax 96 | 97 | 1. we seperate compute different block $l^{(t)}$ and $m^{(t)}$ 98 | 99 | $$ 100 | \begin{align} 101 | l^{(1)} &= l_{N} = \sum_j^N e^{x_j-\max(x_{:N})}\\ 102 | m^{(1)} &= \max(x_{:N}) \\ 103 | l^{(2)} &= l_{N:2N} = \sum_{j=N+1}^{2N} e^{x_j-\max(x_{{N+1}:2N})}\\ 104 | m^{(2)} &= \max(x_{N+1:2N}) \\ 105 | \end{align} 106 | $$ 107 | 108 | 2. it’s easy to update global $m,l$ 109 | $$ 110 | \begin{align} 111 | m=\max({x_{:2N}})&=\max(\max({x_{:N}}),\max(x_{N+1:2N}))\\ 112 | &=max(m^{(1)},m^{(2)}) 113 | \end{align} 114 | $$ 115 | but the $l$ NOT update as follow: 116 | $$ 117 | l=l_{:2N}\neq l^{(1)}+l^{(2)} 118 | $$ 119 | 120 | 3. So we based block sum $l^{(t)}$ and max $m^{(t)}$ to **online** update global $l$ 121 | 122 | $$ 123 | \begin{align} 124 | l^{(1)}&= \sum_j^N e^{x_j-\max(x_{:N})} = \sum_j^N e^{x_j-m^{(1)}}\\ 125 | l^{(2)} &= \sum_{j=N+1}^{2N} e^{x_j-\max(x_{{N+1}:2N})} = \sum_{j=N+1}^{2N} e^{x_j-m^{(2)}}\\ 126 | l &= \sum_{j}^{2N} e^{x_j-\max(x_{:2N})} \\ 127 | &= (\sum_j^N e^{x_j-\max(x_{:2N})}) +(\sum_{j=N+1}^{2N}e^{x_j-\max(x_{:2N})}) \\ 128 | &= (\sum_j^N e^{x_j-m}) +(\sum_{j=N+1}^{2N}e^{x_j-m}) \\ 129 | &= (\sum_j^N e^{x_j-m^{(1)}}) (e^{m^{(1)}-m}) +(\sum_{j=N+1}^{2N}e^{x_j-m^{(2)}})(e^{m^{(2)}-m}) \\ 130 | &= l^{(1)} (e^{m^{(1)}-m}) +l^{(2)}(e^{m^{(2)}-m}) 131 | \end{align} 132 | $$ 133 | 134 | 4. update block softmax like: 135 | 136 | $$ 137 | \tilde{x}_{i} =\frac{e^{x_i-m}}{l} 138 | $$ 139 | 140 | ### multi block online softmax 141 | 142 | we do multi block online softmax by for-loop : 143 | $$ 144 | l_\text{new}= l_\text{old} (e^{m_\text{old}-m}) +l_\text{new}(e^{m_{\text{new}}-m}) 145 | $$ 146 | noted current block max/sum as $m_\text{new},l_\text{new}$ ,the m is $m=\max(m_\text{old},m_\text{new})$, and then update: 147 | $$ 148 | l_\text{old} \leftarrow l_\text{new} 149 | $$ 150 | 151 | ### batch online softmax 152 | 153 | In attention machine, we need softmax for attention score matrix 154 | $$ 155 | S=QK^T,S\in\mathbb{R}^{N\times N} 156 | $$ 157 | the query is row-wise matrix $Q\in\mathbb{R}^{N\times D}$; 158 | 159 | and we need softmax attention score: 160 | $$ 161 | P_{i,:}=\text{softmax}(S_{i,:}) 162 | $$ 163 | when we use online-softmax, we could parallel update k-row max $M^{(t)}$ and row-wise sum $L^{(t)}$, 164 | $$ 165 | L = L^{(1)}(e^{M^{(1)}-M})+L^{(2)}(e^{M^{(2)}-M}) 166 | $$ 167 | where $L,M\in\mathbb{R}^{k\times 1}$ 168 | 169 | ## Reference 170 | 171 | [手撕Flash Attention](https://zhuanlan.zhihu.com/p/663932651) 172 | 173 | [Online normalizer calculation for softmax](https://arxiv.org/abs/1805.02867) 174 | 175 | -------------------------------------------------------------------------------- /online_softmax_torch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6759a9b1-a71f-4f2d-b6f7-881d1c580261", 6 | "metadata": {}, 7 | "source": [ 8 | "# Online Softmax" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "e9a81889-4e32-410f-9942-ae1e50197e73", 14 | "metadata": {}, 15 | "source": [ 16 | "github: xiaodongguaAIGC\n", 17 | "\n", 18 | "- softmax\n", 19 | "- Safe Softmax\n", 20 | "- online softmax\n", 21 | "- block online softmax\n", 22 | "- multi block online softmax\n", 23 | "- batch online softmax\n", 24 | "- multi block batch online softmax" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "id": "ab5f591c-ac6d-4ddc-b7ef-b93d9e2187af", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import torch\n", 35 | "import torch.nn.functional as F" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "id": "c7e9c2d6-63ff-4e55-ad3a-f8bec54c8fb1", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "X = torch.tensor([-0.3, 0.2, 0.5, 0.7, 0.1, 0.8])" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "0e28e1fc-6bc2-4915-b2a5-04dc1b2b80dc", 51 | "metadata": {}, 52 | "source": [ 53 | "## Softmax By Torch" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "id": "bd32b22f-3dbe-45d2-ada1-74c13d3afe98", 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "data": { 64 | "text/html": [ 65 | "
tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])\n",
 66 |        "
\n" 67 | ], 68 | "text/plain": [ 69 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0827\u001b[0m, \u001b[1;36m0.1364\u001b[0m, \u001b[1;36m0.1841\u001b[0m, \u001b[1;36m0.2249\u001b[0m, \u001b[1;36m0.1234\u001b[0m, \u001b[1;36m0.2485\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 70 | ] 71 | }, 72 | "metadata": {}, 73 | "output_type": "display_data" 74 | } 75 | ], 76 | "source": [ 77 | "X_softmax = F.softmax(X, dim = 0)\n", 78 | "print(X_softmax)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "d0398b98-93fa-4a6d-b206-4e0fe3678c34", 84 | "metadata": {}, 85 | "source": [ 86 | "## Softmax By Handwrite" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "id": "8761f59c-1954-4052-95fa-7e9cced802f4", 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/html": [ 98 | "
tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])\n",
 99 |        "
\n" 100 | ], 101 | "text/plain": [ 102 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0827\u001b[0m, \u001b[1;36m0.1364\u001b[0m, \u001b[1;36m0.1841\u001b[0m, \u001b[1;36m0.2249\u001b[0m, \u001b[1;36m0.1234\u001b[0m, \u001b[1;36m0.2485\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 103 | ] 104 | }, 105 | "metadata": {}, 106 | "output_type": "display_data" 107 | } 108 | ], 109 | "source": [ 110 | "X_exp_sum = X.exp().sum()\n", 111 | "X_softmax_hand = torch.exp(X) / X_exp_sum\n", 112 | "print(X_softmax_hand)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "81247fc4-f14b-4ce2-a715-9f1200d63225", 118 | "metadata": {}, 119 | "source": [ 120 | "## Safe Softmax By Handwrite" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "id": "edba72aa-d000-4ae0-badc-1f9b171d9f7c", 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/html": [ 132 | "
tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])\n",
133 |        "
\n" 134 | ], 135 | "text/plain": [ 136 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0827\u001b[0m, \u001b[1;36m0.1364\u001b[0m, \u001b[1;36m0.1841\u001b[0m, \u001b[1;36m0.2249\u001b[0m, \u001b[1;36m0.1234\u001b[0m, \u001b[1;36m0.2485\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 137 | ] 138 | }, 139 | "metadata": {}, 140 | "output_type": "display_data" 141 | } 142 | ], 143 | "source": [ 144 | "X_max = X.max()\n", 145 | "X_exp_sum_sub_max = torch.exp(X-X_max).sum()\n", 146 | "X_safe_softmax_hand = torch.exp(X - X_max) / X_exp_sum_sub_max\n", 147 | "print(X_safe_softmax_hand)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "01273773-397c-4b0d-9607-01b1de1dea45", 153 | "metadata": {}, 154 | "source": [ 155 | "## Online Softmax" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 6, 161 | "id": "4f5f40da-622d-4745-b061-ccc392d32a7a", 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "data": { 166 | "text/html": [ 167 | "
input x\n",
168 |        "
\n" 169 | ], 170 | "text/plain": [ 171 | "input x\n" 172 | ] 173 | }, 174 | "metadata": {}, 175 | "output_type": "display_data" 176 | }, 177 | { 178 | "data": { 179 | "text/html": [ 180 | "
tensor([-0.3000,  0.2000,  0.5000,  0.7000,  0.1000,  0.8000])\n",
181 |        "
\n" 182 | ], 183 | "text/plain": [ 184 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.3000\u001b[0m, \u001b[1;36m0.2000\u001b[0m, \u001b[1;36m0.5000\u001b[0m, \u001b[1;36m0.7000\u001b[0m, \u001b[1;36m0.1000\u001b[0m, \u001b[1;36m0.8000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 185 | ] 186 | }, 187 | "metadata": {}, 188 | "output_type": "display_data" 189 | }, 190 | { 191 | "data": { 192 | "text/html": [ 193 | "
tensor([-0.3000,  0.2000,  0.5000,  0.7000,  0.1000])\n",
194 |        "
\n" 195 | ], 196 | "text/plain": [ 197 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.3000\u001b[0m, \u001b[1;36m0.2000\u001b[0m, \u001b[1;36m0.5000\u001b[0m, \u001b[1;36m0.7000\u001b[0m, \u001b[1;36m0.1000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 198 | ] 199 | }, 200 | "metadata": {}, 201 | "output_type": "display_data" 202 | }, 203 | { 204 | "data": { 205 | "text/html": [ 206 | "
tensor(0.8000)\n",
207 |        "
\n" 208 | ], 209 | "text/plain": [ 210 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m0.8000\u001b[0m\u001b[1m)\u001b[0m\n" 211 | ] 212 | }, 213 | "metadata": {}, 214 | "output_type": "display_data" 215 | }, 216 | { 217 | "data": { 218 | "text/html": [ 219 | "
online softmax result:  tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])\n",
220 |        "
\n" 221 | ], 222 | "text/plain": [ 223 | "online softmax result: \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0827\u001b[0m, \u001b[1;36m0.1364\u001b[0m, \u001b[1;36m0.1841\u001b[0m, \u001b[1;36m0.2249\u001b[0m, \u001b[1;36m0.1234\u001b[0m, \u001b[1;36m0.2485\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 224 | ] 225 | }, 226 | "metadata": {}, 227 | "output_type": "display_data" 228 | } 229 | ], 230 | "source": [ 231 | "X_pre = X[:-1]\n", 232 | "print('input x')\n", 233 | "print(X)\n", 234 | "print(X_pre)\n", 235 | "print(X[-1])\n", 236 | "\n", 237 | "# we calculative t-1 time Online Softmax\n", 238 | "X_max_pre = X_pre.max()\n", 239 | "X_sum_pre = torch.exp(X_pre - X_max_pre).sum()\n", 240 | "\n", 241 | "# we calculative t time Online Softmax\n", 242 | "X_max_cur = torch.max(X_max_pre, X[-1]) # X[-1] is new data\n", 243 | "X_sum_cur = X_sum_pre * torch.exp(X_max_pre - X_max_cur) + torch.exp(X[-1] - X_max_cur)\n", 244 | "\n", 245 | "# final we calculative online softmax\n", 246 | "X_online_softmax = torch.exp(X - X_max_cur) / X_sum_cur\n", 247 | "print('online softmax result: ', X_online_softmax)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "id": "2b80860f-c775-4992-a580-a5c91e080670", 253 | "metadata": {}, 254 | "source": [ 255 | "## Block Online Softmax" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 7, 261 | "id": "c6ac9ff3-e6df-4cd9-8fd6-2d491e41da23", 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/html": [ 267 | "
tensor([-0.3000,  0.2000,  0.5000,  0.7000,  0.1000,  0.8000])\n",
268 |        "
\n" 269 | ], 270 | "text/plain": [ 271 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.3000\u001b[0m, \u001b[1;36m0.2000\u001b[0m, \u001b[1;36m0.5000\u001b[0m, \u001b[1;36m0.7000\u001b[0m, \u001b[1;36m0.1000\u001b[0m, \u001b[1;36m0.8000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 272 | ] 273 | }, 274 | "metadata": {}, 275 | "output_type": "display_data" 276 | }, 277 | { 278 | "data": { 279 | "text/html": [ 280 | "
(tensor([-0.3000,  0.2000,  0.5000]), tensor([0.7000, 0.1000, 0.8000]))\n",
281 |        "
\n" 282 | ], 283 | "text/plain": [ 284 | "\u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.3000\u001b[0m, \u001b[1;36m0.2000\u001b[0m, \u001b[1;36m0.5000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.7000\u001b[0m, \u001b[1;36m0.1000\u001b[0m, \u001b[1;36m0.8000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n" 285 | ] 286 | }, 287 | "metadata": {}, 288 | "output_type": "display_data" 289 | } 290 | ], 291 | "source": [ 292 | "X_block = torch.split(X, split_size_or_sections = 3 , dim = 0) \n", 293 | "print(X)\n", 294 | "print(X_block)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 8, 300 | "id": "83c30afc-84ac-4605-8a7f-f8360a6a3c2a", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "# we parallel calculate different block max & sum\n", 305 | "X_block_0_max = X_block[0].max()\n", 306 | "X_block_0_sum = torch.exp(X_block[0] - X_block_0_max).sum()\n", 307 | "\n", 308 | "X_block_1_max = X_block[1].max()\n", 309 | "X_block_1_sum = torch.exp(X_block[1] - X_block_1_max).sum()" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 9, 315 | "id": "f4e18507-5238-456e-956f-06568376b173", 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "data": { 320 | "text/html": [ 321 | "
tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])\n",
322 |        "
\n" 323 | ], 324 | "text/plain": [ 325 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0827\u001b[0m, \u001b[1;36m0.1364\u001b[0m, \u001b[1;36m0.1841\u001b[0m, \u001b[1;36m0.2249\u001b[0m, \u001b[1;36m0.1234\u001b[0m, \u001b[1;36m0.2485\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 326 | ] 327 | }, 328 | "metadata": {}, 329 | "output_type": "display_data" 330 | } 331 | ], 332 | "source": [ 333 | "# parallel online block update max & sum\n", 334 | "X_max_global = torch.max(X_block_0_max, X_block_1_max) \n", 335 | "L_global = (X_block_0_sum * torch.exp(X_block_0_max - X_max_global) \\\n", 336 | " + X_block_1_sum * torch.exp(X_block_1_max - X_max_global)) # block sum\n", 337 | "\n", 338 | "X_block_online_softmax_parallel = torch.exp(X - X_max_global) / L_global\n", 339 | "print(X_block_online_softmax_parallel)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 10, 345 | "id": "9f618de1-ec3d-4d14-a10b-77e291480249", 346 | "metadata": {}, 347 | "outputs": [ 348 | { 349 | "data": { 350 | "text/html": [ 351 | "
tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])\n",
352 |        "
\n" 353 | ], 354 | "text/plain": [ 355 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0827\u001b[0m, \u001b[1;36m0.1364\u001b[0m, \u001b[1;36m0.1841\u001b[0m, \u001b[1;36m0.2249\u001b[0m, \u001b[1;36m0.1234\u001b[0m, \u001b[1;36m0.2485\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 356 | ] 357 | }, 358 | "metadata": {}, 359 | "output_type": "display_data" 360 | } 361 | ], 362 | "source": [ 363 | "# online block update max & sum\n", 364 | "# updated version for multi-block, simpler version\n", 365 | "X_block_1_max_update = torch.max(X_block_0_max, X_block_1_max) \n", 366 | "X_block_1_sum_update = X_block_0_sum * torch.exp(X_block_0_max - X_block_1_max_update) \\\n", 367 | " + torch.exp(X_block[1] - X_block_1_max_update).sum() # block sum\n", 368 | "\n", 369 | "X_block_online_softmax = torch.exp(X - X_block_1_max_update) / X_block_1_sum_update\n", 370 | "print(X_block_online_softmax)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "id": "d68937aa-1d6d-4f3a-a013-0ef6fb763607", 376 | "metadata": {}, 377 | "source": [ 378 | "## Multi Block Online Softmax" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": 11, 384 | "id": "7a815125-e36d-4098-924c-abfbdac07fd4", 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "data": { 389 | "text/html": [ 390 | "
tensor([-0.3000,  0.2000,  0.5000,  0.7000,  0.1000,  0.8000])\n",
391 |        "
\n" 392 | ], 393 | "text/plain": [ 394 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.3000\u001b[0m, \u001b[1;36m0.2000\u001b[0m, \u001b[1;36m0.5000\u001b[0m, \u001b[1;36m0.7000\u001b[0m, \u001b[1;36m0.1000\u001b[0m, \u001b[1;36m0.8000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 395 | ] 396 | }, 397 | "metadata": {}, 398 | "output_type": "display_data" 399 | }, 400 | { 401 | "data": { 402 | "text/html": [ 403 | "
(tensor([-0.3000,  0.2000]), tensor([0.5000, 0.7000]), tensor([0.1000, 0.8000]))\n",
404 |        "
\n" 405 | ], 406 | "text/plain": [ 407 | "\u001b[1m(\u001b[0m\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.3000\u001b[0m, \u001b[1;36m0.2000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.5000\u001b[0m, \u001b[1;36m0.7000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.1000\u001b[0m, \u001b[1;36m0.8000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n" 408 | ] 409 | }, 410 | "metadata": {}, 411 | "output_type": "display_data" 412 | } 413 | ], 414 | "source": [ 415 | "X_block = torch.split(X, split_size_or_sections = 2, dim = 0) \n", 416 | "print(X)\n", 417 | "print(X_block)" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 12, 423 | "id": "26ed247b-2d4b-478a-a9d3-328cb23e0073", 424 | "metadata": {}, 425 | "outputs": [ 426 | { 427 | "data": { 428 | "text/html": [ 429 | "
tensor([0.0827, 0.1364, 0.1841, 0.2249, 0.1234, 0.2485])\n",
430 |        "
\n" 431 | ], 432 | "text/plain": [ 433 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.0827\u001b[0m, \u001b[1;36m0.1364\u001b[0m, \u001b[1;36m0.1841\u001b[0m, \u001b[1;36m0.2249\u001b[0m, \u001b[1;36m0.1234\u001b[0m, \u001b[1;36m0.2485\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 434 | ] 435 | }, 436 | "metadata": {}, 437 | "output_type": "display_data" 438 | }, 439 | { 440 | "data": { 441 | "text/html": [ 442 | "
tensor(1.0000)\n",
443 |        "
\n" 444 | ], 445 | "text/plain": [ 446 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1.0000\u001b[0m\u001b[1m)\u001b[0m\n" 447 | ] 448 | }, 449 | "metadata": {}, 450 | "output_type": "display_data" 451 | } 452 | ], 453 | "source": [ 454 | "# online multi-block update max & sum\n", 455 | "M_old = torch.tensor([-100000.0])\n", 456 | "L_old = torch.tensor([0.0])\n", 457 | "\n", 458 | "for i in range(len(X_block)):\n", 459 | " M = torch.max(X_block[i])\n", 460 | " M_new = torch.max(M, M_old) \n", 461 | " \n", 462 | " L_new = L_old * torch.exp(M_old - M_new) \\\n", 463 | " + torch.exp(X_block[i] - M).sum() * torch.exp(M - M_new) \n", 464 | " \n", 465 | " # use simplest format\n", 466 | " # L_new = L_old * torch.exp(M_old - M_new) \\\n", 467 | " # + torch.exp(X_block[i] - M_new).sum() \n", 468 | " \n", 469 | " M_old = M_new\n", 470 | " L_old = L_new\n", 471 | "\n", 472 | "X_multi_block_online_softmax = torch.exp(X - M_old) / L_old\n", 473 | "print(X_multi_block_online_softmax)\n", 474 | "print(X_multi_block_online_softmax.sum())" 475 | ] 476 | }, 477 | { 478 | "cell_type": "markdown", 479 | "id": "9b0aadfb-df9a-4835-9945-b7c5b40bd995", 480 | "metadata": {}, 481 | "source": [ 482 | "## Batch Online Softmax" 483 | ] 484 | }, 485 | { 486 | "cell_type": "markdown", 487 | "id": "a9c5288f-495e-415e-8a01-d9a683b489ef", 488 | "metadata": {}, 489 | "source": [ 490 | "### Batch Online Softmax by Torch" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 13, 496 | "id": "57e6e18e-4098-439e-8d97-094041d48f90", 497 | "metadata": {}, 498 | "outputs": [ 499 | { 500 | "data": { 501 | "text/html": [ 502 | "
tensor([[ 1.2568,  0.7383,  0.3952,  1.0712, -0.4927, -1.4833],\n",
503 |        "        [ 0.0294,  0.8888, -0.1123,  1.7380, -1.3930,  0.9524],\n",
504 |        "        [ 0.3562, -0.3688, -0.3282,  3.2616,  1.5259, -0.1031],\n",
505 |        "        [ 0.1950,  0.3002,  1.2051, -1.3824,  1.6178,  0.9580]])\n",
506 |        "
\n" 507 | ], 508 | "text/plain": [ 509 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m1.2568\u001b[0m, \u001b[1;36m0.7383\u001b[0m, \u001b[1;36m0.3952\u001b[0m, \u001b[1;36m1.0712\u001b[0m, \u001b[1;36m-0.4927\u001b[0m, \u001b[1;36m-1.4833\u001b[0m\u001b[1m]\u001b[0m,\n", 510 | " \u001b[1m[\u001b[0m \u001b[1;36m0.0294\u001b[0m, \u001b[1;36m0.8888\u001b[0m, \u001b[1;36m-0.1123\u001b[0m, \u001b[1;36m1.7380\u001b[0m, \u001b[1;36m-1.3930\u001b[0m, \u001b[1;36m0.9524\u001b[0m\u001b[1m]\u001b[0m,\n", 511 | " \u001b[1m[\u001b[0m \u001b[1;36m0.3562\u001b[0m, \u001b[1;36m-0.3688\u001b[0m, \u001b[1;36m-0.3282\u001b[0m, \u001b[1;36m3.2616\u001b[0m, \u001b[1;36m1.5259\u001b[0m, \u001b[1;36m-0.1031\u001b[0m\u001b[1m]\u001b[0m,\n", 512 | " \u001b[1m[\u001b[0m \u001b[1;36m0.1950\u001b[0m, \u001b[1;36m0.3002\u001b[0m, \u001b[1;36m1.2051\u001b[0m, \u001b[1;36m-1.3824\u001b[0m, \u001b[1;36m1.6178\u001b[0m, \u001b[1;36m0.9580\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 513 | ] 514 | }, 515 | "metadata": {}, 516 | "output_type": "display_data" 517 | }, 518 | { 519 | "data": { 520 | "text/html": [ 521 | "
tensor([[0.3240, 0.1929, 0.1369, 0.2691, 0.0563, 0.0209],\n",
522 |        "        [0.0799, 0.1888, 0.0694, 0.4414, 0.0193, 0.2012],\n",
523 |        "        [0.0415, 0.0201, 0.0209, 0.7578, 0.1336, 0.0262],\n",
524 |        "        [0.0881, 0.0978, 0.2418, 0.0182, 0.3653, 0.1888]])\n",
525 |        "
\n" 526 | ], 527 | "text/plain": [ 528 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.3240\u001b[0m, \u001b[1;36m0.1929\u001b[0m, \u001b[1;36m0.1369\u001b[0m, \u001b[1;36m0.2691\u001b[0m, \u001b[1;36m0.0563\u001b[0m, \u001b[1;36m0.0209\u001b[0m\u001b[1m]\u001b[0m,\n", 529 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0799\u001b[0m, \u001b[1;36m0.1888\u001b[0m, \u001b[1;36m0.0694\u001b[0m, \u001b[1;36m0.4414\u001b[0m, \u001b[1;36m0.0193\u001b[0m, \u001b[1;36m0.2012\u001b[0m\u001b[1m]\u001b[0m,\n", 530 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0415\u001b[0m, \u001b[1;36m0.0201\u001b[0m, \u001b[1;36m0.0209\u001b[0m, \u001b[1;36m0.7578\u001b[0m, \u001b[1;36m0.1336\u001b[0m, \u001b[1;36m0.0262\u001b[0m\u001b[1m]\u001b[0m,\n", 531 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0881\u001b[0m, \u001b[1;36m0.0978\u001b[0m, \u001b[1;36m0.2418\u001b[0m, \u001b[1;36m0.0182\u001b[0m, \u001b[1;36m0.3653\u001b[0m, \u001b[1;36m0.1888\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 532 | ] 533 | }, 534 | "metadata": {}, 535 | "output_type": "display_data" 536 | }, 537 | { 538 | "data": { 539 | "text/html": [ 540 | "
tensor([1.0000, 1.0000, 1.0000, 1.0000])\n",
541 |        "
\n" 542 | ], 543 | "text/plain": [ 544 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1.0000\u001b[0m, \u001b[1;36m1.0000\u001b[0m, \u001b[1;36m1.0000\u001b[0m, \u001b[1;36m1.0000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 545 | ] 546 | }, 547 | "metadata": {}, 548 | "output_type": "display_data" 549 | } 550 | ], 551 | "source": [ 552 | "X_batch = torch.randn(4, 6)\n", 553 | "print(X_batch)\n", 554 | "X_batch_softmax = F.softmax(X_batch, dim = 1) \n", 555 | "print(X_batch_softmax)\n", 556 | "X_batch_softmax_evaluete = X_batch_softmax.sum(dim = 1)\n", 557 | "print(X_batch_softmax_evaluete) # row prob sum is 1" 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "id": "9821e291-a47a-48b8-9c63-2f39d4960867", 563 | "metadata": {}, 564 | "source": [ 565 | "### Batch Online Softmax by Hand" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 14, 571 | "id": "b8abed75-6800-4cdb-be7b-a3a958a32256", 572 | "metadata": {}, 573 | "outputs": [ 574 | { 575 | "data": { 576 | "text/html": [ 577 | "
4 3\n",
578 |        "
\n" 579 | ], 580 | "text/plain": [ 581 | "\u001b[1;36m4\u001b[0m \u001b[1;36m3\u001b[0m\n" 582 | ] 583 | }, 584 | "metadata": {}, 585 | "output_type": "display_data" 586 | }, 587 | { 588 | "data": { 589 | "text/html": [ 590 | "
tensor([[ 1.2568,  0.7383,  0.3952,  1.0712, -0.4927, -1.4833],\n",
591 |        "        [ 0.0294,  0.8888, -0.1123,  1.7380, -1.3930,  0.9524],\n",
592 |        "        [ 0.3562, -0.3688, -0.3282,  3.2616,  1.5259, -0.1031],\n",
593 |        "        [ 0.1950,  0.3002,  1.2051, -1.3824,  1.6178,  0.9580]])\n",
594 |        "
\n" 595 | ], 596 | "text/plain": [ 597 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m1.2568\u001b[0m, \u001b[1;36m0.7383\u001b[0m, \u001b[1;36m0.3952\u001b[0m, \u001b[1;36m1.0712\u001b[0m, \u001b[1;36m-0.4927\u001b[0m, \u001b[1;36m-1.4833\u001b[0m\u001b[1m]\u001b[0m,\n", 598 | " \u001b[1m[\u001b[0m \u001b[1;36m0.0294\u001b[0m, \u001b[1;36m0.8888\u001b[0m, \u001b[1;36m-0.1123\u001b[0m, \u001b[1;36m1.7380\u001b[0m, \u001b[1;36m-1.3930\u001b[0m, \u001b[1;36m0.9524\u001b[0m\u001b[1m]\u001b[0m,\n", 599 | " \u001b[1m[\u001b[0m \u001b[1;36m0.3562\u001b[0m, \u001b[1;36m-0.3688\u001b[0m, \u001b[1;36m-0.3282\u001b[0m, \u001b[1;36m3.2616\u001b[0m, \u001b[1;36m1.5259\u001b[0m, \u001b[1;36m-0.1031\u001b[0m\u001b[1m]\u001b[0m,\n", 600 | " \u001b[1m[\u001b[0m \u001b[1;36m0.1950\u001b[0m, \u001b[1;36m0.3002\u001b[0m, \u001b[1;36m1.2051\u001b[0m, \u001b[1;36m-1.3824\u001b[0m, \u001b[1;36m1.6178\u001b[0m, \u001b[1;36m0.9580\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 601 | ] 602 | }, 603 | "metadata": {}, 604 | "output_type": "display_data" 605 | }, 606 | { 607 | "data": { 608 | "text/html": [ 609 | "
tensor([[ 1.2568,  0.7383,  0.3952],\n",
610 |        "        [ 0.0294,  0.8888, -0.1123],\n",
611 |        "        [ 0.3562, -0.3688, -0.3282],\n",
612 |        "        [ 0.1950,  0.3002,  1.2051]])\n",
613 |        "
\n" 614 | ], 615 | "text/plain": [ 616 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m1.2568\u001b[0m, \u001b[1;36m0.7383\u001b[0m, \u001b[1;36m0.3952\u001b[0m\u001b[1m]\u001b[0m,\n", 617 | " \u001b[1m[\u001b[0m \u001b[1;36m0.0294\u001b[0m, \u001b[1;36m0.8888\u001b[0m, \u001b[1;36m-0.1123\u001b[0m\u001b[1m]\u001b[0m,\n", 618 | " \u001b[1m[\u001b[0m \u001b[1;36m0.3562\u001b[0m, \u001b[1;36m-0.3688\u001b[0m, \u001b[1;36m-0.3282\u001b[0m\u001b[1m]\u001b[0m,\n", 619 | " \u001b[1m[\u001b[0m \u001b[1;36m0.1950\u001b[0m, \u001b[1;36m0.3002\u001b[0m, \u001b[1;36m1.2051\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 620 | ] 621 | }, 622 | "metadata": {}, 623 | "output_type": "display_data" 624 | }, 625 | { 626 | "data": { 627 | "text/html": [ 628 | "
tensor([[ 1.0712, -0.4927, -1.4833],\n",
629 |        "        [ 1.7380, -1.3930,  0.9524],\n",
630 |        "        [ 3.2616,  1.5259, -0.1031],\n",
631 |        "        [-1.3824,  1.6178,  0.9580]])\n",
632 |        "
\n" 633 | ], 634 | "text/plain": [ 635 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m1.0712\u001b[0m, \u001b[1;36m-0.4927\u001b[0m, \u001b[1;36m-1.4833\u001b[0m\u001b[1m]\u001b[0m,\n", 636 | " \u001b[1m[\u001b[0m \u001b[1;36m1.7380\u001b[0m, \u001b[1;36m-1.3930\u001b[0m, \u001b[1;36m0.9524\u001b[0m\u001b[1m]\u001b[0m,\n", 637 | " \u001b[1m[\u001b[0m \u001b[1;36m3.2616\u001b[0m, \u001b[1;36m1.5259\u001b[0m, \u001b[1;36m-0.1031\u001b[0m\u001b[1m]\u001b[0m,\n", 638 | " \u001b[1m[\u001b[0m\u001b[1;36m-1.3824\u001b[0m, \u001b[1;36m1.6178\u001b[0m, \u001b[1;36m0.9580\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 639 | ] 640 | }, 641 | "metadata": {}, 642 | "output_type": "display_data" 643 | } 644 | ], 645 | "source": [ 646 | "b, d = X_batch.shape\n", 647 | "print(b, d//2)\n", 648 | "\n", 649 | "X_batch_block_0 = X_batch[:, :d//2]\n", 650 | "X_batch_block_1 = X_batch[:, d//2:]\n", 651 | "\n", 652 | "print(X_batch)\n", 653 | "print(X_batch_block_0)\n", 654 | "print(X_batch_block_1)" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": 15, 660 | "id": "8c89f06f-5536-42ef-880e-55359005c1a5", 661 | "metadata": {}, 662 | "outputs": [ 663 | { 664 | "data": { 665 | "text/html": [ 666 | "
tensor([[1.2568],\n",
667 |        "        [0.8888],\n",
668 |        "        [0.3562],\n",
669 |        "        [1.2051]])\n",
670 |        "
\n" 671 | ], 672 | "text/plain": [ 673 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1.2568\u001b[0m\u001b[1m]\u001b[0m,\n", 674 | " \u001b[1m[\u001b[0m\u001b[1;36m0.8888\u001b[0m\u001b[1m]\u001b[0m,\n", 675 | " \u001b[1m[\u001b[0m\u001b[1;36m0.3562\u001b[0m\u001b[1m]\u001b[0m,\n", 676 | " \u001b[1m[\u001b[0m\u001b[1;36m1.2051\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 677 | ] 678 | }, 679 | "metadata": {}, 680 | "output_type": "display_data" 681 | }, 682 | { 683 | "data": { 684 | "text/html": [ 685 | "
tensor([[2.0179],\n",
686 |        "        [1.7909],\n",
687 |        "        [1.9887],\n",
688 |        "        [1.7688]])\n",
689 |        "
\n" 690 | ], 691 | "text/plain": [ 692 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m2.0179\u001b[0m\u001b[1m]\u001b[0m,\n", 693 | " \u001b[1m[\u001b[0m\u001b[1;36m1.7909\u001b[0m\u001b[1m]\u001b[0m,\n", 694 | " \u001b[1m[\u001b[0m\u001b[1;36m1.9887\u001b[0m\u001b[1m]\u001b[0m,\n", 695 | " \u001b[1m[\u001b[0m\u001b[1;36m1.7688\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 696 | ] 697 | }, 698 | "metadata": {}, 699 | "output_type": "display_data" 700 | } 701 | ], 702 | "source": [ 703 | "# we parallel calculate different block max & sum\n", 704 | "X_batch_0_max, _ = X_batch_block_0.max(dim = 1, keepdim = True)\n", 705 | "X_batch_0_sum = torch.exp(X_batch_block_0 - X_batch_0_max).sum(dim = 1, keepdim = True)\n", 706 | "\n", 707 | "X_batch_1_max, _ = X_batch_block_1.max(dim = 1, keepdim = True)\n", 708 | "X_batch_1_sum = torch.exp(X_batch_block_1 - X_batch_1_max).sum(dim = 1, keepdim = True)\n", 709 | "\n", 710 | "print(X_batch_0_max)\n", 711 | "print(X_batch_0_sum)" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": 16, 717 | "id": "9b97605a-4d38-40d1-bb90-c4243c608daa", 718 | "metadata": {}, 719 | "outputs": [ 720 | { 721 | "data": { 722 | "text/html": [ 723 | "
tensor([[0.3240, 0.1929, 0.1369, 0.2691, 0.0563, 0.0209],\n",
724 |        "        [0.0799, 0.1888, 0.0694, 0.4414, 0.0193, 0.2012],\n",
725 |        "        [0.0415, 0.0201, 0.0209, 0.7578, 0.1336, 0.0262],\n",
726 |        "        [0.0881, 0.0978, 0.2418, 0.0182, 0.3653, 0.1888]])\n",
727 |        "
\n" 728 | ], 729 | "text/plain": [ 730 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.3240\u001b[0m, \u001b[1;36m0.1929\u001b[0m, \u001b[1;36m0.1369\u001b[0m, \u001b[1;36m0.2691\u001b[0m, \u001b[1;36m0.0563\u001b[0m, \u001b[1;36m0.0209\u001b[0m\u001b[1m]\u001b[0m,\n", 731 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0799\u001b[0m, \u001b[1;36m0.1888\u001b[0m, \u001b[1;36m0.0694\u001b[0m, \u001b[1;36m0.4414\u001b[0m, \u001b[1;36m0.0193\u001b[0m, \u001b[1;36m0.2012\u001b[0m\u001b[1m]\u001b[0m,\n", 732 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0415\u001b[0m, \u001b[1;36m0.0201\u001b[0m, \u001b[1;36m0.0209\u001b[0m, \u001b[1;36m0.7578\u001b[0m, \u001b[1;36m0.1336\u001b[0m, \u001b[1;36m0.0262\u001b[0m\u001b[1m]\u001b[0m,\n", 733 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0881\u001b[0m, \u001b[1;36m0.0978\u001b[0m, \u001b[1;36m0.2418\u001b[0m, \u001b[1;36m0.0182\u001b[0m, \u001b[1;36m0.3653\u001b[0m, \u001b[1;36m0.1888\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 734 | ] 735 | }, 736 | "metadata": {}, 737 | "output_type": "display_data" 738 | } 739 | ], 740 | "source": [ 741 | "# online batch block update max & sum\n", 742 | "X_batch_1_max_update = torch.maximum(X_batch_0_max, X_batch_1_max) # 逐个元素找最大值\n", 743 | "X_batch_1_sum_update = X_batch_0_sum * torch.exp(X_batch_0_max - X_batch_1_max_update) \\\n", 744 | " + torch.exp(X_batch_block_1 - X_batch_1_max_update).sum(dim = 1, keepdim = True) # block sum\n", 745 | "\n", 746 | "X_batch_online_softmax = torch.exp(X_batch - X_batch_1_max_update) / X_batch_1_sum_update\n", 747 | "print(X_batch_online_softmax)" 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "execution_count": 17, 753 | "id": "540be634-264d-454f-a242-18cd0680ac38", 754 | "metadata": {}, 755 | "outputs": [ 756 | { 757 | "data": { 758 | "text/html": [ 759 | "
tensor([[0.3240, 0.1929, 0.1369, 0.2691, 0.0563, 0.0209],\n",
760 |        "        [0.0799, 0.1888, 0.0694, 0.4414, 0.0193, 0.2012],\n",
761 |        "        [0.0415, 0.0201, 0.0209, 0.7578, 0.1336, 0.0262],\n",
762 |        "        [0.0881, 0.0978, 0.2418, 0.0182, 0.3653, 0.1888]])\n",
763 |        "
\n" 764 | ], 765 | "text/plain": [ 766 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.3240\u001b[0m, \u001b[1;36m0.1929\u001b[0m, \u001b[1;36m0.1369\u001b[0m, \u001b[1;36m0.2691\u001b[0m, \u001b[1;36m0.0563\u001b[0m, \u001b[1;36m0.0209\u001b[0m\u001b[1m]\u001b[0m,\n", 767 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0799\u001b[0m, \u001b[1;36m0.1888\u001b[0m, \u001b[1;36m0.0694\u001b[0m, \u001b[1;36m0.4414\u001b[0m, \u001b[1;36m0.0193\u001b[0m, \u001b[1;36m0.2012\u001b[0m\u001b[1m]\u001b[0m,\n", 768 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0415\u001b[0m, \u001b[1;36m0.0201\u001b[0m, \u001b[1;36m0.0209\u001b[0m, \u001b[1;36m0.7578\u001b[0m, \u001b[1;36m0.1336\u001b[0m, \u001b[1;36m0.0262\u001b[0m\u001b[1m]\u001b[0m,\n", 769 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0881\u001b[0m, \u001b[1;36m0.0978\u001b[0m, \u001b[1;36m0.2418\u001b[0m, \u001b[1;36m0.0182\u001b[0m, \u001b[1;36m0.3653\u001b[0m, \u001b[1;36m0.1888\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 770 | ] 771 | }, 772 | "metadata": {}, 773 | "output_type": "display_data" 774 | } 775 | ], 776 | "source": [ 777 | "X_batch_softmax_torch = F.softmax(X_batch, dim = 1) \n", 778 | "print(X_batch_softmax_torch)" 779 | ] 780 | }, 781 | { 782 | "cell_type": "markdown", 783 | "id": "87ec5c37-f50d-44bc-8643-14aeb69018e1", 784 | "metadata": {}, 785 | "source": [ 786 | "### Multi Block Batch Online Softmax" 787 | ] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "execution_count": 18, 792 | "id": "19f1d55f-a943-460e-8fcf-04798b1d4423", 793 | "metadata": {}, 794 | "outputs": [ 795 | { 796 | "data": { 797 | "text/html": [ 798 | "
(\n",
799 |        "    tensor([[ 1.2568,  0.7383],\n",
800 |        "        [ 0.0294,  0.8888],\n",
801 |        "        [ 0.3562, -0.3688],\n",
802 |        "        [ 0.1950,  0.3002]]),\n",
803 |        "    tensor([[ 0.3952,  1.0712],\n",
804 |        "        [-0.1123,  1.7380],\n",
805 |        "        [-0.3282,  3.2616],\n",
806 |        "        [ 1.2051, -1.3824]]),\n",
807 |        "    tensor([[-0.4927, -1.4833],\n",
808 |        "        [-1.3930,  0.9524],\n",
809 |        "        [ 1.5259, -0.1031],\n",
810 |        "        [ 1.6178,  0.9580]])\n",
811 |        ")\n",
812 |        "
\n" 813 | ], 814 | "text/plain": [ 815 | "\u001b[1m(\u001b[0m\n", 816 | " \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m1.2568\u001b[0m, \u001b[1;36m0.7383\u001b[0m\u001b[1m]\u001b[0m,\n", 817 | " \u001b[1m[\u001b[0m \u001b[1;36m0.0294\u001b[0m, \u001b[1;36m0.8888\u001b[0m\u001b[1m]\u001b[0m,\n", 818 | " \u001b[1m[\u001b[0m \u001b[1;36m0.3562\u001b[0m, \u001b[1;36m-0.3688\u001b[0m\u001b[1m]\u001b[0m,\n", 819 | " \u001b[1m[\u001b[0m \u001b[1;36m0.1950\u001b[0m, \u001b[1;36m0.3002\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", 820 | " \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m0.3952\u001b[0m, \u001b[1;36m1.0712\u001b[0m\u001b[1m]\u001b[0m,\n", 821 | " \u001b[1m[\u001b[0m\u001b[1;36m-0.1123\u001b[0m, \u001b[1;36m1.7380\u001b[0m\u001b[1m]\u001b[0m,\n", 822 | " \u001b[1m[\u001b[0m\u001b[1;36m-0.3282\u001b[0m, \u001b[1;36m3.2616\u001b[0m\u001b[1m]\u001b[0m,\n", 823 | " \u001b[1m[\u001b[0m \u001b[1;36m1.2051\u001b[0m, \u001b[1;36m-1.3824\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", 824 | " \u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m-0.4927\u001b[0m, \u001b[1;36m-1.4833\u001b[0m\u001b[1m]\u001b[0m,\n", 825 | " \u001b[1m[\u001b[0m\u001b[1;36m-1.3930\u001b[0m, \u001b[1;36m0.9524\u001b[0m\u001b[1m]\u001b[0m,\n", 826 | " \u001b[1m[\u001b[0m \u001b[1;36m1.5259\u001b[0m, \u001b[1;36m-0.1031\u001b[0m\u001b[1m]\u001b[0m,\n", 827 | " \u001b[1m[\u001b[0m \u001b[1;36m1.6178\u001b[0m, \u001b[1;36m0.9580\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n", 828 | "\u001b[1m)\u001b[0m\n" 829 | ] 830 | }, 831 | "metadata": {}, 832 | "output_type": "display_data" 833 | } 834 | ], 835 | "source": [ 836 | "# X_batch = torch.randn(4, 6)\n", 837 | "X_blocks = torch.split(X_batch, 2, dim=1)\n", 838 | "print(X_blocks)" 839 | ] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": 19, 844 | "id": "c2ff29eb-9d6b-43ae-b7af-d6bf6c594b5d", 845 | "metadata": {}, 846 | "outputs": [ 847 | { 848 | "data": { 849 | "text/html": [ 850 | "
tensor([[0.3240, 0.1929, 0.1369, 0.2691, 0.0563, 0.0209],\n",
851 |        "        [0.0799, 0.1888, 0.0694, 0.4414, 0.0193, 0.2012],\n",
852 |        "        [0.0415, 0.0201, 0.0209, 0.7578, 0.1336, 0.0262],\n",
853 |        "        [0.0881, 0.0978, 0.2418, 0.0182, 0.3653, 0.1888]])\n",
854 |        "
\n" 855 | ], 856 | "text/plain": [ 857 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.3240\u001b[0m, \u001b[1;36m0.1929\u001b[0m, \u001b[1;36m0.1369\u001b[0m, \u001b[1;36m0.2691\u001b[0m, \u001b[1;36m0.0563\u001b[0m, \u001b[1;36m0.0209\u001b[0m\u001b[1m]\u001b[0m,\n", 858 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0799\u001b[0m, \u001b[1;36m0.1888\u001b[0m, \u001b[1;36m0.0694\u001b[0m, \u001b[1;36m0.4414\u001b[0m, \u001b[1;36m0.0193\u001b[0m, \u001b[1;36m0.2012\u001b[0m\u001b[1m]\u001b[0m,\n", 859 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0415\u001b[0m, \u001b[1;36m0.0201\u001b[0m, \u001b[1;36m0.0209\u001b[0m, \u001b[1;36m0.7578\u001b[0m, \u001b[1;36m0.1336\u001b[0m, \u001b[1;36m0.0262\u001b[0m\u001b[1m]\u001b[0m,\n", 860 | " \u001b[1m[\u001b[0m\u001b[1;36m0.0881\u001b[0m, \u001b[1;36m0.0978\u001b[0m, \u001b[1;36m0.2418\u001b[0m, \u001b[1;36m0.0182\u001b[0m, \u001b[1;36m0.3653\u001b[0m, \u001b[1;36m0.1888\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 861 | ] 862 | }, 863 | "metadata": {}, 864 | "output_type": "display_data" 865 | }, 866 | { 867 | "data": { 868 | "text/html": [ 869 | "
tensor([[1.0000],\n",
870 |        "        [1.0000],\n",
871 |        "        [1.0000],\n",
872 |        "        [1.0000]])\n",
873 |        "
\n" 874 | ], 875 | "text/plain": [ 876 | "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m1.0000\u001b[0m\u001b[1m]\u001b[0m,\n", 877 | " \u001b[1m[\u001b[0m\u001b[1;36m1.0000\u001b[0m\u001b[1m]\u001b[0m,\n", 878 | " \u001b[1m[\u001b[0m\u001b[1;36m1.0000\u001b[0m\u001b[1m]\u001b[0m,\n", 879 | " \u001b[1m[\u001b[0m\u001b[1;36m1.0000\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" 880 | ] 881 | }, 882 | "metadata": {}, 883 | "output_type": "display_data" 884 | } 885 | ], 886 | "source": [ 887 | "b, d = X_batch.shape\n", 888 | "M_old = torch.ones((b,1)) * -100000.0\n", 889 | "L_old = torch.zeros((b,1))\n", 890 | "\n", 891 | "for X_block in X_blocks:\n", 892 | " M,_ = torch.max(X_block, dim = 1, keepdim = True)\n", 893 | " M_new = torch.maximum(M, M_old) \n", 894 | " \n", 895 | " L_new = L_old * torch.exp(M_old - M_new) \\\n", 896 | " + torch.exp(X_block - M_new).sum(dim = 1, keepdim = True) \n", 897 | " \n", 898 | " M_old = M_new\n", 899 | " L_old = L_new\n", 900 | "\n", 901 | "X_blocks_batch = torch.exp(X_batch - M_old) / L_old\n", 902 | "print(X_blocks_batch)\n", 903 | "print(X_blocks_batch.sum(dim = 1, keepdim = True))" 904 | ] 905 | } 906 | ], 907 | "metadata": { 908 | "kernelspec": { 909 | "display_name": "Python 3 (ipykernel)", 910 | "language": "python", 911 | "name": "python3" 912 | }, 913 | "language_info": { 914 | "codemirror_mode": { 915 | "name": "ipython", 916 | "version": 3 917 | }, 918 | "file_extension": ".py", 919 | "mimetype": "text/x-python", 920 | "name": "python", 921 | "nbconvert_exporter": "python", 922 | "pygments_lexer": "ipython3", 923 | "version": "3.11.9" 924 | } 925 | }, 926 | "nbformat": 4, 927 | "nbformat_minor": 5 928 | } 929 | --------------------------------------------------------------------------------