├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── Cargo.toml ├── README.md ├── csrc ├── block_info.h ├── epilogue_bwd_sm90_tma.hpp ├── epilogue_fwd_sm90_tma.hpp ├── error.h ├── flash.h ├── flash_api.cu ├── flash_api_rust.cu ├── flash_attn_ops.cpp ├── flash_bwd_hdim128_bf16_sm90.cu ├── flash_bwd_hdim128_fp16_sm90.cu ├── flash_bwd_hdim64_bf16_sm90.cu ├── flash_bwd_hdim64_fp16_sm90.cu ├── flash_bwd_hdim96_bf16_sm90.cu ├── flash_bwd_hdim96_fp16_sm90.cu ├── flash_bwd_kernel.h ├── flash_bwd_launch_template.h ├── flash_bwd_postprocess_kernel.h ├── flash_bwd_preprocess_kernel.h ├── flash_fwd_hdim128_bf16_sm90.cu ├── flash_fwd_hdim128_e4m3_sm90.cu ├── flash_fwd_hdim128_fp16_sm90.cu ├── flash_fwd_hdim256_bf16_sm90.cu ├── flash_fwd_hdim256_e4m3_sm90.cu ├── flash_fwd_hdim256_fp16_sm90.cu ├── flash_fwd_hdim64_bf16_sm90.cu ├── flash_fwd_hdim64_e4m3_sm90.cu ├── flash_fwd_hdim64_fp16_sm90.cu ├── flash_fwd_kernel.h ├── flash_fwd_launch_template.h ├── kernel_helpers.h ├── kernel_traits.h ├── kernels.h ├── mainloop_bwd_sm90_tma_gmma_ws.hpp ├── mainloop_fwd_sm90_tma_gmma_ws.hpp ├── named_barrier.hpp ├── pybind11_kernel_helpers.h ├── seq_len.h ├── softmax.h ├── static_switch.h ├── tile_scheduler.hpp ├── tile_scheduler_bwd.hpp └── utils.h ├── jax_flash_attn ├── __init__.py └── register_ops.py ├── pyproject.toml ├── python └── jflash_attn │ ├── __init__.py │ └── register_ops.py ├── rustfmt.toml ├── src ├── ffi.rs └── lib.rs └── test.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/.gitignore -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/.gitmodules -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/CMakeLists.txt -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/Cargo.toml -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/README.md -------------------------------------------------------------------------------- /csrc/block_info.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/block_info.h -------------------------------------------------------------------------------- /csrc/epilogue_bwd_sm90_tma.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/epilogue_bwd_sm90_tma.hpp -------------------------------------------------------------------------------- /csrc/epilogue_fwd_sm90_tma.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/epilogue_fwd_sm90_tma.hpp -------------------------------------------------------------------------------- /csrc/error.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/error.h -------------------------------------------------------------------------------- /csrc/flash.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash.h -------------------------------------------------------------------------------- /csrc/flash_api.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_api.cu -------------------------------------------------------------------------------- /csrc/flash_api_rust.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_api_rust.cu -------------------------------------------------------------------------------- /csrc/flash_attn_ops.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_attn_ops.cpp -------------------------------------------------------------------------------- /csrc/flash_bwd_hdim128_bf16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_hdim128_bf16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_bwd_hdim128_fp16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_hdim128_fp16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_bwd_hdim64_bf16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_hdim64_bf16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_bwd_hdim64_fp16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_hdim64_fp16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_bwd_hdim96_bf16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_hdim96_bf16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_bwd_hdim96_fp16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_hdim96_fp16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_bwd_kernel.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_kernel.h -------------------------------------------------------------------------------- /csrc/flash_bwd_launch_template.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_launch_template.h -------------------------------------------------------------------------------- /csrc/flash_bwd_postprocess_kernel.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_postprocess_kernel.h -------------------------------------------------------------------------------- /csrc/flash_bwd_preprocess_kernel.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_bwd_preprocess_kernel.h -------------------------------------------------------------------------------- /csrc/flash_fwd_hdim128_bf16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_hdim128_bf16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_fwd_hdim128_e4m3_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_hdim128_e4m3_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_fwd_hdim128_fp16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_hdim128_fp16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_fwd_hdim256_bf16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_hdim256_bf16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_fwd_hdim256_e4m3_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_hdim256_e4m3_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_fwd_hdim256_fp16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_hdim256_fp16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_fwd_hdim64_bf16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_hdim64_bf16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_fwd_hdim64_e4m3_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_hdim64_e4m3_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_fwd_hdim64_fp16_sm90.cu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_hdim64_fp16_sm90.cu -------------------------------------------------------------------------------- /csrc/flash_fwd_kernel.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_kernel.h -------------------------------------------------------------------------------- /csrc/flash_fwd_launch_template.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/flash_fwd_launch_template.h -------------------------------------------------------------------------------- /csrc/kernel_helpers.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/kernel_helpers.h -------------------------------------------------------------------------------- /csrc/kernel_traits.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/kernel_traits.h -------------------------------------------------------------------------------- /csrc/kernels.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/kernels.h -------------------------------------------------------------------------------- /csrc/mainloop_bwd_sm90_tma_gmma_ws.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/mainloop_bwd_sm90_tma_gmma_ws.hpp -------------------------------------------------------------------------------- /csrc/mainloop_fwd_sm90_tma_gmma_ws.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/mainloop_fwd_sm90_tma_gmma_ws.hpp -------------------------------------------------------------------------------- /csrc/named_barrier.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/named_barrier.hpp -------------------------------------------------------------------------------- /csrc/pybind11_kernel_helpers.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/pybind11_kernel_helpers.h -------------------------------------------------------------------------------- /csrc/seq_len.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/seq_len.h -------------------------------------------------------------------------------- /csrc/softmax.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/softmax.h -------------------------------------------------------------------------------- /csrc/static_switch.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/static_switch.h -------------------------------------------------------------------------------- /csrc/tile_scheduler.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/tile_scheduler.hpp -------------------------------------------------------------------------------- /csrc/tile_scheduler_bwd.hpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/tile_scheduler_bwd.hpp -------------------------------------------------------------------------------- /csrc/utils.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/csrc/utils.h -------------------------------------------------------------------------------- /jax_flash_attn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/jax_flash_attn/__init__.py -------------------------------------------------------------------------------- /jax_flash_attn/register_ops.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/jax_flash_attn/register_ops.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/pyproject.toml -------------------------------------------------------------------------------- /python/jflash_attn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/python/jflash_attn/__init__.py -------------------------------------------------------------------------------- /python/jflash_attn/register_ops.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/python/jflash_attn/register_ops.py -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/rustfmt.toml -------------------------------------------------------------------------------- /src/ffi.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/src/ffi.rs -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/src/lib.rs -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyutai-labs/jax-flash-attn3/HEAD/test.py --------------------------------------------------------------------------------