├── lib ├── __init__.py ├── config │ ├── __init__.py │ ├── mambatrack │ │ └── config.py │ └── mambatrack_motion │ │ └── config.py ├── test │ ├── __init__.py │ ├── tracker │ │ ├── __init__.py │ │ ├── vis_utils.py │ │ ├── data_utils.py │ │ └── basetracker.py │ ├── analysis │ │ └── __init__.py │ ├── parameter │ │ ├── __init__.py │ │ ├── mambatrack.py │ │ └── mambatrack_motion.py │ ├── utils │ │ ├── __init__.py │ │ ├── _init_paths.py │ │ ├── params.py │ │ ├── transform_trackingnet.py │ │ ├── load_text.py │ │ ├── transform_got10k.py │ │ └── hann.py │ └── evaluation │ │ ├── __init__.py │ │ ├── tc128dataset.py │ │ ├── tnl2kdataset.py │ │ ├── tc128cedataset.py │ │ ├── got10kdataset.py │ │ ├── datasets.py │ │ ├── trackingnetdataset.py │ │ ├── itbdataset.py │ │ └── environment.py ├── vis │ ├── __init__.py │ └── utils.py ├── models │ ├── layers │ │ ├── __init__.py │ │ ├── patch_embed.py │ │ ├── frozen_bn.py │ │ ├── rpe.py │ │ └── attn.py │ ├── mambatrack │ │ ├── __init__.py │ │ └── utils.py │ └── __init__.py ├── train │ ├── __init__.py │ ├── trainers │ │ ├── __init__.py │ │ └── misc.py │ ├── data_specs │ │ ├── depthtrack_val.txt │ │ ├── lasher_motion_val.txt │ │ ├── lasher_val.txt │ │ ├── depthtrack_train.txt │ │ └── depthtrack_all.txt │ ├── data │ │ ├── __init__.py │ │ ├── wandb_logger.py │ │ ├── bounding_box_utils.py │ │ └── image_loader.py │ ├── actors │ │ ├── __init__.py │ │ └── base_actor.py │ ├── admin │ │ ├── __init__.py │ │ ├── settings.py │ │ ├── multigpu.py │ │ ├── tensorboard.py │ │ ├── stats.py │ │ └── environment.py │ ├── _init_paths.py │ ├── dataset │ │ ├── __init__.py │ │ ├── base_image_dataset.py │ │ ├── base_video_dataset.py │ │ ├── coesot.py │ │ ├── imagenetvid_lmdb.py │ │ └── lasher.py │ ├── train_script.py │ ├── train_script_distill.py │ └── run_training.py └── utils │ ├── __init__.py │ ├── merge.py │ ├── variable_hook.py │ ├── lmdb_utils.py │ ├── focal_loss.py │ ├── ce_utils.py │ └── box_ops.py ├── checkpoints └── .gitkeep ├── pretrained_models └── .gitkeep ├── mamba-1p1p1 ├── mamba_ssm │ ├── ops │ │ ├── __init__.py │ │ └── triton │ │ │ └── __init__.py │ ├── models │ │ ├── __init__.py │ │ └── config_mamba.py │ ├── modules │ │ └── __init__.py │ ├── utils │ │ ├── __init__.py │ │ └── hf.py │ └── __init__.py ├── .gitignore ├── AUTHORS ├── assets │ └── selection.png ├── .gitmodules ├── csrc │ └── selective_scan │ │ ├── selective_scan_bwd_fp16_real.cu │ │ ├── selective_scan_bwd_fp32_real.cu │ │ ├── selective_scan_bwd_bf16_real.cu │ │ ├── selective_scan_bwd_fp32_complex.cu │ │ ├── selective_scan_bwd_bf16_complex.cu │ │ ├── selective_scan_bwd_fp16_complex.cu │ │ ├── selective_scan_fwd_fp32.cu │ │ ├── selective_scan_fwd_fp16.cu │ │ ├── selective_scan_fwd_bf16.cu │ │ ├── static_switch.h │ │ ├── uninitialized_copy.cuh │ │ └── selective_scan.h ├── evals │ └── lm_harness_eval.py ├── tests │ └── ops │ │ └── triton │ │ └── test_selective_state_update.py └── benchmarks │ └── benchmark_generation_mamba_simple.py ├── assets └── framework.png ├── tracking ├── _init_paths.py ├── create_default_local_file.py ├── train.py └── profile_model.py ├── xtrain.sh ├── xtrain_motion.sh ├── ytest.sh ├── ytest_motion.sh ├── experiments ├── mambatrack │ ├── mambavt_m256_ep20.yaml │ ├── mambavt_s256_ep20.yaml │ ├── mambavt_m256_ep20_lasher.yaml │ └── mambavt_s256_ep20_lasher.yaml └── mambatrack_motion │ ├── mambavt_motion_m256_ep10_lasher.yaml │ ├── mambavt_motion_s256_ep10_lasher.yaml │ ├── mambavt_motion_m256_ep10.yaml │ └── mambavt_motion_s256_ep10.yaml ├── install.sh └── README.md /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/vis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/test/tracker/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pretrained_models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/test/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/test/parameter/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .admin.multigpu import MultiGPU 2 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensor import TensorDict, TensorList 2 | -------------------------------------------------------------------------------- /mamba-1p1p1/.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | *.egg-info/ 3 | build/ 4 | **.so 5 | -------------------------------------------------------------------------------- /lib/test/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .params import TrackerParams, FeatureParams, Choice -------------------------------------------------------------------------------- /mamba-1p1p1/AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | Albert Gu, agu@andrew.cmu.edu 3 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laisimiao/MambaVT/HEAD/assets/framework.png -------------------------------------------------------------------------------- /lib/train/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | from .ltr_trainer import LTRTrainer 3 | -------------------------------------------------------------------------------- /mamba-1p1p1/assets/selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laisimiao/MambaVT/HEAD/mamba-1p1p1/assets/selection.png -------------------------------------------------------------------------------- /lib/models/mambatrack/__init__.py: -------------------------------------------------------------------------------- 1 | from .mambatrack import build_mambatrack 2 | from .mambatrack_motion import build_mambatrack_motion 3 | -------------------------------------------------------------------------------- /lib/train/data_specs/depthtrack_val.txt: -------------------------------------------------------------------------------- 1 | toy03_indoor 2 | pigeon05_wild 3 | bottle03_indoor 4 | ball16_indoor 5 | bag04_indoor 6 | flower03_indoor -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mambatrack.mambatrack import build_mambatrack 2 | from .mambatrack.mambatrack_motion import build_mambatrack_motion 3 | -------------------------------------------------------------------------------- /lib/train/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import LTRLoader 2 | from .image_loader import jpeg4py_loader, opencv_loader, jpeg4py_loader_w_failsafe, default_image_loader 3 | -------------------------------------------------------------------------------- /lib/vis/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def numpy_to_torch(a: np.ndarray): 6 | return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0) -------------------------------------------------------------------------------- /lib/train/actors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_actor import BaseActor 2 | from .mambatrack_actor import MambaTrackActor 3 | from .mambatrack_motion_actor import MambaTrackMotionActor 4 | -------------------------------------------------------------------------------- /mamba-1p1p1/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/lm-evaluation-harness"] 2 | path = 3rdparty/lm-evaluation-harness 3 | url = https://github.com/EleutherAI/lm-evaluation-harness/ 4 | -------------------------------------------------------------------------------- /lib/train/admin/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_settings, create_default_local_file_ITP_train 2 | from .stats import AverageMeter, StatValue 3 | from .tensorboard import TensorboardWriter 4 | -------------------------------------------------------------------------------- /lib/test/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import Sequence 2 | from .tracker import Tracker, trackerlist 3 | from .datasets import get_dataset 4 | from .environment import create_default_local_file_ITP_test -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.1" 2 | 3 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 4 | from mamba_ssm.modules.mamba_simple import Mamba 5 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /lib/train/admin/settings.py: -------------------------------------------------------------------------------- 1 | from lib.train.admin.environment import env_settings 2 | 3 | 4 | class Settings: 5 | """ Training settings, e.g. the paths to datasets and networks.""" 6 | def __init__(self): 7 | self.set_default() 8 | 9 | def set_default(self): 10 | self.env = env_settings() 11 | self.use_gpu = True 12 | 13 | 14 | -------------------------------------------------------------------------------- /tracking/_init_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | 14 | this_dir = osp.dirname(__file__) 15 | 16 | prj_path = osp.join(this_dir, '..') 17 | add_path(prj_path) 18 | -------------------------------------------------------------------------------- /lib/train/_init_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | 14 | this_dir = osp.dirname(__file__) 15 | 16 | prj_path = osp.join(this_dir, '../..') 17 | add_path(prj_path) 18 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/models/config_mamba.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass 5 | class MambaConfig: 6 | 7 | d_model: int = 2560 8 | n_layer: int = 64 9 | vocab_size: int = 50277 10 | ssm_cfg: dict = field(default_factory=dict) 11 | rms_norm: bool = True 12 | residual_in_fp32: bool = True 13 | fused_add_norm: bool = True 14 | pad_vocab_size_multiple: int = 8 15 | -------------------------------------------------------------------------------- /lib/test/utils/_init_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | 14 | this_dir = osp.dirname(__file__) 15 | 16 | prj_path = osp.join(this_dir, '..', '..', '..') 17 | add_path(prj_path) 18 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_real.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp32_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_bf16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_bwd_fp16_complex.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_bwd_kernel.cuh" 8 | 9 | template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /lib/train/admin/multigpu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training 3 | 4 | 5 | def is_multi_gpu(net): 6 | return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel)) 7 | 8 | 9 | class MultiGPU(nn.parallel.distributed.DistributedDataParallel): 10 | def __getattr__(self, item): 11 | try: 12 | return super().__getattr__(item) 13 | except: 14 | pass 15 | return getattr(self.module, item) 16 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp32.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_fp16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan_fwd_bf16.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | // Split into multiple files to compile in paralell 6 | 7 | #include "selective_scan_fwd_kernel.cuh" 8 | 9 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); 10 | template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /lib/train/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # from .lasot import Lasot 2 | # from .got10k import Got10k 3 | # from .tracking_net import TrackingNet 4 | # from .imagenetvid import ImagenetVID 5 | # from .coco import MSCOCO 6 | # from .coco_seq import MSCOCOSeq 7 | # from .got10k_lmdb import Got10k_lmdb 8 | # from .lasot_lmdb import Lasot_lmdb 9 | # from .imagenetvid_lmdb import ImagenetVID_lmdb 10 | # from .coco_seq_lmdb import MSCOCOSeq_lmdb 11 | # from .tracking_net_lmdb import TrackingNet_lmdb 12 | # RGBT dataloader 13 | from .lasher import LasHeR 14 | from .visevent import VisEvent 15 | from .depthtrack import DepthTrack 16 | from .coesot import COESOT 17 | from .lasher_motion import LasHeRMotion -------------------------------------------------------------------------------- /xtrain.sh: -------------------------------------------------------------------------------- 1 | 2 | # MambaVT-Small 3 | ## train for non-lasher evaluation 4 | python tracking/train.py --script mambatrack --config mambavt_s256_ep20 --save_dir ./output --mode multiple --nproc_per_node 2 5 | ## train for lasher evaluation 6 | python tracking/train.py --script mambatrack --config mambavt_s256_ep20_lasher --save_dir ./output --mode multiple --nproc_per_node 2 7 | 8 | 9 | # MambaVT-Middle 10 | ## train for non-lasher evaluation 11 | python tracking/train.py --script mambatrack --config mambavt_m256_ep20 --save_dir ./output --mode multiple --nproc_per_node 2 12 | ## train for lasher evaluation 13 | python tracking/train.py --script mambatrack --config mambavt_m256_ep20_lasher --save_dir ./output --mode multiple --nproc_per_node 2 14 | -------------------------------------------------------------------------------- /xtrain_motion.sh: -------------------------------------------------------------------------------- 1 | 2 | # MambaVT-Small 3 | ## train for non-lasher evaluation 4 | python tracking/train.py --script mambatrack_motion --config mambavt_motion_s256_ep10 --save_dir ./output --mode multiple --nproc_per_node 2 5 | ## train for lasher evaluation 6 | python tracking/train.py --script mambatrack_motion --config mambavt_motion_s256_ep10_lasher --save_dir ./output --mode multiple --nproc_per_node 2 7 | 8 | 9 | # MambaVT-Middle 10 | ## train for non-lasher evaluation 11 | python tracking/train.py --script mambatrack_motion --config mambavt_motion_m256_ep10 --save_dir ./output --mode multiple --nproc_per_node 2 12 | ## train for lasher evaluation 13 | python tracking/train.py --script mambatrack_motion --config mambavt_motion_m256_ep10_lasher --save_dir ./output --mode multiple --nproc_per_node 2 14 | -------------------------------------------------------------------------------- /tracking/create_default_local_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import _init_paths 4 | from lib.train.admin import create_default_local_file_ITP_train 5 | from lib.test.evaluation import create_default_local_file_ITP_test 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description='Create default local file on ITP or PAI') 10 | parser.add_argument("--workspace_dir", type=str, required=True) # workspace dir 11 | parser.add_argument("--data_dir", type=str, required=True) 12 | parser.add_argument("--save_dir", type=str, required=True) 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | if __name__ == "__main__": 18 | args = parse_args() 19 | workspace_dir = os.path.realpath(args.workspace_dir) 20 | data_dir = os.path.realpath(args.data_dir) 21 | save_dir = os.path.realpath(args.save_dir) 22 | create_default_local_file_ITP_train(workspace_dir, data_dir) 23 | create_default_local_file_ITP_test(workspace_dir, data_dir, save_dir) 24 | -------------------------------------------------------------------------------- /mamba-1p1p1/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /ytest.sh: -------------------------------------------------------------------------------- 1 | # test for MambaVT-Small 2 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack --yaml_name mambavt_s256_ep20 --dataset_name GTOT --epoch 20 3 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack --yaml_name mambavt_s256_ep20 --dataset_name RGBT234 --epoch 20 4 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack --yaml_name mambavt_s256_ep20 --dataset_name RGBT210 --epoch 20 5 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack --yaml_name mambavt_s256_ep20_lasher --dataset_name LasHeR --epoch 20 6 | 7 | # test for MambaVT-Middle 8 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack --yaml_name mambavt_m256_ep20 --dataset_name GTOT --epoch 20 9 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack --yaml_name mambavt_m256_ep20 --dataset_name RGBT234 --epoch 20 10 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack --yaml_name mambavt_m256_ep20 --dataset_name RGBT210 --epoch 20 11 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack --yaml_name mambavt_m256_ep20_lasher --dataset_name LasHeR --epoch 20 -------------------------------------------------------------------------------- /ytest_motion.sh: -------------------------------------------------------------------------------- 1 | # test for MambaVT-Small 2 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack_motion --yaml_name mambavt_motion_s256_ep10 --dataset_name GTOT --epoch 10 3 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack_motion --yaml_name mambavt_motion_s256_ep10 --dataset_name RGBT234 --epoch 10 4 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack_motion --yaml_name mambavt_motion_s256_ep10 --dataset_name RGBT210 --epoch 10 5 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack_motion --yaml_name mambavt_motion_s256_ep10_lasher --dataset_name LasHeR --epoch 10 6 | 7 | # test for MambaVT-Middle 8 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack_motion --yaml_name mambavt_motion_m256_ep10 --dataset_name GTOT --epoch 10 9 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack_motion --yaml_name mambavt_motion_m256_ep10 --dataset_name RGBT234 --epoch 10 10 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack_motion --yaml_name mambavt_motion_m256_ep10 --dataset_name RGBT210 --epoch 10 11 | python ./RGBT_workspace/test_rgbt_mgpus.py --script_name mambatrack_motion --yaml_name mambavt_motion_m256_ep10_lasher --dataset_name LasHeR --epoch 10 -------------------------------------------------------------------------------- /lib/train/data/wandb_logger.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | try: 4 | import wandb 5 | except ImportError: 6 | raise ImportError( 7 | 'Please run "pip install wandb" to install wandb') 8 | 9 | 10 | class WandbWriter: 11 | def __init__(self, exp_name, cfg, output_dir, cur_step=0, step_interval=0): 12 | self.wandb = wandb 13 | self.step = cur_step 14 | self.interval = step_interval 15 | wandb.init(project="tracking", name=exp_name, config=cfg, dir=output_dir) 16 | 17 | def write_log(self, stats: OrderedDict, epoch=-1): 18 | self.step += 1 19 | for loader_name, loader_stats in stats.items(): 20 | if loader_stats is None: 21 | continue 22 | 23 | log_dict = {} 24 | for var_name, val in loader_stats.items(): 25 | if hasattr(val, 'avg'): 26 | log_dict.update({loader_name + '/' + var_name: val.avg}) 27 | else: 28 | log_dict.update({loader_name + '/' + var_name: val.val}) 29 | 30 | if epoch >= 0: 31 | log_dict.update({loader_name + '/epoch': epoch}) 32 | 33 | self.wandb.log(log_dict, step=self.step*self.interval) 34 | -------------------------------------------------------------------------------- /lib/train/admin/tensorboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | try: 4 | from torch.utils.tensorboard import SummaryWriter 5 | except: 6 | print('WARNING: You are using tensorboardX instead sis you have a too old pytorch version.') 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | class TensorboardWriter: 11 | def __init__(self, directory, loader_names): 12 | self.directory = directory 13 | self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names}) 14 | 15 | def write_info(self, script_name, description): 16 | tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info')) 17 | tb_info_writer.add_text('Script_name', script_name) 18 | tb_info_writer.add_text('Description', description) 19 | tb_info_writer.close() 20 | 21 | def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1): 22 | for loader_name, loader_stats in stats.items(): 23 | if loader_stats is None: 24 | continue 25 | for var_name, val in loader_stats.items(): 26 | if hasattr(val, 'history') and getattr(val, 'has_new_data', True): 27 | self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch) -------------------------------------------------------------------------------- /lib/test/parameter/mambatrack.py: -------------------------------------------------------------------------------- 1 | from lib.test.utils import TrackerParams 2 | import os 3 | from lib.test.evaluation.environment import env_settings 4 | from lib.config.mambatrack.config import cfg, update_config_from_file 5 | 6 | 7 | def parameters(yaml_name: str, epoch=300, debug=False): 8 | params = TrackerParams() 9 | prj_dir = env_settings().prj_dir 10 | save_dir = env_settings().save_dir 11 | # update default config from yaml file 12 | yaml_file = os.path.join(prj_dir, 'experiments/mambatrack/%s.yaml' % yaml_name) 13 | update_config_from_file(yaml_file) 14 | params.cfg = cfg 15 | # if debug: 16 | params.debug = debug 17 | # print("test config: ", cfg) 18 | 19 | # template and search region 20 | params.template_factor = cfg.TEST.TEMPLATE_FACTOR 21 | params.template_size = cfg.TEST.TEMPLATE_SIZE 22 | params.search_factor = cfg.TEST.SEARCH_FACTOR 23 | params.search_size = cfg.TEST.SEARCH_SIZE 24 | 25 | # Network checkpoint path 26 | params.checkpoint = os.path.join(save_dir, "checkpoints/train/mambatrack/%s/MambaTrack_ep%04d.pth.tar" % 27 | (yaml_name, epoch)) # cfg.TEST.EPOCH 28 | 29 | # whether to save boxes from all queries 30 | params.save_all_boxes = False 31 | 32 | return params 33 | -------------------------------------------------------------------------------- /lib/utils/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def merge_template_search(inp_list, return_search=False, return_template=False): 5 | """NOTICE: search region related features must be in the last place""" 6 | seq_dict = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 7 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 8 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} 9 | if return_search: 10 | x = inp_list[-1] 11 | seq_dict.update({"feat_x": x["feat"], "mask_x": x["mask"], "pos_x": x["pos"]}) 12 | if return_template: 13 | z = inp_list[0] 14 | seq_dict.update({"feat_z": z["feat"], "mask_z": z["mask"], "pos_z": z["pos"]}) 15 | return seq_dict 16 | 17 | 18 | def get_qkv(inp_list): 19 | """The 1st element of the inp_list is about the template, 20 | the 2nd (the last) element is about the search region""" 21 | dict_x = inp_list[-1] 22 | dict_c = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 23 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 24 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} # concatenated dict 25 | q = dict_x["feat"] + dict_x["pos"] 26 | k = dict_c["feat"] + dict_c["pos"] 27 | v = dict_c["feat"] 28 | key_padding_mask = dict_c["mask"] 29 | return q, k, v, key_padding_mask 30 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /lib/models/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from timm.models.layers import to_2tuple 4 | 5 | 6 | class PatchEmbed(nn.Module): 7 | """ 2D Image to Patch Embedding 8 | """ 9 | 10 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 11 | super().__init__() 12 | img_size = to_2tuple(img_size) 13 | patch_size = to_2tuple(patch_size) 14 | self.img_size = img_size 15 | self.patch_size = patch_size 16 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 17 | self.num_patches = self.grid_size[0] * self.grid_size[1] 18 | self.flatten = flatten 19 | 20 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 21 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 22 | 23 | def forward(self, x): 24 | # allow different input size 25 | # B, C, H, W = x.shape 26 | # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 27 | # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 28 | x = self.proj(x) 29 | if self.flatten: 30 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 31 | x = self.norm(x) 32 | return x 33 | -------------------------------------------------------------------------------- /lib/train/data_specs/lasher_motion_val.txt: -------------------------------------------------------------------------------- 1 | boywalkinginsnow3 2 | leftgirlunderthelamp 3 | girlridesbike 4 | midboyplayingphone 5 | boywithumbrella 6 | manrun 7 | whitecarturnl 8 | girltakemoto 9 | rightgirlatbike 10 | easy_blackboy 11 | man_with_black_clothes2 12 | 7runone 13 | turnblkbike 14 | motobesidescar 15 | bikeafterwhitecar 16 | 2runsix 17 | rightboy_1227 18 | whitesuvcome 19 | AQrightofcomingmotos 20 | 7one 21 | blackman_0115 22 | rightmirrornotshining 23 | AQmanfromdarktrees 24 | orangegirl 25 | girlturnbike 26 | blackman2 27 | blackcarback 28 | rightof2cupsattached 29 | whitecar2west 30 | hatboy`shead 31 | whitebetweenblackandblue 32 | 2rdcarcome 33 | whitemancome 34 | nearmangotoD 35 | farmanrightwhitesmallhouse 36 | lightmotocoming 37 | boymototakesgirl 38 | leftblackboy 39 | righttallholdball 40 | blackcarcome 41 | lowerfoam2throw 42 | Awhitecargo 43 | car2north3 44 | girltakingplate 45 | ab_bolster 46 | 9hatboy 47 | whitecarturn2 48 | basketboywhite 49 | nightmototurn 50 | girlbike 51 | mantoground 52 | 8lastone 53 | AQbikeback 54 | blkbikefromnorth 55 | whitecar 56 | Amidredgirl 57 | AQblkgirlbike 58 | browncar2north 59 | carstop 60 | whiteboywithbag 61 | girlafterglassdoor2 62 | midboy 63 | bikecome 64 | Agirlrideback 65 | rightgirl 66 | moto2north1 67 | truckk 68 | whiteboy 69 | truckwhite 70 | AQgirlbiketurns 71 | left2ndboy 72 | whitegirl2right 73 | girlplayingphone 74 | girlumbrella 75 | truck 76 | manfarbesidespool 77 | -------------------------------------------------------------------------------- /mamba-1p1p1/evals/lm_harness_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import transformers 4 | from transformers import AutoTokenizer 5 | 6 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 7 | 8 | from lm_eval.api.model import LM 9 | from lm_eval.models.huggingface import HFLM 10 | from lm_eval.api.registry import register_model 11 | from lm_eval.__main__ import cli_evaluate 12 | 13 | 14 | @register_model("mamba") 15 | class MambaEvalWrapper(HFLM): 16 | 17 | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM 18 | 19 | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", 20 | dtype=torch.float16): 21 | LM.__init__(self) 22 | self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) 23 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 25 | self.vocab_size = self.tokenizer.vocab_size 26 | self._batch_size = int(batch_size) if batch_size is not None else 64 27 | self._max_length = max_length 28 | self._device = torch.device(device) 29 | 30 | @property 31 | def batch_size(self): 32 | return self._batch_size 33 | 34 | def _model_generate(self, context, max_length, stop, **generation_kwargs): 35 | raise NotImplementedError() 36 | 37 | 38 | if __name__ == "__main__": 39 | cli_evaluate() 40 | -------------------------------------------------------------------------------- /lib/test/utils/params.py: -------------------------------------------------------------------------------- 1 | from lib.utils import TensorList 2 | import random 3 | 4 | 5 | class TrackerParams: 6 | """Class for tracker parameters.""" 7 | def set_default_values(self, default_vals: dict): 8 | for name, val in default_vals.items(): 9 | if not hasattr(self, name): 10 | setattr(self, name, val) 11 | 12 | def get(self, name: str, *default): 13 | """Get a parameter value with the given name. If it does not exists, it return the default value given as a 14 | second argument or returns an error if no default value is given.""" 15 | if len(default) > 1: 16 | raise ValueError('Can only give one default value.') 17 | 18 | if not default: 19 | return getattr(self, name) 20 | 21 | return getattr(self, name, default[0]) 22 | 23 | def has(self, name: str): 24 | """Check if there exist a parameter with the given name.""" 25 | return hasattr(self, name) 26 | 27 | 28 | class FeatureParams: 29 | """Class for feature specific parameters""" 30 | def __init__(self, *args, **kwargs): 31 | if len(args) > 0: 32 | raise ValueError 33 | 34 | for name, val in kwargs.items(): 35 | if isinstance(val, list): 36 | setattr(self, name, TensorList(val)) 37 | else: 38 | setattr(self, name, val) 39 | 40 | 41 | def Choice(*args): 42 | """Can be used to sample random parameter values.""" 43 | return random.choice(args) 44 | -------------------------------------------------------------------------------- /lib/train/actors/base_actor.py: -------------------------------------------------------------------------------- 1 | from lib.utils import TensorDict 2 | 3 | 4 | class BaseActor: 5 | """ Base class for actor. The actor class handles the passing of the data through the network 6 | and calculation the loss""" 7 | def __init__(self, net, objective): 8 | """ 9 | args: 10 | net - The network to train 11 | objective - The loss function 12 | """ 13 | self.net = net 14 | self.objective = objective 15 | 16 | def __call__(self, data: TensorDict): 17 | """ Called in each training iteration. Should pass in input data through the network, calculate the loss, and 18 | return the training stats for the input data 19 | args: 20 | data - A TensorDict containing all the necessary data blocks. 21 | 22 | returns: 23 | loss - loss for the input data 24 | stats - a dict containing detailed losses 25 | """ 26 | raise NotImplementedError 27 | 28 | def to(self, device): 29 | """ Move the network to device 30 | args: 31 | device - device to use. 'cpu' or 'cuda' 32 | """ 33 | self.net.to(device) 34 | 35 | def train(self, mode=True): 36 | """ Set whether the network is in train mode. 37 | args: 38 | mode (True) - Bool specifying whether in training mode. 39 | """ 40 | self.net.train(mode) 41 | 42 | def eval(self): 43 | """ Set network to eval mode""" 44 | self.train(False) -------------------------------------------------------------------------------- /lib/test/utils/transform_trackingnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | import argparse 5 | import _init_paths 6 | from lib.test.evaluation.environment import env_settings 7 | 8 | 9 | def transform_trackingnet(tracker_name, cfg_name): 10 | env = env_settings() 11 | result_dir = env.results_path 12 | src_dir = os.path.join(result_dir, "%s/%s/trackingnet/" % (tracker_name, cfg_name)) 13 | dest_dir = os.path.join(result_dir, "%s/%s/trackingnet_submit/" % (tracker_name, cfg_name)) 14 | if not os.path.exists(dest_dir): 15 | os.makedirs(dest_dir) 16 | items = os.listdir(src_dir) 17 | for item in items: 18 | if "all" in item: 19 | continue 20 | if "time" not in item: 21 | src_path = os.path.join(src_dir, item) 22 | dest_path = os.path.join(dest_dir, item) 23 | bbox_arr = np.loadtxt(src_path, dtype=int, delimiter='\t') 24 | np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',') 25 | # make zip archive 26 | shutil.make_archive(src_dir, "zip", src_dir) 27 | shutil.make_archive(dest_dir, "zip", dest_dir) 28 | # Remove the original files 29 | shutil.rmtree(src_dir) 30 | shutil.rmtree(dest_dir) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser(description='transform trackingnet results.') 35 | parser.add_argument('--tracker_name', type=str, help='Name of tracking method.') 36 | parser.add_argument('--cfg_name', type=str, help='Name of config file.') 37 | 38 | args = parser.parse_args() 39 | transform_trackingnet(args.tracker_name, args.cfg_name) 40 | -------------------------------------------------------------------------------- /experiments/mambatrack/mambavt_m256_ep20.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | NUMBER: 7 23 | TRAIN: 24 | DATASETS_NAME: 25 | - LasHeR_train 26 | DATASETS_RATIO: 27 | - 1 28 | SAMPLE_PER_EPOCH: 60000 29 | VAL: 30 | DATASETS_NAME: 31 | - LasHeR_val 32 | DATASETS_RATIO: 33 | - 1 34 | SAMPLE_PER_EPOCH: 10000 35 | MODEL: 36 | PRETRAIN_FILE: "/your_path/pretrained_models/OSTrack_videomambam_ep300.pth.tar" 37 | EXTRA_MERGER: False 38 | RETURN_INTER: False 39 | BACKBONE: 40 | TYPE: videomamba_middle_576 41 | STRIDE: 16 42 | HEAD: 43 | TYPE: CENTER 44 | NUM_CHANNELS: 256 45 | TRAIN: 46 | BACKBONE_MULTIPLIER: 0.05 47 | DROP_PATH_RATE: 0.1 48 | BATCH_SIZE: 32 49 | EPOCH: 20 50 | GIOU_WEIGHT: 2.0 51 | L1_WEIGHT: 5.0 52 | GRAD_CLIP_NORM: 0.1 53 | LR: 0.0004 54 | LR_DROP_EPOCH: 16 55 | NUM_WORKER: 10 56 | OPTIMIZER: ADAMW 57 | PRINT_INTERVAL: 100 58 | SCHEDULER: 59 | TYPE: step 60 | DECAY_RATE: 0.1 61 | VAL_EPOCH_INTERVAL: 5 62 | WEIGHT_DECAY: 0.0001 63 | AMP: False 64 | SAVE_EPOCH_INTERVAL: 5 # 1 means save model each epoch 65 | SAVE_LAST_N_EPOCH: 5 66 | TEST: 67 | EPOCH: 20 68 | SEARCH_FACTOR: 4.0 69 | SEARCH_SIZE: 256 70 | TEMPLATE_FACTOR: 2.0 71 | TEMPLATE_SIZE: 128 72 | TEMPLATE_NUMBER: 6 -------------------------------------------------------------------------------- /experiments/mambatrack/mambavt_s256_ep20.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | NUMBER: 7 23 | TRAIN: 24 | DATASETS_NAME: 25 | - LasHeR_train 26 | DATASETS_RATIO: 27 | - 1 28 | SAMPLE_PER_EPOCH: 60000 29 | VAL: 30 | DATASETS_NAME: 31 | - LasHeR_val 32 | DATASETS_RATIO: 33 | - 1 34 | SAMPLE_PER_EPOCH: 10000 35 | MODEL: 36 | PRETRAIN_FILE: "/your_path/pretrained_models/OSTrack_videomambas_ep300.pth.tar" 37 | EXTRA_MERGER: False 38 | RETURN_INTER: False 39 | BACKBONE: 40 | TYPE: videomamba_small_576 41 | STRIDE: 16 42 | HEAD: 43 | TYPE: CENTER 44 | NUM_CHANNELS: 256 45 | TRAIN: 46 | BACKBONE_MULTIPLIER: 0.05 47 | DROP_PATH_RATE: 0.1 48 | BATCH_SIZE: 32 49 | EPOCH: 20 50 | GIOU_WEIGHT: 2.0 51 | L1_WEIGHT: 5.0 52 | GRAD_CLIP_NORM: 0.1 53 | LR: 0.0003 54 | LR_DROP_EPOCH: 16 55 | NUM_WORKER: 10 56 | OPTIMIZER: ADAMW 57 | PRINT_INTERVAL: 100 58 | SCHEDULER: 59 | TYPE: step 60 | DECAY_RATE: 0.1 61 | VAL_EPOCH_INTERVAL: 5 62 | WEIGHT_DECAY: 0.0001 63 | AMP: False 64 | SAVE_EPOCH_INTERVAL: 5 # 1 means save model each epoch 65 | SAVE_LAST_N_EPOCH: 5 66 | TEST: 67 | EPOCH: 20 68 | SEARCH_FACTOR: 4.0 69 | SEARCH_SIZE: 256 70 | TEMPLATE_FACTOR: 2.0 71 | TEMPLATE_SIZE: 128 72 | TEMPLATE_NUMBER: 6 -------------------------------------------------------------------------------- /experiments/mambatrack/mambavt_m256_ep20_lasher.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | NUMBER: 7 23 | TRAIN: 24 | DATASETS_NAME: 25 | - LasHeR_train 26 | DATASETS_RATIO: 27 | - 1 28 | SAMPLE_PER_EPOCH: 60000 29 | VAL: 30 | DATASETS_NAME: 31 | - LasHeR_val 32 | DATASETS_RATIO: 33 | - 1 34 | SAMPLE_PER_EPOCH: 10000 35 | MODEL: 36 | PRETRAIN_FILE: "/your_path/pretrained_models/OSTrack_videomambam_ep300.pth.tar" 37 | EXTRA_MERGER: False 38 | RETURN_INTER: False 39 | BACKBONE: 40 | TYPE: videomamba_middle_576 41 | STRIDE: 16 42 | HEAD: 43 | TYPE: CENTER 44 | NUM_CHANNELS: 256 45 | TRAIN: 46 | BACKBONE_MULTIPLIER: 0.05 47 | DROP_PATH_RATE: 0.1 48 | BATCH_SIZE: 32 49 | EPOCH: 20 50 | GIOU_WEIGHT: 2.0 51 | L1_WEIGHT: 5.0 52 | GRAD_CLIP_NORM: 0.1 53 | LR: 0.0008 54 | LR_DROP_EPOCH: 16 55 | NUM_WORKER: 10 56 | OPTIMIZER: ADAMW 57 | PRINT_INTERVAL: 100 58 | SCHEDULER: 59 | TYPE: step 60 | DECAY_RATE: 0.1 61 | VAL_EPOCH_INTERVAL: 5 62 | WEIGHT_DECAY: 0.0001 63 | AMP: False 64 | SAVE_EPOCH_INTERVAL: 5 # 1 means save model each epoch 65 | SAVE_LAST_N_EPOCH: 5 66 | TEST: 67 | EPOCH: 20 68 | SEARCH_FACTOR: 4.0 69 | SEARCH_SIZE: 256 70 | TEMPLATE_FACTOR: 2.0 71 | TEMPLATE_SIZE: 128 72 | TEMPLATE_NUMBER: 6 -------------------------------------------------------------------------------- /experiments/mambatrack/mambavt_s256_ep20_lasher.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | NUMBER: 7 23 | TRAIN: 24 | DATASETS_NAME: 25 | - LasHeR_train 26 | DATASETS_RATIO: 27 | - 1 28 | SAMPLE_PER_EPOCH: 60000 29 | VAL: 30 | DATASETS_NAME: 31 | - LasHeR_val 32 | DATASETS_RATIO: 33 | - 1 34 | SAMPLE_PER_EPOCH: 10000 35 | MODEL: 36 | PRETRAIN_FILE: "/your_path/pretrained_models/OSTrack_videomambas_ep300.pth.tar" 37 | EXTRA_MERGER: False 38 | RETURN_INTER: False 39 | BACKBONE: 40 | TYPE: videomamba_small_576 41 | STRIDE: 16 42 | HEAD: 43 | TYPE: CENTER 44 | NUM_CHANNELS: 256 45 | TRAIN: 46 | BACKBONE_MULTIPLIER: 0.05 47 | DROP_PATH_RATE: 0.1 48 | BATCH_SIZE: 32 49 | EPOCH: 20 50 | GIOU_WEIGHT: 2.0 51 | L1_WEIGHT: 5.0 52 | GRAD_CLIP_NORM: 0.1 53 | LR: 0.0008 54 | LR_DROP_EPOCH: 16 55 | NUM_WORKER: 10 56 | OPTIMIZER: ADAMW 57 | PRINT_INTERVAL: 100 58 | SCHEDULER: 59 | TYPE: step 60 | DECAY_RATE: 0.1 61 | VAL_EPOCH_INTERVAL: 5 62 | WEIGHT_DECAY: 0.0001 63 | AMP: False 64 | SAVE_EPOCH_INTERVAL: 5 # 1 means save model each epoch 65 | SAVE_LAST_N_EPOCH: 5 66 | TEST: 67 | EPOCH: 20 68 | SEARCH_FACTOR: 4.0 69 | SEARCH_SIZE: 256 70 | TEMPLATE_FACTOR: 2.0 71 | TEMPLATE_SIZE: 128 72 | TEMPLATE_NUMBER: 6 -------------------------------------------------------------------------------- /lib/test/parameter/mambatrack_motion.py: -------------------------------------------------------------------------------- 1 | from lib.test.utils import TrackerParams 2 | import os 3 | from lib.test.evaluation.environment import env_settings 4 | from lib.config.mambatrack_motion.config import cfg, update_config_from_file 5 | 6 | 7 | def parameters(yaml_name: str, epoch=300, debug=False): 8 | params = TrackerParams() 9 | prj_dir = env_settings().prj_dir 10 | save_dir = env_settings().save_dir 11 | # update default config from yaml file 12 | yaml_file = os.path.join(prj_dir, 'experiments/mambatrack_motion/%s.yaml' % yaml_name) 13 | update_config_from_file(yaml_file) 14 | params.cfg = cfg 15 | # if debug: 16 | params.debug = debug 17 | # print("test config: ", cfg) 18 | 19 | # template and search region 20 | params.template_factor = cfg.TEST.TEMPLATE_FACTOR 21 | params.template_size = cfg.TEST.TEMPLATE_SIZE 22 | params.search_factor = cfg.TEST.SEARCH_FACTOR 23 | params.search_size = cfg.TEST.SEARCH_SIZE 24 | 25 | # Network checkpoint path 26 | # params.checkpoint = os.path.join(save_dir, "checkpoints/train/mambatrack_motion/%s/MambaTrackMotion_ep%04d.pth.tar" % 27 | # (yaml_name, epoch)) # cfg.TEST.EPOCH 28 | 29 | params.checkpoint = os.path.join(prj_dir, "checkpoints/MambaT-M-LasHeR.pth.tar") 30 | params.checkpoint = os.path.join(prj_dir, "checkpoints/MambaT-M-rgbt234-rgbt210-gtot.pth.tar") 31 | params.checkpoint = os.path.join(prj_dir, "checkpoints/MambaT-S-LasHeR.pth.tar") 32 | params.checkpoint = os.path.join(prj_dir, "checkpoints/MambaT-S-rgbt234-rgbt210-gtot.pth.tar") 33 | 34 | # whether to save boxes from all queries 35 | params.save_all_boxes = False 36 | 37 | return params 38 | -------------------------------------------------------------------------------- /lib/utils/variable_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bytecode import Bytecode, Instr 3 | 4 | 5 | class get_local(object): 6 | cache = {} 7 | is_activate = False 8 | 9 | def __init__(self, varname): 10 | self.varname = varname 11 | 12 | def __call__(self, func): 13 | if not type(self).is_activate: 14 | return func 15 | 16 | type(self).cache[func.__qualname__] = [] 17 | c = Bytecode.from_code(func.__code__) 18 | extra_code = [ 19 | Instr('STORE_FAST', '_res'), 20 | Instr('LOAD_FAST', self.varname), 21 | Instr('STORE_FAST', '_value'), 22 | Instr('LOAD_FAST', '_res'), 23 | Instr('LOAD_FAST', '_value'), 24 | Instr('BUILD_TUPLE', 2), 25 | Instr('STORE_FAST', '_result_tuple'), 26 | Instr('LOAD_FAST', '_result_tuple'), 27 | ] 28 | c[-1:-1] = extra_code 29 | func.__code__ = c.to_code() 30 | 31 | def wrapper(*args, **kwargs): 32 | res, values = func(*args, **kwargs) 33 | if isinstance(values, torch.Tensor): 34 | type(self).cache[func.__qualname__].append(values.detach().cpu().numpy()) 35 | elif isinstance(values, list): # list of Tensor 36 | type(self).cache[func.__qualname__].append([value.detach().cpu().numpy() for value in values]) 37 | else: 38 | raise NotImplementedError 39 | return res 40 | 41 | return wrapper 42 | 43 | @classmethod 44 | def clear(cls): 45 | for key in cls.cache.keys(): 46 | cls.cache[key] = [] 47 | 48 | @classmethod 49 | def activate(cls): 50 | cls.is_activate = True 51 | -------------------------------------------------------------------------------- /lib/models/layers/frozen_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FrozenBatchNorm2d(torch.nn.Module): 5 | """ 6 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 7 | 8 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 9 | without which any other models than torchvision.models.resnet[18,34,50,101] 10 | produce nans. 11 | """ 12 | 13 | def __init__(self, n): 14 | super(FrozenBatchNorm2d, self).__init__() 15 | self.register_buffer("weight", torch.ones(n)) 16 | self.register_buffer("bias", torch.zeros(n)) 17 | self.register_buffer("running_mean", torch.zeros(n)) 18 | self.register_buffer("running_var", torch.ones(n)) 19 | 20 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 21 | missing_keys, unexpected_keys, error_msgs): 22 | num_batches_tracked_key = prefix + 'num_batches_tracked' 23 | if num_batches_tracked_key in state_dict: 24 | del state_dict[num_batches_tracked_key] 25 | 26 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 27 | state_dict, prefix, local_metadata, strict, 28 | missing_keys, unexpected_keys, error_msgs) 29 | 30 | def forward(self, x): 31 | # move reshapes to the beginning 32 | # to make it fuser-friendly 33 | w = self.weight.reshape(1, -1, 1, 1) 34 | b = self.bias.reshape(1, -1, 1, 1) 35 | rv = self.running_var.reshape(1, -1, 1, 1) 36 | rm = self.running_mean.reshape(1, -1, 1, 1) 37 | eps = 1e-5 38 | scale = w * (rv + eps).rsqrt() # rsqrt(x): 1/sqrt(x), r: reciprocal 39 | bias = b - rm * scale 40 | return x * scale + bias 41 | -------------------------------------------------------------------------------- /lib/test/utils/load_text.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def load_text_numpy(path, delimiter, dtype): 6 | if isinstance(delimiter, (tuple, list)): 7 | for d in delimiter: 8 | try: 9 | ground_truth_rect = np.loadtxt(path, delimiter=d, dtype=dtype) 10 | return ground_truth_rect 11 | except: 12 | pass 13 | 14 | raise Exception('Could not read file {}'.format(path)) 15 | else: 16 | ground_truth_rect = np.loadtxt(path, delimiter=delimiter, dtype=dtype) 17 | return ground_truth_rect 18 | 19 | 20 | def load_text_pandas(path, delimiter, dtype): 21 | if isinstance(delimiter, (tuple, list)): 22 | for d in delimiter: 23 | try: 24 | ground_truth_rect = pd.read_csv(path, delimiter=d, header=None, dtype=dtype, na_filter=False, 25 | low_memory=False).values 26 | return ground_truth_rect 27 | except Exception as e: 28 | pass 29 | 30 | raise Exception('Could not read file {}'.format(path)) 31 | else: 32 | ground_truth_rect = pd.read_csv(path, delimiter=delimiter, header=None, dtype=dtype, na_filter=False, 33 | low_memory=False).values 34 | return ground_truth_rect 35 | 36 | 37 | def load_text(path, delimiter=' ', dtype=np.float32, backend='numpy'): 38 | if backend == 'numpy': 39 | return load_text_numpy(path, delimiter, dtype) 40 | elif backend == 'pandas': 41 | return load_text_pandas(path, delimiter, dtype) 42 | 43 | 44 | def load_str(path): 45 | with open(path, "r") as f: 46 | text_str = f.readline().strip().lower() 47 | return text_str 48 | -------------------------------------------------------------------------------- /lib/utils/lmdb_utils.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import numpy as np 3 | import cv2 4 | import json 5 | 6 | LMDB_ENVS = dict() 7 | LMDB_HANDLES = dict() 8 | LMDB_FILELISTS = dict() 9 | 10 | 11 | def get_lmdb_handle(name): 12 | global LMDB_HANDLES, LMDB_FILELISTS 13 | item = LMDB_HANDLES.get(name, None) 14 | if item is None: 15 | env = lmdb.open(name, readonly=True, lock=False, readahead=False, meminit=False) 16 | LMDB_ENVS[name] = env 17 | item = env.begin(write=False) 18 | LMDB_HANDLES[name] = item 19 | 20 | return item 21 | 22 | 23 | def decode_img(lmdb_fname, key_name): 24 | handle = get_lmdb_handle(lmdb_fname) 25 | binfile = handle.get(key_name.encode()) 26 | if binfile is None: 27 | print("Illegal data detected. %s %s" % (lmdb_fname, key_name)) 28 | s = np.frombuffer(binfile, np.uint8) 29 | x = cv2.cvtColor(cv2.imdecode(s, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 30 | return x 31 | 32 | 33 | def decode_str(lmdb_fname, key_name): 34 | handle = get_lmdb_handle(lmdb_fname) 35 | binfile = handle.get(key_name.encode()) 36 | string = binfile.decode() 37 | return string 38 | 39 | 40 | def decode_json(lmdb_fname, key_name): 41 | return json.loads(decode_str(lmdb_fname, key_name)) 42 | 43 | 44 | if __name__ == "__main__": 45 | lmdb_fname = "/data/sda/v-yanbi/iccv21/LittleBoy_clean/data/got10k_lmdb" 46 | '''Decode image''' 47 | # key_name = "test/GOT-10k_Test_000001/00000001.jpg" 48 | # img = decode_img(lmdb_fname, key_name) 49 | # cv2.imwrite("001.jpg", img) 50 | '''Decode str''' 51 | # key_name = "test/list.txt" 52 | # key_name = "train/GOT-10k_Train_000001/groundtruth.txt" 53 | key_name = "train/GOT-10k_Train_000001/absence.label" 54 | str_ = decode_str(lmdb_fname, key_name) 55 | print(str_) 56 | -------------------------------------------------------------------------------- /experiments/mambatrack_motion/mambavt_motion_m256_ep10_lasher.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | PRE_MOTION_NUM: 31 3 | MAX_SAMPLE_INTERVAL: 200 4 | MEAN: 5 | - 0.485 6 | - 0.456 7 | - 0.406 8 | SEARCH: 9 | CENTER_JITTER: 0 # 3 10 | FACTOR: 4.0 11 | SCALE_JITTER: 0.0 # 0.25 12 | SIZE: 256 13 | NUMBER: 1 14 | STD: 15 | - 0.229 16 | - 0.224 17 | - 0.225 18 | TEMPLATE: 19 | CENTER_JITTER: 0 20 | FACTOR: 2.0 21 | SCALE_JITTER: 0 22 | SIZE: 128 23 | NUMBER: 7 24 | TRAIN: 25 | DATASETS_NAME: 26 | - LasHeR_motion_train 27 | DATASETS_RATIO: 28 | - 1 29 | SAMPLE_PER_EPOCH: 60000 30 | VAL: 31 | DATASETS_NAME: 32 | - LasHeR_motion_val 33 | DATASETS_RATIO: 34 | - 1 35 | SAMPLE_PER_EPOCH: 10000 36 | MODEL: 37 | BINS: 400 38 | RANGE: 2 39 | PRETRAIN_FILE: "/your_stage1_trained_model_path/MambaTrack_ep00??.pth.tar" 40 | EXTRA_MERGER: False 41 | RETURN_INTER: False 42 | BACKBONE: 43 | TYPE: videomamba_middle_576 44 | STRIDE: 16 45 | ADD_MOTION_PRED: True 46 | PROMPT_EMBED_TYPE: vocab 47 | HEAD: 48 | TYPE: CENTER 49 | NUM_CHANNELS: 256 50 | TRAIN: 51 | BACKBONE_MULTIPLIER: 0.05 52 | DROP_PATH_RATE: 0.1 53 | BATCH_SIZE: 32 54 | EPOCH: 10 55 | GIOU_WEIGHT: 2.0 56 | L1_WEIGHT: 5.0 57 | GRAD_CLIP_NORM: 0.1 58 | LR: 0.00008 59 | LR_DROP_EPOCH: 8 60 | NUM_WORKER: 10 61 | OPTIMIZER: ADAMW 62 | PRINT_INTERVAL: 100 63 | SCHEDULER: 64 | TYPE: step 65 | DECAY_RATE: 0.1 66 | VAL_EPOCH_INTERVAL: 1 67 | WEIGHT_DECAY: 0.0001 68 | AMP: False 69 | SAVE_EPOCH_INTERVAL: 2 # 1 means save model each epoch 70 | SAVE_LAST_N_EPOCH: 1 71 | TEST: 72 | EPOCH: 10 73 | SEARCH_FACTOR: 4.0 74 | SEARCH_SIZE: 256 75 | TEMPLATE_FACTOR: 2.0 76 | TEMPLATE_SIZE: 128 77 | TEMPLATE_NUMBER: 6 78 | TEST_PRE_NUM: 79 | LasHeR: 2 -------------------------------------------------------------------------------- /experiments/mambatrack_motion/mambavt_motion_s256_ep10_lasher.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | PRE_MOTION_NUM: 31 3 | MAX_SAMPLE_INTERVAL: 200 4 | MEAN: 5 | - 0.485 6 | - 0.456 7 | - 0.406 8 | SEARCH: 9 | CENTER_JITTER: 0 # 3 10 | FACTOR: 4.0 11 | SCALE_JITTER: 0.0 # 0.25 12 | SIZE: 256 13 | NUMBER: 1 14 | STD: 15 | - 0.229 16 | - 0.224 17 | - 0.225 18 | TEMPLATE: 19 | CENTER_JITTER: 0 20 | FACTOR: 2.0 21 | SCALE_JITTER: 0 22 | SIZE: 128 23 | NUMBER: 7 24 | TRAIN: 25 | DATASETS_NAME: 26 | - LasHeR_motion_train 27 | DATASETS_RATIO: 28 | - 1 29 | SAMPLE_PER_EPOCH: 60000 30 | VAL: 31 | DATASETS_NAME: 32 | - LasHeR_motion_val 33 | DATASETS_RATIO: 34 | - 1 35 | SAMPLE_PER_EPOCH: 10000 36 | MODEL: 37 | BINS: 400 38 | RANGE: 2 39 | PRETRAIN_FILE: "/your_stage1_trained_model_path/MambaTrack_ep00??.pth.tar" 40 | EXTRA_MERGER: False 41 | RETURN_INTER: False 42 | BACKBONE: 43 | TYPE: videomamba_small_576 44 | STRIDE: 16 45 | ADD_MOTION_PRED: True 46 | PROMPT_EMBED_TYPE: vocab 47 | HEAD: 48 | TYPE: CENTER 49 | NUM_CHANNELS: 256 50 | TRAIN: 51 | BACKBONE_MULTIPLIER: 0.05 52 | DROP_PATH_RATE: 0.1 53 | BATCH_SIZE: 32 54 | EPOCH: 10 55 | GIOU_WEIGHT: 2.0 56 | L1_WEIGHT: 5.0 57 | GRAD_CLIP_NORM: 0.1 58 | LR: 0.00008 59 | LR_DROP_EPOCH: 8 60 | NUM_WORKER: 10 61 | OPTIMIZER: ADAMW 62 | PRINT_INTERVAL: 100 63 | SCHEDULER: 64 | TYPE: step 65 | DECAY_RATE: 0.1 66 | VAL_EPOCH_INTERVAL: 1 67 | WEIGHT_DECAY: 0.0001 68 | AMP: False 69 | SAVE_EPOCH_INTERVAL: 2 # 1 means save model each epoch 70 | SAVE_LAST_N_EPOCH: 1 71 | TEST: 72 | EPOCH: 10 73 | SEARCH_FACTOR: 4.0 74 | SEARCH_SIZE: 256 75 | TEMPLATE_FACTOR: 2.0 76 | TEMPLATE_SIZE: 128 77 | TEMPLATE_NUMBER: 6 78 | TEST_PRE_NUM: 79 | LasHeR: 10 80 | -------------------------------------------------------------------------------- /experiments/mambatrack_motion/mambavt_motion_m256_ep10.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | PRE_MOTION_NUM: 31 3 | MAX_SAMPLE_INTERVAL: 200 4 | MEAN: 5 | - 0.485 6 | - 0.456 7 | - 0.406 8 | SEARCH: 9 | CENTER_JITTER: 0 # 3 10 | FACTOR: 4.0 11 | SCALE_JITTER: 0.0 # 0.25 12 | SIZE: 256 13 | NUMBER: 1 14 | STD: 15 | - 0.229 16 | - 0.224 17 | - 0.225 18 | TEMPLATE: 19 | CENTER_JITTER: 0 20 | FACTOR: 2.0 21 | SCALE_JITTER: 0 22 | SIZE: 128 23 | NUMBER: 7 24 | TRAIN: 25 | DATASETS_NAME: 26 | - LasHeR_motion_train 27 | DATASETS_RATIO: 28 | - 1 29 | SAMPLE_PER_EPOCH: 60000 30 | VAL: 31 | DATASETS_NAME: 32 | - LasHeR_motion_val 33 | DATASETS_RATIO: 34 | - 1 35 | SAMPLE_PER_EPOCH: 10000 36 | MODEL: 37 | BINS: 400 38 | RANGE: 2 39 | PRETRAIN_FILE: "/your_stage1_trained_model_path/MambaTrack_ep00??.pth.tar" 40 | EXTRA_MERGER: False 41 | RETURN_INTER: False 42 | BACKBONE: 43 | TYPE: videomamba_middle_576 44 | STRIDE: 16 45 | ADD_MOTION_PRED: True 46 | PROMPT_EMBED_TYPE: vocab 47 | HEAD: 48 | TYPE: CENTER 49 | NUM_CHANNELS: 256 50 | TRAIN: 51 | BACKBONE_MULTIPLIER: 0.05 52 | DROP_PATH_RATE: 0.1 53 | BATCH_SIZE: 32 54 | EPOCH: 10 55 | GIOU_WEIGHT: 2.0 56 | L1_WEIGHT: 5.0 57 | GRAD_CLIP_NORM: 0.1 58 | LR: 0.00004 59 | LR_DROP_EPOCH: 8 60 | NUM_WORKER: 10 61 | OPTIMIZER: ADAMW 62 | PRINT_INTERVAL: 100 63 | SCHEDULER: 64 | TYPE: step 65 | DECAY_RATE: 0.1 66 | VAL_EPOCH_INTERVAL: 1 67 | WEIGHT_DECAY: 0.0001 68 | AMP: False 69 | SAVE_EPOCH_INTERVAL: 2 # 1 means save model each epoch 70 | SAVE_LAST_N_EPOCH: 1 71 | TEST: 72 | EPOCH: 10 73 | SEARCH_FACTOR: 4.0 74 | SEARCH_SIZE: 256 75 | TEMPLATE_FACTOR: 2.0 76 | TEMPLATE_SIZE: 128 77 | TEMPLATE_NUMBER: 6 78 | TEST_PRE_NUM: 79 | GTOT: 5 80 | RGBT234: 4 81 | RGBT210: 5 82 | -------------------------------------------------------------------------------- /experiments/mambatrack_motion/mambavt_motion_s256_ep10.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | PRE_MOTION_NUM: 31 3 | MAX_SAMPLE_INTERVAL: 200 4 | MEAN: 5 | - 0.485 6 | - 0.456 7 | - 0.406 8 | SEARCH: 9 | CENTER_JITTER: 0 # 3 10 | FACTOR: 4.0 11 | SCALE_JITTER: 0.0 # 0.25 12 | SIZE: 256 13 | NUMBER: 1 14 | STD: 15 | - 0.229 16 | - 0.224 17 | - 0.225 18 | TEMPLATE: 19 | CENTER_JITTER: 0 20 | FACTOR: 2.0 21 | SCALE_JITTER: 0 22 | SIZE: 128 23 | NUMBER: 7 24 | TRAIN: 25 | DATASETS_NAME: 26 | - LasHeR_motion_train 27 | DATASETS_RATIO: 28 | - 1 29 | SAMPLE_PER_EPOCH: 60000 30 | VAL: 31 | DATASETS_NAME: 32 | - LasHeR_motion_val 33 | DATASETS_RATIO: 34 | - 1 35 | SAMPLE_PER_EPOCH: 10000 36 | MODEL: 37 | BINS: 400 38 | RANGE: 2 39 | PRETRAIN_FILE: "/your_stage1_trained_model_path/MambaTrack_ep00??.pth.tar" 40 | EXTRA_MERGER: False 41 | RETURN_INTER: False 42 | BACKBONE: 43 | TYPE: videomamba_small_576 44 | STRIDE: 16 45 | ADD_MOTION_PRED: True 46 | PROMPT_EMBED_TYPE: vocab 47 | HEAD: 48 | TYPE: CENTER 49 | NUM_CHANNELS: 256 50 | TRAIN: 51 | BACKBONE_MULTIPLIER: 0.05 52 | DROP_PATH_RATE: 0.1 53 | BATCH_SIZE: 32 54 | EPOCH: 10 55 | GIOU_WEIGHT: 2.0 56 | L1_WEIGHT: 5.0 57 | GRAD_CLIP_NORM: 0.1 58 | LR: 0.00003 59 | LR_DROP_EPOCH: 8 60 | NUM_WORKER: 10 61 | OPTIMIZER: ADAMW 62 | PRINT_INTERVAL: 100 63 | SCHEDULER: 64 | TYPE: step 65 | DECAY_RATE: 0.1 66 | VAL_EPOCH_INTERVAL: 1 67 | WEIGHT_DECAY: 0.0001 68 | AMP: False 69 | SAVE_EPOCH_INTERVAL: 2 # 1 means save model each epoch 70 | SAVE_LAST_N_EPOCH: 1 71 | TEST: 72 | EPOCH: 10 73 | SEARCH_FACTOR: 4.0 74 | SEARCH_SIZE: 256 75 | TEMPLATE_FACTOR: 2.0 76 | TEMPLATE_SIZE: 128 77 | TEMPLATE_NUMBER: 6 78 | TEST_PRE_NUM: 79 | GTOT: 10 80 | RGBT234: 5 81 | RGBT210: 2 82 | -------------------------------------------------------------------------------- /lib/train/data_specs/lasher_val.txt: -------------------------------------------------------------------------------- 1 | boywalkinginsnow3 2 | leftdrillmasterstanding 3 | leftgirlunderthelamp 4 | girlridesbike 5 | midboyplayingphone 6 | boywithumbrella 7 | manrun 8 | ab_pingpongball 9 | whitecarturnl 10 | girltakemoto 11 | rightgirlatbike 12 | easy_blackboy 13 | man_with_black_clothes2 14 | 7runone 15 | turnblkbike 16 | motobesidescar 17 | bikeafterwhitecar 18 | 2runsix 19 | rightboy_1227 20 | whitesuvcome 21 | AQrightofcomingmotos 22 | 7one 23 | blackman_0115 24 | rightmirrornotshining 25 | AQmanfromdarktrees 26 | bikeboy128 27 | orangegirl 28 | girlturnbike 29 | blackman2 30 | blackcarback 31 | rightof2cupsattached 32 | whitecar2west 33 | hatboy`shead 34 | whitebetweenblackandblue 35 | 2rdcarcome 36 | whitemancome 37 | nearmangotoD 38 | farmanrightwhitesmallhouse 39 | lightmotocoming 40 | boymototakesgirl 41 | leftblackboy 42 | righttallholdball 43 | blackcarcome 44 | twolinefirstone-gai 45 | lowerfoam2throw 46 | Awhitecargo 47 | car2north3 48 | rightfirstboy-ly 49 | girltakingplate 50 | left2ndgreenboy 51 | ab_bolster 52 | 9hatboy 53 | whitecarturn2 54 | midboyblue 55 | basketboywhite 56 | nightmototurn 57 | girlbike 58 | mantoground 59 | pickuptheyellowbook 60 | 8lastone 61 | AQbikeback 62 | girlsquattingbesidesleftbar 63 | blkbikefromnorth 64 | whitecar 65 | Amidredgirl 66 | blackbag 67 | AQblkgirlbike 68 | manwithyellowumbrella 69 | browncar2north 70 | carstop 71 | whiteboywithbag 72 | theleftestrunningboy 73 | girlafterglassdoor2 74 | rightmirrorlikesky 75 | redgirl1497 76 | midboy 77 | folderatlefthand 78 | bikecome 79 | leftfallenchair_inf_white 80 | Agirlrideback 81 | rightgirl 82 | belowrightwhiteboy 83 | moto2north1 84 | truckk 85 | highright2ndboy 86 | girl`sheadoncall 87 | whiteboy 88 | truckwhite 89 | AQgirlbiketurns 90 | left2ndboy 91 | whitegirl2right 92 | rightboywithwhite 93 | girlplayingphone 94 | girlumbrella 95 | truck 96 | manfarbesidespool 97 | dotat43 -------------------------------------------------------------------------------- /lib/train/admin/stats.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class StatValue: 4 | def __init__(self): 5 | self.clear() 6 | 7 | def reset(self): 8 | self.val = 0 9 | 10 | def clear(self): 11 | self.reset() 12 | self.history = [] 13 | 14 | def update(self, val): 15 | self.val = val 16 | self.history.append(self.val) 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | def __init__(self): 22 | self.clear() 23 | self.has_new_data = False 24 | 25 | def reset(self): 26 | self.avg = 0 27 | self.val = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def clear(self): 32 | self.reset() 33 | self.history = [] 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | def new_epoch(self): 42 | if self.count > 0: 43 | self.history.append(self.avg) 44 | self.reset() 45 | self.has_new_data = True 46 | else: 47 | self.has_new_data = False 48 | 49 | 50 | def topk_accuracy(output, target, topk=(1,)): 51 | """Computes the precision@k for the specified values of k""" 52 | single_input = not isinstance(topk, (tuple, list)) 53 | if single_input: 54 | topk = (topk,) 55 | 56 | maxk = max(topk) 57 | batch_size = target.size(0) 58 | 59 | _, pred = output.topk(maxk, 1, True, True) 60 | pred = pred.t() 61 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 62 | 63 | res = [] 64 | for k in topk: 65 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0] 66 | res.append(correct_k * 100.0 / batch_size) 67 | 68 | if single_input: 69 | return res[0] 70 | 71 | return res 72 | -------------------------------------------------------------------------------- /lib/train/trainers/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | inf = math.inf 4 | 5 | class NativeScalerWithGradNormCount: 6 | state_dict_key = "amp_scaler" 7 | 8 | def __init__(self): 9 | self._scaler = torch.cuda.amp.GradScaler() 10 | 11 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 12 | self._scaler.scale(loss).backward(create_graph=create_graph) 13 | if update_grad: 14 | if clip_grad is not None: 15 | assert parameters is not None 16 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 17 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 18 | else: 19 | self._scaler.unscale_(optimizer) 20 | norm = get_grad_norm_(parameters) 21 | self._scaler.step(optimizer) 22 | self._scaler.update() 23 | else: 24 | norm = None 25 | return norm 26 | 27 | def state_dict(self): 28 | return self._scaler.state_dict() 29 | 30 | def load_state_dict(self, state_dict): 31 | self._scaler.load_state_dict(state_dict) 32 | 33 | 34 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 35 | if isinstance(parameters, torch.Tensor): 36 | parameters = [parameters] 37 | parameters = [p for p in parameters if p.grad is not None] 38 | norm_type = float(norm_type) 39 | if len(parameters) == 0: 40 | return torch.tensor(0.) 41 | device = parameters[0].grad.device 42 | if norm_type == inf: 43 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 44 | else: 45 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 46 | return total_norm 47 | -------------------------------------------------------------------------------- /lib/test/evaluation/tc128dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 3 | import os 4 | import glob 5 | import six 6 | 7 | 8 | class TC128Dataset(BaseDataset): 9 | """ 10 | TC-128 Dataset 11 | modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit) 12 | """ 13 | def __init__(self): 14 | super().__init__() 15 | self.base_path = self.env_settings.tc128_path 16 | self.anno_files = sorted(glob.glob( 17 | os.path.join(self.base_path, '*/*_gt.txt'))) 18 | self.seq_dirs = [os.path.dirname(f) for f in self.anno_files] 19 | self.seq_names = [os.path.basename(d) for d in self.seq_dirs] 20 | # valid frame range for each sequence 21 | self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs] 22 | 23 | def get_sequence_list(self): 24 | return SequenceList([self._construct_sequence(s) for s in self.seq_names]) 25 | 26 | def _construct_sequence(self, sequence_name): 27 | if isinstance(sequence_name, six.string_types): 28 | if not sequence_name in self.seq_names: 29 | raise Exception('Sequence {} not found.'.format(sequence_name)) 30 | index = self.seq_names.index(sequence_name) 31 | # load valid frame range 32 | frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',') 33 | img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)] 34 | 35 | # load annotations 36 | anno = np.loadtxt(self.anno_files[index], delimiter=',') 37 | assert len(img_files) == len(anno) 38 | assert anno.shape[1] == 4 39 | 40 | # return img_files, anno 41 | return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4)) 42 | 43 | def __len__(self): 44 | return len(self.seq_names) 45 | -------------------------------------------------------------------------------- /lib/test/evaluation/tnl2kdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 5 | from lib.test.utils.load_text import load_text, load_str 6 | 7 | ############ 8 | # current 00000492.png of test_015_Sord_video_Q01_done is damaged and replaced by a copy of 00000491.png 9 | ############ 10 | 11 | 12 | class TNL2kDataset(BaseDataset): 13 | """ 14 | TNL2k test set 15 | """ 16 | def __init__(self): 17 | super().__init__() 18 | self.base_path = self.env_settings.tnl2k_path 19 | self.sequence_list = self._get_sequence_list() 20 | 21 | def get_sequence_list(self): 22 | return SequenceList([self._construct_sequence(s) for s in self.sequence_list]) 23 | 24 | def _construct_sequence(self, sequence_name): 25 | # class_name = sequence_name.split('-')[0] 26 | anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name) 27 | 28 | ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64) 29 | 30 | text_dsp_path = '{}/{}/language.txt'.format(self.base_path, sequence_name) 31 | text_dsp = load_str(text_dsp_path) 32 | 33 | frames_path = '{}/{}/imgs'.format(self.base_path, sequence_name) 34 | frames_list = [f for f in os.listdir(frames_path)] 35 | frames_list = sorted(frames_list) 36 | frames_list = ['{}/{}'.format(frames_path, frame_i) for frame_i in frames_list] 37 | 38 | # target_class = class_name 39 | return Sequence(sequence_name, frames_list, 'tnl2k', ground_truth_rect.reshape(-1, 4), text_dsp=text_dsp) 40 | 41 | def __len__(self): 42 | return len(self.sequence_list) 43 | 44 | def _get_sequence_list(self): 45 | sequence_list = [] 46 | for seq in os.listdir(self.base_path): 47 | if os.path.isdir(os.path.join(self.base_path, seq)): 48 | sequence_list.append(seq) 49 | 50 | return sequence_list 51 | -------------------------------------------------------------------------------- /lib/test/evaluation/tc128cedataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 3 | import os 4 | import glob 5 | import six 6 | 7 | 8 | class TC128CEDataset(BaseDataset): 9 | """ 10 | TC-128 Dataset (78 newly added sequences) 11 | modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit) 12 | """ 13 | def __init__(self): 14 | super().__init__() 15 | self.base_path = self.env_settings.tc128_path 16 | self.anno_files = sorted(glob.glob( 17 | os.path.join(self.base_path, '*/*_gt.txt'))) 18 | """filter the newly added sequences (_ce)""" 19 | self.anno_files = [s for s in self.anno_files if "_ce" in s] 20 | self.seq_dirs = [os.path.dirname(f) for f in self.anno_files] 21 | self.seq_names = [os.path.basename(d) for d in self.seq_dirs] 22 | # valid frame range for each sequence 23 | self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs] 24 | 25 | def get_sequence_list(self): 26 | return SequenceList([self._construct_sequence(s) for s in self.seq_names]) 27 | 28 | def _construct_sequence(self, sequence_name): 29 | if isinstance(sequence_name, six.string_types): 30 | if not sequence_name in self.seq_names: 31 | raise Exception('Sequence {} not found.'.format(sequence_name)) 32 | index = self.seq_names.index(sequence_name) 33 | # load valid frame range 34 | frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',') 35 | img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)] 36 | 37 | # load annotations 38 | anno = np.loadtxt(self.anno_files[index], delimiter=',') 39 | assert len(img_files) == len(anno) 40 | assert anno.shape[1] == 4 41 | 42 | # return img_files, anno 43 | return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4)) 44 | 45 | def __len__(self): 46 | return len(self.seq_names) 47 | -------------------------------------------------------------------------------- /lib/test/tracker/vis_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | ############## used for visulize eliminated tokens ################# 5 | def get_keep_indices(decisions): 6 | keep_indices = [] 7 | for i in range(3): 8 | if i == 0: 9 | keep_indices.append(decisions[i]) 10 | else: 11 | keep_indices.append(keep_indices[-1][decisions[i]]) 12 | return keep_indices 13 | 14 | 15 | def gen_masked_tokens(tokens, indices, alpha=0.2): 16 | # indices = [i for i in range(196) if i not in indices] 17 | indices = indices[0].astype(int) 18 | tokens = tokens.copy() 19 | tokens[indices] = alpha * tokens[indices] + (1 - alpha) * 255 20 | return tokens 21 | 22 | 23 | def recover_image(tokens, H, W, Hp, Wp, patch_size): 24 | # image: (C, 196, 16, 16) 25 | image = tokens.reshape(Hp, Wp, patch_size, patch_size, 3).swapaxes(1, 2).reshape(H, W, 3) 26 | return image 27 | 28 | 29 | def pad_img(img): 30 | height, width, channels = img.shape 31 | im_bg = np.ones((height, width + 8, channels)) * 255 32 | im_bg[0:height, 0:width, :] = img 33 | return im_bg 34 | 35 | 36 | def gen_visualization(image, mask_indices, patch_size=16): 37 | # image [224, 224, 3] 38 | # mask_indices, list of masked token indices 39 | 40 | # mask mask_indices need to cat 41 | # mask_indices = mask_indices[::-1] 42 | num_stages = len(mask_indices) 43 | for i in range(1, num_stages): 44 | mask_indices[i] = np.concatenate([mask_indices[i-1], mask_indices[i]], axis=1) 45 | 46 | # keep_indices = get_keep_indices(decisions) 47 | image = np.asarray(image) 48 | H, W, C = image.shape 49 | Hp, Wp = H // patch_size, W // patch_size 50 | image_tokens = image.reshape(Hp, patch_size, Wp, patch_size, 3).swapaxes(1, 2).reshape(Hp * Wp, patch_size, patch_size, 3) 51 | 52 | stages = [ 53 | recover_image(gen_masked_tokens(image_tokens, mask_indices[i]), H, W, Hp, Wp, patch_size) 54 | for i in range(num_stages) 55 | ] 56 | imgs = [image] + stages 57 | imgs = [pad_img(img) for img in imgs] 58 | viz = np.concatenate(imgs, axis=1) 59 | return viz 60 | -------------------------------------------------------------------------------- /lib/test/utils/transform_got10k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | import argparse 5 | import _init_paths 6 | from lib.test.evaluation.environment import env_settings 7 | 8 | 9 | def transform_got10k(tracker_name, cfg_name): 10 | env = env_settings() 11 | result_dir = env.results_path 12 | src_dir = os.path.join(result_dir, "%s/%s/got10k/" % (tracker_name, cfg_name)) 13 | dest_dir = os.path.join(result_dir, "%s/%s/got10k_submit/" % (tracker_name, cfg_name)) 14 | if not os.path.exists(dest_dir): 15 | os.makedirs(dest_dir) 16 | items = os.listdir(src_dir) 17 | for item in items: 18 | if "all" in item: 19 | continue 20 | src_path = os.path.join(src_dir, item) 21 | if "time" not in item: 22 | seq_name = item.replace(".txt", '') 23 | seq_dir = os.path.join(dest_dir, seq_name) 24 | if not os.path.exists(seq_dir): 25 | os.makedirs(seq_dir) 26 | new_item = item.replace(".txt", '_001.txt') 27 | dest_path = os.path.join(seq_dir, new_item) 28 | bbox_arr = np.loadtxt(src_path, dtype=int, delimiter='\t') 29 | np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',') 30 | else: 31 | seq_name = item.replace("_time.txt", '') 32 | seq_dir = os.path.join(dest_dir, seq_name) 33 | if not os.path.exists(seq_dir): 34 | os.makedirs(seq_dir) 35 | dest_path = os.path.join(seq_dir, item) 36 | os.system("cp %s %s" % (src_path, dest_path)) 37 | # make zip archive 38 | shutil.make_archive(src_dir, "zip", src_dir) 39 | shutil.make_archive(dest_dir, "zip", dest_dir) 40 | # Remove the original files 41 | shutil.rmtree(src_dir) 42 | shutil.rmtree(dest_dir) 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser(description='transform got10k results.') 47 | parser.add_argument('--tracker_name', type=str, help='Name of tracking method.') 48 | parser.add_argument('--cfg_name', type=str, help='Name of config file.') 49 | 50 | args = parser.parse_args() 51 | transform_got10k(args.tracker_name, args.cfg_name) 52 | 53 | -------------------------------------------------------------------------------- /mamba-1p1p1/tests/ops/triton/test_selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2023, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import pytest 8 | 9 | from einops import rearrange 10 | 11 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref 12 | 13 | 14 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 15 | # @pytest.mark.parametrize('itype', [torch.float16]) 16 | @pytest.mark.parametrize("has_z", [False, True]) 17 | # @pytest.mark.parametrize('has_z', [True]) 18 | @pytest.mark.parametrize("dstate", [16, 32, 64]) 19 | # @pytest.mark.parametrize("dstate", [16]) 20 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 21 | # @pytest.mark.parametrize("dim", [2048]) 22 | def test_causal_conv1d_update(dim, dstate, has_z, itype): 23 | device = "cuda" 24 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) 25 | if itype == torch.bfloat16: 26 | rtol, atol = 1e-2, 5e-2 27 | # set seed 28 | torch.random.manual_seed(0) 29 | batch_size = 2 30 | state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) 31 | x = torch.randn(batch_size, dim, device=device, dtype=itype) 32 | dt = torch.randn(batch_size, dim, device=device, dtype=itype) 33 | dt_bias = torch.rand(dim, device=device) - 4.0 34 | A = -torch.rand(dim, dstate, device=device) - 1.0 35 | B = torch.randn(batch_size, dstate, device=device) 36 | C = torch.randn(batch_size, dstate, device=device) 37 | D = torch.randn(dim, device=device) 38 | if has_z: 39 | z = torch.randn_like(x) 40 | else: 41 | z = None 42 | state_ref = state.detach().clone() 43 | out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 44 | out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) 45 | 46 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 47 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 48 | assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) 49 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 50 | -------------------------------------------------------------------------------- /lib/test/evaluation/got10kdataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 3 | from lib.test.utils.load_text import load_text 4 | import os 5 | 6 | 7 | class GOT10KDataset(BaseDataset): 8 | """ GOT-10k dataset. 9 | 10 | Publication: 11 | GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild 12 | Lianghua Huang, Xin Zhao, and Kaiqi Huang 13 | arXiv:1810.11981, 2018 14 | https://arxiv.org/pdf/1810.11981.pdf 15 | 16 | Download dataset from http://got-10k.aitestunion.com/downloads 17 | """ 18 | def __init__(self, split): 19 | super().__init__() 20 | # Split can be test, val, or ltrval (a validation split consisting of videos from the official train set) 21 | if split == 'test' or split == 'val': 22 | self.base_path = os.path.join(self.env_settings.got10k_path, split) 23 | else: 24 | self.base_path = os.path.join(self.env_settings.got10k_path, 'train') 25 | 26 | self.sequence_list = self._get_sequence_list(split) 27 | self.split = split 28 | 29 | def get_sequence_list(self): 30 | return SequenceList([self._construct_sequence(s) for s in self.sequence_list]) 31 | 32 | def _construct_sequence(self, sequence_name): 33 | anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name) 34 | 35 | ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64) 36 | 37 | frames_path = '{}/{}'.format(self.base_path, sequence_name) 38 | frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")] 39 | frame_list.sort(key=lambda f: int(f[:-4])) 40 | frames_list = [os.path.join(frames_path, frame) for frame in frame_list] 41 | 42 | return Sequence(sequence_name, frames_list, 'got10k', ground_truth_rect.reshape(-1, 4)) 43 | 44 | def __len__(self): 45 | return len(self.sequence_list) 46 | 47 | def _get_sequence_list(self, split): 48 | with open('{}/list.txt'.format(self.base_path)) as f: 49 | sequence_list = f.read().splitlines() 50 | 51 | if split == 'ltrval': 52 | with open('{}/got10k_val_split.txt'.format(self.env_settings.dataspec_path)) as f: 53 | seq_ids = f.read().splitlines() 54 | 55 | sequence_list = [sequence_list[int(x)] for x in seq_ids] 56 | return sequence_list 57 | -------------------------------------------------------------------------------- /lib/test/evaluation/datasets.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import importlib 3 | from lib.test.evaluation.data import SequenceList 4 | 5 | DatasetInfo = namedtuple('DatasetInfo', ['module', 'class_name', 'kwargs']) 6 | 7 | pt = "lib.test.evaluation.%sdataset" # Useful abbreviations to reduce the clutter 8 | 9 | dataset_dict = dict( 10 | otb=DatasetInfo(module=pt % "otb", class_name="OTBDataset", kwargs=dict()), 11 | nfs=DatasetInfo(module=pt % "nfs", class_name="NFSDataset", kwargs=dict()), 12 | uav=DatasetInfo(module=pt % "uav", class_name="UAVDataset", kwargs=dict()), 13 | tc128=DatasetInfo(module=pt % "tc128", class_name="TC128Dataset", kwargs=dict()), 14 | tc128ce=DatasetInfo(module=pt % "tc128ce", class_name="TC128CEDataset", kwargs=dict()), 15 | trackingnet=DatasetInfo(module=pt % "trackingnet", class_name="TrackingNetDataset", kwargs=dict()), 16 | got10k_test=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='test')), 17 | got10k_val=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='val')), 18 | got10k_ltrval=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='ltrval')), 19 | lasot=DatasetInfo(module=pt % "lasot", class_name="LaSOTDataset", kwargs=dict()), 20 | lasot_lmdb=DatasetInfo(module=pt % "lasot_lmdb", class_name="LaSOTlmdbDataset", kwargs=dict()), 21 | 22 | vot18=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict()), 23 | vot22=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict(year=22)), 24 | itb=DatasetInfo(module=pt % "itb", class_name="ITBDataset", kwargs=dict()), 25 | tnl2k=DatasetInfo(module=pt % "tnl2k", class_name="TNL2kDataset", kwargs=dict()), 26 | lasot_extension_subset=DatasetInfo(module=pt % "lasotextensionsubset", class_name="LaSOTExtensionSubsetDataset", 27 | kwargs=dict()), 28 | ) 29 | 30 | 31 | def load_dataset(name: str): 32 | """ Import and load a single dataset.""" 33 | name = name.lower() 34 | dset_info = dataset_dict.get(name) 35 | if dset_info is None: 36 | raise ValueError('Unknown dataset \'%s\'' % name) 37 | 38 | m = importlib.import_module(dset_info.module) 39 | dataset = getattr(m, dset_info.class_name)(**dset_info.kwargs) # Call the constructor 40 | return dataset.get_sequence_list() 41 | 42 | 43 | def get_dataset(*args): 44 | """ Get a single or set of datasets.""" 45 | dset = SequenceList() 46 | for name in args: 47 | dset.extend(load_dataset(name)) 48 | return dset -------------------------------------------------------------------------------- /lib/test/evaluation/trackingnetdataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 3 | import os 4 | from lib.test.utils.load_text import load_text 5 | 6 | 7 | class TrackingNetDataset(BaseDataset): 8 | """ TrackingNet test set. 9 | 10 | Publication: 11 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild. 12 | Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem 13 | ECCV, 2018 14 | https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf 15 | 16 | Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit. 17 | """ 18 | def __init__(self): 19 | super().__init__() 20 | self.base_path = self.env_settings.trackingnet_path 21 | 22 | sets = 'TEST' 23 | if not isinstance(sets, (list, tuple)): 24 | if sets == 'TEST': 25 | sets = ['TEST'] 26 | elif sets == 'TRAIN': 27 | sets = ['TRAIN_{}'.format(i) for i in range(5)] 28 | 29 | self.sequence_list = self._list_sequences(self.base_path, sets) 30 | 31 | def get_sequence_list(self): 32 | return SequenceList([self._construct_sequence(set, seq_name) for set, seq_name in self.sequence_list]) 33 | 34 | def _construct_sequence(self, set, sequence_name): 35 | anno_path = '{}/{}/anno/{}.txt'.format(self.base_path, set, sequence_name) 36 | 37 | ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64, backend='numpy') 38 | 39 | frames_path = '{}/{}/frames/{}'.format(self.base_path, set, sequence_name) 40 | frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")] 41 | frame_list.sort(key=lambda f: int(f[:-4])) 42 | frames_list = [os.path.join(frames_path, frame) for frame in frame_list] 43 | 44 | return Sequence(sequence_name, frames_list, 'trackingnet', ground_truth_rect.reshape(-1, 4)) 45 | 46 | def __len__(self): 47 | return len(self.sequence_list) 48 | 49 | def _list_sequences(self, root, set_ids): 50 | sequence_list = [] 51 | 52 | for s in set_ids: 53 | anno_dir = os.path.join(root, s, "anno") 54 | sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')] 55 | 56 | sequence_list += sequences_cur_set 57 | 58 | return sequence_list 59 | -------------------------------------------------------------------------------- /lib/utils/focal_loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FocalLoss(nn.Module, ABC): 9 | def __init__(self, alpha=2, beta=4): 10 | super(FocalLoss, self).__init__() 11 | self.alpha = alpha 12 | self.beta = beta 13 | 14 | def forward(self, prediction, target): 15 | positive_index = target.eq(1).float() 16 | negative_index = target.lt(1).float() 17 | 18 | negative_weights = torch.pow(1 - target, self.beta) 19 | # clamp min value is set to 1e-12 to maintain the numerical stability 20 | prediction = torch.clamp(prediction, 1e-12) 21 | 22 | positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index 23 | negative_loss = torch.log(1 - prediction) * torch.pow(prediction, 24 | self.alpha) * negative_weights * negative_index 25 | 26 | num_positive = positive_index.float().sum() 27 | positive_loss = positive_loss.sum() 28 | negative_loss = negative_loss.sum() 29 | 30 | if num_positive == 0: 31 | loss = -negative_loss 32 | else: 33 | loss = -(positive_loss + negative_loss) / num_positive 34 | 35 | return loss 36 | 37 | 38 | class LBHinge(nn.Module): 39 | """Loss that uses a 'hinge' on the lower bound. 40 | This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is 41 | also smaller than that threshold. 42 | args: 43 | error_matric: What base loss to use (MSE by default). 44 | threshold: Threshold to use for the hinge. 45 | clip: Clip the loss if it is above this value. 46 | """ 47 | def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None): 48 | super().__init__() 49 | self.error_metric = error_metric 50 | self.threshold = threshold if threshold is not None else -100 51 | self.clip = clip 52 | 53 | def forward(self, prediction, label, target_bb=None): 54 | negative_mask = (label < self.threshold).float() 55 | positive_mask = (1.0 - negative_mask) 56 | 57 | prediction = negative_mask * F.relu(prediction) + positive_mask * prediction 58 | 59 | loss = self.error_metric(prediction, positive_mask * label) 60 | 61 | if self.clip is not None: 62 | loss = torch.min(loss, torch.tensor([self.clip], device=loss.device)) 63 | return loss -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | echo "****************** Installing pytorch locally and manually !!! ******************" 2 | 3 | echo "" 4 | echo "" 5 | echo "****************** Installing yaml ******************" 6 | pip install PyYAML 7 | 8 | echo "" 9 | echo "" 10 | echo "****************** Installing easydict ******************" 11 | pip install easydict 12 | 13 | echo "" 14 | echo "" 15 | echo "****************** Installing cython ******************" 16 | pip install cython 17 | 18 | echo "" 19 | echo "" 20 | echo "****************** Installing opencv-python ******************" 21 | pip install opencv-python 22 | 23 | echo "" 24 | echo "" 25 | echo "****************** Installing pandas ******************" 26 | pip install pandas 27 | 28 | echo "" 29 | echo "" 30 | echo "****************** Installing tqdm ******************" 31 | conda install -y tqdm 32 | 33 | echo "" 34 | echo "" 35 | echo "****************** Installing coco toolkit ******************" 36 | pip install pycocotools 37 | 38 | echo "" 39 | echo "" 40 | echo "****************** Installing jpeg4py python wrapper ******************" 41 | pip install jpeg4py 42 | 43 | echo "" 44 | echo "" 45 | echo "****************** Installing tensorboard ******************" 46 | pip install tb-nightly 47 | 48 | echo "" 49 | echo "" 50 | echo "****************** Installing tikzplotlib ******************" 51 | pip install tikzplotlib 52 | 53 | echo "" 54 | echo "" 55 | echo "****************** Installing thop tool for FLOPs and Params computing ******************" 56 | pip install thop 57 | 58 | echo "" 59 | echo "" 60 | echo "****************** Installing colorama ******************" 61 | pip install colorama 62 | 63 | echo "" 64 | echo "" 65 | echo "****************** Installing lmdb ******************" 66 | pip install lmdb 67 | 68 | echo "" 69 | echo "" 70 | echo "****************** Installing scipy ******************" 71 | pip install scipy 72 | 73 | echo "" 74 | echo "" 75 | echo "****************** Installing visdom ******************" 76 | pip install visdom 77 | 78 | 79 | echo "" 80 | echo "" 81 | echo "****************** Installing tensorboardX ******************" 82 | pip install tensorboardX 83 | 84 | 85 | echo "" 86 | echo "" 87 | echo "****************** Downgrade setuptools ******************" 88 | pip install setuptools==59.5.0 89 | 90 | 91 | echo "" 92 | echo "" 93 | echo "****************** Installing wandb ******************" 94 | pip install wandb 95 | 96 | echo "" 97 | echo "" 98 | echo "****************** Installing timm ******************" 99 | pip install timm 100 | 101 | echo "" 102 | echo "" 103 | echo "****************** Installation complete! ******************" 104 | -------------------------------------------------------------------------------- /lib/train/dataset/base_image_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from lib.train.data.image_loader import jpeg4py_loader 3 | 4 | 5 | class BaseImageDataset(torch.utils.data.Dataset): 6 | """ Base class for image datasets """ 7 | 8 | def __init__(self, name, root, image_loader=jpeg4py_loader): 9 | """ 10 | args: 11 | root - The root path to the dataset 12 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 13 | is used by default. 14 | """ 15 | self.name = name 16 | self.root = root 17 | self.image_loader = image_loader 18 | 19 | self.image_list = [] # Contains the list of sequences. 20 | self.class_list = [] 21 | 22 | def __len__(self): 23 | """ Returns size of the dataset 24 | returns: 25 | int - number of samples in the dataset 26 | """ 27 | return self.get_num_images() 28 | 29 | def __getitem__(self, index): 30 | """ Not to be used! Check get_frames() instead. 31 | """ 32 | return None 33 | 34 | def get_name(self): 35 | """ Name of the dataset 36 | 37 | returns: 38 | string - Name of the dataset 39 | """ 40 | raise NotImplementedError 41 | 42 | def get_num_images(self): 43 | """ Number of sequences in a dataset 44 | 45 | returns: 46 | int - number of sequences in the dataset.""" 47 | return len(self.image_list) 48 | 49 | def has_class_info(self): 50 | return False 51 | 52 | def get_class_name(self, image_id): 53 | return None 54 | 55 | def get_num_classes(self): 56 | return len(self.class_list) 57 | 58 | def get_class_list(self): 59 | return self.class_list 60 | 61 | def get_images_in_class(self, class_name): 62 | raise NotImplementedError 63 | 64 | def has_segmentation_info(self): 65 | return False 66 | 67 | def get_image_info(self, seq_id): 68 | """ Returns information about a particular image, 69 | 70 | args: 71 | seq_id - index of the image 72 | 73 | returns: 74 | Dict 75 | """ 76 | raise NotImplementedError 77 | 78 | def get_image(self, image_id, anno=None): 79 | """ Get a image 80 | 81 | args: 82 | image_id - index of image 83 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded. 84 | 85 | returns: 86 | image - 87 | anno - 88 | dict - A dict containing meta information about the sequence, e.g. class of the target object. 89 | 90 | """ 91 | raise NotImplementedError 92 | 93 | -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/uninitialized_copy.cuh: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 | * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | #include 31 | 32 | #include 33 | 34 | 35 | namespace detail 36 | { 37 | 38 | #if defined(_NVHPC_CUDA) 39 | template 40 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 41 | { 42 | // NVBug 3384810 43 | new (ptr) T(::cuda::std::forward(val)); 44 | } 45 | #else 46 | template ::value, 50 | int 51 | >::type = 0> 52 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 53 | { 54 | *ptr = ::cuda::std::forward(val); 55 | } 56 | 57 | template ::value, 61 | int 62 | >::type = 0> 63 | __host__ __device__ void uninitialized_copy(T *ptr, U &&val) 64 | { 65 | new (ptr) T(::cuda::std::forward(val)); 66 | } 67 | #endif 68 | 69 | } // namespace detail 70 | -------------------------------------------------------------------------------- /lib/train/data_specs/depthtrack_train.txt: -------------------------------------------------------------------------------- 1 | adapter02_indoor 2 | bag03_indoor 3 | bag05_indoor 4 | ball02_indoor 5 | ball03_indoor 6 | ball04_indoor 7 | ball05_indoor 8 | ball07_indoor 9 | ball08_wild 10 | ball09_wild 11 | ball12_wild 12 | ball13_indoor 13 | ball14_wild 14 | ball17_wild 15 | ball19_indoor 16 | ball21_indoor 17 | basket_indoor 18 | beautifullight01_indoor 19 | bike01_wild 20 | bike02_wild 21 | bike03_wild 22 | book01_indoor 23 | book02_indoor 24 | book04_indoor 25 | book05_indoor 26 | book06_indoor 27 | bottle01_indoor 28 | bottle02_indoor 29 | bottle05_indoor 30 | bottle06_indoor 31 | box_indoor 32 | candlecup_indoor 33 | car01_indoor 34 | car02_indoor 35 | cart_indoor 36 | cat02_indoor 37 | cat03_indoor 38 | cat04_indoor 39 | cat05_indoor 40 | chair01_indoor 41 | chair02_indoor 42 | clothes_indoor 43 | colacan01_indoor 44 | colacan02_indoor 45 | colacan04_indoor 46 | container01_indoor 47 | container02_indoor 48 | cube01_indoor 49 | cube04_indoor 50 | cube06_indoor 51 | cup03_indoor 52 | cup05_indoor 53 | cup06_indoor 54 | cup07_indoor 55 | cup08_indoor 56 | cup09_indoor 57 | cup10_indoor 58 | cup11_indoor 59 | cup13_indoor 60 | cup14_indoor 61 | duck01_wild 62 | duck02_wild 63 | duck04_wild 64 | duck05_wild 65 | duck06_wild 66 | dumbbells02_indoor 67 | earphone02_indoor 68 | egg_indoor 69 | file02_indoor 70 | flower01_indoor 71 | flower02_wild 72 | flowerbasket_indoor 73 | ghostmask_indoor 74 | glass02_indoor 75 | glass03_indoor 76 | glass04_indoor 77 | glass05_indoor 78 | guitarbag_indoor 79 | gymring_wild 80 | hand02_indoor 81 | hat01_indoor 82 | hat02_indoor_320 83 | hat03_indoor 84 | hat04_indoor 85 | human01_indoor 86 | human03_wild 87 | human04_wild 88 | human05_wild 89 | human06_indoor 90 | leaves01_wild 91 | leaves02_indoor 92 | leaves03_wild 93 | leaves04_indoor 94 | leaves05_indoor 95 | leaves06_wild 96 | lock01_wild 97 | mac_indoor 98 | milkbottle_indoor 99 | mirror_indoor 100 | mobilephone01_indoor 101 | mobilephone02_indoor 102 | mobilephone04_indoor 103 | mobilephone05_indoor 104 | mobilephone06_indoor 105 | mushroom01_indoor 106 | mushroom02_wild 107 | mushroom03_wild 108 | mushroom04_indoor 109 | mushroom05_indoor 110 | notebook02_indoor 111 | notebook03_indoor 112 | paintbottle_indoor 113 | painting_indoor_320 114 | parkingsign_wild 115 | pigeon03_wild 116 | pigeon06_wild 117 | pigeon07_wild 118 | pine01_indoor 119 | pine02_wild_320 120 | shoes01_indoor 121 | shoes03_indoor 122 | skateboard01_indoor 123 | skateboard02_indoor 124 | speaker_indoor 125 | stand_indoor 126 | suitcase_indoor 127 | swing01_wild 128 | swing02_wild 129 | teacup_indoor 130 | thermos01_indoor 131 | thermos02_indoor 132 | toiletpaper02_indoor 133 | toiletpaper03_indoor 134 | toiletpaper04_indoor 135 | toy01_indoor 136 | toy04_indoor 137 | toy05_indoor 138 | toy06_indoor 139 | toy07_indoor_320 140 | toy08_indoor 141 | toy10_indoor 142 | toydog_indoor 143 | trashbin_indoor 144 | tree_wild 145 | trophy_indoor 146 | ukulele02_indoor 147 | -------------------------------------------------------------------------------- /lib/test/tracker/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from lib.utils.misc import NestedTensor 4 | 5 | 6 | class Preprocessor(object): 7 | def __init__(self): 8 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda() 9 | self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda() 10 | 11 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): 12 | # Deal with the image patch 13 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0) 14 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W) 15 | # Deal with the attention mask 16 | amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W) 17 | return NestedTensor(img_tensor_norm, amask_tensor) 18 | 19 | class PreprocessorMM(object): 20 | def __init__(self): 21 | self.mean = torch.tensor([0.485, 0.456, 0.406, 0.485, 0.456, 0.406]).view((1, 6, 1, 1)).cuda() 22 | self.std = torch.tensor([0.229, 0.224, 0.225, 0.229, 0.224, 0.225]).view((1, 6, 1, 1)).cuda() 23 | 24 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): 25 | # Deal with the image patch 26 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0) 27 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,6,H,W) 28 | amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W) 29 | return NestedTensor(img_tensor_norm, amask_tensor) 30 | 31 | 32 | class PreprocessorX(object): 33 | def __init__(self): 34 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda() 35 | self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda() 36 | 37 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): 38 | # Deal with the image patch 39 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0) 40 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W) 41 | # Deal with the attention mask 42 | amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W) 43 | return img_tensor_norm, amask_tensor 44 | 45 | 46 | class PreprocessorX_onnx(object): 47 | def __init__(self): 48 | self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1)) 49 | self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1)) 50 | 51 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): 52 | """img_arr: (H,W,3), amask_arr: (H,W)""" 53 | # Deal with the image patch 54 | img_arr_4d = img_arr[np.newaxis, :, :, :].transpose(0, 3, 1, 2) 55 | img_arr_4d = (img_arr_4d / 255.0 - self.mean) / self.std # (1, 3, H, W) 56 | # Deal with the attention mask 57 | amask_arr_3d = amask_arr[np.newaxis, :, :] # (1,H,W) 58 | return img_arr_4d.astype(np.float32), amask_arr_3d.astype(np.bool) 59 | -------------------------------------------------------------------------------- /lib/train/data_specs/depthtrack_all.txt: -------------------------------------------------------------------------------- 1 | adapter02_indoor 2 | bag03_indoor 3 | bag05_indoor 4 | ball02_indoor 5 | ball03_indoor 6 | ball04_indoor 7 | ball05_indoor 8 | ball07_indoor 9 | ball08_wild 10 | ball09_wild 11 | ball12_wild 12 | ball13_indoor 13 | ball14_wild 14 | ball17_wild 15 | ball19_indoor 16 | ball21_indoor 17 | basket_indoor 18 | beautifullight01_indoor 19 | bike01_wild 20 | bike02_wild 21 | bike03_wild 22 | book01_indoor 23 | book02_indoor 24 | book04_indoor 25 | book05_indoor 26 | book06_indoor 27 | bottle01_indoor 28 | bottle02_indoor 29 | bottle05_indoor 30 | bottle06_indoor 31 | box_indoor 32 | candlecup_indoor 33 | car01_indoor 34 | car02_indoor 35 | cart_indoor 36 | cat02_indoor 37 | cat03_indoor 38 | cat04_indoor 39 | cat05_indoor 40 | chair01_indoor 41 | chair02_indoor 42 | clothes_indoor 43 | colacan01_indoor 44 | colacan02_indoor 45 | colacan04_indoor 46 | container01_indoor 47 | container02_indoor 48 | cube01_indoor 49 | cube04_indoor 50 | cube06_indoor 51 | cup03_indoor 52 | cup05_indoor 53 | cup06_indoor 54 | cup07_indoor 55 | cup08_indoor 56 | cup09_indoor 57 | cup10_indoor 58 | cup11_indoor 59 | cup13_indoor 60 | cup14_indoor 61 | duck01_wild 62 | duck02_wild 63 | duck04_wild 64 | duck05_wild 65 | duck06_wild 66 | dumbbells02_indoor 67 | earphone02_indoor 68 | egg_indoor 69 | file02_indoor 70 | flower01_indoor 71 | flower02_wild 72 | flowerbasket_indoor 73 | ghostmask_indoor 74 | glass02_indoor 75 | glass03_indoor 76 | glass04_indoor 77 | glass05_indoor 78 | guitarbag_indoor 79 | gymring_wild 80 | hand02_indoor 81 | hat01_indoor 82 | hat02_indoor_320 83 | hat03_indoor 84 | hat04_indoor 85 | human01_indoor 86 | human03_wild 87 | human04_wild 88 | human05_wild 89 | human06_indoor 90 | leaves01_wild 91 | leaves02_indoor 92 | leaves03_wild 93 | leaves04_indoor 94 | leaves05_indoor 95 | leaves06_wild 96 | lock01_wild 97 | mac_indoor 98 | milkbottle_indoor 99 | mirror_indoor 100 | mobilephone01_indoor 101 | mobilephone02_indoor 102 | mobilephone04_indoor 103 | mobilephone05_indoor 104 | mobilephone06_indoor 105 | mushroom01_indoor 106 | mushroom02_wild 107 | mushroom03_wild 108 | mushroom04_indoor 109 | mushroom05_indoor 110 | notebook02_indoor 111 | notebook03_indoor 112 | paintbottle_indoor 113 | painting_indoor_320 114 | parkingsign_wild 115 | pigeon03_wild 116 | pigeon06_wild 117 | pigeon07_wild 118 | pine01_indoor 119 | pine02_wild_320 120 | shoes01_indoor 121 | shoes03_indoor 122 | skateboard01_indoor 123 | skateboard02_indoor 124 | speaker_indoor 125 | stand_indoor 126 | suitcase_indoor 127 | swing01_wild 128 | swing02_wild 129 | teacup_indoor 130 | thermos01_indoor 131 | thermos02_indoor 132 | toiletpaper02_indoor 133 | toiletpaper03_indoor 134 | toiletpaper04_indoor 135 | toy01_indoor 136 | toy04_indoor 137 | toy05_indoor 138 | toy06_indoor 139 | toy07_indoor_320 140 | toy08_indoor 141 | toy10_indoor 142 | toydog_indoor 143 | trashbin_indoor 144 | tree_wild 145 | trophy_indoor 146 | ukulele02_indoor 147 | toy03_indoor 148 | pigeon05_wild 149 | bottle03_indoor 150 | ball16_indoor 151 | bag04_indoor 152 | flower03_indoor -------------------------------------------------------------------------------- /mamba-1p1p1/csrc/selective_scan/selective_scan.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct SSMScanParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, seqlen, n_chunks; 13 | index_t a_batch_stride; 14 | index_t b_batch_stride; 15 | index_t out_batch_stride; 16 | 17 | // Common data pointers. 18 | void *__restrict__ a_ptr; 19 | void *__restrict__ b_ptr; 20 | void *__restrict__ out_ptr; 21 | void *__restrict__ x_ptr; 22 | }; 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | struct SSMParamsBase { 27 | using index_t = uint32_t; 28 | 29 | int batch, dim, seqlen, dstate, n_groups, n_chunks; 30 | int dim_ngroups_ratio; 31 | bool is_variable_B; 32 | bool is_variable_C; 33 | 34 | bool delta_softplus; 35 | 36 | index_t A_d_stride; 37 | index_t A_dstate_stride; 38 | index_t B_batch_stride; 39 | index_t B_d_stride; 40 | index_t B_dstate_stride; 41 | index_t B_group_stride; 42 | index_t C_batch_stride; 43 | index_t C_d_stride; 44 | index_t C_dstate_stride; 45 | index_t C_group_stride; 46 | index_t u_batch_stride; 47 | index_t u_d_stride; 48 | index_t delta_batch_stride; 49 | index_t delta_d_stride; 50 | index_t z_batch_stride; 51 | index_t z_d_stride; 52 | index_t out_batch_stride; 53 | index_t out_d_stride; 54 | index_t out_z_batch_stride; 55 | index_t out_z_d_stride; 56 | 57 | // Common data pointers. 58 | void *__restrict__ A_ptr; 59 | void *__restrict__ B_ptr; 60 | void *__restrict__ C_ptr; 61 | void *__restrict__ D_ptr; 62 | void *__restrict__ u_ptr; 63 | void *__restrict__ delta_ptr; 64 | void *__restrict__ delta_bias_ptr; 65 | void *__restrict__ out_ptr; 66 | void *__restrict__ x_ptr; 67 | void *__restrict__ z_ptr; 68 | void *__restrict__ out_z_ptr; 69 | }; 70 | 71 | struct SSMParamsBwd: public SSMParamsBase { 72 | index_t dout_batch_stride; 73 | index_t dout_d_stride; 74 | index_t dA_d_stride; 75 | index_t dA_dstate_stride; 76 | index_t dB_batch_stride; 77 | index_t dB_group_stride; 78 | index_t dB_d_stride; 79 | index_t dB_dstate_stride; 80 | index_t dC_batch_stride; 81 | index_t dC_group_stride; 82 | index_t dC_d_stride; 83 | index_t dC_dstate_stride; 84 | index_t du_batch_stride; 85 | index_t du_d_stride; 86 | index_t dz_batch_stride; 87 | index_t dz_d_stride; 88 | index_t ddelta_batch_stride; 89 | index_t ddelta_d_stride; 90 | 91 | // Common data pointers. 92 | void *__restrict__ dout_ptr; 93 | void *__restrict__ dA_ptr; 94 | void *__restrict__ dB_ptr; 95 | void *__restrict__ dC_ptr; 96 | void *__restrict__ dD_ptr; 97 | void *__restrict__ du_ptr; 98 | void *__restrict__ dz_ptr; 99 | void *__restrict__ ddelta_ptr; 100 | void *__restrict__ ddelta_bias_ptr; 101 | }; 102 | -------------------------------------------------------------------------------- /lib/train/data/bounding_box_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rect_to_rel(bb, sz_norm=None): 5 | """Convert standard rectangular parametrization of the bounding box [x, y, w, h] 6 | to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate. 7 | args: 8 | bb - N x 4 tensor of boxes. 9 | sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given. 10 | """ 11 | 12 | c = bb[...,:2] + 0.5 * bb[...,2:] 13 | if sz_norm is None: 14 | c_rel = c / bb[...,2:] 15 | else: 16 | c_rel = c / sz_norm 17 | sz_rel = torch.log(bb[...,2:]) 18 | return torch.cat((c_rel, sz_rel), dim=-1) 19 | 20 | 21 | def rel_to_rect(bb, sz_norm=None): 22 | """Inverts the effect of rect_to_rel. See above.""" 23 | 24 | sz = torch.exp(bb[...,2:]) 25 | if sz_norm is None: 26 | c = bb[...,:2] * sz 27 | else: 28 | c = bb[...,:2] * sz_norm 29 | tl = c - 0.5 * sz 30 | return torch.cat((tl, sz), dim=-1) 31 | 32 | 33 | def masks_to_bboxes(mask, fmt='c'): 34 | 35 | """ Convert a mask tensor to one or more bounding boxes. 36 | Note: This function is a bit new, make sure it does what it says. /Andreas 37 | :param mask: Tensor of masks, shape = (..., H, W) 38 | :param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height) 39 | 't' => "top left + size" or (x_left, y_top, width, height) 40 | 'v' => "vertices" or (x_left, y_top, x_right, y_bottom) 41 | :return: tensor containing a batch of bounding boxes, shape = (..., 4) 42 | """ 43 | batch_shape = mask.shape[:-2] 44 | mask = mask.reshape((-1, *mask.shape[-2:])) 45 | bboxes = [] 46 | 47 | for m in mask: 48 | mx = m.sum(dim=-2).nonzero() 49 | my = m.sum(dim=-1).nonzero() 50 | bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0] 51 | bboxes.append(bb) 52 | 53 | bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device) 54 | bboxes = bboxes.reshape(batch_shape + (4,)) 55 | 56 | if fmt == 'v': 57 | return bboxes 58 | 59 | x1 = bboxes[..., :2] 60 | s = bboxes[..., 2:] - x1 + 1 61 | 62 | if fmt == 'c': 63 | return torch.cat((x1 + 0.5 * s, s), dim=-1) 64 | elif fmt == 't': 65 | return torch.cat((x1, s), dim=-1) 66 | 67 | raise ValueError("Undefined bounding box layout '%s'" % fmt) 68 | 69 | 70 | def masks_to_bboxes_multi(mask, ids, fmt='c'): 71 | assert mask.dim() == 2 72 | bboxes = [] 73 | 74 | for id in ids: 75 | mx = (mask == id).sum(dim=-2).nonzero() 76 | my = (mask == id).float().sum(dim=-1).nonzero() 77 | bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0] 78 | 79 | bb = torch.tensor(bb, dtype=torch.float32, device=mask.device) 80 | 81 | x1 = bb[:2] 82 | s = bb[2:] - x1 + 1 83 | 84 | if fmt == 'v': 85 | pass 86 | elif fmt == 'c': 87 | bb = torch.cat((x1 + 0.5 * s, s), dim=-1) 88 | elif fmt == 't': 89 | bb = torch.cat((x1, s), dim=-1) 90 | else: 91 | raise ValueError("Undefined bounding box layout '%s'" % fmt) 92 | bboxes.append(bb) 93 | 94 | return bboxes 95 | -------------------------------------------------------------------------------- /lib/test/evaluation/itbdataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 3 | from lib.test.utils.load_text import load_text 4 | import os 5 | 6 | 7 | class ITBDataset(BaseDataset): 8 | """ NUS-PRO dataset 9 | """ 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.base_path = self.env_settings.itb_path 14 | self.sequence_info_list = self._get_sequence_info_list(self.base_path) 15 | 16 | def get_sequence_list(self): 17 | return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list]) 18 | 19 | def _construct_sequence(self, sequence_info): 20 | sequence_path = sequence_info['path'] 21 | nz = sequence_info['nz'] 22 | ext = sequence_info['ext'] 23 | start_frame = sequence_info['startFrame'] 24 | end_frame = sequence_info['endFrame'] 25 | 26 | init_omit = 0 27 | if 'initOmit' in sequence_info: 28 | init_omit = sequence_info['initOmit'] 29 | 30 | frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path, 31 | sequence_path=sequence_path, frame=frame_num, 32 | nz=nz, ext=ext) for frame_num in 33 | range(start_frame + init_omit, end_frame + 1)] 34 | 35 | anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path']) 36 | 37 | # NOTE: NUS has some weird annos which panda cannot handle 38 | ground_truth_rect = load_text(str(anno_path), delimiter=(',', None), dtype=np.float64, backend='numpy') 39 | return Sequence(sequence_info['name'], frames, 'otb', ground_truth_rect[init_omit:, :], 40 | object_class=sequence_info['object_class']) 41 | 42 | def __len__(self): 43 | return len(self.sequence_info_list) 44 | 45 | def get_fileNames(self, rootdir): 46 | fs = [] 47 | fs_all = [] 48 | for root, dirs, files in os.walk(rootdir, topdown=True): 49 | files.sort() 50 | files.sort(key=len) 51 | if files is not None: 52 | for name in files: 53 | _, ending = os.path.splitext(name) 54 | if ending == ".jpg": 55 | _, root_ = os.path.split(root) 56 | fs.append(os.path.join(root_, name)) 57 | fs_all.append(os.path.join(root, name)) 58 | 59 | return fs_all, fs 60 | 61 | def _get_sequence_info_list(self, base_path): 62 | sequence_info_list = [] 63 | for scene in os.listdir(base_path): 64 | if '.' in scene: 65 | continue 66 | videos = os.listdir(os.path.join(base_path, scene)) 67 | for video in videos: 68 | _, fs = self.get_fileNames(os.path.join(base_path, scene, video)) 69 | video_tmp = {"name": video, "path": scene + '/' + video, "startFrame": 1, "endFrame": len(fs), 70 | "nz": len(fs[0].split('/')[-1].split('.')[0]), "ext": "jpg", 71 | "anno_path": scene + '/' + video + "/groundtruth.txt", 72 | "object_class": "unknown"} 73 | sequence_info_list.append(video_tmp) 74 | 75 | return sequence_info_list # sequence_info_list_50 # 76 | -------------------------------------------------------------------------------- /mamba-1p1p1/benchmarks/benchmark_generation_mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import argparse 4 | import time 5 | import json 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from einops import rearrange 11 | 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | 14 | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Generation benchmarking") 18 | parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") 19 | parser.add_argument("--prompt", type=str, default=None) 20 | parser.add_argument("--promptlen", type=int, default=100) 21 | parser.add_argument("--genlen", type=int, default=100) 22 | parser.add_argument("--temperature", type=float, default=1.0) 23 | parser.add_argument("--topk", type=int, default=1) 24 | parser.add_argument("--topp", type=float, default=1.0) 25 | parser.add_argument("--repetition-penalty", type=float, default=1.0) 26 | parser.add_argument("--batch", type=int, default=1) 27 | args = parser.parse_args() 28 | 29 | repeats = 3 30 | device = "cuda" 31 | dtype = torch.float16 32 | 33 | print(f"Loading model {args.model_name}") 34 | is_mamba = args.model_name.startswith("state-spaces/mamba-") 35 | if is_mamba: 36 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 37 | model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) 38 | else: 39 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 40 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) 41 | model.eval() 42 | print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 43 | 44 | torch.random.manual_seed(0) 45 | if args.prompt is None: 46 | input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") 47 | attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") 48 | else: 49 | tokens = tokenizer(args.prompt, return_tensors="pt") 50 | input_ids = tokens.input_ids.to(device=device) 51 | attn_mask = tokens.attention_mask.to(device=device) 52 | max_length = input_ids.shape[1] + args.genlen 53 | 54 | if is_mamba: 55 | fn = lambda: model.generate( 56 | input_ids=input_ids, 57 | max_length=max_length, 58 | cg=True, 59 | return_dict_in_generate=True, 60 | output_scores=True, 61 | enable_timing=False, 62 | temperature=args.temperature, 63 | top_k=args.topk, 64 | top_p=args.topp, 65 | repetition_penalty=args.repetition_penalty, 66 | ) 67 | else: 68 | fn = lambda: model.generate( 69 | input_ids=input_ids, 70 | attention_mask=attn_mask, 71 | max_length=max_length, 72 | return_dict_in_generate=True, 73 | pad_token_id=tokenizer.eos_token_id, 74 | do_sample=True, 75 | temperature=args.temperature, 76 | top_k=args.topk, 77 | top_p=args.topp, 78 | repetition_penalty=args.repetition_penalty, 79 | ) 80 | out = fn() 81 | if args.prompt is not None: 82 | print(tokenizer.batch_decode(out.sequences.tolist())) 83 | 84 | torch.cuda.synchronize() 85 | start = time.time() 86 | for _ in range(repeats): 87 | fn() 88 | torch.cuda.synchronize() 89 | print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") 90 | print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") 91 | -------------------------------------------------------------------------------- /lib/utils/ce_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def generate_bbox_mask(bbox_mask, bbox): 8 | b, h, w = bbox_mask.shape 9 | for i in range(b): 10 | bbox_i = bbox[i].cpu().tolist() 11 | bbox_mask[i, int(bbox_i[1]):int(bbox_i[1] + bbox_i[3] - 1), int(bbox_i[0]):int(bbox_i[0] + bbox_i[2] - 1)] = 1 12 | return bbox_mask 13 | 14 | 15 | def generate_mask_cond(cfg, bs, device, gt_bbox): 16 | template_size = cfg.DATA.TEMPLATE.SIZE 17 | stride = cfg.MODEL.BACKBONE.STRIDE 18 | template_feat_size = template_size // stride 19 | 20 | if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'ALL': 21 | box_mask_z = None 22 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT': 23 | if template_feat_size == 8: 24 | index = slice(3, 4) 25 | elif template_feat_size == 12: 26 | index = slice(5, 6) 27 | elif template_feat_size == 7: 28 | index = slice(3, 4) 29 | elif template_feat_size == 14: 30 | index = slice(6, 7) 31 | else: 32 | raise NotImplementedError 33 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 34 | box_mask_z[:, index, index] = 1 35 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 36 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_REC': 37 | # use fixed 4x4 region, 3:5 for 8x8 38 | # use fixed 4x4 region 5:6 for 12x12 39 | if template_feat_size == 8: 40 | index = slice(3, 5) 41 | elif template_feat_size == 12: 42 | index = slice(5, 7) 43 | elif template_feat_size == 7: 44 | index = slice(3, 4) 45 | else: 46 | raise NotImplementedError 47 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 48 | box_mask_z[:, index, index] = 1 49 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 50 | 51 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'GT_BOX': 52 | box_mask_z = torch.zeros([bs, template_size, template_size], device=device) 53 | # box_mask_z_ori = data['template_seg'][0].view(-1, 1, *data['template_seg'].shape[2:]) # (batch, 1, 128, 128) 54 | box_mask_z = generate_bbox_mask(box_mask_z, gt_bbox * template_size).unsqueeze(1).to( 55 | torch.float) # (batch, 1, 128, 128) 56 | # box_mask_z_vis = box_mask_z.cpu().numpy() 57 | box_mask_z = F.interpolate(box_mask_z, scale_factor=1. / cfg.MODEL.BACKBONE.STRIDE, mode='bilinear', 58 | align_corners=False) 59 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 60 | # box_mask_z_vis = box_mask_z[:, 0, ...].cpu().numpy() 61 | # gaussian_maps_vis = generate_heatmap(data['template_anno'], self.cfg.DATA.TEMPLATE.SIZE, self.cfg.MODEL.STRIDE)[0].cpu().numpy() 62 | else: 63 | raise NotImplementedError 64 | 65 | return box_mask_z 66 | 67 | 68 | def adjust_keep_rate(epoch, warmup_epochs, total_epochs, ITERS_PER_EPOCH, base_keep_rate=0.5, max_keep_rate=1, iters=-1): 69 | if epoch < warmup_epochs: 70 | return 1 71 | if epoch >= total_epochs: 72 | return base_keep_rate 73 | if iters == -1: 74 | iters = epoch * ITERS_PER_EPOCH 75 | total_iters = ITERS_PER_EPOCH * (total_epochs - warmup_epochs) 76 | iters = iters - ITERS_PER_EPOCH * warmup_epochs 77 | keep_rate = base_keep_rate + (max_keep_rate - base_keep_rate) \ 78 | * (math.cos(iters / total_iters * math.pi) + 1) * 0.5 79 | 80 | return keep_rate 81 | -------------------------------------------------------------------------------- /lib/train/dataset/base_video_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | # 2021.1.5 use jpeg4py_loader_w_failsafe as default 3 | from lib.train.data.image_loader import jpeg4py_loader_w_failsafe 4 | 5 | 6 | class BaseVideoDataset(torch.utils.data.Dataset): 7 | """ Base class for video datasets """ 8 | 9 | def __init__(self, name, root, image_loader=jpeg4py_loader_w_failsafe): 10 | """ 11 | args: 12 | root - The root path to the dataset 13 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 14 | is used by default. 15 | """ 16 | self.name = name 17 | self.root = root 18 | self.image_loader = image_loader 19 | 20 | self.sequence_list = [] # Contains the list of sequences. 21 | self.class_list = [] 22 | 23 | def __len__(self): 24 | """ Returns size of the dataset 25 | returns: 26 | int - number of samples in the dataset 27 | """ 28 | return self.get_num_sequences() 29 | 30 | def __getitem__(self, index): 31 | """ Not to be used! Check get_frames() instead. 32 | """ 33 | return None 34 | 35 | def is_video_sequence(self): 36 | """ Returns whether the dataset is a video dataset or an image dataset 37 | 38 | returns: 39 | bool - True if a video dataset 40 | """ 41 | return True 42 | 43 | def is_synthetic_video_dataset(self): 44 | """ Returns whether the dataset contains real videos or synthetic 45 | 46 | returns: 47 | bool - True if a video dataset 48 | """ 49 | return False 50 | 51 | def get_name(self): 52 | """ Name of the dataset 53 | 54 | returns: 55 | string - Name of the dataset 56 | """ 57 | raise NotImplementedError 58 | 59 | def get_num_sequences(self): 60 | """ Number of sequences in a dataset 61 | 62 | returns: 63 | int - number of sequences in the dataset.""" 64 | return len(self.sequence_list) 65 | 66 | def has_class_info(self): 67 | return False 68 | 69 | def has_occlusion_info(self): 70 | return False 71 | 72 | def get_num_classes(self): 73 | return len(self.class_list) 74 | 75 | def get_class_list(self): 76 | return self.class_list 77 | 78 | def get_sequences_in_class(self, class_name): 79 | raise NotImplementedError 80 | 81 | def has_segmentation_info(self): 82 | return False 83 | 84 | def get_sequence_info(self, seq_id): 85 | """ Returns information about a particular sequences, 86 | 87 | args: 88 | seq_id - index of the sequence 89 | 90 | returns: 91 | Dict 92 | """ 93 | raise NotImplementedError 94 | 95 | def get_frames(self, seq_id, frame_ids, anno=None): 96 | """ Get a set of frames from a particular sequence 97 | 98 | args: 99 | seq_id - index of sequence 100 | frame_ids - a list of frame numbers 101 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded. 102 | 103 | returns: 104 | list - List of frames corresponding to frame_ids 105 | list - List of dicts for each frame 106 | dict - A dict containing meta information about the sequence, e.g. class of the target object. 107 | 108 | """ 109 | raise NotImplementedError 110 | 111 | -------------------------------------------------------------------------------- /lib/train/data/image_loader.py: -------------------------------------------------------------------------------- 1 | import jpeg4py 2 | import cv2 as cv 3 | from PIL import Image 4 | import numpy as np 5 | 6 | davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8) 7 | davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 8 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 9 | [64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0], 10 | [64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128], 11 | [0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0], 12 | [0, 64, 128], [128, 64, 128]] 13 | 14 | 15 | def default_image_loader(path): 16 | """The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader, 17 | but reverts to the opencv_loader if the former is not available.""" 18 | if default_image_loader.use_jpeg4py is None: 19 | # Try using jpeg4py 20 | im = jpeg4py_loader(path) 21 | if im is None: 22 | default_image_loader.use_jpeg4py = False 23 | print('Using opencv_loader instead.') 24 | else: 25 | default_image_loader.use_jpeg4py = True 26 | return im 27 | if default_image_loader.use_jpeg4py: 28 | return jpeg4py_loader(path) 29 | return opencv_loader(path) 30 | 31 | default_image_loader.use_jpeg4py = None 32 | 33 | 34 | def jpeg4py_loader(path): 35 | """ Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py""" 36 | try: 37 | return jpeg4py.JPEG(path).decode() 38 | except Exception as e: 39 | print('ERROR: Could not read image "{}"'.format(path)) 40 | print(e) 41 | return None 42 | 43 | 44 | def opencv_loader(path): 45 | """ Read image using opencv's imread function and returns it in rgb format""" 46 | try: 47 | im = cv.imread(path, cv.IMREAD_COLOR) 48 | 49 | # convert to rgb and return 50 | return cv.cvtColor(im, cv.COLOR_BGR2RGB) 51 | except Exception as e: 52 | print('ERROR: Could not read image "{}"'.format(path)) 53 | print(e) 54 | return None 55 | 56 | 57 | def jpeg4py_loader_w_failsafe(path): 58 | """ Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py""" 59 | try: 60 | return jpeg4py.JPEG(path).decode() 61 | except: 62 | try: 63 | im = cv.imread(path, cv.IMREAD_COLOR) 64 | 65 | # convert to rgb and return 66 | return cv.cvtColor(im, cv.COLOR_BGR2RGB) 67 | except Exception as e: 68 | print('ERROR: Could not read image "{}"'.format(path)) 69 | print(e) 70 | return None 71 | 72 | 73 | def opencv_seg_loader(path): 74 | """ Read segmentation annotation using opencv's imread function""" 75 | try: 76 | return cv.imread(path) 77 | except Exception as e: 78 | print('ERROR: Could not read image "{}"'.format(path)) 79 | print(e) 80 | return None 81 | 82 | 83 | def imread_indexed(filename): 84 | """ Load indexed image with given filename. Used to read segmentation annotations.""" 85 | 86 | im = Image.open(filename) 87 | 88 | annotation = np.atleast_3d(im)[...,0] 89 | return annotation 90 | 91 | 92 | def imwrite_indexed(filename, array, color_palette=None): 93 | """ Save indexed image as png. Used to save segmentation annotation.""" 94 | 95 | if color_palette is None: 96 | color_palette = davis_palette 97 | 98 | if np.atleast_3d(array).shape[2] != 1: 99 | raise Exception("Saving indexed PNGs requires 2D array.") 100 | 101 | im = Image.fromarray(array) 102 | im.putpalette(color_palette.ravel()) 103 | im.save(filename, format='PNG') -------------------------------------------------------------------------------- /lib/train/dataset/coesot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | import torch 5 | import csv 6 | import pandas 7 | import random 8 | from collections import OrderedDict 9 | from .base_video_dataset import BaseVideoDataset 10 | from lib.train.data import jpeg4py_loader_w_failsafe 11 | from lib.train.admin import env_settings 12 | from lib.train.dataset.depth_utils import get_x_frame 13 | 14 | 15 | class COESOT(BaseVideoDataset): 16 | def __init__(self, root=None, split='train', dtype='rgbrgb', image_loader=jpeg4py_loader_w_failsafe,): 17 | 18 | root = env_settings().coesot_train_dir if root is None else root 19 | super().__init__('COESOT', root, image_loader) 20 | 21 | self.dtype = dtype # colormap or depth 22 | self.split = split 23 | self.sequence_list = self._get_sequence_list() 24 | 25 | def _get_sequence_list(self): 26 | dir_list = [i for i in os.listdir(os.path.join(self.root)) 27 | if os.path.isdir(os.path.join(self.root, i))] 28 | return dir_list 29 | 30 | def get_name(self): 31 | return 'coesot_' + self.split 32 | 33 | def get_num_sequences(self): 34 | return len(self.sequence_list) 35 | 36 | def _read_bb_anno(self, seq_path): 37 | bb_anno_file = os.path.join(seq_path, "groundtruth.txt") 38 | gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values 39 | return torch.tensor(gt) 40 | 41 | def _get_sequence_path(self, seq_id): 42 | return os.path.join(self.root, self.sequence_list[seq_id]) 43 | 44 | def get_sequence_info(self, seq_id): 45 | bbox_path = self._get_sequence_path(seq_id) 46 | bbox = self._read_bb_anno(bbox_path) 47 | 48 | # valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 49 | valid = (bbox[:, 2] > 4.0) & (bbox[:, 3] > 4.0) 50 | visible = valid.clone().byte().bool() 51 | return {'bbox': bbox, 'valid': valid, 'visible': visible, } 52 | 53 | def _get_frame_path(self, seq_path, frame_id): 54 | seq_name = seq_path.split('/')[-1] 55 | aps_dir = os.path.join(seq_path, seq_name + '_aps') 56 | dvs_dir = os.path.join(seq_path, seq_name + '_dvs') 57 | if os.path.exists(os.path.join(aps_dir, 'frame{:04}.png'.format(frame_id))): 58 | vis_path = os.path.join(aps_dir, 'frame{:04}.png'.format(frame_id)) 59 | else: 60 | vis_path = os.path.join(aps_dir, 'frame{:04}.bmp'.format(frame_id)) 61 | 62 | if os.path.exists(os.path.join(dvs_dir, 'frame{:04}.bmp'.format(frame_id))): 63 | event_path = os.path.join(dvs_dir, 'frame{:04}.bmp'.format(frame_id)) 64 | else: 65 | event_path = os.path.join(dvs_dir, 'frame{:04}.png'.format(frame_id)) 66 | 67 | return vis_path, event_path 68 | 69 | def _get_frame(self, seq_path, frame_id): 70 | color_path, event_path = self._get_frame_path(seq_path, frame_id) 71 | img = get_x_frame(color_path, event_path, dtype=self.dtype, depth_clip=False) 72 | return img # (h,w,6) 73 | 74 | 75 | def get_frames(self, seq_id, frame_ids, anno=None): 76 | seq_path = self._get_sequence_path(seq_id) 77 | 78 | if anno is None: 79 | anno = self.get_sequence_info(seq_id) 80 | 81 | anno_frames = {} 82 | for key, value in anno.items(): 83 | anno_frames[key] = [value[f_id, ...].clone() for ii, f_id in enumerate(frame_ids)] 84 | 85 | frame_list = [self._get_frame(seq_path, f_id) for ii, f_id in enumerate(frame_ids)] 86 | 87 | object_meta = OrderedDict({'object_class_name': None, 88 | 'motion_class': None, 89 | 'major_class': None, 90 | 'root_class': None, 91 | 'motion_adverb': None, 92 | }) 93 | 94 | return frame_list, anno_frames, object_meta 95 | 96 | -------------------------------------------------------------------------------- /lib/test/tracker/basetracker.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from _collections import OrderedDict 5 | 6 | from lib.train.data.processing_utils import transform_image_to_crop 7 | # from lib.vis.visdom_cus import Visdom 8 | 9 | 10 | class BaseTracker: 11 | """Base class for all trackers.""" 12 | 13 | def __init__(self, params): 14 | self.params = params 15 | self.visdom = None 16 | 17 | def predicts_segmentation_mask(self): 18 | return False 19 | 20 | def initialize(self, image, info: dict) -> dict: 21 | """Overload this function in your tracker. This should initialize the model.""" 22 | raise NotImplementedError 23 | 24 | def track(self, image, info: dict = None) -> dict: 25 | """Overload this function in your tracker. This should track in the frame and update the model.""" 26 | raise NotImplementedError 27 | 28 | def visdom_draw_tracking(self, image, box, segmentation=None): 29 | if isinstance(box, OrderedDict): 30 | box = [v for k, v in box.items()] 31 | else: 32 | box = (box,) 33 | if segmentation is None: 34 | self.visdom.register((image, *box), 'Tracking', 1, 'Tracking') 35 | else: 36 | self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking') 37 | 38 | def transform_bbox_to_crop(self, box_in, resize_factor, device, box_extract=None, crop_type='template'): 39 | # box_in: list [x1, y1, w, h], not normalized 40 | # box_extract: same as box_in 41 | # out bbox: Torch.tensor [1, 1, 4], x1y1wh, normalized 42 | if crop_type == 'template': 43 | crop_sz = torch.Tensor([self.params.template_size, self.params.template_size]) 44 | elif crop_type == 'search': 45 | crop_sz = torch.Tensor([self.params.search_size, self.params.search_size]) 46 | else: 47 | raise NotImplementedError 48 | 49 | box_in = torch.tensor(box_in) 50 | if box_extract is None: 51 | box_extract = box_in 52 | else: 53 | box_extract = torch.tensor(box_extract) 54 | template_bbox = transform_image_to_crop(box_in, box_extract, resize_factor, crop_sz, normalize=True) 55 | template_bbox = template_bbox.view(1, 1, 4).to(device) 56 | 57 | return template_bbox 58 | 59 | def _init_visdom(self, visdom_info, debug): 60 | visdom_info = {} if visdom_info is None else visdom_info 61 | self.pause_mode = False 62 | self.step = False 63 | self.next_seq = False 64 | if debug > 0 and visdom_info.get('use_visdom', True): 65 | try: 66 | self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'}, 67 | visdom_info=visdom_info) 68 | 69 | # # Show help 70 | # help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \ 71 | # 'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \ 72 | # 'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \ 73 | # 'block list.' 74 | # self.visdom.register(help_text, 'text', 1, 'Help') 75 | except: 76 | time.sleep(0.5) 77 | print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n' 78 | '!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!') 79 | 80 | def _visdom_ui_handler(self, data): 81 | if data['event_type'] == 'KeyPress': 82 | if data['key'] == ' ': 83 | self.pause_mode = not self.pause_mode 84 | 85 | elif data['key'] == 'ArrowRight' and self.pause_mode: 86 | self.step = True 87 | 88 | elif data['key'] == 'n': 89 | self.next_seq = True 90 | -------------------------------------------------------------------------------- /lib/train/admin/environment.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from collections import OrderedDict 4 | 5 | 6 | def create_default_local_file(): 7 | path = os.path.join(os.path.dirname(__file__), 'local.py') 8 | 9 | empty_str = '\'\'' 10 | default_settings = OrderedDict({ 11 | 'workspace_dir': empty_str, 12 | 'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'', 13 | 'pretrained_networks': 'self.workspace_dir + \'/pretrained_networks/\'', 14 | 'lasot_dir': empty_str, 15 | 'got10k_dir': empty_str, 16 | 'trackingnet_dir': empty_str, 17 | 'coco_dir': empty_str, 18 | 'lvis_dir': empty_str, 19 | 'sbd_dir': empty_str, 20 | 'imagenet_dir': empty_str, 21 | 'imagenetdet_dir': empty_str, 22 | 'ecssd_dir': empty_str, 23 | 'hkuis_dir': empty_str, 24 | 'msra10k_dir': empty_str, 25 | 'davis_dir': empty_str, 26 | 'youtubevos_dir': empty_str}) 27 | 28 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.', 29 | 'tensorboard_dir': 'Directory for tensorboard files.'} 30 | 31 | with open(path, 'w') as f: 32 | f.write('class EnvironmentSettings:\n') 33 | f.write(' def __init__(self):\n') 34 | 35 | for attr, attr_val in default_settings.items(): 36 | comment_str = None 37 | if attr in comment: 38 | comment_str = comment[attr] 39 | if comment_str is None: 40 | f.write(' self.{} = {}\n'.format(attr, attr_val)) 41 | else: 42 | f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str)) 43 | 44 | 45 | def create_default_local_file_ITP_train(workspace_dir, data_dir): 46 | path = os.path.join(os.path.dirname(__file__), 'local.py') 47 | 48 | empty_str = '\'\'' 49 | default_settings = OrderedDict({ 50 | 'workspace_dir': workspace_dir, 51 | 'tensorboard_dir': os.path.join(workspace_dir, 'tensorboard'), # Directory for tensorboard files. 52 | 'pretrained_networks': os.path.join(workspace_dir, 'pretrained_networks'), 53 | 'lasot_dir': os.path.join(data_dir, 'lasot'), 54 | 'got10k_dir': os.path.join(data_dir, 'got10k/train'), 55 | 'got10k_val_dir': os.path.join(data_dir, 'got10k/val'), 56 | 'coco_dir': os.path.join(data_dir, 'coco'), 57 | 'imagenet_dir': os.path.join(data_dir, 'vid'), 58 | 'depthtrack_dir': os.path.join(data_dir, 'depthtrack/train'), 59 | 'lasher_dir': os.path.join(data_dir, 'lasher/trainingset'), 60 | 'visevent_dir': os.path.join(data_dir, 'visevent/train'), 61 | }) 62 | 63 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.', 64 | 'tensorboard_dir': 'Directory for tensorboard files.'} 65 | 66 | with open(path, 'w') as f: 67 | f.write('class EnvironmentSettings:\n') 68 | f.write(' def __init__(self):\n') 69 | 70 | for attr, attr_val in default_settings.items(): 71 | comment_str = None 72 | if attr in comment: 73 | comment_str = comment[attr] 74 | if comment_str is None: 75 | if attr_val == empty_str: 76 | f.write(' self.{} = {}\n'.format(attr, attr_val)) 77 | else: 78 | f.write(' self.{} = \'{}\'\n'.format(attr, attr_val)) 79 | else: 80 | f.write(' self.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 81 | 82 | 83 | def env_settings(): 84 | env_module_name = 'lib.train.admin.local' 85 | try: 86 | env_module = importlib.import_module(env_module_name) 87 | return env_module.EnvironmentSettings() 88 | except: 89 | env_file = os.path.join(os.path.dirname(__file__), 'local.py') 90 | 91 | create_default_local_file() 92 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file)) 93 | -------------------------------------------------------------------------------- /lib/train/dataset/imagenetvid_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_video_dataset import BaseVideoDataset 3 | from lib.train.data import jpeg4py_loader 4 | import torch 5 | from collections import OrderedDict 6 | from lib.train.admin import env_settings 7 | from lib.utils.lmdb_utils import decode_img, decode_json 8 | 9 | 10 | def get_target_to_image_ratio(seq): 11 | anno = torch.Tensor(seq['anno']) 12 | img_sz = torch.Tensor(seq['image_size']) 13 | return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt() 14 | 15 | 16 | class ImagenetVID_lmdb(BaseVideoDataset): 17 | """ Imagenet VID dataset. 18 | 19 | Publication: 20 | ImageNet Large Scale Visual Recognition Challenge 21 | Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, 22 | Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei 23 | IJCV, 2015 24 | https://arxiv.org/pdf/1409.0575.pdf 25 | 26 | Download the dataset from http://image-net.org/ 27 | """ 28 | def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1): 29 | """ 30 | args: 31 | root - path to the imagenet vid dataset. 32 | image_loader (default_image_loader) - The function to read the images. If installed, 33 | jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else, 34 | opencv's imread is used. 35 | min_length - Minimum allowed sequence length. 36 | max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets 37 | which cover complete image. 38 | """ 39 | root = env_settings().imagenet_dir if root is None else root 40 | super().__init__("imagenetvid_lmdb", root, image_loader) 41 | 42 | sequence_list_dict = decode_json(root, "cache.json") 43 | self.sequence_list = sequence_list_dict 44 | 45 | # Filter the sequences based on min_length and max_target_area in the first frame 46 | self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and 47 | get_target_to_image_ratio(x) < max_target_area] 48 | 49 | def get_name(self): 50 | return 'imagenetvid_lmdb' 51 | 52 | def get_num_sequences(self): 53 | return len(self.sequence_list) 54 | 55 | def get_sequence_info(self, seq_id): 56 | bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno']) 57 | valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0) 58 | visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte() 59 | return {'bbox': bb_anno, 'valid': valid, 'visible': visible} 60 | 61 | def _get_frame(self, sequence, frame_id): 62 | set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id']) 63 | vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id']) 64 | frame_number = frame_id + sequence['start_frame'] 65 | frame_path = os.path.join('Data', 'VID', 'train', set_name, vid_name, 66 | '{:06d}.JPEG'.format(frame_number)) 67 | return decode_img(self.root, frame_path) 68 | 69 | def get_frames(self, seq_id, frame_ids, anno=None): 70 | sequence = self.sequence_list[seq_id] 71 | 72 | frame_list = [self._get_frame(sequence, f) for f in frame_ids] 73 | 74 | if anno is None: 75 | anno = self.get_sequence_info(seq_id) 76 | 77 | # Create anno dict 78 | anno_frames = {} 79 | for key, value in anno.items(): 80 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 81 | 82 | # added the class info to the meta info 83 | object_meta = OrderedDict({'object_class': sequence['class_name'], 84 | 'motion_class': None, 85 | 'major_class': None, 86 | 'root_class': None, 87 | 'motion_adverb': None}) 88 | 89 | return frame_list, anno_frames, object_meta 90 | 91 | -------------------------------------------------------------------------------- /tracking/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | 5 | 6 | def parse_args(): 7 | """ 8 | args for training. 9 | """ 10 | parser = argparse.ArgumentParser(description='Parse args for training') 11 | # for train 12 | parser.add_argument('--script', type=str, help='training script name') 13 | parser.add_argument('--config', type=str, default='baseline', help='yaml configure file name') 14 | parser.add_argument('--save_dir', type=str, help='root directory to save checkpoints, logs, and tensorboard') 15 | parser.add_argument('--mode', type=str, choices=["single", "multiple", "multi_node"], default="multiple", 16 | help="train on single gpu or multiple gpus") 17 | parser.add_argument('--nproc_per_node', type=int, help="number of GPUs per node") # specify when mode is multiple 18 | parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format 19 | parser.add_argument('--script_prv', type=str, help='training script name') 20 | parser.add_argument('--config_prv', type=str, default='baseline', help='yaml configure file name') 21 | parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb 22 | parser.add_argument('--seed', type=int, default=0, help='seed for random numbers') 23 | # for knowledge distillation 24 | parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation 25 | parser.add_argument('--script_teacher', type=str, help='teacher script name') 26 | parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name') 27 | 28 | # for multiple machines 29 | parser.add_argument('--rank', type=int, help='Rank of the current process.') 30 | parser.add_argument('--world-size', type=int, help='Number of processes participating in the job.') 31 | parser.add_argument('--ip', type=str, default='127.0.0.1', help='IP of the current rank 0.') 32 | parser.add_argument('--port', type=int, default='20000', help='Port of the current rank 0.') 33 | 34 | args = parser.parse_args() 35 | 36 | return args 37 | 38 | 39 | def main(): 40 | args = parse_args() 41 | if args.mode == "single": 42 | train_cmd = "python lib/train/run_training.py --script %s --config %s --save_dir %s --use_lmdb %d --seed %d " \ 43 | "--script_prv %s --config_prv %s --distill %d --script_teacher %s --config_teacher %s --use_wandb %d"\ 44 | % (args.script, args.config, args.save_dir, args.use_lmdb, args.seed, args.script_prv, args.config_prv, 45 | args.distill, args.script_teacher, args.config_teacher, args.use_wandb) 46 | elif args.mode == "multiple": 47 | # python -m torch.distributed.launch --use_env 48 | train_cmd = "torchrun --nproc_per_node %d --master_port %d lib/train/run_training.py " \ 49 | "--script %s --config %s --save_dir %s --use_lmdb %d --seed %d --script_prv %s --config_prv %s --use_wandb %d " \ 50 | "--distill %d --script_teacher %s --config_teacher %s" \ 51 | % (args.nproc_per_node, random.randint(10000, 50000), args.script, args.config, args.save_dir, args.use_lmdb, args.seed, args.script_prv, args.config_prv, args.use_wandb, 52 | args.distill, args.script_teacher, args.config_teacher) 53 | elif args.mode == "multi_node": 54 | train_cmd = "python -m torch.distributed.launch --nproc_per_node %d --master_addr %s --master_port %d --nnodes %d --node_rank %d lib/train/run_training.py " \ 55 | "--script %s --config %s --save_dir %s --use_lmdb %d --seed %d --script_prv %s --config_prv %s --use_wandb %d " \ 56 | "--distill %d --script_teacher %s --config_teacher %s" \ 57 | % (args.nproc_per_node, args.ip, args.port, args.world_size, args.rank, args.script, args.config, args.save_dir, args.use_lmdb, args.seed, args.script_prv, args.config_prv, args.use_wandb, 58 | args.distill, args.script_teacher, args.config_teacher) 59 | else: 60 | raise ValueError("mode should be 'single' or 'multiple'.") 61 | os.system(train_cmd) 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /lib/test/utils/hann.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | 5 | 6 | def hann1d(sz: int, centered = True) -> torch.Tensor: 7 | """1D cosine window.""" 8 | if centered: 9 | return 0.5 * (1 - torch.cos((2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float())) 10 | w = 0.5 * (1 + torch.cos((2 * math.pi / (sz + 2)) * torch.arange(0, sz//2 + 1).float())) 11 | return torch.cat([w, w[1:sz-sz//2].flip((0,))]) 12 | 13 | 14 | def hann2d(sz: torch.Tensor, centered = True) -> torch.Tensor: 15 | """2D cosine window.""" 16 | return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d(sz[1].item(), centered).reshape(1, 1, 1, -1) 17 | 18 | 19 | def hann2d_bias(sz: torch.Tensor, ctr_point: torch.Tensor, centered = True) -> torch.Tensor: 20 | """2D cosine window.""" 21 | distance = torch.stack([ctr_point, sz-ctr_point], dim=0) 22 | max_distance, _ = distance.max(dim=0) 23 | 24 | hann1d_x = hann1d(max_distance[0].item() * 2, centered) 25 | hann1d_x = hann1d_x[max_distance[0] - distance[0, 0]: max_distance[0] + distance[1, 0]] 26 | hann1d_y = hann1d(max_distance[1].item() * 2, centered) 27 | hann1d_y = hann1d_y[max_distance[1] - distance[0, 1]: max_distance[1] + distance[1, 1]] 28 | 29 | return hann1d_y.reshape(1, 1, -1, 1) * hann1d_x.reshape(1, 1, 1, -1) 30 | 31 | 32 | 33 | def hann2d_clipped(sz: torch.Tensor, effective_sz: torch.Tensor, centered = True) -> torch.Tensor: 34 | """1D clipped cosine window.""" 35 | 36 | # Ensure that the difference is even 37 | effective_sz += (effective_sz - sz) % 2 38 | effective_window = hann1d(effective_sz[0].item(), True).reshape(1, 1, -1, 1) * hann1d(effective_sz[1].item(), True).reshape(1, 1, 1, -1) 39 | 40 | pad = (sz - effective_sz) // 2 41 | 42 | window = F.pad(effective_window, (pad[1].item(), pad[1].item(), pad[0].item(), pad[0].item()), 'replicate') 43 | 44 | if centered: 45 | return window 46 | else: 47 | mid = (sz / 2).int() 48 | window_shift_lr = torch.cat((window[:, :, :, mid[1]:], window[:, :, :, :mid[1]]), 3) 49 | return torch.cat((window_shift_lr[:, :, mid[0]:, :], window_shift_lr[:, :, :mid[0], :]), 2) 50 | 51 | 52 | def gauss_fourier(sz: int, sigma: float, half: bool = False) -> torch.Tensor: 53 | if half: 54 | k = torch.arange(0, int(sz/2+1)) 55 | else: 56 | k = torch.arange(-int((sz-1)/2), int(sz/2+1)) 57 | return (math.sqrt(2*math.pi) * sigma / sz) * torch.exp(-2 * (math.pi * sigma * k.float() / sz)**2) 58 | 59 | 60 | def gauss_spatial(sz, sigma, center=0, end_pad=0): 61 | k = torch.arange(-(sz-1)/2, (sz+1)/2+end_pad) 62 | return torch.exp(-1.0/(2*sigma**2) * (k - center)**2) 63 | 64 | 65 | def label_function(sz: torch.Tensor, sigma: torch.Tensor): 66 | return gauss_fourier(sz[0].item(), sigma[0].item()).reshape(1, 1, -1, 1) * gauss_fourier(sz[1].item(), sigma[1].item(), True).reshape(1, 1, 1, -1) 67 | 68 | def label_function_spatial(sz: torch.Tensor, sigma: torch.Tensor, center: torch.Tensor = torch.zeros(2), end_pad: torch.Tensor = torch.zeros(2)): 69 | """The origin is in the middle of the image.""" 70 | return gauss_spatial(sz[0].item(), sigma[0].item(), center[0], end_pad[0].item()).reshape(1, 1, -1, 1) * \ 71 | gauss_spatial(sz[1].item(), sigma[1].item(), center[1], end_pad[1].item()).reshape(1, 1, 1, -1) 72 | 73 | 74 | def cubic_spline_fourier(f, a): 75 | """The continuous Fourier transform of a cubic spline kernel.""" 76 | 77 | bf = (6*(1 - torch.cos(2 * math.pi * f)) + 3*a*(1 - torch.cos(4 * math.pi * f)) 78 | - (6 + 8*a)*math.pi*f*torch.sin(2 * math.pi * f) - 2*a*math.pi*f*torch.sin(4 * math.pi * f)) \ 79 | / (4 * math.pi**4 * f**4) 80 | 81 | bf[f == 0] = 1 82 | 83 | return bf 84 | 85 | def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor): 86 | """Computes maximum and argmax in the last two dimensions.""" 87 | 88 | max_val_row, argmax_row = torch.max(a, dim=-2) 89 | max_val, argmax_col = torch.max(max_val_row, dim=-1) 90 | argmax_row = argmax_row.view(argmax_col.numel(),-1)[torch.arange(argmax_col.numel()), argmax_col.view(-1)] 91 | argmax_row = argmax_row.reshape(argmax_col.shape) 92 | argmax = torch.cat((argmax_row.unsqueeze(-1), argmax_col.unsqueeze(-1)), -1) 93 | return max_val, argmax 94 | -------------------------------------------------------------------------------- /lib/models/layers/rpe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_ 4 | 5 | 6 | def generate_2d_relative_positional_encoding_index(z_shape, x_shape): 7 | ''' 8 | z_shape: (z_h, z_w) 9 | x_shape: (x_h, x_w) 10 | ''' 11 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 12 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 13 | 14 | z_2d_index_h = z_2d_index_h.flatten(0) 15 | z_2d_index_w = z_2d_index_w.flatten(0) 16 | x_2d_index_h = x_2d_index_h.flatten(0) 17 | x_2d_index_w = x_2d_index_w.flatten(0) 18 | 19 | diff_h = z_2d_index_h[:, None] - x_2d_index_h[None, :] 20 | diff_w = z_2d_index_w[:, None] - x_2d_index_w[None, :] 21 | 22 | diff = torch.stack((diff_h, diff_w), dim=-1) 23 | _, indices = torch.unique(diff.view(-1, 2), return_inverse=True, dim=0) 24 | return indices.view(z_shape[0] * z_shape[1], x_shape[0] * x_shape[1]) 25 | 26 | 27 | def generate_2d_concatenated_self_attention_relative_positional_encoding_index(z_shape, x_shape): 28 | ''' 29 | z_shape: (z_h, z_w) 30 | x_shape: (x_h, x_w) 31 | ''' 32 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 33 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 34 | 35 | z_2d_index_h = z_2d_index_h.flatten(0) 36 | z_2d_index_w = z_2d_index_w.flatten(0) 37 | x_2d_index_h = x_2d_index_h.flatten(0) 38 | x_2d_index_w = x_2d_index_w.flatten(0) 39 | 40 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h)) 41 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w)) 42 | 43 | diff_h = concatenated_2d_index_h[:, None] - concatenated_2d_index_h[None, :] 44 | diff_w = concatenated_2d_index_w[:, None] - concatenated_2d_index_w[None, :] 45 | 46 | z_len = z_shape[0] * z_shape[1] 47 | x_len = x_shape[0] * x_shape[1] 48 | a = torch.empty((z_len + x_len), dtype=torch.int64) 49 | a[:z_len] = 0 50 | a[z_len:] = 1 51 | b=a[:, None].repeat(1, z_len + x_len) 52 | c=a[None, :].repeat(z_len + x_len, 1) 53 | 54 | diff = torch.stack((diff_h, diff_w, b, c), dim=-1) 55 | _, indices = torch.unique(diff.view((z_len + x_len) * (z_len + x_len), 4), return_inverse=True, dim=0) 56 | return indices.view((z_len + x_len), (z_len + x_len)) 57 | 58 | 59 | def generate_2d_concatenated_cross_attention_relative_positional_encoding_index(z_shape, x_shape): 60 | ''' 61 | z_shape: (z_h, z_w) 62 | x_shape: (x_h, x_w) 63 | ''' 64 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 65 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 66 | 67 | z_2d_index_h = z_2d_index_h.flatten(0) 68 | z_2d_index_w = z_2d_index_w.flatten(0) 69 | x_2d_index_h = x_2d_index_h.flatten(0) 70 | x_2d_index_w = x_2d_index_w.flatten(0) 71 | 72 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h)) 73 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w)) 74 | 75 | diff_h = x_2d_index_h[:, None] - concatenated_2d_index_h[None, :] 76 | diff_w = x_2d_index_w[:, None] - concatenated_2d_index_w[None, :] 77 | 78 | z_len = z_shape[0] * z_shape[1] 79 | x_len = x_shape[0] * x_shape[1] 80 | 81 | a = torch.empty(z_len + x_len, dtype=torch.int64) 82 | a[: z_len] = 0 83 | a[z_len:] = 1 84 | c = a[None, :].repeat(x_len, 1) 85 | 86 | diff = torch.stack((diff_h, diff_w, c), dim=-1) 87 | _, indices = torch.unique(diff.view(x_len * (z_len + x_len), 3), return_inverse=True, dim=0) 88 | return indices.view(x_len, (z_len + x_len)) 89 | 90 | 91 | class RelativePosition2DEncoder(nn.Module): 92 | def __init__(self, num_heads, embed_size): 93 | super(RelativePosition2DEncoder, self).__init__() 94 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, embed_size))) 95 | trunc_normal_(self.relative_position_bias_table, std=0.02) 96 | 97 | def forward(self, attn_rpe_index): 98 | ''' 99 | Args: 100 | attn_rpe_index (torch.Tensor): (*), any shape containing indices, max(attn_rpe_index) < embed_size 101 | Returns: 102 | torch.Tensor: (1, num_heads, *) 103 | ''' 104 | return self.relative_position_bias_table[:, attn_rpe_index].unsqueeze(0) 105 | -------------------------------------------------------------------------------- /lib/models/mambatrack/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def combine_tokens(template_tokens, search_tokens, mode='direct', return_res=False): 8 | # [B, HW, C] 9 | len_t = template_tokens.shape[1] 10 | len_s = search_tokens.shape[1] 11 | 12 | if mode == 'direct': 13 | merged_feature = torch.cat((template_tokens, search_tokens), dim=1) 14 | elif mode == 'template_central': 15 | central_pivot = len_s // 2 16 | first_half = search_tokens[:, :central_pivot, :] 17 | second_half = search_tokens[:, central_pivot:, :] 18 | merged_feature = torch.cat((first_half, template_tokens, second_half), dim=1) 19 | elif mode == 'partition': 20 | feat_size_s = int(math.sqrt(len_s)) 21 | feat_size_t = int(math.sqrt(len_t)) 22 | window_size = math.ceil(feat_size_t / 2.) 23 | # pad feature maps to multiples of window size 24 | B, _, C = template_tokens.shape 25 | H = W = feat_size_t 26 | template_tokens = template_tokens.view(B, H, W, C) 27 | pad_l = pad_b = pad_r = 0 28 | # pad_r = (window_size - W % window_size) % window_size 29 | pad_t = (window_size - H % window_size) % window_size 30 | template_tokens = F.pad(template_tokens, (0, 0, pad_l, pad_r, pad_t, pad_b)) 31 | _, Hp, Wp, _ = template_tokens.shape 32 | template_tokens = template_tokens.view(B, Hp // window_size, window_size, W, C) 33 | template_tokens = torch.cat([template_tokens[:, 0, ...], template_tokens[:, 1, ...]], dim=2) 34 | _, Hc, Wc, _ = template_tokens.shape 35 | template_tokens = template_tokens.view(B, -1, C) 36 | merged_feature = torch.cat([template_tokens, search_tokens], dim=1) 37 | 38 | # calculate new h and w, which may be useful for SwinT or others 39 | merged_h, merged_w = feat_size_s + Hc, feat_size_s 40 | if return_res: 41 | return merged_feature, merged_h, merged_w 42 | 43 | else: 44 | raise NotImplementedError 45 | 46 | return merged_feature 47 | 48 | 49 | def recover_tokens(merged_tokens, len_template_token, len_search_token, mode='direct'): 50 | if mode == 'direct': 51 | recovered_tokens = merged_tokens 52 | elif mode == 'template_central': 53 | central_pivot = len_search_token // 2 54 | len_remain = len_search_token - central_pivot 55 | len_half_and_t = central_pivot + len_template_token 56 | 57 | first_half = merged_tokens[:, :central_pivot, :] 58 | second_half = merged_tokens[:, -len_remain:, :] 59 | template_tokens = merged_tokens[:, central_pivot:len_half_and_t, :] 60 | 61 | recovered_tokens = torch.cat((template_tokens, first_half, second_half), dim=1) 62 | elif mode == 'partition': 63 | recovered_tokens = merged_tokens 64 | else: 65 | raise NotImplementedError 66 | 67 | return recovered_tokens 68 | 69 | 70 | def window_partition(x, window_size: int): 71 | """ 72 | Args: 73 | x: (B, H, W, C) 74 | window_size (int): window size 75 | 76 | Returns: 77 | windows: (num_windows*B, window_size, window_size, C) 78 | """ 79 | B, H, W, C = x.shape 80 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 81 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 82 | return windows 83 | 84 | 85 | def window_reverse(windows, window_size: int, H: int, W: int): 86 | """ 87 | Args: 88 | windows: (num_windows*B, window_size, window_size, C) 89 | window_size (int): Window size 90 | H (int): Height of image 91 | W (int): Width of image 92 | 93 | Returns: 94 | x: (B, H, W, C) 95 | """ 96 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 97 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 98 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 99 | return x 100 | 101 | 102 | 103 | 104 | ''' 105 | add token transfer to feature 106 | ''' 107 | def token2feature(tokens): 108 | B,L,D=tokens.shape 109 | H=W=int(L**0.5) 110 | x = tokens.permute(0, 2, 1).reshape(B, D, W, H).contiguous() 111 | return x 112 | 113 | 114 | ''' 115 | feature2token 116 | ''' 117 | def feature2token(x): 118 | B,C,W,H = x.shape 119 | L = W*H 120 | tokens = x.reshape(B, C, L).permute(0, 2, 1).contiguous() 121 | return tokens -------------------------------------------------------------------------------- /tracking/profile_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | prj_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 5 | if prj_path not in sys.path: 6 | sys.path.append(prj_path) 7 | 8 | import argparse 9 | import time 10 | import torch 11 | import torch.nn as nn 12 | from thop import profile 13 | from thop.utils import clever_format 14 | 15 | from lib.models.mambatrack.mamba_simple import Mamba 16 | from lib.models.mambatrack.vit import Attention 17 | 18 | def parse_args(): 19 | """ 20 | args for training. 21 | """ 22 | parser = argparse.ArgumentParser(description='Parse args for training') 23 | # for train 24 | parser.add_argument('--script', type=str, default='mambatrack', help='training script name') 25 | parser.add_argument('--config', type=str, default='mambavt_m256_ep20', help='yaml configure file name') 26 | args = parser.parse_args() 27 | 28 | return args 29 | 30 | def get_complexity_MAMBA(m: Mamba, x, y): 31 | """(B, L, D): batch size, sequence length, dimension""" 32 | batch, seqlen, d_model = x[0].shape 33 | d_inner = m.d_inner 34 | d_state = m.d_state 35 | # import ipdb; ipdb.set_trace() 36 | """compute flops""" 37 | total_ops = 0 38 | # in_proj 39 | total_ops += 2 * d_inner * d_model * batch * seqlen 40 | # out_proj 41 | total_ops += d_inner * d_model * batch * seqlen 42 | # compute bi-SSM 43 | # see: https://github.com/state-spaces/mamba/issues/110 44 | total_ops += 2 * (9 * batch * seqlen * d_model * d_state + batch * d_model * seqlen) 45 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 46 | 47 | def get_complexity_SelfAttn(m: Attention, x, y): 48 | """(B, L, D): batch size, sequence length, dimension""" 49 | total_ops = 0 50 | B, N, C = x[0].shape 51 | total_ops += B * 4 * N * C ** 2 52 | total_ops += B * 2 * C * N ** 2 53 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 54 | 55 | def get_data(bs, sz): 56 | img_patch = torch.randn(bs, 6, sz, sz) 57 | return img_patch 58 | 59 | def evaluate_mambat(model, template_list, search_list, custom_ops, verbose=False): 60 | '''Speed Test''' 61 | macs1, params1 = profile(model, inputs=(template_list, search_list), 62 | custom_ops=custom_ops, verbose=verbose) 63 | macs, params = clever_format([macs1, params1], "%.3f") 64 | print('overall macs is ', macs) 65 | print('overall params is ', params) 66 | 67 | T_w = 50 68 | T_t = 100 69 | print("testing speed ...") 70 | torch.cuda.synchronize() 71 | with torch.no_grad(): 72 | # overall 73 | for i in range(T_w): 74 | _ = model(template_list, search_list) 75 | start = time.time() 76 | for i in range(T_t): 77 | _ = model(template_list, search_list) 78 | torch.cuda.synchronize() 79 | end = time.time() 80 | avg_lat = (end - start) / T_t 81 | print("The average overall latency is %.2f ms" % (avg_lat * 1000)) 82 | print("FPS is %.2f fps" % (1. / avg_lat)) 83 | 84 | if __name__ == "__main__": 85 | import importlib 86 | device = "cuda:1" 87 | torch.cuda.set_device(device) 88 | # Compute the Flops and Params of our STARK-S model 89 | args = parse_args() 90 | '''update cfg''' 91 | yaml_fname = f"{prj_path}/experiments/{args.script}/{args.config}.yaml" 92 | config_module = importlib.import_module('lib.config.%s.config' % args.script) 93 | cfg = config_module.cfg 94 | config_module.update_config_from_file(yaml_fname) 95 | '''import mambatrack network module''' 96 | model_module = importlib.import_module('lib.models.mambatrack') 97 | model_constructor = model_module.build_mambatrack 98 | model = model_constructor(cfg, training=False) 99 | # for name, module in model.named_modules(): 100 | # print(name) 101 | '''set some values''' 102 | bs = 1 103 | z_sz = cfg.TEST.TEMPLATE_SIZE 104 | x_sz = cfg.TEST.SEARCH_SIZE 105 | # z_num = 10 # cfg.TEST.TEMPLATE_NUMBER + 1 106 | 107 | model = model.to(device) 108 | template = get_data(bs, z_sz).to(device) 109 | search = get_data(bs, x_sz).to(device) 110 | 111 | for z_num in [7]: 112 | template_list = [template] * z_num 113 | search_list = [search] 114 | custom_ops = {Attention: get_complexity_SelfAttn} if 'vit' in args.config else {Mamba: get_complexity_MAMBA} 115 | verbose = True 116 | token_length = int(z_sz//16)**2 * z_num * 2 + int(x_sz//16)**2 * 2 117 | print(f"bs:{bs}, z_sz:{z_sz}, x_sz:{x_sz}, z_num:{z_num}, token_length:{token_length}") 118 | 119 | evaluate_mambat(model, template_list, search_list, custom_ops, verbose=verbose) 120 | 121 | -------------------------------------------------------------------------------- /lib/train/train_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | # loss function related 3 | from lib.utils.box_ops import giou_loss 4 | from torch.nn.functional import l1_loss 5 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss 6 | # train pipeline related 7 | from lib.train.trainers import LTRTrainer 8 | # distributed training related 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | # some more advanced functions 11 | from .base_functions import * 12 | # network related 13 | from lib.models.mambatrack import build_mambatrack, build_mambatrack_motion 14 | # forward propagation related 15 | from lib.train.actors import MambaTrackActor, MambaTrackMotionActor 16 | # for import modules 17 | import importlib 18 | 19 | from ..utils.focal_loss import FocalLoss 20 | 21 | 22 | def run(settings): 23 | settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2' 24 | 25 | # update the default configs with config file 26 | if not os.path.exists(settings.cfg_file): 27 | raise ValueError("%s doesn't exist." % settings.cfg_file) 28 | config_module = importlib.import_module("lib.config.%s.config" % settings.script_name) 29 | cfg = config_module.cfg 30 | config_module.update_config_from_file(settings.cfg_file) 31 | if settings.local_rank in [-1, 0]: 32 | print("New configuration is shown below.") 33 | for key in cfg.keys(): 34 | print("%s configuration:" % key, cfg[key]) 35 | print('\n') 36 | 37 | # update settings based on cfg 38 | update_settings(settings, cfg) 39 | 40 | # Record the training log 41 | log_dir = os.path.join(settings.save_dir, 'logs') 42 | if settings.local_rank in [-1, 0]: 43 | if not os.path.exists(log_dir): 44 | os.makedirs(log_dir) 45 | settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name)) 46 | 47 | # Build dataloaders 48 | if settings.script_name == "mambatrack": 49 | loader_train, loader_val = build_dataloaders(cfg, settings) 50 | elif settings.script_name == "mambatrack_motion": 51 | loader_train, loader_val = build_dataloaders_wo_flip(cfg, settings) 52 | 53 | 54 | if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE or "LightTrack" in cfg.MODEL.BACKBONE.TYPE: 55 | cfg.ckpt_dir = settings.save_dir 56 | 57 | # Create network 58 | if settings.script_name == "mambatrack": 59 | net = build_mambatrack(cfg) 60 | elif settings.script_name == "mambatrack_motion": 61 | net = build_mambatrack_motion(cfg) 62 | else: 63 | raise ValueError("illegal script name") 64 | 65 | # wrap networks to distributed one 66 | net.cuda() 67 | if settings.local_rank != -1: 68 | # net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter 69 | net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True) 70 | settings.device = torch.device("cuda:%d" % settings.local_rank) 71 | else: 72 | settings.device = torch.device("cuda:0") 73 | settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False) 74 | settings.distill = getattr(cfg.TRAIN, "DISTILL", False) 75 | settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL") 76 | # Loss functions and Actors 77 | if settings.script_name == "mambatrack": 78 | focal_loss = FocalLoss() 79 | objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss(), 'iou': l1_loss,} 80 | loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 1., 'cls': 1.0, 'iou':1.0} 81 | actor = MambaTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg) 82 | elif settings.script_name == "mambatrack_motion": 83 | focal_loss = FocalLoss() 84 | objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss(), 'iou': l1_loss} 85 | loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 1., 'cls': 1.0, 'iou': 1.0, 86 | 'l1_motion': cfg.TRAIN.L1_WEIGHT*2,} 87 | actor = MambaTrackMotionActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg) 88 | else: 89 | raise ValueError("illegal script name") 90 | 91 | # if cfg.TRAIN.DEEP_SUPERVISION: 92 | # raise ValueError("Deep supervision is not supported now.") 93 | 94 | # Optimizer, parameters, and learning rates 95 | optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg, motion='motion' in settings.script_name) 96 | use_amp = getattr(cfg.TRAIN, "AMP", False) 97 | if loader_val is None: 98 | trainer = LTRTrainer(actor, [loader_train], optimizer, settings, lr_scheduler, use_amp=use_amp) 99 | else: 100 | trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp) 101 | 102 | # train process 103 | trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True) 104 | -------------------------------------------------------------------------------- /lib/train/train_script_distill.py: -------------------------------------------------------------------------------- 1 | import os 2 | # loss function related 3 | from lib.utils.box_ops import giou_loss 4 | from torch.nn.functional import l1_loss 5 | from torch.nn import BCEWithLogitsLoss 6 | # train pipeline related 7 | from lib.train.trainers import LTRTrainer 8 | # distributed training related 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | # some more advanced functions 11 | from .base_functions import * 12 | # network related 13 | from lib.models.stark import build_starks, build_starkst 14 | from lib.models.stark import build_stark_lightning_x_trt 15 | # forward propagation related 16 | from lib.train.actors import STARKLightningXtrtdistillActor 17 | # for import modules 18 | import importlib 19 | 20 | 21 | def build_network(script_name, cfg): 22 | # Create network 23 | if script_name == "stark_s": 24 | net = build_starks(cfg) 25 | elif script_name == "stark_st1" or script_name == "stark_st2": 26 | net = build_starkst(cfg) 27 | elif script_name == "stark_lightning_X_trt": 28 | net = build_stark_lightning_x_trt(cfg, phase="train") 29 | else: 30 | raise ValueError("illegal script name") 31 | return net 32 | 33 | 34 | def run(settings): 35 | settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2' 36 | 37 | # update the default configs with config file 38 | if not os.path.exists(settings.cfg_file): 39 | raise ValueError("%s doesn't exist." % settings.cfg_file) 40 | config_module = importlib.import_module("lib.config.%s.config" % settings.script_name) 41 | cfg = config_module.cfg 42 | config_module.update_config_from_file(settings.cfg_file) 43 | if settings.local_rank in [-1, 0]: 44 | print("New configuration is shown below.") 45 | for key in cfg.keys(): 46 | print("%s configuration:" % key, cfg[key]) 47 | print('\n') 48 | 49 | # update the default teacher configs with teacher config file 50 | if not os.path.exists(settings.cfg_file_teacher): 51 | raise ValueError("%s doesn't exist." % settings.cfg_file_teacher) 52 | config_module_teacher = importlib.import_module("lib.config.%s.config" % settings.script_teacher) 53 | cfg_teacher = config_module_teacher.cfg 54 | config_module_teacher.update_config_from_file(settings.cfg_file_teacher) 55 | if settings.local_rank in [-1, 0]: 56 | print("New teacher configuration is shown below.") 57 | for key in cfg_teacher.keys(): 58 | print("%s configuration:" % key, cfg_teacher[key]) 59 | print('\n') 60 | 61 | # update settings based on cfg 62 | update_settings(settings, cfg) 63 | 64 | # Record the training log 65 | log_dir = os.path.join(settings.save_dir, 'logs') 66 | if settings.local_rank in [-1, 0]: 67 | if not os.path.exists(log_dir): 68 | os.makedirs(log_dir) 69 | settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name)) 70 | 71 | # Build dataloaders 72 | loader_train, loader_val = build_dataloaders(cfg, settings) 73 | 74 | if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE: 75 | cfg.ckpt_dir = settings.save_dir 76 | """turn on the distillation mode""" 77 | cfg.TRAIN.DISTILL = True 78 | cfg_teacher.TRAIN.DISTILL = True 79 | net = build_network(settings.script_name, cfg) 80 | net_teacher = build_network(settings.script_teacher, cfg_teacher) 81 | 82 | # wrap networks to distributed one 83 | net.cuda() 84 | net_teacher.cuda() 85 | net_teacher.eval() 86 | 87 | if settings.local_rank != -1: 88 | net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True) 89 | net_teacher = DDP(net_teacher, device_ids=[settings.local_rank], find_unused_parameters=True) 90 | settings.device = torch.device("cuda:%d" % settings.local_rank) 91 | else: 92 | settings.device = torch.device("cuda:0") 93 | # settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False) 94 | # settings.distill = getattr(cfg.TRAIN, "DISTILL", False) 95 | settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "L1") 96 | # Loss functions and Actors 97 | if settings.script_name == "stark_lightning_X_trt": 98 | objective = {'giou': giou_loss, 'l1': l1_loss} 99 | loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT} 100 | actor = STARKLightningXtrtdistillActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, 101 | net_teacher=net_teacher) 102 | else: 103 | raise ValueError("illegal script name") 104 | 105 | # Optimizer, parameters, and learning rates 106 | optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg) 107 | use_amp = getattr(cfg.TRAIN, "AMP", False) 108 | trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp) 109 | 110 | # train process 111 | trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True, distill=True) 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MambaVT 2 | The official implementation for the TCSVT.2025 paper: ["MambaVT: Spatio-Temporal Contextual Modeling for robust RGB-T Tracking"](https://ieeexplore.ieee.org/document/10949219) 3 | 4 | :rocket: Update Models and Results (2024/08/07) 5 | [Models & Raw Results](https://drive.google.com/drive/folders/1Ww-cMuzJ-6XcnTSsPuedlyR_LWSgwAaR?usp=sharing) [Google Driver] 6 | [Models & Raw Results](https://pan.baidu.com/s/1XaXsSrToDqbLLAStJcMr9g) [Baidu Driver: iiau] 7 | [Models & Raw Results](https://pan.quark.cn/s/f0d7ac3f9974) [Quark Driver: neYH] 8 | 9 |

