├── .github ├── license-check │ ├── config.json │ └── header.txt └── workflows │ └── license-header-check.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── launch.json └── settings.json ├── LICENSE ├── NOTICE ├── README.md ├── demo ├── integrate_mlstm_via_backend_module_option1.ipynb ├── integrate_mlstm_via_direct_import_option2and3.ipynb └── kernel_speed_benchmark.ipynb ├── dev_notes.md ├── envs ├── environment_pt240cu121.yaml ├── environment_pt240cu124.yaml └── environment_pt251cu124.yaml ├── flash_attention.patch ├── mlstm_kernels ├── __init__.py ├── baselines │ ├── __init__.py │ ├── flash_attention │ │ ├── __init__.py │ │ ├── flash_attention_triton.py │ │ ├── torch_sdp_attention.py │ │ └── triton_tutorial.py │ ├── flash_linear_attention │ │ ├── __init__.py │ │ ├── fla_utils.py │ │ ├── gla │ │ │ ├── __init__.py │ │ │ ├── chunk.py │ │ │ ├── chunk_fuse.py │ │ │ ├── chunk_util.py │ │ │ ├── naive.py │ │ │ └── recurrent_fuse.py │ │ └── simple_gla │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── chunk.py │ │ │ └── naive.py │ └── lightning_attention │ │ ├── __init__.py │ │ ├── lightning_attn2.py │ │ └── utils.py ├── jax │ ├── __init__.py │ ├── chunkwise │ │ ├── __init__.py │ │ ├── native │ │ │ ├── __init__.py │ │ │ ├── fw.py │ │ │ └── fwbw.py │ │ ├── triton_limit_chunk │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── bw_parallel.py │ │ │ ├── bw_recurrent.py │ │ │ ├── fw.py │ │ │ ├── fw_parallel.py │ │ │ ├── fw_recurrent.py │ │ │ └── fwbw.py │ │ ├── triton_xl_chunk │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── bw_parallel_dK.py │ │ │ ├── bw_parallel_dQ.py │ │ │ ├── bw_parallel_dV.py │ │ │ ├── bw_recurrent.py │ │ │ ├── chunkwise_gates.py │ │ │ ├── fw.py │ │ │ ├── fw_parallel.py │ │ │ ├── fw_recurrent.py │ │ │ └── fwbw.py │ │ └── triton_xl_chunk_siging │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── bw_parallel_dK.py │ │ │ ├── bw_parallel_dQ.py │ │ │ ├── bw_parallel_dV.py │ │ │ ├── bw_recurrent.py │ │ │ ├── chunkwise_gates.py │ │ │ ├── fw.py │ │ │ ├── fw_parallel.py │ │ │ ├── fw_recurrent.py │ │ │ └── fwbw.py │ ├── parallel │ │ ├── __init__.py │ │ ├── native │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── fw.py │ │ │ └── fwbw.py │ │ ├── native_siging │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── fw.py │ │ │ └── fwbw.py │ │ └── native_stablef │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── fw.py │ │ │ └── fwbw.py │ ├── recurrent │ │ ├── __init__.py │ │ ├── native_sequence.py │ │ ├── native_sequence_scan.py │ │ ├── native_step.py │ │ └── triton_step.py │ ├── stride_utils.py │ ├── utils.py │ └── xla_utils.py ├── torch │ ├── __init__.py │ ├── backend_module.py │ ├── chunkwise │ │ ├── __init__.py │ │ ├── native │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── fw.py │ │ │ └── fwbw.py │ │ ├── triton_limit_chunk │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── bw_parallel.py │ │ │ ├── bw_recurrent.py │ │ │ ├── chunkwise_gates.py │ │ │ ├── fw.py │ │ │ ├── fw_parallel.py │ │ │ ├── fw_recurrent.py │ │ │ └── fwbw.py │ │ ├── triton_xl_chunk │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── bw_parallel_dK.py │ │ │ ├── bw_parallel_dQ.py │ │ │ ├── bw_parallel_dV.py │ │ │ ├── bw_recurrent.py │ │ │ ├── chunkwise_gates.py │ │ │ ├── fw.py │ │ │ ├── fw_parallel.py │ │ │ ├── fw_recurrent.py │ │ │ └── fwbw.py │ │ └── triton_xl_chunk_siging │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── bw_parallel_dK.py │ │ │ ├── bw_parallel_dQ.py │ │ │ ├── bw_parallel_dV.py │ │ │ ├── bw_recurrent.py │ │ │ ├── chunkwise_gates.py │ │ │ ├── fw.py │ │ │ ├── fw_parallel.py │ │ │ ├── fw_recurrent.py │ │ │ └── fwbw.py │ ├── kernel_wrappers.py │ ├── parallel │ │ ├── __init__.py │ │ ├── _legacy_native_siging │ │ │ ├── __init__.py │ │ │ ├── ops.py │ │ │ └── sig_ingate.py │ │ ├── _native_tiled.py │ │ ├── native │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── fw.py │ │ │ └── fwbw.py │ │ ├── native_siging │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── fw.py │ │ │ └── fwbw.py │ │ ├── native_stablef │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── fw.py │ │ │ └── fwbw.py │ │ └── triton_limit_headdim │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── fw.py │ │ │ └── fwbw.py │ ├── recurrent │ │ ├── __init__.py │ │ ├── native_sequence.py │ │ ├── native_step.py │ │ ├── triton_step.py │ │ └── triton_step_alternate.py │ └── utils.py ├── triton │ ├── __init__.py │ ├── chunkwise │ │ ├── __init__.py │ │ ├── kernel_param_heuristics.py │ │ ├── limit_chunk │ │ │ ├── __init__.py │ │ │ ├── bw_kernel_parallel.py │ │ │ ├── bw_kernel_recurrent.py │ │ │ ├── fw_kernel_parallel.py │ │ │ └── fw_kernel_recurrent.py │ │ ├── xl_chunk │ │ │ ├── __init__.py │ │ │ ├── bw_kernel_parallel_dK.py │ │ │ ├── bw_kernel_parallel_dQ.py │ │ │ ├── bw_kernel_parallel_dV.py │ │ │ ├── bw_kernel_recurrent.py │ │ │ ├── fw_kernel_parallel.py │ │ │ └── fw_kernel_recurrent.py │ │ └── xl_chunk_siging │ │ │ ├── __init__.py │ │ │ ├── bw_kernel_parallel_dK.py │ │ │ ├── bw_kernel_parallel_dQ.py │ │ │ ├── bw_kernel_parallel_dV.py │ │ │ ├── bw_kernel_recurrent.py │ │ │ ├── fw_kernel_parallel.py │ │ │ └── fw_kernel_recurrent.py │ ├── kernel_param_heuristics.py │ ├── parallel │ │ ├── __init__.py │ │ └── limit_headdim │ │ │ ├── __init__.py │ │ │ ├── bw_kernel.py │ │ │ └── fw_kernel.py │ └── recurrent │ │ ├── __init__.py │ │ ├── fw_step_alternate.py │ │ └── fw_step_fused.py └── utils │ ├── __init__.py │ ├── analysis │ ├── __init__.py │ ├── roofline_analysis │ │ ├── __init__.py │ │ ├── flop_optimal_chunk_size.ipynb │ │ ├── flops_mlstm.py │ │ ├── memops_mlstm.py │ │ ├── mlstm_flop_analysis.ipynb │ │ ├── mlstm_memop_analysis.ipynb │ │ ├── mlstm_roofline_analysis.ipynb │ │ ├── plot_config.py │ │ ├── plot_mlstm_arithmetic_intensity.py │ │ ├── plot_mlstm_flop_analysis.py │ │ ├── plot_mlstm_optimal_chunksize.py │ │ ├── plot_roofline_model.py │ │ ├── plot_runtime.py │ │ └── roofline_analysis_mlstm.py │ └── transfer_behavior │ │ ├── __init__.py │ │ ├── _mlstm_cells.py │ │ ├── _norm_layers.py │ │ ├── generate_transfer_behavior_data.py │ │ ├── mlstm_cell_func.py │ │ └── plot_transfer_behavior.py │ ├── benchmark │ ├── __init__.py │ ├── benchmarks │ │ ├── huggingface_model_benchmark.py │ │ ├── huggingface_model_configs.py │ │ ├── inference_kernel_benchmarks.py │ │ ├── interface.py │ │ ├── model_benchmarks.py │ │ ├── training_kernel_benchmarks.py │ │ └── vllm_model_benchmark.py │ ├── cuda_graphs.py │ ├── param_handling.py │ ├── plot_config.py │ ├── plot_results.py │ ├── run_benchmark.py │ ├── runtime.py │ └── utils.py │ ├── flops │ ├── __init__.py │ ├── mlstm_block_flop_counts.py │ ├── mlstm_flop_analysis.py │ ├── model_flops_computation.py │ ├── slstm_block_flop_counts.py │ └── transformer_block_flop_counts.py │ ├── kernels.py │ ├── plot │ ├── __init__.py │ ├── bar_plot.py │ ├── diff_imshow.py │ ├── diff_lineplot.py │ └── ewma.py │ ├── test │ ├── __init__.py │ ├── checks.py │ ├── fixtures.py │ ├── test_fwbw.py │ └── test_templates │ │ └── __init__.py │ └── time.py ├── notebooks ├── flop_counting │ ├── mlstm_block--count_flops_by_hand.ipynb │ ├── mlstm_cell--count_flops_by_hand.ipynb │ └── mlstm_cell_vs_slstm_cell_flops.ipynb ├── plots_7B_model_benchmark │ ├── __init__.py │ ├── avg_acc_ruler_abl.pkl │ ├── avg_acc_ruler_main.pkl │ ├── dumps │ │ └── scheduler │ │ │ ├── loss_df.pkl │ │ │ ├── lr_df.pkl │ │ │ └── ppl_df.pkl │ ├── gen_time_mem_data.p │ ├── gen_time_mem_data_vllm.p │ ├── hf_7B_genlength_results--final.ipynb │ ├── hf_7B_genlength_results--for paper--final.ipynb │ ├── hf_7B_throughput_results--final.ipynb │ ├── hf_7B_throughput_results--for paper--final.ipynb │ ├── hf_7B_timetofirsttok_results--final.ipynb │ ├── hf_7B_timetofirsttok_results--for paper--final.ipynb │ ├── plot_config.py │ ├── plot_config_for_paper.py │ ├── plot_paper_figures_custom.py │ ├── plot_results_for_paper.py │ ├── ruler_7B_igate_ablation_plot.ipynb │ ├── throughput_data.p │ ├── throughput_df.p │ ├── throughput_vllm_df.p │ ├── ttft_raw_data.p │ ├── ttft_raw_data_vllm.p │ ├── vllmhf_7B_genlength_results--final.ipynb │ ├── vllmhf_7B_throughput_results--final.ipynb │ └── vllmhf_7B_timetofirsttok_results--final.ipynb ├── plots_mlstm_kernel_benchmark │ ├── consttoken_results.pkl │ ├── gen_results.pkl │ ├── plot_config.py │ ├── plot_consttoken_results.ipynb │ └── plot_generate_results.ipynb ├── plots_mlstm_kernel_benchmark_tfla_paper │ ├── mlstm_tfla_paper_consttoken_benchmark_results.p │ ├── mlstm_tfla_paper_consttoken_benchmark_results_lightn_attn.p │ ├── mlstm_tfla_paper_consttoken_benchmark_results_rerun.p │ ├── mlstm_tfla_paper_head_dim_benchmark_results.p │ ├── plot_config.py │ ├── plot_mlstm_tfla_consttoken_benchmark--rerun.ipynb │ ├── plot_mlstm_tfla_consttoken_benchmark.ipynb │ ├── plot_mlstm_tfla_consttoken_benchmark_appendix_gla.ipynb │ ├── plot_mlstm_tfla_consttoken_benchmark_appendix_lightnattn.ipynb │ ├── plot_mlstm_tfla_headdim_benchmark.ipynb │ └── plot_mlstm_tfla_memory_runtime_tradeoff.ipynb ├── plots_roofline_analysis │ ├── plot_arithmetic_intensity_mlstm.ipynb │ ├── plot_flop_comparison_mlstm_formulations.ipynb │ ├── plot_optimal_chunk_size.ipynb │ ├── plot_roofline_model.ipynb │ └── plot_theoretical_runtime_mlstm.ipynb ├── transfer_behavior_analysis │ ├── m_state_explore.ipynb │ ├── mlstm_kernels_epsilon_runs.ipynb │ ├── norm_eps_vs_igbias_raw_data_df.p │ ├── plot_config.py │ ├── plot_transfer_behavior_norm_eps_grid--experimental.ipynb │ ├── plot_transfer_behavior_norm_eps_grid--paper.ipynb │ └── plot_transfer_behavior_sig_vs_exp--paper.ipynb └── triton_explore │ └── understand_block_ptr.ipynb ├── notebooks_kernel_dev ├── integrate_lightning_attn2.ipynb ├── kernel_speed_benchmark.ipynb └── mlstm_siging │ ├── native_siging_legacy_test.ipynb │ ├── native_siging_optional_normalize_test.ipynb │ ├── triton_xl_chunk_exping_test.ipynb │ └── triton_xl_chunk_siging_test.ipynb ├── pyproject.toml ├── pytest.ini ├── res ├── Figure 2 - paper.svg ├── Figure_1-7.pdf ├── Figure_1-7.svg └── plot_tfla_mlstm_kernel_benchmark--paper-rerun.svg ├── scripts ├── run_hf_model_benchmark.py ├── run_hf_model_benchmark_debug.py ├── run_hf_model_benchmark_with_profile.py ├── run_inference_kernel_benchmarks.py ├── run_training_kernel_benchmark.sh ├── run_training_kernel_benchmarks.py ├── run_training_kernel_benchmarks_with_profile.py └── run_vllm_model_benchmark.py ├── setup.cfg └── tests ├── __init__.py ├── conftest.py ├── jax ├── __init__.py ├── chunkwise │ ├── __init__.py │ ├── conftest.py │ ├── test_chunkwise_native.py │ ├── test_chunkwise_triton_limit_chunk.py │ ├── test_chunkwise_triton_xl_chunk.py │ └── test_chunkwise_triton_xl_chunk_siging.py ├── conftest.py ├── losses_tests.py ├── parallel │ ├── __init__.py │ ├── test_parallel_native.py │ ├── test_parallel_native_siging.py │ ├── test_parallel_native_stablef.py │ ├── test_parallel_native_vs_native_stablef.py │ └── test_vs_torch.py ├── recurrent │ ├── __init__.py │ ├── test_recurrent_sequence_native.py │ └── test_recurrent_sequence_scan_native.py ├── template_test_against_pytorch.py └── template_test_parallel_interface.py ├── test_padding.py └── torch ├── __init__.py ├── chunkwise ├── __init__.py ├── conftest.py ├── test_chunkwise_native.py ├── test_chunkwise_triton_limit_chunk.py ├── test_chunkwise_triton_xl_chunk.py └── test_chunkwise_triton_xl_chunk_siging.py ├── conftest.py ├── losses_tests.py ├── parallel ├── __init__.py ├── test_parallel_native_siging.py ├── test_parallel_native_vs_native_stablef.py └── test_parallel_triton_limit_headdim.py ├── recurrent ├── __init__.py ├── test_recurrent_sequence_native_vs_parallel_native.py ├── test_recurrent_sequence_triton_vs_parallel_native.py └── test_sequence_chunked.py ├── template_test_arbitrary_sequence_length.py ├── template_test_parallel_interface.py ├── test_arbitrary_sequence_length.py ├── test_backend_module.py └── test_pad_zeros.py /.github/license-check/config.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "include": [ 4 | "**/*.py" 5 | ], 6 | "exclude": [ 7 | "mlstm_kernels/baselines/flash_attention/flash_attention_triton.py", 8 | "mlstm_kernels/baselines/flash_attention/triton_tutorial.py", 9 | "mlstm_kernels/baselines/flash_linear_attention/gla/__init__.py", 10 | "mlstm_kernels/baselines/flash_linear_attention/fla_utils.py", 11 | "mlstm_kernels/baselines/flash_linear_attention/gla/chunk_fuse.py", 12 | "mlstm_kernels/baselines/flash_linear_attention/gla/chunk.py", 13 | "mlstm_kernels/baselines/flash_linear_attention/gla/recurrent_fuse.py", 14 | "mlstm_kernels/baselines/flash_linear_attention/simple_gla/__init__.py", 15 | "mlstm_kernels/baselines/flash_linear_attention/simple_gla/chunk.py", 16 | "mlstm_kernels/baselines/lightning_attention/**", 17 | "**/*.md" 18 | ], 19 | "license": "./.github/license-check/header.txt" 20 | }, 21 | { 22 | "include": [ 23 | "node_modules/**" 24 | ] 25 | } 26 | ] 27 | -------------------------------------------------------------------------------- /.github/license-check/header.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /.github/workflows/license-header-check.yml: -------------------------------------------------------------------------------- 1 | name: License Check 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Check license headers 11 | uses: viperproject/check-license-header@v2 12 | with: 13 | path: ./ 14 | config: ./.github/license-check/config.json 15 | strict: false 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### CUSTOM 2 | wandb 3 | outputs 4 | data/* 5 | outputs_kernel_benchmarks* 6 | tests_outputs 7 | .vscode 8 | outputs_tests 9 | configs 10 | configs/* 11 | resources/*.p 12 | tests_outputs 13 | *.png 14 | *.csv 15 | notebooks/*/plots/ 16 | 17 | **/build/** 18 | 19 | <** 20 | 21 | *.nsys-rep 22 | *.ncu-rep 23 | nvidia_nsight/ 24 | 25 | .ruff_cache 26 | 27 | ### DEFAULTS 28 | # Byte-compiled / optimized / DLL files 29 | __pycache__/ 30 | *.py[cod] 31 | *$py.class 32 | 33 | # C extensions 34 | *.so 35 | 36 | # Distribution / packaging 37 | .Python 38 | develop-eggs/ 39 | dist/ 40 | downloads/ 41 | eggs/ 42 | .eggs/ 43 | lib64/ 44 | parts/ 45 | sdist/ 46 | var/ 47 | wheels/ 48 | pip-wheel-metadata/ 49 | share/python-wheels/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | MANIFEST 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | # Jupyter Notebook 62 | .ipynb_checkpoints 63 | 64 | # IPython 65 | profile_default/ 66 | ipython_config.py 67 | 68 | # pyenv 69 | .python-version 70 | 71 | # Unit test / coverage reports 72 | htmlcov/ 73 | .tox/ 74 | .nox/ 75 | .coverage 76 | .coverage.* 77 | .cache 78 | nosetests.xml 79 | coverage.xml 80 | *.cover 81 | *.py,cover 82 | .hypothesis/ 83 | .pytest_cache/ 84 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.6.0 6 | hooks: 7 | - id: check-added-large-files 8 | - id: check-case-conflict 9 | - id: check-docstring-first 10 | - id: check-executables-have-shebangs 11 | # - id: check-json # TODO want to allow comments in vscode launch.json 12 | - id: check-shebang-scripts-are-executable 13 | - id: check-merge-conflict 14 | - id: check-symlinks 15 | - id: check-toml 16 | - id: check-yaml 17 | - id: debug-statements 18 | - id: detect-private-key 19 | - id: end-of-file-fixer 20 | - id: mixed-line-ending 21 | # - id: pretty-format-json 22 | # args: [ --autofix, --no-sort-keys ] 23 | # - id: name-tests-test # TODO this odes not allow for other files in tests/ folder 24 | # args: [ --pytest-test-first ] 25 | # - id: no-commit-to-branch 26 | - id: trailing-whitespace 27 | - repo: https://github.com/asottile/setup-cfg-fmt 28 | rev: v2.5.0 29 | hooks: 30 | - id: setup-cfg-fmt 31 | - repo: https://github.com/asottile/pyupgrade 32 | rev: v3.17.0 33 | hooks: 34 | - id: pyupgrade 35 | args: [ --py311-plus ] 36 | - repo: https://github.com/astral-sh/ruff-pre-commit 37 | rev: v0.6.1 38 | hooks: 39 | # - id: ruff 40 | # types_or: [ python, pyi, jupyter ] 41 | # args: [ --fix ] 42 | - id: ruff-format 43 | types_or: [ python, pyi, jupyter ] 44 | # This would clean up all notebooks. 45 | # - repo: https://github.com/srstevenson/nb-clean 46 | # rev: 4.0.1 47 | # hooks: 48 | # - id: nb-clean 49 | # args: 50 | # - --remove-empty-cells 51 | # - --preserve-cell-metadata 52 | # - tags 53 | # - slideshow 54 | # - -- 55 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "pytest", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "stopOnEntry": false, 12 | // "python": "${command:python.interpreterPath}", 13 | "module": "pytest", 14 | // "program": "${workspaceFolder}/run.py", 15 | "cwd": "${workspaceFolder}", 16 | "console": "integratedTerminal", 17 | "justMyCode": false, 18 | "args": [ 19 | // "-cn", 20 | "${workspaceFolder}/tests/torch/test_arbitrary_sequence_length.py", 21 | ], 22 | "env": { 23 | "CUDA_VISIBLE_DEVICES": "0", 24 | "LOGLEVEL": "DEBUG" 25 | } 26 | }, 27 | { 28 | "name": "kernel_benchmark", 29 | "type": "debugpy", 30 | "request": "launch", 31 | "stopOnEntry": false, 32 | "python": "${command:python.interpreterPath}", 33 | "program": "${workspaceFolder}/scripts/run_kernel_benchmark.py", 34 | "cwd": "${workspaceFolder}", 35 | "console": "integratedTerminal", 36 | "justMyCode": false, 37 | "args": [ 38 | // "-cn", 39 | // "${workspaceFolder}/scripts/run_kernel_benchmark.py", 40 | ], 41 | "env": { 42 | "CUDA_VISIBLE_DEVICES": "0", 43 | "LOGLEVEL": "DEBUG", 44 | "PYTHONPATH": "${workspaceFolder}" 45 | } 46 | }, 47 | 48 | { 49 | "name": "model_benchmark_with_profile", 50 | "type": "debugpy", 51 | "request": "launch", 52 | "stopOnEntry": false, 53 | "python": "${command:python.interpreterPath}", 54 | "program": "${workspaceFolder}/scripts/run_hf_model_benchmark_with_profile.py", 55 | "cwd": "${workspaceFolder}", 56 | "console": "integratedTerminal", 57 | "justMyCode": false, 58 | "args": [ 59 | // "-cn", 60 | // "${workspaceFolder}/scripts/run_kernel_benchmark.py", 61 | ], 62 | "env": { 63 | "CUDA_VISIBLE_DEVICES": "0", 64 | "LOGLEVEL": "DEBUG", 65 | "PYTHONPATH": "${workspaceFolder}" 66 | } 67 | }, 68 | 69 | { 70 | "name": "hf_model_benchmark_debug", 71 | "type": "debugpy", 72 | "request": "launch", 73 | "stopOnEntry": false, 74 | "python": "${command:python.interpreterPath}", 75 | "program": "${workspaceFolder}/scripts/run_hf_model_benchmark_debug.py", 76 | "cwd": "${workspaceFolder}", 77 | "console": "integratedTerminal", 78 | "justMyCode": false, 79 | "args": [ 80 | // "-cn", 81 | // "${workspaceFolder}/scripts/run_kernel_benchmark.py", 82 | ], 83 | "env": { 84 | "CUDA_VISIBLE_DEVICES": "0", 85 | "LOGLEVEL": "DEBUG", 86 | "PYTHONPATH": "${workspaceFolder}" 87 | } 88 | }, 89 | 90 | ] 91 | } 92 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.rulers": [120], 3 | "[python]": { 4 | "editor.defaultFormatter": "charliermarsh.ruff", 5 | "editor.formatOnType": true, 6 | "editor.codeActionsOnSave": { 7 | "source.organizeImports": "explicit" 8 | } 9 | }, 10 | "editor.formatOnSave": false, 11 | "isort.args": ["--profile", "ruff"], 12 | "files.watcherExclude": { 13 | "outputs/**": true, 14 | "**/*.feather": true, 15 | "data/**": true, 16 | "**/.git/objects/**": true, 17 | "**/.git/subtree-cache/**": true 18 | }, 19 | "python.linting.enabled": false, 20 | "C_Cpp.intelliSenseEngine": "disabled", 21 | "C_Cpp.default.cppStandard": "c++17", 22 | "C_Cpp.clang_format_fallbackStyle": "LLVM", 23 | "jupyter.debugJustMyCode": false, 24 | "editor.fastScrollSensitivity": 50, 25 | "editor.mouseWheelScrollSensitivity": 3, 26 | "python.testing.pytestArgs": ["tests"], 27 | "python.testing.unittestEnabled": false, 28 | "python.testing.pytestEnabled": true, 29 | "files.autoSave": "onFocusChange" 30 | } 31 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | NXAI mLSTM Kernels 2 | Copyright (c) NXAI GmbH. 3 | 4 | This product includes software from the triton project (https://github.com/triton-lang/triton/) licensed under the MIT License. 5 | Copyright 2018-2020 Philippe Tillet. 6 | Copyright 2020-2022 OpenAI. 7 | 8 | This product includes software from Dao-AILab (https://github.com/Dao-AILab/flash-attention) licensed under the BSD-3-Clause License. 9 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 10 | All rights reserved. 11 | 12 | This product includes software from Songlin Yang (https://github.com/sustcsonglin/flash-linear-attention) licensed under the MIT License. 13 | Copyright (c) 2024 Songlin Yang. 14 | -------------------------------------------------------------------------------- /demo/integrate_mlstm_via_backend_module_option1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "\n", 11 | "sys.path.append(\"..\")\n", 12 | "\n", 13 | "import torch" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from mlstm_kernels.torch.backend_module import mLSTMBackendConfig, mLSTMBackend" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# we use the mLSTMexp TFLA kernel\n", 32 | "# we also configure to use the triton step kernel for inference\n", 33 | "mlstm_backend_config = mLSTMBackendConfig(\n", 34 | " chunkwise_kernel=\"chunkwise--triton_xl_chunk\",\n", 35 | " sequence_kernel=\"native_sequence__triton\",\n", 36 | " step_kernel=\"triton\",\n", 37 | " chunk_size=256,\n", 38 | " return_last_states=False,\n", 39 | ")\n", 40 | "\n", 41 | "mlstm_backend = mLSTMBackend(mlstm_backend_config)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# run the kernel\n", 51 | "DEVICE = torch.device(\"cuda\")\n", 52 | "DTYPE = torch.bfloat16\n", 53 | "B = 2\n", 54 | "S = 512\n", 55 | "DHQK = 128\n", 56 | "DHHV = 256\n", 57 | "NH = 4\n", 58 | "\n", 59 | "# create input tensors\n", 60 | "torch.manual_seed(1)\n", 61 | "matQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)\n", 62 | "matK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)\n", 63 | "matV = torch.randn((B, NH, S, DHHV), dtype=DTYPE, device=DEVICE)\n", 64 | "vecI = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)\n", 65 | "vecF = 3.0 + torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "matH1 = mlstm_backend(q=matQ, k=matK, v=matV, i=vecI, f=vecF)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# directly import mLSTMexp TFLA kernel\n", 84 | "from mlstm_kernels.torch.chunkwise.triton_xl_chunk import mlstm_chunkwise__xl_chunk" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "matH2 = mlstm_chunkwise__xl_chunk(\n", 94 | " q=matQ, k=matK, v=matV, i=vecI, f=vecF, return_last_states=False, chunk_size=256\n", 95 | ")" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "torch.allclose(matH1, matH2, atol=1e-5, rtol=1e-3)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [] 113 | } 114 | ], 115 | "metadata": { 116 | "kernelspec": { 117 | "display_name": "mlstmpt251cu124", 118 | "language": "python", 119 | "name": "python3" 120 | }, 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3", 131 | "version": "3.11.11" 132 | } 133 | }, 134 | "nbformat": 4, 135 | "nbformat_minor": 2 136 | } 137 | -------------------------------------------------------------------------------- /dev_notes.md: -------------------------------------------------------------------------------- 1 | # Dev Notes 2 | 3 | ## Profiling Kernels with Nsight Systems & Nsight Compute 4 | 5 | ### Nsight Systems 6 | 7 | Documentation: 8 | 9 | Command: 10 | 11 | ```bash 12 | PYTHONPATH=. nsys profile -t cuda,osrt,nvtx,cudnn,cublas -w true -o ./nvidia_nsight/nsys_mlstm_xlchunksize python scripts/run_training_kernel_benchmarks_with_profile.py 13 | ``` 14 | 15 | ### Nsight Compute 16 | 17 | Documentation: 18 | 19 | Command: 20 | 21 | ```bash 22 | PYTHONPATH=. ncu -o kernel_prof -f -c 1 -k mlstm_chunkwise__parallel_fw_Hintra_kernel --set=full python ./scripts/run_training_kernel_benchmarks_with_profile.py 23 | ``` 24 | 25 | ## Running kernel benchmarks with baselines 26 | 27 | To run the benchmarks including all baselines, you have to install: 28 | ```bash 29 | pip install mamba_ssm causal_conv1d fla 30 | ``` 31 | For `FlashAttention3`, you have to clone the original repo `https://github.com/Dao-AILab/flash-attention`: 32 | ```bash 33 | # clone FlashAttention 34 | cd .. 35 | git clone https://github.com/Dao-AILab/flash-attention 36 | # Apply CONDA ENV patch 37 | git apply ../mlstm_kernels/flash_attention.patch 38 | # Install flash attention 3 39 | cd hopper 40 | PYTHONPATH=. python3 setup.py install 41 | cd .. 42 | # Install regular flash attention 2 43 | python3 pip install -e . 44 | # Go back to this repo 45 | cd ../mlstm_kernels 46 | ``` 47 | -------------------------------------------------------------------------------- /envs/environment_pt240cu121.yaml: -------------------------------------------------------------------------------- 1 | name: xlstmpt240cu121 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults # exclude this to work on leonardo 7 | dependencies: 8 | - cuda=12.1 9 | - cuda-nvcc=12.1 10 | - gxx_linux-64=11.2.0 11 | - python=3.11 12 | - pip 13 | - pytorch=2.4.0 14 | - pytorch-cuda=12.1 15 | - torchvision 16 | - cmake 17 | - ninja 18 | - cuda-toolkit=12.1 19 | - cuda-cccl=12.1 20 | - pip: 21 | - einops 22 | - pre-commit #==3.6.0 23 | - ipykernel #==6.29.0 24 | - dacite #==1.8.1 25 | - omegaconf #==2.3.0 26 | - torchmetrics #==1.3.0 27 | - tqdm #==4.66.1 28 | - pytest #==8.0.0 29 | - pytest-xdist #==3.5.0 30 | - numpy<2.0 #==1.26.4 31 | - matplotlib 32 | - pandas 33 | -------------------------------------------------------------------------------- /envs/environment_pt240cu124.yaml: -------------------------------------------------------------------------------- 1 | name: xlstmpt240cu124 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults # exclude this to work on leonardo 7 | dependencies: 8 | - cuda=12.4 9 | - cuda-nvcc=12.4 10 | - gxx_linux-64=13.2.0 11 | - python=3.11 12 | - pip 13 | - pytorch=2.4.0 14 | - pytorch-cuda=12.4 15 | - torchvision 16 | - cmake 17 | - ninja 18 | - cuda-toolkit=12.4 19 | - cuda-cccl=12.4 20 | - pip: 21 | - einops #==0.7.0 22 | - opt_einsum #==3.3.0 23 | - transformers #==4.37.2 24 | - datasets #==2.16.1 25 | - pre-commit #==3.6.0 26 | - reportlab #==4.0.9 27 | - lm-eval #==0.4.0 28 | - joypy #==0.2.6 29 | - ipykernel #==6.29.0 30 | - dacite #==1.8.1 31 | - ftfy #==6.1.3 32 | - ninja #==1.11.1.1 33 | - huggingface-hub #==0.20.3 34 | - joblib #==1.3.2 35 | - lightning #==2.1.3 36 | - rich #==13.7.0 37 | - lm-dataformat #==0.0.20 38 | - omegaconf #==2.3.0 39 | - sentencepiece #==0.1.99 40 | - tokenizers #==0.15.1 41 | - torchmetrics #==1.3.0 42 | - tqdm #==4.66.1 43 | - wandb #==0.16.2 44 | - seaborn #==0.13.2 45 | - pytest #==8.0.0 46 | - pytest-xdist #==3.5.0 47 | - openpyxl #==3.1.2 48 | - gitpython #==3.1.41 49 | - scipy #==1.12.0 50 | - pykeops #==2.2.1 51 | - hydra-core #==1.3.2 52 | - torchtext #==0.16.2 53 | - tensorboard #==2.15.1 54 | - cryptography #==42.0.2 55 | - tensorflow 56 | - numpy #==1.26.4 57 | 58 | 59 | # Please install by hand 60 | # - packaging 61 | # - mamba-ssm #==1.0.1 62 | # - causal-conv1d #==1.1.1 63 | # - git+https://github.com/idiap/fast-transformers@master 64 | # - flash-attn 65 | -------------------------------------------------------------------------------- /envs/environment_pt251cu124.yaml: -------------------------------------------------------------------------------- 1 | name: xlstmpt251cu124 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults # exclude this to work on leonardo 7 | dependencies: 8 | - cuda=12.4 9 | - cuda-nvcc=12.4 10 | - gxx_linux-64=13.2.0 11 | - python=3.11 12 | - pip 13 | - pytorch=2.5.1 14 | - pytorch-cuda=12.4 15 | - torchvision 16 | - cmake 17 | - ninja 18 | - cuda-toolkit=12.4 19 | - cuda-cccl=12.4 20 | - pip: 21 | - einops #==0.7.0 22 | - opt_einsum #==3.3.0 23 | - transformers #==4.37.2 # replace this with xlstm transformers 24 | - datasets #==2.16.1 25 | - pre-commit #==3.6.0 26 | - reportlab #==4.0.9 27 | - lm-eval #==0.4.0 28 | - joypy #==0.2.6 29 | - ipykernel #==6.29.0 30 | - dacite #==1.8.1 31 | - ftfy #==6.1.3 32 | - ninja #==1.11.1.1 33 | - huggingface-hub #==0.20.3 34 | - joblib #==1.3.2 35 | - lightning #==2.1.3 36 | - rich #==13.7.0 37 | - lm-dataformat #==0.0.20 38 | - omegaconf #==2.3.0 39 | - sentencepiece #==0.1.99 40 | - tokenizers #==0.15.1 41 | - torchmetrics #==1.3.0 42 | - tqdm #==4.66.1 43 | - wandb #==0.16.2 44 | - seaborn #==0.13.2 45 | - pytest #==8.0.0 46 | - pytest-xdist #==3.5.0 47 | - openpyxl #==3.1.2 48 | - gitpython #==3.1.41 49 | - scipy #==1.12.0 50 | - pykeops #==2.2.1 51 | - hydra-core #==1.3.2 52 | - torchtext #==0.16.2 53 | - tensorboard #==2.15.1 54 | - cryptography #==42.0.2 55 | - tensorflow 56 | - numpy #==1.26.4 57 | - mamba-ssm 58 | - causal_conv1d 59 | - git+https://github.com/sustcsonglin/flash-linear-attention 60 | 61 | 62 | # Please install by hand 63 | # - packaging 64 | # - mamba-ssm #==1.0.1 65 | # - causal-conv1d #==1.1.1 66 | # - git+https://github.com/idiap/fast-transformers@master 67 | # - flash-attn 68 | -------------------------------------------------------------------------------- /flash_attention.patch: -------------------------------------------------------------------------------- 1 | diff --git a/hopper/setup.py b/hopper/setup.py 2 | index f9f3cfd..9098993 100644 3 | --- a/hopper/setup.py 4 | +++ b/hopper/setup.py 5 | @@ -220,7 +220,8 @@ if not SKIP_CUDA_BUILD: 6 | }, 7 | include_dirs=include_dirs, 8 | # Without this we get and error about cuTensorMapEncodeTiled not defined 9 | - libraries=["cuda"] 10 | + libraries=["cuda"], 11 | + extra_link_args=["-L"+os.environ["CONDA_PREFIX"]+"/lib/stubs"] if "CONDA_PREFIX" in os.environ else [], 12 | ) 13 | ) 14 | # ext_modules.append( 15 | -------------------------------------------------------------------------------- /mlstm_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | __version__ = "2.0.0" 5 | -------------------------------------------------------------------------------- /mlstm_kernels/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/baselines/flash_attention/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .flash_attention_triton import attention_causal as attention_causal_triton_flash 5 | from .torch_sdp_attention import ( 6 | attention_causal_pt_cudnn as attention_causal_torch_cudnn, 7 | attention_causal_pt_efficient as attention_causal_torch_efficient, 8 | attention_causal_pt_fa2 as attention_causal_torch_flash, 9 | attention_causal_pt_math as attention_causal_torch_math, 10 | ) 11 | from .triton_tutorial import attention_causal as attention_causal_triton_tutorial 12 | 13 | registry = { 14 | "torch_flash": attention_causal_torch_flash, 15 | "torch_cudnn": attention_causal_torch_cudnn, 16 | "torch_math": attention_causal_torch_math, 17 | "torch_efficient": attention_causal_torch_efficient, 18 | "triton_flash": attention_causal_triton_flash, 19 | "triton_tutorial": attention_causal_triton_tutorial, 20 | } 21 | -------------------------------------------------------------------------------- /mlstm_kernels/baselines/flash_attention/torch_sdp_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import torch 5 | from torch.nn.attention import SDPBackend, sdpa_kernel 6 | from torch.nn.functional import scaled_dot_product_attention 7 | 8 | 9 | def attention_causal_pt_fa2( 10 | query: torch.Tensor, 11 | key: torch.Tensor, 12 | value: torch.Tensor, 13 | scale: float = None, 14 | ) -> torch.Tensor: 15 | with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 16 | return scaled_dot_product_attention(query, key, value, scale=scale) 17 | 18 | 19 | def attention_causal_pt_cudnn( 20 | query: torch.Tensor, 21 | key: torch.Tensor, 22 | value: torch.Tensor, 23 | scale: float = None, 24 | ) -> torch.Tensor: 25 | with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): 26 | return scaled_dot_product_attention(query, key, value, scale=scale) 27 | 28 | 29 | def attention_causal_pt_math( 30 | query: torch.Tensor, 31 | key: torch.Tensor, 32 | value: torch.Tensor, 33 | scale: float = None, 34 | ) -> torch.Tensor: 35 | with sdpa_kernel(SDPBackend.MATH): 36 | return scaled_dot_product_attention(query, key, value, scale=scale) 37 | 38 | 39 | def attention_causal_pt_efficient( 40 | query: torch.Tensor, 41 | key: torch.Tensor, 42 | value: torch.Tensor, 43 | scale: float = None, 44 | ) -> torch.Tensor: 45 | with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): 46 | return scaled_dot_product_attention(query, key, value, scale=scale) 47 | -------------------------------------------------------------------------------- /mlstm_kernels/baselines/flash_linear_attention/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .gla import ( 5 | chunk_gla as gla_triton, 6 | fused_chunk_gla as fused_gla_triton, 7 | fused_recurrent_gla as fused_recurrent_gla_triton, 8 | ) 9 | from .simple_gla import chunk_simple_gla as simple_gla_triton 10 | 11 | registry = { 12 | "triton_simple_gla": simple_gla_triton, 13 | "triton_gla": gla_triton, 14 | "triton_fused_gla": fused_gla_triton, 15 | "triton_fused_recurrent_gla": fused_recurrent_gla_triton, 16 | } 17 | -------------------------------------------------------------------------------- /mlstm_kernels/baselines/flash_linear_attention/gla/__init__.py: -------------------------------------------------------------------------------- 1 | # This module is copied from https://github.com/sustcsonglin/flash-linear-attention/blob/fee90b2e72366a46c60e3ef16431133aa5aced8d/fla/ops/gla 2 | # Adapted to make it work in this codebase 3 | 4 | from .chunk import chunk_gla 5 | from .chunk_fuse import fused_chunk_gla 6 | from .recurrent_fuse import fused_recurrent_gla 7 | 8 | __all__ = ["chunk_gla", "fused_chunk_gla", "fused_recurrent_gla"] 9 | -------------------------------------------------------------------------------- /mlstm_kernels/baselines/flash_linear_attention/simple_gla/README.md: -------------------------------------------------------------------------------- 1 | - Simple GLA 2 | 3 | Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA. 4 | 5 | $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. 6 | 7 | Copy from https://github.com/sustcsonglin/flash-linear-attention/blob/245a3acdc1e02ca24212b6b16a122c917692b3b4/fla/ops/simple_gla 8 | -------------------------------------------------------------------------------- /mlstm_kernels/baselines/flash_linear_attention/simple_gla/__init__.py: -------------------------------------------------------------------------------- 1 | # This module is copied from https://github.com/sustcsonglin/flash-linear-attention/blob/fee90b2e72366a46c60e3ef16431133aa5aced8d/fla/ops/simple_gla 2 | # Adapted to make it work in this codebase 3 | 4 | 5 | from .chunk import chunk_simple_gla 6 | 7 | __all__ = ["chunk_simple_gla"] 8 | -------------------------------------------------------------------------------- /mlstm_kernels/baselines/flash_linear_attention/simple_gla/naive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import torch 5 | from einops import rearrange 6 | 7 | 8 | def torch_simple_gla(q, k, v, g, chunk_size=64): 9 | q = rearrange(q, "b h (n c) d -> b h n c d", c=chunk_size) * (q.shape[-1] ** -0.5) 10 | k = rearrange(k, "b h (n c) d -> b h n c d", c=chunk_size) 11 | v = rearrange(v, "b h (n c) d -> b h n c d", c=chunk_size) 12 | g = rearrange(g, "b h (n c) -> b h n c", c=chunk_size) 13 | g = g.cumsum(-1) 14 | kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None]) 15 | S = torch.zeros_like(kv) 16 | 17 | for i in range(1, g.shape[-2]): 18 | S[:, :, i] = ( 19 | S[:, :, i - 1].clone() * g[:, :, i - 1, -1, None, None].exp() 20 | + kv[:, :, i - 1] 21 | ) 22 | 23 | inter = (q * g[..., None].exp()) @ S 24 | attn = q @ k.transpose(-1, -2) 25 | attn = attn * (g[..., None] - g[..., None, :]).exp() 26 | attn = attn.masked_fill( 27 | torch.triu( 28 | torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1 29 | ), 30 | 0, 31 | ) 32 | intra = attn @ v 33 | o = inter + intra 34 | return rearrange(o, "b h n c d -> b h (n c) d") 35 | 36 | 37 | def torch_simple_gla_recurrent(q, k, v, g, chunk_size=64): 38 | # q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5) 39 | # k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 40 | # v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 41 | # g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size) 42 | # g = g.cumsum(-1) 43 | # kv = k.transpose(-1, -2) @ v 44 | 45 | B, H, T, DK = q.shape 46 | q = q * (DK**-0.5) 47 | _, _, _, DV = v.shape 48 | S = torch.zeros(B, H, DK, DV).to(q) 49 | o = torch.zeros(B, H, T, DV).to(q) 50 | for i in range(T): 51 | gate = g[:, :, i].exp() 52 | key = k[:, :, i] 53 | value = v[:, :, i] 54 | kv = key.unsqueeze(-1) * value.unsqueeze(-2) 55 | S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv 56 | q_i = q[:, :, i, :] 57 | o_i = (q_i.unsqueeze(-1) * S).sum(-2) 58 | o[:, :, i] = o_i 59 | 60 | return o 61 | -------------------------------------------------------------------------------- /mlstm_kernels/baselines/lightning_attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/mlstm_kernels/baselines/lightning_attention/__init__.py -------------------------------------------------------------------------------- /mlstm_kernels/baselines/lightning_attention/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def _build_slope_tensor(n_attention_heads: int): 7 | def get_slopes(n): 8 | def get_slopes_power_of_2(n): 9 | start = 2 ** (-(2 ** -(math.log2(n) - 3))) 10 | ratio = start 11 | return [start * ratio**i for i in range(n)] 12 | 13 | if math.log2(n).is_integer(): 14 | return get_slopes_power_of_2( 15 | n 16 | ) # In the paper, we only train models that have 2^a heads for some a. This function has 17 | else: # some good properties that only occur when the input is a power of 2. To maintain that even 18 | closest_power_of_2 = 2 ** math.floor( 19 | math.log2(n) 20 | ) # when the number of heads is not a power of 2, we use this workaround. 21 | return ( 22 | get_slopes_power_of_2(closest_power_of_2) 23 | + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] 24 | ) 25 | 26 | # h, 1, 1 27 | slopes = torch.tensor(get_slopes(n_attention_heads)).reshape( 28 | n_attention_heads, 1, 1 29 | ) 30 | 31 | return slopes 32 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from collections.abc import Callable 5 | 6 | 7 | def _create_module_sequence_backend_registry() -> dict[str, dict[str, Callable]]: 8 | from .chunkwise import registry as mlstm_chunkwise_registry 9 | from .parallel import registry as mlstm_parallel_registry 10 | 11 | module_backend_registry = { 12 | "chunkwise": mlstm_chunkwise_registry, 13 | "parallel": mlstm_parallel_registry, 14 | } 15 | return module_backend_registry 16 | 17 | 18 | def get_available_mlstm_kernels() -> list[str]: 19 | """ 20 | Get a list of available mlstm sequence kernel names. 21 | """ 22 | module_backend_registry = _create_module_sequence_backend_registry() 23 | 24 | backend_names = [ 25 | f"{module_key}--{kernel_key}" 26 | for module_key in module_backend_registry.keys() 27 | for kernel_key in module_backend_registry[module_key].keys() 28 | ] 29 | return backend_names 30 | 31 | 32 | def get_available_mlstm_step_kernels() -> list[str]: 33 | from .recurrent import registry_step as mlstm_recurrent_step_registry 34 | 35 | backend_names = list(mlstm_recurrent_step_registry.keys()) 36 | return backend_names 37 | 38 | 39 | def get_mlstm_kernel(name: str) -> Callable: 40 | """ 41 | Get a mlstm sequence kernel function by name. 42 | 43 | Naming convention: 44 | name = "--" 45 | 46 | module_name: The name of the module containing the kernel function. 47 | Example: "chunkwise", "parallel", "recurrent" 48 | 49 | backend_name: The name of the kernel function as defined in the registry in the __init__.py file of the module. 50 | """ 51 | 52 | module_backend_registry = _create_module_sequence_backend_registry() 53 | 54 | module_name, backend_name = name.split("--") 55 | 56 | if module_name not in module_backend_registry: 57 | raise ValueError( 58 | f"Unknown module name: {module_name}. Available module names: {list(module_backend_registry.keys())}" 59 | ) 60 | 61 | if backend_name not in module_backend_registry[module_name]: 62 | raise ValueError( 63 | f"Unknown backend name: {backend_name}. Available backend names: {list(module_backend_registry[module_name].keys())}" 64 | ) 65 | 66 | return module_backend_registry[module_name][backend_name] 67 | 68 | 69 | def get_mlstm_step_kernel(name: str) -> Callable: 70 | """ 71 | Get a mlstm step kernel function by name. 72 | 73 | Naming convention: 74 | name = "" 75 | 76 | backend_name: The name of the kernel function as defined in the registry in the __init__.py file of the module. 77 | """ 78 | from .recurrent import registry_step as mlstm_recurrent_step_registry 79 | 80 | if name not in mlstm_recurrent_step_registry: 81 | raise ValueError( 82 | f"Unknown backend name: {name}. Available backend names: {list(mlstm_recurrent_step_registry.keys())}" 83 | ) 84 | 85 | return mlstm_recurrent_step_registry[name] 86 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/chunkwise/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .native import mlstm_chunkwise__native_autograd 5 | from .triton_limit_chunk import mlstm_chunkwise__limit_chunk 6 | from .triton_xl_chunk import mlstm_chunkwise__xl_chunk 7 | from .triton_xl_chunk_siging import mlstm_siging_chunkwise__xl_chunk 8 | 9 | registry = { 10 | "native_autograd": mlstm_chunkwise__native_autograd, 11 | "triton_limit_chunk": mlstm_chunkwise__limit_chunk, 12 | "triton_xl_chunk": mlstm_chunkwise__xl_chunk, 13 | "triton_xl_chunk_siging": mlstm_siging_chunkwise__xl_chunk, 14 | } 15 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/chunkwise/native/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_chunkwise__native_autograd 5 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/chunkwise/native/fwbw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import jax 5 | 6 | from .fw import mlstm_chunkwise_fw 7 | 8 | 9 | def mlstm_chunkwise__native_autograd( 10 | q: jax.Array, 11 | k: jax.Array, 12 | v: jax.Array, 13 | i: jax.Array, 14 | f: jax.Array, 15 | c_initial: jax.Array = None, 16 | n_initial: jax.Array = None, 17 | m_initial: jax.Array = None, 18 | return_last_states: bool = False, 19 | eps: float = 1e-6, 20 | chunk_size: int = 64, 21 | **kwargs, 22 | ) -> jax.Array | tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]: 23 | matH_out, _, _, last_states, _ = mlstm_chunkwise_fw( 24 | matQ=q, 25 | matK=k, 26 | matV=v, 27 | vecI=i, 28 | vecF=f, 29 | matC_initial=c_initial, 30 | vecN_initial=n_initial, 31 | scaM_initial=m_initial, 32 | return_last_states=return_last_states, 33 | return_all_states=False, 34 | eps=eps, 35 | chunk_size=chunk_size, 36 | ) 37 | if return_last_states: 38 | return matH_out, last_states 39 | else: 40 | return matH_out 41 | 42 | 43 | # TODO bring this into jax 44 | # def mlstm_chunkwise_custbw( 45 | # q: jax.Array, 46 | # k: jax.Array, 47 | # v: jax.Array, 48 | # i: jax.Array, 49 | # f: jax.Array, 50 | # c_initial: jax.Array = None, 51 | # n_initial: jax.Array = None, 52 | # m_initial: jax.Array = None, 53 | # return_last_states: bool = False, 54 | # eps: float = 1e-6, 55 | # chunk_size: int = 64, 56 | # autocast_kernel_dtype: torch.dtype = torch.float32, 57 | # ) -> jax.Array | tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]: 58 | # _mlstm_chunkwise_fwbw = _get_chunkwise_fwbw_kernel(autocast_kernel_dtype) 59 | # matH_out, matC_last, vecN_last, scaM_last = _mlstm_chunkwise_fwbw.apply( 60 | # q, 61 | # k, 62 | # v, 63 | # i, 64 | # f, 65 | # c_initial, 66 | # n_initial, 67 | # m_initial, 68 | # None, 69 | # return_last_states, 70 | # True, 71 | # chunk_size, 72 | # eps, 73 | # ) 74 | # if return_last_states: 75 | # return matH_out, (matC_last, vecN_last, scaM_last) 76 | # else: 77 | # return matH_out 78 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/chunkwise/triton_limit_chunk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_chunkwise__limit_chunk 5 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/chunkwise/triton_xl_chunk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_chunkwise__xl_chunk 5 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/chunkwise/triton_xl_chunk/chunkwise_gates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | """In this file we compute the chunkwise or cumulative gates (i.e. vecA and vecB) 5 | for the forward and backward pass of the mLSTM. 6 | We use the stable formulations, i.e. we avoid subtraction of forget gates. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | 12 | 13 | def compute_chunkwise_log_gates_vecB_vecA( 14 | vecI: jax.Array, # (B, NH, S) 15 | vecF: jax.Array, # (B, NH, S) 16 | chunk_size: int, 17 | ): 18 | B, NH, S = vecI.shape 19 | assert S % chunk_size == 0, f"S={S} is not divisible by chunk_size={chunk_size}" 20 | NC = S // chunk_size 21 | L = chunk_size 22 | 23 | # compute vecB 24 | vecF_logsig: jax.Array = jax.nn.log_sigmoid(vecF.astype(jnp.float32)) 25 | vecF_logsig_chunked = vecF_logsig.reshape(B, NH, NC, L) 26 | vecB = vecF_logsig_chunked.cumsum(axis=-1) 27 | 28 | # compute vecA 29 | vecI_chunked = vecI.reshape(B, NH, NC, L) 30 | # unstable vecA computation: 31 | # vecA = (vecB[..., -1, None] - vecB) + vecI # (B, NH, NC, L) 32 | # stable vecA computation: 33 | vecF_cumsum = jnp.flip( 34 | jnp.flip(vecF_logsig_chunked[..., 1:], axis=-1).cumsum(-1), axis=-1 35 | ) 36 | vecA = ( 37 | jnp.concat( 38 | [ 39 | vecF_cumsum, 40 | jnp.zeros((B, NH, NC, 1), dtype=jnp.float32), 41 | ], 42 | axis=-1, 43 | ) 44 | + vecI_chunked 45 | ) # (B, NH, NC, L) 46 | return vecB, vecA 47 | 48 | 49 | def compute_chunkwise_log_gates_vecB( 50 | vecF: jax.Array, # (B, NH, S) 51 | chunk_size: int, 52 | ): 53 | B, NH, S = vecF.shape 54 | assert S % chunk_size == 0, f"S={S} is not divisible by chunk_size={chunk_size}" 55 | NC = S // chunk_size 56 | L = chunk_size 57 | 58 | # compute vecB 59 | vecF_logsig: jax.Array = jax.nn.log_sigmoid(vecF.astype(jnp.float32)) 60 | vecF_logsig_chunked = vecF_logsig.reshape(B, NH, NC, L) 61 | vecB = vecF_logsig_chunked.cumsum(axis=-1) 62 | 63 | return vecB 64 | 65 | 66 | def compute_gate_grads_vecDeltaI_vecDeltaF( 67 | matQ: jax.Array, 68 | matK: jax.Array, 69 | matDeltaQ: jax.Array, 70 | matDeltaK: jax.Array, 71 | vecF: jax.Array, 72 | ) -> tuple[jax.Array, jax.Array]: 73 | # postprocessing: compute deltaF and deltaI gradients 74 | # vecF = rearrange(vecF, "b nh nc l -> b nh (nc l)") 75 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 76 | matQ = matQ.astype(jnp.float32) 77 | matK = matK.astype(jnp.float32) 78 | matDeltaQ = matDeltaQ.astype(jnp.float32) 79 | matDeltaK = matDeltaK.astype(jnp.float32) 80 | vecDeltaFbar_acc = ((matQ * matDeltaQ) - (matK * matDeltaK)).sum(-1) 81 | # 82 | # vecDeltaFbar = jnp.flip(jnp.cumsum(jnp.flip(vecDeltaFbar_acc, axis=-1).astype(jnp.float32), axis=-1), axis=-1) 83 | # vecDeltaF = vecDeltaFbar * jax.nn.sigmoid(-vecF) 84 | # align with limit_chunk kernel: 85 | vecDeltaFbar = jnp.flip(vecDeltaFbar_acc, axis=-1).astype(jnp.float32) 86 | vecDeltaFbar = jnp.flip(vecDeltaFbar.cumsum(axis=-1), axis=-1) 87 | vecDeltaF = vecDeltaFbar * jax.nn.sigmoid(-vecF) 88 | # compute deltaI 89 | # both are equivalent: 90 | # vecDeltaI = (matV * matDeltaV).sum(-1) 91 | vecDeltaI = (matK * matDeltaK).sum(-1) 92 | 93 | return vecDeltaI, vecDeltaF 94 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/chunkwise/triton_xl_chunk_siging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_siging_chunkwise__xl_chunk 5 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/chunkwise/triton_xl_chunk_siging/chunkwise_gates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | """In this file we compute the chunkwise or cumulative gates (i.e. vecA and vecB) 5 | for the forward and backward pass of the mLSTM. 6 | We use the stable formulations, i.e. we avoid subtraction of forget gates. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | 12 | 13 | def compute_chunkwise_log_gates_vecB_vecA( 14 | vecI: jax.Array, # (B, NH, S) 15 | vecF: jax.Array, # (B, NH, S) 16 | chunk_size: int, 17 | ) -> tuple[jax.Array, jax.Array]: 18 | B, NH, S = vecI.shape 19 | assert S % chunk_size == 0, f"S={S} is not divisible by chunk_size={chunk_size}" 20 | NC = S // chunk_size 21 | L = chunk_size 22 | 23 | # compute vecB 24 | vecF_logsig: jax.Array = jax.nn.log_sigmoid(vecF.astype(jnp.float32)) 25 | vecF_logsig_chunked = vecF_logsig.reshape(B, NH, NC, L) 26 | vecB = vecF_logsig_chunked.cumsum(axis=-1) 27 | 28 | vecI_logsig = jax.nn.log_sigmoid(vecI.astype(jnp.float32)) 29 | 30 | # compute vecA 31 | vecI_logsig_chunked = vecI_logsig.reshape(B, NH, NC, L) 32 | # unstable vecA computation: 33 | # vecA = (vecB[..., -1, None] - vecB) + vecI # (B, NH, NC, L) 34 | # stable vecA computation: 35 | vecF_cumsum = jnp.flip( 36 | jnp.flip(vecF_logsig_chunked[..., 1:], axis=-1).cumsum(-1), axis=-1 37 | ) 38 | vecA = ( 39 | jnp.concat( 40 | [ 41 | vecF_cumsum, 42 | jnp.zeros((B, NH, NC, 1), dtype=jnp.float32), 43 | ], 44 | axis=-1, 45 | ) 46 | + vecI_logsig_chunked 47 | ) # (B, NH, NC, L) 48 | return vecB, vecA 49 | 50 | 51 | def compute_chunkwise_log_gates_vecB( 52 | vecF: jax.Array, # (B, NH, S) 53 | chunk_size: int, 54 | ) -> jax.Array: 55 | B, NH, S = vecF.shape 56 | assert S % chunk_size == 0, f"S={S} is not divisible by chunk_size={chunk_size}" 57 | NC = S // chunk_size 58 | L = chunk_size 59 | 60 | # compute vecB 61 | vecF_logsig: jax.Array = jax.nn.log_sigmoid(vecF.astype(jnp.float32)) 62 | vecF_logsig_chunked = vecF_logsig.reshape(B, NH, NC, L) 63 | vecB = vecF_logsig_chunked.cumsum(axis=-1) 64 | return vecB 65 | 66 | 67 | def compute_gate_grads_vecDeltaI_vecDeltaF( 68 | matQ: jax.Array, 69 | matK: jax.Array, 70 | matDeltaQ: jax.Array, 71 | matDeltaK: jax.Array, 72 | vecI: jax.Array, 73 | vecF: jax.Array, 74 | ) -> tuple[jax.Array, jax.Array]: 75 | """Compute the gradients of the input and forget gates.""" 76 | 77 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 78 | matQ = matQ.astype(jnp.float32) 79 | matK = matK.astype(jnp.float32) 80 | matDeltaQ = matDeltaQ.astype(jnp.float32) 81 | matDeltaK = matDeltaK.astype(jnp.float32) 82 | vecDeltaFbar_acc = ((matQ * matDeltaQ) - (matK * matDeltaK)).sum(-1) 83 | vecDeltaFbar = jnp.flip( 84 | jnp.cumsum(jnp.flip(vecDeltaFbar_acc, axis=-1).astype(jnp.float32), axis=-1), 85 | axis=-1, 86 | ) 87 | vecDeltaF = vecDeltaFbar * jax.nn.sigmoid(-vecF) 88 | 89 | # compute deltaI 90 | # both are equivalent: 91 | # vecDeltaIbar = (matV * matDeltaV).sum(-1) 92 | vecDeltaIbar = (matK * matDeltaK).sum(-1) 93 | vecDeltaI = vecDeltaIbar * jax.nn.sigmoid(-vecI) 94 | return vecDeltaI, vecDeltaF 95 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .native import mlstm_parallel__native_autograd, mlstm_parallel__native_custbw 5 | from .native_siging import ( 6 | mlstm_siging_parallel__native_autograd, 7 | mlstm_siging_parallel__native_custbw, 8 | ) 9 | from .native_stablef import ( 10 | mlstm_parallel__native_stablef_autograd, 11 | mlstm_parallel__native_stablef_custbw, 12 | ) 13 | 14 | registry = { 15 | "native_autograd": mlstm_parallel__native_autograd, 16 | "native_custbw": mlstm_parallel__native_custbw, 17 | "native_stablef_autograd": mlstm_parallel__native_stablef_autograd, 18 | "native_stablef_custbw": mlstm_parallel__native_stablef_custbw, 19 | "native_siging_autograd": mlstm_siging_parallel__native_autograd, 20 | "native_siging_custbw": mlstm_siging_parallel__native_custbw, 21 | } 22 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/native/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_parallel__native_autograd, mlstm_parallel__native_custbw 5 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/native/bw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | 8 | def mlstm_parallel_bw( 9 | matDeltaHtilde: jax.Array, 10 | matQ: jax.Array, 11 | matK: jax.Array, 12 | matV: jax.Array, 13 | vecI: jax.Array, 14 | vecF: jax.Array, 15 | vecN: jax.Array, 16 | vecM: jax.Array, 17 | eps: float = 1e-6, 18 | ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: 19 | B, NH, S, DHQK = matQ.shape 20 | assert matK.shape == (B, NH, S, DHQK) 21 | assert vecI.shape == (B, NH, S) 22 | assert vecF.shape == (B, NH, S) 23 | 24 | vecLogSigF = jax.nn.log_sigmoid(vecF) # (B, NH, S) 25 | vecLogSigF_cumsum = jnp.cumsum(vecLogSigF, axis=-1) 26 | 27 | matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :] 28 | 29 | ltr = jnp.tril( 30 | jnp.ones( 31 | (S, S), 32 | dtype=jnp.bool_, 33 | ) 34 | ) 35 | 36 | matLogSigF_mask = jnp.where(ltr, matLogSigF, -float("inf")) 37 | 38 | matLogD = matLogSigF_mask + vecI[:, :, None, :] 39 | 40 | matLogD_stabilized = matLogD - vecM[:, :, :, None] 41 | 42 | matD = jnp.exp(matLogD_stabilized) # (B, NH, S, S) 43 | 44 | # intermediate delta-errors 45 | matDeltaC = matDeltaHtilde @ matV.swapaxes(-2, -1) / (vecN[:, :, :, None] + eps) 46 | 47 | matS = (matQ @ matK.swapaxes(-2, -1)) * (DHQK**-0.5) 48 | 49 | matDeltaDtilde = matDeltaC * matD * matS 50 | 51 | vecDeltaI = jnp.sum(matDeltaDtilde, axis=-2) 52 | 53 | # output delta-errors / gradients 54 | matP = matDeltaC * matD 55 | 56 | matDeltaQ = (matP @ matK) * (DHQK**-0.5) 57 | matDeltaK = (matP.swapaxes(-2, -1) @ matQ) * (DHQK**-0.5) 58 | 59 | matCtilde = matS * matD 60 | matDeltaV = matCtilde.swapaxes(-2, -1) @ ( 61 | matDeltaHtilde / (vecN[:, :, :, None] + eps) 62 | ) 63 | 64 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 65 | vecDeltaFbar_acc = jnp.sum((matQ * matDeltaQ - matK * matDeltaK), axis=-1) 66 | vecDeltaFbar = jnp.flip( 67 | jnp.cumsum(jnp.flip(vecDeltaFbar_acc, axis=-1), axis=-1), axis=-1 68 | ) 69 | vecDeltaF = vecDeltaFbar * jax.nn.sigmoid(-vecF) 70 | 71 | return ( 72 | matDeltaQ, 73 | matDeltaK, 74 | matDeltaV, 75 | vecDeltaI, 76 | vecDeltaF, 77 | ) 78 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/native/fw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | """ 6 | Jax. 7 | 8 | mLSTM forward and backward pass. Parallel formulation. 9 | """ 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | 14 | 15 | def mlstm_parallel_fw( 16 | matQ: jax.Array, 17 | matK: jax.Array, 18 | matV: jax.Array, 19 | vecI: jax.Array, 20 | vecF: jax.Array, 21 | eps: float = 1e-6, 22 | ) -> jax.Array: 23 | B, NH, S, DHQK = matQ.shape 24 | assert matK.shape == (B, NH, S, DHQK) 25 | assert vecI.shape == (B, NH, S) 26 | assert vecF.shape == (B, NH, S) 27 | 28 | vecLogSigF = jax.nn.log_sigmoid(vecF) # (B, NH, S) 29 | vecLogSigF_cumsum = jnp.cumsum(vecLogSigF, axis=-1) 30 | 31 | matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :] 32 | 33 | ltr = jnp.tril(jnp.ones((S, S), dtype=jnp.bool_)) 34 | 35 | matLogSigF_mask = jnp.where(ltr, matLogSigF, -float("inf")) 36 | 37 | matLogD = matLogSigF_mask + vecI[:, :, None, :] 38 | 39 | vecM = jnp.max(matLogD, axis=-1, keepdims=True) # (B, NH, S, 1) 40 | matLogD_stabilized = matLogD - vecM 41 | 42 | matD = jnp.exp(matLogD_stabilized) # (B, NH, S, S) 43 | 44 | matS = (matQ @ matK.swapaxes(-2, -1)) * (DHQK**-0.5) # (B, NH, S, S) 45 | 46 | matCtilde = matS * matD # (B, NH, S, S) 47 | vecN = jnp.maximum( 48 | jnp.abs(jnp.sum(matCtilde, axis=-1, keepdims=True)), jnp.exp(-vecM) 49 | ) # (B, NH, S, 1) 50 | # (B, NH, S, S) 51 | matC = matCtilde / (vecN + eps) 52 | 53 | matH = matC @ matV # (B, NH, S, DH) 54 | 55 | vecN = vecN.squeeze(-1) 56 | vecM = vecM.squeeze(-1) 57 | 58 | return (matH, vecN, vecM) 59 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/native_siging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import ( 5 | mlstm_siging_parallel__native_autograd, 6 | mlstm_siging_parallel__native_custbw, 7 | ) 8 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/native_siging/bw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | 10 | def mlstm_siging_parallel_bw( 11 | matDeltaHtilde: jax.Array, 12 | matQ: jax.Array, 13 | matK: jax.Array, 14 | matV: jax.Array, 15 | vecI: jax.Array, 16 | vecF: jax.Array, 17 | vecN: jax.Array, 18 | eps: float = 1e-6, 19 | stable_fgate: bool = True, 20 | normalize: bool = True, 21 | ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: 22 | B, NH, S, DHQK = matQ.shape 23 | assert matK.shape == (B, NH, S, DHQK) 24 | assert vecI.shape == (B, NH, S) 25 | assert vecF.shape == (B, NH, S) 26 | 27 | vecLogSigF = jax.nn.log_sigmoid(vecF) # (B, NH, S) 28 | 29 | if stable_fgate: 30 | matLogSigF_tril = jnp.tril(vecLogSigF[:, :, :, None].repeat(S, axis=-1), k=-1) 31 | matLogSigF = jnp.cumsum(matLogSigF_tril, axis=-2) 32 | else: 33 | vecLogSigF_cumsum = jnp.cumsum(vecLogSigF, axis=-1) 34 | matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :] 35 | 36 | ltr = jnp.tril(jnp.ones((S, S), dtype=jnp.bool_)) 37 | 38 | matLogSigF_mask = jnp.where(ltr, matLogSigF, -float("inf")) 39 | 40 | vecLogSigI = jax.nn.log_sigmoid(vecI) 41 | 42 | matLogD = matLogSigF_mask + vecLogSigI[:, :, None, :] 43 | 44 | matD = jnp.exp(matLogD) # (B, NH, S, S) 45 | 46 | # intermediate delta-errors 47 | if normalize: 48 | matDeltaC = matDeltaHtilde @ matV.swapaxes(-2, -1) / (vecN[:, :, :, None] + eps) 49 | else: 50 | matDeltaC = matDeltaHtilde @ matV.swapaxes(-2, -1) 51 | 52 | matS = (matQ @ matK.swapaxes(-2, -1)) * (DHQK**-0.5) 53 | 54 | matDeltaDtilde = matDeltaC * matD * matS 55 | 56 | vecDeltaIbar = jnp.sum(matDeltaDtilde, axis=-2) 57 | 58 | # output delta-errors / gradients 59 | matP = matDeltaC * matD 60 | 61 | matDeltaQ = (matP @ matK) * (DHQK**-0.5) 62 | matDeltaK = (matP.swapaxes(-2, -1) @ matQ) * (DHQK**-0.5) 63 | 64 | matCtilde: jax.Array = matS * matD 65 | 66 | if normalize: 67 | matDeltaV = matCtilde.swapaxes(-2, -1) @ ( 68 | matDeltaHtilde / (vecN[:, :, :, None] + eps) 69 | ) 70 | else: 71 | matDeltaV = matCtilde.swapaxes(-2, -1) @ matDeltaHtilde 72 | 73 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 74 | vecDeltaFbar_acc = jnp.sum((matQ * matDeltaQ - matK * matDeltaK), axis=-1) 75 | vecDeltaFbar = jnp.flip( 76 | jnp.cumsum(jnp.flip(vecDeltaFbar_acc, axis=-1), axis=-1), axis=-1 77 | ) 78 | vecDeltaF = vecDeltaFbar * jax.nn.sigmoid(-vecF) 79 | 80 | vecDeltaI = vecDeltaIbar * jax.nn.sigmoid(-vecI) 81 | 82 | return ( 83 | matDeltaQ, 84 | matDeltaK, 85 | matDeltaV, 86 | vecDeltaI, 87 | vecDeltaF, 88 | ) 89 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/native_siging/fw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | """ 6 | Jax. 7 | 8 | mLSTM sigmoid input gate forward pass. Parallel formulation. 9 | """ 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | 14 | 15 | def mlstm_siging_parallel_fw( 16 | matQ: jax.Array, 17 | matK: jax.Array, 18 | matV: jax.Array, 19 | vecI: jax.Array, 20 | vecF: jax.Array, 21 | eps: float = 1e-6, 22 | stable_fgate: bool = True, 23 | normalize: bool = True, 24 | ) -> jax.Array: 25 | B, NH, S, DHQK = matQ.shape 26 | assert matK.shape == (B, NH, S, DHQK) 27 | assert vecI.shape == (B, NH, S) 28 | assert vecF.shape == (B, NH, S) 29 | 30 | vecLogSigF = jax.nn.log_sigmoid(vecF) # (B, NH, S) 31 | 32 | if stable_fgate: 33 | matLogSigF_tril = jnp.tril(vecLogSigF[:, :, :, None].repeat(S, axis=-1), k=-1) 34 | matLogSigF = jnp.cumsum(matLogSigF_tril, axis=-2) 35 | else: 36 | vecLogSigF_cumsum = jnp.cumsum(vecLogSigF, axis=-1) 37 | matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :] 38 | 39 | ltr = jnp.tril(jnp.ones((S, S), dtype=jnp.bool_)) 40 | 41 | matLogSigF_mask = jnp.where(ltr, matLogSigF, -float("inf")) 42 | 43 | vecLogSigI = jax.nn.log_sigmoid(vecI) 44 | 45 | matLogD = matLogSigF_mask + vecLogSigI[:, :, None, :] 46 | 47 | matD = jnp.exp(matLogD) # (B, NH, S, S) 48 | 49 | matS = (matQ @ matK.swapaxes(-2, -1)) * (DHQK**-0.5) # (B, NH, S, S) 50 | 51 | matCtilde = matS * matD # (B, NH, S, S) 52 | if normalize: 53 | vecN = jnp.maximum( 54 | jnp.abs(jnp.sum(matCtilde, axis=-1, keepdims=True)), 55 | jnp.array([1.0]), 56 | ) # (B, NH, S, 1) 57 | # (B, NH, S, S) 58 | matC = matCtilde / (vecN + eps) 59 | vecN = vecN.squeeze(-1) 60 | else: 61 | matC = matCtilde 62 | vecN = None 63 | 64 | matH = matC @ matV # (B, NH, S, DH) 65 | 66 | return (matH, vecN) 67 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/native_stablef/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import ( 5 | mlstm_parallel__native_stablef_autograd, 6 | mlstm_parallel__native_stablef_custbw, 7 | ) 8 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/native_stablef/bw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | 10 | def mlstm_parallel_bw( 11 | matDeltaHtilde: jax.Array, 12 | matQ: jax.Array, 13 | matK: jax.Array, 14 | matV: jax.Array, 15 | vecI: jax.Array, 16 | vecF: jax.Array, 17 | vecN: jax.Array, 18 | vecM: jax.Array, 19 | eps: float = 1e-6, 20 | ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: 21 | B, NH, S, DHQK = matQ.shape 22 | assert matK.shape == (B, NH, S, DHQK) 23 | assert vecI.shape == (B, NH, S) 24 | assert vecF.shape == (B, NH, S) 25 | 26 | vecLogSigF = jax.nn.log_sigmoid(vecF) # (B, NH, S) 27 | 28 | matLogSigF_tril = jnp.tril(vecLogSigF[:, :, :, None].repeat(S, axis=-1), k=-1) 29 | matLogSigF_cum = jnp.cumsum(matLogSigF_tril, axis=-2) 30 | 31 | ltr = jnp.tril( 32 | jnp.ones( 33 | (S, S), 34 | dtype=jnp.bool_, 35 | ) 36 | ) 37 | 38 | matLogSigF_mask = jnp.where(ltr, matLogSigF_cum, -float("inf")) 39 | 40 | matLogD = matLogSigF_mask + vecI[:, :, None, :] 41 | 42 | matLogD_stabilized = matLogD - vecM[:, :, :, None] 43 | 44 | matD = jnp.exp(matLogD_stabilized) # (B, NH, S, S) 45 | 46 | # intermediate delta-errors 47 | matDeltaC = matDeltaHtilde @ matV.swapaxes(-2, -1) / (vecN[:, :, :, None] + eps) 48 | 49 | matS = (matQ @ matK.swapaxes(-2, -1)) * (DHQK**-0.5) 50 | 51 | matDeltaDtilde = matDeltaC * matD * matS 52 | 53 | vecDeltaI = jnp.sum(matDeltaDtilde, axis=-2) 54 | 55 | # output delta-errors / gradients 56 | matP = matDeltaC * matD 57 | 58 | matDeltaQ = (matP @ matK) * (DHQK**-0.5) 59 | matDeltaK = (matP.swapaxes(-2, -1) @ matQ) * (DHQK**-0.5) 60 | 61 | matCtilde = matS * matD 62 | matDeltaV = matCtilde.swapaxes(-2, -1) @ ( 63 | matDeltaHtilde / (vecN[:, :, :, None] + eps) 64 | ) 65 | 66 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 67 | vecDeltaFbar_acc = jnp.sum((matQ * matDeltaQ - matK * matDeltaK), axis=-1) 68 | vecDeltaFbar = jnp.flip( 69 | jnp.cumsum(jnp.flip(vecDeltaFbar_acc, axis=-1), axis=-1), axis=-1 70 | ) 71 | vecDeltaF = vecDeltaFbar * jax.nn.sigmoid(-vecF) 72 | 73 | return ( 74 | matDeltaQ, 75 | matDeltaK, 76 | matDeltaV, 77 | vecDeltaI, 78 | vecDeltaF, 79 | ) 80 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/parallel/native_stablef/fw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | """ 6 | Jax. 7 | 8 | mLSTM forward and backward pass. Parallel formulation. 9 | """ 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | 14 | 15 | def mlstm_parallel_fw( 16 | matQ: jax.Array, 17 | matK: jax.Array, 18 | matV: jax.Array, 19 | vecI: jax.Array, 20 | vecF: jax.Array, 21 | eps: float = 1e-6, 22 | ) -> jax.Array: 23 | B, NH, S, DHQK = matQ.shape 24 | assert matK.shape == (B, NH, S, DHQK) 25 | assert vecI.shape == (B, NH, S) 26 | assert vecF.shape == (B, NH, S) 27 | 28 | vecLogSigF = jax.nn.log_sigmoid(vecF) # (B, NH, S) 29 | 30 | matLogSigF_tril = jnp.tril(vecLogSigF[:, :, :, None].repeat(S, axis=-1), k=-1) 31 | matLogSigF_cum = jnp.cumsum(matLogSigF_tril, axis=-2) 32 | 33 | ltr = jnp.tril( 34 | jnp.ones( 35 | (S, S), 36 | dtype=jnp.bool_, 37 | ) 38 | ) 39 | 40 | matLogSigF_mask = jnp.where(ltr, matLogSigF_cum, -float("inf")) 41 | 42 | matLogD = matLogSigF_mask + vecI[:, :, None, :] 43 | 44 | vecM = jnp.max(matLogD, axis=-1, keepdims=True) # (B, NH, S, 1) 45 | matLogD_stabilized = matLogD - vecM 46 | 47 | matD = jnp.exp(matLogD_stabilized) # (B, NH, S, S) 48 | 49 | matS = (matQ @ matK.swapaxes(-2, -1)) * (DHQK**-0.5) # (B, NH, S, S) 50 | 51 | matCtilde = matS * matD # (B, NH, S, S) 52 | vecN = jnp.maximum( 53 | jnp.abs(jnp.sum(matCtilde, axis=-1, keepdims=True)), jnp.exp(-vecM) 54 | ) # (B, NH, S, 1) 55 | # (B, NH, S, S) 56 | matC = matCtilde / (vecN + eps) 57 | 58 | matH = matC @ matV # (B, NH, S, DH) 59 | 60 | vecN = vecN.squeeze(-1) 61 | vecM = vecM.squeeze(-1) 62 | 63 | return (matH, vecN, vecM) 64 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/recurrent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .native_sequence import ( 5 | mlstm_recurrent_sequence__native_fw, 6 | mlstm_recurrent_sequence__triton_step_fused_fw, 7 | ) 8 | from .native_step import mlstm_recurrent_step__native 9 | from .triton_step import mlstm_recurrent_step__triton 10 | 11 | registry_step = { 12 | "native": mlstm_recurrent_step__native, 13 | "triton": mlstm_recurrent_step__triton, 14 | } 15 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/stride_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import jax 5 | import numpy as np 6 | 7 | 8 | def get_strides(array: jax.Array | jax.ShapeDtypeStruct) -> list[int]: 9 | """ 10 | Returns the strides of a JAX array. 11 | 12 | Args: 13 | array: JAX array or shape-dtype struct. 14 | 15 | Returns: 16 | The strides of the array. Length is equal to the number of dimensions. 17 | """ 18 | shape = array.shape 19 | size = array.size 20 | strides = [] 21 | for s in shape: 22 | size = size // s 23 | strides.append(int(size)) 24 | return strides 25 | 26 | 27 | def get_stride(array: jax.Array | jax.ShapeDtypeStruct, axis: int) -> int: 28 | """ 29 | Returns the stride of a JAX array at a given axis. 30 | 31 | To calculate all strides, use get_strides. 32 | 33 | Args: 34 | array: JAX array or shape-dtype struct. 35 | axis: The axis at which to calculate the stride. 36 | 37 | Returns: 38 | The stride of the array at the given axis. 39 | """ 40 | shape = array.shape 41 | size = array.size 42 | stride = size // np.prod(shape[: axis + 1]) 43 | return int(stride) 44 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import triton.language as tl 8 | 9 | 10 | def dtype2str(dtype: jnp.dtype) -> str: 11 | if dtype == jnp.float32: 12 | return "fp32" 13 | elif dtype == jnp.float16: 14 | return "fp16" 15 | elif dtype == jnp.float64: 16 | return "fp64" 17 | elif dtype == jnp.bfloat16: 18 | return "bf16" 19 | else: 20 | raise ValueError(f"Unsupported dtype: {dtype}") 21 | 22 | 23 | def jax2triton_dtype(dtype): 24 | """ 25 | Converts a JAX dtype to a Triton dtype. 26 | 27 | Args: 28 | dtype: JAX dtype. 29 | 30 | Returns: 31 | Triton dtype. 32 | """ 33 | # We need to grab the dtype from the dtype object in jax 34 | # >> dt = jnp.float32 35 | # >> str(dt), str(dt.dtype) 36 | # Output: 37 | # ("", 'float32') 38 | if hasattr(dtype, "dtype"): 39 | dtype = dtype.dtype 40 | return getattr(tl, str(dtype)) 41 | 42 | 43 | def to_numpy(tensor: jnp.ndarray) -> np.ndarray: 44 | return jax.device_get(tensor) 45 | -------------------------------------------------------------------------------- /mlstm_kernels/jax/xla_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import os 5 | 6 | 7 | def simulate_CPU_devices(device_count: int = 8): 8 | """ 9 | Simulate a CPU with a given number of devices. 10 | 11 | Args: 12 | device_count: The number of devices to simulate. 13 | """ 14 | # Set XLA flags to simulate a CPU with a given number of devices 15 | flags = os.environ.get("XLA_FLAGS", "") 16 | flags += f" --xla_force_host_platform_device_count={device_count}" 17 | os.environ["XLA_FLAGS"] = flags 18 | # Disable CUDA to force XLA to use the CPU 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 20 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/chunkwise/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .native import mlstm_chunkwise__native_autograd, mlstm_chunkwise__native_custbw 5 | from .triton_limit_chunk import mlstm_chunkwise__limit_chunk 6 | from .triton_xl_chunk import mlstm_chunkwise__xl_chunk 7 | from .triton_xl_chunk_siging import mlstm_siging_chunkwise__xl_chunk 8 | 9 | registry = { 10 | "native_autograd": mlstm_chunkwise__native_autograd, 11 | "native_custbw": mlstm_chunkwise__native_custbw, 12 | "triton_limit_chunk": mlstm_chunkwise__limit_chunk, 13 | "triton_xl_chunk": mlstm_chunkwise__xl_chunk, 14 | "triton_xl_chunk_siging": mlstm_siging_chunkwise__xl_chunk, 15 | } 16 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/chunkwise/native/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_chunkwise__native_autograd, mlstm_chunkwise__native_custbw 5 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/chunkwise/triton_limit_chunk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_chunkwise__limit_chunk 5 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/chunkwise/triton_limit_chunk/chunkwise_gates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | 6 | 7 | import torch 8 | 9 | 10 | # Note: we separate this into a extra function for torch.compile. 11 | # torch.compile will compile this into a single kernel with ca. 0.2 ms runtime (compared to 2.5 ms non-fused kernels) 12 | # for a 1.3B sized model with ctx8192. 13 | @torch.compile 14 | def compute_gate_grads_vecDeltaI_vecDeltaF( 15 | matQ: torch.Tensor, 16 | matK: torch.Tensor, 17 | matDeltaQ: torch.Tensor, 18 | matDeltaK: torch.Tensor, 19 | vecF: torch.Tensor, 20 | ) -> tuple[torch.Tensor, torch.Tensor]: 21 | #! postprocessing: compute deltaF and deltaI gradients 22 | ## ? postprocessing 23 | # vecF = rearrange(vecF, "b nh nc l -> b nh (nc l)") 24 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 25 | matQ = matQ.to(torch.float32) 26 | matK = matK.to(torch.float32) 27 | matDeltaQ = matDeltaQ.to(torch.float32) 28 | matDeltaK = matDeltaK.to(torch.float32) 29 | vecDeltaFbar_acc = ((matQ * matDeltaQ) - (matK * matDeltaK)).sum(-1) 30 | vecDeltaFbar = vecDeltaFbar_acc.flip(-1).to(torch.float32).cumsum(-1).flip(-1) 31 | vecDeltaF = vecDeltaFbar * torch.sigmoid(-vecF) 32 | ## ? end postprocessing 33 | # compute deltaI 34 | # both are equivalent: 35 | # vecDeltaI = (matV * matDeltaV).sum(-1) 36 | vecDeltaI = (matK * matDeltaK).sum(-1) 37 | return vecDeltaI, vecDeltaF 38 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/chunkwise/triton_xl_chunk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_chunkwise__xl_chunk 5 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/chunkwise/triton_xl_chunk/chunkwise_gates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | """In this file we compute the chunkwise or cumulative gates (i.e. vecA and vecB) 6 | for the forward and backward pass of the mLSTM. 7 | We use the stable formulations, i.e. we avoid subtraction of forget gates. 8 | """ 9 | 10 | import torch 11 | from einops import rearrange 12 | from torch.nn.functional import logsigmoid 13 | 14 | 15 | @torch.compile 16 | def compute_chunkwise_log_gates_vecB_vecA( 17 | vecI: torch.Tensor, # (B, NH, S) 18 | vecF: torch.Tensor, # (B, NH, S) 19 | chunk_size: int, 20 | ): 21 | B, NH, S = vecI.shape 22 | assert S % chunk_size == 0, f"S={S} is not divisible by chunk_size={chunk_size}" 23 | _device = vecI.device 24 | NC = S // chunk_size 25 | L = chunk_size 26 | 27 | # compute vecB 28 | vecF_logsig = logsigmoid(vecF.to(dtype=torch.float32)) 29 | vecF_logsig_chunked = rearrange(vecF_logsig, "b nh (nc l) -> b nh nc l", nc=NC, l=L) 30 | vecB = vecF_logsig_chunked.cumsum(dim=-1) 31 | 32 | # compute vecA 33 | vecI_chunked = rearrange(vecI, "b nh (nc l) -> b nh nc l", nc=NC, l=L) 34 | # unstable vecA computation: 35 | # vecA = (vecB[..., -1, None] - vecB) + vecI # (B, NH, NC, L) 36 | # stable vecA computation: 37 | vecA = ( 38 | torch.cat( 39 | [ 40 | vecF_logsig_chunked[..., 1:].flip(-1).cumsum(-1).flip(-1), 41 | torch.zeros((B, NH, NC, 1), device=_device, dtype=torch.float32), 42 | ], 43 | dim=-1, 44 | ) 45 | + vecI_chunked 46 | ) # (B, NH, NC, L) 47 | return vecB, vecA 48 | 49 | 50 | @torch.compile 51 | def compute_chunkwise_log_gates_vecB( 52 | vecF: torch.Tensor, # (B, NH, S) 53 | chunk_size: int, 54 | ): 55 | B, NH, S = vecF.shape 56 | assert S % chunk_size == 0, f"S={S} is not divisible by chunk_size={chunk_size}" 57 | NC = S // chunk_size 58 | L = chunk_size 59 | 60 | # compute vecB 61 | vecF_logsig = logsigmoid(vecF.to(dtype=torch.float32)) 62 | vecF_logsig_chunked = rearrange(vecF_logsig, "b nh (nc l) -> b nh nc l", nc=NC, l=L) 63 | vecB = vecF_logsig_chunked.cumsum(dim=-1) 64 | 65 | return vecB 66 | 67 | 68 | # Note: we separate this into a extra function for torch.compile. 69 | # torch.compile will compile this into a single kernel with ca. 0.2 ms runtime (compared to 2.5 ms non-fused kernels) 70 | # for a 1.3B sized model with ctx8192. 71 | @torch.compile 72 | def compute_gate_grads_vecDeltaI_vecDeltaF( 73 | matQ: torch.Tensor, 74 | matK: torch.Tensor, 75 | matDeltaQ: torch.Tensor, 76 | matDeltaK: torch.Tensor, 77 | vecF: torch.Tensor, 78 | ) -> tuple[torch.Tensor, torch.Tensor]: 79 | #! postprocessing: compute deltaF and deltaI gradients 80 | ## ? postprocessing 81 | # vecF = rearrange(vecF, "b nh nc l -> b nh (nc l)") 82 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 83 | matQ = matQ.to(torch.float32) 84 | matK = matK.to(torch.float32) 85 | matDeltaQ = matDeltaQ.to(torch.float32) 86 | matDeltaK = matDeltaK.to(torch.float32) 87 | vecDeltaFbar_acc = ((matQ * matDeltaQ) - (matK * matDeltaK)).sum(-1) 88 | vecDeltaFbar = vecDeltaFbar_acc.flip(-1).to(torch.float32).cumsum(-1).flip(-1) 89 | vecDeltaF = vecDeltaFbar * torch.sigmoid(-vecF) 90 | ## ? end postprocessing 91 | # compute deltaI 92 | # both are equivalent: 93 | # vecDeltaI = (matV * matDeltaV).sum(-1) 94 | vecDeltaI = (matK * matDeltaK).sum(-1) 95 | return vecDeltaI, vecDeltaF 96 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/chunkwise/triton_xl_chunk_siging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_siging_chunkwise__xl_chunk 5 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .native import mlstm_parallel__native_autograd, mlstm_parallel__native_custbw 5 | from .native_siging import ( 6 | mlstm_siging_parallel__native_autograd, 7 | mlstm_siging_parallel__native_custbw, 8 | ) 9 | from .native_stablef import ( 10 | mlstm_parallel__native_stablef_autograd, 11 | mlstm_parallel__native_stablef_custbw, 12 | ) 13 | from .triton_limit_headdim import mlstm_parallel__limit_headdim 14 | 15 | registry = { 16 | "native_autograd": mlstm_parallel__native_autograd, 17 | "native_custbw": mlstm_parallel__native_custbw, 18 | "native_stablef_autograd": mlstm_parallel__native_stablef_autograd, 19 | "native_stablef_custbw": mlstm_parallel__native_stablef_custbw, 20 | "triton_limit_headdim": mlstm_parallel__limit_headdim, 21 | "native_siging_autograd": mlstm_siging_parallel__native_autograd, 22 | "native_siging_custbw": mlstm_siging_parallel__native_custbw, 23 | } 24 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/_legacy_native_siging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/_legacy_native_siging/sig_ingate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import torch 5 | 6 | from .ops import ( 7 | build_forget_gate_matrix, 8 | qkdecmask_normalize_parallel, 9 | ) 10 | 11 | 12 | def mlstm_siging_parallel( 13 | queries: torch.Tensor, 14 | keys: torch.Tensor, 15 | values: torch.Tensor, 16 | igate_preact: torch.Tensor, 17 | fgate_preact: torch.Tensor, 18 | lower_triangular_matrix: torch.Tensor = None, 19 | qk_decmask_normalize: bool = True, 20 | normalization_mode: str = "max_abs_sum_C_1", 21 | normalize_sqrt_d: bool = False, # in order to match with new implementation 22 | normalizer_offset: float = 0.0, 23 | eps: float = 1e-6, 24 | **kwargs, 25 | ): 26 | """This is the core linear hopfield retrieval operation in parallel form. 27 | It has sigmoid input gates instead of exponential ones. 28 | 29 | Args: 30 | queries (torch.Tensor): (B, NH, S, DH) 31 | keys (torch.Tensor): (B, NH, S, DH) 32 | values (torch.Tensor): (B, NH, S, DH) 33 | igate_preact (torch.Tensor): (B, NH, S, 1) 34 | fgate_preact (torch.Tensor): (B, NH, S, 1) 35 | lower_triangular_matrix (torch.Tensor, optional): (S,S). Defaults to None. 36 | qk_decmask_normalize (bool, optional): Wether to normalize the combination matrix C. Defaults to True. 37 | normalization_mode (str, optional): Normalization mode for the combination matrix C. Defaults to "sum_C". 38 | normalize_sqrt_d (bool, optional): Wether to normalize the combination matrix C by the sqrt of the qk head dimension. 39 | Originally, this was not present. In the new implementation we add this in order to 40 | match with the exponential input gate version. Defaults to False. 41 | normalizer_offset (float, optional): Offset for the normalizer. This number is added to the denominator. 42 | Defaults to 0.0. 43 | eps (float, optional): Used for building the forgetgate matrix. Defaults to 1e-6. 44 | 45 | Returns: 46 | torch.Tensor: (B, NH, S, DH), retrieved values 47 | """ 48 | B, NH, S, DHQK = queries.shape 49 | 50 | # forget gate matrix 51 | fgates = torch.nn.functional.logsigmoid(fgate_preact) # (B, NH, S, 1) 52 | fg_matrix = build_forget_gate_matrix( 53 | per_timestep_fg_gate_vals=fgates, 54 | per_time_step_decay_vals_in_logspace=True, 55 | return_log_matrix=True, 56 | lower_triangular_matrix=lower_triangular_matrix, 57 | eps=0.0, 58 | ) # (B, NH, S, S) 59 | 60 | # input gates 61 | igates = torch.nn.functional.logsigmoid(igate_preact) # (B, NH, S, 1) 62 | # gate decay matrix D 63 | log_D_matrix = fg_matrix + igates.transpose(-2, -1) # (B, NH, S, S) 64 | D_matrix = torch.exp(log_D_matrix) # (B, NH, S, S) 65 | # combination matrix C 66 | qk_matrix = queries @ keys.transpose(-2, -1) # (B, NH, S, S) 67 | C_matrix = qk_matrix * D_matrix # (B, NH, S, S) 68 | if normalize_sqrt_d: 69 | C_matrix = C_matrix * (DHQK**-0.5) 70 | if qk_decmask_normalize: 71 | # (B, NH, S, S) 72 | C_matrix_normalized = qkdecmask_normalize_parallel( 73 | C_matrix=C_matrix, 74 | normalization_mode=normalization_mode, 75 | normalizer_offset=normalizer_offset, 76 | eps=eps, 77 | ) 78 | else: 79 | C_matrix_normalized = C_matrix 80 | 81 | # retrieved values 82 | retrieved_values = C_matrix_normalized @ values # (B, NH, S, DH) 83 | return retrieved_values 84 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/native/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_parallel__native_autograd, mlstm_parallel__native_custbw 5 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/native/bw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import torch 5 | from torch.nn.functional import logsigmoid 6 | 7 | 8 | def mlstm_parallel_bw( 9 | matDeltaHtilde: torch.Tensor, 10 | matQ: torch.Tensor, 11 | matK: torch.Tensor, 12 | matV: torch.Tensor, 13 | vecI: torch.Tensor, 14 | vecF: torch.Tensor, 15 | vecN: torch.Tensor, 16 | vecM: torch.Tensor, 17 | eps: float = 1e-6, 18 | ) -> tuple[torch.Tensor, ...]: 19 | B, NH, S, DHQK = matQ.shape 20 | assert matK.shape == (B, NH, S, DHQK) 21 | assert vecI.shape == (B, NH, S) 22 | assert vecF.shape == (B, NH, S) 23 | 24 | _dtype, _device = matQ.dtype, matQ.device 25 | 26 | vecLogSigF = logsigmoid(vecF) # (B, NH, S) 27 | vecLogSigF_cumsum = vecLogSigF.cumsum(-1) 28 | 29 | matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :] 30 | 31 | ltr = torch.tril( 32 | torch.ones( 33 | (S, S), 34 | dtype=torch.bool, 35 | device=_device, 36 | ) 37 | ) 38 | 39 | matLogSigF_mask = torch.where(ltr, matLogSigF, -float("inf")) 40 | 41 | matLogD = matLogSigF_mask + vecI[:, :, None, :] 42 | 43 | matLogD_stabilized = matLogD - vecM[:, :, :, None] 44 | 45 | matD = torch.exp(matLogD_stabilized) # (B, NH, S, S) 46 | 47 | # intermediate delta-errors 48 | matDeltaC = matDeltaHtilde @ matV.transpose(-2, -1) / (vecN[:, :, :, None] + eps) 49 | 50 | matS = (matQ @ matK.transpose(-2, -1)) * (DHQK**-0.5) 51 | 52 | matDeltaDtilde = matDeltaC * matD * matS 53 | 54 | vecDeltaI = torch.sum(matDeltaDtilde, dim=-2) 55 | 56 | # output delta-errors / gradients 57 | matP = matDeltaC * matD 58 | 59 | matDeltaQ = (matP @ matK) * (DHQK**-0.5) 60 | matDeltaK = (matP.transpose(-2, -1) @ matQ) * (DHQK**-0.5) 61 | 62 | matCtilde = matS * matD 63 | matDeltaV = matCtilde.transpose(-2, -1) @ ( 64 | matDeltaHtilde / (vecN[:, :, :, None] + eps) 65 | ) 66 | 67 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 68 | vecDeltaFbar_acc = (matQ * matDeltaQ - matK * matDeltaK).sum(-1) 69 | vecDeltaFbar = vecDeltaFbar_acc.flip(-1).cumsum(-1).flip(-1) 70 | vecDeltaF = vecDeltaFbar * torch.sigmoid(-vecF) 71 | 72 | return ( 73 | matDeltaQ, 74 | matDeltaK, 75 | matDeltaV, 76 | vecDeltaI, 77 | vecDeltaF, 78 | ) 79 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/native/fw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | """ 6 | PyTorch 7 | 8 | mLSTM forward and backward pass. Parallel formulation. 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | def mlstm_parallel_fw( 16 | matQ: torch.Tensor, 17 | matK: torch.Tensor, 18 | matV: torch.Tensor, 19 | vecI: torch.Tensor, 20 | vecF: torch.Tensor, 21 | eps: float = 1e-6, 22 | ) -> torch.Tensor: 23 | B, NH, S, DHQK = matQ.shape 24 | assert matK.shape == (B, NH, S, DHQK) 25 | assert vecI.shape == (B, NH, S) 26 | assert vecF.shape == (B, NH, S) 27 | 28 | _dtype, _device = matQ.dtype, matQ.device 29 | 30 | vecLogSigF = F.logsigmoid(vecF) # (B, NH, S) 31 | 32 | vecLogSigF_cumsum = vecLogSigF.cumsum(-1) 33 | matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :] 34 | 35 | ltr = torch.tril( 36 | torch.ones( 37 | (S, S), 38 | dtype=torch.bool, 39 | device=_device, 40 | ) 41 | ) 42 | 43 | matLogSigF_mask = torch.where(ltr, matLogSigF, -float("inf")) 44 | 45 | matLogD = matLogSigF_mask + vecI[:, :, None, :] 46 | 47 | vecM, _ = torch.max(matLogD, dim=-1, keepdim=True) # (B, NH, S, 1) 48 | matLogD_stabilized = matLogD - vecM 49 | 50 | matD = torch.exp(matLogD_stabilized) # (B, NH, S, S) 51 | 52 | matS = (matQ @ matK.transpose(-2, -1)) * (DHQK**-0.5) # (B, NH, S, S) 53 | 54 | matCtilde = matS * matD # (B, NH, S, S) 55 | vecN = torch.maximum( 56 | matCtilde.sum(dim=-1, keepdim=True).abs(), torch.exp(-vecM) 57 | ) # (B, NH, S, 1) 58 | # (B, NH, S, S) 59 | matC = matCtilde / (vecN + eps) 60 | 61 | matH = matC @ matV # (B, NH, S, DH) 62 | 63 | vecN = vecN.squeeze(-1) 64 | vecM = vecM.squeeze(-1) 65 | 66 | return (matH, vecN, vecM) 67 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/native_siging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import ( 5 | mlstm_siging_parallel__native_autograd, 6 | mlstm_siging_parallel__native_custbw, 7 | ) 8 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/native_siging/bw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | import torch 6 | from torch.nn.functional import logsigmoid 7 | 8 | 9 | def mlstm_siging_parallel_bw( 10 | matDeltaHtilde: torch.Tensor, 11 | matQ: torch.Tensor, 12 | matK: torch.Tensor, 13 | matV: torch.Tensor, 14 | vecI: torch.Tensor, 15 | vecF: torch.Tensor, 16 | vecN: torch.Tensor, 17 | eps: float = 1e-6, 18 | stable_fgate: bool = True, 19 | normalize: bool = True, 20 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 21 | B, NH, S, DHQK = matQ.shape 22 | assert matK.shape == (B, NH, S, DHQK) 23 | assert vecI.shape == (B, NH, S) 24 | assert vecF.shape == (B, NH, S) 25 | 26 | _dtype, _device = matQ.dtype, matQ.device 27 | 28 | vecLogSigF = logsigmoid(vecF) # (B, NH, S) 29 | 30 | if stable_fgate: 31 | matLogSigF_tril = vecLogSigF[:, :, :, None].repeat(1, 1, 1, S).tril(-1) 32 | matLogSigF = matLogSigF_tril.cumsum(-2) 33 | else: 34 | vecLogSigF_cumsum = vecLogSigF.cumsum(-1) 35 | matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :] 36 | 37 | ltr = torch.tril( 38 | torch.ones( 39 | (S, S), 40 | dtype=torch.bool, 41 | device=_device, 42 | ) 43 | ) 44 | 45 | matLogSigF_mask = torch.where(ltr, matLogSigF, -float("inf")) 46 | 47 | vecLogSigI = logsigmoid(vecI) 48 | 49 | matLogD = matLogSigF_mask + vecLogSigI[:, :, None, :] 50 | 51 | matD = torch.exp(matLogD) # (B, NH, S, S) 52 | 53 | # intermediate delta-errors 54 | if normalize: 55 | matDeltaC = ( 56 | matDeltaHtilde @ matV.transpose(-2, -1) / (vecN[:, :, :, None] + eps) 57 | ) 58 | else: 59 | matDeltaC = matDeltaHtilde @ matV.transpose(-2, -1) 60 | 61 | matS = (matQ @ matK.transpose(-2, -1)) * (DHQK**-0.5) 62 | 63 | matDeltaDtilde = matDeltaC * matD * matS 64 | 65 | vecDeltaIbar = torch.sum(matDeltaDtilde, dim=-2) 66 | 67 | # output delta-errors / gradients 68 | matP = matDeltaC * matD 69 | 70 | matDeltaQ = (matP @ matK) * (DHQK**-0.5) 71 | matDeltaK = (matP.transpose(-2, -1) @ matQ) * (DHQK**-0.5) 72 | 73 | matCtilde = matS * matD 74 | 75 | if normalize: 76 | matDeltaV = matCtilde.transpose(-2, -1) @ ( 77 | matDeltaHtilde / (vecN[:, :, :, None] + eps) 78 | ) 79 | else: 80 | matDeltaV = matCtilde.transpose(-2, -1) @ matDeltaHtilde 81 | 82 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 83 | vecDeltaFbar_acc = (matQ * matDeltaQ - matK * matDeltaK).sum(-1) 84 | vecDeltaFbar = vecDeltaFbar_acc.flip(-1).cumsum(-1).flip(-1) 85 | vecDeltaF = vecDeltaFbar * torch.sigmoid(-vecF) 86 | 87 | vecDeltaI = vecDeltaIbar * torch.sigmoid(-vecI) 88 | 89 | return ( 90 | matDeltaQ, 91 | matDeltaK, 92 | matDeltaV, 93 | vecDeltaI, 94 | vecDeltaF, 95 | ) 96 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/native_siging/fw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | """ 6 | PyTorch 7 | 8 | mLSTM sigmoid input gate forward pass. Parallel formulation. 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | def mlstm_siging_parallel_fw( 16 | matQ: torch.Tensor, 17 | matK: torch.Tensor, 18 | matV: torch.Tensor, 19 | vecI: torch.Tensor, 20 | vecF: torch.Tensor, 21 | eps: float = 1e-6, 22 | stable_fgate: bool = True, 23 | normalize: bool = True, 24 | ) -> torch.Tensor: 25 | B, NH, S, DHQK = matQ.shape 26 | assert matK.shape == (B, NH, S, DHQK) 27 | assert vecI.shape == (B, NH, S) 28 | assert vecF.shape == (B, NH, S) 29 | 30 | _dtype, _device = matQ.dtype, matQ.device 31 | 32 | vecLogSigF = F.logsigmoid(vecF) # (B, NH, S) 33 | 34 | if stable_fgate: 35 | matLogSigF_tril = vecLogSigF[:, :, :, None].repeat(1, 1, 1, S).tril(-1) 36 | matLogSigF = matLogSigF_tril.cumsum(-2) 37 | else: 38 | vecLogSigF_cumsum = vecLogSigF.cumsum(-1) 39 | matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :] 40 | 41 | ltr = torch.tril( 42 | torch.ones( 43 | (S, S), 44 | dtype=torch.bool, 45 | device=_device, 46 | ) 47 | ) 48 | 49 | matLogSigF_mask = torch.where(ltr, matLogSigF, -float("inf")) 50 | 51 | vecLogSigI = F.logsigmoid(vecI) 52 | 53 | matLogD = matLogSigF_mask + vecLogSigI[:, :, None, :] 54 | 55 | matD = torch.exp(matLogD) # (B, NH, S, S) 56 | 57 | matS = (matQ @ matK.transpose(-2, -1)) * (DHQK**-0.5) # (B, NH, S, S) 58 | 59 | matCtilde = matS * matD # (B, NH, S, S) 60 | if normalize: 61 | vecN = torch.maximum( 62 | matCtilde.sum(dim=-1, keepdim=True).abs(), 63 | torch.tensor([1.0], dtype=_dtype, device=_device), 64 | ) # (B, NH, S, 1) 65 | # (B, NH, S, S) 66 | matC = matCtilde / (vecN + eps) 67 | vecN = vecN.squeeze(-1) 68 | else: 69 | matC = matCtilde 70 | vecN = None 71 | 72 | matH = matC @ matV # (B, NH, S, DH) 73 | 74 | return (matH, vecN) 75 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/native_stablef/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import ( 5 | mlstm_parallel__native_stablef_autograd, 6 | mlstm_parallel__native_stablef_custbw, 7 | ) 8 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/native_stablef/bw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import torch 5 | from torch.nn.functional import logsigmoid 6 | 7 | 8 | def mlstm_parallel_bw( 9 | matDeltaHtilde: torch.Tensor, 10 | matQ: torch.Tensor, 11 | matK: torch.Tensor, 12 | matV: torch.Tensor, 13 | vecI: torch.Tensor, 14 | vecF: torch.Tensor, 15 | vecN: torch.Tensor, 16 | vecM: torch.Tensor, 17 | eps: float = 1e-6, 18 | ) -> tuple[torch.Tensor, ...]: 19 | B, NH, S, DHQK = matQ.shape 20 | assert matK.shape == (B, NH, S, DHQK) 21 | assert vecI.shape == (B, NH, S) 22 | assert vecF.shape == (B, NH, S) 23 | 24 | _dtype, _device = matQ.dtype, matQ.device 25 | 26 | vecLogSigF = logsigmoid(vecF) # (B, NH, S) 27 | 28 | matLogSigF_tril = vecLogSigF[:, :, :, None].repeat(1, 1, 1, S).tril(-1) 29 | matLogSigF_cum = matLogSigF_tril.cumsum(-2) 30 | 31 | ltr = torch.tril( 32 | torch.ones( 33 | (S, S), 34 | dtype=torch.bool, 35 | device=_device, 36 | ) 37 | ) 38 | 39 | matLogSigF_mask = torch.where(ltr, matLogSigF_cum, -float("inf")) 40 | 41 | matLogD = matLogSigF_mask + vecI[:, :, None, :] 42 | 43 | matLogD_stabilized = matLogD - vecM[:, :, :, None] 44 | 45 | matD = torch.exp(matLogD_stabilized) # (B, NH, S, S) 46 | 47 | # intermediate delta-errors 48 | matDeltaC = matDeltaHtilde @ matV.transpose(-2, -1) / (vecN[:, :, :, None] + eps) 49 | 50 | matS = (matQ @ matK.transpose(-2, -1)) * (DHQK**-0.5) 51 | 52 | matDeltaDtilde = matDeltaC * matD * matS 53 | 54 | vecDeltaI = torch.sum(matDeltaDtilde, dim=-2) 55 | 56 | # output delta-errors / gradients 57 | matP = matDeltaC * matD 58 | 59 | matDeltaQ = (matP @ matK) * (DHQK**-0.5) 60 | matDeltaK = (matP.transpose(-2, -1) @ matQ) * (DHQK**-0.5) 61 | 62 | matCtilde = matS * matD 63 | matDeltaV = matCtilde.transpose(-2, -1) @ ( 64 | matDeltaHtilde / (vecN[:, :, :, None] + eps) 65 | ) 66 | 67 | # compute the vecDeltaFbar values with dfbar = rev_cumsum((q*dq - k*dk).sum(-1)) 68 | vecDeltaFbar_acc = (matQ * matDeltaQ - matK * matDeltaK).sum(-1) 69 | vecDeltaFbar = vecDeltaFbar_acc.flip(-1).cumsum(-1).flip(-1) 70 | vecDeltaF = vecDeltaFbar * torch.sigmoid(-vecF) 71 | 72 | return ( 73 | matDeltaQ, 74 | matDeltaK, 75 | matDeltaV, 76 | vecDeltaI, 77 | vecDeltaF, 78 | ) 79 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/native_stablef/fw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | # Maximilian Beck 5 | """ 6 | PyTorch 7 | 8 | mLSTM forward and backward pass. Parallel formulation. 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | def mlstm_parallel_fw( 16 | matQ: torch.Tensor, 17 | matK: torch.Tensor, 18 | matV: torch.Tensor, 19 | vecI: torch.Tensor, 20 | vecF: torch.Tensor, 21 | eps: float = 1e-6, 22 | ) -> torch.Tensor: 23 | B, NH, S, DHQK = matQ.shape 24 | assert matK.shape == (B, NH, S, DHQK) 25 | assert vecI.shape == (B, NH, S) 26 | assert vecF.shape == (B, NH, S) 27 | 28 | _dtype, _device = matQ.dtype, matQ.device 29 | 30 | vecLogSigF = F.logsigmoid(vecF) # (B, NH, S) 31 | 32 | matLogSigF_tril = vecLogSigF[:, :, :, None].repeat(1, 1, 1, S).tril(-1) 33 | matLogSigF_cum = matLogSigF_tril.cumsum(-2) 34 | 35 | ltr = torch.tril( 36 | torch.ones( 37 | (S, S), 38 | dtype=torch.bool, 39 | device=_device, 40 | ) 41 | ) 42 | 43 | matLogSigF_mask = torch.where(ltr, matLogSigF_cum, -float("inf")) 44 | 45 | matLogD = matLogSigF_mask + vecI[:, :, None, :] 46 | 47 | vecM, _ = torch.max(matLogD, dim=-1, keepdim=True) # (B, NH, S, 1) 48 | matLogD_stabilized = matLogD - vecM 49 | 50 | matD = torch.exp(matLogD_stabilized) # (B, NH, S, S) 51 | 52 | matS = (matQ @ matK.transpose(-2, -1)) * (DHQK**-0.5) # (B, NH, S, S) 53 | 54 | matCtilde = matS * matD # (B, NH, S, S) 55 | vecN = torch.maximum( 56 | matCtilde.sum(dim=-1, keepdim=True).abs(), torch.exp(-vecM) 57 | ) # (B, NH, S, 1) 58 | # (B, NH, S, S) 59 | matC = matCtilde / (vecN + eps) 60 | 61 | matH = matC @ matV # (B, NH, S, DH) 62 | 63 | vecN = vecN.squeeze(-1) 64 | vecM = vecM.squeeze(-1) 65 | 66 | return (matH, vecN, vecM) 67 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/triton_limit_headdim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fwbw import mlstm_parallel__limit_headdim 5 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/parallel/triton_limit_headdim/fw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import torch 5 | import triton 6 | 7 | from ....triton.parallel.limit_headdim import mlstm_parallel_fw_kernel 8 | 9 | MINIMUM_MAX_VAL = -10 # -float("inf") # -10.0 10 | 11 | 12 | def mlstm_parallel_fw( 13 | matQ: torch.Tensor, 14 | matK: torch.Tensor, 15 | matV: torch.Tensor, 16 | vecI: torch.Tensor, 17 | vecF: torch.Tensor, 18 | eps: float = 1e-6, 19 | # BLOCK_Q: int = BLOCK_Q, 20 | # BLOCK_KV: int = BLOCK_KV, 21 | ) -> torch.Tensor: 22 | # batch size, number of heads, sequence length, head dimension 23 | BS, NH, SL, DH = matQ.shape 24 | assert vecI.shape == (BS, NH, SL) 25 | assert vecF.shape == (BS, NH, SL) 26 | 27 | # shape constraints 28 | HEAD_DIM_Q, HEAD_DIM_K = matQ.shape[-1], matK.shape[-1] 29 | # when v is in float8_e5m2 it is transposed. 30 | HEAD_DIM_V = matV.shape[-1] 31 | assert ( 32 | HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 33 | ), f"Q, K, V must have the same head dimension" 34 | assert HEAD_DIM_K in { 35 | 16, 36 | 32, 37 | 64, 38 | 128, 39 | 256, 40 | }, f"Only head dimensions 16, 32, 64, 128, 256 are supported, got {HEAD_DIM_K}" 41 | 42 | def grid(args): 43 | return ( 44 | triton.cdiv(matQ.shape[2], args["BLOCK_Q"]), 45 | matQ.shape[0] * matQ.shape[1], 46 | 1, 47 | ) 48 | 49 | # fix grid for debugging 50 | # def grid(args): 51 | # return triton.cdiv(matQ.shape[2], BLOCK_Q), matQ.shape[0] * matQ.shape[1], 1 52 | 53 | # print(f"Triton grid: {grid(None)}, BLOCK_Q: {BLOCK_Q}, BLOCK_KV: {BLOCK_KV}") 54 | 55 | matH = torch.empty_like(matQ) 56 | 57 | vecN = torch.zeros( 58 | (matQ.shape[0], matQ.shape[1], matQ.shape[2]), 59 | device=matQ.device, 60 | dtype=torch.float32, 61 | ) 62 | vecM = torch.zeros( 63 | (matQ.shape[0], matQ.shape[1], matQ.shape[2]), 64 | device=matQ.device, 65 | dtype=torch.float32, 66 | ) 67 | 68 | # Note we want to compute the forget gate cumsum in float32. 69 | # This results in a more accurate cumsum and lower numerical precision errors. 70 | vecF_cs = torch.nn.functional.logsigmoid(vecF.to(dtype=torch.float32)).cumsum(-1) 71 | 72 | mlstm_parallel_fw_kernel[grid]( 73 | matQ=matQ.contiguous(), 74 | matK=matK.contiguous(), 75 | matV=matV.contiguous(), 76 | vecI=vecI.contiguous(), 77 | vecF_cs=vecF_cs.contiguous(), 78 | qk_scale=HEAD_DIM_Q**0.5, 79 | matH=matH, 80 | vecN=vecN, 81 | vecM=vecM, 82 | stride_qz=matQ.stride(0), 83 | stride_qh=matQ.stride(1), 84 | stride_qm=matQ.stride(2), 85 | stride_qk=matQ.stride(3), 86 | stride_kz=matK.stride(0), 87 | stride_kh=matK.stride(1), 88 | stride_kn=matK.stride(2), 89 | stride_kk=matK.stride(3), 90 | stride_vz=matV.stride(0), 91 | stride_vh=matV.stride(1), 92 | stride_vk=matV.stride(2), 93 | stride_vn=matV.stride(3), 94 | stride_hz=matH.stride(0), 95 | stride_hh=matH.stride(1), 96 | stride_hm=matH.stride(2), 97 | stride_hn=matH.stride(3), 98 | stride_ifmn_z=vecF_cs.stride(0), 99 | stride_ifmn_h=vecF_cs.stride(1), 100 | stride_ifmn_m=vecF_cs.stride(2), 101 | Z=BS, 102 | H=NH, 103 | N_CTX=SL, 104 | HEAD_DIM=HEAD_DIM_K, 105 | # BLOCK_Q=BLOCK_Q, 106 | # BLOCK_KV=BLOCK_KV, 107 | MINIMUM_MAX_VAL=MINIMUM_MAX_VAL, 108 | EPS=eps, 109 | ) 110 | 111 | return matH, vecM, vecN 112 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/recurrent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .native_sequence import ( 5 | mlstm_recurrent_sequence__native_fw, 6 | mlstm_recurrent_sequence__triton_alternate_step_fw, 7 | mlstm_recurrent_sequence__triton_step_fused_fw, 8 | ) 9 | from .native_step import mlstm_recurrent_step__native 10 | from .triton_step import mlstm_recurrent_step__triton 11 | from .triton_step_alternate import mlstm_recurrent_step__triton_alternate 12 | 13 | registry_step = { 14 | "native": mlstm_recurrent_step__native, 15 | # "triton_alternate": mlstm_recurrent_step__triton_alternate, 16 | "triton": mlstm_recurrent_step__triton, 17 | } 18 | 19 | 20 | registry_sequence = { 21 | "native_sequence__native": mlstm_recurrent_sequence__native_fw, 22 | # "native_sequence__triton_alternate": mlstm_recurrent_sequence__triton_alternate_step_fw, 23 | "native_sequence__triton": mlstm_recurrent_sequence__triton_step_fused_fw, 24 | } 25 | -------------------------------------------------------------------------------- /mlstm_kernels/torch/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import functools 5 | 6 | import numpy as np 7 | import torch 8 | import triton.language as tl 9 | 10 | _torch_to_triton_dtype = { 11 | torch.float32: tl.float32, 12 | torch.float16: tl.float16, 13 | torch.bfloat16: tl.bfloat16, 14 | } 15 | 16 | 17 | def dtype2str(dtype: torch.dtype) -> str: 18 | if dtype == torch.float32: 19 | return "fp32" 20 | elif dtype == torch.float16: 21 | return "fp16" 22 | elif dtype == torch.float64: 23 | return "fp64" 24 | elif dtype == torch.bfloat16: 25 | return "bf16" 26 | else: 27 | raise ValueError(f"Unsupported dtype: {dtype}") 28 | 29 | 30 | def contiguous(fn): 31 | @functools.wraps(fn) 32 | def wrapper(ctx, *args, **kwargs): 33 | return fn( 34 | ctx, 35 | *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), 36 | **{ 37 | k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) 38 | for k, v in kwargs.items() 39 | }, 40 | ) 41 | 42 | return wrapper 43 | 44 | 45 | def contiguous_noctx(fn): 46 | @functools.wraps(fn) 47 | def wrapper(*args, **kwargs): 48 | return fn( 49 | *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), 50 | **{ 51 | k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) 52 | for k, v in kwargs.items() 53 | }, 54 | ) 55 | 56 | return wrapper 57 | 58 | 59 | def torch2triton_dtype(dtype): 60 | return _torch_to_triton_dtype[dtype] 61 | 62 | 63 | def to_numpy(tensor: torch.Tensor) -> np.ndarray: 64 | return tensor.detach().cpu().to(dtype=torch.float64).numpy() 65 | 66 | 67 | def tensor_or_none(x): 68 | return x if x is None else torch.tensor(x) 69 | 70 | 71 | def int_or_none(x): 72 | return x if x is None else int(x) 73 | -------------------------------------------------------------------------------- /mlstm_kernels/triton/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/triton/chunkwise/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/triton/chunkwise/limit_chunk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .bw_kernel_parallel import mlstm_chunkwise__parallel_bw_dQKV_kernel 5 | from .bw_kernel_recurrent import mlstm_chunkwise__recurrent_bw_dC_kernel 6 | from .fw_kernel_parallel import mlstm_chunkwise__parallel_fw_H_kernel 7 | from .fw_kernel_recurrent import mlstm_chunkwise__recurrent_fw_C_kernel 8 | -------------------------------------------------------------------------------- /mlstm_kernels/triton/chunkwise/xl_chunk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .bw_kernel_parallel_dK import mlstm_chunkwise__parallel_bw_dK_kernel 5 | from .bw_kernel_parallel_dQ import mlstm_chunkwise__parallel_bw_dQ_kernel 6 | from .bw_kernel_parallel_dV import mlstm_chunkwise__parallel_bw_dV_kernel 7 | from .bw_kernel_recurrent import mlstm_chunkwise__recurrent_bw_dC_kernel 8 | from .fw_kernel_parallel import mlstm_chunkwise__parallel_fw_Hintra_kernel 9 | from .fw_kernel_recurrent import mlstm_chunkwise__recurrent_fw_C_kernel 10 | -------------------------------------------------------------------------------- /mlstm_kernels/triton/chunkwise/xl_chunk_siging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .bw_kernel_parallel_dK import mlstm_siging_chunkwise__parallel_bw_dK_kernel 5 | from .bw_kernel_parallel_dQ import mlstm_siging_chunkwise__parallel_bw_dQ_kernel 6 | from .bw_kernel_parallel_dV import mlstm_siging_chunkwise__parallel_bw_dV_kernel 7 | from .bw_kernel_recurrent import mlstm_siging_chunkwise__recurrent_bw_dC_kernel 8 | from .fw_kernel_parallel import mlstm_siging_chunkwise__parallel_fw_Hintra_kernel 9 | from .fw_kernel_recurrent import mlstm_siging_chunkwise__recurrent_fw_C_kernel 10 | -------------------------------------------------------------------------------- /mlstm_kernels/triton/kernel_param_heuristics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import triton 5 | 6 | from ..utils.kernels import is_power_of_2 7 | 8 | 9 | def get_head_dim_block_size(head_dim: int, min_block_size: int = 64) -> int: 10 | # TODO make proper tests, for when and where this check is necessary. 11 | # For 160M model size, this check is not necessary. 12 | # assert ( 13 | # is_power_of_2(head_dim) or head_dim % min_block_size == 0 14 | # ), f"head_dim must be a power of 2 or multiple of {min_block_size}. Got {head_dim}." 15 | 16 | return min(min_block_size, triton.next_power_of_2(head_dim)) 17 | -------------------------------------------------------------------------------- /mlstm_kernels/triton/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/triton/parallel/limit_headdim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .fw_kernel import mlstm_parallel_fw_kernel 5 | from .bw_kernel import mlstm_parallel_bw_dKdV_kernel, mlstm_parallel_bw_dQ_kernel 6 | -------------------------------------------------------------------------------- /mlstm_kernels/triton/recurrent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/analysis/roofline_analysis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/analysis/roofline_analysis/plot_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from pathlib import Path 5 | 6 | import matplotlib as mpl 7 | import matplotlib.pyplot as plt 8 | 9 | fontsize_delta = 2.5 10 | FONTSIZE = 12 + fontsize_delta 11 | SMALL_OFFSET = 1 12 | FONTSIZE_SMALL = FONTSIZE - SMALL_OFFSET 13 | FONTSIZE_TICKS = FONTSIZE_SMALL 14 | 15 | MARKERSIZE = 6.0 16 | LINEWIDTH = 2.0 # default 1.5 17 | 18 | FIGSIZE = (2 * 12 * 1 / 2.54, 2 * 8 * 1 / 2.54) 19 | FIGSIZE_2COL = (4 * 0.7 * 12 * 1 / 2.54, 2 * 0.7 * 8 * 1 / 2.54) 20 | 21 | GRIDSPEC_KWARGS = {"wspace": 0.115, "hspace": 0} 22 | 23 | 24 | def get_plot_mpl_context(): 25 | return mpl.rc_context( 26 | rc={ 27 | "text.usetex": False, 28 | "font.size": FONTSIZE, 29 | "axes.labelsize": FONTSIZE, 30 | "legend.fontsize": FONTSIZE_SMALL, 31 | "xtick.labelsize": FONTSIZE_TICKS, 32 | "ytick.labelsize": FONTSIZE_TICKS, 33 | "axes.titlesize": FONTSIZE, 34 | "lines.markersize": MARKERSIZE, 35 | "lines.linewidth": LINEWIDTH, 36 | } 37 | ) 38 | 39 | 40 | def savefig(fig, filename: str): 41 | dir = Path("./plots/") 42 | dir.mkdir(parents=True, exist_ok=True) 43 | 44 | if filename is not None: 45 | for file_ending in ["png", "pdf", "svg"]: 46 | file = Path(f"./plots/plot_{filename}.{file_ending}") 47 | fig.savefig(file, dpi=300, bbox_inches="tight", pad_inches=-0.0020) 48 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/analysis/transfer_behavior/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/analysis/transfer_behavior/_norm_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | """This module contains the normalization functions. 5 | 6 | We do not add scale parameters here as they are initialized to 1.0. 7 | """ 8 | 9 | import torch 10 | 11 | 12 | def rms_normalize( 13 | x: torch.Tensor, eps: float = 1e-6, force_float32_reductions: bool = True 14 | ) -> torch.Tensor: 15 | # x: (B, ..., S,..., D) 16 | # apply rms norm over the last dimension, i.e. D dimension 17 | in_dtype = x.dtype 18 | if force_float32_reductions: 19 | x = x.float() 20 | x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps) 21 | return x.to(in_dtype) 22 | 23 | 24 | def layer_normalize( 25 | x: torch.Tensor, eps: float = 1e-6, force_float32_reductions: bool = True 26 | ) -> torch.Tensor: 27 | # x: (B, ..., S,..., D) 28 | # apply layer norm over the last dimension, i.e. D dimension 29 | in_dtype = x.dtype 30 | if force_float32_reductions: 31 | x = x.float() 32 | x_centered = x - x.mean(dim=-1, keepdim=True) 33 | y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + eps) 34 | return y.to(in_dtype) 35 | 36 | 37 | def no_normalize(x: torch.Tensor, **kwargs) -> torch.Tensor: 38 | return x 39 | 40 | 41 | def apply_normalize( 42 | norm_specifier: str, 43 | x: torch.Tensor, 44 | eps: float = 1e-6, 45 | force_float32_reductions: bool = True, 46 | ) -> torch.Tensor: 47 | if norm_specifier == "rms": 48 | return rms_normalize( 49 | x=x, eps=eps, force_float32_reductions=force_float32_reductions 50 | ) 51 | elif norm_specifier == "layer": 52 | return layer_normalize( 53 | x=x, eps=eps, force_float32_reductions=force_float32_reductions 54 | ) 55 | elif norm_specifier == "none": 56 | return no_normalize(x=x) 57 | else: 58 | raise ValueError(f"Unsupported norm specifier {norm_specifier}.") 59 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/benchmark/cuda_graphs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | from typing import Any 6 | from collections.abc import Callable 7 | 8 | import torch 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | 13 | def compile_with_cuda_graphs( 14 | fn: Callable[..., Any], warmups: int = 3 15 | ) -> torch.cuda.CUDAGraph: 16 | """ 17 | Compile the provided function with CUDA graphs. 18 | 19 | Args: 20 | fn: The function to compile. Should take no arguments. 21 | warmups: The number of warmup iterations to run. 22 | 23 | Returns: 24 | The compiled CUDA graph. Can be executed with `graph.replay()`. 25 | """ 26 | s = torch.cuda.Stream() 27 | s.wait_stream(torch.cuda.current_stream()) 28 | with torch.cuda.stream(s): 29 | for idx in range(warmups): 30 | LOGGER.debug(f"Running CUDA Graph Warmup Step {idx + 1}/{warmups}") 31 | fn() 32 | s.synchronize() 33 | if torch.distributed.is_initialized(): 34 | torch.distributed.barrier() 35 | torch.cuda.current_stream().wait_stream(s) 36 | 37 | LOGGER.debug("Tracing CUDA Graph for benchmark function.") 38 | graph = torch.cuda.CUDAGraph() 39 | with torch.cuda.graph(graph): 40 | fn() 41 | LOGGER.debug("CUDA Graph traced.") 42 | 43 | return graph 44 | 45 | 46 | def compile_kwargs_with_cuda_graphs( 47 | fn: Callable[[Any], Any], 48 | inputs: dict, 49 | warmups: int = 3, 50 | clone_outputs: bool = False, 51 | ) -> tuple[torch.cuda.CUDAGraph, Callable[[Any], Any]]: 52 | """ 53 | Compile the provided function with CUDA graphs. 54 | 55 | Args: 56 | fn: The function to compile. Should take no arguments. 57 | warmups: The number of warmup iterations to run. 58 | 59 | Returns: 60 | The compiled CUDA graph. Can be executed with `graph.replay()`. 61 | """ 62 | # Warmup. 63 | s = torch.cuda.Stream() 64 | s.wait_stream(torch.cuda.current_stream()) 65 | with torch.cuda.stream(s): 66 | for idx in range(warmups): 67 | LOGGER.debug(f"Running CUDA Graph Warmup Step {idx + 1}/{warmups}") 68 | _ = fn(**inputs) 69 | s.synchronize() 70 | if torch.distributed.is_initialized(): 71 | torch.distributed.barrier() 72 | torch.cuda.current_stream().wait_stream(s) 73 | 74 | # Trace the CUDA graph. 75 | LOGGER.debug("Tracing CUDA Graph for benchmark function.") 76 | graph = torch.cuda.CUDAGraph() 77 | with torch.cuda.graph(graph): 78 | outputs = fn(**inputs) 79 | LOGGER.debug("CUDA Graph traced.") 80 | 81 | # Create a replay function, using the input/output buffers. 82 | def fn_replay(**new_inputs): 83 | tree_map( 84 | lambda x, y: x.copy_(y) 85 | if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor) 86 | else None, 87 | inputs, 88 | new_inputs, 89 | ) 90 | graph.replay() 91 | if clone_outputs: 92 | return tree_map( 93 | lambda x: x.clone() if isinstance(x, torch.Tensor) else x, outputs 94 | ) 95 | else: 96 | return outputs 97 | 98 | return graph, fn_replay 99 | 100 | 101 | def tree_map(fn, tree, *rest): 102 | import jax 103 | 104 | return jax.tree_map(fn, tree, *rest) 105 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/benchmark/plot_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | FONTSIZE = 12 5 | SMALL_OFFSET = 1 6 | FONTSIZE_SMALL = FONTSIZE - SMALL_OFFSET 7 | FONTSIZE_TICKS = 11 8 | fontsize_delta = 0 9 | 10 | MARKERSIZE = 6.0 11 | LINEWIDTH = 2.0 # default 1.5 12 | 13 | FIGSIZE = (2 * 12 * 1 / 2.54, 2 * 8 * 1 / 2.54) 14 | FIGSIZE_2COL = (4 * 0.7 * 12 * 1 / 2.54, 2 * 0.7 * 8 * 1 / 2.54) 15 | 16 | GRIDSPEC_KWARGS = {"wspace": 0.115, "hspace": 0} 17 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/benchmark/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from pathlib import Path 5 | 6 | 7 | def setup_output_folder( 8 | output_dir: str = "./outputs_kernel_benchmarks", 9 | name_suffix: str | None = None, 10 | log_level: int | str | None = None, 11 | ) -> Path: 12 | import logging 13 | import sys 14 | from datetime import datetime 15 | 16 | timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 17 | 18 | output_folder_name = timestamp 19 | if name_suffix is not None: 20 | output_folder_name += f"__{name_suffix}" 21 | 22 | output_folder = Path(output_dir) / output_folder_name 23 | 24 | output_folder.mkdir(parents=True, exist_ok=False) 25 | 26 | logfile = output_folder / "benchmark.log" 27 | file_handler = logging.FileHandler(filename=logfile) 28 | stdout_handler = logging.StreamHandler(sys.stdout) 29 | if log_level is None: 30 | log_level = logging.INFO 31 | logging.basicConfig( 32 | handlers=[file_handler, stdout_handler], 33 | format="[%(asctime)s][%(name)s:%(lineno)d][%(levelname)s] - %(message)s", 34 | level=log_level, 35 | force=True, 36 | ) 37 | LOGGER = logging.getLogger(__name__) 38 | LOGGER.info(f"Logging to {logfile}") 39 | return output_folder 40 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/flops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/flops/model_flops_computation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | 5 | def compute_total_model_flops( 6 | total_fw_block_flops: int = None, 7 | batch_size: int = None, 8 | num_blocks: int = None, 9 | vocab_size: int = None, 10 | embedding_dim: int = None, 11 | sequence_length: int = None, 12 | num_train_steps: int = 1, 13 | include_embedding_flops: bool = True, 14 | include_logits_flops: bool = True, 15 | backward_flop_factor: float = 2.0, 16 | total_fw_block2_flops: int = None, 17 | num_blocks2: int = None, 18 | ) -> int: 19 | total_flops = 0 20 | 21 | total_block_flops = total_fw_block_flops * num_blocks 22 | if total_fw_block2_flops is not None: 23 | assert ( 24 | num_blocks2 is not None 25 | ), "num_blocks2 must be provided if total_fw_block2_flops is provided" 26 | total_block_flops += total_fw_block2_flops * num_blocks2 27 | 28 | total_flops += total_block_flops 29 | if include_embedding_flops: 30 | embedding_flops = 2 * sequence_length * vocab_size * embedding_dim 31 | total_flops += embedding_flops 32 | if include_logits_flops: 33 | logit_flops = 2 * sequence_length * vocab_size * embedding_dim 34 | total_flops += logit_flops 35 | 36 | total_flops = total_flops * batch_size * backward_flop_factor * num_train_steps 37 | 38 | return total_flops 39 | 40 | 41 | def compute_total_model_flops_for_block_flops_dict( 42 | block_flops_dict: dict[str, tuple[int, ...]], 43 | batch_size: int = None, 44 | num_blocks: int = None, 45 | vocab_size: int = None, 46 | embedding_dim: int = None, 47 | sequence_length: int = None, 48 | num_train_steps: int = None, 49 | sizes_dict: dict[str, dict[str, int | float]] = {}, 50 | include_embedding_flops: bool = True, 51 | include_logits_flops: bool = True, 52 | backward_flop_factor: float = 2.0, 53 | block2_flops_dict: dict[str, tuple[int, ...]] = None, 54 | num_blocks2: int = None, 55 | ) -> dict[str, tuple[int]]: 56 | total_flops_dict = {} 57 | for model_size in sizes_dict.keys(): 58 | block_flops = block_flops_dict[model_size] 59 | total_fw_block_flops = block_flops[0] 60 | size_dict = sizes_dict.get(model_size, {}) 61 | if batch_size is not None: 62 | size_dict["batch_size"] = batch_size 63 | if num_blocks is not None: 64 | size_dict["num_blocks"] = num_blocks 65 | if vocab_size is not None: 66 | size_dict["vocab_size"] = vocab_size 67 | if embedding_dim is not None: 68 | size_dict["embedding_dim"] = embedding_dim 69 | if sequence_length is not None: 70 | size_dict["sequence_length"] = sequence_length 71 | if num_train_steps is not None: 72 | size_dict["num_train_steps"] = num_train_steps 73 | 74 | if block2_flops_dict is not None: 75 | block_flops2 = block2_flops_dict[model_size] 76 | total_fw_block2_flops = block_flops2[0] 77 | else: 78 | total_fw_block2_flops = None 79 | 80 | if num_blocks2 is not None: 81 | size_dict["num_blocks2"] = num_blocks2 82 | 83 | total_flops = compute_total_model_flops( 84 | total_fw_block_flops=total_fw_block_flops, 85 | total_fw_block2_flops=total_fw_block2_flops, 86 | include_embedding_flops=include_embedding_flops, 87 | include_logits_flops=include_logits_flops, 88 | backward_flop_factor=backward_flop_factor, 89 | **size_dict, 90 | ) 91 | total_flops_dict[model_size] = (total_flops,) 92 | 93 | return total_flops_dict 94 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/flops/slstm_block_flop_counts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | """Contains the sLSTM block flop counts.""" 5 | 6 | from collections.abc import Callable 7 | 8 | 9 | def count_flops_slstm_cell_fw(S, d, Nh, factor_exp=1, factor_div=1): 10 | """Counts the number of flops in the forward pass of an sLSTM cell.""" 11 | assert d % Nh == 0, f"d must be divisible by Nh, got d={d} and Nh={Nh}" 12 | dh = d // Nh 13 | 14 | return S * Nh * dh * (8 * dh + 5 * factor_exp + 4 + 5 + 1 * factor_div) 15 | 16 | 17 | def _count_ln_flops(d): 18 | # not sure about this 19 | # d for mean, d to subtract mean, d for variance, d for division 20 | # return 4 * d 21 | 22 | # do not count ln flops 23 | return 0 24 | 25 | 26 | def count_flops_slstm_block_fw( 27 | S, 28 | d, 29 | Nh, 30 | conv1d_kernel_size=4, 31 | pf_ffn=1.3, 32 | factor_exp=1, 33 | count_ln_flops: Callable[[int], int] = _count_ln_flops, 34 | ): 35 | slstm_cell_flops = count_flops_slstm_cell_fw(S=S, d=d, Nh=Nh) 36 | dh = d // Nh 37 | conv1d_flops = ( 38 | 2 * conv1d_kernel_size * (S + conv1d_kernel_size - 1) * dh * Nh + S * dh * Nh 39 | ) 40 | ffn_flops = 4 * S * d * d * pf_ffn + S * d * factor_exp 41 | 42 | skip_ln_flops = 2 * S * d + (2 + 1) * S * count_ln_flops( 43 | d 44 | ) # 2 block pre-norm, 1 group norm 45 | 46 | total_flops = int(slstm_cell_flops + conv1d_flops + ffn_flops + skip_ln_flops) 47 | linear_layer_flops = int(ffn_flops) 48 | return total_flops, linear_layer_flops, int(slstm_cell_flops) 49 | 50 | 51 | def get_slstm_fw_flop_dict(sequence_length: int) -> dict[str, tuple[int, int, int]]: 52 | slstm_size_dict = { 53 | "125M": dict(d=768, Nh=4), 54 | "350M": dict(d=1024, Nh=4), 55 | "760M": dict(d=1536, Nh=4), 56 | "1.3B": dict(d=2048, Nh=4), 57 | } 58 | 59 | flops_dict = { 60 | k: count_flops_slstm_block_fw(sequence_length, **v) 61 | for k, v in slstm_size_dict.items() 62 | } 63 | return flops_dict 64 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/flops/transformer_block_flop_counts.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | 5 | def count_flops_transformer_block_fw( 6 | S: int, d: int, Nh: int, ff_ratio: float = 4.0 7 | ) -> tuple[int, int, int]: 8 | """DeepMind method for forwad pass FLOPs counting of decoder-only Transformer 9 | See Chinchilla paper or this blog post: 10 | https://www.adamcasson.com/posts/transformer-flops 11 | """ 12 | d_attn = d // Nh 13 | d_ff = d * ff_ratio 14 | 15 | attn_qkv = 2 * S * 3 * d * (d_attn * Nh) 16 | attn_logits = 2 * S * S * (d_attn * Nh) 17 | attn_softmax = 3 * Nh * S * S 18 | attn_reduce = 2 * S * S * (d_attn * Nh) 19 | attn_project = 2 * S * (d_attn * Nh) * d 20 | total_attn = attn_qkv + attn_logits + attn_softmax + attn_reduce + attn_project 21 | 22 | ff = 2 * S * (d * d_ff + d * d_ff) 23 | 24 | total_flops = total_attn + ff 25 | 26 | return int(total_flops), int(ff), int(total_attn) 27 | 28 | 29 | def get_transformer_fw_flop_dict( 30 | sequence_length: int, 31 | ) -> dict[str, tuple[int, int, int]]: 32 | """Returns a dictionary with the FLOPs of the transformer block for different hidden sizes.""" 33 | transformer_size_dict = { 34 | "125M": dict(d=768, Nh=12), 35 | "350M": dict(d=1024, Nh=16), 36 | "760M": dict(d=1536, Nh=16), 37 | "1.3B": dict(d=2048, Nh=16), 38 | "2.7B": dict(d=2560, Nh=32), 39 | "7B": dict(d=4096, Nh=32), 40 | } 41 | 42 | flops_dict = { 43 | k: count_flops_transformer_block_fw(sequence_length, **v) 44 | for k, v in transformer_size_dict.items() 45 | } 46 | return flops_dict 47 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/kernels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | 5 | def is_power_of_2(n): 6 | assert isinstance(n, int) 7 | return (n & (n - 1)) == 0 8 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/plot/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from .diff_imshow import plot_numerical_diffs_per_batchhead, plot_numerical_diffs_single 5 | from .diff_lineplot import ( 6 | compute_errors_per_batchhead, 7 | plot_error_statistics_over_time_per_batchhead, 8 | ) 9 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/plot/diff_lineplot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | 7 | from .ewma import ewma_vectorized 8 | 9 | 10 | def compute_errors_per_batchhead( 11 | baseline: np.ndarray, # (B, NH, S, ...) 12 | target: np.ndarray, # (B, NH, S, ...) 13 | ) -> np.ndarray: # (B * NH, S, F) F are flattened features 14 | # compute the difference in float64 to avoid numerical issues 15 | error = np.abs(baseline.astype(np.float64) - target.astype(np.float64)) 16 | all_timesteps_np = error 17 | 18 | B, NH, S = error.shape[:3] 19 | # reshape to target shape 20 | # flatten B, NH 21 | all_timesteps_np = all_timesteps_np.reshape(-1, *all_timesteps_np.shape[2:]) 22 | 23 | # flatten features 24 | all_timesteps_np = all_timesteps_np.reshape(B * NH, S, -1) 25 | 26 | return all_timesteps_np 27 | 28 | 29 | def plot_error_statistics_over_time_single( 30 | errors: np.ndarray, # shape: (num_timesteps, num_features) 31 | percentiles: list = [50, 90, 100], 32 | title: str = "", 33 | add_mean: bool = False, 34 | ema_alpha: float = 0.02, 35 | figsize=(10, 6), 36 | ): 37 | assert ( 38 | len(errors.shape) == 2 39 | ), "errors must have shape (num_timesteps, num_features)" 40 | title = f"{title}--ema{ema_alpha}" 41 | 42 | # compute percentiles 43 | percentiles_sequence_data = np.percentile(errors, percentiles, axis=-1) 44 | 45 | # plot 46 | fig, ax = plt.subplots(figsize=figsize) 47 | 48 | for i, p in enumerate(percentiles): 49 | ema_percentile_data = ewma_vectorized( 50 | percentiles_sequence_data[i], alpha=ema_alpha 51 | ) 52 | 53 | ax.plot(ema_percentile_data, label=f"{percentiles[i]}th percentile") 54 | 55 | if add_mean: 56 | ema_mean_data = ewma_vectorized(np.mean(errors, axis=-1), alpha=ema_alpha) 57 | ax.plot(ema_mean_data, label="mean") 58 | 59 | ax.set_title(title) 60 | ax.set_xlabel("timestep") 61 | ax.set_ylabel("error") 62 | ax.legend() 63 | 64 | ax.grid(alpha=0.5) 65 | ax.spines["top"].set_visible(False) 66 | ax.spines["right"].set_visible(False) 67 | return fig 68 | 69 | 70 | def plot_error_statistics_over_time_per_batchhead( 71 | errors: np.ndarray, # shape: (num_batchheads, num_timesteps, num_features) 72 | percentiles: list = [50, 90, 100], 73 | title: str = "", 74 | add_mean: bool = False, 75 | ema_alpha: float = 0.02, 76 | max_num_batchhead_plots: int = -1, # -1 means all 77 | figsize=(10, 6), 78 | ): 79 | num_batchheads = errors.shape[0] 80 | if max_num_batchhead_plots > 0: 81 | max_num_batchhead_plots = min(num_batchheads, max_num_batchhead_plots) 82 | else: 83 | max_num_batchhead_plots = num_batchheads 84 | 85 | figs = [] 86 | for i in range(max_num_batchhead_plots): 87 | title_i = f"BH({i}):{title}" 88 | fig = plot_error_statistics_over_time_single( 89 | errors[i], 90 | percentiles, 91 | title=title_i, 92 | add_mean=add_mean, 93 | ema_alpha=ema_alpha, 94 | figsize=figsize, 95 | ) 96 | figs.append(fig) 97 | 98 | return figs 99 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/plot/ewma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | """ 4 | This files contains exponential moving average (EMA) functions in numpy. 5 | Obtained from https://stackoverflow.com/questions/42869495/numpy-version-of-exponential-weighted-moving-average-equivalent-to-pandas-ewm 6 | """ 7 | 8 | import numpy as np 9 | 10 | 11 | def ewma(x, alpha): 12 | """ 13 | Returns the exponentially weighted moving average of x. 14 | 15 | Parameters: 16 | ----------- 17 | x : array-like 18 | alpha : float {0 <= alpha <= 1} 19 | 20 | Returns: 21 | -------- 22 | ewma: numpy array 23 | the exponentially weighted moving average 24 | """ 25 | # Coerce x to an array 26 | x = np.array(x) 27 | n = x.size 28 | 29 | # Create an initial weight matrix of (1-alpha), and a matrix of powers 30 | # to raise the weights by 31 | w0 = np.ones(shape=(n, n)) * (1 - alpha) 32 | p = np.vstack([np.arange(i, i - n, -1) for i in range(n)]) 33 | 34 | # Create the weight matrix 35 | w = np.tril(w0**p, 0) 36 | 37 | # Calculate the ewma 38 | return np.dot(w, x[:: np.newaxis]) / w.sum(axis=1) 39 | 40 | 41 | def ewma_vectorized(data, alpha, offset=None, dtype=None, order="C", out=None): 42 | """ 43 | Calculates the exponential moving average over a vector. 44 | Will fail for large inputs. 45 | :param data: Input data 46 | :param alpha: scalar float in range (0,1) 47 | The alpha parameter for the moving average. 48 | :param offset: optional 49 | The offset for the moving average, scalar. Defaults to data[0]. 50 | :param dtype: optional 51 | Data type used for calculations. Defaults to float64 unless 52 | data.dtype is float32, then it will use float32. 53 | :param order: {'C', 'F', 'A'}, optional 54 | Order to use when flattening the data. Defaults to 'C'. 55 | :param out: ndarray, or None, optional 56 | A location into which the result is stored. If provided, it must have 57 | the same shape as the input. If not provided or `None`, 58 | a freshly-allocated array is returned. 59 | """ 60 | data = np.array(data, copy=False) 61 | 62 | if dtype is None: 63 | if data.dtype == np.float32: 64 | dtype = np.float32 65 | else: 66 | dtype = np.float64 67 | else: 68 | dtype = np.dtype(dtype) 69 | 70 | if data.ndim > 1: 71 | # flatten input 72 | data = data.reshape(-1, order) 73 | 74 | if out is None: 75 | out = np.empty_like(data, dtype=dtype) 76 | else: 77 | assert out.shape == data.shape 78 | assert out.dtype == dtype 79 | 80 | if data.size < 1: 81 | # empty input, return empty array 82 | return out 83 | 84 | if offset is None: 85 | offset = data[0] 86 | 87 | alpha = np.asarray(alpha, dtype=dtype) 88 | 89 | # scaling_factors -> 0 as len(data) gets large 90 | # this leads to divide-by-zeros below 91 | scaling_factors = np.power( 92 | 1.0 - alpha, np.arange(data.size + 1, dtype=dtype), dtype=dtype 93 | ) 94 | # create cumulative sum array 95 | np.multiply( 96 | data, (alpha * scaling_factors[-2]) / scaling_factors[:-1], dtype=dtype, out=out 97 | ) 98 | np.cumsum(out, dtype=dtype, out=out) 99 | 100 | # cumsums / scaling 101 | out /= scaling_factors[-2::-1] 102 | 103 | if offset != 0: 104 | offset = np.asarray(offset, dtype=dtype) 105 | # add offsets 106 | out += offset * scaling_factors[1:] 107 | 108 | return out 109 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/test/fixtures.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | import sys 6 | from datetime import datetime 7 | from pathlib import Path 8 | 9 | import pytest 10 | 11 | # We declare this here globally to enforca that there is only one timestamp per test session 12 | TIMESTAMP = None 13 | 14 | TEST_OUTPUT_FOLDER = Path(__file__).parents[3] / "outputs_tests" 15 | 16 | 17 | @pytest.fixture(scope="session") 18 | def test_session_folder() -> Path: 19 | global TIMESTAMP 20 | if TIMESTAMP is None: 21 | TIMESTAMP = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 22 | 23 | timestamp = TIMESTAMP 24 | 25 | test_output_folder = TEST_OUTPUT_FOLDER / timestamp 26 | 27 | test_output_folder.mkdir(parents=True, exist_ok=True) 28 | 29 | logfile = test_output_folder / "pytest.log" 30 | file_handler = logging.FileHandler(filename=logfile) 31 | stdout_handler = logging.StreamHandler(sys.stdout) 32 | logging.basicConfig( 33 | handlers=[file_handler], 34 | format="%(asctime)s %(levelname)s %(message)s", 35 | level=logging.INFO, 36 | force=True, 37 | ) 38 | LOGGER = logging.getLogger(__name__) 39 | LOGGER.info(f"Logging to {logfile}") 40 | return test_output_folder 41 | 42 | 43 | @pytest.fixture 44 | def test_output_folder() -> Path: 45 | return TEST_OUTPUT_FOLDER / "test_data" 46 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/test/test_fwbw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import torch 5 | 6 | 7 | def test_forward( 8 | f1, 9 | f2, 10 | inputs: tuple[torch.Tensor], 11 | comp_func=torch.allclose, 12 | comp_func_kwargs={}, 13 | show_diff_func=None, 14 | ): 15 | out1 = f1(*inputs) 16 | out2 = f2(*inputs) 17 | 18 | if torch.any(torch.isnan(out1)): 19 | print("NaN found in output 1") 20 | 21 | if torch.any(torch.isnan(out2)): 22 | print("NaN found in output 2") 23 | 24 | if out1.shape != out2.shape: 25 | print("Bad output shape") 26 | 27 | if not comp_func(out1, out2, **comp_func_kwargs): 28 | print("Difference") 29 | if show_diff_func is not None: 30 | show_diff_func(out1, out2) 31 | 32 | 33 | def test_backward( 34 | f1, 35 | f2, 36 | inputs: tuple[torch.Tensor], 37 | mask: torch.Tensor | None = None, 38 | comp_func=torch.allclose, 39 | comp_func_kwargs={}, 40 | show_diff_func=None, 41 | ): 42 | inputs1 = [inp.clone().detach() if inp is not None else None for inp in inputs] 43 | inputs2 = [inp.clone().detach() if inp is not None else None for inp in inputs] 44 | for inp in inputs1: 45 | if inp is not None: 46 | inp.requires_grad_(True) 47 | for inp in inputs2: 48 | if inp is not None: 49 | inp.requires_grad_(True) 50 | 51 | out1 = f1(*inputs1) 52 | out2 = f2(*inputs2) 53 | 54 | if mask is None: 55 | mask = torch.randn_like(out1) 56 | mask1 = mask.clone().detach() 57 | mask2 = mask.clone().detach() 58 | 59 | l1 = (out1 * mask1).sum() 60 | l1.backward() 61 | l2 = (out2 * mask2).sum() 62 | l2.backward() 63 | 64 | for n, (inp1, inp2) in enumerate(zip(inputs1, inputs2)): 65 | if inp is None and inp2 is None: 66 | continue 67 | if inp1.grad is None and inp2.grad is None: 68 | continue 69 | if inp1.grad is not None and torch.any(torch.isnan(inp1.grad)): 70 | print(f"NaN found in grad {n} of inp1") 71 | if inp2.grad is not None and torch.any(torch.isnan(inp2.grad)): 72 | print(f"NaN found in grad {n} of inp2") 73 | # if n == 4: 74 | # print(inp1.grad, inp2.grad) 75 | if not comp_func(inp1.grad, inp2.grad, **comp_func_kwargs): 76 | print(f"Difference in {n}-th gradient") 77 | if show_diff_func is not None: 78 | show_diff_func(inp1.grad, inp2.grad) 79 | -------------------------------------------------------------------------------- /mlstm_kernels/utils/test/test_templates/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /notebooks/flop_counting/mlstm_cell_vs_slstm_cell_flops.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "import sys\n", 21 | "\n", 22 | "sys.path.append(\"../..\")\n", 23 | "from mlstm_kernels.flops_utils.mlstm_block_flop_counts import (\n", 24 | " count_flops_mlstm_v2_block_fw,\n", 25 | ")\n", 26 | "from mlstm_kernels.flops_utils.slstm_block_flop_counts import count_flops_slstm_block_fw" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 11, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "S = 2048\n", 36 | "d = 2048\n", 37 | "Nh = 4\n", 38 | "\n", 39 | "dqk = d // Nh\n", 40 | "dv = d // Nh" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 12, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "mlstm_block_flops = count_flops_mlstm_v2_block_fw(S=S, d=d, Nh=Nh, dqk=dqk, dv=dv)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 14, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "slstm_block_flops = count_flops_slstm_block_fw(S=S, d=d, Nh=Nh)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 15, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "0.5475116631333724\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "# slstm / mlstm total\n", 76 | "print(slstm_block_flops[0] / mlstm_block_flops[0])" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 16, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "0.43332112778813753\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "# slstm / mlstm linear\n", 94 | "print(slstm_block_flops[1] / mlstm_block_flops[1])" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 17, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "1.7461187947593693\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "# slstm / mlstm cell\n", 112 | "print(slstm_block_flops[2] / mlstm_block_flops[2])" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "xlstmpt240cu124", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.11.9" 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 2 144 | } 145 | -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/avg_acc_ruler_abl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/avg_acc_ruler_abl.pkl -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/avg_acc_ruler_main.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/avg_acc_ruler_main.pkl -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/dumps/scheduler/loss_df.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/dumps/scheduler/loss_df.pkl -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/dumps/scheduler/lr_df.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/dumps/scheduler/lr_df.pkl -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/dumps/scheduler/ppl_df.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/dumps/scheduler/ppl_df.pkl -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/gen_time_mem_data.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/gen_time_mem_data.p -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/gen_time_mem_data_vllm.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/gen_time_mem_data_vllm.p -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/plot_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from pathlib import Path 5 | 6 | import matplotlib as mpl 7 | 8 | model_colors = { 9 | "mlstm_simple": mpl.colormaps["tab10"].colors[0], 10 | "xlstm": mpl.colormaps["tab10"].colors[1], 11 | "llama2": mpl.colormaps["tab10"].colors[2], 12 | "llama3": mpl.colormaps["tab10"].colors[3], 13 | "ministral8b": mpl.colormaps["tab10"].colors[5], 14 | "codestral_mamba": mpl.colormaps["tab10"].colors[6], 15 | "falcon_mamba": mpl.colormaps["tab10"].colors[4], 16 | "zamba2": mpl.colormaps["tab10"].colors[7], 17 | } 18 | 19 | xlstm_colors = { 20 | "llama3": "#165b89ff", 21 | "llama2": "#80a8b3ff", 22 | # "xLSTM": "#cc4391ff", 23 | "xlstm": "#861657ff", 24 | "codestral_mamba": "#d08814ff", 25 | "falcon_mamba": "#ffd449ff", 26 | "RWKV4": "#145815ff", 27 | } 28 | 29 | model_labels = { 30 | "mlstm_simple": "mLSTM simple", 31 | "xlstm": "xLSTM 7B", 32 | "llama2": "Llama 2 7B", 33 | "llama3": "Llama 3 8B", 34 | "ministral8b": "Ministral8B", 35 | "codestral_mamba": "CodestralMamba (Mamba 2) 7B", 36 | "falcon_mamba": "FalconMamba (Mamba 1) 7B", 37 | "zamba2": "Zamba2", 38 | } 39 | 40 | linestyle_mapping = { 41 | "__tcm__": {"linestyle": "--", "label": ""}, 42 | } 43 | 44 | style_dict = { 45 | "mlstm_simple": { 46 | "color": model_colors["mlstm_simple"], 47 | "label": model_labels["mlstm_simple"], 48 | }, 49 | "xlstm": {"color": xlstm_colors["xlstm"], "label": model_labels["xlstm"]}, 50 | "llama2": {"color": xlstm_colors["llama2"], "label": model_labels["llama2"]}, 51 | "llama3": {"color": xlstm_colors["llama3"], "label": model_labels["llama3"]}, 52 | "ministral8b": { 53 | "color": model_colors["ministral8b"], 54 | "label": model_labels["ministral8b"], 55 | }, 56 | "codestral_mamba": { 57 | "color": xlstm_colors["codestral_mamba"], 58 | "label": model_labels["codestral_mamba"], 59 | }, 60 | "falcon_mamba": { 61 | "color": xlstm_colors["falcon_mamba"], 62 | "label": model_labels["falcon_mamba"], 63 | }, 64 | "zamba2": {"color": model_colors["zamba2"], "label": model_labels["zamba2"]}, 65 | } 66 | 67 | 68 | def savefig(fig, filename: str): 69 | dir = Path("./plots/") 70 | dir.mkdir(parents=True, exist_ok=True) 71 | 72 | if filename is not None: 73 | for file_ending in ["png", "pdf", "svg"]: 74 | file = Path(f"./plots/plot_{filename}.{file_ending}") 75 | fig.savefig(file, dpi=300, bbox_inches="tight", pad_inches=-0.0020) 76 | -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/plot_config_for_paper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import matplotlib as mpl 5 | 6 | FONTSIZE = 12 7 | SMALL_OFFSET = 1 8 | FONTSIZE_SMALL = FONTSIZE - SMALL_OFFSET 9 | FONTSIZE_TICKS = FONTSIZE_SMALL 10 | fontsize_delta = 0 11 | 12 | MARKERSIZE = 6.0 13 | LINEWIDTH = 2.0 # default 1.5 14 | 15 | nice_look_term = 0 16 | half_of_icml_width = ( 17 | 6.75 / 2 - 0.25 / 2 - 1 18 | ) # ICML width of page: 6.75 inches + 0.25 inches between columns 19 | factor = 5.5 20 | desired_aspect_ratio = 2 21 | FIGSIZE_2COL = ( 22 | factor * half_of_icml_width, 23 | factor * half_of_icml_width / desired_aspect_ratio, 24 | ) 25 | one_col_fig_size_factor = 1 26 | FIGSIZE = ( 27 | one_col_fig_size_factor * (FIGSIZE_2COL[0]), 28 | one_col_fig_size_factor * (FIGSIZE_2COL[1] * 2), 29 | ) 30 | 31 | GRIDSPEC_KWARGS = {"wspace": 0.115, "hspace": 0} 32 | 33 | model_colors = { 34 | "mlstm_simple": mpl.colormaps["tab10"].colors[0], 35 | "xlstm": mpl.colormaps["tab10"].colors[1], 36 | "llama2": mpl.colormaps["tab10"].colors[2], 37 | "llama3": mpl.colormaps["tab10"].colors[3], 38 | "ministral8b": mpl.colormaps["tab10"].colors[5], 39 | "codestral_mamba": mpl.colormaps["tab10"].colors[6], 40 | "falcon_mamba": mpl.colormaps["tab10"].colors[4], 41 | "zamba2": mpl.colormaps["tab10"].colors[7], 42 | } 43 | 44 | xlstm_colors = { 45 | "llama3": "#165b89ff", 46 | "llama2": "#80a8b3ff", 47 | # "xLSTM": "#cc4391ff", 48 | "xlstm": "#861657ff", 49 | "codestral_mamba": "#d08814ff", 50 | "falcon_mamba": "#ffd449ff", 51 | "RWKV4": "#145815ff", 52 | } 53 | 54 | model_labels = { 55 | "mlstm_simple": "mLSTM simple", 56 | "xlstm": "xLSTM 7B", 57 | "llama2": "Llama 2 7B", 58 | "llama3": "Llama 3 8B", 59 | "ministral8b": "Ministral8B", 60 | "codestral_mamba": "CodestralMamba 7B", 61 | "falcon_mamba": "FalconMamba 7B", 62 | "zamba2": "Zamba2", 63 | } 64 | 65 | linestyle_mapping = { 66 | "__tcm__": {"linestyle": "--", "label": ""}, 67 | } 68 | 69 | style_dict = { 70 | "mlstm_simple": { 71 | "color": model_colors["mlstm_simple"], 72 | "label": model_labels["mlstm_simple"], 73 | }, 74 | "xlstm": {"color": xlstm_colors["xlstm"], "label": model_labels["xlstm"]}, 75 | "llama2": {"color": xlstm_colors["llama2"], "label": model_labels["llama2"]}, 76 | "llama3": {"color": xlstm_colors["llama3"], "label": model_labels["llama3"]}, 77 | "ministral8b": { 78 | "color": model_colors["ministral8b"], 79 | "label": model_labels["ministral8b"], 80 | }, 81 | "codestral_mamba": { 82 | "color": xlstm_colors["codestral_mamba"], 83 | "label": model_labels["codestral_mamba"], 84 | }, 85 | "falcon_mamba": { 86 | "color": xlstm_colors["falcon_mamba"], 87 | "label": model_labels["falcon_mamba"], 88 | }, 89 | "zamba2": {"color": model_colors["zamba2"], "label": model_labels["zamba2"]}, 90 | } 91 | -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/throughput_data.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/throughput_data.p -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/throughput_df.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/throughput_df.p -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/throughput_vllm_df.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/throughput_vllm_df.p -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/ttft_raw_data.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/ttft_raw_data.p -------------------------------------------------------------------------------- /notebooks/plots_7B_model_benchmark/ttft_raw_data_vllm.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_7B_model_benchmark/ttft_raw_data_vllm.p -------------------------------------------------------------------------------- /notebooks/plots_mlstm_kernel_benchmark/consttoken_results.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_mlstm_kernel_benchmark/consttoken_results.pkl -------------------------------------------------------------------------------- /notebooks/plots_mlstm_kernel_benchmark/gen_results.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_mlstm_kernel_benchmark/gen_results.pkl -------------------------------------------------------------------------------- /notebooks/plots_mlstm_kernel_benchmark_tfla_paper/mlstm_tfla_paper_consttoken_benchmark_results.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_mlstm_kernel_benchmark_tfla_paper/mlstm_tfla_paper_consttoken_benchmark_results.p -------------------------------------------------------------------------------- /notebooks/plots_mlstm_kernel_benchmark_tfla_paper/mlstm_tfla_paper_consttoken_benchmark_results_lightn_attn.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_mlstm_kernel_benchmark_tfla_paper/mlstm_tfla_paper_consttoken_benchmark_results_lightn_attn.p -------------------------------------------------------------------------------- /notebooks/plots_mlstm_kernel_benchmark_tfla_paper/mlstm_tfla_paper_consttoken_benchmark_results_rerun.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_mlstm_kernel_benchmark_tfla_paper/mlstm_tfla_paper_consttoken_benchmark_results_rerun.p -------------------------------------------------------------------------------- /notebooks/plots_mlstm_kernel_benchmark_tfla_paper/mlstm_tfla_paper_head_dim_benchmark_results.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/plots_mlstm_kernel_benchmark_tfla_paper/mlstm_tfla_paper_head_dim_benchmark_results.p -------------------------------------------------------------------------------- /notebooks/plots_roofline_analysis/plot_arithmetic_intensity_mlstm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "from matplotlib.figure import Figure\n", 12 | "from matplotlib.axes import Axes\n", 13 | "\n", 14 | "import sys\n", 15 | "\n", 16 | "sys.path.append(\"../..\")" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "%load_ext autoreload\n", 26 | "%autoreload 2\n", 27 | "from mlstm_kernels.utils.analysis.roofline_analysis.plot_mlstm_arithmetic_intensity import (\n", 28 | " create_mlstm_arithmetic_intensity_plot,\n", 29 | ")\n", 30 | "from mlstm_kernels.utils.analysis.roofline_analysis.plot_config import savefig" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "fig = create_mlstm_arithmetic_intensity_plot()\n", 40 | "fig" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "savefig(fig, \"arithmetic_intensity_over_chunksize_mlstmsig\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [] 58 | } 59 | ], 60 | "metadata": { 61 | "kernelspec": { 62 | "display_name": "mlstmpt251cu124", 63 | "language": "python", 64 | "name": "python3" 65 | }, 66 | "language_info": { 67 | "codemirror_mode": { 68 | "name": "ipython", 69 | "version": 3 70 | }, 71 | "file_extension": ".py", 72 | "mimetype": "text/x-python", 73 | "name": "python", 74 | "nbconvert_exporter": "python", 75 | "pygments_lexer": "ipython3", 76 | "version": "3.11.11" 77 | } 78 | }, 79 | "nbformat": 4, 80 | "nbformat_minor": 2 81 | } 82 | -------------------------------------------------------------------------------- /notebooks/plots_roofline_analysis/plot_flop_comparison_mlstm_formulations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "from matplotlib.figure import Figure\n", 12 | "from matplotlib.axes import Axes\n", 13 | "\n", 14 | "import sys\n", 15 | "\n", 16 | "sys.path.append(\"../..\")" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "%load_ext autoreload\n", 26 | "%autoreload 2\n", 27 | "from mlstm_kernels.utils.analysis.roofline_analysis.plot_mlstm_flop_analysis import (\n", 28 | " create_mlstm_flop_formulation_comparison_plot,\n", 29 | ")\n", 30 | "from mlstm_kernels.utils.analysis.roofline_analysis.plot_config import savefig" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "fig = create_mlstm_flop_formulation_comparison_plot()\n", 40 | "fig" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "savefig(fig, \"flops_mlstmsig_comparison\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [] 58 | } 59 | ], 60 | "metadata": { 61 | "kernelspec": { 62 | "display_name": "mlstmpt251cu124", 63 | "language": "python", 64 | "name": "python3" 65 | }, 66 | "language_info": { 67 | "codemirror_mode": { 68 | "name": "ipython", 69 | "version": 3 70 | }, 71 | "file_extension": ".py", 72 | "mimetype": "text/x-python", 73 | "name": "python", 74 | "nbconvert_exporter": "python", 75 | "pygments_lexer": "ipython3", 76 | "version": "3.11.11" 77 | } 78 | }, 79 | "nbformat": 4, 80 | "nbformat_minor": 2 81 | } 82 | -------------------------------------------------------------------------------- /notebooks/plots_roofline_analysis/plot_optimal_chunk_size.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "\n", 11 | "sys.path.append(\"../..\")" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "%load_ext autoreload\n", 21 | "%autoreload 2\n", 22 | "from mlstm_kernels.utils.analysis.roofline_analysis.plot_mlstm_optimal_chunksize import (\n", 23 | " create_mlstm_flop_optimal_chunksize_plot,\n", 24 | " create_mlstm_runtime_optimal_chunk_size_plot,\n", 25 | " create_mlstm_runtime_optimal_chunk_size_over_acc_intensity_plot,\n", 26 | ")\n", 27 | "from mlstm_kernels.utils.analysis.roofline_analysis.plot_config import savefig" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "# FLOP optimal chunk size" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "fig = create_mlstm_flop_optimal_chunksize_plot()\n", 44 | "fig" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "savefig(fig, \"mlstm_flop_optimal_chunksize\")" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "# Runtime optimal chunk size" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "fig = create_mlstm_runtime_optimal_chunk_size_plot()\n", 70 | "fig" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "savefig(fig, \"mlstm_runtime_optimal_chunksize\")" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "# Runtime optimal chunk size over arithmetic intensity" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "fig = create_mlstm_runtime_optimal_chunk_size_over_acc_intensity_plot()\n", 96 | "fig" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "savefig(fig, \"mlstm_runtime_optimal_chunksize_over_acc_intensity\")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "mlstmpt251cu124", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.11.11" 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 2 144 | } 145 | -------------------------------------------------------------------------------- /notebooks/transfer_behavior_analysis/m_state_explore.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import matplotlib as mpl\n", 11 | "import matplotlib.pyplot as plt" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 4, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "S = 32" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 16, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "fg_range = torch.linspace(3, 6, 50)\n", 30 | "full_range = torch.linspace(0, 10, 100)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 17, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "fg_vals = torch.nn.functional.sigmoid(fg_range)\n", 40 | "full_vals = torch.nn.functional.sigmoid(full_range)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# plt.plot(full_range, full_vals)\n", 50 | "plt.plot(fg_range, fg_vals)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 35, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "fg_std = 1.0\n", 60 | "fg_preact_vals = 0.99 + torch.randn([S]) * fg_std" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "fg_act_vals = torch.nn.functional.sigmoid(fg_preact_vals)\n", 70 | "fg_act_vals" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 33, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "log_fg_act_vals = torch.log(fg_act_vals)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "torch.cumsum(log_fg_act_vals, 0)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [] 97 | } 98 | ], 99 | "metadata": { 100 | "kernelspec": { 101 | "display_name": "mlstmpt251cu124_beck", 102 | "language": "python", 103 | "name": "python3" 104 | }, 105 | "language_info": { 106 | "codemirror_mode": { 107 | "name": "ipython", 108 | "version": 3 109 | }, 110 | "file_extension": ".py", 111 | "mimetype": "text/x-python", 112 | "name": "python", 113 | "nbconvert_exporter": "python", 114 | "pygments_lexer": "ipython3", 115 | "version": "3.11.11" 116 | } 117 | }, 118 | "nbformat": 4, 119 | "nbformat_minor": 2 120 | } 121 | -------------------------------------------------------------------------------- /notebooks/transfer_behavior_analysis/norm_eps_vs_igbias_raw_data_df.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/notebooks/transfer_behavior_analysis/norm_eps_vs_igbias_raw_data_df.p -------------------------------------------------------------------------------- /notebooks/transfer_behavior_analysis/plot_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from pathlib import Path 5 | 6 | import matplotlib as mpl 7 | import matplotlib.pyplot as plt 8 | 9 | fontsize_delta = 2.5 10 | FONTSIZE = 12 + fontsize_delta 11 | SMALL_OFFSET = 1 12 | FONTSIZE_SMALL = FONTSIZE - SMALL_OFFSET 13 | FONTSIZE_TICKS = FONTSIZE_SMALL 14 | 15 | MARKERSIZE = 6.0 16 | LINEWIDTH = 2.0 # default 1.5 17 | 18 | FIGSIZE = (2 * 12 * 1 / 2.54, 2 * 8 * 1 / 2.54) 19 | FIGSIZE_2COL = (4 * 0.7 * 12 * 1 / 2.54, 2 * 0.7 * 8 * 1 / 2.54) 20 | 21 | GRIDSPEC_KWARGS = {"wspace": 0.115, "hspace": 0} 22 | 23 | 24 | def get_tb_plot_mpl_context(): 25 | return mpl.rc_context( 26 | rc={ 27 | "text.usetex": False, 28 | "font.size": FONTSIZE, 29 | "axes.labelsize": FONTSIZE, 30 | "legend.fontsize": FONTSIZE_SMALL, 31 | "xtick.labelsize": FONTSIZE_TICKS, 32 | "ytick.labelsize": FONTSIZE_TICKS, 33 | "axes.titlesize": FONTSIZE, 34 | "lines.markersize": MARKERSIZE, 35 | "lines.linewidth": LINEWIDTH, 36 | } 37 | ) 38 | 39 | 40 | def savefig(fig, filename: str): 41 | dir = Path("./plots/") 42 | dir.mkdir(parents=True, exist_ok=True) 43 | 44 | if filename is not None: 45 | for file_ending in ["png", "pdf", "svg"]: 46 | file = Path(f"./plots/plot_{filename}.{file_ending}") 47 | fig.savefig(file, dpi=300, bbox_inches="tight", pad_inches=-0.0020) 48 | -------------------------------------------------------------------------------- /notebooks/transfer_behavior_analysis/plot_transfer_behavior_norm_eps_grid--experimental.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "import sys\n", 12 | "\n", 13 | "sys.path.append(\"../..\")\n", 14 | "\n", 15 | "import torch\n", 16 | "import numpy as np\n", 17 | "\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "\n", 20 | "from mlstm_kernels.utils.analysis.transfer_behavior.plot_transfer_behavior import (\n", 21 | " generate_generate_norm_eps_grid_transfer_behavior_plot,\n", 22 | ")\n", 23 | "from plot_config import get_tb_plot_mpl_context, savefig" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "num_points = 50" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "with get_tb_plot_mpl_context():\n", 42 | " fig = generate_generate_norm_eps_grid_transfer_behavior_plot(\n", 43 | " mlstm_func_specifiers=[\n", 44 | " \"tb__mlstmsig--paper\",\n", 45 | " \"tb__mlstmsig--max_sum_abs_1-1.0\",\n", 46 | " \"tb__mlstmsig--max_sum_abs_1-0.001\",\n", 47 | " \"tb__mlstmsig--max_sum_abs_1-0.000001\",\n", 48 | " \"tb__mlstmsig--max_sum_abs_1-0.00000001\",\n", 49 | " ],\n", 50 | " norm_epsilons=[1e-2, 1e-6],\n", 51 | " norm_specifier=\"rms\",\n", 52 | " metric_specifier=\"abs_max_mean-v\",\n", 53 | " seq_len=512,\n", 54 | " dhqk=128,\n", 55 | " dhhv=128,\n", 56 | " backend_eps=1e-6,\n", 57 | " qkv_std=(1.0, 1.0, 1.0),\n", 58 | " z_levels=np.linspace(0, 2, 100).tolist(),\n", 59 | " igate_preact_offsets=np.linspace(-12, 8, num_points).tolist(),\n", 60 | " fgate_preact_offsets=np.linspace(-5, 12, num_points).tolist(),\n", 61 | " igate_preact_init_fn=torch.zeros,\n", 62 | " fgate_preact_init_fn=torch.zeros,\n", 63 | " dtype=torch.bfloat16,\n", 64 | " device=torch.device(\"cuda\"),\n", 65 | " colorbar_fraction=0.2,\n", 66 | " fig_height=4,\n", 67 | " )\n", 68 | "fig" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# savefig(fig=fig, filename=\"transfer_behavior_app--mlstm_sig_paper\")" 78 | ] 79 | } 80 | ], 81 | "metadata": { 82 | "kernelspec": { 83 | "display_name": "mlstmpt251cu124", 84 | "language": "python", 85 | "name": "python3" 86 | }, 87 | "language_info": { 88 | "codemirror_mode": { 89 | "name": "ipython", 90 | "version": 3 91 | }, 92 | "file_extension": ".py", 93 | "mimetype": "text/x-python", 94 | "name": "python", 95 | "nbconvert_exporter": "python", 96 | "pygments_lexer": "ipython3", 97 | "version": "3.11.11" 98 | } 99 | }, 100 | "nbformat": 4, 101 | "nbformat_minor": 2 102 | } 103 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "mlstm_kernels" 7 | version = "2.0.0" 8 | authors = [ 9 | { name="Maximilian Beck", email="beck@ml.jku.at" }, 10 | { name="Korbinian Poeppel", email="poeppel@ml.jku.at" }, 11 | { name="Phillip Lippe", email="phillip.lippe@gmail.com" }, 12 | { name="Sebastian Boeck", email="sebastian.boeck@nx-ai.com" }, 13 | ] 14 | description = "A library providing fast and efficient mLSTM kernels for the xLSTM." 15 | readme = "README.md" 16 | license = {file="LICENSE"} 17 | requires-python = ">=3.11" 18 | classifiers = [ 19 | "Programming Language :: Python :: 3", 20 | "Operating System :: OS Independent", 21 | ] 22 | keywords = ["mLSTM", "xLSTM", "LSTM", "Transformer", "Machine Learning", "Deep Learning", "State Space Models"] 23 | dependencies = [ 24 | "dacite", 25 | "einops", 26 | "ipykernel", 27 | "matplotlib", 28 | "numpy", 29 | "omegaconf", 30 | "rich", 31 | "torch", 32 | "tqdm", 33 | ] 34 | 35 | # [tool.setuptools] 36 | # include_package_data = true 37 | 38 | [tool.setuptools.package-data] 39 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | pythonpath = . 3 | log_cli = true 4 | log_cli_level = INFO 5 | log_format = %(asctime)s %(levelname)s %(message)s 6 | log_date_format = %Y-%m-%d %H:%M:%S 7 | -------------------------------------------------------------------------------- /res/Figure_1-7.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NX-AI/mlstm_kernels/0f172a0b88ce2b8c54154072afd58c842975df4b/res/Figure_1-7.pdf -------------------------------------------------------------------------------- /scripts/run_training_kernel_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # # Execute the first Python script 4 | # COMMON_ARGS="--folder_suffix rerun_v0" 5 | 6 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark mamba $COMMON_ARGS 7 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark mlstm_triton $COMMON_ARGS 8 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark fla $COMMON_ARGS 9 | 10 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark mlstm_triton $COMMON_ARGS --num_heads 16 11 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark fla $COMMON_ARGS --num_heads 16 12 | 13 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark mlstm_triton $COMMON_ARGS --num_heads 32 14 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark fla $COMMON_ARGS --num_heads 32 15 | 16 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark mlstm_triton $COMMON_ARGS --num_heads 64 17 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark fla $COMMON_ARGS --num_heads 64 18 | 19 | # Comparison to lightning attention 2 20 | 21 | COMMON_ARGS="--folder_suffix lightnattn_v1" 22 | 23 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark lightning_attn2 $COMMON_ARGS --num_heads 32 24 | # python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark lightning_attn2 $COMMON_ARGS --num_heads 64 25 | 26 | python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark mlstm_triton $COMMON_ARGS --num_heads 32 --half_qkdim 0 27 | python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark mlstm_triton $COMMON_ARGS --num_heads 64 --half_qkdim 0 28 | 29 | 30 | # # Check if the first script ran successfully 31 | # if [ $? -ne 0 ]; then 32 | # echo "The first script encountered an error. Exiting." 33 | # exit 1 34 | # fi 35 | 36 | 37 | # # Check if the second script ran successfully 38 | # if [ $? -ne 0 ]; then 39 | # echo "The second script encountered an error. Exiting." 40 | # exit 1 41 | # fi 42 | 43 | # echo "Both scripts executed successfully." 44 | -------------------------------------------------------------------------------- /scripts/run_training_kernel_benchmarks_with_profile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from functools import partial 5 | from pathlib import Path 6 | 7 | import torch 8 | from dacite import from_dict 9 | from omegaconf import OmegaConf 10 | from torch.profiler import ProfilerActivity, profile, record_function, schedule 11 | 12 | from mlstm_kernels.utils.benchmark.benchmarks.training_kernel_benchmarks import ( 13 | create_training_kernel_benchmark, 14 | ) 15 | from mlstm_kernels.utils.benchmark.param_handling import BenchmarkConfig 16 | from mlstm_kernels.utils.benchmark.run_benchmark import ( 17 | run_and_record_benchmarks, 18 | run_benchmarks, 19 | ) 20 | from mlstm_kernels.utils.benchmark.utils import setup_output_folder 21 | 22 | run_training_benchmarks = partial( 23 | run_benchmarks, benchmark_creator=create_training_kernel_benchmark 24 | ) 25 | 26 | 27 | def _benchmark_to_profile(output_folder: Path): 28 | B = 8 29 | S = 8192 30 | 31 | cfg_yaml = f""" 32 | vary_type: grid 33 | vary_params: 34 | 35 | fixed_params: 36 | batch_size: {B} 37 | sequence_length: {S} 38 | rep: 25 39 | warmup: 5 40 | 41 | x_axis_param: "batch_size" 42 | 43 | kernel_specs: 44 | - kernel_name: "mlstm_chunkwise__xl_chunk" 45 | dtype: bfloat16 46 | fwbw: True 47 | use_torch_compile: True 48 | additional_params: 49 | head_dim_qk: 256 50 | head_dim_v: 512 51 | num_heads: 4 52 | chunk_size_inter: 128 53 | chunk_size_intra: 128 54 | 55 | siz_b_L_parallel: 64 56 | siz_b_L_loop: 64 57 | siz_b_DH_parallel: 128 58 | siz_b_DH_loop: 64 59 | 60 | num_warps_intra: 4 61 | num_warps_inter: 4 62 | num_stages_intra: 1 63 | num_stages_inter: 1 64 | recompute_states_in_bw: False 65 | 66 | benchmark_name: "compare to flash_attention" 67 | """ 68 | cfg = from_dict( 69 | data_class=BenchmarkConfig, 70 | data=OmegaConf.to_container(OmegaConf.create(cfg_yaml)), 71 | ) 72 | 73 | run_and_record_benchmarks(cfg, create_training_kernel_benchmark, output_folder) 74 | 75 | 76 | def run_multiple_benchmarks(output_dir: str = "./outputs_kernel_benchmarks_profiler"): 77 | output_folder = setup_output_folder(output_dir) 78 | 79 | # _sequence_length_benchmark(output_folder, batch_size=1, num_heads=16, head_dim=256) 80 | # _batch_size_benchmark(output_folder, seq_len=8192, num_heads=16, head_dim=256) 81 | # _sequence_length_benchmark(output_folder, batch_size=1, num_heads=8, head_dim=512) 82 | # _batch_size_benchmark(output_folder, seq_len=8192, num_heads=8, head_dim=512) 83 | 84 | activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA, ProfilerActivity.XPU] 85 | sort_by_keyword = "cuda_time_total" 86 | with profile( 87 | activities=activities, 88 | record_shapes=True, 89 | profile_memory=True, 90 | use_cuda=True, 91 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 92 | output_folder / "tensorboard" 93 | ), 94 | ) as prof: 95 | _benchmark_to_profile(output_folder) 96 | 97 | print( 98 | prof.key_averages().table( 99 | sort_by=sort_by_keyword, row_limit=50, max_name_column_width=100 100 | ) 101 | ) 102 | 103 | 104 | if __name__ == "__main__": 105 | run_multiple_benchmarks() 106 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = mlstm_kernels 3 | version = 2.0.0 4 | description = A library providing fast and efficient mLSTM kernels for the xLSTM. 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | license_files = LICENSE 8 | classifiers = 9 | Operating System :: OS Independent 10 | Programming Language :: Python :: 3 11 | Programming Language :: Python :: 3 :: Only 12 | 13 | [options] 14 | packages = find: 15 | install_requires = 16 | dacite 17 | einops 18 | ipykernel 19 | matplotlib 20 | numpy 21 | omegaconf 22 | rich 23 | torch 24 | tqdm 25 | python_requires = >=3.11 26 | include_package_data = True 27 | 28 | [options.packages.find] 29 | exclude = 30 | test* 31 | res* 32 | notebooks* 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from mlstm_kernels.utils.test.fixtures import test_session_folder # noqa 5 | from mlstm_kernels.utils.test.fixtures import test_output_folder # noqa 6 | 7 | 8 | import pytest 9 | 10 | combinations_long = { 11 | "S": [256], 12 | "B": [1], 13 | "NH": [2], 14 | "DHQK": [64], 15 | "DHHV": [128], 16 | } 17 | combinations_long_list = [values for values in zip(*combinations_long.values())] 18 | 19 | final_combinations = combinations_long_list 20 | 21 | combinations_other = { 22 | "S": [256, 256, 256, 256, 256, 256, 256], 23 | "B": [4, 2, 4, 1, 2, 2, 1], 24 | "NH": [2, 4, 8, 2, 4, 2, 2], 25 | "DHQK": [64, 32, 16, 48, 256, 24, 256], 26 | "DHHV": [128, 64, 32, 96, 512, 48, 256], 27 | } 28 | combinations_other_list = [values for values in zip(*combinations_other.values())] 29 | 30 | 31 | pytest.short_test = False 32 | -------------------------------------------------------------------------------- /tests/jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /tests/jax/chunkwise/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /tests/jax/chunkwise/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import pytest 8 | 9 | 10 | @pytest.fixture 11 | def mlstm_state_passing_test() -> callable: 12 | def _mlstm_state_passing_test( 13 | kernel_fn: callable, 14 | q: jax.Array, 15 | k: jax.Array, 16 | v: jax.Array, 17 | igate_preact: jax.Array, 18 | fgate_preact: jax.Array, 19 | num_chunks: int = 4, 20 | rtol: float = 1e-5, 21 | atol: float = 1e-5, 22 | ) -> jax.Array: 23 | ctx_len = q.shape[2] 24 | input_arrays = (q, k, v, igate_preact, fgate_preact) 25 | h_full_solo = kernel_fn(*input_arrays, return_last_states=False) 26 | h_full_states, (c_full, n_full, m_full) = kernel_fn( 27 | *input_arrays, return_last_states=True 28 | ) 29 | h_chunked = [] 30 | c_chunked, n_chunked, m_chunked = None, None, None 31 | chunk_size = ctx_len // num_chunks 32 | for i in range(num_chunks): 33 | input_chunk = jax.tree.map( 34 | lambda x: x[:, :, i * chunk_size : (i + 1) * chunk_size], input_arrays 35 | ) 36 | h_chunked_i, (c_chunked, n_chunked, m_chunked) = kernel_fn( 37 | *input_chunk, 38 | c_initial=c_chunked, 39 | n_initial=n_chunked, 40 | m_initial=m_chunked, 41 | return_last_states=True, 42 | ) 43 | h_chunked.append(h_chunked_i) 44 | h_chunked = jnp.concatenate(h_chunked, axis=2) 45 | 46 | h_full_solo = jax.device_get(h_full_solo) 47 | h_full_states = jax.device_get(h_full_states) 48 | h_chunked = jax.device_get(h_chunked) 49 | 50 | np.testing.assert_allclose( 51 | h_full_solo, 52 | h_full_states, 53 | rtol=rtol, 54 | atol=atol, 55 | err_msg="H state with return_last_states=False vs True do not match.", 56 | ) 57 | np.testing.assert_allclose( 58 | h_full_states[:, :, :chunk_size], 59 | h_chunked[:, :, :chunk_size], 60 | rtol=rtol, 61 | atol=atol, 62 | err_msg="H state with single forward vs chunked do not match in the first chunk, ie without state passing.", 63 | ) 64 | np.testing.assert_allclose( 65 | h_full_states[:, :, chunk_size:], 66 | h_chunked[:, :, chunk_size:], 67 | rtol=rtol, 68 | atol=atol, 69 | err_msg="H state with single forward vs chunked do not match after the first chunk, ie with state passing.", 70 | ) 71 | 72 | c_full, n_full, m_full = jax.device_get((c_full, n_full, m_full)) 73 | c_chunked, n_chunked, m_chunked = jax.device_get( 74 | (c_chunked, n_chunked, m_chunked) 75 | ) 76 | 77 | np.testing.assert_allclose( 78 | c_full, 79 | c_chunked, 80 | rtol=rtol, 81 | atol=atol, 82 | err_msg="C state with single forward vs chunked do not match.", 83 | ) 84 | np.testing.assert_allclose( 85 | n_full, 86 | n_chunked, 87 | rtol=rtol, 88 | atol=atol, 89 | err_msg="N state with single forward vs chunked do not match.", 90 | ) 91 | np.testing.assert_allclose( 92 | m_full, 93 | m_chunked, 94 | rtol=rtol, 95 | atol=atol, 96 | err_msg="M state with single forward vs chunked do not match.", 97 | ) 98 | 99 | return _mlstm_state_passing_test 100 | -------------------------------------------------------------------------------- /tests/jax/chunkwise/test_chunkwise_native.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | 6 | from mlstm_kernels.jax.chunkwise.native import mlstm_chunkwise__native_autograd 7 | from mlstm_kernels.jax.parallel.native_stablef import ( 8 | mlstm_parallel__native_stablef_autograd, 9 | ) 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | import pytest 14 | 15 | from ...conftest import final_combinations 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | TEST_FOLDER_NAME_PREFIX = "chunkwise-jax__native" 20 | 21 | 22 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 23 | def test_native_chunkwise_torch_vs_native_parallel_stablef_fp32( 24 | test_session_folder, mlstm_parallel_interface_test, S, B, NH, DHQK, DHHV 25 | ): 26 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 27 | mlstm_parallel_interface_test( 28 | baseline_fn=mlstm_parallel__native_stablef_autograd, 29 | target_fn=mlstm_chunkwise__native_autograd, 30 | baseline_name="native_parallel_stablef_autograd", 31 | target_name="native_chunkwise_autograd", 32 | S=S, 33 | B=B, 34 | NH=NH, 35 | DHQK=DHQK, 36 | DHHV=DHHV, 37 | dtype=jnp.float32, 38 | atol_fw=3e-3, 39 | rtol_fw=1e-2, 40 | atol_fwbw=2e-2, 41 | rtol_fwbw=5e-2, 42 | vmax=1e-3, 43 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 44 | save_dir=str(test_session_folder), 45 | add_fp64_baseline=False, 46 | use_jit=False, 47 | ) 48 | -------------------------------------------------------------------------------- /tests/jax/losses_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | 8 | def loss_layernorm_offset_quadratic( 9 | input_tensor: jax.Array, seed: int = 0, eps: float = 1e-5 10 | ) -> jax.Array: 11 | rng_key = jax.random.PRNGKey(seed) 12 | rng_key, rng_key_offset = jax.random.split(rng_key) 13 | offset = jax.random.normal(rng_key_offset, input_tensor.shape) 14 | assert len(input_tensor.shape) == 4 15 | 16 | input_tensor_scaled = jax.nn.standardize(input_tensor, axis=-1) 17 | 18 | loss = jnp.sum((input_tensor_scaled + offset) ** 2) 19 | return loss 20 | -------------------------------------------------------------------------------- /tests/jax/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /tests/jax/parallel/test_parallel_native.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | 6 | from mlstm_kernels.jax.parallel.native import ( 7 | mlstm_parallel__native_autograd, 8 | mlstm_parallel__native_custbw, 9 | ) 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | import pytest 14 | 15 | from ...conftest import final_combinations 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | TEST_FOLDER_NAME_PREFIX = "parallel-jax-native" 20 | 21 | 22 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 23 | def test_parallel_native_autograd_vs_native_custbw_fp32( 24 | test_session_folder, 25 | test_output_folder, 26 | mlstm_parallel_interface_test, 27 | S, 28 | B, 29 | NH, 30 | DHQK, 31 | DHHV, 32 | ): 33 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 34 | mlstm_parallel_interface_test( 35 | baseline_fn=mlstm_parallel__native_autograd, 36 | target_fn=mlstm_parallel__native_custbw, 37 | baseline_name="parallel_unstable_autograd", 38 | target_name="parallel_unstable_custbw", 39 | S=S, 40 | B=B, 41 | NH=NH, 42 | DHQK=DHQK, 43 | DHHV=DHHV, 44 | dtype=jnp.float32, 45 | atol_fw=1e-3, 46 | rtol_fw=1e-2, 47 | atol_fwbw=2e-1, # vecFgrad as high errors: Max absolute difference: 0.22783774 48 | rtol_fwbw=5e-2, 49 | vmax=1e-3, 50 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 51 | save_dir=str(test_session_folder), 52 | add_fp64_baseline=False, 53 | run_backward=True, 54 | # save_output_tensors_dir=str(test_output_folder / "test_data"), 55 | use_jit=False, 56 | ) 57 | -------------------------------------------------------------------------------- /tests/jax/parallel/test_parallel_native_siging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import pytest 9 | 10 | from mlstm_kernels.jax.parallel.native_siging import ( 11 | mlstm_siging_parallel__native_autograd, 12 | mlstm_siging_parallel__native_custbw, 13 | ) 14 | 15 | from ...conftest import final_combinations 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | TEST_FOLDER_NAME_PREFIX = "parallel-jax-native-siging" 20 | 21 | 22 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 23 | @pytest.mark.parametrize("normalize", [True, False]) 24 | @pytest.mark.parametrize("stable_fgate", [True, False]) 25 | def test_parallel_native_siging_autograd_vs_native_custbw_fp32( 26 | test_session_folder, 27 | test_output_folder, 28 | mlstm_parallel_interface_test, 29 | S, 30 | B, 31 | NH, 32 | DHQK, 33 | DHHV, 34 | normalize, 35 | stable_fgate, 36 | ): 37 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 38 | mlstm_parallel_interface_test( 39 | baseline_fn=mlstm_siging_parallel__native_autograd, 40 | target_fn=mlstm_siging_parallel__native_custbw, 41 | baseline_name=f"parallel_siging_stablef_{stable_fgate}_norm{normalize}_autograd", 42 | target_name=f"parallel_siging_stablef_{stable_fgate}_norm{normalize}_custbw", 43 | S=S, 44 | B=B, 45 | NH=NH, 46 | DHQK=DHQK, 47 | DHHV=DHHV, 48 | dtype=jnp.float32, 49 | atol_fw=1e-3, 50 | rtol_fw=0.01, 51 | atol_fwbw=0.2, # vecFgrad as high errors: Max absolute difference: 0.27992487 52 | rtol_fwbw=0.1, 53 | vmax=1e-3, 54 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 55 | save_dir=str(test_session_folder), 56 | add_fp64_baseline=False, 57 | run_backward=True, 58 | # save_output_tensors_dir=str(test_output_folder / "test_data"), 59 | use_jit=False, 60 | ) 61 | -------------------------------------------------------------------------------- /tests/jax/parallel/test_parallel_native_stablef.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import pytest 9 | 10 | from mlstm_kernels.jax.parallel.native_stablef import ( 11 | mlstm_parallel__native_stablef_autograd, 12 | mlstm_parallel__native_stablef_custbw, 13 | ) 14 | 15 | from ...conftest import final_combinations 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | TEST_FOLDER_NAME_PREFIX = "parallel-jax-native" 20 | 21 | 22 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 23 | def test_parallel_stablef_native_autograd_vs_native_custbw_fp32( 24 | test_session_folder, 25 | test_output_folder, 26 | mlstm_parallel_interface_test, 27 | S, 28 | B, 29 | NH, 30 | DHQK, 31 | DHHV, 32 | ): 33 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 34 | mlstm_parallel_interface_test( 35 | baseline_fn=mlstm_parallel__native_stablef_autograd, 36 | target_fn=mlstm_parallel__native_stablef_custbw, 37 | baseline_name="parallel_stable_autograd", 38 | target_name="parallel_stable_custbw", 39 | S=S, 40 | B=B, 41 | NH=NH, 42 | DHQK=DHQK, 43 | DHHV=DHHV, 44 | dtype=jnp.float32, 45 | atol_fw=1e-3, 46 | rtol_fw=1e-2, 47 | atol_fwbw=0.15, # 3.2e-2, # matQgrad as high errors: Max absolute difference: 0.33696747 48 | rtol_fwbw=0.1, # 9e-2, 49 | vmax=1e-3, 50 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 51 | save_dir=str(test_session_folder), 52 | add_fp64_baseline=False, 53 | run_backward=True, 54 | # save_output_tensors_dir=str(test_output_folder / "test_data"), 55 | use_jit=False, 56 | ) 57 | -------------------------------------------------------------------------------- /tests/jax/parallel/test_parallel_native_vs_native_stablef.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | 6 | from mlstm_kernels.jax.parallel.native import mlstm_parallel__native_autograd 7 | from mlstm_kernels.jax.parallel.native_stablef import ( 8 | mlstm_parallel__native_stablef_autograd, 9 | ) 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | import pytest 14 | 15 | from ...conftest import final_combinations 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | TEST_FOLDER_NAME_PREFIX = "parallel-jax-native" 20 | 21 | 22 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 23 | def test_parallel_native_vs_native_stablef_fp32( 24 | test_session_folder, 25 | test_output_folder, 26 | mlstm_parallel_interface_test, 27 | S, 28 | B, 29 | NH, 30 | DHQK, 31 | DHHV, 32 | ): 33 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 34 | mlstm_parallel_interface_test( 35 | baseline_fn=mlstm_parallel__native_autograd, 36 | target_fn=mlstm_parallel__native_stablef_autograd, 37 | baseline_name="parallel_unstable_autograd", 38 | target_name="parallel_stablef_autograd", 39 | S=S, 40 | B=B, 41 | NH=NH, 42 | DHQK=DHQK, 43 | DHHV=DHHV, 44 | dtype=jnp.float32, 45 | atol_fw=1e-3, 46 | rtol_fw=1e-2, 47 | atol_fwbw=3e-2, 48 | rtol_fwbw=6e-2, 49 | vmax=1e-3, 50 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 51 | save_dir=str(test_session_folder), 52 | add_fp64_baseline=False, 53 | run_backward=True, 54 | # save_output_tensors_dir=str(test_output_folder / "test_data"), 55 | use_jit=False, 56 | ) 57 | -------------------------------------------------------------------------------- /tests/jax/parallel/test_vs_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | from mlstm_kernels.jax.parallel.native import mlstm_parallel__native_autograd 5 | from mlstm_kernels.jax.parallel.native_stablef import ( 6 | mlstm_parallel__native_stablef_autograd, 7 | ) 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | import pytest 13 | 14 | 15 | def test_mlstm_parallel_jax_vs_torch(torch_parallel_stablef_vs_unstablef_test_data): 16 | test_data = torch_parallel_stablef_vs_unstablef_test_data 17 | matQ = jnp.array(test_data["matQ"]) 18 | matK = jnp.array(test_data["matK"]) 19 | matV = jnp.array(test_data["matV"]) 20 | vecI = jnp.array(test_data["vecI"]) 21 | vecF = jnp.array(test_data["vecF"]) 22 | print(torch_parallel_stablef_vs_unstablef_test_data.keys()) 23 | matH_torch_unstable = test_data["matH_target"] 24 | matH_torch_stable = test_data["matH_baseline"] 25 | 26 | matH_jax_unstable = mlstm_parallel__native_autograd(matQ, matK, matV, vecI, vecF) 27 | matH_jax_unstable = jax.device_get(matH_jax_unstable) 28 | 29 | np.testing.assert_allclose( 30 | matH_torch_unstable, matH_jax_unstable, atol=3e-3, rtol=6e-2 31 | ) 32 | 33 | matH_jax_stable = mlstm_parallel__native_stablef_autograd( 34 | matQ, matK, matV, vecI, vecF 35 | ) 36 | matH_jax_stable = jax.device_get(matH_jax_stable) 37 | 38 | np.testing.assert_allclose(matH_torch_stable, matH_jax_stable, atol=3e-3, rtol=6e-2) 39 | -------------------------------------------------------------------------------- /tests/jax/recurrent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /tests/jax/recurrent/test_recurrent_sequence_native.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | from pathlib import Path 6 | 7 | from mlstm_kernels.jax.parallel.native_stablef import ( 8 | mlstm_parallel__native_stablef_autograd, 9 | ) 10 | from mlstm_kernels.jax.recurrent.native_sequence import ( 11 | mlstm_recurrent_sequence__native_fw, 12 | mlstm_recurrent_sequence__triton_step_fused_fw, 13 | ) 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import pytest 18 | 19 | from ...conftest import final_combinations 20 | 21 | LOGGER = logging.getLogger(__name__) 22 | 23 | TEST_FOLDER_NAME_PREFIX = "recurrent_sequence-jax__native" 24 | 25 | 26 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 27 | def test_native_recurrent_sequence_native_step_vs_native_parallel_stablef_fp32( 28 | test_session_folder: Path, mlstm_parallel_interface_test, S, B, NH, DHQK, DHHV 29 | ): 30 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 31 | mlstm_parallel_interface_test( 32 | baseline_fn=mlstm_parallel__native_stablef_autograd, 33 | target_fn=mlstm_recurrent_sequence__native_fw, 34 | baseline_name="native_parallel_stablef_autograd", 35 | target_name="native_recurrent_sequence_native_step", 36 | S=S, 37 | B=B, 38 | NH=NH, 39 | DHQK=DHQK, 40 | DHHV=DHHV, 41 | dtype=jnp.float32, 42 | atol_fw=1.1e-2, # Max absolute difference: 0.01150007 43 | rtol_fw=5e-2, 44 | atol_fwbw=2e-2, 45 | rtol_fwbw=5e-2, 46 | vmax=1e-3, 47 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 48 | save_dir=str(test_session_folder), 49 | add_fp64_baseline=False, 50 | use_jit=False, 51 | run_backward=False, 52 | ) 53 | 54 | 55 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 56 | def test_native_recurrent_sequence_triton_step_fused_vs_native_parallel_stablef_fp32( 57 | test_session_folder: Path, mlstm_parallel_interface_test, S, B, NH, DHQK, DHHV 58 | ): 59 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 60 | mlstm_parallel_interface_test( 61 | baseline_fn=mlstm_parallel__native_stablef_autograd, 62 | target_fn=mlstm_recurrent_sequence__triton_step_fused_fw, 63 | baseline_name="native_parallel_stablef_autograd", 64 | target_name="native_recurrent_sequence_triton_step_fused", 65 | S=S, 66 | B=B, 67 | NH=NH, 68 | DHQK=DHQK, 69 | DHHV=DHHV, 70 | dtype=jnp.float32, 71 | atol_fw=1.1e-2, # Max absolute difference: 0.0114983 72 | rtol_fw=5e-2, 73 | atol_fwbw=2e-2, 74 | rtol_fwbw=5e-2, 75 | vmax=1e-3, 76 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 77 | save_dir=str(test_session_folder), 78 | add_fp64_baseline=False, 79 | use_jit=False, 80 | run_backward=False, 81 | ) 82 | -------------------------------------------------------------------------------- /tests/jax/recurrent/test_recurrent_sequence_scan_native.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | from pathlib import Path 6 | 7 | from mlstm_kernels.jax.parallel.native_stablef import ( 8 | mlstm_parallel__native_stablef_autograd, 9 | ) 10 | from mlstm_kernels.jax.recurrent.native_sequence_scan import ( 11 | mlstm_recurrent_sequence__native_fw, 12 | mlstm_recurrent_sequence__triton_step_fused_fw, 13 | ) 14 | 15 | import jax 16 | import jax.numpy as jnp 17 | import pytest 18 | 19 | from ...conftest import final_combinations 20 | 21 | LOGGER = logging.getLogger(__name__) 22 | 23 | TEST_FOLDER_NAME_PREFIX = "recurrent_sequence-jax__native" 24 | 25 | 26 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 27 | def test_native_recurrent_sequence_native_step_vs_native_parallel_stablef_fp32( 28 | test_session_folder: Path, mlstm_parallel_interface_test, S, B, NH, DHQK, DHHV 29 | ): 30 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 31 | mlstm_parallel_interface_test( 32 | baseline_fn=mlstm_parallel__native_stablef_autograd, 33 | target_fn=mlstm_recurrent_sequence__native_fw, 34 | baseline_name="native_parallel_stablef_autograd", 35 | target_name="native_recurrent_sequence_native_step", 36 | S=S, 37 | B=B, 38 | NH=NH, 39 | DHQK=DHQK, 40 | DHHV=DHHV, 41 | dtype=jnp.float32, 42 | atol_fw=1.1e-2, # Max absolute difference: 0.01150007 43 | rtol_fw=5e-2, 44 | atol_fwbw=2e-2, 45 | rtol_fwbw=5e-2, 46 | vmax=1e-3, 47 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 48 | save_dir=str(test_session_folder), 49 | add_fp64_baseline=False, 50 | use_jit=False, 51 | run_backward=False, 52 | ) 53 | 54 | 55 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 56 | def test_native_recurrent_sequence_triton_step_fused_vs_native_parallel_stablef_fp32( 57 | test_session_folder: Path, mlstm_parallel_interface_test, S, B, NH, DHQK, DHHV 58 | ): 59 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 60 | mlstm_parallel_interface_test( 61 | baseline_fn=mlstm_parallel__native_stablef_autograd, 62 | target_fn=mlstm_recurrent_sequence__triton_step_fused_fw, 63 | baseline_name="native_parallel_stablef_autograd", 64 | target_name="native_recurrent_sequence_triton_step_fused", 65 | S=S, 66 | B=B, 67 | NH=NH, 68 | DHQK=DHQK, 69 | DHHV=DHHV, 70 | dtype=jnp.float32, 71 | atol_fw=1.1e-2, # Max absolute difference: 0.0114983 72 | rtol_fw=5e-2, 73 | atol_fwbw=2e-2, 74 | rtol_fwbw=5e-2, 75 | vmax=1e-3, 76 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 77 | save_dir=str(test_session_folder), 78 | add_fp64_baseline=False, 79 | use_jit=False, 80 | run_backward=False, 81 | ) 82 | -------------------------------------------------------------------------------- /tests/test_padding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import sys 5 | import unittest 6 | 7 | import torch 8 | 9 | print(sys.path) 10 | 11 | from mlstm_kernels import get_kernel, get_whole_registry 12 | 13 | 14 | class TestPadding(unittest.TestCase): 15 | def test_padding(self): 16 | B, N, S, H = 1, 1, 63, 128 17 | dtype = torch.bfloat16 18 | device = torch.device("cuda") 19 | q, k, v = ( 20 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 21 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 22 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 23 | ) 24 | i, f = ( 25 | torch.zeros([B, N, S], device=device, dtype=dtype), 26 | torch.zeros([B, N, S], device=device, dtype=dtype), 27 | ) 28 | kernel = get_kernel("mlstm_chunkwise--triton", padded_chunk_size=64) 29 | h = kernel(q, k, v, i, f) 30 | assert h.shape == v.shape 31 | 32 | B, N, S, H = 1, 1, 128, 128 33 | dtype = torch.bfloat16 34 | device = torch.device("cuda") 35 | q, k, v = ( 36 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 37 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 38 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 39 | ) 40 | i, f = ( 41 | torch.zeros([B, N, S], device=device, dtype=dtype), 42 | torch.zeros([B, N, S], device=device, dtype=dtype), 43 | ) 44 | kernel = get_kernel("mlstm_chunkwise--triton", padded_chunk_size=64) 45 | h = kernel(q, k, v, i, f) 46 | assert h.shape == v.shape 47 | 48 | def test_padding_whole_registry(self): 49 | B, N, S, H = 1, 1, 63, 128 50 | dtype = torch.bfloat16 51 | device = torch.device("cuda") 52 | q, k, v = ( 53 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 54 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 55 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 56 | ) 57 | i, f = ( 58 | torch.zeros([B, N, S], device=device, dtype=dtype), 59 | torch.zeros([B, N, S], device=device, dtype=dtype), 60 | ) 61 | kernel = get_whole_registry(padded_chunk_size=64)["mlstm_chunkwise--triton"] 62 | h = kernel(q, k, v, i, f) 63 | assert h.shape == v.shape 64 | 65 | B, N, S, H = 1, 1, 128, 128 66 | dtype = torch.bfloat16 67 | device = torch.device("cuda") 68 | q, k, v = ( 69 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 70 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 71 | torch.zeros([B, N, S, H], device=device, dtype=dtype), 72 | ) 73 | i, f = ( 74 | torch.zeros([B, N, S], device=device, dtype=dtype), 75 | torch.zeros([B, N, S], device=device, dtype=dtype), 76 | ) 77 | kernel = get_whole_registry(padded_chunk_size=64)["mlstm_chunkwise--triton"] 78 | h = kernel(q, k, v, i, f) 79 | assert h.shape == v.shape 80 | 81 | 82 | # print("Hello") 83 | -------------------------------------------------------------------------------- /tests/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /tests/torch/chunkwise/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /tests/torch/chunkwise/test_chunkwise_native.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | 6 | from mlstm_kernels.torch.chunkwise.native import mlstm_chunkwise__native_custbw 7 | from mlstm_kernels.torch.parallel.native_stablef import ( 8 | mlstm_parallel__native_stablef_custbw, 9 | ) 10 | 11 | import pytest 12 | import torch 13 | 14 | from ...conftest import final_combinations 15 | 16 | LOGGER = logging.getLogger(__name__) 17 | 18 | TEST_FOLDER_NAME_PREFIX = "chunkwise-torch" 19 | 20 | 21 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 22 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 23 | def test_native_chunkwise_torch_vs_native_parrallel_stablef_fp32( 24 | test_session_folder, 25 | test_output_folder, 26 | mlstm_parallel_interface_test, 27 | S, 28 | B, 29 | NH, 30 | DHQK, 31 | DHHV, 32 | ): 33 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 34 | mlstm_parallel_interface_test( 35 | baseline_fn=mlstm_parallel__native_stablef_custbw, 36 | target_fn=mlstm_chunkwise__native_custbw, 37 | baseline_name="native_parallel_stablef_custbw", 38 | target_name="native_chunkwise_custbw", 39 | S=S, 40 | B=B, 41 | NH=NH, 42 | DHQK=DHQK, 43 | DHHV=DHHV, 44 | dtype=torch.float32, 45 | atol_fw=1e-4, 46 | rtol_fw=1e-3, 47 | atol_fwbw=1e-3, 48 | rtol_fwbw=1e-2, 49 | vmax=1e-4, 50 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 51 | save_dir=str(test_session_folder), 52 | add_fp64_baseline=False, 53 | # save_output_tensors_dir=str(test_output_folder), 54 | ) 55 | 56 | 57 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 58 | def test_state_passing(mlstm_state_passing_test, state_passing_qkvif): 59 | num_chunks = state_passing_qkvif[0].shape[2] // 64 # <- chunk size = 64 60 | 61 | mlstm_state_passing_test( 62 | kernel_fn=mlstm_chunkwise__native_custbw, 63 | q=state_passing_qkvif[0], 64 | k=state_passing_qkvif[1], 65 | v=state_passing_qkvif[2], 66 | igate_preact=state_passing_qkvif[3], 67 | fgate_preact=state_passing_qkvif[4], 68 | num_chunks=num_chunks, 69 | rtol=1e-5, 70 | atol=1e-5, 71 | device="cuda", 72 | ) 73 | -------------------------------------------------------------------------------- /tests/torch/chunkwise/test_chunkwise_triton_limit_chunk.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | 6 | import pytest 7 | import torch 8 | 9 | from mlstm_kernels.torch.chunkwise.triton_limit_chunk import ( 10 | mlstm_chunkwise__limit_chunk, 11 | ) 12 | from mlstm_kernels.torch.parallel.native_stablef import ( 13 | mlstm_parallel__native_stablef_custbw, 14 | ) 15 | 16 | from ...conftest import final_combinations 17 | 18 | LOGGER = logging.getLogger(__name__) 19 | 20 | TEST_FOLDER_NAME_PREFIX = "chunkwise-triton_limit_chunk" 21 | 22 | 23 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 24 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 25 | def test_triton_chunkwise_limit_chunk_vs_native_parallel_stablef_fp32( 26 | test_session_folder, 27 | test_output_folder, 28 | mlstm_parallel_interface_test, 29 | S, 30 | B, 31 | NH, 32 | DHQK, 33 | DHHV, 34 | ): 35 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 36 | mlstm_parallel_interface_test( 37 | baseline_fn=mlstm_parallel__native_stablef_custbw, 38 | target_fn=mlstm_chunkwise__limit_chunk, 39 | baseline_name="native_parallel_stablef_custbw", 40 | target_name="triton_chunkwise_limit_chunk", 41 | S=S, 42 | B=B, 43 | NH=NH, 44 | DHQK=DHQK, 45 | DHHV=DHHV, 46 | dtype=torch.float32, 47 | atol_fw=2e-2, 48 | rtol_fw=5e-2, 49 | atol_fwbw=3e-1, 50 | rtol_fwbw=8e-1, 51 | vmax=1e-3, 52 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 53 | save_dir=str(test_session_folder), 54 | add_fp64_baseline=False, 55 | save_output_tensors_dir=str(test_output_folder), 56 | ) 57 | 58 | 59 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 60 | def test_state_passing(mlstm_state_passing_test, state_passing_qkvif): 61 | num_chunks = state_passing_qkvif[0].shape[2] // 64 # <- chunk size = 64 62 | 63 | mlstm_state_passing_test( 64 | kernel_fn=mlstm_chunkwise__limit_chunk, 65 | q=state_passing_qkvif[0], 66 | k=state_passing_qkvif[1], 67 | v=state_passing_qkvif[2], 68 | igate_preact=state_passing_qkvif[3], 69 | fgate_preact=state_passing_qkvif[4], 70 | num_chunks=num_chunks, 71 | rtol=1e-5, 72 | atol=1e-5, 73 | device="cuda", 74 | ) 75 | -------------------------------------------------------------------------------- /tests/torch/losses_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import torch 5 | 6 | 7 | def loss_layernorm_offset_quadratic( 8 | input_tensor: torch.Tensor, seed: int = 0, eps: float = 1e-5 9 | ) -> torch.Tensor: 10 | torch.manual_seed(seed) 11 | offset = torch.randn_like(input_tensor) 12 | assert len(input_tensor.shape) == 4 13 | 14 | input_tensor_scaled = ( 15 | input_tensor - input_tensor.mean(-1, keepdim=True) 16 | ) / torch.sqrt(input_tensor.var(dim=-1, keepdim=True, unbiased=False) + eps) 17 | 18 | loss = ((input_tensor_scaled + offset) ** 2).sum() 19 | return loss 20 | -------------------------------------------------------------------------------- /tests/torch/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /tests/torch/parallel/test_parallel_native_siging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | from functools import partial 6 | 7 | import pytest 8 | import torch 9 | 10 | from mlstm_kernels.torch.parallel.native_siging import ( 11 | mlstm_siging_parallel__native_autograd, 12 | mlstm_siging_parallel__native_custbw, 13 | ) 14 | 15 | from ...conftest import final_combinations 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | TEST_FOLDER_NAME_PREFIX = "parallel-torch-native_siging" 20 | 21 | 22 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 23 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 24 | def test_parallel_native_siging_stablef_normalized( 25 | test_session_folder, 26 | test_output_folder, 27 | mlstm_parallel_interface_test, 28 | S, 29 | B, 30 | NH, 31 | DHQK, 32 | DHHV, 33 | ): 34 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 35 | mlstm_parallel_interface_test( 36 | baseline_fn=partial( 37 | mlstm_siging_parallel__native_autograd, stable_fgate=True, normalize=True 38 | ), 39 | target_fn=partial( 40 | mlstm_siging_parallel__native_custbw, stable_fgate=True, normalize=True 41 | ), 42 | baseline_name="parallel_siging_stable_normalized_autograd", 43 | target_name="parallel_siging_stable_normalized_custbw", 44 | S=S, 45 | B=B, 46 | NH=NH, 47 | DHQK=DHQK, 48 | DHHV=DHHV, 49 | dtype=torch.float32, 50 | atol_fw=1e-4, 51 | rtol_fw=1e-3, 52 | atol_fwbw=2e-3, 53 | rtol_fwbw=1e-2, 54 | vmax=1e-3, 55 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 56 | save_dir=str(test_session_folder), 57 | add_fp64_baseline=False, 58 | save_output_tensors_dir=str(test_output_folder), 59 | ) 60 | 61 | 62 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 63 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 64 | def test_parallel_native_siging_stablef_unnormalized( 65 | test_session_folder, 66 | test_output_folder, 67 | mlstm_parallel_interface_test, 68 | S, 69 | B, 70 | NH, 71 | DHQK, 72 | DHHV, 73 | ): 74 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 75 | mlstm_parallel_interface_test( 76 | baseline_fn=partial( 77 | mlstm_siging_parallel__native_autograd, stable_fgate=True, normalize=False 78 | ), 79 | target_fn=partial( 80 | mlstm_siging_parallel__native_custbw, stable_fgate=True, normalize=False 81 | ), 82 | baseline_name="parallel_siging_stable_unnormalized_autograd", 83 | target_name="parallel_siging_stable_unnormalized_custbw", 84 | S=S, 85 | B=B, 86 | NH=NH, 87 | DHQK=DHQK, 88 | DHHV=DHHV, 89 | dtype=torch.float32, 90 | atol_fw=1e-4, 91 | rtol_fw=1e-3, 92 | atol_fwbw=2e-4, 93 | rtol_fwbw=5e-3, 94 | vmax=1e-3, 95 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 96 | save_dir=str(test_session_folder), 97 | add_fp64_baseline=False, 98 | save_output_tensors_dir=str(test_output_folder), 99 | ) 100 | -------------------------------------------------------------------------------- /tests/torch/parallel/test_parallel_triton_limit_headdim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | 6 | import pytest 7 | import torch 8 | 9 | from mlstm_kernels.torch.parallel.native_stablef import ( 10 | mlstm_parallel__native_stablef_custbw, 11 | ) 12 | from mlstm_kernels.torch.parallel.triton_limit_headdim import ( 13 | mlstm_parallel__limit_headdim, 14 | ) 15 | 16 | LOGGER = logging.getLogger(__name__) 17 | 18 | TEST_FOLDER_NAME_PREFIX = "parallel-triton_limit_headdim" 19 | 20 | combinations_long = { 21 | "S": [256], # [8192], 22 | "B": [1], # [2, 2, 2, 2], 23 | "NH": [2], # [3, 3, 3, 3], 24 | "DHQK": [128], # [5, 5, 5, 5], 25 | "DHHV": [128], # [5, 5, 5, 5], 26 | } 27 | combinations = [values for values in zip(*combinations_long.values())] 28 | 29 | 30 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 31 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], combinations) 32 | def test_triton_parallel_limit_headdim_vs_native_parallel_stablef_fp32( 33 | test_session_folder, mlstm_parallel_interface_test, S, B, NH, DHQK, DHHV 34 | ): 35 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 36 | mlstm_parallel_interface_test( 37 | baseline_fn=mlstm_parallel__native_stablef_custbw, 38 | target_fn=mlstm_parallel__limit_headdim, 39 | baseline_name="native_parallel_stablef_custbw", 40 | target_name="triton_parallel_limit_headdim", 41 | S=S, 42 | B=B, 43 | NH=NH, 44 | DHQK=DHQK, 45 | DHHV=DHHV, 46 | dtype=torch.float32, 47 | atol_fw=1e-2, 48 | rtol_fw=5e-2, 49 | atol_fwbw=3e-1, # we need to increase this tolerance for vecF.grad (max diff val 0.267...) 50 | rtol_fwbw=1.0, 51 | vmax=1e-2, 52 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 53 | save_dir=str(test_session_folder), 54 | add_fp64_baseline=False, 55 | ) 56 | -------------------------------------------------------------------------------- /tests/torch/recurrent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | -------------------------------------------------------------------------------- /tests/torch/recurrent/test_recurrent_sequence_native_vs_parallel_native.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import logging 5 | 6 | from mlstm_kernels.torch.parallel.native_stablef import ( 7 | mlstm_parallel__native_stablef_autograd, 8 | ) 9 | from mlstm_kernels.torch.recurrent.native_sequence import ( 10 | mlstm_recurrent_sequence__native_fw, 11 | ) 12 | 13 | import pytest 14 | import torch 15 | 16 | from ...conftest import final_combinations 17 | 18 | LOGGER = logging.getLogger(__name__) 19 | 20 | TEST_FOLDER_NAME_PREFIX = "recurrent_seq-torch_native" 21 | 22 | 23 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 24 | @pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) 25 | def test_recurrent_sequence_native_vs_native_parrallel_stablef_fp32( 26 | test_session_folder, mlstm_parallel_interface_test, S, B, NH, DHQK, DHHV 27 | ): 28 | print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}") 29 | mlstm_parallel_interface_test( 30 | baseline_fn=mlstm_parallel__native_stablef_autograd, 31 | target_fn=mlstm_recurrent_sequence__native_fw, 32 | baseline_name="native_parallel_stablef_autograd", 33 | target_name="native_recurrent_sequence__native_fw", 34 | S=S, 35 | B=B, 36 | NH=NH, 37 | DHQK=DHQK, 38 | DHHV=DHHV, 39 | dtype=torch.float32, 40 | atol_fw=1e-4, 41 | rtol_fw=1e-4, 42 | atol_fwbw=1e-4, 43 | rtol_fwbw=1e-2, 44 | vmax=1e-3, 45 | test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX, 46 | save_dir=str(test_session_folder), 47 | add_fp64_baseline=False, 48 | ) 49 | -------------------------------------------------------------------------------- /tests/torch/recurrent/test_sequence_chunked.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) NXAI GmbH. 2 | # This software may be used and distributed according to the terms of the NXAI Community License Agreement. 3 | 4 | import pytest 5 | import torch 6 | 7 | from mlstm_kernels.torch.recurrent.native_sequence import ( 8 | mlstm_recurrent_sequence__native_fw, 9 | mlstm_recurrent_sequence__triton_step_fused_fw, 10 | ) 11 | 12 | 13 | def test_state_passing__native_step(mlstm_state_passing_test, state_passing_qkvif): 14 | num_chunks = state_passing_qkvif[0].shape[2] // 64 # <- chunk size = 64 15 | 16 | mlstm_state_passing_test( 17 | kernel_fn=mlstm_recurrent_sequence__native_fw, 18 | q=state_passing_qkvif[0], 19 | k=state_passing_qkvif[1], 20 | v=state_passing_qkvif[2], 21 | igate_preact=state_passing_qkvif[3], 22 | fgate_preact=state_passing_qkvif[4], 23 | num_chunks=num_chunks, 24 | rtol=1e-6, 25 | atol=1e-6, 26 | device="cuda", 27 | ) 28 | 29 | 30 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 31 | def test_state_passing__triton_step_fused( 32 | mlstm_state_passing_test, state_passing_qkvif 33 | ): 34 | num_chunks = state_passing_qkvif[0].shape[2] // 64 # <- chunk size = 64 35 | 36 | mlstm_state_passing_test( 37 | kernel_fn=mlstm_recurrent_sequence__triton_step_fused_fw, 38 | q=state_passing_qkvif[0], 39 | k=state_passing_qkvif[1], 40 | v=state_passing_qkvif[2], 41 | igate_preact=state_passing_qkvif[3], 42 | fgate_preact=state_passing_qkvif[4], 43 | num_chunks=num_chunks, 44 | rtol=1e-6, 45 | atol=1e-6, 46 | device="cuda", 47 | ) 48 | 49 | 50 | # There is probably a bug in triton_step 51 | # @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") 52 | # def test_state_passing__triton_step(mlstm_state_passing_test, state_passing_qkvif): 53 | # num_chunks = state_passing_qkvif[0].shape[2] // 64 # <- chunk size = 64 54 | 55 | # mlstm_state_passing_test( 56 | # kernel_fn=mlstm_recurrent_sequence__triton_step_fw, 57 | # q=state_passing_qkvif[0], 58 | # k=state_passing_qkvif[1], 59 | # v=state_passing_qkvif[2], 60 | # igate_preact=state_passing_qkvif[3], 61 | # fgate_preact=state_passing_qkvif[4], 62 | # num_chunks=num_chunks, 63 | # rtol=2e-3, 64 | # atol=2e-3, 65 | # device="cuda", 66 | # ) 67 | --------------------------------------------------------------------------------