├── lib
├── __init__.py
├── config
│ ├── __init__.py
│ ├── .DS_Store
│ └── ostrack
│ │ ├── .DS_Store
│ │ └── config.py
├── test
│ ├── __init__.py
│ ├── tracker
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── ostrack.cpython-38.pyc
│ │ │ ├── vis_utils.cpython-38.pyc
│ │ │ ├── basetracker.cpython-38.pyc
│ │ │ └── data_utils.cpython-38.pyc
│ │ ├── vis_utils.py
│ │ ├── data_utils.py
│ │ └── basetracker.py
│ ├── analysis
│ │ ├── __init__.py
│ │ └── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── plot_results.cpython-37.pyc
│ │ │ ├── plot_results.cpython-38.pyc
│ │ │ ├── extract_results.cpython-37.pyc
│ │ │ └── extract_results.cpython-38.pyc
│ ├── parameter
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── ostrack.cpython-38.pyc
│ │ │ └── __init__.cpython-38.pyc
│ │ └── ostrack.py
│ ├── .DS_Store
│ ├── utils
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── hann.cpython-38.pyc
│ │ │ ├── params.cpython-37.pyc
│ │ │ ├── params.cpython-38.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── _init_paths.cpython-38.pyc
│ │ │ ├── load_text.cpython-37.pyc
│ │ │ └── load_text.cpython-38.pyc
│ │ ├── _init_paths.py
│ │ ├── params.py
│ │ ├── transform_trackingnet.py
│ │ ├── load_text.py
│ │ ├── transform_got10k.py
│ │ └── hann.py
│ └── evaluation
│ │ ├── .DS_Store
│ │ ├── __init__.py
│ │ ├── tc128dataset.py
│ │ ├── tnl2kdataset.py
│ │ ├── tc128cedataset.py
│ │ ├── got10kdataset.py
│ │ ├── datasets.py
│ │ ├── trackingnetdataset.py
│ │ ├── itbdataset.py
│ │ └── environment.py
├── vis
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── utils.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── plotting.cpython-38.pyc
│ │ └── visdom_cus.cpython-38.pyc
│ ├── utils.py
│ └── plotting.py
├── models
│ ├── layers
│ │ ├── __init__.py
│ │ ├── .DS_Store
│ │ ├── frozen_bn.py
│ │ ├── rpe.py
│ │ ├── attn.py
│ │ └── attn_blocks.py
│ ├── __init__.py
│ ├── ostrack
│ │ ├── __init__.py
│ │ ├── .DS_Store
│ │ └── utils.py
│ └── .DS_Store
├── train
│ ├── __init__.py
│ ├── .DS_Store
│ ├── admin
│ │ ├── .DS_Store
│ │ ├── __init__.py
│ │ ├── settings.py
│ │ ├── multigpu.py
│ │ ├── tensorboard.py
│ │ ├── stats.py
│ │ └── environment.py
│ ├── data
│ │ ├── .DS_Store
│ │ ├── __pycache__
│ │ │ ├── loader.cpython-37.pyc
│ │ │ ├── loader.cpython-38.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── sampler.cpython-38.pyc
│ │ │ ├── processing.cpython-38.pyc
│ │ │ ├── transforms.cpython-38.pyc
│ │ │ ├── image_loader.cpython-37.pyc
│ │ │ ├── image_loader.cpython-38.pyc
│ │ │ ├── sampler_datr.cpython-38.pyc
│ │ │ ├── sampler_imix.cpython-38.pyc
│ │ │ ├── wandb_logger.cpython-38.pyc
│ │ │ ├── processing_datr.cpython-38.pyc
│ │ │ ├── processing_imix.cpython-38.pyc
│ │ │ └── processing_utils.cpython-38.pyc
│ │ ├── __init__.py
│ │ ├── wandb_logger.py
│ │ ├── bounding_box_utils.py
│ │ └── image_loader.py
│ ├── actors
│ │ ├── .DS_Store
│ │ ├── __init__.py
│ │ └── base_actor.py
│ ├── trainers
│ │ ├── .DS_Store
│ │ └── __init__.py
│ ├── dataset
│ │ ├── __pycache__
│ │ │ ├── coco.cpython-38.pyc
│ │ │ ├── lasot.cpython-38.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── coco_seq.cpython-38.pyc
│ │ │ ├── got10k.cpython-38.pyc
│ │ │ ├── COCO_tool.cpython-38.pyc
│ │ │ ├── lasot_lmdb.cpython-38.pyc
│ │ │ ├── coco_seq_lmdb.cpython-38.pyc
│ │ │ ├── got10k_lmdb.cpython-38.pyc
│ │ │ ├── imagenetvid.cpython-38.pyc
│ │ │ ├── tracking_net.cpython-38.pyc
│ │ │ ├── base_image_dataset.cpython-38.pyc
│ │ │ ├── base_video_dataset.cpython-38.pyc
│ │ │ ├── imagenetvid_lmdb.cpython-38.pyc
│ │ │ └── tracking_net_lmdb.cpython-38.pyc
│ │ ├── __init__.py
│ │ ├── base_image_dataset.py
│ │ ├── base_video_dataset.py
│ │ └── imagenetvid_lmdb.py
│ ├── _init_paths.py
│ ├── data_specs
│ │ └── README.md
│ ├── train_script.py
│ ├── train_script_datr.py
│ ├── train_script_distill.py
│ └── run_training.py
├── utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── merge.cpython-38.pyc
│ │ ├── misc.cpython-38.pyc
│ │ ├── tensor.cpython-37.pyc
│ │ ├── tensor.cpython-38.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── box_ops.cpython-38.pyc
│ │ ├── ce_utils.cpython-38.pyc
│ │ ├── focal_loss.cpython-38.pyc
│ │ ├── lmdb_utils.cpython-37.pyc
│ │ ├── lmdb_utils.cpython-38.pyc
│ │ └── heapmap_utils.cpython-38.pyc
│ ├── merge.py
│ ├── variable_hook.py
│ ├── lmdb_utils.py
│ ├── focal_loss.py
│ ├── box_ops.py
│ ├── ce_utils.py
│ └── heapmap_utils.py
└── .DS_Store
├── .idea
├── .name
├── .gitignore
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── DATr-code.iml
└── deployment.xml
├── .DS_Store
├── framework.png
├── results.png
├── experiments.png
├── experiments
├── .DS_Store
└── ostrack
│ ├── .DS_Store
│ ├── datr_vitb_384_mae_ce_32x4_ep300.yaml
│ └── datr_vitb_256_mae_ce_32x4_ep300.yaml
├── tracking
├── __pycache__
│ ├── test.cpython-38.pyc
│ ├── _init_paths.cpython-37.pyc
│ └── _init_paths.cpython-38.pyc
├── _init_paths.py
├── convert_transt.py
├── analysis_results_ITP.py
├── create_default_local_file.py
├── analysis_results.py
├── video_demo.py
├── test.py
├── test_exp.py
├── pre_read_datasets.py
├── analysis_results.ipynb
├── train.py
└── profile_model.py
├── eval.sh
├── LICENSE
├── README.md
├── install.sh
└── ostrack_cuda113_env.yaml
/lib/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.idea/.name:
--------------------------------------------------------------------------------
1 | README.md
--------------------------------------------------------------------------------
/lib/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/vis/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/test/tracker/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/test/analysis/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/test/parameter/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/train/__init__.py:
--------------------------------------------------------------------------------
1 | from .admin.multigpu import MultiGPU
2 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/.DS_Store
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .tensor import TensorDict, TensorList
2 |
--------------------------------------------------------------------------------
/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/framework.png
--------------------------------------------------------------------------------
/lib/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/.DS_Store
--------------------------------------------------------------------------------
/lib/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .ostrack.ostrack import build_ostrack
2 |
--------------------------------------------------------------------------------
/lib/models/ostrack/__init__.py:
--------------------------------------------------------------------------------
1 | from .ostrack import build_ostrack
2 |
--------------------------------------------------------------------------------
/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/results.png
--------------------------------------------------------------------------------
/experiments.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/experiments.png
--------------------------------------------------------------------------------
/lib/test/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/.DS_Store
--------------------------------------------------------------------------------
/lib/test/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .params import TrackerParams, FeatureParams, Choice
--------------------------------------------------------------------------------
/lib/config/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/config/.DS_Store
--------------------------------------------------------------------------------
/lib/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/models/.DS_Store
--------------------------------------------------------------------------------
/lib/train/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/.DS_Store
--------------------------------------------------------------------------------
/experiments/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/experiments/.DS_Store
--------------------------------------------------------------------------------
/lib/train/admin/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/admin/.DS_Store
--------------------------------------------------------------------------------
/lib/train/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/.DS_Store
--------------------------------------------------------------------------------
/lib/config/ostrack/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/config/ostrack/.DS_Store
--------------------------------------------------------------------------------
/lib/models/layers/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/models/layers/.DS_Store
--------------------------------------------------------------------------------
/lib/models/ostrack/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/models/ostrack/.DS_Store
--------------------------------------------------------------------------------
/lib/train/actors/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/actors/.DS_Store
--------------------------------------------------------------------------------
/lib/train/trainers/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/trainers/.DS_Store
--------------------------------------------------------------------------------
/experiments/ostrack/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/experiments/ostrack/.DS_Store
--------------------------------------------------------------------------------
/lib/test/evaluation/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/evaluation/.DS_Store
--------------------------------------------------------------------------------
/lib/train/actors/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_actor import BaseActor
2 | from .ostrack import OSTrackActor
3 |
--------------------------------------------------------------------------------
/lib/train/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_trainer import BaseTrainer
2 | from .ltr_trainer import LTRTrainer
3 |
--------------------------------------------------------------------------------
/lib/vis/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/vis/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/tracking/__pycache__/test.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/tracking/__pycache__/test.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/merge.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/merge.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/tensor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/tensor.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/tensor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/tensor.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/vis/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/vis/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/vis/__pycache__/plotting.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/vis/__pycache__/plotting.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/box_ops.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/box_ops.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/ce_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/ce_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/vis/__pycache__/visdom_cus.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/vis/__pycache__/visdom_cus.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/utils/__pycache__/hann.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/utils/__pycache__/hann.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/utils/__pycache__/params.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/utils/__pycache__/params.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/test/utils/__pycache__/params.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/utils/__pycache__/params.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/loader.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/loader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/loader.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/focal_loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/focal_loss.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/lmdb_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/lmdb_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/lmdb_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/lmdb_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/tracking/__pycache__/_init_paths.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/tracking/__pycache__/_init_paths.cpython-37.pyc
--------------------------------------------------------------------------------
/tracking/__pycache__/_init_paths.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/tracking/__pycache__/_init_paths.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/test/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/sampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/sampler.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/coco.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/coco.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/lasot.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/lasot.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/utils/__pycache__/heapmap_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/utils/__pycache__/heapmap_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/analysis/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/analysis/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/test/analysis/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/analysis/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/parameter/__pycache__/ostrack.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/parameter/__pycache__/ostrack.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/tracker/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/tracker/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/tracker/__pycache__/ostrack.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/tracker/__pycache__/ostrack.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/tracker/__pycache__/vis_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/tracker/__pycache__/vis_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/utils/__pycache__/_init_paths.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/utils/__pycache__/_init_paths.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/utils/__pycache__/load_text.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/utils/__pycache__/load_text.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/test/utils/__pycache__/load_text.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/utils/__pycache__/load_text.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/processing.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/processing.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/coco_seq.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/coco_seq.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/got10k.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/got10k.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/parameter/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/parameter/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/tracker/__pycache__/basetracker.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/tracker/__pycache__/basetracker.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/tracker/__pycache__/data_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/tracker/__pycache__/data_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/image_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/image_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/image_loader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/image_loader.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/sampler_datr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/sampler_datr.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/sampler_imix.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/sampler_imix.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/wandb_logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/wandb_logger.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/COCO_tool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/COCO_tool.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/lasot_lmdb.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/lasot_lmdb.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/analysis/__pycache__/plot_results.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/analysis/__pycache__/plot_results.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/test/analysis/__pycache__/plot_results.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/analysis/__pycache__/plot_results.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/processing_datr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/processing_datr.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/processing_imix.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/processing_imix.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__pycache__/processing_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/data/__pycache__/processing_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/coco_seq_lmdb.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/coco_seq_lmdb.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/got10k_lmdb.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/got10k_lmdb.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/imagenetvid.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/imagenetvid.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/tracking_net.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/tracking_net.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/test/analysis/__pycache__/extract_results.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/analysis/__pycache__/extract_results.cpython-37.pyc
--------------------------------------------------------------------------------
/lib/test/analysis/__pycache__/extract_results.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/test/analysis/__pycache__/extract_results.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .loader import LTRLoader
2 | from .image_loader import jpeg4py_loader, opencv_loader, jpeg4py_loader_w_failsafe, default_image_loader
3 |
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/base_image_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/base_image_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/base_video_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/base_video_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/imagenetvid_lmdb.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/imagenetvid_lmdb.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/train/dataset/__pycache__/tracking_net_lmdb.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zj5559/DATr/HEAD/lib/train/dataset/__pycache__/tracking_net_lmdb.cpython-38.pyc
--------------------------------------------------------------------------------
/lib/vis/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def numpy_to_torch(a: np.ndarray):
6 | return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0)
--------------------------------------------------------------------------------
/lib/train/admin/__init__.py:
--------------------------------------------------------------------------------
1 | from .environment import env_settings, create_default_local_file_ITP_train
2 | from .stats import AverageMeter, StatValue
3 | from .tensorboard import TensorboardWriter
4 |
--------------------------------------------------------------------------------
/lib/test/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import Sequence
2 | from .tracker import Tracker, trackerlist
3 | from .datasets import get_dataset
4 | from .environment import create_default_local_file_ITP_test
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # training models with our datr
4 | config='datr_vitb_256_mae_ce_32x4_ep300'
5 | python tracking/train.py --script ostrack --config ${config} --save_dir /path/of/model --mode multiple --nproc_per_node 4 --datr 1
6 | python tracking/test.py ostrack ${config} --dataset lasot_extension_subset --threads 4 --num_gpus 4
7 |
8 |
9 |
--------------------------------------------------------------------------------
/lib/train/admin/settings.py:
--------------------------------------------------------------------------------
1 | from lib.train.admin.environment import env_settings
2 |
3 |
4 | class Settings:
5 | """ Training settings, e.g. the paths to datasets and networks."""
6 | def __init__(self):
7 | self.set_default()
8 |
9 | def set_default(self):
10 | self.env = env_settings()
11 | self.use_gpu = True
12 |
13 |
14 |
--------------------------------------------------------------------------------
/tracking/_init_paths.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os.path as osp
6 | import sys
7 |
8 |
9 | def add_path(path):
10 | if path not in sys.path:
11 | sys.path.insert(0, path)
12 |
13 |
14 | this_dir = osp.dirname(__file__)
15 |
16 | prj_path = osp.join(this_dir, '..')
17 | add_path(prj_path)
18 |
--------------------------------------------------------------------------------
/lib/train/_init_paths.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os.path as osp
6 | import sys
7 |
8 |
9 | def add_path(path):
10 | if path not in sys.path:
11 | sys.path.insert(0, path)
12 |
13 |
14 | this_dir = osp.dirname(__file__)
15 |
16 | prj_path = osp.join(this_dir, '../..')
17 | add_path(prj_path)
18 |
--------------------------------------------------------------------------------
/lib/test/utils/_init_paths.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os.path as osp
6 | import sys
7 |
8 |
9 | def add_path(path):
10 | if path not in sys.path:
11 | sys.path.insert(0, path)
12 |
13 |
14 | this_dir = osp.dirname(__file__)
15 |
16 | prj_path = osp.join(this_dir, '..', '..', '..')
17 | add_path(prj_path)
18 |
--------------------------------------------------------------------------------
/lib/train/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .lasot import Lasot
2 | from .got10k import Got10k
3 | from .tracking_net import TrackingNet
4 | from .imagenetvid import ImagenetVID
5 | from .coco import MSCOCO
6 | from .coco_seq import MSCOCOSeq
7 | from .got10k_lmdb import Got10k_lmdb
8 | from .lasot_lmdb import Lasot_lmdb
9 | from .imagenetvid_lmdb import ImagenetVID_lmdb
10 | from .coco_seq_lmdb import MSCOCOSeq_lmdb
11 | from .tracking_net_lmdb import TrackingNet_lmdb
12 |
--------------------------------------------------------------------------------
/lib/train/admin/multigpu.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | # Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training
3 |
4 |
5 | def is_multi_gpu(net):
6 | return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel))
7 |
8 |
9 | class MultiGPU(nn.parallel.distributed.DistributedDataParallel):
10 | def __getattr__(self, item):
11 | try:
12 | return super().__getattr__(item)
13 | except:
14 | pass
15 | return getattr(self.module, item)
16 |
--------------------------------------------------------------------------------
/.idea/DATr-code.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/tracking/convert_transt.py:
--------------------------------------------------------------------------------
1 | import _init_paths
2 | import os
3 | from lib.test.evaluation import get_dataset
4 | import shutil
5 |
6 | trackers = []
7 | # dataset_name = 'uav'
8 | dataset_name = 'nfs'
9 |
10 |
11 | root_dir = "/data/sda/v-yanbi/iccv21/STARK_Latest/Stark"
12 | base_dir = os.path.join(root_dir, "test/tracking_results/TransT_N2")
13 | dataset = get_dataset(dataset_name)
14 | for x in dataset:
15 | seq_name = x.name
16 | file_name = "%s.txt" % (seq_name.replace("nfs_", ""))
17 | file_path = os.path.join(base_dir, file_name)
18 | file_path_new = os.path.join(base_dir, "%s.txt" % seq_name)
19 | if os.path.exists(file_path):
20 | shutil.move(file_path, file_path_new)
21 |
22 |
--------------------------------------------------------------------------------
/lib/train/data_specs/README.md:
--------------------------------------------------------------------------------
1 | # README
2 |
3 | ## Description for different text files
4 | GOT10K
5 | - got10k_train_full_split.txt: the complete GOT-10K training set. (9335 videos)
6 | - got10k_train_split.txt: part of videos from the GOT-10K training set
7 | - got10k_val_split.txt: another part of videos from the GOT-10K training set
8 | - got10k_vot_exclude.txt: 1k videos that are forbidden from "using to train models then testing on VOT" (as required by [VOT Challenge](https://www.votchallenge.net/vot2020/participation.html))
9 | - got10k_vot_train_split.txt: part of videos from the "VOT-permitted" GOT-10K training set
10 | - got10k_vot_val_split.txt: another part of videos from the "VOT-permitted" GOT-10K training set
11 |
12 | LaSOT
13 | - lasot_train_split.txt: the complete LaSOT training set
14 |
15 | TrackingNnet
16 | - trackingnet_classmap.txt: The map from the sequence name to the target class for the TrackingNet
--------------------------------------------------------------------------------
/tracking/analysis_results_ITP.py:
--------------------------------------------------------------------------------
1 | import _init_paths
2 | import argparse
3 | from lib.test.analysis.plot_results import print_results
4 | from lib.test.evaluation import get_dataset, trackerlist
5 |
6 |
7 | def parse_args():
8 | """
9 | args for evaluation.
10 | """
11 | parser = argparse.ArgumentParser(description='Parse args for training')
12 | # for train
13 | parser.add_argument('--script', type=str, help='training script name')
14 | parser.add_argument('--config', type=str, default='baseline', help='yaml configure file name')
15 |
16 | args = parser.parse_args()
17 |
18 | return args
19 |
20 |
21 | if __name__ == "__main__":
22 | args = parse_args()
23 | trackers = []
24 | trackers.extend(trackerlist(args.script, args.config, "None", None, args.config))
25 |
26 | dataset = get_dataset('lasot')
27 |
28 | print_results(trackers, dataset, 'LaSOT', merge_results=True, plot_types=('success', 'prec', 'norm_prec'))
--------------------------------------------------------------------------------
/tracking/create_default_local_file.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import _init_paths
4 | from lib.train.admin import create_default_local_file_ITP_train
5 | from lib.test.evaluation import create_default_local_file_ITP_test
6 |
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser(description='Create default local file on ITP or PAI')
10 | parser.add_argument("--workspace_dir", type=str, required=True) # workspace dir
11 | parser.add_argument("--data_dir", type=str, required=True)
12 | parser.add_argument("--save_dir", type=str, required=True)
13 | args = parser.parse_args()
14 | return args
15 |
16 |
17 | if __name__ == "__main__":
18 | args = parse_args()
19 | workspace_dir = os.path.realpath(args.workspace_dir)
20 | data_dir = os.path.realpath(args.data_dir)
21 | save_dir = os.path.realpath(args.save_dir)
22 | create_default_local_file_ITP_train(workspace_dir, data_dir)
23 | create_default_local_file_ITP_test(workspace_dir, data_dir, save_dir)
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Botao Ye
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/lib/test/parameter/ostrack.py:
--------------------------------------------------------------------------------
1 | from lib.test.utils import TrackerParams
2 | import os
3 | from lib.test.evaluation.environment import env_settings
4 | from lib.config.ostrack.config import cfg, update_config_from_file
5 |
6 |
7 | def parameters(yaml_name: str):
8 | params = TrackerParams()
9 | prj_dir = env_settings().prj_dir
10 | save_dir = env_settings().save_dir
11 | # update default config from yaml file
12 | yaml_file = os.path.join(prj_dir, 'experiments/ostrack/%s.yaml' % yaml_name)
13 | update_config_from_file(yaml_file)
14 | params.cfg = cfg
15 | print("test config: ", cfg)
16 |
17 | # template and search region
18 | params.template_factor = cfg.TEST.TEMPLATE_FACTOR
19 | params.template_size = cfg.TEST.TEMPLATE_SIZE
20 | params.search_factor = cfg.TEST.SEARCH_FACTOR
21 | params.search_size = cfg.TEST.SEARCH_SIZE
22 |
23 | # Network checkpoint path
24 | params.checkpoint = os.path.join(save_dir, "checkpoints/train/ostrack/%s/OSTrack_ep%04d.pth.tar" %
25 | (yaml_name, cfg.TEST.EPOCH))
26 |
27 | # whether to save boxes from all queries
28 | params.save_all_boxes = False
29 |
30 | return params
31 |
--------------------------------------------------------------------------------
/lib/train/data/wandb_logger.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | try:
4 | import wandb
5 | except ImportError:
6 | raise ImportError(
7 | 'Please run "pip install wandb" to install wandb')
8 |
9 |
10 | class WandbWriter:
11 | def __init__(self, exp_name, cfg, output_dir, cur_step=0, step_interval=0):
12 | self.wandb = wandb
13 | self.step = cur_step
14 | self.interval = step_interval
15 | wandb.init(project="tracking", name=exp_name, config=cfg, dir=output_dir)
16 |
17 | def write_log(self, stats: OrderedDict, epoch=-1):
18 | self.step += 1
19 | for loader_name, loader_stats in stats.items():
20 | if loader_stats is None:
21 | continue
22 |
23 | log_dict = {}
24 | for var_name, val in loader_stats.items():
25 | if hasattr(val, 'avg'):
26 | log_dict.update({loader_name + '/' + var_name: val.avg})
27 | else:
28 | log_dict.update({loader_name + '/' + var_name: val.val})
29 |
30 | if epoch >= 0:
31 | log_dict.update({loader_name + '/epoch': epoch})
32 |
33 | self.wandb.log(log_dict, step=self.step*self.interval)
34 |
--------------------------------------------------------------------------------
/lib/train/admin/tensorboard.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | try:
4 | from torch.utils.tensorboard import SummaryWriter
5 | except:
6 | print('WARNING: You are using tensorboardX instead sis you have a too old pytorch version.')
7 | from tensorboardX import SummaryWriter
8 |
9 |
10 | class TensorboardWriter:
11 | def __init__(self, directory, loader_names):
12 | self.directory = directory
13 | self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names})
14 |
15 | def write_info(self, script_name, description):
16 | tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info'))
17 | tb_info_writer.add_text('Script_name', script_name)
18 | tb_info_writer.add_text('Description', description)
19 | tb_info_writer.close()
20 |
21 | def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1):
22 | for loader_name, loader_stats in stats.items():
23 | if loader_stats is None:
24 | continue
25 | for var_name, val in loader_stats.items():
26 | if hasattr(val, 'history') and getattr(val, 'has_new_data', True):
27 | self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch)
--------------------------------------------------------------------------------
/tracking/analysis_results.py:
--------------------------------------------------------------------------------
1 | import _init_paths
2 | import matplotlib.pyplot as plt
3 | plt.rcParams['figure.figsize'] = [8, 8]
4 |
5 | from lib.test.analysis.plot_results import plot_results, print_results, print_per_sequence_results
6 | from lib.test.evaluation import get_dataset, trackerlist
7 |
8 | trackers = []
9 | dataset_name = 'lasot_extension_subset'
10 | # dataset_name = 'lasot'
11 |
12 | """ostrack"""
13 | trackers.extend(trackerlist(name='ostrack', parameter_name='datr_vitb_256_mae_ce_32x4_ep300', dataset_name=dataset_name,
14 | run_ids=None, display_name='datr_vitb_256'))
15 |
16 |
17 | trackers.extend(trackerlist(name='ostrack', parameter_name='vitb_256_mae_ce_32x4_ep300', dataset_name=dataset_name,
18 | run_ids=None, display_name='vitb_256'))
19 |
20 |
21 |
22 | dataset = get_dataset(dataset_name)
23 | # dataset = get_dataset('otb', 'nfs', 'uav', 'tc128ce')
24 | # plot_results(trackers, dataset, 'OTB2015', merge_results=True, plot_types=('success', 'norm_prec'),
25 | # skip_missing_seq=False, force_evaluation=True, plot_bin_gap=0.05)
26 | print_results(trackers, dataset, dataset_name, merge_results=True,seq_eval=False, plot_types=('success', 'norm_prec', 'prec'))
27 | # print_results(trackers, dataset, 'UNO', merge_results=True, plot_types=('success', 'prec'))
28 |
--------------------------------------------------------------------------------
/lib/utils/merge.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def merge_template_search(inp_list, return_search=False, return_template=False):
5 | """NOTICE: search region related features must be in the last place"""
6 | seq_dict = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0),
7 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1),
8 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)}
9 | if return_search:
10 | x = inp_list[-1]
11 | seq_dict.update({"feat_x": x["feat"], "mask_x": x["mask"], "pos_x": x["pos"]})
12 | if return_template:
13 | z = inp_list[0]
14 | seq_dict.update({"feat_z": z["feat"], "mask_z": z["mask"], "pos_z": z["pos"]})
15 | return seq_dict
16 |
17 |
18 | def get_qkv(inp_list):
19 | """The 1st element of the inp_list is about the template,
20 | the 2nd (the last) element is about the search region"""
21 | dict_x = inp_list[-1]
22 | dict_c = {"feat": torch.cat([x["feat"] for x in inp_list], dim=0),
23 | "mask": torch.cat([x["mask"] for x in inp_list], dim=1),
24 | "pos": torch.cat([x["pos"] for x in inp_list], dim=0)} # concatenated dict
25 | q = dict_x["feat"] + dict_x["pos"]
26 | k = dict_c["feat"] + dict_c["pos"]
27 | v = dict_c["feat"]
28 | key_padding_mask = dict_c["mask"]
29 | return q, k, v, key_padding_mask
30 |
--------------------------------------------------------------------------------
/lib/test/utils/params.py:
--------------------------------------------------------------------------------
1 | from lib.utils import TensorList
2 | import random
3 |
4 |
5 | class TrackerParams:
6 | """Class for tracker parameters."""
7 | def set_default_values(self, default_vals: dict):
8 | for name, val in default_vals.items():
9 | if not hasattr(self, name):
10 | setattr(self, name, val)
11 |
12 | def get(self, name: str, *default):
13 | """Get a parameter value with the given name. If it does not exists, it return the default value given as a
14 | second argument or returns an error if no default value is given."""
15 | if len(default) > 1:
16 | raise ValueError('Can only give one default value.')
17 |
18 | if not default:
19 | return getattr(self, name)
20 |
21 | return getattr(self, name, default[0])
22 |
23 | def has(self, name: str):
24 | """Check if there exist a parameter with the given name."""
25 | return hasattr(self, name)
26 |
27 |
28 | class FeatureParams:
29 | """Class for feature specific parameters"""
30 | def __init__(self, *args, **kwargs):
31 | if len(args) > 0:
32 | raise ValueError
33 |
34 | for name, val in kwargs.items():
35 | if isinstance(val, list):
36 | setattr(self, name, TensorList(val))
37 | else:
38 | setattr(self, name, val)
39 |
40 |
41 | def Choice(*args):
42 | """Can be used to sample random parameter values."""
43 | return random.choice(args)
44 |
--------------------------------------------------------------------------------
/lib/train/actors/base_actor.py:
--------------------------------------------------------------------------------
1 | from lib.utils import TensorDict
2 |
3 |
4 | class BaseActor:
5 | """ Base class for actor. The actor class handles the passing of the data through the network
6 | and calculation the loss"""
7 | def __init__(self, net, objective):
8 | """
9 | args:
10 | net - The network to train
11 | objective - The loss function
12 | """
13 | self.net = net
14 | self.objective = objective
15 |
16 | def __call__(self, data: TensorDict):
17 | """ Called in each training iteration. Should pass in input data through the network, calculate the loss, and
18 | return the training stats for the input data
19 | args:
20 | data - A TensorDict containing all the necessary data blocks.
21 |
22 | returns:
23 | loss - loss for the input data
24 | stats - a dict containing detailed losses
25 | """
26 | raise NotImplementedError
27 |
28 | def to(self, device):
29 | """ Move the network to device
30 | args:
31 | device - device to use. 'cpu' or 'cuda'
32 | """
33 | self.net.to(device)
34 |
35 | def train(self, mode=True):
36 | """ Set whether the network is in train mode.
37 | args:
38 | mode (True) - Bool specifying whether in training mode.
39 | """
40 | self.net.train(mode)
41 |
42 | def eval(self):
43 | """ Set network to eval mode"""
44 | self.train(False)
--------------------------------------------------------------------------------
/lib/test/utils/transform_trackingnet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import shutil
4 | import argparse
5 | import _init_paths
6 | from lib.test.evaluation.environment import env_settings
7 |
8 |
9 | def transform_trackingnet(tracker_name, cfg_name):
10 | env = env_settings()
11 | result_dir = env.results_path
12 | src_dir = os.path.join(result_dir, "%s/%s/trackingnet/" % (tracker_name, cfg_name))
13 | dest_dir = os.path.join(result_dir, "%s/%s/trackingnet_submit/" % (tracker_name, cfg_name))
14 | if not os.path.exists(dest_dir):
15 | os.makedirs(dest_dir)
16 | items = os.listdir(src_dir)
17 | for item in items:
18 | if "all" in item:
19 | continue
20 | if "time" not in item:
21 | src_path = os.path.join(src_dir, item)
22 | dest_path = os.path.join(dest_dir, item)
23 | bbox_arr = np.loadtxt(src_path, dtype=np.int, delimiter='\t')
24 | np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',')
25 | # make zip archive
26 | shutil.make_archive(src_dir, "zip", src_dir)
27 | shutil.make_archive(dest_dir, "zip", dest_dir)
28 | # Remove the original files
29 | shutil.rmtree(src_dir)
30 | shutil.rmtree(dest_dir)
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser(description='transform trackingnet results.')
35 | parser.add_argument('--tracker_name', type=str, help='Name of tracking method.')
36 | parser.add_argument('--cfg_name', type=str, help='Name of config file.')
37 |
38 | args = parser.parse_args()
39 | transform_trackingnet(args.tracker_name, args.cfg_name)
40 |
--------------------------------------------------------------------------------
/tracking/video_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | prj_path = os.path.join(os.path.dirname(__file__), '..')
6 | if prj_path not in sys.path:
7 | sys.path.append(prj_path)
8 |
9 | from lib.test.evaluation import Tracker
10 |
11 |
12 | def run_video(tracker_name, tracker_param, videofile, optional_box=None, debug=None, save_results=False):
13 | """Run the tracker on your webcam.
14 | args:
15 | tracker_name: Name of tracking method.
16 | tracker_param: Name of parameter file.
17 | debug: Debug level.
18 | """
19 | tracker = Tracker(tracker_name, tracker_param, "video")
20 | tracker.run_video(videofilepath=videofile, optional_box=optional_box, debug=debug, save_results=save_results)
21 |
22 |
23 | def main():
24 | parser = argparse.ArgumentParser(description='Run the tracker on your webcam.')
25 | parser.add_argument('tracker_name', type=str, help='Name of tracking method.')
26 | parser.add_argument('tracker_param', type=str, help='Name of parameter file.')
27 | parser.add_argument('videofile', type=str, help='path to a video file.')
28 | parser.add_argument('--optional_box', type=float, default=None, nargs="+", help='optional_box with format x y w h.')
29 | parser.add_argument('--debug', type=int, default=0, help='Debug level.')
30 | parser.add_argument('--save_results', dest='save_results', action='store_true', help='Save bounding boxes')
31 | parser.set_defaults(save_results=False)
32 |
33 | args = parser.parse_args()
34 |
35 | run_video(args.tracker_name, args.tracker_param, args.videofile, args.optional_box, args.debug, args.save_results)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/lib/utils/variable_hook.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from bytecode import Bytecode, Instr
3 |
4 |
5 | class get_local(object):
6 | cache = {}
7 | is_activate = False
8 |
9 | def __init__(self, varname):
10 | self.varname = varname
11 |
12 | def __call__(self, func):
13 | if not type(self).is_activate:
14 | return func
15 |
16 | type(self).cache[func.__qualname__] = []
17 | c = Bytecode.from_code(func.__code__)
18 | extra_code = [
19 | Instr('STORE_FAST', '_res'),
20 | Instr('LOAD_FAST', self.varname),
21 | Instr('STORE_FAST', '_value'),
22 | Instr('LOAD_FAST', '_res'),
23 | Instr('LOAD_FAST', '_value'),
24 | Instr('BUILD_TUPLE', 2),
25 | Instr('STORE_FAST', '_result_tuple'),
26 | Instr('LOAD_FAST', '_result_tuple'),
27 | ]
28 | c[-1:-1] = extra_code
29 | func.__code__ = c.to_code()
30 |
31 | def wrapper(*args, **kwargs):
32 | res, values = func(*args, **kwargs)
33 | if isinstance(values, torch.Tensor):
34 | type(self).cache[func.__qualname__].append(values.detach().cpu().numpy())
35 | elif isinstance(values, list): # list of Tensor
36 | type(self).cache[func.__qualname__].append([value.detach().cpu().numpy() for value in values])
37 | else:
38 | raise NotImplementedError
39 | return res
40 |
41 | return wrapper
42 |
43 | @classmethod
44 | def clear(cls):
45 | for key in cls.cache.keys():
46 | cls.cache[key] = []
47 |
48 | @classmethod
49 | def activate(cls):
50 | cls.is_activate = True
51 |
--------------------------------------------------------------------------------
/lib/models/layers/frozen_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class FrozenBatchNorm2d(torch.nn.Module):
5 | """
6 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
7 |
8 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
9 | without which any other models than torchvision.models.resnet[18,34,50,101]
10 | produce nans.
11 | """
12 |
13 | def __init__(self, n):
14 | super(FrozenBatchNorm2d, self).__init__()
15 | self.register_buffer("weight", torch.ones(n))
16 | self.register_buffer("bias", torch.zeros(n))
17 | self.register_buffer("running_mean", torch.zeros(n))
18 | self.register_buffer("running_var", torch.ones(n))
19 |
20 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
21 | missing_keys, unexpected_keys, error_msgs):
22 | num_batches_tracked_key = prefix + 'num_batches_tracked'
23 | if num_batches_tracked_key in state_dict:
24 | del state_dict[num_batches_tracked_key]
25 |
26 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
27 | state_dict, prefix, local_metadata, strict,
28 | missing_keys, unexpected_keys, error_msgs)
29 |
30 | def forward(self, x):
31 | # move reshapes to the beginning
32 | # to make it fuser-friendly
33 | w = self.weight.reshape(1, -1, 1, 1)
34 | b = self.bias.reshape(1, -1, 1, 1)
35 | rv = self.running_var.reshape(1, -1, 1, 1)
36 | rm = self.running_mean.reshape(1, -1, 1, 1)
37 | eps = 1e-5
38 | scale = w * (rv + eps).rsqrt() # rsqrt(x): 1/sqrt(x), r: reciprocal
39 | bias = b - rm * scale
40 | return x * scale + bias
41 |
--------------------------------------------------------------------------------
/lib/test/utils/load_text.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | def load_text_numpy(path, delimiter, dtype):
6 | if isinstance(delimiter, (tuple, list)):
7 | for d in delimiter:
8 | try:
9 | ground_truth_rect = np.loadtxt(path, delimiter=d, dtype=dtype)
10 | return ground_truth_rect
11 | except:
12 | pass
13 |
14 | raise Exception('Could not read file {}'.format(path))
15 | else:
16 | ground_truth_rect = np.loadtxt(path, delimiter=delimiter, dtype=dtype)
17 | return ground_truth_rect
18 |
19 |
20 | def load_text_pandas(path, delimiter, dtype):
21 | if isinstance(delimiter, (tuple, list)):
22 | for d in delimiter:
23 | try:
24 | ground_truth_rect = pd.read_csv(path, delimiter=d, header=None, dtype=dtype, na_filter=False,
25 | low_memory=False).values
26 | return ground_truth_rect
27 | except Exception as e:
28 | pass
29 |
30 | raise Exception('Could not read file {}'.format(path))
31 | else:
32 | ground_truth_rect = pd.read_csv(path, delimiter=delimiter, header=None, dtype=dtype, na_filter=False,
33 | low_memory=False).values
34 | return ground_truth_rect
35 |
36 |
37 | def load_text(path, delimiter=' ', dtype=np.float32, backend='numpy'):
38 | if backend == 'numpy':
39 | return load_text_numpy(path, delimiter, dtype)
40 | elif backend == 'pandas':
41 | return load_text_pandas(path, delimiter, dtype)
42 |
43 |
44 | def load_str(path):
45 | with open(path, "r") as f:
46 | text_str = f.readline().strip().lower()
47 | return text_str
48 |
--------------------------------------------------------------------------------
/lib/utils/lmdb_utils.py:
--------------------------------------------------------------------------------
1 | import lmdb
2 | import numpy as np
3 | import cv2
4 | import json
5 |
6 | LMDB_ENVS = dict()
7 | LMDB_HANDLES = dict()
8 | LMDB_FILELISTS = dict()
9 |
10 |
11 | def get_lmdb_handle(name):
12 | global LMDB_HANDLES, LMDB_FILELISTS
13 | item = LMDB_HANDLES.get(name, None)
14 | if item is None:
15 | env = lmdb.open(name, readonly=True, lock=False, readahead=False, meminit=False)
16 | LMDB_ENVS[name] = env
17 | item = env.begin(write=False)
18 | LMDB_HANDLES[name] = item
19 |
20 | return item
21 |
22 |
23 | def decode_img(lmdb_fname, key_name):
24 | handle = get_lmdb_handle(lmdb_fname)
25 | binfile = handle.get(key_name.encode())
26 | if binfile is None:
27 | print("Illegal data detected. %s %s" % (lmdb_fname, key_name))
28 | s = np.frombuffer(binfile, np.uint8)
29 | x = cv2.cvtColor(cv2.imdecode(s, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
30 | return x
31 |
32 |
33 | def decode_str(lmdb_fname, key_name):
34 | handle = get_lmdb_handle(lmdb_fname)
35 | binfile = handle.get(key_name.encode())
36 | string = binfile.decode()
37 | return string
38 |
39 |
40 | def decode_json(lmdb_fname, key_name):
41 | return json.loads(decode_str(lmdb_fname, key_name))
42 |
43 |
44 | if __name__ == "__main__":
45 | lmdb_fname = "/data/sda/v-yanbi/iccv21/LittleBoy_clean/data/got10k_lmdb"
46 | '''Decode image'''
47 | # key_name = "test/GOT-10k_Test_000001/00000001.jpg"
48 | # img = decode_img(lmdb_fname, key_name)
49 | # cv2.imwrite("001.jpg", img)
50 | '''Decode str'''
51 | # key_name = "test/list.txt"
52 | # key_name = "train/GOT-10k_Train_000001/groundtruth.txt"
53 | key_name = "train/GOT-10k_Train_000001/absence.label"
54 | str_ = decode_str(lmdb_fname, key_name)
55 | print(str_)
56 |
--------------------------------------------------------------------------------
/lib/train/admin/stats.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class StatValue:
4 | def __init__(self):
5 | self.clear()
6 |
7 | def reset(self):
8 | self.val = 0
9 |
10 | def clear(self):
11 | self.reset()
12 | self.history = []
13 |
14 | def update(self, val):
15 | self.val = val
16 | self.history.append(self.val)
17 |
18 |
19 | class AverageMeter(object):
20 | """Computes and stores the average and current value"""
21 | def __init__(self):
22 | self.clear()
23 | self.has_new_data = False
24 |
25 | def reset(self):
26 | self.avg = 0
27 | self.val = 0
28 | self.sum = 0
29 | self.count = 0
30 |
31 | def clear(self):
32 | self.reset()
33 | self.history = []
34 |
35 | def update(self, val, n=1):
36 | self.val = val
37 | self.sum += val * n
38 | self.count += n
39 | self.avg = self.sum / self.count
40 |
41 | def new_epoch(self):
42 | if self.count > 0:
43 | self.history.append(self.avg)
44 | self.reset()
45 | self.has_new_data = True
46 | else:
47 | self.has_new_data = False
48 |
49 |
50 | def topk_accuracy(output, target, topk=(1,)):
51 | """Computes the precision@k for the specified values of k"""
52 | single_input = not isinstance(topk, (tuple, list))
53 | if single_input:
54 | topk = (topk,)
55 |
56 | maxk = max(topk)
57 | batch_size = target.size(0)
58 |
59 | _, pred = output.topk(maxk, 1, True, True)
60 | pred = pred.t()
61 | correct = pred.eq(target.view(1, -1).expand_as(pred))
62 |
63 | res = []
64 | for k in topk:
65 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)[0]
66 | res.append(correct_k * 100.0 / batch_size)
67 |
68 | if single_input:
69 | return res[0]
70 |
71 | return res
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DATr
2 | PyTorch implementation of "Leveraging the Power of Data Augmentation for Transformer-based Tracking" (WACV2024).
3 |
4 | Please find the paper [here](https://arxiv.org/pdf/2309.08264.pdf).
5 |
6 | ## Introduction
7 | In this paper, we perform systematic experiments to explore the impact of General Data Augmentations (GDA) on transformer trackers, including the pure transformer tracker and the hybrid CNN-Transformer tracker. Results below show GDAs have limited effects on SOTA trackers.
8 | 
9 |
10 | Then, We propose two Data Augmentation methods based on challenges faced by Transformer-based trackers, DATr for short. They improve trackers from perspectives of adaptability to different scales, flexibility to boundary targets, and robustness to interference, respectively.
11 | 
12 |
13 | Extensive experiments on different baseline trackers and benchmarks demonstrate the effectiveness and generalization of our DATr, especially for sequences with challenges and unseen classes.
14 | 
15 |
16 | ## Installation
17 | The environment installation and training configurations (like project path, pretrained models) are similar to the baseline trackers, e.g., OSTrack, please refer to [OSTrack](https://github.com/botaoye/OSTrack).
18 |
19 | ## Training and Testing
20 | Please see eval.sh to find the commands for training and testing.
21 |
22 | ## Models and Results
23 | Models and results can be found [here](https://drive.google.com/drive/folders/19-jBvfFVZxPcvZmy6ZXwtyW6NCnQ6bjY?usp=share_link).
24 |
25 | ## Acknowledgments
26 | Our work is mainly implemented on three different Transformer trackers, i.e., [OSTrack](https://github.com/botaoye/OSTrack), [MixFormer](https://github.com/MCG-NJU/MixFormer), and [STARK](https://github.com/MasterBin-IIAU/Stark-1). Thanks for these concise and effective SOT frameworks.
27 |
--------------------------------------------------------------------------------
/lib/test/evaluation/tc128dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
3 | import os
4 | import glob
5 | import six
6 |
7 |
8 | class TC128Dataset(BaseDataset):
9 | """
10 | TC-128 Dataset
11 | modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit)
12 | """
13 | def __init__(self):
14 | super().__init__()
15 | self.base_path = self.env_settings.tc128_path
16 | self.anno_files = sorted(glob.glob(
17 | os.path.join(self.base_path, '*/*_gt.txt')))
18 | self.seq_dirs = [os.path.dirname(f) for f in self.anno_files]
19 | self.seq_names = [os.path.basename(d) for d in self.seq_dirs]
20 | # valid frame range for each sequence
21 | self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs]
22 |
23 | def get_sequence_list(self):
24 | return SequenceList([self._construct_sequence(s) for s in self.seq_names])
25 |
26 | def _construct_sequence(self, sequence_name):
27 | if isinstance(sequence_name, six.string_types):
28 | if not sequence_name in self.seq_names:
29 | raise Exception('Sequence {} not found.'.format(sequence_name))
30 | index = self.seq_names.index(sequence_name)
31 | # load valid frame range
32 | frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',')
33 | img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)]
34 |
35 | # load annotations
36 | anno = np.loadtxt(self.anno_files[index], delimiter=',')
37 | assert len(img_files) == len(anno)
38 | assert anno.shape[1] == 4
39 |
40 | # return img_files, anno
41 | return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4))
42 |
43 | def __len__(self):
44 | return len(self.seq_names)
45 |
--------------------------------------------------------------------------------
/lib/test/evaluation/tnl2kdataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
5 | from lib.test.utils.load_text import load_text, load_str
6 |
7 | ############
8 | # current 00000492.png of test_015_Sord_video_Q01_done is damaged and replaced by a copy of 00000491.png
9 | ############
10 |
11 |
12 | class TNL2kDataset(BaseDataset):
13 | """
14 | TNL2k test set
15 | """
16 | def __init__(self):
17 | super().__init__()
18 | self.base_path = self.env_settings.tnl2k_path
19 | self.sequence_list = self._get_sequence_list()
20 |
21 | def get_sequence_list(self):
22 | return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
23 |
24 | def _construct_sequence(self, sequence_name):
25 | # class_name = sequence_name.split('-')[0]
26 | anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name)
27 |
28 | ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
29 |
30 | text_dsp_path = '{}/{}/language.txt'.format(self.base_path, sequence_name)
31 | text_dsp = load_str(text_dsp_path)
32 |
33 | frames_path = '{}/{}/imgs'.format(self.base_path, sequence_name)
34 | frames_list = [f for f in os.listdir(frames_path)]
35 | frames_list = sorted(frames_list)
36 | frames_list = ['{}/{}'.format(frames_path, frame_i) for frame_i in frames_list]
37 |
38 | # target_class = class_name
39 | return Sequence(sequence_name, frames_list, 'tnl2k', ground_truth_rect.reshape(-1, 4), text_dsp=text_dsp)
40 |
41 | def __len__(self):
42 | return len(self.sequence_list)
43 |
44 | def _get_sequence_list(self):
45 | sequence_list = []
46 | for seq in os.listdir(self.base_path):
47 | if os.path.isdir(os.path.join(self.base_path, seq)):
48 | sequence_list.append(seq)
49 |
50 | return sequence_list
51 |
--------------------------------------------------------------------------------
/experiments/ostrack/datr_vitb_384_mae_ce_32x4_ep300.yaml:
--------------------------------------------------------------------------------
1 | DA:
2 | border_prob: 0.05
3 | sfactor: [2.0,6.0]
4 | imix: True
5 | imix_epoch: 1
6 | imix_interval: 11
7 | imix_occrate: 0.5
8 | imix_reverse_prob: 0.0
9 | mix_type: 1
10 | norm_type: 'global'
11 | DATA:
12 | MAX_SAMPLE_INTERVAL: 200
13 | MEAN:
14 | - 0.485
15 | - 0.456
16 | - 0.406
17 | SEARCH:
18 | CENTER_JITTER: 4.5
19 | FACTOR: 5.0
20 | SCALE_JITTER: 0.5
21 | SIZE: 384
22 | STD:
23 | - 0.229
24 | - 0.224
25 | - 0.225
26 | TEMPLATE:
27 | CENTER_JITTER: 0
28 | FACTOR: 2.0
29 | SCALE_JITTER: 0
30 | SIZE: 192
31 | # TRAIN:
32 | # DATASETS_NAME:
33 | # - GOT10K_train_full
34 | # DATASETS_RATIO:
35 | # - 1
36 | # SAMPLE_PER_EPOCH: 60000
37 |
38 | TRAIN:
39 | DATASETS_NAME:
40 | - LASOT
41 | - GOT10K_vottrain
42 | - COCO17
43 | - TRACKINGNET
44 | DATASETS_RATIO:
45 | - 1
46 | - 1
47 | - 1
48 | - 1
49 | SAMPLE_PER_EPOCH: 60000
50 | VAL:
51 | DATASETS_NAME:
52 | - GOT10K_votval
53 | DATASETS_RATIO:
54 | - 1
55 | SAMPLE_PER_EPOCH: 10000
56 | MODEL:
57 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth"
58 | EXTRA_MERGER: False
59 | RETURN_INTER: False
60 | BACKBONE:
61 | TYPE: vit_base_patch16_224_ce
62 | STRIDE: 16
63 | CE_LOC: [3, 6, 9]
64 | CE_KEEP_RATIO: [0.7, 0.7, 0.7]
65 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX
66 | HEAD:
67 | TYPE: CENTER
68 | NUM_CHANNELS: 256
69 | TRAIN:
70 | BACKBONE_MULTIPLIER: 0.1
71 | DROP_PATH_RATE: 0.1
72 | BATCH_SIZE: 32
73 | EPOCH: 300
74 | GIOU_WEIGHT: 2.0
75 | L1_WEIGHT: 5.0
76 | GRAD_CLIP_NORM: 0.1
77 | LR: 0.0004
78 | LR_DROP_EPOCH: 240
79 | NUM_WORKER: 10
80 | OPTIMIZER: ADAMW
81 | PRINT_INTERVAL: 50
82 | SCHEDULER:
83 | TYPE: step
84 | DECAY_RATE: 0.1
85 | VAL_EPOCH_INTERVAL: 20
86 | WEIGHT_DECAY: 0.0001
87 | AMP: False
88 | TEST:
89 | EPOCH: 300
90 | SEARCH_FACTOR: 5.0
91 | SEARCH_SIZE: 384
92 | TEMPLATE_FACTOR: 2.0
93 | TEMPLATE_SIZE: 192
--------------------------------------------------------------------------------
/lib/test/evaluation/tc128cedataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
3 | import os
4 | import glob
5 | import six
6 |
7 |
8 | class TC128CEDataset(BaseDataset):
9 | """
10 | TC-128 Dataset (78 newly added sequences)
11 | modified from the implementation in got10k-toolkit (https://github.com/got-10k/toolkit)
12 | """
13 | def __init__(self):
14 | super().__init__()
15 | self.base_path = self.env_settings.tc128_path
16 | self.anno_files = sorted(glob.glob(
17 | os.path.join(self.base_path, '*/*_gt.txt')))
18 | """filter the newly added sequences (_ce)"""
19 | self.anno_files = [s for s in self.anno_files if "_ce" in s]
20 | self.seq_dirs = [os.path.dirname(f) for f in self.anno_files]
21 | self.seq_names = [os.path.basename(d) for d in self.seq_dirs]
22 | # valid frame range for each sequence
23 | self.range_files = [glob.glob(os.path.join(d, '*_frames.txt'))[0] for d in self.seq_dirs]
24 |
25 | def get_sequence_list(self):
26 | return SequenceList([self._construct_sequence(s) for s in self.seq_names])
27 |
28 | def _construct_sequence(self, sequence_name):
29 | if isinstance(sequence_name, six.string_types):
30 | if not sequence_name in self.seq_names:
31 | raise Exception('Sequence {} not found.'.format(sequence_name))
32 | index = self.seq_names.index(sequence_name)
33 | # load valid frame range
34 | frames = np.loadtxt(self.range_files[index], dtype=int, delimiter=',')
35 | img_files = [os.path.join(self.seq_dirs[index], 'img/%04d.jpg' % f) for f in range(frames[0], frames[1] + 1)]
36 |
37 | # load annotations
38 | anno = np.loadtxt(self.anno_files[index], delimiter=',')
39 | assert len(img_files) == len(anno)
40 | assert anno.shape[1] == 4
41 |
42 | # return img_files, anno
43 | return Sequence(sequence_name, img_files, 'tc128', anno.reshape(-1, 4))
44 |
45 | def __len__(self):
46 | return len(self.seq_names)
47 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/lib/test/tracker/vis_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | ############## used for visulize eliminated tokens #################
5 | def get_keep_indices(decisions):
6 | keep_indices = []
7 | for i in range(3):
8 | if i == 0:
9 | keep_indices.append(decisions[i])
10 | else:
11 | keep_indices.append(keep_indices[-1][decisions[i]])
12 | return keep_indices
13 |
14 |
15 | def gen_masked_tokens(tokens, indices, alpha=0.2):
16 | # indices = [i for i in range(196) if i not in indices]
17 | indices = indices[0].astype(int)
18 | tokens = tokens.copy()
19 | tokens[indices] = alpha * tokens[indices] + (1 - alpha) * 255
20 | return tokens
21 |
22 |
23 | def recover_image(tokens, H, W, Hp, Wp, patch_size):
24 | # image: (C, 196, 16, 16)
25 | image = tokens.reshape(Hp, Wp, patch_size, patch_size, 3).swapaxes(1, 2).reshape(H, W, 3)
26 | return image
27 |
28 |
29 | def pad_img(img):
30 | height, width, channels = img.shape
31 | im_bg = np.ones((height, width + 8, channels)) * 255
32 | im_bg[0:height, 0:width, :] = img
33 | return im_bg
34 |
35 |
36 | def gen_visualization(image, mask_indices, patch_size=16):
37 | # image [224, 224, 3]
38 | # mask_indices, list of masked token indices
39 |
40 | # mask mask_indices need to cat
41 | # mask_indices = mask_indices[::-1]
42 | num_stages = len(mask_indices)
43 | for i in range(1, num_stages):
44 | mask_indices[i] = np.concatenate([mask_indices[i-1], mask_indices[i]], axis=1)
45 |
46 | # keep_indices = get_keep_indices(decisions)
47 | image = np.asarray(image)
48 | H, W, C = image.shape
49 | Hp, Wp = H // patch_size, W // patch_size
50 | image_tokens = image.reshape(Hp, patch_size, Wp, patch_size, 3).swapaxes(1, 2).reshape(Hp * Wp, patch_size, patch_size, 3)
51 |
52 | stages = [
53 | recover_image(gen_masked_tokens(image_tokens, mask_indices[i]), H, W, Hp, Wp, patch_size)
54 | for i in range(num_stages)
55 | ]
56 | imgs = [image] + stages
57 | imgs = [pad_img(img) for img in imgs]
58 | viz = np.concatenate(imgs, axis=1)
59 | return viz
60 |
--------------------------------------------------------------------------------
/lib/test/utils/transform_got10k.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import shutil
4 | import argparse
5 | import _init_paths
6 | from lib.test.evaluation.environment import env_settings
7 |
8 |
9 | def transform_got10k(tracker_name, cfg_name):
10 | env = env_settings()
11 | result_dir = env.results_path
12 | src_dir = os.path.join(result_dir, "%s/%s/got10k/" % (tracker_name, cfg_name))
13 | dest_dir = os.path.join(result_dir, "%s/%s/got10k_submit/" % (tracker_name, cfg_name))
14 | if not os.path.exists(dest_dir):
15 | os.makedirs(dest_dir)
16 | items = os.listdir(src_dir)
17 | for item in items:
18 | if "all" in item:
19 | continue
20 | src_path = os.path.join(src_dir, item)
21 | if "time" not in item:
22 | seq_name = item.replace(".txt", '')
23 | seq_dir = os.path.join(dest_dir, seq_name)
24 | if not os.path.exists(seq_dir):
25 | os.makedirs(seq_dir)
26 | new_item = item.replace(".txt", '_001.txt')
27 | dest_path = os.path.join(seq_dir, new_item)
28 | bbox_arr = np.loadtxt(src_path, dtype=np.int, delimiter='\t')
29 | np.savetxt(dest_path, bbox_arr, fmt='%d', delimiter=',')
30 | else:
31 | seq_name = item.replace("_time.txt", '')
32 | seq_dir = os.path.join(dest_dir, seq_name)
33 | if not os.path.exists(seq_dir):
34 | os.makedirs(seq_dir)
35 | dest_path = os.path.join(seq_dir, item)
36 | os.system("cp %s %s" % (src_path, dest_path))
37 | # make zip archive
38 | shutil.make_archive(src_dir, "zip", src_dir)
39 | shutil.make_archive(dest_dir, "zip", dest_dir)
40 | # Remove the original files
41 | shutil.rmtree(src_dir)
42 | shutil.rmtree(dest_dir)
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser(description='transform got10k results.')
47 | parser.add_argument('--tracker_name', type=str, help='Name of tracking method.')
48 | parser.add_argument('--cfg_name', type=str, help='Name of config file.')
49 |
50 | args = parser.parse_args()
51 | transform_got10k(args.tracker_name, args.cfg_name)
52 |
53 |
--------------------------------------------------------------------------------
/experiments/ostrack/datr_vitb_256_mae_ce_32x4_ep300.yaml:
--------------------------------------------------------------------------------
1 | DA:
2 | border_prob: 0.05
3 | sfactor: [2.0,6.0]
4 | imix: True
5 | imix_epoch: 1
6 | imix_interval: 11
7 | imix_occrate: 0.5
8 | imix_reverse_prob: 0.0
9 | mix_type: 1
10 | norm_type: 'global'
11 | DATA:
12 | MAX_SAMPLE_INTERVAL: 200
13 | MEAN:
14 | - 0.485
15 | - 0.456
16 | - 0.406
17 | SEARCH:
18 | CENTER_JITTER: 3
19 | FACTOR: 4.0
20 | SCALE_JITTER: 0.25
21 | SIZE: 256
22 | NUMBER: 1
23 | STD:
24 | - 0.229
25 | - 0.224
26 | - 0.225
27 | TEMPLATE:
28 | CENTER_JITTER: 0
29 | FACTOR: 2.0
30 | SCALE_JITTER: 0
31 | SIZE: 128
32 | # TRAIN:
33 | # DATASETS_NAME:
34 | # - GOT10K_train_full
35 | # DATASETS_RATIO:
36 | # - 1
37 | # SAMPLE_PER_EPOCH: 60000
38 |
39 | TRAIN:
40 | DATASETS_NAME:
41 | - LASOT
42 | - GOT10K_vottrain
43 | - COCO17
44 | - TRACKINGNET
45 | DATASETS_RATIO:
46 | - 1
47 | - 1
48 | - 1
49 | - 1
50 | SAMPLE_PER_EPOCH: 60000
51 | VAL:
52 | DATASETS_NAME:
53 | - GOT10K_votval
54 | DATASETS_RATIO:
55 | - 1
56 | SAMPLE_PER_EPOCH: 10000
57 | MODEL:
58 | PRETRAIN_FILE: "mae_pretrain_vit_base.pth"
59 | EXTRA_MERGER: False
60 | RETURN_INTER: False
61 | BACKBONE:
62 | TYPE: vit_base_patch16_224_ce
63 | STRIDE: 16
64 | CE_LOC: [3, 6, 9]
65 | CE_KEEP_RATIO: [0.7, 0.7, 0.7]
66 | CE_TEMPLATE_RANGE: 'CTR_POINT' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX
67 | HEAD:
68 | TYPE: CENTER
69 | NUM_CHANNELS: 256
70 | TRAIN:
71 | BACKBONE_MULTIPLIER: 0.1
72 | DROP_PATH_RATE: 0.1
73 | CE_START_EPOCH: 20 # candidate elimination start epoch
74 | CE_WARM_EPOCH: 80 # candidate elimination warm up epoch
75 | BATCH_SIZE: 32
76 | EPOCH: 300
77 | GIOU_WEIGHT: 2.0
78 | L1_WEIGHT: 5.0
79 | GRAD_CLIP_NORM: 0.1
80 | LR: 0.0004
81 | LR_DROP_EPOCH: 240
82 | NUM_WORKER: 10
83 | OPTIMIZER: ADAMW
84 | PRINT_INTERVAL: 50
85 | SCHEDULER:
86 | TYPE: step
87 | DECAY_RATE: 0.1
88 | VAL_EPOCH_INTERVAL: 20
89 | WEIGHT_DECAY: 0.0001
90 | AMP: False
91 | TEST:
92 | EPOCH: 300
93 | SEARCH_FACTOR: 4.0
94 | SEARCH_SIZE: 256
95 | TEMPLATE_FACTOR: 2.0
96 | TEMPLATE_SIZE: 128
--------------------------------------------------------------------------------
/lib/test/tracker/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from lib.utils.misc import NestedTensor
4 |
5 |
6 | class Preprocessor(object):
7 | def __init__(self):
8 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda()
9 | self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda()
10 |
11 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
12 | # Deal with the image patch
13 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0)
14 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
15 | # Deal with the attention mask
16 | amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W)
17 | return NestedTensor(img_tensor_norm, amask_tensor)
18 |
19 |
20 | class PreprocessorX(object):
21 | def __init__(self):
22 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).cuda()
23 | self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).cuda()
24 |
25 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
26 | # Deal with the image patch
27 | img_tensor = torch.tensor(img_arr).cuda().float().permute((2,0,1)).unsqueeze(dim=0)
28 | img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
29 | # Deal with the attention mask
30 | amask_tensor = torch.from_numpy(amask_arr).to(torch.bool).cuda().unsqueeze(dim=0) # (1,H,W)
31 | return img_tensor_norm, amask_tensor
32 |
33 |
34 | class PreprocessorX_onnx(object):
35 | def __init__(self):
36 | self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))
37 | self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))
38 |
39 | def process(self, img_arr: np.ndarray, amask_arr: np.ndarray):
40 | """img_arr: (H,W,3), amask_arr: (H,W)"""
41 | # Deal with the image patch
42 | img_arr_4d = img_arr[np.newaxis, :, :, :].transpose(0, 3, 1, 2)
43 | img_arr_4d = (img_arr_4d / 255.0 - self.mean) / self.std # (1, 3, H, W)
44 | # Deal with the attention mask
45 | amask_arr_3d = amask_arr[np.newaxis, :, :] # (1,H,W)
46 | return img_arr_4d.astype(np.float32), amask_arr_3d.astype(np.bool)
47 |
--------------------------------------------------------------------------------
/tracking/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | prj_path = os.path.join(os.path.dirname(__file__), '..')
6 | if prj_path not in sys.path:
7 | sys.path.append(prj_path)
8 |
9 | from lib.test.evaluation import get_dataset
10 | from lib.test.evaluation.running import run_dataset
11 | from lib.test.evaluation.tracker import Tracker
12 |
13 |
14 | def run_tracker(tracker_name, tracker_param, run_id=None, dataset_name='otb', sequence=None, debug=0, threads=0,
15 | num_gpus=8):
16 | """Run tracker on sequence or dataset.
17 | args:
18 | tracker_name: Name of tracking method.
19 | tracker_param: Name of parameter file.
20 | run_id: The run id.
21 | dataset_name: Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot).
22 | sequence: Sequence number or name.
23 | debug: Debug level.
24 | threads: Number of threads.
25 | """
26 |
27 | dataset = get_dataset(dataset_name)
28 |
29 | if sequence is not None:
30 | dataset = [dataset[sequence]]
31 |
32 | trackers = [Tracker(tracker_name, tracker_param, dataset_name, run_id)]
33 |
34 | run_dataset(dataset, trackers, debug, threads, num_gpus=num_gpus)
35 |
36 |
37 | def main():
38 | parser = argparse.ArgumentParser(description='Run tracker on sequence or dataset.')
39 | parser.add_argument('tracker_name', type=str, help='Name of tracking method.')
40 | parser.add_argument('tracker_param', type=str, help='Name of config file.')
41 | parser.add_argument('--runid', type=int, default=None, help='The run id.')
42 | parser.add_argument('--dataset_name', type=str, default='otb', help='Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot).')
43 | parser.add_argument('--sequence', type=str, default=None, help='Sequence number or name.')
44 | parser.add_argument('--debug', type=int, default=0, help='Debug level.')
45 | parser.add_argument('--threads', type=int, default=0, help='Number of threads.')
46 | parser.add_argument('--num_gpus', type=int, default=8)
47 |
48 | args = parser.parse_args()
49 |
50 | try:
51 | seq_name = int(args.sequence)
52 | except:
53 | seq_name = args.sequence
54 |
55 | run_tracker(args.tracker_name, args.tracker_param, args.runid, args.dataset_name, seq_name, args.debug,
56 | args.threads, num_gpus=args.num_gpus)
57 |
58 |
59 | if __name__ == '__main__':
60 | main()
61 |
--------------------------------------------------------------------------------
/lib/test/evaluation/got10kdataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
3 | from lib.test.utils.load_text import load_text
4 | import os
5 |
6 |
7 | class GOT10KDataset(BaseDataset):
8 | """ GOT-10k dataset.
9 |
10 | Publication:
11 | GOT-10k: A Large High-Diversity Benchmark for Generic Object Tracking in the Wild
12 | Lianghua Huang, Xin Zhao, and Kaiqi Huang
13 | arXiv:1810.11981, 2018
14 | https://arxiv.org/pdf/1810.11981.pdf
15 |
16 | Download dataset from http://got-10k.aitestunion.com/downloads
17 | """
18 | def __init__(self, split):
19 | super().__init__()
20 | # Split can be test, val, or ltrval (a validation split consisting of videos from the official train set)
21 | if split == 'test' or split == 'val':
22 | self.base_path = os.path.join(self.env_settings.got10k_path, split)
23 | else:
24 | self.base_path = os.path.join(self.env_settings.got10k_path, 'train')
25 |
26 | self.sequence_list = self._get_sequence_list(split)
27 | self.split = split
28 |
29 | def get_sequence_list(self):
30 | return SequenceList([self._construct_sequence(s) for s in self.sequence_list])
31 |
32 | def _construct_sequence(self, sequence_name):
33 | anno_path = '{}/{}/groundtruth.txt'.format(self.base_path, sequence_name)
34 |
35 | ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64)
36 |
37 | frames_path = '{}/{}'.format(self.base_path, sequence_name)
38 | frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")]
39 | frame_list.sort(key=lambda f: int(f[:-4]))
40 | frames_list = [os.path.join(frames_path, frame) for frame in frame_list]
41 |
42 | return Sequence(sequence_name, frames_list, 'got10k', ground_truth_rect.reshape(-1, 4))
43 |
44 | def __len__(self):
45 | return len(self.sequence_list)
46 |
47 | def _get_sequence_list(self, split):
48 | with open('{}/list.txt'.format(self.base_path)) as f:
49 | sequence_list = f.read().splitlines()
50 |
51 | if split == 'ltrval':
52 | with open('{}/got10k_val_split.txt'.format(self.env_settings.dataspec_path)) as f:
53 | seq_ids = f.read().splitlines()
54 |
55 | sequence_list = [sequence_list[int(x)] for x in seq_ids]
56 | return sequence_list
57 |
--------------------------------------------------------------------------------
/tracking/test_exp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 |
5 | prj_path = os.path.join(os.path.dirname(__file__), '..')
6 | if prj_path not in sys.path:
7 | sys.path.append(prj_path)
8 |
9 | from lib.test.evaluation import get_dataset
10 | from lib.test.evaluation.running import run_dataset
11 | from lib.test.evaluation.tracker import Tracker
12 |
13 |
14 | def run_tracker(tracker_name, tracker_param, run_id=None, dataset_name='otb', sequence=None, debug=0, threads=0,
15 | num_gpus=8):
16 | """Run tracker on sequence or dataset.
17 | args:
18 | tracker_name: Name of tracking method.
19 | tracker_param: Name of parameter file.
20 | run_id: The run id.
21 | dataset_name: Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot).
22 | sequence: Sequence number or name.
23 | debug: Debug level.
24 | threads: Number of threads.
25 | """
26 |
27 | dataset = get_dataset(*dataset_name)
28 |
29 | if sequence is not None:
30 | dataset = [dataset[sequence]]
31 |
32 | trackers = [Tracker(tracker_name, tracker_param, dataset_name, run_id)]
33 |
34 | run_dataset(dataset, trackers, debug, threads, num_gpus=num_gpus)
35 |
36 |
37 | def main():
38 | parser = argparse.ArgumentParser(description='Run tracker on sequence or dataset.')
39 | parser.add_argument('tracker_name', type=str, help='Name of tracking method.')
40 | parser.add_argument('tracker_param', type=str, help='Name of config file.')
41 | parser.add_argument('--runid', type=int, default=None, help='The run id.')
42 | parser.add_argument('--dataset_name', type=str, default='otb', help='Name of dataset (otb, nfs, uav, tpl, vot, tn, gott, gotv, lasot).')
43 | parser.add_argument('--sequence', type=str, default=None, help='Sequence number or name.')
44 | parser.add_argument('--debug', type=int, default=0, help='Debug level.')
45 | parser.add_argument('--threads', type=int, default=0, help='Number of threads.')
46 | parser.add_argument('--num_gpus', type=int, default=8)
47 |
48 | args = parser.parse_args()
49 |
50 | try:
51 | seq_name = int(args.sequence)
52 | except:
53 | seq_name = args.sequence
54 |
55 | args.dataset_name = ['trackingnet', 'got10k_test', 'lasot']
56 |
57 | run_tracker(args.tracker_name, args.tracker_param, args.runid, args.dataset_name, seq_name, args.debug,
58 | args.threads, num_gpus=args.num_gpus)
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/lib/test/evaluation/datasets.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import importlib
3 | from lib.test.evaluation.data import SequenceList
4 |
5 | DatasetInfo = namedtuple('DatasetInfo', ['module', 'class_name', 'kwargs'])
6 |
7 | pt = "lib.test.evaluation.%sdataset" # Useful abbreviations to reduce the clutter
8 |
9 | dataset_dict = dict(
10 | otb=DatasetInfo(module=pt % "otb", class_name="OTBDataset", kwargs=dict()),
11 | nfs=DatasetInfo(module=pt % "nfs", class_name="NFSDataset", kwargs=dict()),
12 | uav=DatasetInfo(module=pt % "uav", class_name="UAVDataset", kwargs=dict()),
13 | tc128=DatasetInfo(module=pt % "tc128", class_name="TC128Dataset", kwargs=dict()),
14 | tc128ce=DatasetInfo(module=pt % "tc128ce", class_name="TC128CEDataset", kwargs=dict()),
15 | trackingnet=DatasetInfo(module=pt % "trackingnet", class_name="TrackingNetDataset", kwargs=dict()),
16 | got10k_test=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='test')),
17 | got10k_val=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='val')),
18 | got10k_ltrval=DatasetInfo(module=pt % "got10k", class_name="GOT10KDataset", kwargs=dict(split='ltrval')),
19 | lasot=DatasetInfo(module=pt % "lasot", class_name="LaSOTDataset", kwargs=dict()),
20 | lasot_lmdb=DatasetInfo(module=pt % "lasot_lmdb", class_name="LaSOTlmdbDataset", kwargs=dict()),
21 |
22 | vot18=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict()),
23 | vot22=DatasetInfo(module=pt % "vot", class_name="VOTDataset", kwargs=dict(year=22)),
24 | itb=DatasetInfo(module=pt % "itb", class_name="ITBDataset", kwargs=dict()),
25 | tnl2k=DatasetInfo(module=pt % "tnl2k", class_name="TNL2kDataset", kwargs=dict()),
26 | lasot_extension_subset=DatasetInfo(module=pt % "lasotextensionsubset", class_name="LaSOTExtensionSubsetDataset",
27 | kwargs=dict()),
28 | )
29 |
30 |
31 | def load_dataset(name: str):
32 | """ Import and load a single dataset."""
33 | name = name.lower()
34 | dset_info = dataset_dict.get(name)
35 | if dset_info is None:
36 | raise ValueError('Unknown dataset \'%s\'' % name)
37 |
38 | m = importlib.import_module(dset_info.module)
39 | dataset = getattr(m, dset_info.class_name)(**dset_info.kwargs) # Call the constructor
40 | return dataset.get_sequence_list()
41 |
42 |
43 | def get_dataset(*args):
44 | """ Get a single or set of datasets."""
45 | dset = SequenceList()
46 | for name in args:
47 | dset.extend(load_dataset(name))
48 | return dset
--------------------------------------------------------------------------------
/lib/test/evaluation/trackingnetdataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
3 | import os
4 | from lib.test.utils.load_text import load_text
5 |
6 |
7 | class TrackingNetDataset(BaseDataset):
8 | """ TrackingNet test set.
9 |
10 | Publication:
11 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild.
12 | Matthias Mueller,Adel Bibi, Silvio Giancola, Salman Al-Subaihi and Bernard Ghanem
13 | ECCV, 2018
14 | https://ivul.kaust.edu.sa/Documents/Publications/2018/TrackingNet%20A%20Large%20Scale%20Dataset%20and%20Benchmark%20for%20Object%20Tracking%20in%20the%20Wild.pdf
15 |
16 | Download the dataset using the toolkit https://github.com/SilvioGiancola/TrackingNet-devkit.
17 | """
18 | def __init__(self):
19 | super().__init__()
20 | self.base_path = self.env_settings.trackingnet_path
21 |
22 | sets = 'TEST'
23 | if not isinstance(sets, (list, tuple)):
24 | if sets == 'TEST':
25 | sets = ['TEST']
26 | elif sets == 'TRAIN':
27 | sets = ['TRAIN_{}'.format(i) for i in range(5)]
28 |
29 | self.sequence_list = self._list_sequences(self.base_path, sets)
30 |
31 | def get_sequence_list(self):
32 | return SequenceList([self._construct_sequence(set, seq_name) for set, seq_name in self.sequence_list])
33 |
34 | def _construct_sequence(self, set, sequence_name):
35 | anno_path = '{}/{}/anno/{}.txt'.format(self.base_path, set, sequence_name)
36 |
37 | ground_truth_rect = load_text(str(anno_path), delimiter=',', dtype=np.float64, backend='numpy')
38 |
39 | frames_path = '{}/{}/frames/{}'.format(self.base_path, set, sequence_name)
40 | frame_list = [frame for frame in os.listdir(frames_path) if frame.endswith(".jpg")]
41 | frame_list.sort(key=lambda f: int(f[:-4]))
42 | frames_list = [os.path.join(frames_path, frame) for frame in frame_list]
43 |
44 | return Sequence(sequence_name, frames_list, 'trackingnet', ground_truth_rect.reshape(-1, 4))
45 |
46 | def __len__(self):
47 | return len(self.sequence_list)
48 |
49 | def _list_sequences(self, root, set_ids):
50 | sequence_list = []
51 |
52 | for s in set_ids:
53 | anno_dir = os.path.join(root, s, "anno")
54 | sequences_cur_set = [(s, os.path.splitext(f)[0]) for f in os.listdir(anno_dir) if f.endswith('.txt')]
55 |
56 | sequence_list += sequences_cur_set
57 |
58 | return sequence_list
59 |
--------------------------------------------------------------------------------
/lib/utils/focal_loss.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class FocalLoss(nn.Module, ABC):
9 | def __init__(self, alpha=2, beta=4):
10 | super(FocalLoss, self).__init__()
11 | self.alpha = alpha
12 | self.beta = beta
13 |
14 | def forward(self, prediction, target):
15 | positive_index = target.eq(1).float()
16 | negative_index = target.lt(1).float()
17 |
18 | negative_weights = torch.pow(1 - target, self.beta)
19 | # clamp min value is set to 1e-12 to maintain the numerical stability
20 | prediction = torch.clamp(prediction, 1e-12)
21 |
22 | positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index
23 | negative_loss = torch.log(1 - prediction) * torch.pow(prediction,
24 | self.alpha) * negative_weights * negative_index
25 |
26 | num_positive = positive_index.float().sum()
27 | positive_loss = positive_loss.sum()
28 | negative_loss = negative_loss.sum()
29 |
30 | if num_positive == 0:
31 | loss = -negative_loss
32 | else:
33 | loss = -(positive_loss + negative_loss) / num_positive
34 |
35 | return loss
36 |
37 |
38 | class LBHinge(nn.Module):
39 | """Loss that uses a 'hinge' on the lower bound.
40 | This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is
41 | also smaller than that threshold.
42 | args:
43 | error_matric: What base loss to use (MSE by default).
44 | threshold: Threshold to use for the hinge.
45 | clip: Clip the loss if it is above this value.
46 | """
47 | def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None):
48 | super().__init__()
49 | self.error_metric = error_metric
50 | self.threshold = threshold if threshold is not None else -100
51 | self.clip = clip
52 |
53 | def forward(self, prediction, label, target_bb=None):
54 | negative_mask = (label < self.threshold).float()
55 | positive_mask = (1.0 - negative_mask)
56 |
57 | prediction = negative_mask * F.relu(prediction) + positive_mask * prediction
58 |
59 | loss = self.error_metric(prediction, positive_mask * label)
60 |
61 | if self.clip is not None:
62 | loss = torch.min(loss, torch.tensor([self.clip], device=loss.device))
63 | return loss
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | echo "****************** Installing pytorch ******************"
2 | conda install pytorch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 cudatoolkit=10.2 -c pytorch
3 |
4 | echo ""
5 | echo ""
6 | echo "****************** Installing yaml ******************"
7 | pip install PyYAML
8 |
9 | echo ""
10 | echo ""
11 | echo "****************** Installing easydict ******************"
12 | pip install easydict
13 |
14 | echo ""
15 | echo ""
16 | echo "****************** Installing cython ******************"
17 | pip install cython
18 |
19 | echo ""
20 | echo ""
21 | echo "****************** Installing opencv-python ******************"
22 | pip install opencv-python
23 |
24 | echo ""
25 | echo ""
26 | echo "****************** Installing pandas ******************"
27 | pip install pandas
28 |
29 | echo ""
30 | echo ""
31 | echo "****************** Installing tqdm ******************"
32 | conda install -y tqdm
33 |
34 | echo ""
35 | echo ""
36 | echo "****************** Installing coco toolkit ******************"
37 | pip install pycocotools
38 |
39 | echo ""
40 | echo ""
41 | echo "****************** Installing jpeg4py python wrapper ******************"
42 | pip install jpeg4py
43 |
44 | echo ""
45 | echo ""
46 | echo "****************** Installing tensorboard ******************"
47 | pip install tb-nightly
48 |
49 | echo ""
50 | echo ""
51 | echo "****************** Installing tikzplotlib ******************"
52 | pip install tikzplotlib
53 |
54 | echo ""
55 | echo ""
56 | echo "****************** Installing thop tool for FLOPs and Params computing ******************"
57 | pip install thop-0.0.31.post2005241907
58 |
59 | echo ""
60 | echo ""
61 | echo "****************** Installing colorama ******************"
62 | pip install colorama
63 |
64 | echo ""
65 | echo ""
66 | echo "****************** Installing lmdb ******************"
67 | pip install lmdb
68 |
69 | echo ""
70 | echo ""
71 | echo "****************** Installing scipy ******************"
72 | pip install scipy
73 |
74 | echo ""
75 | echo ""
76 | echo "****************** Installing visdom ******************"
77 | pip install visdom
78 |
79 |
80 | echo ""
81 | echo ""
82 | echo "****************** Installing tensorboardX ******************"
83 | pip install tensorboardX
84 |
85 |
86 | echo ""
87 | echo ""
88 | echo "****************** Downgrade setuptools ******************"
89 | pip install setuptools==59.5.0
90 |
91 |
92 | echo ""
93 | echo ""
94 | echo "****************** Installing wandb ******************"
95 | pip install wandb
96 |
97 | echo ""
98 | echo ""
99 | echo "****************** Installing timm ******************"
100 | pip install timm
101 |
102 | echo ""
103 | echo ""
104 | echo "****************** Installation complete! ******************"
105 |
--------------------------------------------------------------------------------
/lib/train/dataset/base_image_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from lib.train.data.image_loader import jpeg4py_loader
3 |
4 |
5 | class BaseImageDataset(torch.utils.data.Dataset):
6 | """ Base class for image datasets """
7 |
8 | def __init__(self, name, root, image_loader=jpeg4py_loader):
9 | """
10 | args:
11 | root - The root path to the dataset
12 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
13 | is used by default.
14 | """
15 | self.name = name
16 | self.root = root
17 | self.image_loader = image_loader
18 |
19 | self.image_list = [] # Contains the list of sequences.
20 | self.class_list = []
21 |
22 | def __len__(self):
23 | """ Returns size of the dataset
24 | returns:
25 | int - number of samples in the dataset
26 | """
27 | return self.get_num_images()
28 |
29 | def __getitem__(self, index):
30 | """ Not to be used! Check get_frames() instead.
31 | """
32 | return None
33 |
34 | def get_name(self):
35 | """ Name of the dataset
36 |
37 | returns:
38 | string - Name of the dataset
39 | """
40 | raise NotImplementedError
41 |
42 | def get_num_images(self):
43 | """ Number of sequences in a dataset
44 |
45 | returns:
46 | int - number of sequences in the dataset."""
47 | return len(self.image_list)
48 |
49 | def has_class_info(self):
50 | return False
51 |
52 | def get_class_name(self, image_id):
53 | return None
54 |
55 | def get_num_classes(self):
56 | return len(self.class_list)
57 |
58 | def get_class_list(self):
59 | return self.class_list
60 |
61 | def get_images_in_class(self, class_name):
62 | raise NotImplementedError
63 |
64 | def has_segmentation_info(self):
65 | return False
66 |
67 | def get_image_info(self, seq_id):
68 | """ Returns information about a particular image,
69 |
70 | args:
71 | seq_id - index of the image
72 |
73 | returns:
74 | Dict
75 | """
76 | raise NotImplementedError
77 |
78 | def get_image(self, image_id, anno=None):
79 | """ Get a image
80 |
81 | args:
82 | image_id - index of image
83 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded.
84 |
85 | returns:
86 | image -
87 | anno -
88 | dict - A dict containing meta information about the sequence, e.g. class of the target object.
89 |
90 | """
91 | raise NotImplementedError
92 |
93 |
--------------------------------------------------------------------------------
/tracking/pre_read_datasets.py:
--------------------------------------------------------------------------------
1 | import _init_paths
2 | import multiprocessing as mp
3 | import argparse
4 | import os
5 | from lib.utils.lmdb_utils import decode_str
6 | import time
7 | import json
8 |
9 |
10 | def parse_args():
11 | """
12 | args for training.
13 | """
14 | parser = argparse.ArgumentParser(description='Parse args for training')
15 | parser.add_argument('--data_dir', type=str, help='directory where lmdb data is located')
16 | parser.add_argument('--dataset_str', type=str, help="which datasets to use")
17 | args = parser.parse_args()
18 |
19 | return args
20 |
21 |
22 | def get_trknet_dict(trknet_dir):
23 | with open(os.path.join(trknet_dir, "seq_list.json"), "r") as f:
24 | seq_list = json.loads(f.read())
25 | res_dict = {}
26 | set_idx_pre = -1
27 | for set_idx, seq_name in seq_list:
28 | if set_idx != set_idx_pre:
29 | res_dict[set_idx] = "anno/%s.txt" % seq_name
30 | set_idx_pre = set_idx
31 | return res_dict
32 |
33 |
34 | def target(lmdb_dir, key_name):
35 | _ = decode_str(lmdb_dir, key_name)
36 |
37 |
38 | if __name__ == "__main__":
39 | args = parse_args()
40 | data_dir = args.data_dir
41 | dataset_str = args.dataset_str
42 | key_dict = {"got10k_lmdb": "train/list.txt",
43 | "lasot_lmdb": "LaSOTBenchmark.json",
44 | "coco_lmdb": "annotations/instances_train2017.json",
45 | "vid_lmdb": "cache.json"}
46 | print("Ready to pre load datasets")
47 | start = time.time()
48 | ps = []
49 | datasets = []
50 | if 'g' in dataset_str:
51 | datasets.append("got10k_lmdb")
52 | if 'l' in dataset_str:
53 | datasets.append("lasot_lmdb")
54 | if 'c' in dataset_str:
55 | datasets.append("coco_lmdb")
56 | if 'v' in dataset_str:
57 | datasets.append("vid_lmdb")
58 | for dataset in datasets:
59 | lmdb_dir = os.path.join(data_dir, dataset)
60 | p = mp.Process(target=target, args=(lmdb_dir, key_dict[dataset]))
61 | print("add %s %s to job queue" % (lmdb_dir, key_dict[dataset]))
62 | ps.append(p)
63 | # deal with trackingnet
64 | if 't' in dataset_str:
65 | trknet_dict = get_trknet_dict(os.path.join(data_dir, "trackingnet_lmdb"))
66 | for set_idx, seq_path in trknet_dict.items():
67 | lmdb_dir = os.path.join(data_dir, "trackingnet_lmdb", "TRAIN_%d_lmdb" % set_idx)
68 | p = mp.Process(target=target, args=(lmdb_dir, seq_path))
69 | print("add %s %s to job queue" % (lmdb_dir, seq_path))
70 | ps.append(p)
71 | for p in ps:
72 | p.start()
73 | for p in ps:
74 | p.join()
75 |
76 | print("Pre read over")
77 | end = time.time()
78 | hour = (end - start) / 3600
79 | print("it takes %.2f hours to pre-read data" % hour)
80 |
--------------------------------------------------------------------------------
/tracking/analysis_results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "outputs": [],
7 | "source": [
8 | "%load_ext autoreload\n",
9 | "%autoreload 2\n",
10 | "%matplotlib inline\n",
11 | "import os\n",
12 | "import sys\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "plt.rcParams['figure.figsize'] = [14, 8]\n",
15 | "\n",
16 | "sys.path.append('/home/yebotao/OSTrack')\n",
17 | "from lib.test.analysis.plot_results import plot_results, print_results, print_per_sequence_results, print_results_per_video\n",
18 | "from lib.test.evaluation import get_dataset, trackerlist"
19 | ],
20 | "metadata": {
21 | "collapsed": false,
22 | "pycharm": {
23 | "name": "#%%\n"
24 | }
25 | }
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "outputs": [],
31 | "source": [
32 | "dataset_name = 'lasot'\n",
33 | "\n",
34 | "trackers = []\n",
35 | "trackers.extend(trackerlist(name='ostrack', parameter_name='vitb_256_mae_ce_32x4_ep300', dataset_name=dataset_name,\n",
36 | " run_ids=None, display_name='OSTrack256'))\n",
37 | "trackers.extend(trackerlist(name='ostrack', parameter_name='vitb_384_mae_ce_32x4_ep300', dataset_name=dataset_name,\n",
38 | " run_ids=None, display_name='OSTrack384'))\n",
39 | "\n",
40 | "dataset = get_dataset(dataset_name)\n",
41 | "# plot_results(trackers, dataset, dataset_name, merge_results=True, plot_types=('success', 'prec'),\n",
42 | "# skip_missing_seq=False, force_evaluation=True, plot_bin_gap=0.05, exclude_invalid_frames=False)\n",
43 | "print_results(trackers, dataset, dataset_name, merge_results=True, plot_types=('success', 'prec', 'norm_prec'))\n",
44 | "# print_results_per_video(trackers, dataset, dataset_name, merge_results=True, plot_types=('success', 'prec', 'norm_prec'),\n",
45 | "# per_video=True, force_evaluation=True)\n",
46 | "# print_per_sequence_results(trackers, dataset, dataset_name, merge_results=True, plot_types=('success', 'prec', 'norm_prec'))"
47 | ],
48 | "metadata": {
49 | "collapsed": false,
50 | "pycharm": {
51 | "name": "#%%\n"
52 | }
53 | }
54 | }
55 | ],
56 | "metadata": {
57 | "kernelspec": {
58 | "display_name": "Python 3",
59 | "language": "python",
60 | "name": "python3"
61 | },
62 | "language_info": {
63 | "codemirror_mode": {
64 | "name": "ipython",
65 | "version": 2
66 | },
67 | "file_extension": ".py",
68 | "mimetype": "text/x-python",
69 | "name": "python",
70 | "nbconvert_exporter": "python",
71 | "pygments_lexer": "ipython2",
72 | "version": "2.7.6"
73 | }
74 | },
75 | "nbformat": 4,
76 | "nbformat_minor": 0
77 | }
--------------------------------------------------------------------------------
/lib/utils/box_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.ops.boxes import box_area
3 | import numpy as np
4 |
5 |
6 | def box_cxcywh_to_xyxy(x):
7 | x_c, y_c, w, h = x.unbind(-1)
8 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
9 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
10 | return torch.stack(b, dim=-1)
11 |
12 |
13 | def box_xywh_to_xyxy(x):
14 | x1, y1, w, h = x.unbind(-1)
15 | b = [x1, y1, x1 + w, y1 + h]
16 | return torch.stack(b, dim=-1)
17 |
18 |
19 | def box_xyxy_to_xywh(x):
20 | x1, y1, x2, y2 = x.unbind(-1)
21 | b = [x1, y1, x2 - x1, y2 - y1]
22 | return torch.stack(b, dim=-1)
23 |
24 |
25 | def box_xyxy_to_cxcywh(x):
26 | x0, y0, x1, y1 = x.unbind(-1)
27 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
28 | (x1 - x0), (y1 - y0)]
29 | return torch.stack(b, dim=-1)
30 |
31 |
32 | # modified from torchvision to also return the union
33 | '''Note that this function only supports shape (N,4)'''
34 |
35 |
36 | def box_iou(boxes1, boxes2):
37 | """
38 |
39 | :param boxes1: (N, 4) (x1,y1,x2,y2)
40 | :param boxes2: (N, 4) (x1,y1,x2,y2)
41 | :return:
42 | """
43 | area1 = box_area(boxes1) # (N,)
44 | area2 = box_area(boxes2) # (N,)
45 |
46 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # (N,2)
47 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # (N,2)
48 |
49 | wh = (rb - lt).clamp(min=0) # (N,2)
50 | inter = wh[:, 0] * wh[:, 1] # (N,)
51 |
52 | union = area1 + area2 - inter
53 |
54 | iou = inter / union
55 | return iou, union
56 |
57 |
58 | '''Note that this implementation is different from DETR's'''
59 |
60 |
61 | def generalized_box_iou(boxes1, boxes2):
62 | """
63 | Generalized IoU from https://giou.stanford.edu/
64 |
65 | The boxes should be in [x0, y0, x1, y1] format
66 |
67 | boxes1: (N, 4)
68 | boxes2: (N, 4)
69 | """
70 | # degenerate boxes gives inf / nan results
71 | # so do an early check
72 | # try:
73 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
74 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
75 | iou, union = box_iou(boxes1, boxes2) # (N,)
76 |
77 | lt = torch.min(boxes1[:, :2], boxes2[:, :2])
78 | rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
79 |
80 | wh = (rb - lt).clamp(min=0) # (N,2)
81 | area = wh[:, 0] * wh[:, 1] # (N,)
82 |
83 | return iou - (area - union) / area, iou
84 |
85 |
86 | def giou_loss(boxes1, boxes2):
87 | """
88 |
89 | :param boxes1: (N, 4) (x1,y1,x2,y2)
90 | :param boxes2: (N, 4) (x1,y1,x2,y2)
91 | :return:
92 | """
93 | giou, iou = generalized_box_iou(boxes1, boxes2)
94 | return (1 - giou).mean(), iou
95 |
96 |
97 | def clip_box(box: list, H, W, margin=0):
98 | x1, y1, w, h = box
99 | x2, y2 = x1 + w, y1 + h
100 | x1 = min(max(0, x1), W-margin)
101 | x2 = min(max(margin, x2), W)
102 | y1 = min(max(0, y1), H-margin)
103 | y2 = min(max(margin, y2), H)
104 | w = max(margin, x2-x1)
105 | h = max(margin, y2-y1)
106 | return [x1, y1, w, h]
107 |
--------------------------------------------------------------------------------
/lib/train/data/bounding_box_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def rect_to_rel(bb, sz_norm=None):
5 | """Convert standard rectangular parametrization of the bounding box [x, y, w, h]
6 | to relative parametrization [cx/sw, cy/sh, log(w), log(h)], where [cx, cy] is the center coordinate.
7 | args:
8 | bb - N x 4 tensor of boxes.
9 | sz_norm - [N] x 2 tensor of value of [sw, sh] (optional). sw=w and sh=h if not given.
10 | """
11 |
12 | c = bb[...,:2] + 0.5 * bb[...,2:]
13 | if sz_norm is None:
14 | c_rel = c / bb[...,2:]
15 | else:
16 | c_rel = c / sz_norm
17 | sz_rel = torch.log(bb[...,2:])
18 | return torch.cat((c_rel, sz_rel), dim=-1)
19 |
20 |
21 | def rel_to_rect(bb, sz_norm=None):
22 | """Inverts the effect of rect_to_rel. See above."""
23 |
24 | sz = torch.exp(bb[...,2:])
25 | if sz_norm is None:
26 | c = bb[...,:2] * sz
27 | else:
28 | c = bb[...,:2] * sz_norm
29 | tl = c - 0.5 * sz
30 | return torch.cat((tl, sz), dim=-1)
31 |
32 |
33 | def masks_to_bboxes(mask, fmt='c'):
34 |
35 | """ Convert a mask tensor to one or more bounding boxes.
36 | Note: This function is a bit new, make sure it does what it says. /Andreas
37 | :param mask: Tensor of masks, shape = (..., H, W)
38 | :param fmt: bbox layout. 'c' => "center + size" or (x_center, y_center, width, height)
39 | 't' => "top left + size" or (x_left, y_top, width, height)
40 | 'v' => "vertices" or (x_left, y_top, x_right, y_bottom)
41 | :return: tensor containing a batch of bounding boxes, shape = (..., 4)
42 | """
43 | batch_shape = mask.shape[:-2]
44 | mask = mask.reshape((-1, *mask.shape[-2:]))
45 | bboxes = []
46 |
47 | for m in mask:
48 | mx = m.sum(dim=-2).nonzero()
49 | my = m.sum(dim=-1).nonzero()
50 | bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
51 | bboxes.append(bb)
52 |
53 | bboxes = torch.tensor(bboxes, dtype=torch.float32, device=mask.device)
54 | bboxes = bboxes.reshape(batch_shape + (4,))
55 |
56 | if fmt == 'v':
57 | return bboxes
58 |
59 | x1 = bboxes[..., :2]
60 | s = bboxes[..., 2:] - x1 + 1
61 |
62 | if fmt == 'c':
63 | return torch.cat((x1 + 0.5 * s, s), dim=-1)
64 | elif fmt == 't':
65 | return torch.cat((x1, s), dim=-1)
66 |
67 | raise ValueError("Undefined bounding box layout '%s'" % fmt)
68 |
69 |
70 | def masks_to_bboxes_multi(mask, ids, fmt='c'):
71 | assert mask.dim() == 2
72 | bboxes = []
73 |
74 | for id in ids:
75 | mx = (mask == id).sum(dim=-2).nonzero()
76 | my = (mask == id).float().sum(dim=-1).nonzero()
77 | bb = [mx.min(), my.min(), mx.max(), my.max()] if (len(mx) > 0 and len(my) > 0) else [0, 0, 0, 0]
78 |
79 | bb = torch.tensor(bb, dtype=torch.float32, device=mask.device)
80 |
81 | x1 = bb[:2]
82 | s = bb[2:] - x1 + 1
83 |
84 | if fmt == 'v':
85 | pass
86 | elif fmt == 'c':
87 | bb = torch.cat((x1 + 0.5 * s, s), dim=-1)
88 | elif fmt == 't':
89 | bb = torch.cat((x1, s), dim=-1)
90 | else:
91 | raise ValueError("Undefined bounding box layout '%s'" % fmt)
92 | bboxes.append(bb)
93 |
94 | return bboxes
95 |
--------------------------------------------------------------------------------
/lib/test/evaluation/itbdataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lib.test.evaluation.data import Sequence, BaseDataset, SequenceList
3 | from lib.test.utils.load_text import load_text
4 | import os
5 |
6 |
7 | class ITBDataset(BaseDataset):
8 | """ NUS-PRO dataset
9 | """
10 |
11 | def __init__(self):
12 | super().__init__()
13 | self.base_path = self.env_settings.itb_path
14 | self.sequence_info_list = self._get_sequence_info_list(self.base_path)
15 |
16 | def get_sequence_list(self):
17 | return SequenceList([self._construct_sequence(s) for s in self.sequence_info_list])
18 |
19 | def _construct_sequence(self, sequence_info):
20 | sequence_path = sequence_info['path']
21 | nz = sequence_info['nz']
22 | ext = sequence_info['ext']
23 | start_frame = sequence_info['startFrame']
24 | end_frame = sequence_info['endFrame']
25 |
26 | init_omit = 0
27 | if 'initOmit' in sequence_info:
28 | init_omit = sequence_info['initOmit']
29 |
30 | frames = ['{base_path}/{sequence_path}/{frame:0{nz}}.{ext}'.format(base_path=self.base_path,
31 | sequence_path=sequence_path, frame=frame_num,
32 | nz=nz, ext=ext) for frame_num in
33 | range(start_frame + init_omit, end_frame + 1)]
34 |
35 | anno_path = '{}/{}'.format(self.base_path, sequence_info['anno_path'])
36 |
37 | # NOTE: NUS has some weird annos which panda cannot handle
38 | ground_truth_rect = load_text(str(anno_path), delimiter=(',', None), dtype=np.float64, backend='numpy')
39 | return Sequence(sequence_info['name'], frames, 'otb', ground_truth_rect[init_omit:, :],
40 | object_class=sequence_info['object_class'])
41 |
42 | def __len__(self):
43 | return len(self.sequence_info_list)
44 |
45 | def get_fileNames(self, rootdir):
46 | fs = []
47 | fs_all = []
48 | for root, dirs, files in os.walk(rootdir, topdown=True):
49 | files.sort()
50 | files.sort(key=len)
51 | if files is not None:
52 | for name in files:
53 | _, ending = os.path.splitext(name)
54 | if ending == ".jpg":
55 | _, root_ = os.path.split(root)
56 | fs.append(os.path.join(root_, name))
57 | fs_all.append(os.path.join(root, name))
58 |
59 | return fs_all, fs
60 |
61 | def _get_sequence_info_list(self, base_path):
62 | sequence_info_list = []
63 | for scene in os.listdir(base_path):
64 | if '.' in scene:
65 | continue
66 | videos = os.listdir(os.path.join(base_path, scene))
67 | for video in videos:
68 | _, fs = self.get_fileNames(os.path.join(base_path, scene, video))
69 | video_tmp = {"name": video, "path": scene + '/' + video, "startFrame": 1, "endFrame": len(fs),
70 | "nz": len(fs[0].split('/')[-1].split('.')[0]), "ext": "jpg",
71 | "anno_path": scene + '/' + video + "/groundtruth.txt",
72 | "object_class": "unknown"}
73 | sequence_info_list.append(video_tmp)
74 |
75 | return sequence_info_list # sequence_info_list_50 #
76 |
--------------------------------------------------------------------------------
/lib/utils/ce_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def generate_bbox_mask(bbox_mask, bbox):
8 | b, h, w = bbox_mask.shape
9 | for i in range(b):
10 | bbox_i = bbox[i].cpu().tolist()
11 | bbox_mask[i, int(bbox_i[1]):int(bbox_i[1] + bbox_i[3] - 1), int(bbox_i[0]):int(bbox_i[0] + bbox_i[2] - 1)] = 1
12 | return bbox_mask
13 |
14 |
15 | def generate_mask_cond(cfg, bs, device, gt_bbox):
16 | template_size = cfg.DATA.TEMPLATE.SIZE
17 | stride = cfg.MODEL.BACKBONE.STRIDE
18 | template_feat_size = template_size // stride
19 |
20 | if cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'ALL':
21 | box_mask_z = None
22 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_POINT':
23 | if template_feat_size == 8:
24 | index = slice(3, 4)
25 | elif template_feat_size == 12:
26 | index = slice(5, 6)
27 | elif template_feat_size == 7:
28 | index = slice(3, 4)
29 | elif template_feat_size == 14:
30 | index = slice(6, 7)
31 | else:
32 | raise NotImplementedError
33 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device)
34 | box_mask_z[:, index, index] = 1
35 | box_mask_z = box_mask_z.flatten(1).to(torch.bool)
36 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'CTR_REC':
37 | # use fixed 4x4 region, 3:5 for 8x8
38 | # use fixed 4x4 region 5:6 for 12x12
39 | if template_feat_size == 8:
40 | index = slice(3, 5)
41 | elif template_feat_size == 12:
42 | index = slice(5, 7)
43 | elif template_feat_size == 7:
44 | index = slice(3, 4)
45 | else:
46 | raise NotImplementedError
47 | box_mask_z = torch.zeros([bs, template_feat_size, template_feat_size], device=device)
48 | box_mask_z[:, index, index] = 1
49 | box_mask_z = box_mask_z.flatten(1).to(torch.bool)
50 |
51 | elif cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE == 'GT_BOX':
52 | box_mask_z = torch.zeros([bs, template_size, template_size], device=device)
53 | # box_mask_z_ori = data['template_seg'][0].view(-1, 1, *data['template_seg'].shape[2:]) # (batch, 1, 128, 128)
54 | box_mask_z = generate_bbox_mask(box_mask_z, gt_bbox * template_size).unsqueeze(1).to(
55 | torch.float) # (batch, 1, 128, 128)
56 | # box_mask_z_vis = box_mask_z.cpu().numpy()
57 | box_mask_z = F.interpolate(box_mask_z, scale_factor=1. / cfg.MODEL.BACKBONE.STRIDE, mode='bilinear',
58 | align_corners=False)
59 | box_mask_z = box_mask_z.flatten(1).to(torch.bool)
60 | # box_mask_z_vis = box_mask_z[:, 0, ...].cpu().numpy()
61 | # gaussian_maps_vis = generate_heatmap(data['template_anno'], self.cfg.DATA.TEMPLATE.SIZE, self.cfg.MODEL.STRIDE)[0].cpu().numpy()
62 | else:
63 | raise NotImplementedError
64 |
65 | return box_mask_z
66 |
67 |
68 | def adjust_keep_rate(epoch, warmup_epochs, total_epochs, ITERS_PER_EPOCH, base_keep_rate=0.5, max_keep_rate=1, iters=-1):
69 | if epoch < warmup_epochs:
70 | return 1
71 | if epoch >= total_epochs:
72 | return base_keep_rate
73 | if iters == -1:
74 | iters = epoch * ITERS_PER_EPOCH
75 | total_iters = ITERS_PER_EPOCH * (total_epochs - warmup_epochs)
76 | iters = iters - ITERS_PER_EPOCH * warmup_epochs
77 | keep_rate = base_keep_rate + (max_keep_rate - base_keep_rate) \
78 | * (math.cos(iters / total_iters * math.pi) + 1) * 0.5
79 |
80 | return keep_rate
81 |
--------------------------------------------------------------------------------
/lib/train/dataset/base_video_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | # 2021.1.5 use jpeg4py_loader_w_failsafe as default
3 | from lib.train.data.image_loader import jpeg4py_loader_w_failsafe
4 |
5 |
6 | class BaseVideoDataset(torch.utils.data.Dataset):
7 | """ Base class for video datasets """
8 |
9 | def __init__(self, name, root, image_loader=jpeg4py_loader_w_failsafe):
10 | """
11 | args:
12 | root - The root path to the dataset
13 | image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py)
14 | is used by default.
15 | """
16 | self.name = name
17 | self.root = root
18 | self.image_loader = image_loader
19 |
20 | self.sequence_list = [] # Contains the list of sequences.
21 | self.class_list = []
22 |
23 | def __len__(self):
24 | """ Returns size of the dataset
25 | returns:
26 | int - number of samples in the dataset
27 | """
28 | return self.get_num_sequences()
29 |
30 | def __getitem__(self, index):
31 | """ Not to be used! Check get_frames() instead.
32 | """
33 | return None
34 |
35 | def is_video_sequence(self):
36 | """ Returns whether the dataset is a video dataset or an image dataset
37 |
38 | returns:
39 | bool - True if a video dataset
40 | """
41 | return True
42 |
43 | def is_synthetic_video_dataset(self):
44 | """ Returns whether the dataset contains real videos or synthetic
45 |
46 | returns:
47 | bool - True if a video dataset
48 | """
49 | return False
50 |
51 | def get_name(self):
52 | """ Name of the dataset
53 |
54 | returns:
55 | string - Name of the dataset
56 | """
57 | raise NotImplementedError
58 |
59 | def get_num_sequences(self):
60 | """ Number of sequences in a dataset
61 |
62 | returns:
63 | int - number of sequences in the dataset."""
64 | return len(self.sequence_list)
65 |
66 | def has_class_info(self):
67 | return False
68 |
69 | def has_occlusion_info(self):
70 | return False
71 |
72 | def get_num_classes(self):
73 | return len(self.class_list)
74 |
75 | def get_class_list(self):
76 | return self.class_list
77 |
78 | def get_sequences_in_class(self, class_name):
79 | raise NotImplementedError
80 |
81 | def has_segmentation_info(self):
82 | return False
83 |
84 | def get_sequence_info(self, seq_id):
85 | """ Returns information about a particular sequences,
86 |
87 | args:
88 | seq_id - index of the sequence
89 |
90 | returns:
91 | Dict
92 | """
93 | raise NotImplementedError
94 |
95 | def get_frames(self, seq_id, frame_ids, anno=None):
96 | """ Get a set of frames from a particular sequence
97 |
98 | args:
99 | seq_id - index of sequence
100 | frame_ids - a list of frame numbers
101 | anno(None) - The annotation for the sequence (see get_sequence_info). If None, they will be loaded.
102 |
103 | returns:
104 | list - List of frames corresponding to frame_ids
105 | list - List of dicts for each frame
106 | dict - A dict containing meta information about the sequence, e.g. class of the target object.
107 |
108 | """
109 | raise NotImplementedError
110 |
111 |
--------------------------------------------------------------------------------
/lib/train/data/image_loader.py:
--------------------------------------------------------------------------------
1 | import jpeg4py
2 | import cv2 as cv
3 | from PIL import Image
4 | import numpy as np
5 |
6 | davis_palette = np.repeat(np.expand_dims(np.arange(0,256), 1), 3, 1).astype(np.uint8)
7 | davis_palette[:22, :] = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
8 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
9 | [64, 0, 0], [191, 0, 0], [64, 128, 0], [191, 128, 0],
10 | [64, 0, 128], [191, 0, 128], [64, 128, 128], [191, 128, 128],
11 | [0, 64, 0], [128, 64, 0], [0, 191, 0], [128, 191, 0],
12 | [0, 64, 128], [128, 64, 128]]
13 |
14 |
15 | def default_image_loader(path):
16 | """The default image loader, reads the image from the given path. It first tries to use the jpeg4py_loader,
17 | but reverts to the opencv_loader if the former is not available."""
18 | if default_image_loader.use_jpeg4py is None:
19 | # Try using jpeg4py
20 | im = jpeg4py_loader(path)
21 | if im is None:
22 | default_image_loader.use_jpeg4py = False
23 | print('Using opencv_loader instead.')
24 | else:
25 | default_image_loader.use_jpeg4py = True
26 | return im
27 | if default_image_loader.use_jpeg4py:
28 | return jpeg4py_loader(path)
29 | return opencv_loader(path)
30 |
31 | default_image_loader.use_jpeg4py = None
32 |
33 |
34 | def jpeg4py_loader(path):
35 | """ Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
36 | try:
37 | return jpeg4py.JPEG(path).decode()
38 | except Exception as e:
39 | print('ERROR: Could not read image "{}"'.format(path))
40 | print(e)
41 | return None
42 |
43 |
44 | def opencv_loader(path):
45 | """ Read image using opencv's imread function and returns it in rgb format"""
46 | try:
47 | im = cv.imread(path, cv.IMREAD_COLOR)
48 |
49 | # convert to rgb and return
50 | return cv.cvtColor(im, cv.COLOR_BGR2RGB)
51 | except Exception as e:
52 | print('ERROR: Could not read image "{}"'.format(path))
53 | print(e)
54 | return None
55 |
56 |
57 | def jpeg4py_loader_w_failsafe(path):
58 | """ Image reading using jpeg4py https://github.com/ajkxyz/jpeg4py"""
59 | try:
60 | return jpeg4py.JPEG(path).decode()
61 | except:
62 | try:
63 | im = cv.imread(path, cv.IMREAD_COLOR)
64 |
65 | # convert to rgb and return
66 | return cv.cvtColor(im, cv.COLOR_BGR2RGB)
67 | except Exception as e:
68 | print('ERROR: Could not read image "{}"'.format(path))
69 | print(e)
70 | return None
71 |
72 |
73 | def opencv_seg_loader(path):
74 | """ Read segmentation annotation using opencv's imread function"""
75 | try:
76 | return cv.imread(path)
77 | except Exception as e:
78 | print('ERROR: Could not read image "{}"'.format(path))
79 | print(e)
80 | return None
81 |
82 |
83 | def imread_indexed(filename):
84 | """ Load indexed image with given filename. Used to read segmentation annotations."""
85 |
86 | im = Image.open(filename)
87 |
88 | annotation = np.atleast_3d(im)[...,0]
89 | return annotation
90 |
91 |
92 | def imwrite_indexed(filename, array, color_palette=None):
93 | """ Save indexed image as png. Used to save segmentation annotation."""
94 |
95 | if color_palette is None:
96 | color_palette = davis_palette
97 |
98 | if np.atleast_3d(array).shape[2] != 1:
99 | raise Exception("Saving indexed PNGs requires 2D array.")
100 |
101 | im = Image.fromarray(array)
102 | im.putpalette(color_palette.ravel())
103 | im.save(filename, format='PNG')
--------------------------------------------------------------------------------
/lib/train/train_script.py:
--------------------------------------------------------------------------------
1 | import os
2 | # loss function related
3 | from lib.utils.box_ops import giou_loss
4 | from torch.nn.functional import l1_loss
5 | from torch.nn import BCEWithLogitsLoss
6 | # train pipeline related
7 | from lib.train.trainers import LTRTrainer
8 | # distributed training related
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | # some more advanced functions
11 | from .base_functions import *
12 | # network related
13 | from lib.models.ostrack import build_ostrack
14 | # forward propagation related
15 | from lib.train.actors import OSTrackActor
16 | # for import modules
17 | import importlib
18 |
19 | from ..utils.focal_loss import FocalLoss
20 |
21 |
22 | def run(settings):
23 | settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
24 |
25 | # update the default configs with config file
26 | if not os.path.exists(settings.cfg_file):
27 | raise ValueError("%s doesn't exist." % settings.cfg_file)
28 | config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
29 | cfg = config_module.cfg
30 | config_module.update_config_from_file(settings.cfg_file)
31 | if settings.local_rank in [-1, 0]:
32 | print("New configuration is shown below.")
33 | for key in cfg.keys():
34 | print("%s configuration:" % key, cfg[key])
35 | print('\n')
36 |
37 | # update settings based on cfg
38 | update_settings(settings, cfg)
39 |
40 | # Record the training log
41 | log_dir = os.path.join(settings.save_dir, 'logs')
42 | if settings.local_rank in [-1, 0]:
43 | if not os.path.exists(log_dir):
44 | os.makedirs(log_dir)
45 | settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
46 |
47 | # Build dataloaders
48 | loader_train, loader_val = build_dataloaders(cfg, settings)
49 |
50 | if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE or "LightTrack" in cfg.MODEL.BACKBONE.TYPE:
51 | cfg.ckpt_dir = settings.save_dir
52 |
53 | # Create network
54 | if settings.script_name == "ostrack":
55 | net = build_ostrack(cfg)
56 | else:
57 | raise ValueError("illegal script name")
58 |
59 | # wrap networks to distributed one
60 | net.cuda()
61 | if settings.local_rank != -1:
62 | # net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter
63 | net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
64 | settings.device = torch.device("cuda:%d" % settings.local_rank)
65 | else:
66 | settings.device = torch.device("cuda:0")
67 | settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
68 | settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
69 | settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
70 | # Loss functions and Actors
71 | if settings.script_name == "ostrack":
72 | focal_loss = FocalLoss()
73 | objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss()}
74 | loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 1., 'cls': 1.0}
75 | actor = OSTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg)
76 | else:
77 | raise ValueError("illegal script name")
78 |
79 | # if cfg.TRAIN.DEEP_SUPERVISION:
80 | # raise ValueError("Deep supervision is not supported now.")
81 |
82 | # Optimizer, parameters, and learning rates
83 | optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
84 | use_amp = getattr(cfg.TRAIN, "AMP", False)
85 | trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
86 |
87 | # train process
88 | trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
89 |
--------------------------------------------------------------------------------
/lib/test/tracker/basetracker.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | from _collections import OrderedDict
5 |
6 | from lib.train.data.processing_utils import transform_image_to_crop
7 | from lib.vis.visdom_cus import Visdom
8 |
9 |
10 | class BaseTracker:
11 | """Base class for all trackers."""
12 |
13 | def __init__(self, params):
14 | self.params = params
15 | self.visdom = None
16 |
17 | def predicts_segmentation_mask(self):
18 | return False
19 |
20 | def initialize(self, image, info: dict) -> dict:
21 | """Overload this function in your tracker. This should initialize the model."""
22 | raise NotImplementedError
23 |
24 | def track(self, image, info: dict = None) -> dict:
25 | """Overload this function in your tracker. This should track in the frame and update the model."""
26 | raise NotImplementedError
27 |
28 | def visdom_draw_tracking(self, image, box, segmentation=None):
29 | if isinstance(box, OrderedDict):
30 | box = [v for k, v in box.items()]
31 | else:
32 | box = (box,)
33 | if segmentation is None:
34 | self.visdom.register((image, *box), 'Tracking', 1, 'Tracking')
35 | else:
36 | self.visdom.register((image, *box, segmentation), 'Tracking', 1, 'Tracking')
37 |
38 | def transform_bbox_to_crop(self, box_in, resize_factor, device, box_extract=None, crop_type='template'):
39 | # box_in: list [x1, y1, w, h], not normalized
40 | # box_extract: same as box_in
41 | # out bbox: Torch.tensor [1, 1, 4], x1y1wh, normalized
42 | if crop_type == 'template':
43 | crop_sz = torch.Tensor([self.params.template_size, self.params.template_size])
44 | elif crop_type == 'search':
45 | crop_sz = torch.Tensor([self.params.search_size, self.params.search_size])
46 | else:
47 | raise NotImplementedError
48 |
49 | box_in = torch.tensor(box_in)
50 | if box_extract is None:
51 | box_extract = box_in
52 | else:
53 | box_extract = torch.tensor(box_extract)
54 | template_bbox = transform_image_to_crop(box_in, box_extract, resize_factor, crop_sz, normalize=True)
55 | template_bbox = template_bbox.view(1, 1, 4).to(device)
56 |
57 | return template_bbox
58 |
59 | def _init_visdom(self, visdom_info, debug):
60 | visdom_info = {} if visdom_info is None else visdom_info
61 | self.pause_mode = False
62 | self.step = False
63 | self.next_seq = False
64 | if debug > 0 and visdom_info.get('use_visdom', True):
65 | try:
66 | self.visdom = Visdom(debug, {'handler': self._visdom_ui_handler, 'win_id': 'Tracking'},
67 | visdom_info=visdom_info)
68 |
69 | # # Show help
70 | # help_text = 'You can pause/unpause the tracker by pressing ''space'' with the ''Tracking'' window ' \
71 | # 'selected. During paused mode, you can track for one frame by pressing the right arrow key.' \
72 | # 'To enable/disable plotting of a data block, tick/untick the corresponding entry in ' \
73 | # 'block list.'
74 | # self.visdom.register(help_text, 'text', 1, 'Help')
75 | except:
76 | time.sleep(0.5)
77 | print('!!! WARNING: Visdom could not start, so using matplotlib visualization instead !!!\n'
78 | '!!! Start Visdom in a separate terminal window by typing \'visdom\' !!!')
79 |
80 | def _visdom_ui_handler(self, data):
81 | if data['event_type'] == 'KeyPress':
82 | if data['key'] == ' ':
83 | self.pause_mode = not self.pause_mode
84 |
85 | elif data['key'] == 'ArrowRight' and self.pause_mode:
86 | self.step = True
87 |
88 | elif data['key'] == 'n':
89 | self.next_seq = True
90 |
--------------------------------------------------------------------------------
/lib/train/train_script_datr.py:
--------------------------------------------------------------------------------
1 | import os
2 | # loss function related
3 | from lib.utils.box_ops import giou_loss
4 | from torch.nn.functional import l1_loss
5 | from torch.nn import BCEWithLogitsLoss
6 | # train pipeline related
7 | from lib.train.trainers import LTRTrainer
8 | # distributed training related
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | # some more advanced functions
11 | from .base_functions import *
12 | # network related
13 | from lib.models.ostrack import build_ostrack
14 | # forward propagation related
15 | from lib.train.actors import OSTrackActor
16 | # for import modules
17 | import importlib
18 |
19 | from ..utils.focal_loss import FocalLoss
20 |
21 |
22 | def run(settings):
23 | settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
24 |
25 | # update the default configs with config file
26 | if not os.path.exists(settings.cfg_file):
27 | raise ValueError("%s doesn't exist." % settings.cfg_file)
28 | config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
29 | cfg = config_module.cfg
30 | config_module.update_config_from_file(settings.cfg_file)
31 | if settings.local_rank in [-1, 0]:
32 | print("New configuration is shown below.")
33 | for key in cfg.keys():
34 | print("%s configuration:" % key, cfg[key])
35 | print('\n')
36 |
37 | # update settings based on cfg
38 | update_settings(settings, cfg)
39 |
40 | # Record the training log
41 | log_dir = os.path.join(settings.save_dir, 'logs')
42 | if settings.local_rank in [-1, 0]:
43 | if not os.path.exists(log_dir):
44 | os.makedirs(log_dir)
45 | settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
46 |
47 | # Build dataloaders
48 | loader_train, loader_val = build_dataloaders_datr(cfg, settings)
49 |
50 | if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE or "LightTrack" in cfg.MODEL.BACKBONE.TYPE:
51 | cfg.ckpt_dir = settings.save_dir
52 |
53 | # Create network
54 | if settings.script_name == "ostrack":
55 | net = build_ostrack(cfg)
56 | else:
57 | raise ValueError("illegal script name")
58 |
59 | # wrap networks to distributed one
60 | net.cuda()
61 | if settings.local_rank != -1:
62 | # net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter
63 | net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
64 | settings.device = torch.device("cuda:%d" % settings.local_rank)
65 | else:
66 | settings.device = torch.device("cuda:0")
67 | settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
68 | settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
69 | settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
70 | # Loss functions and Actors
71 | if settings.script_name == "ostrack":
72 | focal_loss = FocalLoss()
73 | objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss()}
74 | loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 1., 'cls': 1.0}
75 | actor = OSTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg)
76 | else:
77 | raise ValueError("illegal script name")
78 |
79 | # if cfg.TRAIN.DEEP_SUPERVISION:
80 | # raise ValueError("Deep supervision is not supported now.")
81 |
82 | # Optimizer, parameters, and learning rates
83 | optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
84 | use_amp = getattr(cfg.TRAIN, "AMP", False)
85 | trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler=lr_scheduler,
86 | imix=cfg.DA.imix,imix_epoch=cfg.DA.imix_epoch,imix_interval=cfg.DA.imix_interval,use_amp=use_amp)
87 |
88 | # train process
89 | trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
90 |
--------------------------------------------------------------------------------
/ostrack_cuda113_env.yaml:
--------------------------------------------------------------------------------
1 | name: false
2 | channels:
3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch
4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=4.5=1_gnu
10 | - blas=1.0=mkl
11 | - bzip2=1.0.8=h7f98852_4
12 | - ca-certificates=2022.3.29=h06a4308_0
13 | - certifi=2021.10.8=py38h578d9bd_2
14 | - colorama=0.4.4=pyhd3eb1b0_0
15 | - cudatoolkit=11.3.1=h2bc3f7f_2
16 | - ffmpeg=4.3=hf484d3e_0
17 | - freetype=2.10.4=h0708190_1
18 | - gmp=6.2.1=h58526e2_0
19 | - gnutls=3.6.13=h85f3911_1
20 | - intel-openmp=2021.4.0=h06a4308_3561
21 | - jpeg=9d=h7f8727e_0
22 | - lame=3.100=h7f98852_1001
23 | - ld_impl_linux-64=2.35.1=h7274673_9
24 | - libffi=3.3=he6710b0_2
25 | - libgcc-ng=9.3.0=h5101ec6_17
26 | - libgomp=9.3.0=h5101ec6_17
27 | - libiconv=1.16=h516909a_0
28 | - libpng=1.6.37=h21135ba_2
29 | - libstdcxx-ng=9.3.0=hd4cf53a_17
30 | - libtiff=4.0.10=hc3755c2_1005
31 | - libuv=1.40.0=h7b6447c_0
32 | - lz4-c=1.9.3=h9c3ff4c_1
33 | - mkl=2021.4.0=h06a4308_640
34 | - mkl-service=2.4.0=py38h497a2fe_0
35 | - mkl_fft=1.3.1=py38hd3c417c_0
36 | - mkl_random=1.2.2=py38h1abd341_0
37 | - ncurses=6.3=h7f8727e_2
38 | - nettle=3.6=he412f7d_0
39 | - numpy=1.21.2=py38h20f2e39_0
40 | - numpy-base=1.21.2=py38h79a1101_0
41 | - olefile=0.46=pyh9f0ad1d_1
42 | - openh264=2.1.1=h780b84a_0
43 | - openssl=1.1.1n=h7f8727e_0
44 | - pillow=6.2.1=py38h6b7be26_0
45 | - pip=21.2.4=py38h06a4308_0
46 | - python=3.8.13=h12debd9_0
47 | - python_abi=3.8=2_cp38
48 | - pytorch=1.10.0=py3.8_cuda11.3_cudnn8.2.0_0
49 | - pytorch-mutex=1.0=cuda
50 | - readline=8.1.2=h7f8727e_1
51 | - setuptools=58.0.4=py38h06a4308_0
52 | - six=1.16.0=pyh6c4a22f_0
53 | - sqlite=3.38.2=hc218d9a_0
54 | - tk=8.6.11=h1ccaba5_0
55 | - torchaudio=0.10.0=py38_cu113
56 | - torchvision=0.11.0=py38_cu113
57 | - tqdm=4.63.0=pyhd3eb1b0_0
58 | - typing_extensions=4.1.1=pyha770c72_0
59 | - wheel=0.37.1=pyhd3eb1b0_0
60 | - xz=5.2.5=h7b6447c_0
61 | - zlib=1.2.11=h7f8727e_4
62 | - zstd=1.4.9=ha95c52a_0
63 | - pip:
64 | - absl-py==1.0.0
65 | - cachetools==5.0.0
66 | - cffi==1.15.0
67 | - charset-normalizer==2.0.12
68 | - click==8.1.2
69 | - cycler==0.11.0
70 | - cython==0.29.28
71 | - docker-pycreds==0.4.0
72 | - easydict==1.9
73 | - fonttools==4.31.2
74 | - gitdb==4.0.9
75 | - gitpython==3.1.27
76 | - google-auth==2.6.2
77 | - google-auth-oauthlib==0.4.6
78 | - grpcio==1.45.0
79 | - idna==3.3
80 | - importlib-metadata==4.11.3
81 | - jpeg4py==0.1.4
82 | - jsonpatch==1.32
83 | - jsonpointer==2.2
84 | - kiwisolver==1.4.2
85 | - lmdb==1.3.0
86 | - markdown==3.3.6
87 | - matplotlib==3.5.1
88 | - oauthlib==3.2.0
89 | - opencv-python==4.5.5.64
90 | - packaging==21.3
91 | - pandas==1.4.2
92 | - pathtools==0.1.2
93 | - promise==2.3
94 | - protobuf==3.20.0
95 | - psutil==5.9.0
96 | - pyasn1==0.4.8
97 | - pyasn1-modules==0.2.8
98 | - pycocotools==2.0.4
99 | - pycparser==2.21
100 | - pyparsing==3.0.7
101 | - python-dateutil==2.8.2
102 | - pytz==2022.1
103 | - pyyaml==6.0
104 | - pyzmq==22.3.0
105 | - requests==2.27.1
106 | - requests-oauthlib==1.3.1
107 | - rsa==4.8
108 | - scipy==1.8.0
109 | - sentry-sdk==1.5.8
110 | - setproctitle==1.2.2
111 | - shortuuid==1.0.8
112 | - smmap==5.0.0
113 | - tb-nightly==2.9.0a20220403
114 | - tensorboard-data-server==0.6.1
115 | - tensorboard-plugin-wit==1.8.1
116 | - termcolor==1.1.0
117 | - thop==0.0.31-2005241907
118 | - tikzplotlib==0.10.1
119 | - timm==0.5.4
120 | - torchfile==0.1.0
121 | - tornado==6.1
122 | - urllib3==1.26.9
123 | - visdom==0.1.8.9
124 | - wandb==0.12.11
125 | - webcolors==1.11.1
126 | - websocket-client==1.3.2
127 | - werkzeug==2.1.1
128 | - yaspin==2.1.0
129 | - zipp==3.7.0
130 | prefix: /public/yebotao/conda_envs/ostrack
131 |
--------------------------------------------------------------------------------
/lib/train/dataset/imagenetvid_lmdb.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .base_video_dataset import BaseVideoDataset
3 | from lib.train.data import jpeg4py_loader
4 | import torch
5 | from collections import OrderedDict
6 | from lib.train.admin import env_settings
7 | from lib.utils.lmdb_utils import decode_img, decode_json
8 |
9 |
10 | def get_target_to_image_ratio(seq):
11 | anno = torch.Tensor(seq['anno'])
12 | img_sz = torch.Tensor(seq['image_size'])
13 | return (anno[0, 2:4].prod() / (img_sz.prod())).sqrt()
14 |
15 |
16 | class ImagenetVID_lmdb(BaseVideoDataset):
17 | """ Imagenet VID dataset.
18 |
19 | Publication:
20 | ImageNet Large Scale Visual Recognition Challenge
21 | Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy,
22 | Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei
23 | IJCV, 2015
24 | https://arxiv.org/pdf/1409.0575.pdf
25 |
26 | Download the dataset from http://image-net.org/
27 | """
28 | def __init__(self, root=None, image_loader=jpeg4py_loader, min_length=0, max_target_area=1):
29 | """
30 | args:
31 | root - path to the imagenet vid dataset.
32 | image_loader (default_image_loader) - The function to read the images. If installed,
33 | jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. Else,
34 | opencv's imread is used.
35 | min_length - Minimum allowed sequence length.
36 | max_target_area - max allowed ratio between target area and image area. Can be used to filter out targets
37 | which cover complete image.
38 | """
39 | root = env_settings().imagenet_dir if root is None else root
40 | super().__init__("imagenetvid_lmdb", root, image_loader)
41 |
42 | sequence_list_dict = decode_json(root, "cache.json")
43 | self.sequence_list = sequence_list_dict
44 |
45 | # Filter the sequences based on min_length and max_target_area in the first frame
46 | self.sequence_list = [x for x in self.sequence_list if len(x['anno']) >= min_length and
47 | get_target_to_image_ratio(x) < max_target_area]
48 |
49 | def get_name(self):
50 | return 'imagenetvid_lmdb'
51 |
52 | def get_num_sequences(self):
53 | return len(self.sequence_list)
54 |
55 | def get_sequence_info(self, seq_id):
56 | bb_anno = torch.Tensor(self.sequence_list[seq_id]['anno'])
57 | valid = (bb_anno[:, 2] > 0) & (bb_anno[:, 3] > 0)
58 | visible = torch.ByteTensor(self.sequence_list[seq_id]['target_visible']) & valid.byte()
59 | return {'bbox': bb_anno, 'valid': valid, 'visible': visible}
60 |
61 | def _get_frame(self, sequence, frame_id):
62 | set_name = 'ILSVRC2015_VID_train_{:04d}'.format(sequence['set_id'])
63 | vid_name = 'ILSVRC2015_train_{:08d}'.format(sequence['vid_id'])
64 | frame_number = frame_id + sequence['start_frame']
65 | frame_path = os.path.join('Data', 'VID', 'train', set_name, vid_name,
66 | '{:06d}.JPEG'.format(frame_number))
67 | return decode_img(self.root, frame_path)
68 |
69 | def get_frames(self, seq_id, frame_ids, anno=None):
70 | sequence = self.sequence_list[seq_id]
71 |
72 | frame_list = [self._get_frame(sequence, f) for f in frame_ids]
73 |
74 | if anno is None:
75 | anno = self.get_sequence_info(seq_id)
76 |
77 | # Create anno dict
78 | anno_frames = {}
79 | for key, value in anno.items():
80 | anno_frames[key] = [value[f_id, ...].clone() for f_id in frame_ids]
81 |
82 | # added the class info to the meta info
83 | object_meta = OrderedDict({'object_class': sequence['class_name'],
84 | 'motion_class': None,
85 | 'major_class': None,
86 | 'root_class': None,
87 | 'motion_adverb': None})
88 |
89 | return frame_list, anno_frames, object_meta
90 |
91 |
--------------------------------------------------------------------------------
/lib/models/ostrack/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def combine_tokens(template_tokens, search_tokens, mode='direct', return_res=False):
8 | # [B, HW, C]
9 | len_t = template_tokens.shape[1]
10 | len_s = search_tokens.shape[1]
11 |
12 | if mode == 'direct':
13 | merged_feature = torch.cat((template_tokens, search_tokens), dim=1)
14 | elif mode == 'template_central':
15 | central_pivot = len_s // 2
16 | first_half = search_tokens[:, :central_pivot, :]
17 | second_half = search_tokens[:, central_pivot:, :]
18 | merged_feature = torch.cat((first_half, template_tokens, second_half), dim=1)
19 | elif mode == 'partition':
20 | feat_size_s = int(math.sqrt(len_s))
21 | feat_size_t = int(math.sqrt(len_t))
22 | window_size = math.ceil(feat_size_t / 2.)
23 | # pad feature maps to multiples of window size
24 | B, _, C = template_tokens.shape
25 | H = W = feat_size_t
26 | template_tokens = template_tokens.view(B, H, W, C)
27 | pad_l = pad_b = pad_r = 0
28 | # pad_r = (window_size - W % window_size) % window_size
29 | pad_t = (window_size - H % window_size) % window_size
30 | template_tokens = F.pad(template_tokens, (0, 0, pad_l, pad_r, pad_t, pad_b))
31 | _, Hp, Wp, _ = template_tokens.shape
32 | template_tokens = template_tokens.view(B, Hp // window_size, window_size, W, C)
33 | template_tokens = torch.cat([template_tokens[:, 0, ...], template_tokens[:, 1, ...]], dim=2)
34 | _, Hc, Wc, _ = template_tokens.shape
35 | template_tokens = template_tokens.view(B, -1, C)
36 | merged_feature = torch.cat([template_tokens, search_tokens], dim=1)
37 |
38 | # calculate new h and w, which may be useful for SwinT or others
39 | merged_h, merged_w = feat_size_s + Hc, feat_size_s
40 | if return_res:
41 | return merged_feature, merged_h, merged_w
42 |
43 | else:
44 | raise NotImplementedError
45 |
46 | return merged_feature
47 |
48 |
49 | def recover_tokens(merged_tokens, len_template_token, len_search_token, mode='direct'):
50 | if mode == 'direct':
51 | recovered_tokens = merged_tokens
52 | elif mode == 'template_central':
53 | central_pivot = len_search_token // 2
54 | len_remain = len_search_token - central_pivot
55 | len_half_and_t = central_pivot + len_template_token
56 |
57 | first_half = merged_tokens[:, :central_pivot, :]
58 | second_half = merged_tokens[:, -len_remain:, :]
59 | template_tokens = merged_tokens[:, central_pivot:len_half_and_t, :]
60 |
61 | recovered_tokens = torch.cat((template_tokens, first_half, second_half), dim=1)
62 | elif mode == 'partition':
63 | recovered_tokens = merged_tokens
64 | else:
65 | raise NotImplementedError
66 |
67 | return recovered_tokens
68 |
69 |
70 | def window_partition(x, window_size: int):
71 | """
72 | Args:
73 | x: (B, H, W, C)
74 | window_size (int): window size
75 |
76 | Returns:
77 | windows: (num_windows*B, window_size, window_size, C)
78 | """
79 | B, H, W, C = x.shape
80 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
81 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
82 | return windows
83 |
84 |
85 | def window_reverse(windows, window_size: int, H: int, W: int):
86 | """
87 | Args:
88 | windows: (num_windows*B, window_size, window_size, C)
89 | window_size (int): Window size
90 | H (int): Height of image
91 | W (int): Width of image
92 |
93 | Returns:
94 | x: (B, H, W, C)
95 | """
96 | B = int(windows.shape[0] / (H * W / window_size / window_size))
97 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
98 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
99 | return x
100 |
--------------------------------------------------------------------------------
/lib/test/utils/hann.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn.functional as F
4 |
5 |
6 | def hann1d(sz: int, centered = True) -> torch.Tensor:
7 | """1D cosine window."""
8 | if centered:
9 | return 0.5 * (1 - torch.cos((2 * math.pi / (sz + 1)) * torch.arange(1, sz + 1).float()))
10 | w = 0.5 * (1 + torch.cos((2 * math.pi / (sz + 2)) * torch.arange(0, sz//2 + 1).float()))
11 | return torch.cat([w, w[1:sz-sz//2].flip((0,))])
12 |
13 |
14 | def hann2d(sz: torch.Tensor, centered = True) -> torch.Tensor:
15 | """2D cosine window."""
16 | return hann1d(sz[0].item(), centered).reshape(1, 1, -1, 1) * hann1d(sz[1].item(), centered).reshape(1, 1, 1, -1)
17 |
18 |
19 | def hann2d_bias(sz: torch.Tensor, ctr_point: torch.Tensor, centered = True) -> torch.Tensor:
20 | """2D cosine window."""
21 | distance = torch.stack([ctr_point, sz-ctr_point], dim=0)
22 | max_distance, _ = distance.max(dim=0)
23 |
24 | hann1d_x = hann1d(max_distance[0].item() * 2, centered)
25 | hann1d_x = hann1d_x[max_distance[0] - distance[0, 0]: max_distance[0] + distance[1, 0]]
26 | hann1d_y = hann1d(max_distance[1].item() * 2, centered)
27 | hann1d_y = hann1d_y[max_distance[1] - distance[0, 1]: max_distance[1] + distance[1, 1]]
28 |
29 | return hann1d_y.reshape(1, 1, -1, 1) * hann1d_x.reshape(1, 1, 1, -1)
30 |
31 |
32 |
33 | def hann2d_clipped(sz: torch.Tensor, effective_sz: torch.Tensor, centered = True) -> torch.Tensor:
34 | """1D clipped cosine window."""
35 |
36 | # Ensure that the difference is even
37 | effective_sz += (effective_sz - sz) % 2
38 | effective_window = hann1d(effective_sz[0].item(), True).reshape(1, 1, -1, 1) * hann1d(effective_sz[1].item(), True).reshape(1, 1, 1, -1)
39 |
40 | pad = (sz - effective_sz) // 2
41 |
42 | window = F.pad(effective_window, (pad[1].item(), pad[1].item(), pad[0].item(), pad[0].item()), 'replicate')
43 |
44 | if centered:
45 | return window
46 | else:
47 | mid = (sz / 2).int()
48 | window_shift_lr = torch.cat((window[:, :, :, mid[1]:], window[:, :, :, :mid[1]]), 3)
49 | return torch.cat((window_shift_lr[:, :, mid[0]:, :], window_shift_lr[:, :, :mid[0], :]), 2)
50 |
51 |
52 | def gauss_fourier(sz: int, sigma: float, half: bool = False) -> torch.Tensor:
53 | if half:
54 | k = torch.arange(0, int(sz/2+1))
55 | else:
56 | k = torch.arange(-int((sz-1)/2), int(sz/2+1))
57 | return (math.sqrt(2*math.pi) * sigma / sz) * torch.exp(-2 * (math.pi * sigma * k.float() / sz)**2)
58 |
59 |
60 | def gauss_spatial(sz, sigma, center=0, end_pad=0):
61 | k = torch.arange(-(sz-1)/2, (sz+1)/2+end_pad)
62 | return torch.exp(-1.0/(2*sigma**2) * (k - center)**2)
63 |
64 |
65 | def label_function(sz: torch.Tensor, sigma: torch.Tensor):
66 | return gauss_fourier(sz[0].item(), sigma[0].item()).reshape(1, 1, -1, 1) * gauss_fourier(sz[1].item(), sigma[1].item(), True).reshape(1, 1, 1, -1)
67 |
68 | def label_function_spatial(sz: torch.Tensor, sigma: torch.Tensor, center: torch.Tensor = torch.zeros(2), end_pad: torch.Tensor = torch.zeros(2)):
69 | """The origin is in the middle of the image."""
70 | return gauss_spatial(sz[0].item(), sigma[0].item(), center[0], end_pad[0].item()).reshape(1, 1, -1, 1) * \
71 | gauss_spatial(sz[1].item(), sigma[1].item(), center[1], end_pad[1].item()).reshape(1, 1, 1, -1)
72 |
73 |
74 | def cubic_spline_fourier(f, a):
75 | """The continuous Fourier transform of a cubic spline kernel."""
76 |
77 | bf = (6*(1 - torch.cos(2 * math.pi * f)) + 3*a*(1 - torch.cos(4 * math.pi * f))
78 | - (6 + 8*a)*math.pi*f*torch.sin(2 * math.pi * f) - 2*a*math.pi*f*torch.sin(4 * math.pi * f)) \
79 | / (4 * math.pi**4 * f**4)
80 |
81 | bf[f == 0] = 1
82 |
83 | return bf
84 |
85 | def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor):
86 | """Computes maximum and argmax in the last two dimensions."""
87 |
88 | max_val_row, argmax_row = torch.max(a, dim=-2)
89 | max_val, argmax_col = torch.max(max_val_row, dim=-1)
90 | argmax_row = argmax_row.view(argmax_col.numel(),-1)[torch.arange(argmax_col.numel()), argmax_col.view(-1)]
91 | argmax_row = argmax_row.reshape(argmax_col.shape)
92 | argmax = torch.cat((argmax_row.unsqueeze(-1), argmax_col.unsqueeze(-1)), -1)
93 | return max_val, argmax
94 |
--------------------------------------------------------------------------------
/tracking/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import random
4 |
5 |
6 | def parse_args():
7 | """
8 | args for training.
9 | """
10 | parser = argparse.ArgumentParser(description='Parse args for training')
11 | # for train
12 | parser.add_argument('--script', type=str, help='training script name')
13 | parser.add_argument('--config', type=str, default='baseline', help='yaml configure file name')
14 | parser.add_argument('--save_dir', type=str, help='root directory to save checkpoints, logs, and tensorboard')
15 | parser.add_argument('--seed', type=int, default=42, help='seed for random numbers')
16 | parser.add_argument('--mode', type=str, choices=["single", "multiple", "multi_node"], default="multiple",
17 | help="train on single gpu or multiple gpus")
18 | parser.add_argument('--nproc_per_node', type=int, help="number of GPUs per node") # specify when mode is multiple
19 | parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format
20 | parser.add_argument('--script_prv', type=str, help='training script name')
21 | parser.add_argument('--config_prv', type=str, default='baseline', help='yaml configure file name')
22 | parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb
23 | # for knowledge distillation
24 | parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation
25 | parser.add_argument('--datr', type=int, choices=[0, 1,2], default=0) # whether to use our data augmentation datr
26 | parser.add_argument('--script_teacher', type=str, help='teacher script name')
27 | parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name')
28 |
29 | # for multiple machines
30 | parser.add_argument('--rank', type=int, help='Rank of the current process.')
31 | parser.add_argument('--world-size', type=int, help='Number of processes participating in the job.')
32 | parser.add_argument('--ip', type=str, default='127.0.0.1', help='IP of the current rank 0.')
33 | parser.add_argument('--port', type=int, default='20000', help='Port of the current rank 0.')
34 |
35 | args = parser.parse_args()
36 |
37 | return args
38 |
39 |
40 | def main():
41 | args = parse_args()
42 | if args.mode == "single":
43 | train_cmd = "python lib/train/run_training.py --script %s --config %s --save_dir %s --seed %d --use_lmdb %d " \
44 | "--script_prv %s --config_prv %s --distill %d --datr %d --script_teacher %s --config_teacher %s --use_wandb %d"\
45 | % (args.script, args.config, args.save_dir,args.seed, args.use_lmdb, args.script_prv, args.config_prv,
46 | args.distill,args.datr, args.script_teacher, args.config_teacher, args.use_wandb)
47 | elif args.mode == "multiple":
48 | train_cmd = "python -m torch.distributed.launch --nproc_per_node %d --master_port %d lib/train/run_training.py " \
49 | "--script %s --config %s --save_dir %s --seed %d --use_lmdb %d --script_prv %s --config_prv %s --use_wandb %d " \
50 | "--distill %d --datr %d --script_teacher %s --config_teacher %s" \
51 | % (args.nproc_per_node, random.randint(10000, 50000), args.script, args.config, args.save_dir,args.seed, args.use_lmdb, args.script_prv, args.config_prv, args.use_wandb,
52 | args.distill,args.datr, args.script_teacher, args.config_teacher)
53 | elif args.mode == "multi_node":
54 | train_cmd = "python -m torch.distributed.launch --nproc_per_node %d --master_addr %s --master_port %d --nnodes %d --node_rank %d lib/train/run_training.py " \
55 | "--script %s --config %s --save_dir %s --use_lmdb %d --script_prv %s --config_prv %s --use_wandb %d " \
56 | "--distill %d --datr %d --script_teacher %s --config_teacher %s" \
57 | % (args.nproc_per_node, args.ip, args.port, args.world_size, args.rank, args.script, args.config, args.save_dir, args.use_lmdb, args.script_prv, args.config_prv, args.use_wandb,
58 | args.distill,args.datr, args.script_teacher, args.config_teacher)
59 | else:
60 | raise ValueError("mode should be 'single' or 'multiple'.")
61 | os.system(train_cmd)
62 |
63 |
64 | if __name__ == "__main__":
65 | main()
66 |
--------------------------------------------------------------------------------
/lib/models/layers/rpe.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import trunc_normal_
4 |
5 |
6 | def generate_2d_relative_positional_encoding_index(z_shape, x_shape):
7 | '''
8 | z_shape: (z_h, z_w)
9 | x_shape: (x_h, x_w)
10 | '''
11 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1]))
12 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1]))
13 |
14 | z_2d_index_h = z_2d_index_h.flatten(0)
15 | z_2d_index_w = z_2d_index_w.flatten(0)
16 | x_2d_index_h = x_2d_index_h.flatten(0)
17 | x_2d_index_w = x_2d_index_w.flatten(0)
18 |
19 | diff_h = z_2d_index_h[:, None] - x_2d_index_h[None, :]
20 | diff_w = z_2d_index_w[:, None] - x_2d_index_w[None, :]
21 |
22 | diff = torch.stack((diff_h, diff_w), dim=-1)
23 | _, indices = torch.unique(diff.view(-1, 2), return_inverse=True, dim=0)
24 | return indices.view(z_shape[0] * z_shape[1], x_shape[0] * x_shape[1])
25 |
26 |
27 | def generate_2d_concatenated_self_attention_relative_positional_encoding_index(z_shape, x_shape):
28 | '''
29 | z_shape: (z_h, z_w)
30 | x_shape: (x_h, x_w)
31 | '''
32 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1]))
33 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1]))
34 |
35 | z_2d_index_h = z_2d_index_h.flatten(0)
36 | z_2d_index_w = z_2d_index_w.flatten(0)
37 | x_2d_index_h = x_2d_index_h.flatten(0)
38 | x_2d_index_w = x_2d_index_w.flatten(0)
39 |
40 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h))
41 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w))
42 |
43 | diff_h = concatenated_2d_index_h[:, None] - concatenated_2d_index_h[None, :]
44 | diff_w = concatenated_2d_index_w[:, None] - concatenated_2d_index_w[None, :]
45 |
46 | z_len = z_shape[0] * z_shape[1]
47 | x_len = x_shape[0] * x_shape[1]
48 | a = torch.empty((z_len + x_len), dtype=torch.int64)
49 | a[:z_len] = 0
50 | a[z_len:] = 1
51 | b=a[:, None].repeat(1, z_len + x_len)
52 | c=a[None, :].repeat(z_len + x_len, 1)
53 |
54 | diff = torch.stack((diff_h, diff_w, b, c), dim=-1)
55 | _, indices = torch.unique(diff.view((z_len + x_len) * (z_len + x_len), 4), return_inverse=True, dim=0)
56 | return indices.view((z_len + x_len), (z_len + x_len))
57 |
58 |
59 | def generate_2d_concatenated_cross_attention_relative_positional_encoding_index(z_shape, x_shape):
60 | '''
61 | z_shape: (z_h, z_w)
62 | x_shape: (x_h, x_w)
63 | '''
64 | z_2d_index_h, z_2d_index_w = torch.meshgrid(torch.arange(z_shape[0]), torch.arange(z_shape[1]))
65 | x_2d_index_h, x_2d_index_w = torch.meshgrid(torch.arange(x_shape[0]), torch.arange(x_shape[1]))
66 |
67 | z_2d_index_h = z_2d_index_h.flatten(0)
68 | z_2d_index_w = z_2d_index_w.flatten(0)
69 | x_2d_index_h = x_2d_index_h.flatten(0)
70 | x_2d_index_w = x_2d_index_w.flatten(0)
71 |
72 | concatenated_2d_index_h = torch.cat((z_2d_index_h, x_2d_index_h))
73 | concatenated_2d_index_w = torch.cat((z_2d_index_w, x_2d_index_w))
74 |
75 | diff_h = x_2d_index_h[:, None] - concatenated_2d_index_h[None, :]
76 | diff_w = x_2d_index_w[:, None] - concatenated_2d_index_w[None, :]
77 |
78 | z_len = z_shape[0] * z_shape[1]
79 | x_len = x_shape[0] * x_shape[1]
80 |
81 | a = torch.empty(z_len + x_len, dtype=torch.int64)
82 | a[: z_len] = 0
83 | a[z_len:] = 1
84 | c = a[None, :].repeat(x_len, 1)
85 |
86 | diff = torch.stack((diff_h, diff_w, c), dim=-1)
87 | _, indices = torch.unique(diff.view(x_len * (z_len + x_len), 3), return_inverse=True, dim=0)
88 | return indices.view(x_len, (z_len + x_len))
89 |
90 |
91 | class RelativePosition2DEncoder(nn.Module):
92 | def __init__(self, num_heads, embed_size):
93 | super(RelativePosition2DEncoder, self).__init__()
94 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads, embed_size)))
95 | trunc_normal_(self.relative_position_bias_table, std=0.02)
96 |
97 | def forward(self, attn_rpe_index):
98 | '''
99 | Args:
100 | attn_rpe_index (torch.Tensor): (*), any shape containing indices, max(attn_rpe_index) < embed_size
101 | Returns:
102 | torch.Tensor: (1, num_heads, *)
103 | '''
104 | return self.relative_position_bias_table[:, attn_rpe_index].unsqueeze(0)
105 |
--------------------------------------------------------------------------------
/lib/train/admin/environment.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | from collections import OrderedDict
4 |
5 |
6 | def create_default_local_file():
7 | path = os.path.join(os.path.dirname(__file__), 'local.py')
8 |
9 | empty_str = '\'\''
10 | default_settings = OrderedDict({
11 | 'workspace_dir': empty_str,
12 | 'tensorboard_dir': 'self.workspace_dir + \'/tensorboard/\'',
13 | 'pretrained_networks': 'self.workspace_dir + \'/pretrained_networks/\'',
14 | 'lasot_dir': empty_str,
15 | 'got10k_dir': empty_str,
16 | 'trackingnet_dir': empty_str,
17 | 'coco_dir': empty_str,
18 | 'lvis_dir': empty_str,
19 | 'sbd_dir': empty_str,
20 | 'imagenet_dir': empty_str,
21 | 'imagenetdet_dir': empty_str,
22 | 'ecssd_dir': empty_str,
23 | 'hkuis_dir': empty_str,
24 | 'msra10k_dir': empty_str,
25 | 'davis_dir': empty_str,
26 | 'youtubevos_dir': empty_str})
27 |
28 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
29 | 'tensorboard_dir': 'Directory for tensorboard files.'}
30 |
31 | with open(path, 'w') as f:
32 | f.write('class EnvironmentSettings:\n')
33 | f.write(' def __init__(self):\n')
34 |
35 | for attr, attr_val in default_settings.items():
36 | comment_str = None
37 | if attr in comment:
38 | comment_str = comment[attr]
39 | if comment_str is None:
40 | f.write(' self.{} = {}\n'.format(attr, attr_val))
41 | else:
42 | f.write(' self.{} = {} # {}\n'.format(attr, attr_val, comment_str))
43 |
44 |
45 | def create_default_local_file_ITP_train(workspace_dir, data_dir):
46 | path = os.path.join(os.path.dirname(__file__), 'local.py')
47 |
48 | empty_str = '\'\''
49 | default_settings = OrderedDict({
50 | 'workspace_dir': workspace_dir,
51 | 'tensorboard_dir': os.path.join(workspace_dir, 'tensorboard'), # Directory for tensorboard files.
52 | 'pretrained_networks': os.path.join(workspace_dir, 'pretrained_networks'),
53 | 'lasot_dir': os.path.join(data_dir, 'lasot'),
54 | 'got10k_dir': os.path.join(data_dir, 'got10k/train'),
55 | 'got10k_val_dir': os.path.join(data_dir, 'got10k/val'),
56 | 'lasot_lmdb_dir': os.path.join(data_dir, 'lasot_lmdb'),
57 | 'got10k_lmdb_dir': os.path.join(data_dir, 'got10k_lmdb'),
58 | 'trackingnet_dir': os.path.join(data_dir, 'trackingnet'),
59 | 'trackingnet_lmdb_dir': os.path.join(data_dir, 'trackingnet_lmdb'),
60 | 'coco_dir': os.path.join(data_dir, 'coco'),
61 | 'coco_lmdb_dir': os.path.join(data_dir, 'coco_lmdb'),
62 | 'lvis_dir': empty_str,
63 | 'sbd_dir': empty_str,
64 | 'imagenet_dir': os.path.join(data_dir, 'vid'),
65 | 'imagenet_lmdb_dir': os.path.join(data_dir, 'vid_lmdb'),
66 | 'imagenetdet_dir': empty_str,
67 | 'ecssd_dir': empty_str,
68 | 'hkuis_dir': empty_str,
69 | 'msra10k_dir': empty_str,
70 | 'davis_dir': empty_str,
71 | 'youtubevos_dir': empty_str})
72 |
73 | comment = {'workspace_dir': 'Base directory for saving network checkpoints.',
74 | 'tensorboard_dir': 'Directory for tensorboard files.'}
75 |
76 | with open(path, 'w') as f:
77 | f.write('class EnvironmentSettings:\n')
78 | f.write(' def __init__(self):\n')
79 |
80 | for attr, attr_val in default_settings.items():
81 | comment_str = None
82 | if attr in comment:
83 | comment_str = comment[attr]
84 | if comment_str is None:
85 | if attr_val == empty_str:
86 | f.write(' self.{} = {}\n'.format(attr, attr_val))
87 | else:
88 | f.write(' self.{} = \'{}\'\n'.format(attr, attr_val))
89 | else:
90 | f.write(' self.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
91 |
92 |
93 | def env_settings():
94 | env_module_name = 'lib.train.admin.local'
95 | try:
96 | env_module = importlib.import_module(env_module_name)
97 | return env_module.EnvironmentSettings()
98 | except:
99 | env_file = os.path.join(os.path.dirname(__file__), 'local.py')
100 |
101 | create_default_local_file()
102 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. Then try to run again.'.format(env_file))
103 |
--------------------------------------------------------------------------------
/tracking/profile_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | prj_path = os.path.join(os.path.dirname(__file__), '..')
5 | if prj_path not in sys.path:
6 | sys.path.append(prj_path)
7 |
8 | import argparse
9 | import torch
10 | from lib.utils.misc import NestedTensor
11 | from thop import profile
12 | from thop.utils import clever_format
13 | import time
14 | import importlib
15 |
16 |
17 | def parse_args():
18 | """
19 | args for training.
20 | """
21 | parser = argparse.ArgumentParser(description='Parse args for training')
22 | # for train
23 | parser.add_argument('--script', type=str, default='ostrack', choices=['ostrack'],
24 | help='training script name')
25 | parser.add_argument('--config', type=str, default='vitb_256_mae_ce_32x4_ep300', help='yaml configure file name')
26 | args = parser.parse_args()
27 |
28 | return args
29 |
30 |
31 | def evaluate_vit(model, template, search):
32 | '''Speed Test'''
33 | macs1, params1 = profile(model, inputs=(template, search),
34 | custom_ops=None, verbose=False)
35 | macs, params = clever_format([macs1, params1], "%.3f")
36 | print('overall macs is ', macs)
37 | print('overall params is ', params)
38 |
39 | T_w = 500
40 | T_t = 1000
41 | print("testing speed ...")
42 | torch.cuda.synchronize()
43 | with torch.no_grad():
44 | # overall
45 | for i in range(T_w):
46 | _ = model(template, search)
47 | start = time.time()
48 | for i in range(T_t):
49 | _ = model(template, search)
50 | torch.cuda.synchronize()
51 | end = time.time()
52 | avg_lat = (end - start) / T_t
53 | print("The average overall latency is %.2f ms" % (avg_lat * 1000))
54 | print("FPS is %.2f fps" % (1. / avg_lat))
55 | # for i in range(T_w):
56 | # _ = model(template, search)
57 | # start = time.time()
58 | # for i in range(T_t):
59 | # _ = model(template, search)
60 | # end = time.time()
61 | # avg_lat = (end - start) / T_t
62 | # print("The average backbone latency is %.2f ms" % (avg_lat * 1000))
63 |
64 |
65 | def evaluate_vit_separate(model, template, search):
66 | '''Speed Test'''
67 | T_w = 50
68 | T_t = 1000
69 | print("testing speed ...")
70 | z = model.forward_backbone(template, image_type='template')
71 | x = model.forward_backbone(search, image_type='search')
72 | with torch.no_grad():
73 | # overall
74 | for i in range(T_w):
75 | _ = model.forward_backbone(search, image_type='search')
76 | _ = model.forward_cat(z, x)
77 | start = time.time()
78 | for i in range(T_t):
79 | _ = model.forward_backbone(search, image_type='search')
80 | _ = model.forward_cat(z, x)
81 | end = time.time()
82 | avg_lat = (end - start) / T_t
83 | print("The average overall latency is %.2f ms" % (avg_lat * 1000))
84 |
85 |
86 | def get_data(bs, sz):
87 | img_patch = torch.randn(bs, 3, sz, sz)
88 | att_mask = torch.rand(bs, sz, sz) > 0.5
89 | return NestedTensor(img_patch, att_mask)
90 |
91 |
92 | if __name__ == "__main__":
93 | device = "cuda:0"
94 | torch.cuda.set_device(device)
95 | # Compute the Flops and Params of our STARK-S model
96 | args = parse_args()
97 | '''update cfg'''
98 | yaml_fname = 'experiments/%s/%s.yaml' % (args.script, args.config)
99 | config_module = importlib.import_module('lib.config.%s.config' % args.script)
100 | cfg = config_module.cfg
101 | config_module.update_config_from_file(yaml_fname)
102 | '''set some values'''
103 | bs = 1
104 | z_sz = cfg.TEST.TEMPLATE_SIZE
105 | x_sz = cfg.TEST.SEARCH_SIZE
106 |
107 | if args.script == "ostrack":
108 | model_module = importlib.import_module('lib.models')
109 | model_constructor = model_module.build_ostrack
110 | model = model_constructor(cfg, training=False)
111 | # get the template and search
112 | template = torch.randn(bs, 3, z_sz, z_sz)
113 | search = torch.randn(bs, 3, x_sz, x_sz)
114 | # transfer to device
115 | model = model.to(device)
116 | template = template.to(device)
117 | search = search.to(device)
118 |
119 | merge_layer = cfg.MODEL.BACKBONE.MERGE_LAYER
120 | if merge_layer <= 0:
121 | evaluate_vit(model, template, search)
122 | else:
123 | evaluate_vit_separate(model, template, search)
124 |
125 | else:
126 | raise NotImplementedError
127 |
--------------------------------------------------------------------------------
/lib/train/train_script_distill.py:
--------------------------------------------------------------------------------
1 | import os
2 | # loss function related
3 | from lib.utils.box_ops import giou_loss
4 | from torch.nn.functional import l1_loss
5 | from torch.nn import BCEWithLogitsLoss
6 | # train pipeline related
7 | from lib.train.trainers import LTRTrainer
8 | # distributed training related
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 | # some more advanced functions
11 | from .base_functions import *
12 | # network related
13 | from lib.models.stark import build_starks, build_starkst
14 | from lib.models.stark import build_stark_lightning_x_trt
15 | # forward propagation related
16 | from lib.train.actors import STARKLightningXtrtdistillActor
17 | # for import modules
18 | import importlib
19 |
20 |
21 | def build_network(script_name, cfg):
22 | # Create network
23 | if script_name == "stark_s":
24 | net = build_starks(cfg)
25 | elif script_name == "stark_st1" or script_name == "stark_st2":
26 | net = build_starkst(cfg)
27 | elif script_name == "stark_lightning_X_trt":
28 | net = build_stark_lightning_x_trt(cfg, phase="train")
29 | else:
30 | raise ValueError("illegal script name")
31 | return net
32 |
33 |
34 | def run(settings):
35 | settings.description = 'Training script for STARK-S, STARK-ST stage1, and STARK-ST stage2'
36 |
37 | # update the default configs with config file
38 | if not os.path.exists(settings.cfg_file):
39 | raise ValueError("%s doesn't exist." % settings.cfg_file)
40 | config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
41 | cfg = config_module.cfg
42 | config_module.update_config_from_file(settings.cfg_file)
43 | if settings.local_rank in [-1, 0]:
44 | print("New configuration is shown below.")
45 | for key in cfg.keys():
46 | print("%s configuration:" % key, cfg[key])
47 | print('\n')
48 |
49 | # update the default teacher configs with teacher config file
50 | if not os.path.exists(settings.cfg_file_teacher):
51 | raise ValueError("%s doesn't exist." % settings.cfg_file_teacher)
52 | config_module_teacher = importlib.import_module("lib.config.%s.config" % settings.script_teacher)
53 | cfg_teacher = config_module_teacher.cfg
54 | config_module_teacher.update_config_from_file(settings.cfg_file_teacher)
55 | if settings.local_rank in [-1, 0]:
56 | print("New teacher configuration is shown below.")
57 | for key in cfg_teacher.keys():
58 | print("%s configuration:" % key, cfg_teacher[key])
59 | print('\n')
60 |
61 | # update settings based on cfg
62 | update_settings(settings, cfg)
63 |
64 | # Record the training log
65 | log_dir = os.path.join(settings.save_dir, 'logs')
66 | if settings.local_rank in [-1, 0]:
67 | if not os.path.exists(log_dir):
68 | os.makedirs(log_dir)
69 | settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
70 |
71 | # Build dataloaders
72 | loader_train, loader_val = build_dataloaders(cfg, settings)
73 |
74 | if "RepVGG" in cfg.MODEL.BACKBONE.TYPE or "swin" in cfg.MODEL.BACKBONE.TYPE:
75 | cfg.ckpt_dir = settings.save_dir
76 | """turn on the distillation mode"""
77 | cfg.TRAIN.DISTILL = True
78 | cfg_teacher.TRAIN.DISTILL = True
79 | net = build_network(settings.script_name, cfg)
80 | net_teacher = build_network(settings.script_teacher, cfg_teacher)
81 |
82 | # wrap networks to distributed one
83 | net.cuda()
84 | net_teacher.cuda()
85 | net_teacher.eval()
86 |
87 | if settings.local_rank != -1:
88 | net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
89 | net_teacher = DDP(net_teacher, device_ids=[settings.local_rank], find_unused_parameters=True)
90 | settings.device = torch.device("cuda:%d" % settings.local_rank)
91 | else:
92 | settings.device = torch.device("cuda:0")
93 | # settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
94 | # settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
95 | settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "L1")
96 | # Loss functions and Actors
97 | if settings.script_name == "stark_lightning_X_trt":
98 | objective = {'giou': giou_loss, 'l1': l1_loss}
99 | loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT}
100 | actor = STARKLightningXtrtdistillActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings,
101 | net_teacher=net_teacher)
102 | else:
103 | raise ValueError("illegal script name")
104 |
105 | # Optimizer, parameters, and learning rates
106 | optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
107 | use_amp = getattr(cfg.TRAIN, "AMP", False)
108 | trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
109 |
110 | # train process
111 | trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True, distill=True)
112 |
--------------------------------------------------------------------------------
/lib/config/ostrack/config.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 | import yaml
3 |
4 | """
5 | Add default config for OSTrack.
6 | """
7 | cfg = edict()
8 |
9 | #data augmentation
10 | cfg.DA=edict()
11 | cfg.DA.border_prob=0.0
12 | cfg.DA.sfactor=[4,4]
13 | cfg.DA.imix=False
14 | cfg.DA.imix_occrate=0.5
15 | cfg.DA.imix_reverse_prob=0.3
16 | cfg.DA.imix_epoch=40
17 | cfg.DA.imix_interval=1
18 | cfg.DA.mix_type=1
19 | cfg.DA.norm_type='channel'
20 | cfg.DA.update_bbox=False
21 |
22 | cfg.CUTMIX=edict()
23 | cfg.CUTMIX.prob=0.0
24 | cfg.CUTMIX.occrate=0.5
25 | # MODEL
26 | cfg.MODEL = edict()
27 | cfg.MODEL.PRETRAIN_FILE = "mae_pretrain_vit_base.pth"
28 | cfg.MODEL.EXTRA_MERGER = False
29 |
30 | cfg.MODEL.RETURN_INTER = False
31 | cfg.MODEL.RETURN_STAGES = []
32 |
33 | # MODEL.BACKBONE
34 | cfg.MODEL.BACKBONE = edict()
35 | cfg.MODEL.BACKBONE.TYPE = "vit_base_patch16_224"
36 | cfg.MODEL.BACKBONE.STRIDE = 16
37 | cfg.MODEL.BACKBONE.MID_PE = False
38 | cfg.MODEL.BACKBONE.SEP_SEG = False
39 | cfg.MODEL.BACKBONE.CAT_MODE = 'direct'
40 | cfg.MODEL.BACKBONE.MERGE_LAYER = 0
41 | cfg.MODEL.BACKBONE.ADD_CLS_TOKEN = False
42 | cfg.MODEL.BACKBONE.CLS_TOKEN_USE_MODE = 'ignore'
43 |
44 | cfg.MODEL.BACKBONE.CE_LOC = []
45 | cfg.MODEL.BACKBONE.CE_KEEP_RATIO = []
46 | cfg.MODEL.BACKBONE.CE_TEMPLATE_RANGE = 'ALL' # choose between ALL, CTR_POINT, CTR_REC, GT_BOX
47 |
48 | # MODEL.HEAD
49 | cfg.MODEL.HEAD = edict()
50 | cfg.MODEL.HEAD.TYPE = "CENTER"
51 | cfg.MODEL.HEAD.NUM_CHANNELS = 256
52 |
53 | # TRAIN
54 | cfg.TRAIN = edict()
55 | cfg.TRAIN.LR = 0.0001
56 | cfg.TRAIN.WEIGHT_DECAY = 0.0001
57 | cfg.TRAIN.EPOCH = 500
58 | cfg.TRAIN.LR_DROP_EPOCH = 400
59 | cfg.TRAIN.BATCH_SIZE = 16
60 | cfg.TRAIN.NUM_WORKER = 8
61 | cfg.TRAIN.OPTIMIZER = "ADAMW"
62 | cfg.TRAIN.BACKBONE_MULTIPLIER = 0.1
63 | cfg.TRAIN.GIOU_WEIGHT = 2.0
64 | cfg.TRAIN.L1_WEIGHT = 5.0
65 | cfg.TRAIN.FREEZE_LAYERS = [0, ]
66 | cfg.TRAIN.PRINT_INTERVAL = 50
67 | cfg.TRAIN.VAL_EPOCH_INTERVAL = 20
68 | cfg.TRAIN.GRAD_CLIP_NORM = 0.1
69 | cfg.TRAIN.AMP = False
70 |
71 | cfg.TRAIN.CE_START_EPOCH = 20 # candidate elimination start epoch
72 | cfg.TRAIN.CE_WARM_EPOCH = 80 # candidate elimination warm up epoch
73 | cfg.TRAIN.DROP_PATH_RATE = 0.1 # drop path rate for ViT backbone
74 |
75 | # TRAIN.SCHEDULER
76 | cfg.TRAIN.SCHEDULER = edict()
77 | cfg.TRAIN.SCHEDULER.TYPE = "step"
78 | cfg.TRAIN.SCHEDULER.DECAY_RATE = 0.1
79 |
80 | # DATA
81 | cfg.DATA = edict()
82 | cfg.DATA.SAMPLER_MODE = "causal" # sampling methods
83 | cfg.DATA.DATA_FRACTION = None # None means all data, could also use 1 here
84 | cfg.DATA.MEAN = [0.485, 0.456, 0.406]
85 | cfg.DATA.STD = [0.229, 0.224, 0.225]
86 | cfg.DATA.MAX_SAMPLE_INTERVAL = 200
87 | # DATA.TRAIN
88 | cfg.DATA.TRAIN = edict()
89 | cfg.DATA.TRAIN.DATASETS_NAME = ["LASOT", "GOT10K_vottrain"]
90 | cfg.DATA.TRAIN.DATASETS_RATIO = [1, 1]
91 | cfg.DATA.TRAIN.SAMPLE_PER_EPOCH = 60000
92 | # DATA.VAL
93 | cfg.DATA.VAL = edict()
94 | cfg.DATA.VAL.DATASETS_NAME = ["GOT10K_votval"]
95 | cfg.DATA.VAL.DATASETS_RATIO = [1]
96 | cfg.DATA.VAL.SAMPLE_PER_EPOCH = 10000
97 | # DATA.SEARCH
98 | cfg.DATA.SEARCH = edict()
99 | cfg.DATA.SEARCH.SIZE = 320
100 | cfg.DATA.SEARCH.FACTOR = 5.0
101 | cfg.DATA.SEARCH.CENTER_JITTER = 4.5
102 | cfg.DATA.SEARCH.SCALE_JITTER = 0.5
103 | cfg.DATA.SEARCH.NUMBER = 1
104 | # DATA.TEMPLATE
105 | cfg.DATA.TEMPLATE = edict()
106 | cfg.DATA.TEMPLATE.NUMBER = 1
107 | cfg.DATA.TEMPLATE.SIZE = 128
108 | cfg.DATA.TEMPLATE.FACTOR = 2.0
109 | cfg.DATA.TEMPLATE.CENTER_JITTER = 0
110 | cfg.DATA.TEMPLATE.SCALE_JITTER = 0
111 |
112 | # TEST
113 | cfg.TEST = edict()
114 | cfg.TEST.TEMPLATE_FACTOR = 2.0
115 | cfg.TEST.TEMPLATE_SIZE = 128
116 | cfg.TEST.SEARCH_FACTOR = 5.0
117 | cfg.TEST.SEARCH_SIZE = 320
118 | cfg.TEST.EPOCH = 500
119 |
120 |
121 | def _edict2dict(dest_dict, src_edict):
122 | if isinstance(dest_dict, dict) and isinstance(src_edict, dict):
123 | for k, v in src_edict.items():
124 | if not isinstance(v, edict):
125 | dest_dict[k] = v
126 | else:
127 | dest_dict[k] = {}
128 | _edict2dict(dest_dict[k], v)
129 | else:
130 | return
131 |
132 |
133 | def gen_config(config_file):
134 | cfg_dict = {}
135 | _edict2dict(cfg_dict, cfg)
136 | with open(config_file, 'w') as f:
137 | yaml.dump(cfg_dict, f, default_flow_style=False)
138 |
139 |
140 | def _update_config(base_cfg, exp_cfg):
141 | if isinstance(base_cfg, dict) and isinstance(exp_cfg, edict):
142 | for k, v in exp_cfg.items():
143 | if k in base_cfg:
144 | if not isinstance(v, dict):
145 | base_cfg[k] = v
146 | else:
147 | _update_config(base_cfg[k], v)
148 | else:
149 | raise ValueError("{} not exist in config.py".format(k))
150 | else:
151 | return
152 |
153 |
154 | def update_config_from_file(filename, base_cfg=None):
155 | exp_config = None
156 | with open(filename) as f:
157 | exp_config = edict(yaml.safe_load(f))
158 | if base_cfg is not None:
159 | _update_config(base_cfg, exp_config)
160 | else:
161 | _update_config(cfg, exp_config)
162 |
--------------------------------------------------------------------------------
/lib/models/layers/attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from timm.models.layers import trunc_normal_
5 |
6 | from lib.models.layers.rpe import generate_2d_concatenated_self_attention_relative_positional_encoding_index
7 |
8 |
9 | class Attention(nn.Module):
10 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
11 | rpe=False, z_size=7, x_size=14):
12 | super().__init__()
13 | self.num_heads = num_heads
14 | head_dim = dim // num_heads
15 | self.scale = head_dim ** -0.5
16 |
17 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
18 | self.attn_drop = nn.Dropout(attn_drop)
19 | self.proj = nn.Linear(dim, dim)
20 | self.proj_drop = nn.Dropout(proj_drop)
21 |
22 | self.rpe =rpe
23 | if self.rpe:
24 | relative_position_index = \
25 | generate_2d_concatenated_self_attention_relative_positional_encoding_index([z_size, z_size],
26 | [x_size, x_size])
27 | self.register_buffer("relative_position_index", relative_position_index)
28 | # define a parameter table of relative position bias
29 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads,
30 | relative_position_index.max() + 1)))
31 | trunc_normal_(self.relative_position_bias_table, std=0.02)
32 |
33 | def forward(self, x, mask=None, return_attention=False):
34 | # x: B, N, C
35 | # mask: [B, N, ] torch.bool
36 | B, N, C = x.shape
37 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
38 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
39 |
40 | attn = (q @ k.transpose(-2, -1)) * self.scale
41 |
42 | if self.rpe:
43 | relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index].unsqueeze(0)
44 | attn += relative_position_bias
45 |
46 | if mask is not None:
47 | attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'),)
48 |
49 | attn = attn.softmax(dim=-1)
50 | attn = self.attn_drop(attn)
51 |
52 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
53 | x = self.proj(x)
54 | x = self.proj_drop(x)
55 |
56 | if return_attention:
57 | return x, attn
58 | else:
59 | return x
60 |
61 |
62 | class Attention_talking_head(nn.Module):
63 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
64 | # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
65 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
66 | rpe=True, z_size=7, x_size=14):
67 | super().__init__()
68 |
69 | self.num_heads = num_heads
70 |
71 | head_dim = dim // num_heads
72 |
73 | self.scale = qk_scale or head_dim ** -0.5
74 |
75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
76 | self.attn_drop = nn.Dropout(attn_drop)
77 |
78 | self.proj = nn.Linear(dim, dim)
79 |
80 | self.proj_l = nn.Linear(num_heads, num_heads)
81 | self.proj_w = nn.Linear(num_heads, num_heads)
82 |
83 | self.proj_drop = nn.Dropout(proj_drop)
84 |
85 | self.rpe = rpe
86 | if self.rpe:
87 | relative_position_index = \
88 | generate_2d_concatenated_self_attention_relative_positional_encoding_index([z_size, z_size],
89 | [x_size, x_size])
90 | self.register_buffer("relative_position_index", relative_position_index)
91 | # define a parameter table of relative position bias
92 | self.relative_position_bias_table = nn.Parameter(torch.empty((num_heads,
93 | relative_position_index.max() + 1)))
94 | trunc_normal_(self.relative_position_bias_table, std=0.02)
95 |
96 | def forward(self, x, mask=None):
97 | B, N, C = x.shape
98 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
99 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
100 |
101 | attn = (q @ k.transpose(-2, -1))
102 |
103 | if self.rpe:
104 | relative_position_bias = self.relative_position_bias_table[:, self.relative_position_index].unsqueeze(0)
105 | attn += relative_position_bias
106 |
107 | if mask is not None:
108 | attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2),
109 | float('-inf'),)
110 |
111 | attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
112 |
113 | attn = attn.softmax(dim=-1)
114 |
115 | attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
116 | attn = self.attn_drop(attn)
117 |
118 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
119 | x = self.proj(x)
120 | x = self.proj_drop(x)
121 | return x
--------------------------------------------------------------------------------
/lib/test/evaluation/environment.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 |
5 | class EnvSettings:
6 | def __init__(self):
7 | test_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
8 |
9 | self.results_path = '{}/tracking_results/'.format(test_path)
10 | self.segmentation_path = '{}/segmentation_results/'.format(test_path)
11 | self.network_path = '{}/networks/'.format(test_path)
12 | self.result_plot_path = '{}/result_plots/'.format(test_path)
13 | self.otb_path = ''
14 | self.nfs_path = ''
15 | self.uav_path = ''
16 | self.tpl_path = ''
17 | self.vot_path = ''
18 | self.got10k_path = ''
19 | self.lasot_path = ''
20 | self.trackingnet_path = ''
21 | self.davis_dir = ''
22 | self.youtubevos_dir = ''
23 |
24 | self.got_packed_results_path = ''
25 | self.got_reports_path = ''
26 | self.tn_packed_results_path = ''
27 |
28 |
29 | def create_default_local_file():
30 | comment = {'results_path': 'Where to store tracking results',
31 | 'network_path': 'Where tracking networks are stored.'}
32 |
33 | path = os.path.join(os.path.dirname(__file__), 'local.py')
34 | with open(path, 'w') as f:
35 | settings = EnvSettings()
36 |
37 | f.write('from test.evaluation.environment import EnvSettings\n\n')
38 | f.write('def local_env_settings():\n')
39 | f.write(' settings = EnvSettings()\n\n')
40 | f.write(' # Set your local paths here.\n\n')
41 |
42 | for attr in dir(settings):
43 | comment_str = None
44 | if attr in comment:
45 | comment_str = comment[attr]
46 | attr_val = getattr(settings, attr)
47 | if not attr.startswith('__') and not callable(attr_val):
48 | if comment_str is None:
49 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val))
50 | else:
51 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
52 | f.write('\n return settings\n\n')
53 |
54 |
55 | class EnvSettings_ITP:
56 | def __init__(self, workspace_dir, data_dir, save_dir):
57 | self.prj_dir = workspace_dir
58 | self.save_dir = save_dir
59 | self.results_path = os.path.join(save_dir, 'test/tracking_results')
60 | self.segmentation_path = os.path.join(save_dir, 'test/segmentation_results')
61 | self.network_path = os.path.join(save_dir, 'test/networks')
62 | self.result_plot_path = os.path.join(save_dir, 'test/result_plots')
63 | self.otb_path = os.path.join(data_dir, 'otb')
64 | self.nfs_path = os.path.join(data_dir, 'nfs')
65 | self.uav_path = os.path.join(data_dir, 'uav')
66 | self.tc128_path = os.path.join(data_dir, 'TC128')
67 | self.tpl_path = ''
68 | self.vot_path = os.path.join(data_dir, 'VOT2019')
69 | self.got10k_path = os.path.join(data_dir, 'got10k')
70 | self.got10k_lmdb_path = os.path.join(data_dir, 'got10k_lmdb')
71 | self.lasot_path = os.path.join(data_dir, 'lasot')
72 | self.lasot_lmdb_path = os.path.join(data_dir, 'lasot_lmdb')
73 | self.trackingnet_path = os.path.join(data_dir, 'trackingnet')
74 | self.vot18_path = os.path.join(data_dir, 'vot2018')
75 | self.vot22_path = os.path.join(data_dir, 'vot2022')
76 | self.itb_path = os.path.join(data_dir, 'itb')
77 | self.tnl2k_path = os.path.join(data_dir, 'tnl2k')
78 | self.lasot_extension_subset_path_path = os.path.join(data_dir, 'lasot_extension_subset')
79 | self.davis_dir = ''
80 | self.youtubevos_dir = ''
81 |
82 | self.got_packed_results_path = ''
83 | self.got_reports_path = ''
84 | self.tn_packed_results_path = ''
85 |
86 |
87 | def create_default_local_file_ITP_test(workspace_dir, data_dir, save_dir):
88 | comment = {'results_path': 'Where to store tracking results',
89 | 'network_path': 'Where tracking networks are stored.'}
90 |
91 | path = os.path.join(os.path.dirname(__file__), 'local.py')
92 | with open(path, 'w') as f:
93 | settings = EnvSettings_ITP(workspace_dir, data_dir, save_dir)
94 |
95 | f.write('from lib.test.evaluation.environment import EnvSettings\n\n')
96 | f.write('def local_env_settings():\n')
97 | f.write(' settings = EnvSettings()\n\n')
98 | f.write(' # Set your local paths here.\n\n')
99 |
100 | for attr in dir(settings):
101 | comment_str = None
102 | if attr in comment:
103 | comment_str = comment[attr]
104 | attr_val = getattr(settings, attr)
105 | if not attr.startswith('__') and not callable(attr_val):
106 | if comment_str is None:
107 | f.write(' settings.{} = \'{}\'\n'.format(attr, attr_val))
108 | else:
109 | f.write(' settings.{} = \'{}\' # {}\n'.format(attr, attr_val, comment_str))
110 | f.write('\n return settings\n\n')
111 |
112 |
113 | def env_settings():
114 | env_module_name = 'lib.test.evaluation.local'
115 | try:
116 | env_module = importlib.import_module(env_module_name)
117 | return env_module.local_env_settings()
118 | except:
119 | env_file = os.path.join(os.path.dirname(__file__), 'local.py')
120 |
121 | # Create a default file
122 | create_default_local_file()
123 | raise RuntimeError('YOU HAVE NOT SETUP YOUR local.py!!!\n Go to "{}" and set all the paths you need. '
124 | 'Then try to run again.'.format(env_file))
--------------------------------------------------------------------------------
/lib/train/run_training.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import importlib
5 | import cv2 as cv
6 | import torch.backends.cudnn
7 | import torch.distributed as dist
8 |
9 | import random
10 | import numpy as np
11 | torch.backends.cudnn.benchmark = False
12 |
13 | import _init_paths
14 | import lib.train.admin.settings as ws_settings
15 |
16 |
17 | def init_seeds(seed):
18 | print('seed:',seed)
19 | random.seed(seed)
20 | np.random.seed(seed)
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed(seed)
23 | torch.backends.cudnn.deterministic = True
24 | torch.backends.cudnn.benchmark = False
25 |
26 |
27 | def run_training(script_name, config_name, cudnn_benchmark=True, local_rank=-1, save_dir=None,base_seed=None,
28 | use_lmdb=False, script_name_prv=None, config_name_prv=None, use_wandb=False,
29 | distill=None,datr=None, script_teacher=None, config_teacher=None):
30 | """Run the train script.
31 | args:
32 | script_name: Name of emperiment in the "experiments/" folder.
33 | config_name: Name of the yaml file in the "experiments/".
34 | cudnn_benchmark: Use cudnn benchmark or not (default is True).
35 | """
36 | if save_dir is None:
37 | print("save_dir dir is not given. Use the default dir instead.")
38 | # This is needed to avoid strange crashes related to opencv
39 | cv.setNumThreads(0)
40 |
41 | torch.backends.cudnn.benchmark = cudnn_benchmark
42 |
43 | print('script_name: {}.py config_name: {}.yaml'.format(script_name, config_name))
44 |
45 | '''2021.1.5 set seed for different process'''
46 | if base_seed is not None:
47 | if local_rank != -1:
48 | init_seeds(base_seed + local_rank)
49 | else:
50 | init_seeds(base_seed)
51 |
52 | settings = ws_settings.Settings()
53 | settings.script_name = script_name
54 | settings.config_name = config_name
55 | settings.project_path = 'train/{}/{}'.format(script_name, config_name)
56 | if script_name_prv is not None and config_name_prv is not None:
57 | settings.project_path_prv = 'train/{}/{}'.format(script_name_prv, config_name_prv)
58 | settings.local_rank = local_rank
59 | settings.save_dir = os.path.abspath(save_dir)
60 | settings.use_lmdb = use_lmdb
61 | prj_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
62 | settings.cfg_file = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_name, config_name))
63 | settings.use_wandb = use_wandb
64 | if datr==1:
65 | expr_module = importlib.import_module('lib.train.train_script_datr')
66 | elif distill:
67 | settings.distill = distill
68 | settings.script_teacher = script_teacher
69 | settings.config_teacher = config_teacher
70 | if script_teacher is not None and config_teacher is not None:
71 | settings.project_path_teacher = 'train/{}/{}'.format(script_teacher, config_teacher)
72 | settings.cfg_file_teacher = os.path.join(prj_dir, 'experiments/%s/%s.yaml' % (script_teacher, config_teacher))
73 | expr_module = importlib.import_module('lib.train.train_script_distill')
74 | else:
75 | expr_module = importlib.import_module('lib.train.train_script')
76 | expr_func = getattr(expr_module, 'run')
77 | expr_func(settings)
78 |
79 |
80 | def main():
81 | parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.')
82 | parser.add_argument('--script', type=str, required=True, help='Name of the train script.')
83 | parser.add_argument('--config', type=str, required=True, help="Name of the config file.")
84 | parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).')
85 | parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
86 | parser.add_argument('--save_dir', type=str, help='the directory to save checkpoints and logs')
87 | parser.add_argument('--seed', type=int, default=42, help='seed for random numbers')
88 | parser.add_argument('--use_lmdb', type=int, choices=[0, 1], default=0) # whether datasets are in lmdb format
89 | parser.add_argument('--script_prv', type=str, default=None, help='Name of the train script of previous model.')
90 | parser.add_argument('--config_prv', type=str, default=None, help="Name of the config file of previous model.")
91 | parser.add_argument('--use_wandb', type=int, choices=[0, 1], default=0) # whether to use wandb
92 | # for knowledge distillation
93 | parser.add_argument('--distill', type=int, choices=[0, 1], default=0) # whether to use knowledge distillation
94 | parser.add_argument('--datr', type=int, choices=[0, 1,2], default=0) # whether to use our data augmentation datr
95 | parser.add_argument('--script_teacher', type=str, help='teacher script name')
96 | parser.add_argument('--config_teacher', type=str, help='teacher yaml configure file name')
97 |
98 | args = parser.parse_args()
99 | if args.local_rank != -1:
100 | dist.init_process_group(backend='nccl')
101 | torch.cuda.set_device(args.local_rank)
102 | else:
103 | torch.cuda.set_device(0)
104 | run_training(args.script, args.config, cudnn_benchmark=args.cudnn_benchmark,
105 | local_rank=args.local_rank, save_dir=args.save_dir, base_seed=args.seed,
106 | use_lmdb=args.use_lmdb, script_name_prv=args.script_prv, config_name_prv=args.config_prv,
107 | use_wandb=args.use_wandb,
108 | distill=args.distill, datr=args.datr,script_teacher=args.script_teacher, config_teacher=args.config_teacher)
109 |
110 |
111 | if __name__ == '__main__':
112 | main()
113 |
--------------------------------------------------------------------------------
/lib/vis/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 | import cv2
5 |
6 |
7 | def draw_figure(fig):
8 | fig.canvas.draw()
9 | fig.canvas.flush_events()
10 | plt.pause(0.001)
11 |
12 |
13 | def show_tensor(a: torch.Tensor, fig_num = None, title = None, range=(None, None), ax=None):
14 | """Display a 2D tensor.
15 | args:
16 | fig_num: Figure number.
17 | title: Title of figure.
18 | """
19 | a_np = a.squeeze().cpu().clone().detach().numpy()
20 | if a_np.ndim == 3:
21 | a_np = np.transpose(a_np, (1, 2, 0))
22 |
23 | if ax is None:
24 | fig = plt.figure(fig_num)
25 | plt.tight_layout()
26 | plt.cla()
27 | plt.imshow(a_np, vmin=range[0], vmax=range[1])
28 | plt.axis('off')
29 | plt.axis('equal')
30 | if title is not None:
31 | plt.title(title)
32 | draw_figure(fig)
33 | else:
34 | ax.cla()
35 | ax.imshow(a_np, vmin=range[0], vmax=range[1])
36 | ax.set_axis_off()
37 | ax.axis('equal')
38 | if title is not None:
39 | ax.set_title(title)
40 | draw_figure(plt.gcf())
41 |
42 |
43 | def plot_graph(a: torch.Tensor, fig_num = None, title = None):
44 | """Plot graph. Data is a 1D tensor.
45 | args:
46 | fig_num: Figure number.
47 | title: Title of figure.
48 | """
49 | a_np = a.squeeze().cpu().clone().detach().numpy()
50 | if a_np.ndim > 1:
51 | raise ValueError
52 | fig = plt.figure(fig_num)
53 | # plt.tight_layout()
54 | plt.cla()
55 | plt.plot(a_np)
56 | if title is not None:
57 | plt.title(title)
58 | draw_figure(fig)
59 |
60 |
61 | def show_image_with_boxes(im, boxes, iou_pred=None, disp_ids=None):
62 | im_np = im.clone().cpu().squeeze().numpy()
63 | im_np = np.ascontiguousarray(im_np.transpose(1, 2, 0).astype(np.uint8))
64 |
65 | boxes = boxes.view(-1, 4).cpu().numpy().round().astype(int)
66 |
67 | # Draw proposals
68 | for i_ in range(boxes.shape[0]):
69 | if disp_ids is None or disp_ids[i_]:
70 | bb = boxes[i_, :]
71 | disp_color = (i_*38 % 256, (255 - i_*97) % 256, (123 + i_*66) % 256)
72 | cv2.rectangle(im_np, (bb[0], bb[1]), (bb[0] + bb[2], bb[1] + bb[3]),
73 | disp_color, 1)
74 |
75 | if iou_pred is not None:
76 | text_pos = (bb[0], bb[1] - 5)
77 | cv2.putText(im_np, 'ID={} IOU = {:3.2f}'.format(i_, iou_pred[i_]), text_pos,
78 | cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, bottomLeftOrigin=False)
79 |
80 | im_tensor = torch.from_numpy(im_np.transpose(2, 0, 1)).float()
81 |
82 | return im_tensor
83 |
84 |
85 |
86 | def _pascal_color_map(N=256, normalized=False):
87 | """
88 | Python implementation of the color map function for the PASCAL VOC data set.
89 | Official Matlab version can be found in the PASCAL VOC devkit
90 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit
91 | """
92 |
93 | def bitget(byteval, idx):
94 | return (byteval & (1 << idx)) != 0
95 |
96 | dtype = 'float32' if normalized else 'uint8'
97 | cmap = np.zeros((N, 3), dtype=dtype)
98 | for i in range(N):
99 | r = g = b = 0
100 | c = i
101 | for j in range(8):
102 | r = r | (bitget(c, 0) << 7 - j)
103 | g = g | (bitget(c, 1) << 7 - j)
104 | b = b | (bitget(c, 2) << 7 - j)
105 | c = c >> 3
106 |
107 | cmap[i] = np.array([r, g, b])
108 |
109 | cmap = cmap / 255 if normalized else cmap
110 | return cmap
111 |
112 |
113 | def overlay_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None):
114 | """ Overlay mask over image.
115 | Source: https://github.com/albertomontesg/davis-interactive/blob/master/davisinteractive/utils/visualization.py
116 | This function allows you to overlay a mask over an image with some
117 | transparency.
118 | # Arguments
119 | im: Numpy Array. Array with the image. The shape must be (H, W, 3) and
120 | the pixels must be represented as `np.uint8` data type.
121 | ann: Numpy Array. Array with the mask. The shape must be (H, W) and the
122 | values must be intergers
123 | alpha: Float. Proportion of alpha to apply at the overlaid mask.
124 | colors: Numpy Array. Optional custom colormap. It must have shape (N, 3)
125 | being N the maximum number of colors to represent.
126 | contour_thickness: Integer. Thickness of each object index contour draw
127 | over the overlay. This function requires to have installed the
128 | package `opencv-python`.
129 | # Returns
130 | Numpy Array: Image of the overlay with shape (H, W, 3) and data type
131 | `np.uint8`.
132 | """
133 | im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int)
134 | if im.shape[:-1] != ann.shape:
135 | raise ValueError('First two dimensions of `im` and `ann` must match')
136 | if im.shape[-1] != 3:
137 | raise ValueError('im must have three channels at the 3 dimension')
138 |
139 | colors = colors or _pascal_color_map()
140 | colors = np.asarray(colors, dtype=np.uint8)
141 |
142 | mask = colors[ann]
143 | fg = im * alpha + (1 - alpha) * mask
144 |
145 | img = im.copy()
146 | img[ann > 0] = fg[ann > 0]
147 |
148 | if contour_thickness: # pragma: no cover
149 | import cv2
150 | for obj_id in np.unique(ann[ann > 0]):
151 | contours = cv2.findContours((ann == obj_id).astype(
152 | np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:]
153 | cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(),
154 | contour_thickness)
155 | return img
156 |
--------------------------------------------------------------------------------
/lib/utils/heapmap_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def generate_heatmap(bboxes, patch_size=320, stride=16):
6 | """
7 | Generate ground truth heatmap same as CenterNet
8 | Args:
9 | bboxes (torch.Tensor): shape of [num_search, bs, 4]
10 |
11 | Returns:
12 | gaussian_maps: list of generated heatmap
13 |
14 | """
15 | gaussian_maps = []
16 | heatmap_size = patch_size // stride
17 | for single_patch_bboxes in bboxes:
18 | bs = single_patch_bboxes.shape[0]
19 | gt_scoremap = torch.zeros(bs, heatmap_size, heatmap_size)
20 | classes = torch.arange(bs).to(torch.long)
21 | bbox = single_patch_bboxes * heatmap_size
22 | wh = bbox[:, 2:]
23 | centers_int = (bbox[:, :2] + wh / 2).round()
24 | CenterNetHeatMap.generate_score_map(gt_scoremap, classes, wh, centers_int, 0.7)
25 | gaussian_maps.append(gt_scoremap.to(bbox.device))
26 | return gaussian_maps
27 |
28 |
29 | class CenterNetHeatMap(object):
30 | @staticmethod
31 | def generate_score_map(fmap, gt_class, gt_wh, centers_int, min_overlap):
32 | radius = CenterNetHeatMap.get_gaussian_radius(gt_wh, min_overlap)
33 | radius = torch.clamp_min(radius, 0)
34 | radius = radius.type(torch.int).cpu().numpy()
35 | for i in range(gt_class.shape[0]):
36 | channel_index = gt_class[i]
37 | CenterNetHeatMap.draw_gaussian(fmap[channel_index], centers_int[i], radius[i])
38 |
39 | @staticmethod
40 | def get_gaussian_radius(box_size, min_overlap):
41 | """
42 | copyed from CornerNet
43 | box_size (w, h), it could be a torch.Tensor, numpy.ndarray, list or tuple
44 | notice: we are using a bug-version, please refer to fix bug version in CornerNet
45 | """
46 | # box_tensor = torch.Tensor(box_size)
47 | box_tensor = box_size
48 | width, height = box_tensor[..., 0], box_tensor[..., 1]
49 |
50 | a1 = 1
51 | b1 = height + width
52 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
53 | sq1 = torch.sqrt(b1 ** 2 - 4 * a1 * c1)
54 | r1 = (b1 + sq1) / 2
55 |
56 | a2 = 4
57 | b2 = 2 * (height + width)
58 | c2 = (1 - min_overlap) * width * height
59 | sq2 = torch.sqrt(b2 ** 2 - 4 * a2 * c2)
60 | r2 = (b2 + sq2) / 2
61 |
62 | a3 = 4 * min_overlap
63 | b3 = -2 * min_overlap * (height + width)
64 | c3 = (min_overlap - 1) * width * height
65 | sq3 = torch.sqrt(b3 ** 2 - 4 * a3 * c3)
66 | r3 = (b3 + sq3) / 2
67 |
68 | return torch.min(r1, torch.min(r2, r3))
69 |
70 | @staticmethod
71 | def gaussian2D(radius, sigma=1):
72 | # m, n = [(s - 1.) / 2. for s in shape]
73 | m, n = radius
74 | y, x = np.ogrid[-m: m + 1, -n: n + 1]
75 |
76 | gauss = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
77 | gauss[gauss < np.finfo(gauss.dtype).eps * gauss.max()] = 0
78 | return gauss
79 |
80 | @staticmethod
81 | def draw_gaussian(fmap, center, radius, k=1):
82 | diameter = 2 * radius + 1
83 | gaussian = CenterNetHeatMap.gaussian2D((radius, radius), sigma=diameter / 6)
84 | gaussian = torch.Tensor(gaussian)
85 | x, y = int(center[0]), int(center[1])
86 | height, width = fmap.shape[:2]
87 |
88 | left, right = min(x, radius), min(width - x, radius + 1)
89 | top, bottom = min(y, radius), min(height - y, radius + 1)
90 |
91 | masked_fmap = fmap[y - top: y + bottom, x - left: x + right]
92 | masked_gaussian = gaussian[radius - top: radius + bottom, radius - left: radius + right]
93 | if min(masked_gaussian.shape) > 0 and min(masked_fmap.shape) > 0:
94 | masked_fmap = torch.max(masked_fmap, masked_gaussian * k)
95 | fmap[y - top: y + bottom, x - left: x + right] = masked_fmap
96 | # return fmap
97 |
98 |
99 | def compute_grids(features, strides):
100 | """
101 | grids regret to the input image size
102 | """
103 | grids = []
104 | for level, feature in enumerate(features):
105 | h, w = feature.size()[-2:]
106 | shifts_x = torch.arange(
107 | 0, w * strides[level],
108 | step=strides[level],
109 | dtype=torch.float32, device=feature.device)
110 | shifts_y = torch.arange(
111 | 0, h * strides[level],
112 | step=strides[level],
113 | dtype=torch.float32, device=feature.device)
114 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
115 | shift_x = shift_x.reshape(-1)
116 | shift_y = shift_y.reshape(-1)
117 | grids_per_level = torch.stack((shift_x, shift_y), dim=1) + \
118 | strides[level] // 2
119 | grids.append(grids_per_level)
120 | return grids
121 |
122 |
123 | def get_center3x3(locations, centers, strides, range=3):
124 | '''
125 | Inputs:
126 | locations: M x 2
127 | centers: N x 2
128 | strides: M
129 | '''
130 | range = (range - 1) / 2
131 | M, N = locations.shape[0], centers.shape[0]
132 | locations_expanded = locations.view(M, 1, 2).expand(M, N, 2) # M x N x 2
133 | centers_expanded = centers.view(1, N, 2).expand(M, N, 2) # M x N x 2
134 | strides_expanded = strides.view(M, 1, 1).expand(M, N, 2) # M x N
135 | centers_discret = ((centers_expanded / strides_expanded).int() * strides_expanded).float() + \
136 | strides_expanded / 2 # M x N x 2
137 | dist_x = (locations_expanded[:, :, 0] - centers_discret[:, :, 0]).abs()
138 | dist_y = (locations_expanded[:, :, 1] - centers_discret[:, :, 1]).abs()
139 | return (dist_x <= strides_expanded[:, :, 0] * range) & \
140 | (dist_y <= strides_expanded[:, :, 0] * range)
141 |
142 |
143 | def get_pred(score_map_ctr, size_map, offset_map, feat_size):
144 | max_score, idx = torch.max(score_map_ctr.flatten(1), dim=1, keepdim=True)
145 |
146 | idx = idx.unsqueeze(1).expand(idx.shape[0], 2, 1)
147 | size = size_map.flatten(2).gather(dim=2, index=idx).squeeze(-1)
148 | offset = offset_map.flatten(2).gather(dim=2, index=idx).squeeze(-1)
149 |
150 | return size * feat_size, offset
151 |
--------------------------------------------------------------------------------
/lib/models/layers/attn_blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from timm.models.layers import Mlp, DropPath, trunc_normal_, lecun_normal_
5 |
6 | from lib.models.layers.attn import Attention
7 |
8 |
9 | def candidate_elimination(attn: torch.Tensor, tokens: torch.Tensor, lens_t: int, keep_ratio: float, global_index: torch.Tensor, box_mask_z: torch.Tensor):
10 | """
11 | Eliminate potential background candidates for computation reduction and noise cancellation.
12 | Args:
13 | attn (torch.Tensor): [B, num_heads, L_t + L_s, L_t + L_s], attention weights
14 | tokens (torch.Tensor): [B, L_t + L_s, C], template and search region tokens
15 | lens_t (int): length of template
16 | keep_ratio (float): keep ratio of search region tokens (candidates)
17 | global_index (torch.Tensor): global index of search region tokens
18 | box_mask_z (torch.Tensor): template mask used to accumulate attention weights
19 |
20 | Returns:
21 | tokens_new (torch.Tensor): tokens after candidate elimination
22 | keep_index (torch.Tensor): indices of kept search region tokens
23 | removed_index (torch.Tensor): indices of removed search region tokens
24 | """
25 | lens_s = attn.shape[-1] - lens_t
26 | bs, hn, _, _ = attn.shape
27 |
28 | lens_keep = math.ceil(keep_ratio * lens_s)
29 | if lens_keep == lens_s:
30 | return tokens, global_index, None
31 |
32 | attn_t = attn[:, :, :lens_t, lens_t:]
33 |
34 | if box_mask_z is not None:
35 | box_mask_z = box_mask_z.unsqueeze(1).unsqueeze(-1).expand(-1, attn_t.shape[1], -1, attn_t.shape[-1])
36 | # attn_t = attn_t[:, :, box_mask_z, :]
37 | attn_t = attn_t[box_mask_z]
38 | attn_t = attn_t.view(bs, hn, -1, lens_s)
39 | attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s
40 |
41 | # attn_t = [attn_t[i, :, box_mask_z[i, :], :] for i in range(attn_t.size(0))]
42 | # attn_t = [attn_t[i].mean(dim=1).mean(dim=0) for i in range(len(attn_t))]
43 | # attn_t = torch.stack(attn_t, dim=0)
44 | else:
45 | attn_t = attn_t.mean(dim=2).mean(dim=1) # B, H, L-T, L_s --> B, L_s
46 |
47 | # use sort instead of topk, due to the speed issue
48 | # https://github.com/pytorch/pytorch/issues/22812
49 | sorted_attn, indices = torch.sort(attn_t, dim=1, descending=True)
50 |
51 | topk_attn, topk_idx = sorted_attn[:, :lens_keep], indices[:, :lens_keep]
52 | non_topk_attn, non_topk_idx = sorted_attn[:, lens_keep:], indices[:, lens_keep:]
53 |
54 | keep_index = global_index.gather(dim=1, index=topk_idx)
55 | removed_index = global_index.gather(dim=1, index=non_topk_idx)
56 |
57 | # separate template and search tokens
58 | tokens_t = tokens[:, :lens_t]
59 | tokens_s = tokens[:, lens_t:]
60 |
61 | # obtain the attentive and inattentive tokens
62 | B, L, C = tokens_s.shape
63 | # topk_idx_ = topk_idx.unsqueeze(-1).expand(B, lens_keep, C)
64 | attentive_tokens = tokens_s.gather(dim=1, index=topk_idx.unsqueeze(-1).expand(B, -1, C))
65 | # inattentive_tokens = tokens_s.gather(dim=1, index=non_topk_idx.unsqueeze(-1).expand(B, -1, C))
66 |
67 | # compute the weighted combination of inattentive tokens
68 | # fused_token = non_topk_attn @ inattentive_tokens
69 |
70 | # concatenate these tokens
71 | # tokens_new = torch.cat([tokens_t, attentive_tokens, fused_token], dim=0)
72 | tokens_new = torch.cat([tokens_t, attentive_tokens], dim=1)
73 |
74 | return tokens_new, keep_index, removed_index
75 |
76 |
77 | class CEBlock(nn.Module):
78 |
79 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
80 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, keep_ratio_search=1.0,):
81 | super().__init__()
82 | self.norm1 = norm_layer(dim)
83 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
84 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
85 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
86 | self.norm2 = norm_layer(dim)
87 | mlp_hidden_dim = int(dim * mlp_ratio)
88 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
89 |
90 | self.keep_ratio_search = keep_ratio_search
91 |
92 | def forward(self, x, global_index_template, global_index_search, mask=None, ce_template_mask=None, keep_ratio_search=None):
93 | x_attn, attn = self.attn(self.norm1(x), mask, True)
94 | x = x + self.drop_path(x_attn)
95 | lens_t = global_index_template.shape[1]
96 |
97 | removed_index_search = None
98 | if self.keep_ratio_search < 1 and (keep_ratio_search is None or keep_ratio_search < 1):
99 | keep_ratio_search = self.keep_ratio_search if keep_ratio_search is None else keep_ratio_search
100 | x, global_index_search, removed_index_search = candidate_elimination(attn, x, lens_t, keep_ratio_search, global_index_search, ce_template_mask)
101 |
102 | x = x + self.drop_path(self.mlp(self.norm2(x)))
103 | return x, global_index_template, global_index_search, removed_index_search, attn
104 |
105 |
106 | class Block(nn.Module):
107 |
108 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
109 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
110 | super().__init__()
111 | self.norm1 = norm_layer(dim)
112 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
113 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
114 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
115 | self.norm2 = norm_layer(dim)
116 | mlp_hidden_dim = int(dim * mlp_ratio)
117 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
118 |
119 | def forward(self, x, mask=None):
120 | x = x + self.drop_path(self.attn(self.norm1(x), mask))
121 | x = x + self.drop_path(self.mlp(self.norm2(x)))
122 | return x
123 |
--------------------------------------------------------------------------------