├── .gitignore ├── =1.1.0 ├── README.md ├── assets ├── FambaOverview.png ├── Strategies.png └── logo.jpg ├── causal-conv1d ├── =1.1.0 ├── AUTHORS ├── LICENSE ├── README.md ├── [13 ├── causal_conv1d │ ├── =1.1.0 │ ├── __init__.py │ └── causal_conv1d_interface.py ├── csrc │ ├── causal_conv1d.cpp │ ├── causal_conv1d.h │ ├── causal_conv1d_bwd.cu │ ├── causal_conv1d_common.h │ ├── causal_conv1d_fwd.cu │ ├── causal_conv1d_update.cu │ └── static_switch.h ├── setup.py └── tests │ └── test_causal_conv1d.py ├── fambav ├── .gitignore ├── LICENSE ├── augment.py ├── datasets.py ├── engine.py ├── hubconf.py ├── losses.py ├── main.py ├── mlruns │ └── 0 │ │ └── meta.yaml ├── models_mamba.py ├── rope.py ├── run_with_submitit.py ├── samplers.py ├── scripts │ ├── mlruns │ │ └── 0 │ │ │ └── meta.yaml │ ├── vim-s-cifar-all-merge-r7.sh │ ├── vim-s-cifar-alternate-merge-r14.sh │ ├── vim-s-cifar-lower-merge-r9-l4.sh │ ├── vim-s-cifar-lower-merge-r9-l5.sh │ ├── vim-s-cifar-non-merge.sh │ ├── vim-s-cifar-upper-merge-r9-l18.sh │ ├── vim-t-cifar-all-merge-r7-cls.sh │ ├── vim-t-cifar-all-merge-r7.sh │ ├── vim-t-cifar-alternate-merge-r14.sh │ ├── vim-t-cifar-lower-merge-r1-l4.sh │ ├── vim-t-cifar-lower-merge-r10-l6.sh │ ├── vim-t-cifar-lower-merge-r10-l7.sh │ ├── vim-t-cifar-lower-merge-r11-l8.sh │ ├── vim-t-cifar-lower-merge-r12-l10.sh │ ├── vim-t-cifar-lower-merge-r12-l9.sh │ ├── vim-t-cifar-lower-merge-r13-l11.sh │ ├── vim-t-cifar-lower-merge-r14-l12.sh │ ├── vim-t-cifar-lower-merge-r15-l13.sh │ ├── vim-t-cifar-lower-merge-r2-l4.sh │ ├── vim-t-cifar-lower-merge-r3-l4.sh │ ├── vim-t-cifar-lower-merge-r4-l4.sh │ ├── vim-t-cifar-lower-merge-r5-l4.sh │ ├── vim-t-cifar-lower-merge-r6-l4.sh │ ├── vim-t-cifar-lower-merge-r7-l4.sh │ ├── vim-t-cifar-lower-merge-r8-l3.sh │ ├── vim-t-cifar-lower-merge-r8-l4.sh │ ├── vim-t-cifar-lower-merge-r9-l4.sh │ ├── vim-t-cifar-lower-merge-r9-l5.sh │ ├── vim-t-cifar-non-merge-visual.sh │ ├── vim-t-cifar-non-merge.sh │ ├── vim-t-cifar-upper-merge-r9-l18.sh │ ├── vim-t-imagenet-all-merge-r7.sh │ ├── vim-t-imagenet-alternate-merge-r14.sh │ ├── vim-t-imagenet-lower-merge-r9-l5.sh │ ├── vim-t-imagenet-non-merge.sh │ └── vim-t-imagenet-upper-merge-r9-l18.sh ├── test.sh ├── token_merge.py ├── utils.py ├── vim_requirements.txt └── visualization │ ├── cosine_similarity_layer_0.png │ ├── cosine_similarity_layer_1.png │ ├── cosine_similarity_layer_10.png │ ├── cosine_similarity_layer_11.png │ ├── cosine_similarity_layer_12.png │ ├── cosine_similarity_layer_13.png │ ├── cosine_similarity_layer_14.png │ ├── cosine_similarity_layer_15.png │ ├── cosine_similarity_layer_16.png │ ├── cosine_similarity_layer_17.png │ ├── cosine_similarity_layer_18.png │ ├── cosine_similarity_layer_19.png │ ├── cosine_similarity_layer_2.png │ ├── cosine_similarity_layer_20.png │ ├── cosine_similarity_layer_21.png │ ├── cosine_similarity_layer_22.png │ ├── cosine_similarity_layer_23.png │ ├── cosine_similarity_layer_3.png │ ├── cosine_similarity_layer_4.png │ ├── cosine_similarity_layer_5.png │ ├── cosine_similarity_layer_6.png │ ├── cosine_similarity_layer_7.png │ ├── cosine_similarity_layer_8.png │ ├── cosine_similarity_layer_9.png │ ├── hidden_states_20240714_113249.npz │ ├── hidden_states_20240714_113250.npz │ ├── hidden_states_20240714_113251.npz │ ├── hidden_states_20240714_113252.npz │ ├── hidden_states_20240714_113253.npz │ ├── hidden_states_20240714_113254.npz │ ├── hidden_states_20240714_113255.npz │ ├── npz_viewer.py │ └── test_visualize.py └── mamba-1p1p1 ├── .github └── workflows │ └── publish.yaml ├── .gitignore ├── .gitmodules ├── AUTHORS ├── LICENSE ├── README.md ├── assets └── selection.png ├── benchmarks └── benchmark_generation_mamba_simple.py ├── csrc └── selective_scan │ ├── reverse_scan.cuh │ ├── selective_scan.cpp │ ├── selective_scan.h │ ├── selective_scan_bwd_bf16_complex.cu │ ├── selective_scan_bwd_bf16_real.cu │ ├── selective_scan_bwd_fp16_complex.cu │ ├── selective_scan_bwd_fp16_real.cu │ ├── selective_scan_bwd_fp32_complex.cu │ ├── selective_scan_bwd_fp32_real.cu │ ├── selective_scan_bwd_kernel.cuh │ ├── selective_scan_common.h │ ├── selective_scan_fwd_bf16.cu │ ├── selective_scan_fwd_fp16.cu │ ├── selective_scan_fwd_fp32.cu │ ├── selective_scan_fwd_kernel.cuh │ ├── static_switch.h │ └── uninitialized_copy.cuh ├── evals └── lm_harness_eval.py ├── mamba_ssm ├── __init__.py ├── models │ ├── __init__.py │ ├── config_mamba.py │ └── mixer_seq_simple.py ├── modules │ ├── __init__.py │ └── mamba_simple.py ├── ops │ ├── __init__.py │ ├── selective_scan_interface.py │ └── triton │ │ ├── __init__.py │ │ ├── layernorm.py │ │ └── selective_state_update.py └── utils │ ├── __init__.py │ ├── generation.py │ └── hf.py ├── setup.py └── tests └── ops ├── test_selective_scan.py └── triton └── test_selective_state_update.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | causal-conv1d/causal_conv1d/__pycache__/* 3 | mamba/mamba_ssm/utils/__pycache__/* 4 | mamba/mamba_ssm/ops/triton/__pycache__/* 5 | mamba/mamba_ssm/ops/__pycache__/* 6 | mamba/mamba_ssm/modules/__pycache__/* 7 | mamba/mamba_ssm/__pycache__/* 8 | mamba/mamba_ssm/models/__pycache__/* 9 | causal-conv1d/build/ 10 | causal-conv1d/causal_conv1d.egg-info/ 11 | causal-conv1d/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so 12 | mamba/build/ 13 | mamba/mamba_ssm.egg-info/ 14 | mamba/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /=1.1.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/=1.1.0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 |

4 | 5 |
6 |

Famba-V: Fast Vision Mamba with Cross-Layer Token Fusion

7 |
8 |
9 | 10 | License: Apache 2.0 11 | 12 |
13 | 14 | ## Introduction 15 | 16 | > **[Famba-V: Fast Vision Mamba with Cross-Layer Token Fusion](https://arxiv.org/abs/2409.09808)** [[arXiv]](https://arxiv.org/abs/2409.09808) 17 | > *Hui Shen, Zhongwei Wan, Xin Wang, Mi Zhang* 18 | > *The Ohio State University* 19 | > *ECCV 2024 Workshop on Computational Aspects of Deep Learning* 20 | 21 | ### ⚡News: Famba-V won the Best Paper Award of the ECCV 2024 Workshop on Computational Aspects of Deep Learning. 22 | 23 | ## Abstract 24 | Mamba and Vision Mamba (Vim) models have shown their potential as an alternative to methods based on Transformer architecture. This work introduces Fast Mamba for Vision (Famba-V), a cross-layer token fusion technique to enhance the training efficiency of Vim models. The key idea of Famba-V is to identify and fuse similar tokens across different Vim layers based on a suit of cross-layer strategies instead of simply applying token fusion uniformly across all the layers that existing works propose. We evaluate the performance of Famba-V on CIFAR-100. Our results show that Famba-V is able to enhance the training efficiency of Vim models by reducing both training time and peak memory usage during training. Moreover, the proposed cross-layer strategies allow Famba-V to deliver superior accuracy-efficiency trade-offs. These results all together demonstrate Famba-V as a promising efficiency enhancement technique for Vim models. 25 | 26 | 27 | ## Quick Start 28 | 29 | - Python 3.10.13 30 | 31 | - `conda create -n your_env_name python=3.10.13` 32 | 33 | - torch 2.1.1 + cu118 34 | - `pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118` 35 | 36 | - Requirements: vim_requirements.txt 37 | - `pip install -r fambav/vim_requirements.txt` 38 | 39 | - Install ``causal_conv1d`` and ``mamba`` 40 | - `pip install -e causal_conv1d>=1.1.0` 41 | - `pip install -e mamba-1p1p1` 42 | 43 | 44 | 45 | ## Train Your Famba-V with Upper-layer Fusion Strategy 46 | ```bash 47 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ./output/vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy upper --fusion-layer 4 --fusion-token 8 48 | ``` 49 | ## :heart: Acknowledgement 50 | This project is based on Vision Mamba ([paper](https://arxiv.org/abs/2401.09417), [code](https://github.com/hustvl/Vim?tab=readme-ov-file)), Mamba ([paper](https://arxiv.org/abs/2312.00752), [code](https://github.com/state-spaces/mamba)), Causal-Conv1d ([code](https://github.com/Dao-AILab/causal-conv1d)), DeiT ([paper](https://arxiv.org/abs/2012.12877), [code](https://github.com/facebookresearch/deit)). Thanks for their wonderful works. 51 | 52 | ## 🥳 Citation 53 | If you find Famba-V is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 54 | 55 | ```bibtex 56 | @inproceedings{fambav2024eccvw, 57 | title={Famba-V: Fast Vision Mamba with Sparse Fusion-based Visual Representation}, 58 | author={Shen, Hui and Wan, Zhongwei and Wang, Xin and Zhang, Mi}, 59 | booktitle={European Conference on Computer Vision (ECCV) Workshop on Computational Aspects of Deep Learning}, 60 | year={2024} 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /assets/FambaOverview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/assets/FambaOverview.png -------------------------------------------------------------------------------- /assets/Strategies.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/assets/Strategies.png -------------------------------------------------------------------------------- /assets/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/assets/logo.jpg -------------------------------------------------------------------------------- /causal-conv1d/=1.1.0: -------------------------------------------------------------------------------- 1 | Obtaining file:///users/PAS2490/marcusshen/Vim/causal-conv1d/causal_conv1d 2 | -------------------------------------------------------------------------------- /causal-conv1d/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | -------------------------------------------------------------------------------- /causal-conv1d/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /causal-conv1d/README.md: -------------------------------------------------------------------------------- 1 | # Causal depthwise conv1d in CUDA with a PyTorch interface 2 | -------------------------------------------------------------------------------- /causal-conv1d/[13: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/causal-conv1d/[13 -------------------------------------------------------------------------------- /causal-conv1d/causal_conv1d/=1.1.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/causal-conv1d/causal_conv1d/=1.1.0 -------------------------------------------------------------------------------- /causal-conv1d/causal_conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update 4 | -------------------------------------------------------------------------------- /causal-conv1d/causal_conv1d/causal_conv1d_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | import causal_conv1d_cuda 8 | 9 | 10 | class CausalConv1dFn(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, x, weight, bias=None, activation=None): 13 | if activation not in [None, "silu", "swish"]: 14 | raise NotImplementedError("activation must be None, silu, or swish") 15 | if x.stride(2) != 1 and x.stride(1) != 1: 16 | x = x.contiguous() 17 | bias = bias.contiguous() if bias is not None else None 18 | ctx.save_for_backward(x, weight, bias) 19 | ctx.activation = activation in ["silu", "swish"] 20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation) 21 | return out 22 | 23 | @staticmethod 24 | def backward(ctx, dout): 25 | x, weight, bias = ctx.saved_tensors 26 | if dout.stride(2) != 1 and dout.stride(1) != 1: 27 | dout = dout.contiguous() 28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 29 | # backward of conv1d with the backward of chunk). 30 | # Here we just pass in None and dx will be allocated in the C++ code. 31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd( 32 | x, weight, bias, dout, None, ctx.activation 33 | ) 34 | return dx, dweight, dbias if bias is not None else None, None 35 | 36 | 37 | def causal_conv1d_fn(x, weight, bias=None, activation=None): 38 | """ 39 | x: (batch, dim, seqlen) 40 | weight: (dim, width) 41 | bias: (dim,) 42 | activation: either None or "silu" or "swish" 43 | 44 | out: (batch, dim, seqlen) 45 | """ 46 | return CausalConv1dFn.apply(x, weight, bias, activation) 47 | 48 | 49 | def causal_conv1d_ref(x, weight, bias=None, activation=None): 50 | """ 51 | x: (batch, dim, seqlen) 52 | weight: (dim, width) 53 | bias: (dim,) 54 | 55 | out: (batch, dim, seqlen) 56 | """ 57 | if activation not in [None, "silu", "swish"]: 58 | raise NotImplementedError("activation must be None, silu, or swish") 59 | dtype_in = x.dtype 60 | x = x.to(weight.dtype) 61 | seqlen = x.shape[-1] 62 | dim, width = weight.shape 63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 64 | out = out[..., :seqlen] 65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 66 | 67 | 68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None): 69 | """ 70 | x: (batch, dim) 71 | conv_state: (batch, dim, width) 72 | weight: (dim, width) 73 | bias: (dim,) 74 | 75 | out: (batch, dim) 76 | """ 77 | if activation not in [None, "silu", "swish"]: 78 | raise NotImplementedError("activation must be None, silu, or swish") 79 | activation = activation in ["silu", "swish"] 80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) 81 | 82 | 83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None): 84 | """ 85 | x: (batch, dim) 86 | conv_state: (batch, dim, width) 87 | weight: (dim, width) 88 | bias: (dim,) 89 | 90 | out: (batch, dim) 91 | """ 92 | if activation not in [None, "silu", "swish"]: 93 | raise NotImplementedError("activation must be None, silu, or swish") 94 | dtype_in = x.dtype 95 | batch, dim = x.shape 96 | width = weight.shape[1] 97 | assert conv_state.shape == (batch, dim, width) 98 | assert weight.shape == (dim, width) 99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 100 | conv_state[:, :, -1] = x 101 | out = torch.sum(conv_state * weight, dim=-1) # (B D) 102 | if bias is not None: 103 | out += bias 104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 105 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct ConvParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, dim, seqlen, width; 13 | bool silu_activation; 14 | 15 | index_t x_batch_stride; 16 | index_t x_c_stride; 17 | index_t x_l_stride; 18 | index_t weight_c_stride; 19 | index_t weight_width_stride; 20 | index_t out_batch_stride; 21 | index_t out_c_stride; 22 | index_t out_l_stride; 23 | 24 | index_t conv_state_batch_stride; 25 | index_t conv_state_c_stride; 26 | index_t conv_state_l_stride; 27 | 28 | // Common data pointers. 29 | void *__restrict__ x_ptr; 30 | void *__restrict__ weight_ptr; 31 | void *__restrict__ bias_ptr; 32 | void *__restrict__ out_ptr; 33 | 34 | void *__restrict__ conv_state_ptr; 35 | }; 36 | 37 | struct ConvParamsBwd: public ConvParamsBase { 38 | index_t dx_batch_stride; 39 | index_t dx_c_stride; 40 | index_t dx_l_stride; 41 | index_t dweight_c_stride; 42 | index_t dweight_width_stride; 43 | index_t dout_batch_stride; 44 | index_t dout_c_stride; 45 | index_t dout_l_stride; 46 | 47 | // Common data pointers. 48 | void *__restrict__ dx_ptr; 49 | void *__restrict__ dweight_ptr; 50 | void *__restrict__ dbias_ptr; 51 | void *__restrict__ dout_ptr; 52 | }; 53 | 54 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | //////////////////////////////////////////////////////////////////////////////////////////////////// 11 | 12 | template struct BytesToType {}; 13 | 14 | template<> struct BytesToType<16> { 15 | using Type = uint4; 16 | static_assert(sizeof(Type) == 16); 17 | }; 18 | 19 | template<> struct BytesToType<8> { 20 | using Type = uint64_t; 21 | static_assert(sizeof(Type) == 8); 22 | }; 23 | 24 | template<> struct BytesToType<4> { 25 | using Type = uint32_t; 26 | static_assert(sizeof(Type) == 4); 27 | }; 28 | 29 | template<> struct BytesToType<2> { 30 | using Type = uint16_t; 31 | static_assert(sizeof(Type) == 2); 32 | }; 33 | 34 | template<> struct BytesToType<1> { 35 | using Type = uint8_t; 36 | static_assert(sizeof(Type) == 1); 37 | }; 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | template 42 | struct SumOp { 43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 44 | }; 45 | 46 | template 47 | struct Allreduce { 48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 49 | template 50 | static __device__ inline T run(T x, Operator &op) { 51 | constexpr int OFFSET = THREADS / 2; 52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); 53 | return Allreduce::run(x, op); 54 | } 55 | }; 56 | 57 | template<> 58 | struct Allreduce<2> { 59 | template 60 | static __device__ inline T run(T x, Operator &op) { 61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); 62 | return x; 63 | } 64 | }; 65 | -------------------------------------------------------------------------------- /causal-conv1d/csrc/causal_conv1d_update.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #include 10 | #include 11 | 12 | #include "causal_conv1d.h" 13 | #include "causal_conv1d_common.h" 14 | #include "static_switch.h" 15 | 16 | template 17 | struct Causal_conv1d_update_kernel_traits { 18 | using input_t = input_t_; 19 | using weight_t = weight_t_; 20 | static constexpr int kNThreads = kNThreads_; 21 | static constexpr int kWidth = kWidth_; 22 | static constexpr int kNBytes = sizeof(input_t); 23 | static_assert(kNBytes == 2 || kNBytes == 4); 24 | }; 25 | 26 | template 27 | __global__ __launch_bounds__(Ktraits::kNThreads) 28 | void causal_conv1d_update_kernel(ConvParamsBase params) { 29 | constexpr int kWidth = Ktraits::kWidth; 30 | constexpr int kNThreads = Ktraits::kNThreads; 31 | using input_t = typename Ktraits::input_t; 32 | using weight_t = typename Ktraits::weight_t; 33 | 34 | const int tidx = threadIdx.x; 35 | const int batch_id = blockIdx.x; 36 | const int channel_id = blockIdx.y * kNThreads + tidx; 37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 38 | + channel_id * params.x_c_stride; 39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride 40 | + channel_id * params.conv_state_c_stride; 41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; 42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 43 | + channel_id * params.out_c_stride; 44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); 45 | 46 | float weight_vals[kWidth] = {0}; 47 | if (channel_id < params.dim) { 48 | #pragma unroll 49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } 50 | } 51 | 52 | float x_vals[kWidth] = {0}; 53 | if (channel_id < params.dim) { 54 | #pragma unroll 55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } 56 | x_vals[kWidth - 1] = float(x[0]); 57 | #pragma unroll 58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } 59 | } 60 | 61 | float out_val = bias_val; 62 | #pragma unroll 63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } 64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } 65 | if (channel_id < params.dim) { out[0] = input_t(out_val); } 66 | } 67 | 68 | template 69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 70 | using Ktraits = Causal_conv1d_update_kernel_traits; 71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); 72 | auto kernel = &causal_conv1d_update_kernel; 73 | kernel<<>>(params); 74 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 75 | } 76 | 77 | template 78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 79 | if (params.width == 2) { 80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); 81 | } else if (params.width == 3) { 82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); 83 | } else if (params.width == 4) { 84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); 85 | } 86 | } 87 | 88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /causal-conv1d/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | static constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | static constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /causal-conv1d/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | import sys 3 | import warnings 4 | import os 5 | import re 6 | import ast 7 | from pathlib import Path 8 | from packaging.version import parse, Version 9 | import platform 10 | 11 | from setuptools import setup, find_packages 12 | import subprocess 13 | 14 | import urllib.request 15 | import urllib.error 16 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 17 | 18 | import torch 19 | from torch.utils.cpp_extension import ( 20 | BuildExtension, 21 | CppExtension, 22 | CUDAExtension, 23 | CUDA_HOME, 24 | ) 25 | 26 | 27 | with open("README.md", "r", encoding="utf-8") as fh: 28 | long_description = fh.read() 29 | 30 | 31 | # ninja build does not work unless include_dirs are abs path 32 | this_dir = os.path.dirname(os.path.abspath(__file__)) 33 | 34 | PACKAGE_NAME = "causal_conv1d" 35 | 36 | BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}" 37 | 38 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 39 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 40 | FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE" 41 | SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 42 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 43 | FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE" 44 | 45 | 46 | def get_platform(): 47 | """ 48 | Returns the platform name as used in wheel filenames. 49 | """ 50 | if sys.platform.startswith("linux"): 51 | return "linux_x86_64" 52 | elif sys.platform == "darwin": 53 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 54 | return f"macosx_{mac_version}_x86_64" 55 | elif sys.platform == "win32": 56 | return "win_amd64" 57 | else: 58 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 59 | 60 | 61 | def get_cuda_bare_metal_version(cuda_dir): 62 | raw_output = subprocess.check_output( 63 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 64 | ) 65 | output = raw_output.split() 66 | release_idx = output.index("release") + 1 67 | bare_metal_version = parse(output[release_idx].split(",")[0]) 68 | 69 | return raw_output, bare_metal_version 70 | 71 | 72 | def check_if_cuda_home_none(global_option: str) -> None: 73 | if CUDA_HOME is not None: 74 | return 75 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 76 | # in that case. 77 | warnings.warn( 78 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 79 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 80 | "only images whose names contain 'devel' will provide nvcc." 81 | ) 82 | 83 | 84 | def append_nvcc_threads(nvcc_extra_args): 85 | return nvcc_extra_args + ["--threads", "4"] 86 | 87 | 88 | cmdclass = {} 89 | ext_modules = [] 90 | 91 | if not SKIP_CUDA_BUILD: 92 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 93 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 94 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 95 | 96 | check_if_cuda_home_none("causal_conv1d") 97 | # Check, if CUDA11 is installed for compute capability 8.0 98 | cc_flag = [] 99 | if CUDA_HOME is not None: 100 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 101 | if bare_metal_version < Version("11.6"): 102 | raise RuntimeError( 103 | "causal_conv1d is only supported on CUDA 11.6 and above. " 104 | "Note: make sure nvcc has a supported version by running nvcc -V." 105 | ) 106 | 107 | cc_flag.append("-gencode") 108 | cc_flag.append("arch=compute_70,code=sm_70") 109 | cc_flag.append("-gencode") 110 | cc_flag.append("arch=compute_80,code=sm_80") 111 | if bare_metal_version >= Version("11.8"): 112 | cc_flag.append("-gencode") 113 | cc_flag.append("arch=compute_90,code=sm_90") 114 | 115 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 116 | # torch._C._GLIBCXX_USE_CXX11_ABI 117 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 118 | if FORCE_CXX11_ABI: 119 | torch._C._GLIBCXX_USE_CXX11_ABI = True 120 | 121 | ext_modules.append( 122 | CUDAExtension( 123 | name="causal_conv1d_cuda", 124 | sources=[ 125 | "csrc/causal_conv1d.cpp", 126 | "csrc/causal_conv1d_fwd.cu", 127 | "csrc/causal_conv1d_bwd.cu", 128 | "csrc/causal_conv1d_update.cu", 129 | ], 130 | extra_compile_args={ 131 | "cxx": ["-O3"], 132 | "nvcc": append_nvcc_threads( 133 | [ 134 | "-O3", 135 | "-U__CUDA_NO_HALF_OPERATORS__", 136 | "-U__CUDA_NO_HALF_CONVERSIONS__", 137 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 138 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 139 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 140 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 141 | "--expt-relaxed-constexpr", 142 | "--expt-extended-lambda", 143 | "--use_fast_math", 144 | "--ptxas-options=-v", 145 | "-lineinfo", 146 | ] 147 | + cc_flag 148 | ), 149 | }, 150 | include_dirs=[this_dir], 151 | ) 152 | ) 153 | 154 | 155 | def get_package_version(): 156 | with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f: 157 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 158 | public_version = ast.literal_eval(version_match.group(1)) 159 | local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION") 160 | if local_version: 161 | return f"{public_version}+{local_version}" 162 | else: 163 | return str(public_version) 164 | 165 | 166 | def get_wheel_url(): 167 | # Determine the version numbers that will be used to determine the correct wheel 168 | # We're using the CUDA version used to build torch, not the one currently installed 169 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 170 | torch_cuda_version = parse(torch.version.cuda) 171 | torch_version_raw = parse(torch.__version__) 172 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 173 | # to save CI time. Minor versions should be compatible. 174 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") 175 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 176 | platform_name = get_platform() 177 | causal_conv1d_version = get_package_version() 178 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" 179 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" 180 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 181 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 182 | 183 | # Determine wheel URL based on CUDA version, torch version, python version and OS 184 | wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 185 | wheel_url = BASE_WHEEL_URL.format( 186 | tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename 187 | ) 188 | return wheel_url, wheel_filename 189 | 190 | 191 | class CachedWheelsCommand(_bdist_wheel): 192 | """ 193 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 194 | find an existing wheel (which is currently the case for all installs). We use 195 | the environment parameters to detect whether there is already a pre-built version of a compatible 196 | wheel available and short-circuits the standard full build pipeline. 197 | """ 198 | 199 | def run(self): 200 | if FORCE_BUILD: 201 | return super().run() 202 | 203 | wheel_url, wheel_filename = get_wheel_url() 204 | print("Guessing wheel URL: ", wheel_url) 205 | try: 206 | urllib.request.urlretrieve(wheel_url, wheel_filename) 207 | 208 | # Make the archive 209 | # Lifted from the root wheel processing command 210 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 211 | if not os.path.exists(self.dist_dir): 212 | os.makedirs(self.dist_dir) 213 | 214 | impl_tag, abi_tag, plat_tag = self.get_tag() 215 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 216 | 217 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 218 | print("Raw wheel path", wheel_path) 219 | os.rename(wheel_filename, wheel_path) 220 | except urllib.error.HTTPError: 221 | print("Precompiled wheel not found. Building from source...") 222 | # If the wheel could not be downloaded, build from source 223 | super().run() 224 | 225 | 226 | setup( 227 | name=PACKAGE_NAME, 228 | version=get_package_version(), 229 | packages=find_packages( 230 | exclude=( 231 | "build", 232 | "csrc", 233 | "include", 234 | "tests", 235 | "dist", 236 | "docs", 237 | "benchmarks", 238 | "causal_conv1d.egg-info", 239 | ) 240 | ), 241 | author="Tri Dao", 242 | author_email="tri@tridao.me", 243 | description="Causal depthwise conv1d in CUDA, with a PyTorch interface", 244 | long_description=long_description, 245 | long_description_content_type="text/markdown", 246 | url="https://github.com/Dao-AILab/causal-conv1d", 247 | classifiers=[ 248 | "Programming Language :: Python :: 3", 249 | "License :: OSI Approved :: BSD License", 250 | "Operating System :: Unix", 251 | ], 252 | ext_modules=ext_modules, 253 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 254 | if ext_modules 255 | else { 256 | "bdist_wheel": CachedWheelsCommand, 257 | }, 258 | python_requires=">=3.7", 259 | install_requires=[ 260 | "torch", 261 | "packaging", 262 | "ninja", 263 | ], 264 | ) 265 | -------------------------------------------------------------------------------- /causal-conv1d/tests/test_causal_conv1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import pytest 7 | 8 | from einops import rearrange 9 | 10 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref 11 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("channel_last", [False, True]) 15 | # @pytest.mark.parametrize('channel_last', [True]) 16 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 17 | # @pytest.mark.parametrize('itype', [torch.float16]) 18 | @pytest.mark.parametrize("silu_activation", [False, True]) 19 | # @pytest.mark.parametrize('silu_activation', [True]) 20 | @pytest.mark.parametrize("has_bias", [False, True]) 21 | # @pytest.mark.parametrize('has_bias', [True]) 22 | @pytest.mark.parametrize("width", [2, 3, 4]) 23 | # @pytest.mark.parametrize('width', [2]) 24 | @pytest.mark.parametrize( 25 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 26 | ) 27 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 28 | # @pytest.mark.parametrize('seqlen', [128]) 29 | def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last): 30 | device = "cuda" 31 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 32 | if itype == torch.bfloat16: 33 | rtol, atol = 1e-2, 5e-2 34 | rtolw, atolw = (1e-3, 1e-3) 35 | # set seed 36 | torch.random.manual_seed(0) 37 | batch_size = 2 38 | # batch_size = 1 39 | dim = 4096 + 32 # Try dim not divisible by 64 40 | # dim = 64 41 | if not channel_last: 42 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 43 | else: 44 | x = rearrange( 45 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 46 | ).requires_grad_() 47 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 48 | if has_bias: 49 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 50 | else: 51 | bias = None 52 | x_ref = x.detach().clone().requires_grad_() 53 | weight_ref = weight.detach().clone().requires_grad_() 54 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None 55 | activation = None if not silu_activation else "silu" 56 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 57 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation) 58 | 59 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 60 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 61 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 62 | 63 | g = torch.randn_like(out) 64 | out_ref.backward(g) 65 | out.backward(g) 66 | 67 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") 68 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") 69 | if has_bias: 70 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") 71 | 72 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 73 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) 74 | if has_bias: 75 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) 76 | 77 | 78 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 79 | # @pytest.mark.parametrize('itype', [torch.float16]) 80 | @pytest.mark.parametrize("silu_activation", [False, True]) 81 | # @pytest.mark.parametrize('silu_activation', [False]) 82 | @pytest.mark.parametrize("has_bias", [False, True]) 83 | # @pytest.mark.parametrize('has_bias', [True]) 84 | @pytest.mark.parametrize("width", [2, 3, 4]) 85 | # @pytest.mark.parametrize('width', [2]) 86 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 87 | # @pytest.mark.parametrize("dim", [2048]) 88 | def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype): 89 | device = "cuda" 90 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 91 | if itype == torch.bfloat16: 92 | rtol, atol = 1e-2, 5e-2 93 | rtolw, atolw = (1e-3, 1e-3) 94 | # set seed 95 | torch.random.manual_seed(0) 96 | batch_size = 2 97 | # batch_size = 1 98 | # dim = 64 99 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 100 | conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype) 101 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 102 | if has_bias: 103 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 104 | else: 105 | bias = None 106 | conv_state_ref = conv_state.detach().clone() 107 | activation = None if not silu_activation else "silu" 108 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) 109 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation) 110 | 111 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 112 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 113 | assert torch.equal(conv_state, conv_state_ref) 114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 115 | 116 | 117 | # @pytest.mark.parametrize("channel_last", [False, True]) 118 | @pytest.mark.parametrize('channel_last', [True]) 119 | # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 120 | @pytest.mark.parametrize('itype', [torch.bfloat16]) 121 | # @pytest.mark.parametrize("silu_activation", [False, True]) 122 | @pytest.mark.parametrize('silu_activation', [True]) 123 | # @pytest.mark.parametrize("has_bias", [False, True]) 124 | @pytest.mark.parametrize('has_bias', [True]) 125 | # @pytest.mark.parametrize("width", [2, 3, 4]) 126 | @pytest.mark.parametrize('width', [4]) 127 | @pytest.mark.parametrize( 128 | # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 129 | "seqlen", [2048] 130 | ) 131 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 132 | # @pytest.mark.parametrize('seqlen', [128]) 133 | def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last): 134 | device = "cuda" 135 | # set seed 136 | torch.random.manual_seed(0) 137 | batch_size = 2 138 | # batch_size = 1 139 | dim = 4096 + 32 # Try dim not divisible by 64 140 | # dim = 64 141 | if not channel_last: 142 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 143 | else: 144 | x = rearrange( 145 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 146 | ).requires_grad_() 147 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 148 | if has_bias: 149 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 150 | else: 151 | bias = None 152 | activation = None if not silu_activation else "silu" 153 | out0 = causal_conv1d_fn(x, weight, bias, activation=activation) 154 | g = torch.randn_like(out0) 155 | dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g) 156 | dw_atol = 1e-4 157 | db_atol = 1e-4 158 | 159 | for i in range(10000): 160 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 161 | dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g) 162 | dw_equal = torch.allclose(dw, dw0, atol=dw_atol) 163 | # if not dw_equal: 164 | # breakpoint() 165 | if has_bias: 166 | db_equal = torch.allclose(db, db0, atol=db_atol) 167 | # if not db_equal: 168 | # breakpoint() 169 | assert torch.equal(out, out0) 170 | assert torch.equal(dx, dx0) 171 | assert dw_equal 172 | if has_bias: 173 | assert dw_equal 174 | -------------------------------------------------------------------------------- /fambav/.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | imnet_resnet50_scratch/timm_temp/ 4 | .dumbo.json 5 | checkpoints/ 6 | -------------------------------------------------------------------------------- /fambav/augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | """ 5 | 3Augment implementation 6 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 7 | and timm DA(https://github.com/rwightman/pytorch-image-models) 8 | """ 9 | import torch 10 | from torchvision import transforms 11 | 12 | # error: cannot import name '_pil_interp' from 'timm.data.transforms' 13 | # from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor 14 | 15 | # fix: timm version problem 16 | # from timm.data.transforms import str_pil_interp as _pil_interp 17 | from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor 18 | 19 | import numpy as np 20 | from torchvision import datasets, transforms 21 | import random 22 | 23 | 24 | 25 | from PIL import ImageFilter, ImageOps 26 | import torchvision.transforms.functional as TF 27 | 28 | 29 | class GaussianBlur(object): 30 | """ 31 | Apply Gaussian Blur to the PIL image. 32 | """ 33 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 34 | self.prob = p 35 | self.radius_min = radius_min 36 | self.radius_max = radius_max 37 | 38 | def __call__(self, img): 39 | do_it = random.random() <= self.prob 40 | if not do_it: 41 | return img 42 | 43 | img = img.filter( 44 | ImageFilter.GaussianBlur( 45 | radius=random.uniform(self.radius_min, self.radius_max) 46 | ) 47 | ) 48 | return img 49 | 50 | class Solarization(object): 51 | """ 52 | Apply Solarization to the PIL image. 53 | """ 54 | def __init__(self, p=0.2): 55 | self.p = p 56 | 57 | def __call__(self, img): 58 | if random.random() < self.p: 59 | return ImageOps.solarize(img) 60 | else: 61 | return img 62 | 63 | class gray_scale(object): 64 | """ 65 | Apply Solarization to the PIL image. 66 | """ 67 | def __init__(self, p=0.2): 68 | self.p = p 69 | self.transf = transforms.Grayscale(3) 70 | 71 | def __call__(self, img): 72 | if random.random() < self.p: 73 | return self.transf(img) 74 | else: 75 | return img 76 | 77 | 78 | 79 | class horizontal_flip(object): 80 | """ 81 | Apply Solarization to the PIL image. 82 | """ 83 | def __init__(self, p=0.2,activate_pred=False): 84 | self.p = p 85 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 86 | 87 | def __call__(self, img): 88 | if random.random() < self.p: 89 | return self.transf(img) 90 | else: 91 | return img 92 | 93 | 94 | 95 | def new_data_aug_generator(args = None): 96 | img_size = args.input_size 97 | remove_random_resized_crop = args.src 98 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 99 | primary_tfl = [] 100 | scale=(0.08, 1.0) 101 | interpolation='bicubic' 102 | if remove_random_resized_crop: 103 | primary_tfl = [ 104 | transforms.Resize(img_size, interpolation=3), 105 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), 106 | transforms.RandomHorizontalFlip() 107 | ] 108 | else: 109 | primary_tfl = [ 110 | RandomResizedCropAndInterpolation( 111 | img_size, scale=scale, interpolation=interpolation), 112 | transforms.RandomHorizontalFlip() 113 | ] 114 | 115 | 116 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), 117 | Solarization(p=1.0), 118 | GaussianBlur(p=1.0)])] 119 | 120 | if args.color_jitter is not None and not args.color_jitter==0: 121 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 122 | final_tfl = [ 123 | transforms.ToTensor(), 124 | transforms.Normalize( 125 | mean=torch.tensor(mean), 126 | std=torch.tensor(std)) 127 | ] 128 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) 129 | -------------------------------------------------------------------------------- /fambav/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | 12 | 13 | class INatDataset(ImageFolder): 14 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 15 | category='name', loader=default_loader): 16 | self.transform = transform 17 | self.loader = loader 18 | self.target_transform = target_transform 19 | self.year = year 20 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 21 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 22 | with open(path_json) as json_file: 23 | data = json.load(json_file) 24 | 25 | with open(os.path.join(root, 'categories.json')) as json_file: 26 | data_catg = json.load(json_file) 27 | 28 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 29 | 30 | with open(path_json_for_targeter) as json_file: 31 | data_for_targeter = json.load(json_file) 32 | 33 | targeter = {} 34 | indexer = 0 35 | for elem in data_for_targeter['annotations']: 36 | king = [] 37 | king.append(data_catg[int(elem['category_id'])][category]) 38 | if king[0] not in targeter.keys(): 39 | targeter[king[0]] = indexer 40 | indexer += 1 41 | self.nb_classes = len(targeter) 42 | 43 | self.samples = [] 44 | for elem in data['images']: 45 | cut = elem['file_name'].split('/') 46 | target_current = int(cut[2]) 47 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 48 | 49 | categors = data_catg[target_current] 50 | target_current_true = targeter[categors[category]] 51 | self.samples.append((path_current, target_current_true)) 52 | 53 | # __getitem__ and __len__ inherited from ImageFolder 54 | 55 | 56 | def build_dataset(is_train, args): 57 | transform = build_transform(is_train, args) 58 | 59 | if args.data_set == 'CIFAR': 60 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 61 | nb_classes = 100 62 | elif args.data_set == 'IMNET': 63 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 64 | dataset = datasets.ImageFolder(root, transform=transform) 65 | nb_classes = 1000 66 | elif args.data_set == 'INAT': 67 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 68 | category=args.inat_category, transform=transform) 69 | nb_classes = dataset.nb_classes 70 | elif args.data_set == 'INAT19': 71 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 72 | category=args.inat_category, transform=transform) 73 | nb_classes = dataset.nb_classes 74 | 75 | return dataset, nb_classes 76 | 77 | 78 | def build_transform(is_train, args): 79 | resize_im = args.input_size > 32 80 | if is_train: 81 | # this should always dispatch to transforms_imagenet_train 82 | transform = create_transform( 83 | input_size=args.input_size, 84 | is_training=True, 85 | color_jitter=args.color_jitter, 86 | auto_augment=args.aa, 87 | interpolation=args.train_interpolation, 88 | re_prob=args.reprob, 89 | re_mode=args.remode, 90 | re_count=args.recount, 91 | ) 92 | if not resize_im: 93 | # replace RandomResizedCropAndInterpolation with 94 | # RandomCrop 95 | transform.transforms[0] = transforms.RandomCrop( 96 | args.input_size, padding=4) 97 | return transform 98 | 99 | t = [] 100 | if resize_im: 101 | size = int(args.input_size / args.eval_crop_ratio) 102 | t.append( 103 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 104 | ) 105 | t.append(transforms.CenterCrop(args.input_size)) 106 | 107 | t.append(transforms.ToTensor()) 108 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 109 | return transforms.Compose(t) 110 | -------------------------------------------------------------------------------- /fambav/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | 10 | import torch 11 | 12 | import timm 13 | from timm.data import Mixup 14 | from timm.utils import accuracy, ModelEma 15 | 16 | from losses import DistillationLoss 17 | import utils 18 | 19 | 20 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 21 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 22 | device: torch.device, epoch: int, loss_scaler, amp_autocast, max_norm: float = 0, 23 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 24 | set_training_mode=True, args = None): 25 | model.train(set_training_mode) 26 | metric_logger = utils.MetricLogger(delimiter=" ") 27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | header = 'Epoch: [{}]'.format(epoch) 29 | print_freq = 10 30 | 31 | if args.cosub: 32 | criterion = torch.nn.BCEWithLogitsLoss() 33 | 34 | # debug 35 | # count = 0 36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 37 | # count += 1 38 | # if count > 20: 39 | # break 40 | 41 | samples = samples.to(device, non_blocking=True) 42 | targets = targets.to(device, non_blocking=True) 43 | 44 | if mixup_fn is not None: 45 | samples, targets = mixup_fn(samples, targets) 46 | 47 | if args.cosub: 48 | samples = torch.cat((samples,samples),dim=0) 49 | 50 | if args.bce_loss: 51 | targets = targets.gt(0.0).type(targets.dtype) 52 | 53 | with amp_autocast(): 54 | outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank) 55 | # outputs = model(samples) 56 | if not args.cosub: 57 | loss = criterion(samples, outputs, targets) 58 | else: 59 | outputs = torch.split(outputs, outputs.shape[0]//2, dim=0) 60 | loss = 0.25 * criterion(outputs[0], targets) 61 | loss = loss + 0.25 * criterion(outputs[1], targets) 62 | loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid()) 63 | loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 64 | 65 | if args.if_nan2num: 66 | with amp_autocast(): 67 | loss = torch.nan_to_num(loss) 68 | 69 | loss_value = loss.item() 70 | 71 | if not math.isfinite(loss_value): 72 | print("Loss is {}, stopping training".format(loss_value)) 73 | if args.if_continue_inf: 74 | optimizer.zero_grad() 75 | continue 76 | else: 77 | sys.exit(1) 78 | 79 | optimizer.zero_grad() 80 | 81 | # this attribute is added by timm on one optimizer (adahessian) 82 | if isinstance(loss_scaler, timm.utils.NativeScaler): 83 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 84 | loss_scaler(loss, optimizer, clip_grad=max_norm, 85 | parameters=model.parameters(), create_graph=is_second_order) 86 | else: 87 | loss.backward() 88 | if max_norm != None: 89 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 90 | optimizer.step() 91 | 92 | torch.cuda.synchronize() 93 | if model_ema is not None: 94 | model_ema.update(model) 95 | 96 | metric_logger.update(loss=loss_value) 97 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 98 | # gather the stats from all processes 99 | metric_logger.synchronize_between_processes() 100 | print("Averaged stats:", metric_logger) 101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 102 | 103 | 104 | @torch.no_grad() 105 | def evaluate(data_loader, model, device, amp_autocast): 106 | criterion = torch.nn.CrossEntropyLoss() 107 | 108 | metric_logger = utils.MetricLogger(delimiter=" ") 109 | header = 'Test:' 110 | 111 | # switch to evaluation mode 112 | model.eval() 113 | 114 | for images, target in metric_logger.log_every(data_loader, 10, header): 115 | images = images.to(device, non_blocking=True) 116 | target = target.to(device, non_blocking=True) 117 | 118 | # compute output 119 | with amp_autocast(): 120 | output = model(images) 121 | loss = criterion(output, target) 122 | 123 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 124 | 125 | batch_size = images.shape[0] 126 | metric_logger.update(loss=loss.item()) 127 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 128 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 129 | # gather the stats from all processes 130 | metric_logger.synchronize_between_processes() 131 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 132 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 133 | 134 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 135 | -------------------------------------------------------------------------------- /fambav/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from models import * 4 | from cait_models import * 5 | from resmlp_models import * 6 | #from patchconvnet_models import * 7 | 8 | dependencies = ["torch", "torchvision", "timm"] 9 | -------------------------------------------------------------------------------- /fambav/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | #We provide the teacher's targets in log probability because we use log_target=True 57 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 58 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 59 | F.log_softmax(teacher_outputs / T, dim=1), 60 | reduction='sum', 61 | log_target=True 62 | ) * (T * T) / outputs_kd.numel() 63 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 64 | #But we also experiments output_kd.size(0) 65 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 66 | elif self.distillation_type == 'hard': 67 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 68 | 69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 70 | return loss 71 | -------------------------------------------------------------------------------- /fambav/mlruns/0/meta.yaml: -------------------------------------------------------------------------------- 1 | artifact_location: file:///users/PAS2490/marcusshen/Vim/vim/mlruns/0 2 | creation_time: 1715856885741 3 | experiment_id: '0' 4 | last_update_time: 1715856885741 5 | lifecycle_stage: active 6 | name: Default 7 | -------------------------------------------------------------------------------- /fambav/rope.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # EVA-02: A Visual Representation for Neon Genesis 3 | # Github source: https://github.com/baaivision/EVA/EVA02 4 | # Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI) 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Yuxin Fang 7 | # 8 | # Based on https://github.com/lucidrains/rotary-embedding-torch 9 | # --------------------------------------------------------' 10 | 11 | from math import pi 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from einops import rearrange, repeat 17 | 18 | 19 | 20 | def broadcat(tensors, dim = -1): 21 | num_tensors = len(tensors) 22 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 23 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 24 | shape_len = list(shape_lens)[0] 25 | dim = (dim + shape_len) if dim < 0 else dim 26 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 27 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 28 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 29 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 30 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 31 | expanded_dims.insert(dim, (dim, dims[dim])) 32 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 33 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 34 | return torch.cat(tensors, dim = dim) 35 | 36 | 37 | 38 | def rotate_half(x): 39 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 40 | x1, x2 = x.unbind(dim = -1) 41 | x = torch.stack((-x2, x1), dim = -1) 42 | return rearrange(x, '... d r -> ... (d r)') 43 | 44 | 45 | 46 | class VisionRotaryEmbedding(nn.Module): 47 | def __init__( 48 | self, 49 | dim, 50 | pt_seq_len, 51 | ft_seq_len=None, 52 | custom_freqs = None, 53 | freqs_for = 'lang', 54 | theta = 10000, 55 | max_freq = 10, 56 | num_freqs = 1, 57 | ): 58 | super().__init__() 59 | if custom_freqs: 60 | freqs = custom_freqs 61 | elif freqs_for == 'lang': 62 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 63 | elif freqs_for == 'pixel': 64 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 65 | elif freqs_for == 'constant': 66 | freqs = torch.ones(num_freqs).float() 67 | else: 68 | raise ValueError(f'unknown modality {freqs_for}') 69 | 70 | if ft_seq_len is None: ft_seq_len = pt_seq_len 71 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 72 | 73 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 74 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 75 | 76 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 77 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 78 | 79 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 80 | 81 | self.register_buffer("freqs_cos", freqs.cos()) 82 | self.register_buffer("freqs_sin", freqs.sin()) 83 | 84 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 85 | 86 | def forward(self, t, start_index = 0): 87 | rot_dim = self.freqs_cos.shape[-1] 88 | end_index = start_index + rot_dim 89 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 90 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 91 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 92 | return torch.cat((t_left, t, t_right), dim = -1) 93 | 94 | 95 | 96 | class VisionRotaryEmbeddingFast(nn.Module): 97 | def __init__( 98 | self, 99 | dim, 100 | pt_seq_len=16, 101 | ft_seq_len=None, 102 | custom_freqs = None, 103 | freqs_for = 'lang', 104 | theta = 10000, 105 | max_freq = 10, 106 | num_freqs = 1, 107 | ): 108 | super().__init__() 109 | if custom_freqs: 110 | freqs = custom_freqs 111 | elif freqs_for == 'lang': 112 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 113 | elif freqs_for == 'pixel': 114 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 115 | elif freqs_for == 'constant': 116 | freqs = torch.ones(num_freqs).float() 117 | else: 118 | raise ValueError(f'unknown modality {freqs_for}') 119 | 120 | if ft_seq_len is None: ft_seq_len = pt_seq_len 121 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 122 | 123 | freqs = torch.einsum('..., f -> ... f', t, freqs) 124 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 125 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 126 | 127 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 128 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 129 | 130 | self.register_buffer("freqs_cos", freqs_cos) 131 | self.register_buffer("freqs_sin", freqs_sin) 132 | 133 | print('======== shape of rope freq', self.freqs_cos.shape, '========') 134 | 135 | def forward(self, t): 136 | if t.shape[1] % 2 != 0: 137 | t_spatial = t[:, 1:, :] 138 | t_spatial = t_spatial * self.freqs_cos + rotate_half(t_spatial) * self.freqs_sin 139 | return torch.cat((t[:, :1, :], t_spatial), dim=1) 140 | else: 141 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin -------------------------------------------------------------------------------- /fambav/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | A script to run multinode training with submitit. 5 | """ 6 | import argparse 7 | import os 8 | import uuid 9 | from pathlib import Path 10 | 11 | import main as classification 12 | import submitit 13 | 14 | 15 | def parse_args(): 16 | classification_parser = classification.get_args_parser() 17 | parser = argparse.ArgumentParser("Submitit for DeiT", parents=[classification_parser]) 18 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 19 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 20 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 21 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 22 | 23 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 24 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 25 | parser.add_argument('--comment', default="", type=str, 26 | help='Comment to pass to scheduler, e.g. priority message') 27 | return parser.parse_args() 28 | 29 | 30 | def get_shared_folder() -> Path: 31 | user = os.getenv("USER") 32 | if Path("/checkpoint/").is_dir(): 33 | p = Path(f"/checkpoint/{user}/experiments") 34 | p.mkdir(exist_ok=True) 35 | return p 36 | raise RuntimeError("No shared folder available") 37 | 38 | 39 | def get_init_file(): 40 | # Init file must not exist, but it's parent dir must exist. 41 | os.makedirs(str(get_shared_folder()), exist_ok=True) 42 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 43 | if init_file.exists(): 44 | os.remove(str(init_file)) 45 | return init_file 46 | 47 | 48 | class Trainer(object): 49 | def __init__(self, args): 50 | self.args = args 51 | 52 | def __call__(self): 53 | import main as classification 54 | 55 | self._setup_gpu_args() 56 | classification.main(self.args) 57 | 58 | def checkpoint(self): 59 | import os 60 | import submitit 61 | 62 | self.args.dist_url = get_init_file().as_uri() 63 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 64 | if os.path.exists(checkpoint_file): 65 | self.args.resume = checkpoint_file 66 | print("Requeuing ", self.args) 67 | empty_trainer = type(self)(self.args) 68 | return submitit.helpers.DelayedSubmission(empty_trainer) 69 | 70 | def _setup_gpu_args(self): 71 | import submitit 72 | from pathlib import Path 73 | 74 | job_env = submitit.JobEnvironment() 75 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 76 | self.args.gpu = job_env.local_rank 77 | self.args.rank = job_env.global_rank 78 | self.args.world_size = job_env.num_tasks 79 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 80 | 81 | 82 | def main(): 83 | args = parse_args() 84 | if args.job_dir == "": 85 | args.job_dir = get_shared_folder() / "%j" 86 | 87 | # Note that the folder will depend on the job_id, to easily track experiments 88 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 89 | 90 | num_gpus_per_node = args.ngpus 91 | nodes = args.nodes 92 | timeout_min = args.timeout 93 | 94 | partition = args.partition 95 | kwargs = {} 96 | if args.use_volta32: 97 | kwargs['slurm_constraint'] = 'volta32gb' 98 | if args.comment: 99 | kwargs['slurm_comment'] = args.comment 100 | 101 | executor.update_parameters( 102 | mem_gb=40 * num_gpus_per_node, 103 | gpus_per_node=num_gpus_per_node, 104 | tasks_per_node=num_gpus_per_node, # one task per GPU 105 | cpus_per_task=10, 106 | nodes=nodes, 107 | timeout_min=timeout_min, # max is 60 * 72 108 | # Below are cluster dependent parameters 109 | slurm_partition=partition, 110 | slurm_signal_delay_s=120, 111 | **kwargs 112 | ) 113 | 114 | executor.update_parameters(name="deit") 115 | 116 | args.dist_url = get_init_file().as_uri() 117 | args.output_dir = args.job_dir 118 | 119 | trainer = Trainer(args) 120 | job = executor.submit(trainer) 121 | 122 | print("Submitted job_id:", job.job_id) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /fambav/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | if num_repeats < 1: 26 | raise ValueError("num_repeats should be greater than 0") 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.num_repeats = num_repeats 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g) 44 | else: 45 | indices = torch.arange(start=0, end=len(self.dataset)) 46 | 47 | # add extra samples to make it evenly divisible 48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 49 | padding_size: int = self.total_size - len(indices) 50 | if padding_size > 0: 51 | indices += indices[:padding_size] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /fambav/scripts/mlruns/0/meta.yaml: -------------------------------------------------------------------------------- 1 | artifact_location: file:///users/PAS2490/marcusshen/Vim/vim/scripts/mlruns/0 2 | creation_time: 1718284550737 3 | experiment_id: '0' 4 | last_update_time: 1718284550737 5 | lifecycle_stage: active 6 | name: Default 7 | -------------------------------------------------------------------------------- /fambav/scripts/vim-s-cifar-all-merge-r7.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-s-cifar-all-merge-r7 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-all-merge-r7.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-all-merge-r7_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=35:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_all_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy all --fusion-layer 5 --fusion-token 7 -------------------------------------------------------------------------------- /fambav/scripts/vim-s-cifar-alternate-merge-r14.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-s-cifar-alternate-merge-r14 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-alternate-merge-r14.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-alternate-merge-r14_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=35:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_alternate_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy alternate --fusion-layer 5 --fusion-token 14 -------------------------------------------------------------------------------- /fambav/scripts/vim-s-cifar-lower-merge-r9-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-s-cifar-lower-merge-r9-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-lower-merge-r9-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-lower-merge-r9-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=35:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r9l4_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 9 -------------------------------------------------------------------------------- /fambav/scripts/vim-s-cifar-lower-merge-r9-l5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-s-cifar-lower-merge-r9-l5 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-lower-merge-r9-l5.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-lower-merge-r9-l5_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=35:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r9l5_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 5 --fusion-token 9 -------------------------------------------------------------------------------- /fambav/scripts/vim-s-cifar-non-merge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-s-cifar-non-merge # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-non-merge.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-non-merge_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=35:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_non_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy no --fusion-layer 5 --fusion-token 7 -------------------------------------------------------------------------------- /fambav/scripts/vim-s-cifar-upper-merge-r9-l18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-s-cifar-upper-merge-r9-l18 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-upper-merge-r9-l18.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-upper-merge-r9-l18_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=35:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_upper_r9l18_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy upper --fusion-layer 18 --fusion-token 9 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-all-merge-r7-cls.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-all-merge-r7-cls # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-all-merge-r7-cls.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-all-merge-r7-cls_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_all_cls_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy all --fusion-layer 5 --fusion-token 7 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-all-merge-r7.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-all-merge-r7 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-all-merge-r7.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-all-merge-r7_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_all_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy all --fusion-layer 5 --fusion-token 7 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-alternate-merge-r14.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-alternate-merge-r14 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-alternate-merge-r14.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-alternate-merge-r14_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_alternate_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy alternate --fusion-layer 5 --fusion-token 14 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r1-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r1-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r1-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r1-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r1l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 1 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r10-l6.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r10-l6 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r10-l6.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r10-l6_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r10l6_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 6 --fusion-token 10 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r10-l7.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r10-l7 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r10-l7.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r10-l7_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r10l7_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 7 --fusion-token 10 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r11-l8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r11-l8 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r11-l8.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r11-l8_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r11l8_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 8 --fusion-token 11 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r12-l10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r12-l10 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r12-l10.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r12-l10_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r12l10_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 10 --fusion-token 12 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r12-l9.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r12-l9 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r12-l9.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r12-l9_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r12l9_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 9 --fusion-token 12 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r13-l11.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r13-l11 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r13-l11.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r13-l11_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r13l11_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 11 --fusion-token 13 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r14-l12.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r14-l12 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r14-l12.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r14-l12_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r14l12_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 12 --fusion-token 14 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r15-l13.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r15-l13 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r15-l13.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r15-l13_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r15l13_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 13 --fusion-token 15 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r2-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r2-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r2-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r2-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r2l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 2 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r3-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r3-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r3-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r3-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r3l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 3 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r4-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r4-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r4-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r4-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r4l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 4 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r5-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r5-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r5-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r5-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r5l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 5 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r6-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r6-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r6-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r6-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r6l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 6 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r7-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r7-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r7-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r7-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r7l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 7 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r8-l3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r8-l3 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r8-l3.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r8-l3_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r8l3_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 3 --fusion-token 8 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r8-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r8-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r8-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r8-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r8l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 8 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r9-l4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r9-l4 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r9-l4.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r9-l4_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r9l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 9 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-lower-merge-r9-l5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r9-l5 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r9-l5.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r9-l5_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 5 --fusion-token 9 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-non-merge-visual.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-non-merge-visual # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-non-merge-visual.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-non-merge-visual_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/visual_merge_cifar_non_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy no --fusion-layer 5 --fusion-token 7 --visualize -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-non-merge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-non-merge # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-non-merge.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-non-merge_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_non_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy no --fusion-layer 5 --fusion-token 7 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-cifar-upper-merge-r9-l18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-cifar-upper-merge-r9-l18 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-upper-merge-r9-l18.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-upper-merge-r9-l18_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=6:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_upper_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy upper --fusion-layer 18 --fusion-token 9 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-imagenet-all-merge-r7.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-imagenet-all-merge-r7-2 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-all-merge-r7-2.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-all-merge-r7-2_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=120:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_all_2_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy all --fusion-layer 5 --fusion-token 7 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-imagenet-alternate-merge-r14.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-imagenet-alternate-merge-r14-2 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-alternate-merge-r14-2.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-alternate-merge-r14-2_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=120:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_alternate_2_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy alternate --fusion-layer 5 --fusion-token 14 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-imagenet-lower-merge-r9-l5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-imagenet-lower-merge-r9-l5 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-lower-merge-r9-l5.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-lower-merge-r9-l5_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=120:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_lower_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 5 --fusion-token 9 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-imagenet-non-merge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-imagenet-non-merge # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-non-merge.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-non-merge_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=120:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_non_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy no --fusion-layer 8 --fusion-token 9 -------------------------------------------------------------------------------- /fambav/scripts/vim-t-imagenet-upper-merge-r9-l18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim-t-imagenet-upper-merge-r9-l18 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-upper-merge-r9-l18.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-upper-merge-r9-l18_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=120:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_upper_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy upper --fusion-layer 18 --fusion-token 9 -------------------------------------------------------------------------------- /fambav/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=vim_1 # 作业名称 4 | #SBATCH --account=PAS2490 # Project ID 5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim_1.log # 输出日志文件 6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim_1_error.log # 错误日志文件 7 | #SBATCH --nodes=1 # 节点数 8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数 9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数 10 | #SBATCH --gpus-per-node=4 # GPU per node 11 | #SBATCH --mem=80G # 内存限制 12 | #SBATCH --time=04:00:00 # 作业运行时间限制 13 | 14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh 15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim 16 | # module load cuda 17 | export CUDA_VISIBLE_DEVICES=0,1,2,3 18 | 19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ./output/vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp 20 | -------------------------------------------------------------------------------- /fambav/token_merge.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | from typing import Callable, Tuple 10 | 11 | import torch 12 | 13 | 14 | def do_nothing(x, mode=None): 15 | return x 16 | 17 | 18 | def bipartite_soft_matching( 19 | metric: torch.Tensor, 20 | r: int, 21 | class_token: bool = False, 22 | distill_token: bool = False, 23 | ) -> Tuple[Callable, Callable]: 24 | """ 25 | Applies ToMe with a balanced matching set (50%, 50%). 26 | 27 | Input size is [batch, tokens, channels]. 28 | r indicates the number of tokens to remove (max 50% of tokens). 29 | 30 | Extra args: 31 | - class_token: Whether or not there's a class token. 32 | - distill_token: Whether or not there's also a distillation token. 33 | 34 | When enabled, the class token and distillation tokens won't get merged. 35 | """ 36 | protected = 0 37 | if class_token: 38 | protected += 1 39 | if distill_token: 40 | protected += 1 41 | 42 | # We can only reduce by a maximum of 50% tokens 43 | t = metric.shape[1] 44 | r = min(r, (t - protected) // 2) 45 | 46 | if r <= 0: 47 | return do_nothing, do_nothing 48 | 49 | with torch.no_grad(): 50 | metric = metric / metric.norm(dim=-1, keepdim=True) 51 | a, b = metric[..., ::2, :], metric[..., 1::2, :] 52 | scores = a @ b.transpose(-1, -2) 53 | 54 | if class_token: 55 | scores[..., 0, :] = -math.inf 56 | if distill_token: 57 | scores[..., :, 0] = -math.inf 58 | 59 | node_max, node_idx = scores.max(dim=-1) 60 | edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] 61 | 62 | unm_idx = edge_idx[..., r:, :] # Unmerged Tokens 63 | src_idx = edge_idx[..., :r, :] # Merged Tokens 64 | dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) 65 | 66 | if class_token: 67 | # Sort to ensure the class token is at the start 68 | unm_idx = unm_idx.sort(dim=1)[0] 69 | 70 | def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: 71 | src, dst = x[..., ::2, :], x[..., 1::2, :] 72 | n, t1, c = src.shape 73 | unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c)) 74 | src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) 75 | dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) 76 | 77 | if distill_token: 78 | return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1) 79 | else: 80 | return torch.cat([unm, dst], dim=1) 81 | 82 | def unmerge(x: torch.Tensor) -> torch.Tensor: 83 | unm_len = unm_idx.shape[1] 84 | unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] 85 | n, _, c = unm.shape 86 | 87 | src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)) 88 | 89 | out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype) 90 | 91 | out[..., 1::2, :] = dst 92 | out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm) 93 | out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src) 94 | 95 | return out 96 | 97 | return merge, unmerge 98 | 99 | 100 | def kth_bipartite_soft_matching( 101 | metric: torch.Tensor, k: int 102 | ) -> Tuple[Callable, Callable]: 103 | """ 104 | Applies ToMe with the two sets as (every kth element, the rest). 105 | If n is the number of tokens, resulting number of tokens will be n // z. 106 | 107 | Input size is [batch, tokens, channels]. 108 | z indicates the stride for the first set. 109 | z = 2 is equivalent to regular bipartite_soft_matching with r = 0.5 * N 110 | """ 111 | if k <= 1: 112 | return do_nothing, do_nothing 113 | 114 | def split(x): 115 | t_rnd = (x.shape[1] // k) * k 116 | x = x[:, :t_rnd, :].view(x.shape[0], -1, k, x.shape[2]) 117 | a, b = ( 118 | x[:, :, : (k - 1), :].contiguous().view(x.shape[0], -1, x.shape[-1]), 119 | x[:, :, (k - 1), :], 120 | ) 121 | return a, b 122 | 123 | with torch.no_grad(): 124 | metric = metric / metric.norm(dim=-1, keepdim=True) 125 | a, b = split(metric) 126 | r = a.shape[1] 127 | scores = a @ b.transpose(-1, -2) 128 | 129 | _, dst_idx = scores.max(dim=-1) 130 | dst_idx = dst_idx[..., None] 131 | 132 | def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: 133 | src, dst = split(x) 134 | n, _, c = src.shape 135 | dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) 136 | 137 | return dst 138 | 139 | def unmerge(x: torch.Tensor) -> torch.Tensor: 140 | n, _, c = x.shape 141 | dst = x 142 | 143 | src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)).to(x.dtype) 144 | 145 | src = src.view(n, -1, (k - 1), c) 146 | dst = dst.view(n, -1, 1, c) 147 | 148 | out = torch.cat([src, dst], dim=-2) 149 | out = out.contiguous().view(n, -1, c) 150 | 151 | return out 152 | 153 | return merge, unmerge 154 | 155 | 156 | def random_bipartite_soft_matching( 157 | metric: torch.Tensor, r: int 158 | ) -> Tuple[Callable, Callable]: 159 | """ 160 | Applies ToMe with the two sets as (r chosen randomly, the rest). 161 | Input size is [batch, tokens, channels]. 162 | 163 | This will reduce the number of tokens by r. 164 | """ 165 | if r <= 0: 166 | return do_nothing, do_nothing 167 | 168 | with torch.no_grad(): 169 | B, N, _ = metric.shape 170 | rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1) 171 | 172 | a_idx = rand_idx[:, :r, :] 173 | b_idx = rand_idx[:, r:, :] 174 | 175 | def split(x): 176 | C = x.shape[-1] 177 | a = x.gather(dim=1, index=a_idx.expand(B, r, C)) 178 | b = x.gather(dim=1, index=b_idx.expand(B, N - r, C)) 179 | return a, b 180 | 181 | metric = metric / metric.norm(dim=-1, keepdim=True) 182 | a, b = split(metric) 183 | scores = a @ b.transpose(-1, -2) 184 | 185 | _, dst_idx = scores.max(dim=-1) 186 | dst_idx = dst_idx[..., None] 187 | 188 | def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: 189 | src, dst = split(x) 190 | C = src.shape[-1] 191 | dst = dst.scatter_reduce(-2, dst_idx.expand(B, r, C), src, reduce=mode) 192 | 193 | return dst 194 | 195 | def unmerge(x: torch.Tensor) -> torch.Tensor: 196 | C = x.shape[-1] 197 | dst = x 198 | src = dst.gather(dim=-2, index=dst_idx.expand(B, r, C)) 199 | 200 | out = torch.zeros(B, N, C, device=x.device, dtype=x.dtype) 201 | 202 | out.scatter_(dim=-2, index=a_idx.expand(B, r, C), src=src) 203 | out.scatter_(dim=-2, index=b_idx.expand(B, N - r, C), src=dst) 204 | 205 | return out 206 | 207 | return merge, unmerge 208 | 209 | 210 | def merge_wavg( 211 | merge: Callable, x: torch.Tensor, size: torch.Tensor = None 212 | ) -> Tuple[torch.Tensor, torch.Tensor]: 213 | """ 214 | Applies the merge function by taking a weighted average based on token size. 215 | Returns the merged tensor and the new token sizes. 216 | """ 217 | if size is None: 218 | size = torch.ones_like(x[..., 0, None]) 219 | 220 | x = merge(x * size, mode="sum") 221 | size = merge(size, mode="sum") 222 | 223 | x = x / size 224 | return x, size 225 | 226 | 227 | def merge_source( 228 | merge: Callable, x: torch.Tensor, source: torch.Tensor = None 229 | ) -> torch.Tensor: 230 | """ 231 | For source tracking. Source is an adjacency matrix between the initial tokens and final merged groups. 232 | x is used to find out how many tokens there are in case the source is None. 233 | """ 234 | if source is None: 235 | n, t, _ = x.shape 236 | source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t) 237 | 238 | source = merge(source, mode="amax") 239 | return source -------------------------------------------------------------------------------- /fambav/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = f"{total_time:.6f}" # 保留六位小数,即微秒级 158 | print(f'{header} Total time: {total_time_str} s ({total_time / len(iterable):.4f} s / it)') 159 | # total_time_str = str(datetime.timedelta(seconds=int(total_time))) 160 | # print('{} Total time: {} ({:.4f} s / it)'.format( 161 | # header, total_time_str, total_time / len(iterable))) 162 | 163 | 164 | 165 | def _load_checkpoint_for_ema(model_ema, checkpoint): 166 | """ 167 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 168 | """ 169 | mem_file = io.BytesIO() 170 | torch.save({'state_dict_ema':checkpoint}, mem_file) 171 | mem_file.seek(0) 172 | model_ema._load_checkpoint(mem_file) 173 | 174 | 175 | def setup_for_distributed(is_master): 176 | """ 177 | This function disables printing when not in master process 178 | """ 179 | import builtins as __builtin__ 180 | builtin_print = __builtin__.print 181 | 182 | def print(*args, **kwargs): 183 | force = kwargs.pop('force', False) 184 | if is_master or force: 185 | builtin_print(*args, **kwargs) 186 | 187 | __builtin__.print = print 188 | 189 | 190 | def is_dist_avail_and_initialized(): 191 | if not dist.is_available(): 192 | return False 193 | if not dist.is_initialized(): 194 | return False 195 | return True 196 | 197 | 198 | def get_world_size(): 199 | if not is_dist_avail_and_initialized(): 200 | return 1 201 | return dist.get_world_size() 202 | 203 | 204 | def get_rank(): 205 | if not is_dist_avail_and_initialized(): 206 | return 0 207 | return dist.get_rank() 208 | 209 | 210 | def is_main_process(): 211 | return get_rank() == 0 212 | 213 | 214 | def save_on_master(*args, **kwargs): 215 | if is_main_process(): 216 | torch.save(*args, **kwargs) 217 | 218 | 219 | def init_distributed_mode(args): 220 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 221 | args.rank = int(os.environ["RANK"]) 222 | args.world_size = int(os.environ['WORLD_SIZE']) 223 | args.gpu = int(os.environ['LOCAL_RANK']) 224 | elif 'SLURM_PROCID' in os.environ: 225 | args.rank = int(os.environ['SLURM_PROCID']) 226 | args.gpu = args.rank % torch.cuda.device_count() 227 | else: 228 | print('Not using distributed mode') 229 | args.distributed = False 230 | return 231 | 232 | args.distributed = True 233 | 234 | torch.cuda.set_device(args.gpu) 235 | args.dist_backend = 'nccl' 236 | print('| distributed init (rank {}): {}'.format( 237 | args.rank, args.dist_url), flush=True) 238 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 239 | world_size=args.world_size, rank=args.rank) 240 | torch.distributed.barrier() 241 | setup_for_distributed(args.rank == 0) 242 | 243 | 244 | # if 'pos_embed' in state_dict: 245 | def interpolate_pos_embed(model, state_dict): 246 | pos_embed_checkpoint = state_dict['pos_embed'] 247 | embedding_size = pos_embed_checkpoint.shape[-1] 248 | num_patches = model.patch_embed.num_patches 249 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 250 | # height (== width) for the checkpoint position embedding 251 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 252 | # height (== width) for the new position embedding 253 | new_size = int(num_patches ** 0.5) 254 | # class_token and dist_token are kept unchanged 255 | if orig_size != new_size: 256 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 257 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 258 | # only the position tokens are interpolated 259 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 260 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 261 | pos_tokens = torch.nn.functional.interpolate( 262 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 263 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 264 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 265 | state_dict['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /fambav/vim_requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | aiohttp==3.9.1 3 | aiosignal==1.3.1 4 | alembic==1.13.0 5 | async-timeout==4.0.3 6 | attrs==23.1.0 7 | blinker==1.7.0 8 | # causal-conv1d @ file:///home/zhulianghui/VisionProjects/mamba/lib/causal_conv1d-1.0.0%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=79a4bab633ebff031e615d5e8ba396b0dc0c046f4406980ee238fb86a9090038 9 | certifi==2023.11.17 10 | charset-normalizer==3.3.2 11 | click==8.1.7 12 | cloudpickle==3.0.0 13 | contourpy==1.2.0 14 | cycler==0.12.1 15 | databricks-cli==0.18.0 16 | datasets==2.15.0 17 | dill==0.3.7 18 | docker==6.1.3 19 | einops==0.7.0 20 | entrypoints==0.4 21 | filelock==3.13.1 22 | Flask==3.0.0 23 | fonttools==4.46.0 24 | frozenlist==1.4.0 25 | fsspec==2023.10.0 26 | gitdb==4.0.11 27 | GitPython==3.1.40 28 | greenlet==3.0.2 29 | gunicorn==21.2.0 30 | huggingface-hub==0.19.4 31 | idna==3.6 32 | importlib-metadata==7.0.0 33 | itsdangerous==2.1.2 34 | Jinja2==3.1.2 35 | joblib==1.3.2 36 | kiwisolver==1.4.5 37 | Mako==1.3.0 38 | # mamba-ssm @ file:///home/zhulianghui/VisionProjects/mamba/lib/mamba_ssm-1.0.1%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=71ad1b1eafb05a6e8a41fd82e046fe85511d6378fa3a583e55215b6aa1d65ab9 39 | Markdown==3.5.1 40 | MarkupSafe==2.1.3 41 | matplotlib==3.8.2 42 | mlflow==2.9.1 43 | mmcv==1.3.8 44 | mmsegmentation==0.14.1 45 | mpmath==1.3.0 46 | multidict==6.0.4 47 | multiprocess==0.70.15 48 | networkx==3.2.1 49 | ninja==1.11.1.1 50 | numpy==1.26.2 51 | # nvidia-cublas-cu12==12.1.3.1 52 | # nvidia-cuda-cupti-cu12==12.1.105 53 | # nvidia-cuda-nvrtc-cu12==12.1.105 54 | # nvidia-cuda-runtime-cu12==12.1.105 55 | # nvidia-cudnn-cu12==8.9.2.26 56 | # nvidia-cufft-cu12==11.0.2.54 57 | # nvidia-curand-cu12==10.3.2.106 58 | # nvidia-cusolver-cu12==11.4.5.107 59 | # nvidia-cusparse-cu12==12.1.0.106 60 | # nvidia-nccl-cu12==2.18.1 61 | # nvidia-nvjitlink-cu12==12.3.101 62 | # nvidia-nvtx-cu12==12.1.105 63 | oauthlib==3.2.2 64 | opencv-python==4.8.1.78 65 | packaging==23.2 66 | pandas==2.1.3 67 | Pillow==10.1.0 68 | platformdirs==4.1.0 69 | prettytable==3.9.0 70 | protobuf==4.25.1 71 | pyarrow==14.0.1 72 | pyarrow-hotfix==0.6 73 | PyJWT==2.8.0 74 | pyparsing==3.1.1 75 | python-dateutil==2.8.2 76 | python-hostlist==1.23.0 77 | pytz==2023.3.post1 78 | PyYAML==6.0.1 79 | querystring-parser==1.2.4 80 | regex==2023.10.3 81 | requests==2.31.0 82 | safetensors==0.4.1 83 | scikit-learn==1.3.2 84 | scipy==1.11.4 85 | six==1.16.0 86 | smmap==5.0.1 87 | SQLAlchemy==2.0.23 88 | sqlparse==0.4.4 89 | sympy==1.12 90 | tabulate==0.9.0 91 | threadpoolctl==3.2.0 92 | timm==0.4.12 93 | tokenizers==0.15.0 94 | tomli==2.0.1 95 | # torch==2.1.1+cu118 96 | # torchvision==0.16.1+cu118 97 | tqdm==4.66.1 98 | transformers==4.35.2 99 | triton==2.1.0 100 | typing_extensions==4.8.0 101 | tzdata==2023.3 102 | urllib3==2.1.0 103 | wcwidth==0.2.12 104 | websocket-client==1.7.0 105 | Werkzeug==3.0.1 106 | xxhash==3.4.1 107 | yapf==0.40.2 108 | yarl==1.9.4 109 | zipp==3.17.0 110 | -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_0.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_1.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_10.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_11.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_12.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_13.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_14.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_15.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_16.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_17.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_18.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_19.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_2.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_20.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_21.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_22.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_23.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_3.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_4.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_5.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_6.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_7.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_8.png -------------------------------------------------------------------------------- /fambav/visualization/cosine_similarity_layer_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_9.png -------------------------------------------------------------------------------- /fambav/visualization/hidden_states_20240714_113249.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113249.npz -------------------------------------------------------------------------------- /fambav/visualization/hidden_states_20240714_113250.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113250.npz -------------------------------------------------------------------------------- /fambav/visualization/hidden_states_20240714_113251.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113251.npz -------------------------------------------------------------------------------- /fambav/visualization/hidden_states_20240714_113252.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113252.npz -------------------------------------------------------------------------------- /fambav/visualization/hidden_states_20240714_113253.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113253.npz -------------------------------------------------------------------------------- /fambav/visualization/hidden_states_20240714_113254.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113254.npz -------------------------------------------------------------------------------- /fambav/visualization/hidden_states_20240714_113255.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113255.npz -------------------------------------------------------------------------------- /fambav/visualization/npz_viewer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | data = np.load('./hidden_states_20240714_053015.npz') 4 | for key in data.files: 5 | print(data[key]) -------------------------------------------------------------------------------- /fambav/visualization/test_visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | from sklearn.metrics.pairwise import cosine_similarity 8 | 9 | def visualize_cosine_similarity_from_npz(file_path): 10 | print("Start to visualize cosine similarity from .npz file") 11 | 12 | # 加载 .npz 文件 13 | loaded_data = np.load(file_path) 14 | 15 | # 遍历所有的层 16 | for key in loaded_data.files: 17 | if key.startswith('layer_'): 18 | hidden_states = loaded_data[key] 19 | 20 | if isinstance(hidden_states, np.ndarray): 21 | hidden_states = torch.from_numpy(hidden_states) 22 | 23 | # 如果数据在GPU上,确保将其移动到CPU上 24 | if hidden_states.device.type != 'cpu': 25 | hidden_states = hidden_states.cpu() 26 | 27 | # 转换为 NumPy 数组(如果尚未转换) 28 | hidden_states = hidden_states.numpy() 29 | 30 | # 计算余弦相似度矩阵 31 | cosine_sim_matrix = cosine_similarity(hidden_states) 32 | 33 | print(f"cosine_sim_matrix for {key} success") 34 | 35 | plt.figure(figsize=(12, 10)) # 增大图像尺寸 36 | sns.heatmap(cosine_sim_matrix, cmap='Blues') 37 | 38 | # 设置标题和标签,增大字体大小 39 | plt.suptitle(f"Cosine Similarity Matrix of Hidden States for {key}", fontsize=26, y=0.98) 40 | plt.xlabel("Hidden State Index", fontsize=26) 41 | plt.ylabel("Hidden State Index", fontsize=26) 42 | 43 | # 计算新的刻度位置和标签 44 | n = cosine_sim_matrix.shape[0] # 假设矩阵是方阵 45 | ticks = np.arange(0, n, 8) 46 | tick_labels = ticks 47 | 48 | # 设置新的刻度 49 | plt.xticks(ticks, tick_labels, fontsize=16) 50 | plt.yticks(ticks, tick_labels, fontsize=16) 51 | 52 | # 调整colorbar的字体大小 53 | cbar = plt.gcf().axes[-1] 54 | cbar.tick_params(labelsize=12) 55 | 56 | plt.tight_layout() # 自动调整子图参数,使之填充整个图像区域 57 | plt.savefig(f"./cosine_similarity_{key}.png", dpi=600) # 增加DPI以提高图像质量 58 | print(f"Heatmap for {key} saved.") 59 | 60 | # 示例 .npz 文件路径 61 | file_path = './hidden_states_20240714_113251.npz' 62 | 63 | # 调用函数 64 | visualize_cosine_similarity_from_npz(file_path) -------------------------------------------------------------------------------- /mamba-1p1p1/.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will: 2 | # - Create a new Github release 3 | # - Build wheels for supported architectures 4 | # - Deploy the wheels to the Github release 5 | # - Release the static code to PyPi 6 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 7 | 8 | name: Build wheels and deploy 9 | 10 | on: 11 | create: 12 | tags: 13 | - v* 14 | 15 | jobs: 16 | 17 | setup_release: 18 | name: Create Release 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: Get the tag version 22 | id: extract_branch 23 | run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} 24 | shell: bash 25 | 26 | - name: Create Release 27 | id: create_release 28 | uses: actions/create-release@v1 29 | env: 30 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 31 | with: 32 | tag_name: ${{ steps.extract_branch.outputs.branch }} 33 | release_name: ${{ steps.extract_branch.outputs.branch }} 34 | 35 | build_wheels: 36 | name: Build Wheel 37 | needs: setup_release 38 | runs-on: ${{ matrix.os }} 39 | 40 | strategy: 41 | fail-fast: false 42 | matrix: 43 | # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the 44 | # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. 45 | os: [ubuntu-20.04] 46 | python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] 47 | torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.1', '2.2.0.dev20231106'] 48 | cuda-version: ['11.8.0', '12.2.0'] 49 | # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. 50 | # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. 51 | # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) 52 | # when building without C++11 ABI and using it on nvcr images. 53 | cxx11_abi: ['FALSE', 'TRUE'] 54 | exclude: 55 | # Pytorch <= 1.12 does not support Python 3.11 56 | - torch-version: '1.12.1' 57 | python-version: '3.11' 58 | # Pytorch >= 2.0 only supports Python >= 3.8 59 | - torch-version: '2.0.1' 60 | python-version: '3.7' 61 | - torch-version: '2.1.1' 62 | python-version: '3.7' 63 | - torch-version: '2.2.0.dev20231106' 64 | python-version: '3.7' 65 | # Pytorch <= 2.0 only supports CUDA <= 11.8 66 | - torch-version: '1.12.1' 67 | cuda-version: '12.2.0' 68 | - torch-version: '1.13.1' 69 | cuda-version: '12.2.0' 70 | - torch-version: '2.0.1' 71 | cuda-version: '12.2.0' 72 | 73 | steps: 74 | - name: Checkout 75 | uses: actions/checkout@v3 76 | 77 | - name: Set up Python 78 | uses: actions/setup-python@v4 79 | with: 80 | python-version: ${{ matrix.python-version }} 81 | 82 | - name: Set CUDA and PyTorch versions 83 | run: | 84 | echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV 85 | echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV 86 | 87 | - name: Free up disk space 88 | if: ${{ runner.os == 'Linux' }} 89 | # https://github.com/easimon/maximize-build-space/blob/master/action.yml 90 | # https://github.com/easimon/maximize-build-space/tree/test-report 91 | run: | 92 | sudo rm -rf /usr/share/dotnet 93 | sudo rm -rf /opt/ghc 94 | sudo rm -rf /opt/hostedtoolcache/CodeQL 95 | 96 | - name: Set up swap space 97 | if: runner.os == 'Linux' 98 | uses: pierotofy/set-swap-space@v1.0 99 | with: 100 | swap-size-gb: 10 101 | 102 | - name: Install CUDA ${{ matrix.cuda-version }} 103 | if: ${{ matrix.cuda-version != 'cpu' }} 104 | uses: Jimver/cuda-toolkit@v0.2.11 105 | id: cuda-toolkit 106 | with: 107 | cuda: ${{ matrix.cuda-version }} 108 | linux-local-args: '["--toolkit"]' 109 | # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 110 | # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} 111 | method: 'network' 112 | # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions, 113 | # not just nvcc 114 | # sub-packages: '["nvcc"]' 115 | 116 | - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} 117 | run: | 118 | pip install --upgrade pip 119 | # If we don't install before installing Pytorch, we get error for torch 2.0.1 120 | # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none) 121 | pip install lit 122 | # We want to figure out the CUDA version to download pytorch 123 | # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 124 | # This code is ugly, maybe there's a better way to do this. 125 | export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") 126 | if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then 127 | pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} 128 | else 129 | pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} 130 | fi 131 | nvcc --version 132 | python --version 133 | python -c "import torch; print('PyTorch:', torch.__version__)" 134 | python -c "import torch; print('CUDA:', torch.version.cuda)" 135 | python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" 136 | shell: 137 | bash 138 | 139 | - name: Build wheel 140 | run: | 141 | # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 142 | # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 143 | # However this still fails so I'm using a newer version of setuptools 144 | pip install setuptools==68.0.0 145 | pip install ninja packaging wheel 146 | export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH 147 | export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH 148 | # Limit MAX_JOBS otherwise the github runner goes OOM 149 | MAX_JOBS=2 MAMBA_FORCE_BUILD="TRUE" MAMBA_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist 150 | tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} 151 | wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") 152 | ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} 153 | echo "wheel_name=${wheel_name}" >> $GITHUB_ENV 154 | 155 | - name: Log Built Wheels 156 | run: | 157 | ls dist 158 | 159 | - name: Get the tag version 160 | id: extract_branch 161 | run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} 162 | 163 | - name: Get Release with tag 164 | id: get_current_release 165 | uses: joutvhu/get-release@v1 166 | with: 167 | tag_name: ${{ steps.extract_branch.outputs.branch }} 168 | env: 169 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 170 | 171 | - name: Upload Release Asset 172 | id: upload_release_asset 173 | uses: actions/upload-release-asset@v1 174 | env: 175 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 176 | with: 177 | upload_url: ${{ steps.get_current_release.outputs.upload_url }} 178 | asset_path: ./dist/${{env.wheel_name}} 179 | asset_name: ${{env.wheel_name}} 180 | asset_content_type: application/* 181 | 182 | publish_package: 183 | name: Publish package 184 | needs: [build_wheels] 185 | 186 | runs-on: ubuntu-latest 187 | 188 | steps: 189 | - uses: actions/checkout@v3 190 | 191 | - uses: actions/setup-python@v4 192 | with: 193 | python-version: '3.10' 194 | 195 | - name: Install dependencies 196 | run: | 197 | pip install ninja packaging setuptools wheel twine 198 | # We don't want to download anything CUDA-related here 199 | pip install torch --index-url https://download.pytorch.org/whl/cpu 200 | 201 | - name: Build core package 202 | env: 203 | MAMBA_SKIP_CUDA_BUILD: "TRUE" 204 | run: | 205 | python setup.py sdist --dist-dir=dist 206 | 207 | - name: Deploy 208 | env: 209 | TWINE_USERNAME: "__token__" 210 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 211 | run: | 212 | python -m twine upload dist/* 213 | -------------------------------------------------------------------------------- /mamba-1p1p1/.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | *.egg-info/ 3 | build/ 4 | **.so 5 | -------------------------------------------------------------------------------- /mamba-1p1p1/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/lm-evaluation-harness"] 2 | path = 3rdparty/lm-evaluation-harness 3 | url = https://github.com/EleutherAI/lm-evaluation-harness/ 4 | -------------------------------------------------------------------------------- /mamba-1p1p1/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | Albert Gu, agu@andrew.cmu.edu 3 | -------------------------------------------------------------------------------- /mamba-1p1p1/README.md: -------------------------------------------------------------------------------- 1 | # Mamba 2 | 3 | ![Mamba](assets/selection.png "Selective State Space") 4 | > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ 5 | > Albert Gu*, Tri Dao*\ 6 | > Paper: https://arxiv.org/abs/2312.00752 7 | 8 | ## About 9 | 10 | Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. 11 | It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4), 12 | with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention). 13 | 14 | ## Installation 15 | 16 | - `pip install causal-conv1d>=1.1.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. 17 | - `pip install mamba-ssm`: the core Mamba package. 18 | 19 | It can also be built from source with `pip install .` from this repository. 20 | 21 | If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`. 22 | 23 | Other requirements: 24 | - Linux 25 | - NVIDIA GPU 26 | - PyTorch 1.12+ 27 | - CUDA 11.6+ 28 | 29 | ## Usage 30 | 31 | We expose several levels of interface with the Mamba model. 32 | 33 | ### Selective SSM 34 | 35 | Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2). 36 | 37 | Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py). 38 | 39 | ### Mamba Block 40 | 41 | The main module of this repository is the Mamba architecture block wrapping the selective SSM. 42 | 43 | Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py). 44 | 45 | Usage: 46 | ``` 47 | import torch 48 | from mamba_ssm import Mamba 49 | 50 | batch, length, dim = 2, 64, 16 51 | x = torch.randn(batch, length, dim).to("cuda") 52 | model = Mamba( 53 | # This module uses roughly 3 * expand * d_model^2 parameters 54 | d_model=dim, # Model dimension d_model 55 | d_state=16, # SSM state expansion factor 56 | d_conv=4, # Local convolution width 57 | expand=2, # Block expansion factor 58 | ).to("cuda") 59 | y = model(x) 60 | assert y.shape == x.shape 61 | ``` 62 | 63 | ### Mamba Language Model 64 | 65 | Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. 66 | 67 | Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py). 68 | 69 | This is an example of how to integrate Mamba into an end-to-end neural network. 70 | This example is used in the generation scripts below. 71 | 72 | 73 | 74 | ## Pretrained Models 75 | 76 | Pretrained models are uploaded to 77 | [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`, 78 | `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj` 79 | (trained on 600B tokens on the SlimPajama dataset). 80 | 81 | 82 | The models will be autodownloaded by the generation script below. 83 | 84 | These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models: 85 | 86 | | Parameters | Layers | Model dim. | 87 | |------------|--------|------------| 88 | | 130M | 24 | 768 | 89 | | 370M | 48 | 1024 | 90 | | 790M | 48 | 1536 | 91 | | 1.4B | 48 | 2048 | 92 | | 2.8B | 64 | 2560 | 93 | 94 | (The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.) 95 | 96 | Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). 97 | Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models. 98 | 99 | 100 | ## Evaluations 101 | 102 | To run zero-shot evaluations of models (corresponding to Table 3 of the paper), 103 | we use the 104 | [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) 105 | library. 106 | 107 | 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init 108 | --recursive`. We use the `big-refactor` branch. 109 | 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`. 110 | On Python 3.10 you might need to manually install the latest version of `promptsource`: `pip install git+https://github.com/bigscience-workshop/promptsource.git`. 111 | 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo): 112 | ``` 113 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 114 | python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 115 | ``` 116 | 117 | To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts: 118 | ``` 119 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 64 120 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 64 121 | ``` 122 | 123 | Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process. 124 | 125 | ## Inference 126 | 127 | The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py) 128 | 1. autoloads a model from the Hugging Face Hub, 129 | 2. generates completions of a user-specified prompt, 130 | 3. benchmarks the inference speed of this generation. 131 | 132 | Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature. 133 | 134 | ### Examples 135 | 136 | To test generation latency (e.g. batch size = 1) with different sampling strategies: 137 | 138 | ``` 139 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 140 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 141 | ``` 142 | 143 | To test generation throughput with random prompts (e.g. large batch size): 144 | ``` 145 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128 146 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128 147 | ``` 148 | 149 | 150 | ## Troubleshooting 151 | 152 | ### Precision 153 | Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary. 154 | On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation). 155 | 156 | We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities, 157 | as a first step please try a framework storing parameters in fp32 (such as AMP). 158 | 159 | ### Initialization 160 | Some parts of the model have initializations inherited from prior work on S4 models. 161 | For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection. 162 | However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero). 163 | If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework) 164 | that is specific to the training framework. 165 | 166 | 167 | ## Citation 168 | 169 | If you use this codebase, or otherwise found our work valuable, please cite Mamba: 170 | ``` 171 | @article{mamba, 172 | title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, 173 | author={Gu, Albert and Dao, Tri}, 174 | journal={arXiv preprint arXiv:2312.00752}, 175 | year={2023} 176 | } 177 | ``` 178 | -------------------------------------------------------------------------------- /mamba-1p1p1/assets/selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/assets/selection.png -------------------------------------------------------------------------------- /mamba-1p1p1/benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--repetition-penalty", type=float, default=1.0) 26 | parser.add_argument("--batch", type=int, default=1) 27 | args = parser.parse_args() 28 | 29 | repeats = 3 30 | device = "cuda" 31 | dtype = torch.float16 32 | 33 | print(f"Loading model {args.model_name}") 34 | is_mamba = args.model_name.startswith("state-spaces/mamba-") 35 | if is_mamba: 36 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 38 | else: 39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 41 | model.eval() 42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 43 | 44 | torch.random.manual_seed(0) 45 | if args.prompt is None: 46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 48 | else: 49 | tokens = tokenizer(args.prompt, return_tensors="pt") 50 | input_ids = tokens.input_ids.to(device=device) 51 | attn_mask = tokens.attention_mask.to(device=device) 52 | max_length = input_ids.shape[1] + args.genlen 53 | 54 | if is_mamba: 55 | fn = lambda: model.generate( 56 | input_ids=input_ids, 57 | max_length=max_length, 58 | cg=True, 59 | return_dict_in_generate=True, 60 | output_scores=True, 61 | enable_timing=False, 62 | temperature=args.temperature, 63 | top_k=args.topk, 64 | top_p=args.topp, 65 | repetition_penalty=args.repetition_penalty, 66 | ) 67 | else: 68 | fn = lambda: model.generate( 69 | input_ids=input_ids, 70 | attention_mask=attn_mask, 71 | max_length=max_length, 72 | return_dict_in_generate=True, 73 | pad_token_id=tokenizer.eos_token_id, 74 | do_sample=True, 75 | temperature=args.temperature, 76 | top_k=args.topk, 77 | top_p=args.topp, 78 | repetition_penalty=args.repetition_penalty, 79 | ) 80 | out = fn() 81 | if args.prompt is not None: 82 | print(tokenizer.batch_decode(out.sequences.tolist())) 83 | 84 | torch.cuda.synchronize() 85 | start = time.time() 86 | for _ in range(repeats): 87 | fn() 88 | torch.cuda.synchronize() 89 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 90 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 91 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For scalar_value_type 10 | 11 | #define MAX_DSTATE 256 12 | 13 | using complex_t = c10::complex; 14 | 15 | inline __device__ float2 operator+(const float2 & a, const float2 & b){ 16 | return {a.x + b.x, a.y + b.y}; 17 | } 18 | 19 | inline __device__ float3 operator+(const float3 &a, const float3 &b) { 20 | return {a.x + b.x, a.y + b.y, a.z + b.z}; 21 | } 22 | 23 | inline __device__ float4 operator+(const float4 & a, const float4 & b){ 24 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; 25 | } 26 | 27 | //////////////////////////////////////////////////////////////////////////////////////////////////// 28 | 29 | template struct BytesToType {}; 30 | 31 | template<> struct BytesToType<16> { 32 | using Type = uint4; 33 | static_assert(sizeof(Type) == 16); 34 | }; 35 | 36 | template<> struct BytesToType<8> { 37 | using Type = uint64_t; 38 | static_assert(sizeof(Type) == 8); 39 | }; 40 | 41 | template<> struct BytesToType<4> { 42 | using Type = uint32_t; 43 | static_assert(sizeof(Type) == 4); 44 | }; 45 | 46 | template<> struct BytesToType<2> { 47 | using Type = uint16_t; 48 | static_assert(sizeof(Type) == 2); 49 | }; 50 | 51 | template<> struct BytesToType<1> { 52 | using Type = uint8_t; 53 | static_assert(sizeof(Type) == 1); 54 | }; 55 | 56 | //////////////////////////////////////////////////////////////////////////////////////////////////// 57 | 58 | template 59 | struct Converter{ 60 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { 61 | #pragma unroll 62 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; } 63 | } 64 | }; 65 | 66 | template 67 | struct Converter{ 68 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { 69 | static_assert(N % 2 == 0); 70 | auto &src2 = reinterpret_cast(src); 71 | auto &dst2 = reinterpret_cast(dst); 72 | #pragma unroll 73 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } 74 | } 75 | }; 76 | 77 | #if __CUDA_ARCH__ >= 800 78 | template 79 | struct Converter{ 80 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { 81 | static_assert(N % 2 == 0); 82 | auto &src2 = reinterpret_cast(src); 83 | auto &dst2 = reinterpret_cast(dst); 84 | #pragma unroll 85 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } 86 | } 87 | }; 88 | #endif 89 | 90 | //////////////////////////////////////////////////////////////////////////////////////////////////// 91 | 92 | // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp 93 | // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 94 | __device__ __forceinline__ complex_t cexp2f(complex_t z) { 95 | float t = exp2f(z.real_); 96 | float c, s; 97 | sincosf(z.imag_, &s, &c); 98 | return complex_t(c * t, s * t); 99 | } 100 | 101 | __device__ __forceinline__ complex_t cexpf(complex_t z) { 102 | float t = expf(z.real_); 103 | float c, s; 104 | sincosf(z.imag_, &s, &c); 105 | return complex_t(c * t, s * t); 106 | } 107 | 108 | template struct SSMScanOp; 109 | 110 | template<> 111 | struct SSMScanOp { 112 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { 113 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); 114 | } 115 | }; 116 | 117 | template<> 118 | struct SSMScanOp { 119 | __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { 120 | complex_t a0 = complex_t(ab0.x, ab0.y); 121 | complex_t b0 = complex_t(ab0.z, ab0.w); 122 | complex_t a1 = complex_t(ab1.x, ab1.y); 123 | complex_t b1 = complex_t(ab1.z, ab1.w); 124 | complex_t out_a = a1 * a0; 125 | complex_t out_b = a1 * b0 + b1; 126 | return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); 127 | } 128 | }; 129 | 130 | // A stateful callback functor that maintains a running prefix to be applied 131 | // during consecutive scan operations. 132 | template struct SSMScanPrefixCallbackOp { 133 | using scan_t = std::conditional_t, float2, float4>; 134 | scan_t running_prefix; 135 | // Constructor 136 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} 137 | // Callback operator to be entered by the first warp of threads in the block. 138 | // Thread-0 is responsible for returning a value for seeding the block-wide scan. 139 | __device__ scan_t operator()(scan_t block_aggregate) { 140 | scan_t old_prefix = running_prefix; 141 | running_prefix = SSMScanOp()(running_prefix, block_aggregate); 142 | return old_prefix; 143 | } 144 | }; 145 | 146 | //////////////////////////////////////////////////////////////////////////////////////////////////// 147 | 148 | template 149 | inline __device__ void load_input(typename Ktraits::input_t *u, 150 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], 151 | typename Ktraits::BlockLoadT::TempStorage &smem_load, 152 | int seqlen) { 153 | if constexpr (Ktraits::kIsEvenLen) { 154 | auto& smem_load_vec = reinterpret_cast(smem_load); 155 | using vec_t = typename Ktraits::vec_t; 156 | Ktraits::BlockLoadVecT(smem_load_vec).Load( 157 | reinterpret_cast(u), 158 | reinterpret_cast(u_vals) 159 | ); 160 | } else { 161 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); 162 | } 163 | } 164 | 165 | template 166 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar, 167 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], 168 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, 169 | int seqlen) { 170 | constexpr int kNItems = Ktraits::kNItems; 171 | if constexpr (!Ktraits::kIsComplex) { 172 | typename Ktraits::input_t B_vals_load[kNItems]; 173 | if constexpr (Ktraits::kIsEvenLen) { 174 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 175 | using vec_t = typename Ktraits::vec_t; 176 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 177 | reinterpret_cast(Bvar), 178 | reinterpret_cast(B_vals_load) 179 | ); 180 | } else { 181 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 182 | } 183 | // #pragma unroll 184 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } 185 | Converter::to_float(B_vals_load, B_vals); 186 | } else { 187 | typename Ktraits::input_t B_vals_load[kNItems * 2]; 188 | if constexpr (Ktraits::kIsEvenLen) { 189 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 190 | using vec_t = typename Ktraits::vec_t; 191 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 192 | reinterpret_cast(Bvar), 193 | reinterpret_cast(B_vals_load) 194 | ); 195 | } else { 196 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 197 | } 198 | #pragma unroll 199 | for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } 200 | } 201 | } 202 | 203 | template 204 | inline __device__ void store_output(typename Ktraits::input_t *out, 205 | const float (&out_vals)[Ktraits::kNItems], 206 | typename Ktraits::BlockStoreT::TempStorage &smem_store, 207 | int seqlen) { 208 | typename Ktraits::input_t write_vals[Ktraits::kNItems]; 209 | #pragma unroll 210 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } 211 | if constexpr (Ktraits::kIsEvenLen) { 212 | auto& smem_store_vec = reinterpret_cast(smem_store); 213 | using vec_t = typename Ktraits::vec_t; 214 | Ktraits::BlockStoreVecT(smem_store_vec).Store( 215 | reinterpret_cast(out), 216 | reinterpret_cast(write_vals) 217 | ); 218 | } else { 219 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_bf16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp32.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /mamba-1p1p1/evals/lm_harness_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import transformers 4 | from transformers import AutoTokenizer 5 | 6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.models.huggingface import HFLM 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.__main__ import cli_evaluate 12 | 13 | 14 | @register_model("mamba") 15 | class MambaEvalWrapper(HFLM): 16 | 17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM 18 | 19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", 20 | dtype=torch.float16): 21 | LM.__init__(self) 22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) 23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 25 | self.vocab_size = self.tokenizer.vocab_size 26 | self._batch_size = int(batch_size) if batch_size is not None else 64 27 | self._max_length = max_length 28 | self._device = torch.device(device) 29 | 30 | @property 31 | def batch_size(self): 32 | return self._batch_size 33 | 34 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 35 | raise NotImplementedError() 36 | 37 | 38 | if __name__ == "__main__": 39 | cli_evaluate() 40 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.1" 2 | 3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 4 | from mamba_ssm.modules.mamba_simple import Mamba 5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/models/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/models/config_mamba.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MambaConfig: 6 | 7 | d_model: int = 2560 8 | n_layer: int = 64 9 | vocab_size: int = 50277 10 | ssm_cfg: dict = field(default_factory=dict) 11 | rms_norm: bool = True 12 | residual_in_fp32: bool = True 13 | fused_add_norm: bool = True 14 | pad_vocab_size_multiple: int = 8 15 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/modules/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/ops/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/ops/triton/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/ops/triton/selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | """We want triton==2.1.0 for this 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import triton 11 | import triton.language as tl 12 | 13 | from einops import rearrange, repeat 14 | 15 | 16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) 17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) 18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) 19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) 20 | @triton.jit 21 | def _selective_scan_update_kernel( 22 | # Pointers to matrices 23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, 24 | # Matrix dimensions 25 | batch, dim, dstate, 26 | # Strides 27 | stride_state_batch, stride_state_dim, stride_state_dstate, 28 | stride_x_batch, stride_x_dim, 29 | stride_dt_batch, stride_dt_dim, 30 | stride_dt_bias_dim, 31 | stride_A_dim, stride_A_dstate, 32 | stride_B_batch, stride_B_dstate, 33 | stride_C_batch, stride_C_dstate, 34 | stride_D_dim, 35 | stride_z_batch, stride_z_dim, 36 | stride_out_batch, stride_out_dim, 37 | # Meta-parameters 38 | DT_SOFTPLUS: tl.constexpr, 39 | BLOCK_SIZE_M: tl.constexpr, 40 | HAS_DT_BIAS: tl.constexpr, 41 | HAS_D: tl.constexpr, 42 | HAS_Z: tl.constexpr, 43 | BLOCK_SIZE_DSTATE: tl.constexpr, 44 | ): 45 | pid_m = tl.program_id(axis=0) 46 | pid_b = tl.program_id(axis=1) 47 | state_ptr += pid_b * stride_state_batch 48 | x_ptr += pid_b * stride_x_batch 49 | dt_ptr += pid_b * stride_dt_batch 50 | B_ptr += pid_b * stride_B_batch 51 | C_ptr += pid_b * stride_C_batch 52 | if HAS_Z: 53 | z_ptr += pid_b * stride_z_batch 54 | out_ptr += pid_b * stride_out_batch 55 | 56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) 58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) 59 | x_ptrs = x_ptr + offs_m * stride_x_dim 60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim 61 | if HAS_DT_BIAS: 62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim 63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) 64 | B_ptrs = B_ptr + offs_n * stride_B_dstate 65 | C_ptrs = C_ptr + offs_n * stride_C_dstate 66 | if HAS_D: 67 | D_ptrs = D_ptr + offs_m * stride_D_dim 68 | if HAS_Z: 69 | z_ptrs = z_ptr + offs_m * stride_z_dim 70 | out_ptrs = out_ptr + offs_m * stride_out_dim 71 | 72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) 73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 75 | if HAS_DT_BIAS: 76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 77 | if DT_SOFTPLUS: 78 | dt = tl.log(1.0 + tl.exp(dt)) 79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) 80 | dA = tl.exp(A * dt[:, None]) 81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 83 | if HAS_D: 84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 85 | if HAS_Z: 86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 87 | 88 | dB = B[None, :] * dt[:, None] 89 | state = state * dA + dB * x[:, None] 90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) 91 | out = tl.sum(state * C[None, :], axis=1) 92 | if HAS_D: 93 | out += x * D 94 | if HAS_Z: 95 | out *= z * tl.sigmoid(z) 96 | tl.store(out_ptrs, out, mask=offs_m < dim) 97 | 98 | 99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 100 | """ 101 | Argument: 102 | state: (batch, dim, dstate) 103 | x: (batch, dim) 104 | dt: (batch, dim) 105 | A: (dim, dstate) 106 | B: (batch, dstate) 107 | C: (batch, dstate) 108 | D: (dim,) 109 | z: (batch, dim) 110 | dt_bias: (dim,) 111 | Return: 112 | out: (batch, dim) 113 | """ 114 | batch, dim, dstate = state.shape 115 | assert x.shape == (batch, dim) 116 | assert dt.shape == x.shape 117 | assert A.shape == (dim, dstate) 118 | assert B.shape == (batch, dstate) 119 | assert C.shape == B.shape 120 | if D is not None: 121 | assert D.shape == (dim,) 122 | if z is not None: 123 | assert z.shape == x.shape 124 | if dt_bias is not None: 125 | assert dt_bias.shape == (dim,) 126 | out = torch.empty_like(x) 127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) 128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) 129 | # We don't want autotune since it will overwrite the state 130 | # We instead tune by hand. 131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 132 | else ((16, 4) if dstate <= 32 else 133 | ((8, 4) if dstate <= 64 else 134 | ((4, 4) if dstate <= 128 else 135 | ((4, 8)))))) 136 | with torch.cuda.device(x.device.index): 137 | _selective_scan_update_kernel[grid]( 138 | state, x, dt, dt_bias, A, B, C, D, z, out, 139 | batch, dim, dstate, 140 | state.stride(0), state.stride(1), state.stride(2), 141 | x.stride(0), x.stride(1), 142 | dt.stride(0), dt.stride(1), 143 | dt_bias.stride(0) if dt_bias is not None else 0, 144 | A.stride(0), A.stride(1), 145 | B.stride(0), B.stride(1), 146 | C.stride(0), C.stride(1), 147 | D.stride(0) if D is not None else 0, 148 | z_strides[0], z_strides[1], 149 | out.stride(0), out.stride(1), 150 | dt_softplus, 151 | BLOCK_SIZE_M, 152 | num_warps=num_warps, 153 | ) 154 | return out 155 | 156 | 157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 158 | """ 159 | Argument: 160 | state: (batch, dim, dstate) 161 | x: (batch, dim) 162 | dt: (batch, dim) 163 | A: (dim, dstate) 164 | B: (batch, dstate) 165 | C: (batch, dstate) 166 | D: (dim,) 167 | z: (batch, dim) 168 | dt_bias: (dim,) 169 | Return: 170 | out: (batch, dim) 171 | """ 172 | batch, dim, dstate = state.shape 173 | assert x.shape == (batch, dim) 174 | assert dt.shape == x.shape 175 | assert A.shape == (dim, dstate) 176 | assert B.shape == (batch, dstate) 177 | assert C.shape == B.shape 178 | if D is not None: 179 | assert D.shape == (dim,) 180 | if z is not None: 181 | assert z.shape == x.shape 182 | if dt_bias is not None: 183 | assert dt_bias.shape == (dim,) 184 | dt = dt + dt_bias 185 | dt = F.softplus(dt) if dt_softplus else dt 186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) 187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) 188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate 189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) 190 | if D is not None: 191 | out += (x * D).to(out.dtype) 192 | return (out if z is None else out * F.silu(z)).to(x.dtype) 193 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/utils/__init__.py -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /mamba-1p1p1/tests/ops/triton/test_selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 15 | # @pytest.mark.parametrize('itype', [torch.float16]) 16 | @pytest.mark.parametrize("has_z", [False, True]) 17 | # @pytest.mark.parametrize('has_z', [True]) 18 | @pytest.mark.parametrize("dstate", [16, 32, 64]) 19 | # @pytest.mark.parametrize("dstate", [16]) 20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 21 | # @pytest.mark.parametrize("dim", [2048]) 22 | def test_causal_conv1d_update(dim, dstate, has_z, itype): 23 | device = "cuda" 24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) 25 | if itype == torch.bfloat16: 26 | rtol, atol = 1e-2, 5e-2 27 | # set seed 28 | torch.random.manual_seed(0) 29 | batch_size = 2 30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) 31 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype) 33 | dt_bias = torch.rand(dim, device=device) - 4.0 34 | A = -torch.rand(dim, dstate, device=device) - 1.0 35 | B = torch.randn(batch_size, dstate, device=device) 36 | C = torch.randn(batch_size, dstate, device=device) 37 | D = torch.randn(dim, device=device) 38 | if has_z: 39 | z = torch.randn_like(x) 40 | else: 41 | z = None 42 | state_ref = state.detach().clone() 43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 45 | 46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 50 | --------------------------------------------------------------------------------