├── .gitignore ├── LICENSE ├── README.md ├── data ├── results-causal-fa.png ├── results-causal.png └── results-random.png ├── fa2_custom_mask ├── __init__.py ├── fa2_bwd.py ├── fa2_custom_mask.py ├── fa2_fwd.py └── utils.py ├── fa2_original.py ├── pyproject.toml ├── requirements.txt ├── setup.py └── test_benchmark.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/.gitignore -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/README.md -------------------------------------------------------------------------------- /data/results-causal-fa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/data/results-causal-fa.png -------------------------------------------------------------------------------- /data/results-causal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/data/results-causal.png -------------------------------------------------------------------------------- /data/results-random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/data/results-random.png -------------------------------------------------------------------------------- /fa2_custom_mask/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/fa2_custom_mask/__init__.py -------------------------------------------------------------------------------- /fa2_custom_mask/fa2_bwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/fa2_custom_mask/fa2_bwd.py -------------------------------------------------------------------------------- /fa2_custom_mask/fa2_custom_mask.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/fa2_custom_mask/fa2_custom_mask.py -------------------------------------------------------------------------------- /fa2_custom_mask/fa2_fwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/fa2_custom_mask/fa2_fwd.py -------------------------------------------------------------------------------- /fa2_custom_mask/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/fa2_custom_mask/utils.py -------------------------------------------------------------------------------- /fa2_original.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/fa2_original.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/pyproject.toml -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest~=8.3.1 2 | triton~=3.0.0 3 | torch~=2.2.1 4 | numpy==1.26.4 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/setup.py -------------------------------------------------------------------------------- /test_benchmark.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexzhang13/flashattention2-custom-mask/HEAD/test_benchmark.py --------------------------------------------------------------------------------