├── README.md ├── assets ├── 1.png ├── 2.png ├── MambaXAI.pdf ├── notebook.png ├── pdf.png └── xai_gradmethod.jpg ├── causal-conv1d ├── AUTHORS ├── LICENSE ├── README.md ├── causal_conv1d │ ├── __init__.py │ └── causal_conv1d_interface.py ├── csrc │ ├── causal_conv1d.cpp │ ├── causal_conv1d.h │ ├── causal_conv1d_bwd.cu │ ├── causal_conv1d_common.h │ ├── causal_conv1d_fwd.cu │ ├── causal_conv1d_update.cu │ └── static_switch.h ├── setup.py └── tests │ └── test_causal_conv1d.py ├── mamba-1p1p1 ├── AUTHORS ├── LICENSE ├── README.md ├── assets │ └── selection.png ├── benchmarks │ └── benchmark_generation_mamba_simple.py ├── csrc │ └── selective_scan │ │ ├── reverse_scan.cuh │ │ ├── selective_scan.cpp │ │ ├── selective_scan.h │ │ ├── selective_scan_bwd_bf16_complex.cu │ │ ├── selective_scan_bwd_bf16_real.cu │ │ ├── selective_scan_bwd_fp16_complex.cu │ │ ├── selective_scan_bwd_fp16_real.cu │ │ ├── selective_scan_bwd_fp32_complex.cu │ │ ├── selective_scan_bwd_fp32_real.cu │ │ ├── selective_scan_bwd_kernel.cuh │ │ ├── selective_scan_common.h │ │ ├── selective_scan_fwd_bf16.cu │ │ ├── selective_scan_fwd_fp16.cu │ │ ├── selective_scan_fwd_fp32.cu │ │ ├── selective_scan_fwd_kernel.cuh │ │ ├── static_switch.h │ │ └── uninitialized_copy.cuh ├── evals │ └── lm_harness_eval.py ├── mamba_ssm │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── config_mamba.py │ │ └── mixer_seq_simple.py │ ├── modules │ │ ├── __init__.py │ │ └── mamba_simple.py │ ├── ops │ │ ├── __init__.py │ │ ├── selective_scan_interface.py │ │ └── triton │ │ │ ├── __init__.py │ │ │ ├── layernorm.py │ │ │ └── selective_state_update.py │ └── utils │ │ ├── __init__.py │ │ ├── generation.py │ │ └── hf.py ├── setup.py └── tests │ └── ops │ ├── test_selective_scan.py │ └── triton │ └── test_selective_state_update.py └── vim ├── augment.py ├── class_mapper.py ├── datasets.py ├── engine.py ├── hubconf.py ├── images ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg └── 5.jpg ├── losses.py ├── main.py ├── models_mamba.py ├── rope.py ├── run_with_submitit.py ├── samplers.py ├── scripts ├── ft-vim-s.sh ├── ft-vim-t.sh ├── pt-vim-s.sh └── pt-vim-t.sh ├── utils.py ├── vim_requirements.txt ├── vmamba_xai.ipynb └── xai_utils.py /README.md: -------------------------------------------------------------------------------- 1 |
2 |

🐍 The Hidden Attention of Mamba Models 🐍

