├── .gitignore ├── DCFNet-JCST.pdf ├── LICENSE ├── README.md ├── track ├── DCFNet.py ├── dataset │ ├── OTB2015.json │ └── gen_otb2013.py ├── eval_otb.py ├── net.py ├── net_param.mat ├── param.pth ├── tune_otb.py └── util.py └── train ├── dataset.py ├── dataset ├── compute-image-mean.py ├── crop_image.py ├── gen_snippet.py └── parse_vid.py ├── net.py └── train_DCFNet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /DCFNet-JCST.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foolwood/DCFNet_pytorch/b8434baa2d136df8f55c1addb3e77f40b3c379fc/DCFNet-JCST.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Qiang Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCFNet_pytorch([JCST](https://jcst.ict.ac.cn/en/article/doi/10.1007/s11390-023-3788-3)) 2 | 3 | [️‍🔥News️‍🔥] DCFNet is accepted in JCST. If you find [**DCFNet**](https://arxiv.org/pdf/1704.04057.pdf) useful in your research, please consider citing: 4 | 5 | ``` 6 | @Article{JCST-2309-13788, 7 | title = {DCFNet: Discriminant Correlation Filters Network for Visual Tracking}, 8 | journal = {Journal of Computer Science and Technology}, 9 | year = {2023}, 10 | issn = {1000-9000(Print) /1860-4749(Online)}, 11 | doi = {10.1007/s11390-023-3788-3}, 12 | author = {Wei-Ming Hu and Qiang Wang and Jin Gao and Bing Li and Stephen Maybank} 13 | } 14 | ``` 15 | 16 | 17 | 18 | This repository contains a Python *reimplementation* of the [**DCFNet**](https://arxiv.org/pdf/1704.04057.pdf). 19 | 20 | ### Why implementation in python (PyTorch)? 21 | 22 | - Magical **Autograd** mechanism via PyTorch. Do not need to know the complicated BP. 23 | - Fast Fourier Transforms (**FFT**) supported by PyTorch 0.4.0. 24 | - Engineering demand. 25 | - Fast test speed (**120 FPS** on GTX 1060) and **Multi-GPUs** training. 26 | 27 | ### Contents 28 | 1. [Requirements](#requirements) 29 | 2. [Test](#test) 30 | 3. [Train](#train) 31 | 4. [Citing DCFNet](#citing-dcfnet) 32 | 33 | ## Requirements 34 | 35 | ```shell 36 | git clone --depth=1 https://github.com/foolwood/DCFNet_pytorch 37 | ``` 38 | 39 | Requirements for **PyTorch 0.4.0** and opencv-python 40 | 41 | ```shell 42 | conda install pytorch torchvision -c pytorch 43 | conda install -c menpo opencv 44 | ``` 45 | 46 | Training data (VID) and Test dataset (OTB). 47 | 48 | ## Test 49 | 50 | ```shell 51 | cd DCFNet_pytorch/track 52 | ln -s /path/to/your/OTB2015 ./dataset/OTB2015 53 | ln -s ./dataset/OTB2015 ./dataset/OTB2013 54 | cd dataset & python gen_otb2013.py 55 | python DCFNet.py 56 | ``` 57 | 58 | ## Train 59 | 60 | 1. Download training data. ([**ILSVRC2015 VID**](http://bvisionweb1.cs.unc.edu/ilsvrc2015/download-videos-3j16.php#vid)) 61 | 62 | ``` 63 | ./ILSVRC2015 64 | ├── Annotations 65 | │   └── VID├── a -> ./ILSVRC2015_VID_train_0000 66 | │ ├── b -> ./ILSVRC2015_VID_train_0001 67 | │ ├── c -> ./ILSVRC2015_VID_train_0002 68 | │ ├── d -> ./ILSVRC2015_VID_train_0003 69 | │ ├── e -> ./val 70 | │ ├── ILSVRC2015_VID_train_0000 71 | │ ├── ILSVRC2015_VID_train_0001 72 | │ ├── ILSVRC2015_VID_train_0002 73 | │ ├── ILSVRC2015_VID_train_0003 74 | │ └── val 75 | ├── Data 76 | │   └── VID...........same as Annotations 77 | └── ImageSets 78 | └── VID 79 | ``` 80 | 81 | 2. Prepare training data for `dataloader`. 82 | 83 | ```shell 84 | cd DCFNet_pytorch/train/dataset 85 | python parse_vid.py # save all vid info in a single json 86 | python gen_snippet.py # generate snippets 87 | python crop_image.py # crop and generate a json for dataloader 88 | ``` 89 | 90 | 3. Training. (on multiple ***GPUs*** :zap: :zap: :zap: :zap:) 91 | 92 | ``` 93 | cd DCFNet_pytorch/train/ 94 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_DCFNet.py 95 | ``` 96 | 97 | 98 | ## Fine-tune hyper-parameter 99 | 100 | 1. After training, you can simple test the model with default parameter. 101 | 102 | ```shell 103 | cd DCFNet_pytorch/track/ 104 | python DCFNet --model ../train/work/crop_125_2.0/checkpoint.pth.tar 105 | ``` 106 | 107 | 2. Search a better hyper-parameter. 108 | 109 | ```shell 110 | CUDA_VISIBLE_DEVICES=0 python tune_otb.py # run on parallel to speed up searching 111 | python eval_otb.py OTB2013 * 0 10000 112 | ``` 113 | 114 | ## Citing DCFNet 115 | 116 | If you find [**DCFNet**](https://arxiv.org/pdf/1704.04057.pdf) useful in your research, please consider citing: 117 | 118 | ``` 119 | @article{wang2017dcfnet, 120 | title={DCFNet: Discriminant Correlation Filters Network for Visual Tracking}, 121 | author={Wang, Qiang and Gao, Jin and Xing, Junliang and Zhang, Mengdan and Hu, Weiming}, 122 | journal={arXiv preprint arXiv:1704.04057}, 123 | year={2017} 124 | } 125 | ``` 126 | -------------------------------------------------------------------------------- /track/DCFNet.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import makedirs 3 | import argparse 4 | import json 5 | import numpy as np 6 | import torch 7 | 8 | import cv2 9 | import time as time 10 | from util import crop_chw, gaussian_shaped_labels, cxy_wh_2_rect1, rect1_2_cxy_wh, cxy_wh_2_bbox 11 | from net import DCFNet 12 | from eval_otb import eval_auc 13 | 14 | 15 | class TrackerConfig(object): 16 | # These are the default hyper-params for DCFNet 17 | # OTB2013 / AUC(0.665) 18 | feature_path = 'param.pth' 19 | crop_sz = 125 20 | 21 | lambda0 = 1e-4 22 | padding = 2 23 | output_sigma_factor = 0.1 24 | interp_factor = 0.01 25 | num_scale = 3 26 | scale_step = 1.0275 27 | scale_factor = scale_step ** (np.arange(num_scale) - num_scale / 2) 28 | min_scale_factor = 0.2 29 | max_scale_factor = 5 30 | scale_penalty = 0.9925 31 | scale_penalties = scale_penalty ** (np.abs((np.arange(num_scale) - num_scale / 2))) 32 | 33 | net_input_size = [crop_sz, crop_sz] 34 | net_average_image = np.array([104, 117, 123]).reshape(-1, 1, 1).astype(np.float32) 35 | output_sigma = crop_sz / (1 + padding) * output_sigma_factor 36 | y = gaussian_shaped_labels(output_sigma, net_input_size) 37 | yf = torch.rfft(torch.Tensor(y).view(1, 1, crop_sz, crop_sz).cuda(), signal_ndim=2) 38 | cos_window = torch.Tensor(np.outer(np.hanning(crop_sz), np.hanning(crop_sz))).cuda() 39 | 40 | 41 | class DCFNetTraker(object): 42 | def __init__(self, im, init_rect, config=TrackerConfig(), gpu=True): 43 | self.gpu = gpu 44 | self.config = config 45 | self.net = DCFNet(config) 46 | self.net.load_param(config.feature_path) 47 | self.net.eval() 48 | if gpu: 49 | self.net.cuda() 50 | 51 | # confine results 52 | target_pos, target_sz = rect1_2_cxy_wh(init_rect) 53 | self.min_sz = np.maximum(config.min_scale_factor * target_sz, 4) 54 | self.max_sz = np.minimum(im.shape[:2], config.max_scale_factor * target_sz) 55 | 56 | # crop template 57 | window_sz = target_sz * (1 + config.padding) 58 | bbox = cxy_wh_2_bbox(target_pos, window_sz) 59 | patch = crop_chw(im, bbox, self.config.crop_sz) 60 | 61 | target = patch - config.net_average_image 62 | self.net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda()) 63 | self.target_pos, self.target_sz = target_pos, target_sz 64 | self.patch_crop = np.zeros((config.num_scale, patch.shape[0], patch.shape[1], patch.shape[2]), np.float32) # buff 65 | 66 | def track(self, im): 67 | for i in range(self.config.num_scale): # crop multi-scale search region 68 | window_sz = self.target_sz * (self.config.scale_factor[i] * (1 + self.config.padding)) 69 | bbox = cxy_wh_2_bbox(self.target_pos, window_sz) 70 | self.patch_crop[i, :] = crop_chw(im, bbox, self.config.crop_sz) 71 | 72 | search = self.patch_crop - self.config.net_average_image 73 | 74 | if self.gpu: 75 | response = self.net(torch.Tensor(search).cuda()) 76 | else: 77 | response = self.net(torch.Tensor(search)) 78 | peak, idx = torch.max(response.view(self.config.num_scale, -1), 1) 79 | peak = peak.data.cpu().numpy() * self.config.scale_penalties 80 | best_scale = np.argmax(peak) 81 | r_max, c_max = np.unravel_index(idx[best_scale], self.config.net_input_size) 82 | 83 | if r_max > self.config.net_input_size[0] / 2: 84 | r_max = r_max - self.config.net_input_size[0] 85 | if c_max > self.config.net_input_size[1] / 2: 86 | c_max = c_max - self.config.net_input_size[1] 87 | window_sz = self.target_sz * (self.config.scale_factor[best_scale] * (1 + self.config.padding)) 88 | 89 | self.target_pos = self.target_pos + np.array([c_max, r_max]) * window_sz / self.config.net_input_size 90 | self.target_sz = np.minimum(np.maximum(window_sz / (1 + self.config.padding), self.min_sz), self.max_sz) 91 | 92 | # model update 93 | window_sz = self.target_sz * (1 + self.config.padding) 94 | bbox = cxy_wh_2_bbox(self.target_pos, window_sz) 95 | patch = crop_chw(im, bbox, self.config.crop_sz) 96 | target = patch - self.config.net_average_image 97 | self.net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda(), lr=self.config.interp_factor) 98 | 99 | return cxy_wh_2_rect1(self.target_pos, self.target_sz) # 1-index 100 | 101 | 102 | if __name__ == '__main__': 103 | # base dataset path and setting 104 | parser = argparse.ArgumentParser(description='Test DCFNet on OTB') 105 | parser.add_argument('--dataset', metavar='SET', default='OTB2013', 106 | choices=['OTB2013', 'OTB2015'], help='tune on which dataset') 107 | parser.add_argument('--model', metavar='PATH', default='param.pth') 108 | args = parser.parse_args() 109 | 110 | dataset = args.dataset 111 | base_path = join('dataset', dataset) 112 | json_path = join('dataset', dataset + '.json') 113 | annos = json.load(open(json_path, 'r')) 114 | videos = sorted(annos.keys()) 115 | 116 | use_gpu = True 117 | visualization = False 118 | 119 | # default parameter and load feature extractor network 120 | config = TrackerConfig() 121 | net = DCFNet(config) 122 | net.load_param(args.model) 123 | net.eval().cuda() 124 | 125 | speed = [] 126 | # loop videos 127 | for video_id, video in enumerate(videos): # run without resetting 128 | video_path_name = annos[video]['name'] 129 | init_rect = np.array(annos[video]['init_rect']).astype(np.float) 130 | image_files = [join(base_path, video_path_name, 'img', im_f) for im_f in annos[video]['image_files']] 131 | n_images = len(image_files) 132 | 133 | tic = time.time() # time start 134 | 135 | target_pos, target_sz = rect1_2_cxy_wh(init_rect) # OTB label is 1-indexed 136 | 137 | im = cv2.imread(image_files[0]) # HxWxC 138 | 139 | # confine results 140 | min_sz = np.maximum(config.min_scale_factor * target_sz, 4) 141 | max_sz = np.minimum(im.shape[:2], config.max_scale_factor * target_sz) 142 | 143 | # crop template 144 | window_sz = target_sz * (1 + config.padding) 145 | bbox = cxy_wh_2_bbox(target_pos, window_sz) 146 | patch = crop_chw(im, bbox, config.crop_sz) 147 | 148 | target = patch - config.net_average_image 149 | net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda()) 150 | 151 | res = [cxy_wh_2_rect1(target_pos, target_sz)] # save in .txt 152 | patch_crop = np.zeros((config.num_scale, patch.shape[0], patch.shape[1], patch.shape[2]), np.float32) 153 | for f in range(1, n_images): # track 154 | im = cv2.imread(image_files[f]) 155 | 156 | for i in range(config.num_scale): # crop multi-scale search region 157 | window_sz = target_sz * (config.scale_factor[i] * (1 + config.padding)) 158 | bbox = cxy_wh_2_bbox(target_pos, window_sz) 159 | patch_crop[i, :] = crop_chw(im, bbox, config.crop_sz) 160 | 161 | search = patch_crop - config.net_average_image 162 | response = net(torch.Tensor(search).cuda()) 163 | peak, idx = torch.max(response.view(config.num_scale, -1), 1) 164 | peak = peak.data.cpu().numpy() * config.scale_penalties 165 | best_scale = np.argmax(peak) 166 | r_max, c_max = np.unravel_index(idx[best_scale], config.net_input_size) 167 | 168 | if r_max > config.net_input_size[0] / 2: 169 | r_max = r_max - config.net_input_size[0] 170 | if c_max > config.net_input_size[1] / 2: 171 | c_max = c_max - config.net_input_size[1] 172 | window_sz = target_sz * (config.scale_factor[best_scale] * (1 + config.padding)) 173 | 174 | target_pos = target_pos + np.array([c_max, r_max]) * window_sz / config.net_input_size 175 | target_sz = np.minimum(np.maximum(window_sz / (1 + config.padding), min_sz), max_sz) 176 | 177 | # model update 178 | window_sz = target_sz * (1 + config.padding) 179 | bbox = cxy_wh_2_bbox(target_pos, window_sz) 180 | patch = crop_chw(im, bbox, config.crop_sz) 181 | target = patch - config.net_average_image 182 | net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda(), lr=config.interp_factor) 183 | 184 | res.append(cxy_wh_2_rect1(target_pos, target_sz)) # 1-index 185 | 186 | if visualization: 187 | im_show = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) 188 | cv2.rectangle(im_show, (int(target_pos[0] - target_sz[0] / 2), int(target_pos[1] - target_sz[1] / 2)), 189 | (int(target_pos[0] + target_sz[0] / 2), int(target_pos[1] + target_sz[1] / 2)), 190 | (0, 255, 0), 3) 191 | cv2.putText(im_show, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2, cv2.LINE_AA) 192 | cv2.imshow(video, im_show) 193 | cv2.waitKey(1) 194 | 195 | toc = time.time() - tic 196 | fps = n_images / toc 197 | speed.append(fps) 198 | print('{:3d} Video: {:12s} Time: {:3.1f}s\tSpeed: {:3.1f}fps'.format(video_id, video, toc, fps)) 199 | 200 | # save result 201 | test_path = join('result', dataset, 'DCFNet_test') 202 | if not isdir(test_path): makedirs(test_path) 203 | result_path = join(test_path, video + '.txt') 204 | with open(result_path, 'w') as f: 205 | for x in res: 206 | f.write(','.join(['{:.2f}'.format(i) for i in x]) + '\n') 207 | 208 | print('***Total Mean Speed: {:3.1f} (FPS)***'.format(np.mean(speed))) 209 | 210 | eval_auc(dataset, 'DCFNet_test', 0, 1) 211 | -------------------------------------------------------------------------------- /track/dataset/gen_otb2013.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | OTB2015 = json.load(open('OTB2015.json', 'r')) 4 | videos = OTB2015.keys() 5 | 6 | OTB2013 = dict() 7 | for v in videos: 8 | if v in ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix', 9 | 'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek', 10 | 'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking', 11 | 'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1', 12 | 'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2', 13 | 'basketball', 'football', 'subway', 'tiger1', 'tiger2']: 14 | OTB2013[v] = OTB2015[v] 15 | 16 | 17 | json.dump(OTB2013, open('OTB2013.json', 'w'), indent=2) 18 | 19 | -------------------------------------------------------------------------------- /track/eval_otb.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import os 4 | import glob 5 | from os.path import join as fullfile 6 | import numpy as np 7 | 8 | 9 | def overlap_ratio(rect1, rect2): 10 | ''' 11 | Compute overlap ratio between two rects 12 | - rect: 1d array of [x,y,w,h] or 13 | 2d array of N x [x,y,w,h] 14 | ''' 15 | 16 | if rect1.ndim==1: 17 | rect1 = rect1[None,:] 18 | if rect2.ndim==1: 19 | rect2 = rect2[None,:] 20 | 21 | left = np.maximum(rect1[:,0], rect2[:,0]) 22 | right = np.minimum(rect1[:,0]+rect1[:,2], rect2[:,0]+rect2[:,2]) 23 | top = np.maximum(rect1[:,1], rect2[:,1]) 24 | bottom = np.minimum(rect1[:,1]+rect1[:,3], rect2[:,1]+rect2[:,3]) 25 | 26 | intersect = np.maximum(0,right - left) * np.maximum(0,bottom - top) 27 | union = rect1[:,2]*rect1[:,3] + rect2[:,2]*rect2[:,3] - intersect 28 | iou = np.clip(intersect / union, 0, 1) 29 | return iou 30 | 31 | 32 | def compute_success_overlap(gt_bb, result_bb): 33 | thresholds_overlap = np.arange(0, 1.05, 0.05) 34 | n_frame = len(gt_bb) 35 | success = np.zeros(len(thresholds_overlap)) 36 | iou = overlap_ratio(gt_bb, result_bb) 37 | for i in range(len(thresholds_overlap)): 38 | success[i] = sum(iou > thresholds_overlap[i]) / float(n_frame) 39 | return success 40 | 41 | 42 | def compute_success_error(gt_center, result_center): 43 | thresholds_error = np.arange(0, 51, 1) 44 | n_frame = len(gt_center) 45 | success = np.zeros(len(thresholds_error)) 46 | dist = np.sqrt(np.sum(np.power(gt_center - result_center, 2), axis=1)) 47 | for i in range(len(thresholds_error)): 48 | success[i] = sum(dist <= thresholds_error[i]) / float(n_frame) 49 | return success 50 | 51 | 52 | def get_result_bb(arch, seq): 53 | result_path = fullfile(arch, seq + '.txt') 54 | temp = np.loadtxt(result_path, delimiter=',').astype(np.float) 55 | return np.array(temp) 56 | 57 | 58 | def convert_bb_to_center(bboxes): 59 | return np.array([(bboxes[:, 0] + (bboxes[:, 2] - 1) / 2), 60 | (bboxes[:, 1] + (bboxes[:, 3] - 1) / 2)]).T 61 | 62 | 63 | def eval_auc(dataset='OTB2015', tracker_reg='S*', start=0, end=1e6): 64 | list_path = os.path.join('dataset', dataset + '.json') 65 | annos = json.load(open(list_path, 'r')) 66 | seqs = annos.keys() 67 | 68 | OTB2013 = ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix', 69 | 'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek', 70 | 'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking', 71 | 'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1', 72 | 'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2', 73 | 'basketball', 'football', 'subway', 'tiger1', 'tiger2'] 74 | 75 | OTB2015 = ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix', 76 | 'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek', 77 | 'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking', 78 | 'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1', 79 | 'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2', 80 | 'basketball', 'football', 'subway', 'tiger1', 'tiger2', 'clifBar', 'biker', 'bird1', 'blurBody', 81 | 'blurCar2', 'blurFace', 'blurOwl', 'box', 'car1', 'crowds', 'diving', 'dragonBaby', 'human3', 'human4_2', 82 | 'human6', 'human9', 'jump', 'panda', 'redTeam', 'skating2_1', 'skating2_2', 'surfer', 'bird2', 83 | 'blurCar1', 'blurCar3', 'blurCar4', 'board', 'bolt2', 'car2', 'car24', 'coupon', 'dancer', 'dancer2', 84 | 'dog', 'girl2', 'gym', 'human2', 'human5', 'human7', 'human8', 'kiteSurf', 'man', 'rubik', 'skater', 85 | 'skater2', 'toy', 'trans', 'twinnings', 'vase'] 86 | 87 | trackers = glob.glob(fullfile('result', dataset, tracker_reg)) 88 | trackers = trackers[start:min(end, len(trackers))] 89 | 90 | n_seq = len(seqs) 91 | thresholds_overlap = np.arange(0, 1.05, 0.05) 92 | # thresholds_error = np.arange(0, 51, 1) 93 | 94 | success_overlap = np.zeros((n_seq, len(trackers), len(thresholds_overlap))) 95 | # success_error = np.zeros((n_seq, len(trackers), len(thresholds_error))) 96 | for i in range(n_seq): 97 | seq = seqs[i] 98 | gt_rect = np.array(annos[seq]['gt_rect']).astype(np.float) 99 | gt_center = convert_bb_to_center(gt_rect) 100 | for j in range(len(trackers)): 101 | tracker = trackers[j] 102 | print('{:d} processing:{} tracker: {}'.format(i, seq, tracker)) 103 | bb = get_result_bb(tracker, seq) 104 | center = convert_bb_to_center(bb) 105 | success_overlap[i][j] = compute_success_overlap(gt_rect, bb) 106 | # success_error[i][j] = compute_success_error(gt_center, center) 107 | 108 | print('Success Overlap') 109 | 110 | if 'OTB2015' == dataset: 111 | OTB2013_id = [] 112 | for i in range(n_seq): 113 | if seqs[i] in OTB2013: 114 | OTB2013_id.append(i) 115 | max_auc_OTB2013 = 0. 116 | max_name_OTB2013 = '' 117 | for i in range(len(trackers)): 118 | auc = success_overlap[OTB2013_id, i, :].mean() 119 | if auc > max_auc_OTB2013: 120 | max_auc_OTB2013 = auc 121 | max_name_OTB2013 = trackers[i] 122 | print('%s(%.4f)' % (trackers[i], auc)) 123 | 124 | max_auc = 0. 125 | max_name = '' 126 | for i in range(len(trackers)): 127 | auc = success_overlap[:, i, :].mean() 128 | if auc > max_auc: 129 | max_auc = auc 130 | max_name = trackers[i] 131 | print('%s(%.4f)' % (trackers[i], auc)) 132 | 133 | print('\nOTB2013 Best: %s(%.4f)' % (max_name_OTB2013, max_auc_OTB2013)) 134 | print('\nOTB2015 Best: %s(%.4f)' % (max_name, max_auc)) 135 | else: 136 | max_auc = 0. 137 | max_name = '' 138 | for i in range(len(trackers)): 139 | auc = success_overlap[:, i, :].mean() 140 | if auc > max_auc: 141 | max_auc = auc 142 | max_name = trackers[i] 143 | print('%s(%.4f)' % (trackers[i], auc)) 144 | 145 | print('\n%s Best: %s(%.4f)' % (dataset, max_name, max_auc)) 146 | 147 | 148 | if __name__ == "__main__": 149 | if len(sys.argv) < 5: 150 | print('python eval_otb.py OTB2015 DCFNet_test* 0 10') 151 | exit() 152 | dataset = sys.argv[1] 153 | tracker_reg = sys.argv[2] 154 | start = int(sys.argv[3]) 155 | end = int(sys.argv[4]) 156 | eval_auc(dataset, tracker_reg, start, end) 157 | -------------------------------------------------------------------------------- /track/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch # pytorch 0.4.0! fft 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def complex_mul(x, z): 8 | out_real = x[..., 0] * z[..., 0] - x[..., 1] * z[..., 1] 9 | out_imag = x[..., 0] * z[..., 1] + x[..., 1] * z[..., 0] 10 | return torch.stack((out_real, out_imag), -1) 11 | 12 | 13 | def complex_mulconj(x, z): 14 | out_real = x[..., 0] * z[..., 0] + x[..., 1] * z[..., 1] 15 | out_imag = x[..., 1] * z[..., 0] - x[..., 0] * z[..., 1] 16 | return torch.stack((out_real, out_imag), -1) 17 | 18 | 19 | class DCFNetFeature(nn.Module): 20 | def __init__(self): 21 | super(DCFNetFeature, self).__init__() 22 | self.feature = nn.Sequential( 23 | nn.Conv2d(3, 32, 3, padding=1), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(32, 32, 3, padding=1), 26 | nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1), 27 | ) 28 | 29 | def forward(self, x): 30 | return self.feature(x) 31 | 32 | 33 | class DCFNet(nn.Module): 34 | def __init__(self, config=None): 35 | super(DCFNet, self).__init__() 36 | self.feature = DCFNetFeature() 37 | self.model_alphaf = [] 38 | self.model_xf = [] 39 | self.config = config 40 | 41 | def forward(self, x): 42 | x = self.feature(x) * self.config.cos_window 43 | xf = torch.rfft(x, signal_ndim=2) 44 | kxzf = torch.sum(complex_mulconj(xf, self.model_zf), dim=1, keepdim=True) 45 | response = torch.irfft(complex_mul(kxzf, self.model_alphaf), signal_ndim=2) 46 | # r_max = torch.max(response) 47 | # cv2.imshow('response', response[0, 0].data.cpu().numpy()) 48 | # cv2.waitKey(0) 49 | return response 50 | 51 | def update(self, z, lr=1.): 52 | z = self.feature(z) * self.config.cos_window 53 | zf = torch.rfft(z, signal_ndim=2) 54 | kzzf = torch.sum(torch.sum(zf ** 2, dim=4, keepdim=True), dim=1, keepdim=True) 55 | alphaf = self.config.yf / (kzzf + self.config.lambda0) 56 | if lr > 0.99: 57 | self.model_alphaf = alphaf 58 | self.model_zf = zf 59 | else: 60 | self.model_alphaf = (1 - lr) * self.model_alphaf.data + lr * alphaf.data 61 | self.model_zf = (1 - lr) * self.model_zf.data + lr * zf.data 62 | 63 | def load_param(self, path='param.pth'): 64 | checkpoint = torch.load(path) 65 | if 'state_dict' in checkpoint.keys(): # from training result 66 | state_dict = checkpoint['state_dict'] 67 | if 'module' in state_dict.keys()[0]: # train with nn.DataParallel 68 | from collections import OrderedDict 69 | new_state_dict = OrderedDict() 70 | for k, v in state_dict.items(): 71 | name = k[7:] # remove `module.` 72 | new_state_dict[name] = v 73 | self.load_state_dict(new_state_dict) 74 | else: 75 | self.load_state_dict(state_dict) 76 | else: 77 | self.feature.load_state_dict(checkpoint) 78 | 79 | 80 | if __name__ == '__main__': 81 | 82 | # network test 83 | net = DCFNetFeature() 84 | net.eval() 85 | for idx, m in enumerate(net.modules()): 86 | print(idx, '->', m) 87 | for name, param in net.named_parameters(): 88 | if 'bias' in name or 'weight' in name: 89 | print(param.size()) 90 | from scipy import io 91 | import numpy as np 92 | p = io.loadmat('net_param.mat') 93 | x = p['res'][0][0][:,:,::-1].copy() 94 | x_out = p['res'][0][-1] 95 | from collections import OrderedDict 96 | pth_state_dict = OrderedDict() 97 | 98 | match_dict = dict() 99 | match_dict['feature.0.weight'] = 'conv1_w' 100 | match_dict['feature.0.bias'] = 'conv1_b' 101 | match_dict['feature.2.weight'] = 'conv2_w' 102 | match_dict['feature.2.bias'] = 'conv2_b' 103 | 104 | for var_name in net.state_dict().keys(): 105 | print var_name 106 | key_in_model = match_dict[var_name] 107 | param_in_model = var_name.rsplit('.', 1)[1] 108 | if 'weight' in var_name: 109 | pth_state_dict[var_name] = torch.Tensor(np.transpose(p[key_in_model],(3,2,0,1))) 110 | elif 'bias' in var_name: 111 | pth_state_dict[var_name] = torch.Tensor(np.squeeze(p[key_in_model])) 112 | if var_name == 'feature.0.weight': 113 | weight = pth_state_dict[var_name].data.numpy() 114 | weight = weight[:, ::-1, :, :].copy() # cv2 bgr input 115 | pth_state_dict[var_name] = torch.Tensor(weight) 116 | 117 | 118 | torch.save(pth_state_dict, 'param.pth') 119 | net.load_state_dict(torch.load('param.pth')) 120 | x_t = torch.Tensor(np.expand_dims(np.transpose(x,(2,0,1)), axis=0)) 121 | x_pred = net(x_t).data.numpy() 122 | pred_error = np.sum(np.abs(np.transpose(x_pred,(0,2,3,1)).reshape(-1) - x_out.reshape(-1))) 123 | 124 | x_fft = torch.rfft(x_t, signal_ndim=2, onesided=False) 125 | 126 | 127 | print('model_transfer_error:{:.5f}'.format(pred_error)) 128 | 129 | 130 | -------------------------------------------------------------------------------- /track/net_param.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foolwood/DCFNet_pytorch/b8434baa2d136df8f55c1addb3e77f40b3c379fc/track/net_param.mat -------------------------------------------------------------------------------- /track/param.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foolwood/DCFNet_pytorch/b8434baa2d136df8f55c1addb3e77f40b3c379fc/track/param.pth -------------------------------------------------------------------------------- /track/tune_otb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | from os import makedirs 5 | from os.path import isfile, isdir, join 6 | from util import cxy_wh_2_rect1 7 | import torch 8 | import json 9 | from DCFNet import * 10 | 11 | parser = argparse.ArgumentParser(description='Tune parameters for DCFNet tracker on OTB2015') 12 | parser.add_argument('-v', '--visualization', dest='visualization', action='store_true', 13 | help='whether visualize result') 14 | 15 | args = parser.parse_args() 16 | 17 | 18 | def tune_otb(param): 19 | regions = [] # result and states[1 init / 2 lost / 0 skip] 20 | # save result 21 | benchmark_result_path = join('result', param['dataset']) 22 | tracker_path = join(benchmark_result_path, (param['network_name'] + 23 | '_scale_step_{:.3f}'.format(param['config'].scale_step) + 24 | '_scale_penalty_{:.3f}'.format(param['config'].scale_penalty) + 25 | '_interp_factor_{:.3f}'.format(param['config'].interp_factor))) 26 | result_path = join(tracker_path, '{:s}.txt'.format(param['video'])) 27 | if isfile(result_path): 28 | return 29 | if not isdir(tracker_path): makedirs(tracker_path) 30 | with open(result_path, 'w') as f: # Occupation 31 | for x in regions: 32 | f.write('') 33 | 34 | ims = param['ims'] 35 | toc = 0 36 | for f, im in enumerate(ims): 37 | tic = cv2.getTickCount() 38 | if f == 0: # init 39 | init_rect = p['init_rect'] 40 | tracker = DCFNetTraker(ims[f], init_rect, config=param['config']) 41 | regions.append(init_rect) 42 | else: # tracking 43 | rect = tracker.track(ims[f]) 44 | regions.append(rect) 45 | toc += cv2.getTickCount() - tic 46 | 47 | if args.visualization: # visualization (skip lost frame) 48 | if f == 0: cv2.destroyAllWindows() 49 | location = [int(l) for l in location] # int 50 | cv2.rectangle(im, (location[0], location[1]), (location[0] + location[2], location[1] + location[3]), (0, 255, 255), 3) 51 | cv2.putText(im, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2) 52 | 53 | cv2.imshow(video, im) 54 | cv2.waitKey(1) 55 | toc /= cv2.getTickFrequency() 56 | print('{:2d} Video: {:12s} Time: {:2.1f}s Speed: {:3.1f}fps'.format(v, video, toc, f / toc)) 57 | regions = np.array(regions) 58 | regions[:,:2] += 1 # 1-index 59 | with open(result_path, 'w') as f: 60 | for x in regions: 61 | f.write(','.join(['{:.2f}'.format(i) for i in x]) + '\n') 62 | 63 | 64 | params = {'dataset':['OTB2013'], 'network':['param.pth'], 65 | 'scale_step':np.arange(1.01, 1.05, 0.005, np.float32), 66 | 'scale_penalty':np.arange(0.98, 1.0, 0.025, np.float32), 67 | 'interp_factor':np.arange(0.001, 0.015, 0.001, np.float32)} 68 | 69 | p = dict() 70 | p['config'] = TrackerConfig() 71 | for network in params['network']: 72 | p['network_name'] = network 73 | np.random.shuffle(params['dataset']) 74 | for dataset in params['dataset']: 75 | base_path = join('dataset', dataset) 76 | json_path = join('dataset', dataset+'.json') 77 | annos = json.load(open(json_path, 'r')) 78 | videos = annos.keys() 79 | p['dataset'] = dataset 80 | np.random.shuffle(videos) 81 | for v, video in enumerate(videos): 82 | p['v'] = v 83 | p['video'] = video 84 | video_path_name = annos[video]['name'] 85 | init_rect = np.array(annos[video]['init_rect']).astype(np.float) 86 | image_files = [join(base_path, video_path_name, 'img', im_f) for im_f in annos[video]['image_files']] 87 | target_pos = np.array([init_rect[0] + init_rect[2] / 2 -1 , init_rect[1] + init_rect[3] / 2 -1]) # 0-index 88 | target_sz = np.array([init_rect[2], init_rect[3]]) 89 | ims = [] 90 | for image_file in image_files: 91 | im = cv2.imread(image_file) 92 | if im.shape[2] == 1: 93 | cv2.cvtColor(im, im, cv2.COLOR_GRAY2RGB) 94 | ims.append(im) 95 | p['ims'] = ims 96 | p['init_rect'] = init_rect 97 | 98 | np.random.shuffle(params['scale_step']) 99 | np.random.shuffle(params['scale_penalty']) 100 | np.random.shuffle(params['interp_factor']) 101 | for scale_step in params['scale_step']: 102 | for scale_penalty in params['scale_penalty']: 103 | for interp_factor in params['interp_factor']: 104 | p['config'].scale_step = float(scale_step) 105 | p['config'].scale_penalty = float(scale_penalty) 106 | p['config'].interp_factor = float(interp_factor) 107 | tune_otb(p) 108 | -------------------------------------------------------------------------------- /track/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def cxy_wh_2_rect1(pos, sz): 6 | return np.array([pos[0]-sz[0]/2+1, pos[1]-sz[1]/2+1, sz[0], sz[1]]) # 1-index 7 | 8 | 9 | def rect1_2_cxy_wh(rect): 10 | return np.array([rect[0]+rect[2]/2-1, rect[1]+rect[3]/2-1]), np.array([rect[2], rect[3]]) # 0-index 11 | 12 | 13 | def cxy_wh_2_bbox(cxy, wh): 14 | return np.array([cxy[0]-wh[0]/2, cxy[1]-wh[1]/2, cxy[0]+wh[0]/2, cxy[1]+wh[1]/2]) # 0-index 15 | 16 | 17 | def gaussian_shaped_labels(sigma, sz): 18 | x, y = np.meshgrid(np.arange(1, sz[0]+1) - np.floor(float(sz[0]) / 2), np.arange(1, sz[1]+1) - np.floor(float(sz[1]) / 2)) 19 | d = x ** 2 + y ** 2 20 | g = np.exp(-0.5 / (sigma ** 2) * d) 21 | g = np.roll(g, int(-np.floor(float(sz[0]) / 2.) + 1), axis=0) 22 | g = np.roll(g, int(-np.floor(float(sz[1]) / 2.) + 1), axis=1) 23 | return g 24 | 25 | 26 | def crop_chw(image, bbox, out_sz, padding=(0, 0, 0)): 27 | a = (out_sz-1) / (bbox[2]-bbox[0]) 28 | b = (out_sz-1) / (bbox[3]-bbox[1]) 29 | c = -a * bbox[0] 30 | d = -b * bbox[1] 31 | mapping = np.array([[a, 0, c], 32 | [0, b, d]]).astype(np.float) 33 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 34 | return np.transpose(crop, (2, 0, 1)) 35 | 36 | 37 | if __name__ == '__main__': 38 | a = gaussian_shaped_labels(10, [5,5]) 39 | print a -------------------------------------------------------------------------------- /train/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from os.path import join 3 | import cv2 4 | import json 5 | import numpy as np 6 | 7 | 8 | class VID(data.Dataset): 9 | def __init__(self, file='dataset/dataset.json', root='dataset/crop_125_2.0', range=10, train=True): 10 | self.imdb = json.load(open(file, 'r')) 11 | self.root = root 12 | self.range = range 13 | self.train = train 14 | self.mean = np.expand_dims(np.expand_dims(np.array([109, 120, 119]), axis=1), axis=1).astype(np.float32) 15 | 16 | def __getitem__(self, item): 17 | if self.train: 18 | target_id = self.imdb['train_set'][item] 19 | else: 20 | target_id = self.imdb['val_set'][item] 21 | 22 | # range_down = self.imdb['down_index'][target_id] 23 | range_up = self.imdb['up_index'][target_id] 24 | # search_id = np.random.randint(-min(range_down, self.range), min(range_up, self.range)) + target_id 25 | search_id = np.random.randint(1, min(range_up, self.range+1)) + target_id 26 | 27 | target = cv2.imread(join(self.root, '{:08d}.jpg'.format(target_id))) 28 | search = cv2.imread(join(self.root, '{:08d}.jpg'.format(search_id))) 29 | 30 | target = np.transpose(target, (2, 0, 1)).astype(np.float32) - self.mean 31 | search = np.transpose(search, (2, 0, 1)).astype(np.float32) - self.mean 32 | 33 | return target, search 34 | 35 | def __len__(self): 36 | if self.train: 37 | return len(self.imdb['train_set']) 38 | else: 39 | return len(self.imdb['val_set']) 40 | 41 | 42 | if __name__ == '__main__': 43 | import matplotlib.pyplot as plt 44 | import matplotlib.patches as patches 45 | data = VID(train=True) 46 | n = len(data) 47 | fig = plt.figure(1) 48 | ax = fig.add_axes([0, 0, 1, 1]) 49 | 50 | for i in range(n): 51 | z, x = data[i] 52 | z, x = np.transpose(z, (1, 2, 0)).astype(np.uint8), np.transpose(x, (1, 2, 0)).astype(np.uint8) 53 | zx = np.concatenate((z, x), axis=1) 54 | 55 | ax.imshow(cv2.cvtColor(zx, cv2.COLOR_BGR2RGB)) 56 | p = patches.Rectangle( 57 | (125/3, 125/3), 125/3, 125/3, fill=False, clip_on=False, linewidth=2, edgecolor='g') 58 | ax.add_patch(p) 59 | p = patches.Rectangle( 60 | (125 / 3+125, 125 / 3), 125 / 3, 125 / 3, fill=False, clip_on=False, linewidth=2, edgecolor='r') 61 | ax.add_patch(p) 62 | plt.pause(0.5) 63 | -------------------------------------------------------------------------------- /train/dataset/compute-image-mean.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import time 5 | import glob 6 | 7 | from skimage import io 8 | import cv2 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--meanPrefix', default='mean_img', type=str, help="Prefix of the mean file.") 13 | parser.add_argument('--imageDir', default='crop_125_2.0', type=str, help="Directory of images to read.") 14 | args = parser.parse_args() 15 | 16 | mean = np.zeros((1, 3, 125, 125)) 17 | N = 0 18 | opencv_backend = True 19 | beginTime = time.time() 20 | files = glob.glob(os.path.join(args.imageDir, '*.jpg')) 21 | for file in files: 22 | if opencv_backend: 23 | img = cv2.imread(file) 24 | else: 25 | img = io.imread(file) 26 | if img.shape == (125, 125, 3): 27 | mean[0][0] += img[:, :, 0] 28 | mean[0][1] += img[:, :, 1] 29 | mean[0][2] += img[:, :, 2] 30 | N += 1 31 | if N % 1000 == 0: 32 | elapsed = time.time() - beginTime 33 | print("Processed {} images in {:.2f} seconds. " 34 | "{:.2f} images/second.".format(N, elapsed, N / elapsed)) 35 | mean[0] /= N 36 | 37 | meanImg = np.transpose(mean[0].astype(np.uint8), (1, 2, 0)) 38 | if opencv_backend: 39 | cv2.imwrite("{}.png".format(args.meanPrefix), meanImg) 40 | else: 41 | io.imsave("{}.png".format(args.meanPrefix), meanImg) 42 | 43 | avg_chans = np.mean(meanImg, axis=(0, 1)) 44 | if opencv_backend: 45 | print("image BGR mean: {}".format(avg_chans)) 46 | else: 47 | print("image RGB mean: {}".format(avg_chans)) -------------------------------------------------------------------------------- /train/dataset/crop_image.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import mkdir 3 | import argparse 4 | import numpy as np 5 | import json 6 | import cv2 7 | import time 8 | 9 | parse = argparse.ArgumentParser(description='Generate training data (cropped) for DCFNet_pytorch') 10 | parse.add_argument('-v', '--visual', dest='visual', action='store_true', help='whether visualise crop') 11 | parse.add_argument('-o', '--output_size', dest='output_size', default=125, type=int, help='crop output size') 12 | parse.add_argument('-p', '--padding', dest='padding', default=2, type=float, help='crop padding size') 13 | 14 | args = parse.parse_args() 15 | 16 | print args 17 | 18 | 19 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)): 20 | bbox = [float(x) for x in bbox] 21 | a = (out_sz-1) / (bbox[2]-bbox[0]) 22 | b = (out_sz-1) / (bbox[3]-bbox[1]) 23 | c = -a * bbox[0] 24 | d = -b * bbox[1] 25 | mapping = np.array([[a, 0, c], 26 | [0, b, d]]).astype(np.float) 27 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding) 28 | return crop 29 | 30 | 31 | def cxy_wh_2_bbox(cxy, wh): 32 | return np.array([cxy[0] - wh[0] / 2, cxy[1] - wh[1] / 2, cxy[0] + wh[0] / 2, cxy[1] + wh[1] / 2]) # 0-index 33 | 34 | 35 | snaps = json.load(open('snippet.json', 'r')) 36 | 37 | num_all_frame = 546315 # cat snippet.json | grep bbox |wc -l 38 | num_val = 1000 39 | # crop image 40 | lmdb = dict() 41 | lmdb['down_index'] = np.zeros(num_all_frame, np.int) # buff 42 | lmdb['up_index'] = np.zeros(num_all_frame, np.int) 43 | 44 | crop_base_path = 'crop_{:d}_{:1.1f}'.format(args.output_size, args.padding) 45 | if not isdir(crop_base_path): 46 | mkdir(crop_base_path) 47 | 48 | count = 0 49 | begin_time = time.time() 50 | for snap in snaps: 51 | frames = snap['frame'] 52 | n_frames = len(frames) 53 | for f, frame in enumerate(frames): 54 | img_path = join(snap['base_path'], frame['img_path']) 55 | im = cv2.imread(img_path) 56 | avg_chans = np.mean(im, axis=(0, 1)) 57 | bbox = frame['obj']['bbox'] 58 | 59 | target_pos = [(bbox[2] + bbox[0])/2, (bbox[3] + bbox[1])/2] 60 | target_sz = np.array([bbox[2] - bbox[0], bbox[3] - bbox[1]]) 61 | window_sz = target_sz * (1 + args.padding) 62 | crop_bbox = cxy_wh_2_bbox(target_pos, window_sz) 63 | patch = crop_hwc(im, crop_bbox, args.output_size) 64 | cv2.imwrite(join(crop_base_path, '{:08d}.jpg'.format(count)), patch) 65 | # cv2.imwrite('crop.jpg'.format(count), patch) 66 | 67 | lmdb['down_index'][count] = f 68 | lmdb['up_index'][count] = n_frames - f 69 | count += 1 70 | if count % 100 == 0: 71 | elapsed = time.time() - begin_time 72 | print("Processed {} images in {:.2f} seconds. " 73 | "{:.2f} images/second.".format(count, elapsed, count / elapsed)) 74 | 75 | template_id = np.where(lmdb['up_index'] > 1)[0] # NEVER use the last frame as template! I do not like bidirectional. 76 | rand_split = np.random.choice(len(template_id), len(template_id)) 77 | lmdb['train_set'] = template_id[rand_split[:(len(template_id)-num_val)]] 78 | lmdb['val_set'] = template_id[rand_split[(len(template_id)-num_val):]] 79 | print len(lmdb['train_set']) 80 | print len(lmdb['val_set']) 81 | 82 | # to list for json 83 | lmdb['train_set'] = lmdb['train_set'].tolist() 84 | lmdb['val_set'] = lmdb['val_set'].tolist() 85 | lmdb['down_index'] = lmdb['down_index'].tolist() 86 | lmdb['up_index'] = lmdb['up_index'].tolist() 87 | 88 | print('lmdb json, please wait 5 seconds~') 89 | json.dump(lmdb, open('dataset.json', 'w'), indent=2) 90 | print('done!') 91 | -------------------------------------------------------------------------------- /train/dataset/gen_snippet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | 4 | 5 | def check_size(frame_sz, bbox): 6 | min_ratio = 0.1 7 | max_ratio = 0.75 8 | # only accept objects >10% and <75% of the total frame 9 | area_ratio = np.sqrt((bbox[2]-bbox[0])*(bbox[3]-bbox[1])/float(np.prod(frame_sz))) 10 | ok = (area_ratio > min_ratio) and (area_ratio < max_ratio) 11 | return ok 12 | 13 | 14 | def check_borders(frame_sz, bbox): 15 | dist_from_border = 0.05 * (bbox[2] - bbox[0] + bbox[3] - bbox[1])/2 16 | ok = (bbox[0] > dist_from_border) and (bbox[1] > dist_from_border) and \ 17 | ((frame_sz[0] - bbox[2]) > dist_from_border) and \ 18 | ((frame_sz[1] - bbox[3]) > dist_from_border) 19 | return ok 20 | 21 | 22 | # Filter out snippets 23 | print('load json (raw vid info), please wait 20 seconds~') 24 | vid = json.load(open('vid.json', 'r')) 25 | snippets = [] 26 | n_snippets = 0 27 | n_videos = 0 28 | for subset in vid: 29 | for video in subset: 30 | n_videos += 1 31 | frames = video['frame'] 32 | id_set = [] 33 | id_frames = [[]] * 60 # at most 60 objects 34 | for f, frame in enumerate(frames): 35 | objs = frame['objs'] 36 | frame_sz = frame['frame_sz'] 37 | for obj in objs: 38 | trackid = obj['trackid'] 39 | occluded = obj['occ'] 40 | bbox = obj['bbox'] 41 | if occluded: 42 | continue 43 | 44 | if not(check_size(frame_sz, bbox) and check_borders(frame_sz, bbox)): 45 | continue 46 | 47 | if obj['c'] in ['n01674464', 'n01726692', 'n04468005', 'n02062744']: 48 | continue 49 | 50 | if trackid not in id_set: 51 | id_set.append(trackid) 52 | id_frames[trackid] = [] 53 | id_frames[trackid].append(f) 54 | 55 | for selected in id_set: 56 | frame_ids = sorted(id_frames[selected]) 57 | sequences = np.split(frame_ids, np.array(np.where(np.diff(frame_ids) > 1)[0]) + 1) 58 | sequences = [s for s in sequences if len(s) > 1] # remove isolated frame. 59 | for seq in sequences: 60 | snippet = dict() 61 | snippet['base_path'] = video['base_path'] 62 | snippet['frame'] = [] 63 | for frame_id in seq: 64 | frame = frames[frame_id] 65 | f = dict() 66 | f['frame_sz'] = frame['frame_sz'] 67 | f['img_path'] = frame['img_path'] 68 | for obj in frame['objs']: 69 | if obj['trackid'] == selected: 70 | o = obj 71 | continue 72 | f['obj'] = o 73 | snippet['frame'].append(f) 74 | snippets.append(snippet) 75 | n_snippets += 1 76 | print('video: {:d} snippets_num: {:d}'.format(n_videos, n_snippets)) 77 | 78 | print('save json (snippets), please wait 20 seconds~') 79 | json.dump(snippets, open('snippet.json', 'w'), indent=2) 80 | print('done!') 81 | -------------------------------------------------------------------------------- /train/dataset/parse_vid.py: -------------------------------------------------------------------------------- 1 | from os.path import join, isdir 2 | from os import listdir 3 | import argparse 4 | import json 5 | import glob 6 | import xml.etree.ElementTree as ET 7 | 8 | parser = argparse.ArgumentParser(description='Parse the VID Annotations for training DCFNet') 9 | parser.add_argument('data', metavar='DIR', help='path to VID') 10 | args = parser.parse_args() 11 | 12 | print('VID2015 Data:') 13 | VID_base_path = args.data 14 | ann_base_path = join(VID_base_path, 'Annotations/VID/train/') 15 | img_base_path = join(VID_base_path, 'Data/VID/train/') 16 | sub_sets = sorted({'a', 'b', 'c', 'd', 'e'}) 17 | 18 | vid = [] 19 | for sub_set in sub_sets: 20 | sub_set_base_path = join(ann_base_path, sub_set) 21 | videos = sorted(listdir(sub_set_base_path)) 22 | s = [] 23 | for vi, video in enumerate(videos): 24 | print('subset: {} video id: {:04d} / {:04d}'.format(sub_set, vi, len(videos))) 25 | v = dict() 26 | v['base_path'] = join(img_base_path, sub_set, video) 27 | v['frame'] = [] 28 | video_base_path = join(sub_set_base_path, video) 29 | xmls = sorted(glob.glob(join(video_base_path, '*.xml'))) 30 | for xml in xmls: 31 | f = dict() 32 | xmltree = ET.parse(xml) 33 | size = xmltree.findall('size')[0] 34 | frame_sz = [int(it.text) for it in size] 35 | objects = xmltree.findall('object') 36 | objs = [] 37 | for object_iter in objects: 38 | trackid = int(object_iter.find('trackid').text) 39 | name = (object_iter.find('name')).text 40 | bndbox = object_iter.find('bndbox') 41 | occluded = int(object_iter.find('occluded').text) 42 | o = dict() 43 | o['c'] = name 44 | o['bbox'] = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text), 45 | int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)] 46 | o['trackid'] = trackid 47 | o['occ'] = occluded 48 | objs.append(o) 49 | f['frame_sz'] = frame_sz 50 | f['img_path'] = xml.split('/')[-1].replace('xml', 'JPEG') 51 | f['objs'] = objs 52 | v['frame'].append(f) 53 | s.append(v) 54 | vid.append(s) 55 | print('save json (raw vid info), please wait 1 min~') 56 | json.dump(vid, open('vid.json', 'w'), indent=2) 57 | print('done!') 58 | 59 | -------------------------------------------------------------------------------- /train/net.py: -------------------------------------------------------------------------------- 1 | import torch # pytorch 0.4.0! fft 2 | import torch.nn as nn 3 | 4 | 5 | def complex_mul(x, z): 6 | out_real = x[..., 0] * z[..., 0] - x[..., 1] * z[..., 1] 7 | out_imag = x[..., 0] * z[..., 1] + x[..., 1] * z[..., 0] 8 | return torch.stack((out_real, out_imag), -1) 9 | 10 | 11 | def complex_mulconj(x, z): 12 | out_real = x[..., 0] * z[..., 0] + x[..., 1] * z[..., 1] 13 | out_imag = x[..., 1] * z[..., 0] - x[..., 0] * z[..., 1] 14 | return torch.stack((out_real, out_imag), -1) 15 | 16 | 17 | class DCFNetFeature(nn.Module): 18 | def __init__(self): 19 | super(DCFNetFeature, self).__init__() 20 | self.feature = nn.Sequential( 21 | nn.Conv2d(3, 32, 3), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(32, 32, 3), 24 | nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1), 25 | ) 26 | 27 | def forward(self, x): 28 | return self.feature(x) 29 | 30 | 31 | class DCFNet(nn.Module): 32 | def __init__(self, config=None): 33 | super(DCFNet, self).__init__() 34 | self.feature = DCFNetFeature() 35 | self.yf = config.yf.clone() 36 | self.lambda0 = config.lambda0 37 | 38 | def forward(self, z, x): 39 | z = self.feature(z) 40 | x = self.feature(x) 41 | zf = torch.rfft(z, signal_ndim=2) 42 | xf = torch.rfft(x, signal_ndim=2) 43 | 44 | kzzf = torch.sum(torch.sum(zf ** 2, dim=4, keepdim=True), dim=1, keepdim=True) 45 | kxzf = torch.sum(complex_mulconj(xf, zf), dim=1, keepdim=True) 46 | alphaf = self.yf.to(device=z.device) / (kzzf + self.lambda0) # very Ugly 47 | response = torch.irfft(complex_mul(kxzf, alphaf), signal_ndim=2) 48 | return response 49 | 50 | 51 | if __name__ == '__main__': 52 | 53 | # network test 54 | net = DCFNet() 55 | net.eval() 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /train/train_DCFNet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from os.path import join, isdir, isfile 4 | from os import makedirs 5 | 6 | from dataset import VID 7 | from net import DCFNet 8 | import torch 9 | from torch.utils.data import dataloader 10 | import torch.nn as nn 11 | import torch.backends.cudnn as cudnn 12 | import numpy as np 13 | import time 14 | 15 | 16 | parser = argparse.ArgumentParser(description='Training DCFNet in Pytorch 0.4.0') 17 | parser.add_argument('--input_sz', dest='input_sz', default=125, type=int, help='crop input size') 18 | parser.add_argument('--padding', dest='padding', default=2.0, type=float, help='crop padding size') 19 | parser.add_argument('--range', dest='range', default=10, type=int, help='select range') 20 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 21 | help='number of total epochs to run') 22 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 23 | help='manual epoch number (useful on restarts)') 24 | parser.add_argument('--print-freq', '-p', default=10, type=int, 25 | metavar='N', help='print frequency (default: 10)') 26 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 27 | help='number of data loading workers (default: 8)') 28 | parser.add_argument('-b', '--batch-size', default=32, type=int, 29 | metavar='N', help='mini-batch size (default: 32)') 30 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 31 | metavar='LR', help='initial learning rate') 32 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 33 | help='momentum') 34 | parser.add_argument('--weight-decay', '--wd', default=5e-5, type=float, 35 | metavar='W', help='weight decay (default: 5e-5)') 36 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 37 | parser.add_argument('--save', '-s', default='./work', type=str, help='directory for saving') 38 | 39 | args = parser.parse_args() 40 | 41 | print args 42 | best_loss = 1e6 43 | 44 | 45 | def gaussian_shaped_labels(sigma, sz): 46 | x, y = np.meshgrid(np.arange(1, sz[0]+1) - np.floor(float(sz[0]) / 2), np.arange(1, sz[1]+1) - np.floor(float(sz[1]) / 2)) 47 | d = x ** 2 + y ** 2 48 | g = np.exp(-0.5 / (sigma ** 2) * d) 49 | g = np.roll(g, int(-np.floor(float(sz[0]) / 2.) + 1), axis=0) 50 | g = np.roll(g, int(-np.floor(float(sz[1]) / 2.) + 1), axis=1) 51 | return g.astype(np.float32) 52 | 53 | 54 | class TrackerConfig(object): 55 | crop_sz = 125 56 | output_sz = 121 57 | 58 | lambda0 = 1e-4 59 | padding = 2.0 60 | output_sigma_factor = 0.1 61 | 62 | output_sigma = crop_sz / (1 + padding) * output_sigma_factor 63 | y = gaussian_shaped_labels(output_sigma, [output_sz, output_sz]) 64 | yf = torch.rfft(torch.Tensor(y).view(1, 1, output_sz, output_sz).cuda(), signal_ndim=2) 65 | # cos_window = torch.Tensor(np.outer(np.hanning(crop_sz), np.hanning(crop_sz))).cuda() # train without cos window 66 | 67 | 68 | config = TrackerConfig() 69 | 70 | model = DCFNet(config=config) 71 | model.cuda() 72 | gpu_num = torch.cuda.device_count() 73 | print('GPU NUM: {:2d}'.format(gpu_num)) 74 | if gpu_num > 1: 75 | model = torch.nn.DataParallel(model, list(range(gpu_num))).cuda() 76 | 77 | criterion = nn.MSELoss(size_average=False).cuda() 78 | 79 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 80 | momentum=args.momentum, 81 | weight_decay=args.weight_decay) 82 | 83 | target = torch.Tensor(config.y).cuda().unsqueeze(0).unsqueeze(0).repeat(args.batch_size * gpu_num, 1, 1, 1) # for training 84 | # optionally resume from a checkpoint 85 | if args.resume: 86 | if isfile(args.resume): 87 | print("=> loading checkpoint '{}'".format(args.resume)) 88 | checkpoint = torch.load(args.resume) 89 | args.start_epoch = checkpoint['epoch'] 90 | best_loss = checkpoint['best_loss'] 91 | model.load_state_dict(checkpoint['state_dict']) 92 | optimizer.load_state_dict(checkpoint['optimizer']) 93 | print("=> loaded checkpoint '{}' (epoch {})" 94 | .format(args.resume, checkpoint['epoch'])) 95 | else: 96 | print("=> no checkpoint found at '{}'".format(args.resume)) 97 | 98 | cudnn.benchmark = True 99 | 100 | # training data 101 | crop_base_path = join('dataset', 'crop_{:d}_{:1.1f}'.format(args.input_sz, args.padding)) 102 | if not isdir(crop_base_path): 103 | print('please run gen_training_data.py --output_size {:d} --padding {:.1f}!'.format(args.input_sz, args.padding)) 104 | exit() 105 | 106 | save_path = join(args.save, 'crop_{:d}_{:1.1f}'.format(args.input_sz, args.padding)) 107 | if not isdir(save_path): 108 | makedirs(save_path) 109 | 110 | train_dataset = VID(root=crop_base_path, train=True, range=args.range) 111 | val_dataset = VID(root=crop_base_path, train=False, range=args.range) 112 | 113 | train_loader = torch.utils.data.DataLoader( 114 | train_dataset, batch_size=args.batch_size*gpu_num, shuffle=True, 115 | num_workers=args.workers, pin_memory=True, drop_last=True) 116 | 117 | val_loader = torch.utils.data.DataLoader( 118 | val_dataset, batch_size=args.batch_size*gpu_num, shuffle=False, 119 | num_workers=args.workers, pin_memory=True, drop_last=True) 120 | 121 | 122 | def adjust_learning_rate(optimizer, epoch): 123 | lr = np.logspace(-2, -5, num=args.epochs)[epoch] 124 | for param_group in optimizer.param_groups: 125 | param_group['lr'] = lr 126 | 127 | 128 | class AverageMeter(object): 129 | """Computes and stores the average and current value""" 130 | def __init__(self): 131 | self.reset() 132 | 133 | def reset(self): 134 | self.val = 0 135 | self.avg = 0 136 | self.sum = 0 137 | self.count = 0 138 | 139 | def update(self, val, n=1): 140 | self.val = val 141 | self.sum += val * n 142 | self.count += n 143 | self.avg = self.sum / self.count 144 | 145 | 146 | def save_checkpoint(state, is_best, filename=join(save_path, 'checkpoint.pth.tar')): 147 | torch.save(state, filename) 148 | if is_best: 149 | shutil.copyfile(filename, join(save_path, 'model_best.pth.tar')) 150 | 151 | 152 | def train(train_loader, model, criterion, optimizer, epoch): 153 | batch_time = AverageMeter() 154 | data_time = AverageMeter() 155 | losses = AverageMeter() 156 | 157 | # switch to train mode 158 | model.train() 159 | 160 | end = time.time() 161 | for i, (template, search) in enumerate(train_loader): 162 | # measure data loading time 163 | data_time.update(time.time() - end) 164 | 165 | template = template.cuda(non_blocking=True) 166 | search = search.cuda(non_blocking=True) 167 | 168 | # compute output 169 | output = model(template, search) 170 | loss = criterion(output, target)/template.size(0) 171 | 172 | # measure accuracy and record loss 173 | losses.update(loss.item()) 174 | 175 | # compute gradient and do SGD step 176 | optimizer.zero_grad() 177 | loss.backward() 178 | optimizer.step() 179 | 180 | # measure elapsed time 181 | batch_time.update(time.time() - end) 182 | end = time.time() 183 | 184 | if i % args.print_freq == 0: 185 | print('Epoch: [{0}][{1}/{2}]\t' 186 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 187 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 188 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 189 | epoch, i, len(train_loader), batch_time=batch_time, 190 | data_time=data_time, loss=losses)) 191 | 192 | 193 | def validate(val_loader, model, criterion): 194 | batch_time = AverageMeter() 195 | losses = AverageMeter() 196 | 197 | # switch to evaluate mode 198 | model.eval() 199 | 200 | with torch.no_grad(): 201 | end = time.time() 202 | for i, (template, search) in enumerate(val_loader): 203 | 204 | # compute output 205 | template = template.cuda(non_blocking=True) 206 | search = search.cuda(non_blocking=True) 207 | 208 | # compute output 209 | output = model(template, search) 210 | loss = criterion(output, target)/(args.batch_size * gpu_num) 211 | 212 | # measure accuracy and record loss 213 | losses.update(loss.item()) 214 | 215 | # measure elapsed time 216 | batch_time.update(time.time() - end) 217 | end = time.time() 218 | 219 | if i % args.print_freq == 0: 220 | print('Test: [{0}/{1}]\t' 221 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 222 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 223 | i, len(val_loader), batch_time=batch_time, loss=losses)) 224 | 225 | print(' * Loss {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses)) 226 | 227 | return losses.avg 228 | 229 | 230 | for epoch in range(args.start_epoch, args.epochs): 231 | adjust_learning_rate(optimizer, epoch) 232 | 233 | # train for one epoch 234 | train(train_loader, model, criterion, optimizer, epoch) 235 | 236 | # evaluate on validation set 237 | loss = validate(val_loader, model, criterion) 238 | 239 | # remember best loss and save checkpoint 240 | is_best = loss < best_loss 241 | best_loss = min(best_loss, loss) 242 | save_checkpoint({ 243 | 'epoch': epoch + 1, 244 | 'state_dict': model.state_dict(), 245 | 'best_loss': best_loss, 246 | 'optimizer': optimizer.state_dict(), 247 | }, is_best) 248 | --------------------------------------------------------------------------------