├── assets ├── sts.png └── main.png ├── requirements.txt ├── classification ├── utils │ ├── __pycache__ │ │ ├── logger.cpython-38.pyc │ │ ├── utils.cpython-38.pyc │ │ ├── optimizer.cpython-38.pyc │ │ └── lr_scheduler.cpython-38.pyc │ ├── logger.py │ ├── cosine_lr.py │ ├── lr_scheduler.py │ ├── optimizer.py │ └── utils.py ├── readme.md ├── train_classification.sh ├── data │ ├── __init__.py │ ├── samplers.py │ ├── imagenet22k_dataset.py │ ├── zipreader.py │ ├── data_simmim_pt.py │ ├── data_simmim_ft.py │ ├── build.py │ ├── cached_image_folder.py │ └── map22kto1k.txt ├── configs │ └── vssm │ │ ├── spectral_vmamba_small_224.yaml │ │ ├── spectral_vmamba_tiny_224.yaml │ │ └── spectral_vmamba_base_224.yaml ├── config.py └── models │ └── csms6s.py ├── kernels └── selective_scan │ ├── csrc │ └── selective_scan │ │ ├── cusnrow │ │ ├── selective_scan_core_bwd3.cu │ │ ├── selective_scan_core_bwd4.cu │ │ ├── selective_scan_core_bwd.cu │ │ ├── selective_scan_core_bwd2.cu │ │ ├── selective_scan_core_fwd.cu │ │ ├── selective_scan_core_fwd2.cu │ │ ├── selective_scan_core_fwd3.cu │ │ ├── selective_scan_core_fwd4.cu │ │ └── selective_scan_fwd_kernel_nrow.cuh │ │ ├── cus │ │ ├── selective_scan_core_bwd.cu │ │ ├── selective_scan_core_fwd.cu │ │ └── selective_scan_fwd_kernel.cuh │ │ ├── cusndstate │ │ ├── selective_scan_core_bwd.cu │ │ ├── selective_scan_core_fwd.cu │ │ ├── selective_scan_ndstate.h │ │ ├── selective_scan_fwd_kernel_ndstate.cuh │ │ └── selective_scan_ndstate.cpp │ │ ├── cusoflex │ │ ├── selective_scan_core_bwd.cu │ │ ├── selective_scan_core_fwd.cu │ │ └── selective_scan_fwd_kernel_oflex.cuh │ │ ├── static_switch.h │ │ ├── cub_extra.cuh │ │ ├── selective_scan.h │ │ ├── uninitialized_copy.cuh │ │ └── selective_scan_common.h │ ├── README.md │ └── setup.py ├── LICENSE ├── .gitignore └── README.md /assets/sts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sahardastani/spectral_vmamba/HEAD/assets/sts.png -------------------------------------------------------------------------------- /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sahardastani/spectral_vmamba/HEAD/assets/main.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | triton 3 | timm==0.4.12 4 | pytest 5 | chardet 6 | yacs 7 | termcolor 8 | submitit 9 | tensorboardX 10 | fvcore 11 | seaborn -------------------------------------------------------------------------------- /classification/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sahardastani/spectral_vmamba/HEAD/classification/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sahardastani/spectral_vmamba/HEAD/classification/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/optimizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sahardastani/spectral_vmamba/HEAD/classification/utils/__pycache__/optimizer.cpython-38.pyc -------------------------------------------------------------------------------- /classification/utils/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sahardastani/spectral_vmamba/HEAD/classification/utils/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /classification/readme.md: -------------------------------------------------------------------------------- 1 | ## origins 2 | 3 | based on https://github.com/microsoft/Swin-Transformer#20240103 4 | 5 | `main.py` and `utils/utils_ema.py` is modified from https://github.com/microsoft/Swin-Transformer#20240103, based on https://github.com/facebookresearch/ConvNeXt#20240103 6 | 7 | -------------------------------------------------------------------------------- /classification/train_classification.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1 \ 3 | python -m torch.distributed.launch \ 4 | --nnodes=1 \ 5 | --node_rank=0 \ 6 | --nproc_per_node=2 \ 7 | --master_addr="127.0.0.1" \ 8 | --master_port=21495 \ 9 | main.py \ 10 | --cfg configs/vssm/vmamba_tiny_224.yaml \ 11 | --batch-size 128 \ 12 | --data-path /data/shared/mini-imagenet \ 13 | --output output -------------------------------------------------------------------------------- /classification/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader as _build_loader 2 | from .data_simmim_pt import build_loader_simmim 3 | from .data_simmim_ft import build_loader_finetune 4 | 5 | 6 | def build_loader(config, simmim=False, is_pretrain=False): 7 | if not simmim: 8 | return _build_loader(config) 9 | if is_pretrain: 10 | return build_loader_simmim(config) 11 | else: 12 | return build_loader_finetune(config) 13 | -------------------------------------------------------------------------------- /classification/configs/vssm/spectral_vmamba_small_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_small_0229 4 | DROP_PATH_RATE: 0.6 5 | VSSM: 6 | EMBED_DIM: 256 7 | DEPTHS: [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | MODE: "RFN" -------------------------------------------------------------------------------- /classification/configs/vssm/spectral_vmamba_tiny_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_tiny_0230s 4 | DROP_PATH_RATE: 0.2 5 | VSSM: 6 | EMBED_DIM: 256 7 | DEPTHS: [ 2, 2, 5, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | MODE: "RFN" -------------------------------------------------------------------------------- /classification/configs/vssm/spectral_vmamba_base_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vssm 3 | NAME: vssm1_base_0229 4 | DROP_PATH_RATE: 0.3 5 | VSSM: 6 | EMBED_DIM: 256 7 | DEPTHS: [ 2, 2, 15, 2 ] 8 | SSM_D_STATE: 1 9 | SSM_DT_RANK: "auto" 10 | SSM_RATIO: 2.0 11 | SSM_CONV: 1 12 | SSM_CONV_BIAS: false 13 | SSM_FORWARDTYPE: "v05_noz" # v3_noz 14 | MLP_RATIO: 4.0 15 | DOWNSAMPLE: "v3" 16 | PATCHEMBED: "v2" 17 | NORM_LAYER: "ln2d" 18 | MODE: "RFN" 19 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_bwd_cuda<3, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<3, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<3, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_bwd_cuda<4, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<4, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<4, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_bwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel.cuh" 5 | 6 | template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cus/selective_scan_core_fwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel.cuh" 5 | 6 | template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_bwd_cuda<2, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<2, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<2, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_ndstate.cuh" 5 | 6 | template void selective_scan_bwd_cuda<1, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<1, at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<1, at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_fwd_cuda<2, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<2, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<2, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_fwd_cuda<3, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<3, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<3, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_nrow.cuh" 5 | 6 | template void selective_scan_fwd_cuda<4, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<4, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<4, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_ndstate.cuh" 5 | 6 | template void selective_scan_fwd_cuda<1, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<1, at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<1, at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | 10 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_bwd_kernel_oflex.cuh" 5 | 6 | template void selective_scan_bwd_cuda<1, float, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 7 | template void selective_scan_bwd_cuda<1, at::Half, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 8 | template void selective_scan_bwd_cuda<1, at::BFloat16, float, float>(SSMParamsBwd ¶ms, cudaStream_t stream); 9 | template void selective_scan_bwd_cuda<1, at::Half, float, at::Half>(SSMParamsBwd ¶ms, cudaStream_t stream); 10 | template void selective_scan_bwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBwd ¶ms, cudaStream_t stream); 11 | 12 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | #include "selective_scan_fwd_kernel_oflex.cuh" 5 | 6 | template void selective_scan_fwd_cuda<1, float, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 7 | template void selective_scan_fwd_cuda<1, at::Half, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 8 | template void selective_scan_fwd_cuda<1, at::BFloat16, float, float>(SSMParamsBase ¶ms, cudaStream_t stream); 9 | template void selective_scan_fwd_cuda<1, at::Half, float, at::Half>(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda<1, at::BFloat16, float, at::BFloat16>(SSMParamsBase ¶ms, cudaStream_t stream); 11 | 12 | -------------------------------------------------------------------------------- /classification/data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 MzeroMiko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | **/_ignore 3 | **/.dist_test 4 | **/.pytest_cache 5 | **/*.egg-info 6 | **/*.TAG 7 | **/dist 8 | **/*.so 9 | *.so 10 | **/build 11 | **/tmp 12 | **/output 13 | **/work_dirs 14 | logs 15 | ckpts 16 | log 17 | ckpt 18 | 19 | analyze/show 20 | analyze/features 21 | cross_selective_scan 22 | classification/models/pscan.py 23 | classification/configs/heat 24 | classification/models/heat 25 | classification/models/heat.py 26 | classification/models/vim.py 27 | classification/models/pscan.py 28 | 29 | detection/data 30 | segmentation/data 31 | 32 | kernels/selective_scan/1.log 33 | kernels/csmcuda 34 | classification/dev.py 35 | classification/models/triton_parts.py 36 | classification/1.log 37 | kernels/selective_scan/test_selective_scan_speed.py 38 | kernels/selective_scan/test_selective_scan_easy.py 39 | kernels/selective_scan/ssmtriton.py 40 | test.sh 41 | classification/models/vmamba copy.py 42 | kernels/selective_scan/ssmjax.py 43 | classification/models/vheat_wzz.py 44 | classification/main_simmim_pt.py 45 | 46 | 47 | classification/main_simmim_pt.py 48 | classification/main_simmim_ft.py 49 | classification/utils/utils_simmim.py 50 | classification/models/simmim.py 51 | classification/models/pvmamba.py 52 | classification/models/simvmamba.py 53 | classification/models/csm_triton_n.py 54 | detection/vis/bears.jpg 55 | detection/1.ipynb 56 | detection/bears.jpg 57 | classification/1.ipynb 58 | 1.txt -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /classification/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cub_extra.cuh: -------------------------------------------------------------------------------- 1 | // WarpMask is copied from /usr/local/cuda-12.1/include/cub/util_ptx.cuh 2 | // PowerOfTwo is copied from /usr/local/cuda-12.1/include/cub/util_type.cuh 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | /** 12 | * \brief Statically determine if N is a power-of-two 13 | */ 14 | template 15 | struct PowerOfTwo 16 | { 17 | enum { VALUE = ((N & (N - 1)) == 0) }; 18 | }; 19 | 20 | 21 | /** 22 | * @brief Returns the warp mask for a warp of @p LOGICAL_WARP_THREADS threads 23 | * 24 | * @par 25 | * If the number of threads assigned to the virtual warp is not a power of two, 26 | * it's assumed that only one virtual warp exists. 27 | * 28 | * @tparam LOGICAL_WARP_THREADS [optional] The number of threads per 29 | * "logical" warp (may be less than the number of 30 | * hardware warp threads). 31 | * @param warp_id Id of virtual warp within architectural warp 32 | */ 33 | template 34 | __host__ __device__ __forceinline__ 35 | unsigned int WarpMask(unsigned int warp_id) 36 | { 37 | constexpr bool is_pow_of_two = PowerOfTwo::VALUE; 38 | constexpr bool is_arch_warp = LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0); 39 | 40 | unsigned int member_mask = 0xFFFFFFFFu >> 41 | (CUB_WARP_THREADS(0) - LOGICAL_WARP_THREADS); 42 | 43 | if (is_pow_of_two && !is_arch_warp) 44 | { 45 | member_mask <<= warp_id * LOGICAL_WARP_THREADS; 46 | } 47 | 48 | return member_mask; 49 | } 50 | -------------------------------------------------------------------------------- /classification/data/imagenet22k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch.utils.data as data 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 10 | 11 | 12 | class IN22KDATASET(data.Dataset): 13 | def __init__(self, root, ann_file='', transform=None, target_transform=None): 14 | super(IN22KDATASET, self).__init__() 15 | 16 | self.data_path = root 17 | self.ann_path = os.path.join(self.data_path, ann_file) 18 | self.transform = transform 19 | self.target_transform = target_transform 20 | # id & label: https://github.com/google-research/big_transfer/issues/7 21 | # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027 22 | self.database = json.load(open(self.ann_path)) 23 | 24 | def _load_image(self, path): 25 | try: 26 | im = Image.open(path) 27 | except: 28 | print("ERROR IMG LOADED: ", path) 29 | random_img = np.random.rand(224, 224, 3) * 255 30 | im = Image.fromarray(np.uint8(random_img)) 31 | return im 32 | 33 | def __getitem__(self, index): 34 | """ 35 | Args: 36 | index (int): Index 37 | Returns: 38 | tuple: (image, target) where target is class_index of the target class. 39 | """ 40 | idb = self.database[index] 41 | 42 | # images 43 | images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB') 44 | if self.transform is not None: 45 | images = self.transform(images) 46 | 47 | # target 48 | target = int(idb[1]) 49 | if self.target_transform is not None: 50 | target = self.target_transform(target) 51 | 52 | return images, target 53 | 54 | def __len__(self): 55 | return len(self.database) 56 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | 32 | bool delta_softplus; 33 | 34 | index_t A_d_stride; 35 | index_t B_batch_stride; 36 | index_t B_d_stride; 37 | index_t B_group_stride; 38 | index_t C_batch_stride; 39 | index_t C_d_stride; 40 | index_t C_group_stride; 41 | index_t u_batch_stride; 42 | index_t u_d_stride; 43 | index_t delta_batch_stride; 44 | index_t delta_d_stride; 45 | index_t out_batch_stride; 46 | index_t out_d_stride; 47 | 48 | // Common data pointers. 49 | void *__restrict__ A_ptr; 50 | void *__restrict__ B_ptr; 51 | void *__restrict__ C_ptr; 52 | void *__restrict__ D_ptr; 53 | void *__restrict__ u_ptr; 54 | void *__restrict__ delta_ptr; 55 | void *__restrict__ delta_bias_ptr; 56 | void *__restrict__ out_ptr; 57 | void *__restrict__ x_ptr; 58 | }; 59 | 60 | struct SSMParamsBwd: public SSMParamsBase { 61 | index_t dout_batch_stride; 62 | index_t dout_d_stride; 63 | index_t dA_d_stride; 64 | index_t dB_batch_stride; 65 | index_t dB_group_stride; 66 | index_t dB_d_stride; 67 | index_t dC_batch_stride; 68 | index_t dC_group_stride; 69 | index_t dC_d_stride; 70 | index_t du_batch_stride; 71 | index_t du_d_stride; 72 | index_t ddelta_batch_stride; 73 | index_t ddelta_d_stride; 74 | 75 | // Common data pointers. 76 | void *__restrict__ dout_ptr; 77 | void *__restrict__ dA_ptr; 78 | void *__restrict__ dB_ptr; 79 | void *__restrict__ dC_ptr; 80 | void *__restrict__ dD_ptr; 81 | void *__restrict__ du_ptr; 82 | void *__restrict__ ddelta_ptr; 83 | void *__restrict__ ddelta_bias_ptr; 84 | }; 85 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | 32 | bool delta_softplus; 33 | 34 | index_t A_d_stride; 35 | index_t A_dstate_stride; 36 | index_t B_batch_stride; 37 | index_t B_d_stride; 38 | index_t B_dstate_stride; 39 | index_t B_group_stride; 40 | index_t C_batch_stride; 41 | index_t C_d_stride; 42 | index_t C_dstate_stride; 43 | index_t C_group_stride; 44 | index_t u_batch_stride; 45 | index_t u_d_stride; 46 | index_t delta_batch_stride; 47 | index_t delta_d_stride; 48 | index_t out_batch_stride; 49 | index_t out_d_stride; 50 | 51 | // Common data pointers. 52 | void *__restrict__ A_ptr; 53 | void *__restrict__ B_ptr; 54 | void *__restrict__ C_ptr; 55 | void *__restrict__ D_ptr; 56 | void *__restrict__ u_ptr; 57 | void *__restrict__ delta_ptr; 58 | void *__restrict__ delta_bias_ptr; 59 | void *__restrict__ out_ptr; 60 | void *__restrict__ x_ptr; 61 | }; 62 | 63 | struct SSMParamsBwd: public SSMParamsBase { 64 | index_t dout_batch_stride; 65 | index_t dout_d_stride; 66 | index_t dA_d_stride; 67 | index_t dA_dstate_stride; 68 | index_t dB_batch_stride; 69 | index_t dB_group_stride; 70 | index_t dB_d_stride; 71 | index_t dB_dstate_stride; 72 | index_t dC_batch_stride; 73 | index_t dC_group_stride; 74 | index_t dC_d_stride; 75 | index_t dC_dstate_stride; 76 | index_t du_batch_stride; 77 | index_t du_d_stride; 78 | index_t ddelta_batch_stride; 79 | index_t ddelta_d_stride; 80 | 81 | // Common data pointers. 82 | void *__restrict__ dout_ptr; 83 | void *__restrict__ dA_ptr; 84 | void *__restrict__ dB_ptr; 85 | void *__restrict__ dC_ptr; 86 | void *__restrict__ dD_ptr; 87 | void *__restrict__ du_ptr; 88 | void *__restrict__ ddelta_ptr; 89 | void *__restrict__ ddelta_bias_ptr; 90 | }; 91 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /kernels/selective_scan/README.md: -------------------------------------------------------------------------------- 1 | # mamba-mini 2 | An efficient implementation of selective scan in one file, works with both cpu and gpu, with corresponding mathematical derivation. It is probably the code which is the most close to selective_scan_cuda in mamba. 3 | 4 | ### mathematical derivation 5 | ![image](../assets/derivation.png) 6 | 7 | ### code 8 | ```python 9 | import torch 10 | def selective_scan_easy(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64): 11 | """ 12 | # B: batch_size, G: groups, D: dim, N: state dim, L: seqlen 13 | us: B, G * D, L 14 | dts: B, G * D, L 15 | As: G * D, N 16 | Bs: B, G, N, L 17 | Cs: B, G, N, L 18 | Ds: G * D 19 | delta_bias: G * D 20 | # chunksize can be any as you like. But as the chunksize raises, hs may get None, as exp(sum(delta) A) is really small 21 | """ 22 | def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix): 23 | """ 24 | partial(h) / partial(t) = Ah + Bu; y = Ch + Du; 25 | => partial(h*exp(-At)) / partial(t) = Bu*exp(-At); 26 | => h_t = h_0 + sum_{0}_{t}_{Bu*exp(A(t-v)) dv}; 27 | => h_b = exp(A(dt_a + ... + dt_{b-1})) * (h_a + sum_{a}_{b-1}_{Bu*exp(-A(dt_a + ... + dt_i)) dt_i}); 28 | y_i = C_i*h_i + D*u_i 29 | """ 30 | """ 31 | us, dts: (L, B, G, D) # L is chunk_size 32 | As: (G, D, N) 33 | Bs, Cs: (L, B, G, N) 34 | Ds: (G, D) 35 | hprefix: (B, G, D, N) 36 | """ 37 | ts = dts.cumsum(dim=0) 38 | Ats = torch.einsum("gdn,lbgd->lbgdn", As, ts).exp() 39 | scale = Ats[-1].detach() 40 | rAts = Ats / scale 41 | duts = dts * us 42 | dtBus = torch.einsum("lbgd,lbgn->lbgdn", duts, Bs) 43 | hs_tmp = rAts * (dtBus / rAts).cumsum(dim=0) 44 | hs = hs_tmp + Ats * hprefix.unsqueeze(0) 45 | ys = torch.einsum("lbgn,lbgdn->lbgd", Cs, hs) 46 | return ys, hs 47 | 48 | inp_dtype = us.dtype 49 | has_D = Ds is not None 50 | 51 | dts = dts.float() 52 | if delta_bias is not None: 53 | dts = dts + delta_bias.view(1, -1, 1).float() 54 | if delta_softplus: 55 | dts = torch.nn.functional.softplus(dts) 56 | 57 | if len(Bs.shape) == 3: 58 | Bs = Bs.unsqueeze(1) 59 | if len(Cs.shape) == 3: 60 | Cs = Cs.unsqueeze(1) 61 | B, G, N, L = Bs.shape 62 | us = us.view(B, G, -1, L).permute(3, 0, 1, 2).float() 63 | dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).float() 64 | As = As.view(G, -1, N).float() 65 | Bs = Bs.permute(3, 0, 1, 2).float() 66 | Cs = Cs.permute(3, 0, 1, 2).float() 67 | Ds = Ds.view(G, -1).float() if has_D else None 68 | D = As.shape[1] 69 | 70 | oys = [] 71 | # ohs = [] 72 | hprefix = us.new_zeros((B, G, D, N), dtype=torch.float) 73 | for i in range(0, L - 1, chunksize): 74 | ys, hs = selective_scan_chunk( 75 | us[i:i + chunksize], dts[i:i + chunksize], 76 | As, Bs[i:i + chunksize], Cs[i:i + chunksize], hprefix, 77 | ) 78 | oys.append(ys) 79 | # ohs.append(hs) 80 | hprefix = hs[-1] 81 | 82 | oys = torch.cat(oys, dim=0) 83 | # ohs = torch.cat(ohs, dim=0) 84 | if has_D: 85 | oys = oys + Ds * us 86 | oys = oys.permute(1, 2, 3, 0).view(B, -1, L) 87 | oys = oys.to(inp_dtype) 88 | # hprefix = hprefix.to(inp_dtype) 89 | 90 | return oys if not return_last_state else (oys, hprefix.view(B, G * D, N)) 91 | 92 | ``` 93 | 94 | ### to test 95 | ```bash 96 | pytest test_selective_scan.py 97 | ``` 98 | -------------------------------------------------------------------------------- /classification/data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import zipfile 10 | import io 11 | import numpy as np 12 | from PIL import Image 13 | from PIL import ImageFile 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def is_zip_path(img_or_path): 19 | """judge if this is a zip path""" 20 | return '.zip@' in img_or_path 21 | 22 | 23 | class ZipReader(object): 24 | """A class to read zipped files""" 25 | zip_bank = dict() 26 | 27 | def __init__(self): 28 | super(ZipReader, self).__init__() 29 | 30 | @staticmethod 31 | def get_zipfile(path): 32 | zip_bank = ZipReader.zip_bank 33 | if path not in zip_bank: 34 | zfile = zipfile.ZipFile(path, 'r') 35 | zip_bank[path] = zfile 36 | return zip_bank[path] 37 | 38 | @staticmethod 39 | def split_zip_style_path(path): 40 | pos_at = path.index('@') 41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 42 | 43 | zip_path = path[0: pos_at] 44 | folder_path = path[pos_at + 1:] 45 | folder_path = str.strip(folder_path, '/') 46 | return zip_path, folder_path 47 | 48 | @staticmethod 49 | def list_folder(path): 50 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 51 | 52 | zfile = ZipReader.get_zipfile(zip_path) 53 | folder_list = [] 54 | for file_foler_name in zfile.namelist(): 55 | file_foler_name = str.strip(file_foler_name, '/') 56 | if file_foler_name.startswith(folder_path) and \ 57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 58 | file_foler_name != folder_path: 59 | if len(folder_path) == 0: 60 | folder_list.append(file_foler_name) 61 | else: 62 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 63 | 64 | return folder_list 65 | 66 | @staticmethod 67 | def list_files(path, extension=None): 68 | if extension is None: 69 | extension = ['.*'] 70 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 71 | 72 | zfile = ZipReader.get_zipfile(zip_path) 73 | file_lists = [] 74 | for file_foler_name in zfile.namelist(): 75 | file_foler_name = str.strip(file_foler_name, '/') 76 | if file_foler_name.startswith(folder_path) and \ 77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 78 | if len(folder_path) == 0: 79 | file_lists.append(file_foler_name) 80 | else: 81 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 82 | 83 | return file_lists 84 | 85 | @staticmethod 86 | def read(path): 87 | zip_path, path_img = ZipReader.split_zip_style_path(path) 88 | zfile = ZipReader.get_zipfile(zip_path) 89 | data = zfile.read(path_img) 90 | return data 91 | 92 | @staticmethod 93 | def imread(path): 94 | zip_path, path_img = ZipReader.split_zip_style_path(path) 95 | zfile = ZipReader.get_zipfile(zip_path) 96 | data = zfile.read(path_img) 97 | try: 98 | im = Image.open(io.BytesIO(data)) 99 | except: 100 | print("ERROR IMG LOADED: ", path_img) 101 | random_img = np.random.rand(224, 224, 3) * 255 102 | im = Image.fromarray(np.uint8(random_img)) 103 | return im 104 | -------------------------------------------------------------------------------- /classification/data/data_simmim_pt.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torchvision.transforms as T 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | from torch.utils.data._utils.collate import default_collate 17 | from torchvision.datasets import ImageFolder 18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | 20 | 21 | class MaskGenerator: 22 | def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): 23 | self.input_size = input_size 24 | self.mask_patch_size = mask_patch_size 25 | self.model_patch_size = model_patch_size 26 | self.mask_ratio = mask_ratio 27 | 28 | assert self.input_size % self.mask_patch_size == 0 29 | assert self.mask_patch_size % self.model_patch_size == 0 30 | 31 | self.rand_size = self.input_size // self.mask_patch_size 32 | self.scale = self.mask_patch_size // self.model_patch_size 33 | 34 | self.token_count = self.rand_size ** 2 35 | self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) 36 | 37 | def __call__(self): 38 | mask_idx = np.random.permutation(self.token_count)[:self.mask_count] 39 | mask = np.zeros(self.token_count, dtype=int) 40 | mask[mask_idx] = 1 41 | 42 | mask = mask.reshape((self.rand_size, self.rand_size)) 43 | mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) 44 | 45 | return mask 46 | 47 | 48 | class SimMIMTransform: 49 | def __init__(self, config): 50 | self.transform_img = T.Compose([ 51 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 52 | T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), 53 | T.RandomHorizontalFlip(), 54 | T.ToTensor(), 55 | T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), 56 | ]) 57 | 58 | if config.MODEL.TYPE in ['swin', 'swinv2']: 59 | model_patch_size=config.MODEL.SWIN.PATCH_SIZE 60 | else: 61 | raise NotImplementedError 62 | 63 | self.mask_generator = MaskGenerator( 64 | input_size=config.DATA.IMG_SIZE, 65 | mask_patch_size=config.DATA.MASK_PATCH_SIZE, 66 | model_patch_size=model_patch_size, 67 | mask_ratio=config.DATA.MASK_RATIO, 68 | ) 69 | 70 | def __call__(self, img): 71 | img = self.transform_img(img) 72 | mask = self.mask_generator() 73 | 74 | return img, mask 75 | 76 | 77 | def collate_fn(batch): 78 | if not isinstance(batch[0][0], tuple): 79 | return default_collate(batch) 80 | else: 81 | batch_num = len(batch) 82 | ret = [] 83 | for item_idx in range(len(batch[0][0])): 84 | if batch[0][0][item_idx] is None: 85 | ret.append(None) 86 | else: 87 | ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) 88 | ret.append(default_collate([batch[i][1] for i in range(batch_num)])) 89 | return ret 90 | 91 | 92 | def build_loader_simmim(config): 93 | transform = SimMIMTransform(config) 94 | dataset = ImageFolder(config.DATA.DATA_PATH, transform) 95 | 96 | sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) 97 | dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) 98 | 99 | return dataloader -------------------------------------------------------------------------------- /classification/utils/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from timm.scheduler.scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class CosineLRScheduler(Scheduler): 19 | """ 20 | Cosine decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1608.03983. 22 | 23 | Inspiration from 24 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 25 | """ 26 | 27 | def __init__(self, 28 | optimizer: torch.optim.Optimizer, 29 | t_initial: int, 30 | t_mul: float = 1., 31 | lr_min: float = 0., 32 | decay_rate: float = 1., 33 | warmup_t=0, 34 | warmup_lr_init=0, 35 | warmup_prefix=False, 36 | cycle_limit=0, 37 | t_in_epochs=True, 38 | noise_range_t=None, 39 | noise_pct=0.67, 40 | noise_std=1.0, 41 | noise_seed=42, 42 | initialize=True) -> None: 43 | super().__init__( 44 | optimizer, param_group_field="lr", 45 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 46 | initialize=initialize) 47 | 48 | assert t_initial > 0 49 | assert lr_min >= 0 50 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 51 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 52 | "rate since t_initial = t_mul = eta_mul = 1.") 53 | self.t_initial = t_initial 54 | self.t_mul = t_mul 55 | self.lr_min = lr_min 56 | self.decay_rate = decay_rate 57 | self.cycle_limit = cycle_limit 58 | self.warmup_t = warmup_t 59 | self.warmup_lr_init = warmup_lr_init 60 | self.warmup_prefix = warmup_prefix 61 | self.t_in_epochs = t_in_epochs 62 | if self.warmup_t: 63 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 64 | super().update_groups(self.warmup_lr_init) 65 | else: 66 | self.warmup_steps = [1 for _ in self.base_values] 67 | 68 | def _get_lr(self, t): 69 | if t < self.warmup_t: 70 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 71 | else: 72 | if self.warmup_prefix: 73 | t = t - self.warmup_t 74 | 75 | if self.t_mul != 1: 76 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 77 | t_i = self.t_mul ** i * self.t_initial 78 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 79 | else: 80 | i = t // self.t_initial 81 | t_i = self.t_initial 82 | t_curr = t - (self.t_initial * i) 83 | 84 | gamma = self.decay_rate ** i 85 | lr_min = self.lr_min * gamma 86 | lr_max_values = [v * gamma for v in self.base_values] 87 | 88 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 89 | lrs = [ 90 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 91 | ] 92 | else: 93 | lrs = [self.lr_min for _ in self.base_values] 94 | 95 | return lrs 96 | 97 | def get_epoch_values(self, epoch: int): 98 | if self.t_in_epochs: 99 | return self._get_lr(epoch) 100 | else: 101 | return None 102 | 103 | def get_update_values(self, num_updates: int): 104 | if not self.t_in_epochs: 105 | return self._get_lr(num_updates) 106 | else: 107 | return None 108 | 109 | def get_cycle_length(self, cycles=0): 110 | if not cycles: 111 | cycles = self.cycle_limit 112 | cycles = max(1, cycles) 113 | if self.t_mul == 1.0: 114 | return self.t_initial * cycles 115 | else: 116 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 117 | -------------------------------------------------------------------------------- /classification/data/data_simmim_ft.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader, DistributedSampler 11 | from torchvision import datasets, transforms 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.data import Mixup 14 | from timm.data import create_transform 15 | from timm.data.transforms import _pil_interp 16 | 17 | 18 | def build_loader_finetune(config): 19 | config.defrost() 20 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 21 | config.freeze() 22 | dataset_val, _ = build_dataset(is_train=False, config=config) 23 | 24 | num_tasks = dist.get_world_size() 25 | global_rank = dist.get_rank() 26 | sampler_train = DistributedSampler( 27 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 28 | ) 29 | sampler_val = DistributedSampler( 30 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False 31 | ) 32 | 33 | data_loader_train = DataLoader( 34 | dataset_train, sampler=sampler_train, 35 | batch_size=config.DATA.BATCH_SIZE, 36 | num_workers=config.DATA.NUM_WORKERS, 37 | pin_memory=config.DATA.PIN_MEMORY, 38 | drop_last=True, 39 | ) 40 | 41 | data_loader_val = DataLoader( 42 | dataset_val, sampler=sampler_val, 43 | batch_size=config.DATA.BATCH_SIZE, 44 | num_workers=config.DATA.NUM_WORKERS, 45 | pin_memory=config.DATA.PIN_MEMORY, 46 | drop_last=False, 47 | ) 48 | 49 | # setup mixup / cutmix 50 | mixup_fn = None 51 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 52 | if mixup_active: 53 | mixup_fn = Mixup( 54 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 55 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 56 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 57 | 58 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 59 | 60 | 61 | def build_dataset(is_train, config): 62 | transform = build_transform(is_train, config) 63 | 64 | if config.DATA.DATASET == 'imagenet': 65 | prefix = 'train' if is_train else 'val' 66 | root = os.path.join(config.DATA.DATA_PATH, prefix) 67 | dataset = datasets.ImageFolder(root, transform=transform) 68 | nb_classes = 1000 69 | else: 70 | raise NotImplementedError("We only support ImageNet Now.") 71 | 72 | return dataset, nb_classes 73 | 74 | 75 | def build_transform(is_train, config): 76 | resize_im = config.DATA.IMG_SIZE > 32 77 | if is_train: 78 | # this should always dispatch to transforms_imagenet_train 79 | transform = create_transform( 80 | input_size=config.DATA.IMG_SIZE, 81 | is_training=True, 82 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 83 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 84 | re_prob=config.AUG.REPROB, 85 | re_mode=config.AUG.REMODE, 86 | re_count=config.AUG.RECOUNT, 87 | interpolation=config.DATA.INTERPOLATION, 88 | ) 89 | if not resize_im: 90 | # replace RandomResizedCropAndInterpolation with 91 | # RandomCrop 92 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 93 | return transform 94 | 95 | t = [] 96 | if resize_im: 97 | if config.TEST.CROP: 98 | size = int((256 / 224) * config.DATA.IMG_SIZE) 99 | t.append( 100 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 101 | # to maintain same ratio w.r.t. 224 images 102 | ) 103 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 104 | else: 105 | t.append( 106 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 107 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 108 | ) 109 | 110 | t.append(transforms.ToTensor()) 111 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 112 | return transforms.Compose(t) 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

