├── .idea ├── Object-detecting-and-tracking.iml ├── misc.xml ├── modules.xml ├── other.xml └── vcs.xml ├── DaSiamRPN ├── README.md ├── demo.py ├── eval_otb.py ├── net.py ├── run_SiamRPN.py ├── test_otb.py ├── utils.py ├── vot.py └── vot_SiamRPN.py ├── README.md ├── SSD+SiamRPN.py ├── SSD ├── README.md ├── UseSSD.py ├── caffe_to_tensorflow.py ├── code_test.py ├── datasets │ ├── __init__.py │ ├── cifar10.py │ ├── dataset_factory.py │ ├── dataset_utils.py │ ├── imagenet.py │ ├── pascalvoc_2007.py │ ├── pascalvoc_2012.py │ ├── pascalvoc_common.py │ └── pascalvoc_to_tfrecords.py ├── eval_ssd_network.py ├── inspect_checkpoint.py ├── nets │ ├── __init__.py │ ├── caffe_scope.py │ ├── custom_layers.py │ ├── inception.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── nets_factory.py │ ├── np_methods.py │ ├── readme │ ├── ssd_common.py │ ├── ssd_vgg_300.py │ ├── ssd_vgg_512.py │ ├── vgg.py │ └── xception.py ├── preprocessing │ ├── __init__.py │ ├── inception_preprocessing.py │ ├── preprocessing_factory.py │ ├── readme │ ├── ssd_vgg_preprocessing.py │ ├── tf_image.py │ └── vgg_preprocessing.py ├── readme ├── tf_convert_data.py ├── tf_extended │ ├── __init__.py │ ├── bboxes.py │ ├── image.py │ ├── math.py │ ├── metrics.py │ └── tensors.py ├── tf_utils.py └── train_ssd_network.py └── spaceshooter ├── assets ├── bolt_gold.png ├── laserRed16.png ├── main.png ├── meteorBrown_big1.png ├── meteorBrown_big2.png ├── meteorBrown_med1.png ├── meteorBrown_med3.png ├── meteorBrown_small1.png ├── meteorBrown_small2.png ├── meteorBrown_tiny1.png ├── missile.png ├── playerShip1_orange.png ├── regularExplosion00.png ├── regularExplosion01.png ├── regularExplosion02.png ├── regularExplosion03.png ├── regularExplosion04.png ├── regularExplosion05.png ├── regularExplosion06.png ├── regularExplosion07.png ├── regularExplosion08.png ├── shield_gold.png ├── sonicExplosion00.png ├── sonicExplosion01.png ├── sonicExplosion02.png ├── sonicExplosion03.png ├── sonicExplosion04.png ├── sonicExplosion05.png ├── sonicExplosion06.png ├── sonicExplosion07.png ├── sonicExplosion08.png └── starfield.png ├── detection_and_tracking.py ├── detection_tracking_game.py ├── readme ├── simulate_game.py ├── sounds ├── expl3.wav ├── expl6.wav ├── getready.ogg ├── menu.ogg ├── pew.wav ├── rocket.ogg ├── rumble1.ogg └── tgfcoder-FrozenJam-SeamlessLoop.ogg └── spaceShooter.py /.idea/Object-detecting-and-tracking.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 14 | 15 | 18 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /DaSiamRPN/README.md: -------------------------------------------------------------------------------- 1 | :trophy:News: **We won the VOT-18 real-time challenge** 2 | 3 | :trophy:News: **We won the second place in the VOT-18 long-term challenge** 4 | 5 | # DaSiamRPN 6 | 7 | This repository includes PyTorch code for reproducing the results on VOT2018. 8 | 9 | [**Distractor-aware Siamese Networks for Visual Object Tracking**](https://arxiv.org/pdf/1808.06048.pdf) 10 | 11 | Zheng Zhu\*, Qiang Wang\*, Bo Li\*, Wei Wu, Junjie Yan, and Weiming Hu 12 | 13 | *European Conference on Computer Vision (ECCV), 2018* 14 | 15 | 16 | 17 | ## Introduction 18 | 19 | **SiamRPN** formulates the task of visual tracking as a task of localization and identification simultaneously, initially described in an [CVPR2018 spotlight paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf). (Slides at [CVPR 2018 Spotlight](https://drive.google.com/open?id=1OGIOUqANvYfZjRoQfpiDqhPQtOvPCpdq)) 20 | 21 | **DaSiamRPN** improves the performances of SiamRPN by (1) introducing an effective sampling strategy to control the imbalanced sample distribution, (2) designing a novel distractor-aware module to perform incremental learning, (3) making a long-term tracking extension. [ECCV2018](https://arxiv.org/pdf/1808.06048.pdf). (Slides at [VOT-18 Real-time challenge winners talk](https://drive.google.com/open?id=1dsEI2uYHDfELK0CW2xgv7R4QdCs6lwfr)) 22 | 23 |
24 | 25 |
26 | 27 | ## Prerequisites 28 | 29 | CPU: Intel(R) Core(TM) i7-7700 CPU @ 3.60GHz 30 | GPU: NVIDIA GTX1060 31 | 32 | - python2.7 33 | - pytorch == 0.3.1 34 | - numpy 35 | - opencv 36 | 37 | 38 | ## Pretrained model for SiamRPN 39 | 40 | In our tracker, we use an AlexNet variant as our backbone, which is end-to-end trained for visual tracking. 41 | The pretrained model can be downloaded from google drive: [SiamRPNBIG.model](https://drive.google.com/file/d/1-vNVZxfbIplXHrqMHiJJYWXYWsOIvGsf/view?usp=sharing). 42 | Then, you should copy the pretrained model file `SiamRPNBIG.model` to the subfolder './code', so that the tracker can find and load the pretrained_model. 43 | 44 | 45 | ## Detailed steps to install the prerequisites 46 | 47 | - install pytorch, numpy, opencv following the instructions in the `run_install.sh`. Please do **not** use conda to install. 48 | - you can alternatively modify `/PATH/TO/CODE/FOLDER/` in `tracker_SiamRPN.m` 49 | If the tracker is ready, you will see the tracking results. (EAO: 0.3827) 50 | 51 | 52 | ## Results 53 | All results can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1HJOvl_irX3KFbtfj88_FVLtukMI1GTCR?usp=sharing). 54 | 55 | | | VOT2015
A / R / EAO
| VOT2016
A / R / EAO
| VOT2017 & VOT2018
A / R / EAO
| OTB2015
OP / DP
| UAV123
AUC / DP
| UAV20L
AUC / DP
| 56 | | :-: | :-: | :-: | :-: | :-: | :-: | :-: | 57 | | **SiamRPN**
CVPR2017
| 0.58 / 1.13 / 0.349 | 0.56 / 0.26 / 0.344 | 0.49 / 0.46 / 0.244 | 81.9 / 85.0 | 0.527 / 0.748 | 0.454 / 0.617 | 58 | | **DaSiamRPN**
ECCV2018
| **0.63** / **0.66** / **0.446** | **0.61** / **0.22** / **0.411** | 0.56 / 0.34 / 0.326 | **86.5** / **88.0** | **0.586** / **0.796** | **0.617** / **0.838** | 59 | | **DaSiamRPN**
VOT2018
| - | - | **0.59** / **0.28** / **0.383** | - | - | - | 60 | 61 | 62 | # Demo and Test on OTB2015 63 |
64 | 65 |
66 | 67 | - To reproduce the reuslts on paper, the pretrained model can be downloaded from [Google Drive](https://drive.google.com/open?id=1BtIkp5pB6aqePQGlMb2_Z7bfPy6XEj6H): `SiamRPNOTB.model`.
68 | :zap: :zap: This model is the **fastest** (~200fps) Siamese Tracker with AUC of 0.655 on OTB2015. :zap: :zap: 69 | 70 | - You must download OTB2015 dataset (download [script](code/data/get_otb_data.sh)) at first. 71 | 72 | A simple test example. 73 | 74 | ``` 75 | cd code 76 | python demo.py 77 | ``` 78 | 79 | If you want to test the performance on OTB2015, please using the follwing command. 80 | 81 | ``` 82 | cd code 83 | python test_otb.py 84 | python eval_otb.py OTB2015 "Siam*" 0 1 85 | ``` 86 | 87 | 88 | # License 89 | Licensed under an MIT license. 90 | 91 | 92 | ## Citing DaSiamRPN 93 | 94 | If you find **DaSiamRPN** and **SiamRPN** useful in your research, please consider citing: 95 | 96 | ``` 97 | @inproceedings{Zhu_2018_ECCV, 98 | title={Distractor-aware Siamese Networks for Visual Object Tracking}, 99 | author={Zhu, Zheng and Wang, Qiang and Bo, Li and Wu, Wei and Yan, Junjie and Hu, Weiming}, 100 | booktitle={European Conference on Computer Vision}, 101 | year={2018} 102 | } 103 | 104 | @InProceedings{Li_2018_CVPR, 105 | title = {High Performance Visual Tracking With Siamese Region Proposal Network}, 106 | author = {Li, Bo and Yan, Junjie and Wu, Wei and Zhu, Zheng and Hu, Xiaolin}, 107 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 108 | year = {2018} 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /DaSiamRPN/demo.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # DaSiamRPN 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | # !/usr/bin/python 7 | 8 | import glob, cv2, torch 9 | import numpy as np 10 | from os.path import realpath, dirname, join 11 | 12 | from net import SiamRPNvot 13 | from run_SiamRPN import SiamRPN_init, SiamRPN_track 14 | from utils import get_axis_aligned_bbox, cxy_wh_2_rect 15 | 16 | # load net 17 | net = SiamRPNvot() 18 | net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNVOT.model'))) 19 | 20 | # net.eval().cuda() 21 | net.eval() 22 | 23 | # image and init box 24 | image_files = sorted(glob.glob('./bag/*.jpg')) 25 | init_rbox = [334.02, 128.36, 438.19, 188.78, 396.39, 260.83, 292.23, 200.41] 26 | [cx, cy, w, h] = get_axis_aligned_bbox(init_rbox) 27 | 28 | # tracker init 29 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h]) 30 | im = cv2.imread(image_files[0]) # HxWxC 31 | state = SiamRPN_init(im, target_pos, target_sz, net) 32 | 33 | # tracking and visualization 34 | toc = 0 35 | for f, image_file in enumerate(image_files): 36 | im = cv2.imread(image_file) 37 | tic = cv2.getTickCount() 38 | state = SiamRPN_track(state, im) # track 39 | toc += cv2.getTickCount() - tic 40 | res = cxy_wh_2_rect(state['target_pos'], state['target_sz']) 41 | res = [int(l) for l in res] 42 | cv2.rectangle(im, (res[0], res[1]), (res[0] + res[2], res[1] + res[3]), (0, 255, 255), 3) 43 | cv2.imshow('SiamRPN', im) 44 | cv2.waitKey(1) 45 | 46 | print('Tracking Speed {:.1f}fps'.format((len(image_files) - 1) / (toc / cv2.getTickFrequency()))) 47 | -------------------------------------------------------------------------------- /DaSiamRPN/eval_otb.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import os 4 | import glob 5 | from os.path import join as fullfile 6 | import numpy as np 7 | 8 | 9 | def overlap_ratio(rect1, rect2): 10 | ''' 11 | Compute overlap ratio between two rects 12 | - rect: 1d array of [x,y,w,h] or 13 | 2d array of N x [x,y,w,h] 14 | ''' 15 | 16 | if rect1.ndim==1: 17 | rect1 = rect1[None,:] 18 | if rect2.ndim==1: 19 | rect2 = rect2[None,:] 20 | 21 | left = np.maximum(rect1[:,0], rect2[:,0]) 22 | right = np.minimum(rect1[:,0]+rect1[:,2], rect2[:,0]+rect2[:,2]) 23 | top = np.maximum(rect1[:,1], rect2[:,1]) 24 | bottom = np.minimum(rect1[:,1]+rect1[:,3], rect2[:,1]+rect2[:,3]) 25 | 26 | intersect = np.maximum(0,right - left) * np.maximum(0,bottom - top) 27 | union = rect1[:,2]*rect1[:,3] + rect2[:,2]*rect2[:,3] - intersect 28 | iou = np.clip(intersect / union, 0, 1) 29 | return iou 30 | 31 | 32 | def compute_success_overlap(gt_bb, result_bb): 33 | thresholds_overlap = np.arange(0, 1.05, 0.05) 34 | n_frame = len(gt_bb) 35 | success = np.zeros(len(thresholds_overlap)) 36 | iou = overlap_ratio(gt_bb, result_bb) 37 | for i in range(len(thresholds_overlap)): 38 | success[i] = sum(iou > thresholds_overlap[i]) / float(n_frame) 39 | return success 40 | 41 | 42 | def compute_success_error(gt_center, result_center): 43 | thresholds_error = np.arange(0, 51, 1) 44 | n_frame = len(gt_center) 45 | success = np.zeros(len(thresholds_error)) 46 | dist = np.sqrt(np.sum(np.power(gt_center - result_center, 2), axis=1)) 47 | for i in range(len(thresholds_error)): 48 | success[i] = sum(dist <= thresholds_error[i]) / float(n_frame) 49 | return success 50 | 51 | 52 | def get_result_bb(arch, seq): 53 | result_path = fullfile(arch, seq + '.txt') 54 | temp = np.loadtxt(result_path, delimiter=',').astype(np.float) 55 | return np.array(temp) 56 | 57 | 58 | def convert_bb_to_center(bboxes): 59 | return np.array([(bboxes[:, 0] + (bboxes[:, 2] - 1) / 2), 60 | (bboxes[:, 1] + (bboxes[:, 3] - 1) / 2)]).T 61 | 62 | 63 | def eval_auc(dataset='OTB2015', tracker_reg='S*', start=0, end=1e6): 64 | list_path = os.path.join('data', dataset + '.json') 65 | annos = json.load(open(list_path, 'r')) 66 | seqs = list(annos.keys()) 67 | 68 | OTB2013 = ['carDark', 'car4', 'david', 'david2', 'sylvester', 'trellis', 'fish', 'mhyang', 'soccer', 'matrix', 69 | 'ironman', 'deer', 'skating1', 'shaking', 'singer1', 'singer2', 'coke', 'bolt', 'boy', 'dudek', 70 | 'crossing', 'couple', 'football1', 'jogging_1', 'jogging_2', 'doll', 'girl', 'walking2', 'walking', 71 | 'fleetface', 'freeman1', 'freeman3', 'freeman4', 'david3', 'jumping', 'carScale', 'skiing', 'dog1', 72 | 'suv', 'motorRolling', 'mountainBike', 'lemming', 'liquor', 'woman', 'faceocc1', 'faceocc2', 73 | 'basketball', 'football', 'subway', 'tiger1', 'tiger2'] 74 | 75 | trackers = glob.glob(fullfile('test', dataset, tracker_reg)) 76 | trackers = trackers[start:min(end, len(trackers))] 77 | 78 | n_seq = len(seqs) 79 | thresholds_overlap = np.arange(0, 1.05, 0.05) 80 | # thresholds_error = np.arange(0, 51, 1) 81 | 82 | success_overlap = np.zeros((n_seq, len(trackers), len(thresholds_overlap))) 83 | # success_error = np.zeros((n_seq, len(trackers), len(thresholds_error))) 84 | for i in range(n_seq): 85 | seq = seqs[i] 86 | gt_rect = np.array(annos[seq]['gt_rect']).astype(np.float) 87 | gt_center = convert_bb_to_center(gt_rect) 88 | for j in range(len(trackers)): 89 | tracker = trackers[j] 90 | print('{:d} processing:{} tracker: {}'.format(i, seq, tracker)) 91 | bb = get_result_bb(tracker, seq) 92 | center = convert_bb_to_center(bb) 93 | success_overlap[i][j] = compute_success_overlap(gt_rect, bb) 94 | # success_error[i][j] = compute_success_error(gt_center, center) 95 | 96 | print('Success Overlap') 97 | 98 | if 'OTB2015' == dataset: 99 | OTB2013_id = [] 100 | for i in range(n_seq): 101 | if seqs[i] in OTB2013: 102 | OTB2013_id.append(i) 103 | max_auc_OTB2013 = 0. 104 | max_name_OTB2013 = '' 105 | for i in range(len(trackers)): 106 | auc = success_overlap[OTB2013_id, i, :].mean() 107 | if auc > max_auc_OTB2013: 108 | max_auc_OTB2013 = auc 109 | max_name_OTB2013 = trackers[i] 110 | print('%s(OTB2013 AUC:%.4f)' % (trackers[i], auc)) 111 | 112 | max_auc = 0. 113 | max_name = '' 114 | for i in range(len(trackers)): 115 | auc = success_overlap[:, i, :].mean() 116 | if auc > max_auc: 117 | max_auc = auc 118 | max_name = trackers[i] 119 | print('%s(OTB2015 AUC:%.4f)' % (trackers[i], auc)) 120 | 121 | print('\nOTB2013 Best: %s(%.4f)' % (max_name_OTB2013, max_auc_OTB2013)) 122 | print('\nOTB2015 Best: %s(%.4f)' % (max_name, max_auc)) 123 | else: 124 | max_auc = 0. 125 | max_name = '' 126 | for i in range(len(trackers)): 127 | auc = success_overlap[:, i, :].mean() 128 | if auc > max_auc: 129 | max_auc = auc 130 | max_name = trackers[i] 131 | print('%s(%.4f)' % (trackers[i], auc)) 132 | 133 | print('\n%s Best: %s(%.4f)' % (dataset, max_name, max_auc)) 134 | 135 | 136 | if __name__ == "__main__": 137 | if len(sys.argv) < 5: 138 | print('python eval_otb.py OTB2015 Siam* 0 10') 139 | exit() 140 | dataset = sys.argv[1] 141 | tracker_reg = sys.argv[2] 142 | start = int(float(sys.argv[3])) 143 | end = int(float(sys.argv[4])) 144 | eval_auc(dataset, tracker_reg, start, end) 145 | -------------------------------------------------------------------------------- /DaSiamRPN/net.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # DaSiamRPN 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SiamRPN(nn.Module): 11 | def __init__(self, size=2, feature_out=512, anchor=5): 12 | configs = [3, 96, 256, 384, 384, 256] 13 | configs = list(map(lambda x: 3 if x == 3 else x * size, configs)) 14 | feat_in = configs[-1] 15 | super(SiamRPN, self).__init__() 16 | self.featureExtract = nn.Sequential( 17 | nn.Conv2d(configs[0], configs[1], kernel_size=11, stride=2), 18 | nn.BatchNorm2d(configs[1]), 19 | nn.MaxPool2d(kernel_size=3, stride=2), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(configs[1], configs[2], kernel_size=5), 22 | nn.BatchNorm2d(configs[2]), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(configs[2], configs[3], kernel_size=3), 26 | nn.BatchNorm2d(configs[3]), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(configs[3], configs[4], kernel_size=3), 29 | nn.BatchNorm2d(configs[4]), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(configs[4], configs[5], kernel_size=3), 32 | nn.BatchNorm2d(configs[5]), 33 | ) 34 | 35 | self.anchor = anchor 36 | self.feature_out = feature_out 37 | 38 | self.conv_r1 = nn.Conv2d(feat_in, feature_out * 4 * anchor, 3) 39 | self.conv_r2 = nn.Conv2d(feat_in, feature_out, 3) 40 | self.conv_cls1 = nn.Conv2d(feat_in, feature_out * 2 * anchor, 3) 41 | self.conv_cls2 = nn.Conv2d(feat_in, feature_out, 3) 42 | self.regress_adjust = nn.Conv2d(4 * anchor, 4 * anchor, 1) 43 | 44 | self.r1_kernel = [] 45 | self.cls1_kernel = [] 46 | 47 | self.cfg = {} 48 | 49 | def forward(self, x): 50 | x_f = self.featureExtract(x) 51 | return self.regress_adjust(F.conv2d(self.conv_r2(x_f), self.r1_kernel)), \ 52 | F.conv2d(self.conv_cls2(x_f), self.cls1_kernel) 53 | 54 | def temple(self, z): 55 | z_f = self.featureExtract(z) 56 | r1_kernel_raw = self.conv_r1(z_f) 57 | cls1_kernel_raw = self.conv_cls1(z_f) 58 | kernel_size = r1_kernel_raw.data.size()[-1] 59 | self.r1_kernel = r1_kernel_raw.view(self.anchor * 4, self.feature_out, kernel_size, kernel_size) 60 | self.cls1_kernel = cls1_kernel_raw.view(self.anchor * 2, self.feature_out, kernel_size, kernel_size) 61 | 62 | 63 | class SiamRPNBIG(SiamRPN): 64 | def __init__(self): 65 | super(SiamRPNBIG, self).__init__(size=2) 66 | self.cfg = {'lr': 0.295, 'window_influence': 0.42, 'penalty_k': 0.055, 'instance_size': 271, 67 | 'adaptive': True} # 0.383 68 | 69 | 70 | class SiamRPNvot(SiamRPN): 71 | def __init__(self): 72 | super(SiamRPNvot, self).__init__(size=1, feature_out=256) 73 | self.cfg = {'lr': 0.45, 'window_influence': 0.44, 'penalty_k': 0.04, 'instance_size': 271, 74 | 'adaptive': False} # 0.355 75 | 76 | 77 | class SiamRPNotb(SiamRPN): 78 | def __init__(self): 79 | super(SiamRPNotb, self).__init__(size=1, feature_out=256) 80 | self.cfg = {'lr': 0.30, 'window_influence': 0.40, 'penalty_k': 0.22, 'instance_size': 271, 81 | 'adaptive': False} # 0.655 82 | -------------------------------------------------------------------------------- /DaSiamRPN/run_SiamRPN.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # DaSiamRPN 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | import numpy as np 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | 10 | 11 | from utils import get_subwindow_tracking 12 | 13 | 14 | def generate_anchor(total_stride, scales, ratios, score_size): 15 | anchor_num = len(ratios) * len(scales) 16 | anchor = np.zeros((anchor_num, 4), dtype=np.float32) 17 | size = total_stride * total_stride 18 | count = 0 19 | for ratio in ratios: 20 | # ws = int(np.sqrt(size * 1.0 / ratio)) 21 | ws = int(np.sqrt(size / ratio)) 22 | hs = int(ws * ratio) 23 | for scale in scales: 24 | wws = ws * scale 25 | hhs = hs * scale 26 | anchor[count, 0] = 0 27 | anchor[count, 1] = 0 28 | anchor[count, 2] = wws 29 | anchor[count, 3] = hhs 30 | count += 1 31 | 32 | anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4)) 33 | ori = - (score_size / 2) * total_stride 34 | xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)], 35 | [ori + total_stride * dy for dy in range(score_size)]) 36 | xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \ 37 | np.tile(yy.flatten(), (anchor_num, 1)).flatten() 38 | anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32) 39 | return anchor 40 | 41 | 42 | class TrackerConfig(object): 43 | # These are the default hyper-params for DaSiamRPN 0.3827 44 | windowing = 'cosine' # to penalize large displacements [cosine/uniform] 45 | # Params from the network architecture, have to be consistent with the training 46 | exemplar_size = 127 # input z size 47 | instance_size = 271 # input x size (search region) 48 | total_stride = 8 49 | score_size = (instance_size-exemplar_size)/total_stride+1 50 | context_amount = 0.5 # context amount for the exemplar 51 | ratios = [0.33, 0.5, 1, 2, 3] 52 | scales = [8, ] 53 | anchor_num = len(ratios) * len(scales) 54 | anchor = [] 55 | penalty_k = 0.055 56 | window_influence = 0.42 57 | lr = 0.295 58 | # adaptive change search region # 59 | adaptive = True 60 | 61 | def update(self, cfg): 62 | for k, v in cfg.items(): 63 | setattr(self, k, v) 64 | self.score_size = (self.instance_size - self.exemplar_size) / self.total_stride + 1 65 | 66 | 67 | def tracker_eval(net, x_crop, target_pos, target_sz, window, scale_z, p): 68 | delta, score = net(x_crop) 69 | 70 | delta = delta.permute(1, 2, 3, 0).contiguous().view(4, -1).data.cpu().numpy() 71 | score = F.softmax(score.permute(1, 2, 3, 0).contiguous().view(2, -1), dim=0).data[1, :].cpu().numpy() 72 | 73 | delta[0, :] = delta[0, :] * p.anchor[:, 2] + p.anchor[:, 0] 74 | delta[1, :] = delta[1, :] * p.anchor[:, 3] + p.anchor[:, 1] 75 | delta[2, :] = np.exp(delta[2, :]) * p.anchor[:, 2] 76 | delta[3, :] = np.exp(delta[3, :]) * p.anchor[:, 3] 77 | 78 | def change(r): 79 | return np.maximum(r, 1./r) 80 | 81 | def sz(w, h): 82 | pad = (w + h) * 0.5 83 | sz2 = (w + pad) * (h + pad) 84 | return np.sqrt(sz2) 85 | 86 | def sz_wh(wh): 87 | pad = (wh[0] + wh[1]) * 0.5 88 | sz2 = (wh[0] + pad) * (wh[1] + pad) 89 | return np.sqrt(sz2) 90 | 91 | # size penalty 92 | s_c = change(sz(delta[2, :], delta[3, :]) / (sz_wh(target_sz))) # scale penalty 93 | r_c = change((target_sz[0] / target_sz[1]) / (delta[2, :] / delta[3, :])) # ratio penalty 94 | 95 | penalty = np.exp(-(r_c * s_c - 1.) * p.penalty_k) 96 | pscore = penalty * score 97 | 98 | # window float 99 | pscore = pscore * (1 - p.window_influence) + window * p.window_influence 100 | best_pscore_id = np.argmax(pscore) 101 | 102 | target = delta[:, best_pscore_id] / scale_z 103 | target_sz = target_sz / scale_z 104 | lr = penalty[best_pscore_id] * score[best_pscore_id] * p.lr 105 | 106 | res_x = target[0] + target_pos[0] 107 | res_y = target[1] + target_pos[1] 108 | 109 | res_w = target_sz[0] * (1 - lr) + target[2] * lr 110 | res_h = target_sz[1] * (1 - lr) + target[3] * lr 111 | 112 | target_pos = np.array([res_x, res_y]) 113 | target_sz = np.array([res_w, res_h]) 114 | return target_pos, target_sz, score[best_pscore_id] 115 | 116 | 117 | def SiamRPN_init(im, target_pos, target_sz, net): 118 | state = dict() 119 | p = TrackerConfig() 120 | p.update(net.cfg) 121 | state['im_h'] = im.shape[0] 122 | state['im_w'] = im.shape[1] 123 | 124 | if p.adaptive: 125 | if ((target_sz[0] * target_sz[1]) / float(state['im_h'] * state['im_w'])) < 0.004: 126 | p.instance_size = 287 # small object big search region 127 | else: 128 | p.instance_size = 271 129 | 130 | p.score_size = (p.instance_size - p.exemplar_size) / p.total_stride + 1 131 | 132 | p.anchor = generate_anchor(p.total_stride, p.scales, p.ratios, int(p.score_size)) 133 | 134 | avg_chans = np.mean(im, axis=(0, 1)) 135 | 136 | wc_z = target_sz[0] + p.context_amount * sum(target_sz) 137 | hc_z = target_sz[1] + p.context_amount * sum(target_sz) 138 | s_z = round(np.sqrt(wc_z * hc_z)) 139 | # initialize the exemplar 140 | z_crop = get_subwindow_tracking(im, target_pos, p.exemplar_size, s_z, avg_chans) 141 | 142 | z = Variable(z_crop.unsqueeze(0)) 143 | # net.temple(z.cuda()) 144 | net.temple(z) 145 | 146 | if p.windowing == 'cosine': 147 | window = np.outer(np.hanning(p.score_size), np.hanning(p.score_size)) 148 | elif p.windowing == 'uniform': 149 | window = np.ones((p.score_size, p.score_size)) 150 | window = np.tile(window.flatten(), p.anchor_num) 151 | 152 | state['p'] = p 153 | state['net'] = net 154 | state['avg_chans'] = avg_chans 155 | state['window'] = window 156 | state['target_pos'] = target_pos 157 | state['target_sz'] = target_sz 158 | return state 159 | 160 | 161 | def SiamRPN_track(state, im): 162 | p = state['p'] 163 | net = state['net'] 164 | avg_chans = state['avg_chans'] 165 | window = state['window'] 166 | target_pos = state['target_pos'] 167 | target_sz = state['target_sz'] 168 | 169 | wc_z = target_sz[1] + p.context_amount * sum(target_sz) 170 | hc_z = target_sz[0] + p.context_amount * sum(target_sz) 171 | s_z = np.sqrt(wc_z * hc_z) 172 | scale_z = p.exemplar_size / s_z 173 | d_search = (p.instance_size - p.exemplar_size) / 2 174 | pad = d_search / scale_z 175 | s_x = s_z + 2 * pad 176 | 177 | # extract scaled crops for search region x at previous target position 178 | x_crop = Variable(get_subwindow_tracking(im, target_pos, p.instance_size, round(s_x), avg_chans).unsqueeze(0)) 179 | 180 | # target_pos, target_sz, score = tracker_eval(net, x_crop.cuda(), target_pos, target_sz * scale_z, window, scale_z, p) 181 | target_pos, target_sz, score = tracker_eval(net, x_crop, target_pos, target_sz * scale_z, window, scale_z, p) 182 | 183 | target_pos[0] = max(0, min(state['im_w'], target_pos[0])) 184 | target_pos[1] = max(0, min(state['im_h'], target_pos[1])) 185 | target_sz[0] = max(10, min(state['im_w'], target_sz[0])) 186 | target_sz[1] = max(10, min(state['im_h'], target_sz[1])) 187 | state['target_pos'] = target_pos 188 | state['target_sz'] = target_sz 189 | state['score'] = score 190 | return state 191 | -------------------------------------------------------------------------------- /DaSiamRPN/test_otb.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # DaSiamRPN 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | #!/usr/bin/python 7 | 8 | import argparse, cv2, torch, json 9 | import numpy as np 10 | from os import makedirs 11 | from os.path import realpath, dirname, join, isdir, exists 12 | 13 | from net import SiamRPNotb 14 | from run_SiamRPN import SiamRPN_init, SiamRPN_track 15 | from utils import rect_2_cxy_wh, cxy_wh_2_rect 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch SiamRPN OTB Test') 18 | parser.add_argument('--dataset', dest='dataset', default='OTB2015', help='datasets') 19 | parser.add_argument('-v', '--visualization', dest='visualization', action='store_true', 20 | help='whether visualize result') 21 | 22 | 23 | def track_video(model, video): 24 | toc, regions = 0, [] 25 | image_files, gt = video['image_files'], video['gt'] 26 | for f, image_file in enumerate(image_files): 27 | im = cv2.imread(image_file) # TODO: batch load 28 | tic = cv2.getTickCount() 29 | if f == 0: # init 30 | target_pos, target_sz = rect_2_cxy_wh(gt[f]) 31 | state = SiamRPN_init(im, target_pos, target_sz, model) # init tracker 32 | location = cxy_wh_2_rect(state['target_pos'], state['target_sz']) 33 | regions.append(gt[f]) 34 | elif f > 0: # tracking 35 | state = SiamRPN_track(state, im) # track 36 | location = cxy_wh_2_rect(state['target_pos']+1, state['target_sz']) 37 | regions.append(location) 38 | toc += cv2.getTickCount() - tic 39 | 40 | if args.visualization and f >= 0: # visualization 41 | if f == 0: cv2.destroyAllWindows() 42 | if len(gt[f]) == 8: 43 | cv2.polylines(im, [np.array(gt[f], np.int).reshape((-1, 1, 2))], True, (0, 255, 0), 3) 44 | else: 45 | cv2.rectangle(im, (gt[f, 0], gt[f, 1]), (gt[f, 0] + gt[f, 2], gt[f, 1] + gt[f, 3]), (0, 255, 0), 3) 46 | if len(location) == 8: 47 | cv2.polylines(im, [location.reshape((-1, 1, 2))], True, (0, 255, 255), 3) 48 | else: 49 | location = [int(l) for l in location] # 50 | cv2.rectangle(im, (location[0], location[1]), 51 | (location[0] + location[2], location[1] + location[3]), (0, 255, 255), 3) 52 | cv2.putText(im, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2) 53 | 54 | cv2.imshow(video['name'], im) 55 | cv2.waitKey(1) 56 | toc /= cv2.getTickFrequency() 57 | 58 | # save result 59 | video_path = join('test', args.dataset, 'SiamRPN_AlexNet_OTB2015') 60 | if not isdir(video_path): makedirs(video_path) 61 | result_path = join(video_path, '{:s}.txt'.format(video['name'])) 62 | with open(result_path, "w") as fin: 63 | for x in regions: 64 | fin.write(','.join([str(i) for i in x])+'\n') 65 | 66 | print('({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps'.format( 67 | v_id, video['name'], toc, f / toc)) 68 | return f / toc 69 | 70 | 71 | def load_dataset(dataset): 72 | base_path = join(realpath(dirname(__file__)), 'data', dataset) 73 | if not exists(base_path): 74 | print("Please download OTB dataset into `data` folder!") 75 | exit() 76 | json_path = join(realpath(dirname(__file__)), 'data', dataset + '.json') 77 | info = json.load(open(json_path, 'r')) 78 | for v in info.keys(): 79 | path_name = info[v]['name'] 80 | info[v]['image_files'] = [join(base_path, path_name, 'img', im_f) for im_f in info[v]['image_files']] 81 | info[v]['gt'] = np.array(info[v]['gt_rect'])-[1,1,0,0] # our tracker is 0-index 82 | info[v]['name'] = v 83 | return info 84 | 85 | 86 | def main(): 87 | global args, v_id 88 | args = parser.parse_args() 89 | 90 | net = SiamRPNotb() 91 | net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNOTB.model'))) 92 | net.eval().cuda() 93 | 94 | dataset = load_dataset(args.dataset) 95 | fps_list = [] 96 | for v_id, video in enumerate(dataset.keys()): 97 | fps_list.append(track_video(net, dataset[video])) 98 | print('Mean Running Speed {:.1f}fps'.format(np.mean(np.array(fps_list)))) 99 | 100 | 101 | if __name__ == '__main__': 102 | main() 103 | -------------------------------------------------------------------------------- /DaSiamRPN/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # DaSiamRPN 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | import cv2 7 | import torch 8 | import numpy as np 9 | 10 | 11 | def to_numpy(tensor): 12 | if torch.is_tensor(tensor): 13 | return tensor.cpu().numpy() 14 | elif type(tensor).__module__ != 'numpy': 15 | raise ValueError("Cannot convert {} to numpy array" 16 | .format(type(tensor))) 17 | return tensor 18 | 19 | 20 | def to_torch(ndarray): 21 | if type(ndarray).__module__ == 'numpy': 22 | return torch.from_numpy(ndarray) 23 | elif not torch.is_tensor(ndarray): 24 | raise ValueError("Cannot convert {} to torch tensor" 25 | .format(type(ndarray))) 26 | return ndarray 27 | 28 | 29 | def im_to_numpy(img): 30 | img = to_numpy(img) 31 | img = np.transpose(img, (1, 2, 0)) # H*W*C 32 | return img 33 | 34 | 35 | def im_to_torch(img): 36 | img = np.transpose(img, (2, 0, 1)) # C*H*W 37 | img = to_torch(img).float() 38 | return img 39 | 40 | 41 | def torch_to_img(img): 42 | img = to_numpy(torch.squeeze(img, 0)) 43 | img = np.transpose(img, (1, 2, 0)) # H*W*C 44 | return img 45 | 46 | 47 | def get_subwindow_tracking(im, pos, model_sz, original_sz, avg_chans, out_mode='torch', new=False): 48 | 49 | if isinstance(pos, float): 50 | pos = [pos, pos] 51 | sz = original_sz 52 | im_sz = im.shape 53 | c = (original_sz+1) / 2 54 | context_xmin = round(pos[0] - c) # floor(pos(2) - sz(2) / 2); 55 | context_xmax = context_xmin + sz - 1 56 | context_ymin = round(pos[1] - c) # floor(pos(1) - sz(1) / 2); 57 | context_ymax = context_ymin + sz - 1 58 | left_pad = int(max(0., -context_xmin)) 59 | top_pad = int(max(0., -context_ymin)) 60 | right_pad = int(max(0., context_xmax - im_sz[1] + 1)) 61 | bottom_pad = int(max(0., context_ymax - im_sz[0] + 1)) 62 | 63 | context_xmin = context_xmin + left_pad 64 | context_xmax = context_xmax + left_pad 65 | context_ymin = context_ymin + top_pad 66 | context_ymax = context_ymax + top_pad 67 | 68 | # zzp: a more easy speed version 69 | r, c, k = im.shape 70 | if any([top_pad, bottom_pad, left_pad, right_pad]): 71 | te_im = np.zeros((r + top_pad + bottom_pad, c + left_pad + right_pad, k), np.uint8) # 0 is better than 1 initialization 72 | te_im[top_pad:top_pad + r, left_pad:left_pad + c, :] = im 73 | if top_pad: 74 | te_im[0:top_pad, left_pad:left_pad + c, :] = avg_chans 75 | if bottom_pad: 76 | te_im[r + top_pad:, left_pad:left_pad + c, :] = avg_chans 77 | if left_pad: 78 | te_im[:, 0:left_pad, :] = avg_chans 79 | if right_pad: 80 | te_im[:, c + left_pad:, :] = avg_chans 81 | im_patch_original = te_im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :] 82 | else: 83 | im_patch_original = im[int(context_ymin):int(context_ymax + 1), int(context_xmin):int(context_xmax + 1), :] 84 | 85 | if not np.array_equal(model_sz, original_sz): 86 | im_patch = cv2.resize(im_patch_original, (model_sz, model_sz)) # zzp: use cv to get a better speed 87 | else: 88 | im_patch = im_patch_original 89 | 90 | return im_to_torch(im_patch) if out_mode in 'torch' else im_patch 91 | 92 | 93 | def cxy_wh_2_rect(pos, sz): 94 | return np.array([pos[0]-sz[0]/2, pos[1]-sz[1]/2, sz[0], sz[1]]) # 0-index 95 | 96 | 97 | def rect_2_cxy_wh(rect): 98 | return np.array([rect[0]+rect[2]/2, rect[1]+rect[3]/2]), np.array([rect[2], rect[3]]) # 0-index 99 | 100 | 101 | def get_axis_aligned_bbox(region): 102 | try: 103 | region = np.array([region[0][0][0], region[0][0][1], region[0][1][0], region[0][1][1], 104 | region[0][2][0], region[0][2][1], region[0][3][0], region[0][3][1]]) 105 | except: 106 | region = np.array(region) 107 | cx = np.mean(region[0::2]) 108 | cy = np.mean(region[1::2]) 109 | x1 = min(region[0::2]) 110 | x2 = max(region[0::2]) 111 | y1 = min(region[1::2]) 112 | y2 = max(region[1::2]) 113 | A1 = np.linalg.norm(region[0:2] - region[2:4]) * np.linalg.norm(region[2:4] - region[4:6]) 114 | A2 = (x2 - x1) * (y2 - y1) 115 | s = np.sqrt(A1 / A2) 116 | w = s * (x2 - x1) + 1 117 | h = s * (y2 - y1) + 1 118 | return cx, cy, w, h -------------------------------------------------------------------------------- /DaSiamRPN/vot.py: -------------------------------------------------------------------------------- 1 | """ 2 | \file vot.py 3 | 4 | @brief Python utility functions for VOT integration 5 | 6 | @author Luka Cehovin, Alessio Dore 7 | 8 | @date 2016 9 | 10 | """ 11 | 12 | import sys 13 | import copy 14 | import collections 15 | 16 | try: 17 | import trax 18 | import trax.server 19 | TRAX = True 20 | except ImportError: 21 | TRAX = False 22 | 23 | Rectangle = collections.namedtuple('Rectangle', ['x', 'y', 'width', 'height']) 24 | Point = collections.namedtuple('Point', ['x', 'y']) 25 | Polygon = collections.namedtuple('Polygon', ['points']) 26 | 27 | def parse_region(string): 28 | tokens = map(float, string.split(',')) 29 | if len(tokens) == 4: 30 | return Rectangle(tokens[0], tokens[1], tokens[2], tokens[3]) 31 | elif len(tokens) % 2 == 0 and len(tokens) > 4: 32 | return Polygon([Point(tokens[i],tokens[i+1]) for i in xrange(0,len(tokens),2)]) 33 | return None 34 | 35 | def encode_region(region): 36 | if isinstance(region, Polygon): 37 | return ','.join(['{},{}'.format(p.x,p.y) for p in region.points]) 38 | elif isinstance(region, Rectangle): 39 | return '{},{},{},{}'.format(region.x, region.y, region.width, region.height) 40 | else: 41 | return "" 42 | 43 | def convert_region(region, to): 44 | 45 | if to == 'rectangle': 46 | 47 | if isinstance(region, Rectangle): 48 | return copy.copy(region) 49 | elif isinstance(region, Polygon): 50 | top = sys.float_info.max 51 | bottom = sys.float_info.min 52 | left = sys.float_info.max 53 | right = sys.float_info.min 54 | 55 | for point in region.points: 56 | top = min(top, point.y) 57 | bottom = max(bottom, point.y) 58 | left = min(left, point.x) 59 | right = max(right, point.x) 60 | 61 | return Rectangle(left, top, right - left, bottom - top) 62 | 63 | else: 64 | return None 65 | if to == 'polygon': 66 | 67 | if isinstance(region, Rectangle): 68 | points = [] 69 | points.append((region.x, region.y)) 70 | points.append((region.x + region.width, region.y)) 71 | points.append((region.x + region.width, region.y + region.height)) 72 | points.append((region.x, region.y + region.height)) 73 | return Polygon(points) 74 | 75 | elif isinstance(region, Polygon): 76 | return copy.copy(region) 77 | else: 78 | return None 79 | 80 | return None 81 | 82 | class VOT(object): 83 | """ Base class for Python VOT integration """ 84 | def __init__(self, region_format): 85 | """ Constructor 86 | 87 | Args: 88 | region_format: Region format options 89 | """ 90 | assert(region_format in ['rectangle', 'polygon']) 91 | if TRAX: 92 | options = trax.server.ServerOptions(region_format, trax.image.PATH) 93 | self._trax = trax.server.Server(options) 94 | 95 | request = self._trax.wait() 96 | assert(request.type == 'initialize') 97 | if request.region.type == 'polygon': 98 | self._region = Polygon([Point(x[0], x[1]) for x in request.region.points]) 99 | else: 100 | self._region = Rectangle(request.region.x, request.region.y, request.region.width, request.region.height) 101 | self._image = str(request.image) 102 | self._trax.status(request.region) 103 | else: 104 | self._files = [x.strip('\n') for x in open('images.txt', 'r').readlines()] 105 | self._frame = 0 106 | self._region = convert_region(parse_region(open('region.txt', 'r').readline()), region_format) 107 | self._result = [] 108 | 109 | def region(self): 110 | """ 111 | Send configuration message to the client and receive the initialization 112 | region and the path of the first image 113 | 114 | Returns: 115 | initialization region 116 | """ 117 | 118 | return self._region 119 | 120 | def report(self, region, confidence = 0): 121 | """ 122 | Report the tracking results to the client 123 | 124 | Arguments: 125 | region: region for the frame 126 | """ 127 | assert(isinstance(region, Rectangle) or isinstance(region, Polygon)) 128 | if TRAX: 129 | if isinstance(region, Polygon): 130 | tregion = trax.region.Polygon([(x.x, x.y) for x in region.points]) 131 | else: 132 | tregion = trax.region.Rectangle(region.x, region.y, region.width, region.height) 133 | self._trax.status(tregion, {"confidence" : confidence}) 134 | else: 135 | self._result.append(region) 136 | self._frame += 1 137 | 138 | def frame(self): 139 | """ 140 | Get a frame (image path) from client 141 | 142 | Returns: 143 | absolute path of the image 144 | """ 145 | if TRAX: 146 | if hasattr(self, "_image"): 147 | image = str(self._image) 148 | del self._image 149 | return image 150 | 151 | request = self._trax.wait() 152 | 153 | if request.type == 'frame': 154 | return str(request.image) 155 | else: 156 | return None 157 | 158 | else: 159 | if self._frame >= len(self._files): 160 | return None 161 | return self._files[self._frame] 162 | 163 | def quit(self): 164 | if TRAX: 165 | self._trax.quit() 166 | elif hasattr(self, '_result'): 167 | with open('output.txt', 'w') as f: 168 | for r in self._result: 169 | f.write(encode_region(r)) 170 | f.write('\n') 171 | 172 | def __del__(self): 173 | self.quit() 174 | 175 | -------------------------------------------------------------------------------- /DaSiamRPN/vot_SiamRPN.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # DaSiamRPN 3 | # Licensed under The MIT License 4 | # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) 5 | # -------------------------------------------------------- 6 | #!/usr/bin/python 7 | 8 | import vot 9 | from vot import Rectangle 10 | import sys 11 | import cv2 # imread 12 | import torch 13 | import numpy as np 14 | from os.path import realpath, dirname, join 15 | 16 | from net import SiamRPNBIG 17 | from run_SiamRPN import SiamRPN_init, SiamRPN_track 18 | from utils import get_axis_aligned_bbox, cxy_wh_2_rect 19 | 20 | # load net 21 | net_file = join(realpath(dirname(__file__)), 'SiamRPNBIG.model') 22 | net = SiamRPNBIG() 23 | net.load_state_dict(torch.load(net_file)) 24 | net.eval().cuda() 25 | 26 | # warm up 27 | for i in range(10): 28 | net.temple(torch.autograd.Variable(torch.FloatTensor(1, 3, 127, 127)).cuda()) 29 | net(torch.autograd.Variable(torch.FloatTensor(1, 3, 255, 255)).cuda()) 30 | 31 | # start to track 32 | handle = vot.VOT("polygon") 33 | Polygon = handle.region() 34 | cx, cy, w, h = get_axis_aligned_bbox(Polygon) 35 | 36 | image_file = handle.frame() 37 | if not image_file: 38 | sys.exit(0) 39 | 40 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h]) 41 | im = cv2.imread(image_file) # HxWxC 42 | state = SiamRPN_init(im, target_pos, target_sz, net) # init tracker 43 | while True: 44 | image_file = handle.frame() 45 | if not image_file: 46 | break 47 | im = cv2.imread(image_file) # HxWxC 48 | state = SiamRPN_track(state, im) # track 49 | res = cxy_wh_2_rect(state['target_pos'], state['target_sz']) 50 | 51 | handle.report(Rectangle(res[0], res[1], res[2], res[3])) 52 | 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Object-detecting-and-tracking 2 | This project combines object detection and object tracking. The detector will detect the objects of the image captured by camera and the tracker will track the one of objects choosed by user. The detector is SSD model and tracker is SiamFPN model. Both models are real-time algorithms and you can use these algorithms only by CPU. You can run 'SSD+SiamRPN.py' to achieve object detecting and tracking. 3 | 4 | If you are interested in the princile of these algorithms, please read these papaers: 5 | 6 | * SSD: [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325) 7 | * SiamRPN: [High Performance Visual Tracking with Siamese Region Proposal Network](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf) 8 | 9 | This repository includes original Tensorflow code of SSD 10 | * https://github.com/wcwowwwww/SSD-Tensorflow 11 | 12 | This repository includes original pytorch code of SiamRPN 13 | * https://github.com/wcwowwwww/DaSiamRPN. 14 | 15 | ### Prerequisites 16 | 17 | * python 18 | * numpy 19 | * opencv 20 | * tensorflow 21 | * pytorch 22 | 23 | ### Pretrained model for SSD 24 | 25 | You can download model from https://drive.google.com/file/d/0B0qPCUZ-3YwWUXh4UHJrd1RDM3c/view?usp=sharing , then please put this model into subfolder 'SSD/checkpoints', so that the detector can find and load the pretrained_model. 26 | 27 | ### Pretrained model for SiamRPN 28 | 29 | You can download model from https://drive.google.com/drive/folders/1BtIkp5pB6aqePQGlMb2_Z7bfPy6XEj6H , then please put this model into subfolder 'DaSiamRPN', so that the tracker can find and load the pretrained_model. 30 | 31 | # 结合目标检测和目标跟踪 32 | 33 | ### 目标检测 34 | 目标检测算法使用的是SSD算法,该算法程序基于tensorflow实现。具体见:https://github.com/wcwowwwww/SSD-Tensorflow 35 | 36 | 论文地址为: [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325) 37 | 38 | 预训练的模型可以从 https://drive.google.com/file/d/0B0qPCUZ-3YwWUXh4UHJrd1RDM3c/view?usp=sharing 中下载,下载后放在SSD/checkpoints文件夹中。 39 | 40 | ### 目标追踪 41 | 目标跟踪算法使用的是SiamFPN算法,该算法程序基于pytorch实现。具体见:https://github.com/wcwowwwww/DaSiamRPN 42 | 43 | 论文地址为: [High Performance Visual Tracking with Siamese Region Proposal Network](http://openaccess.thecvf.com/content_cvpr_2018/papers/Li_High_Performance_Visual_CVPR_2018_paper.pdf) 44 | 45 | 46 | 预训练的模型请在 https://drive.google.com/drive/folders/1BtIkp5pB6aqePQGlMb2_Z7bfPy6XEj6H 下载SiamRPNVOT.model文件后放入DaSiamRPN文件夹中即可。 47 | 48 | 这两个算法的精度和速度都表现非常好,不需要GPU,能够在CPU条件下流畅运行。(当然训练时需要使用GPU加速运算) 49 | 50 | 代码执行环境:python, opencv, tensorflow, pytorch均更新为最新即可。 51 | 52 | 下载代码和模型后运行SSD+SiamRPN.py,可以实现对摄像头前目标的检测和跟踪。 53 | 54 | ### spaceshoot 55 | spaceshoot文件夹内是一款飞行设计类游戏,运行detection_tracking_game来打开游戏,按enter进入游戏后,飞机可以通过移动摄像头前的目标控制移动。 56 | 57 | 58 | -------------------------------------------------------------------------------- /SSD+SiamRPN.py: -------------------------------------------------------------------------------- 1 | # ====================================================================== 2 | # 3 | # filename : SSD+SiamRPN.py 4 | # description : 5 | # 6 | # created by Wang An at 2019.4.18 7 | # 8 | # ====================================================================== 9 | 10 | import sys 11 | from os.path import realpath, dirname, join 12 | 13 | import cv2 14 | import numpy as np 15 | import tensorflow as tf 16 | import torch 17 | 18 | from net import SiamRPNvot 19 | from nets import ssd_vgg_300, np_methods 20 | from preprocessing import ssd_vgg_preprocessing 21 | from run_SiamRPN import SiamRPN_init, SiamRPN_track 22 | from utils import cxy_wh_2_rect 23 | 24 | # ------------------------------- # 25 | 26 | 27 | ''' 28 | classes: 29 | 1.Aeroplanes 2.Bicycles 3.Birds 4.Boats 5.Bottles 30 | 6.Buses 7.Cars 8.Cats 9.Chairs 10.Cows 31 | 11.Dining tables 12.Dogs 13.Horses 14.Motorbikes 15.People 32 | 16.Potted plants 17.Sheep 18.Sofas 19.Trains 20.TV/Monitors 33 | ''' 34 | detect_class = 15 35 | 36 | slim = tf.contrib.slim 37 | 38 | # TensorFlow session: grow memory when needed. TF, DO NOT USE ALL MY GPU MEMORY!!! 39 | gpu_options = tf.GPUOptions(allow_growth=True) 40 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options) 41 | isess = tf.InteractiveSession(config=config) 42 | 43 | # Input placeholder. 44 | net_shape = (300, 300) 45 | data_format = 'NHWC' 46 | img_input = tf.placeholder(tf.uint8, shape=(None, None, 3)) 47 | # Evaluation pre-processing: resize to SSD net shape. 48 | image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval( 49 | img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE) 50 | image_4d = tf.expand_dims(image_pre, 0) 51 | 52 | # Define the SSD model. 53 | reuse = True if 'ssd_net' in locals() else None 54 | ssd_net = ssd_vgg_300.SSDNet() 55 | with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)): 56 | predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse) 57 | 58 | # Restore SSD model. 59 | # ckpt_filename = 'checkpoints/ssd_300_vgg.ckpt' 60 | ckpt_filename = 'SSD/checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt' 61 | 62 | isess.run(tf.global_variables_initializer()) 63 | saver = tf.train.Saver() 64 | saver.restore(isess, ckpt_filename) 65 | 66 | # SSD default anchor boxes. 67 | ssd_anchors = ssd_net.anchors(net_shape) 68 | 69 | 70 | # Main image processing routine. 71 | def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)): 72 | # Run SSD network. 73 | rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img], 74 | feed_dict={img_input: img}) 75 | 76 | # Get classes and bboxes from the net outputs. 77 | rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select( 78 | rpredictions, rlocalisations, ssd_anchors, 79 | select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True) 80 | 81 | rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes) 82 | rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400) 83 | rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold) 84 | # Resize bboxes to original image shape. Note: useless for Resize.WARP! 85 | rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes) 86 | return rclasses, rscores, rbboxes 87 | 88 | 89 | def get_bboxes(rclasses, rbboxes): 90 | # get center location of object 91 | 92 | number_classes = rclasses.shape[0] 93 | object_bboxes = [] 94 | for i in range(number_classes): 95 | object_bbox = dict() 96 | object_bbox['i'] = i 97 | object_bbox['class'] = rclasses[i] 98 | object_bbox['y_min'] = rbboxes[i, 0] 99 | object_bbox['x_min'] = rbboxes[i, 1] 100 | object_bbox['y_max'] = rbboxes[i, 2] 101 | object_bbox['x_max'] = rbboxes[i, 3] 102 | object_bboxes.append(object_bbox) 103 | return object_bboxes 104 | 105 | 106 | # load net 107 | net = SiamRPNvot() 108 | net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'DaSiamRPN/SiamRPNVOT.model'))) 109 | 110 | net.eval() 111 | 112 | # open video capture 113 | video = cv2.VideoCapture(0) 114 | 115 | if not video.isOpened(): 116 | print("Could not open video") 117 | sys.exit() 118 | 119 | index = True 120 | while index: 121 | # Read first frame. 122 | ok, frame = video.read() 123 | if not ok: 124 | print('Cannot read video file') 125 | sys.exit() 126 | 127 | # Define an initial bounding box 128 | height = frame.shape[0] 129 | width = frame.shape[1] 130 | rclasses, rscores, rbboxes = process_image(frame) 131 | bboxes = get_bboxes(rclasses, rbboxes) 132 | for bbox in bboxes: 133 | if bbox['class'] == detect_class: 134 | print(bbox) 135 | ymin = int(bbox['y_min'] * height) 136 | xmin = int((bbox['x_min']) * width) 137 | ymax = int(bbox['y_max'] * height) 138 | xmax = int((bbox['x_max']) * width) 139 | cx = (xmin + xmax) / 2 140 | cy = (ymin + ymax) / 2 141 | h = ymax - ymin 142 | w = xmax - xmin 143 | new_bbox = (cx, cy, w, h) 144 | print(new_bbox) 145 | index = False 146 | break 147 | 148 | # tracker init 149 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h]) 150 | state = SiamRPN_init(frame, target_pos, target_sz, net) 151 | 152 | # tracking and visualization 153 | toc = 0 154 | count_number = 0 155 | 156 | while True: 157 | # Read a new frame 158 | ok, frame = video.read() 159 | if not ok: 160 | break 161 | 162 | # Start timer 163 | tic = cv2.getTickCount() 164 | 165 | # Update tracker 166 | state = SiamRPN_track(state, frame) # track 167 | # print(state) 168 | 169 | toc += cv2.getTickCount() - tic 170 | 171 | if state: 172 | 173 | res = cxy_wh_2_rect(state['target_pos'], state['target_sz']) 174 | res = [int(l) for l in res] 175 | cv2.rectangle(frame, (res[0], res[1]), (res[0] + res[2], res[1] + res[3]), (0, 255, 255), 3) 176 | count_number += 1 177 | 178 | if (not state) or count_number % 40 == 3: 179 | # Tracking failure 180 | cv2.putText(frame, "Tracking failure detected", (100, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2) 181 | index = True 182 | while index: 183 | ok, frame = video.read() 184 | rclasses, rscores, rbboxes = process_image(frame) 185 | bboxes = get_bboxes(rclasses, rbboxes) 186 | for bbox in bboxes: 187 | if bbox['class'] == detect_class: 188 | ymin = int(bbox['y_min'] * height) 189 | xmin = int(bbox['x_min'] * width) 190 | ymax = int(bbox['y_max'] * height) 191 | xmax = int(bbox['x_max'] * width) 192 | cx = (xmin + xmax) / 2 193 | cy = (ymin + ymax) / 2 194 | h = ymax - ymin 195 | w = xmax - xmin 196 | new_bbox = (cx, cy, w, h) 197 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h]) 198 | state = SiamRPN_init(frame, target_pos, target_sz, net) 199 | 200 | p1 = (int(xmin), int(ymin)) 201 | p2 = (int(xmax), int(ymax)) 202 | cv2.rectangle(frame, p1, p2, (0, 255, 0), 2, 1) 203 | 204 | index = 0 205 | 206 | break 207 | 208 | cv2.imshow('SSD+SiamRPN', frame) 209 | 210 | # Exit if ESC pressed 211 | k = cv2.waitKey(1) & 0xff 212 | if k == 27: 213 | break 214 | 215 | video.release() 216 | cv2.destroyAllWindows() 217 | -------------------------------------------------------------------------------- /SSD/README.md: -------------------------------------------------------------------------------- 1 | # SSD: Single Shot MultiBox Detector in TensorFlow 2 | 3 | SSD is an unified framework for object detection with a single network. It has been originally introduced in this research [article](http://arxiv.org/abs/1512.02325). 4 | 5 | This repository contains a TensorFlow re-implementation of the original [Caffe code](https://github.com/weiliu89/caffe/tree/ssd). At present, it only implements VGG-based SSD networks (with 300 and 512 inputs), but the architecture of the project is modular, and should make easy the implementation and training of other SSD variants (ResNet or Inception based for instance). Present TF checkpoints have been directly converted from SSD Caffe models. 6 | 7 | The organisation is inspired by the TF-Slim models repository containing the implementation of popular architectures (ResNet, Inception and VGG). Hence, it is separated in three main parts: 8 | * datasets: interface to popular datasets (Pascal VOC, COCO, ...) and scripts to convert the former to TF-Records; 9 | * networks: definition of SSD networks, and common encoding and decoding methods (we refer to the paper on this precise topic); 10 | * pre-processing: pre-processing and data augmentation routines, inspired by original VGG and Inception implementations. 11 | 12 | ## SSD minimal example 13 | 14 | The [SSD Notebook](notebooks/ssd_notebook.ipynb) contains a minimal example of the SSD TensorFlow pipeline. Shortly, the detection is made of two main steps: running the SSD network on the image and post-processing the output using common algorithms (top-k filtering and Non-Maximum Suppression algorithm). 15 | 16 | Here are two examples of successful detection outputs: 17 | ![](pictures/ex1.png "SSD anchors") 18 | ![](pictures/ex2.png "SSD anchors") 19 | 20 | To run the notebook you first have to unzip the checkpoint files in ./checkpoint 21 | ```bash 22 | unzip ssd_300_vgg.ckpt.zip 23 | ``` 24 | and then start a jupyter notebook with 25 | ```bash 26 | jupyter notebook notebooks/ssd_notebook.ipynb 27 | ``` 28 | 29 | 30 | ## Datasets 31 | 32 | The current version only supports Pascal VOC datasets (2007 and 2012). In order to be used for training a SSD model, the former need to be converted to TF-Records using the `tf_convert_data.py` script: 33 | ```bash 34 | DATASET_DIR=./VOC2007/test/ 35 | OUTPUT_DIR=./tfrecords 36 | python tf_convert_data.py \ 37 | --dataset_name=pascalvoc \ 38 | --dataset_dir=${DATASET_DIR} \ 39 | --output_name=voc_2007_train \ 40 | --output_dir=${OUTPUT_DIR} 41 | ``` 42 | Note the previous command generated a collection of TF-Records instead of a single file in order to ease shuffling during training. 43 | 44 | ## Evaluation on Pascal VOC 2007 45 | 46 | The present TensorFlow implementation of SSD models have the following performances: 47 | 48 | | Model | Training data | Testing data | mAP | FPS | 49 | |--------|:---------:|:------:|:------:|:------:| 50 | | [SSD-300 VGG-based](https://drive.google.com/open?id=0B0qPCUZ-3YwWZlJaRTRRQWRFYXM) | VOC07+12 trainval | VOC07 test | 0.778 | - | 51 | | [SSD-300 VGG-based](https://drive.google.com/file/d/0B0qPCUZ-3YwWUXh4UHJrd1RDM3c/view?usp=sharing) | VOC07+12+COCO trainval | VOC07 test | 0.817 | - | 52 | | [SSD-512 VGG-based](https://drive.google.com/open?id=0B0qPCUZ-3YwWT1RCLVZNN3RTVEU) | VOC07+12+COCO trainval | VOC07 test | 0.837 | - | 53 | 54 | We are working hard at reproducing the same performance as the original [Caffe implementation](https://github.com/weiliu89/caffe/tree/ssd)! 55 | 56 | After downloading and extracting the previous checkpoints, the evaluation metrics should be reproducible by running the following command: 57 | ```bash 58 | EVAL_DIR=./logs/ 59 | CHECKPOINT_PATH=./checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt 60 | python eval_ssd_network.py \ 61 | --eval_dir=${EVAL_DIR} \ 62 | --dataset_dir=${DATASET_DIR} \ 63 | --dataset_name=pascalvoc_2007 \ 64 | --dataset_split_name=test \ 65 | --model_name=ssd_300_vgg \ 66 | --checkpoint_path=${CHECKPOINT_PATH} \ 67 | --batch_size=1 68 | ``` 69 | The evaluation script provides estimates on the recall-precision curve and compute the mAP metrics following the Pascal VOC 2007 and 2012 guidelines. 70 | 71 | In addition, if one wants to experiment/test a different Caffe SSD checkpoint, the former can be converted to TensorFlow checkpoints as following: 72 | ```sh 73 | CAFFE_MODEL=./ckpts/SSD_300x300_ft_VOC0712/VGG_VOC0712_SSD_300x300_ft_iter_120000.caffemodel 74 | python caffe_to_tensorflow.py \ 75 | --model_name=ssd_300_vgg \ 76 | --num_classes=21 \ 77 | --caffemodel_path=${CAFFE_MODEL} 78 | ``` 79 | 80 | ## Training 81 | 82 | The script `train_ssd_network.py` is in charged of training the network. Similarly to TF-Slim models, one can pass numerous options to the training process (dataset, optimiser, hyper-parameters, model, ...). In particular, it is possible to provide a checkpoint file which can be use as starting point in order to fine-tune a network. 83 | 84 | ### Fine-tuning existing SSD checkpoints 85 | 86 | The easiest way to fine the SSD model is to use as pre-trained SSD network (VGG-300 or VGG-512). For instance, one can fine a model starting from the former as following: 87 | ```bash 88 | DATASET_DIR=./tfrecords 89 | TRAIN_DIR=./logs/ 90 | CHECKPOINT_PATH=./checkpoints/ssd_300_vgg.ckpt 91 | python train_ssd_network.py \ 92 | --train_dir=${TRAIN_DIR} \ 93 | --dataset_dir=${DATASET_DIR} \ 94 | --dataset_name=pascalvoc_2012 \ 95 | --dataset_split_name=train \ 96 | --model_name=ssd_300_vgg \ 97 | --checkpoint_path=${CHECKPOINT_PATH} \ 98 | --save_summaries_secs=60 \ 99 | --save_interval_secs=600 \ 100 | --weight_decay=0.0005 \ 101 | --optimizer=adam \ 102 | --learning_rate=0.001 \ 103 | --batch_size=32 104 | ``` 105 | Note that in addition to the training script flags, one may also want to experiment with data augmentation parameters (random cropping, resolution, ...) in `ssd_vgg_preprocessing.py` or/and network parameters (feature layers, anchors boxes, ...) in `ssd_vgg_300/512.py` 106 | 107 | Furthermore, the training script can be combined with the evaluation routine in order to monitor the performance of saved checkpoints on a validation dataset. For that purpose, one can pass to training and validation scripts a GPU memory upper limit such that both can run in parallel on the same device. If some GPU memory is available for the evaluation script, the former can be run in parallel as follows: 108 | ```bash 109 | EVAL_DIR=${TRAIN_DIR}/eval 110 | python eval_ssd_network.py \ 111 | --eval_dir=${EVAL_DIR} \ 112 | --dataset_dir=${DATASET_DIR} \ 113 | --dataset_name=pascalvoc_2007 \ 114 | --dataset_split_name=test \ 115 | --model_name=ssd_300_vgg \ 116 | --checkpoint_path=${TRAIN_DIR} \ 117 | --wait_for_checkpoints=True \ 118 | --batch_size=1 \ 119 | --max_num_batches=500 120 | ``` 121 | 122 | ### Fine-tuning a network trained on ImageNet 123 | 124 | One can also try to build a new SSD model based on standard architecture (VGG, ResNet, Inception, ...) and set up on top of it the `multibox` layers (with specific anchors, ratios, ...). For that purpose, you can fine-tune a network by only loading the weights of the original architecture, and initialize randomly the rest of network. For instance, in the case of the [VGG-16 architecture](http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz), one can train a new model as following: 125 | ```bash 126 | DATASET_DIR=./tfrecords 127 | TRAIN_DIR=./log/ 128 | CHECKPOINT_PATH=./checkpoints/vgg_16.ckpt 129 | python train_ssd_network.py \ 130 | --train_dir=${TRAIN_DIR} \ 131 | --dataset_dir=${DATASET_DIR} \ 132 | --dataset_name=pascalvoc_2007 \ 133 | --dataset_split_name=train \ 134 | --model_name=ssd_300_vgg \ 135 | --checkpoint_path=${CHECKPOINT_PATH} \ 136 | --checkpoint_model_scope=vgg_16 \ 137 | --checkpoint_exclude_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \ 138 | --trainable_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \ 139 | --save_summaries_secs=60 \ 140 | --save_interval_secs=600 \ 141 | --weight_decay=0.0005 \ 142 | --optimizer=adam \ 143 | --learning_rate=0.001 \ 144 | --learning_rate_decay_factor=0.94 \ 145 | --batch_size=32 146 | ``` 147 | Hence, in the former command, the training script randomly initializes the weights belonging to the `checkpoint_exclude_scopes` and load from the checkpoint file `vgg_16.ckpt` the remaining part of the network. Note that we also specify with the `trainable_scopes` parameter to first only train the new SSD components and left the rest of VGG network unchanged. Once the network has converged to a good first result (~0.5 mAP for instance), you can fine-tuned the complete network as following: 148 | ```bash 149 | DATASET_DIR=./tfrecords 150 | TRAIN_DIR=./log_finetune/ 151 | CHECKPOINT_PATH=./log/model.ckpt-N 152 | python train_ssd_network.py \ 153 | --train_dir=${TRAIN_DIR} \ 154 | --dataset_dir=${DATASET_DIR} \ 155 | --dataset_name=pascalvoc_2007 \ 156 | --dataset_split_name=train \ 157 | --model_name=ssd_300_vgg \ 158 | --checkpoint_path=${CHECKPOINT_PATH} \ 159 | --checkpoint_model_scope=vgg_16 \ 160 | --save_summaries_secs=60 \ 161 | --save_interval_secs=600 \ 162 | --weight_decay=0.0005 \ 163 | --optimizer=adam \ 164 | --learning_rate=0.00001 \ 165 | --learning_rate_decay_factor=0.94 \ 166 | --batch_size=32 167 | ``` 168 | 169 | A number of pre-trained weights of popular deep architectures can be found on [TF-Slim models page](https://github.com/tensorflow/models/tree/master/slim). 170 | -------------------------------------------------------------------------------- /SSD/UseSSD.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import tensorflow as tf 3 | 4 | slim = tf.contrib.slim 5 | 6 | from nets import ssd_vgg_300, np_methods 7 | from preprocessing import ssd_vgg_preprocessing 8 | 9 | import multiprocessing 10 | 11 | 12 | def set_centers(): 13 | 14 | print("开启线程:将object_centers放入queue") 15 | 16 | # TensorFlow session: grow memory when needed. TF, DO NOT USE ALL MY GPU MEMORY!!! 17 | gpu_options = tf.GPUOptions(allow_growth=True) 18 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options) 19 | isess = tf.InteractiveSession(config=config) 20 | 21 | # Input placeholder. 22 | net_shape = (300, 300) 23 | data_format = 'NHWC' 24 | img_input = tf.placeholder(tf.uint8, shape=(None, None, 3)) 25 | # Evaluation pre-processing: resize to SSD net shape. 26 | image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval( 27 | img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE) 28 | image_4d = tf.expand_dims(image_pre, 0) 29 | 30 | # Define the SSD model. 31 | reuse = True if 'ssd_net' in locals() else None 32 | ssd_net = ssd_vgg_300.SSDNet() 33 | with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)): 34 | predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse) 35 | 36 | # Restore SSD model. 37 | # ckpt_filename = 'checkpoints/ssd_300_vgg.ckpt' 38 | ckpt_filename = 'checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt' 39 | 40 | isess.run(tf.global_variables_initializer()) 41 | saver = tf.train.Saver() 42 | saver.restore(isess, ckpt_filename) 43 | 44 | # SSD default anchor boxes. 45 | ssd_anchors = ssd_net.anchors(net_shape) 46 | 47 | # Main image processing routine. 48 | def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)): 49 | # Run SSD network. 50 | rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img], 51 | feed_dict={img_input: img}) 52 | 53 | # Get classes and bboxes from the net outputs. 54 | rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select( 55 | rpredictions, rlocalisations, ssd_anchors, 56 | select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True) 57 | 58 | rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes) 59 | rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400) 60 | rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold) 61 | # Resize bboxes to original image shape. Note: useless for Resize.WARP! 62 | rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes) 63 | return rclasses, rscores, rbboxes 64 | 65 | def get_centers(rclasses, rbboxes): 66 | # get center location of object 67 | 68 | number_classes = rclasses.shape[0] 69 | object_centers = [] 70 | for i in range(number_classes): 71 | object_center = dict() 72 | object_center['i'] = i 73 | object_center['class'] = rclasses[i] 74 | object_center['x'] = (rbboxes[i, 1] + rbboxes[i, 3]) / 2 # 对象中心的坐标x 75 | object_center['y'] = (rbboxes[i, 0] + rbboxes[i, 2]) / 2 # 对象中心的坐标y 76 | object_centers.append(object_center) 77 | return object_centers 78 | 79 | count = 0 80 | cap = cv2.VideoCapture(0) 81 | 82 | while count < 100: 83 | # 打开摄像头 84 | ret, img = cap.read() 85 | rclasses, rscores, rbboxes = process_image(img) 86 | 87 | ''' 88 | classes: 89 | 1.Aeroplanes 2.Bicycles 3.Birds 4.Boats 5.Bottles 90 | 6.Buses 7.Cars 8.Cats 9.Chairs 10.Cows 91 | 11.Dining tables 12.Dogs 13.Horses 14.Motorbikes 15.People 92 | 16.Potted plants 17.Sheep 18.Sofas 19.Trains 20.TV/Monitors 93 | ''' 94 | object_centers = get_centers(rclasses, rbboxes) 95 | # print("put object centers: " + str(object_centers)) 96 | for object_center in object_centers: 97 | if object_center['class'] == 5 or object_center['class'] == 7: 98 | new_object_center = object_center 99 | q.put(new_object_center) 100 | count += 1 101 | break 102 | print("完成输入") 103 | cap.release() 104 | 105 | 106 | 107 | 108 | def print_centers(): 109 | 110 | 111 | print("开启线程:将object_center打印出来") 112 | while True: 113 | if q: 114 | print("get object center:" + str(q.get(True))) 115 | 116 | print("完成输出") 117 | 118 | 119 | q = multiprocessing.Queue() 120 | 121 | set_process = multiprocessing.Process(target=set_centers) 122 | print_process = multiprocessing.Process(target=print_centers) 123 | 124 | set_process.start() 125 | print_process.start() 126 | 127 | set_process.join() 128 | print_process.terminate() 129 | 130 | print("退出主线程") 131 | 132 | -------------------------------------------------------------------------------- /SSD/caffe_to_tensorflow.py: -------------------------------------------------------------------------------- 1 | """Convert a Caffe model file to TensorFlow checkpoint format. 2 | 3 | Assume that the network built is a equivalent (or a sub-) to the Caffe 4 | definition. 5 | """ 6 | import tensorflow as tf 7 | 8 | from nets import caffe_scope 9 | from nets import nets_factory 10 | 11 | slim = tf.contrib.slim 12 | 13 | # =========================================================================== # 14 | # Main flags. 15 | # =========================================================================== # 16 | tf.app.flags.DEFINE_string( 17 | 'model_name', 'ssd_300_vgg', 'Name of the model to convert.') 18 | tf.app.flags.DEFINE_string( 19 | 'num_classes', 21, 'Number of classes in the dataset.') 20 | tf.app.flags.DEFINE_string( 21 | 'caffemodel_path', None, 22 | 'The path to the Caffe model file to convert.') 23 | 24 | FLAGS = tf.app.flags.FLAGS 25 | 26 | 27 | # =========================================================================== # 28 | # Main converting routine. 29 | # =========================================================================== # 30 | def main(_): 31 | # Caffe scope... 32 | caffemodel = caffe_scope.CaffeScope() 33 | caffemodel.load(FLAGS.caffemodel_path) 34 | 35 | tf.logging.set_verbosity(tf.logging.INFO) 36 | with tf.Graph().as_default(): 37 | global_step = slim.create_global_step() 38 | num_classes = int(FLAGS.num_classes) 39 | 40 | # Select the network. 41 | ssd_class = nets_factory.get_network(FLAGS.model_name) 42 | ssd_params = ssd_class.default_params._replace(num_classes=num_classes) 43 | ssd_net = ssd_class(ssd_params) 44 | ssd_shape = ssd_net.params.img_shape 45 | 46 | # Image placeholder and model. 47 | shape = (1, ssd_shape[0], ssd_shape[1], 3) 48 | img_input = tf.placeholder(shape=shape, dtype=tf.float32) 49 | # Create model. 50 | with slim.arg_scope(ssd_net.arg_scope_caffe(caffemodel)): 51 | ssd_net.net(img_input, is_training=False) 52 | 53 | init_op = tf.global_variables_initializer() 54 | with tf.Session() as session: 55 | # Run the init operation. 56 | session.run(init_op) 57 | 58 | # Save model in checkpoint. 59 | saver = tf.train.Saver() 60 | ckpt_path = FLAGS.caffemodel_path.replace('.caffemodel', '.ckpt') 61 | saver.save(session, ckpt_path, write_meta_graph=False) 62 | 63 | 64 | if __name__ == '__main__': 65 | tf.app.run() 66 | 67 | -------------------------------------------------------------------------------- /SSD/code_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | from multiprocessing import Process, Queue 5 | 6 | 7 | # 写数据进程执行的代码: 8 | def proc_write(q, urls): 9 | print('Process(%s) is writing...' % os.getpid()) 10 | for url in urls: 11 | q.put(url) 12 | print('Put %s to queue...' % url) 13 | time.sleep(random.random()) 14 | 15 | 16 | # 读数据进程执行的代码: 17 | def proc_read(q): 18 | print('Process(%s) is reading...' % os.getpid()) 19 | while True: 20 | url = q.get(True) 21 | print('Get %s from queue.' % url) 22 | 23 | 24 | if __name__ == '__main__': 25 | # 父进程创建Queue,并传给各个子进程: 26 | q = Queue() 27 | proc_writer1 = Process(target=proc_write, args=(q, ['url_1', 'url_2', 'url_3'])) 28 | proc_reader = Process(target=proc_read, args=(q,)) 29 | # 启动子进程proc_writer,写入: 30 | proc_writer1.start() 31 | # 启动子进程proc_reader,读取: 32 | proc_reader.start() 33 | # 等待proc_writer结束: 34 | proc_writer1.join() 35 | # proc_reader进程里是死循环,无法等待其结束,只能强行终止: 36 | proc_reader.terminate() 37 | -------------------------------------------------------------------------------- /SSD/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SSD/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/slim/data/create_cifar10_dataset.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | 27 | import tensorflow as tf 28 | 29 | from datasets import dataset_utils 30 | 31 | slim = tf.contrib.slim 32 | 33 | _FILE_PATTERN = 'cifar10_%s.tfrecord' 34 | 35 | SPLITS_TO_SIZES = {'train': 50000, 'test': 10000} 36 | 37 | _NUM_CLASSES = 10 38 | 39 | _ITEMS_TO_DESCRIPTIONS = { 40 | 'image': 'A [32 x 32 x 3] color image.', 41 | 'label': 'A single integer between 0 and 9', 42 | } 43 | 44 | 45 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 46 | """Gets a dataset tuple with instructions for reading cifar10. 47 | 48 | Args: 49 | split_name: A train/test split name. 50 | dataset_dir: The base directory of the dataset sources. 51 | file_pattern: The file pattern to use when matching the dataset sources. 52 | It is assumed that the pattern contains a '%s' string so that the split 53 | name can be inserted. 54 | reader: The TensorFlow reader type. 55 | 56 | Returns: 57 | A `Dataset` namedtuple. 58 | 59 | Raises: 60 | ValueError: if `split_name` is not a valid train/test split. 61 | """ 62 | if split_name not in SPLITS_TO_SIZES: 63 | raise ValueError('split name %s was not recognized.' % split_name) 64 | 65 | if not file_pattern: 66 | file_pattern = _FILE_PATTERN 67 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 68 | 69 | # Allowing None in the signature so that dataset_factory can use the default. 70 | if not reader: 71 | reader = tf.TFRecordReader 72 | 73 | keys_to_features = { 74 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 75 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 76 | 'image/class/label': tf.FixedLenFeature( 77 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 78 | } 79 | 80 | items_to_handlers = { 81 | 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]), 82 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 83 | } 84 | 85 | decoder = slim.tfexample_decoder.TFExampleDecoder( 86 | keys_to_features, items_to_handlers) 87 | 88 | labels_to_names = None 89 | if dataset_utils.has_labels(dataset_dir): 90 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 91 | 92 | return slim.dataset.Dataset( 93 | data_sources=file_pattern, 94 | reader=reader, 95 | decoder=decoder, 96 | num_samples=SPLITS_TO_SIZES[split_name], 97 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 98 | num_classes=_NUM_CLASSES, 99 | labels_to_names=labels_to_names) 100 | -------------------------------------------------------------------------------- /SSD/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import imagenet 23 | from datasets import pascalvoc_2007 24 | from datasets import pascalvoc_2012 25 | 26 | datasets_map = { 27 | 'cifar10': cifar10, 28 | 'imagenet': imagenet, 29 | 'pascalvoc_2007': pascalvoc_2007, 30 | 'pascalvoc_2012': pascalvoc_2012, 31 | } 32 | 33 | 34 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 35 | """Given a dataset name and a split_name returns a Dataset. 36 | 37 | Args: 38 | name: String, the name of the dataset. 39 | split_name: A train/test split name. 40 | dataset_dir: The directory where the dataset files are stored. 41 | file_pattern: The file pattern to use for matching the dataset source files. 42 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 43 | reader defined by each dataset is used. 44 | Returns: 45 | A `Dataset` class. 46 | Raises: 47 | ValueError: If the dataset `name` is unknown. 48 | """ 49 | if name not in datasets_map: 50 | raise ValueError('Name of dataset unknown %s' % name) 51 | return datasets_map[name].get_split(split_name, 52 | dataset_dir, 53 | file_pattern, 54 | reader) 55 | -------------------------------------------------------------------------------- /SSD/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | import tensorflow as tf 25 | from six.moves import urllib 26 | 27 | LABELS_FILENAME = 'labels.txt' 28 | 29 | 30 | def int64_feature(value): 31 | """Wrapper for inserting int64 features into Example proto. 32 | """ 33 | if not isinstance(value, list): 34 | value = [value] 35 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 36 | 37 | 38 | def float_feature(value): 39 | """Wrapper for inserting float features into Example proto. 40 | """ 41 | if not isinstance(value, list): 42 | value = [value] 43 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 44 | 45 | 46 | def bytes_feature(value): 47 | """Wrapper for inserting bytes features into Example proto. 48 | """ 49 | if not isinstance(value, list): 50 | value = [value] 51 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 52 | 53 | 54 | def image_to_tfexample(image_data, image_format, height, width, class_id): 55 | return tf.train.Example(features=tf.train.Features(feature={ 56 | 'image/encoded': bytes_feature(image_data), 57 | 'image/format': bytes_feature(image_format), 58 | 'image/class/label': int64_feature(class_id), 59 | 'image/height': int64_feature(height), 60 | 'image/width': int64_feature(width), 61 | })) 62 | 63 | 64 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 65 | """Downloads the `tarball_url` and uncompresses it locally. 66 | 67 | Args: 68 | tarball_url: The URL of a tarball file. 69 | dataset_dir: The directory where the temporary files are stored. 70 | """ 71 | filename = tarball_url.split('/')[-1] 72 | filepath = os.path.join(dataset_dir, filename) 73 | 74 | def _progress(count, block_size, total_size): 75 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 76 | filename, float(count * block_size) / float(total_size) * 100.0)) 77 | sys.stdout.flush() 78 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 79 | print() 80 | statinfo = os.stat(filepath) 81 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 82 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 83 | 84 | 85 | def write_label_file(labels_to_class_names, dataset_dir, 86 | filename=LABELS_FILENAME): 87 | """Writes a file with the list of class names. 88 | 89 | Args: 90 | labels_to_class_names: A map of (integer) labels to class names. 91 | dataset_dir: The directory in which the labels file should be written. 92 | filename: The filename where the class names are written. 93 | """ 94 | labels_filename = os.path.join(dataset_dir, filename) 95 | with tf.gfile.Open(labels_filename, 'w') as f: 96 | for label in labels_to_class_names: 97 | class_name = labels_to_class_names[label] 98 | f.write('%d:%s\n' % (label, class_name)) 99 | 100 | 101 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 102 | """Specifies whether or not the dataset directory contains a label map file. 103 | 104 | Args: 105 | dataset_dir: The directory in which the labels file is found. 106 | filename: The filename where the class names are written. 107 | 108 | Returns: 109 | `True` if the labels file exists and `False` otherwise. 110 | """ 111 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 112 | 113 | 114 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 115 | """Reads the labels file and returns a mapping from ID to class name. 116 | 117 | Args: 118 | dataset_dir: The directory in which the labels file is found. 119 | filename: The filename where the class names are written. 120 | 121 | Returns: 122 | A map from a label (integer) to class name. 123 | """ 124 | labels_filename = os.path.join(dataset_dir, filename) 125 | with tf.gfile.Open(labels_filename, 'rb') as f: 126 | lines = f.read() 127 | lines = lines.split(b'\n') 128 | lines = filter(None, lines) 129 | 130 | labels_to_class_names = {} 131 | for line in lines: 132 | index = line.index(b':') 133 | labels_to_class_names[int(line[:index])] = line[index+1:] 134 | return labels_to_class_names 135 | -------------------------------------------------------------------------------- /SSD/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the ImageNet ILSVRC 2012 Dataset plus some bounding boxes. 16 | 17 | Some images have one or more bounding boxes associated with the label of the 18 | image. See details here: http://image-net.org/download-bboxes 19 | 20 | ImageNet is based upon WordNet 3.0. To uniquely identify a synset, we use 21 | "WordNet ID" (wnid), which is a concatenation of POS ( i.e. part of speech ) 22 | and SYNSET OFFSET of WordNet. For more information, please refer to the 23 | WordNet documentation[http://wordnet.princeton.edu/wordnet/documentation/]. 24 | 25 | "There are bounding boxes for over 3000 popular synsets available. 26 | For each synset, there are on average 150 images with bounding boxes." 27 | 28 | WARNING: Don't use for object detection, in this case all the bounding boxes 29 | of the image belong to just one class. 30 | """ 31 | from __future__ import absolute_import 32 | from __future__ import division 33 | from __future__ import print_function 34 | 35 | import os 36 | 37 | import tensorflow as tf 38 | from six.moves import urllib 39 | 40 | from datasets import dataset_utils 41 | 42 | slim = tf.contrib.slim 43 | 44 | # TODO(nsilberman): Add tfrecord file type once the script is updated. 45 | _FILE_PATTERN = '%s-*' 46 | 47 | _SPLITS_TO_SIZES = { 48 | 'train': 1281167, 49 | 'validation': 50000, 50 | } 51 | 52 | _ITEMS_TO_DESCRIPTIONS = { 53 | 'image': 'A color image of varying height and width.', 54 | 'label': 'The label id of the image, integer between 0 and 999', 55 | 'label_text': 'The text of the label.', 56 | 'object/bbox': 'A list of bounding boxes.', 57 | 'object/label': 'A list of labels, one per each object.', 58 | } 59 | 60 | _NUM_CLASSES = 1001 61 | 62 | 63 | def create_readable_names_for_imagenet_labels(): 64 | """Create a dict mapping label id to human readable string. 65 | 66 | Returns: 67 | labels_to_names: dictionary where keys are integers from to 1000 68 | and values are human-readable names. 69 | 70 | We retrieve a synset file, which contains a list of valid synset labels used 71 | by ILSVRC competition. There is one synset one per line, eg. 72 | # n01440764 73 | # n01443537 74 | We also retrieve a synset_to_human_file, which contains a mapping from synsets 75 | to human-readable names for every synset in Imagenet. These are stored in a 76 | tsv format, as follows: 77 | # n02119247 black fox 78 | # n02119359 silver fox 79 | We assign each synset (in alphabetical order) an integer, starting from 1 80 | (since 0 is reserved for the background class). 81 | 82 | Code is based on 83 | https://github.com/tensorflow/models/blob/master/inception/inception/data/build_imagenet_data.py#L463 84 | """ 85 | 86 | # pylint: disable=g-line-too-long 87 | base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/inception/inception/data/' 88 | synset_url = '{}/imagenet_lsvrc_2015_synsets.txt'.format(base_url) 89 | synset_to_human_url = '{}/imagenet_metadata.txt'.format(base_url) 90 | 91 | filename, _ = urllib.request.urlretrieve(synset_url) 92 | synset_list = [s.strip() for s in open(filename).readlines()] 93 | num_synsets_in_ilsvrc = len(synset_list) 94 | assert num_synsets_in_ilsvrc == 1000 95 | 96 | filename, _ = urllib.request.urlretrieve(synset_to_human_url) 97 | synset_to_human_list = open(filename).readlines() 98 | num_synsets_in_all_imagenet = len(synset_to_human_list) 99 | assert num_synsets_in_all_imagenet == 21842 100 | 101 | synset_to_human = {} 102 | for s in synset_to_human_list: 103 | parts = s.strip().split('\t') 104 | assert len(parts) == 2 105 | synset = parts[0] 106 | human = parts[1] 107 | synset_to_human[synset] = human 108 | 109 | label_index = 1 110 | labels_to_names = {0: 'background'} 111 | for synset in synset_list: 112 | name = synset_to_human[synset] 113 | labels_to_names[label_index] = name 114 | label_index += 1 115 | 116 | return labels_to_names 117 | 118 | 119 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 120 | """Gets a dataset tuple with instructions for reading ImageNet. 121 | 122 | Args: 123 | split_name: A train/test split name. 124 | dataset_dir: The base directory of the dataset sources. 125 | file_pattern: The file pattern to use when matching the dataset sources. 126 | It is assumed that the pattern contains a '%s' string so that the split 127 | name can be inserted. 128 | reader: The TensorFlow reader type. 129 | 130 | Returns: 131 | A `Dataset` namedtuple. 132 | 133 | Raises: 134 | ValueError: if `split_name` is not a valid train/test split. 135 | """ 136 | if split_name not in _SPLITS_TO_SIZES: 137 | raise ValueError('split name %s was not recognized.' % split_name) 138 | 139 | if not file_pattern: 140 | file_pattern = _FILE_PATTERN 141 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 142 | 143 | # Allowing None in the signature so that dataset_factory can use the default. 144 | if reader is None: 145 | reader = tf.TFRecordReader 146 | 147 | keys_to_features = { 148 | 'image/encoded': tf.FixedLenFeature( 149 | (), tf.string, default_value=''), 150 | 'image/format': tf.FixedLenFeature( 151 | (), tf.string, default_value='jpeg'), 152 | 'image/class/label': tf.FixedLenFeature( 153 | [], dtype=tf.int64, default_value=-1), 154 | 'image/class/text': tf.FixedLenFeature( 155 | [], dtype=tf.string, default_value=''), 156 | 'image/object/bbox/xmin': tf.VarLenFeature( 157 | dtype=tf.float32), 158 | 'image/object/bbox/ymin': tf.VarLenFeature( 159 | dtype=tf.float32), 160 | 'image/object/bbox/xmax': tf.VarLenFeature( 161 | dtype=tf.float32), 162 | 'image/object/bbox/ymax': tf.VarLenFeature( 163 | dtype=tf.float32), 164 | 'image/object/class/label': tf.VarLenFeature( 165 | dtype=tf.int64), 166 | } 167 | 168 | items_to_handlers = { 169 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 170 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 171 | 'label_text': slim.tfexample_decoder.Tensor('image/class/text'), 172 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 173 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 174 | 'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'), 175 | } 176 | 177 | decoder = slim.tfexample_decoder.TFExampleDecoder( 178 | keys_to_features, items_to_handlers) 179 | 180 | labels_to_names = None 181 | if dataset_utils.has_labels(dataset_dir): 182 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 183 | else: 184 | labels_to_names = create_readable_names_for_imagenet_labels() 185 | dataset_utils.write_label_file(labels_to_names, dataset_dir) 186 | 187 | return slim.dataset.Dataset( 188 | data_sources=file_pattern, 189 | reader=reader, 190 | decoder=decoder, 191 | num_samples=_SPLITS_TO_SIZES[split_name], 192 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 193 | num_classes=_NUM_CLASSES, 194 | labels_to_names=labels_to_names) 195 | -------------------------------------------------------------------------------- /SSD/datasets/pascalvoc_2007.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Pascal VOC Dataset (images + annotations). 16 | """ 17 | import tensorflow as tf 18 | 19 | from datasets import pascalvoc_common 20 | 21 | slim = tf.contrib.slim 22 | 23 | FILE_PATTERN = 'voc_2007_%s_*.tfrecord' 24 | ITEMS_TO_DESCRIPTIONS = { 25 | 'image': 'A color image of varying height and width.', 26 | 'shape': 'Shape of the image', 27 | 'object/bbox': 'A list of bounding boxes, one per each object.', 28 | 'object/label': 'A list of labels, one per each object.', 29 | } 30 | # (Images, Objects) statistics on every class. 31 | TRAIN_STATISTICS = { 32 | 'none': (0, 0), 33 | 'aeroplane': (238, 306), 34 | 'bicycle': (243, 353), 35 | 'bird': (330, 486), 36 | 'boat': (181, 290), 37 | 'bottle': (244, 505), 38 | 'bus': (186, 229), 39 | 'car': (713, 1250), 40 | 'cat': (337, 376), 41 | 'chair': (445, 798), 42 | 'cow': (141, 259), 43 | 'diningtable': (200, 215), 44 | 'dog': (421, 510), 45 | 'horse': (287, 362), 46 | 'motorbike': (245, 339), 47 | 'person': (2008, 4690), 48 | 'pottedplant': (245, 514), 49 | 'sheep': (96, 257), 50 | 'sofa': (229, 248), 51 | 'train': (261, 297), 52 | 'tvmonitor': (256, 324), 53 | 'total': (5011, 12608), 54 | } 55 | TEST_STATISTICS = { 56 | 'none': (0, 0), 57 | 'aeroplane': (1, 1), 58 | 'bicycle': (1, 1), 59 | 'bird': (1, 1), 60 | 'boat': (1, 1), 61 | 'bottle': (1, 1), 62 | 'bus': (1, 1), 63 | 'car': (1, 1), 64 | 'cat': (1, 1), 65 | 'chair': (1, 1), 66 | 'cow': (1, 1), 67 | 'diningtable': (1, 1), 68 | 'dog': (1, 1), 69 | 'horse': (1, 1), 70 | 'motorbike': (1, 1), 71 | 'person': (1, 1), 72 | 'pottedplant': (1, 1), 73 | 'sheep': (1, 1), 74 | 'sofa': (1, 1), 75 | 'train': (1, 1), 76 | 'tvmonitor': (1, 1), 77 | 'total': (20, 20), 78 | } 79 | SPLITS_TO_SIZES = { 80 | 'train': 5011, 81 | 'test': 4952, 82 | } 83 | SPLITS_TO_STATISTICS = { 84 | 'train': TRAIN_STATISTICS, 85 | 'test': TEST_STATISTICS, 86 | } 87 | NUM_CLASSES = 20 88 | 89 | 90 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 91 | """Gets a dataset tuple with instructions for reading ImageNet. 92 | 93 | Args: 94 | split_name: A train/test split name. 95 | dataset_dir: The base directory of the dataset sources. 96 | file_pattern: The file pattern to use when matching the dataset sources. 97 | It is assumed that the pattern contains a '%s' string so that the split 98 | name can be inserted. 99 | reader: The TensorFlow reader type. 100 | 101 | Returns: 102 | A `Dataset` namedtuple. 103 | 104 | Raises: 105 | ValueError: if `split_name` is not a valid train/test split. 106 | """ 107 | if not file_pattern: 108 | file_pattern = FILE_PATTERN 109 | return pascalvoc_common.get_split(split_name, dataset_dir, 110 | file_pattern, reader, 111 | SPLITS_TO_SIZES, 112 | ITEMS_TO_DESCRIPTIONS, 113 | NUM_CLASSES) 114 | -------------------------------------------------------------------------------- /SSD/datasets/pascalvoc_2012.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Pascal VOC Dataset (images + annotations). 16 | """ 17 | import tensorflow as tf 18 | 19 | from datasets import pascalvoc_common 20 | 21 | slim = tf.contrib.slim 22 | 23 | FILE_PATTERN = 'voc_2012_%s_*.tfrecord' 24 | ITEMS_TO_DESCRIPTIONS = { 25 | 'image': 'A color image of varying height and width.', 26 | 'shape': 'Shape of the image', 27 | 'object/bbox': 'A list of bounding boxes, one per each object.', 28 | 'object/label': 'A list of labels, one per each object.', 29 | } 30 | # (Images, Objects) statistics on every class. 31 | TRAIN_STATISTICS = { 32 | 'none': (0, 0), 33 | 'aeroplane': (670, 865), 34 | 'bicycle': (552, 711), 35 | 'bird': (765, 1119), 36 | 'boat': (508, 850), 37 | 'bottle': (706, 1259), 38 | 'bus': (421, 593), 39 | 'car': (1161, 2017), 40 | 'cat': (1080, 1217), 41 | 'chair': (1119, 2354), 42 | 'cow': (303, 588), 43 | 'diningtable': (538, 609), 44 | 'dog': (1286, 1515), 45 | 'horse': (482, 710), 46 | 'motorbike': (526, 713), 47 | 'person': (4087, 8566), 48 | 'pottedplant': (527, 973), 49 | 'sheep': (325, 813), 50 | 'sofa': (507, 566), 51 | 'train': (544, 628), 52 | 'tvmonitor': (575, 784), 53 | 'total': (11540, 27450), 54 | } 55 | SPLITS_TO_SIZES = { 56 | 'train': 17125, 57 | } 58 | SPLITS_TO_STATISTICS = { 59 | 'train': TRAIN_STATISTICS, 60 | } 61 | NUM_CLASSES = 20 62 | 63 | 64 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 65 | """Gets a dataset tuple with instructions for reading ImageNet. 66 | 67 | Args: 68 | split_name: A train/test split name. 69 | dataset_dir: The base directory of the dataset sources. 70 | file_pattern: The file pattern to use when matching the dataset sources. 71 | It is assumed that the pattern contains a '%s' string so that the split 72 | name can be inserted. 73 | reader: The TensorFlow reader type. 74 | 75 | Returns: 76 | A `Dataset` namedtuple. 77 | 78 | Raises: 79 | ValueError: if `split_name` is not a valid train/test split. 80 | """ 81 | if not file_pattern: 82 | file_pattern = FILE_PATTERN 83 | return pascalvoc_common.get_split(split_name, dataset_dir, 84 | file_pattern, reader, 85 | SPLITS_TO_SIZES, 86 | ITEMS_TO_DESCRIPTIONS, 87 | NUM_CLASSES) 88 | 89 | -------------------------------------------------------------------------------- /SSD/datasets/pascalvoc_common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Pascal VOC Dataset (images + annotations). 16 | """ 17 | import os 18 | 19 | import tensorflow as tf 20 | 21 | from datasets import dataset_utils 22 | 23 | slim = tf.contrib.slim 24 | 25 | VOC_LABELS = { 26 | 'none': (0, 'Background'), 27 | 'aeroplane': (1, 'Vehicle'), 28 | 'bicycle': (2, 'Vehicle'), 29 | 'bird': (3, 'Animal'), 30 | 'boat': (4, 'Vehicle'), 31 | 'bottle': (5, 'Indoor'), 32 | 'bus': (6, 'Vehicle'), 33 | 'car': (7, 'Vehicle'), 34 | 'cat': (8, 'Animal'), 35 | 'chair': (9, 'Indoor'), 36 | 'cow': (10, 'Animal'), 37 | 'diningtable': (11, 'Indoor'), 38 | 'dog': (12, 'Animal'), 39 | 'horse': (13, 'Animal'), 40 | 'motorbike': (14, 'Vehicle'), 41 | 'person': (15, 'Person'), 42 | 'pottedplant': (16, 'Indoor'), 43 | 'sheep': (17, 'Animal'), 44 | 'sofa': (18, 'Indoor'), 45 | 'train': (19, 'Vehicle'), 46 | 'tvmonitor': (20, 'Indoor'), 47 | } 48 | 49 | 50 | def get_split(split_name, dataset_dir, file_pattern, reader, 51 | split_to_sizes, items_to_descriptions, num_classes): 52 | """Gets a dataset tuple with instructions for reading Pascal VOC dataset. 53 | 54 | Args: 55 | split_name: A train/test split name. 56 | dataset_dir: The base directory of the dataset sources. 57 | file_pattern: The file pattern to use when matching the dataset sources. 58 | It is assumed that the pattern contains a '%s' string so that the split 59 | name can be inserted. 60 | reader: The TensorFlow reader type. 61 | 62 | Returns: 63 | A `Dataset` namedtuple. 64 | 65 | Raises: 66 | ValueError: if `split_name` is not a valid train/test split. 67 | """ 68 | if split_name not in split_to_sizes: 69 | raise ValueError('split name %s was not recognized.' % split_name) 70 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 71 | 72 | # Allowing None in the signature so that dataset_factory can use the default. 73 | if reader is None: 74 | reader = tf.TFRecordReader 75 | # Features in Pascal VOC TFRecords. 76 | keys_to_features = { 77 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 78 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 79 | 'image/height': tf.FixedLenFeature([1], tf.int64), 80 | 'image/width': tf.FixedLenFeature([1], tf.int64), 81 | 'image/channels': tf.FixedLenFeature([1], tf.int64), 82 | 'image/shape': tf.FixedLenFeature([3], tf.int64), 83 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 84 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 85 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 86 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 87 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64), 88 | 'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64), 89 | 'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64), 90 | } 91 | items_to_handlers = { 92 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 93 | 'shape': slim.tfexample_decoder.Tensor('image/shape'), 94 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 95 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 96 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'), 97 | 'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'), 98 | 'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'), 99 | } 100 | decoder = slim.tfexample_decoder.TFExampleDecoder( 101 | keys_to_features, items_to_handlers) 102 | 103 | labels_to_names = None 104 | if dataset_utils.has_labels(dataset_dir): 105 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 106 | # else: 107 | # labels_to_names = create_readable_names_for_imagenet_labels() 108 | # dataset_utils.write_label_file(labels_to_names, dataset_dir) 109 | 110 | return slim.dataset.Dataset( 111 | data_sources=file_pattern, 112 | reader=reader, 113 | decoder=decoder, 114 | num_samples=split_to_sizes[split_name], 115 | items_to_descriptions=items_to_descriptions, 116 | num_classes=num_classes, 117 | labels_to_names=labels_to_names) 118 | -------------------------------------------------------------------------------- /SSD/datasets/pascalvoc_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Converts Pascal VOC data to TFRecords file format with Example protos. 16 | 17 | The raw Pascal VOC data set is expected to reside in JPEG files located in the 18 | directory 'JPEGImages'. Similarly, bounding box annotations are supposed to be 19 | stored in the 'Annotation directory' 20 | 21 | This TensorFlow script converts the training and evaluation data into 22 | a sharded data set consisting of 1024 and 128 TFRecord files, respectively. 23 | 24 | Each validation TFRecord file contains ~500 records. Each training TFREcord 25 | file contains ~1000 records. Each record within the TFRecord file is a 26 | serialized Example proto. The Example proto contains the following fields: 27 | 28 | image/encoded: string containing JPEG encoded image in RGB colorspace 29 | image/height: integer, image height in pixels 30 | image/width: integer, image width in pixels 31 | image/channels: integer, specifying the number of channels, always 3 32 | image/format: string, specifying the format, always'JPEG' 33 | 34 | 35 | image/object/bbox/xmin: list of float specifying the 0+ human annotated 36 | bounding boxes 37 | image/object/bbox/xmax: list of float specifying the 0+ human annotated 38 | bounding boxes 39 | image/object/bbox/ymin: list of float specifying the 0+ human annotated 40 | bounding boxes 41 | image/object/bbox/ymax: list of float specifying the 0+ human annotated 42 | bounding boxes 43 | image/object/bbox/label: list of integer specifying the classification index. 44 | image/object/bbox/label_text: list of string descriptions. 45 | 46 | Note that the length of xmin is identical to the length of xmax, ymin and ymax 47 | for each example. 48 | """ 49 | import os 50 | import random 51 | import sys 52 | import xml.etree.ElementTree as ET 53 | 54 | import tensorflow as tf 55 | 56 | from datasets.dataset_utils import int64_feature, float_feature, bytes_feature 57 | from datasets.pascalvoc_common import VOC_LABELS 58 | 59 | # Original dataset organisation. 60 | DIRECTORY_ANNOTATIONS = 'Annotations/' 61 | DIRECTORY_IMAGES = 'JPEGImages/' 62 | 63 | # TFRecords convertion parameters. 64 | RANDOM_SEED = 4242 65 | SAMPLES_PER_FILES = 200 66 | 67 | 68 | def _process_image(directory, name): 69 | """Process a image and annotation file. 70 | 71 | Args: 72 | filename: string, path to an image file e.g., '/path/to/example.JPG'. 73 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 74 | Returns: 75 | image_buffer: string, JPEG encoding of RGB image. 76 | height: integer, image height in pixels. 77 | width: integer, image width in pixels. 78 | """ 79 | # Read the image file. 80 | filename = directory + DIRECTORY_IMAGES + name + '.jpg' 81 | image_data = tf.gfile.FastGFile(filename, 'r').read() 82 | 83 | # Read the XML annotation file. 84 | filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml') 85 | tree = ET.parse(filename) 86 | root = tree.getroot() 87 | 88 | # Image shape. 89 | size = root.find('size') 90 | shape = [int(size.find('height').text), 91 | int(size.find('width').text), 92 | int(size.find('depth').text)] 93 | # Find annotations. 94 | bboxes = [] 95 | labels = [] 96 | labels_text = [] 97 | difficult = [] 98 | truncated = [] 99 | for obj in root.findall('object'): 100 | label = obj.find('name').text 101 | labels.append(int(VOC_LABELS[label][0])) 102 | labels_text.append(label.encode('ascii')) 103 | 104 | if obj.find('difficult'): 105 | difficult.append(int(obj.find('difficult').text)) 106 | else: 107 | difficult.append(0) 108 | if obj.find('truncated'): 109 | truncated.append(int(obj.find('truncated').text)) 110 | else: 111 | truncated.append(0) 112 | 113 | bbox = obj.find('bndbox') 114 | bboxes.append((float(bbox.find('ymin').text) / shape[0], 115 | float(bbox.find('xmin').text) / shape[1], 116 | float(bbox.find('ymax').text) / shape[0], 117 | float(bbox.find('xmax').text) / shape[1] 118 | )) 119 | return image_data, shape, bboxes, labels, labels_text, difficult, truncated 120 | 121 | 122 | def _convert_to_example(image_data, labels, labels_text, bboxes, shape, 123 | difficult, truncated): 124 | """Build an Example proto for an image example. 125 | 126 | Args: 127 | image_data: string, JPEG encoding of RGB image; 128 | labels: list of integers, identifier for the ground truth; 129 | labels_text: list of strings, human-readable labels; 130 | bboxes: list of bounding boxes; each box is a list of integers; 131 | specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong 132 | to the same label as the image label. 133 | shape: 3 integers, image shapes in pixels. 134 | Returns: 135 | Example proto 136 | """ 137 | xmin = [] 138 | ymin = [] 139 | xmax = [] 140 | ymax = [] 141 | for b in bboxes: 142 | assert len(b) == 4 143 | # pylint: disable=expression-not-assigned 144 | [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)] 145 | # pylint: enable=expression-not-assigned 146 | 147 | image_format = b'JPEG' 148 | example = tf.train.Example(features=tf.train.Features(feature={ 149 | 'image/height': int64_feature(shape[0]), 150 | 'image/width': int64_feature(shape[1]), 151 | 'image/channels': int64_feature(shape[2]), 152 | 'image/shape': int64_feature(shape), 153 | 'image/object/bbox/xmin': float_feature(xmin), 154 | 'image/object/bbox/xmax': float_feature(xmax), 155 | 'image/object/bbox/ymin': float_feature(ymin), 156 | 'image/object/bbox/ymax': float_feature(ymax), 157 | 'image/object/bbox/label': int64_feature(labels), 158 | 'image/object/bbox/label_text': bytes_feature(labels_text), 159 | 'image/object/bbox/difficult': int64_feature(difficult), 160 | 'image/object/bbox/truncated': int64_feature(truncated), 161 | 'image/format': bytes_feature(image_format), 162 | 'image/encoded': bytes_feature(image_data)})) 163 | return example 164 | 165 | 166 | def _add_to_tfrecord(dataset_dir, name, tfrecord_writer): 167 | """Loads data from image and annotations files and add them to a TFRecord. 168 | 169 | Args: 170 | dataset_dir: Dataset directory; 171 | name: Image name to add to the TFRecord; 172 | tfrecord_writer: The TFRecord writer to use for writing. 173 | """ 174 | image_data, shape, bboxes, labels, labels_text, difficult, truncated = \ 175 | _process_image(dataset_dir, name) 176 | example = _convert_to_example(image_data, labels, labels_text, 177 | bboxes, shape, difficult, truncated) 178 | tfrecord_writer.write(example.SerializeToString()) 179 | 180 | 181 | def _get_output_filename(output_dir, name, idx): 182 | return '%s/%s_%03d.tfrecord' % (output_dir, name, idx) 183 | 184 | 185 | def run(dataset_dir, output_dir, name='voc_train', shuffling=False): 186 | """Runs the conversion operation. 187 | 188 | Args: 189 | dataset_dir: The dataset directory where the dataset is stored. 190 | output_dir: Output directory. 191 | """ 192 | if not tf.gfile.Exists(dataset_dir): 193 | tf.gfile.MakeDirs(dataset_dir) 194 | 195 | # Dataset filenames, and shuffling. 196 | path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS) 197 | filenames = sorted(os.listdir(path)) 198 | if shuffling: 199 | random.seed(RANDOM_SEED) 200 | random.shuffle(filenames) 201 | 202 | # Process dataset files. 203 | i = 0 204 | fidx = 0 205 | while i < len(filenames): 206 | # Open new TFRecord file. 207 | tf_filename = _get_output_filename(output_dir, name, fidx) 208 | with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer: 209 | j = 0 210 | while i < len(filenames) and j < SAMPLES_PER_FILES: 211 | sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames))) 212 | sys.stdout.flush() 213 | 214 | filename = filenames[i] 215 | img_name = filename[:-4] 216 | _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer) 217 | i += 1 218 | j += 1 219 | fidx += 1 220 | 221 | # Finally, write the labels file: 222 | # labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES)) 223 | # dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 224 | print('\nFinished converting the Pascal VOC dataset!') 225 | -------------------------------------------------------------------------------- /SSD/inspect_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A simple script for inspect checkpoint files.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import argparse 21 | import sys 22 | 23 | import numpy as np 24 | from tensorflow.python import pywrap_tensorflow 25 | from tensorflow.python.platform import app 26 | from tensorflow.python.platform import flags 27 | 28 | FLAGS = None 29 | 30 | 31 | def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): 32 | """Prints tensors in a checkpoint file. 33 | 34 | If no `tensor_name` is provided, prints the tensor names and shapes 35 | in the checkpoint file. 36 | 37 | If `tensor_name` is provided, prints the content of the tensor. 38 | 39 | Args: 40 | file_name: Name of the checkpoint file. 41 | tensor_name: Name of the tensor in the checkpoint file to print. 42 | all_tensors: Boolean indicating whether to print all tensors. 43 | """ 44 | try: 45 | reader = pywrap_tensorflow.NewCheckpointReader(file_name) 46 | if all_tensors: 47 | var_to_shape_map = reader.get_variable_to_shape_map() 48 | for key in var_to_shape_map: 49 | print("tensor_name: ", key) 50 | print(reader.get_tensor(key)) 51 | elif not tensor_name: 52 | print(reader.debug_string().decode("utf-8")) 53 | else: 54 | print("tensor_name: ", tensor_name) 55 | print(reader.get_tensor(tensor_name)) 56 | except Exception as e: # pylint: disable=broad-except 57 | print(str(e)) 58 | if "corrupted compressed block contents" in str(e): 59 | print("It's likely that your checkpoint file has been compressed " 60 | "with SNAPPY.") 61 | 62 | 63 | def parse_numpy_printoption(kv_str): 64 | """Sets a single numpy printoption from a string of the form 'x=y'. 65 | 66 | See documentation on numpy.set_printoptions() for details about what values 67 | x and y can take. x can be any option listed there other than 'formatter'. 68 | 69 | Args: 70 | kv_str: A string of the form 'x=y', such as 'threshold=100000' 71 | 72 | Raises: 73 | argparse.ArgumentTypeError: If the string couldn't be used to set any 74 | nump printoption. 75 | """ 76 | k_v_str = kv_str.split("=", 1) 77 | if len(k_v_str) != 2 or not k_v_str[0]: 78 | raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str) 79 | k, v_str = k_v_str 80 | printoptions = np.get_printoptions() 81 | if k not in printoptions: 82 | raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k) 83 | v_type = type(printoptions[k]) 84 | if v_type is type(None): 85 | raise argparse.ArgumentTypeError( 86 | "Setting '%s' from the command line is not supported." % k) 87 | try: 88 | v = (v_type(v_str) if v_type is not bool 89 | else flags.BooleanParser().Parse(v_str)) 90 | except ValueError as e: 91 | raise argparse.ArgumentTypeError(e.message) 92 | np.set_printoptions(**{k: v}) 93 | 94 | 95 | def main(unused_argv): 96 | if not FLAGS.file_name: 97 | print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " 98 | "[--tensor_name=tensor_to_print]") 99 | sys.exit(1) 100 | else: 101 | print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name, 102 | FLAGS.all_tensors) 103 | 104 | 105 | if __name__ == "__main__": 106 | parser = argparse.ArgumentParser() 107 | parser.register("type", "bool", lambda v: v.lower() == "true") 108 | parser.add_argument( 109 | "--file_name", type=str, default="", help="Checkpoint filename. " 110 | "Note, if using Checkpoint V2 format, file_name is the " 111 | "shared prefix between all files in the checkpoint.") 112 | parser.add_argument( 113 | "--tensor_name", 114 | type=str, 115 | default="", 116 | help="Name of the tensor to inspect") 117 | parser.add_argument( 118 | "--all_tensors", 119 | nargs="?", 120 | const=True, 121 | type="bool", 122 | default=False, 123 | help="If True, print the values of all the tensors.") 124 | parser.add_argument( 125 | "--printoptions", 126 | nargs="*", 127 | type=parse_numpy_printoption, 128 | help="Argument for numpy.set_printoptions(), in the form 'k=v'.") 129 | FLAGS, unparsed = parser.parse_known_args() 130 | app.run(main=main, argv=[sys.argv[0]] + unparsed) 131 | -------------------------------------------------------------------------------- /SSD/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SSD/nets/caffe_scope.py: -------------------------------------------------------------------------------- 1 | """Specific Caffe scope used to import weights from a .caffemodel file. 2 | 3 | The idea is to create special initializers loading weights from protobuf 4 | .caffemodel files. 5 | """ 6 | import numpy as np 7 | import tensorflow as tf 8 | from caffe.proto import caffe_pb2 9 | 10 | slim = tf.contrib.slim 11 | 12 | 13 | class CaffeScope(object): 14 | """Caffe scope. 15 | """ 16 | def __init__(self): 17 | """Initialize the caffee scope. 18 | """ 19 | self.counters = {} 20 | self.layers = {} 21 | self.caffe_layers = None 22 | self.bgr_to_rgb = 0 23 | 24 | def load(self, filename, bgr_to_rgb=True): 25 | """Load weights from a .caffemodel file and initialize counters. 26 | 27 | Params: 28 | filename: caffemodel file. 29 | """ 30 | print('Loading Caffe file:', filename) 31 | caffemodel_params = caffe_pb2.NetParameter() 32 | caffemodel_str = open(filename, 'rb').read() 33 | caffemodel_params.ParseFromString(caffemodel_str) 34 | self.caffe_layers = caffemodel_params.layer 35 | 36 | # Layers collection. 37 | self.layers['convolution'] = [i for i, l in enumerate(self.caffe_layers) 38 | if l.type == 'Convolution'] 39 | self.layers['l2_normalization'] = [i for i, l in enumerate(self.caffe_layers) 40 | if l.type == 'Normalize'] 41 | # BGR to RGB convertion. Tries to find the first convolution with 3 42 | # and exchange parameters. 43 | if bgr_to_rgb: 44 | self.bgr_to_rgb = 1 45 | 46 | def conv_weights_init(self): 47 | def _initializer(shape, dtype, partition_info=None): 48 | counter = self.counters.get(self.conv_weights_init, 0) 49 | idx = self.layers['convolution'][counter] 50 | layer = self.caffe_layers[idx] 51 | # Weights: reshape and transpose dimensions. 52 | w = np.array(layer.blobs[0].data) 53 | w = np.reshape(w, layer.blobs[0].shape.dim) 54 | # w = np.transpose(w, (1, 0, 2, 3)) 55 | w = np.transpose(w, (2, 3, 1, 0)) 56 | if self.bgr_to_rgb == 1 and w.shape[2] == 3: 57 | print('Convert BGR to RGB in convolution layer:', layer.name) 58 | w[:, :, (0, 1, 2)] = w[:, :, (2, 1, 0)] 59 | self.bgr_to_rgb += 1 60 | self.counters[self.conv_weights_init] = counter + 1 61 | print('Load weights from convolution layer:', layer.name, w.shape) 62 | return tf.cast(w, dtype) 63 | return _initializer 64 | 65 | def conv_biases_init(self): 66 | def _initializer(shape, dtype, partition_info=None): 67 | counter = self.counters.get(self.conv_biases_init, 0) 68 | idx = self.layers['convolution'][counter] 69 | layer = self.caffe_layers[idx] 70 | # Biases data... 71 | b = np.array(layer.blobs[1].data) 72 | self.counters[self.conv_biases_init] = counter + 1 73 | print('Load biases from convolution layer:', layer.name, b.shape) 74 | return tf.cast(b, dtype) 75 | return _initializer 76 | 77 | def l2_norm_scale_init(self): 78 | def _initializer(shape, dtype, partition_info=None): 79 | counter = self.counters.get(self.l2_norm_scale_init, 0) 80 | idx = self.layers['l2_normalization'][counter] 81 | layer = self.caffe_layers[idx] 82 | # Scaling parameter. 83 | s = np.array(layer.blobs[0].data) 84 | s = np.reshape(s, layer.blobs[0].shape.dim) 85 | self.counters[self.l2_norm_scale_init] = counter + 1 86 | print('Load scaling from L2 normalization layer:', layer.name, s.shape) 87 | return tf.cast(s, dtype) 88 | return _initializer 89 | -------------------------------------------------------------------------------- /SSD/nets/custom_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implement some custom layers, not provided by TensorFlow. 16 | 17 | Trying to follow as much as possible the style/standards used in 18 | tf.contrib.layers 19 | """ 20 | import tensorflow as tf 21 | from tensorflow.contrib.framework.python.ops import add_arg_scope 22 | from tensorflow.contrib.framework.python.ops import variables 23 | from tensorflow.contrib.layers.python.layers import utils 24 | from tensorflow.python.ops import init_ops 25 | from tensorflow.python.ops import nn 26 | from tensorflow.python.ops import variable_scope 27 | 28 | 29 | def abs_smooth(x): 30 | """Smoothed absolute function. Useful to compute an L1 smooth error. 31 | 32 | Define as: 33 | x^2 / 2 if abs(x) < 1 34 | abs(x) - 0.5 if abs(x) > 1 35 | We use here a differentiable definition using min(x) and abs(x). Clearly 36 | not optimal, but good enough for our purpose! 37 | """ 38 | absx = tf.abs(x) 39 | minx = tf.minimum(absx, 1) 40 | r = 0.5 * ((absx - 1) * minx + absx) 41 | return r 42 | 43 | 44 | @add_arg_scope 45 | def l2_normalization( 46 | inputs, 47 | scaling=False, 48 | scale_initializer=init_ops.ones_initializer(), 49 | reuse=None, 50 | variables_collections=None, 51 | outputs_collections=None, 52 | data_format='NHWC', 53 | trainable=True, 54 | scope=None): 55 | """Implement L2 normalization on every feature (i.e. spatial normalization). 56 | 57 | Should be extended in some near future to other dimensions, providing a more 58 | flexible normalization framework. 59 | 60 | Args: 61 | inputs: a 4-D tensor with dimensions [batch_size, height, width, channels]. 62 | scaling: whether or not to add a post scaling operation along the dimensions 63 | which have been normalized. 64 | scale_initializer: An initializer for the weights. 65 | reuse: whether or not the layer and its variables should be reused. To be 66 | able to reuse the layer scope must be given. 67 | variables_collections: optional list of collections for all the variables or 68 | a dictionary containing a different list of collection per variable. 69 | outputs_collections: collection to add the outputs. 70 | data_format: NHWC or NCHW data format. 71 | trainable: If `True` also add variables to the graph collection 72 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 73 | scope: Optional scope for `variable_scope`. 74 | Returns: 75 | A `Tensor` representing the output of the operation. 76 | """ 77 | 78 | with variable_scope.variable_scope( 79 | scope, 'L2Normalization', [inputs], reuse=reuse) as sc: 80 | inputs_shape = inputs.get_shape() 81 | inputs_rank = inputs_shape.ndims 82 | dtype = inputs.dtype.base_dtype 83 | if data_format == 'NHWC': 84 | # norm_dim = tf.range(1, inputs_rank-1) 85 | norm_dim = tf.range(inputs_rank-1, inputs_rank) 86 | params_shape = inputs_shape[-1:] 87 | elif data_format == 'NCHW': 88 | # norm_dim = tf.range(2, inputs_rank) 89 | norm_dim = tf.range(1, 2) 90 | params_shape = (inputs_shape[1]) 91 | 92 | # Normalize along spatial dimensions. 93 | outputs = nn.l2_normalize(inputs, norm_dim, epsilon=1e-12) 94 | # Additional scaling. 95 | if scaling: 96 | scale_collections = utils.get_variable_collections( 97 | variables_collections, 'scale') 98 | scale = variables.model_variable('gamma', 99 | shape=params_shape, 100 | dtype=dtype, 101 | initializer=scale_initializer, 102 | collections=scale_collections, 103 | trainable=trainable) 104 | if data_format == 'NHWC': 105 | outputs = tf.multiply(outputs, scale) 106 | elif data_format == 'NCHW': 107 | scale = tf.expand_dims(scale, axis=-1) 108 | scale = tf.expand_dims(scale, axis=-1) 109 | outputs = tf.multiply(outputs, scale) 110 | # outputs = tf.transpose(outputs, perm=(0, 2, 3, 1)) 111 | 112 | return utils.collect_named_outputs(outputs_collections, 113 | sc.original_name_scope, outputs) 114 | 115 | 116 | @add_arg_scope 117 | def pad2d(inputs, 118 | pad=(0, 0), 119 | mode='CONSTANT', 120 | data_format='NHWC', 121 | trainable=True, 122 | scope=None): 123 | """2D Padding layer, adding a symmetric padding to H and W dimensions. 124 | 125 | Aims to mimic padding in Caffe and MXNet, helping the port of models to 126 | TensorFlow. Tries to follow the naming convention of `tf.contrib.layers`. 127 | 128 | Args: 129 | inputs: 4D input Tensor; 130 | pad: 2-Tuple with padding values for H and W dimensions; 131 | mode: Padding mode. C.f. `tf.pad` 132 | data_format: NHWC or NCHW data format. 133 | """ 134 | with tf.name_scope(scope, 'pad2d', [inputs]): 135 | # Padding shape. 136 | if data_format == 'NHWC': 137 | paddings = [[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]] 138 | elif data_format == 'NCHW': 139 | paddings = [[0, 0], [0, 0], [pad[0], pad[0]], [pad[1], pad[1]]] 140 | net = tf.pad(inputs, paddings, mode=mode) 141 | return net 142 | 143 | 144 | @add_arg_scope 145 | def channel_to_last(inputs, 146 | data_format='NHWC', 147 | scope=None): 148 | """Move the channel axis to the last dimension. Allows to 149 | provide a single output format whatever the input data format. 150 | 151 | Args: 152 | inputs: Input Tensor; 153 | data_format: NHWC or NCHW. 154 | Return: 155 | Input in NHWC format. 156 | """ 157 | with tf.name_scope(scope, 'channel_to_last', [inputs]): 158 | if data_format == 'NHWC': 159 | net = inputs 160 | elif data_format == 'NCHW': 161 | net = tf.transpose(inputs, perm=(0, 2, 3, 1)) 162 | return net 163 | -------------------------------------------------------------------------------- /SSD/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings inception_v1, inception_v2 and inception_v3 under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | # from nets.inception_v1 import inception_v1 23 | # from nets.inception_v1 import inception_v1_arg_scope 24 | # from nets.inception_v1 import inception_v1_base 25 | # from nets.inception_v2 import inception_v2 26 | # from nets.inception_v2 import inception_v2_arg_scope 27 | # from nets.inception_v2 import inception_v2_base 28 | # pylint: enable=unused-import 29 | -------------------------------------------------------------------------------- /SSD/nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models. 16 | """ 17 | 18 | import functools 19 | 20 | import tensorflow as tf 21 | 22 | from nets import ssd_vgg_300 23 | from nets import ssd_vgg_512 24 | # from nets import inception 25 | # from nets import overfeat 26 | # from nets import resnet_v1 27 | # from nets import resnet_v2 28 | from nets import vgg 29 | 30 | # from nets import xception 31 | 32 | slim = tf.contrib.slim 33 | 34 | networks_map = {'vgg_a': vgg.vgg_a, 35 | 'vgg_16': vgg.vgg_16, 36 | 'vgg_19': vgg.vgg_19, 37 | 'ssd_300_vgg': ssd_vgg_300.ssd_net, 38 | 'ssd_300_vgg_caffe': ssd_vgg_300.ssd_net, 39 | 'ssd_512_vgg': ssd_vgg_512.ssd_net, 40 | 'ssd_512_vgg_caffe': ssd_vgg_512.ssd_net, 41 | } 42 | 43 | arg_scopes_map = {'vgg_a': vgg.vgg_arg_scope, 44 | 'vgg_16': vgg.vgg_arg_scope, 45 | 'vgg_19': vgg.vgg_arg_scope, 46 | 'ssd_300_vgg': ssd_vgg_300.ssd_arg_scope, 47 | 'ssd_300_vgg_caffe': ssd_vgg_300.ssd_arg_scope_caffe, 48 | 'ssd_512_vgg': ssd_vgg_512.ssd_arg_scope, 49 | 'ssd_512_vgg_caffe': ssd_vgg_512.ssd_arg_scope_caffe, 50 | } 51 | 52 | networks_obj = {'ssd_300_vgg': ssd_vgg_300.SSDNet, 53 | 'ssd_512_vgg': ssd_vgg_512.SSDNet, 54 | } 55 | 56 | 57 | def get_network(name): 58 | """Get a network object from a name. 59 | """ 60 | # params = networks_obj[name].default_params if params is None else params 61 | return networks_obj[name] 62 | 63 | 64 | def get_network_fn(name, num_classes, is_training=False, **kwargs): 65 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 66 | 67 | Args: 68 | name: The name of the network. 69 | num_classes: The number of classes to use for classification. 70 | is_training: `True` if the model is being used for training and `False` 71 | otherwise. 72 | weight_decay: The l2 coefficient for the model weights. 73 | Returns: 74 | network_fn: A function that applies the model to a batch of images. It has 75 | the following signature: logits, end_points = network_fn(images) 76 | Raises: 77 | ValueError: If network `name` is not recognized. 78 | """ 79 | if name not in networks_map: 80 | raise ValueError('Name of network unknown %s' % name) 81 | arg_scope = arg_scopes_map[name](**kwargs) 82 | func = networks_map[name] 83 | @functools.wraps(func) 84 | def network_fn(images, **kwargs): 85 | with slim.arg_scope(arg_scope): 86 | return func(images, num_classes, is_training=is_training, **kwargs) 87 | if hasattr(func, 'default_image_size'): 88 | network_fn.default_image_size = func.default_image_size 89 | 90 | return network_fn 91 | -------------------------------------------------------------------------------- /SSD/nets/np_methods.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional Numpy methods. Big mess of many things! 16 | """ 17 | import numpy as np 18 | 19 | 20 | # =========================================================================== # 21 | # Numpy implementations of SSD boxes functions. 22 | # =========================================================================== # 23 | def ssd_bboxes_decode(feat_localizations, 24 | anchor_bboxes, 25 | prior_scaling=[0.1, 0.1, 0.2, 0.2]): 26 | """Compute the relative bounding boxes from the layer features and 27 | reference anchor bounding boxes. 28 | 29 | Return: 30 | numpy array Nx4: ymin, xmin, ymax, xmax 31 | """ 32 | # Reshape for easier broadcasting. 33 | l_shape = feat_localizations.shape 34 | feat_localizations = np.reshape(feat_localizations, 35 | (-1, l_shape[-2], l_shape[-1])) 36 | yref, xref, href, wref = anchor_bboxes 37 | xref = np.reshape(xref, [-1, 1]) 38 | yref = np.reshape(yref, [-1, 1]) 39 | 40 | # Compute center, height and width 41 | cx = feat_localizations[:, :, 0] * wref * prior_scaling[0] + xref 42 | cy = feat_localizations[:, :, 1] * href * prior_scaling[1] + yref 43 | w = wref * np.exp(feat_localizations[:, :, 2] * prior_scaling[2]) 44 | h = href * np.exp(feat_localizations[:, :, 3] * prior_scaling[3]) 45 | # bboxes: ymin, xmin, xmax, ymax. 46 | bboxes = np.zeros_like(feat_localizations) 47 | bboxes[:, :, 0] = cy - h / 2. 48 | bboxes[:, :, 1] = cx - w / 2. 49 | bboxes[:, :, 2] = cy + h / 2. 50 | bboxes[:, :, 3] = cx + w / 2. 51 | # Back to original shape. 52 | bboxes = np.reshape(bboxes, l_shape) 53 | return bboxes 54 | 55 | 56 | def ssd_bboxes_select_layer(predictions_layer, 57 | localizations_layer, 58 | anchors_layer, 59 | select_threshold=0.5, 60 | img_shape=(300, 300), 61 | num_classes=21, 62 | decode=True): 63 | """Extract classes, scores and bounding boxes from features in one layer. 64 | 65 | Return: 66 | classes, scores, bboxes: Numpy arrays... 67 | """ 68 | # First decode localizations features if necessary. 69 | if decode: 70 | localizations_layer = ssd_bboxes_decode(localizations_layer, anchors_layer) 71 | 72 | # Reshape features to: Batches x N x N_labels | 4. 73 | p_shape = predictions_layer.shape 74 | batch_size = p_shape[0] if len(p_shape) == 5 else 1 75 | predictions_layer = np.reshape(predictions_layer, 76 | (batch_size, -1, p_shape[-1])) 77 | l_shape = localizations_layer.shape 78 | localizations_layer = np.reshape(localizations_layer, 79 | (batch_size, -1, l_shape[-1])) 80 | 81 | # Boxes selection: use threshold or score > no-label criteria. 82 | if select_threshold is None or select_threshold == 0: 83 | # Class prediction and scores: assign 0. to 0-class 84 | classes = np.argmax(predictions_layer, axis=2) 85 | scores = np.amax(predictions_layer, axis=2) 86 | mask = (classes > 0) 87 | classes = classes[mask] 88 | scores = scores[mask] 89 | bboxes = localizations_layer[mask] 90 | else: 91 | sub_predictions = predictions_layer[:, :, 1:] 92 | idxes = np.where(sub_predictions > select_threshold) 93 | classes = idxes[-1]+1 94 | scores = sub_predictions[idxes] 95 | bboxes = localizations_layer[idxes[:-1]] 96 | 97 | return classes, scores, bboxes 98 | 99 | 100 | def ssd_bboxes_select(predictions_net, 101 | localizations_net, 102 | anchors_net, 103 | select_threshold=0.5, 104 | img_shape=(300, 300), 105 | num_classes=21, 106 | decode=True): 107 | """Extract classes, scores and bounding boxes from network output layers. 108 | 109 | Return: 110 | classes, scores, bboxes: Numpy arrays... 111 | """ 112 | l_classes = [] 113 | l_scores = [] 114 | l_bboxes = [] 115 | # l_layers = [] 116 | # l_idxes = [] 117 | for i in range(len(predictions_net)): 118 | classes, scores, bboxes = ssd_bboxes_select_layer( 119 | predictions_net[i], localizations_net[i], anchors_net[i], 120 | select_threshold, img_shape, num_classes, decode) 121 | l_classes.append(classes) 122 | l_scores.append(scores) 123 | l_bboxes.append(bboxes) 124 | # Debug information. 125 | # l_layers.append(i) 126 | # l_idxes.append((i, idxes)) 127 | 128 | classes = np.concatenate(l_classes, 0) 129 | scores = np.concatenate(l_scores, 0) 130 | bboxes = np.concatenate(l_bboxes, 0) 131 | return classes, scores, bboxes 132 | 133 | 134 | # =========================================================================== # 135 | # Common functions for bboxes handling and selection. 136 | # =========================================================================== # 137 | def bboxes_sort(classes, scores, bboxes, top_k=400): 138 | """Sort bounding boxes by decreasing order and keep only the top_k 139 | """ 140 | # if priority_inside: 141 | # inside = (bboxes[:, 0] > margin) & (bboxes[:, 1] > margin) & \ 142 | # (bboxes[:, 2] < 1-margin) & (bboxes[:, 3] < 1-margin) 143 | # idxes = np.argsort(-scores) 144 | # inside = inside[idxes] 145 | # idxes = np.concatenate([idxes[inside], idxes[~inside]]) 146 | idxes = np.argsort(-scores) 147 | classes = classes[idxes][:top_k] 148 | scores = scores[idxes][:top_k] 149 | bboxes = bboxes[idxes][:top_k] 150 | return classes, scores, bboxes 151 | 152 | 153 | def bboxes_clip(bbox_ref, bboxes): 154 | """Clip bounding boxes with respect to reference bbox. 155 | """ 156 | bboxes = np.copy(bboxes) 157 | bboxes = np.transpose(bboxes) 158 | bbox_ref = np.transpose(bbox_ref) 159 | bboxes[0] = np.maximum(bboxes[0], bbox_ref[0]) 160 | bboxes[1] = np.maximum(bboxes[1], bbox_ref[1]) 161 | bboxes[2] = np.minimum(bboxes[2], bbox_ref[2]) 162 | bboxes[3] = np.minimum(bboxes[3], bbox_ref[3]) 163 | bboxes = np.transpose(bboxes) 164 | return bboxes 165 | 166 | 167 | def bboxes_resize(bbox_ref, bboxes): 168 | """Resize bounding boxes based on a reference bounding box, 169 | assuming that the latter is [0, 0, 1, 1] after transform. 170 | """ 171 | bboxes = np.copy(bboxes) 172 | # Translate. 173 | bboxes[:, 0] -= bbox_ref[0] 174 | bboxes[:, 1] -= bbox_ref[1] 175 | bboxes[:, 2] -= bbox_ref[0] 176 | bboxes[:, 3] -= bbox_ref[1] 177 | # Resize. 178 | resize = [bbox_ref[2] - bbox_ref[0], bbox_ref[3] - bbox_ref[1]] 179 | bboxes[:, 0] /= resize[0] 180 | bboxes[:, 1] /= resize[1] 181 | bboxes[:, 2] /= resize[0] 182 | bboxes[:, 3] /= resize[1] 183 | return bboxes 184 | 185 | 186 | def bboxes_jaccard(bboxes1, bboxes2): 187 | """Computing jaccard index between bboxes1 and bboxes2. 188 | Note: bboxes1 and bboxes2 can be multi-dimensional, but should broacastable. 189 | """ 190 | bboxes1 = np.transpose(bboxes1) 191 | bboxes2 = np.transpose(bboxes2) 192 | # Intersection bbox and volume. 193 | int_ymin = np.maximum(bboxes1[0], bboxes2[0]) 194 | int_xmin = np.maximum(bboxes1[1], bboxes2[1]) 195 | int_ymax = np.minimum(bboxes1[2], bboxes2[2]) 196 | int_xmax = np.minimum(bboxes1[3], bboxes2[3]) 197 | 198 | int_h = np.maximum(int_ymax - int_ymin, 0.) 199 | int_w = np.maximum(int_xmax - int_xmin, 0.) 200 | int_vol = int_h * int_w 201 | # Union volume. 202 | vol1 = (bboxes1[2] - bboxes1[0]) * (bboxes1[3] - bboxes1[1]) 203 | vol2 = (bboxes2[2] - bboxes2[0]) * (bboxes2[3] - bboxes2[1]) 204 | jaccard = int_vol / (vol1 + vol2 - int_vol) 205 | return jaccard 206 | 207 | 208 | def bboxes_intersection(bboxes_ref, bboxes2): 209 | """Computing jaccard index between bboxes1 and bboxes2. 210 | Note: bboxes1 and bboxes2 can be multi-dimensional, but should broacastable. 211 | """ 212 | bboxes_ref = np.transpose(bboxes_ref) 213 | bboxes2 = np.transpose(bboxes2) 214 | # Intersection bbox and volume. 215 | int_ymin = np.maximum(bboxes_ref[0], bboxes2[0]) 216 | int_xmin = np.maximum(bboxes_ref[1], bboxes2[1]) 217 | int_ymax = np.minimum(bboxes_ref[2], bboxes2[2]) 218 | int_xmax = np.minimum(bboxes_ref[3], bboxes2[3]) 219 | 220 | int_h = np.maximum(int_ymax - int_ymin, 0.) 221 | int_w = np.maximum(int_xmax - int_xmin, 0.) 222 | int_vol = int_h * int_w 223 | # Union volume. 224 | vol = (bboxes_ref[2] - bboxes_ref[0]) * (bboxes_ref[3] - bboxes_ref[1]) 225 | score = int_vol / vol 226 | return score 227 | 228 | 229 | def bboxes_nms(classes, scores, bboxes, nms_threshold=0.45): 230 | """Apply non-maximum selection to bounding boxes. 231 | """ 232 | keep_bboxes = np.ones(scores.shape, dtype=np.bool) 233 | for i in range(scores.size-1): 234 | if keep_bboxes[i]: 235 | # Computer overlap with bboxes which are following. 236 | overlap = bboxes_jaccard(bboxes[i], bboxes[(i+1):]) 237 | # Overlap threshold for keeping + checking part of the same class 238 | keep_overlap = np.logical_or(overlap < nms_threshold, classes[(i+1):] != classes[i]) 239 | keep_bboxes[(i+1):] = np.logical_and(keep_bboxes[(i+1):], keep_overlap) 240 | 241 | idxes = np.where(keep_bboxes) 242 | return classes[idxes], scores[idxes], bboxes[idxes] 243 | 244 | 245 | def bboxes_nms_fast(classes, scores, bboxes, threshold=0.45): 246 | """Apply non-maximum selection to bounding boxes. 247 | """ 248 | pass 249 | 250 | 251 | 252 | 253 | -------------------------------------------------------------------------------- /SSD/nets/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SSD/nets/vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains model definitions for versions of the Oxford VGG network. 16 | 17 | These model definitions were introduced in the following technical report: 18 | 19 | Very Deep Convolutional Networks For Large-Scale Image Recognition 20 | Karen Simonyan and Andrew Zisserman 21 | arXiv technical report, 2015 22 | PDF: http://arxiv.org/pdf/1409.1556.pdf 23 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf 24 | CC-BY-4.0 25 | 26 | More information can be obtained from the VGG website: 27 | www.robots.ox.ac.uk/~vgg/research/very_deep/ 28 | 29 | Usage: 30 | with slim.arg_scope(vgg.vgg_arg_scope()): 31 | outputs, end_points = vgg.vgg_a(inputs) 32 | 33 | with slim.arg_scope(vgg.vgg_arg_scope()): 34 | outputs, end_points = vgg.vgg_16(inputs) 35 | 36 | @@vgg_a 37 | @@vgg_16 38 | @@vgg_19 39 | """ 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import tensorflow as tf 45 | 46 | slim = tf.contrib.slim 47 | 48 | 49 | def vgg_arg_scope(weight_decay=0.0005): 50 | """Defines the VGG arg scope. 51 | 52 | Args: 53 | weight_decay: The l2 regularization coefficient. 54 | 55 | Returns: 56 | An arg_scope. 57 | """ 58 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 59 | activation_fn=tf.nn.relu, 60 | weights_regularizer=slim.l2_regularizer(weight_decay), 61 | biases_initializer=tf.zeros_initializer): 62 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: 63 | return arg_sc 64 | 65 | 66 | def vgg_a(inputs, 67 | num_classes=1000, 68 | is_training=True, 69 | dropout_keep_prob=0.5, 70 | spatial_squeeze=True, 71 | scope='vgg_a'): 72 | """Oxford Net VGG 11-Layers version A Example. 73 | 74 | Note: All the fully_connected layers have been transformed to conv2d layers. 75 | To use in classification mode, resize input to 224x224. 76 | 77 | Args: 78 | inputs: a tensor of size [batch_size, height, width, channels]. 79 | num_classes: number of predicted classes. 80 | is_training: whether or not the model is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | outputs. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | 87 | Returns: 88 | the last op containing the log predictions and end_points dict. 89 | """ 90 | with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc: 91 | end_points_collection = sc.name + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d. 93 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1') 96 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 97 | net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2') 98 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 99 | net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3') 100 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 101 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4') 102 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 103 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5') 104 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 105 | # Use conv2d instead of fully_connected layers. 106 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 107 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 108 | scope='dropout6') 109 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout7') 112 | net = slim.conv2d(net, num_classes, [1, 1], 113 | activation_fn=None, 114 | normalizer_fn=None, 115 | scope='fc8') 116 | # Convert end_points_collection into a end_point dict. 117 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 118 | if spatial_squeeze: 119 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 120 | end_points[sc.name + '/fc8'] = net 121 | return net, end_points 122 | vgg_a.default_image_size = 224 123 | 124 | 125 | def vgg_16(inputs, 126 | num_classes=1000, 127 | is_training=True, 128 | dropout_keep_prob=0.5, 129 | spatial_squeeze=True, 130 | scope='vgg_16'): 131 | """Oxford Net VGG 16-Layers version D Example. 132 | 133 | Note: All the fully_connected layers have been transformed to conv2d layers. 134 | To use in classification mode, resize input to 224x224. 135 | 136 | Args: 137 | inputs: a tensor of size [batch_size, height, width, channels]. 138 | num_classes: number of predicted classes. 139 | is_training: whether or not the model is being trained. 140 | dropout_keep_prob: the probability that activations are kept in the dropout 141 | layers during training. 142 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 143 | outputs. Useful to remove unnecessary dimensions for classification. 144 | scope: Optional scope for the variables. 145 | 146 | Returns: 147 | the last op containing the log predictions and end_points dict. 148 | """ 149 | with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc: 150 | end_points_collection = sc.name + '_end_points' 151 | # Collect outputs for conv2d, fully_connected and max_pool2d. 152 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 153 | outputs_collections=end_points_collection): 154 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 155 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 156 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 157 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 158 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') 159 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 160 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') 161 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 162 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') 163 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 164 | # Use conv2d instead of fully_connected layers. 165 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 166 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 167 | scope='dropout6') 168 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 169 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 170 | scope='dropout7') 171 | net = slim.conv2d(net, num_classes, [1, 1], 172 | activation_fn=None, 173 | normalizer_fn=None, 174 | scope='fc8') 175 | # Convert end_points_collection into a end_point dict. 176 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 177 | if spatial_squeeze: 178 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 179 | end_points[sc.name + '/fc8'] = net 180 | return net, end_points 181 | vgg_16.default_image_size = 224 182 | 183 | 184 | def vgg_19(inputs, 185 | num_classes=1000, 186 | is_training=True, 187 | dropout_keep_prob=0.5, 188 | spatial_squeeze=True, 189 | scope='vgg_19'): 190 | """Oxford Net VGG 19-Layers version E Example. 191 | 192 | Note: All the fully_connected layers have been transformed to conv2d layers. 193 | To use in classification mode, resize input to 224x224. 194 | 195 | Args: 196 | inputs: a tensor of size [batch_size, height, width, channels]. 197 | num_classes: number of predicted classes. 198 | is_training: whether or not the model is being trained. 199 | dropout_keep_prob: the probability that activations are kept in the dropout 200 | layers during training. 201 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 202 | outputs. Useful to remove unnecessary dimensions for classification. 203 | scope: Optional scope for the variables. 204 | 205 | Returns: 206 | the last op containing the log predictions and end_points dict. 207 | """ 208 | with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc: 209 | end_points_collection = sc.name + '_end_points' 210 | # Collect outputs for conv2d, fully_connected and max_pool2d. 211 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 212 | outputs_collections=end_points_collection): 213 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 214 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 215 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 216 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 217 | net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3') 218 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 219 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4') 220 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 221 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5') 222 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 223 | # Use conv2d instead of fully_connected layers. 224 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 225 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 226 | scope='dropout6') 227 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 228 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 229 | scope='dropout7') 230 | net = slim.conv2d(net, num_classes, [1, 1], 231 | activation_fn=None, 232 | normalizer_fn=None, 233 | scope='fc8') 234 | # Convert end_points_collection into a end_point dict. 235 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 236 | if spatial_squeeze: 237 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 238 | end_points[sc.name + '/fc8'] = net 239 | return net, end_points 240 | vgg_19.default_image_size = 224 241 | 242 | # Alias 243 | vgg_d = vgg_16 244 | vgg_e = vgg_19 245 | -------------------------------------------------------------------------------- /SSD/nets/xception.py: -------------------------------------------------------------------------------- 1 | """Definition of Xception model introduced by F. Chollet. 2 | 3 | Usage: 4 | with slim.arg_scope(xception.xception_arg_scope()): 5 | outputs, end_points = xception.xception(inputs) 6 | @@xception 7 | """ 8 | 9 | import tensorflow as tf 10 | slim = tf.contrib.slim 11 | 12 | 13 | # =========================================================================== # 14 | # Xception implementation (clean) 15 | # =========================================================================== # 16 | def xception(inputs, 17 | num_classes=1000, 18 | is_training=True, 19 | dropout_keep_prob=0.5, 20 | prediction_fn=slim.softmax, 21 | reuse=None, 22 | scope='xception'): 23 | """Xception model from https://arxiv.org/pdf/1610.02357v2.pdf 24 | 25 | The default image size used to train this network is 299x299. 26 | """ 27 | 28 | # end_points collect relevant activations for external use, for example 29 | # summaries or losses. 30 | end_points = {} 31 | 32 | with tf.variable_scope(scope, 'xception', [inputs]): 33 | # Block 1. 34 | end_point = 'block1' 35 | with tf.variable_scope(end_point): 36 | net = slim.conv2d(inputs, 32, [3, 3], stride=2, padding='VALID', scope='conv1') 37 | net = slim.conv2d(net, 64, [3, 3], padding='VALID', scope='conv2') 38 | end_points[end_point] = net 39 | 40 | # Residual block 2. 41 | end_point = 'block2' 42 | with tf.variable_scope(end_point): 43 | res = slim.conv2d(net, 128, [1, 1], stride=2, activation_fn=None, scope='res') 44 | net = slim.separable_convolution2d(net, 128, [3, 3], 1, scope='sepconv1') 45 | net = slim.separable_convolution2d(net, 128, [3, 3], 1, activation_fn=None, scope='sepconv2') 46 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool') 47 | net = res + net 48 | end_points[end_point] = net 49 | 50 | # Residual block 3. 51 | end_point = 'block3' 52 | with tf.variable_scope(end_point): 53 | res = slim.conv2d(net, 256, [1, 1], stride=2, activation_fn=None, scope='res') 54 | net = tf.nn.relu(net) 55 | net = slim.separable_convolution2d(net, 256, [3, 3], 1, scope='sepconv1') 56 | net = slim.separable_convolution2d(net, 256, [3, 3], 1, activation_fn=None, scope='sepconv2') 57 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool') 58 | net = res + net 59 | end_points[end_point] = net 60 | 61 | # Residual block 4. 62 | end_point = 'block4' 63 | with tf.variable_scope(end_point): 64 | res = slim.conv2d(net, 728, [1, 1], stride=2, activation_fn=None, scope='res') 65 | net = tf.nn.relu(net) 66 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, scope='sepconv1') 67 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None, scope='sepconv2') 68 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool') 69 | net = res + net 70 | end_points[end_point] = net 71 | 72 | # Middle flow blocks. 73 | for i in range(8): 74 | end_point = 'block' + str(i + 5) 75 | with tf.variable_scope(end_point): 76 | res = net 77 | net = tf.nn.relu(net) 78 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None, 79 | scope='sepconv1') 80 | net = tf.nn.relu(net) 81 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None, 82 | scope='sepconv2') 83 | net = tf.nn.relu(net) 84 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None, 85 | scope='sepconv3') 86 | net = res + net 87 | end_points[end_point] = net 88 | 89 | # Exit flow: blocks 13 and 14. 90 | end_point = 'block13' 91 | with tf.variable_scope(end_point): 92 | res = slim.conv2d(net, 1024, [1, 1], stride=2, activation_fn=None, scope='res') 93 | net = tf.nn.relu(net) 94 | net = slim.separable_convolution2d(net, 728, [3, 3], 1, activation_fn=None, scope='sepconv1') 95 | net = tf.nn.relu(net) 96 | net = slim.separable_convolution2d(net, 1024, [3, 3], 1, activation_fn=None, scope='sepconv2') 97 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool') 98 | net = res + net 99 | end_points[end_point] = net 100 | 101 | end_point = 'block14' 102 | with tf.variable_scope(end_point): 103 | net = slim.separable_convolution2d(net, 1536, [3, 3], 1, scope='sepconv1') 104 | net = slim.separable_convolution2d(net, 2048, [3, 3], 1, scope='sepconv2') 105 | end_points[end_point] = net 106 | 107 | # Global averaging. 108 | end_point = 'dense' 109 | with tf.variable_scope(end_point): 110 | net = tf.reduce_mean(net, [1, 2], name='reduce_avg') 111 | logits = slim.fully_connected(net, 1000, activation_fn=None) 112 | 113 | end_points['logits'] = logits 114 | end_points['predictions'] = prediction_fn(logits, scope='Predictions') 115 | 116 | return logits, end_points 117 | xception.default_image_size = 299 118 | 119 | 120 | def xception_arg_scope(weight_decay=0.00001, stddev=0.1): 121 | """Defines the default Xception arg scope. 122 | 123 | Args: 124 | weight_decay: The weight decay to use for regularizing the model. 125 | stddev: The standard deviation of the trunctated normal weight initializer. 126 | 127 | Returns: 128 | An `arg_scope` to use for the xception model. 129 | """ 130 | batch_norm_params = { 131 | # Decay for the moving averages. 132 | 'decay': 0.9997, 133 | # epsilon to prevent 0s in variance. 134 | 'epsilon': 0.001, 135 | # collection containing update_ops. 136 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 137 | } 138 | 139 | # Set weight_decay for weights in Conv and FC layers. 140 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.separable_convolution2d], 141 | weights_regularizer=slim.l2_regularizer(weight_decay)): 142 | with slim.arg_scope( 143 | [slim.conv2d, slim.separable_convolution2d], 144 | padding='SAME', 145 | weights_initializer=tf.contrib.layers.variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False), 146 | activation_fn=tf.nn.relu, 147 | normalizer_fn=slim.batch_norm, 148 | normalizer_params=batch_norm_params): 149 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as sc: 150 | return sc 151 | 152 | 153 | # =========================================================================== # 154 | # Xception arg scope (Keras hack!) 155 | # =========================================================================== # 156 | def xception_keras_arg_scope(hdf5_file, weight_decay=0.00001): 157 | """Defines an Xception arg scope which initialize layers weights 158 | using a Keras HDF5 file. 159 | 160 | Quite hacky implementaion, but seems to be working! 161 | 162 | Args: 163 | hdf5_file: HDF5 file handle. 164 | weight_decay: The weight decay to use for regularizing the model. 165 | 166 | Returns: 167 | An `arg_scope` to use for the xception model. 168 | """ 169 | # Default batch normalization parameters. 170 | batch_norm_params = { 171 | 'center': True, 172 | 'scale': False, 173 | 'decay': 0.9997, 174 | 'epsilon': 0.001, 175 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 176 | } 177 | 178 | # Read weights from HDF5 file. 179 | def keras_bn_params(): 180 | def _beta_initializer(shape, dtype, partition_info=None): 181 | keras_bn_params.bidx += 1 182 | k = 'batchnormalization_%i' % keras_bn_params.bidx 183 | kb = 'batchnormalization_%i_beta:0' % keras_bn_params.bidx 184 | return tf.cast(hdf5_file[k][kb][:], dtype) 185 | 186 | def _gamma_initializer(shape, dtype, partition_info=None): 187 | keras_bn_params.gidx += 1 188 | k = 'batchnormalization_%i' % keras_bn_params.gidx 189 | kg = 'batchnormalization_%i_gamma:0' % keras_bn_params.gidx 190 | return tf.cast(hdf5_file[k][kg][:], dtype) 191 | 192 | def _mean_initializer(shape, dtype, partition_info=None): 193 | keras_bn_params.midx += 1 194 | k = 'batchnormalization_%i' % keras_bn_params.midx 195 | km = 'batchnormalization_%i_running_mean:0' % keras_bn_params.midx 196 | return tf.cast(hdf5_file[k][km][:], dtype) 197 | 198 | def _variance_initializer(shape, dtype, partition_info=None): 199 | keras_bn_params.vidx += 1 200 | k = 'batchnormalization_%i' % keras_bn_params.vidx 201 | kv = 'batchnormalization_%i_running_std:0' % keras_bn_params.vidx 202 | return tf.cast(hdf5_file[k][kv][:], dtype) 203 | 204 | # Batch normalisation initializers. 205 | params = batch_norm_params.copy() 206 | params['initializers'] = { 207 | 'beta': _beta_initializer, 208 | 'gamma': _gamma_initializer, 209 | 'moving_mean': _mean_initializer, 210 | 'moving_variance': _variance_initializer, 211 | } 212 | return params 213 | keras_bn_params.bidx = 0 214 | keras_bn_params.gidx = 0 215 | keras_bn_params.midx = 0 216 | keras_bn_params.vidx = 0 217 | 218 | def keras_conv2d_weights(): 219 | def _initializer(shape, dtype, partition_info=None): 220 | keras_conv2d_weights.idx += 1 221 | k = 'convolution2d_%i' % keras_conv2d_weights.idx 222 | kw = 'convolution2d_%i_W:0' % keras_conv2d_weights.idx 223 | return tf.cast(hdf5_file[k][kw][:], dtype) 224 | return _initializer 225 | keras_conv2d_weights.idx = 0 226 | 227 | def keras_sep_conv2d_weights(): 228 | def _initializer(shape, dtype, partition_info=None): 229 | # Depthwise or Pointwise convolution? 230 | if shape[0] > 1 or shape[1] > 1: 231 | keras_sep_conv2d_weights.didx += 1 232 | k = 'separableconvolution2d_%i' % keras_sep_conv2d_weights.didx 233 | kd = 'separableconvolution2d_%i_depthwise_kernel:0' % keras_sep_conv2d_weights.didx 234 | weights = hdf5_file[k][kd][:] 235 | else: 236 | keras_sep_conv2d_weights.pidx += 1 237 | k = 'separableconvolution2d_%i' % keras_sep_conv2d_weights.pidx 238 | kp = 'separableconvolution2d_%i_pointwise_kernel:0' % keras_sep_conv2d_weights.pidx 239 | weights = hdf5_file[k][kp][:] 240 | return tf.cast(weights, dtype) 241 | return _initializer 242 | keras_sep_conv2d_weights.didx = 0 243 | keras_sep_conv2d_weights.pidx = 0 244 | 245 | def keras_dense_weights(): 246 | def _initializer(shape, dtype, partition_info=None): 247 | keras_dense_weights.idx += 1 248 | k = 'dense_%i' % keras_dense_weights.idx 249 | kw = 'dense_%i_W:0' % keras_dense_weights.idx 250 | return tf.cast(hdf5_file[k][kw][:], dtype) 251 | return _initializer 252 | keras_dense_weights.idx = 1 253 | 254 | def keras_dense_biases(): 255 | def _initializer(shape, dtype, partition_info=None): 256 | keras_dense_biases.idx += 1 257 | k = 'dense_%i' % keras_dense_biases.idx 258 | kb = 'dense_%i_b:0' % keras_dense_biases.idx 259 | return tf.cast(hdf5_file[k][kb][:], dtype) 260 | return _initializer 261 | keras_dense_biases.idx = 1 262 | 263 | # Default network arg scope. 264 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.separable_convolution2d], 265 | weights_regularizer=slim.l2_regularizer(weight_decay)): 266 | with slim.arg_scope( 267 | [slim.conv2d, slim.separable_convolution2d], 268 | padding='SAME', 269 | activation_fn=tf.nn.relu, 270 | normalizer_fn=slim.batch_norm, 271 | normalizer_params=keras_bn_params()): 272 | with slim.arg_scope([slim.max_pool2d], padding='SAME'): 273 | 274 | # Weights initializers from Keras weights. 275 | with slim.arg_scope([slim.conv2d], 276 | weights_initializer=keras_conv2d_weights()): 277 | with slim.arg_scope([slim.separable_convolution2d], 278 | weights_initializer=keras_sep_conv2d_weights()): 279 | with slim.arg_scope([slim.fully_connected], 280 | weights_initializer=keras_dense_weights(), 281 | biases_initializer=keras_dense_biases()) as sc: 282 | return sc 283 | 284 | -------------------------------------------------------------------------------- /SSD/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SSD/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import ssd_vgg_preprocessing 24 | 25 | # from preprocessing import cifarnet_preprocessing 26 | # from preprocessing import inception_preprocessing 27 | # from preprocessing import vgg_preprocessing 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def get_preprocessing(name, is_training=False): 33 | """Returns preprocessing_fn(image, height, width, **kwargs). 34 | 35 | Args: 36 | name: The name of the preprocessing function. 37 | is_training: `True` if the model is being used for training. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'ssd_300_vgg': ssd_vgg_preprocessing, 49 | 'ssd_512_vgg': ssd_vgg_preprocessing, 50 | } 51 | 52 | if name not in preprocessing_fn_map: 53 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 54 | 55 | def preprocessing_fn(image, labels, bboxes, 56 | out_shape, data_format='NHWC', **kwargs): 57 | return preprocessing_fn_map[name].preprocess_image( 58 | image, labels, bboxes, out_shape, data_format=data_format, 59 | is_training=is_training, **kwargs) 60 | return preprocessing_fn 61 | -------------------------------------------------------------------------------- /SSD/preprocessing/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SSD/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SSD/tf_convert_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Convert a dataset to TFRecords format, which can be easily integrated into 16 | a TensorFlow pipeline. 17 | 18 | Usage: 19 | ```shell 20 | python tf_convert_data.py \ 21 | --dataset_name=pascalvoc \ 22 | --dataset_dir=/tmp/pascalvoc \ 23 | --output_name=pascalvoc \ 24 | --output_dir=/tmp/ 25 | ``` 26 | """ 27 | import tensorflow as tf 28 | 29 | from datasets import pascalvoc_to_tfrecords 30 | 31 | FLAGS = tf.app.flags.FLAGS 32 | 33 | tf.app.flags.DEFINE_string( 34 | 'dataset_name', 'pascalvoc', 35 | 'The name of the dataset to convert.') 36 | tf.app.flags.DEFINE_string( 37 | 'dataset_dir', None, 38 | 'Directory where the original dataset is stored.') 39 | tf.app.flags.DEFINE_string( 40 | 'output_name', 'pascalvoc', 41 | 'Basename used for TFRecords output files.') 42 | tf.app.flags.DEFINE_string( 43 | 'output_dir', './', 44 | 'Output directory where to store TFRecords files.') 45 | 46 | 47 | def main(_): 48 | if not FLAGS.dataset_dir: 49 | raise ValueError('You must supply the dataset directory with --dataset_dir') 50 | print('Dataset directory:', FLAGS.dataset_dir) 51 | print('Output directory:', FLAGS.output_dir) 52 | 53 | if FLAGS.dataset_name == 'pascalvoc': 54 | pascalvoc_to_tfrecords.run(FLAGS.dataset_dir, FLAGS.output_dir, FLAGS.output_name) 55 | else: 56 | raise ValueError('Dataset [%s] was not recognized.' % FLAGS.dataset_name) 57 | 58 | if __name__ == '__main__': 59 | tf.app.run() 60 | 61 | -------------------------------------------------------------------------------- /SSD/tf_extended/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional metrics. 16 | """ 17 | 18 | from tf_extended.bboxes import * 19 | from tf_extended.image import * 20 | from tf_extended.math import * 21 | # pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import 22 | from tf_extended.metrics import * 23 | from tf_extended.tensors import * 24 | 25 | -------------------------------------------------------------------------------- /SSD/tf_extended/image.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/SSD/tf_extended/image.py -------------------------------------------------------------------------------- /SSD/tf_extended/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional math functions. 16 | """ 17 | import tensorflow as tf 18 | from tensorflow.python.framework import ops 19 | from tensorflow.python.ops import math_ops 20 | 21 | 22 | def safe_divide(numerator, denominator, name): 23 | """Divides two values, returning 0 if the denominator is <= 0. 24 | Args: 25 | numerator: A real `Tensor`. 26 | denominator: A real `Tensor`, with dtype matching `numerator`. 27 | name: Name for the returned op. 28 | Returns: 29 | 0 if `denominator` <= 0, else `numerator` / `denominator` 30 | """ 31 | return tf.where( 32 | math_ops.greater(denominator, 0), 33 | math_ops.divide(numerator, denominator), 34 | tf.zeros_like(numerator), 35 | name=name) 36 | 37 | 38 | def cummax(x, reverse=False, name=None): 39 | """Compute the cumulative maximum of the tensor `x` along `axis`. This 40 | operation is similar to the more classic `cumsum`. Only support 1D Tensor 41 | for now. 42 | 43 | Args: 44 | x: A `Tensor`. Must be one of the following types: `float32`, `float64`, 45 | `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, 46 | `complex128`, `qint8`, `quint8`, `qint32`, `half`. 47 | axis: A `Tensor` of type `int32` (default: 0). 48 | reverse: A `bool` (default: False). 49 | name: A name for the operation (optional). 50 | Returns: 51 | A `Tensor`. Has the same type as `x`. 52 | """ 53 | with ops.name_scope(name, "Cummax", [x]) as name: 54 | x = ops.convert_to_tensor(x, name="x") 55 | # Not very optimal: should directly integrate reverse into tf.scan. 56 | if reverse: 57 | x = tf.reverse(x, axis=[0]) 58 | # 'Accumlating' maximum: ensure it is always increasing. 59 | cmax = tf.scan(lambda a, y: tf.maximum(a, y), x, 60 | initializer=None, parallel_iterations=1, 61 | back_prop=False, swap_memory=False) 62 | if reverse: 63 | cmax = tf.reverse(cmax, axis=[0]) 64 | return cmax 65 | -------------------------------------------------------------------------------- /SSD/tf_extended/tensors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional tensors operations. 16 | """ 17 | import tensorflow as tf 18 | 19 | 20 | def get_shape(x, rank=None): 21 | """Returns the dimensions of a Tensor as list of integers or scale tensors. 22 | 23 | Args: 24 | x: N-d Tensor; 25 | rank: Rank of the Tensor. If None, will try to guess it. 26 | Returns: 27 | A list of `[d1, d2, ..., dN]` corresponding to the dimensions of the 28 | input tensor. Dimensions that are statically known are python integers, 29 | otherwise they are integer scalar tensors. 30 | """ 31 | if x.get_shape().is_fully_defined(): 32 | return x.get_shape().as_list() 33 | else: 34 | static_shape = x.get_shape() 35 | if rank is None: 36 | static_shape = static_shape.as_list() 37 | rank = len(static_shape) 38 | else: 39 | static_shape = x.get_shape().with_rank(rank).as_list() 40 | dynamic_shape = tf.unstack(tf.shape(x), rank) 41 | return [s if s is not None else d 42 | for s, d in zip(static_shape, dynamic_shape)] 43 | 44 | 45 | def pad_axis(x, offset, size, axis=0, name=None): 46 | """Pad a tensor on an axis, with a given offset and output size. 47 | The tensor is padded with zero (i.e. CONSTANT mode). Note that the if the 48 | `size` is smaller than existing size + `offset`, the output tensor 49 | was the latter dimension. 50 | 51 | Args: 52 | x: Tensor to pad; 53 | offset: Offset to add on the dimension chosen; 54 | size: Final size of the dimension. 55 | Return: 56 | Padded tensor whose dimension on `axis` is `size`, or greater if 57 | the input vector was larger. 58 | """ 59 | with tf.name_scope(name, 'pad_axis'): 60 | shape = get_shape(x) 61 | rank = len(shape) 62 | # Padding description. 63 | new_size = tf.maximum(size-offset-shape[axis], 0) 64 | pad1 = tf.stack([0]*axis + [offset] + [0]*(rank-axis-1)) 65 | pad2 = tf.stack([0]*axis + [new_size] + [0]*(rank-axis-1)) 66 | paddings = tf.stack([pad1, pad2], axis=1) 67 | x = tf.pad(x, paddings, mode='CONSTANT') 68 | # Reshape, to get fully defined shape if possible. 69 | # TODO: fix with tf.slice 70 | shape[axis] = size 71 | x = tf.reshape(x, tf.stack(shape)) 72 | return x 73 | 74 | 75 | # def select_at_index(idx, val, t): 76 | # """Return a tensor. 77 | # """ 78 | # idx = tf.expand_dims(tf.expand_dims(idx, 0), 0) 79 | # val = tf.expand_dims(val, 0) 80 | # t = t + tf.scatter_nd(idx, val, tf.shape(t)) 81 | # return t 82 | -------------------------------------------------------------------------------- /SSD/tf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Diverse TensorFlow utils, for training, evaluation and so on! 16 | """ 17 | import os 18 | from pprint import pprint 19 | 20 | import tensorflow as tf 21 | from tensorflow.contrib.slim.python.slim.data import parallel_reader 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | # =========================================================================== # 27 | # General tools. 28 | # =========================================================================== # 29 | def reshape_list(l, shape=None): 30 | """Reshape list of (list): 1D to 2D or the other way around. 31 | 32 | Args: 33 | l: List or List of list. 34 | shape: 1D or 2D shape. 35 | Return 36 | Reshaped list. 37 | """ 38 | r = [] 39 | if shape is None: 40 | # Flatten everything. 41 | for a in l: 42 | if isinstance(a, (list, tuple)): 43 | r = r + list(a) 44 | else: 45 | r.append(a) 46 | else: 47 | # Reshape to list of list. 48 | i = 0 49 | for s in shape: 50 | if s == 1: 51 | r.append(l[i]) 52 | else: 53 | r.append(l[i:i+s]) 54 | i += s 55 | return r 56 | 57 | 58 | # =========================================================================== # 59 | # Training utils. 60 | # =========================================================================== # 61 | def print_configuration(flags, ssd_params, data_sources, save_dir=None): 62 | """Print the training configuration. 63 | """ 64 | def print_config(stream=None): 65 | print('\n# =========================================================================== #', file=stream) 66 | print('# Training | Evaluation flags:', file=stream) 67 | print('# =========================================================================== #', file=stream) 68 | pprint(flags, stream=stream) 69 | 70 | print('\n# =========================================================================== #', file=stream) 71 | print('# SSD net parameters:', file=stream) 72 | print('# =========================================================================== #', file=stream) 73 | pprint(dict(ssd_params._asdict()), stream=stream) 74 | 75 | print('\n# =========================================================================== #', file=stream) 76 | print('# Training | Evaluation dataset files:', file=stream) 77 | print('# =========================================================================== #', file=stream) 78 | data_files = parallel_reader.get_data_files(data_sources) 79 | pprint(sorted(data_files), stream=stream) 80 | print('', file=stream) 81 | 82 | print_config(None) 83 | # Save to a text file as well. 84 | if save_dir is not None: 85 | if not os.path.exists(save_dir): 86 | os.makedirs(save_dir) 87 | path = os.path.join(save_dir, 'training_config.txt') 88 | with open(path, "w") as out: 89 | print_config(out) 90 | 91 | 92 | def configure_learning_rate(flags, num_samples_per_epoch, global_step): 93 | """Configures the learning rate. 94 | 95 | Args: 96 | num_samples_per_epoch: The number of samples in each epoch of training. 97 | global_step: The global_step tensor. 98 | Returns: 99 | A `Tensor` representing the learning rate. 100 | """ 101 | decay_steps = int(num_samples_per_epoch / flags.batch_size * 102 | flags.num_epochs_per_decay) 103 | 104 | if flags.learning_rate_decay_type == 'exponential': 105 | return tf.train.exponential_decay(flags.learning_rate, 106 | global_step, 107 | decay_steps, 108 | flags.learning_rate_decay_factor, 109 | staircase=True, 110 | name='exponential_decay_learning_rate') 111 | elif flags.learning_rate_decay_type == 'fixed': 112 | return tf.constant(flags.learning_rate, name='fixed_learning_rate') 113 | elif flags.learning_rate_decay_type == 'polynomial': 114 | return tf.train.polynomial_decay(flags.learning_rate, 115 | global_step, 116 | decay_steps, 117 | flags.end_learning_rate, 118 | power=1.0, 119 | cycle=False, 120 | name='polynomial_decay_learning_rate') 121 | else: 122 | raise ValueError('learning_rate_decay_type [%s] was not recognized', 123 | flags.learning_rate_decay_type) 124 | 125 | 126 | def configure_optimizer(flags, learning_rate): 127 | """Configures the optimizer used for training. 128 | 129 | Args: 130 | learning_rate: A scalar or `Tensor` learning rate. 131 | Returns: 132 | An instance of an optimizer. 133 | """ 134 | if flags.optimizer == 'adadelta': 135 | optimizer = tf.train.AdadeltaOptimizer( 136 | learning_rate, 137 | rho=flags.adadelta_rho, 138 | epsilon=flags.opt_epsilon) 139 | elif flags.optimizer == 'adagrad': 140 | optimizer = tf.train.AdagradOptimizer( 141 | learning_rate, 142 | initial_accumulator_value=flags.adagrad_initial_accumulator_value) 143 | elif flags.optimizer == 'adam': 144 | optimizer = tf.train.AdamOptimizer( 145 | learning_rate, 146 | beta1=flags.adam_beta1, 147 | beta2=flags.adam_beta2, 148 | epsilon=flags.opt_epsilon) 149 | elif flags.optimizer == 'ftrl': 150 | optimizer = tf.train.FtrlOptimizer( 151 | learning_rate, 152 | learning_rate_power=flags.ftrl_learning_rate_power, 153 | initial_accumulator_value=flags.ftrl_initial_accumulator_value, 154 | l1_regularization_strength=flags.ftrl_l1, 155 | l2_regularization_strength=flags.ftrl_l2) 156 | elif flags.optimizer == 'momentum': 157 | optimizer = tf.train.MomentumOptimizer( 158 | learning_rate, 159 | momentum=flags.momentum, 160 | name='Momentum') 161 | elif flags.optimizer == 'rmsprop': 162 | optimizer = tf.train.RMSPropOptimizer( 163 | learning_rate, 164 | decay=flags.rmsprop_decay, 165 | momentum=flags.rmsprop_momentum, 166 | epsilon=flags.opt_epsilon) 167 | elif flags.optimizer == 'sgd': 168 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 169 | else: 170 | raise ValueError('Optimizer [%s] was not recognized', flags.optimizer) 171 | return optimizer 172 | 173 | 174 | def add_variables_summaries(learning_rate): 175 | summaries = [] 176 | for variable in slim.get_model_variables(): 177 | summaries.append(tf.summary.histogram(variable.op.name, variable)) 178 | summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate)) 179 | return summaries 180 | 181 | 182 | def update_model_scope(var, ckpt_scope, new_scope): 183 | return var.op.name.replace(new_scope,'vgg_16') 184 | 185 | 186 | def get_init_fn(flags): 187 | """Returns a function run by the chief worker to warm-start the training. 188 | Note that the init_fn is only run when initializing the model during the very 189 | first global step. 190 | 191 | Returns: 192 | An init function run by the supervisor. 193 | """ 194 | if flags.checkpoint_path is None: 195 | return None 196 | # Warn the user if a checkpoint exists in the train_dir. Then ignore. 197 | if tf.train.latest_checkpoint(flags.train_dir): 198 | tf.logging.info( 199 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s' 200 | % flags.train_dir) 201 | return None 202 | 203 | exclusions = [] 204 | if flags.checkpoint_exclude_scopes: 205 | exclusions = [scope.strip() 206 | for scope in flags.checkpoint_exclude_scopes.split(',')] 207 | 208 | # TODO(sguada) variables.filter_variables() 209 | variables_to_restore = [] 210 | for var in slim.get_model_variables(): 211 | excluded = False 212 | for exclusion in exclusions: 213 | if var.op.name.startswith(exclusion): 214 | excluded = True 215 | break 216 | if not excluded: 217 | variables_to_restore.append(var) 218 | # Change model scope if necessary. 219 | if flags.checkpoint_model_scope is not None: 220 | variables_to_restore = \ 221 | {var.op.name.replace(flags.model_name, 222 | flags.checkpoint_model_scope): var 223 | for var in variables_to_restore} 224 | 225 | 226 | if tf.gfile.IsDirectory(flags.checkpoint_path): 227 | checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path) 228 | else: 229 | checkpoint_path = flags.checkpoint_path 230 | tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars)) 231 | 232 | return slim.assign_from_checkpoint_fn( 233 | checkpoint_path, 234 | variables_to_restore, 235 | ignore_missing_vars=flags.ignore_missing_vars) 236 | 237 | 238 | def get_variables_to_train(flags): 239 | """Returns a list of variables to train. 240 | 241 | Returns: 242 | A list of variables to train by the optimizer. 243 | """ 244 | if flags.trainable_scopes is None: 245 | return tf.trainable_variables() 246 | else: 247 | scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')] 248 | 249 | variables_to_train = [] 250 | for scope in scopes: 251 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 252 | variables_to_train.extend(variables) 253 | return variables_to_train 254 | 255 | 256 | # =========================================================================== # 257 | # Evaluation utils. 258 | # =========================================================================== # 259 | -------------------------------------------------------------------------------- /spaceshooter/assets/bolt_gold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/bolt_gold.png -------------------------------------------------------------------------------- /spaceshooter/assets/laserRed16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/laserRed16.png -------------------------------------------------------------------------------- /spaceshooter/assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/main.png -------------------------------------------------------------------------------- /spaceshooter/assets/meteorBrown_big1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_big1.png -------------------------------------------------------------------------------- /spaceshooter/assets/meteorBrown_big2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_big2.png -------------------------------------------------------------------------------- /spaceshooter/assets/meteorBrown_med1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_med1.png -------------------------------------------------------------------------------- /spaceshooter/assets/meteorBrown_med3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_med3.png -------------------------------------------------------------------------------- /spaceshooter/assets/meteorBrown_small1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_small1.png -------------------------------------------------------------------------------- /spaceshooter/assets/meteorBrown_small2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_small2.png -------------------------------------------------------------------------------- /spaceshooter/assets/meteorBrown_tiny1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/meteorBrown_tiny1.png -------------------------------------------------------------------------------- /spaceshooter/assets/missile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/missile.png -------------------------------------------------------------------------------- /spaceshooter/assets/playerShip1_orange.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/playerShip1_orange.png -------------------------------------------------------------------------------- /spaceshooter/assets/regularExplosion00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion00.png -------------------------------------------------------------------------------- /spaceshooter/assets/regularExplosion01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion01.png -------------------------------------------------------------------------------- /spaceshooter/assets/regularExplosion02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion02.png -------------------------------------------------------------------------------- /spaceshooter/assets/regularExplosion03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion03.png -------------------------------------------------------------------------------- /spaceshooter/assets/regularExplosion04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion04.png -------------------------------------------------------------------------------- /spaceshooter/assets/regularExplosion05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion05.png -------------------------------------------------------------------------------- /spaceshooter/assets/regularExplosion06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion06.png -------------------------------------------------------------------------------- /spaceshooter/assets/regularExplosion07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion07.png -------------------------------------------------------------------------------- /spaceshooter/assets/regularExplosion08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/regularExplosion08.png -------------------------------------------------------------------------------- /spaceshooter/assets/shield_gold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/shield_gold.png -------------------------------------------------------------------------------- /spaceshooter/assets/sonicExplosion00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion00.png -------------------------------------------------------------------------------- /spaceshooter/assets/sonicExplosion01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion01.png -------------------------------------------------------------------------------- /spaceshooter/assets/sonicExplosion02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion02.png -------------------------------------------------------------------------------- /spaceshooter/assets/sonicExplosion03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion03.png -------------------------------------------------------------------------------- /spaceshooter/assets/sonicExplosion04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion04.png -------------------------------------------------------------------------------- /spaceshooter/assets/sonicExplosion05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion05.png -------------------------------------------------------------------------------- /spaceshooter/assets/sonicExplosion06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion06.png -------------------------------------------------------------------------------- /spaceshooter/assets/sonicExplosion07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion07.png -------------------------------------------------------------------------------- /spaceshooter/assets/sonicExplosion08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/sonicExplosion08.png -------------------------------------------------------------------------------- /spaceshooter/assets/starfield.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/assets/starfield.png -------------------------------------------------------------------------------- /spaceshooter/detection_and_tracking.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import sys 4 | from os.path import realpath, dirname, join 5 | 6 | import cv2 7 | import numpy as np 8 | import tensorflow as tf 9 | import torch 10 | 11 | from net import SiamRPNvot 12 | from nets import ssd_vgg_300, np_methods 13 | from preprocessing import ssd_vgg_preprocessing 14 | from run_SiamRPN import SiamRPN_init, SiamRPN_track 15 | from utils import cxy_wh_2_rect 16 | 17 | 18 | def get_object_center(q, detect_class): 19 | 20 | # classes: 21 | # 1.Aeroplanes 2.Bicycles 3.Birds 4.Boats 5.Bottles 22 | # 6.Buses 7.Cars 8.Cats 9.Chairs 10.Cows 23 | # 11.Dining tables 12.Dogs 13.Horses 14.Motorbikes 15.People 24 | # 16.Potted plants 17.Sheep 18.Sofas 19.Trains 20.TV/Monitors 25 | 26 | slim = tf.contrib.slim 27 | 28 | # TensorFlow session: grow memory when needed. TF, DO NOT USE ALL MY GPU MEMORY!!! 29 | gpu_options = tf.GPUOptions(allow_growth=True) 30 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options) 31 | isess = tf.InteractiveSession(config=config) 32 | 33 | # Input placeholder. 34 | net_shape = (300, 300) 35 | data_format = 'NHWC' 36 | img_input = tf.placeholder(tf.uint8, shape=(None, None, 3)) 37 | # Evaluation pre-processing: resize to SSD net shape. 38 | image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval( 39 | img_input, None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE) 40 | image_4d = tf.expand_dims(image_pre, 0) 41 | 42 | # Define the SSD model. 43 | reuse = True if 'ssd_net' in locals() else None 44 | ssd_net = ssd_vgg_300.SSDNet() 45 | with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)): 46 | predictions, localisations, _, _ = ssd_net.net(image_4d, is_training=False, reuse=reuse) 47 | 48 | # Restore SSD model. 49 | # ckpt_filename = 'checkpoints/ssd_300_vgg.ckpt' 50 | ckpt_filename = '../SSD-Tensorflow/checkpoints/VGG_VOC0712_SSD_300x300_ft_iter_120000.ckpt' 51 | 52 | isess.run(tf.global_variables_initializer()) 53 | saver = tf.train.Saver() 54 | saver.restore(isess, ckpt_filename) 55 | 56 | # SSD default anchor boxes. 57 | ssd_anchors = ssd_net.anchors(net_shape) 58 | 59 | # Main image processing routine. 60 | def process_image(img, select_threshold=0.5, nms_threshold=.45, net_shape=(300, 300)): 61 | # Run SSD network. 62 | rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions, localisations, bbox_img], 63 | feed_dict={img_input: img}) 64 | 65 | # Get classes and bboxes from the net outputs. 66 | rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select( 67 | rpredictions, rlocalisations, ssd_anchors, 68 | select_threshold=select_threshold, img_shape=net_shape, num_classes=21, decode=True) 69 | 70 | rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes) 71 | rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400) 72 | rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold) 73 | # Resize bboxes to original image shape. Note: useless for Resize.WARP! 74 | rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes) 75 | return rclasses, rscores, rbboxes 76 | 77 | def get_bboxes(rclasses, rbboxes): 78 | # get center location of object 79 | 80 | number_classes = rclasses.shape[0] 81 | object_bboxes = [] 82 | for i in range(number_classes): 83 | object_bbox = dict() 84 | object_bbox['i'] = i 85 | object_bbox['class'] = rclasses[i] 86 | object_bbox['y_min'] = rbboxes[i, 0] 87 | object_bbox['x_min'] = rbboxes[i, 1] 88 | object_bbox['y_max'] = rbboxes[i, 2] 89 | object_bbox['x_max'] = rbboxes[i, 3] 90 | object_bboxes.append(object_bbox) 91 | return object_bboxes 92 | 93 | # load net 94 | net = SiamRPNvot() 95 | net.load_state_dict(torch.load(join(realpath(dirname(__file__)), '../DaSiamRPN-master/code/SiamRPNVOT.model'))) 96 | 97 | net.eval() 98 | 99 | # open video capture 100 | video = cv2.VideoCapture(0) 101 | 102 | if not video.isOpened(): 103 | print("Could not open video") 104 | sys.exit() 105 | 106 | index = True 107 | while index: 108 | 109 | # Read first frame. 110 | ok, frame = video.read() 111 | if not ok: 112 | print('Cannot read video file') 113 | sys.exit() 114 | 115 | # Define an initial bounding box 116 | height = frame.shape[0] 117 | width = frame.shape[1] 118 | 119 | rclasses, rscores, rbboxes = process_image(frame) 120 | 121 | bboxes = get_bboxes(rclasses, rbboxes) 122 | for bbox in bboxes: 123 | if bbox['class'] == detect_class: 124 | print(bbox) 125 | ymin = int(bbox['y_min'] * height) 126 | xmin = int((bbox['x_min']) * width) 127 | ymax = int(bbox['y_max'] * height) 128 | xmax = int((bbox['x_max']) * width) 129 | cx = (xmin + xmax) / 2 130 | cy = (ymin + ymax) / 2 131 | h = ymax - ymin 132 | w = xmax - xmin 133 | new_bbox = (cx, cy, w, h) 134 | print(new_bbox) 135 | index = False 136 | break 137 | 138 | # tracker init 139 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h]) 140 | state = SiamRPN_init(frame, target_pos, target_sz, net) 141 | 142 | # tracking and visualization 143 | toc = 0 144 | count_number = 0 145 | 146 | while True: 147 | 148 | # Read a new frame 149 | ok, frame = video.read() 150 | if not ok: 151 | break 152 | 153 | # Start timer 154 | tic = cv2.getTickCount() 155 | 156 | # Update tracker 157 | state = SiamRPN_track(state, frame) # track 158 | # print(state) 159 | 160 | toc += cv2.getTickCount() - tic 161 | 162 | if state: 163 | 164 | res = cxy_wh_2_rect(state['target_pos'], state['target_sz']) 165 | res = [int(l) for l in res] 166 | 167 | cv2.rectangle(frame, (res[0], res[1]), (res[0] + res[2], res[1] + res[3]), (0, 255, 255), 3) 168 | 169 | count_number += 1 170 | # set object_center 171 | object_center = dict() 172 | object_center['x'] = state['target_pos'][0] / width 173 | object_center['y'] = state['target_pos'][1] / height 174 | q.put(object_center) 175 | 176 | if (not state) or count_number % 40 == 3: 177 | # Tracking failure 178 | cv2.putText(frame, "Tracking failure detected", (100, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 179 | 2) 180 | index = True 181 | while index: 182 | ok, frame = video.read() 183 | rclasses, rscores, rbboxes = process_image(frame) 184 | bboxes = get_bboxes(rclasses, rbboxes) 185 | for bbox in bboxes: 186 | if bbox['class'] == detect_class: 187 | ymin = int(bbox['y_min'] * height) 188 | xmin = int(bbox['x_min'] * width) 189 | ymax = int(bbox['y_max'] * height) 190 | xmax = int(bbox['x_max'] * width) 191 | cx = (xmin + xmax) / 2 192 | cy = (ymin + ymax) / 2 193 | h = ymax - ymin 194 | w = xmax - xmin 195 | new_bbox = (cx, cy, w, h) 196 | target_pos, target_sz = np.array([cx, cy]), np.array([w, h]) 197 | state = SiamRPN_init(frame, target_pos, target_sz, net) 198 | 199 | p1 = (int(xmin), int(ymin)) 200 | p2 = (int(xmax), int(ymax)) 201 | cv2.rectangle(frame, p1, p2, (0, 255, 0), 2, 1) 202 | 203 | index = 0 204 | 205 | break 206 | 207 | # 调整图片大小 208 | resized_frame = cv2.resize(frame, None, fx=0.65, fy=0.65, interpolation=cv2.INTER_AREA) 209 | # 水平翻转图片(为了镜像显示) 210 | horizontal = cv2.flip(resized_frame, 1, dst=None) 211 | 212 | # 显示图片 213 | cv2.namedWindow("SSD+SiamRPN", cv2.WINDOW_NORMAL) 214 | cv2.imshow('SSD+SiamRPN', horizontal) 215 | 216 | # Exit if ESC pressed 217 | k = cv2.waitKey(1) & 0xff 218 | if k == 27: 219 | break 220 | 221 | video.release() 222 | cv2.destroyAllWindows() 223 | -------------------------------------------------------------------------------- /spaceshooter/detection_tracking_game.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import multiprocessing 4 | 5 | import sys 6 | 7 | from detection_and_tracking import get_object_center 8 | 9 | from simulate_game import space_shooter 10 | 11 | sys.path.append("..") 12 | 13 | q = multiprocessing.Queue() 14 | 15 | # classes: 16 | # 1.Aeroplanes 2.Bicycles 3.Birds 4.Boats 5.Bottles 17 | # 6.Buses 7.Cars 8.Cats 9.Chairs 10.Cows 18 | # 11.Dining tables 12.Dogs 13.Horses 14.Motorbikes 15.People 19 | # 16.Potted plants 17.Sheep 18.Sofas 19.Trains 20.TV/Monitors 20 | 21 | # 指定进行检测和跟踪的类型 22 | OBJECT_CLASS = 5 23 | 24 | set_process = multiprocessing.Process(target=get_object_center, args=(q, OBJECT_CLASS,)) 25 | game_process = multiprocessing.Process(target=space_shooter, args=(q,)) 26 | 27 | set_process.start() 28 | game_process.start() 29 | 30 | game_process.join() 31 | set_process.terminate() 32 | 33 | print("退出主线程") 34 | -------------------------------------------------------------------------------- /spaceshooter/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /spaceshooter/sounds/expl3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/expl3.wav -------------------------------------------------------------------------------- /spaceshooter/sounds/expl6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/expl6.wav -------------------------------------------------------------------------------- /spaceshooter/sounds/getready.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/getready.ogg -------------------------------------------------------------------------------- /spaceshooter/sounds/menu.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/menu.ogg -------------------------------------------------------------------------------- /spaceshooter/sounds/pew.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/pew.wav -------------------------------------------------------------------------------- /spaceshooter/sounds/rocket.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/rocket.ogg -------------------------------------------------------------------------------- /spaceshooter/sounds/rumble1.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/rumble1.ogg -------------------------------------------------------------------------------- /spaceshooter/sounds/tgfcoder-FrozenJam-SeamlessLoop.ogg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnWang-AI/Real-Time-Object-Detection-and-Tracking/9b2bce8b9151fc589c09ff07518221e8161bd506/spaceshooter/sounds/tgfcoder-FrozenJam-SeamlessLoop.ogg --------------------------------------------------------------------------------