├── INSTALL.md ├── LICENSE ├── README.md ├── figure ├── SPS.png └── flowchart.png ├── install.sh ├── ltr ├── README.md ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── actors │ ├── __init__.py │ ├── base_actor.py │ └── bbreg.py ├── admin │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── environment.cpython-36.pyc │ │ ├── loading.cpython-36.pyc │ │ ├── local.cpython-36.pyc │ │ ├── model_constructor.cpython-36.pyc │ │ ├── settings.cpython-36.pyc │ │ └── stats.cpython-36.pyc │ ├── environment.py │ ├── loading.py │ ├── local.py │ ├── model_constructor.py │ ├── settings.py │ ├── stats.py │ └── tensorboard.py ├── data │ ├── __init__.py │ ├── image_loader.py │ ├── loader.py │ ├── processing.py │ ├── processing_utils.py │ ├── sampler.py │ └── transforms.py ├── data_specs │ ├── got10k_train_split.txt │ ├── got10k_val_split.txt │ └── lasot_train_split.txt ├── dataset │ ├── __init__.py │ ├── base_dataset.py │ ├── coco_seq.py │ ├── got10k.py │ ├── imagenetvid.py │ ├── lasot.py │ └── tracking_net.py ├── external │ └── PreciseRoIPooling │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── _assets │ │ └── prroi_visualization.png │ │ ├── pytorch │ │ ├── prroi_pool │ │ │ ├── .gitignore │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ ├── functional.py │ │ │ ├── prroi_pool.py │ │ │ ├── src │ │ │ │ ├── prroi_pooling_gpu.c │ │ │ │ ├── prroi_pooling_gpu.h │ │ │ │ ├── prroi_pooling_gpu_impl.cu │ │ │ │ └── prroi_pooling_gpu_impl.cuh │ │ │ └── travis.sh │ │ └── tests │ │ │ └── test_prroi_pooling2d.py │ │ └── src │ │ ├── prroi_pooling_gpu_impl.cu │ │ └── prroi_pooling_gpu_impl.cuh ├── models │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── resnet.cpython-36.pyc │ │ │ └── resnet18_vggm.cpython-36.pyc │ │ ├── resnet.py │ │ └── resnet18_vggm.py │ ├── bbreg │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── atom.cpython-36.pyc │ │ │ └── atom_iou_net.cpython-36.pyc │ │ ├── atom.py │ │ └── atom_iou_net.py │ └── layers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── blocks.cpython-36.pyc │ │ └── blocks.py ├── run_training.py ├── train_settings │ ├── __init__.py │ └── bbreg │ │ ├── __init__.py │ │ └── atom_default.py └── trainers │ ├── __init__.py │ ├── base_trainer.py │ └── ltr_trainer.py └── pytracking ├── README.md ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── run_tracker.cpython-36.pyc └── run_webcam.cpython-36.pyc ├── evaluation ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── data.cpython-36.pyc │ ├── environment.cpython-36.pyc │ ├── got10kdataset.cpython-36.pyc │ ├── lasotdataset.cpython-36.pyc │ ├── local.cpython-36.pyc │ ├── nfsdataset.cpython-36.pyc │ ├── otbdataset.cpython-36.pyc │ ├── running.cpython-36.pyc │ ├── tpldataset.cpython-36.pyc │ ├── tracker.cpython-36.pyc │ ├── trackingnetdataset.cpython-36.pyc │ ├── uavdataset.cpython-36.pyc │ └── votdataset.cpython-36.pyc ├── data.py ├── environment.py ├── got10kdataset.py ├── lasotdataset.py ├── local.py ├── nfsdataset.py ├── otbdataset.py ├── running.py ├── tpldataset.py ├── tracker.py ├── trackingnetdataset.py ├── uavdataset.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ ├── anchors.py │ ├── bbox_helper.py │ ├── benchmark_helper.py │ ├── config_helper.py │ ├── load_helper.py │ ├── log_helper.py │ ├── pysot │ │ ├── __init__.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── dataset.py │ │ │ ├── video.py │ │ │ └── vot.py │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ ├── ar_benchmark.py │ │ │ └── eao_benchmark.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── build │ │ │ └── temp.linux-x86_64-3.6 │ │ │ │ ├── region.o │ │ │ │ └── src │ │ │ │ └── region.o │ │ │ ├── c_region.pxd │ │ │ ├── misc.py │ │ │ ├── region.c │ │ │ ├── region.cpython-36m-x86_64-linux-gnu.so │ │ │ ├── region.pyx │ │ │ ├── setup.py │ │ │ ├── src │ │ │ ├── buffer.h │ │ │ ├── region.c │ │ │ └── region.h │ │ │ └── statistics.py │ ├── pyvotkit │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ ├── build │ │ │ └── temp.linux-x86_64-3.6 │ │ │ │ ├── region.o │ │ │ │ └── src │ │ │ │ └── region.o │ │ ├── c_region.pxd │ │ ├── region.c │ │ ├── region.cpython-36m-x86_64-linux-gnu.so │ │ ├── region.pyx │ │ ├── setup.py │ │ └── src │ │ │ ├── buffer.h │ │ │ ├── region.c │ │ │ └── region.h │ └── tracker_config.py └── votdataset.py ├── experiments ├── __init__.py └── myexperiments.py ├── features ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── augmentation.cpython-36.pyc │ ├── deep.cpython-36.pyc │ ├── extractor.cpython-36.pyc │ ├── featurebase.cpython-36.pyc │ └── preprocessing.cpython-36.pyc ├── augmentation.py ├── color.py ├── deep.py ├── extractor.py ├── featurebase.py ├── preprocessing.py └── util.py ├── libs ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── complex.cpython-36.pyc │ ├── dcf.cpython-36.pyc │ ├── fourier.cpython-36.pyc │ ├── operation.cpython-36.pyc │ ├── optimization.cpython-36.pyc │ ├── tensordict.cpython-36.pyc │ ├── tensorlist.cpython-36.pyc │ └── tensorlist.cpython-37.pyc ├── complex.py ├── dcf.py ├── fourier.py ├── operation.py ├── optimization.py ├── tensordict.py └── tensorlist.py ├── parameter ├── SPSTracker │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── default_vot.cpython-36.pyc │ ├── default.py │ └── default_vot.py ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc └── atom │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── default.cpython-36.pyc │ └── default_vot.cpython-36.pyc │ ├── default.py │ └── default_vot.py ├── run_experiment.py ├── run_tracker.py ├── run_video.py ├── run_webcam.py ├── tracker ├── SPSTracker │ ├── SPSTracker.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── SPSTracker.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── optim.cpython-36.pyc │ └── optim.py ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── atom │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── atom.cpython-36.pyc │ │ └── optim.cpython-36.pyc │ ├── atom.py │ └── optim.py └── base │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── basetracker.cpython-36.pyc │ └── basetracker.py ├── util ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── anchors.cpython-36.pyc │ ├── bbox_helper.cpython-36.pyc │ ├── config_helper.cpython-36.pyc │ ├── load_helper.cpython-36.pyc │ └── tracker_config.cpython-36.pyc ├── anchors.py ├── bbox_helper.py ├── benchmark_helper.py ├── config_helper.py ├── load_helper.py ├── log_helper.py ├── pysot │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── video.py │ │ └── vot.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── ar_benchmark.py │ │ └── eao_benchmark.py │ └── utils │ │ ├── __init__.py │ │ ├── build │ │ └── temp.linux-x86_64-3.6 │ │ │ ├── region.o │ │ │ └── src │ │ │ └── region.o │ │ ├── c_region.pxd │ │ ├── misc.py │ │ ├── region.c │ │ ├── region.cpython-36m-x86_64-linux-gnu.so │ │ ├── region.pyx │ │ ├── setup.py │ │ ├── src │ │ ├── buffer.h │ │ ├── region.c │ │ └── region.h │ │ └── statistics.py ├── pyvotkit │ ├── __init__.py │ ├── build │ │ └── temp.linux-x86_64-3.6 │ │ │ ├── region.o │ │ │ └── src │ │ │ └── region.o │ ├── c_region.pxd │ ├── region.c │ ├── region.cpython-36m-x86_64-linux-gnu.so │ ├── region.pyx │ ├── setup.py │ └── src │ │ ├── buffer.h │ │ ├── region.c │ │ └── region.h └── tracker_config.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── params.cpython-36.pyc └── plotting.cpython-36.pyc ├── gdrive_download ├── params.py └── plotting.py /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | This document contains detailed instructions for installing the necessary dependencies for PyTracking. The instrustions have been tested on an Ubuntu 18.04 system. We recommend using the [install script](install.sh) if you have not already tried that. 4 | 5 | ### Requirements 6 | * Conda installation with Python 3.7. If not already installed, install from https://www.anaconda.com/distribution/. 7 | * Nvidia GPU. 8 | 9 | ## Step-by-step instructions 10 | #### Create and activate a conda environment 11 | ```bash 12 | conda create --name pytracking python=3.7 13 | conda activate pytracking 14 | ``` 15 | 16 | #### Install PyTorch 17 | Install PyTorch 0.4.1 with cuda92. 18 | ```bash 19 | conda install pytorch=0.4.1 torchvision cuda92 -c pytorch 20 | ``` 21 | 22 | **Note:** 23 | - PyTorch 1.0 should be supported, but **not recommended** as it requires an [alternate compilation](https://github.com/vacancy/PreciseRoIPooling) of the PreciseRoIPooling module which hasn't been tested. 24 | - It is possible to use any PyTorch supported version of CUDA (not necessarily 9.2). 25 | - For more details about PyTorch installation, see https://pytorch.org/get-started/previous-versions/. 26 | 27 | #### Install matplotlib, pandas, opencv and tensorboadX 28 | ```bash 29 | conda install matplotlib=2.2.2 pandas 30 | pip install opencv-python tensorboardX 31 | ``` 32 | 33 | 34 | #### Install the coco toolkit 35 | If you want to use COCO dataset for training, install the coco python toolkit. You additionally need to install cython to compile the coco toolkit. 36 | ```bash 37 | conda install cython 38 | pip install pycocotools 39 | ``` 40 | 41 | 42 | #### Compile Precise ROI pooling 43 | To compile the Precise ROI pooling module (https://github.com/vacancy/PreciseRoIPooling) for PyTorch 0.4.1, go to the directory "ltr/external/PreciseRoIPooling/pytorch/prroi_pool" and run "travis.sh" script. 44 | You may additionally have to export the path to the cuda installation. 45 | ```bash 46 | cd ltr/external/PreciseRoIPooling/pytorch/prroi_pool 47 | 48 | # Export the path to the cuda installation 49 | PATH=/usr/local/cuda/bin/:$PATH 50 | 51 | # Compile Precise ROI Pool 52 | bash travis.sh 53 | ``` 54 | 55 | In case of issues, we refer to https://github.com/vacancy/PreciseRoIPooling. 56 | 57 | 58 | #### Install jpeg4py 59 | In order to use [jpeg4py](https://github.com/ajkxyz/jpeg4py) for loading the images instead of OpenCV's imread(), install jpeg4py in the following way, 60 | ```bash 61 | sudo apt-get install libturbojpeg 62 | pip install jpeg4py 63 | ``` 64 | 65 | **Note:** The first step (```sudo apt-get install libturbojpeg```) can be optionally ignored, in which case OpenCV's imread() will be used to read the images. However the second step is a must. 66 | 67 | In case of issues, we refer to https://github.com/ajkxyz/jpeg4py. 68 | 69 | 70 | #### Setup the environment 71 | Create the default environment setting files. 72 | ```bash 73 | # Environment settings for pytracking. Saved at pytracking/evaluation/local.py 74 | python -c "from pytracking.evaluation.environment import create_default_local_file; create_default_local_file()" 75 | 76 | # Environment settings for ltr. Saved at ltr/admin/local.py 77 | python -c "from ltr.admin.environment import create_default_local_file; create_default_local_file()" 78 | ``` 79 | 80 | You can modify these files to set the paths to datasets, results paths etc. 81 | 82 | 83 | #### Download the pre-trained networks 84 | You can download the pre-trained networks from the [google drive folder](https://drive.google.com/drive/folders/1WVhJqvdu-_JG1U-V0IqfxTUa1SBPnL0O). The networks shoud be saved in the directory set by "network_path" in "pytracking/evaluation/local.py". By default, it is set to pytracking/networks. 85 | You can also download the networks using the gdrive_download bash script. 86 | 87 | ```bash 88 | # Download the default network for ATOM 89 | bash pytracking/utils/gdrive_download 1JUB3EucZfBk3rX7M3_q5w_dLBqsT7s-M pytracking/networks/atom_default.pth 90 | 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [SPSTracker: Sub-Peak Suppression of Response Map for Robust Object Tracking](https://arxiv.org/abs/1912.00597). 2 | 3 | 4 | ## News 5 | * \[2019/11/11\] SPSTracker has been accepted as oral paper at the Thirty-Fourth AAAI Conference on Artificial Intelligence (AAAI-20). 6 | ## Introduction 7 | This is the official code of SPSTracker: Sub-Peak Suppression of Response Map for Robust Object Tracking. We propose a simple-yet-effective approach, referred to as SPSTracker for robust object tracking. Our motivation is based on the observation that most failure tracking is caused by the interference around the target. Such interference produces multi-peak tracking response, and the sub-peak may progressively ``grow" and eventually cause model drift. Therefore ,we propose suppressing the sub-peaks to aggregating a single-peak response, with the aim of preventing model drift from the perspective of tracking response regularization. 8 | ![comparison](figure/SPS.png) 9 | ![flowchart](figure/flowchart.png) 10 | 11 | ## Installation 12 | Code is implemented upon the ATOM architecture, check [INSTALL.md](INSTALL.md) for installation instructions. 13 | 14 | ## [Results](https://drive.google.com/drive/folders/1bX_5fcm2EfeZv5dx3L8CwhMEHGwrGXvV)[[Raw result](https://drive.google.com/open?id=1IPpOGYU6r5Dmz20PcrfOBtMCJ60C6q3Y)] 15 | 16 | | Tracker | VOT2016
EAO / A / R
| VOT2018
EAO / A / R
| 17 | |:----------------------------------------------------------------------:|:--------------------------------------------:|:--------------------------------------------:| 18 | | SPSTracker | 0.459/0.625/0.158 | 0.434/0.612/0.169 | 19 | 20 | ## Run 21 | SPSTracker default_vot --dataset vot --debug 1 --threads 0 22 | 23 | 24 | ## Citations 25 | Please consider citing our paper in your publications if the project helps your research. 26 | ``` 27 | @inproceedings{hu2020spstracker, 28 | title = {{SPSTracker}: Sub-Peak Suppression of Response Map for Robust Object Tracking}, 29 | author = {Qintao Hu and Lijun Zhou and Xiaoxiao Wang and Yao Mao and Jianlin Zhang and Qixiang Ye}, 30 | booktitle = {Thirty-Fourth AAAI Conference on Artificial Intelligence (AAAI)}, 31 | year = {2020} 32 | } 33 | ``` 34 | -------------------------------------------------------------------------------- /figure/SPS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/figure/SPS.png -------------------------------------------------------------------------------- /figure/flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/figure/flowchart.png -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 2 ]; then 4 | echo "ERROR! Illegal number of parameters. Usage: bash install.sh conda_install_path environment_name" 5 | exit 0 6 | fi 7 | 8 | conda_install_path=$1 9 | conda_env_name=$2 10 | 11 | source $conda_install_path/etc/profile.d/conda.sh 12 | echo "****************** Creating conda environment ${conda_env_name} python=3.7 ******************" 13 | conda create -y --name $conda_env_name 14 | 15 | echo "" 16 | echo "" 17 | echo "****************** Activating conda environment ${conda_env_name} ******************" 18 | conda activate $conda_env_name 19 | 20 | echo "" 21 | echo "" 22 | echo "****************** Installing pytorch 0.4.1 with cuda92 ******************" 23 | conda install -y pytorch=0.4.1 torchvision cuda92 -c pytorch 24 | 25 | echo "" 26 | echo "" 27 | echo "****************** Installing matplotlib 2.2.2 ******************" 28 | conda install -y matplotlib=2.2.2 29 | 30 | echo "" 31 | echo "" 32 | echo "****************** Installing pandas ******************" 33 | conda install -y pandas 34 | 35 | echo "" 36 | echo "" 37 | echo "****************** Installing opencv ******************" 38 | pip install opencv-python 39 | 40 | echo "" 41 | echo "" 42 | echo "****************** Installing tensorboardX ******************" 43 | pip install tensorboardX 44 | 45 | echo "" 46 | echo "" 47 | echo "****************** Installing cython ******************" 48 | conda install -y cython 49 | 50 | echo "" 51 | echo "" 52 | echo "****************** Installing coco toolkit ******************" 53 | pip install pycocotools 54 | 55 | echo "" 56 | echo "" 57 | echo "****************** Installing jpeg4py python wrapper ******************" 58 | pip install jpeg4py 59 | 60 | echo "" 61 | echo "" 62 | echo "****************** Installing PreROIPooling ******************" 63 | base_dir=$(pwd) 64 | cd ltr/external/PreciseRoIPooling/pytorch/prroi_pool 65 | PATH=/usr/local/cuda/bin/:$PATH 66 | bash travis.sh 67 | cd $base_dir 68 | 69 | echo "" 70 | echo "" 71 | echo "****************** Downloading networks ******************" 72 | mkdir pytracking/networks 73 | 74 | echo "" 75 | echo "" 76 | echo "****************** ATOM Network ******************" 77 | bash pytracking/utils/gdrive_download 1ZTdQbZ1tyN27UIwUnUrjHChQb5ug2sxr pytracking/networks/atom_default.pth 78 | 79 | echo "" 80 | echo "" 81 | echo "****************** ECO Network ******************" 82 | bash pytracking/utils/gdrive_download 1aWC4waLv_te-BULoy0k-n_zS-ONms21S pytracking/networks/resnet18_vggmconv1.pth 83 | 84 | echo "" 85 | echo "" 86 | echo "****************** Setting up environment ******************" 87 | python -c "from pytracking.evaluation.environment import create_default_local_file; create_default_local_file()" 88 | python -c "from ltr.admin.environment import create_default_local_file; create_default_local_file()" 89 | 90 | 91 | echo "" 92 | echo "" 93 | echo "****************** Installing jpeg4py ******************" 94 | while true; do 95 | read -p "Install jpeg4py for reading images? This step required sudo privilege. Installing jpeg4py is optional, however recommended. [y,n] " install_flag 96 | case $install_flag in 97 | [Yy]* ) sudo apt-get install libturbojpeg; break;; 98 | [Nn]* ) echo "Skipping jpeg4py installation!"; break;; 99 | * ) echo "Please answer y or n ";; 100 | esac 101 | done 102 | 103 | echo "" 104 | echo "" 105 | echo "****************** Installation complete! ******************" 106 | -------------------------------------------------------------------------------- /ltr/README.md: -------------------------------------------------------------------------------- 1 | # LTR 2 | 3 | A general PyTorch based framework for learning tracking representations. The repository contains the code for training the [**ATOM**](https://arxiv.org/pdf/1811.07628.pdf) tracker. 4 | 5 | ## Table of Contents 6 | 7 | * [Quick Start](#quick-start) 8 | * [Overview](#overview) 9 | * [Train Settings](#train-settings) 10 | * [Training your own networks](#training-your-own-networks) 11 | 12 | ## Quick Start 13 | The installation script will automatically generate a local configuration file "admin/local.py". In case the file was not generated, run ```admin.environment.create_default_local_file()``` to generate it. Next, set the paths to the training workspace, 14 | i.e. the directory where the checkpoints will be saved. Also set the paths to the datasets you want to use. If all the dependencies have been correctly installed, you can train a network using the run_training.py script in the correct conda environment. 15 | ```bash 16 | conda activate pytracking 17 | python run_training train_module train_name 18 | ``` 19 | Here, ```train_module``` is the sub-module inside ```train_settings``` and ```train_name``` is the name of the train setting file to be used. 20 | 21 | For example, you can train using the included default ATOM settings by running: 22 | ```bash 23 | python run_training bbreg atom_default 24 | ``` 25 | 26 | 27 | ## Overview 28 | The framework consists of the following sub-modules. 29 | - [actors](actors): Contains the actor classes for different trainings. The actor class is responsible for passing the input data through the network can calculating losses. 30 | - [admin](admin): Includes functions for loading networks, tensorboard etc. and also contains environment settings. 31 | - [dataset](dataset): Contains integration of a number of training datasets, namely [TrackingNet](https://tracking-net.org/), [GOT-10k](http://got-10k.aitestunion.com/), [LaSOT](https://cis.temple.edu/lasot/), 32 | [ImageNet-VID](http://image-net.org/), and [COCO](http://cocodataset.org/#home). 33 | - [data_specs](data_specs): Information about train/val splits of different datasets. 34 | - [data](data): Contains functions for processing data, e.g. loading images, data augmentations, sampling frames from videos. 35 | - [external](external): External libraries needed for training. Added as submodules. 36 | - [models](models): Contains different layers and network definitions. 37 | - [trainers](trainers): The main class which runs the training. 38 | - [train_settings](train_settings): Contains settings files, specifying the training of a network. 39 | 40 | ## Train Settings 41 | The framework currently contains the following training settings: 42 | - [bbreg.atom_default](train_settings/bbreg/atom_default.py): The default settings used for training the network in [ATOM](https://arxiv.org/pdf/1811.07628.pdf). 43 | 44 | 45 | ## Training your own networks 46 | To train a custom network using the toolkit, the following components need to be specified in the train settings. For reference, see [atom_default.py](train_settings/bbreg/atom_default.py). 47 | - Datasets: The datasets to be used for training. A number of standard tracking datasets are already available in ```dataset``` module. 48 | - Processing: This function should perform the necessary post-processing of the data, e.g. cropping of target region, data augmentations etc. 49 | - Sampler: Determines how the frames are sampled from a video sequence to form the batches. 50 | - Network: The network module to be trained. 51 | - Objective: The training objective. 52 | - Actor: The trainer passes the training batch to the actor who is responsible for passing the data through the network correctly, and calculating the training loss. 53 | - Optimizer: Optimizer to be used, e.g. Adam. 54 | - Trainer: The main class which runs the epochs and saves checkpoints. 55 | 56 | 57 | -------------------------------------------------------------------------------- /ltr/__init__.py: -------------------------------------------------------------------------------- 1 | from .admin.loading import load_network 2 | from .admin.model_constructor import model_constructor -------------------------------------------------------------------------------- /ltr/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/actors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_actor import BaseActor 2 | from .bbreg import AtomActor 3 | -------------------------------------------------------------------------------- /ltr/actors/base_actor.py: -------------------------------------------------------------------------------- 1 | from pytracking import TensorDict 2 | 3 | 4 | class BaseActor: 5 | """ Base class for actor. The actor class handles the passing of the data through the network 6 | and calculation the loss""" 7 | def __init__(self, net, objective): 8 | """ 9 | args: 10 | net - The network to train 11 | objective - The loss function 12 | """ 13 | self.net = net 14 | self.objective = objective 15 | 16 | def __call__(self, data: TensorDict): 17 | """ Called in each training iteration. Should pass in input data through the network, calculate the loss, and 18 | return the training stats for the input data 19 | args: 20 | data - A TensorDict containing all the necessary data blocks. 21 | 22 | returns: 23 | loss - loss for the input data 24 | stats - a dict containing detailed losses 25 | """ 26 | raise NotImplementedError 27 | 28 | def to(self, device): 29 | """ Move the network to device 30 | args: 31 | device - device to use. 'cpu' or 'cuda' 32 | """ 33 | self.net.to(device) 34 | 35 | def train(self, mode=True): 36 | """ Set whether the network is in train mode. 37 | args: 38 | mode (True) - Bool specifying whether in training mode. 39 | """ 40 | self.net.train(mode) 41 | 42 | def eval(self): 43 | """ Set network to eval mode""" 44 | self.train(False) -------------------------------------------------------------------------------- /ltr/actors/bbreg.py: -------------------------------------------------------------------------------- 1 | from . import BaseActor 2 | 3 | 4 | class AtomActor(BaseActor): 5 | """ Actor for training the IoU-Net in ATOM""" 6 | def __call__(self, data): 7 | """ 8 | args: 9 | data - The input data, should contain the fields 'train_images', 'test_images', 'train_anno', 10 | 'test_proposals' and 'proposal_iou'. 11 | 12 | returns: 13 | loss - the training loss 14 | states - dict containing detailed losses 15 | """ 16 | # Run network to obtain IoU prediction for each proposal in 'test_proposals' 17 | iou_pred = self.net(data['train_images'], data['test_images'], data['train_anno'], data['test_proposals']) 18 | 19 | iou_pred = iou_pred.view(-1, iou_pred.shape[2]) 20 | iou_gt = data['proposal_iou'].view(-1, data['proposal_iou'].shape[2]) 21 | 22 | # Compute loss 23 | loss = self.objective(iou_pred, iou_gt) 24 | 25 | # Return training stats 26 | stats = {'Loss/total': loss.item(), 27 | 'Loss/iou': loss.item()} 28 | 29 | return loss, stats -------------------------------------------------------------------------------- /ltr/admin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__init__.py -------------------------------------------------------------------------------- /ltr/admin/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/admin/__pycache__/environment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/environment.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/admin/__pycache__/loading.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/loading.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/admin/__pycache__/local.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/local.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/admin/__pycache__/model_constructor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/model_constructor.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/admin/__pycache__/settings.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/settings.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/admin/__pycache__/stats.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/admin/__pycache__/stats.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/admin/environment.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from collections import OrderedDict 4 | 5 | 6 | def create_default_local_file(): 7 | path = os.path.join(os.path.dirname(__file__), 'local.py') 8 | 9 | empty_str = '\'\'' 10 | default_settings = OrderedDict({ 11 | 'workspace_dir': empty_str, 12 | 'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'', 13 | 'lasot_dir': empty_str, 14 | 'got10k_dir': empty_str, 15 | 'trackingnet_dir': empty_str, 16 | 'coco_dir': empty_str, 17 | 'imagenet_dir': empty_str, 18 | 'imagenetdet_dir': empty_str}) 19 | 20 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.', 21 | 'tensorboard_dir': 'Directory for tensorboard files.'} 22 | 23 | with open(path, 'w') as f: 24 | f.write('class EnvironmentSettings:\n') 25 | f.write(' def __init__(self):\n') 26 | 27 | for attr, attr_val in default_settings.items(): 28 | comment_str = None 29 | if attr in comment: 30 | comment_str = comment[attr] 31 | if comment_str is None: 32 | f.write(' self.{} = {}\n'.format(attr, attr_val)) 33 | else: 34 | f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str)) 35 | 36 | 37 | def env_settings(): 38 | env_module_name = 'ltr.admin.local' 39 | try: 40 | env_module = importlib.import_module(env_module_name) 41 | return env_module.EnvironmentSettings() 42 | except: 43 | env_file = os.path.join(os.path.dirname(__file__), 'local.py') 44 | 45 | create_default_local_file() 46 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file)) 47 | -------------------------------------------------------------------------------- /ltr/admin/local.py: -------------------------------------------------------------------------------- 1 | class EnvironmentSettings: 2 | def __init__(self): 3 | self.workspace_dir = '' # Base directory for saving network checkpoints. 4 | self.tensorboard_dir = self.workspace_dir + '/tensorboard/' # Directory for tensorboard files. 5 | self.lasot_dir = '' 6 | self.got10k_dir = '' 7 | self.trackingnet_dir = '' 8 | self.coco_dir = '' 9 | self.imagenet_dir = '' 10 | self.imagenetdet_dir = '' 11 | -------------------------------------------------------------------------------- /ltr/admin/model_constructor.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import importlib 3 | 4 | 5 | def model_constructor(f): 6 | """ Wraps the function 'f' which returns the network. An extra field 'constructor' is added to the network returned 7 | by 'f'. This field contains an instance of the 'NetConstructor' class, which contains the information needed to 8 | re-construct the network, such as the name of the function 'f', the function arguments etc. Thus, the network can 9 | be easily constructed from a saved checkpoint by calling NetConstructor.get() function. 10 | """ 11 | @wraps(f) 12 | def f_wrapper(*args, **kwds): 13 | net_constr = NetConstructor(f.__name__, f.__module__, args, kwds) 14 | output = f(*args, **kwds) 15 | if isinstance(output, (tuple, list)): 16 | # Assume first argument is the network 17 | output[0].constructor = net_constr 18 | else: 19 | output.constructor = net_constr 20 | return output 21 | return f_wrapper 22 | 23 | 24 | class NetConstructor: 25 | """ Class to construct networks. Takes as input the function name (e.g. atom_resnet18), the name of the module 26 | which contains the network function (e.g. ltr.models.bbreg.atom) and the arguments for the network 27 | function. The class object can then be stored along with the network weights to re-construct the network.""" 28 | def __init__(self, fun_name, fun_module, args, kwds): 29 | """ 30 | args: 31 | fun_name - The function which returns the network 32 | fun_module - the module which contains the network function 33 | args - arguments which are passed to the network function 34 | kwds - arguments which are passed to the network function 35 | """ 36 | self.fun_name = fun_name 37 | self.fun_module = fun_module 38 | self.args = args 39 | self.kwds = kwds 40 | 41 | def get(self): 42 | """ Rebuild the network by calling the network function with the correct arguments. """ 43 | net_module = importlib.import_module(self.fun_module) 44 | net_fun = getattr(net_module, self.fun_name) 45 | return net_fun(*self.args, **self.kwds) 46 | -------------------------------------------------------------------------------- /ltr/admin/settings.py: -------------------------------------------------------------------------------- 1 | from ltr.admin.environment import env_settings 2 | 3 | 4 | class Settings: 5 | """ Training settings, e.g. the paths to datasets and networks.""" 6 | def __init__(self): 7 | self.set_default() 8 | 9 | def set_default(self): 10 | self.env = env_settings() 11 | self.use_gpu = True 12 | 13 | 14 | -------------------------------------------------------------------------------- /ltr/admin/stats.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class StatValue: 4 | def __init__(self): 5 | self.clear() 6 | 7 | def reset(self): 8 | self.val = 0 9 | 10 | def clear(self): 11 | self.reset() 12 | self.history = [] 13 | 14 | def update(self, val): 15 | self.val = val 16 | self.history.append(self.val) 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | def __init__(self): 22 | self.clear() 23 | self.has_new_data = False 24 | 25 | def reset(self): 26 | self.avg = 0 27 | self.val = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def clear(self): 32 | self.reset() 33 | self.history = [] 34 | 35 | def update(self, val, n=1): 36 | self.val = val 37 | self.sum += val * n 38 | self.count += n 39 | self.avg = self.sum / self.count 40 | 41 | def new_epoch(self): 42 | if self.count > 0: 43 | self.history.append(self.avg) 44 | self.reset() 45 | self.has_new_data = True 46 | else: 47 | self.has_new_data = False 48 | 49 | 50 | def topk_accuracy(output, target, topk=(1,)): 51 | """Computes the precision@k for the specified values of k""" 52 | single_input = not isinstance(topk, (tuple, list)) 53 | if single_input: 54 | topk = (topk,) 55 | 56 | maxk = max(topk) 57 | batch_size = target.size(0) 58 | 59 | _, pred = output.topk(maxk, 1, True, True) 60 | pred = pred.t() 61 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 62 | 63 | res = [] 64 | for k in topk: 65 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0] 66 | res.append(correct_k * 100.0 / batch_size) 67 | 68 | if single_input: 69 | return res[0] 70 | 71 | return res 72 | -------------------------------------------------------------------------------- /ltr/admin/tensorboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from tensorboardX import SummaryWriter 4 | 5 | 6 | class TensorboardWriter: 7 | def __init__(self, directory, loader_names): 8 | self.directory = directory 9 | self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names}) 10 | 11 | def write_info(self, module_name, script_name, description): 12 | tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info')) 13 | tb_info_writer.add_text('Modulet_name', module_name) 14 | tb_info_writer.add_text('Script_name', script_name) 15 | tb_info_writer.add_text('Description', description) 16 | tb_info_writer.close() 17 | 18 | def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1): 19 | for loader_name, loader_stats in stats.items(): 20 | if loader_stats is None: 21 | continue 22 | for var_name, val in loader_stats.items(): 23 | if hasattr(val, 'history') and getattr(val, 'has_new_data', True): 24 | self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch) -------------------------------------------------------------------------------- /ltr/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import LTRLoader -------------------------------------------------------------------------------- /ltr/data/image_loader.py: -------------------------------------------------------------------------------- 1 | import jpeg4py 2 | import cv2 as cv 3 | 4 | 5 | def default_image_loader(path): 6 | """The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader, 7 | but reverts to the opencv_loader if the former is not available.""" 8 | if default_image_loader.use_jpeg4py is None: 9 | # Try using jpeg4py 10 | im = jpeg4py_loader(path) 11 | if im is None: 12 | default_image_loader.use_jpeg4py = False 13 | print('Using opencv_loader instead.') 14 | else: 15 | default_image_loader.use_jpeg4py = True 16 | return im 17 | if default_image_loader.use_jpeg4py: 18 | return jpeg4py_loader(path) 19 | return opencv_loader(path) 20 | 21 | default_image_loader.use_jpeg4py = None 22 | 23 | 24 | def jpeg4py_loader(path): 25 | """ Image reading using jpeg4py (https://github.com/ajkxyz/jpeg4py)""" 26 | try: 27 | return jpeg4py.JPEG(path).decode() 28 | except Exception as e: 29 | print('ERROR: Could not read image "{}"'.format(path)) 30 | print(e) 31 | return None 32 | 33 | 34 | def opencv_loader(path): 35 | """ Read image using opencv's imread function and returns it in rgb format""" 36 | try: 37 | im = cv.imread(path, cv.IMREAD_COLOR) 38 | # convert to rgb and return 39 | return cv.cvtColor(im, cv.COLOR_BGR2RGB) 40 | except Exception as e: 41 | print('ERROR: Could not read image "{}"'.format(path)) 42 | print(e) 43 | return None 44 | -------------------------------------------------------------------------------- /ltr/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .lasot import Lasot 2 | from .got10k import Got10k 3 | from .tracking_net import TrackingNet 4 | from .imagenetvid import ImagenetVID 5 | from .coco_seq import MSCOCOSeq 6 | -------------------------------------------------------------------------------- /ltr/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from ltr.data.image_loader import default_image_loader 3 | 4 | 5 | class BaseDataset(torch.utils.data.Dataset): 6 | """ Base class for datasets """ 7 | 8 | def __init__(self, root, image_loader=default_image_loader): 9 | """ 10 | args: 11 | root - The root path to the dataset 12 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 13 | is used by default. 14 | """ 15 | if root == '': 16 | raise Exception('The dataset path is not setup. Check your "ltr/admin/local.py".') 17 | self.root = root 18 | self.image_loader = image_loader 19 | 20 | self.sequence_list = [] # Contains the list of sequences. 21 | 22 | def __len__(self): 23 | """ Returns size of the dataset 24 | returns: 25 | int - number of samples in the dataset 26 | """ 27 | return self.get_num_sequences() 28 | 29 | def __getitem__(self, index): 30 | """ Not to be used! Check get_frames() instead. 31 | """ 32 | return None 33 | 34 | def is_video_sequence(self): 35 | """ Returns whether the dataset is a video dataset or an image dataset 36 | 37 | returns: 38 | bool - True if a video dataset 39 | """ 40 | return True 41 | 42 | def get_name(self): 43 | """ Name of the dataset 44 | 45 | returns: 46 | string - Name of the dataset 47 | """ 48 | raise NotImplementedError 49 | 50 | def get_num_sequences(self): 51 | """ Number of sequences in a dataset 52 | 53 | returns: 54 | int - number of sequences in the dataset.""" 55 | return len(self.sequence_list) 56 | 57 | def get_sequence_info(self, seq_id): 58 | """ Returns information about a particular sequences, 59 | 60 | args: 61 | seq_id - index of the sequence 62 | 63 | returns: 64 | Tensor - Annotation for the sequence. A 2d tensor of shape (num_frames, 4). 65 | Format [top_left_x, top_left_y, width, height] 66 | Tensor - 1d Tensor specifying whether target is present (=1 )for each frame. shape (num_frames,) 67 | """ 68 | raise NotImplementedError 69 | 70 | def get_frames(self, seq_id, frame_ids, anno=None): 71 | """ Get a set of frames from a particular sequence 72 | 73 | args: 74 | seq_id - index of sequence 75 | frame_ids - a list of frame numbers 76 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded. 77 | 78 | returns: 79 | list - List of frames corresponding to frame_ids 80 | list - List of annotations (tensor of shape (4,)) for each frame 81 | dict - A dict containing meta information about the sequence, e.g. class of the target object. 82 | 83 | """ 84 | raise NotImplementedError 85 | 86 | -------------------------------------------------------------------------------- /ltr/dataset/coco_seq.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .base_dataset import BaseDataset 3 | from ltr.data.image_loader import default_image_loader 4 | import torch 5 | from pycocotools.coco import COCO 6 | from collections import OrderedDict 7 | from ltr.admin.environment import env_settings 8 | 9 | 10 | class MSCOCOSeq(BaseDataset): 11 | """ The COCO dataset. COCO is an image dataset. Thus, we treat each image as a sequence of length 1. 12 | 13 | Publication: 14 | Microsoft COCO: Common Objects in Context. 15 | Tsung-Yi Lin, Michael Maire, Serge J. Belongie, Lubomir D. Bourdev, Ross B. Girshick, James Hays, Pietro Perona, 16 | Deva Ramanan, Piotr Dollar and C. Lawrence Zitnick 17 | ECCV, 2014 18 | https://arxiv.org/pdf/1405.0312.pdf 19 | 20 | Download the images along with annotations from http://cocodataset.org/#download. The root folder should be 21 | organized as follows. 22 | - coco_root 23 | - annotations 24 | - instances_train2014.json 25 | - images 26 | - train2014 27 | 28 | Note: You also have to install the coco pythonAPI from https://github.com/cocodataset/cocoapi. 29 | """ 30 | 31 | def __init__(self, root=None, image_loader=default_image_loader): 32 | root = env_settings().coco_dir if root is None else root 33 | super().__init__(root, image_loader) 34 | 35 | self.img_pth = os.path.join(root, 'train2014/') 36 | self.anno_path = os.path.join(root, 'annotations/instances_train2014.json') 37 | 38 | # Load the COCO set. 39 | self.coco_set = COCO(self.anno_path) 40 | 41 | self.cats = self.coco_set.cats 42 | self.sequence_list = self._get_sequence_list() 43 | 44 | def _get_sequence_list(self): 45 | ann_list = list(self.coco_set.anns.keys()) 46 | seq_list = [a for a in ann_list if self.coco_set.anns[a]['iscrowd'] == 0] 47 | 48 | return seq_list 49 | 50 | def is_video_sequence(self): 51 | return False 52 | 53 | def get_name(self): 54 | return 'coco' 55 | 56 | def get_num_sequences(self): 57 | return len(self.sequence_list) 58 | 59 | def get_sequence_info(self, seq_id): 60 | anno = self._get_anno(seq_id) 61 | 62 | return anno, torch.Tensor([1]) 63 | 64 | def _get_anno(self, seq_id): 65 | anno = self.coco_set.anns[self.sequence_list[seq_id]]['bbox'] 66 | return torch.Tensor(anno).view(1, 4) 67 | 68 | def _get_frames(self, seq_id): 69 | path = self.coco_set.loadImgs([self.coco_set.anns[self.sequence_list[seq_id]]['image_id']])[0]['file_name'] 70 | img = self.image_loader(os.path.join(self.img_pth, path)) 71 | return img 72 | 73 | def get_meta_info(self, seq_id): 74 | try: 75 | cat_dict_current = self.cats[self.coco_set.anns[self.sequence_list[seq_id]]['category_id']] 76 | object_meta = OrderedDict({'object_class': cat_dict_current['name'], 77 | 'motion_class': None, 78 | 'major_class': cat_dict_current['supercategory'], 79 | 'root_class': None, 80 | 'motion_adverb': None}) 81 | except: 82 | object_meta = OrderedDict({'object_class': None, 83 | 'motion_class': None, 84 | 'major_class': None, 85 | 'root_class': None, 86 | 'motion_adverb': None}) 87 | return object_meta 88 | 89 | def get_frames(self, seq_id=None, frame_ids=None, anno=None): 90 | # COCO is an image dataset. Thus we replicate the image denoted by seq_id len(frame_ids) times, and return a 91 | # list containing these replicated images. 92 | frame = self._get_frames(seq_id) 93 | 94 | frame_list = [frame.copy() for _ in frame_ids] 95 | 96 | if anno is None: 97 | anno = self._get_anno(seq_id) 98 | 99 | anno_frames = [anno.clone()[0, :] for _ in frame_ids] 100 | 101 | object_meta = self.get_meta_info(seq_id) 102 | 103 | return frame_list, anno_frames, object_meta 104 | -------------------------------------------------------------------------------- /ltr/dataset/tracking_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import os.path 4 | import numpy as np 5 | import pandas 6 | from collections import OrderedDict 7 | 8 | from ltr.data.image_loader import default_image_loader 9 | from .base_dataset import BaseDataset 10 | from ltr.admin.environment import env_settings 11 | 12 | 13 | def list_sequences(root, set_ids): 14 | """ Lists all the videos in the input set_ids. Returns a list of tuples (set_id, video_name) 15 | 16 | args: 17 | root: Root directory to TrackingNet 18 | set_ids: Sets (0-11) which are to be used 19 | 20 | returns: 21 | list - list of tuples (set_id, video_name) containing the set_id and video_name for each sequence 22 | """ 23 | sequence_list = [] 24 | 25 | for s in set_ids: 26 | anno_dir = os.path.join(root, "TRAIN_" + str(s), "anno") 27 | 28 | sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')] 29 | sequence_list += sequences_cur_set 30 | 31 | return sequence_list 32 | 33 | 34 | class TrackingNet(BaseDataset): 35 | """ TrackingNet dataset. 36 | 37 | Publication: 38 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild. 39 | Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem 40 | ECCV, 2018 41 | https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf 42 | 43 | Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit. 44 | """ 45 | def __init__(self, root=None, image_loader=default_image_loader, set_ids=None): 46 | """ 47 | args: 48 | root - The path to the TrackingNet folder, containing the training sets. 49 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) 50 | is used by default. 51 | set_ids (None) - List containing the ids of the TrackingNet sets to be used for training. If None, all the 52 | sets (0 - 11) will be used. 53 | """ 54 | root = env_settings().trackingnet_dir if root is None else root 55 | super().__init__(root, image_loader) 56 | 57 | if set_ids is None: 58 | set_ids = [i for i in range(12)] 59 | 60 | self.set_ids = set_ids 61 | 62 | # Keep a list of all videos. Sequence list is a list of tuples (set_id, video_name) containing the set_id and 63 | # video_name for each sequence 64 | self.sequence_list = list_sequences(self.root, self.set_ids) 65 | 66 | def get_name(self): 67 | return 'trackingnet' 68 | 69 | def _read_anno(self, seq_id): 70 | set_id = self.sequence_list[seq_id][0] 71 | vid_name = self.sequence_list[seq_id][1] 72 | anno_file = os.path.join(self.root, "TRAIN_" + str(set_id), "anno", vid_name + ".txt") 73 | gt = pandas.read_csv(anno_file, delimiter=',', header=None, dtype=np.float32, na_filter=False, low_memory=False).values 74 | return torch.tensor(gt) 75 | 76 | def get_sequence_info(self, seq_id): 77 | anno = self._read_anno(seq_id) 78 | target_visible = (anno[:,2]>0) & (anno[:,3]>0) 79 | return anno, target_visible 80 | 81 | def _get_frame(self, seq_id, frame_id): 82 | set_id = self.sequence_list[seq_id][0] 83 | vid_name = self.sequence_list[seq_id][1] 84 | frame_path = os.path.join(self.root, "TRAIN_" + str(set_id), "frames", vid_name, str(frame_id) + ".jpg") 85 | return self.image_loader(frame_path) 86 | 87 | def get_frames(self, seq_id, frame_ids, anno=None): 88 | frame_list = [self._get_frame(seq_id, f) for f in frame_ids] 89 | 90 | if anno is None: 91 | anno = self._read_anno(seq_id) 92 | 93 | # Return as list of tensors 94 | anno_frames = [anno[f_id, :] for f_id in frame_ids] 95 | 96 | object_meta = OrderedDict({'object_class': None, 97 | 'motion_class': None, 98 | 'major_class': None, 99 | 'root_class': None, 100 | 'motion_adverb': None}) 101 | 102 | return frame_list, anno_frames, object_meta 103 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .vim-template* 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jiayuan Mao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/README.md: -------------------------------------------------------------------------------- 1 | # PreciseRoIPooling 2 | This repo implements the **Precise RoI Pooling** (PrRoI Pooling), proposed in the paper **Acquisition of Localization Confidence for Accurate Object Detection** published at ECCV 2018 (Oral Presentation). 3 | 4 | **Acquisition of Localization Confidence for Accurate Object Detection** 5 | 6 | _Borui Jiang*, Ruixuan Luo*, Jiayuan Mao*, Tete Xiao, Yuning Jiang_ (* indicates equal contribution.) 7 | 8 | https://arxiv.org/abs/1807.11590 9 | 10 | ## Brief 11 | 12 | In short, Precise RoI Pooling is an integration-based (bilinear interpolation) average pooling method for RoI Pooling. It avoids any quantization and has a continuous gradient on bounding box coordinates. It is: 13 | 14 | - different from the original RoI Pooling proposed in [Fast R-CNN](https://arxiv.org/abs/1504.08083). PrRoI Pooling uses average pooling instead of max pooling for each bin and has a continuous gradient on bounding box coordinates. That is, one can take the derivatives of some loss function w.r.t the coordinates of each RoI and optimize the RoI coordinates. 15 | - different from the RoI Align proposed in [Mask R-CNN](https://arxiv.org/abs/1703.06870). PrRoI Pooling uses a full integration-based average pooling instead of sampling a constant number of points. This makes the gradient w.r.t. the coordinates continuous. 16 | 17 | For a better illustration, we illustrate RoI Pooling, RoI Align and PrRoI Pooing in the following figure. More details including the gradient computation can be found in our paper. 18 | 19 |
20 | 21 | ## Implementation 22 | 23 | PrRoI Pooling was originally implemented by [Tete Xiao](http://tetexiao.com/) based on MegBrain, an (internal) deep learning framework built by Megvii Inc. It was later adapted into open-source deep learning frameworks. Currently, we only support PyTorch. Unfortunately, we don't have any specific plan for the adaptation into other frameworks such as TensorFlow, but any contributions (pull requests) will be more than welcome. 24 | 25 | ## Usage (PyTorch) 26 | 27 | In the directory `pytorch/`, we provide a PyTorch-based implementation of PrRoI Pooling. It requires PyTorch 0.4 and only supports CUDA (CPU mode is not implemented). To use the PrRoI Pooling module, first goto `pytorch/prroi_pool` and execute `./travis.sh` to compile the essential components (you may need `nvcc` for this step). To use the module in your code, simply do: 28 | 29 | ``` 30 | from prroi_pool import PrRoIPool2D 31 | 32 | avg_pool = PrRoIPool2D(window_height, window_width, spatial_scale) 33 | roi_features = avg_pool(features, rois) 34 | 35 | # for those who want to use the "functional" 36 | 37 | from prroi_pool.functional import prroi_pool2d 38 | roi_features = prroi_pool2d(features, rois, window_height, window_width, spatial_scale) 39 | ``` 40 | 41 | Here, 42 | 43 | - RoI is an `m * 5` float tensor of format `(batch_index, x0, y0, x1, y1)`, following the convention in the original Caffe implementation of RoI Pooling, although in some frameworks the batch indices are provided by an integer tensor. 44 | - `spatial_scale` is multiplied to the RoIs. For example, if your feature maps are down-sampled by a factor of 16 (w.r.t. the input image), you should use a spatial scale of `1/16`. 45 | - The coordinates for RoI follows the [L, R) convension. That is, `(0, 0, 4, 4)` denotes a box of size `4x4`. 46 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/_assets/prroi_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/external/PreciseRoIPooling/_assets/prroi_visualization.png -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | /_prroi_pooling 3 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : __init__.py 4 | # Author : Jiayuan Mao, Tete Xiao 5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 6 | # Date : 07/13/2018 7 | # 8 | # This file is part of PreciseRoIPooling. 9 | # Distributed under terms of the MIT license. 10 | # Copyright (c) 2017 Megvii Technology Limited. 11 | 12 | from .prroi_pool import * 13 | 14 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/build.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : build.py 4 | # Author : Jiayuan Mao, Tete Xiao 5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 6 | # Date : 07/13/2018 7 | # 8 | # This file is part of PreciseRoIPooling. 9 | # Distributed under terms of the MIT license. 10 | # Copyright (c) 2017 Megvii Technology Limited. 11 | 12 | import os 13 | import torch 14 | 15 | from torch.utils.ffi import create_extension 16 | 17 | headers = [] 18 | sources = [] 19 | defines = [] 20 | extra_objects = [] 21 | with_cuda = False 22 | 23 | if torch.cuda.is_available(): 24 | with_cuda = True 25 | 26 | headers+= ['src/prroi_pooling_gpu.h'] 27 | sources += ['src/prroi_pooling_gpu.c'] 28 | defines += [('WITH_CUDA', None)] 29 | 30 | this_file = os.path.dirname(os.path.realpath(__file__)) 31 | extra_objects_cuda = ['src/prroi_pooling_gpu_impl.cu.o'] 32 | extra_objects_cuda = [os.path.join(this_file, fname) for fname in extra_objects_cuda] 33 | extra_objects.extend(extra_objects_cuda) 34 | else: 35 | # TODO(Jiayuan Mao @ 07/13): remove this restriction after we support the cpu implementation. 36 | raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.') 37 | 38 | ffi = create_extension( 39 | '_prroi_pooling', 40 | headers=headers, 41 | sources=sources, 42 | define_macros=defines, 43 | relative_to=__file__, 44 | with_cuda=with_cuda, 45 | extra_objects=extra_objects 46 | ) 47 | 48 | if __name__ == '__main__': 49 | ffi.build() 50 | 51 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/functional.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : functional.py 4 | # Author : Jiayuan Mao, Tete Xiao 5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 6 | # Date : 07/13/2018 7 | # 8 | # This file is part of PreciseRoIPooling. 9 | # Distributed under terms of the MIT license. 10 | # Copyright (c) 2017 Megvii Technology Limited. 11 | 12 | import torch 13 | import torch.autograd as ag 14 | 15 | try: 16 | from . import _prroi_pooling 17 | except ImportError: 18 | raise ImportError('Can not found the compiled Precise RoI Pooling library. Run ./travis.sh in the directory first.') 19 | 20 | __all__ = ['prroi_pool2d'] 21 | 22 | 23 | class PrRoIPool2DFunction(ag.Function): 24 | @staticmethod 25 | def forward(ctx, features, rois, pooled_height, pooled_width, spatial_scale): 26 | assert 'FloatTensor' in features.type() and 'FloatTensor' in rois.type(), \ 27 | 'Precise RoI Pooling only takes float input, got {} for features and {} for rois.'.format(features.type(), rois.type()) 28 | 29 | features = features.contiguous() 30 | rois = rois.contiguous() 31 | pooled_height = int(pooled_height) 32 | pooled_width = int(pooled_width) 33 | spatial_scale = float(spatial_scale) 34 | 35 | params = (pooled_height, pooled_width, spatial_scale) 36 | batch_size, nr_channels, data_height, data_width = features.size() 37 | nr_rois = rois.size(0) 38 | output = torch.zeros( 39 | (nr_rois, nr_channels, pooled_height, pooled_width), 40 | dtype=features.dtype, device=features.device 41 | ) 42 | 43 | if features.is_cuda: 44 | _prroi_pooling.prroi_pooling_forward_cuda(features, rois, output, *params) 45 | ctx.params = params 46 | # everything here is contiguous. 47 | ctx.save_for_backward(features, rois, output) 48 | else: 49 | raise NotImplementedError('Precise RoI Pooling only supports GPU (cuda) implememtations.') 50 | 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | features, rois, output = ctx.saved_tensors 56 | grad_input = grad_coor = None 57 | 58 | if features.requires_grad: 59 | grad_output = grad_output.contiguous() 60 | grad_input = torch.zeros_like(features) 61 | _prroi_pooling.prroi_pooling_backward_cuda(features, rois, output, grad_output, grad_input, *ctx.params) 62 | if rois.requires_grad: 63 | grad_output = grad_output.contiguous() 64 | grad_coor = torch.zeros_like(rois) 65 | _prroi_pooling.prroi_pooling_coor_backward_cuda(features, rois, output, grad_output, grad_coor, *ctx.params) 66 | 67 | return grad_input, grad_coor, None, None, None 68 | 69 | 70 | prroi_pool2d = PrRoIPool2DFunction.apply 71 | 72 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/prroi_pool.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : prroi_pool.py 4 | # Author : Jiayuan Mao, Tete Xiao 5 | # Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 6 | # Date : 07/13/2018 7 | # 8 | # This file is part of PreciseRoIPooling. 9 | # Distributed under terms of the MIT license. 10 | # Copyright (c) 2017 Megvii Technology Limited. 11 | 12 | import torch.nn as nn 13 | 14 | from .functional import prroi_pool2d 15 | 16 | __all__ = ['PrRoIPool2D'] 17 | 18 | 19 | class PrRoIPool2D(nn.Module): 20 | def __init__(self, pooled_height, pooled_width, spatial_scale): 21 | super().__init__() 22 | 23 | self.pooled_height = int(pooled_height) 24 | self.pooled_width = int(pooled_width) 25 | self.spatial_scale = float(spatial_scale) 26 | 27 | def forward(self, features, rois): 28 | return prroi_pool2d(features, rois, self.pooled_height, self.pooled_width, self.spatial_scale) 29 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/src/prroi_pooling_gpu.c: -------------------------------------------------------------------------------- 1 | /* 2 | * File : prroi_pooling_gpu.c 3 | * Author : Jiayuan Mao, Tete Xiao 4 | * Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 5 | * Date : 07/13/2018 6 | * 7 | * Distributed under terms of the MIT license. 8 | * Copyright (c) 2017 Megvii Technology Limited. 9 | */ 10 | 11 | #include 12 | #include 13 | 14 | #include "prroi_pooling_gpu_impl.cuh" 15 | 16 | extern THCState *state; 17 | 18 | int prroi_pooling_forward_cuda(THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, int pooled_height, int pooled_width, float spatial_scale) { 19 | const float *data_ptr = THCudaTensor_data(state, features); 20 | const float *rois_ptr = THCudaTensor_data(state, rois); 21 | float *output_ptr = THCudaTensor_data(state, output); 22 | 23 | int nr_rois = THCudaTensor_size(state, rois, 0); 24 | int nr_channels = THCudaTensor_size(state, features, 1); 25 | int height = THCudaTensor_size(state, features, 2); 26 | int width = THCudaTensor_size(state, features, 3); 27 | int top_count = nr_rois * nr_channels * pooled_height * pooled_width; 28 | 29 | cudaStream_t stream = THCState_getCurrentStream(state); 30 | 31 | PrRoIPoolingForwardGpu( 32 | stream, data_ptr, rois_ptr, output_ptr, 33 | nr_channels, height, width, pooled_height, pooled_width, spatial_scale, 34 | top_count 35 | ); 36 | 37 | return 1; 38 | } 39 | 40 | int prroi_pooling_backward_cuda( 41 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff, 42 | int pooled_height, int pooled_width, float spatial_scale) { 43 | 44 | const float *data_ptr = THCudaTensor_data(state, features); 45 | const float *rois_ptr = THCudaTensor_data(state, rois); 46 | const float *output_ptr = THCudaTensor_data(state, output); 47 | const float *output_diff_ptr = THCudaTensor_data(state, output_diff); 48 | float *features_diff_ptr = THCudaTensor_data(state, features_diff); 49 | 50 | int nr_rois = THCudaTensor_size(state, rois, 0); 51 | int batch_size = THCudaTensor_size(state, features, 0); 52 | int nr_channels = THCudaTensor_size(state, features, 1); 53 | int height = THCudaTensor_size(state, features, 2); 54 | int width = THCudaTensor_size(state, features, 3); 55 | int top_count = nr_rois * nr_channels * pooled_height * pooled_width; 56 | int bottom_count = batch_size * nr_channels * height * width; 57 | 58 | cudaStream_t stream = THCState_getCurrentStream(state); 59 | 60 | PrRoIPoolingBackwardGpu( 61 | stream, data_ptr, rois_ptr, output_ptr, output_diff_ptr, features_diff_ptr, 62 | nr_channels, height, width, pooled_height, pooled_width, spatial_scale, 63 | top_count, bottom_count 64 | ); 65 | 66 | return 1; 67 | } 68 | 69 | int prroi_pooling_coor_backward_cuda( 70 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *coor_diff, 71 | int pooled_height, int pooled_width, float spatial_scale) { 72 | 73 | const float *data_ptr = THCudaTensor_data(state, features); 74 | const float *rois_ptr = THCudaTensor_data(state, rois); 75 | const float *output_ptr = THCudaTensor_data(state, output); 76 | const float *output_diff_ptr = THCudaTensor_data(state, output_diff); 77 | float *coor_diff_ptr= THCudaTensor_data(state, coor_diff); 78 | 79 | int nr_rois = THCudaTensor_size(state, rois, 0); 80 | int nr_channels = THCudaTensor_size(state, features, 1); 81 | int height = THCudaTensor_size(state, features, 2); 82 | int width = THCudaTensor_size(state, features, 3); 83 | int top_count = nr_rois * nr_channels * pooled_height * pooled_width; 84 | int bottom_count = nr_rois * 5; 85 | 86 | cudaStream_t stream = THCState_getCurrentStream(state); 87 | 88 | PrRoIPoolingCoorBackwardGpu( 89 | stream, data_ptr, rois_ptr, output_ptr, output_diff_ptr, coor_diff_ptr, 90 | nr_channels, height, width, pooled_height, pooled_width, spatial_scale, 91 | top_count, bottom_count 92 | ); 93 | 94 | return 1; 95 | } 96 | 97 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/src/prroi_pooling_gpu.h: -------------------------------------------------------------------------------- 1 | /* 2 | * File : prroi_pooling_gpu.h 3 | * Author : Jiayuan Mao, Tete Xiao 4 | * Email : maojiayuan@gmail.com, jasonhsiao97@gmail.com 5 | * Date : 07/13/2018 6 | * 7 | * Distributed under terms of the MIT license. 8 | * Copyright (c) 2017 Megvii Technology Limited. 9 | */ 10 | 11 | int prroi_pooling_forward_cuda(THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, int pooled_height, int pooled_width, float spatial_scale); 12 | 13 | int prroi_pooling_backward_cuda( 14 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff, 15 | int pooled_height, int pooled_width, float spatial_scale 16 | ); 17 | 18 | int prroi_pooling_coor_backward_cuda( 19 | THCudaTensor *features, THCudaTensor *rois, THCudaTensor *output, THCudaTensor *output_diff, THCudaTensor *features_diff, 20 | int pooled_height, int pooled_width, float spatial_scal 21 | ); 22 | 23 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/src/prroi_pooling_gpu_impl.cu: -------------------------------------------------------------------------------- 1 | ../../../src/prroi_pooling_gpu_impl.cu -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/src/prroi_pooling_gpu_impl.cuh: -------------------------------------------------------------------------------- 1 | ../../../src/prroi_pooling_gpu_impl.cuh -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/prroi_pool/travis.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash -e 2 | # File : travis.sh 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # 6 | # Distributed under terms of the MIT license. 7 | # Copyright (c) 2017 Megvii Technology Limited. 8 | 9 | cd src 10 | echo "Working directory: " `pwd` 11 | echo "Compiling prroi_pooling kernels by nvcc..." 12 | nvcc -c -o prroi_pooling_gpu_impl.cu.o prroi_pooling_gpu_impl.cu -x cu -Xcompiler -fPIC -arch=sm_52 13 | 14 | cd ../ 15 | echo "Working directory: " `pwd` 16 | echo "Building python libraries..." 17 | python3 build.py 18 | 19 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/pytorch/tests/test_prroi_pooling2d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_prroi_pooling2d.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 18/02/2018 6 | # 7 | # This file is part of Jacinle. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from jactorch.utils.unittest import TorchTestCase 16 | 17 | from prroi_pool import PrRoIPool2D 18 | 19 | 20 | class TestPrRoIPool2D(TorchTestCase): 21 | def test_forward(self): 22 | pool = PrRoIPool2D(7, 7, spatial_scale=0.5) 23 | features = torch.rand((4, 16, 24, 32)).cuda() 24 | rois = torch.tensor([ 25 | [0, 0, 0, 14, 14], 26 | [1, 14, 14, 28, 28], 27 | ]).float().cuda() 28 | 29 | out = pool(features, rois) 30 | out_gold = F.avg_pool2d(features, kernel_size=2, stride=1) 31 | 32 | self.assertTensorClose(out, torch.stack(( 33 | out_gold[0, :, :7, :7], 34 | out_gold[1, :, 7:14, 7:14], 35 | ), dim=0)) 36 | 37 | def test_backward_shapeonly(self): 38 | pool = PrRoIPool2D(2, 2, spatial_scale=0.5) 39 | 40 | features = torch.rand((4, 2, 24, 32)).cuda() 41 | rois = torch.tensor([ 42 | [0, 0, 0, 4, 4], 43 | [1, 14, 14, 18, 18], 44 | ]).float().cuda() 45 | features.requires_grad = rois.requires_grad = True 46 | out = pool(features, rois) 47 | 48 | loss = out.sum() 49 | loss.backward() 50 | 51 | self.assertTupleEqual(features.size(), features.grad.size()) 52 | self.assertTupleEqual(rois.size(), rois.grad.size()) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /ltr/external/PreciseRoIPooling/src/prroi_pooling_gpu_impl.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * File : prroi_pooling_gpu_impl.cuh 3 | * Author : Tete Xiao, Jiayuan Mao 4 | * Email : jasonhsiao97@gmail.com 5 | * 6 | * Distributed under terms of the MIT license. 7 | * Copyright (c) 2017 Megvii Technology Limited. 8 | */ 9 | 10 | #ifndef PRROI_POOLING_GPU_IMPL_CUH 11 | #define PRROI_POOLING_GPU_IMPL_CUH 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | #define F_DEVPTR_IN const float * 18 | #define F_DEVPTR_OUT float * 19 | 20 | void PrRoIPoolingForwardGpu( 21 | cudaStream_t stream, 22 | F_DEVPTR_IN bottom_data, 23 | F_DEVPTR_IN bottom_rois, 24 | F_DEVPTR_OUT top_data, 25 | const int channels_, const int height_, const int width_, 26 | const int pooled_height_, const int pooled_width_, 27 | const float spatial_scale_, 28 | const int top_count); 29 | 30 | void PrRoIPoolingBackwardGpu( 31 | cudaStream_t stream, 32 | F_DEVPTR_IN bottom_data, 33 | F_DEVPTR_IN bottom_rois, 34 | F_DEVPTR_IN top_data, 35 | F_DEVPTR_IN top_diff, 36 | F_DEVPTR_OUT bottom_diff, 37 | const int channels_, const int height_, const int width_, 38 | const int pooled_height_, const int pooled_width_, 39 | const float spatial_scale_, 40 | const int top_count, const int bottom_count); 41 | 42 | void PrRoIPoolingCoorBackwardGpu( 43 | cudaStream_t stream, 44 | F_DEVPTR_IN bottom_data, 45 | F_DEVPTR_IN bottom_rois, 46 | F_DEVPTR_IN top_data, 47 | F_DEVPTR_IN top_diff, 48 | F_DEVPTR_OUT bottom_diff, 49 | const int channels_, const int height_, const int width_, 50 | const int pooled_height_, const int pooled_width_, 51 | const float spatial_scale_, 52 | const int top_count, const int bottom_count); 53 | 54 | #ifdef __cplusplus 55 | } /* !extern "C" */ 56 | #endif 57 | 58 | #endif /* !PRROI_POOLING_GPU_IMPL_CUH */ 59 | 60 | -------------------------------------------------------------------------------- /ltr/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/__init__.py -------------------------------------------------------------------------------- /ltr/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .resnet18_vggm import * 3 | -------------------------------------------------------------------------------- /ltr/models/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/models/backbone/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/backbone/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/models/backbone/__pycache__/resnet18_vggm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/backbone/__pycache__/resnet18_vggm.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/models/bbreg/__init__.py: -------------------------------------------------------------------------------- 1 | from .atom_iou_net import AtomIoUNet 2 | -------------------------------------------------------------------------------- /ltr/models/bbreg/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/bbreg/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/models/bbreg/__pycache__/atom.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/bbreg/__pycache__/atom.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/models/bbreg/__pycache__/atom_iou_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/bbreg/__pycache__/atom_iou_net.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/models/bbreg/atom.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import ltr.models.backbone as backbones 3 | import ltr.models.bbreg as bbmodels 4 | from ltr import model_constructor 5 | 6 | 7 | class ATOMnet(nn.Module): 8 | """ ATOM network module""" 9 | def __init__(self, feature_extractor, bb_regressor, bb_regressor_layer, extractor_grad=True): 10 | """ 11 | args: 12 | feature_extractor - backbone feature extractor 13 | bb_regressor - IoU prediction module 14 | bb_regressor_layer - List containing the name of the layers from feature_extractor, which are input to 15 | bb_regressor 16 | extractor_grad - Bool indicating whether backbone feature extractor requires gradients 17 | """ 18 | super(ATOMnet, self).__init__() 19 | 20 | self.feature_extractor = feature_extractor 21 | self.bb_regressor = bb_regressor 22 | self.bb_regressor_layer = bb_regressor_layer 23 | 24 | if not extractor_grad: 25 | for p in self.feature_extractor.parameters(): 26 | p.requires_grad_(False) 27 | 28 | def forward(self, train_imgs, test_imgs, train_bb, test_proposals): 29 | """ Forward pass 30 | Note: If the training is done in sequence mode, that is, test_imgs.dim() == 5, then the batch dimension 31 | corresponds to the first dimensions. test_imgs is thus of the form [sequence, batch, feature, row, col] 32 | """ 33 | num_sequences = train_imgs.shape[-4] 34 | num_train_images = train_imgs.shape[0] if train_imgs.dim() == 5 else 1 35 | num_test_images = test_imgs.shape[0] if test_imgs.dim() == 5 else 1 36 | 37 | # Extract backbone features 38 | train_feat = self.extract_backbone_features( 39 | train_imgs.view(-1, train_imgs.shape[-3], train_imgs.shape[-2], train_imgs.shape[-1])) 40 | test_feat = self.extract_backbone_features( 41 | test_imgs.view(-1, test_imgs.shape[-3], test_imgs.shape[-2], test_imgs.shape[-1])) 42 | 43 | # For clarity, send the features to bb_regressor in sequence form, i.e. [sequence, batch, feature, row, col] 44 | train_feat_iou = [feat.view(num_train_images, num_sequences, feat.shape[-3], feat.shape[-2], feat.shape[-1]) 45 | for feat in train_feat.values()] 46 | test_feat_iou = [feat.view(num_test_images, num_sequences, feat.shape[-3], feat.shape[-2], feat.shape[-1]) 47 | for feat in test_feat.values()] 48 | 49 | # Obtain iou prediction 50 | iou_pred = self.bb_regressor(train_feat_iou, test_feat_iou, 51 | train_bb.view(num_train_images, num_sequences, 4), 52 | test_proposals.view(num_train_images, num_sequences, -1, 4)) 53 | return iou_pred 54 | 55 | def extract_backbone_features(self, im, layers=None): 56 | if layers is None: 57 | layers = self.bb_regressor_layer 58 | return self.feature_extractor(im, layers) 59 | 60 | def extract_features(self, im, layers): 61 | return self.feature_extractor(im, layers) 62 | 63 | 64 | 65 | @model_constructor 66 | def atom_resnet18(iou_input_dim=(256,256), iou_inter_dim=(256,256), backbone_pretrained=True): 67 | # backbone 68 | backbone_net = backbones.resnet18(pretrained=backbone_pretrained) 69 | 70 | # Bounding box regressor 71 | iou_predictor = bbmodels.AtomIoUNet(pred_input_dim=iou_input_dim, pred_inter_dim=iou_inter_dim) 72 | 73 | net = ATOMnet(feature_extractor=backbone_net, bb_regressor=iou_predictor, bb_regressor_layer=['layer2', 'layer3'], 74 | extractor_grad=False) 75 | 76 | return net 77 | -------------------------------------------------------------------------------- /ltr/models/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/layers/__init__.py -------------------------------------------------------------------------------- /ltr/models/layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/models/layers/__pycache__/blocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/models/layers/__pycache__/blocks.cpython-36.pyc -------------------------------------------------------------------------------- /ltr/models/layers/blocks.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def conv_block(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=True, 5 | batch_norm=True, relu=True): 6 | layers = [nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 7 | padding=padding, dilation=dilation, bias=bias)] 8 | if batch_norm: 9 | layers.append(nn.BatchNorm2d(out_planes)) 10 | if relu: 11 | layers.append(nn.ReLU(inplace=True)) 12 | return nn.Sequential(*layers) 13 | 14 | 15 | class LinearBlock(nn.Module): 16 | def __init__(self, in_planes, out_planes, input_sz, bias=True, batch_norm=True, relu=True): 17 | super().__init__() 18 | self.linear = nn.Linear(in_planes*input_sz*input_sz, out_planes, bias=bias) 19 | self.bn = nn.BatchNorm2d(out_planes) if batch_norm else None 20 | self.relu = nn.ReLU(inplace=True) if relu else None 21 | 22 | def forward(self, x): 23 | x = self.linear(x.view(x.shape[0], -1)) 24 | if self.bn is not None: 25 | x = self.bn(x.view(x.shape[0], x.shape[1], 1, 1)) 26 | if self.relu is not None: 27 | x = self.relu(x) 28 | return x.view(x.shape[0], -1) -------------------------------------------------------------------------------- /ltr/run_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import importlib 5 | import multiprocessing 6 | import cv2 as cv 7 | import torch.backends.cudnn 8 | 9 | env_path = os.path.join(os.path.dirname(__file__), '..') 10 | if env_path not in sys.path: 11 | sys.path.append(env_path) 12 | 13 | import ltr.admin.settings as ws_settings 14 | 15 | 16 | def run_training(train_module, train_name, cudnn_benchmark=True): 17 | """Run a train scripts in train_settings. 18 | args: 19 | train_module: Name of module in the "train_settings/" folder. 20 | train_name: Name of the train settings file. 21 | cudnn_benchmark: Use cudnn benchmark or not (default is True). 22 | """ 23 | 24 | # This is needed to avoid strange crashes related to opencv 25 | cv.setNumThreads(0) 26 | 27 | torch.backends.cudnn.benchmark = cudnn_benchmark 28 | 29 | print('Training: {} {}'.format(train_module, train_name)) 30 | 31 | settings = ws_settings.Settings() 32 | 33 | if settings.env.workspace_dir == '': 34 | raise Exception('Setup your workspace_dir in "ltr/admin/local.py".') 35 | 36 | settings.module_name = train_module 37 | settings.script_name = train_name 38 | settings.project_path = 'ltr/{}/{}'.format(train_module, train_name) 39 | 40 | expr_module = importlib.import_module('ltr.train_settings.{}.{}'.format(train_module, train_name)) 41 | expr_func = getattr(expr_module, 'run') 42 | 43 | expr_func(settings) 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.') 48 | parser.add_argument('train_module', type=str, help='Name of module in the "train_settings/" folder.') 49 | parser.add_argument('train_name', type=str, help='Name of the train settings file.') 50 | parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).') 51 | 52 | args = parser.parse_args() 53 | 54 | run_training(args.train_module, args.train_name, args.cudnn_benchmark) 55 | 56 | 57 | if __name__ == '__main__': 58 | multiprocessing.set_start_method('spawn', force=True) 59 | main() 60 | -------------------------------------------------------------------------------- /ltr/train_settings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/train_settings/__init__.py -------------------------------------------------------------------------------- /ltr/train_settings/bbreg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/ltr/train_settings/bbreg/__init__.py -------------------------------------------------------------------------------- /ltr/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | from .ltr_trainer import LTRTrainer -------------------------------------------------------------------------------- /pytracking/__init__.py: -------------------------------------------------------------------------------- 1 | from pytracking.libs import TensorList, TensorDict 2 | import pytracking.libs.complex as complex 3 | import pytracking.libs.operation as operation 4 | import pytracking.libs.fourier as fourier 5 | import pytracking.libs.dcf as dcf 6 | import pytracking.libs.optimization as optimization 7 | from pytracking.run_tracker import run_tracker 8 | from pytracking.run_webcam import run_webcam 9 | -------------------------------------------------------------------------------- /pytracking/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytracking/__pycache__/run_tracker.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/__pycache__/run_tracker.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/__pycache__/run_webcam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/__pycache__/run_webcam.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .otbdataset import OTBDataset 2 | from .nfsdataset import NFSDataset 3 | from .uavdataset import UAVDataset 4 | from .tpldataset import TPLDataset 5 | from .votdataset import VOTDataset 6 | from .trackingnetdataset import TrackingNetDataset 7 | from .got10kdataset import GOT10KDatasetTest, GOT10KDatasetVal, GOT10KDatasetLTRVal 8 | from .lasotdataset import LaSOTDataset 9 | from .data import Sequence 10 | from .tracker import Tracker 11 | -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/environment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/environment.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/got10kdataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/got10kdataset.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/lasotdataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/lasotdataset.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/local.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/local.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/nfsdataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/nfsdataset.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/otbdataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/otbdataset.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/running.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/running.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/tpldataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/tpldataset.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/tracker.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/tracker.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/trackingnetdataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/trackingnetdataset.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/uavdataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/uavdataset.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/__pycache__/votdataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/__pycache__/votdataset.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/data.py: -------------------------------------------------------------------------------- 1 | from pytracking.evaluation.environment import env_settings 2 | 3 | 4 | class BaseDataset: 5 | """Base class for all datasets.""" 6 | def __init__(self): 7 | self.env_settings = env_settings() 8 | 9 | def __len__(self): 10 | """Overload this function in your dataset. This should return number of sequences in the dataset.""" 11 | raise NotImplementedError 12 | 13 | def get_sequence_list(self): 14 | """Overload this in your dataset. Should return the list of sequences in the dataset.""" 15 | raise NotImplementedError 16 | 17 | 18 | class Sequence: 19 | """Class for the sequence in an evaluation.""" 20 | def __init__(self, name, frames, ground_truth_rect, gt, object_class=None): 21 | self.name = name 22 | self.frames = frames 23 | self.ground_truth_rect = ground_truth_rect 24 | self.gt = gt 25 | self.init_state = list(self.ground_truth_rect[0,:]) 26 | self.object_class = object_class 27 | 28 | 29 | class SequenceList(list): 30 | """List of sequences. Supports the addition operator to concatenate sequence lists.""" 31 | def __getitem__(self, item): 32 | if isinstance(item, str): 33 | for seq in self: 34 | if seq.name == item: 35 | return seq 36 | raise IndexError('Sequence name not in the dataset.') 37 | elif isinstance(item, int): 38 | return super(SequenceList, self).__getitem__(item) 39 | elif isinstance(item, (tuple, list)): 40 | return SequenceList([super(SequenceList, self).__getitem__(i) for i in item]) 41 | else: 42 | return SequenceList(super(SequenceList, self).__getitem__(item)) 43 | 44 | def __add__(self, other): 45 | return SequenceList(super(SequenceList, self).__add__(other)) 46 | 47 | def copy(self): 48 | return SequenceList(super(SequenceList, self).copy()) -------------------------------------------------------------------------------- /pytracking/evaluation/environment.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | 5 | class EnvSettings: 6 | def __init__(self): 7 | pytracking_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 8 | 9 | self.results_path = '{}/tracking_results/'.format(pytracking_path) 10 | self.network_path = '{}/networks/'.format(pytracking_path) 11 | self.otb_path = '' 12 | self.nfs_path = '' 13 | self.uav_path = '' 14 | self.tpl_path = '' 15 | self.vot_path = '' 16 | self.got10k_path = '' 17 | self.lasot_path = '' 18 | self.trackingnet_path = '' 19 | 20 | 21 | def create_default_local_file(): 22 | comment = {'results_path': 'Where to store tracking results', 23 | 'network_path': 'Where tracking networks are stored.'} 24 | 25 | path = os.path.join(os.path.dirname(__file__), 'local.py') 26 | with open(path, 'w') as f: 27 | settings = EnvSettings() 28 | 29 | f.write('from pytracking.evaluation.environment import EnvSettings\n\n') 30 | f.write('def local_env_settings():\n') 31 | f.write(' settings = EnvSettings()\n\n') 32 | f.write(' # Set your local paths here.\n\n') 33 | 34 | for attr in dir(settings): 35 | comment_str = None 36 | if attr in comment: 37 | comment_str = comment[attr] 38 | attr_val = getattr(settings, attr) 39 | if not attr.startswith('__') and not callable(attr_val): 40 | if comment_str is None: 41 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val)) 42 | else: 43 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str)) 44 | f.write('\n return settings\n\n') 45 | 46 | 47 | def env_settings(): 48 | env_module_name = 'pytracking.evaluation.local' 49 | try: 50 | env_module = importlib.import_module(env_module_name) 51 | return env_module.local_env_settings() 52 | except: 53 | env_file = os.path.join(os.path.dirname(__file__), 'local.py') 54 | 55 | # Create a default file 56 | create_default_local_file() 57 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. ' 58 | 'Then try to run again.'.format(env_file)) 59 | -------------------------------------------------------------------------------- /pytracking/evaluation/got10kdataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pytracking.evaluation.data import Sequence, BaseDataset, SequenceList 3 | import os 4 | 5 | 6 | def GOT10KDatasetTest(): 7 | """ GOT-10k official test set""" 8 | return GOT10KDatasetClass('test').get_sequence_list() 9 | 10 | 11 | def GOT10KDatasetVal(): 12 | """ GOT-10k official val set""" 13 | return GOT10KDatasetClass('val').get_sequence_list() 14 | 15 | 16 | def GOT10KDatasetLTRVal(): 17 | """ GOT-10k val split from LTR (a subset of GOT-10k official train set)""" 18 | return GOT10KDatasetClass('ltrval').get_sequence_list() 19 | 20 | 21 | class GOT10KDatasetClass(BaseDataset): 22 | """ GOT-10k dataset. 23 | 24 | Publication: 25 | GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild 26 | Lianghua Huang, Xin Zhao, and Kaiqi Huang 27 | arXiv:1810.11981, 2018 28 | https://arxiv.org/pdf/1810.11981.pdf 29 | 30 | Download dataset from http://got-10k.aitestunion.com/downloads 31 | """ 32 | def __init__(self, split): 33 | """ 34 | args: 35 | split - Split to use. Can be i) 'test': official test set, ii) 'val': official val set, and iii) 'ltrval': 36 | a custom validation set, a subset of the official train set. 37 | """ 38 | super().__init__() 39 | # Split can be test, val, or ltrval 40 | if split == 'test' or split == 'val': 41 | self.base_path = os.path.join(self.env_settings.got10k_path, split) 42 | else: 43 | self.base_path = os.path.join(self.env_settings.got10k_path, 'train') 44 | 45 | self.sequence_list = self._get_sequence_list(split) 46 | self.split = split 47 | 48 | def get_sequence_list(self): 49 | return SequenceList([self._construct_sequence(s) for s in self.sequence_list]) 50 | 51 | def _construct_sequence(self, sequence_name): 52 | anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name) 53 | try: 54 | ground_truth_rect = np.loadtxt(str(anno_path), dtype=np.float64) 55 | except: 56 | ground_truth_rect = np.loadtxt(str(anno_path), delimiter=',', dtype=np.float64) 57 | 58 | frames_path = '{}/{}'.format(self.base_path, sequence_name) 59 | frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")] 60 | frame_list.sort(key=lambda f: int(f[:-4])) 61 | frames_list = [os.path.join(frames_path, frame) for frame in frame_list] 62 | 63 | return Sequence(sequence_name, frames_list, ground_truth_rect.reshape(-1, 4)) 64 | 65 | def __len__(self): 66 | '''Overload this function in your evaluation. This should return number of sequences in the evaluation ''' 67 | return len(self.sequence_list) 68 | 69 | def _get_sequence_list(self, split): 70 | with open('{}/list.txt'.format(self.base_path)) as f: 71 | sequence_list = f.read().splitlines() 72 | 73 | if split == 'ltrval': 74 | with open('{}/got10k_val_split.txt'.format(self.env_settings.dataspec_path)) as f: 75 | seq_ids = f.read().splitlines() 76 | 77 | sequence_list = [sequence_list[int(x)] for x in seq_ids] 78 | return sequence_list 79 | -------------------------------------------------------------------------------- /pytracking/evaluation/local.py: -------------------------------------------------------------------------------- 1 | from evaluation.environment import EnvSettings 2 | 3 | def local_env_settings(): 4 | settings = EnvSettings() 5 | 6 | # Set your local paths here. 7 | 8 | settings.got10k_path = '' 9 | settings.lasot_path = '' 10 | settings.network_path = '/home/.../tracking/SPSTracker/pytracking/network/' # Where tracking networks are stored. 11 | settings.nfs_path = '' 12 | settings.otb_path = '' 13 | settings.results_path = '//home/.../tracking/SPSTracker/pytracking/tracking_results' # Where to store tracking results 14 | settings.tpl_path = '' 15 | settings.trackingnet_path = '' 16 | settings.uav_path = '' 17 | settings.vot_path = '/home/.../data/VOT2018' 18 | 19 | return settings 20 | 21 | -------------------------------------------------------------------------------- /pytracking/evaluation/running.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import multiprocessing 3 | import os 4 | from itertools import product 5 | from pytracking.evaluation import Sequence, Tracker 6 | from os import makedirs 7 | from os.path import join, isdir, isfile 8 | from pytracking.evaluation.utils.pyvotkit.region import vot_float2str 9 | from os.path import split 10 | 11 | def run_sequence(seq: Sequence, tracker: Tracker, debug=False): 12 | """Runs a tracker on a sequence.""" 13 | base_results_path = '{}/baseline/{}'.format(tracker.results_dir, seq.name) 14 | results_path = '{}.txt'.format(base_results_path) 15 | times_path = '{}_time.txt'.format(base_results_path) 16 | 17 | if os.path.isfile(results_path) and not debug: 18 | return 19 | 20 | print('Tracker: {} {} {} , Sequence: {}'.format(tracker.name, tracker.parameter_name, tracker.run_id, seq.name)) 21 | 22 | if debug: 23 | tracked_bb, exec_times = tracker.run(seq, debug=debug) 24 | else: 25 | try: 26 | tracked_bb, exec_times = tracker.run(seq, debug=debug) 27 | except Exception as e: 28 | print(e) 29 | return 30 | path = base_results_path.split('/') 31 | path = path[-1] 32 | 33 | if not isdir(base_results_path): makedirs(base_results_path) 34 | results_path = join(base_results_path, '{:s}_001.txt').format(path) 35 | #'{}.txt'.format(base_results_path) 36 | with open(results_path, "w") as fin: 37 | for x in tracked_bb: 38 | fin.write("{:d}\n".format(x)) if isinstance(x, int) else \ 39 | fin.write(','.join([vot_float2str("%.4f", i) for i in x]) + '\n') 40 | # tracked_bb = np.array(tracked_bb).astype(int) 41 | exec_times = np.array(exec_times).astype(float) 42 | 43 | print('FPS: {}'.format(len(exec_times) / exec_times.sum())) 44 | # if not debug: 45 | # np.savetxt(results_path, tracked_bb, delimiter='\t', fmt='%d') 46 | # np.savetxt(times_path, exec_times, delimiter='\t', fmt='%f') 47 | 48 | 49 | def run_dataset(dataset, trackers, debug=False, threads=0): 50 | """Runs a list of trackers on a dataset. 51 | args: 52 | dataset: List of Sequence instances, forming a dataset. 53 | trackers: List of Tracker instances. 54 | debug: Debug level. 55 | threads: Number of threads to use (default 0). 56 | """ 57 | if threads == 0: 58 | mode = 'sequential' 59 | else: 60 | mode = 'parallel' 61 | 62 | if mode == 'sequential': 63 | for seq in dataset: 64 | for tracker_info in trackers: 65 | run_sequence(seq, tracker_info, debug=debug) 66 | elif mode == 'parallel': 67 | param_list = [(seq, tracker_info, debug) for seq, tracker_info in product(dataset, trackers)] 68 | with multiprocessing.Pool(processes=threads) as pool: 69 | pool.starmap(run_sequence, param_list) 70 | print('Done') 71 | -------------------------------------------------------------------------------- /pytracking/evaluation/tracker.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import pickle 4 | from pytracking.evaluation.environment import env_settings 5 | 6 | 7 | class Tracker: 8 | """Wraps the tracker for evaluation and running purposes. 9 | args: 10 | name: Name of tracking method. 11 | parameter_name: Name of parameter file. 12 | run_id: The run id. 13 | """ 14 | 15 | def __init__(self, name: str, parameter_name: str, run_id: int = None): 16 | self.name = name 17 | self.parameter_name = parameter_name 18 | self.run_id = run_id 19 | 20 | env = env_settings() 21 | if self.run_id is None: 22 | self.results_dir = '{}/{}/{}'.format(env.results_path, self.name, self.parameter_name) 23 | else: 24 | self.results_dir = '{}/{}/{}_{:03d}'.format(env.results_path, self.name, self.parameter_name, self.run_id) 25 | if not os.path.exists(self.results_dir): 26 | os.makedirs(self.results_dir) 27 | 28 | tracker_module = importlib.import_module('pytracking.tracker.{}'.format(self.name)) 29 | 30 | self.parameters = self.get_parameters() 31 | self.tracker_class = tracker_module.get_tracker_class() 32 | 33 | self.default_visualization = getattr(self.parameters, 'visualization', False) 34 | self.default_debug = getattr(self.parameters, 'debug', 0) 35 | 36 | def run(self, seq, visualization=None, debug=None): 37 | """Run tracker on sequence. 38 | args: 39 | seq: Sequence to run the tracker on. 40 | visualization: Set visualization flag (None means default value specified in the parameters). 41 | debug: Set debug level (None means default value specified in the parameters). 42 | """ 43 | visualization_ = visualization 44 | debug_ = debug 45 | if debug is None: 46 | debug_ = self.default_debug 47 | if visualization is None: 48 | if debug is None: 49 | visualization_ = self.default_visualization 50 | else: 51 | visualization_ = True if debug else False 52 | 53 | self.parameters.visualization = visualization_ 54 | self.parameters.debug = debug_ 55 | 56 | tracker = self.tracker_class(self.parameters) 57 | 58 | output_bb, execution_times = tracker.track_sequence(seq) 59 | 60 | self.parameters.free_memory() 61 | 62 | return output_bb, execution_times 63 | def run_video(self, videofilepath, optional_box=None, debug=None): 64 | """Run the tracker with the vieofile. 65 | args: 66 | debug: Debug level. 67 | """ 68 | 69 | debug_ = debug 70 | if debug is None: 71 | debug_ = self.default_debug 72 | self.parameters.debug = debug_ 73 | 74 | self.parameters.tracker_name = self.name 75 | self.parameters.param_name = self.parameter_name 76 | tracker = self.tracker_class(self.parameters) 77 | tracker.track_videofile(videofilepath, optional_box) 78 | 79 | def run_webcam(self, debug=None): 80 | """Run the tracker with the webcam. 81 | args: 82 | debug: Debug level. 83 | """ 84 | 85 | debug_ = debug 86 | if debug is None: 87 | debug_ = self.default_debug 88 | self.parameters.debug = debug_ 89 | 90 | self.parameters.tracker_name = self.name 91 | self.parameters.param_name = self.parameter_name 92 | tracker = self.tracker_class(self.parameters) 93 | 94 | tracker.track_webcam() 95 | 96 | def get_parameters(self): 97 | """Get parameters.""" 98 | 99 | parameter_file = '{}/parameters.pkl'.format(self.results_dir) 100 | if os.path.isfile(parameter_file): 101 | return pickle.load(open(parameter_file, 'rb')) 102 | 103 | param_module = importlib.import_module('pytracking.parameter.{}.{}'.format(self.name, self.parameter_name)) 104 | params = param_module.parameters() 105 | 106 | if self.run_id is not None: 107 | pickle.dump(params, open(parameter_file, 'wb')) 108 | 109 | return params 110 | 111 | 112 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/__init__.py -------------------------------------------------------------------------------- /pytracking/evaluation/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/utils/anchors.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | import numpy as np 7 | import math 8 | from utils.bbox_helper import center2corner, corner2center 9 | 10 | 11 | class Anchors: 12 | def __init__(self, cfg): 13 | self.stride = 8 14 | self.ratios = [0.33, 0.5, 1, 2, 3] 15 | self.scales = [8] 16 | self.round_dight = 0 17 | self.image_center = 0 18 | self.size = 0 19 | 20 | self.__dict__.update(cfg) 21 | 22 | self.anchor_num = len(self.scales) * len(self.ratios) 23 | self.anchors = None # in single position (anchor_num*4) 24 | self.all_anchors = None # in all position 2*(4*anchor_num*h*w) 25 | self.generate_anchors() 26 | 27 | def generate_anchors(self): 28 | self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32) 29 | 30 | size = self.stride * self.stride 31 | count = 0 32 | for r in self.ratios: 33 | if self.round_dight > 0: 34 | ws = round(math.sqrt(size*1. / r), self.round_dight) 35 | hs = round(ws * r, self.round_dight) 36 | else: 37 | ws = int(math.sqrt(size*1. / r)) 38 | hs = int(ws * r) 39 | 40 | for s in self.scales: 41 | w = ws * s 42 | h = hs * s 43 | self.anchors[count][:] = [-w*0.5, -h*0.5, w*0.5, h*0.5][:] 44 | count += 1 45 | 46 | def generate_all_anchors(self, im_c, size): 47 | if self.image_center == im_c and self.size == size: 48 | return False 49 | self.image_center = im_c 50 | self.size = size 51 | 52 | a0x = im_c - size // 2 * self.stride 53 | ori = np.array([a0x] * 4, dtype=np.float32) 54 | zero_anchors = self.anchors + ori 55 | 56 | x1 = zero_anchors[:, 0] 57 | y1 = zero_anchors[:, 1] 58 | x2 = zero_anchors[:, 2] 59 | y2 = zero_anchors[:, 3] 60 | 61 | x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), [x1, y1, x2, y2]) 62 | cx, cy, w, h = corner2center([x1, y1, x2, y2]) 63 | 64 | disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride 65 | disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride 66 | 67 | cx = cx + disp_x 68 | cy = cy + disp_y 69 | 70 | # broadcast 71 | zero = np.zeros((self.anchor_num, size, size), dtype=np.float32) 72 | cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h]) 73 | x1, y1, x2, y2 = center2corner([cx, cy, w, h]) 74 | 75 | self.all_anchors = np.stack([x1, y1, x2, y2]), np.stack([cx, cy, w, h]) 76 | return True 77 | 78 | 79 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/bbox_helper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | import numpy as np 7 | from collections import namedtuple 8 | 9 | Corner = namedtuple('Corner', 'x1 y1 x2 y2') 10 | BBox = Corner 11 | Center = namedtuple('Center', 'x y w h') 12 | 13 | 14 | def corner2center(corner): 15 | """ 16 | :param corner: Corner or np.array 4*N 17 | :return: Center or 4 np.array N 18 | """ 19 | if isinstance(corner, Corner): 20 | x1, y1, x2, y2 = corner 21 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1)) 22 | else: 23 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3] 24 | x = (x1 + x2) * 0.5 25 | y = (y1 + y2) * 0.5 26 | w = x2 - x1 27 | h = y2 - y1 28 | return x, y, w, h 29 | 30 | 31 | def center2corner(center): 32 | """ 33 | :param center: Center or np.array 4*N 34 | :return: Corner or np.array 4*N 35 | """ 36 | if isinstance(center, Center): 37 | x, y, w, h = center 38 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5) 39 | else: 40 | x, y, w, h = center[0], center[1], center[2], center[3] 41 | x1 = x - w * 0.5 42 | y1 = y - h * 0.5 43 | x2 = x + w * 0.5 44 | y2 = y + h * 0.5 45 | return x1, y1, x2, y2 46 | 47 | 48 | def cxy_wh_2_rect(pos, sz): 49 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) # 0-index 50 | 51 | 52 | def get_axis_aligned_bbox(region): 53 | nv = region.size 54 | if nv == 8: 55 | cx = np.mean(region[0::2]) 56 | cy = np.mean(region[1::2]) 57 | x1 = min(region[0::2]) 58 | x2 = max(region[0::2]) 59 | y1 = min(region[1::2]) 60 | y2 = max(region[1::2]) 61 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[2:4] - region[4:6]) 62 | A2 = (x2 - x1) * (y2 - y1) 63 | s = np.sqrt(A1 / A2) 64 | w = s * (x2 - x1) + 1 65 | h = s * (y2 - y1) + 1 66 | else: 67 | x = region[0] 68 | y = region[1] 69 | w = region[2] 70 | h = region[3] 71 | cx = x+w/2 72 | cy = y+h/2 73 | 74 | return cx, cy, w, h 75 | 76 | 77 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/benchmark_helper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | from os.path import join, realpath, dirname, exists, isdir 7 | from os import listdir 8 | import logging 9 | import glob 10 | import numpy as np 11 | import json 12 | from collections import OrderedDict 13 | 14 | 15 | def get_dataset_zoo(): 16 | root = realpath(join(dirname(__file__), '../data')) 17 | zoos = listdir(root) 18 | 19 | def valid(x): 20 | y = join(root, x) 21 | if not isdir(y): return False 22 | 23 | return exists(join(y, 'list.txt')) \ 24 | or exists(join(y, 'train', 'meta.json'))\ 25 | or exists(join(y, 'ImageSets', '2016', 'val.txt')) 26 | 27 | zoos = list(filter(valid, zoos)) 28 | return zoos 29 | 30 | 31 | dataset_zoo = get_dataset_zoo() 32 | 33 | 34 | def load_dataset(dataset): 35 | info = OrderedDict() 36 | if 'VOT' in dataset: 37 | 38 | base_path = join(realpath(dirname(__file__)), '../data', dataset) 39 | print(base_path) 40 | if not exists(base_path): 41 | print('aaaa',base_path) 42 | logging.error("Please download test dataset!!!") 43 | exit() 44 | list_path = join(base_path, 'list.txt') 45 | with open(list_path) as f: 46 | videos = [v.strip() for v in f.readlines()] 47 | for video in videos: 48 | video_path = join(base_path, video) 49 | image_path = join(video_path, '*.jpg') 50 | image_files = sorted(glob.glob(image_path)) 51 | if len(image_files) == 0: # VOT2018 52 | image_path = join(video_path, 'color', '*.jpg') 53 | image_files = sorted(glob.glob(image_path)) 54 | gt_path = join(video_path, 'groundtruth.txt') 55 | gt = np.loadtxt(gt_path, delimiter=',').astype(np.float64) 56 | if gt.shape[1] == 4: 57 | gt = np.column_stack((gt[:, 0], gt[:, 1], gt[:, 0], gt[:, 1] + gt[:, 3]-1, 58 | gt[:, 0] + gt[:, 2]-1, gt[:, 1] + gt[:, 3]-1, gt[:, 0] + gt[:, 2]-1, gt[:, 1])) 59 | info[video] = {'image_files': image_files, 'gt': gt, 'name': video} 60 | elif 'DAVIS' in dataset: 61 | base_path = join(realpath(dirname(__file__)), '../data', 'DAVIS') 62 | list_path = join(realpath(dirname(__file__)), '../data', 'DAVIS', 'ImageSets', dataset[-4:], 'val.txt') 63 | with open(list_path) as f: 64 | videos = [v.strip() for v in f.readlines()] 65 | for video in videos: 66 | info[video] = {} 67 | info[video]['anno_files'] = sorted(glob.glob(join(base_path, 'Annotations/480p', video, '*.png'))) 68 | info[video]['image_files'] = sorted(glob.glob(join(base_path, 'JPEGImages/480p', video, '*.jpg'))) 69 | info[video]['name'] = video 70 | elif 'ytb_vos' in dataset: 71 | base_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid') 72 | json_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid', 'meta.json') 73 | meta = json.load(open(json_path, 'r')) 74 | meta = meta['videos'] 75 | info = dict() 76 | for v in meta.keys(): 77 | objects = meta[v]['objects'] 78 | frames = [] 79 | anno_frames = [] 80 | info[v] = dict() 81 | for obj in objects: 82 | frames += objects[obj]['frames'] 83 | anno_frames += [objects[obj]['frames'][0]] 84 | frames = sorted(np.unique(frames)) 85 | info[v]['anno_files'] = [join(base_path, 'Annotations', v, im_f+'.png') for im_f in frames] 86 | info[v]['anno_init_files'] = [join(base_path, 'Annotations', v, im_f + '.png') for im_f in anno_frames] 87 | info[v]['image_files'] = [join(base_path, 'JPEGImages', v, im_f+'.jpg') for im_f in frames] 88 | info[v]['name'] = v 89 | 90 | info[v]['start_frame'] = dict() 91 | info[v]['end_frame'] = dict() 92 | for obj in objects: 93 | start_file = objects[obj]['frames'][0] 94 | end_file = objects[obj]['frames'][-1] 95 | info[v]['start_frame'][obj] = frames.index(start_file) 96 | info[v]['end_frame'][obj] = frames.index(end_file) 97 | else: 98 | logging.error('Not support') 99 | exit() 100 | return info 101 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/config_helper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | import json 7 | from os.path import exists 8 | 9 | 10 | def load_config(args): 11 | assert exists(args.config), '"{}" not exists'.format(args.config) 12 | config = json.load(open(args.config)) 13 | 14 | # deal with network 15 | if 'network' not in config: 16 | print('Warning: network lost in config. This will be error in next version') 17 | 18 | config['network'] = {} 19 | 20 | if not args.arch: 21 | raise Exception('no arch provided') 22 | args.arch = config['network']['arch'] 23 | 24 | return config 25 | 26 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/load_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | logger = logging.getLogger('global') 4 | 5 | 6 | def check_keys(model, pretrained_state_dict): 7 | ckpt_keys = set(pretrained_state_dict.keys()) 8 | model_keys = set(model.state_dict().keys()) 9 | used_pretrained_keys = model_keys & ckpt_keys 10 | unused_pretrained_keys = ckpt_keys - model_keys 11 | missing_keys = model_keys - ckpt_keys 12 | if len(missing_keys) > 0: 13 | logger.info('[Warning] missing keys: {}'.format(missing_keys)) 14 | logger.info('missing keys:{}'.format(len(missing_keys))) 15 | if len(unused_pretrained_keys) > 0: 16 | logger.info('[Warning] unused_pretrained_keys: {}'.format(unused_pretrained_keys)) 17 | logger.info('unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) 18 | logger.info('used keys:{}'.format(len(used_pretrained_keys))) 19 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' 20 | return True 21 | 22 | 23 | def remove_prefix(state_dict, prefix): 24 | ''' Old style model is stored with all names of parameters share common prefix 'module.' ''' 25 | logger.info('remove prefix \'{}\''.format(prefix)) 26 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 27 | return {f(key): value for key, value in state_dict.items()} 28 | 29 | 30 | def load_pretrain(model, pretrained_path): 31 | logger.info('load pretrained model from {}'.format(pretrained_path)) 32 | device = torch.cuda.current_device() 33 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) 34 | if "state_dict" in pretrained_dict.keys(): 35 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') 36 | else: 37 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 38 | 39 | try: 40 | check_keys(model, pretrained_dict) 41 | except: 42 | logger.info('[Warning]: using pretrain as features. Adding "features." as prefix') 43 | new_dict = {} 44 | for k, v in pretrained_dict.items(): 45 | k = 'features.' + k 46 | new_dict[k] = v 47 | pretrained_dict = new_dict 48 | check_keys(model, pretrained_dict) 49 | model.load_state_dict(pretrained_dict, strict=False) 50 | return model 51 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/log_helper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | from __future__ import division 7 | 8 | import os 9 | import logging 10 | import sys 11 | 12 | if hasattr(sys, 'frozen'): # support for py2exe 13 | _srcfile = "logging%s__init__%s" % (os.sep, __file__[-4:]) 14 | elif __file__[-4:].lower() in ['.pyc', '.pyo']: 15 | _srcfile = __file__[:-4] + '.py' 16 | else: 17 | _srcfile = __file__ 18 | _srcfile = os.path.normcase(_srcfile) 19 | 20 | 21 | logs = set() 22 | 23 | 24 | class Filter: 25 | def __init__(self, flag): 26 | self.flag = flag 27 | 28 | def filter(self, x): return self.flag 29 | 30 | 31 | class Dummy: 32 | def __init__(self, *arg, **kwargs): 33 | pass 34 | 35 | def __getattr__(self, arg): 36 | def dummy(*args, **kwargs): pass 37 | return dummy 38 | 39 | 40 | def get_format(logger, level): 41 | if 'SLURM_PROCID' in os.environ: 42 | rank = int(os.environ['SLURM_PROCID']) 43 | 44 | if level == logging.INFO: 45 | logger.addFilter(Filter(rank == 0)) 46 | else: 47 | rank = 0 48 | format_str = '[%(asctime)s-rk{}-%(filename)s#%(lineno)3d] %(message)s'.format(rank) 49 | formatter = logging.Formatter(format_str) 50 | return formatter 51 | 52 | 53 | def init_log(name, level = logging.INFO, format_func=get_format): 54 | if (name, level) in logs: return 55 | logs.add((name, level)) 56 | logger = logging.getLogger(name) 57 | logger.setLevel(level) 58 | ch = logging.StreamHandler() 59 | ch.setLevel(level) 60 | formatter = format_func(logger, level) 61 | ch.setFormatter(formatter) 62 | logger.addHandler(ch) 63 | return logger 64 | 65 | 66 | def add_file_handler(name, log_file, level = logging.INFO): 67 | logger = logging.getLogger(name) 68 | fh = logging.FileHandler(log_file) 69 | fh.setFormatter(get_format(logger, level)) 70 | logger.addHandler(fh) 71 | 72 | 73 | init_log('global') 74 | 75 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pysot/__init__.py -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from .vot import VOTDataset 10 | 11 | 12 | class DatasetFactory(object): 13 | @staticmethod 14 | def create_dataset(**kwargs): 15 | """ 16 | Args: 17 | name: dataset name 'VOT2018', 'VOT2016' 18 | dataset_root: dataset root 19 | Return: 20 | dataset 21 | """ 22 | assert 'name' in kwargs, "should provide dataset name" 23 | name = kwargs['name'] 24 | if 'VOT2018' == name or 'VOT2016' == name: 25 | dataset = VOTDataset(**kwargs) 26 | else: 27 | raise Exception("unknow dataset {}".format(kwargs['name'])) 28 | return dataset 29 | 30 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | class Dataset(object): 10 | def __init__(self, name, dataset_root): 11 | self.name = name 12 | self.dataset_root = dataset_root 13 | self.videos = None 14 | 15 | def __getitem__(self, idx): 16 | if isinstance(idx, str): 17 | return self.videos[idx] 18 | elif isinstance(idx, int): 19 | return self.videos[sorted(list(self.videos.keys()))[idx]] 20 | 21 | def __len__(self): 22 | return len(self.videos) 23 | 24 | def __iter__(self): 25 | keys = sorted(list(self.videos.keys())) 26 | for key in keys: 27 | yield self.videos[key] 28 | 29 | def set_tracker(self, path, tracker_names): 30 | """ 31 | Args: 32 | path: path to tracker results, 33 | tracker_names: list of tracker name 34 | """ 35 | self.tracker_path = path 36 | self.tracker_names = tracker_names 37 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/datasets/video.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | import os 10 | import cv2 11 | 12 | from glob import glob 13 | 14 | 15 | class Video(object): 16 | def __init__(self, name, root, video_dir, init_rect, img_names, 17 | gt_rect, attr): 18 | self.name = name 19 | self.video_dir = video_dir 20 | self.init_rect = init_rect 21 | self.gt_traj = gt_rect 22 | self.attr = attr 23 | self.pred_trajs = {} 24 | self.img_names = [os.path.join(root, x) for x in img_names] 25 | self.imgs = None 26 | 27 | def load_tracker(self, path, tracker_names=None, store=True): 28 | """ 29 | Args: 30 | path(str): path to result 31 | tracker_name(list): name of tracker 32 | """ 33 | if not tracker_names: 34 | tracker_names = [x.split('/')[-1] for x in glob(path) 35 | if os.path.isdir(x)] 36 | if isinstance(tracker_names, str): 37 | tracker_names = [tracker_names] 38 | for name in tracker_names: 39 | traj_file = os.path.join(path, name, self.name+'.txt') 40 | if os.path.exists(traj_file): 41 | with open(traj_file, 'r') as f : 42 | pred_traj = [list(map(float, x.strip().split(','))) 43 | for x in f.readlines()] 44 | if len(pred_traj) != len(self.gt_traj): 45 | print(name, len(pred_traj), len(self.gt_traj), self.name) 46 | if store: 47 | self.pred_trajs[name] = pred_traj 48 | else: 49 | return pred_traj 50 | else: 51 | print(traj_file) 52 | self.tracker_names = list(self.pred_trajs.keys()) 53 | 54 | def load_img(self): 55 | if self.imgs is None: 56 | self.imgs = [cv2.imread(x) 57 | for x in self.img_names] 58 | self.width = self.imgs[0].shape[1] 59 | self.height = self.imgs[0].shape[0] 60 | 61 | def free_img(self): 62 | self.imgs = None 63 | 64 | def __len__(self): 65 | return len(self.img_names) 66 | 67 | def __getitem__(self, idx): 68 | if self.imgs is None: 69 | return cv2.imread(self.img_names[idx]), \ 70 | self.gt_traj[idx] 71 | else: 72 | return self.imgs[idx], self.gt_traj[idx] 73 | 74 | def __iter__(self): 75 | for i in range(len(self.img_names)): 76 | if self.imgs is not None: 77 | yield self.imgs[i], self.gt_traj[i] 78 | else: 79 | yield cv2.imread(self.img_names[i]), \ 80 | self.gt_traj[i] 81 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from .ar_benchmark import AccuracyRobustnessBenchmark 10 | from .eao_benchmark import EAOBenchmark 11 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from . import region 10 | from .statistics import * 11 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/utils/build/temp.linux-x86_64-3.6/region.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pysot/utils/build/temp.linux-x86_64-3.6/region.o -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/utils/build/temp.linux-x86_64-3.6/src/region.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pysot/utils/build/temp.linux-x86_64-3.6/src/region.o -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/utils/c_region.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "src/region.h": 2 | ctypedef enum region_type "RegionType": 3 | EMTPY 4 | SPECIAL 5 | RECTANGEL 6 | POLYGON 7 | MASK 8 | 9 | ctypedef struct region_bounds: 10 | float top 11 | float bottom 12 | float left 13 | float right 14 | 15 | ctypedef struct region_rectangle: 16 | float x 17 | float y 18 | float width 19 | float height 20 | 21 | # ctypedef struct region_mask: 22 | # int x 23 | # int y 24 | # int width 25 | # int height 26 | # char *data 27 | 28 | ctypedef struct region_polygon: 29 | int count 30 | float *x 31 | float *y 32 | 33 | ctypedef union region_container_data: 34 | region_rectangle rectangle 35 | region_polygon polygon 36 | # region_mask mask 37 | int special 38 | 39 | ctypedef struct region_container: 40 | region_type type 41 | region_container_data data 42 | 43 | # ctypedef struct region_overlap: 44 | # float overlap 45 | # float only1 46 | # float only2 47 | 48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds) 49 | 50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds) 51 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | import numpy as np 10 | 11 | def determine_thresholds(confidence, resolution=100): 12 | """choose threshold according to confidence 13 | 14 | Args: 15 | confidence: list or numpy array or numpy array 16 | reolution: number of threshold to choose 17 | 18 | Restures: 19 | threshold: numpy array 20 | """ 21 | if isinstance(confidence, list): 22 | confidence = np.array(confidence) 23 | confidence = confidence.flatten() 24 | confidence = confidence[~np.isnan(confidence)] 25 | confidence.sort() 26 | 27 | assert len(confidence) > resolution and resolution > 2 28 | 29 | thresholds = np.ones((resolution)) 30 | thresholds[0] = - np.inf 31 | thresholds[-1] = np.inf 32 | delta = np.floor(len(confidence) / (resolution - 2)) 33 | idxs = np.linspace(delta, len(confidence)-delta, resolution-2, dtype=np.int32) 34 | thresholds[1:-1] = confidence[idxs] 35 | return thresholds 36 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/utils/region.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pysot/utils/region.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/utils/setup.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from distutils.core import setup 10 | from distutils.extension import Extension 11 | from Cython.Build import cythonize 12 | 13 | setup( 14 | ext_modules = cythonize([Extension("region", ["region.pyx", "src/region.c"])]), 15 | ) 16 | 17 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pysot/utils/src/region.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */ 2 | 3 | #ifndef _REGION_H_ 4 | #define _REGION_H_ 5 | 6 | #ifdef TRAX_STATIC_DEFINE 7 | # define __TRAX_EXPORT 8 | #else 9 | # ifndef __TRAX_EXPORT 10 | # if defined(_MSC_VER) 11 | # ifdef trax_EXPORTS 12 | /* We are building this library */ 13 | # define __TRAX_EXPORT __declspec(dllexport) 14 | # else 15 | /* We are using this library */ 16 | # define __TRAX_EXPORT __declspec(dllimport) 17 | # endif 18 | # elif defined(__GNUC__) 19 | # ifdef trax_EXPORTS 20 | /* We are building this library */ 21 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 22 | # else 23 | /* We are using this library */ 24 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 25 | # endif 26 | # endif 27 | # endif 28 | #endif 29 | 30 | #ifndef MAX 31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b)) 32 | #endif 33 | 34 | #ifndef MIN 35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b)) 36 | #endif 37 | 38 | #define TRAX_DEFAULT_CODE 0 39 | 40 | #define REGION_LEGACY_RASTERIZATION 1 41 | 42 | #ifdef __cplusplus 43 | extern "C" { 44 | #endif 45 | 46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type; 47 | 48 | typedef struct region_bounds { 49 | 50 | float top; 51 | float bottom; 52 | float left; 53 | float right; 54 | 55 | } region_bounds; 56 | 57 | typedef struct region_polygon { 58 | 59 | int count; 60 | 61 | float* x; 62 | float* y; 63 | 64 | } region_polygon; 65 | 66 | typedef struct region_mask { 67 | 68 | int x; 69 | int y; 70 | 71 | int width; 72 | int height; 73 | 74 | char* data; 75 | 76 | } region_mask; 77 | 78 | typedef struct region_rectangle { 79 | 80 | float x; 81 | float y; 82 | float width; 83 | float height; 84 | 85 | } region_rectangle; 86 | 87 | typedef struct region_container { 88 | enum region_type type; 89 | union { 90 | region_rectangle rectangle; 91 | region_polygon polygon; 92 | region_mask mask; 93 | int special; 94 | } data; 95 | } region_container; 96 | 97 | typedef struct region_overlap { 98 | 99 | float overlap; 100 | float only1; 101 | float only2; 102 | 103 | } region_overlap; 104 | 105 | extern const region_bounds region_no_bounds; 106 | 107 | __TRAX_EXPORT int region_set_flags(int mask); 108 | 109 | __TRAX_EXPORT int region_clear_flags(int mask); 110 | 111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds); 112 | 113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds); 114 | 115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom); 116 | 117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region); 118 | 119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region); 120 | 121 | __TRAX_EXPORT char* region_string(region_container* region); 122 | 123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region); 124 | 125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type); 126 | 127 | __TRAX_EXPORT void region_release(region_container** region); 128 | 129 | __TRAX_EXPORT region_container* region_create_special(int code); 130 | 131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height); 132 | 133 | __TRAX_EXPORT region_container* region_create_polygon(int count); 134 | 135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y); 136 | 137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height); 138 | 139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height); 140 | 141 | #ifdef __cplusplus 142 | } 143 | #endif 144 | 145 | #endif 146 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pyvotkit/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from . import region 10 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pyvotkit/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pyvotkit/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pyvotkit/build/temp.linux-x86_64-3.6/region.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pyvotkit/build/temp.linux-x86_64-3.6/region.o -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pyvotkit/build/temp.linux-x86_64-3.6/src/region.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pyvotkit/build/temp.linux-x86_64-3.6/src/region.o -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pyvotkit/c_region.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "region.h": 2 | ctypedef enum region_type "RegionType": 3 | EMTPY 4 | SPECIAL 5 | RECTANGEL 6 | POLYGON 7 | MASK 8 | 9 | ctypedef struct region_bounds: 10 | float top 11 | float bottom 12 | float left 13 | float right 14 | 15 | ctypedef struct region_rectangle: 16 | float x 17 | float y 18 | float width 19 | float height 20 | 21 | # ctypedef struct region_mask: 22 | # int x 23 | # int y 24 | # int width 25 | # int height 26 | # char *data 27 | 28 | ctypedef struct region_polygon: 29 | int count 30 | float *x 31 | float *y 32 | 33 | ctypedef union region_container_data: 34 | region_rectangle rectangle 35 | region_polygon polygon 36 | # region_mask mask 37 | int special 38 | 39 | ctypedef struct region_container: 40 | region_type type 41 | region_container_data data 42 | 43 | # ctypedef struct region_overlap: 44 | # float overlap 45 | # float only1 46 | # float only2 47 | 48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds) 49 | 50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds) 51 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pyvotkit/region.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/evaluation/utils/pyvotkit/region.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pyvotkit/setup.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from distutils.core import setup 10 | from distutils.extension import Extension 11 | from Cython.Build import cythonize 12 | 13 | setup( 14 | ext_modules = cythonize([Extension("region", ["region.pyx"])]) 15 | ) 16 | 17 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/pyvotkit/src/region.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */ 2 | 3 | #ifndef _REGION_H_ 4 | #define _REGION_H_ 5 | 6 | #ifdef TRAX_STATIC_DEFINE 7 | # define __TRAX_EXPORT 8 | #else 9 | # ifndef __TRAX_EXPORT 10 | # if defined(_MSC_VER) 11 | # ifdef trax_EXPORTS 12 | /* We are building this library */ 13 | # define __TRAX_EXPORT __declspec(dllexport) 14 | # else 15 | /* We are using this library */ 16 | # define __TRAX_EXPORT __declspec(dllimport) 17 | # endif 18 | # elif defined(__GNUC__) 19 | # ifdef trax_EXPORTS 20 | /* We are building this library */ 21 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 22 | # else 23 | /* We are using this library */ 24 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 25 | # endif 26 | # endif 27 | # endif 28 | #endif 29 | 30 | #ifndef MAX 31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b)) 32 | #endif 33 | 34 | #ifndef MIN 35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b)) 36 | #endif 37 | 38 | #define TRAX_DEFAULT_CODE 0 39 | 40 | #define REGION_LEGACY_RASTERIZATION 1 41 | 42 | #ifdef __cplusplus 43 | extern "C" { 44 | #endif 45 | 46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type; 47 | 48 | typedef struct region_bounds { 49 | 50 | float top; 51 | float bottom; 52 | float left; 53 | float right; 54 | 55 | } region_bounds; 56 | 57 | typedef struct region_polygon { 58 | 59 | int count; 60 | 61 | float* x; 62 | float* y; 63 | 64 | } region_polygon; 65 | 66 | typedef struct region_mask { 67 | 68 | int x; 69 | int y; 70 | 71 | int width; 72 | int height; 73 | 74 | char* data; 75 | 76 | } region_mask; 77 | 78 | typedef struct region_rectangle { 79 | 80 | float x; 81 | float y; 82 | float width; 83 | float height; 84 | 85 | } region_rectangle; 86 | 87 | typedef struct region_container { 88 | enum region_type type; 89 | union { 90 | region_rectangle rectangle; 91 | region_polygon polygon; 92 | region_mask mask; 93 | int special; 94 | } data; 95 | } region_container; 96 | 97 | typedef struct region_overlap { 98 | 99 | float overlap; 100 | float only1; 101 | float only2; 102 | 103 | } region_overlap; 104 | 105 | extern const region_bounds region_no_bounds; 106 | 107 | __TRAX_EXPORT int region_set_flags(int mask); 108 | 109 | __TRAX_EXPORT int region_clear_flags(int mask); 110 | 111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds); 112 | 113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds); 114 | 115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom); 116 | 117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region); 118 | 119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region); 120 | 121 | __TRAX_EXPORT char* region_string(region_container* region); 122 | 123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region); 124 | 125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type); 126 | 127 | __TRAX_EXPORT void region_release(region_container** region); 128 | 129 | __TRAX_EXPORT region_container* region_create_special(int code); 130 | 131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height); 132 | 133 | __TRAX_EXPORT region_container* region_create_polygon(int count); 134 | 135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y); 136 | 137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height); 138 | 139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height); 140 | 141 | #ifdef __cplusplus 142 | } 143 | #endif 144 | 145 | #endif 146 | -------------------------------------------------------------------------------- /pytracking/evaluation/utils/tracker_config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | from __future__ import division 7 | from utils.anchors import Anchors 8 | 9 | 10 | class TrackerConfig(object): 11 | # These are the default hyper-params for SiamMask 12 | penalty_k = 0.04 13 | window_influence = 0.42 14 | lr = 0.25 15 | seg_thr = 0.3 # for mask 16 | windowing = 'cosine' # to penalize large displacements [cosine/uniform] 17 | # Params from the network architecture, have to be consistent with the training 18 | exemplar_size = 127 # input z size 19 | instance_size = 255 # input x size (search region) 20 | total_stride = 8 21 | out_size = 63 # for mask 22 | base_size = 8 23 | score_size = (instance_size-exemplar_size)//total_stride+1 + base_size 24 | context_amount = 0.5 # context amount for the exemplar 25 | ratios = [0.33, 0.5, 1, 2, 3] 26 | scales = [8, ] 27 | anchor_num = len(ratios) * len(scales) 28 | round_dight = 0 29 | anchor = [] 30 | 31 | def update(self, newparam=None, anchors=None): 32 | if newparam: 33 | for key, value in newparam.items(): 34 | setattr(self, key, value) 35 | if anchors is not None: 36 | if isinstance(anchors, dict): 37 | anchors = Anchors(anchors) 38 | if isinstance(anchors, Anchors): 39 | self.total_stride = anchors.stride 40 | self.ratios = anchors.ratios 41 | self.scales = anchors.scales 42 | self.round_dight = anchors.round_dight 43 | self.renew() 44 | 45 | def renew(self): 46 | self.score_size = (self.instance_size - self.exemplar_size) // self.total_stride + 1 + self.base_size 47 | self.anchor_num = len(self.ratios) * len(self.scales) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /pytracking/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/experiments/__init__.py -------------------------------------------------------------------------------- /pytracking/experiments/myexperiments.py: -------------------------------------------------------------------------------- 1 | from pytracking.evaluation import Tracker, OTBDataset, NFSDataset, UAVDataset, TPLDataset, VOTDataset, TrackingNetDataset, LaSOTDataset 2 | 3 | 4 | def atom_nfs_uav(): 5 | # Run three runs of ATOM on NFS and UAV datasets 6 | trackers = [Tracker('atom', 'default', i) for i in range(3)] 7 | 8 | dataset = NFSDataset() + UAVDataset() 9 | return trackers, dataset 10 | 11 | 12 | def uav_test(): 13 | # Run ATOM and ECO on the UAV dataset 14 | trackers = [Tracker('atom', 'default', i) for i in range(1)] + \ 15 | [Tracker('eco', 'default', i) for i in range(1)] 16 | 17 | dataset = UAVDataset() 18 | return trackers, dataset 19 | -------------------------------------------------------------------------------- /pytracking/features/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__init__.py -------------------------------------------------------------------------------- /pytracking/features/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/features/__pycache__/augmentation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/augmentation.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/features/__pycache__/deep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/deep.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/features/__pycache__/extractor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/extractor.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/features/__pycache__/featurebase.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/featurebase.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/features/__pycache__/preprocessing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/features/__pycache__/preprocessing.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/features/color.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytracking.features.featurebase import FeatureBase 3 | 4 | 5 | class RGB(FeatureBase): 6 | """RGB feature normalized to [-0.5, 0.5].""" 7 | def dim(self): 8 | return 3 9 | 10 | def stride(self): 11 | return self.pool_stride 12 | 13 | def extract(self, im: torch.Tensor): 14 | return im/255 - 0.5 15 | 16 | 17 | class Grayscale(FeatureBase): 18 | """Grayscale feature normalized to [-0.5, 0.5].""" 19 | def dim(self): 20 | return 1 21 | 22 | def stride(self): 23 | return self.pool_stride 24 | 25 | def extract(self, im: torch.Tensor): 26 | return torch.mean(im/255 - 0.5, 1, keepdim=True) 27 | -------------------------------------------------------------------------------- /pytracking/features/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def numpy_to_torch(a: np.ndarray): 7 | return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0) 8 | 9 | 10 | def torch_to_numpy(a: torch.Tensor): 11 | return a.squeeze(0).permute(1,2,0).numpy() 12 | 13 | 14 | def sample_patch(im: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, output_sz: torch.Tensor = None): 15 | """Sample an image patch. 16 | 17 | args: 18 | im: Image 19 | pos: center position of crop 20 | sample_sz: size to crop 21 | output_sz: size to resize to 22 | """ 23 | 24 | # copy and convert 25 | posl = pos.long().clone() 26 | 27 | # Compute pre-downsampling factor 28 | if output_sz is not None: 29 | resize_factor = torch.min(sample_sz.float() / output_sz.float()).item() 30 | df = int(max(int(resize_factor - 0.1), 1)) 31 | else: 32 | df = int(1) 33 | 34 | sz = sample_sz.float() / df # new size 35 | 36 | # Do downsampling 37 | if df > 1: 38 | os = posl % df # offset 39 | posl = (posl - os) / df # new position 40 | im2 = im[..., os[0].item()::df, os[1].item()::df] # downsample 41 | else: 42 | im2 = im 43 | 44 | # compute size to crop 45 | szl = torch.max(sz.round(), torch.Tensor([2])).long() 46 | 47 | # Extract top and bottom coordinates 48 | tl = posl - (szl - 1)/2 49 | br = posl + szl/2 50 | 51 | # Get image patch 52 | im_patch = F.pad(im2, (-tl[1].item(), br[1].item() - im2.shape[3] + 1, -tl[0].item(), br[0].item() - im2.shape[2] + 1), 'replicate') 53 | 54 | if output_sz is None or (im_patch.shape[-2] == output_sz[0] and im_patch.shape[-1] == output_sz[1]): 55 | return im_patch 56 | 57 | # Resample 58 | im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='bilinear') 59 | 60 | return im_patch 61 | -------------------------------------------------------------------------------- /pytracking/features/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytracking.features.featurebase import FeatureBase 3 | 4 | 5 | class Concatenate(FeatureBase): 6 | """A feature that concatenates other features. 7 | args: 8 | features: List of features to concatenate. 9 | """ 10 | def __init__(self, features, pool_stride = None, normalize_power = None, use_for_color = True, use_for_gray = True): 11 | super(Concatenate, self).__init__(pool_stride, normalize_power, use_for_color, use_for_gray) 12 | self.features = features 13 | 14 | self.input_stride = self.features[0].stride() 15 | 16 | for feat in self.features: 17 | if self.input_stride != feat.stride(): 18 | raise ValueError('Strides for the features must be the same for a bultiresolution feature.') 19 | 20 | def dim(self): 21 | return sum([f.dim() for f in self.features]) 22 | 23 | def stride(self): 24 | return self.pool_stride * self.input_stride 25 | 26 | def extract(self, im: torch.Tensor): 27 | return torch.cat([f.get_feature(im) for f in self.features], 1) 28 | -------------------------------------------------------------------------------- /pytracking/libs/__init__.py: -------------------------------------------------------------------------------- 1 | from .tensorlist import TensorList 2 | from .tensordict import TensorDict -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/complex.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/complex.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/dcf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/dcf.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/fourier.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/fourier.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/operation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/operation.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/optimization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/optimization.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/tensordict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/tensordict.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/tensorlist.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/tensorlist.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/libs/__pycache__/tensorlist.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/libs/__pycache__/tensorlist.cpython-37.pyc -------------------------------------------------------------------------------- /pytracking/libs/operation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from pytracking.libs.tensorlist import tensor_operation, TensorList 4 | 5 | 6 | @tensor_operation 7 | def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None, stride=1, padding=0, dilation=1, groups=1, mode=None): 8 | """Standard conv2d. Returns the input if weight=None.""" 9 | 10 | if weight is None: 11 | return input 12 | 13 | ind = None 14 | if mode is not None: 15 | if padding != 0: 16 | raise ValueError('Cannot input both padding and mode.') 17 | if mode == 'same': 18 | padding = (weight.shape[2]//2, weight.shape[3]//2) 19 | if weight.shape[2] % 2 == 0 or weight.shape[3] % 2 == 0: 20 | ind = (slice(-1) if weight.shape[2] % 2 == 0 else slice(None), 21 | slice(-1) if weight.shape[3] % 2 == 0 else slice(None)) 22 | elif mode == 'valid': 23 | padding = (0, 0) 24 | elif mode == 'full': 25 | padding = (weight.shape[2]-1, weight.shape[3]-1) 26 | else: 27 | raise ValueError('Unknown mode for padding.') 28 | 29 | out = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 30 | if ind is None: 31 | return out 32 | return out[:,:,ind[0],ind[1]] 33 | 34 | 35 | @tensor_operation 36 | def conv1x1(input: torch.Tensor, weight: torch.Tensor): 37 | """Do a convolution with a 1x1 kernel weights. Implemented with matmul, which can be faster than using conv.""" 38 | 39 | if weight is None: 40 | return input 41 | 42 | return torch.matmul(weight.view(weight.shape[0], weight.shape[1]), 43 | input.view(input.shape[0], input.shape[1], -1)).view(input.shape[0], weight.shape[0], input.shape[2], input.shape[3]) 44 | -------------------------------------------------------------------------------- /pytracking/libs/tensordict.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | 4 | 5 | class TensorDict(OrderedDict): 6 | """Container mainly used for dicts of torch tensors. Extends OrderedDict with pytorch functionality.""" 7 | 8 | def concat(self, other): 9 | """Concatenates two dicts without copying internal data.""" 10 | return TensorDict(self, **other) 11 | 12 | def copy(self): 13 | return TensorDict(super(TensorDict, self).copy()) 14 | 15 | def __getattr__(self, name): 16 | if not hasattr(torch.Tensor, name): 17 | raise AttributeError('\'TensorDict\' object has not attribute \'{}\''.format(name)) 18 | 19 | def apply_attr(*args, **kwargs): 20 | return TensorDict({n: getattr(e, name)(*args, **kwargs) if hasattr(e, name) else e for n, e in self.items()}) 21 | return apply_attr 22 | 23 | def attribute(self, attr: str, *args): 24 | return TensorDict({n: getattr(e, attr, *args) for n, e in self.items()}) 25 | 26 | def apply(self, fn, *args, **kwargs): 27 | return TensorDict({n: fn(e, *args, **kwargs) for n, e in self.items()}) 28 | 29 | @staticmethod 30 | def _iterable(a): 31 | return isinstance(a, (TensorDict, list)) 32 | 33 | -------------------------------------------------------------------------------- /pytracking/parameter/SPSTracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/SPSTracker/__init__.py -------------------------------------------------------------------------------- /pytracking/parameter/SPSTracker/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/SPSTracker/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/parameter/SPSTracker/__pycache__/default_vot.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/SPSTracker/__pycache__/default_vot.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/parameter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/__init__.py -------------------------------------------------------------------------------- /pytracking/parameter/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/parameter/atom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/atom/__init__.py -------------------------------------------------------------------------------- /pytracking/parameter/atom/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/atom/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/parameter/atom/__pycache__/default.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/atom/__pycache__/default.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/parameter/atom/__pycache__/default_vot.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/parameter/atom/__pycache__/default_vot.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/run_experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import importlib 5 | 6 | env_path = os.path.join(os.path.dirname(__file__), '..') 7 | if env_path not in sys.path: 8 | sys.path.append(env_path) 9 | 10 | from pytracking.evaluation.running import run_dataset 11 | 12 | 13 | def run_experiment(experiment_module: str, experiment_name: str, debug=0, threads=0): 14 | """Run experiment. 15 | args: 16 | experiment_module: Name of experiment module in the experiments/ folder. 17 | experiment_name: Name of the experiment function. 18 | debug: Debug level. 19 | threads: Number of threads. 20 | """ 21 | expr_module = importlib.import_module('pytracking.experiments.{}'.format(experiment_module)) 22 | expr_func = getattr(expr_module, experiment_name) 23 | trackers, dataset = expr_func() 24 | print('Running: {} {}'.format(experiment_module, experiment_name)) 25 | run_dataset(dataset, trackers, debug, threads) 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser(description='Run tracker.') 30 | parser.add_argument('experiment_module', type=str, help='Name of experiment module in the experiments/ folder.') 31 | parser.add_argument('experiment_name', type=str, help='Name of the experiment function.') 32 | parser.add_argument('--debug', type=int, default=0, help='Debug level.') 33 | parser.add_argument('--threads', type=int, default=0, help='Number of threads.') 34 | 35 | args = parser.parse_args() 36 | 37 | run_experiment(args.experiment_module, args.experiment_name, args.debug, args.threads) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /pytracking/run_tracker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | env_path = os.path.join(os.path.dirname(__file__), '..') 6 | if env_path not in sys.path: 7 | sys.path.append(env_path) 8 | 9 | from pytracking.evaluation.otbdataset import OTBDataset 10 | from pytracking.evaluation.nfsdataset import NFSDataset 11 | from pytracking.evaluation.uavdataset import UAVDataset 12 | from pytracking.evaluation.tpldataset import TPLDataset 13 | from pytracking.evaluation.votdataset import VOTDataset 14 | from pytracking.evaluation.lasotdataset import LaSOTDataset 15 | from pytracking.evaluation.trackingnetdataset import TrackingNetDataset 16 | from pytracking.evaluation.got10kdataset import GOT10KDatasetTest, GOT10KDatasetVal, GOT10KDatasetLTRVal 17 | from pytracking.evaluation.running import run_dataset 18 | from pytracking.evaluation import Tracker 19 | 20 | 21 | def run_tracker(tracker_name, tracker_param, run_id=None, dataset_name='otb', sequence=None, debug=0, threads=0): 22 | """Run tracker on sequence or dataset. 23 | args: 24 | tracker_name: Name of tracking method. 25 | tracker_param: Name of parameter file. 26 | run_id: The run id. 27 | dataset_name: Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot). 28 | sequence: Sequence number or name. 29 | debug: Debug level. 30 | threads: Number of threads. 31 | """ 32 | if dataset_name == 'otb': 33 | dataset = OTBDataset() 34 | elif dataset_name == 'nfs': 35 | dataset = NFSDataset() 36 | elif dataset_name == 'uav': 37 | dataset = UAVDataset() 38 | elif dataset_name == 'tpl': 39 | dataset = TPLDataset() 40 | elif dataset_name == 'vot': 41 | dataset = VOTDataset() 42 | elif dataset_name == 'tn': 43 | dataset = TrackingNetDataset() 44 | elif dataset_name == 'gott': 45 | dataset = GOT10KDatasetTest() 46 | elif dataset_name == 'gotv': 47 | dataset = GOT10KDatasetVal() 48 | elif dataset_name == 'gotlv': 49 | dataset = GOT10KDatasetLTRVal() 50 | elif dataset_name == 'lasot': 51 | dataset = LaSOTDataset() 52 | else: 53 | raise ValueError('Unknown dataset name') 54 | 55 | if sequence is not None: 56 | dataset = [dataset[sequence]] 57 | 58 | trackers = [Tracker(tracker_name, tracker_param, run_id)] 59 | 60 | run_dataset(dataset, trackers, debug, threads) 61 | 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser(description='Run tracker on sequence or dataset.') 65 | parser.add_argument('tracker_name', type=str, help='Name of tracking method.') 66 | parser.add_argument('tracker_param', type=str, help='Name of parameter file.') 67 | parser.add_argument('--runid', type=int, default=None, help='The run id.') 68 | parser.add_argument('--dataset', type=str, default='otb', help='Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot).') 69 | parser.add_argument('--sequence', type=str, default=None, help='Sequence number or name.') 70 | parser.add_argument('--debug', type=int, default=0, help='Debug level.') 71 | parser.add_argument('--threads', type=int, default=0, help='Number of threads.') 72 | 73 | args = parser.parse_args() 74 | 75 | run_tracker(args.tracker_name, args.tracker_param, args.runid, args.dataset, args.sequence, args.debug, args.threads) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /pytracking/run_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | env_path = os.path.join(os.path.dirname(__file__), '..') 6 | if env_path not in sys.path: 7 | sys.path.append(env_path) 8 | 9 | from pytracking.evaluation import Tracker 10 | 11 | 12 | def run_video(tracker_name, tracker_param, videofile, optional_box=None, debug=None): 13 | """Run the tracker on your webcam. 14 | args: 15 | tracker_name: Name of tracking method. 16 | tracker_param: Name of parameter file. 17 | debug: Debug level. 18 | """ 19 | tracker = Tracker(tracker_name, tracker_param) 20 | tracker.run_video(videofilepath=videofile, optional_box=optional_box, debug=debug) 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser(description='Run the tracker on your webcam.') 24 | parser.add_argument('tracker_name', type=str, help='Name of tracking method.') 25 | parser.add_argument('tracker_param', type=str, help='Name of parameter file.') 26 | parser.add_argument('videofile', type=str, help='path to a video file.') 27 | parser.add_argument('--optional_box', default=None, help='optional_box with format x,y,w,h.') 28 | parser.add_argument('--debug', type=int, default=0, help='Debug level.') 29 | 30 | args = parser.parse_args() 31 | 32 | run_video(args.tracker_name, args.tracker_param,args.videofile, args.optional_box, args.debug) 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /pytracking/run_webcam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | env_path = os.path.join(os.path.dirname(__file__), '..') 6 | if env_path not in sys.path: 7 | sys.path.append(env_path) 8 | 9 | from pytracking.evaluation import Tracker 10 | 11 | 12 | def run_webcam(tracker_name, tracker_param, debug=None): 13 | """Run the tracker on your webcam. 14 | args: 15 | tracker_name: Name of tracking method. 16 | tracker_param: Name of parameter file. 17 | debug: Debug level. 18 | """ 19 | tracker = Tracker(tracker_name, tracker_param) 20 | tracker.run_webcam(debug) 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description='Run the tracker on your webcam.') 25 | parser.add_argument('tracker_name', type=str, help='Name of tracking method.') 26 | parser.add_argument('tracker_param', type=str, help='Name of parameter file.') 27 | parser.add_argument('--debug', type=int, default=0, help='Debug level.') 28 | 29 | args = parser.parse_args() 30 | 31 | run_webcam(args.tracker_name, args.tracker_param, args.debug) 32 | 33 | 34 | if __name__ == '__main__': 35 | main() -------------------------------------------------------------------------------- /pytracking/tracker/SPSTracker/__init__.py: -------------------------------------------------------------------------------- 1 | from .SPSTracker import SPSTracker 2 | 3 | def get_tracker_class(): 4 | return SPSTracker 5 | -------------------------------------------------------------------------------- /pytracking/tracker/SPSTracker/__pycache__/SPSTracker.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/SPSTracker/__pycache__/SPSTracker.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/tracker/SPSTracker/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/SPSTracker/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/tracker/SPSTracker/__pycache__/optim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/SPSTracker/__pycache__/optim.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/tracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/__init__.py -------------------------------------------------------------------------------- /pytracking/tracker/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/tracker/atom/__init__.py: -------------------------------------------------------------------------------- 1 | from .atom import ATOM 2 | 3 | def get_tracker_class(): 4 | return ATOM -------------------------------------------------------------------------------- /pytracking/tracker/atom/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/atom/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/tracker/atom/__pycache__/atom.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/atom/__pycache__/atom.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/tracker/atom/__pycache__/optim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/atom/__pycache__/optim.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/tracker/atom/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytracking import optimization, TensorList, operation 3 | import math 4 | 5 | 6 | class FactorizedConvProblem(optimization.L2Problem): 7 | def __init__(self, training_samples: TensorList, y: TensorList, filter_reg: torch.Tensor, projection_reg, params, sample_weights: TensorList, 8 | projection_activation, response_activation): 9 | self.training_samples = training_samples 10 | self.y = y 11 | self.filter_reg = filter_reg 12 | self.sample_weights = sample_weights 13 | self.params = params 14 | self.projection_reg = projection_reg 15 | self.projection_activation = projection_activation 16 | self.response_activation = response_activation 17 | 18 | self.diag_M = self.filter_reg.concat(projection_reg) 19 | 20 | def __call__(self, x: TensorList): 21 | """ 22 | Compute residuals 23 | :param x: [filters, projection_matrices] 24 | :return: [data_terms, filter_regularizations, proj_mat_regularizations] 25 | """ 26 | filter = x[:len(x)//2] # w2 in paper 27 | P = x[len(x)//2:] # w1 in paper 28 | 29 | # Do first convolution 30 | compressed_samples = operation.conv1x1(self.training_samples, P).apply(self.projection_activation) 31 | 32 | # Do second convolution 33 | residuals = operation.conv2d(compressed_samples, filter, mode='same').apply(self.response_activation) 34 | 35 | # Compute data residuals 36 | residuals = residuals - self.y 37 | 38 | residuals = self.sample_weights.sqrt().view(-1, 1, 1, 1) * residuals 39 | 40 | # Add regularization for projection matrix 41 | residuals.extend(self.filter_reg.apply(math.sqrt) * filter) 42 | 43 | # Add regularization for projection matrix 44 | residuals.extend(self.projection_reg.apply(math.sqrt) * P) 45 | 46 | return residuals 47 | 48 | 49 | def ip_input(self, a: TensorList, b: TensorList): 50 | num = len(a) // 2 # Number of filters 51 | a_filter = a[:num] 52 | b_filter = b[:num] 53 | a_P = a[num:] 54 | b_P = b[num:] 55 | 56 | # Filter inner product 57 | # ip_out = a_filter.reshape(-1) @ b_filter.reshape(-1) 58 | ip_out = operation.conv2d(a_filter, b_filter).view(-1) 59 | 60 | # Add projection matrix part 61 | # ip_out += a_P.reshape(-1) @ b_P.reshape(-1) 62 | ip_out += operation.conv2d(a_P.view(1,-1,1,1), b_P.view(1,-1,1,1)).view(-1) 63 | 64 | # Have independent inner products for each filter 65 | return ip_out.concat(ip_out.clone()) 66 | 67 | def M1(self, x: TensorList): 68 | return x / self.diag_M 69 | 70 | 71 | class ConvProblem(optimization.L2Problem): 72 | def __init__(self, training_samples: TensorList, y: TensorList, filter_reg: torch.Tensor, sample_weights: TensorList, response_activation): 73 | self.training_samples = training_samples 74 | self.y = y 75 | self.filter_reg = filter_reg 76 | self.sample_weights = sample_weights 77 | self.response_activation = response_activation 78 | 79 | def __call__(self, x: TensorList): 80 | """ 81 | Compute residuals 82 | :param x: [filters] 83 | :return: [data_terms, filter_regularizations] 84 | """ 85 | # Do convolution and compute residuals 86 | residuals = operation.conv2d(self.training_samples, x, mode='same').apply(self.response_activation) 87 | residuals = residuals - self.y 88 | 89 | residuals = self.sample_weights.sqrt().view(-1, 1, 1, 1) * residuals 90 | 91 | # Add regularization for projection matrix 92 | residuals.extend(self.filter_reg.apply(math.sqrt) * x) 93 | 94 | return residuals 95 | 96 | def ip_input(self, a: TensorList, b: TensorList): 97 | # return a.reshape(-1) @ b.reshape(-1) 98 | # return (a * b).sum() 99 | return operation.conv2d(a, b).view(-1) 100 | -------------------------------------------------------------------------------- /pytracking/tracker/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .basetracker import BaseTracker -------------------------------------------------------------------------------- /pytracking/tracker/base/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/base/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/tracker/base/__pycache__/basetracker.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/tracker/base/__pycache__/basetracker.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__init__.py -------------------------------------------------------------------------------- /pytracking/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/util/__pycache__/anchors.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/anchors.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/util/__pycache__/bbox_helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/bbox_helper.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/util/__pycache__/config_helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/config_helper.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/util/__pycache__/load_helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/load_helper.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/util/__pycache__/tracker_config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/__pycache__/tracker_config.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/util/anchors.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | import numpy as np 7 | import math 8 | from util.bbox_helper import center2corner, corner2center 9 | 10 | 11 | class Anchors: 12 | def __init__(self, cfg): 13 | self.stride = 8 14 | self.ratios = [0.33, 0.5, 1, 2, 3] 15 | self.scales = [8] 16 | self.round_dight = 0 17 | self.image_center = 0 18 | self.size = 0 19 | 20 | self.__dict__.update(cfg) 21 | 22 | self.anchor_num = len(self.scales) * len(self.ratios) 23 | self.anchors = None # in single position (anchor_num*4) 24 | self.all_anchors = None # in all position 2*(4*anchor_num*h*w) 25 | self.generate_anchors() 26 | 27 | def generate_anchors(self): 28 | self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32) 29 | 30 | size = self.stride * self.stride 31 | count = 0 32 | for r in self.ratios: 33 | if self.round_dight > 0: 34 | ws = round(math.sqrt(size*1. / r), self.round_dight) 35 | hs = round(ws * r, self.round_dight) 36 | else: 37 | ws = int(math.sqrt(size*1. / r)) 38 | hs = int(ws * r) 39 | 40 | for s in self.scales: 41 | w = ws * s 42 | h = hs * s 43 | self.anchors[count][:] = [-w*0.5, -h*0.5, w*0.5, h*0.5][:] 44 | count += 1 45 | 46 | def generate_all_anchors(self, im_c, size): 47 | if self.image_center == im_c and self.size == size: 48 | return False 49 | self.image_center = im_c 50 | self.size = size 51 | 52 | a0x = im_c - size // 2 * self.stride 53 | ori = np.array([a0x] * 4, dtype=np.float32) 54 | zero_anchors = self.anchors + ori 55 | 56 | x1 = zero_anchors[:, 0] 57 | y1 = zero_anchors[:, 1] 58 | x2 = zero_anchors[:, 2] 59 | y2 = zero_anchors[:, 3] 60 | 61 | x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), [x1, y1, x2, y2]) 62 | cx, cy, w, h = corner2center([x1, y1, x2, y2]) 63 | 64 | disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride 65 | disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride 66 | 67 | cx = cx + disp_x 68 | cy = cy + disp_y 69 | 70 | # broadcast 71 | zero = np.zeros((self.anchor_num, size, size), dtype=np.float32) 72 | cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h]) 73 | x1, y1, x2, y2 = center2corner([cx, cy, w, h]) 74 | 75 | self.all_anchors = np.stack([x1, y1, x2, y2]), np.stack([cx, cy, w, h]) 76 | return True 77 | 78 | 79 | -------------------------------------------------------------------------------- /pytracking/util/bbox_helper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | import numpy as np 7 | from collections import namedtuple 8 | 9 | Corner = namedtuple('Corner', 'x1 y1 x2 y2') 10 | BBox = Corner 11 | Center = namedtuple('Center', 'x y w h') 12 | 13 | 14 | def corner2center(corner): 15 | """ 16 | :param corner: Corner or np.array 4*N 17 | :return: Center or 4 np.array N 18 | """ 19 | if isinstance(corner, Corner): 20 | x1, y1, x2, y2 = corner 21 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1)) 22 | else: 23 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3] 24 | x = (x1 + x2) * 0.5 25 | y = (y1 + y2) * 0.5 26 | w = x2 - x1 27 | h = y2 - y1 28 | return x, y, w, h 29 | 30 | 31 | def center2corner(center): 32 | """ 33 | :param center: Center or np.array 4*N 34 | :return: Corner or np.array 4*N 35 | """ 36 | if isinstance(center, Center): 37 | x, y, w, h = center 38 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5) 39 | else: 40 | x, y, w, h = center[0], center[1], center[2], center[3] 41 | x1 = x - w * 0.5 42 | y1 = y - h * 0.5 43 | x2 = x + w * 0.5 44 | y2 = y + h * 0.5 45 | return x1, y1, x2, y2 46 | 47 | 48 | def cxy_wh_2_rect(pos, sz): 49 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) # 0-index 50 | 51 | 52 | def get_axis_aligned_bbox(region): 53 | nv = region.size 54 | if nv == 8: 55 | cx = np.mean(region[0::2]) 56 | cy = np.mean(region[1::2]) 57 | x1 = min(region[0::2]) 58 | x2 = max(region[0::2]) 59 | y1 = min(region[1::2]) 60 | y2 = max(region[1::2]) 61 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[2:4] - region[4:6]) 62 | A2 = (x2 - x1) * (y2 - y1) 63 | s = np.sqrt(A1 / A2) 64 | w = s * (x2 - x1) + 1 65 | h = s * (y2 - y1) + 1 66 | else: 67 | x = region[0] 68 | y = region[1] 69 | w = region[2] 70 | h = region[3] 71 | cx = x+w/2 72 | cy = y+h/2 73 | 74 | return cx, cy, w, h 75 | 76 | 77 | -------------------------------------------------------------------------------- /pytracking/util/benchmark_helper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | from os.path import join, realpath, dirname, exists, isdir 7 | from os import listdir 8 | import logging 9 | import glob 10 | import numpy as np 11 | import json 12 | from collections import OrderedDict 13 | 14 | 15 | def get_dataset_zoo(): 16 | root = realpath(join(dirname(__file__), '../data')) 17 | zoos = listdir(root) 18 | 19 | def valid(x): 20 | y = join(root, x) 21 | if not isdir(y): return False 22 | 23 | return exists(join(y, 'list.txt')) \ 24 | or exists(join(y, 'train', 'meta.json'))\ 25 | or exists(join(y, 'ImageSets', '2016', 'val.txt')) 26 | 27 | zoos = list(filter(valid, zoos)) 28 | return zoos 29 | 30 | 31 | dataset_zoo = get_dataset_zoo() 32 | 33 | 34 | def load_dataset(dataset): 35 | info = OrderedDict() 36 | if 'VOT' in dataset: 37 | base_path = join(realpath(dirname(__file__)), '/home/ajun/daima/code/SiamMask/data', dataset) 38 | if not exists(base_path): 39 | logging.error("Please download test dataset!!!") 40 | exit() 41 | list_path = join(base_path, 'list.txt') 42 | with open(list_path) as f: 43 | videos = [v.strip() for v in f.readlines()] 44 | for video in videos: 45 | video_path = join(base_path, video) 46 | image_path = join(video_path, '*.jpg') 47 | image_files = sorted(glob.glob(image_path)) 48 | if len(image_files) == 0: # VOT2018 49 | image_path = join(video_path, 'color', '*.jpg') 50 | image_files = sorted(glob.glob(image_path)) 51 | gt_path = join(video_path, 'groundtruth.txt') 52 | gt = np.loadtxt(gt_path, delimiter=',').astype(np.float64) 53 | if gt.shape[1] == 4: 54 | gt = np.column_stack((gt[:, 0], gt[:, 1], gt[:, 0], gt[:, 1] + gt[:, 3]-1, 55 | gt[:, 0] + gt[:, 2]-1, gt[:, 1] + gt[:, 3]-1, gt[:, 0] + gt[:, 2]-1, gt[:, 1])) 56 | info[video] = {'image_files': image_files, 'gt': gt, 'name': video} 57 | elif 'DAVIS' in dataset: 58 | base_path = join(realpath(dirname(__file__)), '../data', 'DAVIS') 59 | list_path = join(realpath(dirname(__file__)), '../data', 'DAVIS', 'ImageSets', dataset[-4:], 'val.txt') 60 | with open(list_path) as f: 61 | videos = [v.strip() for v in f.readlines()] 62 | for video in videos: 63 | info[video] = {} 64 | info[video]['anno_files'] = sorted(glob.glob(join(base_path, 'Annotations/480p', video, '*.png'))) 65 | info[video]['image_files'] = sorted(glob.glob(join(base_path, 'JPEGImages/480p', video, '*.jpg'))) 66 | info[video]['name'] = video 67 | elif 'ytb_vos' in dataset: 68 | base_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid') 69 | json_path = join(realpath(dirname(__file__)), '../data', 'ytb_vos', 'valid', 'meta.json') 70 | meta = json.load(open(json_path, 'r')) 71 | meta = meta['videos'] 72 | info = dict() 73 | for v in meta.keys(): 74 | objects = meta[v]['objects'] 75 | frames = [] 76 | anno_frames = [] 77 | info[v] = dict() 78 | for obj in objects: 79 | frames += objects[obj]['frames'] 80 | anno_frames += [objects[obj]['frames'][0]] 81 | frames = sorted(np.unique(frames)) 82 | info[v]['anno_files'] = [join(base_path, 'Annotations', v, im_f+'.png') for im_f in frames] 83 | info[v]['anno_init_files'] = [join(base_path, 'Annotations', v, im_f + '.png') for im_f in anno_frames] 84 | info[v]['image_files'] = [join(base_path, 'JPEGImages', v, im_f+'.jpg') for im_f in frames] 85 | info[v]['name'] = v 86 | 87 | info[v]['start_frame'] = dict() 88 | info[v]['end_frame'] = dict() 89 | for obj in objects: 90 | start_file = objects[obj]['frames'][0] 91 | end_file = objects[obj]['frames'][-1] 92 | info[v]['start_frame'][obj] = frames.index(start_file) 93 | info[v]['end_frame'][obj] = frames.index(end_file) 94 | else: 95 | logging.error('Not support') 96 | exit() 97 | return info 98 | -------------------------------------------------------------------------------- /pytracking/util/config_helper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | import json 7 | from os.path import exists 8 | 9 | 10 | def load_config(args): 11 | assert exists(args.config), '"{}" not exists'.format(args.config) 12 | config = json.load(open(args.config)) 13 | 14 | # deal with network 15 | if 'network' not in config: 16 | print('Warning: network lost in config. This will be error in next version') 17 | 18 | config['network'] = {} 19 | 20 | if not args.arch: 21 | raise Exception('no arch provided') 22 | args.arch = config['network']['arch'] 23 | 24 | return config 25 | 26 | -------------------------------------------------------------------------------- /pytracking/util/load_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | logger = logging.getLogger('global') 4 | 5 | 6 | def check_keys(model, pretrained_state_dict): 7 | ckpt_keys = set(pretrained_state_dict.keys()) 8 | model_keys = set(model.state_dict().keys()) 9 | used_pretrained_keys = model_keys & ckpt_keys 10 | unused_pretrained_keys = ckpt_keys - model_keys 11 | missing_keys = model_keys - ckpt_keys 12 | if len(missing_keys) > 0: 13 | logger.info('[Warning] missing keys: {}'.format(missing_keys)) 14 | logger.info('missing keys:{}'.format(len(missing_keys))) 15 | if len(unused_pretrained_keys) > 0: 16 | logger.info('[Warning] unused_pretrained_keys: {}'.format(unused_pretrained_keys)) 17 | logger.info('unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) 18 | logger.info('used keys:{}'.format(len(used_pretrained_keys))) 19 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' 20 | return True 21 | 22 | 23 | def remove_prefix(state_dict, prefix): 24 | ''' Old style model is stored with all names of parameters share common prefix 'module.' ''' 25 | logger.info('remove prefix \'{}\''.format(prefix)) 26 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 27 | return {f(key): value for key, value in state_dict.items()} 28 | 29 | 30 | def load_pretrain(model, pretrained_path): 31 | logger.info('load pretrained model from {}'.format(pretrained_path)) 32 | device = torch.cuda.current_device() 33 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) 34 | if "state_dict" in pretrained_dict.keys(): 35 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') 36 | else: 37 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 38 | 39 | try: 40 | check_keys(model, pretrained_dict) 41 | except: 42 | logger.info('[Warning]: using pretrain as features. Adding "features." as prefix') 43 | new_dict = {} 44 | for k, v in pretrained_dict.items(): 45 | k = 'features.' + k 46 | new_dict[k] = v 47 | pretrained_dict = new_dict 48 | check_keys(model, pretrained_dict) 49 | model.load_state_dict(pretrained_dict, strict=False) 50 | return model 51 | -------------------------------------------------------------------------------- /pytracking/util/log_helper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | from __future__ import division 7 | 8 | import os 9 | import logging 10 | import sys 11 | 12 | if hasattr(sys, 'frozen'): # support for py2exe 13 | _srcfile = "logging%s__init__%s" % (os.sep, __file__[-4:]) 14 | elif __file__[-4:].lower() in ['.pyc', '.pyo']: 15 | _srcfile = __file__[:-4] + '.py' 16 | else: 17 | _srcfile = __file__ 18 | _srcfile = os.path.normcase(_srcfile) 19 | 20 | 21 | logs = set() 22 | 23 | 24 | class Filter: 25 | def __init__(self, flag): 26 | self.flag = flag 27 | 28 | def filter(self, x): return self.flag 29 | 30 | 31 | class Dummy: 32 | def __init__(self, *arg, **kwargs): 33 | pass 34 | 35 | def __getattr__(self, arg): 36 | def dummy(*args, **kwargs): pass 37 | return dummy 38 | 39 | 40 | def get_format(logger, level): 41 | if 'SLURM_PROCID' in os.environ: 42 | rank = int(os.environ['SLURM_PROCID']) 43 | 44 | if level == logging.INFO: 45 | logger.addFilter(Filter(rank == 0)) 46 | else: 47 | rank = 0 48 | format_str = '[%(asctime)s-rk{}-%(filename)s#%(lineno)3d] %(message)s'.format(rank) 49 | formatter = logging.Formatter(format_str) 50 | return formatter 51 | 52 | 53 | def init_log(name, level = logging.INFO, format_func=get_format): 54 | if (name, level) in logs: return 55 | logs.add((name, level)) 56 | logger = logging.getLogger(name) 57 | logger.setLevel(level) 58 | ch = logging.StreamHandler() 59 | ch.setLevel(level) 60 | formatter = format_func(logger, level) 61 | ch.setFormatter(formatter) 62 | logger.addHandler(ch) 63 | return logger 64 | 65 | 66 | def add_file_handler(name, log_file, level = logging.INFO): 67 | logger = logging.getLogger(name) 68 | fh = logging.FileHandler(log_file) 69 | fh.setFormatter(get_format(logger, level)) 70 | logger.addHandler(fh) 71 | 72 | 73 | init_log('global') 74 | 75 | -------------------------------------------------------------------------------- /pytracking/util/pysot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pysot/__init__.py -------------------------------------------------------------------------------- /pytracking/util/pysot/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from .vot import VOTDataset 10 | 11 | 12 | class DatasetFactory(object): 13 | @staticmethod 14 | def create_dataset(**kwargs): 15 | """ 16 | Args: 17 | name: dataset name 'VOT2018', 'VOT2016' 18 | dataset_root: dataset root 19 | Return: 20 | dataset 21 | """ 22 | assert 'name' in kwargs, "should provide dataset name" 23 | name = kwargs['name'] 24 | if 'VOT2018' == name or 'VOT2016' == name: 25 | dataset = VOTDataset(**kwargs) 26 | else: 27 | raise Exception("unknow dataset {}".format(kwargs['name'])) 28 | return dataset 29 | 30 | -------------------------------------------------------------------------------- /pytracking/util/pysot/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | class Dataset(object): 10 | def __init__(self, name, dataset_root): 11 | self.name = name 12 | self.dataset_root = dataset_root 13 | self.videos = None 14 | 15 | def __getitem__(self, idx): 16 | if isinstance(idx, str): 17 | return self.videos[idx] 18 | elif isinstance(idx, int): 19 | return self.videos[sorted(list(self.videos.keys()))[idx]] 20 | 21 | def __len__(self): 22 | return len(self.videos) 23 | 24 | def __iter__(self): 25 | keys = sorted(list(self.videos.keys())) 26 | for key in keys: 27 | yield self.videos[key] 28 | 29 | def set_tracker(self, path, tracker_names): 30 | """ 31 | Args: 32 | path: path to tracker results, 33 | tracker_names: list of tracker name 34 | """ 35 | self.tracker_path = path 36 | self.tracker_names = tracker_names 37 | -------------------------------------------------------------------------------- /pytracking/util/pysot/datasets/video.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | import os 10 | import cv2 11 | 12 | from glob import glob 13 | 14 | 15 | class Video(object): 16 | def __init__(self, name, root, video_dir, init_rect, img_names, 17 | gt_rect, attr): 18 | self.name = name 19 | self.video_dir = video_dir 20 | self.init_rect = init_rect 21 | self.gt_traj = gt_rect 22 | self.attr = attr 23 | self.pred_trajs = {} 24 | self.img_names = [os.path.join(root, x) for x in img_names] 25 | self.imgs = None 26 | 27 | def load_tracker(self, path, tracker_names=None, store=True): 28 | """ 29 | Args: 30 | path(str): path to result 31 | tracker_name(list): name of tracker 32 | """ 33 | if not tracker_names: 34 | tracker_names = [x.split('/')[-1] for x in glob(path) 35 | if os.path.isdir(x)] 36 | if isinstance(tracker_names, str): 37 | tracker_names = [tracker_names] 38 | for name in tracker_names: 39 | traj_file = os.path.join(path, name, self.name+'.txt') 40 | if os.path.exists(traj_file): 41 | with open(traj_file, 'r') as f : 42 | pred_traj = [list(map(float, x.strip().split(','))) 43 | for x in f.readlines()] 44 | if len(pred_traj) != len(self.gt_traj): 45 | print(name, len(pred_traj), len(self.gt_traj), self.name) 46 | if store: 47 | self.pred_trajs[name] = pred_traj 48 | else: 49 | return pred_traj 50 | else: 51 | print(traj_file) 52 | self.tracker_names = list(self.pred_trajs.keys()) 53 | 54 | def load_img(self): 55 | if self.imgs is None: 56 | self.imgs = [cv2.imread(x) 57 | for x in self.img_names] 58 | self.width = self.imgs[0].shape[1] 59 | self.height = self.imgs[0].shape[0] 60 | 61 | def free_img(self): 62 | self.imgs = None 63 | 64 | def __len__(self): 65 | return len(self.img_names) 66 | 67 | def __getitem__(self, idx): 68 | if self.imgs is None: 69 | return cv2.imread(self.img_names[idx]), \ 70 | self.gt_traj[idx] 71 | else: 72 | return self.imgs[idx], self.gt_traj[idx] 73 | 74 | def __iter__(self): 75 | for i in range(len(self.img_names)): 76 | if self.imgs is not None: 77 | yield self.imgs[i], self.gt_traj[i] 78 | else: 79 | yield cv2.imread(self.img_names[i]), \ 80 | self.gt_traj[i] 81 | -------------------------------------------------------------------------------- /pytracking/util/pysot/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from .ar_benchmark import AccuracyRobustnessBenchmark 10 | from .eao_benchmark import EAOBenchmark 11 | -------------------------------------------------------------------------------- /pytracking/util/pysot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from . import region 10 | from .statistics import * 11 | -------------------------------------------------------------------------------- /pytracking/util/pysot/utils/build/temp.linux-x86_64-3.6/region.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pysot/utils/build/temp.linux-x86_64-3.6/region.o -------------------------------------------------------------------------------- /pytracking/util/pysot/utils/build/temp.linux-x86_64-3.6/src/region.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pysot/utils/build/temp.linux-x86_64-3.6/src/region.o -------------------------------------------------------------------------------- /pytracking/util/pysot/utils/c_region.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "src/region.h": 2 | ctypedef enum region_type "RegionType": 3 | EMTPY 4 | SPECIAL 5 | RECTANGEL 6 | POLYGON 7 | MASK 8 | 9 | ctypedef struct region_bounds: 10 | float top 11 | float bottom 12 | float left 13 | float right 14 | 15 | ctypedef struct region_rectangle: 16 | float x 17 | float y 18 | float width 19 | float height 20 | 21 | # ctypedef struct region_mask: 22 | # int x 23 | # int y 24 | # int width 25 | # int height 26 | # char *data 27 | 28 | ctypedef struct region_polygon: 29 | int count 30 | float *x 31 | float *y 32 | 33 | ctypedef union region_container_data: 34 | region_rectangle rectangle 35 | region_polygon polygon 36 | # region_mask mask 37 | int special 38 | 39 | ctypedef struct region_container: 40 | region_type type 41 | region_container_data data 42 | 43 | # ctypedef struct region_overlap: 44 | # float overlap 45 | # float only1 46 | # float only2 47 | 48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds) 49 | 50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds) 51 | -------------------------------------------------------------------------------- /pytracking/util/pysot/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | import numpy as np 10 | 11 | def determine_thresholds(confidence, resolution=100): 12 | """choose threshold according to confidence 13 | 14 | Args: 15 | confidence: list or numpy array or numpy array 16 | reolution: number of threshold to choose 17 | 18 | Restures: 19 | threshold: numpy array 20 | """ 21 | if isinstance(confidence, list): 22 | confidence = np.array(confidence) 23 | confidence = confidence.flatten() 24 | confidence = confidence[~np.isnan(confidence)] 25 | confidence.sort() 26 | 27 | assert len(confidence) > resolution and resolution > 2 28 | 29 | thresholds = np.ones((resolution)) 30 | thresholds[0] = - np.inf 31 | thresholds[-1] = np.inf 32 | delta = np.floor(len(confidence) / (resolution - 2)) 33 | idxs = np.linspace(delta, len(confidence)-delta, resolution-2, dtype=np.int32) 34 | thresholds[1:-1] = confidence[idxs] 35 | return thresholds 36 | -------------------------------------------------------------------------------- /pytracking/util/pysot/utils/region.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pysot/utils/region.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /pytracking/util/pysot/utils/setup.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from distutils.core import setup 10 | from distutils.extension import Extension 11 | from Cython.Build import cythonize 12 | 13 | setup( 14 | ext_modules = cythonize([Extension("region", ["region.pyx", "src/region.c"])]), 15 | ) 16 | 17 | -------------------------------------------------------------------------------- /pytracking/util/pysot/utils/src/region.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */ 2 | 3 | #ifndef _REGION_H_ 4 | #define _REGION_H_ 5 | 6 | #ifdef TRAX_STATIC_DEFINE 7 | # define __TRAX_EXPORT 8 | #else 9 | # ifndef __TRAX_EXPORT 10 | # if defined(_MSC_VER) 11 | # ifdef trax_EXPORTS 12 | /* We are building this library */ 13 | # define __TRAX_EXPORT __declspec(dllexport) 14 | # else 15 | /* We are using this library */ 16 | # define __TRAX_EXPORT __declspec(dllimport) 17 | # endif 18 | # elif defined(__GNUC__) 19 | # ifdef trax_EXPORTS 20 | /* We are building this library */ 21 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 22 | # else 23 | /* We are using this library */ 24 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 25 | # endif 26 | # endif 27 | # endif 28 | #endif 29 | 30 | #ifndef MAX 31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b)) 32 | #endif 33 | 34 | #ifndef MIN 35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b)) 36 | #endif 37 | 38 | #define TRAX_DEFAULT_CODE 0 39 | 40 | #define REGION_LEGACY_RASTERIZATION 1 41 | 42 | #ifdef __cplusplus 43 | extern "C" { 44 | #endif 45 | 46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type; 47 | 48 | typedef struct region_bounds { 49 | 50 | float top; 51 | float bottom; 52 | float left; 53 | float right; 54 | 55 | } region_bounds; 56 | 57 | typedef struct region_polygon { 58 | 59 | int count; 60 | 61 | float* x; 62 | float* y; 63 | 64 | } region_polygon; 65 | 66 | typedef struct region_mask { 67 | 68 | int x; 69 | int y; 70 | 71 | int width; 72 | int height; 73 | 74 | char* data; 75 | 76 | } region_mask; 77 | 78 | typedef struct region_rectangle { 79 | 80 | float x; 81 | float y; 82 | float width; 83 | float height; 84 | 85 | } region_rectangle; 86 | 87 | typedef struct region_container { 88 | enum region_type type; 89 | union { 90 | region_rectangle rectangle; 91 | region_polygon polygon; 92 | region_mask mask; 93 | int special; 94 | } data; 95 | } region_container; 96 | 97 | typedef struct region_overlap { 98 | 99 | float overlap; 100 | float only1; 101 | float only2; 102 | 103 | } region_overlap; 104 | 105 | extern const region_bounds region_no_bounds; 106 | 107 | __TRAX_EXPORT int region_set_flags(int mask); 108 | 109 | __TRAX_EXPORT int region_clear_flags(int mask); 110 | 111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds); 112 | 113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds); 114 | 115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom); 116 | 117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region); 118 | 119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region); 120 | 121 | __TRAX_EXPORT char* region_string(region_container* region); 122 | 123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region); 124 | 125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type); 126 | 127 | __TRAX_EXPORT void region_release(region_container** region); 128 | 129 | __TRAX_EXPORT region_container* region_create_special(int code); 130 | 131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height); 132 | 133 | __TRAX_EXPORT region_container* region_create_polygon(int count); 134 | 135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y); 136 | 137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height); 138 | 139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height); 140 | 141 | #ifdef __cplusplus 142 | } 143 | #endif 144 | 145 | #endif 146 | -------------------------------------------------------------------------------- /pytracking/util/pyvotkit/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from . import region 10 | -------------------------------------------------------------------------------- /pytracking/util/pyvotkit/build/temp.linux-x86_64-3.6/region.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pyvotkit/build/temp.linux-x86_64-3.6/region.o -------------------------------------------------------------------------------- /pytracking/util/pyvotkit/build/temp.linux-x86_64-3.6/src/region.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pyvotkit/build/temp.linux-x86_64-3.6/src/region.o -------------------------------------------------------------------------------- /pytracking/util/pyvotkit/c_region.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "region.h": 2 | ctypedef enum region_type "RegionType": 3 | EMTPY 4 | SPECIAL 5 | RECTANGEL 6 | POLYGON 7 | MASK 8 | 9 | ctypedef struct region_bounds: 10 | float top 11 | float bottom 12 | float left 13 | float right 14 | 15 | ctypedef struct region_rectangle: 16 | float x 17 | float y 18 | float width 19 | float height 20 | 21 | # ctypedef struct region_mask: 22 | # int x 23 | # int y 24 | # int width 25 | # int height 26 | # char *data 27 | 28 | ctypedef struct region_polygon: 29 | int count 30 | float *x 31 | float *y 32 | 33 | ctypedef union region_container_data: 34 | region_rectangle rectangle 35 | region_polygon polygon 36 | # region_mask mask 37 | int special 38 | 39 | ctypedef struct region_container: 40 | region_type type 41 | region_container_data data 42 | 43 | # ctypedef struct region_overlap: 44 | # float overlap 45 | # float only1 46 | # float only2 47 | 48 | # region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds) 49 | 50 | float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds) 51 | -------------------------------------------------------------------------------- /pytracking/util/pyvotkit/region.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/util/pyvotkit/region.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /pytracking/util/pyvotkit/setup.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Python Single Object Tracking Evaluation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Fangyi Zhang 5 | # @author fangyi.zhang@vipl.ict.ac.cn 6 | # @project https://github.com/StrangerZhang/pysot-toolkit.git 7 | # Revised for SiamMask by foolwood 8 | # -------------------------------------------------------- 9 | from distutils.core import setup 10 | from distutils.extension import Extension 11 | from Cython.Build import cythonize 12 | 13 | setup( 14 | ext_modules = cythonize([Extension("region", ["region.pyx"])]) 15 | ) 16 | 17 | -------------------------------------------------------------------------------- /pytracking/util/pyvotkit/src/region.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C; indent-tabs-mode: nil; c-basic-offset: 4; tab-width: 4 -*- */ 2 | 3 | #ifndef _REGION_H_ 4 | #define _REGION_H_ 5 | 6 | #ifdef TRAX_STATIC_DEFINE 7 | # define __TRAX_EXPORT 8 | #else 9 | # ifndef __TRAX_EXPORT 10 | # if defined(_MSC_VER) 11 | # ifdef trax_EXPORTS 12 | /* We are building this library */ 13 | # define __TRAX_EXPORT __declspec(dllexport) 14 | # else 15 | /* We are using this library */ 16 | # define __TRAX_EXPORT __declspec(dllimport) 17 | # endif 18 | # elif defined(__GNUC__) 19 | # ifdef trax_EXPORTS 20 | /* We are building this library */ 21 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 22 | # else 23 | /* We are using this library */ 24 | # define __TRAX_EXPORT __attribute__((visibility("default"))) 25 | # endif 26 | # endif 27 | # endif 28 | #endif 29 | 30 | #ifndef MAX 31 | #define MAX(a,b) (((a) > (b)) ? (a) : (b)) 32 | #endif 33 | 34 | #ifndef MIN 35 | #define MIN(a,b) (((a) < (b)) ? (a) : (b)) 36 | #endif 37 | 38 | #define TRAX_DEFAULT_CODE 0 39 | 40 | #define REGION_LEGACY_RASTERIZATION 1 41 | 42 | #ifdef __cplusplus 43 | extern "C" { 44 | #endif 45 | 46 | typedef enum region_type {EMPTY, SPECIAL, RECTANGLE, POLYGON, MASK} region_type; 47 | 48 | typedef struct region_bounds { 49 | 50 | float top; 51 | float bottom; 52 | float left; 53 | float right; 54 | 55 | } region_bounds; 56 | 57 | typedef struct region_polygon { 58 | 59 | int count; 60 | 61 | float* x; 62 | float* y; 63 | 64 | } region_polygon; 65 | 66 | typedef struct region_mask { 67 | 68 | int x; 69 | int y; 70 | 71 | int width; 72 | int height; 73 | 74 | char* data; 75 | 76 | } region_mask; 77 | 78 | typedef struct region_rectangle { 79 | 80 | float x; 81 | float y; 82 | float width; 83 | float height; 84 | 85 | } region_rectangle; 86 | 87 | typedef struct region_container { 88 | enum region_type type; 89 | union { 90 | region_rectangle rectangle; 91 | region_polygon polygon; 92 | region_mask mask; 93 | int special; 94 | } data; 95 | } region_container; 96 | 97 | typedef struct region_overlap { 98 | 99 | float overlap; 100 | float only1; 101 | float only2; 102 | 103 | } region_overlap; 104 | 105 | extern const region_bounds region_no_bounds; 106 | 107 | __TRAX_EXPORT int region_set_flags(int mask); 108 | 109 | __TRAX_EXPORT int region_clear_flags(int mask); 110 | 111 | __TRAX_EXPORT region_overlap region_compute_overlap(const region_container* ra, const region_container* rb, region_bounds bounds); 112 | 113 | __TRAX_EXPORT float compute_polygon_overlap(const region_polygon* p1, const region_polygon* p2, float *only1, float *only2, region_bounds bounds); 114 | 115 | __TRAX_EXPORT region_bounds region_create_bounds(float left, float top, float right, float bottom); 116 | 117 | __TRAX_EXPORT region_bounds region_compute_bounds(const region_container* region); 118 | 119 | __TRAX_EXPORT int region_parse(const char* buffer, region_container** region); 120 | 121 | __TRAX_EXPORT char* region_string(region_container* region); 122 | 123 | __TRAX_EXPORT void region_print(FILE* out, region_container* region); 124 | 125 | __TRAX_EXPORT region_container* region_convert(const region_container* region, region_type type); 126 | 127 | __TRAX_EXPORT void region_release(region_container** region); 128 | 129 | __TRAX_EXPORT region_container* region_create_special(int code); 130 | 131 | __TRAX_EXPORT region_container* region_create_rectangle(float x, float y, float width, float height); 132 | 133 | __TRAX_EXPORT region_container* region_create_polygon(int count); 134 | 135 | __TRAX_EXPORT int region_contains_point(region_container* r, float x, float y); 136 | 137 | __TRAX_EXPORT void region_get_mask(region_container* r, char* mask, int width, int height); 138 | 139 | __TRAX_EXPORT void region_get_mask_offset(region_container* r, char* mask, int x, int y, int width, int height); 140 | 141 | #ifdef __cplusplus 142 | } 143 | #endif 144 | 145 | #endif 146 | -------------------------------------------------------------------------------- /pytracking/util/tracker_config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SiamMask 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | from __future__ import division 7 | from util.anchors import Anchors 8 | 9 | 10 | class TrackerConfig(object): 11 | # These are the default hyper-params for SiamMask 12 | penalty_k = 0.04 13 | window_influence = 0.42 14 | lr = 0.25 15 | seg_thr = 0.3 # for mask 16 | windowing = 'cosine' # to penalize large displacements [cosine/uniform] 17 | # Params from the network architecture, have to be consistent with the training 18 | exemplar_size = 127 # input z size 19 | instance_size = 255 # input x size (search region) 20 | total_stride = 8 21 | out_size = 63 # for mask 22 | base_size = 8 23 | score_size = (instance_size-exemplar_size)//total_stride+1+base_size 24 | context_amount = 0.5 # context amount for the exemplar 25 | ratios = [0.33, 0.5, 1, 2, 3] 26 | scales = [8, ] 27 | anchor_num = len(ratios) * len(scales) 28 | round_dight = 0 29 | anchor = [] 30 | 31 | def update(self, newparam=None, anchors=None): 32 | if newparam: 33 | for key, value in newparam.items(): 34 | setattr(self, key, value) 35 | if anchors is not None: 36 | if isinstance(anchors, dict): 37 | anchors = Anchors(anchors) 38 | if isinstance(anchors, Anchors): 39 | self.total_stride = anchors.stride 40 | self.ratios = anchors.ratios 41 | self.scales = anchors.scales 42 | self.round_dight = anchors.round_dight 43 | self.renew() 44 | 45 | def renew(self): 46 | self.score_size = (self.instance_size - self.exemplar_size) // self.total_stride + 1 + self.base_size 47 | self.anchor_num = len(self.ratios) * len(self.scales) 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /pytracking/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # from .evaluation import * 2 | from .params import * -------------------------------------------------------------------------------- /pytracking/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/utils/__pycache__/params.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/utils/__pycache__/params.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/utils/__pycache__/plotting.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrackerLB/SPSTracker/67857a3bee35c9f00b0830d40f893194426d60a1/pytracking/utils/__pycache__/plotting.cpython-36.pyc -------------------------------------------------------------------------------- /pytracking/utils/gdrive_download: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # The script taken from https://www.matthuisman.nz/2019/01/download-google-drive-files-wget-curl.html 4 | 5 | url=$1 6 | filename=$2 7 | 8 | [ -z "$url" ] && echo A URL or ID is required first argument && exit 1 9 | 10 | fileid="" 11 | declare -a patterns=("s/.*\/file\/d\/\(.*\)\/.*/\1/p" "s/.*id\=\(.*\)/\1/p" "s/\(.*\)/\1/p") 12 | for i in "${patterns[@]}" 13 | do 14 | fileid=$(echo $url | sed -n $i) 15 | [ ! -z "$fileid" ] && break 16 | done 17 | 18 | [ -z "$fileid" ] && echo Could not find Google ID && exit 1 19 | 20 | echo File ID: $fileid 21 | 22 | tmp_file="$filename.$$.file" 23 | tmp_cookies="$filename.$$.cookies" 24 | tmp_headers="$filename.$$.headers" 25 | 26 | url='https://docs.google.com/uc?export=download&id='$fileid 27 | echo Downloading: "$url > $tmp_file" 28 | wget --save-cookies "$tmp_cookies" -q -S -O - $url 2> "$tmp_headers" 1> "$tmp_file" 29 | 30 | if [[ ! $(find "$tmp_file" -type f -size +10000c 2>/dev/null) ]]; then 31 | confirm=$(cat "$tmp_file" | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1/p') 32 | fi 33 | 34 | if [ ! -z "$confirm" ]; then 35 | url='https://docs.google.com/uc?export=download&id='$fileid'&confirm='$confirm 36 | echo Downloading: "$url > $tmp_file" 37 | wget --load-cookies "$tmp_cookies" -q -S -O - $url 2> "$tmp_headers" 1> "$tmp_file" 38 | fi 39 | 40 | [ -z "$filename" ] && filename=$(cat "$tmp_headers" | sed -rn 's/.*filename=\"(.*)\".*/\1/p') 41 | [ -z "$filename" ] && filename="google_drive.file" 42 | 43 | echo Moving: "$tmp_file > $filename" 44 | 45 | mv "$tmp_file" "$filename" 46 | 47 | rm -f "$tmp_cookies" "$tmp_headers" 48 | 49 | echo Saved: "$filename" 50 | echo DONE! 51 | 52 | exit 0 53 | -------------------------------------------------------------------------------- /pytracking/utils/params.py: -------------------------------------------------------------------------------- 1 | from pytracking import TensorList 2 | import random 3 | 4 | 5 | class TrackerParams: 6 | """Class for tracker parameters.""" 7 | def free_memory(self): 8 | for a in dir(self): 9 | if not a.startswith('__') and hasattr(getattr(self, a), 'free_memory'): 10 | getattr(self, a).free_memory() 11 | 12 | 13 | class FeatureParams: 14 | """Class for feature specific parameters""" 15 | def __init__(self, *args, **kwargs): 16 | if len(args) > 0: 17 | raise ValueError 18 | 19 | for name, val in kwargs.items(): 20 | if isinstance(val, list): 21 | setattr(self, name, TensorList(val)) 22 | else: 23 | setattr(self, name, val) 24 | 25 | 26 | def Choice(*args): 27 | """Can be used to sample random parameter values.""" 28 | return random.choice(args) 29 | -------------------------------------------------------------------------------- /pytracking/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('TkAgg') 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def show_tensor(a: torch.Tensor, fig_num = None, title = None): 9 | """Display a 2D tensor. 10 | args: 11 | fig_num: Figure number. 12 | title: Title of figure. 13 | """ 14 | a_np = a.squeeze().cpu().clone().detach().numpy() 15 | if a_np.ndim == 3: 16 | a_np = np.transpose(a_np, (1, 2, 0)) 17 | plt.figure(fig_num) 18 | plt.tight_layout() 19 | plt.cla() 20 | plt.imshow(a_np) 21 | plt.axis('off') 22 | plt.axis('equal') 23 | if title is not None: 24 | plt.title(title) 25 | plt.draw() 26 | plt.pause(0.001) 27 | 28 | 29 | def plot_graph(a: torch.Tensor, fig_num = None, title = None): 30 | """Plot graph. Data is a 1D tensor. 31 | args: 32 | fig_num: Figure number. 33 | title: Title of figure. 34 | """ 35 | a_np = a.squeeze().cpu().clone().detach().numpy() 36 | if a_np.ndim > 1: 37 | raise ValueError 38 | plt.figure(fig_num) 39 | # plt.tight_layout() 40 | plt.cla() 41 | plt.plot(a_np) 42 | if title is not None: 43 | plt.title(title) 44 | plt.draw() 45 | plt.pause(0.001) 46 | --------------------------------------------------------------------------------