├── .gitignore
├── LICENSE
├── README.md
├── outer_softmax.py
├── outer_softmax_online.py
├── softmax_loop_along_reduce_axis_v1.py
├── softmax_loop_along_reduce_axis_v2.py
├── softmax_naive.py
├── softmax_online_v1.py
├── softmax_online_v2.py
├── softmax_online_v2_evict.py
├── softmax_online_v2_rev.py
├── softmax_online_v2_spec.py
├── softmax_online_v2_spec_rev.py
├── softmax_online_v2_spec_rev_evict.py
└── softmax_split.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Clement Chan
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # optimize_softmax
2 | Optimize softmax in triton in many cases
3 |
--------------------------------------------------------------------------------
/outer_softmax.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def softmax_kernel(
7 | output_ptr,
8 | input_ptr,
9 | M,
10 | N,
11 | K,
12 | TILE_N: tl.constexpr,
13 | TILE_K: tl.constexpr,
14 | ):
15 | pid_k = tl.program_id(0)
16 | pid_m = tl.program_id(1)
17 |
18 | k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)
19 | n_offsets = tl.arange(0, TILE_N)
20 | offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets
21 | mask = (n_offsets[:, None] < N) & (k_offsets < K)
22 | input_ptrs = input_ptr + offset
23 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.type.element_ty)
24 | m = tl.max(inp, 0)
25 | e = tl.exp(inp - m[None, :])
26 | z = tl.sum(e, 0)
27 | out = e / z
28 | output_ptrs = output_ptr + offset
29 | tl.store(output_ptrs, out, mask=mask)
30 |
31 |
32 | def softmax(x, TILE_K):
33 | inp = x.contiguous()
34 | M = 1
35 | N, K = inp.shape
36 | out = torch.empty_like(x)
37 | TILE_N = triton.next_power_of_2(N)
38 | # TILE_K = 1
39 | grid = (triton.cdiv(K, TILE_K), M, 1)
40 | softmax_kernel[grid](
41 | out,
42 | inp,
43 | M,
44 | N,
45 | K,
46 | TILE_N,
47 | TILE_K,
48 | )
49 | return out
50 |
51 |
52 | import pytest
53 |
54 | @pytest.mark.parametrize("n", [10, 128])
55 | @pytest.mark.parametrize("m", [512, 1024, 32 * 1024])
56 | @pytest.mark.parametrize("TILE_K", [1, 2, 4])
57 | def test_softmax(m, n, TILE_K):
58 | x = torch.randn((m, n), device="cuda")
59 | hyp = softmax(x, TILE_K)
60 | ref = torch.softmax(x, dim=0)
61 | torch.testing.assert_close(hyp, ref)
62 |
63 | def benchmark_softmax(m, n):
64 | x = torch.randn((m, n), device="cuda")
65 | t1 = triton.testing.do_bench(lambda: softmax(x, 1), return_mode="median")
66 | t2 = triton.testing.do_bench(lambda: softmax(x, 2), return_mode="median")
67 | t3 = triton.testing.do_bench(lambda: softmax(x, 4), return_mode="median")
68 | t4 = triton.testing.do_bench(lambda: torch.softmax(x, dim=0), return_mode="median")
69 | def throughput(t):
70 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3)
71 | return throughput(t1), throughput(t2), throughput(t3), throughput(t4)
72 |
73 | import pandas as pd
74 | def run_benchmark():
75 | records = []
76 | for n in [10, 128, 1024, 4096]:
77 | for m in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64 * 1024, 128 * 1024]:
78 | if m * n * 4 > 1024 * 1024 * 1024 * 4:
79 | continue
80 | t1, t2, t3, t4 = benchmark_softmax(m, n)
81 | record = (m, n, t1, t2, t3, t4)
82 | records.append(record)
83 | df = pd.DataFrame.from_records(records, columns=["reduce_size", "post_size", "naive_outer_k1", "naive_outer_k2", "naive_outer_k4", "torch"])
84 | print(df)
85 | df.to_excel("naive_outer.xlsx")
86 |
87 | def run_an_example(m, n, tile_k):
88 | x = torch.randn((m, n), device="cuda")
89 | y = softmax(x, tile_k)
90 |
91 | if __name__ == "__main__":
92 | run_benchmark()
93 | # run_an_example(4096, 4096, )
--------------------------------------------------------------------------------
/outer_softmax_online.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 |
6 | @triton.jit
7 | def next_multiple_of(a, b):
8 | # the smallest x>=a that x%b ==0
9 | return tl.cidv(a, b) * b
10 |
11 |
12 | @triton.jit
13 | def prev_multiple_of(a, b):
14 | # the largest x 1024 * 1024 * 1024 * 4:
116 | continue
117 | for tile_k in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
118 | t1 = benchmark_softmax(m, n, tile_k)
119 | record = (m, n, tile_k, t1, triton.cdiv(n, tile_k))
120 | records.append(record)
121 | df = pd.DataFrame.from_records(records, columns=["reduce_size", "post_size", "tiel_k", "online", "grid"])
122 | print(df)
123 | df.to_excel("online_outer.xlsx")
124 |
125 | def run_an_example(m, n, tile_k):
126 | x = torch.randn((m, n), device="cuda")
127 | y = softmax(x, tile_k)
128 |
129 | if __name__ == "__main__":
130 | run_benchmark()
131 | # run_an_example(4096, 32 * 1024, 64)
132 | # run_an_example(4096, 32 * 1024, 2)
133 | # run_an_example(4096, 32 * 1024, 1024)
134 |
--------------------------------------------------------------------------------
/softmax_loop_along_reduce_axis_v1.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def softmax_kernel_loop_v1(
7 | output_ptr,
8 | input_ptr,
9 | M,
10 | N,
11 | TILE_N: tl.constexpr,
12 | ):
13 | pid_m = tl.program_id(0)
14 | m = tl.full((), value=-float("inf"), dtype=output_ptr.dtype.element_ty)
15 | for start_n in range(0, N, TILE_N):
16 | n_offsets = start_n + tl.arange(0, TILE_N)
17 | offset = pid_m * N + n_offsets
18 | input_ptrs = input_ptr + offset
19 | mask = n_offsets < N
20 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
21 | m = tl.maximum(m, tl.max(inp, 0))
22 |
23 | z = tl.full((), value=0, dtype=output_ptr.dtype.element_ty)
24 | for start_n in range(0, N, TILE_N):
25 | n_offsets = start_n + tl.arange(0, TILE_N)
26 | offset = pid_m * N + n_offsets
27 | input_ptrs = input_ptr + offset
28 | mask = n_offsets < N
29 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
30 | e = tl.exp(inp - m)
31 | z += tl.sum(e)
32 |
33 | for start_n in range(0, N, TILE_N):
34 | n_offsets = start_n + tl.arange(0, TILE_N)
35 | offset = pid_m * N + n_offsets
36 | input_ptrs = input_ptr + offset
37 | mask = n_offsets < N
38 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
39 | e = tl.exp(inp - m)
40 | out = e / z
41 | output_ptrs = output_ptr + offset
42 | tl.store(output_ptrs, out, mask=mask)
43 |
44 |
45 | def softmax(x):
46 | M, N = x.shape
47 | out = torch.empty_like(x)
48 | # TODO tune
49 | TILE_N = min(4096, triton.next_power_of_2(N))
50 | grid = (M, 1, 1)
51 | softmax_kernel_loop_v1[grid](out, x, M, N, TILE_N)
52 | return out
53 |
54 | import pytest
55 |
56 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024])
57 | @pytest.mark.parametrize("m", [10, 128])
58 | def test_softmax(m, n):
59 | x = torch.randn((m, n), device="cuda")
60 | hyp = softmax(x)
61 | ref = torch.softmax(x, dim=-1)
62 | torch.testing.assert_close(hyp, ref)
63 |
64 | def benchmark_softmax(m, n):
65 | x = torch.randn((m, n), device="cuda")
66 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median")
67 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median")
68 | def throughput(t):
69 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3)
70 | return throughput(t1), throughput(t2)
71 |
72 | import pandas as pd
73 | def run_benchmark():
74 | records = []
75 | for m in [10, 128, 1024, 4096]:
76 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]:
77 | if m * n * 4 > 1024 * 1024 * 1024 * 4:
78 | continue
79 | t1, t2 = benchmark_softmax(m, n)
80 | record = (m, n, t1, t2)
81 | records.append(record)
82 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "loop_v1", "torch"])
83 | print(df)
84 | df.to_excel("loop_v1.xlsx")
85 |
86 | if __name__ == "__main__":
87 | run_benchmark()
88 |
--------------------------------------------------------------------------------
/softmax_loop_along_reduce_axis_v2.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def softmax_kernel_loop_v2(
7 | output_ptr,
8 | input_ptr,
9 | M,
10 | N,
11 | TILE_N: tl.constexpr,
12 | ):
13 | pid_m = tl.program_id(0)
14 | m = tl.full((TILE_N, ), value=-float("inf"), dtype=output_ptr.dtype.element_ty)
15 | for start_n in range(0, N, TILE_N):
16 | n_offsets = start_n + tl.arange(0, TILE_N)
17 | offset = pid_m * N + n_offsets
18 | input_ptrs = input_ptr + offset
19 | mask = n_offsets < N
20 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
21 | m = tl.maximum(m, inp)
22 | m = tl.max(m, 0)
23 |
24 | z = tl.full((TILE_N, ), value=0, dtype=output_ptr.dtype.element_ty)
25 | for start_n in range(0, N, TILE_N):
26 | n_offsets = start_n + tl.arange(0, TILE_N)
27 | offset = pid_m * N + n_offsets
28 | input_ptrs = input_ptr + offset
29 | mask = n_offsets < N
30 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
31 | e = tl.exp(inp - m)
32 | z += e
33 | z = tl.sum(z, 0)
34 |
35 | for start_n in range(0, N, TILE_N):
36 | n_offsets = start_n + tl.arange(0, TILE_N)
37 | offset = pid_m * N + n_offsets
38 | input_ptrs = input_ptr + offset
39 | mask = n_offsets < N
40 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
41 | e = tl.exp(inp - m)
42 | out = e / z
43 | output_ptrs = output_ptr + offset
44 | tl.store(output_ptrs, out, mask=mask)
45 |
46 |
47 | def softmax(x):
48 | M, N = x.shape
49 | out = torch.empty_like(x)
50 | # TODO: tune
51 | TILE_N = min(4096, triton.next_power_of_2(N))
52 | grid = (M, 1, 1)
53 | softmax_kernel_loop_v2[grid](out, x, M, N, TILE_N)
54 | return out
55 |
56 | import pytest
57 |
58 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024])
59 | @pytest.mark.parametrize("m", [10, 128])
60 | def test_softmax(m, n):
61 | x = torch.randn((m, n), device="cuda")
62 | hyp = softmax(x)
63 | ref = torch.softmax(x, dim=-1)
64 | torch.testing.assert_close(hyp, ref)
65 |
66 | def benchmark_softmax(m, n):
67 | x = torch.randn((m, n), device="cuda")
68 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median")
69 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median")
70 | def throughput(t):
71 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3)
72 | return throughput(t1), throughput(t2)
73 |
74 | import pandas as pd
75 | def run_benchmark():
76 | records = []
77 | for m in [10, 128, 1024, 4096]:
78 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]:
79 | if m * n * 4 > 1024 * 1024 * 1024 * 4:
80 | continue
81 | t1, t2 = benchmark_softmax(m, n)
82 | record = (m, n, t1, t2)
83 | records.append(record)
84 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "loop_v2", "torch"])
85 | print(df)
86 | df.to_excel("loop_v2.xlsx")
87 |
88 | if __name__ == "__main__":
89 | run_benchmark()
--------------------------------------------------------------------------------
/softmax_naive.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def softmax_kernel(
7 | output_ptr,
8 | input_ptr,
9 | M,
10 | N,
11 | TILE_N: tl.constexpr,
12 | ):
13 | pid_m = tl.program_id(0)
14 | n_offsets = tl.arange(0, TILE_N)
15 | offset = pid_m * N + n_offsets
16 | input_ptrs = input_ptr + offset
17 | mask = n_offsets < N
18 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
19 | m = tl.max(inp, 0)
20 | e = tl.exp(inp - m)
21 | z = tl.sum(e, 0)
22 | out = e / z
23 | output_ptrs = output_ptr + offset
24 | tl.store(output_ptrs, out, mask=mask)
25 |
26 |
27 | def softmax(x):
28 | M, N = x.shape
29 | out = torch.empty_like(x)
30 | TILE_N = triton.next_power_of_2(N)
31 | grid = (M, 1, 1)
32 | softmax_kernel[grid](out, x, M, N, TILE_N)
33 | return out
34 |
35 | import pytest
36 |
37 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024])
38 | @pytest.mark.parametrize("m", [10, 128, 1024])
39 | def test_softmax(m, n):
40 | x = torch.randn((m, n), device="cuda")
41 | hyp = softmax(x)
42 | ref = torch.softmax(x, dim=-1)
43 | torch.testing.assert_close(hyp, ref)
44 |
45 | def benchmark_softmax(m, n):
46 | x = torch.randn((m, n), device="cuda")
47 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median")
48 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median")
49 | def throughput(t):
50 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3)
51 | return throughput(t1), throughput(t2)
52 |
53 | import pandas as pd
54 | def run_benchmark():
55 | records = []
56 | for m in [10, 128, 1024, 4096]:
57 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024]:
58 | t1, t2 = benchmark_softmax(m, n)
59 | record = (m, n, t1, t2)
60 | records.append(record)
61 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "naive", "torch"])
62 | print(df)
63 | df.to_excel("naive.xlsx")
64 |
65 | def run_an_example(m, n):
66 | x = torch.randn((m, n), device="cuda")
67 | y = softmax(x)
68 |
69 | if __name__ == "__main__":
70 | run_an_example(4096, 4 * 1024)
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/softmax_online_v1.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def softmax_kernel_online_v1(
7 | output_ptr,
8 | input_ptr,
9 | M,
10 | N,
11 | TILE_N: tl.constexpr,
12 | ):
13 | pid_m = tl.program_id(0)
14 | m = tl.full((), value=-float("inf"), dtype=output_ptr.dtype.element_ty)
15 | z = tl.full((), value=0, dtype=output_ptr.dtype.element_ty)
16 | for start_n in range(0, N, TILE_N):
17 | n_offsets = start_n + tl.arange(0, TILE_N)
18 | offset = pid_m * N + n_offsets
19 | input_ptrs = input_ptr + offset
20 | mask = n_offsets < N
21 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
22 | new_m = tl.maximum(m, tl.max(inp, 0))
23 | new_z = tl.exp(m - new_m) * z + tl.sum(tl.exp(inp - new_m), 0)
24 | m = new_m
25 | z = new_z
26 |
27 | for start_n in range(0, N, TILE_N):
28 | n_offsets = start_n + tl.arange(0, TILE_N)
29 | offset = pid_m * N + n_offsets
30 | input_ptrs = input_ptr + offset
31 | mask = n_offsets < N
32 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
33 | e = tl.exp(inp - m)
34 | out = e / z
35 | output_ptrs = output_ptr + offset
36 | tl.store(output_ptrs, out, mask=mask)
37 |
38 |
39 | def softmax(x):
40 | M, N = x.shape
41 | out = torch.empty_like(x)
42 | # TODO tune
43 | TILE_N = min(4096, triton.next_power_of_2(N))
44 | grid = (M, 1, 1)
45 | softmax_kernel_online_v1[grid](out, x, M, N, TILE_N)
46 | return out
47 |
48 | import pytest
49 |
50 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024])
51 | @pytest.mark.parametrize("m", [10, 128])
52 | def test_softmax(m, n):
53 | x = torch.randn((m, n), device="cuda")
54 | hyp = softmax(x)
55 | ref = torch.softmax(x, dim=-1)
56 | torch.testing.assert_close(hyp, ref)
57 |
58 | def benchmark_softmax(m, n):
59 | x = torch.randn((m, n), device="cuda")
60 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median")
61 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median")
62 | def throughput(t):
63 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3)
64 | return throughput(t1), throughput(t2)
65 |
66 | import pandas as pd
67 | def run_benchmark():
68 | records = []
69 | for m in [10, 128, 1024, 4096]:
70 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]:
71 | if m * n * 4 > 1024 * 1024 * 1024 * 4:
72 | continue
73 | t1, t2 = benchmark_softmax(m, n)
74 | record = (m, n, t1, t2)
75 | records.append(record)
76 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v1", "torch"])
77 | print(df)
78 | df.to_excel("online_v1.xlsx")
79 |
80 | if __name__ == "__main__":
81 | run_benchmark()
--------------------------------------------------------------------------------
/softmax_online_v2.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def softmax_kernel_online_v2(
7 | output_ptr,
8 | input_ptr,
9 | M,
10 | N,
11 | TILE_N: tl.constexpr,
12 | ):
13 | pid_m = tl.program_id(0)
14 | m = tl.full((TILE_N,), value=-float("inf"), dtype=output_ptr.dtype.element_ty)
15 | z = tl.full((TILE_N,), value=0, dtype=output_ptr.dtype.element_ty)
16 | for start_n in range(0, N, TILE_N):
17 | n_offsets = start_n + tl.arange(0, TILE_N)
18 | offset = pid_m * N + n_offsets
19 | input_ptrs = input_ptr + offset
20 | mask = n_offsets < N
21 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
22 | new_m = tl.maximum(m, inp)
23 | new_z = tl.exp(m - new_m) * z + tl.exp(inp - new_m)
24 | m = new_m
25 | z = new_z
26 | final_m = tl.max(m, 0)
27 | z = tl.sum(tl.exp(m - final_m) * z)
28 | m = final_m
29 |
30 | for start_n in range(0, N, TILE_N):
31 | n_offsets = start_n + tl.arange(0, TILE_N)
32 | offset = pid_m * N + n_offsets
33 | input_ptrs = input_ptr + offset
34 | mask = n_offsets < N
35 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
36 | e = tl.exp(inp - m)
37 | out = e / z
38 | output_ptrs = output_ptr + offset
39 | tl.store(output_ptrs, out, mask=mask)
40 |
41 |
42 | def softmax(x):
43 | M, N = x.shape
44 | out = torch.empty_like(x)
45 | # TODO tune
46 | TILE_N = min(4096, triton.next_power_of_2(N))
47 | grid = (M, 1, 1)
48 | softmax_kernel_online_v2[grid](out, x, M, N, TILE_N)
49 | return out
50 |
51 | import pytest
52 |
53 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024])
54 | @pytest.mark.parametrize("m", [10, 128])
55 | def test_softmax(m, n):
56 | x = torch.randn((m, n), device="cuda")
57 | hyp = softmax(x)
58 | ref = torch.softmax(x, dim=-1)
59 | torch.testing.assert_close(hyp, ref)
60 |
61 | def benchmark_softmax(m, n):
62 | x = torch.randn((m, n), device="cuda")
63 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median")
64 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median")
65 | def throughput(t):
66 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3)
67 | return throughput(t1), throughput(t2)
68 |
69 | import pandas as pd
70 | def run_benchmark():
71 | records = []
72 | for m in [10, 128, 1024, 4096]:
73 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]:
74 | if m * n * 4 > 1024 * 1024 * 1024 * 4:
75 | continue
76 | t1, t2 = benchmark_softmax(m, n)
77 | record = (m, n, t1, t2)
78 | records.append(record)
79 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2", "torch"])
80 | print(df)
81 | df.to_excel("online_v2.xlsx")
82 |
83 | if __name__ == "__main__":
84 | run_benchmark()
--------------------------------------------------------------------------------
/softmax_online_v2_evict.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def softmax_kernel_online_v2(
7 | output_ptr,
8 | input_ptr,
9 | M,
10 | N,
11 | TILE_N: tl.constexpr,
12 | ):
13 | pid_m = tl.program_id(0)
14 | m = tl.full((TILE_N,), value=-float("inf"), dtype=output_ptr.dtype.element_ty)
15 | z = tl.full((TILE_N,), value=0, dtype=output_ptr.dtype.element_ty)
16 | for start_n in range(0, N, TILE_N):
17 | n_offsets = start_n + tl.arange(0, TILE_N)
18 | offset = pid_m * N + n_offsets
19 | input_ptrs = input_ptr + offset
20 | mask = n_offsets < N
21 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
22 | new_m = tl.maximum(m, inp)
23 | new_z = tl.exp(m - new_m) * z + tl.exp(inp - new_m)
24 | m = new_m
25 | z = new_z
26 | final_m = tl.max(m, 0)
27 | z = tl.sum(tl.exp(m - final_m) * z)
28 | m = final_m
29 |
30 | for start_n in range(0, N, TILE_N):
31 | n_offsets = start_n + tl.arange(0, TILE_N)
32 | offset = pid_m * N + n_offsets
33 | input_ptrs = input_ptr + offset
34 | mask = n_offsets < N
35 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf"), eviction_policy="evict_first").to(output_ptr.dtype.element_ty)
36 | e = tl.exp(inp - m)
37 | out = e / z
38 | output_ptrs = output_ptr + offset
39 | tl.store(output_ptrs, out, mask=mask)
40 |
41 |
42 | def softmax(x):
43 | M, N = x.shape
44 | out = torch.empty_like(x)
45 | # TODO tune
46 | TILE_N = min(4096, triton.next_power_of_2(N))
47 | grid = (M, 1, 1)
48 | softmax_kernel_online_v2[grid](out, x, M, N, TILE_N)
49 | return out
50 |
51 | import pytest
52 |
53 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024])
54 | @pytest.mark.parametrize("m", [10, 128])
55 | def test_softmax(m, n):
56 | x = torch.randn((m, n), device="cuda")
57 | hyp = softmax(x)
58 | ref = torch.softmax(x, dim=-1)
59 | torch.testing.assert_close(hyp, ref)
60 |
61 | def benchmark_softmax(m, n):
62 | x = torch.randn((m, n), device="cuda")
63 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median")
64 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median")
65 | def throughput(t):
66 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3)
67 | return throughput(t1), throughput(t2)
68 |
69 | import pandas as pd
70 | def run_benchmark():
71 | records = []
72 | for m in [10, 128, 1024, 4096]:
73 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]:
74 | if m * n * 4 > 1024 * 1024 * 1024 * 4:
75 | continue
76 | t1, t2 = benchmark_softmax(m, n)
77 | record = (m, n, t1, t2)
78 | records.append(record)
79 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_evict", "torch"])
80 | print(df)
81 | df.to_excel("online_v2_evict.xlsx")
82 |
83 | if __name__ == "__main__":
84 | run_benchmark()
--------------------------------------------------------------------------------
/softmax_online_v2_rev.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def next_multiple_of(a, b):
7 | # the smallest x>=a that x%b ==0
8 | return tl.cidv(a, b) * b
9 |
10 |
11 | @triton.jit
12 | def prev_multiple_of(a, b):
13 | # the largest x 1024 * 1024 * 1024 * 4:
87 | continue
88 | t1, t2 = benchmark_softmax(m, n)
89 | record = (m, n, t1, t2)
90 | records.append(record)
91 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_rev", "torch"])
92 | print(df)
93 | df.to_excel("online_v2_rev.xlsx")
94 |
95 | if __name__ == "__main__":
96 | run_benchmark()
--------------------------------------------------------------------------------
/softmax_online_v2_spec.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def next_multiple_of(a, b):
7 | # the smallest x>=a that x%b ==0
8 | return tl.cidv(a, b) * b
9 |
10 |
11 | @triton.jit
12 | def prev_multiple_of(a, b):
13 | # the largest x 1024 * 1024 * 1024 * 4:
107 | continue
108 | t1, t2 = benchmark_softmax(m, n)
109 | record = (m, n, t1, t2)
110 | records.append(record)
111 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_spec", "torch"])
112 | print(df)
113 | df.to_excel("online_v2_spec.xlsx")
114 |
115 | if __name__ == "__main__":
116 | run_benchmark()
--------------------------------------------------------------------------------
/softmax_online_v2_spec_rev.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def next_multiple_of(a, b):
7 | # the smallest x>=a that x%b ==0
8 | return tl.cidv(a, b) * b
9 |
10 |
11 | @triton.jit
12 | def prev_multiple_of(a, b):
13 | # the largest x 1024 * 1024 * 1024 * 4:
107 | continue
108 | t1, t2 = benchmark_softmax(m, n)
109 | record = (m, n, t1, t2)
110 | records.append(record)
111 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_spec_rev", "torch"])
112 | print(df)
113 | df.to_excel("online_v2_spec_rev.xlsx")
114 |
115 | if __name__ == "__main__":
116 | run_benchmark()
--------------------------------------------------------------------------------
/softmax_online_v2_spec_rev_evict.py:
--------------------------------------------------------------------------------
1 | import triton
2 | from triton import language as tl
3 | import torch
4 |
5 | @triton.jit
6 | def next_multiple_of(a, b):
7 | # the smallest x>=a that x%b ==0
8 | return tl.cidv(a, b) * b
9 |
10 |
11 | @triton.jit
12 | def prev_multiple_of(a, b):
13 | # the largest x 1024 * 1024 * 1024 * 4:
107 | continue
108 | t1, t2 = benchmark_softmax(m, n)
109 | record = (m, n, t1, t2)
110 | records.append(record)
111 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_spec_rev_evict", "torch"])
112 | print(df)
113 | df.to_excel("online_v2_spec_rev_evict.xlsx")
114 |
115 | if __name__ == "__main__":
116 | run_benchmark()
--------------------------------------------------------------------------------
/softmax_split.py:
--------------------------------------------------------------------------------
1 | import math
2 | import triton
3 | from triton import language as tl
4 | import torch
5 |
6 | @triton.jit
7 | def logsumexp_kernel(
8 | out_ptr,
9 | in_ptr,
10 | M,
11 | N,
12 | TILE_N: tl.constexpr,
13 | ):
14 | pid_n = tl.program_id(0)
15 | num_programs_n = tl.num_programs(0)
16 | pid_m = tl.program_id(1)
17 |
18 | n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N)
19 | mask = n_offsets < N
20 | offset = pid_m * N + n_offsets
21 | inp = tl.load(in_ptr + offset, mask=mask, other=-float("inf")).to(out_ptr.dtype.element_ty)
22 | m = tl.max(inp, 0)
23 | e = tl.exp(inp - m)
24 | z = tl.sum(e, 0)
25 | logz = m + tl.log(z)
26 |
27 | output_ptrs = out_ptr + pid_m * num_programs_n + pid_n
28 | tl.store(output_ptrs, logz)
29 |
30 | @triton.jit
31 | def combine_logsumexp_kernel(out_ptr, inp_ptr, M, N, TILE_N: tl.constexpr):
32 | pid_m = tl.program_id(0)
33 | n_offsets = tl.arange(0, TILE_N)
34 | mask = n_offsets < N
35 | logzs = tl.load(inp_ptr + pid_m * N + n_offsets, other=-float("inf"), mask=mask).to(out_ptr.dtype.element_ty)
36 | m = tl.max(logzs, 0)
37 | e = tl.exp(logzs - m)
38 | z = tl.sum(e, 0)
39 | logz = m + tl.log(z)
40 | tl.store(out_ptr + pid_m, logz)
41 |
42 | @triton.jit
43 | def softmax_kernel(out_ptr, in_ptr, logz_ptr, M, N, TILE_N: tl.constexpr):
44 | pid_n = tl.program_id(0)
45 | pid_m = tl.program_id(1)
46 | n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N)
47 | offset = pid_m * N + n_offsets
48 | mask = n_offsets < N
49 | inp = tl.load(in_ptr + offset, mask=mask, other=-float("inf")).to(out_ptr.dtype.element_ty)
50 | logz = tl.load(logz_ptr + pid_m).to(out_ptr.dtype.element_ty)
51 | out = tl.exp(inp - logz)
52 | tl.store(out_ptr + offset, out, mask=mask)
53 |
54 |
55 |
56 | def softmax(x):
57 | M, N = x.shape
58 |
59 | num_sms = torch.cuda.get_device_properties(x.device).multi_processor_count
60 |
61 | TILE_N = min(4096, triton.next_power_of_2(N))
62 | num_tiles_n = triton.cdiv(N, TILE_N)
63 | logz = torch.empty((M, num_tiles_n), dtype=x.dtype, device=x.device)
64 | grid = (num_tiles_n, M, 1)
65 | logsumexp_kernel[grid](logz, x, M, N, TILE_N)
66 |
67 | combined_logz = torch.empty((M, ), dtype=x.dtype, device=x.device)
68 | TILE_N = triton.next_power_of_2(num_tiles_n)
69 | grid = (M, 1, 1)
70 | combine_logsumexp_kernel[grid](combined_logz, logz, M, num_tiles_n, TILE_N)
71 |
72 | out = torch.empty_like(x)
73 | TILE_N = min(4096, triton.next_power_of_2(N))
74 | num_tiles_n = triton.cdiv(N, TILE_N)
75 | grid = (num_tiles_n, M, 1)
76 | softmax_kernel[grid](out, x, combined_logz, M, N, TILE_N)
77 | return out
78 |
79 | import pytest
80 |
81 | @pytest.mark.parametrize("n", [128 * 1024, 1024 * 1024, 8192 * 1024])
82 | @pytest.mark.parametrize("m", [10, ])
83 | def test_softmax(m, n):
84 | x = torch.randn((m, n), device="cuda")
85 | hyp = softmax(x)
86 | ref = torch.softmax(x, dim=-1)
87 | torch.testing.assert_close(hyp, ref)
88 |
89 | def benchmark_softmax(m, n):
90 | x = torch.randn((m, n), device="cuda")
91 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median")
92 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median")
93 | def throughput(t):
94 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3)
95 | return throughput(t1), throughput(t2)
96 |
97 | import pandas as pd
98 | def run_benchmark():
99 | records = []
100 | for m in [1, 8, 16, 32]:
101 | for n in [16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024, 2048 * 1024, 4096 * 1024, 8192 * 1024]:
102 | if m * n * 4 > 1024 * 1024 * 1024 * 4:
103 | continue
104 | t1, t2 = benchmark_softmax(m, n)
105 | record = (m, n, t1, t2)
106 | records.append(record)
107 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "split", "torch"])
108 | print(df)
109 | df.to_excel("split.xlsx")
110 |
111 | if __name__ == "__main__":
112 | run_benchmark()
--------------------------------------------------------------------------------