├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ ├── bug.yml │ ├── config.yml │ ├── feature-request.yml │ └── rfc.yml └── workflows │ ├── issue.yml │ ├── lint.yaml │ ├── release.yml │ └── stable.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── FAQs.md ├── LICENSE ├── README.md ├── benchmarks ├── benchmark_generation.py ├── benchmark_training_throughput.py ├── modules │ ├── benchmark_conv.py │ ├── benchmark_cross_entropy.py │ ├── benchmark_l2norm.py │ ├── benchmark_layernorm.py │ └── benchmark_tokenshift.py └── ops │ ├── benchmark.py │ ├── benchmark_abc.py │ ├── benchmark_based.py │ ├── benchmark_delta_rule.py │ ├── benchmark_fla.py │ ├── benchmark_gla.py │ ├── benchmark_gsa.py │ ├── benchmark_hgrn.py │ ├── benchmark_nsa.py │ ├── benchmark_retention.py │ ├── benchmark_rwkv.py │ ├── benchmark_rwkv6.py │ ├── benchmark_rwkv7.py │ ├── benchmark_simple_gla_vs_mamba2.py │ ├── benchmark_titans.py │ └── benchmark_ttt.py ├── evals ├── harness.py └── ppl.py ├── examples └── training.md ├── fla ├── __init__.py ├── layers │ ├── __init__.py │ ├── abc.py │ ├── attn.py │ ├── based.py │ ├── bitattn.py │ ├── delta_net.py │ ├── forgetting_attn.py │ ├── gated_deltanet.py │ ├── gated_deltaproduct.py │ ├── gla.py │ ├── gsa.py │ ├── hgrn.py │ ├── hgrn2.py │ ├── lightnet.py │ ├── linear_attn.py │ ├── mamba.py │ ├── mamba2.py │ ├── mesa_net.py │ ├── multiscale_retention.py │ ├── nsa.py │ ├── path_attn.py │ ├── rebased.py │ ├── rodimus.py │ ├── rwkv6.py │ ├── rwkv7.py │ ├── simple_gla.py │ └── utils.py ├── models │ ├── __init__.py │ ├── abc │ │ ├── __init__.py │ │ ├── configuration_abc.py │ │ └── modeling_abc.py │ ├── bitnet │ │ ├── __init__.py │ │ ├── configuration_bitnet.py │ │ └── modeling_bitnet.py │ ├── delta_net │ │ ├── __init__.py │ │ ├── configuration_delta_net.py │ │ └── modeling_delta_net.py │ ├── forgetting_transformer │ │ ├── __init__.py │ │ ├── configuration_forgetting_transformer.py │ │ └── modeling_forgetting_transformer.py │ ├── gated_deltanet │ │ ├── __init__.py │ │ ├── configuration_gated_deltanet.py │ │ └── modeling_gated_deltanet.py │ ├── gated_deltaproduct │ │ ├── __init__.py │ │ ├── configuration_gated_deltaproduct.py │ │ └── modeling_gated_deltaproduct.py │ ├── gla │ │ ├── __init__.py │ │ ├── configuration_gla.py │ │ └── modeling_gla.py │ ├── gsa │ │ ├── __init__.py │ │ ├── configuration_gsa.py │ │ └── modeling_gsa.py │ ├── hgrn │ │ ├── __init__.py │ │ ├── configuration_hgrn.py │ │ └── modeling_hgrn.py │ ├── hgrn2 │ │ ├── __init__.py │ │ ├── configuration_hgrn2.py │ │ └── modeling_hgrn2.py │ ├── lightnet │ │ ├── __init__.py │ │ ├── configuration_lightnet.py │ │ └── modeling_lightnet.py │ ├── linear_attn │ │ ├── __init__.py │ │ ├── configuration_linear_attn.py │ │ └── modeling_linear_attn.py │ ├── mamba │ │ ├── __init__.py │ │ ├── configuration_mamba.py │ │ └── modeling_mamba.py │ ├── mamba2 │ │ ├── __init__.py │ │ ├── configuration_mamba2.py │ │ └── modeling_mamba2.py │ ├── mesa_net │ │ ├── __init__.py │ │ ├── configuration_mesa_net.py │ │ └── modeling_mesa_net.py │ ├── nsa │ │ ├── __init__.py │ │ ├── configuration_nsa.py │ │ └── modeling_nsa.py │ ├── path_attn │ │ ├── __init__.py │ │ ├── configuration_path_attention.py │ │ └── modeling_path_attention.py │ ├── retnet │ │ ├── __init__.py │ │ ├── configuration_retnet.py │ │ └── modeling_retnet.py │ ├── rodimus │ │ ├── __init__.py │ │ ├── chat_format.py │ │ ├── configuration_rodimus.py │ │ ├── modeling_rodimus.py │ │ └── tokenization_rodimus_fast.py │ ├── rwkv6 │ │ ├── __init__.py │ │ ├── configuration_rwkv6.py │ │ └── modeling_rwkv6.py │ ├── rwkv7 │ │ ├── __init__.py │ │ ├── configuration_rwkv7.py │ │ └── modeling_rwkv7.py │ ├── samba │ │ ├── __init__.py │ │ ├── configuration_samba.py │ │ └── modeling_samba.py │ ├── transformer │ │ ├── __init__.py │ │ ├── configuration_transformer.py │ │ └── modeling_transformer.py │ └── utils.py ├── modules │ ├── __init__.py │ ├── activations.py │ ├── convolution.py │ ├── feature_map.py │ ├── fused_bitlinear.py │ ├── fused_cross_entropy.py │ ├── fused_kl_div.py │ ├── fused_linear_cross_entropy.py │ ├── fused_norm_gate.py │ ├── grpo.py │ ├── l2norm.py │ ├── layernorm.py │ ├── layernorm_gated.py │ ├── mlp.py │ ├── parallel.py │ ├── rotary.py │ └── token_shift.py ├── ops │ ├── __init__.py │ ├── abc │ │ ├── __init__.py │ │ ├── chunk.py │ │ └── naive.py │ ├── attn │ │ ├── __init__.py │ │ ├── decoding.py │ │ └── parallel.py │ ├── based │ │ ├── __init__.py │ │ ├── fused_chunk.py │ │ ├── naive.py │ │ └── parallel.py │ ├── common │ │ ├── __init__.py │ │ ├── chunk_delta_h.py │ │ ├── chunk_h.py │ │ ├── chunk_h_parallel.py │ │ ├── chunk_h_split.py │ │ ├── chunk_o.py │ │ ├── chunk_scaled_dot_kkt.py │ │ └── fused_recurrent.py │ ├── delta_rule │ │ ├── README.md │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── fused_chunk.py │ │ ├── fused_recurrent.py │ │ ├── naive.py │ │ ├── parallel.py │ │ └── wy_fast.py │ ├── forgetting_attn │ │ ├── __init__.py │ │ └── parallel.py │ ├── gated_delta_rule │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── fused_recurrent.py │ │ └── wy_fast.py │ ├── generalized_delta_rule │ │ ├── README.md │ │ ├── __init__.py │ │ ├── dplr │ │ │ ├── __init__.py │ │ │ ├── chunk.py │ │ │ ├── chunk_A_bwd.py │ │ │ ├── chunk_A_fwd.py │ │ │ ├── chunk_h_bwd.py │ │ │ ├── chunk_h_fwd.py │ │ │ ├── chunk_o_bwd.py │ │ │ ├── chunk_o_fwd.py │ │ │ ├── fused_recurrent.py │ │ │ ├── naive.py │ │ │ ├── wy_fast_bwd.py │ │ │ └── wy_fast_fwd.py │ │ └── iplr │ │ │ ├── __init__.py │ │ │ ├── chunk.py │ │ │ ├── fused_recurrent.py │ │ │ ├── naive.py │ │ │ └── wy_fast.py │ ├── gla │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── fused_chunk.py │ │ ├── fused_recurrent.py │ │ └── naive.py │ ├── gsa │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── fused_recurrent.py │ │ └── naive.py │ ├── hgrn │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── fused_recurrent.py │ │ └── naive.py │ ├── lightning_attn │ │ ├── __init__.py │ │ ├── chunk.py │ │ └── fused_recurrent.py │ ├── linear_attn │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── fused_chunk.py │ │ ├── fused_recurrent.py │ │ ├── naive.py │ │ └── utils.py │ ├── mesa_net │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_cg_solver_fwd.py │ │ ├── chunk_h_fwd.py │ │ ├── chunk_h_kk_intra_bwd.py │ │ ├── chunk_h_kv_intra_bwd.py │ │ ├── decoding_one_step.py │ │ └── naive.py │ ├── nsa │ │ ├── __init__.py │ │ ├── compression.py │ │ ├── naive.py │ │ ├── parallel.py │ │ └── utils.py │ ├── path_attn │ │ ├── __init__.py │ │ ├── cumprod_householder_bwd.py │ │ ├── cumprod_householder_fwd.py │ │ ├── intra_chunk_preprocess_bwd.py │ │ ├── intra_chunk_preprocess_bwd_prepare.py │ │ ├── intra_chunk_preprocess_fwd.py │ │ ├── parallel.py │ │ ├── parallel_path_bwd_inter_dkv.py │ │ ├── parallel_path_bwd_inter_dqh.py │ │ ├── parallel_path_bwd_intra.py │ │ ├── parallel_path_fwd.py │ │ └── prepare_k_cache.py │ ├── rebased │ │ ├── __init__.py │ │ ├── naive.py │ │ └── parallel.py │ ├── retention │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── fused_chunk.py │ │ ├── fused_recurrent.py │ │ ├── naive.py │ │ └── parallel.py │ ├── rwkv4 │ │ ├── __init__.py │ │ └── fused_recurrent.py │ ├── rwkv6 │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_naive.py │ │ ├── fused_recurrent.py │ │ └── recurrent_naive.py │ ├── rwkv7 │ │ ├── RWKV7(Goose).md │ │ ├── __init__.py │ │ ├── channel_mixing.py │ │ ├── chunk.py │ │ ├── fused_addcmul.py │ │ ├── fused_k_update.py │ │ ├── fused_recurrent.py │ │ └── recurrent_naive.py │ ├── simple_gla │ │ ├── README.md │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── fused_recurrent.py │ │ ├── naive.py │ │ └── parallel.py │ ├── titans │ │ ├── __init__.py │ │ ├── log_impl.py │ │ └── naive.py │ ├── ttt │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── fused_chunk.py │ │ └── naive.py │ └── utils │ │ ├── __init__.py │ │ ├── asm.py │ │ ├── cumsum.py │ │ ├── index.py │ │ ├── logcumsumexp.py │ │ ├── logsumexp.py │ │ ├── matmul.py │ │ ├── op.py │ │ ├── pack.py │ │ ├── pooling.py │ │ ├── softmax.py │ │ └── solve_tril.py └── utils.py ├── legacy └── training │ ├── README.md │ ├── configs │ ├── gla_1B.json │ ├── gla_340M.json │ ├── gla_7B.json │ └── transformer_340M.json │ ├── flame │ ├── __init__.py │ ├── data.py │ ├── logging.py │ └── parser.py │ ├── preprocess.py │ ├── run.py │ └── train.sh ├── scripts └── find_dependent_tests.py ├── setup.py ├── tests ├── modules │ ├── test_conv.py │ ├── test_cross_entropy.py │ ├── test_grpo.py │ ├── test_kl_div.py │ ├── test_l2norm.py │ ├── test_layernorm.py │ ├── test_layernorm_gated.py │ ├── test_rotary.py │ └── test_token_shift.py ├── ops │ ├── test_attn.py │ ├── test_based.py │ ├── test_cumsum.py │ ├── test_delta.py │ ├── test_dplr_delta.py │ ├── test_forgetting_attn.py │ ├── test_gated_delta.py │ ├── test_gla.py │ ├── test_gsa.py │ ├── test_hgrn.py │ ├── test_iplr_delta.py │ ├── test_linear_attn.py │ ├── test_mesa.py │ ├── test_nsa.py │ ├── test_path_attn.py │ ├── test_retention.py │ ├── test_rwkv6.py │ ├── test_rwkv7.py │ ├── test_simple_gla.py │ ├── test_solve_tril.py │ ├── test_titans.py │ ├── test_ttt.py │ └── test_utils.py ├── test_fused_chunk.py ├── test_generation.py ├── test_model.py └── utils │ └── test_rwkv7_conversion.py └── utils ├── convert_from_llama.py ├── convert_from_rwkv6.py └── convert_from_rwkv7.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 127 3 | exclude = 4 | ./.git, 5 | ./docs, 6 | ./build, 7 | ./scripts, 8 | ./venv, 9 | .flake8, 10 | .pre-commit-config.yaml, 11 | *.pyi, 12 | *.md, 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug.yml: -------------------------------------------------------------------------------- 1 | name: 🐞 Bug Report 2 | description: Create a report to help reproduce and resolve the bug. 3 | title: "[Bug] " 4 | labels: ["bug"] 5 | body: 6 | - type: checkboxes 7 | id: checklist 8 | attributes: 9 | label: Checklist 10 | description: Please confirm these steps before submitting 11 | options: 12 | - label: "I have checked [FAQs](https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md) and existing issues for similar problems" 13 | - label: "Please report this bug in English to ensure wider understanding and support" 14 | 15 | - type: textarea 16 | id: description 17 | attributes: 18 | label: Describe the Bug 19 | description: Provide a clear and concise description of the bug. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | id: reproduction 25 | attributes: 26 | label: Steps to Reproduce the Bug 27 | description: | 28 | Please include a code sample that reproduces the issue you encountered. This can be a Colab link or a code snippet. 29 | If you have code snippets, error messages, or stack traces, please include them here as well. 30 | Important! Use code formatting to clearly present your code. See [GitHub's guide on code formatting](https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting). 31 | Avoid using screenshots, as they can be difficult to read and do not allow others to easily copy and paste your code. 32 | placeholder: | 33 | Code snippets or steps to reproduce the behavior. 34 | validations: 35 | required: true 36 | 37 | - type: textarea 38 | id: expected-behavior 39 | attributes: 40 | label: Expected Behavior 41 | description: Provide a clear and concise description of the expected results. 42 | validations: 43 | required: true 44 | 45 | - type: textarea 46 | id: environment-info 47 | attributes: 48 | label: Environment Information 49 | description: Please share your environment details, including Python version, Torch version, Triton version, and platform. 50 | value: | 51 | 1. Torch: 52 | 2. Triton: 53 | validations: 54 | required: true 55 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | contact_links: 2 | - name: 🤗 FLA Hub 3 | url: https://huggingface.co/fla-hub 4 | about: For questions or issues related to released models, datasets, or other resources, please open an issue on our Hugging Face hub. 5 | - name: 💬 Discord Forum 6 | url: https://discord.gg/vDaJTmKNcS 7 | about: Join our Discord community to ask questions, get support, share ideas, and interact with other members. This is the best place for general discussions and quick queries. 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature Request 2 | description: Suggest an idea for this project. 3 | title: "[Feature Request] " 4 | labels: ["enhancement"] 5 | body: 6 | - type: textarea 7 | id: feature-request 8 | attributes: 9 | label: Feature Request 10 | description: Provide a clear and concise description of your feature proposal. 11 | validations: 12 | required: true 13 | 14 | - type: textarea 15 | id: motivation 16 | attributes: 17 | label: Motivation 18 | description: | 19 | Please explain the motivation behind your proposal. Is your feature request addressing a specific problem? For example, "I'm often frustrated when [...]." If this is related to another GitHub issue, please link it here as well. 20 | validations: 21 | required: true 22 | 23 | - type: textarea 24 | id: contribution 25 | attributes: 26 | label: Your Contribution 27 | description: | 28 | How can you contribute to this feature? For example, could you help by submitting a PR? 29 | validations: 30 | required: true -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/rfc.yml: -------------------------------------------------------------------------------- 1 | name: 📝 Request for Comments 2 | description: Propose a new feature or change and gather feedback. 3 | title: "[RFC] " 4 | labels: ["enhancement"] 5 | body: 6 | - type: textarea 7 | id: proposal 8 | attributes: 9 | label: Proposal 10 | description: Provide a brief description of your proposal for a new feature or change. 11 | validations: 12 | required: true 13 | 14 | - type: textarea 15 | id: rationale 16 | attributes: 17 | label: Rationale 18 | description: Why do you think this feature is important or beneficial? 19 | validations: 20 | required: false 21 | -------------------------------------------------------------------------------- /.github/workflows/issue.yml: -------------------------------------------------------------------------------- 1 | name: issues 2 | on: 3 | schedule: 4 | - cron: "0 0 * * 0" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v9.1.0 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 7 17 | stale-issue-label: "stale" 18 | exempt-issue-labels: "enhancement" 19 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 20 | close-issue-message: "This issue was closed because it has been inactive for 7 days since being marked as stale." 21 | days-before-pr-stale: -1 22 | days-before-pr-close: -1 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ['3.12'] 15 | steps: 16 | - name: Check out repo 17 | uses: actions/checkout@v4 18 | - name: Setup python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Update pip 23 | run: python -m pip install --upgrade pip 24 | - name: Install lint utilities 25 | run: | 26 | python -m pip install pre-commit 27 | pre-commit install-hooks 28 | - name: Get changed files 29 | id: changed-files 30 | uses: tj-actions/changed-files@v46.0.1 31 | - name: Lint modified files 32 | run: pre-commit run --files ${{ steps.changed-files.outputs.all_changed_files }} 33 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: release 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.12' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install setuptools wheel twine packaging 23 | - name: Build and publish 24 | env: 25 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 26 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 27 | run: | 28 | python setup.py sdist bdist_wheel 29 | twine upload dist/* 30 | -------------------------------------------------------------------------------- /.github/workflows/stable.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: pip-for-stable-branch 5 | 6 | on: 7 | push: 8 | branches: 9 | - stable 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.12' 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install setuptools wheel twine packaging 24 | - name: Build and publish 25 | env: 26 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 27 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 28 | run: | 29 | python setup.py sdist bdist_wheel 30 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # test file 2 | test.py 3 | 4 | # data files 5 | data 6 | 7 | # bash scripts 8 | *.sh 9 | 10 | # docs 11 | docs/_build 12 | 13 | # intermediate files 14 | build 15 | dist 16 | *.egg-info 17 | 18 | # experimental results 19 | exp 20 | results 21 | wandb 22 | *.csv 23 | *.png 24 | *.html 25 | 26 | # log and config files 27 | log.* 28 | *.log 29 | *.cfg 30 | *.ini 31 | 32 | # pycache 33 | __pycache__ 34 | 35 | # saved model 36 | *.pkl 37 | *.pt 38 | 39 | # hidden files 40 | .* 41 | 42 | # vscode 43 | .vscode 44 | 45 | # macOS 46 | .DS_Store 47 | 48 | 49 | rwkvfla/* 50 | *.ncu-rep 51 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-symlinks 6 | - id: trailing-whitespace 7 | args: [--markdown-linebreak-ext=md] 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-toml 11 | - id: check-ast 12 | - id: check-added-large-files 13 | - id: check-merge-conflict 14 | - id: detect-private-key 15 | - id: debug-statements 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.12.0 18 | hooks: 19 | - id: isort 20 | - repo: https://github.com/PyCQA/flake8 21 | rev: 7.0.0 22 | hooks: 23 | - id: flake8 24 | args: [--max-line-length=127] 25 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Yang" 5 | given-names: "Songlin" 6 | orcid: "https://orcid.org/0000-0002-5944-0110" 7 | - family-names: "Zhang" 8 | given-names: "Yu" 9 | orcid: "https://orcid.org/0000-0002-8345-3835" 10 | title: "FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism" 11 | version: 0.1 12 | date-released: 2024-01-18 13 | url: "https://github.com/fla-org/flash-linear-attention" 14 | -------------------------------------------------------------------------------- /FAQs.md: -------------------------------------------------------------------------------- 1 | # Triton FAQs and Common Issues 2 | 3 | ## Recommended Setup Approach 4 | 5 | > [!IMPORTANT] 6 | > Triton nightly builds often depend on the latest PyTorch nightly versions. To prevent conflicts with existing installations, we strongly recommend creating a fresh conda environment. This isolates the installation from any existing PyTorch/Triton versions that might cause compatibility issues. 7 | 8 | ## Common Issues and Solutions 9 | 10 | ### 1. MMA Assertion Error on H100 11 | 12 | **Error:** 13 | ```py 14 | Assertion `!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere()) && "mma -> mma layout conversion is only supported on Ampere"' failed. 15 | ``` 16 | 17 | **Solution:** 18 | This issue was fixed in [PR #4492](https://github.com/triton-lang/triton/pull/4492). Install the nightly version: 19 | 20 | ```sh 21 | # Create fresh environment (strongly recommended!!!) 22 | conda create -n triton-nightly python=3.12 23 | conda activate triton-nightly 24 | 25 | # Install PyTorch nightly (required for Triton nightly compatibility) 26 | pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 27 | 28 | # Install Triton nightly 29 | pip uninstall triton pytorch-triton -y 30 | pip install -U triton-nightly --index-url http://pypi.fla-org.com/simple --trusted-host pypi.fla-org.com 31 | 32 | # Instal flash-linear-attention 33 | pip install einops ninja datasets transformers numpy 34 | pip uninstall flash-linear-attention && pip install -U --no-use-pep517 git+https://github.com/fla-org/flash-linear-attention --no-deps 35 | 36 | # Optional: Install flash-attention 37 | conda install nvidia/label/cuda-12.6.3::cuda-nvcc 38 | pip install packaging psutil ninja 39 | pip install flash-attn --no-deps --no-cache-dir --no-build-isolation 40 | 41 | # Optional: Verify flash-attention installation 42 | pip install pytest 43 | pytest tests/ops/test_attn.py 44 | ``` 45 | 46 | ### 2. AttributeError: 'NoneType' object has no attribute 'start' 47 | 48 | **Solution:** 49 | This is a known issue ([triton-lang/triton#5224](https://github.com/triton-lang/triton/issues/5224)). Upgrade to Python 3.10+. 50 | 51 | ### 3. H100 LinearLayout Assertion Error 52 | 53 | **Error:** 54 | ``` 55 | mlir::triton::LinearLayout::reshapeOuts(...) failed. 56 | ``` 57 | 58 | **Solution:** 59 | This is a known issue ([triton-lang/triton#5609](https://github.com/triton-lang/triton/issues/5609)). Follow the same installation steps as in Issue #1 above. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-2025 Songlin Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmarks/benchmark_generation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang. 3 | 4 | import argparse 5 | import time 6 | 7 | import torch 8 | from datasets import load_dataset 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | import fla # noqa 12 | 13 | 14 | def sizeof_fmt(num, suffix='B'): 15 | for unit in ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'): 16 | if abs(num) < 1024.0: 17 | return f'{num:3.1f}{unit}{suffix}' 18 | num /= 1024.0 19 | return f'{num:.1f}Yi{suffix}' 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser(description="Generation benchmarking") 24 | parser.add_argument("--path", type=str, default="fla-hub/transformer-1.3B-100B") 25 | parser.add_argument("--data", type=str, default="fla-hub/pg19") 26 | parser.add_argument("--length", type=int, default=128) 27 | parser.add_argument("--maxlen", type=int, default=128) 28 | parser.add_argument("--no-cache", action='store_true') 29 | parser.add_argument("--temperature", type=float, default=0.5) 30 | parser.add_argument("--topp", type=float, default=0.2) 31 | parser.add_argument("--repetition_penalty", type=float, default=1.1) 32 | parser.add_argument("--compile", action='store_true') 33 | args = parser.parse_args() 34 | 35 | device = "cuda" 36 | dtype = torch.bfloat16 37 | torch.manual_seed(0) 38 | 39 | print(f"Loading {args.path}") 40 | tokenizer = AutoTokenizer.from_pretrained( 41 | args.path, 42 | trust_remote_code=True, 43 | add_eos_token=False 44 | ) 45 | tokenizer.pad_token_id = tokenizer.eos_token_id 46 | print(f"{tokenizer}") 47 | 48 | model = AutoModelForCausalLM.from_pretrained( 49 | args.path, 50 | device_map={"": device}, 51 | torch_dtype=dtype, 52 | use_cache=not args.no_cache 53 | ) 54 | if args.compile: 55 | print("Compiling the model") 56 | model = torch.compile(model) 57 | model.eval() 58 | print(f"{model.config}\n{model}\nNumber of parameters: {model.num_parameters()} ({sizeof_fmt(model.num_parameters())})\n") 59 | 60 | print(f"Loading {args.data}") 61 | dataset = load_dataset(args.data, split='train', trust_remote_code=True) 62 | print(f"{dataset}") 63 | 64 | prompt = dataset[0]['text'] 65 | tokens = tokenizer(prompt, return_tensors="pt") 66 | input_ids = tokens.input_ids.to(device=device)[:, :args.length].contiguous() 67 | max_length = input_ids.shape[1] + args.maxlen 68 | 69 | torch.cuda.synchronize() 70 | start = time.time() 71 | with torch.inference_mode(): 72 | text = model.generate( 73 | input_ids=input_ids, 74 | use_cache=not args.no_cache, 75 | max_length=max_length, 76 | pad_token_id=tokenizer.eos_token_id, 77 | eos_token_id=tokenizer.bos_token_id, 78 | do_sample=True, 79 | temperature=args.temperature, 80 | top_p=args.topp, 81 | repetition_penalty=args.repetition_penalty 82 | ) 83 | torch.cuda.synchronize() 84 | elapsed = time.time() - start 85 | print(f"Prompt:\n{tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0].strip()}\n") 86 | print(f"Generated:\n{tokenizer.batch_decode(text, skip_special_tokens=True)[0].strip()}\n") 87 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(text[0]) - len(input_ids[0])}") 88 | print(f"Total prompt processing + decoding time: {elapsed * 1000:.0f}ms") 89 | print(f"Max memory used: {sizeof_fmt(torch.cuda.max_memory_allocated())}") 90 | -------------------------------------------------------------------------------- /benchmarks/modules/benchmark_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import triton 7 | 8 | from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss 9 | 10 | 11 | @triton.testing.perf_report( 12 | triton.testing.Benchmark( 13 | # argument names to use as an x-axis for the plot 14 | x_names=['T'], 15 | # different possible values for `x_name` 16 | x_vals=[128 * 2 ** i for i in range(0, 8)], 17 | # argument name whose value corresponds to a different line in the plot 18 | line_arg='provider', 19 | # possible values for `line_arg`` 20 | line_vals=['naive', 'fused', 'fused_linear', 'naive_bwd', 'fused_bwd', 'fused_linear_bwd'], 21 | # label name for the lines 22 | line_names=['naive', 'fused', 'fused_linear', 'naive_bwd', 'fused_bwd', 'fused_linear_bwd'], 23 | # line styles 24 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 25 | ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')], 26 | ylabel="Execution Time (ms)", # label name for the y-axis 27 | # name for the plot. Used also as a file name for saving the plot. 28 | plot_name="Performance", 29 | args={}, 30 | ) 31 | ) 32 | def benchmark(T, provider): 33 | from fla.utils import device 34 | dtype = torch.bfloat16 35 | requires_grad = True 36 | B, H, V = 4, 4096, 120000 37 | 38 | x = torch.randn(B * T, H, device=device, requires_grad=requires_grad, dtype=dtype) 39 | target = torch.randint(0, V, (B * T,), device=device, dtype=torch.int64) 40 | w = torch.randn(V, H, device=device, requires_grad=requires_grad, dtype=dtype) 41 | b = torch.randn(V, device=device, requires_grad=requires_grad, dtype=dtype) 42 | 43 | quantiles = [0.5, 0.2, 0.8] 44 | results = 0, 0, 0 45 | if provider == 'naive': 46 | criterion = nn.CrossEntropyLoss() 47 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target), quantiles=quantiles) 48 | elif provider == 'naive_bwd': 49 | criterion = nn.CrossEntropyLoss() 50 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target).backward(), quantiles=quantiles) 51 | elif provider == 'fused': 52 | criterion = FusedCrossEntropyLoss() 53 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target), quantiles=quantiles) 54 | elif provider == 'fused_bwd': 55 | criterion = FusedCrossEntropyLoss() 56 | results = triton.testing.do_bench(lambda: criterion(F.linear(x, w, b), target).backward(), quantiles=quantiles) 57 | elif provider == 'fused_linear': 58 | criterion = FusedLinearCrossEntropyLoss() 59 | results = triton.testing.do_bench(lambda: criterion(x, target, w, b), quantiles=quantiles) 60 | elif provider == 'fused_linear_bwd': 61 | criterion = FusedLinearCrossEntropyLoss() 62 | results = triton.testing.do_bench(lambda: criterion(x, target, w, b).backward(), quantiles=quantiles) 63 | return results 64 | 65 | 66 | if __name__ == '__main__': 67 | benchmark.run(print_data=True) 68 | -------------------------------------------------------------------------------- /benchmarks/modules/benchmark_l2norm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import triton 8 | 9 | from fla.modules.l2norm import l2norm 10 | 11 | 12 | @triton.testing.perf_report( 13 | triton.testing.Benchmark( 14 | # argument names to use as an x-axis for the plot 15 | x_names=['B', 'T', 'H', 'D'], 16 | # different possible values for `x_name` 17 | x_vals=[(16, 128 * 2 ** i, h, 2048//h) for h in [1, 2, 4, 8, 16] for i in range(1, 8)], 18 | # argument name whose value corresponds to a different line in the plot 19 | line_arg='provider', 20 | # possible values for `line_arg`` 21 | line_vals=['naive', 'compiled', 'fused', 'naive_bwd', 'compiled_bwd', 'fused_bwd'], 22 | # label name for the lines 23 | line_names=['naive', 'compiled', 'fused', 'naive_bwd', 'compiled_bwd', 'fused_bwd'], 24 | # line styles 25 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 26 | ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')], 27 | ylabel="Execution Time (ms)", # label name for the y-axis 28 | # name for the plot. Used also as a file name for saving the plot. 29 | plot_name="Performance", 30 | args={}, 31 | ) 32 | ) 33 | def benchmark(B, H, D, T, provider): 34 | from fla.utils import device 35 | dtype = torch.bfloat16 36 | requires_grad = True 37 | x = torch.randn(B * T, D, device=device, requires_grad=requires_grad, dtype=dtype) 38 | 39 | quantiles = [0.5, 0.2, 0.8] 40 | results = 0, 0, 0 41 | if provider.startswith('naive'): 42 | norm = partial(F.normalize, dim=-1, p=2) 43 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 44 | if provider.startswith('compiled'): 45 | norm = torch.compile(partial(F.normalize, dim=-1, p=2)) 46 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 47 | if provider.startswith('fused'): 48 | norm = l2norm 49 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 50 | if provider.startswith('naive_bwd'): 51 | norm = partial(F.normalize, dim=-1, p=2) 52 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 53 | if provider.startswith('compiled_bwd'): 54 | norm = torch.compile(partial(F.normalize, dim=-1, p=2)) 55 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 56 | if provider.startswith('fused_bwd'): 57 | norm = l2norm 58 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 59 | return results 60 | 61 | 62 | if __name__ == '__main__': 63 | benchmark.run(print_data=True) 64 | -------------------------------------------------------------------------------- /benchmarks/modules/benchmark_layernorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import triton 6 | 7 | from fla.modules import GroupNorm, LayerNorm 8 | 9 | 10 | @triton.testing.perf_report( 11 | triton.testing.Benchmark( 12 | # argument names to use as an x-axis for the plot 13 | x_names=['T'], 14 | # different possible values for `x_name` 15 | x_vals=[128 * 2 ** i for i in range(0, 8)], 16 | # argument name whose value corresponds to a different line in the plot 17 | line_arg='provider', 18 | # possible values for `line_arg`` 19 | line_vals=['naive_ln', 'fused_ln', 'naive_gn', 'fused_gn', 20 | 'naive_ln_bwd', 'fused_ln_bwd', 'naive_gn_bwd', 'fused_gn_bwd'], 21 | # label name for the lines 22 | line_names=['naive_ln', 'fused_ln', 'naive_gn', 'fused_gn', 23 | 'naive_ln_bwd', 'fused_ln_bwd', 'naive_gn_bwd', 'fused_gn_bwd'], 24 | # line styles 25 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 26 | ('cyan', ':'), ('yellow', 'dotted'), ('cyan', '--'), ('cyan', '-'), ('black', ':')], 27 | ylabel="Execution Time (ms)", # label name for the y-axis 28 | # name for the plot. Used also as a file name for saving the plot. 29 | plot_name="Performance", 30 | args={}, 31 | ) 32 | ) 33 | def benchmark(T, provider): 34 | from fla.utils import device 35 | dtype = torch.bfloat16 36 | requires_grad = True 37 | B, D = 16, 1024 38 | 39 | x = torch.randn(B * T, D, device=device, requires_grad=requires_grad, dtype=dtype) 40 | 41 | quantiles = [0.5, 0.2, 0.8] 42 | results = 0, 0, 0 43 | if provider.startswith('naive_ln'): 44 | norm = nn.LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 45 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 46 | if provider.startswith('fused_ln'): 47 | norm = LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 48 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 49 | if provider.startswith('naive_gn'): 50 | norm = nn.GroupNorm(4, D).to(device=device, dtype=dtype) 51 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 52 | if provider.startswith('fused_gn'): 53 | norm = GroupNorm(4, D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 54 | results = triton.testing.do_bench(lambda: norm(x), quantiles=quantiles) 55 | if provider.startswith('naive_ln_bwd'): 56 | norm = nn.LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 57 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 58 | if provider.startswith('fused_ln_bwd'): 59 | norm = LayerNorm(D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 60 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 61 | if provider.startswith('naive_gn_bwd'): 62 | norm = nn.GroupNorm(4, D).to(device=device, dtype=dtype) 63 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 64 | if provider.startswith('fused_gn_bwd'): 65 | norm = GroupNorm(4, D, elementwise_affine=True, bias=True).to(device=device, dtype=dtype) 66 | results = triton.testing.do_bench(lambda: norm(x).backward(x), quantiles=quantiles) 67 | return results 68 | 69 | 70 | if __name__ == '__main__': 71 | benchmark.run(print_data=True) 72 | -------------------------------------------------------------------------------- /benchmarks/modules/benchmark_tokenshift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import triton 4 | 5 | from fla.modules.token_shift import token_shift 6 | 7 | 8 | def token_shift_ref(x): 9 | shifted = nn.functional.pad(x, (0, 0, 1, -1)) 10 | delta = shifted - x 11 | return delta 12 | 13 | 14 | @triton.testing.perf_report( 15 | triton.testing.Benchmark( 16 | # argument names to use as an x-axis for the plot 17 | x_names=['T'], 18 | # different possible values for `x_name` 19 | x_vals=[128 * 2 ** i for i in range(0, 8)], 20 | # argument name whose value corresponds to a different line in the plot 21 | line_arg='provider', 22 | # possible values for `line_arg`` 23 | line_vals=['naive_token_shift', 'fused_token_shift', 'naive_token_shift_bwd', 'fused_token_shift_bwd'], 24 | # label name for the lines 25 | line_names=['naive_token_shift', 'fused_token_shift', 'naive_token_shift_bwd', 'fused_token_shift_bwd'], 26 | # line styles 27 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 28 | ('cyan', ':')], 29 | ylabel="Execution Time (ms)", # label name for the y-axis 30 | # name for the plot. Used also as a file name for saving the plot. 31 | plot_name="Performance", 32 | args={}, 33 | ) 34 | ) 35 | def benchmark(T, provider): 36 | from fla.utils import device 37 | dtype = torch.bfloat16 38 | requires_grad = True 39 | B, D = 8, 4096 40 | 41 | x = torch.randn(B, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 42 | 43 | quantiles = [0.5, 0.2, 0.8] 44 | results = 0, 0, 0 45 | if provider.startswith('naive_token_shift'): 46 | results = triton.testing.do_bench(lambda: token_shift_ref(x), quantiles=quantiles) 47 | if provider.startswith('fused_token_shift'): 48 | results = triton.testing.do_bench(lambda: token_shift(x), quantiles=quantiles) 49 | if provider.startswith('naive_token_shift_bwd'): 50 | grad_output = torch.randn_like(x) 51 | results = triton.testing.do_bench(lambda: token_shift_ref(x).backward(grad_output), quantiles=quantiles) 52 | if provider.startswith('fused_token_shift_bwd'): 53 | grad_output = torch.randn_like(x) 54 | results = triton.testing.do_bench(lambda: token_shift(x).backward(grad_output), quantiles=quantiles) 55 | return results 56 | 57 | 58 | if __name__ == '__main__': 59 | benchmark.run(print_data=True) 60 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_abc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | from torch.nn import functional as F 6 | 7 | from fla.ops.abc import chunk_abc 8 | from fla.ops.gla import chunk_gla 9 | from fla.ops.retention import chunk_retention 10 | 11 | try: 12 | from flash_attn import flash_attn_func 13 | HAS_FLASH = True 14 | except BaseException: 15 | HAS_FLASH = False 16 | 17 | 18 | @triton.testing.perf_report( 19 | triton.testing.Benchmark( 20 | # argument names to use as an x-axis for the plot 21 | x_names=['T'], 22 | # different possible values for `x_name` 23 | x_vals=[128 * 2 ** i for i in range(0, 8)], 24 | # argument name whose value corresponds to a different line in the plot 25 | line_arg='provider', 26 | # possible values for `line_arg`` 27 | line_vals=['abc', 'gla', 'abc_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'], 28 | # label name for the lines 29 | line_names=['abc', 'gla', 'abc_bwd', 'gla_bwd', 'retention_bwd', 'flash_bwd'], 30 | # line styles 31 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), 32 | ('cyan', ':'), ('yellow', 'dotted'), ('black', ':')], 33 | ylabel="Execution Time (ms)", # label name for the y-axis 34 | # name for the plot. Used also as a file name for saving the plot. 35 | plot_name="Performance", 36 | args={}, 37 | ) 38 | ) 39 | def benchmark(T, provider): 40 | from fla.utils import device 41 | dtype = torch.bfloat16 42 | requires_grad = True 43 | B, H, D, M = 16, 4, 128, 64 44 | 45 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 46 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 47 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 48 | if provider.startswith('flash'): 49 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 50 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 51 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 52 | if provider.startswith('gla'): 53 | g = F.logsigmoid(torch.randn(B, H, T, D, device=device, dtype=dtype)) 54 | g = g.clamp_min(-5).requires_grad_(requires_grad) 55 | if provider.startswith('abc'): 56 | s = torch.randn(B, H, T, M, device=device, requires_grad=requires_grad, dtype=dtype) 57 | 58 | do = torch.ones_like(v, dtype=dtype) 59 | 60 | quantiles = [0.5, 0.2, 0.8] 61 | if provider == 'abc': 62 | results = triton.testing.do_bench(lambda: chunk_abc(q, k, v, s), quantiles=quantiles) 63 | elif provider == 'gla': 64 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g), quantiles=quantiles) 65 | elif provider == 'abc_bwd': 66 | results = triton.testing.do_bench(lambda: chunk_abc(q, k, v, s)[0].backward(do), quantiles=quantiles) 67 | elif provider == 'gla_bwd': 68 | results = triton.testing.do_bench(lambda: chunk_gla(q, k, v, g)[0].backward(do), quantiles=quantiles) 69 | elif provider == 'retention_bwd': 70 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles) 71 | elif provider == 'flash_bwd': 72 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles) 73 | return results 74 | 75 | 76 | if __name__ == '__main__': 77 | benchmark.run(print_data=True) 78 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_hgrn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | 6 | from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn 7 | 8 | 9 | @triton.testing.perf_report( 10 | triton.testing.Benchmark( 11 | # argument names to use as an x-axis for the plot 12 | x_names=['T'], 13 | # different possible values for `x_name` 14 | x_vals=[128 * 2 ** i for i in range(0, 8)], 15 | # argument name whose value corresponds to a different line in the plot 16 | line_arg='provider', 17 | # possible values for `line_arg`` 18 | line_vals=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'], 19 | # label name for the lines 20 | line_names=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'], 21 | # line styles 22 | styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')], 23 | ylabel="Execution Time (ms)", # label name for the y-axis 24 | # name for the plot. Used also as a file name for saving the plot. 25 | plot_name="Performance", 26 | args={}, 27 | ) 28 | ) 29 | def benchmark(T, provider): 30 | from fla.utils import device 31 | dtype = torch.bfloat16 32 | B, D = 16, 512 33 | 34 | x = torch.randn((B, T, D), dtype=dtype, device=device) 35 | g = torch.randn((B, T, D), dtype=dtype, device=device).sigmoid() 36 | x = (1 - g) * x 37 | x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g)) 38 | do = torch.randn_like(x, dtype=dtype) 39 | quantiles = [0.5, 0.2, 0.8] 40 | results = 0, 0, 0 41 | if provider == 'chunk': 42 | results = triton.testing.do_bench(lambda: chunk_hgrn(x, g), quantiles=quantiles) 43 | if provider == 'recurrent': 44 | results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g), quantiles=quantiles) 45 | if provider == 'chunk_bwd': 46 | results = triton.testing.do_bench(lambda: chunk_hgrn(x, g)[0].backward(do), quantiles=quantiles) 47 | if provider == 'recurrent_bwd': 48 | results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g)[0].backward(do), quantiles=quantiles) 49 | return results 50 | 51 | 52 | if __name__ == '__main__': 53 | benchmark.run(print_data=True) 54 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_nsa.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import triton 5 | from flash_attn import flash_attn_func 6 | 7 | from fla.ops.nsa import parallel_nsa 8 | 9 | 10 | @triton.testing.perf_report( 11 | triton.testing.Benchmark( 12 | # argument names to use as an x-axis for the plot 13 | x_names=['T'], 14 | # different possible values for `x_name` 15 | x_vals=[128 * 2 ** i for i in range(0, 8)], 16 | # argument name whose value corresponds to a different line in the plot 17 | line_arg='provider', 18 | # possible values for `line_arg`` 19 | line_vals=['nsa', 'nsa_bwd', 'flash', 'flash_bwd'], 20 | # label name for the lines 21 | line_names=['nsa', 'nsa_bwd', 'flash', 'flash_bwd'], 22 | # line styles 23 | styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('green', 'dotted'), 24 | ('blue', 'dotted'), ('red', 'dotted'), ('cyan', '-'), ('cyan', 'dotted')], 25 | ylabel="Execution Time (ms)", # label name for the y-axis 26 | # name for the plot. Used also as a file name for saving the plot. 27 | plot_name="Performance", 28 | args={}, 29 | ) 30 | ) 31 | def benchmark(T, provider): 32 | from fla.utils import device 33 | dtype = torch.bfloat16 34 | requires_grad = True 35 | B, H, HQ, D, S = 4, 4, 64, 128, 16 36 | block_size = 64 37 | 38 | q = torch.randn(B, T, HQ, D, device=device, requires_grad=requires_grad, dtype=dtype) 39 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 40 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 41 | do = torch.ones_like(q, dtype=dtype) 42 | 43 | indices = torch.full((B, T, H, S), T, dtype=torch.long, device=device) 44 | for b in range(B): 45 | for t in range(T): 46 | for h in range(H): 47 | i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S] 48 | indices[b, t, h, :len(i_i)] = i_i 49 | indices = indices.sort(-1)[0] 50 | 51 | quantiles = [0.5, 0.2, 0.8] 52 | results = 0, 0, 0 53 | if provider == 'nsa': 54 | results = triton.testing.do_bench( 55 | lambda: parallel_nsa(q, k, v, indices, block_size), 56 | quantiles=quantiles 57 | ) 58 | elif provider == 'nsa_bwd': 59 | results = triton.testing.do_bench( 60 | lambda: parallel_nsa(q, k, v, indices, block_size).backward(do), 61 | quantiles=quantiles 62 | ) 63 | elif provider == 'flash': 64 | results = triton.testing.do_bench( 65 | lambda: flash_attn_func(q, k, v, causal=True), 66 | quantiles=quantiles 67 | ) 68 | elif provider == 'flash_bwd': 69 | results = triton.testing.do_bench( 70 | lambda: flash_attn_func(q, k, v, causal=True).backward(do), 71 | quantiles=quantiles 72 | ) 73 | return results 74 | 75 | 76 | if __name__ == '__main__': 77 | benchmark.run(print_data=True, save_path='.') 78 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_retention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import torch 6 | import triton 7 | from flash_attn import flash_attn_func 8 | 9 | from fla.ops.retention import chunk_retention, parallel_retention 10 | 11 | 12 | @triton.testing.perf_report( 13 | triton.testing.Benchmark( 14 | # argument names to use as an x-axis for the plot 15 | x_names=['T'], 16 | # different possible values for `x_name` 17 | x_vals=[128 * 2 ** i for i in range(0, 8)], 18 | # argument name whose value corresponds to a different line in the plot 19 | line_arg='provider', 20 | # possible values for `line_arg`` 21 | line_vals=['chunk', 'parallel', 'flash', 'chunk_bwd', 'parallel_bwd', 'flash_bwd'], 22 | # label name for the lines 23 | line_names=['chunk_fwd', 'parallel_fwd', 'flash_fwd', 'chunk_fwdbwd', 'parallel_fwdbwd', 'flash_fwdbwd'], 24 | # line styles 25 | styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('green', 'dotted'), ('blue', 'dotted'), ('red', 'dotted')], 26 | ylabel="Execution Time (ms)", # label name for the y-axis 27 | # name for the plot. Used also as a file name for saving the plot. 28 | plot_name="Performance", 29 | args={}, 30 | ) 31 | ) 32 | def benchmark(T, provider): 33 | from fla.utils import device 34 | dtype = torch.bfloat16 35 | requires_grad = True 36 | B, H, D = 4, 8, 256 37 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 38 | 39 | if provider == 'flash' or provider == 'flash_bwd': 40 | q = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 41 | k = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 42 | v = torch.randn(B, T, H, D, device=device, requires_grad=requires_grad, dtype=dtype) 43 | else: 44 | q = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 45 | k = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 46 | v = torch.randn(B, H, T, D, device=device, requires_grad=requires_grad, dtype=dtype) 47 | do = torch.ones_like(q, dtype=dtype) 48 | 49 | quantiles = [0.5, 0.2, 0.8] 50 | results = 0, 0, 0 51 | if provider == 'chunk': 52 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v), quantiles=quantiles) 53 | elif provider == 'parallel': 54 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v), quantiles=quantiles) 55 | elif provider == 'flash': 56 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True), quantiles=quantiles) 57 | elif provider == 'chunk_bwd': 58 | results = triton.testing.do_bench(lambda: chunk_retention(q, k, v)[0].backward(do), quantiles=quantiles) 59 | elif provider == 'parallel_bwd': 60 | results = triton.testing.do_bench(lambda: parallel_retention(q, k, v)[0].backward(do), quantiles=quantiles) 61 | elif provider == 'flash_bwd': 62 | results = triton.testing.do_bench(lambda: flash_attn_func(q, k, v, causal=True).backward(do), quantiles=quantiles) 63 | return results 64 | 65 | 66 | if __name__ == '__main__': 67 | benchmark.run(print_data=True, save_path='.') 68 | -------------------------------------------------------------------------------- /benchmarks/ops/benchmark_simple_gla_vs_mamba2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dependencies: 3 | $ pip install mamba-ssm==2.2.2 triton==2.3.1 4 | 5 | For correctness check, see: 6 | https://github.com/sustcsonglin/flash-linear-attention/pull/49 7 | """ 8 | 9 | import torch 10 | import triton 11 | from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined 12 | 13 | from fla.ops.simple_gla import chunk_simple_gla 14 | 15 | 16 | @triton.testing.perf_report( 17 | triton.testing.Benchmark( 18 | # argument names to use as an x-axis for the plot 19 | x_names=['T'], 20 | # different possible values for `x_name` 21 | x_vals=[64] + [128 * 2 ** i for i in range(0, 8)], 22 | # argument name whose value corresponds to a different line in the plot 23 | line_arg='provider', 24 | # possible values for `line_arg`` 25 | line_vals=["chunk_simple_gla", "mamba2_ssd"], 26 | # label name for the lines 27 | line_names=["chunk_simple_gla", "mamba2_ssd"], 28 | # line styles 29 | styles=[('blue', '-'), ('red', '-')], 30 | ylabel="Execution Time (ms)", # label name for the y-axis 31 | # name for the plot. Used also as a file name for saving the plot. 32 | plot_name="Performance", 33 | args={}, 34 | ) 35 | ) 36 | def benchmark(T, provider): 37 | # TODO: also add bwd pass benchmark 38 | from fla.utils import device 39 | dtype = torch.bfloat16 40 | B, H, D = 16, 8, 128 41 | # TODO: test more shapes 42 | # TODO: different values for D_V and D_QK 43 | # TODO: different values for H_Q and H_KV 44 | final_state = False # does not impact performance 45 | 46 | # initialize Mamba2-format inputs 47 | X_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device) 48 | dt_mamba = torch.ones(B, T, H, dtype=dtype, device=device) 49 | A_mamba = -0.1 * torch.rand(H, dtype=dtype, device=device) 50 | B_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device) 51 | C_mamba = 0.1 * torch.randn(B, T, H, D, dtype=dtype, device=device) 52 | 53 | quantiles = [0.5, 0.2, 0.8] 54 | if provider == 'chunk_simple_gla': 55 | # mapping inputs Mamba2 -> FLA 56 | # C, B, X: [B, T, H, D] -> [B, H, T, D] 57 | # g: [B, T, H] -> [B, H, T] 58 | q = C_mamba.transpose(1, 2).contiguous() 59 | k = B_mamba.transpose(1, 2).contiguous() 60 | v = X_mamba.transpose(1, 2).contiguous() 61 | g = (A_mamba * dt_mamba).transpose(1, 2).contiguous() 62 | # NOTE: whether to include the memory-copy cost of `contiguous()`? 63 | # this depends on the memory layout used by surrounding non-SSM layers 64 | 65 | results = triton.testing.do_bench( 66 | lambda: chunk_simple_gla( 67 | q, k, v, g, scale=1.0, output_final_state=final_state 68 | ), quantiles=quantiles 69 | ) 70 | 71 | elif provider == 'mamba2_ssd': 72 | # NOTE: `chunk_size` is configurable in mamba2 kernel 73 | # here sets to the same hard-coded `BT = 64` as in simple_gla kernel 74 | # TODO: benchmark different chunk sizes 75 | results = triton.testing.do_bench( 76 | lambda: mamba_chunk_scan_combined( 77 | X_mamba, dt_mamba, A_mamba, B_mamba, C_mamba, 78 | chunk_size=64, D=None, return_final_states=final_state 79 | ), 80 | quantiles=quantiles 81 | ) 82 | return results 83 | 84 | 85 | if __name__ == '__main__': 86 | benchmark.run(print_data=True, save_path='.') 87 | -------------------------------------------------------------------------------- /evals/harness.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | import fla # noqa 6 | from lm_eval.__main__ import cli_evaluate 7 | from lm_eval.api.registry import register_model 8 | from lm_eval.models.huggingface import HFLM 9 | 10 | 11 | @register_model('fla') 12 | class FlashLinearAttentionLMWrapper(HFLM): 13 | def __init__(self, **kwargs) -> FlashLinearAttentionLMWrapper: 14 | 15 | # TODO: provide options for doing inference with different kernels 16 | 17 | super().__init__(**kwargs) 18 | 19 | 20 | if __name__ == "__main__": 21 | cli_evaluate() 22 | -------------------------------------------------------------------------------- /fla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.layers import ( 4 | ABCAttention, 5 | Attention, 6 | BasedLinearAttention, 7 | BitAttention, 8 | DeltaNet, 9 | GatedDeltaNet, 10 | GatedDeltaProduct, 11 | GatedLinearAttention, 12 | GatedSlotAttention, 13 | HGRN2Attention, 14 | HGRNAttention, 15 | LightNetAttention, 16 | LinearAttention, 17 | MesaNet, 18 | MultiScaleRetention, 19 | NativeSparseAttention, 20 | PaTHAttention, 21 | ReBasedLinearAttention, 22 | RWKV6Attention, 23 | RWKV7Attention 24 | ) 25 | from fla.models import ( 26 | ABCForCausalLM, 27 | ABCModel, 28 | BitNetForCausalLM, 29 | BitNetModel, 30 | DeltaNetForCausalLM, 31 | DeltaNetModel, 32 | GatedDeltaNetForCausalLM, 33 | GatedDeltaNetModel, 34 | GatedDeltaProductForCausalLM, 35 | GatedDeltaProductModel, 36 | GLAForCausalLM, 37 | GLAModel, 38 | GSAForCausalLM, 39 | GSAModel, 40 | HGRN2ForCausalLM, 41 | HGRN2Model, 42 | HGRNForCausalLM, 43 | HGRNModel, 44 | LightNetForCausalLM, 45 | LightNetModel, 46 | LinearAttentionForCausalLM, 47 | LinearAttentionModel, 48 | MesaNetConfig, 49 | MesaNetForCausalLM, 50 | MesaNetModel, 51 | NSAForCausalLM, 52 | NSAModel, 53 | PaTHAttentionForCausalLM, 54 | PaTHAttentionModel, 55 | RetNetForCausalLM, 56 | RetNetModel, 57 | RodimusForCausalLM, 58 | RodimusModel, 59 | RWKV6ForCausalLM, 60 | RWKV6Model, 61 | RWKV7ForCausalLM, 62 | RWKV7Model, 63 | TransformerForCausalLM, 64 | TransformerModel 65 | ) 66 | 67 | __all__ = [ 68 | 'ABCAttention', 69 | 'Attention', 70 | 'BasedLinearAttention', 71 | 'BitAttention', 72 | 'DeltaNet', 73 | 'GatedDeltaNet', 74 | 'GatedDeltaProduct', 75 | 'GatedLinearAttention', 76 | 'GatedSlotAttention', 77 | 'HGRNAttention', 78 | 'HGRN2Attention', 79 | 'LightNetAttention', 80 | 'LinearAttention', 81 | 'MultiScaleRetention', 82 | 'NativeSparseAttention', 83 | 'PaTHAttention', 84 | 'ReBasedLinearAttention', 85 | 'RWKV6Attention', 86 | 'RWKV7Attention', 87 | 'ABCForCausalLM', 88 | 'ABCModel', 89 | 'BitNetForCausalLM', 90 | 'BitNetModel', 91 | 'DeltaNetForCausalLM', 92 | 'DeltaNetModel', 93 | 'GatedDeltaNetForCausalLM', 94 | 'GatedDeltaNetModel', 95 | 'GatedDeltaProductForCausalLM', 96 | 'GatedDeltaProductModel', 97 | 'GLAForCausalLM', 98 | 'GLAModel', 99 | 'GSAForCausalLM', 100 | 'GSAModel', 101 | 'HGRNForCausalLM', 102 | 'HGRNModel', 103 | 'HGRN2ForCausalLM', 104 | 'HGRN2Model', 105 | 'LightNetForCausalLM', 106 | 'LightNetModel', 107 | 'LinearAttentionForCausalLM', 108 | 'LinearAttentionModel', 109 | 'NSAForCausalLM', 110 | 'NSAModel', 111 | 'PaTHAttentionForCausalLM', 112 | 'PaTHAttentionModel', 113 | 'RetNetForCausalLM', 114 | 'RetNetModel', 115 | 'RWKV6ForCausalLM', 116 | 'RWKV6Model', 117 | 'RWKV7ForCausalLM', 118 | 'RWKV7Model', 119 | 'RodimusForCausalLM', 120 | 'RodimusModel', 121 | 'TransformerForCausalLM', 122 | 'TransformerModel', 123 | 'MesaNetConfig', 124 | 'MesaNetForCausalLM', 125 | 'MesaNetModel', 126 | 'MesaNet' 127 | ] 128 | 129 | __version__ = '0.7' 130 | -------------------------------------------------------------------------------- /fla/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | from .abc import ABCAttention 5 | from .attn import Attention 6 | from .based import BasedLinearAttention 7 | from .bitattn import BitAttention 8 | from .delta_net import DeltaNet 9 | from .forgetting_attn import ForgettingAttention 10 | from .gated_deltanet import GatedDeltaNet 11 | from .gated_deltaproduct import GatedDeltaProduct 12 | from .gla import GatedLinearAttention 13 | from .gsa import GatedSlotAttention 14 | from .hgrn import HGRNAttention 15 | from .hgrn2 import HGRN2Attention 16 | from .lightnet import LightNetAttention 17 | from .linear_attn import LinearAttention 18 | from .mamba import Mamba 19 | from .mamba2 import Mamba2 20 | from .mesa_net import MesaNet 21 | from .multiscale_retention import MultiScaleRetention 22 | from .nsa import NativeSparseAttention 23 | from .path_attn import PaTHAttention 24 | from .rebased import ReBasedLinearAttention 25 | from .rodimus import RodimusAttention, SlidingWindowSharedKeyAttention 26 | from .rwkv6 import RWKV6Attention 27 | from .rwkv7 import RWKV7Attention 28 | 29 | __all__ = [ 30 | 'ABCAttention', 31 | 'Attention', 32 | 'BasedLinearAttention', 33 | 'BitAttention', 34 | 'DeltaNet', 35 | 'ForgettingAttention', 36 | 'GatedDeltaNet', 37 | 'GatedDeltaProduct', 38 | 'GatedLinearAttention', 39 | 'GatedSlotAttention', 40 | 'HGRNAttention', 41 | 'HGRN2Attention', 42 | 'LightNetAttention', 43 | 'LinearAttention', 44 | 'Mamba', 45 | 'Mamba2', 46 | 'MultiScaleRetention', 47 | 'NativeSparseAttention', 48 | 'ReBasedLinearAttention', 49 | 'RWKV6Attention', 50 | 'RWKV7Attention', 51 | 'RodimusAttention', 52 | 'SlidingWindowSharedKeyAttention', 53 | 'PaTHAttention', 54 | 'MesaNet' 55 | ] 56 | -------------------------------------------------------------------------------- /fla/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel 4 | from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel 5 | from fla.models.delta_net import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel 6 | from fla.models.forgetting_transformer import ( 7 | ForgettingTransformerConfig, 8 | ForgettingTransformerForCausalLM, 9 | ForgettingTransformerModel 10 | ) 11 | from fla.models.gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel 12 | from fla.models.gated_deltaproduct import GatedDeltaProductConfig, GatedDeltaProductForCausalLM, GatedDeltaProductModel 13 | from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel 14 | from fla.models.gsa import GSAConfig, GSAForCausalLM, GSAModel 15 | from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel 16 | from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model 17 | from fla.models.lightnet import LightNetConfig, LightNetForCausalLM, LightNetModel 18 | from fla.models.linear_attn import LinearAttentionConfig, LinearAttentionForCausalLM, LinearAttentionModel 19 | from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel 20 | from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model 21 | from fla.models.mesa_net import MesaNetConfig, MesaNetForCausalLM, MesaNetModel 22 | from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel 23 | from fla.models.path_attn import PaTHAttentionConfig, PaTHAttentionForCausalLM, PaTHAttentionModel 24 | from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel 25 | from fla.models.rodimus import RodimusConfig, RodimusForCausalLM, RodimusModel, RodimusTokenizer 26 | from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model 27 | from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model 28 | from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel 29 | from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel 30 | 31 | __all__ = [ 32 | 'ABCConfig', 'ABCForCausalLM', 'ABCModel', 33 | 'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel', 34 | 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel', 35 | 'ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel', 36 | 'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel', 37 | 'GLAConfig', 'GLAForCausalLM', 'GLAModel', 38 | 'GSAConfig', 'GSAForCausalLM', 'GSAModel', 39 | 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel', 40 | 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model', 41 | 'LightNetConfig', 'LightNetForCausalLM', 'LightNetModel', 42 | 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel', 43 | 'MambaConfig', 'MambaForCausalLM', 'MambaModel', 44 | 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model', 45 | 'NSAConfig', 'NSAForCausalLM', 'NSAModel', 46 | 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel', 47 | 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model', 48 | 'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model', 49 | 'SambaConfig', 'SambaForCausalLM', 'SambaModel', 50 | 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel', 51 | 'GatedDeltaProductConfig', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel', 52 | 'RodimusConfig', 'RodimusForCausalLM', 'RodimusModel', 'RodimusTokenizer', 53 | 'PaTHAttentionConfig', 'PaTHAttentionForCausalLM', 'PaTHAttentionModel', 54 | 'MesaNetConfig', 'MesaNetForCausalLM', 'MesaNetModel', 55 | ] 56 | -------------------------------------------------------------------------------- /fla/models/abc/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.abc.configuration_abc import ABCConfig 6 | from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel 7 | 8 | AutoConfig.register(ABCConfig.model_type, ABCConfig, exist_ok=True) 9 | AutoModel.register(ABCConfig, ABCModel, exist_ok=True) 10 | AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel'] 14 | -------------------------------------------------------------------------------- /fla/models/abc/configuration_abc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class ABCConfig(PretrainedConfig): 9 | 10 | model_type = 'abc' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2048, 16 | gate_low_rank_dim: int = 16, 17 | clamp_min: float = -32, 18 | clamp_max: float = 32, 19 | hidden_ratio: Optional[int] = 4, 20 | intermediate_size: Optional[int] = None, 21 | num_hidden_layers: int = 24, 22 | num_heads: int = 4, 23 | num_slots: Optional[int] = 64, 24 | use_short_conv: bool = False, 25 | conv_size: int = 4, 26 | exapnd_k: float = 0.5, 27 | exapnd_v: float = 1, 28 | hidden_act: str = "swish", 29 | max_position_embeddings: int = 2048, 30 | elementwise_affine: Optional[bool] = True, 31 | norm_eps: float = 1e-6, 32 | use_rope: bool = True, 33 | attn: Optional[Dict] = None, 34 | use_cache: bool = True, 35 | pad_token_id: int = None, 36 | bos_token_id: int = 1, 37 | eos_token_id: int = 2, 38 | tie_word_embeddings: bool = False, 39 | initializer_range: float = 0.02, 40 | fuse_norm: bool = True, 41 | fuse_swiglu: bool = True, 42 | fuse_cross_entropy: bool = True, 43 | vocab_size: int = 32000, 44 | **kwargs 45 | ): 46 | self.hidden_size = hidden_size 47 | self.gate_low_rank_dim = gate_low_rank_dim 48 | self.clamp_min = clamp_min 49 | self.clamp_max = clamp_max 50 | self.hidden_ratio = hidden_ratio 51 | self.intermediate_size = intermediate_size 52 | self.num_hidden_layers = num_hidden_layers 53 | self.num_heads = num_heads 54 | self.num_slots = num_slots 55 | self.use_short_conv = use_short_conv 56 | self.conv_size = conv_size 57 | self.expand_k = exapnd_k 58 | self.expand_v = exapnd_v 59 | self.hidden_act = hidden_act 60 | self.max_position_embeddings = max_position_embeddings 61 | self.elementwise_affine = elementwise_affine 62 | self.norm_eps = norm_eps 63 | self.use_rope = use_rope 64 | self.attn = attn 65 | self.use_cache = use_cache 66 | self.initializer_range = initializer_range 67 | 68 | self.fuse_norm = fuse_norm 69 | self.fuse_swiglu = fuse_swiglu 70 | self.fuse_cross_entropy = fuse_cross_entropy 71 | self.vocab_size = vocab_size 72 | 73 | if attn is not None: 74 | if not isinstance(attn, Dict): 75 | raise ValueError("attn must be a dictionary") 76 | if 'layers' not in attn: 77 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 78 | if 'num_heads' not in attn: 79 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 80 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 81 | attn['qkv_bias'] = attn.get('qkv_bias', False) 82 | attn['window_size'] = attn.get('window_size', None) 83 | attn['rope_theta'] = attn.get('rope_theta', 10000.) 84 | 85 | super().__init__( 86 | pad_token_id=pad_token_id, 87 | bos_token_id=bos_token_id, 88 | eos_token_id=eos_token_id, 89 | tie_word_embeddings=tie_word_embeddings, 90 | **kwargs, 91 | ) 92 | -------------------------------------------------------------------------------- /fla/models/bitnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.bitnet.configuration_bitnet import BitNetConfig 6 | from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel 7 | 8 | AutoConfig.register(BitNetConfig.model_type, BitNetConfig, exist_ok=True) 9 | AutoModel.register(BitNetConfig, BitNetModel, exist_ok=True) 10 | AutoModelForCausalLM.register(BitNetConfig, BitNetForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel'] 14 | -------------------------------------------------------------------------------- /fla/models/bitnet/configuration_bitnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class BitNetConfig(PretrainedConfig): 9 | 10 | model_type = 'bitnet' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2048, 16 | num_hidden_layers: int = 24, 17 | num_heads: int = 32, 18 | num_kv_heads: int = None, 19 | window_size: Optional[int] = None, 20 | rope_theta: Optional[float] = 10000., 21 | max_position_embeddings: int = 2048, 22 | hidden_ratio: Optional[int] = 4, 23 | intermediate_size: Optional[int] = None, 24 | hidden_act: str = "swish", 25 | initializer_range: float = 0.02, 26 | elementwise_affine: Optional[bool] = True, 27 | norm_eps: float = 1e-6, 28 | use_cache: bool = True, 29 | pad_token_id: int = None, 30 | bos_token_id: int = 1, 31 | eos_token_id: int = 2, 32 | tie_word_embeddings: bool = False, 33 | fuse_norm: bool = True, 34 | fuse_swiglu: bool = True, 35 | fuse_cross_entropy: bool = True, 36 | vocab_size: int = 32000, 37 | **kwargs, 38 | ): 39 | self.hidden_size = hidden_size 40 | self.num_hidden_layers = num_hidden_layers 41 | self.num_heads = num_heads 42 | self.num_kv_heads = num_kv_heads 43 | self.window_size = window_size 44 | self.rope_theta = rope_theta 45 | self.max_position_embeddings = max_position_embeddings 46 | 47 | self.hidden_ratio = hidden_ratio 48 | self.intermediate_size = intermediate_size 49 | self.hidden_act = hidden_act 50 | 51 | self.initializer_range = initializer_range 52 | self.elementwise_affine = elementwise_affine 53 | self.norm_eps = norm_eps 54 | self.use_cache = use_cache 55 | 56 | self.fuse_norm = fuse_norm 57 | self.fuse_swiglu = fuse_swiglu 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | self.vocab_size = vocab_size 60 | 61 | super().__init__( 62 | pad_token_id=pad_token_id, 63 | bos_token_id=bos_token_id, 64 | eos_token_id=eos_token_id, 65 | tie_word_embeddings=tie_word_embeddings, 66 | **kwargs, 67 | ) 68 | -------------------------------------------------------------------------------- /fla/models/delta_net/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.delta_net.configuration_delta_net import DeltaNetConfig 6 | from fla.models.delta_net.modeling_delta_net import DeltaNetForCausalLM, DeltaNetModel 7 | 8 | AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig, exist_ok=True) 9 | AutoModel.register(DeltaNetConfig, DeltaNetModel, exist_ok=True) 10 | AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM, exist_ok=True) 11 | 12 | __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel'] 13 | -------------------------------------------------------------------------------- /fla/models/delta_net/configuration_delta_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class DeltaNetConfig(PretrainedConfig): 9 | 10 | model_type = 'delta_net' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "chunk", 16 | hidden_size: int = 2048, 17 | expand_k: int = 1, 18 | expand_v: int = 1, 19 | use_gate: bool = False, 20 | use_short_conv: bool = True, 21 | conv_size: int = 4, 22 | use_beta: bool = True, 23 | use_output_norm: bool = True, 24 | num_heads: int = 16, 25 | qk_norm: str = 'l2', 26 | qk_activation: str = 'silu', 27 | max_position_embeddings: int = 2048, 28 | hidden_ratio: Optional[int] = 4, 29 | intermediate_size: Optional[int] = None, 30 | hidden_act: str = "swish", 31 | num_hidden_layers: int = 24, 32 | norm_eps: float = 1e-6, 33 | attn: Optional[Dict] = None, 34 | use_cache: bool = True, 35 | pad_token_id: int = None, 36 | bos_token_id: int = 1, 37 | eos_token_id: int = 2, 38 | tie_word_embeddings: bool = False, 39 | initializer_range: float = 0.02, 40 | fuse_norm: bool = True, 41 | fuse_swiglu: bool = True, 42 | fuse_cross_entropy: bool = True, 43 | vocab_size: int = 32000, 44 | **kwargs 45 | ): 46 | self.attn_mode = attn_mode 47 | self.hidden_size = hidden_size 48 | self.expand_k = expand_k 49 | self.expand_v = expand_v 50 | self.use_gate = use_gate 51 | self.use_short_conv = use_short_conv 52 | self.conv_size = conv_size 53 | self.use_beta = use_beta 54 | self.use_output_norm = use_output_norm 55 | self.num_heads = num_heads 56 | self.qk_norm = qk_norm 57 | self.qk_activation = qk_activation 58 | self.max_position_embeddings = max_position_embeddings 59 | 60 | self.hidden_ratio = hidden_ratio 61 | self.intermediate_size = intermediate_size 62 | self.hidden_act = hidden_act 63 | self.num_hidden_layers = num_hidden_layers 64 | self.norm_eps = norm_eps 65 | self.attn = attn 66 | self.use_cache = use_cache 67 | self.initializer_range = initializer_range 68 | self.fuse_norm = fuse_norm 69 | self.fuse_swiglu = fuse_swiglu 70 | self.fuse_cross_entropy = fuse_cross_entropy 71 | self.vocab_size = vocab_size 72 | 73 | if attn is not None: 74 | if not isinstance(attn, Dict): 75 | raise ValueError("attn must be a dictionary") 76 | if 'layers' not in attn: 77 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 78 | if 'num_heads' not in attn: 79 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 80 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 81 | attn['qkv_bias'] = attn.get('qkv_bias', False) 82 | attn['window_size'] = attn.get('window_size', None) 83 | attn['rope_theta'] = attn.get('rope_theta', 10000.) 84 | 85 | super().__init__( 86 | pad_token_id=pad_token_id, 87 | bos_token_id=bos_token_id, 88 | eos_token_id=eos_token_id, 89 | tie_word_embeddings=tie_word_embeddings, 90 | **kwargs, 91 | ) 92 | -------------------------------------------------------------------------------- /fla/models/forgetting_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig 6 | from fla.models.forgetting_transformer.modeling_forgetting_transformer import ( 7 | ForgettingTransformerForCausalLM, 8 | ForgettingTransformerModel 9 | ) 10 | 11 | AutoConfig.register(ForgettingTransformerConfig.model_type, ForgettingTransformerConfig, exist_ok=True) 12 | AutoModel.register(ForgettingTransformerConfig, ForgettingTransformerModel, exist_ok=True) 13 | AutoModelForCausalLM.register(ForgettingTransformerConfig, ForgettingTransformerForCausalLM, exist_ok=True) 14 | 15 | 16 | __all__ = ['ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel'] 17 | -------------------------------------------------------------------------------- /fla/models/forgetting_transformer/configuration_forgetting_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class ForgettingTransformerConfig(PretrainedConfig): 9 | 10 | model_type = 'forgetting_transformer' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2048, 16 | num_hidden_layers: int = 24, 17 | num_heads: int = 32, 18 | num_kv_heads: Optional[int] = None, 19 | qkv_bias: bool = False, 20 | qk_norm: bool = False, 21 | window_size: Optional[int] = None, 22 | use_output_gate: bool = False, 23 | hidden_ratio: Optional[int] = 4, 24 | intermediate_size: Optional[int] = None, 25 | hidden_act: str = "swish", 26 | initializer_range: float = 0.02, 27 | elementwise_affine: Optional[bool] = True, 28 | norm_eps: float = 1e-6, 29 | use_cache: bool = True, 30 | pad_token_id: Optional[int] = None, 31 | bos_token_id: int = 1, 32 | eos_token_id: int = 2, 33 | tie_word_embeddings: bool = False, 34 | fuse_norm: bool = True, 35 | fuse_swiglu: bool = True, 36 | fuse_cross_entropy: bool = True, 37 | vocab_size: int = 32000, 38 | **kwargs, 39 | ): 40 | self.hidden_size = hidden_size 41 | self.num_hidden_layers = num_hidden_layers 42 | self.num_heads = num_heads 43 | self.num_kv_heads = num_kv_heads 44 | self.qkv_bias = qkv_bias 45 | self.qk_norm = qk_norm 46 | self.window_size = window_size 47 | self.use_output_gate = use_output_gate 48 | self.hidden_ratio = hidden_ratio 49 | self.intermediate_size = intermediate_size 50 | self.hidden_act = hidden_act 51 | 52 | self.initializer_range = initializer_range 53 | self.elementwise_affine = elementwise_affine 54 | self.norm_eps = norm_eps 55 | self.use_cache = use_cache 56 | 57 | self.fuse_norm = fuse_norm 58 | self.fuse_swiglu = fuse_swiglu 59 | self.fuse_cross_entropy = fuse_cross_entropy 60 | self.vocab_size = vocab_size 61 | 62 | super().__init__( 63 | pad_token_id=pad_token_id, 64 | bos_token_id=bos_token_id, 65 | eos_token_id=eos_token_id, 66 | tie_word_embeddings=tie_word_embeddings, 67 | **kwargs, 68 | ) 69 | -------------------------------------------------------------------------------- /fla/models/gated_deltanet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig 6 | from fla.models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel 7 | 8 | AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig, exist_ok=True) 9 | AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel, exist_ok=True) 10 | AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM, exist_ok=True) 11 | 12 | __all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel'] 13 | -------------------------------------------------------------------------------- /fla/models/gated_deltanet/configuration_gated_deltanet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class GatedDeltaNetConfig(PretrainedConfig): 9 | model_type = 'gated_deltanet' 10 | keys_to_ignore_at_inference = ['past_key_values'] 11 | 12 | def __init__( 13 | self, 14 | attn_mode: str = "chunk", 15 | hidden_size: int = 2048, 16 | expand_v: int = 2, 17 | use_gate: bool = True, 18 | use_short_conv: bool = True, 19 | conv_size: int = 4, 20 | head_dim: int = 256, 21 | num_heads: int = 6, 22 | num_v_heads: Optional[int] = None, 23 | max_position_embeddings: int = 2048, 24 | hidden_ratio: Optional[int] = 4, 25 | intermediate_size: Optional[int] = None, 26 | hidden_act: str = "swish", 27 | num_hidden_layers: int = 21, 28 | norm_eps: float = 1e-6, 29 | attn: Optional[Dict] = None, 30 | use_cache: bool = True, 31 | pad_token_id: int = None, 32 | bos_token_id: int = 1, 33 | eos_token_id: int = 2, 34 | tie_word_embeddings: bool = False, 35 | initializer_range: float = 0.02, 36 | fuse_norm: bool = True, 37 | fuse_swiglu: bool = True, 38 | fuse_cross_entropy: bool = True, 39 | vocab_size: int = 32000, 40 | **kwargs 41 | ): 42 | self.attn_mode = attn_mode 43 | self.hidden_size = hidden_size 44 | self.expand_v = expand_v 45 | self.use_gate = use_gate 46 | self.use_short_conv = use_short_conv 47 | self.conv_size = conv_size 48 | self.head_dim = head_dim 49 | self.num_heads = num_heads 50 | self.num_v_heads = num_v_heads 51 | self.max_position_embeddings = max_position_embeddings 52 | 53 | self.hidden_ratio = hidden_ratio 54 | self.intermediate_size = intermediate_size 55 | self.hidden_act = hidden_act 56 | self.num_hidden_layers = num_hidden_layers 57 | self.norm_eps = norm_eps 58 | self.attn = attn 59 | self.use_cache = use_cache 60 | self.initializer_range = initializer_range 61 | 62 | self.fuse_norm = fuse_norm 63 | self.fuse_swiglu = fuse_swiglu 64 | self.fuse_cross_entropy = fuse_cross_entropy 65 | self.vocab_size = vocab_size 66 | 67 | if attn is not None: 68 | if not isinstance(attn, Dict): 69 | raise ValueError("attn must be a dictionary") 70 | if 'layers' not in attn: 71 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 72 | if 'num_heads' not in attn: 73 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 74 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 75 | attn['qkv_bias'] = attn.get('qkv_bias', False) 76 | attn['window_size'] = attn.get('window_size', None) 77 | attn['rope_theta'] = attn.get('rope_theta', 10000.) 78 | 79 | super().__init__( 80 | pad_token_id=pad_token_id, 81 | bos_token_id=bos_token_id, 82 | eos_token_id=eos_token_id, 83 | tie_word_embeddings=tie_word_embeddings, 84 | **kwargs, 85 | ) 86 | -------------------------------------------------------------------------------- /fla/models/gated_deltaproduct/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | 3 | from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig 4 | from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductForCausalLM, GatedDeltaProductModel 5 | 6 | AutoConfig.register(GatedDeltaProductConfig.model_type, GatedDeltaProductConfig, exist_ok=True) 7 | AutoModel.register(GatedDeltaProductConfig, GatedDeltaProductModel, exist_ok=True) 8 | AutoModelForCausalLM.register(GatedDeltaProductConfig, GatedDeltaProductForCausalLM, exist_ok=True) 9 | 10 | __all__ = [ 11 | "GatedDeltaProductConfig", 12 | "GatedDeltaProductForCausalLM", 13 | "GatedDeltaProductModel", 14 | ] 15 | -------------------------------------------------------------------------------- /fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class GatedDeltaProductConfig(PretrainedConfig): 9 | model_type = 'gated_deltaproduct' 10 | keys_to_ignore_at_inference = ['past_key_values'] 11 | 12 | def __init__( 13 | self, 14 | attn_mode: str = "chunk", 15 | hidden_size: int = 2048, 16 | expand_v: int = 2, 17 | use_gate: bool = True, 18 | use_short_conv: bool = True, 19 | conv_size: int = 4, 20 | head_dim: int = 256, 21 | num_heads: int = 6, 22 | max_position_embeddings: int = 2048, 23 | hidden_ratio: Optional[int] = 4, 24 | intermediate_size: Optional[int] = None, 25 | hidden_act: str = "swish", 26 | num_hidden_layers: int = 21, 27 | norm_eps: float = 1e-6, 28 | attn: Optional[Dict] = None, 29 | use_cache: bool = True, 30 | pad_token_id: int = None, 31 | bos_token_id: int = 1, 32 | eos_token_id: int = 2, 33 | tie_word_embeddings: bool = False, 34 | initializer_range: float = 0.02, 35 | fuse_norm: bool = True, 36 | fuse_swiglu: bool = True, 37 | fuse_cross_entropy: bool = True, 38 | vocab_size: int = 32000, 39 | use_forget_gate: bool = False, 40 | allow_neg_eigval: bool = False, 41 | num_householder: int = 1, 42 | **kwargs, 43 | ): 44 | self.attn_mode = attn_mode 45 | self.hidden_size = hidden_size 46 | self.expand_v = expand_v 47 | self.use_gate = use_gate 48 | self.use_short_conv = use_short_conv 49 | self.conv_size = conv_size 50 | self.head_dim = head_dim 51 | self.num_heads = num_heads 52 | self.max_position_embeddings = max_position_embeddings 53 | 54 | self.hidden_ratio = hidden_ratio 55 | self.intermediate_size = intermediate_size 56 | self.hidden_act = hidden_act 57 | self.num_hidden_layers = num_hidden_layers 58 | self.norm_eps = norm_eps 59 | self.attn = attn 60 | self.use_cache = use_cache 61 | self.initializer_range = initializer_range 62 | 63 | self.fuse_norm = fuse_norm 64 | self.fuse_swiglu = fuse_swiglu 65 | self.fuse_cross_entropy = fuse_cross_entropy 66 | self.vocab_size = vocab_size 67 | 68 | # DeltaProduct specific 69 | self.allow_neg_eigval = allow_neg_eigval 70 | self.num_householder = num_householder 71 | self.use_forget_gate = use_forget_gate 72 | 73 | if attn is not None: 74 | if not isinstance(attn, Dict): 75 | raise ValueError("attn must be a dictionary") 76 | if 'layers' not in attn: 77 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 78 | if 'num_heads' not in attn: 79 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 80 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 81 | attn['qkv_bias'] = attn.get('qkv_bias', False) 82 | attn['window_size'] = attn.get('window_size', None) 83 | attn['rope_theta'] = attn.get('rope_theta', 10000.) 84 | 85 | super().__init__( 86 | pad_token_id=pad_token_id, 87 | bos_token_id=bos_token_id, 88 | eos_token_id=eos_token_id, 89 | tie_word_embeddings=tie_word_embeddings, 90 | **kwargs, 91 | ) 92 | -------------------------------------------------------------------------------- /fla/models/gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.gla.configuration_gla import GLAConfig 6 | from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel 7 | 8 | AutoConfig.register(GLAConfig.model_type, GLAConfig, exist_ok=True) 9 | AutoModel.register(GLAConfig, GLAModel, exist_ok=True) 10 | AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel'] 14 | -------------------------------------------------------------------------------- /fla/models/gsa/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.gsa.configuration_gsa import GSAConfig 6 | from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel 7 | 8 | AutoConfig.register(GSAConfig.model_type, GSAConfig, exist_ok=True) 9 | AutoModel.register(GSAConfig, GSAModel, exist_ok=True) 10 | AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel'] 14 | -------------------------------------------------------------------------------- /fla/models/hgrn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.hgrn.configuration_hgrn import HGRNConfig 6 | from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel 7 | 8 | AutoConfig.register(HGRNConfig.model_type, HGRNConfig, exist_ok=True) 9 | AutoModel.register(HGRNConfig, HGRNModel, exist_ok=True) 10 | AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel'] 14 | -------------------------------------------------------------------------------- /fla/models/hgrn/configuration_hgrn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class HGRNConfig(PretrainedConfig): 9 | 10 | model_type = 'hgrn' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "fused_recurrent", 16 | hidden_size: int = 2048, 17 | num_hidden_layers: int = 24, 18 | expand_ratio: Optional[int] = 1, 19 | use_short_conv: bool = False, 20 | conv_size: int = 4, 21 | use_lower_bound: bool = True, 22 | max_position_embeddings: int = 2048, 23 | hidden_ratio: Optional[int] = 4, 24 | intermediate_size: Optional[int] = None, 25 | hidden_act: str = "swish", 26 | elementwise_affine: Optional[bool] = True, 27 | norm_eps: float = 1e-6, 28 | attn: Optional[Dict] = None, 29 | use_cache: bool = True, 30 | pad_token_id: int = None, 31 | bos_token_id: int = 1, 32 | eos_token_id: int = 2, 33 | tie_word_embeddings: bool = False, 34 | initializer_range: float = 0.02, 35 | fuse_norm: bool = True, 36 | fuse_swiglu: bool = True, 37 | fuse_cross_entropy: bool = True, 38 | vocab_size: int = 32000, 39 | **kwargs 40 | ): 41 | self.attn_mode = attn_mode 42 | self.hidden_size = hidden_size 43 | self.num_hidden_layers = num_hidden_layers 44 | self.expand_ratio = expand_ratio 45 | self.use_short_conv = use_short_conv 46 | self.conv_size = conv_size 47 | self.use_lower_bound = use_lower_bound 48 | self.max_position_embeddings = max_position_embeddings 49 | self.hidden_ratio = hidden_ratio 50 | self.intermediate_size = intermediate_size 51 | self.elementwise_affine = elementwise_affine 52 | self.attn = attn 53 | self.norm_eps = norm_eps 54 | self.hidden_act = hidden_act 55 | self.use_cache = use_cache 56 | self.initializer_range = initializer_range 57 | 58 | self.fuse_norm = fuse_norm 59 | self.fuse_swiglu = fuse_swiglu 60 | self.fuse_cross_entropy = fuse_cross_entropy 61 | self.vocab_size = vocab_size 62 | 63 | if attn is not None: 64 | if not isinstance(attn, Dict): 65 | raise ValueError("attn must be a dictionary") 66 | if 'layers' not in attn: 67 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 68 | if 'num_heads' not in attn: 69 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 70 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 71 | attn['qkv_bias'] = attn.get('qkv_bias', False) 72 | attn['window_size'] = attn.get('window_size', None) 73 | attn['rope_theta'] = attn.get('rope_theta', 10000.) 74 | 75 | super().__init__( 76 | pad_token_id=pad_token_id, 77 | bos_token_id=bos_token_id, 78 | eos_token_id=eos_token_id, 79 | tie_word_embeddings=tie_word_embeddings, 80 | **kwargs, 81 | ) 82 | -------------------------------------------------------------------------------- /fla/models/hgrn2/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config 6 | from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model 7 | 8 | AutoConfig.register(HGRN2Config.model_type, HGRN2Config, exist_ok=True) 9 | AutoModel.register(HGRN2Config, HGRN2Model, exist_ok=True) 10 | AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model'] 14 | -------------------------------------------------------------------------------- /fla/models/lightnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.lightnet.configuration_lightnet import LightNetConfig 6 | from fla.models.lightnet.modeling_lightnet import LightNetForCausalLM, LightNetModel 7 | 8 | AutoConfig.register(LightNetConfig.model_type, LightNetConfig, exist_ok=True) 9 | AutoModel.register(LightNetConfig, LightNetModel, exist_ok=True) 10 | AutoModelForCausalLM.register(LightNetConfig, LightNetForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['LightNetConfig', 'LightNetForCausalLM', 'LightNetModel'] 14 | -------------------------------------------------------------------------------- /fla/models/lightnet/configuration_lightnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class LightNetConfig(PretrainedConfig): 9 | 10 | model_type = 'lightnet' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2048, 16 | num_hidden_layers: int = 24, 17 | attn_mode: str = "chunk", 18 | num_heads: Optional[int] = None, 19 | expand_ratio: Optional[int] = 128, 20 | use_short_conv: bool = False, 21 | conv_size: int = 4, 22 | hidden_ratio: Optional[int] = 4, 23 | intermediate_size: Optional[int] = None, 24 | hidden_act: str = "swish", 25 | max_position_embeddings: int = 2048, 26 | gate_low_rank_dim: int = 128, 27 | elementwise_affine: Optional[bool] = True, 28 | norm_eps: float = 1e-6, 29 | attn: Optional[Dict] = None, 30 | use_cache: bool = True, 31 | pad_token_id: int = None, 32 | bos_token_id: int = 1, 33 | eos_token_id: int = 2, 34 | tie_word_embeddings: bool = False, 35 | initializer_range: float = 0.02, 36 | fuse_norm: bool = True, 37 | fuse_swiglu: bool = True, 38 | fuse_cross_entropy: bool = True, 39 | vocab_size: int = 32000, 40 | **kwargs 41 | ): 42 | self.hidden_size = hidden_size 43 | self.num_hidden_layers = num_hidden_layers 44 | self.attn_mode = attn_mode 45 | self.num_heads = num_heads 46 | self.expand_ratio = expand_ratio 47 | self.use_short_conv = use_short_conv 48 | self.conv_size = conv_size 49 | self.max_position_embeddings = max_position_embeddings 50 | self.gate_low_rank_dim = gate_low_rank_dim 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.hidden_act = hidden_act 54 | self.elementwise_affine = elementwise_affine 55 | self.norm_eps = norm_eps 56 | self.attn = attn 57 | self.use_cache = use_cache 58 | self.initializer_range = initializer_range 59 | 60 | self.fuse_norm = fuse_norm 61 | self.fuse_swiglu = fuse_swiglu 62 | self.fuse_cross_entropy = fuse_cross_entropy 63 | self.vocab_size = vocab_size 64 | 65 | if attn is not None: 66 | if not isinstance(attn, Dict): 67 | raise ValueError("attn must be a dictionary") 68 | if 'layers' not in attn: 69 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 70 | if 'num_heads' not in attn: 71 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 72 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 73 | attn['qkv_bias'] = attn.get('qkv_bias', False) 74 | attn['window_size'] = attn.get('window_size', None) 75 | attn['rope_theta'] = attn.get('rope_theta', 10000.) 76 | 77 | super().__init__( 78 | pad_token_id=pad_token_id, 79 | bos_token_id=bos_token_id, 80 | eos_token_id=eos_token_id, 81 | tie_word_embeddings=tie_word_embeddings, 82 | **kwargs, 83 | ) 84 | -------------------------------------------------------------------------------- /fla/models/linear_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig 6 | from fla.models.linear_attn.modeling_linear_attn import LinearAttentionForCausalLM, LinearAttentionModel 7 | 8 | AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig, exist_ok=True) 9 | AutoModel.register(LinearAttentionConfig, LinearAttentionModel, exist_ok=True) 10 | AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM, exist_ok=True) 11 | 12 | __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] 13 | -------------------------------------------------------------------------------- /fla/models/mamba/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.mamba.configuration_mamba import MambaConfig 6 | from fla.models.mamba.modeling_mamba import MambaBlock, MambaForCausalLM, MambaModel 7 | 8 | AutoConfig.register(MambaConfig.model_type, MambaConfig, exist_ok=True) 9 | AutoModel.register(MambaConfig, MambaModel, exist_ok=True) 10 | AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock'] 14 | -------------------------------------------------------------------------------- /fla/models/mamba2/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.mamba2.configuration_mamba2 import Mamba2Config 6 | from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model 7 | 8 | AutoConfig.register(Mamba2Config.model_type, Mamba2Config, exist_ok=True) 9 | AutoModel.register(Mamba2Config, Mamba2Model, exist_ok=True) 10 | AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model'] 14 | -------------------------------------------------------------------------------- /fla/models/mesa_net/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.mesa_net.configuration_mesa_net import MesaNetConfig 6 | from fla.models.mesa_net.modeling_mesa_net import MesaNetForCausalLM, MesaNetModel 7 | 8 | AutoConfig.register(MesaNetConfig.model_type, MesaNetConfig, exist_ok=True) 9 | AutoModel.register(MesaNetConfig, MesaNetModel, exist_ok=True) 10 | AutoModelForCausalLM.register(MesaNetConfig, MesaNetForCausalLM, exist_ok=True) 11 | 12 | __all__ = ['MesaNetConfig', 'MesaNetForCausalLM', 'MesaNetModel'] 13 | -------------------------------------------------------------------------------- /fla/models/mesa_net/configuration_mesa_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class MesaNetConfig(PretrainedConfig): 9 | model_type = 'mesa_net' 10 | keys_to_ignore_at_inference = ['past_key_values'] 11 | 12 | def __init__( 13 | self, 14 | attn_mode: str = "chunk", 15 | hidden_size: int = 2048, 16 | expand_v: int = 1, 17 | use_gate: bool = False, 18 | use_short_conv: bool = True, 19 | conv_size: int = 4, 20 | num_heads: int = 16, 21 | lambda_lower_bound: float = 0.25, 22 | max_position_embeddings: int = 2048, 23 | hidden_ratio: Optional[int] = 4, 24 | intermediate_size: Optional[int] = None, 25 | hidden_act: str = "swish", 26 | num_hidden_layers: int = 24, 27 | norm_eps: float = 1e-6, 28 | attn: Optional[Dict] = None, 29 | use_cache: bool = True, 30 | pad_token_id: int = None, 31 | bos_token_id: int = 1, 32 | eos_token_id: int = 2, 33 | tie_word_embeddings: bool = False, 34 | initializer_range: float = 0.02, 35 | fuse_norm: bool = True, 36 | fuse_swiglu: bool = True, 37 | fuse_cross_entropy: bool = True, 38 | vocab_size: int = 32000, 39 | max_cg_step_training: int = 30, 40 | max_cg_step_decoding: int = 30, 41 | **kwargs 42 | ): 43 | self.attn_mode = attn_mode 44 | self.hidden_size = hidden_size 45 | self.expand_v = expand_v 46 | self.use_gate = use_gate 47 | self.use_short_conv = use_short_conv 48 | self.conv_size = conv_size 49 | self.num_heads = num_heads 50 | self.max_position_embeddings = max_position_embeddings 51 | 52 | self.hidden_ratio = hidden_ratio 53 | self.intermediate_size = intermediate_size 54 | self.hidden_act = hidden_act 55 | self.num_hidden_layers = num_hidden_layers 56 | self.norm_eps = norm_eps 57 | self.attn = attn 58 | self.use_cache = use_cache 59 | self.initializer_range = initializer_range 60 | self.lambda_lower_bound = lambda_lower_bound 61 | 62 | self.fuse_norm = fuse_norm 63 | self.fuse_swiglu = fuse_swiglu 64 | self.fuse_cross_entropy = fuse_cross_entropy 65 | self.vocab_size = vocab_size 66 | self.max_cg_step_training = max_cg_step_training 67 | self.max_cg_step_decoding = max_cg_step_decoding 68 | 69 | if attn is not None: 70 | if not isinstance(attn, Dict): 71 | raise ValueError("attn must be a dictionary") 72 | if 'layers' not in attn: 73 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 74 | if 'num_heads' not in attn: 75 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 76 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 77 | attn['qkv_bias'] = attn.get('qkv_bias', False) 78 | attn['window_size'] = attn.get('window_size', None) 79 | attn['rope_theta'] = attn.get('rope_theta', 10000.) 80 | 81 | super().__init__( 82 | pad_token_id=pad_token_id, 83 | bos_token_id=bos_token_id, 84 | eos_token_id=eos_token_id, 85 | tie_word_embeddings=tie_word_embeddings, 86 | **kwargs, 87 | ) 88 | -------------------------------------------------------------------------------- /fla/models/nsa/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.nsa.configuration_nsa import NSAConfig 6 | from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel 7 | 8 | AutoConfig.register(NSAConfig.model_type, NSAConfig, exist_ok=True) 9 | AutoModel.register(NSAConfig, NSAModel, exist_ok=True) 10 | AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = [ 14 | 'NSAConfig', 'NSAModel', 'NSAForCausalLM', 15 | ] 16 | -------------------------------------------------------------------------------- /fla/models/nsa/configuration_nsa.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class NSAConfig(PretrainedConfig): 9 | 10 | model_type = 'nsa' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2048, 16 | num_hidden_layers: int = 24, 17 | num_heads: int = 64, 18 | num_kv_heads: int = 4, 19 | head_dim: int = 32, 20 | qkv_bias: bool = False, 21 | block_size: int = 64, 22 | block_counts: Optional[int] = 16, 23 | window_size: Optional[int] = 512, 24 | rope_theta: Optional[float] = 10000., 25 | max_position_embeddings: int = 2048, 26 | hidden_ratio: Optional[int] = 4, 27 | intermediate_size: Optional[int] = None, 28 | hidden_act: str = "swish", 29 | initializer_range: float = 0.02, 30 | elementwise_affine: Optional[bool] = True, 31 | norm_eps: float = 1e-6, 32 | use_cache: bool = True, 33 | pad_token_id: int = None, 34 | bos_token_id: int = 1, 35 | eos_token_id: int = 2, 36 | tie_word_embeddings: bool = False, 37 | fuse_norm: bool = True, 38 | fuse_swiglu: bool = True, 39 | fuse_cross_entropy: bool = True, 40 | vocab_size: int = 32000, 41 | **kwargs, 42 | ): 43 | self.hidden_size = hidden_size 44 | self.num_hidden_layers = num_hidden_layers 45 | self.num_heads = num_heads 46 | self.num_kv_heads = num_kv_heads 47 | self.head_dim = head_dim 48 | self.qkv_bias = qkv_bias 49 | self.block_size = block_size 50 | self.block_counts = block_counts 51 | self.window_size = window_size 52 | self.rope_theta = rope_theta 53 | self.max_position_embeddings = max_position_embeddings 54 | 55 | self.hidden_ratio = hidden_ratio 56 | self.intermediate_size = intermediate_size 57 | self.hidden_act = hidden_act 58 | 59 | self.initializer_range = initializer_range 60 | self.elementwise_affine = elementwise_affine 61 | self.norm_eps = norm_eps 62 | self.use_cache = use_cache 63 | 64 | self.fuse_norm = fuse_norm 65 | self.fuse_swiglu = fuse_swiglu 66 | self.fuse_cross_entropy = fuse_cross_entropy 67 | self.vocab_size = vocab_size 68 | 69 | super().__init__( 70 | pad_token_id=pad_token_id, 71 | bos_token_id=bos_token_id, 72 | eos_token_id=eos_token_id, 73 | tie_word_embeddings=tie_word_embeddings, 74 | **kwargs, 75 | ) 76 | -------------------------------------------------------------------------------- /fla/models/path_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.path_attn.configuration_path_attention import PaTHAttentionConfig 6 | from fla.models.path_attn.modeling_path_attention import PaTHAttentionForCausalLM, PaTHAttentionModel 7 | 8 | AutoConfig.register(PaTHAttentionConfig.model_type, PaTHAttentionConfig, exist_ok=True) 9 | AutoModel.register(PaTHAttentionConfig, PaTHAttentionModel, exist_ok=True) 10 | AutoModelForCausalLM.register(PaTHAttentionConfig, PaTHAttentionForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['PaTHAttentionConfig', 'PaTHAttentionForCausalLM', 'PaTHAttentionModel'] 14 | -------------------------------------------------------------------------------- /fla/models/path_attn/configuration_path_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class PaTHAttentionConfig(PretrainedConfig): 9 | 10 | model_type = 'path_attn' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2048, 16 | num_hidden_layers: int = 24, 17 | num_heads: int = 32, 18 | num_kv_heads: Optional[int] = None, 19 | hidden_ratio: Optional[int] = 4, 20 | intermediate_size: Optional[int] = None, 21 | hidden_act: str = "swish", 22 | initializer_range: float = 0.02, 23 | elementwise_affine: Optional[bool] = True, 24 | norm_eps: float = 1e-6, 25 | use_cache: bool = True, 26 | pad_token_id: Optional[int] = None, 27 | bos_token_id: int = 1, 28 | eos_token_id: int = 2, 29 | tie_word_embeddings: bool = False, 30 | fuse_norm: bool = True, 31 | fuse_swiglu: bool = True, 32 | fuse_cross_entropy: bool = True, 33 | vocab_size: int = 32000, 34 | use_forget_gate: bool = False, 35 | use_w_shortconv: bool = True, 36 | **kwargs, 37 | ): 38 | self.hidden_size = hidden_size 39 | self.num_hidden_layers = num_hidden_layers 40 | self.num_heads = num_heads 41 | self.num_kv_heads = num_kv_heads 42 | self.hidden_ratio = hidden_ratio 43 | self.intermediate_size = intermediate_size 44 | self.hidden_act = hidden_act 45 | 46 | self.initializer_range = initializer_range 47 | self.elementwise_affine = elementwise_affine 48 | self.norm_eps = norm_eps 49 | self.use_cache = use_cache 50 | 51 | self.fuse_norm = fuse_norm 52 | self.fuse_swiglu = fuse_swiglu 53 | self.fuse_cross_entropy = fuse_cross_entropy 54 | self.vocab_size = vocab_size 55 | 56 | self.use_forget_gate = use_forget_gate 57 | self.use_w_shortconv = use_w_shortconv 58 | 59 | super().__init__( 60 | pad_token_id=pad_token_id, 61 | bos_token_id=bos_token_id, 62 | eos_token_id=eos_token_id, 63 | tie_word_embeddings=tie_word_embeddings, 64 | **kwargs, 65 | ) 66 | -------------------------------------------------------------------------------- /fla/models/retnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.retnet.configuration_retnet import RetNetConfig 6 | from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel 7 | 8 | AutoConfig.register(RetNetConfig.model_type, RetNetConfig, exist_ok=True) 9 | AutoModel.register(RetNetConfig, RetNetModel, exist_ok=True) 10 | AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel'] 14 | -------------------------------------------------------------------------------- /fla/models/rodimus/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer 4 | 5 | from fla.models.rodimus.configuration_rodimus import RodimusConfig 6 | from fla.models.rodimus.modeling_rodimus import RodimusForCausalLM, RodimusModel 7 | from fla.models.rodimus.tokenization_rodimus_fast import RodimusTokenizer 8 | 9 | AutoConfig.register(RodimusConfig.model_type, RodimusConfig) 10 | AutoModel.register(RodimusConfig, RodimusModel) 11 | AutoModelForCausalLM.register(RodimusConfig, RodimusForCausalLM) 12 | AutoTokenizer.register(RodimusConfig, slow_tokenizer_class=None, fast_tokenizer_class=RodimusTokenizer) 13 | 14 | 15 | __all__ = ['RodimusConfig', 'RodimusForCausalLM', 'RodimusModel', 'RodimusTokenizer'] 16 | -------------------------------------------------------------------------------- /fla/models/rwkv6/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config 6 | from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model 7 | 8 | AutoConfig.register(RWKV6Config.model_type, RWKV6Config, exist_ok=True) 9 | AutoModel.register(RWKV6Config, RWKV6Model, exist_ok=True) 10 | AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model'] 14 | -------------------------------------------------------------------------------- /fla/models/rwkv6/configuration_rwkv6.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class RWKV6Config(PretrainedConfig): 9 | 10 | model_type = 'rwkv6' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "chunk", 16 | hidden_size: int = 2048, 17 | expand_k: int = 0.5, 18 | expand_v: int = 1, 19 | hidden_ratio: Optional[int] = 3.5, 20 | intermediate_size: Optional[int] = None, 21 | num_hidden_layers: int = 24, 22 | num_heads: int = 4, 23 | proj_low_rank_dim: int = 32, 24 | gate_low_rank_dim: int = 64, 25 | hidden_act: str = "sqrelu", 26 | max_position_embeddings: int = 2048, 27 | norm_first: bool = True, 28 | norm_bias: bool = True, 29 | norm_eps: float = 1e-5, 30 | attn: Optional[Dict] = None, 31 | use_cache: bool = True, 32 | pad_token_id: int = None, 33 | bos_token_id: int = 1, 34 | eos_token_id: int = 2, 35 | tie_word_embeddings: bool = False, 36 | initializer_range: float = 0.02, 37 | fuse_norm: bool = True, 38 | fuse_cross_entropy: bool = True, 39 | vocab_size: int = 32000, 40 | **kwargs 41 | ): 42 | self.attn_mode = attn_mode 43 | self.hidden_size = hidden_size 44 | self.expand_k = expand_k 45 | self.expand_v = expand_v 46 | self.hidden_ratio = hidden_ratio 47 | self.intermediate_size = intermediate_size 48 | self.norm_first = norm_first 49 | self.num_hidden_layers = num_hidden_layers 50 | self.num_heads = num_heads 51 | self.proj_low_rank_dim = proj_low_rank_dim 52 | self.gate_low_rank_dim = gate_low_rank_dim 53 | self.hidden_act = hidden_act 54 | self.max_position_embeddings = max_position_embeddings 55 | self.norm_bias = norm_bias 56 | self.norm_eps = norm_eps 57 | self.attn = attn 58 | self.use_cache = use_cache 59 | self.initializer_range = initializer_range 60 | self.fuse_norm = fuse_norm 61 | self.fuse_cross_entropy = fuse_cross_entropy 62 | self.vocab_size = vocab_size 63 | 64 | if attn is not None: 65 | if not isinstance(attn, Dict): 66 | raise ValueError("attn must be a dictionary") 67 | if 'layers' not in attn: 68 | raise ValueError("Layer indices must be provided to initialize hybrid attention layers") 69 | if 'num_heads' not in attn: 70 | raise ValueError("Number of heads must be provided to initialize hybrid attention layers") 71 | attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) 72 | attn['qkv_bias'] = attn.get('qkv_bias', False) 73 | attn['window_size'] = attn.get('window_size', None) 74 | attn['rope_theta'] = attn.get('rope_theta', 10000.) 75 | 76 | super().__init__( 77 | pad_token_id=pad_token_id, 78 | bos_token_id=bos_token_id, 79 | eos_token_id=eos_token_id, 80 | tie_word_embeddings=tie_word_embeddings, 81 | **kwargs, 82 | ) 83 | -------------------------------------------------------------------------------- /fla/models/rwkv7/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config 6 | from fla.models.rwkv7.modeling_rwkv7 import RWKV7ForCausalLM, RWKV7Model 7 | 8 | AutoConfig.register(RWKV7Config.model_type, RWKV7Config, exist_ok=True) 9 | AutoModel.register(RWKV7Config, RWKV7Model, exist_ok=True) 10 | AutoModelForCausalLM.register(RWKV7Config, RWKV7ForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model'] 14 | -------------------------------------------------------------------------------- /fla/models/samba/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.samba.configuration_samba import SambaConfig 6 | from fla.models.samba.modeling_samba import SambaBlock, SambaForCausalLM, SambaModel 7 | 8 | AutoConfig.register(SambaConfig.model_type, SambaConfig, exist_ok=True) 9 | AutoModel.register(SambaConfig, SambaModel, exist_ok=True) 10 | AutoModelForCausalLM.register(SambaConfig, SambaForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['SambaConfig', 'SambaForCausalLM', 'SambaModel', 'SambaBlock'] 14 | -------------------------------------------------------------------------------- /fla/models/samba/configuration_samba.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | from typing import Dict, Optional 5 | 6 | from transformers.configuration_utils import PretrainedConfig 7 | 8 | 9 | class SambaConfig(PretrainedConfig): 10 | 11 | model_type = "samba" 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2304, 16 | state_size: int = 16, 17 | num_hidden_layers: int = 18, 18 | norm_eps=1e-5, 19 | pad_token_id: int = 0, 20 | bos_token_id: int = 1, 21 | eos_token_id: int = 2, 22 | expand: int = 2, 23 | conv_kernel: int = 4, 24 | use_bias: bool = False, 25 | use_conv_bias: bool = True, 26 | hidden_act: str = "swish", 27 | initializer_range: str = 0.02, 28 | residual_in_fp32: bool = False, 29 | time_step_rank: str = "auto", 30 | time_step_scale: float = 1.0, 31 | time_step_min: float = 0.001, 32 | time_step_max: float = 0.1, 33 | time_step_init_scheme: str = "random", 34 | time_step_floor: float = 1e-4, 35 | max_position_embeddings: int = 2048, 36 | attn: Optional[Dict] = { 37 | 'layers': (1, 3, 5, 7, 9, 11, 13, 15, 17), 38 | 'num_heads': 18, 39 | 'num_kv_heads': 18, 40 | 'qkv_bias': False, 41 | 'window_size': 2048, 42 | 'rope_theta': 10000. 43 | }, 44 | hidden_ratio: Optional[int] = 4, 45 | rescale_prenorm_residual: bool = False, 46 | use_cache: bool = True, 47 | fuse_norm: bool = True, 48 | fuse_swiglu: bool = True, 49 | fuse_cross_entropy: bool = True, 50 | vocab_size: int = 32000, 51 | tie_word_embeddings: bool = False, 52 | **kwargs, 53 | ): 54 | self.hidden_size = hidden_size 55 | self.state_size = state_size 56 | self.num_hidden_layers = num_hidden_layers 57 | self.norm_eps = norm_eps 58 | self.conv_kernel = conv_kernel 59 | self.expand = expand 60 | self.intermediate_size = int(expand * self.hidden_size) 61 | self.bos_token_id = bos_token_id 62 | self.eos_token_id = eos_token_id 63 | self.pad_token_id = pad_token_id 64 | self.use_bias = use_bias 65 | self.use_conv_bias = use_conv_bias 66 | self.hidden_act = hidden_act 67 | self.initializer_range = initializer_range 68 | self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank 69 | self.time_step_scale = time_step_scale 70 | self.time_step_min = time_step_min 71 | self.time_step_max = time_step_max 72 | self.time_step_init_scheme = time_step_init_scheme 73 | self.time_step_floor = time_step_floor 74 | self.max_position_embeddings = max_position_embeddings 75 | self.attn = attn 76 | self.hidden_ratio = hidden_ratio 77 | self.rescale_prenorm_residual = rescale_prenorm_residual 78 | self.residual_in_fp32 = residual_in_fp32 79 | self.use_cache = use_cache 80 | 81 | self.fuse_norm = fuse_norm 82 | self.fuse_swiglu = fuse_swiglu 83 | self.fuse_cross_entropy = fuse_cross_entropy 84 | self.vocab_size = vocab_size 85 | 86 | super().__init__( 87 | bos_token_id=bos_token_id, 88 | eos_token_id=eos_token_id, 89 | pad_token_id=pad_token_id, 90 | tie_word_embeddings=tie_word_embeddings, 91 | **kwargs 92 | ) 93 | -------------------------------------------------------------------------------- /fla/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.transformer.configuration_transformer import TransformerConfig 6 | from fla.models.transformer.modeling_transformer import TransformerForCausalLM, TransformerModel 7 | 8 | AutoConfig.register(TransformerConfig.model_type, TransformerConfig, exist_ok=True) 9 | AutoModel.register(TransformerConfig, TransformerModel, exist_ok=True) 10 | AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM, exist_ok=True) 11 | 12 | 13 | __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'] 14 | -------------------------------------------------------------------------------- /fla/models/transformer/configuration_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class TransformerConfig(PretrainedConfig): 9 | 10 | model_type = 'transformer' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | hidden_size: int = 2048, 16 | num_hidden_layers: int = 24, 17 | num_heads: int = 32, 18 | num_kv_heads: int = None, 19 | qkv_bias: bool = False, 20 | qk_norm: bool = False, 21 | window_size: Optional[int] = None, 22 | rope_theta: Optional[float] = 10000., 23 | max_position_embeddings: int = 2048, 24 | hidden_ratio: Optional[int] = 4, 25 | intermediate_size: Optional[int] = None, 26 | hidden_act: str = "swish", 27 | initializer_range: float = 0.02, 28 | elementwise_affine: Optional[bool] = True, 29 | norm_eps: float = 1e-6, 30 | use_cache: bool = True, 31 | pad_token_id: int = None, 32 | bos_token_id: int = 1, 33 | eos_token_id: int = 2, 34 | tie_word_embeddings: bool = False, 35 | fuse_norm: bool = True, 36 | fuse_swiglu: bool = True, 37 | fuse_cross_entropy: bool = True, 38 | vocab_size: int = 32000, 39 | **kwargs, 40 | ): 41 | self.hidden_size = hidden_size 42 | self.num_hidden_layers = num_hidden_layers 43 | self.num_heads = num_heads 44 | self.num_kv_heads = num_kv_heads 45 | self.qkv_bias = qkv_bias 46 | self.qk_norm = qk_norm 47 | self.window_size = window_size 48 | self.rope_theta = rope_theta 49 | self.max_position_embeddings = max_position_embeddings 50 | 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.hidden_act = hidden_act 54 | 55 | self.initializer_range = initializer_range 56 | self.elementwise_affine = elementwise_affine 57 | self.norm_eps = norm_eps 58 | self.use_cache = use_cache 59 | 60 | self.fuse_norm = fuse_norm 61 | self.fuse_swiglu = fuse_swiglu 62 | self.fuse_cross_entropy = fuse_cross_entropy 63 | self.vocab_size = vocab_size 64 | 65 | super().__init__( 66 | pad_token_id=pad_token_id, 67 | bos_token_id=bos_token_id, 68 | eos_token_id=eos_token_id, 69 | tie_word_embeddings=tie_word_embeddings, 70 | **kwargs, 71 | ) 72 | -------------------------------------------------------------------------------- /fla/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.modules.convolution import ImplicitLongConvolution, LongConvolution, ShortConvolution 4 | from fla.modules.fused_bitlinear import BitLinear, FusedBitLinear 5 | from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss 6 | from fla.modules.fused_kl_div import FusedKLDivLoss 7 | from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss 8 | from fla.modules.fused_norm_gate import ( 9 | FusedLayerNormGated, 10 | FusedLayerNormSwishGate, 11 | FusedLayerNormSwishGateLinear, 12 | FusedRMSNormGated, 13 | FusedRMSNormSwishGate, 14 | FusedRMSNormSwishGateLinear 15 | ) 16 | from fla.modules.l2norm import L2Norm 17 | from fla.modules.layernorm import GroupNorm, GroupNormLinear, LayerNorm, LayerNormLinear, RMSNorm, RMSNormLinear 18 | from fla.modules.mlp import GatedMLP 19 | from fla.modules.rotary import RotaryEmbedding 20 | from fla.modules.token_shift import TokenShift 21 | 22 | __all__ = [ 23 | 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution', 24 | 'BitLinear', 'FusedBitLinear', 25 | 'FusedCrossEntropyLoss', 'FusedLinearCrossEntropyLoss', 'FusedKLDivLoss', 26 | 'L2Norm', 27 | 'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear', 28 | 'FusedLayerNormGated', 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 29 | 'FusedRMSNormGated', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear', 30 | 'GatedMLP', 31 | 'RotaryEmbedding', 32 | 'TokenShift' 33 | ] 34 | -------------------------------------------------------------------------------- /fla/modules/parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | from typing import Optional 5 | 6 | import torch.nn as nn 7 | from torch.distributed import DeviceMesh 8 | from torch.distributed.tensor import DTensor, distribute_module 9 | from torch.distributed.tensor.parallel import ParallelStyle 10 | from torch.distributed.tensor.placement_types import Placement 11 | 12 | 13 | class PrepareModuleWeight(ParallelStyle): 14 | def __init__(self, *, layouts: Optional[Placement] = None): 15 | super().__init__() 16 | self.layouts = layouts 17 | 18 | def _replicate_module_fn( 19 | self, 20 | name: str, 21 | module: nn.Module, 22 | device_mesh: DeviceMesh 23 | ): 24 | for p_name, param in module.named_parameters(): 25 | replicated_param = nn.Parameter( 26 | DTensor.from_local(param, device_mesh, [self.layouts], run_check=False) 27 | ) 28 | module.register_parameter(p_name, replicated_param) 29 | 30 | def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 31 | return distribute_module( 32 | module, 33 | device_mesh, 34 | partition_fn=self._replicate_module_fn, 35 | input_fn=None, 36 | output_fn=None 37 | ) 38 | -------------------------------------------------------------------------------- /fla/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .abc import chunk_abc 4 | from .attn import parallel_attn 5 | from .based import fused_chunk_based, parallel_based 6 | from .delta_rule import chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule 7 | from .forgetting_attn import parallel_forgetting_attn 8 | from .gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule 9 | from .generalized_delta_rule import ( 10 | chunk_dplr_delta_rule, 11 | chunk_iplr_delta_rule, 12 | fused_recurrent_dplr_delta_rule, 13 | fused_recurrent_iplr_delta_rule 14 | ) 15 | from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla 16 | from .gsa import chunk_gsa, fused_recurrent_gsa 17 | from .hgrn import fused_recurrent_hgrn 18 | from .lightning_attn import chunk_lightning_attn, fused_recurrent_lightning_attn 19 | from .linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn 20 | from .mesa_net import chunk_mesa_net 21 | from .nsa import parallel_nsa 22 | from .path_attn import parallel_path_attention 23 | from .retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention 24 | from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 25 | from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7 26 | from .simple_gla import chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla 27 | 28 | __all__ = [ 29 | 'chunk_abc', 30 | 'parallel_attn', 31 | 'fused_chunk_based', 'parallel_based', 32 | 'chunk_delta_rule', 'fused_chunk_delta_rule', 'fused_recurrent_delta_rule', 33 | 'parallel_forgetting_attn', 34 | 'chunk_gated_delta_rule', 'fused_recurrent_gated_delta_rule', 35 | 'chunk_dplr_delta_rule', 'chunk_iplr_delta_rule', 36 | 'fused_recurrent_dplr_delta_rule', 'fused_recurrent_iplr_delta_rule', 37 | 'chunk_gla', 'fused_chunk_gla', 'fused_recurrent_gla', 38 | 'chunk_gsa', 'fused_recurrent_gsa', 39 | 'fused_recurrent_hgrn', 40 | 'chunk_lightning_attn', 'fused_recurrent_lightning_attn', 41 | 'chunk_linear_attn', 'fused_chunk_linear_attn', 'fused_recurrent_linear_attn', 42 | 'parallel_nsa', 43 | 'chunk_retention', 'fused_chunk_retention', 'fused_recurrent_retention', 'parallel_retention', 44 | 'chunk_rwkv6', 'fused_recurrent_rwkv6', 45 | 'chunk_rwkv7', 'fused_recurrent_rwkv7', 46 | 'chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla', 47 | 'parallel_path_attention', 48 | 'chunk_mesa_net', 49 | ] 50 | -------------------------------------------------------------------------------- /fla/ops/abc/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_abc 4 | 5 | __all__ = [ 6 | 'chunk_abc' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/abc/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | from einops import repeat 7 | 8 | 9 | def naive_recurrent_abc( 10 | q: torch.Tensor, 11 | k: torch.Tensor, 12 | v: torch.Tensor, 13 | s: torch.Tensor, 14 | g: Optional[torch.Tensor] = None, 15 | scale: Optional[int] = None, 16 | initial_state: Optional[torch.Tensor] = None, 17 | output_final_state: Optional[bool] = False 18 | ) -> torch.Tensor: 19 | dtype = q.dtype 20 | 21 | NG = q.shape[1]//k.shape[1] 22 | # [batch_size, n_heads, seq_len, n_slots] 23 | if g is None: 24 | z = s.float().logcumsumexp(2) 25 | g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z 26 | s = torch.exp(s - z) 27 | q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g)) 28 | k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g)) 29 | if initial_state is not None: 30 | initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state)) 31 | 32 | B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] 33 | 34 | hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) 35 | ok = torch.zeros_like(s) 36 | 37 | if scale is None: 38 | scale = q.shape[-1] ** -0.5 39 | 40 | final_state = None 41 | if initial_state is not None: 42 | hk += initial_state[0] 43 | 44 | for i in range(T): 45 | q_i = q[:, :, i] * scale 46 | k_i = k[:, :, i] 47 | v_i = s[:, :, i] 48 | g_i = g[:, :, i].exp() 49 | hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] 50 | ok[:, :, i] = (q_i[..., None] * hk).sum(-2) 51 | 52 | qv = ok.softmax(-1) 53 | hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) 54 | ov = torch.zeros_like(v) 55 | if initial_state is not None: 56 | hv += initial_state[1] 57 | 58 | for i in range(T): 59 | q_i = qv[:, :, i] 60 | k_i = s[:, :, i] 61 | v_i = v[:, :, i] 62 | g_i = g[:, :, i].exp() 63 | hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] 64 | ov[:, :, i] = (q_i[..., None] * hv).sum(-2) 65 | 66 | if output_final_state: 67 | final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0]) 68 | return ov.to(dtype), final_state 69 | 70 | 71 | def naive_cumsum_abc( 72 | q: torch.Tensor, 73 | k: torch.Tensor, 74 | v: torch.Tensor, 75 | s: torch.Tensor 76 | ) -> torch.Tensor: 77 | """ 78 | A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper. 79 | This is just for demonstration purposes, with no numerical stabilities guaranteed. 80 | """ 81 | 82 | dtype = q.dtype 83 | q, k, v, s = map(lambda x: x.float(), (q, k, v, s)) 84 | 85 | scale = q.shape[-1] ** -0.5 86 | # [batch_size, n_heads, seq_len, n_slots] 87 | s = (s - s.max(2, True)[0]).exp() 88 | z = s.cumsum(2) 89 | # [batch_size, n_heads, seq_len, n_slots, d_head] 90 | K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) 91 | V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) 92 | # [batch_size, n_heads, seq_len, n_slots] 93 | p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1) 94 | # [batch_size, n_heads, seq_len, d_head] 95 | o = torch.einsum('...m,...md->...d', p, V) 96 | return o.to(dtype), None 97 | -------------------------------------------------------------------------------- /fla/ops/attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .parallel import parallel_attn 4 | 5 | __all__ = [ 6 | 'parallel_attn' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/based/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .fused_chunk import fused_chunk_based 4 | from .parallel import parallel_based 5 | 6 | __all__ = [ 7 | 'fused_chunk_based', 8 | 'parallel_based' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/based/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | from einops import rearrange 7 | 8 | 9 | def naive_parallel_based( 10 | q: torch.Tensor, 11 | k: torch.Tensor, 12 | v: torch.Tensor, 13 | scale: Optional[float] = None, 14 | use_norm: bool = True 15 | ): 16 | if scale is None: 17 | scale = q.shape[-1] ** -0.5 18 | q = q * scale 19 | attn = q @ k.transpose(-2, -1) 20 | attn = 1 + attn + 1/2 * (attn ** 2) 21 | attn.masked_fill_(~torch.tril(torch.ones( 22 | q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) 23 | o = attn @ v 24 | if use_norm: 25 | z = attn.sum(-1) 26 | return o / (z[..., None] + 1e-6) 27 | else: 28 | return o 29 | 30 | 31 | def naive_chunk_based(q, k, v, chunk_size=256): 32 | q = q * (q.shape[-1] ** -0.5) 33 | # compute normalizer. 34 | k_cumsum = torch.cumsum(k, dim=-2) 35 | kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) 36 | # first 37 | z = (q * k_cumsum).sum(-1) 38 | # second order 39 | z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 40 | # zero-th order 41 | z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] 42 | 43 | # compute o 44 | # constant term 45 | _o = v.cumsum(-2) 46 | 47 | q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) 48 | 49 | k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) 50 | v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) 51 | 52 | intra_chunk_attn = q @ k.transpose(-2, -1) 53 | intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) 54 | intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0) 55 | o = intra_chunk_attn @ v 56 | 57 | # quadractic term 58 | kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) 59 | kv = kv.cumsum(2) 60 | kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) 61 | 62 | o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) 63 | 64 | # linear term 65 | kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) 66 | kv = kv.cumsum(2) 67 | kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) 68 | o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) 69 | 70 | o = rearrange(o, 'b h n c d -> b h (n c) d') 71 | o = o + _o 72 | return o / (z[..., None] + 1e-6) 73 | -------------------------------------------------------------------------------- /fla/ops/common/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /fla/ops/delta_rule/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_delta_rule 4 | from .fused_chunk import fused_chunk_delta_rule 5 | from .fused_recurrent import fused_recurrent_delta_rule 6 | 7 | __all__ = [ 8 | 'fused_chunk_delta_rule', 9 | 'fused_recurrent_delta_rule', 10 | 'chunk_delta_rule' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/delta_rule/fused_chunk.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | def fused_chunk_delta_rule( 4 | **kwargs 5 | ): 6 | raise NotImplementedError("fused_chunk_delta_rule is deprecated. Please use chunk_delta_rule instead.") 7 | -------------------------------------------------------------------------------- /fla/ops/forgetting_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .parallel import parallel_forgetting_attn 4 | 5 | __all__ = [ 6 | 'parallel_forgetting_attn' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/forgetting_attn/parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | import warnings 5 | from typing import Optional 6 | 7 | import torch 8 | from einops import rearrange 9 | 10 | from fla.ops.attn.parallel import parallel_attn 11 | 12 | 13 | def parallel_forgetting_attn( 14 | q: torch.Tensor, 15 | k: torch.Tensor, 16 | v: torch.Tensor, 17 | g: torch.Tensor, 18 | scale: Optional[float] = None, 19 | cu_seqlens: Optional[torch.LongTensor] = None, 20 | head_first: bool = False 21 | ) -> torch.Tensor: 22 | r""" 23 | Args: 24 | q (torch.Tensor): 25 | queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. 26 | k (torch.Tensor): 27 | keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 28 | GQA will be applied if HQ is divisible by H. 29 | v (torch.Tensor): 30 | values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. 31 | g (torch.Tensor): 32 | Log decay at rach time step (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. 33 | scale (Optional[int]): 34 | Scale factor for attention scores. 35 | If not provided, it will default to `1 / sqrt(K)`. Default: `None`. 36 | cu_seqlens (torch.LongTensor): 37 | Cumulative sequence lengths of shape `[N+1]` used for variable-length training, 38 | consistent with the FlashAttention API. 39 | head_first (Optional[bool]): 40 | Whether the inputs are in the head-first format. Default: `False`. 41 | 42 | Returns: 43 | o (torch.Tensor): 44 | Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. 45 | """ 46 | assert (g <= 0).all(), "g_cumsum must be in log space" 47 | if scale is None: 48 | scale = k.shape[-1] ** -0.5 49 | if cu_seqlens is not None: 50 | assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" 51 | if head_first: 52 | raise DeprecationWarning( 53 | "head_first is deprecated and will be removed in a future version. " 54 | "Please use head_first=False for now instead." 55 | ) 56 | q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g)) 57 | if not head_first and q.shape[1] < q.shape[2]: 58 | warnings.warn( 59 | f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " 60 | "This may indicate the inputs were passed in head-first format [B, H, T, ...] " 61 | "when head_first=False was specified. " 62 | "Please verify your input tensor format matches the expected shape [B, T, H, ...]." 63 | ) 64 | o = parallel_attn(q, k, v, g, scale, cu_seqlens) 65 | if head_first: 66 | o = rearrange(o, 'b t h ... -> b h t ...') 67 | return o 68 | -------------------------------------------------------------------------------- /fla/ops/gated_delta_rule/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunk import chunk_gated_delta_rule 2 | from .fused_recurrent import fused_recurrent_gated_delta_rule 3 | 4 | __all__ = [ 5 | "chunk_gated_delta_rule", 6 | "fused_recurrent_gated_delta_rule" 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/generalized_delta_rule/README.md: -------------------------------------------------------------------------------- 1 | # Generalized Delta Rule 2 | 3 | In delta rule we have the recurrence: 4 | 5 | ```math 6 | \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T) + \beta_t \mathbf{v}_t\mathbf{k}_t^T 7 | ``` 8 | 9 | This repository implements a delta rule variant where $\mathbf{I}$ is not necessarily an identity matrix; $\mathbf{k}_t$ in $\mathbf{I} - \beta_t \mathbf{k}_t\mathbf{k}_t^T$ might be different from input $\mathbf{k}_t$ in $\mathbf{v}_t\mathbf{k}_t^T$. 10 | 11 | ## IPLR (Identity Plus Low Rank) 12 | 13 | The first variant is IPLR, where we have: 14 | 15 | ```math 16 | \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T 17 | ``` 18 | 19 | When $\mathbf{a}_t = -\beta_t \mathbf{k}_t$, $\mathbf{b}_t = \mathbf{k}_t$, $\mathbf{v}_t= \beta_t \mathbf{v}_t$, we recover the original delta rule. Since here the transition matrix is identity-plus-low-rank, we refer to this variant as IPLR. 20 | 21 | ### Numerical Stability 22 | 23 | $\mathbf{a}_t$ and $\mathbf{b}_t$ must be in opposite directions, that is, $\mathbf{b}_t = \lambda_t \mathbf{a}_t$ where $\lambda_t < 0$. For an understanding of why this is necessary, you can derive the eigenvalues of the transition matrix. 24 | 25 | ## DPLR (Diagonal Plus Low Rank) 26 | 27 | The second variant is DPLR, where we have: 28 | 29 | ```math 30 | \mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{D}_t+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T 31 | ``` 32 | 33 | Here, $\mathbf{I}$ is replaced by a diagonal matrix $\mathbf{D}_t$. This transition matrix structure has been utilized in RWKV7. 34 | 35 | ## Efficient Chunkwise Implementation 36 | 37 | For detailed information about efficient chunkwise implementation, please refer to our [technical note](https://drive.google.com/file/d/1rJbO3dU4fe7OKG3w7Yg058z_BNIuavNF/view?usp=sharing). 38 | -------------------------------------------------------------------------------- /fla/ops/generalized_delta_rule/__init__.py: -------------------------------------------------------------------------------- 1 | from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule 2 | from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule 3 | 4 | __all__ = [ 5 | 'chunk_dplr_delta_rule', 6 | 'fused_recurrent_dplr_delta_rule', 7 | 'chunk_iplr_delta_rule', 8 | 'fused_recurrent_iplr_delta_rule' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/generalized_delta_rule/dplr/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunk import chunk_dplr_delta_rule 2 | from .fused_recurrent import fused_recurrent_dplr_delta_rule 3 | 4 | __all__ = [ 5 | 'chunk_dplr_delta_rule', 6 | 'fused_recurrent_dplr_delta_rule' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/generalized_delta_rule/iplr/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunk import chunk_iplr_delta_rule 2 | from .fused_recurrent import fused_recurrent_iplr_delta_rule 3 | 4 | __all__ = [ 5 | 'chunk_iplr_delta_rule', 6 | 'fused_recurrent_iplr_delta_rule' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/generalized_delta_rule/iplr/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T 8 | # q, k, alpha, beta [B, H, L, D_K] 9 | # v [B, H, L, D_V] 10 | def iplr_recurrence(q, k, v, alpha, beta, initial_state=None, output_final_state=True): 11 | orig_dtype = q.dtype 12 | b, h, l, d_k = q.shape 13 | q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta]) 14 | d_v = v.shape[-1] 15 | o = torch.zeros_like(v) 16 | S = torch.zeros(b, h, d_k, d_v).to(v) 17 | q = q * (d_k ** -0.5) 18 | 19 | if initial_state is not None: 20 | S += initial_state 21 | 22 | for i in range(l): 23 | _k = k[:, :, i] 24 | _q = q[:, :, i] 25 | _v = v[:, :, i] 26 | _alpha = alpha[:, :, i] 27 | _beta = beta[:, :, i] 28 | _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None] 29 | S = S + _kv 30 | o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) 31 | S = None if output_final_state is False else S 32 | return o.to(orig_dtype), S 33 | 34 | 35 | def iplr_chunkwise(q, k, v, alpha, beta, initial_state=None, output_final_state=True, chunk_size=32): 36 | b, h, l, d_k = q.shape 37 | d_v = v.shape[-1] 38 | q = q * (d_k ** -0.5) 39 | v = v 40 | assert l % chunk_size == 0 41 | 42 | S = k.new_zeros(b, h, d_k, d_v) 43 | if initial_state is not None: 44 | S += initial_state 45 | 46 | # note that diagonal is masked. 47 | mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) 48 | q, k, v, alpha, beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, alpha, beta]) 49 | 50 | v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v 51 | attn = (alpha @ beta.transpose(-1, -2)).masked_fill(mask, 0) 52 | for i in range(1, chunk_size): 53 | attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) 54 | 55 | attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) 56 | u = attn @ v2 57 | w = attn @ alpha 58 | o = torch.zeros_like(v) 59 | mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) 60 | for i in range(0, l // chunk_size): 61 | q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i] 62 | o_1 = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) @ v_i 63 | v2_i = u_i + w_i @ S 64 | o_2 = (q_i @ beta_i.transpose(-1, -2)).masked_fill_(mask, 0) @ (v2_i) 65 | o_3 = q_i @ S 66 | o[:, :, i] = o_1 + o_2 + o_3 67 | S = S + k_i.transpose(-1, -2) @ v_i + beta_i.transpose(-1, -2) @ v2_i 68 | S = None if output_final_state is False else S 69 | return rearrange(o, 'b h n c d -> b h (n c) d'), S 70 | -------------------------------------------------------------------------------- /fla/ops/gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_gla 4 | from .fused_chunk import fused_chunk_gla 5 | from .fused_recurrent import fused_recurrent_gla 6 | 7 | __all__ = [ 8 | 'chunk_gla', 9 | 'fused_chunk_gla', 10 | 'fused_recurrent_gla' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/gla/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | def ceildiv(a, b): 9 | return -(a // -b) 10 | 11 | 12 | def naive_recurrent_gla( 13 | q: torch.Tensor, 14 | k: torch.Tensor, 15 | v: torch.Tensor, 16 | gk: torch.Tensor, 17 | initial_state: Optional[torch.Tensor] = None, 18 | output_final_state: bool = False 19 | ): 20 | dtype = q.dtype 21 | q, k, v, gk = map(lambda x: x.transpose(1, 2).float(), (q, k, v, gk)) 22 | B, H, T, K, V = *q.shape, v.shape[-1] 23 | o = torch.zeros_like(v) 24 | scale = K ** -0.5 25 | 26 | h = q.new_zeros(B, H, K, V, dtype=torch.float32) 27 | if initial_state is not None: 28 | h += initial_state.float() 29 | 30 | for i in range(T): 31 | q_i = q[:, :, i] * scale 32 | k_i = k[:, :, i] 33 | v_i = v[:, :, i] 34 | gk_i = gk[:, :, i].exp() 35 | kv_i = k_i[..., None] * v_i[..., None, :] 36 | h = h * gk_i[..., None] + kv_i 37 | o[:, :, i] = (q_i[..., None] * h).sum(-2) 38 | 39 | if not output_final_state: 40 | h = None 41 | return o.transpose(1, 2).to(dtype), h 42 | -------------------------------------------------------------------------------- /fla/ops/gsa/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_gsa 4 | from .fused_recurrent import fused_recurrent_gsa 5 | 6 | __all__ = [ 7 | 'chunk_gsa', 8 | 'fused_recurrent_gsa' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/gsa/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | from einops import repeat 7 | 8 | 9 | def naive_recurrent_gsa( 10 | q: torch.Tensor, 11 | k: torch.Tensor, 12 | v: torch.Tensor, 13 | s: torch.Tensor, 14 | g: Optional[torch.Tensor] = None, 15 | scale: Optional[int] = None, 16 | initial_state: Optional[torch.Tensor] = None, 17 | output_final_state: Optional[bool] = False 18 | ) -> torch.Tensor: 19 | dtype = q.dtype 20 | q, k, v, s, g = map(lambda x: x.transpose(1, 2).contiguous().float(), (q, k, v, s, g)) 21 | 22 | NG = q.shape[1]//k.shape[1] 23 | # [batch_size, n_heads, seq_len, n_slots] 24 | if g is None: 25 | z = s.float().logcumsumexp(2) 26 | g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z 27 | s = torch.exp(s - z) 28 | k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g)) 29 | if initial_state is not None: 30 | initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state)) 31 | 32 | B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] 33 | 34 | hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) 35 | ok = torch.zeros_like(s) 36 | 37 | if scale is None: 38 | scale = q.shape[-1] ** -0.5 39 | 40 | final_state = None 41 | if initial_state is not None: 42 | hk += initial_state[0] 43 | 44 | for i in range(T): 45 | q_i = q[:, :, i] * scale 46 | k_i = k[:, :, i] 47 | v_i = s[:, :, i] 48 | g_i = g[:, :, i].exp() 49 | hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] 50 | ok[:, :, i] = (q_i[..., None] * hk).sum(-2) 51 | 52 | qv = ok.softmax(-1) 53 | hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) 54 | ov = torch.zeros_like(v) 55 | if initial_state is not None: 56 | hv += initial_state[1] 57 | 58 | for i in range(T): 59 | q_i = qv[:, :, i] 60 | k_i = s[:, :, i] 61 | v_i = v[:, :, i] 62 | g_i = g[:, :, i].exp() 63 | hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] 64 | ov[:, :, i] = (q_i[..., None] * hv).sum(-2) 65 | 66 | if output_final_state: 67 | final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0]) 68 | ov = ov.transpose(1, 2).contiguous() 69 | return ov.to(dtype), final_state 70 | -------------------------------------------------------------------------------- /fla/ops/hgrn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_hgrn 4 | from .fused_recurrent import fused_recurrent_hgrn 5 | 6 | __all__ = [ 7 | 'chunk_hgrn', 8 | 'fused_recurrent_hgrn' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/hgrn/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | def naive_recurrent_hgrn( 9 | x: torch.Tensor, 10 | g: torch.Tensor, 11 | initial_state: Optional[torch.Tensor] = None, 12 | output_final_state: Optional[bool] = False 13 | ) -> torch.Tensor: 14 | dtype = x.dtype 15 | x, g = map(lambda i: i.float(), (x, g)) 16 | B, T, D = x.shape 17 | 18 | h = torch.zeros(B, D, dtype=torch.float, device=x.device) 19 | o = torch.zeros_like(x) 20 | 21 | final_state = None 22 | if initial_state is not None: 23 | h += initial_state 24 | 25 | for i in range(T): 26 | h = g[:, i].exp() * h + x[:, i] 27 | o[:, i] = h 28 | 29 | if output_final_state: 30 | final_state = h 31 | return o.to(dtype), final_state 32 | 33 | 34 | def naive_chunk_hgrn( 35 | x: torch.Tensor, 36 | g: torch.Tensor, 37 | initial_state: Optional[torch.Tensor] = None, 38 | output_final_state: Optional[bool] = False, 39 | chunk_size: int = 64 40 | ) -> torch.Tensor: 41 | dtype = x.dtype 42 | x, g = map(lambda i: i.float(), (x, g)) 43 | B, T, D = x.shape 44 | 45 | gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g) 46 | h = torch.zeros(B, D, dtype=torch.float, device=x.device) 47 | o = torch.zeros_like(x) 48 | 49 | final_state = None 50 | if initial_state is not None: 51 | h += initial_state 52 | 53 | for i in range(0, T, chunk_size): 54 | hp = h 55 | h = torch.zeros(B, D, dtype=torch.float, device=x.device) 56 | for j in range(i, i + chunk_size): 57 | h = g[:, j].exp() * h + x[:, j] 58 | o[:, j] = hp * gc[:, j].exp() + h 59 | h = o[:, j].clone() 60 | 61 | if output_final_state: 62 | final_state = h 63 | return o.to(dtype), final_state 64 | -------------------------------------------------------------------------------- /fla/ops/lightning_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_lightning_attn 4 | from .fused_recurrent import fused_recurrent_lightning_attn 5 | 6 | __all__ = [ 7 | 'chunk_lightning_attn', 8 | 'fused_recurrent_lightning_attn' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/lightning_attn/chunk.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | 8 | from fla.ops.simple_gla.chunk import chunk_simple_gla 9 | 10 | 11 | @torch.compiler.disable 12 | def chunk_lightning_attn( 13 | q: torch.Tensor, 14 | k: torch.Tensor, 15 | v: torch.Tensor, 16 | layer_idx: int, 17 | num_layers: int, 18 | scale: Optional[float] = None, 19 | initial_state: Optional[torch.Tensor] = None, 20 | output_final_state: bool = False, 21 | cu_seqlens: Optional[torch.LongTensor] = None, 22 | head_first: bool = False 23 | ) -> Tuple[torch.Tensor, torch.Tensor]: 24 | r""" 25 | Args: 26 | q (torch.Tensor): 27 | queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 28 | k (torch.Tensor): 29 | keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 30 | v (torch.Tensor): 31 | values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. 32 | layer_idx (int): 33 | The index of the current layer. 34 | num_layers (int): 35 | The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor. 36 | scale (Optional[int]): 37 | Scale factor for the attention scores. 38 | If not provided, it will default to `1 / sqrt(K)`. Default: `None`. 39 | initial_state (Optional[torch.Tensor]): 40 | Initial state of shape `[N, H, K, V]` for `N` input sequences. 41 | For equal-length input sequences, `N` equals the batch size `B`. 42 | Default: `None`. 43 | output_final_state (Optional[bool]): 44 | Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. 45 | cu_seqlens (torch.LongTensor): 46 | Cumulative sequence lengths of shape `[N+1]` used for variable-length training, 47 | consistent with the FlashAttention API. 48 | head_first (Optional[bool]): 49 | Whether the inputs are in the head-first format, which is not supported for variable-length inputs. 50 | Default: `False`. 51 | 52 | Returns: 53 | o (torch.Tensor): 54 | Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. 55 | final_state (torch.Tensor): 56 | Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. 57 | """ 58 | H = q.shape[1] if head_first else q.shape[2] 59 | s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float) 60 | if head_first: 61 | g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() 62 | else: 63 | g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() 64 | return chunk_simple_gla( 65 | q=q, 66 | k=k, 67 | v=v, 68 | scale=scale, 69 | g=g, 70 | initial_state=initial_state, 71 | output_final_state=output_final_state, 72 | head_first=head_first, 73 | cu_seqlens=cu_seqlens 74 | ) 75 | -------------------------------------------------------------------------------- /fla/ops/lightning_attn/fused_recurrent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | 8 | from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla 9 | 10 | 11 | def fused_recurrent_lightning_attn( 12 | q: torch.Tensor, 13 | k: torch.Tensor, 14 | v: torch.Tensor, 15 | layer_idx: int, 16 | num_layers: int, 17 | scale: Optional[float] = None, 18 | initial_state: Optional[torch.Tensor] = None, 19 | output_final_state: bool = False, 20 | reverse: bool = False, 21 | cu_seqlens: Optional[torch.LongTensor] = None, 22 | head_first: bool = False 23 | ) -> Tuple[torch.Tensor, torch.Tensor]: 24 | r""" 25 | Args: 26 | q (torch.Tensor): 27 | queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 28 | k (torch.Tensor): 29 | keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 30 | v (torch.Tensor): 31 | values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. 32 | layer_idx (int): 33 | The index of the current layer. 34 | num_layers (int): 35 | The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor. 36 | scale (Optional[int]): 37 | Scale factor for the attention scores. 38 | If not provided, it will default to `1 / sqrt(K)`. Default: `None`. 39 | initial_state (Optional[torch.Tensor]): 40 | Initial state of shape `[N, H, K, V]` for `N` input sequences. 41 | For equal-length input sequences, `N` equals the batch size `B`. 42 | Default: `None`. 43 | output_final_state (Optional[bool]): 44 | Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. 45 | cu_seqlens (torch.LongTensor): 46 | Cumulative sequence lengths of shape `[N+1]` used for variable-length training, 47 | consistent with the FlashAttention API. 48 | head_first (Optional[bool]): 49 | Whether the inputs are in the head-first format, which is not supported for variable-length inputs. 50 | Default: `False`. 51 | 52 | Returns: 53 | o (torch.Tensor): 54 | Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. 55 | final_state (torch.Tensor): 56 | Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. 57 | """ 58 | H = q.shape[1] if head_first else q.shape[2] 59 | s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float) 60 | if head_first: 61 | g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() 62 | else: 63 | g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() 64 | return fused_recurrent_simple_gla( 65 | q=q, 66 | k=k, 67 | v=v, 68 | g=g, 69 | scale=scale, 70 | initial_state=initial_state, 71 | output_final_state=output_final_state, 72 | reverse=reverse, 73 | cu_seqlens=cu_seqlens, 74 | head_first=head_first 75 | ) 76 | -------------------------------------------------------------------------------- /fla/ops/linear_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_linear_attn 4 | from .fused_chunk import fused_chunk_linear_attn 5 | from .fused_recurrent import fused_recurrent_linear_attn 6 | 7 | __all__ = [ 8 | 'chunk_linear_attn', 9 | 'fused_chunk_linear_attn', 10 | 'fused_recurrent_linear_attn' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/linear_attn/chunk.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Yu Zhang, Songlin Yang 3 | 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | 8 | from fla.ops.linear_attn.utils import normalize_output 9 | from fla.ops.simple_gla import chunk_simple_gla 10 | 11 | 12 | @torch.compiler.disable 13 | def chunk_linear_attn( 14 | q: torch.Tensor, 15 | k: torch.Tensor, 16 | v: torch.Tensor, 17 | scale: Optional[float] = None, 18 | initial_state: Optional[torch.Tensor] = None, 19 | output_final_state: bool = False, 20 | normalize: bool = True, 21 | head_first: bool = False 22 | ) -> Tuple[torch.Tensor, torch.Tensor]: 23 | r""" 24 | Args: 25 | q (torch.Tensor): 26 | queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` 27 | k (torch.Tensor): 28 | keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` 29 | v (torch.Tensor): 30 | values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` 31 | scale (Optional[int]): 32 | Scale factor for the linear attention scores. 33 | If not provided, it will default to `1 / sqrt(K)`. Default: `None`. 34 | initial_state (Optional[torch.Tensor]): 35 | Initial state of shape `[B, H, K, V]`. Default: `None`. 36 | output_final_state (Optional[bool]): 37 | Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. 38 | normalize (bool): 39 | Whether to normalize the output. Default: `True`. 40 | head_first (Optional[bool]): 41 | Whether the inputs are in the head-first format. Default: `False`. 42 | 43 | Returns: 44 | o (torch.Tensor): 45 | Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` 46 | final_state (torch.Tensor): 47 | Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` 48 | """ 49 | 50 | if scale is None: 51 | scale = k.shape[-1] ** -0.5 52 | if head_first: 53 | raise DeprecationWarning( 54 | "head_first is deprecated and will be removed in a future version. " 55 | "Please use head_first=False for now instead." 56 | ) 57 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 58 | if not head_first: 59 | if q.shape[1] < q.shape[2]: 60 | raise DeprecationWarning( 61 | f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " 62 | "This may indicate the inputs were passed in head-first format [B, H, T, ...] " 63 | "when head_first=False was specified. " 64 | "Please verify your input tensor format matches the expected shape [B, T, H, ...]." 65 | ) 66 | o, final_state = chunk_simple_gla( 67 | q=q, 68 | k=k, 69 | v=v, 70 | scale=scale, 71 | g=None, 72 | initial_state=initial_state, 73 | output_final_state=output_final_state 74 | ) 75 | if normalize: 76 | o = normalize_output(q * scale, k, o) 77 | if head_first: 78 | o = o.transpose(1, 2) 79 | return o, final_state 80 | -------------------------------------------------------------------------------- /fla/ops/linear_attn/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | from einops import rearrange 8 | 9 | from fla.ops.linear_attn.utils import normalize_output 10 | 11 | 12 | def naive_chunk_linear_attn( 13 | q: torch.Tensor, 14 | k: torch.Tensor, 15 | v: torch.Tensor, 16 | scale: Optional[float] = None, 17 | normalize: bool = False 18 | ) -> Tuple[torch.Tensor, torch.Tensor]: 19 | if scale is None: 20 | scale = q.shape[-1] ** -0.5 21 | chunk_size = 64 22 | q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale 23 | k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) 24 | v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) 25 | kv = k.transpose(-1, -2) @ v 26 | kv = kv.cumsum(2) 27 | kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) 28 | inter = q @ kv 29 | intra = (( 30 | q @ k.transpose(-1, -2)).masked_fill_( 31 | torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 32 | 0 33 | )) @ v 34 | o = inter + intra 35 | if normalize: 36 | o = normalize_output(q * scale, k, o) 37 | return rearrange(o, 'b h n c d -> b (n c) h d') 38 | -------------------------------------------------------------------------------- /fla/ops/linear_attn/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | @torch.jit.script 7 | def normalize_output(q, k, o): 8 | k = k.cumsum(-2) 9 | z = (q * k).sum(-1, keepdim=True) 10 | return o / (z + 1e-10) 11 | -------------------------------------------------------------------------------- /fla/ops/mesa_net/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunk import chunk_mesa_net 2 | from .decoding_one_step import mesa_net_decoding_one_step 3 | from .naive import naive_mesa_net_decoding_one_step, naive_mesa_net_exact 4 | 5 | __all__ = ['chunk_mesa_net', 'naive_mesa_net_exact', 'mesa_net_decoding_one_step', 'naive_mesa_net_decoding_one_step'] 6 | -------------------------------------------------------------------------------- /fla/ops/nsa/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .naive import naive_nsa 4 | from .parallel import parallel_nsa 5 | 6 | __all__ = [ 7 | 'naive_nsa', 8 | 'parallel_nsa' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/path_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .parallel import parallel_path_attention 4 | 5 | __all__ = [ 6 | 'parallel_path_attention' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/path_attn/prepare_k_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets 6 | 7 | 8 | @triton.heuristics({ 9 | 'IS_VARLEN': lambda args: args['offsets'] is not None 10 | }) 11 | @triton.jit(do_not_specialize=['T']) 12 | def parallel_path_fwd_kernel_prepare_k_cache( 13 | k, k_new, h, 14 | offsets, indices, chunk_offsets, 15 | T, 16 | H: tl.constexpr, 17 | K: tl.constexpr, 18 | BT: tl.constexpr, BK: tl.constexpr, 19 | IS_VARLEN: tl.constexpr 20 | ): 21 | i_t, i_bh = tl.program_id(0), tl.program_id(1) 22 | i_b, i_h = i_bh // H, i_bh % H 23 | 24 | if IS_VARLEN: 25 | i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) 26 | boh = tl.load(chunk_offsets + i_n).to(tl.int32) 27 | bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) 28 | T = eos - bos 29 | else: 30 | i_n = i_b 31 | bos, eos = i_n * T, i_n * T + T 32 | NT = triton.cdiv(T, BT) 33 | boh = i_n * NT 34 | 35 | # offset calculations 36 | k += (bos * H + i_h) * K # GQA when H!=HQ 37 | k_new += (bos * H + i_h) * K # GQA when H!=HQ 38 | h += (boh * H + i_h) * K * K 39 | # constants 40 | stride_h = H * K * K 41 | p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) 42 | b_k = tl.zeros([BT, BK], dtype=tl.float32) 43 | b_k += tl.load(p_k, boundary_check=(0, 1)) 44 | for k_block_idx in range(i_t + 1, tl.cdiv(T, BT)): 45 | p_h = tl.make_block_ptr(h + k_block_idx * stride_h, (K, K), (1, K), (0, 0), (BK, BK), (0, 1)) 46 | b_h = tl.load(p_h, boundary_check=(0, 1)) 47 | b_k_minus = tl.dot(b_k.to(b_h.dtype), b_h) 48 | b_k = b_k - b_k_minus 49 | p_k_new = tl.make_block_ptr(k_new, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) 50 | tl.store(p_k_new, b_k.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) 51 | 52 | 53 | def prepare_k_cache_fn(k, h, cu_seqlens, BS, use_cache=False): 54 | if not use_cache: 55 | return None 56 | else: 57 | B, T, H, K = k.shape 58 | k_new = torch.empty_like(k) 59 | indices = prepare_chunk_indices(cu_seqlens, BS) if cu_seqlens is not None else None 60 | chunk_offsets = prepare_chunk_offsets(cu_seqlens, BS) if cu_seqlens is not None else None 61 | NT = triton.cdiv(T, BS) if cu_seqlens is None else len(indices) 62 | grid = (NT, B * H) 63 | parallel_path_fwd_kernel_prepare_k_cache[grid]( 64 | k=k, 65 | k_new=k_new, 66 | h=h, 67 | offsets=cu_seqlens, 68 | indices=indices, 69 | chunk_offsets=chunk_offsets, 70 | H=H, 71 | T=T, 72 | K=K, 73 | BT=BS, 74 | BK=triton.next_power_of_2(K) 75 | ) 76 | return k_new 77 | -------------------------------------------------------------------------------- /fla/ops/rebased/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .parallel import parallel_rebased 4 | 5 | __all__ = [ 6 | 'parallel_rebased' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/rebased/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | from typing import Optional 5 | 6 | import torch 7 | 8 | 9 | def naive_parallel_rebased( 10 | q: torch.Tensor, 11 | k: torch.Tensor, 12 | v: torch.Tensor, 13 | scale: Optional[float] = None, 14 | use_norm: bool = True, 15 | ) -> torch.Tensor: 16 | if scale is None: 17 | scale = q.shape[-1] ** -0.5 18 | q = q * scale 19 | attn = q @ k.transpose(-2, -1) 20 | attn = attn ** 2 21 | attn.masked_fill_(~torch.tril(torch.ones(q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) 22 | o = attn @ v 23 | if use_norm: 24 | z = attn.sum(-1) 25 | return o / (z[..., None] + 1e-6) 26 | else: 27 | return o 28 | -------------------------------------------------------------------------------- /fla/ops/retention/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_retention 4 | from .fused_chunk import fused_chunk_retention 5 | from .fused_recurrent import fused_recurrent_retention 6 | from .parallel import parallel_retention 7 | 8 | __all__ = [ 9 | 'chunk_retention', 10 | 'fused_chunk_retention', 11 | 'parallel_retention', 12 | 'fused_recurrent_retention' 13 | ] 14 | -------------------------------------------------------------------------------- /fla/ops/retention/fused_recurrent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | 8 | from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla 9 | 10 | 11 | def fused_recurrent_retention( 12 | q: torch.Tensor, 13 | k: torch.Tensor, 14 | v: torch.Tensor, 15 | scale: Optional[float] = None, 16 | initial_state: Optional[torch.Tensor] = None, 17 | output_final_state: bool = False, 18 | reverse: bool = False, 19 | cu_seqlens: Optional[torch.LongTensor] = None, 20 | ) -> Tuple[torch.Tensor, torch.Tensor]: 21 | s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(q.shape[2]), dtype=torch.float))).log() 22 | g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() 23 | o, final_state = fused_recurrent_simple_gla( 24 | q=q, 25 | k=k, 26 | v=v, 27 | g=g, 28 | scale=scale, 29 | initial_state=initial_state, 30 | output_final_state=output_final_state, 31 | reverse=reverse, 32 | cu_seqlens=cu_seqlens, 33 | ) 34 | return o, final_state 35 | -------------------------------------------------------------------------------- /fla/ops/retention/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def naive_retention(q, k, v): 7 | orig_type = q.dtype 8 | q, k, v = q.float(), k.float(), v.float() 9 | _, n_heads, seq_len, d_head = q.shape 10 | s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2() 11 | n = q.new_tensor(range(seq_len), dtype=torch.float) 12 | n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n) 13 | s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype)) 14 | o = torch.einsum('bhqk,bhkd->bhqd', s, v) 15 | return o.to(orig_type) 16 | -------------------------------------------------------------------------------- /fla/ops/retention/parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | import warnings 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | from einops import rearrange 9 | 10 | from fla.ops.simple_gla.parallel import parallel_simple_gla 11 | 12 | 13 | def parallel_retention( 14 | q: torch.Tensor, 15 | k: torch.Tensor, 16 | v: torch.Tensor, 17 | scale: Optional[float] = None, 18 | output_attentions: bool = False, 19 | cu_seqlens: Optional[torch.LongTensor] = None, 20 | head_first: bool = False 21 | ) -> Tuple[torch.Tensor, torch.Tensor]: 22 | r""" 23 | Args: 24 | q (torch.Tensor): 25 | queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` 26 | k (torch.Tensor): 27 | keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` 28 | v (torch.Tensor): 29 | values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` 30 | scale (Optional[int]): 31 | Scale factor for attention scores. 32 | If not provided, it will default to `1 / sqrt(K)`. Default: `None`. 33 | output_attentions (bool): 34 | Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. 35 | cu_seqlens (torch.LongTensor): 36 | Cumulative sequence lengths of shape `[N+1]` used for variable-length training, 37 | consistent with the FlashAttention API. 38 | head_first (Optional[bool]): 39 | Whether the inputs are in the head-first format. Default: `False`. 40 | 41 | Returns: 42 | o (torch.Tensor): 43 | Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. 44 | attn (torch.Tensor): 45 | Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None` 46 | """ 47 | if head_first: 48 | raise DeprecationWarning( 49 | "head_first is deprecated and will be removed in a future version. " 50 | "Please use head_first=False for now instead." 51 | ) 52 | q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v)) 53 | if not head_first and q.shape[1] < q.shape[2]: 54 | warnings.warn( 55 | f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " 56 | "This may indicate the inputs were passed in head-first format [B, H, T, ...] " 57 | "when head_first=False was specified. " 58 | "Please verify your input tensor format matches the expected shape [B, T, H, ...]." 59 | ) 60 | s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(q.shape[2]), dtype=torch.float))).log() 61 | g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]) 62 | 63 | o, attn = parallel_simple_gla( 64 | q=q, 65 | k=k, 66 | v=v, 67 | scale=scale, 68 | g=g, 69 | output_attentions=output_attentions, 70 | cu_seqlens=cu_seqlens 71 | ) 72 | if head_first: 73 | o = rearrange(o, 'b t h ... -> b h t ...') 74 | return o, attn 75 | -------------------------------------------------------------------------------- /fla/ops/rwkv4/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .fused_recurrent import fused_recurrent_rwkv4 4 | 5 | __all__ = [ 6 | 'fused_recurrent_rwkv4' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/rwkv6/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_rwkv6 4 | from .fused_recurrent import fused_recurrent_rwkv6 5 | from .recurrent_naive import native_recurrent_rwkv6 6 | __all__ = [ 7 | 'chunk_rwkv6', 8 | 'fused_recurrent_rwkv6', 9 | 'native_recurrent_rwkv6', 10 | ] 11 | -------------------------------------------------------------------------------- /fla/ops/rwkv6/chunk_naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def naive_chunk_rwkv6( 8 | q: torch.Tensor, 9 | k: torch.Tensor, 10 | v: torch.Tensor, 11 | w: torch.Tensor, 12 | u: torch.Tensor, 13 | chunk_size: int = 32 14 | ): 15 | assert q.shape[-2] % chunk_size == 0 16 | orig_dtype = q.dtype 17 | num_chunk = q.shape[-2] // chunk_size 18 | u = u.unsqueeze(0) 19 | 20 | q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w)) 21 | 22 | w_cumsum = w.cumsum(-2) 23 | 24 | kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp() 25 | wkv = kw.transpose(-1, -2) @ v 26 | 27 | wkv_new = torch.zeros_like(wkv) 28 | 29 | for i in range(num_chunk - 1): 30 | wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i] 31 | 32 | o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp())) 33 | 34 | o_intra = torch.zeros_like(o_inter) 35 | for i in range(chunk_size): 36 | attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1) 37 | mask = (torch.arange(0, chunk_size) < i).to(attn.device) 38 | attn.masked_fill_(~mask, 0) 39 | intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2) 40 | intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i] 41 | o_intra[:, :, :, i] = intra_inter_o + intra_intra_o 42 | o = o_inter + o_intra 43 | return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype) 44 | -------------------------------------------------------------------------------- /fla/ops/rwkv7/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .channel_mixing import channel_mixing_rwkv7 4 | from .chunk import chunk_rwkv7 5 | from .fused_addcmul import fused_addcmul_rwkv7, torch_addcmul_rwkv7 6 | from .fused_k_update import fused_k_rwkv7 7 | from .fused_recurrent import fused_mul_recurrent_rwkv7, fused_recurrent_rwkv7 8 | from .recurrent_naive import native_recurrent_rwkv7 9 | 10 | __all__ = [ 11 | 'channel_mixing_rwkv7', 12 | 'chunk_rwkv7', 13 | 'fused_addcmul_rwkv7', 14 | 'torch_addcmul_rwkv7', 15 | 'fused_k_rwkv7', 16 | 'fused_recurrent_rwkv7', 17 | 'fused_mul_recurrent_rwkv7', 18 | 'native_recurrent_rwkv7', 19 | ] 20 | -------------------------------------------------------------------------------- /fla/ops/rwkv7/chunk.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | from typing import Optional 5 | 6 | import torch 7 | 8 | from fla.ops.generalized_delta_rule import chunk_dplr_delta_rule 9 | 10 | 11 | @torch.compile(fullgraph=True) 12 | def cal_log_w(w: torch.Tensor) -> torch.Tensor: 13 | return -torch.exp(w) 14 | 15 | 16 | def chunk_rwkv7( 17 | r: torch.Tensor, 18 | k: torch.Tensor, 19 | v: torch.Tensor, 20 | a: torch.Tensor, 21 | b: torch.Tensor, 22 | w: torch.Tensor = None, 23 | log_w: torch.Tensor = None, 24 | scale: float = 1.0, 25 | initial_state: torch.Tensor = None, 26 | output_final_state: bool = True, 27 | cu_seqlens: Optional[torch.LongTensor] = None, 28 | head_first: bool = False 29 | ): 30 | """ 31 | Args: 32 | r (torch.Tensor): 33 | r of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 34 | w (torch.Tensor): 35 | log decay of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 36 | k (torch.Tensor): 37 | k of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 38 | v (torch.Tensor): 39 | v of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. 40 | a (torch.Tensor): 41 | a of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 42 | b (torch.Tensor): 43 | b of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. 44 | scale (float): 45 | scale of the attention. 46 | initial_state (Optional[torch.Tensor]): 47 | Initial state of shape `[N, H, K, V]` for `N` input sequences. 48 | For equal-length input sequences, `N` equals the batch size `B`. 49 | Default: `None`. 50 | output_final_state (Optional[bool]): 51 | Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. 52 | cu_seqlens (torch.LongTensor): 53 | Cumulative sequence lengths of shape `[N+1]` used for variable-length training, 54 | consistent with the FlashAttention API. 55 | head_first (bool): 56 | whether to use head first. Recommended to be False to avoid extra transposes. 57 | Default: `False`. 58 | """ 59 | if w is not None: 60 | log_w = cal_log_w(w) 61 | else: 62 | assert log_w is not None, "Either w or log_w must be provided!" 63 | 64 | return chunk_dplr_delta_rule( 65 | q=r, 66 | k=k, 67 | v=v, 68 | a=a, 69 | b=b, 70 | gk=log_w, 71 | scale=scale, 72 | initial_state=initial_state, 73 | output_final_state=output_final_state, 74 | cu_seqlens=cu_seqlens, 75 | head_first=head_first 76 | ) 77 | -------------------------------------------------------------------------------- /fla/ops/simple_gla/README.md: -------------------------------------------------------------------------------- 1 | # Simple GLA 2 | 3 | Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). 4 | 5 | Compared to GLA, the gating is head-wise instead of elementwise. 6 | As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. 7 | It is faster than GLA but has less expressive power. 8 | I will use it as a baseline for the GLA. 9 | 10 | $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. 11 | -------------------------------------------------------------------------------- /fla/ops/simple_gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_simple_gla 4 | from .fused_recurrent import fused_recurrent_simple_gla 5 | from .parallel import parallel_simple_gla 6 | 7 | __all__ = [ 8 | 'chunk_simple_gla', 9 | 'fused_recurrent_simple_gla', 10 | 'parallel_simple_gla' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/simple_gla/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def torch_simple_gla(q, k, v, g, chunk_size=64, scale=None): 8 | if scale is None: 9 | scale = (q.shape[-1] ** -0.5) 10 | q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale 11 | k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) 12 | v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) 13 | g = rearrange(g, 'b h (n c) -> b h n c', c=chunk_size) 14 | g = g.cumsum(-1) 15 | kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None]) 16 | S = torch.zeros_like(kv) 17 | 18 | for i in range(1, g.shape[-2]): 19 | S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1] 20 | 21 | inter = (q * g[..., None].exp()) @ S 22 | attn = q @ k.transpose(-1, -2) 23 | attn = attn * (g[..., None] - g[..., None, :]).exp() 24 | attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) 25 | intra = attn @ v 26 | o = inter + intra 27 | return rearrange(o, 'b h n c d -> b h (n c) d') 28 | 29 | 30 | def torch_simple_gla_recurrent(q, k, v, g, scale=None, initial_state=None, output_final_state=True): 31 | B, H, T, DK = q.shape 32 | original_dtype = q.dtype 33 | q, k, v, g = q.float(), k.float(), v.float(), g.float() 34 | if scale is None: 35 | scale = DK ** -0.5 36 | q = q * scale 37 | _, _, _, DV = v.shape 38 | if initial_state is None: 39 | S = torch.zeros(B, H, DK, DV) 40 | else: 41 | S = initial_state 42 | o = torch.zeros(B, H, T, DV).to(q) 43 | for i in range(T): 44 | gate = g[:, :, i].exp() 45 | key = k[:, :, i] 46 | value = v[:, :, i] 47 | kv = key.unsqueeze(-1) * value.unsqueeze(-2) 48 | S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv 49 | q_i = q[:, :, i, :] 50 | o_i = (q_i.unsqueeze(-1) * S).sum(-2) 51 | o[:, :, i] = o_i 52 | if not output_final_state: 53 | S = None 54 | return o.to(original_dtype), S 55 | -------------------------------------------------------------------------------- /fla/ops/titans/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .naive import chunk_titans_linear 4 | 5 | __all__ = [ 6 | 'chunk_titans_linear' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/ttt/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_ttt_linear 4 | from .fused_chunk import fused_chunk_ttt_linear 5 | 6 | __all__ = [ 7 | 'fused_chunk_ttt_linear', 8 | 'chunk_ttt_linear' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .asm import fp32_to_tf32_asm 4 | from .cumsum import ( 5 | chunk_global_cumsum, 6 | chunk_global_cumsum_scalar, 7 | chunk_global_cumsum_vector, 8 | chunk_local_cumsum, 9 | chunk_local_cumsum_scalar, 10 | chunk_local_cumsum_vector 11 | ) 12 | from .index import ( 13 | prepare_chunk_indices, 14 | prepare_chunk_offsets, 15 | prepare_cu_seqlens_from_mask, 16 | prepare_lens, 17 | prepare_lens_from_mask, 18 | prepare_position_ids, 19 | prepare_sequence_ids, 20 | prepare_token_indices 21 | ) 22 | from .logsumexp import logsumexp_fwd 23 | from .matmul import addmm, matmul 24 | from .pack import pack_sequence, unpack_sequence 25 | from .pooling import mean_pooling 26 | from .softmax import softmax_bwd, softmax_fwd 27 | from .solve_tril import solve_tril 28 | 29 | __all__ = [ 30 | 'chunk_global_cumsum', 31 | 'chunk_global_cumsum_scalar', 32 | 'chunk_global_cumsum_vector', 33 | 'chunk_local_cumsum', 34 | 'chunk_local_cumsum_scalar', 35 | 'chunk_local_cumsum_vector', 36 | 'pack_sequence', 37 | 'unpack_sequence', 38 | 'prepare_chunk_indices', 39 | 'prepare_chunk_offsets', 40 | 'prepare_cu_seqlens_from_mask', 41 | 'prepare_lens', 42 | 'prepare_lens_from_mask', 43 | 'prepare_position_ids', 44 | 'prepare_sequence_ids', 45 | 'prepare_token_indices', 46 | 'logsumexp_fwd', 47 | 'addmm', 48 | 'matmul', 49 | 'mean_pooling', 50 | 'softmax_bwd', 51 | 'softmax_fwd', 52 | 'fp32_to_tf32_asm', 53 | 'solve_tril', 54 | ] 55 | -------------------------------------------------------------------------------- /fla/ops/utils/asm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.utils import device_platform 4 | 5 | 6 | def fp32_to_tf32_asm() -> str: 7 | """ 8 | Get the assembly code for converting FP32 to TF32. 9 | """ 10 | ASM_DICT = { 11 | 'nvidia': 'cvt.rna.tf32.f32 $0, $1;' 12 | } 13 | if device_platform in ASM_DICT: 14 | return ASM_DICT[device_platform] 15 | else: 16 | # return empty string if the device is not supported 17 | return "" 18 | -------------------------------------------------------------------------------- /fla/ops/utils/index.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import triton 7 | import triton.language as tl 8 | 9 | from fla.utils import tensor_cache 10 | 11 | 12 | @triton.autotune( 13 | configs=[ 14 | triton.Config({}, num_warps=num_warps) 15 | for num_warps in [4, 8, 16, 32] 16 | ], 17 | key=['B'], 18 | ) 19 | @triton.jit 20 | def prepare_position_ids_kernel( 21 | y, 22 | cu_seqlens, 23 | B: tl.constexpr 24 | ): 25 | i_n = tl.program_id(0) 26 | bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) 27 | T = eos - bos 28 | 29 | o = tl.arange(0, B) 30 | for i in range(0, tl.cdiv(T, B) * B, B): 31 | o_i = o + i 32 | tl.store(y + bos + o_i, o_i, o_i < T) 33 | 34 | 35 | @tensor_cache 36 | def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: 37 | return cu_seqlens[1:] - cu_seqlens[:-1] 38 | 39 | 40 | @tensor_cache 41 | def prepare_lens_from_mask(mask: torch.BoolTensor) -> torch.LongTensor: 42 | return mask.sum(dim=-1, dtype=torch.int32) 43 | 44 | 45 | @tensor_cache 46 | def prepare_cu_seqlens_from_mask(mask: torch.BoolTensor, out_dtype: torch.dtype = torch.int32) -> torch.LongTensor: 47 | return F.pad(prepare_lens_from_mask(mask).cumsum(dim=0, dtype=out_dtype), (1, 0)) 48 | 49 | 50 | @tensor_cache 51 | def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: 52 | return torch.cat([ 53 | torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) 54 | for n in prepare_lens(cu_seqlens).unbind() 55 | ]) 56 | 57 | 58 | @tensor_cache 59 | def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: 60 | return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 61 | 62 | 63 | @tensor_cache 64 | def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor: 65 | position_ids = prepare_position_ids(cu_seqlens) 66 | return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens) 67 | 68 | 69 | @tensor_cache 70 | def prepare_chunk_indices( 71 | cu_seqlens: torch.LongTensor, 72 | chunk_size: int 73 | ) -> torch.LongTensor: 74 | indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) 75 | return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) 76 | 77 | 78 | @tensor_cache 79 | def prepare_chunk_offsets( 80 | cu_seqlens: torch.LongTensor, 81 | chunk_size: int 82 | ) -> torch.LongTensor: 83 | return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) 84 | -------------------------------------------------------------------------------- /fla/ops/utils/logcumsumexp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 3 | 4 | import triton 5 | import triton.language as tl 6 | 7 | from fla.ops.utils.op import exp, log 8 | 9 | 10 | @triton.autotune( 11 | configs=[ 12 | triton.Config({'BT': BT}, num_warps=num_warps) 13 | for BT in [16, 32, 64] 14 | for num_warps in [2, 4, 8] 15 | ], 16 | key=['S'] 17 | ) 18 | @triton.jit(do_not_specialize=['T']) 19 | def logcumsumexp_fwd_kernel( 20 | s, 21 | z, 22 | T, 23 | S: tl.constexpr, 24 | BT: tl.constexpr 25 | ): 26 | i_bh = tl.program_id(0) 27 | o_i = tl.arange(0, BT) 28 | m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) 29 | 30 | b_mp = tl.full([S,], float('-inf'), dtype=tl.float32) 31 | b_zp = tl.zeros([S,], dtype=tl.float32) 32 | for i_t in range(tl.cdiv(T, BT)): 33 | p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0)) 34 | p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0)) 35 | 36 | # [BT, S] 37 | b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) 38 | # [S,] 39 | b_mc = tl.max(b_s, 0) 40 | b_mc = tl.maximum(b_mp, b_mc) 41 | b_zp = b_zp * exp(b_mp - b_mc) 42 | # [BT, S] 43 | b_s = exp(b_s - b_mc) 44 | b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp 45 | # [S,] 46 | b_zc = tl.max(b_z, 0) 47 | b_mp = b_mc 48 | b_zp = b_zc 49 | # [BT, BS] 50 | # small eps to prevent underflows 51 | b_z = log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc 52 | tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1)) 53 | -------------------------------------------------------------------------------- /fla/ops/utils/logsumexp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang 3 | 4 | from typing import Optional 5 | 6 | import torch 7 | import triton 8 | import triton.language as tl 9 | 10 | from fla.ops.utils.op import exp, log 11 | 12 | 13 | @triton.heuristics({ 14 | 'HAS_SCALE': lambda args: args['scale'] is not None 15 | }) 16 | @triton.autotune( 17 | configs=[ 18 | triton.Config({}, num_warps=num_warps) 19 | for num_warps in [1, 2, 4, 8, 16, 32] 20 | ], 21 | key=['D'] 22 | ) 23 | @triton.jit 24 | def logsumexp_fwd_kernel( 25 | x, 26 | z, 27 | scale, 28 | D: tl.constexpr, 29 | B: tl.constexpr, 30 | HAS_SCALE: tl.constexpr 31 | ): 32 | i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) 33 | o_d = i_d * B + tl.arange(0, B) 34 | m_d = o_d < D 35 | 36 | b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) 37 | if HAS_SCALE: 38 | b_x = b_x * scale 39 | b_m = tl.max(b_x, 0) 40 | b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m 41 | tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z) 42 | 43 | 44 | def logsumexp_fwd( 45 | x, 46 | scale: Optional[float] = None, 47 | dtype: Optional[torch.dtype] = None 48 | ): 49 | r""" 50 | Compute the logsumexp of the input tensor over the last dimension. 51 | 52 | Args: 53 | x (Tensor): 54 | The input tensor of any shape. 55 | scale (Optional[float]): 56 | The scale applied to the input tensor. Default: `None`. 57 | dtype (Optional[torch.dtype]): 58 | The data type of the output tensor. Default: `None`. 59 | Returns: 60 | Tensor: The logsumexp of the input tensor. 61 | """ 62 | 63 | shape = x.shape 64 | x = x.view(-1, shape[-1]) 65 | N, D = x.shape 66 | B = min(triton.next_power_of_2(D), 64 * 1024) 67 | ND = triton.cdiv(D, B) 68 | 69 | z = x.new_empty(N, ND, dtype=torch.float) 70 | logsumexp_fwd_kernel[(N, ND)]( 71 | x=x, 72 | z=z, 73 | scale=scale, 74 | D=D, 75 | B=B 76 | ) 77 | z = z.logsumexp(-1).view(*shape[:-1]) 78 | if dtype is not None and dtype != torch.float: 79 | z = z.to(dtype) 80 | return z 81 | -------------------------------------------------------------------------------- /fla/ops/utils/op.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2024, Songlin Yang, Yu Zhang 3 | 4 | import os 5 | 6 | import triton 7 | import triton.language as tl 8 | import triton.language.extra.libdevice as tldevice 9 | 10 | from fla.utils import is_gather_supported 11 | 12 | if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': 13 | div = tldevice.fast_dividef 14 | exp = tldevice.fast_expf 15 | log = tldevice.fast_logf 16 | log2 = tldevice.fast_log2f 17 | else: 18 | @triton.jit 19 | def div_normal(x, y): 20 | return x / y 21 | div = div_normal 22 | exp = tl.exp 23 | log = tl.log 24 | log2 = tl.log2 25 | 26 | 27 | @triton.jit 28 | def safe_exp(x): 29 | return exp(tl.where(x <= 0, x, float('-inf'))) 30 | 31 | 32 | if not is_gather_supported: 33 | @triton.jit 34 | def gather(src, index, axis, _builder=None): 35 | # This is a fallback implementation when tl.gather is not supported 36 | # In order to pass triton compiler, there is no actual gather operation 37 | return src 38 | else: 39 | gather = tl.gather 40 | -------------------------------------------------------------------------------- /fla/ops/utils/softmax.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2023-2024, Songlin Yang, Yu Zhang 3 | 4 | from typing import Optional 5 | 6 | import torch 7 | import triton 8 | import triton.language as tl 9 | 10 | from fla.ops.utils.op import exp 11 | 12 | 13 | @triton.autotune( 14 | configs=[ 15 | triton.Config({}, num_warps=1), 16 | triton.Config({}, num_warps=2), 17 | triton.Config({}, num_warps=4), 18 | triton.Config({}, num_warps=8), 19 | triton.Config({}, num_warps=16), 20 | triton.Config({}, num_warps=32) 21 | ], 22 | key=['D'] 23 | ) 24 | @triton.jit 25 | def softmax_fwd_kernel( 26 | x, 27 | p, 28 | D: tl.constexpr, 29 | B: tl.constexpr 30 | ): 31 | i_n = tl.program_id(0) 32 | o_d = tl.arange(0, B) 33 | m_d = o_d < D 34 | 35 | b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) 36 | b_m = tl.max(b_x, 0) 37 | b_x = exp(b_x - b_m) 38 | b_p = b_x / tl.sum(b_x, 0) 39 | 40 | tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d) 41 | 42 | 43 | @triton.autotune( 44 | configs=[ 45 | triton.Config({}, num_warps=1), 46 | triton.Config({}, num_warps=2), 47 | triton.Config({}, num_warps=4), 48 | triton.Config({}, num_warps=8), 49 | triton.Config({}, num_warps=16), 50 | triton.Config({}, num_warps=32) 51 | ], 52 | key=['D'] 53 | ) 54 | @triton.jit 55 | def softmax_bwd_kernel( 56 | p, 57 | dp, 58 | ds, 59 | D: tl.constexpr, 60 | B: tl.constexpr 61 | ): 62 | i_n = tl.program_id(0) 63 | o_d = tl.arange(0, B) 64 | m_d = o_d < D 65 | 66 | b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.) 67 | b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.) 68 | b_pp = tl.sum(b_p * b_dp, 0) 69 | b_ds = b_p * b_dp - b_p * b_pp 70 | tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d) 71 | 72 | 73 | def softmax_fwd( 74 | x: torch.Tensor, 75 | dtype: Optional[torch.dtype] = torch.float 76 | ) -> torch.Tensor: 77 | shape = x.shape 78 | x = x.view(-1, x.shape[-1]) 79 | 80 | N, D = x.shape 81 | B = triton.next_power_of_2(D) 82 | 83 | p = torch.empty_like(x, dtype=dtype) 84 | softmax_fwd_kernel[(N,)]( 85 | x=x, 86 | p=p, 87 | D=D, 88 | B=B 89 | ) 90 | return p.view(*shape) 91 | 92 | 93 | def softmax_bwd( 94 | p: torch.Tensor, 95 | dp: torch.Tensor, 96 | dtype: Optional[torch.dtype] = torch.float 97 | ) -> torch.Tensor: 98 | shape = p.shape 99 | p = p.view(-1, p.shape[-1]) 100 | ds = torch.empty_like(p, dtype=dtype) 101 | 102 | N, D = p.shape 103 | B = triton.next_power_of_2(D) 104 | softmax_bwd_kernel[(N,)]( 105 | p=p, 106 | dp=dp, 107 | ds=ds, 108 | D=D, 109 | B=B 110 | ) 111 | return ds.view(*shape) 112 | -------------------------------------------------------------------------------- /legacy/training/configs/gla_1B.json: -------------------------------------------------------------------------------- 1 | { 2 | "attn_mode": "chunk", 3 | "bos_token_id": 1, 4 | "clamp_min": null, 5 | "eos_token_id": 2, 6 | "expand_k": 0.5, 7 | "expand_v": 1, 8 | "fuse_cross_entropy": true, 9 | "fuse_norm": true, 10 | "hidden_act": "swish", 11 | "hidden_ratio": 4, 12 | "hidden_size": 2048, 13 | "initializer_range": 0.02, 14 | "intermediate_size": null, 15 | "model_type": "gla", 16 | "num_heads": 4, 17 | "num_hidden_layers": 24, 18 | "norm_eps": 1e-06, 19 | "tie_word_embeddings": false, 20 | "transformers_version": "4.45.0", 21 | "use_cache": true, 22 | "use_gk": true, 23 | "use_gv": false, 24 | "vocab_size": 32000 25 | } -------------------------------------------------------------------------------- /legacy/training/configs/gla_340M.json: -------------------------------------------------------------------------------- 1 | { 2 | "attn_mode": "chunk", 3 | "bos_token_id": 1, 4 | "clamp_min": null, 5 | "eos_token_id": 2, 6 | "expand_k": 0.5, 7 | "expand_v": 1, 8 | "fuse_cross_entropy": true, 9 | "fuse_norm": true, 10 | "hidden_act": "swish", 11 | "hidden_ratio": 4, 12 | "hidden_size": 1024, 13 | "initializer_range": 0.02, 14 | "intermediate_size": null, 15 | "model_type": "gla", 16 | "num_heads": 4, 17 | "num_hidden_layers": 24, 18 | "norm_eps": 1e-06, 19 | "tie_word_embeddings": true, 20 | "use_cache": true, 21 | "use_gk": true, 22 | "use_gv": false, 23 | "vocab_size": 32000 24 | } -------------------------------------------------------------------------------- /legacy/training/configs/gla_7B.json: -------------------------------------------------------------------------------- 1 | { 2 | "attn_mode": "chunk", 3 | "bos_token_id": 1, 4 | "clamp_min": null, 5 | "eos_token_id": 2, 6 | "expand_k": 1, 7 | "expand_v": 1, 8 | "feature_map": "relu", 9 | "fuse_cross_entropy": true, 10 | "fuse_norm": true, 11 | "hidden_act": "swish", 12 | "hidden_ratio": 4, 13 | "hidden_size": 4096, 14 | "initializer_range": 0.02, 15 | "intermediate_size": 14336, 16 | "model_type": "gla", 17 | "num_heads": 32, 18 | "num_kv_heads": 8, 19 | "num_hidden_layers": 32, 20 | "norm_eps": 1e-05, 21 | "tie_word_embeddings": false, 22 | "transformers_version": "4.45.0", 23 | "use_cache": true, 24 | "use_output_gate": false, 25 | "use_gk": true, 26 | "use_gv": false, 27 | "vocab_size": 32000 28 | } -------------------------------------------------------------------------------- /legacy/training/configs/transformer_340M.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_bias": false, 3 | "bos_token_id": 1, 4 | "eos_token_id": 2, 5 | "fuse_cross_entropy": true, 6 | "fuse_norm": true, 7 | "hidden_act": "swish", 8 | "hidden_size": 1024, 9 | "initializer_range": 0.02, 10 | "max_position_embeddings": 8192, 11 | "model_type": "transformer", 12 | "num_heads": 16, 13 | "num_hidden_layers": 24, 14 | "norm_eps": 1e-06, 15 | "tie_word_embeddings": true, 16 | "use_cache": true, 17 | "vocab_size": 32000 18 | } -------------------------------------------------------------------------------- /legacy/training/flame/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchRWKV/flash-linear-attention/98b8f55bf23d43b024f4878aa7f8615bd040a453/legacy/training/flame/__init__.py -------------------------------------------------------------------------------- /legacy/training/flame/parser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import transformers 9 | from transformers import HfArgumentParser, TrainingArguments 10 | 11 | from flame.logging import get_logger 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | @dataclass 17 | class TrainingArguments(TrainingArguments): 18 | 19 | model_name_or_path: str = field( 20 | default=None, 21 | metadata={ 22 | "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." 23 | }, 24 | ) 25 | tokenizer: str = field( 26 | default="fla-hub/gla-1.3B-100B", 27 | metadata={"help": "Name of the tokenizer to use."} 28 | ) 29 | use_fast_tokenizer: bool = field( 30 | default=False, 31 | metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, 32 | ) 33 | from_config: bool = field( 34 | default=True, 35 | metadata={"help": "Whether to initialize models from scratch."}, 36 | ) 37 | dataset: Optional[str] = field( 38 | default=None, 39 | metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."}, 40 | ) 41 | dataset_name: Optional[str] = field( 42 | default=None, 43 | metadata={"help": "The name of provided dataset(s) to use."}, 44 | ) 45 | cache_dir: str = field( 46 | default=None, 47 | metadata={"help": "Path to the cached tokenized dataset."}, 48 | ) 49 | split: str = field( 50 | default="train", 51 | metadata={"help": "Which dataset split to use for training and evaluation."}, 52 | ) 53 | streaming: bool = field( 54 | default=False, 55 | metadata={"help": "Enable dataset streaming."}, 56 | ) 57 | hf_hub_token: Optional[str] = field( 58 | default=None, 59 | metadata={"help": "Auth token to log in with Hugging Face Hub."}, 60 | ) 61 | preprocessing_num_workers: Optional[int] = field( 62 | default=None, 63 | metadata={"help": "The number of processes to use for the pre-processing."}, 64 | ) 65 | buffer_size: int = field( 66 | default=2048, 67 | metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, 68 | ) 69 | context_length: int = field( 70 | default=2048, 71 | metadata={"help": "The context length of the tokenized inputs in the dataset."}, 72 | ) 73 | varlen: bool = field( 74 | default=False, 75 | metadata={"help": "Enable training with variable length inputs."}, 76 | ) 77 | 78 | 79 | def get_train_args(): 80 | parser = HfArgumentParser(TrainingArguments) 81 | args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True) 82 | 83 | if unknown_args: 84 | print(parser.format_help()) 85 | print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) 86 | raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) 87 | 88 | if args.should_log: 89 | transformers.utils.logging.set_verbosity(args.get_process_log_level()) 90 | transformers.utils.logging.enable_default_handler() 91 | transformers.utils.logging.enable_explicit_format() 92 | # set seeds manually 93 | transformers.set_seed(args.seed) 94 | return args 95 | -------------------------------------------------------------------------------- /legacy/training/run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from datasets import load_from_disk 4 | from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, 5 | Trainer) 6 | 7 | import fla # noqa 8 | from flame.data import DataCollatorForLanguageModeling 9 | from flame.logging import LogCallback, get_logger 10 | from flame.parser import get_train_args 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | def main(): 16 | args = get_train_args() 17 | logger.info(args) 18 | 19 | tokenizer = AutoTokenizer.from_pretrained( 20 | args.tokenizer, 21 | use_fast=args.use_fast_tokenizer, 22 | trust_remote_code=True, 23 | add_bos_token=True, 24 | add_eos_token=False 25 | ) 26 | if tokenizer.pad_token_id is None: 27 | tokenizer.pad_token = tokenizer.eos_token 28 | logger.info("Add pad token: {}".format(tokenizer.pad_token)) 29 | if args.from_config: 30 | logger.info("All model params are randomly initialized for from-scratch training.") 31 | model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model_name_or_path)) 32 | else: 33 | logger.info(f"Loading pretrained checkpoint {args.model_name_or_path}") 34 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) 35 | model.train() 36 | 37 | trainable_params, all_param = model.num_parameters(only_trainable=True), model.num_parameters() 38 | logger.info(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}") 39 | logger.info(f"{tokenizer}\n{model}\n{model.config}") 40 | 41 | logger.info(f"Loading the `{args.split}` split directly from the cache {args.cache_dir}...") 42 | dataset = load_from_disk(args.cache_dir) 43 | logger.info(f"{dataset}") 44 | logger.info(f"Shuffling the dataset with seed {args.seed}") 45 | dataset = dataset.shuffle(seed=args.seed) 46 | logger.info("Creating the data collator") 47 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, varlen=args.varlen) 48 | logger.info(f"{data_collator}") 49 | 50 | if args.lr_scheduler_type == 'cosine_with_min_lr': 51 | args.lr_scheduler_kwargs = {'min_lr_rate': 0.1} 52 | if args.lr_scheduler_type == 'warmup_stable_decay': 53 | args.lr_scheduler_kwargs = { 54 | 'num_stable_steps': args.max_steps * 0.9 - args.warmup_steps, 55 | 'num_decay_steps': args.max_steps * 0.1 56 | } 57 | 58 | trainer = Trainer( 59 | model=model, 60 | args=args, 61 | processing_class=tokenizer, 62 | data_collator=data_collator, 63 | callbacks=[LogCallback()], 64 | train_dataset=dataset 65 | ) 66 | 67 | results = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) 68 | trainer.save_model() 69 | tokenizer.save_pretrained(trainer.args.output_dir) 70 | 71 | trainer.log_metrics("train", results.metrics) 72 | trainer.save_metrics("train", results.metrics) 73 | trainer.save_state() 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /tests/modules/test_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss 9 | from fla.utils import assert_close, device, device_platform 10 | 11 | 12 | @pytest.mark.parametrize("B", [2]) 13 | @pytest.mark.parametrize("T", [512, 1024]) 14 | @pytest.mark.parametrize("D", [1024, 2048]) 15 | @pytest.mark.parametrize("V", [32000, 100000]) 16 | @pytest.mark.parametrize("reduction", ['mean']) 17 | @pytest.mark.parametrize("dtype", [torch.bfloat16]) 18 | @pytest.mark.skipif( 19 | device_platform == 'intel', 20 | reason="Intel Triton Failure" 21 | ) 22 | def test_fused_cross_entropy(B: int, T: int, D: int, V: int, reduction: str, dtype: torch.dtype): 23 | torch.manual_seed(42) 24 | logits = torch.randn(B * T, V).to(device).to(dtype=dtype).requires_grad_() 25 | target = torch.randint(0, V, (B, T,)).to(device) 26 | target = torch.cat((target[..., 1:], torch.full_like(target[..., :1], -100)), -1) 27 | target = target.flatten() 28 | 29 | ref = nn.CrossEntropyLoss(reduction=reduction)(logits, target).to(dtype=dtype) 30 | do = torch.randn_like(ref).to(device).to(dtype=dtype) 31 | 32 | ref.backward(do) 33 | ref_d, logits.grad = logits.grad.clone(), None 34 | 35 | tri = FusedCrossEntropyLoss(reduction=reduction)(logits, target).to(dtype=dtype) 36 | tri.backward(do) 37 | tri_d, logits.grad = logits.grad.clone(), None 38 | 39 | assert_close(" o", ref, tri, ratio=1e-2) 40 | assert_close("dl", ref_d, tri_d, ratio=1e-2) 41 | 42 | 43 | @pytest.mark.parametrize("B", [2]) 44 | @pytest.mark.parametrize("T", [512, 1024]) 45 | @pytest.mark.parametrize("D", [1024, 2048]) 46 | @pytest.mark.parametrize("V", [32000, 100000]) 47 | @pytest.mark.parametrize("scale", [1., 0.5]) 48 | @pytest.mark.parametrize("reduction", ['mean']) 49 | @pytest.mark.parametrize("dtype", [torch.bfloat16]) 50 | @pytest.mark.skipif( 51 | device_platform == 'intel', 52 | reason="Intel Triton Failure" 53 | ) 54 | def test_fused_linear_cross_entropy(B: int, T: int, D: int, V: int, scale: float, reduction: str, dtype: torch.dtype): 55 | torch.manual_seed(42) 56 | 57 | x = torch.randn(B * T, D).to(device).to(dtype=dtype).requires_grad_() 58 | target = torch.randint(0, V, (B, T,)).to(device) 59 | target = torch.cat((target[..., 1:], torch.full_like(target[..., :1], -100)), -1) 60 | target = target.flatten() 61 | weight = torch.randn(V, D).to(device).to(dtype=dtype).requires_grad_() 62 | bias = torch.randn(V).to(device).to(dtype=dtype).requires_grad_() 63 | 64 | logits = F.linear(x, weight, bias) 65 | ref = FusedCrossEntropyLoss(logit_scale=scale, reduction=reduction)(logits, target) 66 | do = torch.randn_like(ref).to(device).to(dtype=dtype) 67 | 68 | ref.backward(do) 69 | ref_dx, x.grad = x.grad.clone(), None 70 | ref_dw, weight.grad = weight.grad.clone(), None 71 | ref_db, bias.grad = bias.grad.clone(), None 72 | 73 | tri = FusedLinearCrossEntropyLoss(logit_scale=scale, reduction=reduction)(x, target, weight, bias) 74 | tri.backward(do) 75 | tri_dx, x.grad = x.grad.clone(), None 76 | tri_dw, weight.grad = weight.grad.clone(), None 77 | tri_db, bias.grad = bias.grad.clone(), None 78 | 79 | assert_close(" o", ref, tri, ratio=1e-2) 80 | assert_close("dx", ref_dx, tri_dx, ratio=1e-2) 81 | assert_close("dw", ref_dw, tri_dw, ratio=1e-2) 82 | assert_close("db", ref_db, tri_db, ratio=1e-2) 83 | -------------------------------------------------------------------------------- /tests/modules/test_grpo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pytest 4 | import torch 5 | 6 | from fla.modules.grpo import fused_grpo_loss, grpo_loss_torch 7 | from fla.utils import assert_close, device, device_torch_lib, is_nvidia_hopper 8 | 9 | 10 | @pytest.mark.parametrize("B", [2]) 11 | @pytest.mark.parametrize("T", [16, 1024, 4096]) 12 | @pytest.mark.parametrize("V", [32000, 65536, 131072]) 13 | @pytest.mark.parametrize("dtype", [torch.bfloat16]) 14 | @pytest.mark.parametrize("inplace", [True, False]) 15 | @pytest.mark.parametrize("repeat", [100]) 16 | def test_fused_grpos(B: int, T: int, V: int, dtype: torch.dtype, inplace: bool, repeat: int): 17 | device_torch_lib.manual_seed(42) 18 | for i in range(repeat): 19 | if not is_nvidia_hopper and T == 4096: 20 | pytest.skip("Skip test for T=4096 on Intel Alchemist") 21 | 22 | def get_random_ref_log_probs(logits, input_ids): 23 | with torch.inference_mode(): 24 | logits = logits[:, :-1] 25 | per_token_logps = [] 26 | for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]): 27 | log_probs = torch.randn_like(logits_row).log_softmax(dim=-1) 28 | token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) 29 | per_token_logps.append(token_log_prob) 30 | device_torch_lib.empty_cache() 31 | return torch.stack(per_token_logps) 32 | 33 | logits = torch.randn(B, T + 1, V, device=device, dtype=dtype) 34 | logits.requires_grad_(True) 35 | advantages = torch.randn(B, device=device, dtype=torch.float32) 36 | input_ids = torch.randint(0, V-1, (B, T + 64), device=device) 37 | ref_logp = get_random_ref_log_probs(logits, input_ids) 38 | beta = 0.04 39 | completion_mask = torch.ones(B, T, dtype=torch.int32, device=device) 40 | completion_mask[::2, T//3: T//2] = 0 41 | save_kl = True 42 | 43 | gold_logits = logits.detach().clone().float() 44 | gold_logits.requires_grad_(True) 45 | gold_ref_logp = ref_logp.clone().float() 46 | device_torch_lib.empty_cache() 47 | y1 = fused_grpo_loss(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl=save_kl, inplace=inplace) 48 | y2 = grpo_loss_torch(gold_logits, gold_ref_logp, input_ids, advantages, beta, completion_mask, save_kl) 49 | if save_kl: 50 | y1, kl2 = y1 51 | y2, kl3 = y2 52 | assert (kl2-kl3).abs().max() < 1e-3 53 | dy = torch.randn_like(y1) * 10 54 | y1.backward(dy) 55 | y2.backward(dy.float()) 56 | assert (y1-y2).abs().max() < 1e-3 57 | assert_close(" dlogits", gold_logits.grad, logits.grad, 3e-3) 58 | -------------------------------------------------------------------------------- /tests/modules/test_kl_div.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pytest 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from fla.modules import FusedKLDivLoss 8 | from fla.utils import assert_close, device, device_platform 9 | 10 | 11 | @pytest.mark.parametrize("B", [2]) 12 | @pytest.mark.parametrize("T", [16, 32]) 13 | @pytest.mark.parametrize("D", [1024, 2048]) 14 | @pytest.mark.parametrize("V", [32000, 100000]) 15 | @pytest.mark.parametrize("reduction", ["batchmean"]) 16 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) 17 | @pytest.mark.skipif( 18 | device_platform == 'intel', 19 | reason="Intel Triton Failure" 20 | ) 21 | def test_fused(B: int, T: int, D: int, V: int, reduction: str, dtype: torch.dtype): 22 | torch.manual_seed(42) 23 | x = torch.randn(B * T, D).to(device).to(dtype=dtype).requires_grad_() 24 | x_weight = torch.randn(V, D).to(device).to(dtype=dtype).requires_grad_() 25 | target_x = torch.randn(B * T, D).to(device).to(dtype=dtype) 26 | target_weight = torch.randn(V, D).to(device).to(dtype=dtype) 27 | 28 | ref = F.kl_div( 29 | F.linear(x, x_weight).log_softmax(-1), 30 | F.linear(target_x, target_weight).softmax(-1), 31 | reduction=reduction 32 | ).to(dtype) 33 | do = torch.randn_like(ref).to(device) 34 | ref.backward(do) 35 | ref_dx, x.grad = x.grad.clone(), None 36 | ref_dw, x_weight.grad = x_weight.grad.clone(), None 37 | 38 | tri = FusedKLDivLoss(reduction)(x, target_x, x_weight, target_weight).to(dtype=dtype) 39 | tri.backward(do) 40 | tri_dx, x.grad = x.grad.clone(), None 41 | tri_dw, x_weight.grad = x_weight.grad.clone(), None 42 | 43 | assert_close(" o", ref, tri, 1e-2) 44 | assert_close(" dx", ref_dx, tri_dx, 1e-2) 45 | assert_close(" dw", ref_dw, tri_dw, 1e-2) 46 | -------------------------------------------------------------------------------- /tests/modules/test_l2norm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pytest 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from fla.modules.l2norm import l2_norm 8 | from fla.utils import assert_close, device 9 | 10 | 11 | @pytest.mark.parametrize("B", [2]) 12 | @pytest.mark.parametrize("T", [512]) 13 | @pytest.mark.parametrize("H", [2]) 14 | @pytest.mark.parametrize("D", [50, 64, 128, 1000, 2048]) 15 | def test_l2norm(B: int, H: int, T: int, D: int): 16 | torch.manual_seed(42) 17 | x = torch.randn(B, T, H, D).to(device).requires_grad_(True) 18 | x = x * 0.5 + 0.3 19 | 20 | ref_y = F.normalize(x, dim=-1, p=2) 21 | tri_y = l2_norm(x) 22 | ref_dx = torch.autograd.grad(ref_y.sum(), x)[0] 23 | tri_dx = torch.autograd.grad(tri_y.sum(), x)[0] 24 | 25 | assert_close(' y', ref_y, tri_y, ratio=1e-3) 26 | assert_close('dx', ref_dx, tri_dx, ratio=1e-3) 27 | -------------------------------------------------------------------------------- /tests/modules/test_token_shift.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from fla.modules.token_shift import token_shift, token_shift_ref 5 | from fla.utils import assert_close, device 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "test_case,batch_size,seq_len,hidden_size,cu_seqlens,dtype", 10 | [ 11 | ("fixed_length_standard", 8, 128, 128, None, torch.float32), 12 | ("fixed_length_different_dims", 4, 256, 64, None, torch.float32), 13 | ("var_length_standard", 1, 128, 128, [0, 4, 7, 40, 128], torch.float32), 14 | ("var_length_fewer_seqs", 1, 64, 64, [0, 10, 20, 64], torch.float32), 15 | ("var_length_single_seq", 1, 32, 32, [0, 32], torch.float32), 16 | ("edge_case_len_1", 1, 4, 64, [0, 1, 3, 4], torch.float32), 17 | ("dtype_float16", 2, 32, 64, None, torch.float16), 18 | ("dtype_bfloat16", 2, 32, 64, None, torch.bfloat16) 19 | ] 20 | ) 21 | def test_token_shift(test_case, batch_size, seq_len, hidden_size, cu_seqlens, dtype): 22 | """Comprehensive test for token shift operation""" 23 | 24 | # Set random seed for reproducibility 25 | torch.manual_seed(42) 26 | 27 | # Create test tensors 28 | x = torch.randn(batch_size, seq_len, hidden_size, device=device).to(dtype).requires_grad_(True) 29 | dy = torch.randn_like(x) 30 | 31 | cu_seqlens_tensor = None 32 | if cu_seqlens is not None: 33 | cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) 34 | 35 | # Forward pass 36 | ref = token_shift_ref(x, cu_seqlens_tensor) 37 | tri = token_shift(x, cu_seqlens_tensor) 38 | 39 | ref.backward(dy) 40 | ref_dx, x.grad = x.grad, None 41 | 42 | tri.backward(dy) 43 | tri_dx, x.grad = x.grad, None 44 | 45 | assert_close(' x', ref, tri, 1e-3) 46 | assert_close('dx', ref_dx, tri_dx, 1e-3) 47 | -------------------------------------------------------------------------------- /tests/ops/test_based.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import pytest 6 | import torch 7 | 8 | from fla.ops.based import fused_chunk_based, parallel_based 9 | from fla.ops.based.naive import naive_parallel_based 10 | from fla.utils import COMPILER_MODE, device 11 | 12 | if COMPILER_MODE: 13 | test_b_list = [1] 14 | test_t_list = [4096] 15 | test_d_list = [64, 128, 256] 16 | else: 17 | test_b_list = [2] 18 | test_t_list = [1, 15, 63, 300] 19 | test_d_list = [64, 32, 100, 256] 20 | test_h_list = [2] 21 | 22 | 23 | @pytest.mark.parametrize('B', test_b_list) 24 | @pytest.mark.parametrize('H', test_h_list) 25 | @pytest.mark.parametrize('T', test_t_list) 26 | @pytest.mark.parametrize('D', test_d_list) 27 | @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float32]) 28 | @pytest.mark.skipif( 29 | os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', 30 | reason='Skipping test because TEST_CHUNK_VARLEN is enabled' 31 | ) 32 | def test_based( 33 | B: int, 34 | H: int, 35 | T: int, 36 | D: int, 37 | dtype: torch.dtype 38 | ): 39 | torch.manual_seed(42) 40 | q = torch.randn((B, H, T, 16), dtype=dtype, device=device).requires_grad_() 41 | k = torch.randn((B, H, T, 16), dtype=dtype, device=device).requires_grad_() 42 | v = torch.randn((B, H, T, D), dtype=dtype, device=device).requires_grad_() 43 | do = torch.randn_like(v) 44 | ref = naive_parallel_based(q, k, v, use_norm=True) 45 | ref.backward(do) 46 | ref_dq, q.grad = q.grad.clone(), None 47 | ref_dk, k.grad = k.grad.clone(), None 48 | ref_dv, v.grad = v.grad.clone(), None 49 | 50 | tri = parallel_based(q, k, v, use_norm=True) 51 | tri.backward(do) 52 | tri_dq, q.grad = q.grad.clone(), None 53 | tri_dk, k.grad = k.grad.clone(), None 54 | tri_dv, v.grad = v.grad.clone(), None 55 | 56 | if dtype == torch.float32: 57 | assert ref.allclose(tri, 0, 1e-4) 58 | assert ref_dq.allclose(tri_dq, 0, 1e-4) 59 | assert ref_dk.allclose(tri_dk, 0, 1e-4) 60 | assert ref_dv.allclose(tri_dv, 0, 1e-4) 61 | 62 | tri = fused_chunk_based(q, k, v, use_norm=True) 63 | tri.backward(do) 64 | tri_dq, q.grad = q.grad.clone(), None 65 | tri_dk, k.grad = k.grad.clone(), None 66 | tri_dv, v.grad = v.grad.clone(), None 67 | 68 | if dtype == torch.float32: 69 | assert ref.allclose(tri, 0, 1e-4) 70 | assert ref_dq.allclose(tri_dq, 0, 1e-4) 71 | assert ref_dk.allclose(tri_dk, 0, 1e-4) 72 | assert ref_dv.allclose(tri_dv, 0, 1e-4) 73 | -------------------------------------------------------------------------------- /tests/ops/test_solve_tril.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import pytest 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd 10 | from fla.ops.utils.solve_tril import solve_tril 11 | from fla.utils import COMPILER_MODE, assert_close, device, device_platform 12 | 13 | if COMPILER_MODE: 14 | test_b_list = [1] 15 | test_t_list = [4096] 16 | test_t_varlen_list = [[0, 64, 128, 256, 512]] 17 | else: 18 | test_b_list = [2] 19 | test_t_list = [128, 200, 300, 500] 20 | test_t_varlen_list = [[0, 63, 286, 300, 512], [0, 127, 246, 521, 1000], [0, 255, 492, 1042, 2000]] 21 | test_h_list = [2] 22 | 23 | 24 | @pytest.mark.parametrize('B', test_b_list) 25 | @pytest.mark.parametrize('T', test_t_list) 26 | @pytest.mark.parametrize('H', test_h_list) 27 | @pytest.mark.parametrize('chunk_size', [16, 32, 64]) 28 | @pytest.mark.skipif( 29 | os.getenv('SKIP_TEST_CHUNK_VARLEN') == '0', 30 | reason='Skipping test because TEST_CHUNK_VARLEN is enabled' 31 | ) 32 | @pytest.mark.skipif( 33 | device_platform == 'intel', 34 | reason='Intel Pytorch Failure' 35 | ) 36 | def test_solve_tril(B, T, H, chunk_size): 37 | # do not randomly intiialize A otherwise the inverse is not stable 38 | k = F.normalize(torch.randn((B, H, T, 64), dtype=torch.float32, device=device), dim=-1) 39 | # Pad the second-to-last dimension (T) to be a multiple of chunk_size 40 | padding_size = (chunk_size - T % chunk_size) % chunk_size 41 | k_padded = F.pad(k, (0, 0, 0, padding_size, 0, 0, 0, 0)) 42 | k_padded = k_padded.reshape(B, H, -1, chunk_size, 64) 43 | A = (k_padded @ k_padded.transpose(-1, -2)).tril(-1) 44 | Ai = solve_tril(A.reshape(B, H, -1, chunk_size)[:, :, :T, :].transpose(1, 2)).transpose(1, 2) 45 | 46 | Ai_ref = torch.inverse(A + torch.eye(A.shape[-1], device=A.device)[None, None, None, ...]) 47 | Ai_ref = Ai_ref.reshape(B, H, -1, chunk_size)[:, :, :T, :] 48 | assert_close('solve_tril', Ai, Ai_ref, 0.0001) 49 | 50 | 51 | @pytest.mark.parametrize('H', test_h_list) 52 | @pytest.mark.parametrize('cu_seqlens', test_t_varlen_list) 53 | @pytest.mark.parametrize('chunk_size', [64, 32, 16]) 54 | @pytest.mark.skipif( 55 | os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', 56 | reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set' 57 | ) 58 | @pytest.mark.skipif( 59 | device_platform == 'intel', 60 | reason='Intel Pytorch Failure' 61 | ) 62 | def test_solve_tril_varlen(H, cu_seqlens, chunk_size): 63 | T = cu_seqlens[-1] 64 | cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) 65 | # Construct the input. otherwise inverse's condition number might be too large to measure the error 66 | k = F.normalize(torch.randn((1, T, H, 64), dtype=torch.bfloat16, device=device), dim=-1) 67 | beta = torch.randn((1, T, H), dtype=torch.bfloat16, device=device).sigmoid() 68 | A, _ = chunk_scaled_dot_kkt_fwd(k, beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size) 69 | Ai = solve_tril(A, cu_seqlens=cu_seqlens) 70 | 71 | Ai_ref = torch.zeros_like(Ai) 72 | for i in range(len(cu_seqlens) - 1): 73 | for j in range(cu_seqlens[i], cu_seqlens[i+1], chunk_size): 74 | actual_size = min(chunk_size, cu_seqlens[i+1] - j) 75 | Ai_ref[:, j:j+actual_size, :, :actual_size] = torch.inverse( 76 | A[:, j:j+actual_size, :, :actual_size].transpose(1, 2) + 77 | torch.eye(actual_size, device=A.device, dtype=A.dtype)[None, None, ...] 78 | ).transpose(1, 2) 79 | assert_close('solve_tril_varlen', Ai, Ai_ref, 0.0001) 80 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pytest 4 | import torch 5 | from transformers import AutoConfig, AutoModelForCausalLM 6 | 7 | from fla.models import ( 8 | ABCConfig, 9 | BitNetConfig, 10 | DeltaNetConfig, 11 | ForgettingTransformerConfig, 12 | GatedDeltaNetConfig, 13 | GatedDeltaProductConfig, 14 | GLAConfig, 15 | GSAConfig, 16 | HGRN2Config, 17 | HGRNConfig, 18 | LightNetConfig, 19 | LinearAttentionConfig, 20 | Mamba2Config, 21 | MambaConfig, 22 | NSAConfig, 23 | RetNetConfig, 24 | RWKV6Config, 25 | RWKV7Config, 26 | SambaConfig, 27 | TransformerConfig 28 | ) 29 | from fla.utils import assert_close, device, is_nvidia_hopper 30 | 31 | 32 | @pytest.mark.parametrize("L", [4]) 33 | @pytest.mark.parametrize("B", [8]) 34 | @pytest.mark.parametrize("T", [2048]) 35 | @pytest.mark.parametrize("H", [16]) 36 | @pytest.mark.parametrize("D", [128]) 37 | @pytest.mark.parametrize("config_class", [ 38 | ABCConfig, 39 | BitNetConfig, 40 | DeltaNetConfig, 41 | ForgettingTransformerConfig, 42 | GatedDeltaNetConfig, 43 | GatedDeltaProductConfig, 44 | GLAConfig, 45 | GSAConfig, 46 | HGRN2Config, 47 | HGRNConfig, 48 | LightNetConfig, 49 | LinearAttentionConfig, 50 | Mamba2Config, 51 | MambaConfig, 52 | NSAConfig, 53 | RetNetConfig, 54 | RWKV6Config, 55 | RWKV7Config, 56 | SambaConfig, 57 | TransformerConfig 58 | ]) 59 | @pytest.mark.parametrize("dtype", [torch.bfloat16]) 60 | @pytest.mark.skipif( 61 | is_nvidia_hopper is False, 62 | reason="Only run on Hopper GPUs" 63 | ) 64 | def test_model( 65 | L: int, 66 | B: int, 67 | T: int, 68 | H: int, 69 | D: int, 70 | config_class: AutoConfig, 71 | dtype: torch.dtype 72 | ): 73 | if config_class in [ 74 | ABCConfig, LinearAttentionConfig, LightNetConfig, 75 | Mamba2Config, MambaConfig, SambaConfig, GatedDeltaProductConfig 76 | ]: 77 | pytest.skip("Variable length not supported yet") 78 | config = config_class(**{ 79 | 'hidden_size': int(H * D), 80 | 'num_hidden_layers': L, 81 | **({'num_heads': H} if config_class != NSAConfig else {}) 82 | }) 83 | model = AutoModelForCausalLM.from_config(config) 84 | model.to(dtype).to(device) 85 | 86 | cu_seqlens = torch.cat([ 87 | torch.arange(0, B * T, T), 88 | torch.tensor([B * T], dtype=torch.long) 89 | ], 0).to(device).to(torch.int32) 90 | 91 | input_ids = torch.randint(low=0, high=config.vocab_size, size=(1, B * T)).to(device) 92 | output = model(input_ids.view(B, T), output_hidden_states=True).hidden_states[-1] 93 | assert output.shape == (B, T, config.hidden_size) 94 | 95 | output_var = model(input_ids, output_hidden_states=True, cu_seqlens=cu_seqlens).hidden_states[-1] 96 | assert output_var.shape == (1, B * T, config.hidden_size) 97 | assert_close('output', output.view(1, B * T, -1), output_var, ratio=1e-3) 98 | -------------------------------------------------------------------------------- /tests/utils/test_rwkv7_conversion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import json 5 | 6 | try: 7 | from lm_eval import evaluator 8 | from lm_eval.models.huggingface import HFLM 9 | except ImportError: 10 | evaluator = None 11 | HFLM = None 12 | from tokenizers import Tokenizer 13 | from transformers import PreTrainedTokenizerFast 14 | 15 | from fla.models.rwkv7 import RWKV7ForCausalLM 16 | 17 | 18 | if __name__ == '__main__': 19 | def test_rwkv7_lm_eval(model, tokenizer, task_names=["lambada_openai"]): 20 | tokenizer1 = PreTrainedTokenizerFast( 21 | tokenizer_object=tokenizer, 22 | eos_token="<|endoftext|>", 23 | pad_token="<|padding|>" 24 | ) 25 | hf_model = HFLM(pretrained=model, tokenizer=tokenizer1) 26 | results = evaluator.simple_evaluate( 27 | model=hf_model, 28 | tasks=task_names, 29 | batch_size=1, 30 | ) 31 | # { 32 | # "lambada_openai": { 33 | # "perplexity,none": 14.457888475382047, 34 | # "perplexity_stderr,none": 0.4455143803996477, 35 | # "acc,none": 0.4585678245682127, 36 | # "acc_stderr,none": 0.006942020515885241, 37 | # "alias": "lambada_openai" 38 | # } 39 | # } 40 | print(json.dumps(results['results'], indent=2)) 41 | 42 | # official results: 43 | # pile 168M: lambada_openai ppl 14.2 acc 45.6% 44 | # pile 421M: lambada_openai ppl 8.14 acc 55.6% 45 | # pile 1.47B: lambada_openai ppl 5.04 acc 64.9% 46 | parser = argparse.ArgumentParser(description='Convert RWKV7') 47 | parser.add_argument('model', type=str, help='path to model') 48 | parser.add_argument('tokenizer', type=str, help='path to tokenizer') 49 | parser.add_argument('--tasks', type=str, nargs='*', 50 | default=['lambada_openai']) 51 | args = parser.parse_args() 52 | 53 | model = RWKV7ForCausalLM.from_pretrained( 54 | args.model, 55 | torch_dtype="auto", 56 | device_map="cuda", 57 | ).half().eval() 58 | tokenizer = Tokenizer.from_file(args.tokenizer) 59 | 60 | test_rwkv7_lm_eval(model, tokenizer, task_names=["lambada_openai"]) 61 | --------------------------------------------------------------------------------