3 | 4 | Ameen Ali1 \*,Itamar Zimerman1 \* and Lior Wolf1 5 |
6 | ameenali023@gmail.com, itamarzimm@gmail.com, liorwolf@gmail.com 7 |
8 | 1 Tel Aviv University 9 | (\*) equal contribution 10 | 11 | 12 | 13 |
14 | 15 | ## Official PyTorch Implementation of "The Hidden Attention of Mamba Models" 16 | 17 | The Mamba layer offers an efficient state space model (SSM) that is highly effective in modeling multiple domains including long-range sequences and images. SSMs are viewed as dual models, in which one trains in parallel on the entire sequence using convolutions, and deploys in an autoregressive manner. We add a third view and show that such models can be viewed as attention-driven models. This new perspective enables us to compare the underlying mechanisms to that of the self-attention layers in transformers and allows us to peer inside the inner workings of the Mamba model with explainability methods. 18 |
19 | You can access the paper through : The Hidden Attention of Mamba Models 20 | 21 |
22 | Left Image 23 |
24 | 25 | ## Set Up Environment 26 | 27 | - Python 3.10.13 28 | 29 | - `conda create -n your_env_name python=3.10.13` 30 | - Activate Env 31 | - `conda activate your_env_name` 32 | - CUDA TOOLKIT 11.8 33 | - `conda install nvidia/label/cuda-11.8.0::cuda-toolkit` 34 | - torch 2.1.1 + cu118 35 | - `pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118` 36 | 37 | - Requirements: vim_requirements.txt 38 | - `pip install -r vim/vim_requirements.txt` 39 | 40 | - Install jupyter 41 | - `pip install jupyter` 42 | 43 | - Install ``causal_conv1d`` and ``mamba`` from *our source* 44 | - `cd causal-conv1d` 45 | - `pip install --editable .` 46 | - `cd ..` 47 | - `pip install --editable mamba-1p1p1` 48 | 49 | 50 | 51 | 52 | ## Pre-Trained Weights 53 | 54 | We have used the official weights provided by [Vim](https://github.com/hustvl/Vim), which can be downloaded from here: 55 | 56 | | Model | #param. | Top-1 Acc. | Top-5 Acc. | Hugginface Repo | 57 | |:------------------------------------------------------------------:|:-------------:|:----------:|:----------:|:----------:| 58 | | [Vim-tiny](https://huggingface.co/hustvl/Vim-tiny-midclstok) | 7M | 76.1 | 93.0 | https://huggingface.co/hustvl/Vim-tiny-midclstok | 59 | | [Vim-tiny+](https://huggingface.co/hustvl/Vim-tiny-midclstok) | 7M | 78.3 | 94.2 | https://huggingface.co/hustvl/Vim-tiny-midclstok | 60 | | [Vim-small](https://huggingface.co/hustvl/Vim-small-midclstok) | 26M | 80.5 | 95.1 | https://huggingface.co/hustvl/Vim-small-midclstok | 61 | | [Vim-small+](https://huggingface.co/hustvl/Vim-small-midclstok) | 26M | 81.6 | 95.4 | https://huggingface.co/hustvl/Vim-small-midclstok | 62 | 63 | **Notes:** 64 | - In all of our experiments, we have worked with [Vim-small](https://huggingface.co/hustvl/Vim-small-midclstok). 65 | 66 | ## Vision-Mamba Explainability Notebook: 67 |
68 | Left Image 69 |
70 |
71 | Follow the instructions in vim/vmamba_xai.ipynb notebook, in order to apply a single-image inference for the 3 introduced methods in the paper. 72 |
73 |
74 | Left Image 75 |
76 | 77 | ## To-Do 78 | For the segmentation experiment, please check out our [follow-up work](https://github.com/Itamarzimm/UnifiedImplicitAttnRepr/tree/main). 79 |
80 | 84 | 85 | ## Citation 86 | if you find our work useful, please consider citing us: 87 | ```latex 88 | @misc{ali2024hidden, 89 | title={The Hidden Attention of Mamba Models}, 90 | author={Ameen Ali and Itamar Zimerman and Lior Wolf}, 91 | year={2024}, 92 | eprint={2403.01590}, 93 | archivePrefix={arXiv}, 94 | primaryClass={cs.LG} 95 | } 96 | ``` 97 | ## Acknowledgement 98 | This repository is heavily based on [Vim](https://github.com/hustvl/Vim), [Mamba](https://github.com/state-spaces/mamba) and [Transformer-Explainability](https://github.com/hila-chefer/Transformer-Explainability). Thanks for their wonderful works. 99 | -------------------------------------------------------------------------------- /assets/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/1.png -------------------------------------------------------------------------------- /assets/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/2.png -------------------------------------------------------------------------------- /assets/MambaXAI.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/MambaXAI.pdf -------------------------------------------------------------------------------- /assets/notebook.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/notebook.png -------------------------------------------------------------------------------- /assets/pdf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/pdf.png -------------------------------------------------------------------------------- /assets/xai_gradmethod.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/xai_gradmethod.jpg -------------------------------------------------------------------------------- /causal-conv1d/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | -------------------------------------------------------------------------------- /causal-conv1d/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /causal-conv1d/README.md: -------------------------------------------------------------------------------- 1 | # Causal depthwise conv1d in CUDA with a PyTorch interface 2 | -------------------------------------------------------------------------------- /causal-conv1d/causal_conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update 4 | -------------------------------------------------------------------------------- /causal-conv1d/causal_conv1d/causal_conv1d_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | import causal_conv1d_cuda 8 | 9 | 10 | class CausalConv1dFn(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, weight, bias=None, activation=None): 13 | if activation not in [None, "silu", "swish"]: 14 | raise NotImplementedError("activation must be None, silu, or swish") 15 | if x.stride(2) != 1 and x.stride(1) != 1: 16 | x = x.contiguous() 17 | bias = bias.contiguous() if bias is not None else None 18 | ctx.save_for_backward(x, weight, bias) 19 | ctx.activation = activation in ["silu", "swish"] 20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation) 21 | return out 22 | 23 | @staticmethod 24 | def backward(ctx, dout): 25 | x, weight, bias = ctx.saved_tensors 26 | if dout.stride(2) != 1 and dout.stride(1) != 1: 27 | dout = dout.contiguous() 28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 29 | # backward of conv1d with the backward of chunk). 30 | # Here we just pass in None and dx will be allocated in the C++ code. 31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( 32 | x, weight, bias, dout, None, ctx.activation 33 | ) 34 | return dx, dweight, dbias if bias is not None else None, None 35 | 36 | 37 | def causal_conv1d_fn(x, weight, bias=None, activation=None): 38 | """ 39 | x: (batch, dim, seqlen) 40 | weight: (dim, width) 41 | bias: (dim,) 42 | activation: either None or "silu" or "swish" 43 | 44 | out: (batch, dim, seqlen) 45 | """ 46 | return CausalConv1dFn.apply(x, weight, bias, activation) 47 | 48 | 49 | def causal_conv1d_ref(x, weight, bias=None, activation=None): 50 | """ 51 | x: (batch, dim, seqlen) 52 | weight: (dim, width) 53 | bias: (dim,) 54 | 55 | out: (batch, dim, seqlen) 56 | """ 57 | if activation not in [None, "silu", "swish"]: 58 | raise NotImplementedError("activation must be None, silu, or swish") 59 | dtype_in = x.dtype 60 | x = x.to(weight.dtype) 61 | seqlen = x.shape[-1] 62 | dim, width = weight.shape 63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 64 | out = out[..., :seqlen] 65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 66 | 67 | 68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): 69 | """ 70 | x: (batch, dim) 71 | conv_state: (batch, dim, width) 72 | weight: (dim, width) 73 | bias: (dim,) 74 | 75 | out: (batch, dim) 76 | """ 77 | if activation not in [None, "silu", "swish"]: 78 | raise NotImplementedError("activation must be None, silu, or swish") 79 | activation = activation in ["silu", "swish"] 80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) 81 | 82 | 83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): 84 | """ 85 | x: (batch, dim) 86 | conv_state: (batch, dim, width) 87 | weight: (dim, width) 88 | bias: (dim,) 89 | 90 | out: (batch, dim) 91 | """ 92 | if activation not in [None, "silu", "swish"]: 93 | raise NotImplementedError("activation must be None, silu, or swish") 94 | dtype_in = x.dtype 95 | batch, dim = x.shape 96 | width = weight.shape[1] 97 | assert conv_state.shape == (batch, dim, width) 98 | assert weight.shape == (dim, width) 99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 100 | conv_state[:, :, -1] = x 101 | out = torch.sum(conv_state * weight, dim=-1) # (B D) 102 | if bias is not None: 103 | out += bias 104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 105 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "causal_conv1d.h" 11 | 12 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 13 | 14 | #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ 15 | if (ITYPE == at::ScalarType::Half) { \ 16 | using input_t = at::Half; \ 17 | __VA_ARGS__(); \ 18 | } else if (ITYPE == at::ScalarType::BFloat16) { \ 19 | using input_t = at::BFloat16; \ 20 | __VA_ARGS__(); \ 21 | } else if (ITYPE == at::ScalarType::Float) { \ 22 | using input_t = float; \ 23 | __VA_ARGS__(); \ 24 | } else { \ 25 | AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ 26 | } 27 | 28 | #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ 29 | if (WTYPE == at::ScalarType::Half) { \ 30 | using weight_t = at::Half; \ 31 | __VA_ARGS__(); \ 32 | } else if (WTYPE == at::ScalarType::BFloat16) { \ 33 | using weight_t = at::BFloat16; \ 34 | __VA_ARGS__(); \ 35 | } else if (WTYPE == at::ScalarType::Float) { \ 36 | using weight_t = float; \ 37 | __VA_ARGS__(); \ 38 | } else { \ 39 | AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ 40 | } 41 | 42 | template 43 | void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 44 | template 45 | void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 46 | 47 | template 48 | void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 49 | template 50 | void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 51 | 52 | template 53 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 54 | 55 | void set_conv_params_fwd(ConvParamsBase ¶ms, 56 | // sizes 57 | const size_t batch, 58 | const size_t dim, 59 | const size_t seqlen, 60 | const size_t width, 61 | // device pointers 62 | const at::Tensor x, 63 | const at::Tensor weight, 64 | const at::Tensor out, 65 | void* bias_ptr, 66 | bool silu_activation) { 67 | 68 | // Reset the parameters 69 | memset(¶ms, 0, sizeof(params)); 70 | 71 | params.batch = batch; 72 | params.dim = dim; 73 | params.seqlen = seqlen; 74 | params.width = width; 75 | 76 | params.silu_activation = silu_activation; 77 | 78 | // Set the pointers and strides. 79 | params.x_ptr = x.data_ptr(); 80 | params.weight_ptr = weight.data_ptr(); 81 | params.bias_ptr = bias_ptr; 82 | params.out_ptr = out.data_ptr(); 83 | // All stride are in elements, not bytes. 84 | params.x_batch_stride = x.stride(0); 85 | params.x_c_stride = x.stride(1); 86 | params.x_l_stride = x.stride(-1); 87 | params.weight_c_stride = weight.stride(0); 88 | params.weight_width_stride = weight.stride(1); 89 | params.out_batch_stride = out.stride(0); 90 | params.out_c_stride = out.stride(1); 91 | params.out_l_stride = out.stride(-1); 92 | } 93 | 94 | 95 | void set_conv_params_bwd(ConvParamsBwd ¶ms, 96 | // sizes 97 | const size_t batch, 98 | const size_t dim, 99 | const size_t seqlen, 100 | const size_t width, 101 | // device pointers 102 | const at::Tensor x, 103 | const at::Tensor weight, 104 | void* bias_ptr, 105 | const at::Tensor dout, 106 | const at::Tensor dx, 107 | const at::Tensor dweight, 108 | void* dbias_ptr, 109 | bool silu_activation) { 110 | // Pass in "dout" instead of "out", we're not gonna use "out" at all. 111 | set_conv_params_fwd(params, batch, dim, seqlen, width, 112 | x, weight, dout, bias_ptr, silu_activation); 113 | 114 | // Set the pointers and strides. 115 | params.dout_ptr = dout.data_ptr(); 116 | params.dx_ptr = dx.data_ptr(); 117 | params.dweight_ptr = dweight.data_ptr(); 118 | params.dbias_ptr = dbias_ptr; 119 | // All stride are in elements, not bytes. 120 | params.dout_batch_stride = dout.stride(0); 121 | params.dout_c_stride = dout.stride(1); 122 | params.dout_l_stride = dout.stride(2); 123 | params.dweight_c_stride = dweight.stride(0); 124 | params.dweight_width_stride = dweight.stride(1); 125 | params.dx_batch_stride = dx.stride(0); 126 | params.dx_c_stride = dx.stride(1); 127 | params.dx_l_stride = dx.stride(2); 128 | } 129 | 130 | at::Tensor 131 | causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, 132 | const c10::optional &bias_, 133 | bool silu_activation) { 134 | auto input_type = x.scalar_type(); 135 | auto weight_type = weight.scalar_type(); 136 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 137 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 138 | 139 | TORCH_CHECK(x.is_cuda()); 140 | TORCH_CHECK(weight.is_cuda()); 141 | 142 | const auto sizes = x.sizes(); 143 | const int batch_size = sizes[0]; 144 | const int dim = sizes[1]; 145 | const int seqlen = sizes[2]; 146 | const int width = weight.size(-1); 147 | 148 | CHECK_SHAPE(x, batch_size, dim, seqlen); 149 | CHECK_SHAPE(weight, dim, width); 150 | 151 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); 152 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; 153 | 154 | if (is_channel_last) { 155 | TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); 156 | } 157 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 158 | 159 | 160 | if (bias_.has_value()) { 161 | auto bias = bias_.value(); 162 | TORCH_CHECK(bias.scalar_type() == weight_type); 163 | TORCH_CHECK(bias.is_cuda()); 164 | TORCH_CHECK(bias.stride(-1) == 1); 165 | CHECK_SHAPE(bias, dim); 166 | } 167 | 168 | at::Tensor out = torch::empty_like(x); 169 | 170 | ConvParamsBase params; 171 | set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, 172 | bias_.has_value() ? bias_.value().data_ptr() : nullptr, 173 | silu_activation); 174 | 175 | // Otherwise the kernel will be launched from cuda:0 device 176 | // Cast to char to avoid compiler warning about narrowing 177 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 178 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 179 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { 180 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { 181 | if (!is_channel_last) { 182 | causal_conv1d_fwd_cuda(params, stream); 183 | } else { 184 | causal_conv1d_channellast_fwd_cuda(params, stream); 185 | } 186 | }); 187 | }); 188 | return out; 189 | } 190 | 191 | std::vector 192 | causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight, 193 | const c10::optional &bias_, 194 | at::Tensor &dout, 195 | c10::optional &dx_, 196 | bool silu_activation) { 197 | auto input_type = x.scalar_type(); 198 | auto weight_type = weight.scalar_type(); 199 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 200 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 201 | 202 | TORCH_CHECK(x.is_cuda()); 203 | TORCH_CHECK(weight.is_cuda()); 204 | TORCH_CHECK(dout.is_cuda()); 205 | 206 | const auto sizes = x.sizes(); 207 | const int batch_size = sizes[0]; 208 | const int dim = sizes[1]; 209 | const int seqlen = sizes[2]; 210 | const int width = weight.size(-1); 211 | 212 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 213 | 214 | CHECK_SHAPE(x, batch_size, dim, seqlen); 215 | CHECK_SHAPE(weight, dim, width); 216 | CHECK_SHAPE(dout, batch_size, dim, seqlen); 217 | 218 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); 219 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; 220 | if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); } 221 | if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); } 222 | 223 | if (bias_.has_value()) { 224 | auto bias = bias_.value(); 225 | TORCH_CHECK(bias.scalar_type() == weight_type); 226 | TORCH_CHECK(bias.is_cuda()); 227 | TORCH_CHECK(bias.stride(-1) == 1); 228 | CHECK_SHAPE(bias, dim); 229 | } 230 | 231 | at::Tensor dx; 232 | if (dx_.has_value()) { 233 | dx = dx_.value(); 234 | TORCH_CHECK(dx.scalar_type() == input_type); 235 | TORCH_CHECK(dx.is_cuda()); 236 | CHECK_SHAPE(dx, batch_size, dim, seqlen); 237 | if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); } 238 | if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); } 239 | } else { 240 | dx = torch::empty_like(x); 241 | } 242 | 243 | // Otherwise the kernel will be launched from cuda:0 device 244 | // Cast to char to avoid compiler warning about narrowing 245 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 246 | 247 | at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat)); 248 | at::Tensor dbias; 249 | if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); } 250 | 251 | ConvParamsBwd params; 252 | set_conv_params_bwd(params, batch_size, dim, seqlen, width, 253 | x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr, 254 | dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr, 255 | silu_activation); 256 | 257 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 258 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] { 259 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] { 260 | if (!is_channel_last) { 261 | causal_conv1d_bwd_cuda(params, stream); 262 | } else { 263 | causal_conv1d_channellast_bwd_cuda(params, stream); 264 | } 265 | }); 266 | }); 267 | return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias}; 268 | } 269 | 270 | at::Tensor 271 | causal_conv1d_update(const at::Tensor &x, 272 | const at::Tensor &conv_state, 273 | const at::Tensor &weight, 274 | const c10::optional &bias_, 275 | bool silu_activation) { 276 | auto input_type = x.scalar_type(); 277 | auto weight_type = weight.scalar_type(); 278 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 279 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 280 | TORCH_CHECK(conv_state.scalar_type() == input_type); 281 | 282 | TORCH_CHECK(x.is_cuda()); 283 | TORCH_CHECK(conv_state.is_cuda()); 284 | TORCH_CHECK(weight.is_cuda()); 285 | 286 | const auto sizes = x.sizes(); 287 | const int batch_size = sizes[0]; 288 | const int dim = sizes[1]; 289 | const int width = weight.size(-1); 290 | 291 | CHECK_SHAPE(x, batch_size, dim); 292 | CHECK_SHAPE(conv_state, batch_size, dim, width); 293 | CHECK_SHAPE(weight, dim, width); 294 | 295 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 296 | 297 | if (bias_.has_value()) { 298 | auto bias = bias_.value(); 299 | TORCH_CHECK(bias.scalar_type() == weight_type); 300 | TORCH_CHECK(bias.is_cuda()); 301 | TORCH_CHECK(bias.stride(-1) == 1); 302 | CHECK_SHAPE(bias, dim); 303 | } 304 | 305 | at::Tensor out = torch::empty_like(x); 306 | 307 | ConvParamsBase params; 308 | set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, 309 | bias_.has_value() ? bias_.value().data_ptr() : nullptr, 310 | silu_activation); 311 | params.conv_state_ptr = conv_state.data_ptr(); 312 | // All stride are in elements, not bytes. 313 | params.conv_state_batch_stride = conv_state.stride(0); 314 | params.conv_state_c_stride = conv_state.stride(1); 315 | params.conv_state_l_stride = conv_state.stride(2); 316 | 317 | // Otherwise the kernel will be launched from cuda:0 device 318 | // Cast to char to avoid compiler warning about narrowing 319 | at::cuda::CUDAGuard device_guard{(char)x.get_device()}; 320 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 321 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { 322 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { 323 | causal_conv1d_update_cuda(params, stream); 324 | }); 325 | }); 326 | return out; 327 | } 328 | 329 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 330 | m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward"); 331 | m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward"); 332 | m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update"); 333 | } 334 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct ConvParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, dim, seqlen, width; 13 | bool silu_activation; 14 | 15 | index_t x_batch_stride; 16 | index_t x_c_stride; 17 | index_t x_l_stride; 18 | index_t weight_c_stride; 19 | index_t weight_width_stride; 20 | index_t out_batch_stride; 21 | index_t out_c_stride; 22 | index_t out_l_stride; 23 | 24 | index_t conv_state_batch_stride; 25 | index_t conv_state_c_stride; 26 | index_t conv_state_l_stride; 27 | 28 | // Common data pointers. 29 | void *__restrict__ x_ptr; 30 | void *__restrict__ weight_ptr; 31 | void *__restrict__ bias_ptr; 32 | void *__restrict__ out_ptr; 33 | 34 | void *__restrict__ conv_state_ptr; 35 | }; 36 | 37 | struct ConvParamsBwd: public ConvParamsBase { 38 | index_t dx_batch_stride; 39 | index_t dx_c_stride; 40 | index_t dx_l_stride; 41 | index_t dweight_c_stride; 42 | index_t dweight_width_stride; 43 | index_t dout_batch_stride; 44 | index_t dout_c_stride; 45 | index_t dout_l_stride; 46 | 47 | // Common data pointers. 48 | void *__restrict__ dx_ptr; 49 | void *__restrict__ dweight_ptr; 50 | void *__restrict__ dbias_ptr; 51 | void *__restrict__ dout_ptr; 52 | }; 53 | 54 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | //////////////////////////////////////////////////////////////////////////////////////////////////// 11 | 12 | template struct BytesToType {}; 13 | 14 | template<> struct BytesToType<16> { 15 | using Type = uint4; 16 | static_assert(sizeof(Type) == 16); 17 | }; 18 | 19 | template<> struct BytesToType<8> { 20 | using Type = uint64_t; 21 | static_assert(sizeof(Type) == 8); 22 | }; 23 | 24 | template<> struct BytesToType<4> { 25 | using Type = uint32_t; 26 | static_assert(sizeof(Type) == 4); 27 | }; 28 | 29 | template<> struct BytesToType<2> { 30 | using Type = uint16_t; 31 | static_assert(sizeof(Type) == 2); 32 | }; 33 | 34 | template<> struct BytesToType<1> { 35 | using Type = uint8_t; 36 | static_assert(sizeof(Type) == 1); 37 | }; 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | template 42 | struct SumOp { 43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 44 | }; 45 | 46 | template 47 | struct Allreduce { 48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 49 | template 50 | static __device__ inline T run(T x, Operator &op) { 51 | constexpr int OFFSET = THREADS / 2; 52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 53 | return Allreduce::run(x, op); 54 | } 55 | }; 56 | 57 | template<> 58 | struct Allreduce<2> { 59 | template 60 | static __device__ inline T run(T x, Operator &op) { 61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 62 | return x; 63 | } 64 | }; 65 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d_update.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #include 10 | #include 11 | 12 | #include "causal_conv1d.h" 13 | #include "causal_conv1d_common.h" 14 | #include "static_switch.h" 15 | 16 | template 17 | struct Causal_conv1d_update_kernel_traits { 18 | using input_t = input_t_; 19 | using weight_t = weight_t_; 20 | static constexpr int kNThreads = kNThreads_; 21 | static constexpr int kWidth = kWidth_; 22 | static constexpr int kNBytes = sizeof(input_t); 23 | static_assert(kNBytes == 2 || kNBytes == 4); 24 | }; 25 | 26 | template 27 | __global__ __launch_bounds__(Ktraits::kNThreads) 28 | void causal_conv1d_update_kernel(ConvParamsBase params) { 29 | constexpr int kWidth = Ktraits::kWidth; 30 | constexpr int kNThreads = Ktraits::kNThreads; 31 | using input_t = typename Ktraits::input_t; 32 | using weight_t = typename Ktraits::weight_t; 33 | 34 | const int tidx = threadIdx.x; 35 | const int batch_id = blockIdx.x; 36 | const int channel_id = blockIdx.y * kNThreads + tidx; 37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 38 | + channel_id * params.x_c_stride; 39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride 40 | + channel_id * params.conv_state_c_stride; 41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; 42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 43 | + channel_id * params.out_c_stride; 44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); 45 | 46 | float weight_vals[kWidth] = {0}; 47 | if (channel_id < params.dim) { 48 | #pragma unroll 49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } 50 | } 51 | 52 | float x_vals[kWidth] = {0}; 53 | if (channel_id < params.dim) { 54 | #pragma unroll 55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } 56 | x_vals[kWidth - 1] = float(x[0]); 57 | #pragma unroll 58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } 59 | } 60 | 61 | float out_val = bias_val; 62 | #pragma unroll 63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } 64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } 65 | if (channel_id < params.dim) { out[0] = input_t(out_val); } 66 | } 67 | 68 | template 69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 70 | using Ktraits = Causal_conv1d_update_kernel_traits; 71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); 72 | auto kernel = &causal_conv1d_update_kernel; 73 | kernel<<>>(params); 74 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 75 | } 76 | 77 | template 78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 79 | if (params.width == 2) { 80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); 81 | } else if (params.width == 3) { 82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); 83 | } else if (params.width == 4) { 84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); 85 | } 86 | } 87 | 88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /causal-conv1d/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | static constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | static constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /causal-conv1d/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | import sys 3 | import warnings 4 | import os 5 | import re 6 | import ast 7 | from pathlib import Path 8 | from packaging.version import parse, Version 9 | import platform 10 | 11 | from setuptools import setup, find_packages 12 | import subprocess 13 | 14 | import urllib.request 15 | import urllib.error 16 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 17 | 18 | import torch 19 | from torch.utils.cpp_extension import ( 20 | BuildExtension, 21 | CppExtension, 22 | CUDAExtension, 23 | CUDA_HOME, 24 | ) 25 | 26 | 27 | with open("README.md", "r", encoding="utf-8") as fh: 28 | long_description = fh.read() 29 | 30 | 31 | # ninja build does not work unless include_dirs are abs path 32 | this_dir = os.path.dirname(os.path.abspath(__file__)) 33 | 34 | PACKAGE_NAME = "causal_conv1d" 35 | 36 | BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}" 37 | 38 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 39 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 40 | FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE" 41 | SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 42 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 43 | FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE" 44 | 45 | 46 | def get_platform(): 47 | """ 48 | Returns the platform name as used in wheel filenames. 49 | """ 50 | if sys.platform.startswith("linux"): 51 | return "linux_x86_64" 52 | elif sys.platform == "darwin": 53 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 54 | return f"macosx_{mac_version}_x86_64" 55 | elif sys.platform == "win32": 56 | return "win_amd64" 57 | else: 58 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 59 | 60 | 61 | def get_cuda_bare_metal_version(cuda_dir): 62 | raw_output = subprocess.check_output( 63 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 64 | ) 65 | output = raw_output.split() 66 | release_idx = output.index("release") + 1 67 | bare_metal_version = parse(output[release_idx].split(",")[0]) 68 | 69 | return raw_output, bare_metal_version 70 | 71 | 72 | def check_if_cuda_home_none(global_option: str) -> None: 73 | if CUDA_HOME is not None: 74 | return 75 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 76 | # in that case. 77 | warnings.warn( 78 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 79 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 80 | "only images whose names contain 'devel' will provide nvcc." 81 | ) 82 | 83 | 84 | def append_nvcc_threads(nvcc_extra_args): 85 | return nvcc_extra_args + ["--threads", "4"] 86 | 87 | 88 | cmdclass = {} 89 | ext_modules = [] 90 | 91 | if not SKIP_CUDA_BUILD: 92 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 93 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 94 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 95 | 96 | check_if_cuda_home_none("causal_conv1d") 97 | # Check, if CUDA11 is installed for compute capability 8.0 98 | cc_flag = [] 99 | if CUDA_HOME is not None: 100 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 101 | if bare_metal_version < Version("11.6"): 102 | raise RuntimeError( 103 | "causal_conv1d is only supported on CUDA 11.6 and above. " 104 | "Note: make sure nvcc has a supported version by running nvcc -V." 105 | ) 106 | 107 | cc_flag.append("-gencode") 108 | cc_flag.append("arch=compute_70,code=sm_70") 109 | cc_flag.append("-gencode") 110 | cc_flag.append("arch=compute_80,code=sm_80") 111 | if bare_metal_version >= Version("11.8"): 112 | cc_flag.append("-gencode") 113 | cc_flag.append("arch=compute_90,code=sm_90") 114 | 115 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 116 | # torch._C._GLIBCXX_USE_CXX11_ABI 117 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 118 | if FORCE_CXX11_ABI: 119 | torch._C._GLIBCXX_USE_CXX11_ABI = True 120 | 121 | ext_modules.append( 122 | CUDAExtension( 123 | name="causal_conv1d_cuda", 124 | sources=[ 125 | "csrc/causal_conv1d.cpp", 126 | "csrc/causal_conv1d_fwd.cu", 127 | "csrc/causal_conv1d_bwd.cu", 128 | "csrc/causal_conv1d_update.cu", 129 | ], 130 | extra_compile_args={ 131 | "cxx": ["-O3"], 132 | "nvcc": append_nvcc_threads( 133 | [ 134 | "-O3", 135 | "-U__CUDA_NO_HALF_OPERATORS__", 136 | "-U__CUDA_NO_HALF_CONVERSIONS__", 137 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 138 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 139 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 140 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 141 | "--expt-relaxed-constexpr", 142 | "--expt-extended-lambda", 143 | "--use_fast_math", 144 | "--ptxas-options=-v", 145 | "-lineinfo", 146 | ] 147 | + cc_flag 148 | ), 149 | }, 150 | include_dirs=[this_dir], 151 | ) 152 | ) 153 | 154 | 155 | def get_package_version(): 156 | with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f: 157 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 158 | public_version = ast.literal_eval(version_match.group(1)) 159 | local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION") 160 | if local_version: 161 | return f"{public_version}+{local_version}" 162 | else: 163 | return str(public_version) 164 | 165 | 166 | def get_wheel_url(): 167 | # Determine the version numbers that will be used to determine the correct wheel 168 | # We're using the CUDA version used to build torch, not the one currently installed 169 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 170 | torch_cuda_version = parse(torch.version.cuda) 171 | torch_version_raw = parse(torch.__version__) 172 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 173 | # to save CI time. Minor versions should be compatible. 174 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") 175 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 176 | platform_name = get_platform() 177 | causal_conv1d_version = get_package_version() 178 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" 179 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" 180 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 181 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 182 | 183 | # Determine wheel URL based on CUDA version, torch version, python version and OS 184 | wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 185 | wheel_url = BASE_WHEEL_URL.format( 186 | tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename 187 | ) 188 | return wheel_url, wheel_filename 189 | 190 | 191 | class CachedWheelsCommand(_bdist_wheel): 192 | """ 193 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 194 | find an existing wheel (which is currently the case for all installs). We use 195 | the environment parameters to detect whether there is already a pre-built version of a compatible 196 | wheel available and short-circuits the standard full build pipeline. 197 | """ 198 | 199 | def run(self): 200 | if FORCE_BUILD: 201 | return super().run() 202 | 203 | wheel_url, wheel_filename = get_wheel_url() 204 | print("Guessing wheel URL: ", wheel_url) 205 | try: 206 | urllib.request.urlretrieve(wheel_url, wheel_filename) 207 | 208 | # Make the archive 209 | # Lifted from the root wheel processing command 210 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 211 | if not os.path.exists(self.dist_dir): 212 | os.makedirs(self.dist_dir) 213 | 214 | impl_tag, abi_tag, plat_tag = self.get_tag() 215 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 216 | 217 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 218 | print("Raw wheel path", wheel_path) 219 | os.rename(wheel_filename, wheel_path) 220 | except urllib.error.HTTPError: 221 | print("Precompiled wheel not found. Building from source...") 222 | # If the wheel could not be downloaded, build from source 223 | super().run() 224 | 225 | 226 | setup( 227 | name=PACKAGE_NAME, 228 | version=get_package_version(), 229 | packages=find_packages( 230 | exclude=( 231 | "build", 232 | "csrc", 233 | "include", 234 | "tests", 235 | "dist", 236 | "docs", 237 | "benchmarks", 238 | "causal_conv1d.egg-info", 239 | ) 240 | ), 241 | author="Tri Dao", 242 | author_email="tri@tridao.me", 243 | description="Causal depthwise conv1d in CUDA, with a PyTorch interface", 244 | long_description=long_description, 245 | long_description_content_type="text/markdown", 246 | url="https://github.com/Dao-AILab/causal-conv1d", 247 | classifiers=[ 248 | "Programming Language :: Python :: 3", 249 | "License :: OSI Approved :: BSD License", 250 | "Operating System :: Unix", 251 | ], 252 | ext_modules=ext_modules, 253 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 254 | if ext_modules 255 | else { 256 | "bdist_wheel": CachedWheelsCommand, 257 | }, 258 | python_requires=">=3.7", 259 | install_requires=[ 260 | "torch", 261 | "packaging", 262 | "ninja", 263 | ], 264 | ) 265 | -------------------------------------------------------------------------------- /causal-conv1d/tests/test_causal_conv1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import pytest 7 | 8 | from einops import rearrange 9 | 10 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref 11 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("channel_last", [False, True]) 15 | # @pytest.mark.parametrize('channel_last', [True]) 16 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 17 | # @pytest.mark.parametrize('itype', [torch.float16]) 18 | @pytest.mark.parametrize("silu_activation", [False, True]) 19 | # @pytest.mark.parametrize('silu_activation', [True]) 20 | @pytest.mark.parametrize("has_bias", [False, True]) 21 | # @pytest.mark.parametrize('has_bias', [True]) 22 | @pytest.mark.parametrize("width", [2, 3, 4]) 23 | # @pytest.mark.parametrize('width', [2]) 24 | @pytest.mark.parametrize( 25 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 26 | ) 27 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 28 | # @pytest.mark.parametrize('seqlen', [128]) 29 | def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last): 30 | device = "cuda" 31 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 32 | if itype == torch.bfloat16: 33 | rtol, atol = 1e-2, 5e-2 34 | rtolw, atolw = (1e-3, 1e-3) 35 | # set seed 36 | torch.random.manual_seed(0) 37 | batch_size = 2 38 | # batch_size = 1 39 | dim = 4096 + 32 # Try dim not divisible by 64 40 | # dim = 64 41 | if not channel_last: 42 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 43 | else: 44 | x = rearrange( 45 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 46 | ).requires_grad_() 47 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 48 | if has_bias: 49 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 50 | else: 51 | bias = None 52 | x_ref = x.detach().clone().requires_grad_() 53 | weight_ref = weight.detach().clone().requires_grad_() 54 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None 55 | activation = None if not silu_activation else "silu" 56 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 57 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation) 58 | 59 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 60 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 61 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 62 | 63 | g = torch.randn_like(out) 64 | out_ref.backward(g) 65 | out.backward(g) 66 | 67 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") 68 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") 69 | if has_bias: 70 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") 71 | 72 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 73 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) 74 | if has_bias: 75 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) 76 | 77 | 78 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 79 | # @pytest.mark.parametrize('itype', [torch.float16]) 80 | @pytest.mark.parametrize("silu_activation", [False, True]) 81 | # @pytest.mark.parametrize('silu_activation', [False]) 82 | @pytest.mark.parametrize("has_bias", [False, True]) 83 | # @pytest.mark.parametrize('has_bias', [True]) 84 | @pytest.mark.parametrize("width", [2, 3, 4]) 85 | # @pytest.mark.parametrize('width', [2]) 86 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 87 | # @pytest.mark.parametrize("dim", [2048]) 88 | def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): 89 | device = "cuda" 90 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 91 | if itype == torch.bfloat16: 92 | rtol, atol = 1e-2, 5e-2 93 | rtolw, atolw = (1e-3, 1e-3) 94 | # set seed 95 | torch.random.manual_seed(0) 96 | batch_size = 2 97 | # batch_size = 1 98 | # dim = 64 99 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 100 | conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype) 101 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 102 | if has_bias: 103 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 104 | else: 105 | bias = None 106 | conv_state_ref = conv_state.detach().clone() 107 | activation = None if not silu_activation else "silu" 108 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) 109 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation) 110 | 111 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 112 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 113 | assert torch.equal(conv_state, conv_state_ref) 114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 115 | 116 | 117 | # @pytest.mark.parametrize("channel_last", [False, True]) 118 | @pytest.mark.parametrize('channel_last', [True]) 119 | # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 120 | @pytest.mark.parametrize('itype', [torch.bfloat16]) 121 | # @pytest.mark.parametrize("silu_activation", [False, True]) 122 | @pytest.mark.parametrize('silu_activation', [True]) 123 | # @pytest.mark.parametrize("has_bias", [False, True]) 124 | @pytest.mark.parametrize('has_bias', [True]) 125 | # @pytest.mark.parametrize("width", [2, 3, 4]) 126 | @pytest.mark.parametrize('width', [4]) 127 | @pytest.mark.parametrize( 128 | # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 129 | "seqlen", [2048] 130 | ) 131 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 132 | # @pytest.mark.parametrize('seqlen', [128]) 133 | def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last): 134 | device = "cuda" 135 | # set seed 136 | torch.random.manual_seed(0) 137 | batch_size = 2 138 | # batch_size = 1 139 | dim = 4096 + 32 # Try dim not divisible by 64 140 | # dim = 64 141 | if not channel_last: 142 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 143 | else: 144 | x = rearrange( 145 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 146 | ).requires_grad_() 147 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 148 | if has_bias: 149 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 150 | else: 151 | bias = None 152 | activation = None if not silu_activation else "silu" 153 | out0 = causal_conv1d_fn(x, weight, bias, activation=activation) 154 | g = torch.randn_like(out0) 155 | dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g) 156 | dw_atol = 1e-4 157 | db_atol = 1e-4 158 | 159 | for i in range(10000): 160 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 161 | dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g) 162 | dw_equal = torch.allclose(dw, dw0, atol=dw_atol) 163 | # if not dw_equal: 164 | # breakpoint() 165 | if has_bias: 166 | db_equal = torch.allclose(db, db0, atol=db_atol) 167 | # if not db_equal: 168 | # breakpoint() 169 | assert torch.equal(out, out0) 170 | assert torch.equal(dx, dx0) 171 | assert dw_equal 172 | if has_bias: 173 | assert dw_equal 174 | -------------------------------------------------------------------------------- /mamba-1p1p1/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | Albert Gu, agu@andrew.cmu.edu 3 | -------------------------------------------------------------------------------- /mamba-1p1p1/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Tri Dao, Albert Gu 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /mamba-1p1p1/README.md: -------------------------------------------------------------------------------- 1 | # Mamba 2 | 3 | ![Mamba](assets/selection.png "Selective State Space") 4 | > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ 5 | > Albert Gu*, Tri Dao*\ 6 | > Paper: https://arxiv.org/abs/2312.00752 7 | 8 | ## About 9 | 10 | Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. 11 | It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4), 12 | with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention). 13 | 14 | ## Installation 15 | 16 | - `pip install causal-conv1d>=1.1.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. 17 | - `pip install mamba-ssm`: the core Mamba package. 18 | 19 | It can also be built from source with `pip install .` from this repository. 20 | 21 | If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`. 22 | 23 | Other requirements: 24 | - Linux 25 | - NVIDIA GPU 26 | - PyTorch 1.12+ 27 | - CUDA 11.6+ 28 | 29 | ## Usage 30 | 31 | We expose several levels of interface with the Mamba model. 32 | 33 | ### Selective SSM 34 | 35 | Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2). 36 | 37 | Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py). 38 | 39 | ### Mamba Block 40 | 41 | The main module of this repository is the Mamba architecture block wrapping the selective SSM. 42 | 43 | Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py). 44 | 45 | Usage: 46 | ``` 47 | import torch 48 | from mamba_ssm import Mamba 49 | 50 | batch, length, dim = 2, 64, 16 51 | x = torch.randn(batch, length, dim).to("cuda") 52 | model = Mamba( 53 | # This module uses roughly 3 * expand * d_model^2 parameters 54 | d_model=dim, # Model dimension d_model 55 | d_state=16, # SSM state expansion factor 56 | d_conv=4, # Local convolution width 57 | expand=2, # Block expansion factor 58 | ).to("cuda") 59 | y = model(x) 60 | assert y.shape == x.shape 61 | ``` 62 | 63 | ### Mamba Language Model 64 | 65 | Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. 66 | 67 | Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py). 68 | 69 | This is an example of how to integrate Mamba into an end-to-end neural network. 70 | This example is used in the generation scripts below. 71 | 72 | 73 | 74 | ## Pretrained Models 75 | 76 | Pretrained models are uploaded to 77 | [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`, 78 | `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj` 79 | (trained on 600B tokens on the SlimPajama dataset). 80 | 81 | 82 | The models will be autodownloaded by the generation script below. 83 | 84 | These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models: 85 | 86 | | Parameters | Layers | Model dim. | 87 | |------------|--------|------------| 88 | | 130M | 24 | 768 | 89 | | 370M | 48 | 1024 | 90 | | 790M | 48 | 1536 | 91 | | 1.4B | 48 | 2048 | 92 | | 2.8B | 64 | 2560 | 93 | 94 | (The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.) 95 | 96 | Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). 97 | Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models. 98 | 99 | 100 | ## Evaluations 101 | 102 | To run zero-shot evaluations of models (corresponding to Table 3 of the paper), 103 | we use the 104 | [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) 105 | library. 106 | 107 | 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init 108 | --recursive`. We use the `big-refactor` branch. 109 | 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`. 110 | On Python 3.10 you might need to manually install the latest version of `promptsource`: `pip install git+https://github.com/bigscience-workshop/promptsource.git`. 111 | 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo): 112 | ``` 113 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 114 | python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 115 | ``` 116 | 117 | To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts: 118 | ``` 119 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 64 120 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 64 121 | ``` 122 | 123 | Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process. 124 | 125 | ## Inference 126 | 127 | The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py) 128 | 1. autoloads a model from the Hugging Face Hub, 129 | 2. generates completions of a user-specified prompt, 130 | 3. benchmarks the inference speed of this generation. 131 | 132 | Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature. 133 | 134 | ### Examples 135 | 136 | To test generation latency (e.g. batch size = 1) with different sampling strategies: 137 | 138 | ``` 139 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 140 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 141 | ``` 142 | 143 | To test generation throughput with random prompts (e.g. large batch size): 144 | ``` 145 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128 146 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128 147 | ``` 148 | 149 | 150 | ## Troubleshooting 151 | 152 | ### Precision 153 | Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary. 154 | On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation). 155 | 156 | We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities, 157 | as a first step please try a framework storing parameters in fp32 (such as AMP). 158 | 159 | ### Initialization 160 | Some parts of the model have initializations inherited from prior work on S4 models. 161 | For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection. 162 | However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero). 163 | If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework) 164 | that is specific to the training framework. 165 | 166 | 167 | ## Citation 168 | 169 | If you use this codebase, or otherwise found our work valuable, please cite Mamba: 170 | ``` 171 | @article{mamba, 172 | title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, 173 | author={Gu, Albert and Dao, Tri}, 174 | journal={arXiv preprint arXiv:2312.00752}, 175 | year={2023} 176 | } 177 | ``` 178 | -------------------------------------------------------------------------------- /mamba-1p1p1/assets/selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/assets/selection.png -------------------------------------------------------------------------------- /mamba-1p1p1/benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--repetition-penalty", type=float, default=1.0) 26 | parser.add_argument("--batch", type=int, default=1) 27 | args = parser.parse_args() 28 | 29 | repeats = 3 30 | device = "cuda" 31 | dtype = torch.float16 32 | 33 | print(f"Loading model {args.model_name}") 34 | is_mamba = args.model_name.startswith("state-spaces/mamba-") 35 | if is_mamba: 36 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 38 | else: 39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 41 | model.eval() 42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 43 | 44 | torch.random.manual_seed(0) 45 | if args.prompt is None: 46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 48 | else: 49 | tokens = tokenizer(args.prompt, return_tensors="pt") 50 | input_ids = tokens.input_ids.to(device=device) 51 | attn_mask = tokens.attention_mask.to(device=device) 52 | max_length = input_ids.shape[1] + args.genlen 53 | 54 | if is_mamba: 55 | fn = lambda: model.generate( 56 | input_ids=input_ids, 57 | max_length=max_length, 58 | cg=True, 59 | return_dict_in_generate=True, 60 | output_scores=True, 61 | enable_timing=False, 62 | temperature=args.temperature, 63 | top_k=args.topk, 64 | top_p=args.topp, 65 | repetition_penalty=args.repetition_penalty, 66 | ) 67 | else: 68 | fn = lambda: model.generate( 69 | input_ids=input_ids, 70 | attention_mask=attn_mask, 71 | max_length=max_length, 72 | return_dict_in_generate=True, 73 | pad_token_id=tokenizer.eos_token_id, 74 | do_sample=True, 75 | temperature=args.temperature, 76 | top_k=args.topk, 77 | top_p=args.topp, 78 | repetition_penalty=args.repetition_penalty, 79 | ) 80 | out = fn() 81 | if args.prompt is not None: 82 | print(tokenizer.batch_decode(out.sequences.tolist())) 83 | 84 | torch.cuda.synchronize() 85 | start = time.time() 86 | for _ in range(repeats): 87 | fn() 88 | torch.cuda.synchronize() 89 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 90 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 91 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For scalar_value_type 10 | 11 | #define MAX_DSTATE 256 12 | 13 | using complex_t = c10::complex; 14 | 15 | inline __device__ float2 operator+(const float2 & a, const float2 & b){ 16 | return {a.x + b.x, a.y + b.y}; 17 | } 18 | 19 | inline __device__ float3 operator+(const float3 &a, const float3 &b) { 20 | return {a.x + b.x, a.y + b.y, a.z + b.z}; 21 | } 22 | 23 | inline __device__ float4 operator+(const float4 & a, const float4 & b){ 24 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; 25 | } 26 | 27 | //////////////////////////////////////////////////////////////////////////////////////////////////// 28 | 29 | template struct BytesToType {}; 30 | 31 | template<> struct BytesToType<16> { 32 | using Type = uint4; 33 | static_assert(sizeof(Type) == 16); 34 | }; 35 | 36 | template<> struct BytesToType<8> { 37 | using Type = uint64_t; 38 | static_assert(sizeof(Type) == 8); 39 | }; 40 | 41 | template<> struct BytesToType<4> { 42 | using Type = uint32_t; 43 | static_assert(sizeof(Type) == 4); 44 | }; 45 | 46 | template<> struct BytesToType<2> { 47 | using Type = uint16_t; 48 | static_assert(sizeof(Type) == 2); 49 | }; 50 | 51 | template<> struct BytesToType<1> { 52 | using Type = uint8_t; 53 | static_assert(sizeof(Type) == 1); 54 | }; 55 | 56 | //////////////////////////////////////////////////////////////////////////////////////////////////// 57 | 58 | template 59 | struct Converter{ 60 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { 61 | #pragma unroll 62 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; } 63 | } 64 | }; 65 | 66 | template 67 | struct Converter{ 68 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { 69 | static_assert(N % 2 == 0); 70 | auto &src2 = reinterpret_cast(src); 71 | auto &dst2 = reinterpret_cast(dst); 72 | #pragma unroll 73 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } 74 | } 75 | }; 76 | 77 | #if __CUDA_ARCH__ >= 800 78 | template 79 | struct Converter{ 80 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { 81 | static_assert(N % 2 == 0); 82 | auto &src2 = reinterpret_cast(src); 83 | auto &dst2 = reinterpret_cast(dst); 84 | #pragma unroll 85 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } 86 | } 87 | }; 88 | #endif 89 | 90 | //////////////////////////////////////////////////////////////////////////////////////////////////// 91 | 92 | // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp 93 | // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 94 | __device__ __forceinline__ complex_t cexp2f(complex_t z) { 95 | float t = exp2f(z.real_); 96 | float c, s; 97 | sincosf(z.imag_, &s, &c); 98 | return complex_t(c * t, s * t); 99 | } 100 | 101 | __device__ __forceinline__ complex_t cexpf(complex_t z) { 102 | float t = expf(z.real_); 103 | float c, s; 104 | sincosf(z.imag_, &s, &c); 105 | return complex_t(c * t, s * t); 106 | } 107 | 108 | template struct SSMScanOp; 109 | 110 | template<> 111 | struct SSMScanOp { 112 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { 113 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); 114 | } 115 | }; 116 | 117 | template<> 118 | struct SSMScanOp { 119 | __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { 120 | complex_t a0 = complex_t(ab0.x, ab0.y); 121 | complex_t b0 = complex_t(ab0.z, ab0.w); 122 | complex_t a1 = complex_t(ab1.x, ab1.y); 123 | complex_t b1 = complex_t(ab1.z, ab1.w); 124 | complex_t out_a = a1 * a0; 125 | complex_t out_b = a1 * b0 + b1; 126 | return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); 127 | } 128 | }; 129 | 130 | // A stateful callback functor that maintains a running prefix to be applied 131 | // during consecutive scan operations. 132 | template struct SSMScanPrefixCallbackOp { 133 | using scan_t = std::conditional_t, float2, float4>; 134 | scan_t running_prefix; 135 | // Constructor 136 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} 137 | // Callback operator to be entered by the first warp of threads in the block. 138 | // Thread-0 is responsible for returning a value for seeding the block-wide scan. 139 | __device__ scan_t operator()(scan_t block_aggregate) { 140 | scan_t old_prefix = running_prefix; 141 | running_prefix = SSMScanOp()(running_prefix, block_aggregate); 142 | return old_prefix; 143 | } 144 | }; 145 | 146 | //////////////////////////////////////////////////////////////////////////////////////////////////// 147 | 148 | template 149 | inline __device__ void load_input(typename Ktraits::input_t *u, 150 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], 151 | typename Ktraits::BlockLoadT::TempStorage &smem_load, 152 | int seqlen) { 153 | if constexpr (Ktraits::kIsEvenLen) { 154 | auto& smem_load_vec = reinterpret_cast(smem_load); 155 | using vec_t = typename Ktraits::vec_t; 156 | Ktraits::BlockLoadVecT(smem_load_vec).Load( 157 | reinterpret_cast(u), 158 | reinterpret_cast(u_vals) 159 | ); 160 | } else { 161 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); 162 | } 163 | } 164 | 165 | template 166 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar, 167 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], 168 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, 169 | int seqlen) { 170 | constexpr int kNItems = Ktraits::kNItems; 171 | if constexpr (!Ktraits::kIsComplex) { 172 | typename Ktraits::input_t B_vals_load[kNItems]; 173 | if constexpr (Ktraits::kIsEvenLen) { 174 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 175 | using vec_t = typename Ktraits::vec_t; 176 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 177 | reinterpret_cast(Bvar), 178 | reinterpret_cast(B_vals_load) 179 | ); 180 | } else { 181 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 182 | } 183 | // #pragma unroll 184 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } 185 | Converter::to_float(B_vals_load, B_vals); 186 | } else { 187 | typename Ktraits::input_t B_vals_load[kNItems * 2]; 188 | if constexpr (Ktraits::kIsEvenLen) { 189 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 190 | using vec_t = typename Ktraits::vec_t; 191 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 192 | reinterpret_cast(Bvar), 193 | reinterpret_cast(B_vals_load) 194 | ); 195 | } else { 196 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 197 | } 198 | #pragma unroll 199 | for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } 200 | } 201 | } 202 | 203 | template 204 | inline __device__ void store_output(typename Ktraits::input_t *out, 205 | const float (&out_vals)[Ktraits::kNItems], 206 | typename Ktraits::BlockStoreT::TempStorage &smem_store, 207 | int seqlen) { 208 | typename Ktraits::input_t write_vals[Ktraits::kNItems]; 209 | #pragma unroll 210 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } 211 | if constexpr (Ktraits::kIsEvenLen) { 212 | auto& smem_store_vec = reinterpret_cast(smem_store); 213 | using vec_t = typename Ktraits::vec_t; 214 | Ktraits::BlockStoreVecT(smem_store_vec).Store( 215 | reinterpret_cast(out), 216 | reinterpret_cast(write_vals) 217 | ); 218 | } else { 219 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_bf16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp32.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /mamba-1p1p1/evals/lm_harness_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import transformers 4 | from transformers import AutoTokenizer 5 | 6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.models.huggingface import HFLM 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.__main__ import cli_evaluate 12 | 13 | 14 | @register_model("mamba") 15 | class MambaEvalWrapper(HFLM): 16 | 17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM 18 | 19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", 20 | dtype=torch.float16): 21 | LM.__init__(self) 22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) 23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 25 | self.vocab_size = self.tokenizer.vocab_size 26 | self._batch_size = int(batch_size) if batch_size is not None else 64 27 | self._max_length = max_length 28 | self._device = torch.device(device) 29 | 30 | @property 31 | def batch_size(self): 32 | return self._batch_size 33 | 34 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 35 | raise NotImplementedError() 36 | 37 | 38 | if __name__ == "__main__": 39 | cli_evaluate() 40 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.1" 2 | 3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 4 | from mamba_ssm.modules.mamba_simple import Mamba 5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/models/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/models/config_mamba.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MambaConfig: 6 | 7 | d_model: int = 2560 8 | n_layer: int = 64 9 | vocab_size: int = 50277 10 | ssm_cfg: dict = field(default_factory=dict) 11 | rms_norm: bool = True 12 | residual_in_fp32: bool = True 13 | fused_add_norm: bool = True 14 | pad_vocab_size_multiple: int = 8 15 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/models/mixer_seq_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | 3 | import math 4 | from functools import partial 5 | import json 6 | import os 7 | 8 | from collections import namedtuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from mamba_ssm.models.config_mamba import MambaConfig 14 | from mamba_ssm.modules.mamba_simple import Mamba, Block 15 | from mamba_ssm.utils.generation import GenerationMixin 16 | from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf 17 | 18 | try: 19 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 20 | except ImportError: 21 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 22 | 23 | 24 | def create_block( 25 | d_model, 26 | ssm_cfg=None, 27 | norm_epsilon=1e-5, 28 | rms_norm=False, 29 | residual_in_fp32=False, 30 | fused_add_norm=False, 31 | layer_idx=None, 32 | device=None, 33 | dtype=None, 34 | ): 35 | if ssm_cfg is None: 36 | ssm_cfg = {} 37 | factory_kwargs = {"device": device, "dtype": dtype} 38 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 39 | norm_cls = partial( 40 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 41 | ) 42 | block = Block( 43 | d_model, 44 | mixer_cls, 45 | norm_cls=norm_cls, 46 | fused_add_norm=fused_add_norm, 47 | residual_in_fp32=residual_in_fp32, 48 | ) 49 | block.layer_idx = layer_idx 50 | return block 51 | 52 | 53 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 54 | def _init_weights( 55 | module, 56 | n_layer, 57 | initializer_range=0.02, # Now only used for embedding layer. 58 | rescale_prenorm_residual=True, 59 | n_residuals_per_layer=1, # Change to 2 if we have MLP 60 | ): 61 | if isinstance(module, nn.Linear): 62 | if module.bias is not None: 63 | if not getattr(module.bias, "_no_reinit", False): 64 | nn.init.zeros_(module.bias) 65 | elif isinstance(module, nn.Embedding): 66 | nn.init.normal_(module.weight, std=initializer_range) 67 | 68 | if rescale_prenorm_residual: 69 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 70 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 71 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 72 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 73 | # 74 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 75 | for name, p in module.named_parameters(): 76 | if name in ["out_proj.weight", "fc2.weight"]: 77 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 78 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 79 | # We need to reinit p since this code could be called multiple times 80 | # Having just p *= scale would repeatedly scale it down 81 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 82 | with torch.no_grad(): 83 | p /= math.sqrt(n_residuals_per_layer * n_layer) 84 | 85 | 86 | class MixerModel(nn.Module): 87 | def __init__( 88 | self, 89 | d_model: int, 90 | n_layer: int, 91 | vocab_size: int, 92 | ssm_cfg=None, 93 | norm_epsilon: float = 1e-5, 94 | rms_norm: bool = False, 95 | initializer_cfg=None, 96 | fused_add_norm=False, 97 | residual_in_fp32=False, 98 | device=None, 99 | dtype=None, 100 | ) -> None: 101 | factory_kwargs = {"device": device, "dtype": dtype} 102 | super().__init__() 103 | self.residual_in_fp32 = residual_in_fp32 104 | 105 | self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) 106 | 107 | # We change the order of residual and layer norm: 108 | # Instead of LN -> Attn / MLP -> Add, we do: 109 | # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and 110 | # the main branch (output of MLP / Mixer). The model definition is unchanged. 111 | # This is for performance reason: we can fuse add + layer_norm. 112 | self.fused_add_norm = fused_add_norm 113 | if self.fused_add_norm: 114 | if layer_norm_fn is None or rms_norm_fn is None: 115 | raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") 116 | 117 | self.layers = nn.ModuleList( 118 | [ 119 | create_block( 120 | d_model, 121 | ssm_cfg=ssm_cfg, 122 | norm_epsilon=norm_epsilon, 123 | rms_norm=rms_norm, 124 | residual_in_fp32=residual_in_fp32, 125 | fused_add_norm=fused_add_norm, 126 | layer_idx=i, 127 | **factory_kwargs, 128 | ) 129 | for i in range(n_layer) 130 | ] 131 | ) 132 | 133 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 134 | d_model, eps=norm_epsilon, **factory_kwargs 135 | ) 136 | 137 | self.apply( 138 | partial( 139 | _init_weights, 140 | n_layer=n_layer, 141 | **(initializer_cfg if initializer_cfg is not None else {}), 142 | ) 143 | ) 144 | 145 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 146 | return { 147 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 148 | for i, layer in enumerate(self.layers) 149 | } 150 | 151 | def forward(self, input_ids, inference_params=None): 152 | hidden_states = self.embedding(input_ids) 153 | residual = None 154 | for layer in self.layers: 155 | hidden_states, residual = layer( 156 | hidden_states, residual, inference_params=inference_params 157 | ) 158 | if not self.fused_add_norm: 159 | residual = (hidden_states + residual) if residual is not None else hidden_states 160 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 161 | else: 162 | # Set prenorm=False here since we don't need the residual 163 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 164 | hidden_states = fused_add_norm_fn( 165 | hidden_states, 166 | self.norm_f.weight, 167 | self.norm_f.bias, 168 | eps=self.norm_f.eps, 169 | residual=residual, 170 | prenorm=False, 171 | residual_in_fp32=self.residual_in_fp32, 172 | ) 173 | return hidden_states 174 | 175 | 176 | class MambaLMHeadModel(nn.Module, GenerationMixin): 177 | 178 | def __init__( 179 | self, 180 | config: MambaConfig, 181 | initializer_cfg=None, 182 | device=None, 183 | dtype=None, 184 | ) -> None: 185 | self.config = config 186 | d_model = config.d_model 187 | n_layer = config.n_layer 188 | vocab_size = config.vocab_size 189 | ssm_cfg = config.ssm_cfg 190 | rms_norm = config.rms_norm 191 | residual_in_fp32 = config.residual_in_fp32 192 | fused_add_norm = config.fused_add_norm 193 | pad_vocab_size_multiple = config.pad_vocab_size_multiple 194 | factory_kwargs = {"device": device, "dtype": dtype} 195 | 196 | super().__init__() 197 | if vocab_size % pad_vocab_size_multiple != 0: 198 | vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) 199 | self.backbone = MixerModel( 200 | d_model=d_model, 201 | n_layer=n_layer, 202 | vocab_size=vocab_size, 203 | ssm_cfg=ssm_cfg, 204 | rms_norm=rms_norm, 205 | initializer_cfg=initializer_cfg, 206 | fused_add_norm=fused_add_norm, 207 | residual_in_fp32=residual_in_fp32, 208 | **factory_kwargs, 209 | ) 210 | self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) 211 | 212 | # Initialize weights and apply final processing 213 | self.apply( 214 | partial( 215 | _init_weights, 216 | n_layer=n_layer, 217 | **(initializer_cfg if initializer_cfg is not None else {}), 218 | ) 219 | ) 220 | self.tie_weights() 221 | 222 | def tie_weights(self): 223 | self.lm_head.weight = self.backbone.embedding.weight 224 | 225 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 226 | return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 227 | 228 | def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): 229 | """ 230 | "position_ids" is just to be compatible with Transformer generation. We don't use it. 231 | num_last_tokens: if > 0, only return the logits for the last n tokens 232 | """ 233 | hidden_states = self.backbone(input_ids, inference_params=inference_params) 234 | if num_last_tokens > 0: 235 | hidden_states = hidden_states[:, -num_last_tokens:] 236 | lm_logits = self.lm_head(hidden_states) 237 | CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) 238 | return CausalLMOutput(logits=lm_logits) 239 | 240 | @classmethod 241 | def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): 242 | config_data = load_config_hf(pretrained_model_name) 243 | config = MambaConfig(**config_data) 244 | model = cls(config, device=device, dtype=dtype, **kwargs) 245 | model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) 246 | return model 247 | 248 | def save_pretrained(self, save_directory): 249 | """ 250 | Minimal implementation of save_pretrained for MambaLMHeadModel. 251 | Save the model and its configuration file to a directory. 252 | """ 253 | # Ensure save_directory exists 254 | if not os.path.exists(save_directory): 255 | os.makedirs(save_directory) 256 | 257 | # Save the model's state_dict 258 | model_path = os.path.join(save_directory, 'pytorch_model.bin') 259 | torch.save(self.state_dict(), model_path) 260 | 261 | # Save the configuration of the model 262 | config_path = os.path.join(save_directory, 'config.json') 263 | with open(config_path, 'w') as f: 264 | json.dump(self.config.__dict__, f) 265 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/modules/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/ops/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/ops/triton/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/ops/triton/selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | """We want triton==2.1.0 for this 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import triton 11 | import triton.language as tl 12 | 13 | from einops import rearrange, repeat 14 | 15 | 16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) 17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) 18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) 19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) 20 | @triton.jit 21 | def _selective_scan_update_kernel( 22 | # Pointers to matrices 23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, 24 | # Matrix dimensions 25 | batch, dim, dstate, 26 | # Strides 27 | stride_state_batch, stride_state_dim, stride_state_dstate, 28 | stride_x_batch, stride_x_dim, 29 | stride_dt_batch, stride_dt_dim, 30 | stride_dt_bias_dim, 31 | stride_A_dim, stride_A_dstate, 32 | stride_B_batch, stride_B_dstate, 33 | stride_C_batch, stride_C_dstate, 34 | stride_D_dim, 35 | stride_z_batch, stride_z_dim, 36 | stride_out_batch, stride_out_dim, 37 | # Meta-parameters 38 | DT_SOFTPLUS: tl.constexpr, 39 | BLOCK_SIZE_M: tl.constexpr, 40 | HAS_DT_BIAS: tl.constexpr, 41 | HAS_D: tl.constexpr, 42 | HAS_Z: tl.constexpr, 43 | BLOCK_SIZE_DSTATE: tl.constexpr, 44 | ): 45 | pid_m = tl.program_id(axis=0) 46 | pid_b = tl.program_id(axis=1) 47 | state_ptr += pid_b * stride_state_batch 48 | x_ptr += pid_b * stride_x_batch 49 | dt_ptr += pid_b * stride_dt_batch 50 | B_ptr += pid_b * stride_B_batch 51 | C_ptr += pid_b * stride_C_batch 52 | if HAS_Z: 53 | z_ptr += pid_b * stride_z_batch 54 | out_ptr += pid_b * stride_out_batch 55 | 56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) 58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) 59 | x_ptrs = x_ptr + offs_m * stride_x_dim 60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim 61 | if HAS_DT_BIAS: 62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim 63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) 64 | B_ptrs = B_ptr + offs_n * stride_B_dstate 65 | C_ptrs = C_ptr + offs_n * stride_C_dstate 66 | if HAS_D: 67 | D_ptrs = D_ptr + offs_m * stride_D_dim 68 | if HAS_Z: 69 | z_ptrs = z_ptr + offs_m * stride_z_dim 70 | out_ptrs = out_ptr + offs_m * stride_out_dim 71 | 72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) 73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 75 | if HAS_DT_BIAS: 76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 77 | if DT_SOFTPLUS: 78 | dt = tl.log(1.0 + tl.exp(dt)) 79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) 80 | dA = tl.exp(A * dt[:, None]) 81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 83 | if HAS_D: 84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 85 | if HAS_Z: 86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 87 | 88 | dB = B[None, :] * dt[:, None] 89 | state = state * dA + dB * x[:, None] 90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) 91 | out = tl.sum(state * C[None, :], axis=1) 92 | if HAS_D: 93 | out += x * D 94 | if HAS_Z: 95 | out *= z * tl.sigmoid(z) 96 | tl.store(out_ptrs, out, mask=offs_m < dim) 97 | 98 | 99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 100 | """ 101 | Argument: 102 | state: (batch, dim, dstate) 103 | x: (batch, dim) 104 | dt: (batch, dim) 105 | A: (dim, dstate) 106 | B: (batch, dstate) 107 | C: (batch, dstate) 108 | D: (dim,) 109 | z: (batch, dim) 110 | dt_bias: (dim,) 111 | Return: 112 | out: (batch, dim) 113 | """ 114 | batch, dim, dstate = state.shape 115 | assert x.shape == (batch, dim) 116 | assert dt.shape == x.shape 117 | assert A.shape == (dim, dstate) 118 | assert B.shape == (batch, dstate) 119 | assert C.shape == B.shape 120 | if D is not None: 121 | assert D.shape == (dim,) 122 | if z is not None: 123 | assert z.shape == x.shape 124 | if dt_bias is not None: 125 | assert dt_bias.shape == (dim,) 126 | out = torch.empty_like(x) 127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) 128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) 129 | # We don't want autotune since it will overwrite the state 130 | # We instead tune by hand. 131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 132 | else ((16, 4) if dstate <= 32 else 133 | ((8, 4) if dstate <= 64 else 134 | ((4, 4) if dstate <= 128 else 135 | ((4, 8)))))) 136 | with torch.cuda.device(x.device.index): 137 | _selective_scan_update_kernel[grid]( 138 | state, x, dt, dt_bias, A, B, C, D, z, out, 139 | batch, dim, dstate, 140 | state.stride(0), state.stride(1), state.stride(2), 141 | x.stride(0), x.stride(1), 142 | dt.stride(0), dt.stride(1), 143 | dt_bias.stride(0) if dt_bias is not None else 0, 144 | A.stride(0), A.stride(1), 145 | B.stride(0), B.stride(1), 146 | C.stride(0), C.stride(1), 147 | D.stride(0) if D is not None else 0, 148 | z_strides[0], z_strides[1], 149 | out.stride(0), out.stride(1), 150 | dt_softplus, 151 | BLOCK_SIZE_M, 152 | num_warps=num_warps, 153 | ) 154 | return out 155 | 156 | 157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 158 | """ 159 | Argument: 160 | state: (batch, dim, dstate) 161 | x: (batch, dim) 162 | dt: (batch, dim) 163 | A: (dim, dstate) 164 | B: (batch, dstate) 165 | C: (batch, dstate) 166 | D: (dim,) 167 | z: (batch, dim) 168 | dt_bias: (dim,) 169 | Return: 170 | out: (batch, dim) 171 | """ 172 | batch, dim, dstate = state.shape 173 | assert x.shape == (batch, dim) 174 | assert dt.shape == x.shape 175 | assert A.shape == (dim, dstate) 176 | assert B.shape == (batch, dstate) 177 | assert C.shape == B.shape 178 | if D is not None: 179 | assert D.shape == (dim,) 180 | if z is not None: 181 | assert z.shape == x.shape 182 | if dt_bias is not None: 183 | assert dt_bias.shape == (dim,) 184 | dt = dt + dt_bias 185 | dt = F.softplus(dt) if dt_softplus else dt 186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) 187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) 188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate 189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) 190 | if D is not None: 191 | out += (x * D).to(out.dtype) 192 | return (out if z is None else out * F.silu(z)).to(x.dtype) 193 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/utils/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /mamba-1p1p1/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | import sys 3 | import warnings 4 | import os 5 | import re 6 | import ast 7 | from pathlib import Path 8 | from packaging.version import parse, Version 9 | import platform 10 | import shutil 11 | 12 | from setuptools import setup, find_packages 13 | import subprocess 14 | 15 | import urllib.request 16 | import urllib.error 17 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 18 | 19 | import torch 20 | from torch.utils.cpp_extension import ( 21 | BuildExtension, 22 | CppExtension, 23 | CUDAExtension, 24 | CUDA_HOME, 25 | ) 26 | 27 | 28 | with open("README.md", "r", encoding="utf-8") as fh: 29 | long_description = fh.read() 30 | 31 | 32 | # ninja build does not work unless include_dirs are abs path 33 | this_dir = os.path.dirname(os.path.abspath(__file__)) 34 | 35 | PACKAGE_NAME = "mamba_ssm" 36 | 37 | BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" 38 | 39 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 40 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 41 | FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" 42 | SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 43 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 44 | FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" 45 | 46 | 47 | def get_platform(): 48 | """ 49 | Returns the platform name as used in wheel filenames. 50 | """ 51 | if sys.platform.startswith("linux"): 52 | return "linux_x86_64" 53 | elif sys.platform == "darwin": 54 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 55 | return f"macosx_{mac_version}_x86_64" 56 | elif sys.platform == "win32": 57 | return "win_amd64" 58 | else: 59 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 60 | 61 | 62 | def get_cuda_bare_metal_version(cuda_dir): 63 | raw_output = subprocess.check_output( 64 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 65 | ) 66 | output = raw_output.split() 67 | release_idx = output.index("release") + 1 68 | bare_metal_version = parse(output[release_idx].split(",")[0]) 69 | 70 | return raw_output, bare_metal_version 71 | 72 | 73 | def check_if_cuda_home_none(global_option: str) -> None: 74 | if CUDA_HOME is not None: 75 | return 76 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 77 | # in that case. 78 | warnings.warn( 79 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 80 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 81 | "only images whose names contain 'devel' will provide nvcc." 82 | ) 83 | 84 | 85 | def append_nvcc_threads(nvcc_extra_args): 86 | return nvcc_extra_args + ["--threads", "4"] 87 | 88 | 89 | cmdclass = {} 90 | ext_modules = [] 91 | 92 | if not SKIP_CUDA_BUILD: 93 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 94 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 95 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 96 | 97 | check_if_cuda_home_none(PACKAGE_NAME) 98 | # Check, if CUDA11 is installed for compute capability 8.0 99 | cc_flag = [] 100 | if CUDA_HOME is not None: 101 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 102 | if bare_metal_version < Version("11.6"): 103 | raise RuntimeError( 104 | f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " 105 | "Note: make sure nvcc has a supported version by running nvcc -V." 106 | ) 107 | 108 | cc_flag.append("-gencode") 109 | cc_flag.append("arch=compute_70,code=sm_70") 110 | cc_flag.append("-gencode") 111 | cc_flag.append("arch=compute_80,code=sm_80") 112 | if bare_metal_version >= Version("11.8"): 113 | cc_flag.append("-gencode") 114 | cc_flag.append("arch=compute_90,code=sm_90") 115 | 116 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 117 | # torch._C._GLIBCXX_USE_CXX11_ABI 118 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 119 | if FORCE_CXX11_ABI: 120 | torch._C._GLIBCXX_USE_CXX11_ABI = True 121 | 122 | ext_modules.append( 123 | CUDAExtension( 124 | name="selective_scan_cuda", 125 | sources=[ 126 | "csrc/selective_scan/selective_scan.cpp", 127 | "csrc/selective_scan/selective_scan_fwd_fp32.cu", 128 | "csrc/selective_scan/selective_scan_fwd_fp16.cu", 129 | "csrc/selective_scan/selective_scan_fwd_bf16.cu", 130 | "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", 131 | "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", 132 | "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", 133 | "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", 134 | "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", 135 | "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", 136 | ], 137 | extra_compile_args={ 138 | "cxx": ["-O3", "-std=c++17"], 139 | "nvcc": append_nvcc_threads( 140 | [ 141 | "-O3", 142 | "-std=c++17", 143 | "-U__CUDA_NO_HALF_OPERATORS__", 144 | "-U__CUDA_NO_HALF_CONVERSIONS__", 145 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 146 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 147 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 148 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 149 | "--expt-relaxed-constexpr", 150 | "--expt-extended-lambda", 151 | "--use_fast_math", 152 | "--ptxas-options=-v", 153 | "-lineinfo", 154 | ] 155 | + cc_flag 156 | ), 157 | }, 158 | include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], 159 | ) 160 | ) 161 | 162 | 163 | def get_package_version(): 164 | with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f: 165 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 166 | public_version = ast.literal_eval(version_match.group(1)) 167 | local_version = os.environ.get("MAMBA_LOCAL_VERSION") 168 | if local_version: 169 | return f"{public_version}+{local_version}" 170 | else: 171 | return str(public_version) 172 | 173 | 174 | def get_wheel_url(): 175 | # Determine the version numbers that will be used to determine the correct wheel 176 | # We're using the CUDA version used to build torch, not the one currently installed 177 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 178 | torch_cuda_version = parse(torch.version.cuda) 179 | torch_version_raw = parse(torch.__version__) 180 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 181 | # to save CI time. Minor versions should be compatible. 182 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") 183 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 184 | platform_name = get_platform() 185 | mamba_ssm_version = get_package_version() 186 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" 187 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" 188 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 189 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 190 | 191 | # Determine wheel URL based on CUDA version, torch version, python version and OS 192 | wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 193 | wheel_url = BASE_WHEEL_URL.format( 194 | tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename 195 | ) 196 | return wheel_url, wheel_filename 197 | 198 | 199 | class CachedWheelsCommand(_bdist_wheel): 200 | """ 201 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 202 | find an existing wheel (which is currently the case for all installs). We use 203 | the environment parameters to detect whether there is already a pre-built version of a compatible 204 | wheel available and short-circuits the standard full build pipeline. 205 | """ 206 | 207 | def run(self): 208 | if FORCE_BUILD: 209 | return super().run() 210 | 211 | wheel_url, wheel_filename = get_wheel_url() 212 | print("Guessing wheel URL: ", wheel_url) 213 | try: 214 | urllib.request.urlretrieve(wheel_url, wheel_filename) 215 | 216 | # Make the archive 217 | # Lifted from the root wheel processing command 218 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 219 | if not os.path.exists(self.dist_dir): 220 | os.makedirs(self.dist_dir) 221 | 222 | impl_tag, abi_tag, plat_tag = self.get_tag() 223 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 224 | 225 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 226 | print("Raw wheel path", wheel_path) 227 | shutil.move(wheel_filename, wheel_path) 228 | except urllib.error.HTTPError: 229 | print("Precompiled wheel not found. Building from source...") 230 | # If the wheel could not be downloaded, build from source 231 | super().run() 232 | 233 | 234 | setup( 235 | name=PACKAGE_NAME, 236 | version=get_package_version(), 237 | packages=find_packages( 238 | exclude=( 239 | "build", 240 | "csrc", 241 | "include", 242 | "tests", 243 | "dist", 244 | "docs", 245 | "benchmarks", 246 | "mamba_ssm.egg-info", 247 | ) 248 | ), 249 | author="Tri Dao, Albert Gu", 250 | author_email="tri@tridao.me, agu@cs.cmu.edu", 251 | description="Mamba state-space model", 252 | long_description=long_description, 253 | long_description_content_type="text/markdown", 254 | url="https://github.com/state-spaces/mamba", 255 | classifiers=[ 256 | "Programming Language :: Python :: 3", 257 | "License :: OSI Approved :: BSD License", 258 | "Operating System :: Unix", 259 | ], 260 | ext_modules=ext_modules, 261 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 262 | if ext_modules 263 | else { 264 | "bdist_wheel": CachedWheelsCommand, 265 | }, 266 | python_requires=">=3.7", 267 | install_requires=[ 268 | "torch", 269 | "packaging", 270 | "ninja", 271 | "einops", 272 | "triton", 273 | "transformers", 274 | # "causal_conv1d>=1.1.0", 275 | ], 276 | ) 277 | -------------------------------------------------------------------------------- /mamba-1p1p1/tests/ops/test_selective_scan.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref 12 | from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref 13 | 14 | 15 | # @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) 16 | @pytest.mark.parametrize('wtype', [torch.float32]) 17 | # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) 18 | @pytest.mark.parametrize('itype', [torch.float32]) 19 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) 20 | @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) 21 | # @pytest.mark.parametrize('seqlen', [128]) 22 | # @pytest.mark.parametrize("return_last_state", [False, True]) 23 | @pytest.mark.parametrize("return_last_state", [True]) 24 | # @pytest.mark.parametrize('has_delta_bias', [False, True]) 25 | @pytest.mark.parametrize('has_delta_bias', [True]) 26 | # @pytest.mark.parametrize('delta_softplus', [False, True]) 27 | @pytest.mark.parametrize('delta_softplus', [True]) 28 | # @pytest.mark.parametrize('has_z', [False, True]) 29 | @pytest.mark.parametrize('has_z', [True]) 30 | # @pytest.mark.parametrize('has_D', [False, True]) 31 | @pytest.mark.parametrize('has_D', [True]) 32 | @pytest.mark.parametrize("varBC_groups", [1, 2]) 33 | # @pytest.mark.parametrize("varBC_groups", [1]) 34 | # @pytest.mark.parametrize("is_variable_C", [False, True]) 35 | @pytest.mark.parametrize("is_variable_C", [True]) 36 | # @pytest.mark.parametrize("is_variable_B", [False, True]) 37 | @pytest.mark.parametrize("is_variable_B", [True]) 38 | def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, 39 | delta_softplus, return_last_state, seqlen, itype, wtype): 40 | if varBC_groups > 1 and (not is_variable_B or not is_variable_C): 41 | pytest.skip() # This config is not applicable 42 | device = 'cuda' 43 | rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) 44 | if itype == torch.bfloat16: 45 | rtol, atol = 3e-2, 5e-2 46 | rtolw, atolw = (1e-3, 1e-3) 47 | if has_z: # If we have z, the errors on the weights seem higher 48 | rtolw = max(rtolw, rtol) 49 | atolw = max(atolw, atol) 50 | # set seed 51 | torch.random.manual_seed(0) 52 | batch_size = 2 53 | dim = 4 54 | dstate = 8 55 | is_complex = wtype == torch.complex64 56 | A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() 57 | if not is_variable_B: 58 | B_shape = (dim, dstate) 59 | elif varBC_groups == 1: 60 | B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) 61 | else: 62 | B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) 63 | B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, 64 | requires_grad=True) 65 | if not is_variable_C: 66 | C_shape = (dim, dstate) 67 | elif varBC_groups == 1: 68 | C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) 69 | else: 70 | C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) 71 | C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, 72 | requires_grad=True) 73 | if has_D: 74 | D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 75 | else: 76 | D = None 77 | if has_z: 78 | z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) 79 | else: 80 | z = None 81 | if has_delta_bias: 82 | delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() 83 | else: 84 | delta_bias = None 85 | u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) 86 | delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() 87 | A_ref = A.detach().clone().requires_grad_() 88 | B_ref = B.detach().clone().requires_grad_() 89 | C_ref = C.detach().clone().requires_grad_() 90 | D_ref = D.detach().clone().requires_grad_() if D is not None else None 91 | z_ref = z.detach().clone().requires_grad_() if z is not None else None 92 | u_ref = u.detach().clone().requires_grad_() 93 | delta_ref = delta.detach().clone().requires_grad_() 94 | delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None 95 | out, *rest = selective_scan_fn( 96 | u, delta, A, B, C, D, z=z, 97 | delta_bias=delta_bias, delta_softplus=delta_softplus, 98 | return_last_state=return_last_state 99 | ) 100 | if return_last_state: 101 | state = rest[0] 102 | out_ref, *rest = selective_scan_ref( 103 | u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, 104 | delta_bias=delta_bias_ref, delta_softplus=delta_softplus, 105 | return_last_state=return_last_state 106 | ) 107 | if return_last_state: 108 | state_ref = rest[0] 109 | # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 110 | # dt_u = delta * u 111 | 112 | print(f'Output max diff: {(out - out_ref).abs().max().item()}') 113 | print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') 114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 115 | if return_last_state: 116 | print(f'State max diff: {(state - state_ref).abs().max().item()}') 117 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 118 | 119 | g = torch.randn_like(out) 120 | out_ref.backward(g) 121 | out.backward(g) 122 | 123 | print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') 124 | print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') 125 | print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') 126 | print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') 127 | print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') 128 | if has_D: 129 | print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') 130 | if has_z: 131 | print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') 132 | if has_delta_bias: 133 | print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') 134 | 135 | assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) 136 | assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) 137 | assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) 138 | assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, 139 | atol=atolw if not is_variable_B else atol) 140 | assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, 141 | atol=atolw if not is_variable_C else atol) 142 | if has_D: 143 | assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) 144 | if has_z: 145 | assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) 146 | if has_delta_bias: 147 | assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) 148 | 149 | 150 | @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) 151 | # @pytest.mark.parametrize('wtype', [torch.complex64]) 152 | # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) 153 | @pytest.mark.parametrize('itype', [torch.float32]) 154 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) 155 | @pytest.mark.parametrize('seqlen', [128]) 156 | @pytest.mark.parametrize("is_variable_C", [False, True]) 157 | # @pytest.mark.parametrize("is_variable_C", [False]) 158 | @pytest.mark.parametrize("is_variable_B", [False, True]) 159 | # @pytest.mark.parametrize("is_variable_B", [True]) 160 | def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): 161 | device = 'cuda' 162 | rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) 163 | if itype == torch.bfloat16: 164 | rtol, atol = 3e-2, 5e-2 165 | rtolw, atolw = (1e-3, 1e-3) 166 | # If we have z, the errors on the weights seem higher 167 | rtolw = max(rtolw, rtol) 168 | atolw = max(atolw, atol) 169 | # set seed 170 | torch.random.manual_seed(0) 171 | batch_size = 2 172 | dim = 768 173 | dstate = 8 174 | dt_rank = 48 175 | is_complex = wtype == torch.complex64 176 | xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) 177 | conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) 178 | conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 179 | x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate 180 | * (1 if not is_complex else 2), 181 | dim, device=device, dtype=itype, requires_grad=True) 182 | delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) 183 | out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) 184 | out_proj_bias = None 185 | A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() 186 | B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) 187 | if not is_variable_B else None) 188 | C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) 189 | if not is_variable_C else None) 190 | D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 191 | delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() 192 | B_proj_bias = None 193 | C_proj_bias = None 194 | xz_ref = xz.detach().clone().requires_grad_() 195 | conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() 196 | conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() 197 | x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() 198 | delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() 199 | out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() 200 | out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() 201 | if out_proj_bias is not None else None) 202 | A_ref = A.detach().clone().requires_grad_() 203 | B_ref = B.detach().clone().requires_grad_() if B is not None else None 204 | C_ref = C.detach().clone().requires_grad_() if C is not None else None 205 | D_ref = D.detach().clone().requires_grad_() 206 | delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None 207 | out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, 208 | out_proj_weight, out_proj_bias, 209 | A, B, C, D, delta_bias=delta_bias, delta_softplus=True) 210 | out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref, 211 | delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref, 212 | A_ref, B_ref, C_ref, D_ref, 213 | delta_bias=delta_bias_ref, delta_softplus=True) 214 | # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) 215 | # dt_u = delta * u 216 | 217 | print(f'Output max diff: {(out - out_ref).abs().max().item()}') 218 | print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') 219 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 220 | 221 | g = torch.randn_like(out) 222 | out_ref.backward(g) 223 | out.backward(g) 224 | 225 | print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}') 226 | print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') 227 | if not is_variable_B: 228 | print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') 229 | if not is_variable_C: 230 | print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') 231 | print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') 232 | print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') 233 | print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}') 234 | print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}') 235 | print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}') 236 | print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}') 237 | print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}') 238 | 239 | # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) 240 | # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) 241 | # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) 242 | # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, 243 | # atol=atolw if not is_variable_B else atol) 244 | # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, 245 | # atol=atolw if not is_variable_C else atol) 246 | # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) 247 | # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) 248 | -------------------------------------------------------------------------------- /mamba-1p1p1/tests/ops/triton/test_selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 15 | # @pytest.mark.parametrize('itype', [torch.float16]) 16 | @pytest.mark.parametrize("has_z", [False, True]) 17 | # @pytest.mark.parametrize('has_z', [True]) 18 | @pytest.mark.parametrize("dstate", [16, 32, 64]) 19 | # @pytest.mark.parametrize("dstate", [16]) 20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 21 | # @pytest.mark.parametrize("dim", [2048]) 22 | def test_causal_conv1d_update(dim, dstate, has_z, itype): 23 | device = "cuda" 24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) 25 | if itype == torch.bfloat16: 26 | rtol, atol = 1e-2, 5e-2 27 | # set seed 28 | torch.random.manual_seed(0) 29 | batch_size = 2 30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) 31 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype) 33 | dt_bias = torch.rand(dim, device=device) - 4.0 34 | A = -torch.rand(dim, dstate, device=device) - 1.0 35 | B = torch.randn(batch_size, dstate, device=device) 36 | C = torch.randn(batch_size, dstate, device=device) 37 | D = torch.randn(dim, device=device) 38 | if has_z: 39 | z = torch.randn_like(x) 40 | else: 41 | z = None 42 | state_ref = state.detach().clone() 43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 45 | 46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 50 | -------------------------------------------------------------------------------- /vim/augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | """ 5 | 3Augment implementation 6 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 7 | and timm DA(https://github.com/rwightman/pytorch-image-models) 8 | """ 9 | import torch 10 | from torchvision import transforms 11 | 12 | # error: cannot import name '_pil_interp' from 'timm.data.transforms' 13 | # from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor 14 | 15 | # fix: timm version problem 16 | # from timm.data.transforms import str_pil_interp as _pil_interp 17 | from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor 18 | 19 | import numpy as np 20 | from torchvision import datasets, transforms 21 | import random 22 | 23 | 24 | 25 | from PIL import ImageFilter, ImageOps 26 | import torchvision.transforms.functional as TF 27 | 28 | 29 | class GaussianBlur(object): 30 | """ 31 | Apply Gaussian Blur to the PIL image. 32 | """ 33 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 34 | self.prob = p 35 | self.radius_min = radius_min 36 | self.radius_max = radius_max 37 | 38 | def __call__(self, img): 39 | do_it = random.random() <= self.prob 40 | if not do_it: 41 | return img 42 | 43 | img = img.filter( 44 | ImageFilter.GaussianBlur( 45 | radius=random.uniform(self.radius_min, self.radius_max) 46 | ) 47 | ) 48 | return img 49 | 50 | class Solarization(object): 51 | """ 52 | Apply Solarization to the PIL image. 53 | """ 54 | def __init__(self, p=0.2): 55 | self.p = p 56 | 57 | def __call__(self, img): 58 | if random.random() < self.p: 59 | return ImageOps.solarize(img) 60 | else: 61 | return img 62 | 63 | class gray_scale(object): 64 | """ 65 | Apply Solarization to the PIL image. 66 | """ 67 | def __init__(self, p=0.2): 68 | self.p = p 69 | self.transf = transforms.Grayscale(3) 70 | 71 | def __call__(self, img): 72 | if random.random() < self.p: 73 | return self.transf(img) 74 | else: 75 | return img 76 | 77 | 78 | 79 | class horizontal_flip(object): 80 | """ 81 | Apply Solarization to the PIL image. 82 | """ 83 | def __init__(self, p=0.2,activate_pred=False): 84 | self.p = p 85 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 86 | 87 | def __call__(self, img): 88 | if random.random() < self.p: 89 | return self.transf(img) 90 | else: 91 | return img 92 | 93 | 94 | 95 | def new_data_aug_generator(args = None): 96 | img_size = args.input_size 97 | remove_random_resized_crop = args.src 98 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 99 | primary_tfl = [] 100 | scale=(0.08, 1.0) 101 | interpolation='bicubic' 102 | if remove_random_resized_crop: 103 | primary_tfl = [ 104 | transforms.Resize(img_size, interpolation=3), 105 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), 106 | transforms.RandomHorizontalFlip() 107 | ] 108 | else: 109 | primary_tfl = [ 110 | RandomResizedCropAndInterpolation( 111 | img_size, scale=scale, interpolation=interpolation), 112 | transforms.RandomHorizontalFlip() 113 | ] 114 | 115 | 116 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), 117 | Solarization(p=1.0), 118 | GaussianBlur(p=1.0)])] 119 | 120 | if args.color_jitter is not None and not args.color_jitter==0: 121 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 122 | final_tfl = [ 123 | transforms.ToTensor(), 124 | transforms.Normalize( 125 | mean=torch.tensor(mean), 126 | std=torch.tensor(std)) 127 | ] 128 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) 129 | -------------------------------------------------------------------------------- /vim/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | 12 | 13 | class INatDataset(ImageFolder): 14 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 15 | category='name', loader=default_loader): 16 | self.transform = transform 17 | self.loader = loader 18 | self.target_transform = target_transform 19 | self.year = year 20 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 21 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 22 | with open(path_json) as json_file: 23 | data = json.load(json_file) 24 | 25 | with open(os.path.join(root, 'categories.json')) as json_file: 26 | data_catg = json.load(json_file) 27 | 28 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 29 | 30 | with open(path_json_for_targeter) as json_file: 31 | data_for_targeter = json.load(json_file) 32 | 33 | targeter = {} 34 | indexer = 0 35 | for elem in data_for_targeter['annotations']: 36 | king = [] 37 | king.append(data_catg[int(elem['category_id'])][category]) 38 | if king[0] not in targeter.keys(): 39 | targeter[king[0]] = indexer 40 | indexer += 1 41 | self.nb_classes = len(targeter) 42 | 43 | self.samples = [] 44 | for elem in data['images']: 45 | cut = elem['file_name'].split('/') 46 | target_current = int(cut[2]) 47 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 48 | 49 | categors = data_catg[target_current] 50 | target_current_true = targeter[categors[category]] 51 | self.samples.append((path_current, target_current_true)) 52 | 53 | # __getitem__ and __len__ inherited from ImageFolder 54 | 55 | 56 | def build_dataset(is_train, args): 57 | transform = build_transform(is_train, args) 58 | 59 | if args.data_set == 'CIFAR': 60 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 61 | nb_classes = 100 62 | elif args.data_set == 'IMNET': 63 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 64 | dataset = datasets.ImageFolder(root, transform=transform) 65 | nb_classes = 1000 66 | elif args.data_set == 'INAT': 67 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 68 | category=args.inat_category, transform=transform) 69 | nb_classes = dataset.nb_classes 70 | elif args.data_set == 'INAT19': 71 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 72 | category=args.inat_category, transform=transform) 73 | nb_classes = dataset.nb_classes 74 | 75 | return dataset, nb_classes 76 | 77 | 78 | def build_transform(is_train, args): 79 | resize_im = args.input_size > 32 80 | if is_train: 81 | # this should always dispatch to transforms_imagenet_train 82 | transform = create_transform( 83 | input_size=args.input_size, 84 | is_training=True, 85 | color_jitter=args.color_jitter, 86 | auto_augment=args.aa, 87 | interpolation=args.train_interpolation, 88 | re_prob=args.reprob, 89 | re_mode=args.remode, 90 | re_count=args.recount, 91 | ) 92 | if not resize_im: 93 | # replace RandomResizedCropAndInterpolation with 94 | # RandomCrop 95 | transform.transforms[0] = transforms.RandomCrop( 96 | args.input_size, padding=4) 97 | return transform 98 | 99 | t = [] 100 | if resize_im: 101 | size = int(args.input_size / args.eval_crop_ratio) 102 | t.append( 103 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 104 | ) 105 | t.append(transforms.CenterCrop(args.input_size)) 106 | 107 | t.append(transforms.ToTensor()) 108 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 109 | return transforms.Compose(t) 110 | -------------------------------------------------------------------------------- /vim/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | 10 | import torch 11 | 12 | import timm 13 | from timm.data import Mixup 14 | from timm.utils import accuracy, ModelEma 15 | 16 | from losses import DistillationLoss 17 | import utils 18 | 19 | 20 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 21 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 22 | device: torch.device, epoch: int, loss_scaler, amp_autocast, max_norm: float = 0, 23 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 24 | set_training_mode=True, args = None): 25 | model.train(set_training_mode) 26 | metric_logger = utils.MetricLogger(delimiter=" ") 27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | header = 'Epoch: [{}]'.format(epoch) 29 | print_freq = 10 30 | 31 | if args.cosub: 32 | criterion = torch.nn.BCEWithLogitsLoss() 33 | 34 | # debug 35 | # count = 0 36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 37 | # count += 1 38 | # if count > 20: 39 | # break 40 | 41 | samples = samples.to(device, non_blocking=True) 42 | targets = targets.to(device, non_blocking=True) 43 | 44 | if mixup_fn is not None: 45 | samples, targets = mixup_fn(samples, targets) 46 | 47 | if args.cosub: 48 | samples = torch.cat((samples,samples),dim=0) 49 | 50 | if args.bce_loss: 51 | targets = targets.gt(0.0).type(targets.dtype) 52 | 53 | with amp_autocast(): 54 | outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank) 55 | # outputs = model(samples) 56 | if not args.cosub: 57 | loss = criterion(samples, outputs, targets) 58 | else: 59 | outputs = torch.split(outputs, outputs.shape[0]//2, dim=0) 60 | loss = 0.25 * criterion(outputs[0], targets) 61 | loss = loss + 0.25 * criterion(outputs[1], targets) 62 | loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid()) 63 | loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 64 | 65 | if args.if_nan2num: 66 | with amp_autocast(): 67 | loss = torch.nan_to_num(loss) 68 | 69 | loss_value = loss.item() 70 | 71 | if not math.isfinite(loss_value): 72 | print("Loss is {}, stopping training".format(loss_value)) 73 | if args.if_continue_inf: 74 | optimizer.zero_grad() 75 | continue 76 | else: 77 | sys.exit(1) 78 | 79 | optimizer.zero_grad() 80 | 81 | # this attribute is added by timm on one optimizer (adahessian) 82 | if isinstance(loss_scaler, timm.utils.NativeScaler): 83 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 84 | loss_scaler(loss, optimizer, clip_grad=max_norm, 85 | parameters=model.parameters(), create_graph=is_second_order) 86 | else: 87 | loss.backward() 88 | if max_norm != None: 89 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 90 | optimizer.step() 91 | 92 | torch.cuda.synchronize() 93 | if model_ema is not None: 94 | model_ema.update(model) 95 | 96 | metric_logger.update(loss=loss_value) 97 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 98 | # gather the stats from all processes 99 | metric_logger.synchronize_between_processes() 100 | print("Averaged stats:", metric_logger) 101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 102 | 103 | 104 | @torch.no_grad() 105 | def evaluate(data_loader, model, device, amp_autocast): 106 | criterion = torch.nn.CrossEntropyLoss() 107 | 108 | metric_logger = utils.MetricLogger(delimiter=" ") 109 | header = 'Test:' 110 | 111 | # switch to evaluation mode 112 | model.eval() 113 | 114 | for images, target in metric_logger.log_every(data_loader, 10, header): 115 | images = images.to(device, non_blocking=True) 116 | target = target.to(device, non_blocking=True) 117 | 118 | # compute output 119 | with amp_autocast(): 120 | output = model(images) 121 | loss = criterion(output, target) 122 | 123 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 124 | 125 | batch_size = images.shape[0] 126 | metric_logger.update(loss=loss.item()) 127 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 128 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 129 | # gather the stats from all processes 130 | metric_logger.synchronize_between_processes() 131 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 132 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 133 | 134 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 135 | -------------------------------------------------------------------------------- /vim/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from models import * 4 | from cait_models import * 5 | from resmlp_models import * 6 | #from patchconvnet_models import * 7 | 8 | dependencies = ["torch", "torchvision", "timm"] 9 | -------------------------------------------------------------------------------- /vim/images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/1.jpg -------------------------------------------------------------------------------- /vim/images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/2.jpg -------------------------------------------------------------------------------- /vim/images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/3.jpg -------------------------------------------------------------------------------- /vim/images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/4.jpg -------------------------------------------------------------------------------- /vim/images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/5.jpg -------------------------------------------------------------------------------- /vim/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | #We provide the teacher's targets in log probability because we use log_target=True 57 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 58 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 59 | F.log_softmax(teacher_outputs / T, dim=1), 60 | reduction='sum', 61 | log_target=True 62 | ) * (T * T) / outputs_kd.numel() 63 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 64 | #But we also experiments output_kd.size(0) 65 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 66 | elif self.distillation_type == 'hard': 67 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 68 | 69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 70 | return loss 71 | -------------------------------------------------------------------------------- /vim/rope.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # EVA-02: A Visual Representation for Neon Genesis 3 | # Github source: https://github.com/baaivision/EVA/EVA02 4 | # Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI) 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Yuxin Fang 7 | # 8 | # Based on https://github.com/lucidrains/rotary-embedding-torch 9 | # --------------------------------------------------------' 10 | 11 | from math import pi 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from einops import rearrange, repeat 17 | 18 | 19 | 20 | def broadcat(tensors, dim = -1): 21 | num_tensors = len(tensors) 22 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 23 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 24 | shape_len = list(shape_lens)[0] 25 | dim = (dim + shape_len) if dim < 0 else dim 26 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 27 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 28 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 29 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 30 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 31 | expanded_dims.insert(dim, (dim, dims[dim])) 32 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 33 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 34 | return torch.cat(tensors, dim = dim) 35 | 36 | 37 | 38 | def rotate_half(x): 39 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 40 | x1, x2 = x.unbind(dim = -1) 41 | x = torch.stack((-x2, x1), dim = -1) 42 | return rearrange(x, '... d r -> ... (d r)') 43 | 44 | 45 | 46 | class VisionRotaryEmbedding(nn.Module): 47 | def __init__( 48 | self, 49 | dim, 50 | pt_seq_len, 51 | ft_seq_len=None, 52 | custom_freqs = None, 53 | freqs_for = 'lang', 54 | theta = 10000, 55 | max_freq = 10, 56 | num_freqs = 1, 57 | ): 58 | super().__init__() 59 | if custom_freqs: 60 | freqs = custom_freqs 61 | elif freqs_for == 'lang': 62 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 63 | elif freqs_for == 'pixel': 64 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 65 | elif freqs_for == 'constant': 66 | freqs = torch.ones(num_freqs).float() 67 | else: 68 | raise ValueError(f'unknown modality {freqs_for}') 69 | 70 | if ft_seq_len is None: ft_seq_len = pt_seq_len 71 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 72 | 73 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 74 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 75 | 76 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 77 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 78 | 79 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 80 | 81 | self.register_buffer("freqs_cos", freqs.cos()) 82 | self.register_buffer("freqs_sin", freqs.sin()) 83 | 84 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 85 | 86 | def forward(self, t, start_index = 0): 87 | rot_dim = self.freqs_cos.shape[-1] 88 | end_index = start_index + rot_dim 89 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 90 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 91 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 92 | return torch.cat((t_left, t, t_right), dim = -1) 93 | 94 | 95 | 96 | class VisionRotaryEmbeddingFast(nn.Module): 97 | def __init__( 98 | self, 99 | dim, 100 | pt_seq_len=16, 101 | ft_seq_len=None, 102 | custom_freqs = None, 103 | freqs_for = 'lang', 104 | theta = 10000, 105 | max_freq = 10, 106 | num_freqs = 1, 107 | ): 108 | super().__init__() 109 | if custom_freqs: 110 | freqs = custom_freqs 111 | elif freqs_for == 'lang': 112 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 113 | elif freqs_for == 'pixel': 114 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 115 | elif freqs_for == 'constant': 116 | freqs = torch.ones(num_freqs).float() 117 | else: 118 | raise ValueError(f'unknown modality {freqs_for}') 119 | 120 | if ft_seq_len is None: ft_seq_len = pt_seq_len 121 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 122 | 123 | freqs = torch.einsum('..., f -> ... f', t, freqs) 124 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 125 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 126 | 127 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 128 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 129 | 130 | self.register_buffer("freqs_cos", freqs_cos) 131 | self.register_buffer("freqs_sin", freqs_sin) 132 | 133 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 134 | 135 | def forward(self, t): 136 | if t.shape[1] % 2 != 0: 137 | t_spatial = t[:, 1:, :] 138 | t_spatial = t_spatial * self.freqs_cos + rotate_half(t_spatial) * self.freqs_sin 139 | return torch.cat((t[:, :1, :], t_spatial), dim=1) 140 | else: 141 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin -------------------------------------------------------------------------------- /vim/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | A script to run multinode training with submitit. 5 | """ 6 | import argparse 7 | import os 8 | import uuid 9 | from pathlib import Path 10 | 11 | import main as classification 12 | import submitit 13 | 14 | 15 | def parse_args(): 16 | classification_parser = classification.get_args_parser() 17 | parser = argparse.ArgumentParser("Submitit for DeiT", parents=[classification_parser]) 18 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 19 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 20 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 21 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 22 | 23 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 24 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 25 | parser.add_argument('--comment', default="", type=str, 26 | help='Comment to pass to scheduler, e.g. priority message') 27 | return parser.parse_args() 28 | 29 | 30 | def get_shared_folder() -> Path: 31 | user = os.getenv("USER") 32 | if Path("/checkpoint/").is_dir(): 33 | p = Path(f"/checkpoint/{user}/experiments") 34 | p.mkdir(exist_ok=True) 35 | return p 36 | raise RuntimeError("No shared folder available") 37 | 38 | 39 | def get_init_file(): 40 | # Init file must not exist, but it's parent dir must exist. 41 | os.makedirs(str(get_shared_folder()), exist_ok=True) 42 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 43 | if init_file.exists(): 44 | os.remove(str(init_file)) 45 | return init_file 46 | 47 | 48 | class Trainer(object): 49 | def __init__(self, args): 50 | self.args = args 51 | 52 | def __call__(self): 53 | import main as classification 54 | 55 | self._setup_gpu_args() 56 | classification.main(self.args) 57 | 58 | def checkpoint(self): 59 | import os 60 | import submitit 61 | 62 | self.args.dist_url = get_init_file().as_uri() 63 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 64 | if os.path.exists(checkpoint_file): 65 | self.args.resume = checkpoint_file 66 | print("Requeuing ", self.args) 67 | empty_trainer = type(self)(self.args) 68 | return submitit.helpers.DelayedSubmission(empty_trainer) 69 | 70 | def _setup_gpu_args(self): 71 | import submitit 72 | from pathlib import Path 73 | 74 | job_env = submitit.JobEnvironment() 75 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 76 | self.args.gpu = job_env.local_rank 77 | self.args.rank = job_env.global_rank 78 | self.args.world_size = job_env.num_tasks 79 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 80 | 81 | 82 | def main(): 83 | args = parse_args() 84 | if args.job_dir == "": 85 | args.job_dir = get_shared_folder() / "%j" 86 | 87 | # Note that the folder will depend on the job_id, to easily track experiments 88 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 89 | 90 | num_gpus_per_node = args.ngpus 91 | nodes = args.nodes 92 | timeout_min = args.timeout 93 | 94 | partition = args.partition 95 | kwargs = {} 96 | if args.use_volta32: 97 | kwargs['slurm_constraint'] = 'volta32gb' 98 | if args.comment: 99 | kwargs['slurm_comment'] = args.comment 100 | 101 | executor.update_parameters( 102 | mem_gb=40 * num_gpus_per_node, 103 | gpus_per_node=num_gpus_per_node, 104 | tasks_per_node=num_gpus_per_node, # one task per GPU 105 | cpus_per_task=10, 106 | nodes=nodes, 107 | timeout_min=timeout_min, # max is 60 * 72 108 | # Below are cluster dependent parameters 109 | slurm_partition=partition, 110 | slurm_signal_delay_s=120, 111 | **kwargs 112 | ) 113 | 114 | executor.update_parameters(name="deit") 115 | 116 | args.dist_url = get_init_file().as_uri() 117 | args.output_dir = args.job_dir 118 | 119 | trainer = Trainer(args) 120 | job = executor.submit(trainer) 121 | 122 | print("Submitted job_id:", job.job_id) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /vim/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | if num_repeats < 1: 26 | raise ValueError("num_repeats should be greater than 0") 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.num_repeats = num_repeats 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g) 44 | else: 45 | indices = torch.arange(start=0, end=len(self.dataset)) 46 | 47 | # add extra samples to make it evenly divisible 48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 49 | padding_size: int = self.total_size - len(indices) 50 | if padding_size > 0: 51 | indices += indices[:padding_size] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /vim/scripts/ft-vim-s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate 3 | cd /vim; 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --lr 5e-6 --min-lr 1e-5 --warmup-lr 1e-5 --drop-path 0.0 --weight-decay 1e-8 --num_workers 25 --data-path --output_dir ./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --epochs 30 --finetune --no_amp 6 | -------------------------------------------------------------------------------- /vim/scripts/ft-vim-t.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate 3 | cd /vim; 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --lr 5e-6 --min-lr 1e-5 --warmup-lr 1e-5 --drop-path 0.0 --weight-decay 1e-8 --num_workers 25 --data-path --output_dir ./output/vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --epochs 30 --finetune --no_amp 6 | -------------------------------------------------------------------------------- /vim/scripts/pt-vim-s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate 3 | cd /vim; 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 64 --drop-path 0.05 --weight-decay 0.05 --lr 1e-3 --num_workers 25 --data-path --output_dir ./output/vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp 6 | -------------------------------------------------------------------------------- /vim/scripts/pt-vim-t.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate 3 | cd /vim; 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-path --output_dir ./output/vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp 6 | -------------------------------------------------------------------------------- /vim/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save({'state_dict_ema':checkpoint}, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | 240 | 241 | # if 'pos_embed' in state_dict: 242 | def interpolate_pos_embed(model, state_dict): 243 | pos_embed_checkpoint = state_dict['pos_embed'] 244 | embedding_size = pos_embed_checkpoint.shape[-1] 245 | num_patches = model.patch_embed.num_patches 246 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 247 | # height (== width) for the checkpoint position embedding 248 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 249 | # height (== width) for the new position embedding 250 | new_size = int(num_patches ** 0.5) 251 | # class_token and dist_token are kept unchanged 252 | if orig_size != new_size: 253 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 254 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 255 | # only the position tokens are interpolated 256 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 257 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 258 | pos_tokens = torch.nn.functional.interpolate( 259 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 260 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 261 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 262 | state_dict['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /vim/vim_requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | aiohttp==3.9.1 3 | aiosignal==1.3.1 4 | alembic==1.13.0 5 | async-timeout==4.0.3 6 | attrs==23.1.0 7 | blinker==1.7.0 8 | # causal-conv1d @ file:///home/zhulianghui/VisionProjects/mamba/lib/causal_conv1d-1.0.0%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=79a4bab633ebff031e615d5e8ba396b0dc0c046f4406980ee238fb86a9090038 9 | certifi==2023.11.17 10 | charset-normalizer==3.3.2 11 | click==8.1.7 12 | cloudpickle==3.0.0 13 | contourpy==1.2.0 14 | cycler==0.12.1 15 | databricks-cli==0.18.0 16 | datasets==2.15.0 17 | dill==0.3.7 18 | docker==6.1.3 19 | einops==0.7.0 20 | entrypoints==0.4 21 | filelock==3.13.1 22 | Flask==3.0.0 23 | fonttools==4.46.0 24 | frozenlist==1.4.0 25 | fsspec==2023.10.0 26 | gitdb==4.0.11 27 | GitPython==3.1.40 28 | greenlet==3.0.2 29 | gunicorn==21.2.0 30 | huggingface-hub==0.19.4 31 | idna==3.6 32 | importlib-metadata==7.0.0 33 | itsdangerous==2.1.2 34 | Jinja2==3.1.2 35 | joblib==1.3.2 36 | kiwisolver==1.4.5 37 | Mako==1.3.0 38 | # mamba-ssm @ file:///home/zhulianghui/VisionProjects/mamba/lib/mamba_ssm-1.0.1%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=71ad1b1eafb05a6e8a41fd82e046fe85511d6378fa3a583e55215b6aa1d65ab9 39 | Markdown==3.5.1 40 | MarkupSafe==2.1.3 41 | matplotlib==3.8.2 42 | mlflow==2.9.1 43 | mmcv==1.3.8 44 | mmsegmentation==0.14.1 45 | mpmath==1.3.0 46 | multidict==6.0.4 47 | multiprocess==0.70.15 48 | networkx==3.2.1 49 | ninja==1.11.1.1 50 | numpy==1.26.2 51 | # nvidia-cublas-cu12==12.1.3.1 52 | # nvidia-cuda-cupti-cu12==12.1.105 53 | # nvidia-cuda-nvrtc-cu12==12.1.105 54 | # nvidia-cuda-runtime-cu12==12.1.105 55 | # nvidia-cudnn-cu12==8.9.2.26 56 | # nvidia-cufft-cu12==11.0.2.54 57 | # nvidia-curand-cu12==10.3.2.106 58 | # nvidia-cusolver-cu12==11.4.5.107 59 | # nvidia-cusparse-cu12==12.1.0.106 60 | # nvidia-nccl-cu12==2.18.1 61 | # nvidia-nvjitlink-cu12==12.3.101 62 | # nvidia-nvtx-cu12==12.1.105 63 | oauthlib==3.2.2 64 | opencv-python==4.8.1.78 65 | packaging==23.2 66 | pandas==2.1.3 67 | Pillow==10.1.0 68 | platformdirs==4.1.0 69 | prettytable==3.9.0 70 | protobuf==4.25.1 71 | pyarrow==14.0.1 72 | pyarrow-hotfix==0.6 73 | PyJWT==2.8.0 74 | pyparsing==3.1.1 75 | python-dateutil==2.8.2 76 | python-hostlist==1.23.0 77 | pytz==2023.3.post1 78 | PyYAML==6.0.1 79 | querystring-parser==1.2.4 80 | regex==2023.10.3 81 | requests==2.31.0 82 | safetensors==0.4.1 83 | scikit-learn==1.3.2 84 | scipy==1.11.4 85 | six==1.16.0 86 | smmap==5.0.1 87 | SQLAlchemy==2.0.23 88 | sqlparse==0.4.4 89 | sympy==1.12 90 | tabulate==0.9.0 91 | threadpoolctl==3.2.0 92 | timm==0.4.12 93 | tokenizers==0.15.0 94 | tomli==2.0.1 95 | # torch==2.1.1+cu118 96 | # torchvision==0.16.1+cu118 97 | tqdm==4.66.1 98 | transformers==4.35.2 99 | triton==2.1.0 100 | typing_extensions==4.8.0 101 | tzdata==2023.3 102 | urllib3==2.1.0 103 | wcwidth==0.2.12 104 | websocket-client==1.7.0 105 | Werkzeug==3.0.1 106 | xxhash==3.4.1 107 | yapf==0.40.2 108 | yarl==1.9.4 109 | zipp==3.17.0 110 | -------------------------------------------------------------------------------- /vim/vmamba_xai.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/media/data1/ameenali/miniconda3/envs/github/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import numpy as np\n", 19 | "import torch\n", 20 | "import torch.backends.cudnn as cudnn\n", 21 | "from timm.models import create_model\n", 22 | "import models_mamba\n", 23 | "import utils\n", 24 | "import os\n", 25 | "from xai_utils import *\n", 26 | "from class_mapper import CLS2IDX\n", 27 | "import matplotlib.pyplot as plt" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "

