├── .gitignore ├── LICENSE ├── README.md ├── experiments └── lightfc │ ├── baseline_v1_release_backbone_tinyvit.yaml │ └── mobilnetv2_p_pwcorr_se_scf_sc_iab_sc_adj_concat_repn33_se_conv33_center_wiou.yaml ├── external └── vot20st │ └── lightfc │ ├── config.yaml │ ├── exp.sh │ └── trackers.ini ├── install.sh ├── lib ├── __init__.py ├── models │ ├── __init__.py │ ├── lightfc │ │ ├── __init__.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── mobilnetv2.py │ │ │ └── tiny_vit.py │ │ ├── fusion │ │ │ ├── __init__.py │ │ │ └── ecm.py │ │ └── head │ │ │ ├── __init__.py │ │ │ └── erh.py │ └── tracker_model.py ├── test │ ├── __init__.py │ ├── analysis │ │ ├── extract_results.py │ │ └── plot_results.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── data.py │ │ ├── datasets.py │ │ ├── dtbdataset.py │ │ ├── environment.py │ │ ├── got10kdataset.py │ │ ├── itbdataset.py │ │ ├── lasot_extdataset.py │ │ ├── lasot_lmdbdataset.py │ │ ├── lasotdataset.py │ │ ├── lasotextensionsubsetdataset.py │ │ ├── local.py │ │ ├── nfsdataset.py │ │ ├── otbdataset.py │ │ ├── running.py │ │ ├── tc128cedataset.py │ │ ├── tc128dataset.py │ │ ├── tnl2kdataset.py │ │ ├── tracker.py │ │ ├── trackingnetdataset.py │ │ ├── uavdataset.py │ │ ├── uotdataset.py │ │ ├── utbdataset.py │ │ └── votdataset.py │ ├── parameter │ │ └── lightfc.py │ ├── tracker │ │ ├── __init__.py │ │ ├── basetracker.py │ │ ├── data_utils.py │ │ └── lightfc.py │ ├── utils │ │ ├── __init__.py │ │ ├── _init_paths.py │ │ ├── deploy.py │ │ ├── hann.py │ │ ├── load_text.py │ │ ├── params.py │ │ ├── transform_got10k.py │ │ └── transform_trackingnet.py │ └── vot_utils │ │ ├── lightfc.py │ │ ├── lightfc_vot.py │ │ ├── utils.py │ │ └── vot.py ├── train │ ├── __init__.py │ ├── _init_paths.py │ ├── actors │ │ ├── __init__.py │ │ ├── base_actor.py │ │ └── lightfc.py │ ├── admin │ │ ├── __init__.py │ │ ├── environment.py │ │ ├── local.py │ │ ├── multigpu.py │ │ ├── settings.py │ │ ├── stats.py │ │ └── tensorboard.py │ ├── data │ │ ├── __init__.py │ │ ├── base_functions.py │ │ ├── bounding_box_utils.py │ │ ├── image_loader.py │ │ ├── loader.py │ │ ├── processing.py │ │ ├── processing_utils.py │ │ ├── sampler.py │ │ ├── sequence_loader.py │ │ ├── sequence_sampler.py │ │ ├── transforms.py │ │ └── wandb_logger.py │ ├── data_specs │ │ └── README.md │ ├── dataset │ │ ├── COCO_tool.py │ │ ├── __init__.py │ │ ├── base_image_dataset.py │ │ ├── base_video_dataset.py │ │ ├── coco.py │ │ ├── coco_seq.py │ │ ├── coco_seq_lmdb.py │ │ ├── got10k.py │ │ ├── got10k_lmdb.py │ │ ├── imagenetvid.py │ │ ├── imagenetvid_lmdb.py │ │ ├── lasot.py │ │ ├── lasot_lmdb.py │ │ ├── tracking_net.py │ │ └── tracking_net_lmdb.py │ ├── loss │ │ ├── __init__.py │ │ ├── box_loss.py │ │ ├── cos_sim_loss.py │ │ ├── focal_loss.py │ │ ├── gfocal_loss.py │ │ ├── objective.py │ │ ├── siamcar_loss.py │ │ └── varifocal_loss.py │ ├── optimizer │ │ ├── anan.py │ │ └── lion.py │ ├── run_training.py │ ├── train_script.py │ └── trainers │ │ ├── __init__.py │ │ ├── base_trainer.py │ │ └── ltr_trainer.py ├── utils │ ├── __init__.py │ ├── box_ops.py │ ├── ce_utils.py │ ├── heapmap_utils.py │ ├── list_tools.py │ ├── lmdb_utils.py │ ├── load.py │ ├── merge.py │ ├── misc.py │ ├── registry.py │ ├── tensor.py │ └── variable_hook.py └── vis │ ├── __init__.py │ ├── plotting.py │ ├── utils.py │ └── visdom_cus.py └── tracking ├── analysis_results.py ├── profile_model.py ├── speed.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.idea 2 | *~ 3 | *.jpg 4 | *__pycache__* 5 | *.pyc 6 | *.txt 7 | *.vscode 8 | *.zip 9 | *.7z 10 | *.pytest_cache 11 | *.bin 12 | *.csv 13 | *.value 14 | #**/local.py 15 | *.log 16 | */checkpoints 17 | /data 18 | /.tag 19 | /tensorboard 20 | /logs 21 | /models 22 | /output 23 | /checkpoints 24 | /test 25 | /pretrained_models 26 | /external/vot20st/lightfc/sequences 27 | /external/vot20st/lightfc/results 28 | /external/vot19st 29 | /external/draw 30 | /dev -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LiYunfengLYF 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightFC 2 | 3 | The official implementation of LightFC 4 | 5 | ## News 6 | 7 | - 14 Oct 2023: our code is available now 8 | - 09 Oct 2023: our manuscript have submitted to [arxiv](https://arxiv.org/abs/2310.05392) 9 | - 12 Jan 2024: lightfc-vit with higher performance is released ! 10 | ## Install the environment 11 | 12 | **Option1**: Use the Anaconda 13 | ``` 14 | conda create -n lightfc python=3.9 15 | conda activate lightfc 16 | bash install.sh 17 | ``` 18 | 19 | ## Data Preparation 20 | Follow [stark](https://github.com/researchmm/Stark) and [ostrack](https://github.com/botaoye/OSTrack) frameworks to set your datasets 21 | 22 | ## File directory 23 | 24 | Project file directory should be like 25 | 26 | ``` 27 | ${YOUR_PROJECT_ROOT} 28 | -- experiments 29 | |-- lightfc 30 | -- external 31 | |-- vot20st 32 | -- lib 33 | |--models 34 | ... 35 | -- outputs (download and unzip the output.zip to obtain our checkpoints and row results) 36 | |--checkpoints 37 | |--... 38 | |--test 39 | |--... 40 | -- pretrained_models (if you want to train lightfc, put pretrained model here) 41 | |--mobilenetv2.pth (from torchvision model) 42 | ... 43 | -- tracking 44 | ... 45 | ``` 46 | 47 | Download lightfc checkpoint and raw results at [Google Drive](https://drive.google.com/file/d/1ns7NQJCt078547X483skqjX1qM1rBqLP/view) 48 | 49 | Download lightfc-vit checkpoint and raw results at [Google Drive](https://drive.google.com/file/d/1tckIW9P0RFheAAoGoSZR9Lgnet7-HNOL/view?usp=sharing) 50 | 51 | 52 | Then go to these two files, and modify the paths 53 | ``` 54 | lib/train/admin/local.py # paths about training 55 | lib/test/evaluation/local.py # paths about testing 56 | ``` 57 | 58 | 59 | ## Train LightFC 60 | Training with multiple GPUs using DDP 61 | ``` 62 | python tracking/train.py --script LightFC --config mobilnetv2_p_pwcorr_se_scf_sc_iab_sc_adj_concat_repn33_se_conv33_center_wiou --save_dir . --mode multiple --nproc_per_node 2 63 | ``` 64 | If you want to train lightfc, please download https://download.pytorch.org/models/mobilenet_v2-b0353104.pth rather than https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth 65 | 66 | if you want to train lightfc-vit, please download https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth 67 | 68 | ## Test and evaluate LightFC on benchmarks 69 | Go to **tracking/test.py** and modify the parameters 70 | ``` 71 | python tracking/test.py 72 | ``` 73 | 74 | Then go to **tracking/analysis_results.py** and modify the parameters 75 | ``` 76 | python tracking/analysis_results.py 77 | ``` 78 | ## Test FLOPs, Params, and Speed 79 | ``` 80 | # Params and FLOPs 81 | python tracking/profile_model.py 82 | # Speed 83 | python tracking/speed.py 84 | ``` 85 | 86 | ## Acknowledgments 87 | * Thanks for the great [stark](https://github.com/researchmm/Stark) and [ostrack](https://github.com/botaoye/OSTrack) Libraries, which helps us to quickly implement our ideas. 88 | -------------------------------------------------------------------------------- /experiments/lightfc/baseline_v1_release_backbone_tinyvit.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | NUMBER: 1 23 | TRAIN: 24 | DATASETS_NAME: 25 | - LASOT 26 | - GOT10K_vottrain 27 | - COCO17 28 | - TRACKINGNET 29 | DATASETS_RATIO: 30 | - 1 31 | - 1 32 | - 1 33 | - 1 34 | SAMPLE_PER_EPOCH: 60000 35 | VAL: 36 | DATASETS_NAME: 37 | - GOT10K_votval 38 | DATASETS_RATIO: 39 | - 1 40 | SAMPLE_PER_EPOCH: 10000 41 | MODEL: 42 | BACKBONE: 43 | TYPE: 'tiny_vit_5m_224' 44 | STRIDE: 16 45 | CHANNEL: 160 46 | USE_PRETRAINED: True 47 | PRETRAIN_FILE: 'tiny_vit_5m_22k_distill.pth' 48 | LOAD_MODE: 3 49 | PARAMS: 50 | pretrained: False 51 | NECK: 52 | USE_NECK: False 53 | FUSION: 54 | TYPE: 'PWCorr_SE_SCF31_IAB11_Concat_Release' 55 | CHANNEL: 160 56 | PARAMS: 57 | num_kernel: 64 58 | adj_channel: 96 59 | HEAD: 60 | TYPE: 'RepN33_SE_Center_Concat' 61 | CHANNEL: 160 62 | PARAMS: 63 | inplanes: 256 64 | channel: 256 65 | feat_sz: 16 66 | stride: 16 67 | freeze_bn: False 68 | TRAIN: 69 | # core 70 | EPOCH: 400 71 | BATCH_SIZE: 32 72 | NUM_WORKER: 8 73 | LR: 0.001 74 | SEED: 42 75 | # loss weight 76 | GIOU_WEIGHT: 2.0 77 | L1_WEIGHT: 5.0 78 | LOC_WEIGHT: 1.0 79 | # optimizer 80 | OPTIMIZER: ADAMW 81 | # scheduler 82 | SCHEDULER: 83 | TYPE: step 84 | DECAY_RATE: 0.1 85 | # Mstep 86 | # MILESTONES: [20] 87 | # GAMMA: 0.1 88 | LR_DROP_EPOCH: 160 89 | WEIGHT_DECAY: 0.0001 90 | MAX_GRAD_NORM: 5.0 91 | # other 92 | BACKBONE_MULTIPLIER: 0.1 93 | DROP_PATH_RATE: 0.1 94 | GRAD_CLIP_NORM: 0.1 95 | # trainer 96 | SAVE_INTERVAL: 5 97 | VAL_EPOCH_INTERVAL: 1 98 | PRINT_INTERVAL: 50 99 | AMP: False 100 | # objective-loss 101 | L_LOSS: 'l1' 102 | BOX_LOSS: 'wiou' 103 | CLS_LOSS: 'focal' 104 | # SWA 105 | # USE_SWA: True 106 | # SWA_EPOCH: 195 107 | TEST: 108 | EPOCH: 400 109 | SEARCH_FACTOR: 4.0 110 | SEARCH_SIZE: 256 111 | TEMPLATE_FACTOR: 2.0 112 | TEMPLATE_SIZE: 128 # 128 -------------------------------------------------------------------------------- /experiments/lightfc/mobilnetv2_p_pwcorr_se_scf_sc_iab_sc_adj_concat_repn33_se_conv33_center_wiou.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | MAX_SAMPLE_INTERVAL: 200 3 | MEAN: 4 | - 0.485 5 | - 0.456 6 | - 0.406 7 | SEARCH: 8 | CENTER_JITTER: 3 9 | FACTOR: 4.0 10 | SCALE_JITTER: 0.25 11 | SIZE: 256 12 | NUMBER: 1 13 | STD: 14 | - 0.229 15 | - 0.224 16 | - 0.225 17 | TEMPLATE: 18 | CENTER_JITTER: 0 19 | FACTOR: 2.0 20 | SCALE_JITTER: 0 21 | SIZE: 128 22 | NUMBER: 1 23 | TRAIN: 24 | DATASETS_NAME: 25 | - LASOT 26 | - GOT10K_vottrain 27 | - COCO17 28 | - TRACKINGNET 29 | DATASETS_RATIO: 30 | - 1 31 | - 1 32 | - 1 33 | - 1 34 | SAMPLE_PER_EPOCH: 60000 35 | VAL: 36 | DATASETS_NAME: 37 | - GOT10K_votval 38 | DATASETS_RATIO: 39 | - 1 40 | SAMPLE_PER_EPOCH: 10000 41 | MODEL: 42 | BACKBONE: 43 | TYPE: 'MobileNetV2' 44 | STRIDE: 16 45 | CHANNEL: 96 46 | USE_PRETRAINED: True 47 | PRETRAIN_FILE: 'mobilenet_v2.pth' 48 | LOAD_MODE: 1 49 | 50 | NECK: 51 | USE_NECK: False 52 | FUSION: 53 | TYPE: 'pwcorr_se_scf_sc_iab_sc_concat' 54 | CHANNEL: 96 55 | PARAMS: 56 | num_kernel: 64 57 | adj_channel: 96 58 | HEAD: 59 | TYPE: 'repn33_se_center_concat' 60 | CHANNEL: 96 61 | PARAMS: 62 | inplanes: 192 63 | channel: 256 64 | feat_sz: 16 65 | stride: 16 66 | freeze_bn: False 67 | TRAIN: 68 | # core 69 | EPOCH: 400 70 | BATCH_SIZE: 32 71 | NUM_WORKER: 8 72 | LR: 0.001 73 | SEED: 42 74 | # loss weight 75 | GIOU_WEIGHT: 2.0 76 | L1_WEIGHT: 5.0 77 | LOC_WEIGHT: 1.0 78 | # optimizer 79 | OPTIMIZER: ADAMW 80 | # scheduler 81 | SCHEDULER: 82 | TYPE: step 83 | DECAY_RATE: 0.1 84 | # Mstep 85 | # MILESTONES: [20] 86 | # GAMMA: 0.1 87 | LR_DROP_EPOCH: 160 88 | WEIGHT_DECAY: 0.0001 89 | MAX_GRAD_NORM: 5.0 90 | # other 91 | BACKBONE_MULTIPLIER: 0.1 92 | DROP_PATH_RATE: 0.1 93 | GRAD_CLIP_NORM: 0.1 94 | # trainer 95 | SAVE_INTERVAL: 5 96 | VAL_EPOCH_INTERVAL: 1 97 | PRINT_INTERVAL: 50 98 | AMP: False 99 | # objective-loss 100 | L_LOSS: 'l1' 101 | BOX_LOSS: 'wiou' 102 | CLS_LOSS: 'focal' 103 | # SWA 104 | # USE_SWA: True 105 | # SWA_EPOCH: 195 106 | TEST: 107 | EPOCH: 400 108 | SEARCH_FACTOR: 4.0 109 | SEARCH_SIZE: 256 110 | TEMPLATE_FACTOR: 2.0 111 | TEMPLATE_SIZE: 128 # 128 -------------------------------------------------------------------------------- /external/vot20st/lightfc/config.yaml: -------------------------------------------------------------------------------- 1 | registry: 2 | - ./trackers.ini 3 | stack: vot2020st 4 | -------------------------------------------------------------------------------- /external/vot20st/lightfc/exp.sh: -------------------------------------------------------------------------------- 1 | 2 | export PYTHONPATH=/home/liyunfeng/code/project2/LightFC:$PYTHONPATH 3 | #vot test lightfc 4 | #vot evaluate --workspace . lightfc 5 | vot analysis --workspace . lightfc --format html 6 | 7 | 8 | -------------------------------------------------------------------------------- /external/vot20st/lightfc/trackers.ini: -------------------------------------------------------------------------------- 1 | [lightfc] # 2 | label = lightfc 3 | protocol = traxpython 4 | command = import os; sys.path.append('/home/liyunfeng/code/project2/LightFC');import lib.test.vot_utils.lightfc as lightfc 5 | paths = /home/liyunfeng/code/project2/LightFC ;/home/liyunfeng/code/project2/LightFC/lib/test/vot_utils 6 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | echo "****************** Installing pytorch ******************" 2 | pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116 3 | 4 | echo "" 5 | echo "" 6 | echo "****************** Installing yaml ******************" 7 | pip install PyYAML 8 | 9 | echo "" 10 | echo "" 11 | echo "****************** Installing easydict ******************" 12 | pip install easydict 13 | 14 | echo "****************** Installing opencv-python ******************" 15 | pip install opencv-python==4.5.5.64 16 | 17 | echo "" 18 | echo "" 19 | echo "****************** Installing pandas ******************" 20 | pip install pandas 21 | 22 | echo "****************** Installing jpeg4py python wrapper ******************" 23 | apt-get install libturbojpeg 24 | pip install jpeg4py 25 | 26 | 27 | echo "****************** Installing thop ******************" 28 | pip install thop 29 | 30 | echo "" 31 | echo "" 32 | echo "****************** Installing lmdb ******************" 33 | pip install lmdb 34 | 35 | echo "" 36 | echo "" 37 | echo "****************** Installing scipy ******************" 38 | pip install scipy 39 | 40 | echo "" 41 | echo "" 42 | echo "****************** Installing visdom ******************" 43 | pip install visdom 44 | 45 | 46 | echo "****************** Installing vot-toolkit python ******************" 47 | pip install git+https://github.com/votchallenge/vot-toolkit-python 48 | 49 | echo "" 50 | echo "" 51 | echo "****************** Installation complete! ******************" 52 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiYunfengLYF/LightFC/97dc3405ec8e8c5ad3d3ad95cae7f12e4f17b5b0/lib/__init__.py -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracker_model import LightFC -------------------------------------------------------------------------------- /lib/models/lightfc/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import MobileNetV2,tiny_vit_5m_224 2 | from .fusion import pwcorr_se_scf_sc_iab_sc_concat 3 | from .head import repn33_se_center_concat -------------------------------------------------------------------------------- /lib/models/lightfc/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilnetv2 import MobileNetV2 2 | from .tiny_vit import tiny_vit_5m_224 -------------------------------------------------------------------------------- /lib/models/lightfc/backbone/mobilnetv2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def _make_divisible(v, divisor, min_value=None): 4 | """ 5 | This function is taken from the original tf repo. 6 | It ensures that all layers have a channel number that is divisible by 8 7 | It can be seen here: 8 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 9 | :param v: 10 | :param divisor:a 11 | :param min_value: 12 | :return: 13 | """ 14 | if min_value is None: 15 | min_value = divisor 16 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 17 | # Make sure that round down does not go down by more than 10%. 18 | if new_v < 0.9 * v: 19 | new_v += divisor 20 | return new_v 21 | 22 | 23 | class ConvBNReLU(nn.Sequential): 24 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None): 25 | padding = (kernel_size - 1) // 2 26 | if norm_layer is None: 27 | norm_layer = nn.BatchNorm2d 28 | 29 | super(ConvBNReLU, self).__init__( 30 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 31 | norm_layer(out_planes), 32 | nn.ReLU6(inplace=True) 33 | ) 34 | 35 | 36 | class InvertedResidual(nn.Module): 37 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): 38 | super(InvertedResidual, self).__init__() 39 | self.stride = stride 40 | assert stride in [1, 2] 41 | 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | 45 | hidden_dim = int(round(inp * expand_ratio)) 46 | self.use_res_connect = self.stride == 1 and inp == oup 47 | 48 | layers = [] 49 | if expand_ratio != 1: 50 | # pw 51 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 52 | layers.extend([ 53 | # dw 54 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 55 | # pw-linear 56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 57 | norm_layer(oup), 58 | ]) 59 | self.conv = nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | if self.use_res_connect: 63 | return x + self.conv(x) 64 | else: 65 | return self.conv(x) 66 | 67 | 68 | class MobileNetV2(nn.Module): 69 | def __init__(self, 70 | num_classes=1000, 71 | width_mult=1.0, 72 | inverted_residual_setting=None, 73 | round_nearest=8, 74 | block=None, 75 | norm_layer=None, 76 | use_mid_feat=False): 77 | super(MobileNetV2, self).__init__() 78 | 79 | if block is None: 80 | block = InvertedResidual 81 | 82 | if norm_layer is None: 83 | norm_layer = nn.BatchNorm2d 84 | 85 | input_channel = 32 86 | last_channel = 1280 87 | 88 | self.use_mid_feat = use_mid_feat 89 | if self.use_mid_feat: 90 | self.middle_feat = [] 91 | 92 | if inverted_residual_setting is None: 93 | inverted_residual_setting = [ 94 | # t, c, n, s 95 | # 128 96 | [1, 16, 1, 1], 97 | # 64 98 | [6, 24, 2, 2], 99 | # 32 100 | [6, 32, 3, 2], 101 | # 16 102 | [6, 64, 4, 2], 103 | [6, 96, 3, 1], 104 | # 8 105 | # [6, 160, 3, 2], 106 | # [6, 320, 1, 1], 107 | ] 108 | 109 | # only check the first element, assuming user knows t,c,n,s are required 110 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 111 | raise ValueError("inverted_residual_setting should be non-empty " 112 | "or a 4-element list, got {}".format(inverted_residual_setting)) 113 | 114 | if self.use_mid_feat: 115 | def middle_feat_hook(module, fea_in, fea_out): 116 | self.middle_feat.append(fea_out) 117 | return None 118 | 119 | # building first layer 120 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 121 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 122 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 123 | # building inverted residual blocks 124 | for t, c, n, s in inverted_residual_setting: 125 | output_channel = _make_divisible(c * width_mult, round_nearest) 126 | for i in range(n): 127 | stride = s if i == 0 else 1 128 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 129 | input_channel = output_channel 130 | if self.use_mid_feat: 131 | if t == 6 and c == 32 and n == 3 and s == 2: 132 | features[-1].register_forward_hook(middle_feat_hook) 133 | # building last several layers 134 | # features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 135 | 136 | # make it nn.Sequential 137 | self.features = nn.Sequential(*features) 138 | 139 | def forward(self, x): 140 | if self.use_mid_feat: 141 | self.middle_feat = [] 142 | x = self.features(x) 143 | return {'16': x, '8': self.middle_feat[0]} 144 | else: 145 | x = self.features(x) 146 | return x 147 | -------------------------------------------------------------------------------- /lib/models/lightfc/fusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .ecm import pwcorr_se_scf_sc_iab_sc_concat 2 | -------------------------------------------------------------------------------- /lib/models/lightfc/fusion/ecm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from lib.models.lightfc.head.erh import repN31 5 | 6 | 7 | def pixel_wise_corr(z, x): 8 | ''' 9 | z is kernel ([32, 96, 8, 8]) 10 | x is search ([32, 96, 16, 16]) 11 | 12 | z -> (32, 64, 96) 13 | x -> (32, 96, 256) 14 | ''' 15 | b, c, h, w = x.size() 16 | z_mat = z.contiguous().view((b, c, -1)).transpose(1, 2) # (b,64,c) 17 | x_mat = x.contiguous().view((b, c, -1)) # (b,c,256) 18 | return torch.matmul(z_mat, x_mat).view((b, -1, h, w)) 19 | 20 | 21 | class SE(nn.Module): 22 | 23 | def __init__(self, channels=64, reduction=1): 24 | super(SE, self).__init__() 25 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 26 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) 29 | self.sigmoid = nn.Sigmoid() 30 | 31 | def forward(self, x): 32 | module_input = x 33 | x = self.avg_pool(x) 34 | x = self.fc1(x) 35 | x = self.relu(x) 36 | x = self.fc2(x) 37 | x = self.sigmoid(x) # nn.silu() 38 | return module_input * x 39 | 40 | 41 | class pwcorr_se_scf_sc_iab_sc_concat(nn.Module): 42 | def __init__(self, num_kernel=64, adj_channel=96): 43 | super().__init__() 44 | 45 | # pw-corr 46 | self.pw_corr = pixel_wise_corr 47 | self.ca = SE() 48 | 49 | # SCF 50 | self.conv33 = nn.Conv2d(in_channels=num_kernel, out_channels=num_kernel, kernel_size=3, stride=1, padding=1, 51 | groups=num_kernel) 52 | self.bn33 = nn.BatchNorm2d(num_kernel, eps=0.00001, momentum=0.1, affine=True, track_running_stats=True) 53 | 54 | self.conv11 = nn.Conv2d(in_channels=num_kernel, out_channels=num_kernel, kernel_size=1, stride=1, padding=0, 55 | groups=num_kernel) 56 | self.bn11 = nn.BatchNorm2d(num_kernel, eps=0.00001, momentum=0.1, affine=True, track_running_stats=True) 57 | 58 | # IAB 59 | self.conv_up = nn.Conv2d(in_channels=num_kernel, out_channels=num_kernel * 2, kernel_size=1, stride=1, 60 | padding=0) 61 | self.bn_up = nn.BatchNorm2d(num_kernel * 2, eps=0.00001, momentum=0.1, affine=True, track_running_stats=True) 62 | self.act = nn.GELU() 63 | 64 | self.conv_down = nn.Conv2d(in_channels=num_kernel * 2, out_channels=num_kernel, kernel_size=1, stride=1, 65 | padding=0) 66 | self.bn_down = nn.BatchNorm2d(num_kernel, eps=0.00001, momentum=0.1, affine=True, track_running_stats=True) 67 | 68 | self.adjust = nn.Conv2d(num_kernel, adj_channel, 1) 69 | 70 | def forward(self, z, x): 71 | corr = self.ca(self.pw_corr(z, x)) 72 | 73 | # scf + skip-connection 74 | corr = corr + self.bn11(self.conv11(corr)) + self.bn33(self.conv33(corr)) 75 | 76 | # iab + skip-connection 77 | corr = corr + self.bn_down(self.conv_down(self.act(self.bn_up(self.conv_up(corr))))) 78 | 79 | corr = self.adjust(corr) 80 | 81 | corr = torch.cat((corr, x), dim=1) 82 | 83 | return corr 84 | 85 | 86 | class pwcorr_se_repn31_sc_iab_sc_adj_concat(nn.Module): 87 | def __init__(self, num_kernel=64, adj_channel=96): 88 | super().__init__() 89 | self.pw_corr = pixel_wise_corr 90 | self.ca = SE() 91 | 92 | # SCF reparam structure 93 | self.repn31 = repN31(num_kernel, num_kernel, kernel_size=3, padding=1, groups=num_kernel,nonlinearity=nn.Identity) 94 | 95 | # IAB 96 | self.conv_up = nn.Conv2d(in_channels=num_kernel, out_channels=num_kernel * 2, kernel_size=1, stride=1, 97 | padding=0) 98 | self.bn_up = nn.BatchNorm2d(num_kernel * 2, eps=0.00001, momentum=0.1, affine=True, track_running_stats=True) 99 | self.act = nn.GELU() 100 | self.conv_down = nn.Conv2d(in_channels=num_kernel * 2, out_channels=num_kernel, kernel_size=1, stride=1, 101 | padding=0) 102 | self.bn_down = nn.BatchNorm2d(num_kernel, eps=0.00001, momentum=0.1, affine=True, track_running_stats=True) 103 | 104 | # adj layer 105 | self.adjust = nn.Conv2d(num_kernel, adj_channel, 1) 106 | 107 | def forward(self, z, x): 108 | corr = self.ca(self.pw_corr(z, x)) 109 | 110 | corr = corr + self.repn31(corr) 111 | 112 | corr = corr + self.bn_down(self.conv_down(self.act(self.bn_up(self.conv_up(corr))))) 113 | 114 | corr = self.adjust(corr) 115 | 116 | corr = torch.cat((corr, x), dim=1) 117 | return corr 118 | -------------------------------------------------------------------------------- /lib/models/lightfc/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .erh import repn33_se_center_concat -------------------------------------------------------------------------------- /lib/models/tracker_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from lib.models.lightfc import MobileNetV2, tiny_vit_5m_224, pwcorr_se_scf_sc_iab_sc_concat, repn33_se_center_concat 3 | from lib.utils.load import load_pretrain 4 | 5 | 6 | class LightFC(nn.Module): 7 | def __init__(self, cfg, env_num=0, training=False, ): 8 | super(LightFC, self).__init__() 9 | 10 | if cfg.MODEL.BACKBONE.TYPE == 'MobileNetV2': 11 | self.backbone = MobileNetV2() 12 | elif cfg.MODEL.BACKBONE.TYPE == 'tiny_vit_5m_224': 13 | self.backbone = tiny_vit_5m_224() 14 | self.training = training 15 | if self.train: 16 | load_pretrain(self.backbone, env_num=env_num, training=training, cfg=cfg, mode=cfg.MODEL.BACKBONE.LOAD_MODE) 17 | 18 | self.fusion = pwcorr_se_scf_sc_iab_sc_concat(num_kernel=cfg.MODEL.FUSION.PARAMS.num_kernel, 19 | adj_channel=cfg.MODEL.FUSION.PARAMS.adj_channel 20 | ) 21 | 22 | self.head = repn33_se_center_concat(inplanes=cfg.MODEL.HEAD.PARAMS.inplanes, 23 | channel=cfg.MODEL.HEAD.PARAMS.channel, 24 | feat_sz=cfg.MODEL.HEAD.PARAMS.feat_sz, 25 | stride=cfg.MODEL.HEAD.PARAMS.stride, 26 | freeze_bn=cfg.MODEL.HEAD.PARAMS.freeze_bn, 27 | ) 28 | 29 | def forward(self, z, x): 30 | if self.training: 31 | z = self.backbone(z) 32 | x = self.backbone(x) 33 | 34 | opt = self.fusion(z, x) 35 | 36 | out = self.head(opt) 37 | else: 38 | return self.forward_tracking(z, x) 39 | return out 40 | 41 | # 42 | def forward_backbone(self, z): 43 | z = self.backbone(z) 44 | return z 45 | 46 | def forward_tracking(self, z_feat, x): 47 | x = self.backbone(x) 48 | opt = self.fusion(z_feat, x) 49 | out = self.head(opt) 50 | return out 51 | -------------------------------------------------------------------------------- /lib/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiYunfengLYF/LightFC/97dc3405ec8e8c5ad3d3ad95cae7f12e4f17b5b0/lib/test/__init__.py -------------------------------------------------------------------------------- /lib/test/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import Sequence 2 | from .tracker import Tracker, trackerlist 3 | from .datasets import get_dataset 4 | from .environment import create_default_local_file_ITP_test -------------------------------------------------------------------------------- /lib/test/evaluation/datasets.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from collections import namedtuple 3 | 4 | from lib.test.evaluation.data import SequenceList 5 | 6 | DatasetInfo = namedtuple('DatasetInfo', ['module', 'class_name', 'kwargs']) 7 | 8 | pt = "lib.test.evaluation.%sdataset" # Useful abbreviations to reduce the clutter 9 | 10 | dataset_dict = dict( 11 | otb=DatasetInfo(module=pt % "otb", class_name="OTBDataset", kwargs=dict()), 12 | nfs=DatasetInfo(module=pt % "nfs", class_name="NFSDataset", kwargs=dict()), 13 | uav=DatasetInfo(module=pt % "uav", class_name="UAVDataset", kwargs=dict()), 14 | tc128=DatasetInfo(module=pt % "tc128", class_name="TC128Dataset", kwargs=dict()), 15 | tc128ce=DatasetInfo(module=pt % "tc128ce", class_name="TC128CEDataset", kwargs=dict()), 16 | trackingnet=DatasetInfo(module=pt % "trackingnet", class_name="TrackingNetDataset", kwargs=dict()), 17 | got10k_test=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='test')), 18 | got10k_val=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='val')), 19 | got10k_ltrval=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='ltrval')), 20 | lasot=DatasetInfo(module=pt % "lasot", class_name="LaSOTDataset", kwargs=dict()), 21 | lasot_lmdb=DatasetInfo(module=pt % "lasot_lmdb", class_name="LaSOTlmdbDataset", kwargs=dict()), 22 | 23 | vot18=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict()), 24 | vot22=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict(year=22)), 25 | itb=DatasetInfo(module=pt % "itb", class_name="ITBDataset", kwargs=dict()), 26 | tnl2k=DatasetInfo(module=pt % "tnl2k", class_name="TNL2kDataset", kwargs=dict()), 27 | lasot_ext=DatasetInfo(module=pt % "lasot_ext", class_name="LaSOTExtDataset",kwargs=dict()), 28 | 29 | uot=DatasetInfo(module=pt % 'uot', class_name='UOTDataset', kwargs=dict()), 30 | uot_sim=DatasetInfo(module=pt % 'uot', class_name='UOTDataset_SimSubset', kwargs=dict()), 31 | uot_unsim=DatasetInfo(module=pt % 'uot', class_name='UOTDataset_unSimSubset', kwargs=dict()), 32 | 33 | utb=DatasetInfo(module=pt % 'utb', class_name='UTBDataset', kwargs=dict()), 34 | utb_sim=DatasetInfo(module=pt % 'utb', class_name='UTBSimDataset', kwargs=dict()), 35 | utb_unsim=DatasetInfo(module=pt % 'utb', class_name='UTBunSimDataset', kwargs=dict()), 36 | 37 | dtb = DatasetInfo(module=pt % "dtb", class_name="DTBDataset", kwargs=dict()), 38 | 39 | ) 40 | 41 | 42 | def load_dataset(name: str, env_num: int): 43 | """ Import and load a single dataset.""" 44 | name = name.lower() 45 | dset_info = dataset_dict.get(name) 46 | if dset_info is None: 47 | raise ValueError('Unknown dataset \'%s\'' % name) 48 | 49 | m = importlib.import_module(dset_info.module) 50 | 51 | dset_info.kwargs['env_num'] = env_num 52 | 53 | dataset = getattr(m, dset_info.class_name)(**dset_info.kwargs) # Call the constructor 54 | return dataset.get_sequence_list() 55 | 56 | 57 | def get_dataset(*args, env_num): 58 | """ Get a single or set of datasets.""" 59 | dset = SequenceList() 60 | 61 | for name in args: 62 | dset.extend(load_dataset(name, env_num)) 63 | return dset 64 | -------------------------------------------------------------------------------- /lib/test/evaluation/environment.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | 5 | class EnvSettings: 6 | def __init__(self): 7 | test_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 8 | 9 | self.results_path = '{}/tracking_results/'.format(test_path) 10 | self.segmentation_path = '{}/segmentation_results/'.format(test_path) 11 | self.network_path = '{}/networks/'.format(test_path) 12 | self.result_plot_path = '{}/result_plots/'.format(test_path) 13 | self.otb_path = '' 14 | self.nfs_path = '' 15 | self.uav_path = '' 16 | self.tpl_path = '' 17 | self.vot_path = '' 18 | self.got10k_path = '' 19 | self.lasot_path = '' 20 | self.trackingnet_path = '' 21 | self.davis_dir = '' 22 | self.youtubevos_dir = '' 23 | 24 | self.got_packed_results_path = '' 25 | self.got_reports_path = '' 26 | self.tn_packed_results_path = '' 27 | 28 | 29 | def create_default_local_file(): 30 | comment = {'results_path': 'Where to store tracking results', 31 | 'network_path': 'Where tracking networks are stored.'} 32 | 33 | path = os.path.join(os.path.dirname(__file__), 'local.py') 34 | with open(path, 'w') as f: 35 | settings = EnvSettings() 36 | 37 | f.write('from test.evaluation.environment import EnvSettings\n\n') 38 | f.write('def local_env_settings():\n') 39 | f.write(' settings = EnvSettings()\n\n') 40 | f.write(' # Set your local paths here.\n\n') 41 | 42 | for attr in dir(settings): 43 | comment_str = None 44 | if attr in comment: 45 | comment_str = comment[attr] 46 | attr_val = getattr(settings, attr) 47 | if not attr.startswith('__') and not callable(attr_val): 48 | if comment_str is None: 49 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val)) 50 | else: 51 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 52 | f.write('\n return settings\n\n') 53 | 54 | 55 | class EnvSettings_ITP: 56 | def __init__(self, workspace_dir, data_dir, save_dir): 57 | self.prj_dir = workspace_dir 58 | self.save_dir = save_dir 59 | self.results_path = os.path.join(save_dir, 'test/tracking_results') 60 | self.segmentation_path = os.path.join(save_dir, 'test/segmentation_results') 61 | self.network_path = os.path.join(save_dir, 'test/networks') 62 | self.result_plot_path = os.path.join(save_dir, 'test/result_plots') 63 | self.otb_path = os.path.join(data_dir, 'otb') 64 | self.nfs_path = os.path.join(data_dir, 'nfs') 65 | self.uav_path = os.path.join(data_dir, 'uav') 66 | self.tc128_path = os.path.join(data_dir, 'TC128') 67 | self.tpl_path = '' 68 | self.vot_path = os.path.join(data_dir, 'VOT2019') 69 | self.got10k_path = os.path.join(data_dir, 'got10k') 70 | self.got10k_lmdb_path = os.path.join(data_dir, 'got10k_lmdb') 71 | self.lasot_path = os.path.join(data_dir, 'lasot') 72 | self.lasot_lmdb_path = os.path.join(data_dir, 'lasot_lmdb') 73 | self.trackingnet_path = os.path.join(data_dir, 'trackingnet') 74 | self.vot18_path = os.path.join(data_dir, 'vot2018') 75 | self.vot22_path = os.path.join(data_dir, 'vot2022') 76 | self.itb_path = os.path.join(data_dir, 'itb') 77 | self.tnl2k_path = os.path.join(data_dir, 'tnl2k') 78 | self.lasot_extension_subset_path_path = os.path.join(data_dir, 'lasot_extension_subset') 79 | self.davis_dir = '' 80 | self.youtubevos_dir = '' 81 | 82 | self.got_packed_results_path = '' 83 | self.got_reports_path = '' 84 | self.tn_packed_results_path = '' 85 | 86 | 87 | def create_default_local_file_ITP_test(workspace_dir, data_dir, save_dir): 88 | comment = {'results_path': 'Where to store tracking results', 89 | 'network_path': 'Where tracking networks are stored.'} 90 | 91 | path = os.path.join(os.path.dirname(__file__), 'local.py') 92 | with open(path, 'w') as f: 93 | settings = EnvSettings_ITP(workspace_dir, data_dir, save_dir) 94 | 95 | f.write('from lib.test.evaluation.environment import EnvSettings\n\n') 96 | f.write('def local_env_settings():\n') 97 | f.write(' settings = EnvSettings()\n\n') 98 | f.write(' # Set your local paths here.\n\n') 99 | 100 | for attr in dir(settings): 101 | comment_str = None 102 | if attr in comment: 103 | comment_str = comment[attr] 104 | attr_val = getattr(settings, attr) 105 | if not attr.startswith('__') and not callable(attr_val): 106 | if comment_str is None: 107 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val)) 108 | else: 109 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 110 | f.write('\n return settings\n\n') 111 | 112 | 113 | def env_settings(env_num): 114 | env_module_name = 'lib.test.evaluation.local' 115 | try: 116 | env_module = importlib.import_module(env_module_name) 117 | return env_module.local_env_settings(env_num) 118 | except: 119 | env_file = os.path.join(os.path.dirname(__file__), 'local.py') 120 | 121 | # Create a default file 122 | # create_default_local_file() 123 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. ' 124 | 'Then try to run again.'.format(env_file)) -------------------------------------------------------------------------------- /lib/test/evaluation/got10kdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 6 | from lib.test.utils.load_text import load_text 7 | 8 | 9 | class GOT10KDataset(BaseDataset): 10 | """ GOT-10k dataset. 11 | 12 | Publication: 13 | GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild 14 | Lianghua Huang, Xin Zhao, and Kaiqi Huang 15 | arXiv:1810.11981, 2018 16 | https://arxiv.org/pdf/1810.11981.pdf 17 | 18 | Download dataset from http://got-10k.aitestunion.com/downloads 19 | """ 20 | 21 | def __init__(self, split, env_num): 22 | super().__init__(env_num) 23 | # Split can be test, val, or ltrval (a validation split consisting of videos from the official train set) 24 | if split == 'test' or split == 'val': 25 | self.base_path = os.path.join(self.env_settings.got10k_path, split) 26 | else: 27 | self.base_path = os.path.join(self.env_settings.got10k_path, 'train') 28 | 29 | self.sequence_list = self._get_sequence_list(split) 30 | self.split = split 31 | 32 | def get_sequence_list(self): 33 | return SequenceList([self._construct_sequence(s) for s in self.sequence_list]) 34 | 35 | def _construct_sequence(self, sequence_name): 36 | anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name) 37 | 38 | ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64) 39 | 40 | frames_path = '{}/{}'.format(self.base_path, sequence_name) 41 | frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")] 42 | frame_list.sort(key=lambda f: int(f[:-4])) 43 | frames_list = [os.path.join(frames_path, frame) for frame in frame_list] 44 | 45 | return Sequence(sequence_name, frames_list, 'got10k', ground_truth_rect.reshape(-1, 4)) 46 | 47 | def __len__(self): 48 | return len(self.sequence_list) 49 | 50 | def _get_sequence_list(self, split): 51 | with open('{}/list.txt'.format(self.base_path)) as f: 52 | sequence_list = f.read().splitlines() 53 | 54 | if split == 'ltrval': 55 | with open('{}/got10k_val_split.txt'.format(self.env_settings.dataspec_path)) as f: 56 | seq_ids = f.read().splitlines() 57 | 58 | sequence_list = [sequence_list[int(x)] for x in seq_ids] 59 | return sequence_list 60 | -------------------------------------------------------------------------------- /lib/test/evaluation/itbdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 6 | from lib.test.utils.load_text import load_text 7 | 8 | 9 | class ITBDataset(BaseDataset): 10 | """ NUS-PRO dataset 11 | """ 12 | 13 | def __init__(self, env_num): 14 | super().__init__(env_num) 15 | self.base_path = self.env_settings.itb_path 16 | self.sequence_info_list = self._get_sequence_info_list(self.base_path) 17 | 18 | def get_sequence_list(self): 19 | return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list]) 20 | 21 | def _construct_sequence(self, sequence_info): 22 | sequence_path = sequence_info['path'] 23 | nz = sequence_info['nz'] 24 | ext = sequence_info['ext'] 25 | start_frame = sequence_info['startFrame'] 26 | end_frame = sequence_info['endFrame'] 27 | 28 | init_omit = 0 29 | if 'initOmit' in sequence_info: 30 | init_omit = sequence_info['initOmit'] 31 | 32 | frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path, 33 | sequence_path=sequence_path, frame=frame_num, 34 | nz=nz, ext=ext) for frame_num in 35 | range(start_frame + init_omit, end_frame + 1)] 36 | 37 | anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path']) 38 | 39 | # NOTE: NUS has some weird annos which panda cannot handle 40 | ground_truth_rect = load_text(str(anno_path), delimiter=(',', None), dtype=np.float64, backend='numpy') 41 | return Sequence(sequence_info['name'], frames, 'otb', ground_truth_rect[init_omit:, :], 42 | object_class=sequence_info['object_class']) 43 | 44 | def __len__(self): 45 | return len(self.sequence_info_list) 46 | 47 | def get_fileNames(self, rootdir): 48 | fs = [] 49 | fs_all = [] 50 | for root, dirs, files in os.walk(rootdir, topdown=True): 51 | files.sort() 52 | files.sort(key=len) 53 | if files is not None: 54 | for name in files: 55 | _, ending = os.path.splitext(name) 56 | if ending == ".jpg": 57 | _, root_ = os.path.split(root) 58 | fs.append(os.path.join(root_, name)) 59 | fs_all.append(os.path.join(root, name)) 60 | 61 | return fs_all, fs 62 | 63 | def _get_sequence_info_list(self, base_path): 64 | sequence_info_list = [] 65 | for scene in os.listdir(base_path): 66 | if '.' in scene: 67 | continue 68 | videos = os.listdir(os.path.join(base_path, scene)) 69 | for video in videos: 70 | _, fs = self.get_fileNames(os.path.join(base_path, scene, video)) 71 | video_tmp = {"name": video, "path": scene + '/' + video, "startFrame": 1, "endFrame": len(fs), 72 | "nz": len(fs[0].split('/')[-1].split('.')[0]), "ext": "jpg", 73 | "anno_path": scene + '/' + video + "/groundtruth.txt", 74 | "object_class": "unknown"} 75 | sequence_info_list.append(video_tmp) 76 | 77 | return sequence_info_list # sequence_info_list_50 # 78 | -------------------------------------------------------------------------------- /lib/test/evaluation/local.py: -------------------------------------------------------------------------------- 1 | from lib.test.evaluation.environment import EnvSettings 2 | 3 | 4 | def local_env_settings(env_num): 5 | settings = EnvSettings() 6 | 7 | settings.davis_dir = r'' 8 | settings.got10k_lmdb_path = r'' 9 | settings.got10k_path = r'' 10 | settings.got_packed_results_path = r'' 11 | settings.got_reports_path = r'' 12 | settings.itb_path = r'' 13 | settings.lasot_extension_subset_path = '' 14 | settings.lasot_lmdb_path = r'' 15 | settings.lasot_path = '/media/liyunfeng/CV2/data/sot/lasot' 16 | settings.network_path = r'' 17 | settings.nfs_path = r'' 18 | settings.otb_path = r'/media/liyunfeng/CV2/data/sot/otb' 19 | settings.dtb_path = r'' 20 | settings.prj_dir = r'/home/liyunfeng/code/project2/LightFC' 21 | settings.result_plot_path = r'/home/liyunfeng/code/project2/LightFC/output/test/result_plots' 22 | # Where to store tracking results 23 | settings.results_path = r'/home/liyunfeng/code/project2/LightFC/output/test/tracking_results' 24 | settings.save_dir = r'/home/liyunfeng/code/project2/LightFC/output' 25 | settings.segmentation_path = r'/home/liyunfeng/code/project2/LightFC/output/test/segmentation_results' 26 | settings.tc128_path = r'/media/liyunfeng/CV2/data/sot/tc128' 27 | settings.tn_packed_results_path = r'' 28 | settings.tnl2k_path = r'/media/liyunfeng/CV2/data/sot/tnl2k/test' 29 | settings.tpl_path = r'' 30 | settings.trackingnet_path = r'' 31 | settings.uav_path = r'/media/liyunfeng/CV2/data/uav/uav123' 32 | settings.vot18_path = r'' 33 | settings.vot22_path = r'' 34 | settings.vot_path = r'' 35 | settings.youtubevos_dir = r'' 36 | settings.uot_path = r'/media/liyunfeng/CV2/data/uot/uot100' 37 | settings.utb_path = r'/media/liyunfeng/CV2/data/uot/utb180' 38 | 39 | return settings 40 | -------------------------------------------------------------------------------- /lib/test/evaluation/tc128cedataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import numpy as np 5 | import six 6 | 7 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 8 | 9 | 10 | class TC128CEDataset(BaseDataset): 11 | """ 12 | TC-128 Dataset (78 newly added sequences) 13 | modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit) 14 | """ 15 | 16 | def __init__(self, env_num): 17 | super().__init__(env_num) 18 | self.base_path = self.env_settings.tc128_path 19 | self.anno_files = sorted(glob.glob( 20 | os.path.join(self.base_path, '*/*_gt.txt'))) 21 | """filter the newly added sequences (_ce)""" 22 | self.anno_files = [s for s in self.anno_files if "_ce" in s] 23 | self.seq_dirs = [os.path.dirname(f) for f in self.anno_files] 24 | self.seq_names = [os.path.basename(d) for d in self.seq_dirs] 25 | # valid frame range for each sequence 26 | self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs] 27 | 28 | def get_sequence_list(self): 29 | return SequenceList([self._construct_sequence(s) for s in self.seq_names]) 30 | 31 | def _construct_sequence(self, sequence_name): 32 | if isinstance(sequence_name, six.string_types): 33 | if not sequence_name in self.seq_names: 34 | raise Exception('Sequence {} not found.'.format(sequence_name)) 35 | index = self.seq_names.index(sequence_name) 36 | # load valid frame range 37 | frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',') 38 | img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)] 39 | 40 | # load annotations 41 | anno = np.loadtxt(self.anno_files[index], delimiter=',') 42 | assert len(img_files) == len(anno) 43 | assert anno.shape[1] == 4 44 | 45 | # return img_files, anno 46 | return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4)) 47 | 48 | def __len__(self): 49 | return len(self.seq_names) 50 | -------------------------------------------------------------------------------- /lib/test/evaluation/tc128dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 3 | import os 4 | import glob 5 | import six 6 | 7 | 8 | class TC128Dataset(BaseDataset): 9 | """ 10 | TC-128 Dataset 11 | modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit) 12 | """ 13 | def __init__(self,env_num): 14 | super().__init__(env_num) 15 | self.base_path = self.env_settings.tc128_path 16 | self.anno_files = sorted(glob.glob( 17 | os.path.join(self.base_path, '*/*_gt.txt'))) 18 | self.seq_dirs = [os.path.dirname(f) for f in self.anno_files] 19 | self.seq_names = [os.path.basename(d) for d in self.seq_dirs] 20 | # valid frame range for each sequence 21 | self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs] 22 | 23 | def get_sequence_list(self): 24 | return SequenceList([self._construct_sequence(s) for s in self.seq_names]) 25 | 26 | def _construct_sequence(self, sequence_name): 27 | if isinstance(sequence_name, six.string_types): 28 | if not sequence_name in self.seq_names: 29 | raise Exception('Sequence {} not found.'.format(sequence_name)) 30 | index = self.seq_names.index(sequence_name) 31 | # load valid frame range 32 | frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',') 33 | img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)] 34 | 35 | # load annotations 36 | anno = np.loadtxt(self.anno_files[index], delimiter=',') 37 | assert len(img_files) == len(anno) 38 | assert anno.shape[1] == 4 39 | 40 | # return img_files, anno 41 | return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4)) 42 | 43 | def __len__(self): 44 | return len(self.seq_names) 45 | -------------------------------------------------------------------------------- /lib/test/evaluation/tnl2kdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 6 | from lib.test.utils.load_text import load_text, load_str 7 | 8 | 9 | ############ 10 | # current 00000492.png of test_015_Sord_video_Q01_done is damaged and replaced by a copy of 00000491.png 11 | ############ 12 | 13 | 14 | class TNL2kDataset(BaseDataset): 15 | """ 16 | TNL2k test set 17 | """ 18 | 19 | def __init__(self, env_num): 20 | super().__init__(env_num) 21 | self.base_path = self.env_settings.tnl2k_path 22 | self.sequence_list = self._get_sequence_list() 23 | 24 | def get_sequence_list(self): 25 | return SequenceList([self._construct_sequence(s) for s in self.sequence_list]) 26 | 27 | def _construct_sequence(self, sequence_name): 28 | # class_name = sequence_name.split('-')[0] 29 | anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name) 30 | 31 | ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64) 32 | 33 | text_dsp_path = '{}/{}/language.txt'.format(self.base_path, sequence_name) 34 | text_dsp = load_str(text_dsp_path) 35 | 36 | frames_path = '{}/{}/imgs'.format(self.base_path, sequence_name) 37 | frames_list = [f for f in os.listdir(frames_path)] 38 | frames_list = sorted(frames_list) 39 | frames_list = ['{}/{}'.format(frames_path, frame_i) for frame_i in frames_list] 40 | 41 | # target_class = class_name 42 | return Sequence(sequence_name, frames_list, 'tnl2k', ground_truth_rect.reshape(-1, 4)) 43 | 44 | def __len__(self): 45 | return len(self.sequence_list) 46 | 47 | def _get_sequence_list(self): 48 | sequence_list = [] 49 | for seq in os.listdir(self.base_path): 50 | if os.path.isdir(os.path.join(self.base_path, seq)): 51 | sequence_list.append(seq) 52 | 53 | return sequence_list 54 | -------------------------------------------------------------------------------- /lib/test/evaluation/trackingnetdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList 6 | from lib.test.utils.load_text import load_text 7 | 8 | 9 | class TrackingNetDataset(BaseDataset): 10 | """ TrackingNet test set. 11 | 12 | Publication: 13 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild. 14 | Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem 15 | ECCV, 2018 16 | https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf 17 | 18 | Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit. 19 | """ 20 | 21 | def __init__(self, env_num): 22 | super().__init__(env_num) 23 | self.base_path = self.env_settings.trackingnet_path 24 | 25 | sets = 'TEST' 26 | if not isinstance(sets, (list, tuple)): 27 | if sets == 'TEST': 28 | sets = ['TEST'] 29 | elif sets == 'TRAIN': 30 | sets = ['TRAIN_{}'.format(i) for i in range(5)] 31 | 32 | self.sequence_list = self._list_sequences(self.base_path, sets) 33 | 34 | def get_sequence_list(self): 35 | return SequenceList([self._construct_sequence(set, seq_name) for set, seq_name in self.sequence_list]) 36 | 37 | def _construct_sequence(self, set, sequence_name): 38 | anno_path = '{}/{}/anno/{}.txt'.format(self.base_path, set, sequence_name) 39 | 40 | ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64, backend='numpy') 41 | 42 | frames_path = '{}/{}/frames/{}'.format(self.base_path, set, sequence_name) 43 | frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")] 44 | frame_list.sort(key=lambda f: int(f[:-4])) 45 | frames_list = [os.path.join(frames_path, frame) for frame in frame_list] 46 | 47 | return Sequence(sequence_name, frames_list, 'trackingnet', ground_truth_rect.reshape(-1, 4)) 48 | 49 | def __len__(self): 50 | return len(self.sequence_list) 51 | 52 | def _list_sequences(self, root, set_ids): 53 | sequence_list = [] 54 | 55 | for s in set_ids: 56 | anno_dir = os.path.join(root, s, "anno") 57 | sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')] 58 | 59 | sequence_list += sequences_cur_set 60 | 61 | return sequence_list 62 | -------------------------------------------------------------------------------- /lib/test/parameter/lightfc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from lib.utils.load import load_yaml 3 | from lib.test.utils import TrackerParams 4 | from lib.test.evaluation.environment import env_settings 5 | 6 | 7 | def parameters(yaml_name: str, env_num: int): 8 | params = TrackerParams() 9 | params.env_num = env_num 10 | prj_dir = env_settings(env_num).prj_dir 11 | save_dir = env_settings(env_num).save_dir 12 | # update default config from yaml file 13 | yaml_file = os.path.join(prj_dir, 'experiments/lightfc/%s.yaml' % yaml_name) 14 | params.cfg = load_yaml(yaml_file) 15 | print("test config: ", params.cfg) 16 | params.tracker_param = yaml_name 17 | 18 | # template and search region 19 | params.template_factor = params.cfg.TEST.TEMPLATE_FACTOR 20 | params.template_size = params.cfg.TEST.TEMPLATE_SIZE 21 | params.search_factor = params.cfg.TEST.SEARCH_FACTOR 22 | params.search_size = params.cfg.TEST.SEARCH_SIZE 23 | 24 | # Network checkpoint path 25 | params.checkpoint = os.path.join(save_dir, "checkpoints/train/lightfc/%s/lightfc_ep%04d.pth.tar" % 26 | (yaml_name, params.cfg.TEST.EPOCH)) 27 | 28 | params.save_all_boxes = False 29 | 30 | return params 31 | -------------------------------------------------------------------------------- /lib/test/tracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiYunfengLYF/LightFC/97dc3405ec8e8c5ad3d3ad95cae7f12e4f17b5b0/lib/test/tracker/__init__.py -------------------------------------------------------------------------------- /lib/test/tracker/basetracker.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from _collections import OrderedDict 5 | 6 | from lib.train.data.processing_utils import transform_image_to_crop 7 | 8 | 9 | # from lib.vis.visdom_cus import Visdom 10 | 11 | 12 | class BaseTracker: 13 | """Base class for all trackers.""" 14 | 15 | def __init__(self, params, dataset_name=None): 16 | self.params = params 17 | self.visdom = None 18 | 19 | def predicts_segmentation_mask(self): 20 | return False 21 | 22 | def initialize(self, image, info: dict) -> dict: 23 | """Overload this function in your tracker. This should initialize the model.""" 24 | raise NotImplementedError 25 | 26 | def track(self, image, info: dict = None) -> dict: 27 | """Overload this function in your tracker. This should track in the frame and update the model.""" 28 | raise NotImplementedError 29 | 30 | def visdom_draw_tracking(self, image, box, segmentation=None): 31 | if isinstance(box, OrderedDict): 32 | box = [v for k, v in box.items()] 33 | else: 34 | box = (box,) 35 | if segmentation is None: 36 | self.visdom.register((image, *box), 'Tracking', 1, 'Tracking') 37 | else: 38 | self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking') 39 | 40 | def transform_bbox_to_crop(self, box_in, resize_factor, device, box_extract=None, crop_type='template'): 41 | # box_in: list [x1, y1, w, h], not normalized 42 | # box_extract: same as box_in 43 | # out bbox: Torch.tensor [1, 1, 4], x1y1wh, normalized 44 | if crop_type == 'template': 45 | crop_sz = torch.Tensor([self.params.template_size, self.params.template_size]) 46 | elif crop_type == 'search': 47 | crop_sz = torch.Tensor([self.params.search_size, self.params.search_size]) 48 | else: 49 | raise NotImplementedError 50 | 51 | box_in = torch.tensor(box_in) 52 | if box_extract is None: 53 | box_extract = box_in 54 | else: 55 | box_extract = torch.tensor(box_extract) 56 | template_bbox = transform_image_to_crop(box_in, box_extract, resize_factor, crop_sz, normalize=True) 57 | template_bbox = template_bbox.view(1, 1, 4).to(device) 58 | 59 | return template_bbox 60 | 61 | def _init_visdom(self, visdom_info, debug): 62 | visdom_info = {} if visdom_info is None else visdom_info 63 | self.pause_mode = False 64 | self.step = False 65 | self.next_seq = False 66 | if debug > 0 and visdom_info.get('use_visdom', True): 67 | try: 68 | # self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'}, 69 | # visdom_info=visdom_info) 70 | pass 71 | # # Show help 72 | # help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \ 73 | # 'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \ 74 | # 'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \ 75 | # 'block list.' 76 | # self.visdom.register(help_text, 'text', 1, 'Help') 77 | except: 78 | time.sleep(0.5) 79 | print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n' 80 | '!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!') 81 | 82 | def _visdom_ui_handler(self, data): 83 | if data['event_type'] == 'KeyPress': 84 | if data['key'] == ' ': 85 | self.pause_mode = not self.pause_mode 86 | 87 | elif data['key'] == 'ArrowRight' and self.pause_mode: 88 | self.step = True 89 | 90 | elif data['key'] == 'n': 91 | self.next_seq = True 92 | -------------------------------------------------------------------------------- /lib/test/tracker/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from lib.utils.misc import NestedTensor 4 | 5 | 6 | class Preprocessor(object): 7 | def __init__(self): 8 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda() 9 | self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda() 10 | 11 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): 12 | # Deal with the image patch 13 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2, 0, 1)).unsqueeze(dim=0) 14 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W) 15 | # Deal with the attention mask 16 | amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W) 17 | return NestedTensor(img_tensor_norm, amask_tensor) 18 | 19 | 20 | class PreprocessorX(object): 21 | def __init__(self): 22 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda() 23 | self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda() 24 | 25 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): 26 | # Deal with the image patch 27 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2, 0, 1)).unsqueeze(dim=0) 28 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W) 29 | # Deal with the attention mask 30 | amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W) 31 | return img_tensor_norm, amask_tensor 32 | 33 | 34 | class PreprocessorX_onnx(object): 35 | def __init__(self): 36 | self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1)) 37 | self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1)) 38 | 39 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray): 40 | """img_arr: (H,W,3), amask_arr: (H,W)""" 41 | # Deal with the image patch 42 | img_arr_4d = img_arr[np.newaxis, :, :, :].transpose(0, 3, 1, 2) 43 | img_arr_4d = (img_arr_4d / 255.0 - self.mean) / self.std # (1, 3, H, W) 44 | # Deal with the attention mask 45 | amask_arr_3d = amask_arr[np.newaxis, :, :] # (1,H,W) 46 | return img_arr_4d.astype(np.float32), amask_arr_3d.astype(np.bool) 47 | -------------------------------------------------------------------------------- /lib/test/tracker/lightfc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lib.models import LightFC 4 | from lib.utils.box_ops import clip_box, box_xywh_to_xyxy, box_iou, box_xyxy_to_xywh 5 | from lib.test.utils.hann import hann2d 6 | from lib.test.tracker.basetracker import BaseTracker 7 | from lib.test.tracker.data_utils import Preprocessor 8 | from lib.train.data.processing_utils import sample_target 9 | 10 | 11 | class lightFC(BaseTracker): 12 | def __init__(self, params, dataset_name): 13 | super(lightFC, self).__init__(params) 14 | 15 | network = LightFC(cfg=params.cfg, env_num=None, training=False) 16 | network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True) 17 | 18 | for module in network.backbone.modules(): 19 | if hasattr(module, 'switch_to_deploy'): 20 | module.switch_to_deploy() 21 | for module in network.head.modules(): 22 | if hasattr(module, 'switch_to_deploy'): 23 | module.switch_to_deploy() 24 | 25 | self.cfg = params.cfg 26 | self.network = network.cuda() 27 | self.network.eval() 28 | self.preprocessor = Preprocessor() 29 | self.state = None 30 | 31 | self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE 32 | 33 | # motion constrain 34 | self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda() 35 | 36 | self.frame_id = 0 37 | 38 | def initialize(self, image, info: dict): 39 | H, W, _ = image.shape 40 | 41 | z_patch_arr, resize_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor, 42 | output_sz=self.params.template_size) 43 | 44 | template = self.preprocessor.process(z_patch_arr, z_amask_arr) 45 | 46 | with torch.no_grad(): 47 | self.z_feat = self.network.forward_backbone(template.tensors) 48 | 49 | self.state = info['init_bbox'] 50 | self.frame_id = 0 51 | 52 | def track(self, image, info: dict = None): 53 | H, W, _ = image.shape 54 | self.frame_id += 1 55 | x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor, 56 | output_sz=self.params.search_size) # (x1, y1, w, h) 57 | 58 | search = self.preprocessor.process(x_patch_arr, x_amask_arr) 59 | 60 | with torch.no_grad(): 61 | x_dict = search 62 | out_dict = self.network.forward_tracking(z_feat=self.z_feat, x=x_dict.tensors) 63 | 64 | response_origin = self.output_window * out_dict['score_map'] 65 | 66 | pred_box_origin = self.compute_box(response_origin, out_dict, 67 | resize_factor).tolist() # .unsqueeze(dim=0) # tolist() 68 | 69 | self.state = clip_box(self.map_box_back(pred_box_origin, resize_factor), H, W, margin=2) 70 | 71 | return {"target_bbox": self.state} 72 | 73 | def compute_box(self, response, out_dict, resize_factor): 74 | pred_boxes = self.network.head.cal_bbox(response, out_dict['size_map'], out_dict['offset_map']) 75 | pred_boxes = pred_boxes.view(-1, 4) 76 | pred_boxes = (pred_boxes.mean(dim=0) * self.params.search_size / resize_factor) 77 | return pred_boxes 78 | 79 | def map_box_back(self, pred_box: list, resize_factor: float): 80 | cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] 81 | cx, cy, w, h = pred_box 82 | half_side = 0.5 * self.params.search_size / resize_factor 83 | cx_real = cx + (cx_prev - half_side) 84 | cy_real = cy + (cy_prev - half_side) 85 | return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h] 86 | 87 | def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float): 88 | cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3] 89 | cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,) 90 | half_side = 0.5 * self.params.search_size / resize_factor 91 | cx_real = cx + (cx_prev - half_side) 92 | cy_real = cy + (cy_prev - half_side) 93 | return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1) 94 | 95 | 96 | def get_tracker_class(): 97 | return lightFC 98 | -------------------------------------------------------------------------------- /lib/test/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .params import TrackerParams, FeatureParams, Choice -------------------------------------------------------------------------------- /lib/test/utils/_init_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | 14 | this_dir = osp.dirname(__file__) 15 | 16 | prj_path = osp.join(this_dir, '../..', '..', '..') 17 | add_path(prj_path) 18 | -------------------------------------------------------------------------------- /lib/test/utils/deploy.py: -------------------------------------------------------------------------------- 1 | import os 2 | from lib.test.evaluation.environment import env_settings 3 | 4 | 5 | def get_onnx_save_name(params): 6 | save_dir = env_settings(params.env_num).save_dir 7 | onnx_save_dir = os.path.join(save_dir, "checkpoints/train/lightfc/%s" % (params.tracker_param)) 8 | 9 | backbone_save_name = os.path.join(onnx_save_dir, 'deploy_lightTrack_ep%04d_backbone.onnx' % (params.cfg.TEST.EPOCH)) 10 | network_save_name = os.path.join(onnx_save_dir, 'deploy_lightTrack_ep%04d_network.onnx' % (params.cfg.TEST.EPOCH)) 11 | return {'backbone': backbone_save_name, 'network': network_save_name} 12 | 13 | 14 | def to_numpy(tensor): 15 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() -------------------------------------------------------------------------------- /lib/test/utils/hann.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | 5 | 6 | def hann1d(sz: int, centered = True) -> torch.Tensor: 7 | """1D cosine window.""" 8 | if centered: 9 | return 0.5 * (1 - torch.cos((2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float())) 10 | w = 0.5 * (1 + torch.cos((2 * math.pi / (sz + 2)) * torch.arange(0, sz//2 + 1).float())) 11 | return torch.cat([w, w[1:sz-sz//2].flip((0,))]) 12 | 13 | 14 | def hann2d(sz: torch.Tensor, centered = True) -> torch.Tensor: 15 | """2D cosine window.""" 16 | return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d(sz[1].item(), centered).reshape(1, 1, 1, -1) 17 | 18 | 19 | def hann2d_bias(sz: torch.Tensor, ctr_point: torch.Tensor, centered = True) -> torch.Tensor: 20 | """2D cosine window.""" 21 | distance = torch.stack([ctr_point, sz-ctr_point], dim=0) 22 | max_distance, _ = distance.max(dim=0) 23 | 24 | hann1d_x = hann1d(max_distance[0].item() * 2, centered) 25 | hann1d_x = hann1d_x[max_distance[0] - distance[0, 0]: max_distance[0] + distance[1, 0]] 26 | hann1d_y = hann1d(max_distance[1].item() * 2, centered) 27 | hann1d_y = hann1d_y[max_distance[1] - distance[0, 1]: max_distance[1] + distance[1, 1]] 28 | 29 | return hann1d_y.reshape(1, 1, -1, 1) * hann1d_x.reshape(1, 1, 1, -1) 30 | 31 | 32 | 33 | def hann2d_clipped(sz: torch.Tensor, effective_sz: torch.Tensor, centered = True) -> torch.Tensor: 34 | """1D clipped cosine window.""" 35 | 36 | # Ensure that the difference is even 37 | effective_sz += (effective_sz - sz) % 2 38 | effective_window = hann1d(effective_sz[0].item(), True).reshape(1, 1, -1, 1) * hann1d(effective_sz[1].item(), True).reshape(1, 1, 1, -1) 39 | 40 | pad = (sz - effective_sz) // 2 41 | 42 | window = F.pad(effective_window, (pad[1].item(), pad[1].item(), pad[0].item(), pad[0].item()), 'replicate') 43 | 44 | if centered: 45 | return window 46 | else: 47 | mid = (sz / 2).int() 48 | window_shift_lr = torch.cat((window[:, :, :, mid[1]:], window[:, :, :, :mid[1]]), 3) 49 | return torch.cat((window_shift_lr[:, :, mid[0]:, :], window_shift_lr[:, :, :mid[0], :]), 2) 50 | 51 | 52 | def gauss_fourier(sz: int, sigma: float, half: bool = False) -> torch.Tensor: 53 | if half: 54 | k = torch.arange(0, int(sz/2+1)) 55 | else: 56 | k = torch.arange(-int((sz-1)/2), int(sz/2+1)) 57 | return (math.sqrt(2*math.pi) * sigma / sz) * torch.exp(-2 * (math.pi * sigma * k.float() / sz)**2) 58 | 59 | 60 | def gauss_spatial(sz, sigma, center=0, end_pad=0): 61 | k = torch.arange(-(sz-1)/2, (sz+1)/2+end_pad) 62 | return torch.exp(-1.0/(2*sigma**2) * (k - center)**2) 63 | 64 | 65 | def label_function(sz: torch.Tensor, sigma: torch.Tensor): 66 | return gauss_fourier(sz[0].item(), sigma[0].item()).reshape(1, 1, -1, 1) * gauss_fourier(sz[1].item(), sigma[1].item(), True).reshape(1, 1, 1, -1) 67 | 68 | def label_function_spatial(sz: torch.Tensor, sigma: torch.Tensor, center: torch.Tensor = torch.zeros(2), end_pad: torch.Tensor = torch.zeros(2)): 69 | """The origin is in the middle of the image.""" 70 | return gauss_spatial(sz[0].item(), sigma[0].item(), center[0], end_pad[0].item()).reshape(1, 1, -1, 1) * \ 71 | gauss_spatial(sz[1].item(), sigma[1].item(), center[1], end_pad[1].item()).reshape(1, 1, 1, -1) 72 | 73 | 74 | def cubic_spline_fourier(f, a): 75 | """The continuous Fourier transform of a cubic spline kernel.""" 76 | 77 | bf = (6*(1 - torch.cos(2 * math.pi * f)) + 3*a*(1 - torch.cos(4 * math.pi * f)) 78 | - (6 + 8*a)*math.pi*f*torch.sin(2 * math.pi * f) - 2*a*math.pi*f*torch.sin(4 * math.pi * f)) \ 79 | / (4 * math.pi**4 * f**4) 80 | 81 | bf[f == 0] = 1 82 | 83 | return bf 84 | 85 | def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor): 86 | """Computes maximum and argmax in the last two dimensions.""" 87 | 88 | max_val_row, argmax_row = torch.max(a, dim=-2) 89 | max_val, argmax_col = torch.max(max_val_row, dim=-1) 90 | argmax_row = argmax_row.view(argmax_col.numel(),-1)[torch.arange(argmax_col.numel()), argmax_col.view(-1)] 91 | argmax_row = argmax_row.reshape(argmax_col.shape) 92 | argmax = torch.cat((argmax_row.unsqueeze(-1), argmax_col.unsqueeze(-1)), -1) 93 | return max_val, argmax 94 | -------------------------------------------------------------------------------- /lib/test/utils/load_text.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def load_text_numpy(path, delimiter, dtype): 6 | if isinstance(delimiter, (tuple, list)): 7 | for d in delimiter: 8 | try: 9 | ground_truth_rect = np.loadtxt(path, delimiter=d, dtype=dtype) 10 | return ground_truth_rect 11 | except: 12 | pass 13 | 14 | raise Exception('Could not read file {}'.format(path)) 15 | else: 16 | ground_truth_rect = np.loadtxt(path, delimiter=delimiter, dtype=dtype) 17 | return ground_truth_rect 18 | 19 | 20 | def load_text_pandas(path, delimiter, dtype): 21 | if isinstance(delimiter, (tuple, list)): 22 | for d in delimiter: 23 | try: 24 | ground_truth_rect = pd.read_csv(path, delimiter=d, header=None, dtype=dtype, na_filter=False, 25 | low_memory=False).values 26 | return ground_truth_rect 27 | except Exception as e: 28 | pass 29 | 30 | raise Exception('Could not read file {}'.format(path)) 31 | else: 32 | ground_truth_rect = pd.read_csv(path, delimiter=delimiter, header=None, dtype=dtype, na_filter=False, 33 | low_memory=False).values 34 | return ground_truth_rect 35 | 36 | 37 | def load_text(path, delimiter=' ', dtype=np.float32, backend='numpy'): 38 | if backend == 'numpy': 39 | return load_text_numpy(path, delimiter, dtype) 40 | elif backend == 'pandas': 41 | return load_text_pandas(path, delimiter, dtype) 42 | 43 | 44 | def load_str(path): 45 | with open(path, "r") as f: 46 | text_str = f.readline().strip().lower() 47 | return text_str 48 | -------------------------------------------------------------------------------- /lib/test/utils/params.py: -------------------------------------------------------------------------------- 1 | from lib.utils import TensorList 2 | import random 3 | 4 | 5 | class TrackerParams: 6 | """Class for tracker parameters.""" 7 | def set_default_values(self, default_vals: dict): 8 | for name, val in default_vals.items(): 9 | if not hasattr(self, name): 10 | setattr(self, name, val) 11 | 12 | def get(self, name: str, *default): 13 | """Get a parameter value with the given name. If it does not exists, it return the default value given as a 14 | second argument or returns an error if no default value is given.""" 15 | if len(default) > 1: 16 | raise ValueError('Can only give one default value.') 17 | 18 | if not default: 19 | return getattr(self, name) 20 | 21 | return getattr(self, name, default[0]) 22 | 23 | def has(self, name: str): 24 | """Check if there exist a parameter with the given name.""" 25 | return hasattr(self, name) 26 | 27 | 28 | class FeatureParams: 29 | """Class for feature specific parameters""" 30 | def __init__(self, *args, **kwargs): 31 | if len(args) > 0: 32 | raise ValueError 33 | 34 | for name, val in kwargs.items(): 35 | if isinstance(val, list): 36 | setattr(self, name, TensorList(val)) 37 | else: 38 | setattr(self, name, val) 39 | 40 | 41 | def Choice(*args): 42 | """Can be used to sample random parameter values.""" 43 | return random.choice(args) 44 | -------------------------------------------------------------------------------- /lib/test/utils/transform_got10k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | import argparse 5 | import _init_paths 6 | from lib.test.evaluation.environment import env_settings 7 | 8 | 9 | def transform_got10k(tracker_name, cfg_name): 10 | env = env_settings(env_num=0) 11 | result_dir = env.results_path 12 | src_dir = os.path.join(result_dir, "%s/%s/got10k/" % (tracker_name, cfg_name)) 13 | dest_dir = os.path.join(result_dir, "%s/%s/got10k_submit/" % (tracker_name, cfg_name)) 14 | if not os.path.exists(dest_dir): 15 | os.makedirs(dest_dir) 16 | items = os.listdir(src_dir) 17 | for item in items: 18 | if "all" in item: 19 | continue 20 | src_path = os.path.join(src_dir, item) 21 | if "time" not in item: 22 | seq_name = item.replace(".txt", '') 23 | seq_dir = os.path.join(dest_dir, seq_name) 24 | if not os.path.exists(seq_dir): 25 | os.makedirs(seq_dir) 26 | new_item = item.replace(".txt", '_001.txt') 27 | dest_path = os.path.join(seq_dir, new_item) 28 | bbox_arr = np.loadtxt(src_path, dtype=np.int64, delimiter='\t') 29 | np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',') 30 | else: 31 | seq_name = item.replace("_time.txt", '') 32 | seq_dir = os.path.join(dest_dir, seq_name) 33 | if not os.path.exists(seq_dir): 34 | os.makedirs(seq_dir) 35 | dest_path = os.path.join(seq_dir, item) 36 | os.system("cp %s %s" % (src_path, dest_path)) 37 | # make zip archive 38 | shutil.make_archive(src_dir, "zip", src_dir) 39 | shutil.make_archive(dest_dir, "zip", dest_dir) 40 | # Remove the original files 41 | shutil.rmtree(src_dir) 42 | shutil.rmtree(dest_dir) 43 | 44 | 45 | if __name__ == "__main__": 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser(description='transform trackingnet results.') 48 | parser.add_argument('--tracker_name', type=str, help='Name of tracking method.') 49 | parser.add_argument('--cfg_name', type=str, help='Name of config file.') 50 | 51 | args = parser.parse_args() 52 | transform_got10k('lightfc', 53 | 'mobilenetv2_p_pwcorr_se_ffn_next_concat_repn_33_se_center_concat_adamw_wiou') 54 | -------------------------------------------------------------------------------- /lib/test/utils/transform_trackingnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | import argparse 5 | import _init_paths 6 | from lib.test.evaluation.environment import env_settings 7 | 8 | 9 | def transform_trackingnet(tracker_name, cfg_name): 10 | env = env_settings(env_num=0) 11 | result_dir = env.results_path 12 | src_dir = os.path.join(result_dir, "%s/%s/trackingnet/" % (tracker_name, cfg_name)) 13 | dest_dir = os.path.join(result_dir, "%s/%s/trackingnet_submit/" % (tracker_name, cfg_name)) 14 | if not os.path.exists(dest_dir): 15 | os.makedirs(dest_dir) 16 | items = os.listdir(src_dir) 17 | for item in items: 18 | if "all" in item: 19 | continue 20 | if "time" not in item: 21 | src_path = os.path.join(src_dir, item) 22 | dest_path = os.path.join(dest_dir, item) 23 | try: 24 | bbox_arr = np.loadtxt(src_path, dtype=np.int64, delimiter='\t') 25 | except: 26 | print(src_path) 27 | np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',') 28 | # make zip archive 29 | shutil.make_archive(src_dir, "zip", src_dir) 30 | shutil.make_archive(dest_dir, "zip", dest_dir) 31 | # Remove the original files 32 | shutil.rmtree(src_dir) 33 | shutil.rmtree(dest_dir) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser(description='transform trackingnet results.') 38 | parser.add_argument('--tracker_name', type=str, help='Name of tracking method.') 39 | parser.add_argument('--cfg_name', type=str, help='Name of config file.') 40 | 41 | args = parser.parse_args() 42 | transform_trackingnet('lightfc', 'mobilenetv2_p_pwcorr_se_ffn_next_concat_repn_33_se_center_concat_adamw_wiou') 43 | -------------------------------------------------------------------------------- /lib/test/vot_utils/lightfc.py: -------------------------------------------------------------------------------- 1 | from lib.test.vot_utils.lightfc_vot import run_vot_exp 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | run_vot_exp('lightfc', 'mobilnetv2_p_pwcorr_se_scf_sc_iab_sc_adj_concat_repn33_se_conv33_center_wiou', vis=False) -------------------------------------------------------------------------------- /lib/test/vot_utils/lightfc_vot.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import cv2 7 | import torch 8 | import vot 9 | import sys 10 | import time 11 | import os 12 | from lib.test.evaluation import Tracker 13 | from lib.test.vot_utils.vot import VOT 14 | from lib.test.vot_utils.utils import * 15 | 16 | '''lightfc class''' 17 | 18 | 19 | class lightfc_vot20(object): 20 | def __init__(self, tracker_name='lightfc', 21 | para_name='mobilnetv2_p_pwcorr_se_scf_sc_iab_sc_adj_concat_repn33_se_conv33_center_wiou'): 22 | # create tracker 23 | tracker_info = Tracker(tracker_name, para_name, "vot20", None, env_num=0) 24 | params = tracker_info.get_parameters() 25 | params.visualization = False 26 | params.debug = False 27 | self.tracker = tracker_info.create_tracker(params) 28 | 29 | def initialize(self, img_rgb, mask): 30 | # VOT20 31 | # init on the 1st frame 32 | region = rect_from_mask(mask) 33 | self.H, self.W, _ = img_rgb.shape 34 | init_info = {'init_bbox': region} 35 | _ = self.tracker.initialize(img_rgb, init_info) 36 | 37 | 38 | def track(self, img_rgb): 39 | # track 40 | outputs = self.tracker.track(img_rgb) 41 | pred_bbox = outputs['target_bbox'] 42 | final_mask = mask_from_rect(pred_bbox, (self.W, self.H)) 43 | return pred_bbox, final_mask 44 | 45 | 46 | def run_vot_exp(tracker_name, para_name, vis=False): 47 | torch.set_num_threads(1) 48 | save_root = os.path.join('/data/sda/v-yanbi/iccv21/LittleBoy/vot20_debug', para_name) 49 | if vis and (not os.path.exists(save_root)): 50 | os.mkdir(save_root) 51 | tracker = lightfc_vot20(tracker_name=tracker_name, para_name=para_name) 52 | handle = VOT("mask") 53 | selection = handle.region() 54 | imagefile = handle.frame() 55 | if not imagefile: 56 | sys.exit(0) 57 | if vis: 58 | '''for vis''' 59 | seq_name = imagefile.split('/')[-3] 60 | save_v_dir = os.path.join(save_root, seq_name) 61 | if not os.path.exists(save_v_dir): 62 | os.mkdir(save_v_dir) 63 | cur_time = int(time.time() % 10000) 64 | save_dir = os.path.join(save_v_dir, str(cur_time)) 65 | if not os.path.exists(save_dir): 66 | os.makedirs(save_dir) 67 | 68 | image = cv2.cvtColor(cv2.imread(imagefile), cv2.COLOR_BGR2RGB) # Right 69 | # mask given by the toolkit ends with the target (zero-padding to the right and down is needed) 70 | # mask = make_full_size(selection, (image.shape[1], image.shape[0])) 71 | tracker.initialize(image, selection) 72 | 73 | while True: 74 | imagefile = handle.frame() 75 | if not imagefile: 76 | break 77 | image = cv2.cvtColor(cv2.imread(imagefile), cv2.COLOR_BGR2RGB) # Right 78 | b1, m = tracker.track(image) 79 | handle.report(m) 80 | if vis: 81 | '''Visualization''' 82 | # original image 83 | image_ori = image[:, :, ::-1].copy() # RGB --> BGR 84 | image_name = imagefile.split('/')[-1] 85 | save_path = os.path.join(save_dir, image_name) 86 | cv2.imwrite(save_path, image_ori) 87 | # tracker box 88 | image_b = image_ori.copy() 89 | cv2.rectangle(image_b, (int(b1[0]), int(b1[1])), 90 | (int(b1[0] + b1[2]), int(b1[1] + b1[3])), (0, 0, 255), 2) 91 | image_b_name = image_name.replace('.jpg', '_bbox.jpg') 92 | save_path = os.path.join(save_dir, image_b_name) 93 | cv2.imwrite(save_path, image_b) 94 | # original image + mask 95 | image_m = image_ori.copy().astype(np.float32) 96 | image_m[:, :, 1] += 127.0 * m 97 | image_m[:, :, 2] += 127.0 * m 98 | contours, _ = cv2.findContours(m, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 99 | image_m = cv2.drawContours(image_m, contours, -1, (0, 255, 255), 2) 100 | image_m = image_m.clip(0, 255).astype(np.uint8) 101 | image_mask_name_m = image_name.replace('.jpg', '_mask.jpg') 102 | save_path = os.path.join(save_dir, image_mask_name_m) 103 | cv2.imwrite(save_path, image_m) 104 | -------------------------------------------------------------------------------- /lib/test/vot_utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def make_full_size(x, output_sz): 5 | """ 6 | zero-pad input x (right and down) to match output_sz 7 | x: numpy array e.g., binary mask 8 | output_sz: size of the output [width, height] 9 | """ 10 | if x.shape[0] == output_sz[1] and x.shape[1] == output_sz[0]: 11 | return x 12 | pad_x = output_sz[0] - x.shape[1] 13 | if pad_x < 0: 14 | x = x[:, :x.shape[1] + pad_x] 15 | # padding has to be set to zero, otherwise pad function fails 16 | pad_x = 0 17 | pad_y = output_sz[1] - x.shape[0] 18 | if pad_y < 0: 19 | x = x[:x.shape[0] + pad_y, :] 20 | # padding has to be set to zero, otherwise pad function fails 21 | pad_y = 0 22 | return np.pad(x, ((0, pad_y), (0, pad_x)), 'constant', constant_values=0) 23 | 24 | 25 | def rect_from_mask(mask): 26 | """ 27 | create an axis-aligned rectangle from a given binary mask 28 | mask in created as a minimal rectangle containing all non-zero pixels 29 | """ 30 | x_ = np.sum(mask, axis=0) 31 | y_ = np.sum(mask, axis=1) 32 | x0 = np.min(np.nonzero(x_)) 33 | x1 = np.max(np.nonzero(x_)) 34 | y0 = np.min(np.nonzero(y_)) 35 | y1 = np.max(np.nonzero(y_)) 36 | return [x0, y0, x1 - x0 + 1, y1 - y0 + 1] 37 | 38 | 39 | def mask_from_rect(rect, output_sz): 40 | """ 41 | create a binary mask from a given rectangle 42 | rect: axis-aligned rectangle [x0, y0, width, height] 43 | output_sz: size of the output [width, height] 44 | """ 45 | mask = np.zeros((output_sz[1], output_sz[0]), dtype=np.uint8) 46 | x0 = max(int(round(rect[0])), 0) 47 | y0 = max(int(round(rect[1])), 0) 48 | x1 = min(int(round(rect[0] + rect[2])), output_sz[0]) 49 | y1 = min(int(round(rect[1] + rect[3])), output_sz[1]) 50 | mask[y0:y1, x0:x1] = 1 51 | return mask 52 | 53 | 54 | def bbox_clip(x1, y1, x2, y2, boundary, min_sz=10): 55 | """boundary (H,W)""" 56 | x1_new = max(0, min(x1, boundary[1] - min_sz)) 57 | y1_new = max(0, min(y1, boundary[0] - min_sz)) 58 | x2_new = max(min_sz, min(x2, boundary[1])) 59 | y2_new = max(min_sz, min(y2, boundary[0])) 60 | return x1_new, y1_new, x2_new, y2_new -------------------------------------------------------------------------------- /lib/test/vot_utils/vot.py: -------------------------------------------------------------------------------- 1 | """ 2 | \file vot.py 3 | @brief Python utility functions for VOT integration 4 | @author Luka Cehovin, Alessio Dore 5 | @date 2016 6 | """ 7 | 8 | import sys 9 | import copy 10 | import collections 11 | import numpy as np 12 | 13 | try: 14 | import trax 15 | except ImportError: 16 | raise Exception('TraX support not found. Please add trax module to Python path.') 17 | 18 | Rectangle = collections.namedtuple('Rectangle', ['x', 'y', 'width', 'height']) 19 | Point = collections.namedtuple('Point', ['x', 'y']) 20 | Polygon = collections.namedtuple('Polygon', ['points']) 21 | Empty = collections.namedtuple('Empty', []) 22 | 23 | class VOT(object): 24 | """ Base class for Python VOT integration """ 25 | 26 | def __init__(self, region_format, channels=None): 27 | """ Constructor 28 | Args: 29 | region_format: Region format options 30 | """ 31 | assert (region_format in [trax.Region.RECTANGLE, trax.Region.POLYGON, trax.Region.MASK]) 32 | 33 | if channels is None: 34 | channels = ['color'] 35 | elif channels == 'rgbd': 36 | channels = ['color', 'depth'] 37 | elif channels == 'rgbt': 38 | channels = ['color', 'ir'] 39 | elif channels == 'ir': 40 | channels = ['ir'] 41 | else: 42 | raise Exception('Illegal configuration {}.'.format(channels)) 43 | 44 | # self._trax = trax.Server([region_format], [trax.Image.PATH], channels, customMetadata=dict(vot="python")) 45 | self._trax = trax.Server([region_format], [trax.Image.PATH], channels, metadata=dict(vot="python")) 46 | request = self._trax.wait() 47 | assert (request.type == 'initialize') 48 | 49 | for object, _ in request.objects: 50 | if isinstance(object, trax.Polygon): 51 | self._region = Polygon([Point(x[0], x[1]) for x in object]) 52 | elif isinstance(object, trax.Mask): 53 | self._region = object.array(True) 54 | else: 55 | self._region = Rectangle(*object.bounds()) 56 | 57 | self._image = [x.path() for k, x in request.image.items()] 58 | if len(self._image) == 1: 59 | self._image = self._image[0] 60 | self._trax.status(request.objects) 61 | 62 | def region(self): 63 | """ 64 | Send configuration message to the client and receive the initialization 65 | region and the path of the first image 66 | Returns: 67 | initialization region 68 | """ 69 | 70 | return self._region 71 | 72 | def report(self, region, confidence=None): 73 | """ 74 | Report the tracking results to the client 75 | Arguments: 76 | region: region for the frame 77 | """ 78 | def convert(a): 79 | """ Convert region to TraX format """ 80 | # If region is None, return empty region 81 | if region is None: return trax.Rectangle.create(0, 0, 0, 0) 82 | assert isinstance(region, (Empty, Rectangle, Polygon, np.ndarray)) 83 | if isinstance(region, Empty): 84 | return trax.Rectangle.create(0, 0, 0, 0) 85 | elif isinstance(region, Polygon): 86 | return trax.Polygon.create([(x.x, x.y) for x in region.points]) 87 | elif isinstance(region, np.ndarray): 88 | return trax.Mask.create(region) 89 | else: 90 | return trax.Rectangle.create(region.x, region.y, region.width, region.height) 91 | 92 | 93 | properties = {} 94 | if not confidence is None: 95 | properties['confidence'] = confidence 96 | status = [(convert(region), properties)] 97 | 98 | self._trax.status(status, {}) 99 | 100 | def frame(self): 101 | """ 102 | Get a frame (image path) from client 103 | Returns: 104 | absolute path of the image 105 | """ 106 | if hasattr(self, "_image"): 107 | image = self._image 108 | del self._image 109 | return image 110 | 111 | request = self._trax.wait() 112 | 113 | if request.type == 'frame': 114 | image = [x.path() for k, x in request.image.items()] 115 | if len(image) == 1: 116 | return image[0] 117 | return image 118 | else: 119 | return None 120 | 121 | def quit(self): 122 | if hasattr(self, '_trax'): 123 | self._trax.quit() 124 | 125 | def __del__(self): 126 | self.quit() 127 | -------------------------------------------------------------------------------- /lib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .admin.multigpu import MultiGPU -------------------------------------------------------------------------------- /lib/train/_init_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | 14 | this_dir = osp.dirname(__file__) 15 | 16 | prj_path = osp.join(this_dir, '../..') 17 | add_path(prj_path) 18 | -------------------------------------------------------------------------------- /lib/train/actors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_actor import BaseActor 2 | from .lightfc import lightTrackActor 3 | from .lightfc_st import lightTrackSTActor 4 | -------------------------------------------------------------------------------- /lib/train/actors/base_actor.py: -------------------------------------------------------------------------------- 1 | from lib.utils import TensorDict 2 | 3 | 4 | class BaseActor: 5 | """ Base class for actor. The actor class handles the passing of the data through the network 6 | and calculation the loss""" 7 | def __init__(self, net, objective): 8 | """ 9 | args: 10 | net - The network to train 11 | objective - The loss function 12 | """ 13 | self.net = net 14 | self.objective = objective 15 | 16 | def __call__(self, data: TensorDict): 17 | """ Called in each training iteration. Should pass in input data through the network, calculate the loss, and 18 | return the training stats for the input data 19 | args: 20 | data - A TensorDict containing all the necessary data blocks. 21 | 22 | returns: 23 | loss - loss for the input data 24 | stats - a dict containing detailed losses 25 | """ 26 | raise NotImplementedError 27 | 28 | def to(self, device): 29 | """ Move the network to device 30 | args: 31 | device - device to use. 'cpu' or 'cuda' 32 | """ 33 | self.net.to(device) 34 | 35 | def train(self, mode=True): 36 | """ Set whether the network is in train mode. 37 | args: 38 | mode (True) - Bool specifying whether in training mode. 39 | """ 40 | self.net.train(mode) 41 | 42 | def eval(self): 43 | """ Set network to eval mode""" 44 | self.train(False) -------------------------------------------------------------------------------- /lib/train/actors/lightfc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import BaseActor 4 | from ..loss.cos_sim_loss import cosine_similarity_loss 5 | from ...utils.box_ops import box_xywh_to_xyxy, box_cxcywh_to_xyxy 6 | from ...utils.heapmap_utils import generate_heatmap 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class lightTrackActor(BaseActor): 11 | def __init__(self, net, objective, loss_weight, settings, cfg=None): 12 | super().__init__(net, objective) 13 | self.loss_weight = loss_weight 14 | self.settings = settings 15 | self.bs = self.settings.batchsize # batch size 16 | self.cfg = cfg 17 | 18 | # triple loss 19 | # self.avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1)) 20 | # self.triple = nn.TripletMarginLoss(margin=1, p=2, reduction='mean') 21 | 22 | # self.transform = transforms.RandomErasing(p=0.05, scale=(0.02, 0.4), ratio=(0.3, 3.3), value=0, inplace=False) 23 | 24 | def __call__(self, data): 25 | 26 | out_dict = self.forward_pass(data) 27 | 28 | loss, status = self.compute_losses(out_dict, data) 29 | return loss, status 30 | 31 | def forward_pass(self, data): 32 | template_list = [] 33 | for i in range(self.settings.num_template): 34 | template_img_i = data['template_images'][:, i, :].view(-1, *data['template_images'].shape[2:]) 35 | template_list.append(template_img_i) 36 | 37 | search_img = data['search_images'][:, 0, :].view(-1, *data['search_images'].shape[2:]) 38 | # search_img = self.transform(search_img) 39 | 40 | if len(template_list) == 1: 41 | template_list = template_list[0] 42 | 43 | out_dict = self.net(z=template_list, x=search_img) 44 | return out_dict 45 | 46 | def compute_losses(self, pred_dict, gt_dict, return_status=True): 47 | bs, n, _ = gt_dict['search_anno'].shape 48 | gt_bbox = gt_dict['search_anno'].view(bs, 4) 49 | 50 | gt_gaussian_maps = generate_heatmap(gt_dict['search_anno'].view(n, bs, 4), self.cfg.DATA.SEARCH.SIZE, 51 | self.cfg.MODEL.BACKBONE.STRIDE) 52 | gt_gaussian_maps_flatten = gt_gaussian_maps[-1].unsqueeze(1) 53 | 54 | pred_boxes = pred_dict['pred_boxes'] 55 | if torch.isnan(pred_boxes).any(): 56 | raise ValueError("Network outputs is NAN! Stop Training") 57 | 58 | pred_boxes_vec = box_cxcywh_to_xyxy(pred_boxes).view(-1, 4) # (B,N,4) --> (BN,4) (x1,y1,x2,y2) 59 | gt_boxes_vec = box_xywh_to_xyxy(gt_bbox).view(-1, 4).clamp(min=0.0, max=1.0) # (B,4) --> (B,1,4) --> (B,N,4) 60 | 61 | # locate box 62 | try: 63 | iou_loss, iou = self.objective.iou(pred_boxes_vec, gt_boxes_vec) # (BN,4) (BN,4) 64 | except: 65 | iou_loss, iou = torch.tensor(0.0).cuda(), torch.tensor(0.0).cuda() 66 | 67 | # l1 loss 68 | l1_loss = self.objective.l1(pred_boxes_vec, gt_boxes_vec) # (BN,4) (BN,4) 69 | 70 | if 'score_map' in pred_dict: 71 | location_loss = self.objective.focal_loss(pred_dict['score_map'], gt_gaussian_maps_flatten) 72 | else: 73 | location_loss = torch.tensor(0.0, device=l1_loss.device) 74 | 75 | 76 | 77 | # weighted sum 78 | loss = self.loss_weight['iou'] * iou_loss + self.loss_weight['l1'] * l1_loss + self.loss_weight[ 79 | 'focal'] * location_loss # + compute_tri_loss * 0.05 80 | # * location_loss # cos_sim_loss * 0.1 81 | 82 | # return 83 | if return_status: 84 | # status for log 85 | mean_iou = iou.detach().mean() 86 | status = {"Loss/total": loss.item(), 87 | "Loss/giou": iou_loss.item(), 88 | "Loss/l1": l1_loss.item(), 89 | "Loss/location": location_loss.item(), 90 | # "Loss/cossim": cos_sim_loss.item(), 91 | # "Loss/triple": compute_tri_loss.item(), 92 | "mean_IoU": mean_iou.item(), 93 | } 94 | return loss, status 95 | else: 96 | return loss 97 | -------------------------------------------------------------------------------- /lib/train/admin/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import env_settings, create_default_local_file_ITP_train 2 | from .stats import AverageMeter, StatValue 3 | from .tensorboard import TensorboardWriter 4 | -------------------------------------------------------------------------------- /lib/train/admin/environment.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from collections import OrderedDict 4 | 5 | 6 | def create_default_local_file(): 7 | path = os.path.join(os.path.dirname(__file__), 'local.py') 8 | 9 | empty_str = '\'\'' 10 | default_settings = OrderedDict({ 11 | 'workspace_dir': empty_str, 12 | 'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'', 13 | 'pretrained_networks': 'self.workspace_dir + \'/pretrained_networks/\'', 14 | 'lasot_dir': empty_str, 15 | 'got10k_dir': empty_str, 16 | 'trackingnet_dir': empty_str, 17 | 'coco_dir': empty_str, 18 | 'lvis_dir': empty_str, 19 | 'sbd_dir': empty_str, 20 | 'imagenet_dir': empty_str, 21 | 'imagenetdet_dir': empty_str, 22 | 'ecssd_dir': empty_str, 23 | 'hkuis_dir': empty_str, 24 | 'msra10k_dir': empty_str, 25 | 'davis_dir': empty_str, 26 | 'youtubevos_dir': empty_str}) 27 | 28 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.', 29 | 'tensorboard_dir': 'Directory for tensorboard files.'} 30 | 31 | with open(path, 'w') as f: 32 | f.write('class EnvironmentSettings:\n') 33 | f.write(' def __init__(self):\n') 34 | 35 | for attr, attr_val in default_settings.items(): 36 | comment_str = None 37 | if attr in comment: 38 | comment_str = comment[attr] 39 | if comment_str is None: 40 | f.write(' self.{} = {}\n'.format(attr, attr_val)) 41 | else: 42 | f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str)) 43 | 44 | 45 | def create_default_local_file_ITP_train(workspace_dir, data_dir): 46 | path = os.path.join(os.path.dirname(__file__), 'local.py') 47 | 48 | empty_str = '\'\'' 49 | default_settings = OrderedDict({ 50 | 'workspace_dir': workspace_dir, 51 | 'tensorboard_dir': os.path.join(workspace_dir, 'tensorboard'), # Directory for tensorboard files. 52 | 'pretrained_networks': os.path.join(workspace_dir, 'pretrained_networks'), 53 | 'lasot_dir': os.path.join(data_dir, 'lasot'), 54 | 'got10k_dir': os.path.join(data_dir, 'got10k/train'), 55 | 'got10k_val_dir': os.path.join(data_dir, 'got10k/val'), 56 | 'lasot_lmdb_dir': os.path.join(data_dir, 'lasot_lmdb'), 57 | 'got10k_lmdb_dir': os.path.join(data_dir, 'got10k_lmdb'), 58 | 'trackingnet_dir': os.path.join(data_dir, 'trackingnet'), 59 | 'trackingnet_lmdb_dir': os.path.join(data_dir, 'trackingnet_lmdb'), 60 | 'coco_dir': os.path.join(data_dir, 'coco'), 61 | 'coco_lmdb_dir': os.path.join(data_dir, 'coco_lmdb'), 62 | 'lvis_dir': empty_str, 63 | 'sbd_dir': empty_str, 64 | 'imagenet_dir': os.path.join(data_dir, 'vid'), 65 | 'imagenet_lmdb_dir': os.path.join(data_dir, 'vid_lmdb'), 66 | 'imagenetdet_dir': empty_str, 67 | 'ecssd_dir': empty_str, 68 | 'hkuis_dir': empty_str, 69 | 'msra10k_dir': empty_str, 70 | 'davis_dir': empty_str, 71 | 'youtubevos_dir': empty_str}) 72 | 73 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.', 74 | 'tensorboard_dir': 'Directory for tensorboard files.'} 75 | 76 | with open(path, 'w') as f: 77 | f.write('class EnvironmentSettings:\n') 78 | f.write(' def __init__(self):\n') 79 | 80 | for attr, attr_val in default_settings.items(): 81 | comment_str = None 82 | if attr in comment: 83 | comment_str = comment[attr] 84 | if comment_str is None: 85 | if attr_val == empty_str: 86 | f.write(' self.{} = {}\n'.format(attr, attr_val)) 87 | else: 88 | f.write(' self.{} = \'{}\'\n'.format(attr, attr_val)) 89 | else: 90 | f.write(' self.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 91 | 92 | 93 | def env_settings(env_num): 94 | env_module_name = 'lib.train.admin.local' 95 | try: 96 | env_module = importlib.import_module(env_module_name) 97 | return env_module.EnvironmentSettings(env_num) 98 | except: 99 | env_file = os.path.join(os.path.dirname(__file__), 'local.py') 100 | 101 | # create_default_local_file() 102 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file)) 103 | -------------------------------------------------------------------------------- /lib/train/admin/local.py: -------------------------------------------------------------------------------- 1 | class EnvironmentSettings: 2 | def __init__(self, env_num=0): 3 | self.workspace_dir = r'/home/liyunfeng/code/project2/LightFC' # Base directory for saving network checkpoints. 4 | self.tensorboard_dir = r'/home/liyunfeng/code/project2/LightFC/tensorboard' # Directory for tensorboard files. 5 | self.pretrained_networks = r'/home/liyunfeng/code/project2/LightFC/pretrained_models' 6 | 7 | self.lasot_dir = '' 8 | self.got10k_dir = '' 9 | self.got10k_val_dir = '' 10 | self.lasot_lmdb_dir = '' 11 | self.got10k_lmdb_dir = '' 12 | self.trackingnet_dir = '' 13 | self.trackingnet_lmdb_dir = '' 14 | self.coco_dir = '' 15 | self.coco_lmdb_dir = '' 16 | self.lvis_dir = '' 17 | self.sbd_dir = '' 18 | 19 | self.imagenet_dir = '' 20 | self.imagenet_lmdb_dir = '' 21 | self.imagenetdet_dir = '' 22 | self.ecssd_dir = '' 23 | self.hkuis_dir = '' 24 | self.msra10k_dir = '' 25 | self.davis_dir = '' 26 | self.youtubevos_dir = '' 27 | -------------------------------------------------------------------------------- /lib/train/admin/multigpu.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training 3 | 4 | 5 | def is_multi_gpu(net): 6 | return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel)) 7 | 8 | 9 | class MultiGPU(nn.parallel.distributed.DistributedDataParallel): 10 | def __getattr__(self, item): 11 | try: 12 | return super().__getattr__(item) 13 | except: 14 | pass 15 | return getattr(self.module, item) 16 | -------------------------------------------------------------------------------- /lib/train/admin/settings.py: -------------------------------------------------------------------------------- 1 | from lib.train.admin.environment import env_settings 2 | 3 | 4 | class Settings: 5 | """ Training settings, e.g. the paths to datasets and networks.""" 6 | 7 | def __init__(self, env_num): 8 | self.set_default(env_num) 9 | 10 | def set_default(self, env_num): 11 | self.env = env_settings(env_num) 12 | self.use_gpu = True 13 | -------------------------------------------------------------------------------- /lib/train/admin/stats.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class StatValue: 4 | def __init__(self): 5 | self.clear() 6 | 7 | def reset(self): 8 | self.val = 0 9 | 10 | def clear(self): 11 | self.reset() 12 | self.history = [] 13 | 14 | def update(self, val): 15 | self.val = val 16 | self.history.append(self.val) 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | def __init__(self): 22 | self.clear() 23 | self.has_new_data = False 24 | 25 | def reset(self): 26 | self.avg = 0 27 | self.val = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def clear(self): 32 | self.reset() 33 | self.history = [] 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | def new_epoch(self): 42 | if self.count > 0: 43 | self.history.append(self.avg) 44 | self.reset() 45 | self.has_new_data = True 46 | else: 47 | self.has_new_data = False 48 | 49 | 50 | def topk_accuracy(output, target, topk=(1,)): 51 | """Computes the precision@k for the specified values of k""" 52 | single_input = not isinstance(topk, (tuple, list)) 53 | if single_input: 54 | topk = (topk,) 55 | 56 | maxk = max(topk) 57 | batch_size = target.size(0) 58 | 59 | _, pred = output.topk(maxk, 1, True, True) 60 | pred = pred.t() 61 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 62 | 63 | res = [] 64 | for k in topk: 65 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0] 66 | res.append(correct_k * 100.0 / batch_size) 67 | 68 | if single_input: 69 | return res[0] 70 | 71 | return res 72 | -------------------------------------------------------------------------------- /lib/train/admin/tensorboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | try: 4 | from torch.utils.tensorboard import SummaryWriter 5 | except: 6 | print('WARNING: You are using tensorboardX instead sis you have a too old pytorch version.') 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | class TensorboardWriter: 11 | def __init__(self, directory, loader_names): 12 | self.directory = directory 13 | self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names}) 14 | 15 | def write_info(self, script_name, description): 16 | tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info')) 17 | tb_info_writer.add_text('Script_name', script_name) 18 | tb_info_writer.add_text('Description', description) 19 | tb_info_writer.close() 20 | 21 | def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1): 22 | for loader_name, loader_stats in stats.items(): 23 | if loader_stats is None: 24 | continue 25 | for var_name, val in loader_stats.items(): 26 | if hasattr(val, 'history') and getattr(val, 'has_new_data', True): 27 | self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch) -------------------------------------------------------------------------------- /lib/train/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import LTRLoader 2 | from .image_loader import jpeg4py_loader, opencv_loader, jpeg4py_loader_w_failsafe, default_image_loader 3 | -------------------------------------------------------------------------------- /lib/train/data/bounding_box_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rect_to_rel(bb, sz_norm=None): 5 | """Convert standard rectangular parametrization of the bounding box [x, y, w, h] 6 | to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate. 7 | args: 8 | bb - N x 4 tensor of boxes. 9 | sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given. 10 | """ 11 | 12 | c = bb[...,:2] + 0.5 * bb[...,2:] 13 | if sz_norm is None: 14 | c_rel = c / bb[...,2:] 15 | else: 16 | c_rel = c / sz_norm 17 | sz_rel = torch.log(bb[...,2:]) 18 | return torch.cat((c_rel, sz_rel), dim=-1) 19 | 20 | 21 | def rel_to_rect(bb, sz_norm=None): 22 | """Inverts the effect of rect_to_rel. See above.""" 23 | 24 | sz = torch.exp(bb[...,2:]) 25 | if sz_norm is None: 26 | c = bb[...,:2] * sz 27 | else: 28 | c = bb[...,:2] * sz_norm 29 | tl = c - 0.5 * sz 30 | return torch.cat((tl, sz), dim=-1) 31 | 32 | 33 | def masks_to_bboxes(mask, fmt='c'): 34 | 35 | """ Convert a mask tensor to one or more bounding boxes. 36 | Note: This function is a bit new, make sure it does what it says. /Andreas 37 | :param mask: Tensor of masks, shape = (..., H, W) 38 | :param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height) 39 | 't' => "top left + size" or (x_left, y_top, width, height) 40 | 'v' => "vertices" or (x_left, y_top, x_right, y_bottom) 41 | :return: tensor containing a batch of bounding boxes, shape = (..., 4) 42 | """ 43 | batch_shape = mask.shape[:-2] 44 | mask = mask.reshape((-1, *mask.shape[-2:])) 45 | bboxes = [] 46 | 47 | for m in mask: 48 | mx = m.sum(dim=-2).nonzero() 49 | my = m.sum(dim=-1).nonzero() 50 | bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0] 51 | bboxes.append(bb) 52 | 53 | bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device) 54 | bboxes = bboxes.reshape(batch_shape + (4,)) 55 | 56 | if fmt == 'v': 57 | return bboxes 58 | 59 | x1 = bboxes[..., :2] 60 | s = bboxes[..., 2:] - x1 + 1 61 | 62 | if fmt == 'c': 63 | return torch.cat((x1 + 0.5 * s, s), dim=-1) 64 | elif fmt == 't': 65 | return torch.cat((x1, s), dim=-1) 66 | 67 | raise ValueError("Undefined bounding box layout '%s'" % fmt) 68 | 69 | 70 | def masks_to_bboxes_multi(mask, ids, fmt='c'): 71 | assert mask.dim() == 2 72 | bboxes = [] 73 | 74 | for id in ids: 75 | mx = (mask == id).sum(dim=-2).nonzero() 76 | my = (mask == id).float().sum(dim=-1).nonzero() 77 | bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0] 78 | 79 | bb = torch.tensor(bb, dtype=torch.float32, device=mask.device) 80 | 81 | x1 = bb[:2] 82 | s = bb[2:] - x1 + 1 83 | 84 | if fmt == 'v': 85 | pass 86 | elif fmt == 'c': 87 | bb = torch.cat((x1 + 0.5 * s, s), dim=-1) 88 | elif fmt == 't': 89 | bb = torch.cat((x1, s), dim=-1) 90 | else: 91 | raise ValueError("Undefined bounding box layout '%s'" % fmt) 92 | bboxes.append(bb) 93 | 94 | return bboxes 95 | -------------------------------------------------------------------------------- /lib/train/data/image_loader.py: -------------------------------------------------------------------------------- 1 | import jpeg4py 2 | import cv2 as cv 3 | from PIL import Image 4 | import numpy as np 5 | 6 | davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8) 7 | davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 8 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 9 | [64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0], 10 | [64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128], 11 | [0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0], 12 | [0, 64, 128], [128, 64, 128]] 13 | 14 | 15 | def default_image_loader(path): 16 | """The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader, 17 | but reverts to the opencv_loader if the former is not available.""" 18 | if default_image_loader.use_jpeg4py is None: 19 | # Try using jpeg4py 20 | im = jpeg4py_loader(path) 21 | if im is None: 22 | default_image_loader.use_jpeg4py = False 23 | print('Using opencv_loader instead.') 24 | else: 25 | default_image_loader.use_jpeg4py = True 26 | return im 27 | if default_image_loader.use_jpeg4py: 28 | return jpeg4py_loader(path) 29 | return opencv_loader(path) 30 | 31 | default_image_loader.use_jpeg4py = None 32 | 33 | 34 | def jpeg4py_loader(path): 35 | """ Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py""" 36 | try: 37 | return jpeg4py.JPEG(path).decode() 38 | except Exception as e: 39 | print('ERROR: Could not read image "{}"'.format(path)) 40 | print(e) 41 | return None 42 | 43 | 44 | def opencv_loader(path): 45 | """ Read image using opencv's imread function and returns it in rgb format""" 46 | try: 47 | im = cv.imread(path, cv.IMREAD_COLOR) 48 | 49 | # convert to rgb and return 50 | return cv.cvtColor(im, cv.COLOR_BGR2RGB) 51 | except Exception as e: 52 | print('ERROR: Could not read image "{}"'.format(path)) 53 | print(e) 54 | return None 55 | 56 | 57 | def jpeg4py_loader_w_failsafe(path): 58 | """ Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py""" 59 | try: 60 | return jpeg4py.JPEG(path).decode() 61 | except: 62 | try: 63 | im = cv.imread(path, cv.IMREAD_COLOR) 64 | 65 | # convert to rgb and return 66 | return cv.cvtColor(im, cv.COLOR_BGR2RGB) 67 | except Exception as e: 68 | print('ERROR: Could not read image "{}"'.format(path)) 69 | print(e) 70 | return None 71 | 72 | 73 | def opencv_seg_loader(path): 74 | """ Read segmentation annotation using opencv's imread function""" 75 | try: 76 | return cv.imread(path) 77 | except Exception as e: 78 | print('ERROR: Could not read image "{}"'.format(path)) 79 | print(e) 80 | return None 81 | 82 | 83 | def imread_indexed(filename): 84 | """ Load indexed image with given filename. Used to read segmentation annotations.""" 85 | 86 | im = Image.open(filename) 87 | 88 | annotation = np.atleast_3d(im)[...,0] 89 | return annotation 90 | 91 | 92 | def imwrite_indexed(filename, array, color_palette=None): 93 | """ Save indexed image as png. Used to save segmentation annotation.""" 94 | 95 | if color_palette is None: 96 | color_palette = davis_palette 97 | 98 | if np.atleast_3d(array).shape[2] != 1: 99 | raise Exception("Saving indexed PNGs requires 2D array.") 100 | 101 | im = Image.fromarray(array) 102 | im.putpalette(color_palette.ravel()) 103 | im.save(filename, format='PNG') -------------------------------------------------------------------------------- /lib/train/data/sequence_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data.dataloader 2 | 3 | if float(torch.__version__[:3]) >= 1.9 or len('.'.join((torch.__version__).split('.')[0:2])) > 3: 4 | int_classes = int 5 | else: 6 | from torch._six import int_classes 7 | 8 | 9 | def slt_collate(batch): 10 | ret = {} 11 | for k in batch[0].keys(): 12 | here_list = [] 13 | for ex in batch: 14 | here_list.append(ex[k]) 15 | ret[k] = here_list 16 | return ret 17 | 18 | 19 | class SLTLoader(torch.utils.data.dataloader.DataLoader): 20 | """ 21 | Data loader. Combines a dataset and a sampler, and provides 22 | single- or multi-process iterators over the dataset. 23 | """ 24 | 25 | __initialized = False 26 | 27 | def __init__(self, name, dataset, training=True, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 28 | num_workers=0, epoch_interval=1, collate_fn=None, stack_dim=0, pin_memory=False, drop_last=False, 29 | timeout=0, worker_init_fn=None): 30 | if collate_fn is None: 31 | collate_fn = slt_collate 32 | 33 | super(SLTLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, 34 | num_workers, collate_fn, pin_memory, drop_last, 35 | timeout, worker_init_fn) 36 | 37 | self.name = name 38 | self.training = training 39 | self.epoch_interval = epoch_interval 40 | self.stack_dim = stack_dim 41 | -------------------------------------------------------------------------------- /lib/train/data/wandb_logger.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | # try: 4 | # import wandb 5 | # except ImportError: 6 | # raise ImportError( 7 | # 'Please run "pip install wandb" to install wandb') 8 | 9 | 10 | # class WandbWriter: 11 | # def __init__(self, exp_name, cfg, output_dir, cur_step=0, step_interval=0): 12 | # self.wandb = wandb 13 | # self.step = cur_step 14 | # self.interval = step_interval 15 | # wandb.init(project="tracking", name=exp_name, config=cfg, dir=output_dir) 16 | # 17 | # def write_log(self, stats: OrderedDict, epoch=-1): 18 | # self.step += 1 19 | # for loader_name, loader_stats in stats.items(): 20 | # if loader_stats is None: 21 | # continue 22 | # 23 | # log_dict = {} 24 | # for var_name, val in loader_stats.items(): 25 | # if hasattr(val, 'avg'): 26 | # log_dict.update({loader_name + '/' + var_name: val.avg}) 27 | # else: 28 | # log_dict.update({loader_name + '/' + var_name: val.val}) 29 | # 30 | # if epoch >= 0: 31 | # log_dict.update({loader_name + '/epoch': epoch}) 32 | # 33 | # self.wandb.log(log_dict, step=self.step*self.interval) 34 | -------------------------------------------------------------------------------- /lib/train/data_specs/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | ## Description for different text files 4 | GOT10K 5 | - got10k_train_full_split.txt: the complete GOT-10K training set. (9335 videos) 6 | - got10k_train_split.txt: part of videos from the GOT-10K training set 7 | - got10k_val_split.txt: another part of videos from the GOT-10K training set 8 | - got10k_vot_exclude.txt: 1k videos that are forbidden from "using to train models then testing on VOT" (as required by [VOT Challenge](https://www.votchallenge.net/vot2020/participation.html)) 9 | - got10k_vot_train_split.txt: part of videos from the "VOT-permitted" GOT-10K training set 10 | - got10k_vot_val_split.txt: another part of videos from the "VOT-permitted" GOT-10K training set 11 | 12 | LaSOT 13 | - lasot_train_split.txt: the complete LaSOT training set 14 | 15 | TrackingNnet 16 | - trackingnet_classmap.txt: The map from the sequence name to the target class for the TrackingNet -------------------------------------------------------------------------------- /lib/train/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .lasot import Lasot 2 | from .got10k import Got10k 3 | from .tracking_net import TrackingNet 4 | from .imagenetvid import ImagenetVID 5 | from .coco import MSCOCO 6 | from .coco_seq import MSCOCOSeq 7 | from .got10k_lmdb import Got10k_lmdb 8 | from .lasot_lmdb import Lasot_lmdb 9 | from .imagenetvid_lmdb import ImagenetVID_lmdb 10 | from .coco_seq_lmdb import MSCOCOSeq_lmdb 11 | from .tracking_net_lmdb import TrackingNet_lmdb 12 | -------------------------------------------------------------------------------- /lib/train/dataset/base_image_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from lib.train.data.image_loader import jpeg4py_loader 3 | 4 | 5 | class BaseImageDataset(torch.utils.data.Dataset): 6 | """ Base class for image datasets """ 7 | 8 | def __init__(self, name, root, image_loader=jpeg4py_loader): 9 | """ 10 | args: 11 | root - The root path to the dataset 12 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 13 | is used by default. 14 | """ 15 | self.name = name 16 | self.root = root 17 | self.image_loader = image_loader 18 | 19 | self.image_list = [] # Contains the list of sequences. 20 | self.class_list = [] 21 | 22 | def __len__(self): 23 | """ Returns size of the dataset 24 | returns: 25 | int - number of samples in the dataset 26 | """ 27 | return self.get_num_images() 28 | 29 | def __getitem__(self, index): 30 | """ Not to be used! Check get_frames() instead. 31 | """ 32 | return None 33 | 34 | def get_name(self): 35 | """ Name of the dataset 36 | 37 | returns: 38 | string - Name of the dataset 39 | """ 40 | raise NotImplementedError 41 | 42 | def get_num_images(self): 43 | """ Number of sequences in a dataset 44 | 45 | returns: 46 | int - number of sequences in the dataset.""" 47 | return len(self.image_list) 48 | 49 | def has_class_info(self): 50 | return False 51 | 52 | def get_class_name(self, image_id): 53 | return None 54 | 55 | def get_num_classes(self): 56 | return len(self.class_list) 57 | 58 | def get_class_list(self): 59 | return self.class_list 60 | 61 | def get_images_in_class(self, class_name): 62 | raise NotImplementedError 63 | 64 | def has_segmentation_info(self): 65 | return False 66 | 67 | def get_image_info(self, seq_id): 68 | """ Returns information about a particular image, 69 | 70 | args: 71 | seq_id - index of the image 72 | 73 | returns: 74 | Dict 75 | """ 76 | raise NotImplementedError 77 | 78 | def get_image(self, image_id, anno=None): 79 | """ Get a image 80 | 81 | args: 82 | image_id - index of image 83 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded. 84 | 85 | returns: 86 | image - 87 | anno - 88 | dict - A dict containing meta information about the sequence, e.g. class of the target object. 89 | 90 | """ 91 | raise NotImplementedError 92 | 93 | -------------------------------------------------------------------------------- /lib/train/dataset/base_video_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | # 2021.1.5 use jpeg4py_loader_w_failsafe as default 3 | from lib.train.data.image_loader import jpeg4py_loader_w_failsafe 4 | 5 | 6 | class BaseVideoDataset(torch.utils.data.Dataset): 7 | """ Base class for video datasets """ 8 | 9 | def __init__(self, name, root, image_loader=jpeg4py_loader_w_failsafe): 10 | """ 11 | args: 12 | root - The root path to the dataset 13 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 14 | is used by default. 15 | """ 16 | self.name = name 17 | self.root = root 18 | self.image_loader = image_loader 19 | 20 | self.sequence_list = [] # Contains the list of sequences. 21 | self.class_list = [] 22 | 23 | def __len__(self): 24 | """ Returns size of the dataset 25 | returns: 26 | int - number of samples in the dataset 27 | """ 28 | return self.get_num_sequences() 29 | 30 | def __getitem__(self, index): 31 | """ Not to be used! Check get_frames() instead. 32 | """ 33 | return None 34 | 35 | def is_video_sequence(self): 36 | """ Returns whether the dataset is a video dataset or an image dataset 37 | 38 | returns: 39 | bool - True if a video dataset 40 | """ 41 | return True 42 | 43 | def is_synthetic_video_dataset(self): 44 | """ Returns whether the dataset contains real videos or synthetic 45 | 46 | returns: 47 | bool - True if a video dataset 48 | """ 49 | return False 50 | 51 | def get_name(self): 52 | """ Name of the dataset 53 | 54 | returns: 55 | string - Name of the dataset 56 | """ 57 | raise NotImplementedError 58 | 59 | def get_num_sequences(self): 60 | """ Number of sequences in a dataset 61 | 62 | returns: 63 | int - number of sequences in the dataset.""" 64 | return len(self.sequence_list) 65 | 66 | def has_class_info(self): 67 | return False 68 | 69 | def has_occlusion_info(self): 70 | return False 71 | 72 | def get_num_classes(self): 73 | return len(self.class_list) 74 | 75 | def get_class_list(self): 76 | return self.class_list 77 | 78 | def get_sequences_in_class(self, class_name): 79 | raise NotImplementedError 80 | 81 | def has_segmentation_info(self): 82 | return False 83 | 84 | def get_sequence_info(self, seq_id): 85 | """ Returns information about a particular sequences, 86 | 87 | args: 88 | seq_id - index of the sequence 89 | 90 | returns: 91 | Dict 92 | """ 93 | raise NotImplementedError 94 | 95 | def get_frames(self, seq_id, frame_ids, anno=None): 96 | """ Get a set of frames from a particular sequence 97 | 98 | args: 99 | seq_id - index of sequence 100 | frame_ids - a list of frame numbers 101 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded. 102 | 103 | returns: 104 | list - List of frames corresponding to frame_ids 105 | list - List of dicts for each frame 106 | dict - A dict containing meta information about the sequence, e.g. class of the target object. 107 | 108 | """ 109 | raise NotImplementedError 110 | 111 | -------------------------------------------------------------------------------- /lib/train/dataset/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_image_dataset import BaseImageDataset 3 | import torch 4 | import random 5 | from collections import OrderedDict 6 | from lib.train.data import jpeg4py_loader 7 | from lib.train.admin import env_settings 8 | from pycocotools.coco import COCO 9 | 10 | 11 | class MSCOCO(BaseImageDataset): 12 | """ The COCO object detection dataset. 13 | 14 | Publication: 15 | Microsoft COCO: Common Objects in Context. 16 | Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona, 17 | Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick 18 | ECCV, 2014 19 | https://arxiv.org/pdf/1405.0312.pdf 20 | 21 | Download the images along with annotations from http://cocodataset.org/#download. The root folder should be 22 | organized as follows. 23 | - coco_root 24 | - annotations 25 | - instances_train2014.json 26 | - instances_train2017.json 27 | - images 28 | - train2014 29 | - train2017 30 | 31 | Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi. 32 | """ 33 | 34 | def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, min_area=None, 35 | split="train", version="2014"): 36 | """ 37 | args: 38 | root - path to coco root folder 39 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 40 | is used by default. 41 | data_fraction - Fraction of dataset to be used. The complete dataset is used by default 42 | min_area - Objects with area less than min_area are filtered out. Default is 0.0 43 | split - 'train' or 'val'. 44 | version - version of coco dataset (2014 or 2017) 45 | """ 46 | 47 | root = env_settings().coco_dir if root is None else root 48 | super().__init__('COCO', root, image_loader) 49 | 50 | self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version)) 51 | self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version)) 52 | 53 | self.coco_set = COCO(self.anno_path) 54 | 55 | self.cats = self.coco_set.cats 56 | 57 | self.class_list = self.get_class_list() # the parent class thing would happen in the sampler 58 | 59 | self.image_list = self._get_image_list(min_area=min_area) 60 | 61 | if data_fraction is not None: 62 | self.image_list = random.sample(self.image_list, int(len(self.image_list) * data_fraction)) 63 | self.im_per_class = self._build_im_per_class() 64 | 65 | def _get_image_list(self, min_area=None): 66 | ann_list = list(self.coco_set.anns.keys()) 67 | image_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0] 68 | 69 | if min_area is not None: 70 | image_list = [a for a in image_list if self.coco_set.anns[a]['area'] > min_area] 71 | 72 | return image_list 73 | 74 | def get_num_classes(self): 75 | return len(self.class_list) 76 | 77 | def get_name(self): 78 | return 'coco' 79 | 80 | def has_class_info(self): 81 | return True 82 | 83 | def has_segmentation_info(self): 84 | return True 85 | 86 | def get_class_list(self): 87 | class_list = [] 88 | for cat_id in self.cats.keys(): 89 | class_list.append(self.cats[cat_id]['name']) 90 | return class_list 91 | 92 | def _build_im_per_class(self): 93 | im_per_class = {} 94 | for i, im in enumerate(self.image_list): 95 | class_name = self.cats[self.coco_set.anns[im]['category_id']]['name'] 96 | if class_name not in im_per_class: 97 | im_per_class[class_name] = [i] 98 | else: 99 | im_per_class[class_name].append(i) 100 | 101 | return im_per_class 102 | 103 | def get_images_in_class(self, class_name): 104 | return self.im_per_class[class_name] 105 | 106 | def get_image_info(self, im_id): 107 | anno = self._get_anno(im_id) 108 | 109 | bbox = torch.Tensor(anno['bbox']).view(4,) 110 | 111 | mask = torch.Tensor(self.coco_set.annToMask(anno)) 112 | 113 | valid = (bbox[2] > 0) & (bbox[3] > 0) 114 | visible = valid.clone().byte() 115 | 116 | return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible} 117 | 118 | def _get_anno(self, im_id): 119 | anno = self.coco_set.anns[self.image_list[im_id]] 120 | 121 | return anno 122 | 123 | def _get_image(self, im_id): 124 | path = self.coco_set.loadImgs([self.coco_set.anns[self.image_list[im_id]]['image_id']])[0]['file_name'] 125 | img = self.image_loader(os.path.join(self.img_pth, path)) 126 | return img 127 | 128 | def get_meta_info(self, im_id): 129 | try: 130 | cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']] 131 | object_meta = OrderedDict({'object_class_name': cat_dict_current['name'], 132 | 'motion_class': None, 133 | 'major_class': cat_dict_current['supercategory'], 134 | 'root_class': None, 135 | 'motion_adverb': None}) 136 | except: 137 | object_meta = OrderedDict({'object_class_name': None, 138 | 'motion_class': None, 139 | 'major_class': None, 140 | 'root_class': None, 141 | 'motion_adverb': None}) 142 | return object_meta 143 | 144 | def get_class_name(self, im_id): 145 | cat_dict_current = self.cats[self.coco_set.anns[self.image_list[im_id]]['category_id']] 146 | return cat_dict_current['name'] 147 | 148 | def get_image(self, image_id, anno=None): 149 | frame = self._get_image(image_id) 150 | 151 | if anno is None: 152 | anno = self.get_image_info(image_id) 153 | 154 | object_meta = self.get_meta_info(image_id) 155 | 156 | return frame, anno, object_meta 157 | -------------------------------------------------------------------------------- /lib/train/dataset/coco_seq.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_video_dataset import BaseVideoDataset 3 | from lib.train.data import jpeg4py_loader 4 | import torch 5 | import random 6 | from pycocotools.coco import COCO 7 | from collections import OrderedDict 8 | from lib.train.admin import env_settings 9 | 10 | 11 | class MSCOCOSeq(BaseVideoDataset): 12 | """ The COCO dataset. COCO is an image dataset. Thus, we treat each image as a sequence of length 1. 13 | 14 | Publication: 15 | Microsoft COCO: Common Objects in Context. 16 | Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona, 17 | Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick 18 | ECCV, 2014 19 | https://arxiv.org/pdf/1405.0312.pdf 20 | 21 | Download the images along with annotations from http://cocodataset.org/#download. The root folder should be 22 | organized as follows. 23 | - coco_root 24 | - annotations 25 | - instances_train2014.json 26 | - instances_train2017.json 27 | - images 28 | - train2014 29 | - train2017 30 | 31 | Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi. 32 | """ 33 | 34 | def __init__(self, root=None, image_loader=jpeg4py_loader, data_fraction=None, split="train", version="2014",env_num=None): 35 | """ 36 | args: 37 | root - path to the coco dataset. 38 | image_loader (default_image_loader) - The function to read the images. If installed, 39 | jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else, 40 | opencv's imread is used. 41 | data_fraction (None) - Fraction of images to be used. The images are selected randomly. If None, all the 42 | images will be used 43 | split - 'train' or 'val'. 44 | version - version of coco dataset (2014 or 2017) 45 | """ 46 | root = env_settings(env_num).coco_dir if root is None else root 47 | super().__init__('COCO', root, image_loader) 48 | 49 | self.img_pth = os.path.join(root, 'images/{}{}/'.format(split, version)) 50 | self.anno_path = os.path.join(root, 'annotations/instances_{}{}.json'.format(split, version)) 51 | 52 | # Load the COCO set. 53 | self.coco_set = COCO(self.anno_path) 54 | 55 | self.cats = self.coco_set.cats 56 | 57 | self.class_list = self.get_class_list() 58 | 59 | self.sequence_list = self._get_sequence_list() 60 | 61 | if data_fraction is not None: 62 | self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list)*data_fraction)) 63 | self.seq_per_class = self._build_seq_per_class() 64 | 65 | def _get_sequence_list(self): 66 | ann_list = list(self.coco_set.anns.keys()) 67 | seq_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0] 68 | 69 | return seq_list 70 | 71 | def is_video_sequence(self): 72 | return False 73 | 74 | def get_num_classes(self): 75 | return len(self.class_list) 76 | 77 | def get_name(self): 78 | return 'coco' 79 | 80 | def has_class_info(self): 81 | return True 82 | 83 | def get_class_list(self): 84 | class_list = [] 85 | for cat_id in self.cats.keys(): 86 | class_list.append(self.cats[cat_id]['name']) 87 | return class_list 88 | 89 | def has_segmentation_info(self): 90 | return True 91 | 92 | def get_num_sequences(self): 93 | return len(self.sequence_list) 94 | 95 | def _build_seq_per_class(self): 96 | seq_per_class = {} 97 | for i, seq in enumerate(self.sequence_list): 98 | class_name = self.cats[self.coco_set.anns[seq]['category_id']]['name'] 99 | if class_name not in seq_per_class: 100 | seq_per_class[class_name] = [i] 101 | else: 102 | seq_per_class[class_name].append(i) 103 | 104 | return seq_per_class 105 | 106 | def get_sequences_in_class(self, class_name): 107 | return self.seq_per_class[class_name] 108 | 109 | def get_sequence_info(self, seq_id): 110 | anno = self._get_anno(seq_id) 111 | 112 | bbox = torch.Tensor(anno['bbox']).view(1, 4) 113 | 114 | mask = torch.Tensor(self.coco_set.annToMask(anno)).unsqueeze(dim=0) 115 | 116 | '''2021.1.3 To avoid too small bounding boxes. Here we change the threshold to 50 pixels''' 117 | valid = (bbox[:, 2] > 50) & (bbox[:, 3] > 50) 118 | 119 | visible = valid.clone().byte() 120 | 121 | return {'bbox': bbox, 'mask': mask, 'valid': valid, 'visible': visible} 122 | 123 | def _get_anno(self, seq_id): 124 | anno = self.coco_set.anns[self.sequence_list[seq_id]] 125 | 126 | return anno 127 | 128 | def _get_frames(self, seq_id): 129 | path = self.coco_set.loadImgs([self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0]['file_name'] 130 | img = self.image_loader(os.path.join(self.img_pth, path)) 131 | return img 132 | 133 | def get_meta_info(self, seq_id): 134 | try: 135 | cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']] 136 | object_meta = OrderedDict({'object_class_name': cat_dict_current['name'], 137 | 'motion_class': None, 138 | 'major_class': cat_dict_current['supercategory'], 139 | 'root_class': None, 140 | 'motion_adverb': None}) 141 | except: 142 | object_meta = OrderedDict({'object_class_name': None, 143 | 'motion_class': None, 144 | 'major_class': None, 145 | 'root_class': None, 146 | 'motion_adverb': None}) 147 | return object_meta 148 | 149 | 150 | def get_class_name(self, seq_id): 151 | cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']] 152 | return cat_dict_current['name'] 153 | 154 | def get_frames(self, seq_id=None, frame_ids=None, anno=None): 155 | # COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a 156 | # list containing these replicated images. 157 | frame = self._get_frames(seq_id) 158 | 159 | frame_list = [frame.copy() for _ in frame_ids] 160 | 161 | if anno is None: 162 | anno = self.get_sequence_info(seq_id) 163 | 164 | anno_frames = {} 165 | for key, value in anno.items(): 166 | anno_frames[key] = [value[0, ...] for _ in frame_ids] 167 | 168 | object_meta = self.get_meta_info(seq_id) 169 | 170 | return frame_list, anno_frames, object_meta 171 | -------------------------------------------------------------------------------- /lib/train/dataset/imagenetvid_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_video_dataset import BaseVideoDataset 3 | from lib.train.data import jpeg4py_loader 4 | import torch 5 | from collections import OrderedDict 6 | from lib.train.admin import env_settings 7 | from lib.utils.lmdb_utils import decode_img, decode_json 8 | 9 | 10 | def get_target_to_image_ratio(seq): 11 | anno = torch.Tensor(seq['anno']) 12 | img_sz = torch.Tensor(seq['image_size']) 13 | return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt() 14 | 15 | 16 | class ImagenetVID_lmdb(BaseVideoDataset): 17 | """ Imagenet VID dataset. 18 | 19 | Publication: 20 | ImageNet Large Scale Visual Recognition Challenge 21 | Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, 22 | Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei 23 | IJCV, 2015 24 | https://arxiv.org/pdf/1409.0575.pdf 25 | 26 | Download the dataset from http://image-net.org/ 27 | """ 28 | def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1,env_num=None): 29 | """ 30 | args: 31 | root - path to the imagenet vid dataset. 32 | image_loader (default_image_loader) - The function to read the images. If installed, 33 | jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else, 34 | opencv's imread is used. 35 | min_length - Minimum allowed sequence length. 36 | max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets 37 | which cover complete image. 38 | """ 39 | root = env_settings(env_num).imagenet_dir if root is None else root 40 | super().__init__("imagenetvid_lmdb", root, image_loader) 41 | 42 | sequence_list_dict = decode_json(root, "cache.json") 43 | self.sequence_list = sequence_list_dict 44 | 45 | # Filter the sequences based on min_length and max_target_area in the first frame 46 | self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and 47 | get_target_to_image_ratio(x) < max_target_area] 48 | 49 | def get_name(self): 50 | return 'imagenetvid_lmdb' 51 | 52 | def get_num_sequences(self): 53 | return len(self.sequence_list) 54 | 55 | def get_sequence_info(self, seq_id): 56 | bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno']) 57 | valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0) 58 | visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte() 59 | return {'bbox': bb_anno, 'valid': valid, 'visible': visible} 60 | 61 | def _get_frame(self, sequence, frame_id): 62 | set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id']) 63 | vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id']) 64 | frame_number = frame_id + sequence['start_frame'] 65 | frame_path = os.path.join('Data', 'VID', 'train', set_name, vid_name, 66 | '{:06d}.JPEG'.format(frame_number)) 67 | return decode_img(self.root, frame_path) 68 | 69 | def get_frames(self, seq_id, frame_ids, anno=None): 70 | sequence = self.sequence_list[seq_id] 71 | 72 | frame_list = [self._get_frame(sequence, f) for f in frame_ids] 73 | 74 | if anno is None: 75 | anno = self.get_sequence_info(seq_id) 76 | 77 | # Create anno dict 78 | anno_frames = {} 79 | for key, value in anno.items(): 80 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 81 | 82 | # added the class info to the meta info 83 | object_meta = OrderedDict({'object_class': sequence['class_name'], 84 | 'motion_class': None, 85 | 'major_class': None, 86 | 'root_class': None, 87 | 'motion_adverb': None}) 88 | 89 | return frame_list, anno_frames, object_meta 90 | 91 | -------------------------------------------------------------------------------- /lib/train/dataset/lasot_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | from collections import OrderedDict 5 | 6 | import pandas 7 | import torch 8 | 9 | from lib.train.admin import env_settings 10 | from lib.train.data import jpeg4py_loader 11 | from .base_video_dataset import BaseVideoDataset 12 | 13 | '''2021.1.16 Lasot for loading lmdb dataset''' 14 | from lib.utils.lmdb_utils import * 15 | 16 | 17 | class Lasot_lmdb(BaseVideoDataset): 18 | 19 | def __init__(self, root=None, image_loader=jpeg4py_loader, vid_ids=None, split=None, data_fraction=None, 20 | env_num=None): 21 | """ 22 | args: 23 | root - path to the lasot dataset. 24 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 25 | is used by default. 26 | vid_ids - List containing the ids of the videos (1 - 20) used for training. If vid_ids = [1, 3, 5], then the 27 | videos with subscripts -1, -3, and -5 from each class will be used for training. 28 | split - If split='train', the official train split (protocol-II) is used for training. Note: Only one of 29 | vid_ids or split option can be used at a time. 30 | data_fraction - Fraction of dataset to be used. The complete dataset is used by default 31 | """ 32 | root = env_settings(env_num).lasot_lmdb_dir if root is None else root 33 | super().__init__('LaSOT_lmdb', root, image_loader) 34 | 35 | self.sequence_list = self._build_sequence_list(vid_ids, split) 36 | class_list = [seq_name.split('-')[0] for seq_name in self.sequence_list] 37 | self.class_list = [] 38 | for ele in class_list: 39 | if ele not in self.class_list: 40 | self.class_list.append(ele) 41 | # Keep a list of all classes 42 | self.class_to_id = {cls_name: cls_id for cls_id, cls_name in enumerate(self.class_list)} 43 | 44 | if data_fraction is not None: 45 | self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction)) 46 | 47 | self.seq_per_class = self._build_class_list() 48 | 49 | def _build_sequence_list(self, vid_ids=None, split=None): 50 | if split is not None: 51 | if vid_ids is not None: 52 | raise ValueError('Cannot set both split_name and vid_ids.') 53 | ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 54 | if split == 'train': 55 | file_path = os.path.join(ltr_path, 'data_specs', 'lasot_train_split.txt') 56 | else: 57 | raise ValueError('Unknown split name.') 58 | sequence_list = pandas.read_csv(file_path, header=None, squeeze=True).values.tolist() 59 | elif vid_ids is not None: 60 | sequence_list = [c + '-' + str(v) for c in self.class_list for v in vid_ids] 61 | else: 62 | raise ValueError('Set either split_name or vid_ids.') 63 | 64 | return sequence_list 65 | 66 | def _build_class_list(self): 67 | seq_per_class = {} 68 | for seq_id, seq_name in enumerate(self.sequence_list): 69 | class_name = seq_name.split('-')[0] 70 | if class_name in seq_per_class: 71 | seq_per_class[class_name].append(seq_id) 72 | else: 73 | seq_per_class[class_name] = [seq_id] 74 | 75 | return seq_per_class 76 | 77 | def get_name(self): 78 | return 'lasot_lmdb' 79 | 80 | def has_class_info(self): 81 | return True 82 | 83 | def has_occlusion_info(self): 84 | return True 85 | 86 | def get_num_sequences(self): 87 | return len(self.sequence_list) 88 | 89 | def get_num_classes(self): 90 | return len(self.class_list) 91 | 92 | def get_sequences_in_class(self, class_name): 93 | return self.seq_per_class[class_name] 94 | 95 | def _read_bb_anno(self, seq_path): 96 | bb_anno_file = os.path.join(seq_path, "groundtruth.txt") 97 | gt_str_list = decode_str(self.root, bb_anno_file).split('\n')[:-1] # the last line is empty 98 | gt_list = [list(map(float, line.split(','))) for line in gt_str_list] 99 | gt_arr = np.array(gt_list).astype(np.float32) 100 | return torch.tensor(gt_arr) 101 | 102 | def _read_target_visible(self, seq_path): 103 | # Read full occlusion and out_of_view 104 | occlusion_file = os.path.join(seq_path, "full_occlusion.txt") 105 | out_of_view_file = os.path.join(seq_path, "out_of_view.txt") 106 | 107 | occ_list = list(map(int, decode_str(self.root, occlusion_file).split(','))) 108 | occlusion = torch.ByteTensor(occ_list) 109 | out_view_list = list(map(int, decode_str(self.root, out_of_view_file).split(','))) 110 | out_of_view = torch.ByteTensor(out_view_list) 111 | 112 | target_visible = ~occlusion & ~out_of_view 113 | 114 | return target_visible 115 | 116 | def _get_sequence_path(self, seq_id): 117 | seq_name = self.sequence_list[seq_id] 118 | class_name = seq_name.split('-')[0] 119 | vid_id = seq_name.split('-')[1] 120 | 121 | return os.path.join(class_name, class_name + '-' + vid_id) 122 | 123 | def get_sequence_info(self, seq_id): 124 | seq_path = self._get_sequence_path(seq_id) 125 | bbox = self._read_bb_anno(seq_path) 126 | 127 | valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 128 | visible = self._read_target_visible(seq_path) & valid.byte() 129 | 130 | return {'bbox': bbox, 'valid': valid, 'visible': visible} 131 | 132 | def _get_frame_path(self, seq_path, frame_id): 133 | return os.path.join(seq_path, 'img', '{:08}.jpg'.format(frame_id + 1)) # frames start from 1 134 | 135 | def _get_frame(self, seq_path, frame_id): 136 | return decode_img(self.root, self._get_frame_path(seq_path, frame_id)) 137 | 138 | def _get_class(self, seq_path): 139 | raw_class = seq_path.split('/')[-2] 140 | return raw_class 141 | 142 | def get_class_name(self, seq_id): 143 | seq_path = self._get_sequence_path(seq_id) 144 | obj_class = self._get_class(seq_path) 145 | 146 | return obj_class 147 | 148 | def get_frames(self, seq_id, frame_ids, anno=None): 149 | seq_path = self._get_sequence_path(seq_id) 150 | 151 | obj_class = self._get_class(seq_path) 152 | frame_list = [self._get_frame(seq_path, f_id) for f_id in frame_ids] 153 | 154 | if anno is None: 155 | anno = self.get_sequence_info(seq_id) 156 | 157 | anno_frames = {} 158 | for key, value in anno.items(): 159 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 160 | 161 | object_meta = OrderedDict({'object_class_name': obj_class, 162 | 'motion_class': None, 163 | 'major_class': None, 164 | 'root_class': None, 165 | 'motion_adverb': None}) 166 | 167 | return frame_list, anno_frames, object_meta 168 | -------------------------------------------------------------------------------- /lib/train/dataset/tracking_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import random 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import pandas 8 | import torch 9 | 10 | from lib.train.admin import env_settings 11 | from lib.train.data import jpeg4py_loader 12 | from .base_video_dataset import BaseVideoDataset 13 | 14 | 15 | def list_sequences(root, set_ids): 16 | """ Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name) 17 | 18 | args: 19 | root: Root directory to TrackingNet 20 | set_ids: Sets (0-11) which are to be used 21 | 22 | returns: 23 | list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence 24 | """ 25 | sequence_list = [] 26 | 27 | for s in set_ids: 28 | anno_dir = os.path.join(root, "TRAIN_" + str(s), "anno") 29 | 30 | sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')] 31 | sequence_list += sequences_cur_set 32 | 33 | return sequence_list 34 | 35 | 36 | class TrackingNet(BaseVideoDataset): 37 | """ TrackingNet dataset. 38 | 39 | Publication: 40 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild. 41 | Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem 42 | ECCV, 2018 43 | https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf 44 | 45 | Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit. 46 | """ 47 | 48 | def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None, env_num=None): 49 | """ 50 | args: 51 | root - The path to the TrackingNet folder, containing the training sets. 52 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 53 | is used by default. 54 | set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the 55 | sets (0 - 11) will be used. 56 | data_fraction - Fraction of dataset to be used. The complete dataset is used by default 57 | """ 58 | root = env_settings(env_num).trackingnet_dir if root is None else root 59 | super().__init__('TrackingNet', root, image_loader) 60 | 61 | if set_ids is None: 62 | set_ids = [i for i in range(12)] 63 | 64 | self.set_ids = set_ids 65 | 66 | # Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and 67 | # video_name for each sequence 68 | self.sequence_list = list_sequences(self.root, self.set_ids) 69 | 70 | if data_fraction is not None: 71 | self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction)) 72 | 73 | self.seq_to_class_map, self.seq_per_class = self._load_class_info() 74 | 75 | # we do not have the class_lists for the tracking net 76 | self.class_list = list(self.seq_per_class.keys()) 77 | self.class_list.sort() 78 | 79 | def _load_class_info(self): 80 | ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 81 | class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt') 82 | 83 | with open(class_map_path, 'r') as f: 84 | seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f} 85 | 86 | seq_per_class = {} 87 | for i, seq in enumerate(self.sequence_list): 88 | class_name = seq_to_class_map.get(seq[1], 'Unknown') 89 | if class_name not in seq_per_class: 90 | seq_per_class[class_name] = [i] 91 | else: 92 | seq_per_class[class_name].append(i) 93 | 94 | return seq_to_class_map, seq_per_class 95 | 96 | def get_name(self): 97 | return 'trackingnet' 98 | 99 | def has_class_info(self): 100 | return True 101 | 102 | def get_sequences_in_class(self, class_name): 103 | return self.seq_per_class[class_name] 104 | 105 | def _read_bb_anno(self, seq_id): 106 | set_id = self.sequence_list[seq_id][0] 107 | vid_name = self.sequence_list[seq_id][1] 108 | bb_anno_file = os.path.join(self.root, "TRAIN_" + str(set_id), "anno", vid_name + ".txt") 109 | gt = pandas.read_csv(bb_anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, 110 | low_memory=False).values 111 | return torch.tensor(gt) 112 | 113 | def get_sequence_info(self, seq_id): 114 | bbox = self._read_bb_anno(seq_id) 115 | 116 | valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 117 | visible = valid.clone().byte() 118 | return {'bbox': bbox, 'valid': valid, 'visible': visible} 119 | 120 | def _get_frame(self, seq_id, frame_id): 121 | set_id = self.sequence_list[seq_id][0] 122 | vid_name = self.sequence_list[seq_id][1] 123 | frame_path = os.path.join(self.root, "TRAIN_" + str(set_id), "frames", vid_name, str(frame_id) + ".jpg") 124 | return self.image_loader(frame_path) 125 | 126 | def _get_class(self, seq_id): 127 | seq_name = self.sequence_list[seq_id][1] 128 | return self.seq_to_class_map[seq_name] 129 | 130 | def get_class_name(self, seq_id): 131 | obj_class = self._get_class(seq_id) 132 | 133 | return obj_class 134 | 135 | def get_frames(self, seq_id, frame_ids, anno=None): 136 | frame_list = [self._get_frame(seq_id, f) for f in frame_ids] 137 | 138 | if anno is None: 139 | anno = self.get_sequence_info(seq_id) 140 | 141 | anno_frames = {} 142 | for key, value in anno.items(): 143 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 144 | 145 | obj_class = self._get_class(seq_id) 146 | 147 | object_meta = OrderedDict({'object_class_name': obj_class, 148 | 'motion_class': None, 149 | 'major_class': None, 150 | 'root_class': None, 151 | 'motion_adverb': None}) 152 | 153 | return frame_list, anno_frames, object_meta 154 | -------------------------------------------------------------------------------- /lib/train/dataset/tracking_net_lmdb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import os.path 4 | import numpy as np 5 | import random 6 | from collections import OrderedDict 7 | 8 | from lib.train.data import jpeg4py_loader 9 | from .base_video_dataset import BaseVideoDataset 10 | from lib.train.admin import env_settings 11 | import json 12 | from lib.utils.lmdb_utils import decode_img, decode_str 13 | 14 | 15 | def list_sequences(root): 16 | """ Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name) 17 | 18 | args: 19 | root: Root directory to TrackingNet 20 | 21 | returns: 22 | list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence 23 | """ 24 | fname = os.path.join(root, "seq_list.json") 25 | with open(fname, "r") as f: 26 | sequence_list = json.loads(f.read()) 27 | return sequence_list 28 | 29 | 30 | class TrackingNet_lmdb(BaseVideoDataset): 31 | """ TrackingNet dataset. 32 | 33 | Publication: 34 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild. 35 | Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem 36 | ECCV, 2018 37 | https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf 38 | 39 | Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit. 40 | """ 41 | def __init__(self, root=None, image_loader=jpeg4py_loader, set_ids=None, data_fraction=None,env_num=None): 42 | """ 43 | args: 44 | root - The path to the TrackingNet folder, containing the training sets. 45 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 46 | is used by default. 47 | set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the 48 | sets (0 - 11) will be used. 49 | data_fraction - Fraction of dataset to be used. The complete dataset is used by default 50 | """ 51 | root = env_settings(env_num).trackingnet_lmdb_dir if root is None else root 52 | super().__init__('TrackingNet_lmdb', root, image_loader) 53 | 54 | if set_ids is None: 55 | set_ids = [i for i in range(12)] 56 | 57 | self.set_ids = set_ids 58 | 59 | # Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and 60 | # video_name for each sequence 61 | self.sequence_list = list_sequences(self.root) 62 | 63 | if data_fraction is not None: 64 | self.sequence_list = random.sample(self.sequence_list, int(len(self.sequence_list) * data_fraction)) 65 | 66 | self.seq_to_class_map, self.seq_per_class = self._load_class_info() 67 | 68 | # we do not have the class_lists for the tracking net 69 | self.class_list = list(self.seq_per_class.keys()) 70 | self.class_list.sort() 71 | 72 | def _load_class_info(self): 73 | ltr_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..') 74 | class_map_path = os.path.join(ltr_path, 'data_specs', 'trackingnet_classmap.txt') 75 | 76 | with open(class_map_path, 'r') as f: 77 | seq_to_class_map = {seq_class.split('\t')[0]: seq_class.rstrip().split('\t')[1] for seq_class in f} 78 | 79 | seq_per_class = {} 80 | for i, seq in enumerate(self.sequence_list): 81 | class_name = seq_to_class_map.get(seq[1], 'Unknown') 82 | if class_name not in seq_per_class: 83 | seq_per_class[class_name] = [i] 84 | else: 85 | seq_per_class[class_name].append(i) 86 | 87 | return seq_to_class_map, seq_per_class 88 | 89 | def get_name(self): 90 | return 'trackingnet_lmdb' 91 | 92 | def has_class_info(self): 93 | return True 94 | 95 | def get_sequences_in_class(self, class_name): 96 | return self.seq_per_class[class_name] 97 | 98 | def _read_bb_anno(self, seq_id): 99 | set_id = self.sequence_list[seq_id][0] 100 | vid_name = self.sequence_list[seq_id][1] 101 | gt_str_list = decode_str(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id), 102 | os.path.join("anno", vid_name + ".txt")).split('\n')[:-1] 103 | gt_list = [list(map(float, line.split(','))) for line in gt_str_list] 104 | gt_arr = np.array(gt_list).astype(np.float32) 105 | return torch.tensor(gt_arr) 106 | 107 | def get_sequence_info(self, seq_id): 108 | bbox = self._read_bb_anno(seq_id) 109 | 110 | valid = (bbox[:, 2] > 0) & (bbox[:, 3] > 0) 111 | visible = valid.clone().byte() 112 | return {'bbox': bbox, 'valid': valid, 'visible': visible} 113 | 114 | def _get_frame(self, seq_id, frame_id): 115 | set_id = self.sequence_list[seq_id][0] 116 | vid_name = self.sequence_list[seq_id][1] 117 | return decode_img(os.path.join(self.root, "TRAIN_%d_lmdb" % set_id), 118 | os.path.join("frames", vid_name, str(frame_id) + ".jpg")) 119 | 120 | def _get_class(self, seq_id): 121 | seq_name = self.sequence_list[seq_id][1] 122 | return self.seq_to_class_map[seq_name] 123 | 124 | def get_class_name(self, seq_id): 125 | obj_class = self._get_class(seq_id) 126 | 127 | return obj_class 128 | 129 | def get_frames(self, seq_id, frame_ids, anno=None): 130 | frame_list = [self._get_frame(seq_id, f) for f in frame_ids] 131 | 132 | if anno is None: 133 | anno = self.get_sequence_info(seq_id) 134 | 135 | anno_frames = {} 136 | for key, value in anno.items(): 137 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids] 138 | 139 | obj_class = self._get_class(seq_id) 140 | 141 | object_meta = OrderedDict({'object_class_name': obj_class, 142 | 'motion_class': None, 143 | 'major_class': None, 144 | 'root_class': None, 145 | 'motion_adverb': None}) 146 | 147 | return frame_list, anno_frames, object_meta 148 | -------------------------------------------------------------------------------- /lib/train/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # __ coding: utf-8 __ 2 | # author: Li Yunfeng 3 | # data: 2022/12/12 11:36 4 | 5 | 6 | from .objective import lightTrackObjective -------------------------------------------------------------------------------- /lib/train/loss/cos_sim_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def cosine_similarity_loss(anchor, negative, margin): 6 | # 计算向量之间的余弦相似度 7 | # similarity_pos = F.cosine_similarity(anchor, positive) 8 | similarity_neg = F.cosine_similarity(anchor, negative) 9 | 10 | # 计算余弦相似度损失 11 | loss = torch.clamp(similarity_neg - 0 + margin, min=0) 12 | 13 | return loss.mean() 14 | -------------------------------------------------------------------------------- /lib/train/loss/focal_loss.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class FocalLoss(nn.Module, ABC): 9 | def __init__(self, alpha=2, beta=4): 10 | super(FocalLoss, self).__init__() 11 | self.alpha = alpha 12 | self.beta = beta 13 | 14 | def forward(self, prediction, target): 15 | positive_index = target.eq(1).float() 16 | negative_index = target.lt(1).float() 17 | 18 | negative_weights = torch.pow(1 - target, self.beta) 19 | # clamp min value is set to 1e-12 to maintain the numerical stability 20 | prediction = torch.clamp(prediction, 1e-12) 21 | 22 | positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index 23 | negative_loss = torch.log(1 - prediction) * torch.pow(prediction, 24 | self.alpha) * negative_weights * negative_index 25 | 26 | num_positive = positive_index.float().sum() 27 | positive_loss = positive_loss.sum() 28 | negative_loss = negative_loss.sum() 29 | 30 | if num_positive == 0: 31 | loss = -negative_loss 32 | else: 33 | loss = -(positive_loss + negative_loss) / num_positive 34 | 35 | return loss 36 | 37 | 38 | class LBHinge(nn.Module): 39 | """Loss that uses a 'hinge' on the lower bound. 40 | This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is 41 | also smaller than that threshold. 42 | args: 43 | error_matric: What base loss to use (MSE by default). 44 | threshold: Threshold to use for the hinge. 45 | clip: Clip the loss if it is above this value. 46 | """ 47 | def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None): 48 | super().__init__() 49 | self.error_metric = error_metric 50 | self.threshold = threshold if threshold is not None else -100 51 | self.clip = clip 52 | 53 | def forward(self, prediction, label, target_bb=None): 54 | negative_mask = (label < self.threshold).float() 55 | positive_mask = (1.0 - negative_mask) 56 | 57 | prediction = negative_mask * F.relu(prediction) + positive_mask * prediction 58 | 59 | loss = self.error_metric(prediction, positive_mask * label) 60 | 61 | if self.clip is not None: 62 | loss = torch.min(loss, torch.tensor([self.clip], device=loss.device)) 63 | return loss -------------------------------------------------------------------------------- /lib/train/loss/objective.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .focal_loss import FocalLoss 3 | from .box_loss import giou_loss, ciou_loss, siou_loss, eiou_loss, wiou_loss 4 | from torch.nn.functional import l1_loss, smooth_l1_loss 5 | import torch.nn as nn 6 | 7 | from .gfocal_loss import DistributionFocalLoss 8 | from .varifocal_loss import VarifocalLoss 9 | 10 | 11 | class lightTrackObjective(object): 12 | def __init__(self, cfg): 13 | super(lightTrackObjective, self).__init__() 14 | 15 | # l loss 16 | if cfg.TRAIN.L_LOSS == 'l1': 17 | self.l1 = l1_loss 18 | elif cfg.TRAIN.L_LOSS == 'smooth_l1': 19 | self.smooth_l1 = smooth_l1_loss 20 | else: 21 | pass 22 | 23 | # box iou 24 | if cfg.TRAIN.BOX_LOSS == 'giou': 25 | self.iou = giou_loss 26 | elif cfg.TRAIN.BOX_LOSS == 'ciou': 27 | self.iou = ciou_loss 28 | elif cfg.TRAIN.BOX_LOSS == 'siou': 29 | self.iou = siou_loss 30 | elif cfg.TRAIN.BOX_LOSS == 'wiou': 31 | self.iou = wiou_loss 32 | elif cfg.TRAIN.BOX_LOSS == 'eiou': 33 | self.iou = eiou_loss 34 | else: 35 | pass 36 | 37 | # cls iou 38 | if cfg.TRAIN.CLS_LOSS == 'focal': 39 | self.focal_loss = FocalLoss() 40 | elif cfg.TRAIN.CLS_LOSS == 'varifocal': 41 | self.focal_loss = VarifocalLoss() 42 | 43 | -------------------------------------------------------------------------------- /lib/train/optimizer/lion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """PyTorch implementation of the Lion optimizer.""" 16 | import torch 17 | from torch.optim.optimizer import Optimizer 18 | 19 | 20 | class Lion(Optimizer): 21 | r"""Implements Lion algorithm.""" 22 | 23 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): 24 | """Initialize the hyperparameters. 25 | 26 | Args: 27 | params (iterable): iterable of parameters to optimize or dicts defining 28 | parameter groups 29 | lr (float, optional): learning rate (default: 1e-4) 30 | betas (Tuple[float, float], optional): coefficients used for computing 31 | running averages of gradient and its square (default: (0.9, 0.99)) 32 | weight_decay (float, optional): weight decay coefficient (default: 0) 33 | """ 34 | 35 | if not 0.0 <= lr: 36 | raise ValueError('Invalid learning rate: {}'.format(lr)) 37 | if not 0.0 <= betas[0] < 1.0: 38 | raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) 39 | if not 0.0 <= betas[1] < 1.0: 40 | raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) 41 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 42 | super().__init__(params, defaults) 43 | 44 | @torch.no_grad() 45 | def step(self, closure=None): 46 | """Performs a single optimization step. 47 | 48 | Args: 49 | closure (callable, optional): A closure that reevaluates the model 50 | and returns the loss. 51 | 52 | Returns: 53 | the loss. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | with torch.enable_grad(): 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | 65 | # Perform stepweight decay 66 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 67 | 68 | grad = p.grad 69 | state = self.state[p] 70 | # State initialization 71 | if len(state) == 0: 72 | # Exponential moving average of gradient values 73 | state['exp_avg'] = torch.zeros_like(p) 74 | 75 | exp_avg = state['exp_avg'] 76 | beta1, beta2 = group['betas'] 77 | 78 | # Weight update 79 | update = exp_avg * beta1 + grad * (1 - beta1) 80 | p.add_(torch.sign(update), alpha=-group['lr']) 81 | # Decay the momentum running average coefficient 82 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 83 | 84 | return loss -------------------------------------------------------------------------------- /lib/train/run_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import importlib 5 | import cv2 as cv 6 | import torch.backends.cudnn 7 | import torch.distributed as dist 8 | 9 | import random 10 | import numpy as np 11 | 12 | torch.backends.cudnn.benchmark = False 13 | 14 | import _init_paths 15 | import lib.train.admin.settings as ws_settings 16 | 17 | 18 | def init_seeds(seed): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False 25 | 26 | 27 | def run_training(script_name, config_name, cudnn_benchmark=True, local_rank=-1, save_dir=None, base_seed=None, 28 | use_lmdb=False, script_name_prv=None, config_name_prv=None, use_wandb=False, env_num=None, 29 | distill=None, script_teacher=None, config_teacher=None): 30 | """Run the train script. 31 | args: 32 | script_name: Name of emperiment in the "experiments/" folder. 33 | config_name: Name of the yaml file in the "experiments/". 34 | cudnn_benchmark: Use cudnn benchmark or not (default is True). 35 | """ 36 | if save_dir is None: 37 | print("save_dir dir is not given. Use the default dir instead.") 38 | # This is needed to avoid strange crashes related to opencv 39 | cv.setNumThreads(0) 40 | 41 | torch.backends.cudnn.benchmark = cudnn_benchmark 42 | 43 | print('script_name: {}.py config_name: {}.yaml'.format(script_name, config_name)) 44 | 45 | '''2021.1.5 set seed for different process''' 46 | if base_seed is not None: 47 | if local_rank != -1: 48 | init_seeds(base_seed + local_rank) 49 | else: 50 | init_seeds(base_seed) 51 | 52 | settings = ws_settings.Settings(env_num) 53 | settings.env_num = env_num 54 | settings.script_name = script_name 55 | settings.config_name = config_name 56 | settings.project_path = 'train/{}/{}'.format(script_name, config_name) 57 | if script_name_prv is not None and config_name_prv is not None: 58 | settings.project_path_prv = 'train/{}/{}'.format(script_name_prv, config_name_prv) 59 | settings.local_rank = local_rank 60 | settings.save_dir = os.path.abspath(save_dir) 61 | settings.use_lmdb = use_lmdb 62 | prj_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) 63 | settings.cfg_file = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_name, config_name)) 64 | settings.use_wandb = use_wandb 65 | if distill: 66 | settings.distill = distill 67 | settings.script_teacher = script_teacher 68 | settings.config_teacher = config_teacher 69 | if script_teacher is not None and config_teacher is not None: 70 | settings.project_path_teacher = 'train/{}/{}'.format(script_teacher, config_teacher) 71 | settings.cfg_file_teacher = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_teacher, config_teacher)) 72 | expr_module = importlib.import_module('lib.train.train_script_distill') 73 | else: 74 | expr_module = importlib.import_module('lib.train.train_script') 75 | expr_func = getattr(expr_module, 'run') 76 | 77 | expr_func(settings) 78 | 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.') 82 | parser.add_argument('--script', type=str, required=True, help='Name of the train script.') 83 | parser.add_argument('--config', type=str, required=True, help="Name of the config file.") 84 | parser.add_argument('--cudnn_benchmark', type=bool, default=True, 85 | help='Set cudnn benchmark on (1) or off (0) (default is on).') 86 | parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training') 87 | parser.add_argument('--save_dir', type=str, help='the directory to save checkpoints and logs') 88 | parser.add_argument('--seed', type=int, default=42, help='seed for random numbers') 89 | parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format 90 | parser.add_argument('--script_prv', type=str, default=None, help='Name of the train script of previous model.') 91 | parser.add_argument('--config_prv', type=str, default=None, help="Name of the config file of previous model.") 92 | parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb 93 | parser.add_argument('--env_num', type=int, default=None, 94 | help='Use for multi environment developing, support: 0,1,2') 95 | # for knowledge distillation 96 | parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation 97 | parser.add_argument('--script_teacher', type=str, help='teacher script name') 98 | parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name') 99 | 100 | args = parser.parse_args() 101 | if args.local_rank != -1: 102 | dist.init_process_group(backend='nccl') 103 | torch.cuda.set_device(args.local_rank) 104 | else: 105 | torch.cuda.set_device(0) 106 | run_training(args.script, args.config, cudnn_benchmark=args.cudnn_benchmark, 107 | local_rank=args.local_rank, save_dir=args.save_dir, base_seed=args.seed, 108 | use_lmdb=args.use_lmdb, script_name_prv=args.script_prv, config_name_prv=args.config_prv, 109 | use_wandb=args.use_wandb, env_num=args.env_num, 110 | distill=args.distill, script_teacher=args.script_teacher, config_teacher=args.config_teacher) 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | -------------------------------------------------------------------------------- /lib/train/train_script.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | 7 | from lib.models import * 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | 10 | from lib.train.actors import * 11 | from lib.train.data.base_functions import * 12 | from lib.train.trainers import LTRTrainer 13 | from lib.utils.load import load_yaml 14 | from lib.train.loss import lightTrackObjective 15 | 16 | 17 | def run(settings): 18 | settings.description = 'Training script' 19 | 20 | cfg = load_yaml(settings.cfg_file) 21 | print('CFG', cfg) 22 | update_settings(settings, cfg) 23 | 24 | # init seed 25 | random.seed(cfg.TRAIN.SEED) 26 | np.random.seed(cfg.TRAIN.SEED) 27 | torch.manual_seed(cfg.TRAIN.SEED) 28 | torch.cuda.manual_seed(cfg.TRAIN.SEED) 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | 32 | # Record the training log 33 | log_dir = os.path.join(settings.save_dir, 'logs') 34 | if settings.local_rank in [-1, 0]: 35 | if not os.path.exists(log_dir): 36 | os.makedirs(log_dir) 37 | settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name)) 38 | 39 | # Build dataloaders 40 | loader_train, loader_val = build_dataloaders(cfg, settings) 41 | _, loader_val = build_dataloaders(cfg, settings) 42 | 43 | # Create network 44 | if settings.script_name == "lightfc": 45 | net = LightFC(cfg, env_num=settings.env_num, training=True) 46 | 47 | else: 48 | raise ValueError("illegal script name") 49 | 50 | # wrap networks to distributed one 51 | net.cuda() 52 | if settings.local_rank != -1: 53 | # net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter 54 | net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True) 55 | settings.device = torch.device("cuda:%d" % settings.local_rank) 56 | else: 57 | settings.device = torch.device("cuda:0") 58 | 59 | settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False) 60 | settings.distill = getattr(cfg.TRAIN, "DISTILL", False) 61 | settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL") 62 | 63 | # Actors 64 | if settings.script_name == "lightfc": 65 | objective = lightTrackObjective(cfg) 66 | loss_weight = {'iou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': cfg.TRAIN.LOC_WEIGHT, } 67 | actor = lightTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg) 68 | 69 | else: 70 | raise ValueError("illegal script name") 71 | 72 | # SWA 73 | settings.use_swa = getattr(cfg.TRAIN, 'USE_SWA', False) 74 | settings.swa_epoch = getattr(cfg.TRAIN, 'SWA_EPOCH', None) 75 | 76 | # Optimizer, parameters, and learning rates 77 | optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg) 78 | use_amp = getattr(cfg.TRAIN, "AMP", False) 79 | trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp, ) 80 | 81 | # train process 82 | trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True) 83 | -------------------------------------------------------------------------------- /lib/train/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | from .ltr_trainer import LTRTrainer 3 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensor import TensorDict, TensorList 2 | -------------------------------------------------------------------------------- /lib/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_area 3 | import numpy as np 4 | 5 | 6 | def box_cxcywh_to_xyxy(x): 7 | x_c, y_c, w, h = x.unbind(-1) 8 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 9 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 10 | return torch.stack(b, dim=-1) 11 | 12 | 13 | def box_xywh_to_xyxy(x): 14 | x1, y1, w, h = x.unbind(-1) 15 | b = [x1, y1, x1 + w, y1 + h] 16 | return torch.stack(b, dim=-1) 17 | 18 | 19 | def box_xyxy_to_xywh(x): 20 | x1, y1, x2, y2 = x.unbind(-1) 21 | b = [x1, y1, x2 - x1, y2 - y1] 22 | return torch.stack(b, dim=-1) 23 | 24 | 25 | def box_xyxy_to_cxcywh(x): 26 | x0, y0, x1, y1 = x.unbind(-1) 27 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 28 | (x1 - x0), (y1 - y0)] 29 | return torch.stack(b, dim=-1) 30 | 31 | 32 | # modified from torchvision to also return the union 33 | '''Note that this function only supports shape (N,4)''' 34 | 35 | 36 | def box_iou(boxes1, boxes2): 37 | """ 38 | 39 | :param boxes1: (N, 4) (x1,y1,x2,y2) 40 | :param boxes2: (N, 4) (x1,y1,x2,y2) 41 | :return: 42 | """ 43 | area1 = box_area(boxes1) # (N,) 44 | area2 = box_area(boxes2) # (N,) 45 | 46 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N,2) 47 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N,2) 48 | 49 | wh = (rb - lt).clamp(min=0) # (N,2) 50 | inter = wh[:, 0] * wh[:, 1] # (N,) 51 | 52 | union = area1 + area2 - inter 53 | 54 | iou = inter / union 55 | return iou, union 56 | 57 | 58 | def clip_box(box: list, H, W, margin=0): 59 | x1, y1, w, h = box 60 | x2, y2 = x1 + w, y1 + h 61 | x1 = min(max(0, x1), W - margin) 62 | x2 = min(max(margin, x2), W) 63 | y1 = min(max(0, y1), H - margin) 64 | y2 = min(max(margin, y2), H) 65 | w = max(margin, x2 - x1) 66 | h = max(margin, y2 - y1) 67 | return [x1, y1, w, h] 68 | 69 | 70 | def batch_xywh2center(boxes): 71 | cx = boxes[:, 0] + (boxes[:, 2] - 1) / 2 72 | cy = boxes[:, 1] + (boxes[:, 3] - 1) / 2 73 | w = boxes[:, 2] 74 | h = boxes[:, 3] 75 | 76 | if isinstance(boxes, np.ndarray): 77 | return np.stack([cx, cy, w, h], 1) 78 | else: 79 | return torch.stack([cx, cy, w, h], 1) 80 | 81 | 82 | def batch_xywh2center2(boxes): 83 | cx = boxes[:, 0] + boxes[:, 2] / 2 84 | cy = boxes[:, 1] + boxes[:, 3] / 2 85 | w = boxes[:, 2] 86 | h = boxes[:, 3] 87 | 88 | if isinstance(boxes, np.ndarray): 89 | return np.stack([cx, cy, w, h], 1) 90 | else: 91 | return torch.stack([cx, cy, w, h], 1) 92 | 93 | def batch_xywh2corner(boxes): 94 | xmin = boxes[:, 0] 95 | ymin = boxes[:, 1] 96 | xmax = boxes[:, 0] + boxes[:, 2] 97 | ymax = boxes[:, 1] + boxes[:, 3] 98 | 99 | if isinstance(boxes, np.ndarray): 100 | return np.stack([xmin, ymin, xmax, ymax], 1) 101 | else: 102 | return torch.stack([xmin, ymin, xmax, ymax], 1) -------------------------------------------------------------------------------- /lib/utils/ce_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def generate_bbox_mask(bbox_mask, bbox): 8 | b, h, w = bbox_mask.shape 9 | for i in range(b): 10 | bbox_i = bbox[i].cpu().tolist() 11 | bbox_mask[i, int(bbox_i[1]):int(bbox_i[1] + bbox_i[3] - 1), int(bbox_i[0]):int(bbox_i[0] + bbox_i[2] - 1)] = 1 12 | return bbox_mask 13 | 14 | 15 | def generate_mask_cond(cfg, bs, device, gt_bbox): 16 | template_size = cfg.DATA.TEMPLATE.SIZE 17 | stride = cfg.MODEL.BACKBONE.STRIDE 18 | template_feat_size = template_size // stride 19 | 20 | if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'ALL': 21 | box_mask_z = None 22 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT': 23 | if template_feat_size == 8: 24 | index = slice(3, 4) 25 | elif template_feat_size == 12: 26 | index = slice(5, 6) 27 | elif template_feat_size == 7: 28 | index = slice(3, 4) 29 | elif template_feat_size == 14: 30 | index = slice(6, 7) 31 | else: 32 | raise NotImplementedError 33 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 34 | box_mask_z[:, index, index] = 1 35 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 36 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_REC': 37 | # use fixed 4x4 region, 3:5 for 8x8 38 | # use fixed 4x4 region 5:6 for 12x12 39 | if template_feat_size == 8: 40 | index = slice(3, 5) 41 | elif template_feat_size == 12: 42 | index = slice(5, 7) 43 | elif template_feat_size == 7: 44 | index = slice(3, 4) 45 | else: 46 | raise NotImplementedError 47 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device) 48 | box_mask_z[:, index, index] = 1 49 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 50 | 51 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'GT_BOX': 52 | box_mask_z = torch.zeros([bs, template_size, template_size], device=device) 53 | # box_mask_z_ori = data['template_seg'][0].view(-1, 1, *data['template_seg'].shape[2:]) # (batch, 1, 128, 128) 54 | box_mask_z = generate_bbox_mask(box_mask_z, gt_bbox * template_size).unsqueeze(1).to( 55 | torch.float) # (batch, 1, 128, 128) 56 | # box_mask_z_vis = box_mask_z.cpu().numpy() 57 | box_mask_z = F.interpolate(box_mask_z, scale_factor=1. / cfg.MODEL.BACKBONE.STRIDE, mode='bilinear', 58 | align_corners=False) 59 | box_mask_z = box_mask_z.flatten(1).to(torch.bool) 60 | # box_mask_z_vis = box_mask_z[:, 0, ...].cpu().numpy() 61 | # gaussian_maps_vis = generate_heatmap(data['template_anno'], self.cfg.DATA.TEMPLATE.SIZE, self.cfg.MODEL.STRIDE)[0].cpu().numpy() 62 | else: 63 | raise NotImplementedError 64 | 65 | return box_mask_z 66 | 67 | 68 | def adjust_keep_rate(epoch, warmup_epochs, total_epochs, ITERS_PER_EPOCH, base_keep_rate=0.5, max_keep_rate=1, iters=-1): 69 | if epoch < warmup_epochs: 70 | return 1 71 | if epoch >= total_epochs: 72 | return base_keep_rate 73 | if iters == -1: 74 | iters = epoch * ITERS_PER_EPOCH 75 | total_iters = ITERS_PER_EPOCH * (total_epochs - warmup_epochs) 76 | iters = iters - ITERS_PER_EPOCH * warmup_epochs 77 | keep_rate = base_keep_rate + (max_keep_rate - base_keep_rate) \ 78 | * (math.cos(iters / total_iters * math.pi) + 1) * 0.5 79 | 80 | return keep_rate 81 | -------------------------------------------------------------------------------- /lib/utils/heapmap_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def generate_heatmap(bboxes, patch_size=320, stride=16, heatmap_size=None): 6 | """ 7 | Generate ground truth heatmap same as CenterNet 8 | Args: 9 | bboxes (torch.Tensor): shape of [num_search, bs, 4] 10 | 11 | Returns: 12 | gaussian_maps: list of generated heatmap 13 | 14 | """ 15 | gaussian_maps = [] 16 | if heatmap_size is None: 17 | heatmap_size = patch_size // stride 18 | for single_patch_bboxes in bboxes: 19 | bs = single_patch_bboxes.shape[0] 20 | gt_scoremap = torch.zeros(bs, heatmap_size, heatmap_size) 21 | classes = torch.arange(bs).to(torch.long) 22 | bbox = single_patch_bboxes * heatmap_size 23 | wh = bbox[:, 2:] 24 | centers_int = (bbox[:, :2] + wh / 2).round() 25 | CenterNetHeatMap.generate_score_map(gt_scoremap, classes, wh, centers_int, 0.7) 26 | gaussian_maps.append(gt_scoremap.to(bbox.device)) 27 | return gaussian_maps 28 | 29 | 30 | class CenterNetHeatMap(object): 31 | @staticmethod 32 | def generate_score_map(fmap, gt_class, gt_wh, centers_int, min_overlap): 33 | radius = CenterNetHeatMap.get_gaussian_radius(gt_wh, min_overlap) 34 | radius = torch.clamp_min(radius, 0) 35 | radius = radius.type(torch.int).cpu().numpy() 36 | for i in range(gt_class.shape[0]): 37 | channel_index = gt_class[i] 38 | CenterNetHeatMap.draw_gaussian(fmap[channel_index], centers_int[i], radius[i]) 39 | 40 | @staticmethod 41 | def get_gaussian_radius(box_size, min_overlap): 42 | """ 43 | copyed from CornerNet 44 | box_size (w, h), it could be a torch.Tensor, numpy.ndarray, list or tuple 45 | notice: we are using a bug-version, please refer to fix bug version in CornerNet 46 | """ 47 | # box_tensor = torch.Tensor(box_size) 48 | box_tensor = box_size 49 | width, height = box_tensor[..., 0], box_tensor[..., 1] 50 | 51 | a1 = 1 52 | b1 = height + width 53 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 54 | sq1 = torch.sqrt(b1 ** 2 - 4 * a1 * c1) 55 | r1 = (b1 + sq1) / 2 56 | 57 | a2 = 4 58 | b2 = 2 * (height + width) 59 | c2 = (1 - min_overlap) * width * height 60 | sq2 = torch.sqrt(b2 ** 2 - 4 * a2 * c2) 61 | r2 = (b2 + sq2) / 2 62 | 63 | a3 = 4 * min_overlap 64 | b3 = -2 * min_overlap * (height + width) 65 | c3 = (min_overlap - 1) * width * height 66 | sq3 = torch.sqrt(b3 ** 2 - 4 * a3 * c3) 67 | r3 = (b3 + sq3) / 2 68 | 69 | return torch.min(r1, torch.min(r2, r3)) 70 | 71 | @staticmethod 72 | def gaussian2D(radius, sigma=1): 73 | # m, n = [(s - 1.) / 2. for s in shape] 74 | m, n = radius 75 | y, x = np.ogrid[-m: m + 1, -n: n + 1] 76 | 77 | gauss = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 78 | gauss[gauss < np.finfo(gauss.dtype).eps * gauss.max()] = 0 79 | return gauss 80 | 81 | @staticmethod 82 | def draw_gaussian(fmap, center, radius, k=1): 83 | diameter = 2 * radius + 1 84 | gaussian = CenterNetHeatMap.gaussian2D((radius, radius), sigma=diameter / 6) 85 | gaussian = torch.Tensor(gaussian) 86 | x, y = int(center[0]), int(center[1]) 87 | height, width = fmap.shape[:2] 88 | 89 | left, right = min(x, radius), min(width - x, radius + 1) 90 | top, bottom = min(y, radius), min(height - y, radius + 1) 91 | 92 | masked_fmap = fmap[y - top: y + bottom, x - left: x + right] 93 | masked_gaussian = gaussian[radius - top: radius + bottom, radius - left: radius + right] 94 | if min(masked_gaussian.shape) > 0 and min(masked_fmap.shape) > 0: 95 | masked_fmap = torch.max(masked_fmap, masked_gaussian * k) 96 | fmap[y - top: y + bottom, x - left: x + right] = masked_fmap 97 | # return fmap 98 | 99 | 100 | def compute_grids(features, strides): 101 | """ 102 | grids regret to the input image size 103 | """ 104 | grids = [] 105 | for level, feature in enumerate(features): 106 | h, w = feature.size()[-2:] 107 | shifts_x = torch.arange( 108 | 0, w * strides[level], 109 | step=strides[level], 110 | dtype=torch.float32, device=feature.device) 111 | shifts_y = torch.arange( 112 | 0, h * strides[level], 113 | step=strides[level], 114 | dtype=torch.float32, device=feature.device) 115 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 116 | shift_x = shift_x.reshape(-1) 117 | shift_y = shift_y.reshape(-1) 118 | grids_per_level = torch.stack((shift_x, shift_y), dim=1) + \ 119 | strides[level] // 2 120 | grids.append(grids_per_level) 121 | return grids 122 | 123 | 124 | def get_center3x3(locations, centers, strides, range=3): 125 | ''' 126 | Inputs: 127 | locations: M x 2 128 | centers: N x 2 129 | strides: M 130 | ''' 131 | range = (range - 1) / 2 132 | M, N = locations.shape[0], centers.shape[0] 133 | locations_expanded = locations.view(M, 1, 2).expand(M, N, 2) # M x N x 2 134 | centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2 135 | strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N 136 | centers_discret = ((centers_expanded / strides_expanded).int() * strides_expanded).float() + \ 137 | strides_expanded / 2 # M x N x 2 138 | dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs() 139 | dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs() 140 | return (dist_x <= strides_expanded[:, :, 0] * range) & \ 141 | (dist_y <= strides_expanded[:, :, 0] * range) 142 | 143 | 144 | def get_pred(score_map_ctr, size_map, offset_map, feat_size): 145 | max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True) 146 | 147 | idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1) 148 | size = size_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 149 | offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1) 150 | 151 | return size * feat_size, offset 152 | -------------------------------------------------------------------------------- /lib/utils/list_tools.py: -------------------------------------------------------------------------------- 1 | def split_list(obj_lisat, node=4): 2 | num = int(len(obj_lisat) / node) + 1 3 | sublist = [] 4 | for i in range(node): 5 | sublist.append(obj_lisat[i * num:(i + 1) * num]) 6 | return sublist 7 | -------------------------------------------------------------------------------- /lib/utils/lmdb_utils.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import numpy as np 3 | import cv2 4 | import json 5 | 6 | LMDB_ENVS = dict() 7 | LMDB_HANDLES = dict() 8 | LMDB_FILELISTS = dict() 9 | 10 | 11 | def get_lmdb_handle(name): 12 | global LMDB_HANDLES, LMDB_FILELISTS 13 | item = LMDB_HANDLES.get(name, None) 14 | if item is None: 15 | env = lmdb.open(name, readonly=True, lock=False, readahead=False, meminit=False) 16 | LMDB_ENVS[name] = env 17 | item = env.begin(write=False) 18 | LMDB_HANDLES[name] = item 19 | 20 | return item 21 | 22 | 23 | def decode_img(lmdb_fname, key_name): 24 | handle = get_lmdb_handle(lmdb_fname) 25 | binfile = handle.get(key_name.encode()) 26 | if binfile is None: 27 | print("Illegal data detected. %s %s" % (lmdb_fname, key_name)) 28 | s = np.frombuffer(binfile, np.uint8) 29 | x = cv2.cvtColor(cv2.imdecode(s, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 30 | return x 31 | 32 | 33 | def decode_str(lmdb_fname, key_name): 34 | handle = get_lmdb_handle(lmdb_fname) 35 | binfile = handle.get(key_name.encode()) 36 | string = binfile.decode() 37 | return string 38 | 39 | 40 | def decode_json(lmdb_fname, key_name): 41 | return json.loads(decode_str(lmdb_fname, key_name)) 42 | 43 | 44 | if __name__ == "__main__": 45 | lmdb_fname = "/data/sda/v-yanbi/iccv21/LittleBoy_clean/data/got10k_lmdb" 46 | '''Decode image''' 47 | # key_name = "test/GOT-10k_Test_000001/00000001.jpg" 48 | # img = decode_img(lmdb_fname, key_name) 49 | # cv2.imwrite("001.jpg", img) 50 | '''Decode str''' 51 | # key_name = "test/list.txt" 52 | # key_name = "train/GOT-10k_Train_000001/groundtruth.txt" 53 | key_name = "train/GOT-10k_Train_000001/absence.label" 54 | str_ = decode_str(lmdb_fname, key_name) 55 | print(str_) 56 | -------------------------------------------------------------------------------- /lib/utils/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | from easydict import EasyDict as edict 5 | from lib.train.admin.local import EnvironmentSettings as env 6 | 7 | 8 | def load_pretrain(backbone, env_num=0, training=True, mode=1, cfg=None): 9 | pretrained_path = env(env_num=env_num).pretrained_networks 10 | if cfg.MODEL.BACKBONE.PRETRAIN_FILE is not None and training: 11 | pretrained = os.path.join(pretrained_path, cfg.MODEL.BACKBONE.PRETRAIN_FILE) 12 | else: 13 | pretrained = '' 14 | 15 | if training and cfg.MODEL.BACKBONE.USE_PRETRAINED: 16 | print(f'Try Loading Pretrained Model, using mode {mode}') 17 | try: 18 | 19 | if mode == 1: 20 | checkpoint = torch.load(pretrained, map_location="cpu") 21 | elif mode == 2: 22 | checkpoint = torch.load(pretrained, map_location="cpu")['state_dict'] 23 | elif mode == 3: 24 | checkpoint = torch.load(pretrained, map_location="cpu")['model'] 25 | else: 26 | raise 27 | missing_keys, unexpected_keys = backbone.load_state_dict(checkpoint, strict=True) 28 | print('Load pretrained model from: ' + pretrained) 29 | except Exception as e: 30 | print(e) 31 | print('Loading Finish ....') 32 | 33 | 34 | def load_yaml(yaml_file): 35 | with open(yaml_file) as f: 36 | exp_config = yaml.safe_load(f) 37 | exp_config = edict(exp_config) 38 | return exp_config 39 | -------------------------------------------------------------------------------- /lib/utils/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def merge_template_search(inp_list, return_search=False, return_template=False): 5 | """NOTICE: search region related features must be in the last place""" 6 | seq_dict = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 7 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 8 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} 9 | if return_search: 10 | x = inp_list[-1] 11 | seq_dict.update({"feat_x": x["feat"], "mask_x": x["mask"], "pos_x": x["pos"]}) 12 | if return_template: 13 | z = inp_list[0] 14 | seq_dict.update({"feat_z": z["feat"], "mask_z": z["mask"], "pos_z": z["pos"]}) 15 | return seq_dict 16 | 17 | 18 | def get_qkv(inp_list): 19 | """The 1st element of the inp_list is about the template, 20 | the 2nd (the last) element is about the search region""" 21 | dict_x = inp_list[-1] 22 | dict_c = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0), 23 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1), 24 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} # concatenated dict 25 | q = dict_x["feat"] + dict_x["pos"] 26 | k = dict_c["feat"] + dict_c["pos"] 27 | v = dict_c["feat"] 28 | key_padding_mask = dict_c["mask"] 29 | return q, k, v, key_padding_mask 30 | 31 | 32 | def merge_feature_sequence(inp_list): 33 | # Used for AiaTrack 34 | return {'feat': torch.cat([x['feat'] for x in inp_list], dim=0), 35 | 'mask': torch.cat([x['mask'] for x in inp_list], dim=1), 36 | 'pos': torch.cat([x['pos'] for x in inp_list], dim=0), 37 | 'inr': torch.cat([x['inr'] for x in inp_list], dim=0)} 38 | -------------------------------------------------------------------------------- /lib/utils/registry.py: -------------------------------------------------------------------------------- 1 | from lib.models import * 2 | class Registry(object): 3 | """ 4 | The registry that provides name -> object mapping, to support third-party 5 | users' custom modules. 6 | To create a registry (e.g. a backbone registry): 7 | .. code-block:: python 8 | BACKBONE_REGISTRY = Registry('BACKBONE') 9 | To register an object: 10 | .. code-block:: python 11 | @BACKBONE_REGISTRY.register() 12 | class MyBackbone(): 13 | ... 14 | Or: 15 | .. code-block:: python 16 | BACKBONE_REGISTRY.register(MyBackbone) 17 | """ 18 | 19 | def __init__(self, name): 20 | """ 21 | Args: 22 | name (str): the name of this registry 23 | """ 24 | self._name = name 25 | self._obj_map = {} 26 | 27 | def _do_register(self, name, obj, suffix=None): 28 | if isinstance(suffix, str): 29 | name = name + '_' + suffix 30 | 31 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 32 | f"in '{self._name}' registry!") 33 | self._obj_map[name] = obj 34 | 35 | def register(self, obj=None, suffix=None): 36 | """ 37 | Register the given object under the the name `obj.__name__`. 38 | Can be used as either a decorator or not. 39 | See docstring of this class for usage. 40 | """ 41 | if obj is None: 42 | # used as a decorator 43 | def deco(func_or_class): 44 | name = func_or_class.__name__ 45 | self._do_register(name, func_or_class, suffix) 46 | return func_or_class 47 | 48 | return deco 49 | 50 | # used as a function call 51 | name = obj.__name__ 52 | self._do_register(name, obj, suffix) 53 | 54 | def get(self, name, suffix='basicsr'): 55 | ret = self._obj_map.get(name) 56 | if ret is None: 57 | ret = self._obj_map.get(name + '_' + suffix) 58 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 59 | if ret is None: 60 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 61 | return ret 62 | 63 | def __contains__(self, name): 64 | return name in self._obj_map 65 | 66 | def __iter__(self): 67 | return iter(self._obj_map.items()) 68 | 69 | def keys(self): 70 | return self._obj_map.keys() 71 | 72 | 73 | MODEL_REGISTRY = Registry('model') 74 | TRACKER_REGISTRY = Registry('tracker') 75 | -------------------------------------------------------------------------------- /lib/utils/variable_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from bytecode import Bytecode, Instr 3 | 4 | 5 | class get_local(object): 6 | cache = {} 7 | is_activate = False 8 | 9 | def __init__(self, varname): 10 | self.varname = varname 11 | 12 | def __call__(self, func): 13 | if not type(self).is_activate: 14 | return func 15 | 16 | type(self).cache[func.__qualname__] = [] 17 | c = Bytecode.from_code(func.__code__) 18 | extra_code = [ 19 | Instr('STORE_FAST', '_res'), 20 | Instr('LOAD_FAST', self.varname), 21 | Instr('STORE_FAST', '_value'), 22 | Instr('LOAD_FAST', '_res'), 23 | Instr('LOAD_FAST', '_value'), 24 | Instr('BUILD_TUPLE', 2), 25 | Instr('STORE_FAST', '_result_tuple'), 26 | Instr('LOAD_FAST', '_result_tuple'), 27 | ] 28 | c[-1:-1] = extra_code 29 | func.__code__ = c.to_code() 30 | 31 | def wrapper(*args, **kwargs): 32 | res, values = func(*args, **kwargs) 33 | if isinstance(values, torch.Tensor): 34 | type(self).cache[func.__qualname__].append(values.detach().cpu().numpy()) 35 | elif isinstance(values, list): # list of Tensor 36 | type(self).cache[func.__qualname__].append([value.detach().cpu().numpy() for value in values]) 37 | else: 38 | raise NotImplementedError 39 | return res 40 | 41 | return wrapper 42 | 43 | @classmethod 44 | def clear(cls): 45 | for key in cls.cache.keys(): 46 | cls.cache[key] = [] 47 | 48 | @classmethod 49 | def activate(cls): 50 | cls.is_activate = True 51 | -------------------------------------------------------------------------------- /lib/vis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiYunfengLYF/LightFC/97dc3405ec8e8c5ad3d3ad95cae7f12e4f17b5b0/lib/vis/__init__.py -------------------------------------------------------------------------------- /lib/vis/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import cv2 5 | 6 | 7 | def draw_figure(fig): 8 | fig.canvas.draw() 9 | fig.canvas.flush_events() 10 | plt.pause(0.001) 11 | 12 | 13 | def show_tensor(a: torch.Tensor, fig_num = None, title = None, range=(None, None), ax=None): 14 | """Display a 2D tensor. 15 | args: 16 | fig_num: Figure number. 17 | title: Title of figure. 18 | """ 19 | a_np = a.squeeze().cpu().clone().detach().numpy() 20 | if a_np.ndim == 3: 21 | a_np = np.transpose(a_np, (1, 2, 0)) 22 | 23 | if ax is None: 24 | fig = plt.figure(fig_num) 25 | plt.tight_layout() 26 | plt.cla() 27 | plt.imshow(a_np, vmin=range[0], vmax=range[1]) 28 | plt.axis('off') 29 | plt.axis('equal') 30 | if title is not None: 31 | plt.title(title) 32 | draw_figure(fig) 33 | else: 34 | ax.cla() 35 | ax.imshow(a_np, vmin=range[0], vmax=range[1]) 36 | ax.set_axis_off() 37 | ax.axis('equal') 38 | if title is not None: 39 | ax.set_title(title) 40 | draw_figure(plt.gcf()) 41 | 42 | 43 | def plot_graph(a: torch.Tensor, fig_num = None, title = None): 44 | """Plot graph. Data is a 1D tensor. 45 | args: 46 | fig_num: Figure number. 47 | title: Title of figure. 48 | """ 49 | a_np = a.squeeze().cpu().clone().detach().numpy() 50 | if a_np.ndim > 1: 51 | raise ValueError 52 | fig = plt.figure(fig_num) 53 | # plt.tight_layout() 54 | plt.cla() 55 | plt.plot(a_np) 56 | if title is not None: 57 | plt.title(title) 58 | draw_figure(fig) 59 | 60 | 61 | def show_image_with_boxes(im, boxes, iou_pred=None, disp_ids=None): 62 | im_np = im.clone().cpu().squeeze().numpy() 63 | im_np = np.ascontiguousarray(im_np.transpose(1, 2, 0).astype(np.uint8)) 64 | 65 | boxes = boxes.view(-1, 4).cpu().numpy().round().astype(int) 66 | 67 | # Draw proposals 68 | for i_ in range(boxes.shape[0]): 69 | if disp_ids is None or disp_ids[i_]: 70 | bb = boxes[i_, :] 71 | disp_color = (i_*38 % 256, (255 - i_*97) % 256, (123 + i_*66) % 256) 72 | cv2.rectangle(im_np, (bb[0], bb[1]), (bb[0] + bb[2], bb[1] + bb[3]), 73 | disp_color, 1) 74 | 75 | if iou_pred is not None: 76 | text_pos = (bb[0], bb[1] - 5) 77 | cv2.putText(im_np, 'ID={} IOU = {:3.2f}'.format(i_, iou_pred[i_]), text_pos, 78 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, bottomLeftOrigin=False) 79 | 80 | im_tensor = torch.from_numpy(im_np.transpose(2, 0, 1)).float() 81 | 82 | return im_tensor 83 | 84 | 85 | 86 | def _pascal_color_map(N=256, normalized=False): 87 | """ 88 | Python implementation of the color map function for the PASCAL VOC data set. 89 | Official Matlab version can be found in the PASCAL VOC devkit 90 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 91 | """ 92 | 93 | def bitget(byteval, idx): 94 | return (byteval & (1 << idx)) != 0 95 | 96 | dtype = 'float32' if normalized else 'uint8' 97 | cmap = np.zeros((N, 3), dtype=dtype) 98 | for i in range(N): 99 | r = g = b = 0 100 | c = i 101 | for j in range(8): 102 | r = r | (bitget(c, 0) << 7 - j) 103 | g = g | (bitget(c, 1) << 7 - j) 104 | b = b | (bitget(c, 2) << 7 - j) 105 | c = c >> 3 106 | 107 | cmap[i] = np.array([r, g, b]) 108 | 109 | cmap = cmap / 255 if normalized else cmap 110 | return cmap 111 | 112 | 113 | def overlay_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): 114 | """ Overlay mask over image. 115 | Source: https://github.com/albertomontesg/davis-interactive/blob/master/davisinteractive/utils/visualization.py 116 | This function allows you to overlay a mask over an image with some 117 | transparency. 118 | # Arguments 119 | im: Numpy Array. Array with the image. The shape must be (H, W, 3) and 120 | the pixels must be represented as `np.uint8` data type. 121 | ann: Numpy Array. Array with the mask. The shape must be (H, W) and the 122 | values must be intergers 123 | alpha: Float. Proportion of alpha to apply at the overlaid mask. 124 | colors: Numpy Array. Optional custom colormap. It must have shape (N, 3) 125 | being N the maximum number of colors to represent. 126 | contour_thickness: Integer. Thickness of each object index contour draw 127 | over the overlay. This function requires to have installed the 128 | package `opencv-python`. 129 | # Returns 130 | Numpy Array: Image of the overlay with shape (H, W, 3) and data type 131 | `np.uint8`. 132 | """ 133 | im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) 134 | if im.shape[:-1] != ann.shape: 135 | raise ValueError('First two dimensions of `im` and `ann` must match') 136 | if im.shape[-1] != 3: 137 | raise ValueError('im must have three channels at the 3 dimension') 138 | 139 | colors = colors or _pascal_color_map() 140 | colors = np.asarray(colors, dtype=np.uint8) 141 | 142 | mask = colors[ann] 143 | fg = im * alpha + (1 - alpha) * mask 144 | 145 | img = im.copy() 146 | img[ann > 0] = fg[ann > 0] 147 | 148 | if contour_thickness: # pragma: no cover 149 | import cv2 150 | for obj_id in np.unique(ann[ann > 0]): 151 | contours = cv2.findContours((ann == obj_id).astype( 152 | np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 153 | cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), 154 | contour_thickness) 155 | return img 156 | -------------------------------------------------------------------------------- /lib/vis/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def numpy_to_torch(a: np.ndarray): 6 | return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0) -------------------------------------------------------------------------------- /tracking/analysis_results.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | plt.rcParams['figure.figsize'] = [8, 8] 4 | env_num = 0 5 | from lib.test.analysis.plot_results import print_results 6 | from lib.test.evaluation import get_dataset, trackerlist 7 | 8 | trackers = [] 9 | dataset_name = 'utb' 10 | 11 | 12 | parameter_name = r'mobilnetv2_p_pwcorr_se_scf_sc_iab_sc_adj_concat_repn33_se_conv33_center_wiou' 13 | trackers.extend( 14 | trackerlist(name='lightfc', parameter_name=parameter_name, dataset_name=dataset_name, 15 | run_ids=None, env_num=env_num, display_name=parameter_name)) 16 | 17 | 18 | dataset = get_dataset(dataset_name, env_num=env_num) 19 | print_results(trackers, dataset, dataset_name, merge_results=True, plot_types=('success', 'norm_prec', 'prec'), 20 | env_num=env_num) 21 | 22 | -------------------------------------------------------------------------------- /tracking/profile_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | from thop import profile, clever_format 6 | 7 | from lib.models.lightfc import MobileNetV2, repn33_se_center_concat 8 | from lib.models.lightfc.fusion.ecm import pwcorr_se_repn31_sc_iab_sc_adj_concat 9 | 10 | 11 | class lightTrack_track(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | self.backbone = MobileNetV2() 16 | self.fusion = pwcorr_se_repn31_sc_iab_sc_adj_concat() 17 | self.head = repn33_se_center_concat(inplanes=192, channel=256) 18 | 19 | for module in self.backbone.modules(): 20 | if hasattr(module, 'switch_to_deploy'): 21 | module.switch_to_deploy() 22 | for module in self.fusion.modules(): 23 | if hasattr(module, 'switch_to_deploy'): 24 | module.switch_to_deploy() 25 | for module in self.head.modules(): 26 | if hasattr(module, 'switch_to_deploy'): 27 | module.switch_to_deploy() 28 | 29 | def forward(self, z, x): 30 | x = self.backbone(x) 31 | opt = self.fusion(z, x) 32 | out = self.head(opt) 33 | 34 | return out 35 | 36 | 37 | 38 | model = lightTrack_track().cuda().eval() 39 | 40 | if __name__ == '__main__': 41 | z_feat = torch.rand(1, 96, 8, 8).cuda() 42 | x = torch.rand(1, 3, 256, 256).cuda() 43 | macs, params = profile(model, inputs=(z_feat, x), custom_ops=None, verbose=False) 44 | macs, params = clever_format([macs, params], "%.3f") 45 | print('overall macs is ', macs) 46 | print('overall params is ', params) 47 | -------------------------------------------------------------------------------- /tracking/speed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | from torch import nn 5 | 6 | from lib.models.lightfc import MobileNetV2, repn33_se_center_concat 7 | from lib.models.lightfc.fusion.ecm import pwcorr_se_repn31_sc_iab_sc_adj_concat 8 | 9 | 10 | class lightTrack_track(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | self.backbone = MobileNetV2() 15 | self.fusion = pwcorr_se_repn31_sc_iab_sc_adj_concat() 16 | self.head = repn33_se_center_concat(inplanes=192, channel=256) 17 | 18 | for module in self.backbone.modules(): 19 | if hasattr(module, 'switch_to_deploy'): 20 | module.switch_to_deploy() 21 | for module in self.fusion.modules(): 22 | if hasattr(module, 'switch_to_deploy'): 23 | module.switch_to_deploy() 24 | for module in self.head.modules(): 25 | if hasattr(module, 'switch_to_deploy'): 26 | module.switch_to_deploy() 27 | 28 | def forward(self, z, x): 29 | x = self.backbone(x) 30 | opt = self.fusion(z, x) 31 | out = self.head(opt) 32 | 33 | return out 34 | 35 | 36 | if __name__ == "__main__": 37 | # test the running speed 38 | 39 | use_gpu = True 40 | z_feat = torch.rand(1, 96, 8, 8).cuda() 41 | x = torch.rand(1, 3, 256, 256).cuda() 42 | 43 | if use_gpu: 44 | model = lightTrack_track().cuda() 45 | x = x.cuda() 46 | z_feat = z_feat.cuda() 47 | # oup = model(x, zf) 48 | 49 | T_w = 10 # warmup 50 | T_t = 100 # test 51 | with torch.no_grad(): 52 | for i in range(T_w): 53 | oup = model(z_feat, x) 54 | t_s = time.time() 55 | for i in range(T_t): 56 | oup = model(z_feat, x) 57 | t_e = time.time() 58 | print('speed: %.2f FPS' % (T_t / (t_e - t_s))) 59 | -------------------------------------------------------------------------------- /tracking/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | import vot 6 | 7 | prj_path = os.path.join(os.path.dirname(__file__), '..') 8 | if prj_path not in sys.path: 9 | sys.path.append(prj_path) 10 | 11 | from lib.test.evaluation import get_dataset 12 | from lib.test.evaluation.running import run_dataset 13 | from lib.test.evaluation.tracker import Tracker 14 | 15 | 16 | def run_tracker(tracker_name, tracker_param, run_id=None, dataset_name='otb', sequence=None, debug=0, threads=0, 17 | num_gpus=8, env_num=0): 18 | """Run tracker on sequence or dataset. 19 | args: 20 | tracker_name: Name of tracking method. 21 | tracker_param: Name of parameter file. 22 | run_id: The run id. 23 | dataset_name: Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot). 24 | sequence: Sequence number or name. 25 | debug: Debug level. 26 | threads: Number of threads. 27 | """ 28 | 29 | dataset = get_dataset(dataset_name, env_num=env_num) 30 | 31 | if sequence is not None: 32 | dataset = [dataset[sequence]] 33 | 34 | trackers = [Tracker(tracker_name, tracker_param, dataset_name, run_id, env_num, deploy=False)] 35 | run_dataset(dataset, trackers, debug, threads, num_gpus=num_gpus, env_num=env_num) 36 | 37 | 38 | def for_test(): 39 | parser = argparse.ArgumentParser(description='Run tracker on sequence or dataset.') 40 | parser.add_argument('--runid', type=int, default=None, help='The run id.') 41 | parser.add_argument('--sequence', type=str, default=None, help='Sequence number or name.') 42 | # parser.add_argument('--debug', type=int, default=1, help='Debug level.') 43 | parser.add_argument('--debug', type=int, default=0, help='Debug level.') 44 | # parser.add_argument('--threads', type=int, default=0, help='Number of threads.') 45 | parser.add_argument('--threads', type=int, default=0, help='Number of threads.') 46 | parser.add_argument('--num_gpus', type=int, default=1) 47 | parser.add_argument('--env_num', type=int, default=0, help='Use for multi environment developing, support: 0,1,2') 48 | parser.add_argument('--deploy', type=int, default=False, help='') 49 | args = parser.parse_args() 50 | 51 | try: 52 | seq_name = int(args.sequence) 53 | except: 54 | seq_name = args.sequence 55 | 56 | tracker_list = [ 57 | # {'name': 'lightfc', 58 | # 'param': 'mobilnetv2_p_pwcorr_se_scf_sc_iab_sc_adj_concat_repn33_se_conv33_center_wiou' 59 | # }, 60 | {'name': 'lightfc', 61 | 'param': 'baseline_v1_release_backbone_tinyvit' 62 | }, 63 | ] 64 | 65 | dataset_list = [ 66 | {'name': 'otb'}, 67 | # {'name': 'uav'}, 68 | # {'name': 'lasot'}, 69 | # {'name': 'tc128'}, 70 | # {'name': 'utb'}, 71 | # {'name': 'uot'}, 72 | # {'name': 'tnl2k'}, 73 | # {'name': 'trackingnet'}, 74 | ] 75 | 76 | for dataset_item in dataset_list: 77 | for tracker_item in tracker_list: 78 | trk_name, trk_param, data_name = tracker_item['name'], tracker_item['param'], dataset_item['name'] 79 | run_tracker(trk_name, trk_param, args.runid, data_name, seq_name, args.debug, args.threads, 80 | num_gpus=args.num_gpus, env_num=args.env_num) 81 | print(F'Tracker {trk_name} {trk_param} on {data_name} dataset is OK! \n\n') 82 | 83 | 84 | if __name__ == '__main__': 85 | for_test() 86 | -------------------------------------------------------------------------------- /tracking/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | 6 | def parse_args(): 7 | """ 8 | args for training. 9 | """ 10 | parser = argparse.ArgumentParser(description='Parse args for training') 11 | # for train 12 | parser.add_argument('--script', type=str, default='emptytrack', help='training script name') 13 | parser.add_argument('--config', type=str, default='detection_datasets', help='yaml configure file name') 14 | parser.add_argument('--save_dir', type=str, help='root directory to save checkpoints, logs, and tensorboard') 15 | parser.add_argument('--mode', type=str, choices=["single", "multiple", "multi_node"], default="single", 16 | help="train on single gpu or multiple gpus") 17 | parser.add_argument('--nproc_per_node', type=int, help="number of GPUs per node") # specify when mode is multiple 18 | parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format 19 | parser.add_argument('--script_prv', type=str, help='training script name') 20 | parser.add_argument('--config_prv', type=str, default='baseline', help='yaml configure file name') 21 | parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb 22 | parser.add_argument('--env_num', type=int, default=0, 23 | help='Use for multi environment developing, support: 0,1,2') 24 | 25 | # for knowledge distillation 26 | parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation 27 | parser.add_argument('--script_teacher', type=str, help='teacher script name') 28 | parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name') 29 | 30 | # for multiple machines 31 | parser.add_argument('--rank', type=int, help='Rank of the current process.') 32 | parser.add_argument('--world-size', type=int, help='Number of processes participating in the job.') 33 | parser.add_argument('--ip', type=str, default='127.0.0.1', help='IP of the current rank 0.') 34 | parser.add_argument('--port', type=int, default='20000', help='Port of the current rank 0.') 35 | 36 | args = parser.parse_args() 37 | 38 | return args 39 | 40 | 41 | def main(): 42 | args = parse_args() 43 | if args.mode == "single": 44 | train_cmd = "python lib/train/run_training.py --script %s --c onfig %s --save_dir %s " \ 45 | "--use_lmdb %d --script_prv %s --config_prv %s --distill %d --script_teacher %s " \ 46 | "--config_teacher %s --use_wandb %d --env_num %d" \ 47 | % (args.script, args.config, args.save_dir, args.use_lmdb, args.script_prv, args.config_prv, 48 | args.distill, args.script_teacher, args.config_teacher, args.use_wandb, args.env_num) 49 | 50 | elif args.mode == "multiple": 51 | train_cmd = "python -m torch.distributed.launch --nproc_per_node %d --master_port %d " \ 52 | "lib/train/run_training.py --script %s --config %s --save_dir %s --use_lmdb %d " \ 53 | "--script_prv %s --config_prv %s --use_wandb %d --env_num %d " \ 54 | "--distill %d --script_teacher %s --config_teacher %s" \ 55 | % (args.nproc_per_node, random.randint(10000, 50000), args.script, args.config, args.save_dir, 56 | args.use_lmdb, args.script_prv, args.config_prv, args.use_wandb, args.env_num, 57 | args.distill, args.script_teacher, args.config_teacher,) 58 | 59 | elif args.mode == "multi_node": 60 | train_cmd = "python -m torch.distributed.launch --nproc_per_node %d --master_addr %s --master_port %d " \ 61 | "--nnodes %d --node_rank %d lib/train/run_training.py --script %s --config %s --save_dir %s " \ 62 | "--use_lmdb %d --script_prv %s --config_prv %s --use_wandb %d --env_num %d --distill %d " \ 63 | "--script_teacher %s --config_teacher %s" \ 64 | % (args.nproc_per_node, args.ip, args.port, args.world_size, args.rank, args.script, args.config, 65 | args.save_dir, args.use_lmdb, args.script_prv, args.config_prv, args.use_wandb, args.env_num, 66 | args.distill, args.script_teacher, args.config_teacher) 67 | else: 68 | raise ValueError("mode should be 'single' or 'multiple'.") 69 | os.system(train_cmd) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | --------------------------------------------------------------------------------