[CVPR 2025 🎉] Spectral State Space Model for 4 | 5 | Rotation-Invariant Visual Representation Learning 6 |

7 | 8 |
9 | 10 | Official Implementation of Spectral VMamba: 11 | [![arXiv](https://img.shields.io/badge/arXiv-2306.08832-B31B1B.svg)](https://arxiv.org/pdf/2503.06369) 12 | 13 | ## Table of Contents 14 | - [1. Abstract](#1-abstract) 15 | - [2. Overview](#2-overview) 16 | - [3. Main Results](#3-main-results) 17 | - [**Classification on mini-ImageNet**](#classification-on-mini-imagenet) 18 | - [4. Getting Started](#4-getting-started) 19 | - [4.1. Installation](#41-installation) 20 | - [4.2. Model Training and Inference](#42-model-training-and-inference) 21 | - [5. Acknowledgment](#5-acknowledgment) 22 | 23 | 24 | ## 1. Abstract 25 | 26 | State Space Models (SSMs) have recently emerged as an alternative to Vision Transformers (ViTs) due to their unique ability of modeling global relationships with linear complexity. SSMs are specifically designed to capture spatially proximate relationships of image patches. However, they fail to identify relationships between conceptually related yet not adjacent patches. This limitation arises from the non-causal nature of image data, which lacks inherent directional relationships. Additionally, current vision-based SSMs are highly sensitive to transformations such as rotation. Their predefined scanning directions depend on the original image orientation, which can cause the model to produce inconsistent patch-processing sequences after rotation. 27 | 28 | To address these limitations, we introduce Spectral VMamba, a novel approach that effectively captures the global structure within an image by leveraging spectral information derived from the graph Laplacian of image patches. Through spectral decomposition, our approach encodes patch relationships independently of image orientation, achieving rotation invariance with the aid of our Rotational Feature Normalizer (RFN) module. Our experiments on classification tasks show that Spectral VMamba outperforms the leading SSM models in vision, such as VMamba, while maintaining invariance to rotations and a providing a similar runtime efficiency. 29 | 30 | ## 2. Overview 31 | 32 |

33 | architecture 34 | architecture 35 |

36 | 37 | ## 3. Main Results 38 | 39 | 40 | ### **Classification on mini-ImageNet** 41 | | name | pretrain | resolution |acc@1 | FLOPs | configs/logs/ckpts | 42 | | :---: | :---: | :---: | :---: | :---: | :---: | 43 | | VMamba-T[`s1l8`] | mini-ImageNet | 224x224 | 86.25 | 4.9G | | 44 | | VMamba-S[`s2l15`] | mini-ImageNet | 224x224 | 86.48 | 8.7G| | 45 | | VMamba-B[`s2l15`] | mini-ImageNet | 224x224 | 87.17 | 8.7G | | 46 | | Ours-T[`s1l8`] | mini-ImageNet | 224x224 | 87.86 | 3.9G | [config](classification/configs/vssm/spectral_vmamba_tiny_224.yaml) | 47 | | Ours-S[`s2l15`] | mini-ImageNet | 224x224 | 88.09 | 6.3G | [config](classification/configs/vssm/spectral_vmamba_small_224.yaml) | 48 | | Ours-B[`s2l15`] | mini-ImageNet | 224x224 | 88.17 | 6.3G | [config](classification/configs/vssm/spectral_vmamba_base_224.yaml) | 49 | 50 | ## 4. Getting Started 51 | 52 | ### 4.1. Installation 53 | 54 | **Step 1: Clone the repository:** 55 | 56 | To get started, first clone the project repository and navigate to the project directory. 57 | 58 | **Step 2: Environment Setup:** 59 | 60 | 61 | ***Create and activate a new conda environment*** 62 | 63 | ```bash 64 | conda create -n spectral_vmamba 65 | conda activate spectral_vmamba 66 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia 67 | ``` 68 | 69 | ***Install Dependencies*** 70 | 71 | ```bash 72 | pip install -r requirements.txt 73 | cd kernels/selective_scan && pip install . 74 | ``` 75 | 76 | 77 | ### 4.2. Model Training and Inference 78 | 79 | **Classification** 80 | 81 | To train Our models for classification on ImageNet, use the following commands for different configurations: 82 | 83 | ```bash 84 | python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=8 --master_addr="127.0.0.1" --master_port=29501 main.py --cfg --batch-size 128 --data-path --output /tmp 85 | ``` 86 | 87 | If you only want to test the performance (together with params and flops): 88 | 89 | ```bash 90 | python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=1 --master_addr="127.0.0.1" --master_port=29501 main.py --cfg --batch-size 128 --data-path --output /tmp --pretrained 91 | ``` 92 | 93 | ## 5. Acknowledgment 94 | 95 | This project is based on VMamba ([paper](https://arxiv.org/abs/2401.10166), [code](https://github.com/MzeroMiko/VMamba/tree/main)), thanks for their excellent works. 96 | -------------------------------------------------------------------------------- /classification/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import bisect 9 | 10 | import torch 11 | from timm.scheduler.cosine_lr import CosineLRScheduler 12 | from timm.scheduler.step_lr import StepLRScheduler 13 | from timm.scheduler.scheduler import Scheduler 14 | 15 | import timm 16 | if timm.__version__ != "0.4.12": 17 | from .cosine_lr import CosineLRScheduler 18 | 19 | 20 | def build_scheduler(config, optimizer, n_iter_per_epoch): 21 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 22 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 23 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 24 | multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] 25 | 26 | lr_scheduler = None 27 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 28 | lr_scheduler = CosineLRScheduler( 29 | optimizer, 30 | t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps, 31 | t_mul=1., 32 | lr_min=config.TRAIN.MIN_LR, 33 | warmup_lr_init=config.TRAIN.WARMUP_LR, 34 | warmup_t=warmup_steps, 35 | cycle_limit=1, 36 | t_in_epochs=False, 37 | warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX, 38 | ) 39 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 40 | lr_scheduler = LinearLRScheduler( 41 | optimizer, 42 | t_initial=num_steps, 43 | lr_min_rate=0.01, 44 | warmup_lr_init=config.TRAIN.WARMUP_LR, 45 | warmup_t=warmup_steps, 46 | t_in_epochs=False, 47 | ) 48 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 49 | lr_scheduler = StepLRScheduler( 50 | optimizer, 51 | decay_t=decay_steps, 52 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 53 | warmup_lr_init=config.TRAIN.WARMUP_LR, 54 | warmup_t=warmup_steps, 55 | t_in_epochs=False, 56 | ) 57 | elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': 58 | lr_scheduler = MultiStepLRScheduler( 59 | optimizer, 60 | milestones=multi_steps, 61 | gamma=config.TRAIN.LR_SCHEDULER.GAMMA, 62 | warmup_lr_init=config.TRAIN.WARMUP_LR, 63 | warmup_t=warmup_steps, 64 | t_in_epochs=False, 65 | ) 66 | 67 | return lr_scheduler 68 | 69 | 70 | class LinearLRScheduler(Scheduler): 71 | def __init__(self, 72 | optimizer: torch.optim.Optimizer, 73 | t_initial: int, 74 | lr_min_rate: float, 75 | warmup_t=0, 76 | warmup_lr_init=0., 77 | t_in_epochs=True, 78 | noise_range_t=None, 79 | noise_pct=0.67, 80 | noise_std=1.0, 81 | noise_seed=42, 82 | initialize=True, 83 | ) -> None: 84 | super().__init__( 85 | optimizer, param_group_field="lr", 86 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 87 | initialize=initialize) 88 | 89 | self.t_initial = t_initial 90 | self.lr_min_rate = lr_min_rate 91 | self.warmup_t = warmup_t 92 | self.warmup_lr_init = warmup_lr_init 93 | self.t_in_epochs = t_in_epochs 94 | if self.warmup_t: 95 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 96 | super().update_groups(self.warmup_lr_init) 97 | else: 98 | self.warmup_steps = [1 for _ in self.base_values] 99 | 100 | def _get_lr(self, t): 101 | if t < self.warmup_t: 102 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 103 | else: 104 | t = t - self.warmup_t 105 | total_t = self.t_initial - self.warmup_t 106 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 107 | return lrs 108 | 109 | def get_epoch_values(self, epoch: int): 110 | if self.t_in_epochs: 111 | return self._get_lr(epoch) 112 | else: 113 | return None 114 | 115 | def get_update_values(self, num_updates: int): 116 | if not self.t_in_epochs: 117 | return self._get_lr(num_updates) 118 | else: 119 | return None 120 | 121 | 122 | class MultiStepLRScheduler(Scheduler): 123 | def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: 124 | super().__init__(optimizer, param_group_field="lr") 125 | 126 | self.milestones = milestones 127 | self.gamma = gamma 128 | self.warmup_t = warmup_t 129 | self.warmup_lr_init = warmup_lr_init 130 | self.t_in_epochs = t_in_epochs 131 | if self.warmup_t: 132 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 133 | super().update_groups(self.warmup_lr_init) 134 | else: 135 | self.warmup_steps = [1 for _ in self.base_values] 136 | 137 | assert self.warmup_t <= min(self.milestones) 138 | 139 | def _get_lr(self, t): 140 | if t < self.warmup_t: 141 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 142 | else: 143 | lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values] 144 | return lrs 145 | 146 | def get_epoch_values(self, epoch: int): 147 | if self.t_in_epochs: 148 | return self._get_lr(epoch) 149 | else: 150 | return None 151 | 152 | def get_update_values(self, num_updates: int): 153 | if not self.t_in_epochs: 154 | return self._get_lr(num_updates) 155 | else: 156 | return None 157 | -------------------------------------------------------------------------------- /kernels/selective_scan/setup.py: -------------------------------------------------------------------------------- 1 | # Modified by $@#Anonymous#@$ #20240123 2 | # Copyright (c) 2023, Albert Gu, Tri Dao. 3 | import sys 4 | import warnings 5 | import os 6 | import re 7 | import ast 8 | from pathlib import Path 9 | from packaging.version import parse, Version 10 | import platform 11 | import shutil 12 | 13 | from setuptools import setup, find_packages 14 | import subprocess 15 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 16 | 17 | import torch 18 | from torch.utils.cpp_extension import ( 19 | BuildExtension, 20 | CppExtension, 21 | CUDAExtension, 22 | CUDA_HOME, 23 | ) 24 | 25 | # ninja build does not work unless include_dirs are abs path 26 | this_dir = os.path.dirname(os.path.abspath(__file__)) 27 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 28 | FORCE_CXX11_ABI = os.getenv("FORCE_CXX11_ABI", "FALSE") == "TRUE" 29 | 30 | def get_cuda_bare_metal_version(cuda_dir): 31 | raw_output = subprocess.check_output( 32 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 33 | ) 34 | output = raw_output.split() 35 | release_idx = output.index("release") + 1 36 | bare_metal_version = parse(output[release_idx].split(",")[0]) 37 | 38 | return raw_output, bare_metal_version 39 | 40 | MODES = ["oflex"] 41 | # MODES = ["core", "ndstate", "oflex"] 42 | # MODES = ["core", "ndstate", "oflex", "nrow"] 43 | 44 | def get_ext(): 45 | cc_flag = [] 46 | 47 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 48 | print("\n\nCUDA_HOME = {}\n\n".format(CUDA_HOME)) 49 | 50 | # Check, if CUDA11 is installed for compute capability 8.0 51 | multi_threads = True 52 | gencode_sm90 = False 53 | if CUDA_HOME is not None: 54 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 55 | print("CUDA version: ", bare_metal_version, flush=True) 56 | if bare_metal_version >= Version("11.8"): 57 | gencode_sm90 = True 58 | if bare_metal_version < Version("11.6"): 59 | warnings.warn("CUDA version ealier than 11.6 may leads to performance mismatch.") 60 | if bare_metal_version < Version("11.2"): 61 | multi_threads = False 62 | 63 | cc_flag.extend(["-gencode", "arch=compute_70,code=sm_70"]) 64 | cc_flag.extend(["-gencode", "arch=compute_80,code=sm_80"]) 65 | if gencode_sm90: 66 | cc_flag.extend(["-gencode", "arch=compute_90,code=sm_90"]) 67 | if multi_threads: 68 | cc_flag.extend(["--threads", "4"]) 69 | 70 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 71 | # torch._C._GLIBCXX_USE_CXX11_ABI 72 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 73 | if FORCE_CXX11_ABI: 74 | torch._C._GLIBCXX_USE_CXX11_ABI = True 75 | 76 | sources = dict( 77 | core=[ 78 | "csrc/selective_scan/cus/selective_scan.cpp", 79 | "csrc/selective_scan/cus/selective_scan_core_fwd.cu", 80 | "csrc/selective_scan/cus/selective_scan_core_bwd.cu", 81 | ], 82 | nrow=[ 83 | "csrc/selective_scan/cusnrow/selective_scan_nrow.cpp", 84 | "csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu", 85 | "csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu", 86 | "csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu", 87 | "csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu", 88 | "csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu", 89 | "csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu", 90 | "csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu", 91 | "csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu", 92 | ], 93 | ndstate=[ 94 | "csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp", 95 | "csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu", 96 | "csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu", 97 | ], 98 | oflex=[ 99 | "csrc/selective_scan/cusoflex/selective_scan_oflex.cpp", 100 | "csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu", 101 | "csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu", 102 | ], 103 | ) 104 | 105 | names = dict( 106 | core="selective_scan_cuda_core", 107 | nrow="selective_scan_cuda_nrow", 108 | ndstate="selective_scan_cuda_ndstate", 109 | oflex="selective_scan_cuda_oflex", 110 | ) 111 | 112 | ext_modules = [ 113 | CUDAExtension( 114 | name=names.get(MODE, None), 115 | sources=sources.get(MODE, None), 116 | extra_compile_args={ 117 | "cxx": ["-O3", "-std=c++17"], 118 | "nvcc": [ 119 | "-O3", 120 | "-std=c++17", 121 | "-U__CUDA_NO_HALF_OPERATORS__", 122 | "-U__CUDA_NO_HALF_CONVERSIONS__", 123 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 124 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 125 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 126 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 127 | "--expt-relaxed-constexpr", 128 | "--expt-extended-lambda", 129 | "--use_fast_math", 130 | "--ptxas-options=-v", 131 | "-lineinfo", 132 | ] 133 | + cc_flag 134 | }, 135 | include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], 136 | ) 137 | for MODE in MODES 138 | ] 139 | 140 | return ext_modules 141 | 142 | ext_modules = get_ext() 143 | setup( 144 | name="selective_scan", 145 | version="0.0.2", 146 | packages=[], 147 | author="Tri Dao, Albert Gu, $@#Anonymous#@$ ", 148 | author_email="tri@tridao.me, agu@cs.cmu.edu, $@#Anonymous#EMAIL@$", 149 | description="selective scan", 150 | long_description="", 151 | long_description_content_type="text/markdown", 152 | url="https://github.com/state-spaces/mamba", 153 | classifiers=[ 154 | "Programming Language :: Python :: 3", 155 | "License :: OSI Approved :: BSD License", 156 | "Operating System :: Unix", 157 | ], 158 | ext_modules=ext_modules, 159 | cmdclass={"bdist_wheel": _bdist_wheel, "build_ext": BuildExtension} if ext_modules else {"bdist_wheel": _bdist_wheel,}, 160 | python_requires=">=3.7", 161 | install_requires=[ 162 | "torch", 163 | "packaging", 164 | "ninja", 165 | "einops", 166 | ], 167 | ) 168 | -------------------------------------------------------------------------------- /classification/data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | import torch.distributed as dist 12 | from torchvision import datasets, transforms 13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.data import Mixup 15 | from timm.data import create_transform 16 | 17 | from .cached_image_folder import CachedImageFolder 18 | from .imagenet22k_dataset import IN22KDATASET 19 | from .samplers import SubsetRandomSampler 20 | 21 | try: 22 | from torchvision.transforms import InterpolationMode 23 | 24 | 25 | def _pil_interp(method): 26 | if method == 'bicubic': 27 | return InterpolationMode.BICUBIC 28 | elif method == 'lanczos': 29 | return InterpolationMode.LANCZOS 30 | elif method == 'hamming': 31 | return InterpolationMode.HAMMING 32 | else: 33 | # default bilinear, do we want to allow nearest? 34 | return InterpolationMode.BILINEAR 35 | 36 | 37 | import timm.data.transforms as timm_transforms 38 | 39 | timm_transforms._pil_interp = _pil_interp 40 | except: 41 | from timm.data.transforms import _pil_interp 42 | 43 | 44 | def build_loader(config): 45 | config.defrost() 46 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 47 | config.freeze() 48 | print(f"rank {dist.get_rank()} successfully build train dataset") 49 | dataset_val, _ = build_dataset(is_train=False, config=config) 50 | print(f"rank {dist.get_rank()} successfully build val dataset") 51 | 52 | num_tasks = dist.get_world_size() 53 | global_rank = dist.get_rank() 54 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 55 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 56 | sampler_train = SubsetRandomSampler(indices) 57 | else: 58 | sampler_train = torch.utils.data.DistributedSampler( 59 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 60 | ) 61 | 62 | if config.TEST.SEQUENTIAL: 63 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 64 | else: 65 | sampler_val = torch.utils.data.distributed.DistributedSampler( 66 | dataset_val, shuffle=config.TEST.SHUFFLE 67 | ) 68 | 69 | data_loader_train = torch.utils.data.DataLoader( 70 | dataset_train, sampler=sampler_train, 71 | batch_size=config.DATA.BATCH_SIZE, 72 | num_workers=config.DATA.NUM_WORKERS, 73 | pin_memory=config.DATA.PIN_MEMORY, 74 | drop_last=True, 75 | ) 76 | 77 | data_loader_val = torch.utils.data.DataLoader( 78 | dataset_val, sampler=sampler_val, 79 | batch_size=config.DATA.BATCH_SIZE, 80 | shuffle=False, 81 | num_workers=config.DATA.NUM_WORKERS, 82 | pin_memory=config.DATA.PIN_MEMORY, 83 | drop_last=False 84 | ) 85 | 86 | # setup mixup / cutmix 87 | mixup_fn = None 88 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 89 | if mixup_active: 90 | mixup_fn = Mixup( 91 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 92 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 93 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 94 | 95 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 96 | 97 | 98 | def build_dataset(is_train, config): 99 | transform = build_transform(is_train, config) 100 | if config.DATA.DATASET == 'imagenet': 101 | prefix = 'train' if is_train else 'val' 102 | if config.DATA.ZIP_MODE: 103 | ann_file = prefix + "_map.txt" 104 | prefix = prefix + ".zip@/" 105 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 106 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') 107 | else: 108 | root = os.path.join(config.DATA.DATA_PATH, prefix) 109 | dataset = datasets.ImageFolder(root, transform=transform) 110 | 111 | # ============================================================================= 112 | # # JUST for test 113 | if False: 114 | from torch.utils.data import Dataset 115 | class FDataset(Dataset): 116 | def __init__(self, *args, **kwargs): 117 | pass 118 | 119 | def __len__(self): 120 | return 1000 121 | 122 | def __getitem__(self, *args,**kwargs): 123 | return torch.randn((3, 224, 224)), 0 124 | 125 | dataset = FDataset() 126 | 127 | # ============================================================================= 128 | 129 | nb_classes = 1000 130 | elif config.DATA.DATASET == 'imagenet22K': 131 | prefix = 'ILSVRC2011fall_whole' 132 | if is_train: 133 | ann_file = prefix + "_map_train.txt" 134 | else: 135 | ann_file = prefix + "_map_val.txt" 136 | dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform) 137 | nb_classes = 21841 138 | else: 139 | raise NotImplementedError("We only support ImageNet Now.") 140 | 141 | return dataset, nb_classes 142 | 143 | 144 | def build_transform(is_train, config): 145 | resize_im = config.DATA.IMG_SIZE > 32 146 | if is_train: 147 | # this should always dispatch to transforms_imagenet_train 148 | transform = create_transform( 149 | input_size=config.DATA.IMG_SIZE, 150 | is_training=True, 151 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 152 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 153 | re_prob=config.AUG.REPROB, 154 | re_mode=config.AUG.REMODE, 155 | re_count=config.AUG.RECOUNT, 156 | interpolation=config.DATA.INTERPOLATION, 157 | ) 158 | if not resize_im: 159 | # replace RandomResizedCropAndInterpolation with 160 | # RandomCrop 161 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 162 | return transform 163 | 164 | t = [] 165 | if resize_im: 166 | if config.TEST.CROP: 167 | size = int((256 / 224) * config.DATA.IMG_SIZE) 168 | t.append( 169 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 170 | # to maintain same ratio w.r.t. 224 images 171 | ) 172 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 173 | else: 174 | t.append( 175 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 176 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 177 | ) 178 | 179 | t.append(transforms.ToTensor()) 180 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 181 | return transforms.Compose(t) 182 | -------------------------------------------------------------------------------- /classification/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Modified by $@#Anonymous#@$ 3 | # -------------------------------------------------------- 4 | # Swin Transformer 5 | # Copyright (c) 2021 Microsoft 6 | # Licensed under The MIT License [see LICENSE for details] 7 | # Written by Ze Liu 8 | # -------------------------------------------------------- 9 | 10 | from functools import partial 11 | from torch import optim as optim 12 | 13 | 14 | def build_optimizer(config, model, logger, **kwargs): 15 | """ 16 | Build optimizer, set weight decay of normalization to 0 by default. 17 | """ 18 | logger.info(f"==============> building optimizer {config.TRAIN.OPTIMIZER.NAME}....................") 19 | skip = {} 20 | skip_keywords = {} 21 | if hasattr(model, 'no_weight_decay'): 22 | skip = model.no_weight_decay() 23 | if hasattr(model, 'no_weight_decay_keywords'): 24 | skip_keywords = model.no_weight_decay_keywords() 25 | parameters, no_decay_names = set_weight_decay(model, skip, skip_keywords) 26 | logger.info(f"No weight decay list: {no_decay_names}") 27 | 28 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 29 | optimizer = None 30 | if opt_lower == 'sgd': 31 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 32 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 33 | elif opt_lower == 'adamw': 34 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 35 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 36 | else: 37 | raise NotImplementedError 38 | 39 | return optimizer 40 | 41 | 42 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 43 | has_decay = [] 44 | no_decay = [] 45 | no_decay_names = [] 46 | 47 | for name, param in model.named_parameters(): 48 | if not param.requires_grad: 49 | continue # frozen weights 50 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 51 | check_keywords_in_name(name, skip_keywords): 52 | no_decay.append(param) 53 | no_decay_names.append(name) 54 | # print(f"{name} has no weight decay") 55 | else: 56 | has_decay.append(param) 57 | return [{'params': has_decay}, 58 | {'params': no_decay, 'weight_decay': 0.}], no_decay_names 59 | 60 | 61 | def check_keywords_in_name(name, keywords=()): 62 | isin = False 63 | for keyword in keywords: 64 | if keyword in name: 65 | isin = True 66 | return isin 67 | 68 | 69 | # ========================== 70 | # for mim, currently not used, and may have bugs... 71 | 72 | def build_optimizer_swimmim(config, model, logger, simmim=True, is_pretrain=False): 73 | """ 74 | Build optimizer, set weight decay of normalization to 0 by default. 75 | """ 76 | skip = {} 77 | skip_keywords = {} 78 | if hasattr(model, 'no_weight_decay'): 79 | skip = model.no_weight_decay() 80 | if hasattr(model, 'no_weight_decay_keywords'): 81 | skip_keywords = model.no_weight_decay_keywords() 82 | if is_pretrain: 83 | parameters = get_pretrain_param_groups(model, skip, skip_keywords) 84 | else: 85 | depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' else config.MODEL.SWINV2.DEPTHS 86 | num_layers = sum(depths) 87 | get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) 88 | scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2))) 89 | parameters = get_finetune_param_groups(model, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, get_layer_func, scales, skip, skip_keywords) 90 | 91 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 92 | optimizer = None 93 | if opt_lower == 'sgd': 94 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 95 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 96 | elif opt_lower == 'adamw': 97 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 98 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 99 | else: 100 | raise NotImplementedError 101 | 102 | return optimizer 103 | 104 | 105 | def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): 106 | has_decay = [] 107 | no_decay = [] 108 | has_decay_name = [] 109 | no_decay_name = [] 110 | 111 | for name, param in model.named_parameters(): 112 | if not param.requires_grad: 113 | continue 114 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 115 | check_keywords_in_name(name, skip_keywords): 116 | no_decay.append(param) 117 | no_decay_name.append(name) 118 | else: 119 | has_decay.append(param) 120 | has_decay_name.append(name) 121 | return [{'params': has_decay}, 122 | {'params': no_decay, 'weight_decay': 0.}] 123 | 124 | 125 | def get_swin_layer(name, num_layers, depths): 126 | if name in ("mask_token"): 127 | return 0 128 | elif name.startswith("patch_embed"): 129 | return 0 130 | elif name.startswith("layers"): 131 | layer_id = int(name.split('.')[1]) 132 | block_id = name.split('.')[3] 133 | if block_id == 'reduction' or block_id == 'norm': 134 | return sum(depths[:layer_id + 1]) 135 | layer_id = sum(depths[:layer_id]) + int(block_id) 136 | return layer_id + 1 137 | else: 138 | return num_layers - 1 139 | 140 | 141 | def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): 142 | parameter_group_names = {} 143 | parameter_group_vars = {} 144 | 145 | for name, param in model.named_parameters(): 146 | if not param.requires_grad: 147 | continue 148 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 149 | check_keywords_in_name(name, skip_keywords): 150 | group_name = "no_decay" 151 | this_weight_decay = 0. 152 | else: 153 | group_name = "decay" 154 | this_weight_decay = weight_decay 155 | if get_layer_func is not None: 156 | layer_id = get_layer_func(name) 157 | group_name = "layer_%d_%s" % (layer_id, group_name) 158 | else: 159 | layer_id = None 160 | 161 | if group_name not in parameter_group_names: 162 | if scales is not None: 163 | scale = scales[layer_id] 164 | else: 165 | scale = 1. 166 | 167 | parameter_group_names[group_name] = { 168 | "group_name": group_name, 169 | "weight_decay": this_weight_decay, 170 | "params": [], 171 | "lr": lr * scale, 172 | "lr_scale": scale, 173 | } 174 | parameter_group_vars[group_name] = { 175 | "group_name": group_name, 176 | "weight_decay": this_weight_decay, 177 | "params": [], 178 | "lr": lr * scale, 179 | "lr_scale": scale 180 | } 181 | 182 | parameter_group_vars[group_name]["params"].append(param) 183 | parameter_group_names[group_name]["params"].append(name) 184 | return list(parameter_group_vars.values()) 185 | 186 | -------------------------------------------------------------------------------- /classification/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Modified By $@#Anonymous#@$ 3 | # -------------------------------------------------------- 4 | # Swin Transformer 5 | # Copyright (c) 2021 Microsoft 6 | # Licensed under The MIT License [see LICENSE for details] 7 | # Written by Ze Liu 8 | # -------------------------------------------------------- 9 | 10 | import os 11 | from math import inf 12 | import torch 13 | import torch.distributed as dist 14 | from timm.utils import ModelEma as ModelEma 15 | 16 | 17 | def load_checkpoint_ema(config, model, optimizer, lr_scheduler, loss_scaler, logger, model_ema: ModelEma=None): 18 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 19 | if config.MODEL.RESUME.startswith('https'): 20 | checkpoint = torch.hub.load_state_dict_from_url( 21 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 22 | else: 23 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 24 | 25 | if 'model' in checkpoint: 26 | msg = model.load_state_dict(checkpoint['model'], strict=False) 27 | logger.info(f"resuming model: {msg}") 28 | else: 29 | logger.warning(f"No 'model' found in {config.MODEL.RESUME}! ") 30 | 31 | if model_ema is not None: 32 | if 'model_ema' in checkpoint: 33 | msg = model_ema.ema.load_state_dict(checkpoint['model_ema'], strict=False) 34 | logger.info(f"resuming model_ema: {msg}") 35 | else: 36 | logger.warning(f"No 'model_ema' found in {config.MODEL.RESUME}! ") 37 | 38 | max_accuracy = 0.0 39 | max_accuracy_ema = 0.0 40 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 41 | optimizer.load_state_dict(checkpoint['optimizer']) 42 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 43 | config.defrost() 44 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 45 | config.freeze() 46 | if 'scaler' in checkpoint: 47 | loss_scaler.load_state_dict(checkpoint['scaler']) 48 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 49 | if 'max_accuracy' in checkpoint: 50 | max_accuracy = checkpoint['max_accuracy'] 51 | if 'max_accuracy_ema' in checkpoint: 52 | max_accuracy_ema = checkpoint['max_accuracy_ema'] 53 | 54 | del checkpoint 55 | torch.cuda.empty_cache() 56 | return max_accuracy, max_accuracy_ema 57 | 58 | 59 | def load_pretrained_ema(config, model, logger, model_ema: ModelEma=None): 60 | logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") 61 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') 62 | 63 | if 'model' in checkpoint: 64 | msg = model.load_state_dict(checkpoint['model'], strict=False) 65 | logger.warning(msg) 66 | logger.info(f"=> loaded 'model' successfully from '{config.MODEL.PRETRAINED}'") 67 | else: 68 | logger.warning(f"No 'model' found in {config.MODEL.PRETRAINED}! ") 69 | 70 | if model_ema is not None: 71 | if "model_ema" in checkpoint: 72 | logger.info(f"=> loading 'model_ema' separately...") 73 | key = "model_ema" if ("model_ema" in checkpoint) else "model" 74 | if key in checkpoint: 75 | msg = model_ema.ema.load_state_dict(checkpoint[key], strict=False) 76 | logger.warning(msg) 77 | logger.info(f"=> loaded '{key}' successfully from '{config.MODEL.PRETRAINED}' for model_ema") 78 | else: 79 | logger.warning(f"No '{key}' found in {config.MODEL.PRETRAINED}! ") 80 | 81 | del checkpoint 82 | torch.cuda.empty_cache() 83 | 84 | 85 | # def save_checkpoint_ema(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, model_ema: ModelEma=None, max_accuracy_ema=None): 86 | # save_state = {'model': model.state_dict(), 87 | # 'optimizer': optimizer.state_dict(), 88 | # 'lr_scheduler': lr_scheduler.state_dict(), 89 | # 'max_accuracy': max_accuracy, 90 | # 'scaler': loss_scaler.state_dict(), 91 | # 'epoch': epoch, 92 | # 'config': config} 93 | 94 | # if model_ema is not None: 95 | # save_state.update({'model_ema': model_ema.ema.state_dict(), 96 | # 'max_accuray_ema': max_accuracy_ema}) 97 | 98 | # save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 99 | # logger.info(f"{save_path} saving......") 100 | # torch.save(save_state, save_path) 101 | # logger.info(f"{save_path} saved !!!") 102 | 103 | 104 | def save_checkpoint_ema(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, model_ema: ModelEma=None, max_accuracy_ema=None): 105 | save_state = {'model': model.state_dict(), 106 | 'optimizer': optimizer.state_dict(), 107 | 'lr_scheduler': lr_scheduler.state_dict(), 108 | 'max_accuracy': max_accuracy, 109 | 'scaler': loss_scaler.state_dict(), 110 | 'epoch': epoch, 111 | 'config': config} 112 | 113 | if model_ema is not None: 114 | save_state.update({'model_ema': model_ema.ema.state_dict(), 115 | 'max_accuray_ema': max_accuracy_ema}) 116 | 117 | #save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 118 | save_path = os.path.join(config.OUTPUT, f'ckpt_best.pth') 119 | logger.info(f"{save_path} saving best model......") 120 | torch.save(save_state, save_path) 121 | logger.info(f"{save_path} saved best model!!!") 122 | 123 | 124 | def get_grad_norm(parameters, norm_type=2): 125 | if isinstance(parameters, torch.Tensor): 126 | parameters = [parameters] 127 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 128 | norm_type = float(norm_type) 129 | total_norm = 0 130 | for p in parameters: 131 | param_norm = p.grad.data.norm(norm_type) 132 | total_norm += param_norm.item() ** norm_type 133 | total_norm = total_norm ** (1. / norm_type) 134 | return total_norm 135 | 136 | 137 | def auto_resume_helper(output_dir): 138 | checkpoints = os.listdir(output_dir) 139 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 140 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 141 | if len(checkpoints) > 0: 142 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 143 | print(f"The latest checkpoint founded: {latest_checkpoint}") 144 | resume_file = latest_checkpoint 145 | else: 146 | resume_file = None 147 | return resume_file 148 | 149 | 150 | def reduce_tensor(tensor): 151 | rt = tensor.clone() 152 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 153 | rt /= dist.get_world_size() 154 | return rt 155 | 156 | 157 | def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: 158 | if isinstance(parameters, torch.Tensor): 159 | parameters = [parameters] 160 | parameters = [p for p in parameters if p.grad is not None] 161 | norm_type = float(norm_type) 162 | if len(parameters) == 0: 163 | return torch.tensor(0.) 164 | device = parameters[0].grad.device 165 | if norm_type == inf: 166 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 167 | else: 168 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 169 | norm_type).to(device) for p in parameters]), norm_type) 170 | return total_norm 171 | 172 | 173 | class NativeScalerWithGradNormCount: 174 | state_dict_key = "amp_scaler" 175 | 176 | def __init__(self): 177 | self._scaler = torch.cuda.amp.GradScaler() 178 | 179 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 180 | self._scaler.scale(loss).backward(create_graph=create_graph) 181 | if update_grad: 182 | if clip_grad is not None: 183 | assert parameters is not None 184 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 185 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 186 | else: 187 | self._scaler.unscale_(optimizer) 188 | norm = ampscaler_get_grad_norm(parameters) 189 | self._scaler.step(optimizer) 190 | self._scaler.update() 191 | else: 192 | norm = None 193 | return norm 194 | 195 | def state_dict(self): 196 | return self._scaler.state_dict() 197 | 198 | def load_state_dict(self, state_dict): 199 | self._scaler.load_state_dict(state_dict) 200 | 201 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/selective_scan_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For scalar_value_type 10 | 11 | #define MAX_DSTATE 256 12 | 13 | inline __device__ float2 operator+(const float2 & a, const float2 & b){ 14 | return {a.x + b.x, a.y + b.y}; 15 | } 16 | 17 | inline __device__ float3 operator+(const float3 &a, const float3 &b) { 18 | return {a.x + b.x, a.y + b.y, a.z + b.z}; 19 | } 20 | 21 | inline __device__ float4 operator+(const float4 & a, const float4 & b){ 22 | return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; 23 | } 24 | 25 | //////////////////////////////////////////////////////////////////////////////////////////////////// 26 | 27 | template struct BytesToType {}; 28 | 29 | template<> struct BytesToType<16> { 30 | using Type = uint4; 31 | static_assert(sizeof(Type) == 16); 32 | }; 33 | 34 | template<> struct BytesToType<8> { 35 | using Type = uint64_t; 36 | static_assert(sizeof(Type) == 8); 37 | }; 38 | 39 | template<> struct BytesToType<4> { 40 | using Type = uint32_t; 41 | static_assert(sizeof(Type) == 4); 42 | }; 43 | 44 | template<> struct BytesToType<2> { 45 | using Type = uint16_t; 46 | static_assert(sizeof(Type) == 2); 47 | }; 48 | 49 | template<> struct BytesToType<1> { 50 | using Type = uint8_t; 51 | static_assert(sizeof(Type) == 1); 52 | }; 53 | 54 | //////////////////////////////////////////////////////////////////////////////////////////////////// 55 | 56 | template 57 | struct Converter{ 58 | static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { 59 | #pragma unroll 60 | for (int i = 0; i < N; ++i) { dst[i] = src[i]; } 61 | } 62 | }; 63 | 64 | template 65 | struct Converter{ 66 | static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { 67 | static_assert(N % 2 == 0); 68 | auto &src2 = reinterpret_cast(src); 69 | auto &dst2 = reinterpret_cast(dst); 70 | #pragma unroll 71 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } 72 | } 73 | }; 74 | 75 | #if __CUDA_ARCH__ >= 800 76 | template 77 | struct Converter{ 78 | static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { 79 | static_assert(N % 2 == 0); 80 | auto &src2 = reinterpret_cast(src); 81 | auto &dst2 = reinterpret_cast(dst); 82 | #pragma unroll 83 | for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } 84 | } 85 | }; 86 | #endif 87 | 88 | //////////////////////////////////////////////////////////////////////////////////////////////////// 89 | template struct SSMScanOp; 90 | 91 | template<> 92 | struct SSMScanOp { 93 | __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { 94 | return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); 95 | } 96 | }; 97 | 98 | // A stateful callback functor that maintains a running prefix to be applied 99 | // during consecutive scan operations. 100 | template struct SSMScanPrefixCallbackOp { 101 | using scan_t = std::conditional_t, float2, float4>; 102 | scan_t running_prefix; 103 | // Constructor 104 | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} 105 | // Callback operator to be entered by the first warp of threads in the block. 106 | // Thread-0 is responsible for returning a value for seeding the block-wide scan. 107 | __device__ scan_t operator()(scan_t block_aggregate) { 108 | scan_t old_prefix = running_prefix; 109 | running_prefix = SSMScanOp()(running_prefix, block_aggregate); 110 | return old_prefix; 111 | } 112 | }; 113 | 114 | //////////////////////////////////////////////////////////////////////////////////////////////////// 115 | 116 | template 117 | inline __device__ void load_input(typename Ktraits::input_t *u, 118 | typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], 119 | typename Ktraits::BlockLoadT::TempStorage &smem_load, 120 | int seqlen) { 121 | if constexpr (Ktraits::kIsEvenLen) { 122 | auto& smem_load_vec = reinterpret_cast(smem_load); 123 | using vec_t = typename Ktraits::vec_t; 124 | Ktraits::BlockLoadVecT(smem_load_vec).Load( 125 | reinterpret_cast(u), 126 | reinterpret_cast(u_vals) 127 | ); 128 | } else { 129 | Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); 130 | } 131 | } 132 | 133 | template 134 | inline __device__ void load_weight(typename Ktraits::input_t *Bvar, 135 | typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], 136 | typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, 137 | int seqlen) { 138 | constexpr int kNItems = Ktraits::kNItems; 139 | typename Ktraits::input_t B_vals_load[kNItems]; 140 | if constexpr (Ktraits::kIsEvenLen) { 141 | auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); 142 | using vec_t = typename Ktraits::vec_t; 143 | Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( 144 | reinterpret_cast(Bvar), 145 | reinterpret_cast(B_vals_load) 146 | ); 147 | } else { 148 | Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); 149 | } 150 | // #pragma unroll 151 | // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } 152 | Converter::to_float(B_vals_load, B_vals); 153 | } 154 | 155 | template 156 | inline __device__ void store_output(typename Ktraits::input_t *out, 157 | const float (&out_vals)[Ktraits::kNItems], 158 | typename Ktraits::BlockStoreT::TempStorage &smem_store, 159 | int seqlen) { 160 | typename Ktraits::input_t write_vals[Ktraits::kNItems]; 161 | #pragma unroll 162 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } 163 | if constexpr (Ktraits::kIsEvenLen) { 164 | auto& smem_store_vec = reinterpret_cast(smem_store); 165 | using vec_t = typename Ktraits::vec_t; 166 | Ktraits::BlockStoreVecT(smem_store_vec).Store( 167 | reinterpret_cast(out), 168 | reinterpret_cast(write_vals) 169 | ); 170 | } else { 171 | Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); 172 | } 173 | } 174 | 175 | template 176 | inline __device__ void store_output1(typename Ktraits::output_t *out, 177 | const float (&out_vals)[Ktraits::kNItems], 178 | typename Ktraits::BlockStoreOutputT::TempStorage &smem_store, 179 | int seqlen) { 180 | typename Ktraits::output_t write_vals[Ktraits::kNItems]; 181 | #pragma unroll 182 | for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } 183 | if constexpr (Ktraits::kIsEvenLen) { 184 | auto& smem_store_vec = reinterpret_cast(smem_store); 185 | using vec_t = typename Ktraits::vec_t; 186 | Ktraits::BlockStoreOutputVecT(smem_store_vec).Store( 187 | reinterpret_cast(out), 188 | reinterpret_cast(write_vals) 189 | ); 190 | } else { 191 | Ktraits::BlockStoreOutputT(smem_store).Store(out, write_vals, seqlen); 192 | } 193 | } 194 | 195 | template 196 | inline __device__ void load_output(typename Ktraits::output_t *u, 197 | typename Ktraits::output_t (&u_vals)[Ktraits::kNItems], 198 | typename Ktraits::BlockLoadOutputT::TempStorage &smem_load, 199 | int seqlen) { 200 | if constexpr (Ktraits::kIsEvenLen) { 201 | auto& smem_load_vec = reinterpret_cast(smem_load); 202 | using vec_t = typename Ktraits::vec_t; 203 | Ktraits::BlockLoadOutputVecT(smem_load_vec).Load( 204 | reinterpret_cast(u), 205 | reinterpret_cast(u_vals) 206 | ); 207 | } else { 208 | Ktraits::BlockLoadOutputT(smem_load).Load(u, u_vals, seqlen, 0.f); 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /classification/data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import io 9 | import os 10 | import time 11 | import torch.distributed as dist 12 | import torch.utils.data as data 13 | from PIL import Image 14 | 15 | from .zipreader import is_zip_path, ZipReader 16 | 17 | 18 | def has_file_allowed_extension(filename, extensions): 19 | """Checks if a file is an allowed extension. 20 | Args: 21 | filename (string): path to a file 22 | Returns: 23 | bool: True if the filename ends with a known image extension 24 | """ 25 | filename_lower = filename.lower() 26 | return any(filename_lower.endswith(ext) for ext in extensions) 27 | 28 | 29 | def find_classes(dir): 30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 31 | classes.sort() 32 | class_to_idx = {classes[i]: i for i in range(len(classes))} 33 | return classes, class_to_idx 34 | 35 | 36 | def make_dataset(dir, class_to_idx, extensions): 37 | images = [] 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if has_file_allowed_extension(fname, extensions): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 55 | images = [] 56 | with open(ann_file, "r") as f: 57 | contents = f.readlines() 58 | for line_str in contents: 59 | path_contents = [c for c in line_str.split('\t')] 60 | im_file_name = path_contents[0] 61 | class_index = int(path_contents[1]) 62 | 63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 64 | item = (os.path.join(img_prefix, im_file_name), class_index) 65 | 66 | images.append(item) 67 | 68 | return images 69 | 70 | 71 | class DatasetFolder(data.Dataset): 72 | """A generic data loader where the samples are arranged in this way: :: 73 | root/class_x/xxx.ext 74 | root/class_x/xxy.ext 75 | root/class_x/xxz.ext 76 | root/class_y/123.ext 77 | root/class_y/nsdf3.ext 78 | root/class_y/asd932_.ext 79 | Args: 80 | root (string): Root directory path. 81 | loader (callable): A function to load a sample given its path. 82 | extensions (list[string]): A list of allowed extensions. 83 | transform (callable, optional): A function/transform that takes in 84 | a sample and returns a transformed version. 85 | E.g, ``transforms.RandomCrop`` for images. 86 | target_transform (callable, optional): A function/transform that takes 87 | in the target and transforms it. 88 | Attributes: 89 | samples (list): List of (sample path, class_index) tuples 90 | """ 91 | 92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 93 | cache_mode="no"): 94 | # image folder mode 95 | if ann_file == '': 96 | _, class_to_idx = find_classes(root) 97 | samples = make_dataset(root, class_to_idx, extensions) 98 | # zip mode 99 | else: 100 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 101 | os.path.join(root, img_prefix), 102 | extensions) 103 | 104 | if len(samples) == 0: 105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.samples = samples 113 | self.labels = [y_1k for _, y_1k in samples] 114 | self.classes = list(set(self.labels)) 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.cache_mode = cache_mode 120 | if self.cache_mode != "no": 121 | self.init_cache() 122 | 123 | def init_cache(self): 124 | assert self.cache_mode in ["part", "full"] 125 | n_sample = len(self.samples) 126 | global_rank = dist.get_rank() 127 | world_size = dist.get_world_size() 128 | 129 | samples_bytes = [None for _ in range(n_sample)] 130 | start_time = time.time() 131 | for index in range(n_sample): 132 | if index % (n_sample // 10) == 0: 133 | t = time.time() - start_time 134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 135 | start_time = time.time() 136 | path, target = self.samples[index] 137 | if self.cache_mode == "full": 138 | samples_bytes[index] = (ZipReader.read(path), target) 139 | elif self.cache_mode == "part" and index % world_size == global_rank: 140 | samples_bytes[index] = (ZipReader.read(path), target) 141 | else: 142 | samples_bytes[index] = (path, target) 143 | self.samples = samples_bytes 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | Returns: 150 | tuple: (sample, target) where target is class_index of the target class. 151 | """ 152 | path, target = self.samples[index] 153 | sample = self.loader(path) 154 | if self.transform is not None: 155 | sample = self.transform(sample) 156 | if self.target_transform is not None: 157 | target = self.target_transform(target) 158 | 159 | return sample, target 160 | 161 | def __len__(self): 162 | return len(self.samples) 163 | 164 | def __repr__(self): 165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 167 | fmt_str += ' Root Location: {}\n'.format(self.root) 168 | tmp = ' Transforms (if any): ' 169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | tmp = ' Target Transforms (if any): ' 171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 172 | return fmt_str 173 | 174 | 175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 176 | 177 | 178 | def pil_loader(path): 179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 180 | if isinstance(path, bytes): 181 | img = Image.open(io.BytesIO(path)) 182 | elif is_zip_path(path): 183 | data = ZipReader.read(path) 184 | img = Image.open(io.BytesIO(data)) 185 | else: 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | return img.convert('RGB') 190 | 191 | 192 | def accimage_loader(path): 193 | import accimage 194 | try: 195 | return accimage.Image(path) 196 | except IOError: 197 | # Potentially a decoding problem, fall back to PIL.Image 198 | return pil_loader(path) 199 | 200 | 201 | def default_img_loader(path): 202 | from torchvision import get_image_backend 203 | if get_image_backend() == 'accimage': 204 | return accimage_loader(path) 205 | else: 206 | return pil_loader(path) 207 | 208 | 209 | class CachedImageFolder(DatasetFolder): 210 | """A generic data loader where the images are arranged in this way: :: 211 | root/dog/xxx.png 212 | root/dog/xxy.png 213 | root/dog/xxz.png 214 | root/cat/123.png 215 | root/cat/nsdf3.png 216 | root/cat/asd932_.png 217 | Args: 218 | root (string): Root directory path. 219 | transform (callable, optional): A function/transform that takes in an PIL image 220 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 221 | target_transform (callable, optional): A function/transform that takes in the 222 | target and transforms it. 223 | loader (callable, optional): A function to load an image given its path. 224 | Attributes: 225 | imgs (list): List of (image path, class_index) tuples 226 | """ 227 | 228 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 229 | loader=default_img_loader, cache_mode="no"): 230 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 231 | ann_file=ann_file, img_prefix=img_prefix, 232 | transform=transform, target_transform=target_transform, 233 | cache_mode=cache_mode) 234 | self.imgs = self.samples 235 | 236 | def __getitem__(self, index): 237 | """ 238 | Args: 239 | index (int): Index 240 | Returns: 241 | tuple: (image, target) where target is class_index of the target class. 242 | """ 243 | path, target = self.samples[index] 244 | image = self.loader(path) 245 | if self.transform is not None: 246 | img = self.transform(image) 247 | else: 248 | img = image 249 | if self.target_transform is not None: 250 | target = self.target_transform(target) 251 | 252 | return img, target 253 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_fwd_kernel_ndstate.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "selective_scan_ndstate.h" 16 | #include "selective_scan_common.h" 17 | #include "static_switch.h" 18 | 19 | template 20 | struct Selective_Scan_fwd_kernel_traits { 21 | static_assert(kNItems_ % 4 == 0); 22 | using input_t = input_t_; 23 | using weight_t = weight_t_; 24 | static constexpr int kNThreads = kNThreads_; 25 | // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. 26 | static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; 27 | static constexpr int kNItems = kNItems_; 28 | static constexpr int kNBytes = sizeof(input_t); 29 | static_assert(kNBytes == 2 || kNBytes == 4); 30 | static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); 31 | static_assert(kNItems % kNElts == 0); 32 | static constexpr int kNLoads = kNItems / kNElts; 33 | static constexpr bool kIsEvenLen = kIsEvenLen_; 34 | 35 | static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; 36 | 37 | using vec_t = typename BytesToType::Type; 38 | using scan_t = float2; 39 | using BlockLoadT = cub::BlockLoad; 40 | using BlockLoadVecT = cub::BlockLoad; 42 | using BlockLoadWeightT = cub::BlockLoad; 43 | using BlockLoadWeightVecT = cub::BlockLoad; 45 | using BlockStoreT = cub::BlockStore; 46 | using BlockStoreVecT = cub::BlockStore; 48 | // using BlockScanT = cub::BlockScan; 49 | // using BlockScanT = cub::BlockScan; 50 | using BlockScanT = cub::BlockScan; 51 | static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), 52 | sizeof(typename BlockLoadVecT::TempStorage), 53 | 2 * sizeof(typename BlockLoadWeightT::TempStorage), 54 | 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), 55 | sizeof(typename BlockStoreT::TempStorage), 56 | sizeof(typename BlockStoreVecT::TempStorage)}); 57 | static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); 58 | }; 59 | 60 | template 61 | __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) 62 | void selective_scan_fwd_kernel(SSMParamsBase params) { 63 | constexpr int kNThreads = Ktraits::kNThreads; 64 | constexpr int kNItems = Ktraits::kNItems; 65 | constexpr bool kDirectIO = Ktraits::kDirectIO; 66 | using input_t = typename Ktraits::input_t; 67 | using weight_t = typename Ktraits::weight_t; 68 | using scan_t = typename Ktraits::scan_t; 69 | 70 | // Shared memory. 71 | extern __shared__ char smem_[]; 72 | auto& smem_load = reinterpret_cast(smem_); 73 | auto& smem_load_weight = reinterpret_cast(smem_); 74 | auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); 75 | auto& smem_store = reinterpret_cast(smem_); 76 | auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); 77 | scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); 78 | 79 | const int batch_id = blockIdx.x; 80 | const int dim_id = blockIdx.y; 81 | const int group_id = dim_id / (params.dim_ngroups_ratio); 82 | input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride 83 | + dim_id * params.u_d_stride; 84 | input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride 85 | + dim_id * params.delta_d_stride; 86 | constexpr float kLog2e = M_LOG2E; 87 | weight_t A_val = reinterpret_cast(params.A_ptr)[dim_id] * kLog2e; 88 | input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; 89 | input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; 90 | scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks; 91 | 92 | float D_val = 0; // attention! 93 | if (params.D_ptr != nullptr) { 94 | D_val = reinterpret_cast(params.D_ptr)[dim_id]; 95 | } 96 | float delta_bias = 0; 97 | if (params.delta_bias_ptr != nullptr) { 98 | delta_bias = reinterpret_cast(params.delta_bias_ptr)[dim_id]; 99 | } 100 | 101 | constexpr int kChunkSize = kNThreads * kNItems; 102 | for (int chunk = 0; chunk < params.n_chunks; ++chunk) { 103 | input_t u_vals[kNItems], delta_vals_load[kNItems]; 104 | __syncthreads(); 105 | load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); 106 | if constexpr (!kDirectIO) { __syncthreads(); } 107 | load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); 108 | u += kChunkSize; 109 | delta += kChunkSize; 110 | 111 | float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems]; 112 | #pragma unroll 113 | for (int i = 0; i < kNItems; ++i) { 114 | float u_val = float(u_vals[i]); 115 | delta_vals[i] = float(delta_vals_load[i]) + delta_bias; 116 | if (params.delta_softplus) { 117 | delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; 118 | } 119 | delta_u_vals[i] = delta_vals[i] * u_val; 120 | out_vals[i] = D_val * u_val; 121 | } 122 | 123 | __syncthreads(); 124 | { 125 | weight_t B_vals[kNItems], C_vals[kNItems]; 126 | load_weight(Bvar, B_vals, 127 | smem_load_weight, (params.seqlen - chunk * kChunkSize)); 128 | auto &smem_load_weight_C = smem_load_weight1; 129 | load_weight(Cvar, C_vals, 130 | smem_load_weight_C, (params.seqlen - chunk * kChunkSize)); 131 | __syncthreads(); 132 | scan_t thread_data[kNItems]; 133 | #pragma unroll 134 | for (int i = 0; i < kNItems; ++i) { 135 | thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]); 136 | if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct 137 | if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { 138 | thread_data[i] = make_float2(1.f, 0.f); 139 | } 140 | } 141 | } 142 | // Initialize running total 143 | scan_t running_prefix; 144 | // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read 145 | running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[0] : make_float2(1.f, 0.f); 146 | SSMScanPrefixCallbackOp prefix_op(running_prefix); 147 | Ktraits::BlockScanT(smem_scan).InclusiveScan( 148 | thread_data, thread_data, SSMScanOp(), prefix_op 149 | ); 150 | // There's a syncthreads in the scan op, so we don't need to sync here. 151 | // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. 152 | if (threadIdx.x == 0) { 153 | smem_running_prefix[0] = prefix_op.running_prefix; 154 | x[chunk] = prefix_op.running_prefix; 155 | } 156 | #pragma unroll 157 | for (int i = 0; i < kNItems; ++i) { 158 | out_vals[i] += thread_data[i].y * C_vals[i]; 159 | } 160 | } 161 | 162 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 163 | + dim_id * params.out_d_stride + chunk * kChunkSize; 164 | __syncthreads(); 165 | store_output(out, out_vals, smem_store, params.seqlen - chunk * kChunkSize); 166 | Bvar += kChunkSize; 167 | Cvar += kChunkSize; 168 | } 169 | } 170 | 171 | template 172 | void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { 173 | BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { 174 | using Ktraits = Selective_Scan_fwd_kernel_traits; 175 | constexpr int kSmemSize = Ktraits::kSmemSize + sizeof(typename Ktraits::scan_t); 176 | // printf("smem_size = %d\n", kSmemSize); 177 | dim3 grid(params.batch, params.dim); 178 | auto kernel = &selective_scan_fwd_kernel; 179 | if (kSmemSize >= 48 * 1024) { 180 | C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 181 | } 182 | kernel<<>>(params); 183 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 184 | }); 185 | } 186 | 187 | template 188 | void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { 189 | if (params.seqlen <= 128) { 190 | selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); 191 | } else if (params.seqlen <= 256) { 192 | selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); 193 | } else if (params.seqlen <= 512) { 194 | selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); 195 | } else if (params.seqlen <= 1024) { 196 | selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); 197 | } else { 198 | selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /classification/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Modified by $@#Anonymous#@$ 3 | # -------------------------------------------------------- 4 | # Swin Transformer 5 | # Copyright (c) 2021 Microsoft 6 | # Licensed under The MIT License [see LICENSE for details] 7 | # Written by Ze Liu 8 | # --------------------------------------------------------' 9 | 10 | import os 11 | import yaml 12 | from yacs.config import CfgNode as CN 13 | 14 | _C = CN() 15 | 16 | # Base config files 17 | _C.BASE = [''] 18 | 19 | # ----------------------------------------------------------------------------- 20 | # Data settings 21 | # ----------------------------------------------------------------------------- 22 | _C.DATA = CN() 23 | # Batch size for a single GPU, could be overwritten by command line argument 24 | _C.DATA.BATCH_SIZE = 128 25 | # Path to dataset, could be overwritten by command line argument 26 | _C.DATA.DATA_PATH = '' 27 | # Dataset name 28 | _C.DATA.DATASET = 'imagenet' 29 | # Input image size 30 | _C.DATA.IMG_SIZE = 224 31 | # Interpolation to resize image (random, bilinear, bicubic) 32 | _C.DATA.INTERPOLATION = 'bicubic' 33 | # Use zipped dataset instead of folder dataset 34 | # could be overwritten by command line argument 35 | _C.DATA.ZIP_MODE = False 36 | # Cache Data in Memory, could be overwritten by command line argument 37 | _C.DATA.CACHE_MODE = 'part' 38 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 39 | _C.DATA.PIN_MEMORY = True 40 | # Number of data loading threads 41 | _C.DATA.NUM_WORKERS = 8 42 | 43 | # [SimMIM] Mask patch size for MaskGenerator 44 | _C.DATA.MASK_PATCH_SIZE = 32 45 | # [SimMIM] Mask ratio for MaskGenerator 46 | _C.DATA.MASK_RATIO = 0.6 47 | 48 | # ----------------------------------------------------------------------------- 49 | # Model settings 50 | # ----------------------------------------------------------------------------- 51 | _C.MODEL = CN() 52 | # Model type 53 | _C.MODEL.TYPE = 'vssm' 54 | # Model name 55 | _C.MODEL.NAME = 'vssm_tiny_224' 56 | # Pretrained weight from checkpoint, could be imagenet22k pretrained weight 57 | # could be overwritten by command line argument 58 | _C.MODEL.PRETRAINED = '' 59 | # Checkpoint to resume, could be overwritten by command line argument 60 | _C.MODEL.RESUME = '' 61 | # Number of classes, overwritten in data preparation 62 | _C.MODEL.NUM_CLASSES = 100 63 | # Dropout rate 64 | _C.MODEL.DROP_RATE = 0.0 65 | # Drop path rate 66 | _C.MODEL.DROP_PATH_RATE = 0.1 67 | # Label Smoothing 68 | _C.MODEL.LABEL_SMOOTHING = 0.1 69 | 70 | # MMpretrain models for test 71 | _C.MODEL.MMCKPT = False 72 | 73 | # VSSM parameters 74 | _C.MODEL.VSSM = CN() 75 | _C.MODEL.VSSM.PATCH_SIZE = 16 76 | _C.MODEL.VSSM.IN_CHANS = 3 77 | _C.MODEL.VSSM.DEPTHS = [2, 2, 9, 2] 78 | _C.MODEL.VSSM.EMBED_DIM = 96 ############# you have to check the second config to tune this hyperparameter 79 | _C.MODEL.VSSM.SSM_D_STATE = 16 80 | _C.MODEL.VSSM.SSM_RATIO = 2.0 81 | _C.MODEL.VSSM.SSM_RANK_RATIO = 2.0 82 | _C.MODEL.VSSM.SSM_DT_RANK = "auto" 83 | _C.MODEL.VSSM.SSM_ACT_LAYER = "silu" 84 | _C.MODEL.VSSM.SSM_CONV = 1 85 | _C.MODEL.VSSM.SSM_CONV_BIAS = True 86 | _C.MODEL.VSSM.SSM_DROP_RATE = 0.0 87 | _C.MODEL.VSSM.SSM_INIT = "v0" 88 | _C.MODEL.VSSM.SSM_FORWARDTYPE = "v2" 89 | _C.MODEL.VSSM.MLP_RATIO = 4.0 90 | _C.MODEL.VSSM.MLP_ACT_LAYER = "gelu" 91 | _C.MODEL.VSSM.MLP_DROP_RATE = 0.0 92 | _C.MODEL.VSSM.PATCH_NORM = True 93 | _C.MODEL.VSSM.NORM_LAYER = "ln" 94 | _C.MODEL.VSSM.DOWNSAMPLE = "v2" 95 | _C.MODEL.VSSM.PATCHEMBED = "v2" 96 | _C.MODEL.VSSM.POSEMBED = False 97 | _C.MODEL.VSSM.GMLP = False 98 | _C.MODEL.VSSM.TOP_K = 4 99 | _C.MODEL.VSSM.KNN = 5 100 | _C.MODEL.VSSM.ALPHA = 100 101 | _C.MODEL.VSSM.AMBIGUITY = False 102 | _C.MODEL.VSSM.BINARY = False 103 | _C.MODEL.VSSM.K_GROUP = 8 104 | _C.MODEL.VSSM.DIVISION_RATE = 4 105 | _C.MODEL.VSSM.MODE = "RFN" 106 | _C.MODEL.VSSM.DIMENSION = "INCREASE", 107 | _C.MODEL.VSSM.CSMS6S_MODE = "NORMAL" 108 | ################################# one of these three 109 | 110 | # ----------------------------------------------------------------------------- 111 | # Training settings 112 | # ----------------------------------------------------------------------------- 113 | _C.TRAIN = CN() 114 | _C.TRAIN.START_EPOCH = 0 115 | _C.TRAIN.EPOCHS = 300 116 | _C.TRAIN.WARMUP_EPOCHS = 20 117 | _C.TRAIN.WEIGHT_DECAY = 0.05 118 | _C.TRAIN.BASE_LR = 5e-4 119 | _C.TRAIN.WARMUP_LR = 5e-7 120 | _C.TRAIN.MIN_LR = 5e-6 121 | # Clip gradient norm 122 | _C.TRAIN.CLIP_GRAD = 5.0 123 | # Auto resume from latest checkpoint 124 | _C.TRAIN.AUTO_RESUME = True 125 | # Gradient accumulation steps 126 | # could be overwritten by command line argument 127 | _C.TRAIN.ACCUMULATION_STEPS = 1 128 | # Whether to use gradient checkpointing to save memory 129 | # could be overwritten by command line argument 130 | _C.TRAIN.USE_CHECKPOINT = False 131 | 132 | # LR scheduler 133 | _C.TRAIN.LR_SCHEDULER = CN() 134 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 135 | # Epoch interval to decay LR, used in StepLRScheduler 136 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 137 | # LR decay rate, used in StepLRScheduler 138 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 139 | # warmup_prefix used in CosineLRScheduler 140 | _C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = True 141 | # [SimMIM] Gamma / Multi steps value, used in MultiStepLRScheduler 142 | _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 143 | _C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] 144 | 145 | # Optimizer 146 | _C.TRAIN.OPTIMIZER = CN() 147 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 148 | # Optimizer Epsilon 149 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 150 | # Optimizer Betas 151 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 152 | # SGD momentum 153 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 154 | 155 | # [SimMIM] Layer decay for fine-tuning 156 | _C.TRAIN.LAYER_DECAY = 1.0 157 | 158 | # MoE 159 | _C.TRAIN.MOE = CN() 160 | # Only save model on master device 161 | _C.TRAIN.MOE.SAVE_MASTER = False 162 | # ----------------------------------------------------------------------------- 163 | # Augmentation settings 164 | # ----------------------------------------------------------------------------- 165 | _C.AUG = CN() 166 | # Color jitter factor 167 | _C.AUG.COLOR_JITTER = 0.4 168 | # Use AutoAugment policy. "v0" or "original" 169 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 170 | # Random erase prob 171 | _C.AUG.REPROB = 0.25 172 | # Random erase mode 173 | _C.AUG.REMODE = 'pixel' 174 | # Random erase count 175 | _C.AUG.RECOUNT = 1 176 | # Mixup alpha, mixup enabled if > 0 177 | _C.AUG.MIXUP = 0.8 178 | # Cutmix alpha, cutmix enabled if > 0 179 | _C.AUG.CUTMIX = 1.0 180 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 181 | _C.AUG.CUTMIX_MINMAX = None 182 | # Probability of performing mixup or cutmix when either/both is enabled 183 | _C.AUG.MIXUP_PROB = 1.0 184 | # Probability of switching to cutmix when both mixup and cutmix enabled 185 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 186 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 187 | _C.AUG.MIXUP_MODE = 'batch' 188 | 189 | # ----------------------------------------------------------------------------- 190 | # Testing settings 191 | # ----------------------------------------------------------------------------- 192 | _C.TEST = CN() 193 | # Whether to use center crop when testing 194 | _C.TEST.CROP = True 195 | # Whether to use SequentialSampler as validation sampler 196 | _C.TEST.SEQUENTIAL = False 197 | _C.TEST.SHUFFLE = False 198 | 199 | # ----------------------------------------------------------------------------- 200 | # Misc 201 | # ----------------------------------------------------------------------------- 202 | # [SimMIM] Whether to enable pytorch amp, overwritten by command line argument 203 | _C.ENABLE_AMP = False 204 | 205 | # Enable Pytorch automatic mixed precision (amp). 206 | _C.AMP_ENABLE = True 207 | # [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2') 208 | _C.AMP_OPT_LEVEL = '' 209 | # Path to output folder, overwritten by command line argument 210 | _C.OUTPUT = '' 211 | # Tag of experiment, overwritten by command line argument 212 | _C.TAG = 'default' 213 | # Frequency to save checkpoint 214 | _C.SAVE_FREQ = 1 215 | # Frequency to logging info 216 | _C.PRINT_FREQ = 10 217 | # Fixed random seed 218 | _C.SEED = 0 219 | # Perform evaluation only, overwritten by command line argument 220 | _C.EVAL_MODE = False 221 | # Test throughput only, overwritten by command line argument 222 | _C.THROUGHPUT_MODE = False 223 | # Test traincost only, overwritten by command line argument 224 | _C.TRAINCOST_MODE = False 225 | # for acceleration 226 | _C.FUSED_LAYERNORM = False 227 | 228 | 229 | def _update_config_from_file(config, cfg_file): 230 | config.defrost() 231 | with open(cfg_file, 'r') as f: 232 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 233 | 234 | for cfg in yaml_cfg.setdefault('BASE', ['']): 235 | if cfg: 236 | _update_config_from_file( 237 | config, os.path.join(os.path.dirname(cfg_file), cfg) 238 | ) 239 | print('=> merge config from {}'.format(cfg_file)) 240 | config.merge_from_file(cfg_file) 241 | config.freeze() 242 | 243 | 244 | def update_config(config, args): 245 | _update_config_from_file(config, args.cfg) 246 | 247 | config.defrost() 248 | if args.opts: 249 | config.merge_from_list(args.opts) 250 | 251 | def _check_args(name): 252 | if hasattr(args, name) and eval(f'args.{name}'): 253 | return True 254 | return False 255 | 256 | # merge from specific arguments 257 | if _check_args('batch_size'): 258 | config.DATA.BATCH_SIZE = args.batch_size 259 | if _check_args('data_path'): 260 | config.DATA.DATA_PATH = args.data_path 261 | if _check_args('zip'): 262 | config.DATA.ZIP_MODE = True 263 | if _check_args('cache_mode'): 264 | config.DATA.CACHE_MODE = args.cache_mode 265 | if _check_args('pretrained'): 266 | config.MODEL.PRETRAINED = args.pretrained 267 | if _check_args('resume'): 268 | config.MODEL.RESUME = args.resume 269 | if _check_args('accumulation_steps'): 270 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 271 | if _check_args('use_checkpoint'): 272 | config.TRAIN.USE_CHECKPOINT = True 273 | if _check_args('disable_amp'): 274 | config.AMP_ENABLE = False 275 | if _check_args('output'): 276 | config.OUTPUT = args.output 277 | if _check_args('tag'): 278 | config.TAG = args.tag 279 | if _check_args('eval'): 280 | config.EVAL_MODE = True 281 | if _check_args('throughput'): 282 | config.THROUGHPUT_MODE = True 283 | if _check_args('traincost'): 284 | config.TRAINCOST_MODE = True 285 | 286 | # [SimMIM] 287 | if _check_args('enable_amp'): 288 | config.ENABLE_AMP = args.enable_amp 289 | 290 | # for acceleration 291 | if _check_args('fused_layernorm'): 292 | config.FUSED_LAYERNORM = True 293 | ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb] 294 | if _check_args('optim'): 295 | config.TRAIN.OPTIMIZER.NAME = args.optim 296 | 297 | # output folder 298 | 299 | config.OUTPUT = os.path.join(config.OUTPUT, f"mode({config.MODEL.VSSM.MODE})", config.MODEL.NAME, config.TAG) 300 | 301 | config.freeze() 302 | 303 | 304 | def get_config(args): 305 | """Get a yacs CfgNode object with default values.""" 306 | # Return a clone so that the defaults will not be altered 307 | # This is for the "local variable" use pattern 308 | config = _C.clone() 309 | update_config(config, args) 310 | 311 | return config 312 | -------------------------------------------------------------------------------- /classification/data/map22kto1k.txt: -------------------------------------------------------------------------------- 1 | 359 2 | 368 3 | 460 4 | 475 5 | 486 6 | 492 7 | 496 8 | 514 9 | 516 10 | 525 11 | 547 12 | 548 13 | 556 14 | 563 15 | 575 16 | 641 17 | 648 18 | 723 19 | 733 20 | 765 21 | 801 22 | 826 23 | 852 24 | 858 25 | 878 26 | 896 27 | 900 28 | 905 29 | 908 30 | 910 31 | 935 32 | 946 33 | 947 34 | 994 35 | 999 36 | 1003 37 | 1005 38 | 1010 39 | 1027 40 | 1029 41 | 1048 42 | 1055 43 | 1064 44 | 1065 45 | 1069 46 | 1075 47 | 1079 48 | 1081 49 | 1085 50 | 1088 51 | 1093 52 | 1106 53 | 1143 54 | 1144 55 | 1145 56 | 1147 57 | 1168 58 | 1171 59 | 1178 60 | 1187 61 | 1190 62 | 1197 63 | 1205 64 | 1216 65 | 1223 66 | 1230 67 | 1236 68 | 1241 69 | 1245 70 | 1257 71 | 1259 72 | 1260 73 | 1267 74 | 1268 75 | 1269 76 | 1271 77 | 1272 78 | 1273 79 | 1277 80 | 1303 81 | 1344 82 | 1349 83 | 1355 84 | 1357 85 | 1384 86 | 1388 87 | 1391 88 | 1427 89 | 1429 90 | 1432 91 | 1437 92 | 1450 93 | 1461 94 | 1462 95 | 1474 96 | 1502 97 | 1503 98 | 1512 99 | 1552 100 | 1555 101 | 1577 102 | 1584 103 | 1587 104 | 1589 105 | 1599 106 | 1615 107 | 1616 108 | 1681 109 | 1692 110 | 1701 111 | 1716 112 | 1729 113 | 1757 114 | 1759 115 | 1764 116 | 1777 117 | 1786 118 | 1822 119 | 1841 120 | 1842 121 | 1848 122 | 1850 123 | 1856 124 | 1860 125 | 1861 126 | 1864 127 | 1876 128 | 1897 129 | 1898 130 | 1910 131 | 1913 132 | 1918 133 | 1922 134 | 1928 135 | 1932 136 | 1935 137 | 1947 138 | 1951 139 | 1953 140 | 1970 141 | 1977 142 | 1979 143 | 2001 144 | 2017 145 | 2067 146 | 2081 147 | 2087 148 | 2112 149 | 2128 150 | 2135 151 | 2147 152 | 2174 153 | 2175 154 | 2176 155 | 2177 156 | 2178 157 | 2181 158 | 2183 159 | 2184 160 | 2187 161 | 2189 162 | 2190 163 | 2191 164 | 2192 165 | 2193 166 | 2197 167 | 2202 168 | 2203 169 | 2206 170 | 2208 171 | 2209 172 | 2211 173 | 2212 174 | 2213 175 | 2214 176 | 2215 177 | 2216 178 | 2217 179 | 2219 180 | 2222 181 | 2223 182 | 2224 183 | 2225 184 | 2226 185 | 2227 186 | 2228 187 | 2229 188 | 2230 189 | 2236 190 | 2238 191 | 2240 192 | 2241 193 | 2242 194 | 2243 195 | 2244 196 | 2245 197 | 2247 198 | 2248 199 | 2249 200 | 2250 201 | 2251 202 | 2252 203 | 2255 204 | 2256 205 | 2257 206 | 2262 207 | 2263 208 | 2264 209 | 2265 210 | 2266 211 | 2268 212 | 2270 213 | 2271 214 | 2272 215 | 2273 216 | 2275 217 | 2276 218 | 2279 219 | 2280 220 | 2281 221 | 2282 222 | 2285 223 | 2289 224 | 2292 225 | 2295 226 | 2296 227 | 2297 228 | 2298 229 | 2299 230 | 2300 231 | 2301 232 | 2302 233 | 2303 234 | 2304 235 | 2305 236 | 2306 237 | 2309 238 | 2310 239 | 2312 240 | 2313 241 | 2314 242 | 2315 243 | 2316 244 | 2318 245 | 2319 246 | 2321 247 | 2322 248 | 2326 249 | 2329 250 | 2330 251 | 2331 252 | 2332 253 | 2334 254 | 2335 255 | 2336 256 | 2337 257 | 2338 258 | 2339 259 | 2341 260 | 2342 261 | 2343 262 | 2344 263 | 2346 264 | 2348 265 | 2349 266 | 2351 267 | 2352 268 | 2353 269 | 2355 270 | 2357 271 | 2358 272 | 2359 273 | 2360 274 | 2364 275 | 2365 276 | 2368 277 | 2369 278 | 2377 279 | 2382 280 | 2383 281 | 2385 282 | 2397 283 | 2398 284 | 2400 285 | 2402 286 | 2405 287 | 2412 288 | 2421 289 | 2428 290 | 2431 291 | 2432 292 | 2433 293 | 2436 294 | 2441 295 | 2445 296 | 2450 297 | 2453 298 | 2454 299 | 2465 300 | 2469 301 | 2532 302 | 2533 303 | 2538 304 | 2544 305 | 2547 306 | 2557 307 | 2565 308 | 2578 309 | 2612 310 | 2658 311 | 2702 312 | 2722 313 | 2731 314 | 2738 315 | 2741 316 | 2747 317 | 2810 318 | 2818 319 | 2833 320 | 2844 321 | 2845 322 | 2867 323 | 2874 324 | 2882 325 | 2884 326 | 2888 327 | 2889 328 | 3008 329 | 3012 330 | 3019 331 | 3029 332 | 3033 333 | 3042 334 | 3091 335 | 3106 336 | 3138 337 | 3159 338 | 3164 339 | 3169 340 | 3280 341 | 3296 342 | 3311 343 | 3318 344 | 3320 345 | 3324 346 | 3330 347 | 3366 348 | 3375 349 | 3381 350 | 3406 351 | 3419 352 | 3432 353 | 3434 354 | 3435 355 | 3493 356 | 3495 357 | 3503 358 | 3509 359 | 3511 360 | 3513 361 | 3517 362 | 3521 363 | 3526 364 | 3546 365 | 3554 366 | 3600 367 | 3601 368 | 3606 369 | 3612 370 | 3613 371 | 3616 372 | 3622 373 | 3623 374 | 3627 375 | 3632 376 | 3634 377 | 3636 378 | 3638 379 | 3644 380 | 3646 381 | 3649 382 | 3650 383 | 3651 384 | 3656 385 | 3663 386 | 3673 387 | 3674 388 | 3689 389 | 3690 390 | 3702 391 | 3733 392 | 3769 393 | 3971 394 | 3974 395 | 4065 396 | 4068 397 | 4073 398 | 4102 399 | 4136 400 | 4140 401 | 4151 402 | 4159 403 | 4165 404 | 4207 405 | 4219 406 | 4226 407 | 4249 408 | 4256 409 | 4263 410 | 4270 411 | 4313 412 | 4321 413 | 4378 414 | 4386 415 | 4478 416 | 4508 417 | 4512 418 | 4536 419 | 4542 420 | 4550 421 | 4560 422 | 4562 423 | 4570 424 | 4571 425 | 4572 426 | 4583 427 | 4588 428 | 4594 429 | 4604 430 | 4608 431 | 4623 432 | 4634 433 | 4636 434 | 4646 435 | 4651 436 | 4652 437 | 4686 438 | 4688 439 | 4691 440 | 4699 441 | 4724 442 | 4727 443 | 4737 444 | 4770 445 | 4774 446 | 4789 447 | 4802 448 | 4807 449 | 4819 450 | 4880 451 | 4886 452 | 4908 453 | 4927 454 | 4931 455 | 4936 456 | 4964 457 | 4976 458 | 4993 459 | 5028 460 | 5033 461 | 5043 462 | 5046 463 | 5096 464 | 5111 465 | 5114 466 | 5131 467 | 5132 468 | 5183 469 | 5199 470 | 5235 471 | 5275 472 | 5291 473 | 5293 474 | 5294 475 | 5343 476 | 5360 477 | 5362 478 | 5364 479 | 5390 480 | 5402 481 | 5418 482 | 5428 483 | 5430 484 | 5437 485 | 5443 486 | 5473 487 | 5484 488 | 5486 489 | 5505 490 | 5507 491 | 5508 492 | 5510 493 | 5567 494 | 5578 495 | 5580 496 | 5584 497 | 5606 498 | 5613 499 | 5629 500 | 5672 501 | 5676 502 | 5692 503 | 5701 504 | 5760 505 | 5769 506 | 5770 507 | 5779 508 | 5814 509 | 5850 510 | 5871 511 | 5893 512 | 5911 513 | 5949 514 | 5954 515 | 6005 516 | 6006 517 | 6012 518 | 6017 519 | 6023 520 | 6024 521 | 6040 522 | 6050 523 | 6054 524 | 6087 525 | 6105 526 | 6157 527 | 6235 528 | 6237 529 | 6256 530 | 6259 531 | 6286 532 | 6291 533 | 6306 534 | 6339 535 | 6341 536 | 6343 537 | 6379 538 | 6383 539 | 6393 540 | 6405 541 | 6479 542 | 6511 543 | 6517 544 | 6541 545 | 6561 546 | 6608 547 | 6611 548 | 6615 549 | 6678 550 | 6682 551 | 6707 552 | 6752 553 | 6798 554 | 6850 555 | 6880 556 | 6885 557 | 6890 558 | 6920 559 | 6981 560 | 7000 561 | 7009 562 | 7038 563 | 7049 564 | 7050 565 | 7052 566 | 7073 567 | 7078 568 | 7098 569 | 7111 570 | 7165 571 | 7198 572 | 7204 573 | 7280 574 | 7283 575 | 7286 576 | 7287 577 | 7293 578 | 7294 579 | 7305 580 | 7318 581 | 7341 582 | 7346 583 | 7354 584 | 7382 585 | 7427 586 | 7428 587 | 7435 588 | 7445 589 | 7450 590 | 7455 591 | 7467 592 | 7469 593 | 7497 594 | 7502 595 | 7506 596 | 7514 597 | 7523 598 | 7651 599 | 7661 600 | 7664 601 | 7672 602 | 7679 603 | 7685 604 | 7696 605 | 7730 606 | 7871 607 | 7873 608 | 7895 609 | 7914 610 | 7915 611 | 7920 612 | 7934 613 | 7935 614 | 7949 615 | 8009 616 | 8036 617 | 8051 618 | 8065 619 | 8074 620 | 8090 621 | 8112 622 | 8140 623 | 8164 624 | 8168 625 | 8178 626 | 8182 627 | 8198 628 | 8212 629 | 8216 630 | 8230 631 | 8242 632 | 8288 633 | 8289 634 | 8295 635 | 8318 636 | 8352 637 | 8368 638 | 8371 639 | 8375 640 | 8376 641 | 8401 642 | 8416 643 | 8419 644 | 8436 645 | 8460 646 | 8477 647 | 8478 648 | 8482 649 | 8498 650 | 8500 651 | 8539 652 | 8543 653 | 8552 654 | 8555 655 | 8580 656 | 8584 657 | 8586 658 | 8594 659 | 8598 660 | 8601 661 | 8606 662 | 8610 663 | 8611 664 | 8622 665 | 8627 666 | 8639 667 | 8649 668 | 8650 669 | 8653 670 | 8654 671 | 8667 672 | 8672 673 | 8673 674 | 8674 675 | 8676 676 | 8684 677 | 8720 678 | 8723 679 | 8750 680 | 8753 681 | 8801 682 | 8815 683 | 8831 684 | 8835 685 | 8842 686 | 8845 687 | 8858 688 | 8897 689 | 8916 690 | 8951 691 | 8954 692 | 8959 693 | 8970 694 | 8976 695 | 8981 696 | 8983 697 | 8989 698 | 8991 699 | 8993 700 | 9019 701 | 9039 702 | 9042 703 | 9043 704 | 9056 705 | 9057 706 | 9070 707 | 9087 708 | 9098 709 | 9106 710 | 9130 711 | 9131 712 | 9155 713 | 9171 714 | 9183 715 | 9198 716 | 9199 717 | 9201 718 | 9204 719 | 9212 720 | 9221 721 | 9225 722 | 9229 723 | 9250 724 | 9260 725 | 9271 726 | 9279 727 | 9295 728 | 9300 729 | 9310 730 | 9322 731 | 9345 732 | 9352 733 | 9376 734 | 9377 735 | 9382 736 | 9392 737 | 9401 738 | 9405 739 | 9441 740 | 9449 741 | 9464 742 | 9475 743 | 9502 744 | 9505 745 | 9514 746 | 9515 747 | 9545 748 | 9567 749 | 9576 750 | 9608 751 | 9609 752 | 9624 753 | 9633 754 | 9639 755 | 9643 756 | 9656 757 | 9674 758 | 9740 759 | 9752 760 | 9760 761 | 9767 762 | 9778 763 | 9802 764 | 9820 765 | 9839 766 | 9879 767 | 9924 768 | 9956 769 | 9961 770 | 9963 771 | 9970 772 | 9997 773 | 10010 774 | 10031 775 | 10040 776 | 10052 777 | 10073 778 | 10075 779 | 10078 780 | 10094 781 | 10097 782 | 10109 783 | 10118 784 | 10121 785 | 10124 786 | 10158 787 | 10226 788 | 10276 789 | 10304 790 | 10307 791 | 10314 792 | 10315 793 | 10332 794 | 10337 795 | 10338 796 | 10413 797 | 10423 798 | 10451 799 | 10463 800 | 10465 801 | 10487 802 | 10519 803 | 10522 804 | 10523 805 | 10532 806 | 10534 807 | 10535 808 | 10551 809 | 10559 810 | 10574 811 | 10583 812 | 10586 813 | 10589 814 | 10612 815 | 10626 816 | 10635 817 | 10638 818 | 10677 819 | 10683 820 | 10726 821 | 10776 822 | 10782 823 | 10783 824 | 10807 825 | 10837 826 | 10840 827 | 10848 828 | 10859 829 | 10871 830 | 10881 831 | 10884 832 | 10908 833 | 10914 834 | 10921 835 | 10936 836 | 10947 837 | 10951 838 | 10952 839 | 10957 840 | 10999 841 | 11003 842 | 11018 843 | 11023 844 | 11025 845 | 11027 846 | 11045 847 | 11055 848 | 11095 849 | 11110 850 | 11137 851 | 5564 852 | 11168 853 | 11186 854 | 11221 855 | 11223 856 | 11242 857 | 11255 858 | 11259 859 | 11279 860 | 11306 861 | 11311 862 | 11331 863 | 11367 864 | 11377 865 | 11389 866 | 11392 867 | 11401 868 | 11407 869 | 11437 870 | 11449 871 | 11466 872 | 11469 873 | 11473 874 | 11478 875 | 11483 876 | 11484 877 | 11507 878 | 11536 879 | 11558 880 | 11566 881 | 11575 882 | 11584 883 | 11594 884 | 11611 885 | 11612 886 | 11619 887 | 11621 888 | 11640 889 | 11643 890 | 11664 891 | 11674 892 | 11689 893 | 11709 894 | 11710 895 | 11716 896 | 11721 897 | 11726 898 | 11729 899 | 11743 900 | 11760 901 | 11771 902 | 11837 903 | 11839 904 | 11856 905 | 11876 906 | 11878 907 | 11884 908 | 11889 909 | 11896 910 | 11917 911 | 11923 912 | 11930 913 | 11944 914 | 11952 915 | 11980 916 | 11984 917 | 12214 918 | 12229 919 | 12239 920 | 12241 921 | 12242 922 | 12247 923 | 12283 924 | 12349 925 | 12369 926 | 12373 927 | 12422 928 | 12560 929 | 12566 930 | 12575 931 | 12688 932 | 12755 933 | 12768 934 | 12778 935 | 12780 936 | 12812 937 | 12832 938 | 12835 939 | 12836 940 | 12843 941 | 12847 942 | 12849 943 | 12850 944 | 12856 945 | 12858 946 | 12873 947 | 12938 948 | 12971 949 | 13017 950 | 13038 951 | 13046 952 | 13059 953 | 13085 954 | 13086 955 | 13088 956 | 13094 957 | 13134 958 | 13182 959 | 13230 960 | 13406 961 | 13444 962 | 13614 963 | 13690 964 | 13698 965 | 13709 966 | 13749 967 | 13804 968 | 13982 969 | 14051 970 | 14059 971 | 14219 972 | 14246 973 | 14256 974 | 14264 975 | 14294 976 | 14324 977 | 14367 978 | 14389 979 | 14394 980 | 14438 981 | 14442 982 | 14965 983 | 15732 984 | 16744 985 | 18037 986 | 18205 987 | 18535 988 | 18792 989 | 19102 990 | 20019 991 | 20462 992 | 21026 993 | 21045 994 | 21163 995 | 21171 996 | 21181 997 | 21196 998 | 21200 999 | 21369 1000 | 21817 -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cus/selective_scan_fwd_kernel.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "selective_scan.h" 16 | #include "selective_scan_common.h" 17 | #include "static_switch.h" 18 | 19 | template 20 | struct Selective_Scan_fwd_kernel_traits { 21 | static_assert(kNItems_ % 4 == 0); 22 | using input_t = input_t_; 23 | using weight_t = weight_t_; 24 | static constexpr int kNThreads = kNThreads_; 25 | // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. 26 | static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; 27 | static constexpr int kNItems = kNItems_; 28 | static constexpr int MaxDState = MAX_DSTATE; 29 | static constexpr int kNBytes = sizeof(input_t); 30 | static_assert(kNBytes == 2 || kNBytes == 4); 31 | static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); 32 | static_assert(kNItems % kNElts == 0); 33 | static constexpr int kNLoads = kNItems / kNElts; 34 | static constexpr bool kIsEvenLen = kIsEvenLen_; 35 | 36 | static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; 37 | 38 | using vec_t = typename BytesToType::Type; 39 | using scan_t = float2; 40 | using BlockLoadT = cub::BlockLoad; 41 | using BlockLoadVecT = cub::BlockLoad; 43 | using BlockLoadWeightT = cub::BlockLoad; 44 | using BlockLoadWeightVecT = cub::BlockLoad; 46 | using BlockStoreT = cub::BlockStore; 47 | using BlockStoreVecT = cub::BlockStore; 49 | // using BlockScanT = cub::BlockScan; 50 | // using BlockScanT = cub::BlockScan; 51 | using BlockScanT = cub::BlockScan; 52 | static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), 53 | sizeof(typename BlockLoadVecT::TempStorage), 54 | 2 * sizeof(typename BlockLoadWeightT::TempStorage), 55 | 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), 56 | sizeof(typename BlockStoreT::TempStorage), 57 | sizeof(typename BlockStoreVecT::TempStorage)}); 58 | static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); 59 | }; 60 | 61 | template 62 | __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) 63 | void selective_scan_fwd_kernel(SSMParamsBase params) { 64 | constexpr int kNThreads = Ktraits::kNThreads; 65 | constexpr int kNItems = Ktraits::kNItems; 66 | constexpr bool kDirectIO = Ktraits::kDirectIO; 67 | using input_t = typename Ktraits::input_t; 68 | using weight_t = typename Ktraits::weight_t; 69 | using scan_t = typename Ktraits::scan_t; 70 | 71 | // Shared memory. 72 | extern __shared__ char smem_[]; 73 | auto& smem_load = reinterpret_cast(smem_); 74 | auto& smem_load_weight = reinterpret_cast(smem_); 75 | auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); 76 | auto& smem_store = reinterpret_cast(smem_); 77 | auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); 78 | scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); 79 | 80 | const int batch_id = blockIdx.x; 81 | const int dim_id = blockIdx.y; 82 | const int group_id = dim_id / (params.dim_ngroups_ratio); 83 | input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride 84 | + dim_id * params.u_d_stride; 85 | input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride 86 | + dim_id * params.delta_d_stride; 87 | weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; 88 | input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; 89 | input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; 90 | scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks * params.dstate; 91 | 92 | float D_val = 0; // attention! 93 | if (params.D_ptr != nullptr) { 94 | D_val = reinterpret_cast(params.D_ptr)[dim_id]; 95 | } 96 | float delta_bias = 0; 97 | if (params.delta_bias_ptr != nullptr) { 98 | delta_bias = reinterpret_cast(params.delta_bias_ptr)[dim_id]; 99 | } 100 | 101 | constexpr int kChunkSize = kNThreads * kNItems; 102 | for (int chunk = 0; chunk < params.n_chunks; ++chunk) { 103 | input_t u_vals[kNItems], delta_vals_load[kNItems]; 104 | __syncthreads(); 105 | load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); 106 | if constexpr (!kDirectIO) { __syncthreads(); } 107 | load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); 108 | u += kChunkSize; 109 | delta += kChunkSize; 110 | 111 | float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems]; 112 | #pragma unroll 113 | for (int i = 0; i < kNItems; ++i) { 114 | float u_val = float(u_vals[i]); 115 | delta_vals[i] = float(delta_vals_load[i]) + delta_bias; 116 | if (params.delta_softplus) { 117 | delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; 118 | } 119 | delta_u_vals[i] = delta_vals[i] * u_val; 120 | out_vals[i] = D_val * u_val; 121 | } 122 | 123 | __syncthreads(); 124 | for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { 125 | constexpr float kLog2e = M_LOG2E; 126 | weight_t A_val = A[state_idx * params.A_dstate_stride]; 127 | A_val *= kLog2e; 128 | weight_t B_vals[kNItems], C_vals[kNItems]; 129 | load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, 130 | smem_load_weight, (params.seqlen - chunk * kChunkSize)); 131 | load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, 132 | smem_load_weight1, (params.seqlen - chunk * kChunkSize)); 133 | __syncthreads(); 134 | scan_t thread_data[kNItems]; 135 | #pragma unroll 136 | for (int i = 0; i < kNItems; ++i) { 137 | thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]); 138 | if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct 139 | if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { 140 | thread_data[i] = make_float2(1.f, 0.f); 141 | } 142 | } 143 | } 144 | // Initialize running total 145 | scan_t running_prefix; 146 | // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read 147 | running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); 148 | // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); 149 | SSMScanPrefixCallbackOp prefix_op(running_prefix); 150 | Ktraits::BlockScanT(smem_scan).InclusiveScan( 151 | thread_data, thread_data, SSMScanOp(), prefix_op 152 | ); 153 | // There's a syncthreads in the scan op, so we don't need to sync here. 154 | // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. 155 | if (threadIdx.x == 0) { 156 | smem_running_prefix[state_idx] = prefix_op.running_prefix; 157 | x[chunk * params.dstate + state_idx] = prefix_op.running_prefix; 158 | } 159 | #pragma unroll 160 | for (int i = 0; i < kNItems; ++i) { 161 | out_vals[i] += thread_data[i].y * C_vals[i]; 162 | } 163 | } 164 | 165 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 166 | + dim_id * params.out_d_stride + chunk * kChunkSize; 167 | __syncthreads(); 168 | store_output(out, out_vals, smem_store, params.seqlen - chunk * kChunkSize); 169 | Bvar += kChunkSize; 170 | Cvar += kChunkSize; 171 | } 172 | } 173 | 174 | template 175 | void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { 176 | BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { 177 | using Ktraits = Selective_Scan_fwd_kernel_traits; 178 | constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); 179 | // printf("smem_size = %d\n", kSmemSize); 180 | dim3 grid(params.batch, params.dim); 181 | auto kernel = &selective_scan_fwd_kernel; 182 | if (kSmemSize >= 48 * 1024) { 183 | C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 184 | } 185 | kernel<<>>(params); 186 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 187 | }); 188 | } 189 | 190 | template 191 | void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { 192 | if (params.seqlen <= 128) { 193 | selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); 194 | } else if (params.seqlen <= 256) { 195 | selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); 196 | } else if (params.seqlen <= 512) { 197 | selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); 198 | } else if (params.seqlen <= 1024) { 199 | selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); 200 | } else { 201 | selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusoflex/selective_scan_fwd_kernel_oflex.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "selective_scan.h" 16 | #include "selective_scan_common.h" 17 | #include "static_switch.h" 18 | 19 | template 20 | struct Selective_Scan_fwd_kernel_traits { 21 | static_assert(kNItems_ % 4 == 0); 22 | using input_t = input_t_; 23 | using weight_t = weight_t_; 24 | using output_t = output_t_; 25 | static constexpr int kNThreads = kNThreads_; 26 | // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. 27 | static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; 28 | static constexpr int kNItems = kNItems_; 29 | static constexpr int MaxDState = MAX_DSTATE; 30 | static constexpr int kNBytes = sizeof(input_t); 31 | static_assert(kNBytes == 2 || kNBytes == 4); 32 | static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); 33 | static_assert(kNItems % kNElts == 0); 34 | static constexpr int kNLoads = kNItems / kNElts; 35 | static constexpr bool kIsEvenLen = kIsEvenLen_; 36 | static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; 37 | static constexpr int kNLoadsOutput = sizeof(output_t) * kNLoads / kNBytes; 38 | static constexpr bool kDirectIOOutput = kDirectIO && (kNLoadsOutput == 1); 39 | using vec_t = typename BytesToType::Type; 40 | using scan_t = float2; 41 | using BlockLoadT = cub::BlockLoad; 42 | using BlockLoadVecT = cub::BlockLoad; 44 | using BlockLoadWeightT = cub::BlockLoad; 45 | using BlockLoadWeightVecT = cub::BlockLoad; 47 | using BlockStoreT = cub::BlockStore; 48 | using BlockStoreVecT = cub::BlockStore; 50 | using BlockStoreOutputT = cub::BlockStore; 51 | using BlockStoreOutputVecT = cub::BlockStore; 53 | // using BlockScanT = cub::BlockScan; 54 | // using BlockScanT = cub::BlockScan; 55 | using BlockScanT = cub::BlockScan; 56 | static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), 57 | sizeof(typename BlockLoadVecT::TempStorage), 58 | 2 * sizeof(typename BlockLoadWeightT::TempStorage), 59 | 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), 60 | sizeof(typename BlockStoreT::TempStorage), 61 | sizeof(typename BlockStoreVecT::TempStorage), 62 | sizeof(typename BlockStoreOutputT::TempStorage), 63 | sizeof(typename BlockStoreOutputVecT::TempStorage)}); 64 | static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); 65 | }; 66 | 67 | template 68 | __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) 69 | void selective_scan_fwd_kernel(SSMParamsBase params) { 70 | constexpr int kNThreads = Ktraits::kNThreads; 71 | constexpr int kNItems = Ktraits::kNItems; 72 | constexpr bool kDirectIO = Ktraits::kDirectIO; 73 | using input_t = typename Ktraits::input_t; 74 | using weight_t = typename Ktraits::weight_t; 75 | using output_t = typename Ktraits::output_t; 76 | using scan_t = typename Ktraits::scan_t; 77 | 78 | // Shared memory. 79 | extern __shared__ char smem_[]; 80 | auto& smem_load = reinterpret_cast(smem_); 81 | auto& smem_load_weight = reinterpret_cast(smem_); 82 | auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); 83 | auto& smem_store = reinterpret_cast(smem_); 84 | auto& smem_store1 = reinterpret_cast(smem_); 85 | auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); 86 | scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); 87 | 88 | const int batch_id = blockIdx.x; 89 | const int dim_id = blockIdx.y; 90 | const int group_id = dim_id / (params.dim_ngroups_ratio); 91 | input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride 92 | + dim_id * params.u_d_stride; 93 | input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride 94 | + dim_id * params.delta_d_stride; 95 | weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; 96 | input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; 97 | input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; 98 | scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * params.n_chunks * params.dstate; 99 | 100 | float D_val = 0; // attention! 101 | if (params.D_ptr != nullptr) { 102 | D_val = reinterpret_cast(params.D_ptr)[dim_id]; 103 | } 104 | float delta_bias = 0; 105 | if (params.delta_bias_ptr != nullptr) { 106 | delta_bias = reinterpret_cast(params.delta_bias_ptr)[dim_id]; 107 | } 108 | 109 | constexpr int kChunkSize = kNThreads * kNItems; 110 | for (int chunk = 0; chunk < params.n_chunks; ++chunk) { 111 | input_t u_vals[kNItems], delta_vals_load[kNItems]; 112 | __syncthreads(); 113 | load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); 114 | if constexpr (!kDirectIO) { __syncthreads(); } 115 | load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); 116 | u += kChunkSize; 117 | delta += kChunkSize; 118 | 119 | float delta_vals[kNItems], delta_u_vals[kNItems], out_vals[kNItems]; 120 | #pragma unroll 121 | for (int i = 0; i < kNItems; ++i) { 122 | float u_val = float(u_vals[i]); 123 | delta_vals[i] = float(delta_vals_load[i]) + delta_bias; 124 | if (params.delta_softplus) { 125 | delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; 126 | } 127 | delta_u_vals[i] = delta_vals[i] * u_val; 128 | out_vals[i] = D_val * u_val; 129 | } 130 | 131 | __syncthreads(); 132 | for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { 133 | constexpr float kLog2e = M_LOG2E; 134 | weight_t A_val = A[state_idx * params.A_dstate_stride]; 135 | A_val *= kLog2e; 136 | weight_t B_vals[kNItems], C_vals[kNItems]; 137 | load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, 138 | smem_load_weight, (params.seqlen - chunk * kChunkSize)); 139 | load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, 140 | smem_load_weight1, (params.seqlen - chunk * kChunkSize)); 141 | __syncthreads(); 142 | scan_t thread_data[kNItems]; 143 | #pragma unroll 144 | for (int i = 0; i < kNItems; ++i) { 145 | thread_data[i] = make_float2(exp2f(delta_vals[i] * A_val), B_vals[i] * delta_u_vals[i]); 146 | if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct 147 | if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { 148 | thread_data[i] = make_float2(1.f, 0.f); 149 | } 150 | } 151 | } 152 | // Initialize running total 153 | scan_t running_prefix; 154 | // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read 155 | running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); 156 | // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); 157 | SSMScanPrefixCallbackOp prefix_op(running_prefix); 158 | Ktraits::BlockScanT(smem_scan).InclusiveScan( 159 | thread_data, thread_data, SSMScanOp(), prefix_op 160 | ); 161 | // There's a syncthreads in the scan op, so we don't need to sync here. 162 | // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. 163 | if (threadIdx.x == 0) { 164 | smem_running_prefix[state_idx] = prefix_op.running_prefix; 165 | x[chunk * params.dstate + state_idx] = prefix_op.running_prefix; 166 | } 167 | #pragma unroll 168 | for (int i = 0; i < kNItems; ++i) { 169 | out_vals[i] += thread_data[i].y * C_vals[i]; 170 | } 171 | } 172 | 173 | output_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 174 | + dim_id * params.out_d_stride + chunk * kChunkSize; 175 | __syncthreads(); 176 | store_output1(out, out_vals, smem_store1, params.seqlen - chunk * kChunkSize); 177 | Bvar += kChunkSize; 178 | Cvar += kChunkSize; 179 | } 180 | } 181 | 182 | template 183 | void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { 184 | BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { 185 | using Ktraits = Selective_Scan_fwd_kernel_traits; 186 | constexpr int kSmemSize = Ktraits::kSmemSize + Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); 187 | // printf("smem_size = %d\n", kSmemSize); 188 | dim3 grid(params.batch, params.dim); 189 | auto kernel = &selective_scan_fwd_kernel; 190 | if (kSmemSize >= 48 * 1024) { 191 | C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 192 | } 193 | kernel<<>>(params); 194 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 195 | }); 196 | } 197 | 198 | template 199 | void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { 200 | if (params.seqlen <= 128) { 201 | selective_scan_fwd_launch<32, 4, input_t, weight_t, output_t>(params, stream); 202 | } else if (params.seqlen <= 256) { 203 | selective_scan_fwd_launch<32, 8, input_t, weight_t, output_t>(params, stream); 204 | } else if (params.seqlen <= 512) { 205 | selective_scan_fwd_launch<32, 16, input_t, weight_t, output_t>(params, stream); 206 | } else if (params.seqlen <= 1024) { 207 | selective_scan_fwd_launch<64, 16, input_t, weight_t, output_t>(params, stream); 208 | } else { 209 | selective_scan_fwd_launch<128, 16, input_t, weight_t, output_t>(params, stream); 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusnrow/selective_scan_fwd_kernel_nrow.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "selective_scan.h" 16 | #include "selective_scan_common.h" 17 | #include "static_switch.h" 18 | 19 | template 20 | struct Selective_Scan_fwd_kernel_traits { 21 | static_assert(kNItems_ % 4 == 0); 22 | using input_t = input_t_; 23 | using weight_t = weight_t_; 24 | static constexpr int kNThreads = kNThreads_; 25 | // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. 26 | static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; 27 | static constexpr int kNItems = kNItems_; 28 | static constexpr int kNRows = kNRows_; 29 | static constexpr int MaxDState = MAX_DSTATE / kNRows; 30 | static constexpr int kNBytes = sizeof(input_t); 31 | static_assert(kNBytes == 2 || kNBytes == 4); 32 | static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); 33 | static_assert(kNItems % kNElts == 0); 34 | static constexpr int kNLoads = kNItems / kNElts; 35 | static constexpr bool kIsEvenLen = kIsEvenLen_; 36 | 37 | static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; 38 | 39 | using vec_t = typename BytesToType::Type; 40 | using scan_t = float2; 41 | using BlockLoadT = cub::BlockLoad; 42 | using BlockLoadVecT = cub::BlockLoad; 44 | using BlockLoadWeightT = cub::BlockLoad; 45 | using BlockLoadWeightVecT = cub::BlockLoad; 47 | using BlockStoreT = cub::BlockStore; 48 | using BlockStoreVecT = cub::BlockStore; 50 | // using BlockScanT = cub::BlockScan; 51 | // using BlockScanT = cub::BlockScan; 52 | using BlockScanT = cub::BlockScan; 53 | static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), 54 | sizeof(typename BlockLoadVecT::TempStorage), 55 | 2 * sizeof(typename BlockLoadWeightT::TempStorage), 56 | 2 * sizeof(typename BlockLoadWeightVecT::TempStorage), 57 | sizeof(typename BlockStoreT::TempStorage), 58 | sizeof(typename BlockStoreVecT::TempStorage)}); 59 | static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); 60 | }; 61 | 62 | template 63 | __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) 64 | void selective_scan_fwd_kernel(SSMParamsBase params) { 65 | constexpr int kNThreads = Ktraits::kNThreads; 66 | constexpr int kNItems = Ktraits::kNItems; 67 | constexpr int kNRows = Ktraits::kNRows; 68 | constexpr bool kDirectIO = Ktraits::kDirectIO; 69 | using input_t = typename Ktraits::input_t; 70 | using weight_t = typename Ktraits::weight_t; 71 | using scan_t = typename Ktraits::scan_t; 72 | 73 | // Shared memory. 74 | extern __shared__ char smem_[]; 75 | auto& smem_load = reinterpret_cast(smem_); 76 | auto& smem_load_weight = reinterpret_cast(smem_); 77 | auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); 78 | auto& smem_store = reinterpret_cast(smem_); 79 | auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); 80 | scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); 81 | 82 | const int batch_id = blockIdx.x; 83 | const int dim_id = blockIdx.y; 84 | const int dim_id_nrow = dim_id * kNRows; 85 | const int group_id = dim_id_nrow / (params.dim_ngroups_ratio); 86 | input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride 87 | + dim_id_nrow * params.u_d_stride; 88 | input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride 89 | + dim_id_nrow * params.delta_d_stride; 90 | weight_t *A = reinterpret_cast(params.A_ptr) + dim_id_nrow * params.A_d_stride; 91 | input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; 92 | input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; 93 | scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id_nrow) * params.n_chunks * params.dstate; 94 | 95 | float D_val[kNRows] = {0}; 96 | if (params.D_ptr != nullptr) { 97 | #pragma unroll 98 | for (int r = 0; r < kNRows; ++r) { 99 | D_val[r] = reinterpret_cast(params.D_ptr)[dim_id_nrow + r]; 100 | } 101 | } 102 | float delta_bias[kNRows] = {0}; 103 | if (params.delta_bias_ptr != nullptr) { 104 | #pragma unroll 105 | for (int r = 0; r < kNRows; ++r) { 106 | delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id_nrow + r]; 107 | } 108 | } 109 | 110 | constexpr int kChunkSize = kNThreads * kNItems; 111 | for (int chunk = 0; chunk < params.n_chunks; ++chunk) { 112 | input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; 113 | __syncthreads(); 114 | #pragma unroll 115 | for (int r = 0; r < kNRows; ++r) { 116 | if constexpr (!kDirectIO) { 117 | if (r > 0) { __syncthreads(); } 118 | } 119 | load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); 120 | if constexpr (!kDirectIO) { __syncthreads(); } 121 | load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); 122 | } 123 | u += kChunkSize; 124 | delta += kChunkSize; 125 | 126 | float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; 127 | #pragma unroll 128 | for (int r = 0; r < kNRows; ++r) { 129 | #pragma unroll 130 | for (int i = 0; i < kNItems; ++i) { 131 | float u_val = float(u_vals[r][i]); 132 | delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; 133 | if (params.delta_softplus) { 134 | delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; 135 | } 136 | delta_u_vals[r][i] = delta_vals[r][i] * u_val; 137 | out_vals[r][i] = D_val[r] * u_val; 138 | } 139 | } 140 | 141 | __syncthreads(); 142 | for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { 143 | weight_t A_val[kNRows]; 144 | #pragma unroll 145 | for (int r = 0; r < kNRows; ++r) { 146 | A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; 147 | // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. 148 | constexpr float kLog2e = M_LOG2E; 149 | A_val[r] *= kLog2e; 150 | } 151 | weight_t B_vals[kNItems], C_vals[kNItems]; 152 | load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, 153 | smem_load_weight, (params.seqlen - chunk * kChunkSize)); 154 | auto &smem_load_weight_C = smem_load_weight1; 155 | load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, 156 | smem_load_weight_C, (params.seqlen - chunk * kChunkSize)); 157 | #pragma unroll 158 | for (int r = 0; r < kNRows; ++r) { 159 | if (r > 0) { __syncthreads(); } // Scan could be using the same smem 160 | scan_t thread_data[kNItems]; 161 | #pragma unroll 162 | for (int i = 0; i < kNItems; ++i) { 163 | thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), 164 | B_vals[i] * delta_u_vals[r][i]); 165 | if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct 166 | if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { 167 | thread_data[i] = make_float2(1.f, 0.f); 168 | } 169 | } 170 | } 171 | // Initialize running total 172 | scan_t running_prefix; 173 | // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read 174 | running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * Ktraits::MaxDState] : make_float2(1.f, 0.f); 175 | // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); 176 | SSMScanPrefixCallbackOp prefix_op(running_prefix); 177 | Ktraits::BlockScanT(smem_scan).InclusiveScan( 178 | thread_data, thread_data, SSMScanOp(), prefix_op 179 | ); 180 | // There's a syncthreads in the scan op, so we don't need to sync here. 181 | // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. 182 | if (threadIdx.x == 0) { 183 | smem_running_prefix[state_idx + r * Ktraits::MaxDState] = prefix_op.running_prefix; 184 | x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; 185 | } 186 | #pragma unroll 187 | for (int i = 0; i < kNItems; ++i) { 188 | out_vals[r][i] += thread_data[i].y * C_vals[i]; 189 | } 190 | } 191 | } 192 | 193 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 194 | + dim_id_nrow * params.out_d_stride + chunk * kChunkSize; 195 | __syncthreads(); 196 | #pragma unroll 197 | for (int r = 0; r < kNRows; ++r) { 198 | if constexpr (!kDirectIO) { 199 | if (r > 0) { __syncthreads(); } 200 | } 201 | store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); 202 | } 203 | 204 | Bvar += kChunkSize; 205 | Cvar += kChunkSize; 206 | } 207 | } 208 | 209 | template 210 | void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { 211 | BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { 212 | using Ktraits = Selective_Scan_fwd_kernel_traits; 213 | constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * Ktraits::MaxDState * sizeof(typename Ktraits::scan_t); 214 | // printf("smem_size = %d\n", kSmemSize); 215 | dim3 grid(params.batch, params.dim / kNRows); 216 | auto kernel = &selective_scan_fwd_kernel; 217 | if (kSmemSize >= 48 * 1024) { 218 | C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 219 | } 220 | kernel<<>>(params); 221 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 222 | }); 223 | } 224 | 225 | template 226 | void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { 227 | if (params.seqlen <= 128) { 228 | selective_scan_fwd_launch<32, 4, knrows, input_t, weight_t>(params, stream); 229 | } else if (params.seqlen <= 256) { 230 | selective_scan_fwd_launch<32, 8, knrows, input_t, weight_t>(params, stream); 231 | } else if (params.seqlen <= 512) { 232 | selective_scan_fwd_launch<32, 16, knrows, input_t, weight_t>(params, stream); 233 | } else if (params.seqlen <= 1024) { 234 | selective_scan_fwd_launch<64, 16, knrows, input_t, weight_t>(params, stream); 235 | } else { 236 | selective_scan_fwd_launch<128, 16, knrows, input_t, weight_t>(params, stream); 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /classification/models/csms6s.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import numpy as np 4 | 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import torchvision.transforms as transforms 11 | import torchvision.models as models 12 | from PIL import Image 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | from einops.layers.torch import Rearrange 16 | 17 | import torchvision.transforms.functional as TF 18 | import matplotlib.pyplot as plt 19 | 20 | class CrossScan(nn.Module): 21 | 22 | def __init__(self, csms6s_mode, top_k, knn, alpha, ambiguity, binary, division_rate, device, weights, topk): 23 | super().__init__() 24 | self.csms6s_mode = csms6s_mode 25 | self.top_k = top_k 26 | self.knn = knn 27 | self.alpha = alpha 28 | self.ambiguity = ambiguity 29 | self.device = device 30 | self.binary = binary 31 | self.division_rate = division_rate 32 | 33 | self.weights = weights 34 | self.topk = topk 35 | 36 | def adjacency(self, feature_vector): 37 | 38 | batch_size, num_nodes, feature_dim = feature_vector.shape 39 | distances = torch.cdist(feature_vector, feature_vector, p=2) 40 | 41 | if self.weights == 'old': 42 | distances = torch.exp(-1 * self.alpha * (distances)**2) 43 | elif self.weights == 'new': 44 | sigma = torch.mean(distances) 45 | distances = torch.exp(-distances ** 2 / (2 * sigma ** 2)) 46 | 47 | if self.topk == 'yes': 48 | value, indices = torch.topk(distances, self.knn, dim=2, largest=True) 49 | 50 | adjacency_matrix = torch.zeros(batch_size, num_nodes, num_nodes, device=self.device) 51 | b_idx = torch.arange(batch_size, device='cuda')[:, None, None] 52 | n_idx = torch.arange(num_nodes, device='cuda')[:, None] 53 | 54 | # Use gathered distances as weights 55 | adjacency_matrix[b_idx, n_idx, indices] = value 56 | adjacency_matrix[b_idx, indices, n_idx] = value # Ensure symmetry 57 | 58 | return adjacency_matrix 59 | elif self.topk == 'no': 60 | return distances 61 | 62 | def compute_symmetric_laplacian(self, adjacency): 63 | 64 | degree = torch.sum(adjacency, dim=2) 65 | 66 | eps = 1e-5 67 | D_inv_sqrt = torch.pow(degree, -0.5) 68 | D_inv_sqrt = torch.diag_embed(D_inv_sqrt) 69 | 70 | I = torch.eye(adjacency.size(1), device=adjacency.device).unsqueeze(0) 71 | I = I.repeat(adjacency.size(0), 1, 1) 72 | 73 | laplacian = I - torch.bmm(torch.bmm(D_inv_sqrt, adjacency), D_inv_sqrt) 74 | 75 | return laplacian 76 | 77 | def topk_eigenvectors(self, eigenvalues, eigenvectors): 78 | 79 | # sort eigenvalues and corresponding eigenvectors + topk eigenvectors 80 | sorted_eigenvalues, indices = torch.sort(eigenvalues, dim=1) 81 | sorted_eigenvectors = torch.gather(eigenvectors, 2, indices.unsqueeze(1).expand(-1, eigenvectors.shape[1], -1)) 82 | smallest_eigenvectors = sorted_eigenvectors[:, :, 1:self.top_k+1] 83 | 84 | with torch.no_grad(): 85 | mean_vals = torch.mean(smallest_eigenvectors, dim=2) 86 | signs = mean_vals.sign() 87 | signs = signs.unsqueeze(2) 88 | smallest_eigenvectors *= signs 89 | 90 | sorted_smallest_eigenvectors, new_indices = torch.sort(smallest_eigenvectors, dim=1) 91 | return sorted_smallest_eigenvectors, new_indices 92 | 93 | def forward(self, features): 94 | 95 | if (self.csms6s_mode == "NORMAL"): 96 | B, D, H, W = features.shape 97 | features = features.view(B, D, -1).transpose(1, 2) 98 | 99 | w_matrix = self.adjacency(features) 100 | L_sym = self.compute_symmetric_laplacian(w_matrix) 101 | 102 | eigenvalues, eigenvectors = torch.linalg.eigh(L_sym) 103 | sorted_smallest_eigenvectors, topk_eigenvector_indexes = self.topk_eigenvectors(eigenvalues, eigenvectors) 104 | 105 | return sorted_smallest_eigenvectors, topk_eigenvector_indexes 106 | 107 | class CrossMerge(nn.Module): 108 | 109 | def __init__(self): 110 | super().__init__() 111 | 112 | def forward(self, ys, vec_indices): 113 | B, K, D, H, W = ys.shape 114 | 115 | argsorted_vec_indices = torch.argsort(vec_indices, 1).permute(0, 2, 1) 116 | argsorted_vec_indices = argsorted_vec_indices.unsqueeze(2).expand(-1, -1, D, -1) 117 | ys_partial = ys[:, :int(K/2), :, :].reshape(B, int(K/2), D, -1) 118 | 119 | result = torch.gather(ys_partial, dim=-1, index=argsorted_vec_indices) 120 | 121 | result_flip = ys[:, int(K/2):, :, :].reshape(B, int(K/2), D, -1).flip(-1) 122 | result_flip = torch.gather(result_flip, dim=-1, index=argsorted_vec_indices) 123 | 124 | ys = result + result_flip 125 | 126 | ys_ = 0 127 | for i in range(ys.shape[1]): 128 | ys_ += ys[:, i] 129 | 130 | return ys_ 131 | 132 | # these are for ablations ============= 133 | class CrossScan_Ab_2direction(torch.autograd.Function): 134 | @staticmethod 135 | def forward(ctx, x: torch.Tensor): 136 | B, C, H, W = x.shape 137 | ctx.shape = (B, C, H, W) 138 | x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) 139 | x = torch.cat([x, x.flip(dims=[-1])], dim=1) 140 | return x 141 | 142 | @staticmethod 143 | def backward(ctx, ys: torch.Tensor): 144 | B, C, H, W = ctx.shape 145 | L = H * W 146 | ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) 147 | return ys.sum(1).view(B, -1, H, W) 148 | 149 | class CrossMerge_Ab_2direction(torch.autograd.Function): 150 | @staticmethod 151 | def forward(ctx, ys: torch.Tensor): 152 | B, K, D, H, W = ys.shape 153 | ctx.shape = (H, W) 154 | ys = ys.view(B, K, D, -1) 155 | ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) 156 | return ys.contiguous().sum(1) 157 | 158 | @staticmethod 159 | def backward(ctx, x: torch.Tensor): 160 | H, W = ctx.shape 161 | B, C, L = x.shape 162 | x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) 163 | x = torch.cat([x, x.flip(dims=[-1])], dim=1) 164 | return x.view(B, 4, C, H, W) 165 | 166 | class CrossScan_Ab_1direction(torch.autograd.Function): 167 | @staticmethod 168 | def forward(ctx, x: torch.Tensor): 169 | B, C, H, W = x.shape 170 | ctx.shape = (B, C, H, W) 171 | x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) 172 | return x 173 | 174 | 175 | @staticmethod 176 | def backward(ctx, ys: torch.Tensor): 177 | B, C, H, W = ctx.shape 178 | return ys.view(B, 4, -1, H, W).sum(1) 179 | 180 | class CrossMerge_Ab_1direction(torch.autograd.Function): 181 | @staticmethod 182 | def forward(ctx, ys: torch.Tensor): 183 | B, K, C, H, W = ys.shape 184 | ctx.shape = (B, C, H, W) 185 | return ys.view(B, 4, -1, H * W).sum(1) 186 | 187 | @staticmethod 188 | def backward(ctx, x: torch.Tensor): 189 | B, C, H, W = ctx.shape 190 | return x.view(B, 1, C, H, W).repeat(1, 4, 1, 1, 1) 191 | 192 | # import selective scan ============================== 193 | try: 194 | import selective_scan_cuda_oflex 195 | except Exception as e: 196 | ... 197 | # print(f"WARNING: can not import selective_scan_cuda_oflex.", flush=True) 198 | # print(e, flush=True) 199 | 200 | try: 201 | import selective_scan_cuda_core 202 | except Exception as e: 203 | ... 204 | # print(f"WARNING: can not import selective_scan_cuda_core.", flush=True) 205 | # print(e, flush=True) 206 | 207 | try: 208 | import selective_scan_cuda 209 | except Exception as e: 210 | ... 211 | # print(f"WARNING: can not import selective_scan_cuda.", flush=True) 212 | # print(e, flush=True) 213 | 214 | 215 | def check_nan_inf(tag: str, x: torch.Tensor, enable=True): 216 | if enable: 217 | if torch.isinf(x).any() or torch.isnan(x).any(): 218 | print(tag, torch.isinf(x).any(), torch.isnan(x).any(), flush=True) 219 | import pdb; pdb.set_trace() 220 | 221 | 222 | # fvcore flops ======================================= 223 | def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False): 224 | """ 225 | u: r(B D L) 226 | delta: r(B D L) 227 | A: r(D N) 228 | B: r(B N L) 229 | C: r(B N L) 230 | D: r(D) 231 | z: r(B D L) 232 | delta_bias: r(D), fp32 233 | 234 | ignores: 235 | [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 236 | """ 237 | assert not with_complex 238 | # https://github.com/state-spaces/mamba/issues/110 239 | flops = 9 * B * L * D * N 240 | if with_D: 241 | flops += B * D * L 242 | if with_Z: 243 | flops += B * D * L 244 | return flops 245 | 246 | # this is only for selective_scan_ref... 247 | def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): 248 | """ 249 | u: r(B D L) 250 | delta: r(B D L) 251 | A: r(D N) 252 | B: r(B N L) 253 | C: r(B N L) 254 | D: r(D) 255 | z: r(B D L) 256 | delta_bias: r(D), fp32 257 | 258 | ignores: 259 | [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 260 | """ 261 | import numpy as np 262 | 263 | # fvcore.nn.jit_handles 264 | def get_flops_einsum(input_shapes, equation): 265 | np_arrs = [np.zeros(s) for s in input_shapes] 266 | optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] 267 | for line in optim.split("\n"): 268 | if "optimized flop" in line.lower(): 269 | # divided by 2 because we count MAC (multiply-add counted as one flop) 270 | flop = float(np.floor(float(line.split(":")[-1]) / 2)) 271 | return flop 272 | 273 | 274 | assert not with_complex 275 | 276 | flops = 0 # below code flops = 0 277 | 278 | flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") 279 | if with_Group: 280 | flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") 281 | else: 282 | flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") 283 | 284 | in_for_flops = B * D * N 285 | if with_Group: 286 | in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") 287 | else: 288 | in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") 289 | flops += L * in_for_flops 290 | if with_D: 291 | flops += B * D * L 292 | if with_Z: 293 | flops += B * D * L 294 | return flops 295 | 296 | 297 | def print_jit_input_names(inputs): 298 | print("input params: ", end=" ", flush=True) 299 | try: 300 | for i in range(10): 301 | print(inputs[i].debugName(), end=" ", flush=True) 302 | except Exception as e: 303 | pass 304 | print("", flush=True) 305 | 306 | # cross selective scan =============================== 307 | # comment all checks if inside cross_selective_scan 308 | class SelectiveScanMamba(torch.autograd.Function): 309 | @staticmethod 310 | @torch.cuda.amp.custom_fwd 311 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): 312 | ctx.delta_softplus = delta_softplus 313 | out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) 314 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 315 | return out 316 | 317 | @staticmethod 318 | @torch.cuda.amp.custom_bwd 319 | def backward(ctx, dout, *args): 320 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 321 | if dout.stride(-1) != 1: 322 | dout = dout.contiguous() 323 | 324 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( 325 | u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, 326 | False 327 | ) 328 | return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) 329 | 330 | 331 | class SelectiveScanCore(torch.autograd.Function): 332 | @staticmethod 333 | @torch.cuda.amp.custom_fwd 334 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): 335 | ctx.delta_softplus = delta_softplus 336 | out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) 337 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 338 | return out 339 | 340 | @staticmethod 341 | @torch.cuda.amp.custom_bwd 342 | def backward(ctx, dout, *args): 343 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 344 | if dout.stride(-1) != 1: 345 | dout = dout.contiguous() 346 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( 347 | u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 348 | ) 349 | return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) 350 | 351 | 352 | class SelectiveScanOflex(torch.autograd.Function): 353 | @staticmethod 354 | @torch.cuda.amp.custom_fwd 355 | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): 356 | ctx.delta_softplus = delta_softplus 357 | out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) 358 | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) 359 | return out 360 | 361 | @staticmethod 362 | @torch.cuda.amp.custom_bwd 363 | def backward(ctx, dout, *args): 364 | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors 365 | if dout.stride(-1) != 1: 366 | dout = dout.contiguous() 367 | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( 368 | u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 369 | ) 370 | return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) 371 | 372 | 373 | def selective_scan_flop_jit(inputs, outputs, flops_fn=flops_selective_scan_fn, verbose=True): 374 | if verbose: 375 | print_jit_input_names(inputs) 376 | B, D, L = inputs[0].type().sizes() 377 | N = inputs[2].type().sizes()[1] 378 | flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False) 379 | return flops 380 | 381 | 382 | 383 | 384 | -------------------------------------------------------------------------------- /kernels/selective_scan/csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "selective_scan_ndstate.h" 11 | #define MAX_DSTATE 256 12 | 13 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 14 | using weight_t = float; 15 | 16 | #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ 17 | if (ITYPE == at::ScalarType::Half) { \ 18 | using input_t = at::Half; \ 19 | __VA_ARGS__(); \ 20 | } else if (ITYPE == at::ScalarType::BFloat16) { \ 21 | using input_t = at::BFloat16; \ 22 | __VA_ARGS__(); \ 23 | } else if (ITYPE == at::ScalarType::Float) { \ 24 | using input_t = float; \ 25 | __VA_ARGS__(); \ 26 | } else { \ 27 | AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ 28 | } 29 | 30 | template 31 | void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 32 | 33 | template 34 | void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); 35 | 36 | void set_ssm_params_fwd(SSMParamsBase ¶ms, 37 | // sizes 38 | const size_t batch, 39 | const size_t dim, 40 | const size_t seqlen, 41 | const size_t n_groups, 42 | const size_t n_chunks, 43 | // device pointers 44 | const at::Tensor u, 45 | const at::Tensor delta, 46 | const at::Tensor A, 47 | const at::Tensor B, 48 | const at::Tensor C, 49 | const at::Tensor out, 50 | void* D_ptr, 51 | void* delta_bias_ptr, 52 | void* x_ptr, 53 | bool delta_softplus) { 54 | 55 | // Reset the parameters 56 | memset(¶ms, 0, sizeof(params)); 57 | 58 | params.batch = batch; 59 | params.dim = dim; 60 | params.seqlen = seqlen; 61 | params.n_groups = n_groups; 62 | params.n_chunks = n_chunks; 63 | params.dim_ngroups_ratio = dim / n_groups; 64 | 65 | params.delta_softplus = delta_softplus; 66 | 67 | // Set the pointers and strides. 68 | params.u_ptr = u.data_ptr(); 69 | params.delta_ptr = delta.data_ptr(); 70 | params.A_ptr = A.data_ptr(); 71 | params.B_ptr = B.data_ptr(); 72 | params.C_ptr = C.data_ptr(); 73 | params.D_ptr = D_ptr; 74 | params.delta_bias_ptr = delta_bias_ptr; 75 | params.out_ptr = out.data_ptr(); 76 | params.x_ptr = x_ptr; 77 | 78 | // All stride are in elements, not bytes. 79 | params.A_d_stride = A.stride(0); 80 | params.B_batch_stride = B.stride(0); 81 | params.B_group_stride = B.stride(1); 82 | params.C_batch_stride = C.stride(0); 83 | params.C_group_stride = C.stride(1); 84 | params.u_batch_stride = u.stride(0); 85 | params.u_d_stride = u.stride(1); 86 | params.delta_batch_stride = delta.stride(0); 87 | params.delta_d_stride = delta.stride(1); 88 | 89 | params.out_batch_stride = out.stride(0); 90 | params.out_d_stride = out.stride(1); 91 | } 92 | 93 | void set_ssm_params_bwd(SSMParamsBwd ¶ms, 94 | // sizes 95 | const size_t batch, 96 | const size_t dim, 97 | const size_t seqlen, 98 | const size_t n_groups, 99 | const size_t n_chunks, 100 | // device pointers 101 | const at::Tensor u, 102 | const at::Tensor delta, 103 | const at::Tensor A, 104 | const at::Tensor B, 105 | const at::Tensor C, 106 | const at::Tensor out, 107 | void* D_ptr, 108 | void* delta_bias_ptr, 109 | void* x_ptr, 110 | const at::Tensor dout, 111 | const at::Tensor du, 112 | const at::Tensor ddelta, 113 | const at::Tensor dA, 114 | const at::Tensor dB, 115 | const at::Tensor dC, 116 | void* dD_ptr, 117 | void* ddelta_bias_ptr, 118 | bool delta_softplus) { 119 | // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z 120 | set_ssm_params_fwd(params, batch, dim, seqlen, n_groups, n_chunks, 121 | u, delta, A, B, C, dout, 122 | D_ptr, delta_bias_ptr, x_ptr, delta_softplus); 123 | 124 | // Set the pointers and strides. 125 | params.dout_ptr = dout.data_ptr(); 126 | params.du_ptr = du.data_ptr(); 127 | params.dA_ptr = dA.data_ptr(); 128 | params.dB_ptr = dB.data_ptr(); 129 | params.dC_ptr = dC.data_ptr(); 130 | params.dD_ptr = dD_ptr; 131 | params.ddelta_ptr = ddelta.data_ptr(); 132 | params.ddelta_bias_ptr = ddelta_bias_ptr; 133 | // All stride are in elements, not bytes. 134 | params.dout_batch_stride = dout.stride(0); 135 | params.dout_d_stride = dout.stride(1); 136 | params.dA_d_stride = dA.stride(0); 137 | params.dB_batch_stride = dB.stride(0); 138 | params.dB_group_stride = dB.stride(1); 139 | params.dC_batch_stride = dC.stride(0); 140 | params.dC_group_stride = dC.stride(1); 141 | params.du_batch_stride = du.stride(0); 142 | params.du_d_stride = du.stride(1); 143 | params.ddelta_batch_stride = ddelta.stride(0); 144 | params.ddelta_d_stride = ddelta.stride(1); 145 | 146 | } 147 | 148 | std::vector 149 | selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, 150 | const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, 151 | const c10::optional &D_, 152 | const c10::optional &delta_bias_, 153 | bool delta_softplus, 154 | int nrows 155 | ) { 156 | auto input_type = u.scalar_type(); 157 | auto weight_type = A.scalar_type(); 158 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 159 | TORCH_CHECK(weight_type == at::ScalarType::Float); 160 | 161 | TORCH_CHECK(delta.scalar_type() == input_type); 162 | TORCH_CHECK(B.scalar_type() == input_type); 163 | TORCH_CHECK(C.scalar_type() == input_type); 164 | 165 | TORCH_CHECK(u.is_cuda()); 166 | TORCH_CHECK(delta.is_cuda()); 167 | TORCH_CHECK(A.is_cuda()); 168 | TORCH_CHECK(B.is_cuda()); 169 | TORCH_CHECK(C.is_cuda()); 170 | 171 | TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); 172 | TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); 173 | 174 | const auto sizes = u.sizes(); 175 | const int batch_size = sizes[0]; 176 | const int dim = sizes[1]; 177 | const int seqlen = sizes[2]; 178 | const int n_groups = B.size(1); 179 | 180 | TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups"); 181 | 182 | CHECK_SHAPE(u, batch_size, dim, seqlen); 183 | CHECK_SHAPE(delta, batch_size, dim, seqlen); 184 | CHECK_SHAPE(A, dim); 185 | CHECK_SHAPE(B, batch_size, n_groups, seqlen); 186 | TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); 187 | CHECK_SHAPE(C, batch_size, n_groups, seqlen); 188 | TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); 189 | 190 | if (D_.has_value()) { 191 | auto D = D_.value(); 192 | TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); 193 | TORCH_CHECK(D.is_cuda()); 194 | TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); 195 | CHECK_SHAPE(D, dim); 196 | } 197 | 198 | if (delta_bias_.has_value()) { 199 | auto delta_bias = delta_bias_.value(); 200 | TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); 201 | TORCH_CHECK(delta_bias.is_cuda()); 202 | TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); 203 | CHECK_SHAPE(delta_bias, dim); 204 | } 205 | 206 | const int n_chunks = (seqlen + 2048 - 1) / 2048; // max is 128 * 16 = 2048 in fwd_kernel 207 | at::Tensor out = torch::empty_like(delta); 208 | at::Tensor x; 209 | x = torch::empty({batch_size, dim, n_chunks, 1 * 2}, u.options().dtype(weight_type)); 210 | 211 | SSMParamsBase params; 212 | set_ssm_params_fwd(params, batch_size, dim, seqlen, n_groups, n_chunks, 213 | u, delta, A, B, C, out, 214 | D_.has_value() ? D_.value().data_ptr() : nullptr, 215 | delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, 216 | x.data_ptr(), 217 | delta_softplus); 218 | 219 | // Otherwise the kernel will be launched from cuda:0 device 220 | // Cast to char to avoid compiler warning about narrowing 221 | at::cuda::CUDAGuard device_guard{(char)u.get_device()}; 222 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 223 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { 224 | selective_scan_fwd_cuda<1, input_t, weight_t>(params, stream); 225 | }); 226 | std::vector result = {out, x}; 227 | return result; 228 | } 229 | 230 | std::vector 231 | selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, 232 | const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, 233 | const c10::optional &D_, 234 | const c10::optional &delta_bias_, 235 | const at::Tensor &dout, 236 | const c10::optional &x_, 237 | bool delta_softplus, 238 | int nrows 239 | ) { 240 | auto input_type = u.scalar_type(); 241 | auto weight_type = A.scalar_type(); 242 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 243 | TORCH_CHECK(weight_type == at::ScalarType::Float); 244 | 245 | TORCH_CHECK(delta.scalar_type() == input_type); 246 | TORCH_CHECK(B.scalar_type() == input_type); 247 | TORCH_CHECK(C.scalar_type() == input_type); 248 | TORCH_CHECK(dout.scalar_type() == input_type); 249 | 250 | TORCH_CHECK(u.is_cuda()); 251 | TORCH_CHECK(delta.is_cuda()); 252 | TORCH_CHECK(A.is_cuda()); 253 | TORCH_CHECK(B.is_cuda()); 254 | TORCH_CHECK(C.is_cuda()); 255 | TORCH_CHECK(dout.is_cuda()); 256 | 257 | TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); 258 | TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); 259 | TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); 260 | 261 | const auto sizes = u.sizes(); 262 | const int batch_size = sizes[0]; 263 | const int dim = sizes[1]; 264 | const int seqlen = sizes[2]; 265 | const int n_groups = B.size(1); 266 | 267 | TORCH_CHECK(dim % n_groups == 0, "dims should be dividable by n_groups"); 268 | 269 | CHECK_SHAPE(u, batch_size, dim, seqlen); 270 | CHECK_SHAPE(delta, batch_size, dim, seqlen); 271 | CHECK_SHAPE(A, dim); 272 | CHECK_SHAPE(B, batch_size, n_groups, seqlen); 273 | TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); 274 | CHECK_SHAPE(C, batch_size, n_groups, seqlen); 275 | TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); 276 | CHECK_SHAPE(dout, batch_size, dim, seqlen); 277 | 278 | if (D_.has_value()) { 279 | auto D = D_.value(); 280 | TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); 281 | TORCH_CHECK(D.is_cuda()); 282 | TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); 283 | CHECK_SHAPE(D, dim); 284 | } 285 | 286 | if (delta_bias_.has_value()) { 287 | auto delta_bias = delta_bias_.value(); 288 | TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); 289 | TORCH_CHECK(delta_bias.is_cuda()); 290 | TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); 291 | CHECK_SHAPE(delta_bias, dim); 292 | } 293 | 294 | at::Tensor out; 295 | const int n_chunks = (seqlen + 2048 - 1) / 2048; 296 | // const int n_chunks = (seqlen + 1024 - 1) / 1024; 297 | if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } 298 | if (x_.has_value()) { 299 | auto x = x_.value(); 300 | TORCH_CHECK(x.scalar_type() == weight_type); 301 | TORCH_CHECK(x.is_cuda()); 302 | TORCH_CHECK(x.is_contiguous()); 303 | CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * 1); 304 | } 305 | 306 | at::Tensor du = torch::empty_like(u); 307 | at::Tensor ddelta = torch::empty_like(delta); 308 | at::Tensor dA = torch::zeros_like(A); 309 | at::Tensor dB = torch::zeros_like(B, B.options().dtype(torch::kFloat32)); 310 | at::Tensor dC = torch::zeros_like(C, C.options().dtype(torch::kFloat32)); 311 | at::Tensor dD; 312 | if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } 313 | at::Tensor ddelta_bias; 314 | if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } 315 | 316 | SSMParamsBwd params; 317 | set_ssm_params_bwd(params, batch_size, dim, seqlen, n_groups, n_chunks, 318 | u, delta, A, B, C, out, 319 | D_.has_value() ? D_.value().data_ptr() : nullptr, 320 | delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, 321 | x_.has_value() ? x_.value().data_ptr() : nullptr, 322 | dout, du, ddelta, dA, dB, dC, 323 | D_.has_value() ? dD.data_ptr() : nullptr, 324 | delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, 325 | delta_softplus); 326 | 327 | // Otherwise the kernel will be launched from cuda:0 device 328 | // Cast to char to avoid compiler warning about narrowing 329 | at::cuda::CUDAGuard device_guard{(char)u.get_device()}; 330 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 331 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { 332 | selective_scan_bwd_cuda<1, input_t, weight_t>(params, stream); 333 | }); 334 | std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; 335 | return result; 336 | } 337 | 338 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 339 | m.def("fwd", &selective_scan_fwd, "Selective scan forward"); 340 | m.def("bwd", &selective_scan_bwd, "Selective scan backward"); 341 | } 342 | --------------------------------------------------------------------------------