├── .gitignore
├── =1.1.0
├── README.md
├── assets
├── FambaOverview.png
├── Strategies.png
└── logo.jpg
├── causal-conv1d
├── =1.1.0
├── AUTHORS
├── LICENSE
├── README.md
├── [13
├── causal_conv1d
│ ├── =1.1.0
│ ├── __init__.py
│ └── causal_conv1d_interface.py
├── csrc
│ ├── causal_conv1d.cpp
│ ├── causal_conv1d.h
│ ├── causal_conv1d_bwd.cu
│ ├── causal_conv1d_common.h
│ ├── causal_conv1d_fwd.cu
│ ├── causal_conv1d_update.cu
│ └── static_switch.h
├── setup.py
└── tests
│ └── test_causal_conv1d.py
├── fambav
├── .gitignore
├── LICENSE
├── augment.py
├── datasets.py
├── engine.py
├── hubconf.py
├── losses.py
├── main.py
├── mlruns
│ └── 0
│ │ └── meta.yaml
├── models_mamba.py
├── rope.py
├── run_with_submitit.py
├── samplers.py
├── scripts
│ ├── mlruns
│ │ └── 0
│ │ │ └── meta.yaml
│ ├── vim-s-cifar-all-merge-r7.sh
│ ├── vim-s-cifar-alternate-merge-r14.sh
│ ├── vim-s-cifar-lower-merge-r9-l4.sh
│ ├── vim-s-cifar-lower-merge-r9-l5.sh
│ ├── vim-s-cifar-non-merge.sh
│ ├── vim-s-cifar-upper-merge-r9-l18.sh
│ ├── vim-t-cifar-all-merge-r7-cls.sh
│ ├── vim-t-cifar-all-merge-r7.sh
│ ├── vim-t-cifar-alternate-merge-r14.sh
│ ├── vim-t-cifar-lower-merge-r1-l4.sh
│ ├── vim-t-cifar-lower-merge-r10-l6.sh
│ ├── vim-t-cifar-lower-merge-r10-l7.sh
│ ├── vim-t-cifar-lower-merge-r11-l8.sh
│ ├── vim-t-cifar-lower-merge-r12-l10.sh
│ ├── vim-t-cifar-lower-merge-r12-l9.sh
│ ├── vim-t-cifar-lower-merge-r13-l11.sh
│ ├── vim-t-cifar-lower-merge-r14-l12.sh
│ ├── vim-t-cifar-lower-merge-r15-l13.sh
│ ├── vim-t-cifar-lower-merge-r2-l4.sh
│ ├── vim-t-cifar-lower-merge-r3-l4.sh
│ ├── vim-t-cifar-lower-merge-r4-l4.sh
│ ├── vim-t-cifar-lower-merge-r5-l4.sh
│ ├── vim-t-cifar-lower-merge-r6-l4.sh
│ ├── vim-t-cifar-lower-merge-r7-l4.sh
│ ├── vim-t-cifar-lower-merge-r8-l3.sh
│ ├── vim-t-cifar-lower-merge-r8-l4.sh
│ ├── vim-t-cifar-lower-merge-r9-l4.sh
│ ├── vim-t-cifar-lower-merge-r9-l5.sh
│ ├── vim-t-cifar-non-merge-visual.sh
│ ├── vim-t-cifar-non-merge.sh
│ ├── vim-t-cifar-upper-merge-r9-l18.sh
│ ├── vim-t-imagenet-all-merge-r7.sh
│ ├── vim-t-imagenet-alternate-merge-r14.sh
│ ├── vim-t-imagenet-lower-merge-r9-l5.sh
│ ├── vim-t-imagenet-non-merge.sh
│ └── vim-t-imagenet-upper-merge-r9-l18.sh
├── test.sh
├── token_merge.py
├── utils.py
├── vim_requirements.txt
└── visualization
│ ├── cosine_similarity_layer_0.png
│ ├── cosine_similarity_layer_1.png
│ ├── cosine_similarity_layer_10.png
│ ├── cosine_similarity_layer_11.png
│ ├── cosine_similarity_layer_12.png
│ ├── cosine_similarity_layer_13.png
│ ├── cosine_similarity_layer_14.png
│ ├── cosine_similarity_layer_15.png
│ ├── cosine_similarity_layer_16.png
│ ├── cosine_similarity_layer_17.png
│ ├── cosine_similarity_layer_18.png
│ ├── cosine_similarity_layer_19.png
│ ├── cosine_similarity_layer_2.png
│ ├── cosine_similarity_layer_20.png
│ ├── cosine_similarity_layer_21.png
│ ├── cosine_similarity_layer_22.png
│ ├── cosine_similarity_layer_23.png
│ ├── cosine_similarity_layer_3.png
│ ├── cosine_similarity_layer_4.png
│ ├── cosine_similarity_layer_5.png
│ ├── cosine_similarity_layer_6.png
│ ├── cosine_similarity_layer_7.png
│ ├── cosine_similarity_layer_8.png
│ ├── cosine_similarity_layer_9.png
│ ├── hidden_states_20240714_113249.npz
│ ├── hidden_states_20240714_113250.npz
│ ├── hidden_states_20240714_113251.npz
│ ├── hidden_states_20240714_113252.npz
│ ├── hidden_states_20240714_113253.npz
│ ├── hidden_states_20240714_113254.npz
│ ├── hidden_states_20240714_113255.npz
│ ├── npz_viewer.py
│ └── test_visualize.py
└── mamba-1p1p1
├── .github
└── workflows
│ └── publish.yaml
├── .gitignore
├── .gitmodules
├── AUTHORS
├── LICENSE
├── README.md
├── assets
└── selection.png
├── benchmarks
└── benchmark_generation_mamba_simple.py
├── csrc
└── selective_scan
│ ├── reverse_scan.cuh
│ ├── selective_scan.cpp
│ ├── selective_scan.h
│ ├── selective_scan_bwd_bf16_complex.cu
│ ├── selective_scan_bwd_bf16_real.cu
│ ├── selective_scan_bwd_fp16_complex.cu
│ ├── selective_scan_bwd_fp16_real.cu
│ ├── selective_scan_bwd_fp32_complex.cu
│ ├── selective_scan_bwd_fp32_real.cu
│ ├── selective_scan_bwd_kernel.cuh
│ ├── selective_scan_common.h
│ ├── selective_scan_fwd_bf16.cu
│ ├── selective_scan_fwd_fp16.cu
│ ├── selective_scan_fwd_fp32.cu
│ ├── selective_scan_fwd_kernel.cuh
│ ├── static_switch.h
│ └── uninitialized_copy.cuh
├── evals
└── lm_harness_eval.py
├── mamba_ssm
├── __init__.py
├── models
│ ├── __init__.py
│ ├── config_mamba.py
│ └── mixer_seq_simple.py
├── modules
│ ├── __init__.py
│ └── mamba_simple.py
├── ops
│ ├── __init__.py
│ ├── selective_scan_interface.py
│ └── triton
│ │ ├── __init__.py
│ │ ├── layernorm.py
│ │ └── selective_state_update.py
└── utils
│ ├── __init__.py
│ ├── generation.py
│ └── hf.py
├── setup.py
└── tests
└── ops
├── test_selective_scan.py
└── triton
└── test_selective_state_update.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | causal-conv1d/causal_conv1d/__pycache__/*
3 | mamba/mamba_ssm/utils/__pycache__/*
4 | mamba/mamba_ssm/ops/triton/__pycache__/*
5 | mamba/mamba_ssm/ops/__pycache__/*
6 | mamba/mamba_ssm/modules/__pycache__/*
7 | mamba/mamba_ssm/__pycache__/*
8 | mamba/mamba_ssm/models/__pycache__/*
9 | causal-conv1d/build/
10 | causal-conv1d/causal_conv1d.egg-info/
11 | causal-conv1d/causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so
12 | mamba/build/
13 | mamba/mamba_ssm.egg-info/
14 | mamba/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/=1.1.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/=1.1.0
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
Famba-V: Fast Vision Mamba with Cross-Layer Token Fusion
7 |
8 |
13 |
14 | ## Introduction
15 |
16 | > **[Famba-V: Fast Vision Mamba with Cross-Layer Token Fusion](https://arxiv.org/abs/2409.09808)** [[arXiv]](https://arxiv.org/abs/2409.09808)
17 | > *Hui Shen, Zhongwei Wan, Xin Wang, Mi Zhang*
18 | > *The Ohio State University*
19 | > *ECCV 2024 Workshop on Computational Aspects of Deep Learning*
20 |
21 | ### ⚡News: Famba-V won the Best Paper Award of the ECCV 2024 Workshop on Computational Aspects of Deep Learning.
22 |
23 | ## Abstract
24 | Mamba and Vision Mamba (Vim) models have shown their potential as an alternative to methods based on Transformer architecture. This work introduces Fast Mamba for Vision (Famba-V), a cross-layer token fusion technique to enhance the training efficiency of Vim models. The key idea of Famba-V is to identify and fuse similar tokens across different Vim layers based on a suit of cross-layer strategies instead of simply applying token fusion uniformly across all the layers that existing works propose. We evaluate the performance of Famba-V on CIFAR-100. Our results show that Famba-V is able to enhance the training efficiency of Vim models by reducing both training time and peak memory usage during training. Moreover, the proposed cross-layer strategies allow Famba-V to deliver superior accuracy-efficiency trade-offs. These results all together demonstrate Famba-V as a promising efficiency enhancement technique for Vim models.
25 |
26 |
27 | ## Quick Start
28 |
29 | - Python 3.10.13
30 |
31 | - `conda create -n your_env_name python=3.10.13`
32 |
33 | - torch 2.1.1 + cu118
34 | - `pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118`
35 |
36 | - Requirements: vim_requirements.txt
37 | - `pip install -r fambav/vim_requirements.txt`
38 |
39 | - Install ``causal_conv1d`` and ``mamba``
40 | - `pip install -e causal_conv1d>=1.1.0`
41 | - `pip install -e mamba-1p1p1`
42 |
43 |
44 |
45 | ## Train Your Famba-V with Upper-layer Fusion Strategy
46 | ```bash
47 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ./output/vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy upper --fusion-layer 4 --fusion-token 8
48 | ```
49 | ## :heart: Acknowledgement
50 | This project is based on Vision Mamba ([paper](https://arxiv.org/abs/2401.09417), [code](https://github.com/hustvl/Vim?tab=readme-ov-file)), Mamba ([paper](https://arxiv.org/abs/2312.00752), [code](https://github.com/state-spaces/mamba)), Causal-Conv1d ([code](https://github.com/Dao-AILab/causal-conv1d)), DeiT ([paper](https://arxiv.org/abs/2012.12877), [code](https://github.com/facebookresearch/deit)). Thanks for their wonderful works.
51 |
52 | ## 🥳 Citation
53 | If you find Famba-V is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
54 |
55 | ```bibtex
56 | @inproceedings{fambav2024eccvw,
57 | title={Famba-V: Fast Vision Mamba with Sparse Fusion-based Visual Representation},
58 | author={Shen, Hui and Wan, Zhongwei and Wang, Xin and Zhang, Mi},
59 | booktitle={European Conference on Computer Vision (ECCV) Workshop on Computational Aspects of Deep Learning},
60 | year={2024}
61 | }
62 | ```
63 |
--------------------------------------------------------------------------------
/assets/FambaOverview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/assets/FambaOverview.png
--------------------------------------------------------------------------------
/assets/Strategies.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/assets/Strategies.png
--------------------------------------------------------------------------------
/assets/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/assets/logo.jpg
--------------------------------------------------------------------------------
/causal-conv1d/=1.1.0:
--------------------------------------------------------------------------------
1 | Obtaining file:///users/PAS2490/marcusshen/Vim/causal-conv1d/causal_conv1d
2 |
--------------------------------------------------------------------------------
/causal-conv1d/AUTHORS:
--------------------------------------------------------------------------------
1 | Tri Dao, tri@tridao.me
2 |
--------------------------------------------------------------------------------
/causal-conv1d/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/causal-conv1d/README.md:
--------------------------------------------------------------------------------
1 | # Causal depthwise conv1d in CUDA with a PyTorch interface
2 |
--------------------------------------------------------------------------------
/causal-conv1d/[13:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/causal-conv1d/[13
--------------------------------------------------------------------------------
/causal-conv1d/causal_conv1d/=1.1.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/causal-conv1d/causal_conv1d/=1.1.0
--------------------------------------------------------------------------------
/causal-conv1d/causal_conv1d/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
4 |
--------------------------------------------------------------------------------
/causal-conv1d/causal_conv1d/causal_conv1d_interface.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao.
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | import causal_conv1d_cuda
8 |
9 |
10 | class CausalConv1dFn(torch.autograd.Function):
11 | @staticmethod
12 | def forward(ctx, x, weight, bias=None, activation=None):
13 | if activation not in [None, "silu", "swish"]:
14 | raise NotImplementedError("activation must be None, silu, or swish")
15 | if x.stride(2) != 1 and x.stride(1) != 1:
16 | x = x.contiguous()
17 | bias = bias.contiguous() if bias is not None else None
18 | ctx.save_for_backward(x, weight, bias)
19 | ctx.activation = activation in ["silu", "swish"]
20 | out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
21 | return out
22 |
23 | @staticmethod
24 | def backward(ctx, dout):
25 | x, weight, bias = ctx.saved_tensors
26 | if dout.stride(2) != 1 and dout.stride(1) != 1:
27 | dout = dout.contiguous()
28 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
29 | # backward of conv1d with the backward of chunk).
30 | # Here we just pass in None and dx will be allocated in the C++ code.
31 | dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
32 | x, weight, bias, dout, None, ctx.activation
33 | )
34 | return dx, dweight, dbias if bias is not None else None, None
35 |
36 |
37 | def causal_conv1d_fn(x, weight, bias=None, activation=None):
38 | """
39 | x: (batch, dim, seqlen)
40 | weight: (dim, width)
41 | bias: (dim,)
42 | activation: either None or "silu" or "swish"
43 |
44 | out: (batch, dim, seqlen)
45 | """
46 | return CausalConv1dFn.apply(x, weight, bias, activation)
47 |
48 |
49 | def causal_conv1d_ref(x, weight, bias=None, activation=None):
50 | """
51 | x: (batch, dim, seqlen)
52 | weight: (dim, width)
53 | bias: (dim,)
54 |
55 | out: (batch, dim, seqlen)
56 | """
57 | if activation not in [None, "silu", "swish"]:
58 | raise NotImplementedError("activation must be None, silu, or swish")
59 | dtype_in = x.dtype
60 | x = x.to(weight.dtype)
61 | seqlen = x.shape[-1]
62 | dim, width = weight.shape
63 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
64 | out = out[..., :seqlen]
65 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
66 |
67 |
68 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
69 | """
70 | x: (batch, dim)
71 | conv_state: (batch, dim, width)
72 | weight: (dim, width)
73 | bias: (dim,)
74 |
75 | out: (batch, dim)
76 | """
77 | if activation not in [None, "silu", "swish"]:
78 | raise NotImplementedError("activation must be None, silu, or swish")
79 | activation = activation in ["silu", "swish"]
80 | return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
81 |
82 |
83 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
84 | """
85 | x: (batch, dim)
86 | conv_state: (batch, dim, width)
87 | weight: (dim, width)
88 | bias: (dim,)
89 |
90 | out: (batch, dim)
91 | """
92 | if activation not in [None, "silu", "swish"]:
93 | raise NotImplementedError("activation must be None, silu, or swish")
94 | dtype_in = x.dtype
95 | batch, dim = x.shape
96 | width = weight.shape[1]
97 | assert conv_state.shape == (batch, dim, width)
98 | assert weight.shape == (dim, width)
99 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
100 | conv_state[:, :, -1] = x
101 | out = torch.sum(conv_state * weight, dim=-1) # (B D)
102 | if bias is not None:
103 | out += bias
104 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
105 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | ////////////////////////////////////////////////////////////////////////////////////////////////////
8 |
9 | struct ConvParamsBase {
10 | using index_t = uint32_t;
11 |
12 | int batch, dim, seqlen, width;
13 | bool silu_activation;
14 |
15 | index_t x_batch_stride;
16 | index_t x_c_stride;
17 | index_t x_l_stride;
18 | index_t weight_c_stride;
19 | index_t weight_width_stride;
20 | index_t out_batch_stride;
21 | index_t out_c_stride;
22 | index_t out_l_stride;
23 |
24 | index_t conv_state_batch_stride;
25 | index_t conv_state_c_stride;
26 | index_t conv_state_l_stride;
27 |
28 | // Common data pointers.
29 | void *__restrict__ x_ptr;
30 | void *__restrict__ weight_ptr;
31 | void *__restrict__ bias_ptr;
32 | void *__restrict__ out_ptr;
33 |
34 | void *__restrict__ conv_state_ptr;
35 | };
36 |
37 | struct ConvParamsBwd: public ConvParamsBase {
38 | index_t dx_batch_stride;
39 | index_t dx_c_stride;
40 | index_t dx_l_stride;
41 | index_t dweight_c_stride;
42 | index_t dweight_width_stride;
43 | index_t dout_batch_stride;
44 | index_t dout_c_stride;
45 | index_t dout_l_stride;
46 |
47 | // Common data pointers.
48 | void *__restrict__ dx_ptr;
49 | void *__restrict__ dweight_ptr;
50 | void *__restrict__ dbias_ptr;
51 | void *__restrict__ dout_ptr;
52 | };
53 |
54 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d_common.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | #include
8 | #include
9 |
10 | ////////////////////////////////////////////////////////////////////////////////////////////////////
11 |
12 | template struct BytesToType {};
13 |
14 | template<> struct BytesToType<16> {
15 | using Type = uint4;
16 | static_assert(sizeof(Type) == 16);
17 | };
18 |
19 | template<> struct BytesToType<8> {
20 | using Type = uint64_t;
21 | static_assert(sizeof(Type) == 8);
22 | };
23 |
24 | template<> struct BytesToType<4> {
25 | using Type = uint32_t;
26 | static_assert(sizeof(Type) == 4);
27 | };
28 |
29 | template<> struct BytesToType<2> {
30 | using Type = uint16_t;
31 | static_assert(sizeof(Type) == 2);
32 | };
33 |
34 | template<> struct BytesToType<1> {
35 | using Type = uint8_t;
36 | static_assert(sizeof(Type) == 1);
37 | };
38 |
39 | ////////////////////////////////////////////////////////////////////////////////////////////////////
40 |
41 | template
42 | struct SumOp {
43 | __device__ inline T operator()(T const & x, T const & y) { return x + y; }
44 | };
45 |
46 | template
47 | struct Allreduce {
48 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
49 | template
50 | static __device__ inline T run(T x, Operator &op) {
51 | constexpr int OFFSET = THREADS / 2;
52 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
53 | return Allreduce::run(x, op);
54 | }
55 | };
56 |
57 | template<>
58 | struct Allreduce<2> {
59 | template
60 | static __device__ inline T run(T x, Operator &op) {
61 | x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
62 | return x;
63 | }
64 | };
65 |
--------------------------------------------------------------------------------
/causal-conv1d/csrc/causal_conv1d_update.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #include
6 | #include
7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8 |
9 | #include
10 | #include
11 |
12 | #include "causal_conv1d.h"
13 | #include "causal_conv1d_common.h"
14 | #include "static_switch.h"
15 |
16 | template
17 | struct Causal_conv1d_update_kernel_traits {
18 | using input_t = input_t_;
19 | using weight_t = weight_t_;
20 | static constexpr int kNThreads = kNThreads_;
21 | static constexpr int kWidth = kWidth_;
22 | static constexpr int kNBytes = sizeof(input_t);
23 | static_assert(kNBytes == 2 || kNBytes == 4);
24 | };
25 |
26 | template
27 | __global__ __launch_bounds__(Ktraits::kNThreads)
28 | void causal_conv1d_update_kernel(ConvParamsBase params) {
29 | constexpr int kWidth = Ktraits::kWidth;
30 | constexpr int kNThreads = Ktraits::kNThreads;
31 | using input_t = typename Ktraits::input_t;
32 | using weight_t = typename Ktraits::weight_t;
33 |
34 | const int tidx = threadIdx.x;
35 | const int batch_id = blockIdx.x;
36 | const int channel_id = blockIdx.y * kNThreads + tidx;
37 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride
38 | + channel_id * params.x_c_stride;
39 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
40 | + channel_id * params.conv_state_c_stride;
41 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride;
42 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride
43 | + channel_id * params.out_c_stride;
44 | float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]);
45 |
46 | float weight_vals[kWidth] = {0};
47 | if (channel_id < params.dim) {
48 | #pragma unroll
49 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
50 | }
51 |
52 | float x_vals[kWidth] = {0};
53 | if (channel_id < params.dim) {
54 | #pragma unroll
55 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
56 | x_vals[kWidth - 1] = float(x[0]);
57 | #pragma unroll
58 | for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
59 | }
60 |
61 | float out_val = bias_val;
62 | #pragma unroll
63 | for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
64 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
65 | if (channel_id < params.dim) { out[0] = input_t(out_val); }
66 | }
67 |
68 | template
69 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
70 | using Ktraits = Causal_conv1d_update_kernel_traits;
71 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
72 | auto kernel = &causal_conv1d_update_kernel;
73 | kernel<<>>(params);
74 | C10_CUDA_KERNEL_LAUNCH_CHECK();
75 | }
76 |
77 | template
78 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
79 | if (params.width == 2) {
80 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
81 | } else if (params.width == 3) {
82 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
83 | } else if (params.width == 4) {
84 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
85 | }
86 | }
87 |
88 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
89 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
90 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
91 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
92 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
93 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
94 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
95 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
96 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/causal-conv1d/csrc/static_switch.h:
--------------------------------------------------------------------------------
1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3 |
4 | #pragma once
5 |
6 | /// @param COND - a boolean expression to switch by
7 | /// @param CONST_NAME - a name given for the constexpr bool variable.
8 | /// @param ... - code to execute for true and false
9 | ///
10 | /// Usage:
11 | /// ```
12 | /// BOOL_SWITCH(flag, BoolConst, [&] {
13 | /// some_function(...);
14 | /// });
15 | /// ```
16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17 | [&] { \
18 | if (COND) { \
19 | static constexpr bool CONST_NAME = true; \
20 | return __VA_ARGS__(); \
21 | } else { \
22 | static constexpr bool CONST_NAME = false; \
23 | return __VA_ARGS__(); \
24 | } \
25 | }()
26 |
--------------------------------------------------------------------------------
/causal-conv1d/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao.
2 | import sys
3 | import warnings
4 | import os
5 | import re
6 | import ast
7 | from pathlib import Path
8 | from packaging.version import parse, Version
9 | import platform
10 |
11 | from setuptools import setup, find_packages
12 | import subprocess
13 |
14 | import urllib.request
15 | import urllib.error
16 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
17 |
18 | import torch
19 | from torch.utils.cpp_extension import (
20 | BuildExtension,
21 | CppExtension,
22 | CUDAExtension,
23 | CUDA_HOME,
24 | )
25 |
26 |
27 | with open("README.md", "r", encoding="utf-8") as fh:
28 | long_description = fh.read()
29 |
30 |
31 | # ninja build does not work unless include_dirs are abs path
32 | this_dir = os.path.dirname(os.path.abspath(__file__))
33 |
34 | PACKAGE_NAME = "causal_conv1d"
35 |
36 | BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"
37 |
38 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
39 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
40 | FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
41 | SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
42 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
43 | FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
44 |
45 |
46 | def get_platform():
47 | """
48 | Returns the platform name as used in wheel filenames.
49 | """
50 | if sys.platform.startswith("linux"):
51 | return "linux_x86_64"
52 | elif sys.platform == "darwin":
53 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
54 | return f"macosx_{mac_version}_x86_64"
55 | elif sys.platform == "win32":
56 | return "win_amd64"
57 | else:
58 | raise ValueError("Unsupported platform: {}".format(sys.platform))
59 |
60 |
61 | def get_cuda_bare_metal_version(cuda_dir):
62 | raw_output = subprocess.check_output(
63 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
64 | )
65 | output = raw_output.split()
66 | release_idx = output.index("release") + 1
67 | bare_metal_version = parse(output[release_idx].split(",")[0])
68 |
69 | return raw_output, bare_metal_version
70 |
71 |
72 | def check_if_cuda_home_none(global_option: str) -> None:
73 | if CUDA_HOME is not None:
74 | return
75 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
76 | # in that case.
77 | warnings.warn(
78 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
79 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
80 | "only images whose names contain 'devel' will provide nvcc."
81 | )
82 |
83 |
84 | def append_nvcc_threads(nvcc_extra_args):
85 | return nvcc_extra_args + ["--threads", "4"]
86 |
87 |
88 | cmdclass = {}
89 | ext_modules = []
90 |
91 | if not SKIP_CUDA_BUILD:
92 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
93 | TORCH_MAJOR = int(torch.__version__.split(".")[0])
94 | TORCH_MINOR = int(torch.__version__.split(".")[1])
95 |
96 | check_if_cuda_home_none("causal_conv1d")
97 | # Check, if CUDA11 is installed for compute capability 8.0
98 | cc_flag = []
99 | if CUDA_HOME is not None:
100 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
101 | if bare_metal_version < Version("11.6"):
102 | raise RuntimeError(
103 | "causal_conv1d is only supported on CUDA 11.6 and above. "
104 | "Note: make sure nvcc has a supported version by running nvcc -V."
105 | )
106 |
107 | cc_flag.append("-gencode")
108 | cc_flag.append("arch=compute_70,code=sm_70")
109 | cc_flag.append("-gencode")
110 | cc_flag.append("arch=compute_80,code=sm_80")
111 | if bare_metal_version >= Version("11.8"):
112 | cc_flag.append("-gencode")
113 | cc_flag.append("arch=compute_90,code=sm_90")
114 |
115 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
116 | # torch._C._GLIBCXX_USE_CXX11_ABI
117 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
118 | if FORCE_CXX11_ABI:
119 | torch._C._GLIBCXX_USE_CXX11_ABI = True
120 |
121 | ext_modules.append(
122 | CUDAExtension(
123 | name="causal_conv1d_cuda",
124 | sources=[
125 | "csrc/causal_conv1d.cpp",
126 | "csrc/causal_conv1d_fwd.cu",
127 | "csrc/causal_conv1d_bwd.cu",
128 | "csrc/causal_conv1d_update.cu",
129 | ],
130 | extra_compile_args={
131 | "cxx": ["-O3"],
132 | "nvcc": append_nvcc_threads(
133 | [
134 | "-O3",
135 | "-U__CUDA_NO_HALF_OPERATORS__",
136 | "-U__CUDA_NO_HALF_CONVERSIONS__",
137 | "-U__CUDA_NO_BFLOAT16_OPERATORS__",
138 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
139 | "-U__CUDA_NO_BFLOAT162_OPERATORS__",
140 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
141 | "--expt-relaxed-constexpr",
142 | "--expt-extended-lambda",
143 | "--use_fast_math",
144 | "--ptxas-options=-v",
145 | "-lineinfo",
146 | ]
147 | + cc_flag
148 | ),
149 | },
150 | include_dirs=[this_dir],
151 | )
152 | )
153 |
154 |
155 | def get_package_version():
156 | with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
157 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
158 | public_version = ast.literal_eval(version_match.group(1))
159 | local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
160 | if local_version:
161 | return f"{public_version}+{local_version}"
162 | else:
163 | return str(public_version)
164 |
165 |
166 | def get_wheel_url():
167 | # Determine the version numbers that will be used to determine the correct wheel
168 | # We're using the CUDA version used to build torch, not the one currently installed
169 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
170 | torch_cuda_version = parse(torch.version.cuda)
171 | torch_version_raw = parse(torch.__version__)
172 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
173 | # to save CI time. Minor versions should be compatible.
174 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
175 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
176 | platform_name = get_platform()
177 | causal_conv1d_version = get_package_version()
178 | # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
179 | cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
180 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
181 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
182 |
183 | # Determine wheel URL based on CUDA version, torch version, python version and OS
184 | wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
185 | wheel_url = BASE_WHEEL_URL.format(
186 | tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
187 | )
188 | return wheel_url, wheel_filename
189 |
190 |
191 | class CachedWheelsCommand(_bdist_wheel):
192 | """
193 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
194 | find an existing wheel (which is currently the case for all installs). We use
195 | the environment parameters to detect whether there is already a pre-built version of a compatible
196 | wheel available and short-circuits the standard full build pipeline.
197 | """
198 |
199 | def run(self):
200 | if FORCE_BUILD:
201 | return super().run()
202 |
203 | wheel_url, wheel_filename = get_wheel_url()
204 | print("Guessing wheel URL: ", wheel_url)
205 | try:
206 | urllib.request.urlretrieve(wheel_url, wheel_filename)
207 |
208 | # Make the archive
209 | # Lifted from the root wheel processing command
210 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
211 | if not os.path.exists(self.dist_dir):
212 | os.makedirs(self.dist_dir)
213 |
214 | impl_tag, abi_tag, plat_tag = self.get_tag()
215 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
216 |
217 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
218 | print("Raw wheel path", wheel_path)
219 | os.rename(wheel_filename, wheel_path)
220 | except urllib.error.HTTPError:
221 | print("Precompiled wheel not found. Building from source...")
222 | # If the wheel could not be downloaded, build from source
223 | super().run()
224 |
225 |
226 | setup(
227 | name=PACKAGE_NAME,
228 | version=get_package_version(),
229 | packages=find_packages(
230 | exclude=(
231 | "build",
232 | "csrc",
233 | "include",
234 | "tests",
235 | "dist",
236 | "docs",
237 | "benchmarks",
238 | "causal_conv1d.egg-info",
239 | )
240 | ),
241 | author="Tri Dao",
242 | author_email="tri@tridao.me",
243 | description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
244 | long_description=long_description,
245 | long_description_content_type="text/markdown",
246 | url="https://github.com/Dao-AILab/causal-conv1d",
247 | classifiers=[
248 | "Programming Language :: Python :: 3",
249 | "License :: OSI Approved :: BSD License",
250 | "Operating System :: Unix",
251 | ],
252 | ext_modules=ext_modules,
253 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
254 | if ext_modules
255 | else {
256 | "bdist_wheel": CachedWheelsCommand,
257 | },
258 | python_requires=">=3.7",
259 | install_requires=[
260 | "torch",
261 | "packaging",
262 | "ninja",
263 | ],
264 | )
265 |
--------------------------------------------------------------------------------
/causal-conv1d/tests/test_causal_conv1d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Tri Dao.
2 |
3 | import math
4 |
5 | import torch
6 | import pytest
7 |
8 | from einops import rearrange
9 |
10 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref
11 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref
12 |
13 |
14 | @pytest.mark.parametrize("channel_last", [False, True])
15 | # @pytest.mark.parametrize('channel_last', [True])
16 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
17 | # @pytest.mark.parametrize('itype', [torch.float16])
18 | @pytest.mark.parametrize("silu_activation", [False, True])
19 | # @pytest.mark.parametrize('silu_activation', [True])
20 | @pytest.mark.parametrize("has_bias", [False, True])
21 | # @pytest.mark.parametrize('has_bias', [True])
22 | @pytest.mark.parametrize("width", [2, 3, 4])
23 | # @pytest.mark.parametrize('width', [2])
24 | @pytest.mark.parametrize(
25 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
26 | )
27 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
28 | # @pytest.mark.parametrize('seqlen', [128])
29 | def test_causal_conv1d(seqlen, width, has_bias, silu_activation, itype, channel_last):
30 | device = "cuda"
31 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
32 | if itype == torch.bfloat16:
33 | rtol, atol = 1e-2, 5e-2
34 | rtolw, atolw = (1e-3, 1e-3)
35 | # set seed
36 | torch.random.manual_seed(0)
37 | batch_size = 2
38 | # batch_size = 1
39 | dim = 4096 + 32 # Try dim not divisible by 64
40 | # dim = 64
41 | if not channel_last:
42 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
43 | else:
44 | x = rearrange(
45 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
46 | ).requires_grad_()
47 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
48 | if has_bias:
49 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
50 | else:
51 | bias = None
52 | x_ref = x.detach().clone().requires_grad_()
53 | weight_ref = weight.detach().clone().requires_grad_()
54 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
55 | activation = None if not silu_activation else "silu"
56 | out = causal_conv1d_fn(x, weight, bias, activation=activation)
57 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, activation=activation)
58 |
59 | print(f"Output max diff: {(out - out_ref).abs().max().item()}")
60 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
61 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
62 |
63 | g = torch.randn_like(out)
64 | out_ref.backward(g)
65 | out.backward(g)
66 |
67 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
68 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
69 | if has_bias:
70 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
71 |
72 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
73 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
74 | if has_bias:
75 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
76 |
77 |
78 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
79 | # @pytest.mark.parametrize('itype', [torch.float16])
80 | @pytest.mark.parametrize("silu_activation", [False, True])
81 | # @pytest.mark.parametrize('silu_activation', [False])
82 | @pytest.mark.parametrize("has_bias", [False, True])
83 | # @pytest.mark.parametrize('has_bias', [True])
84 | @pytest.mark.parametrize("width", [2, 3, 4])
85 | # @pytest.mark.parametrize('width', [2])
86 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
87 | # @pytest.mark.parametrize("dim", [2048])
88 | def test_causal_conv1d_update(dim, width, has_bias, silu_activation, itype):
89 | device = "cuda"
90 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
91 | if itype == torch.bfloat16:
92 | rtol, atol = 1e-2, 5e-2
93 | rtolw, atolw = (1e-3, 1e-3)
94 | # set seed
95 | torch.random.manual_seed(0)
96 | batch_size = 2
97 | # batch_size = 1
98 | # dim = 64
99 | x = torch.randn(batch_size, dim, device=device, dtype=itype)
100 | conv_state = torch.randn(batch_size, dim, width, device=device, dtype=itype)
101 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
102 | if has_bias:
103 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
104 | else:
105 | bias = None
106 | conv_state_ref = conv_state.detach().clone()
107 | activation = None if not silu_activation else "silu"
108 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
109 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation)
110 |
111 | print(f"Output max diff: {(out - out_ref).abs().max().item()}")
112 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
113 | assert torch.equal(conv_state, conv_state_ref)
114 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
115 |
116 |
117 | # @pytest.mark.parametrize("channel_last", [False, True])
118 | @pytest.mark.parametrize('channel_last', [True])
119 | # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
120 | @pytest.mark.parametrize('itype', [torch.bfloat16])
121 | # @pytest.mark.parametrize("silu_activation", [False, True])
122 | @pytest.mark.parametrize('silu_activation', [True])
123 | # @pytest.mark.parametrize("has_bias", [False, True])
124 | @pytest.mark.parametrize('has_bias', [True])
125 | # @pytest.mark.parametrize("width", [2, 3, 4])
126 | @pytest.mark.parametrize('width', [4])
127 | @pytest.mark.parametrize(
128 | # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
129 | "seqlen", [2048]
130 | )
131 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
132 | # @pytest.mark.parametrize('seqlen', [128])
133 | def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
134 | device = "cuda"
135 | # set seed
136 | torch.random.manual_seed(0)
137 | batch_size = 2
138 | # batch_size = 1
139 | dim = 4096 + 32 # Try dim not divisible by 64
140 | # dim = 64
141 | if not channel_last:
142 | x = torch.randn(batch_size, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
143 | else:
144 | x = rearrange(
145 | torch.randn(batch_size, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
146 | ).requires_grad_()
147 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
148 | if has_bias:
149 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
150 | else:
151 | bias = None
152 | activation = None if not silu_activation else "silu"
153 | out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
154 | g = torch.randn_like(out0)
155 | dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
156 | dw_atol = 1e-4
157 | db_atol = 1e-4
158 |
159 | for i in range(10000):
160 | out = causal_conv1d_fn(x, weight, bias, activation=activation)
161 | dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
162 | dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
163 | # if not dw_equal:
164 | # breakpoint()
165 | if has_bias:
166 | db_equal = torch.allclose(db, db0, atol=db_atol)
167 | # if not db_equal:
168 | # breakpoint()
169 | assert torch.equal(out, out0)
170 | assert torch.equal(dx, dx0)
171 | assert dw_equal
172 | if has_bias:
173 | assert dw_equal
174 |
--------------------------------------------------------------------------------
/fambav/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | **/__pycache__/**
3 | imnet_resnet50_scratch/timm_temp/
4 | .dumbo.json
5 | checkpoints/
6 |
--------------------------------------------------------------------------------
/fambav/augment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | """
5 | 3Augment implementation
6 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
7 | and timm DA(https://github.com/rwightman/pytorch-image-models)
8 | """
9 | import torch
10 | from torchvision import transforms
11 |
12 | # error: cannot import name '_pil_interp' from 'timm.data.transforms'
13 | # from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
14 |
15 | # fix: timm version problem
16 | # from timm.data.transforms import str_pil_interp as _pil_interp
17 | from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor
18 |
19 | import numpy as np
20 | from torchvision import datasets, transforms
21 | import random
22 |
23 |
24 |
25 | from PIL import ImageFilter, ImageOps
26 | import torchvision.transforms.functional as TF
27 |
28 |
29 | class GaussianBlur(object):
30 | """
31 | Apply Gaussian Blur to the PIL image.
32 | """
33 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
34 | self.prob = p
35 | self.radius_min = radius_min
36 | self.radius_max = radius_max
37 |
38 | def __call__(self, img):
39 | do_it = random.random() <= self.prob
40 | if not do_it:
41 | return img
42 |
43 | img = img.filter(
44 | ImageFilter.GaussianBlur(
45 | radius=random.uniform(self.radius_min, self.radius_max)
46 | )
47 | )
48 | return img
49 |
50 | class Solarization(object):
51 | """
52 | Apply Solarization to the PIL image.
53 | """
54 | def __init__(self, p=0.2):
55 | self.p = p
56 |
57 | def __call__(self, img):
58 | if random.random() < self.p:
59 | return ImageOps.solarize(img)
60 | else:
61 | return img
62 |
63 | class gray_scale(object):
64 | """
65 | Apply Solarization to the PIL image.
66 | """
67 | def __init__(self, p=0.2):
68 | self.p = p
69 | self.transf = transforms.Grayscale(3)
70 |
71 | def __call__(self, img):
72 | if random.random() < self.p:
73 | return self.transf(img)
74 | else:
75 | return img
76 |
77 |
78 |
79 | class horizontal_flip(object):
80 | """
81 | Apply Solarization to the PIL image.
82 | """
83 | def __init__(self, p=0.2,activate_pred=False):
84 | self.p = p
85 | self.transf = transforms.RandomHorizontalFlip(p=1.0)
86 |
87 | def __call__(self, img):
88 | if random.random() < self.p:
89 | return self.transf(img)
90 | else:
91 | return img
92 |
93 |
94 |
95 | def new_data_aug_generator(args = None):
96 | img_size = args.input_size
97 | remove_random_resized_crop = args.src
98 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
99 | primary_tfl = []
100 | scale=(0.08, 1.0)
101 | interpolation='bicubic'
102 | if remove_random_resized_crop:
103 | primary_tfl = [
104 | transforms.Resize(img_size, interpolation=3),
105 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'),
106 | transforms.RandomHorizontalFlip()
107 | ]
108 | else:
109 | primary_tfl = [
110 | RandomResizedCropAndInterpolation(
111 | img_size, scale=scale, interpolation=interpolation),
112 | transforms.RandomHorizontalFlip()
113 | ]
114 |
115 |
116 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),
117 | Solarization(p=1.0),
118 | GaussianBlur(p=1.0)])]
119 |
120 | if args.color_jitter is not None and not args.color_jitter==0:
121 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))
122 | final_tfl = [
123 | transforms.ToTensor(),
124 | transforms.Normalize(
125 | mean=torch.tensor(mean),
126 | std=torch.tensor(std))
127 | ]
128 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)
129 |
--------------------------------------------------------------------------------
/fambav/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import os
4 | import json
5 |
6 | from torchvision import datasets, transforms
7 | from torchvision.datasets.folder import ImageFolder, default_loader
8 |
9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10 | from timm.data import create_transform
11 |
12 |
13 | class INatDataset(ImageFolder):
14 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
15 | category='name', loader=default_loader):
16 | self.transform = transform
17 | self.loader = loader
18 | self.target_transform = target_transform
19 | self.year = year
20 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
21 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
22 | with open(path_json) as json_file:
23 | data = json.load(json_file)
24 |
25 | with open(os.path.join(root, 'categories.json')) as json_file:
26 | data_catg = json.load(json_file)
27 |
28 | path_json_for_targeter = os.path.join(root, f"train{year}.json")
29 |
30 | with open(path_json_for_targeter) as json_file:
31 | data_for_targeter = json.load(json_file)
32 |
33 | targeter = {}
34 | indexer = 0
35 | for elem in data_for_targeter['annotations']:
36 | king = []
37 | king.append(data_catg[int(elem['category_id'])][category])
38 | if king[0] not in targeter.keys():
39 | targeter[king[0]] = indexer
40 | indexer += 1
41 | self.nb_classes = len(targeter)
42 |
43 | self.samples = []
44 | for elem in data['images']:
45 | cut = elem['file_name'].split('/')
46 | target_current = int(cut[2])
47 | path_current = os.path.join(root, cut[0], cut[2], cut[3])
48 |
49 | categors = data_catg[target_current]
50 | target_current_true = targeter[categors[category]]
51 | self.samples.append((path_current, target_current_true))
52 |
53 | # __getitem__ and __len__ inherited from ImageFolder
54 |
55 |
56 | def build_dataset(is_train, args):
57 | transform = build_transform(is_train, args)
58 |
59 | if args.data_set == 'CIFAR':
60 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
61 | nb_classes = 100
62 | elif args.data_set == 'IMNET':
63 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
64 | dataset = datasets.ImageFolder(root, transform=transform)
65 | nb_classes = 1000
66 | elif args.data_set == 'INAT':
67 | dataset = INatDataset(args.data_path, train=is_train, year=2018,
68 | category=args.inat_category, transform=transform)
69 | nb_classes = dataset.nb_classes
70 | elif args.data_set == 'INAT19':
71 | dataset = INatDataset(args.data_path, train=is_train, year=2019,
72 | category=args.inat_category, transform=transform)
73 | nb_classes = dataset.nb_classes
74 |
75 | return dataset, nb_classes
76 |
77 |
78 | def build_transform(is_train, args):
79 | resize_im = args.input_size > 32
80 | if is_train:
81 | # this should always dispatch to transforms_imagenet_train
82 | transform = create_transform(
83 | input_size=args.input_size,
84 | is_training=True,
85 | color_jitter=args.color_jitter,
86 | auto_augment=args.aa,
87 | interpolation=args.train_interpolation,
88 | re_prob=args.reprob,
89 | re_mode=args.remode,
90 | re_count=args.recount,
91 | )
92 | if not resize_im:
93 | # replace RandomResizedCropAndInterpolation with
94 | # RandomCrop
95 | transform.transforms[0] = transforms.RandomCrop(
96 | args.input_size, padding=4)
97 | return transform
98 |
99 | t = []
100 | if resize_im:
101 | size = int(args.input_size / args.eval_crop_ratio)
102 | t.append(
103 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
104 | )
105 | t.append(transforms.CenterCrop(args.input_size))
106 |
107 | t.append(transforms.ToTensor())
108 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
109 | return transforms.Compose(t)
110 |
--------------------------------------------------------------------------------
/fambav/engine.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Train and eval functions used in main.py
5 | """
6 | import math
7 | import sys
8 | from typing import Iterable, Optional
9 |
10 | import torch
11 |
12 | import timm
13 | from timm.data import Mixup
14 | from timm.utils import accuracy, ModelEma
15 |
16 | from losses import DistillationLoss
17 | import utils
18 |
19 |
20 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
21 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
22 | device: torch.device, epoch: int, loss_scaler, amp_autocast, max_norm: float = 0,
23 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
24 | set_training_mode=True, args = None):
25 | model.train(set_training_mode)
26 | metric_logger = utils.MetricLogger(delimiter=" ")
27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
28 | header = 'Epoch: [{}]'.format(epoch)
29 | print_freq = 10
30 |
31 | if args.cosub:
32 | criterion = torch.nn.BCEWithLogitsLoss()
33 |
34 | # debug
35 | # count = 0
36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
37 | # count += 1
38 | # if count > 20:
39 | # break
40 |
41 | samples = samples.to(device, non_blocking=True)
42 | targets = targets.to(device, non_blocking=True)
43 |
44 | if mixup_fn is not None:
45 | samples, targets = mixup_fn(samples, targets)
46 |
47 | if args.cosub:
48 | samples = torch.cat((samples,samples),dim=0)
49 |
50 | if args.bce_loss:
51 | targets = targets.gt(0.0).type(targets.dtype)
52 |
53 | with amp_autocast():
54 | outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank)
55 | # outputs = model(samples)
56 | if not args.cosub:
57 | loss = criterion(samples, outputs, targets)
58 | else:
59 | outputs = torch.split(outputs, outputs.shape[0]//2, dim=0)
60 | loss = 0.25 * criterion(outputs[0], targets)
61 | loss = loss + 0.25 * criterion(outputs[1], targets)
62 | loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
63 | loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid())
64 |
65 | if args.if_nan2num:
66 | with amp_autocast():
67 | loss = torch.nan_to_num(loss)
68 |
69 | loss_value = loss.item()
70 |
71 | if not math.isfinite(loss_value):
72 | print("Loss is {}, stopping training".format(loss_value))
73 | if args.if_continue_inf:
74 | optimizer.zero_grad()
75 | continue
76 | else:
77 | sys.exit(1)
78 |
79 | optimizer.zero_grad()
80 |
81 | # this attribute is added by timm on one optimizer (adahessian)
82 | if isinstance(loss_scaler, timm.utils.NativeScaler):
83 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
84 | loss_scaler(loss, optimizer, clip_grad=max_norm,
85 | parameters=model.parameters(), create_graph=is_second_order)
86 | else:
87 | loss.backward()
88 | if max_norm != None:
89 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
90 | optimizer.step()
91 |
92 | torch.cuda.synchronize()
93 | if model_ema is not None:
94 | model_ema.update(model)
95 |
96 | metric_logger.update(loss=loss_value)
97 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
98 | # gather the stats from all processes
99 | metric_logger.synchronize_between_processes()
100 | print("Averaged stats:", metric_logger)
101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
102 |
103 |
104 | @torch.no_grad()
105 | def evaluate(data_loader, model, device, amp_autocast):
106 | criterion = torch.nn.CrossEntropyLoss()
107 |
108 | metric_logger = utils.MetricLogger(delimiter=" ")
109 | header = 'Test:'
110 |
111 | # switch to evaluation mode
112 | model.eval()
113 |
114 | for images, target in metric_logger.log_every(data_loader, 10, header):
115 | images = images.to(device, non_blocking=True)
116 | target = target.to(device, non_blocking=True)
117 |
118 | # compute output
119 | with amp_autocast():
120 | output = model(images)
121 | loss = criterion(output, target)
122 |
123 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
124 |
125 | batch_size = images.shape[0]
126 | metric_logger.update(loss=loss.item())
127 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
128 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
129 | # gather the stats from all processes
130 | metric_logger.synchronize_between_processes()
131 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
132 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
133 |
134 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
135 |
--------------------------------------------------------------------------------
/fambav/hubconf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | from models import *
4 | from cait_models import *
5 | from resmlp_models import *
6 | #from patchconvnet_models import *
7 |
8 | dependencies = ["torch", "torchvision", "timm"]
9 |
--------------------------------------------------------------------------------
/fambav/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Implements the knowledge distillation loss
5 | """
6 | import torch
7 | from torch.nn import functional as F
8 |
9 |
10 | class DistillationLoss(torch.nn.Module):
11 | """
12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by
13 | taking a teacher model prediction and using it as additional supervision.
14 | """
15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
16 | distillation_type: str, alpha: float, tau: float):
17 | super().__init__()
18 | self.base_criterion = base_criterion
19 | self.teacher_model = teacher_model
20 | assert distillation_type in ['none', 'soft', 'hard']
21 | self.distillation_type = distillation_type
22 | self.alpha = alpha
23 | self.tau = tau
24 |
25 | def forward(self, inputs, outputs, labels):
26 | """
27 | Args:
28 | inputs: The original inputs that are feed to the teacher model
29 | outputs: the outputs of the model to be trained. It is expected to be
30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output
31 | in the first position and the distillation predictions as the second output
32 | labels: the labels for the base criterion
33 | """
34 | outputs_kd = None
35 | if not isinstance(outputs, torch.Tensor):
36 | # assume that the model outputs a tuple of [outputs, outputs_kd]
37 | outputs, outputs_kd = outputs
38 | base_loss = self.base_criterion(outputs, labels)
39 | if self.distillation_type == 'none':
40 | return base_loss
41 |
42 | if outputs_kd is None:
43 | raise ValueError("When knowledge distillation is enabled, the model is "
44 | "expected to return a Tuple[Tensor, Tensor] with the output of the "
45 | "class_token and the dist_token")
46 | # don't backprop throught the teacher
47 | with torch.no_grad():
48 | teacher_outputs = self.teacher_model(inputs)
49 |
50 | if self.distillation_type == 'soft':
51 | T = self.tau
52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
53 | # with slight modifications
54 | distillation_loss = F.kl_div(
55 | F.log_softmax(outputs_kd / T, dim=1),
56 | #We provide the teacher's targets in log probability because we use log_target=True
57 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
58 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
59 | F.log_softmax(teacher_outputs / T, dim=1),
60 | reduction='sum',
61 | log_target=True
62 | ) * (T * T) / outputs_kd.numel()
63 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
64 | #But we also experiments output_kd.size(0)
65 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
66 | elif self.distillation_type == 'hard':
67 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
68 |
69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
70 | return loss
71 |
--------------------------------------------------------------------------------
/fambav/mlruns/0/meta.yaml:
--------------------------------------------------------------------------------
1 | artifact_location: file:///users/PAS2490/marcusshen/Vim/vim/mlruns/0
2 | creation_time: 1715856885741
3 | experiment_id: '0'
4 | last_update_time: 1715856885741
5 | lifecycle_stage: active
6 | name: Default
7 |
--------------------------------------------------------------------------------
/fambav/rope.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # EVA-02: A Visual Representation for Neon Genesis
3 | # Github source: https://github.com/baaivision/EVA/EVA02
4 | # Copyright (c) 2023 Beijing Academy of Artificial Intelligence (BAAI)
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Yuxin Fang
7 | #
8 | # Based on https://github.com/lucidrains/rotary-embedding-torch
9 | # --------------------------------------------------------'
10 |
11 | from math import pi
12 |
13 | import torch
14 | from torch import nn
15 |
16 | from einops import rearrange, repeat
17 |
18 |
19 |
20 | def broadcat(tensors, dim = -1):
21 | num_tensors = len(tensors)
22 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
23 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
24 | shape_len = list(shape_lens)[0]
25 | dim = (dim + shape_len) if dim < 0 else dim
26 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
27 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
28 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
29 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
30 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
31 | expanded_dims.insert(dim, (dim, dims[dim]))
32 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
33 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
34 | return torch.cat(tensors, dim = dim)
35 |
36 |
37 |
38 | def rotate_half(x):
39 | x = rearrange(x, '... (d r) -> ... d r', r = 2)
40 | x1, x2 = x.unbind(dim = -1)
41 | x = torch.stack((-x2, x1), dim = -1)
42 | return rearrange(x, '... d r -> ... (d r)')
43 |
44 |
45 |
46 | class VisionRotaryEmbedding(nn.Module):
47 | def __init__(
48 | self,
49 | dim,
50 | pt_seq_len,
51 | ft_seq_len=None,
52 | custom_freqs = None,
53 | freqs_for = 'lang',
54 | theta = 10000,
55 | max_freq = 10,
56 | num_freqs = 1,
57 | ):
58 | super().__init__()
59 | if custom_freqs:
60 | freqs = custom_freqs
61 | elif freqs_for == 'lang':
62 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
63 | elif freqs_for == 'pixel':
64 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
65 | elif freqs_for == 'constant':
66 | freqs = torch.ones(num_freqs).float()
67 | else:
68 | raise ValueError(f'unknown modality {freqs_for}')
69 |
70 | if ft_seq_len is None: ft_seq_len = pt_seq_len
71 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
72 |
73 | freqs_h = torch.einsum('..., f -> ... f', t, freqs)
74 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
75 |
76 | freqs_w = torch.einsum('..., f -> ... f', t, freqs)
77 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
78 |
79 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
80 |
81 | self.register_buffer("freqs_cos", freqs.cos())
82 | self.register_buffer("freqs_sin", freqs.sin())
83 |
84 | print('======== shape of rope freq', self.freqs_cos.shape, '========')
85 |
86 | def forward(self, t, start_index = 0):
87 | rot_dim = self.freqs_cos.shape[-1]
88 | end_index = start_index + rot_dim
89 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
90 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
91 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
92 | return torch.cat((t_left, t, t_right), dim = -1)
93 |
94 |
95 |
96 | class VisionRotaryEmbeddingFast(nn.Module):
97 | def __init__(
98 | self,
99 | dim,
100 | pt_seq_len=16,
101 | ft_seq_len=None,
102 | custom_freqs = None,
103 | freqs_for = 'lang',
104 | theta = 10000,
105 | max_freq = 10,
106 | num_freqs = 1,
107 | ):
108 | super().__init__()
109 | if custom_freqs:
110 | freqs = custom_freqs
111 | elif freqs_for == 'lang':
112 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
113 | elif freqs_for == 'pixel':
114 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
115 | elif freqs_for == 'constant':
116 | freqs = torch.ones(num_freqs).float()
117 | else:
118 | raise ValueError(f'unknown modality {freqs_for}')
119 |
120 | if ft_seq_len is None: ft_seq_len = pt_seq_len
121 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
122 |
123 | freqs = torch.einsum('..., f -> ... f', t, freqs)
124 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
125 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
126 |
127 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
128 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
129 |
130 | self.register_buffer("freqs_cos", freqs_cos)
131 | self.register_buffer("freqs_sin", freqs_sin)
132 |
133 | print('======== shape of rope freq', self.freqs_cos.shape, '========')
134 |
135 | def forward(self, t):
136 | if t.shape[1] % 2 != 0:
137 | t_spatial = t[:, 1:, :]
138 | t_spatial = t_spatial * self.freqs_cos + rotate_half(t_spatial) * self.freqs_sin
139 | return torch.cat((t[:, :1, :], t_spatial), dim=1)
140 | else:
141 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
--------------------------------------------------------------------------------
/fambav/run_with_submitit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | A script to run multinode training with submitit.
5 | """
6 | import argparse
7 | import os
8 | import uuid
9 | from pathlib import Path
10 |
11 | import main as classification
12 | import submitit
13 |
14 |
15 | def parse_args():
16 | classification_parser = classification.get_args_parser()
17 | parser = argparse.ArgumentParser("Submitit for DeiT", parents=[classification_parser])
18 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
19 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
20 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
21 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
22 |
23 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
24 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this")
25 | parser.add_argument('--comment', default="", type=str,
26 | help='Comment to pass to scheduler, e.g. priority message')
27 | return parser.parse_args()
28 |
29 |
30 | def get_shared_folder() -> Path:
31 | user = os.getenv("USER")
32 | if Path("/checkpoint/").is_dir():
33 | p = Path(f"/checkpoint/{user}/experiments")
34 | p.mkdir(exist_ok=True)
35 | return p
36 | raise RuntimeError("No shared folder available")
37 |
38 |
39 | def get_init_file():
40 | # Init file must not exist, but it's parent dir must exist.
41 | os.makedirs(str(get_shared_folder()), exist_ok=True)
42 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
43 | if init_file.exists():
44 | os.remove(str(init_file))
45 | return init_file
46 |
47 |
48 | class Trainer(object):
49 | def __init__(self, args):
50 | self.args = args
51 |
52 | def __call__(self):
53 | import main as classification
54 |
55 | self._setup_gpu_args()
56 | classification.main(self.args)
57 |
58 | def checkpoint(self):
59 | import os
60 | import submitit
61 |
62 | self.args.dist_url = get_init_file().as_uri()
63 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth")
64 | if os.path.exists(checkpoint_file):
65 | self.args.resume = checkpoint_file
66 | print("Requeuing ", self.args)
67 | empty_trainer = type(self)(self.args)
68 | return submitit.helpers.DelayedSubmission(empty_trainer)
69 |
70 | def _setup_gpu_args(self):
71 | import submitit
72 | from pathlib import Path
73 |
74 | job_env = submitit.JobEnvironment()
75 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
76 | self.args.gpu = job_env.local_rank
77 | self.args.rank = job_env.global_rank
78 | self.args.world_size = job_env.num_tasks
79 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
80 |
81 |
82 | def main():
83 | args = parse_args()
84 | if args.job_dir == "":
85 | args.job_dir = get_shared_folder() / "%j"
86 |
87 | # Note that the folder will depend on the job_id, to easily track experiments
88 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
89 |
90 | num_gpus_per_node = args.ngpus
91 | nodes = args.nodes
92 | timeout_min = args.timeout
93 |
94 | partition = args.partition
95 | kwargs = {}
96 | if args.use_volta32:
97 | kwargs['slurm_constraint'] = 'volta32gb'
98 | if args.comment:
99 | kwargs['slurm_comment'] = args.comment
100 |
101 | executor.update_parameters(
102 | mem_gb=40 * num_gpus_per_node,
103 | gpus_per_node=num_gpus_per_node,
104 | tasks_per_node=num_gpus_per_node, # one task per GPU
105 | cpus_per_task=10,
106 | nodes=nodes,
107 | timeout_min=timeout_min, # max is 60 * 72
108 | # Below are cluster dependent parameters
109 | slurm_partition=partition,
110 | slurm_signal_delay_s=120,
111 | **kwargs
112 | )
113 |
114 | executor.update_parameters(name="deit")
115 |
116 | args.dist_url = get_init_file().as_uri()
117 | args.output_dir = args.job_dir
118 |
119 | trainer = Trainer(args)
120 | job = executor.submit(trainer)
121 |
122 | print("Submitted job_id:", job.job_id)
123 |
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
--------------------------------------------------------------------------------
/fambav/samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import torch
4 | import torch.distributed as dist
5 | import math
6 |
7 |
8 | class RASampler(torch.utils.data.Sampler):
9 | """Sampler that restricts data loading to a subset of the dataset for distributed,
10 | with repeated augmentation.
11 | It ensures that different each augmented version of a sample will be visible to a
12 | different process (GPU)
13 | Heavily based on torch.utils.data.DistributedSampler
14 | """
15 |
16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
17 | if num_replicas is None:
18 | if not dist.is_available():
19 | raise RuntimeError("Requires distributed package to be available")
20 | num_replicas = dist.get_world_size()
21 | if rank is None:
22 | if not dist.is_available():
23 | raise RuntimeError("Requires distributed package to be available")
24 | rank = dist.get_rank()
25 | if num_repeats < 1:
26 | raise ValueError("num_repeats should be greater than 0")
27 | self.dataset = dataset
28 | self.num_replicas = num_replicas
29 | self.rank = rank
30 | self.num_repeats = num_repeats
31 | self.epoch = 0
32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
33 | self.total_size = self.num_samples * self.num_replicas
34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
36 | self.shuffle = shuffle
37 |
38 | def __iter__(self):
39 | if self.shuffle:
40 | # deterministically shuffle based on epoch
41 | g = torch.Generator()
42 | g.manual_seed(self.epoch)
43 | indices = torch.randperm(len(self.dataset), generator=g)
44 | else:
45 | indices = torch.arange(start=0, end=len(self.dataset))
46 |
47 | # add extra samples to make it evenly divisible
48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
49 | padding_size: int = self.total_size - len(indices)
50 | if padding_size > 0:
51 | indices += indices[:padding_size]
52 | assert len(indices) == self.total_size
53 |
54 | # subsample
55 | indices = indices[self.rank:self.total_size:self.num_replicas]
56 | assert len(indices) == self.num_samples
57 |
58 | return iter(indices[:self.num_selected_samples])
59 |
60 | def __len__(self):
61 | return self.num_selected_samples
62 |
63 | def set_epoch(self, epoch):
64 | self.epoch = epoch
65 |
--------------------------------------------------------------------------------
/fambav/scripts/mlruns/0/meta.yaml:
--------------------------------------------------------------------------------
1 | artifact_location: file:///users/PAS2490/marcusshen/Vim/vim/scripts/mlruns/0
2 | creation_time: 1718284550737
3 | experiment_id: '0'
4 | last_update_time: 1718284550737
5 | lifecycle_stage: active
6 | name: Default
7 |
--------------------------------------------------------------------------------
/fambav/scripts/vim-s-cifar-all-merge-r7.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-s-cifar-all-merge-r7 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-all-merge-r7.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-all-merge-r7_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=35:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_all_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy all --fusion-layer 5 --fusion-token 7
--------------------------------------------------------------------------------
/fambav/scripts/vim-s-cifar-alternate-merge-r14.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-s-cifar-alternate-merge-r14 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-alternate-merge-r14.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-alternate-merge-r14_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=35:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_alternate_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy alternate --fusion-layer 5 --fusion-token 14
--------------------------------------------------------------------------------
/fambav/scripts/vim-s-cifar-lower-merge-r9-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-s-cifar-lower-merge-r9-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-lower-merge-r9-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-lower-merge-r9-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=35:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r9l4_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 9
--------------------------------------------------------------------------------
/fambav/scripts/vim-s-cifar-lower-merge-r9-l5.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-s-cifar-lower-merge-r9-l5 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-lower-merge-r9-l5.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-lower-merge-r9-l5_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=35:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r9l5_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 5 --fusion-token 9
--------------------------------------------------------------------------------
/fambav/scripts/vim-s-cifar-non-merge.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-s-cifar-non-merge # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-non-merge.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-non-merge_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=35:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_non_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy no --fusion-layer 5 --fusion-token 7
--------------------------------------------------------------------------------
/fambav/scripts/vim-s-cifar-upper-merge-r9-l18.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-s-cifar-upper-merge-r9-l18 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-upper-merge-r9-l18.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-s-cifar-upper-merge-r9-l18_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=35:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_upper_r9l18_vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy upper --fusion-layer 18 --fusion-token 9
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-all-merge-r7-cls.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-all-merge-r7-cls # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-all-merge-r7-cls.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-all-merge-r7-cls_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_all_cls_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy all --fusion-layer 5 --fusion-token 7
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-all-merge-r7.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-all-merge-r7 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-all-merge-r7.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-all-merge-r7_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_all_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy all --fusion-layer 5 --fusion-token 7
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-alternate-merge-r14.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-alternate-merge-r14 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-alternate-merge-r14.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-alternate-merge-r14_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_alternate_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy alternate --fusion-layer 5 --fusion-token 14
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r1-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r1-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r1-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r1-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r1l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 1
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r10-l6.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r10-l6 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r10-l6.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r10-l6_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r10l6_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 6 --fusion-token 10
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r10-l7.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r10-l7 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r10-l7.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r10-l7_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r10l7_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 7 --fusion-token 10
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r11-l8.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r11-l8 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r11-l8.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r11-l8_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r11l8_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 8 --fusion-token 11
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r12-l10.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r12-l10 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r12-l10.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r12-l10_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r12l10_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 10 --fusion-token 12
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r12-l9.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r12-l9 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r12-l9.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r12-l9_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r12l9_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 9 --fusion-token 12
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r13-l11.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r13-l11 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r13-l11.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r13-l11_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r13l11_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 11 --fusion-token 13
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r14-l12.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r14-l12 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r14-l12.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r14-l12_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r14l12_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 12 --fusion-token 14
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r15-l13.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r15-l13 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r15-l13.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r15-l13_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r15l13_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 13 --fusion-token 15
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r2-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r2-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r2-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r2-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r2l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 2
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r3-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r3-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r3-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r3-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r3l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 3
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r4-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r4-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r4-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r4-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r4l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 4
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r5-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r5-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r5-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r5-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r5l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 5
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r6-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r6-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r6-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r6-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r6l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 6
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r7-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r7-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r7-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r7-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r7l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 7
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r8-l3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r8-l3 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r8-l3.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r8-l3_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r8l3_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 3 --fusion-token 8
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r8-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r8-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r8-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r8-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r8l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 8
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r9-l4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r9-l4 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r9-l4.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r9-l4_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_r9l4_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 4 --fusion-token 9
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-lower-merge-r9-l5.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-lower-merge-r9-l5 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r9-l5.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-lower-merge-r9-l5_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_lower_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 5 --fusion-token 9
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-non-merge-visual.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-non-merge-visual # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-non-merge-visual.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-non-merge-visual_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/visual_merge_cifar_non_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy no --fusion-layer 5 --fusion-token 7 --visualize
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-non-merge.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-non-merge # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-non-merge.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-non-merge_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_non_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy no --fusion-layer 5 --fusion-token 7
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-cifar-upper-merge-r9-l18.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-cifar-upper-merge-r9-l18 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-upper-merge-r9-l18.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-cifar-upper-merge-r9-l18_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=6:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ../output/merge_cifar_upper_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy upper --fusion-layer 18 --fusion-token 9
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-imagenet-all-merge-r7.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-imagenet-all-merge-r7-2 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-all-merge-r7-2.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-all-merge-r7-2_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=120:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_all_2_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy all --fusion-layer 5 --fusion-token 7
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-imagenet-alternate-merge-r14.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-imagenet-alternate-merge-r14-2 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-alternate-merge-r14-2.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-alternate-merge-r14-2_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=120:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_alternate_2_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy alternate --fusion-layer 5 --fusion-token 14
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-imagenet-lower-merge-r9-l5.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-imagenet-lower-merge-r9-l5 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-lower-merge-r9-l5.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-lower-merge-r9-l5_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=120:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_lower_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy lower --fusion-layer 5 --fusion-token 9
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-imagenet-non-merge.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-imagenet-non-merge # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-non-merge.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-non-merge_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=120:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_non_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy no --fusion-layer 8 --fusion-token 9
--------------------------------------------------------------------------------
/fambav/scripts/vim-t-imagenet-upper-merge-r9-l18.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim-t-imagenet-upper-merge-r9-l18 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-upper-merge-r9-l18.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim-t-imagenet-upper-merge-r9-l18_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=120:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env ../main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set IMNET --data-path /fs/scratch/PAS2490/imagenet --output_dir ../output/merge_imnet_upper_vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp --fusion-strategy upper --fusion-layer 18 --fusion-token 9
--------------------------------------------------------------------------------
/fambav/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=vim_1 # 作业名称
4 | #SBATCH --account=PAS2490 # Project ID
5 | #SBATCH --output=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim_1.log # 输出日志文件
6 | #SBATCH --error=/users/PAS2490/marcusshen/Vim/output_logs_vim/vim_1_error.log # 错误日志文件
7 | #SBATCH --nodes=1 # 节点数
8 | #SBATCH --ntasks-per-node=1 # 每个节点的任务数
9 | #SBATCH --cpus-per-task=4 # 每个任务使用的 CPU 核心数
10 | #SBATCH --gpus-per-node=4 # GPU per node
11 | #SBATCH --mem=80G # 内存限制
12 | #SBATCH --time=04:00:00 # 作业运行时间限制
13 |
14 | # 运行命令或脚本 wget https://repo.anaconda.com/archive/Anaconda3-2023.07-2-Linux-x86_64.sh
15 | source $HOME/miniconda3/bin/activate /users/PAS2490/marcusshen/miniconda3/envs/vim
16 | # module load cuda
17 | export CUDA_VISIBLE_DEVICES=0,1,2,3
18 |
19 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --batch-size 128 --drop-path 0.0 --weight-decay 0.1 --num_workers 25 --data-set CIFAR --data-path ./datasets/cifar-100-python --output_dir ./output/vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 --no_amp
20 |
--------------------------------------------------------------------------------
/fambav/token_merge.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 |
8 | import math
9 | from typing import Callable, Tuple
10 |
11 | import torch
12 |
13 |
14 | def do_nothing(x, mode=None):
15 | return x
16 |
17 |
18 | def bipartite_soft_matching(
19 | metric: torch.Tensor,
20 | r: int,
21 | class_token: bool = False,
22 | distill_token: bool = False,
23 | ) -> Tuple[Callable, Callable]:
24 | """
25 | Applies ToMe with a balanced matching set (50%, 50%).
26 |
27 | Input size is [batch, tokens, channels].
28 | r indicates the number of tokens to remove (max 50% of tokens).
29 |
30 | Extra args:
31 | - class_token: Whether or not there's a class token.
32 | - distill_token: Whether or not there's also a distillation token.
33 |
34 | When enabled, the class token and distillation tokens won't get merged.
35 | """
36 | protected = 0
37 | if class_token:
38 | protected += 1
39 | if distill_token:
40 | protected += 1
41 |
42 | # We can only reduce by a maximum of 50% tokens
43 | t = metric.shape[1]
44 | r = min(r, (t - protected) // 2)
45 |
46 | if r <= 0:
47 | return do_nothing, do_nothing
48 |
49 | with torch.no_grad():
50 | metric = metric / metric.norm(dim=-1, keepdim=True)
51 | a, b = metric[..., ::2, :], metric[..., 1::2, :]
52 | scores = a @ b.transpose(-1, -2)
53 |
54 | if class_token:
55 | scores[..., 0, :] = -math.inf
56 | if distill_token:
57 | scores[..., :, 0] = -math.inf
58 |
59 | node_max, node_idx = scores.max(dim=-1)
60 | edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
61 |
62 | unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
63 | src_idx = edge_idx[..., :r, :] # Merged Tokens
64 | dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
65 |
66 | if class_token:
67 | # Sort to ensure the class token is at the start
68 | unm_idx = unm_idx.sort(dim=1)[0]
69 |
70 | def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
71 | src, dst = x[..., ::2, :], x[..., 1::2, :]
72 | n, t1, c = src.shape
73 | unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
74 | src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
75 | dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
76 |
77 | if distill_token:
78 | return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
79 | else:
80 | return torch.cat([unm, dst], dim=1)
81 |
82 | def unmerge(x: torch.Tensor) -> torch.Tensor:
83 | unm_len = unm_idx.shape[1]
84 | unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
85 | n, _, c = unm.shape
86 |
87 | src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))
88 |
89 | out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)
90 |
91 | out[..., 1::2, :] = dst
92 | out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
93 | out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)
94 |
95 | return out
96 |
97 | return merge, unmerge
98 |
99 |
100 | def kth_bipartite_soft_matching(
101 | metric: torch.Tensor, k: int
102 | ) -> Tuple[Callable, Callable]:
103 | """
104 | Applies ToMe with the two sets as (every kth element, the rest).
105 | If n is the number of tokens, resulting number of tokens will be n // z.
106 |
107 | Input size is [batch, tokens, channels].
108 | z indicates the stride for the first set.
109 | z = 2 is equivalent to regular bipartite_soft_matching with r = 0.5 * N
110 | """
111 | if k <= 1:
112 | return do_nothing, do_nothing
113 |
114 | def split(x):
115 | t_rnd = (x.shape[1] // k) * k
116 | x = x[:, :t_rnd, :].view(x.shape[0], -1, k, x.shape[2])
117 | a, b = (
118 | x[:, :, : (k - 1), :].contiguous().view(x.shape[0], -1, x.shape[-1]),
119 | x[:, :, (k - 1), :],
120 | )
121 | return a, b
122 |
123 | with torch.no_grad():
124 | metric = metric / metric.norm(dim=-1, keepdim=True)
125 | a, b = split(metric)
126 | r = a.shape[1]
127 | scores = a @ b.transpose(-1, -2)
128 |
129 | _, dst_idx = scores.max(dim=-1)
130 | dst_idx = dst_idx[..., None]
131 |
132 | def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
133 | src, dst = split(x)
134 | n, _, c = src.shape
135 | dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
136 |
137 | return dst
138 |
139 | def unmerge(x: torch.Tensor) -> torch.Tensor:
140 | n, _, c = x.shape
141 | dst = x
142 |
143 | src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c)).to(x.dtype)
144 |
145 | src = src.view(n, -1, (k - 1), c)
146 | dst = dst.view(n, -1, 1, c)
147 |
148 | out = torch.cat([src, dst], dim=-2)
149 | out = out.contiguous().view(n, -1, c)
150 |
151 | return out
152 |
153 | return merge, unmerge
154 |
155 |
156 | def random_bipartite_soft_matching(
157 | metric: torch.Tensor, r: int
158 | ) -> Tuple[Callable, Callable]:
159 | """
160 | Applies ToMe with the two sets as (r chosen randomly, the rest).
161 | Input size is [batch, tokens, channels].
162 |
163 | This will reduce the number of tokens by r.
164 | """
165 | if r <= 0:
166 | return do_nothing, do_nothing
167 |
168 | with torch.no_grad():
169 | B, N, _ = metric.shape
170 | rand_idx = torch.rand(B, N, 1, device=metric.device).argsort(dim=1)
171 |
172 | a_idx = rand_idx[:, :r, :]
173 | b_idx = rand_idx[:, r:, :]
174 |
175 | def split(x):
176 | C = x.shape[-1]
177 | a = x.gather(dim=1, index=a_idx.expand(B, r, C))
178 | b = x.gather(dim=1, index=b_idx.expand(B, N - r, C))
179 | return a, b
180 |
181 | metric = metric / metric.norm(dim=-1, keepdim=True)
182 | a, b = split(metric)
183 | scores = a @ b.transpose(-1, -2)
184 |
185 | _, dst_idx = scores.max(dim=-1)
186 | dst_idx = dst_idx[..., None]
187 |
188 | def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
189 | src, dst = split(x)
190 | C = src.shape[-1]
191 | dst = dst.scatter_reduce(-2, dst_idx.expand(B, r, C), src, reduce=mode)
192 |
193 | return dst
194 |
195 | def unmerge(x: torch.Tensor) -> torch.Tensor:
196 | C = x.shape[-1]
197 | dst = x
198 | src = dst.gather(dim=-2, index=dst_idx.expand(B, r, C))
199 |
200 | out = torch.zeros(B, N, C, device=x.device, dtype=x.dtype)
201 |
202 | out.scatter_(dim=-2, index=a_idx.expand(B, r, C), src=src)
203 | out.scatter_(dim=-2, index=b_idx.expand(B, N - r, C), src=dst)
204 |
205 | return out
206 |
207 | return merge, unmerge
208 |
209 |
210 | def merge_wavg(
211 | merge: Callable, x: torch.Tensor, size: torch.Tensor = None
212 | ) -> Tuple[torch.Tensor, torch.Tensor]:
213 | """
214 | Applies the merge function by taking a weighted average based on token size.
215 | Returns the merged tensor and the new token sizes.
216 | """
217 | if size is None:
218 | size = torch.ones_like(x[..., 0, None])
219 |
220 | x = merge(x * size, mode="sum")
221 | size = merge(size, mode="sum")
222 |
223 | x = x / size
224 | return x, size
225 |
226 |
227 | def merge_source(
228 | merge: Callable, x: torch.Tensor, source: torch.Tensor = None
229 | ) -> torch.Tensor:
230 | """
231 | For source tracking. Source is an adjacency matrix between the initial tokens and final merged groups.
232 | x is used to find out how many tokens there are in case the source is None.
233 | """
234 | if source is None:
235 | n, t, _ = x.shape
236 | source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t)
237 |
238 | source = merge(source, mode="amax")
239 | return source
--------------------------------------------------------------------------------
/fambav/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Misc functions, including distributed helpers.
5 |
6 | Mostly copy-paste from torchvision references.
7 | """
8 | import io
9 | import os
10 | import time
11 | from collections import defaultdict, deque
12 | import datetime
13 |
14 | import torch
15 | import torch.distributed as dist
16 |
17 |
18 | class SmoothedValue(object):
19 | """Track a series of values and provide access to smoothed values over a
20 | window or the global series average.
21 | """
22 |
23 | def __init__(self, window_size=20, fmt=None):
24 | if fmt is None:
25 | fmt = "{median:.4f} ({global_avg:.4f})"
26 | self.deque = deque(maxlen=window_size)
27 | self.total = 0.0
28 | self.count = 0
29 | self.fmt = fmt
30 |
31 | def update(self, value, n=1):
32 | self.deque.append(value)
33 | self.count += n
34 | self.total += value * n
35 |
36 | def synchronize_between_processes(self):
37 | """
38 | Warning: does not synchronize the deque!
39 | """
40 | if not is_dist_avail_and_initialized():
41 | return
42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
43 | dist.barrier()
44 | dist.all_reduce(t)
45 | t = t.tolist()
46 | self.count = int(t[0])
47 | self.total = t[1]
48 |
49 | @property
50 | def median(self):
51 | d = torch.tensor(list(self.deque))
52 | return d.median().item()
53 |
54 | @property
55 | def avg(self):
56 | d = torch.tensor(list(self.deque), dtype=torch.float32)
57 | return d.mean().item()
58 |
59 | @property
60 | def global_avg(self):
61 | return self.total / self.count
62 |
63 | @property
64 | def max(self):
65 | return max(self.deque)
66 |
67 | @property
68 | def value(self):
69 | return self.deque[-1]
70 |
71 | def __str__(self):
72 | return self.fmt.format(
73 | median=self.median,
74 | avg=self.avg,
75 | global_avg=self.global_avg,
76 | max=self.max,
77 | value=self.value)
78 |
79 |
80 | class MetricLogger(object):
81 | def __init__(self, delimiter="\t"):
82 | self.meters = defaultdict(SmoothedValue)
83 | self.delimiter = delimiter
84 |
85 | def update(self, **kwargs):
86 | for k, v in kwargs.items():
87 | if isinstance(v, torch.Tensor):
88 | v = v.item()
89 | assert isinstance(v, (float, int))
90 | self.meters[k].update(v)
91 |
92 | def __getattr__(self, attr):
93 | if attr in self.meters:
94 | return self.meters[attr]
95 | if attr in self.__dict__:
96 | return self.__dict__[attr]
97 | raise AttributeError("'{}' object has no attribute '{}'".format(
98 | type(self).__name__, attr))
99 |
100 | def __str__(self):
101 | loss_str = []
102 | for name, meter in self.meters.items():
103 | loss_str.append(
104 | "{}: {}".format(name, str(meter))
105 | )
106 | return self.delimiter.join(loss_str)
107 |
108 | def synchronize_between_processes(self):
109 | for meter in self.meters.values():
110 | meter.synchronize_between_processes()
111 |
112 | def add_meter(self, name, meter):
113 | self.meters[name] = meter
114 |
115 | def log_every(self, iterable, print_freq, header=None):
116 | i = 0
117 | if not header:
118 | header = ''
119 | start_time = time.time()
120 | end = time.time()
121 | iter_time = SmoothedValue(fmt='{avg:.4f}')
122 | data_time = SmoothedValue(fmt='{avg:.4f}')
123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
124 | log_msg = [
125 | header,
126 | '[{0' + space_fmt + '}/{1}]',
127 | 'eta: {eta}',
128 | '{meters}',
129 | 'time: {time}',
130 | 'data: {data}'
131 | ]
132 | if torch.cuda.is_available():
133 | log_msg.append('max mem: {memory:.0f}')
134 | log_msg = self.delimiter.join(log_msg)
135 | MB = 1024.0 * 1024.0
136 | for obj in iterable:
137 | data_time.update(time.time() - end)
138 | yield obj
139 | iter_time.update(time.time() - end)
140 | if i % print_freq == 0 or i == len(iterable) - 1:
141 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
143 | if torch.cuda.is_available():
144 | print(log_msg.format(
145 | i, len(iterable), eta=eta_string,
146 | meters=str(self),
147 | time=str(iter_time), data=str(data_time),
148 | memory=torch.cuda.max_memory_allocated() / MB))
149 | else:
150 | print(log_msg.format(
151 | i, len(iterable), eta=eta_string,
152 | meters=str(self),
153 | time=str(iter_time), data=str(data_time)))
154 | i += 1
155 | end = time.time()
156 | total_time = time.time() - start_time
157 | total_time_str = f"{total_time:.6f}" # 保留六位小数,即微秒级
158 | print(f'{header} Total time: {total_time_str} s ({total_time / len(iterable):.4f} s / it)')
159 | # total_time_str = str(datetime.timedelta(seconds=int(total_time)))
160 | # print('{} Total time: {} ({:.4f} s / it)'.format(
161 | # header, total_time_str, total_time / len(iterable)))
162 |
163 |
164 |
165 | def _load_checkpoint_for_ema(model_ema, checkpoint):
166 | """
167 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
168 | """
169 | mem_file = io.BytesIO()
170 | torch.save({'state_dict_ema':checkpoint}, mem_file)
171 | mem_file.seek(0)
172 | model_ema._load_checkpoint(mem_file)
173 |
174 |
175 | def setup_for_distributed(is_master):
176 | """
177 | This function disables printing when not in master process
178 | """
179 | import builtins as __builtin__
180 | builtin_print = __builtin__.print
181 |
182 | def print(*args, **kwargs):
183 | force = kwargs.pop('force', False)
184 | if is_master or force:
185 | builtin_print(*args, **kwargs)
186 |
187 | __builtin__.print = print
188 |
189 |
190 | def is_dist_avail_and_initialized():
191 | if not dist.is_available():
192 | return False
193 | if not dist.is_initialized():
194 | return False
195 | return True
196 |
197 |
198 | def get_world_size():
199 | if not is_dist_avail_and_initialized():
200 | return 1
201 | return dist.get_world_size()
202 |
203 |
204 | def get_rank():
205 | if not is_dist_avail_and_initialized():
206 | return 0
207 | return dist.get_rank()
208 |
209 |
210 | def is_main_process():
211 | return get_rank() == 0
212 |
213 |
214 | def save_on_master(*args, **kwargs):
215 | if is_main_process():
216 | torch.save(*args, **kwargs)
217 |
218 |
219 | def init_distributed_mode(args):
220 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
221 | args.rank = int(os.environ["RANK"])
222 | args.world_size = int(os.environ['WORLD_SIZE'])
223 | args.gpu = int(os.environ['LOCAL_RANK'])
224 | elif 'SLURM_PROCID' in os.environ:
225 | args.rank = int(os.environ['SLURM_PROCID'])
226 | args.gpu = args.rank % torch.cuda.device_count()
227 | else:
228 | print('Not using distributed mode')
229 | args.distributed = False
230 | return
231 |
232 | args.distributed = True
233 |
234 | torch.cuda.set_device(args.gpu)
235 | args.dist_backend = 'nccl'
236 | print('| distributed init (rank {}): {}'.format(
237 | args.rank, args.dist_url), flush=True)
238 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
239 | world_size=args.world_size, rank=args.rank)
240 | torch.distributed.barrier()
241 | setup_for_distributed(args.rank == 0)
242 |
243 |
244 | # if 'pos_embed' in state_dict:
245 | def interpolate_pos_embed(model, state_dict):
246 | pos_embed_checkpoint = state_dict['pos_embed']
247 | embedding_size = pos_embed_checkpoint.shape[-1]
248 | num_patches = model.patch_embed.num_patches
249 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
250 | # height (== width) for the checkpoint position embedding
251 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
252 | # height (== width) for the new position embedding
253 | new_size = int(num_patches ** 0.5)
254 | # class_token and dist_token are kept unchanged
255 | if orig_size != new_size:
256 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
257 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
258 | # only the position tokens are interpolated
259 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
260 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
261 | pos_tokens = torch.nn.functional.interpolate(
262 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
263 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
264 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
265 | state_dict['pos_embed'] = new_pos_embed
--------------------------------------------------------------------------------
/fambav/vim_requirements.txt:
--------------------------------------------------------------------------------
1 | addict==2.4.0
2 | aiohttp==3.9.1
3 | aiosignal==1.3.1
4 | alembic==1.13.0
5 | async-timeout==4.0.3
6 | attrs==23.1.0
7 | blinker==1.7.0
8 | # causal-conv1d @ file:///home/zhulianghui/VisionProjects/mamba/lib/causal_conv1d-1.0.0%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=79a4bab633ebff031e615d5e8ba396b0dc0c046f4406980ee238fb86a9090038
9 | certifi==2023.11.17
10 | charset-normalizer==3.3.2
11 | click==8.1.7
12 | cloudpickle==3.0.0
13 | contourpy==1.2.0
14 | cycler==0.12.1
15 | databricks-cli==0.18.0
16 | datasets==2.15.0
17 | dill==0.3.7
18 | docker==6.1.3
19 | einops==0.7.0
20 | entrypoints==0.4
21 | filelock==3.13.1
22 | Flask==3.0.0
23 | fonttools==4.46.0
24 | frozenlist==1.4.0
25 | fsspec==2023.10.0
26 | gitdb==4.0.11
27 | GitPython==3.1.40
28 | greenlet==3.0.2
29 | gunicorn==21.2.0
30 | huggingface-hub==0.19.4
31 | idna==3.6
32 | importlib-metadata==7.0.0
33 | itsdangerous==2.1.2
34 | Jinja2==3.1.2
35 | joblib==1.3.2
36 | kiwisolver==1.4.5
37 | Mako==1.3.0
38 | # mamba-ssm @ file:///home/zhulianghui/VisionProjects/mamba/lib/mamba_ssm-1.0.1%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=71ad1b1eafb05a6e8a41fd82e046fe85511d6378fa3a583e55215b6aa1d65ab9
39 | Markdown==3.5.1
40 | MarkupSafe==2.1.3
41 | matplotlib==3.8.2
42 | mlflow==2.9.1
43 | mmcv==1.3.8
44 | mmsegmentation==0.14.1
45 | mpmath==1.3.0
46 | multidict==6.0.4
47 | multiprocess==0.70.15
48 | networkx==3.2.1
49 | ninja==1.11.1.1
50 | numpy==1.26.2
51 | # nvidia-cublas-cu12==12.1.3.1
52 | # nvidia-cuda-cupti-cu12==12.1.105
53 | # nvidia-cuda-nvrtc-cu12==12.1.105
54 | # nvidia-cuda-runtime-cu12==12.1.105
55 | # nvidia-cudnn-cu12==8.9.2.26
56 | # nvidia-cufft-cu12==11.0.2.54
57 | # nvidia-curand-cu12==10.3.2.106
58 | # nvidia-cusolver-cu12==11.4.5.107
59 | # nvidia-cusparse-cu12==12.1.0.106
60 | # nvidia-nccl-cu12==2.18.1
61 | # nvidia-nvjitlink-cu12==12.3.101
62 | # nvidia-nvtx-cu12==12.1.105
63 | oauthlib==3.2.2
64 | opencv-python==4.8.1.78
65 | packaging==23.2
66 | pandas==2.1.3
67 | Pillow==10.1.0
68 | platformdirs==4.1.0
69 | prettytable==3.9.0
70 | protobuf==4.25.1
71 | pyarrow==14.0.1
72 | pyarrow-hotfix==0.6
73 | PyJWT==2.8.0
74 | pyparsing==3.1.1
75 | python-dateutil==2.8.2
76 | python-hostlist==1.23.0
77 | pytz==2023.3.post1
78 | PyYAML==6.0.1
79 | querystring-parser==1.2.4
80 | regex==2023.10.3
81 | requests==2.31.0
82 | safetensors==0.4.1
83 | scikit-learn==1.3.2
84 | scipy==1.11.4
85 | six==1.16.0
86 | smmap==5.0.1
87 | SQLAlchemy==2.0.23
88 | sqlparse==0.4.4
89 | sympy==1.12
90 | tabulate==0.9.0
91 | threadpoolctl==3.2.0
92 | timm==0.4.12
93 | tokenizers==0.15.0
94 | tomli==2.0.1
95 | # torch==2.1.1+cu118
96 | # torchvision==0.16.1+cu118
97 | tqdm==4.66.1
98 | transformers==4.35.2
99 | triton==2.1.0
100 | typing_extensions==4.8.0
101 | tzdata==2023.3
102 | urllib3==2.1.0
103 | wcwidth==0.2.12
104 | websocket-client==1.7.0
105 | Werkzeug==3.0.1
106 | xxhash==3.4.1
107 | yapf==0.40.2
108 | yarl==1.9.4
109 | zipp==3.17.0
110 |
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_0.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_1.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_10.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_11.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_12.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_13.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_14.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_15.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_16.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_17.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_18.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_19.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_2.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_20.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_20.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_21.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_21.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_22.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_23.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_3.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_4.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_5.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_6.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_7.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_8.png
--------------------------------------------------------------------------------
/fambav/visualization/cosine_similarity_layer_9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/cosine_similarity_layer_9.png
--------------------------------------------------------------------------------
/fambav/visualization/hidden_states_20240714_113249.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113249.npz
--------------------------------------------------------------------------------
/fambav/visualization/hidden_states_20240714_113250.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113250.npz
--------------------------------------------------------------------------------
/fambav/visualization/hidden_states_20240714_113251.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113251.npz
--------------------------------------------------------------------------------
/fambav/visualization/hidden_states_20240714_113252.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113252.npz
--------------------------------------------------------------------------------
/fambav/visualization/hidden_states_20240714_113253.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113253.npz
--------------------------------------------------------------------------------
/fambav/visualization/hidden_states_20240714_113254.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113254.npz
--------------------------------------------------------------------------------
/fambav/visualization/hidden_states_20240714_113255.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/fambav/visualization/hidden_states_20240714_113255.npz
--------------------------------------------------------------------------------
/fambav/visualization/npz_viewer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | data = np.load('./hidden_states_20240714_053015.npz')
4 | for key in data.files:
5 | print(data[key])
--------------------------------------------------------------------------------
/fambav/visualization/test_visualize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib
4 | matplotlib.use('Agg')
5 | import matplotlib.pyplot as plt
6 | import seaborn as sns
7 | from sklearn.metrics.pairwise import cosine_similarity
8 |
9 | def visualize_cosine_similarity_from_npz(file_path):
10 | print("Start to visualize cosine similarity from .npz file")
11 |
12 | # 加载 .npz 文件
13 | loaded_data = np.load(file_path)
14 |
15 | # 遍历所有的层
16 | for key in loaded_data.files:
17 | if key.startswith('layer_'):
18 | hidden_states = loaded_data[key]
19 |
20 | if isinstance(hidden_states, np.ndarray):
21 | hidden_states = torch.from_numpy(hidden_states)
22 |
23 | # 如果数据在GPU上,确保将其移动到CPU上
24 | if hidden_states.device.type != 'cpu':
25 | hidden_states = hidden_states.cpu()
26 |
27 | # 转换为 NumPy 数组(如果尚未转换)
28 | hidden_states = hidden_states.numpy()
29 |
30 | # 计算余弦相似度矩阵
31 | cosine_sim_matrix = cosine_similarity(hidden_states)
32 |
33 | print(f"cosine_sim_matrix for {key} success")
34 |
35 | plt.figure(figsize=(12, 10)) # 增大图像尺寸
36 | sns.heatmap(cosine_sim_matrix, cmap='Blues')
37 |
38 | # 设置标题和标签,增大字体大小
39 | plt.suptitle(f"Cosine Similarity Matrix of Hidden States for {key}", fontsize=26, y=0.98)
40 | plt.xlabel("Hidden State Index", fontsize=26)
41 | plt.ylabel("Hidden State Index", fontsize=26)
42 |
43 | # 计算新的刻度位置和标签
44 | n = cosine_sim_matrix.shape[0] # 假设矩阵是方阵
45 | ticks = np.arange(0, n, 8)
46 | tick_labels = ticks
47 |
48 | # 设置新的刻度
49 | plt.xticks(ticks, tick_labels, fontsize=16)
50 | plt.yticks(ticks, tick_labels, fontsize=16)
51 |
52 | # 调整colorbar的字体大小
53 | cbar = plt.gcf().axes[-1]
54 | cbar.tick_params(labelsize=12)
55 |
56 | plt.tight_layout() # 自动调整子图参数,使之填充整个图像区域
57 | plt.savefig(f"./cosine_similarity_{key}.png", dpi=600) # 增加DPI以提高图像质量
58 | print(f"Heatmap for {key} saved.")
59 |
60 | # 示例 .npz 文件路径
61 | file_path = './hidden_states_20240714_113251.npz'
62 |
63 | # 调用函数
64 | visualize_cosine_similarity_from_npz(file_path)
--------------------------------------------------------------------------------
/mamba-1p1p1/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | # This workflow will:
2 | # - Create a new Github release
3 | # - Build wheels for supported architectures
4 | # - Deploy the wheels to the Github release
5 | # - Release the static code to PyPi
6 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
7 |
8 | name: Build wheels and deploy
9 |
10 | on:
11 | create:
12 | tags:
13 | - v*
14 |
15 | jobs:
16 |
17 | setup_release:
18 | name: Create Release
19 | runs-on: ubuntu-latest
20 | steps:
21 | - name: Get the tag version
22 | id: extract_branch
23 | run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
24 | shell: bash
25 |
26 | - name: Create Release
27 | id: create_release
28 | uses: actions/create-release@v1
29 | env:
30 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
31 | with:
32 | tag_name: ${{ steps.extract_branch.outputs.branch }}
33 | release_name: ${{ steps.extract_branch.outputs.branch }}
34 |
35 | build_wheels:
36 | name: Build Wheel
37 | needs: setup_release
38 | runs-on: ${{ matrix.os }}
39 |
40 | strategy:
41 | fail-fast: false
42 | matrix:
43 | # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
44 | # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
45 | os: [ubuntu-20.04]
46 | python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
47 | torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.1', '2.2.0.dev20231106']
48 | cuda-version: ['11.8.0', '12.2.0']
49 | # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
50 | # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
51 | # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
52 | # when building without C++11 ABI and using it on nvcr images.
53 | cxx11_abi: ['FALSE', 'TRUE']
54 | exclude:
55 | # Pytorch <= 1.12 does not support Python 3.11
56 | - torch-version: '1.12.1'
57 | python-version: '3.11'
58 | # Pytorch >= 2.0 only supports Python >= 3.8
59 | - torch-version: '2.0.1'
60 | python-version: '3.7'
61 | - torch-version: '2.1.1'
62 | python-version: '3.7'
63 | - torch-version: '2.2.0.dev20231106'
64 | python-version: '3.7'
65 | # Pytorch <= 2.0 only supports CUDA <= 11.8
66 | - torch-version: '1.12.1'
67 | cuda-version: '12.2.0'
68 | - torch-version: '1.13.1'
69 | cuda-version: '12.2.0'
70 | - torch-version: '2.0.1'
71 | cuda-version: '12.2.0'
72 |
73 | steps:
74 | - name: Checkout
75 | uses: actions/checkout@v3
76 |
77 | - name: Set up Python
78 | uses: actions/setup-python@v4
79 | with:
80 | python-version: ${{ matrix.python-version }}
81 |
82 | - name: Set CUDA and PyTorch versions
83 | run: |
84 | echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
85 | echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
86 |
87 | - name: Free up disk space
88 | if: ${{ runner.os == 'Linux' }}
89 | # https://github.com/easimon/maximize-build-space/blob/master/action.yml
90 | # https://github.com/easimon/maximize-build-space/tree/test-report
91 | run: |
92 | sudo rm -rf /usr/share/dotnet
93 | sudo rm -rf /opt/ghc
94 | sudo rm -rf /opt/hostedtoolcache/CodeQL
95 |
96 | - name: Set up swap space
97 | if: runner.os == 'Linux'
98 | uses: pierotofy/set-swap-space@v1.0
99 | with:
100 | swap-size-gb: 10
101 |
102 | - name: Install CUDA ${{ matrix.cuda-version }}
103 | if: ${{ matrix.cuda-version != 'cpu' }}
104 | uses: Jimver/cuda-toolkit@v0.2.11
105 | id: cuda-toolkit
106 | with:
107 | cuda: ${{ matrix.cuda-version }}
108 | linux-local-args: '["--toolkit"]'
109 | # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
110 | # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
111 | method: 'network'
112 | # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
113 | # not just nvcc
114 | # sub-packages: '["nvcc"]'
115 |
116 | - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
117 | run: |
118 | pip install --upgrade pip
119 | # If we don't install before installing Pytorch, we get error for torch 2.0.1
120 | # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
121 | pip install lit
122 | # We want to figure out the CUDA version to download pytorch
123 | # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
124 | # This code is ugly, maybe there's a better way to do this.
125 | export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
126 | if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
127 | pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
128 | else
129 | pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
130 | fi
131 | nvcc --version
132 | python --version
133 | python -c "import torch; print('PyTorch:', torch.__version__)"
134 | python -c "import torch; print('CUDA:', torch.version.cuda)"
135 | python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
136 | shell:
137 | bash
138 |
139 | - name: Build wheel
140 | run: |
141 | # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
142 | # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
143 | # However this still fails so I'm using a newer version of setuptools
144 | pip install setuptools==68.0.0
145 | pip install ninja packaging wheel
146 | export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
147 | export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
148 | # Limit MAX_JOBS otherwise the github runner goes OOM
149 | MAX_JOBS=2 MAMBA_FORCE_BUILD="TRUE" MAMBA_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
150 | tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
151 | wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
152 | ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
153 | echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
154 |
155 | - name: Log Built Wheels
156 | run: |
157 | ls dist
158 |
159 | - name: Get the tag version
160 | id: extract_branch
161 | run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
162 |
163 | - name: Get Release with tag
164 | id: get_current_release
165 | uses: joutvhu/get-release@v1
166 | with:
167 | tag_name: ${{ steps.extract_branch.outputs.branch }}
168 | env:
169 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
170 |
171 | - name: Upload Release Asset
172 | id: upload_release_asset
173 | uses: actions/upload-release-asset@v1
174 | env:
175 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
176 | with:
177 | upload_url: ${{ steps.get_current_release.outputs.upload_url }}
178 | asset_path: ./dist/${{env.wheel_name}}
179 | asset_name: ${{env.wheel_name}}
180 | asset_content_type: application/*
181 |
182 | publish_package:
183 | name: Publish package
184 | needs: [build_wheels]
185 |
186 | runs-on: ubuntu-latest
187 |
188 | steps:
189 | - uses: actions/checkout@v3
190 |
191 | - uses: actions/setup-python@v4
192 | with:
193 | python-version: '3.10'
194 |
195 | - name: Install dependencies
196 | run: |
197 | pip install ninja packaging setuptools wheel twine
198 | # We don't want to download anything CUDA-related here
199 | pip install torch --index-url https://download.pytorch.org/whl/cpu
200 |
201 | - name: Build core package
202 | env:
203 | MAMBA_SKIP_CUDA_BUILD: "TRUE"
204 | run: |
205 | python setup.py sdist --dist-dir=dist
206 |
207 | - name: Deploy
208 | env:
209 | TWINE_USERNAME: "__token__"
210 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
211 | run: |
212 | python -m twine upload dist/*
213 |
--------------------------------------------------------------------------------
/mamba-1p1p1/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__/
2 | *.egg-info/
3 | build/
4 | **.so
5 |
--------------------------------------------------------------------------------
/mamba-1p1p1/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "3rdparty/lm-evaluation-harness"]
2 | path = 3rdparty/lm-evaluation-harness
3 | url = https://github.com/EleutherAI/lm-evaluation-harness/
4 |
--------------------------------------------------------------------------------
/mamba-1p1p1/AUTHORS:
--------------------------------------------------------------------------------
1 | Tri Dao, tri@tridao.me
2 | Albert Gu, agu@andrew.cmu.edu
3 |
--------------------------------------------------------------------------------
/mamba-1p1p1/README.md:
--------------------------------------------------------------------------------
1 | # Mamba
2 |
3 | 
4 | > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
5 | > Albert Gu*, Tri Dao*\
6 | > Paper: https://arxiv.org/abs/2312.00752
7 |
8 | ## About
9 |
10 | Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
11 | It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
12 | with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
13 |
14 | ## Installation
15 |
16 | - `pip install causal-conv1d>=1.1.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
17 | - `pip install mamba-ssm`: the core Mamba package.
18 |
19 | It can also be built from source with `pip install .` from this repository.
20 |
21 | If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
22 |
23 | Other requirements:
24 | - Linux
25 | - NVIDIA GPU
26 | - PyTorch 1.12+
27 | - CUDA 11.6+
28 |
29 | ## Usage
30 |
31 | We expose several levels of interface with the Mamba model.
32 |
33 | ### Selective SSM
34 |
35 | Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
36 |
37 | Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
38 |
39 | ### Mamba Block
40 |
41 | The main module of this repository is the Mamba architecture block wrapping the selective SSM.
42 |
43 | Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
44 |
45 | Usage:
46 | ```
47 | import torch
48 | from mamba_ssm import Mamba
49 |
50 | batch, length, dim = 2, 64, 16
51 | x = torch.randn(batch, length, dim).to("cuda")
52 | model = Mamba(
53 | # This module uses roughly 3 * expand * d_model^2 parameters
54 | d_model=dim, # Model dimension d_model
55 | d_state=16, # SSM state expansion factor
56 | d_conv=4, # Local convolution width
57 | expand=2, # Block expansion factor
58 | ).to("cuda")
59 | y = model(x)
60 | assert y.shape == x.shape
61 | ```
62 |
63 | ### Mamba Language Model
64 |
65 | Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
66 |
67 | Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
68 |
69 | This is an example of how to integrate Mamba into an end-to-end neural network.
70 | This example is used in the generation scripts below.
71 |
72 |
73 |
74 | ## Pretrained Models
75 |
76 | Pretrained models are uploaded to
77 | [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
78 | `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
79 | (trained on 600B tokens on the SlimPajama dataset).
80 |
81 |
82 | The models will be autodownloaded by the generation script below.
83 |
84 | These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
85 |
86 | | Parameters | Layers | Model dim. |
87 | |------------|--------|------------|
88 | | 130M | 24 | 768 |
89 | | 370M | 48 | 1024 |
90 | | 790M | 48 | 1536 |
91 | | 1.4B | 48 | 2048 |
92 | | 2.8B | 64 | 2560 |
93 |
94 | (The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
95 |
96 | Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
97 | Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
98 |
99 |
100 | ## Evaluations
101 |
102 | To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
103 | we use the
104 | [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
105 | library.
106 |
107 | 1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
108 | --recursive`. We use the `big-refactor` branch.
109 | 2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`.
110 | On Python 3.10 you might need to manually install the latest version of `promptsource`: `pip install git+https://github.com/bigscience-workshop/promptsource.git`.
111 | 3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
112 | ```
113 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
114 | python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
115 | ```
116 |
117 | To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
118 | ```
119 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 64
120 | python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 64
121 | ```
122 |
123 | Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
124 |
125 | ## Inference
126 |
127 | The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
128 | 1. autoloads a model from the Hugging Face Hub,
129 | 2. generates completions of a user-specified prompt,
130 | 3. benchmarks the inference speed of this generation.
131 |
132 | Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
133 |
134 | ### Examples
135 |
136 | To test generation latency (e.g. batch size = 1) with different sampling strategies:
137 |
138 | ```
139 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
140 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
141 | ```
142 |
143 | To test generation throughput with random prompts (e.g. large batch size):
144 | ```
145 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
146 | python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
147 | ```
148 |
149 |
150 | ## Troubleshooting
151 |
152 | ### Precision
153 | Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary.
154 | On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).
155 |
156 | We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities,
157 | as a first step please try a framework storing parameters in fp32 (such as AMP).
158 |
159 | ### Initialization
160 | Some parts of the model have initializations inherited from prior work on S4 models.
161 | For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection.
162 | However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero).
163 | If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework)
164 | that is specific to the training framework.
165 |
166 |
167 | ## Citation
168 |
169 | If you use this codebase, or otherwise found our work valuable, please cite Mamba:
170 | ```
171 | @article{mamba,
172 | title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
173 | author={Gu, Albert and Dao, Tri},
174 | journal={arXiv preprint arXiv:2312.00752},
175 | year={2023}
176 | }
177 | ```
178 |
--------------------------------------------------------------------------------
/mamba-1p1p1/assets/selection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/assets/selection.png
--------------------------------------------------------------------------------
/mamba-1p1p1/benchmarks/benchmark_generation_mamba_simple.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao, Albert Gu.
2 |
3 | import argparse
4 | import time
5 | import json
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from einops import rearrange
11 |
12 | from transformers import AutoTokenizer, AutoModelForCausalLM
13 |
14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
15 |
16 |
17 | parser = argparse.ArgumentParser(description="Generation benchmarking")
18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m")
19 | parser.add_argument("--prompt", type=str, default=None)
20 | parser.add_argument("--promptlen", type=int, default=100)
21 | parser.add_argument("--genlen", type=int, default=100)
22 | parser.add_argument("--temperature", type=float, default=1.0)
23 | parser.add_argument("--topk", type=int, default=1)
24 | parser.add_argument("--topp", type=float, default=1.0)
25 | parser.add_argument("--repetition-penalty", type=float, default=1.0)
26 | parser.add_argument("--batch", type=int, default=1)
27 | args = parser.parse_args()
28 |
29 | repeats = 3
30 | device = "cuda"
31 | dtype = torch.float16
32 |
33 | print(f"Loading model {args.model_name}")
34 | is_mamba = args.model_name.startswith("state-spaces/mamba-")
35 | if is_mamba:
36 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)
38 | else:
39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype)
41 | model.eval()
42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
43 |
44 | torch.random.manual_seed(0)
45 | if args.prompt is None:
46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
48 | else:
49 | tokens = tokenizer(args.prompt, return_tensors="pt")
50 | input_ids = tokens.input_ids.to(device=device)
51 | attn_mask = tokens.attention_mask.to(device=device)
52 | max_length = input_ids.shape[1] + args.genlen
53 |
54 | if is_mamba:
55 | fn = lambda: model.generate(
56 | input_ids=input_ids,
57 | max_length=max_length,
58 | cg=True,
59 | return_dict_in_generate=True,
60 | output_scores=True,
61 | enable_timing=False,
62 | temperature=args.temperature,
63 | top_k=args.topk,
64 | top_p=args.topp,
65 | repetition_penalty=args.repetition_penalty,
66 | )
67 | else:
68 | fn = lambda: model.generate(
69 | input_ids=input_ids,
70 | attention_mask=attn_mask,
71 | max_length=max_length,
72 | return_dict_in_generate=True,
73 | pad_token_id=tokenizer.eos_token_id,
74 | do_sample=True,
75 | temperature=args.temperature,
76 | top_k=args.topk,
77 | top_p=args.topp,
78 | repetition_penalty=args.repetition_penalty,
79 | )
80 | out = fn()
81 | if args.prompt is not None:
82 | print(tokenizer.batch_decode(out.sequences.tolist()))
83 |
84 | torch.cuda.synchronize()
85 | start = time.time()
86 | for _ in range(repeats):
87 | fn()
88 | torch.cuda.synchronize()
89 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
90 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms")
91 |
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | ////////////////////////////////////////////////////////////////////////////////////////////////////
8 |
9 | struct SSMScanParamsBase {
10 | using index_t = uint32_t;
11 |
12 | int batch, seqlen, n_chunks;
13 | index_t a_batch_stride;
14 | index_t b_batch_stride;
15 | index_t out_batch_stride;
16 |
17 | // Common data pointers.
18 | void *__restrict__ a_ptr;
19 | void *__restrict__ b_ptr;
20 | void *__restrict__ out_ptr;
21 | void *__restrict__ x_ptr;
22 | };
23 |
24 | ////////////////////////////////////////////////////////////////////////////////////////////////////
25 |
26 | struct SSMParamsBase {
27 | using index_t = uint32_t;
28 |
29 | int batch, dim, seqlen, dstate, n_groups, n_chunks;
30 | int dim_ngroups_ratio;
31 | bool is_variable_B;
32 | bool is_variable_C;
33 |
34 | bool delta_softplus;
35 |
36 | index_t A_d_stride;
37 | index_t A_dstate_stride;
38 | index_t B_batch_stride;
39 | index_t B_d_stride;
40 | index_t B_dstate_stride;
41 | index_t B_group_stride;
42 | index_t C_batch_stride;
43 | index_t C_d_stride;
44 | index_t C_dstate_stride;
45 | index_t C_group_stride;
46 | index_t u_batch_stride;
47 | index_t u_d_stride;
48 | index_t delta_batch_stride;
49 | index_t delta_d_stride;
50 | index_t z_batch_stride;
51 | index_t z_d_stride;
52 | index_t out_batch_stride;
53 | index_t out_d_stride;
54 | index_t out_z_batch_stride;
55 | index_t out_z_d_stride;
56 |
57 | // Common data pointers.
58 | void *__restrict__ A_ptr;
59 | void *__restrict__ B_ptr;
60 | void *__restrict__ C_ptr;
61 | void *__restrict__ D_ptr;
62 | void *__restrict__ u_ptr;
63 | void *__restrict__ delta_ptr;
64 | void *__restrict__ delta_bias_ptr;
65 | void *__restrict__ out_ptr;
66 | void *__restrict__ x_ptr;
67 | void *__restrict__ z_ptr;
68 | void *__restrict__ out_z_ptr;
69 | };
70 |
71 | struct SSMParamsBwd: public SSMParamsBase {
72 | index_t dout_batch_stride;
73 | index_t dout_d_stride;
74 | index_t dA_d_stride;
75 | index_t dA_dstate_stride;
76 | index_t dB_batch_stride;
77 | index_t dB_group_stride;
78 | index_t dB_d_stride;
79 | index_t dB_dstate_stride;
80 | index_t dC_batch_stride;
81 | index_t dC_group_stride;
82 | index_t dC_d_stride;
83 | index_t dC_dstate_stride;
84 | index_t du_batch_stride;
85 | index_t du_d_stride;
86 | index_t dz_batch_stride;
87 | index_t dz_d_stride;
88 | index_t ddelta_batch_stride;
89 | index_t ddelta_d_stride;
90 |
91 | // Common data pointers.
92 | void *__restrict__ dout_ptr;
93 | void *__restrict__ dA_ptr;
94 | void *__restrict__ dB_ptr;
95 | void *__restrict__ dC_ptr;
96 | void *__restrict__ dD_ptr;
97 | void *__restrict__ du_ptr;
98 | void *__restrict__ dz_ptr;
99 | void *__restrict__ ddelta_ptr;
100 | void *__restrict__ ddelta_bias_ptr;
101 | };
102 |
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_real.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_bwd_kernel.cuh"
8 |
9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_common.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | #include
8 | #include
9 | #include // For scalar_value_type
10 |
11 | #define MAX_DSTATE 256
12 |
13 | using complex_t = c10::complex;
14 |
15 | inline __device__ float2 operator+(const float2 & a, const float2 & b){
16 | return {a.x + b.x, a.y + b.y};
17 | }
18 |
19 | inline __device__ float3 operator+(const float3 &a, const float3 &b) {
20 | return {a.x + b.x, a.y + b.y, a.z + b.z};
21 | }
22 |
23 | inline __device__ float4 operator+(const float4 & a, const float4 & b){
24 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
25 | }
26 |
27 | ////////////////////////////////////////////////////////////////////////////////////////////////////
28 |
29 | template struct BytesToType {};
30 |
31 | template<> struct BytesToType<16> {
32 | using Type = uint4;
33 | static_assert(sizeof(Type) == 16);
34 | };
35 |
36 | template<> struct BytesToType<8> {
37 | using Type = uint64_t;
38 | static_assert(sizeof(Type) == 8);
39 | };
40 |
41 | template<> struct BytesToType<4> {
42 | using Type = uint32_t;
43 | static_assert(sizeof(Type) == 4);
44 | };
45 |
46 | template<> struct BytesToType<2> {
47 | using Type = uint16_t;
48 | static_assert(sizeof(Type) == 2);
49 | };
50 |
51 | template<> struct BytesToType<1> {
52 | using Type = uint8_t;
53 | static_assert(sizeof(Type) == 1);
54 | };
55 |
56 | ////////////////////////////////////////////////////////////////////////////////////////////////////
57 |
58 | template
59 | struct Converter{
60 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
61 | #pragma unroll
62 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
63 | }
64 | };
65 |
66 | template
67 | struct Converter{
68 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
69 | static_assert(N % 2 == 0);
70 | auto &src2 = reinterpret_cast(src);
71 | auto &dst2 = reinterpret_cast(dst);
72 | #pragma unroll
73 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
74 | }
75 | };
76 |
77 | #if __CUDA_ARCH__ >= 800
78 | template
79 | struct Converter{
80 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
81 | static_assert(N % 2 == 0);
82 | auto &src2 = reinterpret_cast(src);
83 | auto &dst2 = reinterpret_cast(dst);
84 | #pragma unroll
85 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
86 | }
87 | };
88 | #endif
89 |
90 | ////////////////////////////////////////////////////////////////////////////////////////////////////
91 |
92 | // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
93 | // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
94 | __device__ __forceinline__ complex_t cexp2f(complex_t z) {
95 | float t = exp2f(z.real_);
96 | float c, s;
97 | sincosf(z.imag_, &s, &c);
98 | return complex_t(c * t, s * t);
99 | }
100 |
101 | __device__ __forceinline__ complex_t cexpf(complex_t z) {
102 | float t = expf(z.real_);
103 | float c, s;
104 | sincosf(z.imag_, &s, &c);
105 | return complex_t(c * t, s * t);
106 | }
107 |
108 | template struct SSMScanOp;
109 |
110 | template<>
111 | struct SSMScanOp {
112 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
113 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
114 | }
115 | };
116 |
117 | template<>
118 | struct SSMScanOp {
119 | __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
120 | complex_t a0 = complex_t(ab0.x, ab0.y);
121 | complex_t b0 = complex_t(ab0.z, ab0.w);
122 | complex_t a1 = complex_t(ab1.x, ab1.y);
123 | complex_t b1 = complex_t(ab1.z, ab1.w);
124 | complex_t out_a = a1 * a0;
125 | complex_t out_b = a1 * b0 + b1;
126 | return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
127 | }
128 | };
129 |
130 | // A stateful callback functor that maintains a running prefix to be applied
131 | // during consecutive scan operations.
132 | template struct SSMScanPrefixCallbackOp {
133 | using scan_t = std::conditional_t, float2, float4>;
134 | scan_t running_prefix;
135 | // Constructor
136 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
137 | // Callback operator to be entered by the first warp of threads in the block.
138 | // Thread-0 is responsible for returning a value for seeding the block-wide scan.
139 | __device__ scan_t operator()(scan_t block_aggregate) {
140 | scan_t old_prefix = running_prefix;
141 | running_prefix = SSMScanOp()(running_prefix, block_aggregate);
142 | return old_prefix;
143 | }
144 | };
145 |
146 | ////////////////////////////////////////////////////////////////////////////////////////////////////
147 |
148 | template
149 | inline __device__ void load_input(typename Ktraits::input_t *u,
150 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
151 | typename Ktraits::BlockLoadT::TempStorage &smem_load,
152 | int seqlen) {
153 | if constexpr (Ktraits::kIsEvenLen) {
154 | auto& smem_load_vec = reinterpret_cast(smem_load);
155 | using vec_t = typename Ktraits::vec_t;
156 | Ktraits::BlockLoadVecT(smem_load_vec).Load(
157 | reinterpret_cast(u),
158 | reinterpret_cast(u_vals)
159 | );
160 | } else {
161 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
162 | }
163 | }
164 |
165 | template
166 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
167 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
168 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
169 | int seqlen) {
170 | constexpr int kNItems = Ktraits::kNItems;
171 | if constexpr (!Ktraits::kIsComplex) {
172 | typename Ktraits::input_t B_vals_load[kNItems];
173 | if constexpr (Ktraits::kIsEvenLen) {
174 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight);
175 | using vec_t = typename Ktraits::vec_t;
176 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
177 | reinterpret_cast(Bvar),
178 | reinterpret_cast(B_vals_load)
179 | );
180 | } else {
181 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
182 | }
183 | // #pragma unroll
184 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
185 | Converter::to_float(B_vals_load, B_vals);
186 | } else {
187 | typename Ktraits::input_t B_vals_load[kNItems * 2];
188 | if constexpr (Ktraits::kIsEvenLen) {
189 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight);
190 | using vec_t = typename Ktraits::vec_t;
191 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
192 | reinterpret_cast(Bvar),
193 | reinterpret_cast(B_vals_load)
194 | );
195 | } else {
196 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
197 | }
198 | #pragma unroll
199 | for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
200 | }
201 | }
202 |
203 | template
204 | inline __device__ void store_output(typename Ktraits::input_t *out,
205 | const float (&out_vals)[Ktraits::kNItems],
206 | typename Ktraits::BlockStoreT::TempStorage &smem_store,
207 | int seqlen) {
208 | typename Ktraits::input_t write_vals[Ktraits::kNItems];
209 | #pragma unroll
210 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
211 | if constexpr (Ktraits::kIsEvenLen) {
212 | auto& smem_store_vec = reinterpret_cast(smem_store);
213 | using vec_t = typename Ktraits::vec_t;
214 | Ktraits::BlockStoreVecT(smem_store_vec).Store(
215 | reinterpret_cast(out),
216 | reinterpret_cast(write_vals)
217 | );
218 | } else {
219 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
220 | }
221 | }
222 |
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_bf16.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp16.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp32.cu:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | // Split into multiple files to compile in paralell
6 |
7 | #include "selective_scan_fwd_kernel.cuh"
8 |
9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/static_switch.h:
--------------------------------------------------------------------------------
1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3 |
4 | #pragma once
5 |
6 | /// @param COND - a boolean expression to switch by
7 | /// @param CONST_NAME - a name given for the constexpr bool variable.
8 | /// @param ... - code to execute for true and false
9 | ///
10 | /// Usage:
11 | /// ```
12 | /// BOOL_SWITCH(flag, BoolConst, [&] {
13 | /// some_function(...);
14 | /// });
15 | /// ```
16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17 | [&] { \
18 | if (COND) { \
19 | constexpr bool CONST_NAME = true; \
20 | return __VA_ARGS__(); \
21 | } else { \
22 | constexpr bool CONST_NAME = false; \
23 | return __VA_ARGS__(); \
24 | } \
25 | }()
26 |
--------------------------------------------------------------------------------
/mamba-1p1p1/csrc/selective_scan/uninitialized_copy.cuh:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Redistribution and use in source and binary forms, with or without
5 | * modification, are permitted provided that the following conditions are met:
6 | * * Redistributions of source code must retain the above copyright
7 | * notice, this list of conditions and the following disclaimer.
8 | * * Redistributions in binary form must reproduce the above copyright
9 | * notice, this list of conditions and the following disclaimer in the
10 | * documentation and/or other materials provided with the distribution.
11 | * * Neither the name of the NVIDIA CORPORATION nor the
12 | * names of its contributors may be used to endorse or promote products
13 | * derived from this software without specific prior written permission.
14 | *
15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | *
26 | ******************************************************************************/
27 |
28 | #pragma once
29 |
30 | #include
31 |
32 | #include
33 |
34 |
35 | namespace detail
36 | {
37 |
38 | #if defined(_NVHPC_CUDA)
39 | template
40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
41 | {
42 | // NVBug 3384810
43 | new (ptr) T(::cuda::std::forward(val));
44 | }
45 | #else
46 | template ::value,
50 | int
51 | >::type = 0>
52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
53 | {
54 | *ptr = ::cuda::std::forward(val);
55 | }
56 |
57 | template ::value,
61 | int
62 | >::type = 0>
63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
64 | {
65 | new (ptr) T(::cuda::std::forward(val));
66 | }
67 | #endif
68 |
69 | } // namespace detail
70 |
--------------------------------------------------------------------------------
/mamba-1p1p1/evals/lm_harness_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import transformers
4 | from transformers import AutoTokenizer
5 |
6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
7 |
8 | from lm_eval.api.model import LM
9 | from lm_eval.models.huggingface import HFLM
10 | from lm_eval.api.registry import register_model
11 | from lm_eval.__main__ import cli_evaluate
12 |
13 |
14 | @register_model("mamba")
15 | class MambaEvalWrapper(HFLM):
16 |
17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
18 |
19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda",
20 | dtype=torch.float16):
21 | LM.__init__(self)
22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)
23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
25 | self.vocab_size = self.tokenizer.vocab_size
26 | self._batch_size = int(batch_size) if batch_size is not None else 64
27 | self._max_length = max_length
28 | self._device = torch.device(device)
29 |
30 | @property
31 | def batch_size(self):
32 | return self._batch_size
33 |
34 | def _model_generate(self, context, max_length, stop, **generation_kwargs):
35 | raise NotImplementedError()
36 |
37 |
38 | if __name__ == "__main__":
39 | cli_evaluate()
40 |
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.1.1"
2 |
3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4 | from mamba_ssm.modules.mamba_simple import Mamba
5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
6 |
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/models/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/models/config_mamba.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 |
3 |
4 | @dataclass
5 | class MambaConfig:
6 |
7 | d_model: int = 2560
8 | n_layer: int = 64
9 | vocab_size: int = 50277
10 | ssm_cfg: dict = field(default_factory=dict)
11 | rms_norm: bool = True
12 | residual_in_fp32: bool = True
13 | fused_add_norm: bool = True
14 | pad_vocab_size_multiple: int = 8
15 |
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/modules/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/ops/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/ops/triton/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/ops/triton/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/ops/triton/selective_state_update.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Tri Dao.
2 |
3 | """We want triton==2.1.0 for this
4 | """
5 |
6 | import math
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | import triton
11 | import triton.language as tl
12 |
13 | from einops import rearrange, repeat
14 |
15 |
16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
20 | @triton.jit
21 | def _selective_scan_update_kernel(
22 | # Pointers to matrices
23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
24 | # Matrix dimensions
25 | batch, dim, dstate,
26 | # Strides
27 | stride_state_batch, stride_state_dim, stride_state_dstate,
28 | stride_x_batch, stride_x_dim,
29 | stride_dt_batch, stride_dt_dim,
30 | stride_dt_bias_dim,
31 | stride_A_dim, stride_A_dstate,
32 | stride_B_batch, stride_B_dstate,
33 | stride_C_batch, stride_C_dstate,
34 | stride_D_dim,
35 | stride_z_batch, stride_z_dim,
36 | stride_out_batch, stride_out_dim,
37 | # Meta-parameters
38 | DT_SOFTPLUS: tl.constexpr,
39 | BLOCK_SIZE_M: tl.constexpr,
40 | HAS_DT_BIAS: tl.constexpr,
41 | HAS_D: tl.constexpr,
42 | HAS_Z: tl.constexpr,
43 | BLOCK_SIZE_DSTATE: tl.constexpr,
44 | ):
45 | pid_m = tl.program_id(axis=0)
46 | pid_b = tl.program_id(axis=1)
47 | state_ptr += pid_b * stride_state_batch
48 | x_ptr += pid_b * stride_x_batch
49 | dt_ptr += pid_b * stride_dt_batch
50 | B_ptr += pid_b * stride_B_batch
51 | C_ptr += pid_b * stride_C_batch
52 | if HAS_Z:
53 | z_ptr += pid_b * stride_z_batch
54 | out_ptr += pid_b * stride_out_batch
55 |
56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
59 | x_ptrs = x_ptr + offs_m * stride_x_dim
60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim
61 | if HAS_DT_BIAS:
62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
64 | B_ptrs = B_ptr + offs_n * stride_B_dstate
65 | C_ptrs = C_ptr + offs_n * stride_C_dstate
66 | if HAS_D:
67 | D_ptrs = D_ptr + offs_m * stride_D_dim
68 | if HAS_Z:
69 | z_ptrs = z_ptr + offs_m * stride_z_dim
70 | out_ptrs = out_ptr + offs_m * stride_out_dim
71 |
72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
75 | if HAS_DT_BIAS:
76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
77 | if DT_SOFTPLUS:
78 | dt = tl.log(1.0 + tl.exp(dt))
79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
80 | dA = tl.exp(A * dt[:, None])
81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
83 | if HAS_D:
84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85 | if HAS_Z:
86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
87 |
88 | dB = B[None, :] * dt[:, None]
89 | state = state * dA + dB * x[:, None]
90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
91 | out = tl.sum(state * C[None, :], axis=1)
92 | if HAS_D:
93 | out += x * D
94 | if HAS_Z:
95 | out *= z * tl.sigmoid(z)
96 | tl.store(out_ptrs, out, mask=offs_m < dim)
97 |
98 |
99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
100 | """
101 | Argument:
102 | state: (batch, dim, dstate)
103 | x: (batch, dim)
104 | dt: (batch, dim)
105 | A: (dim, dstate)
106 | B: (batch, dstate)
107 | C: (batch, dstate)
108 | D: (dim,)
109 | z: (batch, dim)
110 | dt_bias: (dim,)
111 | Return:
112 | out: (batch, dim)
113 | """
114 | batch, dim, dstate = state.shape
115 | assert x.shape == (batch, dim)
116 | assert dt.shape == x.shape
117 | assert A.shape == (dim, dstate)
118 | assert B.shape == (batch, dstate)
119 | assert C.shape == B.shape
120 | if D is not None:
121 | assert D.shape == (dim,)
122 | if z is not None:
123 | assert z.shape == x.shape
124 | if dt_bias is not None:
125 | assert dt_bias.shape == (dim,)
126 | out = torch.empty_like(x)
127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
129 | # We don't want autotune since it will overwrite the state
130 | # We instead tune by hand.
131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
132 | else ((16, 4) if dstate <= 32 else
133 | ((8, 4) if dstate <= 64 else
134 | ((4, 4) if dstate <= 128 else
135 | ((4, 8))))))
136 | with torch.cuda.device(x.device.index):
137 | _selective_scan_update_kernel[grid](
138 | state, x, dt, dt_bias, A, B, C, D, z, out,
139 | batch, dim, dstate,
140 | state.stride(0), state.stride(1), state.stride(2),
141 | x.stride(0), x.stride(1),
142 | dt.stride(0), dt.stride(1),
143 | dt_bias.stride(0) if dt_bias is not None else 0,
144 | A.stride(0), A.stride(1),
145 | B.stride(0), B.stride(1),
146 | C.stride(0), C.stride(1),
147 | D.stride(0) if D is not None else 0,
148 | z_strides[0], z_strides[1],
149 | out.stride(0), out.stride(1),
150 | dt_softplus,
151 | BLOCK_SIZE_M,
152 | num_warps=num_warps,
153 | )
154 | return out
155 |
156 |
157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
158 | """
159 | Argument:
160 | state: (batch, dim, dstate)
161 | x: (batch, dim)
162 | dt: (batch, dim)
163 | A: (dim, dstate)
164 | B: (batch, dstate)
165 | C: (batch, dstate)
166 | D: (dim,)
167 | z: (batch, dim)
168 | dt_bias: (dim,)
169 | Return:
170 | out: (batch, dim)
171 | """
172 | batch, dim, dstate = state.shape
173 | assert x.shape == (batch, dim)
174 | assert dt.shape == x.shape
175 | assert A.shape == (dim, dstate)
176 | assert B.shape == (batch, dstate)
177 | assert C.shape == B.shape
178 | if D is not None:
179 | assert D.shape == (dim,)
180 | if z is not None:
181 | assert z.shape == x.shape
182 | if dt_bias is not None:
183 | assert dt_bias.shape == (dim,)
184 | dt = dt + dt_bias
185 | dt = F.softplus(dt) if dt_softplus else dt
186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
190 | if D is not None:
191 | out += (x * D).to(out.dtype)
192 | return (out if z is None else out * F.silu(z)).to(x.dtype)
193 |
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/Famba-V/28f7236f0ab4acc4fbab611cbff6b5c5af0705c9/mamba-1p1p1/mamba_ssm/utils/__init__.py
--------------------------------------------------------------------------------
/mamba-1p1p1/mamba_ssm/utils/hf.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch
4 |
5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6 | from transformers.utils.hub import cached_file
7 |
8 |
9 | def load_config_hf(model_name):
10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11 | return json.load(open(resolved_archive_file))
12 |
13 |
14 | def load_state_dict_hf(model_name, device=None, dtype=None):
15 | # If not fp32, then we don't want to load directly to the GPU
16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18 | return torch.load(resolved_archive_file, map_location=mapped_device)
19 | # Convert dtype before moving to GPU to save memory
20 | if dtype is not None:
21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23 | return state_dict
24 |
--------------------------------------------------------------------------------
/mamba-1p1p1/tests/ops/triton/test_selective_state_update.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2023, Tri Dao.
2 |
3 | import math
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import pytest
8 |
9 | from einops import rearrange
10 |
11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
12 |
13 |
14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
15 | # @pytest.mark.parametrize('itype', [torch.float16])
16 | @pytest.mark.parametrize("has_z", [False, True])
17 | # @pytest.mark.parametrize('has_z', [True])
18 | @pytest.mark.parametrize("dstate", [16, 32, 64])
19 | # @pytest.mark.parametrize("dstate", [16])
20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
21 | # @pytest.mark.parametrize("dim", [2048])
22 | def test_causal_conv1d_update(dim, dstate, has_z, itype):
23 | device = "cuda"
24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
25 | if itype == torch.bfloat16:
26 | rtol, atol = 1e-2, 5e-2
27 | # set seed
28 | torch.random.manual_seed(0)
29 | batch_size = 2
30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
31 | x = torch.randn(batch_size, dim, device=device, dtype=itype)
32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype)
33 | dt_bias = torch.rand(dim, device=device) - 4.0
34 | A = -torch.rand(dim, dstate, device=device) - 1.0
35 | B = torch.randn(batch_size, dstate, device=device)
36 | C = torch.randn(batch_size, dstate, device=device)
37 | D = torch.randn(dim, device=device)
38 | if has_z:
39 | z = torch.randn_like(x)
40 | else:
41 | z = None
42 | state_ref = state.detach().clone()
43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
45 |
46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}")
47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
50 |
--------------------------------------------------------------------------------