├── img
├── 100.png
├── xywh.png
├── SiamRPN.png
├── WeChat.jpg
├── centr.png
├── coord.png
├── corners.png
├── error.png
└── new_file.png
├── train
├── __pycache__
│ └── net.cpython-37.pyc
├── experiments
│ └── default
│ │ └── parameters.json
├── test.py
├── config.py
├── network.py
├── custom_transforms.py
├── net.py
├── loss.py
├── train_siamrpn.py
├── util.py
└── data.py
├── tracking
├── experiments
│ └── default
│ │ └── parameters.json
├── config.py
├── run_tracking.py
├── network.py
├── data_loader.py
├── util.py
├── custom_transforms.py
└── siamRPNBIG.py
├── .gitignore
├── README.md
└── fixed.py
/img/100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/img/100.png
--------------------------------------------------------------------------------
/img/xywh.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/img/xywh.png
--------------------------------------------------------------------------------
/img/SiamRPN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/img/SiamRPN.png
--------------------------------------------------------------------------------
/img/WeChat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/img/WeChat.jpg
--------------------------------------------------------------------------------
/img/centr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/img/centr.png
--------------------------------------------------------------------------------
/img/coord.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/img/coord.png
--------------------------------------------------------------------------------
/img/corners.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/img/corners.png
--------------------------------------------------------------------------------
/img/error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/img/error.png
--------------------------------------------------------------------------------
/img/new_file.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/img/new_file.png
--------------------------------------------------------------------------------
/train/__pycache__/net.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arbitularov/SiamRPN-PyTorch/HEAD/train/__pycache__/net.cpython-37.pyc
--------------------------------------------------------------------------------
/train/experiments/default/parameters.json:
--------------------------------------------------------------------------------
1 | {
2 | "template_img_size": 127,
3 | "detection_img_size": 271,
4 | "stride": 8,
5 | "lr": 1e-5,
6 | "epoches": 200,
7 | "weight_decay": 0.0005,
8 | "momentum": 0.9,
9 | "context": 0.5,
10 | "ratios": [0.33, 0.5, 1, 2, 3],
11 | "scales": [8],
12 | "penalty_k": 0.055,
13 | "window_influence": 0.42
14 | }
15 |
--------------------------------------------------------------------------------
/tracking/experiments/default/parameters.json:
--------------------------------------------------------------------------------
1 | {
2 | "template_img_size": 127,
3 | "detection_img_size": 271,
4 | "stride": 8,
5 | "lr": 0.001,
6 | "epoches": 20,
7 | "weight_decay": 0.00005,
8 | "momentum": 0.9,
9 | "context": 0.5,
10 | "ratios": [0.33, 0.5, 1, 2, 3],
11 | "scales": [8],
12 | "penalty_k": 0.055,
13 | "window_influence": 0.42
14 | }
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | test/data
2 | youtube_BB
3 | youtube-bb.py
4 | experiments
5 | test/__pycache__
6 | __pycache__
7 | test/results
8 | tracking/results
9 | model
10 | model_e1.pth
11 | model_e2.pth
12 | model_e3.pth
13 | model_e4.pth
14 | weights-000.pth.tar
15 | test
16 | OTBreports
17 | OTBresults
18 | results
19 | siamrpn_7.pth
20 | siamrpn_50.pth
21 | siamrpn_25.pth
22 | model_e6.pth
23 | model_e9.pth
24 | model_e11.pth
25 | model_e12.pth
26 | model_e10.pth
27 | model_e25.pth
28 | model_e74.pth
29 | model_got.pth
30 | model.pth
31 | test.png
32 | test1.png
33 | weights-0690000.pth.tar
34 | cache
35 |
--------------------------------------------------------------------------------
/train/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | foto = cv2.imread('detection_img.png')
5 |
6 | f = open('text.txt', 'r')
7 | boxes = []
8 | for line in f:
9 | a = line.replace('.', ',')
10 | a = a.replace('[', '')
11 | a = a.replace(']', '')
12 | a = a.replace("\n", '')
13 | a = a.split(',')
14 |
15 |
16 | a = np.asarray(a)
17 | a = np.asarray(a)
18 | a = np.asarray([int(a[0]), int(a[1]), int(a[2]), int(a[3])])
19 | boxes.append(a)
20 | f.close()
21 |
22 | coint = 0
23 | for box in boxes:
24 | print(box)
25 | cx , cy, w, h = box
26 | cx_big = 255/2 + (cx/0.16)
27 | cy_big = 255/2 + (cy/0.16)
28 |
29 | x1 = int(cx_big - w/2)
30 | x2 = int(cx_big + w/2)
31 |
32 | y1 = int(cy_big - h/2)
33 | y2 = int(cy_big + h/2)
34 |
35 | r = int(np.random.choice(range(250)))
36 | g = int(np.random.choice(range(250)))
37 | b = int(np.random.choice(range(250)))
38 | coint += 1
39 | '''if coint >= 3:
40 | coint = 1'''
41 |
42 | frame = cv2.rectangle(foto, (x1,y1), (x2,y2), (r, g, b), coint)
43 |
44 | cv2.imwrite('detection_img1.png',frame)
45 |
--------------------------------------------------------------------------------
/tracking/config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class TrackerConfig(object):
4 | # These are the default hyper-params for DaSiamRPN 0.3827
5 | windowing = 'cosine' # to penalize large displacements [cosine/uniform]
6 | # Params from the network architecture, have to be consistent with the training
7 | template_img_size = 127 # input z size
8 | detection_img_size = 271 # input x size (search region)
9 | total_stride = 8
10 | valid_scope = int((detection_img_size - template_img_size) / total_stride / 2)
11 | score_size = int((detection_img_size - template_img_size)/total_stride+1)
12 | context_amount = 0.5 # context amount for the exemplar
13 | ratios = [0.33, 0.5, 1, 2, 3]
14 | scales = [8, ]
15 | anchor_num = len(ratios) * len(scales)
16 | anchor = []
17 | penalty_k = 0.055
18 | window_influence = 0.42
19 | lr = 0.295
20 | lr_box = 0.30
21 |
22 | min_scale = 0.1
23 | max_scale = 10
24 |
25 | anchor_base_size = 8
26 | anchor_scales = np.array([8, ])
27 | anchor_ratios = np.array([0.33, 0.5, 1, 2, 3])
28 | size = anchor_num * score_size * score_size
29 |
30 | config = TrackerConfig()
31 |
--------------------------------------------------------------------------------
/train/config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Config(object):
4 |
5 | '''config for train_siamrpn.py'''
6 | template_img_size = 127
7 | detection_img_size = 271
8 | epoches = 200
9 | train_epoch_size = 1000
10 | val_epoch_size = 100
11 |
12 | train_batch_size = 32 # training batch size
13 | valid_batch_size = 8 # validation batch size
14 | train_num_workers = 16 # number of workers of train dataloader
15 | valid_num_workers = 16
16 |
17 | start_lr = 3e-3
18 | end_lr = 1e-5
19 | warm_lr = 1e-3
20 | warm_scale = warm_lr/start_lr
21 | lr = np.logspace(np.log10(start_lr), np.log10(end_lr), num=epoches)[0]
22 | gamma = np.logspace(np.log10(start_lr), np.log10(end_lr), num=epoches)[1] / \
23 | np.logspace(np.log10(start_lr), np.log10(end_lr), num=epoches)[0]
24 | momentum = 0.9
25 | weight_decay = 0.0005
26 |
27 | clip = 100 # grad clip
28 |
29 | anchor_scales = np.array([8, ])
30 | anchor_ratios = np.array([0.33, 0.5, 1, 2, 3])
31 | anchor_num = len(anchor_scales) * len(anchor_ratios) # 5
32 | score_size = int((detection_img_size - template_img_size) / 8 + 1)
33 | size = anchor_num * score_size * score_size
34 |
35 | '''config for data.py'''
36 |
37 | out_feature = 19
38 | max_inter = 80
39 | fix_former_3_layers = True
40 | pretrained_model = '/home/arbi/Загрузки/alexnet.pth' # '/home/arbi/desktop/alexnet.pth' # # '/Users/arbi/Desktop/alexnet.pth'
41 |
42 | total_stride = 8
43 | anchor_total_stride = total_stride
44 | anchor_base_size = 8
45 | anchor_scales = np.array([8, ])
46 | anchor_ratios = np.array([0.33, 0.5, 1, 2, 3])
47 |
48 | valid_scope = int((detection_img_size - template_img_size) / total_stride / 2)
49 | anchor_valid_scope = 2 * valid_scope + 1
50 | pos_threshold = 0.6
51 | neg_threshold = 0.3
52 |
53 | context = 0.5
54 | penalty_k = 0.055
55 | window_influence = 0.42
56 | eps = 0.01
57 |
58 | max_translate = 12
59 | scale_resize = 0.15
60 | gray_ratio = 0.25
61 | exem_stretch = False
62 |
63 | '''config for net.py'''
64 | num_pos = 16
65 | num_neg = 48
66 | lamb = 5
67 |
68 | ohem_pos = False
69 | ohem_neg = False
70 | ohem_reg = False
71 |
72 |
73 | config = Config()
74 |
--------------------------------------------------------------------------------
/tracking/run_tracking.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from got10k.experiments import *
3 | from siamRPNBIG import TrackerSiamRPNBIG
4 | import argparse
5 | import os
6 | import json
7 |
8 | parser = argparse.ArgumentParser(description='PyTorch SiameseRPN Tracking')
9 |
10 | parser.add_argument('--tracker_path', default='/home/arbi/desktop/data', metavar='DIR',help='path to dataset')
11 | parser.add_argument('--experiment_name', default='default', metavar='DIR',help='path to weight')
12 | parser.add_argument('--net_path', default='../train/experiments/default/model/model_e17.pth', metavar='DIR',help='path to weight')
13 | # ../train/experiments/default/model/model_e1.pth # ../model.pth #../siamrpn_7.pth
14 | # /Users/arbi/Desktop # /home/arbi/desktop/GOT-10k
15 | # /media/arbi/9132EE0B9756C987/dataset/GOT-10k/full_data
16 | parser.add_argument('--visualize', default=True, help='visualize')
17 |
18 | args = parser.parse_args()
19 |
20 | if __name__ == '__main__':
21 |
22 | """Load the parameters from json file"""
23 | json_path = os.path.join('experiments/{}'.format(args.experiment_name), 'parameters.json')
24 | assert os.path.isfile(json_path), ("No json configuration file found at {}".format(json_path))
25 | with open(json_path) as data_file:
26 | params = json.load(data_file)
27 |
28 | '''setup tracker'''
29 | tracker = TrackerSiamRPNBIG(params, args.net_path)
30 |
31 | '''setup experiments'''
32 | # 7 datasets with different versions
33 | '''
34 | experiments = ExperimentGOT10k('data/GOT-10k', subset='test'),
35 | ExperimentOTB('data/OTB', version=2015),
36 | ExperimentOTB('data/OTB', version=2013),
37 | ExperimentVOT('data/vot2018', version=2018),
38 | ExperimentUAV123('data/UAV123', version='UAV123'),
39 | ExperimentUAV123('data/UAV123', version='UAV20L'),
40 | ExperimentDTB70('data/DTB70'),
41 | ExperimentTColor128('data/Temple-color-128'),
42 | ExperimentNfS('data/nfs', fps=30),
43 | ExperimentNfS('data/nfs', fps=240),
44 | ]
45 |
46 | for e in experiments:
47 | e.run(tracker, visualize=True)
48 | e.report([tracker.name])
49 | '''
50 |
51 | '''
52 | experiments = ExperimentGOT10k(args.tracker_path, subset='val',
53 | result_dir='experiments/{}/results'.format(args.experiment_name),
54 | report_dir='experiments/{}/reports'.format(args.experiment_name))
55 |
56 | '''
57 | experiments = ExperimentOTB('/home/arbi/desktop/data', version=2015,
58 | result_dir='experiments/{}/OTB2015resultsGOT-10k_42'.format(args.experiment_name),
59 | report_dir='experiments/{}/OTB2015reportsGOT-10k_42'.format(args.experiment_name))
60 |
61 |
62 | '''run experiments'''
63 | experiments.run(tracker, visualize = args.visualize)
64 | experiments.report([tracker.name])
65 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # I got on OTB2015 result 8.41 and 0.625 without RPN. In the [SiamFusion project](https://github.com/arbitularov/SiamFusion)
2 | 
3 | # SiamRPN-PyTorch
4 | Implementation SiamRPN on PyTorch with GOT-10k dataset
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## How to run Training
15 | 1. Download the GOT-10k dataset in http://got-10k.aitestunion.com/downloads
16 | 2. Run the train_siamrpn.py script:
17 | ```
18 | cd train
19 |
20 | python3 train_siamrpn.py --train_path=/path/to/dataset/GOT-10k/train
21 | ```
22 |
23 | ## How to run Tracking
24 | [Coming Soon]
25 |
26 |
27 | ## pip install
28 | ```
29 | pip3 install shapely
30 | ```
31 |
32 | ## How to fix GOT-10k dataset
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | 1. First you need to delete four videos:
43 | ```
44 | GOT-10k_Train_008628
45 | GOT-10k_Train_008630
46 | GOT-10k_Train_009058
47 | GOT-10k_Train_009059
48 | ```
49 | Because they are ymin and xmin is greater than the size of the image.
50 |
51 | 2. Run the fixed.py script:
52 | ```
53 | python3 fixed.py --dataset_path=/path/to/dataset/GOT-10k/train
54 | ```
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 | After you have new_file.txt file. In this file a lot of information about where the error.
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 | You do not need to change anything yourself, the fixed.py script will do it for you.
74 |
75 | ## My contacts
76 |
77 | E-mail: arbi.tularov@yandex.ru
78 |
79 | WeChat: tularov_arbi
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 | ## Authors
90 |
91 | * `Bo Li` - paper - [Siamese-RPN](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf)
92 | * `De jiasong` - code - [Siamese-RPN-pytorch](https://github.com/songdejia/Siamese-RPN-pytorch)
93 | * `Makalo` - code - [Siamese-RPN-tensorflow](https://github.com/makalo/Siamese-RPN-tensorflow)
94 |
95 | ## Citation
96 | ```
97 | Paper: @InProceedings{Li_2018_CVPR,
98 | author = {Li, Bo and Yan, Junjie and Wu, Wei and Zhu, Zheng and Hu, Xiaolin},
99 | title = {High Performance Visual Tracking With Siamese Region Proposal Network},
100 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
101 | month = {June},
102 | year = {2018}
103 | }
104 | ```
105 |
--------------------------------------------------------------------------------
/fixed.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | from tqdm import tqdm
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser(description='Fixed GOT-10k Dataset')
7 | parser.add_argument('--dataset_path', default='/Users/arbi/Desktop/val', metavar='DIR',help='path to dataset')
8 | param = parser.parse_args()
9 |
10 | sub_class_dir = [sub_class_dir for sub_class_dir in os.listdir(param.dataset_path) if os.path.isdir(os.path.join(param.dataset_path, sub_class_dir))]
11 |
12 | array_error = []
13 |
14 | for name_dir in tqdm(sub_class_dir):
15 |
16 | meta = open("{}/{}/meta_info.ini".format(param.dataset_path, name_dir), "r")
17 |
18 | read_meta = meta.readlines()
19 |
20 | w_and_h = re.findall(r'\d+', read_meta[10])
21 | meta.close()
22 |
23 | groundtruth = open("{}/{}/groundtruth.txt".format(param.dataset_path, name_dir), "r")
24 | read_groundtruth = groundtruth.readlines()
25 | count_gt = len(read_groundtruth)
26 | groundtruth.close()
27 |
28 | groundtruth_write = open("{}/{}/groundtruth.txt".format(param.dataset_path, name_dir), "w")
29 |
30 | groundtruth_array = []
31 | for i2, name_gt in enumerate(read_groundtruth):
32 |
33 | gt = [abs(int(float(i))) for i in name_gt.strip('\n').split(',')]
34 | w = gt[0]+gt[2]
35 | h = gt[1]+gt[3]
36 |
37 | if w > int(w_and_h[0]) and h > int(w_and_h[1]):
38 | print('i2', i2+1,'w and h')
39 | info = 'w and h {}, img: {}, img_size_h: {} < ymax = {} + {} = {} and img_size_w: {} < xmax = {} + {} = {} '.format(name_dir, i2+1, w_and_h[0], gt[0], gt[2], w, w_and_h[1], gt[1], gt[3], h )
40 | array_error.append(info)
41 | w_fixed = gt[2] - (w - int(w_and_h[0]))
42 | h_fixed = gt[3] - (h - int(w_and_h[1]))
43 | gt_fixed = '{}.0000,{}.0000,{}.0000,{}.0000'.format(gt[0], gt[1], w_fixed, h_fixed)
44 | groundtruth_array.append(gt_fixed)
45 |
46 | elif w > int(w_and_h[0]):
47 | #print('i2', i2+1,'just w')
48 | info = 'just w {}, img: {}, img_size: {} < xmax = {} + {} = {}'.format(name_dir, i2+1, w_and_h[0], gt[0], gt[2], w)
49 | array_error.append(info)
50 | w_fixed = gt[2] - (w - int(w_and_h[0]))
51 | gt_fixed = '{}.0000,{}.0000,{}.0000,{}.0000'.format(gt[0], gt[1], w_fixed, gt[3])
52 | groundtruth_array.append(gt_fixed)
53 |
54 | elif h > int(w_and_h[1]):
55 | #print('i2', i2+1,'just h')
56 | info = 'just w {}, img: {}, img_size: {} < ymax = {} + {} = {}'.format(name_dir, i2+1, w_and_h[1], gt[1], gt[3], h)
57 | array_error.append(info)
58 | h_fixed = gt[3] - (h - int(w_and_h[1]))
59 | gt_fixed = '{}.0000,{}.0000,{}.0000,{}.0000'.format(gt[0], gt[1], gt[2], h_fixed)
60 | groundtruth_array.append(gt_fixed)
61 |
62 | else:
63 | #print('i2', i2+1,'all it\'s ok')
64 | gt_fixed = '{}.0000,{}.0000,{}.0000,{}.0000'.format(gt[0], gt[1], gt[2], gt[3])
65 | groundtruth_array.append(gt_fixed)
66 | try:
67 | for l in groundtruth_array:
68 | groundtruth_write.write('{}\n'.format(l))
69 | finally:
70 | groundtruth_write.close()
71 |
72 | new_file = open("new_file.txt", "w")
73 | try:
74 | for i in array_error:
75 | new_file.write('{}\n'.format(i))
76 | finally:
77 | new_file.close()
78 |
--------------------------------------------------------------------------------
/train/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 | import torchvision.transforms as transforms
5 |
6 | from torchvision.models import alexnet
7 | from torch.autograd import Variable
8 | from torch import nn
9 |
10 | from config import config
11 |
12 |
13 | class SiameseAlexNet(nn.Module):
14 | def __init__(self, ):
15 | super(SiameseAlexNet, self).__init__()
16 | self.featureExtract = nn.Sequential(
17 | nn.Conv2d(3, 96, 11, stride=2),
18 | nn.BatchNorm2d(96),
19 | nn.MaxPool2d(3, stride=2),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(96, 256, 5),
22 | nn.BatchNorm2d(256),
23 | nn.MaxPool2d(3, stride=2),
24 | nn.ReLU(inplace=True),
25 | nn.Conv2d(256, 384, 3),
26 | nn.BatchNorm2d(384),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(384, 384, 3),
29 | nn.BatchNorm2d(384),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(384, 256, 3),
32 | nn.BatchNorm2d(256),
33 | )
34 | self.anchor_num = config.anchor_num
35 | self.input_size = config.detection_img_size
36 | self.score_displacement = int((self.input_size - config.template_img_size) / config.total_stride)
37 | self.conv_cls1 = nn.Conv2d(256, 256 * 2 * self.anchor_num, kernel_size=3, stride=1, padding=0)
38 | self.conv_r1 = nn.Conv2d(256, 256 * 4 * self.anchor_num, kernel_size=3, stride=1, padding=0)
39 |
40 | self.conv_cls2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)
41 | self.conv_r2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)
42 | self.regress_adjust = nn.Conv2d(4 * self.anchor_num, 4 * self.anchor_num, 1)
43 |
44 | def init_weights(self):
45 | for m in self.modules():
46 | if isinstance(m, nn.Conv2d):
47 | nn.init.normal_(m.weight.data, std= 0.0005)
48 | nn.init.normal_(m.bias.data, std= 0.0005)
49 | elif isinstance(m, nn.BatchNorm2d):
50 | m.weight.data.fill_(1)
51 | m.bias.data.zero_()
52 |
53 | def forward(self, template, detection):
54 | N = template.size(0)
55 | template_feature = self.featureExtract(template)
56 | detection_feature = self.featureExtract(detection)
57 |
58 | kernel_score = self.conv_cls1(template_feature).view(N, 2 * self.anchor_num, 256, 4, 4)
59 | kernel_regression = self.conv_r1(template_feature).view(N, 4 * self.anchor_num, 256, 4, 4)
60 | conv_score = self.conv_cls2(detection_feature)
61 | conv_regression = self.conv_r2(detection_feature)
62 |
63 | conv_scores = conv_score.reshape(1, -1, self.score_displacement + 4, self.score_displacement + 4)
64 | score_filters = kernel_score.reshape(-1, 256, 4, 4)
65 | pred_score = F.conv2d(conv_scores, score_filters, groups=N).reshape(N, 10, self.score_displacement + 1,
66 | self.score_displacement + 1)
67 |
68 | conv_reg = conv_regression.reshape(1, -1, self.score_displacement + 4, self.score_displacement + 4)
69 | reg_filters = kernel_regression.reshape(-1, 256, 4, 4)
70 | pred_regression = self.regress_adjust(
71 | F.conv2d(conv_reg, reg_filters, groups=N).reshape(N, 20, self.score_displacement + 1,
72 | self.score_displacement + 1))
73 | return pred_score, pred_regression
74 |
75 | def track_init(self, template):
76 | N = template.size(0)
77 | template_feature = self.featureExtract(template)
78 |
79 | kernel_score = self.conv_cls1(template_feature).view(N, 2 * self.anchor_num, 256, 4, 4)
80 | kernel_regression = self.conv_r1(template_feature).view(N, 4 * self.anchor_num, 256, 4, 4)
81 | self.score_filters = kernel_score.reshape(-1, 256, 4, 4)
82 | self.reg_filters = kernel_regression.reshape(-1, 256, 4, 4)
83 |
84 | def track(self, detection):
85 | N = detection.size(0)
86 | detection_feature = self.featureExtract(detection)
87 |
88 | conv_score = self.conv_cls2(detection_feature)
89 | conv_regression = self.conv_r2(detection_feature)
90 |
91 | conv_scores = conv_score.reshape(1, -1, self.score_displacement + 4, self.score_displacement + 4)
92 | pred_score = F.conv2d(conv_scores, self.score_filters, groups=N).reshape(N, 10, self.score_displacement + 1,
93 | self.score_displacement + 1)
94 | conv_reg = conv_regression.reshape(1, -1, self.score_displacement + 4, self.score_displacement + 4)
95 | pred_regression = self.regress_adjust(
96 | F.conv2d(conv_reg, self.reg_filters, groups=N).reshape(N, 20, self.score_displacement + 1,
97 | self.score_displacement + 1))
98 | return pred_score, pred_regression
99 |
--------------------------------------------------------------------------------
/tracking/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 | import torchvision.transforms as transforms
5 | from custom_transforms import ToTensor
6 |
7 | from torchvision.models import alexnet
8 | from torch.autograd import Variable
9 | from torch import nn
10 |
11 |
12 | from config import config
13 |
14 |
15 | class SiameseAlexNet(nn.Module):
16 | def __init__(self, ):
17 | super(SiameseAlexNet, self).__init__()
18 | self.featureExtract = nn.Sequential(
19 | nn.Conv2d(3, 96, 11, stride=2),
20 | nn.BatchNorm2d(96),
21 | nn.MaxPool2d(3, stride=2),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(96, 256, 5),
24 | nn.BatchNorm2d(256),
25 | nn.MaxPool2d(3, stride=2),
26 | nn.ReLU(inplace=True),
27 | nn.Conv2d(256, 384, 3),
28 | nn.BatchNorm2d(384),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(384, 384, 3),
31 | nn.BatchNorm2d(384),
32 | nn.ReLU(inplace=True),
33 | nn.Conv2d(384, 256, 3),
34 | nn.BatchNorm2d(256),
35 | )
36 | self.anchor_num = config.anchor_num
37 | self.input_size = config.detection_img_size
38 | self.score_displacement = int((self.input_size - config.template_img_size) / config.total_stride)
39 | self.conv_cls1 = nn.Conv2d(256, 256 * 2 * self.anchor_num, kernel_size=3, stride=1, padding=0)
40 | self.conv_r1 = nn.Conv2d(256, 256 * 4 * self.anchor_num, kernel_size=3, stride=1, padding=0)
41 |
42 | self.conv_cls2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)
43 | self.conv_r2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)
44 | self.regress_adjust = nn.Conv2d(4 * self.anchor_num, 4 * self.anchor_num, 1)
45 |
46 | def init_weights(self):
47 | for m in self.modules():
48 | if isinstance(m, nn.Conv2d):
49 | # nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu')
50 | nn.init.normal_(m.weight.data, std=0.0005)
51 | nn.init.normal_(m.bias.data, std=0.0005)
52 | elif isinstance(m, nn.BatchNorm2d):
53 | m.weight.data.fill_(1)
54 | m.bias.data.zero_()
55 |
56 | def forward(self, template, detection):
57 | N = template.size(0)
58 | template_feature = self.featureExtract(template)
59 | detection_feature = self.featureExtract(detection)
60 |
61 | kernel_score = self.conv_cls1(template_feature).view(N, 2 * self.anchor_num, 256, 4, 4)
62 | kernel_regression = self.conv_r1(template_feature).view(N, 4 * self.anchor_num, 256, 4, 4)
63 | conv_score = self.conv_cls2(detection_feature)
64 | conv_regression = self.conv_r2(detection_feature)
65 |
66 | conv_scores = conv_score.reshape(1, -1, self.score_displacement + 4, self.score_displacement + 4)
67 | score_filters = kernel_score.reshape(-1, 256, 4, 4)
68 | pred_score = F.conv2d(conv_scores, score_filters, groups=N).reshape(N, 10, self.score_displacement + 1,
69 | self.score_displacement + 1)
70 |
71 | conv_reg = conv_regression.reshape(1, -1, self.score_displacement + 4, self.score_displacement + 4)
72 | reg_filters = kernel_regression.reshape(-1, 256, 4, 4)
73 | pred_regression = self.regress_adjust(
74 | F.conv2d(conv_reg, reg_filters, groups=N).reshape(N, 20, self.score_displacement + 1,
75 | self.score_displacement + 1))
76 | return pred_score, pred_regression
77 |
78 | def track_init(self, template):
79 | N = template.size(0)
80 | template_feature = self.featureExtract(template)
81 |
82 | kernel_score = self.conv_cls1(template_feature).view(N, 2 * self.anchor_num, 256, 4, 4)
83 | kernel_regression = self.conv_r1(template_feature).view(N, 4 * self.anchor_num, 256, 4, 4)
84 | self.score_filters = kernel_score.reshape(-1, 256, 4, 4)
85 | self.reg_filters = kernel_regression.reshape(-1, 256, 4, 4)
86 |
87 | def track(self, detection):
88 | N = detection.size(0)
89 | detection_feature = self.featureExtract(detection)
90 |
91 | conv_score = self.conv_cls2(detection_feature)
92 | conv_regression = self.conv_r2(detection_feature)
93 |
94 | conv_scores = conv_score.reshape(1, -1, self.score_displacement + 4, self.score_displacement + 4)
95 | pred_score = F.conv2d(conv_scores, self.score_filters, groups=N).reshape(N, 10, self.score_displacement + 1,
96 | self.score_displacement + 1)
97 | conv_reg = conv_regression.reshape(1, -1, self.score_displacement + 4, self.score_displacement + 4)
98 | pred_regression = self.regress_adjust(
99 | F.conv2d(conv_reg, self.reg_filters, groups=N).reshape(N, 20, self.score_displacement + 1,
100 | self.score_displacement + 1))
101 | return pred_score, pred_regression
102 |
--------------------------------------------------------------------------------
/tracking/data_loader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import cv2
5 | import torch
6 | import random
7 | import numpy as np
8 | from torchvision import datasets, transforms, utils
9 |
10 | class TrackerDataLoader(object):
11 |
12 | def get_instance_image(self, img, bbox, size_z, size_x, context_amount, img_mean=None):
13 |
14 | cx, cy, w, h = bbox # float type
15 | wc_z = w + 0.5 * (w + h)
16 | hc_z = h + 0.5 * (w + h)
17 | s_z = np.sqrt(wc_z * hc_z) # the width of the crop box
18 | scale_z = size_z / s_z
19 |
20 | s_x = s_z * size_x / size_z
21 | instance_img, scale_x = self.crop_and_pad(img, cx, cy, size_x, s_x, img_mean)
22 | w_x = w * scale_x
23 | h_x = h * scale_x
24 | # point_1 = (size_x + 1) / 2 - w_x / 2, (size_x + 1) / 2 - h_x / 2
25 | # point_2 = (size_x + 1) / 2 + w_x / 2, (size_x + 1) / 2 + h_x / 2
26 | # frame = cv2.rectangle(instance_img, (int(point_1[0]),int(point_1[1])), (int(point_2[0]),int(point_2[1])), (0, 255, 0), 2)
27 | # cv2.imwrite('1.jpg', frame)
28 | return instance_img, w_x, h_x, scale_x
29 |
30 | def get_exemplar_image(self, img, bbox, size_z, context_amount, img_mean=None):
31 | cx, cy, w, h = bbox
32 | wc_z = w + context_amount * (w + h)
33 | hc_z = h + context_amount * (w + h)
34 | s_z = np.sqrt(wc_z * hc_z)
35 | scale_z = size_z / s_z
36 | exemplar_img, _ = self.crop_and_pad(img, cx, cy, size_z, s_z, img_mean)
37 | return exemplar_img, scale_z, s_z
38 |
39 | def crop_and_pad(self, img, cx, cy, model_sz, original_sz, img_mean=None):
40 |
41 | def round_up(value):
42 | # 替换内置round函数,实现保留2位小数的精确四舍五入
43 | return round(value + 1e-6 + 1000) - 1000
44 | im_h, im_w, _ = img.shape
45 |
46 | xmin = cx - (original_sz - 1) / 2
47 | xmax = xmin + original_sz - 1
48 | ymin = cy - (original_sz - 1) / 2
49 | ymax = ymin + original_sz - 1
50 |
51 | left = int(round_up(max(0., -xmin)))
52 | top = int(round_up(max(0., -ymin)))
53 | right = int(round_up(max(0., xmax - im_w + 1)))
54 | bottom = int(round_up(max(0., ymax - im_h + 1)))
55 |
56 | xmin = int(round_up(xmin + left))
57 | xmax = int(round_up(xmax + left))
58 | ymin = int(round_up(ymin + top))
59 | ymax = int(round_up(ymax + top))
60 | r, c, k = img.shape
61 | if any([top, bottom, left, right]):
62 | te_im = np.zeros((r + top + bottom, c + left + right, k), np.uint8) # 0 is better than 1 initialization
63 | te_im[top:top + r, left:left + c, :] = img
64 | if top:
65 | te_im[0:top, left:left + c, :] = img_mean
66 | if bottom:
67 | te_im[r + top:, left:left + c, :] = img_mean
68 | if left:
69 | te_im[:, 0:left, :] = img_mean
70 | if right:
71 | te_im[:, c + left:, :] = img_mean
72 | im_patch_original = te_im[int(ymin):int(ymax + 1), int(xmin):int(xmax + 1), :]
73 | else:
74 | im_patch_original = img[int(ymin):int(ymax + 1), int(xmin):int(xmax + 1), :]
75 | if not np.array_equal(model_sz, original_sz):
76 | im_patch = cv2.resize(im_patch_original, (model_sz, model_sz)) # zzp: use cv to get a better speed
77 | else:
78 | im_patch = im_patch_original
79 | scale = model_sz / im_patch_original.shape[0]
80 | return im_patch, scale
81 |
82 | def box_transform_inv(self, anchors, offset):
83 | anchor_xctr = anchors[:, :1]
84 | anchor_yctr = anchors[:, 1:2]
85 | anchor_w = anchors[:, 2:3]
86 | anchor_h = anchors[:, 3:]
87 | offset_x, offset_y, offset_w, offset_h = offset[:, :1], offset[:, 1:2], offset[:, 2:3], offset[:, 3:],
88 |
89 | box_cx = anchor_w * offset_x + anchor_xctr
90 | box_cy = anchor_h * offset_y + anchor_yctr
91 | box_w = anchor_w * np.exp(offset_w)
92 | box_h = anchor_h * np.exp(offset_h)
93 | box = np.hstack([box_cx, box_cy, box_w, box_h])
94 | return box
95 |
96 | def generate_anchors(self, total_stride, base_size, scales, ratios, score_size):
97 | anchor_num = len(ratios) * len(scales)
98 | anchor = np.zeros((anchor_num, 4), dtype=np.float32)
99 | size = base_size * base_size
100 | count = 0
101 | for ratio in ratios:
102 | # ws = int(np.sqrt(size * 1.0 / ratio))
103 | ws = int(np.sqrt(size / ratio))
104 | hs = int(ws * ratio)
105 | for scale in scales:
106 | wws = ws * scale
107 | hhs = hs * scale
108 | anchor[count, 0] = 0
109 | anchor[count, 1] = 0
110 | anchor[count, 2] = wws
111 | anchor[count, 3] = hhs
112 | count += 1
113 |
114 | anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4))
115 | # (5,4x225) to (225x5,4)
116 | ori = - (score_size // 2) * total_stride
117 | # the left displacement
118 | xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)],
119 | [ori + total_stride * dy for dy in range(score_size)])
120 | # (15,15)
121 | xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \
122 | np.tile(yy.flatten(), (anchor_num, 1)).flatten()
123 | # (15,15) to (225,1) to (5,225) to (225x5,1)
124 | anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32)
125 | return anchor
126 |
--------------------------------------------------------------------------------
/tracking/util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 |
5 | class Util(object):
6 |
7 | def generate_anchors(self, total_stride, base_size, scales, ratios, score_size):
8 | anchor_num = len(ratios) * len(scales)
9 | anchor = np.zeros((anchor_num, 4), dtype=np.float32)
10 | size = base_size * base_size
11 | count = 0
12 | for ratio in ratios:
13 | # ws = int(np.sqrt(size * 1.0 / ratio))
14 | ws = int(np.sqrt(size / ratio))
15 | hs = int(ws * ratio)
16 | for scale in scales:
17 | wws = ws * scale
18 | hhs = hs * scale
19 | anchor[count, 0] = 0
20 | anchor[count, 1] = 0
21 | anchor[count, 2] = wws
22 | anchor[count, 3] = hhs
23 | count += 1
24 |
25 | anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4))
26 | # (5,4x225) to (225x5,4)
27 | ori = - (score_size // 2) * total_stride
28 | # the left displacement
29 | xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)],
30 | [ori + total_stride * dy for dy in range(score_size)])
31 | # (15,15)
32 | xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \
33 | np.tile(yy.flatten(), (anchor_num, 1)).flatten()
34 | # (15,15) to (225,1) to (5,225) to (225x5,1)
35 | anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32)
36 | return anchor
37 |
38 | def get_subwindow_tracking(self, im, pos, model_sz, original_sz, avg_chans, out_mode='torch'):
39 |
40 | # im (720, 1280, 3)
41 | # pos [406. 377.5]
42 | # model_sz 127
43 | # original_sz 768.0
44 | # avg_chans [115.18894748 111.79296549 109.10407878]
45 |
46 | if isinstance(pos, float):
47 | pos = [pos, pos]
48 | sz = original_sz # original_sz 768.0
49 | im_sz = im.shape # im (720, 1280, 3)
50 | c = (original_sz+1) / 2 # 384.5
51 | context_xmin = round(pos[0] - c) # floor(pos(2) - sz(2) / 2);
52 | context_xmax = context_xmin + sz - 1
53 | context_ymin = round(pos[1] - c) # floor(pos(1) - sz(1) / 2);
54 | context_ymax = context_ymin + sz - 1
55 | left_pad = int(max(0., -context_xmin))
56 | top_pad = int(max(0., -context_ymin))
57 | right_pad = int(max(0., context_xmax - im_sz[1] + 1))
58 | bottom_pad = int(max(0., context_ymax - im_sz[0] + 1))
59 |
60 | context_xmin = context_xmin + left_pad
61 | context_xmax = context_xmax + left_pad
62 | context_ymin = context_ymin + top_pad
63 | context_ymax = context_ymax + top_pad
64 |
65 | # zzp: a more easy speed version
66 | r, c, k = im.shape
67 | if any([top_pad, bottom_pad, left_pad, right_pad]):
68 | te_im = np.zeros((r + top_pad + bottom_pad, c + left_pad + right_pad, k), np.uint8) # 0 is better than 1 initialization
69 | te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im
70 | if top_pad:
71 | te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans
72 | if bottom_pad:
73 | te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans
74 | if left_pad:
75 | te_im[:, 0:left_pad, :] = avg_chans
76 | if right_pad:
77 | te_im[:, c + left_pad:, :] = avg_chans
78 | im_patch_original = te_im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :]
79 | else:
80 | im_patch_original = im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :]
81 |
82 | if not np.array_equal(model_sz, original_sz):
83 | im_patch = cv2.resize(im_patch_original, (model_sz, model_sz))
84 | else:
85 | im_patch = im_patch_original
86 |
87 | cv2.imshow('foto', im_patch)
88 |
89 | def im_to_torch(img):
90 |
91 | def to_torch(ndarray):
92 | if type(ndarray).__module__ == 'numpy':
93 | return torch.from_numpy(ndarray)
94 | elif not torch.is_tensor(ndarray):
95 | raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray)))
96 | return ndarray
97 | img = np.transpose(img, (2, 0, 1)) # C*H*W
98 | img = to_torch(img).float()
99 | return img
100 |
101 | return im_to_torch(im_patch) if out_mode in 'torch' else im_patch
102 |
103 | def cxy_wh_2_rect(self, pos, sz):
104 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) # 0-index
105 |
106 | def x1y1_wh_to_xy_wh(self, rect):
107 | return np.array([rect[0]+rect[2]/2, rect[1]+rect[3]/2]), np.array([rect[2], rect[3]]) # 0-index
108 |
109 | def box_transform_inv(self, anchors, offset):
110 | anchor_xctr = anchors[:, :1]
111 | anchor_yctr = anchors[:, 1:2]
112 | anchor_w = anchors[:, 2:3]
113 | anchor_h = anchors[:, 3:]
114 | offset_x, offset_y, offset_w, offset_h = offset[:, :1], offset[:, 1:2], offset[:, 2:3], offset[:, 3:],
115 |
116 | box_cx = anchor_w * offset_x + anchor_xctr
117 | box_cy = anchor_h * offset_y + anchor_yctr
118 | box_w = anchor_w * np.exp(offset_w)
119 | box_h = anchor_h * np.exp(offset_h)
120 | box = np.hstack([box_cx, box_cy, box_w, box_h])
121 | return box
122 |
123 | def change(self, r):
124 | return np.maximum(r, 1. / r)
125 |
126 | def sz(self, w, h):
127 | pad = (w + h) * 0.5
128 | sz2 = (w + pad) * (h + pad)
129 | return np.sqrt(sz2)
130 |
131 | def sz_wh(self, wh):
132 | pad = (wh[0] + wh[1]) * 0.5
133 | sz2 = (wh[0] + pad) * (wh[1] + pad)
134 | return np.sqrt(sz2)
135 |
136 | util = Util()
137 |
--------------------------------------------------------------------------------
/train/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | class RandomStretch(object):
7 | def __init__(self, max_stretch=0.05):
8 | """Random resize image according to the stretch
9 | Args:
10 | max_stretch(float): 0 to 1 value
11 | """
12 | self.max_stretch = max_stretch
13 |
14 | def __call__(self, sample):
15 | """
16 | Args:
17 | sample(numpy array): 3 or 1 dim image
18 | """
19 | scale_h = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch)
20 | scale_w = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch)
21 | h, w = sample.shape[:2]
22 | shape = int(w * scale_w), int(h * scale_h)
23 | return cv2.resize(sample, shape, cv2.INTER_LINEAR)
24 |
25 |
26 | class CenterCrop(object):
27 | def __init__(self, size):
28 | """Crop the image in the center according the given size
29 | if size greater than image size, zero padding will adpot
30 | Args:
31 | size (tuple): desired size
32 | """
33 | self.size = size
34 |
35 | def __call__(self, sample):
36 | """
37 | Args:
38 | sample(numpy array): 3 or 1 dim image
39 | """
40 | shape = sample.shape[:2]
41 | cy, cx = (shape[0] - 1) // 2, (shape[1] - 1) // 2
42 | ymin, xmin = cy - self.size[0] // 2, cx - self.size[1] // 2
43 | ymax, xmax = cy + self.size[0] // 2 + self.size[0] % 2, \
44 | cx + self.size[1] // 2 + self.size[1] % 2
45 | left = right = top = bottom = 0
46 | im_h, im_w = shape
47 | if xmin < 0:
48 | left = int(abs(xmin))
49 | if xmax > im_w:
50 | right = int(xmax - im_w)
51 | if ymin < 0:
52 | top = int(abs(ymin))
53 | if ymax > im_h:
54 | bottom = int(ymax - im_h)
55 |
56 | xmin = int(max(0, xmin))
57 | xmax = int(min(im_w, xmax))
58 | ymin = int(max(0, ymin))
59 | ymax = int(min(im_h, ymax))
60 | im_patch = sample[ymin:ymax, xmin:xmax]
61 | if left != 0 or right != 0 or top != 0 or bottom != 0:
62 | im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right,
63 | cv2.BORDER_CONSTANT, value=0)
64 | return im_patch
65 |
66 |
67 | class RandomCrop(object):
68 | def __init__(self, size, max_translate):
69 | """Crop the image in the center according the given size
70 | if size greater than image size, zero padding will adpot
71 | Args:
72 | size (tuple): desired size
73 | max_translate: max translate of random shift
74 | """
75 | self.size = size
76 | self.max_translate = max_translate
77 |
78 | def __call__(self, sample):
79 | """
80 | Args:
81 | sample(numpy array): 3 or 1 dim image
82 | """
83 | shape = sample.shape[:2]
84 | cy_o = (shape[0] - 1) // 2
85 | cx_o = (shape[1] - 1) // 2
86 | cy = np.random.randint(cy_o - self.max_translate,
87 | cy_o + self.max_translate + 1)
88 | cx = np.random.randint(cx_o - self.max_translate,
89 | cx_o + self.max_translate + 1)
90 | assert abs(cy - cy_o) <= self.max_translate and \
91 | abs(cx - cx_o) <= self.max_translate
92 | ymin = cy - self.size[0] // 2
93 | xmin = cx - self.size[1] // 2
94 | ymax = cy + self.size[0] // 2 + self.size[0] % 2
95 | xmax = cx + self.size[1] // 2 + self.size[1] % 2
96 | left = right = top = bottom = 0
97 | im_h, im_w = shape
98 | if xmin < 0:
99 | left = int(abs(xmin))
100 | if xmax > im_w:
101 | right = int(xmax - im_w)
102 | if ymin < 0:
103 | top = int(abs(ymin))
104 | if ymax > im_h:
105 | bottom = int(ymax - im_h)
106 |
107 | xmin = int(max(0, xmin))
108 | xmax = int(min(im_w, xmax))
109 | ymin = int(max(0, ymin))
110 | ymax = int(min(im_h, ymax))
111 | im_patch = sample[ymin:ymax, xmin:xmax]
112 | if left != 0 or right != 0 or top != 0 or bottom != 0:
113 | im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right,
114 | cv2.BORDER_CONSTANT, value=0)
115 | return im_patch
116 |
117 |
118 | class ColorAug(object):
119 | def __init__(self, type_in='z'):
120 | if type_in == 'z':
121 | rgb_var = np.array([[3.2586416e+03, 2.8992207e+03, 2.6392236e+03],
122 | [2.8992207e+03, 3.0958174e+03, 2.9321748e+03],
123 | [2.6392236e+03, 2.9321748e+03, 3.4533721e+03]])
124 | if type_in == 'x':
125 | rgb_var = np.array([[2.4847285e+03, 2.1796064e+03, 1.9766885e+03],
126 | [2.1796064e+03, 2.3441289e+03, 2.2357402e+03],
127 | [1.9766885e+03, 2.2357402e+03, 2.7369697e+03]])
128 | self.v, _ = np.linalg.eig(rgb_var)
129 | self.v = np.sqrt(self.v)
130 |
131 | def __call__(self, sample):
132 | return sample + 0.1 * self.v * np.random.randn(3)
133 |
134 |
135 | class RandomBlur(object):
136 | def __init__(self, ratio):
137 | self.ratio = ratio
138 |
139 | def __call__(self, sample):
140 | if np.random.rand(1) < self.ratio:
141 | # random kernel size
142 | kernel_size = np.random.choice([3, 5, 7])
143 | # random gaussian sigma
144 | sigma = np.random.rand() * 5
145 | return cv2.GaussianBlur(sample, (kernel_size, kernel_size), sigma)
146 | else:
147 | return sample
148 |
149 |
150 | class Normalize(object):
151 | def __init__(self):
152 | self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
153 | self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
154 |
155 | def __call__(self, sample):
156 | return (sample / 255. - self.mean) / self.std
157 |
158 |
159 | class ToTensor(object):
160 | def __call__(self, sample):
161 | sample = sample.transpose(2, 0, 1)
162 | return torch.from_numpy(sample.astype(np.float32))
163 |
--------------------------------------------------------------------------------
/tracking/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | class RandomStretch(object):
7 | def __init__(self, max_stretch=0.05):
8 | """Random resize image according to the stretch
9 | Args:
10 | max_stretch(float): 0 to 1 value
11 | """
12 | self.max_stretch = max_stretch
13 |
14 | def __call__(self, sample):
15 | """
16 | Args:
17 | sample(numpy array): 3 or 1 dim image
18 | """
19 | scale_h = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch)
20 | scale_w = 1.0 + np.random.uniform(-self.max_stretch, self.max_stretch)
21 | h, w = sample.shape[:2]
22 | shape = int(w * scale_w), int(h * scale_h)
23 | return cv2.resize(sample, shape, cv2.INTER_LINEAR)
24 |
25 |
26 | class CenterCrop(object):
27 | def __init__(self, size):
28 | """Crop the image in the center according the given size
29 | if size greater than image size, zero padding will adpot
30 | Args:
31 | size (tuple): desired size
32 | """
33 | self.size = size
34 |
35 | def __call__(self, sample):
36 | """
37 | Args:
38 | sample(numpy array): 3 or 1 dim image
39 | """
40 | shape = sample.shape[:2]
41 | cy, cx = (shape[0] - 1) // 2, (shape[1] - 1) // 2
42 | ymin, xmin = cy - self.size[0] // 2, cx - self.size[1] // 2
43 | ymax, xmax = cy + self.size[0] // 2 + self.size[0] % 2, \
44 | cx + self.size[1] // 2 + self.size[1] % 2
45 | left = right = top = bottom = 0
46 | im_h, im_w = shape
47 | if xmin < 0:
48 | left = int(abs(xmin))
49 | if xmax > im_w:
50 | right = int(xmax - im_w)
51 | if ymin < 0:
52 | top = int(abs(ymin))
53 | if ymax > im_h:
54 | bottom = int(ymax - im_h)
55 |
56 | xmin = int(max(0, xmin))
57 | xmax = int(min(im_w, xmax))
58 | ymin = int(max(0, ymin))
59 | ymax = int(min(im_h, ymax))
60 | im_patch = sample[ymin:ymax, xmin:xmax]
61 | if left != 0 or right != 0 or top != 0 or bottom != 0:
62 | im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right,
63 | cv2.BORDER_CONSTANT, value=0)
64 | return im_patch
65 |
66 |
67 | class RandomCrop(object):
68 | def __init__(self, size, max_translate):
69 | """Crop the image in the center according the given size
70 | if size greater than image size, zero padding will adpot
71 | Args:
72 | size (tuple): desired size
73 | max_translate: max translate of random shift
74 | """
75 | self.size = size
76 | self.max_translate = max_translate
77 |
78 | def __call__(self, sample):
79 | """
80 | Args:
81 | sample(numpy array): 3 or 1 dim image
82 | """
83 | shape = sample.shape[:2]
84 | cy_o = (shape[0] - 1) // 2
85 | cx_o = (shape[1] - 1) // 2
86 | cy = np.random.randint(cy_o - self.max_translate,
87 | cy_o + self.max_translate + 1)
88 | cx = np.random.randint(cx_o - self.max_translate,
89 | cx_o + self.max_translate + 1)
90 | assert abs(cy - cy_o) <= self.max_translate and \
91 | abs(cx - cx_o) <= self.max_translate
92 | ymin = cy - self.size[0] // 2
93 | xmin = cx - self.size[1] // 2
94 | ymax = cy + self.size[0] // 2 + self.size[0] % 2
95 | xmax = cx + self.size[1] // 2 + self.size[1] % 2
96 | left = right = top = bottom = 0
97 | im_h, im_w = shape
98 | if xmin < 0:
99 | left = int(abs(xmin))
100 | if xmax > im_w:
101 | right = int(xmax - im_w)
102 | if ymin < 0:
103 | top = int(abs(ymin))
104 | if ymax > im_h:
105 | bottom = int(ymax - im_h)
106 |
107 | xmin = int(max(0, xmin))
108 | xmax = int(min(im_w, xmax))
109 | ymin = int(max(0, ymin))
110 | ymax = int(min(im_h, ymax))
111 | im_patch = sample[ymin:ymax, xmin:xmax]
112 | if left != 0 or right != 0 or top != 0 or bottom != 0:
113 | im_patch = cv2.copyMakeBorder(im_patch, top, bottom, left, right,
114 | cv2.BORDER_CONSTANT, value=0)
115 | return im_patch
116 |
117 |
118 | class ColorAug(object):
119 | def __init__(self, type_in='z'):
120 | if type_in == 'z':
121 | rgb_var = np.array([[3.2586416e+03, 2.8992207e+03, 2.6392236e+03],
122 | [2.8992207e+03, 3.0958174e+03, 2.9321748e+03],
123 | [2.6392236e+03, 2.9321748e+03, 3.4533721e+03]])
124 | if type_in == 'x':
125 | rgb_var = np.array([[2.4847285e+03, 2.1796064e+03, 1.9766885e+03],
126 | [2.1796064e+03, 2.3441289e+03, 2.2357402e+03],
127 | [1.9766885e+03, 2.2357402e+03, 2.7369697e+03]])
128 | self.v, _ = np.linalg.eig(rgb_var)
129 | self.v = np.sqrt(self.v)
130 |
131 | def __call__(self, sample):
132 | return sample + 0.1 * self.v * np.random.randn(3)
133 |
134 |
135 | class RandomBlur(object):
136 | def __init__(self, ratio):
137 | self.ratio = ratio
138 |
139 | def __call__(self, sample):
140 | if np.random.rand(1) < self.ratio:
141 | # random kernel size
142 | kernel_size = np.random.choice([3, 5, 7])
143 | # random gaussian sigma
144 | sigma = np.random.rand() * 5
145 | return cv2.GaussianBlur(sample, (kernel_size, kernel_size), sigma)
146 | else:
147 | return sample
148 |
149 |
150 | class Normalize(object):
151 | def __init__(self):
152 | self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
153 | self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
154 |
155 | def __call__(self, sample):
156 | return (sample / 255. - self.mean) / self.std
157 |
158 |
159 | class ToTensor(object):
160 | def __call__(self, sample):
161 | sample = sample.transpose(2, 0, 1)
162 | return torch.from_numpy(sample.astype(np.float32))
163 |
--------------------------------------------------------------------------------
/train/net.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import cv2
5 | import torch
6 | import random
7 | import numpy as np
8 | import torch.nn as nn
9 | from util import util
10 | import matplotlib.pyplot as plt
11 | import torch.nn.functional as F
12 | from config import config
13 | from got10k.trackers import Tracker
14 | from network import SiameseAlexNet
15 | from loss import rpn_smoothL1, rpn_cross_entropy_balance
16 |
17 | class TrackerSiamRPN(Tracker):
18 |
19 | def __init__(self, net_path=None, **kargs):
20 | super(TrackerSiamRPN, self).__init__(
21 | name='SiamRPN', is_deterministic=True)
22 |
23 | '''setup GPU device if available'''
24 | self.cuda = torch.cuda.is_available()
25 | self.device = torch.device('cuda:0' if self.cuda else 'cpu')
26 |
27 | '''setup model'''
28 | self.net = SiameseAlexNet()
29 | #self.net.init_weights()
30 |
31 | if net_path is not None:
32 | self.net.load_state_dict(torch.load(
33 | net_path, map_location = lambda storage, loc: storage ))
34 | if self.cuda:
35 | self.net = self.net.to(self.device)
36 |
37 | '''setup optimizer'''
38 | self.optimizer = torch.optim.SGD(
39 | self.net.parameters(),
40 | lr = config.lr,
41 | momentum = config.momentum,
42 | weight_decay = config.weight_decay)
43 |
44 | def step(self, epoch, dataset, anchors, i = 0, train=True):
45 |
46 | if train:
47 | self.net.train()
48 | else:
49 | self.net.eval()
50 |
51 | template, detection, regression_target, conf_target = dataset
52 |
53 | if self.cuda:
54 | template, detection = template.cuda(), detection.cuda()
55 | regression_target, conf_target = regression_target.cuda(), conf_target.cuda()
56 |
57 | pred_score, pred_regression = self.net(template, detection)
58 |
59 | pred_conf = pred_score.reshape(-1, 2, config.size).permute(0, 2, 1)
60 |
61 | pred_offset = pred_regression.reshape(-1, 4, config.size).permute(0, 2, 1)
62 |
63 | cls_loss = rpn_cross_entropy_balance( pred_conf,
64 | conf_target,
65 | config.num_pos,
66 | config.num_neg,
67 | anchors,
68 | ohem_pos=config.ohem_pos,
69 | ohem_neg=config.ohem_neg)
70 |
71 | reg_loss = rpn_smoothL1(pred_offset,
72 | regression_target,
73 | conf_target,
74 | config.num_pos,
75 | ohem=config.ohem_reg)
76 |
77 | loss = cls_loss + config.lamb * reg_loss
78 |
79 | '''anchors_show = anchors
80 | exem_img = template[0].cpu().numpy().transpose(1, 2, 0) # (127, 127, 3)
81 | #cv2.imwrite('exem_img.png', exem_img)
82 |
83 | inst_img = detection[0].cpu().numpy().transpose(1, 2, 0) # (255, 255, 3)
84 | #cv2.imwrite('inst_img.png', inst_img)
85 |
86 |
87 |
88 | topk = 1
89 | cls_pred = F.softmax(pred_conf, dim=2)[0, :, 1]
90 |
91 | topk_box = util.get_topk_box(cls_pred, pred_offset[0], anchors_show, topk=topk)
92 | img_box = util.add_box_img(inst_img, topk_box, color=(0, 0, 255))
93 |
94 | cv2.imwrite('pred_inst.png', img_box)
95 |
96 | cls_pred = conf_target[0]
97 | gt_box = util.get_topk_box(cls_pred, regression_target[0], anchors_show)
98 | #print('gt_box', gt_box)
99 | img_box = util.add_box_img(img_box, gt_box, color=(255, 0, 0), x = 1, y = 1)
100 | #print('gt_box', gt_box)
101 | cv2.imwrite('pred_inst_gt.png', img_box)'''
102 |
103 | if train:
104 | self.optimizer.zero_grad()
105 | loss.backward()
106 | torch.nn.utils.clip_grad_norm_(self.net.parameters(), config.clip)
107 | self.optimizer.step()
108 |
109 | return cls_loss, reg_loss, loss
110 |
111 | '''save model'''
112 | def save(self,model, exp_name_dir, epoch):
113 | util.adjust_learning_rate(self.optimizer, config.gamma)
114 |
115 | model_save_dir_pth = '{}/model'.format(exp_name_dir)
116 | if not os.path.exists(model_save_dir_pth):
117 | os.makedirs(model_save_dir_pth)
118 | net_path = os.path.join(model_save_dir_pth, 'model_e%d.pth' % (epoch + 1))
119 | torch.save(model.net.state_dict(), net_path)
120 |
121 | '''class SiamRPN(nn.Module):
122 |
123 | def __init__(self, anchor_num = 5):
124 | super(SiamRPN, self).__init__()
125 |
126 | self.anchor_num = anchor_num
127 | self.feature = nn.Sequential(
128 | # conv1
129 | nn.Conv2d(3, 64, kernel_size = 11, stride = 2),
130 | nn.BatchNorm2d(64),
131 | nn.ReLU(inplace = True),
132 | nn.MaxPool2d(kernel_size = 3, stride = 2),
133 | # conv2
134 | nn.Conv2d(64, 192, kernel_size = 5),
135 | nn.BatchNorm2d(192),
136 | nn.ReLU(inplace=True),
137 | nn.MaxPool2d(kernel_size = 3, stride = 2),
138 | # conv3
139 | nn.Conv2d(192, 384, kernel_size = 3),
140 | nn.BatchNorm2d(384),
141 | nn.ReLU(inplace = True),
142 | # conv4
143 | nn.Conv2d(384, 256, kernel_size = 3),
144 | nn.BatchNorm2d(256),
145 | nn.ReLU(inplace = True),
146 | # conv5
147 | nn.Conv2d(256, 256, kernel_size = 3),
148 | nn.BatchNorm2d(256))
149 |
150 | self.conv_reg_z = nn.Conv2d(256, 256 * 4 * self.anchor_num, 3, 1)
151 | self.conv_reg_x = nn.Conv2d(256, 256, 3)
152 | self.conv_cls_z = nn.Conv2d(256, 256 * 2 * anchor_num, 3, 1)
153 | self.conv_cls_x = nn.Conv2d(256, 256, 3)
154 | self.adjust_reg = nn.Conv2d(4 * anchor_num, 4 * anchor_num*1, 1)
155 |
156 | def forward(self, z, x):
157 | return self.inference(x, *self.learn(z))
158 |
159 | def learn(self, z):
160 | z = self.feature(z)
161 | kernel_reg = self.conv_reg_z(z)
162 | kernel_cls = self.conv_cls_z(z)
163 |
164 | k = kernel_reg.size()[-1]
165 | kernel_reg = kernel_reg.view(4 * self.anchor_num, 256, k, k)
166 | kernel_cls = kernel_cls.view(2 * self.anchor_num, 256, k, k)
167 |
168 | return kernel_reg, kernel_cls
169 |
170 | def inference(self, x, kernel_reg, kernel_cls):
171 | x = self.feature(x)
172 | x_reg = self.conv_reg_x(x)
173 | x_cls = self.conv_cls_x(x)
174 |
175 | out_reg = self.adjust_reg(F.conv2d(x_reg, kernel_reg))
176 | out_cls = F.conv2d(x_cls, kernel_cls)
177 |
178 | return out_reg, out_cls'''
179 |
--------------------------------------------------------------------------------
/train/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from util import util
7 |
8 | def rpn_cross_entropy_old(input, target):
9 | mask_ignore = target == -1
10 | mask_calcu = 1 - mask_ignore
11 | loss = F.cross_entropy(input=input[mask_calcu], target=target[mask_calcu])
12 | return loss
13 |
14 |
15 | def rpn_cross_entropy_balance_old(input, target, num_pos, num_neg):
16 | cal_index_pos = np.array([], dtype=np.int64)
17 | cal_index_neg = np.array([], dtype=np.int64)
18 | for batch_id in range(target.shape[0]):
19 | print(target[batch_id])
20 | pos_index = np.random.choice(np.where(target[batch_id].cpu() == 1)[0], num_pos)
21 | neg_index = np.random.choice(np.where(target[batch_id].cpu() == 0)[0], num_neg)
22 | cal_index_pos = np.append(cal_index_pos, batch_id * target.shape[1] + pos_index)
23 | cal_index_neg = np.append(cal_index_neg, batch_id * target.shape[1] + neg_index)
24 | pos_loss = F.cross_entropy(input=input.reshape(-1, 2)[cal_index_pos], target=target.flatten()[cal_index_pos],
25 | reduction='sum') / cal_index_pos.shape[0]
26 | neg_loss = F.cross_entropy(input=input.reshape(-1, 2)[cal_index_neg], target=target.flatten()[cal_index_neg],
27 | reduction='sum') / cal_index_neg.shape[0]
28 | loss = (pos_loss + neg_loss) / 2
29 | # loss = F.cross_entropy(input=input.reshape(-1, 2)[cal_index], target=target.flatten()[cal_index])
30 | return loss
31 |
32 | def rpn_smoothL1_old(input, target, label):
33 | pos_index = np.where(label.cpu() == 1)
34 | loss = F.smooth_l1_loss(input[pos_index], target[pos_index])
35 | return loss
36 |
37 | def rpn_cross_entropy_balance(input, target, num_pos, num_neg, anchors, ohem_pos=None, ohem_neg=None):
38 | cuda = torch.cuda.is_available()
39 | loss_all = []
40 | for batch_id in range(target.shape[0]):
41 | min_pos = min(len(np.where(target[batch_id].cpu() == 1)[0]), num_pos)
42 | min_neg = int(min(len(np.where(target[batch_id].cpu() == 1)[0]) * num_neg / num_pos, num_neg))
43 |
44 | pos_index = np.where(target[batch_id].cpu() == 1)[0].tolist()
45 | neg_index = np.where(target[batch_id].cpu() == 0)[0].tolist()
46 |
47 | if ohem_pos:
48 | if len(pos_index) > 0:
49 | pos_loss_bid = F.cross_entropy(input=input[batch_id][pos_index],
50 | target=target[batch_id][pos_index], reduction='none')
51 | selected_pos_index = util.nms(anchors[pos_index], pos_loss_bid.cpu().detach().numpy(), min_pos)
52 | pos_loss_bid_final = pos_loss_bid[selected_pos_index]
53 | else:
54 | if cuda:
55 | pos_loss_bid = torch.FloatTensor([0]).cuda()
56 | else:
57 | pos_loss_bid = torch.FloatTensor([0])
58 | pos_loss_bid_final = pos_loss_bid
59 | else:
60 | pos_index_random = random.sample(pos_index, min_pos)
61 | if len(pos_index) > 0:
62 | pos_loss_bid_final = F.cross_entropy(input=input[batch_id][pos_index_random],
63 | target=target[batch_id][pos_index_random], reduction='none')
64 | else:
65 | if cuda:
66 | pos_loss_bid_final = torch.FloatTensor([0]).cuda()
67 | else:
68 | pos_loss_bid_final = torch.FloatTensor([0])
69 |
70 | if ohem_neg:
71 | if len(pos_index) > 0:
72 | neg_loss_bid = F.cross_entropy(input=input[batch_id][neg_index],
73 | target=target[batch_id][neg_index], reduction='none')
74 | selected_neg_index = util.nms(anchors[neg_index], neg_loss_bid.cpu().detach().numpy(), min_neg)
75 | neg_loss_bid_final = neg_loss_bid[selected_neg_index]
76 | else:
77 | neg_loss_bid = F.cross_entropy(input=input[batch_id][neg_index],
78 | target=target[batch_id][neg_index], reduction='none')
79 | selected_neg_index = util.nms(anchors[neg_index], neg_loss_bid.cpu().detach().numpy(), num_neg)
80 | neg_loss_bid_final = neg_loss_bid[selected_neg_index]
81 | else:
82 | if len(pos_index) > 0:
83 | neg_index_random = random.sample(np.where(target[batch_id].cpu() == 0)[0].tolist(), min_neg)
84 | #neg_index_random = np.where(target[batch_id].cpu() == 0)[0].tolist()
85 | neg_loss_bid_final = F.cross_entropy(input=input[batch_id][neg_index_random],
86 | target=target[batch_id][neg_index_random], reduction='none')
87 | else:
88 | neg_index_random = random.sample(np.where(target[batch_id].cpu() == 0)[0].tolist(), num_neg)
89 | neg_loss_bid_final = F.cross_entropy(input=input[batch_id][neg_index_random],
90 | target=target[batch_id][neg_index_random], reduction='none')
91 | loss_bid = (pos_loss_bid_final.mean() + neg_loss_bid_final.mean()) / 2
92 | loss_all.append(loss_bid)
93 | final_loss = torch.stack(loss_all).mean()
94 | return final_loss
95 |
96 |
97 | def rpn_smoothL1(input, target, label, num_pos=16, ohem=None):
98 | cuda = torch.cuda.is_available()
99 | loss_all = []
100 | for batch_id in range(target.shape[0]):
101 | min_pos = min(len(np.where(label[batch_id].cpu() == 1)[0]), num_pos)
102 | if ohem:
103 | pos_index = np.where(label[batch_id].cpu() == 1)[0]
104 | if len(pos_index) > 0:
105 | loss_bid = F.smooth_l1_loss(input[batch_id][pos_index], target[batch_id][pos_index], reduction='none')
106 | sort_index = torch.argsort(loss_bid.mean(1))
107 | loss_bid_ohem = loss_bid[sort_index[-num_pos:]]
108 | else:
109 | if cuda:
110 | loss_bid_ohem = torch.FloatTensor([0]).cuda()[0]
111 | else:
112 | loss_bid_ohem = torch.FloatTensor([0])[0]
113 |
114 | loss_all.append(loss_bid_ohem.mean())
115 | else:
116 | pos_index = np.where(label[batch_id].cpu() == 1)[0]
117 | pos_index = random.sample(pos_index.tolist(), min_pos)
118 | if len(pos_index) > 0:
119 | loss_bid = F.smooth_l1_loss(input[batch_id][pos_index], target[batch_id][pos_index])
120 | else:
121 | if cuda:
122 | loss_bid = torch.FloatTensor([0]).cuda()[0]
123 | else:
124 | loss_bid = torch.FloatTensor([0])[0]
125 |
126 | loss_all.append(loss_bid.mean())
127 | final_loss = torch.stack(loss_all).mean()
128 | return final_loss
129 |
--------------------------------------------------------------------------------
/train/train_siamrpn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import sys
4 | import json
5 | import torch
6 | import random
7 | import logging
8 | import argparse
9 | import numpy as np
10 | from tqdm import tqdm
11 | from torch.nn import init
12 | from config import config
13 | from net import TrackerSiamRPN
14 | from data import TrainDataLoader
15 | from torch.utils.data import DataLoader
16 | from util import util, AverageMeter, SavePlot
17 | from got10k.datasets import ImageNetVID, GOT10k
18 | from torchvision import datasets, transforms, utils
19 | from custom_transforms import Normalize, ToTensor, RandomStretch, \
20 | RandomCrop, CenterCrop, RandomBlur, ColorAug
21 |
22 | torch.manual_seed(1234) # config.seed
23 |
24 |
25 | parser = argparse.ArgumentParser(description='PyTorch SiameseRPN Training')
26 |
27 | parser.add_argument('--train_path', default='/home/arbi/desktop/GOT-10k', metavar='DIR',help='path to dataset')
28 | parser.add_argument('--experiment_name', default='default', metavar='DIR',help='path to weight')
29 | parser.add_argument('--checkpoint_path', default=None, help='resume')
30 | # /home/arbi/desktop/GOT-10k # /Users/arbi/Desktop # /home/arbi/desktop/ILSVRC
31 | # 'experiments/default/model/model_e1.pth'
32 | def main():
33 |
34 | '''parameter initialization'''
35 | args = parser.parse_args()
36 | exp_name_dir = util.experiment_name_dir(args.experiment_name)
37 |
38 | '''model on gpu'''
39 | model = TrackerSiamRPN()
40 |
41 | '''setup train data loader'''
42 | name = 'GOT-10k'
43 | assert name in ['VID', 'GOT-10k', 'All']
44 | if name == 'GOT-10k':
45 | root_dir = args.train_path
46 | seq_dataset = GOT10k(root_dir, subset='train')
47 | elif name == 'VID':
48 | root_dir = '/home/arbi/desktop/ILSVRC'
49 | seq_dataset = ImageNetVID(root_dir, subset=('train'))
50 | elif name == 'All':
51 | root_dir_vid = '/home/arbi/desktop/ILSVRC'
52 | seq_datasetVID = ImageNetVID(root_dir_vid, subset=('train'))
53 | root_dir_got = args.train_path
54 | seq_datasetGOT = GOT10k(root_dir_got, subset='train')
55 | seq_dataset = util.data_split(seq_datasetVID, seq_datasetGOT)
56 | print('seq_dataset', len(seq_dataset))
57 |
58 | train_z_transforms = transforms.Compose([
59 | ToTensor()
60 | ])
61 | train_x_transforms = transforms.Compose([
62 | ToTensor()
63 | ])
64 |
65 | train_data = TrainDataLoader(seq_dataset, train_z_transforms, train_x_transforms, name)
66 | anchors = train_data.anchors
67 | train_loader = DataLoader( dataset = train_data,
68 | batch_size = config.train_batch_size,
69 | shuffle = True,
70 | num_workers= config.train_num_workers,
71 | pin_memory = True)
72 |
73 | '''setup val data loader'''
74 | name = 'GOT-10k'
75 | assert name in ['VID', 'GOT-10k', 'All']
76 | if name == 'GOT-10k':
77 | root_dir = args.train_path
78 | seq_dataset_val = GOT10k(root_dir, subset='val')
79 | elif name == 'VID':
80 | root_dir = '/home/arbi/desktop/ILSVRC'
81 | seq_dataset_val = ImageNetVID(root_dir, subset=('val'))
82 | elif name == 'All':
83 | root_dir_vid = '/home/arbi/desktop/ILSVRC'
84 | seq_datasetVID = ImageNetVID(root_dir_vid, subset=('val'))
85 | root_dir_got = args.train_path
86 | seq_datasetGOT = GOT10k(root_dir_got, subset='val')
87 | seq_dataset_val = util.data_split(seq_datasetVID, seq_datasetGOT)
88 | print('seq_dataset_val', len(seq_dataset_val))
89 |
90 | valid_z_transforms = transforms.Compose([
91 | ToTensor()
92 | ])
93 | valid_x_transforms = transforms.Compose([
94 | ToTensor()
95 | ])
96 |
97 | val_data = TrainDataLoader(seq_dataset_val, valid_z_transforms, valid_x_transforms, name)
98 | val_loader = DataLoader( dataset = val_data,
99 | batch_size = config.valid_batch_size,
100 | shuffle = False,
101 | num_workers= config.valid_num_workers,
102 | pin_memory = True)
103 |
104 | '''load weights'''
105 |
106 | if not args.checkpoint_path == None:
107 | assert os.path.isfile(args.checkpoint_path), '{} is not valid checkpoint_path'.format(args.checkpoint_path)
108 | checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
109 | if 'model' in checkpoint.keys():
110 | model.net.load_state_dict(torch.load(args.checkpoint_path, map_location='cpu')['model'])
111 | else:
112 | model.net.load_state_dict(torch.load(args.checkpoint_path, map_location='cpu'))
113 | torch.cuda.empty_cache()
114 | print('You are loading the model.load_state_dict')
115 |
116 | elif config.pretrained_model:
117 | checkpoint = torch.load(config.pretrained_model)
118 | # change name and load parameters
119 | checkpoint = {k.replace('features.features', 'featureExtract'): v for k, v in checkpoint.items()}
120 | model_dict = model.net.state_dict()
121 | model_dict.update(checkpoint)
122 | model.net.load_state_dict(model_dict)
123 | #torch.cuda.empty_cache()
124 |
125 | '''train phase'''
126 | train_closses, train_rlosses, train_tlosses = AverageMeter(), AverageMeter(), AverageMeter()
127 | val_closses, val_rlosses, val_tlosses = AverageMeter(), AverageMeter(), AverageMeter()
128 |
129 | train_val_plot = SavePlot(exp_name_dir, 'train_val_plot')
130 |
131 | for epoch in range(config.epoches):
132 | model.net.train()
133 | if config.fix_former_3_layers:
134 | util.freeze_layers(model.net)
135 | print('Train epoch {}/{}'.format(epoch+1, config.epoches))
136 | train_loss = []
137 | with tqdm(total=config.train_epoch_size) as progbar:
138 | for i, dataset in enumerate(train_loader):
139 |
140 | closs, rloss, loss = model.step(epoch, dataset,anchors, i, train=True)
141 |
142 | closs_ = closs.cpu().item()
143 |
144 | if np.isnan(closs_):
145 | sys.exit(0)
146 |
147 | train_closses.update(closs.cpu().item())
148 | train_rlosses.update(rloss.cpu().item())
149 | train_tlosses.update(loss.cpu().item())
150 |
151 | progbar.set_postfix(closs='{:05.3f}'.format(train_closses.avg),
152 | rloss='{:05.5f}'.format(train_rlosses.avg),
153 | tloss='{:05.3f}'.format(train_tlosses.avg))
154 |
155 | progbar.update()
156 | train_loss.append(train_tlosses.avg)
157 |
158 | if i >= config.train_epoch_size - 1:
159 |
160 | '''save model'''
161 | model.save(model, exp_name_dir, epoch)
162 |
163 | break
164 |
165 | train_loss = np.mean(train_loss)
166 |
167 | '''val phase'''
168 | val_loss = []
169 | with tqdm(total=config.val_epoch_size) as progbar:
170 | print('Val epoch {}/{}'.format(epoch+1, config.epoches))
171 | for i, dataset in enumerate(val_loader):
172 |
173 | val_closs, val_rloss, val_tloss = model.step(epoch, dataset, anchors, train=False)
174 |
175 | closs_ = val_closs.cpu().item()
176 |
177 | if np.isnan(closs_):
178 | sys.exit(0)
179 |
180 | val_closses.update(val_closs.cpu().item())
181 | val_rlosses.update(val_rloss.cpu().item())
182 | val_tlosses.update(val_tloss.cpu().item())
183 |
184 | progbar.set_postfix(closs='{:05.3f}'.format(val_closses.avg),
185 | rloss='{:05.5f}'.format(val_rlosses.avg),
186 | tloss='{:05.3f}'.format(val_tlosses.avg))
187 |
188 | progbar.update()
189 |
190 | val_loss.append(val_tlosses.avg)
191 |
192 | if i >= config.val_epoch_size - 1:
193 | break
194 |
195 | val_loss = np.mean(val_loss)
196 | train_val_plot.update(train_loss, val_loss)
197 | print ('Train loss: {}, val loss: {}'.format(train_loss, val_loss))
198 |
199 |
200 | if __name__ == '__main__':
201 | main()
202 |
--------------------------------------------------------------------------------
/train/util.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch
5 | import cv2
6 |
7 | class Util(object):
8 |
9 | def add_box_img(self, img, boxes, color=(0, 255, 0), x = 1, y = 1):
10 | # boxes (x,y,w,h)
11 | if boxes.ndim == 1:
12 | boxes = boxes[None, :]
13 | img = img.copy()
14 | img_ctx = (img.shape[1] - 1) / 2
15 | img_cty = (img.shape[0] - 1) / 2
16 | for box in boxes:
17 | point_1 = [img_ctx - box[2] / 2 + (box[0]/x) + 0.5, img_cty - box[3] / 2 + (box[1]/y) + 0.5]
18 | point_2 = [img_ctx + box[2] / 2 + (box[0]/x) - 0.5, img_cty + box[3] / 2 + (box[1]/y) - 0.5]
19 | point_1[0] = np.clip(point_1[0], 0, img.shape[1])
20 | point_2[0] = np.clip(point_2[0], 0, img.shape[1])
21 | point_1[1] = np.clip(point_1[1], 0, img.shape[0])
22 | point_2[1] = np.clip(point_2[1], 0, img.shape[0])
23 | img = cv2.rectangle(img, (int(point_1[0]), int(point_1[1])), (int(point_2[0]), int(point_2[1])),
24 | color, 2)
25 | return img
26 |
27 | def get_topk_box(self, cls_score, pred_regression, anchors, topk=10):
28 | # anchors xc,yc,w,h
29 | regress_offset = pred_regression.cpu().detach().numpy()
30 |
31 | scores, index = torch.topk(cls_score, topk, )
32 | index = index.view(-1).cpu().detach().numpy()
33 |
34 | topk_offset = regress_offset[index, :]
35 | anchors = anchors[index, :]
36 | pred_box = self.box_transform_inv(anchors, topk_offset)
37 | return pred_box
38 |
39 | def box_transform_inv(self, anchors, offset):
40 | anchor_xctr = anchors[:, :1]
41 | anchor_yctr = anchors[:, 1:2]
42 | anchor_w = anchors[:, 2:3]
43 | anchor_h = anchors[:, 3:]
44 | offset_x, offset_y, offset_w, offset_h = offset[:, :1], offset[:, 1:2], offset[:, 2:3], offset[:, 3:],
45 |
46 | box_cx = anchor_w * offset_x + anchor_xctr
47 | box_cy = anchor_h * offset_y + anchor_yctr
48 | box_w = anchor_w * np.exp(offset_w)
49 | box_h = anchor_h * np.exp(offset_h)
50 | box = np.hstack([box_cx, box_cy, box_w, box_h])
51 | return box
52 |
53 | def data_split(self, seq_datasetVID, seq_datasetGOT):
54 | seq_dataset = []
55 | for i in seq_datasetVID:
56 | seq_dataset.append(i)
57 |
58 | for i, data in enumerate(seq_datasetGOT):
59 | seq_dataset.append(data)
60 | if i >= 8600:
61 | break
62 | return seq_dataset
63 |
64 | def generate_anchors(self, total_stride, base_size, scales, ratios, score_size):
65 | anchor_num = len(ratios) * len(scales) # 5
66 | anchor = np.zeros((anchor_num, 4), dtype=np.float32)
67 | size = base_size * base_size
68 | count = 0
69 | for ratio in ratios:
70 | # ws = int(np.sqrt(size * 1.0 / ratio))
71 | ws = int(np.sqrt(size / ratio))
72 | hs = int(ws * ratio)
73 | for scale in scales:
74 | wws = ws * scale
75 | hhs = hs * scale
76 | anchor[count, 0] = 0
77 | anchor[count, 1] = 0
78 | anchor[count, 2] = wws
79 | anchor[count, 3] = hhs
80 | count += 1
81 |
82 | anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4))
83 | ori = - (score_size // 2) * total_stride
84 | # the left displacement
85 | xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)],
86 | [ori + total_stride * dy for dy in range(score_size)])
87 |
88 | xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \
89 | np.tile(yy.flatten(), (anchor_num, 1)).flatten()
90 | anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32)
91 | return anchor
92 |
93 | # freeze layers
94 | def freeze_layers(self, model):
95 | for layer in model.featureExtract[:10]:
96 | if isinstance(layer, nn.BatchNorm2d):
97 | layer.eval()
98 | for k, v in layer.named_parameters():
99 | v.requires_grad = False
100 | elif isinstance(layer, nn.Conv2d):
101 | for k, v in layer.named_parameters():
102 | v.requires_grad = False
103 | elif isinstance(layer, nn.MaxPool2d):
104 | continue
105 | elif isinstance(layer, nn.ReLU):
106 | continue
107 | else:
108 | raise KeyError('error in fixing former 3 layers')
109 |
110 | def experiment_name_dir(self, experiment_name):
111 | experiment_name_dir = 'experiments/{}'.format(experiment_name)
112 | if experiment_name == 'default':
113 | print('You are using "default" experiment, my advice to you is: Copy "default" change folder name and change settings in file "parameters.json"')
114 | else:
115 | print('You are using "{}" experiment'.format(experiment_name))
116 | return experiment_name_dir
117 |
118 | def adjust_learning_rate(self, optimizer, decay=0.1):
119 | """Sets the learning rate to the initial LR decayed by 0.5 every 20 epochs"""
120 | for param_group in optimizer.param_groups:
121 | param_group['lr'] = decay * param_group['lr']
122 |
123 | def nms(self, bboxes, scores, num, threshold=0.7):
124 | print('scores', scores)
125 | sort_index = np.argsort(scores)[::-1]
126 | print('sort_index', sort_index)
127 | sort_boxes = bboxes[sort_index]
128 | selected_bbox = [sort_boxes[0]]
129 | selected_index = [sort_index[0]]
130 | for i, bbox in enumerate(sort_boxes):
131 | iou = compute_iou(selected_bbox, bbox)
132 | print(iou, bbox, selected_bbox)
133 | if np.max(iou) < threshold:
134 | selected_bbox.append(bbox)
135 | selected_index.append(sort_index[i])
136 | if len(selected_bbox) >= num:
137 | break
138 | return selected_index
139 |
140 | util = Util()
141 |
142 | class AverageMeter(object):
143 | '''Computes and stores the average and current value'''
144 | def __init__(self):
145 | self.reset()
146 |
147 | def reset(self):
148 | self.val = 0
149 | self.avg = 0
150 | self.sum = 0
151 | self.count = 0
152 |
153 | def update(self, val, n=1):
154 | self.val = val
155 | self.sum += val * n
156 | self.count += n
157 | self.avg = self.sum / self.count
158 |
159 | class SavePlot(object):
160 | def __init__(self, exp_name_dir,
161 | name = 'plot',
162 | title = 'Siamese RPN',
163 | ylabel = 'loss',
164 | xlabel = 'epoch',
165 | show = False):
166 |
167 | self.step = 0
168 | self.exp_name_dir = exp_name_dir
169 | self.steps_array = []
170 | self.train_array = []
171 | self.val_array = []
172 | self.name = name
173 | self.title = title
174 | self.ylabel = ylabel
175 | self.xlabel = xlabel
176 | self.show = show
177 |
178 | self.plot( self.exp_name_dir,
179 | self.steps_array,
180 | self.train_array,
181 | self.val_array,
182 | self.name,
183 | self.title,
184 | self.ylabel,
185 | self.xlabel,
186 | self.show)
187 |
188 | self.plt.legend()
189 |
190 | def update(self, train,
191 | val,
192 | train_label = 'train loss',
193 | val_label = 'val loss',
194 | count_step=1):
195 |
196 | self.step += count_step
197 | self.steps_array.append(self.step)
198 | self.train_array.append(train)
199 | self.val_array.append(val)
200 |
201 | self.plot(exp_name_dir = self.exp_name_dir,
202 | step = self.steps_array,
203 | train = self.train_array,
204 | val = self.val_array,
205 | name = self.name,
206 | title = self.title,
207 | ylabel = self.ylabel,
208 | xlabel = self.xlabel,
209 | show = self.show,
210 | train_label = train_label,
211 | val_label = val_label)
212 |
213 | def plot(self, exp_name_dir,
214 | step,
215 | train,
216 | val,
217 | name,
218 | title,
219 | ylabel,
220 | xlabel,
221 | show,
222 | train_label = 'train loss',
223 | val_label = 'val loss'):
224 | self.plt = plt
225 | self.plt.plot(step, train, 'r', label = train_label, color = 'red')
226 | self.plt.plot(step, val, 'r', label = val_label, color='black')
227 |
228 | self.plt.title(title)
229 | self.plt.ylabel(ylabel)
230 | self.plt.xlabel(xlabel)
231 |
232 | '''save plot'''
233 | self.plt.savefig("{}/{}.png".format(exp_name_dir, name))
234 | if show:
235 | self.plt.show()
236 |
--------------------------------------------------------------------------------
/tracking/siamRPNBIG.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | from util import util
6 | import torch.nn.functional as F
7 | from config import TrackerConfig
8 | import torchvision.transforms as transforms
9 | from custom_transforms import ToTensor
10 | from config import config
11 | from torch.autograd import Variable
12 | from got10k.trackers import Tracker
13 | from network import SiameseAlexNet
14 | from data_loader import TrackerDataLoader
15 | from PIL import Image, ImageOps, ImageStat, ImageDraw
16 |
17 | class SiamRPN(nn.Module):
18 |
19 | def __init__(self, anchor_num = 5):
20 | super(SiamRPN, self).__init__()
21 |
22 | self.anchor_num = anchor_num
23 | self.feature = nn.Sequential(
24 | # conv1
25 | nn.Conv2d(3, 64, kernel_size = 11, stride = 2),
26 | nn.BatchNorm2d(64),
27 | nn.ReLU(inplace = True),
28 | nn.MaxPool2d(kernel_size = 3, stride = 2),
29 | # conv2
30 | nn.Conv2d(64, 192, kernel_size = 5),
31 | nn.BatchNorm2d(192),
32 | nn.ReLU(inplace=True),
33 | nn.MaxPool2d(kernel_size = 3, stride = 2),
34 | # conv3
35 | nn.Conv2d(192, 384, kernel_size = 3),
36 | nn.BatchNorm2d(384),
37 | nn.ReLU(inplace = True),
38 | # conv4
39 | nn.Conv2d(384, 256, kernel_size = 3),
40 | nn.BatchNorm2d(256),
41 | nn.ReLU(inplace = True),
42 | # conv5
43 | nn.Conv2d(256, 256, kernel_size = 3),
44 | nn.BatchNorm2d(256))
45 |
46 | self.conv_reg_z = nn.Conv2d(256, 256 * 4 * self.anchor_num, 3, 1)
47 | self.conv_reg_x = nn.Conv2d(256, 256, 3)
48 | self.conv_cls_z = nn.Conv2d(256, 256 * 2 * anchor_num, 3, 1)
49 | self.conv_cls_x = nn.Conv2d(256, 256, 3)
50 | self.adjust_reg = nn.Conv2d(4 * anchor_num, 4 * anchor_num*1, 1)
51 |
52 | def forward(self, z, x):
53 | return self.inference(x, *self.learn(z))
54 |
55 | def learn(self, z):
56 | z = self.feature(z)
57 | kernel_reg = self.conv_reg_z(z)
58 | kernel_cls = self.conv_cls_z(z)
59 |
60 | k = kernel_reg.size()[-1]
61 | kernel_reg = kernel_reg.view(4 * self.anchor_num, 256, k, k)
62 | kernel_cls = kernel_cls.view(2 * self.anchor_num, 256, k, k)
63 |
64 | return kernel_reg, kernel_cls
65 |
66 | def inference(self, x, kernel_reg, kernel_cls):
67 | x = self.feature(x)
68 | x_reg = self.conv_reg_x(x)
69 | x_cls = self.conv_cls_x(x)
70 |
71 | out_reg = self.adjust_reg(F.conv2d(x_reg, kernel_reg))
72 | out_cls = F.conv2d(x_cls, kernel_cls)
73 |
74 | return out_reg, out_cls
75 |
76 | class TrackerSiamRPNBIG(Tracker):
77 | def __init__(self, params, model_path = None, **kargs):
78 | super(TrackerSiamRPNBIG, self).__init__(name='SiamRPN', is_deterministic=True)
79 |
80 | self.model = SiameseAlexNet()
81 |
82 | self.cuda = torch.cuda.is_available()
83 | self.device = torch.device('cuda:0' if self.cuda else 'cpu')
84 |
85 | checkpoint = torch.load(model_path, map_location = self.device)
86 | #print("1")
87 | if 'model' in checkpoint.keys():
88 | self.model.load_state_dict(torch.load(model_path, map_location = self.device)['model'])
89 | else:
90 | self.model.load_state_dict(torch.load(model_path, map_location = self.device))
91 |
92 |
93 | if self.cuda:
94 | self.model = self.model.cuda()
95 | self.model.eval()
96 | self.transforms = transforms.Compose([
97 | ToTensor()
98 | ])
99 |
100 | valid_scope = 2 * config.valid_scope + 1
101 | self.anchors = util.generate_anchors( config.total_stride,
102 | config.anchor_base_size,
103 | config.anchor_scales,
104 | config.anchor_ratios,
105 | valid_scope)
106 | self.window = np.tile(np.outer(np.hanning(config.score_size), np.hanning(config.score_size))[None, :],
107 | [config.anchor_num, 1, 1]).flatten()
108 |
109 | self.data_loader = TrackerDataLoader()
110 |
111 | def _cosine_window(self, size):
112 | """
113 | get the cosine window
114 | """
115 | cos_window = np.hanning(int(size[0]))[:, np.newaxis].dot(np.hanning(int(size[1]))[np.newaxis, :])
116 | cos_window = cos_window.astype(np.float32)
117 | cos_window /= np.sum(cos_window)
118 | return cos_window
119 |
120 | def init(self, frame, bbox):
121 |
122 | """ initialize siamfc tracker
123 | Args:
124 | frame: an RGB image
125 | bbox: one-based bounding box [x, y, width, height]
126 | """
127 | frame = np.asarray(frame)
128 | '''bbox[0] = bbox[0] + bbox[2]/2
129 | bbox[1] = bbox[1] + bbox[3]/2'''
130 |
131 | self.pos = np.array([bbox[0] + bbox[2] / 2 - 1 / 2, bbox[1] + bbox[3] / 2 - 1 / 2]) # center x, center y, zero based
132 | #self.pos = np.array([bbox[0], bbox[1]]) # center x, center y, zero based
133 |
134 | self.target_sz = np.array([bbox[2], bbox[3]]) # width, height
135 | self.bbox = np.array([bbox[0] + bbox[2] / 2 - 1 / 2, bbox[1] + bbox[3] / 2 - 1 / 2, bbox[2], bbox[3]])
136 | #self.bbox = np.array([bbox[0], bbox[1], bbox[2], bbox[3]])
137 |
138 | self.origin_target_sz = np.array([bbox[2], bbox[3]])
139 | # get exemplar img
140 | self.img_mean = np.mean(frame, axis=(0, 1))
141 |
142 | exemplar_img, _, _ = self.data_loader.get_exemplar_image( frame,
143 | self.bbox,
144 | config.template_img_size,
145 | config.context_amount,
146 | self.img_mean)
147 |
148 | #cv2.imshow('exemplar_img', exemplar_img)
149 | # get exemplar feature
150 | exemplar_img = self.transforms(exemplar_img)[None, :, :, :]
151 | if self.cuda:
152 | self.model.track_init(exemplar_img.cuda())
153 | else:
154 | self.model.track_init(exemplar_img)
155 |
156 | def update(self, frame):
157 | """track object based on the previous frame
158 | Args:
159 | frame: an RGB image
160 |
161 | Returns:
162 | bbox: tuple of 1-based bounding box(xmin, ymin, xmax, ymax)
163 | """
164 | frame = np.asarray(frame)
165 |
166 | instance_img, _, _, scale_x = self.data_loader.get_instance_image( frame,
167 | self.bbox,
168 | config.template_img_size,
169 | config.detection_img_size,
170 | config.context_amount,
171 | self.img_mean)
172 | #cv2.imshow('instance_img', instance_img)
173 |
174 | instance_img = self.transforms(instance_img)[None, :, :, :]
175 | if self.cuda:
176 | pred_score, pred_regression = self.model.track(instance_img.cuda())
177 | else:
178 | pred_score, pred_regression = self.model.track(instance_img)
179 |
180 | pred_conf = pred_score.reshape(-1, 2, config.size ).permute(0, 2, 1)
181 | pred_offset = pred_regression.reshape(-1, 4, config.size ).permute(0, 2, 1)
182 |
183 | delta = pred_offset[0].cpu().detach().numpy()
184 | box_pred = util.box_transform_inv(self.anchors, delta)
185 | score_pred = F.softmax(pred_conf, dim=2)[0, :, 1].cpu().detach().numpy()
186 |
187 | s_c = util.change(util.sz(box_pred[:, 2], box_pred[:, 3]) / (util.sz_wh(self.target_sz * scale_x))) # scale penalty
188 | r_c = util.change((self.target_sz[0] / self.target_sz[1]) / (box_pred[:, 2] / box_pred[:, 3])) # ratio penalty
189 | penalty = np.exp(-(r_c * s_c - 1.) * config.penalty_k)
190 | pscore = penalty * score_pred
191 | pscore = pscore * (1 - config.window_influence) + self.window * config.window_influence
192 | best_pscore_id = np.argmax(pscore)
193 | target = box_pred[best_pscore_id, :] / scale_x
194 |
195 | lr = penalty[best_pscore_id] * score_pred[best_pscore_id] * config.lr_box
196 |
197 | res_x = np.clip(target[0] + self.pos[0], 0, frame.shape[1])
198 | res_y = np.clip(target[1] + self.pos[1], 0, frame.shape[0])
199 |
200 | res_w = np.clip(self.target_sz[0] * (1 - lr) + target[2] * lr, config.min_scale * self.origin_target_sz[0],
201 | config.max_scale * self.origin_target_sz[0])
202 | res_h = np.clip(self.target_sz[1] * (1 - lr) + target[3] * lr, config.min_scale * self.origin_target_sz[1],
203 | config.max_scale * self.origin_target_sz[1])
204 |
205 | self.pos = np.array([res_x, res_y])
206 | self.target_sz = np.array([res_w, res_h])
207 |
208 | bbox = np.array([res_x, res_y, res_w, res_h])
209 | #print('bbox', bbox)
210 | self.bbox = (
211 | np.clip(bbox[0], 0, frame.shape[1]).astype(np.float64),
212 | np.clip(bbox[1], 0, frame.shape[0]).astype(np.float64),
213 | np.clip(bbox[2], 10, frame.shape[1]).astype(np.float64),
214 | np.clip(bbox[3], 10, frame.shape[0]).astype(np.float64))
215 |
216 | res_x = res_x - res_w/2 # x -> x1
217 | res_y = res_y - res_h/2 # y -> y1
218 | bbox = np.array([res_x, res_y, res_w, res_h])
219 | return bbox
220 |
--------------------------------------------------------------------------------
/train/data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import sys
4 | import cv2
5 | import time
6 | import torch
7 | import random
8 | import numpy as np
9 | import os.path as osp
10 | from util import util
11 | from PIL import Image
12 | from config import config
13 | from torch.utils.data import Dataset
14 | from got10k.datasets import ImageNetVID, GOT10k
15 | from torchvision import datasets, transforms, utils
16 | from got10k.datasets import ImageNetVID, GOT10k
17 | from custom_transforms import Normalize, ToTensor, RandomStretch, \
18 | RandomCrop, CenterCrop, RandomBlur, ColorAug
19 |
20 | class TrainDataLoader(Dataset):
21 | def __init__(self, seq_dataset, z_transforms, x_transforms, name = 'GOT-10k'):
22 |
23 | self.max_inter = config.max_inter
24 | self.z_transforms = z_transforms
25 | self.x_transforms = x_transforms
26 | self.sub_class_dir = seq_dataset
27 | self.ret = {}
28 | self.count = 0
29 | self.index = 3000
30 | self.name = name
31 | self.anchors = util.generate_anchors( config.anchor_total_stride,
32 | config.anchor_base_size,
33 | config.anchor_scales,
34 | config.anchor_ratios,
35 | config.score_size)
36 |
37 |
38 | def _pick_img_pairs(self, index_of_subclass):
39 |
40 | assert index_of_subclass < len(self.sub_class_dir), 'index_of_subclass should less than total classes'
41 |
42 | video_name = self.sub_class_dir[index_of_subclass][0]
43 |
44 | video_num = len(video_name)
45 | video_gt = self.sub_class_dir[index_of_subclass][1]
46 |
47 | status = True
48 | while status:
49 | if self.max_inter >= video_num-1:
50 | self.max_inter = video_num//2
51 |
52 | template_index = np.clip(random.choice(range(0, max(1, video_num - self.max_inter))), 0, video_num-1)
53 |
54 | detection_index= np.clip(random.choice(range(1, max(2, self.max_inter))) + template_index, 0, video_num-1)
55 |
56 | template_img_path, detection_img_path = video_name[template_index], video_name[detection_index]
57 |
58 | template_gt = video_gt[template_index]
59 |
60 | detection_gt = video_gt[detection_index]
61 |
62 | if template_gt[2]*template_gt[3]*detection_gt[2]*detection_gt[3] != 0:
63 | status = False
64 | else:
65 | #print('Warning : Encounter object missing, reinitializing ...')
66 | print( 'index_of_subclass:', index_of_subclass, '\n',
67 | 'template_index:', template_index, '\n',
68 | 'template_gt:', template_gt, '\n',
69 | 'detection_index:', detection_index, '\n',
70 | 'detection_gt:', detection_gt, '\n')
71 |
72 |
73 | # load infomation of template and detection
74 | self.ret['template_img_path'] = template_img_path
75 | self.ret['detection_img_path'] = detection_img_path
76 | self.ret['template_target_x1y1wh'] = template_gt
77 | self.ret['detection_target_x1y1wh']= detection_gt
78 | t1, t2 = self.ret['template_target_x1y1wh'].copy(), self.ret['detection_target_x1y1wh'].copy()
79 | self.ret['template_target_xywh'] = np.array([t1[0]+t1[2]//2, t1[1]+t1[3]//2, t1[2], t1[3]], np.float32)
80 | self.ret['detection_target_xywh'] = np.array([t2[0]+t2[2]//2, t2[1]+t2[3]//2, t2[2], t2[3]], np.float32)
81 | self.ret['anchors'] = self.anchors
82 | #self._average()
83 |
84 | def open(self):
85 |
86 | '''template'''
87 | #template_img = cv2.imread(self.ret['template_img_path']) if you use cv2.imread you can not open .JPEG format
88 | template_img = Image.open(self.ret['template_img_path'])
89 | template_img = np.array(template_img)
90 |
91 | detection_img = Image.open(self.ret['detection_img_path'])
92 | detection_img = np.array(detection_img)
93 |
94 | if np.random.rand(1) < config.gray_ratio:
95 |
96 | template_img = cv2.cvtColor(template_img, cv2.COLOR_RGB2GRAY)
97 | template_img = cv2.cvtColor(template_img, cv2.COLOR_GRAY2RGB)
98 | detection_img = cv2.cvtColor(detection_img, cv2.COLOR_RGB2GRAY)
99 | detection_img = cv2.cvtColor(detection_img, cv2.COLOR_GRAY2RGB)
100 |
101 | img_mean = np.mean(template_img, axis=(0, 1))
102 | #img_mean = tuple(map(int, template_img.mean(axis=(0, 1))))
103 |
104 | exemplar_img, scale_z, s_z, w_x, h_x = self.get_exemplar_image( template_img,
105 | self.ret['template_target_xywh'],
106 | config.template_img_size,
107 | config.context, img_mean )
108 |
109 | size_x = config.template_img_size
110 | x1, y1 = int((size_x + 1) / 2 - w_x / 2), int((size_x + 1) / 2 - h_x / 2)
111 | x2, y2 = int((size_x + 1) / 2 + w_x / 2), int((size_x + 1) / 2 + h_x / 2)
112 | #frame = cv2.rectangle(exemplar_img, (x1,y1), (x2,y2), (0, 255, 0), 1)
113 | #cv2.imwrite('exemplar_img.png',frame)
114 | #cv2.waitKey(0)
115 |
116 | self.ret['exemplar_img'] = exemplar_img
117 |
118 | '''detection'''
119 | #detection_img = cv2.imread(self.ret['detection_img_path'])
120 | d = self.ret['detection_target_xywh']
121 | cx, cy, w, h = d # float type
122 |
123 | wc_z = w + 0.5 * (w + h)
124 | hc_z = h + 0.5 * (w + h)
125 | s_z = np.sqrt(wc_z * hc_z)
126 |
127 | s_x = s_z / (config.detection_img_size//2)
128 | img_mean_d = tuple(map(int, detection_img.mean(axis=(0, 1))))
129 |
130 | a_x_ = np.random.choice(range(-12,12))
131 | a_x = a_x_ * s_x
132 |
133 | b_y_ = np.random.choice(range(-12,12))
134 | b_y = b_y_ * s_x
135 |
136 | instance_img, a_x, b_y, w_x, h_x, scale_x = self.get_instance_image( detection_img, d,
137 | config.template_img_size, # 127
138 | config.detection_img_size,# 255
139 | config.context, # 0.5
140 | a_x, b_y,
141 | img_mean_d )
142 |
143 | size_x = config.detection_img_size
144 |
145 | x1, y1 = int((size_x + 1) / 2 - w_x / 2), int((size_x + 1) / 2 - h_x / 2)
146 | x2, y2 = int((size_x + 1) / 2 + w_x / 2), int((size_x + 1) / 2 + h_x / 2)
147 |
148 | #frame_d = cv2.rectangle(instance_img, (int(x1+(a_x*scale_x)),int(y1+(b_y*scale_x))), (int(x2+(a_x*scale_x)),int(y2+(b_y*scale_x))), (0, 255, 0), 1)
149 | #cv2.imwrite('detection_img_ori.png',frame_d)
150 |
151 | w = x2 - x1
152 | h = y2 - y1
153 | cx = x1 + w/2
154 | cy = y1 + h/2
155 |
156 | #print('[a_x_, b_y_, w, h]', [int(a_x_), int(b_y_), w, h])
157 |
158 | self.ret['instance_img'] = instance_img
159 | #self.ret['cx, cy, w, h'] = [int(a_x_*0.16), int(b_y_*0.16), w, h]
160 | self.ret['cx, cy, w, h'] = [int(a_x_), int(b_y_), w, h]
161 |
162 | def get_exemplar_image(self, img, bbox, size_z, context_amount, img_mean=None):
163 | cx, cy, w, h = bbox
164 |
165 | wc_z = w + context_amount * (w + h)
166 | hc_z = h + context_amount * (w + h)
167 | s_z = np.sqrt(wc_z * hc_z)
168 | scale_z = size_z / s_z
169 |
170 | exemplar_img, scale_x = self.crop_and_pad_old(img, cx, cy, size_z, s_z, img_mean)
171 |
172 | w_x = w * scale_x
173 | h_x = h * scale_x
174 |
175 | return exemplar_img, scale_z, s_z, w_x, h_x
176 |
177 | def get_instance_image(self, img, bbox, size_z, size_x, context_amount, a_x, b_y, img_mean=None):
178 |
179 | cx, cy, w, h = bbox # float type
180 |
181 | #cx, cy = cx - a_x , cy - b_y
182 | wc_z = w + context_amount * (w + h)
183 | hc_z = h + context_amount * (w + h)
184 | s_z = np.sqrt(wc_z * hc_z) # the width of the crop box
185 |
186 | scale_z = size_z / s_z
187 |
188 | s_x = s_z * size_x / size_z
189 | instance_img, gt_w, gt_h, scale_x, scale_h, scale_w = self.crop_and_pad(img, cx, cy, w, h, a_x, b_y, size_x, s_x, img_mean)
190 | w_x = gt_w #* scale_x #w * scale_x
191 | h_x = gt_h #* scale_x #h * scale_x
192 |
193 | #cx, cy = cx/ scale_w *scale_x, cy/ scale_h *scale_x
194 | #cx, cy = cx/ scale_w, cy/ scale_h
195 | a_x, b_y = a_x*scale_w, b_y*scale_h
196 | x1, y1 = int((size_x + 1) / 2 - w_x / 2), int((size_x + 1) / 2 - h_x / 2)
197 | x2, y2 = int((size_x + 1) / 2 + w_x / 2), int((size_x + 1) / 2 + h_x / 2)
198 | '''frame = cv2.rectangle(instance_img, ( int(x1+(a_x*scale_x)),
199 | int(y1+(b_y*scale_x))),
200 | (int(x2+(a_x*scale_x)),
201 | int(y2+(b_y*scale_x))),
202 | (0, 255, 0), 1)'''
203 | #cv2.imwrite('1.jpg', frame)
204 | return instance_img, a_x, b_y, w_x, h_x, scale_x
205 |
206 | def crop_and_pad(self, img, cx, cy, gt_w, gt_h, a_x, b_y, model_sz, original_sz, img_mean=None):
207 |
208 | #random = np.random.uniform(-0.15, 0.15)
209 | scale_h = 1.0 + np.random.uniform(-0.15, 0.15)
210 | scale_w = 1.0 + np.random.uniform(-0.15, 0.15)
211 |
212 | im_h, im_w, _ = img.shape
213 |
214 | xmin = (cx-a_x) - ((original_sz - 1) / 2)* scale_w
215 | xmax = (cx-a_x) + ((original_sz - 1) / 2)* scale_w
216 |
217 | ymin = (cy-b_y) - ((original_sz - 1) / 2)* scale_h
218 | ymax = (cy-b_y) + ((original_sz - 1) / 2)* scale_h
219 |
220 | #print('xmin, xmax, ymin, ymax', xmin, xmax, ymin, ymax)
221 |
222 | left = int(self.round_up(max(0., -xmin)))
223 | top = int(self.round_up(max(0., -ymin)))
224 | right = int(self.round_up(max(0., xmax - im_w + 1)))
225 | bottom = int(self.round_up(max(0., ymax - im_h + 1)))
226 |
227 | xmin = int(self.round_up(xmin + left))
228 | xmax = int(self.round_up(xmax + left))
229 | ymin = int(self.round_up(ymin + top))
230 | ymax = int(self.round_up(ymax + top))
231 |
232 | r, c, k = img.shape
233 | if any([top, bottom, left, right]):
234 | te_im_ = np.zeros((int((r + top + bottom)), int((c + left + right)), k), np.uint8) # 0 is better than 1 initialization
235 | te_im = np.zeros((int((r + top + bottom)), int((c + left + right)), k), np.uint8) # 0 is better than 1 initialization
236 |
237 | #cv2.imwrite('te_im1.jpg', te_im)
238 | te_im[:, :, :] = img_mean
239 | #cv2.imwrite('te_im2_1.jpg', te_im)
240 | te_im[top:top + r, left:left + c, :] = img
241 | #cv2.imwrite('te_im2.jpg', te_im)
242 |
243 | if top:
244 | te_im[0:top, left:left + c, :] = img_mean
245 | if bottom:
246 | te_im[r + top:, left:left + c, :] = img_mean
247 | if left:
248 | te_im[:, 0:left, :] = img_mean
249 | if right:
250 | te_im[:, c + left:, :] = img_mean
251 |
252 | im_patch_original = te_im[int(ymin):int(ymax + 1), int(xmin):int(xmax + 1), :]
253 |
254 | #cv2.imwrite('te_im3.jpg', im_patch_original)
255 |
256 | else:
257 | im_patch_original = img[int(ymin):int((ymax) + 1), int(xmin):int((xmax) + 1), :]
258 |
259 | #cv2.imwrite('te_im4.jpg', im_patch_original)
260 |
261 | if not np.array_equal(model_sz, original_sz):
262 |
263 | h, w, _ = im_patch_original.shape
264 |
265 |
266 | if h < w:
267 | scale_h_ = 1
268 | scale_w_ = h/w
269 | scale = config.detection_img_size/h
270 | elif h > w:
271 | scale_h_ = w/h
272 | scale_w_ = 1
273 | scale = config.detection_img_size/w
274 | elif h == w:
275 | scale_h_ = 1
276 | scale_w_ = 1
277 | scale = config.detection_img_size/w
278 |
279 | gt_w = gt_w * scale_w_
280 | gt_h = gt_h * scale_h_
281 |
282 | gt_w = gt_w * scale
283 | gt_h = gt_h * scale
284 |
285 | #im_patch = cv2.resize(im_patch_original_, (shape)) # zzp: use cv to get a better speed
286 | #cv2.imwrite('te_im8.jpg', im_patch)
287 |
288 | im_patch = cv2.resize(im_patch_original, (model_sz, model_sz)) # zzp: use cv to get a better speed
289 | #cv2.imwrite('te_im9.jpg', im_patch)
290 |
291 |
292 | else:
293 | im_patch = im_patch_original
294 | #scale = model_sz / im_patch_original.shape[0]
295 | return im_patch, gt_w, gt_h, scale, scale_h_, scale_w_
296 |
297 |
298 |
299 |
300 | def crop_and_pad_old(self, img, cx, cy, model_sz, original_sz, img_mean=None):
301 | im_h, im_w, _ = img.shape
302 |
303 | xmin = cx - (original_sz - 1) / 2
304 | xmax = xmin + original_sz - 1
305 | ymin = cy - (original_sz - 1) / 2
306 | ymax = ymin + original_sz - 1
307 |
308 | left = int(self.round_up(max(0., -xmin)))
309 | top = int(self.round_up(max(0., -ymin)))
310 | right = int(self.round_up(max(0., xmax - im_w + 1)))
311 | bottom = int(self.round_up(max(0., ymax - im_h + 1)))
312 |
313 | xmin = int(self.round_up(xmin + left))
314 | xmax = int(self.round_up(xmax + left))
315 | ymin = int(self.round_up(ymin + top))
316 | ymax = int(self.round_up(ymax + top))
317 | r, c, k = img.shape
318 | if any([top, bottom, left, right]):
319 | te_im = np.zeros((r + top + bottom, c + left + right, k), np.uint8) # 0 is better than 1 initialization
320 | te_im[top:top + r, left:left + c, :] = img
321 | if top:
322 | te_im[0:top, left:left + c, :] = img_mean
323 | if bottom:
324 | te_im[r + top:, left:left + c, :] = img_mean
325 | if left:
326 | te_im[:, 0:left, :] = img_mean
327 | if right:
328 | te_im[:, c + left:, :] = img_mean
329 | im_patch_original = te_im[int(ymin):int(ymax + 1), int(xmin):int(xmax + 1), :]
330 | else:
331 | im_patch_original = img[int(ymin):int(ymax + 1), int(xmin):int(xmax + 1), :]
332 | if not np.array_equal(model_sz, original_sz):
333 |
334 | im_patch = cv2.resize(im_patch_original, (model_sz, model_sz)) # zzp: use cv to get a better speed
335 | else:
336 | im_patch = im_patch_original
337 | scale = model_sz / im_patch_original.shape[0]
338 | return im_patch, scale
339 |
340 | def round_up(self, value):
341 | return round(value + 1e-6 + 1000) - 1000
342 |
343 | def _target(self):
344 |
345 | regression_target, conf_target = self.compute_target(self.anchors,
346 | np.array(list(map(round, self.ret['cx, cy, w, h']))))
347 |
348 | return regression_target, conf_target
349 |
350 | def compute_target(self, anchors, box):
351 | #box = [-(box[0]), -(box[1]), box[2], box[3]]
352 | regression_target = self.box_transform(anchors, box)
353 |
354 | iou = self.compute_iou(anchors, box).flatten()
355 | #print(np.max(iou))
356 | pos_index = np.where(iou > config.pos_threshold)[0]
357 | neg_index = np.where(iou < config.neg_threshold)[0]
358 | label = np.ones_like(iou) * -1
359 |
360 | label[pos_index] = 1
361 | label[neg_index] = 0
362 | '''print(len(neg_index))
363 | for i, neg_ind in enumerate(neg_index):
364 | if i % 40 == 0:
365 | label[neg_ind] = 0'''
366 |
367 |
368 |
369 | #max_index = np.argsort(iou.flatten())[-20:]
370 |
371 | return regression_target, label
372 |
373 | def box_transform(self, anchors, gt_box):
374 | anchor_xctr = anchors[:, :1]
375 | anchor_yctr = anchors[:, 1:2]
376 | anchor_w = anchors[:, 2:3]
377 | anchor_h = anchors[:, 3:]
378 | gt_cx, gt_cy, gt_w, gt_h = gt_box
379 |
380 | target_x = (gt_cx - anchor_xctr) / anchor_w
381 | target_y = (gt_cy - anchor_yctr) / anchor_h
382 | target_w = np.log(gt_w / anchor_w)
383 | target_h = np.log(gt_h / anchor_h)
384 | regression_target = np.hstack((target_x, target_y, target_w, target_h))
385 | return regression_target
386 |
387 | def compute_iou(self, anchors, box):
388 | if np.array(anchors).ndim == 1:
389 | anchors = np.array(anchors)[None, :]
390 | else:
391 | anchors = np.array(anchors)
392 | if np.array(box).ndim == 1:
393 | box = np.array(box)[None, :]
394 | else:
395 | box = np.array(box)
396 | gt_box = np.tile(box.reshape(1, -1), (anchors.shape[0], 1))
397 |
398 | anchor_x1 = anchors[:, :1] - anchors[:, 2:3] / 2 + 0.5
399 | anchor_x2 = anchors[:, :1] + anchors[:, 2:3] / 2 - 0.5
400 | anchor_y1 = anchors[:, 1:2] - anchors[:, 3:] / 2 + 0.5
401 | anchor_y2 = anchors[:, 1:2] + anchors[:, 3:] / 2 - 0.5
402 |
403 | gt_x1 = gt_box[:, :1] - gt_box[:, 2:3] / 2 + 0.5
404 | gt_x2 = gt_box[:, :1] + gt_box[:, 2:3] / 2 - 0.5
405 | gt_y1 = gt_box[:, 1:2] - gt_box[:, 3:] / 2 + 0.5
406 | gt_y2 = gt_box[:, 1:2] + gt_box[:, 3:] / 2 - 0.5
407 |
408 | xx1 = np.max([anchor_x1, gt_x1], axis=0)
409 | xx2 = np.min([anchor_x2, gt_x2], axis=0)
410 | yy1 = np.max([anchor_y1, gt_y1], axis=0)
411 | yy2 = np.min([anchor_y2, gt_y2], axis=0)
412 |
413 | inter_area = np.max([xx2 - xx1, np.zeros(xx1.shape)], axis=0) * np.max([yy2 - yy1, np.zeros(xx1.shape)],
414 | axis=0)
415 | area_anchor = (anchor_x2 - anchor_x1) * (anchor_y2 - anchor_y1)
416 | area_gt = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)
417 | iou = inter_area / (area_anchor + area_gt - inter_area + 1e-6)
418 | return iou
419 |
420 | def _tranform(self):
421 |
422 | self.ret['train_x_transforms'] = self.x_transforms(self.ret['instance_img'])
423 | self.ret['train_z_transforms'] = self.z_transforms(self.ret['exemplar_img'])
424 |
425 | def __getitem__(self, index):
426 | index = random.choice(range(len(self.sub_class_dir)))
427 | '''if len(self.sub_class_dir) > 180:
428 | index = self.index
429 | self.index += 1
430 |
431 | if self.index >= 8000:
432 | self.index = 3000
433 |
434 | index = random.choice(range(3000, 8000))
435 |
436 | if index in self.index:
437 | index = random.choice(range(3000, 8000))
438 | print("index in self.index")
439 |
440 | if not index in self.index:
441 | self.index.append(index)
442 | if len(self.index) >= 3000:
443 | self.index = []
444 | else:
445 | index = random.choice(range(len(self.sub_class_dir)))'''
446 |
447 | if self.name == 'GOT-10k':
448 | if index == 4418 or index == 8627 or index == 8629 or index == 9057 or index == 9058:
449 | index += 3
450 | self._pick_img_pairs(index)
451 | self.open()
452 | self._tranform()
453 | regression_target, conf_target = self._target()
454 | self.count += 1
455 |
456 | return self.ret['train_z_transforms'], self.ret['train_x_transforms'], regression_target, conf_target.astype(np.int64)
457 |
458 | def __len__(self):
459 | return config.train_epoch_size*64
460 |
461 | if __name__ == "__main__":
462 |
463 | root_dir = '/Users/arbi/Desktop'
464 | seq_dataset = GOT10k(root_dir, subset='val')
465 | train_data = TrainDataLoader(seq_dataset)
466 | train_data.__getitem__(180)
467 |
--------------------------------------------------------------------------------