├── mamba └── mamba-main │ ├── mamba_ssm │ ├── ops │ │ ├── __init__.py │ │ ├── triton │ │ │ ├── __init__.py │ │ │ └── selective_state_update.py │ │ └── selective_scan_interface.py │ ├── models │ │ ├── __init__.py │ │ ├── config_mamba.py │ │ └── mixer_seq_simple.py │ ├── modules │ │ ├── __init__.py │ │ └── mamba_simple.py │ ├── utils │ │ ├── __init__.py │ │ ├── hf.py │ │ └── generation.py │ └── __init__.py │ ├── .gitignore │ ├── AUTHORS │ ├── assets │ └── selection.png │ ├── .gitmodules │ ├── csrc │ └── selective_scan │ │ ├── selective_scan_bwd_fp16_real.cu │ │ ├── selective_scan_bwd_fp32_real.cu │ │ ├── selective_scan_bwd_bf16_real.cu │ │ ├── selective_scan_bwd_fp32_complex.cu │ │ ├── selective_scan_bwd_bf16_complex.cu │ │ ├── selective_scan_bwd_fp16_complex.cu │ │ ├── selective_scan_fwd_fp32.cu │ │ ├── selective_scan_fwd_fp16.cu │ │ ├── selective_scan_fwd_bf16.cu │ │ ├── static_switch.h │ │ ├── uninitialized_copy.cuh │ │ ├── selective_scan.h │ │ └── selective_scan_common.h │ ├── evals │ └── lm_harness_eval.py │ ├── tests │ └── ops │ │ ├── triton │ │ └── test_selective_state_update.py │ │ └── test_selective_scan.py │ ├── benchmarks │ └── benchmark_generation_mamba_simple.py │ ├── README.md │ ├── .github │ └── workflows │ │ └── publish.yaml │ ├── setup.py │ └── LICENSE ├── causal-conv1d └── causal-conv1d-main │ ├── AUTHORS │ ├── README.md │ ├── causal_conv1d │ ├── __init__.py │ └── causal_conv1d_interface.py │ ├── README copy.md │ ├── csrc │ ├── static_switch.h │ ├── causal_conv1d.h │ ├── causal_conv1d_common.h │ ├── causal_conv1d_update.cu │ └── causal_conv1d.cpp │ ├── LICENSE │ ├── setup.py │ └── tests │ └── test_causal_conv1d.py └── README.md /mamba/mamba-main/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | -------------------------------------------------------------------------------- /mamba/mamba-main/.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | *.egg-info/ 3 | build/ 4 | **.so 5 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/README.md: -------------------------------------------------------------------------------- 1 | # causal-conv1d 2 | Add support for ROCm 3 | -------------------------------------------------------------------------------- /mamba/mamba-main/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | Albert Gu, agu@andrew.cmu.edu 3 | -------------------------------------------------------------------------------- /mamba/mamba-main/assets/selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bigyu-777/env-mamba/HEAD/mamba/mamba-main/assets/selection.png -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/causal_conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.3.post1" 2 | 3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update 4 | -------------------------------------------------------------------------------- /mamba/mamba-main/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/lm-evaluation-harness"] 2 | path = 3rdparty/lm-evaluation-harness 3 | url = https://github.com/EleutherAI/lm-evaluation-harness/ 4 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.2.0.post1" 2 | 3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 4 | from mamba_ssm.modules.mamba_simple import Mamba 5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/models/config_mamba.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MambaConfig: 6 | 7 | d_model: int = 2560 8 | n_layer: int = 64 9 | vocab_size: int = 50277 10 | ssm_cfg: dict = field(default_factory=dict) 11 | rms_norm: bool = True 12 | residual_in_fp32: bool = True 13 | fused_add_norm: bool = True 14 | pad_vocab_size_multiple: int = 8 15 | tie_embeddings: bool = True 16 | -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_fwd_fp32.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_fwd_bf16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/README copy.md: -------------------------------------------------------------------------------- 1 | # Causal depthwise conv1d in CUDA with a PyTorch interface 2 | 3 | Features: 4 | - Support fp32, fp16, bf16. 5 | - Kernel size 2, 3, 4. 6 | 7 | ## How to use 8 | 9 | ``` 10 | from causal_conv1d import causal_conv1d_fn 11 | ``` 12 | 13 | ``` 14 | def causal_conv1d_fn(x, weight, bias=None, activation=None): 15 | """ 16 | x: (batch, dim, seqlen) 17 | weight: (dim, width) 18 | bias: (dim,) 19 | activation: either None or "silu" or "swish" 20 | 21 | out: (batch, dim, seqlen) 22 | """ 23 | ``` 24 | 25 | Equivalent to: 26 | ``` 27 | import torch.nn.functional as F 28 | 29 | F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen] 30 | ``` 31 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | static constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | static constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /mamba/mamba-main/evals/lm_harness_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import transformers 4 | from transformers import AutoTokenizer 5 | 6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.models.huggingface import HFLM 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.__main__ import cli_evaluate 12 | 13 | 14 | @register_model("mamba") 15 | class MambaEvalWrapper(HFLM): 16 | 17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM 18 | 19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", 20 | dtype=torch.float16): 21 | LM.__init__(self) 22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) 23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 25 | self.vocab_size = self.tokenizer.vocab_size 26 | self._batch_size = int(batch_size) if batch_size is not None else 64 27 | self._max_length = max_length 28 | self._device = torch.device(device) 29 | 30 | @property 31 | def batch_size(self): 32 | return self._batch_size 33 | 34 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 35 | raise NotImplementedError() 36 | 37 | 38 | if __name__ == "__main__": 39 | cli_evaluate() 40 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/csrc/causal_conv1d.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct ConvParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, dim, seqlen, width; 13 | bool silu_activation; 14 | 15 | index_t x_batch_stride; 16 | index_t x_c_stride; 17 | index_t x_l_stride; 18 | index_t weight_c_stride; 19 | index_t weight_width_stride; 20 | index_t out_batch_stride; 21 | index_t out_c_stride; 22 | index_t out_l_stride; 23 | 24 | index_t conv_state_batch_stride; 25 | index_t conv_state_c_stride; 26 | index_t conv_state_l_stride; 27 | 28 | // Common data pointers. 29 | void *__restrict__ x_ptr; 30 | void *__restrict__ weight_ptr; 31 | void *__restrict__ bias_ptr; 32 | void *__restrict__ out_ptr; 33 | 34 | void *__restrict__ conv_state_ptr; 35 | 36 | void *__restrict__ seq_idx_ptr; 37 | }; 38 | 39 | struct ConvParamsBwd: public ConvParamsBase { 40 | index_t dx_batch_stride; 41 | index_t dx_c_stride; 42 | index_t dx_l_stride; 43 | index_t dweight_c_stride; 44 | index_t dweight_width_stride; 45 | index_t dout_batch_stride; 46 | index_t dout_c_stride; 47 | index_t dout_l_stride; 48 | 49 | // Common data pointers. 50 | void *__restrict__ dx_ptr; 51 | void *__restrict__ dweight_ptr; 52 | void *__restrict__ dbias_ptr; 53 | void *__restrict__ dout_ptr; 54 | }; 55 | 56 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/csrc/causal_conv1d_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | //////////////////////////////////////////////////////////////////////////////////////////////////// 11 | 12 | template struct BytesToType {}; 13 | 14 | template<> struct BytesToType<16> { 15 | using Type = uint4; 16 | static_assert(sizeof(Type) == 16); 17 | }; 18 | 19 | template<> struct BytesToType<8> { 20 | using Type = uint64_t; 21 | static_assert(sizeof(Type) == 8); 22 | }; 23 | 24 | template<> struct BytesToType<4> { 25 | using Type = uint32_t; 26 | static_assert(sizeof(Type) == 4); 27 | }; 28 | 29 | template<> struct BytesToType<2> { 30 | using Type = uint16_t; 31 | static_assert(sizeof(Type) == 2); 32 | }; 33 | 34 | template<> struct BytesToType<1> { 35 | using Type = uint8_t; 36 | static_assert(sizeof(Type) == 1); 37 | }; 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | template 42 | struct SumOp { 43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 44 | }; 45 | 46 | template 47 | struct Allreduce { 48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 49 | template 50 | static __device__ inline T run(T x, Operator &op) { 51 | constexpr int OFFSET = THREADS / 2; 52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 53 | return Allreduce::run(x, op); 54 | } 55 | }; 56 | 57 | template<> 58 | struct Allreduce<2> { 59 | template 60 | static __device__ inline T run(T x, Operator &op) { 61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 62 | return x; 63 | } 64 | }; 65 | -------------------------------------------------------------------------------- /mamba/mamba-main/tests/ops/triton/test_selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 15 | # @pytest.mark.parametrize('itype', [torch.float16]) 16 | @pytest.mark.parametrize("has_z", [False, True]) 17 | # @pytest.mark.parametrize('has_z', [True]) 18 | @pytest.mark.parametrize("dstate", [16, 32, 64]) 19 | # @pytest.mark.parametrize("dstate", [16]) 20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 21 | # @pytest.mark.parametrize("dim", [2048]) 22 | def test_selective_state_update(dim, dstate, has_z, itype): 23 | device = "cuda" 24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) 25 | if itype == torch.bfloat16: 26 | rtol, atol = 1e-2, 5e-2 27 | # set seed 28 | torch.random.manual_seed(0) 29 | batch_size = 2 30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) 31 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype) 33 | dt_bias = torch.rand(dim, device=device) - 4.0 34 | A = -torch.rand(dim, dstate, device=device) - 1.0 35 | B = torch.randn(batch_size, dstate, device=device) 36 | C = torch.randn(batch_size, dstate, device=device) 37 | D = torch.randn(dim, device=device) 38 | if has_z: 39 | z = torch.randn_like(x) 40 | else: 41 | z = None 42 | state_ref = state.detach().clone() 43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 45 | 46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # env-mamba 2 | mamba在Windows的环境配置 3 | 这里复现的是两个大佬的mamba环境 4 | 5 | https://github.com/walking-shadow/Official_Remote_Sensing_Mamba 6 | 7 | https://github.com/MzeroMiko/VMamba 8 | 9 | 这两个环境对于我来说踩了不少的坑所以现在来简单介绍一下怎么配置好这个环境 10 | 11 | 我的电脑环境是window11 i7-14700kf 2080ti22g 12 | 13 | pytorch2.2+cuda12.2 14 | 15 | 这个环境是基于py11的,我们从conda开始 16 | 17 | ```shell 18 | 19 | conda create -n vmamba python==3.11 20 | 21 | ``` 22 | 23 | ```shell 24 | 25 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 26 | 27 | ``` 28 | 29 | 为了方便我们环境的搭建我在这两个大佬的仓库里面下载了这两个必要的环境 30 | 31 | 32 | # 这里是https://github.com/state-spaces/mamba 的编译 33 | 首先我们的环境必须要有一个triton的支持 34 | triton来源于https://github.com/jakaline-dev/Triton_win/releases/tag/3.0.0 35 | 36 | 我们下载完这个离线包之后放在本地文件夹下 37 | 38 | 39 | ```shell 40 | 41 | pip install triton-3.0.0-cp311-cp311-win_amd64.whl 42 | 43 | ``` 44 | 45 | ```shell 46 | 47 | cd mamba\mamba-main 48 | 49 | pip install -r requestment.txt 50 | 51 | pip install . 52 | 53 | ``` 54 | 55 | 56 | 57 | # 这里是https://github.com/SeanSong-amd/causal-conv1d 的编译 58 | 59 | 60 | 这里也是借鉴了大佬的库 61 | 62 | ```shell 63 | 64 | cd causal-conv1d\causal-conv1d-main 65 | 66 | pip install . 67 | ``` 68 | 69 | 目前是完成了https://github.com/walking-shadow/Official_Remote_Sensing_Mamba 的踩坑 70 | 71 | 检查环境中是否有mamba-ssm和causal-con1d 72 | 73 | 然后就可以使用这个mamba的模型啦 74 | 75 | # 接下来是https://github.com/MzeroMiko/VMamba 踩的坑 76 | 77 | 首先是kernals编译不起来 78 | 79 | 这个的问题出现在C++的编译器 80 | 81 | 我们首先要有C++的编译器,如果你有了visual studio这里就可以跳过了 82 | 83 | 参考https://github.com/MzeroMiko/VMamba/issues/95 84 | 85 | 是static_switch.h的constexpr出现了非常量的错误 86 | 87 | 我们手动添加一个static在constexpr前面 88 | 89 | ![image](https://github.com/learning-mamba/env-mamba/assets/66856290/ea4ae859-9fb6-4e6a-9c2a-11ce88acac8a) 90 | 91 | 然后在以下的文件里添加 92 | 93 | ```shell 94 | 95 | #ifndef M_LOG2E 96 | #define M_LOG2E 1.4426950408889634074 97 | #endif 98 | 99 | ``` 100 | 在这 101 | kernels/selective_scan/csrc/selective_scan/cus/selective_scan_bwd_kernel.cuh 102 | kernels/selective_scan/csrc/selective_scan/cus/selective_scan_fwd_kernel.cuh 103 | kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_bwd_kernel_ndstate.cuh 104 | kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_fwd_kernel_ndstate.cuh 105 | kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_bwd_kernel_oflex.cuh 106 | kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_fwd_kernel_oflex.cuh 107 | 108 | 像这样 109 | ![image](https://github.com/learning-mamba/env-mamba/assets/66856290/f8f23011-9043-4529-ab0d-a20cb6dde64f) 110 | 111 | 112 | 如果你前面的mamba没有安装的话也是运行不了这个项目的!!! 113 | 114 | -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /mamba/mamba-main/benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--minp", type=float, default=0.0) 26 | parser.add_argument("--repetition-penalty", type=float, default=1.0) 27 | parser.add_argument("--batch", type=int, default=1) 28 | args = parser.parse_args() 29 | 30 | repeats = 3 31 | device = "cuda" 32 | dtype = torch.float16 33 | 34 | print(f"Loading model {args.model_name}") 35 | is_mamba = args.model_name.startswith("state-spaces/mamba-") 36 | if is_mamba: 37 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 38 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 39 | else: 40 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 41 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 42 | model.eval() 43 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 44 | 45 | torch.random.manual_seed(0) 46 | if args.prompt is None: 47 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 48 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 49 | else: 50 | tokens = tokenizer(args.prompt, return_tensors="pt") 51 | input_ids = tokens.input_ids.to(device=device) 52 | attn_mask = tokens.attention_mask.to(device=device) 53 | max_length = input_ids.shape[1] + args.genlen 54 | 55 | if is_mamba: 56 | fn = lambda: model.generate( 57 | input_ids=input_ids, 58 | max_length=max_length, 59 | cg=True, 60 | return_dict_in_generate=True, 61 | output_scores=True, 62 | enable_timing=False, 63 | temperature=args.temperature, 64 | top_k=args.topk, 65 | top_p=args.topp, 66 | min_p=args.minp, 67 | repetition_penalty=args.repetition_penalty, 68 | ) 69 | else: 70 | fn = lambda: model.generate( 71 | input_ids=input_ids, 72 | attention_mask=attn_mask, 73 | max_length=max_length, 74 | return_dict_in_generate=True, 75 | pad_token_id=tokenizer.eos_token_id, 76 | do_sample=True, 77 | temperature=args.temperature, 78 | top_k=args.topk, 79 | top_p=args.topp, 80 | repetition_penalty=args.repetition_penalty, 81 | ) 82 | out = fn() 83 | if args.prompt is not None: 84 | print(tokenizer.batch_decode(out.sequences.tolist())) 85 | 86 | torch.cuda.synchronize() 87 | start = time.time() 88 | for _ in range(repeats): 89 | fn() 90 | torch.cuda.synchronize() 91 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 92 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 93 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/causal_conv1d/causal_conv1d_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | import causal_conv1d_cuda 8 | 9 | 10 | class CausalConv1dFn(torch.autograd.Function): 11 | 12 | @staticmethod 13 | def forward(ctx, x, weight, bias=None, seq_idx=None, activation=None): 14 | if activation not in [None, "silu", "swish"]: 15 | raise NotImplementedError("activation must be None, silu, or swish") 16 | if x.stride(2) != 1 and x.stride(1) != 1: 17 | x = x.contiguous() 18 | bias = bias.contiguous() if bias is not None else None 19 | seq_idx = seq_idx.contiguous() if seq_idx is not None else None 20 | ctx.save_for_backward(x, weight, bias, seq_idx) 21 | ctx.activation = activation in ["silu", "swish"] 22 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, seq_idx, ctx.activation) 23 | return out 24 | 25 | @staticmethod 26 | def backward(ctx, dout): 27 | x, weight, bias, seq_idx = ctx.saved_tensors 28 | if dout.stride(2) != 1 and dout.stride(1) != 1: 29 | dout = dout.contiguous() 30 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 31 | # backward of conv1d with the backward of chunk). 32 | # Here we just pass in None and dx will be allocated in the C++ code. 33 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( 34 | x, weight, bias, dout, seq_idx, None, ctx.activation 35 | ) 36 | return dx, dweight, dbias if bias is not None else None, None, None 37 | 38 | 39 | def causal_conv1d_fn(x, weight, bias=None, seq_idx=None, activation=None): 40 | """ 41 | x: (batch, dim, seqlen) 42 | weight: (dim, width) 43 | bias: (dim,) 44 | seq_idx: (batch, seqlen) 45 | activation: either None or "silu" or "swish" 46 | 47 | out: (batch, dim, seqlen) 48 | """ 49 | return CausalConv1dFn.apply(x, weight, bias, seq_idx, activation) 50 | 51 | 52 | def causal_conv1d_ref(x, weight, bias=None, activation=None): 53 | """ 54 | x: (batch, dim, seqlen) 55 | weight: (dim, width) 56 | bias: (dim,) 57 | 58 | out: (batch, dim, seqlen) 59 | """ 60 | if activation not in [None, "silu", "swish"]: 61 | raise NotImplementedError("activation must be None, silu, or swish") 62 | dtype_in = x.dtype 63 | x = x.to(weight.dtype) 64 | seqlen = x.shape[-1] 65 | dim, width = weight.shape 66 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 67 | out = out[..., :seqlen] 68 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 69 | 70 | 71 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): 72 | """ 73 | x: (batch, dim) 74 | conv_state: (batch, dim, width) 75 | weight: (dim, width) 76 | bias: (dim,) 77 | 78 | out: (batch, dim) 79 | """ 80 | if activation not in [None, "silu", "swish"]: 81 | raise NotImplementedError("activation must be None, silu, or swish") 82 | activation = activation in ["silu", "swish"] 83 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) 84 | 85 | 86 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): 87 | """ 88 | x: (batch, dim) 89 | conv_state: (batch, dim, width) 90 | weight: (dim, width) 91 | bias: (dim,) 92 | 93 | out: (batch, dim) 94 | """ 95 | if activation not in [None, "silu", "swish"]: 96 | raise NotImplementedError("activation must be None, silu, or swish") 97 | dtype_in = x.dtype 98 | batch, dim = x.shape 99 | width = weight.shape[1] 100 | assert conv_state.shape == (batch, dim, width) 101 | assert weight.shape == (dim, width) 102 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 103 | conv_state[:, :, -1] = x 104 | out = torch.sum(conv_state * weight, dim=-1) # (B D) 105 | if bias is not None: 106 | out += bias 107 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 108 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/csrc/causal_conv1d_update.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #include 10 | #include 11 | 12 | #include "causal_conv1d.h" 13 | #include "causal_conv1d_common.h" 14 | #include "static_switch.h" 15 | 16 | template 17 | struct Causal_conv1d_update_kernel_traits { 18 | using input_t = input_t_; 19 | using weight_t = weight_t_; 20 | static constexpr int kNThreads = kNThreads_; 21 | static constexpr int kWidth = kWidth_; 22 | static constexpr int kNBytes = sizeof(input_t); 23 | static_assert(kNBytes == 2 || kNBytes == 4); 24 | }; 25 | 26 | template 27 | __global__ __launch_bounds__(Ktraits::kNThreads) 28 | void causal_conv1d_update_kernel(ConvParamsBase params) { 29 | constexpr int kWidth = Ktraits::kWidth; 30 | constexpr int kNThreads = Ktraits::kNThreads; 31 | using input_t = typename Ktraits::input_t; 32 | using weight_t = typename Ktraits::weight_t; 33 | 34 | const int tidx = threadIdx.x; 35 | const int batch_id = blockIdx.x; 36 | const int channel_id = blockIdx.y * kNThreads + tidx; 37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 38 | + channel_id * params.x_c_stride; 39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride 40 | + channel_id * params.conv_state_c_stride; 41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; 42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 43 | + channel_id * params.out_c_stride; 44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); 45 | 46 | float weight_vals[kWidth] = {0}; 47 | if (channel_id < params.dim) { 48 | #pragma unroll 49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } 50 | } 51 | 52 | float x_vals[kWidth] = {0}; 53 | if (channel_id < params.dim) { 54 | #pragma unroll 55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } 56 | x_vals[kWidth - 1] = float(x[0]); 57 | #pragma unroll 58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } 59 | } 60 | 61 | float out_val = bias_val; 62 | #pragma unroll 63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } 64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } 65 | if (channel_id < params.dim) { out[0] = input_t(out_val); } 66 | } 67 | 68 | template 69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 70 | using Ktraits = Causal_conv1d_update_kernel_traits; 71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); 72 | auto kernel = &causal_conv1d_update_kernel; 73 | kernel<<>>(params); 74 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 75 | } 76 | 77 | template 78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 79 | if (params.width == 2) { 80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); 81 | } else if (params.width == 3) { 82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); 83 | } else if (params.width == 4) { 84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); 85 | } 86 | } 87 | 88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/ops/triton/selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | """We want triton==2.1.0 for this 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import triton 11 | import triton.language as tl 12 | 13 | from einops import rearrange, repeat 14 | 15 | 16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) 17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) 18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) 19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) 20 | @triton.jit 21 | def _selective_scan_update_kernel( 22 | # Pointers to matrices 23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, 24 | # Matrix dimensions 25 | batch, dim, dstate, 26 | # Strides 27 | stride_state_batch, stride_state_dim, stride_state_dstate, 28 | stride_x_batch, stride_x_dim, 29 | stride_dt_batch, stride_dt_dim, 30 | stride_dt_bias_dim, 31 | stride_A_dim, stride_A_dstate, 32 | stride_B_batch, stride_B_dstate, 33 | stride_C_batch, stride_C_dstate, 34 | stride_D_dim, 35 | stride_z_batch, stride_z_dim, 36 | stride_out_batch, stride_out_dim, 37 | # Meta-parameters 38 | DT_SOFTPLUS: tl.constexpr, 39 | BLOCK_SIZE_M: tl.constexpr, 40 | HAS_DT_BIAS: tl.constexpr, 41 | HAS_D: tl.constexpr, 42 | HAS_Z: tl.constexpr, 43 | BLOCK_SIZE_DSTATE: tl.constexpr, 44 | ): 45 | pid_m = tl.program_id(axis=0) 46 | pid_b = tl.program_id(axis=1) 47 | state_ptr += pid_b * stride_state_batch 48 | x_ptr += pid_b * stride_x_batch 49 | dt_ptr += pid_b * stride_dt_batch 50 | B_ptr += pid_b * stride_B_batch 51 | C_ptr += pid_b * stride_C_batch 52 | if HAS_Z: 53 | z_ptr += pid_b * stride_z_batch 54 | out_ptr += pid_b * stride_out_batch 55 | 56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) 58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) 59 | x_ptrs = x_ptr + offs_m * stride_x_dim 60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim 61 | if HAS_DT_BIAS: 62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim 63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) 64 | B_ptrs = B_ptr + offs_n * stride_B_dstate 65 | C_ptrs = C_ptr + offs_n * stride_C_dstate 66 | if HAS_D: 67 | D_ptrs = D_ptr + offs_m * stride_D_dim 68 | if HAS_Z: 69 | z_ptrs = z_ptr + offs_m * stride_z_dim 70 | out_ptrs = out_ptr + offs_m * stride_out_dim 71 | 72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) 73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 75 | if HAS_DT_BIAS: 76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 77 | if DT_SOFTPLUS: 78 | dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) 79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) 80 | dA = tl.exp(A * dt[:, None]) 81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 83 | if HAS_D: 84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 85 | if HAS_Z: 86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 87 | 88 | dB = B[None, :] * dt[:, None] 89 | state = state * dA + dB * x[:, None] 90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) 91 | out = tl.sum(state * C[None, :], axis=1) 92 | if HAS_D: 93 | out += x * D 94 | if HAS_Z: 95 | out *= z * tl.sigmoid(z) 96 | tl.store(out_ptrs, out, mask=offs_m < dim) 97 | 98 | 99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 100 | """ 101 | Argument: 102 | state: (batch, dim, dstate) 103 | x: (batch, dim) 104 | dt: (batch, dim) 105 | A: (dim, dstate) 106 | B: (batch, dstate) 107 | C: (batch, dstate) 108 | D: (dim,) 109 | z: (batch, dim) 110 | dt_bias: (dim,) 111 | Return: 112 | out: (batch, dim) 113 | """ 114 | batch, dim, dstate = state.shape 115 | assert x.shape == (batch, dim) 116 | assert dt.shape == x.shape 117 | assert A.shape == (dim, dstate) 118 | assert B.shape == (batch, dstate) 119 | assert C.shape == B.shape 120 | if D is not None: 121 | assert D.shape == (dim,) 122 | if z is not None: 123 | assert z.shape == x.shape 124 | if dt_bias is not None: 125 | assert dt_bias.shape == (dim,) 126 | out = torch.empty_like(x) 127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) 128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) 129 | # We don't want autotune since it will overwrite the state 130 | # We instead tune by hand. 131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 132 | else ((16, 4) if dstate <= 32 else 133 | ((8, 4) if dstate <= 64 else 134 | ((4, 4) if dstate <= 128 else 135 | ((4, 8)))))) 136 | with torch.cuda.device(x.device.index): 137 | _selective_scan_update_kernel[grid]( 138 | state, x, dt, dt_bias, A, B, C, D, z, out, 139 | batch, dim, dstate, 140 | state.stride(0), state.stride(1), state.stride(2), 141 | x.stride(0), x.stride(1), 142 | dt.stride(0), dt.stride(1), 143 | dt_bias.stride(0) if dt_bias is not None else 0, 144 | A.stride(0), A.stride(1), 145 | B.stride(0), B.stride(1), 146 | C.stride(0), C.stride(1), 147 | D.stride(0) if D is not None else 0, 148 | z_strides[0], z_strides[1], 149 | out.stride(0), out.stride(1), 150 | dt_softplus, 151 | BLOCK_SIZE_M, 152 | num_warps=num_warps, 153 | ) 154 | return out 155 | 156 | 157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 158 | """ 159 | Argument: 160 | state: (batch, dim, dstate) 161 | x: (batch, dim) 162 | dt: (batch, dim) 163 | A: (dim, dstate) 164 | B: (batch, dstate) 165 | C: (batch, dstate) 166 | D: (dim,) 167 | z: (batch, dim) 168 | dt_bias: (dim,) 169 | Return: 170 | out: (batch, dim) 171 | """ 172 | batch, dim, dstate = state.shape 173 | assert x.shape == (batch, dim) 174 | assert dt.shape == x.shape 175 | assert A.shape == (dim, dstate) 176 | assert B.shape == (batch, dstate) 177 | assert C.shape == B.shape 178 | if D is not None: 179 | assert D.shape == (dim,) 180 | if z is not None: 181 | assert z.shape == x.shape 182 | if dt_bias is not None: 183 | assert dt_bias.shape == (dim,) 184 | dt = dt + dt_bias 185 | dt = F.softplus(dt) if dt_softplus else dt 186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) 187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) 188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate 189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) 190 | if D is not None: 191 | out += (x * D).to(out.dtype) 192 | return (out if z is None else out * F.silu(z)).to(x.dtype) 193 | -------------------------------------------------------------------------------- /mamba/mamba-main/README.md: -------------------------------------------------------------------------------- 1 | # Mamba 2 | 3 | ![Mamba](assets/selection.png "Selective State Space") 4 | > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ 5 | > Albert Gu*, Tri Dao*\ 6 | > Paper: https://arxiv.org/abs/2312.00752 7 | 8 | ## About 9 | 10 | Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. 11 | It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4), 12 | with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention). 13 | 14 | ## Installation 15 | 16 | - [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. 17 | - `pip install mamba-ssm`: the core Mamba package. 18 | 19 | It can also be built from source with `pip install .` from this repository. 20 | 21 | If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`. 22 | 23 | Other requirements: 24 | - Linux 25 | - NVIDIA GPU 26 | - PyTorch 1.12+ 27 | - CUDA 11.6+ 28 | 29 | ## Usage 30 | 31 | We expose several levels of interface with the Mamba model. 32 | 33 | ### Selective SSM 34 | 35 | Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2). 36 | 37 | Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py). 38 | 39 | ### Mamba Block 40 | 41 | The main module of this repository is the Mamba architecture block wrapping the selective SSM. 42 | 43 | Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py). 44 | 45 | Usage: 46 | ``` 47 | import torch 48 | from mamba_ssm import Mamba 49 | 50 | batch, length, dim = 2, 64, 16 51 | x = torch.randn(batch, length, dim).to("cuda") 52 | model = Mamba( 53 | # This module uses roughly 3 * expand * d_model^2 parameters 54 | d_model=dim, # Model dimension d_model 55 | d_state=16, # SSM state expansion factor 56 | d_conv=4, # Local convolution width 57 | expand=2, # Block expansion factor 58 | ).to("cuda") 59 | y = model(x) 60 | assert y.shape == x.shape 61 | ``` 62 | 63 | ### Mamba Language Model 64 | 65 | Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. 66 | 67 | Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py). 68 | 69 | This is an example of how to integrate Mamba into an end-to-end neural network. 70 | This example is used in the generation scripts below. 71 | 72 | 73 | 74 | ## Pretrained Models 75 | 76 | Pretrained models are uploaded to 77 | [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`, 78 | `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj` 79 | (trained on 600B tokens on the SlimPajama dataset). 80 | 81 | 82 | The models will be autodownloaded by the generation script below. 83 | 84 | These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models: 85 | 86 | | Parameters | Layers | Model dim. | 87 | |------------|--------|------------| 88 | | 130M | 24 | 768 | 89 | | 370M | 48 | 1024 | 90 | | 790M | 48 | 1536 | 91 | | 1.4B | 48 | 2048 | 92 | | 2.8B | 64 | 2560 | 93 | 94 | (The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.) 95 | 96 | Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). 97 | Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models. 98 | 99 | 100 | ## Evaluations 101 | 102 | To run zero-shot evaluations of models (corresponding to Table 3 of the paper), 103 | we use the 104 | [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) 105 | library. 106 | 107 | 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init 108 | --recursive`. We use the `big-refactor` branch. 109 | 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`. 110 | On Python 3.10 you might need to manually install the latest version of `promptsource`: `pip install git+https://github.com/bigscience-workshop/promptsource.git`. 111 | 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo): 112 | ``` 113 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 114 | python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 115 | ``` 116 | 117 | To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts: 118 | ``` 119 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 64 120 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 64 121 | ``` 122 | 123 | Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process. 124 | 125 | ## Inference 126 | 127 | The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py) 128 | 1. autoloads a model from the Hugging Face Hub, 129 | 2. generates completions of a user-specified prompt, 130 | 3. benchmarks the inference speed of this generation. 131 | 132 | Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature. 133 | 134 | ### Examples 135 | 136 | To test generation latency (e.g. batch size = 1) with different sampling strategies: 137 | 138 | ``` 139 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 140 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 141 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2 142 | ``` 143 | 144 | To test generation throughput with random prompts (e.g. large batch size): 145 | ``` 146 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128 147 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128 148 | ``` 149 | 150 | 151 | ## Troubleshooting 152 | 153 | ### Precision 154 | Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary. 155 | On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation). 156 | 157 | We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities, 158 | as a first step please try a framework storing parameters in fp32 (such as AMP). 159 | 160 | ### Initialization 161 | Some parts of the model have initializations inherited from prior work on S4 models. 162 | For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection. 163 | However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero). 164 | If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework) 165 | that is specific to the training framework. 166 | 167 | 168 | ## Citation 169 | 170 | If you use this codebase, or otherwise found our work valuable, please cite Mamba: 171 | ``` 172 | @article{mamba, 173 | title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, 174 | author={Gu, Albert and Dao, Tri}, 175 | journal={arXiv preprint arXiv:2312.00752}, 176 | year={2023} 177 | } 178 | ``` 179 | -------------------------------------------------------------------------------- /mamba/mamba-main/csrc/selective_scan/selective_scan_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For scalar_value_type 10 | 11 | #define MAX_DSTATE 256 12 | 13 | using complex_t = c10::complex; 14 | 15 | inline __device__ float2 operator+(const float2 & a, const float2 & b){ 16 | return {a.x + b.x, a.y + b.y}; 17 | } 18 | 19 | inline __device__ float3 operator+(const float3 &a, const float3 &b) { 20 | return {a.x + b.x, a.y + b.y, a.z + b.z}; 21 | } 22 | 23 | inline __device__ float4 operator+(const float4 & a, const float4 & b){ 24 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; 25 | } 26 | 27 | //////////////////////////////////////////////////////////////////////////////////////////////////// 28 | 29 | template struct BytesToType {}; 30 | 31 | template<> struct BytesToType<16> { 32 | using Type = uint4; 33 | static_assert(sizeof(Type) == 16); 34 | }; 35 | 36 | template<> struct BytesToType<8> { 37 | using Type = uint64_t; 38 | static_assert(sizeof(Type) == 8); 39 | }; 40 | 41 | template<> struct BytesToType<4> { 42 | using Type = uint32_t; 43 | static_assert(sizeof(Type) == 4); 44 | }; 45 | 46 | template<> struct BytesToType<2> { 47 | using Type = uint16_t; 48 | static_assert(sizeof(Type) == 2); 49 | }; 50 | 51 | template<> struct BytesToType<1> { 52 | using Type = uint8_t; 53 | static_assert(sizeof(Type) == 1); 54 | }; 55 | 56 | //////////////////////////////////////////////////////////////////////////////////////////////////// 57 | 58 | template 59 | struct Converter{ 60 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { 61 | #pragma unroll 62 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; } 63 | } 64 | }; 65 | 66 | template 67 | struct Converter{ 68 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { 69 | static_assert(N % 2 == 0); 70 | auto &src2 = reinterpret_cast(src); 71 | auto &dst2 = reinterpret_cast(dst); 72 | #pragma unroll 73 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } 74 | } 75 | }; 76 | 77 | #if __CUDA_ARCH__ >= 800 78 | template 79 | struct Converter{ 80 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { 81 | static_assert(N % 2 == 0); 82 | auto &src2 = reinterpret_cast(src); 83 | auto &dst2 = reinterpret_cast(dst); 84 | #pragma unroll 85 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } 86 | } 87 | }; 88 | #endif 89 | 90 | //////////////////////////////////////////////////////////////////////////////////////////////////// 91 | 92 | // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp 93 | // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 94 | __device__ __forceinline__ complex_t cexp2f(complex_t z) { 95 | float t = exp2f(z.real_); 96 | float c, s; 97 | sincosf(z.imag_, &s, &c); 98 | return complex_t(c * t, s * t); 99 | } 100 | 101 | __device__ __forceinline__ complex_t cexpf(complex_t z) { 102 | float t = expf(z.real_); 103 | float c, s; 104 | sincosf(z.imag_, &s, &c); 105 | return complex_t(c * t, s * t); 106 | } 107 | 108 | template struct SSMScanOp; 109 | 110 | template<> 111 | struct SSMScanOp { 112 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { 113 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); 114 | } 115 | }; 116 | 117 | template<> 118 | struct SSMScanOp { 119 | __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { 120 | complex_t a0 = complex_t(ab0.x, ab0.y); 121 | complex_t b0 = complex_t(ab0.z, ab0.w); 122 | complex_t a1 = complex_t(ab1.x, ab1.y); 123 | complex_t b1 = complex_t(ab1.z, ab1.w); 124 | complex_t out_a = a1 * a0; 125 | complex_t out_b = a1 * b0 + b1; 126 | return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); 127 | } 128 | }; 129 | 130 | // A stateful callback functor that maintains a running prefix to be applied 131 | // during consecutive scan operations. 132 | template struct SSMScanPrefixCallbackOp { 133 | using scan_t = std::conditional_t, float2, float4>; 134 | scan_t running_prefix; 135 | // Constructor 136 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} 137 | // Callback operator to be entered by the first warp of threads in the block. 138 | // Thread-0 is responsible for returning a value for seeding the block-wide scan. 139 | __device__ scan_t operator()(scan_t block_aggregate) { 140 | scan_t old_prefix = running_prefix; 141 | running_prefix = SSMScanOp()(running_prefix, block_aggregate); 142 | return old_prefix; 143 | } 144 | }; 145 | 146 | //////////////////////////////////////////////////////////////////////////////////////////////////// 147 | 148 | template 149 | inline __device__ void load_input(typename Ktraits::input_t *u, 150 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], 151 | typename Ktraits::BlockLoadT::TempStorage &smem_load, 152 | int seqlen) { 153 | if constexpr (Ktraits::kIsEvenLen) { 154 | auto& smem_load_vec = reinterpret_cast(smem_load); 155 | using vec_t = typename Ktraits::vec_t; 156 | Ktraits::BlockLoadVecT(smem_load_vec).Load( 157 | reinterpret_cast(u), 158 | reinterpret_cast(u_vals) 159 | ); 160 | } else { 161 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); 162 | } 163 | } 164 | 165 | template 166 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar, 167 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], 168 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, 169 | int seqlen) { 170 | constexpr int kNItems = Ktraits::kNItems; 171 | if constexpr (!Ktraits::kIsComplex) { 172 | typename Ktraits::input_t B_vals_load[kNItems]; 173 | if constexpr (Ktraits::kIsEvenLen) { 174 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 175 | using vec_t = typename Ktraits::vec_t; 176 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 177 | reinterpret_cast(Bvar), 178 | reinterpret_cast(B_vals_load) 179 | ); 180 | } else { 181 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 182 | } 183 | // #pragma unroll 184 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } 185 | Converter::to_float(B_vals_load, B_vals); 186 | } else { 187 | typename Ktraits::input_t B_vals_load[kNItems * 2]; 188 | if constexpr (Ktraits::kIsEvenLen) { 189 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 190 | using vec_t = typename Ktraits::vec_t; 191 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 192 | reinterpret_cast(Bvar), 193 | reinterpret_cast(B_vals_load) 194 | ); 195 | } else { 196 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 197 | } 198 | #pragma unroll 199 | for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } 200 | } 201 | } 202 | 203 | template 204 | inline __device__ void store_output(typename Ktraits::input_t *out, 205 | const float (&out_vals)[Ktraits::kNItems], 206 | typename Ktraits::BlockStoreT::TempStorage &smem_store, 207 | int seqlen) { 208 | typename Ktraits::input_t write_vals[Ktraits::kNItems]; 209 | #pragma unroll 210 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } 211 | if constexpr (Ktraits::kIsEvenLen) { 212 | auto& smem_store_vec = reinterpret_cast(smem_store); 213 | using vec_t = typename Ktraits::vec_t; 214 | Ktraits::BlockStoreVecT(smem_store_vec).Store( 215 | reinterpret_cast(out), 216 | reinterpret_cast(write_vals) 217 | ); 218 | } else { 219 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /mamba/mamba-main/.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will: 2 | # - Create a new Github release 3 | # - Build wheels for supported architectures 4 | # - Deploy the wheels to the Github release 5 | # - Release the static code to PyPi 6 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 7 | 8 | name: Build wheels and deploy 9 | 10 | on: 11 | create: 12 | tags: 13 | - v* 14 | 15 | jobs: 16 | 17 | setup_release: 18 | name: Create Release 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: Get the tag version 22 | id: extract_branch 23 | run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} 24 | shell: bash 25 | 26 | - name: Create Release 27 | id: create_release 28 | uses: actions/create-release@v1 29 | env: 30 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 31 | with: 32 | tag_name: ${{ steps.extract_branch.outputs.branch }} 33 | release_name: ${{ steps.extract_branch.outputs.branch }} 34 | 35 | build_wheels: 36 | name: Build Wheel 37 | needs: setup_release 38 | runs-on: ${{ matrix.os }} 39 | 40 | strategy: 41 | fail-fast: false 42 | matrix: 43 | # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the 44 | # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. 45 | os: [ubuntu-20.04] 46 | python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] 47 | torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105'] 48 | cuda-version: ['11.8.0', '12.2.2'] 49 | # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. 50 | # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. 51 | # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) 52 | # when building without C++11 ABI and using it on nvcr images. 53 | cxx11_abi: ['FALSE', 'TRUE'] 54 | exclude: 55 | # Pytorch < 2.2 does not support Python 3.12 56 | - torch-version: '1.12.1' 57 | python-version: '3.12' 58 | - torch-version: '1.13.1' 59 | python-version: '3.12' 60 | - torch-version: '2.0.1' 61 | python-version: '3.12' 62 | - torch-version: '2.1.2' 63 | python-version: '3.12' 64 | # Pytorch <= 1.12 does not support Python 3.11 65 | - torch-version: '1.12.1' 66 | python-version: '3.11' 67 | # Pytorch >= 2.0 only supports Python >= 3.8 68 | - torch-version: '2.0.1' 69 | python-version: '3.7' 70 | - torch-version: '2.1.2' 71 | python-version: '3.7' 72 | - torch-version: '2.2.0' 73 | python-version: '3.7' 74 | - torch-version: '2.3.0.dev20240105' 75 | python-version: '3.7' 76 | # Pytorch <= 2.0 only supports CUDA <= 11.8 77 | - torch-version: '1.12.1' 78 | cuda-version: '12.2.2' 79 | - torch-version: '1.13.1' 80 | cuda-version: '12.2.2' 81 | - torch-version: '2.0.1' 82 | cuda-version: '12.2.2' 83 | 84 | steps: 85 | - name: Checkout 86 | uses: actions/checkout@v3 87 | 88 | - name: Set up Python 89 | uses: actions/setup-python@v4 90 | with: 91 | python-version: ${{ matrix.python-version }} 92 | 93 | - name: Set CUDA and PyTorch versions 94 | run: | 95 | echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV 96 | echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV 97 | 98 | - name: Free up disk space 99 | if: ${{ runner.os == 'Linux' }} 100 | # https://github.com/easimon/maximize-build-space/blob/master/action.yml 101 | # https://github.com/easimon/maximize-build-space/tree/test-report 102 | run: | 103 | sudo rm -rf /usr/share/dotnet 104 | sudo rm -rf /opt/ghc 105 | sudo rm -rf /opt/hostedtoolcache/CodeQL 106 | 107 | - name: Set up swap space 108 | if: runner.os == 'Linux' 109 | uses: pierotofy/set-swap-space@v1.0 110 | with: 111 | swap-size-gb: 10 112 | 113 | - name: Install CUDA ${{ matrix.cuda-version }} 114 | if: ${{ matrix.cuda-version != 'cpu' }} 115 | uses: Jimver/cuda-toolkit@v0.2.14 116 | id: cuda-toolkit 117 | with: 118 | cuda: ${{ matrix.cuda-version }} 119 | linux-local-args: '["--toolkit"]' 120 | # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 121 | # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} 122 | method: 'network' 123 | # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions, 124 | # not just nvcc 125 | # sub-packages: '["nvcc"]' 126 | 127 | - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} 128 | run: | 129 | pip install --upgrade pip 130 | # If we don't install before installing Pytorch, we get error for torch 2.0.1 131 | # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none) 132 | pip install lit 133 | # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools 134 | pip install setuptools 135 | # We want to figure out the CUDA version to download pytorch 136 | # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 137 | # This code is ugly, maybe there's a better way to do this. 138 | export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ 139 | minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \ 140 | maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \ 141 | print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \ 142 | ) 143 | if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then 144 | if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then 145 | # --no-deps because we can't install old versions of pytorch-triton 146 | pip install typing-extensions jinja2 147 | pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl 148 | else 149 | pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} 150 | fi 151 | else 152 | pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} 153 | fi 154 | nvcc --version 155 | python --version 156 | python -c "import torch; print('PyTorch:', torch.__version__)" 157 | python -c "import torch; print('CUDA:', torch.version.cuda)" 158 | python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" 159 | shell: 160 | bash 161 | 162 | - name: Build wheel 163 | run: | 164 | # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 165 | # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 166 | # However this still fails so I'm using a newer version of setuptools 167 | pip install setuptools==68.0.0 168 | pip install ninja packaging wheel 169 | export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH 170 | export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH 171 | # Limit MAX_JOBS otherwise the github runner goes OOM 172 | MAX_JOBS=2 MAMBA_FORCE_BUILD="TRUE" MAMBA_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist 173 | tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} 174 | wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") 175 | ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} 176 | echo "wheel_name=${wheel_name}" >> $GITHUB_ENV 177 | 178 | - name: Log Built Wheels 179 | run: | 180 | ls dist 181 | 182 | - name: Get the tag version 183 | id: extract_branch 184 | run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} 185 | 186 | - name: Get Release with tag 187 | id: get_current_release 188 | uses: joutvhu/get-release@v1 189 | with: 190 | tag_name: ${{ steps.extract_branch.outputs.branch }} 191 | env: 192 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 193 | 194 | - name: Upload Release Asset 195 | id: upload_release_asset 196 | uses: actions/upload-release-asset@v1 197 | env: 198 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 199 | with: 200 | upload_url: ${{ steps.get_current_release.outputs.upload_url }} 201 | asset_path: ./dist/${{env.wheel_name}} 202 | asset_name: ${{env.wheel_name}} 203 | asset_content_type: application/* 204 | 205 | publish_package: 206 | name: Publish package 207 | needs: [build_wheels] 208 | 209 | runs-on: ubuntu-latest 210 | 211 | steps: 212 | - uses: actions/checkout@v3 213 | 214 | - uses: actions/setup-python@v4 215 | with: 216 | python-version: '3.10' 217 | 218 | - name: Install dependencies 219 | run: | 220 | pip install ninja packaging setuptools wheel twine 221 | # We don't want to download anything CUDA-related here 222 | pip install torch --index-url https://download.pytorch.org/whl/cpu 223 | 224 | - name: Build core package 225 | env: 226 | MAMBA_SKIP_CUDA_BUILD: "TRUE" 227 | run: | 228 | python setup.py sdist --dist-dir=dist 229 | 230 | - name: Deploy 231 | env: 232 | TWINE_USERNAME: "__token__" 233 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 234 | run: | 235 | python -m twine upload dist/* 236 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Tri Dao. 2 | 3 | import sys 4 | import warnings 5 | import os 6 | import re 7 | import shutil 8 | import ast 9 | from pathlib import Path 10 | from packaging.version import parse, Version 11 | import platform 12 | 13 | from setuptools import setup, find_packages 14 | import subprocess 15 | 16 | import urllib.request 17 | import urllib.error 18 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 19 | 20 | import torch 21 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME 22 | 23 | 24 | with open("README.md", "r", encoding="utf-8") as fh: 25 | long_description = fh.read() 26 | 27 | 28 | # ninja build does not work unless include_dirs are abs path 29 | this_dir = os.path.dirname(os.path.abspath(__file__)) 30 | 31 | PACKAGE_NAME = "causal_conv1d" 32 | 33 | BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}" 34 | 35 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 36 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 37 | FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE" 38 | SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 39 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 40 | FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE" 41 | 42 | 43 | def get_platform(): 44 | """ 45 | Returns the platform name as used in wheel filenames. 46 | """ 47 | if sys.platform.startswith("linux"): 48 | return "linux_x86_64" 49 | elif sys.platform == "darwin": 50 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 51 | return f"macosx_{mac_version}_x86_64" 52 | elif sys.platform == "win32": 53 | return "win_amd64" 54 | else: 55 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 56 | 57 | 58 | def get_cuda_bare_metal_version(cuda_dir): 59 | raw_output = subprocess.check_output( 60 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 61 | ) 62 | output = raw_output.split() 63 | release_idx = output.index("release") + 1 64 | bare_metal_version = parse(output[release_idx].split(",")[0]) 65 | 66 | return raw_output, bare_metal_version 67 | 68 | 69 | def check_if_cuda_home_none(global_option: str) -> None: 70 | if CUDA_HOME is not None: 71 | return 72 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 73 | # in that case. 74 | warnings.warn( 75 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 76 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 77 | "only images whose names contain 'devel' will provide nvcc." 78 | ) 79 | 80 | 81 | def append_nvcc_threads(nvcc_extra_args): 82 | return nvcc_extra_args + ["--threads", "4"] 83 | 84 | 85 | cmdclass = {} 86 | ext_modules = [] 87 | 88 | if not SKIP_CUDA_BUILD: 89 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 90 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 91 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 92 | 93 | check_if_cuda_home_none("causal_conv1d") 94 | # Check, if CUDA11 is installed for compute capability 8.0 95 | cc_flag = [] 96 | if CUDA_HOME is not None: 97 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 98 | if bare_metal_version < Version("11.6"): 99 | raise RuntimeError( 100 | "causal_conv1d is only supported on CUDA 11.6 and above. " 101 | "Note: make sure nvcc has a supported version by running nvcc -V." 102 | ) 103 | 104 | cc_flag.append("-gencode") 105 | cc_flag.append("arch=compute_70,code=sm_70") 106 | cc_flag.append("-gencode") 107 | cc_flag.append("arch=compute_80,code=sm_80") 108 | if bare_metal_version >= Version("11.8"): 109 | cc_flag.append("-gencode") 110 | cc_flag.append("arch=compute_90,code=sm_90") 111 | 112 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 113 | # torch._C._GLIBCXX_USE_CXX11_ABI 114 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 115 | if FORCE_CXX11_ABI: 116 | torch._C._GLIBCXX_USE_CXX11_ABI = True 117 | 118 | ext_modules.append( 119 | CUDAExtension( 120 | name="causal_conv1d_cuda", 121 | sources=[ 122 | "csrc/causal_conv1d.cpp", 123 | "csrc/causal_conv1d_fwd.cu", 124 | "csrc/causal_conv1d_bwd.cu", 125 | "csrc/causal_conv1d_update.cu", 126 | ], 127 | extra_compile_args={ 128 | "cxx": ["-O3"], 129 | "nvcc": append_nvcc_threads( 130 | [ 131 | "-O3", 132 | "-U__CUDA_NO_HALF_OPERATORS__", 133 | "-U__CUDA_NO_HALF_CONVERSIONS__", 134 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 135 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 136 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 137 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 138 | "--expt-relaxed-constexpr", 139 | "--expt-extended-lambda", 140 | "--use_fast_math", 141 | "--ptxas-options=-v", 142 | "-lineinfo", 143 | ] 144 | + cc_flag 145 | ), 146 | }, 147 | include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"], 148 | ) 149 | ) 150 | 151 | 152 | def get_package_version(): 153 | with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f: 154 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 155 | public_version = ast.literal_eval(version_match.group(1)) 156 | local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION") 157 | if local_version: 158 | return f"{public_version}+{local_version}" 159 | else: 160 | return str(public_version) 161 | 162 | 163 | def get_wheel_url(): 164 | # Determine the version numbers that will be used to determine the correct wheel 165 | # We're using the CUDA version used to build torch, not the one currently installed 166 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 167 | torch_cuda_version = parse(torch.version.cuda) 168 | torch_version_raw = parse(torch.__version__) 169 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 170 | # to save CI time. Minor versions should be compatible. 171 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") 172 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 173 | platform_name = get_platform() 174 | causal_conv1d_version = get_package_version() 175 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" 176 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" 177 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 178 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 179 | 180 | # Determine wheel URL based on CUDA version, torch version, python version and OS 181 | wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 182 | wheel_url = BASE_WHEEL_URL.format( 183 | tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename 184 | ) 185 | return wheel_url, wheel_filename 186 | 187 | 188 | class CachedWheelsCommand(_bdist_wheel): 189 | """ 190 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 191 | find an existing wheel (which is currently the case for all installs). We use 192 | the environment parameters to detect whether there is already a pre-built version of a compatible 193 | wheel available and short-circuits the standard full build pipeline. 194 | """ 195 | 196 | def run(self): 197 | if FORCE_BUILD: 198 | return super().run() 199 | 200 | wheel_url, wheel_filename = get_wheel_url() 201 | print("Guessing wheel URL: ", wheel_url) 202 | try: 203 | urllib.request.urlretrieve(wheel_url, wheel_filename) 204 | 205 | # Make the archive 206 | # Lifted from the root wheel processing command 207 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 208 | if not os.path.exists(self.dist_dir): 209 | os.makedirs(self.dist_dir) 210 | 211 | impl_tag, abi_tag, plat_tag = self.get_tag() 212 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 213 | 214 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 215 | print("Raw wheel path", wheel_path) 216 | shutil.move(wheel_filename, wheel_path) 217 | except urllib.error.HTTPError: 218 | print("Precompiled wheel not found. Building from source...") 219 | # If the wheel could not be downloaded, build from source 220 | super().run() 221 | 222 | 223 | setup( 224 | name=PACKAGE_NAME, 225 | version=get_package_version(), 226 | packages=find_packages( 227 | exclude=( 228 | "build", 229 | "csrc", 230 | "include", 231 | "tests", 232 | "dist", 233 | "docs", 234 | "benchmarks", 235 | "causal_conv1d.egg-info", 236 | ) 237 | ), 238 | author="Tri Dao", 239 | author_email="tri@tridao.me", 240 | description="Causal depthwise conv1d in CUDA, with a PyTorch interface", 241 | long_description=long_description, 242 | long_description_content_type="text/markdown", 243 | url="https://github.com/Dao-AILab/causal-conv1d", 244 | classifiers=[ 245 | "Programming Language :: Python :: 3", 246 | "License :: OSI Approved :: BSD License", 247 | "Operating System :: Unix", 248 | ], 249 | ext_modules=ext_modules, 250 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 251 | if ext_modules 252 | else { 253 | "bdist_wheel": CachedWheelsCommand, 254 | }, 255 | python_requires=">=3.7", 256 | install_requires=[ 257 | "torch", 258 | "packaging", 259 | "buildtools", 260 | "ninja", 261 | ], 262 | ) -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/models/mixer_seq_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | 3 | import math 4 | from functools import partial 5 | import json 6 | import os 7 | 8 | from collections import namedtuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from mamba_ssm.models.config_mamba import MambaConfig 14 | from mamba_ssm.modules.mamba_simple import Mamba, Block 15 | from mamba_ssm.utils.generation import GenerationMixin 16 | from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf 17 | 18 | try: 19 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 20 | except ImportError: 21 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 22 | 23 | 24 | def create_block( 25 | d_model, 26 | ssm_cfg=None, 27 | norm_epsilon=1e-5, 28 | rms_norm=False, 29 | residual_in_fp32=False, 30 | fused_add_norm=False, 31 | layer_idx=None, 32 | device=None, 33 | dtype=None, 34 | ): 35 | if ssm_cfg is None: 36 | ssm_cfg = {} 37 | factory_kwargs = {"device": device, "dtype": dtype} 38 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 39 | norm_cls = partial( 40 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 41 | ) 42 | block = Block( 43 | d_model, 44 | mixer_cls, 45 | norm_cls=norm_cls, 46 | fused_add_norm=fused_add_norm, 47 | residual_in_fp32=residual_in_fp32, 48 | ) 49 | block.layer_idx = layer_idx 50 | return block 51 | 52 | 53 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 54 | def _init_weights( 55 | module, 56 | n_layer, 57 | initializer_range=0.02, # Now only used for embedding layer. 58 | rescale_prenorm_residual=True, 59 | n_residuals_per_layer=1, # Change to 2 if we have MLP 60 | ): 61 | if isinstance(module, nn.Linear): 62 | if module.bias is not None: 63 | if not getattr(module.bias, "_no_reinit", False): 64 | nn.init.zeros_(module.bias) 65 | elif isinstance(module, nn.Embedding): 66 | nn.init.normal_(module.weight, std=initializer_range) 67 | 68 | if rescale_prenorm_residual: 69 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 70 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 71 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 72 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 73 | # 74 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 75 | for name, p in module.named_parameters(): 76 | if name in ["out_proj.weight", "fc2.weight"]: 77 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 78 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 79 | # We need to reinit p since this code could be called multiple times 80 | # Having just p *= scale would repeatedly scale it down 81 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 82 | with torch.no_grad(): 83 | p /= math.sqrt(n_residuals_per_layer * n_layer) 84 | 85 | 86 | class MixerModel(nn.Module): 87 | def __init__( 88 | self, 89 | d_model: int, 90 | n_layer: int, 91 | vocab_size: int, 92 | ssm_cfg=None, 93 | norm_epsilon: float = 1e-5, 94 | rms_norm: bool = False, 95 | initializer_cfg=None, 96 | fused_add_norm=False, 97 | residual_in_fp32=False, 98 | device=None, 99 | dtype=None, 100 | ) -> None: 101 | factory_kwargs = {"device": device, "dtype": dtype} 102 | super().__init__() 103 | self.residual_in_fp32 = residual_in_fp32 104 | 105 | self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) 106 | 107 | # We change the order of residual and layer norm: 108 | # Instead of LN -> Attn / MLP -> Add, we do: 109 | # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and 110 | # the main branch (output of MLP / Mixer). The model definition is unchanged. 111 | # This is for performance reason: we can fuse add + layer_norm. 112 | self.fused_add_norm = fused_add_norm 113 | if self.fused_add_norm: 114 | if layer_norm_fn is None or rms_norm_fn is None: 115 | raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") 116 | 117 | self.layers = nn.ModuleList( 118 | [ 119 | create_block( 120 | d_model, 121 | ssm_cfg=ssm_cfg, 122 | norm_epsilon=norm_epsilon, 123 | rms_norm=rms_norm, 124 | residual_in_fp32=residual_in_fp32, 125 | fused_add_norm=fused_add_norm, 126 | layer_idx=i, 127 | **factory_kwargs, 128 | ) 129 | for i in range(n_layer) 130 | ] 131 | ) 132 | 133 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 134 | d_model, eps=norm_epsilon, **factory_kwargs 135 | ) 136 | 137 | self.apply( 138 | partial( 139 | _init_weights, 140 | n_layer=n_layer, 141 | **(initializer_cfg if initializer_cfg is not None else {}), 142 | ) 143 | ) 144 | 145 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 146 | return { 147 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 148 | for i, layer in enumerate(self.layers) 149 | } 150 | 151 | def forward(self, input_ids, inference_params=None): 152 | hidden_states = self.embedding(input_ids) 153 | residual = None 154 | for layer in self.layers: 155 | hidden_states, residual = layer( 156 | hidden_states, residual, inference_params=inference_params 157 | ) 158 | if not self.fused_add_norm: 159 | residual = (hidden_states + residual) if residual is not None else hidden_states 160 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 161 | else: 162 | # Set prenorm=False here since we don't need the residual 163 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 164 | hidden_states = fused_add_norm_fn( 165 | hidden_states, 166 | self.norm_f.weight, 167 | self.norm_f.bias, 168 | eps=self.norm_f.eps, 169 | residual=residual, 170 | prenorm=False, 171 | residual_in_fp32=self.residual_in_fp32, 172 | ) 173 | return hidden_states 174 | 175 | 176 | class MambaLMHeadModel(nn.Module, GenerationMixin): 177 | 178 | def __init__( 179 | self, 180 | config: MambaConfig, 181 | initializer_cfg=None, 182 | device=None, 183 | dtype=None, 184 | ) -> None: 185 | self.config = config 186 | d_model = config.d_model 187 | n_layer = config.n_layer 188 | vocab_size = config.vocab_size 189 | ssm_cfg = config.ssm_cfg 190 | rms_norm = config.rms_norm 191 | residual_in_fp32 = config.residual_in_fp32 192 | fused_add_norm = config.fused_add_norm 193 | pad_vocab_size_multiple = config.pad_vocab_size_multiple 194 | factory_kwargs = {"device": device, "dtype": dtype} 195 | 196 | super().__init__() 197 | if vocab_size % pad_vocab_size_multiple != 0: 198 | vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) 199 | self.backbone = MixerModel( 200 | d_model=d_model, 201 | n_layer=n_layer, 202 | vocab_size=vocab_size, 203 | ssm_cfg=ssm_cfg, 204 | rms_norm=rms_norm, 205 | initializer_cfg=initializer_cfg, 206 | fused_add_norm=fused_add_norm, 207 | residual_in_fp32=residual_in_fp32, 208 | **factory_kwargs, 209 | ) 210 | self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) 211 | 212 | # Initialize weights and apply final processing 213 | self.apply( 214 | partial( 215 | _init_weights, 216 | n_layer=n_layer, 217 | **(initializer_cfg if initializer_cfg is not None else {}), 218 | ) 219 | ) 220 | self.tie_weights() 221 | 222 | def tie_weights(self): 223 | if self.config.tie_embeddings: 224 | self.lm_head.weight = self.backbone.embedding.weight 225 | 226 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 227 | return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 228 | 229 | def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): 230 | """ 231 | "position_ids" is just to be compatible with Transformer generation. We don't use it. 232 | num_last_tokens: if > 0, only return the logits for the last n tokens 233 | """ 234 | hidden_states = self.backbone(input_ids, inference_params=inference_params) 235 | if num_last_tokens > 0: 236 | hidden_states = hidden_states[:, -num_last_tokens:] 237 | lm_logits = self.lm_head(hidden_states) 238 | CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) 239 | return CausalLMOutput(logits=lm_logits) 240 | 241 | @classmethod 242 | def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): 243 | config_data = load_config_hf(pretrained_model_name) 244 | config = MambaConfig(**config_data) 245 | model = cls(config, device=device, dtype=dtype, **kwargs) 246 | model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) 247 | return model 248 | 249 | def save_pretrained(self, save_directory): 250 | """ 251 | Minimal implementation of save_pretrained for MambaLMHeadModel. 252 | Save the model and its configuration file to a directory. 253 | """ 254 | # Ensure save_directory exists 255 | os.makedirs(save_directory, exist_ok=True) 256 | 257 | # Save the model's state_dict 258 | model_path = os.path.join(save_directory, 'pytorch_model.bin') 259 | torch.save(self.state_dict(), model_path) 260 | 261 | # Save the configuration of the model 262 | config_path = os.path.join(save_directory, 'config.json') 263 | with open(config_path, 'w') as f: 264 | json.dump(self.config.__dict__, f) 265 | -------------------------------------------------------------------------------- /mamba/mamba-main/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | import sys 3 | import warnings 4 | import os 5 | import re 6 | import ast 7 | from pathlib import Path 8 | from packaging.version import parse, Version 9 | import platform 10 | import shutil 11 | 12 | from setuptools import setup, find_packages 13 | import subprocess 14 | 15 | import urllib.request 16 | import urllib.error 17 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 18 | 19 | import torch 20 | from torch.utils.cpp_extension import ( 21 | BuildExtension, 22 | CppExtension, 23 | CUDAExtension, 24 | CUDA_HOME, 25 | ) 26 | 27 | 28 | with open("README.md", "r", encoding="utf-8") as fh: 29 | long_description = fh.read() 30 | 31 | 32 | # ninja build does not work unless include_dirs are abs path 33 | this_dir = os.path.dirname(os.path.abspath(__file__)) 34 | 35 | PACKAGE_NAME = "mamba_ssm" 36 | 37 | BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" 38 | 39 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 40 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 41 | FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" 42 | SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 43 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 44 | FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" 45 | 46 | 47 | def get_platform(): 48 | """ 49 | Returns the platform name as used in wheel filenames. 50 | """ 51 | if sys.platform.startswith("linux"): 52 | return "linux_x86_64" 53 | elif sys.platform == "darwin": 54 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 55 | return f"macosx_{mac_version}_x86_64" 56 | elif sys.platform == "win32": 57 | return "win_amd64" 58 | else: 59 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 60 | 61 | 62 | def get_cuda_bare_metal_version(cuda_dir): 63 | raw_output = subprocess.check_output( 64 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 65 | ) 66 | output = raw_output.split() 67 | release_idx = output.index("release") + 1 68 | bare_metal_version = parse(output[release_idx].split(",")[0]) 69 | 70 | return raw_output, bare_metal_version 71 | 72 | 73 | def check_if_cuda_home_none(global_option: str) -> None: 74 | if CUDA_HOME is not None: 75 | return 76 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 77 | # in that case. 78 | warnings.warn( 79 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 80 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 81 | "only images whose names contain 'devel' will provide nvcc." 82 | ) 83 | 84 | 85 | def append_nvcc_threads(nvcc_extra_args): 86 | return nvcc_extra_args + ["--threads", "4"] 87 | 88 | 89 | cmdclass = {} 90 | ext_modules = [] 91 | 92 | if not SKIP_CUDA_BUILD: 93 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 94 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 95 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 96 | 97 | check_if_cuda_home_none(PACKAGE_NAME) 98 | # Check, if CUDA11 is installed for compute capability 8.0 99 | cc_flag = [] 100 | if CUDA_HOME is not None: 101 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 102 | if bare_metal_version < Version("11.6"): 103 | raise RuntimeError( 104 | f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " 105 | "Note: make sure nvcc has a supported version by running nvcc -V." 106 | ) 107 | 108 | cc_flag.append("-gencode") 109 | cc_flag.append("arch=compute_53,code=sm_53") 110 | cc_flag.append("-gencode") 111 | cc_flag.append("arch=compute_62,code=sm_62") 112 | cc_flag.append("-gencode") 113 | cc_flag.append("arch=compute_70,code=sm_70") 114 | cc_flag.append("-gencode") 115 | cc_flag.append("arch=compute_72,code=sm_72") 116 | cc_flag.append("-gencode") 117 | cc_flag.append("arch=compute_80,code=sm_80") 118 | cc_flag.append("-gencode") 119 | cc_flag.append("arch=compute_87,code=sm_87") 120 | if bare_metal_version >= Version("11.8"): 121 | cc_flag.append("-gencode") 122 | cc_flag.append("arch=compute_90,code=sm_90") 123 | 124 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 125 | # torch._C._GLIBCXX_USE_CXX11_ABI 126 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 127 | if FORCE_CXX11_ABI: 128 | torch._C._GLIBCXX_USE_CXX11_ABI = True 129 | 130 | ext_modules.append( 131 | CUDAExtension( 132 | name="selective_scan_cuda", 133 | sources=[ 134 | "csrc/selective_scan/selective_scan.cpp", 135 | "csrc/selective_scan/selective_scan_fwd_fp32.cu", 136 | "csrc/selective_scan/selective_scan_fwd_fp16.cu", 137 | "csrc/selective_scan/selective_scan_fwd_bf16.cu", 138 | "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", 139 | "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", 140 | "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", 141 | "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", 142 | "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", 143 | "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", 144 | ], 145 | extra_compile_args={ 146 | "cxx": ["-O3", "-std=c++17"], 147 | "nvcc": append_nvcc_threads( 148 | [ 149 | "-O3", 150 | "-std=c++17", 151 | "-U__CUDA_NO_HALF_OPERATORS__", 152 | "-U__CUDA_NO_HALF_CONVERSIONS__", 153 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 154 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 155 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 156 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 157 | "--expt-relaxed-constexpr", 158 | "--expt-extended-lambda", 159 | "--use_fast_math", 160 | "--ptxas-options=-v", 161 | "-lineinfo", 162 | ] 163 | + cc_flag 164 | ), 165 | }, 166 | include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], 167 | ) 168 | ) 169 | 170 | 171 | def get_package_version(): 172 | with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f: 173 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 174 | public_version = ast.literal_eval(version_match.group(1)) 175 | local_version = os.environ.get("MAMBA_LOCAL_VERSION") 176 | if local_version: 177 | return f"{public_version}+{local_version}" 178 | else: 179 | return str(public_version) 180 | 181 | 182 | def get_wheel_url(): 183 | # Determine the version numbers that will be used to determine the correct wheel 184 | # We're using the CUDA version used to build torch, not the one currently installed 185 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 186 | torch_cuda_version = parse(torch.version.cuda) 187 | torch_version_raw = parse(torch.__version__) 188 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 189 | # to save CI time. Minor versions should be compatible. 190 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") 191 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 192 | platform_name = get_platform() 193 | mamba_ssm_version = get_package_version() 194 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" 195 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" 196 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 197 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 198 | 199 | # Determine wheel URL based on CUDA version, torch version, python version and OS 200 | wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 201 | wheel_url = BASE_WHEEL_URL.format( 202 | tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename 203 | ) 204 | return wheel_url, wheel_filename 205 | 206 | 207 | class CachedWheelsCommand(_bdist_wheel): 208 | """ 209 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 210 | find an existing wheel (which is currently the case for all installs). We use 211 | the environment parameters to detect whether there is already a pre-built version of a compatible 212 | wheel available and short-circuits the standard full build pipeline. 213 | """ 214 | 215 | def run(self): 216 | if FORCE_BUILD: 217 | return super().run() 218 | 219 | wheel_url, wheel_filename = get_wheel_url() 220 | print("Guessing wheel URL: ", wheel_url) 221 | try: 222 | urllib.request.urlretrieve(wheel_url, wheel_filename) 223 | 224 | # Make the archive 225 | # Lifted from the root wheel processing command 226 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 227 | if not os.path.exists(self.dist_dir): 228 | os.makedirs(self.dist_dir) 229 | 230 | impl_tag, abi_tag, plat_tag = self.get_tag() 231 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 232 | 233 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 234 | print("Raw wheel path", wheel_path) 235 | shutil.move(wheel_filename, wheel_path) 236 | except urllib.error.HTTPError: 237 | print("Precompiled wheel not found. Building from source...") 238 | # If the wheel could not be downloaded, build from source 239 | super().run() 240 | 241 | 242 | setup( 243 | name=PACKAGE_NAME, 244 | version=get_package_version(), 245 | packages=find_packages( 246 | exclude=( 247 | "build", 248 | "csrc", 249 | "include", 250 | "tests", 251 | "dist", 252 | "docs", 253 | "benchmarks", 254 | "mamba_ssm.egg-info", 255 | ) 256 | ), 257 | author="Tri Dao, Albert Gu", 258 | author_email="tri@tridao.me, agu@cs.cmu.edu", 259 | description="Mamba state-space model", 260 | long_description=long_description, 261 | long_description_content_type="text/markdown", 262 | url="https://github.com/state-spaces/mamba", 263 | classifiers=[ 264 | "Programming Language :: Python :: 3", 265 | "License :: OSI Approved :: BSD License", 266 | "Operating System :: Unix", 267 | ], 268 | ext_modules=ext_modules, 269 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 270 | if ext_modules 271 | else { 272 | "bdist_wheel": CachedWheelsCommand, 273 | }, 274 | python_requires=">=3.7", 275 | install_requires=[ 276 | "torch", 277 | "packaging", 278 | "ninja", 279 | "einops", 280 | "triton", 281 | "transformers", 282 | # "causal_conv1d>=1.2.0", 283 | ], 284 | ) 285 | -------------------------------------------------------------------------------- /mamba/mamba-main/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Tri Dao, Albert Gu 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/tests/test_causal_conv1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import pytest 7 | 8 | from einops import rearrange, repeat 9 | 10 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref 11 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("channel_last", [False, True]) 15 | # @pytest.mark.parametrize('channel_last', [True]) 16 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 17 | # @pytest.mark.parametrize('itype', [torch.float16]) 18 | @pytest.mark.parametrize("silu_activation", [False, True]) 19 | # @pytest.mark.parametrize('silu_activation', [True]) 20 | @pytest.mark.parametrize("has_bias", [False, True]) 21 | # @pytest.mark.parametrize('has_bias', [True]) 22 | @pytest.mark.parametrize("width", [2, 3, 4]) 23 | # @pytest.mark.parametrize('width', [2]) 24 | @pytest.mark.parametrize( 25 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 26 | ) 27 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 28 | # @pytest.mark.parametrize('seqlen', [128]) 29 | @pytest.mark.parametrize('dim', [64, 4096 + 32]) 30 | def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last): 31 | device = "cuda" 32 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 33 | if itype == torch.bfloat16: 34 | rtol, atol = 1e-2, 5e-2 35 | rtolw, atolw = (1e-3, 1e-3) 36 | # set seed 37 | torch.random.manual_seed(0) 38 | batch = 2 39 | # batch = 1 40 | if not channel_last: 41 | x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 42 | else: 43 | x = rearrange( 44 | torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 45 | ).requires_grad_() 46 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 47 | if has_bias: 48 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 49 | else: 50 | bias = None 51 | x_ref = x.detach().clone().requires_grad_() 52 | weight_ref = weight.detach().clone().requires_grad_() 53 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None 54 | activation = None if not silu_activation else "silu" 55 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 56 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation) 57 | 58 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 59 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 60 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 61 | 62 | g = torch.randn_like(out) 63 | out_ref.backward(g) 64 | out.backward(g) 65 | 66 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") 67 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") 68 | if has_bias: 69 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") 70 | 71 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 72 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) 73 | if has_bias: 74 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) 75 | 76 | 77 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 78 | # @pytest.mark.parametrize('itype', [torch.float16]) 79 | @pytest.mark.parametrize("silu_activation", [False, True]) 80 | # @pytest.mark.parametrize('silu_activation', [False]) 81 | @pytest.mark.parametrize("has_bias", [False, True]) 82 | # @pytest.mark.parametrize('has_bias', [True]) 83 | @pytest.mark.parametrize("width", [2, 3, 4]) 84 | # @pytest.mark.parametrize('width', [2]) 85 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 86 | # @pytest.mark.parametrize("dim", [2048]) 87 | def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): 88 | device = "cuda" 89 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 90 | if itype == torch.bfloat16: 91 | rtol, atol = 1e-2, 5e-2 92 | rtolw, atolw = (1e-3, 1e-3) 93 | # set seed 94 | torch.random.manual_seed(0) 95 | batch = 2 96 | # batch = 1 97 | # dim = 64 98 | x = torch.randn(batch, dim, device=device, dtype=itype) 99 | conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) 100 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 101 | if has_bias: 102 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 103 | else: 104 | bias = None 105 | conv_state_ref = conv_state.detach().clone() 106 | activation = None if not silu_activation else "silu" 107 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) 108 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation) 109 | 110 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 111 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 112 | assert torch.equal(conv_state, conv_state_ref) 113 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 114 | 115 | 116 | # @pytest.mark.parametrize("channel_last", [False, True]) 117 | @pytest.mark.parametrize('channel_last', [True]) 118 | # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 119 | @pytest.mark.parametrize('itype', [torch.bfloat16]) 120 | # @pytest.mark.parametrize("silu_activation", [False, True]) 121 | @pytest.mark.parametrize('silu_activation', [True]) 122 | # @pytest.mark.parametrize("has_bias", [False, True]) 123 | @pytest.mark.parametrize('has_bias', [True]) 124 | # @pytest.mark.parametrize("width", [2, 3, 4]) 125 | @pytest.mark.parametrize('width', [4]) 126 | @pytest.mark.parametrize( 127 | # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 128 | "seqlen", [2048] 129 | ) 130 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 131 | # @pytest.mark.parametrize('seqlen', [128]) 132 | def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last): 133 | device = "cuda" 134 | # set seed 135 | torch.random.manual_seed(0) 136 | batch = 2 137 | # batch = 1 138 | dim = 4096 + 32 # Try dim not divisible by 64 139 | # dim = 64 140 | if not channel_last: 141 | x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 142 | else: 143 | x = rearrange( 144 | torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 145 | ).requires_grad_() 146 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 147 | if has_bias: 148 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 149 | else: 150 | bias = None 151 | activation = None if not silu_activation else "silu" 152 | out0 = causal_conv1d_fn(x, weight, bias, activation=activation) 153 | g = torch.randn_like(out0) 154 | dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g) 155 | dw_atol = 1e-4 156 | db_atol = 1e-4 157 | 158 | for i in range(10000): 159 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 160 | dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g) 161 | dw_equal = torch.allclose(dw, dw0, atol=dw_atol) 162 | # if not dw_equal: 163 | # breakpoint() 164 | if has_bias: 165 | db_equal = torch.allclose(db, db0, atol=db_atol) 166 | # if not db_equal: 167 | # breakpoint() 168 | assert torch.equal(out, out0) 169 | assert torch.equal(dx, dx0) 170 | assert dw_equal 171 | if has_bias: 172 | assert dw_equal 173 | 174 | 175 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 176 | # @pytest.mark.parametrize('itype', [torch.float16]) 177 | @pytest.mark.parametrize("silu_activation", [False, True]) 178 | # @pytest.mark.parametrize('silu_activation', [False]) 179 | @pytest.mark.parametrize("has_bias", [False, True]) 180 | # @pytest.mark.parametrize('has_bias', [False]) 181 | @pytest.mark.parametrize("width", [2, 3, 4]) 182 | # @pytest.mark.parametrize('width', [2]) 183 | @pytest.mark.parametrize( 184 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 185 | ) 186 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 187 | # @pytest.mark.parametrize('seqlen', [2048]) 188 | @pytest.mark.parametrize('dim', [64, 4096 + 32]) 189 | # @pytest.mark.parametrize('dim', [64]) 190 | def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype): 191 | device = "cuda" 192 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 193 | if itype == torch.bfloat16: 194 | rtol, atol = 1e-2, 5e-2 195 | rtolw, atolw = (1e-3, 1e-3) 196 | # set seed 197 | torch.random.manual_seed(seqlen + dim + width) 198 | batch = 3 199 | seqlens = [] 200 | for b in range(batch): 201 | nsplits = torch.randint(1, 5, (1,)).item() 202 | eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values 203 | seqlens.append(torch.diff(torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist()) 204 | assert sum(seqlens[-1]) == seqlen 205 | assert all(s > 0 for s in seqlens[-1]) 206 | # Only support channel_last 207 | x = rearrange( 208 | torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 209 | ).requires_grad_() 210 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 211 | if has_bias: 212 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 213 | else: 214 | bias = None 215 | seq_idx = torch.stack([torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(sl)], dim=0) 216 | for sl in seqlens], dim=0) 217 | x_ref = x.detach().clone().requires_grad_() 218 | weight_ref = weight.detach().clone().requires_grad_() 219 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None 220 | activation = None if not silu_activation else "silu" 221 | out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation) 222 | out_ref = [] 223 | for b in range(batch): 224 | out_ref_b = [] 225 | for x_s in torch.split(x_ref[[b]], seqlens[b], dim=2): 226 | out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation)) 227 | out_ref.append(torch.cat(out_ref_b, dim=2)) 228 | out_ref = torch.cat(out_ref, dim=0) 229 | 230 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 231 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 232 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 233 | 234 | g = torch.randn_like(out) 235 | out_ref.backward(g) 236 | out.backward(g) 237 | 238 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") 239 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") 240 | if has_bias: 241 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") 242 | 243 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 244 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) 245 | if has_bias: 246 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) 247 | -------------------------------------------------------------------------------- /mamba/mamba-main/tests/ops/test_selective_scan.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 12 | from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref 13 | 14 | 15 | # @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) 16 | @pytest.mark.parametrize('wtype', [torch.float32]) 17 | # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) 18 | @pytest.mark.parametrize('itype', [torch.float32]) 19 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) 20 | @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) 21 | # @pytest.mark.parametrize('seqlen', [128]) 22 | # @pytest.mark.parametrize("return_last_state", [False, True]) 23 | @pytest.mark.parametrize("return_last_state", [True]) 24 | # @pytest.mark.parametrize('has_delta_bias', [False, True]) 25 | @pytest.mark.parametrize('has_delta_bias', [True]) 26 | # @pytest.mark.parametrize('delta_softplus', [False, True]) 27 | @pytest.mark.parametrize('delta_softplus', [True]) 28 | # @pytest.mark.parametrize('has_z', [False, True]) 29 | @pytest.mark.parametrize('has_z', [True]) 30 | # @pytest.mark.parametrize('has_D', [False, True]) 31 | @pytest.mark.parametrize('has_D', [True]) 32 | @pytest.mark.parametrize("varBC_groups", [1, 2]) 33 | # @pytest.mark.parametrize("varBC_groups", [1]) 34 | # @pytest.mark.parametrize("is_variable_C", [False, True]) 35 | @pytest.mark.parametrize("is_variable_C", [True]) 36 | # @pytest.mark.parametrize("is_variable_B", [False, True]) 37 | @pytest.mark.parametrize("is_variable_B", [True]) 38 | def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, 39 | delta_softplus, return_last_state, seqlen, itype, wtype): 40 | if varBC_groups > 1 and (not is_variable_B or not is_variable_C): 41 | pytest.skip() # This config is not applicable 42 | device = 'cuda' 43 | rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) 44 | if itype == torch.bfloat16: 45 | rtol, atol = 3e-2, 5e-2 46 | rtolw, atolw = (1e-3, 1e-3) 47 | if has_z: # If we have z, the errors on the weights seem higher 48 | rtolw = max(rtolw, rtol) 49 | atolw = max(atolw, atol) 50 | # set seed 51 | torch.random.manual_seed(0) 52 | batch_size = 2 53 | dim = 4 54 | dstate = 8 55 | is_complex = wtype == torch.complex64 56 | A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() 57 | if not is_variable_B: 58 | B_shape = (dim, dstate) 59 | elif varBC_groups == 1: 60 | B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) 61 | else: 62 | B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) 63 | B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, 64 | requires_grad=True) 65 | if not is_variable_C: 66 | C_shape = (dim, dstate) 67 | elif varBC_groups == 1: 68 | C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) 69 | else: 70 | C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) 71 | C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, 72 | requires_grad=True) 73 | if has_D: 74 | D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 75 | else: 76 | D = None 77 | if has_z: 78 | z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) 79 | else: 80 | z = None 81 | if has_delta_bias: 82 | delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() 83 | else: 84 | delta_bias = None 85 | u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) 86 | delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() 87 | A_ref = A.detach().clone().requires_grad_() 88 | B_ref = B.detach().clone().requires_grad_() 89 | C_ref = C.detach().clone().requires_grad_() 90 | D_ref = D.detach().clone().requires_grad_() if D is not None else None 91 | z_ref = z.detach().clone().requires_grad_() if z is not None else None 92 | u_ref = u.detach().clone().requires_grad_() 93 | delta_ref = delta.detach().clone().requires_grad_() 94 | delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None 95 | out, *rest = selective_scan_fn( 96 | u, delta, A, B, C, D, z=z, 97 | delta_bias=delta_bias, delta_softplus=delta_softplus, 98 | return_last_state=return_last_state 99 | ) 100 | if return_last_state: 101 | state = rest[0] 102 | out_ref, *rest = selective_scan_ref( 103 | u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, 104 | delta_bias=delta_bias_ref, delta_softplus=delta_softplus, 105 | return_last_state=return_last_state 106 | ) 107 | if return_last_state: 108 | state_ref = rest[0] 109 | # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 110 | # dt_u = delta * u 111 | 112 | print(f'Output max diff: {(out - out_ref).abs().max().item()}') 113 | print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') 114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 115 | if return_last_state: 116 | print(f'State max diff: {(state - state_ref).abs().max().item()}') 117 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 118 | 119 | g = torch.randn_like(out) 120 | out_ref.backward(g) 121 | out.backward(g) 122 | 123 | print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') 124 | print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') 125 | print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') 126 | print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') 127 | print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') 128 | if has_D: 129 | print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') 130 | if has_z: 131 | print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') 132 | if has_delta_bias: 133 | print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') 134 | 135 | assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) 136 | assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) 137 | assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) 138 | assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, 139 | atol=atolw if not is_variable_B else atol) 140 | assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, 141 | atol=atolw if not is_variable_C else atol) 142 | if has_D: 143 | assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) 144 | if has_z: 145 | assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) 146 | if has_delta_bias: 147 | assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) 148 | 149 | 150 | @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) 151 | # @pytest.mark.parametrize('wtype', [torch.complex64]) 152 | # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) 153 | @pytest.mark.parametrize('itype', [torch.float32]) 154 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) 155 | @pytest.mark.parametrize('seqlen', [128]) 156 | @pytest.mark.parametrize("is_variable_C", [False, True]) 157 | # @pytest.mark.parametrize("is_variable_C", [False]) 158 | @pytest.mark.parametrize("is_variable_B", [False, True]) 159 | # @pytest.mark.parametrize("is_variable_B", [True]) 160 | def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): 161 | device = 'cuda' 162 | rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) 163 | if itype == torch.bfloat16: 164 | rtol, atol = 3e-2, 5e-2 165 | rtolw, atolw = (1e-3, 1e-3) 166 | # If we have z, the errors on the weights seem higher 167 | rtolw = max(rtolw, rtol) 168 | atolw = max(atolw, atol) 169 | # set seed 170 | torch.random.manual_seed(0) 171 | batch_size = 2 172 | dim = 768 173 | dstate = 8 174 | dt_rank = 48 175 | is_complex = wtype == torch.complex64 176 | xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) 177 | conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) 178 | conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 179 | x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate 180 | * (1 if not is_complex else 2), 181 | dim, device=device, dtype=itype, requires_grad=True) 182 | delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) 183 | out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) 184 | out_proj_bias = None 185 | A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() 186 | B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) 187 | if not is_variable_B else None) 188 | C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) 189 | if not is_variable_C else None) 190 | D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 191 | delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() 192 | B_proj_bias = None 193 | C_proj_bias = None 194 | xz_ref = xz.detach().clone().requires_grad_() 195 | conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() 196 | conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() 197 | x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() 198 | delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() 199 | out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() 200 | out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() 201 | if out_proj_bias is not None else None) 202 | A_ref = A.detach().clone().requires_grad_() 203 | B_ref = B.detach().clone().requires_grad_() if B is not None else None 204 | C_ref = C.detach().clone().requires_grad_() if C is not None else None 205 | D_ref = D.detach().clone().requires_grad_() 206 | delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None 207 | out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 208 | out_proj_weight, out_proj_bias, 209 | A, B, C, D, delta_bias=delta_bias, delta_softplus=True) 210 | out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref, 211 | delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref, 212 | A_ref, B_ref, C_ref, D_ref, 213 | delta_bias=delta_bias_ref, delta_softplus=True) 214 | # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 215 | # dt_u = delta * u 216 | 217 | print(f'Output max diff: {(out - out_ref).abs().max().item()}') 218 | print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') 219 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 220 | 221 | g = torch.randn_like(out) 222 | out_ref.backward(g) 223 | out.backward(g) 224 | 225 | print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}') 226 | print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') 227 | if not is_variable_B: 228 | print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') 229 | if not is_variable_C: 230 | print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') 231 | print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') 232 | print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') 233 | print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}') 234 | print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}') 235 | print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}') 236 | print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}') 237 | print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}') 238 | 239 | # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) 240 | # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) 241 | # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) 242 | # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, 243 | # atol=atolw if not is_variable_B else atol) 244 | # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, 245 | # atol=atolw if not is_variable_C else atol) 246 | # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) 247 | # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) 248 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/modules/mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import math 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | 11 | from einops import rearrange, repeat 12 | 13 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 14 | 15 | try: 16 | from causal_conv1d import causal_conv1d_fn, causal_conv1d_update 17 | except ImportError: 18 | causal_conv1d_fn, causal_conv1d_update = None, None 19 | 20 | try: 21 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update 22 | except ImportError: 23 | selective_state_update = None 24 | 25 | try: 26 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 27 | except ImportError: 28 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 29 | 30 | 31 | class Mamba(nn.Module): 32 | def __init__( 33 | self, 34 | d_model, 35 | d_state=16, 36 | d_conv=4, 37 | expand=2, 38 | dt_rank="auto", 39 | dt_min=0.001, 40 | dt_max=0.1, 41 | dt_init="random", 42 | dt_scale=1.0, 43 | dt_init_floor=1e-4, 44 | conv_bias=True, 45 | bias=False, 46 | use_fast_path=True, # Fused kernel options 47 | layer_idx=None, 48 | device=None, 49 | dtype=None, 50 | ): 51 | factory_kwargs = {"device": device, "dtype": dtype} 52 | super().__init__() 53 | self.d_model = d_model 54 | self.d_state = d_state 55 | self.d_conv = d_conv 56 | self.expand = expand 57 | self.d_inner = int(self.expand * self.d_model) 58 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 59 | self.use_fast_path = use_fast_path 60 | self.layer_idx = layer_idx 61 | 62 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 63 | 64 | self.conv1d = nn.Conv1d( 65 | in_channels=self.d_inner, 66 | out_channels=self.d_inner, 67 | bias=conv_bias, 68 | kernel_size=d_conv, 69 | groups=self.d_inner, 70 | padding=d_conv - 1, 71 | **factory_kwargs, 72 | ) 73 | 74 | self.activation = "silu" 75 | self.act = nn.SiLU() 76 | 77 | self.x_proj = nn.Linear( 78 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 79 | ) 80 | self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 81 | 82 | # Initialize special dt projection to preserve variance at initialization 83 | dt_init_std = self.dt_rank**-0.5 * dt_scale 84 | if dt_init == "constant": 85 | nn.init.constant_(self.dt_proj.weight, dt_init_std) 86 | elif dt_init == "random": 87 | nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) 88 | else: 89 | raise NotImplementedError 90 | 91 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 92 | dt = torch.exp( 93 | torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 94 | + math.log(dt_min) 95 | ).clamp(min=dt_init_floor) 96 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 97 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 98 | with torch.no_grad(): 99 | self.dt_proj.bias.copy_(inv_dt) 100 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 101 | self.dt_proj.bias._no_reinit = True 102 | 103 | # S4D real initialization 104 | A = repeat( 105 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 106 | "n -> d n", 107 | d=self.d_inner, 108 | ).contiguous() 109 | A_log = torch.log(A) # Keep A_log in fp32 110 | self.A_log = nn.Parameter(A_log) 111 | self.A_log._no_weight_decay = True 112 | 113 | # D "skip" parameter 114 | self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 115 | self.D._no_weight_decay = True 116 | 117 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 118 | 119 | def forward(self, hidden_states, inference_params=None): 120 | """ 121 | hidden_states: (B, L, D) 122 | Returns: same shape as hidden_states 123 | """ 124 | batch, seqlen, dim = hidden_states.shape 125 | 126 | conv_state, ssm_state = None, None 127 | if inference_params is not None: 128 | conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) 129 | if inference_params.seqlen_offset > 0: 130 | # The states are updated inplace 131 | out, _, _ = self.step(hidden_states, conv_state, ssm_state) 132 | return out 133 | 134 | # We do matmul and transpose BLH -> HBL at the same time 135 | xz = rearrange( 136 | self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), 137 | "d (b l) -> b d l", 138 | l=seqlen, 139 | ) 140 | if self.in_proj.bias is not None: 141 | xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") 142 | 143 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 144 | # In the backward pass we write dx and dz next to each other to avoid torch.cat 145 | if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states 146 | out = mamba_inner_fn( 147 | xz, 148 | self.conv1d.weight, 149 | self.conv1d.bias, 150 | self.x_proj.weight, 151 | self.dt_proj.weight, 152 | self.out_proj.weight, 153 | self.out_proj.bias, 154 | A, 155 | None, # input-dependent B 156 | None, # input-dependent C 157 | self.D.float(), 158 | delta_bias=self.dt_proj.bias.float(), 159 | delta_softplus=True, 160 | ) 161 | else: 162 | x, z = xz.chunk(2, dim=1) 163 | # Compute short convolution 164 | if conv_state is not None: 165 | # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv 166 | # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. 167 | conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) 168 | if causal_conv1d_fn is None: 169 | x = self.act(self.conv1d(x)[..., :seqlen]) 170 | else: 171 | assert self.activation in ["silu", "swish"] 172 | x = causal_conv1d_fn( 173 | x=x, 174 | weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), 175 | bias=self.conv1d.bias, 176 | activation=self.activation, 177 | ) 178 | 179 | # We're careful here about the layout, to avoid extra transposes. 180 | # We want dt to have d as the slowest moving dimension 181 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 182 | x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) 183 | dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) 184 | dt = self.dt_proj.weight @ dt.t() 185 | dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) 186 | B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 187 | C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 188 | assert self.activation in ["silu", "swish"] 189 | y = selective_scan_fn( 190 | x, 191 | dt, 192 | A, 193 | B, 194 | C, 195 | self.D.float(), 196 | z=z, 197 | delta_bias=self.dt_proj.bias.float(), 198 | delta_softplus=True, 199 | return_last_state=ssm_state is not None, 200 | ) 201 | if ssm_state is not None: 202 | y, last_state = y 203 | ssm_state.copy_(last_state) 204 | y = rearrange(y, "b d l -> b l d") 205 | out = self.out_proj(y) 206 | return out 207 | 208 | def step(self, hidden_states, conv_state, ssm_state): 209 | dtype = hidden_states.dtype 210 | assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" 211 | xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) 212 | x, z = xz.chunk(2, dim=-1) # (B D) 213 | 214 | # Conv step 215 | if causal_conv1d_update is None: 216 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 217 | conv_state[:, :, -1] = x 218 | x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) 219 | if self.conv1d.bias is not None: 220 | x = x + self.conv1d.bias 221 | x = self.act(x).to(dtype=dtype) 222 | else: 223 | x = causal_conv1d_update( 224 | x, 225 | conv_state, 226 | rearrange(self.conv1d.weight, "d 1 w -> d w"), 227 | self.conv1d.bias, 228 | self.activation, 229 | ) 230 | 231 | x_db = self.x_proj(x) # (B dt_rank+2*d_state) 232 | dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) 233 | # Don't add dt_bias here 234 | dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) 235 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 236 | 237 | # SSM step 238 | if selective_state_update is None: 239 | # Discretize A and B 240 | dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) 241 | dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) 242 | dB = torch.einsum("bd,bn->bdn", dt, B) 243 | ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) 244 | y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) 245 | y = y + self.D.to(dtype) * x 246 | y = y * self.act(z) # (B D) 247 | else: 248 | y = selective_state_update( 249 | ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True 250 | ) 251 | 252 | out = self.out_proj(y) 253 | return out.unsqueeze(1), conv_state, ssm_state 254 | 255 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 256 | device = self.out_proj.weight.device 257 | conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype 258 | conv_state = torch.zeros( 259 | batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype 260 | ) 261 | ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype 262 | # ssm_dtype = torch.float32 263 | ssm_state = torch.zeros( 264 | batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype 265 | ) 266 | return conv_state, ssm_state 267 | 268 | def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): 269 | assert self.layer_idx is not None 270 | if self.layer_idx not in inference_params.key_value_memory_dict: 271 | batch_shape = (batch_size,) 272 | conv_state = torch.zeros( 273 | batch_size, 274 | self.d_model * self.expand, 275 | self.d_conv, 276 | device=self.conv1d.weight.device, 277 | dtype=self.conv1d.weight.dtype, 278 | ) 279 | ssm_state = torch.zeros( 280 | batch_size, 281 | self.d_model * self.expand, 282 | self.d_state, 283 | device=self.dt_proj.weight.device, 284 | dtype=self.dt_proj.weight.dtype, 285 | # dtype=torch.float32, 286 | ) 287 | inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) 288 | else: 289 | conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] 290 | # TODO: What if batch size changes between generation, and we reuse the same states? 291 | if initialize_states: 292 | conv_state.zero_() 293 | ssm_state.zero_() 294 | return conv_state, ssm_state 295 | 296 | 297 | class Block(nn.Module): 298 | def __init__( 299 | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False 300 | ): 301 | """ 302 | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" 303 | 304 | This Block has a slightly different structure compared to a regular 305 | prenorm Transformer block. 306 | The standard block is: LN -> MHA/MLP -> Add. 307 | [Ref: https://arxiv.org/abs/2002.04745] 308 | Here we have: Add -> LN -> Mixer, returning both 309 | the hidden_states (output of the mixer) and the residual. 310 | This is purely for performance reasons, as we can fuse add and LayerNorm. 311 | The residual needs to be provided (except for the very first block). 312 | """ 313 | super().__init__() 314 | self.residual_in_fp32 = residual_in_fp32 315 | self.fused_add_norm = fused_add_norm 316 | self.mixer = mixer_cls(dim) 317 | self.norm = norm_cls(dim) 318 | if self.fused_add_norm: 319 | assert RMSNorm is not None, "RMSNorm import fails" 320 | assert isinstance( 321 | self.norm, (nn.LayerNorm, RMSNorm) 322 | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" 323 | 324 | def forward( 325 | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None 326 | ): 327 | r"""Pass the input through the encoder layer. 328 | 329 | Args: 330 | hidden_states: the sequence to the encoder layer (required). 331 | residual: hidden_states = Mixer(LN(residual)) 332 | """ 333 | if not self.fused_add_norm: 334 | residual = (hidden_states + residual) if residual is not None else hidden_states 335 | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) 336 | if self.residual_in_fp32: 337 | residual = residual.to(torch.float32) 338 | else: 339 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn 340 | hidden_states, residual = fused_add_norm_fn( 341 | hidden_states, 342 | self.norm.weight, 343 | self.norm.bias, 344 | residual=residual, 345 | prenorm=True, 346 | residual_in_fp32=self.residual_in_fp32, 347 | eps=self.norm.eps, 348 | ) 349 | hidden_states = self.mixer(hidden_states, inference_params=inference_params) 350 | return hidden_states, residual 351 | 352 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 353 | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 354 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/utils/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | import gc 3 | import time 4 | from collections import namedtuple 5 | from dataclasses import dataclass, field 6 | from functools import partial 7 | from typing import Callable, Optional, Sequence, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from einops import rearrange, repeat 12 | from torch import Tensor 13 | from torch.profiler import ProfilerActivity, profile, record_function 14 | from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer 15 | 16 | 17 | @dataclass 18 | class InferenceParams: 19 | """Inference parameters that are passed to the main model in order 20 | to efficienly calculate and store the context during inference.""" 21 | 22 | max_seqlen: int 23 | max_batch_size: int 24 | seqlen_offset: int = 0 25 | batch_size_offset: int = 0 26 | key_value_memory_dict: dict = field(default_factory=dict) 27 | lengths_per_sample: Optional[Tensor] = None 28 | 29 | def reset(self, max_seqlen, max_batch_size): 30 | self.max_seqlen = max_seqlen 31 | self.max_batch_size = max_batch_size 32 | self.seqlen_offset = 0 33 | if self.lengths_per_sample is not None: 34 | self.lengths_per_sample.zero_() 35 | 36 | 37 | def modify_logits_for_min_p_filtering(logits, min_p): 38 | """Set the logits for none min_p values to -inf. Done in-place.""" 39 | if min_p <= 0.0 or min_p >= 1.0: 40 | return 41 | indices_to_remove = logits < min_p 42 | logits.masked_fill_(indices_to_remove, float("-Inf")) 43 | # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py 44 | # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 45 | def modify_logits_for_top_k_filtering(logits, top_k): 46 | """Set the logits for none top-k values to -inf. Done in-place.""" 47 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 48 | logits.masked_fill_(indices_to_remove, float("-Inf")) 49 | 50 | 51 | # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py 52 | # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 53 | def modify_logits_for_top_p_filtering(logits, top_p): 54 | """Set the logits for none top-p values to -inf. Done in-place.""" 55 | if top_p <= 0.0 or top_p >= 1.0: 56 | return 57 | # First sort and calculate cumulative sum of probabilities. 58 | sorted_logits, sorted_indices = torch.sort(logits, descending=False) 59 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 60 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 61 | sorted_indices_to_remove = cumulative_probs <= (1 - top_p) 62 | # scatter sorted tensors to original indexing 63 | indices_to_remove = sorted_indices_to_remove.scatter( 64 | 1, sorted_indices, sorted_indices_to_remove 65 | ) 66 | logits.masked_fill_(indices_to_remove, float("-inf")) 67 | 68 | 69 | def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0): 70 | """Apply repetition penalty. See https://arxiv.org/abs/1909.05858 71 | logits: (batch_size, vocab_size) 72 | prev_output_tokens: (batch_size, seq_len) 73 | """ 74 | if repetition_penalty == 1.0: 75 | return logits 76 | score = torch.gather(logits, 1, prev_output_tokens) 77 | # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability 78 | score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) 79 | logits.scatter_(1, prev_output_tokens, score) 80 | return logits 81 | 82 | 83 | def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0): 84 | """Sample from top-k logits. 85 | Arguments: 86 | logits: Tensor of shape (batch_size, vocab_size) 87 | """ 88 | if top_k == 1: # Short-circuit for greedy decoding 89 | return logits.argmax(dim=-1) 90 | else: 91 | if top_p > 0.0: 92 | assert top_p <= 1.0, "top-p should be in (0, 1]." 93 | if top_k > 0: 94 | top_k = min(top_k, logits.size(-1)) # Safety check 95 | logits_top, indices = torch.topk(logits, top_k, dim=-1) 96 | if temperature != 1.0: 97 | logits_top /= temperature 98 | modify_logits_for_top_p_filtering(logits_top, top_p) 99 | return indices[ 100 | torch.arange(indices.shape[0], device=indices.device), 101 | torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), 102 | ] 103 | else: 104 | if min_p > 0.0: 105 | logits_top = logits.clone() 106 | max_prob = logits_top[..., 0].item() 107 | min_prob = max_prob * min_p 108 | modify_logits_for_min_p_filtering(logits_top, min_p) 109 | if temperature != 1.0: 110 | logits_top /= temperature 111 | return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) 112 | # Clone so that when we modify for top_p we don't change the original logits 113 | logits_top = logits / temperature if temperature != 1.0 else logits.clone() 114 | modify_logits_for_top_p_filtering(logits_top, top_p) 115 | return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( 116 | dim=-1 117 | ) 118 | 119 | 120 | @torch.inference_mode() 121 | def decode( 122 | input_ids, 123 | model, 124 | max_length, 125 | top_k=1, 126 | top_p=0.0, 127 | min_p=0.0, 128 | temperature=1.0, 129 | repetition_penalty=1.0, 130 | eos_token_id=None, 131 | teacher_outputs=None, 132 | vocab_size=None, 133 | cg=False, 134 | enable_timing=False, 135 | streamer: Optional[TextStreamer] = None 136 | ): 137 | """Decoding, either greedy or with top-k or top-p sampling. 138 | If top-k = 0, don't limit the number of candidates (pure sampling). 139 | Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, 140 | then top-p. 141 | We assume that all sequences in the same batch have the same length. 142 | 143 | Arguments: 144 | input_ids: (batch, seq_len) 145 | max_length: int 146 | teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the 147 | logits, the next token is taken from the teacher_outputs. Useful for testing. 148 | Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: 149 | sequences: (batch, max_length) 150 | scores: tuples of (batch, vocab_size) 151 | """ 152 | if streamer is not None: 153 | streamer.put(input_ids.cpu()) 154 | 155 | batch_size, seqlen_og = input_ids.shape 156 | teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 157 | if cg: 158 | if not hasattr(model, "_decoding_cache"): 159 | model._decoding_cache = None 160 | model._decoding_cache = update_graph_cache( 161 | model, 162 | model._decoding_cache, 163 | batch_size, 164 | seqlen_og, 165 | max_length, 166 | ) 167 | inference_params = model._decoding_cache.inference_params 168 | inference_params.reset(max_length, batch_size) 169 | else: 170 | inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) 171 | 172 | def get_logits(input_ids, inference_params): 173 | decoding = inference_params.seqlen_offset > 0 174 | if decoding: 175 | position_ids = torch.full( 176 | (batch_size, 1), 177 | inference_params.seqlen_offset, 178 | dtype=torch.long, 179 | device=input_ids.device, 180 | ) 181 | else: 182 | position_ids = None 183 | if not cg or not decoding: 184 | logits = model( 185 | input_ids, 186 | position_ids=position_ids, 187 | inference_params=inference_params, 188 | num_last_tokens=1, 189 | ).logits.squeeze(dim=1) 190 | else: 191 | logits = model._decoding_cache.run( 192 | input_ids, position_ids, inference_params.seqlen_offset 193 | ).squeeze(dim=1) 194 | return logits[..., :vocab_size] if vocab_size is not None else logits 195 | 196 | def sample_tokens(logits, inference_params): 197 | if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: 198 | token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature) 199 | else: 200 | token = teacher_outputs[:, inference_params.seqlen_offset] 201 | # return rearrange(token, "b -> b 1") 202 | return token.unsqueeze(1) 203 | 204 | def should_stop(current_token, inference_params): 205 | if inference_params.seqlen_offset == 0: 206 | return False 207 | if eos_token_id is not None and (current_token == eos_token_id).all(): 208 | return True 209 | if inference_params.seqlen_offset >= max_length - 1: 210 | return True 211 | return False 212 | 213 | start = torch.cuda.Event(enable_timing=enable_timing) 214 | end = torch.cuda.Event(enable_timing=enable_timing) 215 | 216 | if enable_timing: 217 | start.record() 218 | scores, sequences = [], [input_ids] 219 | sequences_cat = input_ids 220 | while not should_stop(sequences[-1], inference_params): 221 | scores.append(get_logits(sequences[-1], inference_params)) 222 | inference_params.seqlen_offset += sequences[-1].shape[1] 223 | if repetition_penalty == 1.0: 224 | sampled_tokens = sample_tokens(scores[-1], inference_params) 225 | else: 226 | logits = modify_logit_for_repetition_penalty( 227 | scores[-1].clone(), sequences_cat, repetition_penalty 228 | ) 229 | sampled_tokens = sample_tokens(logits, inference_params) 230 | sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1) 231 | sequences.append(sampled_tokens) 232 | if streamer is not None: 233 | streamer.put(sampled_tokens.cpu()) 234 | if streamer is not None: 235 | streamer.end() 236 | if enable_timing: 237 | end.record() 238 | torch.cuda.synchronize() 239 | print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") 240 | output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput 241 | return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) 242 | 243 | 244 | class GenerationMixin: 245 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 246 | raise NotImplementedError 247 | 248 | def generate( 249 | self, 250 | input_ids, 251 | max_length, 252 | top_k=1, 253 | top_p=0.0, 254 | min_p=0.0, 255 | temperature=1.0, 256 | return_dict_in_generate=False, 257 | output_scores=False, 258 | **kwargs, 259 | ): 260 | output = decode( 261 | input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs 262 | ) 263 | if not output_scores: 264 | output.scores = None 265 | return output if return_dict_in_generate else output.sequences 266 | 267 | 268 | @dataclass 269 | class DecodingCGCache: 270 | max_batch_size: int = 0 271 | max_seqlen: int = 0 272 | device = None 273 | dtype = None 274 | callables: dict = field(default_factory=dict) 275 | mempool = None 276 | inference_params: Optional[InferenceParams] = None 277 | run: Optional[Callable] = None 278 | 279 | 280 | @torch.inference_mode() 281 | def update_graph_cache( 282 | model, 283 | cache, 284 | batch_size, 285 | seqlen_og, 286 | max_seqlen, 287 | decoding_seqlens=(1,), 288 | dtype=None, 289 | n_warmups=2, 290 | ): 291 | if cache is None: 292 | cache = DecodingCGCache() 293 | param_example = next(iter(model.parameters())) 294 | device = param_example.device 295 | if dtype is None: 296 | dtype = param_example.dtype 297 | if ( 298 | (device, dtype) != (cache.device, cache.dtype) 299 | or batch_size > cache.max_batch_size 300 | or max_seqlen > cache.max_seqlen 301 | ): # Invalidate the cache 302 | cache.callables = {} 303 | cache.mempool = None 304 | cache.inference_params = None 305 | gc.collect() 306 | cache.device, cache.dtype = device, dtype 307 | cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen 308 | assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache" 309 | inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) 310 | lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) 311 | cache.inference_params = InferenceParams( 312 | max_seqlen=max_seqlen, 313 | max_batch_size=batch_size, 314 | seqlen_offset=seqlen_og, 315 | key_value_memory_dict=inf_cache, 316 | lengths_per_sample=lengths_per_sample, 317 | ) 318 | cache.mempool = torch.cuda.graphs.graph_pool_handle() 319 | for decoding_seqlen in decoding_seqlens: 320 | if (batch_size, decoding_seqlen) not in cache.callables: 321 | cache.callables[batch_size, decoding_seqlen] = capture_graph( 322 | model, 323 | cache.inference_params, 324 | batch_size, 325 | max_seqlen, 326 | decoding_seqlen=decoding_seqlen, 327 | mempool=cache.mempool, 328 | n_warmups=n_warmups, 329 | ) 330 | 331 | def dispatch(input_ids, position_ids, seqlen): 332 | batch_size, decoding_seqlen = input_ids.shape[:2] 333 | return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) 334 | 335 | cache.run = dispatch 336 | cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing 337 | return cache 338 | 339 | 340 | def capture_graph( 341 | model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 342 | ): 343 | device = next(iter(model.parameters())).device 344 | input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) 345 | position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) 346 | seqlen_offset_og = inference_params.seqlen_offset 347 | inference_params.seqlen_offset = max_seqlen - decoding_seqlen 348 | inference_params.lengths_per_sample[:] = inference_params.seqlen_offset 349 | 350 | # Warmup before capture 351 | s = torch.cuda.Stream() 352 | s.wait_stream(torch.cuda.current_stream()) 353 | with torch.cuda.stream(s): 354 | for _ in range(n_warmups): 355 | logits = model( 356 | input_ids, 357 | position_ids=position_ids, 358 | inference_params=inference_params, 359 | num_last_tokens=decoding_seqlen, 360 | ).logits 361 | s.synchronize() 362 | # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, 363 | # which requires that graph launch and non-captured launch to not overlap (I think, 364 | # that's how I interpret the documentation). I'm not sure if this is required. 365 | if torch.distributed.is_initialized(): 366 | torch.distributed.barrier() 367 | torch.cuda.current_stream().wait_stream(s) 368 | # Captures the graph 369 | # To allow capture, automatically sets a side stream as the current stream in the context 370 | graph = torch.cuda.CUDAGraph() 371 | with torch.cuda.graph(graph, pool=mempool): 372 | logits = model( 373 | input_ids, 374 | position_ids=position_ids, 375 | inference_params=inference_params, 376 | num_last_tokens=decoding_seqlen, 377 | ).logits 378 | 379 | def run(new_input_ids, new_position_ids, seqlen): 380 | inference_params.lengths_per_sample[:] = seqlen 381 | input_ids.copy_(new_input_ids) 382 | position_ids.copy_(new_position_ids) 383 | graph.replay() 384 | return logits.clone() 385 | 386 | inference_params.seqlen_offset = seqlen_offset_og 387 | return run 388 | -------------------------------------------------------------------------------- /causal-conv1d/causal-conv1d-main/csrc/causal_conv1d.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "causal_conv1d.h" 11 | 12 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 13 | 14 | #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ 15 | if (ITYPE == at::ScalarType::Half) { \ 16 | using input_t = at::Half; \ 17 | __VA_ARGS__(); \ 18 | } else if (ITYPE == at::ScalarType::BFloat16) { \ 19 | using input_t = at::BFloat16; \ 20 | __VA_ARGS__(); \ 21 | } else if (ITYPE == at::ScalarType::Float) { \ 22 | using input_t = float; \ 23 | __VA_ARGS__(); \ 24 | } else { \ 25 | AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ 26 | } 27 | 28 | #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ 29 | if (WTYPE == at::ScalarType::Half) { \ 30 | using weight_t = at::Half; \ 31 | __VA_ARGS__(); \ 32 | } else if (WTYPE == at::ScalarType::BFloat16) { \ 33 | using weight_t = at::BFloat16; \ 34 | __VA_ARGS__(); \ 35 | } else if (WTYPE == at::ScalarType::Float) { \ 36 | using weight_t = float; \ 37 | __VA_ARGS__(); \ 38 | } else { \ 39 | AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ 40 | } 41 | 42 | template 43 | void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 44 | template 45 | void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 46 | 47 | template 48 | void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 49 | template 50 | void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 51 | 52 | template 53 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 54 | 55 | void set_conv_params_fwd(ConvParamsBase ¶ms, 56 | // sizes 57 | const size_t batch, 58 | const size_t dim, 59 | const size_t seqlen, 60 | const size_t width, 61 | // device pointers 62 | const at::Tensor x, 63 | const at::Tensor weight, 64 | const at::Tensor out, 65 | void* bias_ptr, 66 | bool silu_activation) { 67 | 68 | // Reset the parameters 69 | memset(¶ms, 0, sizeof(params)); 70 | 71 | params.batch = batch; 72 | params.dim = dim; 73 | params.seqlen = seqlen; 74 | params.width = width; 75 | 76 | params.silu_activation = silu_activation; 77 | 78 | // Set the pointers and strides. 79 | params.x_ptr = x.data_ptr(); 80 | params.weight_ptr = weight.data_ptr(); 81 | params.bias_ptr = bias_ptr; 82 | params.out_ptr = out.data_ptr(); 83 | // All stride are in elements, not bytes. 84 | params.x_batch_stride = x.stride(0); 85 | params.x_c_stride = x.stride(1); 86 | params.x_l_stride = x.stride(-1); 87 | params.weight_c_stride = weight.stride(0); 88 | params.weight_width_stride = weight.stride(1); 89 | params.out_batch_stride = out.stride(0); 90 | params.out_c_stride = out.stride(1); 91 | params.out_l_stride = out.stride(-1); 92 | } 93 | 94 | 95 | void set_conv_params_bwd(ConvParamsBwd ¶ms, 96 | // sizes 97 | const size_t batch, 98 | const size_t dim, 99 | const size_t seqlen, 100 | const size_t width, 101 | // device pointers 102 | const at::Tensor x, 103 | const at::Tensor weight, 104 | void* bias_ptr, 105 | const at::Tensor dout, 106 | const at::Tensor dx, 107 | const at::Tensor dweight, 108 | void* dbias_ptr, 109 | bool silu_activation) { 110 | // Pass in "dout" instead of "out", we're not gonna use "out" at all. 111 | set_conv_params_fwd(params, batch, dim, seqlen, width, 112 | x, weight, dout, bias_ptr, silu_activation); 113 | 114 | // Set the pointers and strides. 115 | params.dout_ptr = dout.data_ptr(); 116 | params.dx_ptr = dx.data_ptr(); 117 | params.dweight_ptr = dweight.data_ptr(); 118 | params.dbias_ptr = dbias_ptr; 119 | // All stride are in elements, not bytes. 120 | params.dout_batch_stride = dout.stride(0); 121 | params.dout_c_stride = dout.stride(1); 122 | params.dout_l_stride = dout.stride(2); 123 | params.dweight_c_stride = dweight.stride(0); 124 | params.dweight_width_stride = dweight.stride(1); 125 | params.dx_batch_stride = dx.stride(0); 126 | params.dx_c_stride = dx.stride(1); 127 | params.dx_l_stride = dx.stride(2); 128 | } 129 | 130 | at::Tensor 131 | causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, 132 | const c10::optional &bias_, 133 | const c10::optional &seq_idx_, 134 | bool silu_activation) { 135 | auto input_type = x.scalar_type(); 136 | auto weight_type = weight.scalar_type(); 137 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 138 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 139 | 140 | TORCH_CHECK(x.is_cuda()); 141 | TORCH_CHECK(weight.is_cuda()); 142 | 143 | const auto sizes = x.sizes(); 144 | const int batch_size = sizes[0]; 145 | const int dim = sizes[1]; 146 | const int seqlen = sizes[2]; 147 | const int width = weight.size(-1); 148 | 149 | CHECK_SHAPE(x, batch_size, dim, seqlen); 150 | CHECK_SHAPE(weight, dim, width); 151 | 152 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); 153 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; 154 | 155 | if (is_channel_last) { 156 | TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); 157 | } 158 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 159 | 160 | if (bias_.has_value()) { 161 | auto bias = bias_.value(); 162 | TORCH_CHECK(bias.scalar_type() == weight_type); 163 | TORCH_CHECK(bias.is_cuda()); 164 | TORCH_CHECK(bias.stride(-1) == 1); 165 | CHECK_SHAPE(bias, dim); 166 | } 167 | 168 | if (seq_idx_.has_value()) { 169 | TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout"); 170 | auto seq_idx = seq_idx_.value(); 171 | TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); 172 | TORCH_CHECK(seq_idx.is_cuda()); 173 | TORCH_CHECK(seq_idx.is_contiguous()); 174 | CHECK_SHAPE(seq_idx, batch_size, seqlen); 175 | } 176 | 177 | at::Tensor out = torch::empty_like(x); 178 | 179 | ConvParamsBase params; 180 | set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, 181 | bias_.has_value() ? bias_.value().data_ptr() : nullptr, 182 | silu_activation); 183 | 184 | if (seq_idx_.has_value()) { 185 | params.seq_idx_ptr = seq_idx_.value().data_ptr(); 186 | } else { 187 | params.seq_idx_ptr = nullptr; 188 | } 189 | 190 | // Otherwise the kernel will be launched from cuda:0 device 191 | // Cast to char to avoid compiler warning about narrowing 192 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 193 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 194 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { 195 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { 196 | if (!is_channel_last) { 197 | causal_conv1d_fwd_cuda(params, stream); 198 | } else { 199 | causal_conv1d_channellast_fwd_cuda(params, stream); 200 | } 201 | }); 202 | }); 203 | return out; 204 | } 205 | 206 | std::vector 207 | causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight, 208 | const c10::optional &bias_, 209 | at::Tensor &dout, 210 | c10::optional &seq_idx_, 211 | c10::optional &dx_, 212 | bool silu_activation) { 213 | auto input_type = x.scalar_type(); 214 | auto weight_type = weight.scalar_type(); 215 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 216 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 217 | 218 | TORCH_CHECK(x.is_cuda()); 219 | TORCH_CHECK(weight.is_cuda()); 220 | TORCH_CHECK(dout.is_cuda()); 221 | 222 | const auto sizes = x.sizes(); 223 | const int batch_size = sizes[0]; 224 | const int dim = sizes[1]; 225 | const int seqlen = sizes[2]; 226 | const int width = weight.size(-1); 227 | 228 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 229 | 230 | CHECK_SHAPE(x, batch_size, dim, seqlen); 231 | CHECK_SHAPE(weight, dim, width); 232 | CHECK_SHAPE(dout, batch_size, dim, seqlen); 233 | 234 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); 235 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; 236 | if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); } 237 | if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); } 238 | 239 | if (bias_.has_value()) { 240 | auto bias = bias_.value(); 241 | TORCH_CHECK(bias.scalar_type() == weight_type); 242 | TORCH_CHECK(bias.is_cuda()); 243 | TORCH_CHECK(bias.stride(-1) == 1); 244 | CHECK_SHAPE(bias, dim); 245 | } 246 | 247 | if (seq_idx_.has_value()) { 248 | TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout"); 249 | auto seq_idx = seq_idx_.value(); 250 | TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); 251 | TORCH_CHECK(seq_idx.is_cuda()); 252 | TORCH_CHECK(seq_idx.is_contiguous()); 253 | CHECK_SHAPE(seq_idx, batch_size, seqlen); 254 | } 255 | 256 | at::Tensor dx; 257 | if (dx_.has_value()) { 258 | dx = dx_.value(); 259 | TORCH_CHECK(dx.scalar_type() == input_type); 260 | TORCH_CHECK(dx.is_cuda()); 261 | CHECK_SHAPE(dx, batch_size, dim, seqlen); 262 | if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); } 263 | if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); } 264 | } else { 265 | dx = torch::empty_like(x); 266 | } 267 | 268 | // Otherwise the kernel will be launched from cuda:0 device 269 | // Cast to char to avoid compiler warning about narrowing 270 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 271 | 272 | at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat)); 273 | at::Tensor dbias; 274 | if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); } 275 | 276 | ConvParamsBwd params; 277 | set_conv_params_bwd(params, batch_size, dim, seqlen, width, 278 | x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr, 279 | dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr, 280 | silu_activation); 281 | 282 | if (seq_idx_.has_value()) { 283 | params.seq_idx_ptr = seq_idx_.value().data_ptr(); 284 | } else { 285 | params.seq_idx_ptr = nullptr; 286 | } 287 | 288 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 289 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] { 290 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] { 291 | if (!is_channel_last) { 292 | causal_conv1d_bwd_cuda(params, stream); 293 | } else { 294 | causal_conv1d_channellast_bwd_cuda(params, stream); 295 | } 296 | }); 297 | }); 298 | return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias}; 299 | } 300 | 301 | at::Tensor 302 | causal_conv1d_update(const at::Tensor &x, 303 | const at::Tensor &conv_state, 304 | const at::Tensor &weight, 305 | const c10::optional &bias_, 306 | bool silu_activation) { 307 | auto input_type = x.scalar_type(); 308 | auto weight_type = weight.scalar_type(); 309 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 310 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 311 | TORCH_CHECK(conv_state.scalar_type() == input_type); 312 | 313 | TORCH_CHECK(x.is_cuda()); 314 | TORCH_CHECK(conv_state.is_cuda()); 315 | TORCH_CHECK(weight.is_cuda()); 316 | 317 | const auto sizes = x.sizes(); 318 | const int batch_size = sizes[0]; 319 | const int dim = sizes[1]; 320 | const int width = weight.size(-1); 321 | 322 | CHECK_SHAPE(x, batch_size, dim); 323 | CHECK_SHAPE(conv_state, batch_size, dim, width); 324 | CHECK_SHAPE(weight, dim, width); 325 | 326 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 327 | 328 | if (bias_.has_value()) { 329 | auto bias = bias_.value(); 330 | TORCH_CHECK(bias.scalar_type() == weight_type); 331 | TORCH_CHECK(bias.is_cuda()); 332 | TORCH_CHECK(bias.stride(-1) == 1); 333 | CHECK_SHAPE(bias, dim); 334 | } 335 | 336 | at::Tensor out = torch::empty_like(x); 337 | 338 | ConvParamsBase params; 339 | set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, 340 | bias_.has_value() ? bias_.value().data_ptr() : nullptr, 341 | silu_activation); 342 | params.conv_state_ptr = conv_state.data_ptr(); 343 | // All stride are in elements, not bytes. 344 | params.conv_state_batch_stride = conv_state.stride(0); 345 | params.conv_state_c_stride = conv_state.stride(1); 346 | params.conv_state_l_stride = conv_state.stride(2); 347 | 348 | // Otherwise the kernel will be launched from cuda:0 device 349 | // Cast to char to avoid compiler warning about narrowing 350 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 351 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 352 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { 353 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { 354 | causal_conv1d_update_cuda(params, stream); 355 | }); 356 | }); 357 | return out; 358 | } 359 | 360 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 361 | m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward"); 362 | m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward"); 363 | m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update"); 364 | } 365 | -------------------------------------------------------------------------------- /mamba/mamba-main/mamba_ssm/ops/selective_scan_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.cuda.amp import custom_bwd, custom_fwd 6 | 7 | from einops import rearrange, repeat 8 | 9 | try: 10 | from causal_conv1d import causal_conv1d_fn 11 | import causal_conv1d_cuda 12 | except ImportError: 13 | causal_conv1d_fn = None 14 | causal_conv1d_cuda = None 15 | 16 | import selective_scan_cuda 17 | 18 | 19 | class SelectiveScanFn(torch.autograd.Function): 20 | 21 | @staticmethod 22 | def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, 23 | return_last_state=False): 24 | if u.stride(-1) != 1: 25 | u = u.contiguous() 26 | if delta.stride(-1) != 1: 27 | delta = delta.contiguous() 28 | if D is not None: 29 | D = D.contiguous() 30 | if B.stride(-1) != 1: 31 | B = B.contiguous() 32 | if C.stride(-1) != 1: 33 | C = C.contiguous() 34 | if z is not None and z.stride(-1) != 1: 35 | z = z.contiguous() 36 | if B.dim() == 3: 37 | B = rearrange(B, "b dstate l -> b 1 dstate l") 38 | ctx.squeeze_B = True 39 | if C.dim() == 3: 40 | C = rearrange(C, "b dstate l -> b 1 dstate l") 41 | ctx.squeeze_C = True 42 | out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) 43 | ctx.delta_softplus = delta_softplus 44 | ctx.has_z = z is not None 45 | last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) 46 | if not ctx.has_z: 47 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 48 | return out if not return_last_state else (out, last_state) 49 | else: 50 | ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) 51 | out_z = rest[0] 52 | return out_z if not return_last_state else (out_z, last_state) 53 | 54 | @staticmethod 55 | def backward(ctx, dout, *args): 56 | if not ctx.has_z: 57 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 58 | z = None 59 | out = None 60 | else: 61 | u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors 62 | if dout.stride(-1) != 1: 63 | dout = dout.contiguous() 64 | # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the 65 | # backward of selective_scan_cuda with the backward of chunk). 66 | # Here we just pass in None and dz will be allocated in the C++ code. 67 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( 68 | u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, 69 | False # option to recompute out_z, not used here 70 | ) 71 | dz = rest[0] if ctx.has_z else None 72 | dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB 73 | dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC 74 | return (du, ddelta, dA, dB, dC, 75 | dD if D is not None else None, 76 | dz, 77 | ddelta_bias if delta_bias is not None else None, 78 | None, 79 | None) 80 | 81 | 82 | def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, 83 | return_last_state=False): 84 | """if return_last_state is True, returns (out, last_state) 85 | last_state has shape (batch, dim, dstate). Note that the gradient of the last state is 86 | not considered in the backward pass. 87 | """ 88 | return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) 89 | 90 | 91 | def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, 92 | return_last_state=False): 93 | """ 94 | u: r(B D L) 95 | delta: r(B D L) 96 | A: c(D N) or r(D N) 97 | B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) 98 | C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) 99 | D: r(D) 100 | z: r(B D L) 101 | delta_bias: r(D), fp32 102 | 103 | out: r(B D L) 104 | last_state (optional): r(B D dstate) or c(B D dstate) 105 | """ 106 | dtype_in = u.dtype 107 | u = u.float() 108 | delta = delta.float() 109 | if delta_bias is not None: 110 | delta = delta + delta_bias[..., None].float() 111 | if delta_softplus: 112 | delta = F.softplus(delta) 113 | batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] 114 | is_variable_B = B.dim() >= 3 115 | is_variable_C = C.dim() >= 3 116 | if A.is_complex(): 117 | if is_variable_B: 118 | B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) 119 | if is_variable_C: 120 | C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) 121 | else: 122 | B = B.float() 123 | C = C.float() 124 | x = A.new_zeros((batch, dim, dstate)) 125 | ys = [] 126 | deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 127 | if not is_variable_B: 128 | deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) 129 | else: 130 | if B.dim() == 3: 131 | deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) 132 | else: 133 | B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) 134 | deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) 135 | if is_variable_C and C.dim() == 4: 136 | C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) 137 | last_state = None 138 | for i in range(u.shape[2]): 139 | x = deltaA[:, :, i] * x + deltaB_u[:, :, i] 140 | if not is_variable_C: 141 | y = torch.einsum('bdn,dn->bd', x, C) 142 | else: 143 | if C.dim() == 3: 144 | y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) 145 | else: 146 | y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) 147 | if i == u.shape[2] - 1: 148 | last_state = x 149 | if y.is_complex(): 150 | y = y.real * 2 151 | ys.append(y) 152 | y = torch.stack(ys, dim=2) # (batch dim L) 153 | out = y if D is None else y + u * rearrange(D, "d -> d 1") 154 | if z is not None: 155 | out = out * F.silu(z) 156 | out = out.to(dtype=dtype_in) 157 | return out if not return_last_state else (out, last_state) 158 | 159 | 160 | class MambaInnerFn(torch.autograd.Function): 161 | 162 | @staticmethod 163 | @custom_fwd 164 | def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 165 | out_proj_weight, out_proj_bias, 166 | A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, 167 | C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): 168 | """ 169 | xz: (batch, dim, seqlen) 170 | """ 171 | assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." 172 | assert checkpoint_lvl in [0, 1] 173 | L = xz.shape[-1] 174 | delta_rank = delta_proj_weight.shape[1] 175 | d_state = A.shape[-1] * (1 if not A.is_complex() else 2) 176 | if torch.is_autocast_enabled(): 177 | x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) 178 | delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) 179 | out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) 180 | out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) 181 | if out_proj_bias is not None else None) 182 | if xz.stride(-1) != 1: 183 | xz = xz.contiguous() 184 | conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") 185 | x, z = xz.chunk(2, dim=1) 186 | conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None 187 | conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( 188 | x, conv1d_weight, conv1d_bias, None, None, None, True 189 | ) 190 | # We're being very careful here about the layout, to avoid extra transposes. 191 | # We want delta to have d as the slowest moving dimension 192 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 193 | x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) 194 | delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) 195 | ctx.is_variable_B = B is None 196 | ctx.is_variable_C = C is None 197 | ctx.B_proj_bias_is_None = B_proj_bias is None 198 | ctx.C_proj_bias_is_None = C_proj_bias is None 199 | if B is None: # variable B 200 | B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) 201 | if B_proj_bias is not None: 202 | B = B + B_proj_bias.to(dtype=B.dtype) 203 | if not A.is_complex(): 204 | # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() 205 | B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() 206 | else: 207 | B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() 208 | else: 209 | if B.stride(-1) != 1: 210 | B = B.contiguous() 211 | if C is None: # variable C 212 | C = x_dbl[:, -d_state:] # (bl dstate) 213 | if C_proj_bias is not None: 214 | C = C + C_proj_bias.to(dtype=C.dtype) 215 | if not A.is_complex(): 216 | # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() 217 | C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() 218 | else: 219 | C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() 220 | else: 221 | if C.stride(-1) != 1: 222 | C = C.contiguous() 223 | if D is not None: 224 | D = D.contiguous() 225 | out, scan_intermediates, out_z = selective_scan_cuda.fwd( 226 | conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus 227 | ) 228 | ctx.delta_softplus = delta_softplus 229 | ctx.out_proj_bias_is_None = out_proj_bias is None 230 | ctx.checkpoint_lvl = checkpoint_lvl 231 | if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass 232 | conv1d_out, delta = None, None 233 | ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, 234 | delta_proj_weight, out_proj_weight, conv1d_out, delta, 235 | A, B, C, D, delta_bias, scan_intermediates, out) 236 | return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) 237 | 238 | @staticmethod 239 | @custom_bwd 240 | def backward(ctx, dout): 241 | # dout: (batch, seqlen, dim) 242 | assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." 243 | (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, 244 | conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors 245 | L = xz.shape[-1] 246 | delta_rank = delta_proj_weight.shape[1] 247 | d_state = A.shape[-1] * (1 if not A.is_complex() else 2) 248 | x, z = xz.chunk(2, dim=1) 249 | if dout.stride(-1) != 1: 250 | dout = dout.contiguous() 251 | if ctx.checkpoint_lvl == 1: 252 | conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( 253 | x, conv1d_weight, conv1d_bias, None, None, None, True 254 | ) 255 | delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), 256 | "d (b l) -> b d l", l = L) 257 | # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the 258 | # backward of selective_scan_cuda with the backward of chunk). 259 | dxz = torch.empty_like(xz) # (batch, dim, seqlen) 260 | dx, dz = dxz.chunk(2, dim=1) 261 | dout = rearrange(dout, "b l e -> e (b l)") 262 | dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) 263 | dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( 264 | conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, 265 | ctx.delta_softplus, 266 | True # option to recompute out_z 267 | ) 268 | dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) 269 | dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None 270 | dD = dD if D is not None else None 271 | dx_dbl = torch.empty_like(x_dbl) 272 | dB_proj_bias = None 273 | if ctx.is_variable_B: 274 | if not A.is_complex(): 275 | dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() 276 | else: 277 | dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() 278 | dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None 279 | dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) 280 | dB = None 281 | dC_proj_bias = None 282 | if ctx.is_variable_C: 283 | if not A.is_complex(): 284 | dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() 285 | else: 286 | dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() 287 | dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None 288 | dx_dbl[:, -d_state:] = dC # (bl d) 289 | dC = None 290 | ddelta = rearrange(ddelta, "b d l -> d (b l)") 291 | ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) 292 | dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) 293 | dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") 294 | dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) 295 | dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) 296 | dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) 297 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 298 | # backward of conv1d with the backward of chunk). 299 | dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( 300 | x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True 301 | ) 302 | dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None 303 | dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") 304 | return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, 305 | dout_proj_weight, dout_proj_bias, 306 | dA, dB, dC, dD, 307 | ddelta_bias if delta_bias is not None else None, 308 | dB_proj_bias, dC_proj_bias, None) 309 | 310 | 311 | def mamba_inner_fn( 312 | xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 313 | out_proj_weight, out_proj_bias, 314 | A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, 315 | C_proj_bias=None, delta_softplus=True 316 | ): 317 | return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 318 | out_proj_weight, out_proj_bias, 319 | A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) 320 | 321 | 322 | def mamba_inner_ref( 323 | xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 324 | out_proj_weight, out_proj_bias, 325 | A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, 326 | C_proj_bias=None, delta_softplus=True 327 | ): 328 | assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." 329 | L = xz.shape[-1] 330 | delta_rank = delta_proj_weight.shape[1] 331 | d_state = A.shape[-1] * (1 if not A.is_complex() else 2) 332 | x, z = xz.chunk(2, dim=1) 333 | x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") 334 | # We're being very careful here about the layout, to avoid extra transposes. 335 | # We want delta to have d as the slowest moving dimension 336 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 337 | x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) 338 | delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() 339 | delta = rearrange(delta, "d (b l) -> b d l", l=L) 340 | if B is None: # variable B 341 | B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) 342 | if B_proj_bias is not None: 343 | B = B + B_proj_bias.to(dtype=B.dtype) 344 | if not A.is_complex(): 345 | B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() 346 | else: 347 | B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() 348 | if C is None: # variable B 349 | C = x_dbl[:, -d_state:] # (bl d) 350 | if C_proj_bias is not None: 351 | C = C + C_proj_bias.to(dtype=C.dtype) 352 | if not A.is_complex(): 353 | C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() 354 | else: 355 | C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() 356 | y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) 357 | return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) 358 | --------------------------------------------------------------------------------