├── CDT.png ├── Denoising └── Options │ └── Denoising_CDT.yml ├── README.md ├── basicsr ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── ffhq_dataset.py │ ├── meta_info │ │ ├── meta_info_DIV2K800sub_GT.txt │ │ ├── meta_info_REDS4_test_GT.txt │ │ ├── meta_info_REDS_GT.txt │ │ ├── meta_info_REDSofficial4_test_GT.txt │ │ ├── meta_info_REDSval_official_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_GT.txt │ │ ├── meta_info_Vimeo90K_test_fast_GT.txt │ │ ├── meta_info_Vimeo90K_test_medium_GT.txt │ │ ├── meta_info_Vimeo90K_test_slow_GT.txt │ │ └── meta_info_Vimeo90K_train_GT.txt │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ ├── video_test_dataset.py │ └── vimeo90k_dataset.py ├── metrics │ ├── __init__.py │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ └── psnr_ssim.py ├── models │ ├── __init__.py │ ├── archs │ │ ├── CDT_arch.py │ │ ├── __init__.py │ │ └── arch_util.py │ ├── base_model.py │ ├── image_restoration_model.py │ ├── losses │ │ ├── __init__.py │ │ ├── loss_util.py │ │ └── losses.py │ └── lr_scheduler.py ├── test.py ├── train.py ├── utils │ ├── __init__.py │ ├── bundle_submissions.py │ ├── create_lmdb.py │ ├── dist_util.py │ ├── download_util.py │ ├── face_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ └── options.py └── version.py ├── demo.py ├── demo ├── img_dns │ ├── 00001.jpg │ ├── 00074.jpg │ ├── noisy1.png │ └── noisy2.png └── img_ori │ ├── 00001.jpg │ ├── 00074.jpg │ ├── noisy1.png │ └── noisy2.png ├── denoiser └── denoiser_builder.py ├── enhancer ├── DCE_model.py ├── SCT_model.py └── enhancer_builder.py ├── eval.py ├── experiments ├── CDT │ └── README.md ├── DCE++ │ └── README.md ├── SCT │ └── README.md ├── SiamAPN++ │ ├── README.md │ └── config.yaml ├── SiamAPN │ ├── README.md │ └── config.yaml ├── SiamBAN │ ├── LICENSE │ ├── README.md │ └── config.yaml ├── SiamGAT │ ├── README.md │ └── config.yaml └── SiamRPN++_mobilev2 │ ├── LICENSE │ ├── README.md │ └── config.yaml ├── results ├── DarkTrack2021 │ ├── SiamAPN++_DCE.zip │ ├── SiamAPN++_DCE_CDT.zip │ ├── SiamAPN++_SCT.zip │ ├── SiamAPN++_SCT_CDT.zip │ ├── SiamAPN_DCE.zip │ ├── SiamAPN_DCE_CDT.zip │ ├── SiamAPN_SCT.zip │ ├── SiamAPN_SCT_CDT.zip │ ├── SiamBAN_DCE.zip │ ├── SiamBAN_DCE_CDT.zip │ ├── SiamBAN_SCT.zip │ ├── SiamBAN_SCT_CDT.zip │ ├── SiamGAT_DCE.zip │ ├── SiamGAT_DCE_CDT.zip │ ├── SiamGAT_SCT.zip │ ├── SiamGAT_SCT_CDT.zip │ ├── SiamRPN++_DCE.zip │ ├── SiamRPN++_DCE_CDT.zip │ ├── SiamRPN++_SCT.zip │ └── SiamRPN++_SCT_CDT.zip └── UAVDark135 │ ├── SiamAPN++_DCE.zip │ ├── SiamAPN++_DCE_CDT.zip │ ├── SiamAPN++_SCT.zip │ ├── SiamAPN++_SCT_CDT.zip │ ├── SiamAPN_DCE.zip │ ├── SiamAPN_DCE_CDT.zip │ ├── SiamAPN_SCT.zip │ ├── SiamAPN_SCT_CBDNet.zip │ ├── SiamAPN_SCT_CDT.zip │ ├── SiamAPN_SCT_DRUNet.zip │ ├── SiamAPN_SCT_DnCNN.zip │ ├── SiamAPN_SCT_FDnCNN.zip │ ├── SiamAPN_SCT_FFDNet.zip │ ├── SiamAPN_SCT_IRCNN.zip │ ├── SiamAPN_SCT_NAFNet.zip │ ├── SiamAPN_SCT_Restormer.zip │ ├── SiamAPN_SCT_Uformer.zip │ ├── SiamBAN_DCE.zip │ ├── SiamBAN_DCE_CDT.zip │ ├── SiamBAN_SCT.zip │ ├── SiamBAN_SCT_CDT.zip │ ├── SiamGAT_DCE.zip │ ├── SiamGAT_DCE_CDT.zip │ ├── SiamGAT_SCT.zip │ ├── SiamGAT_SCT_CDT.zip │ ├── SiamRPN++_DCE.zip │ ├── SiamRPN++_DCE_CDT.zip │ ├── SiamRPN++_SCT.zip │ └── SiamRPN++_SCT_CDT.zip ├── snot ├── core │ ├── __init__.py │ ├── config.py │ ├── config_adapn.py │ ├── config_apn.py │ ├── config_ban.py │ ├── config_gat.py │ └── xcorr.py ├── datasets │ ├── __init__.py │ ├── darktrack.py │ ├── datapath.py │ ├── dataset.py │ ├── uavdark.py │ └── video.py ├── models │ ├── adapn │ │ ├── anchortarget.py │ │ └── utile.py │ ├── adsiamapn_model.py │ ├── apn │ │ ├── anchortarget.py │ │ └── utile.py │ ├── backbone │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── googlenet.py │ │ ├── googlenet_ou.py │ │ ├── mobile_v2.py │ │ └── resnet_atrous.py │ ├── dasiamrpn_model.py │ ├── head │ │ ├── __init__.py │ │ ├── ban.py │ │ ├── car.py │ │ ├── mask.py │ │ └── rpn.py │ ├── model_builder.py │ ├── neck │ │ ├── __init__.py │ │ └── neck.py │ ├── siamapn_model.py │ ├── siamban_model.py │ └── siamgat_model.py ├── pipelines │ ├── pipeline_builder.py │ ├── siamapn_pipeline.py │ ├── siamapnpp_pipeline.py │ ├── siamban_pipeline.py │ ├── siamcar_pipeline.py │ ├── siamgat_pipeline.py │ ├── siammask_pipeline.py │ └── siamrpn_pipeline.py ├── trackers │ ├── adsiamapn_tracker.py │ ├── base_tracker.py │ ├── siamapn_tracker.py │ ├── siamban_tracker.py │ ├── siamcar_tracker.py │ ├── siamgat_tracker.py │ ├── siamrpn_tracker.py │ ├── tracker_builder.py │ └── tracker_builder_ban.py └── utils │ ├── anchor.py │ ├── bbox.py │ ├── evaluation │ ├── __init__.py │ └── ope_benchmark.py │ ├── misc.py │ ├── model_load.py │ ├── statistics.py │ ├── utils_ad.py │ └── visualization │ ├── __init__.py │ ├── draw_success_precision.py │ └── draw_utils.py ├── test.py └── train.py /CDT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/CDT.png -------------------------------------------------------------------------------- /Denoising/Options/Denoising_CDT.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: CDT 3 | model_type: ImageCleanModel 4 | scale: 1 5 | num_gpu: 4 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: Dataset_PairedImage 13 | dataroot_gt: /data/Denoise_dataset/train/SIDD/target_crops 14 | dataroot_lq: /data/Denoise_dataset/train/SIDD/input_crops 15 | geometric_augs: true 16 | 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 8 24 | batch_size_per_gpu: 8 25 | 26 | ### -------------Progressive training-------------------------- 27 | mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu 28 | iters: [92000,64000,48000,36000,36000,24000] 29 | gt_size: 384 # Max patch size for progressive training 30 | gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training. 31 | ### ------------------------------------------------------------ 32 | 33 | ### ------- Training on single fixed-patch size 128x128--------- 34 | # mini_batch_sizes: [8] 35 | # iters: [300000] 36 | # gt_size: 128 37 | # gt_sizes: [128] 38 | ### ------------------------------------------------------------ 39 | 40 | dataset_enlarge_ratio: 1 41 | prefetch_mode: ~ 42 | 43 | val: 44 | name: ValSet 45 | type: Dataset_PairedImage 46 | dataroot_gt: /data/Denoise_dataset/val/SIDD/target_crops 47 | dataroot_lq: /data/Denoise_dataset/val/SIDD/input_crops 48 | io_backend: 49 | type: disk 50 | 51 | # network structures 52 | network_g: 53 | type: CDT 54 | in_channels: 3 55 | out_channels: 3 56 | dim: 48 57 | num_blocks: [1,2,3,4] 58 | heads: [1,2,2,4] 59 | ffn_expansion_factor: 2.67 60 | bias: False 61 | LayerNorm_type: BiasFree 62 | 63 | 64 | # path 65 | path: 66 | pretrain_network_g: ~ 67 | strict_load_g: true 68 | resume_state: ~ 69 | 70 | # training settings 71 | train: 72 | total_iter: 100000 73 | warmup_iter: -1 # no warm up 74 | use_grad_clip: true 75 | 76 | # Split 300k iterations into two cycles. 77 | # 1st cycle: fixed 3e-4 LR for 92k iters. 78 | # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters. 79 | scheduler: 80 | type: CosineAnnealingRestartCyclicLR 81 | periods: [92000, 208000] 82 | restart_weights: [1,1] 83 | eta_mins: [0.0003,0.000001] 84 | 85 | mixing_augs: 86 | mixup: true 87 | mixup_beta: 1.2 88 | use_identity: true 89 | 90 | optim_g: 91 | type: AdamW 92 | lr: !!float 3e-4 93 | weight_decay: !!float 1e-4 94 | betas: [0.9, 0.999] 95 | 96 | # losses 97 | pixel_opt: 98 | type: L1Loss 99 | loss_weight: 1 100 | reduction: mean 101 | 102 | # validation settings 103 | val: 104 | window_size: 8 105 | val_freq: !!float 4e3 106 | save_img: false 107 | rgb2bgr: true 108 | use_image: false 109 | max_minibatch: 8 110 | 111 | metrics: 112 | psnr: # metric name, can be arbitrary 113 | type: calculate_psnr 114 | crop_border: 0 115 | test_y_channel: false 116 | 117 | # logging settings 118 | logger: 119 | print_freq: 1000 120 | save_checkpoint_freq: !!float 4e3 121 | use_tb_logger: true 122 | wandb: 123 | project: ~ 124 | resume_id: ~ 125 | 126 | # dist training settings 127 | dist_params: 128 | backend: nccl 129 | port: 29500 130 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from functools import partial 7 | from os import path as osp 8 | 9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 10 | from basicsr.utils import get_root_logger, scandir 11 | from basicsr.utils.dist_util import get_dist_info 12 | 13 | __all__ = ['create_dataset', 'create_dataloader'] 14 | 15 | # automatically scan and import dataset modules 16 | # scan all the files under the data folder with '_dataset' in file names 17 | data_folder = osp.dirname(osp.abspath(__file__)) 18 | dataset_filenames = [ 19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 20 | if v.endswith('_dataset.py') 21 | ] 22 | # import all the dataset modules 23 | _dataset_modules = [ 24 | importlib.import_module(f'basicsr.data.{file_name}') 25 | for file_name in dataset_filenames 26 | ] 27 | 28 | 29 | def create_dataset(dataset_opt): 30 | """Create dataset. 31 | 32 | Args: 33 | dataset_opt (dict): Configuration for dataset. It constains: 34 | name (str): Dataset name. 35 | type (str): Dataset type. 36 | """ 37 | dataset_type = dataset_opt['type'] 38 | 39 | # dynamic instantiation 40 | for module in _dataset_modules: 41 | dataset_cls = getattr(module, dataset_type, None) 42 | if dataset_cls is not None: 43 | break 44 | if dataset_cls is None: 45 | raise ValueError(f'Dataset {dataset_type} is not found.') 46 | 47 | dataset = dataset_cls(dataset_opt) 48 | 49 | logger = get_root_logger() 50 | logger.info( 51 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 52 | 'is created.') 53 | return dataset 54 | 55 | 56 | def create_dataloader(dataset, 57 | dataset_opt, 58 | num_gpu=1, 59 | dist=False, 60 | sampler=None, 61 | seed=None): 62 | """Create dataloader. 63 | 64 | Args: 65 | dataset (torch.utils.data.Dataset): Dataset. 66 | dataset_opt (dict): Dataset options. It contains the following keys: 67 | phase (str): 'train' or 'val'. 68 | num_worker_per_gpu (int): Number of workers for each GPU. 69 | batch_size_per_gpu (int): Training batch size for each GPU. 70 | num_gpu (int): Number of GPUs. Used only in the train phase. 71 | Default: 1. 72 | dist (bool): Whether in distributed training. Used only in the train 73 | phase. Default: False. 74 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 75 | seed (int | None): Seed. Default: None 76 | """ 77 | phase = dataset_opt['phase'] 78 | rank, _ = get_dist_info() 79 | if phase == 'train': 80 | if dist: # distributed training 81 | batch_size = dataset_opt['batch_size_per_gpu'] 82 | num_workers = dataset_opt['num_worker_per_gpu'] 83 | else: # non-distributed training 84 | multiplier = 1 if num_gpu == 0 else num_gpu 85 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 86 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 87 | dataloader_args = dict( 88 | dataset=dataset, 89 | batch_size=batch_size, 90 | shuffle=False, 91 | num_workers=num_workers, 92 | sampler=sampler, 93 | drop_last=True) 94 | if sampler is None: 95 | dataloader_args['shuffle'] = True 96 | dataloader_args['worker_init_fn'] = partial( 97 | worker_init_fn, num_workers=num_workers, rank=rank, 98 | seed=seed) if seed is not None else None 99 | elif phase in ['val', 'test']: # validation 100 | dataloader_args = dict( 101 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 102 | else: 103 | raise ValueError(f'Wrong dataset phase: {phase}. ' 104 | "Supported ones are 'train', 'val' and 'test'.") 105 | 106 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 107 | 108 | prefetch_mode = dataset_opt.get('prefetch_mode') 109 | if prefetch_mode == 'cpu': # CPUPrefetcher 110 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 111 | logger = get_root_logger() 112 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 113 | f'num_prefetch_queue = {num_prefetch_queue}') 114 | return PrefetchDataLoader( 115 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 116 | else: 117 | # prefetch_mode=None: Normal dataloader 118 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 119 | return torch.utils.data.DataLoader(**dataloader_args) 120 | 121 | 122 | def worker_init_fn(worker_id, num_workers, rank, seed): 123 | # Set the worker seed to num_workers * rank + worker_id + seed 124 | worker_seed = num_workers * rank + worker_id + seed 125 | np.random.seed(worker_seed) 126 | random.seed(worker_seed) 127 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil( 27 | len(self.dataset) * ratio / self.num_replicas) 28 | self.total_size = self.num_samples * self.num_replicas 29 | 30 | def __iter__(self): 31 | # deterministically shuffle based on epoch 32 | g = torch.Generator() 33 | g.manual_seed(self.epoch) 34 | indices = torch.randperm(self.total_size, generator=g).tolist() 35 | 36 | dataset_size = len(self.dataset) 37 | indices = [v % dataset_size for v in indices] 38 | 39 | # subsample 40 | indices = indices[self.rank:self.total_size:self.num_replicas] 41 | assert len(indices) == self.num_samples 42 | 43 | return iter(indices) 44 | 45 | def __len__(self): 46 | return self.num_samples 47 | 48 | def set_epoch(self, epoch): 49 | self.epoch = epoch 50 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.transforms import augment 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor 7 | 8 | 9 | class FFHQDataset(data.Dataset): 10 | """FFHQ dataset for StyleGAN. 11 | 12 | Args: 13 | opt (dict): Config for train datasets. It contains the following keys: 14 | dataroot_gt (str): Data root path for gt. 15 | io_backend (dict): IO backend type and other kwarg. 16 | mean (list | tuple): Image mean. 17 | std (list | tuple): Image std. 18 | use_hflip (bool): Whether to horizontally flip. 19 | 20 | """ 21 | 22 | def __init__(self, opt): 23 | super(FFHQDataset, self).__init__() 24 | self.opt = opt 25 | # file client (io backend) 26 | self.file_client = None 27 | self.io_backend_opt = opt['io_backend'] 28 | 29 | self.gt_folder = opt['dataroot_gt'] 30 | self.mean = opt['mean'] 31 | self.std = opt['std'] 32 | 33 | if self.io_backend_opt['type'] == 'lmdb': 34 | self.io_backend_opt['db_paths'] = self.gt_folder 35 | if not self.gt_folder.endswith('.lmdb'): 36 | raise ValueError("'dataroot_gt' should end with '.lmdb', " 37 | f'but received {self.gt_folder}') 38 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 39 | self.paths = [line.split('.')[0] for line in fin] 40 | else: 41 | # FFHQ has 70000 images in total 42 | self.paths = [ 43 | osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000) 44 | ] 45 | 46 | def __getitem__(self, index): 47 | if self.file_client is None: 48 | self.file_client = FileClient( 49 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | img_bytes = self.file_client.get(gt_path) 54 | img_gt = imfrombytes(img_bytes, float32=True) 55 | 56 | # random horizontal flip 57 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 58 | # BGR to RGB, HWC to CHW, numpy to tensor 59 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 60 | # normalize 61 | normalize(img_gt, self.mean, self.std, inplace=True) 62 | return {'gt': img_gt, 'gt_path': gt_path} 63 | 64 | def __len__(self): 65 | return len(self.paths) 66 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDS4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 000 100 (720,1280,3) 2 | 011 100 (720,1280,3) 3 | 015 100 (720,1280,3) 4 | 020 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSofficial4_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 246 100 (720,1280,3) 4 | 257 100 (720,1280,3) 5 | -------------------------------------------------------------------------------- /basicsr/data/meta_info/meta_info_REDSval_official_test_GT.txt: -------------------------------------------------------------------------------- 1 | 240 100 (720,1280,3) 2 | 241 100 (720,1280,3) 3 | 242 100 (720,1280,3) 4 | 243 100 (720,1280,3) 5 | 244 100 (720,1280,3) 6 | 245 100 (720,1280,3) 7 | 246 100 (720,1280,3) 8 | 247 100 (720,1280,3) 9 | 248 100 (720,1280,3) 10 | 249 100 (720,1280,3) 11 | 250 100 (720,1280,3) 12 | 251 100 (720,1280,3) 13 | 252 100 (720,1280,3) 14 | 253 100 (720,1280,3) 15 | 254 100 (720,1280,3) 16 | 255 100 (720,1280,3) 17 | 256 100 (720,1280,3) 18 | 257 100 (720,1280,3) 19 | 258 100 (720,1280,3) 20 | 259 100 (720,1280,3) 21 | 260 100 (720,1280,3) 22 | 261 100 (720,1280,3) 23 | 262 100 (720,1280,3) 24 | 263 100 (720,1280,3) 25 | 264 100 (720,1280,3) 26 | 265 100 (720,1280,3) 27 | 266 100 (720,1280,3) 28 | 267 100 (720,1280,3) 29 | 268 100 (720,1280,3) 30 | 269 100 (720,1280,3) 31 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to( 116 | device=self.device, non_blocking=True) 117 | 118 | def next(self): 119 | torch.cuda.current_stream().wait_stream(self.stream) 120 | batch = self.batch 121 | self.preload() 122 | return batch 123 | 124 | def reset(self): 125 | self.loader = iter(self.ori_loader) 126 | self.preload() 127 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | 8 | 9 | class SingleImageDataset(data.Dataset): 10 | """Read only lq images in the test phase. 11 | 12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 13 | 14 | There are two modes: 15 | 1. 'meta_info_file': Use meta information file to generate paths. 16 | 2. 'folder': Scan folders to generate paths. 17 | 18 | Args: 19 | opt (dict): Config for train datasets. It contains the following keys: 20 | dataroot_lq (str): Data root path for lq. 21 | meta_info_file (str): Path for meta information file. 22 | io_backend (dict): IO backend type and other kwarg. 23 | """ 24 | 25 | def __init__(self, opt): 26 | super(SingleImageDataset, self).__init__() 27 | self.opt = opt 28 | # file client (io backend) 29 | self.file_client = None 30 | self.io_backend_opt = opt['io_backend'] 31 | self.mean = opt['mean'] if 'mean' in opt else None 32 | self.std = opt['std'] if 'std' in opt else None 33 | self.lq_folder = opt['dataroot_lq'] 34 | 35 | if self.io_backend_opt['type'] == 'lmdb': 36 | self.io_backend_opt['db_paths'] = [self.lq_folder] 37 | self.io_backend_opt['client_keys'] = ['lq'] 38 | self.paths = paths_from_lmdb(self.lq_folder) 39 | elif 'meta_info_file' in self.opt: 40 | with open(self.opt['meta_info_file'], 'r') as fin: 41 | self.paths = [ 42 | osp.join(self.lq_folder, 43 | line.split(' ')[0]) for line in fin 44 | ] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient( 51 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 52 | 53 | # load lq image 54 | lq_path = self.paths[index] 55 | img_bytes = self.file_client.get(lq_path, 'lq') 56 | img_lq = imfrombytes(img_bytes, float32=True) 57 | 58 | # TODO: color space transform 59 | # BGR to RGB, HWC to CHW, numpy to tensor 60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 61 | # normalize 62 | if self.mean is not None or self.std is not None: 63 | normalize(img_lq, self.mean, self.std, inplace=True) 64 | return {'lq': img_lq, 'lq_path': lq_path} 65 | 66 | def __len__(self): 67 | return len(self.paths) 68 | -------------------------------------------------------------------------------- /basicsr/data/vimeo90k_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from pathlib import Path 4 | from torch.utils import data as data 5 | 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 8 | 9 | 10 | class Vimeo90KDataset(data.Dataset): 11 | """Vimeo90K dataset for training. 12 | 13 | The keys are generated from a meta info txt file. 14 | basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt 15 | 16 | Each line contains: 17 | 1. clip name; 2. frame number; 3. image shape, seperated by a white space. 18 | Examples: 19 | 00001/0001 7 (256,448,3) 20 | 00001/0002 7 (256,448,3) 21 | 22 | Key examples: "00001/0001" 23 | GT (gt): Ground-Truth; 24 | LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames. 25 | 26 | The neighboring frame list for different num_frame: 27 | num_frame | frame list 28 | 1 | 4 29 | 3 | 3,4,5 30 | 5 | 2,3,4,5,6 31 | 7 | 1,2,3,4,5,6,7 32 | 33 | Args: 34 | opt (dict): Config for train dataset. It contains the following keys: 35 | dataroot_gt (str): Data root path for gt. 36 | dataroot_lq (str): Data root path for lq. 37 | meta_info_file (str): Path for meta information file. 38 | io_backend (dict): IO backend type and other kwarg. 39 | 40 | num_frame (int): Window size for input frames. 41 | gt_size (int): Cropped patched size for gt patches. 42 | random_reverse (bool): Random reverse input frames. 43 | use_flip (bool): Use horizontal flips. 44 | use_rot (bool): Use rotation (use vertical flip and transposing h 45 | and w for implementation). 46 | 47 | scale (bool): Scale, which will be added automatically. 48 | """ 49 | 50 | def __init__(self, opt): 51 | super(Vimeo90KDataset, self).__init__() 52 | self.opt = opt 53 | self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path( 54 | opt['dataroot_lq']) 55 | 56 | with open(opt['meta_info_file'], 'r') as fin: 57 | self.keys = [line.split(' ')[0] for line in fin] 58 | 59 | # file client (io backend) 60 | self.file_client = None 61 | self.io_backend_opt = opt['io_backend'] 62 | self.is_lmdb = False 63 | if self.io_backend_opt['type'] == 'lmdb': 64 | self.is_lmdb = True 65 | self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root] 66 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 67 | 68 | # indices of input images 69 | self.neighbor_list = [ 70 | i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame']) 71 | ] 72 | 73 | # temporal augmentation configs 74 | self.random_reverse = opt['random_reverse'] 75 | logger = get_root_logger() 76 | logger.info(f'Random reverse is {self.random_reverse}.') 77 | 78 | def __getitem__(self, index): 79 | if self.file_client is None: 80 | self.file_client = FileClient( 81 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 82 | 83 | # random reverse 84 | if self.random_reverse and random.random() < 0.5: 85 | self.neighbor_list.reverse() 86 | 87 | scale = self.opt['scale'] 88 | gt_size = self.opt['gt_size'] 89 | key = self.keys[index] 90 | clip, seq = key.split('/') # key example: 00001/0001 91 | 92 | # get the GT frame (im4.png) 93 | if self.is_lmdb: 94 | img_gt_path = f'{key}/im4' 95 | else: 96 | img_gt_path = self.gt_root / clip / seq / 'im4.png' 97 | img_bytes = self.file_client.get(img_gt_path, 'gt') 98 | img_gt = imfrombytes(img_bytes, float32=True) 99 | 100 | # get the neighboring LQ frames 101 | img_lqs = [] 102 | for neighbor in self.neighbor_list: 103 | if self.is_lmdb: 104 | img_lq_path = f'{clip}/{seq}/im{neighbor}' 105 | else: 106 | img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png' 107 | img_bytes = self.file_client.get(img_lq_path, 'lq') 108 | img_lq = imfrombytes(img_bytes, float32=True) 109 | img_lqs.append(img_lq) 110 | 111 | # randomly crop 112 | img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, 113 | img_gt_path) 114 | 115 | # augmentation - flip, rotate 116 | img_lqs.append(img_gt) 117 | img_results = augment(img_lqs, self.opt['use_flip'], 118 | self.opt['use_rot']) 119 | 120 | img_results = img2tensor(img_results) 121 | img_lqs = torch.stack(img_results[0:-1], dim=0) 122 | img_gt = img_results[-1] 123 | 124 | # img_lqs: (t, c, h, w) 125 | # img_gt: (c, h, w) 126 | # key: str 127 | return {'lq': img_lqs, 'gt': img_gt, 'key': key} 128 | 129 | def __len__(self): 130 | return len(self.keys) 131 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .niqe import calculate_niqe 2 | from .psnr_ssim import calculate_psnr, calculate_ssim 3 | 4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 5 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.models.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', 11 | resize_input=True, 12 | normalize_input=False): 13 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 14 | # does resize the input. 15 | inception = InceptionV3([3], 16 | resize_input=resize_input, 17 | normalize_input=normalize_input) 18 | inception = nn.DataParallel(inception).eval().to(device) 19 | return inception 20 | 21 | 22 | @torch.no_grad() 23 | def extract_inception_features(data_generator, 24 | inception, 25 | len_generator=None, 26 | device='cuda'): 27 | """Extract inception features. 28 | 29 | Args: 30 | data_generator (generator): A data generator. 31 | inception (nn.Module): Inception model. 32 | len_generator (int): Length of the data_generator to show the 33 | progressbar. Default: None. 34 | device (str): Device. Default: cuda. 35 | 36 | Returns: 37 | Tensor: Extracted features. 38 | """ 39 | if len_generator is not None: 40 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 41 | else: 42 | pbar = None 43 | features = [] 44 | 45 | for data in data_generator: 46 | if pbar: 47 | pbar.update(1) 48 | data = data.to(device) 49 | feature = inception(data)[0].view(data.shape[0], -1) 50 | features.append(feature.to('cpu')) 51 | if pbar: 52 | pbar.close() 53 | features = torch.cat(features, 0) 54 | return features 55 | 56 | 57 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 58 | """Numpy implementation of the Frechet Distance. 59 | 60 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 61 | and X_2 ~ N(mu_2, C_2) is 62 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 63 | Stable version by Dougal J. Sutherland. 64 | 65 | Args: 66 | mu1 (np.array): The sample mean over activations. 67 | sigma1 (np.array): The covariance matrix over activations for 68 | generated samples. 69 | mu2 (np.array): The sample mean over activations, precalculated on an 70 | representative data set. 71 | sigma2 (np.array): The covariance matrix over activations, 72 | precalculated on an representative data set. 73 | 74 | Returns: 75 | float: The Frechet Distance. 76 | """ 77 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 78 | assert sigma1.shape == sigma2.shape, ( 79 | 'Two covariances have different dimensions') 80 | 81 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 82 | 83 | # Product might be almost singular 84 | if not np.isfinite(cov_sqrt).all(): 85 | print('Product of cov matrices is singular. Adding {eps} to diagonal ' 86 | 'of cov estimates') 87 | offset = np.eye(sigma1.shape[0]) * eps 88 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 89 | 90 | # Numerical error might give slight imaginary component 91 | if np.iscomplexobj(cov_sqrt): 92 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 93 | m = np.max(np.abs(cov_sqrt.imag)) 94 | raise ValueError(f'Imaginary component {m}') 95 | cov_sqrt = cov_sqrt.real 96 | 97 | mean_diff = mu1 - mu2 98 | mean_norm = mean_diff @ mean_diff 99 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 100 | fid = mean_norm + trace 101 | 102 | return fid 103 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError( 25 | f'Wrong input_order {input_order}. Supported input_orders are ' 26 | "'HWC' and 'CHW'") 27 | if len(img.shape) == 2: 28 | img = img[..., None] 29 | if input_order == 'CHW': 30 | img = img.transpose(1, 2, 0) 31 | return img 32 | 33 | 34 | def to_y_channel(img): 35 | """Change to Y channel of YCbCr. 36 | 37 | Args: 38 | img (ndarray): Images with range [0, 255]. 39 | 40 | Returns: 41 | (ndarray): Images with range [0, 255] (float type) without round. 42 | """ 43 | img = img.astype(np.float32) / 255. 44 | if img.ndim == 3 and img.shape[2] == 3: 45 | img = bgr2ycbcr(img, y_only=True) 46 | img = img[..., None] 47 | return img * 255. 48 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import get_root_logger, scandir 5 | 6 | # automatically scan and import model modules 7 | # scan all the files under the 'models' folder and collect files ending with 8 | # '_model.py' 9 | model_folder = osp.dirname(osp.abspath(__file__)) 10 | model_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 12 | if v.endswith('_model.py') 13 | ] 14 | # import all the model modules 15 | _model_modules = [ 16 | importlib.import_module(f'basicsr.models.{file_name}') 17 | for file_name in model_filenames 18 | ] 19 | 20 | 21 | def create_model(opt): 22 | """Create model. 23 | 24 | Args: 25 | opt (dict): Configuration. It constains: 26 | model_type (str): Model type. 27 | """ 28 | model_type = opt['model_type'] 29 | 30 | # dynamic instantiation 31 | for module in _model_modules: 32 | model_cls = getattr(module, model_type, None) 33 | if model_cls is not None: 34 | break 35 | if model_cls is None: 36 | raise ValueError(f'Model {model_type} is not found.') 37 | 38 | model = model_cls(opt) 39 | 40 | logger = get_root_logger() 41 | logger.info(f'Model [{model.__class__.__name__}] is created.') 42 | return model 43 | -------------------------------------------------------------------------------- /basicsr/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules 7 | # scan all the files under the 'archs' folder and collect files ending with 8 | # '_arch.py' 9 | arch_folder = osp.dirname(osp.abspath(__file__)) 10 | arch_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 12 | if v.endswith('_arch.py') 13 | ] 14 | # import all the arch modules 15 | _arch_modules = [ 16 | importlib.import_module(f'basicsr.models.archs.{file_name}') 17 | for file_name in arch_filenames 18 | ] 19 | 20 | 21 | def dynamic_instantiation(modules, cls_type, opt): 22 | """Dynamically instantiate class. 23 | 24 | Args: 25 | modules (list[importlib modules]): List of modules from importlib 26 | files. 27 | cls_type (str): Class type. 28 | opt (dict): Class initialization kwargs. 29 | 30 | Returns: 31 | class: Instantiated class. 32 | """ 33 | 34 | for module in modules: 35 | cls_ = getattr(module, cls_type, None) 36 | if cls_ is not None: 37 | break 38 | if cls_ is None: 39 | raise ValueError(f'{cls_type} is not found.') 40 | return cls_(**opt) 41 | 42 | 43 | def define_network(opt): 44 | network_type = opt.pop('type') 45 | net = dynamic_instantiation(_arch_modules, network_type, opt) 46 | return net 47 | -------------------------------------------------------------------------------- /basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss) 2 | 3 | __all__ = [ 4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss', 5 | ] 6 | -------------------------------------------------------------------------------- /basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/models/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from basicsr.models.losses.loss_util import weighted_loss 7 | 8 | _reduction_modes = ['none', 'mean', 'sum'] 9 | 10 | 11 | @weighted_loss 12 | def l1_loss(pred, target): 13 | return F.l1_loss(pred, target, reduction='none') 14 | 15 | 16 | @weighted_loss 17 | def mse_loss(pred, target): 18 | return F.mse_loss(pred, target, reduction='none') 19 | 20 | 21 | # @weighted_loss 22 | # def charbonnier_loss(pred, target, eps=1e-12): 23 | # return torch.sqrt((pred - target)**2 + eps) 24 | 25 | 26 | class L1Loss(nn.Module): 27 | """L1 (mean absolute error, MAE) loss. 28 | 29 | Args: 30 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 31 | reduction (str): Specifies the reduction to apply to the output. 32 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 33 | """ 34 | 35 | def __init__(self, loss_weight=1.0, reduction='mean'): 36 | super(L1Loss, self).__init__() 37 | if reduction not in ['none', 'mean', 'sum']: 38 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 39 | f'Supported ones are: {_reduction_modes}') 40 | 41 | self.loss_weight = loss_weight 42 | self.reduction = reduction 43 | 44 | def forward(self, pred, target, weight=None, **kwargs): 45 | """ 46 | Args: 47 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 48 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 49 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 50 | weights. Default: None. 51 | """ 52 | return self.loss_weight * l1_loss( 53 | pred, target, weight, reduction=self.reduction) 54 | 55 | class MSELoss(nn.Module): 56 | """MSE (L2) loss. 57 | 58 | Args: 59 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 60 | reduction (str): Specifies the reduction to apply to the output. 61 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 62 | """ 63 | 64 | def __init__(self, loss_weight=1.0, reduction='mean'): 65 | super(MSELoss, self).__init__() 66 | if reduction not in ['none', 'mean', 'sum']: 67 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 68 | f'Supported ones are: {_reduction_modes}') 69 | 70 | self.loss_weight = loss_weight 71 | self.reduction = reduction 72 | 73 | def forward(self, pred, target, weight=None, **kwargs): 74 | """ 75 | Args: 76 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 77 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 78 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 79 | weights. Default: None. 80 | """ 81 | return self.loss_weight * mse_loss( 82 | pred, target, weight, reduction=self.reduction) 83 | 84 | class PSNRLoss(nn.Module): 85 | 86 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 87 | super(PSNRLoss, self).__init__() 88 | assert reduction == 'mean' 89 | self.loss_weight = loss_weight 90 | self.scale = 10 / np.log(10) 91 | self.toY = toY 92 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 93 | self.first = True 94 | 95 | def forward(self, pred, target): 96 | assert len(pred.size()) == 4 97 | if self.toY: 98 | if self.first: 99 | self.coef = self.coef.to(pred.device) 100 | self.first = False 101 | 102 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 103 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 104 | 105 | pred, target = pred / 255., target / 255. 106 | pass 107 | assert len(pred.size()) == 4 108 | 109 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 110 | 111 | class CharbonnierLoss(nn.Module): 112 | """Charbonnier Loss (L1)""" 113 | 114 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3): 115 | super(CharbonnierLoss, self).__init__() 116 | self.eps = eps 117 | 118 | def forward(self, x, y): 119 | diff = x - y 120 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 121 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 122 | return loss 123 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import create_dataloader, create_dataset 6 | from basicsr.models import create_model 7 | from basicsr.train import parse_options 8 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 9 | make_exp_dirs) 10 | from basicsr.utils.options import dict2str 11 | 12 | 13 | def main(): 14 | # parse options, set distributed setting, set ramdom seed 15 | opt = parse_options(is_train=False) 16 | 17 | torch.backends.cudnn.benchmark = True 18 | # torch.backends.cudnn.deterministic = True 19 | 20 | # mkdir and initialize loggers 21 | make_exp_dirs(opt) 22 | log_file = osp.join(opt['path']['log'], 23 | f"test_{opt['name']}_{get_time_str()}.log") 24 | logger = get_root_logger( 25 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 26 | logger.info(get_env_info()) 27 | logger.info(dict2str(opt)) 28 | 29 | # create test dataset and dataloader 30 | test_loaders = [] 31 | for phase, dataset_opt in sorted(opt['datasets'].items()): 32 | test_set = create_dataset(dataset_opt) 33 | test_loader = create_dataloader( 34 | test_set, 35 | dataset_opt, 36 | num_gpu=opt['num_gpu'], 37 | dist=opt['dist'], 38 | sampler=None, 39 | seed=opt['manual_seed']) 40 | logger.info( 41 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 42 | test_loaders.append(test_loader) 43 | 44 | # create model 45 | model = create_model(opt) 46 | 47 | for test_loader in test_loaders: 48 | test_set_name = test_loader.dataset.opt['name'] 49 | logger.info(f'Testing {test_set_name}...') 50 | rgb2bgr = opt['val'].get('rgb2bgr', True) 51 | # wheather use uint8 image to compute metrics 52 | use_image = opt['val'].get('use_image', True) 53 | model.validation( 54 | test_loader, 55 | current_iter=opt['name'], 56 | tb_logger=None, 57 | save_img=opt['val']['save_img'], 58 | rgb2bgr=rgb2bgr, use_image=use_image) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP 3 | from .logger import (MessageLogger, get_env_info, get_root_logger, 4 | init_tb_logger, init_wandb_logger) 5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | 'padding', 33 | 'padding_DP', 34 | 'imfrombytesDP', 35 | 'create_lmdb_for_reds', 36 | 'create_lmdb_for_gopro', 37 | 'create_lmdb_for_rain13k', 38 | ] 39 | -------------------------------------------------------------------------------- /basicsr/utils/bundle_submissions.py: -------------------------------------------------------------------------------- 1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de) 2 | 3 | # This file is part of the implementation as described in the CVPR 2017 paper: 4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs. 5 | # Please see the file LICENSE.txt for the license governing this code. 6 | 7 | 8 | import numpy as np 9 | import scipy.io as sio 10 | import os 11 | import h5py 12 | 13 | def bundle_submissions_raw(submission_folder,session): 14 | ''' 15 | Bundles submission data for raw denoising 16 | 17 | submission_folder Folder where denoised images reside 18 | 19 | Output is written to /bundled/. Please submit 20 | the content of this folder. 21 | ''' 22 | 23 | out_folder = os.path.join(submission_folder, session) 24 | # out_folder = os.path.join(submission_folder, "bundled/") 25 | try: 26 | os.mkdir(out_folder) 27 | except:pass 28 | 29 | israw = True 30 | eval_version="1.0" 31 | 32 | for i in range(50): 33 | Idenoised = np.zeros((20,), dtype=np.object) 34 | for bb in range(20): 35 | filename = '%04d_%02d.mat'%(i+1,bb+1) 36 | s = sio.loadmat(os.path.join(submission_folder,filename)) 37 | Idenoised_crop = s["Idenoised_crop"] 38 | Idenoised[bb] = Idenoised_crop 39 | filename = '%04d.mat'%(i+1) 40 | sio.savemat(os.path.join(out_folder, filename), 41 | {"Idenoised": Idenoised, 42 | "israw": israw, 43 | "eval_version": eval_version}, 44 | ) 45 | 46 | def bundle_submissions_srgb(submission_folder,session): 47 | ''' 48 | Bundles submission data for sRGB denoising 49 | 50 | submission_folder Folder where denoised images reside 51 | 52 | Output is written to /bundled/. Please submit 53 | the content of this folder. 54 | ''' 55 | out_folder = os.path.join(submission_folder, session) 56 | # out_folder = os.path.join(submission_folder, "bundled/") 57 | try: 58 | os.mkdir(out_folder) 59 | except:pass 60 | israw = False 61 | eval_version="1.0" 62 | 63 | for i in range(50): 64 | Idenoised = np.zeros((20,), dtype=np.object) 65 | for bb in range(20): 66 | filename = '%04d_%02d.mat'%(i+1,bb+1) 67 | s = sio.loadmat(os.path.join(submission_folder,filename)) 68 | Idenoised_crop = s["Idenoised_crop"] 69 | Idenoised[bb] = Idenoised_crop 70 | filename = '%04d.mat'%(i+1) 71 | sio.savemat(os.path.join(out_folder, filename), 72 | {"Idenoised": Idenoised, 73 | "israw": israw, 74 | "eval_version": eval_version}, 75 | ) 76 | 77 | 78 | 79 | def bundle_submissions_srgb_v1(submission_folder,session): 80 | ''' 81 | Bundles submission data for sRGB denoising 82 | 83 | submission_folder Folder where denoised images reside 84 | 85 | Output is written to /bundled/. Please submit 86 | the content of this folder. 87 | ''' 88 | out_folder = os.path.join(submission_folder, session) 89 | # out_folder = os.path.join(submission_folder, "bundled/") 90 | try: 91 | os.mkdir(out_folder) 92 | except:pass 93 | israw = False 94 | eval_version="1.0" 95 | 96 | for i in range(50): 97 | Idenoised = np.zeros((20,), dtype=np.object) 98 | for bb in range(20): 99 | filename = '%04d_%d.mat'%(i+1,bb+1) 100 | s = sio.loadmat(os.path.join(submission_folder,filename)) 101 | Idenoised_crop = s["Idenoised_crop"] 102 | Idenoised[bb] = Idenoised_crop 103 | filename = '%04d.mat'%(i+1) 104 | sio.savemat(os.path.join(out_folder, filename), 105 | {"Idenoised": Idenoised, 106 | "israw": israw, 107 | "eval_version": eval_version}, 108 | ) -------------------------------------------------------------------------------- /basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | def prepare_keys(folder_path, suffix='png'): 8 | """Prepare image path list and keys for DIV2K dataset. 9 | 10 | Args: 11 | folder_path (str): Folder path. 12 | 13 | Returns: 14 | list[str]: Image path list. 15 | list[str]: Key list. 16 | """ 17 | print('Reading image path list ...') 18 | img_path_list = sorted( 19 | list(scandir(folder_path, suffix=suffix, recursive=False))) 20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 21 | 22 | return img_path_list, keys 23 | 24 | def create_lmdb_for_reds(): 25 | folder_path = './datasets/REDS/val/sharp_300' 26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 27 | img_path_list, keys = prepare_keys(folder_path, 'png') 28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 29 | # 30 | folder_path = './datasets/REDS/val/blur_300' 31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 32 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 34 | 35 | folder_path = './datasets/REDS/train/train_sharp' 36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 37 | img_path_list, keys = prepare_keys(folder_path, 'png') 38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 39 | 40 | folder_path = './datasets/REDS/train/train_blur_jpeg' 41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 42 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 44 | 45 | 46 | def create_lmdb_for_gopro(): 47 | folder_path = './datasets/GoPro/train/blur_crops' 48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 49 | 50 | img_path_list, keys = prepare_keys(folder_path, 'png') 51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 52 | 53 | folder_path = './datasets/GoPro/train/sharp_crops' 54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/test/target' 60 | lmdb_path = './datasets/GoPro/test/target.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | folder_path = './datasets/GoPro/test/input' 66 | lmdb_path = './datasets/GoPro/test/input.lmdb' 67 | 68 | img_path_list, keys = prepare_keys(folder_path, 'png') 69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | def create_lmdb_for_rain13k(): 72 | folder_path = './datasets/Rain13k/train/input' 73 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 74 | 75 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 77 | 78 | folder_path = './datasets/Rain13k/train/target' 79 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | def create_lmdb_for_SIDD(): 85 | folder_path = './datasets/SIDD/train/input_crops' 86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 87 | 88 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 90 | 91 | folder_path = './datasets/SIDD/train/gt_crops' 92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | #for val 98 | folder_path = './datasets/SIDD/val/input_crops' 99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 101 | if not osp.exists(folder_path): 102 | os.makedirs(folder_path) 103 | assert osp.exists(mat_path) 104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 105 | N, B, H ,W, C = data.shape 106 | data = data.reshape(N*B, H, W, C) 107 | for i in tqdm(range(N*B)): 108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 109 | img_path_list, keys = prepare_keys(folder_path, 'png') 110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 111 | 112 | folder_path = './datasets/SIDD/val/gt_crops' 113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 115 | if not osp.exists(folder_path): 116 | os.makedirs(folder_path) 117 | assert osp.exists(mat_path) 118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 119 | N, B, H ,W, C = data.shape 120 | data = data.reshape(N*B, H, W, C) 121 | for i in tqdm(range(N*B)): 122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 123 | img_path_list, keys = prepare_keys(folder_path, 'png') 124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 125 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | from .misc import sizeof_fmt 6 | 7 | 8 | def download_file_from_google_drive(file_id, save_path): 9 | """Download files from google drive. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 13 | 14 | Args: 15 | file_id (str): File id. 16 | save_path (str): Save path. 17 | """ 18 | 19 | session = requests.Session() 20 | URL = 'https://docs.google.com/uc?export=download' 21 | params = {'id': file_id} 22 | 23 | response = session.get(URL, params=params, stream=True) 24 | token = get_confirm_token(response) 25 | if token: 26 | params['confirm'] = token 27 | response = session.get(URL, params=params, stream=True) 28 | 29 | # get file size 30 | response_file_size = session.get( 31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | if 'Content-Range' in response_file_size.headers: 33 | file_size = int( 34 | response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, 49 | destination, 50 | file_size=None, 51 | chunk_size=32768): 52 | if file_size is not None: 53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 54 | 55 | readable_file_size = sizeof_fmt(file_size) 56 | else: 57 | pbar = None 58 | 59 | with open(destination, 'wb') as f: 60 | downloaded_size = 0 61 | for chunk in response.iter_content(chunk_size): 62 | downloaded_size += chunk_size 63 | if pbar is not None: 64 | pbar.update(1) 65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 66 | f'/ {readable_file_size}') 67 | if chunk: # filter out keep-alive new chunks 68 | f.write(chunk) 69 | if pbar is not None: 70 | pbar.close() 71 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from collections import OrderedDict 3 | from os import path as osp 4 | 5 | 6 | def ordered_yaml(): 7 | """Support OrderedDict for yaml. 8 | 9 | Returns: 10 | yaml Loader and Dumper. 11 | """ 12 | try: 13 | from yaml import CDumper as Dumper 14 | from yaml import CLoader as Loader 15 | except ImportError: 16 | from yaml import Dumper, Loader 17 | 18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 19 | 20 | def dict_representer(dumper, data): 21 | return dumper.represent_dict(data.items()) 22 | 23 | def dict_constructor(loader, node): 24 | return OrderedDict(loader.construct_pairs(node)) 25 | 26 | Dumper.add_representer(OrderedDict, dict_representer) 27 | Loader.add_constructor(_mapping_tag, dict_constructor) 28 | return Loader, Dumper 29 | 30 | 31 | def parse(opt_path, is_train=True): 32 | """Parse option file. 33 | 34 | Args: 35 | opt_path (str): Option file path. 36 | is_train (str): Indicate whether in training or not. Default: True. 37 | 38 | Returns: 39 | (dict): Options. 40 | """ 41 | with open(opt_path, mode='r') as f: 42 | Loader, _ = ordered_yaml() 43 | opt = yaml.load(f, Loader=Loader) 44 | 45 | opt['is_train'] = is_train 46 | 47 | # datasets 48 | for phase, dataset in opt['datasets'].items(): 49 | # for several datasets, e.g., test_1, test_2 50 | phase = phase.split('_')[0] 51 | dataset['phase'] = phase 52 | if 'scale' in opt: 53 | dataset['scale'] = opt['scale'] 54 | if dataset.get('dataroot_gt') is not None: 55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 56 | if dataset.get('dataroot_lq') is not None: 57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 58 | 59 | # paths 60 | for key, val in opt['path'].items(): 61 | if (val is not None) and ('resume_state' in key 62 | or 'pretrain_network' in key): 63 | opt['path'][key] = osp.expanduser(val) 64 | opt['path']['root'] = osp.abspath( 65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 66 | if is_train: 67 | experiments_root = osp.join(opt['path']['root'], 'experiments', 68 | opt['name']) 69 | opt['path']['experiments_root'] = experiments_root 70 | opt['path']['models'] = osp.join(experiments_root, 'models') 71 | opt['path']['training_states'] = osp.join(experiments_root, 72 | 'training_states') 73 | opt['path']['log'] = experiments_root 74 | opt['path']['visualization'] = osp.join(experiments_root, 75 | 'visualization') 76 | 77 | # change some options for debug mode 78 | if 'debug' in opt['name']: 79 | if 'val' in opt: 80 | opt['val']['val_freq'] = 8 81 | opt['logger']['print_freq'] = 1 82 | opt['logger']['save_checkpoint_freq'] = 8 83 | else: # test 84 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 85 | opt['path']['results_root'] = results_root 86 | opt['path']['log'] = results_root 87 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 88 | 89 | return opt 90 | 91 | 92 | def dict2str(opt, indent_level=1): 93 | """dict to string for printing options. 94 | 95 | Args: 96 | opt (dict): Option dict. 97 | indent_level (int): Indent level. Default: 1. 98 | 99 | Return: 100 | (str): Option string for printing. 101 | """ 102 | msg = '\n' 103 | for k, v in opt.items(): 104 | if isinstance(v, dict): 105 | msg += ' ' * (indent_level * 2) + k + ':[' 106 | msg += dict2str(v, indent_level + 1) 107 | msg += ' ' * (indent_level * 2) + ']\n' 108 | else: 109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 110 | return msg 111 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Wed Mar 9 22:05:30 2022 3 | __version__ = '1.2.0+10018c6' 4 | short_version = '1.2.0' 5 | version_info = (1, 2, 0) 6 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torchvision 4 | import torch.optim 5 | import os 6 | from basicsr.models.archs.CDT_arch import CDT 7 | from denoiser.denoiser_builder import Denoiser 8 | import argparse 9 | from PIL import Image 10 | 11 | 12 | torch.set_num_threads(1) 13 | 14 | parser = argparse.ArgumentParser(description='demo processer') 15 | parser.add_argument('--d_weights', default='./experiments/CDT/model.pth', type=str, 16 | help='weights') 17 | args = parser.parse_args() 18 | 19 | 20 | if __name__ == '__main__': 21 | 22 | parameters = {'CDT':{'in_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[1,2,3,4], 'heads':[1,2,2,4], 'ffn_expansion_factor':2.67, 'bias':False, 'LayerNorm_type':'BiasFree'}, 23 | } 24 | 25 | model = CDT(**parameters['CDT']) 26 | cdt = Denoiser(model, args) 27 | 28 | filePath = './CDT/demo/img_ori/' 29 | 30 | file_list = os.listdir(filePath) 31 | 32 | with torch.no_grad(): 33 | 34 | for file_name in file_list: 35 | img = Image.open(filePath+file_name) 36 | restored = cdt.single_denoise(img) 37 | dns_path = filePath.replace('img_ori', 'img_dns') 38 | torchvision.utils.save_image(restored, dns_path+file_name) 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /demo/img_dns/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/demo/img_dns/00001.jpg -------------------------------------------------------------------------------- /demo/img_dns/00074.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/demo/img_dns/00074.jpg -------------------------------------------------------------------------------- /demo/img_dns/noisy1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/demo/img_dns/noisy1.png -------------------------------------------------------------------------------- /demo/img_dns/noisy2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/demo/img_dns/noisy2.png -------------------------------------------------------------------------------- /demo/img_ori/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/demo/img_ori/00001.jpg -------------------------------------------------------------------------------- /demo/img_ori/00074.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/demo/img_ori/00074.jpg -------------------------------------------------------------------------------- /demo/img_ori/noisy1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/demo/img_ori/noisy1.png -------------------------------------------------------------------------------- /demo/img_ori/noisy2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/demo/img_ori/noisy2.png -------------------------------------------------------------------------------- /denoiser/denoiser_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from basicsr.models.archs.CDT_arch import CDT 6 | 7 | DENOISERS = { 8 | 'CDT': CDT, 9 | } 10 | 11 | class Denoiser(): 12 | def __init__(self, model, args): 13 | super(Denoiser, self).__init__() 14 | 15 | self.model = model 16 | self.model.cuda() 17 | 18 | checkpoint = torch.load(args.d_weights) 19 | self.model.load_state_dict(checkpoint['params']) 20 | model.eval() 21 | 22 | self.multiples = 8 23 | 24 | def denoise(self, img): 25 | 26 | input_ = torch.div(img, 255.) 27 | 28 | # Pad the input if not_multiple_of 8 29 | h,w = input_.shape[2], input_.shape[3] 30 | H,W = ((h+self.multiples)//self.multiples)*self.multiples, ((w+self.multiples)//self.multiples)*self.multiples 31 | padh = H-h if h%self.multiples!=0 else 0 32 | padw = W-w if w%self.multiples!=0 else 0 33 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 34 | 35 | restored = self.model(input_) 36 | 37 | restored = torch.clamp(restored, 0, 1) 38 | 39 | # Unpad the output 40 | restored = restored[:,:,:h,:w] 41 | 42 | return torch.mul(restored, 255.) 43 | 44 | def single_denoise(self, img): 45 | img = (np.asarray(img)/255.0) 46 | img = torch.from_numpy(img).float() 47 | img = img.permute(2,0,1) 48 | input_ = img.cuda().unsqueeze(0) 49 | 50 | # Pad the input if not_multiple_of 8 51 | h,w = input_.shape[2], input_.shape[3] 52 | H,W = ((h+self.multiples)//self.multiples)*self.multiples, ((w+self.multiples)//self.multiples)*self.multiples 53 | padh = H-h if h%self.multiples!=0 else 0 54 | padw = W-w if w%self.multiples!=0 else 0 55 | 56 | 57 | restored = self.model(F.pad(input_, (0,padw,0,padh), 'reflect')) 58 | 59 | restored = torch.clamp(restored, 0, 1) 60 | 61 | # Unpad the output 62 | return restored[:,:,:h,:w] 63 | 64 | 65 | def build_denoiser(args): 66 | 67 | parameters = {'CDT':{'in_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[1,2,3,4], 'heads':[1,2,2,4], 'ffn_expansion_factor':2.67, 'bias':False, 'LayerNorm_type':'BiasFree'}, 68 | } 69 | 70 | model = DENOISERS[args.denoisername.split('-')[0]](**parameters[args.denoisername.split('-')[0]]) 71 | return Denoiser(model, args) 72 | 73 | -------------------------------------------------------------------------------- /enhancer/DCE_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | 7 | 8 | class CSDN_Tem(nn.Module): 9 | def __init__(self, in_ch, out_ch): 10 | super(CSDN_Tem, self).__init__() 11 | self.depth_conv = nn.Conv2d( 12 | in_channels=in_ch, 13 | out_channels=in_ch, 14 | kernel_size=3, 15 | stride=1, 16 | padding=1, 17 | groups=in_ch 18 | ) 19 | self.point_conv = nn.Conv2d( 20 | in_channels=in_ch, 21 | out_channels=out_ch, 22 | kernel_size=1, 23 | stride=1, 24 | padding=0, 25 | groups=1 26 | ) 27 | 28 | def forward(self, input): 29 | out = self.depth_conv(input) 30 | out = self.point_conv(out) 31 | return out 32 | 33 | class enhance_net_nopool(nn.Module): 34 | 35 | def __init__(self,scale_factor): 36 | super(enhance_net_nopool, self).__init__() 37 | 38 | self.relu = nn.ReLU(inplace=True) 39 | self.scale_factor = scale_factor 40 | self.upsample = nn.UpsamplingBilinear2d(scale_factor=self.scale_factor) 41 | number_f = 32 42 | 43 | # zerodce DWC + p-shared 44 | self.e_conv1 = CSDN_Tem(3,number_f) 45 | self.e_conv2 = CSDN_Tem(number_f,number_f) 46 | self.e_conv3 = CSDN_Tem(number_f,number_f) 47 | self.e_conv4 = CSDN_Tem(number_f,number_f) 48 | self.e_conv5 = CSDN_Tem(number_f*2,number_f) 49 | self.e_conv6 = CSDN_Tem(number_f*2,number_f) 50 | self.e_conv7 = CSDN_Tem(number_f*2,3) 51 | 52 | def enhance(self, x,x_r): 53 | 54 | x = x + x_r*(torch.pow(x,2)-x) 55 | x = x + x_r*(torch.pow(x,2)-x) 56 | x = x + x_r*(torch.pow(x,2)-x) 57 | enhance_image_1 = x + x_r*(torch.pow(x,2)-x) 58 | x = enhance_image_1 + x_r*(torch.pow(enhance_image_1,2)-enhance_image_1) 59 | x = x + x_r*(torch.pow(x,2)-x) 60 | x = x + x_r*(torch.pow(x,2)-x) 61 | enhance_image = x + x_r*(torch.pow(x,2)-x) 62 | 63 | return enhance_image 64 | 65 | def forward(self, x): 66 | if self.scale_factor==1: 67 | x_down = x 68 | else: 69 | x_down = F.interpolate(x,scale_factor=1/self.scale_factor, mode='bilinear') 70 | 71 | x1 = self.relu(self.e_conv1(x_down)) 72 | x2 = self.relu(self.e_conv2(x1)) 73 | x3 = self.relu(self.e_conv3(x2)) 74 | x4 = self.relu(self.e_conv4(x3)) 75 | x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1))) 76 | x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1))) 77 | x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1))) 78 | if self.scale_factor==1: 79 | x_r = x_r 80 | else: 81 | x_r = self.upsample(x_r) 82 | enhance_image = self.enhance(x,x_r) 83 | return enhance_image,x_r 84 | -------------------------------------------------------------------------------- /enhancer/enhancer_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .SCT_model import SCT 5 | from .DCE_model import enhance_net_nopool as DCE 6 | 7 | ENHANCERS = { 8 | 'SCT': SCT, 9 | 'DCE': DCE, 10 | } 11 | 12 | class Enhancer(): 13 | def __init__(self, args): 14 | super(Enhancer, self).__init__() 15 | self.args = args 16 | if args.enhancername.split('-')[0]=='SCT': 17 | self.model = SCT(img_size=128,embed_dim=32,win_size=4,token_embed='linear',token_mlp='resffn') 18 | elif args.enhancername.split('-')[0]=='DCE': 19 | self.model = DCE(scale_factor=12) 20 | 21 | self.model.load_state_dict(torch.load(args.e_weights)) 22 | self.model.cuda().eval() 23 | 24 | def enhance(self, img): 25 | 26 | input_ = torch.div(img, 255.) 27 | if self.args.enhancername.split('-')[0]=='DCE': 28 | self.multiples = 12 29 | 30 | h,w = input_.shape[2], input_.shape[3] 31 | H,W = ((h+self.multiples)//self.multiples)*self.multiples, ((w+self.multiples)//self.multiples)*self.multiples 32 | padh = H-h if h%self.multiples!=0 else 0 33 | padw = W-w if w%self.multiples!=0 else 0 34 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 35 | 36 | enhanced,_ = self.model(input_) 37 | enhanced = enhanced[:,:,:h,:w] 38 | else: 39 | enhanced = self.model(input_) 40 | 41 | enhanced = torch.clamp(enhanced, 0, 1) 42 | 43 | return torch.mul(enhanced, 255.) 44 | 45 | 46 | def build_enhancer(args): 47 | return Enhancer(args) 48 | 49 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | sys.path.append("./") 5 | 6 | from glob import glob 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | from snot.datasets import * 10 | from snot.utils.evaluation import OPEBenchmark 11 | from snot.utils.visualization import draw_success_precision 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Single Object Tracking Evaluation') 15 | parser.add_argument('--dataset_dir', default='',type=str, help='dataset root directory') 16 | parser.add_argument('--dataset', default='',type=str, help='dataset name') 17 | parser.add_argument('--tracker_result_dir',default='', type=str, help='tracker result root') 18 | parser.add_argument('--tracker_path', default='', type=str) 19 | parser.add_argument('--tracker_prefix',default='', type=str) 20 | parser.add_argument('--vis',dest='vis', action='store_true') 21 | parser.add_argument('--show_video_level', default=' ',dest='show_video_level', action='store_true') 22 | parser.add_argument('--num', default=1, type=int, help='number of processes to eval') 23 | args = parser.parse_args() 24 | 25 | 26 | def main(): 27 | tracker_dir = os.path.join(args.tracker_path, args.dataset) 28 | trackers = glob(os.path.join(args.tracker_path, 29 | args.dataset, 30 | args.tracker_prefix)) 31 | trackers = [x.split('/')[-1] for x in args.tracker_prefix.split(',')] 32 | 33 | root = args.dataset_dir + args.dataset 34 | 35 | assert len(trackers) > 0 36 | args.num = min(args.num, len(trackers)) 37 | 38 | if 'UAVDark135' in args.dataset: 39 | dataset = UAVDARKDataset(args.dataset, root) 40 | elif 'DarkTrack2021' in args.dataset: 41 | dataset = DARKTRACKDataset(args.dataset, root) 42 | 43 | 44 | 45 | dataset.set_tracker(tracker_dir, trackers) 46 | benchmark = OPEBenchmark(dataset) 47 | success_ret = {} 48 | with Pool(processes=args.num) as pool: 49 | for ret in tqdm(pool.imap_unordered(benchmark.eval_success, 50 | trackers), desc='eval success', total=len(trackers), ncols=18): 51 | success_ret.update(ret) 52 | precision_ret = {} 53 | with Pool(processes=args.num) as pool: 54 | for ret in tqdm(pool.imap_unordered(benchmark.eval_precision, 55 | trackers), desc='eval precision', total=len(trackers), ncols=18): 56 | precision_ret.update(ret) 57 | benchmark.show_result(success_ret, precision_ret, 58 | show_video_level=args.show_video_level) 59 | if args.vis: 60 | for attr, videos in dataset.attr.items(): 61 | draw_success_precision(success_ret, 62 | name=dataset.name, 63 | videos=videos, 64 | attr=attr, 65 | precision_ret=precision_ret) 66 | 67 | if __name__ == '__main__': 68 | main() -------------------------------------------------------------------------------- /experiments/CDT/README.md: -------------------------------------------------------------------------------- 1 | # CDT 2 | 3 | ## Model download 4 | 5 | Before your own test, please download our pretrained model put it into this folder. 6 | The model of CDT can be found at: [model.pth](https://pan.baidu.com/s/1CKOoTUEj8qSHWjzd-FBrYQ?pwd=cdtn)(code:cdtn). 7 | -------------------------------------------------------------------------------- /experiments/DCE++/README.md: -------------------------------------------------------------------------------- 1 | # DCE++ 2 | 3 | ## Model prepare 4 | 5 | Before your own test, please download the pretrained model put it into this folder. 6 | The model of DCE++ can be found at: [model.pth](https://pan.baidu.com/s/1_5dK7B7bPWzWxxcrSuAtww?pwd=dcen)(code:dcen). 7 | You can also download it from official code site: https://github.com/Li-Chongyi/Zero-DCE_extension . 8 | -------------------------------------------------------------------------------- /experiments/SCT/README.md: -------------------------------------------------------------------------------- 1 | # SCT 2 | 3 | ## Model prepare 4 | 5 | Before your own test, please download the pretrained model put it into this folder. 6 | The model of SCT can be found at: [model.pth](https://pan.baidu.com/s/1RBL_QPjmSbqx8TKDaUuqCg?pwd=sctn)(code:sctn). 7 | You can also download it from official code site: https://github.com/vision4robotics/SCT . 8 | -------------------------------------------------------------------------------- /experiments/SiamAPN++/README.md: -------------------------------------------------------------------------------- 1 | # SiamAPN++ 2 | 3 | ## Model prepare 4 | 5 | Before your own test, please download the pretrained model put it into this folder. 6 | The model of SiamAPN++ can be found at: [model.pth](https://pan.baidu.com/s/1tbemYlshgtx8kjer4j4_HQ?pwd=siam)(code:siam). 7 | You can also download it from official code site: https://github.com/vision4robotics/SiamAPN . 8 | -------------------------------------------------------------------------------- /experiments/SiamAPN++/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "ADSiamAPN_alexnet" 2 | 3 | BACKBONE: 4 | TYPE: "alexnet" 5 | PRETRAINED: 'alexnet-bn.pth' 6 | TRAIN_LAYERS: ['layer3','layer4','layer5'] 7 | TRAIN_EPOCH: 10 8 | LAYERS_LR: 0.1 9 | 10 | TRACK: 11 | TYPE: 'ADSiamAPNtracker' 12 | EXEMPLAR_SIZE: 127 13 | INSTANCE_SIZE: 287 14 | CONTEXT_AMOUNT: 0.5 15 | STRIDE: 8 16 | w1: 1.2 17 | w2: 1.3 18 | w3: 1.1 19 | 20 | TRAIN: 21 | EPOCH: 50 22 | START_EPOCH: 0 23 | epochthrelod: 0 24 | SEARCH_SIZE: 287 25 | BATCH_SIZE: 220 26 | NUM_GPU: 2 27 | BASE_LR: 0.005 28 | RESUME: '' 29 | WEIGHT_DECAY : 0.0001 30 | PRETRAINED: '' 31 | OUTPUT_SIZE: 21 32 | NUM_WORKERS: 7 33 | LOC_WEIGHT: 2.2 34 | CLS_WEIGHT: 1.0 35 | SHAPE_WEIGHT: 1.8 36 | w1: 1.2 37 | w2: 1.3 38 | w3: 1.1 39 | w4: 1.5 40 | w5: 1.0 41 | 42 | POS_NUM : 16 43 | TOTAL_NUM : 64 44 | NEG_NUM : 16 45 | LARGER: 1.0 46 | range : 1.0 47 | LR: 48 | TYPE: 'log' 49 | KWARGS: 50 | start_lr: 0.01 51 | end_lr: 0.0005 52 | 53 | LR_WARMUP: 54 | TYPE: 'step' 55 | EPOCH: 5 56 | KWARGS: 57 | start_lr: 0.005 58 | end_lr: 0.01 59 | step: 1 60 | 61 | DATASET: 62 | NAMES: 63 | - 'VID' 64 | - 'COCO' 65 | - 'GOT' 66 | - 'YOUTUBEBB' 67 | 68 | 69 | TEMPLATE: 70 | SHIFT: 4 71 | SCALE: 0.05 72 | BLUR: 0.0 73 | FLIP: 0.0 74 | COLOR: 1.0 75 | 76 | SEARCH: 77 | SHIFT: 64 78 | SCALE: 0.18 79 | BLUR: 0.2 80 | FLIP: 0.0 81 | COLOR: 1.0 82 | 83 | NEG: 0.05 84 | GRAY: 0.0 85 | -------------------------------------------------------------------------------- /experiments/SiamAPN/README.md: -------------------------------------------------------------------------------- 1 | # SiamAPN 2 | 3 | ## Model prepare 4 | 5 | Before your own test, please download the pretrained model put it into this folder. 6 | The model of SiamAPN can be found at: [model.pth](https://pan.baidu.com/s/1rtlxVnWHjWHrNfIJqI6dNw?pwd=siam)(code:siam). 7 | You can also download it from official code site: https://github.com/vision4robotics/SiamAPN . 8 | -------------------------------------------------------------------------------- /experiments/SiamAPN/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "SiamAPN_alexnet" 2 | 3 | BACKBONE: 4 | TYPE: "alexnet" 5 | PRETRAINED: 'alexnet-bn.pth' 6 | TRAIN_LAYERS: ['layer3','layer4','layer5'] 7 | TRAIN_EPOCH: 10 8 | LAYERS_LR: 0.1 9 | 10 | TRACK: 11 | TYPE: 'SiamAPNtracker' 12 | EXEMPLAR_SIZE: 127 13 | INSTANCE_SIZE: 287 14 | CONTEXT_AMOUNT: 0.5 15 | STRIDE: 8 16 | PENALTY_K: 0.08 17 | LR: 0.302 18 | w1: 1.18 19 | w2: 1.0 20 | w3: 1.0 21 | 22 | TRAIN: 23 | EPOCH: 50 24 | START_EPOCH: 0 25 | BATCH_SIZE: 124 26 | NUM_GPU: 2 27 | BASE_LR: 0.005 28 | RESUME: '' 29 | WEIGHT_DECAY : 0.0001 30 | PRETRAINED: '' 31 | OUTPUT_SIZE: 21 32 | NUM_WORKERS: 8 33 | LOC_WEIGHT: 1.0 34 | CLS_WEIGHT: 1.0 35 | SHAPE_WEIGHT: 1.0 36 | w1: 1.2 37 | w2: 1.0 38 | w3: 1.0 39 | w4: 1.0 40 | w5: 1.0 41 | POS_NUM : 16 42 | TOTAL_NUM : 64 43 | NEG_NUM : 16 44 | LARGER: 1.0 45 | range : 1.0 46 | LR: 47 | TYPE: 'log' 48 | KWARGS: 49 | start_lr: 0.01 50 | end_lr: 0.0005 51 | 52 | LR_WARMUP: 53 | TYPE: 'step' 54 | EPOCH: 5 55 | KWARGS: 56 | start_lr: 0.005 57 | end_lr: 0.01 58 | step: 1 59 | 60 | DATASET: 61 | NAMES: 62 | - 'VID' 63 | - 'COCO' 64 | - 'GOT' 65 | - 'YOUTUBEBB' 66 | 67 | 68 | TEMPLATE: 69 | SHIFT: 4 70 | SCALE: 0.05 71 | BLUR: 0.0 72 | FLIP: 0.0 73 | COLOR: 1.0 74 | 75 | SEARCH: 76 | SHIFT: 64 77 | SCALE: 0.18 78 | BLUR: 0.2 79 | FLIP: 0.0 80 | COLOR: 1.0 81 | 82 | NEG: 0.05 83 | GRAY: 0.0 84 | -------------------------------------------------------------------------------- /experiments/SiamBAN/README.md: -------------------------------------------------------------------------------- 1 | # SiamBAN 2 | 3 | ## Model prepare 4 | 5 | Before your own test, please download the pretrained model put it into this folder. 6 | The model of SiamBAN can be found at: [model.pth](https://pan.baidu.com/s/14bHbGZd4R9q4ye2HjCkOVw?pwd=siam)(code:siam). 7 | You can also download it from official code site: https://github.com/hqucv/siamban . 8 | -------------------------------------------------------------------------------- /experiments/SiamBAN/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamban_r50_l234" 2 | 3 | BACKBONE: 4 | TYPE: "resnet50" 5 | KWARGS: 6 | used_layers: [2, 3, 4] 7 | PRETRAINED: 'pretrained_models/resnet50.model' 8 | TRAIN_LAYERS: ['layer2', 'layer3', 'layer4'] 9 | TRAIN_EPOCH: 10 10 | LAYERS_LR: 0.1 11 | 12 | ADJUST: 13 | ADJUST: True 14 | TYPE: "AdjustAllLayer" 15 | KWARGS: 16 | in_channels: [512, 1024, 2048] 17 | out_channels: [256, 256, 256] 18 | 19 | BAN: 20 | BAN: True 21 | TYPE: 'MultiBAN' 22 | KWARGS: 23 | in_channels: [256, 256, 256] 24 | cls_out_channels: 2 # if use sigmoid cls, cls_out_channel = 1 else 2 25 | weighted: True 26 | 27 | POINT: 28 | STRIDE: 8 29 | 30 | TRACK: 31 | TYPE: 'SiamBANTracker' 32 | WINDOW_INFLUENCE: 0.4632532824922313 # VOT2018 33 | PENALTY_K: 0.08513642556896711 # VOT2018 34 | LR: 0.44418184746462425 # VOT2018 35 | # WINDOW_INFLUENCE: 0.334450048565355 # VOT2019 36 | # PENALTY_K: 0.0012159181005195463 # VOT2019 37 | # LR: 0.46386814967815493 # VOT2019 38 | EXEMPLAR_SIZE: 127 39 | INSTANCE_SIZE: 255 40 | BASE_SIZE: 8 41 | CONTEXT_AMOUNT: 0.5 42 | 43 | TRAIN: 44 | EPOCH: 20 45 | START_EPOCH: 0 # 0 or resume checkpoint 46 | BATCH_SIZE: 28 47 | BASE_LR: 0.005 48 | CLS_WEIGHT: 1.0 49 | LOC_WEIGHT: 1.0 50 | RESUME: '' # '' or 'snapshot/checkpoint_e.pth' 51 | 52 | LR: 53 | TYPE: 'log' 54 | KWARGS: 55 | start_lr: 0.005 56 | end_lr: 0.00005 57 | LR_WARMUP: 58 | TYPE: 'step' 59 | EPOCH: 5 60 | KWARGS: 61 | start_lr: 0.001 62 | end_lr: 0.005 63 | step: 1 64 | 65 | DATASET: 66 | NAMES: 67 | - 'VID' 68 | - 'YOUTUBEBB' 69 | - 'COCO' 70 | - 'DET' 71 | - 'GOT10K' 72 | - 'LASOT' 73 | 74 | VIDEOS_PER_EPOCH: 1000000 75 | 76 | TEMPLATE: 77 | SHIFT: 4 78 | SCALE: 0.05 79 | BLUR: 0.0 80 | FLIP: 0.0 81 | COLOR: 1.0 82 | 83 | SEARCH: 84 | SHIFT: 64 85 | SCALE: 0.18 86 | BLUR: 0.2 87 | FLIP: 0.0 88 | COLOR: 1.0 89 | 90 | NEG: 0.2 91 | GRAY: 0.0 92 | -------------------------------------------------------------------------------- /experiments/SiamGAT/README.md: -------------------------------------------------------------------------------- 1 | # SiamGAT 2 | 3 | ## Model prepare 4 | 5 | Before your own test, please download the pretrained model put it into this folder. 6 | The model of SiamGAT can be found at: [model.pth](https://pan.baidu.com/s/1EqBAqX_0tA_OaL--eRLHgQ?pwd=siam)(code:siam). 7 | You can also download it from official code site: https://github.com/ohhhyeahhh/SiamGAT . 8 | -------------------------------------------------------------------------------- /experiments/SiamGAT/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamgat_googlenet" 2 | 3 | BACKBONE: 4 | TYPE: "googlenet_ou" 5 | PRETRAINED: 'pretrained_models/inception_v3.pth' 6 | TRAIN_LAYERS: ['Mixed_5b','Mixed_5c','Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'channel_reduce'] 7 | CHANNEL_REDUCE_LAYERS: ['channel_reduce'] 8 | TRAIN_EPOCH: 10 9 | CROP_PAD: 4 10 | LAYERS_LR: 0.1 11 | 12 | TRACK: 13 | TYPE: 'SiamGATTracker' 14 | EXEMPLAR_SIZE: 127 15 | INSTANCE_SIZE: 287 16 | SCORE_SIZE: 25 17 | CONTEXT_AMOUNT: 0.5 18 | STRIDE: 8 19 | OFFSET: 45 20 | 21 | TRAIN: 22 | EPOCH: 20 23 | START_EPOCH: 0 24 | SEARCH_SIZE: 287 25 | BATCH_SIZE: 76 26 | CLS_WEIGHT: 1.0 27 | LOC_WEIGHT: 3.0 28 | CEN_WEIGHT: 1.0 29 | RESUME: '' 30 | PRETRAINED: '' 31 | NUM_CLASSES: 2 32 | NUM_CONVS: 4 33 | PRIOR_PROB: 0.01 34 | OUTPUT_SIZE: 25 35 | ATTENTION: True 36 | 37 | LR: 38 | TYPE: 'log' 39 | KWARGS: 40 | start_lr: 0.01 41 | end_lr: 0.0005 42 | LR_WARMUP: 43 | TYPE: 'step' 44 | EPOCH: 5 45 | KWARGS: 46 | start_lr: 0.005 47 | end_lr: 0.01 48 | step: 1 49 | 50 | DATASET: 51 | NAMES: 52 | - 'VID' 53 | - 'YOUTUBEBB' 54 | - 'COCO' 55 | - 'DET' 56 | - 'GOT' 57 | 58 | VIDEOS_PER_EPOCH: 800000 59 | 60 | TEMPLATE: 61 | SHIFT: 4 62 | SCALE: 0.05 63 | BLUR: 0.0 64 | FLIP: 0.0 65 | COLOR: 1.0 66 | 67 | SEARCH: 68 | SHIFT: 64 69 | SCALE: 0.18 70 | BLUR: 0.2 71 | FLIP: 0.0 72 | COLOR: 1.0 73 | 74 | NEG: 0.2 75 | GRAY: 0.0 76 | -------------------------------------------------------------------------------- /experiments/SiamRPN++_mobilev2/README.md: -------------------------------------------------------------------------------- 1 | # SiamRPN++ 2 | 3 | ## Model prepare 4 | 5 | Before your own test, please download the pretrained model put it into this folder. 6 | The model of SiamRPN++_MobileNetV2 can be found at: [model.pth](https://pan.baidu.com/s/1kycg8HF_yimK4vwINTiUiQ?pwd=siam)(code:siam). 7 | You can also download it from official code site: https://github.com/STVIR/pysot . 8 | -------------------------------------------------------------------------------- /experiments/SiamRPN++_mobilev2/config.yaml: -------------------------------------------------------------------------------- 1 | META_ARC: "siamrpn_mobilev2_l234_dwxcorr" 2 | 3 | BACKBONE: 4 | TYPE: "mobilenetv2" 5 | KWARGS: 6 | used_layers: [3, 5, 7] 7 | width_mult: 1.4 8 | 9 | ADJUST: 10 | ADJUST: true 11 | TYPE: "AdjustAllLayer" 12 | KWARGS: 13 | in_channels: [44, 134, 448] 14 | out_channels: [256, 256, 256] 15 | 16 | RPN: 17 | TYPE: 'MultiRPN' 18 | KWARGS: 19 | anchor_num: 5 20 | in_channels: [256, 256, 256] 21 | weighted: False 22 | 23 | MASK: 24 | MASK: False 25 | 26 | ANCHOR: 27 | STRIDE: 8 28 | RATIOS: [0.33, 0.5, 1, 2, 3] 29 | SCALES: [8] 30 | ANCHOR_NUM: 5 31 | 32 | TRACK: 33 | TYPE: 'SiamRPNTracker' 34 | PENALTY_K: 0.04 35 | WINDOW_INFLUENCE: 0.4 36 | LR: 0.5 37 | EXEMPLAR_SIZE: 127 38 | INSTANCE_SIZE: 255 39 | BASE_SIZE: 8 40 | CONTEXT_AMOUNT: 0.5 41 | -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamAPN++_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamAPN++_DCE.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamAPN++_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamAPN++_DCE_CDT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamAPN++_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamAPN++_SCT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamAPN++_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamAPN++_SCT_CDT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamAPN_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamAPN_DCE.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamAPN_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamAPN_DCE_CDT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamAPN_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamAPN_SCT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamAPN_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamAPN_SCT_CDT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamBAN_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamBAN_DCE.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamBAN_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamBAN_DCE_CDT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamBAN_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamBAN_SCT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamBAN_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamBAN_SCT_CDT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamGAT_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamGAT_DCE.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamGAT_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamGAT_DCE_CDT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamGAT_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamGAT_SCT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamGAT_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamGAT_SCT_CDT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamRPN++_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamRPN++_DCE.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamRPN++_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamRPN++_DCE_CDT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamRPN++_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamRPN++_SCT.zip -------------------------------------------------------------------------------- /results/DarkTrack2021/SiamRPN++_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/DarkTrack2021/SiamRPN++_SCT_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN++_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN++_DCE.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN++_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN++_DCE_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN++_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN++_SCT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN++_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN++_SCT_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_DCE.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_DCE_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_CBDNet.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_CBDNet.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_DRUNet.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_DRUNet.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_DnCNN.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_DnCNN.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_FDnCNN.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_FDnCNN.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_FFDNet.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_FFDNet.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_IRCNN.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_IRCNN.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_NAFNet.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_NAFNet.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_Restormer.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_Restormer.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamAPN_SCT_Uformer.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamAPN_SCT_Uformer.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamBAN_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamBAN_DCE.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamBAN_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamBAN_DCE_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamBAN_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamBAN_SCT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamBAN_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamBAN_SCT_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamGAT_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamGAT_DCE.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamGAT_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamGAT_DCE_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamGAT_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamGAT_SCT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamGAT_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamGAT_SCT_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamRPN++_DCE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamRPN++_DCE.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamRPN++_DCE_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamRPN++_DCE_CDT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamRPN++_SCT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamRPN++_SCT.zip -------------------------------------------------------------------------------- /results/UAVDark135/SiamRPN++_SCT_CDT.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/results/UAVDark135/SiamRPN++_SCT_CDT.zip -------------------------------------------------------------------------------- /snot/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vision4robotics/CDT/05a4375d44471e9b603385bba9a303d00f298241/snot/core/__init__.py -------------------------------------------------------------------------------- /snot/core/xcorr.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def xcorr_slow(x, kernel): 11 | """for loop to calculate cross correlation, slow version 12 | """ 13 | batch = x.size()[0] 14 | out = [] 15 | for i in range(batch): 16 | px = x[i] 17 | pk = kernel[i] 18 | px = px.view(1, -1, px.size()[1], px.size()[2]) 19 | pk = pk.view(1, -1, pk.size()[1], pk.size()[2]) 20 | po = F.conv2d(px, pk) 21 | out.append(po) 22 | out = torch.cat(out, 0) 23 | return out 24 | 25 | 26 | def xcorr_fast(x, kernel): 27 | """group conv2d to calculate cross correlation, fast version 28 | """ 29 | batch = kernel.size()[0] 30 | pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3]) 31 | px = x.view(1, -1, x.size()[2], x.size()[3]) 32 | po = F.conv2d(px, pk, groups=batch) 33 | po = po.view(batch, -1, po.size()[2], po.size()[3]) 34 | return po 35 | 36 | 37 | def xcorr_depthwise(x, kernel): 38 | """depthwise cross correlation 39 | """ 40 | batch = kernel.size(0) 41 | channel = kernel.size(1) 42 | x = x.view(1, batch*channel, x.size(2), x.size(3)) 43 | kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3)) 44 | out = F.conv2d(x, kernel, groups=batch*channel) 45 | out = out.view(batch, channel, out.size(2), out.size(3)) 46 | return out 47 | -------------------------------------------------------------------------------- /snot/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .uavdark import UAVDARKDataset 2 | from .darktrack import DARKTRACKDataset 3 | 4 | 5 | datapath = { 6 | 'UAVDark135':'/Dataset/UAVDark135', 7 | 'DarkTrack2021':'/Dataset/DarkTrack2021', 8 | } 9 | 10 | class DatasetFactory(object): 11 | @staticmethod 12 | def create_dataset(**kwargs): 13 | 14 | assert 'name' in kwargs, "should provide dataset name" 15 | name = kwargs['name'] 16 | if 'UAVDark' in name: 17 | dataset = UAVDARKDataset(**kwargs) 18 | elif 'DarkTrack2021' in name: 19 | dataset = DARKTRACKDataset(**kwargs) 20 | 21 | else: 22 | raise Exception("unknow dataset {}".format(kwargs['name'])) 23 | return dataset 24 | -------------------------------------------------------------------------------- /snot/datasets/darktrack.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | from glob import glob 5 | 6 | from .dataset import Dataset 7 | from .video import Video 8 | 9 | 10 | def loaddata(dataset_root): 11 | 12 | name_list=os.listdir(dataset_root+'/data_seq') 13 | name_list.sort() 14 | 15 | b=[] 16 | for i in range(len(name_list)): 17 | b.append(name_list[i]) 18 | c=[] 19 | 20 | for jj in range(len(name_list)): 21 | imgs=dataset_root+'/data_seq/'+str(name_list[jj]) 22 | txt=dataset_root+'/anno/'+str(name_list[jj])+'.txt' 23 | bbox=[] 24 | f = open(txt) 25 | file= f.readlines() 26 | li=os.listdir(imgs) 27 | li.sort() 28 | for ii in range(len(file)): 29 | try: 30 | li[ii]=name_list[jj]+'/'+li[ii] 31 | except: 32 | a=1 33 | 34 | line = file[ii].strip('\n').split(' ') 35 | 36 | 37 | if len(line)!=4: 38 | line = file[ii].strip('\n').split(',') 39 | if len(line)!=4: 40 | line = file[ii].strip('\n').split('\t') 41 | 42 | try: 43 | line[0]=int(line[0]) 44 | except: 45 | line[0]=float(line[0]) 46 | try: 47 | line[1]=int(line[1]) 48 | except: 49 | line[1]=float(line[1]) 50 | try: 51 | line[2]=int(line[2]) 52 | except: 53 | line[2]=float(line[2]) 54 | try: 55 | line[3]=int(line[3]) 56 | except: 57 | line[3]=float(line[3]) 58 | bbox.append(line) 59 | 60 | if len(bbox)!=len(li): 61 | print (jj) 62 | f.close() 63 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 64 | 65 | d=dict(zip(b,c)) 66 | 67 | return d 68 | 69 | 70 | class UAVVideo(Video): 71 | """ 72 | Args: 73 | name: video name 74 | root: dataset root 75 | video_dir: video directory 76 | init_rect: init rectangle 77 | img_names: image names 78 | gt_rect: groundtruth rectangle 79 | attr: attribute of video 80 | """ 81 | def __init__(self, name, root, video_dir, init_rect, img_names, 82 | gt_rect, attr, load_img=False): 83 | super(UAVVideo, self).__init__(name, root, video_dir, 84 | init_rect, img_names, gt_rect, attr, load_img) 85 | 86 | 87 | class DARKTRACKDataset(Dataset): 88 | """ 89 | Args: 90 | name: dataset name, should be 'UAV123', 'UAV20L' 91 | dataset_root: dataset root 92 | load_img: wether to load all imgs 93 | """ 94 | def __init__(self, name, dataset_root, load_img=False): 95 | super(DARKTRACKDataset, self).__init__(name, dataset_root) 96 | meta_data = loaddata(dataset_root) 97 | 98 | # load videos 99 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 100 | self.videos = {} 101 | for video in pbar: 102 | pbar.set_postfix_str(video) 103 | self.videos[video] = UAVVideo(video, 104 | dataset_root+'/data_seq', 105 | meta_data[video]['video_dir'], 106 | meta_data[video]['init_rect'], 107 | meta_data[video]['img_names'], 108 | meta_data[video]['gt_rect'], 109 | meta_data[video]['attr']) 110 | 111 | # set attr 112 | attr = [] 113 | for x in self.videos.values(): 114 | attr += x.attr 115 | attr = set(attr) 116 | self.attr = {} 117 | self.attr['ALL'] = list(self.videos.keys()) 118 | for x in attr: 119 | self.attr[x] = [] 120 | for k, v in self.videos.items(): 121 | for attr_ in v.attr: 122 | self.attr[attr_].append(k) 123 | -------------------------------------------------------------------------------- /snot/datasets/datapath.py: -------------------------------------------------------------------------------- 1 | datapath = { 2 | 'DTB70':'/Dataset/DTB70', 3 | 'UAV123':'/Dataset/UAV123', 4 | 'UAV10':'/Dataset/UAV123_10fps', 5 | 'UAV20':'/Dataset/UAV123_20L', 6 | 'UAVDT':'/Dataset/UAVDT', 7 | 'UAVTrack112':'/Dataset/UAVTrack112', 8 | 'VISDRONED2018':'/Dataset/VisDrone2018-SOT-test', 9 | 'VISDRONED2019':'/Dataset/VisDrone2019', 10 | 'UAVDark135':'/Dataset/UAVDark135_TSP_out' 11 | } -------------------------------------------------------------------------------- /snot/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | class Dataset(object): 4 | def __init__(self, name, dataset_root): 5 | self.name = name 6 | self.dataset_root = dataset_root 7 | self.videos = None 8 | 9 | def __getitem__(self, idx): 10 | if isinstance(idx, str): 11 | return self.videos[idx] 12 | elif isinstance(idx, int): 13 | return self.videos[sorted(list(self.videos.keys()))[idx]] 14 | 15 | def __len__(self): 16 | return len(self.videos) 17 | 18 | def __iter__(self): 19 | keys = sorted(list(self.videos.keys())) 20 | for key in keys: 21 | yield self.videos[key] 22 | 23 | def set_tracker(self, path, tracker_names): 24 | """ 25 | Args: 26 | path: path to tracker results, 27 | tracker_names: list of tracker name 28 | """ 29 | self.tracker_path = path 30 | self.tracker_names = tracker_names 31 | # for video in tqdm(self.videos.values(), 32 | # desc='loading tacker result', ncols=100): 33 | # video.load_tracker(path, tracker_names) 34 | -------------------------------------------------------------------------------- /snot/datasets/uavdark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | from glob import glob 5 | 6 | from .dataset import Dataset 7 | from .video import Video 8 | 9 | 10 | def loaddata(dataset_root): 11 | 12 | name_list=os.listdir(dataset_root+'/data_seq') 13 | name_list.sort() 14 | 15 | b=[] 16 | for i in range(len(name_list)): 17 | b.append(name_list[i]) 18 | c=[] 19 | 20 | for jj in range(len(name_list)): 21 | imgs=dataset_root+'/data_seq/'+str(name_list[jj]) 22 | txt=dataset_root+'/anno/'+str(name_list[jj])+'.txt' 23 | bbox=[] 24 | f = open(txt) 25 | file= f.readlines() 26 | li=os.listdir(imgs) 27 | li.sort() 28 | for ii in range(len(file)): 29 | try: 30 | li[ii]=name_list[jj]+'/'+li[ii] 31 | except: 32 | a=1 33 | 34 | line = file[ii].strip('\n').split(' ') 35 | 36 | 37 | if len(line)!=4: 38 | line = file[ii].strip('\n').split(',') 39 | if len(line)!=4: 40 | line = file[ii].strip('\n').split('\t') 41 | 42 | try: 43 | line[0]=int(line[0]) 44 | except: 45 | line[0]=float(line[0]) 46 | try: 47 | line[1]=int(line[1]) 48 | except: 49 | line[1]=float(line[1]) 50 | try: 51 | line[2]=int(line[2]) 52 | except: 53 | line[2]=float(line[2]) 54 | try: 55 | line[3]=int(line[3]) 56 | except: 57 | line[3]=float(line[3]) 58 | bbox.append(line) 59 | 60 | if len(bbox)!=len(li): 61 | print (jj) 62 | f.close() 63 | c.append({'attr':[],'gt_rect':bbox,'img_names':li,'init_rect':bbox[0],'video_dir':name_list[jj]}) 64 | 65 | d=dict(zip(b,c)) 66 | 67 | return d 68 | 69 | 70 | class UAVVideo(Video): 71 | """ 72 | Args: 73 | name: video name 74 | root: dataset root 75 | video_dir: video directory 76 | init_rect: init rectangle 77 | img_names: image names 78 | gt_rect: groundtruth rectangle 79 | attr: attribute of video 80 | """ 81 | def __init__(self, name, root, video_dir, init_rect, img_names, 82 | gt_rect, attr, load_img=False): 83 | super(UAVVideo, self).__init__(name, root, video_dir, 84 | init_rect, img_names, gt_rect, attr, load_img) 85 | 86 | 87 | class UAVDARKDataset(Dataset): 88 | """ 89 | Args: 90 | name: dataset name, should be 'UAV123', 'UAV20L' 91 | dataset_root: dataset root 92 | load_img: wether to load all imgs 93 | """ 94 | def __init__(self, name, dataset_root, load_img=False): 95 | super(UAVDARKDataset, self).__init__(name, dataset_root) 96 | meta_data = loaddata(dataset_root) 97 | 98 | # load videos 99 | pbar = tqdm(meta_data.keys(), desc='loading '+name, ncols=100) 100 | self.videos = {} 101 | for video in pbar: 102 | pbar.set_postfix_str(video) 103 | self.videos[video] = UAVVideo(video, 104 | dataset_root+'/data_seq', 105 | meta_data[video]['video_dir'], 106 | meta_data[video]['init_rect'], 107 | meta_data[video]['img_names'], 108 | meta_data[video]['gt_rect'], 109 | meta_data[video]['attr']) 110 | 111 | # set attr 112 | attr = [] 113 | for x in self.videos.values(): 114 | attr += x.attr 115 | attr = set(attr) 116 | self.attr = {} 117 | self.attr['ALL'] = list(self.videos.keys()) 118 | for x in attr: 119 | self.attr[x] = [] 120 | for k, v in self.videos.items(): 121 | for attr_ in v.attr: 122 | self.attr[attr_].append(k) 123 | -------------------------------------------------------------------------------- /snot/datasets/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import re 4 | import numpy as np 5 | import json 6 | 7 | from glob import glob 8 | 9 | class Video(object): 10 | def __init__(self, name, root, video_dir, init_rect, img_names, 11 | gt_rect, attr, load_img=False): 12 | self.name = name 13 | self.video_dir = video_dir 14 | self.init_rect = init_rect 15 | self.gt_traj = gt_rect 16 | self.attr = attr 17 | self.pred_trajs = {} 18 | self.img_names = [os.path.join(root, x) for x in img_names] 19 | self.imgs = None 20 | 21 | if load_img: 22 | self.imgs = [cv2.imread(x) for x in self.img_names] 23 | self.width = self.imgs[0].shape[1] 24 | self.height = self.imgs[0].shape[0] 25 | else: 26 | img = cv2.imread(self.img_names[0]) 27 | assert img is not None, self.img_names[0] 28 | self.width = img.shape[1] 29 | self.height = img.shape[0] 30 | 31 | def load_tracker(self, path, tracker_names=None, store=True): 32 | """ 33 | Args: 34 | path(str): path to result 35 | tracker_name(list): name of tracker 36 | """ 37 | if not tracker_names: 38 | tracker_names = [x.split('/')[-1] for x in glob(path) 39 | if os.path.isdir(x)] 40 | if isinstance(tracker_names, str): 41 | tracker_names = [tracker_names] 42 | for name in tracker_names: 43 | traj_file = os.path.join(path, name, self.name+'.txt') 44 | if os.path.exists(traj_file): 45 | with open(traj_file, 'r') as f : 46 | pred_traj = [list(map(float, x.strip().split(','))) 47 | for x in f.readlines()] 48 | if len(pred_traj) != len(self.gt_traj): 49 | print(name, len(pred_traj), len(self.gt_traj), self.name) 50 | if store: 51 | self.pred_trajs[name] = pred_traj 52 | else: 53 | return pred_traj 54 | else: 55 | print(traj_file) 56 | self.tracker_names = list(self.pred_trajs.keys()) 57 | 58 | def load_img(self): 59 | if self.imgs is None: 60 | self.imgs = [cv2.imread(x) for x in self.img_names] 61 | self.width = self.imgs[0].shape[1] 62 | self.height = self.imgs[0].shape[0] 63 | 64 | def free_img(self): 65 | self.imgs = None 66 | 67 | def __len__(self): 68 | return len(self.img_names) 69 | 70 | def __getitem__(self, idx): 71 | if self.imgs is None: 72 | return cv2.imread(self.img_names[idx]), self.gt_traj[idx] 73 | else: 74 | return self.imgs[idx], self.gt_traj[idx] 75 | 76 | def __iter__(self): 77 | for i in range(len(self.img_names)): 78 | if self.imgs is not None: 79 | yield self.imgs[i], self.gt_traj[i] 80 | else: 81 | yield cv2.imread(self.img_names[i]), self.gt_traj[i] 82 | 83 | def draw_box(self, roi, img, linewidth, color, name=None): 84 | """ 85 | roi: rectangle or polygon 86 | img: numpy array img 87 | linewith: line width of the bbox 88 | """ 89 | if len(roi) > 6 and len(roi) % 2 == 0: 90 | pts = np.array(roi, np.int32).reshape(-1, 1, 2) 91 | color = tuple(map(int, color)) 92 | img = cv2.polylines(img, [pts], True, color, linewidth) 93 | pt = (pts[0, 0, 0], pts[0, 0, 1]-5) 94 | if name: 95 | img = cv2.putText(img, name, pt, cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, color, 1) 96 | elif len(roi) == 4: 97 | if not np.isnan(roi[0]): 98 | roi = list(map(int, roi)) 99 | color = tuple(map(int, color)) 100 | img = cv2.rectangle(img, (roi[0], roi[1]), (roi[0]+roi[2], roi[1]+roi[3]), 101 | color, linewidth) 102 | if name: 103 | img = cv2.putText(img, name, (roi[0], roi[1]-5), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, color, 1) 104 | return img 105 | 106 | def show(self, pred_trajs={}, linewidth=2, show_name=False): 107 | """ 108 | pred_trajs: dict of pred_traj, {'tracker_name': list of traj} 109 | pred_traj should contain polygon or rectangle(x, y, width, height) 110 | linewith: line width of the bbox 111 | """ 112 | assert self.imgs is not None 113 | video = [] 114 | cv2.namedWindow(self.name, cv2.WINDOW_NORMAL) 115 | colors = {} 116 | if len(pred_trajs) == 0 and len(self.pred_trajs) > 0: 117 | pred_trajs = self.pred_trajs 118 | for i, (roi, img) in enumerate(zip(self.gt_traj, 119 | self.imgs[self.start_frame:self.end_frame+1])): 120 | img = img.copy() 121 | if len(img.shape) == 2: 122 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 123 | else: 124 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 125 | img = self.draw_box(roi, img, linewidth, (0, 255, 0), 126 | 'gt' if show_name else None) 127 | for name, trajs in pred_trajs.items(): 128 | if name not in colors: 129 | color = tuple(np.random.randint(0, 256, 3)) 130 | colors[name] = color 131 | else: 132 | color = colors[name] 133 | img = self.draw_box(trajs[0][i], img, linewidth, color, 134 | name if show_name else None) 135 | cv2.putText(img, str(i+self.start_frame), (5, 20), 136 | cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (255, 255, 0), 2) 137 | cv2.imshow(self.name, img) 138 | cv2.waitKey(40) 139 | video.append(img.copy()) 140 | return video 141 | -------------------------------------------------------------------------------- /snot/models/adapn/anchortarget.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | import torch as t 8 | from snot.core.config_adapn import cfg 9 | from snot.utils.bbox import IoU 10 | 11 | 12 | class AnchorTarget3_adapn(): 13 | def __init__(self): 14 | return 15 | 16 | def select(self,position, keep_num=16): 17 | num = position[0].shape[0] 18 | if num <= keep_num: 19 | return position, num 20 | slt = np.arange(num) 21 | np.random.shuffle(slt) 22 | slt = slt[:keep_num] 23 | return tuple(p[slt] for p in position), keep_num 24 | 25 | def filte(self,over,pos1,num): 26 | top_k_idx=over.argsort()[::-1][0:num] 27 | poss1=(pos1[0][top_k_idx],pos1[1][top_k_idx],pos1[2][top_k_idx]) 28 | return poss1 29 | 30 | def get(self, anchors,targets, size): 31 | num=cfg.TRAIN.BATCH_SIZE//cfg.TRAIN.NUM_GPU 32 | offset=cfg.TRAIN.SEARCH_SIZE//2-cfg.ANCHOR.STRIDE*(size-1)/2 33 | anchor_num=1 34 | cls = -1 * np.ones((num,anchor_num, size, size), dtype=np.int64) 35 | delta = np.zeros((num,4, anchor_num, size, size), dtype=np.float32) 36 | delta_weight = np.zeros((num,anchor_num, size, size), dtype=np.float32) 37 | overlap = np.zeros((num,anchor_num, size, size), dtype=np.float32) 38 | for i in range(num): 39 | anchor=anchors[i] 40 | target=targets[i].cpu().numpy() 41 | neg = cfg.DATASET.NEG and cfg.DATASET.NEG > np.random.random() 42 | # -1 ignore 0 negative 1 positive 43 | tcx = (target[0]+target[2])/2 44 | tcy= (target[1]+target[3])/2 45 | tw=target[2]-target[0] 46 | th=target[3]-target[1] 47 | if neg : 48 | cx = size // 2 49 | cy = size // 2 50 | cx += int(np.ceil((tcx - cfg.TRAIN.SEARCH_SIZE // 2) / 51 | 8 + 0.5)) 52 | cy += int(np.ceil((tcy - cfg.TRAIN.SEARCH_SIZE // 2) / 53 | 8 + 0.5)) 54 | l = max(0, cx - 3) 55 | r = min(size, cx + 4) 56 | u = max(0, cy - 3) 57 | d = min(size, cy + 4) 58 | cls[i,:, u:d,l:r ] = 0 59 | neg, neg_num = self.select(np.where(cls[i][0] == 0), cfg.TRAIN.NEG_NUM) 60 | cls[i] = -1 61 | cls[i][0][neg] = 0 62 | overlap[i] = np.zeros((anchor_num, size, size), dtype=np.float32) 63 | continue 64 | 65 | cx, cy, w, h = anchor[:,0].reshape(1,size,size),anchor[:,1].reshape(1,size,size),anchor[:,2].reshape(1,size,size),anchor[:,3].reshape(1,size,size) 66 | x1 = cx - w * 0.5 67 | y1 = cy - h * 0.5 68 | x2 = cx + w * 0.5 69 | y2 = cy + h * 0.5 70 | index=np.minimum(size-1,np.maximum(0,np.int32((target-offset)/cfg.ANCHOR.STRIDE))) 71 | ww=int(index[2]-index[0])+1 72 | hh=int(index[3]-index[1])+1 73 | labelcls2=np.zeros((1,size,size))-2 74 | labelcls2[0,np.maximum(0,index[1]-hh//cfg.TRAIN.labelcls2range1):np.minimum(size,index[3]+1+hh//cfg.TRAIN.labelcls2range1),\ 75 | np.maximum(0,index[0]-ww//cfg.TRAIN.labelcls2range1):np.minimum(size,index[2]+1+ww//cfg.TRAIN.labelcls2range1)]=-1 76 | labelcls2[0,index[1]:(index[3]+1),index[0]:(index[2]+1)]=0 77 | labelcls2[0,index[1]+hh//cfg.TRAIN.labelcls2range2:index[3]-hh//cfg.TRAIN.labelcls2range2+1,\ 78 | index[0]+ww//cfg.TRAIN.labelcls2range2:index[2]-ww//cfg.TRAIN.labelcls2range2+1]=0.5 79 | labelcls2[0,index[1]+hh//cfg.TRAIN.labelcls2range3:index[3]-hh//cfg.TRAIN.labelcls2range3+1,\ 80 | index[0]+ww//cfg.TRAIN.labelcls2range3:index[2]-ww//cfg.TRAIN.labelcls2range3+1]=1 81 | overlap[i] = IoU([x1, y1, x2, y2], target) 82 | pos1 = np.where((overlap[i] > 0.86)) 83 | neg1 = np.where((overlap[i] <= 0.6)) 84 | pos1, pos_num1 = self.select(pos1, cfg.TRAIN.POS_NUM) 85 | neg1, neg_num1 = self.select(neg1, cfg.TRAIN.TOTAL_NUM - cfg.TRAIN.POS_NUM) 86 | cls[i][pos1] = 1 87 | cls[i][neg1] = 0 88 | pos = np.where((overlap[i] > 0.83)|((overlap[i] > 0.8)&(labelcls2>=0.5))) 89 | neg = np.where((overlap[i] <= 0.6)) 90 | pos, pos_num = self.select(pos, cfg.TRAIN.POS_NUM) 91 | neg, neg_num = self.select(neg, cfg.TRAIN.TOTAL_NUM - cfg.TRAIN.POS_NUM) 92 | if anchor[:,2].min()>0 and anchor[:,3].min()>0: 93 | delta[i][0] = (tcx - cx) / (w+1e-6) 94 | delta[i][1] = (tcy - cy) / (h+1e-6) 95 | delta[i][2] = np.log(tw / (w+1e-6) + 1e-6) 96 | delta[i][3] = np.log(th / (h+1e-6) + 1e-6) 97 | delta_weight[i][pos] = 1. / (pos_num + 1e-6) 98 | delta_weight[i][neg] =0 99 | 100 | cls=t.Tensor(cls).cuda() 101 | delta_weight=t.Tensor(delta_weight).cuda() 102 | delta=t.Tensor(delta).cuda() 103 | 104 | return cls, delta, delta_weight 105 | -------------------------------------------------------------------------------- /snot/models/adsiamapn_model.py: -------------------------------------------------------------------------------- 1 | # Parts of this code come from https://github.com/vision4robotics/SiamAPN 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import torch.nn as nn 8 | 9 | from snot.core.config_adapn import cfg 10 | from snot.models.backbone.alexnet import AlexNet_apn 11 | from snot.models.adapn.utile import ADAPN,clsandloc_adapn 12 | from snot.models.adapn.anchortarget import AnchorTarget3_adapn 13 | 14 | 15 | class ModelBuilderADAPN(nn.Module): 16 | def __init__(self): 17 | super(ModelBuilderADAPN, self).__init__() 18 | 19 | self.backbone = AlexNet_apn().cuda() 20 | self.grader=ADAPN(cfg).cuda() 21 | self.new=clsandloc_adapn(cfg).cuda() 22 | self.fin2=AnchorTarget3_adapn() 23 | 24 | def template(self, z): 25 | 26 | zf = self.backbone(z) 27 | self.zf=zf 28 | 29 | def track(self, x): 30 | 31 | xf = self.backbone(x) 32 | xff,ress=self.grader(xf,self.zf) 33 | 34 | self.ranchors=xff 35 | 36 | cls1,cls2,cls3,loc =self.new(xf,self.zf,ress) 37 | 38 | return { 39 | 'cls1': cls1, 40 | 'cls2': cls2, 41 | 'cls3': cls3, 42 | 'loc': loc 43 | } 44 | -------------------------------------------------------------------------------- /snot/models/apn/anchortarget.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | import torch as t 8 | from snot.core.config_apn import cfg 9 | from snot.utils.bbox import IoU 10 | 11 | 12 | class AnchorTarget3_apn(): 13 | def __init__(self): 14 | return 15 | 16 | def select(self,position, keep_num=16): 17 | num = position[0].shape[0] 18 | if num <= keep_num: 19 | return position, num 20 | slt = np.arange(num) 21 | np.random.shuffle(slt) 22 | slt = slt[:keep_num] 23 | return tuple(p[slt] for p in position), keep_num 24 | 25 | def get(self, anchors,targets, size): 26 | num=cfg.TRAIN.BATCH_SIZE//cfg.TRAIN.NUM_GPU 27 | anchor_num=1 28 | cls = -1 * np.ones((num,anchor_num, size, size), dtype=np.int64) 29 | delta = np.zeros((num,4, anchor_num, size, size), dtype=np.float32) 30 | delta_weight = np.zeros((num,anchor_num, size, size), dtype=np.float32) 31 | overlap = np.zeros((num,anchor_num, size, size), dtype=np.float32) 32 | for i in range(num): 33 | anchor=anchors[i] 34 | target=targets[i].cpu().numpy() 35 | index=np.minimum(size-1,np.maximum(0,np.int32((target-cfg.TRAIN.MOV)/cfg.TRAIN.STRIDE))) 36 | w=int(index[2]-index[0]) 37 | h=int(index[3]-index[1]) 38 | neg = cfg.DATASET.NEG and cfg.DATASET.NEG > np.random.random() 39 | # -1 ignore 0 negative 1 positive 40 | tcx = (target[0]+target[2])/2 41 | tcy= (target[1]+target[3])/2 42 | tw=target[2]-target[0] 43 | th=target[3]-target[1] 44 | if neg : 45 | cx = size // 2 46 | cy = size // 2 47 | cx += int(np.ceil((tcx - cfg.TRAIN.SEARCH_SIZE // 2) / 48 | 8 + 0.5)) 49 | cy += int(np.ceil((tcy - cfg.TRAIN.SEARCH_SIZE // 2) / 50 | 8 + 0.5)) 51 | l = max(0, cx - 3) 52 | r = min(size, cx + 4) 53 | u = max(0, cy - 3) 54 | d = min(size, cy + 4) 55 | cls[i,:, u:d,l:r ] = 0 56 | neg, neg_num = self.select(np.where(cls[i][0] == 0), cfg.TRAIN.NEG_NUM) 57 | cls[i] = -1 58 | cls[i][0][neg] = 0 59 | overlap[i] = np.zeros((anchor_num, size, size), dtype=np.float32) 60 | continue 61 | 62 | cx, cy, w, h = anchor[:,0].reshape(1,size,size),anchor[:,1].reshape(1,size,size),anchor[:,2].reshape(1,size,size),anchor[:,3].reshape(1,size,size) 63 | x1 = cx - w * 0.5 64 | y1 = cy - h * 0.5 65 | x2 = cx + w * 0.5 66 | y2 = cy + h * 0.5 67 | index=np.minimum(size-1,np.maximum(0,np.int32((target-cfg.TRAIN.MOV)/cfg.TRAIN.STRIDE))) 68 | ww=int(index[2]-index[0]) 69 | hh=int(index[3]-index[1]) 70 | labelcls2=np.zeros((1,size,size))-2 71 | labelcls2[0,index[1]:index[3]+1,index[0]:index[2]+1]=-1 72 | labelcls2[0,index[1]+hh//4:index[3]+1-hh//4,index[0]+ww//4:index[2]+1-ww//4]=1 73 | overlap[i] = IoU([x1, y1, x2, y2], target) 74 | pos1 = np.where((overlap[i] > 0.80)) 75 | neg1 = np.where((overlap[i] <= 0.5)) 76 | pos1, pos_num1 = self.select(pos1, cfg.TRAIN.POS_NUM) 77 | neg1, neg_num1 = self.select(neg1, cfg.TRAIN.TOTAL_NUM - cfg.TRAIN.POS_NUM) 78 | cls[i][pos1] = 1 79 | cls[i][neg1] = 0 80 | pos = np.where((overlap[i] > 0.72)) 81 | neg = np.where((overlap[i] <= 0.45)) 82 | pos, pos_num = self.select(pos, cfg.TRAIN.POS_NUM) 83 | neg, neg_num = self.select(neg, cfg.TRAIN.TOTAL_NUM - cfg.TRAIN.POS_NUM) 84 | if anchor[:,2].min()>0 and anchor[:,3].min()>0: 85 | delta[i][0] = (tcx - cx) / (w+1e-6) 86 | delta[i][1] = (tcy - cy) / (h+1e-6) 87 | delta[i][2] = np.log(tw / (w+1e-6) + 1e-6) 88 | delta[i][3] = np.log(th / (h+1e-6) + 1e-6) 89 | delta_weight[i][pos] = 1. / (pos_num + 1e-6) 90 | delta_weight[i][neg] =0 91 | 92 | cls=t.Tensor(cls).cuda() 93 | delta_weight=t.Tensor(delta_weight).cuda() 94 | delta=t.Tensor(delta).cuda() 95 | 96 | return cls, delta, delta_weight 97 | -------------------------------------------------------------------------------- /snot/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from snot.models.backbone.alexnet import alexnetlegacy, alexnet 7 | from snot.models.backbone.mobile_v2 import mobilenetv2 8 | from snot.models.backbone.resnet_atrous import resnet18, resnet34, resnet50 9 | from snot.models.backbone.googlenet import Inception3 10 | from snot.models.backbone.googlenet_ou import Inception3_ou 11 | 12 | BACKBONES = { 13 | 'alexnetlegacy': alexnetlegacy, 14 | 'mobilenetv2': mobilenetv2, 15 | 'resnet18': resnet18, 16 | 'resnet34': resnet34, 17 | 'resnet50': resnet50, 18 | 'alexnet': alexnet, 19 | 'googlenet': Inception3, 20 | 'googlenet_ou': Inception3_ou, 21 | } 22 | 23 | 24 | def get_backbone(name, **kwargs): 25 | return BACKBONES[name](**kwargs) 26 | -------------------------------------------------------------------------------- /snot/models/backbone/mobile_v2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch.nn as nn 7 | 8 | 9 | def conv_bn(inp, oup, stride, padding=1): 10 | return nn.Sequential( 11 | nn.Conv2d(inp, oup, 3, stride, padding, bias=False), 12 | nn.BatchNorm2d(oup), 13 | nn.ReLU6(inplace=True) 14 | ) 15 | 16 | 17 | class InvertedResidual(nn.Module): 18 | def __init__(self, inp, oup, stride, expand_ratio, dilation=1): 19 | super(InvertedResidual, self).__init__() 20 | self.stride = stride 21 | 22 | self.use_res_connect = self.stride == 1 and inp == oup 23 | 24 | padding = 2 - stride 25 | if dilation > 1: 26 | padding = dilation 27 | 28 | self.conv = nn.Sequential( 29 | # pw 30 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 31 | nn.BatchNorm2d(inp * expand_ratio), 32 | nn.ReLU6(inplace=True), 33 | # dw 34 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, 35 | stride, padding, dilation=dilation, 36 | groups=inp * expand_ratio, bias=False), 37 | nn.BatchNorm2d(inp * expand_ratio), 38 | nn.ReLU6(inplace=True), 39 | # pw-linear 40 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 41 | nn.BatchNorm2d(oup), 42 | ) 43 | 44 | def forward(self, x): 45 | if self.use_res_connect: 46 | return x + self.conv(x) 47 | else: 48 | return self.conv(x) 49 | 50 | 51 | class MobileNetV2(nn.Sequential): 52 | def __init__(self, width_mult=1.0, used_layers=[3, 5, 7]): 53 | super(MobileNetV2, self).__init__() 54 | 55 | self.interverted_residual_setting = [ 56 | # t, c, n, s 57 | [1, 16, 1, 1, 1], 58 | [6, 24, 2, 2, 1], 59 | [6, 32, 3, 2, 1], 60 | [6, 64, 4, 2, 1], 61 | [6, 96, 3, 1, 1], 62 | [6, 160, 3, 2, 1], 63 | [6, 320, 1, 1, 1], 64 | ] 65 | # 0,2,3,4,6 66 | 67 | self.interverted_residual_setting = [ 68 | # t, c, n, s 69 | [1, 16, 1, 1, 1], 70 | [6, 24, 2, 2, 1], 71 | [6, 32, 3, 2, 1], 72 | [6, 64, 4, 1, 2], 73 | [6, 96, 3, 1, 2], 74 | [6, 160, 3, 1, 4], 75 | [6, 320, 1, 1, 4], 76 | ] 77 | 78 | self.channels = [24, 32, 96, 320] 79 | self.channels = [int(c * width_mult) for c in self.channels] 80 | 81 | input_channel = int(32 * width_mult) 82 | self.last_channel = int(1280 * width_mult) \ 83 | if width_mult > 1.0 else 1280 84 | 85 | self.add_module('layer0', conv_bn(3, input_channel, 2, 0)) 86 | 87 | last_dilation = 1 88 | 89 | self.used_layers = used_layers 90 | 91 | for idx, (t, c, n, s, d) in \ 92 | enumerate(self.interverted_residual_setting, start=1): 93 | output_channel = int(c * width_mult) 94 | 95 | layers = [] 96 | 97 | for i in range(n): 98 | if i == 0: 99 | if d == last_dilation: 100 | dd = d 101 | else: 102 | dd = max(d // 2, 1) 103 | layers.append(InvertedResidual(input_channel, 104 | output_channel, s, t, dd)) 105 | else: 106 | layers.append(InvertedResidual(input_channel, 107 | output_channel, 1, t, d)) 108 | input_channel = output_channel 109 | 110 | last_dilation = d 111 | 112 | self.add_module('layer%d' % (idx), nn.Sequential(*layers)) 113 | 114 | def forward(self, x): 115 | outputs = [] 116 | for idx in range(8): 117 | name = "layer%d" % idx 118 | x = getattr(self, name)(x) 119 | outputs.append(x) 120 | p0, p1, p2, p3, p4 = [outputs[i] for i in [1, 2, 3, 5, 7]] 121 | out = [outputs[i] for i in self.used_layers] 122 | if len(out) == 1: 123 | return out[0] 124 | return out 125 | 126 | 127 | def mobilenetv2(**kwargs): 128 | return MobileNetV2(**kwargs) 129 | -------------------------------------------------------------------------------- /snot/models/dasiamrpn_model.py: -------------------------------------------------------------------------------- 1 | # Parts of this code come from https://github.com/foolwood/DaSiamRPN 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SiamRPN(nn.Module): 7 | def __init__(self, size=2, feature_out=512, anchor=5): 8 | configs = [3, 96, 256, 384, 384, 256] 9 | configs = list(map(lambda x: 3 if x==3 else x*size, configs)) 10 | feat_in = configs[-1] 11 | super(SiamRPN, self).__init__() 12 | self.featureExtract = nn.Sequential( 13 | nn.Conv2d(configs[0], configs[1] , kernel_size=11, stride=2), 14 | nn.BatchNorm2d(configs[1]), 15 | nn.MaxPool2d(kernel_size=3, stride=2), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(configs[1], configs[2], kernel_size=5), 18 | nn.BatchNorm2d(configs[2]), 19 | nn.MaxPool2d(kernel_size=3, stride=2), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(configs[2], configs[3], kernel_size=3), 22 | nn.BatchNorm2d(configs[3]), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(configs[3], configs[4], kernel_size=3), 25 | nn.BatchNorm2d(configs[4]), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(configs[4], configs[5], kernel_size=3), 28 | nn.BatchNorm2d(configs[5]), 29 | ) 30 | 31 | self.anchor = anchor 32 | self.feature_out = feature_out 33 | 34 | self.conv_r1 = nn.Conv2d(feat_in, feature_out*4*anchor, 3) 35 | self.conv_r2 = nn.Conv2d(feat_in, feature_out, 3) 36 | self.conv_cls1 = nn.Conv2d(feat_in, feature_out*2*anchor, 3) 37 | self.conv_cls2 = nn.Conv2d(feat_in, feature_out, 3) 38 | self.regress_adjust = nn.Conv2d(4*anchor, 4*anchor, 1) 39 | 40 | self.r1_kernel = [] 41 | self.cls1_kernel = [] 42 | 43 | self.cfg = {} 44 | 45 | def forward(self, x): 46 | x_f = self.featureExtract(x) 47 | return self.regress_adjust(F.conv2d(self.conv_r2(x_f), self.r1_kernel)), \ 48 | F.conv2d(self.conv_cls2(x_f), self.cls1_kernel) 49 | 50 | def temple(self, z): 51 | z_f = self.featureExtract(z) 52 | r1_kernel_raw = self.conv_r1(z_f) 53 | cls1_kernel_raw = self.conv_cls1(z_f) 54 | kernel_size = r1_kernel_raw.data.size()[-1] 55 | self.r1_kernel = r1_kernel_raw.view(self.anchor*4, self.feature_out, kernel_size, kernel_size) 56 | self.cls1_kernel = cls1_kernel_raw.view(self.anchor*2, self.feature_out, kernel_size, kernel_size) 57 | 58 | 59 | class SiamRPNBIG(SiamRPN): 60 | def __init__(self): 61 | super(SiamRPNBIG, self).__init__(size=2) 62 | self.cfg = {'lr':0.295, 'window_influence': 0.42, 'penalty_k': 0.055, 'instance_size': 271, 'adaptive': True} # 0.383 63 | 64 | 65 | class SiamRPNvot(SiamRPN): 66 | def __init__(self): 67 | super(SiamRPNvot, self).__init__(size=1, feature_out=256) 68 | self.cfg = {'lr':0.45, 'window_influence': 0.44, 'penalty_k': 0.04, 'instance_size': 271, 'adaptive': False} # 0.355 69 | 70 | 71 | class SiamRPNotb(SiamRPN): 72 | def __init__(self): 73 | super(SiamRPNotb, self).__init__(size=1, feature_out=256) 74 | self.cfg = {'lr': 0.30, 'window_influence': 0.40, 'penalty_k': 0.22, 'instance_size': 271, 'adaptive': False} # 0.655 75 | -------------------------------------------------------------------------------- /snot/models/head/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from snot.models.head.mask import MaskCorr, Refine 7 | from snot.models.head.rpn import UPChannelRPN, DepthwiseRPN, MultiRPN 8 | from snot.models.head.ban import UPChannelBAN, DepthwiseBAN, MultiBAN 9 | 10 | RPNS = { 11 | 'UPChannelRPN': UPChannelRPN, 12 | 'DepthwiseRPN': DepthwiseRPN, 13 | 'MultiRPN': MultiRPN 14 | } 15 | 16 | MASKS = { 17 | 'MaskCorr': MaskCorr, 18 | } 19 | 20 | REFINE = { 21 | 'Refine': Refine, 22 | } 23 | 24 | BANS = { 25 | 'UPChannelBAN': UPChannelBAN, 26 | 'DepthwiseBAN': DepthwiseBAN, 27 | 'MultiBAN': MultiBAN 28 | } 29 | 30 | 31 | def get_rpn_head(name, **kwargs): 32 | return RPNS[name](**kwargs) 33 | 34 | 35 | def get_mask_head(name, **kwargs): 36 | return MASKS[name](**kwargs) 37 | 38 | 39 | def get_refine_head(name): 40 | return REFINE[name]() 41 | 42 | 43 | def get_ban_head(name, **kwargs): 44 | return BANS[name](**kwargs) 45 | -------------------------------------------------------------------------------- /snot/models/head/ban.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from snot.core.xcorr import xcorr_fast, xcorr_depthwise 11 | 12 | class BAN(nn.Module): 13 | def __init__(self): 14 | super(BAN, self).__init__() 15 | 16 | def forward(self, z_f, x_f): 17 | raise NotImplementedError 18 | 19 | class UPChannelBAN(BAN): 20 | def __init__(self, feature_in=256, cls_out_channels=2): 21 | super(UPChannelBAN, self).__init__() 22 | 23 | cls_output = cls_out_channels 24 | loc_output = 4 25 | 26 | self.template_cls_conv = nn.Conv2d(feature_in, 27 | feature_in * cls_output, kernel_size=3) 28 | self.template_loc_conv = nn.Conv2d(feature_in, 29 | feature_in * loc_output, kernel_size=3) 30 | 31 | self.search_cls_conv = nn.Conv2d(feature_in, 32 | feature_in, kernel_size=3) 33 | self.search_loc_conv = nn.Conv2d(feature_in, 34 | feature_in, kernel_size=3) 35 | 36 | self.loc_adjust = nn.Conv2d(loc_output, loc_output, kernel_size=1) 37 | 38 | 39 | def forward(self, z_f, x_f): 40 | cls_kernel = self.template_cls_conv(z_f) 41 | loc_kernel = self.template_loc_conv(z_f) 42 | 43 | cls_feature = self.search_cls_conv(x_f) 44 | loc_feature = self.search_loc_conv(x_f) 45 | 46 | cls = xcorr_fast(cls_feature, cls_kernel) 47 | loc = self.loc_adjust(xcorr_fast(loc_feature, loc_kernel)) 48 | return cls, loc 49 | 50 | 51 | class DepthwiseXCorr(nn.Module): 52 | def __init__(self, in_channels, hidden, out_channels, kernel_size=3): 53 | super(DepthwiseXCorr, self).__init__() 54 | self.conv_kernel = nn.Sequential( 55 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 56 | nn.BatchNorm2d(hidden), 57 | nn.ReLU(inplace=True), 58 | ) 59 | self.conv_search = nn.Sequential( 60 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 61 | nn.BatchNorm2d(hidden), 62 | nn.ReLU(inplace=True), 63 | ) 64 | self.head = nn.Sequential( 65 | nn.Conv2d(hidden, hidden, kernel_size=1, bias=False), 66 | nn.BatchNorm2d(hidden), 67 | nn.ReLU(inplace=True), 68 | nn.Conv2d(hidden, out_channels, kernel_size=1) 69 | ) 70 | 71 | 72 | def forward(self, kernel, search): 73 | kernel = self.conv_kernel(kernel) 74 | search = self.conv_search(search) 75 | feature = xcorr_depthwise(search, kernel) 76 | out = self.head(feature) 77 | return out 78 | 79 | 80 | class DepthwiseBAN(BAN): 81 | def __init__(self, in_channels=256, out_channels=256, cls_out_channels=2, weighted=False): 82 | super(DepthwiseBAN, self).__init__() 83 | self.cls = DepthwiseXCorr(in_channels, out_channels, cls_out_channels) 84 | self.loc = DepthwiseXCorr(in_channels, out_channels, 4) 85 | 86 | def forward(self, z_f, x_f): 87 | cls = self.cls(z_f, x_f) 88 | loc = self.loc(z_f, x_f) 89 | return cls, loc 90 | 91 | 92 | class MultiBAN(BAN): 93 | def __init__(self, in_channels, cls_out_channels, weighted=False): 94 | super(MultiBAN, self).__init__() 95 | self.weighted = weighted 96 | for i in range(len(in_channels)): 97 | self.add_module('box'+str(i+2), DepthwiseBAN(in_channels[i], in_channels[i], cls_out_channels)) 98 | if self.weighted: 99 | self.cls_weight = nn.Parameter(torch.ones(len(in_channels))) 100 | self.loc_weight = nn.Parameter(torch.ones(len(in_channels))) 101 | self.loc_scale = nn.Parameter(torch.ones(len(in_channels))) 102 | 103 | def forward(self, z_fs, x_fs): 104 | cls = [] 105 | loc = [] 106 | for idx, (z_f, x_f) in enumerate(zip(z_fs, x_fs), start=2): 107 | box = getattr(self, 'box'+str(idx)) 108 | c, l = box(z_f, x_f) 109 | cls.append(c) 110 | loc.append(torch.exp(l*self.loc_scale[idx-2])) 111 | 112 | if self.weighted: 113 | cls_weight = F.softmax(self.cls_weight, 0) 114 | loc_weight = F.softmax(self.loc_weight, 0) 115 | 116 | def avg(lst): 117 | return sum(lst) / len(lst) 118 | 119 | def weighted_avg(lst, weight): 120 | s = 0 121 | for i in range(len(weight)): 122 | s += lst[i] * weight[i] 123 | return s 124 | 125 | if self.weighted: 126 | return weighted_avg(cls, cls_weight), weighted_avg(loc, loc_weight) 127 | else: 128 | return avg(cls), avg(loc) 129 | -------------------------------------------------------------------------------- /snot/models/head/car.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | from torch import nn 8 | import math 9 | 10 | 11 | class CAR(torch.nn.Module): 12 | def __init__(self, cfg, in_channels): 13 | """ 14 | Arguments: 15 | in_channels (int): number of channels of the input feature 16 | """ 17 | super(CAR, self).__init__() 18 | # TODO: Implement the sigmoid version first. 19 | num_classes = cfg.TRAIN.NUM_CLASSES 20 | 21 | cls_tower = [] 22 | bbox_tower = [] 23 | for i in range(cfg.TRAIN.NUM_CONVS): 24 | cls_tower.append( 25 | nn.Conv2d( 26 | in_channels, 27 | in_channels, 28 | kernel_size=3, 29 | stride=1, 30 | padding=1 31 | ) 32 | ) 33 | cls_tower.append(nn.GroupNorm(32, in_channels)) 34 | cls_tower.append(nn.ReLU()) 35 | bbox_tower.append( 36 | nn.Conv2d( 37 | in_channels, 38 | in_channels, 39 | kernel_size=3, 40 | stride=1, 41 | padding=1 42 | ) 43 | ) 44 | bbox_tower.append(nn.GroupNorm(32, in_channels)) 45 | bbox_tower.append(nn.ReLU()) 46 | 47 | self.add_module('cls_tower', nn.Sequential(*cls_tower)) 48 | self.add_module('bbox_tower', nn.Sequential(*bbox_tower)) 49 | self.cls_logits = nn.Conv2d( 50 | in_channels, num_classes, kernel_size=3, stride=1, 51 | padding=1 52 | ) 53 | self.bbox_pred = nn.Conv2d( 54 | in_channels, 4, kernel_size=3, stride=1, 55 | padding=1 56 | ) 57 | self.centerness = nn.Conv2d( 58 | in_channels, 1, kernel_size=3, stride=1, 59 | padding=1 60 | ) 61 | 62 | # initialization 63 | for modules in [self.cls_tower, self.bbox_tower, 64 | self.cls_logits, self.bbox_pred, 65 | self.centerness]: 66 | for l in modules.modules(): 67 | if isinstance(l, nn.Conv2d): 68 | torch.nn.init.normal_(l.weight, std=0.01) 69 | torch.nn.init.constant_(l.bias, 0) 70 | 71 | # initialize the bias for focal loss 72 | prior_prob = cfg.TRAIN.PRIOR_PROB 73 | bias_value = -math.log((1 - prior_prob) / prior_prob) 74 | torch.nn.init.constant_(self.cls_logits.bias, bias_value) 75 | 76 | def forward(self, x): 77 | cls_tower = self.cls_tower(x) 78 | logits = self.cls_logits(cls_tower) 79 | centerness = self.centerness(cls_tower) 80 | bbox_reg = torch.exp(self.bbox_pred(self.bbox_tower(x))) 81 | 82 | return logits, bbox_reg, centerness 83 | 84 | 85 | class Scale(nn.Module): 86 | def __init__(self, init_value=1.0): 87 | super(Scale, self).__init__() 88 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 89 | 90 | def forward(self, input): 91 | return input * self.scale 92 | -------------------------------------------------------------------------------- /snot/models/head/mask.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from snot.models.head.rpn import DepthwiseXCorr 10 | from snot.core.xcorr import xcorr_depthwise 11 | 12 | 13 | class MaskCorr(DepthwiseXCorr): 14 | def __init__(self, in_channels, hidden, out_channels, 15 | kernel_size=3, hidden_kernel_size=5): 16 | super(MaskCorr, self).__init__(in_channels, hidden, 17 | out_channels, kernel_size, 18 | hidden_kernel_size) 19 | 20 | def forward(self, kernel, search): 21 | kernel = self.conv_kernel(kernel) 22 | search = self.conv_search(search) 23 | feature = xcorr_depthwise(search, kernel) 24 | out = self.head(feature) 25 | return out, feature 26 | 27 | 28 | class Refine(nn.Module): 29 | def __init__(self): 30 | super(Refine, self).__init__() 31 | self.v0 = nn.Sequential( 32 | nn.Conv2d(64, 16, 3, padding=1), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(16, 4, 3, padding=1), 35 | nn.ReLU(inplace=True), 36 | ) 37 | self.v1 = nn.Sequential( 38 | nn.Conv2d(256, 64, 3, padding=1), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(64, 16, 3, padding=1), 41 | nn.ReLU(inplace=True), 42 | ) 43 | self.v2 = nn.Sequential( 44 | nn.Conv2d(512, 128, 3, padding=1), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(128, 32, 3, padding=1), 47 | nn.ReLU(inplace=True), 48 | ) 49 | self.h2 = nn.Sequential( 50 | nn.Conv2d(32, 32, 3, padding=1), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(32, 32, 3, padding=1), 53 | nn.ReLU(inplace=True), 54 | ) 55 | self.h1 = nn.Sequential( 56 | nn.Conv2d(16, 16, 3, padding=1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(16, 16, 3, padding=1), 59 | nn.ReLU(inplace=True), 60 | ) 61 | self.h0 = nn.Sequential( 62 | nn.Conv2d(4, 4, 3, padding=1), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(4, 4, 3, padding=1), 65 | nn.ReLU(inplace=True), 66 | ) 67 | 68 | self.deconv = nn.ConvTranspose2d(256, 32, 15, 15) 69 | self.post0 = nn.Conv2d(32, 16, 3, padding=1) 70 | self.post1 = nn.Conv2d(16, 4, 3, padding=1) 71 | self.post2 = nn.Conv2d(4, 1, 3, padding=1) 72 | 73 | def forward(self, f, corr_feature, pos): 74 | p0 = F.pad(f[0], [16, 16, 16, 16])[:, :, 4*pos[0]:4*pos[0]+61, 4*pos[1]:4*pos[1]+61] 75 | p1 = F.pad(f[1], [8, 8, 8, 8])[:, :, 2*pos[0]:2*pos[0]+31, 2*pos[1]:2*pos[1]+31] 76 | p2 = F.pad(f[2], [4, 4, 4, 4])[:, :, pos[0]:pos[0]+15, pos[1]:pos[1]+15] 77 | 78 | p3 = corr_feature[:, :, pos[0], pos[1]].view(-1, 256, 1, 1) 79 | 80 | out = self.deconv(p3) 81 | out = self.post0(F.upsample(self.h2(out) + self.v2(p2), size=(31, 31))) 82 | out = self.post1(F.upsample(self.h1(out) + self.v1(p1), size=(61, 61))) 83 | out = self.post2(F.upsample(self.h0(out) + self.v0(p0), size=(127, 127))) 84 | out = out.view(-1, 127*127) 85 | return out 86 | -------------------------------------------------------------------------------- /snot/models/head/rpn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from snot.core.xcorr import xcorr_fast, xcorr_depthwise 11 | 12 | class RPN(nn.Module): 13 | def __init__(self): 14 | super(RPN, self).__init__() 15 | 16 | def forward(self, z_f, x_f): 17 | raise NotImplementedError 18 | 19 | class UPChannelRPN(RPN): 20 | def __init__(self, anchor_num=5, feature_in=256): 21 | super(UPChannelRPN, self).__init__() 22 | 23 | cls_output = 2 * anchor_num 24 | loc_output = 4 * anchor_num 25 | 26 | self.template_cls_conv = nn.Conv2d(feature_in, 27 | feature_in * cls_output, kernel_size=3) 28 | self.template_loc_conv = nn.Conv2d(feature_in, 29 | feature_in * loc_output, kernel_size=3) 30 | 31 | self.search_cls_conv = nn.Conv2d(feature_in, 32 | feature_in, kernel_size=3) 33 | self.search_loc_conv = nn.Conv2d(feature_in, 34 | feature_in, kernel_size=3) 35 | 36 | self.loc_adjust = nn.Conv2d(loc_output, loc_output, kernel_size=1) 37 | 38 | 39 | def forward(self, z_f, x_f): 40 | cls_kernel = self.template_cls_conv(z_f) 41 | loc_kernel = self.template_loc_conv(z_f) 42 | 43 | cls_feature = self.search_cls_conv(x_f) 44 | loc_feature = self.search_loc_conv(x_f) 45 | 46 | cls = xcorr_fast(cls_feature, cls_kernel) 47 | loc = self.loc_adjust(xcorr_fast(loc_feature, loc_kernel)) 48 | return cls, loc 49 | 50 | 51 | class DepthwiseXCorr(nn.Module): 52 | def __init__(self, in_channels, hidden, out_channels, kernel_size=3, hidden_kernel_size=5): 53 | super(DepthwiseXCorr, self).__init__() 54 | self.conv_kernel = nn.Sequential( 55 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 56 | nn.BatchNorm2d(hidden), 57 | nn.ReLU(inplace=True), 58 | ) 59 | self.conv_search = nn.Sequential( 60 | nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), 61 | nn.BatchNorm2d(hidden), 62 | nn.ReLU(inplace=True), 63 | ) 64 | self.head = nn.Sequential( 65 | nn.Conv2d(hidden, hidden, kernel_size=1, bias=False), 66 | nn.BatchNorm2d(hidden), 67 | nn.ReLU(inplace=True), 68 | nn.Conv2d(hidden, out_channels, kernel_size=1) 69 | ) 70 | 71 | 72 | def forward(self, kernel, search): 73 | kernel = self.conv_kernel(kernel) 74 | search = self.conv_search(search) 75 | feature = xcorr_depthwise(search, kernel) 76 | out = self.head(feature) 77 | return out 78 | 79 | 80 | class DepthwiseRPN(RPN): 81 | def __init__(self, anchor_num=5, in_channels=256, out_channels=256): 82 | super(DepthwiseRPN, self).__init__() 83 | self.cls = DepthwiseXCorr(in_channels, out_channels, 2 * anchor_num) 84 | self.loc = DepthwiseXCorr(in_channels, out_channels, 4 * anchor_num) 85 | 86 | def forward(self, z_f, x_f): 87 | cls = self.cls(z_f, x_f) 88 | loc = self.loc(z_f, x_f) 89 | return cls, loc 90 | 91 | 92 | class MultiRPN(RPN): 93 | def __init__(self, anchor_num, in_channels, weighted=False): 94 | super(MultiRPN, self).__init__() 95 | self.weighted = weighted 96 | for i in range(len(in_channels)): 97 | self.add_module('rpn'+str(i+2), 98 | DepthwiseRPN(anchor_num, in_channels[i], in_channels[i])) 99 | if self.weighted: 100 | self.cls_weight = nn.Parameter(torch.ones(len(in_channels))) 101 | self.loc_weight = nn.Parameter(torch.ones(len(in_channels))) 102 | 103 | def forward(self, z_fs, x_fs): 104 | cls = [] 105 | loc = [] 106 | for idx, (z_f, x_f) in enumerate(zip(z_fs, x_fs), start=2): 107 | rpn = getattr(self, 'rpn'+str(idx)) 108 | c, l = rpn(z_f, x_f) 109 | cls.append(c) 110 | loc.append(l) 111 | 112 | if self.weighted: 113 | cls_weight = F.softmax(self.cls_weight, 0) 114 | loc_weight = F.softmax(self.loc_weight, 0) 115 | 116 | def avg(lst): 117 | return sum(lst) / len(lst) 118 | 119 | def weighted_avg(lst, weight): 120 | s = 0 121 | for i in range(len(weight)): 122 | s += lst[i] * weight[i] 123 | return s 124 | 125 | if self.weighted: 126 | return weighted_avg(cls, cls_weight), weighted_avg(loc, loc_weight) 127 | else: 128 | return avg(cls), avg(loc) 129 | -------------------------------------------------------------------------------- /snot/models/model_builder.py: -------------------------------------------------------------------------------- 1 | # Parts of this code come from https://github.com/STVIR/pysot 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import torch.nn as nn 8 | 9 | from snot.core.config import cfg 10 | from snot.models.backbone import get_backbone 11 | from snot.models.head import get_rpn_head, get_mask_head, get_refine_head 12 | from snot.models.neck import get_neck 13 | 14 | 15 | class ModelBuilder(nn.Module): 16 | def __init__(self): 17 | super(ModelBuilder, self).__init__() 18 | 19 | # build backbone 20 | self.backbone = get_backbone(cfg.BACKBONE.TYPE, 21 | **cfg.BACKBONE.KWARGS) 22 | 23 | # build adjust layer 24 | if cfg.ADJUST.ADJUST: 25 | self.neck = get_neck(cfg.ADJUST.TYPE, 26 | **cfg.ADJUST.KWARGS) 27 | 28 | # build rpn head 29 | self.rpn_head = get_rpn_head(cfg.RPN.TYPE, 30 | **cfg.RPN.KWARGS) 31 | 32 | # build mask head 33 | if cfg.MASK.MASK: 34 | self.mask_head = get_mask_head(cfg.MASK.TYPE, 35 | **cfg.MASK.KWARGS) 36 | 37 | if cfg.REFINE.REFINE: 38 | self.refine_head = get_refine_head(cfg.REFINE.TYPE) 39 | 40 | def template(self, z): 41 | zf = self.backbone(z) 42 | if cfg.MASK.MASK: 43 | zf = zf[-1] 44 | if cfg.ADJUST.ADJUST: 45 | zf = self.neck(zf) 46 | self.zf = zf 47 | 48 | def track(self, x): 49 | xf = self.backbone(x) 50 | if cfg.MASK.MASK: 51 | self.xf = xf[:-1] 52 | xf = xf[-1] 53 | if cfg.ADJUST.ADJUST: 54 | xf = self.neck(xf) 55 | cls, loc = self.rpn_head(self.zf, xf) 56 | if cfg.MASK.MASK: 57 | mask, self.mask_corr_feature = self.mask_head(self.zf, xf) 58 | return { 59 | 'cls': cls, 60 | 'loc': loc, 61 | 'mask': mask if cfg.MASK.MASK else None 62 | } 63 | 64 | def mask_refine(self, pos): 65 | return self.refine_head(self.xf, self.mask_corr_feature, pos) 66 | -------------------------------------------------------------------------------- /snot/models/neck/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from snot.models.neck.neck import AdjustLayer, AdjustAllLayer, BAN_AdjustLayer, BAN_AdjustAllLayer 7 | 8 | NECKS = { 9 | 'AdjustLayer': AdjustLayer, 10 | 'AdjustAllLayer': AdjustAllLayer 11 | } 12 | 13 | def get_neck(name, **kwargs): 14 | return NECKS[name](**kwargs) 15 | 16 | BAN_NECKS = { 17 | 'AdjustLayer': BAN_AdjustLayer, 18 | 'AdjustAllLayer': BAN_AdjustAllLayer 19 | } 20 | 21 | def get_ban_neck(name, **kwargs): 22 | return BAN_NECKS[name](**kwargs) 23 | -------------------------------------------------------------------------------- /snot/models/neck/neck.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class AdjustLayer(nn.Module): 10 | def __init__(self, in_channels, out_channels, center_size=7): 11 | super(AdjustLayer, self).__init__() 12 | self.downsample = nn.Sequential( 13 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 14 | nn.BatchNorm2d(out_channels), 15 | ) 16 | self.center_size = center_size 17 | 18 | def forward(self, x): 19 | x = self.downsample(x) 20 | if x.size(3) < 20: 21 | l = (x.size(3) - self.center_size) // 2 22 | r = l + self.center_size 23 | x = x[:, :, l:r, l:r] 24 | return x 25 | 26 | 27 | class AdjustAllLayer(nn.Module): 28 | def __init__(self, in_channels, out_channels, center_size=7): 29 | super(AdjustAllLayer, self).__init__() 30 | self.num = len(out_channels) 31 | if self.num == 1: 32 | self.downsample = AdjustLayer(in_channels[0], 33 | out_channels[0], 34 | center_size) 35 | else: 36 | for i in range(self.num): 37 | self.add_module('downsample'+str(i+2), 38 | AdjustLayer(in_channels[i], 39 | out_channels[i], 40 | center_size)) 41 | 42 | def forward(self, features): 43 | if self.num == 1: 44 | return self.downsample(features) 45 | else: 46 | out = [] 47 | for i in range(self.num): 48 | adj_layer = getattr(self, 'downsample'+str(i+2)) 49 | out.append(adj_layer(features[i])) 50 | return out 51 | 52 | 53 | class BAN_AdjustLayer(nn.Module): 54 | def __init__(self, in_channels, out_channels): 55 | super(BAN_AdjustLayer, self).__init__() 56 | self.downsample = nn.Sequential( 57 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 58 | nn.BatchNorm2d(out_channels), 59 | ) 60 | 61 | def forward(self, x): 62 | x = self.downsample(x) 63 | if x.size(3) < 20: 64 | l = 4 65 | r = l + 7 66 | x = x[:, :, l:r, l:r] 67 | return x 68 | 69 | 70 | class BAN_AdjustAllLayer(nn.Module): 71 | def __init__(self, in_channels, out_channels): 72 | super(BAN_AdjustAllLayer, self).__init__() 73 | self.num = len(out_channels) 74 | if self.num == 1: 75 | self.downsample = AdjustLayer(in_channels[0], out_channels[0]) 76 | else: 77 | for i in range(self.num): 78 | self.add_module('downsample'+str(i+2), 79 | AdjustLayer(in_channels[i], out_channels[i])) 80 | 81 | def forward(self, features): 82 | if self.num == 1: 83 | return self.downsample(features) 84 | else: 85 | out = [] 86 | for i in range(self.num): 87 | adj_layer = getattr(self, 'downsample'+str(i+2)) 88 | out.append(adj_layer(features[i])) 89 | return out 90 | -------------------------------------------------------------------------------- /snot/models/siamapn_model.py: -------------------------------------------------------------------------------- 1 | # Parts of this code come from https://github.com/vision4robotics/SiamAPN 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import torch.nn as nn 8 | 9 | from snot.core.config_apn import cfg 10 | from snot.models.backbone.alexnet import AlexNet_apn 11 | from snot.models.apn.utile import APN,clsandloc_apn 12 | from snot.models.apn.anchortarget import AnchorTarget3_apn 13 | 14 | 15 | class ModelBuilderAPN(nn.Module): 16 | def __init__(self): 17 | super(ModelBuilderAPN, self).__init__() 18 | 19 | self.backbone = AlexNet_apn().cuda() 20 | self.grader=APN(cfg).cuda() 21 | self.new=clsandloc_apn(cfg).cuda() 22 | self.fin2=AnchorTarget3_apn() 23 | 24 | def template(self, z): 25 | 26 | zf1,zf = self.backbone(z) 27 | self.zf=zf 28 | self.zf1=zf1 29 | 30 | def track(self, x): 31 | 32 | xf1,xf = self.backbone(x) 33 | xff,ress=self.grader(xf1,self.zf1) 34 | 35 | self.ranchors=xff 36 | 37 | cls1,cls2,cls3,loc =self.new(xf,self.zf,ress) 38 | 39 | return { 40 | 'cls1': cls1, 41 | 'cls2': cls2, 42 | 'cls3': cls3, 43 | 'loc': loc 44 | } 45 | -------------------------------------------------------------------------------- /snot/models/siamban_model.py: -------------------------------------------------------------------------------- 1 | # Parts of this code come from https://github.com/hqucv/siamban 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import torch.nn as nn 8 | 9 | from snot.core.config_ban import cfg 10 | from snot.models.backbone import get_backbone 11 | from snot.models.head import get_ban_head 12 | from snot.models.neck import get_ban_neck 13 | 14 | 15 | class ModelBuilderBAN(nn.Module): 16 | def __init__(self): 17 | super(ModelBuilderBAN, self).__init__() 18 | 19 | # build backbone 20 | self.backbone = get_backbone(cfg.BACKBONE.TYPE, 21 | **cfg.BACKBONE.KWARGS) 22 | 23 | # build adjust layer 24 | if cfg.ADJUST.ADJUST: 25 | self.neck = get_ban_neck(cfg.ADJUST.TYPE, 26 | **cfg.ADJUST.KWARGS) 27 | 28 | # build ban head 29 | if cfg.BAN.BAN: 30 | self.head = get_ban_head(cfg.BAN.TYPE, 31 | **cfg.BAN.KWARGS) 32 | 33 | def template(self, z): 34 | zf = self.backbone(z) 35 | if cfg.ADJUST.ADJUST: 36 | zf = self.neck(zf) 37 | self.zf = zf 38 | 39 | def track(self, x): 40 | xf = self.backbone(x) 41 | if cfg.ADJUST.ADJUST: 42 | xf = self.neck(xf) 43 | cls, loc = self.head(self.zf, xf) 44 | return { 45 | 'cls': cls, 46 | 'loc': loc 47 | } 48 | -------------------------------------------------------------------------------- /snot/models/siamgat_model.py: -------------------------------------------------------------------------------- 1 | # Parts of this code come from https://github.com/ohhhyeahhh/SiamGAT 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from snot.core.config_gat import cfg 12 | from snot.models.backbone import get_backbone 13 | from snot.models.head.car import CAR 14 | 15 | 16 | class Graph_Attention_Union(nn.Module): 17 | def __init__(self, in_channel, out_channel): 18 | super(Graph_Attention_Union, self).__init__() 19 | 20 | # search region nodes linear transformation 21 | self.support = nn.Conv2d(in_channel, in_channel, 1, 1) 22 | 23 | # target template nodes linear transformation 24 | self.query = nn.Conv2d(in_channel, in_channel, 1, 1) 25 | 26 | # linear transformation for message passing 27 | self.g = nn.Sequential( 28 | nn.Conv2d(in_channel, in_channel, 1, 1), 29 | nn.BatchNorm2d(in_channel), 30 | nn.ReLU(inplace=True), 31 | ) 32 | 33 | # aggregated feature 34 | self.fi = nn.Sequential( 35 | nn.Conv2d(in_channel*2, out_channel, 1, 1), 36 | nn.BatchNorm2d(out_channel), 37 | nn.ReLU(inplace=True), 38 | ) 39 | 40 | def forward(self, zf, xf): 41 | # linear transformation 42 | xf_trans = self.query(xf) 43 | zf_trans = self.support(zf) 44 | 45 | # linear transformation for message passing 46 | xf_g = self.g(xf) 47 | zf_g = self.g(zf) 48 | 49 | # calculate similarity 50 | shape_x = xf_trans.shape 51 | shape_z = zf_trans.shape 52 | 53 | zf_trans_plain = zf_trans.view(-1, shape_z[1], shape_z[2] * shape_z[3]) 54 | zf_g_plain = zf_g.view(-1, shape_z[1], shape_z[2] * shape_z[3]).permute(0, 2, 1) 55 | xf_trans_plain = xf_trans.view(-1, shape_x[1], shape_x[2] * shape_x[3]).permute(0, 2, 1) 56 | 57 | similar = torch.matmul(xf_trans_plain, zf_trans_plain) 58 | similar = F.softmax(similar, dim=2) 59 | 60 | embedding = torch.matmul(similar, zf_g_plain).permute(0, 2, 1) 61 | embedding = embedding.view(-1, shape_x[1], shape_x[2], shape_x[3]) 62 | 63 | # aggregated feature 64 | output = torch.cat([embedding, xf_g], 1) 65 | output = self.fi(output) 66 | return output 67 | 68 | 69 | class ModelBuilderGAT(nn.Module): 70 | def __init__(self): 71 | super(ModelBuilderGAT, self).__init__() 72 | 73 | # build backbone 74 | self.backbone = get_backbone(cfg.BACKBONE.TYPE, 75 | **cfg.BACKBONE.KWARGS) 76 | 77 | # build car head 78 | self.car_head = CAR(cfg, 256) 79 | 80 | # build response map 81 | self.attention = Graph_Attention_Union(256, 256) 82 | 83 | def template(self, z, roi): 84 | zf = self.backbone(z, roi) 85 | self.zf = zf 86 | 87 | def track(self, x): 88 | xf = self.backbone(x) 89 | 90 | features = self.attention(self.zf, xf) 91 | 92 | cls, loc, cen = self.car_head(features) 93 | return { 94 | 'cls': cls, 95 | 'loc': loc, 96 | 'cen': cen 97 | } 98 | -------------------------------------------------------------------------------- /snot/pipelines/pipeline_builder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | 8 | from snot.pipelines.siamapn_pipeline import SiamAPNPipeline 9 | from snot.pipelines.siamapnpp_pipeline import SiamAPNppPipeline 10 | from snot.pipelines.siamban_pipeline import SiamBANPipeline 11 | from snot.pipelines.siamgat_pipeline import SiamGATPipeline 12 | from snot.pipelines.siamrpn_pipeline import SiamRPNppPipeline 13 | 14 | 15 | TRACKERS = { 16 | 'SiamAPN': SiamAPNPipeline, 17 | 'SiamAPN++': SiamAPNppPipeline, 18 | 'SiamRPN++': SiamRPNppPipeline, 19 | 'SiamBAN': SiamBANPipeline, 20 | 'SiamGAT': SiamGATPipeline 21 | } 22 | 23 | def build_pipeline(args, enhancer, denoiser): 24 | return TRACKERS[args.trackername.split('_')[0]](args, enhancer, denoiser) 25 | 26 | -------------------------------------------------------------------------------- /snot/pipelines/siamban_pipeline.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | 8 | from snot.core.config_ban import cfg 9 | from snot.models.siamban_model import ModelBuilderBAN 10 | from snot.trackers.siamban_tracker import SiamBANTracker 11 | from snot.utils.bbox import get_axis_aligned_bbox 12 | from snot.utils.model_load import load_pretrain 13 | 14 | 15 | class DNS_SiamBANTracker(SiamBANTracker): 16 | def __init__(self, model, enhancer=None, denoiser=None): 17 | super(DNS_SiamBANTracker, self).__init__(model) 18 | 19 | self.model = model 20 | self.model.eval() 21 | 22 | self.enhancer = enhancer 23 | self.denoiser = denoiser 24 | 25 | def init(self, img, bbox): 26 | """ 27 | args: 28 | img(np.ndarray): BGR image 29 | bbox: (x, y, w, h) bbox 30 | """ 31 | self.center_pos = np.array([bbox[0]+(bbox[2]-1)/2, 32 | bbox[1]+(bbox[3]-1)/2]) 33 | self.size = np.array([bbox[2], bbox[3]]) 34 | 35 | # calculate z crop size 36 | w_z = self.size[0] + cfg.TRACK.CONTEXT_AMOUNT * np.sum(self.size) 37 | h_z = self.size[1] + cfg.TRACK.CONTEXT_AMOUNT * np.sum(self.size) 38 | s_z = round(np.sqrt(w_z * h_z)) 39 | 40 | # calculate channle average 41 | self.channel_average = np.mean(img, axis=(0, 1)) 42 | 43 | # get crop 44 | z_crop = self.get_subwindow(img, self.center_pos, 45 | cfg.TRACK.EXEMPLAR_SIZE, 46 | s_z, self.channel_average) 47 | if self.enhancer is not None: 48 | z_crop = self.enhancer.enhance(z_crop) 49 | if self.denoiser is not None: 50 | z_crop = self.denoiser.denoise(z_crop) 51 | self.model.template(z_crop) 52 | 53 | def track(self, img): 54 | """ 55 | args: 56 | img(np.ndarray): BGR image 57 | return: 58 | bbox(list):[x, y, width, height] 59 | """ 60 | w_z = self.size[0] + cfg.TRACK.CONTEXT_AMOUNT * np.sum(self.size) 61 | h_z = self.size[1] + cfg.TRACK.CONTEXT_AMOUNT * np.sum(self.size) 62 | s_z = np.sqrt(w_z * h_z) 63 | scale_z = cfg.TRACK.EXEMPLAR_SIZE / s_z 64 | s_x = s_z * (cfg.TRACK.INSTANCE_SIZE / cfg.TRACK.EXEMPLAR_SIZE) 65 | x_crop = self.get_subwindow(img, self.center_pos, 66 | cfg.TRACK.INSTANCE_SIZE, 67 | round(s_x), self.channel_average) 68 | if self.enhancer is not None: 69 | x_crop = self.enhancer.enhance(x_crop) 70 | if self.denoiser is not None: 71 | x_crop = self.denoiser.denoise(x_crop) 72 | outputs = self.model.track(x_crop) 73 | 74 | score = self._convert_score(outputs['cls']) 75 | pred_bbox = self._convert_bbox(outputs['loc'], self.points) 76 | 77 | def change(r): 78 | return np.maximum(r, 1. / r) 79 | 80 | def sz(w, h): 81 | pad = (w + h) * 0.5 82 | return np.sqrt((w + pad) * (h + pad)) 83 | 84 | # scale penalty 85 | s_c = change(sz(pred_bbox[2, :], pred_bbox[3, :]) / 86 | (sz(self.size[0]*scale_z, self.size[1]*scale_z))) 87 | 88 | # aspect ratio penalty 89 | r_c = change((self.size[0]/self.size[1]) / 90 | (pred_bbox[2, :]/pred_bbox[3, :])) 91 | penalty = np.exp(-(r_c * s_c - 1) * cfg.TRACK.PENALTY_K) 92 | pscore = penalty * score 93 | 94 | # window penalty 95 | pscore = pscore * (1 - cfg.TRACK.WINDOW_INFLUENCE) + \ 96 | self.window * cfg.TRACK.WINDOW_INFLUENCE 97 | best_idx = np.argmax(pscore) 98 | bbox = pred_bbox[:, best_idx] / scale_z 99 | lr = penalty[best_idx] * score[best_idx] * cfg.TRACK.LR 100 | 101 | cx = bbox[0] + self.center_pos[0] 102 | cy = bbox[1] + self.center_pos[1] 103 | 104 | # smooth bbox 105 | width = self.size[0] * (1 - lr) + bbox[2] * lr 106 | height = self.size[1] * (1 - lr) + bbox[3] * lr 107 | 108 | # clip boundary 109 | cx, cy, width, height = self._bbox_clip(cx, cy, width, 110 | height, img.shape[:2]) 111 | 112 | # udpate state 113 | self.center_pos = np.array([cx, cy]) 114 | self.size = np.array([width, height]) 115 | 116 | bbox = [cx - width / 2, 117 | cy - height / 2, 118 | width, 119 | height] 120 | best_score = score[best_idx] 121 | return { 122 | 'bbox': bbox, 123 | 'best_score': best_score 124 | } 125 | 126 | 127 | class SiamBANPipeline(): 128 | def __init__(self, args, enhancer=None, denoiser=None): 129 | super(SiamBANPipeline, self).__init__() 130 | if not args.config: 131 | args.config = './experiments/SiamBAN/config.yaml' 132 | if not args.snapshot: 133 | args.snapshot = './experiments/SiamBAN/model.pth' 134 | 135 | cfg.merge_from_file(args.config) 136 | self.model = ModelBuilderBAN() 137 | self.model = load_pretrain(self.model, args.snapshot).cuda().eval() 138 | self.enhancer = enhancer 139 | self.denoiser = denoiser 140 | self.tracker = DNS_SiamBANTracker(self.model, self.enhancer, self.denoiser) 141 | 142 | def init(self, img, gt_bbox): 143 | cx, cy, w, h = get_axis_aligned_bbox(np.array(gt_bbox)) 144 | gt_bbox_ = [cx-(w-1)/2, cy-(h-1)/2, w, h] 145 | self.tracker.init(img, gt_bbox_) 146 | pred_bbox = gt_bbox_ 147 | 148 | return pred_bbox 149 | 150 | def track(self, img): 151 | outputs = self.tracker.track(img) 152 | pred_bbox = outputs['bbox'] 153 | 154 | return pred_bbox -------------------------------------------------------------------------------- /snot/pipelines/siammask_pipeline.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | 8 | from snot.core.config import cfg 9 | from snot.models.model_builder import ModelBuilder 10 | from snot.trackers.tracker_builder import build_tracker 11 | from snot.utils.bbox import get_axis_aligned_bbox 12 | from snot.utils.model_load import load_pretrain 13 | 14 | 15 | class SiamMaskPipeline(): 16 | def __init__(self, args): 17 | super(SiamMaskPipeline, self).__init__() 18 | if not args.config: 19 | args.config = './experiments/SiamMask_r50/config.yaml' 20 | if not args.snapshot: 21 | args.snapshot = './experiments/SiamMask_r50/model.pth' 22 | 23 | cfg.merge_from_file(args.config) 24 | self.model = ModelBuilder() 25 | self.model = load_pretrain(self.model, args.snapshot).cuda().eval() 26 | self.tracker = build_tracker(self.model) 27 | 28 | def init(self, img, gt_bbox): 29 | cx, cy, w, h = get_axis_aligned_bbox(np.array(gt_bbox)) 30 | gt_bbox_ = [cx-(w-1)/2, cy-(h-1)/2, w, h] 31 | self.tracker.init(img, gt_bbox_) 32 | pred_bbox = gt_bbox_ 33 | 34 | return pred_bbox 35 | 36 | def track(self, img): 37 | outputs = self.tracker.track(img) 38 | pred_bbox = outputs['bbox'] 39 | 40 | return pred_bbox 41 | 42 | -------------------------------------------------------------------------------- /snot/trackers/base_tracker.py: -------------------------------------------------------------------------------- 1 | # Parts of this code come from https://github.com/STVIR/pysot 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import numpy as np 8 | import torch 9 | import cv2 10 | 11 | from snot.core.config import cfg 12 | 13 | class BaseTracker(object): 14 | """ Base tracker of single objec tracking 15 | """ 16 | def init(self, img, bbox): 17 | """ 18 | args: 19 | img(np.ndarray): BGR image 20 | bbox(list): [x, y, width, height] 21 | x, y need to be 0-based 22 | """ 23 | raise NotImplementedError 24 | 25 | def track(self, img): 26 | """ 27 | args: 28 | img(np.ndarray): BGR image 29 | return: 30 | bbox(list):[x, y, width, height] 31 | """ 32 | raise NotImplementedError 33 | 34 | 35 | class SiameseTracker(BaseTracker): 36 | def get_subwindow(self, im, pos, model_sz, original_sz, avg_chans): 37 | """ 38 | args: 39 | im: bgr based image 40 | pos: center position 41 | model_sz: exemplar size 42 | s_z: original size 43 | avg_chans: channel average 44 | """ 45 | if isinstance(pos, float): 46 | pos = [pos, pos] 47 | sz = original_sz 48 | im_sz = im.shape 49 | c = (original_sz + 1) / 2 50 | # context_xmin = round(pos[0] - c) # py2 and py3 round 51 | context_xmin = np.floor(pos[0] - c + 0.5) 52 | context_xmax = context_xmin + sz - 1 53 | # context_ymin = round(pos[1] - c) 54 | context_ymin = np.floor(pos[1] - c + 0.5) 55 | context_ymax = context_ymin + sz - 1 56 | left_pad = int(max(0., -context_xmin)) 57 | top_pad = int(max(0., -context_ymin)) 58 | right_pad = int(max(0., context_xmax - im_sz[1] + 1)) 59 | bottom_pad = int(max(0., context_ymax - im_sz[0] + 1)) 60 | 61 | context_xmin = context_xmin + left_pad 62 | context_xmax = context_xmax + left_pad 63 | context_ymin = context_ymin + top_pad 64 | context_ymax = context_ymax + top_pad 65 | 66 | r, c, k = im.shape 67 | if any([top_pad, bottom_pad, left_pad, right_pad]): 68 | size = (r + top_pad + bottom_pad, c + left_pad + right_pad, k) 69 | te_im = np.zeros(size, np.uint8) 70 | te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im 71 | if top_pad: 72 | te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans 73 | if bottom_pad: 74 | te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans 75 | if left_pad: 76 | te_im[:, 0:left_pad, :] = avg_chans 77 | if right_pad: 78 | te_im[:, c + left_pad:, :] = avg_chans 79 | im_patch = te_im[int(context_ymin):int(context_ymax + 1), 80 | int(context_xmin):int(context_xmax + 1), :] 81 | else: 82 | im_patch = im[int(context_ymin):int(context_ymax + 1), 83 | int(context_xmin):int(context_xmax + 1), :] 84 | 85 | if not np.array_equal(model_sz, original_sz): 86 | im_patch = cv2.resize(im_patch, (model_sz, model_sz)) 87 | im_patch = im_patch.transpose(2, 0, 1) 88 | im_patch = im_patch[np.newaxis, :, :, :] 89 | im_patch = im_patch.astype(np.float32) 90 | im_patch = torch.from_numpy(im_patch) 91 | if cfg.CUDA: 92 | im_patch = im_patch.cuda() 93 | return im_patch 94 | -------------------------------------------------------------------------------- /snot/trackers/tracker_builder.py: -------------------------------------------------------------------------------- 1 | # Parts of this code come from https://github.com/STVIR/pysot 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | from snot.core.config import cfg 8 | from snot.trackers.siamrpn_tracker import SiamRPNTracker 9 | from snot.trackers.siammask_tracker import SiamMaskTracker 10 | from snot.trackers.siamrpnlt_tracker import SiamRPNLTTracker 11 | 12 | TRACKS = { 13 | 'SiamRPNTracker': SiamRPNTracker, 14 | 'SiamMaskTracker': SiamMaskTracker, 15 | 'SiamRPNLTTracker': SiamRPNLTTracker 16 | } 17 | 18 | 19 | def build_tracker(model): 20 | return TRACKS[cfg.TRACK.TYPE](model) 21 | -------------------------------------------------------------------------------- /snot/trackers/tracker_builder_ban.py: -------------------------------------------------------------------------------- 1 | # Parts of this code come from https://github.com/hqucv/siamban 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | from snot.core.config_ban import cfg 8 | from snot.trackers.siamban_tracker import SiamBANTracker 9 | 10 | TRACKS = { 11 | 'SiamBANTracker': SiamBANTracker 12 | } 13 | 14 | 15 | def build_tracker_ban(model): 16 | return TRACKS[cfg.TRACK.TYPE](model) 17 | -------------------------------------------------------------------------------- /snot/utils/anchor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import math 7 | 8 | import numpy as np 9 | 10 | from snot.utils.bbox import corner2center, center2corner 11 | 12 | 13 | class Anchors: 14 | """ 15 | This class generate anchors. 16 | """ 17 | def __init__(self, stride, ratios, scales, image_center=0, size=0): 18 | self.stride = stride 19 | self.ratios = ratios 20 | self.scales = scales 21 | self.image_center = image_center 22 | self.size = size 23 | 24 | self.anchor_num = len(self.scales) * len(self.ratios) 25 | 26 | self.anchors = None 27 | 28 | self.generate_anchors() 29 | 30 | def generate_anchors(self): 31 | """ 32 | generate anchors based on predefined configuration 33 | """ 34 | self.anchors = np.zeros((self.anchor_num, 4), dtype=np.float32) 35 | size = self.stride * self.stride 36 | count = 0 37 | for r in self.ratios: 38 | ws = int(math.sqrt(size*1. / r)) 39 | hs = int(ws * r) 40 | 41 | for s in self.scales: 42 | w = ws * s 43 | h = hs * s 44 | self.anchors[count][:] = [-w*0.5, -h*0.5, w*0.5, h*0.5][:] 45 | count += 1 46 | 47 | def generate_all_anchors(self, im_c, size): 48 | """ 49 | im_c: image center 50 | size: image size 51 | """ 52 | if self.image_center == im_c and self.size == size: 53 | return False 54 | self.image_center = im_c 55 | self.size = size 56 | 57 | a0x = im_c - size // 2 * self.stride 58 | ori = np.array([a0x] * 4, dtype=np.float32) 59 | zero_anchors = self.anchors + ori 60 | 61 | x1 = zero_anchors[:, 0] 62 | y1 = zero_anchors[:, 1] 63 | x2 = zero_anchors[:, 2] 64 | y2 = zero_anchors[:, 3] 65 | 66 | x1, y1, x2, y2 = map(lambda x: x.reshape(self.anchor_num, 1, 1), 67 | [x1, y1, x2, y2]) 68 | cx, cy, w, h = corner2center([x1, y1, x2, y2]) 69 | 70 | disp_x = np.arange(0, size).reshape(1, 1, -1) * self.stride 71 | disp_y = np.arange(0, size).reshape(1, -1, 1) * self.stride 72 | 73 | cx = cx + disp_x 74 | cy = cy + disp_y 75 | 76 | # broadcast 77 | zero = np.zeros((self.anchor_num, size, size), dtype=np.float32) 78 | cx, cy, w, h = map(lambda x: x + zero, [cx, cy, w, h]) 79 | x1, y1, x2, y2 = center2corner([cx, cy, w, h]) 80 | 81 | self.all_anchors = (np.stack([x1, y1, x2, y2]).astype(np.float32), 82 | np.stack([cx, cy, w, h]).astype(np.float32)) 83 | return True 84 | -------------------------------------------------------------------------------- /snot/utils/bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from collections import namedtuple 7 | 8 | import numpy as np 9 | 10 | 11 | Corner = namedtuple('Corner', 'x1 y1 x2 y2') 12 | # alias 13 | BBox = Corner 14 | Center = namedtuple('Center', 'x y w h') 15 | 16 | 17 | def corner2center(corner): 18 | """ convert (x1, y1, x2, y2) to (cx, cy, w, h) 19 | Args: 20 | conrner: Corner or np.array (4*N) 21 | Return: 22 | Center or np.array (4 * N) 23 | """ 24 | if isinstance(corner, Corner): 25 | x1, y1, x2, y2 = corner 26 | return Center((x1 + x2) * 0.5, (y1 + y2) * 0.5, (x2 - x1), (y2 - y1)) 27 | else: 28 | x1, y1, x2, y2 = corner[0], corner[1], corner[2], corner[3] 29 | x = (x1 + x2) * 0.5 30 | y = (y1 + y2) * 0.5 31 | w = x2 - x1 32 | h = y2 - y1 33 | return x, y, w, h 34 | 35 | 36 | def center2corner(center): 37 | """ convert (cx, cy, w, h) to (x1, y1, x2, y2) 38 | Args: 39 | center: Center or np.array (4 * N) 40 | Return: 41 | center or np.array (4 * N) 42 | """ 43 | if isinstance(center, Center): 44 | x, y, w, h = center 45 | return Corner(x - w * 0.5, y - h * 0.5, x + w * 0.5, y + h * 0.5) 46 | else: 47 | x, y, w, h = center[0], center[1], center[2], center[3] 48 | x1 = x - w * 0.5 49 | y1 = y - h * 0.5 50 | x2 = x + w * 0.5 51 | y2 = y + h * 0.5 52 | return x1, y1, x2, y2 53 | 54 | 55 | def IoU(rect1, rect2): 56 | """ caculate interection over union 57 | Args: 58 | rect1: (x1, y1, x2, y2) 59 | rect2: (x1, y1, x2, y2) 60 | Returns: 61 | iou 62 | """ 63 | # overlap 64 | x1, y1, x2, y2 = rect1[0], rect1[1], rect1[2], rect1[3] 65 | tx1, ty1, tx2, ty2 = rect2[0], rect2[1], rect2[2], rect2[3] 66 | 67 | xx1 = np.maximum(tx1, x1) 68 | yy1 = np.maximum(ty1, y1) 69 | xx2 = np.minimum(tx2, x2) 70 | yy2 = np.minimum(ty2, y2) 71 | 72 | ww = np.maximum(0, xx2 - xx1) 73 | hh = np.maximum(0, yy2 - yy1) 74 | 75 | area = (x2-x1) * (y2-y1) 76 | target_a = (tx2-tx1) * (ty2 - ty1) 77 | inter = ww * hh 78 | iou = inter / (area + target_a - inter) 79 | return iou 80 | 81 | 82 | def cxy_wh_2_rect(pos, sz): 83 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 0-index 84 | """ 85 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) 86 | 87 | 88 | def rect_2_cxy_wh(rect): 89 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 0-index 90 | """ 91 | return np.array([rect[0]+rect[2]/2, rect[1]+rect[3]/2]), \ 92 | np.array([rect[2], rect[3]]) 93 | 94 | 95 | def cxy_wh_2_rect1(pos, sz): 96 | """ convert (cx, cy, w, h) to (x1, y1, w, h), 1-index 97 | """ 98 | return np.array([pos[0]-sz[0]/2+1, pos[1]-sz[1]/2+1, sz[0], sz[1]]) 99 | 100 | 101 | def rect1_2_cxy_wh(rect): 102 | """ convert (x1, y1, w, h) to (cx, cy, w, h), 1-index 103 | """ 104 | return np.array([rect[0]+rect[2]/2-1, rect[1]+rect[3]/2-1]), \ 105 | np.array([rect[2], rect[3]]) 106 | 107 | 108 | def get_axis_aligned_bbox(region): 109 | """ convert region to (cx, cy, w, h) that represent by axis aligned box 110 | """ 111 | nv = region.size 112 | if nv == 8: 113 | cx = np.mean(region[0::2]) 114 | cy = np.mean(region[1::2]) 115 | x1 = min(region[0::2]) 116 | x2 = max(region[0::2]) 117 | y1 = min(region[1::2]) 118 | y2 = max(region[1::2]) 119 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * \ 120 | np.linalg.norm(region[2:4] - region[4:6]) 121 | A2 = (x2 - x1) * (y2 - y1) 122 | s = np.sqrt(A1 / A2) 123 | w = s * (x2 - x1) + 1 124 | h = s * (y2 - y1) + 1 125 | else: 126 | x = region[0] 127 | y = region[1] 128 | w = region[2] 129 | h = region[3] 130 | cx = x+w/2 131 | cy = y+h/2 132 | return cx, cy, w, h 133 | 134 | 135 | def get_min_max_bbox(region): 136 | """ convert region to (cx, cy, w, h) that represent by mim-max box 137 | """ 138 | nv = region.size 139 | if nv == 8: 140 | cx = np.mean(region[0::2]) 141 | cy = np.mean(region[1::2]) 142 | x1 = min(region[0::2]) 143 | x2 = max(region[0::2]) 144 | y1 = min(region[1::2]) 145 | y2 = max(region[1::2]) 146 | w = x2 - x1 147 | h = y2 - y1 148 | else: 149 | x = region[0] 150 | y = region[1] 151 | w = region[2] 152 | h = region[3] 153 | cx = x+w/2 154 | cy = y+h/2 155 | return cx, cy, w, h 156 | -------------------------------------------------------------------------------- /snot/utils/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .ope_benchmark import OPEBenchmark 2 | -------------------------------------------------------------------------------- /snot/utils/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import os 7 | 8 | from colorama import Fore, Style 9 | 10 | __all__ = ['commit', 'describe'] 11 | 12 | 13 | def _exec(cmd): 14 | f = os.popen(cmd, 'r', 1) 15 | return f.read().strip() 16 | 17 | 18 | def _bold(s): 19 | return "\033[1m%s\033[0m" % s 20 | 21 | 22 | def _color(s): 23 | # return f'{Fore.RED}{s}{Style.RESET_ALL}' 24 | return "{}{}{}".format(Fore.RED,s,Style.RESET_ALL) 25 | 26 | 27 | def _describe(model, lines=None, spaces=0): 28 | head = " " * spaces 29 | for name, p in model.named_parameters(): 30 | if '.' in name: 31 | continue 32 | if p.requires_grad: 33 | name = _color(name) 34 | line = "{head}- {name}".format(head=head, name=name) 35 | lines.append(line) 36 | 37 | for name, m in model.named_children(): 38 | space_num = len(name) + spaces + 1 39 | if m.training: 40 | name = _color(name) 41 | line = "{head}.{name} ({type})".format( 42 | head=head, 43 | name=name, 44 | type=m.__class__.__name__) 45 | lines.append(line) 46 | _describe(m, lines, space_num) 47 | 48 | 49 | def commit(): 50 | root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) 51 | cmd = "cd {}; git log | head -n1 | awk '{{print $2}}'".format(root) 52 | commit = _exec(cmd) 53 | cmd = "cd {}; git log --oneline | head -n1".format(root) 54 | commit_log = _exec(cmd) 55 | return "commit : {}\n log : {}".format(commit, commit_log) 56 | 57 | 58 | def describe(net, name=None): 59 | num = 0 60 | lines = [] 61 | if name is not None: 62 | lines.append(name) 63 | num = len(name) 64 | _describe(net, lines, num) 65 | return "\n".join(lines) 66 | 67 | 68 | def bbox_clip(x, min_value, max_value): 69 | new_x = max(min_value, min(x, max_value)) 70 | return new_x 71 | -------------------------------------------------------------------------------- /snot/utils/model_load.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import logging 7 | 8 | import torch 9 | 10 | 11 | logger = logging.getLogger('global') 12 | 13 | 14 | def check_keys(model, pretrained_state_dict): 15 | ckpt_keys = set(pretrained_state_dict.keys()) 16 | model_keys = set(model.state_dict().keys()) 17 | used_pretrained_keys = model_keys & ckpt_keys 18 | unused_pretrained_keys = ckpt_keys - model_keys 19 | missing_keys = model_keys - ckpt_keys 20 | # filter 'num_batches_tracked' 21 | missing_keys = [x for x in missing_keys 22 | if not x.endswith('num_batches_tracked')] 23 | if len(missing_keys) > 0: 24 | logger.info('[Warning] missing keys: {}'.format(missing_keys)) 25 | logger.info('missing keys:{}'.format(len(missing_keys))) 26 | if len(unused_pretrained_keys) > 0: 27 | logger.info('[Warning] unused_pretrained_keys: {}'.format( 28 | unused_pretrained_keys)) 29 | logger.info('unused checkpoint keys:{}'.format( 30 | len(unused_pretrained_keys))) 31 | logger.info('used keys:{}'.format(len(used_pretrained_keys))) 32 | assert len(used_pretrained_keys) > 0, \ 33 | 'load NONE from pretrained checkpoint' 34 | return True 35 | 36 | 37 | def remove_prefix(state_dict, prefix): 38 | ''' Old style model is stored with all names of parameters 39 | share common prefix 'module.' ''' 40 | logger.info('remove prefix \'{}\''.format(prefix)) 41 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 42 | return {f(key): value for key, value in state_dict.items()} 43 | 44 | 45 | def load_pretrain(model, pretrained_path): 46 | logger.info('load pretrained model from {}'.format(pretrained_path)) 47 | device = torch.cuda.current_device() 48 | pretrained_dict = torch.load(pretrained_path, 49 | map_location=lambda storage, loc: storage.cuda(device)) 50 | if "state_dict" in pretrained_dict.keys(): 51 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 52 | 'module.') 53 | else: 54 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 55 | 56 | try: 57 | check_keys(model, pretrained_dict) 58 | except: 59 | logger.info('[Warning]: using pretrain as features.\ 60 | Adding "features." as prefix') 61 | new_dict = {} 62 | for k, v in pretrained_dict.items(): 63 | k = 'features.' + k 64 | new_dict[k] = v 65 | pretrained_dict = new_dict 66 | check_keys(model, pretrained_dict) 67 | model.load_state_dict(pretrained_dict, strict=False) 68 | return model 69 | 70 | 71 | def restore_from(model, optimizer, ckpt_path): 72 | device = torch.cuda.current_device() 73 | ckpt = torch.load(ckpt_path, 74 | map_location=lambda storage, loc: storage.cuda(device)) 75 | epoch = ckpt['epoch'] 76 | 77 | ckpt_model_dict = remove_prefix(ckpt['state_dict'], 'module.') 78 | check_keys(model, ckpt_model_dict) 79 | model.load_state_dict(ckpt_model_dict, strict=False) 80 | 81 | check_keys(optimizer, ckpt['optimizer']) 82 | optimizer.load_state_dict(ckpt['optimizer']) 83 | return model, optimizer, epoch 84 | -------------------------------------------------------------------------------- /snot/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author fangyi.zhang@vipl.ict.ac.cn 3 | """ 4 | import numpy as np 5 | 6 | def overlap_ratio(rect1, rect2): 7 | '''Compute overlap ratio between two rects 8 | Args 9 | rect:2d array of N x [x,y,w,h] 10 | Return: 11 | iou 12 | ''' 13 | # if rect1.ndim==1: 14 | # rect1 = rect1[np.newaxis, :] 15 | # if rect2.ndim==1: 16 | # rect2 = rect2[np.newaxis, :] 17 | left = np.maximum(rect1[:,0], rect2[:,0]) 18 | right = np.minimum(rect1[:,0]+rect1[:,2], rect2[:,0]+rect2[:,2]) 19 | top = np.maximum(rect1[:,1], rect2[:,1]) 20 | bottom = np.minimum(rect1[:,1]+rect1[:,3], rect2[:,1]+rect2[:,3]) 21 | 22 | intersect = np.maximum(0,right - left) * np.maximum(0,bottom - top) 23 | union = rect1[:,2]*rect1[:,3] + rect2[:,2]*rect2[:,3] - intersect 24 | iou = intersect / union 25 | iou = np.maximum(np.minimum(1, iou), 0) 26 | return iou 27 | 28 | def success_overlap(gt_bb, result_bb, n_frame): 29 | thresholds_overlap = np.arange(0, 1.05, 0.05) 30 | success = np.zeros(len(thresholds_overlap)) 31 | iou = np.ones(len(gt_bb)) * (-1) 32 | # mask = np.sum(gt_bb > 0, axis=1) == 4 #TODO check all dataset 33 | mask = np.sum(gt_bb[:, 2:] > 0, axis=1) == 2 34 | iou[mask] = overlap_ratio(gt_bb[mask], result_bb[mask]) 35 | for i in range(len(thresholds_overlap)): 36 | success[i] = np.sum(iou > thresholds_overlap[i]) / float(n_frame) 37 | return success 38 | 39 | def success_error(gt_center, result_center, thresholds, n_frame): 40 | # n_frame = len(gt_center) 41 | success = np.zeros(len(thresholds)) 42 | dist = np.ones(len(gt_center)) * (-1) 43 | mask = np.sum(gt_center > 0, axis=1) == 2 44 | dist[mask] = np.sqrt(np.sum( 45 | np.power(gt_center[mask] - result_center[mask], 2), axis=1)) 46 | for i in range(len(thresholds)): 47 | success[i] = np.sum(dist <= thresholds[i]) / float(n_frame) 48 | return success 49 | 50 | 51 | -------------------------------------------------------------------------------- /snot/utils/utils_ad.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | 23 | 24 | def im_to_numpy(img): 25 | img = to_numpy(img) 26 | img = np.transpose(img, (1, 2, 0)) # H*W*C 27 | return img 28 | 29 | 30 | def im_to_torch(img): 31 | img = np.transpose(img, (2, 0, 1)) # C*H*W 32 | img = to_torch(img).float() 33 | return img 34 | 35 | 36 | def torch_to_img(img): 37 | img = to_numpy(torch.squeeze(img, 0)) 38 | img = np.transpose(img, (1, 2, 0)) # H*W*C 39 | return img 40 | 41 | 42 | def get_subwindow_tracking(im, pos, model_sz, original_sz, avg_chans, out_mode='torch', new=False): 43 | 44 | if isinstance(pos, float): 45 | pos = [pos, pos] 46 | sz = original_sz 47 | im_sz = im.shape 48 | c = (original_sz+1) / 2 49 | context_xmin = round(pos[0] - c) # floor(pos(2) - sz(2) / 2); 50 | context_xmax = context_xmin + sz - 1 51 | context_ymin = round(pos[1] - c) # floor(pos(1) - sz(1) / 2); 52 | context_ymax = context_ymin + sz - 1 53 | left_pad = int(max(0., -context_xmin)) 54 | top_pad = int(max(0., -context_ymin)) 55 | right_pad = int(max(0., context_xmax - im_sz[1] + 1)) 56 | bottom_pad = int(max(0., context_ymax - im_sz[0] + 1)) 57 | 58 | context_xmin = context_xmin + left_pad 59 | context_xmax = context_xmax + left_pad 60 | context_ymin = context_ymin + top_pad 61 | context_ymax = context_ymax + top_pad 62 | 63 | # zzp: a more easy speed version 64 | r, c, k = im.shape 65 | if any([top_pad, bottom_pad, left_pad, right_pad]): 66 | te_im = np.zeros((r + top_pad + bottom_pad, c + left_pad + right_pad, k), np.uint8) # 0 is better than 1 initialization 67 | te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im 68 | if top_pad: 69 | te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans 70 | if bottom_pad: 71 | te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans 72 | if left_pad: 73 | te_im[:, 0:left_pad, :] = avg_chans 74 | if right_pad: 75 | te_im[:, c + left_pad:, :] = avg_chans 76 | im_patch_original = te_im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :] 77 | else: 78 | im_patch_original = im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :] 79 | 80 | if not np.array_equal(model_sz, original_sz): 81 | im_patch = cv2.resize(im_patch_original, (model_sz, model_sz)) # zzp: use cv to get a better speed 82 | else: 83 | im_patch = im_patch_original 84 | 85 | return im_to_torch(im_patch) if out_mode in 'torch' else im_patch 86 | -------------------------------------------------------------------------------- /snot/utils/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .draw_success_precision import draw_success_precision 2 | -------------------------------------------------------------------------------- /snot/utils/visualization/draw_success_precision.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | from .draw_utils import COLOR, LINE_STYLE 5 | 6 | def draw_success_precision(success_ret, name, videos, attr, precision_ret=None, 7 | norm_precision_ret=None, bold_name=None, axis=[0, 1]): 8 | # success plot 9 | fig, ax = plt.subplots() 10 | ax.grid(b=True) 11 | ax.set_aspect(1) 12 | plt.xlabel('Overlap threshold') 13 | plt.ylabel('Success rate') 14 | if attr == 'ALL': 15 | plt.title(r'\textbf{Success plots of OPE on %s}' % (name)) 16 | else: 17 | plt.title(r'\textbf{Success plots of OPE - %s}' % (attr)) 18 | plt.axis([0, 1]+axis) 19 | success = {} 20 | thresholds = np.arange(0, 1.05, 0.05) 21 | for tracker_name in success_ret.keys(): 22 | value = [v for k, v in success_ret[tracker_name].items() if k in videos] 23 | success[tracker_name] = np.mean(value) 24 | for idx, (tracker_name, auc) in \ 25 | enumerate(sorted(success.items(), key=lambda x:x[1], reverse=True)): 26 | if tracker_name == bold_name: 27 | label = r"\textbf{[%.3f] %s}" % (auc, tracker_name) 28 | else: 29 | label = "[%.3f] " % (auc) + tracker_name 30 | value = [v for k, v in success_ret[tracker_name].items() if k in videos] 31 | plt.plot(thresholds, np.mean(value, axis=0), 32 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2) 33 | ax.legend(loc='lower left', labelspacing=0.2) 34 | ax.autoscale(enable=True, axis='both', tight=True) 35 | xmin, xmax, ymin, ymax = plt.axis() 36 | ax.autoscale(enable=False) 37 | ymax += 0.03 38 | plt.axis([xmin, xmax, ymin, ymax]) 39 | plt.xticks(np.arange(xmin, xmax+0.01, 0.1)) 40 | plt.yticks(np.arange(ymin, ymax, 0.1)) 41 | ax.set_aspect((xmax - xmin)/(ymax-ymin)) 42 | plt.show() 43 | 44 | if precision_ret: 45 | # norm precision plot 46 | fig, ax = plt.subplots() 47 | ax.grid(b=True) 48 | ax.set_aspect(50) 49 | plt.xlabel('Location error threshold') 50 | plt.ylabel('Precision') 51 | if attr == 'ALL': 52 | plt.title(r'\textbf{Precision plots of OPE on %s}' % (name)) 53 | else: 54 | plt.title(r'\textbf{Precision plots of OPE - %s}' % (attr)) 55 | plt.axis([0, 50]+axis) 56 | precision = {} 57 | thresholds = np.arange(0, 51, 1) 58 | for tracker_name in precision_ret.keys(): 59 | value = [v for k, v in precision_ret[tracker_name].items() if k in videos] 60 | precision[tracker_name] = np.mean(value, axis=0)[20] 61 | for idx, (tracker_name, pre) in \ 62 | enumerate(sorted(precision.items(), key=lambda x:x[1], reverse=True)): 63 | if tracker_name == bold_name: 64 | label = r"\textbf{[%.3f] %s}" % (pre, tracker_name) 65 | else: 66 | label = "[%.3f] " % (pre) + tracker_name 67 | value = [v for k, v in precision_ret[tracker_name].items() if k in videos] 68 | plt.plot(thresholds, np.mean(value, axis=0), 69 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2) 70 | ax.legend(loc='lower right', labelspacing=0.2) 71 | ax.autoscale(enable=True, axis='both', tight=True) 72 | xmin, xmax, ymin, ymax = plt.axis() 73 | ax.autoscale(enable=False) 74 | ymax += 0.03 75 | plt.axis([xmin, xmax, ymin, ymax]) 76 | plt.xticks(np.arange(xmin, xmax+0.01, 5)) 77 | plt.yticks(np.arange(ymin, ymax, 0.1)) 78 | ax.set_aspect((xmax - xmin)/(ymax-ymin)) 79 | plt.show() 80 | 81 | # norm precision plot 82 | if norm_precision_ret: 83 | fig, ax = plt.subplots() 84 | ax.grid(b=True) 85 | plt.xlabel('Location error threshold') 86 | plt.ylabel('Precision') 87 | if attr == 'ALL': 88 | plt.title(r'\textbf{Normalized Precision plots of OPE on %s}' % (name)) 89 | else: 90 | plt.title(r'\textbf{Normalized Precision plots of OPE - %s}' % (attr)) 91 | norm_precision = {} 92 | thresholds = np.arange(0, 51, 1) / 100 93 | for tracker_name in precision_ret.keys(): 94 | value = [v for k, v in norm_precision_ret[tracker_name].items() if k in videos] 95 | norm_precision[tracker_name] = np.mean(value, axis=0)[20] 96 | for idx, (tracker_name, pre) in \ 97 | enumerate(sorted(norm_precision.items(), key=lambda x:x[1], reverse=True)): 98 | if tracker_name == bold_name: 99 | label = r"\textbf{[%.3f] %s}" % (pre, tracker_name) 100 | else: 101 | label = "[%.3f] " % (pre) + tracker_name 102 | value = [v for k, v in norm_precision_ret[tracker_name].items() if k in videos] 103 | plt.plot(thresholds, np.mean(value, axis=0), 104 | color=COLOR[idx], linestyle=LINE_STYLE[idx],label=label, linewidth=2) 105 | ax.legend(loc='lower right', labelspacing=0.2) 106 | ax.autoscale(enable=True, axis='both', tight=True) 107 | xmin, xmax, ymin, ymax = plt.axis() 108 | ax.autoscale(enable=False) 109 | ymax += 0.03 110 | plt.axis([xmin, xmax, ymin, ymax]) 111 | plt.xticks(np.arange(xmin, xmax+0.01, 0.05)) 112 | plt.yticks(np.arange(ymin, ymax, 0.1)) 113 | ax.set_aspect((xmax - xmin)/(ymax-ymin)) 114 | plt.show() 115 | -------------------------------------------------------------------------------- /snot/utils/visualization/draw_utils.py: -------------------------------------------------------------------------------- 1 | 2 | COLOR = ((1, 0, 0), 3 | (0, 1, 0), 4 | (1, 0, 1), 5 | (1, 1, 0), 6 | (0 , 162/255, 232/255), 7 | (0.5, 0.5, 0.5), 8 | (0, 0, 1), 9 | (0, 1, 1), 10 | (136/255, 0 , 21/255), 11 | (255/255, 127/255, 39/255), 12 | (0, 0, 0)) 13 | 14 | LINE_STYLE = ['-', '--', ':', '-', '--', ':', '-', '--', ':', '-'] 15 | 16 | MARKER_STYLE = ['o', 'v', '<', '*', 'D', 'x', '.', 'x', '<', '.'] 17 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import argparse 7 | import os 8 | 9 | import cv2 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from snot.pipelines.pipeline_builder import build_pipeline 14 | from snot.datasets import DatasetFactory, datapath 15 | from denoiser.denoiser_builder import build_denoiser 16 | from enhancer.enhancer_builder import build_enhancer 17 | 18 | torch.set_num_threads(1) 19 | 20 | parser = argparse.ArgumentParser(description='siamese tracking') 21 | parser.add_argument('--dataset', default='', type=str, 22 | help='datasets') 23 | parser.add_argument('--datasetpath', default='', type=str, 24 | help='the path of datasets') 25 | parser.add_argument('--config', default='./experiments/SiamAPN/config.yaml', type=str, 26 | help='config file') 27 | parser.add_argument('--snapshot', default='./experiments/SiamAPN/model.pth', type=str, 28 | help='snapshot of models to eval') 29 | parser.add_argument('--trackername', default='SiamAPN', type=str, 30 | help='name of tracker') 31 | 32 | parser.add_argument('--e_weights', default='./experiments/SCT/model.pth', type=str, 33 | help='weights') 34 | parser.add_argument('--enhancername', default='SCT', type=str, 35 | help='name of enhancer') 36 | 37 | parser.add_argument('--d_weights', default='./experiments/CDT/model.pth', type=str, 38 | help='weights') 39 | parser.add_argument('--denoisername', default='CDT', type=str, 40 | help='name of denoiser') 41 | 42 | parser.add_argument('--video', default='', type=str, 43 | help='eval one special video') 44 | parser.add_argument('--vis', default=True, action='store_true', 45 | help='whether visualzie result') 46 | args = parser.parse_args() 47 | 48 | 49 | 50 | def main(): 51 | if args.enhancername.split('-')[0]: 52 | enhancer = build_enhancer(args) 53 | else: 54 | enhancer = None 55 | if args.denoisername.split('-')[0]: 56 | denoiser = build_denoiser(args) 57 | else: 58 | denoiser = None 59 | pipeline = build_pipeline(args, enhancer=enhancer, denoiser=denoiser) 60 | 61 | 62 | for dataset_name in args.dataset.split(','): 63 | # create dataset 64 | try: 65 | dataset_root = args.datasetpath + datapath[dataset_name] 66 | except: 67 | print('?') 68 | dataset = DatasetFactory.create_dataset(name=dataset_name, 69 | dataset_root=dataset_root, 70 | load_img=False) 71 | model_name = args.trackername 72 | 73 | # OPE tracking 74 | IDX = 0 75 | TOC = 0 76 | for v_idx, video in enumerate(dataset): 77 | if args.video != '': 78 | # test one special video 79 | if video.name != args.video: 80 | continue 81 | toc = 0 82 | pred_bboxes = [] 83 | for idx, (img, gt_bbox) in enumerate(video): 84 | tic = cv2.getTickCount() 85 | if idx == 0: 86 | pred_bbox = pipeline.init(img, gt_bbox) 87 | pred_bboxes.append(pred_bbox) 88 | else: 89 | pred_bbox = pipeline.track(img) 90 | pred_bboxes.append(pred_bbox) 91 | toc += cv2.getTickCount() - tic 92 | if idx == 0: 93 | cv2.destroyAllWindows() 94 | if args.vis and idx > 0: 95 | try: 96 | gt_bbox = list(map(int, gt_bbox)) 97 | cv2.rectangle(img, (gt_bbox[0], gt_bbox[1]), 98 | (gt_bbox[0]+gt_bbox[2], gt_bbox[1]+gt_bbox[3]), (0, 255, 0), 3) 99 | except: 100 | pass 101 | pred_bbox = list(map(int, pred_bbox)) 102 | cv2.rectangle(img, (gt_bbox[0], gt_bbox[1]), 103 | (gt_bbox[0]+gt_bbox[2], gt_bbox[1]+gt_bbox[3]), (0, 255, 0), 3) 104 | cv2.rectangle(img, (pred_bbox[0], pred_bbox[1]), 105 | (pred_bbox[0]+pred_bbox[2], pred_bbox[1]+pred_bbox[3]), (0, 255, 255), 3) 106 | cv2.putText(img, str(idx), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2) 107 | cv2.imshow(video.name, img) 108 | cv2.waitKey(1) 109 | toc /= cv2.getTickFrequency() 110 | # save results 111 | model_path = os.path.join('results', dataset_name, model_name) 112 | if not os.path.isdir(model_path): 113 | os.makedirs(model_path) 114 | result_path = os.path.join(model_path, '{}.txt'.format(video.name)) 115 | with open(result_path, 'w') as f: 116 | for x in pred_bboxes: 117 | f.write(','.join([str(i) for i in x])+'\n') 118 | print('({:3d}) Video: {:12s} Time: {:5.1f}s Speed: {:3.1f}fps'.format(v_idx+1, video.name, toc, idx / toc)) 119 | IDX += idx 120 | TOC += toc 121 | print('Total Time: {:5.1f}s Average Speed: {:3.1f}fps'.format(TOC, IDX / TOC)) 122 | fps_path = os.path.join('results', dataset_name, '{}.txt'.format(model_name)) 123 | with open(fps_path, 'w') as f: 124 | f.write('Time:{:5.1f},Speed:{:3.1f}'.format(TOC, IDX / TOC)) 125 | 126 | if __name__ == '__main__': 127 | main() 128 | --------------------------------------------------------------------------------