├── .gitignore ├── INSTALL.md ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── TRAIN.md ├── experiments ├── siamfc_alex_upxcorr_otb │ ├── config.yaml │ └── convert_model.py ├── siamfc_alex_upxcorr_vot │ ├── config.yaml │ └── convert_model.py ├── siammask_r50_l3 │ ├── config.yaml │ └── convert_model.py ├── siamrpn_alex_dwxcorr_otb │ └── config.yaml ├── siamrpn_alex_dwxcorr_vot │ └── config.yaml ├── siamrpn_r50_l234_dwxcorr │ └── config.yaml ├── siamrpn_r50_l234_dwxcorr_otb │ └── config.yaml ├── siamrpn_r50_l234_dwxcorr_vot │ └── config.yaml └── siamrpn_r50_l234_dwxcorr_votlt │ └── config.yaml ├── install.sh ├── pysot ├── __init__.py ├── core │ ├── __init__.py │ ├── config.py │ └── xcorr.py ├── datasets │ ├── __init__.py │ ├── anchor_target.py │ ├── augmentation.py │ └── dataset.py ├── models │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── mobile_v2.py │ │ └── resnet_atrous.py │ ├── head │ │ ├── __init__.py │ │ ├── mask.py │ │ └── rpn.py │ ├── init_weight.py │ ├── loss.py │ ├── model_builder.py │ └── neck │ │ ├── __init__.py │ │ └── neck.py ├── tracker │ ├── __init__.py │ ├── base_tracker.py │ ├── classifier │ │ ├── __init__.py │ │ ├── base_classifier.py │ │ ├── libs │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── augmentation.py │ │ │ ├── complex.py │ │ │ ├── dcf.py │ │ │ ├── fourier.py │ │ │ ├── operation.py │ │ │ ├── optimization.py │ │ │ ├── params.py │ │ │ ├── plotting.py │ │ │ ├── preprocessing.py │ │ │ ├── tensordict.py │ │ │ └── tensorlist.py │ │ └── optim.py │ ├── siamfc_tracker.py │ ├── siammask_tracker.py │ ├── siamrpn_tracker.py │ ├── siamrpnlt_tracker.py │ └── tracker_builder.py └── utils │ ├── __init__.py │ ├── anchor.py │ ├── average_meter.py │ ├── bbox.py │ ├── distributed.py │ ├── log_helper.py │ ├── lr_scheduler.py │ ├── misc.py │ └── model_load.py ├── requirements.txt ├── setup.py ├── testing_dataset └── README.md ├── toolkit ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── got10k.py │ ├── lasot.py │ ├── nfs.py │ ├── otb.py │ ├── trackingnet.py │ ├── uav.py │ ├── video.py │ ├── visdrone.py │ └── vot.py ├── evaluation │ ├── __init__.py │ ├── ar_benchmark.py │ ├── eao_benchmark.py │ ├── f1_benchmark.py │ └── ope_benchmark.py ├── utils │ ├── __init__.py │ ├── c_region.pxd │ ├── misc.py │ ├── region.c │ ├── region.pyx │ ├── src │ │ ├── buffer.h │ │ ├── region.c │ │ └── region.h │ └── statistics.py └── visualization │ ├── __init__.py │ ├── draw_eao.py │ ├── draw_f1.py │ ├── draw_success_precision.py │ └── draw_utils.py ├── tools ├── demo.py ├── eval.py ├── test.py └── train.py └── vot_iter ├── __init__.py ├── tracker_SiamRPNpp.m ├── vot.py └── vot_iter.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # custom files 10 | .idea/ 11 | 12 | # dataset 13 | training_dataset/* 14 | testing_dataset/* 15 | !testing_dataset/README.md 16 | !training_dataset/README.md 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | This document contains detailed instructions for installing dependencies for PySOT. We recommand using the [install.sh](install.sh). The code is tested on an Ubuntu 16.04 system with Nvidia GPU (We recommand 1080TI / TITAN XP). 4 | 5 | ### Requirments 6 | * Conda with Python 3.7. 7 | * Nvidia GPU. 8 | * PyTorch 0.4.1 9 | * yacs 10 | * pyyaml 11 | * matplotlib 12 | * tqdm 13 | * OpenCV 14 | 15 | ## Step-by-step instructions 16 | 17 | #### Create environment and activate 18 | ```bash 19 | conda create --name pysot python=3.7 20 | conda activate pysot 21 | ``` 22 | 23 | #### Install numpy/pytorch/opencv 24 | ``` 25 | conda install numpy 26 | conda install pytorch=0.4.1 torchvision cuda90 -c pytorch 27 | pip install opencv-python 28 | ``` 29 | 30 | #### Install other requirements 31 | ``` 32 | pip install pyyaml yacs tqdm colorama matplotlib cython tensorboardX 33 | ``` 34 | 35 | #### Build extensions 36 | ``` 37 | python setup.py build_ext --inplace 38 | ``` 39 | 40 | 41 | ## Try with scripts 42 | ``` 43 | bash install.sh /path/to/your/conda pysot 44 | ``` 45 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # PySOT Model Zoo 2 | 3 | ## Introduction 4 | 5 | This file documents a large collection of baselines trained with pysot. All configurations for these baselines are located in the [`experiments`](testing_experiments) directory. The tables below provide results about inference. Links to the trained models as well as their output are provided. 6 | 7 | ## Visual Tracking Baselines 8 | 9 | ### Short-term Tracking 10 | 11 | | Model
(arch+backbone+xcorr)
| VOT16
(EAO/A/R)
| VOT18
(EAO/A/R)
| VOT19
(EAO/A/R)
| OTB2015
(AUC/Prec.)
| VOT18-LT
(F1)
| Speed
(fps)
| url | 12 | |:---------------------------------:|:-:|:------------------------:|:--------------------:|:----------------:|:--------------:|:------------:|:-----------:| 13 | | siamrpn_alex_dwxcorr | 0.393/0.618/0.238 | 0.352/0.576/0.290 | 0.260/0.573/0.547| - | - | 180 | [link](https://drive.google.com/open?id=1t62x56Jl7baUzPTo0QrC4jJnwvPZm-2m) | 14 | | siamrpn_alex_dwxcorr_otb | - | - | - |0.666/0.876 | - | 180 | [link](https://drive.google.com/open?id=1gCpmR85Qno3C-naR3SLqRNpVfU7VJ2W0) | 15 | | siamrpn_r50_l234_dwxcorr | 0.464/0.642/0.196 | 0.415/0.601/0.234 | 0.287/0.595/0.467 | - | - | 35 | [link](https://drive.google.com/open?id=1Q4-1563iPwV6wSf_lBHDj5CPFiGSlEPG) | 16 | | siamrpn_r50_l234_dwxcorr_otb | - | - | - |0.696/0.914 | - | 35 | [link](https://drive.google.com/open?id=1Cx_oHu6o0gNeH7F9zZrgevfAGdyWC4D5) | 17 | |siamrpn_mobilev2_l234_dwxcorr| 0.455/0.624/0.214 | 0.410/0.586/0.229 | 0.292/0.580/0.446| - | - | 75 | [link](https://drive.google.com/open?id=1JB94pZTvB1ZByU-qSJn4ZAIfjLWE5EBJ) | 18 | | siammask_r50_l3 | 0.455/0.634/0.219 | 0.423/0.615/0.248 | 0.283/0.597/0.461 | - | - | 56 | [link](https://drive.google.com/open?id=1YbPUQVTYw_slAvk_DchvRY-7B6rnSXP9) | 19 | | siamrpn_r50_l234_dwxcorr_lt | - | - | - | - | 0.629 | 20 | [link](https://drive.google.com/open?id=1lOOTedwGLbGZ7MAbqJimIcET3ANJd29A) | 20 | 21 | The models can also be downloaded from [Baidu Yun](https://pan.baidu.com/s/1GB9-aTtjG57SebraVoBfuQ) Extraction Code: j9yb 22 | 23 | Note: 24 | 25 | - speed tested on GTX-1080Ti 26 | - `alex` denotes [AlexNet](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks), `r50_lxyz` denotes the outputs of stage x, y, and z in [ResNet-50](https://arxiv.org/abs/1512.03385), and `mobilev2` denotes [MobileNetV2](https://arxiv.org/abs/1801.04381). 27 | - `dwxcorr` denotes Depth-wise Cross Correlation. See more in [SiamRPN++ Section 3.4](https://arxiv.org/abs/1812.11703). 28 | - The suffixes `otb` and `lt` are designed for the [OTB](http://cvlab.hanyang.ac.kr/tracker_benchmark/benchmark.html) and [VOT long-term tracking challenge](http://www.votchallenge.net/vot2018/), the default (without suffix) is designed for [VOT short-term tracking challenge](http://www.votchallenge.net/index.html). 29 | - All above models are trained on VID,YoutubeBB,COCO,ImageNetDet which are the same as DaSiamRPN. 30 | - The model of `SiamFC` is from the author's [Matlab](https://github.com/bertinetto/siamese-fc) implementation. The code for further model converting is available in the corresponding experiment file folders. 31 | 32 | ## License 33 | 34 | All models available for download through this document are licensed under the [Creative Commons Attribution-ShareAlike 3.0 license](https://creativecommons.org/licenses/by-sa/3.0/). 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DROL 2 | This is the repo for paper "Discriminative and Robust Online Learning for Siamese Visual Tracking" [[paper](https://arxiv.org/abs/1909.02959)] [[results](https://drive.google.com/open?id=1iXtaxr1zkWvKf6AAMwN98ymaId9ixU4z)], presented as poster at AAAI 2020. 3 | 4 | ## Introduction 5 | 6 | The proposed Discriminative and Robust Online Learning (DROL) module is designed to work with a variety of off-the-shelf siamese trackers. Our method is extensively evaluated over serveral mainstream benchmarks and is believed to induce a consistant performance gain over the given baseline. The model includes but not limited to, as paper evaluated: 7 | 8 | - [SiamRPN++](https://arxiv.org/abs/1812.11703) (DROL-RPN) 9 | - [SiamMask](https://arxiv.org/abs/1812.05050) (DROL-MASK) 10 | - [SiamFC](https://arxiv.org/abs/1606.09549) (DROL-FC) 11 | 12 | ## Model Zoo 13 | 14 | The corresponding offline-trained models are availabe at [PySOT Model Zoo](MODEL_ZOO.md). 15 | 16 | 17 | ## Get Started 18 | 19 | ### Installation 20 | 21 | - Please find installation instructions for PyTorch and PySOT in [`INSTALL.md`](INSTALL.md). 22 | - Add DROL to your PYTHONPATH 23 | ```bash 24 | export PYTHONPATH=/path/to/drol:$PYTHONPATH 25 | ``` 26 | 27 | ### Download models 28 | Download models in [PySOT Model Zoo](MODEL_ZOO.md) and put the model.pth to the corresponding directory in experiment. 29 | 30 | ### Test tracker 31 | ```bash 32 | cd experiments/siamrpn_r50_l234_dwxcorr 33 | python -u ../../tools/test.py \ 34 | --snapshot model.pth \ # model path 35 | --dataset VOT2018 \ # dataset name 36 | --config config.yaml # config file 37 | ``` 38 | 39 | ### Eval tracker 40 | assume still in experiments/siamrpn_r50_l234_dwxcorr_8gpu 41 | ``` bash 42 | python ../../tools/eval.py \ 43 | --tracker_path ./results \ # result path 44 | --dataset VOT2018 \ # dataset name 45 | --num 1 \ # number thread to eval 46 | --tracker_prefix 'model' # tracker_name 47 | ``` 48 | 49 | ### Others 50 | - For `DROL-RPN`, we have seperate config file thus each own experiment file folder for `vot`/`votlt`/`otb`/`others`, where `vot` is used for `VOT-20XX-baseline` benchmark, `votlt` for `VOT-20XX-longterm` benchmark, `otb` for `OTB2013/15` benchmark, and `others` is default setting thus for all the other benchmarks, including but not limited to `LaSOT`/`TrackingNet`/`UAV123`. 51 | - For `DROL-FC/DROL-Mask`, only experiments on `vot/otb` are evaluated as described in the paper. Similar to the repo of `PySOT`, we use config file for `vot` as default setting. 52 | 53 | - Since this repo is a grown-up modification of [PySOT](https://github.com/STVIR/pysot), we recommend to refer to PySOT for more technical issues. 54 | 55 | 56 | ## References 57 | - Jinghao Zhou, Peng Wang, Haoyang Sun, '[Discriminative and Robust Online Learning For Siamese Visual Tracking](http://arxiv.org/abs/1909.02959)', Proc. AAAI Conference on Artificial Intelligence (AAAI), 2020. 58 | 59 | ### Ackowledgement 60 | - [pysot](https://github.com/STVIR/pysot) 61 | - [pytracking](https://github.com/visionml/pytracking) 62 | -------------------------------------------------------------------------------- /TRAIN.md: -------------------------------------------------------------------------------- 1 | # PySOT Training Tutorial 2 | 3 | This implements training of SiamRPN with backbone architectures, such as ResNet, AlexNet. 4 | ### Add PySOT to your PYTHONPATH 5 | ```bash 6 | export PYTHONPATH=/path/to/pysot:$PYTHONPATH 7 | ``` 8 | 9 | ## Prepare training dataset 10 | Prepare training dataset, detailed preparations are listed in [training_dataset](training_dataset) directory. 11 | * [VID](http://image-net.org/challenges/LSVRC/2017/) 12 | * [YOUTUBEBB](https://research.google.com/youtube-bb/) 13 | * [DET](http://image-net.org/challenges/LSVRC/2017/) 14 | * [COCO](http://cocodataset.org) 15 | 16 | ## Download pretrained backbones 17 | Download pretrained backbones from [Google Drive](https://drive.google.com/drive/folders/1DuXVWVYIeynAcvt9uxtkuleV6bs6e3T9) and put them in `pretrained_models` directory 18 | 19 | ## Training 20 | 21 | To train a model (SiamRPN++), run `train.py` with the desired configs: 22 | 23 | ```bash 24 | cd experiments/siamrpn_r50_l234_dwxcorr_8gpu 25 | ``` 26 | 27 | ### Multi-processing Distributed Data Parallel Training 28 | 29 | #### Single node, multiple GPUs: 30 | ```bash 31 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 32 | python -m torch.distributed.launch \ 33 | --nproc_per_node=8 \ 34 | --master_port=2333 \ 35 | ../../tools/train.py --cfg config.yaml 36 | ``` 37 | 38 | #### Multiple nodes: 39 | Node 1: (IP: 192.168.1.1, and has a free port: 2333) master node 40 | ```bash 41 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 42 | python -m torch.distributed.launch \ 43 | --nnodes=2 \ 44 | --node_rank=0 \ 45 | --nproc_per_node=8 \ 46 | --master_addr=192.168.1.1 \ # adjust your ip here 47 | --master_port=2333 \ 48 | ../../tools/train.py 49 | ``` 50 | Node 2: 51 | ```bash 52 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 53 | python -m torch.distributed.launch \ 54 | --nnodes=2 \ 55 | --node_rank=1 \ 56 | --nproc_per_node=8 \ 57 | --master_addr=192.168.1.1 \ 58 | --master_port=2333 \ 59 | ../../tools/train.py 60 | ``` 61 | 62 | ## Testing 63 | After training, you can test snapshots on VOT dataset. 64 | For `AlexNet`, you need to test snapshots from 35 to 50 epoch. 65 | For `ResNet`, you need to test snapshots from 10 to 20 epoch. 66 | 67 | ```bash 68 | START=10 69 | END=20 70 | seq $START 1 $END | \ 71 | xargs -I {} echo "snapshot/checkpoint_e{}.pth" | \ 72 | xargs -I {} \ 73 | python -u ../../tools/test.py \ 74 | --snapshot {} \ 75 | --config config.yaml \ 76 | --dataset VOT2018 2>&1 | tee logs/test_dataset.log 77 | ``` 78 | 79 | ## Evaluation 80 | ``` 81 | python ../../tools/eval.py \ 82 | --tracker_path ./results \ # result path 83 | --dataset VOT2018 \ # dataset name 84 | --num 4 \ # number thread to eval 85 | --tracker_prefix 'ch*' # tracker_name 86 | ``` 87 | -------------------------------------------------------------------------------- /experiments/siamfc_alex_upxcorr_otb/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamfc_alex_upxcorr" 2 | 3 | BACKBONE: 4 | TYPE: "alexnetlegacy2" 5 | 6 | ADJUST: 7 | ADJUST: False 8 | 9 | RPN: 10 | RPN: False 11 | 12 | MASK: 13 | MASK: False 14 | 15 | TRACK: 16 | TYPE: 'SiamFCTracker' 17 | TOTAL_STRIDE: 8 18 | SCALE_NUM: 3 19 | SCALE_STEP: 1.0375 20 | PENALTY_K: 0.9745 21 | WINDOW_INFLUENCE: 0.15 # 0.176 22 | LR: 0.59 23 | EXEMPLAR_SIZE: 127 24 | INSTANCE_SIZE: 255 25 | BASE_SIZE: 0 26 | CONTEXT_AMOUNT: 0.5 27 | TEMPLATE_UPDATE: True 28 | TAU_REGRESSION: 0.6 29 | TAU_CLASSIFICATION: 0.1 30 | 31 | # classifier 32 | USE_CLASSIFIER: True 33 | SEED: 12345 34 | COEE_CLASS: 0.8 35 | USE_ATTENTION_LAYER: True 36 | CHANNEL_ATTENTION: True 37 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool'] 38 | -------------------------------------------------------------------------------- /experiments/siamfc_alex_upxcorr_otb/convert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | import numpy as np 4 | import argparse 5 | 6 | from scipy import io as sio 7 | from tqdm import tqdm 8 | 9 | # code adapted from https://github.com/bilylee/SiamFC-TensorFlow/blob/master/utils/train_utils.py 10 | def convert(mat_path): 11 | """Get parameter from .mat file into parms(dict)""" 12 | 13 | def squeeze(vars_): 14 | # Matlab save some params with shape (*, 1) 15 | # However, we don't need the trailing dimension in TensorFlow. 16 | if isinstance(vars_, (list, tuple)): 17 | return [np.squeeze(v, 1) for v in vars_] 18 | else: 19 | return np.squeeze(vars_, 1) 20 | 21 | netparams = sio.loadmat(mat_path)["net"]["params"][0][0] 22 | params = dict() 23 | 24 | name_map = {(1, 'conv'): 0, (1, 'bn'): 1, 25 | (2, 'conv'): 4, (2, 'bn'): 5, 26 | (3, 'conv'): 8, (3, 'bn'): 9, 27 | (4, 'conv'): 11, (4, 'bn'): 12, 28 | (5, 'conv'): 14} 29 | for i in tqdm(range(netparams.size)): 30 | param = netparams[0][i] 31 | name = param["name"][0] 32 | value = param["value"] 33 | value_size = param["value"].shape[0] 34 | 35 | match = re.match(r"([a-z]+)([0-9]+)([a-z]+)", name, re.I) 36 | if match: 37 | items = match.groups() 38 | elif name == 'adjust_f': 39 | continue 40 | elif name == 'adjust_b': 41 | params['backbone.corr_bias'] = torch.from_numpy(squeeze(value)) 42 | continue 43 | 44 | 45 | op, layer, types = items 46 | layer = int(layer) 47 | if layer in [1, 2, 3, 4, 5]: 48 | idx = name_map[(layer, op)] 49 | if op == 'conv': # convolution 50 | if types == 'f': 51 | params['backbone.features.{}.weight'.format(idx)] = torch.from_numpy(value.transpose(3, 2, 0, 1)) 52 | elif types == 'b':# and layer == 5: 53 | value = squeeze(value) 54 | params['backbone.features.{}.bias'.format(idx)] = torch.from_numpy(value) 55 | elif op == 'bn': # batch normalization 56 | if types == 'x': 57 | m, v = squeeze(np.split(value, 2, 1)) 58 | params['backbone.features.{}.running_mean'.format(idx)] = torch.from_numpy(m) 59 | params['backbone.features.{}.running_var'.format(idx)] = torch.from_numpy(np.square(v)) 60 | # params['features.{}.num_batches_tracked'.format(idx)] = torch.zeros(0) 61 | elif types == 'm': 62 | value = squeeze(value) 63 | params['backbone.features.{}.weight'.format(idx)] = torch.from_numpy(value) 64 | elif types == 'b': 65 | value = squeeze(value) 66 | params['backbone.features.{}.bias'.format(idx)] = torch.from_numpy(value) 67 | else: 68 | raise Exception 69 | return params 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--mat_path', type=str, default="./2016-08-17_gray025.net.mat") 74 | args = parser.parse_args() 75 | params = convert(args.mat_path) 76 | torch.save(params, "./model.pth") 77 | -------------------------------------------------------------------------------- /experiments/siamfc_alex_upxcorr_vot/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamfc_alex_upxcorr" 2 | 3 | BACKBONE: 4 | TYPE: "alexnetlegacy2" 5 | 6 | ADJUST: 7 | ADJUST: False 8 | 9 | RPN: 10 | RPN: False 11 | 12 | MASK: 13 | MASK: False 14 | 15 | TRACK: 16 | TYPE: 'SiamFCTracker' 17 | TOTAL_STRIDE: 8 18 | SCALE_NUM: 3 19 | SCALE_STEP: 1.0375 20 | PENALTY_K: 0.9745 21 | WINDOW_INFLUENCE: 0.15 #0.176 22 | LR: 0.59 23 | EXEMPLAR_SIZE: 127 24 | INSTANCE_SIZE: 255 25 | BASE_SIZE: 0 26 | CONTEXT_AMOUNT: 0.5 27 | 28 | # classifier & updater 29 | USE_CLASSIFIER: True 30 | TEMPLATE_UPDATE: False 31 | TAU_REGRESSION: 0.6 32 | TAU_CLASSIFICATION: 0.1 33 | SEED: 12345 34 | COEE_CLASS: 0.6 35 | USE_ATTENTION_LAYER: False 36 | CHANNEL_ATTENTION: True 37 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool'] 38 | -------------------------------------------------------------------------------- /experiments/siamfc_alex_upxcorr_vot/convert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | import numpy as np 4 | import argparse 5 | 6 | from scipy import io as sio 7 | from tqdm import tqdm 8 | 9 | # code adapted from https://github.com/bilylee/SiamFC-TensorFlow/blob/master/utils/train_utils.py 10 | def convert(mat_path): 11 | """Get parameter from .mat file into parms(dict)""" 12 | 13 | def squeeze(vars_): 14 | # Matlab save some params with shape (*, 1) 15 | # However, we don't need the trailing dimension in TensorFlow. 16 | if isinstance(vars_, (list, tuple)): 17 | return [np.squeeze(v, 1) for v in vars_] 18 | else: 19 | return np.squeeze(vars_, 1) 20 | 21 | netparams = sio.loadmat(mat_path)["net"]["params"][0][0] 22 | params = dict() 23 | 24 | name_map = {(1, 'conv'): 0, (1, 'bn'): 1, 25 | (2, 'conv'): 4, (2, 'bn'): 5, 26 | (3, 'conv'): 8, (3, 'bn'): 9, 27 | (4, 'conv'): 11, (4, 'bn'): 12, 28 | (5, 'conv'): 14} 29 | for i in tqdm(range(netparams.size)): 30 | param = netparams[0][i] 31 | name = param["name"][0] 32 | value = param["value"] 33 | value_size = param["value"].shape[0] 34 | 35 | match = re.match(r"([a-z]+)([0-9]+)([a-z]+)", name, re.I) 36 | if match: 37 | items = match.groups() 38 | elif name == 'adjust_f': 39 | continue 40 | elif name == 'adjust_b': 41 | params['backbone.corr_bias'] = torch.from_numpy(squeeze(value)) 42 | continue 43 | 44 | 45 | op, layer, types = items 46 | layer = int(layer) 47 | if layer in [1, 2, 3, 4, 5]: 48 | idx = name_map[(layer, op)] 49 | if op == 'conv': # convolution 50 | if types == 'f': 51 | params['backbone.features.{}.weight'.format(idx)] = torch.from_numpy(value.transpose(3, 2, 0, 1)) 52 | elif types == 'b':# and layer == 5: 53 | value = squeeze(value) 54 | params['backbone.features.{}.bias'.format(idx)] = torch.from_numpy(value) 55 | elif op == 'bn': # batch normalization 56 | if types == 'x': 57 | m, v = squeeze(np.split(value, 2, 1)) 58 | params['backbone.features.{}.running_mean'.format(idx)] = torch.from_numpy(m) 59 | params['backbone.features.{}.running_var'.format(idx)] = torch.from_numpy(np.square(v)) 60 | # params['features.{}.num_batches_tracked'.format(idx)] = torch.zeros(0) 61 | elif types == 'm': 62 | value = squeeze(value) 63 | params['backbone.features.{}.weight'.format(idx)] = torch.from_numpy(value) 64 | elif types == 'b': 65 | value = squeeze(value) 66 | params['backbone.features.{}.bias'.format(idx)] = torch.from_numpy(value) 67 | else: 68 | raise Exception 69 | return params 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--mat_path', type=str, default="./2016-08-17.net.mat") 74 | args = parser.parse_args() 75 | params = convert(args.mat_path) 76 | torch.save(params, "./model.pth") -------------------------------------------------------------------------------- /experiments/siammask_r50_l3/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siammask_r50_l3_dwxcorr" 2 | 3 | BACKBONE: 4 | TYPE: "resnet50" 5 | KWARGS: 6 | used_layers: [0, 1, 2, 3] 7 | 8 | ADJUST: 9 | ADJUST: true 10 | TYPE: "AdjustAllLayer" 11 | LAYER: 3 12 | KWARGS: 13 | in_channels: [1024] 14 | out_channels: [256] 15 | 16 | RPN: 17 | TYPE: 'DepthwiseRPN' 18 | KWARGS: 19 | anchor_num: 5 20 | in_channels: 256 21 | out_channels: 256 22 | 23 | MASK: 24 | MASK: True 25 | TYPE: 'MaskCorr' 26 | KWARGS: 27 | in_channels: 256 28 | hidden: 256 29 | out_channels: 3969 30 | 31 | REFINE: 32 | REFINE: True 33 | TYPE: 'Refine' 34 | 35 | ANCHOR: 36 | STRIDE: 8 37 | RATIOS: [0.33, 0.5, 1, 2, 3] 38 | SCALES: [8] 39 | ANCHOR_NUM: 5 40 | 41 | TRACK: 42 | TYPE: 'SiamMaskTracker' 43 | PENALTY_K: 0.10 44 | WINDOW_INFLUENCE: 0.37 # 0.41 45 | LR: 0.32 46 | EXEMPLAR_SIZE: 127 47 | INSTANCE_SIZE: 255 48 | BASE_SIZE: 8 49 | CONTEXT_AMOUNT: 0.5 50 | MASK_THERSHOLD: 0.15 51 | 52 | # classifier 53 | USE_CLASSIFIER: True 54 | TEMPLATE_UPDATE: False 55 | SEED: 12345 56 | COEE_CLASS: 0.8 57 | USE_ATTENTION_LAYER: True 58 | CHANNEL_ATTENTION: True 59 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool'] 60 | -------------------------------------------------------------------------------- /experiments/siammask_r50_l3/convert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | model = torch.load('SiamMask_VOT_LD.pth', map_location=lambda storage, loc: storage) 5 | 6 | new_model = OrderedDict() 7 | 8 | for k, v in model.items(): 9 | if k.startswith('features.features'): 10 | k = k.replace('features.features', 'backbone') 11 | elif k.startswith('features'): 12 | k = k.replace('features', 'neck') 13 | elif k.startswith('rpn_model'): 14 | k = k.replace('rpn_model', 'rpn_head') 15 | elif k.startswith('mask_model'): 16 | k = k.replace('mask_model.mask', 'mask_head') 17 | elif k.startswith('refine_model'): 18 | k = k.replace('refine_model', 'refine_head') 19 | new_model[k] = v 20 | 21 | torch.save(new_model, 'model.pth') 22 | -------------------------------------------------------------------------------- /experiments/siamrpn_alex_dwxcorr_otb/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamrpn_alex_dwxcorr_otb" 2 | 3 | BACKBONE: 4 | TYPE: "alexnetlegacy" 5 | KWARGS: 6 | width_mult: 1.0 7 | 8 | ADJUST: 9 | ADJUST: False 10 | 11 | RPN: 12 | TYPE: 'DepthwiseRPN' 13 | KWARGS: 14 | anchor_num: 5 15 | in_channels: 256 16 | out_channels: 256 17 | 18 | MASK: 19 | MASK: False 20 | 21 | ANCHOR: 22 | STRIDE: 8 23 | RATIOS: [0.33, 0.5, 1, 2, 3] 24 | SCALES: [8] 25 | ANCHOR_NUM: 5 26 | 27 | TRACK: 28 | TYPE: 'SiamRPNTracker' 29 | PENALTY_K: 0.16 30 | WINDOW_INFLUENCE: 0.3 # 0.40 31 | LR: 0.30 32 | EXEMPLAR_SIZE: 127 33 | INSTANCE_SIZE: 287 34 | BASE_SIZE: 0 35 | CONTEXT_AMOUNT: 0.5 36 | SHORT_TERM_DRIFT: True 37 | 38 | # classifier & updater 39 | USE_CLASSIFIER: True 40 | TEMPLATE_UPDATE: False 41 | SEED: 123456 42 | COEE_CLASS: 0.8 43 | USE_ATTENTION_LAYER: True 44 | CHANNEL_ATTENTION: True 45 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool'] 46 | -------------------------------------------------------------------------------- /experiments/siamrpn_alex_dwxcorr_vot/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamrpn_alex_dwxcorr_vot" 2 | 3 | BACKBONE: 4 | TYPE: "alexnetlegacy" 5 | KWARGS: 6 | width_mult: 1.0 7 | 8 | ADJUST: 9 | ADJUST: False 10 | 11 | RPN: 12 | TYPE: 'DepthwiseRPN' 13 | KWARGS: 14 | anchor_num: 5 15 | in_channels: 256 16 | out_channels: 256 17 | 18 | MASK: 19 | MASK: False 20 | 21 | ANCHOR: 22 | STRIDE: 8 23 | RATIOS: [0.33, 0.5, 1, 2, 3] 24 | SCALES: [8] 25 | ANCHOR_NUM: 5 26 | 27 | TRACK: 28 | TYPE: 'SiamRPNTracker' 29 | PENALTY_K: 0.16 30 | WINDOW_INFLUENCE: 0.3 #0.40 31 | LR: 0.30 32 | EXEMPLAR_SIZE: 127 33 | INSTANCE_SIZE: 287 34 | BASE_SIZE: 0 35 | CONTEXT_AMOUNT: 0.5 36 | 37 | # classifier & updater 38 | USE_CLASSIFIER: True 39 | TEMPLATE_UPDATE: False 40 | SEED: 12 41 | COEE_CLASS: 0.8 42 | USE_ATTENTION_LAYER: False 43 | CHANNEL_ATTENTION: True 44 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool'] 45 | -------------------------------------------------------------------------------- /experiments/siamrpn_r50_l234_dwxcorr/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamrpn_r50_l234_dwxcorr" 2 | 3 | BACKBONE: 4 | TYPE: "resnet50" 5 | KWARGS: 6 | used_layers: [2, 3, 4] 7 | 8 | ADJUST: 9 | ADJUST: True 10 | TYPE: "AdjustAllLayer" 11 | LAYER: 0 12 | FUSE: 'avg' 13 | KWARGS: 14 | in_channels: [512, 1024, 2048] 15 | out_channels: [256, 256, 256] 16 | 17 | RPN: 18 | TYPE: 'MultiRPN' 19 | KWARGS: 20 | anchor_num: 5 21 | in_channels: [256, 256, 256] 22 | weighted: True 23 | 24 | MASK: 25 | MASK: False 26 | 27 | ANCHOR: 28 | STRIDE: 8 29 | RATIOS: [0.33, 0.5, 1, 2, 3] 30 | SCALES: [8] 31 | ANCHOR_NUM: 5 32 | 33 | TRACK: 34 | # matcher 35 | TYPE: 'SiamRPNTracker' 36 | PENALTY_K: 0.05 37 | WINDOW_INFLUENCE: 0.25 38 | LR: 0.38 39 | EXEMPLAR_SIZE: 127 40 | INSTANCE_SIZE: 255 41 | BASE_SIZE: 8 42 | CONTEXT_AMOUNT: 0.5 43 | 44 | # classifier & updater 45 | USE_CLASSIFIER: True 46 | TEMPLATE_UPDATE: True 47 | SEED: 12345 48 | COEE_CLASS: 0.8 49 | USE_ATTENTION_LAYER: True 50 | CHANNEL_ATTENTION: True 51 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool'] -------------------------------------------------------------------------------- /experiments/siamrpn_r50_l234_dwxcorr_otb/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamrpn_r50_l234_dwxcorr_otb" 2 | 3 | BACKBONE: 4 | TYPE: "resnet50" 5 | KWARGS: 6 | used_layers: [2, 3, 4] 7 | 8 | ADJUST: 9 | ADJUST: True 10 | TYPE: "AdjustAllLayer" 11 | LAYER: 0 12 | FUSE: 'avg' # ['avg', 'wavg', 'con'] 13 | KWARGS: 14 | in_channels: [512, 1024, 2048] 15 | out_channels: [256, 256, 256] 16 | 17 | RPN: 18 | TYPE: 'MultiRPN' 19 | KWARGS: 20 | anchor_num: 5 21 | in_channels: [256, 256, 256] 22 | weighted: False 23 | 24 | MASK: 25 | MASK: False 26 | 27 | ANCHOR: 28 | STRIDE: 8 29 | RATIOS: [0.33, 0.5, 1, 2, 3] 30 | SCALES: [8] 31 | ANCHOR_NUM: 5 32 | 33 | TRACK: 34 | # matcher 35 | TYPE: 'SiamRPNTracker' 36 | PENALTY_K: 0.24 37 | WINDOW_INFLUENCE: 0.25 38 | LR: 0.25 39 | EXEMPLAR_SIZE: 127 40 | INSTANCE_SIZE: 255 41 | BASE_SIZE: 8 42 | CONTEXT_AMOUNT: 0.5 43 | SHORT_TERM_DRIFT: True 44 | 45 | # classifier & updater 46 | USE_CLASSIFIER: True 47 | TEMPLATE_UPDATE: True 48 | SEED: 123456 49 | COEE_CLASS: 0.8 50 | USE_ATTENTION_LAYER: True 51 | CHANNEL_ATTENTION: True 52 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool'] 53 | -------------------------------------------------------------------------------- /experiments/siamrpn_r50_l234_dwxcorr_vot/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamrpn_r50_l234_dwxcorr_vot" 2 | 3 | BACKBONE: 4 | TYPE: "resnet50" 5 | KWARGS: 6 | used_layers: [2, 3, 4] 7 | 8 | ADJUST: 9 | ADJUST: True 10 | TYPE: "AdjustAllLayer" 11 | LAYER: 0 12 | FUSE: 'avg' 13 | KWARGS: 14 | in_channels: [512, 1024, 2048] 15 | out_channels: [256, 256, 256] 16 | 17 | RPN: 18 | TYPE: 'MultiRPN' 19 | KWARGS: 20 | anchor_num: 5 21 | in_channels: [256, 256, 256] 22 | weighted: True 23 | 24 | MASK: 25 | MASK: False 26 | 27 | ANCHOR: 28 | STRIDE: 8 29 | RATIOS: [0.33, 0.5, 1, 2, 3] 30 | SCALES: [8] 31 | ANCHOR_NUM: 5 32 | 33 | TRACK: 34 | # matcher 35 | TYPE: 'SiamRPNTracker' 36 | PENALTY_K: 0.05 37 | WINDOW_INFLUENCE: 0.35 38 | LR: 0.38 39 | EXEMPLAR_SIZE: 127 40 | INSTANCE_SIZE: 255 41 | BASE_SIZE: 8 42 | CONTEXT_AMOUNT: 0.5 43 | 44 | # classifier & updater 45 | USE_CLASSIFIER: True 46 | TEMPLATE_UPDATE: False 47 | SEED: 123 48 | COEE_CLASS: 0.8 49 | USE_ATTENTION_LAYER: False 50 | CHANNEL_ATTENTION: True 51 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool'] 52 | -------------------------------------------------------------------------------- /experiments/siamrpn_r50_l234_dwxcorr_votlt/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamrpn_r50_l234_dwxcorr_votlt" 2 | 3 | BACKBONE: 4 | TYPE: "resnet50" 5 | KWARGS: 6 | used_layers: [2, 3, 4] 7 | 8 | ADJUST: 9 | ADJUST: True 10 | TYPE: "AdjustAllLayer" 11 | LAYER: 1 12 | KWARGS: 13 | in_channels: [512, 1024, 2048] 14 | out_channels: [128, 256, 512] #[128, 256, 512] 15 | 16 | RPN: 17 | TYPE: 'MultiRPN' 18 | KWARGS: 19 | anchor_num: 5 20 | in_channels: [128, 256, 512] #[128, 256, 512] 21 | weighted: True 22 | 23 | MASK: 24 | MASK: False 25 | 26 | ANCHOR: 27 | STRIDE: 8 28 | RATIOS: [0.33, 0.5, 1, 2, 3] 29 | SCALES: [8] 30 | ANCHOR_NUM: 5 31 | 32 | TRACK: 33 | TYPE: 'SiamRPNLTTracker' 34 | PENALTY_K: 0.05 35 | WINDOW_INFLUENCE: 0.26 36 | LR: 0.22 37 | EXEMPLAR_SIZE: 127 38 | INSTANCE_SIZE: 255 39 | BASE_SIZE: 8 40 | CONTEXT_AMOUNT: 0.5 41 | 42 | # classifier & updater 43 | USE_CLASSIFIER: True 44 | TEMPLATE_UPDATE: True 45 | SEED: 123 46 | COEE_CLASS: 0.8 47 | USE_ATTENTION_LAYER: False 48 | CHANNEL_ATTENTION: True 49 | SPATIAL_ATTENTION: 'pool' # ['none', 'pool', 'residual', 'gaussian'] 50 | TARGET_NOT_FOUND_THRESHOLD: 0.2 51 | 52 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -lt 2 ]; then 4 | echo "ARGS ERROR!" 5 | echo " bash install.sh /path/to/your/conda env_name" 6 | exit 1 7 | fi 8 | 9 | set -e 10 | 11 | conda_path=$1 12 | env_name=$2 13 | 14 | source $conda_path/etc/profile.d/conda.sh 15 | 16 | echo "****** create environment " $env_name "*****" 17 | # create environment 18 | conda create -y --name $env_name python=3.7 19 | conda activate $env_name 20 | 21 | echo "***** install numpy pytorch opencv *****" 22 | # numpy 23 | conda install -y numpy 24 | # pytorch 25 | # pytorch with cuda80/cuda90 is tested 26 | conda install -y pytorch=1.0.0 torchvision cuda90 -c pytorch 27 | # opencv 28 | pip install opencv-python 29 | # tensorboardX 30 | 31 | echo "***** install other libs *****" 32 | pip install tensorboardX 33 | # libs 34 | pip install pyyaml yacs tqdm colorama matplotlib cython 35 | 36 | 37 | echo "***** build extensions *****" 38 | python setup.py build_ext --inplace 39 | -------------------------------------------------------------------------------- /pysot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/__init__.py -------------------------------------------------------------------------------- /pysot/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/core/__init__.py -------------------------------------------------------------------------------- /pysot/core/xcorr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def xcorr_slow(x, kernel): 13 | """for loop to calculate cross correlation, slow version 14 | """ 15 | batch = x.size()[0] 16 | out = [] 17 | for i in range(batch): 18 | px = x[i] 19 | pk = kernel[i] 20 | px = px.view(1, px.size()[0], px.size()[1], px.size()[2]) 21 | pk = pk.view(-1, px.size()[1], pk.size()[1], pk.size()[2]) 22 | po = F.conv2d(px, pk) 23 | out.append(po) 24 | out = torch.cat(out, 0) 25 | return out 26 | 27 | 28 | def xcorr_fast(x, kernel): 29 | """group conv2d to calculate cross correlation, fast version 30 | """ 31 | batch = kernel.size()[0] 32 | pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3]) 33 | px = x.view(1, -1, x.size()[2], x.size()[3]) 34 | po = F.conv2d(px, pk, groups=batch) 35 | po = po.view(batch, -1, po.size()[2], po.size()[3]) 36 | return po 37 | 38 | 39 | def xcorr_depthwise(x, kernel): 40 | """depthwise cross correlation 41 | """ 42 | batch = kernel.size(0) 43 | channel = kernel.size(1) 44 | x = x.view(1, batch*channel, x.size(2), x.size(3)) 45 | kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3)) 46 | out = F.conv2d(x, kernel, groups=batch*channel) 47 | out = out.view(batch, channel, out.size(2), out.size(3)) 48 | return out 49 | -------------------------------------------------------------------------------- /pysot/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/datasets/__init__.py -------------------------------------------------------------------------------- /pysot/datasets/anchor_target.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import numpy as np 9 | 10 | from pysot.core.config import cfg 11 | from pysot.utils.bbox import IoU, corner2center 12 | from pysot.utils.anchor import Anchors 13 | 14 | 15 | class AnchorTarget: 16 | def __init__(self,): 17 | self.anchors = Anchors(cfg.ANCHOR.STRIDE, 18 | cfg.ANCHOR.RATIOS, 19 | cfg.ANCHOR.SCALES) 20 | 21 | self.anchors.generate_all_anchors(im_c=cfg.TRAIN.SEARCH_SIZE//2, 22 | size=cfg.TRAIN.OUTPUT_SIZE) 23 | 24 | def __call__(self, target, size, neg=False): 25 | anchor_num = len(cfg.ANCHOR.RATIOS) * len(cfg.ANCHOR.SCALES) 26 | 27 | # -1 ignore 0 negative 1 positive 28 | cls = -1 * np.ones((anchor_num, size, size), dtype=np.int64) 29 | delta = np.zeros((4, anchor_num, size, size), dtype=np.float32) 30 | delta_weight = np.zeros((anchor_num, size, size), dtype=np.float32) 31 | 32 | def select(position, keep_num=16): 33 | num = position[0].shape[0] 34 | if num <= keep_num: 35 | return position, num 36 | slt = np.arange(num) 37 | np.random.shuffle(slt) 38 | slt = slt[:keep_num] 39 | return tuple(p[slt] for p in position), keep_num 40 | 41 | tcx, tcy, tw, th = corner2center(target) 42 | 43 | if neg: 44 | # l = size // 2 - 3 45 | # r = size // 2 + 3 + 1 46 | # cls[:, l:r, l:r] = 0 47 | 48 | cx = size // 2 49 | cy = size // 2 50 | cx += int(np.ceil((tcx - cfg.TRAIN.SEARCH_SIZE // 2) / 51 | cfg.ANCHOR.STRIDE + 0.5)) 52 | cy += int(np.ceil((tcy - cfg.TRAIN.SEARCH_SIZE // 2) / 53 | cfg.ANCHOR.STRIDE + 0.5)) 54 | l = max(0, cx - 3) 55 | r = min(size, cx + 4) 56 | u = max(0, cy - 3) 57 | d = min(size, cy + 4) 58 | cls[:, u:d, l:r] = 0 59 | 60 | neg, neg_num = select(np.where(cls == 0), cfg.TRAIN.NEG_NUM) 61 | cls[:] = -1 62 | cls[neg] = 0 63 | 64 | overlap = np.zeros((anchor_num, size, size), dtype=np.float32) 65 | return cls, delta, delta_weight, overlap 66 | 67 | anchor_box = self.anchors.all_anchors[0] 68 | anchor_center = self.anchors.all_anchors[1] 69 | x1, y1, x2, y2 = anchor_box[0], anchor_box[1], \ 70 | anchor_box[2], anchor_box[3] 71 | cx, cy, w, h = anchor_center[0], anchor_center[1], \ 72 | anchor_center[2], anchor_center[3] 73 | 74 | delta[0] = (tcx - cx) / w 75 | delta[1] = (tcy - cy) / h 76 | delta[2] = np.log(tw / w) 77 | delta[3] = np.log(th / h) 78 | 79 | overlap = IoU([x1, y1, x2, y2], target) 80 | 81 | pos = np.where(overlap > cfg.TRAIN.THR_HIGH) 82 | neg = np.where(overlap < cfg.TRAIN.THR_LOW) 83 | 84 | pos, pos_num = select(pos, cfg.TRAIN.POS_NUM) 85 | neg, neg_num = select(neg, cfg.TRAIN.TOTAL_NUM - cfg.TRAIN.POS_NUM) 86 | 87 | cls[pos] = 1 88 | delta_weight[pos] = 1. / (pos_num + 1e-6) 89 | 90 | cls[neg] = 0 91 | return cls, delta, delta_weight, overlap 92 | -------------------------------------------------------------------------------- /pysot/datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import numpy as np 9 | import cv2 10 | 11 | from pysot.utils.bbox import corner2center, \ 12 | Center, center2corner, Corner 13 | 14 | 15 | class Augmentation: 16 | def __init__(self, shift, scale, blur, flip, color): 17 | self.shift = shift 18 | self.scale = scale 19 | self.blur = blur 20 | self.flip = flip 21 | self.color = color 22 | self.rgbVar = np.array( 23 | [[-0.55919361, 0.98062831, - 0.41940627], 24 | [1.72091413, 0.19879334, - 1.82968581], 25 | [4.64467907, 4.73710203, 4.88324118]], dtype=np.float32) 26 | 27 | @staticmethod 28 | def random(): 29 | return np.random.random() * 2 - 1.0 30 | 31 | def _crop_roi(self, image, bbox, out_sz, padding=(0, 0, 0)): 32 | bbox = [float(x) for x in bbox] 33 | a = (out_sz-1) / (bbox[2]-bbox[0]) 34 | b = (out_sz-1) / (bbox[3]-bbox[1]) 35 | c = -a * bbox[0] 36 | d = -b * bbox[1] 37 | mapping = np.array([[a, 0, c], 38 | [0, b, d]]).astype(np.float) 39 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), 40 | borderMode=cv2.BORDER_CONSTANT, 41 | borderValue=padding) 42 | return crop 43 | 44 | def _blur_aug(self, image): 45 | def rand_kernel(): 46 | sizes = np.arange(5, 46, 2) 47 | size = np.random.choice(sizes) 48 | kernel = np.zeros((size, size)) 49 | c = int(size/2) 50 | wx = np.random.random() 51 | kernel[:, c] += 1. / size * wx 52 | kernel[c, :] += 1. / size * (1-wx) 53 | return kernel 54 | kernel = rand_kernel() 55 | image = cv2.filter2D(image, -1, kernel) 56 | return image 57 | 58 | def _color_aug(self, image): 59 | offset = np.dot(self.rgbVar, np.random.randn(3, 1)) 60 | offset = offset[::-1] # bgr 2 rgb 61 | offset = offset.reshape(3) 62 | image = image - offset 63 | return image 64 | 65 | def _gray_aug(self, image): 66 | grayed = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 67 | image = cv2.cvtColor(grayed, cv2.COLOR_GRAY2BGR) 68 | return image 69 | 70 | def _shift_scale_aug(self, image, bbox, crop_bbox, size): 71 | im_h, im_w = image.shape[:2] 72 | 73 | # adjust crop bounding box 74 | crop_bbox_center = corner2center(crop_bbox) 75 | if self.scale: 76 | scale_x = (1.0 + Augmentation.random() * self.scale) 77 | scale_y = (1.0 + Augmentation.random() * self.scale) 78 | h, w = crop_bbox_center.h, crop_bbox_center.w 79 | scale_x = min(scale_x, float(im_w) / w) 80 | scale_y = min(scale_y, float(im_h) / h) 81 | crop_bbox_center = Center(crop_bbox_center.x, 82 | crop_bbox_center.y, 83 | crop_bbox_center.w * scale_x, 84 | crop_bbox_center.h * scale_y) 85 | 86 | crop_bbox = center2corner(crop_bbox_center) 87 | if self.shift: 88 | sx = Augmentation.random() * self.shift 89 | sy = Augmentation.random() * self.shift 90 | 91 | x1, y1, x2, y2 = crop_bbox 92 | 93 | sx = max(-x1, min(im_w - 1 - x2, sx)) 94 | sy = max(-y1, min(im_h - 1 - y2, sy)) 95 | 96 | crop_bbox = Corner(x1 + sx, y1 + sy, x2 + sx, y2 + sy) 97 | 98 | # adjust target bounding box 99 | x1, y1 = crop_bbox.x1, crop_bbox.y1 100 | bbox = Corner(bbox.x1 - x1, bbox.y1 - y1, 101 | bbox.x2 - x1, bbox.y2 - y1) 102 | 103 | if self.scale: 104 | bbox = Corner(bbox.x1 / scale_x, bbox.y1 / scale_y, 105 | bbox.x2 / scale_x, bbox.y2 / scale_y) 106 | 107 | image = self._crop_roi(image, crop_bbox, size) 108 | return image, bbox 109 | 110 | def _flip_aug(self, image, bbox): 111 | image = cv2.flip(image, 1) 112 | width = image.shape[1] 113 | bbox = Corner(width - 1 - bbox.x2, bbox.y1, 114 | width - 1 - bbox.x1, bbox.y2) 115 | return image, bbox 116 | 117 | def __call__(self, image, bbox, size, gray=False): 118 | shape = image.shape 119 | crop_bbox = center2corner(Center(shape[0]//2, shape[1]//2, 120 | size-1, size-1)) 121 | # gray augmentation 122 | if gray: 123 | image = self._gray_aug(image) 124 | 125 | # shift scale augmentation 126 | image, bbox = self._shift_scale_aug(image, bbox, crop_bbox, size) 127 | 128 | # color augmentation 129 | if self.color > np.random.random(): 130 | image = self._color_aug(image) 131 | 132 | # blur augmentation 133 | if self.blur > np.random.random(): 134 | image = self._blur_aug(image) 135 | 136 | # flip augmentation 137 | if self.flip and self.flip > np.random.random(): 138 | image, bbox = self._flip_aug(image, bbox) 139 | return image, bbox 140 | -------------------------------------------------------------------------------- /pysot/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/models/__init__.py -------------------------------------------------------------------------------- /pysot/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | from pysot.models.backbone.alexnet import alexnetlegacy, alexnetlegacy2, alexnet 9 | from pysot.models.backbone.mobile_v2 import mobilenetv2 10 | from pysot.models.backbone.resnet_atrous import resnet18, resnet34, resnet50 11 | 12 | BACKBONES = { 13 | 'alexnetlegacy': alexnetlegacy, 14 | 'alexnetlegacy2': alexnetlegacy2, 15 | 'mobilenetv2': mobilenetv2, 16 | 'resnet18': resnet18, 17 | 'resnet34': resnet34, 18 | 'resnet50': resnet50, 19 | 'alexnet': alexnet, 20 | } 21 | 22 | 23 | def get_backbone(name, **kwargs): 24 | return BACKBONES[name](**kwargs) 25 | -------------------------------------------------------------------------------- /pysot/models/backbone/alexnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch.nn as nn 7 | import torch 8 | 9 | class AlexNetLegacy(nn.Module): 10 | configs = [3, 96, 256, 384, 384, 256] 11 | 12 | def __init__(self, width_mult=1): 13 | configs = list(map(lambda x: 3 if x == 3 else 14 | int(x*width_mult), AlexNet.configs)) 15 | super(AlexNetLegacy, self).__init__() 16 | self.features = nn.Sequential( 17 | nn.Conv2d(configs[0], configs[1], kernel_size=11, stride=2), 18 | nn.BatchNorm2d(configs[1]), 19 | nn.MaxPool2d(kernel_size=3, stride=2), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(configs[1], configs[2], kernel_size=5), 22 | nn.BatchNorm2d(configs[2]), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(configs[2], configs[3], kernel_size=3), 26 | nn.BatchNorm2d(configs[3]), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(configs[3], configs[4], kernel_size=3), 29 | nn.BatchNorm2d(configs[4]), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(configs[4], configs[5], kernel_size=3), 32 | nn.BatchNorm2d(configs[5]), 33 | ) 34 | self.feature_size = configs[5] 35 | 36 | def forward(self, x): 37 | x = self.features(x) 38 | return x 39 | 40 | class AlexNetLegacy2(nn.Module): 41 | def __init__(self): 42 | super(AlexNetLegacy2, self).__init__() 43 | self.features = nn.Sequential( 44 | nn.Conv2d(3, 96, 11, 2), 45 | nn.BatchNorm2d(96), 46 | nn.ReLU(inplace=True), 47 | nn.MaxPool2d(3, 2), 48 | nn.Conv2d(96, 256, 5, 1, groups=2), 49 | nn.BatchNorm2d(256), 50 | nn.ReLU(inplace=True), 51 | nn.MaxPool2d(3, 2), 52 | nn.Conv2d(256, 384, 3, 1), 53 | nn.BatchNorm2d(384), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(384, 384, 3, 1, groups=2), 56 | nn.BatchNorm2d(384), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(384, 256, 3, 1, groups=2) 59 | ) 60 | self.corr_bias = nn.Parameter(torch.zeros(1)) 61 | 62 | def forward(self, x): 63 | x = self.features(x) 64 | return x 65 | 66 | class AlexNet(nn.Module): 67 | configs = [3, 96, 256, 384, 384, 256] 68 | 69 | def __init__(self, width_mult=1): 70 | configs = list(map(lambda x: 3 if x == 3 else 71 | int(x*width_mult), AlexNet.configs)) 72 | super(AlexNet, self).__init__() 73 | self.layer1 = nn.Sequential( 74 | nn.Conv2d(configs[0], configs[1], kernel_size=11, stride=2), 75 | nn.BatchNorm2d(configs[1]), 76 | nn.MaxPool2d(kernel_size=3, stride=2), 77 | nn.ReLU(inplace=True), 78 | ) 79 | self.layer2 = nn.Sequential( 80 | nn.Conv2d(configs[1], configs[2], kernel_size=5), 81 | nn.BatchNorm2d(configs[2]), 82 | nn.MaxPool2d(kernel_size=3, stride=2), 83 | nn.ReLU(inplace=True), 84 | ) 85 | self.layer3 = nn.Sequential( 86 | nn.Conv2d(configs[2], configs[3], kernel_size=3), 87 | nn.BatchNorm2d(configs[3]), 88 | nn.ReLU(inplace=True), 89 | ) 90 | self.layer4 = nn.Sequential( 91 | nn.Conv2d(configs[3], configs[4], kernel_size=3), 92 | nn.BatchNorm2d(configs[4]), 93 | nn.ReLU(inplace=True), 94 | ) 95 | 96 | self.layer5 = nn.Sequential( 97 | nn.Conv2d(configs[4], configs[5], kernel_size=3), 98 | nn.BatchNorm2d(configs[5]), 99 | ) 100 | self.feature_size = configs[5] 101 | 102 | def forward(self, x): 103 | x = self.layer1(x) 104 | x = self.layer2(x) 105 | x = self.layer3(x) 106 | x = self.layer4(x) 107 | x = self.layer5(x) 108 | return x 109 | 110 | 111 | def alexnetlegacy(**kwargs): 112 | return AlexNetLegacy(**kwargs) 113 | 114 | def alexnetlegacy2(**kwargs): 115 | return AlexNetLegacy2() 116 | 117 | def alexnet(**kwargs): 118 | return AlexNet(**kwargs) 119 | -------------------------------------------------------------------------------- /pysot/models/backbone/mobile_v2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def conv_bn(inp, oup, stride, padding=1): 11 | return nn.Sequential( 12 | nn.Conv2d(inp, oup, 3, stride, padding, bias=False), 13 | nn.BatchNorm2d(oup), 14 | nn.ReLU6(inplace=True) 15 | ) 16 | 17 | 18 | def conv_1x1_bn(inp, oup): 19 | return nn.Sequential( 20 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 21 | nn.BatchNorm2d(oup), 22 | nn.ReLU6(inplace=True) 23 | ) 24 | 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1): 28 | super(InvertedResidual, self).__init__() 29 | self.stride = stride 30 | 31 | self.use_res_connect = self.stride == 1 and inp == oup 32 | 33 | padding = 2 - stride 34 | if dilation > 1: 35 | padding = dilation 36 | 37 | self.conv = nn.Sequential( 38 | # pw 39 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 40 | nn.BatchNorm2d(inp * expand_ratio), 41 | nn.ReLU6(inplace=True), 42 | # dw 43 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, 44 | stride, padding, dilation=dilation, 45 | groups=inp * expand_ratio, bias=False), 46 | nn.BatchNorm2d(inp * expand_ratio), 47 | nn.ReLU6(inplace=True), 48 | # pw-linear 49 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(oup), 51 | ) 52 | 53 | def forward(self, x): 54 | if self.use_res_connect: 55 | return x + self.conv(x) 56 | else: 57 | return self.conv(x) 58 | 59 | 60 | class MobileNetV2(nn.Sequential): 61 | def __init__(self, width_mult=1.0, used_layers=[3, 5, 7]): 62 | super(MobileNetV2, self).__init__() 63 | 64 | self.interverted_residual_setting = [ 65 | # t, c, n, s 66 | [1, 16, 1, 1, 1], 67 | [6, 24, 2, 2, 1], 68 | [6, 32, 3, 2, 1], 69 | [6, 64, 4, 2, 1], 70 | [6, 96, 3, 1, 1], 71 | [6, 160, 3, 2, 1], 72 | [6, 320, 1, 1, 1], 73 | ] 74 | # 0,2,3,4,6 75 | 76 | self.interverted_residual_setting = [ 77 | # t, c, n, s 78 | [1, 16, 1, 1, 1], 79 | [6, 24, 2, 2, 1], 80 | [6, 32, 3, 2, 1], 81 | [6, 64, 4, 1, 2], 82 | [6, 96, 3, 1, 2], 83 | [6, 160, 3, 1, 4], 84 | [6, 320, 1, 1, 4], 85 | ] 86 | 87 | self.channels = [24, 32, 96, 320] 88 | self.channels = [int(c * width_mult) for c in self.channels] 89 | 90 | input_channel = int(32 * width_mult) 91 | self.last_channel = int(1280 * width_mult) \ 92 | if width_mult > 1.0 else 1280 93 | 94 | self.add_module('layer0', conv_bn(3, input_channel, 2, 0)) 95 | 96 | last_dilation = 1 97 | 98 | self.used_layers = used_layers 99 | 100 | for idx, (t, c, n, s, d) in \ 101 | enumerate(self.interverted_residual_setting, start=1): 102 | output_channel = int(c * width_mult) 103 | 104 | layers = [] 105 | 106 | for i in range(n): 107 | if i == 0: 108 | if d == last_dilation: 109 | dd = d 110 | else: 111 | dd = max(d // 2, 1) 112 | layers.append(InvertedResidual(input_channel, 113 | output_channel, s, t, dd)) 114 | else: 115 | layers.append(InvertedResidual(input_channel, 116 | output_channel, 1, t, d)) 117 | input_channel = output_channel 118 | 119 | last_dilation = d 120 | 121 | self.add_module('layer%d' % (idx), nn.Sequential(*layers)) 122 | 123 | def forward(self, x): 124 | outputs = [] 125 | for idx in range(8): 126 | name = "layer%d" % idx 127 | x = getattr(self, name)(x) 128 | outputs.append(x) 129 | p0, p1, p2, p3, p4 = [outputs[i] for i in [1, 2, 3, 5, 7]] 130 | out = [outputs[i] for i in self.used_layers] 131 | return out 132 | 133 | 134 | def mobilenetv2(**kwargs): 135 | model = MobileNetV2(**kwargs) 136 | return model 137 | 138 | 139 | if __name__ == '__main__': 140 | net = mobilenetv2() 141 | 142 | print(net) 143 | 144 | from torch.autograd import Variable 145 | tensor = Variable(torch.Tensor(1, 3, 255, 255)).cuda() 146 | 147 | net = net.cuda() 148 | 149 | out = net(tensor) 150 | 151 | for i, p in enumerate(out): 152 | print(i, p.size()) 153 | -------------------------------------------------------------------------------- /pysot/models/head/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | from pysot.models.head.mask import MaskCorr, Refine 9 | from pysot.models.head.rpn import UPChannelRPN, DepthwiseRPN, MultiRPN 10 | # from pysot.models.head.rpn import DeformableRPN, MultiDeformableRPN 11 | 12 | RPNS = { 13 | 'UPChannelRPN': UPChannelRPN, 14 | 'DepthwiseRPN': DepthwiseRPN, 15 | 'MultiRPN': MultiRPN, 16 | } 17 | 18 | MASKS = { 19 | 'MaskCorr': MaskCorr, 20 | } 21 | 22 | REFINE = { 23 | 'Refine': Refine, 24 | } 25 | 26 | def get_rpn_head(name, **kwargs): 27 | return RPNS[name](**kwargs) 28 | 29 | 30 | def get_mask_head(name, **kwargs): 31 | return MASKS[name](**kwargs) 32 | 33 | def get_refine_head(name): 34 | return REFINE[name]() 35 | -------------------------------------------------------------------------------- /pysot/models/head/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from pysot.models.head.rpn import DepthwiseXCorr 12 | from pysot.core.xcorr import xcorr_depthwise 13 | 14 | 15 | class MaskCorr(DepthwiseXCorr): 16 | def __init__(self, in_channels, hidden, out_channels, 17 | fused='none', kernel_size=3, hidden_kernel_size=5): 18 | super(MaskCorr, self).__init__(in_channels, hidden, 19 | out_channels, fused, 20 | kernel_size, hidden_kernel_size) 21 | 22 | def forward(self, kernel, search): 23 | kernel = self.conv_kernel(kernel) 24 | search = self.conv_search(search) 25 | feature = xcorr_depthwise(search, kernel) 26 | out = self.head(feature) 27 | return out, feature 28 | 29 | 30 | class Refine(nn.Module): 31 | def __init__(self): 32 | super(Refine, self).__init__() 33 | self.v0 = nn.Sequential( 34 | nn.Conv2d(64, 16, 3, padding=1), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(16, 4, 3, padding=1), 37 | nn.ReLU(inplace=True), 38 | ) 39 | self.v1 = nn.Sequential( 40 | nn.Conv2d(256, 64, 3, padding=1), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(64, 16, 3, padding=1), 43 | nn.ReLU(inplace=True), 44 | ) 45 | self.v2 = nn.Sequential( 46 | nn.Conv2d(512, 128, 3, padding=1), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(128, 32, 3, padding=1), 49 | nn.ReLU(inplace=True), 50 | ) 51 | self.h2 = nn.Sequential( 52 | nn.Conv2d(32, 32, 3, padding=1), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(32, 32, 3, padding=1), 55 | nn.ReLU(inplace=True), 56 | ) 57 | self.h1 = nn.Sequential( 58 | nn.Conv2d(16, 16, 3, padding=1), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(16, 16, 3, padding=1), 61 | nn.ReLU(inplace=True), 62 | ) 63 | self.h0 = nn.Sequential( 64 | nn.Conv2d(4, 4, 3, padding=1), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(4, 4, 3, padding=1), 67 | nn.ReLU(inplace=True), 68 | ) 69 | 70 | self.deconv = nn.ConvTranspose2d(256, 32, 15, 15) 71 | self.post0 = nn.Conv2d(32, 16, 3, padding=1) 72 | self.post1 = nn.Conv2d(16, 4, 3, padding=1) 73 | self.post2 = nn.Conv2d(4, 1, 3, padding=1) 74 | 75 | def forward(self, f, corr_feature, pos): 76 | p0 = F.pad(f[0], [16, 16, 16, 16])[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61] 77 | p1 = F.pad(f[1], [8, 8, 8, 8])[:, :, 2*pos[0]:2*pos[0]+31, 2*pos[1]:2*pos[1]+31] 78 | p2 = F.pad(f[2], [4, 4, 4, 4])[:, :, pos[0]:pos[0]+15, pos[1]:pos[1]+15] 79 | 80 | p3 = corr_feature[:, :, pos[0], pos[1]].view(-1, 256, 1, 1) 81 | 82 | out = self.deconv(p3) 83 | out = self.post0(F.upsample(self.h2(out) + self.v2(p2), size=(31, 31))) 84 | out = self.post1(F.upsample(self.h1(out) + self.v1(p1), size=(61, 61))) 85 | out = self.post2(F.upsample(self.h0(out) + self.v0(p0), size=(127, 127))) 86 | out = out.view(-1, 127*127) 87 | return out 88 | -------------------------------------------------------------------------------- /pysot/models/head/rpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from pysot.core.xcorr import xcorr_fast, xcorr_depthwise 13 | from pysot.models.init_weight import init_weights 14 | 15 | class RPN(nn.Module): 16 | def __init__(self): 17 | super(RPN, self).__init__() 18 | 19 | def forward(self, z_f, x_f): 20 | raise NotImplementedError 21 | 22 | class UPChannelRPN(RPN): 23 | def __init__(self, anchor_num=5, feature_in=256): 24 | super(UPChannelRPN, self).__init__() 25 | 26 | cls_output = 2 * anchor_num 27 | loc_output = 4 * anchor_num 28 | 29 | self.template_cls_conv = nn.Conv2d(feature_in, 30 | feature_in * cls_output, kernel_size=3) 31 | self.template_loc_conv = nn.Conv2d(feature_in, 32 | feature_in * loc_output, kernel_size=3) 33 | 34 | self.search_cls_conv = nn.Conv2d(feature_in, 35 | feature_in, kernel_size=3) 36 | self.search_loc_conv = nn.Conv2d(feature_in, 37 | feature_in, kernel_size=3) 38 | 39 | self.loc_adjust = nn.Conv2d(loc_output, loc_output, kernel_size=1) 40 | 41 | 42 | def forward(self, z_f, x_f): 43 | cls_kernel = self.template_cls_conv(z_f) 44 | loc_kernel = self.template_loc_conv(z_f) 45 | 46 | cls_feature = self.search_cls_conv(x_f) 47 | loc_feature = self.search_loc_conv(x_f) 48 | 49 | cls = xcorr_fast(cls_feature, cls_kernel) 50 | loc = self.loc_adjust(xcorr_fast(loc_feature, loc_kernel)) 51 | return cls, loc 52 | 53 | class DepthwiseXCorr(nn.Module): 54 | def __init__(self, in_channels, hidden, out_channels, fused='none', kernel_size=3, hidden_kernel_size=5): 55 | super(DepthwiseXCorr, self).__init__() 56 | self.fused = fused 57 | self.conv_kernel = nn.Sequential( 58 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 59 | nn.BatchNorm2d(hidden), 60 | nn.ReLU(inplace=True), 61 | ) 62 | self.conv_search = nn.Sequential( 63 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 64 | nn.BatchNorm2d(hidden), 65 | nn.ReLU(inplace=True), 66 | ) 67 | if self.fused == 'con': 68 | self.head = nn.Sequential( 69 | nn.Conv2d(hidden * 2, hidden, kernel_size=1, bias=False), 70 | nn.BatchNorm2d(hidden), 71 | nn.ReLU(inplace=True), 72 | nn.Conv2d(hidden, out_channels, kernel_size=1) 73 | ) 74 | else: 75 | self.head = nn.Sequential( 76 | nn.Conv2d(hidden, hidden, kernel_size=1, bias=False), 77 | nn.BatchNorm2d(hidden), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(hidden, out_channels, kernel_size=1) 80 | ) 81 | if self.fused != 'none': 82 | self.conv_raw = nn.Conv2d(hidden, hidden, kernel_size=5) 83 | 84 | def forward(self, kernel, search): 85 | kernel = self.conv_kernel(kernel) 86 | search = self.conv_search(search) 87 | feature = xcorr_depthwise(search, kernel) 88 | if self.fused != 'none': 89 | raw = self.conv_raw(search) 90 | if self.fused == 'con': 91 | feature = torch.cat((feature, raw), dim=1) 92 | elif self.fused == 'mod': 93 | feature = torch.matmul(feature, raw) 94 | elif self.fused == 'avg': 95 | feature = feature + raw 96 | else: 97 | raise NotImplementedError() 98 | out = self.head(feature) 99 | return out 100 | 101 | class DepthwiseRPN(RPN): 102 | def __init__(self, anchor_num=5, in_channels=256, out_channels=256, fused='none'): 103 | super(DepthwiseRPN, self).__init__() 104 | self.cls = DepthwiseXCorr(in_channels, out_channels, 2 * anchor_num, fused) 105 | self.loc = DepthwiseXCorr(in_channels, out_channels, 4 * anchor_num, fused) 106 | 107 | def forward(self, z_f, x_f): 108 | cls = self.cls(z_f, x_f) 109 | loc = self.loc(z_f, x_f) 110 | return cls, loc 111 | 112 | class MultiRPN(RPN): 113 | def __init__(self, anchor_num, in_channels, weighted=False, fused='none'): 114 | super(MultiRPN, self).__init__() 115 | self.weighted = weighted 116 | for i in range(len(in_channels)): 117 | self.add_module('rpn'+str(i+2), 118 | DepthwiseRPN(anchor_num, in_channels[i], in_channels[i], fused)) 119 | if self.weighted: 120 | self.cls_weight = nn.Parameter(torch.ones(len(in_channels))) 121 | self.loc_weight = nn.Parameter(torch.ones(len(in_channels))) 122 | 123 | def forward(self, z_fs, x_fs): 124 | cls = [] 125 | loc = [] 126 | for idx, (z_f, x_f) in enumerate(zip(z_fs, x_fs), start=2): 127 | rpn = getattr(self, 'rpn'+str(idx)) 128 | c, l = rpn(z_f, x_f) 129 | cls.append(c) 130 | loc.append(l) 131 | 132 | if self.weighted: 133 | cls_weight = F.softmax(self.cls_weight, 0) 134 | loc_weight = F.softmax(self.loc_weight, 0) 135 | 136 | def avg(lst): 137 | return sum(lst) / len(lst) 138 | 139 | def weighted_avg(lst, weight): 140 | s = 0 141 | for i in range(len(weight)): 142 | s += lst[i] * weight[i] 143 | return s 144 | 145 | if self.weighted: 146 | return weighted_avg(cls, cls_weight), weighted_avg(loc, loc_weight) 147 | else: 148 | return avg(cls), avg(loc) -------------------------------------------------------------------------------- /pysot/models/init_weight.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def init_weights(model): 5 | for m in model.modules(): 6 | if isinstance(m, nn.Conv2d): 7 | nn.init.kaiming_normal_(m.weight.data, 8 | mode='fan_out', 9 | nonlinearity='relu') 10 | elif isinstance(m, nn.BatchNorm2d): 11 | m.weight.data.fill_(1) 12 | m.bias.data.zero_() 13 | -------------------------------------------------------------------------------- /pysot/models/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def get_cls_loss(pred, label, select): 13 | if len(select.size()) == 0: 14 | return 0 15 | pred = torch.index_select(pred, 0, select) 16 | label = torch.index_select(label, 0, select) 17 | return F.nll_loss(pred, label) 18 | 19 | 20 | def select_cross_entropy_loss(pred, label): 21 | pred = pred.view(-1, 2) 22 | label = label.view(-1) 23 | pos = label.data.eq(1).nonzero().squeeze().cuda() 24 | neg = label.data.eq(0).nonzero().squeeze().cuda() 25 | loss_pos = get_cls_loss(pred, label, pos) 26 | loss_neg = get_cls_loss(pred, label, neg) 27 | return loss_pos * 0.5 + loss_neg * 0.5 28 | 29 | 30 | def weight_l1_loss(pred_loc, label_loc, loss_weight): 31 | b, _, sh, sw = pred_loc.size() 32 | pred_loc = pred_loc.view(b, 4, -1, sh, sw) 33 | diff = (pred_loc - label_loc).abs() 34 | diff = diff.sum(dim=1).view(b, -1, sh, sw) 35 | loss = diff * loss_weight 36 | return loss.sum().div(b) 37 | -------------------------------------------------------------------------------- /pysot/models/neck/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from pysot.models.neck.neck import AdjustLayer, AdjustAllLayer 13 | 14 | NECKS = { 15 | 'AdjustLayer': AdjustLayer, 16 | 'AdjustAllLayer': AdjustAllLayer 17 | } 18 | 19 | def get_neck(name, **kwargs): 20 | return NECKS[name](**kwargs) 21 | -------------------------------------------------------------------------------- /pysot/models/neck/neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class AdjustLayer(nn.Module): 12 | def __init__(self, in_channels, out_channels): 13 | super(AdjustLayer, self).__init__() 14 | self.downsample = nn.Sequential( 15 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 16 | nn.BatchNorm2d(out_channels), 17 | ) 18 | 19 | def forward(self, x): 20 | x = self.downsample(x) 21 | if x.size(3) < 20: 22 | l = 4 23 | r = l + 7 24 | x = x[:, :, l:r, l:r] 25 | return x 26 | 27 | 28 | class AdjustAllLayer(nn.Module): 29 | def __init__(self, in_channels, out_channels): 30 | super(AdjustAllLayer, self).__init__() 31 | self.num = len(out_channels) 32 | if self.num == 1: 33 | self.downsample = AdjustLayer(in_channels[0], out_channels[0]) 34 | else: 35 | for i in range(self.num): 36 | self.add_module('downsample'+str(i+2), 37 | AdjustLayer(in_channels[i], out_channels[i])) 38 | 39 | def forward(self, features): 40 | if self.num == 1: 41 | return self.downsample(features) 42 | else: 43 | out = [] 44 | for i in range(self.num): 45 | adj_layer = getattr(self, 'downsample'+str(i+2)) 46 | out.append(adj_layer(features[i])) 47 | return out 48 | -------------------------------------------------------------------------------- /pysot/tracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/tracker/__init__.py -------------------------------------------------------------------------------- /pysot/tracker/base_tracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | # Modified by Jinghao Zhou 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | from __future__ import unicode_literals 8 | 9 | import cv2 10 | import numpy as np 11 | import torch 12 | 13 | from pysot.core.config import cfg 14 | 15 | 16 | class BaseTracker(object): 17 | 18 | def init(self, img, bbox): 19 | raise NotImplementedError 20 | 21 | def track(self, img): 22 | raise NotImplementedError 23 | 24 | 25 | class SiameseTracker(BaseTracker): 26 | def get_subwindow(self, im, pos, model_sz, original_sz, avg_chans): 27 | 28 | if isinstance(pos, float): 29 | pos = [pos, pos] 30 | sz = original_sz 31 | im_sz = im.shape 32 | c = (original_sz + 1) / 2 33 | # context_xmin = round(pos[0] - c) # py2 and py3 round 34 | context_xmin = np.floor(pos[0] - c + 0.5) 35 | context_xmax = context_xmin + sz - 1 36 | # context_ymin = round(pos[1] - c) 37 | context_ymin = np.floor(pos[1] - c + 0.5) 38 | context_ymax = context_ymin + sz - 1 39 | left_pad = int(max(0., -context_xmin)) 40 | top_pad = int(max(0., -context_ymin)) 41 | right_pad = int(max(0., context_xmax - im_sz[1] + 1)) 42 | bottom_pad = int(max(0., context_ymax - im_sz[0] + 1)) 43 | 44 | context_xmin = context_xmin + left_pad 45 | context_xmax = context_xmax + left_pad 46 | context_ymin = context_ymin + top_pad 47 | context_ymax = context_ymax + top_pad 48 | 49 | r, c, k = im.shape 50 | if any([top_pad, bottom_pad, left_pad, right_pad]): 51 | size = (r + top_pad + bottom_pad, c + left_pad + right_pad, k) 52 | te_im = np.zeros(size, np.uint8) 53 | te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im 54 | if top_pad: 55 | te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans 56 | if bottom_pad: 57 | te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans 58 | if left_pad: 59 | te_im[:, 0:left_pad, :] = avg_chans 60 | if right_pad: 61 | te_im[:, c + left_pad:, :] = avg_chans 62 | im_patch = te_im[int(context_ymin):int(context_ymax + 1), 63 | int(context_xmin):int(context_xmax + 1), :] 64 | else: 65 | im_patch = im[int(context_ymin):int(context_ymax + 1), 66 | int(context_xmin):int(context_xmax + 1), :] 67 | 68 | if not np.array_equal(model_sz, original_sz): 69 | im_patch = cv2.resize(im_patch, (model_sz, model_sz)) 70 | im_patch = im_patch.transpose(2, 0, 1) 71 | im_patch = im_patch[np.newaxis, :, :, :] 72 | im_patch = im_patch.astype(np.float32) 73 | im_patch = torch.from_numpy(im_patch) 74 | if cfg.CUDA: 75 | im_patch = im_patch.cuda() 76 | return im_patch 77 | -------------------------------------------------------------------------------- /pysot/tracker/classifier/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/tracker/classifier/__init__.py -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensorlist import TensorList 2 | from .tensordict import TensorDict -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | 4 | def normalize(score): 5 | score = (score - np.min(score)) / (np.max(score) - np.min(score)) 6 | return score 7 | 8 | def normfun(x, mu, sigma): 9 | pdf = np.exp(-((x - mu)**2) / (2* sigma**2)) 10 | return pdf 11 | 12 | def generate_xy_attention(center, size): 13 | 14 | a = np.linspace(-size//2+1, size//2, size) 15 | x = - normfun(a, center[1], 10).reshape((size,1)) + 2 16 | y = - normfun(a, center[0], 10).reshape((1,size)) + 2 17 | z = normalize(1. / np.dot(np.abs(x), np.abs(y))) 18 | return z 19 | 20 | if __name__ == '__main__': 21 | generate_xy_attention([0,0], 31) -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | import cv2 as cv 6 | from .preprocessing import numpy_to_torch, torch_to_numpy 7 | 8 | 9 | class Transform: 10 | """Base data augmentation transform class.""" 11 | 12 | def __init__(self, output_sz = None, shift = None): 13 | self.output_sz = output_sz 14 | self.shift = (0,0) if shift is None else shift 15 | 16 | def __call__(self, image): 17 | raise NotImplementedError 18 | 19 | def crop_to_output(self, image): 20 | if isinstance(image, torch.Tensor): 21 | imsz = image.shape[2:] 22 | if self.output_sz is None: 23 | pad_h = 0 24 | pad_w = 0 25 | else: 26 | pad_h = (self.output_sz[0] - imsz[0]) / 2 27 | pad_w = (self.output_sz[1] - imsz[1]) / 2 28 | 29 | pad_left = math.floor(pad_w) + self.shift[1] 30 | pad_right = math.ceil(pad_w) - self.shift[1] 31 | pad_top = math.floor(pad_h) + self.shift[0] 32 | pad_bottom = math.ceil(pad_h) - self.shift[0] 33 | 34 | return F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), 'replicate') 35 | else: 36 | raise NotImplementedError 37 | 38 | class Identity(Transform): 39 | """Identity transformation.""" 40 | def __call__(self, image): 41 | return self.crop_to_output(image) 42 | 43 | class FlipHorizontal(Transform): 44 | """Flip along horizontal axis.""" 45 | def __call__(self, image): 46 | if isinstance(image, torch.Tensor): 47 | return self.crop_to_output(image.flip((3,))) 48 | else: 49 | return np.fliplr(image) 50 | 51 | class FlipVertical(Transform): 52 | """Flip along vertical axis.""" 53 | def __call__(self, image: torch.Tensor): 54 | if isinstance(image, torch.Tensor): 55 | return self.crop_to_output(image.flip((2,))) 56 | else: 57 | return np.flipud(image) 58 | 59 | class Translation(Transform): 60 | """Translate.""" 61 | def __init__(self, translation, output_sz = None, shift = None): 62 | super().__init__(output_sz, shift) 63 | self.shift = (self.shift[0] + translation[0], self.shift[1] + translation[1]) 64 | 65 | def __call__(self, image): 66 | if isinstance(image, torch.Tensor): 67 | return self.crop_to_output(image) 68 | else: 69 | raise NotImplementedError 70 | 71 | class Scale(Transform): 72 | """Scale.""" 73 | def __init__(self, scale_factor, output_sz = None, shift = None): 74 | super().__init__(output_sz, shift) 75 | self.scale_factor = scale_factor 76 | 77 | def __call__(self, image): 78 | if isinstance(image, torch.Tensor): 79 | # Calculate new size. Ensure that it is even so that crop/pad becomes easier 80 | h_orig, w_orig = image.shape[2:] 81 | 82 | if h_orig != w_orig: 83 | raise NotImplementedError 84 | 85 | h_new = round(h_orig /self.scale_factor) 86 | h_new += (h_new - h_orig) % 2 87 | w_new = round(w_orig /self.scale_factor) 88 | w_new += (w_new - w_orig) % 2 89 | 90 | image_resized = F.interpolate(image, [h_new, w_new], mode='bilinear') 91 | 92 | return self.crop_to_output(image_resized) 93 | else: 94 | raise NotImplementedError 95 | 96 | 97 | class Affine(Transform): 98 | """Affine transformation.""" 99 | def __init__(self, transform_matrix, output_sz = None, shift = None): 100 | super().__init__(output_sz, shift) 101 | self.transform_matrix = transform_matrix 102 | 103 | def __call__(self, image): 104 | if isinstance(image, torch.Tensor): 105 | return self.crop_to_output(numpy_to_torch(self(torch_to_numpy(image)))) 106 | else: 107 | return cv.warpAffine(image, self.transform_matrix, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE) 108 | 109 | 110 | class Rotate(Transform): 111 | """Rotate with given angle.""" 112 | def __init__(self, angle, output_sz = None, shift = None): 113 | super().__init__(output_sz, shift) 114 | self.angle = math.pi * angle/180 115 | 116 | def __call__(self, image): 117 | if isinstance(image, torch.Tensor): 118 | return self.crop_to_output(numpy_to_torch(self(torch_to_numpy(image)))) 119 | else: 120 | c = (np.expand_dims(np.array(image.shape[:2]),1)-1)/2 121 | R = np.array([[math.cos(self.angle), math.sin(self.angle)], 122 | [-math.sin(self.angle), math.cos(self.angle)]]) 123 | H =np.concatenate([R, c - R @ c], 1) 124 | return cv.warpAffine(image, H, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE) 125 | 126 | 127 | class Blur(Transform): 128 | """Blur with given sigma (can be axis dependent).""" 129 | def __init__(self, sigma, output_sz = None, shift = None): 130 | super().__init__(output_sz, shift) 131 | if isinstance(sigma, (float, int)): 132 | sigma = (sigma, sigma) 133 | self.sigma = sigma 134 | self.filter_size = [math.ceil(2*s) for s in self.sigma] 135 | x_coord = [torch.arange(-sz, sz+1, dtype=torch.float32) for sz in self.filter_size] 136 | self.filter = [torch.exp(-(x**2)/(2*s**2)) for x, s in zip(x_coord, self.sigma)] 137 | self.filter[0] = self.filter[0].view(1,1,-1,1) / self.filter[0].sum() 138 | self.filter[1] = self.filter[1].view(1,1,1,-1) / self.filter[1].sum() 139 | 140 | def __call__(self, image): 141 | if isinstance(image, torch.Tensor): 142 | sz = image.shape[2:] 143 | im1 = F.conv2d(image.view(-1,1,sz[0],sz[1]), self.filter[0], padding=(self.filter_size[0],0)) 144 | return self.crop_to_output(F.conv2d(im1, self.filter[1], padding=(0,self.filter_size[1])).view(1,-1,sz[0],sz[1])) 145 | else: 146 | raise NotImplementedError -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/complex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .tensorlist import tensor_operation 3 | 4 | 5 | def is_complex(a: torch.Tensor) -> bool: 6 | return a.dim() >= 4 and a.shape[-1] == 2 7 | 8 | 9 | def is_real(a: torch.Tensor) -> bool: 10 | return not is_complex(a) 11 | 12 | 13 | @tensor_operation 14 | def mult(a: torch.Tensor, b: torch.Tensor): 15 | """Pointwise complex multiplication of complex tensors.""" 16 | 17 | if is_real(a): 18 | if a.dim() >= b.dim(): 19 | raise ValueError('Incorrect dimensions.') 20 | # a is real 21 | return mult_real_cplx(a, b) 22 | if is_real(b): 23 | if b.dim() >= a.dim(): 24 | raise ValueError('Incorrect dimensions.') 25 | # b is real 26 | return mult_real_cplx(b, a) 27 | 28 | # Both complex 29 | c = mult_real_cplx(a[..., 0], b) 30 | c[..., 0] -= a[..., 1] * b[..., 1] 31 | c[..., 1] += a[..., 1] * b[..., 0] 32 | return c 33 | 34 | 35 | @tensor_operation 36 | def mult_conj(a: torch.Tensor, b: torch.Tensor): 37 | """Pointwise complex multiplication of complex tensors, with conjugate on b: a*conj(b).""" 38 | 39 | if is_real(a): 40 | if a.dim() >= b.dim(): 41 | raise ValueError('Incorrect dimensions.') 42 | # a is real 43 | return mult_real_cplx(a, conj(b)) 44 | if is_real(b): 45 | if b.dim() >= a.dim(): 46 | raise ValueError('Incorrect dimensions.') 47 | # b is real 48 | return mult_real_cplx(b, a) 49 | 50 | # Both complex 51 | c = mult_real_cplx(b[...,0], a) 52 | c[..., 0] += a[..., 1] * b[..., 1] 53 | c[..., 1] -= a[..., 0] * b[..., 1] 54 | return c 55 | 56 | 57 | @tensor_operation 58 | def mult_real_cplx(a: torch.Tensor, b: torch.Tensor): 59 | """Pointwise complex multiplication of real tensor a with complex tensor b.""" 60 | 61 | if is_real(b): 62 | raise ValueError('Last dimension must have length 2.') 63 | 64 | return a.unsqueeze(-1) * b 65 | 66 | 67 | @tensor_operation 68 | def div(a: torch.Tensor, b: torch.Tensor): 69 | """Pointwise complex division of complex tensors.""" 70 | 71 | if is_real(b): 72 | if b.dim() >= a.dim(): 73 | raise ValueError('Incorrect dimensions.') 74 | # b is real 75 | return div_cplx_real(a, b) 76 | 77 | return div_cplx_real(mult_conj(a, b), abs_sqr(b)) 78 | 79 | 80 | @tensor_operation 81 | def div_cplx_real(a: torch.Tensor, b: torch.Tensor): 82 | """Pointwise complex division of complex tensor a with real tensor b.""" 83 | 84 | if is_real(a): 85 | raise ValueError('Last dimension must have length 2.') 86 | 87 | return a / b.unsqueeze(-1) 88 | 89 | 90 | @tensor_operation 91 | def abs_sqr(a: torch.Tensor): 92 | """Squared absolute value.""" 93 | 94 | if is_real(a): 95 | raise ValueError('Last dimension must have length 2.') 96 | 97 | return torch.sum(a*a, -1) 98 | 99 | 100 | @tensor_operation 101 | def abs(a: torch.Tensor): 102 | """Absolute value.""" 103 | 104 | if is_real(a): 105 | raise ValueError('Last dimension must have length 2.') 106 | 107 | return torch.sqrt(abs_sqr(a)) 108 | 109 | 110 | @tensor_operation 111 | def conj(a: torch.Tensor): 112 | """Complex conjugate.""" 113 | 114 | if is_real(a): 115 | raise ValueError('Last dimension must have length 2.') 116 | 117 | # return a * torch.Tensor([1, -1], device=a.device) 118 | return complex(a[...,0], -a[...,1]) 119 | 120 | 121 | @tensor_operation 122 | def real(a: torch.Tensor): 123 | """Real part.""" 124 | 125 | if is_real(a): 126 | raise ValueError('Last dimension must have length 2.') 127 | 128 | return a[..., 0] 129 | 130 | 131 | @tensor_operation 132 | def imag(a: torch.Tensor): 133 | """Imaginary part.""" 134 | 135 | if is_real(a): 136 | raise ValueError('Last dimension must have length 2.') 137 | 138 | return a[..., 1] 139 | 140 | 141 | @tensor_operation 142 | def complex(a: torch.Tensor, b: torch.Tensor = None): 143 | """Create complex tensor from real and imaginary part.""" 144 | 145 | if b is None: 146 | b = a.new_zeros(a.shape) 147 | elif a is None: 148 | a = b.new_zeros(b.shape) 149 | 150 | return torch.cat((a.unsqueeze(-1), b.unsqueeze(-1)), -1) 151 | 152 | 153 | @tensor_operation 154 | def mtimes(a: torch.Tensor, b: torch.Tensor, conj_a=False, conj_b=False): 155 | """Complex matrix multiplication of complex tensors. 156 | The dimensions (-3, -2) are matrix multiplied. -1 is the complex dimension.""" 157 | 158 | if is_real(a): 159 | if a.dim() >= b.dim(): 160 | raise ValueError('Incorrect dimensions.') 161 | return mtimes_real_complex(a, b, conj_b=conj_b) 162 | if is_real(b): 163 | if b.dim() >= a.dim(): 164 | raise ValueError('Incorrect dimensions.') 165 | return mtimes_complex_real(a, b, conj_a=conj_a) 166 | 167 | if not conj_a and not conj_b: 168 | return complex(torch.matmul(a[..., 0], b[..., 0]) - torch.matmul(a[..., 1], b[..., 1]), 169 | torch.matmul(a[..., 0], b[..., 1]) + torch.matmul(a[..., 1], b[..., 0])) 170 | if conj_a and not conj_b: 171 | return complex(torch.matmul(a[..., 0], b[..., 0]) + torch.matmul(a[..., 1], b[..., 1]), 172 | torch.matmul(a[..., 0], b[..., 1]) - torch.matmul(a[..., 1], b[..., 0])) 173 | if not conj_a and conj_b: 174 | return complex(torch.matmul(a[..., 0], b[..., 0]) + torch.matmul(a[..., 1], b[..., 1]), 175 | torch.matmul(a[..., 1], b[..., 0]) - torch.matmul(a[..., 0], b[..., 1])) 176 | if conj_a and conj_b: 177 | return complex(torch.matmul(a[..., 0], b[..., 0]) - torch.matmul(a[..., 1], b[..., 1]), 178 | -torch.matmul(a[..., 0], b[..., 1]) - torch.matmul(a[..., 1], b[..., 0])) 179 | 180 | 181 | @tensor_operation 182 | def mtimes_real_complex(a: torch.Tensor, b: torch.Tensor, conj_b=False): 183 | if is_real(b): 184 | raise ValueError('Incorrect dimensions.') 185 | 186 | if not conj_b: 187 | return complex(torch.matmul(a, b[..., 0]), torch.matmul(a, b[..., 1])) 188 | if conj_b: 189 | return complex(torch.matmul(a, b[..., 0]), -torch.matmul(a, b[..., 1])) 190 | 191 | 192 | @tensor_operation 193 | def mtimes_complex_real(a: torch.Tensor, b: torch.Tensor, conj_a=False): 194 | if is_real(a): 195 | raise ValueError('Incorrect dimensions.') 196 | 197 | if not conj_a: 198 | return complex(torch.matmul(a[..., 0], b), torch.matmul(a[..., 1], b)) 199 | if conj_a: 200 | return complex(torch.matmul(a[..., 0], b), -torch.matmul(a[..., 1], b)) 201 | 202 | 203 | @tensor_operation 204 | def exp_imag(a: torch.Tensor): 205 | """Complex exponential with imaginary input: e^(i*a)""" 206 | 207 | a = a.unsqueeze(-1) 208 | return torch.cat((torch.cos(a), torch.sin(a)), -1) 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/fourier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pysot.tracker.classifier.libs.complex as complex 4 | from .tensorlist import tensor_operation, TensorList 5 | 6 | 7 | @tensor_operation 8 | def rfftshift2(a: torch.Tensor): 9 | h = a.shape[2] + 2 10 | return torch.cat((a[:,:,(h-1)//2:,...], a[:,:,:h//2,...]), 2) 11 | 12 | 13 | @tensor_operation 14 | def irfftshift2(a: torch.Tensor): 15 | mid = int((a.shape[2]-1)/2) 16 | return torch.cat((a[:,:,mid:,...], a[:,:,:mid,...]), 2) 17 | 18 | 19 | @tensor_operation 20 | def cfft2(a): 21 | """Do FFT and center the low frequency component. 22 | Always produces odd (full) output sizes.""" 23 | 24 | return rfftshift2(torch.rfft(a, 2)) 25 | 26 | 27 | @tensor_operation 28 | def cifft2(a, signal_sizes=None): 29 | """Do inverse FFT corresponding to cfft2.""" 30 | 31 | return torch.irfft(irfftshift2(a), 2, signal_sizes=signal_sizes) 32 | 33 | 34 | @tensor_operation 35 | def sample_fs(a: torch.Tensor, grid_sz: torch.Tensor = None, rescale = True): 36 | """Samples the Fourier series.""" 37 | 38 | # Size of the fourier series 39 | sz = torch.Tensor([a.shape[2], 2*a.shape[3]-1]).float() 40 | 41 | # Default grid 42 | if grid_sz is None or sz[0] == grid_sz[0] and sz[1] == grid_sz[1]: 43 | if rescale: 44 | return sz.prod().item() * cifft2(a) 45 | return cifft2(a) 46 | 47 | if sz[0] > grid_sz[0] or sz[1] > grid_sz[1]: 48 | raise ValueError("Only grid sizes that are smaller than the Fourier series size are supported.") 49 | 50 | tot_pad = (grid_sz - sz).tolist() 51 | is_even = [s.item() % 2 == 0 for s in sz] 52 | 53 | # Compute paddings 54 | pad_top = int((tot_pad[0]+1)/2) if is_even[0] else int(tot_pad[0]/2) 55 | pad_bottom = int(tot_pad[0] - pad_top) 56 | pad_right = int((tot_pad[1]+1)/2) 57 | 58 | if rescale: 59 | return grid_sz.prod().item() * cifft2(F.pad(a, (0, 0, 0, pad_right, pad_top, pad_bottom)), signal_sizes=grid_sz.long().tolist()) 60 | else: 61 | return cifft2(F.pad(a, (0, 0, 0, pad_right, pad_top, pad_bottom)), signal_sizes=grid_sz.long().tolist()) 62 | 63 | 64 | def get_frequency_coord(sz, add_complex_dim = False, device='cpu'): 65 | """Frequency coordinates.""" 66 | 67 | ky = torch.arange(-int((sz[0]-1)/2), int(sz[0]/2+1), dtype=torch.float32, device=device).view(1,1,-1,1) 68 | kx = torch.arange(0, int(sz[1]/2+1), dtype=torch.float32, device=device).view(1,1,1,-1) 69 | 70 | if add_complex_dim: 71 | ky = ky.unsqueeze(-1) 72 | kx = kx.unsqueeze(-1) 73 | 74 | return ky, kx 75 | 76 | 77 | @tensor_operation 78 | def shift_fs(a: torch.Tensor, shift: torch.Tensor): 79 | """Shift a sample a in the Fourier domain. 80 | Params: 81 | a : The fourier coefficiens of the sample. 82 | shift : The shift to be performed normalized to the range [-pi, pi].""" 83 | 84 | if a.dim() != 5: 85 | raise ValueError('a must be the Fourier coefficients, a 5-dimensional tensor.') 86 | 87 | if shift[0] == 0 and shift[1] == 0: 88 | return a 89 | 90 | ky, kx = get_frequency_coord((a.shape[2], 2*a.shape[3]-1), device=a.device) 91 | 92 | return complex.mult(complex.mult(a, complex.exp_imag(shift[0].item() * ky)), complex.exp_imag(shift[1].item() * kx)) 93 | 94 | 95 | def sum_fs(a: TensorList) -> torch.Tensor: 96 | """Sum a list of Fourier series expansions.""" 97 | 98 | s = None 99 | mid = None 100 | 101 | for e in sorted(a, key=lambda elem: elem.shape[-3], reverse=True): 102 | if s is None: 103 | s = e.clone() 104 | mid = int((s.shape[-3] - 1) / 2) 105 | else: 106 | # Compute coordinates 107 | top = mid - int((e.shape[-3] - 1) / 2) 108 | bottom = mid + int(e.shape[-3] / 2) + 1 109 | right = e.shape[-2] 110 | 111 | # Add the data 112 | s[..., top:bottom, :right, :] += e 113 | 114 | return s 115 | 116 | 117 | def sum_fs12(a: TensorList) -> torch.Tensor: 118 | """Sum a list of Fourier series expansions.""" 119 | 120 | s = None 121 | mid = None 122 | 123 | for e in sorted(a, key=lambda elem: elem.shape[0], reverse=True): 124 | if s is None: 125 | s = e.clone() 126 | mid = int((s.shape[0] - 1) / 2) 127 | else: 128 | # Compute coordinates 129 | top = mid - int((e.shape[0] - 1) / 2) 130 | bottom = mid + int(e.shape[0] / 2) + 1 131 | right = e.shape[1] 132 | 133 | # Add the data 134 | s[top:bottom, :right, ...] += e 135 | 136 | return s 137 | 138 | 139 | @tensor_operation 140 | def inner_prod_fs(a: torch.Tensor, b: torch.Tensor): 141 | if complex.is_complex(a) and complex.is_complex(b): 142 | return 2 * (a.reshape(-1) @ b.reshape(-1)) - a[:, :, :, 0, :].reshape(-1) @ b[:, :, :, 0, :].reshape(-1) 143 | elif complex.is_real(a) and complex.is_real(b): 144 | return 2 * (a.reshape(-1) @ b.reshape(-1)) - a[:, :, :, 0].reshape(-1) @ b[:, :, :, 0].reshape(-1) 145 | else: 146 | raise NotImplementedError('Not implemented for mixed real and complex.') -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/operation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .tensorlist import tensor_operation, TensorList 4 | 5 | 6 | @tensor_operation 7 | def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None, stride=1, padding=0, dilation=1, groups=1, mode=None): 8 | """Standard conv2d. Returns the input if weight=None.""" 9 | 10 | if weight is None: 11 | return input 12 | 13 | ind = None 14 | if mode is not None: 15 | if padding != 0: 16 | raise ValueError('Cannot input both padding and mode.') 17 | if mode == 'same': 18 | padding = (weight.shape[2]//2, weight.shape[3]//2) 19 | if weight.shape[2] % 2 == 0 or weight.shape[3] % 2 == 0: 20 | ind = (slice(-1) if weight.shape[2] % 2 == 0 else slice(None), 21 | slice(-1) if weight.shape[3] % 2 == 0 else slice(None)) 22 | elif mode == 'valid': 23 | padding = (0, 0) 24 | elif mode == 'full': 25 | padding = (weight.shape[2]-1, weight.shape[3]-1) 26 | else: 27 | raise ValueError('Unknown mode for padding.') 28 | 29 | out = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 30 | if ind is None: 31 | return out 32 | return out[:,:,ind[0],ind[1]] 33 | 34 | 35 | @tensor_operation 36 | def conv1x1(input: torch.Tensor, weight: torch.Tensor): 37 | """Do a convolution with a 1x1 kernel weights. Implemented with matmul, which can be faster than using conv.""" 38 | 39 | if weight is None: 40 | return input 41 | 42 | return torch.matmul(weight.view(weight.shape[0], weight.shape[1]), 43 | input.view(input.shape[0], input.shape[1], -1)).view(input.shape[0], weight.shape[0], input.shape[2], input.shape[3]) 44 | 45 | @tensor_operation 46 | def spatial_attention(input: torch.Tensor, dim: int=0, keepdim: bool=True): 47 | return torch.sigmoid(torch.mean(input, dim, keepdim)) 48 | 49 | @tensor_operation 50 | def adaptive_avg_pool2d(input: torch.Tensor, shape): 51 | return F.adaptive_avg_pool2d(input, shape) 52 | 53 | @tensor_operation 54 | def sigmoid(input: torch.Tensor): 55 | return torch.sigmoid(input) 56 | 57 | @tensor_operation 58 | def softmax(input: torch.Tensor): 59 | x_shape = input.size() 60 | return F.softmax(input.reshape(x_shape[0], -1), dim=1).reshape(x_shape) 61 | 62 | @tensor_operation 63 | def matmul(a: torch.Tensor, b: torch.Tensor): 64 | return a * b.expand_as(a) -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/params.py: -------------------------------------------------------------------------------- 1 | from .tensorlist import TensorList 2 | import random 3 | 4 | 5 | class TrackerParams: 6 | """Class for tracker parameters.""" 7 | def free_memory(self): 8 | for a in dir(self): 9 | if not a.startswith('__') and hasattr(getattr(self, a), 'free_memory'): 10 | getattr(self, a).free_memory() 11 | 12 | 13 | class FeatureParams: 14 | """Class for feature specific parameters""" 15 | def __init__(self, *args, **kwargs): 16 | if len(args) > 0: 17 | raise ValueError 18 | 19 | for name, val in kwargs.items(): 20 | if isinstance(val, list): 21 | setattr(self, name, TensorList(val)) 22 | else: 23 | setattr(self, name, val) 24 | 25 | 26 | def Choice(*args): 27 | """Can be used to sample random parameter values.""" 28 | return random.choice(args) 29 | -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('TkAgg') 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def show_tensor(a: torch.Tensor, fig_num = None, title = None): 9 | """Display a 2D tensor. 10 | args: 11 | fig_num: Figure number. 12 | title: Title of figure. 13 | """ 14 | a_np = a.squeeze().cpu().clone().detach().numpy() 15 | if a_np.ndim == 3: 16 | a_np = np.transpose(a_np, (1, 2, 0)) 17 | plt.figure(fig_num) 18 | plt.tight_layout() 19 | plt.cla() 20 | plt.imshow(a_np) 21 | plt.axis('off') 22 | plt.axis('equal') 23 | if title is not None: 24 | plt.title(title) 25 | plt.draw() 26 | plt.pause(0.001) 27 | 28 | 29 | def plot_graph(a: torch.Tensor, fig_num = None, title = None): 30 | """Plot graph. Data is a 1D tensor. 31 | args: 32 | fig_num: Figure number. 33 | title: Title of figure. 34 | """ 35 | a_np = a.squeeze().cpu().clone().detach().numpy() 36 | if a_np.ndim > 1: 37 | raise ValueError 38 | plt.figure(fig_num) 39 | # plt.tight_layout() 40 | plt.cla() 41 | plt.plot(a_np) 42 | if title is not None: 43 | plt.title(title) 44 | plt.draw() 45 | plt.pause(0.001) 46 | -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def numpy_to_torch(a: np.ndarray): 7 | return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0) 8 | 9 | 10 | def torch_to_numpy(a: torch.Tensor): 11 | return a.squeeze(0).permute(1,2,0).numpy() 12 | 13 | 14 | def sample_patch(im: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, output_sz: torch.Tensor = None): 15 | """Sample an image patch. 16 | 17 | args: 18 | im: Image 19 | pos: center position of crop 20 | sample_sz: size to crop 21 | output_sz: size to resize to 22 | """ 23 | 24 | # copy and convert 25 | posl = pos.long().clone() 26 | 27 | # Compute pre-downsampling factor 28 | if output_sz is not None: 29 | resize_factor = torch.min(sample_sz.float() / output_sz.float()).item() 30 | df = int(max(int(resize_factor - 0.1), 1)) 31 | else: 32 | df = int(1) 33 | 34 | sz = sample_sz.float() / df # new size 35 | 36 | # Do downsampling 37 | if df > 1: 38 | os = posl % df # offset 39 | posl = (posl - os) / df # new position 40 | im2 = im[..., os[0].item()::df, os[1].item()::df] # downsample 41 | else: 42 | im2 = im 43 | 44 | # compute size to crop 45 | szl = torch.max(sz.round(), torch.Tensor([2])).long() 46 | 47 | # Extract top and bottom coordinates 48 | tl = posl - (szl - 1)/2 49 | br = posl + szl/2 50 | 51 | # Get image patch 52 | im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3] + 1, -tl[0].item(), br[0].item() - im2.shape[2] + 1), 'replicate') 53 | 54 | if output_sz is None or (im_patch.shape[-2] == output_sz[0] and im_patch.shape[-1] == output_sz[1]): 55 | return im_patch 56 | 57 | # Resample 58 | im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='bilinear') 59 | 60 | return im_patch 61 | -------------------------------------------------------------------------------- /pysot/tracker/classifier/libs/tensordict.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | 4 | 5 | class TensorDict(OrderedDict): 6 | """Container mainly used for dicts of torch tensors. Extends OrderedDict with pytorch functionality.""" 7 | 8 | def concat(self, other): 9 | """Concatenates two dicts without copying internal data.""" 10 | return TensorDict(self, **other) 11 | 12 | def copy(self): 13 | return TensorDict(super(TensorDict, self).copy()) 14 | 15 | def __getattr__(self, name): 16 | if not hasattr(torch.Tensor, name): 17 | raise AttributeError('\'TensorDict\' object has not attribute \'{}\''.format(name)) 18 | 19 | def apply_attr(*args, **kwargs): 20 | return TensorDict({n: getattr(e, name)(*args, **kwargs) if hasattr(e, name) else e for n, e in self.items()}) 21 | return apply_attr 22 | 23 | def attribute(self, attr: str, *args): 24 | return TensorDict({n: getattr(e, attr, *args) for n, e in self.items()}) 25 | 26 | def apply(self, fn, *args, **kwargs): 27 | return TensorDict({n: fn(e, *args, **kwargs) for n, e in self.items()}) 28 | 29 | @staticmethod 30 | def _iterable(a): 31 | return isinstance(a, (TensorDict, list)) 32 | 33 | -------------------------------------------------------------------------------- /pysot/tracker/classifier/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pysot.tracker.classifier.libs import optimization, TensorList, operation 3 | import math 4 | from pysot.core.config import cfg 5 | 6 | 7 | class FactorizedConvProblem(optimization.L2Problem): 8 | def __init__(self, training_samples: TensorList, y: TensorList, use_attention: bool, 9 | filter_reg: torch.Tensor, projection_reg, sample_weights: TensorList, 10 | projection_activation, att_activation, response_activation): 11 | 12 | self.training_samples = training_samples 13 | self.y = y 14 | self.sample_weights = sample_weights 15 | self.use_attetion = use_attention 16 | self.filter_reg = filter_reg 17 | self.projection_reg = projection_reg 18 | self.projection_activation = projection_activation 19 | self.att_activation = att_activation 20 | self.response_activation = response_activation 21 | 22 | if self.use_attetion: 23 | self.diag_M = self.filter_reg.concat(projection_reg).concat(projection_reg).concat(projection_reg) 24 | else: 25 | self.diag_M = self.filter_reg.concat(projection_reg) 26 | 27 | def __call__(self, x: TensorList): 28 | 29 | if self.use_attetion: 30 | filter = x[:1] 31 | fc2 = x[1:2] 32 | fc1 = x[2:3] 33 | P = x[3:4] 34 | else: 35 | filter = x[:len(x)//2] # w2 in paper 36 | P = x[len(x)//2:] # w1 in paper 37 | 38 | # Compression module 39 | compressed_samples = operation.conv1x1(self.training_samples, P).apply(self.projection_activation) 40 | 41 | # Attention module 42 | if self.use_attetion: 43 | if cfg.TRACK.CHANNEL_ATTENTION: 44 | global_average = operation.adaptive_avg_pool2d(compressed_samples, 1) 45 | temp_variables = operation.conv1x1(global_average, fc1).apply(self.att_activation) 46 | channel_attention = operation.sigmoid(operation.conv1x1(temp_variables, fc2)) 47 | else: 48 | channel_attention = TensorList( 49 | [torch.zeros(compressed_samples[0].size(0), compressed_samples[0].size(1), 1, 1).cuda()] 50 | ) 51 | 52 | if cfg.TRACK.SPATIAL_ATTENTION == 'none': 53 | spatial_attention = TensorList( 54 | [torch.zeros(compressed_samples[0].size(0), 1, compressed_samples[0].size(2), 55 | compressed_samples[0].size(3)).cuda()] 56 | ) 57 | elif cfg.TRACK.SPATIAL_ATTENTION == 'pool': 58 | spatial_attention = operation.spatial_attention(compressed_samples, dim=1, keepdim=True) 59 | else: 60 | raise NotImplementedError('No spatial attention Implemented') 61 | 62 | compressed_samples = operation.matmul(compressed_samples, spatial_attention) + \ 63 | operation.matmul(compressed_samples, channel_attention) 64 | 65 | # Filter module 66 | residuals = operation.conv2d(compressed_samples, filter, mode='same').apply(self.response_activation) 67 | residuals = residuals - self.y 68 | residuals = self.sample_weights.sqrt().view(-1, 1, 1, 1) * residuals 69 | 70 | residuals.extend(self.filter_reg.apply(math.sqrt) * filter) 71 | if self.use_attetion: 72 | residuals.extend(self.projection_reg.apply(math.sqrt) * fc2) 73 | residuals.extend(self.projection_reg.apply(math.sqrt) * fc1) 74 | residuals.extend(self.projection_reg.apply(math.sqrt) * P) 75 | 76 | return residuals 77 | 78 | 79 | def ip_input(self, a: TensorList, b: TensorList): 80 | 81 | if self.use_attetion: 82 | a_filter = a[:1] 83 | a_f2 = a[1:2] 84 | a_f1 = a[2:3] 85 | a_P = a[3:] 86 | b_filter = b[:1] 87 | b_f2 = b[1:2] 88 | b_f1 = b[2:3] 89 | b_P = b[3:] 90 | 91 | ip_out = operation.conv2d(a_filter, b_filter).view(-1) 92 | ip_out += operation.conv2d(a_f2.view(1, -1, 1, 1), b_f2.view(1, -1, 1, 1)).view(-1) 93 | ip_out += operation.conv2d(a_f1.view(1, -1, 1, 1), b_f1.view(1, -1, 1, 1)).view(-1) 94 | ip_out += operation.conv2d(a_P.view(1, -1, 1, 1), b_P.view(1, -1, 1, 1)).view(-1) 95 | 96 | return ip_out.concat(ip_out.clone()).concat(ip_out.clone()).concat(ip_out.clone()) 97 | 98 | else: 99 | num = len(a) // 2 # Number of filters 100 | a_filter = a[:num] 101 | b_filter = b[:num] 102 | a_P = a[num:] 103 | b_P = b[num:] 104 | 105 | # Filter inner product 106 | # ip_out = a_filter.reshape(-1) @ b_filter.reshape(-1) 107 | ip_out = operation.conv2d(a_filter, b_filter).view(-1) 108 | 109 | # Add projection matrix part 110 | # ip_out += a_P.reshape(-1) @ b_P.reshape(-1) 111 | ip_out += operation.conv2d(a_P.view(1, -1, 1, 1), b_P.view(1, -1, 1, 1)).view(-1) 112 | 113 | # Have independent inner products for each filter 114 | return ip_out.concat(ip_out.clone()) 115 | 116 | def M1(self, x: TensorList): 117 | # factorized convolution 118 | return x / self.diag_M 119 | 120 | class ConvProblem(optimization.L2Problem): 121 | def __init__(self, training_samples: TensorList, y: TensorList, filter_reg: torch.Tensor, sample_weights: TensorList, response_activation): 122 | self.training_samples = training_samples 123 | self.y = y 124 | self.filter_reg = filter_reg 125 | self.sample_weights = sample_weights 126 | self.response_activation = response_activation 127 | 128 | def __call__(self, x: TensorList): 129 | """ 130 | Compute residuals 131 | :param x: [filters] 132 | :return: [data_terms, filter_regularizations] 133 | """ 134 | # Do convolution and compute residuals 135 | residuals = operation.conv2d(self.training_samples, x, mode='same').apply(self.response_activation) 136 | residuals = residuals - self.y 137 | residuals = self.sample_weights.sqrt().view(-1, 1, 1, 1) * residuals 138 | 139 | # Add regularization for projection matrix 140 | residuals.extend(self.filter_reg.apply(math.sqrt) * x) 141 | 142 | return residuals 143 | 144 | def ip_input(self, a: TensorList, b: TensorList): 145 | # return a.reshape(-1) @ b.reshape(-1) 146 | # return (a * b).sum() 147 | return operation.conv2d(a, b).view(-1) 148 | -------------------------------------------------------------------------------- /pysot/tracker/siamrpnlt_tracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | # Modified by Jinghao Zhou 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | from __future__ import unicode_literals 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | 13 | from pysot.core.config import cfg 14 | from pysot.tracker.classifier.libs.plotting import show_tensor 15 | from pysot.tracker.siamrpn_tracker import SiamRPNTracker 16 | 17 | class SiamRPNLTTracker(SiamRPNTracker): 18 | def __init__(self, model): 19 | super(SiamRPNLTTracker, self).__init__(model) 20 | self.longterm_state = False 21 | 22 | def track(self, img): 23 | w_z = self.size[0] + cfg.TRACK.CONTEXT_AMOUNT * np.sum(self.size) 24 | h_z = self.size[1] + cfg.TRACK.CONTEXT_AMOUNT * np.sum(self.size) 25 | s_z = np.sqrt(w_z * h_z) 26 | scale_z = cfg.TRACK.EXEMPLAR_SIZE / s_z 27 | 28 | if self.longterm_state: 29 | instance_size = cfg.TRACK.LOST_INSTANCE_SIZE 30 | else: 31 | instance_size = cfg.TRACK.INSTANCE_SIZE 32 | 33 | score_size = (instance_size - cfg.TRACK.EXEMPLAR_SIZE) // \ 34 | cfg.ANCHOR.STRIDE + 1 + cfg.TRACK.BASE_SIZE 35 | hanning = np.hanning(score_size) 36 | window = np.outer(hanning, hanning) 37 | window = np.tile(window.flatten(), self.anchor_num) 38 | anchors = self.generate_anchor(score_size) 39 | 40 | s_x = s_z * (instance_size / cfg.TRACK.EXEMPLAR_SIZE) 41 | x_crop = self.get_subwindow(img, self.center_pos, instance_size, 42 | round(s_x), self.channel_average) 43 | with torch.no_grad(): 44 | outputs = self.model.track(x_crop) 45 | 46 | score = self._convert_score(outputs['cls']) 47 | pred_bbox = self._convert_bbox(outputs['loc'], anchors) 48 | 49 | def change(r): 50 | return np.maximum(r, 1. / r) 51 | 52 | def sz(w, h): 53 | pad = (w + h) * 0.5 54 | return np.sqrt((w + pad) * (h + pad)) 55 | 56 | s_c = change(sz(pred_bbox[2, :], pred_bbox[3, :]) / 57 | (sz(self.size[0] * scale_z, self.size[1] * scale_z))) 58 | r_c = change((self.size[0] / self.size[1]) / 59 | (pred_bbox[2, :] / pred_bbox[3, :])) 60 | penalty = np.exp(-(r_c * s_c - 1) * cfg.TRACK.PENALTY_K) 61 | pscore = penalty * score 62 | 63 | def normalize(score): 64 | score = (score - np.min(score)) / (np.max(score) - np.min(score)) 65 | return score 66 | 67 | if cfg.TRACK.USE_CLASSIFIER: 68 | 69 | flag, s = self.classifier.track() 70 | confidence = Image.fromarray(s.detach().cpu().numpy()) 71 | confidence = np.array(confidence.resize((score_size, score_size))).flatten() 72 | pscore = pscore.reshape(5, -1) * (1 - cfg.TRACK.COEE_CLASS) + \ 73 | normalize(confidence) * cfg.TRACK.COEE_CLASS 74 | pscore = pscore.flatten() 75 | 76 | if not self.longterm_state: 77 | pscore = pscore * (1 - cfg.TRACK.WINDOW_INFLUENCE) + \ 78 | window * cfg.TRACK.WINDOW_INFLUENCE 79 | else: 80 | pscore = pscore * (1 - 0.001) + window * 0.001 81 | 82 | best_idx = np.argmax(pscore) 83 | bbox = pred_bbox[:, best_idx] / scale_z 84 | lr = penalty[best_idx] * score[best_idx] * cfg.TRACK.LR 85 | best_score = score[best_idx] 86 | if best_score >= cfg.TRACK.CONFIDENCE_LOW: 87 | cx = bbox[0] + self.center_pos[0] 88 | cy = bbox[1] + self.center_pos[1] 89 | width = self.size[0] * (1 - lr) + bbox[2] * lr 90 | height = self.size[1] * (1 - lr) + bbox[3] * lr 91 | else: 92 | cx = self.center_pos[0] 93 | cy = self.center_pos[1] 94 | width = self.size[0] 95 | height = self.size[1] 96 | 97 | self.center_pos = np.array([cx, cy]) 98 | self.size = np.array([width, height]) 99 | cx, cy, width, height = self._bbox_clip(cx, cy, width, height, img.shape[:2]) 100 | bbox = [cx - width / 2, cy - height / 2, width, height] 101 | 102 | if not self.longterm_state: 103 | if cfg.TRACK.USE_CLASSIFIER: 104 | self.classifier.update(bbox, scale_z, flag) 105 | 106 | if best_score < cfg.TRACK.CONFIDENCE_LOW: 107 | self.longterm_state = True 108 | elif best_score > cfg.TRACK.CONFIDENCE_HIGH: 109 | self.longterm_state = False 110 | 111 | if cfg.TRACK.USE_CLASSIFIER: 112 | return { 113 | 'bbox': bbox, 114 | 'best_score': best_score, 115 | 'flag': flag 116 | } 117 | else: 118 | return { 119 | 'bbox': bbox, 120 | 'best_score': best_score 121 | } 122 | 123 | -------------------------------------------------------------------------------- /pysot/tracker/tracker_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | # Modified by Jinghao Zhou 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | from __future__ import unicode_literals 8 | 9 | from pysot.core.config import cfg 10 | from pysot.tracker.siamfc_tracker import SiamFCTracker 11 | from pysot.tracker.siamrpn_tracker import SiamRPNTracker 12 | from pysot.tracker.siamrpnlt_tracker import SiamRPNLTTracker 13 | from pysot.tracker.siammask_tracker import SiamMaskTracker 14 | 15 | TRACKS = { 16 | 'SiamFCTracker': SiamFCTracker, 17 | 'SiamRPNTracker': SiamRPNTracker, 18 | 'SiamMaskTracker': SiamMaskTracker, 19 | 'SiamRPNLTTracker': SiamRPNLTTracker 20 | } 21 | 22 | 23 | def build_tracker(model): 24 | return TRACKS[cfg.TRACK.TYPE](model) 25 | -------------------------------------------------------------------------------- /pysot/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/utils/__init__.py -------------------------------------------------------------------------------- /pysot/utils/anchor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import math 9 | 10 | import numpy as np 11 | 12 | from pysot.utils.bbox import corner2center, center2corner 13 | 14 | 15 | class Anchors: 16 | """ 17 | This class generate anchors. 18 | """ 19 | def __init__(self, stride, ratios, scales, image_center=0, size=0): 20 | self.stride = stride 21 | self.ratios = ratios 22 | self.scales = scales 23 | self.image_center = 0 24 | self.size = 0 25 | 26 | self.anchor_num = len(self.scales) * len(self.ratios) 27 | 28 | self.anchors = None 29 | 30 | self.generate_anchors() 31 | 32 | def generate_anchors(self): 33 | """ 34 | generate anchors based on predefined configuration 35 | """ 36 | self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32) 37 | size = self.stride * self.stride 38 | count = 0 39 | for r in self.ratios: 40 | ws = int(math.sqrt(size*1. / r)) 41 | hs = int(ws * r) 42 | 43 | for s in self.scales: 44 | w = ws * s 45 | h = hs * s 46 | self.anchors[count][:] = [-w*0.5, -h*0.5, w*0.5, h*0.5][:] 47 | count += 1 48 | 49 | def generate_all_anchors(self, im_c, size): 50 | """ 51 | im_c: image center 52 | size: image size 53 | """ 54 | if self.image_center == im_c and self.size == size: 55 | return False 56 | self.image_center = im_c 57 | self.size = size 58 | 59 | a0x = im_c - size // 2 * self.stride 60 | ori = np.array([a0x] * 4, dtype=np.float32) 61 | zero_anchors = self.anchors + ori 62 | 63 | x1 = zero_anchors[:, 0] 64 | y1 = zero_anchors[:, 1] 65 | x2 = zero_anchors[:, 2] 66 | y2 = zero_anchors[:, 3] 67 | 68 | x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), 69 | [x1, y1, x2, y2]) 70 | cx, cy, w, h = corner2center([x1, y1, x2, y2]) 71 | 72 | disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride 73 | disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride 74 | 75 | cx = cx + disp_x 76 | cy = cy + disp_y 77 | 78 | # broadcast 79 | zero = np.zeros((self.anchor_num, size, size), dtype=np.float32) 80 | cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h]) 81 | x1, y1, x2, y2 = center2corner([cx, cy, w, h]) 82 | 83 | self.all_anchors = (np.stack([x1, y1, x2, y2]).astype(np.float32), 84 | np.stack([cx, cy, w, h]).astype(np.float32)) 85 | return True 86 | -------------------------------------------------------------------------------- /pysot/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | 9 | class Meter(object): 10 | def __init__(self, name, val, avg): 11 | self.name = name 12 | self.val = val 13 | self.avg = avg 14 | 15 | def __repr__(self): 16 | return "{name}: {val:.6f} ({avg:.6f})".format( 17 | name=self.name, val=self.val, avg=self.avg 18 | ) 19 | 20 | def __format__(self, *tuples, **kwargs): 21 | return self.__repr__() 22 | 23 | 24 | class AverageMeter: 25 | """Computes and stores the average and current value""" 26 | def __init__(self, num=100): 27 | self.num = num 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = {} 32 | self.sum = {} 33 | self.count = {} 34 | self.history = {} 35 | 36 | def update(self, batch=1, **kwargs): 37 | val = {} 38 | for k in kwargs: 39 | val[k] = kwargs[k] / float(batch) 40 | self.val.update(val) 41 | for k in kwargs: 42 | if k not in self.sum: 43 | self.sum[k] = 0 44 | self.count[k] = 0 45 | self.history[k] = [] 46 | self.sum[k] += kwargs[k] 47 | self.count[k] += batch 48 | for _ in range(batch): 49 | self.history[k].append(val[k]) 50 | 51 | if self.num <= 0: 52 | # < 0, average all 53 | self.history[k] = [] 54 | 55 | # == 0: no average 56 | if self.num == 0: 57 | self.sum[k] = self.val[k] 58 | self.count[k] = 1 59 | 60 | elif len(self.history[k]) > self.num: 61 | pop_num = len(self.history[k]) - self.num 62 | for _ in range(pop_num): 63 | self.sum[k] -= self.history[k][0] 64 | del self.history[k][0] 65 | self.count[k] -= 1 66 | 67 | def __repr__(self): 68 | s = '' 69 | for k in self.sum: 70 | s += self.format_str(k) 71 | return s 72 | 73 | def format_str(self, attr): 74 | return "{name}: {val:.6f} ({avg:.6f}) ".format( 75 | name=attr, 76 | val=float(self.val[attr]), 77 | avg=float(self.sum[attr]) / self.count[attr]) 78 | 79 | def __getattr__(self, attr): 80 | if attr in self.__dict__: 81 | return super(AverageMeter, self).__getattr__(attr) 82 | if attr not in self.sum: 83 | print("invalid key '{}'".format(attr)) 84 | return Meter(attr, 0, 0) 85 | return Meter(attr, self.val[attr], self.avg(attr)) 86 | 87 | def avg(self, attr): 88 | return float(self.sum[attr]) / self.count[attr] 89 | 90 | 91 | if __name__ == '__main__': 92 | avg1 = AverageMeter(10) 93 | avg2 = AverageMeter(0) 94 | avg3 = AverageMeter(-1) 95 | 96 | for i in range(20): 97 | avg1.update(s=i) 98 | avg2.update(s=i) 99 | avg3.update(s=i) 100 | 101 | print('iter {}'.format(i)) 102 | print(avg1.s) 103 | print(avg2.s) 104 | print(avg3.s) 105 | -------------------------------------------------------------------------------- /pysot/utils/bbox.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | from collections import namedtuple 9 | 10 | import numpy as np 11 | 12 | 13 | Corner = namedtuple('Corner', 'x1 y1 x2 y2') 14 | # alias 15 | BBox = Corner 16 | Center = namedtuple('Center', 'x y w h') 17 | 18 | 19 | def corner2center(corner): 20 | """ convert (x1, y1, x2, y2) to (cx, cy, w, h) 21 | Args: 22 | conrner: Corner or np.array (4*N) 23 | Return: 24 | Center or np.array (4 * N) 25 | """ 26 | if isinstance(corner, Corner): 27 | x1, y1, x2, y2 = corner 28 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1)) 29 | else: 30 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3] 31 | x = (x1 + x2) * 0.5 32 | y = (y1 + y2) * 0.5 33 | w = x2 - x1 34 | h = y2 - y1 35 | return x, y, w, h 36 | 37 | 38 | def center2corner(center): 39 | """ convert (cx, cy, w, h) to (x1, y1, x2, y2) 40 | Args: 41 | center: Center or np.array (4 * N) 42 | Return: 43 | center or np.array (4 * N) 44 | """ 45 | if isinstance(center, Center): 46 | x, y, w, h = center 47 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5) 48 | else: 49 | x, y, w, h = center[0], center[1], center[2], center[3] 50 | x1 = x - w * 0.5 51 | y1 = y - h * 0.5 52 | x2 = x + w * 0.5 53 | y2 = y + h * 0.5 54 | return x1, y1, x2, y2 55 | 56 | 57 | def IoU(rect1, rect2): 58 | """ caculate interection over union 59 | Args: 60 | rect1: (x1, y1, x2, y2) 61 | rect2: (x1, y1, x2, y2) 62 | Returns: 63 | iou 64 | """ 65 | # overlap 66 | x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3] 67 | tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3] 68 | 69 | xx1 = np.maximum(tx1, x1) 70 | yy1 = np.maximum(ty1, y1) 71 | xx2 = np.minimum(tx2, x2) 72 | yy2 = np.minimum(ty2, y2) 73 | 74 | ww = np.maximum(0, xx2 - xx1) 75 | hh = np.maximum(0, yy2 - yy1) 76 | 77 | area = (x2-x1) * (y2-y1) 78 | target_a = (tx2-tx1) * (ty2 - ty1) 79 | inter = ww * hh 80 | iou = inter / (area + target_a - inter) 81 | return iou 82 | 83 | 84 | def cxy_wh_2_rect(pos, sz): 85 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 0-index 86 | """ 87 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) 88 | 89 | 90 | def rect_2_cxy_wh(rect): 91 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 0-index 92 | """ 93 | return np.array([rect[0]+rect[2]/2, rect[1]+rect[3]/2]), \ 94 | np.array([rect[2], rect[3]]) 95 | 96 | 97 | def cxy_wh_2_rect1(pos, sz): 98 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 1-index 99 | """ 100 | return np.array([pos[0]-sz[0]/2+1, pos[1]-sz[1]/2+1, sz[0], sz[1]]) 101 | 102 | 103 | def rect1_2_cxy_wh(rect): 104 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 1-index 105 | """ 106 | return np.array([rect[0]+rect[2]/2-1, rect[1]+rect[3]/2-1]), \ 107 | np.array([rect[2], rect[3]]) 108 | 109 | 110 | def get_axis_aligned_bbox(region): 111 | """ convert region to (cx, cy, w, h) that represent by axis aligned box 112 | """ 113 | nv = region.size 114 | if nv == 8: 115 | cx = np.mean(region[0::2]) 116 | cy = np.mean(region[1::2]) 117 | x1 = min(region[0::2]) 118 | x2 = max(region[0::2]) 119 | y1 = min(region[1::2]) 120 | y2 = max(region[1::2]) 121 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * \ 122 | np.linalg.norm(region[2:4] - region[4:6]) 123 | A2 = (x2 - x1) * (y2 - y1) 124 | s = np.sqrt(A1 / A2) 125 | w = s * (x2 - x1) + 1 126 | h = s * (y2 - y1) + 1 127 | else: 128 | x = region[0] 129 | y = region[1] 130 | w = region[2] 131 | h = region[3] 132 | cx = x+w/2 133 | cy = y+h/2 134 | return cx, cy, w, h 135 | 136 | 137 | def get_min_max_bbox(region): 138 | """ convert region to (cx, cy, w, h) that represent by mim-max box 139 | """ 140 | nv = region.size 141 | if nv == 8: 142 | cx = np.mean(region[0::2]) 143 | cy = np.mean(region[1::2]) 144 | x1 = min(region[0::2]) 145 | x2 = max(region[0::2]) 146 | y1 = min(region[1::2]) 147 | y2 = max(region[1::2]) 148 | w = x2 - x1 149 | h = y2 - y1 150 | else: 151 | x = region[0] 152 | y = region[1] 153 | w = region[2] 154 | h = region[3] 155 | cx = x+w/2 156 | cy = y+h/2 157 | return cx, cy, w, h 158 | -------------------------------------------------------------------------------- /pysot/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import os 9 | import socket 10 | import logging 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.distributed as dist 15 | 16 | from pysot.utils.log_helper import log_once 17 | 18 | logger = logging.getLogger('global') 19 | 20 | 21 | def average_reduce(v): 22 | if get_world_size() == 1: 23 | return v 24 | tensor = torch.cuda.FloatTensor(1) 25 | tensor[0] = v 26 | dist.all_reduce(tensor) 27 | v = tensor[0] / get_world_size() 28 | return v 29 | 30 | 31 | class DistModule(nn.Module): 32 | def __init__(self, module, bn_method=0): 33 | super(DistModule, self).__init__() 34 | self.module = module 35 | self.bn_method = bn_method 36 | if get_world_size() > 1: 37 | broadcast_params(self.module) 38 | else: 39 | self.bn_method = 0 # single proccess 40 | 41 | def forward(self, *args, **kwargs): 42 | broadcast_buffers(self.module, self.bn_method) 43 | return self.module(*args, **kwargs) 44 | 45 | def train(self, mode=True): 46 | super(DistModule, self).train(mode) 47 | self.module.train(mode) 48 | return self 49 | 50 | 51 | def broadcast_params(model): 52 | """ broadcast model parameters """ 53 | for p in model.state_dict().values(): 54 | dist.broadcast(p, 0) 55 | 56 | 57 | def broadcast_buffers(model, method=0): 58 | """ broadcast model buffers """ 59 | if method == 0: 60 | return 61 | 62 | world_size = get_world_size() 63 | 64 | for b in model._all_buffers(): 65 | if method == 1: # broadcast from main proccess 66 | dist.broadcast(b, 0) 67 | elif method == 2: # average 68 | dist.all_reduce(b) 69 | b /= world_size 70 | else: 71 | raise Exception('Invalid buffer broadcast code {}'.format(method)) 72 | 73 | 74 | inited = False 75 | 76 | 77 | def _dist_init(): 78 | ''' 79 | if guess right: 80 | ntasks: world_size (process num) 81 | proc_id: rank 82 | ''' 83 | rank = int(os.environ['RANK']) 84 | num_gpus = torch.cuda.device_count() 85 | torch.cuda.set_device(rank % num_gpus) 86 | dist.init_process_group(backend='nccl') 87 | world_size = dist.get_world_size() 88 | return rank, world_size 89 | 90 | 91 | def _get_local_ip(): 92 | try: 93 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 94 | s.connect(('8.8.8.8', 80)) 95 | ip = s.getsockname()[0] 96 | finally: 97 | s.close() 98 | return ip 99 | 100 | 101 | def dist_init(): 102 | global rank, world_size, inited 103 | try: 104 | rank, world_size = _dist_init() 105 | except RuntimeError as e: 106 | if 'public' in e.args[0]: 107 | logger.info(e) 108 | logger.info('Warning: use single process') 109 | rank, world_size = 0, 1 110 | else: 111 | raise RuntimeError(*e.args) 112 | inited = True 113 | return rank, world_size 114 | 115 | 116 | def get_rank(): 117 | if not inited: 118 | raise(Exception('dist not inited')) 119 | return rank 120 | 121 | 122 | def get_world_size(): 123 | if not inited: 124 | raise(Exception('dist not inited')) 125 | return world_size 126 | 127 | 128 | def reduce_gradients(model, _type='sum'): 129 | types = ['sum', 'avg'] 130 | assert _type in types, 'gradients method must be in "{}"'.format(types) 131 | log_once("gradients method is {}".format(_type)) 132 | if get_world_size() > 1: 133 | for param in model.parameters(): 134 | if param.requires_grad: 135 | dist.all_reduce(param.grad.data) 136 | if _type == 'avg': 137 | param.grad.data /= get_world_size() 138 | else: 139 | return None 140 | -------------------------------------------------------------------------------- /pysot/utils/log_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import os 9 | import logging 10 | import math 11 | import sys 12 | 13 | 14 | if hasattr(sys, 'frozen'): # support for py2exe 15 | _srcfile = "logging%s__init__%s" % (os.sep, __file__[-4:]) 16 | elif __file__[-4:].lower() in ['.pyc', '.pyo']: 17 | _srcfile = __file__[:-4] + '.py' 18 | else: 19 | _srcfile = __file__ 20 | _srcfile = os.path.normcase(_srcfile) 21 | 22 | logs = set() 23 | 24 | 25 | class Filter: 26 | def __init__(self, flag): 27 | self.flag = flag 28 | 29 | def filter(self, x): 30 | return self.flag 31 | 32 | 33 | class Dummy: 34 | def __init__(self, *arg, **kwargs): 35 | pass 36 | 37 | def __getattr__(self, arg): 38 | def dummy(*args, **kwargs): pass 39 | return dummy 40 | 41 | 42 | def get_format(logger, level): 43 | if 'RANK' in os.environ: 44 | rank = int(os.environ['RANK']) 45 | 46 | if level == logging.INFO: 47 | logger.addFilter(Filter(rank == 0)) 48 | else: 49 | rank = 0 50 | format_str = '[%(asctime)s-rk{}-%(filename)s#%(lineno)3d] %(message)s'.format(rank) 51 | formatter = logging.Formatter(format_str) 52 | return formatter 53 | 54 | 55 | def get_format_custom(logger, level): 56 | if 'RANK' in os.environ: 57 | rank = int(os.environ['RANK']) 58 | if level == logging.INFO: 59 | logger.addFilter(Filter(rank == 0)) 60 | else: 61 | rank = 0 62 | format_str = '[%(asctime)s-rk{}-%(message)s'.format(rank) 63 | formatter = logging.Formatter(format_str) 64 | return formatter 65 | 66 | 67 | def init_log(name, level=logging.INFO, format_func=get_format): 68 | if (name, level) in logs: 69 | return 70 | logs.add((name, level)) 71 | logger = logging.getLogger(name) 72 | logger.setLevel(level) 73 | ch = logging.StreamHandler() 74 | ch.setLevel(level) 75 | formatter = format_func(logger, level) 76 | ch.setFormatter(formatter) 77 | logger.addHandler(ch) 78 | return logger 79 | 80 | 81 | def add_file_handler(name, log_file, level=logging.INFO): 82 | logger = logging.getLogger(name) 83 | fh = logging.FileHandler(log_file) 84 | fh.setFormatter(get_format(logger, level)) 85 | logger.addHandler(fh) 86 | 87 | 88 | init_log('global') 89 | 90 | 91 | def print_speed(i, i_time, n): 92 | """print_speed(index, index_time, total_iteration)""" 93 | logger = logging.getLogger('global') 94 | average_time = i_time 95 | remaining_time = (n - i) * average_time 96 | remaining_day = math.floor(remaining_time / 86400) 97 | remaining_hour = math.floor(remaining_time / 3600 - 98 | remaining_day * 24) 99 | remaining_min = math.floor(remaining_time / 60 - 100 | remaining_day * 1440 - 101 | remaining_hour * 60) 102 | logger.info('Progress: %d / %d [%d%%], Speed: %.3f s/iter, ETA %d:%02d:%02d (D:H:M)\n' % 103 | (i, n, i / n * 100, 104 | average_time, 105 | remaining_day, remaining_hour, remaining_min)) 106 | 107 | 108 | def find_caller(): 109 | def current_frame(): 110 | try: 111 | raise Exception 112 | except: 113 | return sys.exc_info()[2].tb_frame.f_back 114 | 115 | f = current_frame() 116 | if f is not None: 117 | f = f.f_back 118 | rv = "(unknown file)", 0, "(unknown function)" 119 | while hasattr(f, "f_code"): 120 | co = f.f_code 121 | filename = os.path.normcase(co.co_filename) 122 | rv = (co.co_filename, f.f_lineno, co.co_name) 123 | if filename == _srcfile: 124 | f = f.f_back 125 | continue 126 | break 127 | rv = list(rv) 128 | rv[0] = os.path.basename(rv[0]) 129 | return rv 130 | 131 | 132 | class LogOnce: 133 | def __init__(self): 134 | self.logged = set() 135 | self.logger = init_log('log_once', format_func=get_format_custom) 136 | 137 | def log(self, strings): 138 | fn, lineno, caller = find_caller() 139 | key = (fn, lineno, caller, strings) 140 | if key in self.logged: 141 | return 142 | self.logged.add(key) 143 | message = "{filename:s}<{caller}>#{lineno:3d}] {strings}".format( 144 | filename=fn, lineno=lineno, strings=strings, caller=caller) 145 | self.logger.info(message) 146 | 147 | 148 | once_logger = LogOnce() 149 | 150 | 151 | def log_once(strings): 152 | once_logger.log(strings) 153 | 154 | 155 | def main(): 156 | for i, lvl in enumerate([logging.DEBUG, logging.INFO, 157 | logging.WARNING, logging.ERROR, 158 | logging.CRITICAL]): 159 | log_name = str(lvl) 160 | init_log(log_name, lvl) 161 | logger = logging.getLogger(log_name) 162 | print('****cur lvl:{}'.format(lvl)) 163 | logger.debug('debug') 164 | logger.info('info') 165 | logger.warning('warning') 166 | logger.error('error') 167 | logger.critical('critiacal') 168 | 169 | 170 | if __name__ == '__main__': 171 | main() 172 | for i in range(10): 173 | log_once('xxx') 174 | -------------------------------------------------------------------------------- /pysot/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import os 9 | 10 | from colorama import Fore, Style 11 | 12 | 13 | __all__ = ['commit', 'describe'] 14 | 15 | 16 | def _exec(cmd): 17 | f = os.popen(cmd, 'r', 1) 18 | return f.read().strip() 19 | 20 | 21 | def _bold(s): 22 | return "\033[1m%s\033[0m" % s 23 | 24 | 25 | def _color(s): 26 | return f'{Fore.RED}{s}{Style.RESET_ALL}' 27 | 28 | 29 | def _describe(model, lines=None, spaces=0): 30 | head = " " * spaces 31 | for name, p in model.named_parameters(): 32 | if '.' in name: 33 | continue 34 | if p.requires_grad: 35 | name = _color(name) 36 | line = "{head}- {name}".format(head=head, name=name) 37 | lines.append(line) 38 | 39 | for name, m in model.named_children(): 40 | space_num = len(name) + spaces + 1 41 | if m.training: 42 | name = _color(name) 43 | line = "{head}.{name} ({type})".format( 44 | head=head, 45 | name=name, 46 | type=m.__class__.__name__) 47 | lines.append(line) 48 | _describe(m, lines, space_num) 49 | 50 | 51 | def commit(): 52 | root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) 53 | cmd = "cd {}; git log | head -n1 | awk '{{print $2}}'".format(root) 54 | commit = _exec(cmd) 55 | cmd = "cd {}; git log --oneline | head -n1".format(root) 56 | commit_log = _exec(cmd) 57 | return "commit : {}\n log : {}".format(commit, commit_log) 58 | 59 | 60 | def describe(net, name=None): 61 | num = 0 62 | lines = [] 63 | if name is not None: 64 | lines.append(name) 65 | num = len(name) 66 | _describe(net, lines, num) 67 | return "\n".join(lines) 68 | -------------------------------------------------------------------------------- /pysot/utils/model_load.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import logging 9 | 10 | import torch 11 | import torch._utils 12 | try: 13 | torch._utils._rebuild_tensor_v2 14 | except AttributeError: 15 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): 16 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 17 | tensor.requires_grad = requires_grad 18 | tensor._backward_hooks = backward_hooks 19 | return tensor 20 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 21 | 22 | logger = logging.getLogger('global') 23 | 24 | 25 | def check_keys(model, pretrained_state_dict): 26 | ckpt_keys = set(pretrained_state_dict.keys()) 27 | model_keys = set(model.state_dict().keys()) 28 | used_pretrained_keys = model_keys & ckpt_keys 29 | unused_pretrained_keys = ckpt_keys - model_keys 30 | missing_keys = model_keys - ckpt_keys 31 | # filter 'num_batches_tracked' 32 | missing_keys = [x for x in missing_keys 33 | if not x.endswith('num_batches_tracked')] 34 | if len(missing_keys) > 0: 35 | logger.info('[Warning] missing keys: {}'.format(missing_keys)) 36 | logger.info('missing keys:{}'.format(len(missing_keys))) 37 | if len(unused_pretrained_keys) > 0: 38 | logger.info('[Warning] unused_pretrained_keys: {}'.format( 39 | unused_pretrained_keys)) 40 | logger.info('unused checkpoint keys:{}'.format( 41 | len(unused_pretrained_keys))) 42 | logger.info('used keys:{}'.format(len(used_pretrained_keys))) 43 | assert len(used_pretrained_keys) > 0, \ 44 | 'load NONE from pretrained checkpoint' 45 | return True 46 | 47 | 48 | def remove_prefix(state_dict, prefix): 49 | ''' Old style model is stored with all names of parameters 50 | share common prefix 'module.' ''' 51 | logger.info('remove prefix \'{}\''.format(prefix)) 52 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 53 | return {f(key): value for key, value in state_dict.items()} 54 | 55 | 56 | def load_pretrain(model, pretrained_path): 57 | logger.info('load pretrained model from {}'.format(pretrained_path)) 58 | device = torch.cuda.current_device() 59 | pretrained_dict = torch.load(pretrained_path, 60 | map_location=lambda storage, loc: storage.cuda(device)) 61 | if "state_dict" in pretrained_dict.keys(): 62 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 63 | 'module.') 64 | else: 65 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 66 | 67 | try: 68 | check_keys(model, pretrained_dict) 69 | except: 70 | logger.info('[Warning]: using pretrain as features.\ 71 | Adding "features." as prefix') 72 | new_dict = {} 73 | for k, v in pretrained_dict.items(): 74 | k = 'features.' + k 75 | new_dict[k] = v 76 | pretrained_dict = new_dict 77 | check_keys(model, pretrained_dict) 78 | model.load_state_dict(pretrained_dict, strict=False) 79 | return model 80 | 81 | 82 | def restore_from(model, optimizer, ckpt_path): 83 | device = torch.cuda.current_device() 84 | ckpt = torch.load(ckpt_path, 85 | map_location=lambda storage, loc: storage.cuda(device)) 86 | epoch = ckpt['epoch'] 87 | 88 | ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.') 89 | check_keys(model, ckpt_model_dict) 90 | model.load_state_dict(ckpt_model_dict, strict=False) 91 | 92 | check_keys(optimizer, ckpt['optimizer']) 93 | optimizer.load_state_dict(ckpt['optimizer']) 94 | return model, optimizer, epoch 95 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | yacs 3 | tqdm 4 | pyyaml 5 | matplotlib 6 | colorama 7 | cython 8 | tensorboardX 9 | filterpy 10 | pillow -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | 5 | ext_modules = [ 6 | Extension( 7 | name='toolkit.utils.region', 8 | sources=[ 9 | 'toolkit/utils/region.pyx', 10 | 'toolkit/utils/src/region.c', 11 | ], 12 | include_dirs=[ 13 | 'toolkit/utils/src' 14 | ] 15 | ) 16 | ] 17 | 18 | setup( 19 | name='toolkit', 20 | packages=['toolkit'], 21 | ext_modules=cythonize(ext_modules) 22 | ) 23 | -------------------------------------------------------------------------------- /testing_dataset/README.md: -------------------------------------------------------------------------------- 1 | # Testing dataset directory 2 | # putting your testing dataset here 3 | - [x] [VOT2016](http://www.votchallenge.net/vot2016/dataset.html) 4 | - [x] [VOT2018](http://www.votchallenge.net/vot2018/dataset.html) 5 | - [x] [VOT2018-LT](http://www.votchallenge.net/vot2018/dataset.html) 6 | - [x] [OTB100(OTB2015)](http://cvlab.hanyang.ac.kr/tracker_benchmark/datasets.html) 7 | - [x] [UAV123](https://ivul.kaust.edu.sa/Pages/Dataset-UAV123.aspx) 8 | - [x] [NFS](http://ci2cv.net/nfs/index.html) 9 | - [x] [LaSOT](https://cis.temple.edu/lasot/) 10 | - [ ] [TrackingNet (Evaluation on Server)](https://tracking-net.org) 11 | - [ ] [GOT-10k (Evaluation on Server)](http://got-10k.aitestunion.com) 12 | 13 | ## Download Dataset 14 | Download json files used in our toolkit [baidu pan](https://pan.baidu.com/s/1js0Qhykqqur7_lNRtle1tA) 15 | 16 | 1. Put CVRP13.json, OTB100.json, OTB50.json in OTB100 dataset directory (you need to copy Jogging to Jogging-1 and Jogging-2, and copy Skating2 to Skating2-1 and Skating2-2 or using softlink) 17 | 18 | The directory should have the below format 19 | 20 | | -- OTB100/ 21 | 22 | ​ | -- Basketball 23 | 24 | ​ | ...... 25 | 26 | ​ | -- Woman 27 | 28 | ​ | -- OTB100.json 29 | 30 | ​ | -- OTB50.json 31 | 32 | ​ | -- CVPR13.json 33 | 34 | 2. Put all other jsons in the dataset directory like in step 1 35 | -------------------------------------------------------------------------------- /toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/toolkit/__init__.py -------------------------------------------------------------------------------- /toolkit/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .vot import VOTDataset, VOTLTDataset 2 | from .otb import OTBDataset 3 | from .uav import UAVDataset 4 | from .lasot import LaSOTDataset 5 | from .nfs import NFSDataset 6 | from .trackingnet import TrackingNetDataset 7 | from .got10k import GOT10kDataset 8 | from .visdrone import VisDroneDataset 9 | 10 | class DatasetFactory(object): 11 | @staticmethod 12 | def create_dataset(**kwargs): 13 | """ 14 | Args: 15 | name: dataset name 'OTB2015', 'LaSOT', 'UAV123', 'NFS240', 'NFS30', 16 | 'VOT2018', 'VOT2016', 'VOT2018-LT' 17 | dataset_root: dataset root 18 | load_img: wether to load image 19 | Return: 20 | dataset 21 | """ 22 | assert 'name' in kwargs, "should provide dataset name" 23 | name = kwargs['name'] 24 | if 'OTB' in name: 25 | dataset = OTBDataset(**kwargs) 26 | elif 'LaSOT' == name: 27 | dataset = LaSOTDataset(**kwargs) 28 | elif 'UAV' in name: 29 | dataset = UAVDataset(**kwargs) 30 | elif 'NFS' in name: 31 | dataset = NFSDataset(**kwargs) 32 | elif 'VOT2018' == name or 'VOT2016' == name or 'VOT2019' == name: 33 | dataset = VOTDataset(**kwargs) 34 | elif 'VOT2018-LT' == name: 35 | dataset = VOTLTDataset(**kwargs) 36 | elif 'TrackingNet' == name: 37 | dataset = TrackingNetDataset(**kwargs) 38 | elif 'carplate' == name: 39 | dataset = TrackingNetDataset(**kwargs) 40 | elif 'GOT-10k' == name: 41 | dataset = GOT10kDataset(**kwargs) 42 | elif 'VisDrone' in name: 43 | dataset = VisDroneDataset(**kwargs) 44 | else: 45 | raise Exception("unknow dataset {}".format(kwargs['name'])) 46 | return dataset 47 | 48 | -------------------------------------------------------------------------------- /toolkit/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | class Dataset(object): 4 | def __init__(self, name, dataset_root): 5 | self.name = name 6 | self.dataset_root = dataset_root 7 | self.videos = None 8 | 9 | def __getitem__(self, idx): 10 | if isinstance(idx, str): 11 | return self.videos[idx] 12 | elif isinstance(idx, int): 13 | return self.videos[sorted(list(self.videos.keys()))[idx]] 14 | 15 | def __len__(self): 16 | return len(self.videos) 17 | 18 | def __iter__(self): 19 | keys = sorted(list(self.videos.keys())) 20 | for key in keys: 21 | yield self.videos[key] 22 | 23 | def set_tracker(self, path, tracker_names): 24 | """ 25 | Args: 26 | path: path to tracker results, 27 | tracker_names: list of tracker name 28 | """ 29 | self.tracker_path = path 30 | self.tracker_names = tracker_names 31 | # for video in tqdm(self.videos.values(), 32 | # desc='loading tacker result', ncols=100): 33 | # video.load_tracker(path, tracker_names) 34 | -------------------------------------------------------------------------------- /toolkit/datasets/got10k.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | from .dataset import Dataset 8 | from .video import Video 9 | 10 | class GOT10kVideo(Video): 11 | """ 12 | Args: 13 | name: video name 14 | root: dataset root 15 | video_dir: video directory 16 | init_rect: init rectangle 17 | img_names: image names 18 | gt_rect: groundtruth rectangle 19 | attr: attribute of video 20 | """ 21 | def __init__(self, name, root, video_dir, init_rect, img_names, 22 | gt_rect, attr, load_img=False): 23 | super(GOT10kVideo, self).__init__(name, root, video_dir, 24 | init_rect, img_names, gt_rect, attr, load_img) 25 | 26 | # def load_tracker(self, path, tracker_names=None): 27 | # """ 28 | # Args: 29 | # path(str): path to result 30 | # tracker_name(list): name of tracker 31 | # """ 32 | # if not tracker_names: 33 | # tracker_names = [x.split('/')[-1] for x in glob(path) 34 | # if os.path.isdir(x)] 35 | # if isinstance(tracker_names, str): 36 | # tracker_names = [tracker_names] 37 | # # self.pred_trajs = {} 38 | # for name in tracker_names: 39 | # traj_file = os.path.join(path, name, self.name+'.txt') 40 | # if os.path.exists(traj_file): 41 | # with open(traj_file, 'r') as f : 42 | # self.pred_trajs[name] = [list(map(float, x.strip().split(','))) 43 | # for x in f.readlines()] 44 | # if len(self.pred_trajs[name]) != len(self.gt_traj): 45 | # print(name, len(self.pred_trajs[name]), len(self.gt_traj), self.name) 46 | # else: 47 | 48 | # self.tracker_names = list(self.pred_trajs.keys()) 49 | 50 | class GOT10kDataset(Dataset): 51 | """ 52 | Args: 53 | name: dataset name, should be "NFS30" or "NFS240" 54 | dataset_root, dataset root dir 55 | """ 56 | def __init__(self, name, dataset_root, load_img=False): 57 | super(GOT10kDataset, self).__init__(name, dataset_root) 58 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 59 | meta_data = json.load(f) 60 | 61 | # load videos 62 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 63 | self.videos = {} 64 | for video in pbar: 65 | pbar.set_postfix_str(video) 66 | self.videos[video] = GOT10kVideo(video, 67 | dataset_root, 68 | meta_data[video]['video_dir'], 69 | meta_data[video]['init_rect'], 70 | meta_data[video]['img_names'], 71 | meta_data[video]['gt_rect'], 72 | None) 73 | self.attr = {} 74 | self.attr['ALL'] = list(self.videos.keys()) 75 | -------------------------------------------------------------------------------- /toolkit/datasets/lasot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | from tqdm import tqdm 6 | from glob import glob 7 | 8 | from .dataset import Dataset 9 | from .video import Video 10 | 11 | class LaSOTVideo(Video): 12 | """ 13 | Args: 14 | name: video name 15 | root: dataset root 16 | video_dir: video directory 17 | init_rect: init rectangle 18 | img_names: image names 19 | gt_rect: groundtruth rectangle 20 | attr: attribute of video 21 | """ 22 | def __init__(self, name, root, video_dir, init_rect, img_names, 23 | gt_rect, attr, absent, load_img=False): 24 | super(LaSOTVideo, self).__init__(name, root, video_dir, 25 | init_rect, img_names, gt_rect, attr, load_img) 26 | self.absent = np.array(absent, np.int8) 27 | 28 | def load_tracker(self, path, tracker_names=None, store=True): 29 | """ 30 | Args: 31 | path(str): path to result 32 | tracker_name(list): name of tracker 33 | """ 34 | if not tracker_names: 35 | tracker_names = [x.split('/')[-1] for x in glob(path) 36 | if os.path.isdir(x)] 37 | if isinstance(tracker_names, str): 38 | tracker_names = [tracker_names] 39 | for name in tracker_names: 40 | traj_file = os.path.join(path, name, self.name+'.txt') 41 | if os.path.exists(traj_file): 42 | with open(traj_file, 'r') as f : 43 | pred_traj = [list(map(float, x.strip().split(','))) 44 | for x in f.readlines()] 45 | else: 46 | print("File not exists: ", traj_file) 47 | if self.name == 'monkey-17': 48 | pred_traj = pred_traj[:len(self.gt_traj)] 49 | if store: 50 | self.pred_trajs[name] = pred_traj 51 | else: 52 | return pred_traj 53 | self.tracker_names = list(self.pred_trajs.keys()) 54 | 55 | 56 | 57 | class LaSOTDataset(Dataset): 58 | """ 59 | Args: 60 | name: dataset name, should be 'OTB100', 'CVPR13', 'OTB50' 61 | dataset_root: dataset root 62 | load_img: wether to load all imgs 63 | """ 64 | def __init__(self, name, dataset_root, load_img=False): 65 | super(LaSOTDataset, self).__init__(name, dataset_root) 66 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 67 | meta_data = json.load(f) 68 | 69 | # load videos 70 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 71 | self.videos = {} 72 | for video in pbar: 73 | pbar.set_postfix_str(video) 74 | self.videos[video] = LaSOTVideo(video, 75 | dataset_root, 76 | meta_data[video]['video_dir'], 77 | meta_data[video]['init_rect'], 78 | meta_data[video]['img_names'], 79 | meta_data[video]['gt_rect'], 80 | meta_data[video]['attr'], 81 | meta_data[video]['absent']) 82 | 83 | # set attr 84 | attr = [] 85 | for x in self.videos.values(): 86 | attr += x.attr 87 | attr = set(attr) 88 | self.attr = {} 89 | self.attr['ALL'] = list(self.videos.keys()) 90 | for x in attr: 91 | self.attr[x] = [] 92 | for k, v in self.videos.items(): 93 | for attr_ in v.attr: 94 | self.attr[attr_].append(k) 95 | 96 | 97 | -------------------------------------------------------------------------------- /toolkit/datasets/nfs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from tqdm import tqdm 6 | from glob import glob 7 | 8 | from .dataset import Dataset 9 | from .video import Video 10 | 11 | 12 | class NFSVideo(Video): 13 | """ 14 | Args: 15 | name: video name 16 | root: dataset root 17 | video_dir: video directory 18 | init_rect: init rectangle 19 | img_names: image names 20 | gt_rect: groundtruth rectangle 21 | attr: attribute of video 22 | """ 23 | def __init__(self, name, root, video_dir, init_rect, img_names, 24 | gt_rect, attr, load_img=False): 25 | super(NFSVideo, self).__init__(name, root, video_dir, 26 | init_rect, img_names, gt_rect, attr, load_img) 27 | 28 | # def load_tracker(self, path, tracker_names=None): 29 | # """ 30 | # Args: 31 | # path(str): path to result 32 | # tracker_name(list): name of tracker 33 | # """ 34 | # if not tracker_names: 35 | # tracker_names = [x.split('/')[-1] for x in glob(path) 36 | # if os.path.isdir(x)] 37 | # if isinstance(tracker_names, str): 38 | # tracker_names = [tracker_names] 39 | # # self.pred_trajs = {} 40 | # for name in tracker_names: 41 | # traj_file = os.path.join(path, name, self.name+'.txt') 42 | # if os.path.exists(traj_file): 43 | # with open(traj_file, 'r') as f : 44 | # self.pred_trajs[name] = [list(map(float, x.strip().split(','))) 45 | # for x in f.readlines()] 46 | # if len(self.pred_trajs[name]) != len(self.gt_traj): 47 | # print(name, len(self.pred_trajs[name]), len(self.gt_traj), self.name) 48 | # else: 49 | 50 | # self.tracker_names = list(self.pred_trajs.keys()) 51 | 52 | class NFSDataset(Dataset): 53 | """ 54 | Args: 55 | name: dataset name, should be "NFS30" or "NFS240" 56 | dataset_root, dataset root dir 57 | """ 58 | def __init__(self, name, dataset_root, load_img=False): 59 | super(NFSDataset, self).__init__(name, dataset_root) 60 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 61 | meta_data = json.load(f) 62 | 63 | # load videos 64 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 65 | self.videos = {} 66 | for video in pbar: 67 | pbar.set_postfix_str(video) 68 | self.videos[video] = NFSVideo(video, 69 | dataset_root, 70 | meta_data[video]['video_dir'], 71 | meta_data[video]['init_rect'], 72 | meta_data[video]['img_names'], 73 | meta_data[video]['gt_rect'], 74 | None) 75 | 76 | self.attr = {} 77 | self.attr['ALL'] = list(self.videos.keys()) 78 | -------------------------------------------------------------------------------- /toolkit/datasets/otb.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | class OTBVideo(Video): 14 | """ 15 | Args: 16 | name: video name 17 | root: dataset root 18 | video_dir: video directory 19 | init_rect: init rectangle 20 | img_names: image names 21 | gt_rect: groundtruth rectangle 22 | attr: attribute of video 23 | """ 24 | def __init__(self, name, root, video_dir, init_rect, img_names, 25 | gt_rect, attr, load_img=False): 26 | super(OTBVideo, self).__init__(name, root, video_dir, 27 | init_rect, img_names, gt_rect, attr, load_img) 28 | 29 | def load_tracker(self, path, tracker_names=None, store=True): 30 | """ 31 | Args: 32 | path(str): path to result 33 | tracker_name(list): name of tracker 34 | """ 35 | if not tracker_names: 36 | tracker_names = [x.split('/')[-1] for x in glob(path) 37 | if os.path.isdir(x)] 38 | if isinstance(tracker_names, str): 39 | tracker_names = [tracker_names] 40 | for name in tracker_names: 41 | traj_file = os.path.join(path, name, self.name+'.txt') 42 | if not os.path.exists(traj_file): 43 | if self.name == 'FleetFace': 44 | txt_name = 'fleetface.txt' 45 | elif self.name == 'Jogging-1': 46 | txt_name = 'jogging_1.txt' 47 | elif self.name == 'Jogging-2': 48 | txt_name = 'jogging_2.txt' 49 | elif self.name == 'Skating2-1': 50 | txt_name = 'skating2_1.txt' 51 | elif self.name == 'Skating2-2': 52 | txt_name = 'skating2_2.txt' 53 | elif self.name == 'FaceOcc1': 54 | txt_name = 'faceocc1.txt' 55 | elif self.name == 'FaceOcc2': 56 | txt_name = 'faceocc2.txt' 57 | elif self.name == 'Human4-2': 58 | txt_name = 'human4_2.txt' 59 | else: 60 | txt_name = self.name[0].lower()+self.name[1:]+'.txt' 61 | traj_file = os.path.join(path, name, txt_name) 62 | if os.path.exists(traj_file): 63 | with open(traj_file, 'r') as f : 64 | pred_traj = [list(map(float, x.strip().split(','))) 65 | for x in f.readlines()] 66 | if len(pred_traj) != len(self.gt_traj): 67 | print(name, len(pred_traj), len(self.gt_traj), self.name) 68 | if store: 69 | self.pred_trajs[name] = pred_traj 70 | else: 71 | return pred_traj 72 | else: 73 | print(traj_file) 74 | self.tracker_names = list(self.pred_trajs.keys()) 75 | 76 | 77 | 78 | class OTBDataset(Dataset): 79 | """ 80 | Args: 81 | name: dataset name, should be 'OTB100', 'CVPR13', 'OTB50' 82 | dataset_root: dataset root 83 | load_img: wether to load all imgs 84 | """ 85 | def __init__(self, name, dataset_root, load_img=False): 86 | super(OTBDataset, self).__init__(name, dataset_root) 87 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 88 | meta_data = json.load(f) 89 | 90 | # load videos 91 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 92 | self.videos = {} 93 | for video in pbar: 94 | pbar.set_postfix_str(video) 95 | self.videos[video] = OTBVideo(video, 96 | dataset_root, 97 | meta_data[video]['video_dir'], 98 | meta_data[video]['init_rect'], 99 | meta_data[video]['img_names'], 100 | meta_data[video]['gt_rect'], 101 | meta_data[video]['attr'], 102 | load_img) 103 | 104 | # set attr 105 | attr = [] 106 | for x in self.videos.values(): 107 | attr += x.attr 108 | attr = set(attr) 109 | self.attr = {} 110 | self.attr['ALL'] = list(self.videos.keys()) 111 | for x in attr: 112 | self.attr[x] = [] 113 | for k, v in self.videos.items(): 114 | for attr_ in v.attr: 115 | self.attr[attr_].append(k) 116 | -------------------------------------------------------------------------------- /toolkit/datasets/trackingnet.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from tqdm import tqdm 6 | from glob import glob 7 | 8 | from .dataset import Dataset 9 | from .video import Video 10 | 11 | class TrackingNetVideo(Video): 12 | """ 13 | Args: 14 | name: video name 15 | root: dataset root 16 | video_dir: video directory 17 | init_rect: init rectangle 18 | img_names: image names 19 | gt_rect: groundtruth rectangle 20 | attr: attribute of video 21 | """ 22 | def __init__(self, name, root, video_dir, init_rect, img_names, 23 | gt_rect, attr, load_img=False): 24 | super(TrackingNetVideo, self).__init__(name, root, video_dir, 25 | init_rect, img_names, gt_rect, attr, load_img) 26 | 27 | # def load_tracker(self, path, tracker_names=None): 28 | # """ 29 | # Args: 30 | # path(str): path to result 31 | # tracker_name(list): name of tracker 32 | # """ 33 | # if not tracker_names: 34 | # tracker_names = [x.split('/')[-1] for x in glob(path) 35 | # if os.path.isdir(x)] 36 | # if isinstance(tracker_names, str): 37 | # tracker_names = [tracker_names] 38 | # # self.pred_trajs = {} 39 | # for name in tracker_names: 40 | # traj_file = os.path.join(path, name, self.name+'.txt') 41 | # if os.path.exists(traj_file): 42 | # with open(traj_file, 'r') as f : 43 | # self.pred_trajs[name] = [list(map(float, x.strip().split(','))) 44 | # for x in f.readlines()] 45 | # if len(self.pred_trajs[name]) != len(self.gt_traj): 46 | # print(name, len(self.pred_trajs[name]), len(self.gt_traj), self.name) 47 | # else: 48 | 49 | # self.tracker_names = list(self.pred_trajs.keys()) 50 | 51 | class TrackingNetDataset(Dataset): 52 | """ 53 | Args: 54 | name: dataset name, should be "NFS30" or "NFS240" 55 | dataset_root, dataset root dir 56 | """ 57 | def __init__(self, name, dataset_root, load_img=False): 58 | super(TrackingNetDataset, self).__init__(name, dataset_root) 59 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 60 | meta_data = json.load(f) 61 | 62 | # load videos 63 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 64 | self.videos = {} 65 | for video in pbar: 66 | pbar.set_postfix_str(video) 67 | self.videos[video] = TrackingNetVideo(video, 68 | dataset_root, 69 | meta_data[video]['video_dir'], 70 | meta_data[video]['init_rect'], 71 | meta_data[video]['img_names'], 72 | meta_data[video]['gt_rect'], 73 | None) 74 | self.attr = {} 75 | self.attr['ALL'] = list(self.videos.keys()) 76 | -------------------------------------------------------------------------------- /toolkit/datasets/uav.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from tqdm import tqdm 5 | from glob import glob 6 | 7 | from .dataset import Dataset 8 | from .video import Video 9 | 10 | class UAVVideo(Video): 11 | """ 12 | Args: 13 | name: video name 14 | root: dataset root 15 | video_dir: video directory 16 | init_rect: init rectangle 17 | img_names: image names 18 | gt_rect: groundtruth rectangle 19 | attr: attribute of video 20 | """ 21 | def __init__(self, name, root, video_dir, init_rect, img_names, 22 | gt_rect, attr, load_img=False): 23 | super(UAVVideo, self).__init__(name, root, video_dir, 24 | init_rect, img_names, gt_rect, attr, load_img) 25 | 26 | 27 | class UAVDataset(Dataset): 28 | """ 29 | Args: 30 | name: dataset name, should be 'UAV123', 'UAV20L' 31 | dataset_root: dataset root 32 | load_img: wether to load all imgs 33 | """ 34 | def __init__(self, name, dataset_root, load_img=False): 35 | super(UAVDataset, self).__init__(name, dataset_root) 36 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 37 | meta_data = json.load(f) 38 | 39 | # load videos 40 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 41 | self.videos = {} 42 | for video in pbar: 43 | pbar.set_postfix_str(video) 44 | self.videos[video] = UAVVideo(video, 45 | dataset_root, 46 | meta_data[video]['video_dir'], 47 | meta_data[video]['init_rect'], 48 | meta_data[video]['img_names'], 49 | meta_data[video]['gt_rect'], 50 | meta_data[video]['attr']) 51 | 52 | # set attr 53 | attr = [] 54 | for x in self.videos.values(): 55 | attr += x.attr 56 | attr = set(attr) 57 | self.attr = {} 58 | self.attr['ALL'] = list(self.videos.keys()) 59 | for x in attr: 60 | self.attr[x] = [] 61 | for k, v in self.videos.items(): 62 | for attr_ in v.attr: 63 | self.attr[attr_].append(k) 64 | 65 | -------------------------------------------------------------------------------- /toolkit/datasets/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import re 4 | import numpy as np 5 | import json 6 | 7 | from glob import glob 8 | 9 | class Video(object): 10 | def __init__(self, name, root, video_dir, init_rect, img_names, 11 | gt_rect, attr, load_img=False): 12 | self.name = name 13 | self.video_dir = video_dir 14 | self.init_rect = init_rect 15 | self.gt_traj = gt_rect 16 | self.attr = attr 17 | self.pred_trajs = {} 18 | self.img_names = [os.path.join(root, x) for x in img_names] 19 | self.imgs = None 20 | 21 | if load_img: 22 | self.imgs = [cv2.imread(x) for x in self.img_names] 23 | self.width = self.imgs[0].shape[1] 24 | self.height = self.imgs[0].shape[0] 25 | else: 26 | img = cv2.imread(self.img_names[0]) 27 | assert img is not None, self.img_names[0] 28 | self.width = img.shape[1] 29 | self.height = img.shape[0] 30 | 31 | def load_tracker(self, path, tracker_names=None, store=True): 32 | """ 33 | Args: 34 | path(str): path to result 35 | tracker_name(list): name of tracker 36 | """ 37 | if not tracker_names: 38 | tracker_names = [x.split('/')[-1] for x in glob(path) 39 | if os.path.isdir(x)] 40 | if isinstance(tracker_names, str): 41 | tracker_names = [tracker_names] 42 | for name in tracker_names: 43 | traj_file = os.path.join(path, name, self.name+'.txt') 44 | if os.path.exists(traj_file): 45 | with open(traj_file, 'r') as f : 46 | pred_traj = [list(map(float, x.strip().split(','))) 47 | for x in f.readlines()] 48 | if len(pred_traj) != len(self.gt_traj): 49 | print(name, len(pred_traj), len(self.gt_traj), self.name) 50 | if store: 51 | self.pred_trajs[name] = pred_traj 52 | else: 53 | return pred_traj 54 | else: 55 | print(traj_file) 56 | self.tracker_names = list(self.pred_trajs.keys()) 57 | 58 | def load_img(self): 59 | if self.imgs is None: 60 | self.imgs = [cv2.imread(x) for x in self.img_names] 61 | self.width = self.imgs[0].shape[1] 62 | self.height = self.imgs[0].shape[0] 63 | 64 | def free_img(self): 65 | self.imgs = None 66 | 67 | def __len__(self): 68 | return len(self.img_names) 69 | 70 | def __getitem__(self, idx): 71 | if self.imgs is None: 72 | return cv2.imread(self.img_names[idx]), self.gt_traj[idx] 73 | else: 74 | return self.imgs[idx], self.gt_traj[idx] 75 | 76 | def __iter__(self): 77 | for i in range(len(self.img_names)): 78 | if self.imgs is not None: 79 | yield self.imgs[i], self.gt_traj[i] 80 | else: 81 | yield cv2.imread(self.img_names[i]), self.gt_traj[i] 82 | 83 | def draw_box(self, roi, img, linewidth, color, name=None): 84 | """ 85 | roi: rectangle or polygon 86 | img: numpy array img 87 | linewith: line width of the bbox 88 | """ 89 | if len(roi) > 6 and len(roi) % 2 == 0: 90 | pts = np.array(roi, np.int32).reshape(-1, 1, 2) 91 | color = tuple(map(int, color)) 92 | img = cv2.polylines(img, [pts], True, color, linewidth) 93 | pt = (pts[0, 0, 0], pts[0, 0, 1]-5) 94 | if name: 95 | img = cv2.putText(img, name, pt, cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, color, 1) 96 | elif len(roi) == 4: 97 | if not np.isnan(roi[0]): 98 | roi = list(map(int, roi)) 99 | color = tuple(map(int, color)) 100 | img = cv2.rectangle(img, (roi[0], roi[1]), (roi[0]+roi[2], roi[1]+roi[3]), 101 | color, linewidth) 102 | if name: 103 | img = cv2.putText(img, name, (roi[0], roi[1]-5), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, color, 1) 104 | return img 105 | 106 | def show(self, pred_trajs={}, linewidth=2, show_name=False): 107 | """ 108 | pred_trajs: dict of pred_traj, {'tracker_name': list of traj} 109 | pred_traj should contain polygon or rectangle(x, y, width, height) 110 | linewith: line width of the bbox 111 | """ 112 | assert self.imgs is not None 113 | video = [] 114 | cv2.namedWindow(self.name, cv2.WINDOW_NORMAL) 115 | colors = {} 116 | if len(pred_trajs) == 0 and len(self.pred_trajs) > 0: 117 | pred_trajs = self.pred_trajs 118 | for i, (roi, img) in enumerate(zip(self.gt_traj, 119 | self.imgs[self.start_frame:self.end_frame+1])): 120 | img = img.copy() 121 | if len(img.shape) == 2: 122 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 123 | else: 124 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 125 | img = self.draw_box(roi, img, linewidth, (0, 255, 0), 126 | 'gt' if show_name else None) 127 | for name, trajs in pred_trajs.items(): 128 | if name not in colors: 129 | color = tuple(np.random.randint(0, 256, 3)) 130 | colors[name] = color 131 | else: 132 | color = colors[name] 133 | img = self.draw_box(trajs[0][i], img, linewidth, color, 134 | name if show_name else None) 135 | cv2.putText(img, str(i+self.start_frame), (5, 20), 136 | cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (255, 255, 0), 2) 137 | cv2.imshow(self.name, img) 138 | cv2.waitKey(40) 139 | video.append(img.copy()) 140 | return video 141 | -------------------------------------------------------------------------------- /toolkit/datasets/visdrone.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | from .dataset import Dataset 9 | from .video import Video 10 | 11 | 12 | class VisDroneVideo(Video): 13 | """ 14 | Args: 15 | name: video name 16 | root: dataset root 17 | video_dir: video directory 18 | init_rect: init rectangle 19 | img_names: image names 20 | gt_rect: groundtruth rectangle 21 | attr: attribute of video 22 | """ 23 | def __init__(self, name, root, video_dir, init_rect, img_names, 24 | gt_rect, attr, load_img=False): 25 | super(VisDroneVideo, self).__init__(name, root, video_dir, 26 | init_rect, img_names, gt_rect, attr, load_img) 27 | 28 | def load_tracker(self, path, tracker_names=None, store=True): 29 | """ 30 | Args: 31 | path(str): path to result 32 | tracker_name(list): name of tracker 33 | """ 34 | if not tracker_names: 35 | tracker_names = [x.split('/')[-1] for x in glob(path) 36 | if os.path.isdir(x)] 37 | if isinstance(tracker_names, str): 38 | tracker_names = [tracker_names] 39 | for name in tracker_names: 40 | traj_file = os.path.join(path, name, self.name+'.txt') 41 | # if not os.path.exists(traj_file): 42 | # if self.name == 'FleetFace': 43 | # txt_name = 'fleetface.txt' 44 | # elif self.name == 'Jogging-1': 45 | # txt_name = 'jogging_1.txt' 46 | # elif self.name == 'Jogging-2': 47 | # txt_name = 'jogging_2.txt' 48 | # elif self.name == 'Skating2-1': 49 | # txt_name = 'skating2_1.txt' 50 | # elif self.name == 'Skating2-2': 51 | # txt_name = 'skating2_2.txt' 52 | # elif self.name == 'FaceOcc1': 53 | # txt_name = 'faceocc1.txt' 54 | # elif self.name == 'FaceOcc2': 55 | # txt_name = 'faceocc2.txt' 56 | # elif self.name == 'Human4-2': 57 | # txt_name = 'human4_2.txt' 58 | # else: 59 | # txt_name = self.name[0].lower()+self.name[1:]+'.txt' 60 | # traj_file = os.path.join(path, name, txt_name) 61 | if os.path.exists(traj_file): 62 | with open(traj_file, 'r') as f : 63 | pred_traj = [list(map(float, x.strip().split(','))) 64 | for x in f.readlines()] 65 | if len(pred_traj) != len(self.gt_traj): 66 | print(name, len(pred_traj), len(self.gt_traj), self.name) 67 | if store: 68 | self.pred_trajs[name] = pred_traj 69 | else: 70 | return pred_traj 71 | else: 72 | print(traj_file) 73 | self.tracker_names = list(self.pred_trajs.keys()) 74 | 75 | 76 | 77 | class VisDroneDataset(Dataset): 78 | """ 79 | Args: 80 | name: dataset name, should be 'VisDrone2019' 81 | dataset_root: dataset root 82 | load_img: wether to load all imgs 83 | """ 84 | def __init__(self, name, dataset_root, load_img=False): 85 | super(VisDroneDataset, self).__init__(name, dataset_root) 86 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 87 | meta_data = json.load(f) 88 | 89 | # load videos 90 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 91 | self.videos = {} 92 | for video in pbar: 93 | pbar.set_postfix_str(video) 94 | self.videos[video] = VisDroneVideo(video, 95 | dataset_root, 96 | meta_data[video]['video_dir'], 97 | meta_data[video]['init_rect'], 98 | meta_data[video]['img_names'], 99 | meta_data[video]['gt_rect'], 100 | meta_data[video]['attr'], 101 | load_img) 102 | 103 | # set attr 104 | attr = [] 105 | for x in self.videos.values(): 106 | if x.attr: 107 | attr += x.attr 108 | attr = set(attr) 109 | self.attr = {} 110 | self.attr['ALL'] = list(self.videos.keys()) 111 | for x in attr: 112 | self.attr[x] = [] 113 | for k, v in self.videos.items(): 114 | if v.attr: 115 | for attr_ in v.attr: 116 | self.attr[attr_].append(k) -------------------------------------------------------------------------------- /toolkit/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .ar_benchmark import AccuracyRobustnessBenchmark 2 | from .eao_benchmark import EAOBenchmark 3 | from .ope_benchmark import OPEBenchmark 4 | from .f1_benchmark import F1Benchmark 5 | -------------------------------------------------------------------------------- /toolkit/evaluation/ar_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author 3 | """ 4 | 5 | import warnings 6 | import itertools 7 | import numpy as np 8 | 9 | from colorama import Style, Fore 10 | from ..utils import calculate_failures, calculate_accuracy 11 | 12 | class AccuracyRobustnessBenchmark: 13 | """ 14 | Args: 15 | dataset: 16 | burnin: 17 | """ 18 | def __init__(self, dataset, burnin=10): 19 | self.dataset = dataset 20 | self.burnin = burnin 21 | 22 | def eval(self, eval_trackers=None): 23 | """ 24 | Args: 25 | eval_tags: list of tag 26 | eval_trackers: list of tracker name 27 | Returns: 28 | ret: dict of results 29 | """ 30 | if eval_trackers is None: 31 | eval_trackers = self.dataset.tracker_names 32 | if isinstance(eval_trackers, str): 33 | eval_trackers = [eval_trackers] 34 | 35 | result = {} 36 | for tracker_name in eval_trackers: 37 | accuracy, failures = self._calculate_accuracy_robustness(tracker_name) 38 | result[tracker_name] = {'overlaps': accuracy, 39 | 'failures': failures} 40 | return result 41 | 42 | def show_result(self, result, eao_result=None, show_video_level=False, helight_threshold=0.5): 43 | """pretty print result 44 | Args: 45 | result: returned dict from function eval 46 | """ 47 | tracker_name_len = max((max([len(x) for x in result.keys()])+2), 12) 48 | if eao_result is not None: 49 | header = "|{:^"+str(tracker_name_len)+"}|{:^10}|{:^12}|{:^13}|{:^7}|" 50 | header = header.format('Tracker Name', 51 | 'Accuracy', 'Robustness', 'Lost Number', 'EAO') 52 | formatter = "|{:^"+str(tracker_name_len)+"}|{:^10.3f}|{:^12.3f}|{:^13.1f}|{:^7.3f}|" 53 | else: 54 | header = "|{:^"+str(tracker_name_len)+"}|{:^10}|{:^12}|{:^13}|" 55 | header = header.format('Tracker Name', 56 | 'Accuracy', 'Robustness', 'Lost Number') 57 | formatter = "|{:^"+str(tracker_name_len)+"}|{:^10.3f}|{:^12.3f}|{:^13.1f}|" 58 | bar = '-'*len(header) 59 | print(bar) 60 | print(header) 61 | print(bar) 62 | if eao_result is not None: 63 | tracker_eao = sorted(eao_result.items(), 64 | key=lambda x:x[1]['all'], 65 | reverse=True)[:20] 66 | tracker_names = [x[0] for x in tracker_eao] 67 | else: 68 | tracker_names = list(result.keys()) 69 | for tracker_name in tracker_names: 70 | # for tracker_name, ret in result.items(): 71 | ret = result[tracker_name] 72 | overlaps = list(itertools.chain(*ret['overlaps'].values())) 73 | accuracy = np.nanmean(overlaps) 74 | length = sum([len(x) for x in ret['overlaps'].values()]) 75 | failures = list(ret['failures'].values()) 76 | lost_number = np.mean(np.sum(failures, axis=0)) 77 | robustness = np.mean(np.sum(np.array(failures), axis=0) / length) * 100 78 | if eao_result is None: 79 | print(formatter.format(tracker_name, accuracy, robustness, lost_number)) 80 | else: 81 | print(formatter.format(tracker_name, accuracy, robustness, lost_number, eao_result[tracker_name]['all'])) 82 | print(bar) 83 | 84 | if show_video_level and len(result) < 10: 85 | print('\n\n') 86 | header1 = "|{:^14}|".format("Tracker name") 87 | header2 = "|{:^14}|".format("Video name") 88 | for tracker_name in result.keys(): 89 | header1 += ("{:^17}|").format(tracker_name) 90 | header2 += "{:^8}|{:^8}|".format("Acc", "LN") 91 | print('-'*len(header1)) 92 | print(header1) 93 | print('-'*len(header1)) 94 | print(header2) 95 | print('-'*len(header1)) 96 | videos = list(result[tracker_name]['overlaps'].keys()) 97 | for video in videos: 98 | row = "|{:^14}|".format(video) 99 | for tracker_name in result.keys(): 100 | overlaps = result[tracker_name]['overlaps'][video] 101 | accuracy = np.nanmean(overlaps) 102 | failures = result[tracker_name]['failures'][video] 103 | lost_number = np.mean(failures) 104 | 105 | accuracy_str = "{:^8.3f}".format(accuracy) 106 | if accuracy < helight_threshold: 107 | row += f'{Fore.RED}{accuracy_str}{Style.RESET_ALL}|' 108 | else: 109 | row += accuracy_str+'|' 110 | lost_num_str = "{:^8.3f}".format(lost_number) 111 | if lost_number > 0: 112 | row += f'{Fore.RED}{lost_num_str}{Style.RESET_ALL}|' 113 | else: 114 | row += lost_num_str+'|' 115 | print(row) 116 | print('-'*len(header1)) 117 | 118 | def _calculate_accuracy_robustness(self, tracker_name): 119 | overlaps = {} 120 | failures = {} 121 | all_length = {} 122 | for i in range(len(self.dataset)): 123 | video = self.dataset[i] 124 | gt_traj = video.gt_traj 125 | if tracker_name not in video.pred_trajs: 126 | tracker_trajs = video.load_tracker(self.dataset.tracker_path, tracker_name, False) 127 | else: 128 | tracker_trajs = video.pred_trajs[tracker_name] 129 | overlaps_group = [] 130 | num_failures_group = [] 131 | for tracker_traj in tracker_trajs: 132 | num_failures = calculate_failures(tracker_traj)[0] 133 | overlaps_ = calculate_accuracy(tracker_traj, gt_traj, 134 | burnin=10, bound=(video.width, video.height))[1] 135 | overlaps_group.append(overlaps_) 136 | num_failures_group.append(num_failures) 137 | with warnings.catch_warnings(): 138 | warnings.simplefilter("ignore", category=RuntimeWarning) 139 | overlaps[video.name] = np.nanmean(overlaps_group, axis=0).tolist() 140 | failures[video.name] = num_failures_group 141 | return overlaps, failures 142 | -------------------------------------------------------------------------------- /toolkit/evaluation/f1_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from glob import glob 5 | from tqdm import tqdm 6 | from colorama import Style, Fore 7 | 8 | from ..utils import determine_thresholds, calculate_accuracy, calculate_f1 9 | 10 | class F1Benchmark: 11 | def __init__(self, dataset): 12 | """ 13 | Args: 14 | result_path: 15 | """ 16 | self.dataset = dataset 17 | 18 | def eval(self, eval_trackers=None): 19 | """ 20 | Args: 21 | eval_tags: list of tag 22 | eval_trackers: list of tracker name 23 | Returns: 24 | eao: dict of results 25 | """ 26 | if eval_trackers is None: 27 | eval_trackers = self.dataset.tracker_names 28 | if isinstance(eval_trackers, str): 29 | eval_trackers = [eval_trackers] 30 | 31 | ret = {} 32 | for tracker_name in eval_trackers: 33 | precision, recall, f1 = self._cal_precision_reall(tracker_name) 34 | ret[tracker_name] = {"precision": precision, 35 | "recall": recall, 36 | "f1": f1 37 | } 38 | return ret 39 | 40 | def _cal_precision_reall(self, tracker_name): 41 | score = [] 42 | # for i in range(len(self.dataset)): 43 | # video = self.dataset[i] 44 | for video in self.dataset: 45 | if tracker_name not in video.confidence: 46 | score += video.load_tracker(self.dataset.tracker_path, tracker_name, False)[1] 47 | else: 48 | score += video.confidence[tracker_name] 49 | score = np.array(score) 50 | thresholds = determine_thresholds(score)[::-1] 51 | 52 | precision = {} 53 | recall = {} 54 | f1 = {} 55 | for i in range(len(self.dataset)): 56 | video = self.dataset[i] 57 | gt_traj = video.gt_traj 58 | N = sum([1 for x in gt_traj if len(x) > 1]) 59 | if tracker_name not in video.pred_trajs: 60 | tracker_traj, score = video.load_tracker(self.dataset.tracker_path, tracker_name, False) 61 | else: 62 | tracker_traj = video.pred_trajs[tracker_name] 63 | score = video.confidence[tracker_name] 64 | overlaps = calculate_accuracy(tracker_traj, gt_traj, \ 65 | bound=(video.width,video.height))[1] 66 | f1[video.name], precision[video.name], recall[video.name] = \ 67 | calculate_f1(overlaps, score, (video.width,video.height),thresholds, N) 68 | return precision, recall, f1 69 | 70 | def show_result(self, result, show_video_level=False, helight_threshold=0.5): 71 | """pretty print result 72 | Args: 73 | result: returned dict from function eval 74 | """ 75 | # sort tracker according to f1 76 | sorted_tracker = {} 77 | for tracker_name, ret in result.items(): 78 | precision = np.mean(list(ret['precision'].values()), axis=0) 79 | recall = np.mean(list(ret['recall'].values()), axis=0) 80 | f1 = 2 * precision * recall / (precision + recall) 81 | max_idx = np.argmax(f1) 82 | sorted_tracker[tracker_name] = (precision[max_idx], recall[max_idx], 83 | f1[max_idx]) 84 | sorted_tracker_ = sorted(sorted_tracker.items(), 85 | key=lambda x:x[1][2], 86 | reverse=True)[:20] 87 | tracker_names = [x[0] for x in sorted_tracker_] 88 | 89 | tracker_name_len = max((max([len(x) for x in result.keys()])+2), 12) 90 | header = "|{:^"+str(tracker_name_len)+"}|{:^11}|{:^8}|{:^7}|" 91 | header = header.format('Tracker Name', 92 | 'Precision', 'Recall', 'F1') 93 | bar = '-' * len(header) 94 | formatter = "|{:^"+str(tracker_name_len)+"}|{:^11.3f}|{:^8.3f}|{:^7.3f}|" 95 | print(bar) 96 | print(header) 97 | print(bar) 98 | # for tracker_name, ret in result.items(): 99 | # precision = np.mean(list(ret['precision'].values()), axis=0) 100 | # recall = np.mean(list(ret['recall'].values()), axis=0) 101 | # f1 = 2 * precision * recall / (precision + recall) 102 | # max_idx = np.argmax(f1) 103 | for tracker_name in tracker_names: 104 | precision = sorted_tracker[tracker_name][0] 105 | recall = sorted_tracker[tracker_name][1] 106 | f1 = sorted_tracker[tracker_name][2] 107 | print(formatter.format(tracker_name, precision, recall, f1)) 108 | print(bar) 109 | 110 | if show_video_level and len(result) < 10: 111 | print('\n\n') 112 | header1 = "|{:^14}|".format("Tracker name") 113 | header2 = "|{:^14}|".format("Video name") 114 | for tracker_name in result.keys(): 115 | # col_len = max(20, len(tracker_name)) 116 | header1 += ("{:^28}|").format(tracker_name) 117 | header2 += "{:^11}|{:^8}|{:^7}|".format("Precision", "Recall", "F1") 118 | print('-'*len(header1)) 119 | print(header1) 120 | print('-'*len(header1)) 121 | print(header2) 122 | print('-'*len(header1)) 123 | videos = list(result[tracker_name]['precision'].keys()) 124 | for video in videos: 125 | row = "|{:^14}|".format(video) 126 | for tracker_name in result.keys(): 127 | precision = result[tracker_name]['precision'][video] 128 | recall = result[tracker_name]['recall'][video] 129 | f1 = result[tracker_name]['f1'][video] 130 | max_idx = np.argmax(f1) 131 | precision_str = "{:^11.3f}".format(precision[max_idx]) 132 | if precision[max_idx] < helight_threshold: 133 | row += f'{Fore.RED}{precision_str}{Style.RESET_ALL}|' 134 | else: 135 | row += precision_str+'|' 136 | recall_str = "{:^8.3f}".format(recall[max_idx]) 137 | if recall[max_idx] < helight_threshold: 138 | row += f'{Fore.RED}{recall_str}{Style.RESET_ALL}|' 139 | else: 140 | row += recall_str+'|' 141 | f1_str = "{:^7.3f}".format(f1[max_idx]) 142 | if f1[max_idx] < helight_threshold: 143 | row += f'{Fore.RED}{f1_str}{Style.RESET_ALL}|' 144 | else: 145 | row += f1_str+'|' 146 | print(row) 147 | print('-'*len(header1)) 148 | -------------------------------------------------------------------------------- /toolkit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import region 2 | from .statistics import * 3 | -------------------------------------------------------------------------------- /toolkit/utils/c_region.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "src/region.h": 2 | ctypedef enum region_type "RegionType": 3 | EMTPY 4 | SPECIAL 5 | RECTANGEL 6 | POLYGON 7 | MASK 8 | 9 | ctypedef struct region_bounds: 10 | float top 11 | float bottom 12 | float left 13 | float right 14 | 15 | ctypedef struct region_rectangle: 16 | float x 17 | float y 18 | float width 19 | float height 20 | 21 | # ctypedef struct region_mask: 22 | # int x 23 | # int y 24 | # int width 25 | # int height 26 | # char *data 27 | 28 | ctypedef struct region_polygon: 29 | int count 30 | float *x 31 | float *y 32 | 33 | ctypedef union region_container_data: 34 | region_rectangle rectangle 35 | region_polygon polygon 36 | # region_mask mask 37 | int special 38 | 39 | ctypedef struct region_container: 40 | region_type type 41 | region_container_data data 42 | 43 | # ctypedef struct region_overlap: 44 | # float overlap 45 | # float only1 46 | # float only2 47 | 48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds) 49 | 50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds) 51 | -------------------------------------------------------------------------------- /toolkit/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author fangyi.zhang@vipl.ict.ac.cn 3 | """ 4 | import numpy as np 5 | 6 | def determine_thresholds(confidence, resolution=100): 7 | """choose threshold according to confidence 8 | 9 | Args: 10 | confidence: list or numpy array or numpy array 11 | reolution: number of threshold to choose 12 | 13 | Restures: 14 | threshold: numpy array 15 | """ 16 | if isinstance(confidence, list): 17 | confidence = np.array(confidence) 18 | confidence = confidence.flatten() 19 | confidence = confidence[~np.isnan(confidence)] 20 | confidence.sort() 21 | 22 | assert len(confidence) > resolution and resolution > 2 23 | 24 | thresholds = np.ones((resolution)) 25 | thresholds[0] = - np.inf 26 | thresholds[-1] = np.inf 27 | delta = np.floor(len(confidence) / (resolution - 2)) 28 | idxs = np.linspace(delta, len(confidence)-delta, resolution-2, dtype=np.int32) 29 | thresholds[1:-1] = confidence[idxs] 30 | return thresholds 31 | -------------------------------------------------------------------------------- /toolkit/utils/src/buffer.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __STRING_BUFFER_H 3 | #define __STRING_BUFFER_H 4 | 5 | // Enable MinGW secure API for _snprintf_s 6 | #define MINGW_HAS_SECURE_API 1 7 | 8 | #ifdef _MSC_VER 9 | #define __INLINE __inline 10 | #else 11 | #define __INLINE inline 12 | #endif 13 | 14 | #include 15 | #include 16 | #include 17 | 18 | typedef struct string_buffer { 19 | char* buffer; 20 | int position; 21 | int size; 22 | } string_buffer; 23 | 24 | typedef struct string_list { 25 | char** buffer; 26 | int position; 27 | int size; 28 | } string_list; 29 | 30 | #define BUFFER_INCREMENT_STEP 4096 31 | 32 | static __INLINE string_buffer* buffer_create(int L) { 33 | string_buffer* B = (string_buffer*) malloc(sizeof(string_buffer)); 34 | B->size = L; 35 | B->buffer = (char*) malloc(sizeof(char) * B->size); 36 | B->position = 0; 37 | return B; 38 | } 39 | 40 | static __INLINE void buffer_reset(string_buffer* B) { 41 | B->position = 0; 42 | } 43 | 44 | static __INLINE void buffer_destroy(string_buffer** B) { 45 | if (!(*B)) return; 46 | if ((*B)->buffer) { 47 | free((*B)->buffer); 48 | (*B)->buffer = NULL; 49 | } 50 | free((*B)); 51 | (*B) = NULL; 52 | } 53 | 54 | static __INLINE char* buffer_extract(const string_buffer* B) { 55 | char *S = (char*) malloc(sizeof(char) * (B->position + 1)); 56 | memcpy(S, B->buffer, B->position); 57 | S[B->position] = '\0'; 58 | return S; 59 | } 60 | 61 | static __INLINE int buffer_size(const string_buffer* B) { 62 | return B->position; 63 | } 64 | 65 | static __INLINE void buffer_push(string_buffer* B, char C) { 66 | int required = 1; 67 | if (required > B->size - B->position) { 68 | B->size = B->position + BUFFER_INCREMENT_STEP; 69 | B->buffer = (char*) realloc(B->buffer, sizeof(char) * B->size); 70 | } 71 | B->buffer[B->position] = C; 72 | B->position += required; 73 | } 74 | 75 | static __INLINE void buffer_append(string_buffer* B, const char *format, ...) { 76 | 77 | int required; 78 | va_list args; 79 | 80 | #if defined(__OS2__) || defined(__WINDOWS__) || defined(WIN32) || defined(_MSC_VER) 81 | 82 | va_start(args, format); 83 | required = _vscprintf(format, args) + 1; 84 | va_end(args); 85 | if (required >= B->size - B->position) { 86 | B->size = B->position + required + 1; 87 | B->buffer = (char*) realloc(B->buffer, sizeof(char) * B->size); 88 | } 89 | va_start(args, format); 90 | required = _vsnprintf_s(&(B->buffer[B->position]), B->size - B->position, _TRUNCATE, format, args); 91 | va_end(args); 92 | B->position += required; 93 | 94 | #else 95 | va_start(args, format); 96 | required = vsnprintf(&(B->buffer[B->position]), B->size - B->position, format, args); 97 | va_end(args); 98 | if (required >= B->size - B->position) { 99 | B->size = B->position + required + 1; 100 | B->buffer = (char*) realloc(B->buffer, sizeof(char) * B->size); 101 | va_start(args, format); 102 | required = vsnprintf(&(B->buffer[B->position]), B->size - B->position, format, args); 103 | va_end(args); 104 | } 105 | B->position += required; 106 | #endif 107 | 108 | } 109 | 110 | static __INLINE string_list* list_create(int L) { 111 | string_list* B = (string_list*) malloc(sizeof(string_list)); 112 | B->size = L; 113 | B->buffer = (char**) malloc(sizeof(char*) * B->size); 114 | memset(B->buffer, 0, sizeof(char*) * B->size); 115 | B->position = 0; 116 | return B; 117 | } 118 | 119 | static __INLINE void list_reset(string_list* B) { 120 | int i; 121 | for (i = 0; i < B->position; i++) { 122 | if (B->buffer[i]) free(B->buffer[i]); 123 | B->buffer[i] = NULL; 124 | } 125 | B->position = 0; 126 | } 127 | 128 | static __INLINE void list_destroy(string_list **B) { 129 | int i; 130 | 131 | if (!(*B)) return; 132 | 133 | for (i = 0; i < (*B)->position; i++) { 134 | if ((*B)->buffer[i]) free((*B)->buffer[i]); (*B)->buffer[i] = NULL; 135 | } 136 | 137 | if ((*B)->buffer) { 138 | free((*B)->buffer); (*B)->buffer = NULL; 139 | } 140 | 141 | free((*B)); 142 | (*B) = NULL; 143 | } 144 | 145 | static __INLINE char* list_get(const string_list *B, int I) { 146 | if (I < 0 || I >= B->position) { 147 | return NULL; 148 | } else { 149 | if (!B->buffer[I]) { 150 | return NULL; 151 | } else { 152 | char *S; 153 | int length = strlen(B->buffer[I]); 154 | S = (char*) malloc(sizeof(char) * (length + 1)); 155 | memcpy(S, B->buffer[I], length + 1); 156 | return S; 157 | } 158 | } 159 | } 160 | 161 | static __INLINE int list_size(const string_list *B) { 162 | return B->position; 163 | } 164 | 165 | static __INLINE void list_append(string_list *B, char* S) { 166 | int required = 1; 167 | int length = strlen(S); 168 | if (required > B->size - B->position) { 169 | B->size = B->position + 16; 170 | B->buffer = (char**) realloc(B->buffer, sizeof(char*) * B->size); 171 | } 172 | B->buffer[B->position] = (char*) malloc(sizeof(char) * (length + 1)); 173 | memcpy(B->buffer[B->position], S, length + 1); 174 | B->position += required; 175 | } 176 | 177 | // This version of the append does not copy the string but simply takes the control of its allocation 178 | static __INLINE void list_append_direct(string_list *B, char* S) { 179 | int required = 1; 180 | // int length = strlen(S); 181 | if (required > B->size - B->position) { 182 | B->size = B->position + 16; 183 | B->buffer = (char**) realloc(B->buffer, sizeof(char*) * B->size); 184 | } 185 | B->buffer[B->position] = S; 186 | B->position += required; 187 | } 188 | 189 | 190 | #endif 191 | -------------------------------------------------------------------------------- /toolkit/utils/src/region.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */ 2 | 3 | #ifndef _REGION_H_ 4 | #define _REGION_H_ 5 | 6 | #ifdef TRAX_STATIC_DEFINE 7 | # define __TRAX_EXPORT 8 | #else 9 | # ifndef __TRAX_EXPORT 10 | # if defined(_MSC_VER) 11 | # ifdef trax_EXPORTS 12 | /* We are building this library */ 13 | # define __TRAX_EXPORT __declspec(dllexport) 14 | # else 15 | /* We are using this library */ 16 | # define __TRAX_EXPORT __declspec(dllimport) 17 | # endif 18 | # elif defined(__GNUC__) 19 | # ifdef trax_EXPORTS 20 | /* We are building this library */ 21 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 22 | # else 23 | /* We are using this library */ 24 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 25 | # endif 26 | # endif 27 | # endif 28 | #endif 29 | 30 | #ifndef MAX 31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b)) 32 | #endif 33 | 34 | #ifndef MIN 35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b)) 36 | #endif 37 | 38 | #define TRAX_DEFAULT_CODE 0 39 | 40 | #define REGION_LEGACY_RASTERIZATION 1 41 | 42 | #ifdef __cplusplus 43 | extern "C" { 44 | #endif 45 | 46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type; 47 | 48 | typedef struct region_bounds { 49 | 50 | float top; 51 | float bottom; 52 | float left; 53 | float right; 54 | 55 | } region_bounds; 56 | 57 | typedef struct region_polygon { 58 | 59 | int count; 60 | 61 | float* x; 62 | float* y; 63 | 64 | } region_polygon; 65 | 66 | typedef struct region_mask { 67 | 68 | int x; 69 | int y; 70 | 71 | int width; 72 | int height; 73 | 74 | char* data; 75 | 76 | } region_mask; 77 | 78 | typedef struct region_rectangle { 79 | 80 | float x; 81 | float y; 82 | float width; 83 | float height; 84 | 85 | } region_rectangle; 86 | 87 | typedef struct region_container { 88 | enum region_type type; 89 | union { 90 | region_rectangle rectangle; 91 | region_polygon polygon; 92 | region_mask mask; 93 | int special; 94 | } data; 95 | } region_container; 96 | 97 | typedef struct region_overlap { 98 | 99 | float overlap; 100 | float only1; 101 | float only2; 102 | 103 | } region_overlap; 104 | 105 | extern const region_bounds region_no_bounds; 106 | 107 | __TRAX_EXPORT int region_set_flags(int mask); 108 | 109 | __TRAX_EXPORT int region_clear_flags(int mask); 110 | 111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds); 112 | 113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds); 114 | 115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom); 116 | 117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region); 118 | 119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region); 120 | 121 | __TRAX_EXPORT char* region_string(region_container* region); 122 | 123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region); 124 | 125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type); 126 | 127 | __TRAX_EXPORT void region_release(region_container** region); 128 | 129 | __TRAX_EXPORT region_container* region_create_special(int code); 130 | 131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height); 132 | 133 | __TRAX_EXPORT region_container* region_create_polygon(int count); 134 | 135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y); 136 | 137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height); 138 | 139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height); 140 | 141 | #ifdef __cplusplus 142 | } 143 | #endif 144 | 145 | #endif 146 | -------------------------------------------------------------------------------- /toolkit/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .draw_f1 import draw_f1 2 | from .draw_success_precision import draw_success_precision 3 | from .draw_eao import draw_eao 4 | -------------------------------------------------------------------------------- /toolkit/visualization/draw_eao.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pickle 4 | from functools import cmp_to_key 5 | from matplotlib import rc 6 | from .draw_utils import COLOR, MARKER_STYLE 7 | 8 | rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) 9 | rc('text', usetex=True) 10 | 11 | def draw_eao(result): 12 | fig = plt.figure() 13 | ax = fig.add_subplot(111, projection='polar') 14 | angles = np.linspace(0, 2*np.pi, 8, endpoint=True) 15 | 16 | attr2value = [] 17 | for i, (tracker_name, ret) in enumerate(result.items()): 18 | value = list(ret.values()) 19 | attr2value.append(value) 20 | value.append(value[0]) 21 | attr2value = np.array(attr2value) 22 | max_value = np.max(attr2value, axis=0) 23 | min_value = np.min(attr2value, axis=0) 24 | result = {key:value for key,value in sorted(result.items(), key=lambda x:x[1]['all'], reverse=True)} 25 | for i, (tracker_name, ret) in enumerate(result.items()): 26 | value = list(ret.values()) 27 | value.append(value[0]) 28 | value = np.array(value) 29 | value *= (1 / max_value) 30 | plt.plot(angles, value, linestyle='-', color=COLOR[i], marker=MARKER_STYLE[i], 31 | label=tracker_name, linewidth=1.5, markersize=6) 32 | 33 | attrs = ["Overall", "Camera motion", "Illumination change","Motion Change", 34 | "Size change", "Occlusion", "Unassigned"] 35 | attr_value = [] 36 | for attr, maxv, minv in zip(attrs, max_value, min_value): 37 | attr_value.append(attr + "\n({:.3f},{:.3f})".format(minv, maxv)) 38 | ax.set_thetagrids(angles[:-1] * 180/np.pi, attr_value) 39 | ax.spines['polar'].set_visible(False) 40 | ax.legend(loc='upper center', bbox_to_anchor=(0.5,-0.07), frameon=False, ncol=5) 41 | # ax.grid(b=False) 42 | ax.set_ylim(0, 1.18) 43 | ax.set_yticks([]) 44 | plt.show() 45 | 46 | if __name__ == '__main__': 47 | result = pickle.load(open("../../result.pkl", 'rb')) 48 | draw_eao(result) 49 | -------------------------------------------------------------------------------- /toolkit/visualization/draw_f1.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from matplotlib import rc 5 | from .draw_utils import COLOR, LINE_STYLE 6 | 7 | rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) 8 | rc('text', usetex=True) 9 | 10 | def draw_f1(result, bold_name=None): 11 | # drawing f1 contour 12 | fig, ax = plt.subplots() 13 | for f1 in np.arange(0.1, 1, 0.1): 14 | recall = np.arange(f1, 1+0.01, 0.01) 15 | precision = f1 * recall / (2 * recall - f1) 16 | ax.plot(recall, precision, color=[0,1,0], linestyle='-', linewidth=0.5) 17 | ax.plot(precision, recall, color=[0,1,0], linestyle='-', linewidth=0.5) 18 | ax.grid(b=True) 19 | ax.set_aspect(1) 20 | plt.xlabel('Recall') 21 | plt.ylabel('Precision') 22 | plt.axis([0, 1, 0, 1]) 23 | plt.title(r'\textbf{VOT2018-LT Precision vs Recall}') 24 | 25 | # draw result line 26 | all_precision = {} 27 | all_recall = {} 28 | best_f1 = {} 29 | best_idx = {} 30 | for tracker_name, ret in result.items(): 31 | precision = np.mean(list(ret['precision'].values()), axis=0) 32 | recall = np.mean(list(ret['recall'].values()), axis=0) 33 | f1 = 2 * precision * recall / (precision + recall) 34 | max_idx = np.argmax(f1) 35 | all_precision[tracker_name] = precision 36 | all_recall[tracker_name] = recall 37 | best_f1[tracker_name] = f1[max_idx] 38 | best_idx[tracker_name] = max_idx 39 | 40 | for idx, (tracker_name, best_f1) in \ 41 | enumerate(sorted(best_f1.items(), key=lambda x:x[1], reverse=True)): 42 | if tracker_name == bold_name: 43 | label = r"\textbf{[%.3f] Ours}" % (best_f1) 44 | else: 45 | label = "[%.3f] " % (best_f1) + tracker_name 46 | recall = all_recall[tracker_name][:-1] 47 | precision = all_precision[tracker_name][:-1] 48 | ax.plot(recall, precision, color=COLOR[idx], linestyle='-', 49 | label=label) 50 | f1_idx = best_idx[tracker_name] 51 | ax.plot(recall[f1_idx], precision[f1_idx], color=[0,0,0], marker='o', 52 | markerfacecolor=COLOR[idx], markersize=5) 53 | ax.legend(loc='lower right', labelspacing=0.2) 54 | plt.xticks(np.arange(0, 1+0.1, 0.1)) 55 | plt.yticks(np.arange(0, 1+0.1, 0.1)) 56 | plt.show() 57 | 58 | if __name__ == '__main__': 59 | draw_f1(None) 60 | -------------------------------------------------------------------------------- /toolkit/visualization/draw_success_precision.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from matplotlib import rc 5 | from .draw_utils import COLOR, LINE_STYLE 6 | 7 | # rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) 8 | # rc('text', usetex=True) 9 | 10 | def draw_success_precision(success_ret, name, videos, attr, precision_ret=None, 11 | norm_precision_ret=None, bold_name=None, axis=[0, 1]): 12 | # success plot 13 | fig, ax = plt.subplots() 14 | ax.grid(b=True) 15 | ax.set_aspect(1) 16 | plt.xlabel('Overlap threshold') 17 | plt.ylabel('Success rate') 18 | if attr == 'ALL': 19 | plt.title(r'\textbf{Success plots of OPE on %s}' % (name)) 20 | else: 21 | plt.title(r'\textbf{Success plots of OPE - %s}' % (attr)) 22 | plt.axis([0, 1]+axis) 23 | success = {} 24 | thresholds = np.arange(0, 1.05, 0.05) 25 | for tracker_name in success_ret.keys(): 26 | value = [v for k, v in success_ret[tracker_name].items() if k in videos] 27 | success[tracker_name] = np.mean(value) 28 | for idx, (tracker_name, auc) in \ 29 | enumerate(sorted(success.items(), key=lambda x:x[1], reverse=True)): 30 | if tracker_name == bold_name: 31 | label = r"\textbf{[%.3f] %s}" % (auc, tracker_name) 32 | else: 33 | label = "[%.3f] " % (auc) + tracker_name 34 | value = [v for k, v in success_ret[tracker_name].items() if k in videos] 35 | plt.plot(thresholds, np.mean(value, axis=0), 36 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2) 37 | ax.legend(loc='lower left', labelspacing=0.2) 38 | ax.autoscale(enable=True, axis='both', tight=True) 39 | xmin, xmax, ymin, ymax = plt.axis() 40 | ax.autoscale(enable=False) 41 | ymax += 0.03 42 | plt.axis([xmin, xmax, ymin, ymax]) 43 | plt.xticks(np.arange(xmin, xmax+0.01, 0.1)) 44 | plt.yticks(np.arange(ymin, ymax, 0.1)) 45 | ax.set_aspect((xmax - xmin)/(ymax-ymin)) 46 | plt.show() 47 | 48 | if precision_ret: 49 | # norm precision plot 50 | fig, ax = plt.subplots() 51 | ax.grid(b=True) 52 | ax.set_aspect(50) 53 | plt.xlabel('Location error threshold') 54 | plt.ylabel('Precision') 55 | if attr == 'ALL': 56 | plt.title(r'\textbf{Precision plots of OPE on %s}' % (name)) 57 | else: 58 | plt.title(r'\textbf{Precision plots of OPE - %s}' % (attr)) 59 | plt.axis([0, 50]+axis) 60 | precision = {} 61 | thresholds = np.arange(0, 51, 1) 62 | for tracker_name in precision_ret.keys(): 63 | value = [v for k, v in precision_ret[tracker_name].items() if k in videos] 64 | precision[tracker_name] = np.mean(value, axis=0)[20] 65 | for idx, (tracker_name, pre) in \ 66 | enumerate(sorted(precision.items(), key=lambda x:x[1], reverse=True)): 67 | if tracker_name == bold_name: 68 | label = r"\textbf{[%.3f] %s}" % (pre, tracker_name) 69 | else: 70 | label = "[%.3f] " % (pre) + tracker_name 71 | value = [v for k, v in precision_ret[tracker_name].items() if k in videos] 72 | plt.plot(thresholds, np.mean(value, axis=0), 73 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2) 74 | ax.legend(loc='lower right', labelspacing=0.2) 75 | ax.autoscale(enable=True, axis='both', tight=True) 76 | xmin, xmax, ymin, ymax = plt.axis() 77 | ax.autoscale(enable=False) 78 | ymax += 0.03 79 | plt.axis([xmin, xmax, ymin, ymax]) 80 | plt.xticks(np.arange(xmin, xmax+0.01, 5)) 81 | plt.yticks(np.arange(ymin, ymax, 0.1)) 82 | ax.set_aspect((xmax - xmin)/(ymax-ymin)) 83 | plt.show() 84 | 85 | # norm precision plot 86 | if norm_precision_ret: 87 | fig, ax = plt.subplots() 88 | ax.grid(b=True) 89 | plt.xlabel('Location error threshold') 90 | plt.ylabel('Precision') 91 | if attr == 'ALL': 92 | plt.title(r'\textbf{Normalized Precision plots of OPE on %s}' % (name)) 93 | else: 94 | plt.title(r'\textbf{Normalized Precision plots of OPE - %s}' % (attr)) 95 | norm_precision = {} 96 | thresholds = np.arange(0, 51, 1) / 100 97 | for tracker_name in precision_ret.keys(): 98 | value = [v for k, v in norm_precision_ret[tracker_name].items() if k in videos] 99 | norm_precision[tracker_name] = np.mean(value, axis=0)[20] 100 | for idx, (tracker_name, pre) in \ 101 | enumerate(sorted(norm_precision.items(), key=lambda x:x[1], reverse=True)): 102 | if tracker_name == bold_name: 103 | label = r"\textbf{[%.3f] %s}" % (pre, tracker_name) 104 | else: 105 | label = "[%.3f] " % (pre) + tracker_name 106 | value = [v for k, v in norm_precision_ret[tracker_name].items() if k in videos] 107 | plt.plot(thresholds, np.mean(value, axis=0), 108 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2) 109 | ax.legend(loc='lower right', labelspacing=0.2) 110 | ax.autoscale(enable=True, axis='both', tight=True) 111 | xmin, xmax, ymin, ymax = plt.axis() 112 | ax.autoscale(enable=False) 113 | ymax += 0.03 114 | plt.axis([xmin, xmax, ymin, ymax]) 115 | plt.xticks(np.arange(xmin, xmax+0.01, 0.05)) 116 | plt.yticks(np.arange(ymin, ymax, 0.1)) 117 | ax.set_aspect((xmax - xmin)/(ymax-ymin)) 118 | plt.show() 119 | -------------------------------------------------------------------------------- /toolkit/visualization/draw_utils.py: -------------------------------------------------------------------------------- 1 | 2 | COLOR = ((1, 0, 0), 3 | (0, 1, 0), 4 | (1, 0, 1), 5 | (1, 1, 0), 6 | (0, 162/255, 232/255), 7 | (0.5, 0.5, 0.5), 8 | (0, 0, 1), 9 | (0, 1, 1), 10 | (136/255, 0 , 21/255), 11 | (255/255, 127/255, 39/255), 12 | (0, 0, 0)) 13 | 14 | LINE_STYLE = ['-', '--', ':', '-', '--', ':', '-', '--', ':', '-', '--'] 15 | 16 | MARKER_STYLE = ['o', 'v', '<', '*', 'D', 'x', '.', 'x', '<', '.', 'o'] 17 | -------------------------------------------------------------------------------- /tools/demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import os 7 | import argparse 8 | 9 | import cv2 10 | import torch 11 | import numpy as np 12 | from glob import glob 13 | 14 | from pysot.core.config import cfg 15 | from pysot.models.model_builder import ModelBuilder 16 | from pysot.tracker.tracker_builder import build_tracker 17 | 18 | torch.set_num_threads(1) 19 | 20 | parser = argparse.ArgumentParser(description='tracking demo') 21 | parser.add_argument('--config', type=str, help='config file') 22 | parser.add_argument('--snapshot', type=str, help='model name') 23 | parser.add_argument('--video_name', default='', type=str, 24 | help='videos or image files') 25 | args = parser.parse_args() 26 | 27 | 28 | def get_frames(video_name): 29 | if not video_name: 30 | cap = cv2.VideoCapture(0) 31 | # warmup 32 | for i in range(5): 33 | cap.read() 34 | while True: 35 | ret, frame = cap.read() 36 | if ret: 37 | yield frame 38 | else: 39 | break 40 | elif video_name.endswith('avi') or \ 41 | video_name.endswith('mp4'): 42 | cap = cv2.VideoCapture(args.video_name) 43 | while True: 44 | ret, frame = cap.read() 45 | if ret: 46 | yield frame 47 | else: 48 | break 49 | else: 50 | images = glob(os.path.join(video_name, '*.jp*')) 51 | images = sorted(images, 52 | key=lambda x: int(x.split('/')[-1].split('.')[0])) 53 | for img in images: 54 | frame = cv2.imread(img) 55 | yield frame 56 | 57 | 58 | def main(): 59 | # load config 60 | cfg.merge_from_file(args.config) 61 | cfg.CUDA = torch.cuda.is_available() 62 | device = torch.device('cuda' if cfg.CUDA else 'cpu') 63 | 64 | # create model 65 | model = ModelBuilder() 66 | 67 | # load model 68 | model.load_state_dict(torch.load(args.snapshot, 69 | map_location=lambda storage, loc: storage.cpu())) 70 | model.eval().to(device) 71 | 72 | # build tracker 73 | tracker = build_tracker(model) 74 | 75 | first_frame = True 76 | if args.video_name: 77 | video_name = args.video_name.split('/')[-1].split('.')[0] 78 | else: 79 | video_name = 'webcam' 80 | cv2.namedWindow(video_name, cv2.WND_PROP_FULLSCREEN) 81 | for frame in get_frames(args.video_name): 82 | if first_frame: 83 | try: 84 | init_rect = cv2.selectROI(video_name, frame, False, False) 85 | except: 86 | exit() 87 | tracker.init(frame, init_rect) 88 | first_frame = False 89 | else: 90 | outputs = tracker.track(frame) 91 | if 'polygon' in outputs: 92 | polygon = np.array(outputs['polygon']).astype(np.int32) 93 | cv2.polylines(frame, [polygon.reshape((-1, 1, 2))], 94 | True, (0, 255, 0), 3) 95 | mask = ((outputs['mask'] > cfg.TRACK.MASK_THERSHOLD) * 255) 96 | mask = mask.astype(np.uint8) 97 | mask = np.stack([mask, mask*255, mask]).transpose(1, 2, 0) 98 | frame = cv2.addWeighted(frame, 0.77, mask, 0.23, -1) 99 | else: 100 | bbox = list(map(int, outputs['bbox'])) 101 | cv2.rectangle(frame, (bbox[0], bbox[1]), 102 | (bbox[0]+bbox[2], bbox[1]+bbox[3]), 103 | (0, 255, 0), 3) 104 | cv2.imshow(video_name, frame) 105 | cv2.waitKey(40) 106 | 107 | 108 | if __name__ == '__main__': 109 | main() 110 | -------------------------------------------------------------------------------- /vot_iter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/vot_iter/__init__.py -------------------------------------------------------------------------------- /vot_iter/tracker_SiamRPNpp.m: -------------------------------------------------------------------------------- 1 | 2 | % error('Tracker not configured! Please edit the tracker_test.m file.'); % Remove this line after proper configuration 3 | 4 | % The human readable label for the tracker, used to identify the tracker in reports 5 | % If not set, it will be set to the same value as the identifier. 6 | % It does not have to be unique, but it is best that it is. 7 | tracker_label = ['SiamRPNpp']; 8 | 9 | % For Python implementations we have created a handy function that generates the appropritate 10 | % command that will run the python executable and execute the given script that includes your 11 | % tracker implementation. 12 | % 13 | % Please customize the line below by substituting the first argument with the name of the 14 | % script of your tracker (not the .py file but just the name of the script) and also provide the 15 | % path (or multiple paths) where the tracker sources % are found as the elements of the cell 16 | % array (second argument). 17 | setenv('MKL_NUM_THREADS','1'); 18 | pysot_root = 'path/to/pysot'; 19 | track_build_path = 'path/to/track/build'; 20 | tracker_command = generate_python_command('vot_iter.vot_iter', {pysot_root; [track_build_path '/python/lib']}) 21 | 22 | tracker_interpreter = 'python'; 23 | 24 | tracker_linkpath = {track_build_path}; 25 | 26 | % tracker_linkpath = {}; % A cell array of custom library directories used by the tracker executable (optional) 27 | 28 | -------------------------------------------------------------------------------- /vot_iter/vot.py: -------------------------------------------------------------------------------- 1 | """ 2 | \file vot.py 3 | 4 | @brief Python utility functions for VOT integration 5 | 6 | @author Luka Cehovin, Alessio Dore 7 | 8 | @date 2016 9 | 10 | """ 11 | 12 | import sys 13 | import copy 14 | import collections 15 | 16 | try: 17 | import trax 18 | except ImportError: 19 | raise Exception('TraX support not found. Please add trax module to Python path.') 20 | 21 | Rectangle = collections.namedtuple('Rectangle', ['x', 'y', 'width', 'height']) 22 | Point = collections.namedtuple('Point', ['x', 'y']) 23 | Polygon = collections.namedtuple('Polygon', ['points']) 24 | 25 | class VOT(object): 26 | """ Base class for Python VOT integration """ 27 | def __init__(self, region_format, channels=None): 28 | """ Constructor 29 | 30 | Args: 31 | region_format: Region format options 32 | """ 33 | assert(region_format in [trax.Region.RECTANGLE, trax.Region.POLYGON]) 34 | 35 | if channels is None: 36 | channels = ['color'] 37 | elif channels == 'rgbd': 38 | channels = ['color', 'depth'] 39 | elif channels == 'rgbt': 40 | channels = ['color', 'ir'] 41 | elif channels == 'ir': 42 | channels = ['ir'] 43 | else: 44 | raise Exception('Illegal configuration {}.'.format(channels)) 45 | 46 | self._trax = trax.Server([region_format], [trax.Image.PATH], channels) 47 | 48 | request = self._trax.wait() 49 | assert(request.type == 'initialize') 50 | if isinstance(request.region, trax.Polygon): 51 | self._region = Polygon([Point(x[0], x[1]) for x in request.region]) 52 | else: 53 | self._region = Rectangle(*request.region.bounds()) 54 | self._image = [x.path() for k, x in request.image.items()] 55 | if len(self._image) == 1: 56 | self._image = self._image[0] 57 | 58 | self._trax.status(request.region) 59 | 60 | def region(self): 61 | """ 62 | Send configuration message to the client and receive the initialization 63 | region and the path of the first image 64 | 65 | Returns: 66 | initialization region 67 | """ 68 | 69 | return self._region 70 | 71 | def report(self, region, confidence = None): 72 | """ 73 | Report the tracking results to the client 74 | 75 | Arguments: 76 | region: region for the frame 77 | """ 78 | assert(isinstance(region, Rectangle) or isinstance(region, Polygon)) 79 | if isinstance(region, Polygon): 80 | tregion = trax.Polygon.create([(x.x, x.y) for x in region.points]) 81 | else: 82 | tregion = trax.Rectangle.create(region.x, region.y, region.width, region.height) 83 | properties = {} 84 | if not confidence is None: 85 | properties['confidence'] = confidence 86 | self._trax.status(tregion, properties) 87 | 88 | def frame(self): 89 | """ 90 | Get a frame (image path) from client 91 | 92 | Returns: 93 | absolute path of the image 94 | """ 95 | if hasattr(self, "_image"): 96 | image = self._image 97 | del self._image 98 | return image 99 | 100 | request = self._trax.wait() 101 | 102 | if request.type == 'frame': 103 | image = [x.path() for k, x in request.image.items()] 104 | if len(image) == 1: 105 | return image[0] 106 | return image 107 | else: 108 | return None 109 | 110 | 111 | def quit(self): 112 | if hasattr(self, '_trax'): 113 | self._trax.quit() 114 | 115 | def __del__(self): 116 | self.quit() 117 | 118 | -------------------------------------------------------------------------------- /vot_iter/vot_iter.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import os 6 | from os.path import join 7 | 8 | from pysot.core.config import cfg 9 | from pysot.models.model_builder import ModelBuilder 10 | from pysot.tracker.tracker_builder import build_tracker 11 | from pysot.utils.bbox import get_axis_aligned_bbox 12 | from pysot.utils.model_load import load_pretrain 13 | from toolkit.datasets import DatasetFactory 14 | from toolkit.utils.region import vot_overlap, vot_float2str 15 | 16 | from . import vot 17 | from .vot import Rectangle, Polygon, Point 18 | 19 | 20 | # modify root 21 | 22 | cfg_root = "path/to/expr" 23 | model_file = join(cfg_root, 'model.pth') 24 | cfg_file = join(cfg_root, 'config.yaml') 25 | 26 | def warmup(model): 27 | for i in range(10): 28 | model.template(torch.FloatTensor(1,3,127,127).cuda()) 29 | 30 | def setup_tracker(): 31 | cfg.merge_from_file(cfg_file) 32 | 33 | model = ModelBuilder() 34 | model = load_pretrain(model, model_file).cuda().eval() 35 | 36 | tracker = build_tracker(model) 37 | warmup(model) 38 | return tracker 39 | 40 | 41 | tracker = setup_tracker() 42 | 43 | handle = vot.VOT("polygon") 44 | region = handle.region() 45 | try: 46 | region = np.array([region[0][0][0], region[0][0][1], region[0][1][0], region[0][1][1], 47 | region[0][2][0], region[0][2][1], region[0][3][0], region[0][3][1]]) 48 | except: 49 | region = np.array(region) 50 | 51 | cx, cy, w, h = get_axis_aligned_bbox(region) 52 | 53 | image_file = handle.frame() 54 | if not image_file: 55 | sys.exit(0) 56 | 57 | im = cv2.imread(image_file) # HxWxC 58 | # init 59 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h]) 60 | gt_bbox_ = [cx-(w-1)/2, cy-(h-1)/2, w, h] 61 | tracker.init(im, gt_bbox_) 62 | 63 | while True: 64 | img_file = handle.frame() 65 | if not img_file: 66 | break 67 | im = cv2.imread(img_file) 68 | outputs = tracker.track(im) 69 | pred_bbox = outputs['bbox'] 70 | result = Rectangle(*pred_bbox) 71 | score = outputs['best_score'] 72 | if cfg.MASK.MASK: 73 | pred_bbox = outputs['polygon'] 74 | result = Polygon(Point(x[0], x[1]) for x in pred_bbox) 75 | 76 | handle.report(result, score) 77 | --------------------------------------------------------------------------------