├── setup.cfg ├── requirements.txt ├── pyproject.toml ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── hello_triton.py ├── grid.py ├── atomic.py ├── autograd.py ├── reduction.py ├── vector_add.py ├── block_ptr.py ├── heuristics.py ├── matmul.py ├── argmax.py ├── .gitignore ├── cnn.py ├── dropout.py └── relu.py /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = E9, F63, F7, F82 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | matplotlib 5 | pandas 6 | triton 7 | tabulate -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.flake8] 2 | max-line-length = 120 3 | select = ["E9", "F63", "F7", "F82"] 4 | 5 | [tool.black] 6 | line-length = 120 7 | target-version = ["py38"] 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/pycqa/flake8 9 | rev: 6.1.0 10 | hooks: 11 | - id: flake8 12 | - repo: https://github.com/psf/black 13 | rev: 23.9.0 14 | hooks: 15 | - id: black 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 daemyung jang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![header](https://capsule-render.vercel.app/api?type=waving&color=gradient&height=256&text=삼각형의%20실전!%20Triton%20초급) 2 | 3 | # 반갑습니다! 🤗 4 | 5 | [![Inflearn](https://img.shields.io/badge/-Inflearn-brightgreen?style=for-the-badge)](https://inf.run/KRwn) 6 | [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/BQbBddeXfK) 7 | [![Triton](https://img.shields.io/badge/OpenAI-100000?style=for-the-badge&logo=openai&logoColor=white)](https://github.com/openai/triton) 8 | [![LinkedIn](https://img.shields.io/badge/linkedin-%230077B5.svg?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/djang-88b01b91) 9 | 10 | ------------------------------------------------------------------------------------------------------------------------ 11 | 12 | ## 📖 개요 13 | 14 | 이 저장소의 코드는 [삼각형의 실전! Triton 초급](https://inf.run/n1KPQ) 강의를 위해서 작성되었습니다. \ 15 | 강의에서 코드에 대한 설명을 들을 수 있기 때문에 코드를 위한 설명은 별도로 제공하지 않습니다. \ 16 | 학생이신 경우에 저에게 이메일을 보내주시면 **50%** 할인을 지원해 드리겠습니다. 17 | 18 | ## 🎓 학습 방법 19 | 20 | 예제를 실행하고 수정하면서 Triton과 친해질 수 있습니다. \ 21 | 이해가 안되거나 궁금한게 생기면 바로 질문하면 학습 효율을 높힐 수 있습니다. 22 | 23 | ## 🙋 질문 방법 24 | 25 | 질문은 인프런을 통하거나 이슈를 생성해서 물어보면 됩니다. 26 | -------------------------------------------------------------------------------- /hello_triton.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import triton 24 | import triton.language as tl 25 | 26 | 27 | @triton.jit 28 | def hello_triton(): 29 | tl.device_print("Hello Triton!") 30 | 31 | 32 | def main(): 33 | hello_triton[(1,)]() 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /grid.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import triton 24 | import triton.language as tl 25 | 26 | 27 | @triton.jit 28 | def print_grid(): 29 | pid = tl.program_id(0) 30 | tl.device_print("pid: ", pid) 31 | 32 | 33 | def main(): 34 | def grid(meta): 35 | return (2,) 36 | 37 | print_grid[grid]() 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /atomic.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import triton 25 | import triton.language as tl 26 | 27 | 28 | @triton.jit 29 | def atomic_kernel(x_ptr, increment): 30 | tl.atomic_add(x_ptr, increment) 31 | 32 | 33 | def atomic(increment): 34 | x = torch.zeros(1, device="cuda") 35 | 36 | def grid(meta): 37 | return (1024,) 38 | 39 | atomic_kernel[grid](x, increment) 40 | 41 | return x 42 | 43 | 44 | def main(): 45 | x = atomic(2) 46 | print(x) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /autograd.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from typing import Any 24 | 25 | import torch 26 | 27 | 28 | class Mul(torch.autograd.Function): 29 | @staticmethod 30 | def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: 31 | a, b = args 32 | ctx.save_for_backward(a, b) 33 | return a * b 34 | 35 | @staticmethod 36 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 37 | a, b = ctx.saved_tensors 38 | return b, a 39 | 40 | 41 | mul = Mul.apply 42 | 43 | 44 | def main(): 45 | a = torch.tensor(2.0, requires_grad=True) 46 | b = torch.tensor(3.0) 47 | c = mul(a, b) 48 | print(f"{a} * {b} = {c}") 49 | c.backward() 50 | print(f"a's gradient is {a.grad}") 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /reduction.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import triton 25 | import triton.language as tl 26 | 27 | 28 | @triton.jit 29 | def combine_add(a, b): 30 | return a + b 31 | 32 | 33 | @triton.jit 34 | def sum_kernel(y_ptr, x_ptr, size, block_size: tl.constexpr): 35 | offsets = tl.arange(0, block_size) 36 | mask = offsets < size 37 | 38 | x = tl.load(x_ptr + offsets, mask) 39 | y = tl.reduce(x, 0, combine_add) 40 | tl.store(y_ptr, y) 41 | 42 | 43 | def sum(x): 44 | size = x.numel() 45 | y = torch.empty(1, device="cuda") 46 | 47 | def grid(meta): 48 | return (1,) 49 | 50 | sum_kernel[grid](y, x, size, triton.next_power_of_2(size)) 51 | 52 | return y 53 | 54 | 55 | def main(): 56 | x = torch.randn(1024, device="cuda") 57 | 58 | a = sum(x) 59 | b = torch.sum(x) 60 | 61 | assert torch.allclose(a, b) 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /vector_add.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import triton 25 | import triton.language as tl 26 | 27 | 28 | @triton.jit 29 | def add_kernel(x_ptr, y_ptr, z_ptr, size, block_size: tl.constexpr): 30 | pid = tl.program_id(0) 31 | 32 | offsets = tl.arange(0, block_size) + pid * block_size 33 | mask = offsets < size 34 | 35 | x = tl.load(x_ptr + offsets, mask) 36 | y = tl.load(y_ptr + offsets, mask) 37 | z = x + y 38 | 39 | tl.store(z_ptr + offsets, z, mask) 40 | 41 | 42 | def add(x, y): 43 | z = torch.empty_like(x, device="cuda") 44 | size = z.numel() 45 | 46 | def grid(meta): 47 | return (triton.cdiv(size, meta["block_size"]),) 48 | 49 | add_kernel[grid](x, y, z, size, 1024) 50 | 51 | return z 52 | 53 | 54 | def main(): 55 | size = 2**16 56 | x = torch.rand(size, device="cuda") 57 | y = torch.rand(size, device="cuda") 58 | 59 | a = add(x, y) 60 | b = x + y 61 | 62 | assert torch.allclose(a, b) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /block_ptr.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import triton 25 | import triton.language as tl 26 | 27 | 28 | @triton.jit 29 | def add_kernel(x_ptr, y_ptr, z_ptr, size, block_size: tl.constexpr): 30 | offset = tl.program_id(0) * block_size 31 | 32 | x_block_ptr = tl.make_block_ptr( 33 | x_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 34 | ) 35 | y_block_ptr = tl.make_block_ptr( 36 | y_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 37 | ) 38 | 39 | x = tl.load(x_block_ptr, boundary_check=(0,)) 40 | y = tl.load(y_block_ptr, boundary_check=(0,)) 41 | z = x + y 42 | 43 | z_block_ptr = tl.make_block_ptr( 44 | z_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 45 | ) 46 | 47 | tl.store(z_block_ptr, z, boundary_check=(0,)) 48 | 49 | 50 | def add(x, y): 51 | z = torch.empty_like(x, device="cuda") 52 | size = z.numel() 53 | 54 | def grid(meta): 55 | return (triton.cdiv(size, meta["block_size"]),) 56 | 57 | add_kernel[grid](x, y, z, size, 1024) 58 | 59 | return z 60 | 61 | 62 | def main(): 63 | size = 2**16 64 | x = torch.rand(size, device="cuda") 65 | y = torch.rand(size, device="cuda") 66 | 67 | a = add(x, y) 68 | b = x + y 69 | 70 | assert torch.allclose(a, b) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /heuristics.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2024 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import triton 25 | import triton.language as tl 26 | 27 | 28 | @triton.heuristics({'boundary_check': lambda args: args["x_size"] % args["block_size"] }) 29 | @triton.jit 30 | def add_kernel(x_ptr, y_ptr, z_ptr, size, block_size: tl.constexpr, boundary_check: tl.constexpr): 31 | offset = tl.program_id(0) * block_size 32 | 33 | x_block_ptr = tl.make_block_ptr( 34 | x_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 35 | ) 36 | y_block_ptr = tl.make_block_ptr( 37 | y_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 38 | ) 39 | 40 | if boundary_check: 41 | x = tl.load(x_block_ptr, boundary_check=(0,)) 42 | y = tl.load(y_block_ptr, boundary_check=(0,)) 43 | else: 44 | x = tl.load(x_block_ptr) 45 | y = tl.load(y_block_ptr) 46 | 47 | z = x + y 48 | 49 | z_block_ptr = tl.make_block_ptr( 50 | z_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 51 | ) 52 | 53 | if boundary_check: 54 | tl.store(z_block_ptr, z, boundary_check=(0,)) 55 | else: 56 | tl.store(z_block_ptr, z) 57 | 58 | 59 | def add(x, y): 60 | z = torch.empty_like(x, device="cuda") 61 | size = z.numel() 62 | 63 | def grid(meta): 64 | return (triton.cdiv(size, meta["block_size"]),) 65 | 66 | add_kernel[grid](x, y, z, size, 1024) 67 | 68 | return z 69 | 70 | 71 | def main(): 72 | size = 2**16 73 | x = torch.rand(size, device="cuda") 74 | y = torch.rand(size, device="cuda") 75 | 76 | a = add(x, y) 77 | b = x + y 78 | 79 | assert torch.allclose(a, b) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /matmul.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import triton 25 | import triton.language as tl 26 | 27 | 28 | @triton.jit 29 | def matmul_kernel( 30 | x_ptr, 31 | y_ptr, 32 | z_ptr, 33 | m_size, 34 | k_size, 35 | n_size, 36 | m_block_size: tl.constexpr, 37 | k_block_size: tl.constexpr, 38 | n_block_size: tl.constexpr, 39 | ): 40 | pid = tl.program_id(0) 41 | num_n_blocks = tl.cdiv(n_size, n_block_size) 42 | m_block = pid // num_n_blocks 43 | n_block = pid % num_n_blocks 44 | 45 | m_offsets = tl.arange(0, m_block_size) + m_block * m_block_size 46 | n_offsets = tl.arange(0, n_block_size) + n_block * n_block_size 47 | k_offsets = tl.arange(0, k_block_size) 48 | 49 | x_ptrs = x_ptr + m_offsets[:, None] * k_size + k_offsets[None, :] 50 | y_ptrs = y_ptr + k_offsets[:, None] * n_size + n_offsets[None, :] 51 | z_ptrs = z_ptr + m_offsets[:, None] * n_size + n_offsets[None, :] 52 | 53 | z = tl.zeros((m_block_size, n_block_size), tl.float32) 54 | 55 | for _ in range(0, k_size, k_block_size): 56 | x = tl.load(x_ptrs) 57 | y = tl.load(y_ptrs) 58 | z += tl.dot(x, y, allow_tf32=False) 59 | 60 | x_ptrs += k_block_size 61 | y_ptrs += k_block_size * n_size 62 | 63 | tl.store(z_ptrs, z) 64 | 65 | 66 | def matmul(x, y): 67 | m_size, k_size = x.shape 68 | _, n_size = y.shape 69 | z = torch.empty(m_size, n_size, device="cuda") 70 | 71 | def grid(meta): 72 | return (triton.cdiv(m_size, meta["m_block_size"]) * triton.cdiv(n_size, meta["n_block_size"]),) 73 | 74 | matmul_kernel[grid](x, y, z, m_size, k_size, n_size, m_size, k_size, n_size) 75 | 76 | return z 77 | 78 | 79 | def main(): 80 | x = torch.randn(16, 16, device="cuda") 81 | y = torch.randn(16, 16, device="cuda") 82 | 83 | a = matmul(x, y) 84 | b = torch.matmul(x, y) 85 | 86 | assert torch.allclose(a, b) 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /argmax.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import triton 25 | import triton.language as tl 26 | import triton.testing as testing 27 | 28 | 29 | @triton.jit 30 | def argmax_kernel(output_ptr, input_ptr, num_batches, size, block_size: tl.constexpr): 31 | batch = tl.program_id(0) 32 | 33 | output_block_ptr = tl.make_block_ptr( 34 | output_ptr, shape=(num_batches,), strides=(1,), offsets=(batch,), block_shape=(1,), order=(0,) 35 | ) 36 | input_block_ptr = tl.make_block_ptr( 37 | input_ptr, 38 | shape=(num_batches, size), 39 | strides=(size, 1), 40 | offsets=(batch, 0), 41 | block_shape=(1, block_size), 42 | order=(1, 0), 43 | ) 44 | 45 | input = tl.load(input_block_ptr, boundary_check=(1,)) 46 | condition = tl.arange(0, block_size) < size 47 | input = tl.where(condition, input, float("-inf")) 48 | output = tl.argmax(input, 1) 49 | tl.store(output_block_ptr, output.to(tl.int64)) 50 | 51 | 52 | def argmax(input, dim): 53 | if dim != 1: 54 | raise RuntimeError("Only 1 dim is supported.") 55 | 56 | num_batches, size = input.shape 57 | output = torch.empty(num_batches, device=input.device, dtype=torch.int64) 58 | block_size = triton.next_power_of_2(size) 59 | 60 | def grid(meta): 61 | return (num_batches,) 62 | 63 | argmax_kernel[grid](output, input, num_batches, size, block_size) 64 | 65 | return output 66 | 67 | 68 | def validate(): 69 | input = torch.rand(2, 4096, device="cuda") 70 | assert torch.allclose(argmax(input, 1), torch.argmax(input, 1)) 71 | 72 | 73 | @testing.perf_report( 74 | [ 75 | testing.Benchmark( 76 | x_names=["size"], 77 | x_vals=[256 * i for i in range(1, 11, 1)], 78 | x_log=True, 79 | line_arg="backend", 80 | line_vals=["triton", "torch"], 81 | line_names=["Triton", "Torch"], 82 | ylabel="milliseconds", 83 | plot_name="argmax-performance", 84 | args={"num_batches": 8}, 85 | ), 86 | ] 87 | ) 88 | def benchmark(num_batches, size, backend): 89 | input = torch.rand(num_batches, size, device="cuda") 90 | 91 | if backend == "triton": 92 | return testing.do_bench(lambda: argmax(input, 1)) 93 | else: 94 | return testing.do_bench(lambda: torch.argmax(input, 1)) 95 | 96 | 97 | def main(): 98 | validate() 99 | benchmark.run(show_plots=True, print_data=True) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/windows,linux,macos,pycharm+all 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=windows,linux,macos,pycharm+all 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### macOS Patch ### 49 | # iCloud generated files 50 | *.icloud 51 | 52 | ### PyCharm+all ### 53 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 54 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 55 | 56 | # User-specific stuff 57 | .idea/**/workspace.xml 58 | .idea/**/tasks.xml 59 | .idea/**/usage.statistics.xml 60 | .idea/**/dictionaries 61 | .idea/**/shelf 62 | 63 | # AWS User-specific 64 | .idea/**/aws.xml 65 | 66 | # Generated files 67 | .idea/**/contentModel.xml 68 | 69 | # Sensitive or high-churn files 70 | .idea/**/dataSources/ 71 | .idea/**/dataSources.ids 72 | .idea/**/dataSources.local.xml 73 | .idea/**/sqlDataSources.xml 74 | .idea/**/dynamic.xml 75 | .idea/**/uiDesigner.xml 76 | .idea/**/dbnavigator.xml 77 | 78 | # Gradle 79 | .idea/**/gradle.xml 80 | .idea/**/libraries 81 | 82 | # Gradle and Maven with auto-import 83 | # When using Gradle or Maven with auto-import, you should exclude module files, 84 | # since they will be recreated, and may cause churn. Uncomment if using 85 | # auto-import. 86 | # .idea/artifacts 87 | # .idea/compiler.xml 88 | # .idea/jarRepositories.xml 89 | # .idea/modules.xml 90 | # .idea/*.iml 91 | # .idea/modules 92 | # *.iml 93 | # *.ipr 94 | 95 | # CMake 96 | cmake-build-*/ 97 | 98 | # Mongo Explorer plugin 99 | .idea/**/mongoSettings.xml 100 | 101 | # File-based project format 102 | *.iws 103 | 104 | # IntelliJ 105 | out/ 106 | 107 | # mpeltonen/sbt-idea plugin 108 | .idea_modules/ 109 | 110 | # JIRA plugin 111 | atlassian-ide-plugin.xml 112 | 113 | # Cursive Clojure plugin 114 | .idea/replstate.xml 115 | 116 | # SonarLint plugin 117 | .idea/sonarlint/ 118 | 119 | # Crashlytics plugin (for Android Studio and IntelliJ) 120 | com_crashlytics_export_strings.xml 121 | crashlytics.properties 122 | crashlytics-build.properties 123 | fabric.properties 124 | 125 | # Editor-based Rest Client 126 | .idea/httpRequests 127 | 128 | # Android studio 3.1+ serialized cache file 129 | .idea/caches/build_file_checksums.ser 130 | 131 | ### PyCharm+all Patch ### 132 | # Ignore everything but code style settings and run configurations 133 | # that are supposed to be shared within teams. 134 | 135 | .idea/* 136 | 137 | !.idea/codeStyles 138 | !.idea/runConfigurations 139 | 140 | ### Windows ### 141 | # Windows thumbnail cache files 142 | Thumbs.db 143 | Thumbs.db:encryptable 144 | ehthumbs.db 145 | ehthumbs_vista.db 146 | 147 | # Dump file 148 | *.stackdump 149 | 150 | # Folder config file 151 | [Dd]esktop.ini 152 | 153 | # Recycle Bin used on file shares 154 | $RECYCLE.BIN/ 155 | 156 | # Windows Installer files 157 | *.cab 158 | *.msi 159 | *.msix 160 | *.msm 161 | *.msp 162 | 163 | # Windows shortcuts 164 | *.lnk 165 | 166 | # End of https://www.toptal.com/developers/gitignore/api/windows,linux,macos,pycharm+all -------------------------------------------------------------------------------- /cnn.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import matplotlib.pyplot as plt 24 | import torch 25 | import torch.nn as nn 26 | import torch.optim as optim 27 | import torch.hub as hub 28 | from torch.utils.data import DataLoader 29 | from torchvision import datasets, transforms 30 | 31 | 32 | class CNN(nn.Module): 33 | def __init__(self, device=None): 34 | super().__init__() 35 | factory_kwargs = {"device": device} 36 | self.model = nn.Sequential( 37 | nn.Conv2d(1, 32, 3, 1, **factory_kwargs), 38 | nn.ReLU(), 39 | nn.Conv2d(32, 64, 3, 1, **factory_kwargs), 40 | nn.ReLU(), 41 | nn.MaxPool2d(2), 42 | nn.Dropout(0.25), 43 | nn.Flatten(1), 44 | nn.Linear(9216, 128, **factory_kwargs), 45 | nn.ReLU(), 46 | nn.Dropout(0.5), 47 | nn.Linear(128, 10, **factory_kwargs), 48 | nn.Softmax(1), 49 | ) 50 | 51 | def forward(self, input): 52 | return self.model(input) 53 | 54 | 55 | def train(epoch, num_epochs, dataloader, model, criterion, optimizer, device): 56 | model.train() 57 | progress = hub.tqdm(dataloader, desc=f"[TRAIN] {epoch+1}/{num_epochs}") 58 | total_loss = 0 59 | 60 | for idx, (inputs, labels) in enumerate(progress): 61 | inputs = inputs.to(device) 62 | labels = labels.to(device) 63 | optimizer.zero_grad() 64 | outputs = model(inputs) 65 | loss = criterion(outputs, labels) 66 | loss.backward() 67 | optimizer.step() 68 | progress.set_postfix({"Loss": f"{loss:.4f}"}) 69 | total_loss += loss 70 | 71 | print(f"[TRAIN] {epoch+1}/{num_epochs} Average Loss: {total_loss/len(dataloader):.4f}") 72 | 73 | 74 | def test(epoch, num_epochs, dataloader, model, device): 75 | model.eval() 76 | progress = hub.tqdm(dataloader, desc=f"[TEST] {epoch+1}/{num_epochs}") 77 | total_correct = 0 78 | num_inputs = 0 79 | 80 | for idx, (inputs, labels) in enumerate(progress): 81 | inputs = inputs.to(device) 82 | labels = labels.to(device) 83 | outputs = model(inputs) 84 | outputs = torch.argmax(outputs, 1) 85 | correct = torch.sum(torch.eq(outputs, labels)) 86 | progress.set_postfix({"Accuracy": f"{100 * correct / labels.numel():.2f}"}) 87 | total_correct += correct 88 | num_inputs += labels.numel() 89 | 90 | print(f"[TEST] {epoch+1}/{num_epochs} Average Accuracy: {total_correct / num_inputs * 100:.2f}") 91 | 92 | 93 | def main(): 94 | num_epochs = 4 95 | num_batches = 256 96 | 97 | if torch.cuda.is_available(): 98 | device = torch.device("cuda") 99 | elif torch.backends.mps.is_available(): 100 | device = torch.device("mps") 101 | else: 102 | device = torch.device("cpu") 103 | 104 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.1307, 0.3081)]) 105 | train_dataset = datasets.MNIST("./datasets", train=True, transform=transform, download=True) 106 | test_dataset = datasets.MNIST("./datasets", train=False, transform=transform, download=True) 107 | train_dataloader = DataLoader(train_dataset, num_batches, True) 108 | test_dataloader = DataLoader(test_dataset, num_batches, True) 109 | 110 | cnn = CNN(device) 111 | criterion = nn.CrossEntropyLoss() 112 | optimizer = optim.Adam(cnn.parameters()) 113 | 114 | for epoch in range(0, num_epochs): 115 | train(epoch, num_epochs, train_dataloader, cnn, criterion, optimizer, device) 116 | test(epoch, num_epochs, test_dataloader, cnn, device) 117 | 118 | inputs, labels = next(iter(train_dataloader)) 119 | num_rows = 3 120 | num_cols = 5 121 | num_inputs = num_rows * num_cols 122 | inputs = inputs[:num_inputs,] 123 | labels = labels[:num_inputs,] 124 | outputs = cnn(inputs.to(device)) 125 | outputs = torch.argmax(outputs, 1) 126 | fig, axs = plt.subplots(num_rows, num_cols, figsize=(5, 5)) 127 | axs = axs.ravel() 128 | 129 | for i in range(num_inputs): 130 | axs[i].imshow(torch.permute(inputs[i], (1, 2, 0)).numpy()) 131 | axs[i].set_title(f"{labels[i]}/{outputs[i]}") 132 | axs[i].axis("off") 133 | 134 | plt.show() 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /dropout.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from typing import Any 24 | 25 | import torch 26 | import torch.autograd as autograd 27 | import torch.random as random 28 | import torch.nn as nn 29 | import triton 30 | import triton.language as tl 31 | 32 | 33 | class DropoutKernel: 34 | @staticmethod 35 | @triton.jit 36 | def forward(output_ptr, input_ptr, size, p, seed, block_size: tl.constexpr): 37 | pid = tl.program_id(0) 38 | offset = pid * block_size 39 | 40 | input_block_ptr = tl.make_block_ptr( 41 | input_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 42 | ) 43 | output_block_ptr = tl.make_block_ptr( 44 | output_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 45 | ) 46 | 47 | offsets = tl.arange(0, block_size) + offset 48 | random_values = tl.rand(seed, offsets) 49 | condition = random_values > p 50 | input = tl.load(input_block_ptr, boundary_check=(0,)) 51 | output = tl.where(condition, input * (1 / (1 - p)), 0.0) 52 | tl.store(output_block_ptr, output, boundary_check=(0,)) 53 | 54 | @staticmethod 55 | @triton.jit 56 | def backward(grad_input_ptr, grad_output_ptr, output_ptr, size, p, block_size: tl.constexpr): 57 | pid = tl.program_id(0) 58 | offset = pid * block_size 59 | 60 | grad_input_block_ptr = tl.make_block_ptr( 61 | grad_input_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 62 | ) 63 | grad_output_block_ptr = tl.make_block_ptr( 64 | grad_output_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 65 | ) 66 | output_block_ptr = tl.make_block_ptr( 67 | output_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 68 | ) 69 | 70 | grad_output = tl.load(grad_output_block_ptr, boundary_check=(0,)) 71 | output = tl.load(output_block_ptr, boundary_check=(0,)) 72 | condition = output > 0.0 73 | grad_input = tl.where(condition, grad_output * (1 / (1 - p)), 0.0) 74 | tl.store(grad_input_block_ptr, grad_input, boundary_check=(0,)) 75 | 76 | 77 | class DropoutFunction(autograd.Function): 78 | @staticmethod 79 | def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: 80 | input, p, training = args 81 | output = torch.empty_like(input) 82 | size = input.numel() 83 | block_size = triton.next_power_of_2(input.shape[-1]) 84 | 85 | def grid(meta): 86 | return (triton.cdiv(size, meta["block_size"]),) 87 | 88 | DropoutKernel.forward[grid](output, input, size, p, random.seed(), block_size) 89 | 90 | ctx.save_for_backward(output) 91 | ctx.p = p 92 | 93 | return output 94 | 95 | @staticmethod 96 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 97 | (grad_output,) = grad_outputs 98 | (output,) = ctx.saved_tensors 99 | grad_input = torch.empty_like(grad_output) 100 | size = grad_input.numel() 101 | block_size = triton.next_power_of_2(grad_input.shape[-1]) 102 | 103 | def grid(meta): 104 | return (triton.cdiv(size, meta["block_size"]),) 105 | 106 | DropoutKernel.backward[grid](grad_input, grad_output, output, size, ctx.p, block_size) 107 | 108 | return grad_input, None, None 109 | 110 | 111 | def dropout(input, p=0.5, training=True): 112 | if training: 113 | return DropoutFunction.apply(input, p, training) 114 | else: 115 | return input 116 | 117 | 118 | class Dropout(nn.Module): 119 | def __init__(self, p=0.5): 120 | super().__init__() 121 | self.p = p 122 | 123 | def forward(self, input): 124 | return dropout(input, self.p, self.training) 125 | 126 | 127 | def main(): 128 | input = torch.rand(6, device="cuda", requires_grad=True) 129 | p = 0.4 130 | output = dropout(input, p) 131 | grad_output = torch.ones_like(output) 132 | output.backward(grad_output) 133 | 134 | print(f"input : {input.data}") 135 | print(f"output : {output.data}") 136 | print(f"grad_input: {input.grad.data}") 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /relu.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 Daemyung Jang 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from typing import Any 24 | 25 | import torch 26 | import torch.autograd as autograd 27 | import torch.nn as nn 28 | import triton 29 | import triton.language as tl 30 | 31 | 32 | class ReLUKernel: 33 | @staticmethod 34 | @triton.jit 35 | def forward(output_ptr, input_ptr, size, block_size: tl.constexpr): 36 | pid = tl.program_id(0) 37 | offset = pid * block_size 38 | 39 | input_block_ptr = tl.make_block_ptr( 40 | input_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 41 | ) 42 | output_block_ptr = tl.make_block_ptr( 43 | output_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 44 | ) 45 | 46 | input = tl.load(input_block_ptr, boundary_check=(0,)) 47 | condition = input >= 0 48 | output = tl.where(condition, input, 0) 49 | tl.store(output_block_ptr, output, boundary_check=(0,)) 50 | 51 | @staticmethod 52 | @triton.jit 53 | def backward(grad_input_ptr, grad_output_ptr, input_ptr, size, block_size: tl.constexpr): 54 | pid = tl.program_id(0) 55 | offset = pid * block_size 56 | 57 | grad_input_block_ptr = tl.make_block_ptr( 58 | grad_input_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 59 | ) 60 | grad_output_block_ptr = tl.make_block_ptr( 61 | grad_output_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 62 | ) 63 | input_block_ptr = tl.make_block_ptr( 64 | input_ptr, shape=(size,), strides=(1,), offsets=(offset,), block_shape=(block_size,), order=(0,) 65 | ) 66 | 67 | grad_output = tl.load(grad_output_block_ptr, boundary_check=(0,)) 68 | input = tl.load(input_block_ptr, boundary_check=(0,)) 69 | condition = input >= 0 70 | grad_input = tl.where(condition, grad_output, 0) 71 | tl.store(grad_input_block_ptr, grad_input, boundary_check=(0,)) 72 | 73 | 74 | class ReLUFunction(autograd.Function): 75 | @staticmethod 76 | def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: 77 | (input,) = args 78 | output = torch.empty_like(input) 79 | size = input.numel() 80 | block_size = triton.next_power_of_2(input.shape[-1]) 81 | 82 | def grid(meta): 83 | return (triton.cdiv(size, meta["block_size"]),) 84 | 85 | ReLUKernel.forward[grid](output, input, size, block_size) 86 | 87 | ctx.save_for_backward(input) 88 | 89 | return output 90 | 91 | @staticmethod 92 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 93 | (grad_output,) = grad_outputs 94 | (input,) = ctx.saved_tensors 95 | grad_input = torch.empty_like(grad_output) 96 | size = grad_input.numel() 97 | block_size = triton.next_power_of_2(grad_input.shape[-1]) 98 | 99 | def grid(meta): 100 | return (triton.cdiv(size, meta["block_size"]),) 101 | 102 | ReLUKernel.backward[grid](grad_input, grad_output, input, size, block_size) 103 | 104 | return grad_input 105 | 106 | 107 | def relu(input): 108 | return ReLUFunction.apply(input) 109 | 110 | 111 | class ReLU(nn.Module): 112 | def __init__(self): 113 | super().__init__() 114 | 115 | def forward(self, input): 116 | return relu(input) 117 | 118 | 119 | def main(): 120 | input = torch.rand(6, device="cuda") * 2 - 1 121 | 122 | input_a = input.clone() 123 | input_a.requires_grad = True 124 | grad_output_a = torch.ones_like(input_a) 125 | output_a = relu(input_a) 126 | output_a.backward(grad_output_a) 127 | 128 | input_b = input.clone() 129 | input_b.requires_grad = True 130 | grad_output_b = torch.ones_like(input_b) 131 | output_b = torch.nn.functional.relu(input_b) 132 | output_b.backward(grad_output_b) 133 | 134 | print(f"input ⬇️\ntriton: {input_a.data}\ntorch : {input_b.data}\n") 135 | print(f"output ⬇️\ntriton: {output_a.data}\ntorch : {output_b.data}\n") 136 | print(f"input_grad ⬇️\ntriton: {input_a.grad.data}\ntorch : {input_b.grad.data}\n") 137 | 138 | assert torch.allclose(input_a, input_b) 139 | assert torch.allclose(output_a, output_b) 140 | assert torch.allclose(input_a.grad, input_b.grad) 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | --------------------------------------------------------------------------------