├── .gitignore
├── DCFNet-JCST.pdf
├── LICENSE
├── README.md
├── track
├── DCFNet.py
├── dataset
│ ├── OTB2015.json
│ └── gen_otb2013.py
├── eval_otb.py
├── net.py
├── net_param.mat
├── param.pth
├── tune_otb.py
└── util.py
└── train
├── dataset.py
├── dataset
├── compute-image-mean.py
├── crop_image.py
├── gen_snippet.py
└── parse_vid.py
├── net.py
└── train_DCFNet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/DCFNet-JCST.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foolwood/DCFNet_pytorch/b8434baa2d136df8f55c1addb3e77f40b3c379fc/DCFNet-JCST.pdf
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Qiang Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DCFNet_pytorch([JCST](https://jcst.ict.ac.cn/en/article/doi/10.1007/s11390-023-3788-3))
2 |
3 | [️🔥News️🔥] DCFNet is accepted in JCST. If you find [**DCFNet**](https://arxiv.org/pdf/1704.04057.pdf) useful in your research, please consider citing:
4 |
5 | ```
6 | @Article{JCST-2309-13788,
7 | title = {DCFNet: Discriminant Correlation Filters Network for Visual Tracking},
8 | journal = {Journal of Computer Science and Technology},
9 | year = {2023},
10 | issn = {1000-9000(Print) /1860-4749(Online)},
11 | doi = {10.1007/s11390-023-3788-3},
12 | author = {Wei-Ming Hu and Qiang Wang and Jin Gao and Bing Li and Stephen Maybank}
13 | }
14 | ```
15 |
16 |
17 |
18 | This repository contains a Python *reimplementation* of the [**DCFNet**](https://arxiv.org/pdf/1704.04057.pdf).
19 |
20 | ### Why implementation in python (PyTorch)?
21 |
22 | - Magical **Autograd** mechanism via PyTorch. Do not need to know the complicated BP.
23 | - Fast Fourier Transforms (**FFT**) supported by PyTorch 0.4.0.
24 | - Engineering demand.
25 | - Fast test speed (**120 FPS** on GTX 1060) and **Multi-GPUs** training.
26 |
27 | ### Contents
28 | 1. [Requirements](#requirements)
29 | 2. [Test](#test)
30 | 3. [Train](#train)
31 | 4. [Citing DCFNet](#citing-dcfnet)
32 |
33 | ## Requirements
34 |
35 | ```shell
36 | git clone --depth=1 https://github.com/foolwood/DCFNet_pytorch
37 | ```
38 |
39 | Requirements for **PyTorch 0.4.0** and opencv-python
40 |
41 | ```shell
42 | conda install pytorch torchvision -c pytorch
43 | conda install -c menpo opencv
44 | ```
45 |
46 | Training data (VID) and Test dataset (OTB).
47 |
48 | ## Test
49 |
50 | ```shell
51 | cd DCFNet_pytorch/track
52 | ln -s /path/to/your/OTB2015 ./dataset/OTB2015
53 | ln -s ./dataset/OTB2015 ./dataset/OTB2013
54 | cd dataset & python gen_otb2013.py
55 | python DCFNet.py
56 | ```
57 |
58 | ## Train
59 |
60 | 1. Download training data. ([**ILSVRC2015 VID**](http://bvisionweb1.cs.unc.edu/ilsvrc2015/download-videos-3j16.php#vid))
61 |
62 | ```
63 | ./ILSVRC2015
64 | ├── Annotations
65 | │ └── VID├── a -> ./ILSVRC2015_VID_train_0000
66 | │ ├── b -> ./ILSVRC2015_VID_train_0001
67 | │ ├── c -> ./ILSVRC2015_VID_train_0002
68 | │ ├── d -> ./ILSVRC2015_VID_train_0003
69 | │ ├── e -> ./val
70 | │ ├── ILSVRC2015_VID_train_0000
71 | │ ├── ILSVRC2015_VID_train_0001
72 | │ ├── ILSVRC2015_VID_train_0002
73 | │ ├── ILSVRC2015_VID_train_0003
74 | │ └── val
75 | ├── Data
76 | │ └── VID...........same as Annotations
77 | └── ImageSets
78 | └── VID
79 | ```
80 |
81 | 2. Prepare training data for `dataloader`.
82 |
83 | ```shell
84 | cd DCFNet_pytorch/train/dataset
85 | python parse_vid.py # save all vid info in a single json
86 | python gen_snippet.py # generate snippets
87 | python crop_image.py # crop and generate a json for dataloader
88 | ```
89 |
90 | 3. Training. (on multiple ***GPUs*** :zap: :zap: :zap: :zap:)
91 |
92 | ```
93 | cd DCFNet_pytorch/train/
94 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_DCFNet.py
95 | ```
96 |
97 |
98 | ## Fine-tune hyper-parameter
99 |
100 | 1. After training, you can simple test the model with default parameter.
101 |
102 | ```shell
103 | cd DCFNet_pytorch/track/
104 | python DCFNet --model ../train/work/crop_125_2.0/checkpoint.pth.tar
105 | ```
106 |
107 | 2. Search a better hyper-parameter.
108 |
109 | ```shell
110 | CUDA_VISIBLE_DEVICES=0 python tune_otb.py # run on parallel to speed up searching
111 | python eval_otb.py OTB2013 * 0 10000
112 | ```
113 |
114 | ## Citing DCFNet
115 |
116 | If you find [**DCFNet**](https://arxiv.org/pdf/1704.04057.pdf) useful in your research, please consider citing:
117 |
118 | ```
119 | @article{wang2017dcfnet,
120 | title={DCFNet: Discriminant Correlation Filters Network for Visual Tracking},
121 | author={Wang, Qiang and Gao, Jin and Xing, Junliang and Zhang, Mengdan and Hu, Weiming},
122 | journal={arXiv preprint arXiv:1704.04057},
123 | year={2017}
124 | }
125 | ```
126 |
--------------------------------------------------------------------------------
/track/DCFNet.py:
--------------------------------------------------------------------------------
1 | from os.path import join, isdir
2 | from os import makedirs
3 | import argparse
4 | import json
5 | import numpy as np
6 | import torch
7 |
8 | import cv2
9 | import time as time
10 | from util import crop_chw, gaussian_shaped_labels, cxy_wh_2_rect1, rect1_2_cxy_wh, cxy_wh_2_bbox
11 | from net import DCFNet
12 | from eval_otb import eval_auc
13 |
14 |
15 | class TrackerConfig(object):
16 | # These are the default hyper-params for DCFNet
17 | # OTB2013 / AUC(0.665)
18 | feature_path = 'param.pth'
19 | crop_sz = 125
20 |
21 | lambda0 = 1e-4
22 | padding = 2
23 | output_sigma_factor = 0.1
24 | interp_factor = 0.01
25 | num_scale = 3
26 | scale_step = 1.0275
27 | scale_factor = scale_step ** (np.arange(num_scale) - num_scale / 2)
28 | min_scale_factor = 0.2
29 | max_scale_factor = 5
30 | scale_penalty = 0.9925
31 | scale_penalties = scale_penalty ** (np.abs((np.arange(num_scale) - num_scale / 2)))
32 |
33 | net_input_size = [crop_sz, crop_sz]
34 | net_average_image = np.array([104, 117, 123]).reshape(-1, 1, 1).astype(np.float32)
35 | output_sigma = crop_sz / (1 + padding) * output_sigma_factor
36 | y = gaussian_shaped_labels(output_sigma, net_input_size)
37 | yf = torch.rfft(torch.Tensor(y).view(1, 1, crop_sz, crop_sz).cuda(), signal_ndim=2)
38 | cos_window = torch.Tensor(np.outer(np.hanning(crop_sz), np.hanning(crop_sz))).cuda()
39 |
40 |
41 | class DCFNetTraker(object):
42 | def __init__(self, im, init_rect, config=TrackerConfig(), gpu=True):
43 | self.gpu = gpu
44 | self.config = config
45 | self.net = DCFNet(config)
46 | self.net.load_param(config.feature_path)
47 | self.net.eval()
48 | if gpu:
49 | self.net.cuda()
50 |
51 | # confine results
52 | target_pos, target_sz = rect1_2_cxy_wh(init_rect)
53 | self.min_sz = np.maximum(config.min_scale_factor * target_sz, 4)
54 | self.max_sz = np.minimum(im.shape[:2], config.max_scale_factor * target_sz)
55 |
56 | # crop template
57 | window_sz = target_sz * (1 + config.padding)
58 | bbox = cxy_wh_2_bbox(target_pos, window_sz)
59 | patch = crop_chw(im, bbox, self.config.crop_sz)
60 |
61 | target = patch - config.net_average_image
62 | self.net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda())
63 | self.target_pos, self.target_sz = target_pos, target_sz
64 | self.patch_crop = np.zeros((config.num_scale, patch.shape[0], patch.shape[1], patch.shape[2]), np.float32) # buff
65 |
66 | def track(self, im):
67 | for i in range(self.config.num_scale): # crop multi-scale search region
68 | window_sz = self.target_sz * (self.config.scale_factor[i] * (1 + self.config.padding))
69 | bbox = cxy_wh_2_bbox(self.target_pos, window_sz)
70 | self.patch_crop[i, :] = crop_chw(im, bbox, self.config.crop_sz)
71 |
72 | search = self.patch_crop - self.config.net_average_image
73 |
74 | if self.gpu:
75 | response = self.net(torch.Tensor(search).cuda())
76 | else:
77 | response = self.net(torch.Tensor(search))
78 | peak, idx = torch.max(response.view(self.config.num_scale, -1), 1)
79 | peak = peak.data.cpu().numpy() * self.config.scale_penalties
80 | best_scale = np.argmax(peak)
81 | r_max, c_max = np.unravel_index(idx[best_scale], self.config.net_input_size)
82 |
83 | if r_max > self.config.net_input_size[0] / 2:
84 | r_max = r_max - self.config.net_input_size[0]
85 | if c_max > self.config.net_input_size[1] / 2:
86 | c_max = c_max - self.config.net_input_size[1]
87 | window_sz = self.target_sz * (self.config.scale_factor[best_scale] * (1 + self.config.padding))
88 |
89 | self.target_pos = self.target_pos + np.array([c_max, r_max]) * window_sz / self.config.net_input_size
90 | self.target_sz = np.minimum(np.maximum(window_sz / (1 + self.config.padding), self.min_sz), self.max_sz)
91 |
92 | # model update
93 | window_sz = self.target_sz * (1 + self.config.padding)
94 | bbox = cxy_wh_2_bbox(self.target_pos, window_sz)
95 | patch = crop_chw(im, bbox, self.config.crop_sz)
96 | target = patch - self.config.net_average_image
97 | self.net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda(), lr=self.config.interp_factor)
98 |
99 | return cxy_wh_2_rect1(self.target_pos, self.target_sz) # 1-index
100 |
101 |
102 | if __name__ == '__main__':
103 | # base dataset path and setting
104 | parser = argparse.ArgumentParser(description='Test DCFNet on OTB')
105 | parser.add_argument('--dataset', metavar='SET', default='OTB2013',
106 | choices=['OTB2013', 'OTB2015'], help='tune on which dataset')
107 | parser.add_argument('--model', metavar='PATH', default='param.pth')
108 | args = parser.parse_args()
109 |
110 | dataset = args.dataset
111 | base_path = join('dataset', dataset)
112 | json_path = join('dataset', dataset + '.json')
113 | annos = json.load(open(json_path, 'r'))
114 | videos = sorted(annos.keys())
115 |
116 | use_gpu = True
117 | visualization = False
118 |
119 | # default parameter and load feature extractor network
120 | config = TrackerConfig()
121 | net = DCFNet(config)
122 | net.load_param(args.model)
123 | net.eval().cuda()
124 |
125 | speed = []
126 | # loop videos
127 | for video_id, video in enumerate(videos): # run without resetting
128 | video_path_name = annos[video]['name']
129 | init_rect = np.array(annos[video]['init_rect']).astype(np.float)
130 | image_files = [join(base_path, video_path_name, 'img', im_f) for im_f in annos[video]['image_files']]
131 | n_images = len(image_files)
132 |
133 | tic = time.time() # time start
134 |
135 | target_pos, target_sz = rect1_2_cxy_wh(init_rect) # OTB label is 1-indexed
136 |
137 | im = cv2.imread(image_files[0]) # HxWxC
138 |
139 | # confine results
140 | min_sz = np.maximum(config.min_scale_factor * target_sz, 4)
141 | max_sz = np.minimum(im.shape[:2], config.max_scale_factor * target_sz)
142 |
143 | # crop template
144 | window_sz = target_sz * (1 + config.padding)
145 | bbox = cxy_wh_2_bbox(target_pos, window_sz)
146 | patch = crop_chw(im, bbox, config.crop_sz)
147 |
148 | target = patch - config.net_average_image
149 | net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda())
150 |
151 | res = [cxy_wh_2_rect1(target_pos, target_sz)] # save in .txt
152 | patch_crop = np.zeros((config.num_scale, patch.shape[0], patch.shape[1], patch.shape[2]), np.float32)
153 | for f in range(1, n_images): # track
154 | im = cv2.imread(image_files[f])
155 |
156 | for i in range(config.num_scale): # crop multi-scale search region
157 | window_sz = target_sz * (config.scale_factor[i] * (1 + config.padding))
158 | bbox = cxy_wh_2_bbox(target_pos, window_sz)
159 | patch_crop[i, :] = crop_chw(im, bbox, config.crop_sz)
160 |
161 | search = patch_crop - config.net_average_image
162 | response = net(torch.Tensor(search).cuda())
163 | peak, idx = torch.max(response.view(config.num_scale, -1), 1)
164 | peak = peak.data.cpu().numpy() * config.scale_penalties
165 | best_scale = np.argmax(peak)
166 | r_max, c_max = np.unravel_index(idx[best_scale], config.net_input_size)
167 |
168 | if r_max > config.net_input_size[0] / 2:
169 | r_max = r_max - config.net_input_size[0]
170 | if c_max > config.net_input_size[1] / 2:
171 | c_max = c_max - config.net_input_size[1]
172 | window_sz = target_sz * (config.scale_factor[best_scale] * (1 + config.padding))
173 |
174 | target_pos = target_pos + np.array([c_max, r_max]) * window_sz / config.net_input_size
175 | target_sz = np.minimum(np.maximum(window_sz / (1 + config.padding), min_sz), max_sz)
176 |
177 | # model update
178 | window_sz = target_sz * (1 + config.padding)
179 | bbox = cxy_wh_2_bbox(target_pos, window_sz)
180 | patch = crop_chw(im, bbox, config.crop_sz)
181 | target = patch - config.net_average_image
182 | net.update(torch.Tensor(np.expand_dims(target, axis=0)).cuda(), lr=config.interp_factor)
183 |
184 | res.append(cxy_wh_2_rect1(target_pos, target_sz)) # 1-index
185 |
186 | if visualization:
187 | im_show = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
188 | cv2.rectangle(im_show, (int(target_pos[0] - target_sz[0] / 2), int(target_pos[1] - target_sz[1] / 2)),
189 | (int(target_pos[0] + target_sz[0] / 2), int(target_pos[1] + target_sz[1] / 2)),
190 | (0, 255, 0), 3)
191 | cv2.putText(im_show, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2, cv2.LINE_AA)
192 | cv2.imshow(video, im_show)
193 | cv2.waitKey(1)
194 |
195 | toc = time.time() - tic
196 | fps = n_images / toc
197 | speed.append(fps)
198 | print('{:3d} Video: {:12s} Time: {:3.1f}s\tSpeed: {:3.1f}fps'.format(video_id, video, toc, fps))
199 |
200 | # save result
201 | test_path = join('result', dataset, 'DCFNet_test')
202 | if not isdir(test_path): makedirs(test_path)
203 | result_path = join(test_path, video + '.txt')
204 | with open(result_path, 'w') as f:
205 | for x in res:
206 | f.write(','.join(['{:.2f}'.format(i) for i in x]) + '\n')
207 |
208 | print('***Total Mean Speed: {:3.1f} (FPS)***'.format(np.mean(speed)))
209 |
210 | eval_auc(dataset, 'DCFNet_test', 0, 1)
211 |
--------------------------------------------------------------------------------
/track/dataset/gen_otb2013.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | OTB2015 = json.load(open('OTB2015.json', 'r'))
4 | videos = OTB2015.keys()
5 |
6 | OTB2013 = dict()
7 | for v in videos:
8 | if v in ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix',
9 | 'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek',
10 | 'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking',
11 | 'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1',
12 | 'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2',
13 | 'basketball', 'football', 'subway', 'tiger1', 'tiger2']:
14 | OTB2013[v] = OTB2015[v]
15 |
16 |
17 | json.dump(OTB2013, open('OTB2013.json', 'w'), indent=2)
18 |
19 |
--------------------------------------------------------------------------------
/track/eval_otb.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import os
4 | import glob
5 | from os.path import join as fullfile
6 | import numpy as np
7 |
8 |
9 | def overlap_ratio(rect1, rect2):
10 | '''
11 | Compute overlap ratio between two rects
12 | - rect: 1d array of [x,y,w,h] or
13 | 2d array of N x [x,y,w,h]
14 | '''
15 |
16 | if rect1.ndim==1:
17 | rect1 = rect1[None,:]
18 | if rect2.ndim==1:
19 | rect2 = rect2[None,:]
20 |
21 | left = np.maximum(rect1[:,0], rect2[:,0])
22 | right = np.minimum(rect1[:,0]+rect1[:,2], rect2[:,0]+rect2[:,2])
23 | top = np.maximum(rect1[:,1], rect2[:,1])
24 | bottom = np.minimum(rect1[:,1]+rect1[:,3], rect2[:,1]+rect2[:,3])
25 |
26 | intersect = np.maximum(0,right - left) * np.maximum(0,bottom - top)
27 | union = rect1[:,2]*rect1[:,3] + rect2[:,2]*rect2[:,3] - intersect
28 | iou = np.clip(intersect / union, 0, 1)
29 | return iou
30 |
31 |
32 | def compute_success_overlap(gt_bb, result_bb):
33 | thresholds_overlap = np.arange(0, 1.05, 0.05)
34 | n_frame = len(gt_bb)
35 | success = np.zeros(len(thresholds_overlap))
36 | iou = overlap_ratio(gt_bb, result_bb)
37 | for i in range(len(thresholds_overlap)):
38 | success[i] = sum(iou > thresholds_overlap[i]) / float(n_frame)
39 | return success
40 |
41 |
42 | def compute_success_error(gt_center, result_center):
43 | thresholds_error = np.arange(0, 51, 1)
44 | n_frame = len(gt_center)
45 | success = np.zeros(len(thresholds_error))
46 | dist = np.sqrt(np.sum(np.power(gt_center - result_center, 2), axis=1))
47 | for i in range(len(thresholds_error)):
48 | success[i] = sum(dist <= thresholds_error[i]) / float(n_frame)
49 | return success
50 |
51 |
52 | def get_result_bb(arch, seq):
53 | result_path = fullfile(arch, seq + '.txt')
54 | temp = np.loadtxt(result_path, delimiter=',').astype(np.float)
55 | return np.array(temp)
56 |
57 |
58 | def convert_bb_to_center(bboxes):
59 | return np.array([(bboxes[:, 0] + (bboxes[:, 2] - 1) / 2),
60 | (bboxes[:, 1] + (bboxes[:, 3] - 1) / 2)]).T
61 |
62 |
63 | def eval_auc(dataset='OTB2015', tracker_reg='S*', start=0, end=1e6):
64 | list_path = os.path.join('dataset', dataset + '.json')
65 | annos = json.load(open(list_path, 'r'))
66 | seqs = annos.keys()
67 |
68 | OTB2013 = ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix',
69 | 'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek',
70 | 'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking',
71 | 'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1',
72 | 'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2',
73 | 'basketball', 'football', 'subway', 'tiger1', 'tiger2']
74 |
75 | OTB2015 = ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix',
76 | 'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek',
77 | 'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking',
78 | 'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1',
79 | 'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2',
80 | 'basketball', 'football', 'subway', 'tiger1', 'tiger2', 'clifBar', 'biker', 'bird1', 'blurBody',
81 | 'blurCar2', 'blurFace', 'blurOwl', 'box', 'car1', 'crowds', 'diving', 'dragonBaby', 'human3', 'human4_2',
82 | 'human6', 'human9', 'jump', 'panda', 'redTeam', 'skating2_1', 'skating2_2', 'surfer', 'bird2',
83 | 'blurCar1', 'blurCar3', 'blurCar4', 'board', 'bolt2', 'car2', 'car24', 'coupon', 'dancer', 'dancer2',
84 | 'dog', 'girl2', 'gym', 'human2', 'human5', 'human7', 'human8', 'kiteSurf', 'man', 'rubik', 'skater',
85 | 'skater2', 'toy', 'trans', 'twinnings', 'vase']
86 |
87 | trackers = glob.glob(fullfile('result', dataset, tracker_reg))
88 | trackers = trackers[start:min(end, len(trackers))]
89 |
90 | n_seq = len(seqs)
91 | thresholds_overlap = np.arange(0, 1.05, 0.05)
92 | # thresholds_error = np.arange(0, 51, 1)
93 |
94 | success_overlap = np.zeros((n_seq, len(trackers), len(thresholds_overlap)))
95 | # success_error = np.zeros((n_seq, len(trackers), len(thresholds_error)))
96 | for i in range(n_seq):
97 | seq = seqs[i]
98 | gt_rect = np.array(annos[seq]['gt_rect']).astype(np.float)
99 | gt_center = convert_bb_to_center(gt_rect)
100 | for j in range(len(trackers)):
101 | tracker = trackers[j]
102 | print('{:d} processing:{} tracker: {}'.format(i, seq, tracker))
103 | bb = get_result_bb(tracker, seq)
104 | center = convert_bb_to_center(bb)
105 | success_overlap[i][j] = compute_success_overlap(gt_rect, bb)
106 | # success_error[i][j] = compute_success_error(gt_center, center)
107 |
108 | print('Success Overlap')
109 |
110 | if 'OTB2015' == dataset:
111 | OTB2013_id = []
112 | for i in range(n_seq):
113 | if seqs[i] in OTB2013:
114 | OTB2013_id.append(i)
115 | max_auc_OTB2013 = 0.
116 | max_name_OTB2013 = ''
117 | for i in range(len(trackers)):
118 | auc = success_overlap[OTB2013_id, i, :].mean()
119 | if auc > max_auc_OTB2013:
120 | max_auc_OTB2013 = auc
121 | max_name_OTB2013 = trackers[i]
122 | print('%s(%.4f)' % (trackers[i], auc))
123 |
124 | max_auc = 0.
125 | max_name = ''
126 | for i in range(len(trackers)):
127 | auc = success_overlap[:, i, :].mean()
128 | if auc > max_auc:
129 | max_auc = auc
130 | max_name = trackers[i]
131 | print('%s(%.4f)' % (trackers[i], auc))
132 |
133 | print('\nOTB2013 Best: %s(%.4f)' % (max_name_OTB2013, max_auc_OTB2013))
134 | print('\nOTB2015 Best: %s(%.4f)' % (max_name, max_auc))
135 | else:
136 | max_auc = 0.
137 | max_name = ''
138 | for i in range(len(trackers)):
139 | auc = success_overlap[:, i, :].mean()
140 | if auc > max_auc:
141 | max_auc = auc
142 | max_name = trackers[i]
143 | print('%s(%.4f)' % (trackers[i], auc))
144 |
145 | print('\n%s Best: %s(%.4f)' % (dataset, max_name, max_auc))
146 |
147 |
148 | if __name__ == "__main__":
149 | if len(sys.argv) < 5:
150 | print('python eval_otb.py OTB2015 DCFNet_test* 0 10')
151 | exit()
152 | dataset = sys.argv[1]
153 | tracker_reg = sys.argv[2]
154 | start = int(sys.argv[3])
155 | end = int(sys.argv[4])
156 | eval_auc(dataset, tracker_reg, start, end)
157 |
--------------------------------------------------------------------------------
/track/net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch # pytorch 0.4.0! fft
3 | import numpy as np
4 | import cv2
5 |
6 |
7 | def complex_mul(x, z):
8 | out_real = x[..., 0] * z[..., 0] - x[..., 1] * z[..., 1]
9 | out_imag = x[..., 0] * z[..., 1] + x[..., 1] * z[..., 0]
10 | return torch.stack((out_real, out_imag), -1)
11 |
12 |
13 | def complex_mulconj(x, z):
14 | out_real = x[..., 0] * z[..., 0] + x[..., 1] * z[..., 1]
15 | out_imag = x[..., 1] * z[..., 0] - x[..., 0] * z[..., 1]
16 | return torch.stack((out_real, out_imag), -1)
17 |
18 |
19 | class DCFNetFeature(nn.Module):
20 | def __init__(self):
21 | super(DCFNetFeature, self).__init__()
22 | self.feature = nn.Sequential(
23 | nn.Conv2d(3, 32, 3, padding=1),
24 | nn.ReLU(inplace=True),
25 | nn.Conv2d(32, 32, 3, padding=1),
26 | nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1),
27 | )
28 |
29 | def forward(self, x):
30 | return self.feature(x)
31 |
32 |
33 | class DCFNet(nn.Module):
34 | def __init__(self, config=None):
35 | super(DCFNet, self).__init__()
36 | self.feature = DCFNetFeature()
37 | self.model_alphaf = []
38 | self.model_xf = []
39 | self.config = config
40 |
41 | def forward(self, x):
42 | x = self.feature(x) * self.config.cos_window
43 | xf = torch.rfft(x, signal_ndim=2)
44 | kxzf = torch.sum(complex_mulconj(xf, self.model_zf), dim=1, keepdim=True)
45 | response = torch.irfft(complex_mul(kxzf, self.model_alphaf), signal_ndim=2)
46 | # r_max = torch.max(response)
47 | # cv2.imshow('response', response[0, 0].data.cpu().numpy())
48 | # cv2.waitKey(0)
49 | return response
50 |
51 | def update(self, z, lr=1.):
52 | z = self.feature(z) * self.config.cos_window
53 | zf = torch.rfft(z, signal_ndim=2)
54 | kzzf = torch.sum(torch.sum(zf ** 2, dim=4, keepdim=True), dim=1, keepdim=True)
55 | alphaf = self.config.yf / (kzzf + self.config.lambda0)
56 | if lr > 0.99:
57 | self.model_alphaf = alphaf
58 | self.model_zf = zf
59 | else:
60 | self.model_alphaf = (1 - lr) * self.model_alphaf.data + lr * alphaf.data
61 | self.model_zf = (1 - lr) * self.model_zf.data + lr * zf.data
62 |
63 | def load_param(self, path='param.pth'):
64 | checkpoint = torch.load(path)
65 | if 'state_dict' in checkpoint.keys(): # from training result
66 | state_dict = checkpoint['state_dict']
67 | if 'module' in state_dict.keys()[0]: # train with nn.DataParallel
68 | from collections import OrderedDict
69 | new_state_dict = OrderedDict()
70 | for k, v in state_dict.items():
71 | name = k[7:] # remove `module.`
72 | new_state_dict[name] = v
73 | self.load_state_dict(new_state_dict)
74 | else:
75 | self.load_state_dict(state_dict)
76 | else:
77 | self.feature.load_state_dict(checkpoint)
78 |
79 |
80 | if __name__ == '__main__':
81 |
82 | # network test
83 | net = DCFNetFeature()
84 | net.eval()
85 | for idx, m in enumerate(net.modules()):
86 | print(idx, '->', m)
87 | for name, param in net.named_parameters():
88 | if 'bias' in name or 'weight' in name:
89 | print(param.size())
90 | from scipy import io
91 | import numpy as np
92 | p = io.loadmat('net_param.mat')
93 | x = p['res'][0][0][:,:,::-1].copy()
94 | x_out = p['res'][0][-1]
95 | from collections import OrderedDict
96 | pth_state_dict = OrderedDict()
97 |
98 | match_dict = dict()
99 | match_dict['feature.0.weight'] = 'conv1_w'
100 | match_dict['feature.0.bias'] = 'conv1_b'
101 | match_dict['feature.2.weight'] = 'conv2_w'
102 | match_dict['feature.2.bias'] = 'conv2_b'
103 |
104 | for var_name in net.state_dict().keys():
105 | print var_name
106 | key_in_model = match_dict[var_name]
107 | param_in_model = var_name.rsplit('.', 1)[1]
108 | if 'weight' in var_name:
109 | pth_state_dict[var_name] = torch.Tensor(np.transpose(p[key_in_model],(3,2,0,1)))
110 | elif 'bias' in var_name:
111 | pth_state_dict[var_name] = torch.Tensor(np.squeeze(p[key_in_model]))
112 | if var_name == 'feature.0.weight':
113 | weight = pth_state_dict[var_name].data.numpy()
114 | weight = weight[:, ::-1, :, :].copy() # cv2 bgr input
115 | pth_state_dict[var_name] = torch.Tensor(weight)
116 |
117 |
118 | torch.save(pth_state_dict, 'param.pth')
119 | net.load_state_dict(torch.load('param.pth'))
120 | x_t = torch.Tensor(np.expand_dims(np.transpose(x,(2,0,1)), axis=0))
121 | x_pred = net(x_t).data.numpy()
122 | pred_error = np.sum(np.abs(np.transpose(x_pred,(0,2,3,1)).reshape(-1) - x_out.reshape(-1)))
123 |
124 | x_fft = torch.rfft(x_t, signal_ndim=2, onesided=False)
125 |
126 |
127 | print('model_transfer_error:{:.5f}'.format(pred_error))
128 |
129 |
130 |
--------------------------------------------------------------------------------
/track/net_param.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foolwood/DCFNet_pytorch/b8434baa2d136df8f55c1addb3e77f40b3c379fc/track/net_param.mat
--------------------------------------------------------------------------------
/track/param.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foolwood/DCFNet_pytorch/b8434baa2d136df8f55c1addb3e77f40b3c379fc/track/param.pth
--------------------------------------------------------------------------------
/track/tune_otb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | from os import makedirs
5 | from os.path import isfile, isdir, join
6 | from util import cxy_wh_2_rect1
7 | import torch
8 | import json
9 | from DCFNet import *
10 |
11 | parser = argparse.ArgumentParser(description='Tune parameters for DCFNet tracker on OTB2015')
12 | parser.add_argument('-v', '--visualization', dest='visualization', action='store_true',
13 | help='whether visualize result')
14 |
15 | args = parser.parse_args()
16 |
17 |
18 | def tune_otb(param):
19 | regions = [] # result and states[1 init / 2 lost / 0 skip]
20 | # save result
21 | benchmark_result_path = join('result', param['dataset'])
22 | tracker_path = join(benchmark_result_path, (param['network_name'] +
23 | '_scale_step_{:.3f}'.format(param['config'].scale_step) +
24 | '_scale_penalty_{:.3f}'.format(param['config'].scale_penalty) +
25 | '_interp_factor_{:.3f}'.format(param['config'].interp_factor)))
26 | result_path = join(tracker_path, '{:s}.txt'.format(param['video']))
27 | if isfile(result_path):
28 | return
29 | if not isdir(tracker_path): makedirs(tracker_path)
30 | with open(result_path, 'w') as f: # Occupation
31 | for x in regions:
32 | f.write('')
33 |
34 | ims = param['ims']
35 | toc = 0
36 | for f, im in enumerate(ims):
37 | tic = cv2.getTickCount()
38 | if f == 0: # init
39 | init_rect = p['init_rect']
40 | tracker = DCFNetTraker(ims[f], init_rect, config=param['config'])
41 | regions.append(init_rect)
42 | else: # tracking
43 | rect = tracker.track(ims[f])
44 | regions.append(rect)
45 | toc += cv2.getTickCount() - tic
46 |
47 | if args.visualization: # visualization (skip lost frame)
48 | if f == 0: cv2.destroyAllWindows()
49 | location = [int(l) for l in location] # int
50 | cv2.rectangle(im, (location[0], location[1]), (location[0] + location[2], location[1] + location[3]), (0, 255, 255), 3)
51 | cv2.putText(im, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
52 |
53 | cv2.imshow(video, im)
54 | cv2.waitKey(1)
55 | toc /= cv2.getTickFrequency()
56 | print('{:2d} Video: {:12s} Time: {:2.1f}s Speed: {:3.1f}fps'.format(v, video, toc, f / toc))
57 | regions = np.array(regions)
58 | regions[:,:2] += 1 # 1-index
59 | with open(result_path, 'w') as f:
60 | for x in regions:
61 | f.write(','.join(['{:.2f}'.format(i) for i in x]) + '\n')
62 |
63 |
64 | params = {'dataset':['OTB2013'], 'network':['param.pth'],
65 | 'scale_step':np.arange(1.01, 1.05, 0.005, np.float32),
66 | 'scale_penalty':np.arange(0.98, 1.0, 0.025, np.float32),
67 | 'interp_factor':np.arange(0.001, 0.015, 0.001, np.float32)}
68 |
69 | p = dict()
70 | p['config'] = TrackerConfig()
71 | for network in params['network']:
72 | p['network_name'] = network
73 | np.random.shuffle(params['dataset'])
74 | for dataset in params['dataset']:
75 | base_path = join('dataset', dataset)
76 | json_path = join('dataset', dataset+'.json')
77 | annos = json.load(open(json_path, 'r'))
78 | videos = annos.keys()
79 | p['dataset'] = dataset
80 | np.random.shuffle(videos)
81 | for v, video in enumerate(videos):
82 | p['v'] = v
83 | p['video'] = video
84 | video_path_name = annos[video]['name']
85 | init_rect = np.array(annos[video]['init_rect']).astype(np.float)
86 | image_files = [join(base_path, video_path_name, 'img', im_f) for im_f in annos[video]['image_files']]
87 | target_pos = np.array([init_rect[0] + init_rect[2] / 2 -1 , init_rect[1] + init_rect[3] / 2 -1]) # 0-index
88 | target_sz = np.array([init_rect[2], init_rect[3]])
89 | ims = []
90 | for image_file in image_files:
91 | im = cv2.imread(image_file)
92 | if im.shape[2] == 1:
93 | cv2.cvtColor(im, im, cv2.COLOR_GRAY2RGB)
94 | ims.append(im)
95 | p['ims'] = ims
96 | p['init_rect'] = init_rect
97 |
98 | np.random.shuffle(params['scale_step'])
99 | np.random.shuffle(params['scale_penalty'])
100 | np.random.shuffle(params['interp_factor'])
101 | for scale_step in params['scale_step']:
102 | for scale_penalty in params['scale_penalty']:
103 | for interp_factor in params['interp_factor']:
104 | p['config'].scale_step = float(scale_step)
105 | p['config'].scale_penalty = float(scale_penalty)
106 | p['config'].interp_factor = float(interp_factor)
107 | tune_otb(p)
108 |
--------------------------------------------------------------------------------
/track/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | def cxy_wh_2_rect1(pos, sz):
6 | return np.array([pos[0]-sz[0]/2+1, pos[1]-sz[1]/2+1, sz[0], sz[1]]) # 1-index
7 |
8 |
9 | def rect1_2_cxy_wh(rect):
10 | return np.array([rect[0]+rect[2]/2-1, rect[1]+rect[3]/2-1]), np.array([rect[2], rect[3]]) # 0-index
11 |
12 |
13 | def cxy_wh_2_bbox(cxy, wh):
14 | return np.array([cxy[0]-wh[0]/2, cxy[1]-wh[1]/2, cxy[0]+wh[0]/2, cxy[1]+wh[1]/2]) # 0-index
15 |
16 |
17 | def gaussian_shaped_labels(sigma, sz):
18 | x, y = np.meshgrid(np.arange(1, sz[0]+1) - np.floor(float(sz[0]) / 2), np.arange(1, sz[1]+1) - np.floor(float(sz[1]) / 2))
19 | d = x ** 2 + y ** 2
20 | g = np.exp(-0.5 / (sigma ** 2) * d)
21 | g = np.roll(g, int(-np.floor(float(sz[0]) / 2.) + 1), axis=0)
22 | g = np.roll(g, int(-np.floor(float(sz[1]) / 2.) + 1), axis=1)
23 | return g
24 |
25 |
26 | def crop_chw(image, bbox, out_sz, padding=(0, 0, 0)):
27 | a = (out_sz-1) / (bbox[2]-bbox[0])
28 | b = (out_sz-1) / (bbox[3]-bbox[1])
29 | c = -a * bbox[0]
30 | d = -b * bbox[1]
31 | mapping = np.array([[a, 0, c],
32 | [0, b, d]]).astype(np.float)
33 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding)
34 | return np.transpose(crop, (2, 0, 1))
35 |
36 |
37 | if __name__ == '__main__':
38 | a = gaussian_shaped_labels(10, [5,5])
39 | print a
--------------------------------------------------------------------------------
/train/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from os.path import join
3 | import cv2
4 | import json
5 | import numpy as np
6 |
7 |
8 | class VID(data.Dataset):
9 | def __init__(self, file='dataset/dataset.json', root='dataset/crop_125_2.0', range=10, train=True):
10 | self.imdb = json.load(open(file, 'r'))
11 | self.root = root
12 | self.range = range
13 | self.train = train
14 | self.mean = np.expand_dims(np.expand_dims(np.array([109, 120, 119]), axis=1), axis=1).astype(np.float32)
15 |
16 | def __getitem__(self, item):
17 | if self.train:
18 | target_id = self.imdb['train_set'][item]
19 | else:
20 | target_id = self.imdb['val_set'][item]
21 |
22 | # range_down = self.imdb['down_index'][target_id]
23 | range_up = self.imdb['up_index'][target_id]
24 | # search_id = np.random.randint(-min(range_down, self.range), min(range_up, self.range)) + target_id
25 | search_id = np.random.randint(1, min(range_up, self.range+1)) + target_id
26 |
27 | target = cv2.imread(join(self.root, '{:08d}.jpg'.format(target_id)))
28 | search = cv2.imread(join(self.root, '{:08d}.jpg'.format(search_id)))
29 |
30 | target = np.transpose(target, (2, 0, 1)).astype(np.float32) - self.mean
31 | search = np.transpose(search, (2, 0, 1)).astype(np.float32) - self.mean
32 |
33 | return target, search
34 |
35 | def __len__(self):
36 | if self.train:
37 | return len(self.imdb['train_set'])
38 | else:
39 | return len(self.imdb['val_set'])
40 |
41 |
42 | if __name__ == '__main__':
43 | import matplotlib.pyplot as plt
44 | import matplotlib.patches as patches
45 | data = VID(train=True)
46 | n = len(data)
47 | fig = plt.figure(1)
48 | ax = fig.add_axes([0, 0, 1, 1])
49 |
50 | for i in range(n):
51 | z, x = data[i]
52 | z, x = np.transpose(z, (1, 2, 0)).astype(np.uint8), np.transpose(x, (1, 2, 0)).astype(np.uint8)
53 | zx = np.concatenate((z, x), axis=1)
54 |
55 | ax.imshow(cv2.cvtColor(zx, cv2.COLOR_BGR2RGB))
56 | p = patches.Rectangle(
57 | (125/3, 125/3), 125/3, 125/3, fill=False, clip_on=False, linewidth=2, edgecolor='g')
58 | ax.add_patch(p)
59 | p = patches.Rectangle(
60 | (125 / 3+125, 125 / 3), 125 / 3, 125 / 3, fill=False, clip_on=False, linewidth=2, edgecolor='r')
61 | ax.add_patch(p)
62 | plt.pause(0.5)
63 |
--------------------------------------------------------------------------------
/train/dataset/compute-image-mean.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import os
4 | import time
5 | import glob
6 |
7 | from skimage import io
8 | import cv2
9 |
10 | if __name__ == '__main__':
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--meanPrefix', default='mean_img', type=str, help="Prefix of the mean file.")
13 | parser.add_argument('--imageDir', default='crop_125_2.0', type=str, help="Directory of images to read.")
14 | args = parser.parse_args()
15 |
16 | mean = np.zeros((1, 3, 125, 125))
17 | N = 0
18 | opencv_backend = True
19 | beginTime = time.time()
20 | files = glob.glob(os.path.join(args.imageDir, '*.jpg'))
21 | for file in files:
22 | if opencv_backend:
23 | img = cv2.imread(file)
24 | else:
25 | img = io.imread(file)
26 | if img.shape == (125, 125, 3):
27 | mean[0][0] += img[:, :, 0]
28 | mean[0][1] += img[:, :, 1]
29 | mean[0][2] += img[:, :, 2]
30 | N += 1
31 | if N % 1000 == 0:
32 | elapsed = time.time() - beginTime
33 | print("Processed {} images in {:.2f} seconds. "
34 | "{:.2f} images/second.".format(N, elapsed, N / elapsed))
35 | mean[0] /= N
36 |
37 | meanImg = np.transpose(mean[0].astype(np.uint8), (1, 2, 0))
38 | if opencv_backend:
39 | cv2.imwrite("{}.png".format(args.meanPrefix), meanImg)
40 | else:
41 | io.imsave("{}.png".format(args.meanPrefix), meanImg)
42 |
43 | avg_chans = np.mean(meanImg, axis=(0, 1))
44 | if opencv_backend:
45 | print("image BGR mean: {}".format(avg_chans))
46 | else:
47 | print("image RGB mean: {}".format(avg_chans))
--------------------------------------------------------------------------------
/train/dataset/crop_image.py:
--------------------------------------------------------------------------------
1 | from os.path import join, isdir
2 | from os import mkdir
3 | import argparse
4 | import numpy as np
5 | import json
6 | import cv2
7 | import time
8 |
9 | parse = argparse.ArgumentParser(description='Generate training data (cropped) for DCFNet_pytorch')
10 | parse.add_argument('-v', '--visual', dest='visual', action='store_true', help='whether visualise crop')
11 | parse.add_argument('-o', '--output_size', dest='output_size', default=125, type=int, help='crop output size')
12 | parse.add_argument('-p', '--padding', dest='padding', default=2, type=float, help='crop padding size')
13 |
14 | args = parse.parse_args()
15 |
16 | print args
17 |
18 |
19 | def crop_hwc(image, bbox, out_sz, padding=(0, 0, 0)):
20 | bbox = [float(x) for x in bbox]
21 | a = (out_sz-1) / (bbox[2]-bbox[0])
22 | b = (out_sz-1) / (bbox[3]-bbox[1])
23 | c = -a * bbox[0]
24 | d = -b * bbox[1]
25 | mapping = np.array([[a, 0, c],
26 | [0, b, d]]).astype(np.float)
27 | crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding)
28 | return crop
29 |
30 |
31 | def cxy_wh_2_bbox(cxy, wh):
32 | return np.array([cxy[0] - wh[0] / 2, cxy[1] - wh[1] / 2, cxy[0] + wh[0] / 2, cxy[1] + wh[1] / 2]) # 0-index
33 |
34 |
35 | snaps = json.load(open('snippet.json', 'r'))
36 |
37 | num_all_frame = 546315 # cat snippet.json | grep bbox |wc -l
38 | num_val = 1000
39 | # crop image
40 | lmdb = dict()
41 | lmdb['down_index'] = np.zeros(num_all_frame, np.int) # buff
42 | lmdb['up_index'] = np.zeros(num_all_frame, np.int)
43 |
44 | crop_base_path = 'crop_{:d}_{:1.1f}'.format(args.output_size, args.padding)
45 | if not isdir(crop_base_path):
46 | mkdir(crop_base_path)
47 |
48 | count = 0
49 | begin_time = time.time()
50 | for snap in snaps:
51 | frames = snap['frame']
52 | n_frames = len(frames)
53 | for f, frame in enumerate(frames):
54 | img_path = join(snap['base_path'], frame['img_path'])
55 | im = cv2.imread(img_path)
56 | avg_chans = np.mean(im, axis=(0, 1))
57 | bbox = frame['obj']['bbox']
58 |
59 | target_pos = [(bbox[2] + bbox[0])/2, (bbox[3] + bbox[1])/2]
60 | target_sz = np.array([bbox[2] - bbox[0], bbox[3] - bbox[1]])
61 | window_sz = target_sz * (1 + args.padding)
62 | crop_bbox = cxy_wh_2_bbox(target_pos, window_sz)
63 | patch = crop_hwc(im, crop_bbox, args.output_size)
64 | cv2.imwrite(join(crop_base_path, '{:08d}.jpg'.format(count)), patch)
65 | # cv2.imwrite('crop.jpg'.format(count), patch)
66 |
67 | lmdb['down_index'][count] = f
68 | lmdb['up_index'][count] = n_frames - f
69 | count += 1
70 | if count % 100 == 0:
71 | elapsed = time.time() - begin_time
72 | print("Processed {} images in {:.2f} seconds. "
73 | "{:.2f} images/second.".format(count, elapsed, count / elapsed))
74 |
75 | template_id = np.where(lmdb['up_index'] > 1)[0] # NEVER use the last frame as template! I do not like bidirectional.
76 | rand_split = np.random.choice(len(template_id), len(template_id))
77 | lmdb['train_set'] = template_id[rand_split[:(len(template_id)-num_val)]]
78 | lmdb['val_set'] = template_id[rand_split[(len(template_id)-num_val):]]
79 | print len(lmdb['train_set'])
80 | print len(lmdb['val_set'])
81 |
82 | # to list for json
83 | lmdb['train_set'] = lmdb['train_set'].tolist()
84 | lmdb['val_set'] = lmdb['val_set'].tolist()
85 | lmdb['down_index'] = lmdb['down_index'].tolist()
86 | lmdb['up_index'] = lmdb['up_index'].tolist()
87 |
88 | print('lmdb json, please wait 5 seconds~')
89 | json.dump(lmdb, open('dataset.json', 'w'), indent=2)
90 | print('done!')
91 |
--------------------------------------------------------------------------------
/train/dataset/gen_snippet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 |
4 |
5 | def check_size(frame_sz, bbox):
6 | min_ratio = 0.1
7 | max_ratio = 0.75
8 | # only accept objects >10% and <75% of the total frame
9 | area_ratio = np.sqrt((bbox[2]-bbox[0])*(bbox[3]-bbox[1])/float(np.prod(frame_sz)))
10 | ok = (area_ratio > min_ratio) and (area_ratio < max_ratio)
11 | return ok
12 |
13 |
14 | def check_borders(frame_sz, bbox):
15 | dist_from_border = 0.05 * (bbox[2] - bbox[0] + bbox[3] - bbox[1])/2
16 | ok = (bbox[0] > dist_from_border) and (bbox[1] > dist_from_border) and \
17 | ((frame_sz[0] - bbox[2]) > dist_from_border) and \
18 | ((frame_sz[1] - bbox[3]) > dist_from_border)
19 | return ok
20 |
21 |
22 | # Filter out snippets
23 | print('load json (raw vid info), please wait 20 seconds~')
24 | vid = json.load(open('vid.json', 'r'))
25 | snippets = []
26 | n_snippets = 0
27 | n_videos = 0
28 | for subset in vid:
29 | for video in subset:
30 | n_videos += 1
31 | frames = video['frame']
32 | id_set = []
33 | id_frames = [[]] * 60 # at most 60 objects
34 | for f, frame in enumerate(frames):
35 | objs = frame['objs']
36 | frame_sz = frame['frame_sz']
37 | for obj in objs:
38 | trackid = obj['trackid']
39 | occluded = obj['occ']
40 | bbox = obj['bbox']
41 | if occluded:
42 | continue
43 |
44 | if not(check_size(frame_sz, bbox) and check_borders(frame_sz, bbox)):
45 | continue
46 |
47 | if obj['c'] in ['n01674464', 'n01726692', 'n04468005', 'n02062744']:
48 | continue
49 |
50 | if trackid not in id_set:
51 | id_set.append(trackid)
52 | id_frames[trackid] = []
53 | id_frames[trackid].append(f)
54 |
55 | for selected in id_set:
56 | frame_ids = sorted(id_frames[selected])
57 | sequences = np.split(frame_ids, np.array(np.where(np.diff(frame_ids) > 1)[0]) + 1)
58 | sequences = [s for s in sequences if len(s) > 1] # remove isolated frame.
59 | for seq in sequences:
60 | snippet = dict()
61 | snippet['base_path'] = video['base_path']
62 | snippet['frame'] = []
63 | for frame_id in seq:
64 | frame = frames[frame_id]
65 | f = dict()
66 | f['frame_sz'] = frame['frame_sz']
67 | f['img_path'] = frame['img_path']
68 | for obj in frame['objs']:
69 | if obj['trackid'] == selected:
70 | o = obj
71 | continue
72 | f['obj'] = o
73 | snippet['frame'].append(f)
74 | snippets.append(snippet)
75 | n_snippets += 1
76 | print('video: {:d} snippets_num: {:d}'.format(n_videos, n_snippets))
77 |
78 | print('save json (snippets), please wait 20 seconds~')
79 | json.dump(snippets, open('snippet.json', 'w'), indent=2)
80 | print('done!')
81 |
--------------------------------------------------------------------------------
/train/dataset/parse_vid.py:
--------------------------------------------------------------------------------
1 | from os.path import join, isdir
2 | from os import listdir
3 | import argparse
4 | import json
5 | import glob
6 | import xml.etree.ElementTree as ET
7 |
8 | parser = argparse.ArgumentParser(description='Parse the VID Annotations for training DCFNet')
9 | parser.add_argument('data', metavar='DIR', help='path to VID')
10 | args = parser.parse_args()
11 |
12 | print('VID2015 Data:')
13 | VID_base_path = args.data
14 | ann_base_path = join(VID_base_path, 'Annotations/VID/train/')
15 | img_base_path = join(VID_base_path, 'Data/VID/train/')
16 | sub_sets = sorted({'a', 'b', 'c', 'd', 'e'})
17 |
18 | vid = []
19 | for sub_set in sub_sets:
20 | sub_set_base_path = join(ann_base_path, sub_set)
21 | videos = sorted(listdir(sub_set_base_path))
22 | s = []
23 | for vi, video in enumerate(videos):
24 | print('subset: {} video id: {:04d} / {:04d}'.format(sub_set, vi, len(videos)))
25 | v = dict()
26 | v['base_path'] = join(img_base_path, sub_set, video)
27 | v['frame'] = []
28 | video_base_path = join(sub_set_base_path, video)
29 | xmls = sorted(glob.glob(join(video_base_path, '*.xml')))
30 | for xml in xmls:
31 | f = dict()
32 | xmltree = ET.parse(xml)
33 | size = xmltree.findall('size')[0]
34 | frame_sz = [int(it.text) for it in size]
35 | objects = xmltree.findall('object')
36 | objs = []
37 | for object_iter in objects:
38 | trackid = int(object_iter.find('trackid').text)
39 | name = (object_iter.find('name')).text
40 | bndbox = object_iter.find('bndbox')
41 | occluded = int(object_iter.find('occluded').text)
42 | o = dict()
43 | o['c'] = name
44 | o['bbox'] = [int(bndbox.find('xmin').text), int(bndbox.find('ymin').text),
45 | int(bndbox.find('xmax').text), int(bndbox.find('ymax').text)]
46 | o['trackid'] = trackid
47 | o['occ'] = occluded
48 | objs.append(o)
49 | f['frame_sz'] = frame_sz
50 | f['img_path'] = xml.split('/')[-1].replace('xml', 'JPEG')
51 | f['objs'] = objs
52 | v['frame'].append(f)
53 | s.append(v)
54 | vid.append(s)
55 | print('save json (raw vid info), please wait 1 min~')
56 | json.dump(vid, open('vid.json', 'w'), indent=2)
57 | print('done!')
58 |
59 |
--------------------------------------------------------------------------------
/train/net.py:
--------------------------------------------------------------------------------
1 | import torch # pytorch 0.4.0! fft
2 | import torch.nn as nn
3 |
4 |
5 | def complex_mul(x, z):
6 | out_real = x[..., 0] * z[..., 0] - x[..., 1] * z[..., 1]
7 | out_imag = x[..., 0] * z[..., 1] + x[..., 1] * z[..., 0]
8 | return torch.stack((out_real, out_imag), -1)
9 |
10 |
11 | def complex_mulconj(x, z):
12 | out_real = x[..., 0] * z[..., 0] + x[..., 1] * z[..., 1]
13 | out_imag = x[..., 1] * z[..., 0] - x[..., 0] * z[..., 1]
14 | return torch.stack((out_real, out_imag), -1)
15 |
16 |
17 | class DCFNetFeature(nn.Module):
18 | def __init__(self):
19 | super(DCFNetFeature, self).__init__()
20 | self.feature = nn.Sequential(
21 | nn.Conv2d(3, 32, 3),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(32, 32, 3),
24 | nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1),
25 | )
26 |
27 | def forward(self, x):
28 | return self.feature(x)
29 |
30 |
31 | class DCFNet(nn.Module):
32 | def __init__(self, config=None):
33 | super(DCFNet, self).__init__()
34 | self.feature = DCFNetFeature()
35 | self.yf = config.yf.clone()
36 | self.lambda0 = config.lambda0
37 |
38 | def forward(self, z, x):
39 | z = self.feature(z)
40 | x = self.feature(x)
41 | zf = torch.rfft(z, signal_ndim=2)
42 | xf = torch.rfft(x, signal_ndim=2)
43 |
44 | kzzf = torch.sum(torch.sum(zf ** 2, dim=4, keepdim=True), dim=1, keepdim=True)
45 | kxzf = torch.sum(complex_mulconj(xf, zf), dim=1, keepdim=True)
46 | alphaf = self.yf.to(device=z.device) / (kzzf + self.lambda0) # very Ugly
47 | response = torch.irfft(complex_mul(kxzf, alphaf), signal_ndim=2)
48 | return response
49 |
50 |
51 | if __name__ == '__main__':
52 |
53 | # network test
54 | net = DCFNet()
55 | net.eval()
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/train/train_DCFNet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import shutil
3 | from os.path import join, isdir, isfile
4 | from os import makedirs
5 |
6 | from dataset import VID
7 | from net import DCFNet
8 | import torch
9 | from torch.utils.data import dataloader
10 | import torch.nn as nn
11 | import torch.backends.cudnn as cudnn
12 | import numpy as np
13 | import time
14 |
15 |
16 | parser = argparse.ArgumentParser(description='Training DCFNet in Pytorch 0.4.0')
17 | parser.add_argument('--input_sz', dest='input_sz', default=125, type=int, help='crop input size')
18 | parser.add_argument('--padding', dest='padding', default=2.0, type=float, help='crop padding size')
19 | parser.add_argument('--range', dest='range', default=10, type=int, help='select range')
20 | parser.add_argument('--epochs', default=50, type=int, metavar='N',
21 | help='number of total epochs to run')
22 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
23 | help='manual epoch number (useful on restarts)')
24 | parser.add_argument('--print-freq', '-p', default=10, type=int,
25 | metavar='N', help='print frequency (default: 10)')
26 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
27 | help='number of data loading workers (default: 8)')
28 | parser.add_argument('-b', '--batch-size', default=32, type=int,
29 | metavar='N', help='mini-batch size (default: 32)')
30 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
31 | metavar='LR', help='initial learning rate')
32 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
33 | help='momentum')
34 | parser.add_argument('--weight-decay', '--wd', default=5e-5, type=float,
35 | metavar='W', help='weight decay (default: 5e-5)')
36 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
37 | parser.add_argument('--save', '-s', default='./work', type=str, help='directory for saving')
38 |
39 | args = parser.parse_args()
40 |
41 | print args
42 | best_loss = 1e6
43 |
44 |
45 | def gaussian_shaped_labels(sigma, sz):
46 | x, y = np.meshgrid(np.arange(1, sz[0]+1) - np.floor(float(sz[0]) / 2), np.arange(1, sz[1]+1) - np.floor(float(sz[1]) / 2))
47 | d = x ** 2 + y ** 2
48 | g = np.exp(-0.5 / (sigma ** 2) * d)
49 | g = np.roll(g, int(-np.floor(float(sz[0]) / 2.) + 1), axis=0)
50 | g = np.roll(g, int(-np.floor(float(sz[1]) / 2.) + 1), axis=1)
51 | return g.astype(np.float32)
52 |
53 |
54 | class TrackerConfig(object):
55 | crop_sz = 125
56 | output_sz = 121
57 |
58 | lambda0 = 1e-4
59 | padding = 2.0
60 | output_sigma_factor = 0.1
61 |
62 | output_sigma = crop_sz / (1 + padding) * output_sigma_factor
63 | y = gaussian_shaped_labels(output_sigma, [output_sz, output_sz])
64 | yf = torch.rfft(torch.Tensor(y).view(1, 1, output_sz, output_sz).cuda(), signal_ndim=2)
65 | # cos_window = torch.Tensor(np.outer(np.hanning(crop_sz), np.hanning(crop_sz))).cuda() # train without cos window
66 |
67 |
68 | config = TrackerConfig()
69 |
70 | model = DCFNet(config=config)
71 | model.cuda()
72 | gpu_num = torch.cuda.device_count()
73 | print('GPU NUM: {:2d}'.format(gpu_num))
74 | if gpu_num > 1:
75 | model = torch.nn.DataParallel(model, list(range(gpu_num))).cuda()
76 |
77 | criterion = nn.MSELoss(size_average=False).cuda()
78 |
79 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
80 | momentum=args.momentum,
81 | weight_decay=args.weight_decay)
82 |
83 | target = torch.Tensor(config.y).cuda().unsqueeze(0).unsqueeze(0).repeat(args.batch_size * gpu_num, 1, 1, 1) # for training
84 | # optionally resume from a checkpoint
85 | if args.resume:
86 | if isfile(args.resume):
87 | print("=> loading checkpoint '{}'".format(args.resume))
88 | checkpoint = torch.load(args.resume)
89 | args.start_epoch = checkpoint['epoch']
90 | best_loss = checkpoint['best_loss']
91 | model.load_state_dict(checkpoint['state_dict'])
92 | optimizer.load_state_dict(checkpoint['optimizer'])
93 | print("=> loaded checkpoint '{}' (epoch {})"
94 | .format(args.resume, checkpoint['epoch']))
95 | else:
96 | print("=> no checkpoint found at '{}'".format(args.resume))
97 |
98 | cudnn.benchmark = True
99 |
100 | # training data
101 | crop_base_path = join('dataset', 'crop_{:d}_{:1.1f}'.format(args.input_sz, args.padding))
102 | if not isdir(crop_base_path):
103 | print('please run gen_training_data.py --output_size {:d} --padding {:.1f}!'.format(args.input_sz, args.padding))
104 | exit()
105 |
106 | save_path = join(args.save, 'crop_{:d}_{:1.1f}'.format(args.input_sz, args.padding))
107 | if not isdir(save_path):
108 | makedirs(save_path)
109 |
110 | train_dataset = VID(root=crop_base_path, train=True, range=args.range)
111 | val_dataset = VID(root=crop_base_path, train=False, range=args.range)
112 |
113 | train_loader = torch.utils.data.DataLoader(
114 | train_dataset, batch_size=args.batch_size*gpu_num, shuffle=True,
115 | num_workers=args.workers, pin_memory=True, drop_last=True)
116 |
117 | val_loader = torch.utils.data.DataLoader(
118 | val_dataset, batch_size=args.batch_size*gpu_num, shuffle=False,
119 | num_workers=args.workers, pin_memory=True, drop_last=True)
120 |
121 |
122 | def adjust_learning_rate(optimizer, epoch):
123 | lr = np.logspace(-2, -5, num=args.epochs)[epoch]
124 | for param_group in optimizer.param_groups:
125 | param_group['lr'] = lr
126 |
127 |
128 | class AverageMeter(object):
129 | """Computes and stores the average and current value"""
130 | def __init__(self):
131 | self.reset()
132 |
133 | def reset(self):
134 | self.val = 0
135 | self.avg = 0
136 | self.sum = 0
137 | self.count = 0
138 |
139 | def update(self, val, n=1):
140 | self.val = val
141 | self.sum += val * n
142 | self.count += n
143 | self.avg = self.sum / self.count
144 |
145 |
146 | def save_checkpoint(state, is_best, filename=join(save_path, 'checkpoint.pth.tar')):
147 | torch.save(state, filename)
148 | if is_best:
149 | shutil.copyfile(filename, join(save_path, 'model_best.pth.tar'))
150 |
151 |
152 | def train(train_loader, model, criterion, optimizer, epoch):
153 | batch_time = AverageMeter()
154 | data_time = AverageMeter()
155 | losses = AverageMeter()
156 |
157 | # switch to train mode
158 | model.train()
159 |
160 | end = time.time()
161 | for i, (template, search) in enumerate(train_loader):
162 | # measure data loading time
163 | data_time.update(time.time() - end)
164 |
165 | template = template.cuda(non_blocking=True)
166 | search = search.cuda(non_blocking=True)
167 |
168 | # compute output
169 | output = model(template, search)
170 | loss = criterion(output, target)/template.size(0)
171 |
172 | # measure accuracy and record loss
173 | losses.update(loss.item())
174 |
175 | # compute gradient and do SGD step
176 | optimizer.zero_grad()
177 | loss.backward()
178 | optimizer.step()
179 |
180 | # measure elapsed time
181 | batch_time.update(time.time() - end)
182 | end = time.time()
183 |
184 | if i % args.print_freq == 0:
185 | print('Epoch: [{0}][{1}/{2}]\t'
186 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
187 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
188 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
189 | epoch, i, len(train_loader), batch_time=batch_time,
190 | data_time=data_time, loss=losses))
191 |
192 |
193 | def validate(val_loader, model, criterion):
194 | batch_time = AverageMeter()
195 | losses = AverageMeter()
196 |
197 | # switch to evaluate mode
198 | model.eval()
199 |
200 | with torch.no_grad():
201 | end = time.time()
202 | for i, (template, search) in enumerate(val_loader):
203 |
204 | # compute output
205 | template = template.cuda(non_blocking=True)
206 | search = search.cuda(non_blocking=True)
207 |
208 | # compute output
209 | output = model(template, search)
210 | loss = criterion(output, target)/(args.batch_size * gpu_num)
211 |
212 | # measure accuracy and record loss
213 | losses.update(loss.item())
214 |
215 | # measure elapsed time
216 | batch_time.update(time.time() - end)
217 | end = time.time()
218 |
219 | if i % args.print_freq == 0:
220 | print('Test: [{0}/{1}]\t'
221 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
222 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
223 | i, len(val_loader), batch_time=batch_time, loss=losses))
224 |
225 | print(' * Loss {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses))
226 |
227 | return losses.avg
228 |
229 |
230 | for epoch in range(args.start_epoch, args.epochs):
231 | adjust_learning_rate(optimizer, epoch)
232 |
233 | # train for one epoch
234 | train(train_loader, model, criterion, optimizer, epoch)
235 |
236 | # evaluate on validation set
237 | loss = validate(val_loader, model, criterion)
238 |
239 | # remember best loss and save checkpoint
240 | is_best = loss < best_loss
241 | best_loss = min(best_loss, loss)
242 | save_checkpoint({
243 | 'epoch': epoch + 1,
244 | 'state_dict': model.state_dict(),
245 | 'best_loss': best_loss,
246 | 'optimizer': optimizer.state_dict(),
247 | }, is_best)
248 |
--------------------------------------------------------------------------------