├── LICENSE ├── README.md ├── RGBT_workspace └── test_rgbt_mgpus.py ├── eval_rgbt.sh ├── experiments └── bat │ ├── rgbe.yaml │ └── rgbt.yaml ├── install_bat.sh ├── lib ├── __init__.py ├── config │ └── bat │ │ └── config.py ├── models │ ├── __init__.py │ ├── bat │ │ ├── __init__.py │ │ ├── base_backbone.py │ │ ├── ostrack.py │ │ ├── ostrack_adapter.py │ │ ├── utils.py │ │ ├── vit.py │ │ ├── vit_ce.py │ │ └── vit_ce_adapter.py │ └── layers │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── attn.py │ │ ├── attn_adapt_blocks.py │ │ ├── attn_blocks.py │ │ ├── dualstream_attn_blocks.py │ │ ├── frozen_bn.py │ │ ├── head.py │ │ ├── max_head.py │ │ ├── patch_embed.py │ │ └── rpe.py ├── test │ ├── evaluation │ │ ├── __init__.py │ │ ├── data.py │ │ ├── datasets.py │ │ ├── environment.py │ │ ├── local.py │ │ ├── running.py │ │ ├── tracker.py │ │ ├── votdataset.py │ │ └── vtuavdataset.py │ ├── parameter │ │ ├── __init__.py │ │ └── bat.py │ ├── tracker │ │ ├── __init__.py │ │ ├── basetracker.py │ │ ├── bat.py │ │ ├── data_utils.py │ │ ├── ostrack.py │ │ └── vis_utils.py │ ├── utils │ │ ├── __init__.py │ │ ├── _init_paths.py │ │ ├── hann.py │ │ ├── load_text.py │ │ └── params.py │ └── vot │ │ ├── bat_baseline.py │ │ ├── bat_class.py │ │ ├── vot.py │ │ └── vot22_utils.py ├── train │ ├── __init__.py │ ├── _init_paths.py │ ├── actors │ │ ├── __init__.py │ │ ├── base_actor.py │ │ └── bat.py │ ├── admin │ │ ├── __init__.py │ │ ├── environment.py │ │ ├── local.py │ │ ├── multigpu.py │ │ ├── settings.py │ │ ├── stats.py │ │ └── tensorboard.py │ ├── base_functions.py │ ├── data │ │ ├── __init__.py │ │ ├── bounding_box_utils.py │ │ ├── image_loader.py │ │ ├── loader.py │ │ ├── processing.py │ │ ├── processing_utils.py │ │ ├── sampler.py │ │ └── transforms.py │ ├── data_specs │ │ ├── README.md │ │ ├── depthtrack_train.txt │ │ ├── depthtrack_val.txt │ │ ├── got10k_train_full_split.txt │ │ ├── got10k_train_split.txt │ │ ├── got10k_val_split.txt │ │ ├── got10k_vot_exclude.txt │ │ ├── got10k_vot_train_split.txt │ │ ├── got10k_vot_val_split.txt │ │ ├── lasher_all.txt │ │ ├── lasher_train.txt │ │ ├── lasher_val.txt │ │ ├── lasot_train_split.txt │ │ └── trackingnet_classmap.txt │ ├── dataset │ │ ├── COCO_tool.py │ │ ├── __init__.py │ │ ├── base_image_dataset.py │ │ ├── base_video_dataset.py │ │ ├── coco.py │ │ ├── coco_seq.py │ │ ├── coco_seq_lmdb.py │ │ ├── depth_utils.py │ │ ├── depthtrack.py │ │ ├── got10k.py │ │ ├── got10k_lmdb.py │ │ ├── imagenetvid.py │ │ ├── imagenetvid_lmdb.py │ │ ├── lasher.py │ │ ├── lasot.py │ │ ├── lasot_lmdb.py │ │ ├── tracking_net.py │ │ ├── tracking_net_lmdb.py │ │ └── visevent.py │ ├── run_training.py │ ├── train_script.py │ └── trainers │ │ ├── __init__.py │ │ ├── base_trainer.py │ │ └── ltr_trainer.py ├── utils │ ├── __init__.py │ ├── box_ops.py │ ├── ce_utils.py │ ├── focal_loss.py │ ├── heapmap_utils.py │ ├── lmdb_utils.py │ ├── merge.py │ ├── misc.py │ └── tensor.py └── vis │ ├── __init__.py │ ├── plotting.py │ ├── utils.py │ └── visdom_cus.py ├── main.py ├── tracking ├── __pycache__ │ └── _init_paths.cpython-37.pyc ├── _init_paths.py ├── create_default_local_file.py └── train.py └── train_bat.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SparkTempest 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bi-directional Adapter for Multimodal Tracking 2 | The official implementation for the AAAI2024 paper [**Bi-directional Adapter for Multimodal Tracking**](https://arxiv.org/abs/2312.10611). 3 | 4 | 5 | 6 | ## Models 7 | 8 | [Models & Raw Results](https://pan.baidu.com/s/1Fcv2BX2HTb8M8u2IRJ75aQ?pwd=ak66) 9 | (Baidu Driver: ak66) 10 | 11 | [Models & Raw Results](https://drive.google.com/drive/folders/1l8j8Ns8dGyrKrFrmetHPdqKPO0wNrZ1n?usp=sharing) 12 | (Google Drive) 13 | 14 | 15 | ## Usage 16 | ### Installation 17 | Create and activate a conda environment: 18 | ``` 19 | conda create -n bat python=3.7 20 | conda activate bat 21 | ``` 22 | Install the required packages: 23 | ``` 24 | bash install_bat.sh 25 | ``` 26 | 27 | ### Data Preparation 28 | Download the training datasets, It should look like: 29 | ``` 30 | $ 31 | -- LasHeR/TrainingSet 32 | |-- 1boygo 33 | |-- 1handsth 34 | ... 35 | -- VisEvent/train 36 | |-- 00142_tank_outdoor2 37 | |-- 00143_tank_outdoor2 38 | ... 39 | |-- trainlist.txt 40 | ``` 41 | 42 | ### Path Setting 43 | Run the following command to set paths: 44 | ``` 45 | cd 46 | python tracking/create_default_local_file.py --workspace_dir . --data_dir --save_dir ./output 47 | ``` 48 | You can also modify paths by these two files: 49 | ``` 50 | ./lib/train/admin/local.py # paths for training 51 | ./lib/test/evaluation/local.py # paths for testing 52 | ``` 53 | 54 | ### Training 55 | Dowmload the pretrained [foundation model](https://pan.baidu.com/s/1JX7xUlr-XutcsDsOeATU1A?pwd=4lvo) (OSTrack) (Baidu Driver: 4lvo) / [foundation model](https://drive.google.com/file/d/1WSkrdJu3OEBekoRz8qnDpnvEXhdr7Oec/view?usp=sharing) (Google Drive) 56 | and put it under ./pretrained/. 57 | ``` 58 | bash train_bat.sh 59 | ``` 60 | You can train models with various modalities and variants by modifying ```train_bat.sh```. 61 | 62 | ### Testing 63 | 64 | #### For RGB-T benchmarks 65 | [LasHeR & RGBT234] \ 66 | Modify the and in```./RGBT_workspace/test_rgbt_mgpus.py```, then run: 67 | ``` 68 | bash eval_rgbt.sh 69 | ``` 70 | We refer you to use [LasHeR Toolkit](https://github.com/BUGPLEASEOUT/LasHeR) for LasHeR evaluation, 71 | and refer you to use [MPR_MSR_Evaluation](https://sites.google.com/view/ahutracking001/) for RGBT234 evaluation. 72 | 73 | 74 | #### For RGB-E benchmark 75 | [VisEvent]\ 76 | Modify the and in```./RGBE_workspace/test_rgbe_mgpus.py```, then run: 77 | ``` 78 | bash eval_rgbe.sh 79 | ``` 80 | We refer you to use [VisEvent_SOT_Benchmark](https://github.com/wangxiao5791509/VisEvent_SOT_Benchmark) for evaluation. 81 | 82 | ## Citation 83 | Please cite our work if you think it is useful for your research. 84 | 85 | ```bibtex 86 | @inproceedings{BAT, 87 | title={Bi-directional Adapter for Multimodal Tracking}, 88 | author={Bing Cao, Junliang Guo, Pengfei Zhu, Qinghua Hu}, 89 | booktitle={AAAI Conference on Artificial Intelligence}, 90 | year={2024} 91 | } 92 | ``` 93 | 94 | 95 | 96 | 97 | 98 | ## Acknowledgment 99 | - This repo is based on [ViPT](https://github.com/jiawen-zhu/ViPT) which is an exellent work, helps us to quickly implement our ideas. 100 | - Thanks for the [OSTrack](https://github.com/botaoye/OSTrack) and [PyTracking](https://github.com/visionml/pytracking) library. 101 | 102 | -------------------------------------------------------------------------------- /eval_rgbt.sh: -------------------------------------------------------------------------------- 1 | # test lasher 2 | #CUDA_VISIBLE_DEVICES=1 NCCL_P2P_LEVEL=NVL python ./RGBT_workspace/test_rgbt_mgpus.py --script_name bat --dataset_name LasHeR --yaml_name rgbt 3 | 4 | # test rgbt234 5 | CUDA_VISIBLE_DEVICES=0 NCCL_P2P_LEVEL=NVL python ./RGBT_workspace/test_rgbt_mgpus.py --script_name bat --dataset_name RGBT234 --yaml_name rgbt 6 | 7 | 8 | 9 | #CUDA_VISIBLE_DEVICES=0 NCCL_P2P_LEVEL=NVL python ./RGBT_workspace/test_rgbt_mgpus.py --script_name bat --dataset_name DroneT --yaml_name rgbt 10 | 11 | 12 | #CUDA_VISIBLE_DEVICES=0,1,2,3 NCCL_P2P_LEVEL=NVL python ./RGBT_workspace/test_rgbt_mgpus.py --script_name bat --dataset_name VTUAVST --yaml_name rgbt -------------------------------------------------------------------------------- /experiments/bat/rgbe.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | TRAIN: 23 | DATASETS_NAME: 24 | - VisEvent 25 | DATASETS_RATIO: 26 | - 1 27 | SAMPLE_PER_EPOCH: 60000 28 | VAL: 29 | DATASETS_NAME: 30 | - 31 | DATASETS_RATIO: 32 | - 1 33 | SAMPLE_PER_EPOCH: 10000 34 | MODEL: 35 | PRETRAIN_FILE: "./pretrained/OSTrack_ep0300.pth.tar" 36 | EXTRA_MERGER: False 37 | RETURN_INTER: False 38 | BACKBONE: 39 | TYPE: vit_base_patch16_224_ce_adapter 40 | STRIDE: 16 41 | CE_LOC: [3, 6, 9] 42 | CE_KEEP_RATIO: [1, 1, 1] 43 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 44 | HEAD: 45 | TYPE: CENTER 46 | NUM_CHANNELS: 256 47 | TRAIN: 48 | BACKBONE_MULTIPLIER: 0.1 49 | DROP_PATH_RATE: 0.1 50 | CE_START_EPOCH: 4 # candidate elimination start epoch 1/15 51 | CE_WARM_EPOCH: 16 # candidate elimination warm up epoch 4/15 52 | BATCH_SIZE: 32 53 | EPOCH: 60 54 | GIOU_WEIGHT: 2.0 55 | L1_WEIGHT: 5.0 56 | GRAD_CLIP_NORM: 0.1 57 | LR: 0.0004 58 | LR_DROP_EPOCH: 48 # 4/5 59 | NUM_WORKER: 10 60 | OPTIMIZER: ADAMW 61 | PRINT_INTERVAL: 50 62 | SCHEDULER: 63 | TYPE: step 64 | DECAY_RATE: 0.1 65 | VAL_EPOCH_INTERVAL: 5 66 | WEIGHT_DECAY: 0.0001 67 | AMP: False 68 | PROMPT: 69 | TYPE: bat 70 | FIX_BN: true 71 | SAVE_EPOCH_INTERVAL: 5 72 | SAVE_LAST_N_EPOCH: 1 73 | TEST: 74 | EPOCH: 60 75 | SEARCH_FACTOR: 4.0 76 | SEARCH_SIZE: 256 77 | TEMPLATE_FACTOR: 2.0 78 | TEMPLATE_SIZE: 128 79 | -------------------------------------------------------------------------------- /experiments/bat/rgbt.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | TRAIN: 23 | DATASETS_NAME: 24 | - LasHeR_train 25 | DATASETS_RATIO: 26 | - 1 27 | SAMPLE_PER_EPOCH: 60000 28 | VAL: 29 | DATASETS_NAME: 30 | - LasHeR_val 31 | DATASETS_RATIO: 32 | - 1 33 | SAMPLE_PER_EPOCH: 10000 34 | MODEL: 35 | PRETRAIN_FILE: "./pretrained/OSTrack_ep0300.pth.tar" 36 | EXTRA_MERGER: False 37 | RETURN_INTER: False 38 | BACKBONE: 39 | TYPE: vit_base_patch16_224_ce_adapter 40 | STRIDE: 16 41 | CE_LOC: [3, 6, 9] 42 | CE_KEEP_RATIO: [1, 1, 1] #[0.7, 0.7, 0.7] 43 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 44 | HEAD: 45 | TYPE: CENTER 46 | NUM_CHANNELS: 256 47 | TRAIN: 48 | BACKBONE_MULTIPLIER: 0.1 49 | DROP_PATH_RATE: 0.1 50 | CE_START_EPOCH: 4 # candidate elimination start epoch 1/15 ######################################## 51 | CE_WARM_EPOCH: 16 # candidate elimination warm up epoch 4/15 52 | BATCH_SIZE: 32 53 | EPOCH: 60 54 | GIOU_WEIGHT: 2.0 55 | L1_WEIGHT: 5.0 56 | GRAD_CLIP_NORM: 0.1 57 | LR: 0.0004 58 | LR_DROP_EPOCH: 48 # 4/5 10 59 | NUM_WORKER: 10 60 | OPTIMIZER: ADAMW 61 | PRINT_INTERVAL: 50 62 | SCHEDULER: 63 | TYPE: step 64 | DECAY_RATE: 0.1 65 | VAL_EPOCH_INTERVAL: 5 66 | WEIGHT_DECAY: 0.0001 67 | AMP: False 68 | PROMPT: 69 | TYPE: bat #bat_12 70 | FIX_BN: true #true #false #============================= 71 | SAVE_EPOCH_INTERVAL: 5 72 | SAVE_LAST_N_EPOCH: 1 73 | TEST: 74 | EPOCH: 60 75 | SEARCH_FACTOR: 4.0 76 | SEARCH_SIZE: 256 77 | TEMPLATE_FACTOR: 2.0 78 | TEMPLATE_SIZE: 128 79 | -------------------------------------------------------------------------------- /install_bat.sh: -------------------------------------------------------------------------------- 1 | echo "****************** Installing pytorch ******************" 2 | #conda install -y pytorch==1.7.0 torchvision==0.8.1 cudatoolkit=10.2 -c pytorch 3 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia 4 | 5 | echo "" 6 | echo "" 7 | echo "****************** Installing yaml ******************" 8 | pip install PyYAML 9 | 10 | echo "" 11 | echo "" 12 | echo "****************** Installing easydict ******************" 13 | pip install easydict 14 | 15 | echo "" 16 | echo "" 17 | echo "****************** Installing cython ******************" 18 | pip install cython 19 | 20 | echo "" 21 | echo "" 22 | echo "****************** Installing opencv-python ******************" 23 | pip install opencv-python 24 | 25 | echo "" 26 | echo "" 27 | echo "****************** Installing pandas ******************" 28 | pip install pandas 29 | 30 | echo "" 31 | echo "" 32 | echo "****************** Installing tqdm ******************" 33 | conda install -y tqdm 34 | 35 | echo "" 36 | echo "" 37 | echo "****************** Installing coco toolkit ******************" 38 | pip install pycocotools 39 | 40 | echo "" 41 | echo "" 42 | echo "****************** Installing jpeg4py python wrapper ******************" 43 | apt-get install libturbojpeg 44 | pip install jpeg4py 45 | 46 | echo "" 47 | echo "" 48 | echo "****************** Installing scipy ******************" 49 | pip install scipy 50 | 51 | echo "" 52 | echo "" 53 | echo "****************** Installing timm ******************" 54 | pip install timm==0.5.4 55 | 56 | echo "" 57 | echo "" 58 | echo "****************** Installing tensorboard ******************" 59 | pip install tb-nightly 60 | 61 | echo "" 62 | echo "" 63 | echo "****************** Installing lmdb ******************" 64 | pip install lmdb 65 | 66 | echo "" 67 | echo "" 68 | echo "****************** Installing visdom ******************" 69 | pip install visdom 70 | 71 | echo "" 72 | echo "" 73 | echo "****************** Installing vot-toolkit python ******************" 74 | # Hi~ We employ the vot-toolkit==0.5.3 with vot-trax==3.0.3 75 | pip install git+https://github.com/votchallenge/vot-toolkit-python 76 | 77 | echo "****************** Installation complete! ******************" -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkTempest/BAT/ccf9b2f6ae3e810f4e7318c9d0b62083deb7ec89/lib/__init__.py -------------------------------------------------------------------------------- /lib/config/bat/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import yaml 3 | 4 | """ 5 | Add default config for BAT. 6 | """ 7 | cfg = edict() 8 | 9 | # MODEL 10 | cfg.MODEL = edict() 11 | cfg.MODEL.PRETRAIN_FILE = "" 12 | cfg.MODEL.EXTRA_MERGER = False 13 | cfg.MODEL.RETURN_INTER = False 14 | cfg.MODEL.RETURN_STAGES = [] 15 | 16 | # MODEL.BACKBONE 17 | cfg.MODEL.BACKBONE = edict() 18 | cfg.MODEL.BACKBONE.TYPE = "vit_base_patch16_224" 19 | cfg.MODEL.BACKBONE.STRIDE = 16 20 | cfg.MODEL.BACKBONE.MID_PE = False 21 | cfg.MODEL.BACKBONE.SEP_SEG = False 22 | cfg.MODEL.BACKBONE.CAT_MODE = 'direct' 23 | cfg.MODEL.BACKBONE.MERGE_LAYER = 0 24 | cfg.MODEL.BACKBONE.ADD_CLS_TOKEN = False 25 | cfg.MODEL.BACKBONE.CLS_TOKEN_USE_MODE = 'ignore' 26 | 27 | cfg.MODEL.BACKBONE.CE_LOC = [] 28 | cfg.MODEL.BACKBONE.CE_KEEP_RATIO = [] 29 | cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE = 'ALL' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX 30 | 31 | # MODEL.HEAD 32 | cfg.MODEL.HEAD = edict() 33 | cfg.MODEL.HEAD.TYPE = "CENTER" 34 | cfg.MODEL.HEAD.NUM_CHANNELS = 256 35 | 36 | # TRAIN 37 | cfg.TRAIN = edict() 38 | cfg.TRAIN.PROMPT = edict() 39 | cfg.TRAIN.PROMPT.TYPE = 'bat' # bat_12 40 | cfg.TRAIN.LR = 0.0001 41 | cfg.TRAIN.WEIGHT_DECAY = 0.0001 42 | cfg.TRAIN.EPOCH = 500 43 | cfg.TRAIN.LR_DROP_EPOCH = 400 44 | cfg.TRAIN.BATCH_SIZE = 16 45 | cfg.TRAIN.NUM_WORKER = 8 46 | cfg.TRAIN.OPTIMIZER = "ADAMW" 47 | cfg.TRAIN.BACKBONE_MULTIPLIER = 0.1 48 | cfg.TRAIN.GIOU_WEIGHT = 2.0 49 | cfg.TRAIN.L1_WEIGHT = 5.0 50 | cfg.TRAIN.FREEZE_LAYERS = [0, ] 51 | cfg.TRAIN.PRINT_INTERVAL = 50 52 | cfg.TRAIN.VAL_EPOCH_INTERVAL = 20 53 | cfg.TRAIN.GRAD_CLIP_NORM = 0.1 54 | cfg.TRAIN.AMP = False 55 | ## TRAIN save cfgs 56 | cfg.TRAIN.FIX_BN = True ###### 57 | cfg.TRAIN.SAVE_EPOCH_INTERVAL = 1 # 1 means save model each epoch 58 | cfg.TRAIN.SAVE_LAST_N_EPOCH = 1 # besides, last n epoch model will be saved 59 | 60 | cfg.TRAIN.CE_START_EPOCH = 20 # candidate elimination start epoch 61 | cfg.TRAIN.CE_WARM_EPOCH = 80 # candidate elimination warm up epoch 62 | cfg.TRAIN.DROP_PATH_RATE = 0.1 # drop path rate for ViT backbone 63 | 64 | # TRAIN.SCHEDULER 65 | cfg.TRAIN.SCHEDULER = edict() 66 | cfg.TRAIN.SCHEDULER.TYPE = "step" 67 | cfg.TRAIN.SCHEDULER.DECAY_RATE = 0.1 68 | 69 | # DATA 70 | cfg.DATA = edict() 71 | cfg.DATA.SAMPLER_MODE = "causal" # sampling methods 72 | cfg.DATA.MEAN = [0.485, 0.456, 0.406] 73 | cfg.DATA.STD = [0.229, 0.224, 0.225] 74 | cfg.DATA.MAX_SAMPLE_INTERVAL = 200 75 | # DATA.TRAIN 76 | cfg.DATA.TRAIN = edict() 77 | cfg.DATA.TRAIN.DATASETS_NAME = ["LASOT", "GOT10K_vottrain"] 78 | cfg.DATA.TRAIN.DATASETS_RATIO = [1, 1] 79 | cfg.DATA.TRAIN.SAMPLE_PER_EPOCH = 60000 80 | # DATA.VAL 81 | cfg.DATA.VAL = edict() 82 | cfg.DATA.VAL.DATASETS_NAME = [] 83 | cfg.DATA.VAL.DATASETS_RATIO = [1] 84 | cfg.DATA.VAL.SAMPLE_PER_EPOCH = 10000 85 | # DATA.SEARCH 86 | cfg.DATA.SEARCH = edict() 87 | cfg.DATA.SEARCH.SIZE = 320 88 | cfg.DATA.SEARCH.FACTOR = 5.0 89 | cfg.DATA.SEARCH.CENTER_JITTER = 4.5 90 | cfg.DATA.SEARCH.SCALE_JITTER = 0.5 91 | cfg.DATA.SEARCH.NUMBER = 1 92 | # DATA.TEMPLATE 93 | cfg.DATA.TEMPLATE = edict() 94 | cfg.DATA.TEMPLATE.NUMBER = 1 95 | cfg.DATA.TEMPLATE.SIZE = 128 96 | cfg.DATA.TEMPLATE.FACTOR = 2.0 97 | cfg.DATA.TEMPLATE.CENTER_JITTER = 0 98 | cfg.DATA.TEMPLATE.SCALE_JITTER = 0 99 | 100 | # TEST 101 | cfg.TEST = edict() 102 | cfg.TEST.TEMPLATE_FACTOR = 2.0 103 | cfg.TEST.TEMPLATE_SIZE = 128 104 | cfg.TEST.SEARCH_FACTOR = 5.0 105 | cfg.TEST.SEARCH_SIZE = 320 106 | cfg.TEST.EPOCH = 500 107 | 108 | 109 | def _edict2dict(dest_dict, src_edict): 110 | if isinstance(dest_dict, dict) and isinstance(src_edict, dict): 111 | for k, v in src_edict.items(): 112 | if not isinstance(v, edict): 113 | dest_dict[k] = v 114 | else: 115 | dest_dict[k] = {} 116 | _edict2dict(dest_dict[k], v) 117 | else: 118 | return 119 | 120 | 121 | def gen_config(config_file): 122 | cfg_dict = {} 123 | _edict2dict(cfg_dict, cfg) 124 | with open(config_file, 'w') as f: 125 | yaml.dump(cfg_dict, f, default_flow_style=False) 126 | 127 | 128 | def _update_config(base_cfg, exp_cfg): 129 | if isinstance(base_cfg, dict) and isinstance(exp_cfg, edict): 130 | for k, v in exp_cfg.items(): 131 | if k in base_cfg: 132 | if not isinstance(v, dict): 133 | base_cfg[k] = v 134 | else: 135 | _update_config(base_cfg[k], v) 136 | else: 137 | raise ValueError("{} not exist in config.py".format(k)) 138 | else: 139 | return 140 | 141 | 142 | def update_config_from_file(filename, base_cfg=None): 143 | exp_config = None 144 | with open(filename) as f: 145 | exp_config = edict(yaml.safe_load(f)) 146 | if base_cfg is not None: 147 | _update_config(base_cfg, exp_config) 148 | else: 149 | _update_config(cfg, exp_config) 150 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .bat.ostrack_adapter import build_batrack -------------------------------------------------------------------------------- /lib/models/bat/__init__.py: -------------------------------------------------------------------------------- 1 | from .ostrack import build_ostrack 2 | from .ostrack_adapter import build_batrack -------------------------------------------------------------------------------- /lib/models/bat/ostrack.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic OSTrack model. 3 | """ 4 | import math 5 | import os 6 | from typing import List 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn.modules.transformer import _get_clones 11 | 12 | from lib.models.layers.head import build_box_head 13 | from lib.models.bat.vit import vit_base_patch16_224 14 | from lib.models.bat.vit_ce import vit_large_patch16_224_ce, vit_base_patch16_224_ce 15 | from lib.utils.box_ops import box_xyxy_to_cxcywh 16 | 17 | 18 | class OSTrack(nn.Module): 19 | """ This is the base class for OSTrack """ 20 | 21 | def __init__(self, transformer, box_head, aux_loss=False, head_type="CORNER"): 22 | """ Initializes the model. 23 | Parameters: 24 | transformer: torch module of the transformer architecture. 25 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 26 | """ 27 | super().__init__() 28 | self.backbone = transformer 29 | self.box_head = box_head 30 | 31 | self.aux_loss = aux_loss 32 | self.head_type = head_type 33 | if head_type == "CORNER" or head_type == "CENTER": 34 | self.feat_sz_s = int(box_head.feat_sz) 35 | self.feat_len_s = int(box_head.feat_sz ** 2) 36 | 37 | if self.aux_loss: 38 | self.box_head = _get_clones(self.box_head, 6) 39 | 40 | def forward(self, template: torch.Tensor, 41 | search: torch.Tensor, 42 | ce_template_mask=None, 43 | ce_keep_rate=None, 44 | return_last_attn=False, 45 | ): 46 | x, aux_dict = self.backbone(z=template, x=search, 47 | ce_template_mask=ce_template_mask, 48 | ce_keep_rate=ce_keep_rate, 49 | return_last_attn=return_last_attn, ) 50 | 51 | # Forward head 52 | feat_last = x 53 | if isinstance(x, list): 54 | feat_last = x[-1] 55 | out = self.forward_head(feat_last, None) 56 | 57 | out.update(aux_dict) 58 | out['backbone_feat'] = x 59 | return out 60 | 61 | def forward_head(self, cat_feature, gt_score_map=None): 62 | """ 63 | cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) 64 | """ 65 | enc_opt = cat_feature[:, -self.feat_len_s:] # encoder output for the search region (B, HW, C) 66 | opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous() 67 | bs, Nq, C, HW = opt.size() 68 | opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s) 69 | 70 | if self.head_type == "CORNER": 71 | # run the corner head 72 | pred_box, score_map = self.box_head(opt_feat, True) 73 | outputs_coord = box_xyxy_to_cxcywh(pred_box) 74 | outputs_coord_new = outputs_coord.view(bs, Nq, 4) 75 | out = {'pred_boxes': outputs_coord_new, 76 | 'score_map': score_map, 77 | } 78 | return out 79 | 80 | elif self.head_type == "CENTER": 81 | # run the center head 82 | score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat, gt_score_map) 83 | # outputs_coord = box_xyxy_to_cxcywh(bbox) 84 | outputs_coord = bbox 85 | outputs_coord_new = outputs_coord.view(bs, Nq, 4) 86 | out = {'pred_boxes': outputs_coord_new, 87 | 'score_map': score_map_ctr, 88 | 'size_map': size_map, 89 | 'offset_map': offset_map} 90 | return out 91 | else: 92 | raise NotImplementedError 93 | 94 | 95 | def build_ostrack(cfg, training=True): 96 | current_dir = os.path.dirname(os.path.abspath(__file__)) # This is your Project Root 97 | pretrained_path = os.path.join(current_dir, '../../../pretrained_models') 98 | if cfg.MODEL.PRETRAIN_FILE and ('OSTrack' not in cfg.MODEL.PRETRAIN_FILE) and training: 99 | pretrained = os.path.join(pretrained_path, cfg.MODEL.PRETRAIN_FILE) 100 | else: 101 | pretrained = '' 102 | 103 | if cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224': 104 | backbone = vit_base_patch16_224(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE) 105 | hidden_dim = backbone.embed_dim 106 | patch_start_index = 1 107 | 108 | elif cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224_ce': 109 | backbone = vit_base_patch16_224_ce(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE, 110 | ce_loc=cfg.MODEL.BACKBONE.CE_LOC, 111 | ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO, 112 | ) 113 | hidden_dim = backbone.embed_dim 114 | patch_start_index = 1 115 | 116 | elif cfg.MODEL.BACKBONE.TYPE == 'vit_large_patch16_224_ce': 117 | backbone = vit_large_patch16_224_ce(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE, 118 | ce_loc=cfg.MODEL.BACKBONE.CE_LOC, 119 | ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO, 120 | ) 121 | 122 | hidden_dim = backbone.embed_dim 123 | patch_start_index = 1 124 | 125 | else: 126 | raise NotImplementedError 127 | 128 | backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index) 129 | 130 | box_head = build_box_head(cfg, hidden_dim) 131 | 132 | model = OSTrack( 133 | backbone, 134 | box_head, 135 | aux_loss=False, 136 | head_type=cfg.MODEL.HEAD.TYPE, 137 | ) 138 | 139 | if 'OSTrack' in cfg.MODEL.PRETRAIN_FILE and training: 140 | checkpoint = torch.load(cfg.MODEL.PRETRAIN_FILE, map_location="cpu") 141 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint["net"], strict=False) 142 | print('Load pretrained model from: ' + cfg.MODEL.PRETRAIN_FILE) 143 | 144 | return model 145 | -------------------------------------------------------------------------------- /lib/models/bat/ostrack_adapter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic BAT model. 3 | """ 4 | import math 5 | import os 6 | from typing import List 7 | from timm.models.layers import to_2tuple 8 | import torch 9 | from torch import nn 10 | from torch.nn.modules.transformer import _get_clones 11 | from lib.models.layers.head import build_box_head 12 | #from lib.models.bat.vit_adapter import vit_base_patch16_224_adapter 13 | from lib.models.bat.vit_ce_adapter import vit_base_patch16_224_ce_adapter 14 | from lib.utils.box_ops import box_xyxy_to_cxcywh 15 | 16 | 17 | class BATrack(nn.Module): 18 | """ This is the base class for BATrack """ 19 | 20 | def __init__(self, transformer, box_head, aux_loss=False, head_type="CORNER"): 21 | """ Initializes the model. 22 | Parameters: 23 | transformer: torch module of the transformer architecture. 24 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 25 | """ 26 | super().__init__() 27 | self.backbone = transformer 28 | self.box_head = box_head 29 | 30 | self.aux_loss = aux_loss 31 | self.head_type = head_type 32 | if head_type == "CORNER" or head_type == "CENTER": 33 | self.feat_sz_s = int(box_head.feat_sz) 34 | self.feat_len_s = int(box_head.feat_sz ** 2) 35 | 36 | if self.aux_loss: 37 | self.box_head = _get_clones(self.box_head, 6) 38 | 39 | def forward(self, template: torch.Tensor, 40 | search: torch.Tensor, 41 | ce_template_mask=None, 42 | ce_keep_rate=None, 43 | return_last_attn=False, 44 | ): 45 | x, aux_dict = self.backbone(z=template, x=search, 46 | ce_template_mask=ce_template_mask, 47 | ce_keep_rate=ce_keep_rate, 48 | return_last_attn=return_last_attn, ) 49 | 50 | # Forward head 51 | feat_last = x 52 | if isinstance(x, list): 53 | feat_last = x[-1] 54 | out = self.forward_head(feat_last, None) 55 | 56 | out.update(aux_dict) 57 | out['backbone_feat'] = x 58 | return out 59 | 60 | def forward_head(self, cat_feature, gt_score_map=None): 61 | """ 62 | cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C) 63 | """ 64 | #print("cat_feature",cat_feature.shape) 65 | enc_opt = cat_feature[:, -self.feat_len_s:] # encoder output for the search region (B, HW, C) 66 | opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous() 67 | bs, Nq, C, HW = opt.size() 68 | opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s) 69 | #print("opt_feat", opt_feat.shape) 70 | 71 | if self.head_type == "CORNER": 72 | # run the corner head 73 | pred_box, score_map = self.box_head(opt_feat, True) 74 | outputs_coord = box_xyxy_to_cxcywh(pred_box) 75 | outputs_coord_new = outputs_coord.view(bs, Nq, 4) 76 | out = {'pred_boxes': outputs_coord_new, 77 | 'score_map': score_map, 78 | } 79 | return out 80 | 81 | elif self.head_type == "CENTER": 82 | # run the center head 83 | score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat, gt_score_map) 84 | # outputs_coord = box_xyxy_to_cxcywh(bbox) 85 | outputs_coord = bbox 86 | #print("outputs_coord", outputs_coord.shape) 87 | outputs_coord_new = outputs_coord.view(bs, Nq, 4) 88 | out = {'pred_boxes': outputs_coord_new, 89 | 'score_map': score_map_ctr, 90 | 'size_map': size_map, 91 | 'offset_map': offset_map} 92 | return out 93 | else: 94 | raise NotImplementedError 95 | 96 | 97 | def build_batrack(cfg, training=True): 98 | current_dir = os.path.dirname(os.path.abspath(__file__)) # This is your Project Root 99 | pretrained_path = os.path.join(current_dir, '../../../pretrained_models') # use pretrained OSTrack as initialization 100 | if cfg.MODEL.PRETRAIN_FILE and ('OSTrack' not in cfg.MODEL.PRETRAIN_FILE) and training: 101 | pretrained = os.path.join(pretrained_path, cfg.MODEL.PRETRAIN_FILE) 102 | else: 103 | pretrained = '' 104 | 105 | if cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224_adapter': 106 | backbone = vit_base_patch16_224_adapter(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE, 107 | search_size=to_2tuple(cfg.DATA.SEARCH.SIZE), 108 | template_size=to_2tuple(cfg.DATA.TEMPLATE.SIZE), 109 | new_patch_size=cfg.MODEL.BACKBONE.STRIDE, 110 | adapter_type=cfg.TRAIN.PROMPT.TYPE 111 | ) 112 | hidden_dim = backbone.embed_dim 113 | patch_start_index = 1 114 | 115 | elif cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224_ce_adapter': 116 | backbone = vit_base_patch16_224_ce_adapter(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE, 117 | ce_loc=cfg.MODEL.BACKBONE.CE_LOC, 118 | ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO, 119 | search_size=to_2tuple(cfg.DATA.SEARCH.SIZE), 120 | template_size=to_2tuple(cfg.DATA.TEMPLATE.SIZE), 121 | new_patch_size=cfg.MODEL.BACKBONE.STRIDE, 122 | adapter_type=cfg.TRAIN.PROMPT.TYPE 123 | ) 124 | hidden_dim = backbone.embed_dim 125 | patch_start_index = 1 126 | 127 | else: 128 | raise NotImplementedError 129 | """For adapter no need, because we have OSTrack as initialization""" 130 | # backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index) 131 | 132 | box_head = build_box_head(cfg, hidden_dim) 133 | 134 | model = BATrack( 135 | backbone, 136 | box_head, 137 | aux_loss=False, 138 | head_type=cfg.MODEL.HEAD.TYPE, 139 | ) 140 | 141 | if 'OSTrack' in cfg.MODEL.PRETRAIN_FILE and training: 142 | checkpoint = torch.load(cfg.MODEL.PRETRAIN_FILE, map_location="cpu") 143 | missing_keys, unexpected_keys = model.load_state_dict(checkpoint["net"], strict=False) 144 | print('Load pretrained model from: ' + cfg.MODEL.PRETRAIN_FILE) 145 | #print(f"missing_keys: {missing_keys}") 146 | #print(f"unexpected_keys: {unexpected_keys}") 147 | 148 | return model 149 | -------------------------------------------------------------------------------- /lib/models/bat/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def combine_tokens(template_tokens, search_tokens, mode='direct', return_res=False): 8 | # [B, HW, C] 9 | len_t = template_tokens.shape[1] 10 | len_s = search_tokens.shape[1] 11 | 12 | if mode == 'direct': 13 | merged_feature = torch.cat((template_tokens, search_tokens), dim=1) 14 | elif mode == 'template_central': 15 | central_pivot = len_s // 2 16 | first_half = search_tokens[:, :central_pivot, :] 17 | second_half = search_tokens[:, central_pivot:, :] 18 | merged_feature = torch.cat((first_half, template_tokens, second_half), dim=1) 19 | elif mode == 'partition': 20 | feat_size_s = int(math.sqrt(len_s)) 21 | feat_size_t = int(math.sqrt(len_t)) 22 | window_size = math.ceil(feat_size_t / 2.) 23 | # pad feature maps to multiples of window size 24 | B, _, C = template_tokens.shape 25 | H = W = feat_size_t 26 | template_tokens = template_tokens.view(B, H, W, C) 27 | pad_l = pad_b = pad_r = 0 28 | # pad_r = (window_size - W % window_size) % window_size 29 | pad_t = (window_size - H % window_size) % window_size 30 | template_tokens = F.pad(template_tokens, (0, 0, pad_l, pad_r, pad_t, pad_b)) 31 | _, Hp, Wp, _ = template_tokens.shape 32 | template_tokens = template_tokens.view(B, Hp // window_size, window_size, W, C) 33 | template_tokens = torch.cat([template_tokens[:, 0, ...], template_tokens[:, 1, ...]], dim=2) 34 | _, Hc, Wc, _ = template_tokens.shape 35 | template_tokens = template_tokens.view(B, -1, C) 36 | merged_feature = torch.cat([template_tokens, search_tokens], dim=1) 37 | 38 | # calculate new h and w, which may be useful for SwinT or others 39 | merged_h, merged_w = feat_size_s + Hc, feat_size_s 40 | if return_res: 41 | return merged_feature, merged_h, merged_w 42 | 43 | else: 44 | raise NotImplementedError 45 | 46 | return merged_feature 47 | 48 | 49 | def recover_tokens(merged_tokens, len_template_token, len_search_token, mode='direct'): 50 | if mode == 'direct': 51 | recovered_tokens = merged_tokens 52 | elif mode == 'template_central': 53 | central_pivot = len_search_token // 2 54 | len_remain = len_search_token - central_pivot 55 | len_half_and_t = central_pivot + len_template_token 56 | 57 | first_half = merged_tokens[:, :central_pivot, :] 58 | second_half = merged_tokens[:, -len_remain:, :] 59 | template_tokens = merged_tokens[:, central_pivot:len_half_and_t, :] 60 | 61 | recovered_tokens = torch.cat((template_tokens, first_half, second_half), dim=1) 62 | elif mode == 'partition': 63 | recovered_tokens = merged_tokens 64 | else: 65 | raise NotImplementedError 66 | 67 | return recovered_tokens 68 | 69 | 70 | def window_partition(x, window_size: int): 71 | """ 72 | Args: 73 | x: (B, H, W, C) 74 | window_size (int): window size 75 | 76 | Returns: 77 | windows: (num_windows*B, window_size, window_size, C) 78 | """ 79 | B, H, W, C = x.shape 80 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 81 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 82 | return windows 83 | 84 | 85 | def window_reverse(windows, window_size: int, H: int, W: int): 86 | """ 87 | Args: 88 | windows: (num_windows*B, window_size, window_size, C) 89 | window_size (int): Window size 90 | H (int): Height of image 91 | W (int): Width of image 92 | 93 | Returns: 94 | x: (B, H, W, C) 95 | """ 96 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 97 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 98 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 99 | return x 100 | 101 | 102 | ''' 103 | add token transfer to feature 104 | ''' 105 | def token2feature(tokens): 106 | B,L,D=tokens.shape 107 | H=W=int(L**0.5) 108 | x = tokens.permute(0, 2, 1).view(B, D, W, H).contiguous() 109 | return x 110 | 111 | 112 | ''' 113 | feature2token 114 | ''' 115 | def feature2token(x): 116 | B,C,W,H = x.shape 117 | L = W*H 118 | tokens = x.view(B, C, L).permute(0, 2, 1).contiguous() 119 | return tokens -------------------------------------------------------------------------------- /lib/models/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkTempest/BAT/ccf9b2f6ae3e810f4e7318c9d0b62083deb7ec89/lib/models/layers/__init__.py -------------------------------------------------------------------------------- /lib/models/layers/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import timm 4 | import math 5 | 6 | 7 | ''' 8 | def forward_block(self, x): 9 | x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.adapter_attn(self.norm1(x))) * self.s 10 | x = x + self.drop_path(self.mlp(self.norm2(x))) + self.drop_path(self.adapter_mlp(self.norm2(x))) * self.s 11 | return x 12 | 13 | 14 | def forward_block_attn(self, x): 15 | x = x + self.drop_path(self.attn(self.norm1(x))) + self.drop_path(self.adapter_attn(self.norm1(x))) * self.s 16 | x = x + self.drop_path(self.mlp(self.norm2(x))) 17 | return x 18 | ''' 19 | 20 | 21 | class QuickGELU(nn.Module): 22 | def forward(self, x: torch.Tensor): 23 | return x * torch.sigmoid(1.702 * x) 24 | 25 | 26 | 27 | class Bi_direct_adapter(nn.Module): 28 | def __init__(self, dim=8, xavier_init=False): 29 | super().__init__() 30 | 31 | self.adapter_down = nn.Linear(768, dim) 32 | self.adapter_up = nn.Linear(dim, 768) 33 | self.adapter_mid = nn.Linear(dim, dim) 34 | 35 | #nn.init.xavier_uniform_(self.adapter_down.weight) 36 | nn.init.zeros_(self.adapter_mid.bias) 37 | nn.init.zeros_(self.adapter_mid.weight) 38 | nn.init.zeros_(self.adapter_down.weight) 39 | nn.init.zeros_(self.adapter_down.bias) 40 | nn.init.zeros_(self.adapter_up.weight) 41 | nn.init.zeros_(self.adapter_up.bias) 42 | 43 | #self.act = QuickGELU() 44 | self.dropout = nn.Dropout(0.1) 45 | self.dim = dim 46 | 47 | def forward(self, x): 48 | B, N, C = x.shape 49 | x_down = self.adapter_down(x) 50 | #x_down = self.act(x_down) 51 | x_down = self.adapter_mid(x_down) 52 | #x_down = self.act(x_down) 53 | x_down = self.dropout(x_down) 54 | x_up = self.adapter_up(x_down) 55 | #print("return adap x", x_up.size()) 56 | return x_up 57 | 58 | """ 59 | 60 | 61 | class Convpass(nn.Module): 62 | def __init__(self, dim=8, xavier_init=False): 63 | super().__init__() 64 | 65 | self.adapter_conv = nn.Conv2d(dim, dim, 3, 1, 1) 66 | if xavier_init: 67 | nn.init.xavier_uniform_(self.adapter_conv.weight) 68 | else: 69 | nn.init.zeros_(self.adapter_conv.weight) 70 | self.adapter_conv.weight.data[:, :, 1, 1] += torch.eye(8, dtype=torch.float) 71 | nn.init.zeros_(self.adapter_conv.bias) 72 | 73 | self.adapter_down = nn.Linear(768, dim) # equivalent to 1 * 1 Conv 74 | self.adapter_up = nn.Linear(dim, 768) # equivalent to 1 * 1 Conv 75 | nn.init.xavier_uniform_(self.adapter_down.weight) 76 | nn.init.zeros_(self.adapter_down.bias) 77 | nn.init.zeros_(self.adapter_up.weight) 78 | nn.init.zeros_(self.adapter_up.bias) 79 | 80 | self.act = QuickGELU() 81 | self.dropout = nn.Dropout(0.1) 82 | self.dim = dim 83 | 84 | def forward(self, x): 85 | B, N, C = x.shape 86 | #print(x.shape) 87 | x_down = self.adapter_down(x) # equivalent to 1 * 1 Conv 88 | x_down = self.act(x_down) 89 | 90 | #print(x_down.shape) 91 | 92 | x_patch = x_down[:, 64:].reshape(B, 16, 16, self.dim).permute(0, 3, 1, 2) ############ 93 | x_patch = self.adapter_conv(x_patch) 94 | x_patch = x_patch.permute(0, 2, 3, 1).reshape(B, 16 * 16, self.dim) 95 | 96 | 97 | #x_down = torch.cat([x_cls, x_patch], dim=1) 98 | 99 | x_down = self.act(x_down) 100 | x_down = self.dropout(x_down) 101 | x_up = self.adapter_up(x_down) # equivalent to 1 * 1 Conv 102 | 103 | return x_up 104 | """ -------------------------------------------------------------------------------- /lib/models/layers/attn_blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from timm.models.layers import Mlp, DropPath, trunc_normal_, lecun_normal_ 5 | 6 | from lib.models.layers.attn import Attention 7 | 8 | 9 | def candidate_elimination_adapter(tokens: torch.Tensor, lens_t: int, global_index: torch.Tensor): 10 | # tokens: actually adapter_parameters 11 | tokens_t = tokens[:, :lens_t] 12 | tokens_s = tokens[:, lens_t:] 13 | 14 | B, L, C = tokens_s.shape 15 | attentive_tokens = tokens_s.gather(dim=1, index=global_index.unsqueeze(-1).expand(B, -1, C)) 16 | tokens_new = torch.cat([tokens_t, attentive_tokens], dim=1) 17 | 18 | return tokens_new 19 | 20 | 21 | def candidate_elimination(attn: torch.Tensor, tokens: torch.Tensor, lens_t: int, keep_ratio: float, global_index: torch.Tensor, box_mask_z: torch.Tensor): 22 | """ 23 | Eliminate potential background candidates for computation reduction and noise cancellation. 24 | Args: 25 | attn (torch.Tensor): [B, num_heads, L_t + L_s, L_t + L_s], attention weights 26 | tokens (torch.Tensor): [B, L_t + L_s, C], template and search region tokens 27 | lens_t (int): length of template 28 | keep_ratio (float): keep ratio of search region tokens (candidates) 29 | global_index (torch.Tensor): global index of search region tokens 30 | box_mask_z (torch.Tensor): template mask used to accumulate attention weights 31 | 32 | Returns: 33 | tokens_new (torch.Tensor): tokens after candidate elimination 34 | keep_index (torch.Tensor): indices of kept search region tokens 35 | removed_index (torch.Tensor): indices of removed search region tokens 36 | """ 37 | lens_s = attn.shape[-1] - lens_t 38 | bs, hn, _, _ = attn.shape 39 | 40 | lens_keep = math.ceil(keep_ratio * lens_s) 41 | if lens_keep == lens_s: 42 | return tokens, global_index, None 43 | 44 | attn_t = attn[:, :, :lens_t, lens_t:] 45 | 46 | if box_mask_z is not None: 47 | box_mask_z = box_mask_z.unsqueeze(1).unsqueeze(-1).expand(-1, attn_t.shape[1], -1, attn_t.shape[-1]) 48 | # attn_t = attn_t[:, :, box_mask_z, :] 49 | attn_t = attn_t[box_mask_z] 50 | attn_t = attn_t.view(bs, hn, -1, lens_s) 51 | attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s 52 | else: 53 | attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s 54 | 55 | # use sort instead of topk, due to the speed issue 56 | # https://github.com/pytorch/pytorch/issues/22812 57 | sorted_attn, indices = torch.sort(attn_t, dim=1, descending=True) 58 | 59 | topk_attn, topk_idx = sorted_attn[:, :lens_keep], indices[:, :lens_keep] 60 | non_topk_attn, non_topk_idx = sorted_attn[:, lens_keep:], indices[:, lens_keep:] 61 | 62 | keep_index = global_index.gather(dim=1, index=topk_idx) 63 | removed_index = global_index.gather(dim=1, index=non_topk_idx) 64 | 65 | # separate template and search tokens 66 | tokens_t = tokens[:, :lens_t] 67 | tokens_s = tokens[:, lens_t:] 68 | 69 | # obtain the attentive and inattentive tokens 70 | B, L, C = tokens_s.shape 71 | # topk_idx_ = topk_idx.unsqueeze(-1).expand(B, lens_keep, C) 72 | attentive_tokens = tokens_s.gather(dim=1, index=topk_idx.unsqueeze(-1).expand(B, -1, C)) 73 | tokens_new = torch.cat([tokens_t, attentive_tokens], dim=1) 74 | 75 | return tokens_new, keep_index, removed_index 76 | 77 | 78 | class CEBlock(nn.Module): 79 | 80 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 81 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, keep_ratio_search=1.0,): 82 | super().__init__() 83 | self.norm1 = norm_layer(dim) 84 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 85 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 86 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 87 | self.norm2 = norm_layer(dim) 88 | mlp_hidden_dim = int(dim * mlp_ratio) 89 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 90 | 91 | self.keep_ratio_search = keep_ratio_search 92 | 93 | def forward(self, x, global_index_template, global_index_search, mask=None, ce_template_mask=None, keep_ratio_search=None): 94 | x_attn, attn = self.attn(self.norm1(x), mask, True) 95 | x = x + self.drop_path(x_attn) 96 | lens_t = global_index_template.shape[1] 97 | 98 | removed_index_search = None 99 | if self.keep_ratio_search < 1 and (keep_ratio_search is None or keep_ratio_search < 1): 100 | keep_ratio_search = self.keep_ratio_search if keep_ratio_search is None else keep_ratio_search 101 | x, global_index_search, removed_index_search = candidate_elimination(attn, x, lens_t, keep_ratio_search, global_index_search, ce_template_mask) 102 | 103 | x = x + self.drop_path(self.mlp(self.norm2(x))) 104 | return x, global_index_template, global_index_search, removed_index_search, attn 105 | 106 | 107 | class Block(nn.Module): 108 | 109 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 110 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 111 | super().__init__() 112 | self.norm1 = norm_layer(dim) 113 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 114 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 115 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 116 | self.norm2 = norm_layer(dim) 117 | mlp_hidden_dim = int(dim * mlp_ratio) 118 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 119 | 120 | def forward(self, x, mask=None): 121 | x = x + self.drop_path(self.attn(self.norm1(x), mask)) 122 | x = x + self.drop_path(self.mlp(self.norm2(x))) 123 | return x 124 | -------------------------------------------------------------------------------- /lib/models/layers/frozen_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FrozenBatchNorm2d(torch.nn.Module): 5 | """ 6 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 7 | 8 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 9 | without which any other models than torchvision.models.resnet[18,34,50,101] 10 | produce nans. 11 | """ 12 | 13 | def __init__(self, n): 14 | super(FrozenBatchNorm2d, self).__init__() 15 | self.register_buffer("weight", torch.ones(n)) 16 | self.register_buffer("bias", torch.zeros(n)) 17 | self.register_buffer("running_mean", torch.zeros(n)) 18 | self.register_buffer("running_var", torch.ones(n)) 19 | 20 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 21 | missing_keys, unexpected_keys, error_msgs): 22 | num_batches_tracked_key = prefix + 'num_batches_tracked' 23 | if num_batches_tracked_key in state_dict: 24 | del state_dict[num_batches_tracked_key] 25 | 26 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 27 | state_dict, prefix, local_metadata, strict, 28 | missing_keys, unexpected_keys, error_msgs) 29 | 30 | def forward(self, x): 31 | # move reshapes to the beginning 32 | # to make it fuser-friendly 33 | w = self.weight.reshape(1, -1, 1, 1) 34 | b = self.bias.reshape(1, -1, 1, 1) 35 | rv = self.running_var.reshape(1, -1, 1, 1) 36 | rm = self.running_mean.reshape(1, -1, 1, 1) 37 | eps = 1e-5 38 | scale = w * (rv + eps).rsqrt() # rsqrt(x): 1/sqrt(x), r: reciprocal 39 | bias = b - rm * scale 40 | return x * scale + bias 41 | -------------------------------------------------------------------------------- /lib/models/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from timm.models.layers import to_2tuple 4 | 5 | 6 | class PatchEmbed(nn.Module): 7 | """ 2D Image to Patch Embedding 8 | """ 9 | 10 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 11 | super().__init__() 12 | img_size = to_2tuple(img_size) 13 | patch_size = to_2tuple(patch_size) 14 | self.img_size = img_size 15 | self.patch_size = patch_size 16 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 17 | self.num_patches = self.grid_size[0] * self.grid_size[1] 18 | self.flatten = flatten 19 | 20 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 21 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 22 | 23 | def forward(self, x): 24 | # allow different input size 25 | # B, C, H, W = x.shape 26 | # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 27 | # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 28 | #print("start",x.size()) #start torch.Size([1, 3, 256, 256]) 29 | x = self.proj(x) #flatten before torch.Size([1, 768, 16, 16]) 30 | if self.flatten: 31 | #print("flatten before",x.size()) #flatten before torch.Size([1, 768, 16, 16]) 32 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 33 | #print("flatten transpose",x.size()) #flatten transpose torch.Size([1, 256, 768]) 34 | x = self.norm(x) 35 | #print("after",x.size()) #after torch.Size([1, 256, 768]) 36 | return x 37 | -------------------------------------------------------------------------------- /lib/models/layers/rpe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_ 4 | 5 | 6 | def generate_2d_relative_positional_encoding_index(z_shape, x_shape): 7 | ''' 8 | z_shape: (z_h, z_w) 9 | x_shape: (x_h, x_w) 10 | ''' 11 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 12 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 13 | 14 | z_2d_index_h = z_2d_index_h.flatten(0) 15 | z_2d_index_w = z_2d_index_w.flatten(0) 16 | x_2d_index_h = x_2d_index_h.flatten(0) 17 | x_2d_index_w = x_2d_index_w.flatten(0) 18 | 19 | diff_h = z_2d_index_h[:, None] - x_2d_index_h[None, :] 20 | diff_w = z_2d_index_w[:, None] - x_2d_index_w[None, :] 21 | 22 | diff = torch.stack((diff_h, diff_w), dim=-1) 23 | _, indices = torch.unique(diff.view(-1, 2), return_inverse=True, dim=0) 24 | return indices.view(z_shape[0] * z_shape[1], x_shape[0] * x_shape[1]) 25 | 26 | 27 | def generate_2d_concatenated_self_attention_relative_positional_encoding_index(z_shape, x_shape): 28 | ''' 29 | z_shape: (z_h, z_w) 30 | x_shape: (x_h, x_w) 31 | ''' 32 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 33 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 34 | 35 | z_2d_index_h = z_2d_index_h.flatten(0) 36 | z_2d_index_w = z_2d_index_w.flatten(0) 37 | x_2d_index_h = x_2d_index_h.flatten(0) 38 | x_2d_index_w = x_2d_index_w.flatten(0) 39 | 40 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h)) 41 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w)) 42 | 43 | diff_h = concatenated_2d_index_h[:, None] - concatenated_2d_index_h[None, :] 44 | diff_w = concatenated_2d_index_w[:, None] - concatenated_2d_index_w[None, :] 45 | 46 | z_len = z_shape[0] * z_shape[1] 47 | x_len = x_shape[0] * x_shape[1] 48 | a = torch.empty((z_len + x_len), dtype=torch.int64) 49 | a[:z_len] = 0 50 | a[z_len:] = 1 51 | b=a[:, None].repeat(1, z_len + x_len) 52 | c=a[None, :].repeat(z_len + x_len, 1) 53 | 54 | diff = torch.stack((diff_h, diff_w, b, c), dim=-1) 55 | _, indices = torch.unique(diff.view((z_len + x_len) * (z_len + x_len), 4), return_inverse=True, dim=0) 56 | return indices.view((z_len + x_len), (z_len + x_len)) 57 | 58 | 59 | def generate_2d_concatenated_cross_attention_relative_positional_encoding_index(z_shape, x_shape): 60 | ''' 61 | z_shape: (z_h, z_w) 62 | x_shape: (x_h, x_w) 63 | ''' 64 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1])) 65 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1])) 66 | 67 | z_2d_index_h = z_2d_index_h.flatten(0) 68 | z_2d_index_w = z_2d_index_w.flatten(0) 69 | x_2d_index_h = x_2d_index_h.flatten(0) 70 | x_2d_index_w = x_2d_index_w.flatten(0) 71 | 72 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h)) 73 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w)) 74 | 75 | diff_h = x_2d_index_h[:, None] - concatenated_2d_index_h[None, :] 76 | diff_w = x_2d_index_w[:, None] - concatenated_2d_index_w[None, :] 77 | 78 | z_len = z_shape[0] * z_shape[1] 79 | x_len = x_shape[0] * x_shape[1] 80 | 81 | a = torch.empty(z_len + x_len, dtype=torch.int64) 82 | a[: z_len] = 0 83 | a[z_len:] = 1 84 | c = a[None, :].repeat(x_len, 1) 85 | 86 | diff = torch.stack((diff_h, diff_w, c), dim=-1) 87 | _, indices = torch.unique(diff.view(x_len * (z_len + x_len), 3), return_inverse=True, dim=0) 88 | return indices.view(x_len, (z_len + x_len)) 89 | 90 | 91 | class RelativePosition2DEncoder(nn.Module): 92 | def __init__(self, num_heads, embed_size): 93 | super(RelativePosition2DEncoder, self).__init__() 94 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, embed_size))) 95 | trunc_normal_(self.relative_position_bias_table, std=0.02) 96 | 97 | def forward(self, attn_rpe_index): 98 | ''' 99 | Args: 100 | attn_rpe_index (torch.Tensor): (*), any shape containing indices, max(attn_rpe_index) < embed_size 101 | Returns: 102 | torch.Tensor: (1, num_heads, *) 103 | ''' 104 | return self.relative_position_bias_table[:, attn_rpe_index].unsqueeze(0) 105 | -------------------------------------------------------------------------------- /lib/test/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import Sequence 2 | from .tracker import Tracker, trackerlist 3 | from .datasets import get_dataset 4 | from .environment import create_default_local_file_test -------------------------------------------------------------------------------- /lib/test/evaluation/datasets.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import importlib 3 | from lib.test.evaluation.data import SequenceList 4 | 5 | DatasetInfo = namedtuple('DatasetInfo', ['module', 'class_name', 'kwargs']) 6 | 7 | pt = "lib.test.evaluation.%sdataset" # Useful abbreviations to reduce the clutter 8 | 9 | dataset_dict = dict( 10 | otb=DatasetInfo(module=pt % "otb", class_name="OTBDataset", kwargs=dict()), 11 | nfs=DatasetInfo(module=pt % "nfs", class_name="NFSDataset", kwargs=dict()), 12 | uav=DatasetInfo(module=pt % "uav", class_name="UAVDataset", kwargs=dict()), 13 | tc128=DatasetInfo(module=pt % "tc128", class_name="TC128Dataset", kwargs=dict()), 14 | tc128ce=DatasetInfo(module=pt % "tc128ce", class_name="TC128CEDataset", kwargs=dict()), 15 | trackingnet=DatasetInfo(module=pt % "trackingnet", class_name="TrackingNetDataset", kwargs=dict()), 16 | got10k_test=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='test')), 17 | got10k_val=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='val')), 18 | got10k_ltrval=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='ltrval')), 19 | lasot=DatasetInfo(module=pt % "lasot", class_name="LaSOTDataset", kwargs=dict()), 20 | lasot_lmdb=DatasetInfo(module=pt % "lasot_lmdb", class_name="LaSOTlmdbDataset", kwargs=dict()), 21 | 22 | vot18=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict()), 23 | vot22=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict(year=22)), 24 | itb=DatasetInfo(module=pt % "itb", class_name="ITBDataset", kwargs=dict()), 25 | tnl2k=DatasetInfo(module=pt % "tnl2k", class_name="TNL2kDataset", kwargs=dict()), 26 | lasot_extension_subset=DatasetInfo(module=pt % "lasotextensionsubset", class_name="LaSOTExtensionSubsetDataset", 27 | kwargs=dict()), 28 | vtuav_st=DatasetInfo(module=pt % "vtuav", class_name="VTUAVDataset", kwargs=dict(subset='st')), 29 | vtuav_lt=DatasetInfo(module=pt % "vtuav", class_name="VTUAVDataset", kwargs=dict(subset='lt')), 30 | lasher=DatasetInfo(module=pt % "lasher", class_name="LasHeRDataset", kwargs=dict()), 31 | ) 32 | 33 | 34 | def load_dataset(name: str): 35 | """ Import and load a single dataset.""" 36 | name = name.lower() 37 | dset_info = dataset_dict.get(name) 38 | if dset_info is None: 39 | raise ValueError('Unknown dataset \'%s\'' % name) 40 | 41 | m = importlib.import_module(dset_info.module) 42 | dataset = getattr(m, dset_info.class_name)(**dset_info.kwargs) # Call the constructor 43 | return dataset.get_sequence_list() 44 | 45 | 46 | def get_dataset(*args): 47 | """ Get a single or set of datasets.""" 48 | dset = SequenceList() 49 | for name in args: 50 | dset.extend(load_dataset(name)) 51 | return dset -------------------------------------------------------------------------------- /lib/test/evaluation/environment.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | 5 | class EnvSettings: 6 | def __init__(self): 7 | test_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 8 | 9 | self.results_path = '{}/tracking_results/'.format(test_path) 10 | self.segmentation_path = '{}/segmentation_results/'.format(test_path) 11 | self.network_path = '{}/networks/'.format(test_path) 12 | self.result_plot_path = '{}/result_plots/'.format(test_path) 13 | self.otb_path = '' 14 | self.nfs_path = '' 15 | self.uav_path = '' 16 | self.tpl_path = '' 17 | self.vot_path = '' 18 | self.got10k_path = '' 19 | self.lasot_path = '' 20 | self.trackingnet_path = '' 21 | self.davis_dir = '' 22 | self.youtubevos_dir = '' 23 | 24 | self.got_packed_results_path = '' 25 | self.got_reports_path = '' 26 | self.tn_packed_results_path = '' 27 | 28 | 29 | def create_default_local_file(): 30 | comment = {'results_path': 'Where to store tracking results', 31 | 'network_path': 'Where tracking networks are stored.'} 32 | 33 | path = os.path.join(os.path.dirname(__file__), 'local.py') 34 | with open(path, 'w') as f: 35 | settings = EnvSettings() 36 | 37 | f.write('from test.evaluation.environment import EnvSettings\n\n') 38 | f.write('def local_env_settings():\n') 39 | f.write(' settings = EnvSettings()\n\n') 40 | f.write(' # Set your local paths here.\n\n') 41 | 42 | for attr in dir(settings): 43 | comment_str = None 44 | if attr in comment: 45 | comment_str = comment[attr] 46 | attr_val = getattr(settings, attr) 47 | if not attr.startswith('__') and not callable(attr_val): 48 | if comment_str is None: 49 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val)) 50 | else: 51 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 52 | f.write('\n return settings\n\n') 53 | 54 | 55 | class EnvSettings_ITP: 56 | def __init__(self, workspace_dir, data_dir, save_dir): 57 | self.prj_dir = workspace_dir 58 | self.save_dir = save_dir 59 | self.results_path = os.path.join(save_dir, 'test/tracking_results') 60 | self.segmentation_path = os.path.join(save_dir, 'test/segmentation_results') 61 | self.network_path = os.path.join(save_dir, 'test/networks') 62 | self.result_plot_path = os.path.join(save_dir, 'test/result_plots') 63 | self.otb_path = os.path.join(data_dir, 'otb') 64 | self.nfs_path = os.path.join(data_dir, 'nfs') 65 | self.uav_path = os.path.join(data_dir, 'uav') 66 | self.tc128_path = os.path.join(data_dir, 'TC128') 67 | self.tpl_path = '' 68 | self.vot_path = os.path.join(data_dir, 'VOT2019') 69 | self.got10k_path = os.path.join(data_dir, 'got10k') 70 | self.got10k_lmdb_path = os.path.join(data_dir, 'got10k_lmdb') 71 | self.lasot_path = os.path.join(data_dir, 'lasot') 72 | self.lasot_lmdb_path = os.path.join(data_dir, 'lasot_lmdb') 73 | self.trackingnet_path = os.path.join(data_dir, 'trackingnet') 74 | self.vot18_path = os.path.join(data_dir, 'vot2018') 75 | self.vot22_path = os.path.join(data_dir, 'vot2022') 76 | self.itb_path = os.path.join(data_dir, 'itb') 77 | self.tnl2k_path = os.path.join(data_dir, 'tnl2k') 78 | self.lasot_extension_subset_path_path = os.path.join(data_dir, 'lasot_extension_subset') 79 | self.davis_dir = '' 80 | self.youtubevos_dir = '' 81 | 82 | self.got_packed_results_path = '' 83 | self.got_reports_path = '' 84 | self.tn_packed_results_path = '' 85 | 86 | 87 | def create_default_local_file_test(workspace_dir, data_dir, save_dir): 88 | comment = {'results_path': 'Where to store tracking results', 89 | 'network_path': 'Where tracking networks are stored.'} 90 | 91 | path = os.path.join(os.path.dirname(__file__), 'local.py') 92 | with open(path, 'w') as f: 93 | settings = EnvSettings_ITP(workspace_dir, data_dir, save_dir) 94 | 95 | f.write('from lib.test.evaluation.environment import EnvSettings\n\n') 96 | f.write('def local_env_settings():\n') 97 | f.write(' settings = EnvSettings()\n\n') 98 | f.write(' # Set your local paths here.\n\n') 99 | 100 | for attr in dir(settings): 101 | comment_str = None 102 | if attr in comment: 103 | comment_str = comment[attr] 104 | attr_val = getattr(settings, attr) 105 | if not attr.startswith('__') and not callable(attr_val): 106 | if comment_str is None: 107 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val)) 108 | else: 109 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 110 | f.write('\n return settings\n\n') 111 | 112 | 113 | def env_settings(): 114 | env_module_name = 'lib.test.evaluation.local' 115 | try: 116 | env_module = importlib.import_module(env_module_name) 117 | return env_module.local_env_settings() 118 | except: 119 | env_file = os.path.join(os.path.dirname(__file__), 'local.py') 120 | 121 | # Create a default file 122 | create_default_local_file() 123 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. ' 124 | 'Then try to run again.'.format(env_file)) -------------------------------------------------------------------------------- /lib/test/evaluation/local.py: -------------------------------------------------------------------------------- 1 | from lib.test.evaluation.environment import EnvSettings 2 | 3 | def local_env_settings(): 4 | settings = EnvSettings() 5 | 6 | # Set your local paths here. 7 | 8 | settings.davis_dir = '' 9 | settings.got10k_lmdb_path = '/root/BAT-main/data/got10k_lmdb' 10 | settings.got10k_path = '/root/BAT-main/data/got10k' 11 | settings.got_packed_results_path = '' 12 | settings.got_reports_path = '' 13 | settings.itb_path = '/root/BAT-main/data/itb' 14 | settings.lasot_extension_subset_path_path = '/root/BAT-main/data/lasot_extension_subset' 15 | settings.lasot_lmdb_path = '/root/BAT-main/data/lasot_lmdb' 16 | settings.lasot_path = '/root/BAT-main/data/lasot' 17 | settings.network_path = '/root/BAT-main/output/test/networks' # Where tracking networks are stored. 18 | settings.nfs_path = '/root/BAT-main/data/nfs' 19 | settings.otb_path = '/root/BAT-main/data/otb' 20 | settings.prj_dir = '/root/BAT-main' 21 | settings.result_plot_path = '/root/BAT-main/output/test/result_plots' 22 | settings.results_path = '/root/BAT-main/output/test/tracking_results' # Where to store tracking results 23 | settings.save_dir = '/root/BAT-main/output' 24 | settings.segmentation_path = '/root/BAT-main/output/test/segmentation_results' 25 | settings.tc128_path = '/root/BAT-main/data/TC128' 26 | settings.tn_packed_results_path = '' 27 | settings.tnl2k_path = '/root/BAT-main/data/tnl2k' 28 | settings.tpl_path = '' 29 | settings.trackingnet_path = '/root/BAT-main/data/trackingnet' 30 | settings.uav_path = '/root/BAT-main/data/uav' 31 | settings.vot18_path = '/root/BAT-main/data/vot2018' 32 | settings.vot22_path = '/root/BAT-main/data/vot2022' 33 | settings.vot_path = '/root/BAT-main/data/VOT2019' 34 | settings.youtubevos_dir = '' 35 | 36 | return settings 37 | 38 | -------------------------------------------------------------------------------- /lib/test/parameter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkTempest/BAT/ccf9b2f6ae3e810f4e7318c9d0b62083deb7ec89/lib/test/parameter/__init__.py -------------------------------------------------------------------------------- /lib/test/parameter/bat.py: -------------------------------------------------------------------------------- 1 | from lib.test.utils import TrackerParams 2 | import os 3 | from lib.test.evaluation.environment import env_settings 4 | from lib.config.bat.config import cfg, update_config_from_file 5 | 6 | 7 | def parameters(yaml_name: str, epoch=None): 8 | params = TrackerParams() 9 | prj_dir = env_settings().prj_dir 10 | save_dir = env_settings().save_dir 11 | # update default config from yaml file 12 | yaml_file = os.path.join(prj_dir, 'experiments/bat/%s.yaml' % yaml_name) 13 | update_config_from_file(yaml_file) 14 | params.cfg = cfg 15 | print("test config: ", cfg) 16 | 17 | # template and search region 18 | params.template_factor = cfg.TEST.TEMPLATE_FACTOR 19 | params.template_size = cfg.TEST.TEMPLATE_SIZE 20 | params.search_factor = cfg.TEST.SEARCH_FACTOR 21 | params.search_size = cfg.TEST.SEARCH_SIZE 22 | 23 | # Network checkpoint path 24 | # params.checkpoint = os.path.join(save_dir, "checkpoints/train/bat/%s/BATrack_ep%04d.pth.tar" % (yaml_name, cfg.TEST.EPOCH)) 25 | params.checkpoint = os.path.join(prj_dir, "./models/BAT_%s.pth"%yaml_name) 26 | # whether to save boxes from all queries 27 | params.save_all_boxes = False 28 | 29 | return params 30 | -------------------------------------------------------------------------------- /lib/test/tracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkTempest/BAT/ccf9b2f6ae3e810f4e7318c9d0b62083deb7ec89/lib/test/tracker/__init__.py -------------------------------------------------------------------------------- /lib/test/tracker/basetracker.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from _collections import OrderedDict 5 | 6 | from lib.train.data.processing_utils import transform_image_to_crop 7 | from lib.vis.visdom_cus import Visdom 8 | 9 | 10 | class BaseTracker: 11 | """Base class for all trackers.""" 12 | 13 | def __init__(self, params): 14 | self.params = params 15 | self.visdom = None 16 | 17 | def predicts_segmentation_mask(self): 18 | return False 19 | 20 | def initialize(self, image, info: dict) -> dict: 21 | """Overload this function in your tracker. This should initialize the model.""" 22 | raise NotImplementedError 23 | 24 | def track(self, image, info: dict = None) -> dict: 25 | """Overload this function in your tracker. This should track in the frame and update the model.""" 26 | raise NotImplementedError 27 | 28 | def visdom_draw_tracking(self, image, box, segmentation=None): 29 | if isinstance(box, OrderedDict): 30 | box = [v for k, v in box.items()] 31 | else: 32 | box = (box,) 33 | if segmentation is None: 34 | self.visdom.register((image, *box), 'Tracking', 1, 'Tracking') 35 | else: 36 | self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking') 37 | 38 | def transform_bbox_to_crop(self, box_in, resize_factor, device, box_extract=None, crop_type='template'): 39 | # box_in: list [x1, y1, w, h], not normalized 40 | # box_extract: same as box_in 41 | # out bbox: Torch.tensor [1, 1, 4], x1y1wh, normalized 42 | if crop_type == 'template': 43 | crop_sz = torch.Tensor([self.params.template_size, self.params.template_size]) 44 | elif crop_type == 'search': 45 | crop_sz = torch.Tensor([self.params.search_size, self.params.search_size]) 46 | else: 47 | raise NotImplementedError 48 | 49 | box_in = torch.tensor(box_in) 50 | if box_extract is None: 51 | box_extract = box_in 52 | else: 53 | box_extract = torch.tensor(box_extract) 54 | template_bbox = transform_image_to_crop(box_in, box_extract, resize_factor, crop_sz, normalize=True) 55 | template_bbox = template_bbox.view(1, 1, 4).to(device) 56 | 57 | return template_bbox 58 | 59 | def _init_visdom(self, visdom_info, debug): 60 | visdom_info = {} if visdom_info is None else visdom_info 61 | self.pause_mode = False 62 | self.step = False 63 | self.next_seq = False 64 | if debug > 0 and visdom_info.get('use_visdom', True): 65 | try: 66 | self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'}, 67 | visdom_info=visdom_info) 68 | except: 69 | time.sleep(0.5) 70 | print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n' 71 | '!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!') 72 | 73 | def _visdom_ui_handler(self, data): 74 | if data['event_type'] == 'KeyPress': 75 | if data['key'] == ' ': 76 | self.pause_mode = not self.pause_mode 77 | 78 | elif data['key'] == 'ArrowRight' and self.pause_mode: 79 | self.step = True 80 | 81 | elif data['key'] == 'n': 82 | self.next_seq = True 83 | -------------------------------------------------------------------------------- /lib/test/tracker/bat.py: -------------------------------------------------------------------------------- 1 | import math 2 | from lib.models.bat import build_batrack 3 | from lib.test.tracker.basetracker import BaseTracker 4 | import torch 5 | from lib.test.tracker.vis_utils import gen_visualization 6 | from lib.test.utils.hann import hann2d 7 | from lib.train.data.processing_utils import sample_target 8 | # for debug 9 | import cv2 10 | import os 11 | import vot 12 | from lib.test.tracker.data_utils import PreprocessorMM 13 | from lib.utils.box_ops import clip_box 14 | from lib.utils.ce_utils import generate_mask_cond 15 | 16 | 17 | class BATTrack(BaseTracker): 18 | def __init__(self, params): 19 | super(BATTrack, self).__init__(params) 20 | network = build_batrack(params.cfg, training=False) 21 | network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True) 22 | self.cfg = params.cfg 23 | self.network = network.cuda() 24 | self.network.eval() 25 | self.preprocessor = PreprocessorMM() 26 | self.state = None 27 | 28 | self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE 29 | # motion constrain 30 | self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda() 31 | 32 | # for debug 33 | if getattr(params, 'debug', None) is None: 34 | setattr(params, 'debug', 0) 35 | self.use_visdom = True #params.debug 36 | #self._init_visdom(None, 1) 37 | self.debug = params.debug 38 | self.frame_id = 0 39 | # for save boxes from all queries 40 | self.save_all_boxes = params.save_all_boxes 41 | 42 | def initialize(self, image, info: dict): 43 | # forward the template once 44 | z_patch_arr, resize_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor, 45 | output_sz=self.params.template_size) 46 | self.z_patch_arr = z_patch_arr 47 | template = self.preprocessor.process(z_patch_arr) 48 | with torch.no_grad(): 49 | self.z_tensor = template 50 | 51 | self.box_mask_z = None 52 | if self.cfg.MODEL.BACKBONE.CE_LOC: 53 | template_bbox = self.transform_bbox_to_crop(info['init_bbox'], resize_factor, 54 | template.device).squeeze(1) 55 | self.box_mask_z = generate_mask_cond(self.cfg, 1, template.device, template_bbox) 56 | 57 | # save states 58 | self.state = info['init_bbox'] 59 | self.frame_id = 0 60 | if self.save_all_boxes: 61 | '''save all predicted boxes''' 62 | all_boxes_save = info['init_bbox'] * self.cfg.MODEL.NUM_OBJECT_QUERIES 63 | return {"all_boxes": all_boxes_save} 64 | 65 | def track(self, image, info: dict = None): 66 | H, W, _ = image.shape 67 | self.frame_id += 1 68 | x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor, 69 | output_sz=self.params.search_size) # (x1, y1, w, h) 70 | search = self.preprocessor.process(x_patch_arr) 71 | 72 | with torch.no_grad(): 73 | x_tensor = search 74 | # merge the template and the search 75 | # run the transformer 76 | out_dict = self.network.forward( 77 | template=self.z_tensor, search=x_tensor, ce_template_mask=self.box_mask_z) 78 | 79 | # add hann windows 80 | pred_score_map = out_dict['score_map'] 81 | response = self.output_window * pred_score_map 82 | pred_boxes, best_score = self.network.box_head.cal_bbox(response, out_dict['size_map'], out_dict['offset_map'], return_score=True) 83 | max_score = best_score[0][0].item() 84 | pred_boxes = pred_boxes.view(-1, 4) 85 | # Baseline: Take the mean of all pred boxes as the final result 86 | pred_box = (pred_boxes.mean( 87 | dim=0) * self.params.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1] 88 | # get the final box result 89 | self.state = clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10) 90 | 91 | #self.debug = 1 92 | 93 | # for debug 94 | if self.debug == 1: 95 | x1, y1, w, h = self.state 96 | image_BGR = cv2.cvtColor(image[:,:,:3], cv2.COLOR_RGB2BGR) 97 | cv2.rectangle(image_BGR, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color=(0, 0, 255), thickness=2) 98 | cv2.putText(image_BGR, 'max_score:' + str(round(max_score, 3)), (40, 40), 99 | cv2.FONT_HERSHEY_SIMPLEX, 1, 100 | (0, 255, 255), 2) 101 | cv2.imshow('debug_vis', image_BGR) 102 | cv2.waitKey(1) 103 | 104 | 105 | if self.save_all_boxes: 106 | '''save all predictions''' 107 | all_boxes = self.map_box_back_batch(pred_boxes * self.params.search_size / resize_factor, resize_factor) 108 | all_boxes_save = all_boxes.view(-1).tolist() # (4N, ) 109 | return {"target_bbox": self.state, 110 | "all_boxes": all_boxes_save, 111 | "best_score": max_score} 112 | else: 113 | return {"target_bbox": self.state, 114 | "best_score": max_score} 115 | 116 | def map_box_back(self, pred_box: list, resize_factor: float): 117 | cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] 118 | cx, cy, w, h = pred_box 119 | half_side = 0.5 * self.params.search_size / resize_factor 120 | cx_real = cx + (cx_prev - half_side) 121 | cy_real = cy + (cy_prev - half_side) 122 | return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h] 123 | 124 | def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float): 125 | cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] 126 | cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,) 127 | half_side = 0.5 * self.params.search_size / resize_factor 128 | cx_real = cx + (cx_prev - half_side) 129 | cy_real = cy + (cy_prev - half_side) 130 | return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1) 131 | 132 | 133 | def get_tracker_class(): 134 | return BATTrack 135 | -------------------------------------------------------------------------------- /lib/test/tracker/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class Preprocessor(object): 5 | def __init__(self): 6 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda() 7 | self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda() 8 | 9 | def process(self, img_arr: np.ndarray): 10 | # Deal with the image patch 11 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0) 12 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W) 13 | return img_tensor_norm 14 | 15 | class PreprocessorMM(object): 16 | def __init__(self): 17 | self.mean = torch.tensor([0.485, 0.456, 0.406, 0.485, 0.456, 0.406]).view((1, 6, 1, 1)).cuda() 18 | self.std = torch.tensor([0.229, 0.224, 0.225, 0.229, 0.224, 0.225]).view((1, 6, 1, 1)).cuda() 19 | 20 | def process(self, img_arr: np.ndarray): 21 | # Deal with the image patch 22 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0) 23 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,6,H,W) 24 | return img_tensor_norm 25 | 26 | 27 | class PreprocessorX(object): 28 | def __init__(self): 29 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda() 30 | self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda() 31 | 32 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): 33 | # Deal with the image patch 34 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0) 35 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W) 36 | # Deal with the attention mask 37 | amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W) 38 | return img_tensor_norm, amask_tensor 39 | 40 | 41 | class PreprocessorX_onnx(object): 42 | def __init__(self): 43 | self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1)) 44 | self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1)) 45 | 46 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): 47 | """img_arr: (H,W,3), amask_arr: (H,W)""" 48 | # Deal with the image patch 49 | img_arr_4d = img_arr[np.newaxis, :, :, :].transpose(0, 3, 1, 2) 50 | img_arr_4d = (img_arr_4d / 255.0 - self.mean) / self.std # (1, 3, H, W) 51 | # Deal with the attention mask 52 | amask_arr_3d = amask_arr[np.newaxis, :, :] # (1,H,W) 53 | return img_arr_4d.astype(np.float32), amask_arr_3d.astype(np.bool) 54 | -------------------------------------------------------------------------------- /lib/test/tracker/vis_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | ############## used for visulize eliminated tokens ################# 5 | def get_keep_indices(decisions): 6 | keep_indices = [] 7 | for i in range(3): 8 | if i == 0: 9 | keep_indices.append(decisions[i]) 10 | else: 11 | keep_indices.append(keep_indices[-1][decisions[i]]) 12 | return keep_indices 13 | 14 | 15 | def gen_masked_tokens(tokens, indices, alpha=0.2): 16 | # indices = [i for i in range(196) if i not in indices] 17 | indices = indices[0].astype(int) 18 | tokens = tokens.copy() 19 | tokens[indices] = alpha * tokens[indices] + (1 - alpha) * 255 20 | return tokens 21 | 22 | 23 | def recover_image(tokens, H, W, Hp, Wp, patch_size): 24 | # image: (C, 196, 16, 16) 25 | image = tokens.reshape(Hp, Wp, patch_size, patch_size, 3).swapaxes(1, 2).reshape(H, W, 3) 26 | return image 27 | 28 | 29 | def pad_img(img): 30 | height, width, channels = img.shape 31 | im_bg = np.ones((height, width + 8, channels)) * 255 32 | im_bg[0:height, 0:width, :] = img 33 | return im_bg 34 | 35 | 36 | def gen_visualization(image, mask_indices, patch_size=16): 37 | # image [224, 224, 3] 38 | # mask_indices, list of masked token indices 39 | 40 | # mask mask_indices need to cat 41 | # mask_indices = mask_indices[::-1] 42 | num_stages = len(mask_indices) 43 | for i in range(1, num_stages): 44 | mask_indices[i] = np.concatenate([mask_indices[i-1], mask_indices[i]], axis=1) 45 | 46 | # keep_indices = get_keep_indices(decisions) 47 | image = np.asarray(image) 48 | H, W, C = image.shape 49 | Hp, Wp = H // patch_size, W // patch_size 50 | image_tokens = image.reshape(Hp, patch_size, Wp, patch_size, 3).swapaxes(1, 2).reshape(Hp * Wp, patch_size, patch_size, 3) 51 | 52 | stages = [ 53 | recover_image(gen_masked_tokens(image_tokens, mask_indices[i]), H, W, Hp, Wp, patch_size) 54 | for i in range(num_stages) 55 | ] 56 | imgs = [image] + stages 57 | imgs = [pad_img(img) for img in imgs] 58 | viz = np.concatenate(imgs, axis=1) 59 | return viz 60 | -------------------------------------------------------------------------------- /lib/test/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .params import TrackerParams, FeatureParams, Choice -------------------------------------------------------------------------------- /lib/test/utils/_init_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | 14 | this_dir = osp.dirname(__file__) 15 | 16 | prj_path = osp.join(this_dir, '..', '..', '..') 17 | add_path(prj_path) 18 | -------------------------------------------------------------------------------- /lib/test/utils/hann.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | 5 | 6 | def hann1d(sz: int, centered = True) -> torch.Tensor: 7 | """1D cosine window.""" 8 | if centered: 9 | return 0.5 * (1 - torch.cos((2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float())) 10 | w = 0.5 * (1 + torch.cos((2 * math.pi / (sz + 2)) * torch.arange(0, sz//2 + 1).float())) 11 | return torch.cat([w, w[1:sz-sz//2].flip((0,))]) 12 | 13 | 14 | def hann2d(sz: torch.Tensor, centered = True) -> torch.Tensor: 15 | """2D cosine window.""" 16 | return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d(sz[1].item(), centered).reshape(1, 1, 1, -1) 17 | 18 | 19 | def hann2d_bias(sz: torch.Tensor, ctr_point: torch.Tensor, centered = True) -> torch.Tensor: 20 | """2D cosine window.""" 21 | distance = torch.stack([ctr_point, sz-ctr_point], dim=0) 22 | max_distance, _ = distance.max(dim=0) 23 | 24 | hann1d_x = hann1d(max_distance[0].item() * 2, centered) 25 | hann1d_x = hann1d_x[max_distance[0] - distance[0, 0]: max_distance[0] + distance[1, 0]] 26 | hann1d_y = hann1d(max_distance[1].item() * 2, centered) 27 | hann1d_y = hann1d_y[max_distance[1] - distance[0, 1]: max_distance[1] + distance[1, 1]] 28 | 29 | return hann1d_y.reshape(1, 1, -1, 1) * hann1d_x.reshape(1, 1, 1, -1) 30 | 31 | 32 | 33 | def hann2d_clipped(sz: torch.Tensor, effective_sz: torch.Tensor, centered = True) -> torch.Tensor: 34 | """1D clipped cosine window.""" 35 | 36 | # Ensure that the difference is even 37 | effective_sz += (effective_sz - sz) % 2 38 | effective_window = hann1d(effective_sz[0].item(), True).reshape(1, 1, -1, 1) * hann1d(effective_sz[1].item(), True).reshape(1, 1, 1, -1) 39 | 40 | pad = (sz - effective_sz) // 2 41 | 42 | window = F.pad(effective_window, (pad[1].item(), pad[1].item(), pad[0].item(), pad[0].item()), 'replicate') 43 | 44 | if centered: 45 | return window 46 | else: 47 | mid = (sz / 2).int() 48 | window_shift_lr = torch.cat((window[:, :, :, mid[1]:], window[:, :, :, :mid[1]]), 3) 49 | return torch.cat((window_shift_lr[:, :, mid[0]:, :], window_shift_lr[:, :, :mid[0], :]), 2) 50 | 51 | 52 | def gauss_fourier(sz: int, sigma: float, half: bool = False) -> torch.Tensor: 53 | if half: 54 | k = torch.arange(0, int(sz/2+1)) 55 | else: 56 | k = torch.arange(-int((sz-1)/2), int(sz/2+1)) 57 | return (math.sqrt(2*math.pi) * sigma / sz) * torch.exp(-2 * (math.pi * sigma * k.float() / sz)**2) 58 | 59 | 60 | def gauss_spatial(sz, sigma, center=0, end_pad=0): 61 | k = torch.arange(-(sz-1)/2, (sz+1)/2+end_pad) 62 | return torch.exp(-1.0/(2*sigma**2) * (k - center)**2) 63 | 64 | 65 | def label_function(sz: torch.Tensor, sigma: torch.Tensor): 66 | return gauss_fourier(sz[0].item(), sigma[0].item()).reshape(1, 1, -1, 1) * gauss_fourier(sz[1].item(), sigma[1].item(), True).reshape(1, 1, 1, -1) 67 | 68 | def label_function_spatial(sz: torch.Tensor, sigma: torch.Tensor, center: torch.Tensor = torch.zeros(2), end_pad: torch.Tensor = torch.zeros(2)): 69 | """The origin is in the middle of the image.""" 70 | return gauss_spatial(sz[0].item(), sigma[0].item(), center[0], end_pad[0].item()).reshape(1, 1, -1, 1) * \ 71 | gauss_spatial(sz[1].item(), sigma[1].item(), center[1], end_pad[1].item()).reshape(1, 1, 1, -1) 72 | 73 | 74 | def cubic_spline_fourier(f, a): 75 | """The continuous Fourier transform of a cubic spline kernel.""" 76 | 77 | bf = (6*(1 - torch.cos(2 * math.pi * f)) + 3*a*(1 - torch.cos(4 * math.pi * f)) 78 | - (6 + 8*a)*math.pi*f*torch.sin(2 * math.pi * f) - 2*a*math.pi*f*torch.sin(4 * math.pi * f)) \ 79 | / (4 * math.pi**4 * f**4) 80 | 81 | bf[f == 0] = 1 82 | 83 | return bf 84 | 85 | def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor): 86 | """Computes maximum and argmax in the last two dimensions.""" 87 | 88 | max_val_row, argmax_row = torch.max(a, dim=-2) 89 | max_val, argmax_col = torch.max(max_val_row, dim=-1) 90 | argmax_row = argmax_row.view(argmax_col.numel(),-1)[torch.arange(argmax_col.numel()), argmax_col.view(-1)] 91 | argmax_row = argmax_row.reshape(argmax_col.shape) 92 | argmax = torch.cat((argmax_row.unsqueeze(-1), argmax_col.unsqueeze(-1)), -1) 93 | return max_val, argmax 94 | -------------------------------------------------------------------------------- /lib/test/utils/load_text.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def load_text_numpy(path, delimiter, dtype): 6 | if isinstance(delimiter, (tuple, list)): 7 | for d in delimiter: 8 | try: 9 | ground_truth_rect = np.loadtxt(path, delimiter=d, dtype=dtype) 10 | return ground_truth_rect 11 | except: 12 | pass 13 | 14 | raise Exception('Could not read file {}'.format(path)) 15 | else: 16 | ground_truth_rect = np.loadtxt(path, delimiter=delimiter, dtype=dtype) 17 | return ground_truth_rect 18 | 19 | 20 | def load_text_pandas(path, delimiter, dtype): 21 | if isinstance(delimiter, (tuple, list)): 22 | for d in delimiter: 23 | try: 24 | ground_truth_rect = pd.read_csv(path, delimiter=d, header=None, dtype=dtype, na_filter=False, 25 | low_memory=False).values 26 | return ground_truth_rect 27 | except Exception as e: 28 | pass 29 | 30 | raise Exception('Could not read file {}'.format(path)) 31 | else: 32 | ground_truth_rect = pd.read_csv(path, delimiter=delimiter, header=None, dtype=dtype, na_filter=False, 33 | low_memory=False).values 34 | return ground_truth_rect 35 | 36 | 37 | def load_text(path, delimiter=' ', dtype=np.float32, backend='numpy'): 38 | if backend == 'numpy': 39 | return load_text_numpy(path, delimiter, dtype) 40 | elif backend == 'pandas': 41 | return load_text_pandas(path, delimiter, dtype) 42 | 43 | 44 | def load_str(path): 45 | with open(path, "r") as f: 46 | text_str = f.readline().strip().lower() 47 | return text_str 48 | -------------------------------------------------------------------------------- /lib/test/utils/params.py: -------------------------------------------------------------------------------- 1 | from lib.utils import TensorList 2 | import random 3 | 4 | 5 | class TrackerParams: 6 | """Class for tracker parameters.""" 7 | def set_default_values(self, default_vals: dict): 8 | for name, val in default_vals.items(): 9 | if not hasattr(self, name): 10 | setattr(self, name, val) 11 | 12 | def get(self, name: str, *default): 13 | """Get a parameter value with the given name. If it does not exists, it return the default value given as a 14 | second argument or returns an error if no default value is given.""" 15 | if len(default) > 1: 16 | raise ValueError('Can only give one default value.') 17 | 18 | if not default: 19 | return getattr(self, name) 20 | 21 | return getattr(self, name, default[0]) 22 | 23 | def has(self, name: str): 24 | """Check if there exist a parameter with the given name.""" 25 | return hasattr(self, name) 26 | 27 | 28 | class FeatureParams: 29 | """Class for feature specific parameters""" 30 | def __init__(self, *args, **kwargs): 31 | if len(args) > 0: 32 | raise ValueError 33 | 34 | for name, val in kwargs.items(): 35 | if isinstance(val, list): 36 | setattr(self, name, TensorList(val)) 37 | else: 38 | setattr(self, name, val) 39 | 40 | 41 | def Choice(*args): 42 | """Can be used to sample random parameter values.""" 43 | return random.choice(args) 44 | -------------------------------------------------------------------------------- /lib/test/vot/bat_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | env_path = os.path.join(os.path.dirname(__file__), '../../..') 4 | if env_path not in sys.path: 5 | sys.path.append(env_path) 6 | from lib.test.vot.bat_class import run_vot_exp 7 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 8 | 9 | 10 | run_vot_exp('bat', 'rgbd', vis=False, out_conf=True, channel_type='rgbd') 11 | -------------------------------------------------------------------------------- /lib/test/vot/bat_class.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import pdb 7 | import cv2 8 | import torch 9 | # import vot 10 | import sys 11 | import time 12 | import os 13 | from lib.test.evaluation import Tracker 14 | import lib.test.vot.vot as vot 15 | from lib.test.vot.vot22_utils import * 16 | from lib.train.dataset.depth_utils import get_rgbd_frame 17 | 18 | 19 | class bat(object): 20 | def __init__(self, tracker_name='', para_name=''): 21 | # create tracker 22 | tracker_info = Tracker(tracker_name, para_name, "vot22", None) 23 | params = tracker_info.get_parameters() 24 | params.visualization = False 25 | params.debug = False 26 | self.tracker = tracker_info.create_tracker(params) 27 | 28 | def write(self, str): 29 | txt_path = "" 30 | file = open(txt_path, 'a') 31 | file.write(str) 32 | 33 | def initialize(self, img_rgb, selection): 34 | # init on the 1st frame 35 | # region = rect_from_mask(mask) 36 | x, y, w, h = selection 37 | bbox = [x,y,w,h] 38 | self.H, self.W, _ = img_rgb.shape 39 | init_info = {'init_bbox': bbox} 40 | _ = self.tracker.initialize(img_rgb, init_info) 41 | 42 | def track(self, img_rgb): 43 | # track 44 | outputs = self.tracker.track(img_rgb) 45 | pred_bbox = outputs['target_bbox'] 46 | max_score = outputs['best_score'] #.max().cpu().numpy() 47 | return pred_bbox, max_score 48 | 49 | 50 | def run_vot_exp(tracker_name, para_name, vis=False, out_conf=False, channel_type='color'): 51 | 52 | torch.set_num_threads(1) 53 | save_root = os.path.join('', para_name) 54 | if vis and (not os.path.exists(save_root)): 55 | os.mkdir(save_root) 56 | tracker = bat(tracker_name=tracker_name, para_name=para_name) 57 | 58 | if channel_type=='rgb': 59 | channel_type=None 60 | handle = vot.VOT("rectangle", channels=channel_type) 61 | 62 | selection = handle.region() 63 | imagefile = handle.frame() 64 | if not imagefile: 65 | sys.exit(0) 66 | if vis: 67 | '''for vis''' 68 | seq_name = imagefile.split('/')[-3] 69 | save_v_dir = os.path.join(save_root,seq_name) 70 | if not os.path.exists(save_v_dir): 71 | os.mkdir(save_v_dir) 72 | cur_time = int(time.time() % 10000) 73 | save_dir = os.path.join(save_v_dir, str(cur_time)) 74 | if not os.path.exists(save_dir): 75 | os.makedirs(save_dir) 76 | 77 | # read rgbd data 78 | if isinstance(imagefile, list) and len(imagefile)==2: 79 | image = get_rgbd_frame(imagefile[0], imagefile[1], dtype='rgbcolormap', depth_clip=True) 80 | else: 81 | image = cv2.cvtColor(cv2.imread(imagefile), cv2.COLOR_BGR2RGB) # Right 82 | 83 | tracker.initialize(image, selection) 84 | 85 | while True: 86 | imagefile = handle.frame() 87 | if not imagefile: 88 | break 89 | 90 | # read rgbd data 91 | if isinstance(imagefile, list) and len(imagefile) == 2: 92 | image = get_rgbd_frame(imagefile[0], imagefile[1], dtype='rgbcolormap', depth_clip=True) 93 | else: 94 | image = cv2.cvtColor(cv2.imread(imagefile), cv2.COLOR_BGR2RGB) # Right 95 | 96 | b1, max_score = tracker.track(image) 97 | 98 | if out_conf: 99 | handle.report(vot.Rectangle(*b1), max_score) 100 | else: 101 | handle.report(vot.Rectangle(*b1)) 102 | if vis: 103 | '''Visualization''' 104 | # original image 105 | image_ori = image[:,:,::-1].copy() # RGB --> BGR 106 | image_name = imagefile.split('/')[-1] 107 | save_path = os.path.join(save_dir, image_name) 108 | image_b = image_ori.copy() 109 | cv2.rectangle(image_b, (int(b1[0]), int(b1[1])), 110 | (int(b1[0] + b1[2]), int(b1[1] + b1[3])), (0, 0, 255), 2) 111 | image_b_name = image_name.replace('.jpg','_bbox.jpg') 112 | save_path = os.path.join(save_dir, image_b_name) 113 | cv2.imwrite(save_path, image_b) 114 | 115 | -------------------------------------------------------------------------------- /lib/test/vot/vot.py: -------------------------------------------------------------------------------- 1 | """ 2 | \file vot.py 3 | @brief Python utility functions for VOT integration 4 | @author Luka Cehovin, Alessio Dore 5 | @date 2016 6 | """ 7 | 8 | import sys 9 | import copy 10 | import collections 11 | import numpy as np 12 | 13 | try: 14 | import trax 15 | except ImportError: 16 | raise Exception('TraX support not found. Please add trax module to Python path.') 17 | 18 | Rectangle = collections.namedtuple('Rectangle', ['x', 'y', 'width', 'height']) 19 | Point = collections.namedtuple('Point', ['x', 'y']) 20 | Polygon = collections.namedtuple('Polygon', ['points']) 21 | 22 | class VOT(object): 23 | """ Base class for Python VOT integration """ 24 | def __init__(self, region_format, channels=None): 25 | """ Constructor 26 | Args: 27 | region_format: Region format options 28 | """ 29 | assert(region_format in [trax.Region.RECTANGLE, trax.Region.POLYGON, trax.Region.MASK]) 30 | 31 | if channels is None: 32 | channels = ['color'] 33 | elif channels == 'rgbd': 34 | channels = ['color', 'depth'] 35 | elif channels == 'rgbt': 36 | channels = ['color', 'ir'] 37 | elif channels == 'ir': 38 | channels = ['ir'] 39 | else: 40 | raise Exception('Illegal configuration {}.'.format(channels)) 41 | 42 | self._trax = trax.Server([region_format], [trax.Image.PATH], channels, customMetadata=dict(vot="python")) 43 | 44 | request = self._trax.wait() 45 | assert(request.type == 'initialize') 46 | if isinstance(request.region, trax.Polygon): 47 | self._region = Polygon([Point(x[0], x[1]) for x in request.region]) 48 | elif isinstance(request.region, trax.Mask): 49 | self._region = request.region.array(True) 50 | else: 51 | self._region = Rectangle(*request.region.bounds()) 52 | self._image = [x.path() for k, x in request.image.items()] 53 | if len(self._image) == 1: 54 | self._image = self._image[0] 55 | 56 | self._trax.status(request.region) 57 | 58 | def region(self): 59 | """ 60 | Send configuration message to the client and receive the initialization 61 | region and the path of the first image 62 | Returns: 63 | initialization region 64 | """ 65 | 66 | return self._region 67 | 68 | def report(self, region, confidence = None): 69 | """ 70 | Report the tracking results to the client 71 | Arguments: 72 | region: region for the frame 73 | """ 74 | assert(isinstance(region, (Rectangle, Polygon, np.ndarray))) 75 | if isinstance(region, Polygon): 76 | tregion = trax.Polygon.create([(x.x, x.y) for x in region.points]) 77 | elif isinstance(region, np.ndarray): 78 | tregion = trax.Mask.create(region) 79 | else: 80 | tregion = trax.Rectangle.create(region.x, region.y, region.width, region.height) 81 | properties = {} 82 | if not confidence is None: 83 | properties['confidence'] = confidence 84 | self._trax.status(tregion, properties) 85 | 86 | def frame(self): 87 | """ 88 | Get a frame (image path) from client 89 | Returns: 90 | absolute path of the image 91 | """ 92 | if hasattr(self, "_image"): 93 | image = self._image 94 | del self._image 95 | return image 96 | 97 | request = self._trax.wait() 98 | 99 | if request.type == 'frame': 100 | image = [x.path() for k, x in request.image.items()] 101 | if len(image) == 1: 102 | return image[0] 103 | return image 104 | else: 105 | return None 106 | 107 | 108 | def quit(self): 109 | if hasattr(self, '_trax'): 110 | self._trax.quit() 111 | 112 | def __del__(self): 113 | self.quit() 114 | -------------------------------------------------------------------------------- /lib/test/vot/vot22_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def make_full_size(x, output_sz): 5 | """ 6 | zero-pad input x (right and down) to match output_sz 7 | x: numpy array e.g., binary mask 8 | output_sz: size of the output [width, height] 9 | """ 10 | if x.shape[0] == output_sz[1] and x.shape[1] == output_sz[0]: 11 | return x 12 | pad_x = output_sz[0] - x.shape[1] 13 | if pad_x < 0: 14 | x = x[:, :x.shape[1] + pad_x] 15 | # padding has to be set to zero, otherwise pad function fails 16 | pad_x = 0 17 | pad_y = output_sz[1] - x.shape[0] 18 | if pad_y < 0: 19 | x = x[:x.shape[0] + pad_y, :] 20 | # padding has to be set to zero, otherwise pad function fails 21 | pad_y = 0 22 | return np.pad(x, ((0, pad_y), (0, pad_x)), 'constant', constant_values=0) 23 | 24 | 25 | def rect_from_mask(mask): 26 | """ 27 | create an axis-aligned rectangle from a given binary mask 28 | mask in created as a minimal rectangle containing all non-zero pixels 29 | """ 30 | x_ = np.sum(mask, axis=0) 31 | y_ = np.sum(mask, axis=1) 32 | x0 = np.min(np.nonzero(x_)) 33 | x1 = np.max(np.nonzero(x_)) 34 | y0 = np.min(np.nonzero(y_)) 35 | y1 = np.max(np.nonzero(y_)) 36 | return [x0, y0, x1 - x0 + 1, y1 - y0 + 1] 37 | 38 | 39 | def mask_from_rect(rect, output_sz): 40 | """ 41 | create a binary mask from a given rectangle 42 | rect: axis-aligned rectangle [x0, y0, width, height] 43 | output_sz: size of the output [width, height] 44 | """ 45 | mask = np.zeros((output_sz[1], output_sz[0]), dtype=np.uint8) 46 | x0 = max(int(round(rect[0])), 0) 47 | y0 = max(int(round(rect[1])), 0) 48 | x1 = min(int(round(rect[0] + rect[2])), output_sz[0]) 49 | y1 = min(int(round(rect[1] + rect[3])), output_sz[1]) 50 | mask[y0:y1, x0:x1] = 1 51 | return mask 52 | 53 | 54 | def bbox_clip(x1, y1, x2, y2, boundary, min_sz=10): 55 | """boundary (H,W)""" 56 | x1_new = max(0, min(x1, boundary[1] - min_sz)) 57 | y1_new = max(0, min(y1, boundary[0] - min_sz)) 58 | x2_new = max(min_sz, min(x2, boundary[1])) 59 | y2_new = max(min_sz, min(y2, boundary[0])) 60 | return x1_new, y1_new, x2_new, y2_new -------------------------------------------------------------------------------- /lib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .admin.multigpu import MultiGPU 2 | -------------------------------------------------------------------------------- /lib/train/_init_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | 14 | this_dir = osp.dirname(__file__) 15 | 16 | prj_path = osp.join(this_dir, '../..') 17 | add_path(prj_path) 18 | -------------------------------------------------------------------------------- /lib/train/actors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_actor import BaseActor 2 | from .bat import BATActor 3 | -------------------------------------------------------------------------------- /lib/train/actors/base_actor.py: -------------------------------------------------------------------------------- 1 | from lib.utils import TensorDict 2 | 3 | 4 | class BaseActor: 5 | """ Base class for actor. The actor class handles the passing of the data through the network 6 | and calculation the loss""" 7 | def __init__(self, net, objective): 8 | """ 9 | args: 10 | net - The network to train 11 | objective - The loss function 12 | """ 13 | self.net = net 14 | self.objective = objective 15 | 16 | def __call__(self, data: TensorDict): 17 | """ Called in each training iteration. Should pass in input data through the network, calculate the loss, and 18 | return the training stats for the input data 19 | args: 20 | data - A TensorDict containing all the necessary data blocks. 21 | 22 | returns: 23 | loss - loss for the input data 24 | stats - a dict containing detailed losses 25 | """ 26 | raise NotImplementedError 27 | 28 | def to(self, device): 29 | """ Move the network to device 30 | args: 31 | device - device to use. 'cpu' or 'cuda' 32 | """ 33 | self.net.to(device) 34 | 35 | def train(self, mode=True): 36 | """ Set whether the network is in train mode. 37 | args: 38 | mode (True) - Bool specifying whether in training mode. 39 | """ 40 | self.net.train(mode) 41 | 42 | def eval(self): 43 | """ Set network to eval mode""" 44 | self.train(False) -------------------------------------------------------------------------------- /lib/train/actors/bat.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | from . import BaseActor 4 | from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy 5 | import torch 6 | from ...utils.heapmap_utils import generate_heatmap 7 | from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate 8 | from lib.train.admin import multigpu 9 | 10 | 11 | class BATActor(BaseActor): 12 | """ Actor for training BAT models """ 13 | 14 | def __init__(self, net, objective, loss_weight, settings, cfg=None): 15 | super().__init__(net, objective) 16 | self.loss_weight = loss_weight 17 | self.settings = settings 18 | self.bs = self.settings.batchsize # batch size 19 | self.cfg = cfg 20 | 21 | def fix_bns(self): 22 | net = self.net.module if multigpu.is_multi_gpu(self.net) else self.net 23 | net.box_head.apply(self.fix_bn) 24 | 25 | def fix_bn(self, m): 26 | classname = m.__class__.__name__ 27 | if classname.find('BatchNorm') != -1: 28 | m.eval() 29 | 30 | def __call__(self, data): 31 | """ 32 | args: 33 | data - The input data, should contain the fields 'template', 'search', 'gt_bbox'. 34 | template_images: (N_t, batch, 3, H, W) 35 | search_images: (N_s, batch, 3, H, W) 36 | returns: 37 | loss - the training loss 38 | status - dict containing detailed losses 39 | """ 40 | # forward pass 41 | out_dict = self.forward_pass(data) 42 | 43 | # compute losses 44 | loss, status = self.compute_losses(out_dict, data) 45 | 46 | return loss, status 47 | 48 | def forward_pass(self, data): 49 | # currently only support 1 template and 1 search region 50 | assert len(data['template_images']) == 1 51 | assert len(data['search_images']) == 1 52 | 53 | template_list = [] 54 | for i in range(self.settings.num_template): 55 | template_img_i = data['template_images'][i].view(-1, 56 | *data['template_images'].shape[2:]) # (batch, 6, 128, 128) 57 | template_list.append(template_img_i) 58 | 59 | search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:]) # (batch, 6, 320, 320) 60 | 61 | box_mask_z = None 62 | ce_keep_rate = None 63 | if self.cfg.MODEL.BACKBONE.CE_LOC: 64 | box_mask_z = generate_mask_cond(self.cfg, template_list[0].shape[0], template_list[0].device, 65 | data['template_anno'][0]) 66 | 67 | ce_start_epoch = self.cfg.TRAIN.CE_START_EPOCH 68 | ce_warm_epoch = self.cfg.TRAIN.CE_WARM_EPOCH 69 | ce_keep_rate = adjust_keep_rate(data['epoch'], warmup_epochs=ce_start_epoch, 70 | total_epochs=ce_start_epoch + ce_warm_epoch, 71 | ITERS_PER_EPOCH=1, 72 | base_keep_rate=self.cfg.MODEL.BACKBONE.CE_KEEP_RATIO[0]) 73 | # ce_keep_rate = 0.7 74 | 75 | if len(template_list) == 1: 76 | template_list = template_list[0] 77 | 78 | out_dict = self.net(template=template_list, 79 | search=search_img, 80 | ce_template_mask=box_mask_z, 81 | ce_keep_rate=ce_keep_rate, 82 | return_last_attn=False) 83 | 84 | return out_dict 85 | 86 | def compute_losses(self, pred_dict, gt_dict, return_status=True): 87 | # gt gaussian map 88 | gt_bbox = gt_dict['search_anno'][-1] # (Ns, batch, 4) (x1,y1,w,h) -> (batch, 4) 89 | gt_gaussian_maps = generate_heatmap(gt_dict['search_anno'], self.cfg.DATA.SEARCH.SIZE, self.cfg.MODEL.BACKBONE.STRIDE) 90 | gt_gaussian_maps = gt_gaussian_maps[-1].unsqueeze(1) # (B,1,H,W) 91 | 92 | # Get boxes 93 | pred_boxes = pred_dict['pred_boxes'] 94 | if torch.isnan(pred_boxes).any(): 95 | raise ValueError("Network outputs is NAN! Stop Training") 96 | num_queries = pred_boxes.size(1) 97 | pred_boxes_vec = box_cxcywh_to_xyxy(pred_boxes).view(-1, 4) # (B,N,4) --> (BN,4) (x1,y1,x2,y2) 98 | gt_boxes_vec = box_xywh_to_xyxy(gt_bbox)[:, None, :].repeat((1, num_queries, 1)).view(-1, 4).clamp(min=0.0, 99 | max=1.0) # (B,4) --> (B,1,4) --> (B,N,4) 100 | # compute giou and iou 101 | try: 102 | giou_loss, iou = self.objective['giou'](pred_boxes_vec, gt_boxes_vec) # (BN,4) (BN,4) 103 | except: 104 | giou_loss, iou = torch.tensor(0.0).cuda(), torch.tensor(0.0).cuda() 105 | # compute l1 loss 106 | l1_loss = self.objective['l1'](pred_boxes_vec, gt_boxes_vec) # (BN,4) (BN,4) 107 | # compute location loss 108 | if 'score_map' in pred_dict: 109 | location_loss = self.objective['focal'](pred_dict['score_map'], gt_gaussian_maps) 110 | else: 111 | location_loss = torch.tensor(0.0, device=l1_loss.device) 112 | # weighted sum 113 | loss = self.loss_weight['giou'] * giou_loss + self.loss_weight['l1'] * l1_loss + self.loss_weight['focal'] * location_loss 114 | if return_status: 115 | # status for log 116 | mean_iou = iou.detach().mean() 117 | status = {"Loss/total": loss.item(), 118 | "Loss/giou": giou_loss.item(), 119 | "Loss/l1": l1_loss.item(), 120 | "Loss/location": location_loss.item(), 121 | "IoU": mean_iou.item()} 122 | return loss, status 123 | else: 124 | return loss -------------------------------------------------------------------------------- /lib/train/admin/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_settings, create_default_local_file_train 2 | from .tensorboard import TensorboardWriter 3 | from .stats import AverageMeter, StatValue -------------------------------------------------------------------------------- /lib/train/admin/environment.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from collections import OrderedDict 4 | 5 | 6 | def create_default_local_file(): 7 | path = os.path.join(os.path.dirname(__file__), 'local.py') 8 | 9 | empty_str = '\'\'' 10 | default_settings = OrderedDict({ 11 | 'workspace_dir': empty_str, 12 | 'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'', 13 | 'pretrained_networks': 'self.workspace_dir + \'/pretrained_networks/\'', 14 | 'lasot_dir': empty_str, 15 | 'got10k_dir': empty_str, 16 | 'trackingnet_dir': empty_str, 17 | 'coco_dir': empty_str, 18 | 'lvis_dir': empty_str, 19 | 'sbd_dir': empty_str, 20 | 'imagenet_dir': empty_str, 21 | 'imagenetdet_dir': empty_str, 22 | 'ecssd_dir': empty_str, 23 | 'hkuis_dir': empty_str, 24 | 'msra10k_dir': empty_str, 25 | 'davis_dir': empty_str, 26 | 'youtubevos_dir': empty_str}) 27 | 28 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.', 29 | 'tensorboard_dir': 'Directory for tensorboard files.'} 30 | 31 | with open(path, 'w') as f: 32 | f.write('class EnvironmentSettings:\n') 33 | f.write(' def __init__(self):\n') 34 | 35 | for attr, attr_val in default_settings.items(): 36 | comment_str = None 37 | if attr in comment: 38 | comment_str = comment[attr] 39 | if comment_str is None: 40 | f.write(' self.{} = {}\n'.format(attr, attr_val)) 41 | else: 42 | f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str)) 43 | 44 | 45 | def create_default_local_file_train(workspace_dir, data_dir): 46 | path = os.path.join(os.path.dirname(__file__), 'local.py') 47 | 48 | empty_str = '\'\'' 49 | default_settings = OrderedDict({ 50 | 'workspace_dir': workspace_dir, 51 | 'tensorboard_dir': os.path.join(workspace_dir, 'tensorboard'), # Directory for tensorboard files. 52 | 'pretrained_networks': os.path.join(workspace_dir, 'pretrained_networks'), 53 | 'got10k_val_dir': os.path.join(data_dir, 'got10k/val'), 54 | 'lasot_lmdb_dir': os.path.join(data_dir, 'lasot_lmdb'), 55 | 'got10k_lmdb_dir': os.path.join(data_dir, 'got10k_lmdb'), 56 | 'trackingnet_lmdb_dir': os.path.join(data_dir, 'trackingnet_lmdb'), 57 | 'coco_lmdb_dir': os.path.join(data_dir, 'coco_lmdb'), 58 | 'coco_dir': os.path.join(data_dir, 'coco'), 59 | 'lasot_dir': os.path.join(data_dir, 'lasot'), 60 | 'got10k_dir': os.path.join(data_dir, 'got10k/train'), 61 | 'trackingnet_dir': os.path.join(data_dir, 'trackingnet'), 62 | 'depthtrack_dir': os.path.join(data_dir, 'depthtrack/train'), 63 | 'lasher_dir': os.path.join(data_dir, 'lasher/trainingset'), 64 | 'visevent_dir': os.path.join(data_dir, 'visevent/train'), 65 | }) 66 | 67 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.', 68 | 'tensorboard_dir': 'Directory for tensorboard files.'} 69 | 70 | with open(path, 'w') as f: 71 | f.write('class EnvironmentSettings:\n') 72 | f.write(' def __init__(self):\n') 73 | 74 | for attr, attr_val in default_settings.items(): 75 | comment_str = None 76 | if attr in comment: 77 | comment_str = comment[attr] 78 | if comment_str is None: 79 | if attr_val == empty_str: 80 | f.write(' self.{} = {}\n'.format(attr, attr_val)) 81 | else: 82 | f.write(' self.{} = \'{}\'\n'.format(attr, attr_val)) 83 | else: 84 | f.write(' self.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 85 | 86 | 87 | def env_settings(): 88 | env_module_name = 'lib.train.admin.local' 89 | try: 90 | env_module = importlib.import_module(env_module_name) 91 | return env_module.EnvironmentSettings() 92 | except: 93 | env_file = os.path.join(os.path.dirname(__file__), 'local.py') 94 | 95 | create_default_local_file() 96 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file)) 97 | -------------------------------------------------------------------------------- /lib/train/admin/local.py: -------------------------------------------------------------------------------- 1 | class EnvironmentSettings: 2 | def __init__(self): 3 | self.workspace_dir = '/root/BAT-main' # Base directory for saving network checkpoints. 4 | self.tensorboard_dir = '/root/BAT-main/tensorboard' # Directory for tensorboard files. 5 | self.pretrained_networks = '/root/BAT-main/pretrained_networks' 6 | self.got10k_val_dir = '/root/BAT-main/data/got10k/val' 7 | self.lasot_lmdb_dir = '/root/BAT-main/data/lasot_lmdb' 8 | self.got10k_lmdb_dir = '/root/BAT-main/data/got10k_lmdb' 9 | self.trackingnet_lmdb_dir = '/root/BAT-main/data/trackingnet_lmdb' 10 | self.coco_lmdb_dir = '/root/BAT-main/data/coco_lmdb' 11 | self.coco_dir = '/root/BAT-main/data/coco' 12 | self.lasot_dir = '/root/BAT-main/data/lasot' 13 | self.got10k_dir = '/root/BAT-main/data/got10k/train' 14 | self.trackingnet_dir = '/root/BAT-main/data/trackingnet' 15 | self.depthtrack_dir = '/root/BAT-main/data/depthtrack/train' 16 | self.lasher_dir = '/root/LasHeR/TrainingSet' 17 | self.visevent_dir = '/root/BAT-main/data/visevent/train' 18 | -------------------------------------------------------------------------------- /lib/train/admin/multigpu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training 3 | 4 | 5 | def is_multi_gpu(net): 6 | return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel)) 7 | 8 | 9 | class MultiGPU(nn.parallel.distributed.DistributedDataParallel): 10 | def __getattr__(self, item): 11 | try: 12 | return super().__getattr__(item) 13 | except: 14 | pass 15 | return getattr(self.module, item) 16 | -------------------------------------------------------------------------------- /lib/train/admin/settings.py: -------------------------------------------------------------------------------- 1 | from lib.train.admin.environment import env_settings 2 | 3 | 4 | class Settings: 5 | """ Training settings, e.g. the paths to datasets and networks.""" 6 | def __init__(self): 7 | self.set_default() 8 | 9 | def set_default(self): 10 | self.env = env_settings() 11 | self.use_gpu = True 12 | 13 | 14 | -------------------------------------------------------------------------------- /lib/train/admin/stats.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class StatValue: 4 | def __init__(self): 5 | self.clear() 6 | 7 | def reset(self): 8 | self.val = 0 9 | 10 | def clear(self): 11 | self.reset() 12 | self.history = [] 13 | 14 | def update(self, val): 15 | self.val = val 16 | self.history.append(self.val) 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | def __init__(self): 22 | self.clear() 23 | self.has_new_data = False 24 | 25 | def reset(self): 26 | self.avg = 0 27 | self.val = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def clear(self): 32 | self.reset() 33 | self.history = [] 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | def new_epoch(self): 42 | if self.count > 0: 43 | self.history.append(self.avg) 44 | self.reset() 45 | self.has_new_data = True 46 | else: 47 | self.has_new_data = False 48 | 49 | 50 | def topk_accuracy(output, target, topk=(1,)): 51 | """Computes the precision@k for the specified values of k""" 52 | single_input = not isinstance(topk, (tuple, list)) 53 | if single_input: 54 | topk = (topk,) 55 | 56 | maxk = max(topk) 57 | batch_size = target.size(0) 58 | 59 | _, pred = output.topk(maxk, 1, True, True) 60 | pred = pred.t() 61 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 62 | 63 | res = [] 64 | for k in topk: 65 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0] 66 | res.append(correct_k * 100.0 / batch_size) 67 | 68 | if single_input: 69 | return res[0] 70 | 71 | return res 72 | -------------------------------------------------------------------------------- /lib/train/admin/tensorboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | try: 4 | from torch.utils.tensorboard import SummaryWriter 5 | except: 6 | print('WARNING: You are using tensorboardX instead sis you have a too old pytorch version.') 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | class TensorboardWriter: 11 | def __init__(self, directory, loader_names): 12 | self.directory = directory 13 | self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names}) 14 | 15 | def write_info(self, script_name, description): 16 | tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info')) 17 | tb_info_writer.add_text('Script_name', script_name) 18 | tb_info_writer.add_text('Description', description) 19 | tb_info_writer.close() 20 | 21 | def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1): 22 | for loader_name, loader_stats in stats.items(): 23 | if loader_stats is None: 24 | continue 25 | for var_name, val in loader_stats.items(): 26 | if hasattr(val, 'history') and getattr(val, 'has_new_data', True): 27 | self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch) -------------------------------------------------------------------------------- /lib/train/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import LTRLoader 2 | from .image_loader import jpeg4py_loader, opencv_loader, jpeg4py_loader_w_failsafe, default_image_loader 3 | -------------------------------------------------------------------------------- /lib/train/data/bounding_box_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rect_to_rel(bb, sz_norm=None): 5 | """Convert standard rectangular parametrization of the bounding box [x, y, w, h] 6 | to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate. 7 | args: 8 | bb - N x 4 tensor of boxes. 9 | sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given. 10 | """ 11 | 12 | c = bb[...,:2] + 0.5 * bb[...,2:] 13 | if sz_norm is None: 14 | c_rel = c / bb[...,2:] 15 | else: 16 | c_rel = c / sz_norm 17 | sz_rel = torch.log(bb[...,2:]) 18 | return torch.cat((c_rel, sz_rel), dim=-1) 19 | 20 | 21 | def rel_to_rect(bb, sz_norm=None): 22 | """Inverts the effect of rect_to_rel. See above.""" 23 | 24 | sz = torch.exp(bb[...,2:]) 25 | if sz_norm is None: 26 | c = bb[...,:2] * sz 27 | else: 28 | c = bb[...,:2] * sz_norm 29 | tl = c - 0.5 * sz 30 | return torch.cat((tl, sz), dim=-1) 31 | 32 | 33 | def masks_to_bboxes(mask, fmt='c'): 34 | 35 | """ Convert a mask tensor to one or more bounding boxes. 36 | Note: This function is a bit new, make sure it does what it says. /Andreas 37 | :param mask: Tensor of masks, shape = (..., H, W) 38 | :param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height) 39 | 't' => "top left + size" or (x_left, y_top, width, height) 40 | 'v' => "vertices" or (x_left, y_top, x_right, y_bottom) 41 | :return: tensor containing a batch of bounding boxes, shape = (..., 4) 42 | """ 43 | batch_shape = mask.shape[:-2] 44 | mask = mask.reshape((-1, *mask.shape[-2:])) 45 | bboxes = [] 46 | 47 | for m in mask: 48 | mx = m.sum(dim=-2).nonzero() 49 | my = m.sum(dim=-1).nonzero() 50 | bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0] 51 | bboxes.append(bb) 52 | 53 | bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device) 54 | bboxes = bboxes.reshape(batch_shape + (4,)) 55 | 56 | if fmt == 'v': 57 | return bboxes 58 | 59 | x1 = bboxes[..., :2] 60 | s = bboxes[..., 2:] - x1 + 1 61 | 62 | if fmt == 'c': 63 | return torch.cat((x1 + 0.5 * s, s), dim=-1) 64 | elif fmt == 't': 65 | return torch.cat((x1, s), dim=-1) 66 | 67 | raise ValueError("Undefined bounding box layout '%s'" % fmt) 68 | 69 | 70 | def masks_to_bboxes_multi(mask, ids, fmt='c'): 71 | assert mask.dim() == 2 72 | bboxes = [] 73 | 74 | for id in ids: 75 | mx = (mask == id).sum(dim=-2).nonzero() 76 | my = (mask == id).float().sum(dim=-1).nonzero() 77 | bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0] 78 | 79 | bb = torch.tensor(bb, dtype=torch.float32, device=mask.device) 80 | 81 | x1 = bb[:2] 82 | s = bb[2:] - x1 + 1 83 | 84 | if fmt == 'v': 85 | pass 86 | elif fmt == 'c': 87 | bb = torch.cat((x1 + 0.5 * s, s), dim=-1) 88 | elif fmt == 't': 89 | bb = torch.cat((x1, s), dim=-1) 90 | else: 91 | raise ValueError("Undefined bounding box layout '%s'" % fmt) 92 | bboxes.append(bb) 93 | 94 | return bboxes 95 | -------------------------------------------------------------------------------- /lib/train/data/image_loader.py: -------------------------------------------------------------------------------- 1 | import jpeg4py 2 | import cv2 as cv 3 | from PIL import Image 4 | import numpy as np 5 | 6 | davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8) 7 | davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 8 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 9 | [64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0], 10 | [64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128], 11 | [0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0], 12 | [0, 64, 128], [128, 64, 128]] 13 | 14 | 15 | def default_image_loader(path): 16 | """The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader, 17 | but reverts to the opencv_loader if the former is not available.""" 18 | if default_image_loader.use_jpeg4py is None: 19 | # Try using jpeg4py 20 | im = jpeg4py_loader(path) 21 | if im is None: 22 | default_image_loader.use_jpeg4py = False 23 | print('Using opencv_loader instead.') 24 | else: 25 | default_image_loader.use_jpeg4py = True 26 | return im 27 | if default_image_loader.use_jpeg4py: 28 | return jpeg4py_loader(path) 29 | return opencv_loader(path) 30 | 31 | default_image_loader.use_jpeg4py = None 32 | 33 | 34 | def jpeg4py_loader(path): 35 | """ Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py""" 36 | try: 37 | return jpeg4py.JPEG(path).decode() 38 | except Exception as e: 39 | print('ERROR: Could not read image "{}"'.format(path)) 40 | print(e) 41 | return None 42 | 43 | 44 | def opencv_loader(path): 45 | """ Read image using opencv's imread function and returns it in rgb format""" 46 | try: 47 | im = cv.imread(path, cv.IMREAD_COLOR) 48 | 49 | # convert to rgb and return 50 | return cv.cvtColor(im, cv.COLOR_BGR2RGB) 51 | except Exception as e: 52 | print('ERROR: Could not read image "{}"'.format(path)) 53 | print(e) 54 | return None 55 | 56 | 57 | def jpeg4py_loader_w_failsafe(path): 58 | """ Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py""" 59 | try: 60 | return jpeg4py.JPEG(path).decode() 61 | except: 62 | try: 63 | im = cv.imread(path, cv.IMREAD_COLOR) 64 | 65 | # convert to rgb and return 66 | return cv.cvtColor(im, cv.COLOR_BGR2RGB) 67 | except Exception as e: 68 | print('ERROR: Could not read image "{}"'.format(path)) 69 | print(e) 70 | return None 71 | 72 | 73 | def opencv_seg_loader(path): 74 | """ Read segmentation annotation using opencv's imread function""" 75 | try: 76 | return cv.imread(path) 77 | except Exception as e: 78 | print('ERROR: Could not read image "{}"'.format(path)) 79 | print(e) 80 | return None 81 | 82 | 83 | def imread_indexed(filename): 84 | """ Load indexed image with given filename. Used to read segmentation annotations.""" 85 | 86 | im = Image.open(filename) 87 | 88 | annotation = np.atleast_3d(im)[...,0] 89 | return annotation 90 | 91 | 92 | def imwrite_indexed(filename, array, color_palette=None): 93 | """ Save indexed image as png. Used to save segmentation annotation.""" 94 | 95 | if color_palette is None: 96 | color_palette = davis_palette 97 | 98 | if np.atleast_3d(array).shape[2] != 1: 99 | raise Exception("Saving indexed PNGs requires 2D array.") 100 | 101 | im = Image.fromarray(array) 102 | im.putpalette(color_palette.ravel()) 103 | im.save(filename, format='PNG') -------------------------------------------------------------------------------- /lib/train/data_specs/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | ## Description for different text files 4 | GOT10K 5 | - got10k_train_full_split.txt: the complete GOT-10K training set. (9335 videos) 6 | - got10k_train_split.txt: part of videos from the GOT-10K training set 7 | - got10k_val_split.txt: another part of videos from the GOT-10K training set 8 | - got10k_vot_exclude.txt: 1k videos that are forbidden from "using to train models then testing on VOT" (as required by [VOT Challenge](https://www.votchallenge.net/vot2020/participation.html)) 9 | - got10k_vot_train_split.txt: part of videos from the "VOT-permitted" GOT-10K training set 10 | - got10k_vot_val_split.txt: another part of videos from the "VOT-permitted" GOT-10K training set 11 | 12 | LaSOT 13 | - lasot_train_split.txt: the complete LaSOT training set 14 | 15 | TrackingNnet 16 | - trackingnet_classmap.txt: The map from the sequence name to the target class for the TrackingNet -------------------------------------------------------------------------------- /lib/train/data_specs/depthtrack_train.txt: -------------------------------------------------------------------------------- 1 | adapter02_indoor 2 | bag03_indoor 3 | bag05_indoor 4 | ball02_indoor 5 | ball03_indoor 6 | ball04_indoor 7 | ball05_indoor 8 | ball07_indoor 9 | ball08_wild 10 | ball09_wild 11 | ball12_wild 12 | ball13_indoor 13 | ball14_wild 14 | ball17_wild 15 | ball19_indoor 16 | ball21_indoor 17 | basket_indoor 18 | beautifullight01_indoor 19 | bike01_wild 20 | bike02_wild 21 | bike03_wild 22 | book01_indoor 23 | book02_indoor 24 | book04_indoor 25 | book05_indoor 26 | book06_indoor 27 | bottle01_indoor 28 | bottle02_indoor 29 | bottle05_indoor 30 | bottle06_indoor 31 | box_indoor 32 | candlecup_indoor 33 | car01_indoor 34 | car02_indoor 35 | cart_indoor 36 | cat02_indoor 37 | cat03_indoor 38 | cat04_indoor 39 | cat05_indoor 40 | chair01_indoor 41 | chair02_indoor 42 | clothes_indoor 43 | colacan01_indoor 44 | colacan02_indoor 45 | colacan04_indoor 46 | container01_indoor 47 | container02_indoor 48 | cube01_indoor 49 | cube04_indoor 50 | cube06_indoor 51 | cup03_indoor 52 | cup05_indoor 53 | cup06_indoor 54 | cup07_indoor 55 | cup08_indoor 56 | cup09_indoor 57 | cup10_indoor 58 | cup11_indoor 59 | cup13_indoor 60 | cup14_indoor 61 | duck01_wild 62 | duck02_wild 63 | duck04_wild 64 | duck05_wild 65 | duck06_wild 66 | dumbbells02_indoor 67 | earphone02_indoor 68 | egg_indoor 69 | file02_indoor 70 | flower01_indoor 71 | flower02_wild 72 | flowerbasket_indoor 73 | ghostmask_indoor 74 | glass02_indoor 75 | glass03_indoor 76 | glass04_indoor 77 | glass05_indoor 78 | guitarbag_indoor 79 | gymring_wild 80 | hand02_indoor 81 | hat01_indoor 82 | hat02_indoor_320 83 | hat03_indoor 84 | hat04_indoor 85 | human01_indoor 86 | human03_wild 87 | human04_wild 88 | human05_wild 89 | human06_indoor 90 | leaves01_wild 91 | leaves02_indoor 92 | leaves03_wild 93 | leaves04_indoor 94 | leaves05_indoor 95 | leaves06_wild 96 | lock01_wild 97 | mac_indoor 98 | milkbottle_indoor 99 | mirror_indoor 100 | mobilephone01_indoor 101 | mobilephone02_indoor 102 | mobilephone04_indoor 103 | mobilephone05_indoor 104 | mobilephone06_indoor 105 | mushroom01_indoor 106 | mushroom02_wild 107 | mushroom03_wild 108 | mushroom04_indoor 109 | mushroom05_indoor 110 | notebook02_indoor 111 | notebook03_indoor 112 | paintbottle_indoor 113 | painting_indoor_320 114 | parkingsign_wild 115 | pigeon03_wild 116 | pigeon06_wild 117 | pigeon07_wild 118 | pine01_indoor 119 | pine02_wild_320 120 | shoes01_indoor 121 | shoes03_indoor 122 | skateboard01_indoor 123 | skateboard02_indoor 124 | speaker_indoor 125 | stand_indoor 126 | suitcase_indoor 127 | swing01_wild 128 | swing02_wild 129 | teacup_indoor 130 | thermos01_indoor 131 | thermos02_indoor 132 | toiletpaper02_indoor 133 | toiletpaper03_indoor 134 | toiletpaper04_indoor 135 | toy01_indoor 136 | toy04_indoor 137 | toy05_indoor 138 | toy06_indoor 139 | toy07_indoor_320 140 | toy08_indoor 141 | toy10_indoor 142 | toydog_indoor 143 | trashbin_indoor 144 | tree_wild 145 | trophy_indoor 146 | ukulele02_indoor 147 | -------------------------------------------------------------------------------- /lib/train/data_specs/depthtrack_val.txt: -------------------------------------------------------------------------------- 1 | toy03_indoor 2 | pigeon05_wild 3 | bottle03_indoor 4 | ball16_indoor 5 | bag04_indoor 6 | flower03_indoor -------------------------------------------------------------------------------- /lib/train/data_specs/lasher_val.txt: -------------------------------------------------------------------------------- 1 | boywalkinginsnow3 2 | leftdrillmasterstanding 3 | leftgirlunderthelamp 4 | girlridesbike 5 | midboyplayingphone 6 | boywithumbrella 7 | manrun 8 | ab_pingpongball 9 | whitecarturnl 10 | girltakemoto 11 | rightgirlatbike 12 | easy_blackboy 13 | man_with_black_clothes2 14 | 7runone 15 | turnblkbike 16 | motobesidescar 17 | bikeafterwhitecar 18 | 2runsix 19 | rightboy_1227 20 | whitesuvcome 21 | AQrightofcomingmotos 22 | 7one 23 | blackman_0115 24 | rightmirrornotshining 25 | AQmanfromdarktrees 26 | bikeboy128 27 | orangegirl 28 | girlturnbike 29 | blackman2 30 | blackcarback 31 | rightof2cupsattached 32 | whitecar2west 33 | hatboy`shead 34 | whitebetweenblackandblue 35 | 2rdcarcome 36 | whitemancome 37 | nearmangotoD 38 | farmanrightwhitesmallhouse 39 | lightmotocoming 40 | boymototakesgirl 41 | leftblackboy 42 | righttallholdball 43 | blackcarcome 44 | twolinefirstone-gai 45 | lowerfoam2throw 46 | Awhitecargo 47 | car2north3 48 | rightfirstboy-ly 49 | girltakingplate 50 | left2ndgreenboy 51 | ab_bolster 52 | 9hatboy 53 | whitecarturn2 54 | midboyblue 55 | basketboywhite 56 | nightmototurn 57 | girlbike 58 | mantoground 59 | pickuptheyellowbook 60 | 8lastone 61 | AQbikeback 62 | girlsquattingbesidesleftbar 63 | blkbikefromnorth 64 | whitecar 65 | Amidredgirl 66 | blackbag 67 | AQblkgirlbike 68 | manwithyellowumbrella 69 | browncar2north 70 | carstop 71 | whiteboywithbag 72 | theleftestrunningboy 73 | girlafterglassdoor2 74 | rightmirrorlikesky 75 | redgirl1497 76 | midboy 77 | folderatlefthand 78 | bikecome 79 | leftfallenchair_inf_white 80 | Agirlrideback 81 | rightgirl 82 | belowrightwhiteboy 83 | moto2north1 84 | truckk 85 | highright2ndboy 86 | girl`sheadoncall 87 | whiteboy 88 | truckwhite 89 | AQgirlbiketurns 90 | left2ndboy 91 | whitegirl2right 92 | rightboywithwhite 93 | girlplayingphone 94 | girlumbrella 95 | truck 96 | manfarbesidespool 97 | dotat43 -------------------------------------------------------------------------------- /lib/train/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .lasot import Lasot 2 | from .got10k import Got10k 3 | from .tracking_net import TrackingNet 4 | from .imagenetvid import ImagenetVID 5 | from .coco import MSCOCO 6 | from .coco_seq import MSCOCOSeq 7 | from .got10k_lmdb import Got10k_lmdb 8 | from .lasot_lmdb import Lasot_lmdb 9 | from .imagenetvid_lmdb import ImagenetVID_lmdb 10 | from .coco_seq_lmdb import MSCOCOSeq_lmdb 11 | from .tracking_net_lmdb import TrackingNet_lmdb 12 | # RGBT dataloader 13 | from .lasher import LasHeR 14 | # RGBD dataloader 15 | from .depthtrack import DepthTrack 16 | # Event dataloader 17 | from .visevent import VisEvent -------------------------------------------------------------------------------- /lib/train/dataset/base_image_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from lib.train.data.image_loader import jpeg4py_loader 3 | 4 | 5 | class BaseImageDataset(torch.utils.data.Dataset): 6 | """ Base class for image datasets """ 7 | 8 | def __init__(self, name, root, image_loader=jpeg4py_loader): 9 | """ 10 | args: 11 | root - The root path to the dataset 12 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 13 | is used by default. 14 | """ 15 | self.name = name 16 | self.root = root 17 | self.image_loader = image_loader 18 | 19 | self.image_list = [] # Contains the list of sequences. 20 | self.class_list = [] 21 | 22 | def __len__(self): 23 | """ Returns size of the dataset 24 | returns: 25 | int - number of samples in the dataset 26 | """ 27 | return self.get_num_images() 28 | 29 | def __getitem__(self, index): 30 | """ Not to be used! Check get_frames() instead. 31 | """ 32 | return None 33 | 34 | def get_name(self): 35 | """ Name of the dataset 36 | 37 | returns: 38 | string - Name of the dataset 39 | """ 40 | raise NotImplementedError 41 | 42 | def get_num_images(self): 43 | """ Number of sequences in a dataset 44 | 45 | returns: 46 | int - number of sequences in the dataset.""" 47 | return len(self.image_list) 48 | 49 | def has_class_info(self): 50 | return False 51 | 52 | def get_class_name(self, image_id): 53 | return None 54 | 55 | def get_num_classes(self): 56 | return len(self.class_list) 57 | 58 | def get_class_list(self): 59 | return self.class_list 60 | 61 | def get_images_in_class(self, class_name): 62 | raise NotImplementedError 63 | 64 | def has_segmentation_info(self): 65 | return False 66 | 67 | def get_image_info(self, seq_id): 68 | """ Returns information about a particular image, 69 | 70 | args: 71 | seq_id - index of the image 72 | 73 | returns: 74 | Dict 75 | """ 76 | raise NotImplementedError 77 | 78 | def get_image(self, image_id, anno=None): 79 | """ Get a image 80 | 81 | args: 82 | image_id - index of image 83 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded. 84 | 85 | returns: 86 | image - 87 | anno - 88 | dict - A dict containing meta information about the sequence, e.g. class of the target object. 89 | 90 | """ 91 | raise NotImplementedError 92 | 93 | -------------------------------------------------------------------------------- /lib/train/dataset/base_video_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | # 2021.1.5 use jpeg4py_loader_w_failsafe as default 3 | from lib.train.data.image_loader import jpeg4py_loader_w_failsafe 4 | 5 | 6 | class BaseVideoDataset(torch.utils.data.Dataset): 7 | """ Base class for video datasets """ 8 | 9 | def __init__(self, name, root, image_loader=jpeg4py_loader_w_failsafe): 10 | """ 11 | args: 12 | root - The root path to the dataset 13 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 14 | is used by default. 15 | """ 16 | self.name = name 17 | self.root = root 18 | self.image_loader = image_loader 19 | 20 | self.sequence_list = [] # Contains the list of sequences. 21 | self.class_list = [] 22 | 23 | def __len__(self): 24 | """ Returns size of the dataset 25 | returns: 26 | int - number of samples in the dataset 27 | """ 28 | return self.get_num_sequences() 29 | 30 | def __getitem__(self, index): 31 | """ Not to be used! Check get_frames() instead. 32 | """ 33 | return None 34 | 35 | def is_video_sequence(self): 36 | """ Returns whether the dataset is a video dataset or an image dataset 37 | 38 | returns: 39 | bool - True if a video dataset 40 | """ 41 | return True 42 | 43 | def is_synthetic_video_dataset(self): 44 | """ Returns whether the dataset contains real videos or synthetic 45 | 46 | returns: 47 | bool - True if a video dataset 48 | """ 49 | return False 50 | 51 | def get_name(self): 52 | """ Name of the dataset 53 | 54 | returns: 55 | string - Name of the dataset 56 | """ 57 | raise NotImplementedError 58 | 59 | def get_num_sequences(self): 60 | """ Number of sequences in a dataset 61 | 62 | returns: 63 | int - number of sequences in the dataset.""" 64 | return len(self.sequence_list) 65 | 66 | def has_class_info(self): 67 | return False 68 | 69 | def has_occlusion_info(self): 70 | return False 71 | 72 | def get_num_classes(self): 73 | return len(self.class_list) 74 | 75 | def get_class_list(self): 76 | return self.class_list 77 | 78 | def get_sequences_in_class(self, class_name): 79 | raise NotImplementedError 80 | 81 | def has_segmentation_info(self): 82 | return False 83 | 84 | def get_sequence_info(self, seq_id): 85 | """ Returns information about a particular sequences, 86 | 87 | args: 88 | seq_id - index of the sequence 89 | 90 | returns: 91 | Dict 92 | """ 93 | raise NotImplementedError 94 | 95 | def get_frames(self, seq_id, frame_ids, anno=None): 96 | """ Get a set of frames from a particular sequence 97 | 98 | args: 99 | seq_id - index of sequence 100 | frame_ids - a list of frame numbers 101 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded. 102 | 103 | returns: 104 | list - List of frames corresponding to frame_ids 105 | list - List of dicts for each frame 106 | dict - A dict containing meta information about the sequence, e.g. class of the target object. 107 | 108 | """ 109 | raise NotImplementedError 110 | 111 | -------------------------------------------------------------------------------- /lib/train/dataset/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_image_dataset import BaseImageDataset 3 | import torch 4 | import random 5 | from collections import OrderedDict 6 | from lib.train.data import jpeg4py_loader 7 | from lib.train.admin import env_settings 8 | from pycocotools.coco import COCO 9 | 10 | 11 | class MSCOCO(BaseImageDataset): 12 | """ The COCO object detection dataset. 13 | 14 | Publication: 15 | Microsoft COCO: Common Objects in Context. 16 | Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona, 17 | Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick 18 | ECCV, 2014 19 | https://arxiv.org/pdf/1405.0312.pdf 20 | 21 | Download the images along with annotations from http://cocodataset.org/#download. The root folder should be 22 | organized as follows. 23 | - coco_root 24 | - annotations 25 | - instances_train2014.json 26 | - instances_train2017.json 27 | - images 28 | - train2014 29 | - train2017 30 | 31 | Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi. 32 | """ 33 | 34 | def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, min_area=None, 35 | split="train", version="2014"): 36 | """ 37 | args: 38 | root - path to coco root folder 39 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 40 | is used by default. 41 | data_fraction - Fraction of dataset to be used. The complete dataset is used by default 42 | min_area - Objects with area less than min_area are filtered out. Default is 0.0 43 | split - 'train' or 'val'. 44 | version - version of coco dataset (2014 or 2017) 45 | """ 46 | 47 | root = env_settings().coco_dir if root is None else root 48 | super().__init__('COCO', root, image_loader) 49 | 50 | self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version)) 51 | self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version)) 52 | 53 | self.coco_set = COCO(self.anno_path) 54 | 55 | self.cats = self.coco_set.cats 56 | 57 | self.class_list = self.get_class_list() # the parent class thing would happen in the sampler 58 | 59 | self.image_list = self._get_image_list(min_area=min_area) 60 | 61 | if data_fraction is not None: 62 | self.image_list = random.sample(self.image_list, int(len(self.image_list) * data_fraction)) 63 | self.im_per_class = self._build_im_per_class() 64 | 65 | def _get_image_list(self, min_area=None): 66 | ann_list = list(self.coco_set.anns.keys()) 67 | image_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0] 68 | 69 | if min_area is not None: 70 | image_list = [a for a in image_list if self.coco_set.anns[a]['area'] > min_area] 71 | 72 | return image_list 73 | 74 | def get_num_classes(self): 75 | return len(self.class_list) 76 | 77 | def get_name(self): 78 | return 'coco' 79 | 80 | def has_class_info(self): 81 | return True 82 | 83 | def has_segmentation_info(self): 84 | return True 85 | 86 | def get_class_list(self): 87 | class_list = [] 88 | for cat_id in self.cats.keys(): 89 | class_list.append(self.cats[cat_id]['name']) 90 | return class_list 91 | 92 | def _build_im_per_class(self): 93 | im_per_class = {} 94 | for i, im in enumerate(self.image_list): 95 | class_name = self.cats[self.coco_set.anns[im]['category_id']]['name'] 96 | if class_name not in im_per_class: 97 | im_per_class[class_name] = [i] 98 | else: 99 | im_per_class[class_name].append(i) 100 | 101 | return im_per_class 102 | 103 | def get_images_in_class(self, class_name): 104 | return self.im_per_class[class_name] 105 | 106 | def get_image_info(self, im_id): 107 | anno = self._get_anno(im_id) 108 | 109 | bbox = torch.Tensor(anno['bbox']).view(4,) 110 | 111 | mask = torch.Tensor(self.coco_set.annToMask(anno)) 112 | 113 | valid = (bbox[2] > 0) & (bbox[3] > 0) 114 | visible = valid.clone().byte() 115 | 116 | return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible} 117 | 118 | def _get_anno(self, im_id): 119 | anno = self.coco_set.anns[self.image_list[im_id]] 120 | 121 | return anno 122 | 123 | def _get_image(self, im_id): 124 | path = self.coco_set.loadImgs([self.coco_set.anns[self.image_list[im_id]]['image_id']])[0]['file_name'] 125 | img = self.image_loader(os.path.join(self.img_pth, path)) 126 | return img 127 | 128 | def get_meta_info(self, im_id): 129 | try: 130 | cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']] 131 | object_meta = OrderedDict({'object_class_name': cat_dict_current['name'], 132 | 'motion_class': None, 133 | 'major_class': cat_dict_current['supercategory'], 134 | 'root_class': None, 135 | 'motion_adverb': None}) 136 | except: 137 | object_meta = OrderedDict({'object_class_name': None, 138 | 'motion_class': None, 139 | 'major_class': None, 140 | 'root_class': None, 141 | 'motion_adverb': None}) 142 | return object_meta 143 | 144 | def get_class_name(self, im_id): 145 | cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']] 146 | return cat_dict_current['name'] 147 | 148 | def get_image(self, image_id, anno=None): 149 | frame = self._get_image(image_id) 150 | 151 | if anno is None: 152 | anno = self.get_image_info(image_id) 153 | 154 | object_meta = self.get_meta_info(image_id) 155 | 156 | return frame, anno, object_meta 157 | -------------------------------------------------------------------------------- /lib/train/dataset/depthtrack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch 4 | import numpy as np 5 | import pandas 6 | import csv 7 | from collections import OrderedDict 8 | from .base_video_dataset import BaseVideoDataset 9 | from lib.train.data import jpeg4py_loader_w_failsafe 10 | from lib.train.admin import env_settings 11 | from lib.train.dataset.depth_utils import get_x_frame 12 | 13 | class DepthTrack(BaseVideoDataset): 14 | """ DepthTrack dataset. 15 | """ 16 | 17 | def __init__(self, root=None, dtype='rgbcolormap', split='train', image_loader=jpeg4py_loader_w_failsafe): # vid_ids=None, split=None, data_fraction=None 18 | """ 19 | args: 20 | 21 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 22 | is used by default. 23 | vid_ids - List containing the ids of the videos (1 - 20) used for training. If vid_ids = [1, 3, 5], then the 24 | videos with subscripts -1, -3, and -5 from each class will be used for training. 25 | # split - If split='train', the official train split (protocol-II) is used for training. Note: Only one of 26 | # vid_ids or split option can be used at a time. 27 | # data_fraction - Fraction of dataset to be used. The complete dataset is used by default 28 | 29 | root - path to the lasot depth dataset. 30 | dtype - colormap or depth,, colormap + depth 31 | if colormap, it returns the colormap by cv2, 32 | if depth, it returns [depth, depth, depth] 33 | """ 34 | root = env_settings().depthtrack_dir if root is None else root 35 | super().__init__('DepthTrack', root, image_loader) 36 | 37 | self.dtype = dtype # colormap or depth 38 | self.split = split 39 | self.sequence_list = self._build_sequence_list() 40 | 41 | self.seq_per_class, self.class_list = self._build_class_list() 42 | self.class_list.sort() 43 | self.class_to_id = {cls_name: cls_id for cls_id, cls_name in enumerate(self.class_list)} 44 | 45 | def _build_sequence_list(self): 46 | 47 | ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 48 | file_path = os.path.join(ltr_path, 'data_specs', 'depthtrack_%s.txt'%self.split) 49 | sequence_list = pandas.read_csv(file_path, header=None, squeeze=True).values.tolist() 50 | return sequence_list 51 | 52 | def _build_class_list(self): 53 | seq_per_class = {} 54 | class_list = [] 55 | for seq_id, seq_name in enumerate(self.sequence_list): 56 | class_name = seq_name.split('_')[0] 57 | 58 | if class_name not in class_list: 59 | class_list.append(class_name) 60 | 61 | if class_name in seq_per_class: 62 | seq_per_class[class_name].append(seq_id) 63 | else: 64 | seq_per_class[class_name] = [seq_id] 65 | 66 | return seq_per_class, class_list 67 | 68 | def get_name(self): 69 | return 'depthtrack' 70 | 71 | def has_class_info(self): 72 | return True 73 | 74 | def has_occlusion_info(self): 75 | return True 76 | 77 | def get_num_sequences(self): 78 | return len(self.sequence_list) 79 | 80 | def get_num_classes(self): 81 | return len(self.class_list) 82 | 83 | def get_sequences_in_class(self, class_name): 84 | return self.seq_per_class[class_name] 85 | 86 | def _read_bb_anno(self, seq_path): 87 | bb_anno_file = os.path.join(seq_path, "groundtruth.txt") 88 | gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=True, low_memory=False).values 89 | return torch.tensor(gt) 90 | 91 | def _get_sequence_path(self, seq_id): 92 | seq_name = self.sequence_list[seq_id] 93 | return os.path.join(self.root, seq_name) 94 | 95 | def get_sequence_info(self, seq_id): 96 | seq_path = self._get_sequence_path(seq_id) 97 | bbox = self._read_bb_anno(seq_path) # xywh just one kind label 98 | ''' 99 | if the box is too small, it will be ignored 100 | ''' 101 | # valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 102 | valid = (bbox[:, 2] > 10.0) & (bbox[:, 3] > 10.0) 103 | visible = valid.clone().byte() 104 | return {'bbox': bbox, 'valid': valid, 'visible': visible} 105 | 106 | def _get_frame_path(self, seq_path, frame_id): 107 | ''' 108 | return depth image path 109 | ''' 110 | return os.path.join(seq_path, 'color', '{:08}.jpg'.format(frame_id+1)) , os.path.join(seq_path, 'depth', '{:08}.png'.format(frame_id+1)) # frames start from 1 111 | 112 | def _get_frame(self, seq_path, frame_id): 113 | ''' 114 | Return : 115 | - colormap from depth image 116 | - 3xD = [depth, depth, depth], 255 117 | - rgbcolormap 118 | - rgb3d 119 | - color 120 | - raw_depth 121 | ''' 122 | color_path, depth_path = self._get_frame_path(seq_path, frame_id) 123 | img = get_x_frame(color_path, depth_path, dtype=self.dtype, depth_clip=True) 124 | 125 | return img 126 | 127 | def _get_class(self, seq_path): 128 | # raw_class = seq_path.split('/')[-2] 129 | # return raw_class 130 | return self.split 131 | 132 | def get_class_name(self, seq_id): 133 | depth_path = self._get_sequence_path(seq_id) 134 | obj_class = self._get_class(depth_path) 135 | 136 | return obj_class 137 | 138 | def get_frames(self, seq_id, frame_ids, anno=None): 139 | seq_path = self._get_sequence_path(seq_id) 140 | 141 | obj_class = self._get_class(seq_path) 142 | 143 | if anno is None: 144 | anno = self.get_sequence_info(seq_id) 145 | 146 | anno_frames = {} 147 | for key, value in anno.items(): 148 | anno_frames[key] = [value[f_id, ...].clone() for ii, f_id in enumerate(frame_ids)] 149 | 150 | frame_list = [self._get_frame(seq_path, f_id) for ii, f_id in enumerate(frame_ids)] 151 | 152 | object_meta = OrderedDict({'object_class_name': obj_class, 153 | 'motion_class': None, 154 | 'major_class': None, 155 | 'root_class': None, 156 | 'motion_adverb': None}) 157 | 158 | return frame_list, anno_frames, object_meta 159 | -------------------------------------------------------------------------------- /lib/train/dataset/imagenetvid_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_video_dataset import BaseVideoDataset 3 | from lib.train.data import jpeg4py_loader 4 | import torch 5 | from collections import OrderedDict 6 | from lib.train.admin import env_settings 7 | from lib.utils.lmdb_utils import decode_img, decode_json 8 | 9 | 10 | def get_target_to_image_ratio(seq): 11 | anno = torch.Tensor(seq['anno']) 12 | img_sz = torch.Tensor(seq['image_size']) 13 | return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt() 14 | 15 | 16 | class ImagenetVID_lmdb(BaseVideoDataset): 17 | """ Imagenet VID dataset. 18 | 19 | Publication: 20 | ImageNet Large Scale Visual Recognition Challenge 21 | Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, 22 | Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei 23 | IJCV, 2015 24 | https://arxiv.org/pdf/1409.0575.pdf 25 | 26 | Download the dataset from http://image-net.org/ 27 | """ 28 | def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1): 29 | """ 30 | args: 31 | root - path to the imagenet vid dataset. 32 | image_loader (default_image_loader) - The function to read the images. If installed, 33 | jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else, 34 | opencv's imread is used. 35 | min_length - Minimum allowed sequence length. 36 | max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets 37 | which cover complete image. 38 | """ 39 | root = env_settings().imagenet_dir if root is None else root 40 | super().__init__("imagenetvid_lmdb", root, image_loader) 41 | 42 | sequence_list_dict = decode_json(root, "cache.json") 43 | self.sequence_list = sequence_list_dict 44 | 45 | # Filter the sequences based on min_length and max_target_area in the first frame 46 | self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and 47 | get_target_to_image_ratio(x) < max_target_area] 48 | 49 | def get_name(self): 50 | return 'imagenetvid_lmdb' 51 | 52 | def get_num_sequences(self): 53 | return len(self.sequence_list) 54 | 55 | def get_sequence_info(self, seq_id): 56 | bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno']) 57 | valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0) 58 | visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte() 59 | return {'bbox': bb_anno, 'valid': valid, 'visible': visible} 60 | 61 | def _get_frame(self, sequence, frame_id): 62 | set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id']) 63 | vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id']) 64 | frame_number = frame_id + sequence['start_frame'] 65 | frame_path = os.path.join('Data', 'VID', 'train', set_name, vid_name, 66 | '{:06d}.JPEG'.format(frame_number)) 67 | return decode_img(self.root, frame_path) 68 | 69 | def get_frames(self, seq_id, frame_ids, anno=None): 70 | sequence = self.sequence_list[seq_id] 71 | 72 | frame_list = [self._get_frame(sequence, f) for f in frame_ids] 73 | 74 | if anno is None: 75 | anno = self.get_sequence_info(seq_id) 76 | 77 | # Create anno dict 78 | anno_frames = {} 79 | for key, value in anno.items(): 80 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 81 | 82 | # added the class info to the meta info 83 | object_meta = OrderedDict({'object_class': sequence['class_name'], 84 | 'motion_class': None, 85 | 'major_class': None, 86 | 'root_class': None, 87 | 'motion_adverb': None}) 88 | 89 | return frame_list, anno_frames, object_meta 90 | 91 | -------------------------------------------------------------------------------- /lib/train/dataset/lasher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | import torch 5 | import csv 6 | import pandas 7 | import random 8 | from collections import OrderedDict 9 | from .base_video_dataset import BaseVideoDataset 10 | from lib.train.admin import env_settings 11 | from lib.train.dataset.depth_utils import get_x_frame 12 | 13 | 14 | class LasHeR(BaseVideoDataset): 15 | """ LasHeR dataset(aligned version). 16 | 17 | Publication: 18 | A Large-scale High-diversity Benchmark for RGBT Tracking 19 | Chenglong Li, Wanlin Xue, Yaqing Jia, Zhichen Qu, Bin Luo, Jin Tang, and Dengdi Sun 20 | https://arxiv.org/pdf/2104.13202.pdf 21 | 22 | Download dataset from https://github.com/BUGPLEASEOUT/LasHeR 23 | """ 24 | 25 | def __init__(self, root=None, split='train', dtype='rgbrgb', seq_ids=None, data_fraction=None): 26 | """ 27 | args: 28 | root - path to the LasHeR trainingset. 29 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 30 | is used by default. 31 | seq_ids - List containing the ids of the videos to be used for training. Note: Only one of 'split' or 'seq_ids' 32 | options can be used at the same time. 33 | data_fraction - Fraction of dataset to be used. The complete dataset is used by default 34 | """ 35 | root = env_settings().lasher_dir if root is None else root 36 | assert split in ['train', 'val','all'], 'Only support all, train or val split in LasHeR, got {}'.format(split) 37 | super().__init__('LasHeR', root) 38 | self.dtype = dtype 39 | 40 | # all folders inside the root 41 | self.sequence_list = self._get_sequence_list(split) 42 | 43 | # seq_id is the index of the folder inside the got10k root path 44 | if seq_ids is None: 45 | seq_ids = list(range(0, len(self.sequence_list))) 46 | 47 | self.sequence_list = [self.sequence_list[i] for i in seq_ids] 48 | 49 | if data_fraction is not None: 50 | self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction)) 51 | 52 | def get_name(self): 53 | return 'lasher' 54 | 55 | def has_class_info(self): 56 | return True 57 | 58 | def has_occlusion_info(self): 59 | return True # w=h=0 in visible.txt and infrared.txt is occlusion/oov 60 | 61 | def _get_sequence_list(self, split): 62 | ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 63 | file_path = os.path.join(ltr_path, 'data_specs', 'lasher_{}.txt'.format(split)) 64 | with open(file_path, 'r') as f: 65 | dir_list = f.read().splitlines() 66 | return dir_list 67 | 68 | def _read_bb_anno(self, seq_path): 69 | # in lasher dataset, visible.txt is same as infrared.txt 70 | rgb_bb_anno_file = os.path.join(seq_path, "init.txt") 71 | # ir_bb_anno_file = os.path.join(seq_path, "infrared.txt") 72 | rgb_gt = pandas.read_csv(rgb_bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values 73 | # ir_gt = pandas.read_csv(ir_bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values 74 | return torch.tensor(rgb_gt) 75 | 76 | def _get_sequence_path(self, seq_id): 77 | return os.path.join(self.root, self.sequence_list[seq_id]) 78 | 79 | def get_sequence_info(self, seq_id): 80 | """2022/8/10 ir and rgb have synchronous w=h=0 frame_index""" 81 | seq_path = self._get_sequence_path(seq_id) 82 | bbox = self._read_bb_anno(seq_path) 83 | valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 84 | visible = valid.clone().byte() 85 | return {'bbox': bbox, 'valid': valid, 'visible': visible} 86 | 87 | def _get_frame_path(self, seq_path, frame_id): 88 | # Note original filename is chaotic, we rename them 89 | rgb_frame_path = os.path.join(seq_path, 'visible', '{:06d}.jpg'.format(frame_id)) # frames start from 0 90 | ir_frame_path = os.path.join(seq_path, 'infrared', '{:06d}.jpg'.format(frame_id)) 91 | return (rgb_frame_path, ir_frame_path) # jpg jpg 92 | 93 | def _get_frame(self, seq_path, frame_id): 94 | rgb_frame_path, ir_frame_path = self._get_frame_path(seq_path, frame_id) 95 | img = get_x_frame(rgb_frame_path, ir_frame_path, dtype=self.dtype) 96 | return img # (h,w,6) 97 | 98 | def get_frames(self, seq_id, frame_ids, anno=None): 99 | seq_path = self._get_sequence_path(seq_id) 100 | 101 | frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids] 102 | 103 | if anno is None: 104 | anno = self.get_sequence_info(seq_id) 105 | 106 | anno_frames = {} 107 | for key, value in anno.items(): 108 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 109 | 110 | object_meta = OrderedDict({'object_class_name': None, 111 | 'motion_class': None, 112 | 'major_class': None, 113 | 'root_class': None, 114 | 'motion_adverb': None}) 115 | 116 | return frame_list, anno_frames, object_meta 117 | -------------------------------------------------------------------------------- /lib/train/dataset/tracking_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import os.path 4 | import numpy as np 5 | import pandas 6 | import random 7 | from collections import OrderedDict 8 | 9 | from lib.train.data import jpeg4py_loader 10 | from .base_video_dataset import BaseVideoDataset 11 | from lib.train.admin import env_settings 12 | 13 | 14 | def list_sequences(root, set_ids): 15 | """ Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name) 16 | 17 | args: 18 | root: Root directory to TrackingNet 19 | set_ids: Sets (0-11) which are to be used 20 | 21 | returns: 22 | list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence 23 | """ 24 | sequence_list = [] 25 | 26 | for s in set_ids: 27 | anno_dir = os.path.join(root, "TRAIN_" + str(s), "anno") 28 | 29 | sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')] 30 | sequence_list += sequences_cur_set 31 | 32 | return sequence_list 33 | 34 | 35 | class TrackingNet(BaseVideoDataset): 36 | """ TrackingNet dataset. 37 | 38 | Publication: 39 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild. 40 | Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem 41 | ECCV, 2018 42 | https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf 43 | 44 | Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit. 45 | """ 46 | def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None): 47 | """ 48 | args: 49 | root - The path to the TrackingNet folder, containing the training sets. 50 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 51 | is used by default. 52 | set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the 53 | sets (0 - 11) will be used. 54 | data_fraction - Fraction of dataset to be used. The complete dataset is used by default 55 | """ 56 | root = env_settings().trackingnet_dir if root is None else root 57 | super().__init__('TrackingNet', root, image_loader) 58 | 59 | if set_ids is None: 60 | set_ids = [i for i in range(12)] 61 | 62 | self.set_ids = set_ids 63 | 64 | # Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and 65 | # video_name for each sequence 66 | self.sequence_list = list_sequences(self.root, self.set_ids) 67 | 68 | if data_fraction is not None: 69 | self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction)) 70 | 71 | self.seq_to_class_map, self.seq_per_class = self._load_class_info() 72 | 73 | # we do not have the class_lists for the tracking net 74 | self.class_list = list(self.seq_per_class.keys()) 75 | self.class_list.sort() 76 | 77 | def _load_class_info(self): 78 | ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 79 | class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt') 80 | 81 | with open(class_map_path, 'r') as f: 82 | seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f} 83 | 84 | seq_per_class = {} 85 | for i, seq in enumerate(self.sequence_list): 86 | class_name = seq_to_class_map.get(seq[1], 'Unknown') 87 | if class_name not in seq_per_class: 88 | seq_per_class[class_name] = [i] 89 | else: 90 | seq_per_class[class_name].append(i) 91 | 92 | return seq_to_class_map, seq_per_class 93 | 94 | def get_name(self): 95 | return 'trackingnet' 96 | 97 | def has_class_info(self): 98 | return True 99 | 100 | def get_sequences_in_class(self, class_name): 101 | return self.seq_per_class[class_name] 102 | 103 | def _read_bb_anno(self, seq_id): 104 | set_id = self.sequence_list[seq_id][0] 105 | vid_name = self.sequence_list[seq_id][1] 106 | bb_anno_file = os.path.join(self.root, "TRAIN_" + str(set_id), "anno", vid_name + ".txt") 107 | gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, 108 | low_memory=False).values 109 | return torch.tensor(gt) 110 | 111 | def get_sequence_info(self, seq_id): 112 | bbox = self._read_bb_anno(seq_id) 113 | 114 | valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 115 | visible = valid.clone().byte() 116 | return {'bbox': bbox, 'valid': valid, 'visible': visible} 117 | 118 | def _get_frame(self, seq_id, frame_id): 119 | set_id = self.sequence_list[seq_id][0] 120 | vid_name = self.sequence_list[seq_id][1] 121 | frame_path = os.path.join(self.root, "TRAIN_" + str(set_id), "frames", vid_name, str(frame_id) + ".jpg") 122 | return self.image_loader(frame_path) 123 | 124 | def _get_class(self, seq_id): 125 | seq_name = self.sequence_list[seq_id][1] 126 | return self.seq_to_class_map[seq_name] 127 | 128 | def get_class_name(self, seq_id): 129 | obj_class = self._get_class(seq_id) 130 | 131 | return obj_class 132 | 133 | def get_frames(self, seq_id, frame_ids, anno=None): 134 | frame_list = [self._get_frame(seq_id, f) for f in frame_ids] 135 | 136 | if anno is None: 137 | anno = self.get_sequence_info(seq_id) 138 | 139 | anno_frames = {} 140 | for key, value in anno.items(): 141 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 142 | 143 | obj_class = self._get_class(seq_id) 144 | 145 | object_meta = OrderedDict({'object_class_name': obj_class, 146 | 'motion_class': None, 147 | 'major_class': None, 148 | 'root_class': None, 149 | 'motion_adverb': None}) 150 | 151 | return frame_list, anno_frames, object_meta 152 | -------------------------------------------------------------------------------- /lib/train/dataset/tracking_net_lmdb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import os.path 4 | import numpy as np 5 | import random 6 | from collections import OrderedDict 7 | 8 | from lib.train.data import jpeg4py_loader 9 | from .base_video_dataset import BaseVideoDataset 10 | from lib.train.admin import env_settings 11 | import json 12 | from lib.utils.lmdb_utils import decode_img, decode_str 13 | 14 | 15 | def list_sequences(root): 16 | """ Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name) 17 | 18 | args: 19 | root: Root directory to TrackingNet 20 | 21 | returns: 22 | list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence 23 | """ 24 | fname = os.path.join(root, "seq_list.json") 25 | with open(fname, "r") as f: 26 | sequence_list = json.loads(f.read()) 27 | return sequence_list 28 | 29 | 30 | class TrackingNet_lmdb(BaseVideoDataset): 31 | """ TrackingNet dataset. 32 | 33 | Publication: 34 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild. 35 | Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem 36 | ECCV, 2018 37 | https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf 38 | 39 | Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit. 40 | """ 41 | def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None): 42 | """ 43 | args: 44 | root - The path to the TrackingNet folder, containing the training sets. 45 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 46 | is used by default. 47 | set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the 48 | sets (0 - 11) will be used. 49 | data_fraction - Fraction of dataset to be used. The complete dataset is used by default 50 | """ 51 | root = env_settings().trackingnet_lmdb_dir if root is None else root 52 | super().__init__('TrackingNet_lmdb', root, image_loader) 53 | 54 | if set_ids is None: 55 | set_ids = [i for i in range(12)] 56 | 57 | self.set_ids = set_ids 58 | 59 | # Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and 60 | # video_name for each sequence 61 | self.sequence_list = list_sequences(self.root) 62 | 63 | if data_fraction is not None: 64 | self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction)) 65 | 66 | self.seq_to_class_map, self.seq_per_class = self._load_class_info() 67 | 68 | # we do not have the class_lists for the tracking net 69 | self.class_list = list(self.seq_per_class.keys()) 70 | self.class_list.sort() 71 | 72 | def _load_class_info(self): 73 | ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 74 | class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt') 75 | 76 | with open(class_map_path, 'r') as f: 77 | seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f} 78 | 79 | seq_per_class = {} 80 | for i, seq in enumerate(self.sequence_list): 81 | class_name = seq_to_class_map.get(seq[1], 'Unknown') 82 | if class_name not in seq_per_class: 83 | seq_per_class[class_name] = [i] 84 | else: 85 | seq_per_class[class_name].append(i) 86 | 87 | return seq_to_class_map, seq_per_class 88 | 89 | def get_name(self): 90 | return 'trackingnet_lmdb' 91 | 92 | def has_class_info(self): 93 | return True 94 | 95 | def get_sequences_in_class(self, class_name): 96 | return self.seq_per_class[class_name] 97 | 98 | def _read_bb_anno(self, seq_id): 99 | set_id = self.sequence_list[seq_id][0] 100 | vid_name = self.sequence_list[seq_id][1] 101 | gt_str_list = decode_str(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id), 102 | os.path.join("anno", vid_name + ".txt")).split('\n')[:-1] 103 | gt_list = [list(map(float, line.split(','))) for line in gt_str_list] 104 | gt_arr = np.array(gt_list).astype(np.float32) 105 | return torch.tensor(gt_arr) 106 | 107 | def get_sequence_info(self, seq_id): 108 | bbox = self._read_bb_anno(seq_id) 109 | 110 | valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 111 | visible = valid.clone().byte() 112 | return {'bbox': bbox, 'valid': valid, 'visible': visible} 113 | 114 | def _get_frame(self, seq_id, frame_id): 115 | set_id = self.sequence_list[seq_id][0] 116 | vid_name = self.sequence_list[seq_id][1] 117 | return decode_img(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id), 118 | os.path.join("frames", vid_name, str(frame_id) + ".jpg")) 119 | 120 | def _get_class(self, seq_id): 121 | seq_name = self.sequence_list[seq_id][1] 122 | return self.seq_to_class_map[seq_name] 123 | 124 | def get_class_name(self, seq_id): 125 | obj_class = self._get_class(seq_id) 126 | 127 | return obj_class 128 | 129 | def get_frames(self, seq_id, frame_ids, anno=None): 130 | frame_list = [self._get_frame(seq_id, f) for f in frame_ids] 131 | 132 | if anno is None: 133 | anno = self.get_sequence_info(seq_id) 134 | 135 | anno_frames = {} 136 | for key, value in anno.items(): 137 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 138 | 139 | obj_class = self._get_class(seq_id) 140 | 141 | object_meta = OrderedDict({'object_class_name': obj_class, 142 | 'motion_class': None, 143 | 'major_class': None, 144 | 'root_class': None, 145 | 'motion_adverb': None}) 146 | 147 | return frame_list, anno_frames, object_meta 148 | -------------------------------------------------------------------------------- /lib/train/dataset/visevent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import torch 4 | import numpy as np 5 | import pandas 6 | import csv 7 | from glob import glob 8 | from collections import OrderedDict 9 | from .base_video_dataset import BaseVideoDataset 10 | from lib.train.data import jpeg4py_loader_w_failsafe 11 | from lib.train.admin import env_settings 12 | from lib.train.dataset.depth_utils import get_x_frame 13 | 14 | 15 | class VisEvent(BaseVideoDataset): 16 | """ VisEvent dataset. 17 | """ 18 | 19 | def __init__(self, root=None, dtype='rgbrgb', split='train', image_loader=jpeg4py_loader_w_failsafe): # vid_ids=None, split=None, data_fraction=None 20 | """ 21 | args: 22 | 23 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 24 | is used by default. 25 | vid_ids - List containing the ids of the videos (1 - 20) used for training. If vid_ids = [1, 3, 5], then the 26 | videos with subscripts -1, -3, and -5 from each class will be used for training. 27 | # split - If split='train', the official train split (protocol-II) is used for training. Note: Only one of 28 | # vid_ids or split option can be used at a time. 29 | # data_fraction - Fraction of dataset to be used. The complete dataset is used by default 30 | 31 | root - path to the lasot depth dataset. 32 | dtype - colormap or depth,, colormap + depth 33 | if colormap, it returns the colormap by cv2, 34 | if depth, it returns [depth, depth, depth] 35 | """ 36 | root = env_settings().visevent_dir if root is None else root 37 | assert split in ['train'], 'Only support train split in VisEvent, got {}'.format(split) 38 | super().__init__('VisEvent', root, image_loader) 39 | 40 | self.dtype = dtype # colormap or depth 41 | self.split = split 42 | self.sequence_list = self._build_sequence_list() 43 | 44 | 45 | def _build_sequence_list(self): 46 | 47 | file_path = os.path.join(self.root, '{}list.txt'.format(self.split)) 48 | sequence_list = pandas.read_csv(file_path, header=None, squeeze=True).values.tolist() 49 | return sequence_list 50 | 51 | def get_name(self): 52 | return 'visevent' 53 | 54 | def has_class_info(self): 55 | return False 56 | 57 | def has_occlusion_info(self): 58 | return True 59 | 60 | def get_num_sequences(self): 61 | return len(self.sequence_list) 62 | 63 | def _read_bb_anno(self, seq_path): 64 | bb_anno_file = os.path.join(seq_path, "groundtruth.txt") 65 | gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=True, low_memory=False).values 66 | return torch.tensor(gt) 67 | 68 | def _read_target_visible(self, seq_path): 69 | # Read full occlusion and out_of_view 70 | occlusion_file = os.path.join(seq_path, "absent_label.txt") 71 | 72 | with open(occlusion_file, 'r', newline='') as f: 73 | occlusion = torch.ByteTensor([int(v[0]) for v in list(csv.reader(f))]) 74 | 75 | target_visible = occlusion 76 | 77 | return target_visible 78 | 79 | def _get_sequence_path(self, seq_id): 80 | seq_name = self.sequence_list[seq_id] 81 | return os.path.join(self.root, seq_name) 82 | 83 | def get_sequence_info(self, seq_id): 84 | seq_path = self._get_sequence_path(seq_id) 85 | bbox = self._read_bb_anno(seq_path) # xywh just one kind label 86 | ''' 87 | if the box is too small, it will be ignored 88 | ''' 89 | # valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 90 | valid = (bbox[:, 2] > 5.0) & (bbox[:, 3] > 5.0) 91 | visible = self._read_target_visible(seq_path) & valid.byte() 92 | return {'bbox': bbox, 'valid': valid, 'visible': visible} 93 | 94 | def _get_frame_path(self, seq_path, frame_id): 95 | ''' 96 | return rgb event image path 97 | ''' 98 | vis_img_files = sorted(glob(os.path.join(seq_path, 'vis_imgs', '*.bmp'))) 99 | 100 | try: 101 | vis_path = vis_img_files[frame_id] 102 | except: 103 | print(f"seq_path: {seq_path}") 104 | print(f"vis_img_files: {vis_img_files}") 105 | print(f"frame_id: {frame_id}") 106 | 107 | event_path = vis_path.replace('vis_imgs', 'event_imgs') 108 | 109 | 110 | return vis_path, event_path # frames start irregularly 111 | 112 | def _get_frame(self, seq_path, frame_id): 113 | ''' 114 | Return : 115 | - rgb+event_colormap 116 | ''' 117 | color_path, event_path = self._get_frame_path(seq_path, frame_id) 118 | img = get_x_frame(color_path, event_path, dtype=self.dtype, depth_clip=False) 119 | return img # (h,w,6) 120 | 121 | def get_frames(self, seq_id, frame_ids, anno=None): 122 | seq_path = self._get_sequence_path(seq_id) 123 | 124 | if anno is None: 125 | anno = self.get_sequence_info(seq_id) 126 | 127 | anno_frames = {} 128 | for key, value in anno.items(): 129 | anno_frames[key] = [value[f_id, ...].clone() for ii, f_id in enumerate(frame_ids)] 130 | 131 | frame_list = [self._get_frame(seq_path, f_id) for ii, f_id in enumerate(frame_ids)] 132 | 133 | object_meta = OrderedDict({'object_class_name': None, 134 | 'motion_class': None, 135 | 'major_class': None, 136 | 'root_class': None, 137 | 'motion_adverb': None}) 138 | 139 | return frame_list, anno_frames, object_meta 140 | -------------------------------------------------------------------------------- /lib/train/run_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import importlib 5 | import cv2 as cv 6 | import torch.backends.cudnn 7 | import torch.distributed as dist 8 | 9 | import random 10 | import numpy as np 11 | torch.backends.cudnn.benchmark = False 12 | 13 | import _init_paths 14 | import lib.train.admin.settings as ws_settings 15 | 16 | 17 | def init_seeds(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | 25 | 26 | def run_training(script_name, config_name, cudnn_benchmark=True, local_rank=-1, save_dir=None, base_seed=None, 27 | use_lmdb=False, script_name_prv=None, config_name_prv=None, use_wandb=False, 28 | distill=None, script_teacher=None, config_teacher=None): 29 | """Run the train script. 30 | args: 31 | script_name: Name of emperiment in the "experiments/" folder. 32 | config_name: Name of the yaml file in the "experiments/". 33 | cudnn_benchmark: Use cudnn benchmark or not (default is True). 34 | """ 35 | if save_dir is None: 36 | print("save_dir dir is not given. Use the default dir instead.") 37 | # This is needed to avoid strange crashes related to opencv 38 | cv.setNumThreads(0) 39 | 40 | torch.backends.cudnn.benchmark = cudnn_benchmark 41 | 42 | print('script_name: {}.py config_name: {}.yaml'.format(script_name, config_name)) 43 | 44 | '''2021.1.5 set seed for different process''' 45 | if base_seed is not None: 46 | if local_rank != -1: 47 | init_seeds(base_seed + local_rank) 48 | else: 49 | init_seeds(base_seed) 50 | 51 | settings = ws_settings.Settings() 52 | settings.script_name = script_name 53 | settings.config_name = config_name 54 | settings.project_path = 'train/{}/{}'.format(script_name, config_name) 55 | if script_name_prv is not None and config_name_prv is not None: 56 | settings.project_path_prv = 'train/{}/{}'.format(script_name_prv, config_name_prv) 57 | settings.local_rank = local_rank 58 | settings.save_dir = os.path.abspath(save_dir) 59 | settings.use_lmdb = use_lmdb 60 | prj_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) 61 | settings.cfg_file = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_name, config_name)) 62 | settings.use_wandb = use_wandb 63 | if distill: 64 | settings.distill = distill 65 | settings.script_teacher = script_teacher 66 | settings.config_teacher = config_teacher 67 | if script_teacher is not None and config_teacher is not None: 68 | settings.project_path_teacher = 'train/{}/{}'.format(script_teacher, config_teacher) 69 | settings.cfg_file_teacher = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_teacher, config_teacher)) 70 | expr_module = importlib.import_module('lib.train.train_script_distill') 71 | else: 72 | expr_module = importlib.import_module('lib.train.train_script') 73 | expr_func = getattr(expr_module, 'run') 74 | 75 | expr_func(settings) 76 | 77 | 78 | def main(): 79 | parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.') 80 | parser.add_argument('--script', type=str, required=True, help='Name of the train script.') 81 | parser.add_argument('--config', type=str, required=True, help="Name of the config file.") 82 | parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).') 83 | parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') 84 | parser.add_argument('--save_dir', type=str, help='the directory to save checkpoints and logs') # ./output 85 | parser.add_argument('--seed', type=int, default=0, help='seed for random numbers') 86 | parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format 87 | parser.add_argument('--script_prv', type=str, default=None, help='Name of the train script of previous model.') 88 | parser.add_argument('--config_prv', type=str, default=None, help="Name of the config file of previous model.") 89 | parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb 90 | # for knowledge distillation 91 | parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation 92 | parser.add_argument('--script_teacher', type=str, help='teacher script name') 93 | parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name') 94 | 95 | args = parser.parse_args() 96 | if args.local_rank != -1: 97 | dist.init_process_group(backend='nccl') 98 | torch.cuda.set_device(args.local_rank) 99 | else: 100 | torch.cuda.set_device(0) 101 | run_training(args.script, args.config, cudnn_benchmark=args.cudnn_benchmark, 102 | local_rank=args.local_rank, save_dir=args.save_dir, base_seed=args.seed, 103 | use_lmdb=args.use_lmdb, script_name_prv=args.script_prv, config_name_prv=args.config_prv, 104 | use_wandb=args.use_wandb, 105 | distill=args.distill, script_teacher=args.script_teacher, config_teacher=args.config_teacher) 106 | 107 | 108 | if __name__ == '__main__': 109 | main() 110 | -------------------------------------------------------------------------------- /lib/train/train_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | # loss function related 3 | from lib.utils.box_ops import giou_loss 4 | from torch.nn.functional import l1_loss 5 | from torch.nn import BCEWithLogitsLoss 6 | # train pipeline related 7 | from lib.train.trainers import LTRTrainer 8 | # distributed training related 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | # some more advanced functions 11 | from .base_functions import * 12 | # network related 13 | from lib.models.bat import build_ostrack 14 | from lib.models.bat import build_batrack 15 | # forward propagation related 16 | from lib.train.actors import BATActor 17 | # for import modules 18 | import importlib 19 | 20 | from ..utils.focal_loss import FocalLoss 21 | 22 | 23 | def run(settings): 24 | settings.description = 'Training script for bat' 25 | 26 | # update the default configs with config file 27 | if not os.path.exists(settings.cfg_file): 28 | raise ValueError("%s doesn't exist." % settings.cfg_file) 29 | config_module = importlib.import_module("lib.config.%s.config" % settings.script_name) 30 | cfg = config_module.cfg 31 | config_module.update_config_from_file(settings.cfg_file) 32 | if settings.local_rank in [-1, 0]: 33 | print("New configuration is shown below.") 34 | for key in cfg.keys(): 35 | print("%s configuration:" % key, cfg[key]) 36 | print('\n') 37 | 38 | # update settings based on cfg 39 | update_settings(settings, cfg) 40 | 41 | # Record the training log 42 | log_dir = os.path.join(settings.save_dir, 'logs') 43 | if settings.local_rank in [-1, 0]: 44 | if not os.path.exists(log_dir): 45 | os.makedirs(log_dir) 46 | settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name)) 47 | 48 | # Build dataloaders 49 | loader_train, loader_val = build_dataloaders(cfg, settings) 50 | 51 | # Create network 52 | if settings.script_name == "bat": 53 | net = build_batrack(cfg) 54 | else: 55 | raise ValueError("illegal script name") 56 | 57 | # wrap networks to distributed one 58 | net.cuda() 59 | if settings.local_rank != -1: 60 | # net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter 61 | net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True) 62 | settings.device = torch.device("cuda:%d" % settings.local_rank) 63 | else: 64 | settings.device = torch.device("cuda:0") 65 | settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False) 66 | settings.distill = getattr(cfg.TRAIN, "DISTILL", False) 67 | settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL") 68 | # Loss functions and Actors 69 | if settings.script_name == "bat": 70 | # here cls loss and cls weight are not use 71 | focal_loss = FocalLoss() 72 | objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss()} 73 | loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 1., 'cls': 1.0} 74 | actor = BATActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg) 75 | else: 76 | raise ValueError("illegal script name") 77 | 78 | # Optimizer, parameters, and learning rates 79 | optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg) 80 | use_amp = getattr(cfg.TRAIN, "AMP", False) 81 | settings.save_epoch_interval = getattr(cfg.TRAIN, "SAVE_EPOCH_INTERVAL", 1) 82 | settings.save_last_n_epoch = getattr(cfg.TRAIN, "SAVE_LAST_N_EPOCH", 1) 83 | 84 | if loader_val is None: 85 | trainer = LTRTrainer(actor, [loader_train], optimizer, settings, lr_scheduler, use_amp=use_amp) 86 | else: 87 | trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp) 88 | 89 | # train process 90 | trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True) 91 | -------------------------------------------------------------------------------- /lib/train/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | from .ltr_trainer import LTRTrainer 3 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensor import TensorDict, TensorList 2 | -------------------------------------------------------------------------------- /lib/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_area 3 | import numpy as np 4 | 5 | 6 | def box_cxcywh_to_xyxy(x): 7 | x_c, y_c, w, h = x.unbind(-1) 8 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 9 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 10 | return torch.stack(b, dim=-1) 11 | 12 | 13 | def box_xywh_to_xyxy(x): 14 | x1, y1, w, h = x.unbind(-1) 15 | b = [x1, y1, x1 + w, y1 + h] 16 | return torch.stack(b, dim=-1) 17 | 18 | 19 | def box_xyxy_to_xywh(x): 20 | x1, y1, x2, y2 = x.unbind(-1) 21 | b = [x1, y1, x2 - x1, y2 - y1] 22 | return torch.stack(b, dim=-1) 23 | 24 | 25 | def box_xyxy_to_cxcywh(x): 26 | x0, y0, x1, y1 = x.unbind(-1) 27 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 28 | (x1 - x0), (y1 - y0)] 29 | return torch.stack(b, dim=-1) 30 | 31 | 32 | # modified from torchvision to also return the union 33 | '''Note that this function only supports shape (N,4)''' 34 | 35 | 36 | def box_iou(boxes1, boxes2): 37 | """ 38 | 39 | :param boxes1: (N, 4) (x1,y1,x2,y2) 40 | :param boxes2: (N, 4) (x1,y1,x2,y2) 41 | :return: 42 | """ 43 | area1 = box_area(boxes1) # (N,) 44 | area2 = box_area(boxes2) # (N,) 45 | 46 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N,2) 47 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N,2) 48 | 49 | wh = (rb - lt).clamp(min=0) # (N,2) 50 | inter = wh[:, 0] * wh[:, 1] # (N,) 51 | 52 | union = area1 + area2 - inter 53 | 54 | iou = inter / union 55 | return iou, union 56 | 57 | 58 | '''Note that this implementation is different from DETR's''' 59 | 60 | 61 | def generalized_box_iou(boxes1, boxes2): 62 | """ 63 | Generalized IoU from https://giou.stanford.edu/ 64 | 65 | The boxes should be in [x0, y0, x1, y1] format 66 | 67 | boxes1: (N, 4) 68 | boxes2: (N, 4) 69 | """ 70 | # degenerate boxes gives inf / nan results 71 | # so do an early check 72 | # try: 73 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 74 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 75 | iou, union = box_iou(boxes1, boxes2) # (N,) 76 | 77 | lt = torch.min(boxes1[:, :2], boxes2[:, :2]) 78 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) 79 | 80 | wh = (rb - lt).clamp(min=0) # (N,2) 81 | area = wh[:, 0] * wh[:, 1] # (N,) 82 | 83 | return iou - (area - union) / area, iou 84 | 85 | 86 | def giou_loss(boxes1, boxes2): 87 | """ 88 | 89 | :param boxes1: (N, 4) (x1,y1,x2,y2) 90 | :param boxes2: (N, 4) (x1,y1,x2,y2) 91 | :return: 92 | """ 93 | giou, iou = generalized_box_iou(boxes1, boxes2) 94 | return (1 - giou).mean(), iou 95 | 96 | 97 | def clip_box(box: list, H, W, margin=0): 98 | x1, y1, w, h = box 99 | x2, y2 = x1 + w, y1 + h 100 | x1 = min(max(0, x1), W-margin) 101 | x2 = min(max(margin, x2), W) 102 | y1 = min(max(0, y1), H-margin) 103 | y2 = min(max(margin, y2), H) 104 | w = max(margin, x2-x1) 105 | h = max(margin, y2-y1) 106 | return [x1, y1, w, h] 107 | -------------------------------------------------------------------------------- /lib/utils/ce_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def generate_bbox_mask(bbox_mask, bbox): 8 | b, h, w = bbox_mask.shape 9 | for i in range(b): 10 | bbox_i = bbox[i].cpu().tolist() 11 | bbox_mask[i, int(bbox_i[1]):int(bbox_i[1] + bbox_i[3] - 1), int(bbox_i[0]):int(bbox_i[0] + bbox_i[2] - 1)] = 1 12 | return bbox_mask 13 | 14 | 15 | def generate_mask_cond(cfg, bs, device, gt_bbox): 16 | template_size = cfg.DATA.TEMPLATE.SIZE 17 | stride = cfg.MODEL.BACKBONE.STRIDE 18 | template_feat_size = template_size // stride 19 | 20 | if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'ALL': 21 | box_mask_z = None 22 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT': 23 | if template_feat_size == 8: 24 | index = slice(3, 4) 25 | elif template_feat_size == 12: 26 | index = slice(5, 6) 27 | elif template_feat_size == 7: 28 | index = slice(3, 4) 29 | elif template_feat_size == 14: 30 | index = slice(6, 7) 31 | else: 32 | raise NotImplementedError 33 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 34 | box_mask_z[:, index, index] = 1 35 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 36 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_REC': 37 | # use fixed 4x4 region, 3:5 for 8x8 38 | # use fixed 4x4 region 5:6 for 12x12 39 | if template_feat_size == 8: 40 | index = slice(3, 5) 41 | elif template_feat_size == 12: 42 | index = slice(5, 7) 43 | elif template_feat_size == 7: 44 | index = slice(3, 4) 45 | else: 46 | raise NotImplementedError 47 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 48 | box_mask_z[:, index, index] = 1 49 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 50 | 51 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'GT_BOX': 52 | box_mask_z = torch.zeros([bs, template_size, template_size], device=device) 53 | # box_mask_z_ori = data['template_seg'][0].view(-1, 1, *data['template_seg'].shape[2:]) # (batch, 1, 128, 128) 54 | box_mask_z = generate_bbox_mask(box_mask_z, gt_bbox * template_size).unsqueeze(1).to( 55 | torch.float) # (batch, 1, 128, 128) 56 | # box_mask_z_vis = box_mask_z.cpu().numpy() 57 | box_mask_z = F.interpolate(box_mask_z, scale_factor=1. / cfg.MODEL.BACKBONE.STRIDE, mode='bilinear', 58 | align_corners=False) 59 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 60 | # box_mask_z_vis = box_mask_z[:, 0, ...].cpu().numpy() 61 | # gaussian_maps_vis = generate_heatmap(data['template_anno'], self.cfg.DATA.TEMPLATE.SIZE, self.cfg.MODEL.STRIDE)[0].cpu().numpy() 62 | else: 63 | raise NotImplementedError 64 | 65 | return box_mask_z 66 | 67 | 68 | def adjust_keep_rate(epoch, warmup_epochs, total_epochs, ITERS_PER_EPOCH, base_keep_rate=0.5, max_keep_rate=1, iters=-1): 69 | if epoch < warmup_epochs: 70 | return 1 71 | if epoch >= total_epochs: 72 | return base_keep_rate 73 | if iters == -1: 74 | iters = epoch * ITERS_PER_EPOCH 75 | total_iters = ITERS_PER_EPOCH * (total_epochs - warmup_epochs) 76 | iters = iters - ITERS_PER_EPOCH * warmup_epochs 77 | keep_rate = base_keep_rate + (max_keep_rate - base_keep_rate) \ 78 | * (math.cos(iters / total_iters * math.pi) + 1) * 0.5 79 | 80 | return keep_rate 81 | -------------------------------------------------------------------------------- /lib/utils/focal_loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FocalLoss(nn.Module, ABC): 9 | def __init__(self, alpha=2, beta=4): 10 | super(FocalLoss, self).__init__() 11 | self.alpha = alpha 12 | self.beta = beta 13 | 14 | def forward(self, prediction, target): 15 | positive_index = target.eq(1).float() 16 | negative_index = target.lt(1).float() 17 | 18 | negative_weights = torch.pow(1 - target, self.beta) 19 | # clamp min value is set to 1e-12 to maintain the numerical stability 20 | prediction = torch.clamp(prediction, 1e-12) 21 | 22 | positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index 23 | negative_loss = torch.log(1 - prediction) * torch.pow(prediction, 24 | self.alpha) * negative_weights * negative_index 25 | 26 | num_positive = positive_index.float().sum() 27 | positive_loss = positive_loss.sum() 28 | negative_loss = negative_loss.sum() 29 | 30 | if num_positive == 0: 31 | loss = -negative_loss 32 | else: 33 | loss = -(positive_loss + negative_loss) / num_positive 34 | 35 | return loss 36 | 37 | 38 | class LBHinge(nn.Module): 39 | """Loss that uses a 'hinge' on the lower bound. 40 | This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is 41 | also smaller than that threshold. 42 | args: 43 | error_matric: What base loss to use (MSE by default). 44 | threshold: Threshold to use for the hinge. 45 | clip: Clip the loss if it is above this value. 46 | """ 47 | def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None): 48 | super().__init__() 49 | self.error_metric = error_metric 50 | self.threshold = threshold if threshold is not None else -100 51 | self.clip = clip 52 | 53 | def forward(self, prediction, label, target_bb=None): 54 | negative_mask = (label < self.threshold).float() 55 | positive_mask = (1.0 - negative_mask) 56 | 57 | prediction = negative_mask * F.relu(prediction) + positive_mask * prediction 58 | 59 | loss = self.error_metric(prediction, positive_mask * label) 60 | 61 | if self.clip is not None: 62 | loss = torch.min(loss, torch.tensor([self.clip], device=loss.device)) 63 | return loss -------------------------------------------------------------------------------- /lib/utils/heapmap_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def generate_heatmap(bboxes, patch_size=320, stride=16): 6 | """ 7 | Generate ground truth heatmap same as CenterNet 8 | Args: 9 | bboxes (torch.Tensor): shape of [num_search, bs, 4] 10 | 11 | Returns: 12 | gaussian_maps: list of generated heatmap 13 | 14 | """ 15 | gaussian_maps = [] 16 | heatmap_size = patch_size // stride 17 | for single_patch_bboxes in bboxes: 18 | bs = single_patch_bboxes.shape[0] 19 | gt_scoremap = torch.zeros(bs, heatmap_size, heatmap_size) 20 | classes = torch.arange(bs).to(torch.long) 21 | bbox = single_patch_bboxes * heatmap_size 22 | wh = bbox[:, 2:] 23 | centers_int = (bbox[:, :2] + wh / 2).round() 24 | CenterNetHeatMap.generate_score_map(gt_scoremap, classes, wh, centers_int, 0.7) 25 | gaussian_maps.append(gt_scoremap.to(bbox.device)) 26 | return gaussian_maps 27 | 28 | 29 | class CenterNetHeatMap(object): 30 | @staticmethod 31 | def generate_score_map(fmap, gt_class, gt_wh, centers_int, min_overlap): 32 | radius = CenterNetHeatMap.get_gaussian_radius(gt_wh, min_overlap) 33 | radius = torch.clamp_min(radius, 0) 34 | radius = radius.type(torch.int).cpu().numpy() 35 | for i in range(gt_class.shape[0]): 36 | channel_index = gt_class[i] 37 | CenterNetHeatMap.draw_gaussian(fmap[channel_index], centers_int[i], radius[i]) 38 | 39 | @staticmethod 40 | def get_gaussian_radius(box_size, min_overlap): 41 | """ 42 | copyed from CornerNet 43 | box_size (w, h), it could be a torch.Tensor, numpy.ndarray, list or tuple 44 | notice: we are using a bug-version, please refer to fix bug version in CornerNet 45 | """ 46 | # box_tensor = torch.Tensor(box_size) 47 | box_tensor = box_size 48 | width, height = box_tensor[..., 0], box_tensor[..., 1] 49 | 50 | a1 = 1 51 | b1 = height + width 52 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 53 | sq1 = torch.sqrt(b1 ** 2 - 4 * a1 * c1) 54 | r1 = (b1 + sq1) / 2 55 | 56 | a2 = 4 57 | b2 = 2 * (height + width) 58 | c2 = (1 - min_overlap) * width * height 59 | sq2 = torch.sqrt(b2 ** 2 - 4 * a2 * c2) 60 | r2 = (b2 + sq2) / 2 61 | 62 | a3 = 4 * min_overlap 63 | b3 = -2 * min_overlap * (height + width) 64 | c3 = (min_overlap - 1) * width * height 65 | sq3 = torch.sqrt(b3 ** 2 - 4 * a3 * c3) 66 | r3 = (b3 + sq3) / 2 67 | 68 | return torch.min(r1, torch.min(r2, r3)) 69 | 70 | @staticmethod 71 | def gaussian2D(radius, sigma=1): 72 | # m, n = [(s - 1.) / 2. for s in shape] 73 | m, n = radius 74 | y, x = np.ogrid[-m: m + 1, -n: n + 1] 75 | 76 | gauss = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 77 | gauss[gauss < np.finfo(gauss.dtype).eps * gauss.max()] = 0 78 | return gauss 79 | 80 | @staticmethod 81 | def draw_gaussian(fmap, center, radius, k=1): 82 | diameter = 2 * radius + 1 83 | gaussian = CenterNetHeatMap.gaussian2D((radius, radius), sigma=diameter / 6) 84 | gaussian = torch.Tensor(gaussian) 85 | x, y = int(center[0]), int(center[1]) 86 | height, width = fmap.shape[:2] 87 | 88 | left, right = min(x, radius), min(width - x, radius + 1) 89 | top, bottom = min(y, radius), min(height - y, radius + 1) 90 | 91 | masked_fmap = fmap[y - top: y + bottom, x - left: x + right] 92 | masked_gaussian = gaussian[radius - top: radius + bottom, radius - left: radius + right] 93 | if min(masked_gaussian.shape) > 0 and min(masked_fmap.shape) > 0: 94 | masked_fmap = torch.max(masked_fmap, masked_gaussian * k) 95 | fmap[y - top: y + bottom, x - left: x + right] = masked_fmap 96 | # return fmap 97 | 98 | 99 | def compute_grids(features, strides): 100 | """ 101 | grids regret to the input image size 102 | """ 103 | grids = [] 104 | for level, feature in enumerate(features): 105 | h, w = feature.size()[-2:] 106 | shifts_x = torch.arange( 107 | 0, w * strides[level], 108 | step=strides[level], 109 | dtype=torch.float32, device=feature.device) 110 | shifts_y = torch.arange( 111 | 0, h * strides[level], 112 | step=strides[level], 113 | dtype=torch.float32, device=feature.device) 114 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 115 | shift_x = shift_x.reshape(-1) 116 | shift_y = shift_y.reshape(-1) 117 | grids_per_level = torch.stack((shift_x, shift_y), dim=1) + \ 118 | strides[level] // 2 119 | grids.append(grids_per_level) 120 | return grids 121 | 122 | 123 | def get_center3x3(locations, centers, strides, range=3): 124 | ''' 125 | Inputs: 126 | locations: M x 2 127 | centers: N x 2 128 | strides: M 129 | ''' 130 | range = (range - 1) / 2 131 | M, N = locations.shape[0], centers.shape[0] 132 | locations_expanded = locations.view(M, 1, 2).expand(M, N, 2) # M x N x 2 133 | centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2 134 | strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N 135 | centers_discret = ((centers_expanded / strides_expanded).int() * strides_expanded).float() + \ 136 | strides_expanded / 2 # M x N x 2 137 | dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs() 138 | dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs() 139 | return (dist_x <= strides_expanded[:, :, 0] * range) & \ 140 | (dist_y <= strides_expanded[:, :, 0] * range) 141 | 142 | 143 | def get_pred(score_map_ctr, size_map, offset_map, feat_size): 144 | max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True) 145 | 146 | idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1) 147 | size = size_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 148 | offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 149 | 150 | return size * feat_size, offset 151 | -------------------------------------------------------------------------------- /lib/utils/lmdb_utils.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import numpy as np 3 | import cv2 4 | import json 5 | 6 | LMDB_ENVS = dict() 7 | LMDB_HANDLES = dict() 8 | LMDB_FILELISTS = dict() 9 | 10 | 11 | def get_lmdb_handle(name): 12 | global LMDB_HANDLES, LMDB_FILELISTS 13 | item = LMDB_HANDLES.get(name, None) 14 | if item is None: 15 | env = lmdb.open(name, readonly=True, lock=False, readahead=False, meminit=False) 16 | LMDB_ENVS[name] = env 17 | item = env.begin(write=False) 18 | LMDB_HANDLES[name] = item 19 | 20 | return item 21 | 22 | 23 | def decode_img(lmdb_fname, key_name): 24 | handle = get_lmdb_handle(lmdb_fname) 25 | binfile = handle.get(key_name.encode()) 26 | if binfile is None: 27 | print("Illegal data detected. %s %s" % (lmdb_fname, key_name)) 28 | s = np.frombuffer(binfile, np.uint8) 29 | x = cv2.cvtColor(cv2.imdecode(s, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 30 | return x 31 | 32 | 33 | def decode_str(lmdb_fname, key_name): 34 | handle = get_lmdb_handle(lmdb_fname) 35 | binfile = handle.get(key_name.encode()) 36 | string = binfile.decode() 37 | return string 38 | 39 | 40 | def decode_json(lmdb_fname, key_name): 41 | return json.loads(decode_str(lmdb_fname, key_name)) 42 | 43 | 44 | if __name__ == "__main__": 45 | lmdb_fname = "/data/sda/v-yanbi/iccv21/LittleBoy_clean/data/got10k_lmdb" 46 | '''Decode image''' 47 | # key_name = "test/GOT-10k_Test_000001/00000001.jpg" 48 | # img = decode_img(lmdb_fname, key_name) 49 | # cv2.imwrite("001.jpg", img) 50 | '''Decode str''' 51 | # key_name = "test/list.txt" 52 | # key_name = "train/GOT-10k_Train_000001/groundtruth.txt" 53 | key_name = "train/GOT-10k_Train_000001/absence.label" 54 | str_ = decode_str(lmdb_fname, key_name) 55 | print(str_) 56 | -------------------------------------------------------------------------------- /lib/utils/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def merge_template_search(inp_list, return_search=False, return_template=False): 5 | """NOTICE: search region related features must be in the last place""" 6 | seq_dict = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 7 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 8 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} 9 | if return_search: 10 | x = inp_list[-1] 11 | seq_dict.update({"feat_x": x["feat"], "mask_x": x["mask"], "pos_x": x["pos"]}) 12 | if return_template: 13 | z = inp_list[0] 14 | seq_dict.update({"feat_z": z["feat"], "mask_z": z["mask"], "pos_z": z["pos"]}) 15 | return seq_dict 16 | 17 | 18 | def get_qkv(inp_list): 19 | """The 1st element of the inp_list is about the template, 20 | the 2nd (the last) element is about the search region""" 21 | dict_x = inp_list[-1] 22 | dict_c = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 23 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 24 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} # concatenated dict 25 | q = dict_x["feat"] + dict_x["pos"] 26 | k = dict_c["feat"] + dict_c["pos"] 27 | v = dict_c["feat"] 28 | key_padding_mask = dict_c["mask"] 29 | return q, k, v, key_padding_mask 30 | -------------------------------------------------------------------------------- /lib/vis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkTempest/BAT/ccf9b2f6ae3e810f4e7318c9d0b62083deb7ec89/lib/vis/__init__.py -------------------------------------------------------------------------------- /lib/vis/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import cv2 5 | 6 | 7 | def draw_figure(fig): 8 | fig.canvas.draw() 9 | fig.canvas.flush_events() 10 | plt.pause(0.001) 11 | 12 | 13 | def show_tensor(a: torch.Tensor, fig_num = None, title = None, range=(None, None), ax=None): 14 | """Display a 2D tensor. 15 | args: 16 | fig_num: Figure number. 17 | title: Title of figure. 18 | """ 19 | a_np = a.squeeze().cpu().clone().detach().numpy() 20 | if a_np.ndim == 3: 21 | a_np = np.transpose(a_np, (1, 2, 0)) 22 | 23 | if ax is None: 24 | fig = plt.figure(fig_num) 25 | plt.tight_layout() 26 | plt.cla() 27 | plt.imshow(a_np, vmin=range[0], vmax=range[1]) 28 | plt.axis('off') 29 | plt.axis('equal') 30 | if title is not None: 31 | plt.title(title) 32 | draw_figure(fig) 33 | else: 34 | ax.cla() 35 | ax.imshow(a_np, vmin=range[0], vmax=range[1]) 36 | ax.set_axis_off() 37 | ax.axis('equal') 38 | if title is not None: 39 | ax.set_title(title) 40 | draw_figure(plt.gcf()) 41 | 42 | 43 | def plot_graph(a: torch.Tensor, fig_num = None, title = None): 44 | """Plot graph. Data is a 1D tensor. 45 | args: 46 | fig_num: Figure number. 47 | title: Title of figure. 48 | """ 49 | a_np = a.squeeze().cpu().clone().detach().numpy() 50 | if a_np.ndim > 1: 51 | raise ValueError 52 | fig = plt.figure(fig_num) 53 | # plt.tight_layout() 54 | plt.cla() 55 | plt.plot(a_np) 56 | if title is not None: 57 | plt.title(title) 58 | draw_figure(fig) 59 | 60 | 61 | def show_image_with_boxes(im, boxes, iou_pred=None, disp_ids=None): 62 | im_np = im.clone().cpu().squeeze().numpy() 63 | im_np = np.ascontiguousarray(im_np.transpose(1, 2, 0).astype(np.uint8)) 64 | 65 | boxes = boxes.view(-1, 4).cpu().numpy().round().astype(int) 66 | 67 | # Draw proposals 68 | for i_ in range(boxes.shape[0]): 69 | if disp_ids is None or disp_ids[i_]: 70 | bb = boxes[i_, :] 71 | disp_color = (i_*38 % 256, (255 - i_*97) % 256, (123 + i_*66) % 256) 72 | cv2.rectangle(im_np, (bb[0], bb[1]), (bb[0] + bb[2], bb[1] + bb[3]), 73 | disp_color, 1) 74 | 75 | if iou_pred is not None: 76 | text_pos = (bb[0], bb[1] - 5) 77 | cv2.putText(im_np, 'ID={} IOU = {:3.2f}'.format(i_, iou_pred[i_]), text_pos, 78 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, bottomLeftOrigin=False) 79 | 80 | im_tensor = torch.from_numpy(im_np.transpose(2, 0, 1)).float() 81 | 82 | return im_tensor 83 | 84 | 85 | 86 | def _pascal_color_map(N=256, normalized=False): 87 | """ 88 | Python implementation of the color map function for the PASCAL VOC data set. 89 | Official Matlab version can be found in the PASCAL VOC devkit 90 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 91 | """ 92 | 93 | def bitget(byteval, idx): 94 | return (byteval & (1 << idx)) != 0 95 | 96 | dtype = 'float32' if normalized else 'uint8' 97 | cmap = np.zeros((N, 3), dtype=dtype) 98 | for i in range(N): 99 | r = g = b = 0 100 | c = i 101 | for j in range(8): 102 | r = r | (bitget(c, 0) << 7 - j) 103 | g = g | (bitget(c, 1) << 7 - j) 104 | b = b | (bitget(c, 2) << 7 - j) 105 | c = c >> 3 106 | 107 | cmap[i] = np.array([r, g, b]) 108 | 109 | cmap = cmap / 255 if normalized else cmap 110 | return cmap 111 | 112 | 113 | def overlay_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): 114 | """ Overlay mask over image. 115 | Source: https://github.com/albertomontesg/davis-interactive/blob/master/davisinteractive/utils/visualization.py 116 | This function allows you to overlay a mask over an image with some 117 | transparency. 118 | # Arguments 119 | im: Numpy Array. Array with the image. The shape must be (H, W, 3) and 120 | the pixels must be represented as `np.uint8` data type. 121 | ann: Numpy Array. Array with the mask. The shape must be (H, W) and the 122 | values must be intergers 123 | alpha: Float. Proportion of alpha to apply at the overlaid mask. 124 | colors: Numpy Array. Optional custom colormap. It must have shape (N, 3) 125 | being N the maximum number of colors to represent. 126 | contour_thickness: Integer. Thickness of each object index contour draw 127 | over the overlay. This function requires to have installed the 128 | package `opencv-python`. 129 | # Returns 130 | Numpy Array: Image of the overlay with shape (H, W, 3) and data type 131 | `np.uint8`. 132 | """ 133 | im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) 134 | if im.shape[:-1] != ann.shape: 135 | raise ValueError('First two dimensions of `im` and `ann` must match') 136 | if im.shape[-1] != 3: 137 | raise ValueError('im must have three channels at the 3 dimension') 138 | 139 | colors = colors or _pascal_color_map() 140 | colors = np.asarray(colors, dtype=np.uint8) 141 | 142 | mask = colors[ann] 143 | fg = im * alpha + (1 - alpha) * mask 144 | 145 | img = im.copy() 146 | img[ann > 0] = fg[ann > 0] 147 | 148 | if contour_thickness: # pragma: no cover 149 | import cv2 150 | for obj_id in np.unique(ann[ann > 0]): 151 | contours = cv2.findContours((ann == obj_id).astype( 152 | np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 153 | cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), 154 | contour_thickness) 155 | return img 156 | -------------------------------------------------------------------------------- /lib/vis/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def numpy_to_torch(a: np.ndarray): 6 | return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # This is a sample Python script. 2 | 3 | # Press Shift+F10 to execute it or replace it with your code. 4 | # Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. 5 | 6 | 7 | def print_hi(name): 8 | # Use a breakpoint in the code line below to debug your script. 9 | print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. 10 | 11 | 12 | # Press the green button in the gutter to run the script. 13 | if __name__ == '__main__': 14 | print_hi('PyCharm') 15 | 16 | # See PyCharm help at https://www.jetbrains.com/help/pycharm/ 17 | -------------------------------------------------------------------------------- /tracking/__pycache__/_init_paths.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SparkTempest/BAT/ccf9b2f6ae3e810f4e7318c9d0b62083deb7ec89/tracking/__pycache__/_init_paths.cpython-37.pyc -------------------------------------------------------------------------------- /tracking/_init_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | 14 | this_dir = osp.dirname(__file__) 15 | 16 | prj_path = osp.join(this_dir, '..') 17 | add_path(prj_path) 18 | -------------------------------------------------------------------------------- /tracking/create_default_local_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import _init_paths 4 | from lib.train.admin import create_default_local_file_train 5 | from lib.test.evaluation import create_default_local_file_test 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description='Create default local file on ITP or PAI') 10 | parser.add_argument("--workspace_dir", type=str, required=True) 11 | parser.add_argument("--data_dir", type=str, required=True) 12 | parser.add_argument("--save_dir", type=str, required=True) 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | if __name__ == "__main__": 18 | args = parse_args() 19 | workspace_dir = os.path.realpath(args.workspace_dir) 20 | data_dir = os.path.realpath(args.data_dir) 21 | save_dir = os.path.realpath(args.save_dir) 22 | create_default_local_file_train(workspace_dir, data_dir) 23 | create_default_local_file_test(workspace_dir, data_dir, save_dir) 24 | -------------------------------------------------------------------------------- /tracking/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import torch 5 | 6 | 7 | def parse_args(): 8 | """ 9 | args for training. 10 | """ 11 | parser = argparse.ArgumentParser(description='Parse args for training') 12 | # for train 13 | parser.add_argument('--script', type=str, default='bat', help='training script name') 14 | parser.add_argument('--config', type=str, help='yaml configure file name') 15 | parser.add_argument('--save_dir', type=str, default='./output', help='root directory to save checkpoints, logs, and tensorboard') 16 | parser.add_argument('--mode', type=str, choices=["single", "multiple", "multi_node"], default="multiple", 17 | help="train on single gpu or multiple gpus") 18 | parser.add_argument('--nproc_per_node', type=int, default=torch.cuda.device_count(), help="number of GPUs per node") # specify when mode is multiple 19 | parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format 20 | parser.add_argument('--script_prv', type=str, help='training script name') 21 | parser.add_argument('--config_prv', type=str, default='baseline', help='yaml configure file name') 22 | parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb 23 | # for knowledge distillation 24 | parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation 25 | parser.add_argument('--script_teacher', type=str, help='teacher script name') 26 | parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name') 27 | 28 | # for multiple machines 29 | parser.add_argument('--rank', type=int, help='Rank of the current process.') 30 | parser.add_argument('--world-size', type=int, help='Number of processes participating in the job.') 31 | parser.add_argument('--ip', type=str, default='127.0.0.1', help='IP of the current rank 0.') 32 | parser.add_argument('--port', type=int, default='20000', help='Port of the current rank 0.') 33 | 34 | args = parser.parse_args() 35 | 36 | return args 37 | 38 | 39 | def main(): 40 | args = parse_args() 41 | print('args.config ', args.config) 42 | if args.mode == "single": 43 | train_cmd = "python lib/train/run_training.py --script %s --config %s --save_dir %s --use_lmdb %d " \ 44 | "--script_prv %s --config_prv %s --distill %d --script_teacher %s --config_teacher %s --use_wandb %d"\ 45 | % (args.script, args.config, args.save_dir, args.use_lmdb, args.script_prv, args.config_prv, 46 | args.distill, args.script_teacher, args.config_teacher, args.use_wandb) 47 | elif args.mode == "multiple": 48 | train_cmd = "python -m torch.distributed.launch --nproc_per_node %d --master_port %d lib/train/run_training.py " \ 49 | "--script %s --config %s --save_dir %s --use_lmdb %d --script_prv %s --config_prv %s --use_wandb %d " \ 50 | "--distill %d --script_teacher %s --config_teacher %s" \ 51 | % (args.nproc_per_node, random.randint(10000, 50000), args.script, args.config, args.save_dir, args.use_lmdb, args.script_prv, args.config_prv, args.use_wandb, 52 | args.distill, args.script_teacher, args.config_teacher) 53 | elif args.mode == "multi_node": 54 | train_cmd = "python -m torch.distributed.launch --nproc_per_node %d --master_addr %s --master_port %d --nnodes %d --node_rank %d lib/train/run_training.py " \ 55 | "--script %s --config %s --save_dir %s --use_lmdb %d --script_prv %s --config_prv %s --use_wandb %d " \ 56 | "--distill %d --script_teacher %s --config_teacher %s" \ 57 | % (args.nproc_per_node, args.ip, args.port, args.world_size, args.rank, args.script, args.config, args.save_dir, args.use_lmdb, args.script_prv, args.config_prv, args.use_wandb, 58 | args.distill, args.script_teacher, args.config_teacher) 59 | else: 60 | raise ValueError("mode should be 'single' or 'multiple' or 'multi_node'.") 61 | 62 | 63 | os.system(train_cmd) 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /train_bat.sh: -------------------------------------------------------------------------------- 1 | # Training BAT 2 | #NCCL_P2P_LEVEL=NVL python tracking/train.py --script bat --config rgbd --save_dir ./output --mode multiple --nproc_per_node 4 3 | NCCL_P2P_LEVEL=NVL python tracking/train.py --script bat --config rgbt --save_dir ./output --mode multiple --nproc_per_node 3 4 | #python tracking/train.py --script bat --config rgbt --save_dir ./output --mode multiple --nproc_per_node 1 --use_wandb 1 5 | #NCCL_P2P_LEVEL=NVL python tracking/train.py --script bat --config rgbe --save_dir ./output --mode multiple --nproc_per_node 4 6 | --------------------------------------------------------------------------------