├── .idea
├── Object-detecting-and-tracking.iml
├── misc.xml
├── modules.xml
├── other.xml
└── vcs.xml
├── DaSiamRPN
├── README.md
├── demo.py
├── eval_otb.py
├── net.py
├── run_SiamRPN.py
├── test_otb.py
├── utils.py
├── vot.py
└── vot_SiamRPN.py
├── README.md
├── SSD+SiamRPN.py
├── SSD
├── README.md
├── UseSSD.py
├── caffe_to_tensorflow.py
├── code_test.py
├── datasets
│ ├── __init__.py
│ ├── cifar10.py
│ ├── dataset_factory.py
│ ├── dataset_utils.py
│ ├── imagenet.py
│ ├── pascalvoc_2007.py
│ ├── pascalvoc_2012.py
│ ├── pascalvoc_common.py
│ └── pascalvoc_to_tfrecords.py
├── eval_ssd_network.py
├── inspect_checkpoint.py
├── nets
│ ├── __init__.py
│ ├── caffe_scope.py
│ ├── custom_layers.py
│ ├── inception.py
│ ├── inception_resnet_v2.py
│ ├── inception_v3.py
│ ├── nets_factory.py
│ ├── np_methods.py
│ ├── readme
│ ├── ssd_common.py
│ ├── ssd_vgg_300.py
│ ├── ssd_vgg_512.py
│ ├── vgg.py
│ └── xception.py
├── preprocessing
│ ├── __init__.py
│ ├── inception_preprocessing.py
│ ├── preprocessing_factory.py
│ ├── readme
│ ├── ssd_vgg_preprocessing.py
│ ├── tf_image.py
│ └── vgg_preprocessing.py
├── readme
├── tf_convert_data.py
├── tf_extended
│ ├── __init__.py
│ ├── bboxes.py
│ ├── image.py
│ ├── math.py
│ ├── metrics.py
│ └── tensors.py
├── tf_utils.py
└── train_ssd_network.py
└── spaceshooter
├── assets
├── bolt_gold.png
├── laserRed16.png
├── main.png
├── meteorBrown_big1.png
├── meteorBrown_big2.png
├── meteorBrown_med1.png
├── meteorBrown_med3.png
├── meteorBrown_small1.png
├── meteorBrown_small2.png
├── meteorBrown_tiny1.png
├── missile.png
├── playerShip1_orange.png
├── regularExplosion00.png
├── regularExplosion01.png
├── regularExplosion02.png
├── regularExplosion03.png
├── regularExplosion04.png
├── regularExplosion05.png
├── regularExplosion06.png
├── regularExplosion07.png
├── regularExplosion08.png
├── shield_gold.png
├── sonicExplosion00.png
├── sonicExplosion01.png
├── sonicExplosion02.png
├── sonicExplosion03.png
├── sonicExplosion04.png
├── sonicExplosion05.png
├── sonicExplosion06.png
├── sonicExplosion07.png
├── sonicExplosion08.png
└── starfield.png
├── detection_and_tracking.py
├── detection_tracking_game.py
├── readme
├── simulate_game.py
├── sounds
├── expl3.wav
├── expl6.wav
├── getready.ogg
├── menu.ogg
├── pew.wav
├── rocket.ogg
├── rumble1.ogg
└── tgfcoder-FrozenJam-SeamlessLoop.ogg
└── spaceShooter.py
/.idea/Object-detecting-and-tracking.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/DaSiamRPN/README.md:
--------------------------------------------------------------------------------
1 | :trophy:News: **We won the VOT-18 real-time challenge**
2 |
3 | :trophy:News: **We won the second place in the VOT-18 long-term challenge**
4 |
5 | # DaSiamRPN
6 |
7 | This repository includes PyTorch code for reproducing the results on VOT2018.
8 |
9 | [**Distractor-aware Siamese Networks for Visual Object Tracking**](https://arxiv.org/pdf/1808.06048.pdf)
10 |
11 | Zheng Zhu\*, Qiang Wang\*, Bo Li\*, Wei Wu, Junjie Yan, and Weiming Hu
12 |
13 | *European Conference on Computer Vision (ECCV), 2018*
14 |
15 |
16 |
17 | ## Introduction
18 |
19 | **SiamRPN** formulates the task of visual tracking as a task of localization and identification simultaneously, initially described in an [CVPR2018 spotlight paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf). (Slides at [CVPR 2018 Spotlight](https://drive.google.com/open?id=1OGIOUqANvYfZjRoQfpiDqhPQtOvPCpdq))
20 |
21 | **DaSiamRPN** improves the performances of SiamRPN by (1) introducing an effective sampling strategy to control the imbalanced sample distribution, (2) designing a novel distractor-aware module to perform incremental learning, (3) making a long-term tracking extension. [ECCV2018](https://arxiv.org/pdf/1808.06048.pdf). (Slides at [VOT-18 Real-time challenge winners talk](https://drive.google.com/open?id=1dsEI2uYHDfELK0CW2xgv7R4QdCs6lwfr))
22 |
23 |
24 |

25 |
26 |
27 | ## Prerequisites
28 |
29 | CPU: Intel(R) Core(TM) i7-7700 CPU @ 3.60GHz
30 | GPU: NVIDIA GTX1060
31 |
32 | - python2.7
33 | - pytorch == 0.3.1
34 | - numpy
35 | - opencv
36 |
37 |
38 | ## Pretrained model for SiamRPN
39 |
40 | In our tracker, we use an AlexNet variant as our backbone, which is end-to-end trained for visual tracking.
41 | The pretrained model can be downloaded from google drive: [SiamRPNBIG.model](https://drive.google.com/file/d/1-vNVZxfbIplXHrqMHiJJYWXYWsOIvGsf/view?usp=sharing).
42 | Then, you should copy the pretrained model file `SiamRPNBIG.model` to the subfolder './code', so that the tracker can find and load the pretrained_model.
43 |
44 |
45 | ## Detailed steps to install the prerequisites
46 |
47 | - install pytorch, numpy, opencv following the instructions in the `run_install.sh`. Please do **not** use conda to install.
48 | - you can alternatively modify `/PATH/TO/CODE/FOLDER/` in `tracker_SiamRPN.m`
49 | If the tracker is ready, you will see the tracking results. (EAO: 0.3827)
50 |
51 |
52 | ## Results
53 | All results can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1HJOvl_irX3KFbtfj88_FVLtukMI1GTCR?usp=sharing).
54 |
55 | | | VOT2015A / R / EAO | VOT2016A / R / EAO | VOT2017 & VOT2018A / R / EAO | OTB2015OP / DP | UAV123AUC / DP | UAV20LAUC / DP |
56 | | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
57 | | **SiamRPN** CVPR2017 | 0.58 / 1.13 / 0.349 | 0.56 / 0.26 / 0.344 | 0.49 / 0.46 / 0.244 | 81.9 / 85.0 | 0.527 / 0.748 | 0.454 / 0.617 |
58 | | **DaSiamRPN** ECCV2018 | **0.63** / **0.66** / **0.446** | **0.61** / **0.22** / **0.411** | 0.56 / 0.34 / 0.326 | **86.5** / **88.0** | **0.586** / **0.796** | **0.617** / **0.838** |
59 | | **DaSiamRPN** VOT2018 | - | - | **0.59** / **0.28** / **0.383** | - | - | - |
60 |
61 |
62 | # Demo and Test on OTB2015
63 |
64 |

65 |
66 |
67 | - To reproduce the reuslts on paper, the pretrained model can be downloaded from [Google Drive](https://drive.google.com/open?id=1BtIkp5pB6aqePQGlMb2_Z7bfPy6XEj6H): `SiamRPNOTB.model`.
68 | :zap: :zap: This model is the **fastest** (~200fps) Siamese Tracker with AUC of 0.655 on OTB2015. :zap: :zap:
69 |
70 | - You must download OTB2015 dataset (download [script](code/data/get_otb_data.sh)) at first.
71 |
72 | A simple test example.
73 |
74 | ```
75 | cd code
76 | python demo.py
77 | ```
78 |
79 | If you want to test the performance on OTB2015, please using the follwing command.
80 |
81 | ```
82 | cd code
83 | python test_otb.py
84 | python eval_otb.py OTB2015 "Siam*" 0 1
85 | ```
86 |
87 |
88 | # License
89 | Licensed under an MIT license.
90 |
91 |
92 | ## Citing DaSiamRPN
93 |
94 | If you find **DaSiamRPN** and **SiamRPN** useful in your research, please consider citing:
95 |
96 | ```
97 | @inproceedings{Zhu_2018_ECCV,
98 | title={Distractor-aware Siamese Networks for Visual Object Tracking},
99 | author={Zhu, Zheng and Wang, Qiang and Bo, Li and Wu, Wei and Yan, Junjie and Hu, Weiming},
100 | booktitle={European Conference on Computer Vision},
101 | year={2018}
102 | }
103 |
104 | @InProceedings{Li_2018_CVPR,
105 | title = {High Performance Visual Tracking With Siamese Region Proposal Network},
106 | author = {Li, Bo and Yan, Junjie and Wu, Wei and Zhu, Zheng and Hu, Xiaolin},
107 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
108 | year = {2018}
109 | }
110 | ```
111 |
--------------------------------------------------------------------------------
/DaSiamRPN/demo.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # DaSiamRPN
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | # !/usr/bin/python
7 |
8 | import glob, cv2, torch
9 | import numpy as np
10 | from os.path import realpath, dirname, join
11 |
12 | from net import SiamRPNvot
13 | from run_SiamRPN import SiamRPN_init, SiamRPN_track
14 | from utils import get_axis_aligned_bbox, cxy_wh_2_rect
15 |
16 | # load net
17 | net = SiamRPNvot()
18 | net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNVOT.model')))
19 |
20 | # net.eval().cuda()
21 | net.eval()
22 |
23 | # image and init box
24 | image_files = sorted(glob.glob('./bag/*.jpg'))
25 | init_rbox = [334.02, 128.36, 438.19, 188.78, 396.39, 260.83, 292.23, 200.41]
26 | [cx, cy, w, h] = get_axis_aligned_bbox(init_rbox)
27 |
28 | # tracker init
29 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h])
30 | im = cv2.imread(image_files[0]) # HxWxC
31 | state = SiamRPN_init(im, target_pos, target_sz, net)
32 |
33 | # tracking and visualization
34 | toc = 0
35 | for f, image_file in enumerate(image_files):
36 | im = cv2.imread(image_file)
37 | tic = cv2.getTickCount()
38 | state = SiamRPN_track(state, im) # track
39 | toc += cv2.getTickCount() - tic
40 | res = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
41 | res = [int(l) for l in res]
42 | cv2.rectangle(im, (res[0], res[1]), (res[0] + res[2], res[1] + res[3]), (0, 255, 255), 3)
43 | cv2.imshow('SiamRPN', im)
44 | cv2.waitKey(1)
45 |
46 | print('Tracking Speed {:.1f}fps'.format((len(image_files) - 1) / (toc / cv2.getTickFrequency())))
47 |
--------------------------------------------------------------------------------
/DaSiamRPN/eval_otb.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import os
4 | import glob
5 | from os.path import join as fullfile
6 | import numpy as np
7 |
8 |
9 | def overlap_ratio(rect1, rect2):
10 | '''
11 | Compute overlap ratio between two rects
12 | - rect: 1d array of [x,y,w,h] or
13 | 2d array of N x [x,y,w,h]
14 | '''
15 |
16 | if rect1.ndim==1:
17 | rect1 = rect1[None,:]
18 | if rect2.ndim==1:
19 | rect2 = rect2[None,:]
20 |
21 | left = np.maximum(rect1[:,0], rect2[:,0])
22 | right = np.minimum(rect1[:,0]+rect1[:,2], rect2[:,0]+rect2[:,2])
23 | top = np.maximum(rect1[:,1], rect2[:,1])
24 | bottom = np.minimum(rect1[:,1]+rect1[:,3], rect2[:,1]+rect2[:,3])
25 |
26 | intersect = np.maximum(0,right - left) * np.maximum(0,bottom - top)
27 | union = rect1[:,2]*rect1[:,3] + rect2[:,2]*rect2[:,3] - intersect
28 | iou = np.clip(intersect / union, 0, 1)
29 | return iou
30 |
31 |
32 | def compute_success_overlap(gt_bb, result_bb):
33 | thresholds_overlap = np.arange(0, 1.05, 0.05)
34 | n_frame = len(gt_bb)
35 | success = np.zeros(len(thresholds_overlap))
36 | iou = overlap_ratio(gt_bb, result_bb)
37 | for i in range(len(thresholds_overlap)):
38 | success[i] = sum(iou > thresholds_overlap[i]) / float(n_frame)
39 | return success
40 |
41 |
42 | def compute_success_error(gt_center, result_center):
43 | thresholds_error = np.arange(0, 51, 1)
44 | n_frame = len(gt_center)
45 | success = np.zeros(len(thresholds_error))
46 | dist = np.sqrt(np.sum(np.power(gt_center - result_center, 2), axis=1))
47 | for i in range(len(thresholds_error)):
48 | success[i] = sum(dist <= thresholds_error[i]) / float(n_frame)
49 | return success
50 |
51 |
52 | def get_result_bb(arch, seq):
53 | result_path = fullfile(arch, seq + '.txt')
54 | temp = np.loadtxt(result_path, delimiter=',').astype(np.float)
55 | return np.array(temp)
56 |
57 |
58 | def convert_bb_to_center(bboxes):
59 | return np.array([(bboxes[:, 0] + (bboxes[:, 2] - 1) / 2),
60 | (bboxes[:, 1] + (bboxes[:, 3] - 1) / 2)]).T
61 |
62 |
63 | def eval_auc(dataset='OTB2015', tracker_reg='S*', start=0, end=1e6):
64 | list_path = os.path.join('data', dataset + '.json')
65 | annos = json.load(open(list_path, 'r'))
66 | seqs = list(annos.keys())
67 |
68 | OTB2013 = ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix',
69 | 'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek',
70 | 'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking',
71 | 'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1',
72 | 'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2',
73 | 'basketball', 'football', 'subway', 'tiger1', 'tiger2']
74 |
75 | trackers = glob.glob(fullfile('test', dataset, tracker_reg))
76 | trackers = trackers[start:min(end, len(trackers))]
77 |
78 | n_seq = len(seqs)
79 | thresholds_overlap = np.arange(0, 1.05, 0.05)
80 | # thresholds_error = np.arange(0, 51, 1)
81 |
82 | success_overlap = np.zeros((n_seq, len(trackers), len(thresholds_overlap)))
83 | # success_error = np.zeros((n_seq, len(trackers), len(thresholds_error)))
84 | for i in range(n_seq):
85 | seq = seqs[i]
86 | gt_rect = np.array(annos[seq]['gt_rect']).astype(np.float)
87 | gt_center = convert_bb_to_center(gt_rect)
88 | for j in range(len(trackers)):
89 | tracker = trackers[j]
90 | print('{:d} processing:{} tracker: {}'.format(i, seq, tracker))
91 | bb = get_result_bb(tracker, seq)
92 | center = convert_bb_to_center(bb)
93 | success_overlap[i][j] = compute_success_overlap(gt_rect, bb)
94 | # success_error[i][j] = compute_success_error(gt_center, center)
95 |
96 | print('Success Overlap')
97 |
98 | if 'OTB2015' == dataset:
99 | OTB2013_id = []
100 | for i in range(n_seq):
101 | if seqs[i] in OTB2013:
102 | OTB2013_id.append(i)
103 | max_auc_OTB2013 = 0.
104 | max_name_OTB2013 = ''
105 | for i in range(len(trackers)):
106 | auc = success_overlap[OTB2013_id, i, :].mean()
107 | if auc > max_auc_OTB2013:
108 | max_auc_OTB2013 = auc
109 | max_name_OTB2013 = trackers[i]
110 | print('%s(OTB2013 AUC:%.4f)' % (trackers[i], auc))
111 |
112 | max_auc = 0.
113 | max_name = ''
114 | for i in range(len(trackers)):
115 | auc = success_overlap[:, i, :].mean()
116 | if auc > max_auc:
117 | max_auc = auc
118 | max_name = trackers[i]
119 | print('%s(OTB2015 AUC:%.4f)' % (trackers[i], auc))
120 |
121 | print('\nOTB2013 Best: %s(%.4f)' % (max_name_OTB2013, max_auc_OTB2013))
122 | print('\nOTB2015 Best: %s(%.4f)' % (max_name, max_auc))
123 | else:
124 | max_auc = 0.
125 | max_name = ''
126 | for i in range(len(trackers)):
127 | auc = success_overlap[:, i, :].mean()
128 | if auc > max_auc:
129 | max_auc = auc
130 | max_name = trackers[i]
131 | print('%s(%.4f)' % (trackers[i], auc))
132 |
133 | print('\n%s Best: %s(%.4f)' % (dataset, max_name, max_auc))
134 |
135 |
136 | if __name__ == "__main__":
137 | if len(sys.argv) < 5:
138 | print('python eval_otb.py OTB2015 Siam* 0 10')
139 | exit()
140 | dataset = sys.argv[1]
141 | tracker_reg = sys.argv[2]
142 | start = int(float(sys.argv[3]))
143 | end = int(float(sys.argv[4]))
144 | eval_auc(dataset, tracker_reg, start, end)
145 |
--------------------------------------------------------------------------------
/DaSiamRPN/net.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # DaSiamRPN
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class SiamRPN(nn.Module):
11 | def __init__(self, size=2, feature_out=512, anchor=5):
12 | configs = [3, 96, 256, 384, 384, 256]
13 | configs = list(map(lambda x: 3 if x == 3 else x * size, configs))
14 | feat_in = configs[-1]
15 | super(SiamRPN, self).__init__()
16 | self.featureExtract = nn.Sequential(
17 | nn.Conv2d(configs[0], configs[1], kernel_size=11, stride=2),
18 | nn.BatchNorm2d(configs[1]),
19 | nn.MaxPool2d(kernel_size=3, stride=2),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(configs[1], configs[2], kernel_size=5),
22 | nn.BatchNorm2d(configs[2]),
23 | nn.MaxPool2d(kernel_size=3, stride=2),
24 | nn.ReLU(inplace=True),
25 | nn.Conv2d(configs[2], configs[3], kernel_size=3),
26 | nn.BatchNorm2d(configs[3]),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(configs[3], configs[4], kernel_size=3),
29 | nn.BatchNorm2d(configs[4]),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(configs[4], configs[5], kernel_size=3),
32 | nn.BatchNorm2d(configs[5]),
33 | )
34 |
35 | self.anchor = anchor
36 | self.feature_out = feature_out
37 |
38 | self.conv_r1 = nn.Conv2d(feat_in, feature_out * 4 * anchor, 3)
39 | self.conv_r2 = nn.Conv2d(feat_in, feature_out, 3)
40 | self.conv_cls1 = nn.Conv2d(feat_in, feature_out * 2 * anchor, 3)
41 | self.conv_cls2 = nn.Conv2d(feat_in, feature_out, 3)
42 | self.regress_adjust = nn.Conv2d(4 * anchor, 4 * anchor, 1)
43 |
44 | self.r1_kernel = []
45 | self.cls1_kernel = []
46 |
47 | self.cfg = {}
48 |
49 | def forward(self, x):
50 | x_f = self.featureExtract(x)
51 | return self.regress_adjust(F.conv2d(self.conv_r2(x_f), self.r1_kernel)), \
52 | F.conv2d(self.conv_cls2(x_f), self.cls1_kernel)
53 |
54 | def temple(self, z):
55 | z_f = self.featureExtract(z)
56 | r1_kernel_raw = self.conv_r1(z_f)
57 | cls1_kernel_raw = self.conv_cls1(z_f)
58 | kernel_size = r1_kernel_raw.data.size()[-1]
59 | self.r1_kernel = r1_kernel_raw.view(self.anchor * 4, self.feature_out, kernel_size, kernel_size)
60 | self.cls1_kernel = cls1_kernel_raw.view(self.anchor * 2, self.feature_out, kernel_size, kernel_size)
61 |
62 |
63 | class SiamRPNBIG(SiamRPN):
64 | def __init__(self):
65 | super(SiamRPNBIG, self).__init__(size=2)
66 | self.cfg = {'lr': 0.295, 'window_influence': 0.42, 'penalty_k': 0.055, 'instance_size': 271,
67 | 'adaptive': True} # 0.383
68 |
69 |
70 | class SiamRPNvot(SiamRPN):
71 | def __init__(self):
72 | super(SiamRPNvot, self).__init__(size=1, feature_out=256)
73 | self.cfg = {'lr': 0.45, 'window_influence': 0.44, 'penalty_k': 0.04, 'instance_size': 271,
74 | 'adaptive': False} # 0.355
75 |
76 |
77 | class SiamRPNotb(SiamRPN):
78 | def __init__(self):
79 | super(SiamRPNotb, self).__init__(size=1, feature_out=256)
80 | self.cfg = {'lr': 0.30, 'window_influence': 0.40, 'penalty_k': 0.22, 'instance_size': 271,
81 | 'adaptive': False} # 0.655
82 |
--------------------------------------------------------------------------------
/DaSiamRPN/run_SiamRPN.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # DaSiamRPN
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | import numpy as np
7 | from torch.autograd import Variable
8 | import torch.nn.functional as F
9 |
10 |
11 | from utils import get_subwindow_tracking
12 |
13 |
14 | def generate_anchor(total_stride, scales, ratios, score_size):
15 | anchor_num = len(ratios) * len(scales)
16 | anchor = np.zeros((anchor_num, 4), dtype=np.float32)
17 | size = total_stride * total_stride
18 | count = 0
19 | for ratio in ratios:
20 | # ws = int(np.sqrt(size * 1.0 / ratio))
21 | ws = int(np.sqrt(size / ratio))
22 | hs = int(ws * ratio)
23 | for scale in scales:
24 | wws = ws * scale
25 | hhs = hs * scale
26 | anchor[count, 0] = 0
27 | anchor[count, 1] = 0
28 | anchor[count, 2] = wws
29 | anchor[count, 3] = hhs
30 | count += 1
31 |
32 | anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4))
33 | ori = - (score_size / 2) * total_stride
34 | xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)],
35 | [ori + total_stride * dy for dy in range(score_size)])
36 | xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \
37 | np.tile(yy.flatten(), (anchor_num, 1)).flatten()
38 | anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32)
39 | return anchor
40 |
41 |
42 | class TrackerConfig(object):
43 | # These are the default hyper-params for DaSiamRPN 0.3827
44 | windowing = 'cosine' # to penalize large displacements [cosine/uniform]
45 | # Params from the network architecture, have to be consistent with the training
46 | exemplar_size = 127 # input z size
47 | instance_size = 271 # input x size (search region)
48 | total_stride = 8
49 | score_size = (instance_size-exemplar_size)/total_stride+1
50 | context_amount = 0.5 # context amount for the exemplar
51 | ratios = [0.33, 0.5, 1, 2, 3]
52 | scales = [8, ]
53 | anchor_num = len(ratios) * len(scales)
54 | anchor = []
55 | penalty_k = 0.055
56 | window_influence = 0.42
57 | lr = 0.295
58 | # adaptive change search region #
59 | adaptive = True
60 |
61 | def update(self, cfg):
62 | for k, v in cfg.items():
63 | setattr(self, k, v)
64 | self.score_size = (self.instance_size - self.exemplar_size) / self.total_stride + 1
65 |
66 |
67 | def tracker_eval(net, x_crop, target_pos, target_sz, window, scale_z, p):
68 | delta, score = net(x_crop)
69 |
70 | delta = delta.permute(1, 2, 3, 0).contiguous().view(4, -1).data.cpu().numpy()
71 | score = F.softmax(score.permute(1, 2, 3, 0).contiguous().view(2, -1), dim=0).data[1, :].cpu().numpy()
72 |
73 | delta[0, :] = delta[0, :] * p.anchor[:, 2] + p.anchor[:, 0]
74 | delta[1, :] = delta[1, :] * p.anchor[:, 3] + p.anchor[:, 1]
75 | delta[2, :] = np.exp(delta[2, :]) * p.anchor[:, 2]
76 | delta[3, :] = np.exp(delta[3, :]) * p.anchor[:, 3]
77 |
78 | def change(r):
79 | return np.maximum(r, 1./r)
80 |
81 | def sz(w, h):
82 | pad = (w + h) * 0.5
83 | sz2 = (w + pad) * (h + pad)
84 | return np.sqrt(sz2)
85 |
86 | def sz_wh(wh):
87 | pad = (wh[0] + wh[1]) * 0.5
88 | sz2 = (wh[0] + pad) * (wh[1] + pad)
89 | return np.sqrt(sz2)
90 |
91 | # size penalty
92 | s_c = change(sz(delta[2, :], delta[3, :]) / (sz_wh(target_sz))) # scale penalty
93 | r_c = change((target_sz[0] / target_sz[1]) / (delta[2, :] / delta[3, :])) # ratio penalty
94 |
95 | penalty = np.exp(-(r_c * s_c - 1.) * p.penalty_k)
96 | pscore = penalty * score
97 |
98 | # window float
99 | pscore = pscore * (1 - p.window_influence) + window * p.window_influence
100 | best_pscore_id = np.argmax(pscore)
101 |
102 | target = delta[:, best_pscore_id] / scale_z
103 | target_sz = target_sz / scale_z
104 | lr = penalty[best_pscore_id] * score[best_pscore_id] * p.lr
105 |
106 | res_x = target[0] + target_pos[0]
107 | res_y = target[1] + target_pos[1]
108 |
109 | res_w = target_sz[0] * (1 - lr) + target[2] * lr
110 | res_h = target_sz[1] * (1 - lr) + target[3] * lr
111 |
112 | target_pos = np.array([res_x, res_y])
113 | target_sz = np.array([res_w, res_h])
114 | return target_pos, target_sz, score[best_pscore_id]
115 |
116 |
117 | def SiamRPN_init(im, target_pos, target_sz, net):
118 | state = dict()
119 | p = TrackerConfig()
120 | p.update(net.cfg)
121 | state['im_h'] = im.shape[0]
122 | state['im_w'] = im.shape[1]
123 |
124 | if p.adaptive:
125 | if ((target_sz[0] * target_sz[1]) / float(state['im_h'] * state['im_w'])) < 0.004:
126 | p.instance_size = 287 # small object big search region
127 | else:
128 | p.instance_size = 271
129 |
130 | p.score_size = (p.instance_size - p.exemplar_size) / p.total_stride + 1
131 |
132 | p.anchor = generate_anchor(p.total_stride, p.scales, p.ratios, int(p.score_size))
133 |
134 | avg_chans = np.mean(im, axis=(0, 1))
135 |
136 | wc_z = target_sz[0] + p.context_amount * sum(target_sz)
137 | hc_z = target_sz[1] + p.context_amount * sum(target_sz)
138 | s_z = round(np.sqrt(wc_z * hc_z))
139 | # initialize the exemplar
140 | z_crop = get_subwindow_tracking(im, target_pos, p.exemplar_size, s_z, avg_chans)
141 |
142 | z = Variable(z_crop.unsqueeze(0))
143 | # net.temple(z.cuda())
144 | net.temple(z)
145 |
146 | if p.windowing == 'cosine':
147 | window = np.outer(np.hanning(p.score_size), np.hanning(p.score_size))
148 | elif p.windowing == 'uniform':
149 | window = np.ones((p.score_size, p.score_size))
150 | window = np.tile(window.flatten(), p.anchor_num)
151 |
152 | state['p'] = p
153 | state['net'] = net
154 | state['avg_chans'] = avg_chans
155 | state['window'] = window
156 | state['target_pos'] = target_pos
157 | state['target_sz'] = target_sz
158 | return state
159 |
160 |
161 | def SiamRPN_track(state, im):
162 | p = state['p']
163 | net = state['net']
164 | avg_chans = state['avg_chans']
165 | window = state['window']
166 | target_pos = state['target_pos']
167 | target_sz = state['target_sz']
168 |
169 | wc_z = target_sz[1] + p.context_amount * sum(target_sz)
170 | hc_z = target_sz[0] + p.context_amount * sum(target_sz)
171 | s_z = np.sqrt(wc_z * hc_z)
172 | scale_z = p.exemplar_size / s_z
173 | d_search = (p.instance_size - p.exemplar_size) / 2
174 | pad = d_search / scale_z
175 | s_x = s_z + 2 * pad
176 |
177 | # extract scaled crops for search region x at previous target position
178 | x_crop = Variable(get_subwindow_tracking(im, target_pos, p.instance_size, round(s_x), avg_chans).unsqueeze(0))
179 |
180 | # target_pos, target_sz, score = tracker_eval(net, x_crop.cuda(), target_pos, target_sz * scale_z, window, scale_z, p)
181 | target_pos, target_sz, score = tracker_eval(net, x_crop, target_pos, target_sz * scale_z, window, scale_z, p)
182 |
183 | target_pos[0] = max(0, min(state['im_w'], target_pos[0]))
184 | target_pos[1] = max(0, min(state['im_h'], target_pos[1]))
185 | target_sz[0] = max(10, min(state['im_w'], target_sz[0]))
186 | target_sz[1] = max(10, min(state['im_h'], target_sz[1]))
187 | state['target_pos'] = target_pos
188 | state['target_sz'] = target_sz
189 | state['score'] = score
190 | return state
191 |
--------------------------------------------------------------------------------
/DaSiamRPN/test_otb.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # DaSiamRPN
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | #!/usr/bin/python
7 |
8 | import argparse, cv2, torch, json
9 | import numpy as np
10 | from os import makedirs
11 | from os.path import realpath, dirname, join, isdir, exists
12 |
13 | from net import SiamRPNotb
14 | from run_SiamRPN import SiamRPN_init, SiamRPN_track
15 | from utils import rect_2_cxy_wh, cxy_wh_2_rect
16 |
17 | parser = argparse.ArgumentParser(description='PyTorch SiamRPN OTB Test')
18 | parser.add_argument('--dataset', dest='dataset', default='OTB2015', help='datasets')
19 | parser.add_argument('-v', '--visualization', dest='visualization', action='store_true',
20 | help='whether visualize result')
21 |
22 |
23 | def track_video(model, video):
24 | toc, regions = 0, []
25 | image_files, gt = video['image_files'], video['gt']
26 | for f, image_file in enumerate(image_files):
27 | im = cv2.imread(image_file) # TODO: batch load
28 | tic = cv2.getTickCount()
29 | if f == 0: # init
30 | target_pos, target_sz = rect_2_cxy_wh(gt[f])
31 | state = SiamRPN_init(im, target_pos, target_sz, model) # init tracker
32 | location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
33 | regions.append(gt[f])
34 | elif f > 0: # tracking
35 | state = SiamRPN_track(state, im) # track
36 | location = cxy_wh_2_rect(state['target_pos']+1, state['target_sz'])
37 | regions.append(location)
38 | toc += cv2.getTickCount() - tic
39 |
40 | if args.visualization and f >= 0: # visualization
41 | if f == 0: cv2.destroyAllWindows()
42 | if len(gt[f]) == 8:
43 | cv2.polylines(im, [np.array(gt[f], np.int).reshape((-1, 1, 2))], True, (0, 255, 0), 3)
44 | else:
45 | cv2.rectangle(im, (gt[f, 0], gt[f, 1]), (gt[f, 0] + gt[f, 2], gt[f, 1] + gt[f, 3]), (0, 255, 0), 3)
46 | if len(location) == 8:
47 | cv2.polylines(im, [location.reshape((-1, 1, 2))], True, (0, 255, 255), 3)
48 | else:
49 | location = [int(l) for l in location] #
50 | cv2.rectangle(im, (location[0], location[1]),
51 | (location[0] + location[2], location[1] + location[3]), (0, 255, 255), 3)
52 | cv2.putText(im, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
53 |
54 | cv2.imshow(video['name'], im)
55 | cv2.waitKey(1)
56 | toc /= cv2.getTickFrequency()
57 |
58 | # save result
59 | video_path = join('test', args.dataset, 'SiamRPN_AlexNet_OTB2015')
60 | if not isdir(video_path): makedirs(video_path)
61 | result_path = join(video_path, '{:s}.txt'.format(video['name']))
62 | with open(result_path, "w") as fin:
63 | for x in regions:
64 | fin.write(','.join([str(i) for i in x])+'\n')
65 |
66 | print('({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps'.format(
67 | v_id, video['name'], toc, f / toc))
68 | return f / toc
69 |
70 |
71 | def load_dataset(dataset):
72 | base_path = join(realpath(dirname(__file__)), 'data', dataset)
73 | if not exists(base_path):
74 | print("Please download OTB dataset into `data` folder!")
75 | exit()
76 | json_path = join(realpath(dirname(__file__)), 'data', dataset + '.json')
77 | info = json.load(open(json_path, 'r'))
78 | for v in info.keys():
79 | path_name = info[v]['name']
80 | info[v]['image_files'] = [join(base_path, path_name, 'img', im_f) for im_f in info[v]['image_files']]
81 | info[v]['gt'] = np.array(info[v]['gt_rect'])-[1,1,0,0] # our tracker is 0-index
82 | info[v]['name'] = v
83 | return info
84 |
85 |
86 | def main():
87 | global args, v_id
88 | args = parser.parse_args()
89 |
90 | net = SiamRPNotb()
91 | net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNOTB.model')))
92 | net.eval().cuda()
93 |
94 | dataset = load_dataset(args.dataset)
95 | fps_list = []
96 | for v_id, video in enumerate(dataset.keys()):
97 | fps_list.append(track_video(net, dataset[video]))
98 | print('Mean Running Speed {:.1f}fps'.format(np.mean(np.array(fps_list))))
99 |
100 |
101 | if __name__ == '__main__':
102 | main()
103 |
--------------------------------------------------------------------------------
/DaSiamRPN/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # DaSiamRPN
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | import cv2
7 | import torch
8 | import numpy as np
9 |
10 |
11 | def to_numpy(tensor):
12 | if torch.is_tensor(tensor):
13 | return tensor.cpu().numpy()
14 | elif type(tensor).__module__ != 'numpy':
15 | raise ValueError("Cannot convert {} to numpy array"
16 | .format(type(tensor)))
17 | return tensor
18 |
19 |
20 | def to_torch(ndarray):
21 | if type(ndarray).__module__ == 'numpy':
22 | return torch.from_numpy(ndarray)
23 | elif not torch.is_tensor(ndarray):
24 | raise ValueError("Cannot convert {} to torch tensor"
25 | .format(type(ndarray)))
26 | return ndarray
27 |
28 |
29 | def im_to_numpy(img):
30 | img = to_numpy(img)
31 | img = np.transpose(img, (1, 2, 0)) # H*W*C
32 | return img
33 |
34 |
35 | def im_to_torch(img):
36 | img = np.transpose(img, (2, 0, 1)) # C*H*W
37 | img = to_torch(img).float()
38 | return img
39 |
40 |
41 | def torch_to_img(img):
42 | img = to_numpy(torch.squeeze(img, 0))
43 | img = np.transpose(img, (1, 2, 0)) # H*W*C
44 | return img
45 |
46 |
47 | def get_subwindow_tracking(im, pos, model_sz, original_sz, avg_chans, out_mode='torch', new=False):
48 |
49 | if isinstance(pos, float):
50 | pos = [pos, pos]
51 | sz = original_sz
52 | im_sz = im.shape
53 | c = (original_sz+1) / 2
54 | context_xmin = round(pos[0] - c) # floor(pos(2) - sz(2) / 2);
55 | context_xmax = context_xmin + sz - 1
56 | context_ymin = round(pos[1] - c) # floor(pos(1) - sz(1) / 2);
57 | context_ymax = context_ymin + sz - 1
58 | left_pad = int(max(0., -context_xmin))
59 | top_pad = int(max(0., -context_ymin))
60 | right_pad = int(max(0., context_xmax - im_sz[1] + 1))
61 | bottom_pad = int(max(0., context_ymax - im_sz[0] + 1))
62 |
63 | context_xmin = context_xmin + left_pad
64 | context_xmax = context_xmax + left_pad
65 | context_ymin = context_ymin + top_pad
66 | context_ymax = context_ymax + top_pad
67 |
68 | # zzp: a more easy speed version
69 | r, c, k = im.shape
70 | if any([top_pad, bottom_pad, left_pad, right_pad]):
71 | te_im = np.zeros((r + top_pad + bottom_pad, c + left_pad + right_pad, k), np.uint8) # 0 is better than 1 initialization
72 | te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im
73 | if top_pad:
74 | te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans
75 | if bottom_pad:
76 | te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans
77 | if left_pad:
78 | te_im[:, 0:left_pad, :] = avg_chans
79 | if right_pad:
80 | te_im[:, c + left_pad:, :] = avg_chans
81 | im_patch_original = te_im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :]
82 | else:
83 | im_patch_original = im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :]
84 |
85 | if not np.array_equal(model_sz, original_sz):
86 | im_patch = cv2.resize(im_patch_original, (model_sz, model_sz)) # zzp: use cv to get a better speed
87 | else:
88 | im_patch = im_patch_original
89 |
90 | return im_to_torch(im_patch) if out_mode in 'torch' else im_patch
91 |
92 |
93 | def cxy_wh_2_rect(pos, sz):
94 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) # 0-index
95 |
96 |
97 | def rect_2_cxy_wh(rect):
98 | return np.array([rect[0]+rect[2]/2, rect[1]+rect[3]/2]), np.array([rect[2], rect[3]]) # 0-index
99 |
100 |
101 | def get_axis_aligned_bbox(region):
102 | try:
103 | region = np.array([region[0][0][0], region[0][0][1], region[0][1][0], region[0][1][1],
104 | region[0][2][0], region[0][2][1], region[0][3][0], region[0][3][1]])
105 | except:
106 | region = np.array(region)
107 | cx = np.mean(region[0::2])
108 | cy = np.mean(region[1::2])
109 | x1 = min(region[0::2])
110 | x2 = max(region[0::2])
111 | y1 = min(region[1::2])
112 | y2 = max(region[1::2])
113 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[2:4] - region[4:6])
114 | A2 = (x2 - x1) * (y2 - y1)
115 | s = np.sqrt(A1 / A2)
116 | w = s * (x2 - x1) + 1
117 | h = s * (y2 - y1) + 1
118 | return cx, cy, w, h
--------------------------------------------------------------------------------
/DaSiamRPN/vot.py:
--------------------------------------------------------------------------------
1 | """
2 | \file vot.py
3 |
4 | @brief Python utility functions for VOT integration
5 |
6 | @author Luka Cehovin, Alessio Dore
7 |
8 | @date 2016
9 |
10 | """
11 |
12 | import sys
13 | import copy
14 | import collections
15 |
16 | try:
17 | import trax
18 | import trax.server
19 | TRAX = True
20 | except ImportError:
21 | TRAX = False
22 |
23 | Rectangle = collections.namedtuple('Rectangle', ['x', 'y', 'width', 'height'])
24 | Point = collections.namedtuple('Point', ['x', 'y'])
25 | Polygon = collections.namedtuple('Polygon', ['points'])
26 |
27 | def parse_region(string):
28 | tokens = map(float, string.split(','))
29 | if len(tokens) == 4:
30 | return Rectangle(tokens[0], tokens[1], tokens[2], tokens[3])
31 | elif len(tokens) % 2 == 0 and len(tokens) > 4:
32 | return Polygon([Point(tokens[i],tokens[i+1]) for i in xrange(0,len(tokens),2)])
33 | return None
34 |
35 | def encode_region(region):
36 | if isinstance(region, Polygon):
37 | return ','.join(['{},{}'.format(p.x,p.y) for p in region.points])
38 | elif isinstance(region, Rectangle):
39 | return '{},{},{},{}'.format(region.x, region.y, region.width, region.height)
40 | else:
41 | return ""
42 |
43 | def convert_region(region, to):
44 |
45 | if to == 'rectangle':
46 |
47 | if isinstance(region, Rectangle):
48 | return copy.copy(region)
49 | elif isinstance(region, Polygon):
50 | top = sys.float_info.max
51 | bottom = sys.float_info.min
52 | left = sys.float_info.max
53 | right = sys.float_info.min
54 |
55 | for point in region.points:
56 | top = min(top, point.y)
57 | bottom = max(bottom, point.y)
58 | left = min(left, point.x)
59 | right = max(right, point.x)
60 |
61 | return Rectangle(left, top, right - left, bottom - top)
62 |
63 | else:
64 | return None
65 | if to == 'polygon':
66 |
67 | if isinstance(region, Rectangle):
68 | points = []
69 | points.append((region.x, region.y))
70 | points.append((region.x + region.width, region.y))
71 | points.append((region.x + region.width, region.y + region.height))
72 | points.append((region.x, region.y + region.height))
73 | return Polygon(points)
74 |
75 | elif isinstance(region, Polygon):
76 | return copy.copy(region)
77 | else:
78 | return None
79 |
80 | return None
81 |
82 | class VOT(object):
83 | """ Base class for Python VOT integration """
84 | def __init__(self, region_format):
85 | """ Constructor
86 |
87 | Args:
88 | region_format: Region format options
89 | """
90 | assert(region_format in ['rectangle', 'polygon'])
91 | if TRAX:
92 | options = trax.server.ServerOptions(region_format, trax.image.PATH)
93 | self._trax = trax.server.Server(options)
94 |
95 | request = self._trax.wait()
96 | assert(request.type == 'initialize')
97 | if request.region.type == 'polygon':
98 | self._region = Polygon([Point(x[0], x[1]) for x in request.region.points])
99 | else:
100 | self._region = Rectangle(request.region.x, request.region.y, request.region.width, request.region.height)
101 | self._image = str(request.image)
102 | self._trax.status(request.region)
103 | else:
104 | self._files = [x.strip('\n') for x in open('images.txt', 'r').readlines()]
105 | self._frame = 0
106 | self._region = convert_region(parse_region(open('region.txt', 'r').readline()), region_format)
107 | self._result = []
108 |
109 | def region(self):
110 | """
111 | Send configuration message to the client and receive the initialization
112 | region and the path of the first image
113 |
114 | Returns:
115 | initialization region
116 | """
117 |
118 | return self._region
119 |
120 | def report(self, region, confidence = 0):
121 | """
122 | Report the tracking results to the client
123 |
124 | Arguments:
125 | region: region for the frame
126 | """
127 | assert(isinstance(region, Rectangle) or isinstance(region, Polygon))
128 | if TRAX:
129 | if isinstance(region, Polygon):
130 | tregion = trax.region.Polygon([(x.x, x.y) for x in region.points])
131 | else:
132 | tregion = trax.region.Rectangle(region.x, region.y, region.width, region.height)
133 | self._trax.status(tregion, {"confidence" : confidence})
134 | else:
135 | self._result.append(region)
136 | self._frame += 1
137 |
138 | def frame(self):
139 | """
140 | Get a frame (image path) from client
141 |
142 | Returns:
143 | absolute path of the image
144 | """
145 | if TRAX:
146 | if hasattr(self, "_image"):
147 | image = str(self._image)
148 | del self._image
149 | return image
150 |
151 | request = self._trax.wait()
152 |
153 | if request.type == 'frame':
154 | return str(request.image)
155 | else:
156 | return None
157 |
158 | else:
159 | if self._frame >= len(self._files):
160 | return None
161 | return self._files[self._frame]
162 |
163 | def quit(self):
164 | if TRAX:
165 | self._trax.quit()
166 | elif hasattr(self, '_result'):
167 | with open('output.txt', 'w') as f:
168 | for r in self._result:
169 | f.write(encode_region(r))
170 | f.write('\n')
171 |
172 | def __del__(self):
173 | self.quit()
174 |
175 |
--------------------------------------------------------------------------------
/DaSiamRPN/vot_SiamRPN.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # DaSiamRPN
3 | # Licensed under The MIT License
4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
5 | # --------------------------------------------------------
6 | #!/usr/bin/python
7 |
8 | import vot
9 | from vot import Rectangle
10 | import sys
11 | import cv2 # imread
12 | import torch
13 | import numpy as np
14 | from os.path import realpath, dirname, join
15 |
16 | from net import SiamRPNBIG
17 | from run_SiamRPN import SiamRPN_init, SiamRPN_track
18 | from utils import get_axis_aligned_bbox, cxy_wh_2_rect
19 |
20 | # load net
21 | net_file = join(realpath(dirname(__file__)), 'SiamRPNBIG.model')
22 | net = SiamRPNBIG()
23 | net.load_state_dict(torch.load(net_file))
24 | net.eval().cuda()
25 |
26 | # warm up
27 | for i in range(10):
28 | net.temple(torch.autograd.Variable(torch.FloatTensor(1, 3, 127, 127)).cuda())
29 | net(torch.autograd.Variable(torch.FloatTensor(1, 3, 255, 255)).cuda())
30 |
31 | # start to track
32 | handle = vot.VOT("polygon")
33 | Polygon = handle.region()
34 | cx, cy, w, h = get_axis_aligned_bbox(Polygon)
35 |
36 | image_file = handle.frame()
37 | if not image_file:
38 | sys.exit(0)
39 |
40 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h])
41 | im = cv2.imread(image_file) # HxWxC
42 | state = SiamRPN_init(im, target_pos, target_sz, net) # init tracker
43 | while True:
44 | image_file = handle.frame()
45 | if not image_file:
46 | break
47 | im = cv2.imread(image_file) # HxWxC
48 | state = SiamRPN_track(state, im) # track
49 | res = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
50 |
51 | handle.report(Rectangle(res[0], res[1], res[2], res[3]))
52 |
53 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Object-detecting-and-tracking
2 | This project combines object detection and object tracking. The detector will detect the objects of the image captured by camera and the tracker will track the one of objects choosed by user. The detector is SSD model and tracker is SiamFPN model. Both models are real-time algorithms and you can use these algorithms only by CPU. You can run 'SSD+SiamRPN.py' to achieve object detecting and tracking.
3 |
4 | If you are interested in the princile of these algorithms, please read these papaers:
5 |
6 | * SSD: [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325)
7 | * SiamRPN: [High Performance Visual Tracking with Siamese Region Proposal Network](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf)
8 |
9 | This repository includes original Tensorflow code of SSD
10 | * https://github.com/wcwowwwww/SSD-Tensorflow
11 |
12 | This repository includes original pytorch code of SiamRPN
13 | * https://github.com/wcwowwwww/DaSiamRPN.
14 |
15 | ### Prerequisites
16 |
17 | * python
18 | * numpy
19 | * opencv
20 | * tensorflow
21 | * pytorch
22 |
23 | ### Pretrained model for SSD
24 |
25 | You can download model from https://drive.google.com/file/d/0B0qPCUZ-3YwWUXh4UHJrd1RDM3c/view?usp=sharing , then please put this model into subfolder 'SSD/checkpoints', so that the detector can find and load the pretrained_model.
26 |
27 | ### Pretrained model for SiamRPN
28 |
29 | You can download model from https://drive.google.com/drive/folders/1BtIkp5pB6aqePQGlMb2_Z7bfPy6XEj6H , then please put this model into subfolder 'DaSiamRPN', so that the tracker can find and load the pretrained_model.
30 |
31 | # 结合目标检测和目标跟踪
32 |
33 | ### 目标检测
34 | 目标检测算法使用的是SSD算法,该算法程序基于tensorflow实现。具体见:https://github.com/wcwowwwww/SSD-Tensorflow
35 |
36 | 论文地址为: [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325)
37 |
38 | 预训练的模型可以从 https://drive.google.com/file/d/0B0qPCUZ-3YwWUXh4UHJrd1RDM3c/view?usp=sharing 中下载,下载后放在SSD/checkpoints文件夹中。
39 |
40 | ### 目标追踪
41 | 目标跟踪算法使用的是SiamFPN算法,该算法程序基于pytorch实现。具体见:https://github.com/wcwowwwww/DaSiamRPN
42 |
43 | 论文地址为: [High Performance Visual Tracking with Siamese Region Proposal Network](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf)
44 |
45 |
46 | 预训练的模型请在 https://drive.google.com/drive/folders/1BtIkp5pB6aqePQGlMb2_Z7bfPy6XEj6H 下载SiamRPNVOT.model文件后放入DaSiamRPN文件夹中即可。
47 |
48 | 这两个算法的精度和速度都表现非常好,不需要GPU,能够在CPU条件下流畅运行。(当然训练时需要使用GPU加速运算)
49 |
50 | 代码执行环境:python, opencv, tensorflow, pytorch均更新为最新即可。
51 |
52 | 下载代码和模型后运行SSD+SiamRPN.py,可以实现对摄像头前目标的检测和跟踪。
53 |
54 | ### spaceshoot
55 | spaceshoot文件夹内是一款飞行设计类游戏,运行detection_tracking_game来打开游戏,按enter进入游戏后,飞机可以通过移动摄像头前的目标控制移动。
56 |
57 |
58 |
--------------------------------------------------------------------------------
/SSD+SiamRPN.py:
--------------------------------------------------------------------------------
1 | # ======================================================================
2 | #
3 | # filename : SSD+SiamRPN.py
4 | # description :
5 | #
6 | # created by Wang An at 2019.4.18
7 | #
8 | # ======================================================================
9 |
10 | import sys
11 | from os.path import realpath, dirname, join
12 |
13 | import cv2
14 | import numpy as np
15 | import tensorflow as tf
16 | import torch
17 |
18 | from net import SiamRPNvot
19 | from nets import ssd_vgg_300, np_methods
20 | from preprocessing import ssd_vgg_preprocessing
21 | from run_SiamRPN import SiamRPN_init, SiamRPN_track
22 | from utils import cxy_wh_2_rect
23 |
24 | # ------------------------------- #
25 |
26 |
27 | '''
28 | classes:
29 | 1.Aeroplanes 2.Bicycles 3.Birds 4.Boats 5.Bottles
30 | 6.Buses 7.Cars 8.Cats 9.Chairs 10.Cows
31 | 11.Dining tables 12.Dogs 13.Horses 14.Motorbikes 15.People
32 | 16.Potted plants 17.Sheep 18.Sofas 19.Trains 20.TV/Monitors
33 | '''
34 | detect_class = 15
35 |
36 | slim = tf.contrib.slim
37 |
38 | # TensorFlow session: grow memory when needed. TF, DO NOT USE ALL MY GPU MEMORY!!!
39 | gpu_options = tf.GPUOptions(allow_growth=True)
40 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
41 | isess = tf.InteractiveSession(config=config)
42 |
43 | # Input placeholder.
44 | net_shape = (300, 300)
45 | data_format = 'NHWC'
46 | img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
47 | # Evaluation pre-processing: resize to SSD net shape.
48 | image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
49 | img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
50 | image_4d = tf.expand_dims(image_pre, 0)
51 |
52 | # Define the SSD model.
53 | reuse = True if 'ssd_net' in locals() else None
54 | ssd_net = ssd_vgg_300.SSDNet()
55 | with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
56 | predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)
57 |
58 | # Restore SSD model.
59 | # ckpt_filename = 'checkpoints/ssd_300_vgg.ckpt'
60 | ckpt_filename = 'SSD/checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt'
61 |
62 | isess.run(tf.global_variables_initializer())
63 | saver = tf.train.Saver()
64 | saver.restore(isess, ckpt_filename)
65 |
66 | # SSD default anchor boxes.
67 | ssd_anchors = ssd_net.anchors(net_shape)
68 |
69 |
70 | # Main image processing routine.
71 | def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)):
72 | # Run SSD network.
73 | rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
74 | feed_dict={img_input: img})
75 |
76 | # Get classes and bboxes from the net outputs.
77 | rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
78 | rpredictions, rlocalisations, ssd_anchors,
79 | select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
80 |
81 | rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
82 | rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
83 | rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
84 | # Resize bboxes to original image shape. Note: useless for Resize.WARP!
85 | rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
86 | return rclasses, rscores, rbboxes
87 |
88 |
89 | def get_bboxes(rclasses, rbboxes):
90 | # get center location of object
91 |
92 | number_classes = rclasses.shape[0]
93 | object_bboxes = []
94 | for i in range(number_classes):
95 | object_bbox = dict()
96 | object_bbox['i'] = i
97 | object_bbox['class'] = rclasses[i]
98 | object_bbox['y_min'] = rbboxes[i, 0]
99 | object_bbox['x_min'] = rbboxes[i, 1]
100 | object_bbox['y_max'] = rbboxes[i, 2]
101 | object_bbox['x_max'] = rbboxes[i, 3]
102 | object_bboxes.append(object_bbox)
103 | return object_bboxes
104 |
105 |
106 | # load net
107 | net = SiamRPNvot()
108 | net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'DaSiamRPN/SiamRPNVOT.model')))
109 |
110 | net.eval()
111 |
112 | # open video capture
113 | video = cv2.VideoCapture(0)
114 |
115 | if not video.isOpened():
116 | print("Could not open video")
117 | sys.exit()
118 |
119 | index = True
120 | while index:
121 | # Read first frame.
122 | ok, frame = video.read()
123 | if not ok:
124 | print('Cannot read video file')
125 | sys.exit()
126 |
127 | # Define an initial bounding box
128 | height = frame.shape[0]
129 | width = frame.shape[1]
130 | rclasses, rscores, rbboxes = process_image(frame)
131 | bboxes = get_bboxes(rclasses, rbboxes)
132 | for bbox in bboxes:
133 | if bbox['class'] == detect_class:
134 | print(bbox)
135 | ymin = int(bbox['y_min'] * height)
136 | xmin = int((bbox['x_min']) * width)
137 | ymax = int(bbox['y_max'] * height)
138 | xmax = int((bbox['x_max']) * width)
139 | cx = (xmin + xmax) / 2
140 | cy = (ymin + ymax) / 2
141 | h = ymax - ymin
142 | w = xmax - xmin
143 | new_bbox = (cx, cy, w, h)
144 | print(new_bbox)
145 | index = False
146 | break
147 |
148 | # tracker init
149 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h])
150 | state = SiamRPN_init(frame, target_pos, target_sz, net)
151 |
152 | # tracking and visualization
153 | toc = 0
154 | count_number = 0
155 |
156 | while True:
157 | # Read a new frame
158 | ok, frame = video.read()
159 | if not ok:
160 | break
161 |
162 | # Start timer
163 | tic = cv2.getTickCount()
164 |
165 | # Update tracker
166 | state = SiamRPN_track(state, frame) # track
167 | # print(state)
168 |
169 | toc += cv2.getTickCount() - tic
170 |
171 | if state:
172 |
173 | res = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
174 | res = [int(l) for l in res]
175 | cv2.rectangle(frame, (res[0], res[1]), (res[0] + res[2], res[1] + res[3]), (0, 255, 255), 3)
176 | count_number += 1
177 |
178 | if (not state) or count_number % 40 == 3:
179 | # Tracking failure
180 | cv2.putText(frame, "Tracking failure detected", (100, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)
181 | index = True
182 | while index:
183 | ok, frame = video.read()
184 | rclasses, rscores, rbboxes = process_image(frame)
185 | bboxes = get_bboxes(rclasses, rbboxes)
186 | for bbox in bboxes:
187 | if bbox['class'] == detect_class:
188 | ymin = int(bbox['y_min'] * height)
189 | xmin = int(bbox['x_min'] * width)
190 | ymax = int(bbox['y_max'] * height)
191 | xmax = int(bbox['x_max'] * width)
192 | cx = (xmin + xmax) / 2
193 | cy = (ymin + ymax) / 2
194 | h = ymax - ymin
195 | w = xmax - xmin
196 | new_bbox = (cx, cy, w, h)
197 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h])
198 | state = SiamRPN_init(frame, target_pos, target_sz, net)
199 |
200 | p1 = (int(xmin), int(ymin))
201 | p2 = (int(xmax), int(ymax))
202 | cv2.rectangle(frame, p1, p2, (0, 255, 0), 2, 1)
203 |
204 | index = 0
205 |
206 | break
207 |
208 | cv2.imshow('SSD+SiamRPN', frame)
209 |
210 | # Exit if ESC pressed
211 | k = cv2.waitKey(1) & 0xff
212 | if k == 27:
213 | break
214 |
215 | video.release()
216 | cv2.destroyAllWindows()
217 |
--------------------------------------------------------------------------------
/SSD/README.md:
--------------------------------------------------------------------------------
1 | # SSD: Single Shot MultiBox Detector in TensorFlow
2 |
3 | SSD is an unified framework for object detection with a single network. It has been originally introduced in this research [article](http://arxiv.org/abs/1512.02325).
4 |
5 | This repository contains a TensorFlow re-implementation of the original [Caffe code](https://github.com/weiliu89/caffe/tree/ssd). At present, it only implements VGG-based SSD networks (with 300 and 512 inputs), but the architecture of the project is modular, and should make easy the implementation and training of other SSD variants (ResNet or Inception based for instance). Present TF checkpoints have been directly converted from SSD Caffe models.
6 |
7 | The organisation is inspired by the TF-Slim models repository containing the implementation of popular architectures (ResNet, Inception and VGG). Hence, it is separated in three main parts:
8 | * datasets: interface to popular datasets (Pascal VOC, COCO, ...) and scripts to convert the former to TF-Records;
9 | * networks: definition of SSD networks, and common encoding and decoding methods (we refer to the paper on this precise topic);
10 | * pre-processing: pre-processing and data augmentation routines, inspired by original VGG and Inception implementations.
11 |
12 | ## SSD minimal example
13 |
14 | The [SSD Notebook](notebooks/ssd_notebook.ipynb) contains a minimal example of the SSD TensorFlow pipeline. Shortly, the detection is made of two main steps: running the SSD network on the image and post-processing the output using common algorithms (top-k filtering and Non-Maximum Suppression algorithm).
15 |
16 | Here are two examples of successful detection outputs:
17 | 
18 | 
19 |
20 | To run the notebook you first have to unzip the checkpoint files in ./checkpoint
21 | ```bash
22 | unzip ssd_300_vgg.ckpt.zip
23 | ```
24 | and then start a jupyter notebook with
25 | ```bash
26 | jupyter notebook notebooks/ssd_notebook.ipynb
27 | ```
28 |
29 |
30 | ## Datasets
31 |
32 | The current version only supports Pascal VOC datasets (2007 and 2012). In order to be used for training a SSD model, the former need to be converted to TF-Records using the `tf_convert_data.py` script:
33 | ```bash
34 | DATASET_DIR=./VOC2007/test/
35 | OUTPUT_DIR=./tfrecords
36 | python tf_convert_data.py \
37 | --dataset_name=pascalvoc \
38 | --dataset_dir=${DATASET_DIR} \
39 | --output_name=voc_2007_train \
40 | --output_dir=${OUTPUT_DIR}
41 | ```
42 | Note the previous command generated a collection of TF-Records instead of a single file in order to ease shuffling during training.
43 |
44 | ## Evaluation on Pascal VOC 2007
45 |
46 | The present TensorFlow implementation of SSD models have the following performances:
47 |
48 | | Model | Training data | Testing data | mAP | FPS |
49 | |--------|:---------:|:------:|:------:|:------:|
50 | | [SSD-300 VGG-based](https://drive.google.com/open?id=0B0qPCUZ-3YwWZlJaRTRRQWRFYXM) | VOC07+12 trainval | VOC07 test | 0.778 | - |
51 | | [SSD-300 VGG-based](https://drive.google.com/file/d/0B0qPCUZ-3YwWUXh4UHJrd1RDM3c/view?usp=sharing) | VOC07+12+COCO trainval | VOC07 test | 0.817 | - |
52 | | [SSD-512 VGG-based](https://drive.google.com/open?id=0B0qPCUZ-3YwWT1RCLVZNN3RTVEU) | VOC07+12+COCO trainval | VOC07 test | 0.837 | - |
53 |
54 | We are working hard at reproducing the same performance as the original [Caffe implementation](https://github.com/weiliu89/caffe/tree/ssd)!
55 |
56 | After downloading and extracting the previous checkpoints, the evaluation metrics should be reproducible by running the following command:
57 | ```bash
58 | EVAL_DIR=./logs/
59 | CHECKPOINT_PATH=./checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt
60 | python eval_ssd_network.py \
61 | --eval_dir=${EVAL_DIR} \
62 | --dataset_dir=${DATASET_DIR} \
63 | --dataset_name=pascalvoc_2007 \
64 | --dataset_split_name=test \
65 | --model_name=ssd_300_vgg \
66 | --checkpoint_path=${CHECKPOINT_PATH} \
67 | --batch_size=1
68 | ```
69 | The evaluation script provides estimates on the recall-precision curve and compute the mAP metrics following the Pascal VOC 2007 and 2012 guidelines.
70 |
71 | In addition, if one wants to experiment/test a different Caffe SSD checkpoint, the former can be converted to TensorFlow checkpoints as following:
72 | ```sh
73 | CAFFE_MODEL=./ckpts/SSD_300x300_ft_VOC0712/VGG_VOC0712_SSD_300x300_ft_iter_120000.caffemodel
74 | python caffe_to_tensorflow.py \
75 | --model_name=ssd_300_vgg \
76 | --num_classes=21 \
77 | --caffemodel_path=${CAFFE_MODEL}
78 | ```
79 |
80 | ## Training
81 |
82 | The script `train_ssd_network.py` is in charged of training the network. Similarly to TF-Slim models, one can pass numerous options to the training process (dataset, optimiser, hyper-parameters, model, ...). In particular, it is possible to provide a checkpoint file which can be use as starting point in order to fine-tune a network.
83 |
84 | ### Fine-tuning existing SSD checkpoints
85 |
86 | The easiest way to fine the SSD model is to use as pre-trained SSD network (VGG-300 or VGG-512). For instance, one can fine a model starting from the former as following:
87 | ```bash
88 | DATASET_DIR=./tfrecords
89 | TRAIN_DIR=./logs/
90 | CHECKPOINT_PATH=./checkpoints/ssd_300_vgg.ckpt
91 | python train_ssd_network.py \
92 | --train_dir=${TRAIN_DIR} \
93 | --dataset_dir=${DATASET_DIR} \
94 | --dataset_name=pascalvoc_2012 \
95 | --dataset_split_name=train \
96 | --model_name=ssd_300_vgg \
97 | --checkpoint_path=${CHECKPOINT_PATH} \
98 | --save_summaries_secs=60 \
99 | --save_interval_secs=600 \
100 | --weight_decay=0.0005 \
101 | --optimizer=adam \
102 | --learning_rate=0.001 \
103 | --batch_size=32
104 | ```
105 | Note that in addition to the training script flags, one may also want to experiment with data augmentation parameters (random cropping, resolution, ...) in `ssd_vgg_preprocessing.py` or/and network parameters (feature layers, anchors boxes, ...) in `ssd_vgg_300/512.py`
106 |
107 | Furthermore, the training script can be combined with the evaluation routine in order to monitor the performance of saved checkpoints on a validation dataset. For that purpose, one can pass to training and validation scripts a GPU memory upper limit such that both can run in parallel on the same device. If some GPU memory is available for the evaluation script, the former can be run in parallel as follows:
108 | ```bash
109 | EVAL_DIR=${TRAIN_DIR}/eval
110 | python eval_ssd_network.py \
111 | --eval_dir=${EVAL_DIR} \
112 | --dataset_dir=${DATASET_DIR} \
113 | --dataset_name=pascalvoc_2007 \
114 | --dataset_split_name=test \
115 | --model_name=ssd_300_vgg \
116 | --checkpoint_path=${TRAIN_DIR} \
117 | --wait_for_checkpoints=True \
118 | --batch_size=1 \
119 | --max_num_batches=500
120 | ```
121 |
122 | ### Fine-tuning a network trained on ImageNet
123 |
124 | One can also try to build a new SSD model based on standard architecture (VGG, ResNet, Inception, ...) and set up on top of it the `multibox` layers (with specific anchors, ratios, ...). For that purpose, you can fine-tune a network by only loading the weights of the original architecture, and initialize randomly the rest of network. For instance, in the case of the [VGG-16 architecture](http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz), one can train a new model as following:
125 | ```bash
126 | DATASET_DIR=./tfrecords
127 | TRAIN_DIR=./log/
128 | CHECKPOINT_PATH=./checkpoints/vgg_16.ckpt
129 | python train_ssd_network.py \
130 | --train_dir=${TRAIN_DIR} \
131 | --dataset_dir=${DATASET_DIR} \
132 | --dataset_name=pascalvoc_2007 \
133 | --dataset_split_name=train \
134 | --model_name=ssd_300_vgg \
135 | --checkpoint_path=${CHECKPOINT_PATH} \
136 | --checkpoint_model_scope=vgg_16 \
137 | --checkpoint_exclude_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \
138 | --trainable_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \
139 | --save_summaries_secs=60 \
140 | --save_interval_secs=600 \
141 | --weight_decay=0.0005 \
142 | --optimizer=adam \
143 | --learning_rate=0.001 \
144 | --learning_rate_decay_factor=0.94 \
145 | --batch_size=32
146 | ```
147 | Hence, in the former command, the training script randomly initializes the weights belonging to the `checkpoint_exclude_scopes` and load from the checkpoint file `vgg_16.ckpt` the remaining part of the network. Note that we also specify with the `trainable_scopes` parameter to first only train the new SSD components and left the rest of VGG network unchanged. Once the network has converged to a good first result (~0.5 mAP for instance), you can fine-tuned the complete network as following:
148 | ```bash
149 | DATASET_DIR=./tfrecords
150 | TRAIN_DIR=./log_finetune/
151 | CHECKPOINT_PATH=./log/model.ckpt-N
152 | python train_ssd_network.py \
153 | --train_dir=${TRAIN_DIR} \
154 | --dataset_dir=${DATASET_DIR} \
155 | --dataset_name=pascalvoc_2007 \
156 | --dataset_split_name=train \
157 | --model_name=ssd_300_vgg \
158 | --checkpoint_path=${CHECKPOINT_PATH} \
159 | --checkpoint_model_scope=vgg_16 \
160 | --save_summaries_secs=60 \
161 | --save_interval_secs=600 \
162 | --weight_decay=0.0005 \
163 | --optimizer=adam \
164 | --learning_rate=0.00001 \
165 | --learning_rate_decay_factor=0.94 \
166 | --batch_size=32
167 | ```
168 |
169 | A number of pre-trained weights of popular deep architectures can be found on [TF-Slim models page](https://github.com/tensorflow/models/tree/master/slim).
170 |
--------------------------------------------------------------------------------
/SSD/UseSSD.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import tensorflow as tf
3 |
4 | slim = tf.contrib.slim
5 |
6 | from nets import ssd_vgg_300, np_methods
7 | from preprocessing import ssd_vgg_preprocessing
8 |
9 | import multiprocessing
10 |
11 |
12 | def set_centers():
13 |
14 | print("开启线程:将object_centers放入queue")
15 |
16 | # TensorFlow session: grow memory when needed. TF, DO NOT USE ALL MY GPU MEMORY!!!
17 | gpu_options = tf.GPUOptions(allow_growth=True)
18 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
19 | isess = tf.InteractiveSession(config=config)
20 |
21 | # Input placeholder.
22 | net_shape = (300, 300)
23 | data_format = 'NHWC'
24 | img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
25 | # Evaluation pre-processing: resize to SSD net shape.
26 | image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
27 | img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
28 | image_4d = tf.expand_dims(image_pre, 0)
29 |
30 | # Define the SSD model.
31 | reuse = True if 'ssd_net' in locals() else None
32 | ssd_net = ssd_vgg_300.SSDNet()
33 | with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
34 | predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)
35 |
36 | # Restore SSD model.
37 | # ckpt_filename = 'checkpoints/ssd_300_vgg.ckpt'
38 | ckpt_filename = 'checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt'
39 |
40 | isess.run(tf.global_variables_initializer())
41 | saver = tf.train.Saver()
42 | saver.restore(isess, ckpt_filename)
43 |
44 | # SSD default anchor boxes.
45 | ssd_anchors = ssd_net.anchors(net_shape)
46 |
47 | # Main image processing routine.
48 | def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)):
49 | # Run SSD network.
50 | rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
51 | feed_dict={img_input: img})
52 |
53 | # Get classes and bboxes from the net outputs.
54 | rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
55 | rpredictions, rlocalisations, ssd_anchors,
56 | select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
57 |
58 | rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
59 | rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
60 | rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
61 | # Resize bboxes to original image shape. Note: useless for Resize.WARP!
62 | rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
63 | return rclasses, rscores, rbboxes
64 |
65 | def get_centers(rclasses, rbboxes):
66 | # get center location of object
67 |
68 | number_classes = rclasses.shape[0]
69 | object_centers = []
70 | for i in range(number_classes):
71 | object_center = dict()
72 | object_center['i'] = i
73 | object_center['class'] = rclasses[i]
74 | object_center['x'] = (rbboxes[i, 1] + rbboxes[i, 3]) / 2 # 对象中心的坐标x
75 | object_center['y'] = (rbboxes[i, 0] + rbboxes[i, 2]) / 2 # 对象中心的坐标y
76 | object_centers.append(object_center)
77 | return object_centers
78 |
79 | count = 0
80 | cap = cv2.VideoCapture(0)
81 |
82 | while count < 100:
83 | # 打开摄像头
84 | ret, img = cap.read()
85 | rclasses, rscores, rbboxes = process_image(img)
86 |
87 | '''
88 | classes:
89 | 1.Aeroplanes 2.Bicycles 3.Birds 4.Boats 5.Bottles
90 | 6.Buses 7.Cars 8.Cats 9.Chairs 10.Cows
91 | 11.Dining tables 12.Dogs 13.Horses 14.Motorbikes 15.People
92 | 16.Potted plants 17.Sheep 18.Sofas 19.Trains 20.TV/Monitors
93 | '''
94 | object_centers = get_centers(rclasses, rbboxes)
95 | # print("put object centers: " + str(object_centers))
96 | for object_center in object_centers:
97 | if object_center['class'] == 5 or object_center['class'] == 7:
98 | new_object_center = object_center
99 | q.put(new_object_center)
100 | count += 1
101 | break
102 | print("完成输入")
103 | cap.release()
104 |
105 |
106 |
107 |
108 | def print_centers():
109 |
110 |
111 | print("开启线程:将object_center打印出来")
112 | while True:
113 | if q:
114 | print("get object center:" + str(q.get(True)))
115 |
116 | print("完成输出")
117 |
118 |
119 | q = multiprocessing.Queue()
120 |
121 | set_process = multiprocessing.Process(target=set_centers)
122 | print_process = multiprocessing.Process(target=print_centers)
123 |
124 | set_process.start()
125 | print_process.start()
126 |
127 | set_process.join()
128 | print_process.terminate()
129 |
130 | print("退出主线程")
131 |
132 |
--------------------------------------------------------------------------------
/SSD/caffe_to_tensorflow.py:
--------------------------------------------------------------------------------
1 | """Convert a Caffe model file to TensorFlow checkpoint format.
2 |
3 | Assume that the network built is a equivalent (or a sub-) to the Caffe
4 | definition.
5 | """
6 | import tensorflow as tf
7 |
8 | from nets import caffe_scope
9 | from nets import nets_factory
10 |
11 | slim = tf.contrib.slim
12 |
13 | # =========================================================================== #
14 | # Main flags.
15 | # =========================================================================== #
16 | tf.app.flags.DEFINE_string(
17 | 'model_name', 'ssd_300_vgg', 'Name of the model to convert.')
18 | tf.app.flags.DEFINE_string(
19 | 'num_classes', 21, 'Number of classes in the dataset.')
20 | tf.app.flags.DEFINE_string(
21 | 'caffemodel_path', None,
22 | 'The path to the Caffe model file to convert.')
23 |
24 | FLAGS = tf.app.flags.FLAGS
25 |
26 |
27 | # =========================================================================== #
28 | # Main converting routine.
29 | # =========================================================================== #
30 | def main(_):
31 | # Caffe scope...
32 | caffemodel = caffe_scope.CaffeScope()
33 | caffemodel.load(FLAGS.caffemodel_path)
34 |
35 | tf.logging.set_verbosity(tf.logging.INFO)
36 | with tf.Graph().as_default():
37 | global_step = slim.create_global_step()
38 | num_classes = int(FLAGS.num_classes)
39 |
40 | # Select the network.
41 | ssd_class = nets_factory.get_network(FLAGS.model_name)
42 | ssd_params = ssd_class.default_params._replace(num_classes=num_classes)
43 | ssd_net = ssd_class(ssd_params)
44 | ssd_shape = ssd_net.params.img_shape
45 |
46 | # Image placeholder and model.
47 | shape = (1, ssd_shape[0], ssd_shape[1], 3)
48 | img_input = tf.placeholder(shape=shape, dtype=tf.float32)
49 | # Create model.
50 | with slim.arg_scope(ssd_net.arg_scope_caffe(caffemodel)):
51 | ssd_net.net(img_input, is_training=False)
52 |
53 | init_op = tf.global_variables_initializer()
54 | with tf.Session() as session:
55 | # Run the init operation.
56 | session.run(init_op)
57 |
58 | # Save model in checkpoint.
59 | saver = tf.train.Saver()
60 | ckpt_path = FLAGS.caffemodel_path.replace('.caffemodel', '.ckpt')
61 | saver.save(session, ckpt_path, write_meta_graph=False)
62 |
63 |
64 | if __name__ == '__main__':
65 | tf.app.run()
66 |
67 |
--------------------------------------------------------------------------------
/SSD/code_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 | from multiprocessing import Process, Queue
5 |
6 |
7 | # 写数据进程执行的代码:
8 | def proc_write(q, urls):
9 | print('Process(%s) is writing...' % os.getpid())
10 | for url in urls:
11 | q.put(url)
12 | print('Put %s to queue...' % url)
13 | time.sleep(random.random())
14 |
15 |
16 | # 读数据进程执行的代码:
17 | def proc_read(q):
18 | print('Process(%s) is reading...' % os.getpid())
19 | while True:
20 | url = q.get(True)
21 | print('Get %s from queue.' % url)
22 |
23 |
24 | if __name__ == '__main__':
25 | # 父进程创建Queue,并传给各个子进程:
26 | q = Queue()
27 | proc_writer1 = Process(target=proc_write, args=(q, ['url_1', 'url_2', 'url_3']))
28 | proc_reader = Process(target=proc_read, args=(q,))
29 | # 启动子进程proc_writer,写入:
30 | proc_writer1.start()
31 | # 启动子进程proc_reader,读取:
32 | proc_reader.start()
33 | # 等待proc_writer结束:
34 | proc_writer1.join()
35 | # proc_reader进程里是死循环,无法等待其结束,只能强行终止:
36 | proc_reader.terminate()
37 |
--------------------------------------------------------------------------------
/SSD/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SSD/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides data for the Cifar10 dataset.
16 |
17 | The dataset scripts used to create the dataset can be found at:
18 | tensorflow/models/slim/data/create_cifar10_dataset.py
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | import os
26 |
27 | import tensorflow as tf
28 |
29 | from datasets import dataset_utils
30 |
31 | slim = tf.contrib.slim
32 |
33 | _FILE_PATTERN = 'cifar10_%s.tfrecord'
34 |
35 | SPLITS_TO_SIZES = {'train': 50000, 'test': 10000}
36 |
37 | _NUM_CLASSES = 10
38 |
39 | _ITEMS_TO_DESCRIPTIONS = {
40 | 'image': 'A [32 x 32 x 3] color image.',
41 | 'label': 'A single integer between 0 and 9',
42 | }
43 |
44 |
45 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
46 | """Gets a dataset tuple with instructions for reading cifar10.
47 |
48 | Args:
49 | split_name: A train/test split name.
50 | dataset_dir: The base directory of the dataset sources.
51 | file_pattern: The file pattern to use when matching the dataset sources.
52 | It is assumed that the pattern contains a '%s' string so that the split
53 | name can be inserted.
54 | reader: The TensorFlow reader type.
55 |
56 | Returns:
57 | A `Dataset` namedtuple.
58 |
59 | Raises:
60 | ValueError: if `split_name` is not a valid train/test split.
61 | """
62 | if split_name not in SPLITS_TO_SIZES:
63 | raise ValueError('split name %s was not recognized.' % split_name)
64 |
65 | if not file_pattern:
66 | file_pattern = _FILE_PATTERN
67 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
68 |
69 | # Allowing None in the signature so that dataset_factory can use the default.
70 | if not reader:
71 | reader = tf.TFRecordReader
72 |
73 | keys_to_features = {
74 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
75 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
76 | 'image/class/label': tf.FixedLenFeature(
77 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
78 | }
79 |
80 | items_to_handlers = {
81 | 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]),
82 | 'label': slim.tfexample_decoder.Tensor('image/class/label'),
83 | }
84 |
85 | decoder = slim.tfexample_decoder.TFExampleDecoder(
86 | keys_to_features, items_to_handlers)
87 |
88 | labels_to_names = None
89 | if dataset_utils.has_labels(dataset_dir):
90 | labels_to_names = dataset_utils.read_label_file(dataset_dir)
91 |
92 | return slim.dataset.Dataset(
93 | data_sources=file_pattern,
94 | reader=reader,
95 | decoder=decoder,
96 | num_samples=SPLITS_TO_SIZES[split_name],
97 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
98 | num_classes=_NUM_CLASSES,
99 | labels_to_names=labels_to_names)
100 |
--------------------------------------------------------------------------------
/SSD/datasets/dataset_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A factory-pattern class which returns classification image/label pairs."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from datasets import cifar10
22 | from datasets import imagenet
23 | from datasets import pascalvoc_2007
24 | from datasets import pascalvoc_2012
25 |
26 | datasets_map = {
27 | 'cifar10': cifar10,
28 | 'imagenet': imagenet,
29 | 'pascalvoc_2007': pascalvoc_2007,
30 | 'pascalvoc_2012': pascalvoc_2012,
31 | }
32 |
33 |
34 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):
35 | """Given a dataset name and a split_name returns a Dataset.
36 |
37 | Args:
38 | name: String, the name of the dataset.
39 | split_name: A train/test split name.
40 | dataset_dir: The directory where the dataset files are stored.
41 | file_pattern: The file pattern to use for matching the dataset source files.
42 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default
43 | reader defined by each dataset is used.
44 | Returns:
45 | A `Dataset` class.
46 | Raises:
47 | ValueError: If the dataset `name` is unknown.
48 | """
49 | if name not in datasets_map:
50 | raise ValueError('Name of dataset unknown %s' % name)
51 | return datasets_map[name].get_split(split_name,
52 | dataset_dir,
53 | file_pattern,
54 | reader)
55 |
--------------------------------------------------------------------------------
/SSD/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains utilities for downloading and converting datasets."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import os
21 | import sys
22 | import tarfile
23 |
24 | import tensorflow as tf
25 | from six.moves import urllib
26 |
27 | LABELS_FILENAME = 'labels.txt'
28 |
29 |
30 | def int64_feature(value):
31 | """Wrapper for inserting int64 features into Example proto.
32 | """
33 | if not isinstance(value, list):
34 | value = [value]
35 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
36 |
37 |
38 | def float_feature(value):
39 | """Wrapper for inserting float features into Example proto.
40 | """
41 | if not isinstance(value, list):
42 | value = [value]
43 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
44 |
45 |
46 | def bytes_feature(value):
47 | """Wrapper for inserting bytes features into Example proto.
48 | """
49 | if not isinstance(value, list):
50 | value = [value]
51 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
52 |
53 |
54 | def image_to_tfexample(image_data, image_format, height, width, class_id):
55 | return tf.train.Example(features=tf.train.Features(feature={
56 | 'image/encoded': bytes_feature(image_data),
57 | 'image/format': bytes_feature(image_format),
58 | 'image/class/label': int64_feature(class_id),
59 | 'image/height': int64_feature(height),
60 | 'image/width': int64_feature(width),
61 | }))
62 |
63 |
64 | def download_and_uncompress_tarball(tarball_url, dataset_dir):
65 | """Downloads the `tarball_url` and uncompresses it locally.
66 |
67 | Args:
68 | tarball_url: The URL of a tarball file.
69 | dataset_dir: The directory where the temporary files are stored.
70 | """
71 | filename = tarball_url.split('/')[-1]
72 | filepath = os.path.join(dataset_dir, filename)
73 |
74 | def _progress(count, block_size, total_size):
75 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (
76 | filename, float(count * block_size) / float(total_size) * 100.0))
77 | sys.stdout.flush()
78 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
79 | print()
80 | statinfo = os.stat(filepath)
81 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
82 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
83 |
84 |
85 | def write_label_file(labels_to_class_names, dataset_dir,
86 | filename=LABELS_FILENAME):
87 | """Writes a file with the list of class names.
88 |
89 | Args:
90 | labels_to_class_names: A map of (integer) labels to class names.
91 | dataset_dir: The directory in which the labels file should be written.
92 | filename: The filename where the class names are written.
93 | """
94 | labels_filename = os.path.join(dataset_dir, filename)
95 | with tf.gfile.Open(labels_filename, 'w') as f:
96 | for label in labels_to_class_names:
97 | class_name = labels_to_class_names[label]
98 | f.write('%d:%s\n' % (label, class_name))
99 |
100 |
101 | def has_labels(dataset_dir, filename=LABELS_FILENAME):
102 | """Specifies whether or not the dataset directory contains a label map file.
103 |
104 | Args:
105 | dataset_dir: The directory in which the labels file is found.
106 | filename: The filename where the class names are written.
107 |
108 | Returns:
109 | `True` if the labels file exists and `False` otherwise.
110 | """
111 | return tf.gfile.Exists(os.path.join(dataset_dir, filename))
112 |
113 |
114 | def read_label_file(dataset_dir, filename=LABELS_FILENAME):
115 | """Reads the labels file and returns a mapping from ID to class name.
116 |
117 | Args:
118 | dataset_dir: The directory in which the labels file is found.
119 | filename: The filename where the class names are written.
120 |
121 | Returns:
122 | A map from a label (integer) to class name.
123 | """
124 | labels_filename = os.path.join(dataset_dir, filename)
125 | with tf.gfile.Open(labels_filename, 'rb') as f:
126 | lines = f.read()
127 | lines = lines.split(b'\n')
128 | lines = filter(None, lines)
129 |
130 | labels_to_class_names = {}
131 | for line in lines:
132 | index = line.index(b':')
133 | labels_to_class_names[int(line[:index])] = line[index+1:]
134 | return labels_to_class_names
135 |
--------------------------------------------------------------------------------
/SSD/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides data for the ImageNet ILSVRC 2012 Dataset plus some bounding boxes.
16 |
17 | Some images have one or more bounding boxes associated with the label of the
18 | image. See details here: http://image-net.org/download-bboxes
19 |
20 | ImageNet is based upon WordNet 3.0. To uniquely identify a synset, we use
21 | "WordNet ID" (wnid), which is a concatenation of POS ( i.e. part of speech )
22 | and SYNSET OFFSET of WordNet. For more information, please refer to the
23 | WordNet documentation[http://wordnet.princeton.edu/wordnet/documentation/].
24 |
25 | "There are bounding boxes for over 3000 popular synsets available.
26 | For each synset, there are on average 150 images with bounding boxes."
27 |
28 | WARNING: Don't use for object detection, in this case all the bounding boxes
29 | of the image belong to just one class.
30 | """
31 | from __future__ import absolute_import
32 | from __future__ import division
33 | from __future__ import print_function
34 |
35 | import os
36 |
37 | import tensorflow as tf
38 | from six.moves import urllib
39 |
40 | from datasets import dataset_utils
41 |
42 | slim = tf.contrib.slim
43 |
44 | # TODO(nsilberman): Add tfrecord file type once the script is updated.
45 | _FILE_PATTERN = '%s-*'
46 |
47 | _SPLITS_TO_SIZES = {
48 | 'train': 1281167,
49 | 'validation': 50000,
50 | }
51 |
52 | _ITEMS_TO_DESCRIPTIONS = {
53 | 'image': 'A color image of varying height and width.',
54 | 'label': 'The label id of the image, integer between 0 and 999',
55 | 'label_text': 'The text of the label.',
56 | 'object/bbox': 'A list of bounding boxes.',
57 | 'object/label': 'A list of labels, one per each object.',
58 | }
59 |
60 | _NUM_CLASSES = 1001
61 |
62 |
63 | def create_readable_names_for_imagenet_labels():
64 | """Create a dict mapping label id to human readable string.
65 |
66 | Returns:
67 | labels_to_names: dictionary where keys are integers from to 1000
68 | and values are human-readable names.
69 |
70 | We retrieve a synset file, which contains a list of valid synset labels used
71 | by ILSVRC competition. There is one synset one per line, eg.
72 | # n01440764
73 | # n01443537
74 | We also retrieve a synset_to_human_file, which contains a mapping from synsets
75 | to human-readable names for every synset in Imagenet. These are stored in a
76 | tsv format, as follows:
77 | # n02119247 black fox
78 | # n02119359 silver fox
79 | We assign each synset (in alphabetical order) an integer, starting from 1
80 | (since 0 is reserved for the background class).
81 |
82 | Code is based on
83 | https://github.com/tensorflow/models/blob/master/inception/inception/data/build_imagenet_data.py#L463
84 | """
85 |
86 | # pylint: disable=g-line-too-long
87 | base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/inception/inception/data/'
88 | synset_url = '{}/imagenet_lsvrc_2015_synsets.txt'.format(base_url)
89 | synset_to_human_url = '{}/imagenet_metadata.txt'.format(base_url)
90 |
91 | filename, _ = urllib.request.urlretrieve(synset_url)
92 | synset_list = [s.strip() for s in open(filename).readlines()]
93 | num_synsets_in_ilsvrc = len(synset_list)
94 | assert num_synsets_in_ilsvrc == 1000
95 |
96 | filename, _ = urllib.request.urlretrieve(synset_to_human_url)
97 | synset_to_human_list = open(filename).readlines()
98 | num_synsets_in_all_imagenet = len(synset_to_human_list)
99 | assert num_synsets_in_all_imagenet == 21842
100 |
101 | synset_to_human = {}
102 | for s in synset_to_human_list:
103 | parts = s.strip().split('\t')
104 | assert len(parts) == 2
105 | synset = parts[0]
106 | human = parts[1]
107 | synset_to_human[synset] = human
108 |
109 | label_index = 1
110 | labels_to_names = {0: 'background'}
111 | for synset in synset_list:
112 | name = synset_to_human[synset]
113 | labels_to_names[label_index] = name
114 | label_index += 1
115 |
116 | return labels_to_names
117 |
118 |
119 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
120 | """Gets a dataset tuple with instructions for reading ImageNet.
121 |
122 | Args:
123 | split_name: A train/test split name.
124 | dataset_dir: The base directory of the dataset sources.
125 | file_pattern: The file pattern to use when matching the dataset sources.
126 | It is assumed that the pattern contains a '%s' string so that the split
127 | name can be inserted.
128 | reader: The TensorFlow reader type.
129 |
130 | Returns:
131 | A `Dataset` namedtuple.
132 |
133 | Raises:
134 | ValueError: if `split_name` is not a valid train/test split.
135 | """
136 | if split_name not in _SPLITS_TO_SIZES:
137 | raise ValueError('split name %s was not recognized.' % split_name)
138 |
139 | if not file_pattern:
140 | file_pattern = _FILE_PATTERN
141 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
142 |
143 | # Allowing None in the signature so that dataset_factory can use the default.
144 | if reader is None:
145 | reader = tf.TFRecordReader
146 |
147 | keys_to_features = {
148 | 'image/encoded': tf.FixedLenFeature(
149 | (), tf.string, default_value=''),
150 | 'image/format': tf.FixedLenFeature(
151 | (), tf.string, default_value='jpeg'),
152 | 'image/class/label': tf.FixedLenFeature(
153 | [], dtype=tf.int64, default_value=-1),
154 | 'image/class/text': tf.FixedLenFeature(
155 | [], dtype=tf.string, default_value=''),
156 | 'image/object/bbox/xmin': tf.VarLenFeature(
157 | dtype=tf.float32),
158 | 'image/object/bbox/ymin': tf.VarLenFeature(
159 | dtype=tf.float32),
160 | 'image/object/bbox/xmax': tf.VarLenFeature(
161 | dtype=tf.float32),
162 | 'image/object/bbox/ymax': tf.VarLenFeature(
163 | dtype=tf.float32),
164 | 'image/object/class/label': tf.VarLenFeature(
165 | dtype=tf.int64),
166 | }
167 |
168 | items_to_handlers = {
169 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
170 | 'label': slim.tfexample_decoder.Tensor('image/class/label'),
171 | 'label_text': slim.tfexample_decoder.Tensor('image/class/text'),
172 | 'object/bbox': slim.tfexample_decoder.BoundingBox(
173 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
174 | 'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'),
175 | }
176 |
177 | decoder = slim.tfexample_decoder.TFExampleDecoder(
178 | keys_to_features, items_to_handlers)
179 |
180 | labels_to_names = None
181 | if dataset_utils.has_labels(dataset_dir):
182 | labels_to_names = dataset_utils.read_label_file(dataset_dir)
183 | else:
184 | labels_to_names = create_readable_names_for_imagenet_labels()
185 | dataset_utils.write_label_file(labels_to_names, dataset_dir)
186 |
187 | return slim.dataset.Dataset(
188 | data_sources=file_pattern,
189 | reader=reader,
190 | decoder=decoder,
191 | num_samples=_SPLITS_TO_SIZES[split_name],
192 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
193 | num_classes=_NUM_CLASSES,
194 | labels_to_names=labels_to_names)
195 |
--------------------------------------------------------------------------------
/SSD/datasets/pascalvoc_2007.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides data for the Pascal VOC Dataset (images + annotations).
16 | """
17 | import tensorflow as tf
18 |
19 | from datasets import pascalvoc_common
20 |
21 | slim = tf.contrib.slim
22 |
23 | FILE_PATTERN = 'voc_2007_%s_*.tfrecord'
24 | ITEMS_TO_DESCRIPTIONS = {
25 | 'image': 'A color image of varying height and width.',
26 | 'shape': 'Shape of the image',
27 | 'object/bbox': 'A list of bounding boxes, one per each object.',
28 | 'object/label': 'A list of labels, one per each object.',
29 | }
30 | # (Images, Objects) statistics on every class.
31 | TRAIN_STATISTICS = {
32 | 'none': (0, 0),
33 | 'aeroplane': (238, 306),
34 | 'bicycle': (243, 353),
35 | 'bird': (330, 486),
36 | 'boat': (181, 290),
37 | 'bottle': (244, 505),
38 | 'bus': (186, 229),
39 | 'car': (713, 1250),
40 | 'cat': (337, 376),
41 | 'chair': (445, 798),
42 | 'cow': (141, 259),
43 | 'diningtable': (200, 215),
44 | 'dog': (421, 510),
45 | 'horse': (287, 362),
46 | 'motorbike': (245, 339),
47 | 'person': (2008, 4690),
48 | 'pottedplant': (245, 514),
49 | 'sheep': (96, 257),
50 | 'sofa': (229, 248),
51 | 'train': (261, 297),
52 | 'tvmonitor': (256, 324),
53 | 'total': (5011, 12608),
54 | }
55 | TEST_STATISTICS = {
56 | 'none': (0, 0),
57 | 'aeroplane': (1, 1),
58 | 'bicycle': (1, 1),
59 | 'bird': (1, 1),
60 | 'boat': (1, 1),
61 | 'bottle': (1, 1),
62 | 'bus': (1, 1),
63 | 'car': (1, 1),
64 | 'cat': (1, 1),
65 | 'chair': (1, 1),
66 | 'cow': (1, 1),
67 | 'diningtable': (1, 1),
68 | 'dog': (1, 1),
69 | 'horse': (1, 1),
70 | 'motorbike': (1, 1),
71 | 'person': (1, 1),
72 | 'pottedplant': (1, 1),
73 | 'sheep': (1, 1),
74 | 'sofa': (1, 1),
75 | 'train': (1, 1),
76 | 'tvmonitor': (1, 1),
77 | 'total': (20, 20),
78 | }
79 | SPLITS_TO_SIZES = {
80 | 'train': 5011,
81 | 'test': 4952,
82 | }
83 | SPLITS_TO_STATISTICS = {
84 | 'train': TRAIN_STATISTICS,
85 | 'test': TEST_STATISTICS,
86 | }
87 | NUM_CLASSES = 20
88 |
89 |
90 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
91 | """Gets a dataset tuple with instructions for reading ImageNet.
92 |
93 | Args:
94 | split_name: A train/test split name.
95 | dataset_dir: The base directory of the dataset sources.
96 | file_pattern: The file pattern to use when matching the dataset sources.
97 | It is assumed that the pattern contains a '%s' string so that the split
98 | name can be inserted.
99 | reader: The TensorFlow reader type.
100 |
101 | Returns:
102 | A `Dataset` namedtuple.
103 |
104 | Raises:
105 | ValueError: if `split_name` is not a valid train/test split.
106 | """
107 | if not file_pattern:
108 | file_pattern = FILE_PATTERN
109 | return pascalvoc_common.get_split(split_name, dataset_dir,
110 | file_pattern, reader,
111 | SPLITS_TO_SIZES,
112 | ITEMS_TO_DESCRIPTIONS,
113 | NUM_CLASSES)
114 |
--------------------------------------------------------------------------------
/SSD/datasets/pascalvoc_2012.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides data for the Pascal VOC Dataset (images + annotations).
16 | """
17 | import tensorflow as tf
18 |
19 | from datasets import pascalvoc_common
20 |
21 | slim = tf.contrib.slim
22 |
23 | FILE_PATTERN = 'voc_2012_%s_*.tfrecord'
24 | ITEMS_TO_DESCRIPTIONS = {
25 | 'image': 'A color image of varying height and width.',
26 | 'shape': 'Shape of the image',
27 | 'object/bbox': 'A list of bounding boxes, one per each object.',
28 | 'object/label': 'A list of labels, one per each object.',
29 | }
30 | # (Images, Objects) statistics on every class.
31 | TRAIN_STATISTICS = {
32 | 'none': (0, 0),
33 | 'aeroplane': (670, 865),
34 | 'bicycle': (552, 711),
35 | 'bird': (765, 1119),
36 | 'boat': (508, 850),
37 | 'bottle': (706, 1259),
38 | 'bus': (421, 593),
39 | 'car': (1161, 2017),
40 | 'cat': (1080, 1217),
41 | 'chair': (1119, 2354),
42 | 'cow': (303, 588),
43 | 'diningtable': (538, 609),
44 | 'dog': (1286, 1515),
45 | 'horse': (482, 710),
46 | 'motorbike': (526, 713),
47 | 'person': (4087, 8566),
48 | 'pottedplant': (527, 973),
49 | 'sheep': (325, 813),
50 | 'sofa': (507, 566),
51 | 'train': (544, 628),
52 | 'tvmonitor': (575, 784),
53 | 'total': (11540, 27450),
54 | }
55 | SPLITS_TO_SIZES = {
56 | 'train': 17125,
57 | }
58 | SPLITS_TO_STATISTICS = {
59 | 'train': TRAIN_STATISTICS,
60 | }
61 | NUM_CLASSES = 20
62 |
63 |
64 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
65 | """Gets a dataset tuple with instructions for reading ImageNet.
66 |
67 | Args:
68 | split_name: A train/test split name.
69 | dataset_dir: The base directory of the dataset sources.
70 | file_pattern: The file pattern to use when matching the dataset sources.
71 | It is assumed that the pattern contains a '%s' string so that the split
72 | name can be inserted.
73 | reader: The TensorFlow reader type.
74 |
75 | Returns:
76 | A `Dataset` namedtuple.
77 |
78 | Raises:
79 | ValueError: if `split_name` is not a valid train/test split.
80 | """
81 | if not file_pattern:
82 | file_pattern = FILE_PATTERN
83 | return pascalvoc_common.get_split(split_name, dataset_dir,
84 | file_pattern, reader,
85 | SPLITS_TO_SIZES,
86 | ITEMS_TO_DESCRIPTIONS,
87 | NUM_CLASSES)
88 |
89 |
--------------------------------------------------------------------------------
/SSD/datasets/pascalvoc_common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides data for the Pascal VOC Dataset (images + annotations).
16 | """
17 | import os
18 |
19 | import tensorflow as tf
20 |
21 | from datasets import dataset_utils
22 |
23 | slim = tf.contrib.slim
24 |
25 | VOC_LABELS = {
26 | 'none': (0, 'Background'),
27 | 'aeroplane': (1, 'Vehicle'),
28 | 'bicycle': (2, 'Vehicle'),
29 | 'bird': (3, 'Animal'),
30 | 'boat': (4, 'Vehicle'),
31 | 'bottle': (5, 'Indoor'),
32 | 'bus': (6, 'Vehicle'),
33 | 'car': (7, 'Vehicle'),
34 | 'cat': (8, 'Animal'),
35 | 'chair': (9, 'Indoor'),
36 | 'cow': (10, 'Animal'),
37 | 'diningtable': (11, 'Indoor'),
38 | 'dog': (12, 'Animal'),
39 | 'horse': (13, 'Animal'),
40 | 'motorbike': (14, 'Vehicle'),
41 | 'person': (15, 'Person'),
42 | 'pottedplant': (16, 'Indoor'),
43 | 'sheep': (17, 'Animal'),
44 | 'sofa': (18, 'Indoor'),
45 | 'train': (19, 'Vehicle'),
46 | 'tvmonitor': (20, 'Indoor'),
47 | }
48 |
49 |
50 | def get_split(split_name, dataset_dir, file_pattern, reader,
51 | split_to_sizes, items_to_descriptions, num_classes):
52 | """Gets a dataset tuple with instructions for reading Pascal VOC dataset.
53 |
54 | Args:
55 | split_name: A train/test split name.
56 | dataset_dir: The base directory of the dataset sources.
57 | file_pattern: The file pattern to use when matching the dataset sources.
58 | It is assumed that the pattern contains a '%s' string so that the split
59 | name can be inserted.
60 | reader: The TensorFlow reader type.
61 |
62 | Returns:
63 | A `Dataset` namedtuple.
64 |
65 | Raises:
66 | ValueError: if `split_name` is not a valid train/test split.
67 | """
68 | if split_name not in split_to_sizes:
69 | raise ValueError('split name %s was not recognized.' % split_name)
70 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
71 |
72 | # Allowing None in the signature so that dataset_factory can use the default.
73 | if reader is None:
74 | reader = tf.TFRecordReader
75 | # Features in Pascal VOC TFRecords.
76 | keys_to_features = {
77 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
78 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
79 | 'image/height': tf.FixedLenFeature([1], tf.int64),
80 | 'image/width': tf.FixedLenFeature([1], tf.int64),
81 | 'image/channels': tf.FixedLenFeature([1], tf.int64),
82 | 'image/shape': tf.FixedLenFeature([3], tf.int64),
83 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
84 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
85 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
86 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
87 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
88 | 'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64),
89 | 'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64),
90 | }
91 | items_to_handlers = {
92 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
93 | 'shape': slim.tfexample_decoder.Tensor('image/shape'),
94 | 'object/bbox': slim.tfexample_decoder.BoundingBox(
95 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
96 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
97 | 'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'),
98 | 'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'),
99 | }
100 | decoder = slim.tfexample_decoder.TFExampleDecoder(
101 | keys_to_features, items_to_handlers)
102 |
103 | labels_to_names = None
104 | if dataset_utils.has_labels(dataset_dir):
105 | labels_to_names = dataset_utils.read_label_file(dataset_dir)
106 | # else:
107 | # labels_to_names = create_readable_names_for_imagenet_labels()
108 | # dataset_utils.write_label_file(labels_to_names, dataset_dir)
109 |
110 | return slim.dataset.Dataset(
111 | data_sources=file_pattern,
112 | reader=reader,
113 | decoder=decoder,
114 | num_samples=split_to_sizes[split_name],
115 | items_to_descriptions=items_to_descriptions,
116 | num_classes=num_classes,
117 | labels_to_names=labels_to_names)
118 |
--------------------------------------------------------------------------------
/SSD/datasets/pascalvoc_to_tfrecords.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Converts Pascal VOC data to TFRecords file format with Example protos.
16 |
17 | The raw Pascal VOC data set is expected to reside in JPEG files located in the
18 | directory 'JPEGImages'. Similarly, bounding box annotations are supposed to be
19 | stored in the 'Annotation directory'
20 |
21 | This TensorFlow script converts the training and evaluation data into
22 | a sharded data set consisting of 1024 and 128 TFRecord files, respectively.
23 |
24 | Each validation TFRecord file contains ~500 records. Each training TFREcord
25 | file contains ~1000 records. Each record within the TFRecord file is a
26 | serialized Example proto. The Example proto contains the following fields:
27 |
28 | image/encoded: string containing JPEG encoded image in RGB colorspace
29 | image/height: integer, image height in pixels
30 | image/width: integer, image width in pixels
31 | image/channels: integer, specifying the number of channels, always 3
32 | image/format: string, specifying the format, always'JPEG'
33 |
34 |
35 | image/object/bbox/xmin: list of float specifying the 0+ human annotated
36 | bounding boxes
37 | image/object/bbox/xmax: list of float specifying the 0+ human annotated
38 | bounding boxes
39 | image/object/bbox/ymin: list of float specifying the 0+ human annotated
40 | bounding boxes
41 | image/object/bbox/ymax: list of float specifying the 0+ human annotated
42 | bounding boxes
43 | image/object/bbox/label: list of integer specifying the classification index.
44 | image/object/bbox/label_text: list of string descriptions.
45 |
46 | Note that the length of xmin is identical to the length of xmax, ymin and ymax
47 | for each example.
48 | """
49 | import os
50 | import random
51 | import sys
52 | import xml.etree.ElementTree as ET
53 |
54 | import tensorflow as tf
55 |
56 | from datasets.dataset_utils import int64_feature, float_feature, bytes_feature
57 | from datasets.pascalvoc_common import VOC_LABELS
58 |
59 | # Original dataset organisation.
60 | DIRECTORY_ANNOTATIONS = 'Annotations/'
61 | DIRECTORY_IMAGES = 'JPEGImages/'
62 |
63 | # TFRecords convertion parameters.
64 | RANDOM_SEED = 4242
65 | SAMPLES_PER_FILES = 200
66 |
67 |
68 | def _process_image(directory, name):
69 | """Process a image and annotation file.
70 |
71 | Args:
72 | filename: string, path to an image file e.g., '/path/to/example.JPG'.
73 | coder: instance of ImageCoder to provide TensorFlow image coding utils.
74 | Returns:
75 | image_buffer: string, JPEG encoding of RGB image.
76 | height: integer, image height in pixels.
77 | width: integer, image width in pixels.
78 | """
79 | # Read the image file.
80 | filename = directory + DIRECTORY_IMAGES + name + '.jpg'
81 | image_data = tf.gfile.FastGFile(filename, 'r').read()
82 |
83 | # Read the XML annotation file.
84 | filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
85 | tree = ET.parse(filename)
86 | root = tree.getroot()
87 |
88 | # Image shape.
89 | size = root.find('size')
90 | shape = [int(size.find('height').text),
91 | int(size.find('width').text),
92 | int(size.find('depth').text)]
93 | # Find annotations.
94 | bboxes = []
95 | labels = []
96 | labels_text = []
97 | difficult = []
98 | truncated = []
99 | for obj in root.findall('object'):
100 | label = obj.find('name').text
101 | labels.append(int(VOC_LABELS[label][0]))
102 | labels_text.append(label.encode('ascii'))
103 |
104 | if obj.find('difficult'):
105 | difficult.append(int(obj.find('difficult').text))
106 | else:
107 | difficult.append(0)
108 | if obj.find('truncated'):
109 | truncated.append(int(obj.find('truncated').text))
110 | else:
111 | truncated.append(0)
112 |
113 | bbox = obj.find('bndbox')
114 | bboxes.append((float(bbox.find('ymin').text) / shape[0],
115 | float(bbox.find('xmin').text) / shape[1],
116 | float(bbox.find('ymax').text) / shape[0],
117 | float(bbox.find('xmax').text) / shape[1]
118 | ))
119 | return image_data, shape, bboxes, labels, labels_text, difficult, truncated
120 |
121 |
122 | def _convert_to_example(image_data, labels, labels_text, bboxes, shape,
123 | difficult, truncated):
124 | """Build an Example proto for an image example.
125 |
126 | Args:
127 | image_data: string, JPEG encoding of RGB image;
128 | labels: list of integers, identifier for the ground truth;
129 | labels_text: list of strings, human-readable labels;
130 | bboxes: list of bounding boxes; each box is a list of integers;
131 | specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
132 | to the same label as the image label.
133 | shape: 3 integers, image shapes in pixels.
134 | Returns:
135 | Example proto
136 | """
137 | xmin = []
138 | ymin = []
139 | xmax = []
140 | ymax = []
141 | for b in bboxes:
142 | assert len(b) == 4
143 | # pylint: disable=expression-not-assigned
144 | [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
145 | # pylint: enable=expression-not-assigned
146 |
147 | image_format = b'JPEG'
148 | example = tf.train.Example(features=tf.train.Features(feature={
149 | 'image/height': int64_feature(shape[0]),
150 | 'image/width': int64_feature(shape[1]),
151 | 'image/channels': int64_feature(shape[2]),
152 | 'image/shape': int64_feature(shape),
153 | 'image/object/bbox/xmin': float_feature(xmin),
154 | 'image/object/bbox/xmax': float_feature(xmax),
155 | 'image/object/bbox/ymin': float_feature(ymin),
156 | 'image/object/bbox/ymax': float_feature(ymax),
157 | 'image/object/bbox/label': int64_feature(labels),
158 | 'image/object/bbox/label_text': bytes_feature(labels_text),
159 | 'image/object/bbox/difficult': int64_feature(difficult),
160 | 'image/object/bbox/truncated': int64_feature(truncated),
161 | 'image/format': bytes_feature(image_format),
162 | 'image/encoded': bytes_feature(image_data)}))
163 | return example
164 |
165 |
166 | def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
167 | """Loads data from image and annotations files and add them to a TFRecord.
168 |
169 | Args:
170 | dataset_dir: Dataset directory;
171 | name: Image name to add to the TFRecord;
172 | tfrecord_writer: The TFRecord writer to use for writing.
173 | """
174 | image_data, shape, bboxes, labels, labels_text, difficult, truncated = \
175 | _process_image(dataset_dir, name)
176 | example = _convert_to_example(image_data, labels, labels_text,
177 | bboxes, shape, difficult, truncated)
178 | tfrecord_writer.write(example.SerializeToString())
179 |
180 |
181 | def _get_output_filename(output_dir, name, idx):
182 | return '%s/%s_%03d.tfrecord' % (output_dir, name, idx)
183 |
184 |
185 | def run(dataset_dir, output_dir, name='voc_train', shuffling=False):
186 | """Runs the conversion operation.
187 |
188 | Args:
189 | dataset_dir: The dataset directory where the dataset is stored.
190 | output_dir: Output directory.
191 | """
192 | if not tf.gfile.Exists(dataset_dir):
193 | tf.gfile.MakeDirs(dataset_dir)
194 |
195 | # Dataset filenames, and shuffling.
196 | path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)
197 | filenames = sorted(os.listdir(path))
198 | if shuffling:
199 | random.seed(RANDOM_SEED)
200 | random.shuffle(filenames)
201 |
202 | # Process dataset files.
203 | i = 0
204 | fidx = 0
205 | while i < len(filenames):
206 | # Open new TFRecord file.
207 | tf_filename = _get_output_filename(output_dir, name, fidx)
208 | with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
209 | j = 0
210 | while i < len(filenames) and j < SAMPLES_PER_FILES:
211 | sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))
212 | sys.stdout.flush()
213 |
214 | filename = filenames[i]
215 | img_name = filename[:-4]
216 | _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer)
217 | i += 1
218 | j += 1
219 | fidx += 1
220 |
221 | # Finally, write the labels file:
222 | # labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
223 | # dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
224 | print('\nFinished converting the Pascal VOC dataset!')
225 |
--------------------------------------------------------------------------------
/SSD/inspect_checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A simple script for inspect checkpoint files."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import argparse
21 | import sys
22 |
23 | import numpy as np
24 | from tensorflow.python import pywrap_tensorflow
25 | from tensorflow.python.platform import app
26 | from tensorflow.python.platform import flags
27 |
28 | FLAGS = None
29 |
30 |
31 | def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
32 | """Prints tensors in a checkpoint file.
33 |
34 | If no `tensor_name` is provided, prints the tensor names and shapes
35 | in the checkpoint file.
36 |
37 | If `tensor_name` is provided, prints the content of the tensor.
38 |
39 | Args:
40 | file_name: Name of the checkpoint file.
41 | tensor_name: Name of the tensor in the checkpoint file to print.
42 | all_tensors: Boolean indicating whether to print all tensors.
43 | """
44 | try:
45 | reader = pywrap_tensorflow.NewCheckpointReader(file_name)
46 | if all_tensors:
47 | var_to_shape_map = reader.get_variable_to_shape_map()
48 | for key in var_to_shape_map:
49 | print("tensor_name: ", key)
50 | print(reader.get_tensor(key))
51 | elif not tensor_name:
52 | print(reader.debug_string().decode("utf-8"))
53 | else:
54 | print("tensor_name: ", tensor_name)
55 | print(reader.get_tensor(tensor_name))
56 | except Exception as e: # pylint: disable=broad-except
57 | print(str(e))
58 | if "corrupted compressed block contents" in str(e):
59 | print("It's likely that your checkpoint file has been compressed "
60 | "with SNAPPY.")
61 |
62 |
63 | def parse_numpy_printoption(kv_str):
64 | """Sets a single numpy printoption from a string of the form 'x=y'.
65 |
66 | See documentation on numpy.set_printoptions() for details about what values
67 | x and y can take. x can be any option listed there other than 'formatter'.
68 |
69 | Args:
70 | kv_str: A string of the form 'x=y', such as 'threshold=100000'
71 |
72 | Raises:
73 | argparse.ArgumentTypeError: If the string couldn't be used to set any
74 | nump printoption.
75 | """
76 | k_v_str = kv_str.split("=", 1)
77 | if len(k_v_str) != 2 or not k_v_str[0]:
78 | raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str)
79 | k, v_str = k_v_str
80 | printoptions = np.get_printoptions()
81 | if k not in printoptions:
82 | raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k)
83 | v_type = type(printoptions[k])
84 | if v_type is type(None):
85 | raise argparse.ArgumentTypeError(
86 | "Setting '%s' from the command line is not supported." % k)
87 | try:
88 | v = (v_type(v_str) if v_type is not bool
89 | else flags.BooleanParser().Parse(v_str))
90 | except ValueError as e:
91 | raise argparse.ArgumentTypeError(e.message)
92 | np.set_printoptions(**{k: v})
93 |
94 |
95 | def main(unused_argv):
96 | if not FLAGS.file_name:
97 | print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
98 | "[--tensor_name=tensor_to_print]")
99 | sys.exit(1)
100 | else:
101 | print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
102 | FLAGS.all_tensors)
103 |
104 |
105 | if __name__ == "__main__":
106 | parser = argparse.ArgumentParser()
107 | parser.register("type", "bool", lambda v: v.lower() == "true")
108 | parser.add_argument(
109 | "--file_name", type=str, default="", help="Checkpoint filename. "
110 | "Note, if using Checkpoint V2 format, file_name is the "
111 | "shared prefix between all files in the checkpoint.")
112 | parser.add_argument(
113 | "--tensor_name",
114 | type=str,
115 | default="",
116 | help="Name of the tensor to inspect")
117 | parser.add_argument(
118 | "--all_tensors",
119 | nargs="?",
120 | const=True,
121 | type="bool",
122 | default=False,
123 | help="If True, print the values of all the tensors.")
124 | parser.add_argument(
125 | "--printoptions",
126 | nargs="*",
127 | type=parse_numpy_printoption,
128 | help="Argument for numpy.set_printoptions(), in the form 'k=v'.")
129 | FLAGS, unparsed = parser.parse_known_args()
130 | app.run(main=main, argv=[sys.argv[0]] + unparsed)
131 |
--------------------------------------------------------------------------------
/SSD/nets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SSD/nets/caffe_scope.py:
--------------------------------------------------------------------------------
1 | """Specific Caffe scope used to import weights from a .caffemodel file.
2 |
3 | The idea is to create special initializers loading weights from protobuf
4 | .caffemodel files.
5 | """
6 | import numpy as np
7 | import tensorflow as tf
8 | from caffe.proto import caffe_pb2
9 |
10 | slim = tf.contrib.slim
11 |
12 |
13 | class CaffeScope(object):
14 | """Caffe scope.
15 | """
16 | def __init__(self):
17 | """Initialize the caffee scope.
18 | """
19 | self.counters = {}
20 | self.layers = {}
21 | self.caffe_layers = None
22 | self.bgr_to_rgb = 0
23 |
24 | def load(self, filename, bgr_to_rgb=True):
25 | """Load weights from a .caffemodel file and initialize counters.
26 |
27 | Params:
28 | filename: caffemodel file.
29 | """
30 | print('Loading Caffe file:', filename)
31 | caffemodel_params = caffe_pb2.NetParameter()
32 | caffemodel_str = open(filename, 'rb').read()
33 | caffemodel_params.ParseFromString(caffemodel_str)
34 | self.caffe_layers = caffemodel_params.layer
35 |
36 | # Layers collection.
37 | self.layers['convolution'] = [i for i, l in enumerate(self.caffe_layers)
38 | if l.type == 'Convolution']
39 | self.layers['l2_normalization'] = [i for i, l in enumerate(self.caffe_layers)
40 | if l.type == 'Normalize']
41 | # BGR to RGB convertion. Tries to find the first convolution with 3
42 | # and exchange parameters.
43 | if bgr_to_rgb:
44 | self.bgr_to_rgb = 1
45 |
46 | def conv_weights_init(self):
47 | def _initializer(shape, dtype, partition_info=None):
48 | counter = self.counters.get(self.conv_weights_init, 0)
49 | idx = self.layers['convolution'][counter]
50 | layer = self.caffe_layers[idx]
51 | # Weights: reshape and transpose dimensions.
52 | w = np.array(layer.blobs[0].data)
53 | w = np.reshape(w, layer.blobs[0].shape.dim)
54 | # w = np.transpose(w, (1, 0, 2, 3))
55 | w = np.transpose(w, (2, 3, 1, 0))
56 | if self.bgr_to_rgb == 1 and w.shape[2] == 3:
57 | print('Convert BGR to RGB in convolution layer:', layer.name)
58 | w[:, :, (0, 1, 2)] = w[:, :, (2, 1, 0)]
59 | self.bgr_to_rgb += 1
60 | self.counters[self.conv_weights_init] = counter + 1
61 | print('Load weights from convolution layer:', layer.name, w.shape)
62 | return tf.cast(w, dtype)
63 | return _initializer
64 |
65 | def conv_biases_init(self):
66 | def _initializer(shape, dtype, partition_info=None):
67 | counter = self.counters.get(self.conv_biases_init, 0)
68 | idx = self.layers['convolution'][counter]
69 | layer = self.caffe_layers[idx]
70 | # Biases data...
71 | b = np.array(layer.blobs[1].data)
72 | self.counters[self.conv_biases_init] = counter + 1
73 | print('Load biases from convolution layer:', layer.name, b.shape)
74 | return tf.cast(b, dtype)
75 | return _initializer
76 |
77 | def l2_norm_scale_init(self):
78 | def _initializer(shape, dtype, partition_info=None):
79 | counter = self.counters.get(self.l2_norm_scale_init, 0)
80 | idx = self.layers['l2_normalization'][counter]
81 | layer = self.caffe_layers[idx]
82 | # Scaling parameter.
83 | s = np.array(layer.blobs[0].data)
84 | s = np.reshape(s, layer.blobs[0].shape.dim)
85 | self.counters[self.l2_norm_scale_init] = counter + 1
86 | print('Load scaling from L2 normalization layer:', layer.name, s.shape)
87 | return tf.cast(s, dtype)
88 | return _initializer
89 |
--------------------------------------------------------------------------------
/SSD/nets/custom_layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Implement some custom layers, not provided by TensorFlow.
16 |
17 | Trying to follow as much as possible the style/standards used in
18 | tf.contrib.layers
19 | """
20 | import tensorflow as tf
21 | from tensorflow.contrib.framework.python.ops import add_arg_scope
22 | from tensorflow.contrib.framework.python.ops import variables
23 | from tensorflow.contrib.layers.python.layers import utils
24 | from tensorflow.python.ops import init_ops
25 | from tensorflow.python.ops import nn
26 | from tensorflow.python.ops import variable_scope
27 |
28 |
29 | def abs_smooth(x):
30 | """Smoothed absolute function. Useful to compute an L1 smooth error.
31 |
32 | Define as:
33 | x^2 / 2 if abs(x) < 1
34 | abs(x) - 0.5 if abs(x) > 1
35 | We use here a differentiable definition using min(x) and abs(x). Clearly
36 | not optimal, but good enough for our purpose!
37 | """
38 | absx = tf.abs(x)
39 | minx = tf.minimum(absx, 1)
40 | r = 0.5 * ((absx - 1) * minx + absx)
41 | return r
42 |
43 |
44 | @add_arg_scope
45 | def l2_normalization(
46 | inputs,
47 | scaling=False,
48 | scale_initializer=init_ops.ones_initializer(),
49 | reuse=None,
50 | variables_collections=None,
51 | outputs_collections=None,
52 | data_format='NHWC',
53 | trainable=True,
54 | scope=None):
55 | """Implement L2 normalization on every feature (i.e. spatial normalization).
56 |
57 | Should be extended in some near future to other dimensions, providing a more
58 | flexible normalization framework.
59 |
60 | Args:
61 | inputs: a 4-D tensor with dimensions [batch_size, height, width, channels].
62 | scaling: whether or not to add a post scaling operation along the dimensions
63 | which have been normalized.
64 | scale_initializer: An initializer for the weights.
65 | reuse: whether or not the layer and its variables should be reused. To be
66 | able to reuse the layer scope must be given.
67 | variables_collections: optional list of collections for all the variables or
68 | a dictionary containing a different list of collection per variable.
69 | outputs_collections: collection to add the outputs.
70 | data_format: NHWC or NCHW data format.
71 | trainable: If `True` also add variables to the graph collection
72 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
73 | scope: Optional scope for `variable_scope`.
74 | Returns:
75 | A `Tensor` representing the output of the operation.
76 | """
77 |
78 | with variable_scope.variable_scope(
79 | scope, 'L2Normalization', [inputs], reuse=reuse) as sc:
80 | inputs_shape = inputs.get_shape()
81 | inputs_rank = inputs_shape.ndims
82 | dtype = inputs.dtype.base_dtype
83 | if data_format == 'NHWC':
84 | # norm_dim = tf.range(1, inputs_rank-1)
85 | norm_dim = tf.range(inputs_rank-1, inputs_rank)
86 | params_shape = inputs_shape[-1:]
87 | elif data_format == 'NCHW':
88 | # norm_dim = tf.range(2, inputs_rank)
89 | norm_dim = tf.range(1, 2)
90 | params_shape = (inputs_shape[1])
91 |
92 | # Normalize along spatial dimensions.
93 | outputs = nn.l2_normalize(inputs, norm_dim, epsilon=1e-12)
94 | # Additional scaling.
95 | if scaling:
96 | scale_collections = utils.get_variable_collections(
97 | variables_collections, 'scale')
98 | scale = variables.model_variable('gamma',
99 | shape=params_shape,
100 | dtype=dtype,
101 | initializer=scale_initializer,
102 | collections=scale_collections,
103 | trainable=trainable)
104 | if data_format == 'NHWC':
105 | outputs = tf.multiply(outputs, scale)
106 | elif data_format == 'NCHW':
107 | scale = tf.expand_dims(scale, axis=-1)
108 | scale = tf.expand_dims(scale, axis=-1)
109 | outputs = tf.multiply(outputs, scale)
110 | # outputs = tf.transpose(outputs, perm=(0, 2, 3, 1))
111 |
112 | return utils.collect_named_outputs(outputs_collections,
113 | sc.original_name_scope, outputs)
114 |
115 |
116 | @add_arg_scope
117 | def pad2d(inputs,
118 | pad=(0, 0),
119 | mode='CONSTANT',
120 | data_format='NHWC',
121 | trainable=True,
122 | scope=None):
123 | """2D Padding layer, adding a symmetric padding to H and W dimensions.
124 |
125 | Aims to mimic padding in Caffe and MXNet, helping the port of models to
126 | TensorFlow. Tries to follow the naming convention of `tf.contrib.layers`.
127 |
128 | Args:
129 | inputs: 4D input Tensor;
130 | pad: 2-Tuple with padding values for H and W dimensions;
131 | mode: Padding mode. C.f. `tf.pad`
132 | data_format: NHWC or NCHW data format.
133 | """
134 | with tf.name_scope(scope, 'pad2d', [inputs]):
135 | # Padding shape.
136 | if data_format == 'NHWC':
137 | paddings = [[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]]
138 | elif data_format == 'NCHW':
139 | paddings = [[0, 0], [0, 0], [pad[0], pad[0]], [pad[1], pad[1]]]
140 | net = tf.pad(inputs, paddings, mode=mode)
141 | return net
142 |
143 |
144 | @add_arg_scope
145 | def channel_to_last(inputs,
146 | data_format='NHWC',
147 | scope=None):
148 | """Move the channel axis to the last dimension. Allows to
149 | provide a single output format whatever the input data format.
150 |
151 | Args:
152 | inputs: Input Tensor;
153 | data_format: NHWC or NCHW.
154 | Return:
155 | Input in NHWC format.
156 | """
157 | with tf.name_scope(scope, 'channel_to_last', [inputs]):
158 | if data_format == 'NHWC':
159 | net = inputs
160 | elif data_format == 'NCHW':
161 | net = tf.transpose(inputs, perm=(0, 2, 3, 1))
162 | return net
163 |
--------------------------------------------------------------------------------
/SSD/nets/inception.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Brings inception_v1, inception_v2 and inception_v3 under one namespace."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # pylint: disable=unused-import
22 | # from nets.inception_v1 import inception_v1
23 | # from nets.inception_v1 import inception_v1_arg_scope
24 | # from nets.inception_v1 import inception_v1_base
25 | # from nets.inception_v2 import inception_v2
26 | # from nets.inception_v2 import inception_v2_arg_scope
27 | # from nets.inception_v2 import inception_v2_base
28 | # pylint: enable=unused-import
29 |
--------------------------------------------------------------------------------
/SSD/nets/nets_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models.
16 | """
17 |
18 | import functools
19 |
20 | import tensorflow as tf
21 |
22 | from nets import ssd_vgg_300
23 | from nets import ssd_vgg_512
24 | # from nets import inception
25 | # from nets import overfeat
26 | # from nets import resnet_v1
27 | # from nets import resnet_v2
28 | from nets import vgg
29 |
30 | # from nets import xception
31 |
32 | slim = tf.contrib.slim
33 |
34 | networks_map = {'vgg_a': vgg.vgg_a,
35 | 'vgg_16': vgg.vgg_16,
36 | 'vgg_19': vgg.vgg_19,
37 | 'ssd_300_vgg': ssd_vgg_300.ssd_net,
38 | 'ssd_300_vgg_caffe': ssd_vgg_300.ssd_net,
39 | 'ssd_512_vgg': ssd_vgg_512.ssd_net,
40 | 'ssd_512_vgg_caffe': ssd_vgg_512.ssd_net,
41 | }
42 |
43 | arg_scopes_map = {'vgg_a': vgg.vgg_arg_scope,
44 | 'vgg_16': vgg.vgg_arg_scope,
45 | 'vgg_19': vgg.vgg_arg_scope,
46 | 'ssd_300_vgg': ssd_vgg_300.ssd_arg_scope,
47 | 'ssd_300_vgg_caffe': ssd_vgg_300.ssd_arg_scope_caffe,
48 | 'ssd_512_vgg': ssd_vgg_512.ssd_arg_scope,
49 | 'ssd_512_vgg_caffe': ssd_vgg_512.ssd_arg_scope_caffe,
50 | }
51 |
52 | networks_obj = {'ssd_300_vgg': ssd_vgg_300.SSDNet,
53 | 'ssd_512_vgg': ssd_vgg_512.SSDNet,
54 | }
55 |
56 |
57 | def get_network(name):
58 | """Get a network object from a name.
59 | """
60 | # params = networks_obj[name].default_params if params is None else params
61 | return networks_obj[name]
62 |
63 |
64 | def get_network_fn(name, num_classes, is_training=False, **kwargs):
65 | """Returns a network_fn such as `logits, end_points = network_fn(images)`.
66 |
67 | Args:
68 | name: The name of the network.
69 | num_classes: The number of classes to use for classification.
70 | is_training: `True` if the model is being used for training and `False`
71 | otherwise.
72 | weight_decay: The l2 coefficient for the model weights.
73 | Returns:
74 | network_fn: A function that applies the model to a batch of images. It has
75 | the following signature: logits, end_points = network_fn(images)
76 | Raises:
77 | ValueError: If network `name` is not recognized.
78 | """
79 | if name not in networks_map:
80 | raise ValueError('Name of network unknown %s' % name)
81 | arg_scope = arg_scopes_map[name](**kwargs)
82 | func = networks_map[name]
83 | @functools.wraps(func)
84 | def network_fn(images, **kwargs):
85 | with slim.arg_scope(arg_scope):
86 | return func(images, num_classes, is_training=is_training, **kwargs)
87 | if hasattr(func, 'default_image_size'):
88 | network_fn.default_image_size = func.default_image_size
89 |
90 | return network_fn
91 |
--------------------------------------------------------------------------------
/SSD/nets/np_methods.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Additional Numpy methods. Big mess of many things!
16 | """
17 | import numpy as np
18 |
19 |
20 | # =========================================================================== #
21 | # Numpy implementations of SSD boxes functions.
22 | # =========================================================================== #
23 | def ssd_bboxes_decode(feat_localizations,
24 | anchor_bboxes,
25 | prior_scaling=[0.1, 0.1, 0.2, 0.2]):
26 | """Compute the relative bounding boxes from the layer features and
27 | reference anchor bounding boxes.
28 |
29 | Return:
30 | numpy array Nx4: ymin, xmin, ymax, xmax
31 | """
32 | # Reshape for easier broadcasting.
33 | l_shape = feat_localizations.shape
34 | feat_localizations = np.reshape(feat_localizations,
35 | (-1, l_shape[-2], l_shape[-1]))
36 | yref, xref, href, wref = anchor_bboxes
37 | xref = np.reshape(xref, [-1, 1])
38 | yref = np.reshape(yref, [-1, 1])
39 |
40 | # Compute center, height and width
41 | cx = feat_localizations[:, :, 0] * wref * prior_scaling[0] + xref
42 | cy = feat_localizations[:, :, 1] * href * prior_scaling[1] + yref
43 | w = wref * np.exp(feat_localizations[:, :, 2] * prior_scaling[2])
44 | h = href * np.exp(feat_localizations[:, :, 3] * prior_scaling[3])
45 | # bboxes: ymin, xmin, xmax, ymax.
46 | bboxes = np.zeros_like(feat_localizations)
47 | bboxes[:, :, 0] = cy - h / 2.
48 | bboxes[:, :, 1] = cx - w / 2.
49 | bboxes[:, :, 2] = cy + h / 2.
50 | bboxes[:, :, 3] = cx + w / 2.
51 | # Back to original shape.
52 | bboxes = np.reshape(bboxes, l_shape)
53 | return bboxes
54 |
55 |
56 | def ssd_bboxes_select_layer(predictions_layer,
57 | localizations_layer,
58 | anchors_layer,
59 | select_threshold=0.5,
60 | img_shape=(300, 300),
61 | num_classes=21,
62 | decode=True):
63 | """Extract classes, scores and bounding boxes from features in one layer.
64 |
65 | Return:
66 | classes, scores, bboxes: Numpy arrays...
67 | """
68 | # First decode localizations features if necessary.
69 | if decode:
70 | localizations_layer = ssd_bboxes_decode(localizations_layer, anchors_layer)
71 |
72 | # Reshape features to: Batches x N x N_labels | 4.
73 | p_shape = predictions_layer.shape
74 | batch_size = p_shape[0] if len(p_shape) == 5 else 1
75 | predictions_layer = np.reshape(predictions_layer,
76 | (batch_size, -1, p_shape[-1]))
77 | l_shape = localizations_layer.shape
78 | localizations_layer = np.reshape(localizations_layer,
79 | (batch_size, -1, l_shape[-1]))
80 |
81 | # Boxes selection: use threshold or score > no-label criteria.
82 | if select_threshold is None or select_threshold == 0:
83 | # Class prediction and scores: assign 0. to 0-class
84 | classes = np.argmax(predictions_layer, axis=2)
85 | scores = np.amax(predictions_layer, axis=2)
86 | mask = (classes > 0)
87 | classes = classes[mask]
88 | scores = scores[mask]
89 | bboxes = localizations_layer[mask]
90 | else:
91 | sub_predictions = predictions_layer[:, :, 1:]
92 | idxes = np.where(sub_predictions > select_threshold)
93 | classes = idxes[-1]+1
94 | scores = sub_predictions[idxes]
95 | bboxes = localizations_layer[idxes[:-1]]
96 |
97 | return classes, scores, bboxes
98 |
99 |
100 | def ssd_bboxes_select(predictions_net,
101 | localizations_net,
102 | anchors_net,
103 | select_threshold=0.5,
104 | img_shape=(300, 300),
105 | num_classes=21,
106 | decode=True):
107 | """Extract classes, scores and bounding boxes from network output layers.
108 |
109 | Return:
110 | classes, scores, bboxes: Numpy arrays...
111 | """
112 | l_classes = []
113 | l_scores = []
114 | l_bboxes = []
115 | # l_layers = []
116 | # l_idxes = []
117 | for i in range(len(predictions_net)):
118 | classes, scores, bboxes = ssd_bboxes_select_layer(
119 | predictions_net[i], localizations_net[i], anchors_net[i],
120 | select_threshold, img_shape, num_classes, decode)
121 | l_classes.append(classes)
122 | l_scores.append(scores)
123 | l_bboxes.append(bboxes)
124 | # Debug information.
125 | # l_layers.append(i)
126 | # l_idxes.append((i, idxes))
127 |
128 | classes = np.concatenate(l_classes, 0)
129 | scores = np.concatenate(l_scores, 0)
130 | bboxes = np.concatenate(l_bboxes, 0)
131 | return classes, scores, bboxes
132 |
133 |
134 | # =========================================================================== #
135 | # Common functions for bboxes handling and selection.
136 | # =========================================================================== #
137 | def bboxes_sort(classes, scores, bboxes, top_k=400):
138 | """Sort bounding boxes by decreasing order and keep only the top_k
139 | """
140 | # if priority_inside:
141 | # inside = (bboxes[:, 0] > margin) & (bboxes[:, 1] > margin) & \
142 | # (bboxes[:, 2] < 1-margin) & (bboxes[:, 3] < 1-margin)
143 | # idxes = np.argsort(-scores)
144 | # inside = inside[idxes]
145 | # idxes = np.concatenate([idxes[inside], idxes[~inside]])
146 | idxes = np.argsort(-scores)
147 | classes = classes[idxes][:top_k]
148 | scores = scores[idxes][:top_k]
149 | bboxes = bboxes[idxes][:top_k]
150 | return classes, scores, bboxes
151 |
152 |
153 | def bboxes_clip(bbox_ref, bboxes):
154 | """Clip bounding boxes with respect to reference bbox.
155 | """
156 | bboxes = np.copy(bboxes)
157 | bboxes = np.transpose(bboxes)
158 | bbox_ref = np.transpose(bbox_ref)
159 | bboxes[0] = np.maximum(bboxes[0], bbox_ref[0])
160 | bboxes[1] = np.maximum(bboxes[1], bbox_ref[1])
161 | bboxes[2] = np.minimum(bboxes[2], bbox_ref[2])
162 | bboxes[3] = np.minimum(bboxes[3], bbox_ref[3])
163 | bboxes = np.transpose(bboxes)
164 | return bboxes
165 |
166 |
167 | def bboxes_resize(bbox_ref, bboxes):
168 | """Resize bounding boxes based on a reference bounding box,
169 | assuming that the latter is [0, 0, 1, 1] after transform.
170 | """
171 | bboxes = np.copy(bboxes)
172 | # Translate.
173 | bboxes[:, 0] -= bbox_ref[0]
174 | bboxes[:, 1] -= bbox_ref[1]
175 | bboxes[:, 2] -= bbox_ref[0]
176 | bboxes[:, 3] -= bbox_ref[1]
177 | # Resize.
178 | resize = [bbox_ref[2] - bbox_ref[0], bbox_ref[3] - bbox_ref[1]]
179 | bboxes[:, 0] /= resize[0]
180 | bboxes[:, 1] /= resize[1]
181 | bboxes[:, 2] /= resize[0]
182 | bboxes[:, 3] /= resize[1]
183 | return bboxes
184 |
185 |
186 | def bboxes_jaccard(bboxes1, bboxes2):
187 | """Computing jaccard index between bboxes1 and bboxes2.
188 | Note: bboxes1 and bboxes2 can be multi-dimensional, but should broacastable.
189 | """
190 | bboxes1 = np.transpose(bboxes1)
191 | bboxes2 = np.transpose(bboxes2)
192 | # Intersection bbox and volume.
193 | int_ymin = np.maximum(bboxes1[0], bboxes2[0])
194 | int_xmin = np.maximum(bboxes1[1], bboxes2[1])
195 | int_ymax = np.minimum(bboxes1[2], bboxes2[2])
196 | int_xmax = np.minimum(bboxes1[3], bboxes2[3])
197 |
198 | int_h = np.maximum(int_ymax - int_ymin, 0.)
199 | int_w = np.maximum(int_xmax - int_xmin, 0.)
200 | int_vol = int_h * int_w
201 | # Union volume.
202 | vol1 = (bboxes1[2] - bboxes1[0]) * (bboxes1[3] - bboxes1[1])
203 | vol2 = (bboxes2[2] - bboxes2[0]) * (bboxes2[3] - bboxes2[1])
204 | jaccard = int_vol / (vol1 + vol2 - int_vol)
205 | return jaccard
206 |
207 |
208 | def bboxes_intersection(bboxes_ref, bboxes2):
209 | """Computing jaccard index between bboxes1 and bboxes2.
210 | Note: bboxes1 and bboxes2 can be multi-dimensional, but should broacastable.
211 | """
212 | bboxes_ref = np.transpose(bboxes_ref)
213 | bboxes2 = np.transpose(bboxes2)
214 | # Intersection bbox and volume.
215 | int_ymin = np.maximum(bboxes_ref[0], bboxes2[0])
216 | int_xmin = np.maximum(bboxes_ref[1], bboxes2[1])
217 | int_ymax = np.minimum(bboxes_ref[2], bboxes2[2])
218 | int_xmax = np.minimum(bboxes_ref[3], bboxes2[3])
219 |
220 | int_h = np.maximum(int_ymax - int_ymin, 0.)
221 | int_w = np.maximum(int_xmax - int_xmin, 0.)
222 | int_vol = int_h * int_w
223 | # Union volume.
224 | vol = (bboxes_ref[2] - bboxes_ref[0]) * (bboxes_ref[3] - bboxes_ref[1])
225 | score = int_vol / vol
226 | return score
227 |
228 |
229 | def bboxes_nms(classes, scores, bboxes, nms_threshold=0.45):
230 | """Apply non-maximum selection to bounding boxes.
231 | """
232 | keep_bboxes = np.ones(scores.shape, dtype=np.bool)
233 | for i in range(scores.size-1):
234 | if keep_bboxes[i]:
235 | # Computer overlap with bboxes which are following.
236 | overlap = bboxes_jaccard(bboxes[i], bboxes[(i+1):])
237 | # Overlap threshold for keeping + checking part of the same class
238 | keep_overlap = np.logical_or(overlap < nms_threshold, classes[(i+1):] != classes[i])
239 | keep_bboxes[(i+1):] = np.logical_and(keep_bboxes[(i+1):], keep_overlap)
240 |
241 | idxes = np.where(keep_bboxes)
242 | return classes[idxes], scores[idxes], bboxes[idxes]
243 |
244 |
245 | def bboxes_nms_fast(classes, scores, bboxes, threshold=0.45):
246 | """Apply non-maximum selection to bounding boxes.
247 | """
248 | pass
249 |
250 |
251 |
252 |
253 |
--------------------------------------------------------------------------------
/SSD/nets/readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SSD/nets/vgg.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains model definitions for versions of the Oxford VGG network.
16 |
17 | These model definitions were introduced in the following technical report:
18 |
19 | Very Deep Convolutional Networks For Large-Scale Image Recognition
20 | Karen Simonyan and Andrew Zisserman
21 | arXiv technical report, 2015
22 | PDF: http://arxiv.org/pdf/1409.1556.pdf
23 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
24 | CC-BY-4.0
25 |
26 | More information can be obtained from the VGG website:
27 | www.robots.ox.ac.uk/~vgg/research/very_deep/
28 |
29 | Usage:
30 | with slim.arg_scope(vgg.vgg_arg_scope()):
31 | outputs, end_points = vgg.vgg_a(inputs)
32 |
33 | with slim.arg_scope(vgg.vgg_arg_scope()):
34 | outputs, end_points = vgg.vgg_16(inputs)
35 |
36 | @@vgg_a
37 | @@vgg_16
38 | @@vgg_19
39 | """
40 | from __future__ import absolute_import
41 | from __future__ import division
42 | from __future__ import print_function
43 |
44 | import tensorflow as tf
45 |
46 | slim = tf.contrib.slim
47 |
48 |
49 | def vgg_arg_scope(weight_decay=0.0005):
50 | """Defines the VGG arg scope.
51 |
52 | Args:
53 | weight_decay: The l2 regularization coefficient.
54 |
55 | Returns:
56 | An arg_scope.
57 | """
58 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
59 | activation_fn=tf.nn.relu,
60 | weights_regularizer=slim.l2_regularizer(weight_decay),
61 | biases_initializer=tf.zeros_initializer):
62 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
63 | return arg_sc
64 |
65 |
66 | def vgg_a(inputs,
67 | num_classes=1000,
68 | is_training=True,
69 | dropout_keep_prob=0.5,
70 | spatial_squeeze=True,
71 | scope='vgg_a'):
72 | """Oxford Net VGG 11-Layers version A Example.
73 |
74 | Note: All the fully_connected layers have been transformed to conv2d layers.
75 | To use in classification mode, resize input to 224x224.
76 |
77 | Args:
78 | inputs: a tensor of size [batch_size, height, width, channels].
79 | num_classes: number of predicted classes.
80 | is_training: whether or not the model is being trained.
81 | dropout_keep_prob: the probability that activations are kept in the dropout
82 | layers during training.
83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
84 | outputs. Useful to remove unnecessary dimensions for classification.
85 | scope: Optional scope for the variables.
86 |
87 | Returns:
88 | the last op containing the log predictions and end_points dict.
89 | """
90 | with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc:
91 | end_points_collection = sc.name + '_end_points'
92 | # Collect outputs for conv2d, fully_connected and max_pool2d.
93 | with slim.arg_scope([slim.conv2d, slim.max_pool2d],
94 | outputs_collections=end_points_collection):
95 | net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1')
96 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
97 | net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2')
98 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
99 | net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3')
100 | net = slim.max_pool2d(net, [2, 2], scope='pool3')
101 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4')
102 | net = slim.max_pool2d(net, [2, 2], scope='pool4')
103 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5')
104 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
105 | # Use conv2d instead of fully_connected layers.
106 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
107 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
108 | scope='dropout6')
109 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
111 | scope='dropout7')
112 | net = slim.conv2d(net, num_classes, [1, 1],
113 | activation_fn=None,
114 | normalizer_fn=None,
115 | scope='fc8')
116 | # Convert end_points_collection into a end_point dict.
117 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
118 | if spatial_squeeze:
119 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
120 | end_points[sc.name + '/fc8'] = net
121 | return net, end_points
122 | vgg_a.default_image_size = 224
123 |
124 |
125 | def vgg_16(inputs,
126 | num_classes=1000,
127 | is_training=True,
128 | dropout_keep_prob=0.5,
129 | spatial_squeeze=True,
130 | scope='vgg_16'):
131 | """Oxford Net VGG 16-Layers version D Example.
132 |
133 | Note: All the fully_connected layers have been transformed to conv2d layers.
134 | To use in classification mode, resize input to 224x224.
135 |
136 | Args:
137 | inputs: a tensor of size [batch_size, height, width, channels].
138 | num_classes: number of predicted classes.
139 | is_training: whether or not the model is being trained.
140 | dropout_keep_prob: the probability that activations are kept in the dropout
141 | layers during training.
142 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
143 | outputs. Useful to remove unnecessary dimensions for classification.
144 | scope: Optional scope for the variables.
145 |
146 | Returns:
147 | the last op containing the log predictions and end_points dict.
148 | """
149 | with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
150 | end_points_collection = sc.name + '_end_points'
151 | # Collect outputs for conv2d, fully_connected and max_pool2d.
152 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
153 | outputs_collections=end_points_collection):
154 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
155 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
156 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
157 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
158 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
159 | net = slim.max_pool2d(net, [2, 2], scope='pool3')
160 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
161 | net = slim.max_pool2d(net, [2, 2], scope='pool4')
162 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
163 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
164 | # Use conv2d instead of fully_connected layers.
165 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
166 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
167 | scope='dropout6')
168 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
169 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
170 | scope='dropout7')
171 | net = slim.conv2d(net, num_classes, [1, 1],
172 | activation_fn=None,
173 | normalizer_fn=None,
174 | scope='fc8')
175 | # Convert end_points_collection into a end_point dict.
176 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
177 | if spatial_squeeze:
178 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
179 | end_points[sc.name + '/fc8'] = net
180 | return net, end_points
181 | vgg_16.default_image_size = 224
182 |
183 |
184 | def vgg_19(inputs,
185 | num_classes=1000,
186 | is_training=True,
187 | dropout_keep_prob=0.5,
188 | spatial_squeeze=True,
189 | scope='vgg_19'):
190 | """Oxford Net VGG 19-Layers version E Example.
191 |
192 | Note: All the fully_connected layers have been transformed to conv2d layers.
193 | To use in classification mode, resize input to 224x224.
194 |
195 | Args:
196 | inputs: a tensor of size [batch_size, height, width, channels].
197 | num_classes: number of predicted classes.
198 | is_training: whether or not the model is being trained.
199 | dropout_keep_prob: the probability that activations are kept in the dropout
200 | layers during training.
201 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
202 | outputs. Useful to remove unnecessary dimensions for classification.
203 | scope: Optional scope for the variables.
204 |
205 | Returns:
206 | the last op containing the log predictions and end_points dict.
207 | """
208 | with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc:
209 | end_points_collection = sc.name + '_end_points'
210 | # Collect outputs for conv2d, fully_connected and max_pool2d.
211 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
212 | outputs_collections=end_points_collection):
213 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
214 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
215 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
216 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
217 | net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3')
218 | net = slim.max_pool2d(net, [2, 2], scope='pool3')
219 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4')
220 | net = slim.max_pool2d(net, [2, 2], scope='pool4')
221 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5')
222 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
223 | # Use conv2d instead of fully_connected layers.
224 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
225 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
226 | scope='dropout6')
227 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
228 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
229 | scope='dropout7')
230 | net = slim.conv2d(net, num_classes, [1, 1],
231 | activation_fn=None,
232 | normalizer_fn=None,
233 | scope='fc8')
234 | # Convert end_points_collection into a end_point dict.
235 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
236 | if spatial_squeeze:
237 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
238 | end_points[sc.name + '/fc8'] = net
239 | return net, end_points
240 | vgg_19.default_image_size = 224
241 |
242 | # Alias
243 | vgg_d = vgg_16
244 | vgg_e = vgg_19
245 |
--------------------------------------------------------------------------------
/SSD/nets/xception.py:
--------------------------------------------------------------------------------
1 | """Definition of Xception model introduced by F. Chollet.
2 |
3 | Usage:
4 | with slim.arg_scope(xception.xception_arg_scope()):
5 | outputs, end_points = xception.xception(inputs)
6 | @@xception
7 | """
8 |
9 | import tensorflow as tf
10 | slim = tf.contrib.slim
11 |
12 |
13 | # =========================================================================== #
14 | # Xception implementation (clean)
15 | # =========================================================================== #
16 | def xception(inputs,
17 | num_classes=1000,
18 | is_training=True,
19 | dropout_keep_prob=0.5,
20 | prediction_fn=slim.softmax,
21 | reuse=None,
22 | scope='xception'):
23 | """Xception model from https://arxiv.org/pdf/1610.02357v2.pdf
24 |
25 | The default image size used to train this network is 299x299.
26 | """
27 |
28 | # end_points collect relevant activations for external use, for example
29 | # summaries or losses.
30 | end_points = {}
31 |
32 | with tf.variable_scope(scope, 'xception', [inputs]):
33 | # Block 1.
34 | end_point = 'block1'
35 | with tf.variable_scope(end_point):
36 | net = slim.conv2d(inputs, 32, [3, 3], stride=2, padding='VALID', scope='conv1')
37 | net = slim.conv2d(net, 64, [3, 3], padding='VALID', scope='conv2')
38 | end_points[end_point] = net
39 |
40 | # Residual block 2.
41 | end_point = 'block2'
42 | with tf.variable_scope(end_point):
43 | res = slim.conv2d(net, 128, [1, 1], stride=2, activation_fn=None, scope='res')
44 | net = slim.separable_convolution2d(net, 128, [3, 3], 1, scope='sepconv1')
45 | net = slim.separable_convolution2d(net, 128, [3, 3], 1, activation_fn=None, scope='sepconv2')
46 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool')
47 | net = res + net
48 | end_points[end_point] = net
49 |
50 | # Residual block 3.
51 | end_point = 'block3'
52 | with tf.variable_scope(end_point):
53 | res = slim.conv2d(net, 256, [1, 1], stride=2, activation_fn=None, scope='res')
54 | net = tf.nn.relu(net)
55 | net = slim.separable_convolution2d(net, 256, [3, 3], 1, scope='sepconv1')
56 | net = slim.separable_convolution2d(net, 256, [3, 3], 1, activation_fn=None, scope='sepconv2')
57 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool')
58 | net = res + net
59 | end_points[end_point] = net
60 |
61 | # Residual block 4.
62 | end_point = 'block4'
63 | with tf.variable_scope(end_point):
64 | res = slim.conv2d(net, 728, [1, 1], stride=2, activation_fn=None, scope='res')
65 | net = tf.nn.relu(net)
66 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, scope='sepconv1')
67 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None, scope='sepconv2')
68 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool')
69 | net = res + net
70 | end_points[end_point] = net
71 |
72 | # Middle flow blocks.
73 | for i in range(8):
74 | end_point = 'block' + str(i + 5)
75 | with tf.variable_scope(end_point):
76 | res = net
77 | net = tf.nn.relu(net)
78 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None,
79 | scope='sepconv1')
80 | net = tf.nn.relu(net)
81 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None,
82 | scope='sepconv2')
83 | net = tf.nn.relu(net)
84 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None,
85 | scope='sepconv3')
86 | net = res + net
87 | end_points[end_point] = net
88 |
89 | # Exit flow: blocks 13 and 14.
90 | end_point = 'block13'
91 | with tf.variable_scope(end_point):
92 | res = slim.conv2d(net, 1024, [1, 1], stride=2, activation_fn=None, scope='res')
93 | net = tf.nn.relu(net)
94 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None, scope='sepconv1')
95 | net = tf.nn.relu(net)
96 | net = slim.separable_convolution2d(net, 1024, [3, 3], 1, activation_fn=None, scope='sepconv2')
97 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool')
98 | net = res + net
99 | end_points[end_point] = net
100 |
101 | end_point = 'block14'
102 | with tf.variable_scope(end_point):
103 | net = slim.separable_convolution2d(net, 1536, [3, 3], 1, scope='sepconv1')
104 | net = slim.separable_convolution2d(net, 2048, [3, 3], 1, scope='sepconv2')
105 | end_points[end_point] = net
106 |
107 | # Global averaging.
108 | end_point = 'dense'
109 | with tf.variable_scope(end_point):
110 | net = tf.reduce_mean(net, [1, 2], name='reduce_avg')
111 | logits = slim.fully_connected(net, 1000, activation_fn=None)
112 |
113 | end_points['logits'] = logits
114 | end_points['predictions'] = prediction_fn(logits, scope='Predictions')
115 |
116 | return logits, end_points
117 | xception.default_image_size = 299
118 |
119 |
120 | def xception_arg_scope(weight_decay=0.00001, stddev=0.1):
121 | """Defines the default Xception arg scope.
122 |
123 | Args:
124 | weight_decay: The weight decay to use for regularizing the model.
125 | stddev: The standard deviation of the trunctated normal weight initializer.
126 |
127 | Returns:
128 | An `arg_scope` to use for the xception model.
129 | """
130 | batch_norm_params = {
131 | # Decay for the moving averages.
132 | 'decay': 0.9997,
133 | # epsilon to prevent 0s in variance.
134 | 'epsilon': 0.001,
135 | # collection containing update_ops.
136 | 'updates_collections': tf.GraphKeys.UPDATE_OPS,
137 | }
138 |
139 | # Set weight_decay for weights in Conv and FC layers.
140 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.separable_convolution2d],
141 | weights_regularizer=slim.l2_regularizer(weight_decay)):
142 | with slim.arg_scope(
143 | [slim.conv2d, slim.separable_convolution2d],
144 | padding='SAME',
145 | weights_initializer=tf.contrib.layers.variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False),
146 | activation_fn=tf.nn.relu,
147 | normalizer_fn=slim.batch_norm,
148 | normalizer_params=batch_norm_params):
149 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as sc:
150 | return sc
151 |
152 |
153 | # =========================================================================== #
154 | # Xception arg scope (Keras hack!)
155 | # =========================================================================== #
156 | def xception_keras_arg_scope(hdf5_file, weight_decay=0.00001):
157 | """Defines an Xception arg scope which initialize layers weights
158 | using a Keras HDF5 file.
159 |
160 | Quite hacky implementaion, but seems to be working!
161 |
162 | Args:
163 | hdf5_file: HDF5 file handle.
164 | weight_decay: The weight decay to use for regularizing the model.
165 |
166 | Returns:
167 | An `arg_scope` to use for the xception model.
168 | """
169 | # Default batch normalization parameters.
170 | batch_norm_params = {
171 | 'center': True,
172 | 'scale': False,
173 | 'decay': 0.9997,
174 | 'epsilon': 0.001,
175 | 'updates_collections': tf.GraphKeys.UPDATE_OPS,
176 | }
177 |
178 | # Read weights from HDF5 file.
179 | def keras_bn_params():
180 | def _beta_initializer(shape, dtype, partition_info=None):
181 | keras_bn_params.bidx += 1
182 | k = 'batchnormalization_%i' % keras_bn_params.bidx
183 | kb = 'batchnormalization_%i_beta:0' % keras_bn_params.bidx
184 | return tf.cast(hdf5_file[k][kb][:], dtype)
185 |
186 | def _gamma_initializer(shape, dtype, partition_info=None):
187 | keras_bn_params.gidx += 1
188 | k = 'batchnormalization_%i' % keras_bn_params.gidx
189 | kg = 'batchnormalization_%i_gamma:0' % keras_bn_params.gidx
190 | return tf.cast(hdf5_file[k][kg][:], dtype)
191 |
192 | def _mean_initializer(shape, dtype, partition_info=None):
193 | keras_bn_params.midx += 1
194 | k = 'batchnormalization_%i' % keras_bn_params.midx
195 | km = 'batchnormalization_%i_running_mean:0' % keras_bn_params.midx
196 | return tf.cast(hdf5_file[k][km][:], dtype)
197 |
198 | def _variance_initializer(shape, dtype, partition_info=None):
199 | keras_bn_params.vidx += 1
200 | k = 'batchnormalization_%i' % keras_bn_params.vidx
201 | kv = 'batchnormalization_%i_running_std:0' % keras_bn_params.vidx
202 | return tf.cast(hdf5_file[k][kv][:], dtype)
203 |
204 | # Batch normalisation initializers.
205 | params = batch_norm_params.copy()
206 | params['initializers'] = {
207 | 'beta': _beta_initializer,
208 | 'gamma': _gamma_initializer,
209 | 'moving_mean': _mean_initializer,
210 | 'moving_variance': _variance_initializer,
211 | }
212 | return params
213 | keras_bn_params.bidx = 0
214 | keras_bn_params.gidx = 0
215 | keras_bn_params.midx = 0
216 | keras_bn_params.vidx = 0
217 |
218 | def keras_conv2d_weights():
219 | def _initializer(shape, dtype, partition_info=None):
220 | keras_conv2d_weights.idx += 1
221 | k = 'convolution2d_%i' % keras_conv2d_weights.idx
222 | kw = 'convolution2d_%i_W:0' % keras_conv2d_weights.idx
223 | return tf.cast(hdf5_file[k][kw][:], dtype)
224 | return _initializer
225 | keras_conv2d_weights.idx = 0
226 |
227 | def keras_sep_conv2d_weights():
228 | def _initializer(shape, dtype, partition_info=None):
229 | # Depthwise or Pointwise convolution?
230 | if shape[0] > 1 or shape[1] > 1:
231 | keras_sep_conv2d_weights.didx += 1
232 | k = 'separableconvolution2d_%i' % keras_sep_conv2d_weights.didx
233 | kd = 'separableconvolution2d_%i_depthwise_kernel:0' % keras_sep_conv2d_weights.didx
234 | weights = hdf5_file[k][kd][:]
235 | else:
236 | keras_sep_conv2d_weights.pidx += 1
237 | k = 'separableconvolution2d_%i' % keras_sep_conv2d_weights.pidx
238 | kp = 'separableconvolution2d_%i_pointwise_kernel:0' % keras_sep_conv2d_weights.pidx
239 | weights = hdf5_file[k][kp][:]
240 | return tf.cast(weights, dtype)
241 | return _initializer
242 | keras_sep_conv2d_weights.didx = 0
243 | keras_sep_conv2d_weights.pidx = 0
244 |
245 | def keras_dense_weights():
246 | def _initializer(shape, dtype, partition_info=None):
247 | keras_dense_weights.idx += 1
248 | k = 'dense_%i' % keras_dense_weights.idx
249 | kw = 'dense_%i_W:0' % keras_dense_weights.idx
250 | return tf.cast(hdf5_file[k][kw][:], dtype)
251 | return _initializer
252 | keras_dense_weights.idx = 1
253 |
254 | def keras_dense_biases():
255 | def _initializer(shape, dtype, partition_info=None):
256 | keras_dense_biases.idx += 1
257 | k = 'dense_%i' % keras_dense_biases.idx
258 | kb = 'dense_%i_b:0' % keras_dense_biases.idx
259 | return tf.cast(hdf5_file[k][kb][:], dtype)
260 | return _initializer
261 | keras_dense_biases.idx = 1
262 |
263 | # Default network arg scope.
264 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.separable_convolution2d],
265 | weights_regularizer=slim.l2_regularizer(weight_decay)):
266 | with slim.arg_scope(
267 | [slim.conv2d, slim.separable_convolution2d],
268 | padding='SAME',
269 | activation_fn=tf.nn.relu,
270 | normalizer_fn=slim.batch_norm,
271 | normalizer_params=keras_bn_params()):
272 | with slim.arg_scope([slim.max_pool2d], padding='SAME'):
273 |
274 | # Weights initializers from Keras weights.
275 | with slim.arg_scope([slim.conv2d],
276 | weights_initializer=keras_conv2d_weights()):
277 | with slim.arg_scope([slim.separable_convolution2d],
278 | weights_initializer=keras_sep_conv2d_weights()):
279 | with slim.arg_scope([slim.fully_connected],
280 | weights_initializer=keras_dense_weights(),
281 | biases_initializer=keras_dense_biases()) as sc:
282 | return sc
283 |
284 |
--------------------------------------------------------------------------------
/SSD/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SSD/preprocessing/preprocessing_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from preprocessing import ssd_vgg_preprocessing
24 |
25 | # from preprocessing import cifarnet_preprocessing
26 | # from preprocessing import inception_preprocessing
27 | # from preprocessing import vgg_preprocessing
28 |
29 | slim = tf.contrib.slim
30 |
31 |
32 | def get_preprocessing(name, is_training=False):
33 | """Returns preprocessing_fn(image, height, width, **kwargs).
34 |
35 | Args:
36 | name: The name of the preprocessing function.
37 | is_training: `True` if the model is being used for training.
38 |
39 | Returns:
40 | preprocessing_fn: A function that preprocessing a single image (pre-batch).
41 | It has the following signature:
42 | image = preprocessing_fn(image, output_height, output_width, ...).
43 |
44 | Raises:
45 | ValueError: If Preprocessing `name` is not recognized.
46 | """
47 | preprocessing_fn_map = {
48 | 'ssd_300_vgg': ssd_vgg_preprocessing,
49 | 'ssd_512_vgg': ssd_vgg_preprocessing,
50 | }
51 |
52 | if name not in preprocessing_fn_map:
53 | raise ValueError('Preprocessing name [%s] was not recognized' % name)
54 |
55 | def preprocessing_fn(image, labels, bboxes,
56 | out_shape, data_format='NHWC', **kwargs):
57 | return preprocessing_fn_map[name].preprocess_image(
58 | image, labels, bboxes, out_shape, data_format=data_format,
59 | is_training=is_training, **kwargs)
60 | return preprocessing_fn
61 |
--------------------------------------------------------------------------------
/SSD/preprocessing/readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SSD/readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/SSD/tf_convert_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Convert a dataset to TFRecords format, which can be easily integrated into
16 | a TensorFlow pipeline.
17 |
18 | Usage:
19 | ```shell
20 | python tf_convert_data.py \
21 | --dataset_name=pascalvoc \
22 | --dataset_dir=/tmp/pascalvoc \
23 | --output_name=pascalvoc \
24 | --output_dir=/tmp/
25 | ```
26 | """
27 | import tensorflow as tf
28 |
29 | from datasets import pascalvoc_to_tfrecords
30 |
31 | FLAGS = tf.app.flags.FLAGS
32 |
33 | tf.app.flags.DEFINE_string(
34 | 'dataset_name', 'pascalvoc',
35 | 'The name of the dataset to convert.')
36 | tf.app.flags.DEFINE_string(
37 | 'dataset_dir', None,
38 | 'Directory where the original dataset is stored.')
39 | tf.app.flags.DEFINE_string(
40 | 'output_name', 'pascalvoc',
41 | 'Basename used for TFRecords output files.')
42 | tf.app.flags.DEFINE_string(
43 | 'output_dir', './',
44 | 'Output directory where to store TFRecords files.')
45 |
46 |
47 | def main(_):
48 | if not FLAGS.dataset_dir:
49 | raise ValueError('You must supply the dataset directory with --dataset_dir')
50 | print('Dataset directory:', FLAGS.dataset_dir)
51 | print('Output directory:', FLAGS.output_dir)
52 |
53 | if FLAGS.dataset_name == 'pascalvoc':
54 | pascalvoc_to_tfrecords.run(FLAGS.dataset_dir, FLAGS.output_dir, FLAGS.output_name)
55 | else:
56 | raise ValueError('Dataset [%s] was not recognized.' % FLAGS.dataset_name)
57 |
58 | if __name__ == '__main__':
59 | tf.app.run()
60 |
61 |
--------------------------------------------------------------------------------
/SSD/tf_extended/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF Extended: additional metrics.
16 | """
17 |
18 | from tf_extended.bboxes import *
19 | from tf_extended.image import *
20 | from tf_extended.math import *
21 | # pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import
22 | from tf_extended.metrics import *
23 | from tf_extended.tensors import *
24 |
25 |
--------------------------------------------------------------------------------
/SSD/tf_extended/image.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/SSD/tf_extended/image.py
--------------------------------------------------------------------------------
/SSD/tf_extended/math.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF Extended: additional math functions.
16 | """
17 | import tensorflow as tf
18 | from tensorflow.python.framework import ops
19 | from tensorflow.python.ops import math_ops
20 |
21 |
22 | def safe_divide(numerator, denominator, name):
23 | """Divides two values, returning 0 if the denominator is <= 0.
24 | Args:
25 | numerator: A real `Tensor`.
26 | denominator: A real `Tensor`, with dtype matching `numerator`.
27 | name: Name for the returned op.
28 | Returns:
29 | 0 if `denominator` <= 0, else `numerator` / `denominator`
30 | """
31 | return tf.where(
32 | math_ops.greater(denominator, 0),
33 | math_ops.divide(numerator, denominator),
34 | tf.zeros_like(numerator),
35 | name=name)
36 |
37 |
38 | def cummax(x, reverse=False, name=None):
39 | """Compute the cumulative maximum of the tensor `x` along `axis`. This
40 | operation is similar to the more classic `cumsum`. Only support 1D Tensor
41 | for now.
42 |
43 | Args:
44 | x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
45 | `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
46 | `complex128`, `qint8`, `quint8`, `qint32`, `half`.
47 | axis: A `Tensor` of type `int32` (default: 0).
48 | reverse: A `bool` (default: False).
49 | name: A name for the operation (optional).
50 | Returns:
51 | A `Tensor`. Has the same type as `x`.
52 | """
53 | with ops.name_scope(name, "Cummax", [x]) as name:
54 | x = ops.convert_to_tensor(x, name="x")
55 | # Not very optimal: should directly integrate reverse into tf.scan.
56 | if reverse:
57 | x = tf.reverse(x, axis=[0])
58 | # 'Accumlating' maximum: ensure it is always increasing.
59 | cmax = tf.scan(lambda a, y: tf.maximum(a, y), x,
60 | initializer=None, parallel_iterations=1,
61 | back_prop=False, swap_memory=False)
62 | if reverse:
63 | cmax = tf.reverse(cmax, axis=[0])
64 | return cmax
65 |
--------------------------------------------------------------------------------
/SSD/tf_extended/tensors.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF Extended: additional tensors operations.
16 | """
17 | import tensorflow as tf
18 |
19 |
20 | def get_shape(x, rank=None):
21 | """Returns the dimensions of a Tensor as list of integers or scale tensors.
22 |
23 | Args:
24 | x: N-d Tensor;
25 | rank: Rank of the Tensor. If None, will try to guess it.
26 | Returns:
27 | A list of `[d1, d2, ..., dN]` corresponding to the dimensions of the
28 | input tensor. Dimensions that are statically known are python integers,
29 | otherwise they are integer scalar tensors.
30 | """
31 | if x.get_shape().is_fully_defined():
32 | return x.get_shape().as_list()
33 | else:
34 | static_shape = x.get_shape()
35 | if rank is None:
36 | static_shape = static_shape.as_list()
37 | rank = len(static_shape)
38 | else:
39 | static_shape = x.get_shape().with_rank(rank).as_list()
40 | dynamic_shape = tf.unstack(tf.shape(x), rank)
41 | return [s if s is not None else d
42 | for s, d in zip(static_shape, dynamic_shape)]
43 |
44 |
45 | def pad_axis(x, offset, size, axis=0, name=None):
46 | """Pad a tensor on an axis, with a given offset and output size.
47 | The tensor is padded with zero (i.e. CONSTANT mode). Note that the if the
48 | `size` is smaller than existing size + `offset`, the output tensor
49 | was the latter dimension.
50 |
51 | Args:
52 | x: Tensor to pad;
53 | offset: Offset to add on the dimension chosen;
54 | size: Final size of the dimension.
55 | Return:
56 | Padded tensor whose dimension on `axis` is `size`, or greater if
57 | the input vector was larger.
58 | """
59 | with tf.name_scope(name, 'pad_axis'):
60 | shape = get_shape(x)
61 | rank = len(shape)
62 | # Padding description.
63 | new_size = tf.maximum(size-offset-shape[axis], 0)
64 | pad1 = tf.stack([0]*axis + [offset] + [0]*(rank-axis-1))
65 | pad2 = tf.stack([0]*axis + [new_size] + [0]*(rank-axis-1))
66 | paddings = tf.stack([pad1, pad2], axis=1)
67 | x = tf.pad(x, paddings, mode='CONSTANT')
68 | # Reshape, to get fully defined shape if possible.
69 | # TODO: fix with tf.slice
70 | shape[axis] = size
71 | x = tf.reshape(x, tf.stack(shape))
72 | return x
73 |
74 |
75 | # def select_at_index(idx, val, t):
76 | # """Return a tensor.
77 | # """
78 | # idx = tf.expand_dims(tf.expand_dims(idx, 0), 0)
79 | # val = tf.expand_dims(val, 0)
80 | # t = t + tf.scatter_nd(idx, val, tf.shape(t))
81 | # return t
82 |
--------------------------------------------------------------------------------
/SSD/tf_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Diverse TensorFlow utils, for training, evaluation and so on!
16 | """
17 | import os
18 | from pprint import pprint
19 |
20 | import tensorflow as tf
21 | from tensorflow.contrib.slim.python.slim.data import parallel_reader
22 |
23 | slim = tf.contrib.slim
24 |
25 |
26 | # =========================================================================== #
27 | # General tools.
28 | # =========================================================================== #
29 | def reshape_list(l, shape=None):
30 | """Reshape list of (list): 1D to 2D or the other way around.
31 |
32 | Args:
33 | l: List or List of list.
34 | shape: 1D or 2D shape.
35 | Return
36 | Reshaped list.
37 | """
38 | r = []
39 | if shape is None:
40 | # Flatten everything.
41 | for a in l:
42 | if isinstance(a, (list, tuple)):
43 | r = r + list(a)
44 | else:
45 | r.append(a)
46 | else:
47 | # Reshape to list of list.
48 | i = 0
49 | for s in shape:
50 | if s == 1:
51 | r.append(l[i])
52 | else:
53 | r.append(l[i:i+s])
54 | i += s
55 | return r
56 |
57 |
58 | # =========================================================================== #
59 | # Training utils.
60 | # =========================================================================== #
61 | def print_configuration(flags, ssd_params, data_sources, save_dir=None):
62 | """Print the training configuration.
63 | """
64 | def print_config(stream=None):
65 | print('\n# =========================================================================== #', file=stream)
66 | print('# Training | Evaluation flags:', file=stream)
67 | print('# =========================================================================== #', file=stream)
68 | pprint(flags, stream=stream)
69 |
70 | print('\n# =========================================================================== #', file=stream)
71 | print('# SSD net parameters:', file=stream)
72 | print('# =========================================================================== #', file=stream)
73 | pprint(dict(ssd_params._asdict()), stream=stream)
74 |
75 | print('\n# =========================================================================== #', file=stream)
76 | print('# Training | Evaluation dataset files:', file=stream)
77 | print('# =========================================================================== #', file=stream)
78 | data_files = parallel_reader.get_data_files(data_sources)
79 | pprint(sorted(data_files), stream=stream)
80 | print('', file=stream)
81 |
82 | print_config(None)
83 | # Save to a text file as well.
84 | if save_dir is not None:
85 | if not os.path.exists(save_dir):
86 | os.makedirs(save_dir)
87 | path = os.path.join(save_dir, 'training_config.txt')
88 | with open(path, "w") as out:
89 | print_config(out)
90 |
91 |
92 | def configure_learning_rate(flags, num_samples_per_epoch, global_step):
93 | """Configures the learning rate.
94 |
95 | Args:
96 | num_samples_per_epoch: The number of samples in each epoch of training.
97 | global_step: The global_step tensor.
98 | Returns:
99 | A `Tensor` representing the learning rate.
100 | """
101 | decay_steps = int(num_samples_per_epoch / flags.batch_size *
102 | flags.num_epochs_per_decay)
103 |
104 | if flags.learning_rate_decay_type == 'exponential':
105 | return tf.train.exponential_decay(flags.learning_rate,
106 | global_step,
107 | decay_steps,
108 | flags.learning_rate_decay_factor,
109 | staircase=True,
110 | name='exponential_decay_learning_rate')
111 | elif flags.learning_rate_decay_type == 'fixed':
112 | return tf.constant(flags.learning_rate, name='fixed_learning_rate')
113 | elif flags.learning_rate_decay_type == 'polynomial':
114 | return tf.train.polynomial_decay(flags.learning_rate,
115 | global_step,
116 | decay_steps,
117 | flags.end_learning_rate,
118 | power=1.0,
119 | cycle=False,
120 | name='polynomial_decay_learning_rate')
121 | else:
122 | raise ValueError('learning_rate_decay_type [%s] was not recognized',
123 | flags.learning_rate_decay_type)
124 |
125 |
126 | def configure_optimizer(flags, learning_rate):
127 | """Configures the optimizer used for training.
128 |
129 | Args:
130 | learning_rate: A scalar or `Tensor` learning rate.
131 | Returns:
132 | An instance of an optimizer.
133 | """
134 | if flags.optimizer == 'adadelta':
135 | optimizer = tf.train.AdadeltaOptimizer(
136 | learning_rate,
137 | rho=flags.adadelta_rho,
138 | epsilon=flags.opt_epsilon)
139 | elif flags.optimizer == 'adagrad':
140 | optimizer = tf.train.AdagradOptimizer(
141 | learning_rate,
142 | initial_accumulator_value=flags.adagrad_initial_accumulator_value)
143 | elif flags.optimizer == 'adam':
144 | optimizer = tf.train.AdamOptimizer(
145 | learning_rate,
146 | beta1=flags.adam_beta1,
147 | beta2=flags.adam_beta2,
148 | epsilon=flags.opt_epsilon)
149 | elif flags.optimizer == 'ftrl':
150 | optimizer = tf.train.FtrlOptimizer(
151 | learning_rate,
152 | learning_rate_power=flags.ftrl_learning_rate_power,
153 | initial_accumulator_value=flags.ftrl_initial_accumulator_value,
154 | l1_regularization_strength=flags.ftrl_l1,
155 | l2_regularization_strength=flags.ftrl_l2)
156 | elif flags.optimizer == 'momentum':
157 | optimizer = tf.train.MomentumOptimizer(
158 | learning_rate,
159 | momentum=flags.momentum,
160 | name='Momentum')
161 | elif flags.optimizer == 'rmsprop':
162 | optimizer = tf.train.RMSPropOptimizer(
163 | learning_rate,
164 | decay=flags.rmsprop_decay,
165 | momentum=flags.rmsprop_momentum,
166 | epsilon=flags.opt_epsilon)
167 | elif flags.optimizer == 'sgd':
168 | optimizer = tf.train.GradientDescentOptimizer(learning_rate)
169 | else:
170 | raise ValueError('Optimizer [%s] was not recognized', flags.optimizer)
171 | return optimizer
172 |
173 |
174 | def add_variables_summaries(learning_rate):
175 | summaries = []
176 | for variable in slim.get_model_variables():
177 | summaries.append(tf.summary.histogram(variable.op.name, variable))
178 | summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate))
179 | return summaries
180 |
181 |
182 | def update_model_scope(var, ckpt_scope, new_scope):
183 | return var.op.name.replace(new_scope,'vgg_16')
184 |
185 |
186 | def get_init_fn(flags):
187 | """Returns a function run by the chief worker to warm-start the training.
188 | Note that the init_fn is only run when initializing the model during the very
189 | first global step.
190 |
191 | Returns:
192 | An init function run by the supervisor.
193 | """
194 | if flags.checkpoint_path is None:
195 | return None
196 | # Warn the user if a checkpoint exists in the train_dir. Then ignore.
197 | if tf.train.latest_checkpoint(flags.train_dir):
198 | tf.logging.info(
199 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s'
200 | % flags.train_dir)
201 | return None
202 |
203 | exclusions = []
204 | if flags.checkpoint_exclude_scopes:
205 | exclusions = [scope.strip()
206 | for scope in flags.checkpoint_exclude_scopes.split(',')]
207 |
208 | # TODO(sguada) variables.filter_variables()
209 | variables_to_restore = []
210 | for var in slim.get_model_variables():
211 | excluded = False
212 | for exclusion in exclusions:
213 | if var.op.name.startswith(exclusion):
214 | excluded = True
215 | break
216 | if not excluded:
217 | variables_to_restore.append(var)
218 | # Change model scope if necessary.
219 | if flags.checkpoint_model_scope is not None:
220 | variables_to_restore = \
221 | {var.op.name.replace(flags.model_name,
222 | flags.checkpoint_model_scope): var
223 | for var in variables_to_restore}
224 |
225 |
226 | if tf.gfile.IsDirectory(flags.checkpoint_path):
227 | checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
228 | else:
229 | checkpoint_path = flags.checkpoint_path
230 | tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars))
231 |
232 | return slim.assign_from_checkpoint_fn(
233 | checkpoint_path,
234 | variables_to_restore,
235 | ignore_missing_vars=flags.ignore_missing_vars)
236 |
237 |
238 | def get_variables_to_train(flags):
239 | """Returns a list of variables to train.
240 |
241 | Returns:
242 | A list of variables to train by the optimizer.
243 | """
244 | if flags.trainable_scopes is None:
245 | return tf.trainable_variables()
246 | else:
247 | scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')]
248 |
249 | variables_to_train = []
250 | for scope in scopes:
251 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
252 | variables_to_train.extend(variables)
253 | return variables_to_train
254 |
255 |
256 | # =========================================================================== #
257 | # Evaluation utils.
258 | # =========================================================================== #
259 |
--------------------------------------------------------------------------------
/spaceshooter/assets/bolt_gold.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/bolt_gold.png
--------------------------------------------------------------------------------
/spaceshooter/assets/laserRed16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/laserRed16.png
--------------------------------------------------------------------------------
/spaceshooter/assets/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/main.png
--------------------------------------------------------------------------------
/spaceshooter/assets/meteorBrown_big1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_big1.png
--------------------------------------------------------------------------------
/spaceshooter/assets/meteorBrown_big2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_big2.png
--------------------------------------------------------------------------------
/spaceshooter/assets/meteorBrown_med1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_med1.png
--------------------------------------------------------------------------------
/spaceshooter/assets/meteorBrown_med3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_med3.png
--------------------------------------------------------------------------------
/spaceshooter/assets/meteorBrown_small1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_small1.png
--------------------------------------------------------------------------------
/spaceshooter/assets/meteorBrown_small2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_small2.png
--------------------------------------------------------------------------------
/spaceshooter/assets/meteorBrown_tiny1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_tiny1.png
--------------------------------------------------------------------------------
/spaceshooter/assets/missile.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/missile.png
--------------------------------------------------------------------------------
/spaceshooter/assets/playerShip1_orange.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/playerShip1_orange.png
--------------------------------------------------------------------------------
/spaceshooter/assets/regularExplosion00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion00.png
--------------------------------------------------------------------------------
/spaceshooter/assets/regularExplosion01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion01.png
--------------------------------------------------------------------------------
/spaceshooter/assets/regularExplosion02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion02.png
--------------------------------------------------------------------------------
/spaceshooter/assets/regularExplosion03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion03.png
--------------------------------------------------------------------------------
/spaceshooter/assets/regularExplosion04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion04.png
--------------------------------------------------------------------------------
/spaceshooter/assets/regularExplosion05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion05.png
--------------------------------------------------------------------------------
/spaceshooter/assets/regularExplosion06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion06.png
--------------------------------------------------------------------------------
/spaceshooter/assets/regularExplosion07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion07.png
--------------------------------------------------------------------------------
/spaceshooter/assets/regularExplosion08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion08.png
--------------------------------------------------------------------------------
/spaceshooter/assets/shield_gold.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/shield_gold.png
--------------------------------------------------------------------------------
/spaceshooter/assets/sonicExplosion00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion00.png
--------------------------------------------------------------------------------
/spaceshooter/assets/sonicExplosion01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion01.png
--------------------------------------------------------------------------------
/spaceshooter/assets/sonicExplosion02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion02.png
--------------------------------------------------------------------------------
/spaceshooter/assets/sonicExplosion03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion03.png
--------------------------------------------------------------------------------
/spaceshooter/assets/sonicExplosion04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion04.png
--------------------------------------------------------------------------------
/spaceshooter/assets/sonicExplosion05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion05.png
--------------------------------------------------------------------------------
/spaceshooter/assets/sonicExplosion06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion06.png
--------------------------------------------------------------------------------
/spaceshooter/assets/sonicExplosion07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion07.png
--------------------------------------------------------------------------------
/spaceshooter/assets/sonicExplosion08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion08.png
--------------------------------------------------------------------------------
/spaceshooter/assets/starfield.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/starfield.png
--------------------------------------------------------------------------------
/spaceshooter/detection_and_tracking.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import sys
4 | from os.path import realpath, dirname, join
5 |
6 | import cv2
7 | import numpy as np
8 | import tensorflow as tf
9 | import torch
10 |
11 | from net import SiamRPNvot
12 | from nets import ssd_vgg_300, np_methods
13 | from preprocessing import ssd_vgg_preprocessing
14 | from run_SiamRPN import SiamRPN_init, SiamRPN_track
15 | from utils import cxy_wh_2_rect
16 |
17 |
18 | def get_object_center(q, detect_class):
19 |
20 | # classes:
21 | # 1.Aeroplanes 2.Bicycles 3.Birds 4.Boats 5.Bottles
22 | # 6.Buses 7.Cars 8.Cats 9.Chairs 10.Cows
23 | # 11.Dining tables 12.Dogs 13.Horses 14.Motorbikes 15.People
24 | # 16.Potted plants 17.Sheep 18.Sofas 19.Trains 20.TV/Monitors
25 |
26 | slim = tf.contrib.slim
27 |
28 | # TensorFlow session: grow memory when needed. TF, DO NOT USE ALL MY GPU MEMORY!!!
29 | gpu_options = tf.GPUOptions(allow_growth=True)
30 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
31 | isess = tf.InteractiveSession(config=config)
32 |
33 | # Input placeholder.
34 | net_shape = (300, 300)
35 | data_format = 'NHWC'
36 | img_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
37 | # Evaluation pre-processing: resize to SSD net shape.
38 | image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
39 | img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
40 | image_4d = tf.expand_dims(image_pre, 0)
41 |
42 | # Define the SSD model.
43 | reuse = True if 'ssd_net' in locals() else None
44 | ssd_net = ssd_vgg_300.SSDNet()
45 | with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
46 | predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse)
47 |
48 | # Restore SSD model.
49 | # ckpt_filename = 'checkpoints/ssd_300_vgg.ckpt'
50 | ckpt_filename = '../SSD-Tensorflow/checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt'
51 |
52 | isess.run(tf.global_variables_initializer())
53 | saver = tf.train.Saver()
54 | saver.restore(isess, ckpt_filename)
55 |
56 | # SSD default anchor boxes.
57 | ssd_anchors = ssd_net.anchors(net_shape)
58 |
59 | # Main image processing routine.
60 | def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)):
61 | # Run SSD network.
62 | rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img],
63 | feed_dict={img_input: img})
64 |
65 | # Get classes and bboxes from the net outputs.
66 | rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
67 | rpredictions, rlocalisations, ssd_anchors,
68 | select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True)
69 |
70 | rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
71 | rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
72 | rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
73 | # Resize bboxes to original image shape. Note: useless for Resize.WARP!
74 | rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
75 | return rclasses, rscores, rbboxes
76 |
77 | def get_bboxes(rclasses, rbboxes):
78 | # get center location of object
79 |
80 | number_classes = rclasses.shape[0]
81 | object_bboxes = []
82 | for i in range(number_classes):
83 | object_bbox = dict()
84 | object_bbox['i'] = i
85 | object_bbox['class'] = rclasses[i]
86 | object_bbox['y_min'] = rbboxes[i, 0]
87 | object_bbox['x_min'] = rbboxes[i, 1]
88 | object_bbox['y_max'] = rbboxes[i, 2]
89 | object_bbox['x_max'] = rbboxes[i, 3]
90 | object_bboxes.append(object_bbox)
91 | return object_bboxes
92 |
93 | # load net
94 | net = SiamRPNvot()
95 | net.load_state_dict(torch.load(join(realpath(dirname(__file__)), '../DaSiamRPN-master/code/SiamRPNVOT.model')))
96 |
97 | net.eval()
98 |
99 | # open video capture
100 | video = cv2.VideoCapture(0)
101 |
102 | if not video.isOpened():
103 | print("Could not open video")
104 | sys.exit()
105 |
106 | index = True
107 | while index:
108 |
109 | # Read first frame.
110 | ok, frame = video.read()
111 | if not ok:
112 | print('Cannot read video file')
113 | sys.exit()
114 |
115 | # Define an initial bounding box
116 | height = frame.shape[0]
117 | width = frame.shape[1]
118 |
119 | rclasses, rscores, rbboxes = process_image(frame)
120 |
121 | bboxes = get_bboxes(rclasses, rbboxes)
122 | for bbox in bboxes:
123 | if bbox['class'] == detect_class:
124 | print(bbox)
125 | ymin = int(bbox['y_min'] * height)
126 | xmin = int((bbox['x_min']) * width)
127 | ymax = int(bbox['y_max'] * height)
128 | xmax = int((bbox['x_max']) * width)
129 | cx = (xmin + xmax) / 2
130 | cy = (ymin + ymax) / 2
131 | h = ymax - ymin
132 | w = xmax - xmin
133 | new_bbox = (cx, cy, w, h)
134 | print(new_bbox)
135 | index = False
136 | break
137 |
138 | # tracker init
139 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h])
140 | state = SiamRPN_init(frame, target_pos, target_sz, net)
141 |
142 | # tracking and visualization
143 | toc = 0
144 | count_number = 0
145 |
146 | while True:
147 |
148 | # Read a new frame
149 | ok, frame = video.read()
150 | if not ok:
151 | break
152 |
153 | # Start timer
154 | tic = cv2.getTickCount()
155 |
156 | # Update tracker
157 | state = SiamRPN_track(state, frame) # track
158 | # print(state)
159 |
160 | toc += cv2.getTickCount() - tic
161 |
162 | if state:
163 |
164 | res = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
165 | res = [int(l) for l in res]
166 |
167 | cv2.rectangle(frame, (res[0], res[1]), (res[0] + res[2], res[1] + res[3]), (0, 255, 255), 3)
168 |
169 | count_number += 1
170 | # set object_center
171 | object_center = dict()
172 | object_center['x'] = state['target_pos'][0] / width
173 | object_center['y'] = state['target_pos'][1] / height
174 | q.put(object_center)
175 |
176 | if (not state) or count_number % 40 == 3:
177 | # Tracking failure
178 | cv2.putText(frame, "Tracking failure detected", (100, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255),
179 | 2)
180 | index = True
181 | while index:
182 | ok, frame = video.read()
183 | rclasses, rscores, rbboxes = process_image(frame)
184 | bboxes = get_bboxes(rclasses, rbboxes)
185 | for bbox in bboxes:
186 | if bbox['class'] == detect_class:
187 | ymin = int(bbox['y_min'] * height)
188 | xmin = int(bbox['x_min'] * width)
189 | ymax = int(bbox['y_max'] * height)
190 | xmax = int(bbox['x_max'] * width)
191 | cx = (xmin + xmax) / 2
192 | cy = (ymin + ymax) / 2
193 | h = ymax - ymin
194 | w = xmax - xmin
195 | new_bbox = (cx, cy, w, h)
196 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h])
197 | state = SiamRPN_init(frame, target_pos, target_sz, net)
198 |
199 | p1 = (int(xmin), int(ymin))
200 | p2 = (int(xmax), int(ymax))
201 | cv2.rectangle(frame, p1, p2, (0, 255, 0), 2, 1)
202 |
203 | index = 0
204 |
205 | break
206 |
207 | # 调整图片大小
208 | resized_frame = cv2.resize(frame, None, fx=0.65, fy=0.65, interpolation=cv2.INTER_AREA)
209 | # 水平翻转图片(为了镜像显示)
210 | horizontal = cv2.flip(resized_frame, 1, dst=None)
211 |
212 | # 显示图片
213 | cv2.namedWindow("SSD+SiamRPN", cv2.WINDOW_NORMAL)
214 | cv2.imshow('SSD+SiamRPN', horizontal)
215 |
216 | # Exit if ESC pressed
217 | k = cv2.waitKey(1) & 0xff
218 | if k == 27:
219 | break
220 |
221 | video.release()
222 | cv2.destroyAllWindows()
223 |
--------------------------------------------------------------------------------
/spaceshooter/detection_tracking_game.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import multiprocessing
4 |
5 | import sys
6 |
7 | from detection_and_tracking import get_object_center
8 |
9 | from simulate_game import space_shooter
10 |
11 | sys.path.append("..")
12 |
13 | q = multiprocessing.Queue()
14 |
15 | # classes:
16 | # 1.Aeroplanes 2.Bicycles 3.Birds 4.Boats 5.Bottles
17 | # 6.Buses 7.Cars 8.Cats 9.Chairs 10.Cows
18 | # 11.Dining tables 12.Dogs 13.Horses 14.Motorbikes 15.People
19 | # 16.Potted plants 17.Sheep 18.Sofas 19.Trains 20.TV/Monitors
20 |
21 | # 指定进行检测和跟踪的类型
22 | OBJECT_CLASS = 5
23 |
24 | set_process = multiprocessing.Process(target=get_object_center, args=(q, OBJECT_CLASS,))
25 | game_process = multiprocessing.Process(target=space_shooter, args=(q,))
26 |
27 | set_process.start()
28 | game_process.start()
29 |
30 | game_process.join()
31 | set_process.terminate()
32 |
33 | print("退出主线程")
34 |
--------------------------------------------------------------------------------
/spaceshooter/readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/spaceshooter/sounds/expl3.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/expl3.wav
--------------------------------------------------------------------------------
/spaceshooter/sounds/expl6.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/expl6.wav
--------------------------------------------------------------------------------
/spaceshooter/sounds/getready.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/getready.ogg
--------------------------------------------------------------------------------
/spaceshooter/sounds/menu.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/menu.ogg
--------------------------------------------------------------------------------
/spaceshooter/sounds/pew.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/pew.wav
--------------------------------------------------------------------------------
/spaceshooter/sounds/rocket.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/rocket.ogg
--------------------------------------------------------------------------------
/spaceshooter/sounds/rumble1.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/rumble1.ogg
--------------------------------------------------------------------------------
/spaceshooter/sounds/tgfcoder-FrozenJam-SeamlessLoop.ogg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/tgfcoder-FrozenJam-SeamlessLoop.ogg
--------------------------------------------------------------------------------