├── INSTALL.md
├── LICENSE
├── README.md
├── figure
├── SPS.png
└── flowchart.png
├── install.sh
├── ltr
├── README.md
├── __init__.py
├── __pycache__
│ └── __init__.cpython-36.pyc
├── actors
│ ├── __init__.py
│ ├── base_actor.py
│ └── bbreg.py
├── admin
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── environment.cpython-36.pyc
│ │ ├── loading.cpython-36.pyc
│ │ ├── local.cpython-36.pyc
│ │ ├── model_constructor.cpython-36.pyc
│ │ ├── settings.cpython-36.pyc
│ │ └── stats.cpython-36.pyc
│ ├── environment.py
│ ├── loading.py
│ ├── local.py
│ ├── model_constructor.py
│ ├── settings.py
│ ├── stats.py
│ └── tensorboard.py
├── data
│ ├── __init__.py
│ ├── image_loader.py
│ ├── loader.py
│ ├── processing.py
│ ├── processing_utils.py
│ ├── sampler.py
│ └── transforms.py
├── data_specs
│ ├── got10k_train_split.txt
│ ├── got10k_val_split.txt
│ └── lasot_train_split.txt
├── dataset
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── coco_seq.py
│ ├── got10k.py
│ ├── imagenetvid.py
│ ├── lasot.py
│ └── tracking_net.py
├── external
│ └── PreciseRoIPooling
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── _assets
│ │ └── prroi_visualization.png
│ │ ├── pytorch
│ │ ├── prroi_pool
│ │ │ ├── .gitignore
│ │ │ ├── __init__.py
│ │ │ ├── build.py
│ │ │ ├── functional.py
│ │ │ ├── prroi_pool.py
│ │ │ ├── src
│ │ │ │ ├── prroi_pooling_gpu.c
│ │ │ │ ├── prroi_pooling_gpu.h
│ │ │ │ ├── prroi_pooling_gpu_impl.cu
│ │ │ │ └── prroi_pooling_gpu_impl.cuh
│ │ │ └── travis.sh
│ │ └── tests
│ │ │ └── test_prroi_pooling2d.py
│ │ └── src
│ │ ├── prroi_pooling_gpu_impl.cu
│ │ └── prroi_pooling_gpu_impl.cuh
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-36.pyc
│ ├── backbone
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── resnet.cpython-36.pyc
│ │ │ └── resnet18_vggm.cpython-36.pyc
│ │ ├── resnet.py
│ │ └── resnet18_vggm.py
│ ├── bbreg
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── atom.cpython-36.pyc
│ │ │ └── atom_iou_net.cpython-36.pyc
│ │ ├── atom.py
│ │ └── atom_iou_net.py
│ └── layers
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── blocks.cpython-36.pyc
│ │ └── blocks.py
├── run_training.py
├── train_settings
│ ├── __init__.py
│ └── bbreg
│ │ ├── __init__.py
│ │ └── atom_default.py
└── trainers
│ ├── __init__.py
│ ├── base_trainer.py
│ └── ltr_trainer.py
└── pytracking
├── README.md
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── __init__.cpython-37.pyc
├── run_tracker.cpython-36.pyc
└── run_webcam.cpython-36.pyc
├── evaluation
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── data.cpython-36.pyc
│ ├── environment.cpython-36.pyc
│ ├── got10kdataset.cpython-36.pyc
│ ├── lasotdataset.cpython-36.pyc
│ ├── local.cpython-36.pyc
│ ├── nfsdataset.cpython-36.pyc
│ ├── otbdataset.cpython-36.pyc
│ ├── running.cpython-36.pyc
│ ├── tpldataset.cpython-36.pyc
│ ├── tracker.cpython-36.pyc
│ ├── trackingnetdataset.cpython-36.pyc
│ ├── uavdataset.cpython-36.pyc
│ └── votdataset.cpython-36.pyc
├── data.py
├── environment.py
├── got10kdataset.py
├── lasotdataset.py
├── local.py
├── nfsdataset.py
├── otbdataset.py
├── running.py
├── tpldataset.py
├── tracker.py
├── trackingnetdataset.py
├── uavdataset.py
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-36.pyc
│ ├── anchors.py
│ ├── bbox_helper.py
│ ├── benchmark_helper.py
│ ├── config_helper.py
│ ├── load_helper.py
│ ├── log_helper.py
│ ├── pysot
│ │ ├── __init__.py
│ │ ├── datasets
│ │ │ ├── __init__.py
│ │ │ ├── dataset.py
│ │ │ ├── video.py
│ │ │ └── vot.py
│ │ ├── evaluation
│ │ │ ├── __init__.py
│ │ │ ├── ar_benchmark.py
│ │ │ └── eao_benchmark.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── build
│ │ │ └── temp.linux-x86_64-3.6
│ │ │ │ ├── region.o
│ │ │ │ └── src
│ │ │ │ └── region.o
│ │ │ ├── c_region.pxd
│ │ │ ├── misc.py
│ │ │ ├── region.c
│ │ │ ├── region.cpython-36m-x86_64-linux-gnu.so
│ │ │ ├── region.pyx
│ │ │ ├── setup.py
│ │ │ ├── src
│ │ │ ├── buffer.h
│ │ │ ├── region.c
│ │ │ └── region.h
│ │ │ └── statistics.py
│ ├── pyvotkit
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ ├── build
│ │ │ └── temp.linux-x86_64-3.6
│ │ │ │ ├── region.o
│ │ │ │ └── src
│ │ │ │ └── region.o
│ │ ├── c_region.pxd
│ │ ├── region.c
│ │ ├── region.cpython-36m-x86_64-linux-gnu.so
│ │ ├── region.pyx
│ │ ├── setup.py
│ │ └── src
│ │ │ ├── buffer.h
│ │ │ ├── region.c
│ │ │ └── region.h
│ └── tracker_config.py
└── votdataset.py
├── experiments
├── __init__.py
└── myexperiments.py
├── features
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── augmentation.cpython-36.pyc
│ ├── deep.cpython-36.pyc
│ ├── extractor.cpython-36.pyc
│ ├── featurebase.cpython-36.pyc
│ └── preprocessing.cpython-36.pyc
├── augmentation.py
├── color.py
├── deep.py
├── extractor.py
├── featurebase.py
├── preprocessing.py
└── util.py
├── libs
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── complex.cpython-36.pyc
│ ├── dcf.cpython-36.pyc
│ ├── fourier.cpython-36.pyc
│ ├── operation.cpython-36.pyc
│ ├── optimization.cpython-36.pyc
│ ├── tensordict.cpython-36.pyc
│ ├── tensorlist.cpython-36.pyc
│ └── tensorlist.cpython-37.pyc
├── complex.py
├── dcf.py
├── fourier.py
├── operation.py
├── optimization.py
├── tensordict.py
└── tensorlist.py
├── parameter
├── SPSTracker
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── default_vot.cpython-36.pyc
│ ├── default.py
│ └── default_vot.py
├── __init__.py
├── __pycache__
│ └── __init__.cpython-36.pyc
└── atom
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── default.cpython-36.pyc
│ └── default_vot.cpython-36.pyc
│ ├── default.py
│ └── default_vot.py
├── run_experiment.py
├── run_tracker.py
├── run_video.py
├── run_webcam.py
├── tracker
├── SPSTracker
│ ├── SPSTracker.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── SPSTracker.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── optim.cpython-36.pyc
│ └── optim.py
├── __init__.py
├── __pycache__
│ └── __init__.cpython-36.pyc
├── atom
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── atom.cpython-36.pyc
│ │ └── optim.cpython-36.pyc
│ ├── atom.py
│ └── optim.py
└── base
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── basetracker.cpython-36.pyc
│ └── basetracker.py
├── util
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── anchors.cpython-36.pyc
│ ├── bbox_helper.cpython-36.pyc
│ ├── config_helper.cpython-36.pyc
│ ├── load_helper.cpython-36.pyc
│ └── tracker_config.cpython-36.pyc
├── anchors.py
├── bbox_helper.py
├── benchmark_helper.py
├── config_helper.py
├── load_helper.py
├── log_helper.py
├── pysot
│ ├── __init__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ ├── video.py
│ │ └── vot.py
│ ├── evaluation
│ │ ├── __init__.py
│ │ ├── ar_benchmark.py
│ │ └── eao_benchmark.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── build
│ │ └── temp.linux-x86_64-3.6
│ │ │ ├── region.o
│ │ │ └── src
│ │ │ └── region.o
│ │ ├── c_region.pxd
│ │ ├── misc.py
│ │ ├── region.c
│ │ ├── region.cpython-36m-x86_64-linux-gnu.so
│ │ ├── region.pyx
│ │ ├── setup.py
│ │ ├── src
│ │ ├── buffer.h
│ │ ├── region.c
│ │ └── region.h
│ │ └── statistics.py
├── pyvotkit
│ ├── __init__.py
│ ├── build
│ │ └── temp.linux-x86_64-3.6
│ │ │ ├── region.o
│ │ │ └── src
│ │ │ └── region.o
│ ├── c_region.pxd
│ ├── region.c
│ ├── region.cpython-36m-x86_64-linux-gnu.so
│ ├── region.pyx
│ ├── setup.py
│ └── src
│ │ ├── buffer.h
│ │ ├── region.c
│ │ └── region.h
└── tracker_config.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── params.cpython-36.pyc
└── plotting.cpython-36.pyc
├── gdrive_download
├── params.py
└── plotting.py
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | This document contains detailed instructions for installing the necessary dependencies for PyTracking. The instrustions have been tested on an Ubuntu 18.04 system. We recommend using the [install script](install.sh) if you have not already tried that.
4 |
5 | ### Requirements
6 | * Conda installation with Python 3.7. If not already installed, install from https://www.anaconda.com/distribution/.
7 | * Nvidia GPU.
8 |
9 | ## Step-by-step instructions
10 | #### Create and activate a conda environment
11 | ```bash
12 | conda create --name pytracking python=3.7
13 | conda activate pytracking
14 | ```
15 |
16 | #### Install PyTorch
17 | Install PyTorch 0.4.1 with cuda92.
18 | ```bash
19 | conda install pytorch=0.4.1 torchvision cuda92 -c pytorch
20 | ```
21 |
22 | **Note:**
23 | - PyTorch 1.0 should be supported, but **not recommended** as it requires an [alternate compilation](https://github.com/vacancy/PreciseRoIPooling) of the PreciseRoIPooling module which hasn't been tested.
24 | - It is possible to use any PyTorch supported version of CUDA (not necessarily 9.2).
25 | - For more details about PyTorch installation, see https://pytorch.org/get-started/previous-versions/.
26 |
27 | #### Install matplotlib, pandas, opencv and tensorboadX
28 | ```bash
29 | conda install matplotlib=2.2.2 pandas
30 | pip install opencv-python tensorboardX
31 | ```
32 |
33 |
34 | #### Install the coco toolkit
35 | If you want to use COCO dataset for training, install the coco python toolkit. You additionally need to install cython to compile the coco toolkit.
36 | ```bash
37 | conda install cython
38 | pip install pycocotools
39 | ```
40 |
41 |
42 | #### Compile Precise ROI pooling
43 | To compile the Precise ROI pooling module (https://github.com/vacancy/PreciseRoIPooling) for PyTorch 0.4.1, go to the directory "ltr/external/PreciseRoIPooling/pytorch/prroi_pool" and run "travis.sh" script.
44 | You may additionally have to export the path to the cuda installation.
45 | ```bash
46 | cd ltr/external/PreciseRoIPooling/pytorch/prroi_pool
47 |
48 | # Export the path to the cuda installation
49 | PATH=/usr/local/cuda/bin/:$PATH
50 |
51 | # Compile Precise ROI Pool
52 | bash travis.sh
53 | ```
54 |
55 | In case of issues, we refer to https://github.com/vacancy/PreciseRoIPooling.
56 |
57 |
58 | #### Install jpeg4py
59 | In order to use [jpeg4py](https://github.com/ajkxyz/jpeg4py) for loading the images instead of OpenCV's imread(), install jpeg4py in the following way,
60 | ```bash
61 | sudo apt-get install libturbojpeg
62 | pip install jpeg4py
63 | ```
64 |
65 | **Note:** The first step (```sudo apt-get install libturbojpeg```) can be optionally ignored, in which case OpenCV's imread() will be used to read the images. However the second step is a must.
66 |
67 | In case of issues, we refer to https://github.com/ajkxyz/jpeg4py.
68 |
69 |
70 | #### Setup the environment
71 | Create the default environment setting files.
72 | ```bash
73 | # Environment settings for pytracking. Saved at pytracking/evaluation/local.py
74 | python -c "from pytracking.evaluation.environment import create_default_local_file; create_default_local_file()"
75 |
76 | # Environment settings for ltr. Saved at ltr/admin/local.py
77 | python -c "from ltr.admin.environment import create_default_local_file; create_default_local_file()"
78 | ```
79 |
80 | You can modify these files to set the paths to datasets, results paths etc.
81 |
82 |
83 | #### Download the pre-trained networks
84 | You can download the pre-trained networks from the [google drive folder](https://drive.google.com/drive/folders/1WVhJqvdu-_JG1U-V0IqfxTUa1SBPnL0O). The networks shoud be saved in the directory set by "network_path" in "pytracking/evaluation/local.py". By default, it is set to pytracking/networks.
85 | You can also download the networks using the gdrive_download bash script.
86 |
87 | ```bash
88 | # Download the default network for ATOM
89 | bash pytracking/utils/gdrive_download 1JUB3EucZfBk3rX7M3_q5w_dLBqsT7s-M pytracking/networks/atom_default.pth
90 |
91 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [SPSTracker: Sub-Peak Suppression of Response Map for Robust Object Tracking](https://arxiv.org/abs/1912.00597).
2 |
3 |
4 | ## News
5 | * \[2019/11/11\] SPSTracker has been accepted as oral paper at the Thirty-Fourth AAAI Conference on Artificial Intelligence (AAAI-20).
6 | ## Introduction
7 | This is the official code of SPSTracker: Sub-Peak Suppression of Response Map for Robust Object Tracking. We propose a simple-yet-effective approach, referred to as SPSTracker for robust object tracking. Our motivation is based on the observation that most failure tracking is caused by the interference around the target. Such interference produces multi-peak tracking response, and the sub-peak may progressively ``grow" and eventually cause model drift. Therefore ,we propose suppressing the sub-peaks to aggregating a single-peak response, with the aim of preventing model drift from the perspective of tracking response regularization.
8 | 
9 | 
10 |
11 | ## Installation
12 | Code is implemented upon the ATOM architecture, check [INSTALL.md](INSTALL.md) for installation instructions.
13 |
14 | ## [Results](https://drive.google.com/drive/folders/1bX_5fcm2EfeZv5dx3L8CwhMEHGwrGXvV)[[Raw result](https://drive.google.com/open?id=1IPpOGYU6r5Dmz20PcrfOBtMCJ60C6q3Y)]
15 |
16 | | Tracker | VOT2016EAO / A / R | VOT2018EAO / A / R |
17 | |:----------------------------------------------------------------------:|:--------------------------------------------:|:--------------------------------------------:|
18 | | SPSTracker | 0.459/0.625/0.158 | 0.434/0.612/0.169 |
19 |
20 | ## Run
21 | SPSTracker default_vot --dataset vot --debug 1 --threads 0
22 |
23 |
24 | ## Citations
25 | Please consider citing our paper in your publications if the project helps your research.
26 | ```
27 | @inproceedings{hu2020spstracker,
28 | title = {{SPSTracker}: Sub-Peak Suppression of Response Map for Robust Object Tracking},
29 | author = {Qintao Hu and Lijun Zhou and Xiaoxiao Wang and Yao Mao and Jianlin Zhang and Qixiang Ye},
30 | booktitle = {Thirty-Fourth AAAI Conference on Artificial Intelligence (AAAI)},
31 | year = {2020}
32 | }
33 | ```
34 |
--------------------------------------------------------------------------------
/figure/SPS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/figure/SPS.png
--------------------------------------------------------------------------------
/figure/flowchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/figure/flowchart.png
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ "$#" -ne 2 ]; then
4 | echo "ERROR! Illegal number of parameters. Usage: bash install.sh conda_install_path environment_name"
5 | exit 0
6 | fi
7 |
8 | conda_install_path=$1
9 | conda_env_name=$2
10 |
11 | source $conda_install_path/etc/profile.d/conda.sh
12 | echo "****************** Creating conda environment ${conda_env_name} python=3.7 ******************"
13 | conda create -y --name $conda_env_name
14 |
15 | echo ""
16 | echo ""
17 | echo "****************** Activating conda environment ${conda_env_name} ******************"
18 | conda activate $conda_env_name
19 |
20 | echo ""
21 | echo ""
22 | echo "****************** Installing pytorch 0.4.1 with cuda92 ******************"
23 | conda install -y pytorch=0.4.1 torchvision cuda92 -c pytorch
24 |
25 | echo ""
26 | echo ""
27 | echo "****************** Installing matplotlib 2.2.2 ******************"
28 | conda install -y matplotlib=2.2.2
29 |
30 | echo ""
31 | echo ""
32 | echo "****************** Installing pandas ******************"
33 | conda install -y pandas
34 |
35 | echo ""
36 | echo ""
37 | echo "****************** Installing opencv ******************"
38 | pip install opencv-python
39 |
40 | echo ""
41 | echo ""
42 | echo "****************** Installing tensorboardX ******************"
43 | pip install tensorboardX
44 |
45 | echo ""
46 | echo ""
47 | echo "****************** Installing cython ******************"
48 | conda install -y cython
49 |
50 | echo ""
51 | echo ""
52 | echo "****************** Installing coco toolkit ******************"
53 | pip install pycocotools
54 |
55 | echo ""
56 | echo ""
57 | echo "****************** Installing jpeg4py python wrapper ******************"
58 | pip install jpeg4py
59 |
60 | echo ""
61 | echo ""
62 | echo "****************** Installing PreROIPooling ******************"
63 | base_dir=$(pwd)
64 | cd ltr/external/PreciseRoIPooling/pytorch/prroi_pool
65 | PATH=/usr/local/cuda/bin/:$PATH
66 | bash travis.sh
67 | cd $base_dir
68 |
69 | echo ""
70 | echo ""
71 | echo "****************** Downloading networks ******************"
72 | mkdir pytracking/networks
73 |
74 | echo ""
75 | echo ""
76 | echo "****************** ATOM Network ******************"
77 | bash pytracking/utils/gdrive_download 1ZTdQbZ1tyN27UIwUnUrjHChQb5ug2sxr pytracking/networks/atom_default.pth
78 |
79 | echo ""
80 | echo ""
81 | echo "****************** ECO Network ******************"
82 | bash pytracking/utils/gdrive_download 1aWC4waLv_te-BULoy0k-n_zS-ONms21S pytracking/networks/resnet18_vggmconv1.pth
83 |
84 | echo ""
85 | echo ""
86 | echo "****************** Setting up environment ******************"
87 | python -c "from pytracking.evaluation.environment import create_default_local_file; create_default_local_file()"
88 | python -c "from ltr.admin.environment import create_default_local_file; create_default_local_file()"
89 |
90 |
91 | echo ""
92 | echo ""
93 | echo "****************** Installing jpeg4py ******************"
94 | while true; do
95 | read -p "Install jpeg4py for reading images? This step required sudo privilege. Installing jpeg4py is optional, however recommended. [y,n] " install_flag
96 | case $install_flag in
97 | [Yy]* ) sudo apt-get install libturbojpeg; break;;
98 | [Nn]* ) echo "Skipping jpeg4py installation!"; break;;
99 | * ) echo "Please answer y or n ";;
100 | esac
101 | done
102 |
103 | echo ""
104 | echo ""
105 | echo "****************** Installation complete! ******************"
106 |
--------------------------------------------------------------------------------
/ltr/README.md:
--------------------------------------------------------------------------------
1 | # LTR
2 |
3 | A general PyTorch based framework for learning tracking representations. The repository contains the code for training the [**ATOM**](https://arxiv.org/pdf/1811.07628.pdf) tracker.
4 |
5 | ## Table of Contents
6 |
7 | * [Quick Start](#quick-start)
8 | * [Overview](#overview)
9 | * [Train Settings](#train-settings)
10 | * [Training your own networks](#training-your-own-networks)
11 |
12 | ## Quick Start
13 | The installation script will automatically generate a local configuration file "admin/local.py". In case the file was not generated, run ```admin.environment.create_default_local_file()``` to generate it. Next, set the paths to the training workspace,
14 | i.e. the directory where the checkpoints will be saved. Also set the paths to the datasets you want to use. If all the dependencies have been correctly installed, you can train a network using the run_training.py script in the correct conda environment.
15 | ```bash
16 | conda activate pytracking
17 | python run_training train_module train_name
18 | ```
19 | Here, ```train_module``` is the sub-module inside ```train_settings``` and ```train_name``` is the name of the train setting file to be used.
20 |
21 | For example, you can train using the included default ATOM settings by running:
22 | ```bash
23 | python run_training bbreg atom_default
24 | ```
25 |
26 |
27 | ## Overview
28 | The framework consists of the following sub-modules.
29 | - [actors](actors): Contains the actor classes for different trainings. The actor class is responsible for passing the input data through the network can calculating losses.
30 | - [admin](admin): Includes functions for loading networks, tensorboard etc. and also contains environment settings.
31 | - [dataset](dataset): Contains integration of a number of training datasets, namely [TrackingNet](https://tracking-net.org/), [GOT-10k](http://got-10k.aitestunion.com/), [LaSOT](https://cis.temple.edu/lasot/),
32 | [ImageNet-VID](http://image-net.org/), and [COCO](http://cocodataset.org/#home).
33 | - [data_specs](data_specs): Information about train/val splits of different datasets.
34 | - [data](data): Contains functions for processing data, e.g. loading images, data augmentations, sampling frames from videos.
35 | - [external](external): External libraries needed for training. Added as submodules.
36 | - [models](models): Contains different layers and network definitions.
37 | - [trainers](trainers): The main class which runs the training.
38 | - [train_settings](train_settings): Contains settings files, specifying the training of a network.
39 |
40 | ## Train Settings
41 | The framework currently contains the following training settings:
42 | - [bbreg.atom_default](train_settings/bbreg/atom_default.py): The default settings used for training the network in [ATOM](https://arxiv.org/pdf/1811.07628.pdf).
43 |
44 |
45 | ## Training your own networks
46 | To train a custom network using the toolkit, the following components need to be specified in the train settings. For reference, see [atom_default.py](train_settings/bbreg/atom_default.py).
47 | - Datasets: The datasets to be used for training. A number of standard tracking datasets are already available in ```dataset``` module.
48 | - Processing: This function should perform the necessary post-processing of the data, e.g. cropping of target region, data augmentations etc.
49 | - Sampler: Determines how the frames are sampled from a video sequence to form the batches.
50 | - Network: The network module to be trained.
51 | - Objective: The training objective.
52 | - Actor: The trainer passes the training batch to the actor who is responsible for passing the data through the network correctly, and calculating the training loss.
53 | - Optimizer: Optimizer to be used, e.g. Adam.
54 | - Trainer: The main class which runs the epochs and saves checkpoints.
55 |
56 |
57 |
--------------------------------------------------------------------------------
/ltr/__init__.py:
--------------------------------------------------------------------------------
1 | from .admin.loading import load_network
2 | from .admin.model_constructor import model_constructor
--------------------------------------------------------------------------------
/ltr/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/actors/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_actor import BaseActor
2 | from .bbreg import AtomActor
3 |
--------------------------------------------------------------------------------
/ltr/actors/base_actor.py:
--------------------------------------------------------------------------------
1 | from pytracking import TensorDict
2 |
3 |
4 | class BaseActor:
5 | """ Base class for actor. The actor class handles the passing of the data through the network
6 | and calculation the loss"""
7 | def __init__(self, net, objective):
8 | """
9 | args:
10 | net - The network to train
11 | objective - The loss function
12 | """
13 | self.net = net
14 | self.objective = objective
15 |
16 | def __call__(self, data: TensorDict):
17 | """ Called in each training iteration. Should pass in input data through the network, calculate the loss, and
18 | return the training stats for the input data
19 | args:
20 | data - A TensorDict containing all the necessary data blocks.
21 |
22 | returns:
23 | loss - loss for the input data
24 | stats - a dict containing detailed losses
25 | """
26 | raise NotImplementedError
27 |
28 | def to(self, device):
29 | """ Move the network to device
30 | args:
31 | device - device to use. 'cpu' or 'cuda'
32 | """
33 | self.net.to(device)
34 |
35 | def train(self, mode=True):
36 | """ Set whether the network is in train mode.
37 | args:
38 | mode (True) - Bool specifying whether in training mode.
39 | """
40 | self.net.train(mode)
41 |
42 | def eval(self):
43 | """ Set network to eval mode"""
44 | self.train(False)
--------------------------------------------------------------------------------
/ltr/actors/bbreg.py:
--------------------------------------------------------------------------------
1 | from . import BaseActor
2 |
3 |
4 | class AtomActor(BaseActor):
5 | """ Actor for training the IoU-Net in ATOM"""
6 | def __call__(self, data):
7 | """
8 | args:
9 | data - The input data, should contain the fields 'train_images', 'test_images', 'train_anno',
10 | 'test_proposals' and 'proposal_iou'.
11 |
12 | returns:
13 | loss - the training loss
14 | states - dict containing detailed losses
15 | """
16 | # Run network to obtain IoU prediction for each proposal in 'test_proposals'
17 | iou_pred = self.net(data['train_images'], data['test_images'], data['train_anno'], data['test_proposals'])
18 |
19 | iou_pred = iou_pred.view(-1, iou_pred.shape[2])
20 | iou_gt = data['proposal_iou'].view(-1, data['proposal_iou'].shape[2])
21 |
22 | # Compute loss
23 | loss = self.objective(iou_pred, iou_gt)
24 |
25 | # Return training stats
26 | stats = {'Loss/total': loss.item(),
27 | 'Loss/iou': loss.item()}
28 |
29 | return loss, stats
--------------------------------------------------------------------------------
/ltr/admin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__init__.py
--------------------------------------------------------------------------------
/ltr/admin/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/admin/__pycache__/environment.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/environment.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/admin/__pycache__/loading.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/loading.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/admin/__pycache__/local.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/local.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/admin/__pycache__/model_constructor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/model_constructor.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/admin/__pycache__/settings.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/settings.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/admin/__pycache__/stats.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/stats.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/admin/environment.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | from collections import OrderedDict
4 |
5 |
6 | def create_default_local_file():
7 | path = os.path.join(os.path.dirname(__file__), 'local.py')
8 |
9 | empty_str = '\'\''
10 | default_settings = OrderedDict({
11 | 'workspace_dir': empty_str,
12 | 'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'',
13 | 'lasot_dir': empty_str,
14 | 'got10k_dir': empty_str,
15 | 'trackingnet_dir': empty_str,
16 | 'coco_dir': empty_str,
17 | 'imagenet_dir': empty_str,
18 | 'imagenetdet_dir': empty_str})
19 |
20 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
21 | 'tensorboard_dir': 'Directory for tensorboard files.'}
22 |
23 | with open(path, 'w') as f:
24 | f.write('class EnvironmentSettings:\n')
25 | f.write(' def __init__(self):\n')
26 |
27 | for attr, attr_val in default_settings.items():
28 | comment_str = None
29 | if attr in comment:
30 | comment_str = comment[attr]
31 | if comment_str is None:
32 | f.write(' self.{} = {}\n'.format(attr, attr_val))
33 | else:
34 | f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str))
35 |
36 |
37 | def env_settings():
38 | env_module_name = 'ltr.admin.local'
39 | try:
40 | env_module = importlib.import_module(env_module_name)
41 | return env_module.EnvironmentSettings()
42 | except:
43 | env_file = os.path.join(os.path.dirname(__file__), 'local.py')
44 |
45 | create_default_local_file()
46 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file))
47 |
--------------------------------------------------------------------------------
/ltr/admin/local.py:
--------------------------------------------------------------------------------
1 | class EnvironmentSettings:
2 | def __init__(self):
3 | self.workspace_dir = '' # Base directory for saving network checkpoints.
4 | self.tensorboard_dir = self.workspace_dir + '/tensorboard/' # Directory for tensorboard files.
5 | self.lasot_dir = ''
6 | self.got10k_dir = ''
7 | self.trackingnet_dir = ''
8 | self.coco_dir = ''
9 | self.imagenet_dir = ''
10 | self.imagenetdet_dir = ''
11 |
--------------------------------------------------------------------------------
/ltr/admin/model_constructor.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | import importlib
3 |
4 |
5 | def model_constructor(f):
6 | """ Wraps the function 'f' which returns the network. An extra field 'constructor' is added to the network returned
7 | by 'f'. This field contains an instance of the 'NetConstructor' class, which contains the information needed to
8 | re-construct the network, such as the name of the function 'f', the function arguments etc. Thus, the network can
9 | be easily constructed from a saved checkpoint by calling NetConstructor.get() function.
10 | """
11 | @wraps(f)
12 | def f_wrapper(*args, **kwds):
13 | net_constr = NetConstructor(f.__name__, f.__module__, args, kwds)
14 | output = f(*args, **kwds)
15 | if isinstance(output, (tuple, list)):
16 | # Assume first argument is the network
17 | output[0].constructor = net_constr
18 | else:
19 | output.constructor = net_constr
20 | return output
21 | return f_wrapper
22 |
23 |
24 | class NetConstructor:
25 | """ Class to construct networks. Takes as input the function name (e.g. atom_resnet18), the name of the module
26 | which contains the network function (e.g. ltr.models.bbreg.atom) and the arguments for the network
27 | function. The class object can then be stored along with the network weights to re-construct the network."""
28 | def __init__(self, fun_name, fun_module, args, kwds):
29 | """
30 | args:
31 | fun_name - The function which returns the network
32 | fun_module - the module which contains the network function
33 | args - arguments which are passed to the network function
34 | kwds - arguments which are passed to the network function
35 | """
36 | self.fun_name = fun_name
37 | self.fun_module = fun_module
38 | self.args = args
39 | self.kwds = kwds
40 |
41 | def get(self):
42 | """ Rebuild the network by calling the network function with the correct arguments. """
43 | net_module = importlib.import_module(self.fun_module)
44 | net_fun = getattr(net_module, self.fun_name)
45 | return net_fun(*self.args, **self.kwds)
46 |
--------------------------------------------------------------------------------
/ltr/admin/settings.py:
--------------------------------------------------------------------------------
1 | from ltr.admin.environment import env_settings
2 |
3 |
4 | class Settings:
5 | """ Training settings, e.g. the paths to datasets and networks."""
6 | def __init__(self):
7 | self.set_default()
8 |
9 | def set_default(self):
10 | self.env = env_settings()
11 | self.use_gpu = True
12 |
13 |
14 |
--------------------------------------------------------------------------------
/ltr/admin/stats.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class StatValue:
4 | def __init__(self):
5 | self.clear()
6 |
7 | def reset(self):
8 | self.val = 0
9 |
10 | def clear(self):
11 | self.reset()
12 | self.history = []
13 |
14 | def update(self, val):
15 | self.val = val
16 | self.history.append(self.val)
17 |
18 |
19 | class AverageMeter(object):
20 | """Computes and stores the average and current value"""
21 | def __init__(self):
22 | self.clear()
23 | self.has_new_data = False
24 |
25 | def reset(self):
26 | self.avg = 0
27 | self.val = 0
28 | self.sum = 0
29 | self.count = 0
30 |
31 | def clear(self):
32 | self.reset()
33 | self.history = []
34 |
35 | def update(self, val, n=1):
36 | self.val = val
37 | self.sum += val * n
38 | self.count += n
39 | self.avg = self.sum / self.count
40 |
41 | def new_epoch(self):
42 | if self.count > 0:
43 | self.history.append(self.avg)
44 | self.reset()
45 | self.has_new_data = True
46 | else:
47 | self.has_new_data = False
48 |
49 |
50 | def topk_accuracy(output, target, topk=(1,)):
51 | """Computes the precision@k for the specified values of k"""
52 | single_input = not isinstance(topk, (tuple, list))
53 | if single_input:
54 | topk = (topk,)
55 |
56 | maxk = max(topk)
57 | batch_size = target.size(0)
58 |
59 | _, pred = output.topk(maxk, 1, True, True)
60 | pred = pred.t()
61 | correct = pred.eq(target.view(1, -1).expand_as(pred))
62 |
63 | res = []
64 | for k in topk:
65 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0]
66 | res.append(correct_k * 100.0 / batch_size)
67 |
68 | if single_input:
69 | return res[0]
70 |
71 | return res
72 |
--------------------------------------------------------------------------------
/ltr/admin/tensorboard.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | from tensorboardX import SummaryWriter
4 |
5 |
6 | class TensorboardWriter:
7 | def __init__(self, directory, loader_names):
8 | self.directory = directory
9 | self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names})
10 |
11 | def write_info(self, module_name, script_name, description):
12 | tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info'))
13 | tb_info_writer.add_text('Modulet_name', module_name)
14 | tb_info_writer.add_text('Script_name', script_name)
15 | tb_info_writer.add_text('Description', description)
16 | tb_info_writer.close()
17 |
18 | def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1):
19 | for loader_name, loader_stats in stats.items():
20 | if loader_stats is None:
21 | continue
22 | for var_name, val in loader_stats.items():
23 | if hasattr(val, 'history') and getattr(val, 'has_new_data', True):
24 | self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch)
--------------------------------------------------------------------------------
/ltr/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .loader import LTRLoader
--------------------------------------------------------------------------------
/ltr/data/image_loader.py:
--------------------------------------------------------------------------------
1 | import jpeg4py
2 | import cv2 as cv
3 |
4 |
5 | def default_image_loader(path):
6 | """The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader,
7 | but reverts to the opencv_loader if the former is not available."""
8 | if default_image_loader.use_jpeg4py is None:
9 | # Try using jpeg4py
10 | im = jpeg4py_loader(path)
11 | if im is None:
12 | default_image_loader.use_jpeg4py = False
13 | print('Using opencv_loader instead.')
14 | else:
15 | default_image_loader.use_jpeg4py = True
16 | return im
17 | if default_image_loader.use_jpeg4py:
18 | return jpeg4py_loader(path)
19 | return opencv_loader(path)
20 |
21 | default_image_loader.use_jpeg4py = None
22 |
23 |
24 | def jpeg4py_loader(path):
25 | """ Image reading using jpeg4py (https://github.com/ajkxyz/jpeg4py)"""
26 | try:
27 | return jpeg4py.JPEG(path).decode()
28 | except Exception as e:
29 | print('ERROR: Could not read image "{}"'.format(path))
30 | print(e)
31 | return None
32 |
33 |
34 | def opencv_loader(path):
35 | """ Read image using opencv's imread function and returns it in rgb format"""
36 | try:
37 | im = cv.imread(path, cv.IMREAD_COLOR)
38 | # convert to rgb and return
39 | return cv.cvtColor(im, cv.COLOR_BGR2RGB)
40 | except Exception as e:
41 | print('ERROR: Could not read image "{}"'.format(path))
42 | print(e)
43 | return None
44 |
--------------------------------------------------------------------------------
/ltr/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .lasot import Lasot
2 | from .got10k import Got10k
3 | from .tracking_net import TrackingNet
4 | from .imagenetvid import ImagenetVID
5 | from .coco_seq import MSCOCOSeq
6 |
--------------------------------------------------------------------------------
/ltr/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from ltr.data.image_loader import default_image_loader
3 |
4 |
5 | class BaseDataset(torch.utils.data.Dataset):
6 | """ Base class for datasets """
7 |
8 | def __init__(self, root, image_loader=default_image_loader):
9 | """
10 | args:
11 | root - The root path to the dataset
12 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
13 | is used by default.
14 | """
15 | if root == '':
16 | raise Exception('The dataset path is not setup. Check your "ltr/admin/local.py".')
17 | self.root = root
18 | self.image_loader = image_loader
19 |
20 | self.sequence_list = [] # Contains the list of sequences.
21 |
22 | def __len__(self):
23 | """ Returns size of the dataset
24 | returns:
25 | int - number of samples in the dataset
26 | """
27 | return self.get_num_sequences()
28 |
29 | def __getitem__(self, index):
30 | """ Not to be used! Check get_frames() instead.
31 | """
32 | return None
33 |
34 | def is_video_sequence(self):
35 | """ Returns whether the dataset is a video dataset or an image dataset
36 |
37 | returns:
38 | bool - True if a video dataset
39 | """
40 | return True
41 |
42 | def get_name(self):
43 | """ Name of the dataset
44 |
45 | returns:
46 | string - Name of the dataset
47 | """
48 | raise NotImplementedError
49 |
50 | def get_num_sequences(self):
51 | """ Number of sequences in a dataset
52 |
53 | returns:
54 | int - number of sequences in the dataset."""
55 | return len(self.sequence_list)
56 |
57 | def get_sequence_info(self, seq_id):
58 | """ Returns information about a particular sequences,
59 |
60 | args:
61 | seq_id - index of the sequence
62 |
63 | returns:
64 | Tensor - Annotation for the sequence. A 2d tensor of shape (num_frames, 4).
65 | Format [top_left_x, top_left_y, width, height]
66 | Tensor - 1d Tensor specifying whether target is present (=1 )for each frame. shape (num_frames,)
67 | """
68 | raise NotImplementedError
69 |
70 | def get_frames(self, seq_id, frame_ids, anno=None):
71 | """ Get a set of frames from a particular sequence
72 |
73 | args:
74 | seq_id - index of sequence
75 | frame_ids - a list of frame numbers
76 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded.
77 |
78 | returns:
79 | list - List of frames corresponding to frame_ids
80 | list - List of annotations (tensor of shape (4,)) for each frame
81 | dict - A dict containing meta information about the sequence, e.g. class of the target object.
82 |
83 | """
84 | raise NotImplementedError
85 |
86 |
--------------------------------------------------------------------------------
/ltr/dataset/coco_seq.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .base_dataset import BaseDataset
3 | from ltr.data.image_loader import default_image_loader
4 | import torch
5 | from pycocotools.coco import COCO
6 | from collections import OrderedDict
7 | from ltr.admin.environment import env_settings
8 |
9 |
10 | class MSCOCOSeq(BaseDataset):
11 | """ The COCO dataset. COCO is an image dataset. Thus, we treat each image as a sequence of length 1.
12 |
13 | Publication:
14 | Microsoft COCO: Common Objects in Context.
15 | Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona,
16 | Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick
17 | ECCV, 2014
18 | https://arxiv.org/pdf/1405.0312.pdf
19 |
20 | Download the images along with annotations from http://cocodataset.org/#download. The root folder should be
21 | organized as follows.
22 | - coco_root
23 | - annotations
24 | - instances_train2014.json
25 | - images
26 | - train2014
27 |
28 | Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi.
29 | """
30 |
31 | def __init__(self, root=None, image_loader=default_image_loader):
32 | root = env_settings().coco_dir if root is None else root
33 | super().__init__(root, image_loader)
34 |
35 | self.img_pth = os.path.join(root, 'train2014/')
36 | self.anno_path = os.path.join(root, 'annotations/instances_train2014.json')
37 |
38 | # Load the COCO set.
39 | self.coco_set = COCO(self.anno_path)
40 |
41 | self.cats = self.coco_set.cats
42 | self.sequence_list = self._get_sequence_list()
43 |
44 | def _get_sequence_list(self):
45 | ann_list = list(self.coco_set.anns.keys())
46 | seq_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0]
47 |
48 | return seq_list
49 |
50 | def is_video_sequence(self):
51 | return False
52 |
53 | def get_name(self):
54 | return 'coco'
55 |
56 | def get_num_sequences(self):
57 | return len(self.sequence_list)
58 |
59 | def get_sequence_info(self, seq_id):
60 | anno = self._get_anno(seq_id)
61 |
62 | return anno, torch.Tensor([1])
63 |
64 | def _get_anno(self, seq_id):
65 | anno = self.coco_set.anns[self.sequence_list[seq_id]]['bbox']
66 | return torch.Tensor(anno).view(1, 4)
67 |
68 | def _get_frames(self, seq_id):
69 | path = self.coco_set.loadImgs([self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0]['file_name']
70 | img = self.image_loader(os.path.join(self.img_pth, path))
71 | return img
72 |
73 | def get_meta_info(self, seq_id):
74 | try:
75 | cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']]
76 | object_meta = OrderedDict({'object_class': cat_dict_current['name'],
77 | 'motion_class': None,
78 | 'major_class': cat_dict_current['supercategory'],
79 | 'root_class': None,
80 | 'motion_adverb': None})
81 | except:
82 | object_meta = OrderedDict({'object_class': None,
83 | 'motion_class': None,
84 | 'major_class': None,
85 | 'root_class': None,
86 | 'motion_adverb': None})
87 | return object_meta
88 |
89 | def get_frames(self, seq_id=None, frame_ids=None, anno=None):
90 | # COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a
91 | # list containing these replicated images.
92 | frame = self._get_frames(seq_id)
93 |
94 | frame_list = [frame.copy() for _ in frame_ids]
95 |
96 | if anno is None:
97 | anno = self._get_anno(seq_id)
98 |
99 | anno_frames = [anno.clone()[0, :] for _ in frame_ids]
100 |
101 | object_meta = self.get_meta_info(seq_id)
102 |
103 | return frame_list, anno_frames, object_meta
104 |
--------------------------------------------------------------------------------
/ltr/dataset/tracking_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import os.path
4 | import numpy as np
5 | import pandas
6 | from collections import OrderedDict
7 |
8 | from ltr.data.image_loader import default_image_loader
9 | from .base_dataset import BaseDataset
10 | from ltr.admin.environment import env_settings
11 |
12 |
13 | def list_sequences(root, set_ids):
14 | """ Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name)
15 |
16 | args:
17 | root: Root directory to TrackingNet
18 | set_ids: Sets (0-11) which are to be used
19 |
20 | returns:
21 | list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence
22 | """
23 | sequence_list = []
24 |
25 | for s in set_ids:
26 | anno_dir = os.path.join(root, "TRAIN_" + str(s), "anno")
27 |
28 | sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')]
29 | sequence_list += sequences_cur_set
30 |
31 | return sequence_list
32 |
33 |
34 | class TrackingNet(BaseDataset):
35 | """ TrackingNet dataset.
36 |
37 | Publication:
38 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
39 | Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
40 | ECCV, 2018
41 | https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
42 |
43 | Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
44 | """
45 | def __init__(self, root=None, image_loader=default_image_loader, set_ids=None):
46 | """
47 | args:
48 | root - The path to the TrackingNet folder, containing the training sets.
49 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
50 | is used by default.
51 | set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the
52 | sets (0 - 11) will be used.
53 | """
54 | root = env_settings().trackingnet_dir if root is None else root
55 | super().__init__(root, image_loader)
56 |
57 | if set_ids is None:
58 | set_ids = [i for i in range(12)]
59 |
60 | self.set_ids = set_ids
61 |
62 | # Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and
63 | # video_name for each sequence
64 | self.sequence_list = list_sequences(self.root, self.set_ids)
65 |
66 | def get_name(self):
67 | return 'trackingnet'
68 |
69 | def _read_anno(self, seq_id):
70 | set_id = self.sequence_list[seq_id][0]
71 | vid_name = self.sequence_list[seq_id][1]
72 | anno_file = os.path.join(self.root, "TRAIN_" + str(set_id), "anno", vid_name + ".txt")
73 | gt = pandas.read_csv(anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values
74 | return torch.tensor(gt)
75 |
76 | def get_sequence_info(self, seq_id):
77 | anno = self._read_anno(seq_id)
78 | target_visible = (anno[:,2]>0) & (anno[:,3]>0)
79 | return anno, target_visible
80 |
81 | def _get_frame(self, seq_id, frame_id):
82 | set_id = self.sequence_list[seq_id][0]
83 | vid_name = self.sequence_list[seq_id][1]
84 | frame_path = os.path.join(self.root, "TRAIN_" + str(set_id), "frames", vid_name, str(frame_id) + ".jpg")
85 | return self.image_loader(frame_path)
86 |
87 | def get_frames(self, seq_id, frame_ids, anno=None):
88 | frame_list = [self._get_frame(seq_id, f) for f in frame_ids]
89 |
90 | if anno is None:
91 | anno = self._read_anno(seq_id)
92 |
93 | # Return as list of tensors
94 | anno_frames = [anno[f_id, :] for f_id in frame_ids]
95 |
96 | object_meta = OrderedDict({'object_class': None,
97 | 'motion_class': None,
98 | 'major_class': None,
99 | 'root_class': None,
100 | 'motion_adverb': None})
101 |
102 | return frame_list, anno_frames, object_meta
103 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | .vim-template*
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Jiayuan Mao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/README.md:
--------------------------------------------------------------------------------
1 | # PreciseRoIPooling
2 | This repo implements the **Precise RoI Pooling** (PrRoI Pooling), proposed in the paper **Acquisition of Localization Confidence for Accurate Object Detection** published at ECCV 2018 (Oral Presentation).
3 |
4 | **Acquisition of Localization Confidence for Accurate Object Detection**
5 |
6 | _Borui Jiang*, Ruixuan Luo*, Jiayuan Mao*, Tete Xiao, Yuning Jiang_ (* indicates equal contribution.)
7 |
8 | https://arxiv.org/abs/1807.11590
9 |
10 | ## Brief
11 |
12 | In short, Precise RoI Pooling is an integration-based (bilinear interpolation) average pooling method for RoI Pooling. It avoids any quantization and has a continuous gradient on bounding box coordinates. It is:
13 |
14 | - different from the original RoI Pooling proposed in [Fast R-CNN](https://arxiv.org/abs/1504.08083). PrRoI Pooling uses average pooling instead of max pooling for each bin and has a continuous gradient on bounding box coordinates. That is, one can take the derivatives of some loss function w.r.t the coordinates of each RoI and optimize the RoI coordinates.
15 | - different from the RoI Align proposed in [Mask R-CNN](https://arxiv.org/abs/1703.06870). PrRoI Pooling uses a full integration-based average pooling instead of sampling a constant number of points. This makes the gradient w.r.t. the coordinates continuous.
16 |
17 | For a better illustration, we illustrate RoI Pooling, RoI Align and PrRoI Pooing in the following figure. More details including the gradient computation can be found in our paper.
18 |
19 |
20 |
21 | ## Implementation
22 |
23 | PrRoI Pooling was originally implemented by [Tete Xiao](http://tetexiao.com/) based on MegBrain, an (internal) deep learning framework built by Megvii Inc. It was later adapted into open-source deep learning frameworks. Currently, we only support PyTorch. Unfortunately, we don't have any specific plan for the adaptation into other frameworks such as TensorFlow, but any contributions (pull requests) will be more than welcome.
24 |
25 | ## Usage (PyTorch)
26 |
27 | In the directory `pytorch/`, we provide a PyTorch-based implementation of PrRoI Pooling. It requires PyTorch 0.4 and only supports CUDA (CPU mode is not implemented). To use the PrRoI Pooling module, first goto `pytorch/prroi_pool` and execute `./travis.sh` to compile the essential components (you may need `nvcc` for this step). To use the module in your code, simply do:
28 |
29 | ```
30 | from prroi_pool import PrRoIPool2D
31 |
32 | avg_pool = PrRoIPool2D(window_height, window_width, spatial_scale)
33 | roi_features = avg_pool(features, rois)
34 |
35 | # for those who want to use the "functional"
36 |
37 | from prroi_pool.functional import prroi_pool2d
38 | roi_features = prroi_pool2d(features, rois, window_height, window_width, spatial_scale)
39 | ```
40 |
41 | Here,
42 |
43 | - RoI is an `m * 5` float tensor of format `(batch_index, x0, y0, x1, y1)`, following the convention in the original Caffe implementation of RoI Pooling, although in some frameworks the batch indices are provided by an integer tensor.
44 | - `spatial_scale` is multiplied to the RoIs. For example, if your feature maps are down-sampled by a factor of 16 (w.r.t. the input image), you should use a spatial scale of `1/16`.
45 | - The coordinates for RoI follows the [L, R) convension. That is, `(0, 0, 4, 4)` denotes a box of size `4x4`.
46 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/_assets/prroi_visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/external/PreciseRoIPooling/_assets/prroi_visualization.png
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/.gitignore:
--------------------------------------------------------------------------------
1 | *.o
2 | /_prroi_pooling
3 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : __init__.py
4 | # Author : Jiayuan Mao, Tete Xiao
5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com
6 | # Date : 07/13/2018
7 | #
8 | # This file is part of PreciseRoIPooling.
9 | # Distributed under terms of the MIT license.
10 | # Copyright (c) 2017 Megvii Technology Limited.
11 |
12 | from .prroi_pool import *
13 |
14 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/build.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : build.py
4 | # Author : Jiayuan Mao, Tete Xiao
5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com
6 | # Date : 07/13/2018
7 | #
8 | # This file is part of PreciseRoIPooling.
9 | # Distributed under terms of the MIT license.
10 | # Copyright (c) 2017 Megvii Technology Limited.
11 |
12 | import os
13 | import torch
14 |
15 | from torch.utils.ffi import create_extension
16 |
17 | headers = []
18 | sources = []
19 | defines = []
20 | extra_objects = []
21 | with_cuda = False
22 |
23 | if torch.cuda.is_available():
24 | with_cuda = True
25 |
26 | headers+= ['src/prroi_pooling_gpu.h']
27 | sources += ['src/prroi_pooling_gpu.c']
28 | defines += [('WITH_CUDA', None)]
29 |
30 | this_file = os.path.dirname(os.path.realpath(__file__))
31 | extra_objects_cuda = ['src/prroi_pooling_gpu_impl.cu.o']
32 | extra_objects_cuda = [os.path.join(this_file, fname) for fname in extra_objects_cuda]
33 | extra_objects.extend(extra_objects_cuda)
34 | else:
35 | # TODO(Jiayuan Mao @ 07/13): remove this restriction after we support the cpu implementation.
36 | raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.')
37 |
38 | ffi = create_extension(
39 | '_prroi_pooling',
40 | headers=headers,
41 | sources=sources,
42 | define_macros=defines,
43 | relative_to=__file__,
44 | with_cuda=with_cuda,
45 | extra_objects=extra_objects
46 | )
47 |
48 | if __name__ == '__main__':
49 | ffi.build()
50 |
51 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/functional.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : functional.py
4 | # Author : Jiayuan Mao, Tete Xiao
5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com
6 | # Date : 07/13/2018
7 | #
8 | # This file is part of PreciseRoIPooling.
9 | # Distributed under terms of the MIT license.
10 | # Copyright (c) 2017 Megvii Technology Limited.
11 |
12 | import torch
13 | import torch.autograd as ag
14 |
15 | try:
16 | from . import _prroi_pooling
17 | except ImportError:
18 | raise ImportError('Can not found the compiled Precise RoI Pooling library. Run ./travis.sh in the directory first.')
19 |
20 | __all__ = ['prroi_pool2d']
21 |
22 |
23 | class PrRoIPool2DFunction(ag.Function):
24 | @staticmethod
25 | def forward(ctx, features, rois, pooled_height, pooled_width, spatial_scale):
26 | assert 'FloatTensor' in features.type() and 'FloatTensor' in rois.type(), \
27 | 'Precise RoI Pooling only takes float input, got {} for features and {} for rois.'.format(features.type(), rois.type())
28 |
29 | features = features.contiguous()
30 | rois = rois.contiguous()
31 | pooled_height = int(pooled_height)
32 | pooled_width = int(pooled_width)
33 | spatial_scale = float(spatial_scale)
34 |
35 | params = (pooled_height, pooled_width, spatial_scale)
36 | batch_size, nr_channels, data_height, data_width = features.size()
37 | nr_rois = rois.size(0)
38 | output = torch.zeros(
39 | (nr_rois, nr_channels, pooled_height, pooled_width),
40 | dtype=features.dtype, device=features.device
41 | )
42 |
43 | if features.is_cuda:
44 | _prroi_pooling.prroi_pooling_forward_cuda(features, rois, output, *params)
45 | ctx.params = params
46 | # everything here is contiguous.
47 | ctx.save_for_backward(features, rois, output)
48 | else:
49 | raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.')
50 |
51 | return output
52 |
53 | @staticmethod
54 | def backward(ctx, grad_output):
55 | features, rois, output = ctx.saved_tensors
56 | grad_input = grad_coor = None
57 |
58 | if features.requires_grad:
59 | grad_output = grad_output.contiguous()
60 | grad_input = torch.zeros_like(features)
61 | _prroi_pooling.prroi_pooling_backward_cuda(features, rois, output, grad_output, grad_input, *ctx.params)
62 | if rois.requires_grad:
63 | grad_output = grad_output.contiguous()
64 | grad_coor = torch.zeros_like(rois)
65 | _prroi_pooling.prroi_pooling_coor_backward_cuda(features, rois, output, grad_output, grad_coor, *ctx.params)
66 |
67 | return grad_input, grad_coor, None, None, None
68 |
69 |
70 | prroi_pool2d = PrRoIPool2DFunction.apply
71 |
72 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/prroi_pool.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : prroi_pool.py
4 | # Author : Jiayuan Mao, Tete Xiao
5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com
6 | # Date : 07/13/2018
7 | #
8 | # This file is part of PreciseRoIPooling.
9 | # Distributed under terms of the MIT license.
10 | # Copyright (c) 2017 Megvii Technology Limited.
11 |
12 | import torch.nn as nn
13 |
14 | from .functional import prroi_pool2d
15 |
16 | __all__ = ['PrRoIPool2D']
17 |
18 |
19 | class PrRoIPool2D(nn.Module):
20 | def __init__(self, pooled_height, pooled_width, spatial_scale):
21 | super().__init__()
22 |
23 | self.pooled_height = int(pooled_height)
24 | self.pooled_width = int(pooled_width)
25 | self.spatial_scale = float(spatial_scale)
26 |
27 | def forward(self, features, rois):
28 | return prroi_pool2d(features, rois, self.pooled_height, self.pooled_width, self.spatial_scale)
29 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/src/prroi_pooling_gpu.c:
--------------------------------------------------------------------------------
1 | /*
2 | * File : prroi_pooling_gpu.c
3 | * Author : Jiayuan Mao, Tete Xiao
4 | * Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com
5 | * Date : 07/13/2018
6 | *
7 | * Distributed under terms of the MIT license.
8 | * Copyright (c) 2017 Megvii Technology Limited.
9 | */
10 |
11 | #include
12 | #include
13 |
14 | #include "prroi_pooling_gpu_impl.cuh"
15 |
16 | extern THCState *state;
17 |
18 | int prroi_pooling_forward_cuda(THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, int pooled_height, int pooled_width, float spatial_scale) {
19 | const float *data_ptr = THCudaTensor_data(state, features);
20 | const float *rois_ptr = THCudaTensor_data(state, rois);
21 | float *output_ptr = THCudaTensor_data(state, output);
22 |
23 | int nr_rois = THCudaTensor_size(state, rois, 0);
24 | int nr_channels = THCudaTensor_size(state, features, 1);
25 | int height = THCudaTensor_size(state, features, 2);
26 | int width = THCudaTensor_size(state, features, 3);
27 | int top_count = nr_rois * nr_channels * pooled_height * pooled_width;
28 |
29 | cudaStream_t stream = THCState_getCurrentStream(state);
30 |
31 | PrRoIPoolingForwardGpu(
32 | stream, data_ptr, rois_ptr, output_ptr,
33 | nr_channels, height, width, pooled_height, pooled_width, spatial_scale,
34 | top_count
35 | );
36 |
37 | return 1;
38 | }
39 |
40 | int prroi_pooling_backward_cuda(
41 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff,
42 | int pooled_height, int pooled_width, float spatial_scale) {
43 |
44 | const float *data_ptr = THCudaTensor_data(state, features);
45 | const float *rois_ptr = THCudaTensor_data(state, rois);
46 | const float *output_ptr = THCudaTensor_data(state, output);
47 | const float *output_diff_ptr = THCudaTensor_data(state, output_diff);
48 | float *features_diff_ptr = THCudaTensor_data(state, features_diff);
49 |
50 | int nr_rois = THCudaTensor_size(state, rois, 0);
51 | int batch_size = THCudaTensor_size(state, features, 0);
52 | int nr_channels = THCudaTensor_size(state, features, 1);
53 | int height = THCudaTensor_size(state, features, 2);
54 | int width = THCudaTensor_size(state, features, 3);
55 | int top_count = nr_rois * nr_channels * pooled_height * pooled_width;
56 | int bottom_count = batch_size * nr_channels * height * width;
57 |
58 | cudaStream_t stream = THCState_getCurrentStream(state);
59 |
60 | PrRoIPoolingBackwardGpu(
61 | stream, data_ptr, rois_ptr, output_ptr, output_diff_ptr, features_diff_ptr,
62 | nr_channels, height, width, pooled_height, pooled_width, spatial_scale,
63 | top_count, bottom_count
64 | );
65 |
66 | return 1;
67 | }
68 |
69 | int prroi_pooling_coor_backward_cuda(
70 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *coor_diff,
71 | int pooled_height, int pooled_width, float spatial_scale) {
72 |
73 | const float *data_ptr = THCudaTensor_data(state, features);
74 | const float *rois_ptr = THCudaTensor_data(state, rois);
75 | const float *output_ptr = THCudaTensor_data(state, output);
76 | const float *output_diff_ptr = THCudaTensor_data(state, output_diff);
77 | float *coor_diff_ptr= THCudaTensor_data(state, coor_diff);
78 |
79 | int nr_rois = THCudaTensor_size(state, rois, 0);
80 | int nr_channels = THCudaTensor_size(state, features, 1);
81 | int height = THCudaTensor_size(state, features, 2);
82 | int width = THCudaTensor_size(state, features, 3);
83 | int top_count = nr_rois * nr_channels * pooled_height * pooled_width;
84 | int bottom_count = nr_rois * 5;
85 |
86 | cudaStream_t stream = THCState_getCurrentStream(state);
87 |
88 | PrRoIPoolingCoorBackwardGpu(
89 | stream, data_ptr, rois_ptr, output_ptr, output_diff_ptr, coor_diff_ptr,
90 | nr_channels, height, width, pooled_height, pooled_width, spatial_scale,
91 | top_count, bottom_count
92 | );
93 |
94 | return 1;
95 | }
96 |
97 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/src/prroi_pooling_gpu.h:
--------------------------------------------------------------------------------
1 | /*
2 | * File : prroi_pooling_gpu.h
3 | * Author : Jiayuan Mao, Tete Xiao
4 | * Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com
5 | * Date : 07/13/2018
6 | *
7 | * Distributed under terms of the MIT license.
8 | * Copyright (c) 2017 Megvii Technology Limited.
9 | */
10 |
11 | int prroi_pooling_forward_cuda(THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, int pooled_height, int pooled_width, float spatial_scale);
12 |
13 | int prroi_pooling_backward_cuda(
14 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff,
15 | int pooled_height, int pooled_width, float spatial_scale
16 | );
17 |
18 | int prroi_pooling_coor_backward_cuda(
19 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff,
20 | int pooled_height, int pooled_width, float spatial_scal
21 | );
22 |
23 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/src/prroi_pooling_gpu_impl.cu:
--------------------------------------------------------------------------------
1 | ../../../src/prroi_pooling_gpu_impl.cu
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/src/prroi_pooling_gpu_impl.cuh:
--------------------------------------------------------------------------------
1 | ../../../src/prroi_pooling_gpu_impl.cuh
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/prroi_pool/travis.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash -e
2 | # File : travis.sh
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | #
6 | # Distributed under terms of the MIT license.
7 | # Copyright (c) 2017 Megvii Technology Limited.
8 |
9 | cd src
10 | echo "Working directory: " `pwd`
11 | echo "Compiling prroi_pooling kernels by nvcc..."
12 | nvcc -c -o prroi_pooling_gpu_impl.cu.o prroi_pooling_gpu_impl.cu -x cu -Xcompiler -fPIC -arch=sm_52
13 |
14 | cd ../
15 | echo "Working directory: " `pwd`
16 | echo "Building python libraries..."
17 | python3 build.py
18 |
19 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/pytorch/tests/test_prroi_pooling2d.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_prroi_pooling2d.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 18/02/2018
6 | #
7 | # This file is part of Jacinle.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | from jactorch.utils.unittest import TorchTestCase
16 |
17 | from prroi_pool import PrRoIPool2D
18 |
19 |
20 | class TestPrRoIPool2D(TorchTestCase):
21 | def test_forward(self):
22 | pool = PrRoIPool2D(7, 7, spatial_scale=0.5)
23 | features = torch.rand((4, 16, 24, 32)).cuda()
24 | rois = torch.tensor([
25 | [0, 0, 0, 14, 14],
26 | [1, 14, 14, 28, 28],
27 | ]).float().cuda()
28 |
29 | out = pool(features, rois)
30 | out_gold = F.avg_pool2d(features, kernel_size=2, stride=1)
31 |
32 | self.assertTensorClose(out, torch.stack((
33 | out_gold[0, :, :7, :7],
34 | out_gold[1, :, 7:14, 7:14],
35 | ), dim=0))
36 |
37 | def test_backward_shapeonly(self):
38 | pool = PrRoIPool2D(2, 2, spatial_scale=0.5)
39 |
40 | features = torch.rand((4, 2, 24, 32)).cuda()
41 | rois = torch.tensor([
42 | [0, 0, 0, 4, 4],
43 | [1, 14, 14, 18, 18],
44 | ]).float().cuda()
45 | features.requires_grad = rois.requires_grad = True
46 | out = pool(features, rois)
47 |
48 | loss = out.sum()
49 | loss.backward()
50 |
51 | self.assertTupleEqual(features.size(), features.grad.size())
52 | self.assertTupleEqual(rois.size(), rois.grad.size())
53 |
54 |
55 | if __name__ == '__main__':
56 | unittest.main()
57 |
--------------------------------------------------------------------------------
/ltr/external/PreciseRoIPooling/src/prroi_pooling_gpu_impl.cuh:
--------------------------------------------------------------------------------
1 | /*
2 | * File : prroi_pooling_gpu_impl.cuh
3 | * Author : Tete Xiao, Jiayuan Mao
4 | * Email : jasonhsiao97@gmail.com
5 | *
6 | * Distributed under terms of the MIT license.
7 | * Copyright (c) 2017 Megvii Technology Limited.
8 | */
9 |
10 | #ifndef PRROI_POOLING_GPU_IMPL_CUH
11 | #define PRROI_POOLING_GPU_IMPL_CUH
12 |
13 | #ifdef __cplusplus
14 | extern "C" {
15 | #endif
16 |
17 | #define F_DEVPTR_IN const float *
18 | #define F_DEVPTR_OUT float *
19 |
20 | void PrRoIPoolingForwardGpu(
21 | cudaStream_t stream,
22 | F_DEVPTR_IN bottom_data,
23 | F_DEVPTR_IN bottom_rois,
24 | F_DEVPTR_OUT top_data,
25 | const int channels_, const int height_, const int width_,
26 | const int pooled_height_, const int pooled_width_,
27 | const float spatial_scale_,
28 | const int top_count);
29 |
30 | void PrRoIPoolingBackwardGpu(
31 | cudaStream_t stream,
32 | F_DEVPTR_IN bottom_data,
33 | F_DEVPTR_IN bottom_rois,
34 | F_DEVPTR_IN top_data,
35 | F_DEVPTR_IN top_diff,
36 | F_DEVPTR_OUT bottom_diff,
37 | const int channels_, const int height_, const int width_,
38 | const int pooled_height_, const int pooled_width_,
39 | const float spatial_scale_,
40 | const int top_count, const int bottom_count);
41 |
42 | void PrRoIPoolingCoorBackwardGpu(
43 | cudaStream_t stream,
44 | F_DEVPTR_IN bottom_data,
45 | F_DEVPTR_IN bottom_rois,
46 | F_DEVPTR_IN top_data,
47 | F_DEVPTR_IN top_diff,
48 | F_DEVPTR_OUT bottom_diff,
49 | const int channels_, const int height_, const int width_,
50 | const int pooled_height_, const int pooled_width_,
51 | const float spatial_scale_,
52 | const int top_count, const int bottom_count);
53 |
54 | #ifdef __cplusplus
55 | } /* !extern "C" */
56 | #endif
57 |
58 | #endif /* !PRROI_POOLING_GPU_IMPL_CUH */
59 |
60 |
--------------------------------------------------------------------------------
/ltr/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/__init__.py
--------------------------------------------------------------------------------
/ltr/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/models/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .resnet18_vggm import *
3 |
--------------------------------------------------------------------------------
/ltr/models/backbone/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/backbone/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/models/backbone/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/backbone/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/models/backbone/__pycache__/resnet18_vggm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/backbone/__pycache__/resnet18_vggm.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/models/bbreg/__init__.py:
--------------------------------------------------------------------------------
1 | from .atom_iou_net import AtomIoUNet
2 |
--------------------------------------------------------------------------------
/ltr/models/bbreg/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/bbreg/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/models/bbreg/__pycache__/atom.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/bbreg/__pycache__/atom.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/models/bbreg/__pycache__/atom_iou_net.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/bbreg/__pycache__/atom_iou_net.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/models/bbreg/atom.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import ltr.models.backbone as backbones
3 | import ltr.models.bbreg as bbmodels
4 | from ltr import model_constructor
5 |
6 |
7 | class ATOMnet(nn.Module):
8 | """ ATOM network module"""
9 | def __init__(self, feature_extractor, bb_regressor, bb_regressor_layer, extractor_grad=True):
10 | """
11 | args:
12 | feature_extractor - backbone feature extractor
13 | bb_regressor - IoU prediction module
14 | bb_regressor_layer - List containing the name of the layers from feature_extractor, which are input to
15 | bb_regressor
16 | extractor_grad - Bool indicating whether backbone feature extractor requires gradients
17 | """
18 | super(ATOMnet, self).__init__()
19 |
20 | self.feature_extractor = feature_extractor
21 | self.bb_regressor = bb_regressor
22 | self.bb_regressor_layer = bb_regressor_layer
23 |
24 | if not extractor_grad:
25 | for p in self.feature_extractor.parameters():
26 | p.requires_grad_(False)
27 |
28 | def forward(self, train_imgs, test_imgs, train_bb, test_proposals):
29 | """ Forward pass
30 | Note: If the training is done in sequence mode, that is, test_imgs.dim() == 5, then the batch dimension
31 | corresponds to the first dimensions. test_imgs is thus of the form [sequence, batch, feature, row, col]
32 | """
33 | num_sequences = train_imgs.shape[-4]
34 | num_train_images = train_imgs.shape[0] if train_imgs.dim() == 5 else 1
35 | num_test_images = test_imgs.shape[0] if test_imgs.dim() == 5 else 1
36 |
37 | # Extract backbone features
38 | train_feat = self.extract_backbone_features(
39 | train_imgs.view(-1, train_imgs.shape[-3], train_imgs.shape[-2], train_imgs.shape[-1]))
40 | test_feat = self.extract_backbone_features(
41 | test_imgs.view(-1, test_imgs.shape[-3], test_imgs.shape[-2], test_imgs.shape[-1]))
42 |
43 | # For clarity, send the features to bb_regressor in sequence form, i.e. [sequence, batch, feature, row, col]
44 | train_feat_iou = [feat.view(num_train_images, num_sequences, feat.shape[-3], feat.shape[-2], feat.shape[-1])
45 | for feat in train_feat.values()]
46 | test_feat_iou = [feat.view(num_test_images, num_sequences, feat.shape[-3], feat.shape[-2], feat.shape[-1])
47 | for feat in test_feat.values()]
48 |
49 | # Obtain iou prediction
50 | iou_pred = self.bb_regressor(train_feat_iou, test_feat_iou,
51 | train_bb.view(num_train_images, num_sequences, 4),
52 | test_proposals.view(num_train_images, num_sequences, -1, 4))
53 | return iou_pred
54 |
55 | def extract_backbone_features(self, im, layers=None):
56 | if layers is None:
57 | layers = self.bb_regressor_layer
58 | return self.feature_extractor(im, layers)
59 |
60 | def extract_features(self, im, layers):
61 | return self.feature_extractor(im, layers)
62 |
63 |
64 |
65 | @model_constructor
66 | def atom_resnet18(iou_input_dim=(256,256), iou_inter_dim=(256,256), backbone_pretrained=True):
67 | # backbone
68 | backbone_net = backbones.resnet18(pretrained=backbone_pretrained)
69 |
70 | # Bounding box regressor
71 | iou_predictor = bbmodels.AtomIoUNet(pred_input_dim=iou_input_dim, pred_inter_dim=iou_inter_dim)
72 |
73 | net = ATOMnet(feature_extractor=backbone_net, bb_regressor=iou_predictor, bb_regressor_layer=['layer2', 'layer3'],
74 | extractor_grad=False)
75 |
76 | return net
77 |
--------------------------------------------------------------------------------
/ltr/models/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/layers/__init__.py
--------------------------------------------------------------------------------
/ltr/models/layers/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/layers/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/models/layers/__pycache__/blocks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/layers/__pycache__/blocks.cpython-36.pyc
--------------------------------------------------------------------------------
/ltr/models/layers/blocks.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | def conv_block(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=True,
5 | batch_norm=True, relu=True):
6 | layers = [nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
7 | padding=padding, dilation=dilation, bias=bias)]
8 | if batch_norm:
9 | layers.append(nn.BatchNorm2d(out_planes))
10 | if relu:
11 | layers.append(nn.ReLU(inplace=True))
12 | return nn.Sequential(*layers)
13 |
14 |
15 | class LinearBlock(nn.Module):
16 | def __init__(self, in_planes, out_planes, input_sz, bias=True, batch_norm=True, relu=True):
17 | super().__init__()
18 | self.linear = nn.Linear(in_planes*input_sz*input_sz, out_planes, bias=bias)
19 | self.bn = nn.BatchNorm2d(out_planes) if batch_norm else None
20 | self.relu = nn.ReLU(inplace=True) if relu else None
21 |
22 | def forward(self, x):
23 | x = self.linear(x.view(x.shape[0], -1))
24 | if self.bn is not None:
25 | x = self.bn(x.view(x.shape[0], x.shape[1], 1, 1))
26 | if self.relu is not None:
27 | x = self.relu(x)
28 | return x.view(x.shape[0], -1)
--------------------------------------------------------------------------------
/ltr/run_training.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import importlib
5 | import multiprocessing
6 | import cv2 as cv
7 | import torch.backends.cudnn
8 |
9 | env_path = os.path.join(os.path.dirname(__file__), '..')
10 | if env_path not in sys.path:
11 | sys.path.append(env_path)
12 |
13 | import ltr.admin.settings as ws_settings
14 |
15 |
16 | def run_training(train_module, train_name, cudnn_benchmark=True):
17 | """Run a train scripts in train_settings.
18 | args:
19 | train_module: Name of module in the "train_settings/" folder.
20 | train_name: Name of the train settings file.
21 | cudnn_benchmark: Use cudnn benchmark or not (default is True).
22 | """
23 |
24 | # This is needed to avoid strange crashes related to opencv
25 | cv.setNumThreads(0)
26 |
27 | torch.backends.cudnn.benchmark = cudnn_benchmark
28 |
29 | print('Training: {} {}'.format(train_module, train_name))
30 |
31 | settings = ws_settings.Settings()
32 |
33 | if settings.env.workspace_dir == '':
34 | raise Exception('Setup your workspace_dir in "ltr/admin/local.py".')
35 |
36 | settings.module_name = train_module
37 | settings.script_name = train_name
38 | settings.project_path = 'ltr/{}/{}'.format(train_module, train_name)
39 |
40 | expr_module = importlib.import_module('ltr.train_settings.{}.{}'.format(train_module, train_name))
41 | expr_func = getattr(expr_module, 'run')
42 |
43 | expr_func(settings)
44 |
45 |
46 | def main():
47 | parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.')
48 | parser.add_argument('train_module', type=str, help='Name of module in the "train_settings/" folder.')
49 | parser.add_argument('train_name', type=str, help='Name of the train settings file.')
50 | parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).')
51 |
52 | args = parser.parse_args()
53 |
54 | run_training(args.train_module, args.train_name, args.cudnn_benchmark)
55 |
56 |
57 | if __name__ == '__main__':
58 | multiprocessing.set_start_method('spawn', force=True)
59 | main()
60 |
--------------------------------------------------------------------------------
/ltr/train_settings/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/train_settings/__init__.py
--------------------------------------------------------------------------------
/ltr/train_settings/bbreg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/train_settings/bbreg/__init__.py
--------------------------------------------------------------------------------
/ltr/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_trainer import BaseTrainer
2 | from .ltr_trainer import LTRTrainer
--------------------------------------------------------------------------------
/pytracking/__init__.py:
--------------------------------------------------------------------------------
1 | from pytracking.libs import TensorList, TensorDict
2 | import pytracking.libs.complex as complex
3 | import pytracking.libs.operation as operation
4 | import pytracking.libs.fourier as fourier
5 | import pytracking.libs.dcf as dcf
6 | import pytracking.libs.optimization as optimization
7 | from pytracking.run_tracker import run_tracker
8 | from pytracking.run_webcam import run_webcam
9 |
--------------------------------------------------------------------------------
/pytracking/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/pytracking/__pycache__/run_tracker.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/__pycache__/run_tracker.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/__pycache__/run_webcam.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/__pycache__/run_webcam.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .otbdataset import OTBDataset
2 | from .nfsdataset import NFSDataset
3 | from .uavdataset import UAVDataset
4 | from .tpldataset import TPLDataset
5 | from .votdataset import VOTDataset
6 | from .trackingnetdataset import TrackingNetDataset
7 | from .got10kdataset import GOT10KDatasetTest, GOT10KDatasetVal, GOT10KDatasetLTRVal
8 | from .lasotdataset import LaSOTDataset
9 | from .data import Sequence
10 | from .tracker import Tracker
11 |
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/data.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/environment.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/environment.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/got10kdataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/got10kdataset.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/lasotdataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/lasotdataset.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/local.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/local.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/nfsdataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/nfsdataset.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/otbdataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/otbdataset.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/running.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/running.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/tpldataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/tpldataset.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/tracker.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/tracker.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/trackingnetdataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/trackingnetdataset.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/uavdataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/uavdataset.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/__pycache__/votdataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/votdataset.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/data.py:
--------------------------------------------------------------------------------
1 | from pytracking.evaluation.environment import env_settings
2 |
3 |
4 | class BaseDataset:
5 | """Base class for all datasets."""
6 | def __init__(self):
7 | self.env_settings = env_settings()
8 |
9 | def __len__(self):
10 | """Overload this function in your dataset. This should return number of sequences in the dataset."""
11 | raise NotImplementedError
12 |
13 | def get_sequence_list(self):
14 | """Overload this in your dataset. Should return the list of sequences in the dataset."""
15 | raise NotImplementedError
16 |
17 |
18 | class Sequence:
19 | """Class for the sequence in an evaluation."""
20 | def __init__(self, name, frames, ground_truth_rect, gt, object_class=None):
21 | self.name = name
22 | self.frames = frames
23 | self.ground_truth_rect = ground_truth_rect
24 | self.gt = gt
25 | self.init_state = list(self.ground_truth_rect[0,:])
26 | self.object_class = object_class
27 |
28 |
29 | class SequenceList(list):
30 | """List of sequences. Supports the addition operator to concatenate sequence lists."""
31 | def __getitem__(self, item):
32 | if isinstance(item, str):
33 | for seq in self:
34 | if seq.name == item:
35 | return seq
36 | raise IndexError('Sequence name not in the dataset.')
37 | elif isinstance(item, int):
38 | return super(SequenceList, self).__getitem__(item)
39 | elif isinstance(item, (tuple, list)):
40 | return SequenceList([super(SequenceList, self).__getitem__(i) for i in item])
41 | else:
42 | return SequenceList(super(SequenceList, self).__getitem__(item))
43 |
44 | def __add__(self, other):
45 | return SequenceList(super(SequenceList, self).__add__(other))
46 |
47 | def copy(self):
48 | return SequenceList(super(SequenceList, self).copy())
--------------------------------------------------------------------------------
/pytracking/evaluation/environment.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 |
5 | class EnvSettings:
6 | def __init__(self):
7 | pytracking_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
8 |
9 | self.results_path = '{}/tracking_results/'.format(pytracking_path)
10 | self.network_path = '{}/networks/'.format(pytracking_path)
11 | self.otb_path = ''
12 | self.nfs_path = ''
13 | self.uav_path = ''
14 | self.tpl_path = ''
15 | self.vot_path = ''
16 | self.got10k_path = ''
17 | self.lasot_path = ''
18 | self.trackingnet_path = ''
19 |
20 |
21 | def create_default_local_file():
22 | comment = {'results_path': 'Where to store tracking results',
23 | 'network_path': 'Where tracking networks are stored.'}
24 |
25 | path = os.path.join(os.path.dirname(__file__), 'local.py')
26 | with open(path, 'w') as f:
27 | settings = EnvSettings()
28 |
29 | f.write('from pytracking.evaluation.environment import EnvSettings\n\n')
30 | f.write('def local_env_settings():\n')
31 | f.write(' settings = EnvSettings()\n\n')
32 | f.write(' # Set your local paths here.\n\n')
33 |
34 | for attr in dir(settings):
35 | comment_str = None
36 | if attr in comment:
37 | comment_str = comment[attr]
38 | attr_val = getattr(settings, attr)
39 | if not attr.startswith('__') and not callable(attr_val):
40 | if comment_str is None:
41 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val))
42 | else:
43 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
44 | f.write('\n return settings\n\n')
45 |
46 |
47 | def env_settings():
48 | env_module_name = 'pytracking.evaluation.local'
49 | try:
50 | env_module = importlib.import_module(env_module_name)
51 | return env_module.local_env_settings()
52 | except:
53 | env_file = os.path.join(os.path.dirname(__file__), 'local.py')
54 |
55 | # Create a default file
56 | create_default_local_file()
57 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. '
58 | 'Then try to run again.'.format(env_file))
59 |
--------------------------------------------------------------------------------
/pytracking/evaluation/got10kdataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pytracking.evaluation.data import Sequence, BaseDataset, SequenceList
3 | import os
4 |
5 |
6 | def GOT10KDatasetTest():
7 | """ GOT-10k official test set"""
8 | return GOT10KDatasetClass('test').get_sequence_list()
9 |
10 |
11 | def GOT10KDatasetVal():
12 | """ GOT-10k official val set"""
13 | return GOT10KDatasetClass('val').get_sequence_list()
14 |
15 |
16 | def GOT10KDatasetLTRVal():
17 | """ GOT-10k val split from LTR (a subset of GOT-10k official train set)"""
18 | return GOT10KDatasetClass('ltrval').get_sequence_list()
19 |
20 |
21 | class GOT10KDatasetClass(BaseDataset):
22 | """ GOT-10k dataset.
23 |
24 | Publication:
25 | GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild
26 | Lianghua Huang, Xin Zhao, and Kaiqi Huang
27 | arXiv:1810.11981, 2018
28 | https://arxiv.org/pdf/1810.11981.pdf
29 |
30 | Download dataset from http://got-10k.aitestunion.com/downloads
31 | """
32 | def __init__(self, split):
33 | """
34 | args:
35 | split - Split to use. Can be i) 'test': official test set, ii) 'val': official val set, and iii) 'ltrval':
36 | a custom validation set, a subset of the official train set.
37 | """
38 | super().__init__()
39 | # Split can be test, val, or ltrval
40 | if split == 'test' or split == 'val':
41 | self.base_path = os.path.join(self.env_settings.got10k_path, split)
42 | else:
43 | self.base_path = os.path.join(self.env_settings.got10k_path, 'train')
44 |
45 | self.sequence_list = self._get_sequence_list(split)
46 | self.split = split
47 |
48 | def get_sequence_list(self):
49 | return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
50 |
51 | def _construct_sequence(self, sequence_name):
52 | anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name)
53 | try:
54 | ground_truth_rect = np.loadtxt(str(anno_path), dtype=np.float64)
55 | except:
56 | ground_truth_rect = np.loadtxt(str(anno_path), delimiter=',', dtype=np.float64)
57 |
58 | frames_path = '{}/{}'.format(self.base_path, sequence_name)
59 | frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")]
60 | frame_list.sort(key=lambda f: int(f[:-4]))
61 | frames_list = [os.path.join(frames_path, frame) for frame in frame_list]
62 |
63 | return Sequence(sequence_name, frames_list, ground_truth_rect.reshape(-1, 4))
64 |
65 | def __len__(self):
66 | '''Overload this function in your evaluation. This should return number of sequences in the evaluation '''
67 | return len(self.sequence_list)
68 |
69 | def _get_sequence_list(self, split):
70 | with open('{}/list.txt'.format(self.base_path)) as f:
71 | sequence_list = f.read().splitlines()
72 |
73 | if split == 'ltrval':
74 | with open('{}/got10k_val_split.txt'.format(self.env_settings.dataspec_path)) as f:
75 | seq_ids = f.read().splitlines()
76 |
77 | sequence_list = [sequence_list[int(x)] for x in seq_ids]
78 | return sequence_list
79 |
--------------------------------------------------------------------------------
/pytracking/evaluation/local.py:
--------------------------------------------------------------------------------
1 | from evaluation.environment import EnvSettings
2 |
3 | def local_env_settings():
4 | settings = EnvSettings()
5 |
6 | # Set your local paths here.
7 |
8 | settings.got10k_path = ''
9 | settings.lasot_path = ''
10 | settings.network_path = '/home/.../tracking/SPSTracker/pytracking/network/' # Where tracking networks are stored.
11 | settings.nfs_path = ''
12 | settings.otb_path = ''
13 | settings.results_path = '//home/.../tracking/SPSTracker/pytracking/tracking_results' # Where to store tracking results
14 | settings.tpl_path = ''
15 | settings.trackingnet_path = ''
16 | settings.uav_path = ''
17 | settings.vot_path = '/home/.../data/VOT2018'
18 |
19 | return settings
20 |
21 |
--------------------------------------------------------------------------------
/pytracking/evaluation/running.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import multiprocessing
3 | import os
4 | from itertools import product
5 | from pytracking.evaluation import Sequence, Tracker
6 | from os import makedirs
7 | from os.path import join, isdir, isfile
8 | from pytracking.evaluation.utils.pyvotkit.region import vot_float2str
9 | from os.path import split
10 |
11 | def run_sequence(seq: Sequence, tracker: Tracker, debug=False):
12 | """Runs a tracker on a sequence."""
13 | base_results_path = '{}/baseline/{}'.format(tracker.results_dir, seq.name)
14 | results_path = '{}.txt'.format(base_results_path)
15 | times_path = '{}_time.txt'.format(base_results_path)
16 |
17 | if os.path.isfile(results_path) and not debug:
18 | return
19 |
20 | print('Tracker: {} {} {} , Sequence: {}'.format(tracker.name, tracker.parameter_name, tracker.run_id, seq.name))
21 |
22 | if debug:
23 | tracked_bb, exec_times = tracker.run(seq, debug=debug)
24 | else:
25 | try:
26 | tracked_bb, exec_times = tracker.run(seq, debug=debug)
27 | except Exception as e:
28 | print(e)
29 | return
30 | path = base_results_path.split('/')
31 | path = path[-1]
32 |
33 | if not isdir(base_results_path): makedirs(base_results_path)
34 | results_path = join(base_results_path, '{:s}_001.txt').format(path)
35 | #'{}.txt'.format(base_results_path)
36 | with open(results_path, "w") as fin:
37 | for x in tracked_bb:
38 | fin.write("{:d}\n".format(x)) if isinstance(x, int) else \
39 | fin.write(','.join([vot_float2str("%.4f", i) for i in x]) + '\n')
40 | # tracked_bb = np.array(tracked_bb).astype(int)
41 | exec_times = np.array(exec_times).astype(float)
42 |
43 | print('FPS: {}'.format(len(exec_times) / exec_times.sum()))
44 | # if not debug:
45 | # np.savetxt(results_path, tracked_bb, delimiter='\t', fmt='%d')
46 | # np.savetxt(times_path, exec_times, delimiter='\t', fmt='%f')
47 |
48 |
49 | def run_dataset(dataset, trackers, debug=False, threads=0):
50 | """Runs a list of trackers on a dataset.
51 | args:
52 | dataset: List of Sequence instances, forming a dataset.
53 | trackers: List of Tracker instances.
54 | debug: Debug level.
55 | threads: Number of threads to use (default 0).
56 | """
57 | if threads == 0:
58 | mode = 'sequential'
59 | else:
60 | mode = 'parallel'
61 |
62 | if mode == 'sequential':
63 | for seq in dataset:
64 | for tracker_info in trackers:
65 | run_sequence(seq, tracker_info, debug=debug)
66 | elif mode == 'parallel':
67 | param_list = [(seq, tracker_info, debug) for seq, tracker_info in product(dataset, trackers)]
68 | with multiprocessing.Pool(processes=threads) as pool:
69 | pool.starmap(run_sequence, param_list)
70 | print('Done')
71 |
--------------------------------------------------------------------------------
/pytracking/evaluation/tracker.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import pickle
4 | from pytracking.evaluation.environment import env_settings
5 |
6 |
7 | class Tracker:
8 | """Wraps the tracker for evaluation and running purposes.
9 | args:
10 | name: Name of tracking method.
11 | parameter_name: Name of parameter file.
12 | run_id: The run id.
13 | """
14 |
15 | def __init__(self, name: str, parameter_name: str, run_id: int = None):
16 | self.name = name
17 | self.parameter_name = parameter_name
18 | self.run_id = run_id
19 |
20 | env = env_settings()
21 | if self.run_id is None:
22 | self.results_dir = '{}/{}/{}'.format(env.results_path, self.name, self.parameter_name)
23 | else:
24 | self.results_dir = '{}/{}/{}_{:03d}'.format(env.results_path, self.name, self.parameter_name, self.run_id)
25 | if not os.path.exists(self.results_dir):
26 | os.makedirs(self.results_dir)
27 |
28 | tracker_module = importlib.import_module('pytracking.tracker.{}'.format(self.name))
29 |
30 | self.parameters = self.get_parameters()
31 | self.tracker_class = tracker_module.get_tracker_class()
32 |
33 | self.default_visualization = getattr(self.parameters, 'visualization', False)
34 | self.default_debug = getattr(self.parameters, 'debug', 0)
35 |
36 | def run(self, seq, visualization=None, debug=None):
37 | """Run tracker on sequence.
38 | args:
39 | seq: Sequence to run the tracker on.
40 | visualization: Set visualization flag (None means default value specified in the parameters).
41 | debug: Set debug level (None means default value specified in the parameters).
42 | """
43 | visualization_ = visualization
44 | debug_ = debug
45 | if debug is None:
46 | debug_ = self.default_debug
47 | if visualization is None:
48 | if debug is None:
49 | visualization_ = self.default_visualization
50 | else:
51 | visualization_ = True if debug else False
52 |
53 | self.parameters.visualization = visualization_
54 | self.parameters.debug = debug_
55 |
56 | tracker = self.tracker_class(self.parameters)
57 |
58 | output_bb, execution_times = tracker.track_sequence(seq)
59 |
60 | self.parameters.free_memory()
61 |
62 | return output_bb, execution_times
63 | def run_video(self, videofilepath, optional_box=None, debug=None):
64 | """Run the tracker with the vieofile.
65 | args:
66 | debug: Debug level.
67 | """
68 |
69 | debug_ = debug
70 | if debug is None:
71 | debug_ = self.default_debug
72 | self.parameters.debug = debug_
73 |
74 | self.parameters.tracker_name = self.name
75 | self.parameters.param_name = self.parameter_name
76 | tracker = self.tracker_class(self.parameters)
77 | tracker.track_videofile(videofilepath, optional_box)
78 |
79 | def run_webcam(self, debug=None):
80 | """Run the tracker with the webcam.
81 | args:
82 | debug: Debug level.
83 | """
84 |
85 | debug_ = debug
86 | if debug is None:
87 | debug_ = self.default_debug
88 | self.parameters.debug = debug_
89 |
90 | self.parameters.tracker_name = self.name
91 | self.parameters.param_name = self.parameter_name
92 | tracker = self.tracker_class(self.parameters)
93 |
94 | tracker.track_webcam()
95 |
96 | def get_parameters(self):
97 | """Get parameters."""
98 |
99 | parameter_file = '{}/parameters.pkl'.format(self.results_dir)
100 | if os.path.isfile(parameter_file):
101 | return pickle.load(open(parameter_file, 'rb'))
102 |
103 | param_module = importlib.import_module('pytracking.parameter.{}.{}'.format(self.name, self.parameter_name))
104 | params = param_module.parameters()
105 |
106 | if self.run_id is not None:
107 | pickle.dump(params, open(parameter_file, 'wb'))
108 |
109 | return params
110 |
111 |
112 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/__init__.py
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/anchors.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | import numpy as np
7 | import math
8 | from utils.bbox_helper import center2corner, corner2center
9 |
10 |
11 | class Anchors:
12 | def __init__(self, cfg):
13 | self.stride = 8
14 | self.ratios = [0.33, 0.5, 1, 2, 3]
15 | self.scales = [8]
16 | self.round_dight = 0
17 | self.image_center = 0
18 | self.size = 0
19 |
20 | self.__dict__.update(cfg)
21 |
22 | self.anchor_num = len(self.scales) * len(self.ratios)
23 | self.anchors = None # in single position (anchor_num*4)
24 | self.all_anchors = None # in all position 2*(4*anchor_num*h*w)
25 | self.generate_anchors()
26 |
27 | def generate_anchors(self):
28 | self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32)
29 |
30 | size = self.stride * self.stride
31 | count = 0
32 | for r in self.ratios:
33 | if self.round_dight > 0:
34 | ws = round(math.sqrt(size*1. / r), self.round_dight)
35 | hs = round(ws * r, self.round_dight)
36 | else:
37 | ws = int(math.sqrt(size*1. / r))
38 | hs = int(ws * r)
39 |
40 | for s in self.scales:
41 | w = ws * s
42 | h = hs * s
43 | self.anchors[count][:] = [-w*0.5, -h*0.5, w*0.5, h*0.5][:]
44 | count += 1
45 |
46 | def generate_all_anchors(self, im_c, size):
47 | if self.image_center == im_c and self.size == size:
48 | return False
49 | self.image_center = im_c
50 | self.size = size
51 |
52 | a0x = im_c - size // 2 * self.stride
53 | ori = np.array([a0x] * 4, dtype=np.float32)
54 | zero_anchors = self.anchors + ori
55 |
56 | x1 = zero_anchors[:, 0]
57 | y1 = zero_anchors[:, 1]
58 | x2 = zero_anchors[:, 2]
59 | y2 = zero_anchors[:, 3]
60 |
61 | x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), [x1, y1, x2, y2])
62 | cx, cy, w, h = corner2center([x1, y1, x2, y2])
63 |
64 | disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride
65 | disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride
66 |
67 | cx = cx + disp_x
68 | cy = cy + disp_y
69 |
70 | # broadcast
71 | zero = np.zeros((self.anchor_num, size, size), dtype=np.float32)
72 | cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h])
73 | x1, y1, x2, y2 = center2corner([cx, cy, w, h])
74 |
75 | self.all_anchors = np.stack([x1, y1, x2, y2]), np.stack([cx, cy, w, h])
76 | return True
77 |
78 |
79 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/bbox_helper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | import numpy as np
7 | from collections import namedtuple
8 |
9 | Corner = namedtuple('Corner', 'x1 y1 x2 y2')
10 | BBox = Corner
11 | Center = namedtuple('Center', 'x y w h')
12 |
13 |
14 | def corner2center(corner):
15 | """
16 | :param corner: Corner or np.array 4*N
17 | :return: Center or 4 np.array N
18 | """
19 | if isinstance(corner, Corner):
20 | x1, y1, x2, y2 = corner
21 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1))
22 | else:
23 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3]
24 | x = (x1 + x2) * 0.5
25 | y = (y1 + y2) * 0.5
26 | w = x2 - x1
27 | h = y2 - y1
28 | return x, y, w, h
29 |
30 |
31 | def center2corner(center):
32 | """
33 | :param center: Center or np.array 4*N
34 | :return: Corner or np.array 4*N
35 | """
36 | if isinstance(center, Center):
37 | x, y, w, h = center
38 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5)
39 | else:
40 | x, y, w, h = center[0], center[1], center[2], center[3]
41 | x1 = x - w * 0.5
42 | y1 = y - h * 0.5
43 | x2 = x + w * 0.5
44 | y2 = y + h * 0.5
45 | return x1, y1, x2, y2
46 |
47 |
48 | def cxy_wh_2_rect(pos, sz):
49 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) # 0-index
50 |
51 |
52 | def get_axis_aligned_bbox(region):
53 | nv = region.size
54 | if nv == 8:
55 | cx = np.mean(region[0::2])
56 | cy = np.mean(region[1::2])
57 | x1 = min(region[0::2])
58 | x2 = max(region[0::2])
59 | y1 = min(region[1::2])
60 | y2 = max(region[1::2])
61 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[2:4] - region[4:6])
62 | A2 = (x2 - x1) * (y2 - y1)
63 | s = np.sqrt(A1 / A2)
64 | w = s * (x2 - x1) + 1
65 | h = s * (y2 - y1) + 1
66 | else:
67 | x = region[0]
68 | y = region[1]
69 | w = region[2]
70 | h = region[3]
71 | cx = x+w/2
72 | cy = y+h/2
73 |
74 | return cx, cy, w, h
75 |
76 |
77 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/benchmark_helper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | from os.path import join, realpath, dirname, exists, isdir
7 | from os import listdir
8 | import logging
9 | import glob
10 | import numpy as np
11 | import json
12 | from collections import OrderedDict
13 |
14 |
15 | def get_dataset_zoo():
16 | root = realpath(join(dirname(__file__), '../data'))
17 | zoos = listdir(root)
18 |
19 | def valid(x):
20 | y = join(root, x)
21 | if not isdir(y): return False
22 |
23 | return exists(join(y, 'list.txt')) \
24 | or exists(join(y, 'train', 'meta.json'))\
25 | or exists(join(y, 'ImageSets', '2016', 'val.txt'))
26 |
27 | zoos = list(filter(valid, zoos))
28 | return zoos
29 |
30 |
31 | dataset_zoo = get_dataset_zoo()
32 |
33 |
34 | def load_dataset(dataset):
35 | info = OrderedDict()
36 | if 'VOT' in dataset:
37 |
38 | base_path = join(realpath(dirname(__file__)), '../data', dataset)
39 | print(base_path)
40 | if not exists(base_path):
41 | print('aaaa',base_path)
42 | logging.error("Please download test dataset!!!")
43 | exit()
44 | list_path = join(base_path, 'list.txt')
45 | with open(list_path) as f:
46 | videos = [v.strip() for v in f.readlines()]
47 | for video in videos:
48 | video_path = join(base_path, video)
49 | image_path = join(video_path, '*.jpg')
50 | image_files = sorted(glob.glob(image_path))
51 | if len(image_files) == 0: # VOT2018
52 | image_path = join(video_path, 'color', '*.jpg')
53 | image_files = sorted(glob.glob(image_path))
54 | gt_path = join(video_path, 'groundtruth.txt')
55 | gt = np.loadtxt(gt_path, delimiter=',').astype(np.float64)
56 | if gt.shape[1] == 4:
57 | gt = np.column_stack((gt[:, 0], gt[:, 1], gt[:, 0], gt[:, 1] + gt[:, 3]-1,
58 | gt[:, 0] + gt[:, 2]-1, gt[:, 1] + gt[:, 3]-1, gt[:, 0] + gt[:, 2]-1, gt[:, 1]))
59 | info[video] = {'image_files': image_files, 'gt': gt, 'name': video}
60 | elif 'DAVIS' in dataset:
61 | base_path = join(realpath(dirname(__file__)), '../data', 'DAVIS')
62 | list_path = join(realpath(dirname(__file__)), '../data', 'DAVIS', 'ImageSets', dataset[-4:], 'val.txt')
63 | with open(list_path) as f:
64 | videos = [v.strip() for v in f.readlines()]
65 | for video in videos:
66 | info[video] = {}
67 | info[video]['anno_files'] = sorted(glob.glob(join(base_path, 'Annotations/480p', video, '*.png')))
68 | info[video]['image_files'] = sorted(glob.glob(join(base_path, 'JPEGImages/480p', video, '*.jpg')))
69 | info[video]['name'] = video
70 | elif 'ytb_vos' in dataset:
71 | base_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid')
72 | json_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid', 'meta.json')
73 | meta = json.load(open(json_path, 'r'))
74 | meta = meta['videos']
75 | info = dict()
76 | for v in meta.keys():
77 | objects = meta[v]['objects']
78 | frames = []
79 | anno_frames = []
80 | info[v] = dict()
81 | for obj in objects:
82 | frames += objects[obj]['frames']
83 | anno_frames += [objects[obj]['frames'][0]]
84 | frames = sorted(np.unique(frames))
85 | info[v]['anno_files'] = [join(base_path, 'Annotations', v, im_f+'.png') for im_f in frames]
86 | info[v]['anno_init_files'] = [join(base_path, 'Annotations', v, im_f + '.png') for im_f in anno_frames]
87 | info[v]['image_files'] = [join(base_path, 'JPEGImages', v, im_f+'.jpg') for im_f in frames]
88 | info[v]['name'] = v
89 |
90 | info[v]['start_frame'] = dict()
91 | info[v]['end_frame'] = dict()
92 | for obj in objects:
93 | start_file = objects[obj]['frames'][0]
94 | end_file = objects[obj]['frames'][-1]
95 | info[v]['start_frame'][obj] = frames.index(start_file)
96 | info[v]['end_frame'][obj] = frames.index(end_file)
97 | else:
98 | logging.error('Not support')
99 | exit()
100 | return info
101 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/config_helper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | import json
7 | from os.path import exists
8 |
9 |
10 | def load_config(args):
11 | assert exists(args.config), '"{}" not exists'.format(args.config)
12 | config = json.load(open(args.config))
13 |
14 | # deal with network
15 | if 'network' not in config:
16 | print('Warning: network lost in config. This will be error in next version')
17 |
18 | config['network'] = {}
19 |
20 | if not args.arch:
21 | raise Exception('no arch provided')
22 | args.arch = config['network']['arch']
23 |
24 | return config
25 |
26 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/load_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | logger = logging.getLogger('global')
4 |
5 |
6 | def check_keys(model, pretrained_state_dict):
7 | ckpt_keys = set(pretrained_state_dict.keys())
8 | model_keys = set(model.state_dict().keys())
9 | used_pretrained_keys = model_keys & ckpt_keys
10 | unused_pretrained_keys = ckpt_keys - model_keys
11 | missing_keys = model_keys - ckpt_keys
12 | if len(missing_keys) > 0:
13 | logger.info('[Warning] missing keys: {}'.format(missing_keys))
14 | logger.info('missing keys:{}'.format(len(missing_keys)))
15 | if len(unused_pretrained_keys) > 0:
16 | logger.info('[Warning] unused_pretrained_keys: {}'.format(unused_pretrained_keys))
17 | logger.info('unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
18 | logger.info('used keys:{}'.format(len(used_pretrained_keys)))
19 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
20 | return True
21 |
22 |
23 | def remove_prefix(state_dict, prefix):
24 | ''' Old style model is stored with all names of parameters share common prefix 'module.' '''
25 | logger.info('remove prefix \'{}\''.format(prefix))
26 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
27 | return {f(key): value for key, value in state_dict.items()}
28 |
29 |
30 | def load_pretrain(model, pretrained_path):
31 | logger.info('load pretrained model from {}'.format(pretrained_path))
32 | device = torch.cuda.current_device()
33 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
34 | if "state_dict" in pretrained_dict.keys():
35 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
36 | else:
37 | pretrained_dict = remove_prefix(pretrained_dict, 'module.')
38 |
39 | try:
40 | check_keys(model, pretrained_dict)
41 | except:
42 | logger.info('[Warning]: using pretrain as features. Adding "features." as prefix')
43 | new_dict = {}
44 | for k, v in pretrained_dict.items():
45 | k = 'features.' + k
46 | new_dict[k] = v
47 | pretrained_dict = new_dict
48 | check_keys(model, pretrained_dict)
49 | model.load_state_dict(pretrained_dict, strict=False)
50 | return model
51 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/log_helper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | from __future__ import division
7 |
8 | import os
9 | import logging
10 | import sys
11 |
12 | if hasattr(sys, 'frozen'): # support for py2exe
13 | _srcfile = "logging%s__init__%s" % (os.sep, __file__[-4:])
14 | elif __file__[-4:].lower() in ['.pyc', '.pyo']:
15 | _srcfile = __file__[:-4] + '.py'
16 | else:
17 | _srcfile = __file__
18 | _srcfile = os.path.normcase(_srcfile)
19 |
20 |
21 | logs = set()
22 |
23 |
24 | class Filter:
25 | def __init__(self, flag):
26 | self.flag = flag
27 |
28 | def filter(self, x): return self.flag
29 |
30 |
31 | class Dummy:
32 | def __init__(self, *arg, **kwargs):
33 | pass
34 |
35 | def __getattr__(self, arg):
36 | def dummy(*args, **kwargs): pass
37 | return dummy
38 |
39 |
40 | def get_format(logger, level):
41 | if 'SLURM_PROCID' in os.environ:
42 | rank = int(os.environ['SLURM_PROCID'])
43 |
44 | if level == logging.INFO:
45 | logger.addFilter(Filter(rank == 0))
46 | else:
47 | rank = 0
48 | format_str = '[%(asctime)s-rk{}-%(filename)s#%(lineno)3d] %(message)s'.format(rank)
49 | formatter = logging.Formatter(format_str)
50 | return formatter
51 |
52 |
53 | def init_log(name, level = logging.INFO, format_func=get_format):
54 | if (name, level) in logs: return
55 | logs.add((name, level))
56 | logger = logging.getLogger(name)
57 | logger.setLevel(level)
58 | ch = logging.StreamHandler()
59 | ch.setLevel(level)
60 | formatter = format_func(logger, level)
61 | ch.setFormatter(formatter)
62 | logger.addHandler(ch)
63 | return logger
64 |
65 |
66 | def add_file_handler(name, log_file, level = logging.INFO):
67 | logger = logging.getLogger(name)
68 | fh = logging.FileHandler(log_file)
69 | fh.setFormatter(get_format(logger, level))
70 | logger.addHandler(fh)
71 |
72 |
73 | init_log('global')
74 |
75 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pysot/__init__.py
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from .vot import VOTDataset
10 |
11 |
12 | class DatasetFactory(object):
13 | @staticmethod
14 | def create_dataset(**kwargs):
15 | """
16 | Args:
17 | name: dataset name 'VOT2018', 'VOT2016'
18 | dataset_root: dataset root
19 | Return:
20 | dataset
21 | """
22 | assert 'name' in kwargs, "should provide dataset name"
23 | name = kwargs['name']
24 | if 'VOT2018' == name or 'VOT2016' == name:
25 | dataset = VOTDataset(**kwargs)
26 | else:
27 | raise Exception("unknow dataset {}".format(kwargs['name']))
28 | return dataset
29 |
30 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | class Dataset(object):
10 | def __init__(self, name, dataset_root):
11 | self.name = name
12 | self.dataset_root = dataset_root
13 | self.videos = None
14 |
15 | def __getitem__(self, idx):
16 | if isinstance(idx, str):
17 | return self.videos[idx]
18 | elif isinstance(idx, int):
19 | return self.videos[sorted(list(self.videos.keys()))[idx]]
20 |
21 | def __len__(self):
22 | return len(self.videos)
23 |
24 | def __iter__(self):
25 | keys = sorted(list(self.videos.keys()))
26 | for key in keys:
27 | yield self.videos[key]
28 |
29 | def set_tracker(self, path, tracker_names):
30 | """
31 | Args:
32 | path: path to tracker results,
33 | tracker_names: list of tracker name
34 | """
35 | self.tracker_path = path
36 | self.tracker_names = tracker_names
37 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/datasets/video.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | import os
10 | import cv2
11 |
12 | from glob import glob
13 |
14 |
15 | class Video(object):
16 | def __init__(self, name, root, video_dir, init_rect, img_names,
17 | gt_rect, attr):
18 | self.name = name
19 | self.video_dir = video_dir
20 | self.init_rect = init_rect
21 | self.gt_traj = gt_rect
22 | self.attr = attr
23 | self.pred_trajs = {}
24 | self.img_names = [os.path.join(root, x) for x in img_names]
25 | self.imgs = None
26 |
27 | def load_tracker(self, path, tracker_names=None, store=True):
28 | """
29 | Args:
30 | path(str): path to result
31 | tracker_name(list): name of tracker
32 | """
33 | if not tracker_names:
34 | tracker_names = [x.split('/')[-1] for x in glob(path)
35 | if os.path.isdir(x)]
36 | if isinstance(tracker_names, str):
37 | tracker_names = [tracker_names]
38 | for name in tracker_names:
39 | traj_file = os.path.join(path, name, self.name+'.txt')
40 | if os.path.exists(traj_file):
41 | with open(traj_file, 'r') as f :
42 | pred_traj = [list(map(float, x.strip().split(',')))
43 | for x in f.readlines()]
44 | if len(pred_traj) != len(self.gt_traj):
45 | print(name, len(pred_traj), len(self.gt_traj), self.name)
46 | if store:
47 | self.pred_trajs[name] = pred_traj
48 | else:
49 | return pred_traj
50 | else:
51 | print(traj_file)
52 | self.tracker_names = list(self.pred_trajs.keys())
53 |
54 | def load_img(self):
55 | if self.imgs is None:
56 | self.imgs = [cv2.imread(x)
57 | for x in self.img_names]
58 | self.width = self.imgs[0].shape[1]
59 | self.height = self.imgs[0].shape[0]
60 |
61 | def free_img(self):
62 | self.imgs = None
63 |
64 | def __len__(self):
65 | return len(self.img_names)
66 |
67 | def __getitem__(self, idx):
68 | if self.imgs is None:
69 | return cv2.imread(self.img_names[idx]), \
70 | self.gt_traj[idx]
71 | else:
72 | return self.imgs[idx], self.gt_traj[idx]
73 |
74 | def __iter__(self):
75 | for i in range(len(self.img_names)):
76 | if self.imgs is not None:
77 | yield self.imgs[i], self.gt_traj[i]
78 | else:
79 | yield cv2.imread(self.img_names[i]), \
80 | self.gt_traj[i]
81 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from .ar_benchmark import AccuracyRobustnessBenchmark
10 | from .eao_benchmark import EAOBenchmark
11 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from . import region
10 | from .statistics import *
11 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/utils/build/temp.linux-x86_64-3.6/region.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pysot/utils/build/temp.linux-x86_64-3.6/region.o
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/utils/build/temp.linux-x86_64-3.6/src/region.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pysot/utils/build/temp.linux-x86_64-3.6/src/region.o
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/utils/c_region.pxd:
--------------------------------------------------------------------------------
1 | cdef extern from "src/region.h":
2 | ctypedef enum region_type "RegionType":
3 | EMTPY
4 | SPECIAL
5 | RECTANGEL
6 | POLYGON
7 | MASK
8 |
9 | ctypedef struct region_bounds:
10 | float top
11 | float bottom
12 | float left
13 | float right
14 |
15 | ctypedef struct region_rectangle:
16 | float x
17 | float y
18 | float width
19 | float height
20 |
21 | # ctypedef struct region_mask:
22 | # int x
23 | # int y
24 | # int width
25 | # int height
26 | # char *data
27 |
28 | ctypedef struct region_polygon:
29 | int count
30 | float *x
31 | float *y
32 |
33 | ctypedef union region_container_data:
34 | region_rectangle rectangle
35 | region_polygon polygon
36 | # region_mask mask
37 | int special
38 |
39 | ctypedef struct region_container:
40 | region_type type
41 | region_container_data data
42 |
43 | # ctypedef struct region_overlap:
44 | # float overlap
45 | # float only1
46 | # float only2
47 |
48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds)
49 |
50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds)
51 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/utils/misc.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | import numpy as np
10 |
11 | def determine_thresholds(confidence, resolution=100):
12 | """choose threshold according to confidence
13 |
14 | Args:
15 | confidence: list or numpy array or numpy array
16 | reolution: number of threshold to choose
17 |
18 | Restures:
19 | threshold: numpy array
20 | """
21 | if isinstance(confidence, list):
22 | confidence = np.array(confidence)
23 | confidence = confidence.flatten()
24 | confidence = confidence[~np.isnan(confidence)]
25 | confidence.sort()
26 |
27 | assert len(confidence) > resolution and resolution > 2
28 |
29 | thresholds = np.ones((resolution))
30 | thresholds[0] = - np.inf
31 | thresholds[-1] = np.inf
32 | delta = np.floor(len(confidence) / (resolution - 2))
33 | idxs = np.linspace(delta, len(confidence)-delta, resolution-2, dtype=np.int32)
34 | thresholds[1:-1] = confidence[idxs]
35 | return thresholds
36 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/utils/region.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pysot/utils/region.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/utils/setup.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from distutils.core import setup
10 | from distutils.extension import Extension
11 | from Cython.Build import cythonize
12 |
13 | setup(
14 | ext_modules = cythonize([Extension("region", ["region.pyx", "src/region.c"])]),
15 | )
16 |
17 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pysot/utils/src/region.h:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */
2 |
3 | #ifndef _REGION_H_
4 | #define _REGION_H_
5 |
6 | #ifdef TRAX_STATIC_DEFINE
7 | # define __TRAX_EXPORT
8 | #else
9 | # ifndef __TRAX_EXPORT
10 | # if defined(_MSC_VER)
11 | # ifdef trax_EXPORTS
12 | /* We are building this library */
13 | # define __TRAX_EXPORT __declspec(dllexport)
14 | # else
15 | /* We are using this library */
16 | # define __TRAX_EXPORT __declspec(dllimport)
17 | # endif
18 | # elif defined(__GNUC__)
19 | # ifdef trax_EXPORTS
20 | /* We are building this library */
21 | # define __TRAX_EXPORT __attribute__((visibility("default")))
22 | # else
23 | /* We are using this library */
24 | # define __TRAX_EXPORT __attribute__((visibility("default")))
25 | # endif
26 | # endif
27 | # endif
28 | #endif
29 |
30 | #ifndef MAX
31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b))
32 | #endif
33 |
34 | #ifndef MIN
35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b))
36 | #endif
37 |
38 | #define TRAX_DEFAULT_CODE 0
39 |
40 | #define REGION_LEGACY_RASTERIZATION 1
41 |
42 | #ifdef __cplusplus
43 | extern "C" {
44 | #endif
45 |
46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type;
47 |
48 | typedef struct region_bounds {
49 |
50 | float top;
51 | float bottom;
52 | float left;
53 | float right;
54 |
55 | } region_bounds;
56 |
57 | typedef struct region_polygon {
58 |
59 | int count;
60 |
61 | float* x;
62 | float* y;
63 |
64 | } region_polygon;
65 |
66 | typedef struct region_mask {
67 |
68 | int x;
69 | int y;
70 |
71 | int width;
72 | int height;
73 |
74 | char* data;
75 |
76 | } region_mask;
77 |
78 | typedef struct region_rectangle {
79 |
80 | float x;
81 | float y;
82 | float width;
83 | float height;
84 |
85 | } region_rectangle;
86 |
87 | typedef struct region_container {
88 | enum region_type type;
89 | union {
90 | region_rectangle rectangle;
91 | region_polygon polygon;
92 | region_mask mask;
93 | int special;
94 | } data;
95 | } region_container;
96 |
97 | typedef struct region_overlap {
98 |
99 | float overlap;
100 | float only1;
101 | float only2;
102 |
103 | } region_overlap;
104 |
105 | extern const region_bounds region_no_bounds;
106 |
107 | __TRAX_EXPORT int region_set_flags(int mask);
108 |
109 | __TRAX_EXPORT int region_clear_flags(int mask);
110 |
111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds);
112 |
113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds);
114 |
115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom);
116 |
117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region);
118 |
119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region);
120 |
121 | __TRAX_EXPORT char* region_string(region_container* region);
122 |
123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region);
124 |
125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type);
126 |
127 | __TRAX_EXPORT void region_release(region_container** region);
128 |
129 | __TRAX_EXPORT region_container* region_create_special(int code);
130 |
131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height);
132 |
133 | __TRAX_EXPORT region_container* region_create_polygon(int count);
134 |
135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y);
136 |
137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height);
138 |
139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height);
140 |
141 | #ifdef __cplusplus
142 | }
143 | #endif
144 |
145 | #endif
146 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pyvotkit/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from . import region
10 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pyvotkit/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pyvotkit/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pyvotkit/build/temp.linux-x86_64-3.6/region.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pyvotkit/build/temp.linux-x86_64-3.6/region.o
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pyvotkit/build/temp.linux-x86_64-3.6/src/region.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pyvotkit/build/temp.linux-x86_64-3.6/src/region.o
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pyvotkit/c_region.pxd:
--------------------------------------------------------------------------------
1 | cdef extern from "region.h":
2 | ctypedef enum region_type "RegionType":
3 | EMTPY
4 | SPECIAL
5 | RECTANGEL
6 | POLYGON
7 | MASK
8 |
9 | ctypedef struct region_bounds:
10 | float top
11 | float bottom
12 | float left
13 | float right
14 |
15 | ctypedef struct region_rectangle:
16 | float x
17 | float y
18 | float width
19 | float height
20 |
21 | # ctypedef struct region_mask:
22 | # int x
23 | # int y
24 | # int width
25 | # int height
26 | # char *data
27 |
28 | ctypedef struct region_polygon:
29 | int count
30 | float *x
31 | float *y
32 |
33 | ctypedef union region_container_data:
34 | region_rectangle rectangle
35 | region_polygon polygon
36 | # region_mask mask
37 | int special
38 |
39 | ctypedef struct region_container:
40 | region_type type
41 | region_container_data data
42 |
43 | # ctypedef struct region_overlap:
44 | # float overlap
45 | # float only1
46 | # float only2
47 |
48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds)
49 |
50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds)
51 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pyvotkit/region.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pyvotkit/region.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pyvotkit/setup.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from distutils.core import setup
10 | from distutils.extension import Extension
11 | from Cython.Build import cythonize
12 |
13 | setup(
14 | ext_modules = cythonize([Extension("region", ["region.pyx"])])
15 | )
16 |
17 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/pyvotkit/src/region.h:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */
2 |
3 | #ifndef _REGION_H_
4 | #define _REGION_H_
5 |
6 | #ifdef TRAX_STATIC_DEFINE
7 | # define __TRAX_EXPORT
8 | #else
9 | # ifndef __TRAX_EXPORT
10 | # if defined(_MSC_VER)
11 | # ifdef trax_EXPORTS
12 | /* We are building this library */
13 | # define __TRAX_EXPORT __declspec(dllexport)
14 | # else
15 | /* We are using this library */
16 | # define __TRAX_EXPORT __declspec(dllimport)
17 | # endif
18 | # elif defined(__GNUC__)
19 | # ifdef trax_EXPORTS
20 | /* We are building this library */
21 | # define __TRAX_EXPORT __attribute__((visibility("default")))
22 | # else
23 | /* We are using this library */
24 | # define __TRAX_EXPORT __attribute__((visibility("default")))
25 | # endif
26 | # endif
27 | # endif
28 | #endif
29 |
30 | #ifndef MAX
31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b))
32 | #endif
33 |
34 | #ifndef MIN
35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b))
36 | #endif
37 |
38 | #define TRAX_DEFAULT_CODE 0
39 |
40 | #define REGION_LEGACY_RASTERIZATION 1
41 |
42 | #ifdef __cplusplus
43 | extern "C" {
44 | #endif
45 |
46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type;
47 |
48 | typedef struct region_bounds {
49 |
50 | float top;
51 | float bottom;
52 | float left;
53 | float right;
54 |
55 | } region_bounds;
56 |
57 | typedef struct region_polygon {
58 |
59 | int count;
60 |
61 | float* x;
62 | float* y;
63 |
64 | } region_polygon;
65 |
66 | typedef struct region_mask {
67 |
68 | int x;
69 | int y;
70 |
71 | int width;
72 | int height;
73 |
74 | char* data;
75 |
76 | } region_mask;
77 |
78 | typedef struct region_rectangle {
79 |
80 | float x;
81 | float y;
82 | float width;
83 | float height;
84 |
85 | } region_rectangle;
86 |
87 | typedef struct region_container {
88 | enum region_type type;
89 | union {
90 | region_rectangle rectangle;
91 | region_polygon polygon;
92 | region_mask mask;
93 | int special;
94 | } data;
95 | } region_container;
96 |
97 | typedef struct region_overlap {
98 |
99 | float overlap;
100 | float only1;
101 | float only2;
102 |
103 | } region_overlap;
104 |
105 | extern const region_bounds region_no_bounds;
106 |
107 | __TRAX_EXPORT int region_set_flags(int mask);
108 |
109 | __TRAX_EXPORT int region_clear_flags(int mask);
110 |
111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds);
112 |
113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds);
114 |
115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom);
116 |
117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region);
118 |
119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region);
120 |
121 | __TRAX_EXPORT char* region_string(region_container* region);
122 |
123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region);
124 |
125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type);
126 |
127 | __TRAX_EXPORT void region_release(region_container** region);
128 |
129 | __TRAX_EXPORT region_container* region_create_special(int code);
130 |
131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height);
132 |
133 | __TRAX_EXPORT region_container* region_create_polygon(int count);
134 |
135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y);
136 |
137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height);
138 |
139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height);
140 |
141 | #ifdef __cplusplus
142 | }
143 | #endif
144 |
145 | #endif
146 |
--------------------------------------------------------------------------------
/pytracking/evaluation/utils/tracker_config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | from __future__ import division
7 | from utils.anchors import Anchors
8 |
9 |
10 | class TrackerConfig(object):
11 | # These are the default hyper-params for SiamMask
12 | penalty_k = 0.04
13 | window_influence = 0.42
14 | lr = 0.25
15 | seg_thr = 0.3 # for mask
16 | windowing = 'cosine' # to penalize large displacements [cosine/uniform]
17 | # Params from the network architecture, have to be consistent with the training
18 | exemplar_size = 127 # input z size
19 | instance_size = 255 # input x size (search region)
20 | total_stride = 8
21 | out_size = 63 # for mask
22 | base_size = 8
23 | score_size = (instance_size-exemplar_size)//total_stride+1 + base_size
24 | context_amount = 0.5 # context amount for the exemplar
25 | ratios = [0.33, 0.5, 1, 2, 3]
26 | scales = [8, ]
27 | anchor_num = len(ratios) * len(scales)
28 | round_dight = 0
29 | anchor = []
30 |
31 | def update(self, newparam=None, anchors=None):
32 | if newparam:
33 | for key, value in newparam.items():
34 | setattr(self, key, value)
35 | if anchors is not None:
36 | if isinstance(anchors, dict):
37 | anchors = Anchors(anchors)
38 | if isinstance(anchors, Anchors):
39 | self.total_stride = anchors.stride
40 | self.ratios = anchors.ratios
41 | self.scales = anchors.scales
42 | self.round_dight = anchors.round_dight
43 | self.renew()
44 |
45 | def renew(self):
46 | self.score_size = (self.instance_size - self.exemplar_size) // self.total_stride + 1 + self.base_size
47 | self.anchor_num = len(self.ratios) * len(self.scales)
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/pytracking/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/experiments/__init__.py
--------------------------------------------------------------------------------
/pytracking/experiments/myexperiments.py:
--------------------------------------------------------------------------------
1 | from pytracking.evaluation import Tracker, OTBDataset, NFSDataset, UAVDataset, TPLDataset, VOTDataset, TrackingNetDataset, LaSOTDataset
2 |
3 |
4 | def atom_nfs_uav():
5 | # Run three runs of ATOM on NFS and UAV datasets
6 | trackers = [Tracker('atom', 'default', i) for i in range(3)]
7 |
8 | dataset = NFSDataset() + UAVDataset()
9 | return trackers, dataset
10 |
11 |
12 | def uav_test():
13 | # Run ATOM and ECO on the UAV dataset
14 | trackers = [Tracker('atom', 'default', i) for i in range(1)] + \
15 | [Tracker('eco', 'default', i) for i in range(1)]
16 |
17 | dataset = UAVDataset()
18 | return trackers, dataset
19 |
--------------------------------------------------------------------------------
/pytracking/features/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__init__.py
--------------------------------------------------------------------------------
/pytracking/features/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/features/__pycache__/augmentation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/augmentation.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/features/__pycache__/deep.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/deep.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/features/__pycache__/extractor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/extractor.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/features/__pycache__/featurebase.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/featurebase.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/features/__pycache__/preprocessing.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/preprocessing.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/features/color.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytracking.features.featurebase import FeatureBase
3 |
4 |
5 | class RGB(FeatureBase):
6 | """RGB feature normalized to [-0.5, 0.5]."""
7 | def dim(self):
8 | return 3
9 |
10 | def stride(self):
11 | return self.pool_stride
12 |
13 | def extract(self, im: torch.Tensor):
14 | return im/255 - 0.5
15 |
16 |
17 | class Grayscale(FeatureBase):
18 | """Grayscale feature normalized to [-0.5, 0.5]."""
19 | def dim(self):
20 | return 1
21 |
22 | def stride(self):
23 | return self.pool_stride
24 |
25 | def extract(self, im: torch.Tensor):
26 | return torch.mean(im/255 - 0.5, 1, keepdim=True)
27 |
--------------------------------------------------------------------------------
/pytracking/features/preprocessing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 |
6 | def numpy_to_torch(a: np.ndarray):
7 | return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0)
8 |
9 |
10 | def torch_to_numpy(a: torch.Tensor):
11 | return a.squeeze(0).permute(1,2,0).numpy()
12 |
13 |
14 | def sample_patch(im: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, output_sz: torch.Tensor = None):
15 | """Sample an image patch.
16 |
17 | args:
18 | im: Image
19 | pos: center position of crop
20 | sample_sz: size to crop
21 | output_sz: size to resize to
22 | """
23 |
24 | # copy and convert
25 | posl = pos.long().clone()
26 |
27 | # Compute pre-downsampling factor
28 | if output_sz is not None:
29 | resize_factor = torch.min(sample_sz.float() / output_sz.float()).item()
30 | df = int(max(int(resize_factor - 0.1), 1))
31 | else:
32 | df = int(1)
33 |
34 | sz = sample_sz.float() / df # new size
35 |
36 | # Do downsampling
37 | if df > 1:
38 | os = posl % df # offset
39 | posl = (posl - os) / df # new position
40 | im2 = im[..., os[0].item()::df, os[1].item()::df] # downsample
41 | else:
42 | im2 = im
43 |
44 | # compute size to crop
45 | szl = torch.max(sz.round(), torch.Tensor([2])).long()
46 |
47 | # Extract top and bottom coordinates
48 | tl = posl - (szl - 1)/2
49 | br = posl + szl/2
50 |
51 | # Get image patch
52 | im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3] + 1, -tl[0].item(), br[0].item() - im2.shape[2] + 1), 'replicate')
53 |
54 | if output_sz is None or (im_patch.shape[-2] == output_sz[0] and im_patch.shape[-1] == output_sz[1]):
55 | return im_patch
56 |
57 | # Resample
58 | im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='bilinear')
59 |
60 | return im_patch
61 |
--------------------------------------------------------------------------------
/pytracking/features/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytracking.features.featurebase import FeatureBase
3 |
4 |
5 | class Concatenate(FeatureBase):
6 | """A feature that concatenates other features.
7 | args:
8 | features: List of features to concatenate.
9 | """
10 | def __init__(self, features, pool_stride = None, normalize_power = None, use_for_color = True, use_for_gray = True):
11 | super(Concatenate, self).__init__(pool_stride, normalize_power, use_for_color, use_for_gray)
12 | self.features = features
13 |
14 | self.input_stride = self.features[0].stride()
15 |
16 | for feat in self.features:
17 | if self.input_stride != feat.stride():
18 | raise ValueError('Strides for the features must be the same for a bultiresolution feature.')
19 |
20 | def dim(self):
21 | return sum([f.dim() for f in self.features])
22 |
23 | def stride(self):
24 | return self.pool_stride * self.input_stride
25 |
26 | def extract(self, im: torch.Tensor):
27 | return torch.cat([f.get_feature(im) for f in self.features], 1)
28 |
--------------------------------------------------------------------------------
/pytracking/libs/__init__.py:
--------------------------------------------------------------------------------
1 | from .tensorlist import TensorList
2 | from .tensordict import TensorDict
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/complex.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/complex.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/dcf.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/dcf.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/fourier.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/fourier.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/operation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/operation.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/optimization.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/optimization.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/tensordict.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/tensordict.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/tensorlist.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/tensorlist.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/libs/__pycache__/tensorlist.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/tensorlist.cpython-37.pyc
--------------------------------------------------------------------------------
/pytracking/libs/operation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from pytracking.libs.tensorlist import tensor_operation, TensorList
4 |
5 |
6 | @tensor_operation
7 | def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None, stride=1, padding=0, dilation=1, groups=1, mode=None):
8 | """Standard conv2d. Returns the input if weight=None."""
9 |
10 | if weight is None:
11 | return input
12 |
13 | ind = None
14 | if mode is not None:
15 | if padding != 0:
16 | raise ValueError('Cannot input both padding and mode.')
17 | if mode == 'same':
18 | padding = (weight.shape[2]//2, weight.shape[3]//2)
19 | if weight.shape[2] % 2 == 0 or weight.shape[3] % 2 == 0:
20 | ind = (slice(-1) if weight.shape[2] % 2 == 0 else slice(None),
21 | slice(-1) if weight.shape[3] % 2 == 0 else slice(None))
22 | elif mode == 'valid':
23 | padding = (0, 0)
24 | elif mode == 'full':
25 | padding = (weight.shape[2]-1, weight.shape[3]-1)
26 | else:
27 | raise ValueError('Unknown mode for padding.')
28 |
29 | out = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
30 | if ind is None:
31 | return out
32 | return out[:,:,ind[0],ind[1]]
33 |
34 |
35 | @tensor_operation
36 | def conv1x1(input: torch.Tensor, weight: torch.Tensor):
37 | """Do a convolution with a 1x1 kernel weights. Implemented with matmul, which can be faster than using conv."""
38 |
39 | if weight is None:
40 | return input
41 |
42 | return torch.matmul(weight.view(weight.shape[0], weight.shape[1]),
43 | input.view(input.shape[0], input.shape[1], -1)).view(input.shape[0], weight.shape[0], input.shape[2], input.shape[3])
44 |
--------------------------------------------------------------------------------
/pytracking/libs/tensordict.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch
3 |
4 |
5 | class TensorDict(OrderedDict):
6 | """Container mainly used for dicts of torch tensors. Extends OrderedDict with pytorch functionality."""
7 |
8 | def concat(self, other):
9 | """Concatenates two dicts without copying internal data."""
10 | return TensorDict(self, **other)
11 |
12 | def copy(self):
13 | return TensorDict(super(TensorDict, self).copy())
14 |
15 | def __getattr__(self, name):
16 | if not hasattr(torch.Tensor, name):
17 | raise AttributeError('\'TensorDict\' object has not attribute \'{}\''.format(name))
18 |
19 | def apply_attr(*args, **kwargs):
20 | return TensorDict({n: getattr(e, name)(*args, **kwargs) if hasattr(e, name) else e for n, e in self.items()})
21 | return apply_attr
22 |
23 | def attribute(self, attr: str, *args):
24 | return TensorDict({n: getattr(e, attr, *args) for n, e in self.items()})
25 |
26 | def apply(self, fn, *args, **kwargs):
27 | return TensorDict({n: fn(e, *args, **kwargs) for n, e in self.items()})
28 |
29 | @staticmethod
30 | def _iterable(a):
31 | return isinstance(a, (TensorDict, list))
32 |
33 |
--------------------------------------------------------------------------------
/pytracking/parameter/SPSTracker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/SPSTracker/__init__.py
--------------------------------------------------------------------------------
/pytracking/parameter/SPSTracker/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/SPSTracker/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/parameter/SPSTracker/__pycache__/default_vot.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/SPSTracker/__pycache__/default_vot.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/parameter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/__init__.py
--------------------------------------------------------------------------------
/pytracking/parameter/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/parameter/atom/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/atom/__init__.py
--------------------------------------------------------------------------------
/pytracking/parameter/atom/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/atom/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/parameter/atom/__pycache__/default.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/atom/__pycache__/default.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/parameter/atom/__pycache__/default_vot.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/atom/__pycache__/default_vot.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/run_experiment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import importlib
5 |
6 | env_path = os.path.join(os.path.dirname(__file__), '..')
7 | if env_path not in sys.path:
8 | sys.path.append(env_path)
9 |
10 | from pytracking.evaluation.running import run_dataset
11 |
12 |
13 | def run_experiment(experiment_module: str, experiment_name: str, debug=0, threads=0):
14 | """Run experiment.
15 | args:
16 | experiment_module: Name of experiment module in the experiments/ folder.
17 | experiment_name: Name of the experiment function.
18 | debug: Debug level.
19 | threads: Number of threads.
20 | """
21 | expr_module = importlib.import_module('pytracking.experiments.{}'.format(experiment_module))
22 | expr_func = getattr(expr_module, experiment_name)
23 | trackers, dataset = expr_func()
24 | print('Running: {} {}'.format(experiment_module, experiment_name))
25 | run_dataset(dataset, trackers, debug, threads)
26 |
27 |
28 | def main():
29 | parser = argparse.ArgumentParser(description='Run tracker.')
30 | parser.add_argument('experiment_module', type=str, help='Name of experiment module in the experiments/ folder.')
31 | parser.add_argument('experiment_name', type=str, help='Name of the experiment function.')
32 | parser.add_argument('--debug', type=int, default=0, help='Debug level.')
33 | parser.add_argument('--threads', type=int, default=0, help='Number of threads.')
34 |
35 | args = parser.parse_args()
36 |
37 | run_experiment(args.experiment_module, args.experiment_name, args.debug, args.threads)
38 |
39 |
40 | if __name__ == '__main__':
41 | main()
42 |
--------------------------------------------------------------------------------
/pytracking/run_tracker.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | env_path = os.path.join(os.path.dirname(__file__), '..')
6 | if env_path not in sys.path:
7 | sys.path.append(env_path)
8 |
9 | from pytracking.evaluation.otbdataset import OTBDataset
10 | from pytracking.evaluation.nfsdataset import NFSDataset
11 | from pytracking.evaluation.uavdataset import UAVDataset
12 | from pytracking.evaluation.tpldataset import TPLDataset
13 | from pytracking.evaluation.votdataset import VOTDataset
14 | from pytracking.evaluation.lasotdataset import LaSOTDataset
15 | from pytracking.evaluation.trackingnetdataset import TrackingNetDataset
16 | from pytracking.evaluation.got10kdataset import GOT10KDatasetTest, GOT10KDatasetVal, GOT10KDatasetLTRVal
17 | from pytracking.evaluation.running import run_dataset
18 | from pytracking.evaluation import Tracker
19 |
20 |
21 | def run_tracker(tracker_name, tracker_param, run_id=None, dataset_name='otb', sequence=None, debug=0, threads=0):
22 | """Run tracker on sequence or dataset.
23 | args:
24 | tracker_name: Name of tracking method.
25 | tracker_param: Name of parameter file.
26 | run_id: The run id.
27 | dataset_name: Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot).
28 | sequence: Sequence number or name.
29 | debug: Debug level.
30 | threads: Number of threads.
31 | """
32 | if dataset_name == 'otb':
33 | dataset = OTBDataset()
34 | elif dataset_name == 'nfs':
35 | dataset = NFSDataset()
36 | elif dataset_name == 'uav':
37 | dataset = UAVDataset()
38 | elif dataset_name == 'tpl':
39 | dataset = TPLDataset()
40 | elif dataset_name == 'vot':
41 | dataset = VOTDataset()
42 | elif dataset_name == 'tn':
43 | dataset = TrackingNetDataset()
44 | elif dataset_name == 'gott':
45 | dataset = GOT10KDatasetTest()
46 | elif dataset_name == 'gotv':
47 | dataset = GOT10KDatasetVal()
48 | elif dataset_name == 'gotlv':
49 | dataset = GOT10KDatasetLTRVal()
50 | elif dataset_name == 'lasot':
51 | dataset = LaSOTDataset()
52 | else:
53 | raise ValueError('Unknown dataset name')
54 |
55 | if sequence is not None:
56 | dataset = [dataset[sequence]]
57 |
58 | trackers = [Tracker(tracker_name, tracker_param, run_id)]
59 |
60 | run_dataset(dataset, trackers, debug, threads)
61 |
62 |
63 | def main():
64 | parser = argparse.ArgumentParser(description='Run tracker on sequence or dataset.')
65 | parser.add_argument('tracker_name', type=str, help='Name of tracking method.')
66 | parser.add_argument('tracker_param', type=str, help='Name of parameter file.')
67 | parser.add_argument('--runid', type=int, default=None, help='The run id.')
68 | parser.add_argument('--dataset', type=str, default='otb', help='Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot).')
69 | parser.add_argument('--sequence', type=str, default=None, help='Sequence number or name.')
70 | parser.add_argument('--debug', type=int, default=0, help='Debug level.')
71 | parser.add_argument('--threads', type=int, default=0, help='Number of threads.')
72 |
73 | args = parser.parse_args()
74 |
75 | run_tracker(args.tracker_name, args.tracker_param, args.runid, args.dataset, args.sequence, args.debug, args.threads)
76 |
77 |
78 | if __name__ == '__main__':
79 | main()
80 |
--------------------------------------------------------------------------------
/pytracking/run_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | env_path = os.path.join(os.path.dirname(__file__), '..')
6 | if env_path not in sys.path:
7 | sys.path.append(env_path)
8 |
9 | from pytracking.evaluation import Tracker
10 |
11 |
12 | def run_video(tracker_name, tracker_param, videofile, optional_box=None, debug=None):
13 | """Run the tracker on your webcam.
14 | args:
15 | tracker_name: Name of tracking method.
16 | tracker_param: Name of parameter file.
17 | debug: Debug level.
18 | """
19 | tracker = Tracker(tracker_name, tracker_param)
20 | tracker.run_video(videofilepath=videofile, optional_box=optional_box, debug=debug)
21 |
22 | def main():
23 | parser = argparse.ArgumentParser(description='Run the tracker on your webcam.')
24 | parser.add_argument('tracker_name', type=str, help='Name of tracking method.')
25 | parser.add_argument('tracker_param', type=str, help='Name of parameter file.')
26 | parser.add_argument('videofile', type=str, help='path to a video file.')
27 | parser.add_argument('--optional_box', default=None, help='optional_box with format x,y,w,h.')
28 | parser.add_argument('--debug', type=int, default=0, help='Debug level.')
29 |
30 | args = parser.parse_args()
31 |
32 | run_video(args.tracker_name, args.tracker_param,args.videofile, args.optional_box, args.debug)
33 |
34 |
35 | if __name__ == '__main__':
36 | main()
37 |
--------------------------------------------------------------------------------
/pytracking/run_webcam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | env_path = os.path.join(os.path.dirname(__file__), '..')
6 | if env_path not in sys.path:
7 | sys.path.append(env_path)
8 |
9 | from pytracking.evaluation import Tracker
10 |
11 |
12 | def run_webcam(tracker_name, tracker_param, debug=None):
13 | """Run the tracker on your webcam.
14 | args:
15 | tracker_name: Name of tracking method.
16 | tracker_param: Name of parameter file.
17 | debug: Debug level.
18 | """
19 | tracker = Tracker(tracker_name, tracker_param)
20 | tracker.run_webcam(debug)
21 |
22 |
23 | def main():
24 | parser = argparse.ArgumentParser(description='Run the tracker on your webcam.')
25 | parser.add_argument('tracker_name', type=str, help='Name of tracking method.')
26 | parser.add_argument('tracker_param', type=str, help='Name of parameter file.')
27 | parser.add_argument('--debug', type=int, default=0, help='Debug level.')
28 |
29 | args = parser.parse_args()
30 |
31 | run_webcam(args.tracker_name, args.tracker_param, args.debug)
32 |
33 |
34 | if __name__ == '__main__':
35 | main()
--------------------------------------------------------------------------------
/pytracking/tracker/SPSTracker/__init__.py:
--------------------------------------------------------------------------------
1 | from .SPSTracker import SPSTracker
2 |
3 | def get_tracker_class():
4 | return SPSTracker
5 |
--------------------------------------------------------------------------------
/pytracking/tracker/SPSTracker/__pycache__/SPSTracker.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/SPSTracker/__pycache__/SPSTracker.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/tracker/SPSTracker/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/SPSTracker/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/tracker/SPSTracker/__pycache__/optim.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/SPSTracker/__pycache__/optim.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/tracker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/__init__.py
--------------------------------------------------------------------------------
/pytracking/tracker/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/tracker/atom/__init__.py:
--------------------------------------------------------------------------------
1 | from .atom import ATOM
2 |
3 | def get_tracker_class():
4 | return ATOM
--------------------------------------------------------------------------------
/pytracking/tracker/atom/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/atom/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/tracker/atom/__pycache__/atom.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/atom/__pycache__/atom.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/tracker/atom/__pycache__/optim.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/atom/__pycache__/optim.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/tracker/atom/optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytracking import optimization, TensorList, operation
3 | import math
4 |
5 |
6 | class FactorizedConvProblem(optimization.L2Problem):
7 | def __init__(self, training_samples: TensorList, y: TensorList, filter_reg: torch.Tensor, projection_reg, params, sample_weights: TensorList,
8 | projection_activation, response_activation):
9 | self.training_samples = training_samples
10 | self.y = y
11 | self.filter_reg = filter_reg
12 | self.sample_weights = sample_weights
13 | self.params = params
14 | self.projection_reg = projection_reg
15 | self.projection_activation = projection_activation
16 | self.response_activation = response_activation
17 |
18 | self.diag_M = self.filter_reg.concat(projection_reg)
19 |
20 | def __call__(self, x: TensorList):
21 | """
22 | Compute residuals
23 | :param x: [filters, projection_matrices]
24 | :return: [data_terms, filter_regularizations, proj_mat_regularizations]
25 | """
26 | filter = x[:len(x)//2] # w2 in paper
27 | P = x[len(x)//2:] # w1 in paper
28 |
29 | # Do first convolution
30 | compressed_samples = operation.conv1x1(self.training_samples, P).apply(self.projection_activation)
31 |
32 | # Do second convolution
33 | residuals = operation.conv2d(compressed_samples, filter, mode='same').apply(self.response_activation)
34 |
35 | # Compute data residuals
36 | residuals = residuals - self.y
37 |
38 | residuals = self.sample_weights.sqrt().view(-1, 1, 1, 1) * residuals
39 |
40 | # Add regularization for projection matrix
41 | residuals.extend(self.filter_reg.apply(math.sqrt) * filter)
42 |
43 | # Add regularization for projection matrix
44 | residuals.extend(self.projection_reg.apply(math.sqrt) * P)
45 |
46 | return residuals
47 |
48 |
49 | def ip_input(self, a: TensorList, b: TensorList):
50 | num = len(a) // 2 # Number of filters
51 | a_filter = a[:num]
52 | b_filter = b[:num]
53 | a_P = a[num:]
54 | b_P = b[num:]
55 |
56 | # Filter inner product
57 | # ip_out = a_filter.reshape(-1) @ b_filter.reshape(-1)
58 | ip_out = operation.conv2d(a_filter, b_filter).view(-1)
59 |
60 | # Add projection matrix part
61 | # ip_out += a_P.reshape(-1) @ b_P.reshape(-1)
62 | ip_out += operation.conv2d(a_P.view(1,-1,1,1), b_P.view(1,-1,1,1)).view(-1)
63 |
64 | # Have independent inner products for each filter
65 | return ip_out.concat(ip_out.clone())
66 |
67 | def M1(self, x: TensorList):
68 | return x / self.diag_M
69 |
70 |
71 | class ConvProblem(optimization.L2Problem):
72 | def __init__(self, training_samples: TensorList, y: TensorList, filter_reg: torch.Tensor, sample_weights: TensorList, response_activation):
73 | self.training_samples = training_samples
74 | self.y = y
75 | self.filter_reg = filter_reg
76 | self.sample_weights = sample_weights
77 | self.response_activation = response_activation
78 |
79 | def __call__(self, x: TensorList):
80 | """
81 | Compute residuals
82 | :param x: [filters]
83 | :return: [data_terms, filter_regularizations]
84 | """
85 | # Do convolution and compute residuals
86 | residuals = operation.conv2d(self.training_samples, x, mode='same').apply(self.response_activation)
87 | residuals = residuals - self.y
88 |
89 | residuals = self.sample_weights.sqrt().view(-1, 1, 1, 1) * residuals
90 |
91 | # Add regularization for projection matrix
92 | residuals.extend(self.filter_reg.apply(math.sqrt) * x)
93 |
94 | return residuals
95 |
96 | def ip_input(self, a: TensorList, b: TensorList):
97 | # return a.reshape(-1) @ b.reshape(-1)
98 | # return (a * b).sum()
99 | return operation.conv2d(a, b).view(-1)
100 |
--------------------------------------------------------------------------------
/pytracking/tracker/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .basetracker import BaseTracker
--------------------------------------------------------------------------------
/pytracking/tracker/base/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/base/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/tracker/base/__pycache__/basetracker.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/base/__pycache__/basetracker.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__init__.py
--------------------------------------------------------------------------------
/pytracking/util/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/util/__pycache__/anchors.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/anchors.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/util/__pycache__/bbox_helper.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/bbox_helper.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/util/__pycache__/config_helper.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/config_helper.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/util/__pycache__/load_helper.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/load_helper.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/util/__pycache__/tracker_config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/tracker_config.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/util/anchors.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | import numpy as np
7 | import math
8 | from util.bbox_helper import center2corner, corner2center
9 |
10 |
11 | class Anchors:
12 | def __init__(self, cfg):
13 | self.stride = 8
14 | self.ratios = [0.33, 0.5, 1, 2, 3]
15 | self.scales = [8]
16 | self.round_dight = 0
17 | self.image_center = 0
18 | self.size = 0
19 |
20 | self.__dict__.update(cfg)
21 |
22 | self.anchor_num = len(self.scales) * len(self.ratios)
23 | self.anchors = None # in single position (anchor_num*4)
24 | self.all_anchors = None # in all position 2*(4*anchor_num*h*w)
25 | self.generate_anchors()
26 |
27 | def generate_anchors(self):
28 | self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32)
29 |
30 | size = self.stride * self.stride
31 | count = 0
32 | for r in self.ratios:
33 | if self.round_dight > 0:
34 | ws = round(math.sqrt(size*1. / r), self.round_dight)
35 | hs = round(ws * r, self.round_dight)
36 | else:
37 | ws = int(math.sqrt(size*1. / r))
38 | hs = int(ws * r)
39 |
40 | for s in self.scales:
41 | w = ws * s
42 | h = hs * s
43 | self.anchors[count][:] = [-w*0.5, -h*0.5, w*0.5, h*0.5][:]
44 | count += 1
45 |
46 | def generate_all_anchors(self, im_c, size):
47 | if self.image_center == im_c and self.size == size:
48 | return False
49 | self.image_center = im_c
50 | self.size = size
51 |
52 | a0x = im_c - size // 2 * self.stride
53 | ori = np.array([a0x] * 4, dtype=np.float32)
54 | zero_anchors = self.anchors + ori
55 |
56 | x1 = zero_anchors[:, 0]
57 | y1 = zero_anchors[:, 1]
58 | x2 = zero_anchors[:, 2]
59 | y2 = zero_anchors[:, 3]
60 |
61 | x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), [x1, y1, x2, y2])
62 | cx, cy, w, h = corner2center([x1, y1, x2, y2])
63 |
64 | disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride
65 | disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride
66 |
67 | cx = cx + disp_x
68 | cy = cy + disp_y
69 |
70 | # broadcast
71 | zero = np.zeros((self.anchor_num, size, size), dtype=np.float32)
72 | cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h])
73 | x1, y1, x2, y2 = center2corner([cx, cy, w, h])
74 |
75 | self.all_anchors = np.stack([x1, y1, x2, y2]), np.stack([cx, cy, w, h])
76 | return True
77 |
78 |
79 |
--------------------------------------------------------------------------------
/pytracking/util/bbox_helper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | import numpy as np
7 | from collections import namedtuple
8 |
9 | Corner = namedtuple('Corner', 'x1 y1 x2 y2')
10 | BBox = Corner
11 | Center = namedtuple('Center', 'x y w h')
12 |
13 |
14 | def corner2center(corner):
15 | """
16 | :param corner: Corner or np.array 4*N
17 | :return: Center or 4 np.array N
18 | """
19 | if isinstance(corner, Corner):
20 | x1, y1, x2, y2 = corner
21 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1))
22 | else:
23 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3]
24 | x = (x1 + x2) * 0.5
25 | y = (y1 + y2) * 0.5
26 | w = x2 - x1
27 | h = y2 - y1
28 | return x, y, w, h
29 |
30 |
31 | def center2corner(center):
32 | """
33 | :param center: Center or np.array 4*N
34 | :return: Corner or np.array 4*N
35 | """
36 | if isinstance(center, Center):
37 | x, y, w, h = center
38 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5)
39 | else:
40 | x, y, w, h = center[0], center[1], center[2], center[3]
41 | x1 = x - w * 0.5
42 | y1 = y - h * 0.5
43 | x2 = x + w * 0.5
44 | y2 = y + h * 0.5
45 | return x1, y1, x2, y2
46 |
47 |
48 | def cxy_wh_2_rect(pos, sz):
49 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) # 0-index
50 |
51 |
52 | def get_axis_aligned_bbox(region):
53 | nv = region.size
54 | if nv == 8:
55 | cx = np.mean(region[0::2])
56 | cy = np.mean(region[1::2])
57 | x1 = min(region[0::2])
58 | x2 = max(region[0::2])
59 | y1 = min(region[1::2])
60 | y2 = max(region[1::2])
61 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[2:4] - region[4:6])
62 | A2 = (x2 - x1) * (y2 - y1)
63 | s = np.sqrt(A1 / A2)
64 | w = s * (x2 - x1) + 1
65 | h = s * (y2 - y1) + 1
66 | else:
67 | x = region[0]
68 | y = region[1]
69 | w = region[2]
70 | h = region[3]
71 | cx = x+w/2
72 | cy = y+h/2
73 |
74 | return cx, cy, w, h
75 |
76 |
77 |
--------------------------------------------------------------------------------
/pytracking/util/benchmark_helper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | from os.path import join, realpath, dirname, exists, isdir
7 | from os import listdir
8 | import logging
9 | import glob
10 | import numpy as np
11 | import json
12 | from collections import OrderedDict
13 |
14 |
15 | def get_dataset_zoo():
16 | root = realpath(join(dirname(__file__), '../data'))
17 | zoos = listdir(root)
18 |
19 | def valid(x):
20 | y = join(root, x)
21 | if not isdir(y): return False
22 |
23 | return exists(join(y, 'list.txt')) \
24 | or exists(join(y, 'train', 'meta.json'))\
25 | or exists(join(y, 'ImageSets', '2016', 'val.txt'))
26 |
27 | zoos = list(filter(valid, zoos))
28 | return zoos
29 |
30 |
31 | dataset_zoo = get_dataset_zoo()
32 |
33 |
34 | def load_dataset(dataset):
35 | info = OrderedDict()
36 | if 'VOT' in dataset:
37 | base_path = join(realpath(dirname(__file__)), '/home/ajun/daima/code/SiamMask/data', dataset)
38 | if not exists(base_path):
39 | logging.error("Please download test dataset!!!")
40 | exit()
41 | list_path = join(base_path, 'list.txt')
42 | with open(list_path) as f:
43 | videos = [v.strip() for v in f.readlines()]
44 | for video in videos:
45 | video_path = join(base_path, video)
46 | image_path = join(video_path, '*.jpg')
47 | image_files = sorted(glob.glob(image_path))
48 | if len(image_files) == 0: # VOT2018
49 | image_path = join(video_path, 'color', '*.jpg')
50 | image_files = sorted(glob.glob(image_path))
51 | gt_path = join(video_path, 'groundtruth.txt')
52 | gt = np.loadtxt(gt_path, delimiter=',').astype(np.float64)
53 | if gt.shape[1] == 4:
54 | gt = np.column_stack((gt[:, 0], gt[:, 1], gt[:, 0], gt[:, 1] + gt[:, 3]-1,
55 | gt[:, 0] + gt[:, 2]-1, gt[:, 1] + gt[:, 3]-1, gt[:, 0] + gt[:, 2]-1, gt[:, 1]))
56 | info[video] = {'image_files': image_files, 'gt': gt, 'name': video}
57 | elif 'DAVIS' in dataset:
58 | base_path = join(realpath(dirname(__file__)), '../data', 'DAVIS')
59 | list_path = join(realpath(dirname(__file__)), '../data', 'DAVIS', 'ImageSets', dataset[-4:], 'val.txt')
60 | with open(list_path) as f:
61 | videos = [v.strip() for v in f.readlines()]
62 | for video in videos:
63 | info[video] = {}
64 | info[video]['anno_files'] = sorted(glob.glob(join(base_path, 'Annotations/480p', video, '*.png')))
65 | info[video]['image_files'] = sorted(glob.glob(join(base_path, 'JPEGImages/480p', video, '*.jpg')))
66 | info[video]['name'] = video
67 | elif 'ytb_vos' in dataset:
68 | base_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid')
69 | json_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid', 'meta.json')
70 | meta = json.load(open(json_path, 'r'))
71 | meta = meta['videos']
72 | info = dict()
73 | for v in meta.keys():
74 | objects = meta[v]['objects']
75 | frames = []
76 | anno_frames = []
77 | info[v] = dict()
78 | for obj in objects:
79 | frames += objects[obj]['frames']
80 | anno_frames += [objects[obj]['frames'][0]]
81 | frames = sorted(np.unique(frames))
82 | info[v]['anno_files'] = [join(base_path, 'Annotations', v, im_f+'.png') for im_f in frames]
83 | info[v]['anno_init_files'] = [join(base_path, 'Annotations', v, im_f + '.png') for im_f in anno_frames]
84 | info[v]['image_files'] = [join(base_path, 'JPEGImages', v, im_f+'.jpg') for im_f in frames]
85 | info[v]['name'] = v
86 |
87 | info[v]['start_frame'] = dict()
88 | info[v]['end_frame'] = dict()
89 | for obj in objects:
90 | start_file = objects[obj]['frames'][0]
91 | end_file = objects[obj]['frames'][-1]
92 | info[v]['start_frame'][obj] = frames.index(start_file)
93 | info[v]['end_frame'][obj] = frames.index(end_file)
94 | else:
95 | logging.error('Not support')
96 | exit()
97 | return info
98 |
--------------------------------------------------------------------------------
/pytracking/util/config_helper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | import json
7 | from os.path import exists
8 |
9 |
10 | def load_config(args):
11 | assert exists(args.config), '"{}" not exists'.format(args.config)
12 | config = json.load(open(args.config))
13 |
14 | # deal with network
15 | if 'network' not in config:
16 | print('Warning: network lost in config. This will be error in next version')
17 |
18 | config['network'] = {}
19 |
20 | if not args.arch:
21 | raise Exception('no arch provided')
22 | args.arch = config['network']['arch']
23 |
24 | return config
25 |
26 |
--------------------------------------------------------------------------------
/pytracking/util/load_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | logger = logging.getLogger('global')
4 |
5 |
6 | def check_keys(model, pretrained_state_dict):
7 | ckpt_keys = set(pretrained_state_dict.keys())
8 | model_keys = set(model.state_dict().keys())
9 | used_pretrained_keys = model_keys & ckpt_keys
10 | unused_pretrained_keys = ckpt_keys - model_keys
11 | missing_keys = model_keys - ckpt_keys
12 | if len(missing_keys) > 0:
13 | logger.info('[Warning] missing keys: {}'.format(missing_keys))
14 | logger.info('missing keys:{}'.format(len(missing_keys)))
15 | if len(unused_pretrained_keys) > 0:
16 | logger.info('[Warning] unused_pretrained_keys: {}'.format(unused_pretrained_keys))
17 | logger.info('unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
18 | logger.info('used keys:{}'.format(len(used_pretrained_keys)))
19 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
20 | return True
21 |
22 |
23 | def remove_prefix(state_dict, prefix):
24 | ''' Old style model is stored with all names of parameters share common prefix 'module.' '''
25 | logger.info('remove prefix \'{}\''.format(prefix))
26 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
27 | return {f(key): value for key, value in state_dict.items()}
28 |
29 |
30 | def load_pretrain(model, pretrained_path):
31 | logger.info('load pretrained model from {}'.format(pretrained_path))
32 | device = torch.cuda.current_device()
33 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
34 | if "state_dict" in pretrained_dict.keys():
35 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
36 | else:
37 | pretrained_dict = remove_prefix(pretrained_dict, 'module.')
38 |
39 | try:
40 | check_keys(model, pretrained_dict)
41 | except:
42 | logger.info('[Warning]: using pretrain as features. Adding "features." as prefix')
43 | new_dict = {}
44 | for k, v in pretrained_dict.items():
45 | k = 'features.' + k
46 | new_dict[k] = v
47 | pretrained_dict = new_dict
48 | check_keys(model, pretrained_dict)
49 | model.load_state_dict(pretrained_dict, strict=False)
50 | return model
51 |
--------------------------------------------------------------------------------
/pytracking/util/log_helper.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | from __future__ import division
7 |
8 | import os
9 | import logging
10 | import sys
11 |
12 | if hasattr(sys, 'frozen'): # support for py2exe
13 | _srcfile = "logging%s__init__%s" % (os.sep, __file__[-4:])
14 | elif __file__[-4:].lower() in ['.pyc', '.pyo']:
15 | _srcfile = __file__[:-4] + '.py'
16 | else:
17 | _srcfile = __file__
18 | _srcfile = os.path.normcase(_srcfile)
19 |
20 |
21 | logs = set()
22 |
23 |
24 | class Filter:
25 | def __init__(self, flag):
26 | self.flag = flag
27 |
28 | def filter(self, x): return self.flag
29 |
30 |
31 | class Dummy:
32 | def __init__(self, *arg, **kwargs):
33 | pass
34 |
35 | def __getattr__(self, arg):
36 | def dummy(*args, **kwargs): pass
37 | return dummy
38 |
39 |
40 | def get_format(logger, level):
41 | if 'SLURM_PROCID' in os.environ:
42 | rank = int(os.environ['SLURM_PROCID'])
43 |
44 | if level == logging.INFO:
45 | logger.addFilter(Filter(rank == 0))
46 | else:
47 | rank = 0
48 | format_str = '[%(asctime)s-rk{}-%(filename)s#%(lineno)3d] %(message)s'.format(rank)
49 | formatter = logging.Formatter(format_str)
50 | return formatter
51 |
52 |
53 | def init_log(name, level = logging.INFO, format_func=get_format):
54 | if (name, level) in logs: return
55 | logs.add((name, level))
56 | logger = logging.getLogger(name)
57 | logger.setLevel(level)
58 | ch = logging.StreamHandler()
59 | ch.setLevel(level)
60 | formatter = format_func(logger, level)
61 | ch.setFormatter(formatter)
62 | logger.addHandler(ch)
63 | return logger
64 |
65 |
66 | def add_file_handler(name, log_file, level = logging.INFO):
67 | logger = logging.getLogger(name)
68 | fh = logging.FileHandler(log_file)
69 | fh.setFormatter(get_format(logger, level))
70 | logger.addHandler(fh)
71 |
72 |
73 | init_log('global')
74 |
75 |
--------------------------------------------------------------------------------
/pytracking/util/pysot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pysot/__init__.py
--------------------------------------------------------------------------------
/pytracking/util/pysot/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from .vot import VOTDataset
10 |
11 |
12 | class DatasetFactory(object):
13 | @staticmethod
14 | def create_dataset(**kwargs):
15 | """
16 | Args:
17 | name: dataset name 'VOT2018', 'VOT2016'
18 | dataset_root: dataset root
19 | Return:
20 | dataset
21 | """
22 | assert 'name' in kwargs, "should provide dataset name"
23 | name = kwargs['name']
24 | if 'VOT2018' == name or 'VOT2016' == name:
25 | dataset = VOTDataset(**kwargs)
26 | else:
27 | raise Exception("unknow dataset {}".format(kwargs['name']))
28 | return dataset
29 |
30 |
--------------------------------------------------------------------------------
/pytracking/util/pysot/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | class Dataset(object):
10 | def __init__(self, name, dataset_root):
11 | self.name = name
12 | self.dataset_root = dataset_root
13 | self.videos = None
14 |
15 | def __getitem__(self, idx):
16 | if isinstance(idx, str):
17 | return self.videos[idx]
18 | elif isinstance(idx, int):
19 | return self.videos[sorted(list(self.videos.keys()))[idx]]
20 |
21 | def __len__(self):
22 | return len(self.videos)
23 |
24 | def __iter__(self):
25 | keys = sorted(list(self.videos.keys()))
26 | for key in keys:
27 | yield self.videos[key]
28 |
29 | def set_tracker(self, path, tracker_names):
30 | """
31 | Args:
32 | path: path to tracker results,
33 | tracker_names: list of tracker name
34 | """
35 | self.tracker_path = path
36 | self.tracker_names = tracker_names
37 |
--------------------------------------------------------------------------------
/pytracking/util/pysot/datasets/video.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | import os
10 | import cv2
11 |
12 | from glob import glob
13 |
14 |
15 | class Video(object):
16 | def __init__(self, name, root, video_dir, init_rect, img_names,
17 | gt_rect, attr):
18 | self.name = name
19 | self.video_dir = video_dir
20 | self.init_rect = init_rect
21 | self.gt_traj = gt_rect
22 | self.attr = attr
23 | self.pred_trajs = {}
24 | self.img_names = [os.path.join(root, x) for x in img_names]
25 | self.imgs = None
26 |
27 | def load_tracker(self, path, tracker_names=None, store=True):
28 | """
29 | Args:
30 | path(str): path to result
31 | tracker_name(list): name of tracker
32 | """
33 | if not tracker_names:
34 | tracker_names = [x.split('/')[-1] for x in glob(path)
35 | if os.path.isdir(x)]
36 | if isinstance(tracker_names, str):
37 | tracker_names = [tracker_names]
38 | for name in tracker_names:
39 | traj_file = os.path.join(path, name, self.name+'.txt')
40 | if os.path.exists(traj_file):
41 | with open(traj_file, 'r') as f :
42 | pred_traj = [list(map(float, x.strip().split(',')))
43 | for x in f.readlines()]
44 | if len(pred_traj) != len(self.gt_traj):
45 | print(name, len(pred_traj), len(self.gt_traj), self.name)
46 | if store:
47 | self.pred_trajs[name] = pred_traj
48 | else:
49 | return pred_traj
50 | else:
51 | print(traj_file)
52 | self.tracker_names = list(self.pred_trajs.keys())
53 |
54 | def load_img(self):
55 | if self.imgs is None:
56 | self.imgs = [cv2.imread(x)
57 | for x in self.img_names]
58 | self.width = self.imgs[0].shape[1]
59 | self.height = self.imgs[0].shape[0]
60 |
61 | def free_img(self):
62 | self.imgs = None
63 |
64 | def __len__(self):
65 | return len(self.img_names)
66 |
67 | def __getitem__(self, idx):
68 | if self.imgs is None:
69 | return cv2.imread(self.img_names[idx]), \
70 | self.gt_traj[idx]
71 | else:
72 | return self.imgs[idx], self.gt_traj[idx]
73 |
74 | def __iter__(self):
75 | for i in range(len(self.img_names)):
76 | if self.imgs is not None:
77 | yield self.imgs[i], self.gt_traj[i]
78 | else:
79 | yield cv2.imread(self.img_names[i]), \
80 | self.gt_traj[i]
81 |
--------------------------------------------------------------------------------
/pytracking/util/pysot/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from .ar_benchmark import AccuracyRobustnessBenchmark
10 | from .eao_benchmark import EAOBenchmark
11 |
--------------------------------------------------------------------------------
/pytracking/util/pysot/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from . import region
10 | from .statistics import *
11 |
--------------------------------------------------------------------------------
/pytracking/util/pysot/utils/build/temp.linux-x86_64-3.6/region.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pysot/utils/build/temp.linux-x86_64-3.6/region.o
--------------------------------------------------------------------------------
/pytracking/util/pysot/utils/build/temp.linux-x86_64-3.6/src/region.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pysot/utils/build/temp.linux-x86_64-3.6/src/region.o
--------------------------------------------------------------------------------
/pytracking/util/pysot/utils/c_region.pxd:
--------------------------------------------------------------------------------
1 | cdef extern from "src/region.h":
2 | ctypedef enum region_type "RegionType":
3 | EMTPY
4 | SPECIAL
5 | RECTANGEL
6 | POLYGON
7 | MASK
8 |
9 | ctypedef struct region_bounds:
10 | float top
11 | float bottom
12 | float left
13 | float right
14 |
15 | ctypedef struct region_rectangle:
16 | float x
17 | float y
18 | float width
19 | float height
20 |
21 | # ctypedef struct region_mask:
22 | # int x
23 | # int y
24 | # int width
25 | # int height
26 | # char *data
27 |
28 | ctypedef struct region_polygon:
29 | int count
30 | float *x
31 | float *y
32 |
33 | ctypedef union region_container_data:
34 | region_rectangle rectangle
35 | region_polygon polygon
36 | # region_mask mask
37 | int special
38 |
39 | ctypedef struct region_container:
40 | region_type type
41 | region_container_data data
42 |
43 | # ctypedef struct region_overlap:
44 | # float overlap
45 | # float only1
46 | # float only2
47 |
48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds)
49 |
50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds)
51 |
--------------------------------------------------------------------------------
/pytracking/util/pysot/utils/misc.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | import numpy as np
10 |
11 | def determine_thresholds(confidence, resolution=100):
12 | """choose threshold according to confidence
13 |
14 | Args:
15 | confidence: list or numpy array or numpy array
16 | reolution: number of threshold to choose
17 |
18 | Restures:
19 | threshold: numpy array
20 | """
21 | if isinstance(confidence, list):
22 | confidence = np.array(confidence)
23 | confidence = confidence.flatten()
24 | confidence = confidence[~np.isnan(confidence)]
25 | confidence.sort()
26 |
27 | assert len(confidence) > resolution and resolution > 2
28 |
29 | thresholds = np.ones((resolution))
30 | thresholds[0] = - np.inf
31 | thresholds[-1] = np.inf
32 | delta = np.floor(len(confidence) / (resolution - 2))
33 | idxs = np.linspace(delta, len(confidence)-delta, resolution-2, dtype=np.int32)
34 | thresholds[1:-1] = confidence[idxs]
35 | return thresholds
36 |
--------------------------------------------------------------------------------
/pytracking/util/pysot/utils/region.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pysot/utils/region.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/pytracking/util/pysot/utils/setup.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from distutils.core import setup
10 | from distutils.extension import Extension
11 | from Cython.Build import cythonize
12 |
13 | setup(
14 | ext_modules = cythonize([Extension("region", ["region.pyx", "src/region.c"])]),
15 | )
16 |
17 |
--------------------------------------------------------------------------------
/pytracking/util/pysot/utils/src/region.h:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */
2 |
3 | #ifndef _REGION_H_
4 | #define _REGION_H_
5 |
6 | #ifdef TRAX_STATIC_DEFINE
7 | # define __TRAX_EXPORT
8 | #else
9 | # ifndef __TRAX_EXPORT
10 | # if defined(_MSC_VER)
11 | # ifdef trax_EXPORTS
12 | /* We are building this library */
13 | # define __TRAX_EXPORT __declspec(dllexport)
14 | # else
15 | /* We are using this library */
16 | # define __TRAX_EXPORT __declspec(dllimport)
17 | # endif
18 | # elif defined(__GNUC__)
19 | # ifdef trax_EXPORTS
20 | /* We are building this library */
21 | # define __TRAX_EXPORT __attribute__((visibility("default")))
22 | # else
23 | /* We are using this library */
24 | # define __TRAX_EXPORT __attribute__((visibility("default")))
25 | # endif
26 | # endif
27 | # endif
28 | #endif
29 |
30 | #ifndef MAX
31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b))
32 | #endif
33 |
34 | #ifndef MIN
35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b))
36 | #endif
37 |
38 | #define TRAX_DEFAULT_CODE 0
39 |
40 | #define REGION_LEGACY_RASTERIZATION 1
41 |
42 | #ifdef __cplusplus
43 | extern "C" {
44 | #endif
45 |
46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type;
47 |
48 | typedef struct region_bounds {
49 |
50 | float top;
51 | float bottom;
52 | float left;
53 | float right;
54 |
55 | } region_bounds;
56 |
57 | typedef struct region_polygon {
58 |
59 | int count;
60 |
61 | float* x;
62 | float* y;
63 |
64 | } region_polygon;
65 |
66 | typedef struct region_mask {
67 |
68 | int x;
69 | int y;
70 |
71 | int width;
72 | int height;
73 |
74 | char* data;
75 |
76 | } region_mask;
77 |
78 | typedef struct region_rectangle {
79 |
80 | float x;
81 | float y;
82 | float width;
83 | float height;
84 |
85 | } region_rectangle;
86 |
87 | typedef struct region_container {
88 | enum region_type type;
89 | union {
90 | region_rectangle rectangle;
91 | region_polygon polygon;
92 | region_mask mask;
93 | int special;
94 | } data;
95 | } region_container;
96 |
97 | typedef struct region_overlap {
98 |
99 | float overlap;
100 | float only1;
101 | float only2;
102 |
103 | } region_overlap;
104 |
105 | extern const region_bounds region_no_bounds;
106 |
107 | __TRAX_EXPORT int region_set_flags(int mask);
108 |
109 | __TRAX_EXPORT int region_clear_flags(int mask);
110 |
111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds);
112 |
113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds);
114 |
115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom);
116 |
117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region);
118 |
119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region);
120 |
121 | __TRAX_EXPORT char* region_string(region_container* region);
122 |
123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region);
124 |
125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type);
126 |
127 | __TRAX_EXPORT void region_release(region_container** region);
128 |
129 | __TRAX_EXPORT region_container* region_create_special(int code);
130 |
131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height);
132 |
133 | __TRAX_EXPORT region_container* region_create_polygon(int count);
134 |
135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y);
136 |
137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height);
138 |
139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height);
140 |
141 | #ifdef __cplusplus
142 | }
143 | #endif
144 |
145 | #endif
146 |
--------------------------------------------------------------------------------
/pytracking/util/pyvotkit/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from . import region
10 |
--------------------------------------------------------------------------------
/pytracking/util/pyvotkit/build/temp.linux-x86_64-3.6/region.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pyvotkit/build/temp.linux-x86_64-3.6/region.o
--------------------------------------------------------------------------------
/pytracking/util/pyvotkit/build/temp.linux-x86_64-3.6/src/region.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pyvotkit/build/temp.linux-x86_64-3.6/src/region.o
--------------------------------------------------------------------------------
/pytracking/util/pyvotkit/c_region.pxd:
--------------------------------------------------------------------------------
1 | cdef extern from "region.h":
2 | ctypedef enum region_type "RegionType":
3 | EMTPY
4 | SPECIAL
5 | RECTANGEL
6 | POLYGON
7 | MASK
8 |
9 | ctypedef struct region_bounds:
10 | float top
11 | float bottom
12 | float left
13 | float right
14 |
15 | ctypedef struct region_rectangle:
16 | float x
17 | float y
18 | float width
19 | float height
20 |
21 | # ctypedef struct region_mask:
22 | # int x
23 | # int y
24 | # int width
25 | # int height
26 | # char *data
27 |
28 | ctypedef struct region_polygon:
29 | int count
30 | float *x
31 | float *y
32 |
33 | ctypedef union region_container_data:
34 | region_rectangle rectangle
35 | region_polygon polygon
36 | # region_mask mask
37 | int special
38 |
39 | ctypedef struct region_container:
40 | region_type type
41 | region_container_data data
42 |
43 | # ctypedef struct region_overlap:
44 | # float overlap
45 | # float only1
46 | # float only2
47 |
48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds)
49 |
50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds)
51 |
--------------------------------------------------------------------------------
/pytracking/util/pyvotkit/region.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pyvotkit/region.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/pytracking/util/pyvotkit/setup.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Python Single Object Tracking Evaluation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Fangyi Zhang
5 | # @author fangyi.zhang@vipl.ict.ac.cn
6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git
7 | # Revised for SiamMask by foolwood
8 | # --------------------------------------------------------
9 | from distutils.core import setup
10 | from distutils.extension import Extension
11 | from Cython.Build import cythonize
12 |
13 | setup(
14 | ext_modules = cythonize([Extension("region", ["region.pyx"])])
15 | )
16 |
17 |
--------------------------------------------------------------------------------
/pytracking/util/pyvotkit/src/region.h:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */
2 |
3 | #ifndef _REGION_H_
4 | #define _REGION_H_
5 |
6 | #ifdef TRAX_STATIC_DEFINE
7 | # define __TRAX_EXPORT
8 | #else
9 | # ifndef __TRAX_EXPORT
10 | # if defined(_MSC_VER)
11 | # ifdef trax_EXPORTS
12 | /* We are building this library */
13 | # define __TRAX_EXPORT __declspec(dllexport)
14 | # else
15 | /* We are using this library */
16 | # define __TRAX_EXPORT __declspec(dllimport)
17 | # endif
18 | # elif defined(__GNUC__)
19 | # ifdef trax_EXPORTS
20 | /* We are building this library */
21 | # define __TRAX_EXPORT __attribute__((visibility("default")))
22 | # else
23 | /* We are using this library */
24 | # define __TRAX_EXPORT __attribute__((visibility("default")))
25 | # endif
26 | # endif
27 | # endif
28 | #endif
29 |
30 | #ifndef MAX
31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b))
32 | #endif
33 |
34 | #ifndef MIN
35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b))
36 | #endif
37 |
38 | #define TRAX_DEFAULT_CODE 0
39 |
40 | #define REGION_LEGACY_RASTERIZATION 1
41 |
42 | #ifdef __cplusplus
43 | extern "C" {
44 | #endif
45 |
46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type;
47 |
48 | typedef struct region_bounds {
49 |
50 | float top;
51 | float bottom;
52 | float left;
53 | float right;
54 |
55 | } region_bounds;
56 |
57 | typedef struct region_polygon {
58 |
59 | int count;
60 |
61 | float* x;
62 | float* y;
63 |
64 | } region_polygon;
65 |
66 | typedef struct region_mask {
67 |
68 | int x;
69 | int y;
70 |
71 | int width;
72 | int height;
73 |
74 | char* data;
75 |
76 | } region_mask;
77 |
78 | typedef struct region_rectangle {
79 |
80 | float x;
81 | float y;
82 | float width;
83 | float height;
84 |
85 | } region_rectangle;
86 |
87 | typedef struct region_container {
88 | enum region_type type;
89 | union {
90 | region_rectangle rectangle;
91 | region_polygon polygon;
92 | region_mask mask;
93 | int special;
94 | } data;
95 | } region_container;
96 |
97 | typedef struct region_overlap {
98 |
99 | float overlap;
100 | float only1;
101 | float only2;
102 |
103 | } region_overlap;
104 |
105 | extern const region_bounds region_no_bounds;
106 |
107 | __TRAX_EXPORT int region_set_flags(int mask);
108 |
109 | __TRAX_EXPORT int region_clear_flags(int mask);
110 |
111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds);
112 |
113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds);
114 |
115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom);
116 |
117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region);
118 |
119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region);
120 |
121 | __TRAX_EXPORT char* region_string(region_container* region);
122 |
123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region);
124 |
125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type);
126 |
127 | __TRAX_EXPORT void region_release(region_container** region);
128 |
129 | __TRAX_EXPORT region_container* region_create_special(int code);
130 |
131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height);
132 |
133 | __TRAX_EXPORT region_container* region_create_polygon(int count);
134 |
135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y);
136 |
137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height);
138 |
139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height);
140 |
141 | #ifdef __cplusplus
142 | }
143 | #endif
144 |
145 | #endif
146 |
--------------------------------------------------------------------------------
/pytracking/util/tracker_config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SiamMask
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | from __future__ import division
7 | from util.anchors import Anchors
8 |
9 |
10 | class TrackerConfig(object):
11 | # These are the default hyper-params for SiamMask
12 | penalty_k = 0.04
13 | window_influence = 0.42
14 | lr = 0.25
15 | seg_thr = 0.3 # for mask
16 | windowing = 'cosine' # to penalize large displacements [cosine/uniform]
17 | # Params from the network architecture, have to be consistent with the training
18 | exemplar_size = 127 # input z size
19 | instance_size = 255 # input x size (search region)
20 | total_stride = 8
21 | out_size = 63 # for mask
22 | base_size = 8
23 | score_size = (instance_size-exemplar_size)//total_stride+1+base_size
24 | context_amount = 0.5 # context amount for the exemplar
25 | ratios = [0.33, 0.5, 1, 2, 3]
26 | scales = [8, ]
27 | anchor_num = len(ratios) * len(scales)
28 | round_dight = 0
29 | anchor = []
30 |
31 | def update(self, newparam=None, anchors=None):
32 | if newparam:
33 | for key, value in newparam.items():
34 | setattr(self, key, value)
35 | if anchors is not None:
36 | if isinstance(anchors, dict):
37 | anchors = Anchors(anchors)
38 | if isinstance(anchors, Anchors):
39 | self.total_stride = anchors.stride
40 | self.ratios = anchors.ratios
41 | self.scales = anchors.scales
42 | self.round_dight = anchors.round_dight
43 | self.renew()
44 |
45 | def renew(self):
46 | self.score_size = (self.instance_size - self.exemplar_size) // self.total_stride + 1 + self.base_size
47 | self.anchor_num = len(self.ratios) * len(self.scales)
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/pytracking/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # from .evaluation import *
2 | from .params import *
--------------------------------------------------------------------------------
/pytracking/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/utils/__pycache__/params.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/utils/__pycache__/params.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/utils/__pycache__/plotting.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/utils/__pycache__/plotting.cpython-36.pyc
--------------------------------------------------------------------------------
/pytracking/utils/gdrive_download:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # The script taken from https://www.matthuisman.nz/2019/01/download-google-drive-files-wget-curl.html
4 |
5 | url=$1
6 | filename=$2
7 |
8 | [ -z "$url" ] && echo A URL or ID is required first argument && exit 1
9 |
10 | fileid=""
11 | declare -a patterns=("s/.*\/file\/d\/\(.*\)\/.*/\1/p" "s/.*id\=\(.*\)/\1/p" "s/\(.*\)/\1/p")
12 | for i in "${patterns[@]}"
13 | do
14 | fileid=$(echo $url | sed -n $i)
15 | [ ! -z "$fileid" ] && break
16 | done
17 |
18 | [ -z "$fileid" ] && echo Could not find Google ID && exit 1
19 |
20 | echo File ID: $fileid
21 |
22 | tmp_file="$filename.$$.file"
23 | tmp_cookies="$filename.$$.cookies"
24 | tmp_headers="$filename.$$.headers"
25 |
26 | url='https://docs.google.com/uc?export=download&id='$fileid
27 | echo Downloading: "$url > $tmp_file"
28 | wget --save-cookies "$tmp_cookies" -q -S -O - $url 2> "$tmp_headers" 1> "$tmp_file"
29 |
30 | if [[ ! $(find "$tmp_file" -type f -size +10000c 2>/dev/null) ]]; then
31 | confirm=$(cat "$tmp_file" | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1/p')
32 | fi
33 |
34 | if [ ! -z "$confirm" ]; then
35 | url='https://docs.google.com/uc?export=download&id='$fileid'&confirm='$confirm
36 | echo Downloading: "$url > $tmp_file"
37 | wget --load-cookies "$tmp_cookies" -q -S -O - $url 2> "$tmp_headers" 1> "$tmp_file"
38 | fi
39 |
40 | [ -z "$filename" ] && filename=$(cat "$tmp_headers" | sed -rn 's/.*filename=\"(.*)\".*/\1/p')
41 | [ -z "$filename" ] && filename="google_drive.file"
42 |
43 | echo Moving: "$tmp_file > $filename"
44 |
45 | mv "$tmp_file" "$filename"
46 |
47 | rm -f "$tmp_cookies" "$tmp_headers"
48 |
49 | echo Saved: "$filename"
50 | echo DONE!
51 |
52 | exit 0
53 |
--------------------------------------------------------------------------------
/pytracking/utils/params.py:
--------------------------------------------------------------------------------
1 | from pytracking import TensorList
2 | import random
3 |
4 |
5 | class TrackerParams:
6 | """Class for tracker parameters."""
7 | def free_memory(self):
8 | for a in dir(self):
9 | if not a.startswith('__') and hasattr(getattr(self, a), 'free_memory'):
10 | getattr(self, a).free_memory()
11 |
12 |
13 | class FeatureParams:
14 | """Class for feature specific parameters"""
15 | def __init__(self, *args, **kwargs):
16 | if len(args) > 0:
17 | raise ValueError
18 |
19 | for name, val in kwargs.items():
20 | if isinstance(val, list):
21 | setattr(self, name, TensorList(val))
22 | else:
23 | setattr(self, name, val)
24 |
25 |
26 | def Choice(*args):
27 | """Can be used to sample random parameter values."""
28 | return random.choice(args)
29 |
--------------------------------------------------------------------------------
/pytracking/utils/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('TkAgg')
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def show_tensor(a: torch.Tensor, fig_num = None, title = None):
9 | """Display a 2D tensor.
10 | args:
11 | fig_num: Figure number.
12 | title: Title of figure.
13 | """
14 | a_np = a.squeeze().cpu().clone().detach().numpy()
15 | if a_np.ndim == 3:
16 | a_np = np.transpose(a_np, (1, 2, 0))
17 | plt.figure(fig_num)
18 | plt.tight_layout()
19 | plt.cla()
20 | plt.imshow(a_np)
21 | plt.axis('off')
22 | plt.axis('equal')
23 | if title is not None:
24 | plt.title(title)
25 | plt.draw()
26 | plt.pause(0.001)
27 |
28 |
29 | def plot_graph(a: torch.Tensor, fig_num = None, title = None):
30 | """Plot graph. Data is a 1D tensor.
31 | args:
32 | fig_num: Figure number.
33 | title: Title of figure.
34 | """
35 | a_np = a.squeeze().cpu().clone().detach().numpy()
36 | if a_np.ndim > 1:
37 | raise ValueError
38 | plt.figure(fig_num)
39 | # plt.tight_layout()
40 | plt.cla()
41 | plt.plot(a_np)
42 | if title is not None:
43 | plt.title(title)
44 | plt.draw()
45 | plt.pause(0.001)
46 |
--------------------------------------------------------------------------------