├── .github └── workflows │ ├── ci.yml │ └── docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── index.md ├── requirements.txt └── triton_call.md ├── examples ├── JAX_+_Triton_Flash_Attention.ipynb ├── add.py ├── fused_attention.py ├── fusion │ ├── benchmark_matmul.py │ └── nn.py ├── matmul.py └── softmax.py ├── jax_triton ├── __init__.py ├── triton_lib.py ├── utils.py └── version.py ├── mkdocs.yml ├── pyproject.toml └── tests ├── cluster_test.py ├── triton_call_test.py └── triton_test.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/.github/workflows/ci.yml -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/.github/workflows/docs.yml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | build/ 3 | *.egg-info 4 | *.so 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/.pre-commit-config.yaml -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/CONTRIBUTING.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/README.md -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/docs/index.md -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/docs/requirements.txt -------------------------------------------------------------------------------- /docs/triton_call.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/docs/triton_call.md -------------------------------------------------------------------------------- /examples/JAX_+_Triton_Flash_Attention.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/examples/JAX_+_Triton_Flash_Attention.ipynb -------------------------------------------------------------------------------- /examples/add.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/examples/add.py -------------------------------------------------------------------------------- /examples/fused_attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/examples/fused_attention.py -------------------------------------------------------------------------------- /examples/fusion/benchmark_matmul.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/examples/fusion/benchmark_matmul.py -------------------------------------------------------------------------------- /examples/fusion/nn.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/examples/fusion/nn.py -------------------------------------------------------------------------------- /examples/matmul.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/examples/matmul.py -------------------------------------------------------------------------------- /examples/softmax.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/examples/softmax.py -------------------------------------------------------------------------------- /jax_triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/jax_triton/__init__.py -------------------------------------------------------------------------------- /jax_triton/triton_lib.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/jax_triton/triton_lib.py -------------------------------------------------------------------------------- /jax_triton/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/jax_triton/utils.py -------------------------------------------------------------------------------- /jax_triton/version.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/jax_triton/version.py -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/mkdocs.yml -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/pyproject.toml -------------------------------------------------------------------------------- /tests/cluster_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/tests/cluster_test.py -------------------------------------------------------------------------------- /tests/triton_call_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/tests/triton_call_test.py -------------------------------------------------------------------------------- /tests/triton_test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jax-ml/jax-triton/HEAD/tests/triton_test.py --------------------------------------------------------------------------------