Load Model

\n", 35 | "Make sure to speiciy the model checkpoint path" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "data": { 45 | "text/plain": [ 46 | "" 47 | ] 48 | }, 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "output_type": "execute_result" 52 | } 53 | ], 54 | "source": [ 55 | "model_type = 'vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2'\n", 56 | "model_path = './vim_s_midclstok_80p5acc.pth'\n", 57 | "num_classes = 1000\n", 58 | "model = create_model(\n", 59 | " model_type,\n", 60 | " pretrained=False,\n", 61 | " num_classes=num_classes,\n", 62 | " drop_rate=0,\n", 63 | " drop_path_rate=0,\n", 64 | " drop_block_rate=None,\n", 65 | " img_size=224\n", 66 | ")\n", 67 | "checkpoint = torch.load(model_path, map_location='cpu')\n", 68 | "model.load_state_dict(checkpoint['model'])" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "

Auxiliary Functions

" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "from PIL import Image\n", 85 | "import torchvision.transforms as transforms\n", 86 | "\n", 87 | "IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]\n", 88 | "IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]\n", 89 | "\n", 90 | "def transform_for_eval(image_path, input_size=224):\n", 91 | " transform_eval = transforms.Compose([\n", 92 | " transforms.Resize(int(input_size)),\n", 93 | " transforms.CenterCrop(input_size),\n", 94 | " transforms.ToTensor(),\n", 95 | " transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),\n", 96 | " ])\n", 97 | " img = Image.open(image_path).convert('RGB')\n", 98 | " transformed_img = transform_eval(img)\n", 99 | " return transformed_img\n", 100 | "\n", 101 | "import cv2\n", 102 | "\n", 103 | "invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],\n", 104 | " std = [ 1/0.229, 1/0.224, 1/0.225 ]),\n", 105 | " transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],\n", 106 | " std = [ 1., 1., 1. ]),\n", 107 | " ])\n", 108 | "\n", 109 | "def show_cam_on_image(img, mask):\n", 110 | " heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)\n", 111 | " heatmap = np.float32(heatmap) / 255\n", 112 | " cam = heatmap + np.float32(img)\n", 113 | " cam = cam / np.max(cam)\n", 114 | " return cam\n", 115 | "\n", 116 | "\n", 117 | "def generate_visualization(original_image, transformer_attribution):\n", 118 | " transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)\n", 119 | " transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')\n", 120 | " transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()\n", 121 | " transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())\n", 122 | " image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()\n", 123 | " image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())\n", 124 | " vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)\n", 125 | " vis = np.uint8(255 * vis)\n", 126 | " vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)\n", 127 | " return vis\n", 128 | "\n", 129 | "def print_preds(logits):\n", 130 | " prob = torch.softmax(logits, dim=1)\n", 131 | " class_indices = logits.data.topk(5, dim=1)[1][0].tolist()\n", 132 | " max_str_len = 0\n", 133 | " class_names = []\n", 134 | " for cls_idx in class_indices:\n", 135 | " class_names.append(CLS2IDX[cls_idx])\n", 136 | " if len(CLS2IDX[cls_idx]) > max_str_len:\n", 137 | " max_str_len = len(CLS2IDX[cls_idx])\n", 138 | "\n", 139 | " print('Top 5 classes:')\n", 140 | " for cls_idx in class_indices:\n", 141 | " output_string = '\\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])\n", 142 | " output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\\t\\t'\n", 143 | " output_string += 'value = {:.3f}\\t prob = {:.1f}%'.format(logits[0, cls_idx], 100 * prob[0, cls_idx])\n", 144 | " print(output_string)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 4, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "model = model.cuda()" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "image = transform_for_eval('./images/1.jpg').unsqueeze(0).cuda()\n", 163 | "raw_image = Image.open('./images/1.jpg')\n", 164 | "map_raw_atten, logits = generate_raw_attn(model, image)\n", 165 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n", 166 | "map_rollout, _ = generate_rollout(model, image)\n", 167 | "image = image.squeeze()\n", 168 | "\n", 169 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n", 170 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n", 171 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n", 172 | "print_preds(logits)\n", 173 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n", 174 | "axs[0].imshow(raw_image)\n", 175 | "axs[0].axis('off')\n", 176 | "axs[1].imshow(raw_attn)\n", 177 | "axs[1].axis('off')\n", 178 | "axs[2].imshow(rollout)\n", 179 | "axs[2].axis('off')\n", 180 | "axs[3].imshow(mamba_attr)\n", 181 | "axs[3].axis('off')\n" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "# Attention Matrices:\n", 191 | "# Load an image and prepare it for model evaluation\n", 192 | "image = transform_for_eval('./images/1.jpg').unsqueeze(0).cuda()\n", 193 | "\n", 194 | "# Specify the layer and channel to analyze\n", 195 | "selected_layer = 6\n", 196 | "selected_channel = 30\n", 197 | "\n", 198 | "# Enable computation of attention matrices in the model\n", 199 | "model.layers[selected_layer].mixer.compute_attn_matrix = True\n", 200 | "# Pass the image through the model\n", 201 | "out = model(image)\n", 202 | "\n", 203 | "# Extract and normalize attention matrices\n", 204 | "attn_matrix_a = model.layers[selected_layer].mixer.attn_matrix_a.abs()\n", 205 | "attn_matrix_b = model.layers[selected_layer].mixer.attn_matrix_b.abs()\n", 206 | "normalize_attn_mat = lambda attn_mat : (attn_mat.abs() - torch.min(attn_mat.abs())) / (torch.max(attn_mat.abs()) - torch.min(attn_mat.abs()))\n", 207 | "attn_matrix_a_normalize = normalize_attn_mat(attn_matrix_a)\n", 208 | "attn_matrix_b_normalize = normalize_attn_mat(attn_matrix_b)\n", 209 | "\n", 210 | "# Plot each attention matrix\n", 211 | "fig, axs = plt.subplots(1, 6, figsize=(10,10))\n", 212 | "for i in range(3):\n", 213 | " axs[i].imshow(attn_matrix_a.cpu().detach().numpy()[0, selected_channel+i, :, :])\n", 214 | " axs[i].axis('off')\n", 215 | " axs[i+3].imshow(attn_matrix_b.cpu().detach().numpy()[0, selected_channel+i, :, :])\n", 216 | " axs[i+3].axis('off')\n", 217 | " model.layers[selected_layer].mixer.compute_attn_matrix = False" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "image = transform_for_eval('./images/2.jpg').unsqueeze(0).cuda()\n", 227 | "raw_image = Image.open('./images/2.jpg')\n", 228 | "map_raw_atten, logits = generate_raw_attn(model, image)\n", 229 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n", 230 | "map_rollout, _ = generate_rollout(model, image)\n", 231 | "image = image.squeeze()\n", 232 | "\n", 233 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n", 234 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n", 235 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n", 236 | "print_preds(logits)\n", 237 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n", 238 | "axs[0].imshow(raw_image)\n", 239 | "axs[0].axis('off')\n", 240 | "axs[1].imshow(raw_attn)\n", 241 | "axs[1].axis('off')\n", 242 | "axs[2].imshow(rollout)\n", 243 | "axs[2].axis('off')\n", 244 | "axs[3].imshow(mamba_attr)\n", 245 | "axs[3].axis('off')\n" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "image = transform_for_eval('./images/3.jpg').unsqueeze(0).cuda()\n", 255 | "raw_image = Image.open('./images/3.jpg')\n", 256 | "map_raw_atten, logits = generate_raw_attn(model, image)\n", 257 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n", 258 | "map_rollout, _ = generate_rollout(model, image)\n", 259 | "image = image.squeeze()\n", 260 | "\n", 261 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n", 262 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n", 263 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n", 264 | "print_preds(logits)\n", 265 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n", 266 | "axs[0].imshow(raw_image)\n", 267 | "axs[0].axis('off')\n", 268 | "axs[1].imshow(raw_attn)\n", 269 | "axs[1].axis('off')\n", 270 | "axs[2].imshow(rollout)\n", 271 | "axs[2].axis('off')\n", 272 | "axs[3].imshow(mamba_attr)\n", 273 | "axs[3].axis('off')\n" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "image = transform_for_eval('./images/4.jpg').unsqueeze(0).cuda()\n", 283 | "raw_image = Image.open('./images/4.jpg')\n", 284 | "map_raw_atten, logits = generate_raw_attn(model, image)\n", 285 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n", 286 | "map_rollout, _ = generate_rollout(model, image)\n", 287 | "image = image.squeeze()\n", 288 | "\n", 289 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n", 290 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n", 291 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n", 292 | "print_preds(logits)\n", 293 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n", 294 | "axs[0].imshow(raw_image)\n", 295 | "axs[0].axis('off')\n", 296 | "axs[1].imshow(raw_attn)\n", 297 | "axs[1].axis('off')\n", 298 | "axs[2].imshow(rollout)\n", 299 | "axs[2].axis('off')\n", 300 | "axs[3].imshow(mamba_attr)\n", 301 | "axs[3].axis('off')\n" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "image = transform_for_eval('./images/5.jpg').unsqueeze(0).cuda()\n", 311 | "raw_image = Image.open('./images/5.jpg')\n", 312 | "map_raw_atten, logits = generate_raw_attn(model, image)\n", 313 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n", 314 | "map_rollout, _ = generate_rollout(model, image)\n", 315 | "image = image.squeeze()\n", 316 | "\n", 317 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n", 318 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n", 319 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n", 320 | "print_preds(logits)\n", 321 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n", 322 | "axs[0].imshow(raw_image)\n", 323 | "axs[0].axis('off')\n", 324 | "axs[1].imshow(raw_attn)\n", 325 | "axs[1].axis('off')\n", 326 | "axs[2].imshow(rollout)\n", 327 | "axs[2].axis('off')\n", 328 | "axs[3].imshow(mamba_attr)\n", 329 | "axs[3].axis('off')\n" 330 | ] 331 | } 332 | ], 333 | "metadata": { 334 | "kernelspec": { 335 | "display_name": "mamba", 336 | "language": "python", 337 | "name": "python3" 338 | }, 339 | "language_info": { 340 | "codemirror_mode": { 341 | "name": "ipython", 342 | "version": 3 343 | }, 344 | "file_extension": ".py", 345 | "mimetype": "text/x-python", 346 | "name": "python", 347 | "nbconvert_exporter": "python", 348 | "pygments_lexer": "ipython3", 349 | "version": "3.10.13" 350 | } 351 | }, 352 | "nbformat": 4, 353 | "nbformat_minor": 2 354 | } 355 | -------------------------------------------------------------------------------- /vim/xai_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def compute_rollout_attention(all_layer_matrices, start_layer=0): 6 | # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow 7 | num_tokens = all_layer_matrices[0].shape[1] 8 | batch_size = all_layer_matrices[0].shape[0] 9 | eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device) 10 | 11 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] 12 | matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) 13 | for i in range(len(all_layer_matrices))] 14 | joint_attention = matrices_aug[start_layer] 15 | for i in range(start_layer+1, len(matrices_aug)): 16 | joint_attention = matrices_aug[i].bmm(joint_attention) 17 | return joint_attention 18 | 19 | def generate_raw_attn(model, image, start_layer=15): 20 | image.requires_grad_() 21 | logits = model(image) 22 | all_layer_attentions = [] 23 | cls_pos = 98 24 | for layeridx in range(len(model.layers)): 25 | attn_heads = model.layers[layeridx].mixer.xai_b 26 | attn_heads = (attn_heads - attn_heads.min()) / (attn_heads.max() - attn_heads.min()) 27 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach() 28 | all_layer_attentions.append(avg_heads) 29 | p = torch.cat(all_layer_attentions[start_layer:], dim=0).mean(dim=0).unsqueeze(0) 30 | p = torch.cat([p[:,:cls_pos], p[:,(cls_pos+1):]], dim=-1) 31 | return p.clamp(min=0).squeeze().unsqueeze(0), logits 32 | 33 | 34 | def generate_mamba_attr(model, image, start_layer=15): 35 | image.requires_grad_() 36 | logits = model(image) 37 | index = np.argmax(logits.cpu().data.numpy(), axis=-1) 38 | one_hot = np.zeros((1, logits.size()[-1]), dtype=np.float32) 39 | one_hot[0, index] = 1 40 | one_hot = torch.from_numpy(one_hot).requires_grad_(True) 41 | one_hot = torch.sum(one_hot.cuda() * logits) 42 | model.zero_grad() 43 | one_hot.backward(retain_graph=True) 44 | all_layer_attentions = [] 45 | cls_pos = 98 46 | for layeridx in range(len(model.layers)): 47 | attn_heads = model.layers[layeridx].mixer.xai_b.clamp(min=0) 48 | s = model.layers[layeridx].get_gradients().squeeze().detach() #[1:, :].clamp(min=0).max(dim=1)[0].unsqueeze(0) 49 | s = s.clamp(min=0).max(dim=1)[0].unsqueeze(0) 50 | s = (s - s.min()) / (s.max() - s.min()) 51 | attn_heads = (attn_heads - attn_heads.min()) / (attn_heads.max() - attn_heads.min()) 52 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach() 53 | fused = avg_heads * s 54 | all_layer_attentions.append(fused) 55 | rollout = compute_rollout_attention(all_layer_attentions, start_layer) 56 | p = rollout[0 , cls_pos , :].unsqueeze(0) 57 | p = torch.cat([p[:,:cls_pos], p[:,(cls_pos+1):]], dim=-1) 58 | return p.clamp(min=0).squeeze().unsqueeze(0), logits 59 | 60 | 61 | def generate_rollout(model, image, start_layer=15, num_layers=24): 62 | image.requires_grad_() 63 | logits = model(image) 64 | all_layer_attentions = [] 65 | cls_pos = 98 66 | for layer in range(num_layers): 67 | attn_heads = model.layers[layer].mixer.xai_b 68 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach() 69 | all_layer_attentions.append(avg_heads) 70 | rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer) 71 | p = rollout[0 , cls_pos , :].unsqueeze(0) 72 | p = torch.cat([p[:,:cls_pos], p[:,(cls_pos+1):]], dim=-1) 73 | return p.clamp(min=0).squeeze().unsqueeze(0), logits 74 | --------------------------------------------------------------------------------