├── .gitignore ├── BENCHMARKS.md ├── README.md ├── assets ├── Bwd_Comparisson.png ├── Causal_Comparison.png ├── Fwd_Comparison.png └── Naive_Comparison.png ├── benchmark ├── benchmark_bwd_all.py ├── benchmark_bwd_causal.py ├── benchmark_bwd_naive.py ├── benchmark_fwd_all.py ├── benchmark_fwd_causal.py └── benchmark_fwd_naive.py ├── requirements.txt ├── setup.py ├── src └── ssd │ ├── __init__.py │ ├── bi │ ├── __init__.py │ ├── k_activations.py │ ├── layer_norm.py │ ├── layernorm_gated.py │ ├── selective_state_update.py │ ├── softplus.py │ ├── ssd_bmm.py │ ├── ssd_chunk_scan.py │ ├── ssd_chunk_state.py │ ├── ssd_combined.py │ └── ssd_state_passing.py │ ├── modules.py │ ├── uni │ ├── __init__.py │ ├── k_activations.py │ ├── layer_norm.py │ ├── layernorm_gated.py │ ├── selective_state_update.py │ ├── softplus.py │ ├── ssd_bmm.py │ ├── ssd_chunk_scan.py │ ├── ssd_chunk_state.py │ ├── ssd_combined.py │ └── ssd_state_passing.py │ └── version.py └── tests ├── conftest.py ├── fixtures ├── bwd.py ├── chunk_scan_bwd_dc.py ├── chunk_scan_bwd_dcb.py ├── chunk_scan_bwd_ddAcs_stable_bwd.py ├── chunk_scan_bwd_dstates.py ├── chunk_scan_chunk_state_bwd_dx.py ├── chunk_state_bwd_db.py ├── chunk_state_fwd.py ├── cumsum_fwd.py ├── fwd.py ├── state_passing_bwd.py ├── state_passing_fwd.py └── state_scan_fwd.py ├── test_bwd_chunk_scan.py ├── test_bwd_chunk_state.py ├── test_bwd_scan.py ├── test_bwd_state_passing.py ├── test_fwd_chunk_cumsum.py ├── test_fwd_chunk_state.py ├── test_fwd_scan.py ├── test_fwd_state_passing.py └── test_fwd_state_scan.py /.gitignore: -------------------------------------------------------------------------------- 1 | src/ssd.egg-info/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /BENCHMARKS.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/BENCHMARKS.md -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/README.md -------------------------------------------------------------------------------- /assets/Bwd_Comparisson.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/assets/Bwd_Comparisson.png -------------------------------------------------------------------------------- /assets/Causal_Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/assets/Causal_Comparison.png -------------------------------------------------------------------------------- /assets/Fwd_Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/assets/Fwd_Comparison.png -------------------------------------------------------------------------------- /assets/Naive_Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/assets/Naive_Comparison.png -------------------------------------------------------------------------------- /benchmark/benchmark_bwd_all.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/benchmark/benchmark_bwd_all.py -------------------------------------------------------------------------------- /benchmark/benchmark_bwd_causal.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/benchmark/benchmark_bwd_causal.py -------------------------------------------------------------------------------- /benchmark/benchmark_bwd_naive.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/benchmark/benchmark_bwd_naive.py -------------------------------------------------------------------------------- /benchmark/benchmark_fwd_all.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/benchmark/benchmark_fwd_all.py -------------------------------------------------------------------------------- /benchmark/benchmark_fwd_causal.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/benchmark/benchmark_fwd_causal.py -------------------------------------------------------------------------------- /benchmark/benchmark_fwd_naive.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/benchmark/benchmark_fwd_naive.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | triton 2 | torch 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/setup.py -------------------------------------------------------------------------------- /src/ssd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/__init__.py -------------------------------------------------------------------------------- /src/ssd/bi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/ssd/bi/k_activations.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/k_activations.py -------------------------------------------------------------------------------- /src/ssd/bi/layer_norm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/layer_norm.py -------------------------------------------------------------------------------- /src/ssd/bi/layernorm_gated.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/layernorm_gated.py -------------------------------------------------------------------------------- /src/ssd/bi/selective_state_update.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/selective_state_update.py -------------------------------------------------------------------------------- /src/ssd/bi/softplus.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/softplus.py -------------------------------------------------------------------------------- /src/ssd/bi/ssd_bmm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/ssd_bmm.py -------------------------------------------------------------------------------- /src/ssd/bi/ssd_chunk_scan.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/ssd_chunk_scan.py -------------------------------------------------------------------------------- /src/ssd/bi/ssd_chunk_state.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/ssd_chunk_state.py -------------------------------------------------------------------------------- /src/ssd/bi/ssd_combined.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/ssd_combined.py -------------------------------------------------------------------------------- /src/ssd/bi/ssd_state_passing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/bi/ssd_state_passing.py -------------------------------------------------------------------------------- /src/ssd/modules.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/modules.py -------------------------------------------------------------------------------- /src/ssd/uni/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/ssd/uni/k_activations.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/k_activations.py -------------------------------------------------------------------------------- /src/ssd/uni/layer_norm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/layer_norm.py -------------------------------------------------------------------------------- /src/ssd/uni/layernorm_gated.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/layernorm_gated.py -------------------------------------------------------------------------------- /src/ssd/uni/selective_state_update.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/selective_state_update.py -------------------------------------------------------------------------------- /src/ssd/uni/softplus.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/softplus.py -------------------------------------------------------------------------------- /src/ssd/uni/ssd_bmm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/ssd_bmm.py -------------------------------------------------------------------------------- /src/ssd/uni/ssd_chunk_scan.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/ssd_chunk_scan.py -------------------------------------------------------------------------------- /src/ssd/uni/ssd_chunk_state.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/ssd_chunk_state.py -------------------------------------------------------------------------------- /src/ssd/uni/ssd_combined.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/ssd_combined.py -------------------------------------------------------------------------------- /src/ssd/uni/ssd_state_passing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/src/ssd/uni/ssd_state_passing.py -------------------------------------------------------------------------------- /src/ssd/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/conftest.py -------------------------------------------------------------------------------- /tests/fixtures/bwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/bwd.py -------------------------------------------------------------------------------- /tests/fixtures/chunk_scan_bwd_dc.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/chunk_scan_bwd_dc.py -------------------------------------------------------------------------------- /tests/fixtures/chunk_scan_bwd_dcb.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/chunk_scan_bwd_dcb.py -------------------------------------------------------------------------------- /tests/fixtures/chunk_scan_bwd_ddAcs_stable_bwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/chunk_scan_bwd_ddAcs_stable_bwd.py -------------------------------------------------------------------------------- /tests/fixtures/chunk_scan_bwd_dstates.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/chunk_scan_bwd_dstates.py -------------------------------------------------------------------------------- /tests/fixtures/chunk_scan_chunk_state_bwd_dx.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/chunk_scan_chunk_state_bwd_dx.py -------------------------------------------------------------------------------- /tests/fixtures/chunk_state_bwd_db.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/chunk_state_bwd_db.py -------------------------------------------------------------------------------- /tests/fixtures/chunk_state_fwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/chunk_state_fwd.py -------------------------------------------------------------------------------- /tests/fixtures/cumsum_fwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/cumsum_fwd.py -------------------------------------------------------------------------------- /tests/fixtures/fwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/fwd.py -------------------------------------------------------------------------------- /tests/fixtures/state_passing_bwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/state_passing_bwd.py -------------------------------------------------------------------------------- /tests/fixtures/state_passing_fwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/state_passing_fwd.py -------------------------------------------------------------------------------- /tests/fixtures/state_scan_fwd.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/fixtures/state_scan_fwd.py -------------------------------------------------------------------------------- /tests/test_bwd_chunk_scan.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/test_bwd_chunk_scan.py -------------------------------------------------------------------------------- /tests/test_bwd_chunk_state.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/test_bwd_chunk_state.py -------------------------------------------------------------------------------- /tests/test_bwd_scan.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/test_bwd_scan.py -------------------------------------------------------------------------------- /tests/test_bwd_state_passing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/test_bwd_state_passing.py -------------------------------------------------------------------------------- /tests/test_fwd_chunk_cumsum.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/test_fwd_chunk_cumsum.py -------------------------------------------------------------------------------- /tests/test_fwd_chunk_state.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/test_fwd_chunk_state.py -------------------------------------------------------------------------------- /tests/test_fwd_scan.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/test_fwd_scan.py -------------------------------------------------------------------------------- /tests/test_fwd_state_passing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hprairie/Bi-Mamba2/HEAD/tests/test_fwd_state_passing.py -------------------------------------------------------------------------------- /tests/test_fwd_state_scan.py: -------------------------------------------------------------------------------- 1 | --------------------------------------------------------------------------------