├── README.md
├── assets
├── 1.png
├── 2.png
├── MambaXAI.pdf
├── notebook.png
├── pdf.png
└── xai_gradmethod.jpg
├── causal-conv1d
├── AUTHORS
├── LICENSE
├── README.md
├── causal_conv1d
│ ├── __init__.py
│ └── causal_conv1d_interface.py
├── csrc
│ ├── causal_conv1d.cpp
│ ├── causal_conv1d.h
│ ├── causal_conv1d_bwd.cu
│ ├── causal_conv1d_common.h
│ ├── causal_conv1d_fwd.cu
│ ├── causal_conv1d_update.cu
│ └── static_switch.h
├── setup.py
└── tests
│ └── test_causal_conv1d.py
├── mamba-1p1p1
├── AUTHORS
├── LICENSE
├── README.md
├── assets
│ └── selection.png
├── benchmarks
│ └── benchmark_generation_mamba_simple.py
├── csrc
│ └── selective_scan
│ │ ├── reverse_scan.cuh
│ │ ├── selective_scan.cpp
│ │ ├── selective_scan.h
│ │ ├── selective_scan_bwd_bf16_complex.cu
│ │ ├── selective_scan_bwd_bf16_real.cu
│ │ ├── selective_scan_bwd_fp16_complex.cu
│ │ ├── selective_scan_bwd_fp16_real.cu
│ │ ├── selective_scan_bwd_fp32_complex.cu
│ │ ├── selective_scan_bwd_fp32_real.cu
│ │ ├── selective_scan_bwd_kernel.cuh
│ │ ├── selective_scan_common.h
│ │ ├── selective_scan_fwd_bf16.cu
│ │ ├── selective_scan_fwd_fp16.cu
│ │ ├── selective_scan_fwd_fp32.cu
│ │ ├── selective_scan_fwd_kernel.cuh
│ │ ├── static_switch.h
│ │ └── uninitialized_copy.cuh
├── evals
│ └── lm_harness_eval.py
├── mamba_ssm
│ ├── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── config_mamba.py
│ │ └── mixer_seq_simple.py
│ ├── modules
│ │ ├── __init__.py
│ │ └── mamba_simple.py
│ ├── ops
│ │ ├── __init__.py
│ │ ├── selective_scan_interface.py
│ │ └── triton
│ │ │ ├── __init__.py
│ │ │ ├── layernorm.py
│ │ │ └── selective_state_update.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── generation.py
│ │ └── hf.py
├── setup.py
└── tests
│ └── ops
│ ├── test_selective_scan.py
│ └── triton
│ └── test_selective_state_update.py
└── vim
├── augment.py
├── class_mapper.py
├── datasets.py
├── engine.py
├── hubconf.py
├── images
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg
└── 5.jpg
├── losses.py
├── main.py
├── models_mamba.py
├── rope.py
├── run_with_submitit.py
├── samplers.py
├── scripts
├── ft-vim-s.sh
├── ft-vim-t.sh
├── pt-vim-s.sh
└── pt-vim-t.sh
├── utils.py
├── vim_requirements.txt
├── vmamba_xai.ipynb
└── xai_utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
🐍 The Hidden Attention of Mamba Models 🐍
3 |
4 | Ameen Ali1 \*,Itamar Zimerman1 \* and Lior Wolf1
5 |
6 | ameenali023@gmail.com, itamarzimm@gmail.com, liorwolf@gmail.com
7 |
8 | 1 Tel Aviv University
9 | (\*) equal contribution
10 |
11 |
12 |
13 |
14 |
15 | ## Official PyTorch Implementation of "The Hidden Attention of Mamba Models"
16 |
17 | The Mamba layer offers an efficient state space model (SSM) that is highly effective in modeling multiple domains including long-range sequences and images. SSMs are viewed as dual models, in which one trains in parallel on the entire sequence using convolutions, and deploys in an autoregressive manner. We add a third view and show that such models can be viewed as attention-driven models. This new perspective enables us to compare the underlying mechanisms to that of the self-attention layers in transformers and allows us to peer inside the inner workings of the Mamba model with explainability methods.
18 |
19 | You can access the paper through : The Hidden Attention of Mamba Models
20 |
21 |
22 |

