├── .gitignore
├── INSTALL.md
├── LICENSE
├── MODEL_ZOO.md
├── README.md
├── TRAIN.md
├── experiments
├── siamfc_alex_upxcorr_otb
│ ├── config.yaml
│ └── convert_model.py
├── siamfc_alex_upxcorr_vot
│ ├── config.yaml
│ └── convert_model.py
├── siammask_r50_l3
│ ├── config.yaml
│ └── convert_model.py
├── siamrpn_alex_dwxcorr_otb
│ └── config.yaml
├── siamrpn_alex_dwxcorr_vot
│ └── config.yaml
├── siamrpn_r50_l234_dwxcorr
│ └── config.yaml
├── siamrpn_r50_l234_dwxcorr_otb
│ └── config.yaml
├── siamrpn_r50_l234_dwxcorr_vot
│ └── config.yaml
└── siamrpn_r50_l234_dwxcorr_votlt
│ └── config.yaml
├── install.sh
├── pysot
├── __init__.py
├── core
│ ├── __init__.py
│ ├── config.py
│ └── xcorr.py
├── datasets
│ ├── __init__.py
│ ├── anchor_target.py
│ ├── augmentation.py
│ └── dataset.py
├── models
│ ├── __init__.py
│ ├── backbone
│ │ ├── __init__.py
│ │ ├── alexnet.py
│ │ ├── mobile_v2.py
│ │ └── resnet_atrous.py
│ ├── head
│ │ ├── __init__.py
│ │ ├── mask.py
│ │ └── rpn.py
│ ├── init_weight.py
│ ├── loss.py
│ ├── model_builder.py
│ └── neck
│ │ ├── __init__.py
│ │ └── neck.py
├── tracker
│ ├── __init__.py
│ ├── base_tracker.py
│ ├── classifier
│ │ ├── __init__.py
│ │ ├── base_classifier.py
│ │ ├── libs
│ │ │ ├── __init__.py
│ │ │ ├── attention.py
│ │ │ ├── augmentation.py
│ │ │ ├── complex.py
│ │ │ ├── dcf.py
│ │ │ ├── fourier.py
│ │ │ ├── operation.py
│ │ │ ├── optimization.py
│ │ │ ├── params.py
│ │ │ ├── plotting.py
│ │ │ ├── preprocessing.py
│ │ │ ├── tensordict.py
│ │ │ └── tensorlist.py
│ │ └── optim.py
│ ├── siamfc_tracker.py
│ ├── siammask_tracker.py
│ ├── siamrpn_tracker.py
│ ├── siamrpnlt_tracker.py
│ └── tracker_builder.py
└── utils
│ ├── __init__.py
│ ├── anchor.py
│ ├── average_meter.py
│ ├── bbox.py
│ ├── distributed.py
│ ├── log_helper.py
│ ├── lr_scheduler.py
│ ├── misc.py
│ └── model_load.py
├── requirements.txt
├── setup.py
├── testing_dataset
└── README.md
├── toolkit
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── dataset.py
│ ├── got10k.py
│ ├── lasot.py
│ ├── nfs.py
│ ├── otb.py
│ ├── trackingnet.py
│ ├── uav.py
│ ├── video.py
│ ├── visdrone.py
│ └── vot.py
├── evaluation
│ ├── __init__.py
│ ├── ar_benchmark.py
│ ├── eao_benchmark.py
│ ├── f1_benchmark.py
│ └── ope_benchmark.py
├── utils
│ ├── __init__.py
│ ├── c_region.pxd
│ ├── misc.py
│ ├── region.c
│ ├── region.pyx
│ ├── src
│ │ ├── buffer.h
│ │ ├── region.c
│ │ └── region.h
│ └── statistics.py
└── visualization
│ ├── __init__.py
│ ├── draw_eao.py
│ ├── draw_f1.py
│ ├── draw_success_precision.py
│ └── draw_utils.py
├── tools
├── demo.py
├── eval.py
├── test.py
└── train.py
└── vot_iter
├── __init__.py
├── tracker_SiamRPNpp.m
├── vot.py
└── vot_iter.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # custom files
10 | .idea/
11 |
12 | # dataset
13 | training_dataset/*
14 | testing_dataset/*
15 | !testing_dataset/README.md
16 | !training_dataset/README.md
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # celery beat schedule file
88 | celerybeat-schedule
89 |
90 | # SageMath parsed files
91 | *.sage.py
92 |
93 | # Environments
94 | .env
95 | .venv
96 | env/
97 | venv/
98 | ENV/
99 | env.bak/
100 | venv.bak/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 | .spyproject
105 |
106 | # Rope project settings
107 | .ropeproject
108 |
109 | # mkdocs documentation
110 | /site
111 |
112 | # mypy
113 | .mypy_cache/
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | This document contains detailed instructions for installing dependencies for PySOT. We recommand using the [install.sh](install.sh). The code is tested on an Ubuntu 16.04 system with Nvidia GPU (We recommand 1080TI / TITAN XP).
4 |
5 | ### Requirments
6 | * Conda with Python 3.7.
7 | * Nvidia GPU.
8 | * PyTorch 0.4.1
9 | * yacs
10 | * pyyaml
11 | * matplotlib
12 | * tqdm
13 | * OpenCV
14 |
15 | ## Step-by-step instructions
16 |
17 | #### Create environment and activate
18 | ```bash
19 | conda create --name pysot python=3.7
20 | conda activate pysot
21 | ```
22 |
23 | #### Install numpy/pytorch/opencv
24 | ```
25 | conda install numpy
26 | conda install pytorch=0.4.1 torchvision cuda90 -c pytorch
27 | pip install opencv-python
28 | ```
29 |
30 | #### Install other requirements
31 | ```
32 | pip install pyyaml yacs tqdm colorama matplotlib cython tensorboardX
33 | ```
34 |
35 | #### Build extensions
36 | ```
37 | python setup.py build_ext --inplace
38 | ```
39 |
40 |
41 | ## Try with scripts
42 | ```
43 | bash install.sh /path/to/your/conda pysot
44 | ```
45 |
--------------------------------------------------------------------------------
/MODEL_ZOO.md:
--------------------------------------------------------------------------------
1 | # PySOT Model Zoo
2 |
3 | ## Introduction
4 |
5 | This file documents a large collection of baselines trained with pysot. All configurations for these baselines are located in the [`experiments`](testing_experiments) directory. The tables below provide results about inference. Links to the trained models as well as their output are provided.
6 |
7 | ## Visual Tracking Baselines
8 |
9 | ### Short-term Tracking
10 |
11 | | Model(arch+backbone+xcorr) | VOT16 (EAO/A/R) | VOT18 (EAO/A/R) | VOT19 (EAO/A/R) | OTB2015 (AUC/Prec.) | VOT18-LT(F1) | Speed (fps) | url |
12 | |:---------------------------------:|:-:|:------------------------:|:--------------------:|:----------------:|:--------------:|:------------:|:-----------:|
13 | | siamrpn_alex_dwxcorr | 0.393/0.618/0.238 | 0.352/0.576/0.290 | 0.260/0.573/0.547| - | - | 180 | [link](https://drive.google.com/open?id=1t62x56Jl7baUzPTo0QrC4jJnwvPZm-2m) |
14 | | siamrpn_alex_dwxcorr_otb | - | - | - |0.666/0.876 | - | 180 | [link](https://drive.google.com/open?id=1gCpmR85Qno3C-naR3SLqRNpVfU7VJ2W0) |
15 | | siamrpn_r50_l234_dwxcorr | 0.464/0.642/0.196 | 0.415/0.601/0.234 | 0.287/0.595/0.467 | - | - | 35 | [link](https://drive.google.com/open?id=1Q4-1563iPwV6wSf_lBHDj5CPFiGSlEPG) |
16 | | siamrpn_r50_l234_dwxcorr_otb | - | - | - |0.696/0.914 | - | 35 | [link](https://drive.google.com/open?id=1Cx_oHu6o0gNeH7F9zZrgevfAGdyWC4D5) |
17 | |siamrpn_mobilev2_l234_dwxcorr| 0.455/0.624/0.214 | 0.410/0.586/0.229 | 0.292/0.580/0.446| - | - | 75 | [link](https://drive.google.com/open?id=1JB94pZTvB1ZByU-qSJn4ZAIfjLWE5EBJ) |
18 | | siammask_r50_l3 | 0.455/0.634/0.219 | 0.423/0.615/0.248 | 0.283/0.597/0.461 | - | - | 56 | [link](https://drive.google.com/open?id=1YbPUQVTYw_slAvk_DchvRY-7B6rnSXP9) |
19 | | siamrpn_r50_l234_dwxcorr_lt | - | - | - | - | 0.629 | 20 | [link](https://drive.google.com/open?id=1lOOTedwGLbGZ7MAbqJimIcET3ANJd29A) |
20 |
21 | The models can also be downloaded from [Baidu Yun](https://pan.baidu.com/s/1GB9-aTtjG57SebraVoBfuQ) Extraction Code: j9yb
22 |
23 | Note:
24 |
25 | - speed tested on GTX-1080Ti
26 | - `alex` denotes [AlexNet](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks), `r50_lxyz` denotes the outputs of stage x, y, and z in [ResNet-50](https://arxiv.org/abs/1512.03385), and `mobilev2` denotes [MobileNetV2](https://arxiv.org/abs/1801.04381).
27 | - `dwxcorr` denotes Depth-wise Cross Correlation. See more in [SiamRPN++ Section 3.4](https://arxiv.org/abs/1812.11703).
28 | - The suffixes `otb` and `lt` are designed for the [OTB](http://cvlab.hanyang.ac.kr/tracker_benchmark/benchmark.html) and [VOT long-term tracking challenge](http://www.votchallenge.net/vot2018/), the default (without suffix) is designed for [VOT short-term tracking challenge](http://www.votchallenge.net/index.html).
29 | - All above models are trained on VID,YoutubeBB,COCO,ImageNetDet which are the same as DaSiamRPN.
30 | - The model of `SiamFC` is from the author's [Matlab](https://github.com/bertinetto/siamese-fc) implementation. The code for further model converting is available in the corresponding experiment file folders.
31 |
32 | ## License
33 |
34 | All models available for download through this document are licensed under the [Creative Commons Attribution-ShareAlike 3.0 license](https://creativecommons.org/licenses/by-sa/3.0/).
35 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DROL
2 | This is the repo for paper "Discriminative and Robust Online Learning for Siamese Visual Tracking" [[paper](https://arxiv.org/abs/1909.02959)] [[results](https://drive.google.com/open?id=1iXtaxr1zkWvKf6AAMwN98ymaId9ixU4z)], presented as poster at AAAI 2020.
3 |
4 | ## Introduction
5 |
6 | The proposed Discriminative and Robust Online Learning (DROL) module is designed to work with a variety of off-the-shelf siamese trackers. Our method is extensively evaluated over serveral mainstream benchmarks and is believed to induce a consistant performance gain over the given baseline. The model includes but not limited to, as paper evaluated:
7 |
8 | - [SiamRPN++](https://arxiv.org/abs/1812.11703) (DROL-RPN)
9 | - [SiamMask](https://arxiv.org/abs/1812.05050) (DROL-MASK)
10 | - [SiamFC](https://arxiv.org/abs/1606.09549) (DROL-FC)
11 |
12 | ## Model Zoo
13 |
14 | The corresponding offline-trained models are availabe at [PySOT Model Zoo](MODEL_ZOO.md).
15 |
16 |
17 | ## Get Started
18 |
19 | ### Installation
20 |
21 | - Please find installation instructions for PyTorch and PySOT in [`INSTALL.md`](INSTALL.md).
22 | - Add DROL to your PYTHONPATH
23 | ```bash
24 | export PYTHONPATH=/path/to/drol:$PYTHONPATH
25 | ```
26 |
27 | ### Download models
28 | Download models in [PySOT Model Zoo](MODEL_ZOO.md) and put the model.pth to the corresponding directory in experiment.
29 |
30 | ### Test tracker
31 | ```bash
32 | cd experiments/siamrpn_r50_l234_dwxcorr
33 | python -u ../../tools/test.py \
34 | --snapshot model.pth \ # model path
35 | --dataset VOT2018 \ # dataset name
36 | --config config.yaml # config file
37 | ```
38 |
39 | ### Eval tracker
40 | assume still in experiments/siamrpn_r50_l234_dwxcorr_8gpu
41 | ``` bash
42 | python ../../tools/eval.py \
43 | --tracker_path ./results \ # result path
44 | --dataset VOT2018 \ # dataset name
45 | --num 1 \ # number thread to eval
46 | --tracker_prefix 'model' # tracker_name
47 | ```
48 |
49 | ### Others
50 | - For `DROL-RPN`, we have seperate config file thus each own experiment file folder for `vot`/`votlt`/`otb`/`others`, where `vot` is used for `VOT-20XX-baseline` benchmark, `votlt` for `VOT-20XX-longterm` benchmark, `otb` for `OTB2013/15` benchmark, and `others` is default setting thus for all the other benchmarks, including but not limited to `LaSOT`/`TrackingNet`/`UAV123`.
51 | - For `DROL-FC/DROL-Mask`, only experiments on `vot/otb` are evaluated as described in the paper. Similar to the repo of `PySOT`, we use config file for `vot` as default setting.
52 |
53 | - Since this repo is a grown-up modification of [PySOT](https://github.com/STVIR/pysot), we recommend to refer to PySOT for more technical issues.
54 |
55 |
56 | ## References
57 | - Jinghao Zhou, Peng Wang, Haoyang Sun, '[Discriminative and Robust Online Learning For Siamese Visual Tracking](http://arxiv.org/abs/1909.02959)', Proc. AAAI Conference on Artificial Intelligence (AAAI), 2020.
58 |
59 | ### Ackowledgement
60 | - [pysot](https://github.com/STVIR/pysot)
61 | - [pytracking](https://github.com/visionml/pytracking)
62 |
--------------------------------------------------------------------------------
/TRAIN.md:
--------------------------------------------------------------------------------
1 | # PySOT Training Tutorial
2 |
3 | This implements training of SiamRPN with backbone architectures, such as ResNet, AlexNet.
4 | ### Add PySOT to your PYTHONPATH
5 | ```bash
6 | export PYTHONPATH=/path/to/pysot:$PYTHONPATH
7 | ```
8 |
9 | ## Prepare training dataset
10 | Prepare training dataset, detailed preparations are listed in [training_dataset](training_dataset) directory.
11 | * [VID](http://image-net.org/challenges/LSVRC/2017/)
12 | * [YOUTUBEBB](https://research.google.com/youtube-bb/)
13 | * [DET](http://image-net.org/challenges/LSVRC/2017/)
14 | * [COCO](http://cocodataset.org)
15 |
16 | ## Download pretrained backbones
17 | Download pretrained backbones from [Google Drive](https://drive.google.com/drive/folders/1DuXVWVYIeynAcvt9uxtkuleV6bs6e3T9) and put them in `pretrained_models` directory
18 |
19 | ## Training
20 |
21 | To train a model (SiamRPN++), run `train.py` with the desired configs:
22 |
23 | ```bash
24 | cd experiments/siamrpn_r50_l234_dwxcorr_8gpu
25 | ```
26 |
27 | ### Multi-processing Distributed Data Parallel Training
28 |
29 | #### Single node, multiple GPUs:
30 | ```bash
31 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
32 | python -m torch.distributed.launch \
33 | --nproc_per_node=8 \
34 | --master_port=2333 \
35 | ../../tools/train.py --cfg config.yaml
36 | ```
37 |
38 | #### Multiple nodes:
39 | Node 1: (IP: 192.168.1.1, and has a free port: 2333) master node
40 | ```bash
41 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
42 | python -m torch.distributed.launch \
43 | --nnodes=2 \
44 | --node_rank=0 \
45 | --nproc_per_node=8 \
46 | --master_addr=192.168.1.1 \ # adjust your ip here
47 | --master_port=2333 \
48 | ../../tools/train.py
49 | ```
50 | Node 2:
51 | ```bash
52 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
53 | python -m torch.distributed.launch \
54 | --nnodes=2 \
55 | --node_rank=1 \
56 | --nproc_per_node=8 \
57 | --master_addr=192.168.1.1 \
58 | --master_port=2333 \
59 | ../../tools/train.py
60 | ```
61 |
62 | ## Testing
63 | After training, you can test snapshots on VOT dataset.
64 | For `AlexNet`, you need to test snapshots from 35 to 50 epoch.
65 | For `ResNet`, you need to test snapshots from 10 to 20 epoch.
66 |
67 | ```bash
68 | START=10
69 | END=20
70 | seq $START 1 $END | \
71 | xargs -I {} echo "snapshot/checkpoint_e{}.pth" | \
72 | xargs -I {} \
73 | python -u ../../tools/test.py \
74 | --snapshot {} \
75 | --config config.yaml \
76 | --dataset VOT2018 2>&1 | tee logs/test_dataset.log
77 | ```
78 |
79 | ## Evaluation
80 | ```
81 | python ../../tools/eval.py \
82 | --tracker_path ./results \ # result path
83 | --dataset VOT2018 \ # dataset name
84 | --num 4 \ # number thread to eval
85 | --tracker_prefix 'ch*' # tracker_name
86 | ```
87 |
--------------------------------------------------------------------------------
/experiments/siamfc_alex_upxcorr_otb/config.yaml:
--------------------------------------------------------------------------------
1 | META_ARC: "siamfc_alex_upxcorr"
2 |
3 | BACKBONE:
4 | TYPE: "alexnetlegacy2"
5 |
6 | ADJUST:
7 | ADJUST: False
8 |
9 | RPN:
10 | RPN: False
11 |
12 | MASK:
13 | MASK: False
14 |
15 | TRACK:
16 | TYPE: 'SiamFCTracker'
17 | TOTAL_STRIDE: 8
18 | SCALE_NUM: 3
19 | SCALE_STEP: 1.0375
20 | PENALTY_K: 0.9745
21 | WINDOW_INFLUENCE: 0.15 # 0.176
22 | LR: 0.59
23 | EXEMPLAR_SIZE: 127
24 | INSTANCE_SIZE: 255
25 | BASE_SIZE: 0
26 | CONTEXT_AMOUNT: 0.5
27 | TEMPLATE_UPDATE: True
28 | TAU_REGRESSION: 0.6
29 | TAU_CLASSIFICATION: 0.1
30 |
31 | # classifier
32 | USE_CLASSIFIER: True
33 | SEED: 12345
34 | COEE_CLASS: 0.8
35 | USE_ATTENTION_LAYER: True
36 | CHANNEL_ATTENTION: True
37 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool']
38 |
--------------------------------------------------------------------------------
/experiments/siamfc_alex_upxcorr_otb/convert_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import re
3 | import numpy as np
4 | import argparse
5 |
6 | from scipy import io as sio
7 | from tqdm import tqdm
8 |
9 | # code adapted from https://github.com/bilylee/SiamFC-TensorFlow/blob/master/utils/train_utils.py
10 | def convert(mat_path):
11 | """Get parameter from .mat file into parms(dict)"""
12 |
13 | def squeeze(vars_):
14 | # Matlab save some params with shape (*, 1)
15 | # However, we don't need the trailing dimension in TensorFlow.
16 | if isinstance(vars_, (list, tuple)):
17 | return [np.squeeze(v, 1) for v in vars_]
18 | else:
19 | return np.squeeze(vars_, 1)
20 |
21 | netparams = sio.loadmat(mat_path)["net"]["params"][0][0]
22 | params = dict()
23 |
24 | name_map = {(1, 'conv'): 0, (1, 'bn'): 1,
25 | (2, 'conv'): 4, (2, 'bn'): 5,
26 | (3, 'conv'): 8, (3, 'bn'): 9,
27 | (4, 'conv'): 11, (4, 'bn'): 12,
28 | (5, 'conv'): 14}
29 | for i in tqdm(range(netparams.size)):
30 | param = netparams[0][i]
31 | name = param["name"][0]
32 | value = param["value"]
33 | value_size = param["value"].shape[0]
34 |
35 | match = re.match(r"([a-z]+)([0-9]+)([a-z]+)", name, re.I)
36 | if match:
37 | items = match.groups()
38 | elif name == 'adjust_f':
39 | continue
40 | elif name == 'adjust_b':
41 | params['backbone.corr_bias'] = torch.from_numpy(squeeze(value))
42 | continue
43 |
44 |
45 | op, layer, types = items
46 | layer = int(layer)
47 | if layer in [1, 2, 3, 4, 5]:
48 | idx = name_map[(layer, op)]
49 | if op == 'conv': # convolution
50 | if types == 'f':
51 | params['backbone.features.{}.weight'.format(idx)] = torch.from_numpy(value.transpose(3, 2, 0, 1))
52 | elif types == 'b':# and layer == 5:
53 | value = squeeze(value)
54 | params['backbone.features.{}.bias'.format(idx)] = torch.from_numpy(value)
55 | elif op == 'bn': # batch normalization
56 | if types == 'x':
57 | m, v = squeeze(np.split(value, 2, 1))
58 | params['backbone.features.{}.running_mean'.format(idx)] = torch.from_numpy(m)
59 | params['backbone.features.{}.running_var'.format(idx)] = torch.from_numpy(np.square(v))
60 | # params['features.{}.num_batches_tracked'.format(idx)] = torch.zeros(0)
61 | elif types == 'm':
62 | value = squeeze(value)
63 | params['backbone.features.{}.weight'.format(idx)] = torch.from_numpy(value)
64 | elif types == 'b':
65 | value = squeeze(value)
66 | params['backbone.features.{}.bias'.format(idx)] = torch.from_numpy(value)
67 | else:
68 | raise Exception
69 | return params
70 |
71 | if __name__ == '__main__':
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('--mat_path', type=str, default="./2016-08-17_gray025.net.mat")
74 | args = parser.parse_args()
75 | params = convert(args.mat_path)
76 | torch.save(params, "./model.pth")
77 |
--------------------------------------------------------------------------------
/experiments/siamfc_alex_upxcorr_vot/config.yaml:
--------------------------------------------------------------------------------
1 | META_ARC: "siamfc_alex_upxcorr"
2 |
3 | BACKBONE:
4 | TYPE: "alexnetlegacy2"
5 |
6 | ADJUST:
7 | ADJUST: False
8 |
9 | RPN:
10 | RPN: False
11 |
12 | MASK:
13 | MASK: False
14 |
15 | TRACK:
16 | TYPE: 'SiamFCTracker'
17 | TOTAL_STRIDE: 8
18 | SCALE_NUM: 3
19 | SCALE_STEP: 1.0375
20 | PENALTY_K: 0.9745
21 | WINDOW_INFLUENCE: 0.15 #0.176
22 | LR: 0.59
23 | EXEMPLAR_SIZE: 127
24 | INSTANCE_SIZE: 255
25 | BASE_SIZE: 0
26 | CONTEXT_AMOUNT: 0.5
27 |
28 | # classifier & updater
29 | USE_CLASSIFIER: True
30 | TEMPLATE_UPDATE: False
31 | TAU_REGRESSION: 0.6
32 | TAU_CLASSIFICATION: 0.1
33 | SEED: 12345
34 | COEE_CLASS: 0.6
35 | USE_ATTENTION_LAYER: False
36 | CHANNEL_ATTENTION: True
37 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool']
38 |
--------------------------------------------------------------------------------
/experiments/siamfc_alex_upxcorr_vot/convert_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import re
3 | import numpy as np
4 | import argparse
5 |
6 | from scipy import io as sio
7 | from tqdm import tqdm
8 |
9 | # code adapted from https://github.com/bilylee/SiamFC-TensorFlow/blob/master/utils/train_utils.py
10 | def convert(mat_path):
11 | """Get parameter from .mat file into parms(dict)"""
12 |
13 | def squeeze(vars_):
14 | # Matlab save some params with shape (*, 1)
15 | # However, we don't need the trailing dimension in TensorFlow.
16 | if isinstance(vars_, (list, tuple)):
17 | return [np.squeeze(v, 1) for v in vars_]
18 | else:
19 | return np.squeeze(vars_, 1)
20 |
21 | netparams = sio.loadmat(mat_path)["net"]["params"][0][0]
22 | params = dict()
23 |
24 | name_map = {(1, 'conv'): 0, (1, 'bn'): 1,
25 | (2, 'conv'): 4, (2, 'bn'): 5,
26 | (3, 'conv'): 8, (3, 'bn'): 9,
27 | (4, 'conv'): 11, (4, 'bn'): 12,
28 | (5, 'conv'): 14}
29 | for i in tqdm(range(netparams.size)):
30 | param = netparams[0][i]
31 | name = param["name"][0]
32 | value = param["value"]
33 | value_size = param["value"].shape[0]
34 |
35 | match = re.match(r"([a-z]+)([0-9]+)([a-z]+)", name, re.I)
36 | if match:
37 | items = match.groups()
38 | elif name == 'adjust_f':
39 | continue
40 | elif name == 'adjust_b':
41 | params['backbone.corr_bias'] = torch.from_numpy(squeeze(value))
42 | continue
43 |
44 |
45 | op, layer, types = items
46 | layer = int(layer)
47 | if layer in [1, 2, 3, 4, 5]:
48 | idx = name_map[(layer, op)]
49 | if op == 'conv': # convolution
50 | if types == 'f':
51 | params['backbone.features.{}.weight'.format(idx)] = torch.from_numpy(value.transpose(3, 2, 0, 1))
52 | elif types == 'b':# and layer == 5:
53 | value = squeeze(value)
54 | params['backbone.features.{}.bias'.format(idx)] = torch.from_numpy(value)
55 | elif op == 'bn': # batch normalization
56 | if types == 'x':
57 | m, v = squeeze(np.split(value, 2, 1))
58 | params['backbone.features.{}.running_mean'.format(idx)] = torch.from_numpy(m)
59 | params['backbone.features.{}.running_var'.format(idx)] = torch.from_numpy(np.square(v))
60 | # params['features.{}.num_batches_tracked'.format(idx)] = torch.zeros(0)
61 | elif types == 'm':
62 | value = squeeze(value)
63 | params['backbone.features.{}.weight'.format(idx)] = torch.from_numpy(value)
64 | elif types == 'b':
65 | value = squeeze(value)
66 | params['backbone.features.{}.bias'.format(idx)] = torch.from_numpy(value)
67 | else:
68 | raise Exception
69 | return params
70 |
71 | if __name__ == '__main__':
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('--mat_path', type=str, default="./2016-08-17.net.mat")
74 | args = parser.parse_args()
75 | params = convert(args.mat_path)
76 | torch.save(params, "./model.pth")
--------------------------------------------------------------------------------
/experiments/siammask_r50_l3/config.yaml:
--------------------------------------------------------------------------------
1 | META_ARC: "siammask_r50_l3_dwxcorr"
2 |
3 | BACKBONE:
4 | TYPE: "resnet50"
5 | KWARGS:
6 | used_layers: [0, 1, 2, 3]
7 |
8 | ADJUST:
9 | ADJUST: true
10 | TYPE: "AdjustAllLayer"
11 | LAYER: 3
12 | KWARGS:
13 | in_channels: [1024]
14 | out_channels: [256]
15 |
16 | RPN:
17 | TYPE: 'DepthwiseRPN'
18 | KWARGS:
19 | anchor_num: 5
20 | in_channels: 256
21 | out_channels: 256
22 |
23 | MASK:
24 | MASK: True
25 | TYPE: 'MaskCorr'
26 | KWARGS:
27 | in_channels: 256
28 | hidden: 256
29 | out_channels: 3969
30 |
31 | REFINE:
32 | REFINE: True
33 | TYPE: 'Refine'
34 |
35 | ANCHOR:
36 | STRIDE: 8
37 | RATIOS: [0.33, 0.5, 1, 2, 3]
38 | SCALES: [8]
39 | ANCHOR_NUM: 5
40 |
41 | TRACK:
42 | TYPE: 'SiamMaskTracker'
43 | PENALTY_K: 0.10
44 | WINDOW_INFLUENCE: 0.37 # 0.41
45 | LR: 0.32
46 | EXEMPLAR_SIZE: 127
47 | INSTANCE_SIZE: 255
48 | BASE_SIZE: 8
49 | CONTEXT_AMOUNT: 0.5
50 | MASK_THERSHOLD: 0.15
51 |
52 | # classifier
53 | USE_CLASSIFIER: True
54 | TEMPLATE_UPDATE: False
55 | SEED: 12345
56 | COEE_CLASS: 0.8
57 | USE_ATTENTION_LAYER: True
58 | CHANNEL_ATTENTION: True
59 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool']
60 |
--------------------------------------------------------------------------------
/experiments/siammask_r50_l3/convert_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 |
4 | model = torch.load('SiamMask_VOT_LD.pth', map_location=lambda storage, loc: storage)
5 |
6 | new_model = OrderedDict()
7 |
8 | for k, v in model.items():
9 | if k.startswith('features.features'):
10 | k = k.replace('features.features', 'backbone')
11 | elif k.startswith('features'):
12 | k = k.replace('features', 'neck')
13 | elif k.startswith('rpn_model'):
14 | k = k.replace('rpn_model', 'rpn_head')
15 | elif k.startswith('mask_model'):
16 | k = k.replace('mask_model.mask', 'mask_head')
17 | elif k.startswith('refine_model'):
18 | k = k.replace('refine_model', 'refine_head')
19 | new_model[k] = v
20 |
21 | torch.save(new_model, 'model.pth')
22 |
--------------------------------------------------------------------------------
/experiments/siamrpn_alex_dwxcorr_otb/config.yaml:
--------------------------------------------------------------------------------
1 | META_ARC: "siamrpn_alex_dwxcorr_otb"
2 |
3 | BACKBONE:
4 | TYPE: "alexnetlegacy"
5 | KWARGS:
6 | width_mult: 1.0
7 |
8 | ADJUST:
9 | ADJUST: False
10 |
11 | RPN:
12 | TYPE: 'DepthwiseRPN'
13 | KWARGS:
14 | anchor_num: 5
15 | in_channels: 256
16 | out_channels: 256
17 |
18 | MASK:
19 | MASK: False
20 |
21 | ANCHOR:
22 | STRIDE: 8
23 | RATIOS: [0.33, 0.5, 1, 2, 3]
24 | SCALES: [8]
25 | ANCHOR_NUM: 5
26 |
27 | TRACK:
28 | TYPE: 'SiamRPNTracker'
29 | PENALTY_K: 0.16
30 | WINDOW_INFLUENCE: 0.3 # 0.40
31 | LR: 0.30
32 | EXEMPLAR_SIZE: 127
33 | INSTANCE_SIZE: 287
34 | BASE_SIZE: 0
35 | CONTEXT_AMOUNT: 0.5
36 | SHORT_TERM_DRIFT: True
37 |
38 | # classifier & updater
39 | USE_CLASSIFIER: True
40 | TEMPLATE_UPDATE: False
41 | SEED: 123456
42 | COEE_CLASS: 0.8
43 | USE_ATTENTION_LAYER: True
44 | CHANNEL_ATTENTION: True
45 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool']
46 |
--------------------------------------------------------------------------------
/experiments/siamrpn_alex_dwxcorr_vot/config.yaml:
--------------------------------------------------------------------------------
1 | META_ARC: "siamrpn_alex_dwxcorr_vot"
2 |
3 | BACKBONE:
4 | TYPE: "alexnetlegacy"
5 | KWARGS:
6 | width_mult: 1.0
7 |
8 | ADJUST:
9 | ADJUST: False
10 |
11 | RPN:
12 | TYPE: 'DepthwiseRPN'
13 | KWARGS:
14 | anchor_num: 5
15 | in_channels: 256
16 | out_channels: 256
17 |
18 | MASK:
19 | MASK: False
20 |
21 | ANCHOR:
22 | STRIDE: 8
23 | RATIOS: [0.33, 0.5, 1, 2, 3]
24 | SCALES: [8]
25 | ANCHOR_NUM: 5
26 |
27 | TRACK:
28 | TYPE: 'SiamRPNTracker'
29 | PENALTY_K: 0.16
30 | WINDOW_INFLUENCE: 0.3 #0.40
31 | LR: 0.30
32 | EXEMPLAR_SIZE: 127
33 | INSTANCE_SIZE: 287
34 | BASE_SIZE: 0
35 | CONTEXT_AMOUNT: 0.5
36 |
37 | # classifier & updater
38 | USE_CLASSIFIER: True
39 | TEMPLATE_UPDATE: False
40 | SEED: 12
41 | COEE_CLASS: 0.8
42 | USE_ATTENTION_LAYER: False
43 | CHANNEL_ATTENTION: True
44 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool']
45 |
--------------------------------------------------------------------------------
/experiments/siamrpn_r50_l234_dwxcorr/config.yaml:
--------------------------------------------------------------------------------
1 | META_ARC: "siamrpn_r50_l234_dwxcorr"
2 |
3 | BACKBONE:
4 | TYPE: "resnet50"
5 | KWARGS:
6 | used_layers: [2, 3, 4]
7 |
8 | ADJUST:
9 | ADJUST: True
10 | TYPE: "AdjustAllLayer"
11 | LAYER: 0
12 | FUSE: 'avg'
13 | KWARGS:
14 | in_channels: [512, 1024, 2048]
15 | out_channels: [256, 256, 256]
16 |
17 | RPN:
18 | TYPE: 'MultiRPN'
19 | KWARGS:
20 | anchor_num: 5
21 | in_channels: [256, 256, 256]
22 | weighted: True
23 |
24 | MASK:
25 | MASK: False
26 |
27 | ANCHOR:
28 | STRIDE: 8
29 | RATIOS: [0.33, 0.5, 1, 2, 3]
30 | SCALES: [8]
31 | ANCHOR_NUM: 5
32 |
33 | TRACK:
34 | # matcher
35 | TYPE: 'SiamRPNTracker'
36 | PENALTY_K: 0.05
37 | WINDOW_INFLUENCE: 0.25
38 | LR: 0.38
39 | EXEMPLAR_SIZE: 127
40 | INSTANCE_SIZE: 255
41 | BASE_SIZE: 8
42 | CONTEXT_AMOUNT: 0.5
43 |
44 | # classifier & updater
45 | USE_CLASSIFIER: True
46 | TEMPLATE_UPDATE: True
47 | SEED: 12345
48 | COEE_CLASS: 0.8
49 | USE_ATTENTION_LAYER: True
50 | CHANNEL_ATTENTION: True
51 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool']
--------------------------------------------------------------------------------
/experiments/siamrpn_r50_l234_dwxcorr_otb/config.yaml:
--------------------------------------------------------------------------------
1 | META_ARC: "siamrpn_r50_l234_dwxcorr_otb"
2 |
3 | BACKBONE:
4 | TYPE: "resnet50"
5 | KWARGS:
6 | used_layers: [2, 3, 4]
7 |
8 | ADJUST:
9 | ADJUST: True
10 | TYPE: "AdjustAllLayer"
11 | LAYER: 0
12 | FUSE: 'avg' # ['avg', 'wavg', 'con']
13 | KWARGS:
14 | in_channels: [512, 1024, 2048]
15 | out_channels: [256, 256, 256]
16 |
17 | RPN:
18 | TYPE: 'MultiRPN'
19 | KWARGS:
20 | anchor_num: 5
21 | in_channels: [256, 256, 256]
22 | weighted: False
23 |
24 | MASK:
25 | MASK: False
26 |
27 | ANCHOR:
28 | STRIDE: 8
29 | RATIOS: [0.33, 0.5, 1, 2, 3]
30 | SCALES: [8]
31 | ANCHOR_NUM: 5
32 |
33 | TRACK:
34 | # matcher
35 | TYPE: 'SiamRPNTracker'
36 | PENALTY_K: 0.24
37 | WINDOW_INFLUENCE: 0.25
38 | LR: 0.25
39 | EXEMPLAR_SIZE: 127
40 | INSTANCE_SIZE: 255
41 | BASE_SIZE: 8
42 | CONTEXT_AMOUNT: 0.5
43 | SHORT_TERM_DRIFT: True
44 |
45 | # classifier & updater
46 | USE_CLASSIFIER: True
47 | TEMPLATE_UPDATE: True
48 | SEED: 123456
49 | COEE_CLASS: 0.8
50 | USE_ATTENTION_LAYER: True
51 | CHANNEL_ATTENTION: True
52 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool']
53 |
--------------------------------------------------------------------------------
/experiments/siamrpn_r50_l234_dwxcorr_vot/config.yaml:
--------------------------------------------------------------------------------
1 | META_ARC: "siamrpn_r50_l234_dwxcorr_vot"
2 |
3 | BACKBONE:
4 | TYPE: "resnet50"
5 | KWARGS:
6 | used_layers: [2, 3, 4]
7 |
8 | ADJUST:
9 | ADJUST: True
10 | TYPE: "AdjustAllLayer"
11 | LAYER: 0
12 | FUSE: 'avg'
13 | KWARGS:
14 | in_channels: [512, 1024, 2048]
15 | out_channels: [256, 256, 256]
16 |
17 | RPN:
18 | TYPE: 'MultiRPN'
19 | KWARGS:
20 | anchor_num: 5
21 | in_channels: [256, 256, 256]
22 | weighted: True
23 |
24 | MASK:
25 | MASK: False
26 |
27 | ANCHOR:
28 | STRIDE: 8
29 | RATIOS: [0.33, 0.5, 1, 2, 3]
30 | SCALES: [8]
31 | ANCHOR_NUM: 5
32 |
33 | TRACK:
34 | # matcher
35 | TYPE: 'SiamRPNTracker'
36 | PENALTY_K: 0.05
37 | WINDOW_INFLUENCE: 0.35
38 | LR: 0.38
39 | EXEMPLAR_SIZE: 127
40 | INSTANCE_SIZE: 255
41 | BASE_SIZE: 8
42 | CONTEXT_AMOUNT: 0.5
43 |
44 | # classifier & updater
45 | USE_CLASSIFIER: True
46 | TEMPLATE_UPDATE: False
47 | SEED: 123
48 | COEE_CLASS: 0.8
49 | USE_ATTENTION_LAYER: False
50 | CHANNEL_ATTENTION: True
51 | SPATIAL_ATTENTION: 'pool' # ['none', 'conv', 'pool']
52 |
--------------------------------------------------------------------------------
/experiments/siamrpn_r50_l234_dwxcorr_votlt/config.yaml:
--------------------------------------------------------------------------------
1 | META_ARC: "siamrpn_r50_l234_dwxcorr_votlt"
2 |
3 | BACKBONE:
4 | TYPE: "resnet50"
5 | KWARGS:
6 | used_layers: [2, 3, 4]
7 |
8 | ADJUST:
9 | ADJUST: True
10 | TYPE: "AdjustAllLayer"
11 | LAYER: 1
12 | KWARGS:
13 | in_channels: [512, 1024, 2048]
14 | out_channels: [128, 256, 512] #[128, 256, 512]
15 |
16 | RPN:
17 | TYPE: 'MultiRPN'
18 | KWARGS:
19 | anchor_num: 5
20 | in_channels: [128, 256, 512] #[128, 256, 512]
21 | weighted: True
22 |
23 | MASK:
24 | MASK: False
25 |
26 | ANCHOR:
27 | STRIDE: 8
28 | RATIOS: [0.33, 0.5, 1, 2, 3]
29 | SCALES: [8]
30 | ANCHOR_NUM: 5
31 |
32 | TRACK:
33 | TYPE: 'SiamRPNLTTracker'
34 | PENALTY_K: 0.05
35 | WINDOW_INFLUENCE: 0.26
36 | LR: 0.22
37 | EXEMPLAR_SIZE: 127
38 | INSTANCE_SIZE: 255
39 | BASE_SIZE: 8
40 | CONTEXT_AMOUNT: 0.5
41 |
42 | # classifier & updater
43 | USE_CLASSIFIER: True
44 | TEMPLATE_UPDATE: True
45 | SEED: 123
46 | COEE_CLASS: 0.8
47 | USE_ATTENTION_LAYER: False
48 | CHANNEL_ATTENTION: True
49 | SPATIAL_ATTENTION: 'pool' # ['none', 'pool', 'residual', 'gaussian']
50 | TARGET_NOT_FOUND_THRESHOLD: 0.2
51 |
52 |
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ $# -lt 2 ]; then
4 | echo "ARGS ERROR!"
5 | echo " bash install.sh /path/to/your/conda env_name"
6 | exit 1
7 | fi
8 |
9 | set -e
10 |
11 | conda_path=$1
12 | env_name=$2
13 |
14 | source $conda_path/etc/profile.d/conda.sh
15 |
16 | echo "****** create environment " $env_name "*****"
17 | # create environment
18 | conda create -y --name $env_name python=3.7
19 | conda activate $env_name
20 |
21 | echo "***** install numpy pytorch opencv *****"
22 | # numpy
23 | conda install -y numpy
24 | # pytorch
25 | # pytorch with cuda80/cuda90 is tested
26 | conda install -y pytorch=1.0.0 torchvision cuda90 -c pytorch
27 | # opencv
28 | pip install opencv-python
29 | # tensorboardX
30 |
31 | echo "***** install other libs *****"
32 | pip install tensorboardX
33 | # libs
34 | pip install pyyaml yacs tqdm colorama matplotlib cython
35 |
36 |
37 | echo "***** build extensions *****"
38 | python setup.py build_ext --inplace
39 |
--------------------------------------------------------------------------------
/pysot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/__init__.py
--------------------------------------------------------------------------------
/pysot/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/core/__init__.py
--------------------------------------------------------------------------------
/pysot/core/xcorr.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 |
12 | def xcorr_slow(x, kernel):
13 | """for loop to calculate cross correlation, slow version
14 | """
15 | batch = x.size()[0]
16 | out = []
17 | for i in range(batch):
18 | px = x[i]
19 | pk = kernel[i]
20 | px = px.view(1, px.size()[0], px.size()[1], px.size()[2])
21 | pk = pk.view(-1, px.size()[1], pk.size()[1], pk.size()[2])
22 | po = F.conv2d(px, pk)
23 | out.append(po)
24 | out = torch.cat(out, 0)
25 | return out
26 |
27 |
28 | def xcorr_fast(x, kernel):
29 | """group conv2d to calculate cross correlation, fast version
30 | """
31 | batch = kernel.size()[0]
32 | pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3])
33 | px = x.view(1, -1, x.size()[2], x.size()[3])
34 | po = F.conv2d(px, pk, groups=batch)
35 | po = po.view(batch, -1, po.size()[2], po.size()[3])
36 | return po
37 |
38 |
39 | def xcorr_depthwise(x, kernel):
40 | """depthwise cross correlation
41 | """
42 | batch = kernel.size(0)
43 | channel = kernel.size(1)
44 | x = x.view(1, batch*channel, x.size(2), x.size(3))
45 | kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))
46 | out = F.conv2d(x, kernel, groups=batch*channel)
47 | out = out.view(batch, channel, out.size(2), out.size(3))
48 | return out
49 |
--------------------------------------------------------------------------------
/pysot/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/datasets/__init__.py
--------------------------------------------------------------------------------
/pysot/datasets/anchor_target.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import numpy as np
9 |
10 | from pysot.core.config import cfg
11 | from pysot.utils.bbox import IoU, corner2center
12 | from pysot.utils.anchor import Anchors
13 |
14 |
15 | class AnchorTarget:
16 | def __init__(self,):
17 | self.anchors = Anchors(cfg.ANCHOR.STRIDE,
18 | cfg.ANCHOR.RATIOS,
19 | cfg.ANCHOR.SCALES)
20 |
21 | self.anchors.generate_all_anchors(im_c=cfg.TRAIN.SEARCH_SIZE//2,
22 | size=cfg.TRAIN.OUTPUT_SIZE)
23 |
24 | def __call__(self, target, size, neg=False):
25 | anchor_num = len(cfg.ANCHOR.RATIOS) * len(cfg.ANCHOR.SCALES)
26 |
27 | # -1 ignore 0 negative 1 positive
28 | cls = -1 * np.ones((anchor_num, size, size), dtype=np.int64)
29 | delta = np.zeros((4, anchor_num, size, size), dtype=np.float32)
30 | delta_weight = np.zeros((anchor_num, size, size), dtype=np.float32)
31 |
32 | def select(position, keep_num=16):
33 | num = position[0].shape[0]
34 | if num <= keep_num:
35 | return position, num
36 | slt = np.arange(num)
37 | np.random.shuffle(slt)
38 | slt = slt[:keep_num]
39 | return tuple(p[slt] for p in position), keep_num
40 |
41 | tcx, tcy, tw, th = corner2center(target)
42 |
43 | if neg:
44 | # l = size // 2 - 3
45 | # r = size // 2 + 3 + 1
46 | # cls[:, l:r, l:r] = 0
47 |
48 | cx = size // 2
49 | cy = size // 2
50 | cx += int(np.ceil((tcx - cfg.TRAIN.SEARCH_SIZE // 2) /
51 | cfg.ANCHOR.STRIDE + 0.5))
52 | cy += int(np.ceil((tcy - cfg.TRAIN.SEARCH_SIZE // 2) /
53 | cfg.ANCHOR.STRIDE + 0.5))
54 | l = max(0, cx - 3)
55 | r = min(size, cx + 4)
56 | u = max(0, cy - 3)
57 | d = min(size, cy + 4)
58 | cls[:, u:d, l:r] = 0
59 |
60 | neg, neg_num = select(np.where(cls == 0), cfg.TRAIN.NEG_NUM)
61 | cls[:] = -1
62 | cls[neg] = 0
63 |
64 | overlap = np.zeros((anchor_num, size, size), dtype=np.float32)
65 | return cls, delta, delta_weight, overlap
66 |
67 | anchor_box = self.anchors.all_anchors[0]
68 | anchor_center = self.anchors.all_anchors[1]
69 | x1, y1, x2, y2 = anchor_box[0], anchor_box[1], \
70 | anchor_box[2], anchor_box[3]
71 | cx, cy, w, h = anchor_center[0], anchor_center[1], \
72 | anchor_center[2], anchor_center[3]
73 |
74 | delta[0] = (tcx - cx) / w
75 | delta[1] = (tcy - cy) / h
76 | delta[2] = np.log(tw / w)
77 | delta[3] = np.log(th / h)
78 |
79 | overlap = IoU([x1, y1, x2, y2], target)
80 |
81 | pos = np.where(overlap > cfg.TRAIN.THR_HIGH)
82 | neg = np.where(overlap < cfg.TRAIN.THR_LOW)
83 |
84 | pos, pos_num = select(pos, cfg.TRAIN.POS_NUM)
85 | neg, neg_num = select(neg, cfg.TRAIN.TOTAL_NUM - cfg.TRAIN.POS_NUM)
86 |
87 | cls[pos] = 1
88 | delta_weight[pos] = 1. / (pos_num + 1e-6)
89 |
90 | cls[neg] = 0
91 | return cls, delta, delta_weight, overlap
92 |
--------------------------------------------------------------------------------
/pysot/datasets/augmentation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import numpy as np
9 | import cv2
10 |
11 | from pysot.utils.bbox import corner2center, \
12 | Center, center2corner, Corner
13 |
14 |
15 | class Augmentation:
16 | def __init__(self, shift, scale, blur, flip, color):
17 | self.shift = shift
18 | self.scale = scale
19 | self.blur = blur
20 | self.flip = flip
21 | self.color = color
22 | self.rgbVar = np.array(
23 | [[-0.55919361, 0.98062831, - 0.41940627],
24 | [1.72091413, 0.19879334, - 1.82968581],
25 | [4.64467907, 4.73710203, 4.88324118]], dtype=np.float32)
26 |
27 | @staticmethod
28 | def random():
29 | return np.random.random() * 2 - 1.0
30 |
31 | def _crop_roi(self, image, bbox, out_sz, padding=(0, 0, 0)):
32 | bbox = [float(x) for x in bbox]
33 | a = (out_sz-1) / (bbox[2]-bbox[0])
34 | b = (out_sz-1) / (bbox[3]-bbox[1])
35 | c = -a * bbox[0]
36 | d = -b * bbox[1]
37 | mapping = np.array([[a, 0, c],
38 | [0, b, d]]).astype(np.float)
39 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz),
40 | borderMode=cv2.BORDER_CONSTANT,
41 | borderValue=padding)
42 | return crop
43 |
44 | def _blur_aug(self, image):
45 | def rand_kernel():
46 | sizes = np.arange(5, 46, 2)
47 | size = np.random.choice(sizes)
48 | kernel = np.zeros((size, size))
49 | c = int(size/2)
50 | wx = np.random.random()
51 | kernel[:, c] += 1. / size * wx
52 | kernel[c, :] += 1. / size * (1-wx)
53 | return kernel
54 | kernel = rand_kernel()
55 | image = cv2.filter2D(image, -1, kernel)
56 | return image
57 |
58 | def _color_aug(self, image):
59 | offset = np.dot(self.rgbVar, np.random.randn(3, 1))
60 | offset = offset[::-1] # bgr 2 rgb
61 | offset = offset.reshape(3)
62 | image = image - offset
63 | return image
64 |
65 | def _gray_aug(self, image):
66 | grayed = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
67 | image = cv2.cvtColor(grayed, cv2.COLOR_GRAY2BGR)
68 | return image
69 |
70 | def _shift_scale_aug(self, image, bbox, crop_bbox, size):
71 | im_h, im_w = image.shape[:2]
72 |
73 | # adjust crop bounding box
74 | crop_bbox_center = corner2center(crop_bbox)
75 | if self.scale:
76 | scale_x = (1.0 + Augmentation.random() * self.scale)
77 | scale_y = (1.0 + Augmentation.random() * self.scale)
78 | h, w = crop_bbox_center.h, crop_bbox_center.w
79 | scale_x = min(scale_x, float(im_w) / w)
80 | scale_y = min(scale_y, float(im_h) / h)
81 | crop_bbox_center = Center(crop_bbox_center.x,
82 | crop_bbox_center.y,
83 | crop_bbox_center.w * scale_x,
84 | crop_bbox_center.h * scale_y)
85 |
86 | crop_bbox = center2corner(crop_bbox_center)
87 | if self.shift:
88 | sx = Augmentation.random() * self.shift
89 | sy = Augmentation.random() * self.shift
90 |
91 | x1, y1, x2, y2 = crop_bbox
92 |
93 | sx = max(-x1, min(im_w - 1 - x2, sx))
94 | sy = max(-y1, min(im_h - 1 - y2, sy))
95 |
96 | crop_bbox = Corner(x1 + sx, y1 + sy, x2 + sx, y2 + sy)
97 |
98 | # adjust target bounding box
99 | x1, y1 = crop_bbox.x1, crop_bbox.y1
100 | bbox = Corner(bbox.x1 - x1, bbox.y1 - y1,
101 | bbox.x2 - x1, bbox.y2 - y1)
102 |
103 | if self.scale:
104 | bbox = Corner(bbox.x1 / scale_x, bbox.y1 / scale_y,
105 | bbox.x2 / scale_x, bbox.y2 / scale_y)
106 |
107 | image = self._crop_roi(image, crop_bbox, size)
108 | return image, bbox
109 |
110 | def _flip_aug(self, image, bbox):
111 | image = cv2.flip(image, 1)
112 | width = image.shape[1]
113 | bbox = Corner(width - 1 - bbox.x2, bbox.y1,
114 | width - 1 - bbox.x1, bbox.y2)
115 | return image, bbox
116 |
117 | def __call__(self, image, bbox, size, gray=False):
118 | shape = image.shape
119 | crop_bbox = center2corner(Center(shape[0]//2, shape[1]//2,
120 | size-1, size-1))
121 | # gray augmentation
122 | if gray:
123 | image = self._gray_aug(image)
124 |
125 | # shift scale augmentation
126 | image, bbox = self._shift_scale_aug(image, bbox, crop_bbox, size)
127 |
128 | # color augmentation
129 | if self.color > np.random.random():
130 | image = self._color_aug(image)
131 |
132 | # blur augmentation
133 | if self.blur > np.random.random():
134 | image = self._blur_aug(image)
135 |
136 | # flip augmentation
137 | if self.flip and self.flip > np.random.random():
138 | image, bbox = self._flip_aug(image, bbox)
139 | return image, bbox
140 |
--------------------------------------------------------------------------------
/pysot/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/models/__init__.py
--------------------------------------------------------------------------------
/pysot/models/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | from pysot.models.backbone.alexnet import alexnetlegacy, alexnetlegacy2, alexnet
9 | from pysot.models.backbone.mobile_v2 import mobilenetv2
10 | from pysot.models.backbone.resnet_atrous import resnet18, resnet34, resnet50
11 |
12 | BACKBONES = {
13 | 'alexnetlegacy': alexnetlegacy,
14 | 'alexnetlegacy2': alexnetlegacy2,
15 | 'mobilenetv2': mobilenetv2,
16 | 'resnet18': resnet18,
17 | 'resnet34': resnet34,
18 | 'resnet50': resnet50,
19 | 'alexnet': alexnet,
20 | }
21 |
22 |
23 | def get_backbone(name, **kwargs):
24 | return BACKBONES[name](**kwargs)
25 |
--------------------------------------------------------------------------------
/pysot/models/backbone/alexnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 |
6 | import torch.nn as nn
7 | import torch
8 |
9 | class AlexNetLegacy(nn.Module):
10 | configs = [3, 96, 256, 384, 384, 256]
11 |
12 | def __init__(self, width_mult=1):
13 | configs = list(map(lambda x: 3 if x == 3 else
14 | int(x*width_mult), AlexNet.configs))
15 | super(AlexNetLegacy, self).__init__()
16 | self.features = nn.Sequential(
17 | nn.Conv2d(configs[0], configs[1], kernel_size=11, stride=2),
18 | nn.BatchNorm2d(configs[1]),
19 | nn.MaxPool2d(kernel_size=3, stride=2),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(configs[1], configs[2], kernel_size=5),
22 | nn.BatchNorm2d(configs[2]),
23 | nn.MaxPool2d(kernel_size=3, stride=2),
24 | nn.ReLU(inplace=True),
25 | nn.Conv2d(configs[2], configs[3], kernel_size=3),
26 | nn.BatchNorm2d(configs[3]),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(configs[3], configs[4], kernel_size=3),
29 | nn.BatchNorm2d(configs[4]),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(configs[4], configs[5], kernel_size=3),
32 | nn.BatchNorm2d(configs[5]),
33 | )
34 | self.feature_size = configs[5]
35 |
36 | def forward(self, x):
37 | x = self.features(x)
38 | return x
39 |
40 | class AlexNetLegacy2(nn.Module):
41 | def __init__(self):
42 | super(AlexNetLegacy2, self).__init__()
43 | self.features = nn.Sequential(
44 | nn.Conv2d(3, 96, 11, 2),
45 | nn.BatchNorm2d(96),
46 | nn.ReLU(inplace=True),
47 | nn.MaxPool2d(3, 2),
48 | nn.Conv2d(96, 256, 5, 1, groups=2),
49 | nn.BatchNorm2d(256),
50 | nn.ReLU(inplace=True),
51 | nn.MaxPool2d(3, 2),
52 | nn.Conv2d(256, 384, 3, 1),
53 | nn.BatchNorm2d(384),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(384, 384, 3, 1, groups=2),
56 | nn.BatchNorm2d(384),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(384, 256, 3, 1, groups=2)
59 | )
60 | self.corr_bias = nn.Parameter(torch.zeros(1))
61 |
62 | def forward(self, x):
63 | x = self.features(x)
64 | return x
65 |
66 | class AlexNet(nn.Module):
67 | configs = [3, 96, 256, 384, 384, 256]
68 |
69 | def __init__(self, width_mult=1):
70 | configs = list(map(lambda x: 3 if x == 3 else
71 | int(x*width_mult), AlexNet.configs))
72 | super(AlexNet, self).__init__()
73 | self.layer1 = nn.Sequential(
74 | nn.Conv2d(configs[0], configs[1], kernel_size=11, stride=2),
75 | nn.BatchNorm2d(configs[1]),
76 | nn.MaxPool2d(kernel_size=3, stride=2),
77 | nn.ReLU(inplace=True),
78 | )
79 | self.layer2 = nn.Sequential(
80 | nn.Conv2d(configs[1], configs[2], kernel_size=5),
81 | nn.BatchNorm2d(configs[2]),
82 | nn.MaxPool2d(kernel_size=3, stride=2),
83 | nn.ReLU(inplace=True),
84 | )
85 | self.layer3 = nn.Sequential(
86 | nn.Conv2d(configs[2], configs[3], kernel_size=3),
87 | nn.BatchNorm2d(configs[3]),
88 | nn.ReLU(inplace=True),
89 | )
90 | self.layer4 = nn.Sequential(
91 | nn.Conv2d(configs[3], configs[4], kernel_size=3),
92 | nn.BatchNorm2d(configs[4]),
93 | nn.ReLU(inplace=True),
94 | )
95 |
96 | self.layer5 = nn.Sequential(
97 | nn.Conv2d(configs[4], configs[5], kernel_size=3),
98 | nn.BatchNorm2d(configs[5]),
99 | )
100 | self.feature_size = configs[5]
101 |
102 | def forward(self, x):
103 | x = self.layer1(x)
104 | x = self.layer2(x)
105 | x = self.layer3(x)
106 | x = self.layer4(x)
107 | x = self.layer5(x)
108 | return x
109 |
110 |
111 | def alexnetlegacy(**kwargs):
112 | return AlexNetLegacy(**kwargs)
113 |
114 | def alexnetlegacy2(**kwargs):
115 | return AlexNetLegacy2()
116 |
117 | def alexnet(**kwargs):
118 | return AlexNet(**kwargs)
119 |
--------------------------------------------------------------------------------
/pysot/models/backbone/mobile_v2.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | def conv_bn(inp, oup, stride, padding=1):
11 | return nn.Sequential(
12 | nn.Conv2d(inp, oup, 3, stride, padding, bias=False),
13 | nn.BatchNorm2d(oup),
14 | nn.ReLU6(inplace=True)
15 | )
16 |
17 |
18 | def conv_1x1_bn(inp, oup):
19 | return nn.Sequential(
20 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
21 | nn.BatchNorm2d(oup),
22 | nn.ReLU6(inplace=True)
23 | )
24 |
25 |
26 | class InvertedResidual(nn.Module):
27 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1):
28 | super(InvertedResidual, self).__init__()
29 | self.stride = stride
30 |
31 | self.use_res_connect = self.stride == 1 and inp == oup
32 |
33 | padding = 2 - stride
34 | if dilation > 1:
35 | padding = dilation
36 |
37 | self.conv = nn.Sequential(
38 | # pw
39 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
40 | nn.BatchNorm2d(inp * expand_ratio),
41 | nn.ReLU6(inplace=True),
42 | # dw
43 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3,
44 | stride, padding, dilation=dilation,
45 | groups=inp * expand_ratio, bias=False),
46 | nn.BatchNorm2d(inp * expand_ratio),
47 | nn.ReLU6(inplace=True),
48 | # pw-linear
49 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
50 | nn.BatchNorm2d(oup),
51 | )
52 |
53 | def forward(self, x):
54 | if self.use_res_connect:
55 | return x + self.conv(x)
56 | else:
57 | return self.conv(x)
58 |
59 |
60 | class MobileNetV2(nn.Sequential):
61 | def __init__(self, width_mult=1.0, used_layers=[3, 5, 7]):
62 | super(MobileNetV2, self).__init__()
63 |
64 | self.interverted_residual_setting = [
65 | # t, c, n, s
66 | [1, 16, 1, 1, 1],
67 | [6, 24, 2, 2, 1],
68 | [6, 32, 3, 2, 1],
69 | [6, 64, 4, 2, 1],
70 | [6, 96, 3, 1, 1],
71 | [6, 160, 3, 2, 1],
72 | [6, 320, 1, 1, 1],
73 | ]
74 | # 0,2,3,4,6
75 |
76 | self.interverted_residual_setting = [
77 | # t, c, n, s
78 | [1, 16, 1, 1, 1],
79 | [6, 24, 2, 2, 1],
80 | [6, 32, 3, 2, 1],
81 | [6, 64, 4, 1, 2],
82 | [6, 96, 3, 1, 2],
83 | [6, 160, 3, 1, 4],
84 | [6, 320, 1, 1, 4],
85 | ]
86 |
87 | self.channels = [24, 32, 96, 320]
88 | self.channels = [int(c * width_mult) for c in self.channels]
89 |
90 | input_channel = int(32 * width_mult)
91 | self.last_channel = int(1280 * width_mult) \
92 | if width_mult > 1.0 else 1280
93 |
94 | self.add_module('layer0', conv_bn(3, input_channel, 2, 0))
95 |
96 | last_dilation = 1
97 |
98 | self.used_layers = used_layers
99 |
100 | for idx, (t, c, n, s, d) in \
101 | enumerate(self.interverted_residual_setting, start=1):
102 | output_channel = int(c * width_mult)
103 |
104 | layers = []
105 |
106 | for i in range(n):
107 | if i == 0:
108 | if d == last_dilation:
109 | dd = d
110 | else:
111 | dd = max(d // 2, 1)
112 | layers.append(InvertedResidual(input_channel,
113 | output_channel, s, t, dd))
114 | else:
115 | layers.append(InvertedResidual(input_channel,
116 | output_channel, 1, t, d))
117 | input_channel = output_channel
118 |
119 | last_dilation = d
120 |
121 | self.add_module('layer%d' % (idx), nn.Sequential(*layers))
122 |
123 | def forward(self, x):
124 | outputs = []
125 | for idx in range(8):
126 | name = "layer%d" % idx
127 | x = getattr(self, name)(x)
128 | outputs.append(x)
129 | p0, p1, p2, p3, p4 = [outputs[i] for i in [1, 2, 3, 5, 7]]
130 | out = [outputs[i] for i in self.used_layers]
131 | return out
132 |
133 |
134 | def mobilenetv2(**kwargs):
135 | model = MobileNetV2(**kwargs)
136 | return model
137 |
138 |
139 | if __name__ == '__main__':
140 | net = mobilenetv2()
141 |
142 | print(net)
143 |
144 | from torch.autograd import Variable
145 | tensor = Variable(torch.Tensor(1, 3, 255, 255)).cuda()
146 |
147 | net = net.cuda()
148 |
149 | out = net(tensor)
150 |
151 | for i, p in enumerate(out):
152 | print(i, p.size())
153 |
--------------------------------------------------------------------------------
/pysot/models/head/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | from pysot.models.head.mask import MaskCorr, Refine
9 | from pysot.models.head.rpn import UPChannelRPN, DepthwiseRPN, MultiRPN
10 | # from pysot.models.head.rpn import DeformableRPN, MultiDeformableRPN
11 |
12 | RPNS = {
13 | 'UPChannelRPN': UPChannelRPN,
14 | 'DepthwiseRPN': DepthwiseRPN,
15 | 'MultiRPN': MultiRPN,
16 | }
17 |
18 | MASKS = {
19 | 'MaskCorr': MaskCorr,
20 | }
21 |
22 | REFINE = {
23 | 'Refine': Refine,
24 | }
25 |
26 | def get_rpn_head(name, **kwargs):
27 | return RPNS[name](**kwargs)
28 |
29 |
30 | def get_mask_head(name, **kwargs):
31 | return MASKS[name](**kwargs)
32 |
33 | def get_refine_head(name):
34 | return REFINE[name]()
35 |
--------------------------------------------------------------------------------
/pysot/models/head/mask.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from pysot.models.head.rpn import DepthwiseXCorr
12 | from pysot.core.xcorr import xcorr_depthwise
13 |
14 |
15 | class MaskCorr(DepthwiseXCorr):
16 | def __init__(self, in_channels, hidden, out_channels,
17 | fused='none', kernel_size=3, hidden_kernel_size=5):
18 | super(MaskCorr, self).__init__(in_channels, hidden,
19 | out_channels, fused,
20 | kernel_size, hidden_kernel_size)
21 |
22 | def forward(self, kernel, search):
23 | kernel = self.conv_kernel(kernel)
24 | search = self.conv_search(search)
25 | feature = xcorr_depthwise(search, kernel)
26 | out = self.head(feature)
27 | return out, feature
28 |
29 |
30 | class Refine(nn.Module):
31 | def __init__(self):
32 | super(Refine, self).__init__()
33 | self.v0 = nn.Sequential(
34 | nn.Conv2d(64, 16, 3, padding=1),
35 | nn.ReLU(inplace=True),
36 | nn.Conv2d(16, 4, 3, padding=1),
37 | nn.ReLU(inplace=True),
38 | )
39 | self.v1 = nn.Sequential(
40 | nn.Conv2d(256, 64, 3, padding=1),
41 | nn.ReLU(inplace=True),
42 | nn.Conv2d(64, 16, 3, padding=1),
43 | nn.ReLU(inplace=True),
44 | )
45 | self.v2 = nn.Sequential(
46 | nn.Conv2d(512, 128, 3, padding=1),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(128, 32, 3, padding=1),
49 | nn.ReLU(inplace=True),
50 | )
51 | self.h2 = nn.Sequential(
52 | nn.Conv2d(32, 32, 3, padding=1),
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(32, 32, 3, padding=1),
55 | nn.ReLU(inplace=True),
56 | )
57 | self.h1 = nn.Sequential(
58 | nn.Conv2d(16, 16, 3, padding=1),
59 | nn.ReLU(inplace=True),
60 | nn.Conv2d(16, 16, 3, padding=1),
61 | nn.ReLU(inplace=True),
62 | )
63 | self.h0 = nn.Sequential(
64 | nn.Conv2d(4, 4, 3, padding=1),
65 | nn.ReLU(inplace=True),
66 | nn.Conv2d(4, 4, 3, padding=1),
67 | nn.ReLU(inplace=True),
68 | )
69 |
70 | self.deconv = nn.ConvTranspose2d(256, 32, 15, 15)
71 | self.post0 = nn.Conv2d(32, 16, 3, padding=1)
72 | self.post1 = nn.Conv2d(16, 4, 3, padding=1)
73 | self.post2 = nn.Conv2d(4, 1, 3, padding=1)
74 |
75 | def forward(self, f, corr_feature, pos):
76 | p0 = F.pad(f[0], [16, 16, 16, 16])[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61]
77 | p1 = F.pad(f[1], [8, 8, 8, 8])[:, :, 2*pos[0]:2*pos[0]+31, 2*pos[1]:2*pos[1]+31]
78 | p2 = F.pad(f[2], [4, 4, 4, 4])[:, :, pos[0]:pos[0]+15, pos[1]:pos[1]+15]
79 |
80 | p3 = corr_feature[:, :, pos[0], pos[1]].view(-1, 256, 1, 1)
81 |
82 | out = self.deconv(p3)
83 | out = self.post0(F.upsample(self.h2(out) + self.v2(p2), size=(31, 31)))
84 | out = self.post1(F.upsample(self.h1(out) + self.v1(p1), size=(61, 61)))
85 | out = self.post2(F.upsample(self.h0(out) + self.v0(p0), size=(127, 127)))
86 | out = out.view(-1, 127*127)
87 | return out
88 |
--------------------------------------------------------------------------------
/pysot/models/head/rpn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from pysot.core.xcorr import xcorr_fast, xcorr_depthwise
13 | from pysot.models.init_weight import init_weights
14 |
15 | class RPN(nn.Module):
16 | def __init__(self):
17 | super(RPN, self).__init__()
18 |
19 | def forward(self, z_f, x_f):
20 | raise NotImplementedError
21 |
22 | class UPChannelRPN(RPN):
23 | def __init__(self, anchor_num=5, feature_in=256):
24 | super(UPChannelRPN, self).__init__()
25 |
26 | cls_output = 2 * anchor_num
27 | loc_output = 4 * anchor_num
28 |
29 | self.template_cls_conv = nn.Conv2d(feature_in,
30 | feature_in * cls_output, kernel_size=3)
31 | self.template_loc_conv = nn.Conv2d(feature_in,
32 | feature_in * loc_output, kernel_size=3)
33 |
34 | self.search_cls_conv = nn.Conv2d(feature_in,
35 | feature_in, kernel_size=3)
36 | self.search_loc_conv = nn.Conv2d(feature_in,
37 | feature_in, kernel_size=3)
38 |
39 | self.loc_adjust = nn.Conv2d(loc_output, loc_output, kernel_size=1)
40 |
41 |
42 | def forward(self, z_f, x_f):
43 | cls_kernel = self.template_cls_conv(z_f)
44 | loc_kernel = self.template_loc_conv(z_f)
45 |
46 | cls_feature = self.search_cls_conv(x_f)
47 | loc_feature = self.search_loc_conv(x_f)
48 |
49 | cls = xcorr_fast(cls_feature, cls_kernel)
50 | loc = self.loc_adjust(xcorr_fast(loc_feature, loc_kernel))
51 | return cls, loc
52 |
53 | class DepthwiseXCorr(nn.Module):
54 | def __init__(self, in_channels, hidden, out_channels, fused='none', kernel_size=3, hidden_kernel_size=5):
55 | super(DepthwiseXCorr, self).__init__()
56 | self.fused = fused
57 | self.conv_kernel = nn.Sequential(
58 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False),
59 | nn.BatchNorm2d(hidden),
60 | nn.ReLU(inplace=True),
61 | )
62 | self.conv_search = nn.Sequential(
63 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False),
64 | nn.BatchNorm2d(hidden),
65 | nn.ReLU(inplace=True),
66 | )
67 | if self.fused == 'con':
68 | self.head = nn.Sequential(
69 | nn.Conv2d(hidden * 2, hidden, kernel_size=1, bias=False),
70 | nn.BatchNorm2d(hidden),
71 | nn.ReLU(inplace=True),
72 | nn.Conv2d(hidden, out_channels, kernel_size=1)
73 | )
74 | else:
75 | self.head = nn.Sequential(
76 | nn.Conv2d(hidden, hidden, kernel_size=1, bias=False),
77 | nn.BatchNorm2d(hidden),
78 | nn.ReLU(inplace=True),
79 | nn.Conv2d(hidden, out_channels, kernel_size=1)
80 | )
81 | if self.fused != 'none':
82 | self.conv_raw = nn.Conv2d(hidden, hidden, kernel_size=5)
83 |
84 | def forward(self, kernel, search):
85 | kernel = self.conv_kernel(kernel)
86 | search = self.conv_search(search)
87 | feature = xcorr_depthwise(search, kernel)
88 | if self.fused != 'none':
89 | raw = self.conv_raw(search)
90 | if self.fused == 'con':
91 | feature = torch.cat((feature, raw), dim=1)
92 | elif self.fused == 'mod':
93 | feature = torch.matmul(feature, raw)
94 | elif self.fused == 'avg':
95 | feature = feature + raw
96 | else:
97 | raise NotImplementedError()
98 | out = self.head(feature)
99 | return out
100 |
101 | class DepthwiseRPN(RPN):
102 | def __init__(self, anchor_num=5, in_channels=256, out_channels=256, fused='none'):
103 | super(DepthwiseRPN, self).__init__()
104 | self.cls = DepthwiseXCorr(in_channels, out_channels, 2 * anchor_num, fused)
105 | self.loc = DepthwiseXCorr(in_channels, out_channels, 4 * anchor_num, fused)
106 |
107 | def forward(self, z_f, x_f):
108 | cls = self.cls(z_f, x_f)
109 | loc = self.loc(z_f, x_f)
110 | return cls, loc
111 |
112 | class MultiRPN(RPN):
113 | def __init__(self, anchor_num, in_channels, weighted=False, fused='none'):
114 | super(MultiRPN, self).__init__()
115 | self.weighted = weighted
116 | for i in range(len(in_channels)):
117 | self.add_module('rpn'+str(i+2),
118 | DepthwiseRPN(anchor_num, in_channels[i], in_channels[i], fused))
119 | if self.weighted:
120 | self.cls_weight = nn.Parameter(torch.ones(len(in_channels)))
121 | self.loc_weight = nn.Parameter(torch.ones(len(in_channels)))
122 |
123 | def forward(self, z_fs, x_fs):
124 | cls = []
125 | loc = []
126 | for idx, (z_f, x_f) in enumerate(zip(z_fs, x_fs), start=2):
127 | rpn = getattr(self, 'rpn'+str(idx))
128 | c, l = rpn(z_f, x_f)
129 | cls.append(c)
130 | loc.append(l)
131 |
132 | if self.weighted:
133 | cls_weight = F.softmax(self.cls_weight, 0)
134 | loc_weight = F.softmax(self.loc_weight, 0)
135 |
136 | def avg(lst):
137 | return sum(lst) / len(lst)
138 |
139 | def weighted_avg(lst, weight):
140 | s = 0
141 | for i in range(len(weight)):
142 | s += lst[i] * weight[i]
143 | return s
144 |
145 | if self.weighted:
146 | return weighted_avg(cls, cls_weight), weighted_avg(loc, loc_weight)
147 | else:
148 | return avg(cls), avg(loc)
--------------------------------------------------------------------------------
/pysot/models/init_weight.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def init_weights(model):
5 | for m in model.modules():
6 | if isinstance(m, nn.Conv2d):
7 | nn.init.kaiming_normal_(m.weight.data,
8 | mode='fan_out',
9 | nonlinearity='relu')
10 | elif isinstance(m, nn.BatchNorm2d):
11 | m.weight.data.fill_(1)
12 | m.bias.data.zero_()
13 |
--------------------------------------------------------------------------------
/pysot/models/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 |
12 | def get_cls_loss(pred, label, select):
13 | if len(select.size()) == 0:
14 | return 0
15 | pred = torch.index_select(pred, 0, select)
16 | label = torch.index_select(label, 0, select)
17 | return F.nll_loss(pred, label)
18 |
19 |
20 | def select_cross_entropy_loss(pred, label):
21 | pred = pred.view(-1, 2)
22 | label = label.view(-1)
23 | pos = label.data.eq(1).nonzero().squeeze().cuda()
24 | neg = label.data.eq(0).nonzero().squeeze().cuda()
25 | loss_pos = get_cls_loss(pred, label, pos)
26 | loss_neg = get_cls_loss(pred, label, neg)
27 | return loss_pos * 0.5 + loss_neg * 0.5
28 |
29 |
30 | def weight_l1_loss(pred_loc, label_loc, loss_weight):
31 | b, _, sh, sw = pred_loc.size()
32 | pred_loc = pred_loc.view(b, 4, -1, sh, sw)
33 | diff = (pred_loc - label_loc).abs()
34 | diff = diff.sum(dim=1).view(b, -1, sh, sw)
35 | loss = diff * loss_weight
36 | return loss.sum().div(b)
37 |
--------------------------------------------------------------------------------
/pysot/models/neck/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from pysot.models.neck.neck import AdjustLayer, AdjustAllLayer
13 |
14 | NECKS = {
15 | 'AdjustLayer': AdjustLayer,
16 | 'AdjustAllLayer': AdjustAllLayer
17 | }
18 |
19 | def get_neck(name, **kwargs):
20 | return NECKS[name](**kwargs)
21 |
--------------------------------------------------------------------------------
/pysot/models/neck/neck.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import torch.nn as nn
9 |
10 |
11 | class AdjustLayer(nn.Module):
12 | def __init__(self, in_channels, out_channels):
13 | super(AdjustLayer, self).__init__()
14 | self.downsample = nn.Sequential(
15 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
16 | nn.BatchNorm2d(out_channels),
17 | )
18 |
19 | def forward(self, x):
20 | x = self.downsample(x)
21 | if x.size(3) < 20:
22 | l = 4
23 | r = l + 7
24 | x = x[:, :, l:r, l:r]
25 | return x
26 |
27 |
28 | class AdjustAllLayer(nn.Module):
29 | def __init__(self, in_channels, out_channels):
30 | super(AdjustAllLayer, self).__init__()
31 | self.num = len(out_channels)
32 | if self.num == 1:
33 | self.downsample = AdjustLayer(in_channels[0], out_channels[0])
34 | else:
35 | for i in range(self.num):
36 | self.add_module('downsample'+str(i+2),
37 | AdjustLayer(in_channels[i], out_channels[i]))
38 |
39 | def forward(self, features):
40 | if self.num == 1:
41 | return self.downsample(features)
42 | else:
43 | out = []
44 | for i in range(self.num):
45 | adj_layer = getattr(self, 'downsample'+str(i+2))
46 | out.append(adj_layer(features[i]))
47 | return out
48 |
--------------------------------------------------------------------------------
/pysot/tracker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/tracker/__init__.py
--------------------------------------------------------------------------------
/pysot/tracker/base_tracker.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 | # Modified by Jinghao Zhou
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 | from __future__ import unicode_literals
8 |
9 | import cv2
10 | import numpy as np
11 | import torch
12 |
13 | from pysot.core.config import cfg
14 |
15 |
16 | class BaseTracker(object):
17 |
18 | def init(self, img, bbox):
19 | raise NotImplementedError
20 |
21 | def track(self, img):
22 | raise NotImplementedError
23 |
24 |
25 | class SiameseTracker(BaseTracker):
26 | def get_subwindow(self, im, pos, model_sz, original_sz, avg_chans):
27 |
28 | if isinstance(pos, float):
29 | pos = [pos, pos]
30 | sz = original_sz
31 | im_sz = im.shape
32 | c = (original_sz + 1) / 2
33 | # context_xmin = round(pos[0] - c) # py2 and py3 round
34 | context_xmin = np.floor(pos[0] - c + 0.5)
35 | context_xmax = context_xmin + sz - 1
36 | # context_ymin = round(pos[1] - c)
37 | context_ymin = np.floor(pos[1] - c + 0.5)
38 | context_ymax = context_ymin + sz - 1
39 | left_pad = int(max(0., -context_xmin))
40 | top_pad = int(max(0., -context_ymin))
41 | right_pad = int(max(0., context_xmax - im_sz[1] + 1))
42 | bottom_pad = int(max(0., context_ymax - im_sz[0] + 1))
43 |
44 | context_xmin = context_xmin + left_pad
45 | context_xmax = context_xmax + left_pad
46 | context_ymin = context_ymin + top_pad
47 | context_ymax = context_ymax + top_pad
48 |
49 | r, c, k = im.shape
50 | if any([top_pad, bottom_pad, left_pad, right_pad]):
51 | size = (r + top_pad + bottom_pad, c + left_pad + right_pad, k)
52 | te_im = np.zeros(size, np.uint8)
53 | te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im
54 | if top_pad:
55 | te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans
56 | if bottom_pad:
57 | te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans
58 | if left_pad:
59 | te_im[:, 0:left_pad, :] = avg_chans
60 | if right_pad:
61 | te_im[:, c + left_pad:, :] = avg_chans
62 | im_patch = te_im[int(context_ymin):int(context_ymax + 1),
63 | int(context_xmin):int(context_xmax + 1), :]
64 | else:
65 | im_patch = im[int(context_ymin):int(context_ymax + 1),
66 | int(context_xmin):int(context_xmax + 1), :]
67 |
68 | if not np.array_equal(model_sz, original_sz):
69 | im_patch = cv2.resize(im_patch, (model_sz, model_sz))
70 | im_patch = im_patch.transpose(2, 0, 1)
71 | im_patch = im_patch[np.newaxis, :, :, :]
72 | im_patch = im_patch.astype(np.float32)
73 | im_patch = torch.from_numpy(im_patch)
74 | if cfg.CUDA:
75 | im_patch = im_patch.cuda()
76 | return im_patch
77 |
--------------------------------------------------------------------------------
/pysot/tracker/classifier/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/tracker/classifier/__init__.py
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/__init__.py:
--------------------------------------------------------------------------------
1 | from .tensorlist import TensorList
2 | from .tensordict import TensorDict
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 |
4 | def normalize(score):
5 | score = (score - np.min(score)) / (np.max(score) - np.min(score))
6 | return score
7 |
8 | def normfun(x, mu, sigma):
9 | pdf = np.exp(-((x - mu)**2) / (2* sigma**2))
10 | return pdf
11 |
12 | def generate_xy_attention(center, size):
13 |
14 | a = np.linspace(-size//2+1, size//2, size)
15 | x = - normfun(a, center[1], 10).reshape((size,1)) + 2
16 | y = - normfun(a, center[0], 10).reshape((1,size)) + 2
17 | z = normalize(1. / np.dot(np.abs(x), np.abs(y)))
18 | return z
19 |
20 | if __name__ == '__main__':
21 | generate_xy_attention([0,0], 31)
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/augmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | import cv2 as cv
6 | from .preprocessing import numpy_to_torch, torch_to_numpy
7 |
8 |
9 | class Transform:
10 | """Base data augmentation transform class."""
11 |
12 | def __init__(self, output_sz = None, shift = None):
13 | self.output_sz = output_sz
14 | self.shift = (0,0) if shift is None else shift
15 |
16 | def __call__(self, image):
17 | raise NotImplementedError
18 |
19 | def crop_to_output(self, image):
20 | if isinstance(image, torch.Tensor):
21 | imsz = image.shape[2:]
22 | if self.output_sz is None:
23 | pad_h = 0
24 | pad_w = 0
25 | else:
26 | pad_h = (self.output_sz[0] - imsz[0]) / 2
27 | pad_w = (self.output_sz[1] - imsz[1]) / 2
28 |
29 | pad_left = math.floor(pad_w) + self.shift[1]
30 | pad_right = math.ceil(pad_w) - self.shift[1]
31 | pad_top = math.floor(pad_h) + self.shift[0]
32 | pad_bottom = math.ceil(pad_h) - self.shift[0]
33 |
34 | return F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), 'replicate')
35 | else:
36 | raise NotImplementedError
37 |
38 | class Identity(Transform):
39 | """Identity transformation."""
40 | def __call__(self, image):
41 | return self.crop_to_output(image)
42 |
43 | class FlipHorizontal(Transform):
44 | """Flip along horizontal axis."""
45 | def __call__(self, image):
46 | if isinstance(image, torch.Tensor):
47 | return self.crop_to_output(image.flip((3,)))
48 | else:
49 | return np.fliplr(image)
50 |
51 | class FlipVertical(Transform):
52 | """Flip along vertical axis."""
53 | def __call__(self, image: torch.Tensor):
54 | if isinstance(image, torch.Tensor):
55 | return self.crop_to_output(image.flip((2,)))
56 | else:
57 | return np.flipud(image)
58 |
59 | class Translation(Transform):
60 | """Translate."""
61 | def __init__(self, translation, output_sz = None, shift = None):
62 | super().__init__(output_sz, shift)
63 | self.shift = (self.shift[0] + translation[0], self.shift[1] + translation[1])
64 |
65 | def __call__(self, image):
66 | if isinstance(image, torch.Tensor):
67 | return self.crop_to_output(image)
68 | else:
69 | raise NotImplementedError
70 |
71 | class Scale(Transform):
72 | """Scale."""
73 | def __init__(self, scale_factor, output_sz = None, shift = None):
74 | super().__init__(output_sz, shift)
75 | self.scale_factor = scale_factor
76 |
77 | def __call__(self, image):
78 | if isinstance(image, torch.Tensor):
79 | # Calculate new size. Ensure that it is even so that crop/pad becomes easier
80 | h_orig, w_orig = image.shape[2:]
81 |
82 | if h_orig != w_orig:
83 | raise NotImplementedError
84 |
85 | h_new = round(h_orig /self.scale_factor)
86 | h_new += (h_new - h_orig) % 2
87 | w_new = round(w_orig /self.scale_factor)
88 | w_new += (w_new - w_orig) % 2
89 |
90 | image_resized = F.interpolate(image, [h_new, w_new], mode='bilinear')
91 |
92 | return self.crop_to_output(image_resized)
93 | else:
94 | raise NotImplementedError
95 |
96 |
97 | class Affine(Transform):
98 | """Affine transformation."""
99 | def __init__(self, transform_matrix, output_sz = None, shift = None):
100 | super().__init__(output_sz, shift)
101 | self.transform_matrix = transform_matrix
102 |
103 | def __call__(self, image):
104 | if isinstance(image, torch.Tensor):
105 | return self.crop_to_output(numpy_to_torch(self(torch_to_numpy(image))))
106 | else:
107 | return cv.warpAffine(image, self.transform_matrix, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE)
108 |
109 |
110 | class Rotate(Transform):
111 | """Rotate with given angle."""
112 | def __init__(self, angle, output_sz = None, shift = None):
113 | super().__init__(output_sz, shift)
114 | self.angle = math.pi * angle/180
115 |
116 | def __call__(self, image):
117 | if isinstance(image, torch.Tensor):
118 | return self.crop_to_output(numpy_to_torch(self(torch_to_numpy(image))))
119 | else:
120 | c = (np.expand_dims(np.array(image.shape[:2]),1)-1)/2
121 | R = np.array([[math.cos(self.angle), math.sin(self.angle)],
122 | [-math.sin(self.angle), math.cos(self.angle)]])
123 | H =np.concatenate([R, c - R @ c], 1)
124 | return cv.warpAffine(image, H, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE)
125 |
126 |
127 | class Blur(Transform):
128 | """Blur with given sigma (can be axis dependent)."""
129 | def __init__(self, sigma, output_sz = None, shift = None):
130 | super().__init__(output_sz, shift)
131 | if isinstance(sigma, (float, int)):
132 | sigma = (sigma, sigma)
133 | self.sigma = sigma
134 | self.filter_size = [math.ceil(2*s) for s in self.sigma]
135 | x_coord = [torch.arange(-sz, sz+1, dtype=torch.float32) for sz in self.filter_size]
136 | self.filter = [torch.exp(-(x**2)/(2*s**2)) for x, s in zip(x_coord, self.sigma)]
137 | self.filter[0] = self.filter[0].view(1,1,-1,1) / self.filter[0].sum()
138 | self.filter[1] = self.filter[1].view(1,1,1,-1) / self.filter[1].sum()
139 |
140 | def __call__(self, image):
141 | if isinstance(image, torch.Tensor):
142 | sz = image.shape[2:]
143 | im1 = F.conv2d(image.view(-1,1,sz[0],sz[1]), self.filter[0], padding=(self.filter_size[0],0))
144 | return self.crop_to_output(F.conv2d(im1, self.filter[1], padding=(0,self.filter_size[1])).view(1,-1,sz[0],sz[1]))
145 | else:
146 | raise NotImplementedError
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/complex.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .tensorlist import tensor_operation
3 |
4 |
5 | def is_complex(a: torch.Tensor) -> bool:
6 | return a.dim() >= 4 and a.shape[-1] == 2
7 |
8 |
9 | def is_real(a: torch.Tensor) -> bool:
10 | return not is_complex(a)
11 |
12 |
13 | @tensor_operation
14 | def mult(a: torch.Tensor, b: torch.Tensor):
15 | """Pointwise complex multiplication of complex tensors."""
16 |
17 | if is_real(a):
18 | if a.dim() >= b.dim():
19 | raise ValueError('Incorrect dimensions.')
20 | # a is real
21 | return mult_real_cplx(a, b)
22 | if is_real(b):
23 | if b.dim() >= a.dim():
24 | raise ValueError('Incorrect dimensions.')
25 | # b is real
26 | return mult_real_cplx(b, a)
27 |
28 | # Both complex
29 | c = mult_real_cplx(a[..., 0], b)
30 | c[..., 0] -= a[..., 1] * b[..., 1]
31 | c[..., 1] += a[..., 1] * b[..., 0]
32 | return c
33 |
34 |
35 | @tensor_operation
36 | def mult_conj(a: torch.Tensor, b: torch.Tensor):
37 | """Pointwise complex multiplication of complex tensors, with conjugate on b: a*conj(b)."""
38 |
39 | if is_real(a):
40 | if a.dim() >= b.dim():
41 | raise ValueError('Incorrect dimensions.')
42 | # a is real
43 | return mult_real_cplx(a, conj(b))
44 | if is_real(b):
45 | if b.dim() >= a.dim():
46 | raise ValueError('Incorrect dimensions.')
47 | # b is real
48 | return mult_real_cplx(b, a)
49 |
50 | # Both complex
51 | c = mult_real_cplx(b[...,0], a)
52 | c[..., 0] += a[..., 1] * b[..., 1]
53 | c[..., 1] -= a[..., 0] * b[..., 1]
54 | return c
55 |
56 |
57 | @tensor_operation
58 | def mult_real_cplx(a: torch.Tensor, b: torch.Tensor):
59 | """Pointwise complex multiplication of real tensor a with complex tensor b."""
60 |
61 | if is_real(b):
62 | raise ValueError('Last dimension must have length 2.')
63 |
64 | return a.unsqueeze(-1) * b
65 |
66 |
67 | @tensor_operation
68 | def div(a: torch.Tensor, b: torch.Tensor):
69 | """Pointwise complex division of complex tensors."""
70 |
71 | if is_real(b):
72 | if b.dim() >= a.dim():
73 | raise ValueError('Incorrect dimensions.')
74 | # b is real
75 | return div_cplx_real(a, b)
76 |
77 | return div_cplx_real(mult_conj(a, b), abs_sqr(b))
78 |
79 |
80 | @tensor_operation
81 | def div_cplx_real(a: torch.Tensor, b: torch.Tensor):
82 | """Pointwise complex division of complex tensor a with real tensor b."""
83 |
84 | if is_real(a):
85 | raise ValueError('Last dimension must have length 2.')
86 |
87 | return a / b.unsqueeze(-1)
88 |
89 |
90 | @tensor_operation
91 | def abs_sqr(a: torch.Tensor):
92 | """Squared absolute value."""
93 |
94 | if is_real(a):
95 | raise ValueError('Last dimension must have length 2.')
96 |
97 | return torch.sum(a*a, -1)
98 |
99 |
100 | @tensor_operation
101 | def abs(a: torch.Tensor):
102 | """Absolute value."""
103 |
104 | if is_real(a):
105 | raise ValueError('Last dimension must have length 2.')
106 |
107 | return torch.sqrt(abs_sqr(a))
108 |
109 |
110 | @tensor_operation
111 | def conj(a: torch.Tensor):
112 | """Complex conjugate."""
113 |
114 | if is_real(a):
115 | raise ValueError('Last dimension must have length 2.')
116 |
117 | # return a * torch.Tensor([1, -1], device=a.device)
118 | return complex(a[...,0], -a[...,1])
119 |
120 |
121 | @tensor_operation
122 | def real(a: torch.Tensor):
123 | """Real part."""
124 |
125 | if is_real(a):
126 | raise ValueError('Last dimension must have length 2.')
127 |
128 | return a[..., 0]
129 |
130 |
131 | @tensor_operation
132 | def imag(a: torch.Tensor):
133 | """Imaginary part."""
134 |
135 | if is_real(a):
136 | raise ValueError('Last dimension must have length 2.')
137 |
138 | return a[..., 1]
139 |
140 |
141 | @tensor_operation
142 | def complex(a: torch.Tensor, b: torch.Tensor = None):
143 | """Create complex tensor from real and imaginary part."""
144 |
145 | if b is None:
146 | b = a.new_zeros(a.shape)
147 | elif a is None:
148 | a = b.new_zeros(b.shape)
149 |
150 | return torch.cat((a.unsqueeze(-1), b.unsqueeze(-1)), -1)
151 |
152 |
153 | @tensor_operation
154 | def mtimes(a: torch.Tensor, b: torch.Tensor, conj_a=False, conj_b=False):
155 | """Complex matrix multiplication of complex tensors.
156 | The dimensions (-3, -2) are matrix multiplied. -1 is the complex dimension."""
157 |
158 | if is_real(a):
159 | if a.dim() >= b.dim():
160 | raise ValueError('Incorrect dimensions.')
161 | return mtimes_real_complex(a, b, conj_b=conj_b)
162 | if is_real(b):
163 | if b.dim() >= a.dim():
164 | raise ValueError('Incorrect dimensions.')
165 | return mtimes_complex_real(a, b, conj_a=conj_a)
166 |
167 | if not conj_a and not conj_b:
168 | return complex(torch.matmul(a[..., 0], b[..., 0]) - torch.matmul(a[..., 1], b[..., 1]),
169 | torch.matmul(a[..., 0], b[..., 1]) + torch.matmul(a[..., 1], b[..., 0]))
170 | if conj_a and not conj_b:
171 | return complex(torch.matmul(a[..., 0], b[..., 0]) + torch.matmul(a[..., 1], b[..., 1]),
172 | torch.matmul(a[..., 0], b[..., 1]) - torch.matmul(a[..., 1], b[..., 0]))
173 | if not conj_a and conj_b:
174 | return complex(torch.matmul(a[..., 0], b[..., 0]) + torch.matmul(a[..., 1], b[..., 1]),
175 | torch.matmul(a[..., 1], b[..., 0]) - torch.matmul(a[..., 0], b[..., 1]))
176 | if conj_a and conj_b:
177 | return complex(torch.matmul(a[..., 0], b[..., 0]) - torch.matmul(a[..., 1], b[..., 1]),
178 | -torch.matmul(a[..., 0], b[..., 1]) - torch.matmul(a[..., 1], b[..., 0]))
179 |
180 |
181 | @tensor_operation
182 | def mtimes_real_complex(a: torch.Tensor, b: torch.Tensor, conj_b=False):
183 | if is_real(b):
184 | raise ValueError('Incorrect dimensions.')
185 |
186 | if not conj_b:
187 | return complex(torch.matmul(a, b[..., 0]), torch.matmul(a, b[..., 1]))
188 | if conj_b:
189 | return complex(torch.matmul(a, b[..., 0]), -torch.matmul(a, b[..., 1]))
190 |
191 |
192 | @tensor_operation
193 | def mtimes_complex_real(a: torch.Tensor, b: torch.Tensor, conj_a=False):
194 | if is_real(a):
195 | raise ValueError('Incorrect dimensions.')
196 |
197 | if not conj_a:
198 | return complex(torch.matmul(a[..., 0], b), torch.matmul(a[..., 1], b))
199 | if conj_a:
200 | return complex(torch.matmul(a[..., 0], b), -torch.matmul(a[..., 1], b))
201 |
202 |
203 | @tensor_operation
204 | def exp_imag(a: torch.Tensor):
205 | """Complex exponential with imaginary input: e^(i*a)"""
206 |
207 | a = a.unsqueeze(-1)
208 | return torch.cat((torch.cos(a), torch.sin(a)), -1)
209 |
210 |
211 |
212 |
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/fourier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import pysot.tracker.classifier.libs.complex as complex
4 | from .tensorlist import tensor_operation, TensorList
5 |
6 |
7 | @tensor_operation
8 | def rfftshift2(a: torch.Tensor):
9 | h = a.shape[2] + 2
10 | return torch.cat((a[:,:,(h-1)//2:,...], a[:,:,:h//2,...]), 2)
11 |
12 |
13 | @tensor_operation
14 | def irfftshift2(a: torch.Tensor):
15 | mid = int((a.shape[2]-1)/2)
16 | return torch.cat((a[:,:,mid:,...], a[:,:,:mid,...]), 2)
17 |
18 |
19 | @tensor_operation
20 | def cfft2(a):
21 | """Do FFT and center the low frequency component.
22 | Always produces odd (full) output sizes."""
23 |
24 | return rfftshift2(torch.rfft(a, 2))
25 |
26 |
27 | @tensor_operation
28 | def cifft2(a, signal_sizes=None):
29 | """Do inverse FFT corresponding to cfft2."""
30 |
31 | return torch.irfft(irfftshift2(a), 2, signal_sizes=signal_sizes)
32 |
33 |
34 | @tensor_operation
35 | def sample_fs(a: torch.Tensor, grid_sz: torch.Tensor = None, rescale = True):
36 | """Samples the Fourier series."""
37 |
38 | # Size of the fourier series
39 | sz = torch.Tensor([a.shape[2], 2*a.shape[3]-1]).float()
40 |
41 | # Default grid
42 | if grid_sz is None or sz[0] == grid_sz[0] and sz[1] == grid_sz[1]:
43 | if rescale:
44 | return sz.prod().item() * cifft2(a)
45 | return cifft2(a)
46 |
47 | if sz[0] > grid_sz[0] or sz[1] > grid_sz[1]:
48 | raise ValueError("Only grid sizes that are smaller than the Fourier series size are supported.")
49 |
50 | tot_pad = (grid_sz - sz).tolist()
51 | is_even = [s.item() % 2 == 0 for s in sz]
52 |
53 | # Compute paddings
54 | pad_top = int((tot_pad[0]+1)/2) if is_even[0] else int(tot_pad[0]/2)
55 | pad_bottom = int(tot_pad[0] - pad_top)
56 | pad_right = int((tot_pad[1]+1)/2)
57 |
58 | if rescale:
59 | return grid_sz.prod().item() * cifft2(F.pad(a, (0, 0, 0, pad_right, pad_top, pad_bottom)), signal_sizes=grid_sz.long().tolist())
60 | else:
61 | return cifft2(F.pad(a, (0, 0, 0, pad_right, pad_top, pad_bottom)), signal_sizes=grid_sz.long().tolist())
62 |
63 |
64 | def get_frequency_coord(sz, add_complex_dim = False, device='cpu'):
65 | """Frequency coordinates."""
66 |
67 | ky = torch.arange(-int((sz[0]-1)/2), int(sz[0]/2+1), dtype=torch.float32, device=device).view(1,1,-1,1)
68 | kx = torch.arange(0, int(sz[1]/2+1), dtype=torch.float32, device=device).view(1,1,1,-1)
69 |
70 | if add_complex_dim:
71 | ky = ky.unsqueeze(-1)
72 | kx = kx.unsqueeze(-1)
73 |
74 | return ky, kx
75 |
76 |
77 | @tensor_operation
78 | def shift_fs(a: torch.Tensor, shift: torch.Tensor):
79 | """Shift a sample a in the Fourier domain.
80 | Params:
81 | a : The fourier coefficiens of the sample.
82 | shift : The shift to be performed normalized to the range [-pi, pi]."""
83 |
84 | if a.dim() != 5:
85 | raise ValueError('a must be the Fourier coefficients, a 5-dimensional tensor.')
86 |
87 | if shift[0] == 0 and shift[1] == 0:
88 | return a
89 |
90 | ky, kx = get_frequency_coord((a.shape[2], 2*a.shape[3]-1), device=a.device)
91 |
92 | return complex.mult(complex.mult(a, complex.exp_imag(shift[0].item() * ky)), complex.exp_imag(shift[1].item() * kx))
93 |
94 |
95 | def sum_fs(a: TensorList) -> torch.Tensor:
96 | """Sum a list of Fourier series expansions."""
97 |
98 | s = None
99 | mid = None
100 |
101 | for e in sorted(a, key=lambda elem: elem.shape[-3], reverse=True):
102 | if s is None:
103 | s = e.clone()
104 | mid = int((s.shape[-3] - 1) / 2)
105 | else:
106 | # Compute coordinates
107 | top = mid - int((e.shape[-3] - 1) / 2)
108 | bottom = mid + int(e.shape[-3] / 2) + 1
109 | right = e.shape[-2]
110 |
111 | # Add the data
112 | s[..., top:bottom, :right, :] += e
113 |
114 | return s
115 |
116 |
117 | def sum_fs12(a: TensorList) -> torch.Tensor:
118 | """Sum a list of Fourier series expansions."""
119 |
120 | s = None
121 | mid = None
122 |
123 | for e in sorted(a, key=lambda elem: elem.shape[0], reverse=True):
124 | if s is None:
125 | s = e.clone()
126 | mid = int((s.shape[0] - 1) / 2)
127 | else:
128 | # Compute coordinates
129 | top = mid - int((e.shape[0] - 1) / 2)
130 | bottom = mid + int(e.shape[0] / 2) + 1
131 | right = e.shape[1]
132 |
133 | # Add the data
134 | s[top:bottom, :right, ...] += e
135 |
136 | return s
137 |
138 |
139 | @tensor_operation
140 | def inner_prod_fs(a: torch.Tensor, b: torch.Tensor):
141 | if complex.is_complex(a) and complex.is_complex(b):
142 | return 2 * (a.reshape(-1) @ b.reshape(-1)) - a[:, :, :, 0, :].reshape(-1) @ b[:, :, :, 0, :].reshape(-1)
143 | elif complex.is_real(a) and complex.is_real(b):
144 | return 2 * (a.reshape(-1) @ b.reshape(-1)) - a[:, :, :, 0].reshape(-1) @ b[:, :, :, 0].reshape(-1)
145 | else:
146 | raise NotImplementedError('Not implemented for mixed real and complex.')
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/operation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .tensorlist import tensor_operation, TensorList
4 |
5 |
6 | @tensor_operation
7 | def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None, stride=1, padding=0, dilation=1, groups=1, mode=None):
8 | """Standard conv2d. Returns the input if weight=None."""
9 |
10 | if weight is None:
11 | return input
12 |
13 | ind = None
14 | if mode is not None:
15 | if padding != 0:
16 | raise ValueError('Cannot input both padding and mode.')
17 | if mode == 'same':
18 | padding = (weight.shape[2]//2, weight.shape[3]//2)
19 | if weight.shape[2] % 2 == 0 or weight.shape[3] % 2 == 0:
20 | ind = (slice(-1) if weight.shape[2] % 2 == 0 else slice(None),
21 | slice(-1) if weight.shape[3] % 2 == 0 else slice(None))
22 | elif mode == 'valid':
23 | padding = (0, 0)
24 | elif mode == 'full':
25 | padding = (weight.shape[2]-1, weight.shape[3]-1)
26 | else:
27 | raise ValueError('Unknown mode for padding.')
28 |
29 | out = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
30 | if ind is None:
31 | return out
32 | return out[:,:,ind[0],ind[1]]
33 |
34 |
35 | @tensor_operation
36 | def conv1x1(input: torch.Tensor, weight: torch.Tensor):
37 | """Do a convolution with a 1x1 kernel weights. Implemented with matmul, which can be faster than using conv."""
38 |
39 | if weight is None:
40 | return input
41 |
42 | return torch.matmul(weight.view(weight.shape[0], weight.shape[1]),
43 | input.view(input.shape[0], input.shape[1], -1)).view(input.shape[0], weight.shape[0], input.shape[2], input.shape[3])
44 |
45 | @tensor_operation
46 | def spatial_attention(input: torch.Tensor, dim: int=0, keepdim: bool=True):
47 | return torch.sigmoid(torch.mean(input, dim, keepdim))
48 |
49 | @tensor_operation
50 | def adaptive_avg_pool2d(input: torch.Tensor, shape):
51 | return F.adaptive_avg_pool2d(input, shape)
52 |
53 | @tensor_operation
54 | def sigmoid(input: torch.Tensor):
55 | return torch.sigmoid(input)
56 |
57 | @tensor_operation
58 | def softmax(input: torch.Tensor):
59 | x_shape = input.size()
60 | return F.softmax(input.reshape(x_shape[0], -1), dim=1).reshape(x_shape)
61 |
62 | @tensor_operation
63 | def matmul(a: torch.Tensor, b: torch.Tensor):
64 | return a * b.expand_as(a)
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/params.py:
--------------------------------------------------------------------------------
1 | from .tensorlist import TensorList
2 | import random
3 |
4 |
5 | class TrackerParams:
6 | """Class for tracker parameters."""
7 | def free_memory(self):
8 | for a in dir(self):
9 | if not a.startswith('__') and hasattr(getattr(self, a), 'free_memory'):
10 | getattr(self, a).free_memory()
11 |
12 |
13 | class FeatureParams:
14 | """Class for feature specific parameters"""
15 | def __init__(self, *args, **kwargs):
16 | if len(args) > 0:
17 | raise ValueError
18 |
19 | for name, val in kwargs.items():
20 | if isinstance(val, list):
21 | setattr(self, name, TensorList(val))
22 | else:
23 | setattr(self, name, val)
24 |
25 |
26 | def Choice(*args):
27 | """Can be used to sample random parameter values."""
28 | return random.choice(args)
29 |
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('TkAgg')
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def show_tensor(a: torch.Tensor, fig_num = None, title = None):
9 | """Display a 2D tensor.
10 | args:
11 | fig_num: Figure number.
12 | title: Title of figure.
13 | """
14 | a_np = a.squeeze().cpu().clone().detach().numpy()
15 | if a_np.ndim == 3:
16 | a_np = np.transpose(a_np, (1, 2, 0))
17 | plt.figure(fig_num)
18 | plt.tight_layout()
19 | plt.cla()
20 | plt.imshow(a_np)
21 | plt.axis('off')
22 | plt.axis('equal')
23 | if title is not None:
24 | plt.title(title)
25 | plt.draw()
26 | plt.pause(0.001)
27 |
28 |
29 | def plot_graph(a: torch.Tensor, fig_num = None, title = None):
30 | """Plot graph. Data is a 1D tensor.
31 | args:
32 | fig_num: Figure number.
33 | title: Title of figure.
34 | """
35 | a_np = a.squeeze().cpu().clone().detach().numpy()
36 | if a_np.ndim > 1:
37 | raise ValueError
38 | plt.figure(fig_num)
39 | # plt.tight_layout()
40 | plt.cla()
41 | plt.plot(a_np)
42 | if title is not None:
43 | plt.title(title)
44 | plt.draw()
45 | plt.pause(0.001)
46 |
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/preprocessing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 |
6 | def numpy_to_torch(a: np.ndarray):
7 | return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0)
8 |
9 |
10 | def torch_to_numpy(a: torch.Tensor):
11 | return a.squeeze(0).permute(1,2,0).numpy()
12 |
13 |
14 | def sample_patch(im: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, output_sz: torch.Tensor = None):
15 | """Sample an image patch.
16 |
17 | args:
18 | im: Image
19 | pos: center position of crop
20 | sample_sz: size to crop
21 | output_sz: size to resize to
22 | """
23 |
24 | # copy and convert
25 | posl = pos.long().clone()
26 |
27 | # Compute pre-downsampling factor
28 | if output_sz is not None:
29 | resize_factor = torch.min(sample_sz.float() / output_sz.float()).item()
30 | df = int(max(int(resize_factor - 0.1), 1))
31 | else:
32 | df = int(1)
33 |
34 | sz = sample_sz.float() / df # new size
35 |
36 | # Do downsampling
37 | if df > 1:
38 | os = posl % df # offset
39 | posl = (posl - os) / df # new position
40 | im2 = im[..., os[0].item()::df, os[1].item()::df] # downsample
41 | else:
42 | im2 = im
43 |
44 | # compute size to crop
45 | szl = torch.max(sz.round(), torch.Tensor([2])).long()
46 |
47 | # Extract top and bottom coordinates
48 | tl = posl - (szl - 1)/2
49 | br = posl + szl/2
50 |
51 | # Get image patch
52 | im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3] + 1, -tl[0].item(), br[0].item() - im2.shape[2] + 1), 'replicate')
53 |
54 | if output_sz is None or (im_patch.shape[-2] == output_sz[0] and im_patch.shape[-1] == output_sz[1]):
55 | return im_patch
56 |
57 | # Resample
58 | im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='bilinear')
59 |
60 | return im_patch
61 |
--------------------------------------------------------------------------------
/pysot/tracker/classifier/libs/tensordict.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch
3 |
4 |
5 | class TensorDict(OrderedDict):
6 | """Container mainly used for dicts of torch tensors. Extends OrderedDict with pytorch functionality."""
7 |
8 | def concat(self, other):
9 | """Concatenates two dicts without copying internal data."""
10 | return TensorDict(self, **other)
11 |
12 | def copy(self):
13 | return TensorDict(super(TensorDict, self).copy())
14 |
15 | def __getattr__(self, name):
16 | if not hasattr(torch.Tensor, name):
17 | raise AttributeError('\'TensorDict\' object has not attribute \'{}\''.format(name))
18 |
19 | def apply_attr(*args, **kwargs):
20 | return TensorDict({n: getattr(e, name)(*args, **kwargs) if hasattr(e, name) else e for n, e in self.items()})
21 | return apply_attr
22 |
23 | def attribute(self, attr: str, *args):
24 | return TensorDict({n: getattr(e, attr, *args) for n, e in self.items()})
25 |
26 | def apply(self, fn, *args, **kwargs):
27 | return TensorDict({n: fn(e, *args, **kwargs) for n, e in self.items()})
28 |
29 | @staticmethod
30 | def _iterable(a):
31 | return isinstance(a, (TensorDict, list))
32 |
33 |
--------------------------------------------------------------------------------
/pysot/tracker/classifier/optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pysot.tracker.classifier.libs import optimization, TensorList, operation
3 | import math
4 | from pysot.core.config import cfg
5 |
6 |
7 | class FactorizedConvProblem(optimization.L2Problem):
8 | def __init__(self, training_samples: TensorList, y: TensorList, use_attention: bool,
9 | filter_reg: torch.Tensor, projection_reg, sample_weights: TensorList,
10 | projection_activation, att_activation, response_activation):
11 |
12 | self.training_samples = training_samples
13 | self.y = y
14 | self.sample_weights = sample_weights
15 | self.use_attetion = use_attention
16 | self.filter_reg = filter_reg
17 | self.projection_reg = projection_reg
18 | self.projection_activation = projection_activation
19 | self.att_activation = att_activation
20 | self.response_activation = response_activation
21 |
22 | if self.use_attetion:
23 | self.diag_M = self.filter_reg.concat(projection_reg).concat(projection_reg).concat(projection_reg)
24 | else:
25 | self.diag_M = self.filter_reg.concat(projection_reg)
26 |
27 | def __call__(self, x: TensorList):
28 |
29 | if self.use_attetion:
30 | filter = x[:1]
31 | fc2 = x[1:2]
32 | fc1 = x[2:3]
33 | P = x[3:4]
34 | else:
35 | filter = x[:len(x)//2] # w2 in paper
36 | P = x[len(x)//2:] # w1 in paper
37 |
38 | # Compression module
39 | compressed_samples = operation.conv1x1(self.training_samples, P).apply(self.projection_activation)
40 |
41 | # Attention module
42 | if self.use_attetion:
43 | if cfg.TRACK.CHANNEL_ATTENTION:
44 | global_average = operation.adaptive_avg_pool2d(compressed_samples, 1)
45 | temp_variables = operation.conv1x1(global_average, fc1).apply(self.att_activation)
46 | channel_attention = operation.sigmoid(operation.conv1x1(temp_variables, fc2))
47 | else:
48 | channel_attention = TensorList(
49 | [torch.zeros(compressed_samples[0].size(0), compressed_samples[0].size(1), 1, 1).cuda()]
50 | )
51 |
52 | if cfg.TRACK.SPATIAL_ATTENTION == 'none':
53 | spatial_attention = TensorList(
54 | [torch.zeros(compressed_samples[0].size(0), 1, compressed_samples[0].size(2),
55 | compressed_samples[0].size(3)).cuda()]
56 | )
57 | elif cfg.TRACK.SPATIAL_ATTENTION == 'pool':
58 | spatial_attention = operation.spatial_attention(compressed_samples, dim=1, keepdim=True)
59 | else:
60 | raise NotImplementedError('No spatial attention Implemented')
61 |
62 | compressed_samples = operation.matmul(compressed_samples, spatial_attention) + \
63 | operation.matmul(compressed_samples, channel_attention)
64 |
65 | # Filter module
66 | residuals = operation.conv2d(compressed_samples, filter, mode='same').apply(self.response_activation)
67 | residuals = residuals - self.y
68 | residuals = self.sample_weights.sqrt().view(-1, 1, 1, 1) * residuals
69 |
70 | residuals.extend(self.filter_reg.apply(math.sqrt) * filter)
71 | if self.use_attetion:
72 | residuals.extend(self.projection_reg.apply(math.sqrt) * fc2)
73 | residuals.extend(self.projection_reg.apply(math.sqrt) * fc1)
74 | residuals.extend(self.projection_reg.apply(math.sqrt) * P)
75 |
76 | return residuals
77 |
78 |
79 | def ip_input(self, a: TensorList, b: TensorList):
80 |
81 | if self.use_attetion:
82 | a_filter = a[:1]
83 | a_f2 = a[1:2]
84 | a_f1 = a[2:3]
85 | a_P = a[3:]
86 | b_filter = b[:1]
87 | b_f2 = b[1:2]
88 | b_f1 = b[2:3]
89 | b_P = b[3:]
90 |
91 | ip_out = operation.conv2d(a_filter, b_filter).view(-1)
92 | ip_out += operation.conv2d(a_f2.view(1, -1, 1, 1), b_f2.view(1, -1, 1, 1)).view(-1)
93 | ip_out += operation.conv2d(a_f1.view(1, -1, 1, 1), b_f1.view(1, -1, 1, 1)).view(-1)
94 | ip_out += operation.conv2d(a_P.view(1, -1, 1, 1), b_P.view(1, -1, 1, 1)).view(-1)
95 |
96 | return ip_out.concat(ip_out.clone()).concat(ip_out.clone()).concat(ip_out.clone())
97 |
98 | else:
99 | num = len(a) // 2 # Number of filters
100 | a_filter = a[:num]
101 | b_filter = b[:num]
102 | a_P = a[num:]
103 | b_P = b[num:]
104 |
105 | # Filter inner product
106 | # ip_out = a_filter.reshape(-1) @ b_filter.reshape(-1)
107 | ip_out = operation.conv2d(a_filter, b_filter).view(-1)
108 |
109 | # Add projection matrix part
110 | # ip_out += a_P.reshape(-1) @ b_P.reshape(-1)
111 | ip_out += operation.conv2d(a_P.view(1, -1, 1, 1), b_P.view(1, -1, 1, 1)).view(-1)
112 |
113 | # Have independent inner products for each filter
114 | return ip_out.concat(ip_out.clone())
115 |
116 | def M1(self, x: TensorList):
117 | # factorized convolution
118 | return x / self.diag_M
119 |
120 | class ConvProblem(optimization.L2Problem):
121 | def __init__(self, training_samples: TensorList, y: TensorList, filter_reg: torch.Tensor, sample_weights: TensorList, response_activation):
122 | self.training_samples = training_samples
123 | self.y = y
124 | self.filter_reg = filter_reg
125 | self.sample_weights = sample_weights
126 | self.response_activation = response_activation
127 |
128 | def __call__(self, x: TensorList):
129 | """
130 | Compute residuals
131 | :param x: [filters]
132 | :return: [data_terms, filter_regularizations]
133 | """
134 | # Do convolution and compute residuals
135 | residuals = operation.conv2d(self.training_samples, x, mode='same').apply(self.response_activation)
136 | residuals = residuals - self.y
137 | residuals = self.sample_weights.sqrt().view(-1, 1, 1, 1) * residuals
138 |
139 | # Add regularization for projection matrix
140 | residuals.extend(self.filter_reg.apply(math.sqrt) * x)
141 |
142 | return residuals
143 |
144 | def ip_input(self, a: TensorList, b: TensorList):
145 | # return a.reshape(-1) @ b.reshape(-1)
146 | # return (a * b).sum()
147 | return operation.conv2d(a, b).view(-1)
148 |
--------------------------------------------------------------------------------
/pysot/tracker/siamrpnlt_tracker.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 | # Modified by Jinghao Zhou
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 | from __future__ import unicode_literals
8 |
9 | import numpy as np
10 | import torch
11 | from PIL import Image
12 |
13 | from pysot.core.config import cfg
14 | from pysot.tracker.classifier.libs.plotting import show_tensor
15 | from pysot.tracker.siamrpn_tracker import SiamRPNTracker
16 |
17 | class SiamRPNLTTracker(SiamRPNTracker):
18 | def __init__(self, model):
19 | super(SiamRPNLTTracker, self).__init__(model)
20 | self.longterm_state = False
21 |
22 | def track(self, img):
23 | w_z = self.size[0] + cfg.TRACK.CONTEXT_AMOUNT * np.sum(self.size)
24 | h_z = self.size[1] + cfg.TRACK.CONTEXT_AMOUNT * np.sum(self.size)
25 | s_z = np.sqrt(w_z * h_z)
26 | scale_z = cfg.TRACK.EXEMPLAR_SIZE / s_z
27 |
28 | if self.longterm_state:
29 | instance_size = cfg.TRACK.LOST_INSTANCE_SIZE
30 | else:
31 | instance_size = cfg.TRACK.INSTANCE_SIZE
32 |
33 | score_size = (instance_size - cfg.TRACK.EXEMPLAR_SIZE) // \
34 | cfg.ANCHOR.STRIDE + 1 + cfg.TRACK.BASE_SIZE
35 | hanning = np.hanning(score_size)
36 | window = np.outer(hanning, hanning)
37 | window = np.tile(window.flatten(), self.anchor_num)
38 | anchors = self.generate_anchor(score_size)
39 |
40 | s_x = s_z * (instance_size / cfg.TRACK.EXEMPLAR_SIZE)
41 | x_crop = self.get_subwindow(img, self.center_pos, instance_size,
42 | round(s_x), self.channel_average)
43 | with torch.no_grad():
44 | outputs = self.model.track(x_crop)
45 |
46 | score = self._convert_score(outputs['cls'])
47 | pred_bbox = self._convert_bbox(outputs['loc'], anchors)
48 |
49 | def change(r):
50 | return np.maximum(r, 1. / r)
51 |
52 | def sz(w, h):
53 | pad = (w + h) * 0.5
54 | return np.sqrt((w + pad) * (h + pad))
55 |
56 | s_c = change(sz(pred_bbox[2, :], pred_bbox[3, :]) /
57 | (sz(self.size[0] * scale_z, self.size[1] * scale_z)))
58 | r_c = change((self.size[0] / self.size[1]) /
59 | (pred_bbox[2, :] / pred_bbox[3, :]))
60 | penalty = np.exp(-(r_c * s_c - 1) * cfg.TRACK.PENALTY_K)
61 | pscore = penalty * score
62 |
63 | def normalize(score):
64 | score = (score - np.min(score)) / (np.max(score) - np.min(score))
65 | return score
66 |
67 | if cfg.TRACK.USE_CLASSIFIER:
68 |
69 | flag, s = self.classifier.track()
70 | confidence = Image.fromarray(s.detach().cpu().numpy())
71 | confidence = np.array(confidence.resize((score_size, score_size))).flatten()
72 | pscore = pscore.reshape(5, -1) * (1 - cfg.TRACK.COEE_CLASS) + \
73 | normalize(confidence) * cfg.TRACK.COEE_CLASS
74 | pscore = pscore.flatten()
75 |
76 | if not self.longterm_state:
77 | pscore = pscore * (1 - cfg.TRACK.WINDOW_INFLUENCE) + \
78 | window * cfg.TRACK.WINDOW_INFLUENCE
79 | else:
80 | pscore = pscore * (1 - 0.001) + window * 0.001
81 |
82 | best_idx = np.argmax(pscore)
83 | bbox = pred_bbox[:, best_idx] / scale_z
84 | lr = penalty[best_idx] * score[best_idx] * cfg.TRACK.LR
85 | best_score = score[best_idx]
86 | if best_score >= cfg.TRACK.CONFIDENCE_LOW:
87 | cx = bbox[0] + self.center_pos[0]
88 | cy = bbox[1] + self.center_pos[1]
89 | width = self.size[0] * (1 - lr) + bbox[2] * lr
90 | height = self.size[1] * (1 - lr) + bbox[3] * lr
91 | else:
92 | cx = self.center_pos[0]
93 | cy = self.center_pos[1]
94 | width = self.size[0]
95 | height = self.size[1]
96 |
97 | self.center_pos = np.array([cx, cy])
98 | self.size = np.array([width, height])
99 | cx, cy, width, height = self._bbox_clip(cx, cy, width, height, img.shape[:2])
100 | bbox = [cx - width / 2, cy - height / 2, width, height]
101 |
102 | if not self.longterm_state:
103 | if cfg.TRACK.USE_CLASSIFIER:
104 | self.classifier.update(bbox, scale_z, flag)
105 |
106 | if best_score < cfg.TRACK.CONFIDENCE_LOW:
107 | self.longterm_state = True
108 | elif best_score > cfg.TRACK.CONFIDENCE_HIGH:
109 | self.longterm_state = False
110 |
111 | if cfg.TRACK.USE_CLASSIFIER:
112 | return {
113 | 'bbox': bbox,
114 | 'best_score': best_score,
115 | 'flag': flag
116 | }
117 | else:
118 | return {
119 | 'bbox': bbox,
120 | 'best_score': best_score
121 | }
122 |
123 |
--------------------------------------------------------------------------------
/pysot/tracker/tracker_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 | # Modified by Jinghao Zhou
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 | from __future__ import unicode_literals
8 |
9 | from pysot.core.config import cfg
10 | from pysot.tracker.siamfc_tracker import SiamFCTracker
11 | from pysot.tracker.siamrpn_tracker import SiamRPNTracker
12 | from pysot.tracker.siamrpnlt_tracker import SiamRPNLTTracker
13 | from pysot.tracker.siammask_tracker import SiamMaskTracker
14 |
15 | TRACKS = {
16 | 'SiamFCTracker': SiamFCTracker,
17 | 'SiamRPNTracker': SiamRPNTracker,
18 | 'SiamMaskTracker': SiamMaskTracker,
19 | 'SiamRPNLTTracker': SiamRPNLTTracker
20 | }
21 |
22 |
23 | def build_tracker(model):
24 | return TRACKS[cfg.TRACK.TYPE](model)
25 |
--------------------------------------------------------------------------------
/pysot/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/pysot/utils/__init__.py
--------------------------------------------------------------------------------
/pysot/utils/anchor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import math
9 |
10 | import numpy as np
11 |
12 | from pysot.utils.bbox import corner2center, center2corner
13 |
14 |
15 | class Anchors:
16 | """
17 | This class generate anchors.
18 | """
19 | def __init__(self, stride, ratios, scales, image_center=0, size=0):
20 | self.stride = stride
21 | self.ratios = ratios
22 | self.scales = scales
23 | self.image_center = 0
24 | self.size = 0
25 |
26 | self.anchor_num = len(self.scales) * len(self.ratios)
27 |
28 | self.anchors = None
29 |
30 | self.generate_anchors()
31 |
32 | def generate_anchors(self):
33 | """
34 | generate anchors based on predefined configuration
35 | """
36 | self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32)
37 | size = self.stride * self.stride
38 | count = 0
39 | for r in self.ratios:
40 | ws = int(math.sqrt(size*1. / r))
41 | hs = int(ws * r)
42 |
43 | for s in self.scales:
44 | w = ws * s
45 | h = hs * s
46 | self.anchors[count][:] = [-w*0.5, -h*0.5, w*0.5, h*0.5][:]
47 | count += 1
48 |
49 | def generate_all_anchors(self, im_c, size):
50 | """
51 | im_c: image center
52 | size: image size
53 | """
54 | if self.image_center == im_c and self.size == size:
55 | return False
56 | self.image_center = im_c
57 | self.size = size
58 |
59 | a0x = im_c - size // 2 * self.stride
60 | ori = np.array([a0x] * 4, dtype=np.float32)
61 | zero_anchors = self.anchors + ori
62 |
63 | x1 = zero_anchors[:, 0]
64 | y1 = zero_anchors[:, 1]
65 | x2 = zero_anchors[:, 2]
66 | y2 = zero_anchors[:, 3]
67 |
68 | x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1),
69 | [x1, y1, x2, y2])
70 | cx, cy, w, h = corner2center([x1, y1, x2, y2])
71 |
72 | disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride
73 | disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride
74 |
75 | cx = cx + disp_x
76 | cy = cy + disp_y
77 |
78 | # broadcast
79 | zero = np.zeros((self.anchor_num, size, size), dtype=np.float32)
80 | cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h])
81 | x1, y1, x2, y2 = center2corner([cx, cy, w, h])
82 |
83 | self.all_anchors = (np.stack([x1, y1, x2, y2]).astype(np.float32),
84 | np.stack([cx, cy, w, h]).astype(np.float32))
85 | return True
86 |
--------------------------------------------------------------------------------
/pysot/utils/average_meter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 |
9 | class Meter(object):
10 | def __init__(self, name, val, avg):
11 | self.name = name
12 | self.val = val
13 | self.avg = avg
14 |
15 | def __repr__(self):
16 | return "{name}: {val:.6f} ({avg:.6f})".format(
17 | name=self.name, val=self.val, avg=self.avg
18 | )
19 |
20 | def __format__(self, *tuples, **kwargs):
21 | return self.__repr__()
22 |
23 |
24 | class AverageMeter:
25 | """Computes and stores the average and current value"""
26 | def __init__(self, num=100):
27 | self.num = num
28 | self.reset()
29 |
30 | def reset(self):
31 | self.val = {}
32 | self.sum = {}
33 | self.count = {}
34 | self.history = {}
35 |
36 | def update(self, batch=1, **kwargs):
37 | val = {}
38 | for k in kwargs:
39 | val[k] = kwargs[k] / float(batch)
40 | self.val.update(val)
41 | for k in kwargs:
42 | if k not in self.sum:
43 | self.sum[k] = 0
44 | self.count[k] = 0
45 | self.history[k] = []
46 | self.sum[k] += kwargs[k]
47 | self.count[k] += batch
48 | for _ in range(batch):
49 | self.history[k].append(val[k])
50 |
51 | if self.num <= 0:
52 | # < 0, average all
53 | self.history[k] = []
54 |
55 | # == 0: no average
56 | if self.num == 0:
57 | self.sum[k] = self.val[k]
58 | self.count[k] = 1
59 |
60 | elif len(self.history[k]) > self.num:
61 | pop_num = len(self.history[k]) - self.num
62 | for _ in range(pop_num):
63 | self.sum[k] -= self.history[k][0]
64 | del self.history[k][0]
65 | self.count[k] -= 1
66 |
67 | def __repr__(self):
68 | s = ''
69 | for k in self.sum:
70 | s += self.format_str(k)
71 | return s
72 |
73 | def format_str(self, attr):
74 | return "{name}: {val:.6f} ({avg:.6f}) ".format(
75 | name=attr,
76 | val=float(self.val[attr]),
77 | avg=float(self.sum[attr]) / self.count[attr])
78 |
79 | def __getattr__(self, attr):
80 | if attr in self.__dict__:
81 | return super(AverageMeter, self).__getattr__(attr)
82 | if attr not in self.sum:
83 | print("invalid key '{}'".format(attr))
84 | return Meter(attr, 0, 0)
85 | return Meter(attr, self.val[attr], self.avg(attr))
86 |
87 | def avg(self, attr):
88 | return float(self.sum[attr]) / self.count[attr]
89 |
90 |
91 | if __name__ == '__main__':
92 | avg1 = AverageMeter(10)
93 | avg2 = AverageMeter(0)
94 | avg3 = AverageMeter(-1)
95 |
96 | for i in range(20):
97 | avg1.update(s=i)
98 | avg2.update(s=i)
99 | avg3.update(s=i)
100 |
101 | print('iter {}'.format(i))
102 | print(avg1.s)
103 | print(avg2.s)
104 | print(avg3.s)
105 |
--------------------------------------------------------------------------------
/pysot/utils/bbox.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | from collections import namedtuple
9 |
10 | import numpy as np
11 |
12 |
13 | Corner = namedtuple('Corner', 'x1 y1 x2 y2')
14 | # alias
15 | BBox = Corner
16 | Center = namedtuple('Center', 'x y w h')
17 |
18 |
19 | def corner2center(corner):
20 | """ convert (x1, y1, x2, y2) to (cx, cy, w, h)
21 | Args:
22 | conrner: Corner or np.array (4*N)
23 | Return:
24 | Center or np.array (4 * N)
25 | """
26 | if isinstance(corner, Corner):
27 | x1, y1, x2, y2 = corner
28 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1))
29 | else:
30 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3]
31 | x = (x1 + x2) * 0.5
32 | y = (y1 + y2) * 0.5
33 | w = x2 - x1
34 | h = y2 - y1
35 | return x, y, w, h
36 |
37 |
38 | def center2corner(center):
39 | """ convert (cx, cy, w, h) to (x1, y1, x2, y2)
40 | Args:
41 | center: Center or np.array (4 * N)
42 | Return:
43 | center or np.array (4 * N)
44 | """
45 | if isinstance(center, Center):
46 | x, y, w, h = center
47 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5)
48 | else:
49 | x, y, w, h = center[0], center[1], center[2], center[3]
50 | x1 = x - w * 0.5
51 | y1 = y - h * 0.5
52 | x2 = x + w * 0.5
53 | y2 = y + h * 0.5
54 | return x1, y1, x2, y2
55 |
56 |
57 | def IoU(rect1, rect2):
58 | """ caculate interection over union
59 | Args:
60 | rect1: (x1, y1, x2, y2)
61 | rect2: (x1, y1, x2, y2)
62 | Returns:
63 | iou
64 | """
65 | # overlap
66 | x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3]
67 | tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3]
68 |
69 | xx1 = np.maximum(tx1, x1)
70 | yy1 = np.maximum(ty1, y1)
71 | xx2 = np.minimum(tx2, x2)
72 | yy2 = np.minimum(ty2, y2)
73 |
74 | ww = np.maximum(0, xx2 - xx1)
75 | hh = np.maximum(0, yy2 - yy1)
76 |
77 | area = (x2-x1) * (y2-y1)
78 | target_a = (tx2-tx1) * (ty2 - ty1)
79 | inter = ww * hh
80 | iou = inter / (area + target_a - inter)
81 | return iou
82 |
83 |
84 | def cxy_wh_2_rect(pos, sz):
85 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 0-index
86 | """
87 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]])
88 |
89 |
90 | def rect_2_cxy_wh(rect):
91 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 0-index
92 | """
93 | return np.array([rect[0]+rect[2]/2, rect[1]+rect[3]/2]), \
94 | np.array([rect[2], rect[3]])
95 |
96 |
97 | def cxy_wh_2_rect1(pos, sz):
98 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 1-index
99 | """
100 | return np.array([pos[0]-sz[0]/2+1, pos[1]-sz[1]/2+1, sz[0], sz[1]])
101 |
102 |
103 | def rect1_2_cxy_wh(rect):
104 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 1-index
105 | """
106 | return np.array([rect[0]+rect[2]/2-1, rect[1]+rect[3]/2-1]), \
107 | np.array([rect[2], rect[3]])
108 |
109 |
110 | def get_axis_aligned_bbox(region):
111 | """ convert region to (cx, cy, w, h) that represent by axis aligned box
112 | """
113 | nv = region.size
114 | if nv == 8:
115 | cx = np.mean(region[0::2])
116 | cy = np.mean(region[1::2])
117 | x1 = min(region[0::2])
118 | x2 = max(region[0::2])
119 | y1 = min(region[1::2])
120 | y2 = max(region[1::2])
121 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * \
122 | np.linalg.norm(region[2:4] - region[4:6])
123 | A2 = (x2 - x1) * (y2 - y1)
124 | s = np.sqrt(A1 / A2)
125 | w = s * (x2 - x1) + 1
126 | h = s * (y2 - y1) + 1
127 | else:
128 | x = region[0]
129 | y = region[1]
130 | w = region[2]
131 | h = region[3]
132 | cx = x+w/2
133 | cy = y+h/2
134 | return cx, cy, w, h
135 |
136 |
137 | def get_min_max_bbox(region):
138 | """ convert region to (cx, cy, w, h) that represent by mim-max box
139 | """
140 | nv = region.size
141 | if nv == 8:
142 | cx = np.mean(region[0::2])
143 | cy = np.mean(region[1::2])
144 | x1 = min(region[0::2])
145 | x2 = max(region[0::2])
146 | y1 = min(region[1::2])
147 | y2 = max(region[1::2])
148 | w = x2 - x1
149 | h = y2 - y1
150 | else:
151 | x = region[0]
152 | y = region[1]
153 | w = region[2]
154 | h = region[3]
155 | cx = x+w/2
156 | cy = y+h/2
157 | return cx, cy, w, h
158 |
--------------------------------------------------------------------------------
/pysot/utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import os
9 | import socket
10 | import logging
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.distributed as dist
15 |
16 | from pysot.utils.log_helper import log_once
17 |
18 | logger = logging.getLogger('global')
19 |
20 |
21 | def average_reduce(v):
22 | if get_world_size() == 1:
23 | return v
24 | tensor = torch.cuda.FloatTensor(1)
25 | tensor[0] = v
26 | dist.all_reduce(tensor)
27 | v = tensor[0] / get_world_size()
28 | return v
29 |
30 |
31 | class DistModule(nn.Module):
32 | def __init__(self, module, bn_method=0):
33 | super(DistModule, self).__init__()
34 | self.module = module
35 | self.bn_method = bn_method
36 | if get_world_size() > 1:
37 | broadcast_params(self.module)
38 | else:
39 | self.bn_method = 0 # single proccess
40 |
41 | def forward(self, *args, **kwargs):
42 | broadcast_buffers(self.module, self.bn_method)
43 | return self.module(*args, **kwargs)
44 |
45 | def train(self, mode=True):
46 | super(DistModule, self).train(mode)
47 | self.module.train(mode)
48 | return self
49 |
50 |
51 | def broadcast_params(model):
52 | """ broadcast model parameters """
53 | for p in model.state_dict().values():
54 | dist.broadcast(p, 0)
55 |
56 |
57 | def broadcast_buffers(model, method=0):
58 | """ broadcast model buffers """
59 | if method == 0:
60 | return
61 |
62 | world_size = get_world_size()
63 |
64 | for b in model._all_buffers():
65 | if method == 1: # broadcast from main proccess
66 | dist.broadcast(b, 0)
67 | elif method == 2: # average
68 | dist.all_reduce(b)
69 | b /= world_size
70 | else:
71 | raise Exception('Invalid buffer broadcast code {}'.format(method))
72 |
73 |
74 | inited = False
75 |
76 |
77 | def _dist_init():
78 | '''
79 | if guess right:
80 | ntasks: world_size (process num)
81 | proc_id: rank
82 | '''
83 | rank = int(os.environ['RANK'])
84 | num_gpus = torch.cuda.device_count()
85 | torch.cuda.set_device(rank % num_gpus)
86 | dist.init_process_group(backend='nccl')
87 | world_size = dist.get_world_size()
88 | return rank, world_size
89 |
90 |
91 | def _get_local_ip():
92 | try:
93 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
94 | s.connect(('8.8.8.8', 80))
95 | ip = s.getsockname()[0]
96 | finally:
97 | s.close()
98 | return ip
99 |
100 |
101 | def dist_init():
102 | global rank, world_size, inited
103 | try:
104 | rank, world_size = _dist_init()
105 | except RuntimeError as e:
106 | if 'public' in e.args[0]:
107 | logger.info(e)
108 | logger.info('Warning: use single process')
109 | rank, world_size = 0, 1
110 | else:
111 | raise RuntimeError(*e.args)
112 | inited = True
113 | return rank, world_size
114 |
115 |
116 | def get_rank():
117 | if not inited:
118 | raise(Exception('dist not inited'))
119 | return rank
120 |
121 |
122 | def get_world_size():
123 | if not inited:
124 | raise(Exception('dist not inited'))
125 | return world_size
126 |
127 |
128 | def reduce_gradients(model, _type='sum'):
129 | types = ['sum', 'avg']
130 | assert _type in types, 'gradients method must be in "{}"'.format(types)
131 | log_once("gradients method is {}".format(_type))
132 | if get_world_size() > 1:
133 | for param in model.parameters():
134 | if param.requires_grad:
135 | dist.all_reduce(param.grad.data)
136 | if _type == 'avg':
137 | param.grad.data /= get_world_size()
138 | else:
139 | return None
140 |
--------------------------------------------------------------------------------
/pysot/utils/log_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import os
9 | import logging
10 | import math
11 | import sys
12 |
13 |
14 | if hasattr(sys, 'frozen'): # support for py2exe
15 | _srcfile = "logging%s__init__%s" % (os.sep, __file__[-4:])
16 | elif __file__[-4:].lower() in ['.pyc', '.pyo']:
17 | _srcfile = __file__[:-4] + '.py'
18 | else:
19 | _srcfile = __file__
20 | _srcfile = os.path.normcase(_srcfile)
21 |
22 | logs = set()
23 |
24 |
25 | class Filter:
26 | def __init__(self, flag):
27 | self.flag = flag
28 |
29 | def filter(self, x):
30 | return self.flag
31 |
32 |
33 | class Dummy:
34 | def __init__(self, *arg, **kwargs):
35 | pass
36 |
37 | def __getattr__(self, arg):
38 | def dummy(*args, **kwargs): pass
39 | return dummy
40 |
41 |
42 | def get_format(logger, level):
43 | if 'RANK' in os.environ:
44 | rank = int(os.environ['RANK'])
45 |
46 | if level == logging.INFO:
47 | logger.addFilter(Filter(rank == 0))
48 | else:
49 | rank = 0
50 | format_str = '[%(asctime)s-rk{}-%(filename)s#%(lineno)3d] %(message)s'.format(rank)
51 | formatter = logging.Formatter(format_str)
52 | return formatter
53 |
54 |
55 | def get_format_custom(logger, level):
56 | if 'RANK' in os.environ:
57 | rank = int(os.environ['RANK'])
58 | if level == logging.INFO:
59 | logger.addFilter(Filter(rank == 0))
60 | else:
61 | rank = 0
62 | format_str = '[%(asctime)s-rk{}-%(message)s'.format(rank)
63 | formatter = logging.Formatter(format_str)
64 | return formatter
65 |
66 |
67 | def init_log(name, level=logging.INFO, format_func=get_format):
68 | if (name, level) in logs:
69 | return
70 | logs.add((name, level))
71 | logger = logging.getLogger(name)
72 | logger.setLevel(level)
73 | ch = logging.StreamHandler()
74 | ch.setLevel(level)
75 | formatter = format_func(logger, level)
76 | ch.setFormatter(formatter)
77 | logger.addHandler(ch)
78 | return logger
79 |
80 |
81 | def add_file_handler(name, log_file, level=logging.INFO):
82 | logger = logging.getLogger(name)
83 | fh = logging.FileHandler(log_file)
84 | fh.setFormatter(get_format(logger, level))
85 | logger.addHandler(fh)
86 |
87 |
88 | init_log('global')
89 |
90 |
91 | def print_speed(i, i_time, n):
92 | """print_speed(index, index_time, total_iteration)"""
93 | logger = logging.getLogger('global')
94 | average_time = i_time
95 | remaining_time = (n - i) * average_time
96 | remaining_day = math.floor(remaining_time / 86400)
97 | remaining_hour = math.floor(remaining_time / 3600 -
98 | remaining_day * 24)
99 | remaining_min = math.floor(remaining_time / 60 -
100 | remaining_day * 1440 -
101 | remaining_hour * 60)
102 | logger.info('Progress: %d / %d [%d%%], Speed: %.3f s/iter, ETA %d:%02d:%02d (D:H:M)\n' %
103 | (i, n, i / n * 100,
104 | average_time,
105 | remaining_day, remaining_hour, remaining_min))
106 |
107 |
108 | def find_caller():
109 | def current_frame():
110 | try:
111 | raise Exception
112 | except:
113 | return sys.exc_info()[2].tb_frame.f_back
114 |
115 | f = current_frame()
116 | if f is not None:
117 | f = f.f_back
118 | rv = "(unknown file)", 0, "(unknown function)"
119 | while hasattr(f, "f_code"):
120 | co = f.f_code
121 | filename = os.path.normcase(co.co_filename)
122 | rv = (co.co_filename, f.f_lineno, co.co_name)
123 | if filename == _srcfile:
124 | f = f.f_back
125 | continue
126 | break
127 | rv = list(rv)
128 | rv[0] = os.path.basename(rv[0])
129 | return rv
130 |
131 |
132 | class LogOnce:
133 | def __init__(self):
134 | self.logged = set()
135 | self.logger = init_log('log_once', format_func=get_format_custom)
136 |
137 | def log(self, strings):
138 | fn, lineno, caller = find_caller()
139 | key = (fn, lineno, caller, strings)
140 | if key in self.logged:
141 | return
142 | self.logged.add(key)
143 | message = "{filename:s}<{caller}>#{lineno:3d}] {strings}".format(
144 | filename=fn, lineno=lineno, strings=strings, caller=caller)
145 | self.logger.info(message)
146 |
147 |
148 | once_logger = LogOnce()
149 |
150 |
151 | def log_once(strings):
152 | once_logger.log(strings)
153 |
154 |
155 | def main():
156 | for i, lvl in enumerate([logging.DEBUG, logging.INFO,
157 | logging.WARNING, logging.ERROR,
158 | logging.CRITICAL]):
159 | log_name = str(lvl)
160 | init_log(log_name, lvl)
161 | logger = logging.getLogger(log_name)
162 | print('****cur lvl:{}'.format(lvl))
163 | logger.debug('debug')
164 | logger.info('info')
165 | logger.warning('warning')
166 | logger.error('error')
167 | logger.critical('critiacal')
168 |
169 |
170 | if __name__ == '__main__':
171 | main()
172 | for i in range(10):
173 | log_once('xxx')
174 |
--------------------------------------------------------------------------------
/pysot/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import os
9 |
10 | from colorama import Fore, Style
11 |
12 |
13 | __all__ = ['commit', 'describe']
14 |
15 |
16 | def _exec(cmd):
17 | f = os.popen(cmd, 'r', 1)
18 | return f.read().strip()
19 |
20 |
21 | def _bold(s):
22 | return "\033[1m%s\033[0m" % s
23 |
24 |
25 | def _color(s):
26 | return f'{Fore.RED}{s}{Style.RESET_ALL}'
27 |
28 |
29 | def _describe(model, lines=None, spaces=0):
30 | head = " " * spaces
31 | for name, p in model.named_parameters():
32 | if '.' in name:
33 | continue
34 | if p.requires_grad:
35 | name = _color(name)
36 | line = "{head}- {name}".format(head=head, name=name)
37 | lines.append(line)
38 |
39 | for name, m in model.named_children():
40 | space_num = len(name) + spaces + 1
41 | if m.training:
42 | name = _color(name)
43 | line = "{head}.{name} ({type})".format(
44 | head=head,
45 | name=name,
46 | type=m.__class__.__name__)
47 | lines.append(line)
48 | _describe(m, lines, space_num)
49 |
50 |
51 | def commit():
52 | root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
53 | cmd = "cd {}; git log | head -n1 | awk '{{print $2}}'".format(root)
54 | commit = _exec(cmd)
55 | cmd = "cd {}; git log --oneline | head -n1".format(root)
56 | commit_log = _exec(cmd)
57 | return "commit : {}\n log : {}".format(commit, commit_log)
58 |
59 |
60 | def describe(net, name=None):
61 | num = 0
62 | lines = []
63 | if name is not None:
64 | lines.append(name)
65 | num = len(name)
66 | _describe(net, lines, num)
67 | return "\n".join(lines)
68 |
--------------------------------------------------------------------------------
/pysot/utils/model_load.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) SenseTime. All Rights Reserved.
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | from __future__ import unicode_literals
7 |
8 | import logging
9 |
10 | import torch
11 | import torch._utils
12 | try:
13 | torch._utils._rebuild_tensor_v2
14 | except AttributeError:
15 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
16 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
17 | tensor.requires_grad = requires_grad
18 | tensor._backward_hooks = backward_hooks
19 | return tensor
20 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
21 |
22 | logger = logging.getLogger('global')
23 |
24 |
25 | def check_keys(model, pretrained_state_dict):
26 | ckpt_keys = set(pretrained_state_dict.keys())
27 | model_keys = set(model.state_dict().keys())
28 | used_pretrained_keys = model_keys & ckpt_keys
29 | unused_pretrained_keys = ckpt_keys - model_keys
30 | missing_keys = model_keys - ckpt_keys
31 | # filter 'num_batches_tracked'
32 | missing_keys = [x for x in missing_keys
33 | if not x.endswith('num_batches_tracked')]
34 | if len(missing_keys) > 0:
35 | logger.info('[Warning] missing keys: {}'.format(missing_keys))
36 | logger.info('missing keys:{}'.format(len(missing_keys)))
37 | if len(unused_pretrained_keys) > 0:
38 | logger.info('[Warning] unused_pretrained_keys: {}'.format(
39 | unused_pretrained_keys))
40 | logger.info('unused checkpoint keys:{}'.format(
41 | len(unused_pretrained_keys)))
42 | logger.info('used keys:{}'.format(len(used_pretrained_keys)))
43 | assert len(used_pretrained_keys) > 0, \
44 | 'load NONE from pretrained checkpoint'
45 | return True
46 |
47 |
48 | def remove_prefix(state_dict, prefix):
49 | ''' Old style model is stored with all names of parameters
50 | share common prefix 'module.' '''
51 | logger.info('remove prefix \'{}\''.format(prefix))
52 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
53 | return {f(key): value for key, value in state_dict.items()}
54 |
55 |
56 | def load_pretrain(model, pretrained_path):
57 | logger.info('load pretrained model from {}'.format(pretrained_path))
58 | device = torch.cuda.current_device()
59 | pretrained_dict = torch.load(pretrained_path,
60 | map_location=lambda storage, loc: storage.cuda(device))
61 | if "state_dict" in pretrained_dict.keys():
62 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'],
63 | 'module.')
64 | else:
65 | pretrained_dict = remove_prefix(pretrained_dict, 'module.')
66 |
67 | try:
68 | check_keys(model, pretrained_dict)
69 | except:
70 | logger.info('[Warning]: using pretrain as features.\
71 | Adding "features." as prefix')
72 | new_dict = {}
73 | for k, v in pretrained_dict.items():
74 | k = 'features.' + k
75 | new_dict[k] = v
76 | pretrained_dict = new_dict
77 | check_keys(model, pretrained_dict)
78 | model.load_state_dict(pretrained_dict, strict=False)
79 | return model
80 |
81 |
82 | def restore_from(model, optimizer, ckpt_path):
83 | device = torch.cuda.current_device()
84 | ckpt = torch.load(ckpt_path,
85 | map_location=lambda storage, loc: storage.cuda(device))
86 | epoch = ckpt['epoch']
87 |
88 | ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.')
89 | check_keys(model, ckpt_model_dict)
90 | model.load_state_dict(ckpt_model_dict, strict=False)
91 |
92 | check_keys(optimizer, ckpt['optimizer'])
93 | optimizer.load_state_dict(ckpt['optimizer'])
94 | return model, optimizer, epoch
95 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | yacs
3 | tqdm
4 | pyyaml
5 | matplotlib
6 | colorama
7 | cython
8 | tensorboardX
9 | filterpy
10 | pillow
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from distutils.extension import Extension
3 | from Cython.Build import cythonize
4 |
5 | ext_modules = [
6 | Extension(
7 | name='toolkit.utils.region',
8 | sources=[
9 | 'toolkit/utils/region.pyx',
10 | 'toolkit/utils/src/region.c',
11 | ],
12 | include_dirs=[
13 | 'toolkit/utils/src'
14 | ]
15 | )
16 | ]
17 |
18 | setup(
19 | name='toolkit',
20 | packages=['toolkit'],
21 | ext_modules=cythonize(ext_modules)
22 | )
23 |
--------------------------------------------------------------------------------
/testing_dataset/README.md:
--------------------------------------------------------------------------------
1 | # Testing dataset directory
2 | # putting your testing dataset here
3 | - [x] [VOT2016](http://www.votchallenge.net/vot2016/dataset.html)
4 | - [x] [VOT2018](http://www.votchallenge.net/vot2018/dataset.html)
5 | - [x] [VOT2018-LT](http://www.votchallenge.net/vot2018/dataset.html)
6 | - [x] [OTB100(OTB2015)](http://cvlab.hanyang.ac.kr/tracker_benchmark/datasets.html)
7 | - [x] [UAV123](https://ivul.kaust.edu.sa/Pages/Dataset-UAV123.aspx)
8 | - [x] [NFS](http://ci2cv.net/nfs/index.html)
9 | - [x] [LaSOT](https://cis.temple.edu/lasot/)
10 | - [ ] [TrackingNet (Evaluation on Server)](https://tracking-net.org)
11 | - [ ] [GOT-10k (Evaluation on Server)](http://got-10k.aitestunion.com)
12 |
13 | ## Download Dataset
14 | Download json files used in our toolkit [baidu pan](https://pan.baidu.com/s/1js0Qhykqqur7_lNRtle1tA)
15 |
16 | 1. Put CVRP13.json, OTB100.json, OTB50.json in OTB100 dataset directory (you need to copy Jogging to Jogging-1 and Jogging-2, and copy Skating2 to Skating2-1 and Skating2-2 or using softlink)
17 |
18 | The directory should have the below format
19 |
20 | | -- OTB100/
21 |
22 | | -- Basketball
23 |
24 | | ......
25 |
26 | | -- Woman
27 |
28 | | -- OTB100.json
29 |
30 | | -- OTB50.json
31 |
32 | | -- CVPR13.json
33 |
34 | 2. Put all other jsons in the dataset directory like in step 1
35 |
--------------------------------------------------------------------------------
/toolkit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/toolkit/__init__.py
--------------------------------------------------------------------------------
/toolkit/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .vot import VOTDataset, VOTLTDataset
2 | from .otb import OTBDataset
3 | from .uav import UAVDataset
4 | from .lasot import LaSOTDataset
5 | from .nfs import NFSDataset
6 | from .trackingnet import TrackingNetDataset
7 | from .got10k import GOT10kDataset
8 | from .visdrone import VisDroneDataset
9 |
10 | class DatasetFactory(object):
11 | @staticmethod
12 | def create_dataset(**kwargs):
13 | """
14 | Args:
15 | name: dataset name 'OTB2015', 'LaSOT', 'UAV123', 'NFS240', 'NFS30',
16 | 'VOT2018', 'VOT2016', 'VOT2018-LT'
17 | dataset_root: dataset root
18 | load_img: wether to load image
19 | Return:
20 | dataset
21 | """
22 | assert 'name' in kwargs, "should provide dataset name"
23 | name = kwargs['name']
24 | if 'OTB' in name:
25 | dataset = OTBDataset(**kwargs)
26 | elif 'LaSOT' == name:
27 | dataset = LaSOTDataset(**kwargs)
28 | elif 'UAV' in name:
29 | dataset = UAVDataset(**kwargs)
30 | elif 'NFS' in name:
31 | dataset = NFSDataset(**kwargs)
32 | elif 'VOT2018' == name or 'VOT2016' == name or 'VOT2019' == name:
33 | dataset = VOTDataset(**kwargs)
34 | elif 'VOT2018-LT' == name:
35 | dataset = VOTLTDataset(**kwargs)
36 | elif 'TrackingNet' == name:
37 | dataset = TrackingNetDataset(**kwargs)
38 | elif 'carplate' == name:
39 | dataset = TrackingNetDataset(**kwargs)
40 | elif 'GOT-10k' == name:
41 | dataset = GOT10kDataset(**kwargs)
42 | elif 'VisDrone' in name:
43 | dataset = VisDroneDataset(**kwargs)
44 | else:
45 | raise Exception("unknow dataset {}".format(kwargs['name']))
46 | return dataset
47 |
48 |
--------------------------------------------------------------------------------
/toolkit/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 |
3 | class Dataset(object):
4 | def __init__(self, name, dataset_root):
5 | self.name = name
6 | self.dataset_root = dataset_root
7 | self.videos = None
8 |
9 | def __getitem__(self, idx):
10 | if isinstance(idx, str):
11 | return self.videos[idx]
12 | elif isinstance(idx, int):
13 | return self.videos[sorted(list(self.videos.keys()))[idx]]
14 |
15 | def __len__(self):
16 | return len(self.videos)
17 |
18 | def __iter__(self):
19 | keys = sorted(list(self.videos.keys()))
20 | for key in keys:
21 | yield self.videos[key]
22 |
23 | def set_tracker(self, path, tracker_names):
24 | """
25 | Args:
26 | path: path to tracker results,
27 | tracker_names: list of tracker name
28 | """
29 | self.tracker_path = path
30 | self.tracker_names = tracker_names
31 | # for video in tqdm(self.videos.values(),
32 | # desc='loading tacker result', ncols=100):
33 | # video.load_tracker(path, tracker_names)
34 |
--------------------------------------------------------------------------------
/toolkit/datasets/got10k.py:
--------------------------------------------------------------------------------
1 |
2 | import json
3 | import os
4 |
5 | from tqdm import tqdm
6 |
7 | from .dataset import Dataset
8 | from .video import Video
9 |
10 | class GOT10kVideo(Video):
11 | """
12 | Args:
13 | name: video name
14 | root: dataset root
15 | video_dir: video directory
16 | init_rect: init rectangle
17 | img_names: image names
18 | gt_rect: groundtruth rectangle
19 | attr: attribute of video
20 | """
21 | def __init__(self, name, root, video_dir, init_rect, img_names,
22 | gt_rect, attr, load_img=False):
23 | super(GOT10kVideo, self).__init__(name, root, video_dir,
24 | init_rect, img_names, gt_rect, attr, load_img)
25 |
26 | # def load_tracker(self, path, tracker_names=None):
27 | # """
28 | # Args:
29 | # path(str): path to result
30 | # tracker_name(list): name of tracker
31 | # """
32 | # if not tracker_names:
33 | # tracker_names = [x.split('/')[-1] for x in glob(path)
34 | # if os.path.isdir(x)]
35 | # if isinstance(tracker_names, str):
36 | # tracker_names = [tracker_names]
37 | # # self.pred_trajs = {}
38 | # for name in tracker_names:
39 | # traj_file = os.path.join(path, name, self.name+'.txt')
40 | # if os.path.exists(traj_file):
41 | # with open(traj_file, 'r') as f :
42 | # self.pred_trajs[name] = [list(map(float, x.strip().split(',')))
43 | # for x in f.readlines()]
44 | # if len(self.pred_trajs[name]) != len(self.gt_traj):
45 | # print(name, len(self.pred_trajs[name]), len(self.gt_traj), self.name)
46 | # else:
47 |
48 | # self.tracker_names = list(self.pred_trajs.keys())
49 |
50 | class GOT10kDataset(Dataset):
51 | """
52 | Args:
53 | name: dataset name, should be "NFS30" or "NFS240"
54 | dataset_root, dataset root dir
55 | """
56 | def __init__(self, name, dataset_root, load_img=False):
57 | super(GOT10kDataset, self).__init__(name, dataset_root)
58 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f:
59 | meta_data = json.load(f)
60 |
61 | # load videos
62 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100)
63 | self.videos = {}
64 | for video in pbar:
65 | pbar.set_postfix_str(video)
66 | self.videos[video] = GOT10kVideo(video,
67 | dataset_root,
68 | meta_data[video]['video_dir'],
69 | meta_data[video]['init_rect'],
70 | meta_data[video]['img_names'],
71 | meta_data[video]['gt_rect'],
72 | None)
73 | self.attr = {}
74 | self.attr['ALL'] = list(self.videos.keys())
75 |
--------------------------------------------------------------------------------
/toolkit/datasets/lasot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 |
5 | from tqdm import tqdm
6 | from glob import glob
7 |
8 | from .dataset import Dataset
9 | from .video import Video
10 |
11 | class LaSOTVideo(Video):
12 | """
13 | Args:
14 | name: video name
15 | root: dataset root
16 | video_dir: video directory
17 | init_rect: init rectangle
18 | img_names: image names
19 | gt_rect: groundtruth rectangle
20 | attr: attribute of video
21 | """
22 | def __init__(self, name, root, video_dir, init_rect, img_names,
23 | gt_rect, attr, absent, load_img=False):
24 | super(LaSOTVideo, self).__init__(name, root, video_dir,
25 | init_rect, img_names, gt_rect, attr, load_img)
26 | self.absent = np.array(absent, np.int8)
27 |
28 | def load_tracker(self, path, tracker_names=None, store=True):
29 | """
30 | Args:
31 | path(str): path to result
32 | tracker_name(list): name of tracker
33 | """
34 | if not tracker_names:
35 | tracker_names = [x.split('/')[-1] for x in glob(path)
36 | if os.path.isdir(x)]
37 | if isinstance(tracker_names, str):
38 | tracker_names = [tracker_names]
39 | for name in tracker_names:
40 | traj_file = os.path.join(path, name, self.name+'.txt')
41 | if os.path.exists(traj_file):
42 | with open(traj_file, 'r') as f :
43 | pred_traj = [list(map(float, x.strip().split(',')))
44 | for x in f.readlines()]
45 | else:
46 | print("File not exists: ", traj_file)
47 | if self.name == 'monkey-17':
48 | pred_traj = pred_traj[:len(self.gt_traj)]
49 | if store:
50 | self.pred_trajs[name] = pred_traj
51 | else:
52 | return pred_traj
53 | self.tracker_names = list(self.pred_trajs.keys())
54 |
55 |
56 |
57 | class LaSOTDataset(Dataset):
58 | """
59 | Args:
60 | name: dataset name, should be 'OTB100', 'CVPR13', 'OTB50'
61 | dataset_root: dataset root
62 | load_img: wether to load all imgs
63 | """
64 | def __init__(self, name, dataset_root, load_img=False):
65 | super(LaSOTDataset, self).__init__(name, dataset_root)
66 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f:
67 | meta_data = json.load(f)
68 |
69 | # load videos
70 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100)
71 | self.videos = {}
72 | for video in pbar:
73 | pbar.set_postfix_str(video)
74 | self.videos[video] = LaSOTVideo(video,
75 | dataset_root,
76 | meta_data[video]['video_dir'],
77 | meta_data[video]['init_rect'],
78 | meta_data[video]['img_names'],
79 | meta_data[video]['gt_rect'],
80 | meta_data[video]['attr'],
81 | meta_data[video]['absent'])
82 |
83 | # set attr
84 | attr = []
85 | for x in self.videos.values():
86 | attr += x.attr
87 | attr = set(attr)
88 | self.attr = {}
89 | self.attr['ALL'] = list(self.videos.keys())
90 | for x in attr:
91 | self.attr[x] = []
92 | for k, v in self.videos.items():
93 | for attr_ in v.attr:
94 | self.attr[attr_].append(k)
95 |
96 |
97 |
--------------------------------------------------------------------------------
/toolkit/datasets/nfs.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy as np
4 |
5 | from tqdm import tqdm
6 | from glob import glob
7 |
8 | from .dataset import Dataset
9 | from .video import Video
10 |
11 |
12 | class NFSVideo(Video):
13 | """
14 | Args:
15 | name: video name
16 | root: dataset root
17 | video_dir: video directory
18 | init_rect: init rectangle
19 | img_names: image names
20 | gt_rect: groundtruth rectangle
21 | attr: attribute of video
22 | """
23 | def __init__(self, name, root, video_dir, init_rect, img_names,
24 | gt_rect, attr, load_img=False):
25 | super(NFSVideo, self).__init__(name, root, video_dir,
26 | init_rect, img_names, gt_rect, attr, load_img)
27 |
28 | # def load_tracker(self, path, tracker_names=None):
29 | # """
30 | # Args:
31 | # path(str): path to result
32 | # tracker_name(list): name of tracker
33 | # """
34 | # if not tracker_names:
35 | # tracker_names = [x.split('/')[-1] for x in glob(path)
36 | # if os.path.isdir(x)]
37 | # if isinstance(tracker_names, str):
38 | # tracker_names = [tracker_names]
39 | # # self.pred_trajs = {}
40 | # for name in tracker_names:
41 | # traj_file = os.path.join(path, name, self.name+'.txt')
42 | # if os.path.exists(traj_file):
43 | # with open(traj_file, 'r') as f :
44 | # self.pred_trajs[name] = [list(map(float, x.strip().split(',')))
45 | # for x in f.readlines()]
46 | # if len(self.pred_trajs[name]) != len(self.gt_traj):
47 | # print(name, len(self.pred_trajs[name]), len(self.gt_traj), self.name)
48 | # else:
49 |
50 | # self.tracker_names = list(self.pred_trajs.keys())
51 |
52 | class NFSDataset(Dataset):
53 | """
54 | Args:
55 | name: dataset name, should be "NFS30" or "NFS240"
56 | dataset_root, dataset root dir
57 | """
58 | def __init__(self, name, dataset_root, load_img=False):
59 | super(NFSDataset, self).__init__(name, dataset_root)
60 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f:
61 | meta_data = json.load(f)
62 |
63 | # load videos
64 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100)
65 | self.videos = {}
66 | for video in pbar:
67 | pbar.set_postfix_str(video)
68 | self.videos[video] = NFSVideo(video,
69 | dataset_root,
70 | meta_data[video]['video_dir'],
71 | meta_data[video]['init_rect'],
72 | meta_data[video]['img_names'],
73 | meta_data[video]['gt_rect'],
74 | None)
75 |
76 | self.attr = {}
77 | self.attr['ALL'] = list(self.videos.keys())
78 |
--------------------------------------------------------------------------------
/toolkit/datasets/otb.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy as np
4 |
5 | from PIL import Image
6 | from tqdm import tqdm
7 | from glob import glob
8 |
9 | from .dataset import Dataset
10 | from .video import Video
11 |
12 |
13 | class OTBVideo(Video):
14 | """
15 | Args:
16 | name: video name
17 | root: dataset root
18 | video_dir: video directory
19 | init_rect: init rectangle
20 | img_names: image names
21 | gt_rect: groundtruth rectangle
22 | attr: attribute of video
23 | """
24 | def __init__(self, name, root, video_dir, init_rect, img_names,
25 | gt_rect, attr, load_img=False):
26 | super(OTBVideo, self).__init__(name, root, video_dir,
27 | init_rect, img_names, gt_rect, attr, load_img)
28 |
29 | def load_tracker(self, path, tracker_names=None, store=True):
30 | """
31 | Args:
32 | path(str): path to result
33 | tracker_name(list): name of tracker
34 | """
35 | if not tracker_names:
36 | tracker_names = [x.split('/')[-1] for x in glob(path)
37 | if os.path.isdir(x)]
38 | if isinstance(tracker_names, str):
39 | tracker_names = [tracker_names]
40 | for name in tracker_names:
41 | traj_file = os.path.join(path, name, self.name+'.txt')
42 | if not os.path.exists(traj_file):
43 | if self.name == 'FleetFace':
44 | txt_name = 'fleetface.txt'
45 | elif self.name == 'Jogging-1':
46 | txt_name = 'jogging_1.txt'
47 | elif self.name == 'Jogging-2':
48 | txt_name = 'jogging_2.txt'
49 | elif self.name == 'Skating2-1':
50 | txt_name = 'skating2_1.txt'
51 | elif self.name == 'Skating2-2':
52 | txt_name = 'skating2_2.txt'
53 | elif self.name == 'FaceOcc1':
54 | txt_name = 'faceocc1.txt'
55 | elif self.name == 'FaceOcc2':
56 | txt_name = 'faceocc2.txt'
57 | elif self.name == 'Human4-2':
58 | txt_name = 'human4_2.txt'
59 | else:
60 | txt_name = self.name[0].lower()+self.name[1:]+'.txt'
61 | traj_file = os.path.join(path, name, txt_name)
62 | if os.path.exists(traj_file):
63 | with open(traj_file, 'r') as f :
64 | pred_traj = [list(map(float, x.strip().split(',')))
65 | for x in f.readlines()]
66 | if len(pred_traj) != len(self.gt_traj):
67 | print(name, len(pred_traj), len(self.gt_traj), self.name)
68 | if store:
69 | self.pred_trajs[name] = pred_traj
70 | else:
71 | return pred_traj
72 | else:
73 | print(traj_file)
74 | self.tracker_names = list(self.pred_trajs.keys())
75 |
76 |
77 |
78 | class OTBDataset(Dataset):
79 | """
80 | Args:
81 | name: dataset name, should be 'OTB100', 'CVPR13', 'OTB50'
82 | dataset_root: dataset root
83 | load_img: wether to load all imgs
84 | """
85 | def __init__(self, name, dataset_root, load_img=False):
86 | super(OTBDataset, self).__init__(name, dataset_root)
87 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f:
88 | meta_data = json.load(f)
89 |
90 | # load videos
91 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100)
92 | self.videos = {}
93 | for video in pbar:
94 | pbar.set_postfix_str(video)
95 | self.videos[video] = OTBVideo(video,
96 | dataset_root,
97 | meta_data[video]['video_dir'],
98 | meta_data[video]['init_rect'],
99 | meta_data[video]['img_names'],
100 | meta_data[video]['gt_rect'],
101 | meta_data[video]['attr'],
102 | load_img)
103 |
104 | # set attr
105 | attr = []
106 | for x in self.videos.values():
107 | attr += x.attr
108 | attr = set(attr)
109 | self.attr = {}
110 | self.attr['ALL'] = list(self.videos.keys())
111 | for x in attr:
112 | self.attr[x] = []
113 | for k, v in self.videos.items():
114 | for attr_ in v.attr:
115 | self.attr[attr_].append(k)
116 |
--------------------------------------------------------------------------------
/toolkit/datasets/trackingnet.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy as np
4 |
5 | from tqdm import tqdm
6 | from glob import glob
7 |
8 | from .dataset import Dataset
9 | from .video import Video
10 |
11 | class TrackingNetVideo(Video):
12 | """
13 | Args:
14 | name: video name
15 | root: dataset root
16 | video_dir: video directory
17 | init_rect: init rectangle
18 | img_names: image names
19 | gt_rect: groundtruth rectangle
20 | attr: attribute of video
21 | """
22 | def __init__(self, name, root, video_dir, init_rect, img_names,
23 | gt_rect, attr, load_img=False):
24 | super(TrackingNetVideo, self).__init__(name, root, video_dir,
25 | init_rect, img_names, gt_rect, attr, load_img)
26 |
27 | # def load_tracker(self, path, tracker_names=None):
28 | # """
29 | # Args:
30 | # path(str): path to result
31 | # tracker_name(list): name of tracker
32 | # """
33 | # if not tracker_names:
34 | # tracker_names = [x.split('/')[-1] for x in glob(path)
35 | # if os.path.isdir(x)]
36 | # if isinstance(tracker_names, str):
37 | # tracker_names = [tracker_names]
38 | # # self.pred_trajs = {}
39 | # for name in tracker_names:
40 | # traj_file = os.path.join(path, name, self.name+'.txt')
41 | # if os.path.exists(traj_file):
42 | # with open(traj_file, 'r') as f :
43 | # self.pred_trajs[name] = [list(map(float, x.strip().split(',')))
44 | # for x in f.readlines()]
45 | # if len(self.pred_trajs[name]) != len(self.gt_traj):
46 | # print(name, len(self.pred_trajs[name]), len(self.gt_traj), self.name)
47 | # else:
48 |
49 | # self.tracker_names = list(self.pred_trajs.keys())
50 |
51 | class TrackingNetDataset(Dataset):
52 | """
53 | Args:
54 | name: dataset name, should be "NFS30" or "NFS240"
55 | dataset_root, dataset root dir
56 | """
57 | def __init__(self, name, dataset_root, load_img=False):
58 | super(TrackingNetDataset, self).__init__(name, dataset_root)
59 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f:
60 | meta_data = json.load(f)
61 |
62 | # load videos
63 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100)
64 | self.videos = {}
65 | for video in pbar:
66 | pbar.set_postfix_str(video)
67 | self.videos[video] = TrackingNetVideo(video,
68 | dataset_root,
69 | meta_data[video]['video_dir'],
70 | meta_data[video]['init_rect'],
71 | meta_data[video]['img_names'],
72 | meta_data[video]['gt_rect'],
73 | None)
74 | self.attr = {}
75 | self.attr['ALL'] = list(self.videos.keys())
76 |
--------------------------------------------------------------------------------
/toolkit/datasets/uav.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | from tqdm import tqdm
5 | from glob import glob
6 |
7 | from .dataset import Dataset
8 | from .video import Video
9 |
10 | class UAVVideo(Video):
11 | """
12 | Args:
13 | name: video name
14 | root: dataset root
15 | video_dir: video directory
16 | init_rect: init rectangle
17 | img_names: image names
18 | gt_rect: groundtruth rectangle
19 | attr: attribute of video
20 | """
21 | def __init__(self, name, root, video_dir, init_rect, img_names,
22 | gt_rect, attr, load_img=False):
23 | super(UAVVideo, self).__init__(name, root, video_dir,
24 | init_rect, img_names, gt_rect, attr, load_img)
25 |
26 |
27 | class UAVDataset(Dataset):
28 | """
29 | Args:
30 | name: dataset name, should be 'UAV123', 'UAV20L'
31 | dataset_root: dataset root
32 | load_img: wether to load all imgs
33 | """
34 | def __init__(self, name, dataset_root, load_img=False):
35 | super(UAVDataset, self).__init__(name, dataset_root)
36 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f:
37 | meta_data = json.load(f)
38 |
39 | # load videos
40 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100)
41 | self.videos = {}
42 | for video in pbar:
43 | pbar.set_postfix_str(video)
44 | self.videos[video] = UAVVideo(video,
45 | dataset_root,
46 | meta_data[video]['video_dir'],
47 | meta_data[video]['init_rect'],
48 | meta_data[video]['img_names'],
49 | meta_data[video]['gt_rect'],
50 | meta_data[video]['attr'])
51 |
52 | # set attr
53 | attr = []
54 | for x in self.videos.values():
55 | attr += x.attr
56 | attr = set(attr)
57 | self.attr = {}
58 | self.attr['ALL'] = list(self.videos.keys())
59 | for x in attr:
60 | self.attr[x] = []
61 | for k, v in self.videos.items():
62 | for attr_ in v.attr:
63 | self.attr[attr_].append(k)
64 |
65 |
--------------------------------------------------------------------------------
/toolkit/datasets/video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import re
4 | import numpy as np
5 | import json
6 |
7 | from glob import glob
8 |
9 | class Video(object):
10 | def __init__(self, name, root, video_dir, init_rect, img_names,
11 | gt_rect, attr, load_img=False):
12 | self.name = name
13 | self.video_dir = video_dir
14 | self.init_rect = init_rect
15 | self.gt_traj = gt_rect
16 | self.attr = attr
17 | self.pred_trajs = {}
18 | self.img_names = [os.path.join(root, x) for x in img_names]
19 | self.imgs = None
20 |
21 | if load_img:
22 | self.imgs = [cv2.imread(x) for x in self.img_names]
23 | self.width = self.imgs[0].shape[1]
24 | self.height = self.imgs[0].shape[0]
25 | else:
26 | img = cv2.imread(self.img_names[0])
27 | assert img is not None, self.img_names[0]
28 | self.width = img.shape[1]
29 | self.height = img.shape[0]
30 |
31 | def load_tracker(self, path, tracker_names=None, store=True):
32 | """
33 | Args:
34 | path(str): path to result
35 | tracker_name(list): name of tracker
36 | """
37 | if not tracker_names:
38 | tracker_names = [x.split('/')[-1] for x in glob(path)
39 | if os.path.isdir(x)]
40 | if isinstance(tracker_names, str):
41 | tracker_names = [tracker_names]
42 | for name in tracker_names:
43 | traj_file = os.path.join(path, name, self.name+'.txt')
44 | if os.path.exists(traj_file):
45 | with open(traj_file, 'r') as f :
46 | pred_traj = [list(map(float, x.strip().split(',')))
47 | for x in f.readlines()]
48 | if len(pred_traj) != len(self.gt_traj):
49 | print(name, len(pred_traj), len(self.gt_traj), self.name)
50 | if store:
51 | self.pred_trajs[name] = pred_traj
52 | else:
53 | return pred_traj
54 | else:
55 | print(traj_file)
56 | self.tracker_names = list(self.pred_trajs.keys())
57 |
58 | def load_img(self):
59 | if self.imgs is None:
60 | self.imgs = [cv2.imread(x) for x in self.img_names]
61 | self.width = self.imgs[0].shape[1]
62 | self.height = self.imgs[0].shape[0]
63 |
64 | def free_img(self):
65 | self.imgs = None
66 |
67 | def __len__(self):
68 | return len(self.img_names)
69 |
70 | def __getitem__(self, idx):
71 | if self.imgs is None:
72 | return cv2.imread(self.img_names[idx]), self.gt_traj[idx]
73 | else:
74 | return self.imgs[idx], self.gt_traj[idx]
75 |
76 | def __iter__(self):
77 | for i in range(len(self.img_names)):
78 | if self.imgs is not None:
79 | yield self.imgs[i], self.gt_traj[i]
80 | else:
81 | yield cv2.imread(self.img_names[i]), self.gt_traj[i]
82 |
83 | def draw_box(self, roi, img, linewidth, color, name=None):
84 | """
85 | roi: rectangle or polygon
86 | img: numpy array img
87 | linewith: line width of the bbox
88 | """
89 | if len(roi) > 6 and len(roi) % 2 == 0:
90 | pts = np.array(roi, np.int32).reshape(-1, 1, 2)
91 | color = tuple(map(int, color))
92 | img = cv2.polylines(img, [pts], True, color, linewidth)
93 | pt = (pts[0, 0, 0], pts[0, 0, 1]-5)
94 | if name:
95 | img = cv2.putText(img, name, pt, cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, color, 1)
96 | elif len(roi) == 4:
97 | if not np.isnan(roi[0]):
98 | roi = list(map(int, roi))
99 | color = tuple(map(int, color))
100 | img = cv2.rectangle(img, (roi[0], roi[1]), (roi[0]+roi[2], roi[1]+roi[3]),
101 | color, linewidth)
102 | if name:
103 | img = cv2.putText(img, name, (roi[0], roi[1]-5), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, color, 1)
104 | return img
105 |
106 | def show(self, pred_trajs={}, linewidth=2, show_name=False):
107 | """
108 | pred_trajs: dict of pred_traj, {'tracker_name': list of traj}
109 | pred_traj should contain polygon or rectangle(x, y, width, height)
110 | linewith: line width of the bbox
111 | """
112 | assert self.imgs is not None
113 | video = []
114 | cv2.namedWindow(self.name, cv2.WINDOW_NORMAL)
115 | colors = {}
116 | if len(pred_trajs) == 0 and len(self.pred_trajs) > 0:
117 | pred_trajs = self.pred_trajs
118 | for i, (roi, img) in enumerate(zip(self.gt_traj,
119 | self.imgs[self.start_frame:self.end_frame+1])):
120 | img = img.copy()
121 | if len(img.shape) == 2:
122 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
123 | else:
124 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
125 | img = self.draw_box(roi, img, linewidth, (0, 255, 0),
126 | 'gt' if show_name else None)
127 | for name, trajs in pred_trajs.items():
128 | if name not in colors:
129 | color = tuple(np.random.randint(0, 256, 3))
130 | colors[name] = color
131 | else:
132 | color = colors[name]
133 | img = self.draw_box(trajs[0][i], img, linewidth, color,
134 | name if show_name else None)
135 | cv2.putText(img, str(i+self.start_frame), (5, 20),
136 | cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (255, 255, 0), 2)
137 | cv2.imshow(self.name, img)
138 | cv2.waitKey(40)
139 | video.append(img.copy())
140 | return video
141 |
--------------------------------------------------------------------------------
/toolkit/datasets/visdrone.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy as np
4 |
5 | from PIL import Image
6 | from tqdm import tqdm
7 |
8 | from .dataset import Dataset
9 | from .video import Video
10 |
11 |
12 | class VisDroneVideo(Video):
13 | """
14 | Args:
15 | name: video name
16 | root: dataset root
17 | video_dir: video directory
18 | init_rect: init rectangle
19 | img_names: image names
20 | gt_rect: groundtruth rectangle
21 | attr: attribute of video
22 | """
23 | def __init__(self, name, root, video_dir, init_rect, img_names,
24 | gt_rect, attr, load_img=False):
25 | super(VisDroneVideo, self).__init__(name, root, video_dir,
26 | init_rect, img_names, gt_rect, attr, load_img)
27 |
28 | def load_tracker(self, path, tracker_names=None, store=True):
29 | """
30 | Args:
31 | path(str): path to result
32 | tracker_name(list): name of tracker
33 | """
34 | if not tracker_names:
35 | tracker_names = [x.split('/')[-1] for x in glob(path)
36 | if os.path.isdir(x)]
37 | if isinstance(tracker_names, str):
38 | tracker_names = [tracker_names]
39 | for name in tracker_names:
40 | traj_file = os.path.join(path, name, self.name+'.txt')
41 | # if not os.path.exists(traj_file):
42 | # if self.name == 'FleetFace':
43 | # txt_name = 'fleetface.txt'
44 | # elif self.name == 'Jogging-1':
45 | # txt_name = 'jogging_1.txt'
46 | # elif self.name == 'Jogging-2':
47 | # txt_name = 'jogging_2.txt'
48 | # elif self.name == 'Skating2-1':
49 | # txt_name = 'skating2_1.txt'
50 | # elif self.name == 'Skating2-2':
51 | # txt_name = 'skating2_2.txt'
52 | # elif self.name == 'FaceOcc1':
53 | # txt_name = 'faceocc1.txt'
54 | # elif self.name == 'FaceOcc2':
55 | # txt_name = 'faceocc2.txt'
56 | # elif self.name == 'Human4-2':
57 | # txt_name = 'human4_2.txt'
58 | # else:
59 | # txt_name = self.name[0].lower()+self.name[1:]+'.txt'
60 | # traj_file = os.path.join(path, name, txt_name)
61 | if os.path.exists(traj_file):
62 | with open(traj_file, 'r') as f :
63 | pred_traj = [list(map(float, x.strip().split(',')))
64 | for x in f.readlines()]
65 | if len(pred_traj) != len(self.gt_traj):
66 | print(name, len(pred_traj), len(self.gt_traj), self.name)
67 | if store:
68 | self.pred_trajs[name] = pred_traj
69 | else:
70 | return pred_traj
71 | else:
72 | print(traj_file)
73 | self.tracker_names = list(self.pred_trajs.keys())
74 |
75 |
76 |
77 | class VisDroneDataset(Dataset):
78 | """
79 | Args:
80 | name: dataset name, should be 'VisDrone2019'
81 | dataset_root: dataset root
82 | load_img: wether to load all imgs
83 | """
84 | def __init__(self, name, dataset_root, load_img=False):
85 | super(VisDroneDataset, self).__init__(name, dataset_root)
86 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f:
87 | meta_data = json.load(f)
88 |
89 | # load videos
90 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100)
91 | self.videos = {}
92 | for video in pbar:
93 | pbar.set_postfix_str(video)
94 | self.videos[video] = VisDroneVideo(video,
95 | dataset_root,
96 | meta_data[video]['video_dir'],
97 | meta_data[video]['init_rect'],
98 | meta_data[video]['img_names'],
99 | meta_data[video]['gt_rect'],
100 | meta_data[video]['attr'],
101 | load_img)
102 |
103 | # set attr
104 | attr = []
105 | for x in self.videos.values():
106 | if x.attr:
107 | attr += x.attr
108 | attr = set(attr)
109 | self.attr = {}
110 | self.attr['ALL'] = list(self.videos.keys())
111 | for x in attr:
112 | self.attr[x] = []
113 | for k, v in self.videos.items():
114 | if v.attr:
115 | for attr_ in v.attr:
116 | self.attr[attr_].append(k)
--------------------------------------------------------------------------------
/toolkit/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .ar_benchmark import AccuracyRobustnessBenchmark
2 | from .eao_benchmark import EAOBenchmark
3 | from .ope_benchmark import OPEBenchmark
4 | from .f1_benchmark import F1Benchmark
5 |
--------------------------------------------------------------------------------
/toolkit/evaluation/ar_benchmark.py:
--------------------------------------------------------------------------------
1 | """
2 | @author
3 | """
4 |
5 | import warnings
6 | import itertools
7 | import numpy as np
8 |
9 | from colorama import Style, Fore
10 | from ..utils import calculate_failures, calculate_accuracy
11 |
12 | class AccuracyRobustnessBenchmark:
13 | """
14 | Args:
15 | dataset:
16 | burnin:
17 | """
18 | def __init__(self, dataset, burnin=10):
19 | self.dataset = dataset
20 | self.burnin = burnin
21 |
22 | def eval(self, eval_trackers=None):
23 | """
24 | Args:
25 | eval_tags: list of tag
26 | eval_trackers: list of tracker name
27 | Returns:
28 | ret: dict of results
29 | """
30 | if eval_trackers is None:
31 | eval_trackers = self.dataset.tracker_names
32 | if isinstance(eval_trackers, str):
33 | eval_trackers = [eval_trackers]
34 |
35 | result = {}
36 | for tracker_name in eval_trackers:
37 | accuracy, failures = self._calculate_accuracy_robustness(tracker_name)
38 | result[tracker_name] = {'overlaps': accuracy,
39 | 'failures': failures}
40 | return result
41 |
42 | def show_result(self, result, eao_result=None, show_video_level=False, helight_threshold=0.5):
43 | """pretty print result
44 | Args:
45 | result: returned dict from function eval
46 | """
47 | tracker_name_len = max((max([len(x) for x in result.keys()])+2), 12)
48 | if eao_result is not None:
49 | header = "|{:^"+str(tracker_name_len)+"}|{:^10}|{:^12}|{:^13}|{:^7}|"
50 | header = header.format('Tracker Name',
51 | 'Accuracy', 'Robustness', 'Lost Number', 'EAO')
52 | formatter = "|{:^"+str(tracker_name_len)+"}|{:^10.3f}|{:^12.3f}|{:^13.1f}|{:^7.3f}|"
53 | else:
54 | header = "|{:^"+str(tracker_name_len)+"}|{:^10}|{:^12}|{:^13}|"
55 | header = header.format('Tracker Name',
56 | 'Accuracy', 'Robustness', 'Lost Number')
57 | formatter = "|{:^"+str(tracker_name_len)+"}|{:^10.3f}|{:^12.3f}|{:^13.1f}|"
58 | bar = '-'*len(header)
59 | print(bar)
60 | print(header)
61 | print(bar)
62 | if eao_result is not None:
63 | tracker_eao = sorted(eao_result.items(),
64 | key=lambda x:x[1]['all'],
65 | reverse=True)[:20]
66 | tracker_names = [x[0] for x in tracker_eao]
67 | else:
68 | tracker_names = list(result.keys())
69 | for tracker_name in tracker_names:
70 | # for tracker_name, ret in result.items():
71 | ret = result[tracker_name]
72 | overlaps = list(itertools.chain(*ret['overlaps'].values()))
73 | accuracy = np.nanmean(overlaps)
74 | length = sum([len(x) for x in ret['overlaps'].values()])
75 | failures = list(ret['failures'].values())
76 | lost_number = np.mean(np.sum(failures, axis=0))
77 | robustness = np.mean(np.sum(np.array(failures), axis=0) / length) * 100
78 | if eao_result is None:
79 | print(formatter.format(tracker_name, accuracy, robustness, lost_number))
80 | else:
81 | print(formatter.format(tracker_name, accuracy, robustness, lost_number, eao_result[tracker_name]['all']))
82 | print(bar)
83 |
84 | if show_video_level and len(result) < 10:
85 | print('\n\n')
86 | header1 = "|{:^14}|".format("Tracker name")
87 | header2 = "|{:^14}|".format("Video name")
88 | for tracker_name in result.keys():
89 | header1 += ("{:^17}|").format(tracker_name)
90 | header2 += "{:^8}|{:^8}|".format("Acc", "LN")
91 | print('-'*len(header1))
92 | print(header1)
93 | print('-'*len(header1))
94 | print(header2)
95 | print('-'*len(header1))
96 | videos = list(result[tracker_name]['overlaps'].keys())
97 | for video in videos:
98 | row = "|{:^14}|".format(video)
99 | for tracker_name in result.keys():
100 | overlaps = result[tracker_name]['overlaps'][video]
101 | accuracy = np.nanmean(overlaps)
102 | failures = result[tracker_name]['failures'][video]
103 | lost_number = np.mean(failures)
104 |
105 | accuracy_str = "{:^8.3f}".format(accuracy)
106 | if accuracy < helight_threshold:
107 | row += f'{Fore.RED}{accuracy_str}{Style.RESET_ALL}|'
108 | else:
109 | row += accuracy_str+'|'
110 | lost_num_str = "{:^8.3f}".format(lost_number)
111 | if lost_number > 0:
112 | row += f'{Fore.RED}{lost_num_str}{Style.RESET_ALL}|'
113 | else:
114 | row += lost_num_str+'|'
115 | print(row)
116 | print('-'*len(header1))
117 |
118 | def _calculate_accuracy_robustness(self, tracker_name):
119 | overlaps = {}
120 | failures = {}
121 | all_length = {}
122 | for i in range(len(self.dataset)):
123 | video = self.dataset[i]
124 | gt_traj = video.gt_traj
125 | if tracker_name not in video.pred_trajs:
126 | tracker_trajs = video.load_tracker(self.dataset.tracker_path, tracker_name, False)
127 | else:
128 | tracker_trajs = video.pred_trajs[tracker_name]
129 | overlaps_group = []
130 | num_failures_group = []
131 | for tracker_traj in tracker_trajs:
132 | num_failures = calculate_failures(tracker_traj)[0]
133 | overlaps_ = calculate_accuracy(tracker_traj, gt_traj,
134 | burnin=10, bound=(video.width, video.height))[1]
135 | overlaps_group.append(overlaps_)
136 | num_failures_group.append(num_failures)
137 | with warnings.catch_warnings():
138 | warnings.simplefilter("ignore", category=RuntimeWarning)
139 | overlaps[video.name] = np.nanmean(overlaps_group, axis=0).tolist()
140 | failures[video.name] = num_failures_group
141 | return overlaps, failures
142 |
--------------------------------------------------------------------------------
/toolkit/evaluation/f1_benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from glob import glob
5 | from tqdm import tqdm
6 | from colorama import Style, Fore
7 |
8 | from ..utils import determine_thresholds, calculate_accuracy, calculate_f1
9 |
10 | class F1Benchmark:
11 | def __init__(self, dataset):
12 | """
13 | Args:
14 | result_path:
15 | """
16 | self.dataset = dataset
17 |
18 | def eval(self, eval_trackers=None):
19 | """
20 | Args:
21 | eval_tags: list of tag
22 | eval_trackers: list of tracker name
23 | Returns:
24 | eao: dict of results
25 | """
26 | if eval_trackers is None:
27 | eval_trackers = self.dataset.tracker_names
28 | if isinstance(eval_trackers, str):
29 | eval_trackers = [eval_trackers]
30 |
31 | ret = {}
32 | for tracker_name in eval_trackers:
33 | precision, recall, f1 = self._cal_precision_reall(tracker_name)
34 | ret[tracker_name] = {"precision": precision,
35 | "recall": recall,
36 | "f1": f1
37 | }
38 | return ret
39 |
40 | def _cal_precision_reall(self, tracker_name):
41 | score = []
42 | # for i in range(len(self.dataset)):
43 | # video = self.dataset[i]
44 | for video in self.dataset:
45 | if tracker_name not in video.confidence:
46 | score += video.load_tracker(self.dataset.tracker_path, tracker_name, False)[1]
47 | else:
48 | score += video.confidence[tracker_name]
49 | score = np.array(score)
50 | thresholds = determine_thresholds(score)[::-1]
51 |
52 | precision = {}
53 | recall = {}
54 | f1 = {}
55 | for i in range(len(self.dataset)):
56 | video = self.dataset[i]
57 | gt_traj = video.gt_traj
58 | N = sum([1 for x in gt_traj if len(x) > 1])
59 | if tracker_name not in video.pred_trajs:
60 | tracker_traj, score = video.load_tracker(self.dataset.tracker_path, tracker_name, False)
61 | else:
62 | tracker_traj = video.pred_trajs[tracker_name]
63 | score = video.confidence[tracker_name]
64 | overlaps = calculate_accuracy(tracker_traj, gt_traj, \
65 | bound=(video.width,video.height))[1]
66 | f1[video.name], precision[video.name], recall[video.name] = \
67 | calculate_f1(overlaps, score, (video.width,video.height),thresholds, N)
68 | return precision, recall, f1
69 |
70 | def show_result(self, result, show_video_level=False, helight_threshold=0.5):
71 | """pretty print result
72 | Args:
73 | result: returned dict from function eval
74 | """
75 | # sort tracker according to f1
76 | sorted_tracker = {}
77 | for tracker_name, ret in result.items():
78 | precision = np.mean(list(ret['precision'].values()), axis=0)
79 | recall = np.mean(list(ret['recall'].values()), axis=0)
80 | f1 = 2 * precision * recall / (precision + recall)
81 | max_idx = np.argmax(f1)
82 | sorted_tracker[tracker_name] = (precision[max_idx], recall[max_idx],
83 | f1[max_idx])
84 | sorted_tracker_ = sorted(sorted_tracker.items(),
85 | key=lambda x:x[1][2],
86 | reverse=True)[:20]
87 | tracker_names = [x[0] for x in sorted_tracker_]
88 |
89 | tracker_name_len = max((max([len(x) for x in result.keys()])+2), 12)
90 | header = "|{:^"+str(tracker_name_len)+"}|{:^11}|{:^8}|{:^7}|"
91 | header = header.format('Tracker Name',
92 | 'Precision', 'Recall', 'F1')
93 | bar = '-' * len(header)
94 | formatter = "|{:^"+str(tracker_name_len)+"}|{:^11.3f}|{:^8.3f}|{:^7.3f}|"
95 | print(bar)
96 | print(header)
97 | print(bar)
98 | # for tracker_name, ret in result.items():
99 | # precision = np.mean(list(ret['precision'].values()), axis=0)
100 | # recall = np.mean(list(ret['recall'].values()), axis=0)
101 | # f1 = 2 * precision * recall / (precision + recall)
102 | # max_idx = np.argmax(f1)
103 | for tracker_name in tracker_names:
104 | precision = sorted_tracker[tracker_name][0]
105 | recall = sorted_tracker[tracker_name][1]
106 | f1 = sorted_tracker[tracker_name][2]
107 | print(formatter.format(tracker_name, precision, recall, f1))
108 | print(bar)
109 |
110 | if show_video_level and len(result) < 10:
111 | print('\n\n')
112 | header1 = "|{:^14}|".format("Tracker name")
113 | header2 = "|{:^14}|".format("Video name")
114 | for tracker_name in result.keys():
115 | # col_len = max(20, len(tracker_name))
116 | header1 += ("{:^28}|").format(tracker_name)
117 | header2 += "{:^11}|{:^8}|{:^7}|".format("Precision", "Recall", "F1")
118 | print('-'*len(header1))
119 | print(header1)
120 | print('-'*len(header1))
121 | print(header2)
122 | print('-'*len(header1))
123 | videos = list(result[tracker_name]['precision'].keys())
124 | for video in videos:
125 | row = "|{:^14}|".format(video)
126 | for tracker_name in result.keys():
127 | precision = result[tracker_name]['precision'][video]
128 | recall = result[tracker_name]['recall'][video]
129 | f1 = result[tracker_name]['f1'][video]
130 | max_idx = np.argmax(f1)
131 | precision_str = "{:^11.3f}".format(precision[max_idx])
132 | if precision[max_idx] < helight_threshold:
133 | row += f'{Fore.RED}{precision_str}{Style.RESET_ALL}|'
134 | else:
135 | row += precision_str+'|'
136 | recall_str = "{:^8.3f}".format(recall[max_idx])
137 | if recall[max_idx] < helight_threshold:
138 | row += f'{Fore.RED}{recall_str}{Style.RESET_ALL}|'
139 | else:
140 | row += recall_str+'|'
141 | f1_str = "{:^7.3f}".format(f1[max_idx])
142 | if f1[max_idx] < helight_threshold:
143 | row += f'{Fore.RED}{f1_str}{Style.RESET_ALL}|'
144 | else:
145 | row += f1_str+'|'
146 | print(row)
147 | print('-'*len(header1))
148 |
--------------------------------------------------------------------------------
/toolkit/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from . import region
2 | from .statistics import *
3 |
--------------------------------------------------------------------------------
/toolkit/utils/c_region.pxd:
--------------------------------------------------------------------------------
1 | cdef extern from "src/region.h":
2 | ctypedef enum region_type "RegionType":
3 | EMTPY
4 | SPECIAL
5 | RECTANGEL
6 | POLYGON
7 | MASK
8 |
9 | ctypedef struct region_bounds:
10 | float top
11 | float bottom
12 | float left
13 | float right
14 |
15 | ctypedef struct region_rectangle:
16 | float x
17 | float y
18 | float width
19 | float height
20 |
21 | # ctypedef struct region_mask:
22 | # int x
23 | # int y
24 | # int width
25 | # int height
26 | # char *data
27 |
28 | ctypedef struct region_polygon:
29 | int count
30 | float *x
31 | float *y
32 |
33 | ctypedef union region_container_data:
34 | region_rectangle rectangle
35 | region_polygon polygon
36 | # region_mask mask
37 | int special
38 |
39 | ctypedef struct region_container:
40 | region_type type
41 | region_container_data data
42 |
43 | # ctypedef struct region_overlap:
44 | # float overlap
45 | # float only1
46 | # float only2
47 |
48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds)
49 |
50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds)
51 |
--------------------------------------------------------------------------------
/toolkit/utils/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | @author fangyi.zhang@vipl.ict.ac.cn
3 | """
4 | import numpy as np
5 |
6 | def determine_thresholds(confidence, resolution=100):
7 | """choose threshold according to confidence
8 |
9 | Args:
10 | confidence: list or numpy array or numpy array
11 | reolution: number of threshold to choose
12 |
13 | Restures:
14 | threshold: numpy array
15 | """
16 | if isinstance(confidence, list):
17 | confidence = np.array(confidence)
18 | confidence = confidence.flatten()
19 | confidence = confidence[~np.isnan(confidence)]
20 | confidence.sort()
21 |
22 | assert len(confidence) > resolution and resolution > 2
23 |
24 | thresholds = np.ones((resolution))
25 | thresholds[0] = - np.inf
26 | thresholds[-1] = np.inf
27 | delta = np.floor(len(confidence) / (resolution - 2))
28 | idxs = np.linspace(delta, len(confidence)-delta, resolution-2, dtype=np.int32)
29 | thresholds[1:-1] = confidence[idxs]
30 | return thresholds
31 |
--------------------------------------------------------------------------------
/toolkit/utils/src/buffer.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef __STRING_BUFFER_H
3 | #define __STRING_BUFFER_H
4 |
5 | // Enable MinGW secure API for _snprintf_s
6 | #define MINGW_HAS_SECURE_API 1
7 |
8 | #ifdef _MSC_VER
9 | #define __INLINE __inline
10 | #else
11 | #define __INLINE inline
12 | #endif
13 |
14 | #include
15 | #include
16 | #include
17 |
18 | typedef struct string_buffer {
19 | char* buffer;
20 | int position;
21 | int size;
22 | } string_buffer;
23 |
24 | typedef struct string_list {
25 | char** buffer;
26 | int position;
27 | int size;
28 | } string_list;
29 |
30 | #define BUFFER_INCREMENT_STEP 4096
31 |
32 | static __INLINE string_buffer* buffer_create(int L) {
33 | string_buffer* B = (string_buffer*) malloc(sizeof(string_buffer));
34 | B->size = L;
35 | B->buffer = (char*) malloc(sizeof(char) * B->size);
36 | B->position = 0;
37 | return B;
38 | }
39 |
40 | static __INLINE void buffer_reset(string_buffer* B) {
41 | B->position = 0;
42 | }
43 |
44 | static __INLINE void buffer_destroy(string_buffer** B) {
45 | if (!(*B)) return;
46 | if ((*B)->buffer) {
47 | free((*B)->buffer);
48 | (*B)->buffer = NULL;
49 | }
50 | free((*B));
51 | (*B) = NULL;
52 | }
53 |
54 | static __INLINE char* buffer_extract(const string_buffer* B) {
55 | char *S = (char*) malloc(sizeof(char) * (B->position + 1));
56 | memcpy(S, B->buffer, B->position);
57 | S[B->position] = '\0';
58 | return S;
59 | }
60 |
61 | static __INLINE int buffer_size(const string_buffer* B) {
62 | return B->position;
63 | }
64 |
65 | static __INLINE void buffer_push(string_buffer* B, char C) {
66 | int required = 1;
67 | if (required > B->size - B->position) {
68 | B->size = B->position + BUFFER_INCREMENT_STEP;
69 | B->buffer = (char*) realloc(B->buffer, sizeof(char) * B->size);
70 | }
71 | B->buffer[B->position] = C;
72 | B->position += required;
73 | }
74 |
75 | static __INLINE void buffer_append(string_buffer* B, const char *format, ...) {
76 |
77 | int required;
78 | va_list args;
79 |
80 | #if defined(__OS2__) || defined(__WINDOWS__) || defined(WIN32) || defined(_MSC_VER)
81 |
82 | va_start(args, format);
83 | required = _vscprintf(format, args) + 1;
84 | va_end(args);
85 | if (required >= B->size - B->position) {
86 | B->size = B->position + required + 1;
87 | B->buffer = (char*) realloc(B->buffer, sizeof(char) * B->size);
88 | }
89 | va_start(args, format);
90 | required = _vsnprintf_s(&(B->buffer[B->position]), B->size - B->position, _TRUNCATE, format, args);
91 | va_end(args);
92 | B->position += required;
93 |
94 | #else
95 | va_start(args, format);
96 | required = vsnprintf(&(B->buffer[B->position]), B->size - B->position, format, args);
97 | va_end(args);
98 | if (required >= B->size - B->position) {
99 | B->size = B->position + required + 1;
100 | B->buffer = (char*) realloc(B->buffer, sizeof(char) * B->size);
101 | va_start(args, format);
102 | required = vsnprintf(&(B->buffer[B->position]), B->size - B->position, format, args);
103 | va_end(args);
104 | }
105 | B->position += required;
106 | #endif
107 |
108 | }
109 |
110 | static __INLINE string_list* list_create(int L) {
111 | string_list* B = (string_list*) malloc(sizeof(string_list));
112 | B->size = L;
113 | B->buffer = (char**) malloc(sizeof(char*) * B->size);
114 | memset(B->buffer, 0, sizeof(char*) * B->size);
115 | B->position = 0;
116 | return B;
117 | }
118 |
119 | static __INLINE void list_reset(string_list* B) {
120 | int i;
121 | for (i = 0; i < B->position; i++) {
122 | if (B->buffer[i]) free(B->buffer[i]);
123 | B->buffer[i] = NULL;
124 | }
125 | B->position = 0;
126 | }
127 |
128 | static __INLINE void list_destroy(string_list **B) {
129 | int i;
130 |
131 | if (!(*B)) return;
132 |
133 | for (i = 0; i < (*B)->position; i++) {
134 | if ((*B)->buffer[i]) free((*B)->buffer[i]); (*B)->buffer[i] = NULL;
135 | }
136 |
137 | if ((*B)->buffer) {
138 | free((*B)->buffer); (*B)->buffer = NULL;
139 | }
140 |
141 | free((*B));
142 | (*B) = NULL;
143 | }
144 |
145 | static __INLINE char* list_get(const string_list *B, int I) {
146 | if (I < 0 || I >= B->position) {
147 | return NULL;
148 | } else {
149 | if (!B->buffer[I]) {
150 | return NULL;
151 | } else {
152 | char *S;
153 | int length = strlen(B->buffer[I]);
154 | S = (char*) malloc(sizeof(char) * (length + 1));
155 | memcpy(S, B->buffer[I], length + 1);
156 | return S;
157 | }
158 | }
159 | }
160 |
161 | static __INLINE int list_size(const string_list *B) {
162 | return B->position;
163 | }
164 |
165 | static __INLINE void list_append(string_list *B, char* S) {
166 | int required = 1;
167 | int length = strlen(S);
168 | if (required > B->size - B->position) {
169 | B->size = B->position + 16;
170 | B->buffer = (char**) realloc(B->buffer, sizeof(char*) * B->size);
171 | }
172 | B->buffer[B->position] = (char*) malloc(sizeof(char) * (length + 1));
173 | memcpy(B->buffer[B->position], S, length + 1);
174 | B->position += required;
175 | }
176 |
177 | // This version of the append does not copy the string but simply takes the control of its allocation
178 | static __INLINE void list_append_direct(string_list *B, char* S) {
179 | int required = 1;
180 | // int length = strlen(S);
181 | if (required > B->size - B->position) {
182 | B->size = B->position + 16;
183 | B->buffer = (char**) realloc(B->buffer, sizeof(char*) * B->size);
184 | }
185 | B->buffer[B->position] = S;
186 | B->position += required;
187 | }
188 |
189 |
190 | #endif
191 |
--------------------------------------------------------------------------------
/toolkit/utils/src/region.h:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */
2 |
3 | #ifndef _REGION_H_
4 | #define _REGION_H_
5 |
6 | #ifdef TRAX_STATIC_DEFINE
7 | # define __TRAX_EXPORT
8 | #else
9 | # ifndef __TRAX_EXPORT
10 | # if defined(_MSC_VER)
11 | # ifdef trax_EXPORTS
12 | /* We are building this library */
13 | # define __TRAX_EXPORT __declspec(dllexport)
14 | # else
15 | /* We are using this library */
16 | # define __TRAX_EXPORT __declspec(dllimport)
17 | # endif
18 | # elif defined(__GNUC__)
19 | # ifdef trax_EXPORTS
20 | /* We are building this library */
21 | # define __TRAX_EXPORT __attribute__((visibility("default")))
22 | # else
23 | /* We are using this library */
24 | # define __TRAX_EXPORT __attribute__((visibility("default")))
25 | # endif
26 | # endif
27 | # endif
28 | #endif
29 |
30 | #ifndef MAX
31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b))
32 | #endif
33 |
34 | #ifndef MIN
35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b))
36 | #endif
37 |
38 | #define TRAX_DEFAULT_CODE 0
39 |
40 | #define REGION_LEGACY_RASTERIZATION 1
41 |
42 | #ifdef __cplusplus
43 | extern "C" {
44 | #endif
45 |
46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type;
47 |
48 | typedef struct region_bounds {
49 |
50 | float top;
51 | float bottom;
52 | float left;
53 | float right;
54 |
55 | } region_bounds;
56 |
57 | typedef struct region_polygon {
58 |
59 | int count;
60 |
61 | float* x;
62 | float* y;
63 |
64 | } region_polygon;
65 |
66 | typedef struct region_mask {
67 |
68 | int x;
69 | int y;
70 |
71 | int width;
72 | int height;
73 |
74 | char* data;
75 |
76 | } region_mask;
77 |
78 | typedef struct region_rectangle {
79 |
80 | float x;
81 | float y;
82 | float width;
83 | float height;
84 |
85 | } region_rectangle;
86 |
87 | typedef struct region_container {
88 | enum region_type type;
89 | union {
90 | region_rectangle rectangle;
91 | region_polygon polygon;
92 | region_mask mask;
93 | int special;
94 | } data;
95 | } region_container;
96 |
97 | typedef struct region_overlap {
98 |
99 | float overlap;
100 | float only1;
101 | float only2;
102 |
103 | } region_overlap;
104 |
105 | extern const region_bounds region_no_bounds;
106 |
107 | __TRAX_EXPORT int region_set_flags(int mask);
108 |
109 | __TRAX_EXPORT int region_clear_flags(int mask);
110 |
111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds);
112 |
113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds);
114 |
115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom);
116 |
117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region);
118 |
119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region);
120 |
121 | __TRAX_EXPORT char* region_string(region_container* region);
122 |
123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region);
124 |
125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type);
126 |
127 | __TRAX_EXPORT void region_release(region_container** region);
128 |
129 | __TRAX_EXPORT region_container* region_create_special(int code);
130 |
131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height);
132 |
133 | __TRAX_EXPORT region_container* region_create_polygon(int count);
134 |
135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y);
136 |
137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height);
138 |
139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height);
140 |
141 | #ifdef __cplusplus
142 | }
143 | #endif
144 |
145 | #endif
146 |
--------------------------------------------------------------------------------
/toolkit/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | from .draw_f1 import draw_f1
2 | from .draw_success_precision import draw_success_precision
3 | from .draw_eao import draw_eao
4 |
--------------------------------------------------------------------------------
/toolkit/visualization/draw_eao.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import pickle
4 | from functools import cmp_to_key
5 | from matplotlib import rc
6 | from .draw_utils import COLOR, MARKER_STYLE
7 |
8 | rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
9 | rc('text', usetex=True)
10 |
11 | def draw_eao(result):
12 | fig = plt.figure()
13 | ax = fig.add_subplot(111, projection='polar')
14 | angles = np.linspace(0, 2*np.pi, 8, endpoint=True)
15 |
16 | attr2value = []
17 | for i, (tracker_name, ret) in enumerate(result.items()):
18 | value = list(ret.values())
19 | attr2value.append(value)
20 | value.append(value[0])
21 | attr2value = np.array(attr2value)
22 | max_value = np.max(attr2value, axis=0)
23 | min_value = np.min(attr2value, axis=0)
24 | result = {key:value for key,value in sorted(result.items(), key=lambda x:x[1]['all'], reverse=True)}
25 | for i, (tracker_name, ret) in enumerate(result.items()):
26 | value = list(ret.values())
27 | value.append(value[0])
28 | value = np.array(value)
29 | value *= (1 / max_value)
30 | plt.plot(angles, value, linestyle='-', color=COLOR[i], marker=MARKER_STYLE[i],
31 | label=tracker_name, linewidth=1.5, markersize=6)
32 |
33 | attrs = ["Overall", "Camera motion", "Illumination change","Motion Change",
34 | "Size change", "Occlusion", "Unassigned"]
35 | attr_value = []
36 | for attr, maxv, minv in zip(attrs, max_value, min_value):
37 | attr_value.append(attr + "\n({:.3f},{:.3f})".format(minv, maxv))
38 | ax.set_thetagrids(angles[:-1] * 180/np.pi, attr_value)
39 | ax.spines['polar'].set_visible(False)
40 | ax.legend(loc='upper center', bbox_to_anchor=(0.5,-0.07), frameon=False, ncol=5)
41 | # ax.grid(b=False)
42 | ax.set_ylim(0, 1.18)
43 | ax.set_yticks([])
44 | plt.show()
45 |
46 | if __name__ == '__main__':
47 | result = pickle.load(open("../../result.pkl", 'rb'))
48 | draw_eao(result)
49 |
--------------------------------------------------------------------------------
/toolkit/visualization/draw_f1.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 | from matplotlib import rc
5 | from .draw_utils import COLOR, LINE_STYLE
6 |
7 | rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
8 | rc('text', usetex=True)
9 |
10 | def draw_f1(result, bold_name=None):
11 | # drawing f1 contour
12 | fig, ax = plt.subplots()
13 | for f1 in np.arange(0.1, 1, 0.1):
14 | recall = np.arange(f1, 1+0.01, 0.01)
15 | precision = f1 * recall / (2 * recall - f1)
16 | ax.plot(recall, precision, color=[0,1,0], linestyle='-', linewidth=0.5)
17 | ax.plot(precision, recall, color=[0,1,0], linestyle='-', linewidth=0.5)
18 | ax.grid(b=True)
19 | ax.set_aspect(1)
20 | plt.xlabel('Recall')
21 | plt.ylabel('Precision')
22 | plt.axis([0, 1, 0, 1])
23 | plt.title(r'\textbf{VOT2018-LT Precision vs Recall}')
24 |
25 | # draw result line
26 | all_precision = {}
27 | all_recall = {}
28 | best_f1 = {}
29 | best_idx = {}
30 | for tracker_name, ret in result.items():
31 | precision = np.mean(list(ret['precision'].values()), axis=0)
32 | recall = np.mean(list(ret['recall'].values()), axis=0)
33 | f1 = 2 * precision * recall / (precision + recall)
34 | max_idx = np.argmax(f1)
35 | all_precision[tracker_name] = precision
36 | all_recall[tracker_name] = recall
37 | best_f1[tracker_name] = f1[max_idx]
38 | best_idx[tracker_name] = max_idx
39 |
40 | for idx, (tracker_name, best_f1) in \
41 | enumerate(sorted(best_f1.items(), key=lambda x:x[1], reverse=True)):
42 | if tracker_name == bold_name:
43 | label = r"\textbf{[%.3f] Ours}" % (best_f1)
44 | else:
45 | label = "[%.3f] " % (best_f1) + tracker_name
46 | recall = all_recall[tracker_name][:-1]
47 | precision = all_precision[tracker_name][:-1]
48 | ax.plot(recall, precision, color=COLOR[idx], linestyle='-',
49 | label=label)
50 | f1_idx = best_idx[tracker_name]
51 | ax.plot(recall[f1_idx], precision[f1_idx], color=[0,0,0], marker='o',
52 | markerfacecolor=COLOR[idx], markersize=5)
53 | ax.legend(loc='lower right', labelspacing=0.2)
54 | plt.xticks(np.arange(0, 1+0.1, 0.1))
55 | plt.yticks(np.arange(0, 1+0.1, 0.1))
56 | plt.show()
57 |
58 | if __name__ == '__main__':
59 | draw_f1(None)
60 |
--------------------------------------------------------------------------------
/toolkit/visualization/draw_success_precision.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 | from matplotlib import rc
5 | from .draw_utils import COLOR, LINE_STYLE
6 |
7 | # rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
8 | # rc('text', usetex=True)
9 |
10 | def draw_success_precision(success_ret, name, videos, attr, precision_ret=None,
11 | norm_precision_ret=None, bold_name=None, axis=[0, 1]):
12 | # success plot
13 | fig, ax = plt.subplots()
14 | ax.grid(b=True)
15 | ax.set_aspect(1)
16 | plt.xlabel('Overlap threshold')
17 | plt.ylabel('Success rate')
18 | if attr == 'ALL':
19 | plt.title(r'\textbf{Success plots of OPE on %s}' % (name))
20 | else:
21 | plt.title(r'\textbf{Success plots of OPE - %s}' % (attr))
22 | plt.axis([0, 1]+axis)
23 | success = {}
24 | thresholds = np.arange(0, 1.05, 0.05)
25 | for tracker_name in success_ret.keys():
26 | value = [v for k, v in success_ret[tracker_name].items() if k in videos]
27 | success[tracker_name] = np.mean(value)
28 | for idx, (tracker_name, auc) in \
29 | enumerate(sorted(success.items(), key=lambda x:x[1], reverse=True)):
30 | if tracker_name == bold_name:
31 | label = r"\textbf{[%.3f] %s}" % (auc, tracker_name)
32 | else:
33 | label = "[%.3f] " % (auc) + tracker_name
34 | value = [v for k, v in success_ret[tracker_name].items() if k in videos]
35 | plt.plot(thresholds, np.mean(value, axis=0),
36 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2)
37 | ax.legend(loc='lower left', labelspacing=0.2)
38 | ax.autoscale(enable=True, axis='both', tight=True)
39 | xmin, xmax, ymin, ymax = plt.axis()
40 | ax.autoscale(enable=False)
41 | ymax += 0.03
42 | plt.axis([xmin, xmax, ymin, ymax])
43 | plt.xticks(np.arange(xmin, xmax+0.01, 0.1))
44 | plt.yticks(np.arange(ymin, ymax, 0.1))
45 | ax.set_aspect((xmax - xmin)/(ymax-ymin))
46 | plt.show()
47 |
48 | if precision_ret:
49 | # norm precision plot
50 | fig, ax = plt.subplots()
51 | ax.grid(b=True)
52 | ax.set_aspect(50)
53 | plt.xlabel('Location error threshold')
54 | plt.ylabel('Precision')
55 | if attr == 'ALL':
56 | plt.title(r'\textbf{Precision plots of OPE on %s}' % (name))
57 | else:
58 | plt.title(r'\textbf{Precision plots of OPE - %s}' % (attr))
59 | plt.axis([0, 50]+axis)
60 | precision = {}
61 | thresholds = np.arange(0, 51, 1)
62 | for tracker_name in precision_ret.keys():
63 | value = [v for k, v in precision_ret[tracker_name].items() if k in videos]
64 | precision[tracker_name] = np.mean(value, axis=0)[20]
65 | for idx, (tracker_name, pre) in \
66 | enumerate(sorted(precision.items(), key=lambda x:x[1], reverse=True)):
67 | if tracker_name == bold_name:
68 | label = r"\textbf{[%.3f] %s}" % (pre, tracker_name)
69 | else:
70 | label = "[%.3f] " % (pre) + tracker_name
71 | value = [v for k, v in precision_ret[tracker_name].items() if k in videos]
72 | plt.plot(thresholds, np.mean(value, axis=0),
73 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2)
74 | ax.legend(loc='lower right', labelspacing=0.2)
75 | ax.autoscale(enable=True, axis='both', tight=True)
76 | xmin, xmax, ymin, ymax = plt.axis()
77 | ax.autoscale(enable=False)
78 | ymax += 0.03
79 | plt.axis([xmin, xmax, ymin, ymax])
80 | plt.xticks(np.arange(xmin, xmax+0.01, 5))
81 | plt.yticks(np.arange(ymin, ymax, 0.1))
82 | ax.set_aspect((xmax - xmin)/(ymax-ymin))
83 | plt.show()
84 |
85 | # norm precision plot
86 | if norm_precision_ret:
87 | fig, ax = plt.subplots()
88 | ax.grid(b=True)
89 | plt.xlabel('Location error threshold')
90 | plt.ylabel('Precision')
91 | if attr == 'ALL':
92 | plt.title(r'\textbf{Normalized Precision plots of OPE on %s}' % (name))
93 | else:
94 | plt.title(r'\textbf{Normalized Precision plots of OPE - %s}' % (attr))
95 | norm_precision = {}
96 | thresholds = np.arange(0, 51, 1) / 100
97 | for tracker_name in precision_ret.keys():
98 | value = [v for k, v in norm_precision_ret[tracker_name].items() if k in videos]
99 | norm_precision[tracker_name] = np.mean(value, axis=0)[20]
100 | for idx, (tracker_name, pre) in \
101 | enumerate(sorted(norm_precision.items(), key=lambda x:x[1], reverse=True)):
102 | if tracker_name == bold_name:
103 | label = r"\textbf{[%.3f] %s}" % (pre, tracker_name)
104 | else:
105 | label = "[%.3f] " % (pre) + tracker_name
106 | value = [v for k, v in norm_precision_ret[tracker_name].items() if k in videos]
107 | plt.plot(thresholds, np.mean(value, axis=0),
108 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2)
109 | ax.legend(loc='lower right', labelspacing=0.2)
110 | ax.autoscale(enable=True, axis='both', tight=True)
111 | xmin, xmax, ymin, ymax = plt.axis()
112 | ax.autoscale(enable=False)
113 | ymax += 0.03
114 | plt.axis([xmin, xmax, ymin, ymax])
115 | plt.xticks(np.arange(xmin, xmax+0.01, 0.05))
116 | plt.yticks(np.arange(ymin, ymax, 0.1))
117 | ax.set_aspect((xmax - xmin)/(ymax-ymin))
118 | plt.show()
119 |
--------------------------------------------------------------------------------
/toolkit/visualization/draw_utils.py:
--------------------------------------------------------------------------------
1 |
2 | COLOR = ((1, 0, 0),
3 | (0, 1, 0),
4 | (1, 0, 1),
5 | (1, 1, 0),
6 | (0, 162/255, 232/255),
7 | (0.5, 0.5, 0.5),
8 | (0, 0, 1),
9 | (0, 1, 1),
10 | (136/255, 0 , 21/255),
11 | (255/255, 127/255, 39/255),
12 | (0, 0, 0))
13 |
14 | LINE_STYLE = ['-', '--', ':', '-', '--', ':', '-', '--', ':', '-', '--']
15 |
16 | MARKER_STYLE = ['o', 'v', '<', '*', 'D', 'x', '.', 'x', '<', '.', 'o']
17 |
--------------------------------------------------------------------------------
/tools/demo.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 |
6 | import os
7 | import argparse
8 |
9 | import cv2
10 | import torch
11 | import numpy as np
12 | from glob import glob
13 |
14 | from pysot.core.config import cfg
15 | from pysot.models.model_builder import ModelBuilder
16 | from pysot.tracker.tracker_builder import build_tracker
17 |
18 | torch.set_num_threads(1)
19 |
20 | parser = argparse.ArgumentParser(description='tracking demo')
21 | parser.add_argument('--config', type=str, help='config file')
22 | parser.add_argument('--snapshot', type=str, help='model name')
23 | parser.add_argument('--video_name', default='', type=str,
24 | help='videos or image files')
25 | args = parser.parse_args()
26 |
27 |
28 | def get_frames(video_name):
29 | if not video_name:
30 | cap = cv2.VideoCapture(0)
31 | # warmup
32 | for i in range(5):
33 | cap.read()
34 | while True:
35 | ret, frame = cap.read()
36 | if ret:
37 | yield frame
38 | else:
39 | break
40 | elif video_name.endswith('avi') or \
41 | video_name.endswith('mp4'):
42 | cap = cv2.VideoCapture(args.video_name)
43 | while True:
44 | ret, frame = cap.read()
45 | if ret:
46 | yield frame
47 | else:
48 | break
49 | else:
50 | images = glob(os.path.join(video_name, '*.jp*'))
51 | images = sorted(images,
52 | key=lambda x: int(x.split('/')[-1].split('.')[0]))
53 | for img in images:
54 | frame = cv2.imread(img)
55 | yield frame
56 |
57 |
58 | def main():
59 | # load config
60 | cfg.merge_from_file(args.config)
61 | cfg.CUDA = torch.cuda.is_available()
62 | device = torch.device('cuda' if cfg.CUDA else 'cpu')
63 |
64 | # create model
65 | model = ModelBuilder()
66 |
67 | # load model
68 | model.load_state_dict(torch.load(args.snapshot,
69 | map_location=lambda storage, loc: storage.cpu()))
70 | model.eval().to(device)
71 |
72 | # build tracker
73 | tracker = build_tracker(model)
74 |
75 | first_frame = True
76 | if args.video_name:
77 | video_name = args.video_name.split('/')[-1].split('.')[0]
78 | else:
79 | video_name = 'webcam'
80 | cv2.namedWindow(video_name, cv2.WND_PROP_FULLSCREEN)
81 | for frame in get_frames(args.video_name):
82 | if first_frame:
83 | try:
84 | init_rect = cv2.selectROI(video_name, frame, False, False)
85 | except:
86 | exit()
87 | tracker.init(frame, init_rect)
88 | first_frame = False
89 | else:
90 | outputs = tracker.track(frame)
91 | if 'polygon' in outputs:
92 | polygon = np.array(outputs['polygon']).astype(np.int32)
93 | cv2.polylines(frame, [polygon.reshape((-1, 1, 2))],
94 | True, (0, 255, 0), 3)
95 | mask = ((outputs['mask'] > cfg.TRACK.MASK_THERSHOLD) * 255)
96 | mask = mask.astype(np.uint8)
97 | mask = np.stack([mask, mask*255, mask]).transpose(1, 2, 0)
98 | frame = cv2.addWeighted(frame, 0.77, mask, 0.23, -1)
99 | else:
100 | bbox = list(map(int, outputs['bbox']))
101 | cv2.rectangle(frame, (bbox[0], bbox[1]),
102 | (bbox[0]+bbox[2], bbox[1]+bbox[3]),
103 | (0, 255, 0), 3)
104 | cv2.imshow(video_name, frame)
105 | cv2.waitKey(40)
106 |
107 |
108 | if __name__ == '__main__':
109 | main()
110 |
--------------------------------------------------------------------------------
/vot_iter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jensenzhoujh/DROL/4aebe575394bc035e9924c8711c7d5d76bfef37a/vot_iter/__init__.py
--------------------------------------------------------------------------------
/vot_iter/tracker_SiamRPNpp.m:
--------------------------------------------------------------------------------
1 |
2 | % error('Tracker not configured! Please edit the tracker_test.m file.'); % Remove this line after proper configuration
3 |
4 | % The human readable label for the tracker, used to identify the tracker in reports
5 | % If not set, it will be set to the same value as the identifier.
6 | % It does not have to be unique, but it is best that it is.
7 | tracker_label = ['SiamRPNpp'];
8 |
9 | % For Python implementations we have created a handy function that generates the appropritate
10 | % command that will run the python executable and execute the given script that includes your
11 | % tracker implementation.
12 | %
13 | % Please customize the line below by substituting the first argument with the name of the
14 | % script of your tracker (not the .py file but just the name of the script) and also provide the
15 | % path (or multiple paths) where the tracker sources % are found as the elements of the cell
16 | % array (second argument).
17 | setenv('MKL_NUM_THREADS','1');
18 | pysot_root = 'path/to/pysot';
19 | track_build_path = 'path/to/track/build';
20 | tracker_command = generate_python_command('vot_iter.vot_iter', {pysot_root; [track_build_path '/python/lib']})
21 |
22 | tracker_interpreter = 'python';
23 |
24 | tracker_linkpath = {track_build_path};
25 |
26 | % tracker_linkpath = {}; % A cell array of custom library directories used by the tracker executable (optional)
27 |
28 |
--------------------------------------------------------------------------------
/vot_iter/vot.py:
--------------------------------------------------------------------------------
1 | """
2 | \file vot.py
3 |
4 | @brief Python utility functions for VOT integration
5 |
6 | @author Luka Cehovin, Alessio Dore
7 |
8 | @date 2016
9 |
10 | """
11 |
12 | import sys
13 | import copy
14 | import collections
15 |
16 | try:
17 | import trax
18 | except ImportError:
19 | raise Exception('TraX support not found. Please add trax module to Python path.')
20 |
21 | Rectangle = collections.namedtuple('Rectangle', ['x', 'y', 'width', 'height'])
22 | Point = collections.namedtuple('Point', ['x', 'y'])
23 | Polygon = collections.namedtuple('Polygon', ['points'])
24 |
25 | class VOT(object):
26 | """ Base class for Python VOT integration """
27 | def __init__(self, region_format, channels=None):
28 | """ Constructor
29 |
30 | Args:
31 | region_format: Region format options
32 | """
33 | assert(region_format in [trax.Region.RECTANGLE, trax.Region.POLYGON])
34 |
35 | if channels is None:
36 | channels = ['color']
37 | elif channels == 'rgbd':
38 | channels = ['color', 'depth']
39 | elif channels == 'rgbt':
40 | channels = ['color', 'ir']
41 | elif channels == 'ir':
42 | channels = ['ir']
43 | else:
44 | raise Exception('Illegal configuration {}.'.format(channels))
45 |
46 | self._trax = trax.Server([region_format], [trax.Image.PATH], channels)
47 |
48 | request = self._trax.wait()
49 | assert(request.type == 'initialize')
50 | if isinstance(request.region, trax.Polygon):
51 | self._region = Polygon([Point(x[0], x[1]) for x in request.region])
52 | else:
53 | self._region = Rectangle(*request.region.bounds())
54 | self._image = [x.path() for k, x in request.image.items()]
55 | if len(self._image) == 1:
56 | self._image = self._image[0]
57 |
58 | self._trax.status(request.region)
59 |
60 | def region(self):
61 | """
62 | Send configuration message to the client and receive the initialization
63 | region and the path of the first image
64 |
65 | Returns:
66 | initialization region
67 | """
68 |
69 | return self._region
70 |
71 | def report(self, region, confidence = None):
72 | """
73 | Report the tracking results to the client
74 |
75 | Arguments:
76 | region: region for the frame
77 | """
78 | assert(isinstance(region, Rectangle) or isinstance(region, Polygon))
79 | if isinstance(region, Polygon):
80 | tregion = trax.Polygon.create([(x.x, x.y) for x in region.points])
81 | else:
82 | tregion = trax.Rectangle.create(region.x, region.y, region.width, region.height)
83 | properties = {}
84 | if not confidence is None:
85 | properties['confidence'] = confidence
86 | self._trax.status(tregion, properties)
87 |
88 | def frame(self):
89 | """
90 | Get a frame (image path) from client
91 |
92 | Returns:
93 | absolute path of the image
94 | """
95 | if hasattr(self, "_image"):
96 | image = self._image
97 | del self._image
98 | return image
99 |
100 | request = self._trax.wait()
101 |
102 | if request.type == 'frame':
103 | image = [x.path() for k, x in request.image.items()]
104 | if len(image) == 1:
105 | return image[0]
106 | return image
107 | else:
108 | return None
109 |
110 |
111 | def quit(self):
112 | if hasattr(self, '_trax'):
113 | self._trax.quit()
114 |
115 | def __del__(self):
116 | self.quit()
117 |
118 |
--------------------------------------------------------------------------------
/vot_iter/vot_iter.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import os
6 | from os.path import join
7 |
8 | from pysot.core.config import cfg
9 | from pysot.models.model_builder import ModelBuilder
10 | from pysot.tracker.tracker_builder import build_tracker
11 | from pysot.utils.bbox import get_axis_aligned_bbox
12 | from pysot.utils.model_load import load_pretrain
13 | from toolkit.datasets import DatasetFactory
14 | from toolkit.utils.region import vot_overlap, vot_float2str
15 |
16 | from . import vot
17 | from .vot import Rectangle, Polygon, Point
18 |
19 |
20 | # modify root
21 |
22 | cfg_root = "path/to/expr"
23 | model_file = join(cfg_root, 'model.pth')
24 | cfg_file = join(cfg_root, 'config.yaml')
25 |
26 | def warmup(model):
27 | for i in range(10):
28 | model.template(torch.FloatTensor(1,3,127,127).cuda())
29 |
30 | def setup_tracker():
31 | cfg.merge_from_file(cfg_file)
32 |
33 | model = ModelBuilder()
34 | model = load_pretrain(model, model_file).cuda().eval()
35 |
36 | tracker = build_tracker(model)
37 | warmup(model)
38 | return tracker
39 |
40 |
41 | tracker = setup_tracker()
42 |
43 | handle = vot.VOT("polygon")
44 | region = handle.region()
45 | try:
46 | region = np.array([region[0][0][0], region[0][0][1], region[0][1][0], region[0][1][1],
47 | region[0][2][0], region[0][2][1], region[0][3][0], region[0][3][1]])
48 | except:
49 | region = np.array(region)
50 |
51 | cx, cy, w, h = get_axis_aligned_bbox(region)
52 |
53 | image_file = handle.frame()
54 | if not image_file:
55 | sys.exit(0)
56 |
57 | im = cv2.imread(image_file) # HxWxC
58 | # init
59 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h])
60 | gt_bbox_ = [cx-(w-1)/2, cy-(h-1)/2, w, h]
61 | tracker.init(im, gt_bbox_)
62 |
63 | while True:
64 | img_file = handle.frame()
65 | if not img_file:
66 | break
67 | im = cv2.imread(img_file)
68 | outputs = tracker.track(im)
69 | pred_bbox = outputs['bbox']
70 | result = Rectangle(*pred_bbox)
71 | score = outputs['best_score']
72 | if cfg.MASK.MASK:
73 | pred_bbox = outputs['polygon']
74 | result = Polygon(Point(x[0], x[1]) for x in pred_bbox)
75 |
76 | handle.report(result, score)
77 |
--------------------------------------------------------------------------------