├── .gitignore ├── LICENSE ├── README.md ├── convert_to_voc.py ├── data ├── __init__.py ├── config.py ├── data_augment.py ├── video │ ├── CARDS_OFFICE_H_T_frame_1085.jpg │ ├── hand.avi │ └── saveVideo.gif └── wider_voc.py ├── egohands_dataset_clean.py ├── layers ├── __init__.py ├── functions │ └── prior_box.py └── modules │ ├── __init__.py │ └── multibox_loss.py ├── make.sh ├── models ├── __init__.py └── faceboxes.py ├── prepare_data.sh ├── test.py ├── train.py ├── utils ├── __init__.py ├── box_utils.py ├── build.py ├── nms │ ├── __init__.py │ ├── cpu_nms.c │ ├── cpu_nms.pyx │ ├── gpu_nms.cpp │ ├── gpu_nms.hpp │ ├── gpu_nms.pyx │ ├── nms_kernel.cu │ └── py_cpu_nms.py ├── nms_wrapper.py └── timer.py └── xml2dict.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # cython generated cpp 107 | .vscode 108 | .idea 109 | 110 | # data 111 | data/Hand/* 112 | weights/ 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 zll 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hand-detection.PyTorch 2 | Hand detection in PyTorch 3 | 4 |

5 | 6 | 7 | 8 |

