├── .gitignore ├── README.md ├── experiments ├── siamgat_ct_googlenet_alldataset │ └── config.yaml ├── siamgat_ct_googlenet_got10k │ └── config.yaml ├── siamgat_googlenet │ └── config.yaml ├── siamgat_googlenet_got10k │ └── config.yaml ├── siamgat_googlenet_lasot │ └── config.yaml └── siamgat_googlenet_trackingnet │ └── config.yaml ├── pysot ├── __init__.py ├── __pycache__ │ └── __init__.cpython-35.pyc ├── core │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-35.pyc │ └── config.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── augmentation.cpython-35.pyc │ │ └── dataset.cpython-35.pyc │ ├── augmentation.py │ └── dataset.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ └── loss_car.cpython-35.pyc │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── googlenet.cpython-35.pyc │ │ │ └── googlenet_ou.cpython-35.pyc │ │ ├── googlenet.py │ │ ├── googlenet_ct.py │ │ └── googlenet_ou.py │ ├── head │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ └── car_head.cpython-35.pyc │ │ ├── car_head.py │ │ └── head_utils.py │ ├── init_weight.py │ ├── loss_car.py │ ├── model_builder_gat.py │ ├── model_builder_gat_ct.py │ ├── model_builder_gat_got.py │ └── neck │ │ ├── __init__.py │ │ └── neck.py ├── tracker │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── base_tracker.cpython-35.pyc │ │ ├── siamgat_tracker.cpython-35.pyc │ │ └── siamgat_tracker_hp_search.cpython-35.pyc │ ├── base_tracker.py │ └── siamgat_tracker.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── average_meter.cpython-35.pyc │ ├── bbox.cpython-35.pyc │ ├── distributed.cpython-35.pyc │ ├── location_grid.cpython-35.pyc │ ├── log_helper.cpython-35.pyc │ ├── lr_scheduler.cpython-35.pyc │ ├── misc.cpython-35.pyc │ └── model_load.cpython-35.pyc │ ├── average_meter.py │ ├── bbox.py │ ├── distributed.py │ ├── location_grid.py │ ├── log_helper.py │ ├── lr_scheduler.py │ ├── misc.py │ └── model_load.py ├── requirement.txt ├── toolkit ├── __init__.py ├── __pycache__ │ └── __init__.cpython-35.pyc ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── dataset.cpython-35.pyc │ │ ├── got10k.cpython-35.pyc │ │ ├── lasot.cpython-35.pyc │ │ ├── otb.cpython-35.pyc │ │ ├── uav.cpython-35.pyc │ │ ├── video.cpython-35.pyc │ │ └── vot.cpython-35.pyc │ ├── dataset.py │ ├── got10k.py │ ├── lasot.py │ ├── otb.py │ ├── uav.py │ └── video.py ├── evaluation │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ └── ope_benchmark.cpython-35.pyc │ └── ope_benchmark.py ├── utils │ ├── __pycache__ │ │ └── statistics.cpython-35.pyc │ └── statistics.py └── visualization │ ├── __init__.py │ ├── draw_success_precision.py │ └── draw_utils.py ├── tools ├── eval.py ├── testTracker.py └── train.py └── training_dataset ├── coco ├── gen_json.py ├── par_crop.py ├── pycocotools │ ├── Makefile │ ├── __init__.py │ ├── _mask.c │ ├── _mask.pyx │ ├── coco.py │ ├── cocoeval.py │ ├── common │ │ ├── gason.cpp │ │ ├── gason.h │ │ ├── maskApi.c │ │ └── maskApi.h │ ├── mask.py │ └── setup.py └── readme.md ├── det ├── gen_json.py ├── par_crop.py └── readme.md ├── got10k ├── gen_json.py ├── par_crop.py └── readme.md ├── lasot ├── gen_json.py ├── par_crop.py └── readme.md ├── trackingnet ├── gen_json.py ├── par_crop.py └── readme.md ├── vid ├── gen_json.py ├── par_crop.py ├── parse_vid.py └── readme.md └── yt_bb ├── check.py ├── checknum.py ├── gen_json.py ├── par_crop.py └── readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /pysot/core/configTest.py 3 | /tools/* 4 | *.txt 5 | !/tools/testTracker.py 6 | !/tools/eval.py 7 | !/tools/train.py 8 | !/tools/snapshot/*.pth 9 | /toolkit/datasets/trackingnet.py 10 | /toolkit/datasets/vot.py 11 | !/pretrained_models 12 | *.pth 13 | /pysot/models/model_builder_gat_Test.py 14 | /pysot/tracker/siamgat_tracker_upmis.py 15 | /setup.py 16 | !requirement.txt -------------------------------------------------------------------------------- /experiments/siamgat_ct_googlenet_alldataset/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamgat_ct_googlenet" 2 | 3 | BACKBONE: 4 | TYPE: "googlenet_ct" 5 | PRETRAINED: 'pretrained_models/inception_v3.pth' 6 | TRAIN_LAYERS: ['Mixed_5b','Mixed_5c','Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'channel_reduce'] 7 | TRAIN_EPOCH: 10 8 | LAYERS_LR: 0.1 9 | 10 | ADJUST: 11 | ADJUST: true 12 | TYPE: 'GoogLeNetAdjustLayer' 13 | KWARGS: 14 | in_channels: 768 15 | out_channels: 256 16 | crop_pad: 4 17 | 18 | TRACK: 19 | TYPE: 'SiamGATTracker' 20 | EXEMPLAR_SIZE: 127 21 | INSTANCE_SIZE: 287 22 | SCORE_SIZE: 25 23 | CONTEXT_AMOUNT: 0.5 24 | STRIDE: 8 25 | OFFSET: 45 26 | 27 | TRAIN: 28 | EPOCH: 20 29 | START_EPOCH: 0 30 | SEARCH_SIZE: 287 31 | BATCH_SIZE: 28 32 | CLS_WEIGHT: 1.0 33 | LOC_WEIGHT: 3.0 34 | CEN_WEIGHT: 1.0 35 | RESUME: '' 36 | PRETRAINED: '' 37 | NUM_CLASSES: 2 38 | NUM_CONVS: 4 39 | PRIOR_PROB: 0.01 40 | OUTPUT_SIZE: 25 41 | ATTENTION: True 42 | CHANNEL_NUM: 256 43 | 44 | LR: 45 | TYPE: 'log' 46 | KWARGS: 47 | start_lr: 0.01 48 | end_lr: 0.0005 49 | LR_WARMUP: 50 | TYPE: 'step' 51 | EPOCH: 5 52 | KWARGS: 53 | start_lr: 0.005 54 | end_lr: 0.01 55 | step: 1 56 | 57 | DATASET: 58 | NAMES: 59 | - 'VID' 60 | - 'DET' 61 | - 'COCO' 62 | - 'GOT' 63 | - 'LaSOT' 64 | - 'TrackingNet' 65 | 66 | VIDEOS_PER_EPOCH: 800000 67 | 68 | TEMPLATE: 69 | SHIFT: 4 70 | SCALE: 0.05 71 | BLUR: 0.0 72 | FLIP: 0.0 73 | COLOR: 1.0 74 | 75 | SEARCH: 76 | SHIFT: 64 77 | SCALE: 0.18 78 | BLUR: 0.2 79 | FLIP: 0.0 80 | COLOR: 1.0 81 | 82 | NEG: 0.0 83 | GRAY: 0.0 -------------------------------------------------------------------------------- /experiments/siamgat_ct_googlenet_got10k/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamgat_ct_googlenet" 2 | 3 | BACKBONE: 4 | TYPE: "googlenet_ct" 5 | PRETRAINED: 'pretrained_models/inception_v3.pth' 6 | TRAIN_LAYERS: ['Mixed_5b','Mixed_5c','Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'channel_reduce'] 7 | TRAIN_EPOCH: 10 8 | LAYERS_LR: 0.1 9 | 10 | ADJUST: 11 | ADJUST: true 12 | TYPE: 'GoogLeNetAdjustLayer' 13 | KWARGS: 14 | in_channels: 768 15 | out_channels: 256 16 | crop_pad: 4 17 | 18 | TRACK: 19 | TYPE: 'SiamGATTracker' 20 | EXEMPLAR_SIZE: 127 21 | INSTANCE_SIZE: 287 22 | SCORE_SIZE: 25 23 | CONTEXT_AMOUNT: 0.5 24 | STRIDE: 8 25 | OFFSET: 45 26 | 27 | TRAIN: 28 | EPOCH: 20 29 | START_EPOCH: 0 30 | SEARCH_SIZE: 287 31 | BATCH_SIZE: 28 32 | CLS_WEIGHT: 1.0 33 | LOC_WEIGHT: 3.0 34 | CEN_WEIGHT: 1.0 35 | RESUME: '' 36 | PRETRAINED: '' 37 | NUM_CLASSES: 2 38 | NUM_CONVS: 4 39 | PRIOR_PROB: 0.01 40 | OUTPUT_SIZE: 25 41 | ATTENTION: True 42 | CHANNEL_NUM: 256 43 | 44 | LR: 45 | TYPE: 'log' 46 | KWARGS: 47 | start_lr: 0.01 48 | end_lr: 0.0005 49 | LR_WARMUP: 50 | TYPE: 'step' 51 | EPOCH: 5 52 | KWARGS: 53 | start_lr: 0.005 54 | end_lr: 0.01 55 | step: 1 56 | 57 | DATASET: 58 | NAMES: 59 | - 'GOT' 60 | 61 | VIDEOS_PER_EPOCH: 600000 62 | 63 | TEMPLATE: 64 | SHIFT: 4 65 | SCALE: 0.05 66 | BLUR: 0.0 67 | FLIP: 0.0 68 | COLOR: 1.0 69 | 70 | SEARCH: 71 | SHIFT: 64 72 | SCALE: 0.18 73 | BLUR: 0.2 74 | FLIP: 0.0 75 | COLOR: 1.0 76 | 77 | NEG: 0.0 78 | GRAY: 0.0 79 | -------------------------------------------------------------------------------- /experiments/siamgat_googlenet/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamgat_googlenet" 2 | 3 | BACKBONE: 4 | TYPE: "googlenet_ou" 5 | PRETRAINED: 'pretrained_models/inception_v3.pth' 6 | TRAIN_LAYERS: ['Mixed_5b','Mixed_5c','Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'channel_reduce'] 7 | CHANNEL_REDUCE_LAYERS: ['channel_reduce'] 8 | TRAIN_EPOCH: 10 9 | CROP_PAD: 4 10 | LAYERS_LR: 0.1 11 | 12 | TRACK: 13 | TYPE: 'SiamGATTracker' 14 | EXEMPLAR_SIZE: 127 15 | INSTANCE_SIZE: 287 16 | SCORE_SIZE: 25 17 | CONTEXT_AMOUNT: 0.5 18 | STRIDE: 8 19 | OFFSET: 45 20 | 21 | TRAIN: 22 | EPOCH: 20 23 | START_EPOCH: 0 24 | SEARCH_SIZE: 287 25 | BATCH_SIZE: 76 26 | CLS_WEIGHT: 1.0 27 | LOC_WEIGHT: 3.0 28 | CEN_WEIGHT: 1.0 29 | RESUME: '' 30 | PRETRAINED: '' 31 | NUM_CLASSES: 2 32 | NUM_CONVS: 4 33 | PRIOR_PROB: 0.01 34 | OUTPUT_SIZE: 25 35 | ATTENTION: True 36 | 37 | LR: 38 | TYPE: 'log' 39 | KWARGS: 40 | start_lr: 0.01 41 | end_lr: 0.0005 42 | LR_WARMUP: 43 | TYPE: 'step' 44 | EPOCH: 5 45 | KWARGS: 46 | start_lr: 0.005 47 | end_lr: 0.01 48 | step: 1 49 | 50 | DATASET: 51 | NAMES: 52 | - 'VID' 53 | - 'YOUTUBEBB' 54 | - 'COCO' 55 | - 'DET' 56 | - 'GOT' 57 | 58 | VIDEOS_PER_EPOCH: 800000 59 | 60 | TEMPLATE: 61 | SHIFT: 4 62 | SCALE: 0.05 63 | BLUR: 0.0 64 | FLIP: 0.0 65 | COLOR: 1.0 66 | 67 | SEARCH: 68 | SHIFT: 64 69 | SCALE: 0.18 70 | BLUR: 0.2 71 | FLIP: 0.0 72 | COLOR: 1.0 73 | 74 | NEG: 0.2 75 | GRAY: 0.0 76 | -------------------------------------------------------------------------------- /experiments/siamgat_googlenet_got10k/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamgat_googlenet" 2 | 3 | BACKBONE: 4 | TYPE: "googlenet" 5 | PRETRAINED: 'pretrained_models/inception_v3.pth' 6 | TRAIN_LAYERS: ['Mixed_5b','Mixed_5c','Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'channel_reduce'] 7 | CHANNEL_REDUCE_LAYERS: ['channel_reduce'] 8 | TRAIN_EPOCH: 10 9 | CROP_PAD: 4 10 | LAYERS_LR: 0.1 11 | 12 | TRACK: 13 | TYPE: 'SiamGATTracker' 14 | EXEMPLAR_SIZE: 127 15 | INSTANCE_SIZE: 287 16 | SCORE_SIZE: 25 17 | CONTEXT_AMOUNT: 0.5 18 | STRIDE: 8 19 | OFFSET: 45 20 | 21 | TRAIN: 22 | EPOCH: 20 23 | START_EPOCH: 0 24 | SEARCH_SIZE: 287 25 | BATCH_SIZE: 76 26 | CLS_WEIGHT: 1.0 27 | LOC_WEIGHT: 3.0 28 | CEN_WEIGHT: 1.0 29 | RESUME: '' 30 | PRETRAINED: '' 31 | NUM_CLASSES: 2 32 | NUM_CONVS: 4 33 | PRIOR_PROB: 0.01 34 | OUTPUT_SIZE: 25 35 | ATTENTION: True 36 | 37 | LR: 38 | TYPE: 'log' 39 | KWARGS: 40 | start_lr: 0.01 41 | end_lr: 0.0005 42 | LR_WARMUP: 43 | TYPE: 'step' 44 | EPOCH: 5 45 | KWARGS: 46 | start_lr: 0.005 47 | end_lr: 0.01 48 | step: 1 49 | 50 | DATASET: 51 | NAMES: 52 | - 'GOT' 53 | 54 | VIDEOS_PER_EPOCH: 600000 55 | 56 | TEMPLATE: 57 | SHIFT: 4 58 | SCALE: 0.05 59 | BLUR: 0.0 60 | FLIP: 0.0 61 | COLOR: 1.0 62 | 63 | SEARCH: 64 | SHIFT: 64 65 | SCALE: 0.18 66 | BLUR: 0.2 67 | FLIP: 0.0 68 | COLOR: 1.0 69 | 70 | NEG: 0.2 71 | GRAY: 0.0 72 | -------------------------------------------------------------------------------- /experiments/siamgat_googlenet_lasot/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamgat_googlenet" 2 | 3 | BACKBONE: 4 | TYPE: "googlenet" 5 | PRETRAINED: 'pretrained_models/inception_v3.pth' 6 | TRAIN_LAYERS: ['Mixed_5b','Mixed_5c','Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'channel_reduce'] 7 | CHANNEL_REDUCE_LAYERS: ['channel_reduce'] 8 | TRAIN_EPOCH: 10 9 | CROP_PAD: 4 10 | LAYERS_LR: 0.1 11 | 12 | TRACK: 13 | TYPE: 'SiamGATTracker' 14 | EXEMPLAR_SIZE: 127 15 | INSTANCE_SIZE: 287 16 | SCORE_SIZE: 25 17 | CONTEXT_AMOUNT: 0.5 18 | STRIDE: 8 19 | OFFSET: 45 20 | 21 | TRAIN: 22 | EPOCH: 20 23 | START_EPOCH: 0 24 | SEARCH_SIZE: 287 25 | BATCH_SIZE: 76 26 | CLS_WEIGHT: 1.0 27 | LOC_WEIGHT: 3.0 28 | CEN_WEIGHT: 1.0 29 | RESUME: '' 30 | PRETRAINED: '' 31 | NUM_CLASSES: 2 32 | NUM_CONVS: 4 33 | PRIOR_PROB: 0.01 34 | OUTPUT_SIZE: 25 35 | ATTENTION: True 36 | 37 | LR: 38 | TYPE: 'log' 39 | KWARGS: 40 | start_lr: 0.01 41 | end_lr: 0.0005 42 | LR_WARMUP: 43 | TYPE: 'step' 44 | EPOCH: 5 45 | KWARGS: 46 | start_lr: 0.005 47 | end_lr: 0.01 48 | step: 1 49 | 50 | DATASET: 51 | NAMES: 52 | - 'LaSOT' 53 | 54 | VIDEOS_PER_EPOCH: 600000 55 | 56 | TEMPLATE: 57 | SHIFT: 4 58 | SCALE: 0.05 59 | BLUR: 0.0 60 | FLIP: 0.0 61 | COLOR: 1.0 62 | 63 | SEARCH: 64 | SHIFT: 64 65 | SCALE: 0.18 66 | BLUR: 0.2 67 | FLIP: 0.0 68 | COLOR: 1.0 69 | 70 | NEG: 0.2 71 | GRAY: 0.0 72 | -------------------------------------------------------------------------------- /experiments/siamgat_googlenet_trackingnet/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamgat_googlenet" 2 | 3 | BACKBONE: 4 | TYPE: "googlenet" 5 | PRETRAINED: 'pretrained_models/inception_v3.pth' 6 | TRAIN_LAYERS: ['Mixed_5b','Mixed_5c','Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'channel_reduce'] 7 | CHANNEL_REDUCE_LAYERS: ['channel_reduce'] 8 | TRAIN_EPOCH: 10 9 | CROP_PAD: 4 10 | LAYERS_LR: 0.1 11 | 12 | TRACK: 13 | TYPE: 'SiamGATTracker' 14 | EXEMPLAR_SIZE: 127 15 | INSTANCE_SIZE: 287 16 | SCORE_SIZE: 25 17 | CONTEXT_AMOUNT: 0.5 18 | STRIDE: 8 19 | OFFSET: 45 20 | 21 | TRAIN: 22 | EPOCH: 20 23 | START_EPOCH: 0 24 | SEARCH_SIZE: 287 25 | BATCH_SIZE: 76 26 | CLS_WEIGHT: 1.0 27 | LOC_WEIGHT: 3.0 28 | CEN_WEIGHT: 1.0 29 | RESUME: '' 30 | PRETRAINED: '' 31 | NUM_CLASSES: 2 32 | NUM_CONVS: 4 33 | PRIOR_PROB: 0.01 34 | OUTPUT_SIZE: 25 35 | ATTENTION: True 36 | 37 | LR: 38 | TYPE: 'log' 39 | KWARGS: 40 | start_lr: 0.01 41 | end_lr: 0.0005 42 | LR_WARMUP: 43 | TYPE: 'step' 44 | EPOCH: 5 45 | KWARGS: 46 | start_lr: 0.005 47 | end_lr: 0.01 48 | step: 1 49 | 50 | DATASET: 51 | NAMES: 52 | - 'TrackingNet' 53 | 54 | VIDEOS_PER_EPOCH: 600000 55 | 56 | TEMPLATE: 57 | SHIFT: 4 58 | SCALE: 0.05 59 | BLUR: 0.0 60 | FLIP: 0.0 61 | COLOR: 1.0 62 | 63 | SEARCH: 64 | SHIFT: 64 65 | SCALE: 0.18 66 | BLUR: 0.2 67 | FLIP: 0.0 68 | COLOR: 1.0 69 | 70 | NEG: 0.2 71 | GRAY: 0.0 72 | -------------------------------------------------------------------------------- /pysot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/__init__.py -------------------------------------------------------------------------------- /pysot/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/core/__init__.py -------------------------------------------------------------------------------- /pysot/core/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/core/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/core/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | from yacs.config import CfgNode as CN 9 | 10 | __C = CN() 11 | 12 | cfg = __C 13 | 14 | __C.META_ARC = "siamgat_googlenet" 15 | 16 | __C.CUDA = True 17 | 18 | # ------------------------------------------------------------------------ # 19 | # Training options 20 | # ------------------------------------------------------------------------ # 21 | __C.TRAIN = CN() 22 | 23 | __C.TRAIN.EXEMPLAR_SIZE = 127 24 | 25 | __C.TRAIN.SEARCH_SIZE = 287 26 | 27 | __C.TRAIN.OUTPUT_SIZE = 25 28 | 29 | __C.TRAIN.RESUME = '' 30 | 31 | __C.TRAIN.PRETRAINED = '' 32 | 33 | __C.TRAIN.LOG_DIR = './logs' 34 | 35 | __C.TRAIN.SNAPSHOT_DIR = './snapshot' 36 | 37 | __C.TRAIN.EPOCH = 20 38 | 39 | __C.TRAIN.START_EPOCH = 0 40 | 41 | __C.TRAIN.BATCH_SIZE = 32 42 | 43 | __C.TRAIN.NUM_WORKERS = 4 # 1 44 | 45 | __C.TRAIN.MOMENTUM = 0.9 46 | 47 | __C.TRAIN.WEIGHT_DECAY = 0.0001 48 | 49 | __C.TRAIN.CLS_WEIGHT = 1.0 50 | 51 | __C.TRAIN.LOC_WEIGHT = 3.0 52 | 53 | __C.TRAIN.CEN_WEIGHT = 1.0 54 | 55 | __C.TRAIN.PRINT_FREQ = 20 56 | 57 | __C.TRAIN.LOG_GRADS = False 58 | 59 | __C.TRAIN.GRAD_CLIP = 10.0 60 | 61 | __C.TRAIN.BASE_LR = 0.005 62 | 63 | __C.TRAIN.LR = CN() 64 | 65 | __C.TRAIN.LR.TYPE = 'log' 66 | 67 | __C.TRAIN.LR.KWARGS = CN(new_allowed=True) 68 | 69 | __C.TRAIN.LR_WARMUP = CN() 70 | 71 | __C.TRAIN.LR_WARMUP.WARMUP = True 72 | 73 | __C.TRAIN.LR_WARMUP.TYPE = 'step' 74 | 75 | __C.TRAIN.LR_WARMUP.EPOCH = 5 76 | 77 | __C.TRAIN.LR_WARMUP.KWARGS = CN(new_allowed=True) 78 | 79 | __C.TRAIN.NUM_CLASSES = 2 80 | 81 | __C.TRAIN.NUM_CONVS = 4 82 | 83 | __C.TRAIN.PRIOR_PROB = 0.01 84 | 85 | __C.TRAIN.LOSS_ALPHA = 0.25 86 | 87 | __C.TRAIN.LOSS_GAMMA = 2.0 88 | 89 | __C.TRAIN.CHANNEL_NUM = 256 90 | 91 | # ------------------------------------------------------------------------ # 92 | # Dataset options 93 | # ------------------------------------------------------------------------ # 94 | __C.DATASET = CN(new_allowed=True) 95 | 96 | # Augmentation 97 | # for template 98 | __C.DATASET.TEMPLATE = CN() 99 | 100 | # for detail discussion 101 | __C.DATASET.TEMPLATE.SHIFT = 4 102 | 103 | __C.DATASET.TEMPLATE.SCALE = 0.05 104 | 105 | __C.DATASET.TEMPLATE.BLUR = 0.0 106 | 107 | __C.DATASET.TEMPLATE.FLIP = 0.0 108 | 109 | __C.DATASET.TEMPLATE.COLOR = 1.0 110 | 111 | __C.DATASET.SEARCH = CN() 112 | 113 | __C.DATASET.SEARCH.SHIFT = 64 114 | 115 | __C.DATASET.SEARCH.SCALE = 0.18 116 | 117 | __C.DATASET.SEARCH.BLUR = 0.0 118 | 119 | __C.DATASET.SEARCH.FLIP = 0.0 120 | 121 | __C.DATASET.SEARCH.COLOR = 1.0 122 | 123 | # for detail discussion 124 | __C.DATASET.NEG = 0.0 125 | 126 | __C.DATASET.GRAY = 0.0 127 | 128 | __C.DATASET.NAMES = ('VID', 'COCO', 'DET', 'YOUTUBEBB', 'GOT', 'LaSOT', 'TrackingNet') 129 | 130 | __C.DATASET.VID = CN() 131 | __C.DATASET.VID.ROOT = '/PATH/TO/VID' 132 | __C.DATASET.VID.ANNO = 'training_dataset/vid/train.json' 133 | __C.DATASET.VID.FRAME_RANGE = 100 134 | __C.DATASET.VID.NUM_USE = 100000 # repeat until reach NUM_USE 135 | 136 | __C.DATASET.YOUTUBEBB = CN() 137 | __C.DATASET.YOUTUBEBB.ROOT = '/PATH/TO/YTBB' 138 | __C.DATASET.YOUTUBEBB.ANNO = 'training_dataset/yt_bb/train.json' 139 | __C.DATASET.YOUTUBEBB.FRAME_RANGE = 3 140 | __C.DATASET.YOUTUBEBB.NUM_USE = 200000 141 | 142 | __C.DATASET.COCO = CN() 143 | __C.DATASET.COCO.ROOT = '/PATH/TO/COCO' 144 | __C.DATASET.COCO.ANNO = 'training_dataset/coco/train2017.json' 145 | __C.DATASET.COCO.FRAME_RANGE = 1 146 | __C.DATASET.COCO.NUM_USE = 50000 147 | 148 | __C.DATASET.DET = CN() 149 | __C.DATASET.DET.ROOT = '/PATH/TO/DET' 150 | __C.DATASET.DET.ANNO = 'training_dataset/det/train.json' 151 | __C.DATASET.DET.FRAME_RANGE = 1 152 | __C.DATASET.DET.NUM_USE = 50000 153 | 154 | __C.DATASET.GOT = CN() 155 | __C.DATASET.GOT.ROOT = '/PATH/TO/GOT' 156 | __C.DATASET.GOT.ANNO = 'training_dataset/got10k/train.json' 157 | __C.DATASET.GOT.FRAME_RANGE = 50 158 | __C.DATASET.GOT.NUM_USE = 200000 159 | 160 | __C.DATASET.LaSOT = CN() 161 | __C.DATASET.LaSOT.ROOT = '/PATH/TO/LaSOT' 162 | __C.DATASET.LaSOT.ANNO = 'training_dataset/lasot/train.json' 163 | __C.DATASET.LaSOT.FRAME_RANGE = 100 164 | __C.DATASET.LaSOT.NUM_USE = 150000 165 | 166 | __C.DATASET.TrackingNet = CN() 167 | __C.DATASET.TrackingNet.ROOT = '/PATH/TO/TrackingNet' 168 | __C.DATASET.TrackingNet.ANNO = 'training_dataset/trackingnet/train.json' 169 | __C.DATASET.TrackingNet.FRAME_RANGE = 100 170 | __C.DATASET.TrackingNet.NUM_USE = 350000 171 | 172 | __C.DATASET.VIDEOS_PER_EPOCH = 800000 173 | 174 | # ------------------------------------------------------------------------ # 175 | # Backbone options 176 | # ------------------------------------------------------------------------ # 177 | __C.BACKBONE = CN() 178 | 179 | # Backbone type, current only support googlenet;alexnet; 180 | __C.BACKBONE.TYPE = 'googlenet' 181 | 182 | __C.BACKBONE.KWARGS = CN(new_allowed=True) 183 | 184 | # Pretrained backbone weights 185 | __C.BACKBONE.PRETRAINED = '' 186 | 187 | # Train backbone layers 188 | __C.BACKBONE.TRAIN_LAYERS = [] 189 | 190 | # Train channel_layer 191 | __C.BACKBONE.CHANNEL_REDUCE_LAYERS = [] 192 | 193 | # Layer LR 194 | __C.BACKBONE.LAYERS_LR = 0.1 195 | 196 | # Crop_pad 197 | __C.BACKBONE.CROP_PAD = 4 198 | 199 | # Switch to train layer 200 | __C.BACKBONE.TRAIN_EPOCH = 10 201 | 202 | # Backbone offset 203 | __C.BACKBONE.OFFSET = 13 204 | 205 | # Backbone stride 206 | __C.BACKBONE.STRIDE = 8 207 | 208 | # ------------------------------------------------------------------------ # 209 | # Adjust layer options 210 | # ------------------------------------------------------------------------ # 211 | __C.ADJUST = CN() 212 | 213 | # Adjust layer 214 | __C.ADJUST.ADJUST = True 215 | 216 | __C.ADJUST.KWARGS = CN(new_allowed=True) 217 | 218 | # Adjust layer type 219 | __C.ADJUST.TYPE = "GoogLeNetAdjustLayer" 220 | 221 | # ------------------------------------------------------------------------ # 222 | # Tracker options 223 | # ------------------------------------------------------------------------ # 224 | __C.TRACK = CN() 225 | 226 | # SiamGAT 227 | __C.TRAIN.ATTENTION = True 228 | 229 | __C.TRACK.TYPE = 'SiamGATTracker' 230 | 231 | # Scale penalty 232 | __C.TRACK.PENALTY_K = 0.04 233 | 234 | # Window influence 235 | __C.TRACK.WINDOW_INFLUENCE = 0.44 236 | 237 | # Interpolation learning rate 238 | __C.TRACK.LR = 0.4 239 | 240 | # Exemplar size 241 | __C.TRACK.EXEMPLAR_SIZE = 127 242 | 243 | # Instance size 244 | __C.TRACK.INSTANCE_SIZE = 287 245 | 246 | # Context amount 247 | __C.TRACK.CONTEXT_AMOUNT = 0.5 248 | 249 | __C.TRACK.STRIDE = 8 250 | 251 | __C.TRACK.OFFSET = 45 252 | 253 | __C.TRACK.SCORE_SIZE = 25 254 | 255 | __C.TRACK.hanming = True 256 | 257 | __C.TRACK.REGION_S = 0.1 258 | 259 | __C.TRACK.REGION_L = 0.44 260 | 261 | # ------------------------------------------------------------------------ # 262 | # HP_SEARCH parameters 263 | # ------------------------------------------------------------------------ # 264 | __C.HP_SEARCH = CN() 265 | 266 | __C.HP_SEARCH.OTB100 = [0.28, 0.16, 0.4] 267 | 268 | # __C.HP_SEARCH.OTB100 = [0.32, 0.3, 0.38] 269 | 270 | __C.HP_SEARCH.GOT_10k = [0.7, 0.02, 0.35] 271 | 272 | # __C.HP_SEARCH.GOT_10k = [0.9, 0.25, 0.35] 273 | 274 | __C.HP_SEARCH.UAV123 = [0.24, 0.04, 0.04] 275 | 276 | __C.HP_SEARCH.LaSOT = [0.35, 0.05, 0.18] 277 | 278 | # __C.HP_SEARCH.LaSOT = [0.45, 0.05, 0.18] 279 | 280 | # __C.HP_SEARCH.TrackingNet = [0.4, 0.05, 0.4] -------------------------------------------------------------------------------- /pysot/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/datasets/__init__.py -------------------------------------------------------------------------------- /pysot/datasets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/datasets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/datasets/__pycache__/augmentation.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/datasets/__pycache__/augmentation.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/datasets/__pycache__/dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/datasets/__pycache__/dataset.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import numpy as np 9 | import cv2 10 | 11 | from pysot.utils.bbox import corner2center, \ 12 | Center, center2corner, Corner 13 | 14 | 15 | class Augmentation: 16 | def __init__(self, shift, scale, blur, flip, color): 17 | self.shift = shift 18 | self.scale = scale 19 | self.blur = blur 20 | self.flip = flip 21 | self.color = color 22 | self.rgbVar = np.array( 23 | [[-0.55919361, 0.98062831, - 0.41940627], 24 | [1.72091413, 0.19879334, - 1.82968581], 25 | [4.64467907, 4.73710203, 4.88324118]], dtype=np.float32) 26 | 27 | @staticmethod 28 | def random(): 29 | return np.random.random() * 2 - 1.0 30 | 31 | def _crop_roi(self, image, bbox, out_sz, padding=(0, 0, 0)): 32 | bbox = [float(x) for x in bbox] 33 | a = (out_sz-1) / (bbox[2]-bbox[0]) 34 | b = (out_sz-1) / (bbox[3]-bbox[1]) 35 | c = -a * bbox[0] 36 | d = -b * bbox[1] 37 | mapping = np.array([[a, 0, c], 38 | [0, b, d]]).astype(np.float) 39 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), 40 | borderMode=cv2.BORDER_CONSTANT, 41 | borderValue=padding) 42 | return crop 43 | 44 | def _blur_aug(self, image): 45 | def rand_kernel(): 46 | sizes = np.arange(5, 46, 2) 47 | size = np.random.choice(sizes) 48 | kernel = np.zeros((size, size)) 49 | c = int(size/2) 50 | wx = np.random.random() 51 | kernel[:, c] += 1. / size * wx 52 | kernel[c, :] += 1. / size * (1-wx) 53 | return kernel 54 | kernel = rand_kernel() 55 | image = cv2.filter2D(image, -1, kernel) 56 | return image 57 | 58 | def _color_aug(self, image): 59 | offset = np.dot(self.rgbVar, np.random.randn(3, 1)) 60 | offset = offset[::-1] # bgr 2 rgb 61 | offset = offset.reshape(3) 62 | image = image - offset 63 | return image 64 | 65 | def _gray_aug(self, image): 66 | grayed = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 67 | image = cv2.cvtColor(grayed, cv2.COLOR_GRAY2BGR) 68 | return image 69 | 70 | def _shift_scale_aug(self, image, bbox, crop_bbox, size): 71 | im_h, im_w = image.shape[:2] 72 | 73 | # adjust crop bounding box 74 | crop_bbox_center = corner2center(crop_bbox) 75 | if self.scale: 76 | scale_x = (1.0 + Augmentation.random() * self.scale) 77 | scale_y = (1.0 + Augmentation.random() * self.scale) 78 | h, w = crop_bbox_center.h, crop_bbox_center.w 79 | scale_x = min(scale_x, float(im_w) / w) 80 | scale_y = min(scale_y, float(im_h) / h) 81 | crop_bbox_center = Center(crop_bbox_center.x, 82 | crop_bbox_center.y, 83 | crop_bbox_center.w * scale_x, 84 | crop_bbox_center.h * scale_y) 85 | 86 | crop_bbox = center2corner(crop_bbox_center) 87 | if self.shift: 88 | sx = Augmentation.random() * self.shift 89 | sy = Augmentation.random() * self.shift 90 | 91 | x1, y1, x2, y2 = crop_bbox 92 | 93 | sx = max(-x1, min(im_w - 1 - x2, sx)) 94 | sy = max(-y1, min(im_h - 1 - y2, sy)) 95 | 96 | crop_bbox = Corner(x1 + sx, y1 + sy, x2 + sx, y2 + sy) 97 | 98 | # adjust target bounding box 99 | x1, y1 = crop_bbox.x1, crop_bbox.y1 100 | bbox = Corner(bbox.x1 - x1, bbox.y1 - y1, 101 | bbox.x2 - x1, bbox.y2 - y1) 102 | 103 | if self.scale: 104 | bbox = Corner(bbox.x1 / scale_x, bbox.y1 / scale_y, 105 | bbox.x2 / scale_x, bbox.y2 / scale_y) 106 | 107 | image = self._crop_roi(image, crop_bbox, size) 108 | return image, bbox 109 | 110 | def _flip_aug(self, image, bbox): 111 | image = cv2.flip(image, 1) 112 | width = image.shape[1] 113 | bbox = Corner(width - 1 - bbox.x2, bbox.y1, 114 | width - 1 - bbox.x1, bbox.y2) 115 | return image, bbox 116 | 117 | def __call__(self, image, bbox, size, gray=False): 118 | shape = image.shape 119 | crop_bbox = center2corner(Center(shape[0]//2, shape[1]//2, 120 | size-1, size-1)) 121 | # gray augmentation 122 | if gray: 123 | image = self._gray_aug(image) 124 | 125 | # shift scale augmentation 126 | image, bbox = self._shift_scale_aug(image, bbox, crop_bbox, size) 127 | 128 | # color augmentation 129 | if self.color > np.random.random(): 130 | image = self._color_aug(image) 131 | 132 | # blur augmentation 133 | if self.blur > np.random.random(): 134 | image = self._blur_aug(image) 135 | 136 | # flip augmentation 137 | if self.flip and self.flip > np.random.random(): 138 | image, bbox = self._flip_aug(image, bbox) 139 | return image, bbox 140 | -------------------------------------------------------------------------------- /pysot/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/models/__init__.py -------------------------------------------------------------------------------- /pysot/models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/models/__pycache__/loss_car.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/models/__pycache__/loss_car.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | from pysot.models.backbone.googlenet import Inception3 9 | from pysot.models.backbone.googlenet_ct import Inception3_ct 10 | from pysot.models.backbone.googlenet_ou import Inception3_ou 11 | 12 | BACKBONES = { 13 | # 'alexnetlegacy': alexnetlegacy, 14 | # 'mobilenetv2': mobilenetv2, 15 | # 'resnet18': resnet18, 16 | # 'resnet34': resnet34, 17 | # 'resnet50': resnet50, 18 | # 'alexnet': alexnet, 19 | 'googlenet': Inception3, 20 | 'googlenet_ct': Inception3_ct, 21 | 'googlenet_ou': Inception3_ou, 22 | } 23 | 24 | 25 | def get_backbone(name, **kwargs): 26 | return BACKBONES[name](**kwargs) 27 | -------------------------------------------------------------------------------- /pysot/models/backbone/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/models/backbone/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/models/backbone/__pycache__/googlenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/models/backbone/__pycache__/googlenet.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/models/backbone/__pycache__/googlenet_ou.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/models/backbone/__pycache__/googlenet_ou.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/models/head/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pysot/models/head/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/models/head/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/models/head/__pycache__/car_head.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/models/head/__pycache__/car_head.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/models/head/car_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | from pysot.models.head.head_utils import ConvMixerBlock 6 | 7 | 8 | 9 | class CARHead(torch.nn.Module): 10 | def __init__(self, cfg, in_channels): 11 | """ 12 | Arguments: 13 | in_channels (int): number of channels of the input feature 14 | """ 15 | super(CARHead, self).__init__() 16 | # TODO: Implement the sigmoid version first. 17 | num_classes = cfg.TRAIN.NUM_CLASSES 18 | 19 | cls_tower = [] 20 | bbox_tower = [] 21 | for i in range(cfg.TRAIN.NUM_CONVS): 22 | cls_tower.append( 23 | nn.Conv2d( 24 | in_channels, 25 | in_channels, 26 | kernel_size=3, 27 | stride=1, 28 | padding=1 29 | ) 30 | ) 31 | cls_tower.append(nn.GroupNorm(32, in_channels)) 32 | cls_tower.append(nn.ReLU()) 33 | bbox_tower.append( 34 | nn.Conv2d( 35 | in_channels, 36 | in_channels, 37 | kernel_size=3, 38 | stride=1, 39 | padding=1 40 | ) 41 | ) 42 | bbox_tower.append(nn.GroupNorm(32, in_channels)) 43 | bbox_tower.append(nn.ReLU()) 44 | 45 | self.add_module('cls_tower', nn.Sequential(*cls_tower)) 46 | self.add_module('bbox_tower', nn.Sequential(*bbox_tower)) 47 | self.cls_logits = nn.Conv2d( 48 | in_channels, num_classes, kernel_size=3, stride=1, 49 | padding=1 50 | ) 51 | self.bbox_pred = nn.Conv2d( 52 | in_channels, 4, kernel_size=3, stride=1, 53 | padding=1 54 | ) 55 | self.centerness = nn.Conv2d( 56 | in_channels, 1, kernel_size=3, stride=1, 57 | padding=1 58 | ) 59 | 60 | # initialization 61 | for modules in [self.cls_tower, self.bbox_tower, 62 | self.cls_logits, self.bbox_pred, 63 | self.centerness]: 64 | for l in modules.modules(): 65 | if isinstance(l, nn.Conv2d): 66 | torch.nn.init.normal_(l.weight, std=0.01) 67 | torch.nn.init.constant_(l.bias, 0) 68 | 69 | # initialize the bias for focal loss 70 | prior_prob = cfg.TRAIN.PRIOR_PROB 71 | bias_value = -math.log((1 - prior_prob) / prior_prob) 72 | torch.nn.init.constant_(self.cls_logits.bias, bias_value) 73 | 74 | def forward(self, x): 75 | cls_tower = self.cls_tower(x) 76 | logits = self.cls_logits(cls_tower) 77 | centerness = self.centerness(cls_tower) 78 | bbox_reg = torch.exp(self.bbox_pred(self.bbox_tower(x))) 79 | 80 | return logits, bbox_reg, centerness 81 | 82 | 83 | class CARHead_CT(torch.nn.Module): 84 | def __init__(self, cfg, kernel=3, spatial_num=2): 85 | """ 86 | Arguments: 87 | in_channels (int): number of channels of the input feature 88 | """ 89 | super(CARHead_CT, self).__init__() 90 | # TODO: Implement the sigmoid version first. 91 | num_classes = cfg.TRAIN.NUM_CLASSES 92 | in_channels = cfg.TRAIN.CHANNEL_NUM 93 | 94 | cls_tower = [] 95 | bbox_tower = [] 96 | 97 | self.cls_spatial = ConvMixerBlock(in_channels, 3, depth=spatial_num) 98 | self.bbox_spatial = ConvMixerBlock(in_channels, 3, depth=spatial_num) 99 | 100 | for i in range(cfg.TRAIN.NUM_CONVS): 101 | cls_tower.append( 102 | nn.Conv2d( 103 | in_channels, 104 | in_channels, 105 | kernel_size=kernel, 106 | stride=1, 107 | padding=kernel // 2, 108 | ) 109 | ) 110 | cls_tower.append(nn.GroupNorm(32, in_channels)) 111 | cls_tower.append(nn.ReLU()) 112 | bbox_tower.append( 113 | nn.Conv2d( 114 | in_channels, 115 | in_channels, 116 | kernel_size=kernel, 117 | stride=1, 118 | padding=kernel // 2, 119 | ) 120 | ) 121 | bbox_tower.append(nn.GroupNorm(32, in_channels)) 122 | bbox_tower.append(nn.ReLU()) 123 | 124 | self.add_module('cls_tower', nn.Sequential(*cls_tower)) 125 | self.add_module('bbox_tower', nn.Sequential(*bbox_tower)) 126 | self.cls_logits = nn.Conv2d( 127 | in_channels, num_classes, kernel_size=kernel, stride=kernel // 2, 128 | padding=1 129 | ) 130 | self.bbox_pred = nn.Conv2d( 131 | in_channels, 4, kernel_size=kernel, stride=kernel // 2, 132 | padding=1 133 | ) 134 | self.centerness = nn.Conv2d( 135 | in_channels, 1, kernel_size=kernel, stride=kernel // 2, 136 | padding=1 137 | ) 138 | 139 | # initialization 140 | for modules in [self.cls_tower, self.bbox_tower, 141 | self.cls_logits, self.bbox_pred, 142 | self.centerness]: 143 | for l in modules.modules(): 144 | if isinstance(l, nn.Conv2d): 145 | torch.nn.init.normal_(l.weight, std=0.01) 146 | torch.nn.init.constant_(l.bias, 0) 147 | 148 | # initialize the bias for focal loss 149 | prior_prob = cfg.TRAIN.PRIOR_PROB 150 | bias_value = -math.log((1 - prior_prob) / prior_prob) 151 | torch.nn.init.constant_(self.cls_logits.bias, bias_value) 152 | 153 | def forward(self, x): 154 | cls_x = self.cls_spatial(x) 155 | bbox_x = self.bbox_spatial(x) 156 | 157 | cls_tower = self.cls_tower(cls_x) 158 | bbox_tower = self.bbox_tower(bbox_x) 159 | 160 | logits = self.cls_logits(cls_tower) 161 | centerness = self.centerness(cls_tower) 162 | bbox_reg = torch.exp(self.bbox_pred(bbox_tower)) 163 | 164 | return logits, bbox_reg, centerness 165 | 166 | 167 | 168 | 169 | class Scale(nn.Module): 170 | def __init__(self, init_value=1.0): 171 | super(Scale, self).__init__() 172 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 173 | 174 | def forward(self, input): 175 | return input * self.scale 176 | 177 | -------------------------------------------------------------------------------- /pysot/models/head/head_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | class ConvMixerBlock(nn.Module): 7 | def __init__(self, dim, kernel_size, depth): 8 | super().__init__() 9 | self.convmixer = nn.Sequential( 10 | *[nn.Sequential( 11 | # token mixing 12 | Residual(nn.Sequential( 13 | nn.Conv2d(dim, dim, kernel_size, groups=dim, padding=kernel_size//2), 14 | nn.GELU(), 15 | nn.BatchNorm2d(dim), 16 | )), 17 | # channel mixing 18 | nn.Conv2d(dim, dim, kernel_size=1), 19 | nn.GELU(), 20 | nn.BatchNorm2d(dim), 21 | ) for i in range(depth)], 22 | ) 23 | 24 | def forward(self, x): 25 | output = self.convmixer(x) 26 | return output 27 | 28 | 29 | class Residual(nn.Module): 30 | def __init__(self, fn): 31 | super().__init__() 32 | self.fn = fn 33 | 34 | def forward(self, x): 35 | return self.fn(x)+x -------------------------------------------------------------------------------- /pysot/models/init_weight.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def init_weights(model): 5 | for m in model.modules(): 6 | if isinstance(m, nn.Conv2d): 7 | nn.init.kaiming_normal_(m.weight.data, 8 | mode='fan_out', 9 | nonlinearity='relu') 10 | elif isinstance(m, nn.BatchNorm2d): 11 | m.weight.data.fill_(1) 12 | m.bias.data.zero_() 13 | -------------------------------------------------------------------------------- /pysot/models/model_builder_gat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from pysot.core.config import cfg 13 | from pysot.models.loss_car import make_siamcar_loss_evaluator 14 | from pysot.models.backbone import get_backbone 15 | from pysot.models.head.car_head import CARHead 16 | from ..utils.location_grid import compute_locations 17 | 18 | 19 | class Graph_Attention_Union(nn.Module): 20 | def __init__(self, in_channel, out_channel): 21 | super(Graph_Attention_Union, self).__init__() 22 | 23 | # search region nodes linear transformation 24 | self.support = nn.Conv2d(in_channel, in_channel, 1, 1) 25 | 26 | # target template nodes linear transformation 27 | self.query = nn.Conv2d(in_channel, in_channel, 1, 1) 28 | 29 | # linear transformation for message passing 30 | self.g = nn.Sequential( 31 | nn.Conv2d(in_channel, in_channel, 1, 1), 32 | nn.BatchNorm2d(in_channel), 33 | nn.ReLU(inplace=True), 34 | ) 35 | 36 | # aggregated feature 37 | self.fi = nn.Sequential( 38 | nn.Conv2d(in_channel*2, out_channel, 1, 1), 39 | nn.BatchNorm2d(out_channel), 40 | nn.ReLU(inplace=True), 41 | ) 42 | 43 | def forward(self, zf, xf): 44 | # linear transformation 45 | xf_trans = self.query(xf) 46 | zf_trans = self.support(zf) 47 | 48 | # linear transformation for message passing 49 | xf_g = self.g(xf) 50 | zf_g = self.g(zf) 51 | 52 | # calculate similarity 53 | shape_x = xf_trans.shape 54 | shape_z = zf_trans.shape 55 | 56 | zf_trans_plain = zf_trans.view(-1, shape_z[1], shape_z[2] * shape_z[3]) 57 | zf_g_plain = zf_g.view(-1, shape_z[1], shape_z[2] * shape_z[3]).permute(0, 2, 1) 58 | xf_trans_plain = xf_trans.view(-1, shape_x[1], shape_x[2] * shape_x[3]).permute(0, 2, 1) 59 | 60 | similar = torch.matmul(xf_trans_plain, zf_trans_plain) 61 | similar = F.softmax(similar, dim=2) 62 | 63 | embedding = torch.matmul(similar, zf_g_plain).permute(0, 2, 1) 64 | embedding = embedding.view(-1, shape_x[1], shape_x[2], shape_x[3]) 65 | 66 | # aggregated feature 67 | output = torch.cat([embedding, xf_g], 1) 68 | output = self.fi(output) 69 | return output 70 | 71 | 72 | class ModelBuilder(nn.Module): 73 | def __init__(self): 74 | super(ModelBuilder, self).__init__() 75 | 76 | # build backbone 77 | self.backbone = get_backbone(cfg.BACKBONE.TYPE, 78 | **cfg.BACKBONE.KWARGS) 79 | 80 | # build car head 81 | self.car_head = CARHead(cfg, 256) 82 | 83 | # build response map 84 | self.attention = Graph_Attention_Union(256, 256) 85 | 86 | # build loss 87 | self.loss_evaluator = make_siamcar_loss_evaluator(cfg) 88 | 89 | def template(self, z, roi): 90 | zf = self.backbone(z, roi) 91 | self.zf = zf 92 | 93 | def track(self, x): 94 | xf = self.backbone(x) 95 | 96 | features = self.attention(self.zf, xf) 97 | 98 | cls, loc, cen = self.car_head(features) 99 | return { 100 | 'cls': cls, 101 | 'loc': loc, 102 | 'cen': cen 103 | } 104 | 105 | def log_softmax(self, cls): 106 | b, a2, h, w = cls.size() 107 | cls = cls.view(b, 2, a2//2, h, w) 108 | cls = cls.permute(0, 2, 3, 4, 1).contiguous() 109 | cls = F.log_softmax(cls, dim=4) 110 | return cls 111 | 112 | def forward(self, data): 113 | """ only used in training 114 | """ 115 | template = data['template'].cuda() 116 | search = data['search'].cuda() 117 | label_cls = data['label_cls'].cuda() 118 | label_loc = data['bbox'].cuda() 119 | target_box = data['target_box'].cuda() 120 | neg = data['neg'].cuda() 121 | 122 | # get feature 123 | zf = self.backbone(template, target_box) 124 | xf = self.backbone(search) 125 | 126 | features = self.attention(zf, xf) 127 | 128 | cls, loc, cen = self.car_head(features) 129 | locations = compute_locations(cls, cfg.TRACK.STRIDE, cfg.TRACK.OFFSET) 130 | cls = self.log_softmax(cls) 131 | cls_loss, loc_loss, cen_loss = self.loss_evaluator( 132 | locations, 133 | cls, 134 | loc, 135 | cen, label_cls, label_loc, neg 136 | ) 137 | 138 | # get loss 139 | outputs = {} 140 | outputs['total_loss'] = cfg.TRAIN.CLS_WEIGHT * cls_loss + \ 141 | cfg.TRAIN.LOC_WEIGHT * loc_loss + cfg.TRAIN.CEN_WEIGHT * cen_loss 142 | outputs['cls_loss'] = cls_loss 143 | outputs['loc_loss'] = loc_loss 144 | outputs['cen_loss'] = cen_loss 145 | return outputs 146 | -------------------------------------------------------------------------------- /pysot/models/model_builder_gat_ct.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from pysot.core.config import cfg 13 | from pysot.models.loss_car import make_siamcar_loss_evaluator 14 | from pysot.models.backbone import get_backbone 15 | from pysot.models.head.car_head import CARHead_CT as head 16 | from ..utils.location_grid import compute_locations 17 | from pysot.models.neck import get_neck 18 | 19 | 20 | class Graph_Attention_Union(nn.Module): 21 | def __init__(self, in_channel, out_channel, fore_groups=1, back_groups=1): 22 | super(Graph_Attention_Union, self).__init__() 23 | 24 | self.fore_groups = fore_groups 25 | self.back_groups = back_groups 26 | for i in range(fore_groups): 27 | # foreground nodes linear transformation 28 | self.add_module('fore_support'+str(i),nn.Conv2d(in_channel, in_channel, 1, 1, bias=False)) 29 | # search nodes linear transformation 30 | self.add_module('fore_query'+str(i),nn.Conv2d(in_channel, in_channel, 1, 1, bias=False)) 31 | # foreground transformation for message passing 32 | self.add_module('fore_g'+str(i),nn.Sequential( 33 | nn.Conv2d(in_channel, in_channel, 1, 1), 34 | nn.BatchNorm2d(in_channel), 35 | nn.ReLU(inplace=True),) 36 | ) 37 | for i in range(back_groups): 38 | # background nodes linear transformation 39 | self.add_module('back_support'+str(i),nn.Conv2d(in_channel, in_channel, 1, 1, bias=False)) 40 | # search nodes linear transformation 41 | self.add_module('back_query'+str(i),nn.Conv2d(in_channel, in_channel, 1, 1, bias=False)) 42 | # background transformation for message passing 43 | self.add_module('back_g'+str(i),nn.Sequential( 44 | nn.Conv2d(in_channel, in_channel, 1, 1), 45 | nn.BatchNorm2d(in_channel), 46 | nn.ReLU(inplace=True),) 47 | ) 48 | 49 | # search transformation for message passing 50 | self.xf_g = nn.Sequential( 51 | nn.Conv2d(in_channel, in_channel, 1, 1), 52 | nn.BatchNorm2d(in_channel), 53 | nn.ReLU(inplace=True), 54 | ) 55 | # aggregated feature 56 | self.fi = nn.Sequential( 57 | nn.Conv2d(in_channel*(fore_groups+back_groups+1), out_channel, 1, 1), 58 | nn.BatchNorm2d(out_channel), 59 | nn.ReLU(inplace=True), 60 | ) 61 | 62 | def forward(self, zf, xf, zf_mask): 63 | zf_fore = zf * zf_mask 64 | zf_back = zf * (1 - zf_mask) 65 | 66 | similars = [] 67 | ebds = [] 68 | for idx in range(self.fore_groups): 69 | query_func = getattr(self, 'fore_query'+str(idx)) 70 | support_func = getattr(self, 'fore_support'+str(idx)) 71 | g_func = getattr(self, 'fore_g'+str(idx)) 72 | similar, ebd = self.calculate(support_func(zf_fore), query_func(xf), g_func(zf_fore)) 73 | similars.append(similar) 74 | ebds.append(ebd) 75 | 76 | for idx in range(self.back_groups): 77 | query_func = getattr(self, 'back_query'+str(idx)) 78 | support_func = getattr(self, 'back_support'+str(idx)) 79 | g_func = getattr(self, 'back_g'+str(idx)) 80 | similar, ebd = self.calculate(support_func(zf_back), query_func(xf), g_func(zf_back)) 81 | similars.append(similar) 82 | ebds.append(ebd) 83 | ebds = torch.cat(ebds, dim=1) 84 | 85 | # aggregated feature 86 | output = torch.cat([ebds, self.xf_g(xf)], 1) 87 | output = self.fi(output) 88 | return output 89 | 90 | def calculate(self, zf, xf, zf_g): 91 | xf = F.normalize(xf, dim=1) 92 | zf = F.normalize(zf, dim=1) 93 | zf_flatten = zf.flatten(2) 94 | xf_flatten = xf.flatten(2) 95 | zf_g_flatten = zf_g.flatten(2) 96 | similar = torch.einsum("bcn,bcm->bnm", xf_flatten, zf_flatten) 97 | bs, c, xw, xh = xf.shape 98 | embedding = torch.einsum("bcm, bnm->bcn", zf_g_flatten, similar) 99 | embedding = embedding.reshape(bs, c, xw, xh) 100 | return similar, embedding 101 | 102 | 103 | class ModelBuilder(nn.Module): 104 | def __init__(self): 105 | super(ModelBuilder, self).__init__() 106 | 107 | # build backbone 108 | self.backbone = get_backbone(cfg.BACKBONE.TYPE, 109 | **cfg.BACKBONE.KWARGS) 110 | if cfg.ADJUST.ADJUST: 111 | self.neck = get_neck(cfg.ADJUST.TYPE, 112 | **cfg.ADJUST.KWARGS) 113 | 114 | # build response map 115 | self.attention = Graph_Attention_Union(cfg.TRAIN.CHANNEL_NUM, cfg.TRAIN.CHANNEL_NUM, fore_groups=1, back_groups=1) 116 | 117 | # build car head 118 | self.car_head = head(cfg) 119 | 120 | # build loss 121 | self.loss_evaluator = make_siamcar_loss_evaluator(cfg) 122 | 123 | def template(self, z, mask): 124 | zf = self.backbone(z) 125 | self.zf_mask = F.interpolate(mask, size=zf.shape[-1], mode='bilinear') 126 | 127 | if cfg.ADJUST.ADJUST: 128 | zf = self.neck(zf) 129 | self.zf = zf 130 | 131 | def track(self, x): 132 | xf = self.backbone(x) 133 | 134 | if cfg.ADJUST.ADJUST: 135 | xf = self.neck(xf) 136 | 137 | features = self.attention(self.zf, xf, self.zf_mask) 138 | 139 | cls, loc, cen = self.car_head(features) 140 | return { 141 | 'cls': cls, 142 | 'loc': loc, 143 | 'cen': cen 144 | } 145 | 146 | def log_softmax(self, cls): 147 | b, a2, h, w = cls.size() 148 | cls = cls.view(b, 2, a2//2, h, w) 149 | cls = cls.permute(0, 2, 3, 4, 1).contiguous() 150 | cls = F.log_softmax(cls, dim=4) 151 | return cls 152 | 153 | def forward(self, data): 154 | """ only used in training 155 | """ 156 | template = data['template'].cuda() 157 | search = data['search'].cuda() 158 | label_cls = data['label_cls'].cuda() 159 | label_loc = data['bbox'].cuda() 160 | neg = data['neg'].cuda() 161 | mask = data['mask'].cuda() 162 | 163 | # get feature 164 | zf = self.backbone(template) 165 | xf = self.backbone(search) 166 | 167 | if cfg.ADJUST.ADJUST: 168 | zf = self.neck(zf) 169 | xf = self.neck(xf) 170 | 171 | zf_mask = F.interpolate(mask, size=zf.shape[-1], mode='bilinear') 172 | features = self.attention(zf, xf, zf_mask) 173 | 174 | cls, loc, cen = self.car_head(features) 175 | locations = compute_locations(cls, cfg.TRACK.STRIDE, cfg.TRACK.OFFSET) 176 | cls = self.log_softmax(cls) 177 | cls_loss, loc_loss, cen_loss = self.loss_evaluator( 178 | locations, 179 | cls, 180 | loc, 181 | cen, label_cls, label_loc, neg 182 | ) 183 | 184 | # get loss 185 | outputs = {} 186 | outputs['total_loss'] = cfg.TRAIN.CLS_WEIGHT * cls_loss + \ 187 | cfg.TRAIN.LOC_WEIGHT * loc_loss + cfg.TRAIN.CEN_WEIGHT * cen_loss 188 | outputs['cls_loss'] = cls_loss 189 | outputs['loc_loss'] = loc_loss 190 | outputs['cen_loss'] = cen_loss 191 | return outputs 192 | -------------------------------------------------------------------------------- /pysot/models/model_builder_gat_got.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from pysot.core.config import cfg 13 | from pysot.models.loss_car import make_siamcar_loss_evaluator 14 | from pysot.models.backbone import get_backbone 15 | from pysot.models.head.car_head import CARHead 16 | from ..utils.location_grid import compute_locations 17 | 18 | 19 | class Graph_Attention_Union(nn.Module): 20 | def __init__(self, in_channel, out_channel): 21 | super(Graph_Attention_Union, self).__init__() 22 | 23 | # search region nodes linear transformation 24 | self.support = nn.Conv2d(in_channel, in_channel, 1, 1, bias=False) 25 | 26 | # target template nodes linear transformation 27 | self.query = nn.Conv2d(in_channel, in_channel, 1, 1, bias=False) 28 | 29 | # linear transformation for message passing 30 | self.g = nn.Sequential( 31 | nn.Conv2d(in_channel, in_channel, 1, 1), 32 | nn.BatchNorm2d(in_channel), 33 | nn.ReLU(inplace=True), 34 | ) 35 | 36 | # aggregated feature 37 | self.fi = nn.Sequential( 38 | nn.Conv2d(in_channel*2, out_channel, 1, 1), 39 | nn.BatchNorm2d(out_channel), 40 | nn.ReLU(inplace=True), 41 | ) 42 | 43 | def forward(self, zf, xf): 44 | # linear transformation 45 | xf_trans = self.query(xf) 46 | zf_trans = self.support(zf) 47 | 48 | # linear transformation for message passing 49 | xf_g = self.g(xf) 50 | zf_g = self.g(zf) 51 | 52 | # calculate similarity 53 | shape_x = xf_trans.shape 54 | shape_z = zf_trans.shape 55 | 56 | zf_trans_plain = zf_trans.view(-1, shape_z[1], shape_z[2] * shape_z[3]) 57 | zf_g_plain = zf_g.view(-1, shape_z[1], shape_z[2] * shape_z[3]).permute(0, 2, 1) 58 | xf_trans_plain = xf_trans.view(-1, shape_x[1], shape_x[2] * shape_x[3]).permute(0, 2, 1) 59 | 60 | similar = torch.matmul(xf_trans_plain, zf_trans_plain) 61 | similar = F.softmax(similar, dim=2) 62 | 63 | embedding = torch.matmul(similar, zf_g_plain).permute(0, 2, 1) 64 | embedding = embedding.view(-1, shape_x[1], shape_x[2], shape_x[3]) 65 | 66 | # aggregated feature 67 | output = torch.cat([embedding, xf_g], 1) 68 | output = self.fi(output) 69 | return output 70 | 71 | 72 | class ModelBuilder(nn.Module): 73 | def __init__(self): 74 | super(ModelBuilder, self).__init__() 75 | 76 | # build backbone 77 | self.backbone = get_backbone(cfg.BACKBONE.TYPE, 78 | **cfg.BACKBONE.KWARGS) 79 | 80 | # build car head 81 | self.car_head = CARHead(cfg, 256) 82 | 83 | # build response map 84 | self.attention = Graph_Attention_Union(256, 256) 85 | 86 | # build loss 87 | self.loss_evaluator = make_siamcar_loss_evaluator(cfg) 88 | 89 | def template(self, z, roi): 90 | zf = self.backbone(z, roi) 91 | self.zf = zf 92 | 93 | def track(self, x): 94 | xf = self.backbone(x) 95 | 96 | features = self.attention(self.zf, xf) 97 | 98 | cls, loc, cen = self.car_head(features) 99 | return { 100 | 'cls': cls, 101 | 'loc': loc, 102 | 'cen': cen 103 | } 104 | 105 | def log_softmax(self, cls): 106 | b, a2, h, w = cls.size() 107 | cls = cls.view(b, 2, a2//2, h, w) 108 | cls = cls.permute(0, 2, 3, 4, 1).contiguous() 109 | cls = F.log_softmax(cls, dim=4) 110 | return cls 111 | 112 | def forward(self, data): 113 | """ only used in training 114 | """ 115 | template = data['template'].cuda() 116 | search = data['search'].cuda() 117 | label_cls = data['label_cls'].cuda() 118 | label_loc = data['bbox'].cuda() 119 | target_box = data['target_box'].cuda() 120 | neg = data['neg'].cuda() 121 | 122 | # get feature 123 | zf = self.backbone(template, target_box) 124 | xf = self.backbone(search) 125 | 126 | features = self.attention(zf, xf) 127 | 128 | cls, loc, cen = self.car_head(features) 129 | locations = compute_locations(cls, cfg.TRACK.STRIDE, cfg.TRACK.OFFSET) 130 | cls = self.log_softmax(cls) 131 | cls_loss, loc_loss, cen_loss = self.loss_evaluator( 132 | locations, 133 | cls, 134 | loc, 135 | cen, label_cls, label_loc, neg 136 | ) 137 | 138 | # get loss 139 | outputs = {} 140 | outputs['total_loss'] = cfg.TRAIN.CLS_WEIGHT * cls_loss + \ 141 | cfg.TRAIN.LOC_WEIGHT * loc_loss + cfg.TRAIN.CEN_WEIGHT * cen_loss 142 | outputs['cls_loss'] = cls_loss 143 | outputs['loc_loss'] = loc_loss 144 | outputs['cen_loss'] = cen_loss 145 | return outputs 146 | -------------------------------------------------------------------------------- /pysot/models/neck/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | from pysot.models.neck.neck import GoogLeNetAdjustLayer 9 | 10 | NECKS = { 11 | 'GoogLeNetAdjustLayer': GoogLeNetAdjustLayer, 12 | } 13 | 14 | def get_neck(name, **kwargs): 15 | return NECKS[name](**kwargs) -------------------------------------------------------------------------------- /pysot/models/neck/neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class GoogLeNetAdjustLayer(nn.Module): 13 | ''' 14 | with mask: F.interpolate 15 | ''' 16 | def __init__(self, in_channels, out_channels, crop_pad=0, kernel=1): 17 | super(GoogLeNetAdjustLayer, self).__init__() 18 | self.channel_reduce = nn.Sequential( 19 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel), 20 | nn.BatchNorm2d(out_channels, eps=0.001), 21 | ) 22 | self.crop_pad = crop_pad 23 | 24 | def forward(self, x): 25 | x = self.channel_reduce(x) 26 | 27 | if x.shape[-1] > 25 and self.crop_pad > 0: 28 | crop_pad = self.crop_pad 29 | x = x[:, :, crop_pad:-crop_pad, crop_pad:-crop_pad] 30 | 31 | return x 32 | 33 | -------------------------------------------------------------------------------- /pysot/tracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/tracker/__init__.py -------------------------------------------------------------------------------- /pysot/tracker/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/tracker/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/tracker/__pycache__/base_tracker.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/tracker/__pycache__/base_tracker.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/tracker/__pycache__/siamgat_tracker.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/tracker/__pycache__/siamgat_tracker.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/tracker/__pycache__/siamgat_tracker_hp_search.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/tracker/__pycache__/siamgat_tracker_hp_search.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/tracker/base_tracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | 12 | from pysot.core.config import cfg 13 | 14 | 15 | class BaseTracker(object): 16 | """ Base tracker of single objec tracking 17 | """ 18 | def init(self, img, bbox): 19 | """ 20 | args: 21 | img(np.ndarray): BGR image 22 | bbox(list): [x, y, width, height] 23 | x, y need to be 0-based 24 | """ 25 | raise NotImplementedError 26 | 27 | def track(self, img): 28 | """ 29 | args: 30 | img(np.ndarray): BGR image 31 | return: 32 | bbox(list):[x, y, width, height] 33 | """ 34 | raise NotImplementedError 35 | 36 | 37 | class SiameseTracker(BaseTracker): 38 | def get_subwindow(self, im, pos, model_sz, original_sz, avg_chans): 39 | """ 40 | args: 41 | im: bgr based image 42 | pos: center position 43 | model_sz: exemplar size 44 | s_z: original size 45 | avg_chans: channel average 46 | """ 47 | if isinstance(pos, float): 48 | pos = [pos, pos] 49 | sz = original_sz 50 | im_sz = im.shape 51 | c = (original_sz + 1) / 2 52 | # context_xmin = round(pos[0] - c) # py2 and py3 round 53 | context_xmin = np.floor(pos[0] - c + 0.5) 54 | context_xmax = context_xmin + sz - 1 55 | # context_ymin = round(pos[1] - c) 56 | context_ymin = np.floor(pos[1] - c + 0.5) 57 | context_ymax = context_ymin + sz - 1 58 | left_pad = int(max(0., -context_xmin)) 59 | top_pad = int(max(0., -context_ymin)) 60 | right_pad = int(max(0., context_xmax - im_sz[1] + 1)) 61 | bottom_pad = int(max(0., context_ymax - im_sz[0] + 1)) 62 | 63 | context_xmin = context_xmin + left_pad 64 | context_xmax = context_xmax + left_pad 65 | context_ymin = context_ymin + top_pad 66 | context_ymax = context_ymax + top_pad 67 | 68 | r, c, k = im.shape 69 | if any([top_pad, bottom_pad, left_pad, right_pad]): 70 | size = (r + top_pad + bottom_pad, c + left_pad + right_pad, k) 71 | te_im = np.zeros(size, np.uint8) 72 | te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im 73 | if top_pad: 74 | te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans 75 | if bottom_pad: 76 | te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans 77 | if left_pad: 78 | te_im[:, 0:left_pad, :] = avg_chans 79 | if right_pad: 80 | te_im[:, c + left_pad:, :] = avg_chans 81 | im_patch = te_im[int(context_ymin):int(context_ymax + 1), 82 | int(context_xmin):int(context_xmax + 1), :] 83 | else: 84 | im_patch = im[int(context_ymin):int(context_ymax + 1), 85 | int(context_xmin):int(context_xmax + 1), :] 86 | 87 | if not np.array_equal(model_sz, original_sz): 88 | im_patch = cv2.resize(im_patch, (model_sz, model_sz)) 89 | im_patch = im_patch.transpose(2, 0, 1) 90 | im_patch = im_patch[np.newaxis, :, :, :] 91 | im_patch = im_patch.astype(np.float32) 92 | im_patch = torch.from_numpy(im_patch) 93 | if cfg.CUDA: 94 | im_patch = im_patch.cuda() 95 | return im_patch 96 | -------------------------------------------------------------------------------- /pysot/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__init__.py -------------------------------------------------------------------------------- /pysot/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/utils/__pycache__/average_meter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__pycache__/average_meter.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/utils/__pycache__/bbox.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__pycache__/bbox.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/utils/__pycache__/distributed.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__pycache__/distributed.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/utils/__pycache__/location_grid.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__pycache__/location_grid.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/utils/__pycache__/log_helper.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__pycache__/log_helper.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/utils/__pycache__/lr_scheduler.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__pycache__/lr_scheduler.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/utils/__pycache__/misc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__pycache__/misc.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/utils/__pycache__/model_load.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/pysot/utils/__pycache__/model_load.cpython-35.pyc -------------------------------------------------------------------------------- /pysot/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | 9 | class Meter(object): 10 | def __init__(self, name, val, avg): 11 | self.name = name 12 | self.val = val 13 | self.avg = avg 14 | 15 | def __repr__(self): 16 | return "{name}: {val:.6f} ({avg:.6f})".format( 17 | name=self.name, val=self.val, avg=self.avg 18 | ) 19 | 20 | def __format__(self, *tuples, **kwargs): 21 | return self.__repr__() 22 | 23 | 24 | class AverageMeter: 25 | """Computes and stores the average and current value""" 26 | def __init__(self, num=100): 27 | self.num = num 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = {} 32 | self.sum = {} 33 | self.count = {} 34 | self.history = {} 35 | 36 | def update(self, batch=1, **kwargs): 37 | val = {} 38 | for k in kwargs: 39 | val[k] = kwargs[k] / float(batch) 40 | self.val.update(val) 41 | for k in kwargs: 42 | if k not in self.sum: 43 | self.sum[k] = 0 44 | self.count[k] = 0 45 | self.history[k] = [] 46 | self.sum[k] += kwargs[k] 47 | self.count[k] += batch 48 | for _ in range(batch): 49 | self.history[k].append(val[k]) 50 | 51 | if self.num <= 0: 52 | # < 0, average all 53 | self.history[k] = [] 54 | 55 | # == 0: no average 56 | if self.num == 0: 57 | self.sum[k] = self.val[k] 58 | self.count[k] = 1 59 | 60 | elif len(self.history[k]) > self.num: 61 | pop_num = len(self.history[k]) - self.num 62 | for _ in range(pop_num): 63 | self.sum[k] -= self.history[k][0] 64 | del self.history[k][0] 65 | self.count[k] -= 1 66 | 67 | def __repr__(self): 68 | s = '' 69 | for k in self.sum: 70 | s += self.format_str(k) 71 | return s 72 | 73 | def format_str(self, attr): 74 | return "{name}: {val:.6f} ({avg:.6f}) ".format( 75 | name=attr, 76 | val=float(self.val[attr]), 77 | avg=float(self.sum[attr]) / self.count[attr]) 78 | 79 | def __getattr__(self, attr): 80 | if attr in self.__dict__: 81 | return super(AverageMeter, self).__getattr__(attr) 82 | if attr not in self.sum: 83 | print("invalid key '{}'".format(attr)) 84 | return Meter(attr, 0, 0) 85 | return Meter(attr, self.val[attr], self.avg(attr)) 86 | 87 | def avg(self, attr): 88 | return float(self.sum[attr]) / self.count[attr] 89 | 90 | 91 | if __name__ == '__main__': 92 | avg1 = AverageMeter(10) 93 | avg2 = AverageMeter(0) 94 | avg3 = AverageMeter(-1) 95 | 96 | for i in range(20): 97 | avg1.update(s=i) 98 | avg2.update(s=i) 99 | avg3.update(s=i) 100 | 101 | print('iter {}'.format(i)) 102 | print(avg1.s) 103 | print(avg2.s) 104 | print(avg3.s) 105 | -------------------------------------------------------------------------------- /pysot/utils/bbox.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | from collections import namedtuple 9 | 10 | import numpy as np 11 | 12 | 13 | Corner = namedtuple('Corner', 'x1 y1 x2 y2') 14 | # alias 15 | BBox = Corner 16 | Center = namedtuple('Center', 'x y w h') 17 | 18 | 19 | def corner2center(corner): 20 | """ convert (x1, y1, x2, y2) to (cx, cy, w, h) 21 | Args: 22 | conrner: Corner or np.array (4*N) 23 | Return: 24 | Center or np.array (4 * N) 25 | """ 26 | if isinstance(corner, Corner): 27 | x1, y1, x2, y2 = corner 28 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1)) 29 | else: 30 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3] 31 | x = (x1 + x2) * 0.5 32 | y = (y1 + y2) * 0.5 33 | w = x2 - x1 34 | h = y2 - y1 35 | return x, y, w, h 36 | 37 | 38 | def center2corner(center): 39 | """ convert (cx, cy, w, h) to (x1, y1, x2, y2) 40 | Args: 41 | center: Center or np.array (4 * N) 42 | Return: 43 | center or np.array (4 * N) 44 | """ 45 | if isinstance(center, Center): 46 | x, y, w, h = center 47 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5) 48 | else: 49 | x, y, w, h = center[0], center[1], center[2], center[3] 50 | x1 = x - w * 0.5 51 | y1 = y - h * 0.5 52 | x2 = x + w * 0.5 53 | y2 = y + h * 0.5 54 | return x1, y1, x2, y2 55 | 56 | 57 | def IoU(rect1, rect2): 58 | """ caculate interection over union 59 | Args: 60 | rect1: (x1, y1, x2, y2) 61 | rect2: (x1, y1, x2, y2) 62 | Returns: 63 | iou 64 | """ 65 | # overlap 66 | x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3] 67 | tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3] 68 | 69 | xx1 = np.maximum(tx1, x1) 70 | yy1 = np.maximum(ty1, y1) 71 | xx2 = np.minimum(tx2, x2) 72 | yy2 = np.minimum(ty2, y2) 73 | 74 | ww = np.maximum(0, xx2 - xx1) 75 | hh = np.maximum(0, yy2 - yy1) 76 | 77 | area = (x2-x1) * (y2-y1) 78 | target_a = (tx2-tx1) * (ty2 - ty1) 79 | inter = ww * hh 80 | iou = inter / (area + target_a - inter) 81 | return iou 82 | 83 | 84 | def cxy_wh_2_rect(pos, sz): 85 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 0-index 86 | """ 87 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) 88 | 89 | 90 | def rect_2_cxy_wh(rect): 91 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 0-index 92 | """ 93 | return np.array([rect[0]+rect[2]/2, rect[1]+rect[3]/2]), \ 94 | np.array([rect[2], rect[3]]) 95 | 96 | 97 | def cxy_wh_2_rect1(pos, sz): 98 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 1-index 99 | """ 100 | return np.array([pos[0]-sz[0]/2+1, pos[1]-sz[1]/2+1, sz[0], sz[1]]) 101 | 102 | 103 | def rect1_2_cxy_wh(rect): 104 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 1-index 105 | """ 106 | return np.array([rect[0]+rect[2]/2-1, rect[1]+rect[3]/2-1]), \ 107 | np.array([rect[2], rect[3]]) 108 | 109 | 110 | def get_axis_aligned_bbox(region): 111 | """ convert region to (cx, cy, w, h) that represent by axis aligned box 112 | """ 113 | nv = region.size 114 | if nv == 8: 115 | cx = np.mean(region[0::2]) 116 | cy = np.mean(region[1::2]) 117 | x1 = min(region[0::2]) 118 | x2 = max(region[0::2]) 119 | y1 = min(region[1::2]) 120 | y2 = max(region[1::2]) 121 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * \ 122 | np.linalg.norm(region[2:4] - region[4:6]) 123 | A2 = (x2 - x1) * (y2 - y1) 124 | s = np.sqrt(A1 / A2) 125 | w = s * (x2 - x1) + 1 126 | h = s * (y2 - y1) + 1 127 | else: 128 | x = region[0] 129 | y = region[1] 130 | w = region[2] 131 | h = region[3] 132 | cx = x+w/2 133 | cy = y+h/2 134 | return cx, cy, w, h 135 | 136 | 137 | def get_min_max_bbox(region): 138 | """ convert region to (cx, cy, w, h) that represent by mim-max box 139 | """ 140 | nv = region.size 141 | if nv == 8: 142 | cx = np.mean(region[0::2]) 143 | cy = np.mean(region[1::2]) 144 | x1 = min(region[0::2]) 145 | x2 = max(region[0::2]) 146 | y1 = min(region[1::2]) 147 | y2 = max(region[1::2]) 148 | w = x2 - x1 149 | h = y2 - y1 150 | else: 151 | x = region[0] 152 | y = region[1] 153 | w = region[2] 154 | h = region[3] 155 | cx = x+w/2 156 | cy = y+h/2 157 | return cx, cy, w, h 158 | -------------------------------------------------------------------------------- /pysot/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import os 9 | import socket 10 | import logging 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.distributed as dist 15 | 16 | from pysot.utils.log_helper import log_once 17 | 18 | logger = logging.getLogger('global') 19 | 20 | 21 | def average_reduce(v): 22 | if get_world_size() == 1: 23 | return v 24 | tensor = torch.cuda.FloatTensor(1) 25 | tensor[0] = v 26 | dist.all_reduce(tensor) 27 | v = tensor[0] / get_world_size() 28 | return v 29 | 30 | 31 | class DistModule(nn.Module): 32 | def __init__(self, module, bn_method=0): 33 | super(DistModule, self).__init__() 34 | self.module = module 35 | self.bn_method = bn_method 36 | if get_world_size() > 1: 37 | broadcast_params(self.module) 38 | else: 39 | self.bn_method = 0 # single proccess 40 | 41 | def forward(self, *args, **kwargs): 42 | broadcast_buffers(self.module, self.bn_method) 43 | return self.module(*args, **kwargs) 44 | 45 | def train(self, mode=True): 46 | super(DistModule, self).train(mode) 47 | self.module.train(mode) 48 | return self 49 | 50 | 51 | def broadcast_params(model): 52 | """ broadcast model parameters """ 53 | for p in model.state_dict().values(): 54 | dist.broadcast(p, 0) 55 | 56 | 57 | def broadcast_buffers(model, method=0): 58 | """ broadcast model buffers """ 59 | if method == 0: 60 | return 61 | 62 | world_size = get_world_size() 63 | 64 | for b in model._all_buffers(): 65 | if method == 1: # broadcast from main proccess 66 | dist.broadcast(b, 0) 67 | elif method == 2: # average 68 | dist.all_reduce(b) 69 | b /= world_size 70 | else: 71 | raise Exception('Invalid buffer broadcast code {}'.format(method)) 72 | 73 | 74 | inited = False 75 | 76 | 77 | def _dist_init(): 78 | ''' 79 | if guess right: 80 | ntasks: world_size (process num) 81 | proc_id: rank 82 | ''' 83 | # rank = int(os.environ['RANK']) 84 | rank = 0 85 | num_gpus = torch.cuda.device_count() 86 | torch.cuda.set_device(rank % num_gpus) 87 | dist.init_process_group(backend='nccl') 88 | world_size = dist.get_world_size() 89 | return rank, world_size 90 | 91 | 92 | def _get_local_ip(): 93 | try: 94 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 95 | s.connect(('8.8.8.8', 80)) 96 | ip = s.getsockname()[0] 97 | finally: 98 | s.close() 99 | return ip 100 | 101 | 102 | def dist_init(): 103 | global rank, world_size, inited 104 | # try: 105 | # rank, world_size = _dist_init() 106 | # except RuntimeError as e: 107 | # if 'public' in e.args[0]: 108 | # logger.info(e) 109 | # logger.info('Warning: use single process') 110 | # rank, world_size = 0, 1 111 | # else: 112 | # raise RuntimeError(*e.args) 113 | rank, world_size = 0, 1 114 | inited = True 115 | return rank, world_size 116 | 117 | 118 | def get_rank(): 119 | if not inited: 120 | raise(Exception('dist not inited')) 121 | return rank 122 | 123 | 124 | def get_world_size(): 125 | if not inited: 126 | raise(Exception('dist not inited')) 127 | return world_size 128 | 129 | 130 | def reduce_gradients(model, _type='sum'): 131 | types = ['sum', 'avg'] 132 | assert _type in types, 'gradients method must be in "{}"'.format(types) 133 | log_once("gradients method is {}".format(_type)) 134 | if get_world_size() > 1: 135 | for param in model.parameters(): 136 | if param.requires_grad: 137 | dist.all_reduce(param.grad.data) 138 | if _type == 'avg': 139 | param.grad.data /= get_world_size() 140 | else: 141 | return None 142 | -------------------------------------------------------------------------------- /pysot/utils/location_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_locations(features, stride, offset): 5 | h, w = features.size()[-2:] 6 | locations_per_level = compute_locations_per_level( 7 | h, w, stride, offset, 8 | features.device 9 | ) 10 | return locations_per_level 11 | 12 | 13 | def compute_locations_per_level(h, w, stride, offset, device): 14 | shifts_x = torch.arange( 15 | 0, w * stride, step=stride, 16 | dtype=torch.float32, device=device 17 | ) 18 | shifts_y = torch.arange( 19 | 0, h * stride, step=stride, 20 | dtype=torch.float32, device=device 21 | ) 22 | shift_y, shift_x = torch.meshgrid((shifts_y, shifts_x)) 23 | shift_x = shift_x.reshape(-1) 24 | shift_y = shift_y.reshape(-1) 25 | locations = torch.stack((shift_x, shift_y), dim=1) + offset 26 | return locations -------------------------------------------------------------------------------- /pysot/utils/log_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import os 9 | import logging 10 | import math 11 | import sys 12 | 13 | 14 | if hasattr(sys, 'frozen'): # support for py2exe 15 | _srcfile = "logging%s__init__%s" % (os.sep, __file__[-4:]) 16 | elif __file__[-4:].lower() in ['.pyc', '.pyo']: 17 | _srcfile = __file__[:-4] + '.py' 18 | else: 19 | _srcfile = __file__ 20 | _srcfile = os.path.normcase(_srcfile) 21 | 22 | logs = set() 23 | 24 | 25 | class Filter: 26 | def __init__(self, flag): 27 | self.flag = flag 28 | 29 | def filter(self, x): 30 | return self.flag 31 | 32 | 33 | class Dummy: 34 | def __init__(self, *arg, **kwargs): 35 | pass 36 | 37 | def __getattr__(self, arg): 38 | def dummy(*args, **kwargs): pass 39 | return dummy 40 | 41 | 42 | def get_format(logger, level): 43 | if 'RANK' in os.environ: 44 | rank = int(os.environ['RANK']) 45 | 46 | if level == logging.INFO: 47 | logger.addFilter(Filter(rank == 0)) 48 | else: 49 | rank = 0 50 | format_str = '[%(asctime)s-rk{}-%(filename)s#%(lineno)3d] %(message)s'.format(rank) 51 | formatter = logging.Formatter(format_str) 52 | return formatter 53 | 54 | 55 | def get_format_custom(logger, level): 56 | if 'RANK' in os.environ: 57 | rank = int(os.environ['RANK']) 58 | if level == logging.INFO: 59 | logger.addFilter(Filter(rank == 0)) 60 | else: 61 | rank = 0 62 | format_str = '[%(asctime)s-rk{}-%(message)s'.format(rank) 63 | formatter = logging.Formatter(format_str) 64 | return formatter 65 | 66 | 67 | def init_log(name, level=logging.INFO, format_func=get_format): 68 | if (name, level) in logs: 69 | return 70 | logs.add((name, level)) 71 | logger = logging.getLogger(name) 72 | logger.setLevel(level) 73 | ch = logging.StreamHandler() 74 | ch.setLevel(level) 75 | formatter = format_func(logger, level) 76 | ch.setFormatter(formatter) 77 | logger.addHandler(ch) 78 | return logger 79 | 80 | 81 | def add_file_handler(name, log_file, level=logging.INFO): 82 | logger = logging.getLogger(name) 83 | fh = logging.FileHandler(log_file) 84 | fh.setFormatter(get_format(logger, level)) 85 | logger.addHandler(fh) 86 | 87 | 88 | init_log('global') 89 | 90 | 91 | def print_speed(i, i_time, n): 92 | """print_speed(index, index_time, total_iteration)""" 93 | logger = logging.getLogger('global') 94 | average_time = i_time 95 | remaining_time = (n - i) * average_time 96 | remaining_day = math.floor(remaining_time / 86400) 97 | remaining_hour = math.floor(remaining_time / 3600 - 98 | remaining_day * 24) 99 | remaining_min = math.floor(remaining_time / 60 - 100 | remaining_day * 1440 - 101 | remaining_hour * 60) 102 | logger.info('Progress: %d / %d [%d%%], Speed: %.3f s/iter, ETA %d:%02d:%02d (D:H:M)\n' % 103 | (i, n, i / n * 100, 104 | average_time, 105 | remaining_day, remaining_hour, remaining_min)) 106 | 107 | 108 | def find_caller(): 109 | def current_frame(): 110 | try: 111 | raise Exception 112 | except: 113 | return sys.exc_info()[2].tb_frame.f_back 114 | 115 | f = current_frame() 116 | if f is not None: 117 | f = f.f_back 118 | rv = "(unknown file)", 0, "(unknown function)" 119 | while hasattr(f, "f_code"): 120 | co = f.f_code 121 | filename = os.path.normcase(co.co_filename) 122 | rv = (co.co_filename, f.f_lineno, co.co_name) 123 | if filename == _srcfile: 124 | f = f.f_back 125 | continue 126 | break 127 | rv = list(rv) 128 | rv[0] = os.path.basename(rv[0]) 129 | return rv 130 | 131 | 132 | class LogOnce: 133 | def __init__(self): 134 | self.logged = set() 135 | self.logger = init_log('log_once', format_func=get_format_custom) 136 | 137 | def log(self, strings): 138 | fn, lineno, caller = find_caller() 139 | key = (fn, lineno, caller, strings) 140 | if key in self.logged: 141 | return 142 | self.logged.add(key) 143 | message = "{filename:s}<{caller}>#{lineno:3d}] {strings}".format( 144 | filename=fn, lineno=lineno, strings=strings, caller=caller) 145 | self.logger.info(message) 146 | 147 | 148 | once_logger = LogOnce() 149 | 150 | 151 | def log_once(strings): 152 | once_logger.log(strings) 153 | 154 | 155 | def main(): 156 | for i, lvl in enumerate([logging.DEBUG, logging.INFO, 157 | logging.WARNING, logging.ERROR, 158 | logging.CRITICAL]): 159 | log_name = str(lvl) 160 | init_log(log_name, lvl) 161 | logger = logging.getLogger(log_name) 162 | print('****cur lvl:{}'.format(lvl)) 163 | logger.debug('debug') 164 | logger.info('info') 165 | logger.warning('warning') 166 | logger.error('error') 167 | logger.critical('critiacal') 168 | 169 | 170 | if __name__ == '__main__': 171 | main() 172 | for i in range(10): 173 | log_once('xxx') 174 | -------------------------------------------------------------------------------- /pysot/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import math 9 | 10 | import numpy as np 11 | from torch.optim.lr_scheduler import _LRScheduler 12 | 13 | from pysot.core.config import cfg 14 | 15 | 16 | class LRScheduler(_LRScheduler): 17 | def __init__(self, optimizer, last_epoch=-1): 18 | if 'lr_spaces' not in self.__dict__: 19 | raise Exception('lr_spaces must be set in "LRSchduler"') 20 | super(LRScheduler, self).__init__(optimizer, last_epoch) 21 | 22 | def get_cur_lr(self): 23 | return self.lr_spaces[self.last_epoch] 24 | 25 | def get_lr(self): 26 | epoch = self.last_epoch 27 | return [self.lr_spaces[epoch] * pg['initial_lr'] / self.start_lr 28 | for pg in self.optimizer.param_groups] 29 | 30 | def __repr__(self): 31 | return "({}) lr spaces: \n{}".format(self.__class__.__name__, 32 | self.lr_spaces) 33 | 34 | 35 | class LogScheduler(LRScheduler): 36 | def __init__(self, optimizer, start_lr=0.03, end_lr=5e-4, 37 | epochs=50, last_epoch=-1, **kwargs): 38 | self.start_lr = start_lr 39 | self.end_lr = end_lr 40 | self.epochs = epochs 41 | self.lr_spaces = np.logspace(math.log10(start_lr), 42 | math.log10(end_lr), 43 | epochs) 44 | 45 | super(LogScheduler, self).__init__(optimizer, last_epoch) 46 | 47 | 48 | class StepScheduler(LRScheduler): 49 | def __init__(self, optimizer, start_lr=0.01, end_lr=None, 50 | step=10, mult=0.1, epochs=50, last_epoch=-1, **kwargs): 51 | if end_lr is not None: 52 | if start_lr is None: 53 | start_lr = end_lr / (mult ** (epochs // step)) 54 | else: # for warm up policy 55 | mult = math.pow(end_lr/start_lr, 1. / (epochs // step)) 56 | self.start_lr = start_lr 57 | self.lr_spaces = self.start_lr * (mult**(np.arange(epochs) // step)) 58 | self.mult = mult 59 | self._step = step 60 | 61 | super(StepScheduler, self).__init__(optimizer, last_epoch) 62 | 63 | 64 | class MultiStepScheduler(LRScheduler): 65 | def __init__(self, optimizer, start_lr=0.01, end_lr=None, 66 | steps=[10, 20, 30, 40], mult=0.5, epochs=50, 67 | last_epoch=-1, **kwargs): 68 | if end_lr is not None: 69 | if start_lr is None: 70 | start_lr = end_lr / (mult ** (len(steps))) 71 | else: 72 | mult = math.pow(end_lr/start_lr, 1. / len(steps)) 73 | self.start_lr = start_lr 74 | self.lr_spaces = self._build_lr(start_lr, steps, mult, epochs) 75 | self.mult = mult 76 | self.steps = steps 77 | 78 | super(MultiStepScheduler, self).__init__(optimizer, last_epoch) 79 | 80 | def _build_lr(self, start_lr, steps, mult, epochs): 81 | lr = [0] * epochs 82 | lr[0] = start_lr 83 | for i in range(1, epochs): 84 | lr[i] = lr[i-1] 85 | if i in steps: 86 | lr[i] *= mult 87 | return np.array(lr, dtype=np.float32) 88 | 89 | 90 | class LinearStepScheduler(LRScheduler): 91 | def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, 92 | epochs=50, last_epoch=-1, **kwargs): 93 | self.start_lr = start_lr 94 | self.end_lr = end_lr 95 | self.lr_spaces = np.linspace(start_lr, end_lr, epochs) 96 | super(LinearStepScheduler, self).__init__(optimizer, last_epoch) 97 | 98 | 99 | class CosStepScheduler(LRScheduler): 100 | def __init__(self, optimizer, start_lr=0.01, end_lr=0.005, 101 | epochs=50, last_epoch=-1, **kwargs): 102 | self.start_lr = start_lr 103 | self.end_lr = end_lr 104 | self.lr_spaces = self._build_lr(start_lr, end_lr, epochs) 105 | 106 | super(CosStepScheduler, self).__init__(optimizer, last_epoch) 107 | 108 | def _build_lr(self, start_lr, end_lr, epochs): 109 | index = np.arange(epochs).astype(np.float32) 110 | lr = end_lr + (start_lr - end_lr) * \ 111 | (1. + np.cos(index * np.pi / epochs)) * 0.5 112 | return lr.astype(np.float32) 113 | 114 | 115 | class WarmUPScheduler(LRScheduler): 116 | def __init__(self, optimizer, warmup, normal, epochs=50, last_epoch=-1): 117 | warmup = warmup.lr_spaces # [::-1] 118 | normal = normal.lr_spaces 119 | self.lr_spaces = np.concatenate([warmup, normal]) 120 | self.start_lr = normal[0] 121 | 122 | super(WarmUPScheduler, self).__init__(optimizer, last_epoch) 123 | 124 | 125 | LRs = { 126 | 'log': LogScheduler, 127 | 'step': StepScheduler, 128 | 'multi-step': MultiStepScheduler, 129 | 'linear': LinearStepScheduler, 130 | 'cos': CosStepScheduler} 131 | 132 | 133 | def _build_lr_scheduler(optimizer, config, epochs=50, last_epoch=-1): 134 | return LRs[config.TYPE](optimizer, last_epoch=last_epoch, 135 | epochs=epochs, **config.KWARGS) 136 | 137 | 138 | def _build_warm_up_scheduler(optimizer, epochs=50, last_epoch=-1): 139 | warmup_epoch = cfg.TRAIN.LR_WARMUP.EPOCH 140 | sc1 = _build_lr_scheduler(optimizer, cfg.TRAIN.LR_WARMUP, 141 | warmup_epoch, last_epoch) 142 | sc2 = _build_lr_scheduler(optimizer, cfg.TRAIN.LR, 143 | epochs - warmup_epoch, last_epoch) 144 | return WarmUPScheduler(optimizer, sc1, sc2, epochs, last_epoch) 145 | 146 | 147 | def build_lr_scheduler(optimizer, epochs=50, last_epoch=-1): 148 | if cfg.TRAIN.LR_WARMUP.WARMUP: 149 | return _build_warm_up_scheduler(optimizer, epochs, last_epoch) 150 | else: 151 | return _build_lr_scheduler(optimizer, cfg.TRAIN.LR, 152 | epochs, last_epoch) 153 | 154 | 155 | if __name__ == '__main__': 156 | import torch.nn as nn 157 | from torch.optim import SGD 158 | 159 | class Net(nn.Module): 160 | def __init__(self): 161 | super(Net, self).__init__() 162 | self.conv = nn.Conv2d(10, 10, kernel_size=3) 163 | net = Net().parameters() 164 | optimizer = SGD(net, lr=0.01) 165 | cfg.TRAIN.LR_WARMUP.WARMUP = False 166 | # test1 167 | step = { 168 | 'type': 'step', 169 | 'start_lr': 0.01, 170 | 'step': 10, 171 | 'mult': 0.1 172 | } 173 | lr = build_lr_scheduler(optimizer, step) 174 | print(lr) 175 | 176 | log = { 177 | 'type': 'log', 178 | 'start_lr': 0.03, 179 | 'end_lr': 5e-4, 180 | } 181 | lr = build_lr_scheduler(optimizer, log) 182 | 183 | print(lr) 184 | 185 | log = { 186 | 'type': 'multi-step', 187 | "start_lr": 0.01, 188 | "mult": 0.1, 189 | "steps": [10, 15, 20] 190 | } 191 | lr = build_lr_scheduler(optimizer, log) 192 | print(lr) 193 | 194 | cos = { 195 | "type": 'cos', 196 | 'start_lr': 0.01, 197 | 'end_lr': 0.0005, 198 | } 199 | lr = build_lr_scheduler(optimizer, cos) 200 | print(lr) 201 | 202 | cfg.TRAIN.LR_WARMUP.WARMUP = True 203 | step = { 204 | 'type': 'step', 205 | 'start_lr': 0.001, 206 | 'end_lr': 0.03, 207 | 'step': 1, 208 | } 209 | 210 | warmup = log.copy() 211 | warmup['warmup'] = step 212 | warmup['warmup']['epoch'] = 5 213 | lr = build_lr_scheduler(optimizer, warmup, epochs=55) 214 | print(lr) 215 | 216 | lr.step() 217 | print(lr.last_epoch) 218 | 219 | lr.step(5) 220 | print(lr.last_epoch) 221 | -------------------------------------------------------------------------------- /pysot/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import os 9 | import numpy as np 10 | import torch 11 | 12 | from colorama import Fore, Style 13 | 14 | 15 | __all__ = ['commit', 'describe'] 16 | 17 | 18 | def _exec(cmd): 19 | f = os.popen(cmd, 'r', 1) 20 | return f.read().strip() 21 | 22 | 23 | def _bold(s): 24 | return "\033[1m%s\033[0m" % s 25 | 26 | 27 | def _color(s): 28 | # return f'{Fore.RED}{s}{Style.RESET_ALL}' 29 | return "{}{}{}".format(Fore.RED,s,Style.RESET_ALL) 30 | 31 | 32 | def _describe(model, lines=None, spaces=0): 33 | head = " " * spaces 34 | for name, p in model.named_parameters(): 35 | if '.' in name: 36 | continue 37 | if p.requires_grad: 38 | name = _color(name) 39 | line = "{head}- {name}".format(head=head, name=name) 40 | lines.append(line) 41 | 42 | for name, m in model.named_children(): 43 | space_num = len(name) + spaces + 1 44 | if m.training: 45 | name = _color(name) 46 | line = "{head}.{name} ({type})".format( 47 | head=head, 48 | name=name, 49 | type=m.__class__.__name__) 50 | lines.append(line) 51 | _describe(m, lines, space_num) 52 | 53 | 54 | def commit(): 55 | root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) 56 | cmd = "cd {}; git log | head -n1 | awk '{{print $2}}'".format(root) 57 | commit = _exec(cmd) 58 | cmd = "cd {}; git log --oneline | head -n1".format(root) 59 | commit_log = _exec(cmd) 60 | return "commit : {}\n log : {}".format(commit, commit_log) 61 | 62 | 63 | def describe(net, name=None): 64 | num = 0 65 | lines = [] 66 | if name is not None: 67 | lines.append(name) 68 | num = len(name) 69 | _describe(net, lines, num) 70 | return "\n".join(lines) 71 | 72 | 73 | def bbox_clip(x, min_value, max_value): 74 | new_x = max(min_value, min(x, max_value)) 75 | return new_x 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /pysot/utils/model_load.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import logging 9 | 10 | import torch 11 | 12 | 13 | logger = logging.getLogger('global') 14 | 15 | 16 | def check_keys(model, pretrained_state_dict): 17 | ckpt_keys = set(pretrained_state_dict.keys()) 18 | model_keys = set(model.state_dict().keys()) 19 | used_pretrained_keys = model_keys & ckpt_keys 20 | unused_pretrained_keys = ckpt_keys - model_keys 21 | missing_keys = model_keys - ckpt_keys 22 | # filter 'num_batches_tracked' 23 | missing_keys = [x for x in missing_keys 24 | if not x.endswith('num_batches_tracked')] 25 | if len(missing_keys) > 0: 26 | logger.info('[Warning] missing keys: {}'.format(missing_keys)) 27 | logger.info('missing keys:{}'.format(len(missing_keys))) 28 | if len(unused_pretrained_keys) > 0: 29 | logger.info('[Warning] unused_pretrained_keys: {}'.format( 30 | unused_pretrained_keys)) 31 | logger.info('unused checkpoint keys:{}'.format( 32 | len(unused_pretrained_keys))) 33 | logger.info('used keys:{}'.format(len(used_pretrained_keys))) 34 | assert len(used_pretrained_keys) > 0, \ 35 | 'load NONE from pretrained checkpoint' 36 | return True 37 | 38 | 39 | def remove_prefix(state_dict, prefix): 40 | ''' Old style model is stored with all names of parameters 41 | share common prefix 'module.' ''' 42 | logger.info('remove prefix \'{}\''.format(prefix)) 43 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 44 | return {f(key): value for key, value in state_dict.items()} 45 | 46 | 47 | def load_pretrain(model, pretrained_path): 48 | logger.info('load pretrained model from {}'.format(pretrained_path)) 49 | device = torch.cuda.current_device() 50 | 51 | if 'inception' in pretrained_path: 52 | pretrained_dict = torch.load(pretrained_path, 53 | map_location=lambda storage, loc: storage.cpu()) 54 | else: 55 | pretrained_dict = torch.load(pretrained_path, 56 | map_location=lambda storage, loc: storage.cuda(device)) 57 | 58 | if "state_dict" in pretrained_dict.keys(): 59 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 60 | 'module.') 61 | else: 62 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 63 | 64 | try: 65 | check_keys(model, pretrained_dict) 66 | except: 67 | logger.info('[Warning]: using pretrain as features.\ 68 | Adding "features." as prefix') 69 | new_dict = {} 70 | for k, v in pretrained_dict.items(): 71 | k = 'features.' + k 72 | new_dict[k] = v 73 | pretrained_dict = new_dict 74 | check_keys(model, pretrained_dict) 75 | model.load_state_dict(pretrained_dict, strict=False) 76 | return model 77 | 78 | 79 | def restore_from(model, optimizer, ckpt_path): 80 | device = torch.cuda.current_device() 81 | ckpt = torch.load(ckpt_path, 82 | map_location=lambda storage, loc: storage.cuda(device)) 83 | epoch = ckpt['epoch'] 84 | 85 | ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.') 86 | check_keys(model, ckpt_model_dict) 87 | model.load_state_dict(ckpt_model_dict, strict=False) 88 | 89 | check_keys(optimizer, ckpt['optimizer']) 90 | optimizer.load_state_dict(ckpt['optimizer']) 91 | return model, optimizer, epoch 92 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pytorch==1.2.0 3 | opencv-python 4 | pyyaml 5 | yacs 6 | tqdm 7 | colorama 8 | matplotlib 9 | cython 10 | tensorboardX 11 | -------------------------------------------------------------------------------- /toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/__init__.py -------------------------------------------------------------------------------- /toolkit/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .otb import OTBDataset 2 | from .uav import UAVDataset 3 | from .lasot import LaSOTDataset 4 | from .got10k import GOT10kDataset 5 | from .trackingnet import TrackingNetDataset 6 | 7 | 8 | class DatasetFactory(object): 9 | @staticmethod 10 | def create_dataset(**kwargs): 11 | """ 12 | Args: 13 | name: dataset name 'OTB2015', 'LaSOT', 'UAV123', 'NFS240', 'NFS30', 14 | 'VOT2018', 'VOT2016', 'VOT2018-LT' 15 | dataset_root: dataset root 16 | load_img: wether to load image 17 | Return: 18 | dataset 19 | """ 20 | assert 'name' in kwargs, "should provide dataset name" 21 | name = kwargs['name'] 22 | if 'OTB' in name: 23 | dataset = OTBDataset(**kwargs) 24 | elif 'LaSOT' == name: 25 | dataset = LaSOTDataset(**kwargs) 26 | elif 'UAV' in name: 27 | dataset = UAVDataset(**kwargs) 28 | elif 'GOT-10k' == name: 29 | dataset = GOT10kDataset(**kwargs) 30 | elif 'TrackingNet' == name: 31 | dataset = TrackingNetDataset(**kwargs) 32 | else: 33 | raise Exception("unknow dataset {}".format(kwargs['name'])) 34 | return dataset 35 | 36 | -------------------------------------------------------------------------------- /toolkit/datasets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/datasets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/datasets/__pycache__/dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/datasets/__pycache__/dataset.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/datasets/__pycache__/got10k.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/datasets/__pycache__/got10k.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/datasets/__pycache__/lasot.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/datasets/__pycache__/lasot.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/datasets/__pycache__/otb.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/datasets/__pycache__/otb.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/datasets/__pycache__/uav.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/datasets/__pycache__/uav.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/datasets/__pycache__/video.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/datasets/__pycache__/video.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/datasets/__pycache__/vot.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/datasets/__pycache__/vot.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | class Dataset(object): 4 | def __init__(self, name, dataset_root): 5 | self.name = name 6 | self.dataset_root = dataset_root 7 | self.videos = None 8 | 9 | def __getitem__(self, idx): 10 | if isinstance(idx, str): 11 | return self.videos[idx] 12 | elif isinstance(idx, int): 13 | return self.videos[sorted(list(self.videos.keys()))[idx]] 14 | 15 | def __len__(self): 16 | return len(self.videos) 17 | 18 | def __iter__(self): 19 | keys = sorted(list(self.videos.keys())) 20 | for key in keys: 21 | yield self.videos[key] 22 | 23 | def set_tracker(self, path, tracker_names): 24 | """ 25 | Args: 26 | path: path to tracker results, 27 | tracker_names: list of tracker name 28 | """ 29 | self.tracker_path = path 30 | self.tracker_names = tracker_names 31 | # for video in tqdm(self.videos.values(), 32 | # desc='loading tacker result', ncols=100): 33 | # video.load_tracker(path, tracker_names) 34 | -------------------------------------------------------------------------------- /toolkit/datasets/got10k.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | from .dataset import Dataset 8 | from .video import Video 9 | 10 | class GOT10kVideo(Video): 11 | """ 12 | Args: 13 | name: video name 14 | root: dataset root 15 | video_dir: video directory 16 | init_rect: init rectangle 17 | img_names: image names 18 | gt_rect: groundtruth rectangle 19 | attr: attribute of video 20 | """ 21 | def __init__(self, name, root, video_dir, init_rect, img_names, 22 | gt_rect, attr, load_img=False): 23 | super(GOT10kVideo, self).__init__(name, root, video_dir, 24 | init_rect, img_names, gt_rect, attr, load_img) 25 | 26 | # def load_tracker(self, path, tracker_names=None): 27 | # """ 28 | # Args: 29 | # path(str): path to result 30 | # tracker_name(list): name of tracker 31 | # """ 32 | # if not tracker_names: 33 | # tracker_names = [x.split('/')[-1] for x in glob(path) 34 | # if os.path.isdir(x)] 35 | # if isinstance(tracker_names, str): 36 | # tracker_names = [tracker_names] 37 | # # self.pred_trajs = {} 38 | # for name in tracker_names: 39 | # traj_file = os.path.join(path, name, self.name+'.txt') 40 | # if os.path.exists(traj_file): 41 | # with open(traj_file, 'r') as f : 42 | # self.pred_trajs[name] = [list(map(float, x.strip().split(','))) 43 | # for x in f.readlines()] 44 | # if len(self.pred_trajs[name]) != len(self.gt_traj): 45 | # print(name, len(self.pred_trajs[name]), len(self.gt_traj), self.name) 46 | # else: 47 | 48 | # self.tracker_names = list(self.pred_trajs.keys()) 49 | 50 | class GOT10kDataset(Dataset): 51 | """ 52 | Args: 53 | name: dataset name, should be "NFS30" or "NFS240" 54 | dataset_root, dataset root dir 55 | """ 56 | def __init__(self, name, dataset_root, load_img=False): 57 | super(GOT10kDataset, self).__init__(name, dataset_root) 58 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 59 | meta_data = json.load(f) 60 | 61 | # load videos 62 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 63 | self.videos = {} 64 | for video in pbar: 65 | pbar.set_postfix_str(video) 66 | self.videos[video] = GOT10kVideo(video, 67 | dataset_root, 68 | meta_data[video]['video_dir'], 69 | meta_data[video]['init_rect'], 70 | meta_data[video]['img_names'], 71 | meta_data[video]['gt_rect'], 72 | None) 73 | self.attr = {} 74 | self.attr['ALL'] = list(self.videos.keys()) 75 | -------------------------------------------------------------------------------- /toolkit/datasets/lasot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | from tqdm import tqdm 6 | from glob import glob 7 | 8 | from .dataset import Dataset 9 | from .video import Video 10 | 11 | class LaSOTVideo(Video): 12 | """ 13 | Args: 14 | name: video name 15 | root: dataset root 16 | video_dir: video directory 17 | init_rect: init rectangle 18 | img_names: image names 19 | gt_rect: groundtruth rectangle 20 | attr: attribute of video 21 | """ 22 | def __init__(self, name, root, video_dir, init_rect, img_names, 23 | gt_rect, attr, absent, load_img=False): 24 | super(LaSOTVideo, self).__init__(name, root, video_dir, 25 | init_rect, img_names, gt_rect, attr, load_img) 26 | self.absent = np.array(absent, np.int8) 27 | 28 | def load_tracker(self, path, tracker_names=None, store=True): 29 | """ 30 | Args: 31 | path(str): path to result 32 | tracker_name(list): name of tracker 33 | """ 34 | if not tracker_names: 35 | tracker_names = [x.split('/')[-1] for x in glob(path) 36 | if os.path.isdir(x)] 37 | if isinstance(tracker_names, str): 38 | tracker_names = [tracker_names] 39 | for name in tracker_names: 40 | traj_file = os.path.join(path, name, self.name+'.txt') 41 | if os.path.exists(traj_file): 42 | with open(traj_file, 'r') as f : 43 | pred_traj = [list(map(float, x.strip().split(','))) 44 | for x in f.readlines()] 45 | else: 46 | print("File not exists: ", traj_file) 47 | if self.name == 'monkey-17': 48 | pred_traj = pred_traj[:len(self.gt_traj)] 49 | if store: 50 | self.pred_trajs[name] = pred_traj 51 | else: 52 | return pred_traj 53 | self.tracker_names = list(self.pred_trajs.keys()) 54 | 55 | 56 | 57 | class LaSOTDataset(Dataset): 58 | """ 59 | Args: 60 | name: dataset name, should be 'OTB100', 'CVPR13', 'OTB50' 61 | dataset_root: dataset root 62 | load_img: wether to load all imgs 63 | """ 64 | def __init__(self, name, dataset_root, load_img=False): 65 | super(LaSOTDataset, self).__init__(name, dataset_root) 66 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 67 | meta_data = json.load(f) 68 | 69 | # load videos 70 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 71 | self.videos = {} 72 | for video in pbar: 73 | pbar.set_postfix_str(video) 74 | self.videos[video] = LaSOTVideo(video, 75 | dataset_root, 76 | meta_data[video]['video_dir'], 77 | meta_data[video]['init_rect'], 78 | meta_data[video]['img_names'], 79 | meta_data[video]['gt_rect'], 80 | meta_data[video]['attr'], 81 | meta_data[video]['absent']) 82 | 83 | # set attr 84 | attr = [] 85 | for x in self.videos.values(): 86 | attr += x.attr 87 | attr = set(attr) 88 | self.attr = {} 89 | self.attr['ALL'] = list(self.videos.keys()) 90 | for x in attr: 91 | self.attr[x] = [] 92 | for k, v in self.videos.items(): 93 | for attr_ in v.attr: 94 | self.attr[attr_].append(k) 95 | 96 | 97 | -------------------------------------------------------------------------------- /toolkit/datasets/otb.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from glob import glob 8 | 9 | from .dataset import Dataset 10 | from .video import Video 11 | 12 | 13 | class OTBVideo(Video): 14 | """ 15 | Args: 16 | name: video name 17 | root: dataset root 18 | video_dir: video directory 19 | init_rect: init rectangle 20 | img_names: image names 21 | gt_rect: groundtruth rectangle 22 | attr: attribute of video 23 | """ 24 | def __init__(self, name, root, video_dir, init_rect, img_names, 25 | gt_rect, attr, load_img=False): 26 | super(OTBVideo, self).__init__(name, root, video_dir, 27 | init_rect, img_names, gt_rect, attr, load_img) 28 | 29 | def load_tracker(self, path, tracker_names=None, store=True): 30 | """ 31 | Args: 32 | path(str): path to result 33 | tracker_name(list): name of tracker 34 | """ 35 | if not tracker_names: 36 | tracker_names = [x.split('/')[-1] for x in glob(path) 37 | if os.path.isdir(x)] 38 | if isinstance(tracker_names, str): 39 | tracker_names = [tracker_names] 40 | for name in tracker_names: 41 | traj_file = os.path.join(path, name, self.name+'.txt') 42 | if not os.path.exists(traj_file): 43 | if self.name == 'FleetFace': 44 | txt_name = 'fleetface.txt' 45 | elif self.name == 'Jogging-1': 46 | txt_name = 'jogging_1.txt' 47 | elif self.name == 'Jogging-2': 48 | txt_name = 'jogging_2.txt' 49 | elif self.name == 'Skating2-1': 50 | txt_name = 'skating2_1.txt' 51 | elif self.name == 'Skating2-2': 52 | txt_name = 'skating2_2.txt' 53 | elif self.name == 'FaceOcc1': 54 | txt_name = 'faceocc1.txt' 55 | elif self.name == 'FaceOcc2': 56 | txt_name = 'faceocc2.txt' 57 | elif self.name == 'Human4-2': 58 | txt_name = 'human4_2.txt' 59 | else: 60 | txt_name = self.name[0].lower()+self.name[1:]+'.txt' 61 | traj_file = os.path.join(path, name, txt_name) 62 | if os.path.exists(traj_file): 63 | with open(traj_file, 'r') as f : 64 | pred_traj = [list(map(float, x.strip().split(','))) 65 | for x in f.readlines()] 66 | if len(pred_traj) != len(self.gt_traj): 67 | print(name, len(pred_traj), len(self.gt_traj), self.name) 68 | if store: 69 | self.pred_trajs[name] = pred_traj 70 | else: 71 | return pred_traj 72 | else: 73 | print(traj_file) 74 | self.tracker_names = list(self.pred_trajs.keys()) 75 | 76 | 77 | 78 | class OTBDataset(Dataset): 79 | """ 80 | Args: 81 | name: dataset name, should be 'OTB100', 'CVPR13', 'OTB50' 82 | dataset_root: dataset root 83 | load_img: wether to load all imgs 84 | """ 85 | def __init__(self, name, dataset_root, load_img=False): 86 | super(OTBDataset, self).__init__(name, dataset_root) 87 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 88 | meta_data = json.load(f) 89 | 90 | # load videos 91 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 92 | self.videos = {} 93 | for video in pbar: 94 | pbar.set_postfix_str(video) 95 | self.videos[video] = OTBVideo(video, 96 | dataset_root, 97 | meta_data[video]['video_dir'], 98 | meta_data[video]['init_rect'], 99 | meta_data[video]['img_names'], 100 | meta_data[video]['gt_rect'], 101 | meta_data[video]['attr'], 102 | load_img) 103 | 104 | # set attr 105 | attr = [] 106 | for x in self.videos.values(): 107 | attr += x.attr 108 | attr = set(attr) 109 | self.attr = {} 110 | self.attr['ALL'] = list(self.videos.keys()) 111 | for x in attr: 112 | self.attr[x] = [] 113 | for k, v in self.videos.items(): 114 | for attr_ in v.attr: 115 | self.attr[attr_].append(k) 116 | -------------------------------------------------------------------------------- /toolkit/datasets/uav.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from tqdm import tqdm 5 | from glob import glob 6 | 7 | from .dataset import Dataset 8 | from .video import Video 9 | 10 | class UAVVideo(Video): 11 | """ 12 | Args: 13 | name: video name 14 | root: dataset root 15 | video_dir: video directory 16 | init_rect: init rectangle 17 | img_names: image names 18 | gt_rect: groundtruth rectangle 19 | attr: attribute of video 20 | """ 21 | def __init__(self, name, root, video_dir, init_rect, img_names, 22 | gt_rect, attr, load_img=False): 23 | super(UAVVideo, self).__init__(name, root, video_dir, 24 | init_rect, img_names, gt_rect, attr, load_img) 25 | 26 | 27 | class UAVDataset(Dataset): 28 | """ 29 | Args: 30 | name: dataset name, should be 'UAV123', 'UAV20L' 31 | dataset_root: dataset root 32 | load_img: wether to load all imgs 33 | """ 34 | def __init__(self, name, dataset_root, load_img=False): 35 | super(UAVDataset, self).__init__(name, dataset_root) 36 | with open(os.path.join(dataset_root, name+'.json'), 'r') as f: 37 | meta_data = json.load(f) 38 | 39 | # load videos 40 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 41 | self.videos = {} 42 | for video in pbar: 43 | pbar.set_postfix_str(video) 44 | self.videos[video] = UAVVideo(video, 45 | dataset_root, 46 | meta_data[video]['video_dir'], 47 | meta_data[video]['init_rect'], 48 | meta_data[video]['img_names'], 49 | meta_data[video]['gt_rect'], 50 | meta_data[video]['attr']) 51 | 52 | # set attr 53 | attr = [] 54 | for x in self.videos.values(): 55 | attr += x.attr 56 | attr = set(attr) 57 | self.attr = {} 58 | self.attr['ALL'] = list(self.videos.keys()) 59 | for x in attr: 60 | self.attr[x] = [] 61 | for k, v in self.videos.items(): 62 | for attr_ in v.attr: 63 | self.attr[attr_].append(k) 64 | 65 | -------------------------------------------------------------------------------- /toolkit/datasets/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import re 4 | import numpy as np 5 | import json 6 | 7 | from glob import glob 8 | 9 | class Video(object): 10 | def __init__(self, name, root, video_dir, init_rect, img_names, 11 | gt_rect, attr, load_img=False): 12 | self.name = name 13 | self.video_dir = video_dir 14 | self.init_rect = init_rect 15 | self.gt_traj = gt_rect 16 | self.attr = attr 17 | self.pred_trajs = {} 18 | self.img_names = [os.path.join(root, x) for x in img_names] 19 | self.imgs = None 20 | 21 | if load_img: 22 | self.imgs = [cv2.imread(x) for x in self.img_names] 23 | self.width = self.imgs[0].shape[1] 24 | self.height = self.imgs[0].shape[0] 25 | else: 26 | img = cv2.imread(self.img_names[0]) 27 | assert img is not None, self.img_names[0] 28 | self.width = img.shape[1] 29 | self.height = img.shape[0] 30 | 31 | def load_tracker(self, path, tracker_names=None, store=True): 32 | """ 33 | Args: 34 | path(str): path to result 35 | tracker_name(list): name of tracker 36 | """ 37 | if not tracker_names: 38 | tracker_names = [x.split('/')[-1] for x in glob(path) 39 | if os.path.isdir(x)] 40 | if isinstance(tracker_names, str): 41 | tracker_names = [tracker_names] 42 | for name in tracker_names: 43 | traj_file = os.path.join(path, name, self.name+'.txt') 44 | if os.path.exists(traj_file): 45 | with open(traj_file, 'r') as f : 46 | pred_traj = [list(map(float, x.strip().split(','))) 47 | for x in f.readlines()] 48 | if len(pred_traj) != len(self.gt_traj): 49 | print(name, len(pred_traj), len(self.gt_traj), self.name) 50 | if store: 51 | self.pred_trajs[name] = pred_traj 52 | else: 53 | return pred_traj 54 | else: 55 | print(traj_file) 56 | self.tracker_names = list(self.pred_trajs.keys()) 57 | 58 | def load_img(self): 59 | if self.imgs is None: 60 | self.imgs = [cv2.imread(x) for x in self.img_names] 61 | self.width = self.imgs[0].shape[1] 62 | self.height = self.imgs[0].shape[0] 63 | 64 | def free_img(self): 65 | self.imgs = None 66 | 67 | def __len__(self): 68 | return len(self.img_names) 69 | 70 | def __getitem__(self, idx): 71 | if self.imgs is None: 72 | return cv2.imread(self.img_names[idx]), self.gt_traj[idx] 73 | else: 74 | return self.imgs[idx], self.gt_traj[idx] 75 | 76 | def __iter__(self): 77 | for i in range(len(self.img_names)): 78 | if self.imgs is not None: 79 | yield self.imgs[i], self.gt_traj[i] 80 | else: 81 | yield cv2.imread(self.img_names[i]), self.gt_traj[i] 82 | 83 | def draw_box(self, roi, img, linewidth, color, name=None): 84 | """ 85 | roi: rectangle or polygon 86 | img: numpy array img 87 | linewith: line width of the bbox 88 | """ 89 | if len(roi) > 6 and len(roi) % 2 == 0: 90 | pts = np.array(roi, np.int32).reshape(-1, 1, 2) 91 | color = tuple(map(int, color)) 92 | img = cv2.polylines(img, [pts], True, color, linewidth) 93 | pt = (pts[0, 0, 0], pts[0, 0, 1]-5) 94 | if name: 95 | img = cv2.putText(img, name, pt, cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, color, 1) 96 | elif len(roi) == 4: 97 | if not np.isnan(roi[0]): 98 | roi = list(map(int, roi)) 99 | color = tuple(map(int, color)) 100 | img = cv2.rectangle(img, (roi[0], roi[1]), (roi[0]+roi[2], roi[1]+roi[3]), 101 | color, linewidth) 102 | if name: 103 | img = cv2.putText(img, name, (roi[0], roi[1]-5), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, color, 1) 104 | return img 105 | 106 | def show(self, pred_trajs={}, linewidth=2, show_name=False): 107 | """ 108 | pred_trajs: dict of pred_traj, {'tracker_name': list of traj} 109 | pred_traj should contain polygon or rectangle(x, y, width, height) 110 | linewith: line width of the bbox 111 | """ 112 | assert self.imgs is not None 113 | video = [] 114 | cv2.namedWindow(self.name, cv2.WINDOW_NORMAL) 115 | colors = {} 116 | if len(pred_trajs) == 0 and len(self.pred_trajs) > 0: 117 | pred_trajs = self.pred_trajs 118 | for i, (roi, img) in enumerate(zip(self.gt_traj, 119 | self.imgs[self.start_frame:self.end_frame+1])): 120 | img = img.copy() 121 | if len(img.shape) == 2: 122 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 123 | else: 124 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 125 | img = self.draw_box(roi, img, linewidth, (0, 255, 0), 126 | 'gt' if show_name else None) 127 | for name, trajs in pred_trajs.items(): 128 | if name not in colors: 129 | color = tuple(np.random.randint(0, 256, 3)) 130 | colors[name] = color 131 | else: 132 | color = colors[name] 133 | img = self.draw_box(trajs[0][i], img, linewidth, color, 134 | name if show_name else None) 135 | cv2.putText(img, str(i+self.start_frame), (5, 20), 136 | cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (255, 255, 0), 2) 137 | cv2.imshow(self.name, img) 138 | cv2.waitKey(40) 139 | video.append(img.copy()) 140 | return video 141 | -------------------------------------------------------------------------------- /toolkit/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .ope_benchmark import OPEBenchmark 2 | -------------------------------------------------------------------------------- /toolkit/evaluation/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/evaluation/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/evaluation/__pycache__/ope_benchmark.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/evaluation/__pycache__/ope_benchmark.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/utils/__pycache__/statistics.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ohhhyeahhh/SiamGAT/eab3e94f9c9d44c1b007b2b63bdf75129443dd40/toolkit/utils/__pycache__/statistics.cpython-35.pyc -------------------------------------------------------------------------------- /toolkit/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author fangyi.zhang@vipl.ict.ac.cn 3 | """ 4 | import numpy as np 5 | 6 | def overlap_ratio(rect1, rect2): 7 | '''Compute overlap ratio between two rects 8 | Args 9 | rect:2d array of N x [x,y,w,h] 10 | Return: 11 | iou 12 | ''' 13 | # if rect1.ndim==1: 14 | # rect1 = rect1[np.newaxis, :] 15 | # if rect2.ndim==1: 16 | # rect2 = rect2[np.newaxis, :] 17 | left = np.maximum(rect1[:,0], rect2[:,0]) 18 | right = np.minimum(rect1[:,0]+rect1[:,2], rect2[:,0]+rect2[:,2]) 19 | top = np.maximum(rect1[:,1], rect2[:,1]) 20 | bottom = np.minimum(rect1[:,1]+rect1[:,3], rect2[:,1]+rect2[:,3]) 21 | 22 | intersect = np.maximum(0,right - left) * np.maximum(0,bottom - top) 23 | union = rect1[:,2]*rect1[:,3] + rect2[:,2]*rect2[:,3] - intersect 24 | iou = intersect / union 25 | iou = np.maximum(np.minimum(1, iou), 0) 26 | return iou 27 | 28 | def success_overlap(gt_bb, result_bb, n_frame): 29 | thresholds_overlap = np.arange(0, 1.05, 0.05) 30 | success = np.zeros(len(thresholds_overlap)) 31 | iou = np.ones(len(gt_bb)) * (-1) 32 | # mask = np.sum(gt_bb > 0, axis=1) == 4 #TODO check all dataset 33 | mask = np.sum(gt_bb[:, 2:] > 0, axis=1) == 2 34 | iou[mask] = overlap_ratio(gt_bb[mask], result_bb[mask]) 35 | for i in range(len(thresholds_overlap)): 36 | success[i] = np.sum(iou > thresholds_overlap[i]) / float(n_frame) 37 | return success 38 | 39 | def success_error(gt_center, result_center, thresholds, n_frame): 40 | # n_frame = len(gt_center) 41 | success = np.zeros(len(thresholds)) 42 | dist = np.ones(len(gt_center)) * (-1) 43 | mask = np.sum(gt_center > 0, axis=1) == 2 44 | dist[mask] = np.sqrt(np.sum( 45 | np.power(gt_center[mask] - result_center[mask], 2), axis=1)) 46 | for i in range(len(thresholds)): 47 | success[i] = np.sum(dist <= thresholds[i]) / float(n_frame) 48 | return success 49 | 50 | 51 | -------------------------------------------------------------------------------- /toolkit/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .draw_success_precision import draw_success_precision 2 | -------------------------------------------------------------------------------- /toolkit/visualization/draw_success_precision.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from .draw_utils import COLOR, LINE_STYLE 5 | 6 | def draw_success_precision(success_ret, name, videos, attr, precision_ret=None, 7 | norm_precision_ret=None, bold_name=None, axis=[0, 1]): 8 | # success plot 9 | fig, ax = plt.subplots() 10 | ax.grid(b=True) 11 | ax.set_aspect(1) 12 | plt.xlabel('Overlap threshold') 13 | plt.ylabel('Success rate') 14 | if attr == 'ALL': 15 | plt.title(r'\textbf{Success plots of OPE on %s}' % (name)) 16 | else: 17 | plt.title(r'\textbf{Success plots of OPE - %s}' % (attr)) 18 | plt.axis([0, 1]+axis) 19 | success = {} 20 | thresholds = np.arange(0, 1.05, 0.05) 21 | for tracker_name in success_ret.keys(): 22 | value = [v for k, v in success_ret[tracker_name].items() if k in videos] 23 | success[tracker_name] = np.mean(value) 24 | for idx, (tracker_name, auc) in \ 25 | enumerate(sorted(success.items(), key=lambda x:x[1], reverse=True)): 26 | if tracker_name == bold_name: 27 | label = r"\textbf{[%.3f] %s}" % (auc, tracker_name) 28 | else: 29 | label = "[%.3f] " % (auc) + tracker_name 30 | value = [v for k, v in success_ret[tracker_name].items() if k in videos] 31 | plt.plot(thresholds, np.mean(value, axis=0), 32 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2) 33 | ax.legend(loc='lower left', labelspacing=0.2) 34 | ax.autoscale(enable=True, axis='both', tight=True) 35 | xmin, xmax, ymin, ymax = plt.axis() 36 | ax.autoscale(enable=False) 37 | ymax += 0.03 38 | plt.axis([xmin, xmax, ymin, ymax]) 39 | plt.xticks(np.arange(xmin, xmax+0.01, 0.1)) 40 | plt.yticks(np.arange(ymin, ymax, 0.1)) 41 | ax.set_aspect((xmax - xmin)/(ymax-ymin)) 42 | plt.show() 43 | 44 | if precision_ret: 45 | # norm precision plot 46 | fig, ax = plt.subplots() 47 | ax.grid(b=True) 48 | ax.set_aspect(50) 49 | plt.xlabel('Location error threshold') 50 | plt.ylabel('Precision') 51 | if attr == 'ALL': 52 | plt.title(r'\textbf{Precision plots of OPE on %s}' % (name)) 53 | else: 54 | plt.title(r'\textbf{Precision plots of OPE - %s}' % (attr)) 55 | plt.axis([0, 50]+axis) 56 | precision = {} 57 | thresholds = np.arange(0, 51, 1) 58 | for tracker_name in precision_ret.keys(): 59 | value = [v for k, v in precision_ret[tracker_name].items() if k in videos] 60 | precision[tracker_name] = np.mean(value, axis=0)[20] 61 | for idx, (tracker_name, pre) in \ 62 | enumerate(sorted(precision.items(), key=lambda x:x[1], reverse=True)): 63 | if tracker_name == bold_name: 64 | label = r"\textbf{[%.3f] %s}" % (pre, tracker_name) 65 | else: 66 | label = "[%.3f] " % (pre) + tracker_name 67 | value = [v for k, v in precision_ret[tracker_name].items() if k in videos] 68 | plt.plot(thresholds, np.mean(value, axis=0), 69 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2) 70 | ax.legend(loc='lower right', labelspacing=0.2) 71 | ax.autoscale(enable=True, axis='both', tight=True) 72 | xmin, xmax, ymin, ymax = plt.axis() 73 | ax.autoscale(enable=False) 74 | ymax += 0.03 75 | plt.axis([xmin, xmax, ymin, ymax]) 76 | plt.xticks(np.arange(xmin, xmax+0.01, 5)) 77 | plt.yticks(np.arange(ymin, ymax, 0.1)) 78 | ax.set_aspect((xmax - xmin)/(ymax-ymin)) 79 | plt.show() 80 | 81 | # norm precision plot 82 | if norm_precision_ret: 83 | fig, ax = plt.subplots() 84 | ax.grid(b=True) 85 | plt.xlabel('Location error threshold') 86 | plt.ylabel('Precision') 87 | if attr == 'ALL': 88 | plt.title(r'\textbf{Normalized Precision plots of OPE on %s}' % (name)) 89 | else: 90 | plt.title(r'\textbf{Normalized Precision plots of OPE - %s}' % (attr)) 91 | norm_precision = {} 92 | thresholds = np.arange(0, 51, 1) / 100 93 | for tracker_name in precision_ret.keys(): 94 | value = [v for k, v in norm_precision_ret[tracker_name].items() if k in videos] 95 | norm_precision[tracker_name] = np.mean(value, axis=0)[20] 96 | for idx, (tracker_name, pre) in \ 97 | enumerate(sorted(norm_precision.items(), key=lambda x:x[1], reverse=True)): 98 | if tracker_name == bold_name: 99 | label = r"\textbf{[%.3f] %s}" % (pre, tracker_name) 100 | else: 101 | label = "[%.3f] " % (pre) + tracker_name 102 | value = [v for k, v in norm_precision_ret[tracker_name].items() if k in videos] 103 | plt.plot(thresholds, np.mean(value, axis=0), 104 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2) 105 | ax.legend(loc='lower right', labelspacing=0.2) 106 | ax.autoscale(enable=True, axis='both', tight=True) 107 | xmin, xmax, ymin, ymax = plt.axis() 108 | ax.autoscale(enable=False) 109 | ymax += 0.03 110 | plt.axis([xmin, xmax, ymin, ymax]) 111 | plt.xticks(np.arange(xmin, xmax+0.01, 0.05)) 112 | plt.yticks(np.arange(ymin, ymax, 0.1)) 113 | ax.set_aspect((xmax - xmin)/(ymax-ymin)) 114 | plt.show() 115 | -------------------------------------------------------------------------------- /toolkit/visualization/draw_utils.py: -------------------------------------------------------------------------------- 1 | 2 | COLOR = ((1, 0, 0), 3 | (0, 1, 0), 4 | (1, 0, 1), 5 | (1, 1, 0), 6 | (0 , 162/255, 232/255), 7 | (0.5, 0.5, 0.5), 8 | (0, 0, 1), 9 | (0, 1, 1), 10 | (136/255, 0 , 21/255), 11 | (255/255, 127/255, 39/255), 12 | (0, 0, 0)) 13 | 14 | LINE_STYLE = ['-', '--', ':', '-', '--', ':', '-', '--', ':', '-'] 15 | 16 | MARKER_STYLE = ['o', 'v', '<', '*', 'D', 'x', '.', 'x', '<', '.'] 17 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import os 7 | import argparse 8 | 9 | from glob import glob 10 | from tqdm import tqdm 11 | from multiprocessing import Pool 12 | from toolkit.datasets import OTBDataset, UAVDataset, LaSOTDataset 13 | from toolkit.evaluation.ope_benchmark import OPEBenchmark 14 | 15 | parser = argparse.ArgumentParser(description='tracking evaluation') 16 | 17 | parser.add_argument('--tracker_path', '-p', type=str, default='./results', 18 | help='tracker result path') 19 | parser.add_argument('--dataset', '-d', type=str, default='UAV123', 20 | help='dataset name') 21 | parser.add_argument('--num', '-n', default=1, type=int, 22 | help='number of thread to eval') 23 | parser.add_argument('--tracker_prefix', '-t', default='', 24 | type=str, help='tracker name') 25 | parser.add_argument('--show_video_level', '-s', dest='show_video_level', 26 | action='store_true') 27 | parser.set_defaults(show_video_level=False) 28 | args = parser.parse_args() 29 | 30 | 31 | def main(): 32 | tracker_dir = os.path.join(args.tracker_path, args.dataset) 33 | trackers = glob(os.path.join(args.tracker_path, 34 | args.dataset, 35 | args.tracker_prefix+'*')) 36 | trackers = [x.split('/')[-1] for x in trackers] 37 | 38 | assert len(trackers) > 0 39 | args.num = min(args.num, len(trackers)) 40 | 41 | root = os.path.realpath(os.path.join(os.path.dirname(__file__), 42 | '../testing_dataset')) 43 | 44 | if 'OTB' in args.dataset: 45 | dataset = OTBDataset(args.dataset, root) 46 | dataset.set_tracker(tracker_dir, trackers) 47 | benchmark = OPEBenchmark(dataset) 48 | success_ret = {} 49 | with Pool(processes=args.num) as pool: 50 | for ret in tqdm(pool.imap_unordered(benchmark.eval_success, 51 | trackers), desc='eval success', total=len(trackers), ncols=100): 52 | success_ret.update(ret) 53 | precision_ret = {} 54 | with Pool(processes=args.num) as pool: 55 | for ret in tqdm(pool.imap_unordered(benchmark.eval_precision, 56 | trackers), desc='eval precision', total=len(trackers), ncols=100): 57 | precision_ret.update(ret) 58 | benchmark.show_result(success_ret, precision_ret, 59 | show_video_level=args.show_video_level) 60 | elif 'LaSOT' == args.dataset: 61 | dataset = LaSOTDataset(args.dataset, root) 62 | dataset.set_tracker(tracker_dir, trackers) 63 | benchmark = OPEBenchmark(dataset) 64 | success_ret = {} 65 | with Pool(processes=args.num) as pool: 66 | for ret in tqdm(pool.imap_unordered(benchmark.eval_success, 67 | trackers), desc='eval success', total=len(trackers), ncols=100): 68 | success_ret.update(ret) 69 | precision_ret = {} 70 | with Pool(processes=args.num) as pool: 71 | for ret in tqdm(pool.imap_unordered(benchmark.eval_precision, 72 | trackers), desc='eval precision', total=len(trackers), ncols=100): 73 | precision_ret.update(ret) 74 | norm_precision_ret = {} 75 | with Pool(processes=args.num) as pool: 76 | for ret in tqdm(pool.imap_unordered(benchmark.eval_norm_precision, 77 | trackers), desc='eval norm precision', total=len(trackers), ncols=100): 78 | norm_precision_ret.update(ret) 79 | benchmark.show_result(success_ret, precision_ret, norm_precision_ret, 80 | show_video_level=args.show_video_level) 81 | elif 'UAV' in args.dataset: 82 | dataset = UAVDataset(args.dataset, root) 83 | dataset.set_tracker(tracker_dir, trackers) 84 | benchmark = OPEBenchmark(dataset) 85 | success_ret = {} 86 | with Pool(processes=args.num) as pool: 87 | for ret in tqdm(pool.imap_unordered(benchmark.eval_success, 88 | trackers), desc='eval success', total=len(trackers), ncols=100): 89 | success_ret.update(ret) 90 | precision_ret = {} 91 | with Pool(processes=args.num) as pool: 92 | for ret in tqdm(pool.imap_unordered(benchmark.eval_precision, 93 | trackers), desc='eval precision', total=len(trackers), ncols=100): 94 | precision_ret.update(ret) 95 | benchmark.show_result(success_ret, precision_ret, 96 | show_video_level=args.show_video_level) 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /tools/testTracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) SenseTime. All Rights Reserved. 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import argparse 9 | import os 10 | 11 | import cv2 12 | import torch 13 | import numpy as np 14 | import math 15 | import sys 16 | 17 | sys.path.append('../') 18 | base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 19 | sys.path.append(base_dir) 20 | 21 | from pysot.core.config import cfg 22 | from pysot.utils.bbox import get_axis_aligned_bbox 23 | from pysot.utils.model_load import load_pretrain 24 | from pysot.models.model_builder_gat import ModelBuilder 25 | from toolkit.datasets import DatasetFactory 26 | from pysot.tracker.siamgat_tracker import SiamGATTracker 27 | 28 | 29 | parser = argparse.ArgumentParser(description='siamgat tracking') 30 | parser.add_argument('--video', default='', type=str, 31 | help='eval one special video') 32 | parser.add_argument('--dataset', type=str, default='GOT-10k', 33 | help='datasets') # OTB100 LaSOT UAV123 GOT-10k 34 | parser.add_argument('--vis', action='store_true', default=True, 35 | help='whether visualzie result') 36 | parser.add_argument('--snapshot', type=str, default='snapshot/got10k_model.pth', 37 | help='snapshot of models to eval') 38 | parser.add_argument('--config', type=str, default='../experiments/siamgat_googlenet_got10k/config.yaml', 39 | help='config file') 40 | args = parser.parse_args() 41 | 42 | torch.set_num_threads(1) 43 | 44 | 45 | def main(): 46 | # load config 47 | cfg.merge_from_file(args.config) 48 | 49 | # Test dataset 50 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 51 | dataset_root = os.path.join(cur_dir, 'test_dataset', args.dataset) 52 | 53 | # set hyper parameters 54 | params = getattr(cfg.HP_SEARCH, args.dataset) 55 | cfg.TRACK.LR = params[0] 56 | cfg.TRACK.PENALTY_K = params[1] 57 | cfg.TRACK.WINDOW_INFLUENCE = params[2] 58 | 59 | model = ModelBuilder() 60 | 61 | # load model 62 | model = load_pretrain(model, args.snapshot).cuda().eval() 63 | 64 | # build tracker 65 | tracker = SiamGATTracker(model) 66 | 67 | # create dataset 68 | dataset = DatasetFactory.create_dataset(name=args.dataset, 69 | dataset_root=dataset_root, 70 | load_img=False) 71 | 72 | model_name = args.snapshot.split('/')[-1].split('.')[-2] 73 | 74 | # OPE tracking 75 | for v_idx, video in enumerate(dataset): 76 | if args.video != '': 77 | # test one special video 78 | if video.name != args.video: 79 | continue 80 | toc = 0 81 | pred_bboxes = [] 82 | track_times = [] 83 | 84 | for idx, (img, gt_bbox) in enumerate(video): 85 | tic = cv2.getTickCount() 86 | 87 | if idx == 0: 88 | cx, cy, w, h = get_axis_aligned_bbox(np.array(gt_bbox)) 89 | gt_bbox_ = [cx-(w-1)/2, cy-(h-1)/2, w, h] 90 | tracker.init(img, gt_bbox_) 91 | pred_bbox = gt_bbox_ 92 | pred_bboxes.append(pred_bbox) 93 | else: 94 | outputs = tracker.track(img) 95 | pred_bbox = outputs['bbox'] 96 | pred_bboxes.append(pred_bbox) 97 | toc += cv2.getTickCount() - tic 98 | track_times.append((cv2.getTickCount() - tic)/cv2.getTickFrequency()) 99 | if idx == 0: 100 | cv2.destroyAllWindows() 101 | if args.vis and idx > 0: 102 | if not any(map(math.isnan,gt_bbox)): 103 | gt_bbox = list(map(int, gt_bbox)) 104 | pred_bbox = list(map(int, pred_bbox)) 105 | cv2.rectangle(img, (gt_bbox[0], gt_bbox[1]), 106 | (gt_bbox[0]+gt_bbox[2], gt_bbox[1]+gt_bbox[3]), (0, 255, 0), 3) 107 | cv2.rectangle(img, (pred_bbox[0], pred_bbox[1]), 108 | (pred_bbox[0]+pred_bbox[2], pred_bbox[1]+pred_bbox[3]), (0, 255, 255), 3) 109 | cv2.putText(img, str(idx), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2) 110 | cv2.imshow(video.name, img) 111 | cv2.waitKey(1) 112 | toc /= cv2.getTickFrequency() 113 | 114 | # save results 115 | if 'GOT-10k' == args.dataset: 116 | video_path = os.path.join('results', args.dataset, model_name, video.name) 117 | if not os.path.isdir(video_path): 118 | os.makedirs(video_path) 119 | result_path = os.path.join(video_path, '{}_001.txt'.format(video.name)) 120 | with open(result_path, 'w') as f: 121 | for x in pred_bboxes: 122 | f.write(','.join([str(i) for i in x]) + '\n') 123 | result_path = os.path.join(video_path, 124 | '{}_time.txt'.format(video.name)) 125 | with open(result_path, 'w') as f: 126 | for x in track_times: 127 | f.write("{:.6f}\n".format(x)) 128 | else: 129 | model_path = os.path.join('results', args.dataset, model_name) 130 | if not os.path.isdir(model_path): 131 | os.makedirs(model_path) 132 | result_path = os.path.join(model_path, '{}.txt'.format(video.name)) 133 | with open(result_path, 'w') as f: 134 | for x in pred_bboxes: 135 | f.write(','.join([str(i) for i in x])+'\n') 136 | print('({:3d}) Video: {:12s} Time: {:5.1f}s Speed: {:3.1f}fps'.format( 137 | v_idx+1, video.name, toc, idx / toc)) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /training_dataset/coco/gen_json.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | from os.path import join 3 | import json 4 | 5 | 6 | dataDir = '.' 7 | for dataType in ['val2017', 'train2017']: 8 | dataset = dict() 9 | annFile = '{}/annotations/instances_{}.json'.format(dataDir,dataType) 10 | coco = COCO(annFile) 11 | n_imgs = len(coco.imgs) 12 | for n, img_id in enumerate(coco.imgs): 13 | print('subset: {} image id: {:04d} / {:04d}'.format(dataType, n, n_imgs)) 14 | img = coco.loadImgs(img_id)[0] 15 | annIds = coco.getAnnIds(imgIds=img['id'], iscrowd=None) 16 | anns = coco.loadAnns(annIds) 17 | video_crop_base_path = join(dataType, img['file_name'].split('/')[-1].split('.')[0]) 18 | 19 | if len(anns) > 0: 20 | dataset[video_crop_base_path] = dict() 21 | 22 | for trackid, ann in enumerate(anns): 23 | rect = ann['bbox'] 24 | c = ann['category_id'] 25 | bbox = [rect[0], rect[1], rect[0]+rect[2], rect[1]+rect[3]] 26 | if rect[2] <= 0 or rect[3] <= 0: # lead nan error in cls. 27 | continue 28 | dataset[video_crop_base_path]['{:02d}'.format(trackid)] = {'000000': bbox} 29 | 30 | print('save json (dataset), please wait 20 seconds~') 31 | json.dump(dataset, open('{}.json'.format(dataType), 'w'), indent=4, sort_keys=True) 32 | print('done!') 33 | 34 | -------------------------------------------------------------------------------- /training_dataset/coco/par_crop.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | import cv2 3 | import numpy as np 4 | from os.path import join, isdir 5 | from os import mkdir, makedirs 6 | from concurrent import futures 7 | import sys 8 | import time 9 | 10 | 11 | # Print iterations progress (thanks StackOverflow) 12 | def printProgress(iteration, total, prefix='', suffix='', decimals=1, barLength=100): 13 | """ 14 | Call in a loop to create terminal progress bar 15 | @params: 16 | iteration - Required : current iteration (Int) 17 | total - Required : total iterations (Int) 18 | prefix - Optional : prefix string (Str) 19 | suffix - Optional : suffix string (Str) 20 | decimals - Optional : positive number of decimals in percent complete (Int) 21 | barLength - Optional : character length of bar (Int) 22 | """ 23 | formatStr = "{0:." + str(decimals) + "f}" 24 | percents = formatStr.format(100 * (iteration / float(total))) 25 | filledLength = int(round(barLength * iteration / float(total))) 26 | bar = '' * filledLength + '-' * (barLength - filledLength) 27 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 28 | if iteration == total: 29 | sys.stdout.write('\x1b[2K\r') 30 | sys.stdout.flush() 31 | 32 | 33 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 34 | a = (out_sz-1) / (bbox[2]-bbox[0]) 35 | b = (out_sz-1) / (bbox[3]-bbox[1]) 36 | c = -a * bbox[0] 37 | d = -b * bbox[1] 38 | mapping = np.array([[a, 0, c], 39 | [0, b, d]]).astype(np.float) 40 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 41 | return crop 42 | 43 | 44 | def pos_s_2_bbox(pos, s): 45 | return [pos[0]-s/2, pos[1]-s/2, pos[0]+s/2, pos[1]+s/2] 46 | 47 | 48 | def crop_like_SiamFC(image, bbox, context_amount=0.5, exemplar_size=127, instanc_size=255, padding=(0, 0, 0)): 49 | target_pos = [(bbox[2]+bbox[0])/2., (bbox[3]+bbox[1])/2.] 50 | target_size = [bbox[2]-bbox[0], bbox[3]-bbox[1]] 51 | wc_z = target_size[1] + context_amount * sum(target_size) 52 | hc_z = target_size[0] + context_amount * sum(target_size) 53 | s_z = np.sqrt(wc_z * hc_z) 54 | scale_z = exemplar_size / s_z 55 | d_search = (instanc_size - exemplar_size) / 2 56 | pad = d_search / scale_z 57 | s_x = s_z + 2 * pad 58 | 59 | z = crop_hwc(image, pos_s_2_bbox(target_pos, s_z), exemplar_size, padding) 60 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), instanc_size, padding) 61 | return z, x 62 | 63 | 64 | def crop_img(img, anns, set_crop_base_path, set_img_base_path, instanc_size=511): 65 | frame_crop_base_path = join(set_crop_base_path, img['file_name'].split('/')[-1].split('.')[0]) 66 | if not isdir(frame_crop_base_path): makedirs(frame_crop_base_path) 67 | 68 | im = cv2.imread('{}/{}'.format(set_img_base_path, img['file_name'])) 69 | avg_chans = np.mean(im, axis=(0, 1)) 70 | for trackid, ann in enumerate(anns): 71 | rect = ann['bbox'] 72 | bbox = [rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3]] 73 | if rect[2] <= 0 or rect[3] <=0: 74 | continue 75 | z, x = crop_like_SiamFC(im, bbox, instanc_size=instanc_size, padding=avg_chans) 76 | cv2.imwrite(join(frame_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(0, trackid)), z) 77 | cv2.imwrite(join(frame_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(0, trackid)), x) 78 | 79 | 80 | def main(instanc_size=511, num_threads=12): 81 | dataDir = '.' 82 | crop_path = './crop{:d}'.format(instanc_size) 83 | if not isdir(crop_path): mkdir(crop_path) 84 | 85 | for dataType in ['val2017', 'train2017']: 86 | set_crop_base_path = join(crop_path, dataType) 87 | set_img_base_path = join(dataDir, dataType) 88 | 89 | annFile = '{}/annotations/instances_{}.json'.format(dataDir,dataType) 90 | coco = COCO(annFile) 91 | n_imgs = len(coco.imgs) 92 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 93 | fs = [executor.submit(crop_img, coco.loadImgs(id)[0], 94 | coco.loadAnns(coco.getAnnIds(imgIds=id, iscrowd=None)), 95 | set_crop_base_path, set_img_base_path, instanc_size) for id in coco.imgs] 96 | for i, f in enumerate(futures.as_completed(fs)): 97 | # Write progress to error so that it can be seen 98 | printProgress(i, n_imgs, prefix=dataType, suffix='Done ', barLength=40) 99 | print('done') 100 | 101 | 102 | if __name__ == '__main__': 103 | since = time.time() 104 | main(int(sys.argv[1]), int(sys.argv[2])) 105 | time_elapsed = time.time() - since 106 | print('Total complete in {:.0f}m {:.0f}s'.format( 107 | time_elapsed // 60, time_elapsed % 60)) 108 | -------------------------------------------------------------------------------- /training_dataset/coco/pycocotools/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | # install pycocotools locally 3 | python3 setup.py build_ext --inplace 4 | rm -rf build 5 | 6 | install: 7 | # install pycocotools to the Python site-packages 8 | python3 setup.py build_ext install 9 | rm -rf build 10 | clean: 11 | rm _mask.c _mask.cpython-36m-x86_64-linux-gnu.so 12 | -------------------------------------------------------------------------------- /training_dataset/coco/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /training_dataset/coco/pycocotools/common/gason.h: -------------------------------------------------------------------------------- 1 | // https://github.com/vivkin/gason - pulled January 10, 2016 2 | #pragma once 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | enum JsonTag { 9 | JSON_NUMBER = 0, 10 | JSON_STRING, 11 | JSON_ARRAY, 12 | JSON_OBJECT, 13 | JSON_TRUE, 14 | JSON_FALSE, 15 | JSON_NULL = 0xF 16 | }; 17 | 18 | struct JsonNode; 19 | 20 | #define JSON_VALUE_PAYLOAD_MASK 0x00007FFFFFFFFFFFULL 21 | #define JSON_VALUE_NAN_MASK 0x7FF8000000000000ULL 22 | #define JSON_VALUE_TAG_MASK 0xF 23 | #define JSON_VALUE_TAG_SHIFT 47 24 | 25 | union JsonValue { 26 | uint64_t ival; 27 | double fval; 28 | 29 | JsonValue(double x) 30 | : fval(x) { 31 | } 32 | JsonValue(JsonTag tag = JSON_NULL, void *payload = nullptr) { 33 | assert((uintptr_t)payload <= JSON_VALUE_PAYLOAD_MASK); 34 | ival = JSON_VALUE_NAN_MASK | ((uint64_t)tag << JSON_VALUE_TAG_SHIFT) | (uintptr_t)payload; 35 | } 36 | bool isDouble() const { 37 | return (int64_t)ival <= (int64_t)JSON_VALUE_NAN_MASK; 38 | } 39 | JsonTag getTag() const { 40 | return isDouble() ? JSON_NUMBER : JsonTag((ival >> JSON_VALUE_TAG_SHIFT) & JSON_VALUE_TAG_MASK); 41 | } 42 | uint64_t getPayload() const { 43 | assert(!isDouble()); 44 | return ival & JSON_VALUE_PAYLOAD_MASK; 45 | } 46 | double toNumber() const { 47 | assert(getTag() == JSON_NUMBER); 48 | return fval; 49 | } 50 | char *toString() const { 51 | assert(getTag() == JSON_STRING); 52 | return (char *)getPayload(); 53 | } 54 | JsonNode *toNode() const { 55 | assert(getTag() == JSON_ARRAY || getTag() == JSON_OBJECT); 56 | return (JsonNode *)getPayload(); 57 | } 58 | }; 59 | 60 | struct JsonNode { 61 | JsonValue value; 62 | JsonNode *next; 63 | char *key; 64 | }; 65 | 66 | struct JsonIterator { 67 | JsonNode *p; 68 | 69 | void operator++() { 70 | p = p->next; 71 | } 72 | bool operator!=(const JsonIterator &x) const { 73 | return p != x.p; 74 | } 75 | JsonNode *operator*() const { 76 | return p; 77 | } 78 | JsonNode *operator->() const { 79 | return p; 80 | } 81 | }; 82 | 83 | inline JsonIterator begin(JsonValue o) { 84 | return JsonIterator{o.toNode()}; 85 | } 86 | inline JsonIterator end(JsonValue) { 87 | return JsonIterator{nullptr}; 88 | } 89 | 90 | #define JSON_ERRNO_MAP(XX) \ 91 | XX(OK, "ok") \ 92 | XX(BAD_NUMBER, "bad number") \ 93 | XX(BAD_STRING, "bad string") \ 94 | XX(BAD_IDENTIFIER, "bad identifier") \ 95 | XX(STACK_OVERFLOW, "stack overflow") \ 96 | XX(STACK_UNDERFLOW, "stack underflow") \ 97 | XX(MISMATCH_BRACKET, "mismatch bracket") \ 98 | XX(UNEXPECTED_CHARACTER, "unexpected character") \ 99 | XX(UNQUOTED_KEY, "unquoted key") \ 100 | XX(BREAKING_BAD, "breaking bad") \ 101 | XX(ALLOCATION_FAILURE, "allocation failure") 102 | 103 | enum JsonErrno { 104 | #define XX(no, str) JSON_##no, 105 | JSON_ERRNO_MAP(XX) 106 | #undef XX 107 | }; 108 | 109 | const char *jsonStrError(int err); 110 | 111 | class JsonAllocator { 112 | struct Zone { 113 | Zone *next; 114 | size_t used; 115 | } *head = nullptr; 116 | 117 | public: 118 | JsonAllocator() = default; 119 | JsonAllocator(const JsonAllocator &) = delete; 120 | JsonAllocator &operator=(const JsonAllocator &) = delete; 121 | JsonAllocator(JsonAllocator &&x) : head(x.head) { 122 | x.head = nullptr; 123 | } 124 | JsonAllocator &operator=(JsonAllocator &&x) { 125 | head = x.head; 126 | x.head = nullptr; 127 | return *this; 128 | } 129 | ~JsonAllocator() { 130 | deallocate(); 131 | } 132 | void *allocate(size_t size); 133 | void deallocate(); 134 | }; 135 | 136 | int jsonParse(char *str, char **endptr, JsonValue *value, JsonAllocator &allocator); 137 | -------------------------------------------------------------------------------- /training_dataset/coco/pycocotools/common/maskApi.h: -------------------------------------------------------------------------------- 1 | /************************************************************************** 2 | * Microsoft COCO Toolbox. version 2.0 3 | * Data, paper, and tutorials available at: http://mscoco.org/ 4 | * Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 5 | * Licensed under the Simplified BSD License [see coco/license.txt] 6 | **************************************************************************/ 7 | #pragma once 8 | 9 | typedef unsigned int uint; 10 | typedef unsigned long siz; 11 | typedef unsigned char byte; 12 | typedef double* BB; 13 | typedef struct { siz h, w, m; uint *cnts; } RLE; 14 | 15 | /* Initialize/destroy RLE. */ 16 | void rleInit( RLE *R, siz h, siz w, siz m, uint *cnts ); 17 | void rleFree( RLE *R ); 18 | 19 | /* Initialize/destroy RLE array. */ 20 | void rlesInit( RLE **R, siz n ); 21 | void rlesFree( RLE **R, siz n ); 22 | 23 | /* Encode binary masks using RLE. */ 24 | void rleEncode( RLE *R, const byte *mask, siz h, siz w, siz n ); 25 | 26 | /* Decode binary masks encoded via RLE. */ 27 | void rleDecode( const RLE *R, byte *mask, siz n ); 28 | 29 | /* Compute union or intersection of encoded masks. */ 30 | void rleMerge( const RLE *R, RLE *M, siz n, int intersect ); 31 | 32 | /* Compute area of encoded masks. */ 33 | void rleArea( const RLE *R, siz n, uint *a ); 34 | 35 | /* Compute intersection over union between masks. */ 36 | void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ); 37 | 38 | /* Compute non-maximum suppression between bounding masks */ 39 | void rleNms( RLE *dt, siz n, uint *keep, double thr ); 40 | 41 | /* Compute intersection over union between bounding boxes. */ 42 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ); 43 | 44 | /* Compute non-maximum suppression between bounding boxes */ 45 | void bbNms( BB dt, siz n, uint *keep, double thr ); 46 | 47 | /* Get bounding boxes surrounding encoded masks. */ 48 | void rleToBbox( const RLE *R, BB bb, siz n ); 49 | 50 | /* Convert bounding boxes to encoded masks. */ 51 | void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n ); 52 | 53 | /* Convert polygon to encoded mask. */ 54 | void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w ); 55 | 56 | /* Get compressed string representation of encoded mask. */ 57 | char* rleToString( const RLE *R ); 58 | 59 | /* Convert from compressed string representation of encoded mask. */ 60 | void rleFrString( RLE *R, char *s, siz h, siz w ); 61 | -------------------------------------------------------------------------------- /training_dataset/coco/pycocotools/mask.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | #import pycocotools._mask as _mask 4 | from . import _mask 5 | 6 | # Interface for manipulating masks stored in RLE format. 7 | # 8 | # RLE is a simple yet efficient format for storing binary masks. RLE 9 | # first divides a vector (or vectorized image) into a series of piecewise 10 | # constant regions and then for each piece simply stores the length of 11 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 12 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 13 | # (note that the odd counts are always the numbers of zeros). Instead of 14 | # storing the counts directly, additional compression is achieved with a 15 | # variable bitrate representation based on a common scheme called LEB128. 16 | # 17 | # Compression is greatest given large piecewise constant regions. 18 | # Specifically, the size of the RLE is proportional to the number of 19 | # *boundaries* in M (or for an image the number of boundaries in the y 20 | # direction). Assuming fairly simple shapes, the RLE representation is 21 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage 22 | # is substantially lower, especially for large simple objects (large n). 23 | # 24 | # Many common operations on masks can be computed directly using the RLE 25 | # (without need for decoding). This includes computations such as area, 26 | # union, intersection, etc. All of these operations are linear in the 27 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area 28 | # of the object. Computing these operations on the original mask is O(n). 29 | # Thus, using the RLE can result in substantial computational savings. 30 | # 31 | # The following API functions are defined: 32 | # encode - Encode binary masks using RLE. 33 | # decode - Decode binary masks encoded via RLE. 34 | # merge - Compute union or intersection of encoded masks. 35 | # iou - Compute intersection over union between masks. 36 | # area - Compute area of encoded masks. 37 | # toBbox - Get bounding boxes surrounding encoded masks. 38 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask. 39 | # 40 | # Usage: 41 | # Rs = encode( masks ) 42 | # masks = decode( Rs ) 43 | # R = merge( Rs, intersect=false ) 44 | # o = iou( dt, gt, iscrowd ) 45 | # a = area( Rs ) 46 | # bbs = toBbox( Rs ) 47 | # Rs = frPyObjects( [pyObjects], h, w ) 48 | # 49 | # In the API the following formats are used: 50 | # Rs - [dict] Run-length encoding of binary masks 51 | # R - dict Run-length encoding of binary mask 52 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order) 53 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore 54 | # bbs - [nx4] Bounding box(es) stored as [x y w h] 55 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list) 56 | # dt,gt - May be either bounding boxes or encoded masks 57 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 58 | # 59 | # Finally, a note about the intersection over union (iou) computation. 60 | # The standard iou of a ground truth (gt) and detected (dt) object is 61 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 62 | # For "crowd" regions, we use a modified criteria. If a gt object is 63 | # marked as "iscrowd", we allow a dt to match any subregion of the gt. 64 | # Choosing gt' in the crowd gt that best matches the dt can be done using 65 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 66 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 67 | # For crowd gt regions we use this modified criteria above for the iou. 68 | # 69 | # To compile run "python setup.py build_ext --inplace" 70 | # Please do not contact us for help with compiling. 71 | # 72 | # Microsoft COCO Toolbox. version 2.0 73 | # Data, paper, and tutorials available at: http://mscoco.org/ 74 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 75 | # Licensed under the Simplified BSD License [see coco/license.txt] 76 | 77 | iou = _mask.iou 78 | merge = _mask.merge 79 | frPyObjects = _mask.frPyObjects 80 | 81 | def encode(bimask): 82 | if len(bimask.shape) == 3: 83 | return _mask.encode(bimask) 84 | elif len(bimask.shape) == 2: 85 | h, w = bimask.shape 86 | return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0] 87 | 88 | def decode(rleObjs): 89 | if type(rleObjs) == list: 90 | return _mask.decode(rleObjs) 91 | else: 92 | return _mask.decode([rleObjs])[:,:,0] 93 | 94 | def area(rleObjs): 95 | if type(rleObjs) == list: 96 | return _mask.area(rleObjs) 97 | else: 98 | return _mask.area([rleObjs])[0] 99 | 100 | def toBbox(rleObjs): 101 | if type(rleObjs) == list: 102 | return _mask.toBbox(rleObjs) 103 | else: 104 | return _mask.toBbox([rleObjs])[0] 105 | -------------------------------------------------------------------------------- /training_dataset/coco/pycocotools/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | from distutils.extension import Extension 4 | import numpy as np 5 | 6 | # To compile and install locally run "python setup.py build_ext --inplace" 7 | # To install library to Python site-packages run "python setup.py build_ext install" 8 | 9 | ext_modules = [ 10 | Extension( 11 | '_mask', 12 | sources=['common/maskApi.c', '_mask.pyx'], 13 | include_dirs = [np.get_include(), 'common'], 14 | extra_compile_args=['-Wno-cpp', '-Wno-unused-function', '-std=c99'], 15 | ) 16 | ] 17 | 18 | setup(name='pycocotools', 19 | packages=['pycocotools'], 20 | package_dir = {'pycocotools': '.'}, 21 | version='2.0', 22 | ext_modules= 23 | cythonize(ext_modules) 24 | ) 25 | -------------------------------------------------------------------------------- /training_dataset/coco/readme.md: -------------------------------------------------------------------------------- 1 | # Preprocessing COCO 2 | 3 | ### Download raw images and annotations 4 | 5 | ````shell 6 | wget http://images.cocodataset.org/zips/train2017.zip 7 | wget http://images.cocodataset.org/zips/val2017.zip 8 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 9 | 10 | unzip ./train2017.zip 11 | unzip ./val2017.zip 12 | unzip ./annotations_trainval2017.zip 13 | cd pycocotools && make && cd .. 14 | ```` 15 | 16 | ### Crop & Generate data info 17 | 18 | ````shell 19 | #python par_crop.py [crop_size] [num_threads] 20 | python par_crop.py 511 12 21 | python gen_json.py 22 | ```` 23 | -------------------------------------------------------------------------------- /training_dataset/det/gen_json.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import mkdir 3 | import glob 4 | import xml.etree.ElementTree as ET 5 | import json 6 | 7 | js = {} 8 | VID_base_path = './ILSVRC' 9 | ann_base_path = join(VID_base_path, 'Annotations/DET/train/') 10 | sub_sets = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i') 11 | for sub_set in sub_sets: 12 | sub_set_base_path = join(ann_base_path, sub_set) 13 | 14 | if 'a' == sub_set: 15 | xmls = sorted(glob.glob(join(sub_set_base_path, '*', '*.xml'))) 16 | else: 17 | xmls = sorted(glob.glob(join(sub_set_base_path, '*.xml'))) 18 | n_imgs = len(xmls) 19 | for f, xml in enumerate(xmls): 20 | print('subset: {} frame id: {:08d} / {:08d}'.format(sub_set, f, n_imgs)) 21 | xmltree = ET.parse(xml) 22 | objects = xmltree.findall('object') 23 | 24 | video = join(sub_set, xml.split('/')[-1].split('.')[0]) 25 | 26 | for id, object_iter in enumerate(objects): 27 | bndbox = object_iter.find('bndbox') 28 | bbox = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text), 29 | int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)] 30 | frame = '%06d' % (0) 31 | obj = '%02d' % (id) 32 | if video not in js: 33 | js[video] = {} 34 | if obj not in js[video]: 35 | js[video][obj] = {} 36 | js[video][obj][frame] = bbox 37 | 38 | train = {k:v for (k,v) in js.items() if 'i/' not in k} 39 | val = {k:v for (k,v) in js.items() if 'i/' in k} 40 | 41 | json.dump(train, open('train.json', 'w'), indent=4, sort_keys=True) 42 | json.dump(val, open('val.json', 'w'), indent=4, sort_keys=True) 43 | -------------------------------------------------------------------------------- /training_dataset/det/par_crop.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import mkdir, makedirs 3 | import cv2 4 | import numpy as np 5 | import glob 6 | import xml.etree.ElementTree as ET 7 | from concurrent import futures 8 | import time 9 | import sys 10 | 11 | 12 | # Print iterations progress (thanks StackOverflow) 13 | def printProgress(iteration, total, prefix='', suffix='', decimals=1, barLength=100): 14 | """ 15 | Call in a loop to create terminal progress bar 16 | @params: 17 | iteration - Required : current iteration (Int) 18 | total - Required : total iterations (Int) 19 | prefix - Optional : prefix string (Str) 20 | suffix - Optional : suffix string (Str) 21 | decimals - Optional : positive number of decimals in percent complete (Int) 22 | barLength - Optional : character length of bar (Int) 23 | """ 24 | formatStr = "{0:." + str(decimals) + "f}" 25 | percents = formatStr.format(100 * (iteration / float(total))) 26 | filledLength = int(round(barLength * iteration / float(total))) 27 | bar = '' * filledLength + '-' * (barLength - filledLength) 28 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 29 | if iteration == total: 30 | sys.stdout.write('\x1b[2K\r') 31 | sys.stdout.flush() 32 | 33 | 34 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 35 | a = (out_sz - 1) / (bbox[2] - bbox[0]) 36 | b = (out_sz - 1) / (bbox[3] - bbox[1]) 37 | c = -a * bbox[0] 38 | d = -b * bbox[1] 39 | mapping = np.array([[a, 0, c], 40 | [0, b, d]]).astype(np.float) 41 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 42 | return crop 43 | 44 | 45 | def pos_s_2_bbox(pos, s): 46 | return [pos[0] - s / 2, pos[1] - s / 2, pos[0] + s / 2, pos[1] + s / 2] 47 | 48 | 49 | def crop_like_SiamFC(image, bbox, context_amount=0.5, exemplar_size=127, instanc_size=255, padding=(0, 0, 0)): 50 | target_pos = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.] 51 | target_size = [bbox[2] - bbox[0], bbox[3] - bbox[1]] 52 | wc_z = target_size[1] + context_amount * sum(target_size) 53 | hc_z = target_size[0] + context_amount * sum(target_size) 54 | s_z = np.sqrt(wc_z * hc_z) 55 | scale_z = exemplar_size / s_z 56 | d_search = (instanc_size - exemplar_size) / 2 57 | pad = d_search / scale_z 58 | s_x = s_z + 2 * pad 59 | 60 | z = crop_hwc(image, pos_s_2_bbox(target_pos, s_z), exemplar_size, padding) 61 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), instanc_size, padding) 62 | return z, x 63 | 64 | 65 | def crop_xml(xml, sub_set_crop_path, instanc_size=511): 66 | xmltree = ET.parse(xml) 67 | objects = xmltree.findall('object') 68 | 69 | frame_crop_base_path = join(sub_set_crop_path, xml.split('/')[-1].split('.')[0]) 70 | if not isdir(frame_crop_base_path): makedirs(frame_crop_base_path) 71 | 72 | img_path = xml.replace('xml', 'JPEG').replace('Annotations', 'Data') 73 | 74 | im = cv2.imread(img_path) 75 | avg_chans = np.mean(im, axis=(0, 1)) 76 | 77 | for id, object_iter in enumerate(objects): 78 | bndbox = object_iter.find('bndbox') 79 | bbox = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text), 80 | int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)] 81 | 82 | z, x = crop_like_SiamFC(im, bbox, instanc_size=instanc_size, padding=avg_chans) 83 | cv2.imwrite(join(frame_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(0, id)), z) 84 | cv2.imwrite(join(frame_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(0, id)), x) 85 | 86 | 87 | def main(instanc_size=511, num_threads=24): 88 | crop_path = './crop{:d}'.format(instanc_size) 89 | if not isdir(crop_path): mkdir(crop_path) 90 | VID_base_path = './ILSVRC' 91 | ann_base_path = join(VID_base_path, 'Annotations/DET/train/') 92 | sub_sets = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i') 93 | for sub_set in sub_sets: 94 | sub_set_base_path = join(ann_base_path, sub_set) 95 | if 'a' == sub_set: 96 | xmls = sorted(glob.glob(join(sub_set_base_path, '*', '*.xml'))) 97 | else: 98 | xmls = sorted(glob.glob(join(sub_set_base_path, '*.xml'))) 99 | 100 | n_imgs = len(xmls) 101 | sub_set_crop_path = join(crop_path, sub_set) 102 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 103 | fs = [executor.submit(crop_xml, xml, sub_set_crop_path, instanc_size) for xml in xmls] 104 | for i, f in enumerate(futures.as_completed(fs)): 105 | printProgress(i, n_imgs, prefix=sub_set, suffix='Done ', barLength=80) 106 | 107 | 108 | if __name__ == '__main__': 109 | since = time.time() 110 | main(int(sys.argv[1]), int(sys.argv[2])) 111 | time_elapsed = time.time() - since 112 | print('Total complete in {:.0f}m {:.0f}s'.format( 113 | time_elapsed // 60, time_elapsed % 60)) -------------------------------------------------------------------------------- /training_dataset/det/readme.md: -------------------------------------------------------------------------------- 1 | # Preprocessing DET(Object detection) 2 | Large Scale Visual Recognition Challenge 2015 (ILSVRC2015) 3 | 4 | ### Download dataset 5 | 6 | ````shell 7 | wget http://image-net.org/image/ILSVRC2015/ILSVRC2015_DET.tar.gz 8 | tar -xzvf ./ILSVRC2015_DET.tar.gz 9 | 10 | ln -sfb $PWD/ILSVRC/Annotations/DET/train/ILSVRC2013_train ILSVRC/Annotations/DET/train/a 11 | ln -sfb $PWD/ILSVRC/Annotations/DET/train/ILSVRC2014_train_0000 ILSVRC/Annotations/DET/train/b 12 | ln -sfb $PWD/ILSVRC/Annotations/DET/train/ILSVRC2014_train_0001 ILSVRC/Annotations/DET/train/c 13 | ln -sfb $PWD/ILSVRC/Annotations/DET/train/ILSVRC2014_train_0002 ILSVRC/Annotations/DET/train/d 14 | ln -sfb $PWD/ILSVRC/Annotations/DET/train/ILSVRC2014_train_0003 ILSVRC/Annotations/DET/train/e 15 | ln -sfb $PWD/ILSVRC/Annotations/DET/train/ILSVRC2014_train_0004 ILSVRC/Annotations/DET/train/f 16 | ln -sfb $PWD/ILSVRC/Annotations/DET/train/ILSVRC2014_train_0005 ILSVRC/Annotations/DET/train/g 17 | ln -sfb $PWD/ILSVRC/Annotations/DET/train/ILSVRC2014_train_0006 ILSVRC/Annotations/DET/train/h 18 | ln -sfb $PWD/ILSVRC/Annotations/DET/val ILSVRC/Annotations/DET/train/i 19 | 20 | ln -sfb $PWD/ILSVRC/Data/DET/train/ILSVRC2013_train ILSVRC/Data/DET/train/a 21 | ln -sfb $PWD/ILSVRC/Data/DET/train/ILSVRC2014_train_0000 ILSVRC/Data/DET/train/b 22 | ln -sfb $PWD/ILSVRC/Data/DET/train/ILSVRC2014_train_0001 ILSVRC/Data/DET/train/c 23 | ln -sfb $PWD/ILSVRC/Data/DET/train/ILSVRC2014_train_0002 ILSVRC/Data/DET/train/d 24 | ln -sfb $PWD/ILSVRC/Data/DET/train/ILSVRC2014_train_0003 ILSVRC/Data/DET/train/e 25 | ln -sfb $PWD/ILSVRC/Data/DET/train/ILSVRC2014_train_0004 ILSVRC/Data/DET/train/f 26 | ln -sfb $PWD/ILSVRC/Data/DET/train/ILSVRC2014_train_0005 ILSVRC/Data/DET/train/g 27 | ln -sfb $PWD/ILSVRC/Data/DET/train/ILSVRC2014_train_0006 ILSVRC/Data/DET/train/h 28 | ln -sfb $PWD/ILSVRC/Data/DET/val ILSVRC/Data/DET/train/i 29 | ```` 30 | 31 | ### Crop & Generate data info 32 | 33 | ````shell 34 | #python par_crop.py [crop_size] [num_threads] 35 | python par_crop.py 511 12 36 | python gen_json.py 37 | ```` 38 | -------------------------------------------------------------------------------- /training_dataset/got10k/gen_json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import unicode_literals 5 | import json 6 | from os.path import join, exists 7 | import os 8 | import pandas as pd 9 | 10 | dataset_path = 'data' 11 | train_sets = ['GOT-10k_Train_split_01','GOT-10k_Train_split_02','GOT-10k_Train_split_03','GOT-10k_Train_split_04', 12 | 'GOT-10k_Train_split_05','GOT-10k_Train_split_06','GOT-10k_Train_split_07','GOT-10k_Train_split_08', 13 | 'GOT-10k_Train_split_09','GOT-10k_Train_split_10','GOT-10k_Train_split_11','GOT-10k_Train_split_12', 14 | 'GOT-10k_Train_split_13','GOT-10k_Train_split_14','GOT-10k_Train_split_15','GOT-10k_Train_split_16', 15 | 'GOT-10k_Train_split_17','GOT-10k_Train_split_18','GOT-10k_Train_split_19'] 16 | val_set = ['val'] 17 | d_sets = {'videos_val':val_set,'videos_train':train_sets} 18 | # videos_val = ['MOT17-02-DPM'] 19 | # videos_train = ['MOT17-04-DPM','MOT17-05-DPM','MOT17-09-DPM','MOT17-11-DPM','MOT17-13-DPM'] 20 | # d_sets = {'videos_val':videos_val,'videos_train':videos_train} 21 | 22 | def parse_and_sched(dl_dir='.'): 23 | # For each of the two datasets 24 | js = {} 25 | for d_set in d_sets: 26 | for dataset in d_sets[d_set]: 27 | videos = os.listdir(os.path.join(dataset_path,dataset)) 28 | for video in videos: 29 | if video == 'list.txt': 30 | continue 31 | video = dataset+'/'+video 32 | gt_path = join(dataset_path, video, 'groundtruth.txt') 33 | f = open(gt_path, 'r') 34 | groundtruth = f.readlines() 35 | f.close() 36 | for idx, gt_line in enumerate(groundtruth): 37 | gt_image = gt_line.strip().split(',') 38 | frame = '%06d' % (int(idx)) 39 | obj = '%02d' % (int(0)) 40 | bbox = [int(float(gt_image[0])), int(float(gt_image[1])), 41 | int(float(gt_image[0])) + int(float(gt_image[2])), 42 | int(float(gt_image[1])) + int(float(gt_image[3]))] # xmin,ymin,xmax,ymax 43 | 44 | if video not in js: 45 | js[video] = {} 46 | if obj not in js[video]: 47 | js[video][obj] = {} 48 | js[video][obj][frame] = bbox 49 | if 'videos_val' == d_set: 50 | json.dump(js, open('val.json', 'w'), indent=4, sort_keys=True) 51 | else: 52 | json.dump(js, open('train.json', 'w'), indent=4, sort_keys=True) 53 | js = {} 54 | 55 | print(d_set+': All videos downloaded' ) 56 | 57 | 58 | if __name__ == '__main__': 59 | parse_and_sched() 60 | -------------------------------------------------------------------------------- /training_dataset/got10k/par_crop.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import listdir, mkdir, makedirs 3 | import cv2 4 | import numpy as np 5 | import glob 6 | import xml.etree.ElementTree as ET 7 | from concurrent import futures 8 | import sys 9 | import time 10 | 11 | dataset_path = './data' 12 | sub_sets = ['GOT-10k_Train_split_01','GOT-10k_Train_split_02','GOT-10k_Train_split_03','GOT-10k_Train_split_04', 13 | 'GOT-10k_Train_split_05','GOT-10k_Train_split_06','GOT-10k_Train_split_07','GOT-10k_Train_split_08', 14 | 'GOT-10k_Train_split_09','GOT-10k_Train_split_10','GOT-10k_Train_split_11','GOT-10k_Train_split_12', 15 | 'GOT-10k_Train_split_13','GOT-10k_Train_split_14','GOT-10k_Train_split_15','GOT-10k_Train_split_16', 16 | 'GOT-10k_Train_split_17','GOT-10k_Train_split_18','GOT-10k_Train_split_19','val'] 17 | 18 | # Print iterations progress (thanks StackOverflow) 19 | def printProgress(iteration, total, prefix='', suffix='', decimals=1, barLength=100): 20 | """ 21 | Call in a loop to create terminal progress bar 22 | @params: 23 | iteration - Required : current iteration (Int) 24 | total - Required : total iterations (Int) 25 | prefix - Optional : prefix string (Str) 26 | suffix - Optional : suffix string (Str) 27 | decimals - Optional : positive number of decimals in percent complete (Int) 28 | barLength - Optional : character length of bar (Int) 29 | """ 30 | formatStr = "{0:." + str(decimals) + "f}" 31 | percents = formatStr.format(100 * (iteration / float(total))) 32 | filledLength = int(round(barLength * iteration / float(total))) 33 | bar = '' * filledLength + '-' * (barLength - filledLength) 34 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 35 | if iteration == total: 36 | sys.stdout.write('\x1b[2K\r') 37 | sys.stdout.flush() 38 | 39 | 40 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 41 | a = (out_sz-1) / (bbox[2]-bbox[0]) 42 | b = (out_sz-1) / (bbox[3]-bbox[1]) 43 | c = -a * bbox[0] 44 | d = -b * bbox[1] 45 | mapping = np.array([[a, 0, c], 46 | [0, b, d]]).astype(np.float) 47 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 48 | return crop 49 | 50 | 51 | def pos_s_2_bbox(pos, s): 52 | return [pos[0]-s/2, pos[1]-s/2, pos[0]+s/2, pos[1]+s/2] 53 | 54 | 55 | def crop_like_SiamFC(image, bbox, context_amount=0.5, exemplar_size=127, instanc_size=255, padding=(0, 0, 0)): 56 | target_pos = [(bbox[2]+bbox[0])/2., (bbox[3]+bbox[1])/2.] 57 | target_size = [bbox[2]-bbox[0], bbox[3]-bbox[1]] 58 | wc_z = target_size[1] + context_amount * sum(target_size) 59 | hc_z = target_size[0] + context_amount * sum(target_size) 60 | s_z = np.sqrt(wc_z * hc_z) 61 | scale_z = exemplar_size / s_z 62 | d_search = (instanc_size - exemplar_size) / 2 63 | pad = d_search / scale_z 64 | s_x = s_z + 2 * pad 65 | 66 | z = crop_hwc(image, pos_s_2_bbox(target_pos, s_z), exemplar_size, padding) 67 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), instanc_size, padding) 68 | return z, x 69 | 70 | 71 | def crop_video(video, d_set, crop_path, instanc_size): 72 | if video != 'list.txt': 73 | video_crop_base_path = join(crop_path, video) 74 | if not isdir(video_crop_base_path): makedirs(video_crop_base_path) 75 | gt_path = join(dataset_path, d_set, video, 'groundtruth.txt') 76 | images_path = join(dataset_path, d_set, video) 77 | f = open(gt_path, 'r') 78 | groundtruth = f.readlines() 79 | f.close() 80 | for idx, gt_line in enumerate(groundtruth): 81 | gt_image = gt_line.strip().split(',') 82 | bbox = [int(float(gt_image[0])),int(float(gt_image[1])),int(float(gt_image[0]))+int(float(gt_image[2])),int(float(gt_image[1]))+int(float(gt_image[3]))]#xmin,ymin,xmax,ymax 83 | 84 | im = cv2.imread(join(images_path,str(idx+1).zfill(8)+'.jpg')) 85 | avg_chans = np.mean(im, axis=(0, 1)) 86 | 87 | z, x = crop_like_SiamFC(im, bbox, instanc_size=instanc_size, padding=avg_chans) 88 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(int(idx), int(0))), z) 89 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(int(idx), int(0))), x) 90 | 91 | 92 | def main(instanc_size=511, num_threads=24): 93 | crop_path = './crop{:d}'.format(instanc_size) 94 | 95 | if not isdir(crop_path): mkdir(crop_path) 96 | for d_set in sub_sets: 97 | save_path = join(crop_path, d_set) 98 | videos = listdir(join(dataset_path,d_set)) 99 | if not isdir(save_path): mkdir(save_path) 100 | 101 | 102 | n_videos = len(videos) 103 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 104 | fs = [executor.submit(crop_video, video, d_set, save_path, instanc_size) for video in videos] 105 | for i, f in enumerate(futures.as_completed(fs)): 106 | # Write progress to error so that it can be seen 107 | printProgress(i, n_videos, prefix='train', suffix='Done ', barLength=40) 108 | 109 | 110 | if __name__ == '__main__': 111 | since = time.time() 112 | main() 113 | time_elapsed = time.time() - since 114 | print('Total complete in {:.0f}m {:.0f}s'.format( 115 | time_elapsed // 60, time_elapsed % 60)) 116 | -------------------------------------------------------------------------------- /training_dataset/got10k/readme.md: -------------------------------------------------------------------------------- 1 | # Preprocessing GOT-10K 2 | A Large High-Diversity Benchmark for Generic Object Tracking in the Wild 3 | 4 | ### Prepare dataset 5 | 6 | After download the dataset, please unzip the dataset at *train_dataset/got10k* directory 7 | mkdir data 8 | unzip full_data/train_data/*.zip -d ./data 9 | ```` 10 | 11 | ### Crop & Generate data info 12 | 13 | ````shell 14 | #python par_crop.py [crop_size] [num_threads] 15 | python par_crop.py 511 12 16 | python gen_json.py 17 | ```` 18 | -------------------------------------------------------------------------------- /training_dataset/lasot/gen_json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import unicode_literals 5 | import json 6 | from os.path import join, exists 7 | import os 8 | import pandas as pd 9 | dataset_path = './data' 10 | 11 | def parse_and_sched(dl_dir='.'): 12 | # For each of the two datasets 13 | f = open('./train_id.txt', 'r') 14 | videos = f.readlines() 15 | f.close() 16 | n_videos = len(videos) 17 | js = {} 18 | for idx,video in enumerate(videos): 19 | print('{}/{}'.format(idx,n_videos)) 20 | video = video.strip() 21 | class_name = video.split('-')[0] 22 | class_path = join(dataset_path, class_name) 23 | gt_path = join(class_path, video, 'groundtruth.txt') 24 | f = open(gt_path, 'r') 25 | groundtruth = f.readlines() 26 | f.close() 27 | for idx, gt_line in enumerate(groundtruth): 28 | gt_image = gt_line.strip().split(',') 29 | frame = '%06d' % (int(idx)) 30 | obj = '%02d' % (int(0)) 31 | bbox = [int(float(gt_image[0])), int(float(gt_image[1])), 32 | int(float(gt_image[0])) + int(float(gt_image[2])), 33 | int(float(gt_image[1])) + int(float(gt_image[3]))] # xmin,ymin,xmax,ymax 34 | x1 = bbox[0] 35 | y1 = bbox[1] 36 | w = bbox[2] 37 | h = bbox[3] 38 | if x1 < 0 or y1 < 0 or w <= 0 or h <= 0: 39 | continue 40 | 41 | if video not in js: 42 | js[video] = {} 43 | if obj not in js[video]: 44 | js[video][obj] = {} 45 | js[video][obj][frame] = bbox 46 | json.dump(js, open('train.json', 'w'), indent=4, sort_keys=True) 47 | js = {} 48 | json.dump(js, open('val.json', 'w'), indent=4, sort_keys=True) 49 | print('done') 50 | 51 | 52 | if __name__ == '__main__': 53 | parse_and_sched() 54 | -------------------------------------------------------------------------------- /training_dataset/lasot/par_crop.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import listdir, mkdir, makedirs 3 | import cv2 4 | import numpy as np 5 | import glob 6 | import xml.etree.ElementTree as ET 7 | from concurrent import futures 8 | import sys 9 | import time 10 | 11 | dataset_path = './data' 12 | 13 | # Print iterations progress (thanks StackOverflow) 14 | def printProgress(iteration, total, prefix='', suffix='', decimals=1, barLength=100): 15 | """ 16 | Call in a loop to create terminal progress bar 17 | @params: 18 | iteration - Required : current iteration (Int) 19 | total - Required : total iterations (Int) 20 | prefix - Optional : prefix string (Str) 21 | suffix - Optional : suffix string (Str) 22 | decimals - Optional : positive number of decimals in percent complete (Int) 23 | barLength - Optional : character length of bar (Int) 24 | """ 25 | formatStr = "{0:." + str(decimals) + "f}" 26 | percents = formatStr.format(100 * (iteration / float(total))) 27 | filledLength = int(round(barLength * iteration / float(total))) 28 | bar = '' * filledLength + '-' * (barLength - filledLength) 29 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 30 | if iteration == total: 31 | sys.stdout.write('\x1b[2K\r') 32 | sys.stdout.flush() 33 | 34 | 35 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 36 | a = (out_sz-1) / (bbox[2]-bbox[0]) 37 | b = (out_sz-1) / (bbox[3]-bbox[1]) 38 | c = -a * bbox[0] 39 | d = -b * bbox[1] 40 | mapping = np.array([[a, 0, c], 41 | [0, b, d]]).astype(np.float) 42 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 43 | return crop 44 | 45 | 46 | def pos_s_2_bbox(pos, s): 47 | return [pos[0]-s/2, pos[1]-s/2, pos[0]+s/2, pos[1]+s/2] 48 | 49 | 50 | def crop_like_SiamFC(image, bbox, context_amount=0.5, exemplar_size=127, instanc_size=255, padding=(0, 0, 0)): 51 | target_pos = [(bbox[2]+bbox[0])/2., (bbox[3]+bbox[1])/2.] 52 | target_size = [bbox[2]-bbox[0], bbox[3]-bbox[1]] 53 | wc_z = target_size[1] + context_amount * sum(target_size) 54 | hc_z = target_size[0] + context_amount * sum(target_size) 55 | s_z = np.sqrt(wc_z * hc_z) 56 | scale_z = exemplar_size / s_z 57 | d_search = (instanc_size - exemplar_size) / 2 58 | pad = d_search / scale_z 59 | s_x = s_z + 2 * pad 60 | 61 | z = crop_hwc(image, pos_s_2_bbox(target_pos, s_z), exemplar_size, padding) 62 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), instanc_size, padding) 63 | return z, x 64 | 65 | 66 | def crop_video(video, crop_path, instanc_size,num): 67 | video = video.strip() 68 | class_name = video.split('-')[0] 69 | class_path = join(dataset_path,class_name) 70 | video_crop_base_path = join(crop_path, video) 71 | if not isdir(video_crop_base_path): makedirs(video_crop_base_path) 72 | gt_path = join(class_path, video, 'groundtruth.txt') 73 | images_path = join(class_path, video,'img') 74 | f = open(gt_path, 'r') 75 | groundtruth = f.readlines() 76 | f.close() 77 | for idx, gt_line in enumerate(groundtruth): 78 | gt_image = gt_line.strip().split(',') 79 | bbox = [int(float(gt_image[0])),int(float(gt_image[1])),int(float(gt_image[0]))+int(float(gt_image[2])),int(float(gt_image[1]))+int(float(gt_image[3]))]#xmin,ymin,xmax,ymax 80 | 81 | im = cv2.imread(join(images_path,str(idx+1).zfill(8)+'.jpg')) 82 | avg_chans = np.mean(im, axis=(0, 1)) 83 | 84 | z, x = crop_like_SiamFC(im, bbox, instanc_size=instanc_size, padding=avg_chans) 85 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(int(idx), int(0))), z) 86 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(int(idx), int(0))), x) 87 | 88 | 89 | def main(instanc_size=511, num_threads=24): 90 | crop_path = './crop{:d}'.format(instanc_size) 91 | f = open('./train_id.txt', 'r') 92 | videos = f.readlines() 93 | f.close() 94 | if not isdir(crop_path): mkdir(crop_path) 95 | 96 | n_videos = len(videos) 97 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 98 | fs = [executor.submit(crop_video, video, crop_path, instanc_size,idx) for idx,video in enumerate(videos)] 99 | for i, f in enumerate(futures.as_completed(fs)): 100 | # Write progress to error so that it can be seen 101 | printProgress(i, n_videos, prefix='train', suffix='Done ', barLength=40) 102 | 103 | 104 | if __name__ == '__main__': 105 | since = time.time() 106 | main() 107 | time_elapsed = time.time() - since 108 | print('Total complete in {:.0f}m {:.0f}s'.format( 109 | time_elapsed // 60, time_elapsed % 60)) 110 | -------------------------------------------------------------------------------- /training_dataset/lasot/readme.md: -------------------------------------------------------------------------------- 1 | # Preprocessing LaSOT 2 | Large-scale Single Object Tracking 3 | 4 | ### Prepare dataset 5 | 6 | After download the dataset, please unzip the dataset at *train_dataset/lasot* directory 7 | ````shell 8 | mkdir data 9 | unzip LaSOT/zip/*.zip -d ./data 10 | ```` 11 | 12 | ### Crop & Generate data info 13 | 14 | ````shell 15 | #python par_crop.py [crop_size] [num_threads] 16 | python par_crop.py 511 12 17 | python gen_json.py 18 | ```` 19 | -------------------------------------------------------------------------------- /training_dataset/trackingnet/gen_json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import unicode_literals 5 | import json 6 | from os.path import join, exists 7 | import os 8 | import pandas as pd 9 | 10 | dataset_path = './data' 11 | 12 | train_sets = ['TRAIN_0', 'TRAIN_1', 'TRAIN_2', 'TRAIN_3', 'TRAIN_4', 'TRAIN_5', 13 | 'TRAIN_6', 'TRAIN_7', 'TRAIN_8', 'TRAIN_9', 'TRAIN_10', 'TRAIN_11'] 14 | test_sets = ['TEST'] 15 | d_sets = {'videos_train': train_sets} 16 | 17 | 18 | def parse_and_sched(dl_dir='.'): 19 | js = {} 20 | for d_set in d_sets: 21 | for dataset in d_sets[d_set]: 22 | anno_path = os.path.join(dataset_path, dataset, 'anno') 23 | anno_files = os.listdir(anno_path) 24 | for video in anno_files: 25 | gt_path = os.path.join(anno_path, video) 26 | f = open(gt_path, 'r') 27 | groundtruth = f.readlines() 28 | f.close() 29 | for idx, gt_line in enumerate(groundtruth): 30 | gt_image = gt_line.strip().split(',') 31 | frame = '%06d' % (int(idx)) 32 | obj = '%02d' % (int(0)) 33 | bbox = [int(float(gt_image[0])), int(float(gt_image[1])), 34 | int(float(gt_image[0])) + int(float(gt_image[2])), 35 | int(float(gt_image[1])) + int(float(gt_image[3]))] # xmin,ymin,xmax,ymax 36 | video_name = dataset + '/' + video.split('.')[0] 37 | if video_name not in js: 38 | js[video_name] = {} 39 | if obj not in js[video_name]: 40 | js[video_name][obj] = {} 41 | js[video_name][obj][frame] = bbox 42 | if 'videos_test' == d_set: 43 | json.dump(js, open('test.json', 'w'), indent=4, sort_keys=True) 44 | else: 45 | json.dump(js, open('train.json', 'w'), indent=4, sort_keys=True) 46 | 47 | print(d_set+': All videos downloaded') 48 | 49 | 50 | if __name__ == '__main__': 51 | parse_and_sched() 52 | -------------------------------------------------------------------------------- /training_dataset/trackingnet/par_crop.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import listdir, mkdir, makedirs 3 | import cv2 4 | import numpy as np 5 | import glob 6 | import xml.etree.ElementTree as ET 7 | from concurrent import futures 8 | import sys 9 | import time 10 | 11 | dataset_path = '/data1/trackingNet/crop511' 12 | # dataset_path = './data' 13 | # sub_sets = ['TRAIN_0', 'TRAIN_1', 'TRAIN_2', 'TRAIN_3', 'TRAIN_4', 'TRAIN_5', 14 | # 'TRAIN_6', 'TRAIN_7', 'TRAIN_8', 'TRAIN_9', 'TRAIN_10', 'TRAIN_11'] 15 | sub_sets = ['TRAIN_4'] 16 | 17 | # Print iterations progress (thanks StackOverflow) 18 | def printProgress(iteration, total, prefix='', suffix='', decimals=1, barLength=100): 19 | """ 20 | Call in a loop to create terminal progress bar 21 | @params: 22 | iteration - Required : current iteration (Int) 23 | total - Required : total iterations (Int) 24 | prefix - Optional : prefix string (Str) 25 | suffix - Optional : suffix string (Str) 26 | decimals - Optional : positive number of decimals in percent complete (Int) 27 | barLength - Optional : character length of bar (Int) 28 | """ 29 | formatStr = "{0:." + str(decimals) + "f}" 30 | percents = formatStr.format(100 * (iteration / float(total))) 31 | filledLength = int(round(barLength * iteration / float(total))) 32 | bar = '' * filledLength + '-' * (barLength - filledLength) 33 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 34 | if iteration == total: 35 | sys.stdout.write('\x1b[2K\r') 36 | sys.stdout.flush() 37 | 38 | 39 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 40 | a = (out_sz-1) / (bbox[2]-bbox[0]) 41 | b = (out_sz-1) / (bbox[3]-bbox[1]) 42 | c = -a * bbox[0] 43 | d = -b * bbox[1] 44 | mapping = np.array([[a, 0, c], 45 | [0, b, d]]).astype(np.float) 46 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 47 | return crop 48 | 49 | 50 | def pos_s_2_bbox(pos, s): 51 | return [pos[0]-s/2, pos[1]-s/2, pos[0]+s/2, pos[1]+s/2] 52 | 53 | 54 | def crop_like_SiamFC(image, bbox, context_amount=0.5, exemplar_size=127, instanc_size=255, padding=(0, 0, 0)): 55 | target_pos = [(bbox[2]+bbox[0])/2., (bbox[3]+bbox[1])/2.] 56 | target_size = [bbox[2]-bbox[0], bbox[3]-bbox[1]] 57 | wc_z = target_size[1] + context_amount * sum(target_size) 58 | hc_z = target_size[0] + context_amount * sum(target_size) 59 | s_z = np.sqrt(wc_z * hc_z) 60 | scale_z = exemplar_size / s_z 61 | d_search = (instanc_size - exemplar_size) / 2 62 | pad = d_search / scale_z 63 | s_x = s_z + 2 * pad 64 | 65 | z = crop_hwc(image, pos_s_2_bbox(target_pos, s_z), exemplar_size, padding) 66 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), instanc_size, padding) 67 | return z, x 68 | 69 | 70 | def crop_video(video, d_set, crop_path, instanc_size): 71 | if video != 'list.txt': 72 | video_crop_base_path = join(crop_path, video) 73 | if not isdir(video_crop_base_path): makedirs(video_crop_base_path) 74 | gt_path = join(dataset_path, d_set, 'anno', video + '.txt') 75 | images_path = join(dataset_path, d_set, 'zip', video) 76 | f = open(gt_path, 'r') 77 | groundtruth = f.readlines() 78 | f.close() 79 | for idx, gt_line in enumerate(groundtruth): 80 | gt_image = gt_line.strip().split(',') 81 | bbox = [int(float(gt_image[0])),int(float(gt_image[1])),int(float(gt_image[0]))+int(float(gt_image[2])),int(float(gt_image[1]))+int(float(gt_image[3]))]#xmin,ymin,xmax,ymax 82 | 83 | im = cv2.imread(join(images_path, str(idx)+'.jpg')) 84 | # cv2.rectangle(im, (bbox[0], bbox[1]), 85 | # (bbox[2], bbox[3]), (0, 255, 0), 3) 86 | # cv2.imshow('test', im) 87 | 88 | avg_chans = np.mean(im, axis=(0, 1)) 89 | 90 | z, x = crop_like_SiamFC(im, bbox, instanc_size=instanc_size, padding=avg_chans) 91 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(int(idx), int(0))), z) 92 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(int(idx), int(0))), x) 93 | 94 | 95 | def main(instanc_size=511, num_threads=24): 96 | crop_path = './crop{:d}'.format(instanc_size) 97 | 98 | if not isdir(crop_path): mkdir(crop_path) 99 | for d_set in sub_sets: 100 | save_path = join(crop_path, d_set) 101 | videos = listdir(join(dataset_path, d_set, 'zip')) 102 | if not isdir(save_path): mkdir(save_path) 103 | 104 | n_videos = len(videos) 105 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 106 | fs = [executor.submit(crop_video, video, d_set, save_path, instanc_size) for video in videos] 107 | for i, f in enumerate(futures.as_completed(fs)): 108 | # Write progress to error so that it can be seen 109 | printProgress(i, n_videos, prefix='train', suffix='Done ', barLength=40) 110 | 111 | 112 | if __name__ == '__main__': 113 | since = time.time() 114 | main() 115 | time_elapsed = time.time() - since 116 | print('Total complete in {:.0f}m {:.0f}s'.format( 117 | time_elapsed // 60, time_elapsed % 60)) 118 | -------------------------------------------------------------------------------- /training_dataset/trackingnet/readme.md: -------------------------------------------------------------------------------- 1 | # Preprocessing TrackingNet 2 | TrackingNet: A Large-Scale Dataset and Benchmark for Object Tracking in the Wild 3 | 4 | ### Prepare dataset 5 | 6 | After download the dataset, please unzip the dataset at *train_dataset/TrackingNet* directory 7 | ````shell 8 | mkdir data 9 | unzip TrackingNet/zip/*.zip -d ./data 10 | ```` 11 | 12 | ### Crop & Generate data info 13 | 14 | ````shell 15 | #python par_crop.py [crop_size] [num_threads] 16 | python par_crop.py 511 12 17 | python gen_json.py 18 | ```` 19 | -------------------------------------------------------------------------------- /training_dataset/vid/gen_json.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from os import listdir 3 | import json 4 | import numpy as np 5 | 6 | print('load json (raw vid info), please wait 20 seconds~') 7 | vid = json.load(open('vid.json', 'r')) 8 | 9 | 10 | def check_size(frame_sz, bbox): 11 | min_ratio = 0.1 12 | max_ratio = 0.75 13 | # only accept objects >10% and <75% of the total frame 14 | area_ratio = np.sqrt((bbox[2]-bbox[0])*(bbox[3]-bbox[1])/float(np.prod(frame_sz))) 15 | ok = (area_ratio > min_ratio) and (area_ratio < max_ratio) 16 | return ok 17 | 18 | 19 | def check_borders(frame_sz, bbox): 20 | dist_from_border = 0.05 * (bbox[2] - bbox[0] + bbox[3] - bbox[1])/2 21 | ok = (bbox[0] > dist_from_border) and (bbox[1] > dist_from_border) and \ 22 | ((frame_sz[0] - bbox[2]) > dist_from_border) and \ 23 | ((frame_sz[1] - bbox[3]) > dist_from_border) 24 | return ok 25 | 26 | 27 | snippets = dict() 28 | n_snippets = 0 29 | n_videos = 0 30 | for subset in vid: 31 | for video in subset: 32 | n_videos += 1 33 | frames = video['frame'] 34 | id_set = [] 35 | id_frames = [[]] * 60 # at most 60 objects 36 | for f, frame in enumerate(frames): 37 | objs = frame['objs'] 38 | frame_sz = frame['frame_sz'] 39 | for obj in objs: 40 | trackid = obj['trackid'] 41 | occluded = obj['occ'] 42 | bbox = obj['bbox'] 43 | # if occluded: 44 | # continue 45 | # 46 | # if not(check_size(frame_sz, bbox) and check_borders(frame_sz, bbox)): 47 | # continue 48 | # 49 | # if obj['c'] in ['n01674464', 'n01726692', 'n04468005', 'n02062744']: 50 | # continue 51 | 52 | if trackid not in id_set: 53 | id_set.append(trackid) 54 | id_frames[trackid] = [] 55 | id_frames[trackid].append(f) 56 | if len(id_set) > 0: 57 | snippets[video['base_path']] = dict() 58 | for selected in id_set: 59 | frame_ids = sorted(id_frames[selected]) 60 | sequences = np.split(frame_ids, np.array(np.where(np.diff(frame_ids) > 1)[0]) + 1) 61 | sequences = [s for s in sequences if len(s) > 1] # remove isolated frame. 62 | for seq in sequences: 63 | snippet = dict() 64 | for frame_id in seq: 65 | frame = frames[frame_id] 66 | for obj in frame['objs']: 67 | if obj['trackid'] == selected: 68 | o = obj 69 | continue 70 | snippet[frame['img_path'].split('.')[0]] = o['bbox'] 71 | snippets[video['base_path']]['{:02d}'.format(selected)] = snippet 72 | n_snippets += 1 73 | print('video: {:d} snippets_num: {:d}'.format(n_videos, n_snippets)) 74 | 75 | train = {k:v for (k,v) in snippets.items() if 'train' in k} 76 | val = {k:v for (k,v) in snippets.items() if 'val' in k} 77 | 78 | json.dump(train, open('train.json', 'w'), indent=4, sort_keys=True) 79 | json.dump(val, open('val.json', 'w'), indent=4, sort_keys=True) 80 | print('done!') 81 | -------------------------------------------------------------------------------- /training_dataset/vid/par_crop.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import listdir, mkdir, makedirs 3 | import cv2 4 | import numpy as np 5 | import glob 6 | import xml.etree.ElementTree as ET 7 | from concurrent import futures 8 | import sys 9 | import time 10 | 11 | VID_base_path = './ILSVRC2015' 12 | ann_base_path = join(VID_base_path, 'Annotations/VID/train/') 13 | sub_sets = sorted({'a', 'b', 'c', 'd', 'e'}) 14 | 15 | 16 | # Print iterations progress (thanks StackOverflow) 17 | def printProgress(iteration, total, prefix='', suffix='', decimals=1, barLength=100): 18 | """ 19 | Call in a loop to create terminal progress bar 20 | @params: 21 | iteration - Required : current iteration (Int) 22 | total - Required : total iterations (Int) 23 | prefix - Optional : prefix string (Str) 24 | suffix - Optional : suffix string (Str) 25 | decimals - Optional : positive number of decimals in percent complete (Int) 26 | barLength - Optional : character length of bar (Int) 27 | """ 28 | formatStr = "{0:." + str(decimals) + "f}" 29 | percents = formatStr.format(100 * (iteration / float(total))) 30 | filledLength = int(round(barLength * iteration / float(total))) 31 | bar = '' * filledLength + '-' * (barLength - filledLength) 32 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 33 | if iteration == total: 34 | sys.stdout.write('\x1b[2K\r') 35 | sys.stdout.flush() 36 | 37 | 38 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 39 | a = (out_sz-1) / (bbox[2]-bbox[0]) 40 | b = (out_sz-1) / (bbox[3]-bbox[1]) 41 | c = -a * bbox[0] 42 | d = -b * bbox[1] 43 | mapping = np.array([[a, 0, c], 44 | [0, b, d]]).astype(np.float) 45 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 46 | return crop 47 | 48 | 49 | def pos_s_2_bbox(pos, s): 50 | return [pos[0]-s/2, pos[1]-s/2, pos[0]+s/2, pos[1]+s/2] 51 | 52 | 53 | def crop_like_SiamFC(image, bbox, context_amount=0.5, exemplar_size=127, instanc_size=255, padding=(0, 0, 0)): 54 | target_pos = [(bbox[2]+bbox[0])/2., (bbox[3]+bbox[1])/2.] 55 | target_size = [bbox[2]-bbox[0], bbox[3]-bbox[1]] 56 | wc_z = target_size[1] + context_amount * sum(target_size) 57 | hc_z = target_size[0] + context_amount * sum(target_size) 58 | s_z = np.sqrt(wc_z * hc_z) 59 | scale_z = exemplar_size / s_z 60 | d_search = (instanc_size - exemplar_size) / 2 61 | pad = d_search / scale_z 62 | s_x = s_z + 2 * pad 63 | 64 | z = crop_hwc(image, pos_s_2_bbox(target_pos, s_z), exemplar_size, padding) 65 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), instanc_size, padding) 66 | return z, x 67 | 68 | 69 | def crop_video(sub_set, video, crop_path, instanc_size): 70 | video_crop_base_path = join(crop_path, sub_set, video) 71 | if not isdir(video_crop_base_path): makedirs(video_crop_base_path) 72 | 73 | sub_set_base_path = join(ann_base_path, sub_set) 74 | xmls = sorted(glob.glob(join(sub_set_base_path, video, '*.xml'))) 75 | for xml in xmls: 76 | xmltree = ET.parse(xml) 77 | # size = xmltree.findall('size')[0] 78 | # frame_sz = [int(it.text) for it in size] 79 | objects = xmltree.findall('object') 80 | objs = [] 81 | filename = xmltree.findall('filename')[0].text 82 | 83 | im = cv2.imread(xml.replace('xml', 'JPEG').replace('Annotations', 'Data')) 84 | avg_chans = np.mean(im, axis=(0, 1)) 85 | for object_iter in objects: 86 | trackid = int(object_iter.find('trackid').text) 87 | # name = (object_iter.find('name')).text 88 | bndbox = object_iter.find('bndbox') 89 | # occluded = int(object_iter.find('occluded').text) 90 | 91 | bbox = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text), 92 | int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)] 93 | z, x = crop_like_SiamFC(im, bbox, instanc_size=instanc_size, padding=avg_chans) 94 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.z.jpg'.format(int(filename), trackid)), z) 95 | cv2.imwrite(join(video_crop_base_path, '{:06d}.{:02d}.x.jpg'.format(int(filename), trackid)), x) 96 | 97 | 98 | def main(instanc_size=511, num_threads=24): 99 | crop_path = './crop{:d}'.format(instanc_size) 100 | if not isdir(crop_path): mkdir(crop_path) 101 | 102 | for sub_set in sub_sets: 103 | sub_set_base_path = join(ann_base_path, sub_set) 104 | videos = sorted(listdir(sub_set_base_path)) 105 | n_videos = len(videos) 106 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 107 | fs = [executor.submit(crop_video, sub_set, video, crop_path, instanc_size) for video in videos] 108 | for i, f in enumerate(futures.as_completed(fs)): 109 | # Write progress to error so that it can be seen 110 | printProgress(i, n_videos, prefix=sub_set, suffix='Done ', barLength=40) 111 | 112 | 113 | if __name__ == '__main__': 114 | since = time.time() 115 | main(int(sys.argv[1]), int(sys.argv[2])) 116 | time_elapsed = time.time() - since 117 | print('Total complete in {:.0f}m {:.0f}s'.format( 118 | time_elapsed // 60, time_elapsed % 60)) 119 | -------------------------------------------------------------------------------- /training_dataset/vid/parse_vid.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from os import listdir 3 | import json 4 | import glob 5 | import xml.etree.ElementTree as ET 6 | 7 | VID_base_path = './ILSVRC2015' 8 | ann_base_path = join(VID_base_path, 'Annotations/VID/train/') 9 | img_base_path = join(VID_base_path, 'Data/VID/train/') 10 | sub_sets = sorted({'a', 'b', 'c', 'd', 'e'}) 11 | 12 | vid = [] 13 | for sub_set in sub_sets: 14 | sub_set_base_path = join(ann_base_path, sub_set) 15 | videos = sorted(listdir(sub_set_base_path)) 16 | s = [] 17 | for vi, video in enumerate(videos): 18 | print('subset: {} video id: {:04d} / {:04d}'.format(sub_set, vi, len(videos))) 19 | v = dict() 20 | v['base_path'] = join(sub_set, video) 21 | v['frame'] = [] 22 | video_base_path = join(sub_set_base_path, video) 23 | xmls = sorted(glob.glob(join(video_base_path, '*.xml'))) 24 | for xml in xmls: 25 | f = dict() 26 | xmltree = ET.parse(xml) 27 | size = xmltree.findall('size')[0] 28 | frame_sz = [int(it.text) for it in size] 29 | objects = xmltree.findall('object') 30 | objs = [] 31 | for object_iter in objects: 32 | trackid = int(object_iter.find('trackid').text) 33 | name = (object_iter.find('name')).text 34 | bndbox = object_iter.find('bndbox') 35 | occluded = int(object_iter.find('occluded').text) 36 | o = dict() 37 | o['c'] = name 38 | o['bbox'] = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text), 39 | int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)] 40 | o['trackid'] = trackid 41 | o['occ'] = occluded 42 | objs.append(o) 43 | f['frame_sz'] = frame_sz 44 | f['img_path'] = xml.split('/')[-1].replace('xml', 'JPEG') 45 | f['objs'] = objs 46 | v['frame'].append(f) 47 | s.append(v) 48 | vid.append(s) 49 | print('save json (raw vid info), please wait 1 min~') 50 | json.dump(vid, open('vid.json', 'w'), indent=4, sort_keys=True) 51 | print('done!') 52 | -------------------------------------------------------------------------------- /training_dataset/vid/readme.md: -------------------------------------------------------------------------------- 1 | # Preprocessing VID(Object detection from video) 2 | Large Scale Visual Recognition Challenge 2015 (ILSVRC2015) 3 | 4 | ### Download dataset 5 | 6 | ````shell 7 | wget http://bvisionweb1.cs.unc.edu/ilsvrc2015/ILSVRC2015_VID.tar.gz 8 | tar -xzvf ./ILSVRC2015_VID.tar.gz 9 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/train/ILSVRC2015_VID_train_0000 ILSVRC2015/Annotations/VID/train/a 10 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/train/ILSVRC2015_VID_train_0001 ILSVRC2015/Annotations/VID/train/b 11 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/train/ILSVRC2015_VID_train_0002 ILSVRC2015/Annotations/VID/train/c 12 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/train/ILSVRC2015_VID_train_0003 ILSVRC2015/Annotations/VID/train/d 13 | ln -sfb $PWD/ILSVRC2015/Annotations/VID/val ILSVRC2015/Annotations/VID/train/e 14 | 15 | ln -sfb $PWD/ILSVRC2015/Data/VID/train/ILSVRC2015_VID_train_0000 ILSVRC2015/Data/VID/train/a 16 | ln -sfb $PWD/ILSVRC2015/Data/VID/train/ILSVRC2015_VID_train_0001 ILSVRC2015/Data/VID/train/b 17 | ln -sfb $PWD/ILSVRC2015/Data/VID/train/ILSVRC2015_VID_train_0002 ILSVRC2015/Data/VID/train/c 18 | ln -sfb $PWD/ILSVRC2015/Data/VID/train/ILSVRC2015_VID_train_0003 ILSVRC2015/Data/VID/train/d 19 | ln -sfb $PWD/ILSVRC2015/Data/VID/val ILSVRC2015/Data/VID/train/e 20 | ```` 21 | 22 | ### Crop & Generate data info 23 | 24 | ````shell 25 | python parse_vid.py 26 | 27 | #python par_crop.py [crop_size] [num_threads] 28 | python par_crop.py 511 12 29 | python gen_json.py 30 | ```` 31 | -------------------------------------------------------------------------------- /training_dataset/yt_bb/check.py: -------------------------------------------------------------------------------- 1 | import os 2 | path = '/media/amax/guo/Guo_dataset/Guo/yt_bb_detection_train/1' 3 | videos = os.listdir(path) 4 | path_crop = '/data0/youtubebb/crop511/yt_bb_detection_train/1' 5 | video_crop = os.listdir(path_crop) 6 | num_have = len(video_crop) 7 | num_amount = len(videos) 8 | num_miss = 0 9 | num_corr = 0 10 | for video in videos: 11 | video_name = video.split('+') 12 | if video_name[0] in video_crop: 13 | if int(video_name[2][0]) > 0: 14 | print(video_name[0]) 15 | num_have += 1 16 | continue 17 | else: 18 | num_miss += 1 19 | print('num_amount:',num_amount,'num_have:',num_have,'num_miss:',num_miss) -------------------------------------------------------------------------------- /training_dataset/yt_bb/checknum.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import glob 3 | 4 | col_names = ['youtube_id', 'timestamp_ms', 'class_id', 'class_name', 5 | 'object_id', 'object_presence', 'xmin', 'xmax', 'ymin', 'ymax'] 6 | 7 | sets = ['yt_bb_detection_validation', 'yt_bb_detection_train'] 8 | 9 | for subset in sets: 10 | df = pd.DataFrame.from_csv('./'+ subset +'.csv', header=None, index_col=False) 11 | df.columns = col_names 12 | vids = sorted(df['youtube_id'].unique()) 13 | n_vids = len(vids) 14 | print('Total video in {}.csv is {:d}'.format(subset, n_vids)) 15 | 16 | frame_download = glob.glob('./{}/*/*.jpg'.format(subset)) 17 | frame_download = [frame.split('/')[-1] for frame in frame_download] 18 | frame_download = [frame[:frame.find('_')] for frame in frame_download] 19 | frame_download = [frame[:frame.find('_')] for frame in frame_download] 20 | frame_download = [frame[:frame.find('_')] for frame in frame_download] 21 | frame_download = sorted(set(frame_download)) 22 | # print(frame_download) 23 | print('Total downloaded in {} is {:d}'.format(subset, len(frame_download))) 24 | 25 | 26 | print('done') 27 | -------------------------------------------------------------------------------- /training_dataset/yt_bb/gen_json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import unicode_literals 5 | import json 6 | from os.path import join, exists 7 | import pandas as pd 8 | 9 | # The data sets to be downloaded 10 | d_sets = ['yt_bb_detection_validation', 'yt_bb_detection_train'] 11 | 12 | # Column names for detection CSV files 13 | col_names = ['youtube_id', 'timestamp_ms','class_id','class_name', 14 | 'object_id','object_presence','xmin','xmax','ymin','ymax'] 15 | 16 | instanc_size = 511 17 | crop_path = './crop{:d}'.format(instanc_size) 18 | 19 | 20 | def parse_and_sched(dl_dir='.'): 21 | # For each of the two datasets 22 | js = {} 23 | for d_set in d_sets: 24 | 25 | # Make the directory for this dataset 26 | d_set_dir = dl_dir+'/'+d_set+'/' 27 | 28 | # Parse csv data using pandas 29 | print (d_set+': Parsing annotations into clip data...') 30 | df = pd.DataFrame.from_csv(d_set+'.csv', header=None, index_col=False) 31 | df.columns = col_names 32 | 33 | # Get list of unique video files 34 | vids = df['youtube_id'].unique() 35 | 36 | for vid in vids: 37 | data = df[df['youtube_id']==vid] 38 | for index, row in data.iterrows(): 39 | youtube_id, timestamp_ms, class_id, class_name, \ 40 | object_id, object_presence, x1, x2, y1, y2 = row 41 | 42 | if object_presence == 'absent': 43 | continue 44 | 45 | if x1 < 0 or x2 < 0 or y1 < 0 or y2 < 0 or y2 < y1 or x2 < x1: 46 | continue 47 | 48 | bbox = [x1, y1, x2, y2] 49 | frame = '%06d' % (int(timestamp_ms) / 1000) 50 | obj = '%02d' % (int(object_id)) 51 | video = join(d_set_dir + str(class_id), youtube_id) 52 | 53 | if not exists(join(crop_path, video, '{}.{}.x.jpg'.format(frame, obj))): 54 | continue 55 | 56 | if video not in js: 57 | js[video] = {} 58 | if obj not in js[video]: 59 | js[video][obj] = {} 60 | js[video][obj][frame] = bbox 61 | 62 | if 'yt_bb_detection_train' == d_set: 63 | json.dump(js, open('train.json', 'w'), indent=4, sort_keys=True) 64 | else: 65 | json.dump(js, open('val.json', 'w'), indent=4, sort_keys=True) 66 | js = {} 67 | print(d_set+': All videos downloaded' ) 68 | 69 | 70 | if __name__ == '__main__': 71 | parse_and_sched() 72 | -------------------------------------------------------------------------------- /training_dataset/yt_bb/par_crop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import unicode_literals 5 | from subprocess import check_call 6 | from concurrent import futures 7 | import os 8 | from os.path import join 9 | import sys 10 | import cv2 11 | import pandas as pd 12 | import numpy as np 13 | 14 | # The data sets to be downloaded 15 | d_sets = ['yt_bb_detection_validation', 'yt_bb_detection_train'] 16 | 17 | # Column names for detection CSV files 18 | col_names = ['youtube_id', 'timestamp_ms','class_id','class_name', 19 | 'object_id','object_presence','xmin','xmax','ymin','ymax'] 20 | 21 | # Host location of segment lists 22 | web_host = 'https://research.google.com/youtube-bb/' 23 | 24 | 25 | # Print iterations progress (thanks StackOverflow) 26 | def printProgress (iteration, total, prefix = '', suffix = '', decimals = 1, barLength = 100): 27 | """ 28 | Call in a loop to create terminal progress bar 29 | @params: 30 | iteration - Required : current iteration (Int) 31 | total - Required : total iterations (Int) 32 | prefix - Optional : prefix string (Str) 33 | suffix - Optional : suffix string (Str) 34 | decimals - Optional : positive number of decimals in percent complete (Int) 35 | barLength - Optional : character length of bar (Int) 36 | """ 37 | formatStr = "{0:." + str(decimals) + "f}" 38 | percents = formatStr.format(100 * (iteration / float(total))) 39 | filledLength = int(round(barLength * iteration / float(total))) 40 | bar = '█' * filledLength + '-' * (barLength - filledLength) 41 | sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percents, '%', suffix)), 42 | if iteration == total: 43 | sys.stdout.write('\x1b[2K\r') 44 | sys.stdout.flush() 45 | 46 | 47 | instanc_size = 511 48 | crop_path = './crop{:d}'.format(instanc_size) 49 | check_call(['mkdir', '-p', crop_path]) 50 | 51 | 52 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 53 | a = (out_sz-1) / (bbox[2]-bbox[0]) 54 | b = (out_sz-1) / (bbox[3]-bbox[1]) 55 | c = -a * bbox[0] 56 | d = -b * bbox[1] 57 | mapping = np.array([[a, 0, c], 58 | [0, b, d]]).astype(np.float) 59 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 60 | return crop 61 | 62 | 63 | def pos_s_2_bbox(pos, s): 64 | return [pos[0]-s/2, pos[1]-s/2, pos[0]+s/2, pos[1]+s/2] 65 | 66 | 67 | def crop_like_SiamFC(image, bbox, context_amount=0.5, exemplar_size=127, instanc_size=255, padding=(0, 0, 0)): 68 | target_pos = [(bbox[2]+bbox[0])/2., (bbox[3]+bbox[1])/2.] 69 | target_size = [bbox[2]-bbox[0], bbox[3]-bbox[1]] 70 | wc_z = target_size[1] + context_amount * sum(target_size) 71 | hc_z = target_size[0] + context_amount * sum(target_size) 72 | s_z = np.sqrt(wc_z * hc_z) 73 | scale_z = exemplar_size / s_z 74 | d_search = (instanc_size - exemplar_size) / 2 75 | pad = d_search / scale_z 76 | s_x = s_z + 2 * pad 77 | 78 | z = crop_hwc(image, pos_s_2_bbox(target_pos, s_z), exemplar_size, padding) 79 | x = crop_hwc(image, pos_s_2_bbox(target_pos, s_x), instanc_size, padding) 80 | return z, x 81 | 82 | 83 | # Download and cut a clip to size 84 | def dl_and_cut(vid, data, d_set_dir): 85 | for index, row in data.iterrows(): 86 | youtube_id, timestamp_ms, class_id, class_name,\ 87 | object_id, object_presence, xmin, xmax, ymin, ymax = row 88 | 89 | if object_presence == 'absent': 90 | continue 91 | 92 | class_dir = d_set_dir + str(class_id) 93 | frame_path = class_dir + '/' + youtube_id + '_' + str(timestamp_ms) + \ 94 | '_' + str(class_id) + '_' + str(object_id) + '.jpg' 95 | # Verify that the video has been downloaded. Skip otherwise 96 | if not os.path.exists(frame_path): 97 | continue 98 | 99 | image = cv2.imread(frame_path) 100 | avg_chans = np.mean(image, axis=(0, 1)) 101 | # Uncomment lines below to print bounding boxes on downloaded images 102 | h, w = image.shape[:2] 103 | x1 = xmin*w 104 | x2 = xmax*w 105 | y1 = ymin*h 106 | y2 = ymax*h 107 | if x1 < 0 or x2 < 0 or y1 < 0 or y2 < 0 or y2 < y1 or x2 < x1: 108 | continue 109 | 110 | # Make the class directory if it doesn't exist yet 111 | crop_class_dir = join(crop_path, d_set_dir+str(class_id), youtube_id) 112 | check_call(['mkdir', '-p', crop_class_dir]) 113 | 114 | # Save the extracted image 115 | bbox = [x1, y1, x2, y2] 116 | z, x = crop_like_SiamFC(image, bbox, instanc_size=instanc_size, padding=avg_chans) 117 | cv2.imwrite(join(crop_class_dir, '{:06d}.{:02d}.z.jpg'.format(int(timestamp_ms)/1000, int(object_id))), z) 118 | cv2.imwrite(join(crop_class_dir, '{:06d}.{:02d}.x.jpg'.format(int(timestamp_ms)/1000, int(object_id))), x) 119 | return True 120 | 121 | 122 | # Parse the annotation csv file and schedule downloads and cuts 123 | def parse_and_sched(dl_dir='.', num_threads=24): 124 | """Crop the entire youtube-bb data set into `crop_path`. 125 | """ 126 | # For each of the two datasets 127 | for d_set in d_sets: 128 | 129 | # Make the directory for this dataset 130 | d_set_dir = dl_dir+'/'+d_set+'/' 131 | 132 | # Download & extract the annotation list 133 | # print (d_set+': Downloading annotations...') 134 | # check_call(['wget', web_host+d_set+'.csv.gz']) 135 | # print (d_set+': Unzipping annotations...') 136 | # check_call(['gzip', '-d', '-f', d_set+'.csv.gz']) 137 | 138 | # Parse csv data using pandas 139 | print (d_set+': Parsing annotations into clip data...') 140 | df = pd.DataFrame.from_csv(d_set+'.csv', header=None, index_col=False) 141 | df.columns = col_names 142 | 143 | # Get list of unique video files 144 | vids = df['youtube_id'].unique() 145 | 146 | # Download and cut in parallel threads giving 147 | with futures.ProcessPoolExecutor(max_workers=num_threads) as executor: 148 | fs = [executor.submit(dl_and_cut,vid,df[df['youtube_id']==vid],d_set_dir) for vid in vids] 149 | for i, f in enumerate(futures.as_completed(fs)): 150 | # Write progress to error so that it can be seen 151 | printProgress(i, len(vids), 152 | prefix = d_set, 153 | suffix = 'Done', 154 | barLength = 40) 155 | 156 | print(d_set+': All videos Crop Done') 157 | 158 | 159 | if __name__ == '__main__': 160 | parse_and_sched() 161 | -------------------------------------------------------------------------------- /training_dataset/yt_bb/readme.md: -------------------------------------------------------------------------------- 1 | # Preprocessing Youtube-bb(YouTube-BoundingBoxes Dataset) 2 | 3 | ### Download raw label 4 | 5 | ````shell 6 | wget https://research.google.com/youtube-bb/yt_bb_detection_train.csv.gz 7 | wget https://research.google.com/youtube-bb/yt_bb_detection_validation.csv.gz 8 | 9 | gzip -d ./yt_bb_detection_train.csv.gz 10 | gzip -d ./yt_bb_detection_validation.csv.gz 11 | ```` 12 | 13 | ### Download raw image by `youtube-bb-utility` 14 | 15 | ````shell 16 | git clone https://github.com/mehdi-shiba/youtube-bb-utility.git 17 | cd youtube-bb-utility 18 | pip install -r requirements.txt 19 | # python download_detection.py [VIDEO_DIR] [NUM_THREADS] 20 | python download_detection.py ../ 12 21 | cd .. 22 | ```` 23 | 24 | ### Crop & Generate data info 25 | 26 | ````shell 27 | python par_crop.py 28 | python gen_json.py 29 | ```` 30 | --------------------------------------------------------------------------------