10 | Framework 11 |

12 | 13 | ## Highlights 14 | ### :star2: New Unified Mamba-based Tracking Framework 15 | MambaVT is a simple, neat, high-performance **unified Mamba-based tracking framework** for global long-range and local short-term spatio-temporal contextual modeling for robust RGB-T Tracking. MambaVT achieves SOTA performance on multiple RGB-T benchmarks with fewer FLOPs and Params. MambaVT can serve as a strong baseline for further research. 16 | 17 | | Tracker | GTOT (SR) | RGBT210 (SR) | RGBT234 (MSR) | LasHeR(SR) | 18 | |:-----------:|:------------:|:-----------:|:-----------------:|:-----------:| 19 | | MambaVT-S256 | 75.3 | 63.7 | 65.8 | 57.9 | 20 | | MambaVT-M256 | 78.6 | 64.4 | 67.5 | 57.5 | 21 | 22 | ## Install the environment 23 | We've tested the results on the PyTorch2.1.1+cuda11.8+Python3.9+causal-conv1d==1.1.0 24 | 25 | :warning: Cuda version must be strictly met, but PyTorch>=2.0 will be fine 26 | 27 | 28 | **Option1**: Use the Anaconda (CUDA 11.8) 29 | ```bash 30 | conda create -n mambavt python=3.9 31 | conda activate mambavt 32 | bash install.sh 33 | ``` 34 | 35 | And we strongly recommend installing torch/torchvision/causal-conv1d manually by: 36 | ```bash 37 | # Download torch from: https://download.pytorch.org/whl/cu118/torch-2.1.1%2Bcu118-cp39-cp39-linux_x86_64.whl 38 | pip install torch-2.1.1%2Bcu118-cp39-cp39-linux_x86_64.whl 39 | 40 | # Download torchvision from: https://download.pytorch.org/whl/cu118/torchvision-0.16.1%2Bcu118-cp39-cp39-linux_x86_64.whl 41 | pip install torchvision-0.16.1%2Bcu118-cp39-cp39-linux_x86_64.whl 42 | 43 | # Download causal-conv1d from: https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.1.0/causal_conv1d-1.1.0+cu118torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl 44 | pip install causal_conv1d-1.1.0+cu118torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl 45 | ``` 46 | 47 | Install `mamba` 48 | ``` 49 | cd mamba-1p1p1 50 | pip install -e . # in editor mode 51 | 52 | # If compile successfully, selective_scan_cuda.cpython-39-x86_64-linux-gnu.so will be generated 53 | ``` 54 | ## Set project paths 55 | Run the following command to set paths for this project 56 | ``` 57 | python tracking/create_default_local_file.py --workspace_dir . --data_dir ./data --save_dir ./output 58 | ``` 59 | After running this command, you can also modify paths by editing these two files 60 | ``` 61 | lib/train/admin/local.py # paths about training 62 | lib/test/evaluation/local.py # paths about testing 63 | ``` 64 | 65 | ## Data Preparation 66 | prepare the LasHeR dataset. It should look like: 67 | ``` 68 | ${PROJECT_ROOT} 69 | -- LasHeR 70 | -- train/ 71 | |-- trainingset 72 | |-- 1boygo 73 | |-- ... 74 | |-- trainingsetList.txt 75 | |-- tracker_predicted 76 | ... 77 | -- test 78 | |-- testingset 79 | |-- 1blackteacher 80 | |-- ... 81 | |-- testingsetList.txt 82 | ``` 83 | 84 | ## Training 85 | Download pre-trained `OSTrack_videomambas_ep300.pth.tar` or `OSTrack_videomambam_ep300.pth.tar` from above driver link and put it under `$PROJECT_ROOT$/pretrained_models` . Then 86 | 87 | ``` 88 | bash xtrain.sh 89 | bash xtrain_motion.sh 90 | ``` 91 | 92 | Replace `--config` with the desired model config under `experiments/mambatrack` or `experiments/mambatrack_motion`. 93 | 94 | 95 | ## Evaluation 96 | Download the model weights from above driver link 97 | 98 | Put the downloaded weights on `$PROJECT_ROOT$/checkpoints/` 99 | 100 | Change the corresponding values of `lib/test/parameter/mambatrack_motion.py` to the actual checkpoint paths. Then 101 | 102 | ``` 103 | bash ytest.sh 104 | bash ytest_motion.sh 105 | ``` 106 | 107 | ## Test FLOPs, and Speed 108 | *Note:* The speeds reported in our paper were tested on a single 3090 GPU. 109 | 110 | ```bash 111 | # Profiling mambavt_s256_ep20 112 | python tracking/profile_model.py --script mambatrack --config mambavt_s256_ep20 113 | # Profiling mambavt_m256_ep20 114 | python tracking/profile_model.py --script mambatrack --config mambavt_m256_ep20 115 | ``` 116 | 117 | ## Acknowledgments 118 | * Thanks for the [OSTrack](https://github.com/botaoye/OSTrack), [Vim](https://github.com/hustvl/Vim) and [VideoMamba](https://github.com/OpenGVLab/VideoMamba) library, which helps us to quickly implement our ideas. 119 | 120 | 121 | ## Contact 122 | If you have any question, feel free to email laisimiao@mail.dlut.edu.cn. ^_^ 123 | -------------------------------------------------------------------------------- /lib/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_area 3 | import numpy as np 4 | 5 | 6 | def box_cxcywh_to_xyxy(x): 7 | x_c, y_c, w, h = x.unbind(-1) 8 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 9 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 10 | return torch.stack(b, dim=-1) 11 | 12 | 13 | def box_xywh_to_xyxy(x): 14 | x1, y1, w, h = x.unbind(-1) 15 | b = [x1, y1, x1 + w, y1 + h] 16 | return torch.stack(b, dim=-1) 17 | 18 | 19 | def box_xyxy_to_xywh(x): 20 | x1, y1, x2, y2 = x.unbind(-1) 21 | b = [x1, y1, x2 - x1, y2 - y1] 22 | return torch.stack(b, dim=-1) 23 | 24 | 25 | def box_xyxy_to_cxcywh(x): 26 | x0, y0, x1, y1 = x.unbind(-1) 27 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 28 | (x1 - x0), (y1 - y0)] 29 | return torch.stack(b, dim=-1) 30 | 31 | 32 | # modified from torchvision to also return the union 33 | '''Note that this function only supports shape (N,4)''' 34 | 35 | 36 | def box_iou(boxes1, boxes2): 37 | """ 38 | 39 | :param boxes1: (N, 4) (x1,y1,x2,y2) 40 | :param boxes2: (N, 4) (x1,y1,x2,y2) 41 | :return: 42 | """ 43 | area1 = box_area(boxes1) # (N,) 44 | area2 = box_area(boxes2) # (N,) 45 | 46 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N,2) 47 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N,2) 48 | 49 | wh = (rb - lt).clamp(min=0) # (N,2) 50 | inter = wh[:, 0] * wh[:, 1] # (N,) 51 | 52 | union = area1 + area2 - inter 53 | 54 | iou = inter / union 55 | return iou, union 56 | 57 | 58 | '''Note that this implementation is different from DETR's''' 59 | 60 | 61 | def generalized_box_iou(boxes1, boxes2): 62 | """ 63 | Generalized IoU from https://giou.stanford.edu/ 64 | 65 | The boxes should be in [x0, y0, x1, y1] format 66 | 67 | boxes1: (N, 4) 68 | boxes2: (N, 4) 69 | """ 70 | # degenerate boxes gives inf / nan results 71 | # so do an early check 72 | # try: 73 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 74 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 75 | iou, union = box_iou(boxes1, boxes2) # (N,) 76 | 77 | lt = torch.min(boxes1[:, :2], boxes2[:, :2]) 78 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) 79 | 80 | wh = (rb - lt).clamp(min=0) # (N,2) 81 | area = wh[:, 0] * wh[:, 1] # (N,) 82 | 83 | return iou - (area - union) / area, iou 84 | 85 | 86 | def giou_loss(boxes1, boxes2): 87 | """ 88 | 89 | :param boxes1: (N, 4) (x1,y1,x2,y2) 90 | :param boxes2: (N, 4) (x1,y1,x2,y2) 91 | :return: 92 | """ 93 | giou, iou = generalized_box_iou(boxes1, boxes2) 94 | return (1 - giou).mean(), iou 95 | 96 | 97 | def clip_box(box: list, H, W, margin=0): 98 | x1, y1, w, h = box 99 | x2, y2 = x1 + w, y1 + h 100 | x1 = min(max(0, x1), W-margin) 101 | x2 = min(max(margin, x2), W) 102 | y1 = min(max(0, y1), H-margin) 103 | y2 = min(max(margin, y2), H) 104 | w = max(margin, x2-x1) 105 | h = max(margin, y2-y1) 106 | return [x1, y1, w, h] 107 | 108 | 109 | 110 | def fp16_clamp(x, min=None, max=None): 111 | if not x.is_cuda and x.dtype == torch.float16: 112 | # clamp for cpu float16, tensor fp16 has no clamp implementation 113 | return x.float().clamp(min, max).half() 114 | 115 | return x.clamp(min, max) 116 | 117 | # angle cost 118 | def SIoU_loss(test1, test2, theta=4): 119 | # must xyxy format input 120 | eps = 1e-7 121 | cx_pred = (test1[:, 0] + test1[:, 2]) / 2 122 | cy_pred = (test1[:, 1] + test1[:, 3]) / 2 123 | cx_gt = (test2[:, 0] + test2[:, 2]) / 2 124 | cy_gt = (test2[:, 1] + test2[:, 3]) / 2 125 | 126 | dist = ((cx_pred - cx_gt)**2 + (cy_pred - cy_gt)**2) ** 0.5 127 | ch = torch.max(cy_gt, cy_pred) - torch.min(cy_gt, cy_pred) 128 | x = ch / (dist + eps) 129 | 130 | angle = 1 - 2*torch.sin(torch.arcsin(x)-torch.pi/4)**2 131 | # distance cost 132 | xmin = torch.min(test1[:, 0], test2[:, 0]) 133 | xmax = torch.max(test1[:, 2], test2[:, 2]) 134 | ymin = torch.min(test1[:, 1], test2[:, 1]) 135 | ymax = torch.max(test1[:, 3], test2[:, 3]) 136 | cw = xmax - xmin 137 | ch = ymax - ymin 138 | px = ((cx_gt - cx_pred) / (cw+eps))**2 139 | py = ((cy_gt - cy_pred) / (ch+eps))**2 140 | gama = 2 - angle 141 | dis = (1 - torch.exp(-1 * gama * px)) + (1 - torch.exp(-1 * gama * py)) 142 | 143 | #shape cost 144 | w_pred = test1[:, 2] - test1[:, 0] 145 | h_pred = test1[:, 3] - test1[:, 1] 146 | w_gt = test2[:, 2] - test2[:, 0] 147 | h_gt = test2[:, 3] - test2[:, 1] 148 | ww = torch.abs(w_pred - w_gt) / (torch.max(w_pred, w_gt) + eps) 149 | wh = torch.abs(h_gt - h_pred) / (torch.max(h_gt, h_pred) + eps) 150 | omega = (1 - torch.exp(-1 * wh)) ** theta + (1 - torch.exp(-1 * ww)) ** theta 151 | 152 | #IoU loss 153 | lt = torch.max(test1[..., :2], test2[..., :2]) # [B, rows, 2] 154 | rb = torch.min(test1[..., 2:], test2[..., 2:]) # [B, rows, 2] 155 | 156 | wh = fp16_clamp(rb - lt, min=0) 157 | overlap = wh[..., 0] * wh[..., 1] 158 | area1 = (test1[..., 2] - test1[..., 0]) * ( 159 | test1[..., 3] - test1[..., 1]) 160 | area2 = (test2[..., 2] - test2[..., 0]) * ( 161 | test2[..., 3] - test2[..., 1]) 162 | iou = overlap / (area1 + area2 - overlap) 163 | 164 | SIoU = 1 - iou + (omega + dis) / 2 165 | return SIoU, iou -------------------------------------------------------------------------------- /lib/config/mambatrack/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import yaml 3 | 4 | """ 5 | Add default config for OSTrack. 6 | """ 7 | cfg = edict() 8 | 9 | # MODEL 10 | cfg.MODEL = edict() 11 | cfg.MODEL.PRETRAIN_FILE = "mae_pretrain_vit_base.pth" 12 | cfg.MODEL.EXTRA_MERGER = False 13 | 14 | cfg.MODEL.RETURN_INTER = False 15 | cfg.MODEL.RETURN_STAGES = [] 16 | 17 | # MODEL.BACKBONE 18 | cfg.MODEL.BACKBONE = edict() 19 | cfg.MODEL.BACKBONE.TYPE = "vit_base_patch16_224" 20 | cfg.MODEL.BACKBONE.STRIDE = 16 21 | cfg.MODEL.BACKBONE.CONCAT_MODE = 'tsts' # tsts: t_rgb, s_rgb, t_X, s_X | ttss: t_rgb, t_X, s_rgb, s_X 22 | cfg.MODEL.BACKBONE.SCAN_MODE = 'spatial_first' # spatial_first temporal_first 23 | cfg.MODEL.BACKBONE.Z_SEG = False # add template segment embedding (as multiple templates) 24 | cfg.MODEL.BACKBONE.Z_VOCAB_SIZE = 200 # 25 | cfg.MODEL.BACKBONE.ADD_CLS_TOKEN = False 26 | 27 | cfg.MODEL.BACKBONE.MID_PE = False 28 | cfg.MODEL.BACKBONE.SEP_SEG = False 29 | cfg.MODEL.BACKBONE.CAT_MODE = 'direct' 30 | cfg.MODEL.BACKBONE.MERGE_LAYER = 0 31 | cfg.MODEL.BACKBONE.CLS_TOKEN_USE_MODE = 'ignore' 32 | 33 | cfg.MODEL.BACKBONE.CE_LOC = [] 34 | cfg.MODEL.BACKBONE.CE_KEEP_RATIO = [] 35 | cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE = 'ALL' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 36 | 37 | # MODEL.HEAD 38 | cfg.MODEL.HEAD = edict() 39 | cfg.MODEL.HEAD.TYPE = "CENTER" 40 | cfg.MODEL.HEAD.NUM_CHANNELS = 256 41 | cfg.MODEL.HEAD.FUSE = False 42 | 43 | # TRAIN 44 | cfg.TRAIN = edict() 45 | cfg.TRAIN.LR = 0.0001 46 | cfg.TRAIN.WEIGHT_DECAY = 0.0001 47 | cfg.TRAIN.EPOCH = 500 48 | cfg.TRAIN.LR_DROP_EPOCH = 400 49 | cfg.TRAIN.BATCH_SIZE = 16 50 | cfg.TRAIN.NUM_WORKER = 8 51 | cfg.TRAIN.OPTIMIZER = "ADAMW" 52 | cfg.TRAIN.BACKBONE_MULTIPLIER = 0.1 53 | cfg.TRAIN.GIOU_WEIGHT = 2.0 54 | cfg.TRAIN.L1_WEIGHT = 5.0 55 | cfg.TRAIN.FREEZE_LAYERS = [0, ] 56 | cfg.TRAIN.PRINT_INTERVAL = 50 57 | cfg.TRAIN.VAL_PRINT_INTERVAL = 10 58 | cfg.TRAIN.VAL_EPOCH_INTERVAL = 20 59 | cfg.TRAIN.GRAD_CLIP_NORM = 0.1 60 | cfg.TRAIN.AMP = False 61 | cfg.TRAIN.ACCUM_ITER = 1 62 | 63 | ## TRAIN save cfgs 64 | cfg.TRAIN.SAVE_EPOCH_INTERVAL = 1 # 1 means save model each epoch 65 | cfg.TRAIN.SAVE_LAST_N_EPOCH = 1 # besides, last n epoch model will be saved 66 | 67 | 68 | cfg.TRAIN.CE_START_EPOCH = 20 # candidate elimination start epoch 69 | cfg.TRAIN.CE_WARM_EPOCH = 80 # candidate elimination warm up epoch 70 | cfg.TRAIN.DROP_PATH_RATE = 0.1 # drop path rate for ViT backbone 71 | 72 | # TRAIN.SCHEDULER 73 | cfg.TRAIN.SCHEDULER = edict() 74 | cfg.TRAIN.SCHEDULER.TYPE = "step" 75 | cfg.TRAIN.SCHEDULER.DECAY_RATE = 0.1 76 | 77 | # DATA 78 | cfg.DATA = edict() 79 | cfg.DATA.REVERSE_PROB = 0.0 # reverse video-level training augmentation < this 80 | cfg.DATA.SAMPLER_MODE = "causal" # uniform sampling methods 81 | cfg.DATA.MEAN = [0.485, 0.456, 0.406] 82 | cfg.DATA.STD = [0.229, 0.224, 0.225] 83 | cfg.DATA.MAX_SAMPLE_INTERVAL = 200 84 | # DATA.TRAIN 85 | cfg.DATA.TRAIN = edict() 86 | cfg.DATA.TRAIN.DATASETS_NAME = ["LASOT", "GOT10K_vottrain"] 87 | cfg.DATA.TRAIN.DATASETS_RATIO = [1, 1] 88 | cfg.DATA.TRAIN.SAMPLE_PER_EPOCH = 60000 89 | # DATA.VAL 90 | cfg.DATA.VAL = edict() 91 | cfg.DATA.VAL.DATASETS_NAME = ["GOT10K_votval"] 92 | cfg.DATA.VAL.DATASETS_RATIO = [1] 93 | cfg.DATA.VAL.SAMPLE_PER_EPOCH = 10000 94 | # DATA.SEARCH 95 | cfg.DATA.SEARCH = edict() 96 | cfg.DATA.SEARCH.SIZE = 320 97 | cfg.DATA.SEARCH.FACTOR = 5.0 98 | cfg.DATA.SEARCH.CENTER_JITTER = 4.5 99 | cfg.DATA.SEARCH.SCALE_JITTER = 0.5 100 | cfg.DATA.SEARCH.NUMBER = 1 101 | # DATA.TEMPLATE 102 | cfg.DATA.TEMPLATE = edict() 103 | cfg.DATA.TEMPLATE.NUMBER = 1 104 | cfg.DATA.TEMPLATE.SIZE = 128 105 | cfg.DATA.TEMPLATE.FACTOR = 2.0 106 | cfg.DATA.TEMPLATE.CENTER_JITTER = 0 107 | cfg.DATA.TEMPLATE.SCALE_JITTER = 0 108 | 109 | # TEST 110 | cfg.TEST = edict() 111 | cfg.TEST.TEMPLATE_FACTOR = 2.0 112 | cfg.TEST.TEMPLATE_SIZE = 128 113 | cfg.TEST.SEARCH_FACTOR = 5.0 114 | cfg.TEST.SEARCH_SIZE = 320 115 | cfg.TEST.EPOCH = 500 116 | cfg.TEST.TEMPLATE_NUMBER = 3 117 | 118 | 119 | def _edict2dict(dest_dict, src_edict): 120 | if isinstance(dest_dict, dict) and isinstance(src_edict, dict): 121 | for k, v in src_edict.items(): 122 | if not isinstance(v, edict): 123 | dest_dict[k] = v 124 | else: 125 | dest_dict[k] = {} 126 | _edict2dict(dest_dict[k], v) 127 | else: 128 | return 129 | 130 | 131 | def gen_config(config_file): 132 | cfg_dict = {} 133 | _edict2dict(cfg_dict, cfg) 134 | with open(config_file, 'w') as f: 135 | yaml.dump(cfg_dict, f, default_flow_style=False) 136 | 137 | 138 | def _update_config(base_cfg, exp_cfg): 139 | if isinstance(base_cfg, dict) and isinstance(exp_cfg, edict): 140 | for k, v in exp_cfg.items(): 141 | if k in base_cfg: 142 | if not isinstance(v, dict): 143 | base_cfg[k] = v 144 | else: 145 | _update_config(base_cfg[k], v) 146 | else: 147 | raise ValueError("{} not exist in config.py".format(k)) 148 | else: 149 | return 150 | 151 | 152 | def update_config_from_file(filename, base_cfg=None): 153 | exp_config = None 154 | with open(filename) as f: 155 | exp_config = edict(yaml.safe_load(f)) 156 | if base_cfg is not None: 157 | _update_config(base_cfg, exp_config) 158 | else: 159 | _update_config(cfg, exp_config) 160 | -------------------------------------------------------------------------------- /lib/train/dataset/lasher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | import torch 5 | import csv 6 | import pandas 7 | import random 8 | from collections import OrderedDict 9 | from .base_video_dataset import BaseVideoDataset 10 | from lib.train.admin import env_settings 11 | from lib.train.dataset.depth_utils import get_x_frame 12 | 13 | 14 | class LasHeR(BaseVideoDataset): 15 | """ LasHeR dataset(aligned version). 16 | 17 | Publication: 18 | A Large-scale High-diversity Benchmark for RGBT Tracking 19 | Chenglong Li, Wanlin Xue, Yaqing Jia, Zhichen Qu, Bin Luo, Jin Tang, and Dengdi Sun 20 | https://arxiv.org/pdf/2104.13202.pdf 21 | 22 | Download dataset from https://github.com/BUGPLEASEOUT/LasHeR 23 | """ 24 | 25 | def __init__(self, root=None, split='train', dtype='rgbrgb', seq_ids=None, data_fraction=None): 26 | """ 27 | args: 28 | root - path to the LasHeR trainingset. 29 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 30 | is used by default. 31 | seq_ids - List containing the ids of the videos to be used for training. Note: Only one of 'split' or 'seq_ids' 32 | options can be used at the same time. 33 | data_fraction - Fraction of dataset to be used. The complete dataset is used by default 34 | """ 35 | root = env_settings().lasher_dir if root is None else root 36 | assert split in ['train', 'val','all'], 'Only support all, train or val split in LasHeR, got {}'.format(split) 37 | super().__init__('LasHeR', root) 38 | self.dtype = dtype 39 | 40 | # all folders inside the root 41 | self.sequence_list = self._get_sequence_list(split) 42 | 43 | # seq_id is the index of the folder inside the got10k root path 44 | if seq_ids is None: 45 | seq_ids = list(range(0, len(self.sequence_list))) 46 | 47 | self.sequence_list = [self.sequence_list[i] for i in seq_ids] 48 | 49 | if data_fraction is not None: 50 | self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction)) 51 | 52 | def get_name(self): 53 | return 'lasher' 54 | 55 | def has_class_info(self): 56 | return True 57 | 58 | def has_occlusion_info(self): 59 | return True # w=h=0 in visible.txt and infrared.txt is occlusion/oov 60 | 61 | def _get_sequence_list(self, split): 62 | ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 63 | file_path = os.path.join(ltr_path, 'data_specs', 'lasher_{}.txt'.format(split)) 64 | with open(file_path, 'r') as f: 65 | dir_list = f.read().splitlines() 66 | return dir_list 67 | 68 | def _read_bb_anno(self, seq_path): 69 | # in lasher dataset, visible.txt is same as infrared.txt 70 | rgb_bb_anno_file = os.path.join(seq_path, "visible.txt") 71 | # ir_bb_anno_file = os.path.join(seq_path, "infrared.txt") 72 | rgb_gt = pandas.read_csv(rgb_bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values 73 | # ir_gt = pandas.read_csv(ir_bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values 74 | return torch.tensor(rgb_gt) 75 | 76 | def _get_sequence_path(self, seq_id): 77 | return os.path.join(self.root, self.sequence_list[seq_id]) 78 | 79 | def get_sequence_info(self, seq_id): 80 | """2022/8/10 ir and rgb have synchronous w=h=0 frame_index""" 81 | seq_path = self._get_sequence_path(seq_id) 82 | bbox = self._read_bb_anno(seq_path) 83 | valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 84 | visible = valid.clone().byte() 85 | return {'bbox': bbox, 'valid': valid, 'visible': visible} 86 | 87 | def _get_frame_path(self, seq_path, frame_id): 88 | # Note original filename is chaotic, we rename them 89 | rgb_frame_path = os.path.join(seq_path, 'visible', '{:06d}.jpg'.format(frame_id)) # frames start from 0 90 | ir_frame_path = os.path.join(seq_path, 'infrared', '{:06d}.jpg'.format(frame_id)) 91 | return (rgb_frame_path, ir_frame_path) # jpg jpg 92 | 93 | def _get_frame(self, seq_path, frame_id): 94 | rgb_frame_path, ir_frame_path = self._get_frame_path(seq_path, frame_id) 95 | img = get_x_frame(rgb_frame_path, ir_frame_path, dtype=self.dtype) 96 | return img # (h,w,6) 97 | 98 | def get_frames(self, seq_id, frame_ids, anno=None, is_search=False, pre_motion_num=0, reverse=False): 99 | seq_path = self._get_sequence_path(seq_id) 100 | 101 | frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids] 102 | 103 | if anno is None: 104 | anno = self.get_sequence_info(seq_id) 105 | 106 | anno_frames = {} 107 | for key, value in anno.items(): 108 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 109 | 110 | object_meta = OrderedDict({'object_class_name': None, 111 | 'motion_class': None, 112 | 'major_class': None, 113 | 'root_class': None, 114 | 'motion_adverb': None}) 115 | 116 | return frame_list, anno_frames, object_meta 117 | -------------------------------------------------------------------------------- /lib/models/layers/attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import trunc_normal_ 5 | 6 | from lib.models.layers.rpe import generate_2d_concatenated_self_attention_relative_positional_encoding_index 7 | 8 | 9 | class Attention(nn.Module): 10 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., 11 | rpe=False, z_size=7, x_size=14): 12 | super().__init__() 13 | self.num_heads = num_heads 14 | head_dim = dim // num_heads 15 | self.scale = head_dim ** -0.5 16 | 17 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 18 | self.attn_drop = nn.Dropout(attn_drop) 19 | self.proj = nn.Linear(dim, dim) 20 | self.proj_drop = nn.Dropout(proj_drop) 21 | 22 | self.rpe =rpe 23 | if self.rpe: 24 | relative_position_index = \ 25 | generate_2d_concatenated_self_attention_relative_positional_encoding_index([z_size, z_size], 26 | [x_size, x_size]) 27 | self.register_buffer("relative_position_index", relative_position_index) 28 | # define a parameter table of relative position bias 29 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, 30 | relative_position_index.max() + 1))) 31 | trunc_normal_(self.relative_position_bias_table, std=0.02) 32 | 33 | def forward(self, x, mask=None, return_attention=False): 34 | # x: B, N, C 35 | # mask: [B, N, ] torch.bool 36 | B, N, C = x.shape 37 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 38 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 39 | 40 | attn = (q @ k.transpose(-2, -1)) * self.scale 41 | 42 | if self.rpe: 43 | relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index].unsqueeze(0) 44 | attn += relative_position_bias 45 | 46 | if mask is not None: 47 | attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'),) 48 | 49 | attn = attn.softmax(dim=-1) 50 | attn = self.attn_drop(attn) 51 | 52 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 53 | x = self.proj(x) 54 | x = self.proj_drop(x) 55 | 56 | if return_attention: 57 | return x, attn 58 | else: 59 | return x 60 | 61 | 62 | class Attention_talking_head(nn.Module): 63 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 64 | # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) 65 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 66 | rpe=True, z_size=7, x_size=14): 67 | super().__init__() 68 | 69 | self.num_heads = num_heads 70 | 71 | head_dim = dim // num_heads 72 | 73 | self.scale = qk_scale or head_dim ** -0.5 74 | 75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 76 | self.attn_drop = nn.Dropout(attn_drop) 77 | 78 | self.proj = nn.Linear(dim, dim) 79 | 80 | self.proj_l = nn.Linear(num_heads, num_heads) 81 | self.proj_w = nn.Linear(num_heads, num_heads) 82 | 83 | self.proj_drop = nn.Dropout(proj_drop) 84 | 85 | self.rpe = rpe 86 | if self.rpe: 87 | relative_position_index = \ 88 | generate_2d_concatenated_self_attention_relative_positional_encoding_index([z_size, z_size], 89 | [x_size, x_size]) 90 | self.register_buffer("relative_position_index", relative_position_index) 91 | # define a parameter table of relative position bias 92 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, 93 | relative_position_index.max() + 1))) 94 | trunc_normal_(self.relative_position_bias_table, std=0.02) 95 | 96 | def forward(self, x, mask=None): 97 | B, N, C = x.shape 98 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 100 | 101 | attn = (q @ k.transpose(-2, -1)) 102 | 103 | if self.rpe: 104 | relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index].unsqueeze(0) 105 | attn += relative_position_bias 106 | 107 | if mask is not None: 108 | attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), 109 | float('-inf'),) 110 | 111 | attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 112 | 113 | attn = attn.softmax(dim=-1) 114 | 115 | attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 116 | attn = self.attn_drop(attn) 117 | 118 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 119 | x = self.proj(x) 120 | x = self.proj_drop(x) 121 | return x -------------------------------------------------------------------------------- /lib/config/mambatrack_motion/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import yaml 3 | 4 | """ 5 | Add default config for OSTrack. 6 | """ 7 | cfg = edict() 8 | 9 | # MODEL 10 | cfg.MODEL = edict() 11 | cfg.MODEL.PRETRAIN_FILE = "mae_pretrain_vit_base.pth" 12 | cfg.MODEL.BINS = 400 13 | cfg.MODEL.RANGE = 2 14 | cfg.MODEL.EXTRA_MERGER = False 15 | 16 | cfg.MODEL.RETURN_INTER = False 17 | cfg.MODEL.RETURN_STAGES = [] 18 | 19 | # MODEL.BACKBONE 20 | cfg.MODEL.BACKBONE = edict() 21 | cfg.MODEL.BACKBONE.TYPE = "vit_base_patch16_224" 22 | cfg.MODEL.BACKBONE.STRIDE = 16 23 | cfg.MODEL.BACKBONE.CONCAT_MODE = 'tsts' # tsts: t_rgb, s_rgb, t_X, s_X | ttss: t_rgb, t_X, s_rgb, s_X 24 | cfg.MODEL.BACKBONE.Z_SEG = False # add template segment embedding (as multiple templates) 25 | cfg.MODEL.BACKBONE.ADD_CLS_TOKEN = False 26 | cfg.MODEL.BACKBONE.ADD_MOTION_PRED = False 27 | cfg.MODEL.BACKBONE.PROMPT_EMBED_TYPE = "vocab" # sam vocab 28 | 29 | cfg.MODEL.BACKBONE.MID_PE = False 30 | cfg.MODEL.BACKBONE.SEP_SEG = False 31 | cfg.MODEL.BACKBONE.CAT_MODE = 'direct' 32 | cfg.MODEL.BACKBONE.MERGE_LAYER = 0 33 | cfg.MODEL.BACKBONE.CLS_TOKEN_USE_MODE = 'ignore' 34 | 35 | cfg.MODEL.BACKBONE.CE_LOC = [] 36 | cfg.MODEL.BACKBONE.CE_KEEP_RATIO = [] 37 | cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE = 'ALL' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 38 | 39 | # MODEL.HEAD 40 | cfg.MODEL.HEAD = edict() 41 | cfg.MODEL.HEAD.TYPE = "CENTER" 42 | cfg.MODEL.HEAD.NUM_CHANNELS = 256 43 | cfg.MODEL.HEAD.FUSE = False 44 | 45 | # TRAIN 46 | cfg.TRAIN = edict() 47 | cfg.TRAIN.LR = 0.0001 48 | cfg.TRAIN.WEIGHT_DECAY = 0.0001 49 | cfg.TRAIN.EPOCH = 500 50 | cfg.TRAIN.LR_DROP_EPOCH = 400 51 | cfg.TRAIN.BATCH_SIZE = 16 52 | cfg.TRAIN.NUM_WORKER = 8 53 | cfg.TRAIN.OPTIMIZER = "ADAMW" 54 | cfg.TRAIN.BACKBONE_MULTIPLIER = 0.1 55 | cfg.TRAIN.GIOU_WEIGHT = 2.0 56 | cfg.TRAIN.L1_WEIGHT = 5.0 57 | cfg.TRAIN.FREEZE_LAYERS = [0, ] 58 | cfg.TRAIN.PRINT_INTERVAL = 50 59 | cfg.TRAIN.VAL_PRINT_INTERVAL = 10 60 | cfg.TRAIN.VAL_EPOCH_INTERVAL = 20 61 | cfg.TRAIN.GRAD_CLIP_NORM = 0.1 62 | cfg.TRAIN.AMP = False 63 | cfg.TRAIN.ACCUM_ITER = 1 64 | 65 | ## TRAIN save cfgs 66 | cfg.TRAIN.SAVE_EPOCH_INTERVAL = 1 # 1 means save model each epoch 67 | cfg.TRAIN.SAVE_LAST_N_EPOCH = 1 # besides, last n epoch model will be saved 68 | 69 | 70 | cfg.TRAIN.CE_START_EPOCH = 20 # candidate elimination start epoch 71 | cfg.TRAIN.CE_WARM_EPOCH = 80 # candidate elimination warm up epoch 72 | cfg.TRAIN.DROP_PATH_RATE = 0.1 # drop path rate for ViT backbone 73 | 74 | # TRAIN.SCHEDULER 75 | cfg.TRAIN.SCHEDULER = edict() 76 | cfg.TRAIN.SCHEDULER.TYPE = "step" 77 | cfg.TRAIN.SCHEDULER.DECAY_RATE = 0.1 78 | 79 | # DATA 80 | cfg.DATA = edict() 81 | cfg.DATA.PRE_MOTION_NUM = 0 82 | cfg.DATA.REVERSE_PROB = 0.0 # reverse video-level training augmentation 83 | cfg.DATA.SAMPLER_MODE = "causal" # sampling methods 84 | cfg.DATA.MEAN = [0.485, 0.456, 0.406] 85 | cfg.DATA.STD = [0.229, 0.224, 0.225] 86 | cfg.DATA.MAX_SAMPLE_INTERVAL = 200 87 | # DATA.TRAIN 88 | cfg.DATA.TRAIN = edict() 89 | cfg.DATA.TRAIN.DATASETS_NAME = ["LASOT", "GOT10K_vottrain"] 90 | cfg.DATA.TRAIN.DATASETS_RATIO = [1, 1] 91 | cfg.DATA.TRAIN.SAMPLE_PER_EPOCH = 60000 92 | # DATA.VAL 93 | cfg.DATA.VAL = edict() 94 | cfg.DATA.VAL.DATASETS_NAME = ["GOT10K_votval"] 95 | cfg.DATA.VAL.DATASETS_RATIO = [1] 96 | cfg.DATA.VAL.SAMPLE_PER_EPOCH = 10000 97 | # DATA.SEARCH 98 | cfg.DATA.SEARCH = edict() 99 | cfg.DATA.SEARCH.SIZE = 320 100 | cfg.DATA.SEARCH.FACTOR = 5.0 101 | cfg.DATA.SEARCH.CENTER_JITTER = 4.5 102 | cfg.DATA.SEARCH.SCALE_JITTER = 0.5 103 | cfg.DATA.SEARCH.NUMBER = 1 104 | # DATA.TEMPLATE 105 | cfg.DATA.TEMPLATE = edict() 106 | cfg.DATA.TEMPLATE.NUMBER = 1 107 | cfg.DATA.TEMPLATE.SIZE = 128 108 | cfg.DATA.TEMPLATE.FACTOR = 2.0 109 | cfg.DATA.TEMPLATE.CENTER_JITTER = 0 110 | cfg.DATA.TEMPLATE.SCALE_JITTER = 0 111 | 112 | # TEST 113 | cfg.TEST = edict() 114 | cfg.TEST.TEMPLATE_FACTOR = 2.0 115 | cfg.TEST.TEMPLATE_SIZE = 128 116 | cfg.TEST.SEARCH_FACTOR = 5.0 117 | cfg.TEST.SEARCH_SIZE = 320 118 | cfg.TEST.EPOCH = 500 119 | cfg.TEST.TEMPLATE_NUMBER = 3 120 | 121 | cfg.TEST.TEST_PRE_NUM = edict() 122 | cfg.TEST.TEST_PRE_NUM.GTOT = 6 123 | cfg.TEST.TEST_PRE_NUM.RGBT210 = 6 124 | cfg.TEST.TEST_PRE_NUM.RGBT234 = 6 125 | cfg.TEST.TEST_PRE_NUM.LasHeR = 6 126 | 127 | 128 | def _edict2dict(dest_dict, src_edict): 129 | if isinstance(dest_dict, dict) and isinstance(src_edict, dict): 130 | for k, v in src_edict.items(): 131 | if not isinstance(v, edict): 132 | dest_dict[k] = v 133 | else: 134 | dest_dict[k] = {} 135 | _edict2dict(dest_dict[k], v) 136 | else: 137 | return 138 | 139 | 140 | def gen_config(config_file): 141 | cfg_dict = {} 142 | _edict2dict(cfg_dict, cfg) 143 | with open(config_file, 'w') as f: 144 | yaml.dump(cfg_dict, f, default_flow_style=False) 145 | 146 | 147 | def _update_config(base_cfg, exp_cfg): 148 | if isinstance(base_cfg, dict) and isinstance(exp_cfg, edict): 149 | for k, v in exp_cfg.items(): 150 | if k in base_cfg: 151 | if not isinstance(v, dict): 152 | base_cfg[k] = v 153 | else: 154 | _update_config(base_cfg[k], v) 155 | else: 156 | raise ValueError("{} not exist in config.py".format(k)) 157 | else: 158 | return 159 | 160 | 161 | def update_config_from_file(filename, base_cfg=None): 162 | exp_config = None 163 | with open(filename) as f: 164 | exp_config = edict(yaml.safe_load(f)) 165 | if base_cfg is not None: 166 | _update_config(base_cfg, exp_config) 167 | else: 168 | _update_config(cfg, exp_config) 169 | -------------------------------------------------------------------------------- /lib/test/evaluation/environment.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | 5 | class EnvSettings: 6 | def __init__(self): 7 | test_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 8 | 9 | self.results_path = '{}/tracking_results/'.format(test_path) 10 | self.segmentation_path = '{}/segmentation_results/'.format(test_path) 11 | self.network_path = '{}/networks/'.format(test_path) 12 | self.result_plot_path = '{}/result_plots/'.format(test_path) 13 | self.otb_path = '' 14 | self.nfs_path = '' 15 | self.uav_path = '' 16 | self.tpl_path = '' 17 | self.vot_path = '' 18 | self.got10k_path = '' 19 | self.lasot_path = '' 20 | self.trackingnet_path = '' 21 | self.davis_dir = '' 22 | self.youtubevos_dir = '' 23 | 24 | self.got_packed_results_path = '' 25 | self.got_reports_path = '' 26 | self.tn_packed_results_path = '' 27 | 28 | 29 | def create_default_local_file(): 30 | comment = {'results_path': 'Where to store tracking results', 31 | 'network_path': 'Where tracking networks are stored.'} 32 | 33 | path = os.path.join(os.path.dirname(__file__), 'local.py') 34 | with open(path, 'w') as f: 35 | settings = EnvSettings() 36 | 37 | f.write('from test.evaluation.environment import EnvSettings\n\n') 38 | f.write('def local_env_settings():\n') 39 | f.write(' settings = EnvSettings()\n\n') 40 | f.write(' # Set your local paths here.\n\n') 41 | 42 | for attr in dir(settings): 43 | comment_str = None 44 | if attr in comment: 45 | comment_str = comment[attr] 46 | attr_val = getattr(settings, attr) 47 | if not attr.startswith('__') and not callable(attr_val): 48 | if comment_str is None: 49 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val)) 50 | else: 51 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 52 | f.write('\n return settings\n\n') 53 | 54 | 55 | class EnvSettings_ITP: 56 | def __init__(self, workspace_dir, data_dir, save_dir): 57 | self.prj_dir = workspace_dir 58 | self.save_dir = save_dir 59 | self.results_path = os.path.join(save_dir, 'test/tracking_results') 60 | self.segmentation_path = os.path.join(save_dir, 'test/segmentation_results') 61 | self.network_path = os.path.join(save_dir, 'test/networks') 62 | self.result_plot_path = os.path.join(save_dir, 'test/result_plots') 63 | self.otb_path = os.path.join(data_dir, 'otb') 64 | self.nfs_path = os.path.join(data_dir, 'nfs') 65 | self.uav_path = os.path.join(data_dir, 'uav') 66 | self.tc128_path = os.path.join(data_dir, 'TC128') 67 | self.tpl_path = '' 68 | self.vot_path = os.path.join(data_dir, 'VOT2019') 69 | self.got10k_path = os.path.join(data_dir, 'got10k') 70 | self.got10k_lmdb_path = os.path.join(data_dir, 'got10k_lmdb') 71 | self.lasot_path = os.path.join(data_dir, 'lasot') 72 | self.lasot_lmdb_path = os.path.join(data_dir, 'lasot_lmdb') 73 | self.trackingnet_path = os.path.join(data_dir, 'trackingnet') 74 | self.vot18_path = os.path.join(data_dir, 'vot2018') 75 | self.vot22_path = os.path.join(data_dir, 'vot2022') 76 | self.itb_path = os.path.join(data_dir, 'itb') 77 | self.tnl2k_path = os.path.join(data_dir, 'tnl2k') 78 | self.lasot_extension_subset_path_path = os.path.join(data_dir, 'lasot_extension_subset') 79 | self.davis_dir = '' 80 | self.youtubevos_dir = '' 81 | 82 | self.got_packed_results_path = '' 83 | self.got_reports_path = '' 84 | self.tn_packed_results_path = '' 85 | 86 | 87 | def create_default_local_file_ITP_test(workspace_dir, data_dir, save_dir): 88 | comment = {'results_path': 'Where to store tracking results', 89 | 'network_path': 'Where tracking networks are stored.'} 90 | 91 | path = os.path.join(os.path.dirname(__file__), 'local.py') 92 | with open(path, 'w') as f: 93 | settings = EnvSettings_ITP(workspace_dir, data_dir, save_dir) 94 | 95 | f.write('from lib.test.evaluation.environment import EnvSettings\n\n') 96 | f.write('def local_env_settings():\n') 97 | f.write(' settings = EnvSettings()\n\n') 98 | f.write(' # Set your local paths here.\n\n') 99 | 100 | for attr in dir(settings): 101 | comment_str = None 102 | if attr in comment: 103 | comment_str = comment[attr] 104 | attr_val = getattr(settings, attr) 105 | if not attr.startswith('__') and not callable(attr_val): 106 | if comment_str is None: 107 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val)) 108 | else: 109 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 110 | f.write('\n return settings\n\n') 111 | 112 | 113 | def env_settings(): 114 | env_module_name = 'lib.test.evaluation.local' 115 | try: 116 | env_module = importlib.import_module(env_module_name) 117 | return env_module.local_env_settings() 118 | except: 119 | env_file = os.path.join(os.path.dirname(__file__), 'local.py') 120 | 121 | # Create a default file 122 | create_default_local_file() 123 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. ' 124 | 'Then try to run again.'.format(env_file)) -------------------------------------------------------------------------------- /lib/train/run_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import importlib 6 | import cv2 as cv 7 | import torch.backends.cudnn 8 | import torch.distributed as dist 9 | 10 | import random 11 | import numpy as np 12 | torch.backends.cudnn.benchmark = False 13 | 14 | import _init_paths 15 | import lib.train.admin.settings as ws_settings 16 | 17 | 18 | def init_seeds(seed): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False 25 | 26 | 27 | def run_training(script_name, config_name, cudnn_benchmark=True, local_rank=-1, save_dir=None, base_seed=None, 28 | use_lmdb=False, script_name_prv=None, config_name_prv=None, use_wandb=False, 29 | distill=None, script_teacher=None, config_teacher=None): 30 | """Run the train script. 31 | args: 32 | script_name: Name of emperiment in the "experiments/" folder. 33 | config_name: Name of the yaml file in the "experiments/". 34 | cudnn_benchmark: Use cudnn benchmark or not (default is True). 35 | """ 36 | if save_dir is None: 37 | print("save_dir dir is not given. Use the default dir instead.") 38 | # This is needed to avoid strange crashes related to opencv 39 | cv.setNumThreads(0) 40 | 41 | torch.backends.cudnn.benchmark = cudnn_benchmark 42 | 43 | print('script_name: {}.py config_name: {}.yaml'.format(script_name, config_name)) 44 | 45 | '''2021.1.5 set seed for different process''' 46 | if base_seed is not None: 47 | if local_rank != -1: 48 | init_seeds(base_seed + local_rank) 49 | else: 50 | init_seeds(base_seed) 51 | 52 | settings = ws_settings.Settings() 53 | settings.script_name = script_name 54 | settings.config_name = config_name 55 | settings.project_path = 'train/{}/{}'.format(script_name, config_name) 56 | if script_name_prv is not None and config_name_prv is not None: 57 | settings.project_path_prv = 'train/{}/{}'.format(script_name_prv, config_name_prv) 58 | settings.local_rank = local_rank 59 | settings.save_dir = os.path.abspath(save_dir) 60 | settings.use_lmdb = use_lmdb 61 | prj_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) 62 | settings.cfg_file = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_name, config_name)) 63 | settings.use_wandb = use_wandb 64 | if distill: 65 | settings.distill = distill 66 | settings.script_teacher = script_teacher 67 | settings.config_teacher = config_teacher 68 | if script_teacher is not None and config_teacher is not None: 69 | settings.project_path_teacher = 'train/{}/{}'.format(script_teacher, config_teacher) 70 | settings.cfg_file_teacher = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_teacher, config_teacher)) 71 | expr_module = importlib.import_module('lib.train.train_script_distill') 72 | else: 73 | expr_module = importlib.import_module('lib.train.train_script') 74 | expr_func = getattr(expr_module, 'run') 75 | 76 | expr_func(settings) 77 | 78 | 79 | def main(): 80 | parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.') 81 | parser.add_argument('--script', type=str, required=True, help='Name of the train script.') 82 | parser.add_argument('--config', type=str, required=True, help="Name of the config file.") 83 | parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).') 84 | # parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') 85 | parser.add_argument('--save_dir', type=str, help='the directory to save checkpoints and logs') 86 | parser.add_argument('--seed', type=int, default=42, help='seed for random numbers') 87 | parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format 88 | parser.add_argument('--script_prv', type=str, default=None, help='Name of the train script of previous model.') 89 | parser.add_argument('--config_prv', type=str, default=None, help="Name of the config file of previous model.") 90 | parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb 91 | # for knowledge distillation 92 | parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation 93 | parser.add_argument('--script_teacher', type=str, help='teacher script name') 94 | parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name') 95 | 96 | # print(f"os.environ['LOCAL_RANK']: {os.environ['LOCAL_RANK']}") 97 | 98 | args = parser.parse_args() 99 | args.local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else -1 100 | 101 | if args.local_rank != -1: 102 | dist.init_process_group(backend='nccl') 103 | torch.cuda.set_device(args.local_rank) 104 | else: 105 | torch.cuda.set_device(0) 106 | run_training(args.script, args.config, cudnn_benchmark=args.cudnn_benchmark, 107 | local_rank=args.local_rank, save_dir=args.save_dir, base_seed=args.seed, 108 | use_lmdb=args.use_lmdb, script_name_prv=args.script_prv, config_name_prv=args.config_prv, 109 | use_wandb=args.use_wandb, 110 | distill=args.distill, script_teacher=args.script_teacher, config_teacher=args.config_teacher) 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | --------------------------------------------------------------------------------