9 | 10 | ### Contents 11 | - [Installation](#installation) 12 | - [Training](#training) 13 | - [Demo](#Demo) 14 | - [References](#references) 15 | 16 | ## Installation 17 | 1. Install [PyTorch-0.4.0](https://pytorch.org/) according to your environment. 18 | 19 | 2. Clone this repository. We will call the cloned directory as `$HandBoxes_ROOT`. 20 | ```Shell 21 | git clone https://github.com/zllrunning/hand-detection.PyTorch.git 22 | ``` 23 | 24 | 3. Compile the nms: 25 | ```Shell 26 | ./make.sh 27 | ``` 28 | 29 | _Note: We currently only support PyTorch-0.4.0 and Python 3+._ 30 | 31 | ## Training 32 | 33 | 1. Prepare training data: 34 | ``` 35 | -- download EgoHands dataset 36 | -- generate bounding boxes and visualize them to ensure correctness 37 | -- convert bbox file to VOC format 38 | ``` 39 | 40 | ```Shell 41 | cd $HandBoxes_ROOT/ 42 | sh prepare_data.sh 43 | ``` 44 | 45 | 2. Train the model using EgoHands dataset: 46 | ```Shell 47 | python3 train.py 48 | ``` 49 | 50 | If you do not wish to train the model, you can download [our pre-trained model](https://drive.google.com/open?id=1eFSwZoSfVVroAy7LiGYybW6F8ErshoZW) and save it in `$HandBoxes_ROOT/weights`. 51 | 52 | 53 | ## Demo 54 | 1. Evaluate the trained model using: 55 | ```Shell 56 | # evaluate using GPU 57 | python test.py --video data/video/hand.avi 58 | # evaluate using cpu 59 | python test.py --image data/video/CARDS_OFFICE_H_T_frame_1085.jpg --cpu 60 | ``` 61 | 62 | ## References 63 | This project is based on [FaceBoxes.PyTorch](https://github.com/zisianw/FaceBoxes.PyTorch) 64 | - [handtracking](https://github.com/victordibia/handtracking) 65 | - [od-annotation](https://github.com/hzylmf/od-annotation) 66 | -------------------------------------------------------------------------------- /convert_to_voc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import codecs 3 | import hashlib 4 | import traceback 5 | import os 6 | import json 7 | import random 8 | import xml2dict 9 | import pandas as pd 10 | 11 | 12 | def convert_to_voc2007(file_path_1='annotation/annotation1.txt', file_path_2='annotation/annotation2.txt'): 13 | """转换标注数据为VOC2007格式""" 14 | # with codecs.open(file_path, mode='r', encoding='utf-8') as file: 15 | # lines = file.readlines() 16 | df_1 = pd.read_csv(file_path_1) 17 | df_2 = pd.read_csv(file_path_2) 18 | df = pd.concat([df_1, df_2], axis=0) 19 | # lines = df.iterrows() 20 | annotations = dict() 21 | for index, line in df.iterrows(): 22 | # if line.strip()=='':continue 23 | # values = line.strip().split(',') 24 | name = line['filename'] 25 | type = line['class'] 26 | object = dict() 27 | object['name'] = type 28 | object['pose'] = 'Unspecified' 29 | object['truncated'] = 0 30 | object['difficult'] = 0 31 | object['bndbox'] = dict() 32 | object['bndbox']['xmin'] = line['xmin'] 33 | object['bndbox']['ymin'] = line['ymin'] 34 | object['bndbox']['xmax'] = line['xmax'] 35 | object['bndbox']['ymax'] = line['ymax'] 36 | if name not in annotations: 37 | annotation = dict() 38 | annotation['folder'] = 'VOC2007' 39 | annotation['filename'] = name 40 | annotation['size'] = dict() 41 | annotation['size']['width'] = line['width'] # 若样本未统一尺寸,请根据实际情况获取 42 | annotation['size']['height'] = line['height'] # 若样本未统一尺寸,请根据实际情况获取 43 | annotation['size']['depth'] = 3 44 | annotation['segmented'] = 0 45 | annotation['object'] = [object] 46 | annotations[name] = annotation 47 | else: 48 | annotation = annotations[name] 49 | annotation['object'].append(object) 50 | names = [] 51 | path = 'annotation/VOC2007/' 52 | if not os.path.exists(path+'Annotations'): 53 | os.makedirs(path+'Annotations') 54 | for annotation in annotations.items(): 55 | filename = annotation[0].split('.')[0] 56 | names.append(filename) 57 | dic = {'annotation':annotation[1]} 58 | convertedXml = xml2dict.unparse(dic) 59 | xml_nohead = convertedXml.split('\n')[1] 60 | file = codecs.open(path + 'Annotations/'+filename + '.xml', mode='w', encoding='utf-8') 61 | file.write(xml_nohead) 62 | file.close() 63 | random.shuffle(names) 64 | if not os.path.exists(path+'ImageSets'): 65 | os.mkdir(path+'ImageSets') 66 | if not os.path.exists(path+'ImageSets/Main'): 67 | os.mkdir(path+'ImageSets/Main') 68 | file_train = codecs.open(path+'ImageSets/Main/train.txt',mode='w',encoding='utf-8') 69 | file_test = codecs.open(path + 'ImageSets/Main/test.txt', mode='w', encoding='utf-8') 70 | file_train_val = codecs.open(path + 'ImageSets/Main/trainval.txt', mode='w', encoding='utf-8') 71 | file_val = codecs.open(path + 'ImageSets/Main/val.txt', mode='w', encoding='utf-8') 72 | count = len(names) 73 | count_1 = 0.25 * count 74 | count_2 = 0.9 * count 75 | for i in range(count): 76 | if i < count_1: 77 | file_train_val.write(names[i]+'\n') 78 | file_train.write(names[i] + '\n') 79 | elif count_1 <= i = 1) 32 | if not flag.any(): 33 | continue 34 | 35 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2 36 | mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1) 37 | boxes_t = boxes[mask_a].copy() 38 | labels_t = labels[mask_a].copy() 39 | 40 | # ignore tiny faces 41 | b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim 42 | b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim 43 | mask_b = np.minimum(b_w_t, b_h_t) > 16.0 44 | boxes_t = boxes_t[mask_b] 45 | labels_t = labels_t[mask_b] 46 | 47 | if boxes_t.shape[0] == 0: 48 | continue 49 | 50 | image_t = image[roi[1]:roi[3], roi[0]:roi[2]] 51 | 52 | boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) 53 | boxes_t[:, :2] -= roi[:2] 54 | boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) 55 | boxes_t[:, 2:] -= roi[:2] 56 | 57 | pad_image_flag = False 58 | 59 | return image_t, boxes_t, labels_t, pad_image_flag 60 | return image, boxes, labels, pad_image_flag 61 | 62 | 63 | def _distort(image): 64 | 65 | def _convert(image, alpha=1, beta=0): 66 | tmp = image.astype(float) * alpha + beta 67 | tmp[tmp < 0] = 0 68 | tmp[tmp > 255] = 255 69 | image[:] = tmp 70 | 71 | image = image.copy() 72 | 73 | if random.randrange(2): 74 | 75 | #brightness distortion 76 | if random.randrange(2): 77 | _convert(image, beta=random.uniform(-32, 32)) 78 | 79 | #contrast distortion 80 | if random.randrange(2): 81 | _convert(image, alpha=random.uniform(0.5, 1.5)) 82 | 83 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 84 | 85 | #saturation distortion 86 | if random.randrange(2): 87 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) 88 | 89 | #hue distortion 90 | if random.randrange(2): 91 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) 92 | tmp %= 180 93 | image[:, :, 0] = tmp 94 | 95 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 96 | 97 | else: 98 | 99 | #brightness distortion 100 | if random.randrange(2): 101 | _convert(image, beta=random.uniform(-32, 32)) 102 | 103 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 104 | 105 | #saturation distortion 106 | if random.randrange(2): 107 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) 108 | 109 | #hue distortion 110 | if random.randrange(2): 111 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) 112 | tmp %= 180 113 | image[:, :, 0] = tmp 114 | 115 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 116 | 117 | #contrast distortion 118 | if random.randrange(2): 119 | _convert(image, alpha=random.uniform(0.5, 1.5)) 120 | 121 | return image 122 | 123 | 124 | def _expand(image, boxes, fill, p): 125 | if random.randrange(2): 126 | return image, boxes 127 | 128 | height, width, depth = image.shape 129 | 130 | scale = random.uniform(1, p) 131 | w = int(scale * width) 132 | h = int(scale * height) 133 | 134 | left = random.randint(0, w - width) 135 | top = random.randint(0, h - height) 136 | 137 | boxes_t = boxes.copy() 138 | boxes_t[:, :2] += (left, top) 139 | boxes_t[:, 2:] += (left, top) 140 | expand_image = np.empty( 141 | (h, w, depth), 142 | dtype=image.dtype) 143 | expand_image[:, :] = fill 144 | expand_image[top:top + height, left:left + width] = image 145 | image = expand_image 146 | 147 | return image, boxes_t 148 | 149 | 150 | def _mirror(image, boxes): 151 | _, width, _ = image.shape 152 | if random.randrange(2): 153 | image = image[:, ::-1] 154 | boxes = boxes.copy() 155 | boxes[:, 0::2] = width - boxes[:, 2::-2] 156 | return image, boxes 157 | 158 | 159 | def _pad_to_square(image, rgb_mean, pad_image_flag): 160 | if not pad_image_flag: 161 | return image 162 | height, width, _ = image.shape 163 | long_side = max(width, height) 164 | image_t = np.empty((long_side, long_side, 3), dtype=image.dtype) 165 | image_t[:, :] = rgb_mean 166 | image_t[0:0 + height, 0:0 + width] = image 167 | return image_t 168 | 169 | 170 | def _resize_subtract_mean(image, insize, rgb_mean): 171 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] 172 | interp_method = interp_methods[random.randrange(5)] 173 | image = cv2.resize(image, (insize, insize), interpolation=interp_method) 174 | image = image.astype(np.float32) 175 | image -= rgb_mean 176 | return image.transpose(2, 0, 1) 177 | 178 | 179 | class preproc(object): 180 | 181 | def __init__(self, img_dim, rgb_means): 182 | self.img_dim = img_dim 183 | self.rgb_means = rgb_means 184 | 185 | def __call__(self, image, targets): 186 | image = image.astype(np.float32) 187 | assert targets.shape[0] > 0, "this image does not have gt" 188 | 189 | boxes = targets[:, :-1].copy() 190 | labels = targets[:, -1].copy() 191 | 192 | #image_t = _distort(image) 193 | #image_t, boxes_t = _expand(image_t, boxes, self.cfg['rgb_mean'], self.cfg['max_expand_ratio']) 194 | #image_t, boxes_t, labels_t = _crop(image_t, boxes, labels, self.img_dim, self.rgb_means) 195 | image_t, boxes_t, labels_t, pad_image_flag = _crop(image, boxes, labels, self.img_dim) 196 | image_t = _distort(image_t) 197 | image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag) 198 | image_t, boxes_t = _mirror(image_t, boxes_t) 199 | height, width, _ = image_t.shape 200 | image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means) 201 | boxes_t[:, 0::2] /= width 202 | boxes_t[:, 1::2] /= height 203 | 204 | labels_t = np.expand_dims(labels_t, 1) 205 | targets_t = np.hstack((boxes_t, labels_t)) 206 | 207 | return image_t, targets_t 208 | -------------------------------------------------------------------------------- /data/video/CARDS_OFFICE_H_T_frame_1085.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zllrunning/hand-detection.PyTorch/ed1398d9e31bd02e879688045692124460382109/data/video/CARDS_OFFICE_H_T_frame_1085.jpg -------------------------------------------------------------------------------- /data/video/hand.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zllrunning/hand-detection.PyTorch/ed1398d9e31bd02e879688045692124460382109/data/video/hand.avi -------------------------------------------------------------------------------- /data/video/saveVideo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zllrunning/hand-detection.PyTorch/ed1398d9e31bd02e879688045692124460382109/data/video/saveVideo.gif -------------------------------------------------------------------------------- /data/wider_voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import sys 4 | import torch 5 | import torch.utils.data as data 6 | import cv2 7 | import numpy as np 8 | if sys.version_info[0] == 2: 9 | import xml.etree.cElementTree as ET 10 | else: 11 | import xml.etree.ElementTree as ET 12 | 13 | 14 | WIDER_CLASSES = ('__background__', 'hand') 15 | 16 | 17 | class AnnotationTransform(object): 18 | 19 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 20 | Initilized with a dictionary lookup of classnames to indexes 21 | 22 | Arguments: 23 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 24 | (default: alphabetic indexing of VOC's 20 classes) 25 | keep_difficult (bool, optional): keep difficult instances or not 26 | (default: False) 27 | height (int): height 28 | width (int): width 29 | """ 30 | 31 | def __init__(self, class_to_ind=None, keep_difficult=True): 32 | self.class_to_ind = class_to_ind or dict( 33 | zip(WIDER_CLASSES, range(len(WIDER_CLASSES)))) 34 | self.keep_difficult = keep_difficult 35 | 36 | def __call__(self, target): 37 | """ 38 | Arguments: 39 | target (annotation) : the target annotation to be made usable 40 | will be an ET.Element 41 | Returns: 42 | a list containing lists of bounding boxes [bbox coords, class name] 43 | """ 44 | res = np.empty((0, 5)) 45 | for obj in target.iter('object'): 46 | difficult = int(obj.find('difficult').text) == 1 47 | if not self.keep_difficult and difficult: 48 | continue 49 | name = obj.find('name').text.lower().strip() 50 | bbox = obj.find('bndbox') 51 | 52 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 53 | bndbox = [] 54 | for i, pt in enumerate(pts): 55 | cur_pt = int(bbox.find(pt).text) 56 | bndbox.append(cur_pt) 57 | label_idx = self.class_to_ind[name] 58 | bndbox.append(label_idx) 59 | res = np.vstack((res, bndbox)) # [xmin, ymin, xmax, ymax, label_ind] 60 | return res 61 | 62 | 63 | class VOCDetection(data.Dataset): 64 | 65 | """VOC Detection Dataset Object 66 | 67 | input is image, target is annotation 68 | 69 | Arguments: 70 | root (string): filepath to WIDER folder 71 | target_transform (callable, optional): transformation to perform on the 72 | target `annotation` 73 | (eg: take in caption string, return tensor of word indices) 74 | """ 75 | 76 | def __init__(self, root, preproc=None, target_transform=None): 77 | self.root = root 78 | self.preproc = preproc 79 | self.target_transform = target_transform 80 | self._annopath = os.path.join(self.root, 'Annotations', '%s.xml') 81 | self._imgpath = os.path.join(self.root, 'images', '%s.jpg') 82 | self.ids = list() 83 | with open(os.path.join(self.root, 'ImageSets/Main/trainval.txt'), 'r') as f: 84 | self.ids = [tuple(line.split()) for line in f] 85 | 86 | def __getitem__(self, index): 87 | img_id = self.ids[index] 88 | target = ET.parse(self._annopath % img_id[0]).getroot() 89 | img = cv2.imread(self._imgpath % img_id[0], cv2.IMREAD_COLOR) 90 | height, width, _ = img.shape 91 | 92 | if self.target_transform is not None: 93 | target = self.target_transform(target) 94 | 95 | if self.preproc is not None: 96 | img, target = self.preproc(img, target) 97 | 98 | return torch.from_numpy(img), target 99 | 100 | def __len__(self): 101 | return len(self.ids) 102 | 103 | 104 | def detection_collate(batch): 105 | """Custom collate fn for dealing with batches of images that have a different 106 | number of associated object annotations (bounding boxes). 107 | 108 | Arguments: 109 | batch: (tuple) A tuple of tensor images and lists of annotations 110 | 111 | Return: 112 | A tuple containing: 113 | 1) (tensor) batch of images stacked on their 0 dim 114 | 2) (list of tensors) annotations for a given image are stacked on 0 dim 115 | """ 116 | targets = [] 117 | imgs = [] 118 | for _, sample in enumerate(batch): 119 | for _, tup in enumerate(sample): 120 | if torch.is_tensor(tup): 121 | imgs.append(tup) 122 | elif isinstance(tup, type(np.empty(0))): 123 | annos = torch.from_numpy(tup).float() 124 | targets.append(annos) 125 | 126 | return (torch.stack(imgs, 0), targets) 127 | -------------------------------------------------------------------------------- /egohands_dataset_clean.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import numpy as np 3 | import os 4 | import gc 5 | import six.moves.urllib as urllib 6 | import cv2 7 | import time 8 | import xml.etree.cElementTree as ET 9 | import random 10 | import shutil as sh 11 | from shutil import copyfile 12 | import zipfile 13 | 14 | import csv 15 | 16 | 17 | def save_csv(csv_path, csv_content): 18 | with open(csv_path, 'w') as csvfile: 19 | wr = csv.writer(csvfile) 20 | for i in range(len(csv_content)): 21 | wr.writerow(csv_content[i]) 22 | 23 | 24 | def get_bbox_visualize(base_path, dir): 25 | image_path_array = [] 26 | for root, dirs, filenames in os.walk(base_path + dir): 27 | for f in filenames: 28 | if(f.split(".")[1] == "jpg"): 29 | img_path = base_path + dir + "/" + f 30 | image_path_array.append(img_path) 31 | 32 | #sort image_path_array to ensure its in the low to high order expected in polygon.mat 33 | image_path_array.sort() 34 | boxes = sio.loadmat(base_path + dir + "/polygons.mat") 35 | # there are 100 of these per folder in the egohands dataset 36 | polygons = boxes["polygons"][0] 37 | # first = polygons[0] 38 | # print(len(first)) 39 | pointindex = 0 40 | 41 | for first in polygons: 42 | index = 0 43 | 44 | font = cv2.FONT_HERSHEY_SIMPLEX 45 | 46 | img_id = image_path_array[pointindex] 47 | img = cv2.imread(img_id) 48 | 49 | img_params = {} 50 | img_params["width"] = np.size(img, 1) 51 | img_params["height"] = np.size(img, 0) 52 | head, tail = os.path.split(img_id) 53 | img_params["filename"] = tail 54 | img_params["path"] = os.path.abspath(img_id) 55 | img_params["type"] = "train" 56 | pointindex += 1 57 | 58 | boxarray = [] 59 | csvholder = [] 60 | for pointlist in first: 61 | pst = np.empty((0, 2), int) 62 | max_x = max_y = min_x = min_y = height = width = 0 63 | 64 | findex = 0 65 | for point in pointlist: 66 | if(len(point) == 2): 67 | x = int(point[0]) 68 | y = int(point[1]) 69 | 70 | if(findex == 0): 71 | min_x = x 72 | min_y = y 73 | findex += 1 74 | max_x = x if (x > max_x) else max_x 75 | min_x = x if (x < min_x) else min_x 76 | max_y = y if (y > max_y) else max_y 77 | min_y = y if (y < min_y) else min_y 78 | # print(index, "====", len(point)) 79 | appeno = np.array([[x, y]]) 80 | pst = np.append(pst, appeno, axis=0) 81 | cv2.putText(img, ".", (x, y), font, 0.7, 82 | (255, 255, 255), 2, cv2.LINE_AA) 83 | 84 | hold = {} 85 | hold['minx'] = min_x 86 | hold['miny'] = min_y 87 | hold['maxx'] = max_x 88 | hold['maxy'] = max_y 89 | if (min_x > 0 and min_y > 0 and max_x > 0 and max_y > 0): 90 | boxarray.append(hold) 91 | labelrow = [tail, 92 | np.size(img, 1), np.size(img, 0), "hand", min_x, min_y, max_x, max_y] 93 | csvholder.append(labelrow) 94 | 95 | cv2.polylines(img, [pst], True, (0, 255, 255), 1) 96 | cv2.rectangle(img, (min_x, max_y), 97 | (max_x, min_y), (0, 255, 0), 1) 98 | 99 | csv_path = img_id.split(".")[0] 100 | if not os.path.exists(csv_path + ".csv"): 101 | cv2.putText(img, "DIR : " + dir + " - " + tail, (20, 50), 102 | cv2.FONT_HERSHEY_SIMPLEX, 0.75, (77, 255, 9), 2) 103 | cv2.imshow('Verifying annotation ', img) 104 | save_csv(csv_path + ".csv", csvholder) 105 | print("===== saving csv file for ", tail) 106 | cv2.waitKey(2) # close window when a key press is detected 107 | 108 | 109 | def create_directory(dir_path): 110 | if not os.path.exists(dir_path): 111 | os.makedirs(dir_path) 112 | 113 | # combine all individual csv files for each image into a single csv file per folder. 114 | 115 | 116 | def generate_label_files(image_dir): 117 | header = ['filename', 'width', 'height', 118 | 'class', 'xmin', 'ymin', 'xmax', 'ymax'] 119 | for root, dirs, filenames in os.walk(image_dir): 120 | for dir in dirs: 121 | csvholder = [] 122 | csvholder.append(header) 123 | loop_index = 0 124 | for f in os.listdir(image_dir + dir): 125 | if(f.split(".")[1] == "csv"): 126 | loop_index += 1 127 | #print(loop_index, f) 128 | csv_file = open(image_dir + dir + "/" + f, 'r') 129 | reader = csv.reader(csv_file) 130 | for row in reader: 131 | csvholder.append(row) 132 | csv_file.close() 133 | os.remove(image_dir + dir + "/" + f) 134 | save_csv(image_dir + dir + "/" + dir + "_labels.csv", csvholder) 135 | print("Saved label csv for ", dir, image_dir + 136 | dir + "/" + dir + "_labels.csv") 137 | 138 | 139 | # Split data, copy to train/test folders 140 | def split_data_test_eval_train(image_dir): 141 | create_directory("images") 142 | create_directory("images/train") 143 | create_directory("images/test") 144 | 145 | data_size = 4000 146 | loop_index = 0 147 | data_sampsize = int(0.1 * data_size) 148 | test_samp_array = random.sample(range(data_size), k=data_sampsize) 149 | 150 | for root, dirs, filenames in os.walk(image_dir): 151 | for dir in dirs: 152 | for f in os.listdir(image_dir + dir): 153 | if(f.split(".")[1] == "jpg"): 154 | loop_index += 1 155 | print(loop_index, f) 156 | 157 | if loop_index in test_samp_array: 158 | os.rename(image_dir + dir + 159 | "/" + f, "images/test/" + f) 160 | os.rename(image_dir + dir + 161 | "/" + f.split(".")[0] + ".csv", "images/test/" + f.split(".")[0] + ".csv") 162 | else: 163 | os.rename(image_dir + dir + 164 | "/" + f, "images/train/" + f) 165 | os.rename(image_dir + dir + 166 | "/" + f.split(".")[0] + ".csv", "images/train/" + f.split(".")[0] + ".csv") 167 | print(loop_index, image_dir + f) 168 | print("> done scanning director ", dir) 169 | os.remove(image_dir + dir + "/polygons.mat") 170 | os.rmdir(image_dir + dir) 171 | 172 | print("Train/test content generation complete!") 173 | generate_label_files("images/") 174 | 175 | 176 | def generate_csv_files(image_dir): 177 | for root, dirs, filenames in os.walk(image_dir): 178 | for dir in dirs: 179 | get_bbox_visualize(image_dir, dir) 180 | 181 | print("CSV generation complete!\nGenerating train/test/eval folders") 182 | split_data_test_eval_train("egohands/_LABELLED_SAMPLES/") 183 | 184 | 185 | # rename image files so we can have them all in a train/test/eval folder. 186 | def rename_files(image_dir): 187 | print("Renaming files") 188 | loop_index = 0 189 | for root, dirs, filenames in os.walk(image_dir): 190 | for dir in dirs: 191 | for f in os.listdir(image_dir + dir): 192 | if (dir not in f): 193 | if(f.split(".")[1] == "jpg"): 194 | loop_index += 1 195 | os.rename(image_dir + dir + 196 | "/" + f, image_dir + dir + 197 | "/" + dir + "_" + f) 198 | else: 199 | break 200 | 201 | generate_csv_files("egohands/_LABELLED_SAMPLES/") 202 | 203 | def extract_folder(dataset_path): 204 | print("Egohands dataset already downloaded.\nGenerating CSV files") 205 | if not os.path.exists("egohands"): 206 | zip_ref = zipfile.ZipFile(dataset_path, 'r') 207 | print("> Extracting Dataset files") 208 | zip_ref.extractall("egohands") 209 | print("> Extraction complete") 210 | zip_ref.close() 211 | rename_files("egohands/_LABELLED_SAMPLES/") 212 | 213 | def download_egohands_dataset(dataset_url, dataset_path): 214 | is_downloaded = os.path.exists(dataset_path) 215 | if not is_downloaded: 216 | print( 217 | "> downloading egohands dataset. This may take a while (1.3GB, say 3-5mins). Coffee break?") 218 | opener = urllib.request.URLopener() 219 | opener.retrieve(dataset_url, dataset_path) 220 | print("> download complete") 221 | extract_folder(dataset_path); 222 | 223 | else: 224 | extract_folder(dataset_path) 225 | 226 | 227 | EGOHANDS_DATASET_URL = "http://vision.soic.indiana.edu/egohands_files/egohands_data.zip" 228 | EGO_HANDS_FILE = "egohands_data.zip" 229 | 230 | 231 | download_egohands_dataset(EGOHANDS_DATASET_URL, EGO_HANDS_FILE) 232 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /layers/functions/prior_box.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from itertools import product as product 3 | import numpy as np 4 | 5 | 6 | class PriorBox(object): 7 | def __init__(self, cfg, box_dimension=None, image_size=None, phase='train'): 8 | super(PriorBox, self).__init__() 9 | self.variance = cfg['variance'] 10 | self.min_sizes = cfg['min_sizes'] 11 | self.steps = cfg['steps'] 12 | self.aspect_ratios = cfg['aspect_ratios'] 13 | self.clip = cfg['clip'] 14 | if phase == 'train': 15 | self.image_size = (cfg['min_dim'], cfg['min_dim']) 16 | self.feature_maps = cfg['feature_maps'] 17 | elif phase == 'test': 18 | self.feature_maps = box_dimension.cpu().numpy().astype(np.int) 19 | self.image_size = image_size 20 | for v in self.variance: 21 | if v <= 0: 22 | raise ValueError('Variances must be greater than 0') 23 | 24 | def forward(self): 25 | mean = [] 26 | for k, f in enumerate(self.feature_maps): 27 | min_sizes = self.min_sizes[k] 28 | for i, j in product(range(f[0]), range(f[1])): 29 | for min_size in min_sizes: 30 | s_kx = min_size / self.image_size[1] 31 | s_ky = min_size / self.image_size[0] 32 | if min_size == 32: 33 | dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.25, j+0.5, j+0.75]] 34 | dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.25, i+0.5, i+0.75]] 35 | for cy, cx in product(dense_cy, dense_cx): 36 | mean += [cx, cy, s_kx, s_ky] 37 | elif min_size == 64: 38 | dense_cx = [x*self.steps[k]/self.image_size[1] for x in [j+0, j+0.5]] 39 | dense_cy = [y*self.steps[k]/self.image_size[0] for y in [i+0, i+0.5]] 40 | for cy, cx in product(dense_cy, dense_cx): 41 | mean += [cx, cy, s_kx, s_ky] 42 | else: 43 | cx = (j + 0.5) * self.steps[k] / self.image_size[1] 44 | cy = (i + 0.5) * self.steps[k] / self.image_size[0] 45 | mean += [cx, cy, s_kx, s_ky] 46 | # back to torch land 47 | output = torch.Tensor(mean).view(-1, 4) 48 | if self.clip: 49 | output.clamp_(max=1, min=0) 50 | return output 51 | -------------------------------------------------------------------------------- /layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .multibox_loss import MultiBoxLoss 2 | 3 | __all__ = ['MultiBoxLoss'] 4 | -------------------------------------------------------------------------------- /layers/modules/multibox_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from utils.box_utils import match, log_sum_exp 6 | from data import cfg 7 | GPU = cfg['gpu_train'] 8 | 9 | class MultiBoxLoss(nn.Module): 10 | """SSD Weighted Loss Function 11 | Compute Targets: 12 | 1) Produce Confidence Target Indices by matching ground truth boxes 13 | with (default) 'priorboxes' that have jaccard index > threshold parameter 14 | (default threshold: 0.5). 15 | 2) Produce localization target by 'encoding' variance into offsets of ground 16 | truth boxes and their matched 'priorboxes'. 17 | 3) Hard negative mining to filter the excessive number of negative examples 18 | that comes with using a large number of default bounding boxes. 19 | (default negative:positive ratio 3:1) 20 | Objective Loss: 21 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 22 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 23 | weighted by α which is set to 1 by cross val. 24 | Args: 25 | c: class confidences, 26 | l: predicted boxes, 27 | g: ground truth boxes 28 | N: number of matched default boxes 29 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 30 | """ 31 | 32 | def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): 33 | super(MultiBoxLoss, self).__init__() 34 | self.num_classes = num_classes 35 | self.threshold = overlap_thresh 36 | self.background_label = bkg_label 37 | self.encode_target = encode_target 38 | self.use_prior_for_matching = prior_for_matching 39 | self.do_neg_mining = neg_mining 40 | self.negpos_ratio = neg_pos 41 | self.neg_overlap = neg_overlap 42 | self.variance = [0.1, 0.2] 43 | 44 | def forward(self, predictions, priors, targets): 45 | """Multibox Loss 46 | Args: 47 | predictions (tuple): A tuple containing loc preds, conf preds, 48 | and prior boxes from SSD net. 49 | conf shape: torch.size(batch_size,num_priors,num_classes) 50 | loc shape: torch.size(batch_size,num_priors,4) 51 | priors shape: torch.size(num_priors,4) 52 | 53 | ground_truth (tensor): Ground truth boxes and labels for a batch, 54 | shape: [batch_size,num_objs,5] (last idx is the label). 55 | """ 56 | 57 | loc_data, conf_data, _ = predictions 58 | priors = priors 59 | num = loc_data.size(0) 60 | num_priors = (priors.size(0)) 61 | 62 | # match priors (default boxes) and ground truth boxes 63 | loc_t = torch.Tensor(num, num_priors, 4) 64 | conf_t = torch.LongTensor(num, num_priors) 65 | for idx in range(num): 66 | truths = targets[idx][:, :-1].data 67 | labels = targets[idx][:, -1].data 68 | defaults = priors.data 69 | match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx) 70 | if GPU: 71 | loc_t = loc_t.cuda() 72 | conf_t = conf_t.cuda() 73 | # wrap targets 74 | loc_t = Variable(loc_t, requires_grad=False) 75 | conf_t = Variable(conf_t, requires_grad=False) 76 | 77 | pos = conf_t > 0 78 | 79 | # Localization Loss (Smooth L1) 80 | # Shape: [batch,num_priors,4] 81 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 82 | loc_p = loc_data[pos_idx].view(-1, 4) 83 | loc_t = loc_t[pos_idx].view(-1, 4) 84 | loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) 85 | 86 | # Compute max conf across batch for hard negative mining 87 | batch_conf = conf_data.view(-1, self.num_classes) 88 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) 89 | 90 | # Hard Negative Mining 91 | loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now 92 | loss_c = loss_c.view(num, -1) 93 | _, loss_idx = loss_c.sort(1, descending=True) 94 | _, idx_rank = loss_idx.sort(1) 95 | num_pos = pos.long().sum(1, keepdim=True) 96 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) 97 | neg = idx_rank < num_neg.expand_as(idx_rank) 98 | 99 | # Confidence Loss Including Positive and Negative Examples 100 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 101 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 102 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) 103 | targets_weighted = conf_t[(pos+neg).gt(0)] 104 | loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) 105 | 106 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 107 | N = max(num_pos.data.sum().float(), 1) 108 | loss_l /= N 109 | loss_c /= N 110 | 111 | return loss_l, loss_c 112 | -------------------------------------------------------------------------------- /make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ./utils/ 3 | 4 | CUDA_PATH=/usr/local/cuda/ 5 | 6 | python3 build.py build_ext --inplace 7 | 8 | cd .. 9 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zllrunning/hand-detection.PyTorch/ed1398d9e31bd02e879688045692124460382109/models/__init__.py -------------------------------------------------------------------------------- /models/faceboxes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicConv2d(nn.Module): 7 | 8 | def __init__(self, in_channels, out_channels, **kwargs): 9 | super(BasicConv2d, self).__init__() 10 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 11 | self.bn = nn.BatchNorm2d(out_channels, eps=1e-5) 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | x = self.bn(x) 16 | return F.relu(x, inplace=True) 17 | 18 | 19 | class Inception(nn.Module): 20 | 21 | def __init__(self): 22 | super(Inception, self).__init__() 23 | self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0) 24 | self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0) 25 | self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0) 26 | self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1) 27 | self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0) 28 | self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1) 29 | self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1) 30 | 31 | def forward(self, x): 32 | branch1x1 = self.branch1x1(x) 33 | 34 | branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 35 | branch1x1_2 = self.branch1x1_2(branch1x1_pool) 36 | 37 | branch3x3_reduce = self.branch3x3_reduce(x) 38 | branch3x3 = self.branch3x3(branch3x3_reduce) 39 | 40 | branch3x3_reduce_2 = self.branch3x3_reduce_2(x) 41 | branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2) 42 | branch3x3_3 = self.branch3x3_3(branch3x3_2) 43 | 44 | outputs = [branch1x1, branch1x1_2, branch3x3, branch3x3_3] 45 | return torch.cat(outputs, 1) 46 | 47 | 48 | class CRelu(nn.Module): 49 | 50 | def __init__(self, in_channels, out_channels, **kwargs): 51 | super(CRelu, self).__init__() 52 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 53 | self.bn = nn.BatchNorm2d(out_channels, eps=1e-5) 54 | 55 | def forward(self, x): 56 | x = self.conv(x) 57 | x = self.bn(x) 58 | x = torch.cat([x, -x], 1) 59 | x = F.relu(x, inplace=True) 60 | return x 61 | 62 | 63 | class FaceBoxes(nn.Module): 64 | 65 | def __init__(self, phase, size, num_classes): 66 | super(FaceBoxes, self).__init__() 67 | self.phase = phase 68 | self.num_classes = num_classes 69 | self.size = size 70 | 71 | self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3) 72 | self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2) 73 | 74 | self.inception1 = Inception() 75 | self.inception2 = Inception() 76 | self.inception3 = Inception() 77 | 78 | self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0) 79 | self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1) 80 | 81 | self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0) 82 | self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1) 83 | 84 | self.loc, self.conf = self.multibox(self.num_classes) 85 | 86 | if self.phase == 'test': 87 | self.softmax = nn.Softmax(dim=-1) 88 | 89 | if self.phase == 'train': 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | if m.bias is not None: 93 | nn.init.xavier_normal_(m.weight.data) 94 | m.bias.data.fill_(0.02) 95 | else: 96 | m.weight.data.normal_(0, 0.01) 97 | elif isinstance(m, nn.BatchNorm2d): 98 | m.weight.data.fill_(1) 99 | m.bias.data.zero_() 100 | 101 | def multibox(self, num_classes): 102 | loc_layers = [] 103 | conf_layers = [] 104 | loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)] 105 | conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)] 106 | loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)] 107 | conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)] 108 | loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)] 109 | conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)] 110 | return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers) 111 | 112 | def forward(self, x): 113 | 114 | sources = list() 115 | loc = list() 116 | conf = list() 117 | detection_dimension = list() 118 | 119 | x = self.conv1(x) 120 | x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 121 | x = self.conv2(x) 122 | x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 123 | x = self.inception1(x) 124 | x = self.inception2(x) 125 | x = self.inception3(x) 126 | detection_dimension.append(x.shape[2:]) 127 | sources.append(x) 128 | x = self.conv3_1(x) 129 | x = self.conv3_2(x) 130 | detection_dimension.append(x.shape[2:]) 131 | sources.append(x) 132 | x = self.conv4_1(x) 133 | x = self.conv4_2(x) 134 | detection_dimension.append(x.shape[2:]) 135 | sources.append(x) 136 | 137 | detection_dimension = torch.tensor(detection_dimension, device=x.device) 138 | 139 | for (x, l, c) in zip(sources, self.loc, self.conf): 140 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 141 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 142 | 143 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 144 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 145 | 146 | if self.phase == "test": 147 | output = (loc.view(loc.size(0), -1, 4), 148 | self.softmax(conf.view(-1, self.num_classes)), 149 | detection_dimension) 150 | else: 151 | output = (loc.view(loc.size(0), -1, 4), 152 | conf.view(conf.size(0), -1, self.num_classes), 153 | detection_dimension) 154 | 155 | return output 156 | -------------------------------------------------------------------------------- /prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python egohands_dataset_clean.py 4 | 5 | mkdir -p data/Hand/images 6 | mv images/train/*.jpg data/Hand/images 7 | mv images/test/*.jpg data/Hand/images 8 | 9 | python convert_to_voc.py 10 | 11 | mv annotation/VOC2007/* data/Hand/ 12 | 13 | rm -r annotation/ 14 | rm -r egohands/ 15 | rm -r images/ 16 | 17 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import numpy as np 7 | from data import cfg 8 | from layers.functions.prior_box import PriorBox 9 | from utils.nms_wrapper import nms 10 | import cv2 11 | from models.faceboxes import FaceBoxes 12 | from utils.box_utils import decode 13 | from utils.timer import Timer 14 | 15 | parser = argparse.ArgumentParser(description='FaceBoxes') 16 | 17 | parser.add_argument('-m', '--trained_model', default='weights/Final_HandBoxes.pth', 18 | type=str, help='Trained state_dict file path to open') 19 | parser.add_argument('--cpu', action="store_true", default=False, help='Use cpu inference') 20 | parser.add_argument('--video', default='data/video/hand.avi', type=str, help='dataset') 21 | parser.add_argument('--image', default=None, type=str, help='dataset') 22 | parser.add_argument('--confidence_threshold', default=0.2, type=float, help='confidence_threshold') 23 | parser.add_argument('--top_k', default=5000, type=int, help='top_k') 24 | parser.add_argument('--nms_threshold', default=0.2, type=float, help='nms_threshold') 25 | parser.add_argument('--keep_top_k', default=750, type=int, help='keep_top_k') 26 | args = parser.parse_args() 27 | 28 | 29 | def check_keys(model, pretrained_state_dict): 30 | ckpt_keys = set(pretrained_state_dict.keys()) 31 | model_keys = set(model.state_dict().keys()) 32 | used_pretrained_keys = model_keys & ckpt_keys 33 | unused_pretrained_keys = ckpt_keys - model_keys 34 | missing_keys = model_keys - ckpt_keys 35 | print('Missing keys:{}'.format(len(missing_keys))) 36 | print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) 37 | print('Used keys:{}'.format(len(used_pretrained_keys))) 38 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' 39 | return True 40 | 41 | 42 | def remove_prefix(state_dict, prefix): 43 | ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' 44 | print('remove prefix \'{}\''.format(prefix)) 45 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 46 | return {f(key): value for key, value in state_dict.items()} 47 | 48 | 49 | def load_model(model, pretrained_path, load_to_cpu): 50 | print('Loading pretrained model from {}'.format(pretrained_path)) 51 | if load_to_cpu: 52 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) 53 | else: 54 | device = torch.cuda.current_device() 55 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) 56 | if "state_dict" in pretrained_dict.keys(): 57 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') 58 | else: 59 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 60 | check_keys(model, pretrained_dict) 61 | model.load_state_dict(pretrained_dict, strict=False) 62 | return model 63 | 64 | 65 | if __name__ == '__main__': 66 | torch.set_grad_enabled(False) 67 | # net and model 68 | net = FaceBoxes(phase='test', size=None, num_classes=2) # initialize detector 69 | net = load_model(net, args.trained_model, args.cpu) 70 | net.eval() 71 | print('Finished loading model!') 72 | print(net) 73 | cudnn.benchmark = True 74 | device = torch.device("cpu" if args.cpu else "cuda") 75 | net = net.to(device) 76 | 77 | # testing scale 78 | resize = 2 79 | 80 | _t = {'forward_pass': Timer(), 'misc': Timer()} 81 | 82 | if args.image: 83 | to_show = cv2.imread(args.image, cv2.IMREAD_COLOR) 84 | img = np.float32(to_show) 85 | 86 | if resize != 1: 87 | img = cv2.resize(img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) 88 | im_height, im_width, _ = img.shape 89 | scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) 90 | img -= (104, 117, 123) 91 | img = img.transpose(2, 0, 1) 92 | img = torch.from_numpy(img).unsqueeze(0) 93 | img = img.to(device) 94 | scale = scale.to(device) 95 | 96 | _t['forward_pass'].tic() 97 | out = net(img) # forward pass 98 | _t['forward_pass'].toc() 99 | _t['misc'].tic() 100 | priorbox = PriorBox(cfg, out[2], (im_height, im_width), phase='test') 101 | priors = priorbox.forward() 102 | priors = priors.to(device) 103 | loc, conf, _ = out 104 | prior_data = priors.data 105 | boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) 106 | boxes = boxes * scale / resize 107 | boxes = boxes.cpu().numpy() 108 | scores = conf.data.cpu().numpy()[:, 1] 109 | 110 | # ignore low scores 111 | inds = np.where(scores > args.confidence_threshold)[0] 112 | boxes = boxes[inds] 113 | scores = scores[inds] 114 | 115 | # keep top-K before NMS 116 | order = scores.argsort()[::-1][:args.top_k] 117 | boxes = boxes[order] 118 | scores = scores[order] 119 | 120 | # do NMS 121 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) 122 | #keep = py_cpu_nms(dets, args.nms_threshold) 123 | keep = nms(dets, args.nms_threshold, force_cpu=args.cpu) 124 | dets = dets[keep, :] 125 | 126 | # keep top-K faster NMS 127 | dets = dets[:args.keep_top_k, :] 128 | _t['misc'].toc() 129 | 130 | for i in range(dets.shape[0]): 131 | cv2.rectangle(to_show, (dets[i][0], dets[i][1]), (dets[i][2], dets[i][3]), [0, 0, 255], 3) 132 | 133 | cv2.imshow('image', to_show) 134 | cv2.waitKey(0) 135 | cv2.destroyAllWindows() 136 | 137 | else: 138 | videofile = args.video 139 | 140 | cap = cv2.VideoCapture(videofile) 141 | 142 | assert cap.isOpened(), 'Cannot capture source' 143 | 144 | while cap.isOpened(): 145 | 146 | ret, frame = cap.read() 147 | if ret: 148 | to_show = frame 149 | img = np.float32(to_show) 150 | 151 | if resize != 1: 152 | img = cv2.resize(img, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) 153 | im_height, im_width, _ = img.shape 154 | scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) 155 | img -= (104, 117, 123) 156 | img = img.transpose(2, 0, 1) 157 | img = torch.from_numpy(img).unsqueeze(0) 158 | img = img.to(device) 159 | scale = scale.to(device) 160 | 161 | _t['forward_pass'].tic() 162 | out = net(img) # forward pass 163 | _t['forward_pass'].toc() 164 | _t['misc'].tic() 165 | priorbox = PriorBox(cfg, out[2], (im_height, im_width), phase='test') 166 | priors = priorbox.forward() 167 | priors = priors.to(device) 168 | loc, conf, _ = out 169 | prior_data = priors.data 170 | boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) 171 | boxes = boxes * scale / resize 172 | boxes = boxes.cpu().numpy() 173 | scores = conf.data.cpu().numpy()[:, 1] 174 | 175 | # ignore low scores 176 | inds = np.where(scores > args.confidence_threshold)[0] 177 | boxes = boxes[inds] 178 | scores = scores[inds] 179 | 180 | # keep top-K before NMS 181 | order = scores.argsort()[::-1][:args.top_k] 182 | boxes = boxes[order] 183 | scores = scores[order] 184 | 185 | # do NMS 186 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) 187 | # keep = py_cpu_nms(dets, args.nms_threshold) 188 | keep = nms(dets, args.nms_threshold, force_cpu=args.cpu) 189 | dets = dets[keep, :] 190 | 191 | # keep top-K faster NMS 192 | dets = dets[:args.keep_top_k, :] 193 | _t['misc'].toc() 194 | 195 | for i in range(dets.shape[0]): 196 | cv2.rectangle(to_show, (dets[i][0], dets[i][1]), (dets[i][2], dets[i][3]), [0, 0, 255], 3) 197 | 198 | cv2.imshow('image', to_show) 199 | # cv2.waitKey(0) 200 | # cv2.destroyAllWindows() 201 | 202 | key = cv2.waitKey(1) 203 | if key & 0xFF == ord('q'): 204 | break 205 | 206 | 207 | else: 208 | break 209 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | import torch.optim as optim 5 | import torch.backends.cudnn as cudnn 6 | import argparse 7 | from torch.autograd import Variable 8 | import torch.utils.data as data 9 | from data import AnnotationTransform, VOCDetection, detection_collate, preproc, cfg 10 | from layers.modules import MultiBoxLoss 11 | from layers.functions.prior_box import PriorBox 12 | import time 13 | import math 14 | from models.faceboxes import FaceBoxes 15 | 16 | parser = argparse.ArgumentParser(description='HandBoxes Training') 17 | parser.add_argument('--training_dataset', default='./data/Hand', help='Training dataset directory') 18 | parser.add_argument('-b', '--batch_size', default=32, type=int, help='Batch size for training') 19 | parser.add_argument('--num_workers', default=8, type=int, help='Number of workers used in dataloading') 20 | parser.add_argument('--ngpu', default=2, type=int, help='gpus') 21 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate') 22 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 23 | parser.add_argument('--resume_net', default=None, help='resume net for retraining') 24 | parser.add_argument('--resume_epoch', default=0, type=int, help='resume iter for retraining') 25 | parser.add_argument('-max', '--max_epoch', default=300, type=int, help='max epoch for retraining') 26 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD') 27 | parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD') 28 | parser.add_argument('--save_folder', default='./weights/', help='Location to save checkpoint models') 29 | args = parser.parse_args() 30 | 31 | if not os.path.exists(args.save_folder): 32 | os.mkdir(args.save_folder) 33 | 34 | img_dim = 1024 35 | rgb_means = (104, 117, 123) # bgr order 36 | num_classes = 2 37 | batch_size = args.batch_size 38 | weight_decay = args.weight_decay 39 | gamma = args.gamma 40 | momentum = args.momentum 41 | gpu_train = cfg['gpu_train'] 42 | 43 | net = FaceBoxes('train', img_dim, num_classes) 44 | print("Printing net...") 45 | print(net) 46 | 47 | if args.resume_net is not None: 48 | print('Loading resume network...') 49 | state_dict = torch.load(args.resume_net) 50 | # create new OrderedDict that does not contain `module.` 51 | from collections import OrderedDict 52 | new_state_dict = OrderedDict() 53 | for k, v in state_dict.items(): 54 | head = k[:7] 55 | if head == 'module.': 56 | name = k[7:] # remove `module.` 57 | else: 58 | name = k 59 | new_state_dict[name] = v 60 | net.load_state_dict(new_state_dict) 61 | 62 | if args.ngpu > 1 and gpu_train: 63 | net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) 64 | 65 | device = torch.device('cuda:0' if gpu_train else 'cpu') 66 | cudnn.benchmark = True 67 | net = net.to(device) 68 | 69 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 70 | criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False) 71 | 72 | priorbox = PriorBox(cfg) 73 | with torch.no_grad(): 74 | priors = priorbox.forward() 75 | priors = priors.to(device) 76 | 77 | 78 | def train(): 79 | net.train() 80 | epoch = 0 + args.resume_epoch 81 | print('Loading Dataset...') 82 | 83 | dataset = VOCDetection(args.training_dataset, preproc(img_dim, rgb_means), AnnotationTransform()) 84 | 85 | epoch_size = math.ceil(len(dataset) / args.batch_size) 86 | max_iter = args.max_epoch * epoch_size 87 | 88 | stepvalues = (200 * epoch_size, 250 * epoch_size) 89 | step_index = 0 90 | 91 | if args.resume_epoch > 0: 92 | start_iter = args.resume_epoch * epoch_size 93 | else: 94 | start_iter = 0 95 | 96 | for iteration in range(start_iter, max_iter): 97 | if iteration % epoch_size == 0: 98 | # create batch iterator 99 | batch_iterator = iter(data.DataLoader(dataset, batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=detection_collate)) 100 | if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0 and epoch > 200): 101 | torch.save(net.state_dict(), args.save_folder + 'HandBoxes_epoch_' + repr(epoch) + '.pth') 102 | epoch += 1 103 | 104 | load_t0 = time.time() 105 | if iteration in stepvalues: 106 | step_index += 1 107 | lr = adjust_learning_rate(optimizer, args.gamma, epoch, step_index, iteration, epoch_size) 108 | 109 | # load train data 110 | images, targets = next(batch_iterator) 111 | if gpu_train: 112 | images = Variable(images.cuda()) 113 | targets = [Variable(anno.cuda()) for anno in targets] 114 | else: 115 | images = Variable(images) 116 | targets = [Variable(anno) for anno in targets] 117 | 118 | # forward 119 | out = net(images) 120 | 121 | # backprop 122 | optimizer.zero_grad() 123 | loss_l, loss_c = criterion(out, priors, targets) 124 | loss = cfg['loc_weight'] * loss_l + loss_c 125 | loss.backward() 126 | optimizer.step() 127 | load_t1 = time.time() 128 | print('Epoch:' + repr(epoch) + ' || epochiter: ' + repr(iteration % epoch_size) + '/' + repr(epoch_size) + 129 | '|| Totel iter ' + repr(iteration) + ' || L: %.4f C: %.4f||' % (cfg['loc_weight']*loss_l.item(), loss_c.item()) + 130 | 'Batch time: %.4f sec. ||' % (load_t1 - load_t0) + 'LR: %.8f' % (lr)) 131 | 132 | torch.save(net.state_dict(), args.save_folder + 'Final_HandBoxes.pth') 133 | 134 | 135 | def adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size): 136 | """Sets the learning rate 137 | # Adapted from PyTorch Imagenet example: 138 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 139 | """ 140 | if epoch < 0: 141 | lr = 1e-6 + (args.lr-1e-6) * iteration / (epoch_size * 5) 142 | else: 143 | lr = args.lr * (gamma ** (step_index)) 144 | for param_group in optimizer.param_groups: 145 | param_group['lr'] = lr 146 | return lr 147 | 148 | if __name__ == '__main__': 149 | train() 150 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zllrunning/hand-detection.PyTorch/ed1398d9e31bd02e879688045692124460382109/utils/__init__.py -------------------------------------------------------------------------------- /utils/box_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def point_form(boxes): 6 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax) 7 | representation for comparison to point form ground truth data. 8 | Args: 9 | boxes: (tensor) center-size default boxes from priorbox layers. 10 | Return: 11 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 12 | """ 13 | return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin 14 | boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax 15 | 16 | 17 | def center_size(boxes): 18 | """ Convert prior_boxes to (cx, cy, w, h) 19 | representation for comparison to center-size form ground truth data. 20 | Args: 21 | boxes: (tensor) point_form boxes 22 | Return: 23 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 24 | """ 25 | return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy 26 | boxes[:, 2:] - boxes[:, :2], 1) # w, h 27 | 28 | 29 | def intersect(box_a, box_b): 30 | """ We resize both tensors to [A,B,2] without new malloc: 31 | [A,2] -> [A,1,2] -> [A,B,2] 32 | [B,2] -> [1,B,2] -> [A,B,2] 33 | Then we compute the area of intersect between box_a and box_b. 34 | Args: 35 | box_a: (tensor) bounding boxes, Shape: [A,4]. 36 | box_b: (tensor) bounding boxes, Shape: [B,4]. 37 | Return: 38 | (tensor) intersection area, Shape: [A,B]. 39 | """ 40 | A = box_a.size(0) 41 | B = box_b.size(0) 42 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 43 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 44 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 45 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 46 | inter = torch.clamp((max_xy - min_xy), min=0) 47 | return inter[:, :, 0] * inter[:, :, 1] 48 | 49 | 50 | def jaccard(box_a, box_b): 51 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 52 | is simply the intersection over union of two boxes. Here we operate on 53 | ground truth boxes and default boxes. 54 | E.g.: 55 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 56 | Args: 57 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 58 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 59 | Return: 60 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 61 | """ 62 | inter = intersect(box_a, box_b) 63 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 64 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 65 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 66 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 67 | union = area_a + area_b - inter 68 | return inter / union # [A,B] 69 | 70 | 71 | def matrix_iou(a, b): 72 | """ 73 | return iou of a and b, numpy version for data augenmentation 74 | """ 75 | lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) 76 | rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) 77 | 78 | area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) 79 | area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) 80 | area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) 81 | return area_i / (area_a[:, np.newaxis] + area_b - area_i) 82 | 83 | 84 | def matrix_iof(a, b): 85 | """ 86 | return iof of a and b, numpy version for data augenmentation 87 | """ 88 | lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) 89 | rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) 90 | 91 | area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) 92 | area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) 93 | return area_i / np.maximum(area_a[:, np.newaxis], 1) 94 | 95 | 96 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): 97 | """Match each prior box with the ground truth box of the highest jaccard 98 | overlap, encode the bounding boxes, then return the matched indices 99 | corresponding to both confidence and location preds. 100 | Args: 101 | threshold: (float) The overlap threshold used when mathing boxes. 102 | truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. 103 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. 104 | variances: (tensor) Variances corresponding to each prior coord, 105 | Shape: [num_priors, 4]. 106 | labels: (tensor) All the class labels for the image, Shape: [num_obj]. 107 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets. 108 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. 109 | idx: (int) current batch index 110 | Return: 111 | The matched indices corresponding to 1)location and 2)confidence preds. 112 | """ 113 | # jaccard index 114 | overlaps = jaccard( 115 | truths, 116 | point_form(priors) 117 | ) 118 | # (Bipartite Matching) 119 | # [1,num_objects] best prior for each ground truth 120 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 121 | 122 | # ignore hard gt 123 | valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 124 | best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] 125 | if best_prior_idx_filter.shape[0] <= 0: 126 | loc_t[idx] = 0 127 | conf_t[idx] = 0 128 | return 129 | 130 | # [1,num_priors] best ground truth for each prior 131 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 132 | best_truth_idx.squeeze_(0) 133 | best_truth_overlap.squeeze_(0) 134 | best_prior_idx.squeeze_(1) 135 | best_prior_idx_filter.squeeze_(1) 136 | best_prior_overlap.squeeze_(1) 137 | best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior 138 | # TODO refactor: index best_prior_idx with long tensor 139 | # ensure every gt matches with its prior of max overlap 140 | for j in range(best_prior_idx.size(0)): 141 | best_truth_idx[best_prior_idx[j]] = j 142 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 143 | conf = labels[best_truth_idx] # Shape: [num_priors] 144 | conf[best_truth_overlap < threshold] = 0 # label as background 145 | loc = encode(matches, priors, variances) 146 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn 147 | conf_t[idx] = conf # [num_priors] top class label for each prior 148 | 149 | 150 | def encode(matched, priors, variances): 151 | """Encode the variances from the priorbox layers into the ground truth boxes 152 | we have matched (based on jaccard overlap) with the prior boxes. 153 | Args: 154 | matched: (tensor) Coords of ground truth for each prior in point-form 155 | Shape: [num_priors, 4]. 156 | priors: (tensor) Prior boxes in center-offset form 157 | Shape: [num_priors,4]. 158 | variances: (list[float]) Variances of priorboxes 159 | Return: 160 | encoded boxes (tensor), Shape: [num_priors, 4] 161 | """ 162 | 163 | # dist b/t match center and prior's center 164 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] 165 | # encode variance 166 | g_cxcy /= (variances[0] * priors[:, 2:]) 167 | # match wh / prior wh 168 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 169 | g_wh = torch.log(g_wh) / variances[1] 170 | # return target for smooth_l1_loss 171 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 172 | 173 | 174 | # Adapted from https://github.com/Hakuyume/chainer-ssd 175 | def decode(loc, priors, variances): 176 | """Decode locations from predictions using priors to undo 177 | the encoding we did for offset regression at train time. 178 | Args: 179 | loc (tensor): location predictions for loc layers, 180 | Shape: [num_priors,4] 181 | priors (tensor): Prior boxes in center-offset form. 182 | Shape: [num_priors,4]. 183 | variances: (list[float]) Variances of priorboxes 184 | Return: 185 | decoded bounding box predictions 186 | """ 187 | 188 | boxes = torch.cat(( 189 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 190 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 191 | boxes[:, :2] -= boxes[:, 2:] / 2 192 | boxes[:, 2:] += boxes[:, :2] 193 | return boxes 194 | 195 | 196 | def log_sum_exp(x): 197 | """Utility function for computing log_sum_exp while determining 198 | This will be used to determine unaveraged confidence loss across 199 | all examples in a batch. 200 | Args: 201 | x (Variable(tensor)): conf_preds from conf layers 202 | """ 203 | x_max = x.data.max() 204 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max 205 | 206 | 207 | # Original author: Francisco Massa: 208 | # https://github.com/fmassa/object-detection.torch 209 | # Ported to PyTorch by Max deGroot (02/01/2017) 210 | def nms(boxes, scores, overlap=0.5, top_k=200): 211 | """Apply non-maximum suppression at test time to avoid detecting too many 212 | overlapping bounding boxes for a given object. 213 | Args: 214 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. 215 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 216 | overlap: (float) The overlap thresh for suppressing unnecessary boxes. 217 | top_k: (int) The Maximum number of box preds to consider. 218 | Return: 219 | The indices of the kept boxes with respect to num_priors. 220 | """ 221 | 222 | keep = torch.Tensor(scores.size(0)).fill_(0).long() 223 | if boxes.numel() == 0: 224 | return keep 225 | x1 = boxes[:, 0] 226 | y1 = boxes[:, 1] 227 | x2 = boxes[:, 2] 228 | y2 = boxes[:, 3] 229 | area = torch.mul(x2 - x1, y2 - y1) 230 | v, idx = scores.sort(0) # sort in ascending order 231 | # I = I[v >= 0.01] 232 | idx = idx[-top_k:] # indices of the top-k largest vals 233 | xx1 = boxes.new() 234 | yy1 = boxes.new() 235 | xx2 = boxes.new() 236 | yy2 = boxes.new() 237 | w = boxes.new() 238 | h = boxes.new() 239 | 240 | # keep = torch.Tensor() 241 | count = 0 242 | while idx.numel() > 0: 243 | i = idx[-1] # index of current largest val 244 | # keep.append(i) 245 | keep[count] = i 246 | count += 1 247 | if idx.size(0) == 1: 248 | break 249 | idx = idx[:-1] # remove kept element from view 250 | # load bboxes of next highest vals 251 | torch.index_select(x1, 0, idx, out=xx1) 252 | torch.index_select(y1, 0, idx, out=yy1) 253 | torch.index_select(x2, 0, idx, out=xx2) 254 | torch.index_select(y2, 0, idx, out=yy2) 255 | # store element-wise max with next highest score 256 | xx1 = torch.clamp(xx1, min=x1[i]) 257 | yy1 = torch.clamp(yy1, min=y1[i]) 258 | xx2 = torch.clamp(xx2, max=x2[i]) 259 | yy2 = torch.clamp(yy2, max=y2[i]) 260 | w.resize_as_(xx2) 261 | h.resize_as_(yy2) 262 | w = xx2 - xx1 263 | h = yy2 - yy1 264 | # check sizes of xx1 and xx2.. after each iteration 265 | w = torch.clamp(w, min=0.0) 266 | h = torch.clamp(h, min=0.0) 267 | inter = w*h 268 | # IoU = i / (area(a) + area(b) - i) 269 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 270 | union = (rem_areas - inter) + area[i] 271 | IoU = inter/union # store result in iou 272 | # keep only elements with an IoU <= overlap 273 | idx = idx[IoU.le(overlap)] 274 | return keep, count 275 | 276 | 277 | -------------------------------------------------------------------------------- /utils/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | from os.path import join as pjoin 10 | import numpy as np 11 | from distutils.core import setup 12 | from distutils.extension import Extension 13 | from Cython.Distutils import build_ext 14 | 15 | 16 | def find_in_path(name, path): 17 | "Find a file in a search path" 18 | # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/ 19 | for dir in path.split(os.pathsep): 20 | binpath = pjoin(dir, name) 21 | if os.path.exists(binpath): 22 | return os.path.abspath(binpath) 23 | return None 24 | 25 | 26 | def locate_cuda(): 27 | """Locate the CUDA environment on the system 28 | 29 | Returns a dict with keys 'home', 'nvcc', 'include', and 'lib64' 30 | and values giving the absolute path to each directory. 31 | 32 | Starts by looking for the CUDAHOME env variable. If not found, everything 33 | is based on finding 'nvcc' in the PATH. 34 | """ 35 | 36 | # first check if the CUDAHOME env variable is in use 37 | if 'CUDAHOME' in os.environ: 38 | home = os.environ['CUDAHOME'] 39 | nvcc = pjoin(home, 'bin', 'nvcc') 40 | else: 41 | # otherwise, search the PATH for NVCC 42 | default_path = pjoin(os.sep, 'usr', 'local', 'cuda', 'bin') 43 | nvcc = find_in_path('nvcc', os.environ['PATH'] + os.pathsep + default_path) 44 | if nvcc is None: 45 | raise EnvironmentError('The nvcc binary could not be ' 46 | 'located in your $PATH. Either add it to your path, or set $CUDAHOME') 47 | home = os.path.dirname(os.path.dirname(nvcc)) 48 | 49 | cudaconfig = {'home': home, 'nvcc': nvcc, 50 | 'include': pjoin(home, 'include'), 51 | 'lib64': pjoin(home, 'lib64')} 52 | for k, v in cudaconfig.items(): 53 | if not os.path.exists(v): 54 | raise EnvironmentError('The CUDA %s path could not be located in %s' % (k, v)) 55 | 56 | return cudaconfig 57 | 58 | 59 | CUDA = locate_cuda() 60 | 61 | # Obtain the numpy include directory. This logic works across numpy versions. 62 | try: 63 | numpy_include = np.get_include() 64 | except AttributeError: 65 | numpy_include = np.get_numpy_include() 66 | 67 | 68 | def customize_compiler_for_nvcc(self): 69 | """inject deep into distutils to customize how the dispatch 70 | to gcc/nvcc works. 71 | 72 | If you subclass UnixCCompiler, it's not trivial to get your subclass 73 | injected in, and still have the right customizations (i.e. 74 | distutils.sysconfig.customize_compiler) run on it. So instead of going 75 | the OO route, I have this. Note, it's kindof like a wierd functional 76 | subclassing going on.""" 77 | 78 | # tell the compiler it can processes .cu 79 | self.src_extensions.append('.cu') 80 | 81 | # save references to the default compiler_so and _comple methods 82 | default_compiler_so = self.compiler_so 83 | super = self._compile 84 | 85 | # now redefine the _compile method. This gets executed for each 86 | # object but distutils doesn't have the ability to change compilers 87 | # based on source extension: we add it. 88 | def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts): 89 | print(extra_postargs) 90 | if os.path.splitext(src)[1] == '.cu': 91 | # use the cuda for .cu files 92 | self.set_executable('compiler_so', CUDA['nvcc']) 93 | # use only a subset of the extra_postargs, which are 1-1 translated 94 | # from the extra_compile_args in the Extension class 95 | postargs = extra_postargs['nvcc'] 96 | else: 97 | postargs = extra_postargs['gcc'] 98 | 99 | super(obj, src, ext, cc_args, postargs, pp_opts) 100 | # reset the default compiler_so, which we might have changed for cuda 101 | self.compiler_so = default_compiler_so 102 | 103 | # inject our redefined _compile method into the class 104 | self._compile = _compile 105 | 106 | 107 | # run the customize_compiler 108 | class custom_build_ext(build_ext): 109 | def build_extensions(self): 110 | customize_compiler_for_nvcc(self.compiler) 111 | build_ext.build_extensions(self) 112 | 113 | 114 | ext_modules = [ 115 | Extension( 116 | "nms.cpu_nms", 117 | ["nms/cpu_nms.pyx"], 118 | extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]}, 119 | include_dirs=[numpy_include] 120 | ), 121 | Extension('nms.gpu_nms', 122 | ['nms/nms_kernel.cu', 'nms/gpu_nms.pyx'], 123 | library_dirs=[CUDA['lib64']], 124 | libraries=['cudart'], 125 | language='c++', 126 | runtime_library_dirs=[CUDA['lib64']], 127 | # this syntax is specific to this build system 128 | # we're only going to use certain compiler args with nvcc and not with gcc 129 | # the implementation of this trick is in customize_compiler() below 130 | extra_compile_args={'gcc': ["-Wno-unused-function"], 131 | 'nvcc': ['-arch=sm_52', 132 | '--ptxas-options=-v', 133 | '-c', 134 | '--compiler-options', 135 | "'-fPIC'"]}, 136 | include_dirs=[numpy_include, CUDA['include']] 137 | ), 138 | ] 139 | 140 | setup( 141 | name='mot_utils', 142 | ext_modules=ext_modules, 143 | # inject our custom trigger 144 | cmdclass={'build_ext': custom_build_ext}, 145 | ) 146 | -------------------------------------------------------------------------------- /utils/nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zllrunning/hand-detection.PyTorch/ed1398d9e31bd02e879688045692124460382109/utils/nms/__init__.py -------------------------------------------------------------------------------- /utils/nms/cpu_nms.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | cimport numpy as np 10 | 11 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b): 12 | return a if a >= b else b 13 | 14 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b): 15 | return a if a <= b else b 16 | 17 | def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): 18 | cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] 19 | cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] 20 | cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] 21 | cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] 22 | cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] 23 | 24 | cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) 25 | cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1] 26 | 27 | cdef int ndets = dets.shape[0] 28 | cdef np.ndarray[np.int_t, ndim=1] suppressed = \ 29 | np.zeros((ndets), dtype=np.int) 30 | 31 | # nominal indices 32 | cdef int _i, _j 33 | # sorted indices 34 | cdef int i, j 35 | # temp variables for box i's (the box currently under consideration) 36 | cdef np.float32_t ix1, iy1, ix2, iy2, iarea 37 | # variables for computing overlap with box j (lower scoring box) 38 | cdef np.float32_t xx1, yy1, xx2, yy2 39 | cdef np.float32_t w, h 40 | cdef np.float32_t inter, ovr 41 | 42 | keep = [] 43 | for _i in range(ndets): 44 | i = order[_i] 45 | if suppressed[i] == 1: 46 | continue 47 | keep.append(i) 48 | ix1 = x1[i] 49 | iy1 = y1[i] 50 | ix2 = x2[i] 51 | iy2 = y2[i] 52 | iarea = areas[i] 53 | for _j in range(_i + 1, ndets): 54 | j = order[_j] 55 | if suppressed[j] == 1: 56 | continue 57 | xx1 = max(ix1, x1[j]) 58 | yy1 = max(iy1, y1[j]) 59 | xx2 = min(ix2, x2[j]) 60 | yy2 = min(iy2, y2[j]) 61 | w = max(0.0, xx2 - xx1 + 1) 62 | h = max(0.0, yy2 - yy1 + 1) 63 | inter = w * h 64 | ovr = inter / (iarea + areas[j] - inter) 65 | if ovr >= thresh: 66 | suppressed[j] = 1 67 | 68 | return keep 69 | 70 | def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): 71 | cdef unsigned int N = boxes.shape[0] 72 | cdef float iw, ih, box_area 73 | cdef float ua 74 | cdef int pos = 0 75 | cdef float maxscore = 0 76 | cdef int maxpos = 0 77 | cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov 78 | 79 | for i in range(N): 80 | maxscore = boxes[i, 4] 81 | maxpos = i 82 | 83 | tx1 = boxes[i,0] 84 | ty1 = boxes[i,1] 85 | tx2 = boxes[i,2] 86 | ty2 = boxes[i,3] 87 | ts = boxes[i,4] 88 | 89 | pos = i + 1 90 | # get max box 91 | while pos < N: 92 | if maxscore < boxes[pos, 4]: 93 | maxscore = boxes[pos, 4] 94 | maxpos = pos 95 | pos = pos + 1 96 | 97 | # add max box as a detection 98 | boxes[i,0] = boxes[maxpos,0] 99 | boxes[i,1] = boxes[maxpos,1] 100 | boxes[i,2] = boxes[maxpos,2] 101 | boxes[i,3] = boxes[maxpos,3] 102 | boxes[i,4] = boxes[maxpos,4] 103 | 104 | # swap ith box with position of max box 105 | boxes[maxpos,0] = tx1 106 | boxes[maxpos,1] = ty1 107 | boxes[maxpos,2] = tx2 108 | boxes[maxpos,3] = ty2 109 | boxes[maxpos,4] = ts 110 | 111 | tx1 = boxes[i,0] 112 | ty1 = boxes[i,1] 113 | tx2 = boxes[i,2] 114 | ty2 = boxes[i,3] 115 | ts = boxes[i,4] 116 | 117 | pos = i + 1 118 | # NMS iterations, note that N changes if detection boxes fall below threshold 119 | while pos < N: 120 | x1 = boxes[pos, 0] 121 | y1 = boxes[pos, 1] 122 | x2 = boxes[pos, 2] 123 | y2 = boxes[pos, 3] 124 | s = boxes[pos, 4] 125 | 126 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 127 | iw = (min(tx2, x2) - max(tx1, x1) + 1) 128 | if iw > 0: 129 | ih = (min(ty2, y2) - max(ty1, y1) + 1) 130 | if ih > 0: 131 | ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) 132 | ov = iw * ih / ua #iou between max box and detection box 133 | 134 | if method == 1: # linear 135 | if ov > Nt: 136 | weight = 1 - ov 137 | else: 138 | weight = 1 139 | elif method == 2: # gaussian 140 | weight = np.exp(-(ov * ov)/sigma) 141 | else: # original NMS 142 | if ov > Nt: 143 | weight = 0 144 | else: 145 | weight = 1 146 | 147 | boxes[pos, 4] = weight*boxes[pos, 4] 148 | 149 | # if box score falls below threshold, discard the box by swapping with last box 150 | # update N 151 | if boxes[pos, 4] < threshold: 152 | boxes[pos,0] = boxes[N-1, 0] 153 | boxes[pos,1] = boxes[N-1, 1] 154 | boxes[pos,2] = boxes[N-1, 2] 155 | boxes[pos,3] = boxes[N-1, 3] 156 | boxes[pos,4] = boxes[N-1, 4] 157 | N = N - 1 158 | pos = pos - 1 159 | 160 | pos = pos + 1 161 | 162 | keep = [i for i in range(N)] 163 | return keep 164 | -------------------------------------------------------------------------------- /utils/nms/gpu_nms.hpp: -------------------------------------------------------------------------------- 1 | void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num, 2 | int boxes_dim, float nms_overlap_thresh, int device_id); 3 | -------------------------------------------------------------------------------- /utils/nms/gpu_nms.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Faster R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | cimport numpy as np 10 | 11 | assert sizeof(int) == sizeof(np.int32_t) 12 | 13 | cdef extern from "gpu_nms.hpp": 14 | void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int) 15 | 16 | def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh, 17 | np.int32_t device_id=0): 18 | cdef int boxes_num = dets.shape[0] 19 | cdef int boxes_dim = dets.shape[1] 20 | cdef int num_out 21 | cdef np.ndarray[np.int32_t, ndim=1] \ 22 | keep = np.zeros(boxes_num, dtype=np.int32) 23 | cdef np.ndarray[np.float32_t, ndim=1] \ 24 | scores = dets[:, 4] 25 | cdef np.ndarray[np.int_t, ndim=1] \ 26 | order = scores.argsort()[::-1] 27 | cdef np.ndarray[np.float32_t, ndim=2] \ 28 | sorted_dets = dets[order, :] 29 | _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id) 30 | keep = keep[:num_out] 31 | return list(order[keep]) 32 | -------------------------------------------------------------------------------- /utils/nms/nms_kernel.cu: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Faster R-CNN 3 | // Copyright (c) 2015 Microsoft 4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details] 5 | // Written by Shaoqing Ren 6 | // ------------------------------------------------------------------ 7 | 8 | #include "gpu_nms.hpp" 9 | #include 10 | #include 11 | 12 | #define CUDA_CHECK(condition) \ 13 | /* Code block avoids redefinition of cudaError_t error */ \ 14 | do { \ 15 | cudaError_t error = condition; \ 16 | if (error != cudaSuccess) { \ 17 | std::cout << cudaGetErrorString(error) << std::endl; \ 18 | } \ 19 | } while (0) 20 | 21 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 22 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 23 | 24 | __device__ inline float devIoU(float const * const a, float const * const b) { 25 | float left = max(a[0], b[0]), right = min(a[2], b[2]); 26 | float top = max(a[1], b[1]), bottom = min(a[3], b[3]); 27 | float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); 28 | float interS = width * height; 29 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 30 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 31 | return interS / (Sa + Sb - interS); 32 | } 33 | 34 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 35 | const float *dev_boxes, unsigned long long *dev_mask) { 36 | const int row_start = blockIdx.y; 37 | const int col_start = blockIdx.x; 38 | 39 | // if (row_start > col_start) return; 40 | 41 | const int row_size = 42 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 43 | const int col_size = 44 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 45 | 46 | __shared__ float block_boxes[threadsPerBlock * 5]; 47 | if (threadIdx.x < col_size) { 48 | block_boxes[threadIdx.x * 5 + 0] = 49 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; 50 | block_boxes[threadIdx.x * 5 + 1] = 51 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; 52 | block_boxes[threadIdx.x * 5 + 2] = 53 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; 54 | block_boxes[threadIdx.x * 5 + 3] = 55 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; 56 | block_boxes[threadIdx.x * 5 + 4] = 57 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; 58 | } 59 | __syncthreads(); 60 | 61 | if (threadIdx.x < row_size) { 62 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 63 | const float *cur_box = dev_boxes + cur_box_idx * 5; 64 | int i = 0; 65 | unsigned long long t = 0; 66 | int start = 0; 67 | if (row_start == col_start) { 68 | start = threadIdx.x + 1; 69 | } 70 | for (i = start; i < col_size; i++) { 71 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { 72 | t |= 1ULL << i; 73 | } 74 | } 75 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock); 76 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 77 | } 78 | } 79 | 80 | void _set_device(int device_id) { 81 | int current_device; 82 | CUDA_CHECK(cudaGetDevice(¤t_device)); 83 | if (current_device == device_id) { 84 | return; 85 | } 86 | // The call to cudaSetDevice must come before any calls to Get, which 87 | // may perform initialization using the GPU. 88 | CUDA_CHECK(cudaSetDevice(device_id)); 89 | } 90 | 91 | void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num, 92 | int boxes_dim, float nms_overlap_thresh, int device_id) { 93 | _set_device(device_id); 94 | 95 | float* boxes_dev = NULL; 96 | unsigned long long* mask_dev = NULL; 97 | 98 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock); 99 | 100 | CUDA_CHECK(cudaMalloc(&boxes_dev, 101 | boxes_num * boxes_dim * sizeof(float))); 102 | CUDA_CHECK(cudaMemcpy(boxes_dev, 103 | boxes_host, 104 | boxes_num * boxes_dim * sizeof(float), 105 | cudaMemcpyHostToDevice)); 106 | 107 | CUDA_CHECK(cudaMalloc(&mask_dev, 108 | boxes_num * col_blocks * sizeof(unsigned long long))); 109 | 110 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock), 111 | DIVUP(boxes_num, threadsPerBlock)); 112 | dim3 threads(threadsPerBlock); 113 | nms_kernel<<>>(boxes_num, 114 | nms_overlap_thresh, 115 | boxes_dev, 116 | mask_dev); 117 | 118 | std::vector mask_host(boxes_num * col_blocks); 119 | CUDA_CHECK(cudaMemcpy(&mask_host[0], 120 | mask_dev, 121 | sizeof(unsigned long long) * boxes_num * col_blocks, 122 | cudaMemcpyDeviceToHost)); 123 | 124 | std::vector remv(col_blocks); 125 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); 126 | 127 | int num_to_keep = 0; 128 | for (int i = 0; i < boxes_num; i++) { 129 | int nblock = i / threadsPerBlock; 130 | int inblock = i % threadsPerBlock; 131 | 132 | if (!(remv[nblock] & (1ULL << inblock))) { 133 | keep_out[num_to_keep++] = i; 134 | unsigned long long *p = &mask_host[0] + i * col_blocks; 135 | for (int j = nblock; j < col_blocks; j++) { 136 | remv[j] |= p[j]; 137 | } 138 | } 139 | } 140 | *num_out = num_to_keep; 141 | 142 | CUDA_CHECK(cudaFree(boxes_dev)); 143 | CUDA_CHECK(cudaFree(mask_dev)); 144 | } 145 | -------------------------------------------------------------------------------- /utils/nms/py_cpu_nms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | 10 | def py_cpu_nms(dets, thresh): 11 | """Pure Python NMS baseline.""" 12 | x1 = dets[:, 0] 13 | y1 = dets[:, 1] 14 | x2 = dets[:, 2] 15 | y2 = dets[:, 3] 16 | scores = dets[:, 4] 17 | 18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 19 | order = scores.argsort()[::-1] 20 | 21 | keep = [] 22 | while order.size > 0: 23 | i = order[0] 24 | keep.append(i) 25 | xx1 = np.maximum(x1[i], x1[order[1:]]) 26 | yy1 = np.maximum(y1[i], y1[order[1:]]) 27 | xx2 = np.minimum(x2[i], x2[order[1:]]) 28 | yy2 = np.minimum(y2[i], y2[order[1:]]) 29 | 30 | w = np.maximum(0.0, xx2 - xx1 + 1) 31 | h = np.maximum(0.0, yy2 - yy1 + 1) 32 | inter = w * h 33 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 34 | 35 | inds = np.where(ovr <= thresh)[0] 36 | order = order[inds + 1] 37 | 38 | return keep 39 | -------------------------------------------------------------------------------- /utils/nms_wrapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | from .nms.cpu_nms import cpu_nms, cpu_soft_nms 9 | from .nms.gpu_nms import gpu_nms 10 | 11 | 12 | # def nms(dets, thresh, force_cpu=False): 13 | # """Dispatch to either CPU or GPU NMS implementations.""" 14 | # 15 | # if dets.shape[0] == 0: 16 | # return [] 17 | # if cfg.USE_GPU_NMS and not force_cpu: 18 | # return gpu_nms(dets, thresh, device_id=cfg.GPU_ID) 19 | # else: 20 | # return cpu_nms(dets, thresh) 21 | 22 | 23 | def nms(dets, thresh, force_cpu=False): 24 | """Dispatch to either CPU or GPU NMS implementations.""" 25 | 26 | if dets.shape[0] == 0: 27 | return [] 28 | if force_cpu: 29 | #return cpu_soft_nms(dets, thresh, method = 0) 30 | return cpu_nms(dets, thresh) 31 | return gpu_nms(dets, thresh) 32 | -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import time 9 | 10 | 11 | class Timer(object): 12 | """A simple timer.""" 13 | def __init__(self): 14 | self.total_time = 0. 15 | self.calls = 0 16 | self.start_time = 0. 17 | self.diff = 0. 18 | self.average_time = 0. 19 | 20 | def tic(self): 21 | # using time.time instead of time.clock because time time.clock 22 | # does not normalize for multithreading 23 | self.start_time = time.time() 24 | 25 | def toc(self, average=True): 26 | self.diff = time.time() - self.start_time 27 | self.total_time += self.diff 28 | self.calls += 1 29 | self.average_time = self.total_time / self.calls 30 | if average: 31 | return self.average_time 32 | else: 33 | return self.diff 34 | 35 | def clear(self): 36 | self.total_time = 0. 37 | self.calls = 0 38 | self.start_time = 0. 39 | self.diff = 0. 40 | self.average_time = 0. 41 | -------------------------------------------------------------------------------- /xml2dict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | try: 4 | from defusedexpat import pyexpat as expat 5 | except ImportError: 6 | from xml.parsers import expat 7 | from xml.sax.saxutils import XMLGenerator 8 | from xml.sax.xmlreader import AttributesImpl 9 | try: # pragma no cover 10 | from cStringIO import StringIO 11 | except ImportError: # pragma no cover 12 | try: 13 | from StringIO import StringIO 14 | except ImportError: 15 | from io import StringIO 16 | try: # pragma no cover 17 | from collections import OrderedDict 18 | except ImportError: # pragma no cover 19 | try: 20 | from ordereddict import OrderedDict 21 | except ImportError: 22 | OrderedDict = dict 23 | 24 | 25 | __author__ = 'Martin Blech' 26 | __version__ = '0.11.0' 27 | __license__ = 'MIT' 28 | 29 | 30 | class ParsingInterrupted(Exception): 31 | pass 32 | 33 | 34 | class _DictSAXHandler(object): 35 | def __init__(self, 36 | item_depth=0, 37 | item_callback=lambda *args: True, 38 | xml_attribs=True, 39 | attr_prefix='@', 40 | cdata_key='#text', 41 | force_cdata=False, 42 | cdata_separator='', 43 | postprocessor=None, 44 | dict_constructor=OrderedDict, 45 | strip_whitespace=True, 46 | namespace_separator=':', 47 | namespaces=None, 48 | force_list=None): 49 | self.path = [] 50 | self.stack = [] 51 | self.data = [] 52 | self.item = None 53 | self.item_depth = item_depth 54 | self.xml_attribs = xml_attribs 55 | self.item_callback = item_callback 56 | self.attr_prefix = attr_prefix 57 | self.cdata_key = cdata_key 58 | self.force_cdata = force_cdata 59 | self.cdata_separator = cdata_separator 60 | self.postprocessor = postprocessor 61 | self.dict_constructor = dict_constructor 62 | self.strip_whitespace = strip_whitespace 63 | self.namespace_separator = namespace_separator 64 | self.namespaces = namespaces 65 | self.namespace_declarations = OrderedDict() 66 | self.force_list = force_list 67 | 68 | def _build_name(self, full_name): 69 | if not self.namespaces: 70 | return full_name 71 | i = full_name.rfind(self.namespace_separator) 72 | if i == -1: 73 | return full_name 74 | namespace, name = full_name[:i], full_name[i+1:] 75 | short_namespace = self.namespaces.get(namespace, namespace) 76 | if not short_namespace: 77 | return name 78 | else: 79 | return self.namespace_separator.join((short_namespace, name)) 80 | 81 | def _attrs_to_dict(self, attrs): 82 | if isinstance(attrs, dict): 83 | return attrs 84 | return self.dict_constructor(zip(attrs[0::2], attrs[1::2])) 85 | 86 | def startNamespaceDecl(self, prefix, uri): 87 | self.namespace_declarations[prefix or ''] = uri 88 | 89 | def startElement(self, full_name, attrs): 90 | name = self._build_name(full_name) 91 | attrs = self._attrs_to_dict(attrs) 92 | if attrs and self.namespace_declarations: 93 | attrs['xmlns'] = self.namespace_declarations 94 | self.namespace_declarations = OrderedDict() 95 | self.path.append((name, attrs or None)) 96 | if len(self.path) > self.item_depth: 97 | self.stack.append((self.item, self.data)) 98 | if self.xml_attribs: 99 | attr_entries = [] 100 | for key, value in attrs.items(): 101 | key = self.attr_prefix+self._build_name(key) 102 | if self.postprocessor: 103 | entry = self.postprocessor(self.path, key, value) 104 | else: 105 | entry = (key, value) 106 | if entry: 107 | attr_entries.append(entry) 108 | attrs = self.dict_constructor(attr_entries) 109 | else: 110 | attrs = None 111 | self.item = attrs or None 112 | self.data = [] 113 | 114 | def endElement(self, full_name): 115 | name = self._build_name(full_name) 116 | if len(self.path) == self.item_depth: 117 | item = self.item 118 | if item is None: 119 | item = (None if not self.data 120 | else self.cdata_separator.join(self.data)) 121 | 122 | should_continue = self.item_callback(self.path, item) 123 | if not should_continue: 124 | raise ParsingInterrupted() 125 | if len(self.stack): 126 | data = (None if not self.data 127 | else self.cdata_separator.join(self.data)) 128 | item = self.item 129 | self.item, self.data = self.stack.pop() 130 | if self.strip_whitespace and data: 131 | data = data.strip() or None 132 | if data and self.force_cdata and item is None: 133 | item = self.dict_constructor() 134 | if item is not None: 135 | if data: 136 | self.push_data(item, self.cdata_key, data) 137 | self.item = self.push_data(self.item, name, item) 138 | else: 139 | self.item = self.push_data(self.item, name, data) 140 | else: 141 | self.item = None 142 | self.data = [] 143 | self.path.pop() 144 | 145 | def characters(self, data): 146 | if not self.data: 147 | self.data = [data] 148 | else: 149 | self.data.append(data) 150 | 151 | def push_data(self, item, key, data): 152 | if self.postprocessor is not None: 153 | result = self.postprocessor(self.path, key, data) 154 | if result is None: 155 | return item 156 | key, data = result 157 | if item is None: 158 | item = self.dict_constructor() 159 | try: 160 | value = item[key] 161 | if isinstance(value, list): 162 | value.append(data) 163 | else: 164 | item[key] = [value, data] 165 | except KeyError: 166 | if self._should_force_list(key, data): 167 | item[key] = [data] 168 | else: 169 | item[key] = data 170 | return item 171 | 172 | def _should_force_list(self, key, value): 173 | if not self.force_list: 174 | return False 175 | try: 176 | return key in self.force_list 177 | except TypeError: 178 | return self.force_list(self.path[:-1], key, value) 179 | 180 | 181 | def parse(xml_input, encoding=None, expat=expat, process_namespaces=False, 182 | namespace_separator=':', disable_entities=True, **kwargs): 183 | """Parse the given XML input and convert it into a dictionary. 184 | `xml_input` can either be a `string` or a file-like object. 185 | If `xml_attribs` is `True`, element attributes are put in the dictionary 186 | among regular child elements, using `@` as a prefix to avoid collisions. If 187 | set to `False`, they are just ignored. 188 | Simple example:: 189 | [u'1', u'2'] 190 | If `item_depth` is `0`, the function returns a dictionary for the root 191 | element (default behavior). Otherwise, it calls `item_callback` every time 192 | an item at the specified depth is found and returns `None` in the end 193 | (streaming mode). 194 | The callback function receives two parameters: the `path` from the document 195 | root to the item (name-attribs pairs), and the `item` (dict). If the 196 | callback's return value is false-ish, parsing will be stopped with the 197 | :class:`ParsingInterrupted` exception. 198 | Streaming example:: 199 | 200 | path:[(u'a', {u'prop': u'x'}), (u'b', None)] item:1 201 | path:[(u'a', {u'prop': u'x'}), (u'b', None)] item:2 202 | The optional argument `postprocessor` is a function that takes `path`, 203 | `key` and `value` as positional arguments and returns a new `(key, value)` 204 | pair where both `key` and `value` may have changed. Usage example:: 205 | 206 | OrderedDict([(u'a', OrderedDict([(u'b:int', [1, 2]), (u'b', u'x')]))]) 207 | You can pass an alternate version of `expat` (such as `defusedexpat`) by 208 | using the `expat` parameter. E.g: 209 | 210 | OrderedDict([(u'a', u'hello')]) 211 | You can use the force_list argument to force lists to be created even 212 | when there is only a single child of a given level of hierarchy. The 213 | force_list argument is a tuple of keys. If the key for a given level 214 | of hierarchy is in the force_list argument, that level of hierarchy 215 | will have a list as a child (even if there is only one sub-element). 216 | The index_keys operation takes precendence over this. This is applied 217 | after any user-supplied postprocessor has already run. 218 | For example, given this input: 219 | 220 | 221 | host1 222 | Linux 223 | 224 | 225 | em0 226 | 10.0.0.1 227 | 228 | 229 | 230 | 231 | If called with force_list=('interface',), it will produce 232 | this dictionary: 233 | {'servers': 234 | {'server': 235 | {'name': 'host1', 236 | 'os': 'Linux'}, 237 | 'interfaces': 238 | {'interface': 239 | [ {'name': 'em0', 'ip_address': '10.0.0.1' } ] } } } 240 | `force_list` can also be a callable that receives `path`, `key` and 241 | `value`. This is helpful in cases where the logic that decides whether 242 | a list should be forced is more complex. 243 | """ 244 | handler = _DictSAXHandler(namespace_separator=namespace_separator, 245 | **kwargs) 246 | if isinstance(xml_input, str): 247 | if not encoding: 248 | encoding = 'utf-8' 249 | xml_input = xml_input.encode(encoding) 250 | if not process_namespaces: 251 | namespace_separator = None 252 | parser = expat.ParserCreate( 253 | encoding, 254 | namespace_separator 255 | ) 256 | try: 257 | parser.ordered_attributes = True 258 | except AttributeError: 259 | # Jython's expat does not support ordered_attributes 260 | pass 261 | parser.StartNamespaceDeclHandler = handler.startNamespaceDecl 262 | parser.StartElementHandler = handler.startElement 263 | parser.EndElementHandler = handler.endElement 264 | parser.CharacterDataHandler = handler.characters 265 | parser.buffer_text = True 266 | if disable_entities: 267 | try: 268 | # Attempt to disable DTD in Jython's expat parser (Xerces-J). 269 | feature = "http://apache.org/xml/features/disallow-doctype-decl" 270 | parser._reader.setFeature(feature, True) 271 | except AttributeError: 272 | # For CPython / expat parser. 273 | # Anything not handled ends up here and entities aren't expanded. 274 | parser.DefaultHandler = lambda x: None 275 | # Expects an integer return; zero means failure -> expat.ExpatError. 276 | parser.ExternalEntityRefHandler = lambda *x: 1 277 | if hasattr(xml_input, 'read'): 278 | parser.ParseFile(xml_input) 279 | else: 280 | parser.Parse(xml_input, True) 281 | return handler.item 282 | 283 | 284 | def _process_namespace(name, namespaces, ns_sep=':', attr_prefix='@'): 285 | if not namespaces: 286 | return name 287 | try: 288 | ns, name = name.rsplit(ns_sep, 1) 289 | except ValueError: 290 | pass 291 | else: 292 | ns_res = namespaces.get(ns.strip(attr_prefix)) 293 | name = '{0}{1}{2}{3}'.format( 294 | attr_prefix if ns.startswith(attr_prefix) else '', 295 | ns_res, ns_sep, name) if ns_res else name 296 | return name 297 | 298 | 299 | def _emit(key, value, content_handler, 300 | attr_prefix='@', 301 | cdata_key='#text', 302 | depth=0, 303 | preprocessor=None, 304 | pretty=False, 305 | newl='\n', 306 | indent='\t', 307 | namespace_separator=':', 308 | namespaces=None, 309 | full_document=True): 310 | key = _process_namespace(key, namespaces, namespace_separator, attr_prefix) 311 | if preprocessor is not None: 312 | result = preprocessor(key, value) 313 | if result is None: 314 | return 315 | key, value = result 316 | if (not hasattr(value, '__iter__') 317 | or isinstance(value, str) 318 | or isinstance(value, dict)): 319 | value = [value] 320 | for index, v in enumerate(value): 321 | if full_document and depth == 0 and index > 0: 322 | raise ValueError('document with multiple roots') 323 | if v is None: 324 | v = OrderedDict() 325 | elif not isinstance(v, dict): 326 | v = str(v) 327 | if isinstance(v, str): 328 | v = OrderedDict(((cdata_key, v),)) 329 | cdata = None 330 | attrs = OrderedDict() 331 | children = [] 332 | for ik, iv in v.items(): 333 | if ik == cdata_key: 334 | cdata = iv 335 | continue 336 | if ik.startswith(attr_prefix): 337 | ik = _process_namespace(ik, namespaces, namespace_separator, 338 | attr_prefix) 339 | if ik == '@xmlns' and isinstance(iv, dict): 340 | for k, v in iv.items(): 341 | attr = 'xmlns{0}'.format(':{0}'.format(k) if k else '') 342 | attrs[attr] = str(v) 343 | continue 344 | if not isinstance(iv, str): 345 | iv = str(iv) 346 | attrs[ik[len(attr_prefix):]] = iv 347 | continue 348 | children.append((ik, iv)) 349 | if pretty: 350 | content_handler.ignorableWhitespace(depth * indent) 351 | content_handler.startElement(key, AttributesImpl(attrs)) 352 | if pretty and children: 353 | content_handler.ignorableWhitespace(newl) 354 | for child_key, child_value in children: 355 | _emit(child_key, child_value, content_handler, 356 | attr_prefix, cdata_key, depth+1, preprocessor, 357 | pretty, newl, indent, namespaces=namespaces, 358 | namespace_separator=namespace_separator) 359 | if cdata is not None: 360 | content_handler.characters(cdata) 361 | if pretty and children: 362 | content_handler.ignorableWhitespace(depth * indent) 363 | content_handler.endElement(key) 364 | if pretty and depth: 365 | content_handler.ignorableWhitespace(newl) 366 | 367 | 368 | def unparse(input_dict, output=None, encoding='utf-8', full_document=True, 369 | short_empty_elements=False, 370 | **kwargs): 371 | """Emit an XML document for the given `input_dict` (reverse of `parse`). 372 | The resulting XML document is returned as a string, but if `output` (a 373 | file-like object) is specified, it is written there instead. 374 | Dictionary keys prefixed with `attr_prefix` (default=`'@'`) are interpreted 375 | as XML node attributes, whereas keys equal to `cdata_key` 376 | (default=`'#text'`) are treated as character data. 377 | The `pretty` parameter (default=`False`) enables pretty-printing. In this 378 | mode, lines are terminated with `'\n'` and indented with `'\t'`, but this 379 | can be customized with the `newl` and `indent` parameters. 380 | """ 381 | if full_document and len(input_dict) != 1: 382 | raise ValueError('Document must have exactly one root.') 383 | must_return = False 384 | if output is None: 385 | output = StringIO() 386 | must_return = True 387 | if short_empty_elements: 388 | content_handler = XMLGenerator(output, encoding, True) 389 | else: 390 | content_handler = XMLGenerator(output, encoding) 391 | if full_document: 392 | content_handler.startDocument() 393 | for key, value in input_dict.items(): 394 | _emit(key, value, content_handler, full_document=full_document, 395 | **kwargs) 396 | if full_document: 397 | content_handler.endDocument() 398 | if must_return: 399 | value = output.getvalue() 400 | try: # pragma no cover 401 | value = value.decode(encoding) 402 | except AttributeError: # pragma no cover 403 | pass 404 | return value 405 | 406 | if __name__ == '__main__': # pragma: no cover 407 | import sys 408 | import marshal 409 | try: 410 | stdin = sys.stdin.buffer 411 | stdout = sys.stdout.buffer 412 | except AttributeError: 413 | stdin = sys.stdin 414 | stdout = sys.stdout 415 | 416 | (item_depth,) = sys.argv[1:] 417 | item_depth = int(item_depth) 418 | 419 | 420 | def handle_item(path, item): 421 | marshal.dump((path, item), stdout) 422 | return True 423 | 424 | try: 425 | root = parse(stdin, 426 | item_depth=item_depth, 427 | item_callback=handle_item, 428 | dict_constructor=dict) 429 | if item_depth == 0: 430 | handle_item([], root) 431 | except KeyboardInterrupt: 432 | pass --------------------------------------------------------------------------------