23 |
24 |
25 | ## Set Up Environment
26 |
27 | - Python 3.10.13
28 |
29 | - `conda create -n your_env_name python=3.10.13`
30 | - Activate Env
31 | - `conda activate your_env_name`
32 | - CUDA TOOLKIT 11.8
33 | - `conda install nvidia/label/cuda-11.8.0::cuda-toolkit`
34 | - torch 2.1.1 + cu118
35 | - `pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118`
36 |
37 | - Requirements: vim_requirements.txt
38 | - `pip install -r vim/vim_requirements.txt`
39 |
40 | - Install jupyter
41 | - `pip install jupyter`
42 |
43 | - Install ``causal_conv1d`` and ``mamba`` from *our source*
44 | - `cd causal-conv1d`
45 | - `pip install --editable .`
46 | - `cd ..`
47 | - `pip install --editable mamba-1p1p1`
48 |
49 |
50 |
51 |
52 | ## Pre-Trained Weights
53 |
54 | We have used the official weights provided by [Vim](https://github.com/hustvl/Vim), which can be downloaded from here:
55 |
56 | | Model | #param. | Top-1 Acc. | Top-5 Acc. | Hugginface Repo |
57 | |:------------------------------------------------------------------:|:-------------:|:----------:|:----------:|:----------:|
58 | | [Vim-tiny](https://huggingface.co/hustvl/Vim-tiny-midclstok) | 7M | 76.1 | 93.0 | https://huggingface.co/hustvl/Vim-tiny-midclstok |
59 | | [Vim-tiny+](https://huggingface.co/hustvl/Vim-tiny-midclstok) | 7M | 78.3 | 94.2 | https://huggingface.co/hustvl/Vim-tiny-midclstok |
60 | | [Vim-small](https://huggingface.co/hustvl/Vim-small-midclstok) | 26M | 80.5 | 95.1 | https://huggingface.co/hustvl/Vim-small-midclstok |
61 | | [Vim-small+](https://huggingface.co/hustvl/Vim-small-midclstok) | 26M | 81.6 | 95.4 | https://huggingface.co/hustvl/Vim-small-midclstok |
62 |
63 | **Notes:**
64 | - In all of our experiments, we have worked with [Vim-small](https://huggingface.co/hustvl/Vim-small-midclstok).
65 |
66 | ## Vision-Mamba Explainability Notebook:
67 |
68 |

69 |
70 |
71 | Follow the instructions in vim/vmamba_xai.ipynb notebook, in order to apply a single-image inference for the 3 introduced methods in the paper.
72 |
73 |
74 |

75 |
76 |
77 | ## To-Do
78 | For the segmentation experiment, please check out our [follow-up work](https://github.com/Itamarzimm/UnifiedImplicitAttnRepr/tree/main).
79 |
80 |
84 |
85 | ## Citation
86 | if you find our work useful, please consider citing us:
87 | ```latex
88 | @misc{ali2024hidden,
89 | title={The Hidden Attention of Mamba Models},
90 | author={Ameen Ali and Itamar Zimerman and Lior Wolf},
91 | year={2024},
92 | eprint={2403.01590},
93 | archivePrefix={arXiv},
94 | primaryClass={cs.LG}
95 | }
96 | ```
97 | ## Acknowledgement
98 | This repository is heavily based on [Vim](https://github.com/hustvl/Vim), [Mamba](https://github.com/state-spaces/mamba) and [Transformer-Explainability](https://github.com/hila-chefer/Transformer-Explainability). Thanks for their wonderful works.
99 |
--------------------------------------------------------------------------------
/assets/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/1.png
--------------------------------------------------------------------------------
/assets/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/2.png
--------------------------------------------------------------------------------
/assets/MambaXAI.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/MambaXAI.pdf
--------------------------------------------------------------------------------
/assets/notebook.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/notebook.png
--------------------------------------------------------------------------------
/assets/pdf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/pdf.png
--------------------------------------------------------------------------------
/assets/xai_gradmethod.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/assets/xai_gradmethod.jpg
--------------------------------------------------------------------------------
/causal-conv1d/AUTHORS:
--------------------------------------------------------------------------------
1 | Tri Dao, tri@tridao.me
2 |
--------------------------------------------------------------------------------
/causal-conv1d/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/causal-conv1d/README.md:
--------------------------------------------------------------------------------
1 | # Causal depthwise conv1d in CUDA with a PyTorch interface
2 |
--------------------------------------------------------------------------------
/causal-conv1d/causal_conv1d/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
4 |
--------------------------------------------------------------------------------
/causal-conv1d/causal_conv1d/causal_conv1d_interface.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao.
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | import causal_conv1d_cuda
8 |
9 |
10 | class CausalConv1dFn(torch.autograd.Function):
11 | @staticmethod
12 | def forward(ctx, x, weight, bias=None, activation=None):
13 | if activation not in [None, "silu", "swish"]:
14 | raise NotImplementedError("activation must be None, silu, or swish")
15 | if x.stride(2) != 1 and x.stride(1) != 1:
16 | x = x.contiguous()
17 | bias = bias.contiguous() if bias is not None else None
18 | ctx.save_for_backward(x, weight, bias)
19 | ctx.activation = activation in ["silu", "swish"]
20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
21 | return out
22 |
23 | @staticmethod
24 | def backward(ctx, dout):
25 | x, weight, bias = ctx.saved_tensors
26 | if dout.stride(2) != 1 and dout.stride(1) != 1:
27 | dout = dout.contiguous()
28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
29 | # backward of conv1d with the backward of chunk).
30 | # Here we just pass in None and dx will be allocated in the C++ code.
31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
32 | x, weight, bias, dout, None, ctx.activation
33 | )
34 | return dx, dweight, dbias if bias is not None else None, None
35 |
36 |
37 | def causal_conv1d_fn(x, weight, bias=None, activation=None):
38 | """
39 | x: (batch, dim, seqlen)
40 | weight: (dim, width)
41 | bias: (dim,)
42 | activation: either None or "silu" or "swish"
43 |
44 | out: (batch, dim, seqlen)
45 | """
46 | return CausalConv1dFn.apply(x, weight, bias, activation)
47 |
48 |
49 | def causal_conv1d_ref(x, weight, bias=None, activation=None):
50 | """
51 | x: (batch, dim, seqlen)
52 | weight: (dim, width)
53 | bias: (dim,)
54 |
55 | out: (batch, dim, seqlen)
56 | """
57 | if activation not in [None, "silu", "swish"]:
58 | raise NotImplementedError("activation must be None, silu, or swish")
59 | dtype_in = x.dtype
60 | x = x.to(weight.dtype)
61 | seqlen = x.shape[-1]
62 | dim, width = weight.shape
63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
64 | out = out[..., :seqlen]
65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
66 |
67 |
68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
69 | """
70 | x: (batch, dim)
71 | conv_state: (batch, dim, width)
72 | weight: (dim, width)
73 | bias: (dim,)
74 |
75 | out: (batch, dim)
76 | """
77 | if activation not in [None, "silu", "swish"]:
78 | raise NotImplementedError("activation must be None, silu, or swish")
79 | activation = activation in ["silu", "swish"]
80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
81 |
82 |
83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
84 | """
85 | x: (batch, dim)
86 | conv_state: (batch, dim, width)
87 | weight: (dim, width)
88 | bias: (dim,)
89 |
90 | out: (batch, dim)
91 | """
92 | if activation not in [None, "silu", "swish"]:
93 | raise NotImplementedError("activation must be None, silu, or swish")
94 | dtype_in = x.dtype
95 | batch, dim = x.shape
96 | width = weight.shape[1]
97 | assert conv_state.shape == (batch, dim, width)
98 | assert weight.shape == (dim, width)
99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
100 | conv_state[:, :, -1] = x
101 | out = torch.sum(conv_state * weight, dim=-1) # (B D)
102 | if bias is not None:
103 | out += bias
104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
105 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d.cpp:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #include "causal_conv1d.h"
11 |
12 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13 |
14 | #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
15 | if (ITYPE == at::ScalarType::Half) { \
16 | using input_t = at::Half; \
17 | __VA_ARGS__(); \
18 | } else if (ITYPE == at::ScalarType::BFloat16) { \
19 | using input_t = at::BFloat16; \
20 | __VA_ARGS__(); \
21 | } else if (ITYPE == at::ScalarType::Float) { \
22 | using input_t = float; \
23 | __VA_ARGS__(); \
24 | } else { \
25 | AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
26 | }
27 |
28 | #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
29 | if (WTYPE == at::ScalarType::Half) { \
30 | using weight_t = at::Half; \
31 | __VA_ARGS__(); \
32 | } else if (WTYPE == at::ScalarType::BFloat16) { \
33 | using weight_t = at::BFloat16; \
34 | __VA_ARGS__(); \
35 | } else if (WTYPE == at::ScalarType::Float) { \
36 | using weight_t = float; \
37 | __VA_ARGS__(); \
38 | } else { \
39 | AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
40 | }
41 |
42 | template
43 | void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
44 | template
45 | void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
46 |
47 | template
48 | void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
49 | template
50 | void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
51 |
52 | template
53 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
54 |
55 | void set_conv_params_fwd(ConvParamsBase ¶ms,
56 | // sizes
57 | const size_t batch,
58 | const size_t dim,
59 | const size_t seqlen,
60 | const size_t width,
61 | // device pointers
62 | const at::Tensor x,
63 | const at::Tensor weight,
64 | const at::Tensor out,
65 | void* bias_ptr,
66 | bool silu_activation) {
67 |
68 | // Reset the parameters
69 | memset(¶ms, 0, sizeof(params));
70 |
71 | params.batch = batch;
72 | params.dim = dim;
73 | params.seqlen = seqlen;
74 | params.width = width;
75 |
76 | params.silu_activation = silu_activation;
77 |
78 | // Set the pointers and strides.
79 | params.x_ptr = x.data_ptr();
80 | params.weight_ptr = weight.data_ptr();
81 | params.bias_ptr = bias_ptr;
82 | params.out_ptr = out.data_ptr();
83 | // All stride are in elements, not bytes.
84 | params.x_batch_stride = x.stride(0);
85 | params.x_c_stride = x.stride(1);
86 | params.x_l_stride = x.stride(-1);
87 | params.weight_c_stride = weight.stride(0);
88 | params.weight_width_stride = weight.stride(1);
89 | params.out_batch_stride = out.stride(0);
90 | params.out_c_stride = out.stride(1);
91 | params.out_l_stride = out.stride(-1);
92 | }
93 |
94 |
95 | void set_conv_params_bwd(ConvParamsBwd ¶ms,
96 | // sizes
97 | const size_t batch,
98 | const size_t dim,
99 | const size_t seqlen,
100 | const size_t width,
101 | // device pointers
102 | const at::Tensor x,
103 | const at::Tensor weight,
104 | void* bias_ptr,
105 | const at::Tensor dout,
106 | const at::Tensor dx,
107 | const at::Tensor dweight,
108 | void* dbias_ptr,
109 | bool silu_activation) {
110 | // Pass in "dout" instead of "out", we're not gonna use "out" at all.
111 | set_conv_params_fwd(params, batch, dim, seqlen, width,
112 | x, weight, dout, bias_ptr, silu_activation);
113 |
114 | // Set the pointers and strides.
115 | params.dout_ptr = dout.data_ptr();
116 | params.dx_ptr = dx.data_ptr();
117 | params.dweight_ptr = dweight.data_ptr();
118 | params.dbias_ptr = dbias_ptr;
119 | // All stride are in elements, not bytes.
120 | params.dout_batch_stride = dout.stride(0);
121 | params.dout_c_stride = dout.stride(1);
122 | params.dout_l_stride = dout.stride(2);
123 | params.dweight_c_stride = dweight.stride(0);
124 | params.dweight_width_stride = dweight.stride(1);
125 | params.dx_batch_stride = dx.stride(0);
126 | params.dx_c_stride = dx.stride(1);
127 | params.dx_l_stride = dx.stride(2);
128 | }
129 |
130 | at::Tensor
131 | causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
132 | const c10::optional &bias_,
133 | bool silu_activation) {
134 | auto input_type = x.scalar_type();
135 | auto weight_type = weight.scalar_type();
136 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
137 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
138 |
139 | TORCH_CHECK(x.is_cuda());
140 | TORCH_CHECK(weight.is_cuda());
141 |
142 | const auto sizes = x.sizes();
143 | const int batch_size = sizes[0];
144 | const int dim = sizes[1];
145 | const int seqlen = sizes[2];
146 | const int width = weight.size(-1);
147 |
148 | CHECK_SHAPE(x, batch_size, dim, seqlen);
149 | CHECK_SHAPE(weight, dim, width);
150 |
151 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
152 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
153 |
154 | if (is_channel_last) {
155 | TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
156 | }
157 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
158 |
159 |
160 | if (bias_.has_value()) {
161 | auto bias = bias_.value();
162 | TORCH_CHECK(bias.scalar_type() == weight_type);
163 | TORCH_CHECK(bias.is_cuda());
164 | TORCH_CHECK(bias.stride(-1) == 1);
165 | CHECK_SHAPE(bias, dim);
166 | }
167 |
168 | at::Tensor out = torch::empty_like(x);
169 |
170 | ConvParamsBase params;
171 | set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
172 | bias_.has_value() ? bias_.value().data_ptr() : nullptr,
173 | silu_activation);
174 |
175 | // Otherwise the kernel will be launched from cuda:0 device
176 | // Cast to char to avoid compiler warning about narrowing
177 | at::cuda::CUDAGuard device_guard{(char)x.get_device()};
178 | auto stream = at::cuda::getCurrentCUDAStream().stream();
179 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
180 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
181 | if (!is_channel_last) {
182 | causal_conv1d_fwd_cuda(params, stream);
183 | } else {
184 | causal_conv1d_channellast_fwd_cuda(params, stream);
185 | }
186 | });
187 | });
188 | return out;
189 | }
190 |
191 | std::vector
192 | causal_conv1d_bwd(const at::Tensor &x, const at::Tensor &weight,
193 | const c10::optional &bias_,
194 | at::Tensor &dout,
195 | c10::optional &dx_,
196 | bool silu_activation) {
197 | auto input_type = x.scalar_type();
198 | auto weight_type = weight.scalar_type();
199 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
200 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
201 |
202 | TORCH_CHECK(x.is_cuda());
203 | TORCH_CHECK(weight.is_cuda());
204 | TORCH_CHECK(dout.is_cuda());
205 |
206 | const auto sizes = x.sizes();
207 | const int batch_size = sizes[0];
208 | const int dim = sizes[1];
209 | const int seqlen = sizes[2];
210 | const int width = weight.size(-1);
211 |
212 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
213 |
214 | CHECK_SHAPE(x, batch_size, dim, seqlen);
215 | CHECK_SHAPE(weight, dim, width);
216 | CHECK_SHAPE(dout, batch_size, dim, seqlen);
217 |
218 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
219 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
220 | if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
221 | if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
222 |
223 | if (bias_.has_value()) {
224 | auto bias = bias_.value();
225 | TORCH_CHECK(bias.scalar_type() == weight_type);
226 | TORCH_CHECK(bias.is_cuda());
227 | TORCH_CHECK(bias.stride(-1) == 1);
228 | CHECK_SHAPE(bias, dim);
229 | }
230 |
231 | at::Tensor dx;
232 | if (dx_.has_value()) {
233 | dx = dx_.value();
234 | TORCH_CHECK(dx.scalar_type() == input_type);
235 | TORCH_CHECK(dx.is_cuda());
236 | CHECK_SHAPE(dx, batch_size, dim, seqlen);
237 | if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
238 | if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
239 | } else {
240 | dx = torch::empty_like(x);
241 | }
242 |
243 | // Otherwise the kernel will be launched from cuda:0 device
244 | // Cast to char to avoid compiler warning about narrowing
245 | at::cuda::CUDAGuard device_guard{(char)x.get_device()};
246 |
247 | at::Tensor dweight = torch::zeros_like(weight, weight.options().dtype(at::kFloat));
248 | at::Tensor dbias;
249 | if (bias_.has_value()) { dbias = torch::zeros_like(bias_.value(), bias_.value().options().dtype(at::kFloat)); }
250 |
251 | ConvParamsBwd params;
252 | set_conv_params_bwd(params, batch_size, dim, seqlen, width,
253 | x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
254 | dout, dx, dweight, bias_.has_value() ? dbias.data_ptr() : nullptr,
255 | silu_activation);
256 |
257 | auto stream = at::cuda::getCurrentCUDAStream().stream();
258 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
259 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
260 | if (!is_channel_last) {
261 | causal_conv1d_bwd_cuda(params, stream);
262 | } else {
263 | causal_conv1d_channellast_bwd_cuda(params, stream);
264 | }
265 | });
266 | });
267 | return {dx, dweight.to(weight.dtype()), bias_.has_value() ? dbias.to(bias_.value().dtype()) : dbias};
268 | }
269 |
270 | at::Tensor
271 | causal_conv1d_update(const at::Tensor &x,
272 | const at::Tensor &conv_state,
273 | const at::Tensor &weight,
274 | const c10::optional &bias_,
275 | bool silu_activation) {
276 | auto input_type = x.scalar_type();
277 | auto weight_type = weight.scalar_type();
278 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
279 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
280 | TORCH_CHECK(conv_state.scalar_type() == input_type);
281 |
282 | TORCH_CHECK(x.is_cuda());
283 | TORCH_CHECK(conv_state.is_cuda());
284 | TORCH_CHECK(weight.is_cuda());
285 |
286 | const auto sizes = x.sizes();
287 | const int batch_size = sizes[0];
288 | const int dim = sizes[1];
289 | const int width = weight.size(-1);
290 |
291 | CHECK_SHAPE(x, batch_size, dim);
292 | CHECK_SHAPE(conv_state, batch_size, dim, width);
293 | CHECK_SHAPE(weight, dim, width);
294 |
295 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
296 |
297 | if (bias_.has_value()) {
298 | auto bias = bias_.value();
299 | TORCH_CHECK(bias.scalar_type() == weight_type);
300 | TORCH_CHECK(bias.is_cuda());
301 | TORCH_CHECK(bias.stride(-1) == 1);
302 | CHECK_SHAPE(bias, dim);
303 | }
304 |
305 | at::Tensor out = torch::empty_like(x);
306 |
307 | ConvParamsBase params;
308 | set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
309 | bias_.has_value() ? bias_.value().data_ptr() : nullptr,
310 | silu_activation);
311 | params.conv_state_ptr = conv_state.data_ptr();
312 | // All stride are in elements, not bytes.
313 | params.conv_state_batch_stride = conv_state.stride(0);
314 | params.conv_state_c_stride = conv_state.stride(1);
315 | params.conv_state_l_stride = conv_state.stride(2);
316 |
317 | // Otherwise the kernel will be launched from cuda:0 device
318 | // Cast to char to avoid compiler warning about narrowing
319 | at::cuda::CUDAGuard device_guard{(char)x.get_device()};
320 | auto stream = at::cuda::getCurrentCUDAStream().stream();
321 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
322 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
323 | causal_conv1d_update_cuda(params, stream);
324 | });
325 | });
326 | return out;
327 | }
328 |
329 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
330 | m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
331 | m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
332 | m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
333 | }
334 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | ////////////////////////////////////////////////////////////////////////////////////////////////////
8 |
9 | struct ConvParamsBase {
10 | using index_t = uint32_t;
11 |
12 | int batch, dim, seqlen, width;
13 | bool silu_activation;
14 |
15 | index_t x_batch_stride;
16 | index_t x_c_stride;
17 | index_t x_l_stride;
18 | index_t weight_c_stride;
19 | index_t weight_width_stride;
20 | index_t out_batch_stride;
21 | index_t out_c_stride;
22 | index_t out_l_stride;
23 |
24 | index_t conv_state_batch_stride;
25 | index_t conv_state_c_stride;
26 | index_t conv_state_l_stride;
27 |
28 | // Common data pointers.
29 | void *__restrict__ x_ptr;
30 | void *__restrict__ weight_ptr;
31 | void *__restrict__ bias_ptr;
32 | void *__restrict__ out_ptr;
33 |
34 | void *__restrict__ conv_state_ptr;
35 | };
36 |
37 | struct ConvParamsBwd: public ConvParamsBase {
38 | index_t dx_batch_stride;
39 | index_t dx_c_stride;
40 | index_t dx_l_stride;
41 | index_t dweight_c_stride;
42 | index_t dweight_width_stride;
43 | index_t dout_batch_stride;
44 | index_t dout_c_stride;
45 | index_t dout_l_stride;
46 |
47 | // Common data pointers.
48 | void *__restrict__ dx_ptr;
49 | void *__restrict__ dweight_ptr;
50 | void *__restrict__ dbias_ptr;
51 | void *__restrict__ dout_ptr;
52 | };
53 |
54 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d_common.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | #include
8 | #include
9 |
10 | ////////////////////////////////////////////////////////////////////////////////////////////////////
11 |
12 | template struct BytesToType {};
13 |
14 | template<> struct BytesToType<16> {
15 | using Type = uint4;
16 | static_assert(sizeof(Type) == 16);
17 | };
18 |
19 | template<> struct BytesToType<8> {
20 | using Type = uint64_t;
21 | static_assert(sizeof(Type) == 8);
22 | };
23 |
24 | template<> struct BytesToType<4> {
25 | using Type = uint32_t;
26 | static_assert(sizeof(Type) == 4);
27 | };
28 |
29 | template<> struct BytesToType<2> {
30 | using Type = uint16_t;
31 | static_assert(sizeof(Type) == 2);
32 | };
33 |
34 | template<> struct BytesToType<1> {
35 | using Type = uint8_t;
36 | static_assert(sizeof(Type) == 1);
37 | };
38 |
39 | ////////////////////////////////////////////////////////////////////////////////////////////////////
40 |
41 | template
42 | struct SumOp {
43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; }
44 | };
45 |
46 | template
47 | struct Allreduce {
48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
49 | template
50 | static __device__ inline T run(T x, Operator &op) {
51 | constexpr int OFFSET = THREADS / 2;
52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
53 | return Allreduce::run(x, op);
54 | }
55 | };
56 |
57 | template<>
58 | struct Allreduce<2> {
59 | template
60 | static __device__ inline T run(T x, Operator &op) {
61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
62 | return x;
63 | }
64 | };
65 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d_update.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #include
6 | #include
7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8 |
9 | #include
10 | #include
11 |
12 | #include "causal_conv1d.h"
13 | #include "causal_conv1d_common.h"
14 | #include "static_switch.h"
15 |
16 | template
17 | struct Causal_conv1d_update_kernel_traits {
18 | using input_t = input_t_;
19 | using weight_t = weight_t_;
20 | static constexpr int kNThreads = kNThreads_;
21 | static constexpr int kWidth = kWidth_;
22 | static constexpr int kNBytes = sizeof(input_t);
23 | static_assert(kNBytes == 2 || kNBytes == 4);
24 | };
25 |
26 | template
27 | __global__ __launch_bounds__(Ktraits::kNThreads)
28 | void causal_conv1d_update_kernel(ConvParamsBase params) {
29 | constexpr int kWidth = Ktraits::kWidth;
30 | constexpr int kNThreads = Ktraits::kNThreads;
31 | using input_t = typename Ktraits::input_t;
32 | using weight_t = typename Ktraits::weight_t;
33 |
34 | const int tidx = threadIdx.x;
35 | const int batch_id = blockIdx.x;
36 | const int channel_id = blockIdx.y * kNThreads + tidx;
37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride
38 | + channel_id * params.x_c_stride;
39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
40 | + channel_id * params.conv_state_c_stride;
41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride;
42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride
43 | + channel_id * params.out_c_stride;
44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]);
45 |
46 | float weight_vals[kWidth] = {0};
47 | if (channel_id < params.dim) {
48 | #pragma unroll
49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
50 | }
51 |
52 | float x_vals[kWidth] = {0};
53 | if (channel_id < params.dim) {
54 | #pragma unroll
55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
56 | x_vals[kWidth - 1] = float(x[0]);
57 | #pragma unroll
58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
59 | }
60 |
61 | float out_val = bias_val;
62 | #pragma unroll
63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
65 | if (channel_id < params.dim) { out[0] = input_t(out_val); }
66 | }
67 |
68 | template
69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
70 | using Ktraits = Causal_conv1d_update_kernel_traits;
71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
72 | auto kernel = &causal_conv1d_update_kernel;
73 | kernel<<>>(params);
74 | C10_CUDA_KERNEL_LAUNCH_CHECK();
75 | }
76 |
77 | template
78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
79 | if (params.width == 2) {
80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
81 | } else if (params.width == 3) {
82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
83 | } else if (params.width == 4) {
84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
85 | }
86 | }
87 |
88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/causal-conv1d/csrc/static_switch.h:
--------------------------------------------------------------------------------
1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3 |
4 | #pragma once
5 |
6 | /// @param COND - a boolean expression to switch by
7 | /// @param CONST_NAME - a name given for the constexpr bool variable.
8 | /// @param ... - code to execute for true and false
9 | ///
10 | /// Usage:
11 | /// ```
12 | /// BOOL_SWITCH(flag, BoolConst, [&] {
13 | /// some_function(...);
14 | /// });
15 | /// ```
16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17 | [&] { \
18 | if (COND) { \
19 | static constexpr bool CONST_NAME = true; \
20 | return __VA_ARGS__(); \
21 | } else { \
22 | static constexpr bool CONST_NAME = false; \
23 | return __VA_ARGS__(); \
24 | } \
25 | }()
26 |
--------------------------------------------------------------------------------
/causal-conv1d/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao.
2 | import sys
3 | import warnings
4 | import os
5 | import re
6 | import ast
7 | from pathlib import Path
8 | from packaging.version import parse, Version
9 | import platform
10 |
11 | from setuptools import setup, find_packages
12 | import subprocess
13 |
14 | import urllib.request
15 | import urllib.error
16 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
17 |
18 | import torch
19 | from torch.utils.cpp_extension import (
20 | BuildExtension,
21 | CppExtension,
22 | CUDAExtension,
23 | CUDA_HOME,
24 | )
25 |
26 |
27 | with open("README.md", "r", encoding="utf-8") as fh:
28 | long_description = fh.read()
29 |
30 |
31 | # ninja build does not work unless include_dirs are abs path
32 | this_dir = os.path.dirname(os.path.abspath(__file__))
33 |
34 | PACKAGE_NAME = "causal_conv1d"
35 |
36 | BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
37 |
38 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
39 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
40 | FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
41 | SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
42 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
43 | FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
44 |
45 |
46 | def get_platform():
47 | """
48 | Returns the platform name as used in wheel filenames.
49 | """
50 | if sys.platform.startswith("linux"):
51 | return "linux_x86_64"
52 | elif sys.platform == "darwin":
53 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
54 | return f"macosx_{mac_version}_x86_64"
55 | elif sys.platform == "win32":
56 | return "win_amd64"
57 | else:
58 | raise ValueError("Unsupported platform: {}".format(sys.platform))
59 |
60 |
61 | def get_cuda_bare_metal_version(cuda_dir):
62 | raw_output = subprocess.check_output(
63 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
64 | )
65 | output = raw_output.split()
66 | release_idx = output.index("release") + 1
67 | bare_metal_version = parse(output[release_idx].split(",")[0])
68 |
69 | return raw_output, bare_metal_version
70 |
71 |
72 | def check_if_cuda_home_none(global_option: str) -> None:
73 | if CUDA_HOME is not None:
74 | return
75 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
76 | # in that case.
77 | warnings.warn(
78 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
79 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
80 | "only images whose names contain 'devel' will provide nvcc."
81 | )
82 |
83 |
84 | def append_nvcc_threads(nvcc_extra_args):
85 | return nvcc_extra_args + ["--threads", "4"]
86 |
87 |
88 | cmdclass = {}
89 | ext_modules = []
90 |
91 | if not SKIP_CUDA_BUILD:
92 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
93 | TORCH_MAJOR = int(torch.__version__.split(".")[0])
94 | TORCH_MINOR = int(torch.__version__.split(".")[1])
95 |
96 | check_if_cuda_home_none("causal_conv1d")
97 | # Check, if CUDA11 is installed for compute capability 8.0
98 | cc_flag = []
99 | if CUDA_HOME is not None:
100 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
101 | if bare_metal_version < Version("11.6"):
102 | raise RuntimeError(
103 | "causal_conv1d is only supported on CUDA 11.6 and above. "
104 | "Note: make sure nvcc has a supported version by running nvcc -V."
105 | )
106 |
107 | cc_flag.append("-gencode")
108 | cc_flag.append("arch=compute_70,code=sm_70")
109 | cc_flag.append("-gencode")
110 | cc_flag.append("arch=compute_80,code=sm_80")
111 | if bare_metal_version >= Version("11.8"):
112 | cc_flag.append("-gencode")
113 | cc_flag.append("arch=compute_90,code=sm_90")
114 |
115 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
116 | # torch._C._GLIBCXX_USE_CXX11_ABI
117 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
118 | if FORCE_CXX11_ABI:
119 | torch._C._GLIBCXX_USE_CXX11_ABI = True
120 |
121 | ext_modules.append(
122 | CUDAExtension(
123 | name="causal_conv1d_cuda",
124 | sources=[
125 | "csrc/causal_conv1d.cpp",
126 | "csrc/causal_conv1d_fwd.cu",
127 | "csrc/causal_conv1d_bwd.cu",
128 | "csrc/causal_conv1d_update.cu",
129 | ],
130 | extra_compile_args={
131 | "cxx": ["-O3"],
132 | "nvcc": append_nvcc_threads(
133 | [
134 | "-O3",
135 | "-U__CUDA_NO_HALF_OPERATORS__",
136 | "-U__CUDA_NO_HALF_CONVERSIONS__",
137 | "-U__CUDA_NO_BFLOAT16_OPERATORS__",
138 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
139 | "-U__CUDA_NO_BFLOAT162_OPERATORS__",
140 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
141 | "--expt-relaxed-constexpr",
142 | "--expt-extended-lambda",
143 | "--use_fast_math",
144 | "--ptxas-options=-v",
145 | "-lineinfo",
146 | ]
147 | + cc_flag
148 | ),
149 | },
150 | include_dirs=[this_dir],
151 | )
152 | )
153 |
154 |
155 | def get_package_version():
156 | with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
157 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
158 | public_version = ast.literal_eval(version_match.group(1))
159 | local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
160 | if local_version:
161 | return f"{public_version}+{local_version}"
162 | else:
163 | return str(public_version)
164 |
165 |
166 | def get_wheel_url():
167 | # Determine the version numbers that will be used to determine the correct wheel
168 | # We're using the CUDA version used to build torch, not the one currently installed
169 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
170 | torch_cuda_version = parse(torch.version.cuda)
171 | torch_version_raw = parse(torch.__version__)
172 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
173 | # to save CI time. Minor versions should be compatible.
174 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
175 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
176 | platform_name = get_platform()
177 | causal_conv1d_version = get_package_version()
178 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
179 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
180 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
181 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
182 |
183 | # Determine wheel URL based on CUDA version, torch version, python version and OS
184 | wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
185 | wheel_url = BASE_WHEEL_URL.format(
186 | tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
187 | )
188 | return wheel_url, wheel_filename
189 |
190 |
191 | class CachedWheelsCommand(_bdist_wheel):
192 | """
193 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
194 | find an existing wheel (which is currently the case for all installs). We use
195 | the environment parameters to detect whether there is already a pre-built version of a compatible
196 | wheel available and short-circuits the standard full build pipeline.
197 | """
198 |
199 | def run(self):
200 | if FORCE_BUILD:
201 | return super().run()
202 |
203 | wheel_url, wheel_filename = get_wheel_url()
204 | print("Guessing wheel URL: ", wheel_url)
205 | try:
206 | urllib.request.urlretrieve(wheel_url, wheel_filename)
207 |
208 | # Make the archive
209 | # Lifted from the root wheel processing command
210 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
211 | if not os.path.exists(self.dist_dir):
212 | os.makedirs(self.dist_dir)
213 |
214 | impl_tag, abi_tag, plat_tag = self.get_tag()
215 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
216 |
217 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
218 | print("Raw wheel path", wheel_path)
219 | os.rename(wheel_filename, wheel_path)
220 | except urllib.error.HTTPError:
221 | print("Precompiled wheel not found. Building from source...")
222 | # If the wheel could not be downloaded, build from source
223 | super().run()
224 |
225 |
226 | setup(
227 | name=PACKAGE_NAME,
228 | version=get_package_version(),
229 | packages=find_packages(
230 | exclude=(
231 | "build",
232 | "csrc",
233 | "include",
234 | "tests",
235 | "dist",
236 | "docs",
237 | "benchmarks",
238 | "causal_conv1d.egg-info",
239 | )
240 | ),
241 | author="Tri Dao",
242 | author_email="tri@tridao.me",
243 | description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
244 | long_description=long_description,
245 | long_description_content_type="text/markdown",
246 | url="https://github.com/Dao-AILab/causal-conv1d",
247 | classifiers=[
248 | "Programming Language :: Python :: 3",
249 | "License :: OSI Approved :: BSD License",
250 | "Operating System :: Unix",
251 | ],
252 | ext_modules=ext_modules,
253 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
254 | if ext_modules
255 | else {
256 | "bdist_wheel": CachedWheelsCommand,
257 | },
258 | python_requires=">=3.7",
259 | install_requires=[
260 | "torch",
261 | "packaging",
262 | "ninja",
263 | ],
264 | )
265 |
--------------------------------------------------------------------------------
/causal-conv1d/tests/test_causal_conv1d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Tri Dao.
2 |
3 | import math
4 |
5 | import torch
6 | import pytest
7 |
8 | from einops import rearrange
9 |
10 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
11 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
12 |
13 |
14 | @pytest.mark.parametrize("channel_last", [False, True])
15 | # @pytest.mark.parametrize('channel_last', [True])
16 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
17 | # @pytest.mark.parametrize('itype', [torch.float16])
18 | @pytest.mark.parametrize("silu_activation", [False, True])
19 | # @pytest.mark.parametrize('silu_activation', [True])
20 | @pytest.mark.parametrize("has_bias", [False, True])
21 | # @pytest.mark.parametrize('has_bias', [True])
22 | @pytest.mark.parametrize("width", [2, 3, 4])
23 | # @pytest.mark.parametrize('width', [2])
24 | @pytest.mark.parametrize(
25 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
26 | )
27 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
28 | # @pytest.mark.parametrize('seqlen', [128])
29 | def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last):
30 | device = "cuda"
31 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
32 | if itype == torch.bfloat16:
33 | rtol, atol = 1e-2, 5e-2
34 | rtolw, atolw = (1e-3, 1e-3)
35 | # set seed
36 | torch.random.manual_seed(0)
37 | batch_size = 2
38 | # batch_size = 1
39 | dim = 4096 + 32 # Try dim not divisible by 64
40 | # dim = 64
41 | if not channel_last:
42 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
43 | else:
44 | x = rearrange(
45 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
46 | ).requires_grad_()
47 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
48 | if has_bias:
49 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
50 | else:
51 | bias = None
52 | x_ref = x.detach().clone().requires_grad_()
53 | weight_ref = weight.detach().clone().requires_grad_()
54 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
55 | activation = None if not silu_activation else "silu"
56 | out = causal_conv1d_fn(x, weight, bias, activation=activation)
57 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation)
58 |
59 | print(f"Output max diff: {(out - out_ref).abs().max().item()}")
60 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
61 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
62 |
63 | g = torch.randn_like(out)
64 | out_ref.backward(g)
65 | out.backward(g)
66 |
67 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
68 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
69 | if has_bias:
70 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
71 |
72 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
73 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
74 | if has_bias:
75 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
76 |
77 |
78 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
79 | # @pytest.mark.parametrize('itype', [torch.float16])
80 | @pytest.mark.parametrize("silu_activation", [False, True])
81 | # @pytest.mark.parametrize('silu_activation', [False])
82 | @pytest.mark.parametrize("has_bias", [False, True])
83 | # @pytest.mark.parametrize('has_bias', [True])
84 | @pytest.mark.parametrize("width", [2, 3, 4])
85 | # @pytest.mark.parametrize('width', [2])
86 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
87 | # @pytest.mark.parametrize("dim", [2048])
88 | def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
89 | device = "cuda"
90 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
91 | if itype == torch.bfloat16:
92 | rtol, atol = 1e-2, 5e-2
93 | rtolw, atolw = (1e-3, 1e-3)
94 | # set seed
95 | torch.random.manual_seed(0)
96 | batch_size = 2
97 | # batch_size = 1
98 | # dim = 64
99 | x = torch.randn(batch_size, dim, device=device, dtype=itype)
100 | conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype)
101 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
102 | if has_bias:
103 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
104 | else:
105 | bias = None
106 | conv_state_ref = conv_state.detach().clone()
107 | activation = None if not silu_activation else "silu"
108 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
109 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation)
110 |
111 | print(f"Output max diff: {(out - out_ref).abs().max().item()}")
112 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
113 | assert torch.equal(conv_state, conv_state_ref)
114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
115 |
116 |
117 | # @pytest.mark.parametrize("channel_last", [False, True])
118 | @pytest.mark.parametrize('channel_last', [True])
119 | # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
120 | @pytest.mark.parametrize('itype', [torch.bfloat16])
121 | # @pytest.mark.parametrize("silu_activation", [False, True])
122 | @pytest.mark.parametrize('silu_activation', [True])
123 | # @pytest.mark.parametrize("has_bias", [False, True])
124 | @pytest.mark.parametrize('has_bias', [True])
125 | # @pytest.mark.parametrize("width", [2, 3, 4])
126 | @pytest.mark.parametrize('width', [4])
127 | @pytest.mark.parametrize(
128 | # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
129 | "seqlen", [2048]
130 | )
131 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
132 | # @pytest.mark.parametrize('seqlen', [128])
133 | def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
134 | device = "cuda"
135 | # set seed
136 | torch.random.manual_seed(0)
137 | batch_size = 2
138 | # batch_size = 1
139 | dim = 4096 + 32 # Try dim not divisible by 64
140 | # dim = 64
141 | if not channel_last:
142 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
143 | else:
144 | x = rearrange(
145 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
146 | ).requires_grad_()
147 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
148 | if has_bias:
149 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
150 | else:
151 | bias = None
152 | activation = None if not silu_activation else "silu"
153 | out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
154 | g = torch.randn_like(out0)
155 | dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
156 | dw_atol = 1e-4
157 | db_atol = 1e-4
158 |
159 | for i in range(10000):
160 | out = causal_conv1d_fn(x, weight, bias, activation=activation)
161 | dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
162 | dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
163 | # if not dw_equal:
164 | # breakpoint()
165 | if has_bias:
166 | db_equal = torch.allclose(db, db0, atol=db_atol)
167 | # if not db_equal:
168 | # breakpoint()
169 | assert torch.equal(out, out0)
170 | assert torch.equal(dx, dx0)
171 | assert dw_equal
172 | if has_bias:
173 | assert dw_equal
174 |
--------------------------------------------------------------------------------
/mamba-1p1p1/AUTHORS:
--------------------------------------------------------------------------------
1 | Tri Dao, tri@tridao.me
2 | Albert Gu, agu@andrew.cmu.edu
3 |
--------------------------------------------------------------------------------
/mamba-1p1p1/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Tri Dao, Albert Gu
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/mamba-1p1p1/README.md:
--------------------------------------------------------------------------------
1 | # Mamba
2 |
3 | 
4 | > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
5 | > Albert Gu*, Tri Dao*\
6 | > Paper: https://arxiv.org/abs/2312.00752
7 |
8 | ## About
9 |
10 | Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
11 | It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
12 | with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
13 |
14 | ## Installation
15 |
16 | - `pip install causal-conv1d>=1.1.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
17 | - `pip install mamba-ssm`: the core Mamba package.
18 |
19 | It can also be built from source with `pip install .` from this repository.
20 |
21 | If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
22 |
23 | Other requirements:
24 | - Linux
25 | - NVIDIA GPU
26 | - PyTorch 1.12+
27 | - CUDA 11.6+
28 |
29 | ## Usage
30 |
31 | We expose several levels of interface with the Mamba model.
32 |
33 | ### Selective SSM
34 |
35 | Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
36 |
37 | Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
38 |
39 | ### Mamba Block
40 |
41 | The main module of this repository is the Mamba architecture block wrapping the selective SSM.
42 |
43 | Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
44 |
45 | Usage:
46 | ```
47 | import torch
48 | from mamba_ssm import Mamba
49 |
50 | batch, length, dim = 2, 64, 16
51 | x = torch.randn(batch, length, dim).to("cuda")
52 | model = Mamba(
53 | # This module uses roughly 3 * expand * d_model^2 parameters
54 | d_model=dim, # Model dimension d_model
55 | d_state=16, # SSM state expansion factor
56 | d_conv=4, # Local convolution width
57 | expand=2, # Block expansion factor
58 | ).to("cuda")
59 | y = model(x)
60 | assert y.shape == x.shape
61 | ```
62 |
63 | ### Mamba Language Model
64 |
65 | Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
66 |
67 | Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
68 |
69 | This is an example of how to integrate Mamba into an end-to-end neural network.
70 | This example is used in the generation scripts below.
71 |
72 |
73 |
74 | ## Pretrained Models
75 |
76 | Pretrained models are uploaded to
77 | [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
78 | `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
79 | (trained on 600B tokens on the SlimPajama dataset).
80 |
81 |
82 | The models will be autodownloaded by the generation script below.
83 |
84 | These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
85 |
86 | | Parameters | Layers | Model dim. |
87 | |------------|--------|------------|
88 | | 130M | 24 | 768 |
89 | | 370M | 48 | 1024 |
90 | | 790M | 48 | 1536 |
91 | | 1.4B | 48 | 2048 |
92 | | 2.8B | 64 | 2560 |
93 |
94 | (The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
95 |
96 | Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
97 | Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
98 |
99 |
100 | ## Evaluations
101 |
102 | To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
103 | we use the
104 | [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
105 | library.
106 |
107 | 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
108 | --recursive`. We use the `big-refactor` branch.
109 | 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`.
110 | On Python 3.10 you might need to manually install the latest version of `promptsource`: `pip install git+https://github.com/bigscience-workshop/promptsource.git`.
111 | 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
112 | ```
113 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
114 | python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
115 | ```
116 |
117 | To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
118 | ```
119 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 64
120 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 64
121 | ```
122 |
123 | Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
124 |
125 | ## Inference
126 |
127 | The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
128 | 1. autoloads a model from the Hugging Face Hub,
129 | 2. generates completions of a user-specified prompt,
130 | 3. benchmarks the inference speed of this generation.
131 |
132 | Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
133 |
134 | ### Examples
135 |
136 | To test generation latency (e.g. batch size = 1) with different sampling strategies:
137 |
138 | ```
139 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
140 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
141 | ```
142 |
143 | To test generation throughput with random prompts (e.g. large batch size):
144 | ```
145 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
146 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
147 | ```
148 |
149 |
150 | ## Troubleshooting
151 |
152 | ### Precision
153 | Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary.
154 | On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).
155 |
156 | We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities,
157 | as a first step please try a framework storing parameters in fp32 (such as AMP).
158 |
159 | ### Initialization
160 | Some parts of the model have initializations inherited from prior work on S4 models.
161 | For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection.
162 | However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero).
163 | If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework)
164 | that is specific to the training framework.
165 |
166 |
167 | ## Citation
168 |
169 | If you use this codebase, or otherwise found our work valuable, please cite Mamba:
170 | ```
171 | @article{mamba,
172 | title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
173 | author={Gu, Albert and Dao, Tri},
174 | journal={arXiv preprint arXiv:2312.00752},
175 | year={2023}
176 | }
177 | ```
178 |
--------------------------------------------------------------------------------
/mamba-1p1p1/assets/selection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/assets/selection.png
--------------------------------------------------------------------------------
/mamba-1p1p1/benchmarks/benchmark_generation_mamba_simple.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao, Albert Gu.
2 |
3 | import argparse
4 | import time
5 | import json
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from einops import rearrange
11 |
12 | from transformers import AutoTokenizer, AutoModelForCausalLM
13 |
14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
15 |
16 |
17 | parser = argparse.ArgumentParser(description="Generation benchmarking")
18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
19 | parser.add_argument("--prompt", type=str, default=None)
20 | parser.add_argument("--promptlen", type=int, default=100)
21 | parser.add_argument("--genlen", type=int, default=100)
22 | parser.add_argument("--temperature", type=float, default=1.0)
23 | parser.add_argument("--topk", type=int, default=1)
24 | parser.add_argument("--topp", type=float, default=1.0)
25 | parser.add_argument("--repetition-penalty", type=float, default=1.0)
26 | parser.add_argument("--batch", type=int, default=1)
27 | args = parser.parse_args()
28 |
29 | repeats = 3
30 | device = "cuda"
31 | dtype = torch.float16
32 |
33 | print(f"Loading model {args.model_name}")
34 | is_mamba = args.model_name.startswith("state-spaces/mamba-")
35 | if is_mamba:
36 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
38 | else:
39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
41 | model.eval()
42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
43 |
44 | torch.random.manual_seed(0)
45 | if args.prompt is None:
46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
48 | else:
49 | tokens = tokenizer(args.prompt, return_tensors="pt")
50 | input_ids = tokens.input_ids.to(device=device)
51 | attn_mask = tokens.attention_mask.to(device=device)
52 | max_length = input_ids.shape[1] + args.genlen
53 |
54 | if is_mamba:
55 | fn = lambda: model.generate(
56 | input_ids=input_ids,
57 | max_length=max_length,
58 | cg=True,
59 | return_dict_in_generate=True,
60 | output_scores=True,
61 | enable_timing=False,
62 | temperature=args.temperature,
63 | top_k=args.topk,
64 | top_p=args.topp,
65 | repetition_penalty=args.repetition_penalty,
66 | )
67 | else:
68 | fn = lambda: model.generate(
69 | input_ids=input_ids,
70 | attention_mask=attn_mask,
71 | max_length=max_length,
72 | return_dict_in_generate=True,
73 | pad_token_id=tokenizer.eos_token_id,
74 | do_sample=True,
75 | temperature=args.temperature,
76 | top_k=args.topk,
77 | top_p=args.topp,
78 | repetition_penalty=args.repetition_penalty,
79 | )
80 | out = fn()
81 | if args.prompt is not None:
82 | print(tokenizer.batch_decode(out.sequences.tolist()))
83 |
84 | torch.cuda.synchronize()
85 | start = time.time()
86 | for _ in range(repeats):
87 | fn()
88 | torch.cuda.synchronize()
89 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
90 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
91 |
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | ////////////////////////////////////////////////////////////////////////////////////////////////////
8 |
9 | struct SSMScanParamsBase {
10 | using index_t = uint32_t;
11 |
12 | int batch, seqlen, n_chunks;
13 | index_t a_batch_stride;
14 | index_t b_batch_stride;
15 | index_t out_batch_stride;
16 |
17 | // Common data pointers.
18 | void *__restrict__ a_ptr;
19 | void *__restrict__ b_ptr;
20 | void *__restrict__ out_ptr;
21 | void *__restrict__ x_ptr;
22 | };
23 |
24 | ////////////////////////////////////////////////////////////////////////////////////////////////////
25 |
26 | struct SSMParamsBase {
27 | using index_t = uint32_t;
28 |
29 | int batch, dim, seqlen, dstate, n_groups, n_chunks;
30 | int dim_ngroups_ratio;
31 | bool is_variable_B;
32 | bool is_variable_C;
33 |
34 | bool delta_softplus;
35 |
36 | index_t A_d_stride;
37 | index_t A_dstate_stride;
38 | index_t B_batch_stride;
39 | index_t B_d_stride;
40 | index_t B_dstate_stride;
41 | index_t B_group_stride;
42 | index_t C_batch_stride;
43 | index_t C_d_stride;
44 | index_t C_dstate_stride;
45 | index_t C_group_stride;
46 | index_t u_batch_stride;
47 | index_t u_d_stride;
48 | index_t delta_batch_stride;
49 | index_t delta_d_stride;
50 | index_t z_batch_stride;
51 | index_t z_d_stride;
52 | index_t out_batch_stride;
53 | index_t out_d_stride;
54 | index_t out_z_batch_stride;
55 | index_t out_z_d_stride;
56 |
57 | // Common data pointers.
58 | void *__restrict__ A_ptr;
59 | void *__restrict__ B_ptr;
60 | void *__restrict__ C_ptr;
61 | void *__restrict__ D_ptr;
62 | void *__restrict__ u_ptr;
63 | void *__restrict__ delta_ptr;
64 | void *__restrict__ delta_bias_ptr;
65 | void *__restrict__ out_ptr;
66 | void *__restrict__ x_ptr;
67 | void *__restrict__ z_ptr;
68 | void *__restrict__ out_z_ptr;
69 | };
70 |
71 | struct SSMParamsBwd: public SSMParamsBase {
72 | index_t dout_batch_stride;
73 | index_t dout_d_stride;
74 | index_t dA_d_stride;
75 | index_t dA_dstate_stride;
76 | index_t dB_batch_stride;
77 | index_t dB_group_stride;
78 | index_t dB_d_stride;
79 | index_t dB_dstate_stride;
80 | index_t dC_batch_stride;
81 | index_t dC_group_stride;
82 | index_t dC_d_stride;
83 | index_t dC_dstate_stride;
84 | index_t du_batch_stride;
85 | index_t du_d_stride;
86 | index_t dz_batch_stride;
87 | index_t dz_d_stride;
88 | index_t ddelta_batch_stride;
89 | index_t ddelta_d_stride;
90 |
91 | // Common data pointers.
92 | void *__restrict__ dout_ptr;
93 | void *__restrict__ dA_ptr;
94 | void *__restrict__ dB_ptr;
95 | void *__restrict__ dC_ptr;
96 | void *__restrict__ dD_ptr;
97 | void *__restrict__ du_ptr;
98 | void *__restrict__ dz_ptr;
99 | void *__restrict__ ddelta_ptr;
100 | void *__restrict__ ddelta_bias_ptr;
101 | };
102 |
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_common.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | #include
8 | #include
9 | #include // For scalar_value_type
10 |
11 | #define MAX_DSTATE 256
12 |
13 | using complex_t = c10::complex;
14 |
15 | inline __device__ float2 operator+(const float2 & a, const float2 & b){
16 | return {a.x + b.x, a.y + b.y};
17 | }
18 |
19 | inline __device__ float3 operator+(const float3 &a, const float3 &b) {
20 | return {a.x + b.x, a.y + b.y, a.z + b.z};
21 | }
22 |
23 | inline __device__ float4 operator+(const float4 & a, const float4 & b){
24 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
25 | }
26 |
27 | ////////////////////////////////////////////////////////////////////////////////////////////////////
28 |
29 | template struct BytesToType {};
30 |
31 | template<> struct BytesToType<16> {
32 | using Type = uint4;
33 | static_assert(sizeof(Type) == 16);
34 | };
35 |
36 | template<> struct BytesToType<8> {
37 | using Type = uint64_t;
38 | static_assert(sizeof(Type) == 8);
39 | };
40 |
41 | template<> struct BytesToType<4> {
42 | using Type = uint32_t;
43 | static_assert(sizeof(Type) == 4);
44 | };
45 |
46 | template<> struct BytesToType<2> {
47 | using Type = uint16_t;
48 | static_assert(sizeof(Type) == 2);
49 | };
50 |
51 | template<> struct BytesToType<1> {
52 | using Type = uint8_t;
53 | static_assert(sizeof(Type) == 1);
54 | };
55 |
56 | ////////////////////////////////////////////////////////////////////////////////////////////////////
57 |
58 | template
59 | struct Converter{
60 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
61 | #pragma unroll
62 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
63 | }
64 | };
65 |
66 | template
67 | struct Converter{
68 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
69 | static_assert(N % 2 == 0);
70 | auto &src2 = reinterpret_cast(src);
71 | auto &dst2 = reinterpret_cast(dst);
72 | #pragma unroll
73 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
74 | }
75 | };
76 |
77 | #if __CUDA_ARCH__ >= 800
78 | template
79 | struct Converter{
80 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
81 | static_assert(N % 2 == 0);
82 | auto &src2 = reinterpret_cast(src);
83 | auto &dst2 = reinterpret_cast(dst);
84 | #pragma unroll
85 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
86 | }
87 | };
88 | #endif
89 |
90 | ////////////////////////////////////////////////////////////////////////////////////////////////////
91 |
92 | // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
93 | // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
94 | __device__ __forceinline__ complex_t cexp2f(complex_t z) {
95 | float t = exp2f(z.real_);
96 | float c, s;
97 | sincosf(z.imag_, &s, &c);
98 | return complex_t(c * t, s * t);
99 | }
100 |
101 | __device__ __forceinline__ complex_t cexpf(complex_t z) {
102 | float t = expf(z.real_);
103 | float c, s;
104 | sincosf(z.imag_, &s, &c);
105 | return complex_t(c * t, s * t);
106 | }
107 |
108 | template struct SSMScanOp;
109 |
110 | template<>
111 | struct SSMScanOp {
112 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
113 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
114 | }
115 | };
116 |
117 | template<>
118 | struct SSMScanOp {
119 | __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
120 | complex_t a0 = complex_t(ab0.x, ab0.y);
121 | complex_t b0 = complex_t(ab0.z, ab0.w);
122 | complex_t a1 = complex_t(ab1.x, ab1.y);
123 | complex_t b1 = complex_t(ab1.z, ab1.w);
124 | complex_t out_a = a1 * a0;
125 | complex_t out_b = a1 * b0 + b1;
126 | return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
127 | }
128 | };
129 |
130 | // A stateful callback functor that maintains a running prefix to be applied
131 | // during consecutive scan operations.
132 | template struct SSMScanPrefixCallbackOp {
133 | using scan_t = std::conditional_t, float2, float4>;
134 | scan_t running_prefix;
135 | // Constructor
136 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
137 | // Callback operator to be entered by the first warp of threads in the block.
138 | // Thread-0 is responsible for returning a value for seeding the block-wide scan.
139 | __device__ scan_t operator()(scan_t block_aggregate) {
140 | scan_t old_prefix = running_prefix;
141 | running_prefix = SSMScanOp()(running_prefix, block_aggregate);
142 | return old_prefix;
143 | }
144 | };
145 |
146 | ////////////////////////////////////////////////////////////////////////////////////////////////////
147 |
148 | template
149 | inline __device__ void load_input(typename Ktraits::input_t *u,
150 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
151 | typename Ktraits::BlockLoadT::TempStorage &smem_load,
152 | int seqlen) {
153 | if constexpr (Ktraits::kIsEvenLen) {
154 | auto& smem_load_vec = reinterpret_cast(smem_load);
155 | using vec_t = typename Ktraits::vec_t;
156 | Ktraits::BlockLoadVecT(smem_load_vec).Load(
157 | reinterpret_cast(u),
158 | reinterpret_cast(u_vals)
159 | );
160 | } else {
161 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
162 | }
163 | }
164 |
165 | template
166 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
167 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
168 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
169 | int seqlen) {
170 | constexpr int kNItems = Ktraits::kNItems;
171 | if constexpr (!Ktraits::kIsComplex) {
172 | typename Ktraits::input_t B_vals_load[kNItems];
173 | if constexpr (Ktraits::kIsEvenLen) {
174 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight);
175 | using vec_t = typename Ktraits::vec_t;
176 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
177 | reinterpret_cast(Bvar),
178 | reinterpret_cast(B_vals_load)
179 | );
180 | } else {
181 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
182 | }
183 | // #pragma unroll
184 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
185 | Converter::to_float(B_vals_load, B_vals);
186 | } else {
187 | typename Ktraits::input_t B_vals_load[kNItems * 2];
188 | if constexpr (Ktraits::kIsEvenLen) {
189 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight);
190 | using vec_t = typename Ktraits::vec_t;
191 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
192 | reinterpret_cast(Bvar),
193 | reinterpret_cast(B_vals_load)
194 | );
195 | } else {
196 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
197 | }
198 | #pragma unroll
199 | for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
200 | }
201 | }
202 |
203 | template
204 | inline __device__ void store_output(typename Ktraits::input_t *out,
205 | const float (&out_vals)[Ktraits::kNItems],
206 | typename Ktraits::BlockStoreT::TempStorage &smem_store,
207 | int seqlen) {
208 | typename Ktraits::input_t write_vals[Ktraits::kNItems];
209 | #pragma unroll
210 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
211 | if constexpr (Ktraits::kIsEvenLen) {
212 | auto& smem_store_vec = reinterpret_cast(smem_store);
213 | using vec_t = typename Ktraits::vec_t;
214 | Ktraits::BlockStoreVecT(smem_store_vec).Store(
215 | reinterpret_cast(out),
216 | reinterpret_cast(write_vals)
217 | );
218 | } else {
219 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
220 | }
221 | }
222 |
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_bf16.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp16.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp32.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/static_switch.h:
--------------------------------------------------------------------------------
1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3 |
4 | #pragma once
5 |
6 | /// @param COND - a boolean expression to switch by
7 | /// @param CONST_NAME - a name given for the constexpr bool variable.
8 | /// @param ... - code to execute for true and false
9 | ///
10 | /// Usage:
11 | /// ```
12 | /// BOOL_SWITCH(flag, BoolConst, [&] {
13 | /// some_function(...);
14 | /// });
15 | /// ```
16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17 | [&] { \
18 | if (COND) { \
19 | constexpr bool CONST_NAME = true; \
20 | return __VA_ARGS__(); \
21 | } else { \
22 | constexpr bool CONST_NAME = false; \
23 | return __VA_ARGS__(); \
24 | } \
25 | }()
26 |
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/uninitialized_copy.cuh:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Redistribution and use in source and binary forms, with or without
5 | * modification, are permitted provided that the following conditions are met:
6 | * * Redistributions of source code must retain the above copyright
7 | * notice, this list of conditions and the following disclaimer.
8 | * * Redistributions in binary form must reproduce the above copyright
9 | * notice, this list of conditions and the following disclaimer in the
10 | * documentation and/or other materials provided with the distribution.
11 | * * Neither the name of the NVIDIA CORPORATION nor the
12 | * names of its contributors may be used to endorse or promote products
13 | * derived from this software without specific prior written permission.
14 | *
15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | *
26 | ******************************************************************************/
27 |
28 | #pragma once
29 |
30 | #include
31 |
32 | #include
33 |
34 |
35 | namespace detail
36 | {
37 |
38 | #if defined(_NVHPC_CUDA)
39 | template
40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
41 | {
42 | // NVBug 3384810
43 | new (ptr) T(::cuda::std::forward(val));
44 | }
45 | #else
46 | template ::value,
50 | int
51 | >::type = 0>
52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
53 | {
54 | *ptr = ::cuda::std::forward(val);
55 | }
56 |
57 | template ::value,
61 | int
62 | >::type = 0>
63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
64 | {
65 | new (ptr) T(::cuda::std::forward(val));
66 | }
67 | #endif
68 |
69 | } // namespace detail
70 |
--------------------------------------------------------------------------------
/mamba-1p1p1/evals/lm_harness_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import transformers
4 | from transformers import AutoTokenizer
5 |
6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
7 |
8 | from lm_eval.api.model import LM
9 | from lm_eval.models.huggingface import HFLM
10 | from lm_eval.api.registry import register_model
11 | from lm_eval.__main__ import cli_evaluate
12 |
13 |
14 | @register_model("mamba")
15 | class MambaEvalWrapper(HFLM):
16 |
17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
18 |
19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
20 | dtype=torch.float16):
21 | LM.__init__(self)
22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
25 | self.vocab_size = self.tokenizer.vocab_size
26 | self._batch_size = int(batch_size) if batch_size is not None else 64
27 | self._max_length = max_length
28 | self._device = torch.device(device)
29 |
30 | @property
31 | def batch_size(self):
32 | return self._batch_size
33 |
34 | def _model_generate(self, context, max_length, stop, **generation_kwargs):
35 | raise NotImplementedError()
36 |
37 |
38 | if __name__ == "__main__":
39 | cli_evaluate()
40 |
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.1.1"
2 |
3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4 | from mamba_ssm.modules.mamba_simple import Mamba
5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
6 |
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/models/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/models/config_mamba.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 |
4 | @dataclass
5 | class MambaConfig:
6 |
7 | d_model: int = 2560
8 | n_layer: int = 64
9 | vocab_size: int = 50277
10 | ssm_cfg: dict = field(default_factory=dict)
11 | rms_norm: bool = True
12 | residual_in_fp32: bool = True
13 | fused_add_norm: bool = True
14 | pad_vocab_size_multiple: int = 8
15 |
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/models/mixer_seq_simple.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Albert Gu, Tri Dao.
2 |
3 | import math
4 | from functools import partial
5 | import json
6 | import os
7 |
8 | from collections import namedtuple
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from mamba_ssm.models.config_mamba import MambaConfig
14 | from mamba_ssm.modules.mamba_simple import Mamba, Block
15 | from mamba_ssm.utils.generation import GenerationMixin
16 | from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
17 |
18 | try:
19 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
20 | except ImportError:
21 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
22 |
23 |
24 | def create_block(
25 | d_model,
26 | ssm_cfg=None,
27 | norm_epsilon=1e-5,
28 | rms_norm=False,
29 | residual_in_fp32=False,
30 | fused_add_norm=False,
31 | layer_idx=None,
32 | device=None,
33 | dtype=None,
34 | ):
35 | if ssm_cfg is None:
36 | ssm_cfg = {}
37 | factory_kwargs = {"device": device, "dtype": dtype}
38 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
39 | norm_cls = partial(
40 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
41 | )
42 | block = Block(
43 | d_model,
44 | mixer_cls,
45 | norm_cls=norm_cls,
46 | fused_add_norm=fused_add_norm,
47 | residual_in_fp32=residual_in_fp32,
48 | )
49 | block.layer_idx = layer_idx
50 | return block
51 |
52 |
53 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
54 | def _init_weights(
55 | module,
56 | n_layer,
57 | initializer_range=0.02, # Now only used for embedding layer.
58 | rescale_prenorm_residual=True,
59 | n_residuals_per_layer=1, # Change to 2 if we have MLP
60 | ):
61 | if isinstance(module, nn.Linear):
62 | if module.bias is not None:
63 | if not getattr(module.bias, "_no_reinit", False):
64 | nn.init.zeros_(module.bias)
65 | elif isinstance(module, nn.Embedding):
66 | nn.init.normal_(module.weight, std=initializer_range)
67 |
68 | if rescale_prenorm_residual:
69 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
70 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
71 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
72 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/
73 | #
74 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
75 | for name, p in module.named_parameters():
76 | if name in ["out_proj.weight", "fc2.weight"]:
77 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
78 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
79 | # We need to reinit p since this code could be called multiple times
80 | # Having just p *= scale would repeatedly scale it down
81 | nn.init.kaiming_uniform_(p, a=math.sqrt(5))
82 | with torch.no_grad():
83 | p /= math.sqrt(n_residuals_per_layer * n_layer)
84 |
85 |
86 | class MixerModel(nn.Module):
87 | def __init__(
88 | self,
89 | d_model: int,
90 | n_layer: int,
91 | vocab_size: int,
92 | ssm_cfg=None,
93 | norm_epsilon: float = 1e-5,
94 | rms_norm: bool = False,
95 | initializer_cfg=None,
96 | fused_add_norm=False,
97 | residual_in_fp32=False,
98 | device=None,
99 | dtype=None,
100 | ) -> None:
101 | factory_kwargs = {"device": device, "dtype": dtype}
102 | super().__init__()
103 | self.residual_in_fp32 = residual_in_fp32
104 |
105 | self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
106 |
107 | # We change the order of residual and layer norm:
108 | # Instead of LN -> Attn / MLP -> Add, we do:
109 | # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
110 | # the main branch (output of MLP / Mixer). The model definition is unchanged.
111 | # This is for performance reason: we can fuse add + layer_norm.
112 | self.fused_add_norm = fused_add_norm
113 | if self.fused_add_norm:
114 | if layer_norm_fn is None or rms_norm_fn is None:
115 | raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
116 |
117 | self.layers = nn.ModuleList(
118 | [
119 | create_block(
120 | d_model,
121 | ssm_cfg=ssm_cfg,
122 | norm_epsilon=norm_epsilon,
123 | rms_norm=rms_norm,
124 | residual_in_fp32=residual_in_fp32,
125 | fused_add_norm=fused_add_norm,
126 | layer_idx=i,
127 | **factory_kwargs,
128 | )
129 | for i in range(n_layer)
130 | ]
131 | )
132 |
133 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
134 | d_model, eps=norm_epsilon, **factory_kwargs
135 | )
136 |
137 | self.apply(
138 | partial(
139 | _init_weights,
140 | n_layer=n_layer,
141 | **(initializer_cfg if initializer_cfg is not None else {}),
142 | )
143 | )
144 |
145 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
146 | return {
147 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
148 | for i, layer in enumerate(self.layers)
149 | }
150 |
151 | def forward(self, input_ids, inference_params=None):
152 | hidden_states = self.embedding(input_ids)
153 | residual = None
154 | for layer in self.layers:
155 | hidden_states, residual = layer(
156 | hidden_states, residual, inference_params=inference_params
157 | )
158 | if not self.fused_add_norm:
159 | residual = (hidden_states + residual) if residual is not None else hidden_states
160 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
161 | else:
162 | # Set prenorm=False here since we don't need the residual
163 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
164 | hidden_states = fused_add_norm_fn(
165 | hidden_states,
166 | self.norm_f.weight,
167 | self.norm_f.bias,
168 | eps=self.norm_f.eps,
169 | residual=residual,
170 | prenorm=False,
171 | residual_in_fp32=self.residual_in_fp32,
172 | )
173 | return hidden_states
174 |
175 |
176 | class MambaLMHeadModel(nn.Module, GenerationMixin):
177 |
178 | def __init__(
179 | self,
180 | config: MambaConfig,
181 | initializer_cfg=None,
182 | device=None,
183 | dtype=None,
184 | ) -> None:
185 | self.config = config
186 | d_model = config.d_model
187 | n_layer = config.n_layer
188 | vocab_size = config.vocab_size
189 | ssm_cfg = config.ssm_cfg
190 | rms_norm = config.rms_norm
191 | residual_in_fp32 = config.residual_in_fp32
192 | fused_add_norm = config.fused_add_norm
193 | pad_vocab_size_multiple = config.pad_vocab_size_multiple
194 | factory_kwargs = {"device": device, "dtype": dtype}
195 |
196 | super().__init__()
197 | if vocab_size % pad_vocab_size_multiple != 0:
198 | vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
199 | self.backbone = MixerModel(
200 | d_model=d_model,
201 | n_layer=n_layer,
202 | vocab_size=vocab_size,
203 | ssm_cfg=ssm_cfg,
204 | rms_norm=rms_norm,
205 | initializer_cfg=initializer_cfg,
206 | fused_add_norm=fused_add_norm,
207 | residual_in_fp32=residual_in_fp32,
208 | **factory_kwargs,
209 | )
210 | self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
211 |
212 | # Initialize weights and apply final processing
213 | self.apply(
214 | partial(
215 | _init_weights,
216 | n_layer=n_layer,
217 | **(initializer_cfg if initializer_cfg is not None else {}),
218 | )
219 | )
220 | self.tie_weights()
221 |
222 | def tie_weights(self):
223 | self.lm_head.weight = self.backbone.embedding.weight
224 |
225 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
226 | return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
227 |
228 | def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
229 | """
230 | "position_ids" is just to be compatible with Transformer generation. We don't use it.
231 | num_last_tokens: if > 0, only return the logits for the last n tokens
232 | """
233 | hidden_states = self.backbone(input_ids, inference_params=inference_params)
234 | if num_last_tokens > 0:
235 | hidden_states = hidden_states[:, -num_last_tokens:]
236 | lm_logits = self.lm_head(hidden_states)
237 | CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
238 | return CausalLMOutput(logits=lm_logits)
239 |
240 | @classmethod
241 | def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
242 | config_data = load_config_hf(pretrained_model_name)
243 | config = MambaConfig(**config_data)
244 | model = cls(config, device=device, dtype=dtype, **kwargs)
245 | model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
246 | return model
247 |
248 | def save_pretrained(self, save_directory):
249 | """
250 | Minimal implementation of save_pretrained for MambaLMHeadModel.
251 | Save the model and its configuration file to a directory.
252 | """
253 | # Ensure save_directory exists
254 | if not os.path.exists(save_directory):
255 | os.makedirs(save_directory)
256 |
257 | # Save the model's state_dict
258 | model_path = os.path.join(save_directory, 'pytorch_model.bin')
259 | torch.save(self.state_dict(), model_path)
260 |
261 | # Save the configuration of the model
262 | config_path = os.path.join(save_directory, 'config.json')
263 | with open(config_path, 'w') as f:
264 | json.dump(self.config.__dict__, f)
265 |
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/modules/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/ops/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/ops/triton/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/ops/triton/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/ops/triton/selective_state_update.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao.
2 |
3 | """We want triton==2.1.0 for this
4 | """
5 |
6 | import math
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | import triton
11 | import triton.language as tl
12 |
13 | from einops import rearrange, repeat
14 |
15 |
16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
20 | @triton.jit
21 | def _selective_scan_update_kernel(
22 | # Pointers to matrices
23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
24 | # Matrix dimensions
25 | batch, dim, dstate,
26 | # Strides
27 | stride_state_batch, stride_state_dim, stride_state_dstate,
28 | stride_x_batch, stride_x_dim,
29 | stride_dt_batch, stride_dt_dim,
30 | stride_dt_bias_dim,
31 | stride_A_dim, stride_A_dstate,
32 | stride_B_batch, stride_B_dstate,
33 | stride_C_batch, stride_C_dstate,
34 | stride_D_dim,
35 | stride_z_batch, stride_z_dim,
36 | stride_out_batch, stride_out_dim,
37 | # Meta-parameters
38 | DT_SOFTPLUS: tl.constexpr,
39 | BLOCK_SIZE_M: tl.constexpr,
40 | HAS_DT_BIAS: tl.constexpr,
41 | HAS_D: tl.constexpr,
42 | HAS_Z: tl.constexpr,
43 | BLOCK_SIZE_DSTATE: tl.constexpr,
44 | ):
45 | pid_m = tl.program_id(axis=0)
46 | pid_b = tl.program_id(axis=1)
47 | state_ptr += pid_b * stride_state_batch
48 | x_ptr += pid_b * stride_x_batch
49 | dt_ptr += pid_b * stride_dt_batch
50 | B_ptr += pid_b * stride_B_batch
51 | C_ptr += pid_b * stride_C_batch
52 | if HAS_Z:
53 | z_ptr += pid_b * stride_z_batch
54 | out_ptr += pid_b * stride_out_batch
55 |
56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
59 | x_ptrs = x_ptr + offs_m * stride_x_dim
60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim
61 | if HAS_DT_BIAS:
62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
64 | B_ptrs = B_ptr + offs_n * stride_B_dstate
65 | C_ptrs = C_ptr + offs_n * stride_C_dstate
66 | if HAS_D:
67 | D_ptrs = D_ptr + offs_m * stride_D_dim
68 | if HAS_Z:
69 | z_ptrs = z_ptr + offs_m * stride_z_dim
70 | out_ptrs = out_ptr + offs_m * stride_out_dim
71 |
72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
75 | if HAS_DT_BIAS:
76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
77 | if DT_SOFTPLUS:
78 | dt = tl.log(1.0 + tl.exp(dt))
79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
80 | dA = tl.exp(A * dt[:, None])
81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
83 | if HAS_D:
84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85 | if HAS_Z:
86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
87 |
88 | dB = B[None, :] * dt[:, None]
89 | state = state * dA + dB * x[:, None]
90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
91 | out = tl.sum(state * C[None, :], axis=1)
92 | if HAS_D:
93 | out += x * D
94 | if HAS_Z:
95 | out *= z * tl.sigmoid(z)
96 | tl.store(out_ptrs, out, mask=offs_m < dim)
97 |
98 |
99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
100 | """
101 | Argument:
102 | state: (batch, dim, dstate)
103 | x: (batch, dim)
104 | dt: (batch, dim)
105 | A: (dim, dstate)
106 | B: (batch, dstate)
107 | C: (batch, dstate)
108 | D: (dim,)
109 | z: (batch, dim)
110 | dt_bias: (dim,)
111 | Return:
112 | out: (batch, dim)
113 | """
114 | batch, dim, dstate = state.shape
115 | assert x.shape == (batch, dim)
116 | assert dt.shape == x.shape
117 | assert A.shape == (dim, dstate)
118 | assert B.shape == (batch, dstate)
119 | assert C.shape == B.shape
120 | if D is not None:
121 | assert D.shape == (dim,)
122 | if z is not None:
123 | assert z.shape == x.shape
124 | if dt_bias is not None:
125 | assert dt_bias.shape == (dim,)
126 | out = torch.empty_like(x)
127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
129 | # We don't want autotune since it will overwrite the state
130 | # We instead tune by hand.
131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
132 | else ((16, 4) if dstate <= 32 else
133 | ((8, 4) if dstate <= 64 else
134 | ((4, 4) if dstate <= 128 else
135 | ((4, 8))))))
136 | with torch.cuda.device(x.device.index):
137 | _selective_scan_update_kernel[grid](
138 | state, x, dt, dt_bias, A, B, C, D, z, out,
139 | batch, dim, dstate,
140 | state.stride(0), state.stride(1), state.stride(2),
141 | x.stride(0), x.stride(1),
142 | dt.stride(0), dt.stride(1),
143 | dt_bias.stride(0) if dt_bias is not None else 0,
144 | A.stride(0), A.stride(1),
145 | B.stride(0), B.stride(1),
146 | C.stride(0), C.stride(1),
147 | D.stride(0) if D is not None else 0,
148 | z_strides[0], z_strides[1],
149 | out.stride(0), out.stride(1),
150 | dt_softplus,
151 | BLOCK_SIZE_M,
152 | num_warps=num_warps,
153 | )
154 | return out
155 |
156 |
157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
158 | """
159 | Argument:
160 | state: (batch, dim, dstate)
161 | x: (batch, dim)
162 | dt: (batch, dim)
163 | A: (dim, dstate)
164 | B: (batch, dstate)
165 | C: (batch, dstate)
166 | D: (dim,)
167 | z: (batch, dim)
168 | dt_bias: (dim,)
169 | Return:
170 | out: (batch, dim)
171 | """
172 | batch, dim, dstate = state.shape
173 | assert x.shape == (batch, dim)
174 | assert dt.shape == x.shape
175 | assert A.shape == (dim, dstate)
176 | assert B.shape == (batch, dstate)
177 | assert C.shape == B.shape
178 | if D is not None:
179 | assert D.shape == (dim,)
180 | if z is not None:
181 | assert z.shape == x.shape
182 | if dt_bias is not None:
183 | assert dt_bias.shape == (dim,)
184 | dt = dt + dt_bias
185 | dt = F.softplus(dt) if dt_softplus else dt
186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
190 | if D is not None:
191 | out += (x * D).to(out.dtype)
192 | return (out if z is None else out * F.silu(z)).to(x.dtype)
193 |
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/mamba-1p1p1/mamba_ssm/utils/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/utils/hf.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch
4 |
5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6 | from transformers.utils.hub import cached_file
7 |
8 |
9 | def load_config_hf(model_name):
10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11 | return json.load(open(resolved_archive_file))
12 |
13 |
14 | def load_state_dict_hf(model_name, device=None, dtype=None):
15 | # If not fp32, then we don't want to load directly to the GPU
16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18 | return torch.load(resolved_archive_file, map_location=mapped_device)
19 | # Convert dtype before moving to GPU to save memory
20 | if dtype is not None:
21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23 | return state_dict
24 |
--------------------------------------------------------------------------------
/mamba-1p1p1/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Albert Gu, Tri Dao.
2 | import sys
3 | import warnings
4 | import os
5 | import re
6 | import ast
7 | from pathlib import Path
8 | from packaging.version import parse, Version
9 | import platform
10 | import shutil
11 |
12 | from setuptools import setup, find_packages
13 | import subprocess
14 |
15 | import urllib.request
16 | import urllib.error
17 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
18 |
19 | import torch
20 | from torch.utils.cpp_extension import (
21 | BuildExtension,
22 | CppExtension,
23 | CUDAExtension,
24 | CUDA_HOME,
25 | )
26 |
27 |
28 | with open("README.md", "r", encoding="utf-8") as fh:
29 | long_description = fh.read()
30 |
31 |
32 | # ninja build does not work unless include_dirs are abs path
33 | this_dir = os.path.dirname(os.path.abspath(__file__))
34 |
35 | PACKAGE_NAME = "mamba_ssm"
36 |
37 | BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}"
38 |
39 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
40 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
41 | FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
42 | SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
43 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
44 | FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
45 |
46 |
47 | def get_platform():
48 | """
49 | Returns the platform name as used in wheel filenames.
50 | """
51 | if sys.platform.startswith("linux"):
52 | return "linux_x86_64"
53 | elif sys.platform == "darwin":
54 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
55 | return f"macosx_{mac_version}_x86_64"
56 | elif sys.platform == "win32":
57 | return "win_amd64"
58 | else:
59 | raise ValueError("Unsupported platform: {}".format(sys.platform))
60 |
61 |
62 | def get_cuda_bare_metal_version(cuda_dir):
63 | raw_output = subprocess.check_output(
64 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
65 | )
66 | output = raw_output.split()
67 | release_idx = output.index("release") + 1
68 | bare_metal_version = parse(output[release_idx].split(",")[0])
69 |
70 | return raw_output, bare_metal_version
71 |
72 |
73 | def check_if_cuda_home_none(global_option: str) -> None:
74 | if CUDA_HOME is not None:
75 | return
76 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
77 | # in that case.
78 | warnings.warn(
79 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
80 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
81 | "only images whose names contain 'devel' will provide nvcc."
82 | )
83 |
84 |
85 | def append_nvcc_threads(nvcc_extra_args):
86 | return nvcc_extra_args + ["--threads", "4"]
87 |
88 |
89 | cmdclass = {}
90 | ext_modules = []
91 |
92 | if not SKIP_CUDA_BUILD:
93 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
94 | TORCH_MAJOR = int(torch.__version__.split(".")[0])
95 | TORCH_MINOR = int(torch.__version__.split(".")[1])
96 |
97 | check_if_cuda_home_none(PACKAGE_NAME)
98 | # Check, if CUDA11 is installed for compute capability 8.0
99 | cc_flag = []
100 | if CUDA_HOME is not None:
101 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
102 | if bare_metal_version < Version("11.6"):
103 | raise RuntimeError(
104 | f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
105 | "Note: make sure nvcc has a supported version by running nvcc -V."
106 | )
107 |
108 | cc_flag.append("-gencode")
109 | cc_flag.append("arch=compute_70,code=sm_70")
110 | cc_flag.append("-gencode")
111 | cc_flag.append("arch=compute_80,code=sm_80")
112 | if bare_metal_version >= Version("11.8"):
113 | cc_flag.append("-gencode")
114 | cc_flag.append("arch=compute_90,code=sm_90")
115 |
116 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
117 | # torch._C._GLIBCXX_USE_CXX11_ABI
118 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
119 | if FORCE_CXX11_ABI:
120 | torch._C._GLIBCXX_USE_CXX11_ABI = True
121 |
122 | ext_modules.append(
123 | CUDAExtension(
124 | name="selective_scan_cuda",
125 | sources=[
126 | "csrc/selective_scan/selective_scan.cpp",
127 | "csrc/selective_scan/selective_scan_fwd_fp32.cu",
128 | "csrc/selective_scan/selective_scan_fwd_fp16.cu",
129 | "csrc/selective_scan/selective_scan_fwd_bf16.cu",
130 | "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
131 | "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
132 | "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
133 | "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
134 | "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
135 | "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
136 | ],
137 | extra_compile_args={
138 | "cxx": ["-O3", "-std=c++17"],
139 | "nvcc": append_nvcc_threads(
140 | [
141 | "-O3",
142 | "-std=c++17",
143 | "-U__CUDA_NO_HALF_OPERATORS__",
144 | "-U__CUDA_NO_HALF_CONVERSIONS__",
145 | "-U__CUDA_NO_BFLOAT16_OPERATORS__",
146 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
147 | "-U__CUDA_NO_BFLOAT162_OPERATORS__",
148 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
149 | "--expt-relaxed-constexpr",
150 | "--expt-extended-lambda",
151 | "--use_fast_math",
152 | "--ptxas-options=-v",
153 | "-lineinfo",
154 | ]
155 | + cc_flag
156 | ),
157 | },
158 | include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
159 | )
160 | )
161 |
162 |
163 | def get_package_version():
164 | with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
165 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
166 | public_version = ast.literal_eval(version_match.group(1))
167 | local_version = os.environ.get("MAMBA_LOCAL_VERSION")
168 | if local_version:
169 | return f"{public_version}+{local_version}"
170 | else:
171 | return str(public_version)
172 |
173 |
174 | def get_wheel_url():
175 | # Determine the version numbers that will be used to determine the correct wheel
176 | # We're using the CUDA version used to build torch, not the one currently installed
177 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
178 | torch_cuda_version = parse(torch.version.cuda)
179 | torch_version_raw = parse(torch.__version__)
180 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
181 | # to save CI time. Minor versions should be compatible.
182 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
183 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
184 | platform_name = get_platform()
185 | mamba_ssm_version = get_package_version()
186 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
187 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
188 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
189 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
190 |
191 | # Determine wheel URL based on CUDA version, torch version, python version and OS
192 | wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
193 | wheel_url = BASE_WHEEL_URL.format(
194 | tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename
195 | )
196 | return wheel_url, wheel_filename
197 |
198 |
199 | class CachedWheelsCommand(_bdist_wheel):
200 | """
201 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
202 | find an existing wheel (which is currently the case for all installs). We use
203 | the environment parameters to detect whether there is already a pre-built version of a compatible
204 | wheel available and short-circuits the standard full build pipeline.
205 | """
206 |
207 | def run(self):
208 | if FORCE_BUILD:
209 | return super().run()
210 |
211 | wheel_url, wheel_filename = get_wheel_url()
212 | print("Guessing wheel URL: ", wheel_url)
213 | try:
214 | urllib.request.urlretrieve(wheel_url, wheel_filename)
215 |
216 | # Make the archive
217 | # Lifted from the root wheel processing command
218 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
219 | if not os.path.exists(self.dist_dir):
220 | os.makedirs(self.dist_dir)
221 |
222 | impl_tag, abi_tag, plat_tag = self.get_tag()
223 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
224 |
225 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
226 | print("Raw wheel path", wheel_path)
227 | shutil.move(wheel_filename, wheel_path)
228 | except urllib.error.HTTPError:
229 | print("Precompiled wheel not found. Building from source...")
230 | # If the wheel could not be downloaded, build from source
231 | super().run()
232 |
233 |
234 | setup(
235 | name=PACKAGE_NAME,
236 | version=get_package_version(),
237 | packages=find_packages(
238 | exclude=(
239 | "build",
240 | "csrc",
241 | "include",
242 | "tests",
243 | "dist",
244 | "docs",
245 | "benchmarks",
246 | "mamba_ssm.egg-info",
247 | )
248 | ),
249 | author="Tri Dao, Albert Gu",
250 | author_email="tri@tridao.me, agu@cs.cmu.edu",
251 | description="Mamba state-space model",
252 | long_description=long_description,
253 | long_description_content_type="text/markdown",
254 | url="https://github.com/state-spaces/mamba",
255 | classifiers=[
256 | "Programming Language :: Python :: 3",
257 | "License :: OSI Approved :: BSD License",
258 | "Operating System :: Unix",
259 | ],
260 | ext_modules=ext_modules,
261 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
262 | if ext_modules
263 | else {
264 | "bdist_wheel": CachedWheelsCommand,
265 | },
266 | python_requires=">=3.7",
267 | install_requires=[
268 | "torch",
269 | "packaging",
270 | "ninja",
271 | "einops",
272 | "triton",
273 | "transformers",
274 | # "causal_conv1d>=1.1.0",
275 | ],
276 | )
277 |
--------------------------------------------------------------------------------
/mamba-1p1p1/tests/ops/test_selective_scan.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Tri Dao.
2 |
3 | import math
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import pytest
8 |
9 | from einops import rearrange
10 |
11 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
12 | from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref
13 |
14 |
15 | # @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
16 | @pytest.mark.parametrize('wtype', [torch.float32])
17 | # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
18 | @pytest.mark.parametrize('itype', [torch.float32])
19 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
20 | @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
21 | # @pytest.mark.parametrize('seqlen', [128])
22 | # @pytest.mark.parametrize("return_last_state", [False, True])
23 | @pytest.mark.parametrize("return_last_state", [True])
24 | # @pytest.mark.parametrize('has_delta_bias', [False, True])
25 | @pytest.mark.parametrize('has_delta_bias', [True])
26 | # @pytest.mark.parametrize('delta_softplus', [False, True])
27 | @pytest.mark.parametrize('delta_softplus', [True])
28 | # @pytest.mark.parametrize('has_z', [False, True])
29 | @pytest.mark.parametrize('has_z', [True])
30 | # @pytest.mark.parametrize('has_D', [False, True])
31 | @pytest.mark.parametrize('has_D', [True])
32 | @pytest.mark.parametrize("varBC_groups", [1, 2])
33 | # @pytest.mark.parametrize("varBC_groups", [1])
34 | # @pytest.mark.parametrize("is_variable_C", [False, True])
35 | @pytest.mark.parametrize("is_variable_C", [True])
36 | # @pytest.mark.parametrize("is_variable_B", [False, True])
37 | @pytest.mark.parametrize("is_variable_B", [True])
38 | def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,
39 | delta_softplus, return_last_state, seqlen, itype, wtype):
40 | if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
41 | pytest.skip() # This config is not applicable
42 | device = 'cuda'
43 | rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
44 | if itype == torch.bfloat16:
45 | rtol, atol = 3e-2, 5e-2
46 | rtolw, atolw = (1e-3, 1e-3)
47 | if has_z: # If we have z, the errors on the weights seem higher
48 | rtolw = max(rtolw, rtol)
49 | atolw = max(atolw, atol)
50 | # set seed
51 | torch.random.manual_seed(0)
52 | batch_size = 2
53 | dim = 4
54 | dstate = 8
55 | is_complex = wtype == torch.complex64
56 | A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
57 | if not is_variable_B:
58 | B_shape = (dim, dstate)
59 | elif varBC_groups == 1:
60 | B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
61 | else:
62 | B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
63 | B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype,
64 | requires_grad=True)
65 | if not is_variable_C:
66 | C_shape = (dim, dstate)
67 | elif varBC_groups == 1:
68 | C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
69 | else:
70 | C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
71 | C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype,
72 | requires_grad=True)
73 | if has_D:
74 | D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
75 | else:
76 | D = None
77 | if has_z:
78 | z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
79 | else:
80 | z = None
81 | if has_delta_bias:
82 | delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
83 | else:
84 | delta_bias = None
85 | u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
86 | delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_()
87 | A_ref = A.detach().clone().requires_grad_()
88 | B_ref = B.detach().clone().requires_grad_()
89 | C_ref = C.detach().clone().requires_grad_()
90 | D_ref = D.detach().clone().requires_grad_() if D is not None else None
91 | z_ref = z.detach().clone().requires_grad_() if z is not None else None
92 | u_ref = u.detach().clone().requires_grad_()
93 | delta_ref = delta.detach().clone().requires_grad_()
94 | delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
95 | out, *rest = selective_scan_fn(
96 | u, delta, A, B, C, D, z=z,
97 | delta_bias=delta_bias, delta_softplus=delta_softplus,
98 | return_last_state=return_last_state
99 | )
100 | if return_last_state:
101 | state = rest[0]
102 | out_ref, *rest = selective_scan_ref(
103 | u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
104 | delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
105 | return_last_state=return_last_state
106 | )
107 | if return_last_state:
108 | state_ref = rest[0]
109 | # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
110 | # dt_u = delta * u
111 |
112 | print(f'Output max diff: {(out - out_ref).abs().max().item()}')
113 | print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
115 | if return_last_state:
116 | print(f'State max diff: {(state - state_ref).abs().max().item()}')
117 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
118 |
119 | g = torch.randn_like(out)
120 | out_ref.backward(g)
121 | out.backward(g)
122 |
123 | print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}')
124 | print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}')
125 | print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
126 | print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
127 | print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
128 | if has_D:
129 | print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
130 | if has_z:
131 | print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}')
132 | if has_delta_bias:
133 | print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
134 |
135 | assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
136 | assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
137 | assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
138 | assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
139 | atol=atolw if not is_variable_B else atol)
140 | assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
141 | atol=atolw if not is_variable_C else atol)
142 | if has_D:
143 | assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
144 | if has_z:
145 | assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw)
146 | if has_delta_bias:
147 | assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
148 |
149 |
150 | @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
151 | # @pytest.mark.parametrize('wtype', [torch.complex64])
152 | # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
153 | @pytest.mark.parametrize('itype', [torch.float32])
154 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
155 | @pytest.mark.parametrize('seqlen', [128])
156 | @pytest.mark.parametrize("is_variable_C", [False, True])
157 | # @pytest.mark.parametrize("is_variable_C", [False])
158 | @pytest.mark.parametrize("is_variable_B", [False, True])
159 | # @pytest.mark.parametrize("is_variable_B", [True])
160 | def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
161 | device = 'cuda'
162 | rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
163 | if itype == torch.bfloat16:
164 | rtol, atol = 3e-2, 5e-2
165 | rtolw, atolw = (1e-3, 1e-3)
166 | # If we have z, the errors on the weights seem higher
167 | rtolw = max(rtolw, rtol)
168 | atolw = max(atolw, atol)
169 | # set seed
170 | torch.random.manual_seed(0)
171 | batch_size = 2
172 | dim = 768
173 | dstate = 8
174 | dt_rank = 48
175 | is_complex = wtype == torch.complex64
176 | xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
177 | conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
178 | conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
179 | x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
180 | * (1 if not is_complex else 2),
181 | dim, device=device, dtype=itype, requires_grad=True)
182 | delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
183 | out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
184 | out_proj_bias = None
185 | A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
186 | B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
187 | if not is_variable_B else None)
188 | C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
189 | if not is_variable_C else None)
190 | D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
191 | delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
192 | B_proj_bias = None
193 | C_proj_bias = None
194 | xz_ref = xz.detach().clone().requires_grad_()
195 | conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
196 | conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
197 | x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
198 | delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
199 | out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
200 | out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
201 | if out_proj_bias is not None else None)
202 | A_ref = A.detach().clone().requires_grad_()
203 | B_ref = B.detach().clone().requires_grad_() if B is not None else None
204 | C_ref = C.detach().clone().requires_grad_() if C is not None else None
205 | D_ref = D.detach().clone().requires_grad_()
206 | delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
207 | out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
208 | out_proj_weight, out_proj_bias,
209 | A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
210 | out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
211 | delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
212 | A_ref, B_ref, C_ref, D_ref,
213 | delta_bias=delta_bias_ref, delta_softplus=True)
214 | # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
215 | # dt_u = delta * u
216 |
217 | print(f'Output max diff: {(out - out_ref).abs().max().item()}')
218 | print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
219 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
220 |
221 | g = torch.randn_like(out)
222 | out_ref.backward(g)
223 | out.backward(g)
224 |
225 | print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}')
226 | print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
227 | if not is_variable_B:
228 | print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
229 | if not is_variable_C:
230 | print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
231 | print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
232 | print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
233 | print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}')
234 | print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}')
235 | print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}')
236 | print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}')
237 | print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}')
238 |
239 | # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
240 | # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
241 | # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
242 | # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
243 | # atol=atolw if not is_variable_B else atol)
244 | # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
245 | # atol=atolw if not is_variable_C else atol)
246 | # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
247 | # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
248 |
--------------------------------------------------------------------------------
/mamba-1p1p1/tests/ops/triton/test_selective_state_update.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Tri Dao.
2 |
3 | import math
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import pytest
8 |
9 | from einops import rearrange
10 |
11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
12 |
13 |
14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
15 | # @pytest.mark.parametrize('itype', [torch.float16])
16 | @pytest.mark.parametrize("has_z", [False, True])
17 | # @pytest.mark.parametrize('has_z', [True])
18 | @pytest.mark.parametrize("dstate", [16, 32, 64])
19 | # @pytest.mark.parametrize("dstate", [16])
20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
21 | # @pytest.mark.parametrize("dim", [2048])
22 | def test_causal_conv1d_update(dim, dstate, has_z, itype):
23 | device = "cuda"
24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
25 | if itype == torch.bfloat16:
26 | rtol, atol = 1e-2, 5e-2
27 | # set seed
28 | torch.random.manual_seed(0)
29 | batch_size = 2
30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
31 | x = torch.randn(batch_size, dim, device=device, dtype=itype)
32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype)
33 | dt_bias = torch.rand(dim, device=device) - 4.0
34 | A = -torch.rand(dim, dstate, device=device) - 1.0
35 | B = torch.randn(batch_size, dstate, device=device)
36 | C = torch.randn(batch_size, dstate, device=device)
37 | D = torch.randn(dim, device=device)
38 | if has_z:
39 | z = torch.randn_like(x)
40 | else:
41 | z = None
42 | state_ref = state.detach().clone()
43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
45 |
46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}")
47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
50 |
--------------------------------------------------------------------------------
/vim/augment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | """
5 | 3Augment implementation
6 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
7 | and timm DA(https://github.com/rwightman/pytorch-image-models)
8 | """
9 | import torch
10 | from torchvision import transforms
11 |
12 | # error: cannot import name '_pil_interp' from 'timm.data.transforms'
13 | # from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
14 |
15 | # fix: timm version problem
16 | # from timm.data.transforms import str_pil_interp as _pil_interp
17 | from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor
18 |
19 | import numpy as np
20 | from torchvision import datasets, transforms
21 | import random
22 |
23 |
24 |
25 | from PIL import ImageFilter, ImageOps
26 | import torchvision.transforms.functional as TF
27 |
28 |
29 | class GaussianBlur(object):
30 | """
31 | Apply Gaussian Blur to the PIL image.
32 | """
33 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
34 | self.prob = p
35 | self.radius_min = radius_min
36 | self.radius_max = radius_max
37 |
38 | def __call__(self, img):
39 | do_it = random.random() <= self.prob
40 | if not do_it:
41 | return img
42 |
43 | img = img.filter(
44 | ImageFilter.GaussianBlur(
45 | radius=random.uniform(self.radius_min, self.radius_max)
46 | )
47 | )
48 | return img
49 |
50 | class Solarization(object):
51 | """
52 | Apply Solarization to the PIL image.
53 | """
54 | def __init__(self, p=0.2):
55 | self.p = p
56 |
57 | def __call__(self, img):
58 | if random.random() < self.p:
59 | return ImageOps.solarize(img)
60 | else:
61 | return img
62 |
63 | class gray_scale(object):
64 | """
65 | Apply Solarization to the PIL image.
66 | """
67 | def __init__(self, p=0.2):
68 | self.p = p
69 | self.transf = transforms.Grayscale(3)
70 |
71 | def __call__(self, img):
72 | if random.random() < self.p:
73 | return self.transf(img)
74 | else:
75 | return img
76 |
77 |
78 |
79 | class horizontal_flip(object):
80 | """
81 | Apply Solarization to the PIL image.
82 | """
83 | def __init__(self, p=0.2,activate_pred=False):
84 | self.p = p
85 | self.transf = transforms.RandomHorizontalFlip(p=1.0)
86 |
87 | def __call__(self, img):
88 | if random.random() < self.p:
89 | return self.transf(img)
90 | else:
91 | return img
92 |
93 |
94 |
95 | def new_data_aug_generator(args = None):
96 | img_size = args.input_size
97 | remove_random_resized_crop = args.src
98 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
99 | primary_tfl = []
100 | scale=(0.08, 1.0)
101 | interpolation='bicubic'
102 | if remove_random_resized_crop:
103 | primary_tfl = [
104 | transforms.Resize(img_size, interpolation=3),
105 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'),
106 | transforms.RandomHorizontalFlip()
107 | ]
108 | else:
109 | primary_tfl = [
110 | RandomResizedCropAndInterpolation(
111 | img_size, scale=scale, interpolation=interpolation),
112 | transforms.RandomHorizontalFlip()
113 | ]
114 |
115 |
116 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),
117 | Solarization(p=1.0),
118 | GaussianBlur(p=1.0)])]
119 |
120 | if args.color_jitter is not None and not args.color_jitter==0:
121 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))
122 | final_tfl = [
123 | transforms.ToTensor(),
124 | transforms.Normalize(
125 | mean=torch.tensor(mean),
126 | std=torch.tensor(std))
127 | ]
128 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)
129 |
--------------------------------------------------------------------------------
/vim/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import os
4 | import json
5 |
6 | from torchvision import datasets, transforms
7 | from torchvision.datasets.folder import ImageFolder, default_loader
8 |
9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10 | from timm.data import create_transform
11 |
12 |
13 | class INatDataset(ImageFolder):
14 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
15 | category='name', loader=default_loader):
16 | self.transform = transform
17 | self.loader = loader
18 | self.target_transform = target_transform
19 | self.year = year
20 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
21 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
22 | with open(path_json) as json_file:
23 | data = json.load(json_file)
24 |
25 | with open(os.path.join(root, 'categories.json')) as json_file:
26 | data_catg = json.load(json_file)
27 |
28 | path_json_for_targeter = os.path.join(root, f"train{year}.json")
29 |
30 | with open(path_json_for_targeter) as json_file:
31 | data_for_targeter = json.load(json_file)
32 |
33 | targeter = {}
34 | indexer = 0
35 | for elem in data_for_targeter['annotations']:
36 | king = []
37 | king.append(data_catg[int(elem['category_id'])][category])
38 | if king[0] not in targeter.keys():
39 | targeter[king[0]] = indexer
40 | indexer += 1
41 | self.nb_classes = len(targeter)
42 |
43 | self.samples = []
44 | for elem in data['images']:
45 | cut = elem['file_name'].split('/')
46 | target_current = int(cut[2])
47 | path_current = os.path.join(root, cut[0], cut[2], cut[3])
48 |
49 | categors = data_catg[target_current]
50 | target_current_true = targeter[categors[category]]
51 | self.samples.append((path_current, target_current_true))
52 |
53 | # __getitem__ and __len__ inherited from ImageFolder
54 |
55 |
56 | def build_dataset(is_train, args):
57 | transform = build_transform(is_train, args)
58 |
59 | if args.data_set == 'CIFAR':
60 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
61 | nb_classes = 100
62 | elif args.data_set == 'IMNET':
63 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
64 | dataset = datasets.ImageFolder(root, transform=transform)
65 | nb_classes = 1000
66 | elif args.data_set == 'INAT':
67 | dataset = INatDataset(args.data_path, train=is_train, year=2018,
68 | category=args.inat_category, transform=transform)
69 | nb_classes = dataset.nb_classes
70 | elif args.data_set == 'INAT19':
71 | dataset = INatDataset(args.data_path, train=is_train, year=2019,
72 | category=args.inat_category, transform=transform)
73 | nb_classes = dataset.nb_classes
74 |
75 | return dataset, nb_classes
76 |
77 |
78 | def build_transform(is_train, args):
79 | resize_im = args.input_size > 32
80 | if is_train:
81 | # this should always dispatch to transforms_imagenet_train
82 | transform = create_transform(
83 | input_size=args.input_size,
84 | is_training=True,
85 | color_jitter=args.color_jitter,
86 | auto_augment=args.aa,
87 | interpolation=args.train_interpolation,
88 | re_prob=args.reprob,
89 | re_mode=args.remode,
90 | re_count=args.recount,
91 | )
92 | if not resize_im:
93 | # replace RandomResizedCropAndInterpolation with
94 | # RandomCrop
95 | transform.transforms[0] = transforms.RandomCrop(
96 | args.input_size, padding=4)
97 | return transform
98 |
99 | t = []
100 | if resize_im:
101 | size = int(args.input_size / args.eval_crop_ratio)
102 | t.append(
103 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
104 | )
105 | t.append(transforms.CenterCrop(args.input_size))
106 |
107 | t.append(transforms.ToTensor())
108 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
109 | return transforms.Compose(t)
110 |
--------------------------------------------------------------------------------
/vim/engine.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Train and eval functions used in main.py
5 | """
6 | import math
7 | import sys
8 | from typing import Iterable, Optional
9 |
10 | import torch
11 |
12 | import timm
13 | from timm.data import Mixup
14 | from timm.utils import accuracy, ModelEma
15 |
16 | from losses import DistillationLoss
17 | import utils
18 |
19 |
20 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
21 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
22 | device: torch.device, epoch: int, loss_scaler, amp_autocast, max_norm: float = 0,
23 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
24 | set_training_mode=True, args = None):
25 | model.train(set_training_mode)
26 | metric_logger = utils.MetricLogger(delimiter=" ")
27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
28 | header = 'Epoch: [{}]'.format(epoch)
29 | print_freq = 10
30 |
31 | if args.cosub:
32 | criterion = torch.nn.BCEWithLogitsLoss()
33 |
34 | # debug
35 | # count = 0
36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
37 | # count += 1
38 | # if count > 20:
39 | # break
40 |
41 | samples = samples.to(device, non_blocking=True)
42 | targets = targets.to(device, non_blocking=True)
43 |
44 | if mixup_fn is not None:
45 | samples, targets = mixup_fn(samples, targets)
46 |
47 | if args.cosub:
48 | samples = torch.cat((samples,samples),dim=0)
49 |
50 | if args.bce_loss:
51 | targets = targets.gt(0.0).type(targets.dtype)
52 |
53 | with amp_autocast():
54 | outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank)
55 | # outputs = model(samples)
56 | if not args.cosub:
57 | loss = criterion(samples, outputs, targets)
58 | else:
59 | outputs = torch.split(outputs, outputs.shape[0]//2, dim=0)
60 | loss = 0.25 * criterion(outputs[0], targets)
61 | loss = loss + 0.25 * criterion(outputs[1], targets)
62 | loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
63 | loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid())
64 |
65 | if args.if_nan2num:
66 | with amp_autocast():
67 | loss = torch.nan_to_num(loss)
68 |
69 | loss_value = loss.item()
70 |
71 | if not math.isfinite(loss_value):
72 | print("Loss is {}, stopping training".format(loss_value))
73 | if args.if_continue_inf:
74 | optimizer.zero_grad()
75 | continue
76 | else:
77 | sys.exit(1)
78 |
79 | optimizer.zero_grad()
80 |
81 | # this attribute is added by timm on one optimizer (adahessian)
82 | if isinstance(loss_scaler, timm.utils.NativeScaler):
83 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
84 | loss_scaler(loss, optimizer, clip_grad=max_norm,
85 | parameters=model.parameters(), create_graph=is_second_order)
86 | else:
87 | loss.backward()
88 | if max_norm != None:
89 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
90 | optimizer.step()
91 |
92 | torch.cuda.synchronize()
93 | if model_ema is not None:
94 | model_ema.update(model)
95 |
96 | metric_logger.update(loss=loss_value)
97 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
98 | # gather the stats from all processes
99 | metric_logger.synchronize_between_processes()
100 | print("Averaged stats:", metric_logger)
101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
102 |
103 |
104 | @torch.no_grad()
105 | def evaluate(data_loader, model, device, amp_autocast):
106 | criterion = torch.nn.CrossEntropyLoss()
107 |
108 | metric_logger = utils.MetricLogger(delimiter=" ")
109 | header = 'Test:'
110 |
111 | # switch to evaluation mode
112 | model.eval()
113 |
114 | for images, target in metric_logger.log_every(data_loader, 10, header):
115 | images = images.to(device, non_blocking=True)
116 | target = target.to(device, non_blocking=True)
117 |
118 | # compute output
119 | with amp_autocast():
120 | output = model(images)
121 | loss = criterion(output, target)
122 |
123 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
124 |
125 | batch_size = images.shape[0]
126 | metric_logger.update(loss=loss.item())
127 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
128 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
129 | # gather the stats from all processes
130 | metric_logger.synchronize_between_processes()
131 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
132 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
133 |
134 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
135 |
--------------------------------------------------------------------------------
/vim/hubconf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | from models import *
4 | from cait_models import *
5 | from resmlp_models import *
6 | #from patchconvnet_models import *
7 |
8 | dependencies = ["torch", "torchvision", "timm"]
9 |
--------------------------------------------------------------------------------
/vim/images/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/1.jpg
--------------------------------------------------------------------------------
/vim/images/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/2.jpg
--------------------------------------------------------------------------------
/vim/images/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/3.jpg
--------------------------------------------------------------------------------
/vim/images/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/4.jpg
--------------------------------------------------------------------------------
/vim/images/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AmeenAli/HiddenMambaAttn/77739efbf4f29cd61eaa381a21446744e069dac4/vim/images/5.jpg
--------------------------------------------------------------------------------
/vim/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Implements the knowledge distillation loss
5 | """
6 | import torch
7 | from torch.nn import functional as F
8 |
9 |
10 | class DistillationLoss(torch.nn.Module):
11 | """
12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by
13 | taking a teacher model prediction and using it as additional supervision.
14 | """
15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
16 | distillation_type: str, alpha: float, tau: float):
17 | super().__init__()
18 | self.base_criterion = base_criterion
19 | self.teacher_model = teacher_model
20 | assert distillation_type in ['none', 'soft', 'hard']
21 | self.distillation_type = distillation_type
22 | self.alpha = alpha
23 | self.tau = tau
24 |
25 | def forward(self, inputs, outputs, labels):
26 | """
27 | Args:
28 | inputs: The original inputs that are feed to the teacher model
29 | outputs: the outputs of the model to be trained. It is expected to be
30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output
31 | in the first position and the distillation predictions as the second output
32 | labels: the labels for the base criterion
33 | """
34 | outputs_kd = None
35 | if not isinstance(outputs, torch.Tensor):
36 | # assume that the model outputs a tuple of [outputs, outputs_kd]
37 | outputs, outputs_kd = outputs
38 | base_loss = self.base_criterion(outputs, labels)
39 | if self.distillation_type == 'none':
40 | return base_loss
41 |
42 | if outputs_kd is None:
43 | raise ValueError("When knowledge distillation is enabled, the model is "
44 | "expected to return a Tuple[Tensor, Tensor] with the output of the "
45 | "class_token and the dist_token")
46 | # don't backprop throught the teacher
47 | with torch.no_grad():
48 | teacher_outputs = self.teacher_model(inputs)
49 |
50 | if self.distillation_type == 'soft':
51 | T = self.tau
52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
53 | # with slight modifications
54 | distillation_loss = F.kl_div(
55 | F.log_softmax(outputs_kd / T, dim=1),
56 | #We provide the teacher's targets in log probability because we use log_target=True
57 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
58 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
59 | F.log_softmax(teacher_outputs / T, dim=1),
60 | reduction='sum',
61 | log_target=True
62 | ) * (T * T) / outputs_kd.numel()
63 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
64 | #But we also experiments output_kd.size(0)
65 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
66 | elif self.distillation_type == 'hard':
67 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
68 |
69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
70 | return loss
71 |
--------------------------------------------------------------------------------
/vim/rope.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # EVA-02: A Visual Representation for Neon Genesis
3 | # Github source: https://github.com/baaivision/EVA/EVA02
4 | # Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI)
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Yuxin Fang
7 | #
8 | # Based on https://github.com/lucidrains/rotary-embedding-torch
9 | # --------------------------------------------------------'
10 |
11 | from math import pi
12 |
13 | import torch
14 | from torch import nn
15 |
16 | from einops import rearrange, repeat
17 |
18 |
19 |
20 | def broadcat(tensors, dim = -1):
21 | num_tensors = len(tensors)
22 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
23 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
24 | shape_len = list(shape_lens)[0]
25 | dim = (dim + shape_len) if dim < 0 else dim
26 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
27 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
28 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
29 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
30 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
31 | expanded_dims.insert(dim, (dim, dims[dim]))
32 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
33 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
34 | return torch.cat(tensors, dim = dim)
35 |
36 |
37 |
38 | def rotate_half(x):
39 | x = rearrange(x, '... (d r) -> ... d r', r = 2)
40 | x1, x2 = x.unbind(dim = -1)
41 | x = torch.stack((-x2, x1), dim = -1)
42 | return rearrange(x, '... d r -> ... (d r)')
43 |
44 |
45 |
46 | class VisionRotaryEmbedding(nn.Module):
47 | def __init__(
48 | self,
49 | dim,
50 | pt_seq_len,
51 | ft_seq_len=None,
52 | custom_freqs = None,
53 | freqs_for = 'lang',
54 | theta = 10000,
55 | max_freq = 10,
56 | num_freqs = 1,
57 | ):
58 | super().__init__()
59 | if custom_freqs:
60 | freqs = custom_freqs
61 | elif freqs_for == 'lang':
62 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
63 | elif freqs_for == 'pixel':
64 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
65 | elif freqs_for == 'constant':
66 | freqs = torch.ones(num_freqs).float()
67 | else:
68 | raise ValueError(f'unknown modality {freqs_for}')
69 |
70 | if ft_seq_len is None: ft_seq_len = pt_seq_len
71 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
72 |
73 | freqs_h = torch.einsum('..., f -> ... f', t, freqs)
74 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
75 |
76 | freqs_w = torch.einsum('..., f -> ... f', t, freqs)
77 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
78 |
79 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
80 |
81 | self.register_buffer("freqs_cos", freqs.cos())
82 | self.register_buffer("freqs_sin", freqs.sin())
83 |
84 | print('======== shape of rope freq', self.freqs_cos.shape, '========')
85 |
86 | def forward(self, t, start_index = 0):
87 | rot_dim = self.freqs_cos.shape[-1]
88 | end_index = start_index + rot_dim
89 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
90 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
91 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
92 | return torch.cat((t_left, t, t_right), dim = -1)
93 |
94 |
95 |
96 | class VisionRotaryEmbeddingFast(nn.Module):
97 | def __init__(
98 | self,
99 | dim,
100 | pt_seq_len=16,
101 | ft_seq_len=None,
102 | custom_freqs = None,
103 | freqs_for = 'lang',
104 | theta = 10000,
105 | max_freq = 10,
106 | num_freqs = 1,
107 | ):
108 | super().__init__()
109 | if custom_freqs:
110 | freqs = custom_freqs
111 | elif freqs_for == 'lang':
112 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
113 | elif freqs_for == 'pixel':
114 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
115 | elif freqs_for == 'constant':
116 | freqs = torch.ones(num_freqs).float()
117 | else:
118 | raise ValueError(f'unknown modality {freqs_for}')
119 |
120 | if ft_seq_len is None: ft_seq_len = pt_seq_len
121 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
122 |
123 | freqs = torch.einsum('..., f -> ... f', t, freqs)
124 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
125 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
126 |
127 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
128 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
129 |
130 | self.register_buffer("freqs_cos", freqs_cos)
131 | self.register_buffer("freqs_sin", freqs_sin)
132 |
133 | print('======== shape of rope freq', self.freqs_cos.shape, '========')
134 |
135 | def forward(self, t):
136 | if t.shape[1] % 2 != 0:
137 | t_spatial = t[:, 1:, :]
138 | t_spatial = t_spatial * self.freqs_cos + rotate_half(t_spatial) * self.freqs_sin
139 | return torch.cat((t[:, :1, :], t_spatial), dim=1)
140 | else:
141 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
--------------------------------------------------------------------------------
/vim/run_with_submitit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | A script to run multinode training with submitit.
5 | """
6 | import argparse
7 | import os
8 | import uuid
9 | from pathlib import Path
10 |
11 | import main as classification
12 | import submitit
13 |
14 |
15 | def parse_args():
16 | classification_parser = classification.get_args_parser()
17 | parser = argparse.ArgumentParser("Submitit for DeiT", parents=[classification_parser])
18 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
19 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
20 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
21 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
22 |
23 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
24 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this")
25 | parser.add_argument('--comment', default="", type=str,
26 | help='Comment to pass to scheduler, e.g. priority message')
27 | return parser.parse_args()
28 |
29 |
30 | def get_shared_folder() -> Path:
31 | user = os.getenv("USER")
32 | if Path("/checkpoint/").is_dir():
33 | p = Path(f"/checkpoint/{user}/experiments")
34 | p.mkdir(exist_ok=True)
35 | return p
36 | raise RuntimeError("No shared folder available")
37 |
38 |
39 | def get_init_file():
40 | # Init file must not exist, but it's parent dir must exist.
41 | os.makedirs(str(get_shared_folder()), exist_ok=True)
42 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
43 | if init_file.exists():
44 | os.remove(str(init_file))
45 | return init_file
46 |
47 |
48 | class Trainer(object):
49 | def __init__(self, args):
50 | self.args = args
51 |
52 | def __call__(self):
53 | import main as classification
54 |
55 | self._setup_gpu_args()
56 | classification.main(self.args)
57 |
58 | def checkpoint(self):
59 | import os
60 | import submitit
61 |
62 | self.args.dist_url = get_init_file().as_uri()
63 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
64 | if os.path.exists(checkpoint_file):
65 | self.args.resume = checkpoint_file
66 | print("Requeuing ", self.args)
67 | empty_trainer = type(self)(self.args)
68 | return submitit.helpers.DelayedSubmission(empty_trainer)
69 |
70 | def _setup_gpu_args(self):
71 | import submitit
72 | from pathlib import Path
73 |
74 | job_env = submitit.JobEnvironment()
75 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
76 | self.args.gpu = job_env.local_rank
77 | self.args.rank = job_env.global_rank
78 | self.args.world_size = job_env.num_tasks
79 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
80 |
81 |
82 | def main():
83 | args = parse_args()
84 | if args.job_dir == "":
85 | args.job_dir = get_shared_folder() / "%j"
86 |
87 | # Note that the folder will depend on the job_id, to easily track experiments
88 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
89 |
90 | num_gpus_per_node = args.ngpus
91 | nodes = args.nodes
92 | timeout_min = args.timeout
93 |
94 | partition = args.partition
95 | kwargs = {}
96 | if args.use_volta32:
97 | kwargs['slurm_constraint'] = 'volta32gb'
98 | if args.comment:
99 | kwargs['slurm_comment'] = args.comment
100 |
101 | executor.update_parameters(
102 | mem_gb=40 * num_gpus_per_node,
103 | gpus_per_node=num_gpus_per_node,
104 | tasks_per_node=num_gpus_per_node, # one task per GPU
105 | cpus_per_task=10,
106 | nodes=nodes,
107 | timeout_min=timeout_min, # max is 60 * 72
108 | # Below are cluster dependent parameters
109 | slurm_partition=partition,
110 | slurm_signal_delay_s=120,
111 | **kwargs
112 | )
113 |
114 | executor.update_parameters(name="deit")
115 |
116 | args.dist_url = get_init_file().as_uri()
117 | args.output_dir = args.job_dir
118 |
119 | trainer = Trainer(args)
120 | job = executor.submit(trainer)
121 |
122 | print("Submitted job_id:", job.job_id)
123 |
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
--------------------------------------------------------------------------------
/vim/samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import torch
4 | import torch.distributed as dist
5 | import math
6 |
7 |
8 | class RASampler(torch.utils.data.Sampler):
9 | """Sampler that restricts data loading to a subset of the dataset for distributed,
10 | with repeated augmentation.
11 | It ensures that different each augmented version of a sample will be visible to a
12 | different process (GPU)
13 | Heavily based on torch.utils.data.DistributedSampler
14 | """
15 |
16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
17 | if num_replicas is None:
18 | if not dist.is_available():
19 | raise RuntimeError("Requires distributed package to be available")
20 | num_replicas = dist.get_world_size()
21 | if rank is None:
22 | if not dist.is_available():
23 | raise RuntimeError("Requires distributed package to be available")
24 | rank = dist.get_rank()
25 | if num_repeats < 1:
26 | raise ValueError("num_repeats should be greater than 0")
27 | self.dataset = dataset
28 | self.num_replicas = num_replicas
29 | self.rank = rank
30 | self.num_repeats = num_repeats
31 | self.epoch = 0
32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
33 | self.total_size = self.num_samples * self.num_replicas
34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
36 | self.shuffle = shuffle
37 |
38 | def __iter__(self):
39 | if self.shuffle:
40 | # deterministically shuffle based on epoch
41 | g = torch.Generator()
42 | g.manual_seed(self.epoch)
43 | indices = torch.randperm(len(self.dataset), generator=g)
44 | else:
45 | indices = torch.arange(start=0, end=len(self.dataset))
46 |
47 | # add extra samples to make it evenly divisible
48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
49 | padding_size: int = self.total_size - len(indices)
50 | if padding_size > 0:
51 | indices += indices[:padding_size]
52 | assert len(indices) == self.total_size
53 |
54 | # subsample
55 | indices = indices[self.rank:self.total_size:self.num_replicas]
56 | assert len(indices) == self.num_samples
57 |
58 | return iter(indices[:self.num_selected_samples])
59 |
60 | def __len__(self):
61 | return self.num_selected_samples
62 |
63 | def set_epoch(self, epoch):
64 | self.epoch = epoch
65 |
--------------------------------------------------------------------------------
/vim/scripts/ft-vim-s.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | conda activate
3 | cd /vim;
4 |
5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --lr 5e-6 --min-lr 1e-5 --warmup-lr 1e-5 --drop-path 0.0 --weight-decay 1e-8 --num_workers 25 --data-path --output_dir ./output/vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --epochs 30 --finetune --no_amp
6 |
--------------------------------------------------------------------------------
/vim/scripts/ft-vim-t.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | conda activate
3 | cd /vim;
4 |
5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --lr 5e-6 --min-lr 1e-5 --warmup-lr 1e-5 --drop-path 0.0 --weight-decay 1e-8 --num_workers 25 --data-path --output_dir ./output/vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --epochs 30 --finetune --no_amp
6 |
--------------------------------------------------------------------------------
/vim/scripts/pt-vim-s.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | conda activate
3 | cd /vim;
4 |
5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 64 --drop-path 0.05 --weight-decay 0.05 --lr 1e-3 --num_workers 25 --data-path --output_dir ./output/vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp
6 |
--------------------------------------------------------------------------------
/vim/scripts/pt-vim-t.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | conda activate
3 | cd /vim;
4 |
5 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-path --output_dir ./output/vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp
6 |
--------------------------------------------------------------------------------
/vim/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Misc functions, including distributed helpers.
5 |
6 | Mostly copy-paste from torchvision references.
7 | """
8 | import io
9 | import os
10 | import time
11 | from collections import defaultdict, deque
12 | import datetime
13 |
14 | import torch
15 | import torch.distributed as dist
16 |
17 |
18 | class SmoothedValue(object):
19 | """Track a series of values and provide access to smoothed values over a
20 | window or the global series average.
21 | """
22 |
23 | def __init__(self, window_size=20, fmt=None):
24 | if fmt is None:
25 | fmt = "{median:.4f} ({global_avg:.4f})"
26 | self.deque = deque(maxlen=window_size)
27 | self.total = 0.0
28 | self.count = 0
29 | self.fmt = fmt
30 |
31 | def update(self, value, n=1):
32 | self.deque.append(value)
33 | self.count += n
34 | self.total += value * n
35 |
36 | def synchronize_between_processes(self):
37 | """
38 | Warning: does not synchronize the deque!
39 | """
40 | if not is_dist_avail_and_initialized():
41 | return
42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
43 | dist.barrier()
44 | dist.all_reduce(t)
45 | t = t.tolist()
46 | self.count = int(t[0])
47 | self.total = t[1]
48 |
49 | @property
50 | def median(self):
51 | d = torch.tensor(list(self.deque))
52 | return d.median().item()
53 |
54 | @property
55 | def avg(self):
56 | d = torch.tensor(list(self.deque), dtype=torch.float32)
57 | return d.mean().item()
58 |
59 | @property
60 | def global_avg(self):
61 | return self.total / self.count
62 |
63 | @property
64 | def max(self):
65 | return max(self.deque)
66 |
67 | @property
68 | def value(self):
69 | return self.deque[-1]
70 |
71 | def __str__(self):
72 | return self.fmt.format(
73 | median=self.median,
74 | avg=self.avg,
75 | global_avg=self.global_avg,
76 | max=self.max,
77 | value=self.value)
78 |
79 |
80 | class MetricLogger(object):
81 | def __init__(self, delimiter="\t"):
82 | self.meters = defaultdict(SmoothedValue)
83 | self.delimiter = delimiter
84 |
85 | def update(self, **kwargs):
86 | for k, v in kwargs.items():
87 | if isinstance(v, torch.Tensor):
88 | v = v.item()
89 | assert isinstance(v, (float, int))
90 | self.meters[k].update(v)
91 |
92 | def __getattr__(self, attr):
93 | if attr in self.meters:
94 | return self.meters[attr]
95 | if attr in self.__dict__:
96 | return self.__dict__[attr]
97 | raise AttributeError("'{}' object has no attribute '{}'".format(
98 | type(self).__name__, attr))
99 |
100 | def __str__(self):
101 | loss_str = []
102 | for name, meter in self.meters.items():
103 | loss_str.append(
104 | "{}: {}".format(name, str(meter))
105 | )
106 | return self.delimiter.join(loss_str)
107 |
108 | def synchronize_between_processes(self):
109 | for meter in self.meters.values():
110 | meter.synchronize_between_processes()
111 |
112 | def add_meter(self, name, meter):
113 | self.meters[name] = meter
114 |
115 | def log_every(self, iterable, print_freq, header=None):
116 | i = 0
117 | if not header:
118 | header = ''
119 | start_time = time.time()
120 | end = time.time()
121 | iter_time = SmoothedValue(fmt='{avg:.4f}')
122 | data_time = SmoothedValue(fmt='{avg:.4f}')
123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
124 | log_msg = [
125 | header,
126 | '[{0' + space_fmt + '}/{1}]',
127 | 'eta: {eta}',
128 | '{meters}',
129 | 'time: {time}',
130 | 'data: {data}'
131 | ]
132 | if torch.cuda.is_available():
133 | log_msg.append('max mem: {memory:.0f}')
134 | log_msg = self.delimiter.join(log_msg)
135 | MB = 1024.0 * 1024.0
136 | for obj in iterable:
137 | data_time.update(time.time() - end)
138 | yield obj
139 | iter_time.update(time.time() - end)
140 | if i % print_freq == 0 or i == len(iterable) - 1:
141 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
143 | if torch.cuda.is_available():
144 | print(log_msg.format(
145 | i, len(iterable), eta=eta_string,
146 | meters=str(self),
147 | time=str(iter_time), data=str(data_time),
148 | memory=torch.cuda.max_memory_allocated() / MB))
149 | else:
150 | print(log_msg.format(
151 | i, len(iterable), eta=eta_string,
152 | meters=str(self),
153 | time=str(iter_time), data=str(data_time)))
154 | i += 1
155 | end = time.time()
156 | total_time = time.time() - start_time
157 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
158 | print('{} Total time: {} ({:.4f} s / it)'.format(
159 | header, total_time_str, total_time / len(iterable)))
160 |
161 |
162 | def _load_checkpoint_for_ema(model_ema, checkpoint):
163 | """
164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
165 | """
166 | mem_file = io.BytesIO()
167 | torch.save({'state_dict_ema':checkpoint}, mem_file)
168 | mem_file.seek(0)
169 | model_ema._load_checkpoint(mem_file)
170 |
171 |
172 | def setup_for_distributed(is_master):
173 | """
174 | This function disables printing when not in master process
175 | """
176 | import builtins as __builtin__
177 | builtin_print = __builtin__.print
178 |
179 | def print(*args, **kwargs):
180 | force = kwargs.pop('force', False)
181 | if is_master or force:
182 | builtin_print(*args, **kwargs)
183 |
184 | __builtin__.print = print
185 |
186 |
187 | def is_dist_avail_and_initialized():
188 | if not dist.is_available():
189 | return False
190 | if not dist.is_initialized():
191 | return False
192 | return True
193 |
194 |
195 | def get_world_size():
196 | if not is_dist_avail_and_initialized():
197 | return 1
198 | return dist.get_world_size()
199 |
200 |
201 | def get_rank():
202 | if not is_dist_avail_and_initialized():
203 | return 0
204 | return dist.get_rank()
205 |
206 |
207 | def is_main_process():
208 | return get_rank() == 0
209 |
210 |
211 | def save_on_master(*args, **kwargs):
212 | if is_main_process():
213 | torch.save(*args, **kwargs)
214 |
215 |
216 | def init_distributed_mode(args):
217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
218 | args.rank = int(os.environ["RANK"])
219 | args.world_size = int(os.environ['WORLD_SIZE'])
220 | args.gpu = int(os.environ['LOCAL_RANK'])
221 | elif 'SLURM_PROCID' in os.environ:
222 | args.rank = int(os.environ['SLURM_PROCID'])
223 | args.gpu = args.rank % torch.cuda.device_count()
224 | else:
225 | print('Not using distributed mode')
226 | args.distributed = False
227 | return
228 |
229 | args.distributed = True
230 |
231 | torch.cuda.set_device(args.gpu)
232 | args.dist_backend = 'nccl'
233 | print('| distributed init (rank {}): {}'.format(
234 | args.rank, args.dist_url), flush=True)
235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
236 | world_size=args.world_size, rank=args.rank)
237 | torch.distributed.barrier()
238 | setup_for_distributed(args.rank == 0)
239 |
240 |
241 | # if 'pos_embed' in state_dict:
242 | def interpolate_pos_embed(model, state_dict):
243 | pos_embed_checkpoint = state_dict['pos_embed']
244 | embedding_size = pos_embed_checkpoint.shape[-1]
245 | num_patches = model.patch_embed.num_patches
246 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
247 | # height (== width) for the checkpoint position embedding
248 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
249 | # height (== width) for the new position embedding
250 | new_size = int(num_patches ** 0.5)
251 | # class_token and dist_token are kept unchanged
252 | if orig_size != new_size:
253 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
254 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
255 | # only the position tokens are interpolated
256 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
257 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
258 | pos_tokens = torch.nn.functional.interpolate(
259 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
260 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
261 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
262 | state_dict['pos_embed'] = new_pos_embed
--------------------------------------------------------------------------------
/vim/vim_requirements.txt:
--------------------------------------------------------------------------------
1 | addict==2.4.0
2 | aiohttp==3.9.1
3 | aiosignal==1.3.1
4 | alembic==1.13.0
5 | async-timeout==4.0.3
6 | attrs==23.1.0
7 | blinker==1.7.0
8 | # causal-conv1d @ file:///home/zhulianghui/VisionProjects/mamba/lib/causal_conv1d-1.0.0%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=79a4bab633ebff031e615d5e8ba396b0dc0c046f4406980ee238fb86a9090038
9 | certifi==2023.11.17
10 | charset-normalizer==3.3.2
11 | click==8.1.7
12 | cloudpickle==3.0.0
13 | contourpy==1.2.0
14 | cycler==0.12.1
15 | databricks-cli==0.18.0
16 | datasets==2.15.0
17 | dill==0.3.7
18 | docker==6.1.3
19 | einops==0.7.0
20 | entrypoints==0.4
21 | filelock==3.13.1
22 | Flask==3.0.0
23 | fonttools==4.46.0
24 | frozenlist==1.4.0
25 | fsspec==2023.10.0
26 | gitdb==4.0.11
27 | GitPython==3.1.40
28 | greenlet==3.0.2
29 | gunicorn==21.2.0
30 | huggingface-hub==0.19.4
31 | idna==3.6
32 | importlib-metadata==7.0.0
33 | itsdangerous==2.1.2
34 | Jinja2==3.1.2
35 | joblib==1.3.2
36 | kiwisolver==1.4.5
37 | Mako==1.3.0
38 | # mamba-ssm @ file:///home/zhulianghui/VisionProjects/mamba/lib/mamba_ssm-1.0.1%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=71ad1b1eafb05a6e8a41fd82e046fe85511d6378fa3a583e55215b6aa1d65ab9
39 | Markdown==3.5.1
40 | MarkupSafe==2.1.3
41 | matplotlib==3.8.2
42 | mlflow==2.9.1
43 | mmcv==1.3.8
44 | mmsegmentation==0.14.1
45 | mpmath==1.3.0
46 | multidict==6.0.4
47 | multiprocess==0.70.15
48 | networkx==3.2.1
49 | ninja==1.11.1.1
50 | numpy==1.26.2
51 | # nvidia-cublas-cu12==12.1.3.1
52 | # nvidia-cuda-cupti-cu12==12.1.105
53 | # nvidia-cuda-nvrtc-cu12==12.1.105
54 | # nvidia-cuda-runtime-cu12==12.1.105
55 | # nvidia-cudnn-cu12==8.9.2.26
56 | # nvidia-cufft-cu12==11.0.2.54
57 | # nvidia-curand-cu12==10.3.2.106
58 | # nvidia-cusolver-cu12==11.4.5.107
59 | # nvidia-cusparse-cu12==12.1.0.106
60 | # nvidia-nccl-cu12==2.18.1
61 | # nvidia-nvjitlink-cu12==12.3.101
62 | # nvidia-nvtx-cu12==12.1.105
63 | oauthlib==3.2.2
64 | opencv-python==4.8.1.78
65 | packaging==23.2
66 | pandas==2.1.3
67 | Pillow==10.1.0
68 | platformdirs==4.1.0
69 | prettytable==3.9.0
70 | protobuf==4.25.1
71 | pyarrow==14.0.1
72 | pyarrow-hotfix==0.6
73 | PyJWT==2.8.0
74 | pyparsing==3.1.1
75 | python-dateutil==2.8.2
76 | python-hostlist==1.23.0
77 | pytz==2023.3.post1
78 | PyYAML==6.0.1
79 | querystring-parser==1.2.4
80 | regex==2023.10.3
81 | requests==2.31.0
82 | safetensors==0.4.1
83 | scikit-learn==1.3.2
84 | scipy==1.11.4
85 | six==1.16.0
86 | smmap==5.0.1
87 | SQLAlchemy==2.0.23
88 | sqlparse==0.4.4
89 | sympy==1.12
90 | tabulate==0.9.0
91 | threadpoolctl==3.2.0
92 | timm==0.4.12
93 | tokenizers==0.15.0
94 | tomli==2.0.1
95 | # torch==2.1.1+cu118
96 | # torchvision==0.16.1+cu118
97 | tqdm==4.66.1
98 | transformers==4.35.2
99 | triton==2.1.0
100 | typing_extensions==4.8.0
101 | tzdata==2023.3
102 | urllib3==2.1.0
103 | wcwidth==0.2.12
104 | websocket-client==1.7.0
105 | Werkzeug==3.0.1
106 | xxhash==3.4.1
107 | yapf==0.40.2
108 | yarl==1.9.4
109 | zipp==3.17.0
110 |
--------------------------------------------------------------------------------
/vim/vmamba_xai.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "/media/data1/ameenali/miniconda3/envs/github/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13 | " from .autonotebook import tqdm as notebook_tqdm\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "import numpy as np\n",
19 | "import torch\n",
20 | "import torch.backends.cudnn as cudnn\n",
21 | "from timm.models import create_model\n",
22 | "import models_mamba\n",
23 | "import utils\n",
24 | "import os\n",
25 | "from xai_utils import *\n",
26 | "from class_mapper import CLS2IDX\n",
27 | "import matplotlib.pyplot as plt"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "Load Model
\n",
35 | "Make sure to speiciy the model checkpoint path"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {},
42 | "outputs": [
43 | {
44 | "data": {
45 | "text/plain": [
46 | ""
47 | ]
48 | },
49 | "execution_count": 2,
50 | "metadata": {},
51 | "output_type": "execute_result"
52 | }
53 | ],
54 | "source": [
55 | "model_type = 'vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2'\n",
56 | "model_path = './vim_s_midclstok_80p5acc.pth'\n",
57 | "num_classes = 1000\n",
58 | "model = create_model(\n",
59 | " model_type,\n",
60 | " pretrained=False,\n",
61 | " num_classes=num_classes,\n",
62 | " drop_rate=0,\n",
63 | " drop_path_rate=0,\n",
64 | " drop_block_rate=None,\n",
65 | " img_size=224\n",
66 | ")\n",
67 | "checkpoint = torch.load(model_path, map_location='cpu')\n",
68 | "model.load_state_dict(checkpoint['model'])"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "Auxiliary Functions
"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 3,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "from PIL import Image\n",
85 | "import torchvision.transforms as transforms\n",
86 | "\n",
87 | "IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]\n",
88 | "IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]\n",
89 | "\n",
90 | "def transform_for_eval(image_path, input_size=224):\n",
91 | " transform_eval = transforms.Compose([\n",
92 | " transforms.Resize(int(input_size)),\n",
93 | " transforms.CenterCrop(input_size),\n",
94 | " transforms.ToTensor(),\n",
95 | " transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),\n",
96 | " ])\n",
97 | " img = Image.open(image_path).convert('RGB')\n",
98 | " transformed_img = transform_eval(img)\n",
99 | " return transformed_img\n",
100 | "\n",
101 | "import cv2\n",
102 | "\n",
103 | "invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],\n",
104 | " std = [ 1/0.229, 1/0.224, 1/0.225 ]),\n",
105 | " transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],\n",
106 | " std = [ 1., 1., 1. ]),\n",
107 | " ])\n",
108 | "\n",
109 | "def show_cam_on_image(img, mask):\n",
110 | " heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)\n",
111 | " heatmap = np.float32(heatmap) / 255\n",
112 | " cam = heatmap + np.float32(img)\n",
113 | " cam = cam / np.max(cam)\n",
114 | " return cam\n",
115 | "\n",
116 | "\n",
117 | "def generate_visualization(original_image, transformer_attribution):\n",
118 | " transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)\n",
119 | " transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')\n",
120 | " transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()\n",
121 | " transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())\n",
122 | " image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()\n",
123 | " image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())\n",
124 | " vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)\n",
125 | " vis = np.uint8(255 * vis)\n",
126 | " vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)\n",
127 | " return vis\n",
128 | "\n",
129 | "def print_preds(logits):\n",
130 | " prob = torch.softmax(logits, dim=1)\n",
131 | " class_indices = logits.data.topk(5, dim=1)[1][0].tolist()\n",
132 | " max_str_len = 0\n",
133 | " class_names = []\n",
134 | " for cls_idx in class_indices:\n",
135 | " class_names.append(CLS2IDX[cls_idx])\n",
136 | " if len(CLS2IDX[cls_idx]) > max_str_len:\n",
137 | " max_str_len = len(CLS2IDX[cls_idx])\n",
138 | "\n",
139 | " print('Top 5 classes:')\n",
140 | " for cls_idx in class_indices:\n",
141 | " output_string = '\\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])\n",
142 | " output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\\t\\t'\n",
143 | " output_string += 'value = {:.3f}\\t prob = {:.1f}%'.format(logits[0, cls_idx], 100 * prob[0, cls_idx])\n",
144 | " print(output_string)"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 4,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "model = model.cuda()"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": null,
159 | "metadata": {},
160 | "outputs": [],
161 | "source": [
162 | "image = transform_for_eval('./images/1.jpg').unsqueeze(0).cuda()\n",
163 | "raw_image = Image.open('./images/1.jpg')\n",
164 | "map_raw_atten, logits = generate_raw_attn(model, image)\n",
165 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n",
166 | "map_rollout, _ = generate_rollout(model, image)\n",
167 | "image = image.squeeze()\n",
168 | "\n",
169 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n",
170 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n",
171 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n",
172 | "print_preds(logits)\n",
173 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n",
174 | "axs[0].imshow(raw_image)\n",
175 | "axs[0].axis('off')\n",
176 | "axs[1].imshow(raw_attn)\n",
177 | "axs[1].axis('off')\n",
178 | "axs[2].imshow(rollout)\n",
179 | "axs[2].axis('off')\n",
180 | "axs[3].imshow(mamba_attr)\n",
181 | "axs[3].axis('off')\n"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "# Attention Matrices:\n",
191 | "# Load an image and prepare it for model evaluation\n",
192 | "image = transform_for_eval('./images/1.jpg').unsqueeze(0).cuda()\n",
193 | "\n",
194 | "# Specify the layer and channel to analyze\n",
195 | "selected_layer = 6\n",
196 | "selected_channel = 30\n",
197 | "\n",
198 | "# Enable computation of attention matrices in the model\n",
199 | "model.layers[selected_layer].mixer.compute_attn_matrix = True\n",
200 | "# Pass the image through the model\n",
201 | "out = model(image)\n",
202 | "\n",
203 | "# Extract and normalize attention matrices\n",
204 | "attn_matrix_a = model.layers[selected_layer].mixer.attn_matrix_a.abs()\n",
205 | "attn_matrix_b = model.layers[selected_layer].mixer.attn_matrix_b.abs()\n",
206 | "normalize_attn_mat = lambda attn_mat : (attn_mat.abs() - torch.min(attn_mat.abs())) / (torch.max(attn_mat.abs()) - torch.min(attn_mat.abs()))\n",
207 | "attn_matrix_a_normalize = normalize_attn_mat(attn_matrix_a)\n",
208 | "attn_matrix_b_normalize = normalize_attn_mat(attn_matrix_b)\n",
209 | "\n",
210 | "# Plot each attention matrix\n",
211 | "fig, axs = plt.subplots(1, 6, figsize=(10,10))\n",
212 | "for i in range(3):\n",
213 | " axs[i].imshow(attn_matrix_a.cpu().detach().numpy()[0, selected_channel+i, :, :])\n",
214 | " axs[i].axis('off')\n",
215 | " axs[i+3].imshow(attn_matrix_b.cpu().detach().numpy()[0, selected_channel+i, :, :])\n",
216 | " axs[i+3].axis('off')\n",
217 | " model.layers[selected_layer].mixer.compute_attn_matrix = False"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": null,
223 | "metadata": {},
224 | "outputs": [],
225 | "source": [
226 | "image = transform_for_eval('./images/2.jpg').unsqueeze(0).cuda()\n",
227 | "raw_image = Image.open('./images/2.jpg')\n",
228 | "map_raw_atten, logits = generate_raw_attn(model, image)\n",
229 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n",
230 | "map_rollout, _ = generate_rollout(model, image)\n",
231 | "image = image.squeeze()\n",
232 | "\n",
233 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n",
234 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n",
235 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n",
236 | "print_preds(logits)\n",
237 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n",
238 | "axs[0].imshow(raw_image)\n",
239 | "axs[0].axis('off')\n",
240 | "axs[1].imshow(raw_attn)\n",
241 | "axs[1].axis('off')\n",
242 | "axs[2].imshow(rollout)\n",
243 | "axs[2].axis('off')\n",
244 | "axs[3].imshow(mamba_attr)\n",
245 | "axs[3].axis('off')\n"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": null,
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "image = transform_for_eval('./images/3.jpg').unsqueeze(0).cuda()\n",
255 | "raw_image = Image.open('./images/3.jpg')\n",
256 | "map_raw_atten, logits = generate_raw_attn(model, image)\n",
257 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n",
258 | "map_rollout, _ = generate_rollout(model, image)\n",
259 | "image = image.squeeze()\n",
260 | "\n",
261 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n",
262 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n",
263 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n",
264 | "print_preds(logits)\n",
265 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n",
266 | "axs[0].imshow(raw_image)\n",
267 | "axs[0].axis('off')\n",
268 | "axs[1].imshow(raw_attn)\n",
269 | "axs[1].axis('off')\n",
270 | "axs[2].imshow(rollout)\n",
271 | "axs[2].axis('off')\n",
272 | "axs[3].imshow(mamba_attr)\n",
273 | "axs[3].axis('off')\n"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": null,
279 | "metadata": {},
280 | "outputs": [],
281 | "source": [
282 | "image = transform_for_eval('./images/4.jpg').unsqueeze(0).cuda()\n",
283 | "raw_image = Image.open('./images/4.jpg')\n",
284 | "map_raw_atten, logits = generate_raw_attn(model, image)\n",
285 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n",
286 | "map_rollout, _ = generate_rollout(model, image)\n",
287 | "image = image.squeeze()\n",
288 | "\n",
289 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n",
290 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n",
291 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n",
292 | "print_preds(logits)\n",
293 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n",
294 | "axs[0].imshow(raw_image)\n",
295 | "axs[0].axis('off')\n",
296 | "axs[1].imshow(raw_attn)\n",
297 | "axs[1].axis('off')\n",
298 | "axs[2].imshow(rollout)\n",
299 | "axs[2].axis('off')\n",
300 | "axs[3].imshow(mamba_attr)\n",
301 | "axs[3].axis('off')\n"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": null,
307 | "metadata": {},
308 | "outputs": [],
309 | "source": [
310 | "image = transform_for_eval('./images/5.jpg').unsqueeze(0).cuda()\n",
311 | "raw_image = Image.open('./images/5.jpg')\n",
312 | "map_raw_atten, logits = generate_raw_attn(model, image)\n",
313 | "map_mamba_attr, _ = generate_mamba_attr(model, image)\n",
314 | "map_rollout, _ = generate_rollout(model, image)\n",
315 | "image = image.squeeze()\n",
316 | "\n",
317 | "raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)\n",
318 | "mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)\n",
319 | "rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)\n",
320 | "print_preds(logits)\n",
321 | "fig, axs = plt.subplots(1, 4, figsize=(10,10))\n",
322 | "axs[0].imshow(raw_image)\n",
323 | "axs[0].axis('off')\n",
324 | "axs[1].imshow(raw_attn)\n",
325 | "axs[1].axis('off')\n",
326 | "axs[2].imshow(rollout)\n",
327 | "axs[2].axis('off')\n",
328 | "axs[3].imshow(mamba_attr)\n",
329 | "axs[3].axis('off')\n"
330 | ]
331 | }
332 | ],
333 | "metadata": {
334 | "kernelspec": {
335 | "display_name": "mamba",
336 | "language": "python",
337 | "name": "python3"
338 | },
339 | "language_info": {
340 | "codemirror_mode": {
341 | "name": "ipython",
342 | "version": 3
343 | },
344 | "file_extension": ".py",
345 | "mimetype": "text/x-python",
346 | "name": "python",
347 | "nbconvert_exporter": "python",
348 | "pygments_lexer": "ipython3",
349 | "version": "3.10.13"
350 | }
351 | },
352 | "nbformat": 4,
353 | "nbformat_minor": 2
354 | }
355 |
--------------------------------------------------------------------------------
/vim/xai_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def compute_rollout_attention(all_layer_matrices, start_layer=0):
6 | # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
7 | num_tokens = all_layer_matrices[0].shape[1]
8 | batch_size = all_layer_matrices[0].shape[0]
9 | eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
10 |
11 | all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
12 | matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
13 | for i in range(len(all_layer_matrices))]
14 | joint_attention = matrices_aug[start_layer]
15 | for i in range(start_layer+1, len(matrices_aug)):
16 | joint_attention = matrices_aug[i].bmm(joint_attention)
17 | return joint_attention
18 |
19 | def generate_raw_attn(model, image, start_layer=15):
20 | image.requires_grad_()
21 | logits = model(image)
22 | all_layer_attentions = []
23 | cls_pos = 98
24 | for layeridx in range(len(model.layers)):
25 | attn_heads = model.layers[layeridx].mixer.xai_b
26 | attn_heads = (attn_heads - attn_heads.min()) / (attn_heads.max() - attn_heads.min())
27 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
28 | all_layer_attentions.append(avg_heads)
29 | p = torch.cat(all_layer_attentions[start_layer:], dim=0).mean(dim=0).unsqueeze(0)
30 | p = torch.cat([p[:,:cls_pos], p[:,(cls_pos+1):]], dim=-1)
31 | return p.clamp(min=0).squeeze().unsqueeze(0), logits
32 |
33 |
34 | def generate_mamba_attr(model, image, start_layer=15):
35 | image.requires_grad_()
36 | logits = model(image)
37 | index = np.argmax(logits.cpu().data.numpy(), axis=-1)
38 | one_hot = np.zeros((1, logits.size()[-1]), dtype=np.float32)
39 | one_hot[0, index] = 1
40 | one_hot = torch.from_numpy(one_hot).requires_grad_(True)
41 | one_hot = torch.sum(one_hot.cuda() * logits)
42 | model.zero_grad()
43 | one_hot.backward(retain_graph=True)
44 | all_layer_attentions = []
45 | cls_pos = 98
46 | for layeridx in range(len(model.layers)):
47 | attn_heads = model.layers[layeridx].mixer.xai_b.clamp(min=0)
48 | s = model.layers[layeridx].get_gradients().squeeze().detach() #[1:, :].clamp(min=0).max(dim=1)[0].unsqueeze(0)
49 | s = s.clamp(min=0).max(dim=1)[0].unsqueeze(0)
50 | s = (s - s.min()) / (s.max() - s.min())
51 | attn_heads = (attn_heads - attn_heads.min()) / (attn_heads.max() - attn_heads.min())
52 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
53 | fused = avg_heads * s
54 | all_layer_attentions.append(fused)
55 | rollout = compute_rollout_attention(all_layer_attentions, start_layer)
56 | p = rollout[0 , cls_pos , :].unsqueeze(0)
57 | p = torch.cat([p[:,:cls_pos], p[:,(cls_pos+1):]], dim=-1)
58 | return p.clamp(min=0).squeeze().unsqueeze(0), logits
59 |
60 |
61 | def generate_rollout(model, image, start_layer=15, num_layers=24):
62 | image.requires_grad_()
63 | logits = model(image)
64 | all_layer_attentions = []
65 | cls_pos = 98
66 | for layer in range(num_layers):
67 | attn_heads = model.layers[layer].mixer.xai_b
68 | avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
69 | all_layer_attentions.append(avg_heads)
70 | rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
71 | p = rollout[0 , cls_pos , :].unsqueeze(0)
72 | p = torch.cat([p[:,:cls_pos], p[:,(cls_pos+1):]], dim=-1)
73 | return p.clamp(min=0).squeeze().unsqueeze(0), logits
74 |
--------------------------------------------------------------------------------