├── .gitignore ├── LICENSE ├── README.md ├── create_pb.py ├── data ├── create_tfrecords.py ├── prn_pipeline.ipynb ├── test_detector_pipeline.ipynb └── test_keypoint_pipeline.ipynb ├── detector ├── __init__.py ├── anchor_generator.py ├── backbones │ ├── __init__.py │ └── mobilenet_v1.py ├── box_predictor.py ├── constants.py ├── fpn.py ├── input_pipeline │ ├── __init__.py │ ├── color_augmentations.py │ ├── heatmap_creation.py │ ├── keypoints_detector_pipeline.py │ ├── person_detector_pipeline.py │ ├── prn_pipeline.py │ ├── random_crop.py │ └── random_rotation.py ├── keypoint_subnet.py ├── prn.py ├── retinanet.py ├── training_target_creation.py └── utils │ ├── __init__.py │ ├── box_utils.py │ ├── layer_utils.py │ └── nms.py ├── inference ├── detector.py ├── predict.ipynb └── utils.py ├── keypoints_model.py ├── metrics.py ├── person_detector_model.py ├── prn_model.py ├── train_keypoints.py ├── train_person_detector.py └── train_prn.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .ipynb_checkpoints 3 | __pycache__ 4 | *.pb 5 | 6 | models/ 7 | pretrained/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Dan Antoshchenko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiPoseNet in `tensorflow` (*work in progress :wrench:*) 2 | 3 | This an implementation of [MultiPoseNet: Fast Multi-Person Pose Estimation using Pose Residual Network](https://arxiv.org/abs/1807.04067). 4 | 5 | ## How to use this 6 | 7 | 1. Download COCO dataset. 8 | 2. Run `create_tfrecords.py`. 9 | 3. Run `train_keypoints.py`. 10 | 4. Run `train_person_detector.py`. 11 | 5. Run `train_prn.py`. 12 | 6. Run `create_pb.py`. 13 | 14 | ## Requirements 15 | 1. tensorflow 1.15 16 | 2. Pillow 6.1, opencv-python 4.1 17 | 3. numpy 1.17, scipy 1.3 18 | 4. matplotlib 3.1, tqdm 4.36 19 | 5. [pycocotools](https://github.com/cocodataset/cocoapi/) 20 | -------------------------------------------------------------------------------- /create_pb.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.backbones import mobilenet_v1 3 | from detector import KeypointSubnet 4 | from detector import RetinaNet 5 | from detector import prn 6 | 7 | 8 | """ 9 | This code creates a .pb frozen inference graph. 10 | """ 11 | 12 | # result will be here 13 | PB_FILE_PATH = 'inference/model.pb' 14 | 15 | # must be an integer 16 | BATCH_SIZE = 1 17 | 18 | # used by PRN 19 | CROP_SIZE = [56, 36] 20 | 21 | # the .pb file will 22 | # output these tensors 23 | OUTPUT_NAMES = [ 24 | 'boxes', 'scores', 'num_boxes', 25 | 'keypoint_heatmaps', 'segmentation_masks', 26 | 'keypoint_scores', 'keypoint_positions' 27 | ] 28 | 29 | # parameters of the backbone 30 | # and of non maximum suppression 31 | PARAMS = { 32 | 'depth_multiplier': 1.0, 33 | 'score_threshold': 0.3, 34 | 'iou_threshold': 0.6, 35 | 'max_boxes': 25 36 | } 37 | 38 | # trained models 39 | KEYPOINTS_CHECKPOINT = 'models/run00/model.ckpt-200000' 40 | PERSON_DETECTOR_CHECKPOINT = 'models/run01/model.ckpt-150000' 41 | PRN_CHECKPOINT = 'models/run02/model.ckpt-200000' 42 | 43 | 44 | def create_full_graph(images, params): 45 | """ 46 | Batch size must be a static value. 47 | Image size must be divisible by 128. 48 | 49 | Arguments: 50 | images: a float tensor with shape [b, h, w, 3]. 51 | params: a dict. 52 | Returns: 53 | boxes: a float tensor with shape [b, max_boxes, 4], 54 | where max_boxes = max(num_boxes). 55 | scores: a float tensor with shape [b, max_boxes]. 56 | num_boxes: an int tensor with shape [b]. 57 | keypoint_heatmaps: a float tensor with shape [b, h / 4, w / 4, 17]. 58 | segmentation_masks: a float tensor with shape [b, h / 4, w / 4]. 59 | keypoint_scores: a float tensor with shape [total_num_boxes], 60 | where total_num_boxes = sum(num_boxes). 61 | keypoint_positions: a float tensor with shape [total_num_boxes, 17, 2]. 62 | """ 63 | 64 | is_training = False 65 | backbone_features = mobilenet_v1(images, is_training, params['depth_multiplier']) 66 | 67 | with tf.variable_scope('keypoint_subnet'): 68 | subnet = KeypointSubnet(backbone_features, is_training, params) 69 | 70 | with tf.variable_scope('retinanet'): 71 | retinanet = RetinaNet(backbone_features, tf.shape(images), is_training, params) 72 | 73 | predictions = { 74 | 'keypoint_heatmaps': tf.sigmoid(subnet.heatmaps[:, :, :, :17]), 75 | 'segmentation_masks': subnet.heatmaps[:, :, :, 17] 76 | } 77 | predictions.update(retinanet.get_predictions( 78 | score_threshold=params['score_threshold'], 79 | iou_threshold=params['iou_threshold'], 80 | max_detections=params['max_boxes'] 81 | )) 82 | 83 | batch_size = images.shape[0].value 84 | assert batch_size is not None 85 | 86 | heatmaps = predictions['keypoint_heatmaps'] # shape [b, h / 4, w / 4, 17] 87 | predicted_boxes = predictions['boxes'] # shape [b, max_boxes, 4] 88 | num_boxes = predictions['num_boxes'] # shape [b] 89 | 90 | M = tf.reduce_max(heatmaps, [1, 2], keepdims=True) 91 | mask = tf.to_float(M > 0.2) 92 | m = tf.reduce_min(heatmaps, [1, 2], keepdims=True) 93 | heatmaps = (heatmaps - m)/(M - m) 94 | heatmaps *= mask 95 | 96 | boxes, box_ind = [], [] 97 | for i in range(batch_size): 98 | n = num_boxes[i] 99 | boxes.append(predicted_boxes[i][:n]) 100 | box_ind.append(i * tf.ones([n], dtype=tf.int32)) 101 | 102 | boxes = tf.concat(boxes, axis=0) # shape [num_boxes, 4] 103 | box_ind = tf.concat(box_ind, axis=0) # shape [num_boxes] 104 | # where num_boxes is equal to sum(num_boxes) 105 | 106 | crops = tf.image.crop_and_resize( 107 | heatmaps, boxes, box_ind, 108 | crop_size=CROP_SIZE 109 | ) # shape [num_boxes, 56, 36, 17] 110 | 111 | num_boxes = tf.shape(crops)[0] 112 | logits = prn(crops, is_training) 113 | # it has shape [num_boxes, 56, 36, 17] 114 | 115 | H, W = CROP_SIZE 116 | logits = tf.reshape(logits, [num_boxes, H * W, 17]) 117 | probabilities = tf.nn.softmax(logits, axis=1) 118 | probabilities = tf.reshape(probabilities, [num_boxes, H, W, 17]) 119 | 120 | def argmax_2d(x): 121 | """ 122 | Arguments: 123 | x: a tensor with shape [b, h, w, c]. 124 | Returns: 125 | an int tensor with shape [b, c, 2]. 126 | """ 127 | shape = tf.unstack(tf.shape(x)) 128 | b, h, w, c = shape 129 | 130 | flat_x = tf.reshape(x, [b, h * w, c]) 131 | argmax = tf.argmax(flat_x, axis=1, output_type=tf.int32) 132 | 133 | argmax_y = argmax // w 134 | argmax_x = argmax % w 135 | 136 | return tf.stack([argmax_y, argmax_x], axis=2) 137 | 138 | keypoint_scores = tf.reduce_max(probabilities, axis=[1, 2]) # shape [num_boxes, 17] 139 | keypoint_positions = tf.to_float(argmax_2d(probabilities)) # shape [num_boxes, 17, 2] 140 | 141 | scaler = tf.stack(CROP_SIZE, axis=0) 142 | keypoint_positions /= tf.to_float(scaler) 143 | 144 | predictions.update({ 145 | 'keypoint_scores': keypoint_scores, 146 | 'keypoint_positions': keypoint_positions 147 | }) 148 | 149 | predictions = { 150 | n: tf.identity(predictions[n], name=n) 151 | for n in OUTPUT_NAMES 152 | } 153 | return predictions 154 | 155 | 156 | def convert_to_pb(): 157 | 158 | tf.logging.set_verbosity('INFO') 159 | graph = tf.Graph() 160 | config = tf.ConfigProto() 161 | config.gpu_options.visible_device_list = '0' 162 | 163 | with graph.as_default(): 164 | 165 | shape = [BATCH_SIZE, None, None, 3] 166 | raw_images = tf.placeholder(dtype=tf.uint8, shape=shape, name='images') 167 | images = (1.0/255.0) * tf.to_float(raw_images) 168 | predictions = create_full_graph(images, PARAMS) 169 | 170 | tf.train.init_from_checkpoint( 171 | KEYPOINTS_CHECKPOINT, 172 | {'MobilenetV1/': 'MobilenetV1/'} 173 | ) 174 | tf.train.init_from_checkpoint( 175 | KEYPOINTS_CHECKPOINT, 176 | {'/': 'keypoint_subnet/'} 177 | ) 178 | tf.train.init_from_checkpoint( 179 | PERSON_DETECTOR_CHECKPOINT, 180 | {'/': 'retinanet/'} 181 | ) 182 | tf.train.init_from_checkpoint( 183 | PRN_CHECKPOINT, 184 | {'PRN/': 'PRN/'} 185 | ) 186 | init = tf.global_variables_initializer() 187 | 188 | with tf.Session(config=config) as sess: 189 | sess.run(init) 190 | 191 | input_graph_def = tf.graph_util.convert_variables_to_constants( 192 | sess, graph.as_graph_def(), 193 | output_node_names=OUTPUT_NAMES 194 | ) 195 | 196 | nms_nodes = [n.name for n in input_graph_def.node if 'nms' in n.name] 197 | output_graph_def = tf.graph_util.remove_training_nodes( 198 | input_graph_def, protected_nodes=OUTPUT_NAMES + nms_nodes 199 | ) 200 | 201 | with tf.gfile.GFile(PB_FILE_PATH, 'wb') as f: 202 | f.write(output_graph_def.SerializeToString()) 203 | print('%d ops in the final graph.' % len(output_graph_def.node)) 204 | 205 | 206 | convert_to_pb() 207 | -------------------------------------------------------------------------------- /data/create_tfrecords.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import math 4 | import random 5 | import shutil 6 | import cv2 7 | import numpy as np 8 | from tqdm import tqdm 9 | from PIL import Image 10 | import tensorflow.compat.v1 as tf 11 | from pycocotools.coco import COCO 12 | 13 | 14 | """ 15 | This script creates training and validation data for 16 | person detection and keypoints heatmap regression. 17 | 18 | Just run: 19 | python create_tfrecords.py 20 | 21 | And don't forget set the right paths below. 22 | """ 23 | 24 | 25 | # paths to downloaded data 26 | IMAGES_DIR = '/home/dan/datasets/COCO/images/' 27 | # (it contains folders train2017 and val2017) 28 | ANNOTATIONS_DIR = '/home/dan/datasets/COCO/annotations/' 29 | # (it contains files person_keypoints_*.json) 30 | 31 | # path where the converted data will be stored 32 | RESULT_PATH = '/home/dan/datasets/COCO/multiposenet/' 33 | 34 | # because dataset is big we will split it into parts 35 | NUM_TRAIN_SHARDS = 300 36 | NUM_VAL_SHARDS = 1 37 | 38 | # all masks (for segmentation and for loss) are reduced in size 39 | DOWNSAMPLE = 4 40 | 41 | # we don't use poorly visible persons 42 | MIN_NUM_KEYPOINTS = 2 43 | MIN_BOX_SIDE = 5 44 | 45 | 46 | def to_tf_example(image_path, annotations, coco): 47 | """ 48 | Arguments: 49 | image_path: a string. 50 | annotations: a list of dicts. 51 | coco: an instance of COCO. 52 | Returns: 53 | an instance of tf.train.Example. 54 | """ 55 | 56 | with tf.gfile.GFile(image_path, 'rb') as f: 57 | encoded_jpg = f.read() 58 | 59 | # check image format 60 | image = Image.open(io.BytesIO(encoded_jpg)) 61 | if not image.format == 'JPEG': 62 | return None 63 | 64 | width, height = image.size 65 | if image.mode == 'L': # if grayscale 66 | rgb_image = np.stack(3*[np.array(image)], axis=2) 67 | encoded_jpg = to_jpeg_bytes(rgb_image) 68 | image = Image.open(io.BytesIO(encoded_jpg)) 69 | assert image.mode == 'RGB' 70 | assert width > 0 and height > 0 71 | 72 | # whether to use a pixel for computing the loss 73 | loss_mask = np.ones((height, width), dtype='bool') 74 | 75 | # whether there is a person on a pixel 76 | segmentation_mask = np.zeros((height, width), dtype='bool') 77 | 78 | boxes, keypoints = [], [] 79 | for a in annotations: 80 | 81 | xmin, ymin, w, h = a['bbox'] 82 | xmax, ymax = xmin + w, ymin + h 83 | 84 | # sometimes boxes go over the edges 85 | ymin = np.clip(ymin, 0.0, height) 86 | xmin = np.clip(xmin, 0.0, width) 87 | ymax = np.clip(ymax, 0.0, height) 88 | xmax = np.clip(xmax, 0.0, width) 89 | 90 | ymin, ymax = min(ymin, ymax), max(ymin, ymax) 91 | xmin, xmax = min(xmin, xmax), max(xmin, xmax) 92 | 93 | h, w = ymax - ymin, xmax - xmin 94 | 95 | # do not add barely visible people, 96 | # do not add small boxes 97 | is_bad = (a['num_keypoints'] < MIN_NUM_KEYPOINTS) or\ 98 | (h < MIN_BOX_SIDE) or (w < MIN_BOX_SIDE) 99 | 100 | if is_bad: 101 | unannotated_person_mask = coco.annToMask(a) 102 | use_this = unannotated_person_mask == 0 103 | loss_mask = np.logical_and(use_this, loss_mask) 104 | continue 105 | 106 | person_mask = coco.annToMask(a) 107 | segmentation_mask = np.logical_or(person_mask == 1, segmentation_mask) 108 | 109 | points = np.array(a['keypoints'], dtype='int64').reshape(17, 3) 110 | x, y, v = np.split(points, 3, axis=1) 111 | x = np.clip(x, 0, width - 1) 112 | y = np.clip(y, 0, height - 1) 113 | points = np.stack([y, x, v], axis=1) # note the change (x, y) -> (y, x) 114 | 115 | boxes.append((ymin, xmin, ymax, xmax)) 116 | keypoints.append(points) 117 | 118 | # every image must have boxes 119 | if len(boxes) < 1: 120 | return None 121 | 122 | num_persons = len(boxes) 123 | boxes = np.array(boxes, dtype='float32') 124 | keypoints = np.stack(keypoints, axis=0).astype('int64') 125 | 126 | # downsample and encode masks 127 | masks_width, masks_height = math.ceil(width/DOWNSAMPLE), math.ceil(height/DOWNSAMPLE) 128 | masks = np.stack([loss_mask, segmentation_mask], axis=2) 129 | masks = masks.astype('uint8') 130 | masks = cv2.resize(masks, (masks_width, masks_height), cv2.INTER_LANCZOS4) 131 | masks = np.packbits(masks > 0) 132 | # we use `ceil` because of the 'SAME' padding 133 | 134 | example = tf.train.Example(features=tf.train.Features(feature={ 135 | 'image': _bytes_feature(encoded_jpg), 136 | 'num_persons': _int64_feature(num_persons), 137 | 'boxes': _float_list_feature(boxes.reshape(-1)), 138 | 'keypoints': _int64_list_feature(keypoints.reshape(-1)), 139 | 'masks': _bytes_feature(masks.tostring()) 140 | })) 141 | return example 142 | 143 | 144 | def _bytes_feature(value): 145 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 146 | 147 | 148 | def _float_list_feature(value): 149 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 150 | 151 | 152 | def _int64_feature(value): 153 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 154 | 155 | 156 | def _int64_list_feature(value): 157 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 158 | 159 | 160 | def to_jpeg_bytes(array): 161 | image = Image.fromarray(array) 162 | tmp = io.BytesIO() 163 | image.save(tmp, format='jpeg') 164 | return tmp.getvalue() 165 | 166 | 167 | def convert(coco, image_dir, result_path, num_shards): 168 | 169 | # get all images with people 170 | catIds = coco.getCatIds(catNms=['person']) 171 | examples_list = coco.getImgIds(catIds=catIds) 172 | 173 | shutil.rmtree(result_path, ignore_errors=True) 174 | os.mkdir(result_path) 175 | 176 | # randomize image order 177 | random.shuffle(examples_list) 178 | num_examples = len(examples_list) 179 | print('Number of images:', num_examples) 180 | 181 | shard_size = math.ceil(num_examples/num_shards) 182 | print('Number of images per shard:', shard_size) 183 | 184 | shard_id = 0 185 | num_examples_written = 0 186 | num_skipped_images = 0 187 | for example in tqdm(examples_list): 188 | 189 | if num_examples_written == 0: 190 | shard_path = os.path.join(result_path, 'shard-%04d.tfrecords' % shard_id) 191 | if not os.path.exists(shard_path): 192 | writer = tf.python_io.TFRecordWriter(shard_path) 193 | 194 | image_metadata = coco.loadImgs(example)[0] 195 | image_path = os.path.join(image_dir, image_metadata['file_name']) 196 | annIds = coco.getAnnIds(imgIds=image_metadata['id'], catIds=catIds, iscrowd=None) 197 | annotations = coco.loadAnns(annIds) 198 | 199 | tf_example = to_tf_example(image_path, annotations, coco) 200 | if tf_example is None: 201 | num_skipped_images += 1 202 | continue 203 | writer.write(tf_example.SerializeToString()) 204 | num_examples_written += 1 205 | 206 | if num_examples_written == shard_size: 207 | shard_id += 1 208 | num_examples_written = 0 209 | writer.close() 210 | 211 | if num_examples_written != 0: 212 | shard_id += 1 213 | writer.close() 214 | 215 | print('Number of skipped images:', num_skipped_images) 216 | print('Number of shards:', shard_id) 217 | print('Result is here:', result_path, '\n') 218 | 219 | 220 | shutil.rmtree(RESULT_PATH, ignore_errors=True) 221 | os.mkdir(RESULT_PATH) 222 | 223 | coco = COCO(os.path.join(ANNOTATIONS_DIR, 'person_keypoints_train2017.json')) 224 | image_dir = os.path.join(IMAGES_DIR, 'train2017') 225 | result_path = os.path.join(RESULT_PATH, 'train') 226 | convert(coco, image_dir, result_path, NUM_TRAIN_SHARDS) 227 | 228 | coco = COCO(os.path.join(ANNOTATIONS_DIR, 'person_keypoints_val2017.json')) 229 | image_dir = os.path.join(IMAGES_DIR, 'val2017') 230 | result_path = os.path.join(RESULT_PATH, 'val') 231 | convert(coco, image_dir, result_path, NUM_VAL_SHARDS) 232 | -------------------------------------------------------------------------------- /data/prn_pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import tensorflow.compat.v1 as tf\n", 21 | "from PIL import Image, ImageDraw\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "from matplotlib.colors import ListedColormap\n", 24 | "\n", 25 | "import sys\n", 26 | "sys.path.append('..')\n", 27 | "from detector.input_pipeline import PoseResidualNetworkPipeline" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "cmap = plt.cm.get_cmap('autumn')\n", 37 | "new_cmap = cmap(np.arange(cmap.N))\n", 38 | "new_cmap[:, -1] = np.sqrt(np.linspace(0, 1, cmap.N)) # set alpha\n", 39 | "cmap = ListedColormap(new_cmap) # create new colormap" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "# Build a graph" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": { 53 | "scrolled": false 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "tf.reset_default_graph()\n", 58 | "\n", 59 | "files = [\n", 60 | " '/home/dan/datasets/COCO/multiposenet/train/shard-0001.tfrecords',\n", 61 | " '/home/dan/datasets/COCO/multiposenet/train/shard-0002.tfrecords',\n", 62 | " '/home/dan/datasets/COCO/multiposenet/train/shard-0003.tfrecords',\n", 63 | "]\n", 64 | "\n", 65 | "with tf.device('/cpu:0'):\n", 66 | " pipeline = PoseResidualNetworkPipeline(files, is_training=True, batch_size=10, max_keypoints=None)\n", 67 | " dataset = pipeline.dataset\n", 68 | " iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)\n", 69 | " init = iterator.make_initializer(dataset)\n", 70 | " features, labels = iterator.get_next()\n", 71 | "\n", 72 | "print(features, labels)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "# Show an image" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "ORDER = {\n", 89 | " 0: 'nose',\n", 90 | " 1: 'left eye', 2: 'right eye',\n", 91 | " 3: 'left ear', 4: 'right ear',\n", 92 | " 5: 'left shoulder', 6: 'right shoulder',\n", 93 | " 7: 'left elbow', 8: 'right elbow',\n", 94 | " 9: 'left wrist', 10: 'right wrist',\n", 95 | " 11: 'left hip', 12: 'right hip',\n", 96 | " 13: 'left knee', 14: 'right knee',\n", 97 | " 15: 'left ankle', 16: 'right ankle'\n", 98 | "}\n", 99 | "\n", 100 | "\n", 101 | "EDGES = [\n", 102 | " (0, 1), (0, 2),\n", 103 | " (1, 3), (2, 4),\n", 104 | " (5, 7), (7, 9), (6, 8), (8, 10),\n", 105 | " (11, 13), (13, 15), (12, 14), (14, 16),\n", 106 | " (3, 5), (4, 6),\n", 107 | " (5, 11), (6, 12)\n", 108 | "]\n", 109 | "\n", 110 | "\n", 111 | "def get_keypoints(heatmaps, box, threshold):\n", 112 | " \"\"\"\n", 113 | " Arguments:\n", 114 | " heatmaps: a numpy float array with shape [h, w, 17].\n", 115 | " box: a numpy array with shape [4].\n", 116 | " threshold: a float number.\n", 117 | " Returns:\n", 118 | " a numpy int array with shape [17, 3].\n", 119 | " \"\"\"\n", 120 | " keypoints = np.zeros([17, 3], dtype='int32')\n", 121 | "\n", 122 | " ymin, xmin, ymax, xmax = box\n", 123 | " height, width = ymax - ymin, xmax - xmin\n", 124 | " h, w, _ = heatmaps.shape\n", 125 | "\n", 126 | " for j in range(17):\n", 127 | " mask = heatmaps[:, :, j]\n", 128 | " if mask.max() > threshold:\n", 129 | " y, x = np.unravel_index(mask.argmax(), mask.shape)\n", 130 | " y = np.clip(int(y * height/h), 0, height)\n", 131 | " x = np.clip(int(x * width/w), 0, width)\n", 132 | " keypoints[j] = np.array([x, y, 1])\n", 133 | "\n", 134 | " return keypoints\n", 135 | "\n", 136 | "\n", 137 | "def draw_pose(draw, keypoints, box):\n", 138 | " \"\"\"\n", 139 | " Arguments:\n", 140 | " draw: an instance of ImageDraw.Draw.\n", 141 | " keypoints: a numpy int array with shape [17, 3].\n", 142 | " box: a numpy int array with shape [4].\n", 143 | " \"\"\"\n", 144 | " ymin, xmin, ymax, xmax = box\n", 145 | " keypoints = keypoints.copy()\n", 146 | " keypoints += np.array([xmin, ymin, 0])\n", 147 | "\n", 148 | " for (p, q) in EDGES:\n", 149 | "\n", 150 | " x1, y1, v1 = keypoints[p]\n", 151 | " x2, y2, v2 = keypoints[q]\n", 152 | "\n", 153 | " both_visible = v1 > 0 and v2 > 0\n", 154 | " if both_visible:\n", 155 | " draw.line([(x1, y1), (x2, y2)], fill='red')\n", 156 | "\n", 157 | " for j in range(17):\n", 158 | " x, y, v = keypoints[j]\n", 159 | " if v > 0:\n", 160 | " s = 1\n", 161 | " draw.ellipse([\n", 162 | " (x - s, y - s),\n", 163 | " (x + s, y + s)\n", 164 | " ], fill='red')" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "scrolled": false 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "with tf.Session() as sess:\n", 176 | " sess.run(init)\n", 177 | " output = sess.run([features, labels])" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "scrolled": false 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "i = 0\n", 189 | "w, h = 36, 56\n", 190 | "w, h = w * 5, h * 5\n", 191 | "\n", 192 | "background = Image.new('RGBA', (w * 2, h * 17), (255, 255, 255, 255))\n", 193 | "draw = ImageDraw.Draw(background, 'RGBA')\n", 194 | "\n", 195 | "keypoints = get_keypoints(output[1][i], box=(0, 0, h, w), threshold=0.9)\n", 196 | "heatmaps = (cmap(output[0][i]) * 255).astype('uint8')\n", 197 | "binary_masks = (output[1][i] * 255).astype('uint8')\n", 198 | "\n", 199 | "for j, name in ORDER.items():\n", 200 | "\n", 201 | " heat = Image.fromarray(heatmaps[:, :, j]).resize((w, h))\n", 202 | " mask = Image.fromarray(binary_masks[:, :, j]).resize((w, h))\n", 203 | "\n", 204 | " background.paste(heat, (0, j*h))\n", 205 | " background.paste(mask, (w, j*h))\n", 206 | " draw.text((0, j*h), name, fill='red')\n", 207 | " draw_pose(draw, keypoints, box=(j*h, w, j*h + h, 2 * w))\n", 208 | " \n", 209 | "background" 210 | ] 211 | } 212 | ], 213 | "metadata": { 214 | "kernelspec": { 215 | "display_name": "Python 3", 216 | "language": "python", 217 | "name": "python3" 218 | }, 219 | "language_info": { 220 | "codemirror_mode": { 221 | "name": "ipython", 222 | "version": 3 223 | }, 224 | "file_extension": ".py", 225 | "mimetype": "text/x-python", 226 | "name": "python", 227 | "nbconvert_exporter": "python", 228 | "pygments_lexer": "ipython3", 229 | "version": "3.6.9" 230 | } 231 | }, 232 | "nbformat": 4, 233 | "nbformat_minor": 2 234 | } 235 | -------------------------------------------------------------------------------- /data/test_detector_pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import tensorflow.compat.v1 as tf\n", 20 | "from PIL import Image, ImageDraw\n", 21 | "from tqdm import tqdm\n", 22 | "import numpy as np\n", 23 | "import cv2\n", 24 | "import os\n", 25 | "import math\n", 26 | "import time\n", 27 | "\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "%matplotlib inline\n", 30 | "\n", 31 | "import sys\n", 32 | "sys.path.append('..')\n", 33 | "from detector.input_pipeline import DetectorPipeline" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "# Build a graph" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "scrolled": false 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "tf.reset_default_graph()\n", 52 | "\n", 53 | "dataset_path = '/home/dan/datasets/COCO/multiposenet/train/'\n", 54 | "filenames = os.listdir(dataset_path)\n", 55 | "filenames = [n for n in filenames if n.endswith('.tfrecords')]\n", 56 | "filenames = [os.path.join(dataset_path, n) for n in sorted(filenames)]\n", 57 | "\n", 58 | "batch_size = 16\n", 59 | "params = {\n", 60 | " 'batch_size': batch_size, \n", 61 | " 'image_size': (640, 640), \n", 62 | " 'min_dimension': 640\n", 63 | "}\n", 64 | "\n", 65 | "pipeline = DetectorPipeline(filenames, is_training=True, params=params)\n", 66 | "dataset = pipeline.dataset\n", 67 | "iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)\n", 68 | "init = iterator.make_initializer(dataset)\n", 69 | "features, labels = iterator.get_next()\n", 70 | "features.update(labels)\n", 71 | "features" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "# Show an image" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "def draw_boxes(image, boxes):\n", 88 | " \n", 89 | " image = Image.fromarray(image).copy()\n", 90 | " draw = ImageDraw.Draw(image, 'RGBA')\n", 91 | "\n", 92 | " for box in boxes:\n", 93 | " ymin, xmin, ymax, xmax = box\n", 94 | " fill = (255, 0, 0, 45)\n", 95 | " outline = 'red'\n", 96 | " draw.rectangle(\n", 97 | " [(xmin, ymin), (xmax, ymax)],\n", 98 | " fill=fill, outline=outline\n", 99 | " )\n", 100 | "\n", 101 | " return image" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": { 108 | "scrolled": false 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "with tf.Session() as sess:\n", 113 | " sess.run(init)\n", 114 | " output = sess.run(features)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": { 121 | "scrolled": false 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "i = 0\n", 126 | "image = (255.0*output['images'][i]).astype('uint8')\n", 127 | "boxes = output['boxes'][i].copy()\n", 128 | "\n", 129 | "num_boxes = output['num_boxes'][i]\n", 130 | "boxes = boxes[:num_boxes]\n", 131 | "\n", 132 | "h, w, _ = image.shape\n", 133 | "scaler = np.array([h, w, h, w], dtype='float32')\n", 134 | "boxes *= scaler\n", 135 | "\n", 136 | "draw_boxes(image, boxes)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "# Measure speed" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "times = []\n", 153 | "with tf.Session() as sess:\n", 154 | " sess.run(init)\n", 155 | " for _ in range(105):\n", 156 | " start = time.perf_counter()\n", 157 | " output = sess.run(features)\n", 158 | " times.append(time.perf_counter() - start)\n", 159 | "\n", 160 | "times = np.array(times[5:])\n", 161 | "print(times.mean(), times.std())" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "# Measure box scale distribution" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": { 175 | "scrolled": true 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "# when running this don't forget \n", 180 | "# to set `image_size` very small (like 128)\n", 181 | "\n", 182 | "num_epochs = 10\n", 183 | "datasets_size = pipeline.num_examples\n", 184 | "num_batches_per_epoch = datasets_size // batch_size\n", 185 | "num_steps = num_epochs * num_batches_per_epoch\n", 186 | "\n", 187 | "result = []\n", 188 | "with tf.Session() as sess:\n", 189 | " sess.run(init) \n", 190 | " for _ in tqdm(range(num_steps)):\n", 191 | " output = sess.run(features)\n", 192 | " boxes = output['boxes']\n", 193 | " num_boxes = output['num_boxes']\n", 194 | " result.append((boxes, num_boxes))" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "scales = []\n", 204 | "for b, n in result:\n", 205 | "\n", 206 | " ymin, xmin, ymax, xmax = np.split(b, 4, axis=2)\n", 207 | " h, w = ymax - ymin, xmax - xmin\n", 208 | " s = np.squeeze(np.sqrt(h * w), axis=2)\n", 209 | "\n", 210 | " for i in range(len(n)):\n", 211 | " scales.append(s[i][:n[i]])\n", 212 | "\n", 213 | "scales = np.concatenate(scales)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "plt.hist(scales, bins=100);" 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.6.9" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 2 247 | } 248 | -------------------------------------------------------------------------------- /data/test_keypoint_pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import tensorflow.compat.v1 as tf\n", 20 | "from PIL import Image, ImageDraw\n", 21 | "import numpy as np\n", 22 | "import cv2\n", 23 | "import math\n", 24 | "import time\n", 25 | "\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from matplotlib.colors import ListedColormap\n", 28 | "\n", 29 | "import sys\n", 30 | "sys.path.append('..')\n", 31 | "from detector.input_pipeline import KeypointPipeline" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "cmap = plt.cm.get_cmap('autumn')\n", 41 | "new_cmap = cmap(np.arange(cmap.N))\n", 42 | "new_cmap[:, -1] = np.sqrt(np.linspace(0, 1, cmap.N)) # set alpha\n", 43 | "cmap = ListedColormap(new_cmap) # create new colormap" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "# Build a graph" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "scrolled": false 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "tf.reset_default_graph()\n", 62 | "\n", 63 | "files = [\n", 64 | " '/home/dan/datasets/COCO/multiposenet/train/shard-0001.tfrecords',\n", 65 | " '/home/dan/datasets/COCO/multiposenet/train/shard-0002.tfrecords',\n", 66 | " '/home/dan/datasets/COCO/multiposenet/train/shard-0003.tfrecords'\n", 67 | "]\n", 68 | "params = {\n", 69 | " 'batch_size': 16, \n", 70 | " 'image_size': (512, 512), \n", 71 | " 'min_dimension': 512\n", 72 | "}\n", 73 | "\n", 74 | "pipeline = KeypointPipeline(files, is_training=True, params=params)\n", 75 | "dataset = pipeline.dataset\n", 76 | "iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)\n", 77 | "init = iterator.make_initializer(dataset)\n", 78 | "features, labels = iterator.get_next()\n", 79 | "features.update(labels)\n", 80 | "features" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "# Show an image" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "ORDER = {\n", 97 | " 0: 'nose',\n", 98 | " 1: 'left eye', 2: 'right eye',\n", 99 | " 3: 'left ear', 4: 'right ear',\n", 100 | " 5: 'left shoulder', 6: 'right shoulder',\n", 101 | " 7: 'left elbow', 8: 'right elbow',\n", 102 | " 9: 'left wrist', 10: 'right wrist',\n", 103 | " 11: 'left hip', 12: 'right hip',\n", 104 | " 13: 'left knee', 14: 'right knee',\n", 105 | " 15: 'left ankle', 16: 'right ankle'\n", 106 | "}\n", 107 | "\n", 108 | "\n", 109 | "def plot_maps(image, heatmaps, segmentation_mask, loss_mask):\n", 110 | " \"\"\"\n", 111 | " Arguments:\n", 112 | " image: a float numpy array with shape [h, w, 3].\n", 113 | " heatmaps: a float numpy array with shape [h / 4, w / 4, 17].\n", 114 | " segmentation_mask: a float numpy array with shape [h / 4, w / 4].\n", 115 | " loss_mask: a float numpy array with shape [h / 4, w / 4].\n", 116 | " \"\"\"\n", 117 | "\n", 118 | " h, w, _ = image.shape\n", 119 | " h, w = (h // 2), (w // 2)\n", 120 | " background = Image.new('RGBA', (w, h * 19), (255, 255, 255, 255))\n", 121 | " draw = ImageDraw.Draw(background, 'RGBA')\n", 122 | " \n", 123 | " image = (255 * image).astype('uint8')\n", 124 | " image = Image.fromarray(image)\n", 125 | " image = image.resize((w, h), Image.LANCZOS)\n", 126 | " image.putalpha(255)\n", 127 | "\n", 128 | " heatmaps = (255 * cmap(heatmaps)).astype('uint8')\n", 129 | " # it has shape [h, w, 17, 4]\n", 130 | " \n", 131 | " heats = []\n", 132 | " for j, name in ORDER.items():\n", 133 | "\n", 134 | " heat = Image.fromarray(heatmaps[:, :, j])\n", 135 | " heat = heat.resize((w, h), Image.LANCZOS)\n", 136 | " heat = Image.alpha_composite(image, heat)\n", 137 | " background.paste(heat, (0, j * h))\n", 138 | " draw.text((0, j * h), name, fill='red')\n", 139 | " \n", 140 | " def draw_mask(mask):\n", 141 | " mask = np.clip(mask, 0.0, 1.0)\n", 142 | " mask = (255 * mask).astype('uint8')\n", 143 | " mask = Image.fromarray(mask)\n", 144 | " mask = mask.resize((w, h), Image.LANCZOS).convert('RGB')\n", 145 | " mask.putalpha(mask.convert('L'))\n", 146 | " mask = Image.alpha_composite(image, mask)\n", 147 | " return mask\n", 148 | " \n", 149 | " mask = draw_mask(segmentation_mask)\n", 150 | " background.paste(mask, (0, 17 * h))\n", 151 | " draw.text((0, 17 * h), 'segmentation mask', fill='red')\n", 152 | " \n", 153 | " mask = draw_mask(loss_mask)\n", 154 | " background.paste(mask, (0, 18 * h))\n", 155 | " draw.text((0, 18 * h), 'loss mask', fill='red')\n", 156 | "\n", 157 | " return background" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": { 164 | "scrolled": false 165 | }, 166 | "outputs": [], 167 | "source": [ 168 | "with tf.Session() as sess:\n", 169 | " sess.run(init)\n", 170 | " output = sess.run(features)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": { 177 | "scrolled": false 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "i = 10\n", 182 | "image = output['images'][i]\n", 183 | "heatmaps = output['heatmaps'][i]\n", 184 | "segmentation_mask = output['segmentation_masks'][i]\n", 185 | "loss_mask = output['loss_masks'][i]\n", 186 | "\n", 187 | "plot_maps(image, heatmaps, segmentation_mask, loss_mask)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "# Measure speed" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": { 201 | "scrolled": true 202 | }, 203 | "outputs": [], 204 | "source": [ 205 | "times = []\n", 206 | "with tf.Session() as sess:\n", 207 | " sess.run(init)\n", 208 | " for _ in range(105):\n", 209 | " start = time.perf_counter()\n", 210 | " output = sess.run(features)\n", 211 | " times.append(time.perf_counter() - start)\n", 212 | "\n", 213 | "times = np.array(times[5:])\n", 214 | "print(times.mean(), times.std())" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": "Python 3", 221 | "language": "python", 222 | "name": "python3" 223 | }, 224 | "language_info": { 225 | "codemirror_mode": { 226 | "name": "ipython", 227 | "version": 3 228 | }, 229 | "file_extension": ".py", 230 | "mimetype": "text/x-python", 231 | "name": "python", 232 | "nbconvert_exporter": "python", 233 | "pygments_lexer": "ipython3", 234 | "version": "3.6.9" 235 | } 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 2 239 | } 240 | -------------------------------------------------------------------------------- /detector/__init__.py: -------------------------------------------------------------------------------- 1 | from .retinanet import RetinaNet 2 | from .keypoint_subnet import KeypointSubnet 3 | from .prn import prn 4 | -------------------------------------------------------------------------------- /detector/anchor_generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | import itertools 3 | 4 | 5 | """ 6 | Note that for FPN it is required that image height and image width 7 | are divisible by maximal feature stride (by default it is 128). 8 | It is because of all upsampling layers. 9 | """ 10 | 11 | 12 | class AnchorGenerator: 13 | def __init__(self, strides=[8, 16, 32, 64, 128], 14 | scales=[32, 64, 128, 256, 512], 15 | scale_multipliers=[1.0, 1.4142], 16 | aspect_ratios=[1.0, 2.0, 0.5]): 17 | """ 18 | The number of scales and strides must 19 | be equal to the number of feature maps. 20 | 21 | Note that 1.4142 is equal to sqrt(2). 22 | 23 | So, the number of anchors on each feature map is: 24 | w * h * len(aspect_ratios) * len(scale_multipliers), 25 | where (w, h) is the spatial size of the feature map. 26 | 27 | Arguments: 28 | strides: a list of integers, the feature strides. 29 | scales: a list of integers, a main scale for each feature map. 30 | scale_multipliers: a list of floats, a factors for a main scale. 31 | aspect_ratios: a list of float numbers, aspect ratios to place on each grid point. 32 | """ 33 | assert len(strides) == len(scales) 34 | self.strides = strides 35 | self.scales = scales 36 | self.scale_multipliers = scale_multipliers 37 | self.aspect_ratios = aspect_ratios 38 | self.num_anchors_per_location = len(aspect_ratios) * len(scale_multipliers) 39 | 40 | def __call__(self, image_height, image_width): 41 | """ 42 | Note that we don't need to pass feature map shapes 43 | because we use only 'SAME' padding in all our networks. 44 | 45 | Arguments: 46 | image_height, image_width: scalar int tensors. 47 | Returns: 48 | a float tensor with shape [num_anchors, 4], 49 | boxes with normalized coordinates. 50 | """ 51 | with tf.name_scope('anchor_generator'): 52 | 53 | image_height = tf.to_float(image_height) 54 | image_width = tf.to_float(image_width) 55 | 56 | feature_map_info = [] 57 | num_anchors_per_feature_map = [] 58 | for stride in self.strides: 59 | h = tf.to_int32(tf.ceil(image_height/stride)) 60 | w = tf.to_int32(tf.ceil(image_width/stride)) 61 | feature_map_info.append((stride, h, w)) 62 | num_anchors_per_feature_map.append(h * w * self.num_anchors_per_location) 63 | 64 | # these are needed elsewhere 65 | self.num_anchors_per_feature_map = num_anchors_per_feature_map 66 | 67 | anchors = [] 68 | 69 | # this is shared by all feature maps 70 | pairs = list(itertools.product(self.scale_multipliers, self.aspect_ratios)) 71 | aspect_ratios = tf.constant([a for _, a in pairs], dtype=tf.float32) 72 | 73 | for i, (stride, h, w) in enumerate(feature_map_info): 74 | 75 | scales = tf.constant([m * self.scales[i] for m, _ in pairs], dtype=tf.float32) 76 | stride = tf.constant(stride, dtype=tf.float32) 77 | 78 | """ 79 | It is true that 80 | image_height = h * stride - x, where 0 <= x < stride. 81 | 82 | Then image_height = (h - 1) * stride + (stride - x). 83 | So offset y must be equal to 0.5 * (stride - x). 84 | 85 | x = h * stride - image_height, 86 | y = 0.5 * (image_height - (h - 1) * stride), 87 | 0 < y <= 0.5 * stride. 88 | 89 | Offset y is maximal when image_height is divisible by stride. 90 | Offset y is minimal when image_height = k * stride + 1, where k is a positive integer. 91 | """ 92 | offset_y = 0.5 * (image_height - (tf.to_float(h) - 1.0) * stride) 93 | offset_x = 0.5 * (image_width - (tf.to_float(w) - 1.0) * stride) 94 | 95 | anchors.append(tile_anchors( 96 | grid_height=h, grid_width=w, 97 | scales=scales, aspect_ratios=aspect_ratios, 98 | anchor_stride=(stride, stride), 99 | anchor_offset=(offset_y, offset_x) 100 | )) 101 | 102 | with tf.name_scope('concatenate_normalize'): 103 | 104 | # this is for visualization and debugging only 105 | self.raw_anchors = anchors 106 | 107 | anchors = tf.concat(anchors, axis=0) 108 | 109 | # convert to the [0, 1] range 110 | scaler = tf.to_float(tf.stack([ 111 | image_height, image_width, 112 | image_height, image_width 113 | ])) 114 | anchors /= scaler 115 | 116 | return anchors 117 | 118 | 119 | def tile_anchors( 120 | grid_height, grid_width, 121 | scales, aspect_ratios, 122 | anchor_stride, anchor_offset): 123 | """ 124 | It returns boxes in absolute coordinates. 125 | 126 | Arguments: 127 | grid_height: a scalar int tensor, size of the grid in the y direction. 128 | grid_width: a scalar int tensor, size of the grid in the x direction. 129 | scales: a float tensor with shape [N], 130 | it represents the scale of each box in the basis set. 131 | aspect_ratios: a float tensor with shape [N], 132 | it represents the aspect ratio of each box in the basis set. 133 | anchor_stride: a tuple of float scalar tensors, 134 | difference in centers between anchors for adjacent grid positions. 135 | anchor_offset: a tuple of float scalar tensors, 136 | center of the anchor on upper left element of the grid ((0, 0)-th anchor). 137 | Returns: 138 | a float tensor with shape [grid_height * grid_width * N, 4]. 139 | """ 140 | N = tf.size(scales) 141 | ratio_sqrts = tf.sqrt(aspect_ratios) 142 | heights = scales / ratio_sqrts 143 | widths = scales * ratio_sqrts 144 | # widths/heights = aspect_ratios, 145 | # and scales = sqrt(heights * widths) 146 | 147 | # get a grid of box centers 148 | y_centers = tf.to_float(tf.range(grid_height)) * anchor_stride[0] + anchor_offset[0] 149 | x_centers = tf.to_float(tf.range(grid_width)) * anchor_stride[1] + anchor_offset[1] 150 | x_centers, y_centers = tf.meshgrid(x_centers, y_centers) 151 | # they have shape [grid_height, grid_width] 152 | 153 | centers = tf.stack([y_centers, x_centers], axis=2) 154 | centers = tf.expand_dims(centers, 2) 155 | centers = tf.tile(centers, [1, 1, N, 1]) 156 | # shape [grid_height, grid_width, N, 2] 157 | 158 | sizes = tf.stack([heights, widths], axis=1) 159 | sizes = tf.expand_dims(tf.expand_dims(sizes, 0), 0) 160 | sizes = tf.tile(sizes, [grid_height, grid_width, 1, 1]) 161 | # shape [grid_height, grid_width, N, 2] 162 | 163 | boxes = tf.concat([centers - 0.5 * sizes, centers + 0.5 * sizes], axis=3) 164 | # it has shape [grid_height, grid_width, N, 4] 165 | boxes = tf.reshape(boxes, [-1, 4]) 166 | return boxes 167 | -------------------------------------------------------------------------------- /detector/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenet_v1 import mobilenet_v1 2 | -------------------------------------------------------------------------------- /detector/backbones/mobilenet_v1.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | import tensorflow.contrib as contrib 3 | import tensorflow.contrib.slim as slim 4 | from detector.constants import DATA_FORMAT 5 | 6 | 7 | BATCH_NORM_MOMENTUM = 0.95 8 | BATCH_NORM_EPSILON = 1e-3 9 | 10 | 11 | def mobilenet_v1(images, is_training, depth_multiplier=1.0): 12 | """ 13 | This implementation works with checkpoints from here: 14 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md 15 | 16 | Arguments: 17 | images: a float tensor with shape [b, h, w, 3], 18 | a batch of RGB images with pixel values in the range [0, 1]. 19 | is_training: a boolean. 20 | depth_multiplier: a float number, multiplier for the number of filters in a layer. 21 | Returns: 22 | a dict with four float tensors. 23 | """ 24 | 25 | def depth(x): 26 | """Reduce the number of filters in a layer.""" 27 | return max(int(x * depth_multiplier), 8) 28 | 29 | def batch_norm(x): 30 | x = tf.layers.batch_normalization( 31 | x, axis=1 if DATA_FORMAT == 'channels_first' else 3, 32 | center=True, scale=True, 33 | momentum=BATCH_NORM_MOMENTUM, 34 | epsilon=BATCH_NORM_EPSILON, 35 | training=is_training, fused=True, 36 | name='BatchNorm' 37 | ) 38 | return x 39 | 40 | with tf.name_scope('standardize_input'): 41 | x = (2.0 * images) - 1.0 42 | 43 | with tf.variable_scope('MobilenetV1'): 44 | params = { 45 | 'padding': 'SAME', 46 | 'activation_fn': tf.nn.relu6, 'normalizer_fn': batch_norm, 47 | 'data_format': 'NCHW' if DATA_FORMAT == 'channels_first' else 'NHWC' 48 | } 49 | with slim.arg_scope([slim.conv2d, depthwise_conv], **params): 50 | features = {} 51 | 52 | if DATA_FORMAT == 'channels_first': 53 | x = tf.transpose(x, [0, 3, 1, 2]) 54 | 55 | layer_name = 'Conv2d_0' 56 | x = slim.conv2d(x, depth(32), (3, 3), stride=2, scope=layer_name) 57 | features[layer_name] = x 58 | 59 | strides_and_filters = [ 60 | (1, 64), 61 | (2, 128), (1, 128), 62 | (2, 256), (1, 256), 63 | (2, 512), (1, 512), (1, 512), (1, 512), (1, 512), (1, 512), 64 | (2, 1024), (1, 1024) 65 | ] 66 | for i, (stride, num_filters) in enumerate(strides_and_filters, 1): 67 | 68 | layer_name = 'Conv2d_%d_depthwise' % i 69 | x = depthwise_conv(x, stride=stride, scope=layer_name) 70 | features[layer_name] = x 71 | 72 | layer_name = 'Conv2d_%d_pointwise' % i 73 | x = slim.conv2d(x, depth(num_filters), (1, 1), stride=1, scope=layer_name) 74 | features[layer_name] = x 75 | 76 | return { 77 | 'c2': features['Conv2d_3_pointwise'], 'c3': features['Conv2d_5_pointwise'], 78 | 'c4': features['Conv2d_11_pointwise'], 'c5': features['Conv2d_13_pointwise'] 79 | } 80 | 81 | 82 | @contrib.framework.add_arg_scope 83 | def depthwise_conv( 84 | x, kernel=3, stride=1, padding='SAME', 85 | activation_fn=None, normalizer_fn=None, 86 | data_format='NHWC', scope='depthwise_conv'): 87 | with tf.variable_scope(scope): 88 | 89 | if data_format == 'NHWC': 90 | in_channels = x.shape[3].value 91 | strides = [1, stride, stride, 1] 92 | else: 93 | in_channels = x.shape[1].value 94 | strides = [1, 1, stride, stride] 95 | 96 | W = tf.get_variable( 97 | 'depthwise_weights', 98 | [kernel, kernel, in_channels, 1], 99 | dtype=tf.float32 100 | ) 101 | x = tf.nn.depthwise_conv2d(x, W, strides, padding, data_format=data_format) 102 | x = normalizer_fn(x) if normalizer_fn is not None else x # batch normalization 103 | x = activation_fn(x) if activation_fn is not None else x # nonlinearity 104 | return x 105 | -------------------------------------------------------------------------------- /detector/box_predictor.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.constants import DATA_FORMAT 3 | from detector.utils import batch_norm_relu, conv2d_same 4 | 5 | 6 | def retinanet_box_predictor( 7 | image_features, is_training, 8 | num_anchors_per_location=6, 9 | depth=256, min_level=3): 10 | """ 11 | Adds box predictors to each feature map, 12 | reshapes, and returns concatenated results. 13 | 14 | Arguments: 15 | image_features: a list of float tensors where the ith tensor 16 | has shape [batch_size, channels_i, height_i, width_i]. 17 | is_training: a boolean. 18 | num_anchors_per_location, depth, min_level: integers. 19 | Returns: 20 | encoded_boxes: a float tensor with shape [batch_size, num_anchors, 4]. 21 | class_predictions: a float tensor with shape [batch_size, num_anchors]. 22 | """ 23 | 24 | encoded_boxes = [] 25 | class_predictions = [] 26 | 27 | """ 28 | The convolution layers in the box net are shared among all levels, but 29 | each level has its batch normalization to capture the statistical 30 | difference among different levels. The same for the class net. 31 | """ 32 | 33 | with tf.variable_scope('box_net', reuse=tf.AUTO_REUSE): 34 | for level, p in enumerate(image_features, min_level): 35 | encoded_boxes.append(box_net( 36 | p, is_training, depth, level, 37 | num_anchors_per_location 38 | )) 39 | 40 | with tf.variable_scope('class_net', reuse=tf.AUTO_REUSE): 41 | for level, p in enumerate(image_features, min_level): 42 | class_predictions.append(class_net( 43 | p, is_training, depth, level, 44 | num_anchors_per_location 45 | )) 46 | 47 | return reshape_and_concatenate( 48 | encoded_boxes, class_predictions, 49 | num_anchors_per_location 50 | ) 51 | 52 | 53 | def reshape_and_concatenate( 54 | encoded_boxes, class_predictions, 55 | num_anchors_per_location): 56 | 57 | # batch size is a static value 58 | # during training and evaluation 59 | batch_size = encoded_boxes[0].shape[0].value 60 | if batch_size is None: 61 | batch_size = tf.shape(encoded_boxes[0])[0] 62 | 63 | # it is important that reshaping here is the same as when anchors were generated 64 | with tf.name_scope('reshaping_and_concatenation'): 65 | for i in range(len(encoded_boxes)): 66 | 67 | # get spatial dimensions 68 | shape = tf.shape(encoded_boxes[i]) 69 | if DATA_FORMAT == 'channels_first': 70 | height_i, width_i = shape[2], shape[3] 71 | else: 72 | height_i, width_i = shape[1], shape[2] 73 | 74 | # total number of anchors 75 | num_anchors_on_feature_map = height_i * width_i * num_anchors_per_location 76 | 77 | y = encoded_boxes[i] 78 | y = tf.transpose(y, perm=[0, 2, 3, 1]) if DATA_FORMAT == 'channels_first' else y 79 | y = tf.reshape(y, [batch_size, height_i, width_i, num_anchors_per_location, 4]) 80 | encoded_boxes[i] = tf.reshape(y, [batch_size, num_anchors_on_feature_map, 4]) 81 | 82 | y = class_predictions[i] 83 | y = tf.transpose(y, perm=[0, 2, 3, 1]) if DATA_FORMAT == 'channels_first' else y 84 | y = tf.reshape(y, [batch_size, height_i, width_i, num_anchors_per_location]) 85 | class_predictions[i] = tf.reshape(y, [batch_size, num_anchors_on_feature_map]) 86 | 87 | encoded_boxes = tf.concat(encoded_boxes, axis=1) 88 | class_predictions = tf.concat(class_predictions, axis=1) 89 | 90 | return {'encoded_boxes': encoded_boxes, 'class_predictions': class_predictions} 91 | 92 | 93 | def class_net(x, is_training, depth, level, num_anchors_per_location): 94 | """ 95 | Arguments: 96 | x: a float tensor with shape [batch_size, depth, height, width]. 97 | is_training: a boolean. 98 | depth, level, num_anchors_per_location: integers. 99 | Returns: 100 | a float tensor with shape [batch_size, num_anchors_per_location, height, width]. 101 | """ 102 | 103 | for i in range(4): 104 | x = conv2d_same(x, depth, kernel_size=3, name='conv3x3_%d' % i) 105 | x = batch_norm_relu(x, is_training, name='batch_norm_%d_for_level_%d' % (i, level)) 106 | 107 | import math 108 | p = 0.01 # probability of foreground 109 | # note that sigmoid(-log((1 - p) / p)) = p 110 | 111 | logits = tf.layers.conv2d( 112 | x, num_anchors_per_location, 113 | kernel_size=(3, 3), padding='same', 114 | bias_initializer=tf.constant_initializer(-math.log((1.0 - p) / p)), 115 | kernel_initializer=tf.random_normal_initializer(stddev=0.01), 116 | data_format=DATA_FORMAT, name='logits' 117 | ) 118 | return logits 119 | 120 | 121 | def box_net(x, is_training, depth, level, num_anchors_per_location): 122 | """ 123 | Arguments: 124 | x: a float tensor with shape [batch_size, depth, height, width]. 125 | is_training: a boolean. 126 | depth, level, num_anchors_per_location: integers. 127 | Returns: 128 | a float tensor with shape [batch_size, 4 * num_anchors_per_location, height, width]. 129 | """ 130 | 131 | for i in range(4): 132 | x = conv2d_same(x, depth, kernel_size=3, name='conv3x3_%d' % i) 133 | x = batch_norm_relu(x, is_training, name='batch_norm_%d_for_level_%d' % (i, level)) 134 | 135 | encoded_boxes = tf.layers.conv2d( 136 | x, 4 * num_anchors_per_location, 137 | kernel_size=(3, 3), padding='same', 138 | bias_initializer=tf.zeros_initializer(), 139 | kernel_initializer=tf.random_normal_initializer(stddev=0.01), 140 | data_format=DATA_FORMAT, name='encoded_boxes' 141 | ) 142 | return encoded_boxes 143 | -------------------------------------------------------------------------------- /detector/constants.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | 3 | # all image sizes must be divisible by this value 4 | DIVISOR = 128 5 | 6 | # 'channels_first' or 'channels_last' 7 | DATA_FORMAT = 'channels_first' 8 | 9 | # number of body landmarks that will be predicted 10 | NUM_KEYPOINTS = 17 11 | 12 | # all heatmaps and masks are downsampled 13 | DOWNSAMPLE = 4 14 | 15 | # a small value 16 | EPSILON = 1e-8 17 | 18 | # this is used when we are doing box encoding/decoding 19 | SCALE_FACTORS = [10.0, 10.0, 5.0, 5.0] 20 | 21 | # here are input pipeline settings, 22 | # you need to tweak these numbers for your system, 23 | # it can accelerate training 24 | SHUFFLE_BUFFER_SIZE = 10000 25 | NUM_PARALLEL_CALLS = 12 26 | 27 | # images are resized before feeding them to the network 28 | RESIZE_METHOD = tf.image.ResizeMethod.BILINEAR 29 | 30 | # thresholds for iou when creating training targets 31 | POSITIVES_THRESHOLD = 0.5 32 | NEGATIVES_THRESHOLD = 0.5 33 | 34 | # this is used in tf.map_fn when creating training targets or doing nms 35 | PARALLEL_ITERATIONS = 10 36 | 37 | # if overlap of a box with an image less than this value it is removed 38 | OVERLAP_THRESHOLD = 0.1 39 | -------------------------------------------------------------------------------- /detector/fpn.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.utils import conv2d_same, batch_norm_relu 3 | from detector.constants import DATA_FORMAT 4 | 5 | 6 | def feature_pyramid_network( 7 | features, is_training, depth, 8 | min_level=3, add_coarse_features=True, 9 | scope='fpn'): 10 | """ 11 | For person detector subnetwork we 12 | use min_level=3 and add_coarse_features=True 13 | (like in the original retinanet paper). 14 | 15 | For keypoint detector subnetwork we 16 | use min_level=2 and add_coarse_features=False 17 | (like in the original multiposenet paper). 18 | 19 | Arguments: 20 | features: a dict with four float tensors. 21 | It must have keys ['c2', 'c3', 'c4', 'c5']. 22 | Where a number in a name means that 23 | a feature has stride `2 ** number`. 24 | is_training: a boolean. 25 | depth: an integer. 26 | min_level: an integer, minimal feature stride 27 | that will be used is `2 ** min_level`. 28 | Possible values are [2, 3, 4, 5] 29 | add_coarse_features: a boolean, whether to add 30 | features with strides 64 and 128. 31 | scope: a string. 32 | Returns: 33 | a dict with float tensors. 34 | """ 35 | 36 | with tf.variable_scope(scope): 37 | 38 | x = conv2d_same(features['c5'], depth, kernel_size=1, name='lateral5') 39 | p5 = conv2d_same(x, depth, kernel_size=3, name='p5') 40 | enriched_features = {'p5': p5} 41 | 42 | if add_coarse_features: 43 | p6 = conv2d_same(features['c5'], depth, kernel_size=3, stride=2, name='p6') 44 | pre_p7 = batch_norm_relu(p6, is_training, name='pre_p7_bn') 45 | p7 = conv2d_same(pre_p7, depth, kernel_size=3, stride=2, name='p7') 46 | enriched_features.update({'p6': p6, 'p7': p7}) 47 | 48 | # top-down path 49 | for i in reversed(range(min_level, 5)): 50 | lateral = conv2d_same(features[f'c{i}'], depth, kernel_size=1, name=f'lateral{i}') 51 | x = nearest_neighbor_upsample(x) + lateral 52 | p = conv2d_same(x, depth, kernel_size=3, name=f'p{i}') 53 | enriched_features[f'p{i}'] = p 54 | 55 | return enriched_features 56 | 57 | 58 | def nearest_neighbor_upsample(x): 59 | """ 60 | Arguments: 61 | x: a float tensor with shape [b, h, w, c]. 62 | Returns: 63 | a float tensor with shape [b, 2 * h, 2 * w, c]. 64 | """ 65 | 66 | if DATA_FORMAT == 'channels_first': 67 | x = tf.transpose(x, [0, 2, 3, 1]) 68 | 69 | shape = tf.shape(x) 70 | h, w = shape[1], shape[2] 71 | x = tf.image.resize_nearest_neighbor(x, [2 * h, 2 * w]) 72 | 73 | if DATA_FORMAT == 'channels_first': 74 | x = tf.transpose(x, [0, 3, 1, 2]) 75 | 76 | return x 77 | -------------------------------------------------------------------------------- /detector/input_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .keypoints_detector_pipeline import KeypointPipeline 2 | from .person_detector_pipeline import DetectorPipeline 3 | from .prn_pipeline import PoseResidualNetworkPipeline 4 | -------------------------------------------------------------------------------- /detector/input_pipeline/color_augmentations.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | 3 | 4 | """ 5 | `image` is assumed to be a float tensor with shape [height, width, 3]. 6 | It is a RGB image with pixel values in the range [0, 1]. 7 | """ 8 | 9 | 10 | def random_color_manipulations(image, probability=0.5, grayscale_probability=0.1): 11 | """ 12 | This function randomly changes color of an image. 13 | 14 | It is taken from here: 15 | https://cloud.google.com/tpu/docs/inception-v3-advanced 16 | """ 17 | def manipulate(image): 18 | 19 | br_delta = tf.random_uniform([], -32.0/255.0, 32.0/255.0) 20 | cb_factor = tf.random_uniform([], -0.1, 0.1) 21 | cr_factor = tf.random_uniform([], -0.1, 0.1) 22 | 23 | red_offset = 1.402 * cr_factor + br_delta 24 | green_offset = -0.344136 * cb_factor - 0.714136 * cr_factor + br_delta 25 | blue_offset = 1.772 * cb_factor + br_delta 26 | 27 | channels = tf.split(axis=2, num_or_size_splits=3, value=image) 28 | channels[0] += red_offset 29 | channels[1] += green_offset 30 | channels[2] += blue_offset 31 | 32 | image = tf.concat(axis=2, values=channels) 33 | image = tf.clip_by_value(image, 0.0, 1.0) 34 | return image 35 | 36 | def to_grayscale(image): 37 | image = tf.image.rgb_to_grayscale(image) 38 | image = tf.image.grayscale_to_rgb(image) 39 | return image 40 | 41 | do_it = tf.less(tf.random_uniform([]), probability) 42 | image = tf.cond(do_it, lambda: manipulate(image), lambda: image) 43 | 44 | do_it = tf.less(tf.random_uniform([]), grayscale_probability) 45 | image = tf.cond(do_it, lambda: to_grayscale(image), lambda: image) 46 | 47 | return image 48 | 49 | 50 | def random_pixel_value_scale(image, probability=0.5, minval=0.9, maxval=1.1): 51 | """ 52 | This function scales each pixel 53 | independently of the other ones. 54 | 55 | Arguments: 56 | image: a float tensor with shape [height, width, 3], 57 | an image with pixel values varying between zero and one. 58 | probability: a float number. 59 | minval: a float number, lower ratio of scaling pixel values. 60 | maxval: a float number, upper ratio of scaling pixel values. 61 | Returns: 62 | a float tensor with shape [height, width, 3]. 63 | """ 64 | def random_value_scale(image): 65 | color_coefficient = tf.random_uniform( 66 | tf.shape(image), minval=minval, 67 | maxval=maxval, dtype=tf.float32 68 | ) 69 | image = tf.multiply(image, color_coefficient) 70 | image = tf.clip_by_value(image, 0.0, 1.0) 71 | return image 72 | 73 | do_it = tf.less(tf.random_uniform([]), probability) 74 | image = tf.cond(do_it, lambda: random_value_scale(image), lambda: image) 75 | return image 76 | -------------------------------------------------------------------------------- /detector/input_pipeline/heatmap_creation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from scipy import signal 4 | 5 | 6 | def get_heatmaps(keypoints, boxes, width, height, downsample): 7 | """ 8 | Arguments: 9 | keypoints: a numpy int array with shape [num_persons, 17, 3]. 10 | It is in format (y, x, visibility), 11 | where coordinates `y, x` are in the ranges 12 | [0, height - 1] and [0, width - 1]. 13 | And a keypoint is visible if `visibility > 0`. 14 | boxes: a numpy float array with shape [num_persons, 4], 15 | person bounding boxes in absolute coordinates. 16 | width, height: integers, size of the original image. 17 | downsample: an integer. 18 | Returns: 19 | a numpy float array with shape [height/downsample, width/downsample, 17]. 20 | """ 21 | 22 | min_sigma, max_sigma = 1.0, 4.0 23 | scaler = np.array([height - 1.0, width - 1.0], dtype=np.float32) 24 | keypoints = keypoints.astype(np.float32) 25 | 26 | # compute output size 27 | h = math.ceil(height / downsample) 28 | w = math.ceil(width / downsample) 29 | 30 | ymin, xmin, ymax, xmax = np.split(boxes, 4, axis=1) 31 | # they have shape [num_persons, 1] 32 | 33 | scale = np.sqrt((ymax - ymin) * (xmax - xmin)) 34 | sigmas = np.squeeze(scale * 0.007, axis=1) 35 | 36 | kernels = [] # each person has different blob size 37 | sigmas = np.clip(sigmas, min_sigma, max_sigma) 38 | 39 | for sigma in sigmas: 40 | kernels.append(get_kernel(sigma)) 41 | 42 | heatmaps = [] 43 | for i in range(17): 44 | 45 | is_visible = keypoints[:, i, 2] > 0 46 | num_visible = is_visible.sum() 47 | 48 | if num_visible == 0: 49 | empty = np.zeros([h, w], dtype=np.float32) 50 | heatmaps.append(empty) 51 | continue 52 | 53 | person_id = np.where(is_visible)[0] 54 | body_part = keypoints[is_visible, i, :2] 55 | # it has shape [num_visible, 2] 56 | 57 | # to the [0, 1] range 58 | body_part /= scaler 59 | 60 | heatmaps_for_part = [] 61 | for i in range(num_visible): 62 | 63 | kernel = kernels[person_id[i]] 64 | y, x = body_part[i] 65 | 66 | heatmap = create_heatmap(y, x, kernel, w, h) 67 | heatmaps_for_part.append(heatmap) 68 | 69 | heatmaps.append(np.stack(heatmaps_for_part, axis=2).max(2)) 70 | 71 | heatmaps = np.stack(heatmaps, axis=2) 72 | return heatmaps 73 | 74 | 75 | def get_kernel(std): 76 | """Returns a 2D Gaussian kernel array.""" 77 | 78 | k = np.ceil(np.sqrt(- 2.0 * std**2 * np.log(0.01))) 79 | # it is true that exp(- 0.5 * k**2 / std**2) < 0.01 80 | 81 | size = 2 * int(k) + 1 82 | x = signal.windows.gaussian(size, std=std).reshape([size, 1]) 83 | x = np.outer(x, x).astype(np.float32) 84 | return x 85 | 86 | 87 | def create_heatmap(y, x, kernel, width, height): 88 | """ 89 | Arguments: 90 | y, x: float numbers, normalized to the [0, 1] range. 91 | kernel: a numpy float array with shape [2 * k + 1, 2 * k + 1]. 92 | width, height: integers. 93 | Returns: 94 | a numpy float array with shape [height, width]. 95 | """ 96 | 97 | # half kernel size 98 | k = (kernel.shape[0] - 1) // 2 99 | 100 | x = x * (width - 1) 101 | y = y * (height - 1) 102 | x, y = int(round(x)), int(round(y)) 103 | # they are in ranges [0, width - 1] and [0, height - 1] 104 | 105 | xmin, ymin = x - k, y - k 106 | xmax, ymax = x + k, y + k 107 | 108 | shape = [height + 2 * k, width + 2 * k] 109 | heatmap = np.zeros(shape, dtype=np.float32) 110 | 111 | # shift coordinates 112 | xmin, ymin = xmin + k, ymin + k 113 | xmax, ymax = xmax + k, ymax + k 114 | 115 | heatmap[ymin:(ymax + 1), xmin:(xmax + 1)] = kernel 116 | heatmap = heatmap[k:-k, k:-k] 117 | 118 | return heatmap 119 | -------------------------------------------------------------------------------- /detector/input_pipeline/keypoints_detector_pipeline.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.constants import SHUFFLE_BUFFER_SIZE, NUM_PARALLEL_CALLS, RESIZE_METHOD 3 | from detector.constants import DOWNSAMPLE, DIVISOR, OVERLAP_THRESHOLD 4 | from detector.input_pipeline.random_crop import random_image_crop 5 | from detector.input_pipeline.random_rotation import random_image_rotation 6 | from detector.input_pipeline.color_augmentations import random_color_manipulations, random_pixel_value_scale 7 | from detector.input_pipeline.heatmap_creation import get_heatmaps 8 | 9 | 10 | class KeypointPipeline: 11 | """ 12 | Input pipeline for training or evaluating 13 | networks for heatmaps regression. 14 | """ 15 | def __init__(self, filenames, is_training, params): 16 | """ 17 | During the evaluation we resize images keeping aspect ratio. 18 | 19 | Arguments: 20 | filenames: a list of strings, paths to tfrecords files. 21 | is_training: a boolean. 22 | params: a dict. 23 | """ 24 | self.is_training = is_training 25 | 26 | if not is_training: 27 | batch_size = 1 28 | min_dimension = params['min_dimension'] 29 | assert min_dimension % DIVISOR == 0 30 | self.min_dimension = min_dimension 31 | else: 32 | batch_size = params['batch_size'] 33 | width, height = params['image_size'] 34 | assert height % DIVISOR == 0 35 | assert width % DIVISOR == 0 36 | self.image_size = [height, width] 37 | 38 | dataset = tf.data.Dataset.from_tensor_slices(filenames) 39 | 40 | if is_training: 41 | num_shards = len(filenames) 42 | dataset = dataset.shuffle(buffer_size=num_shards) 43 | 44 | dataset = dataset.flat_map(tf.data.TFRecordDataset) 45 | dataset = dataset.prefetch(buffer_size=batch_size) 46 | 47 | if is_training: 48 | dataset = dataset.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE) 49 | 50 | dataset = dataset.repeat(None if is_training else 1) 51 | dataset = dataset.map(self.parse_and_preprocess, num_parallel_calls=NUM_PARALLEL_CALLS) 52 | 53 | dataset = dataset.batch(batch_size, drop_remainder=True) 54 | dataset = dataset.prefetch(buffer_size=1) 55 | 56 | self.dataset = dataset 57 | 58 | def parse_and_preprocess(self, example_proto): 59 | """ 60 | All heatmaps and masks have values in [0, 1] range. 61 | 62 | Returns: 63 | image: a float tensor with shape [height, width, 3], 64 | an RGB image with pixel values in the range [0, 1]. 65 | heatmaps: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE, 17]. 66 | loss_masks: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE]. 67 | segmentation_masks: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE]. 68 | num_boxes: an int tensor with shape []. 69 | """ 70 | image, masks, boxes, keypoints = self.parse(example_proto) 71 | 72 | if self.is_training: 73 | image, masks, boxes, keypoints = augmentation( 74 | image, masks, boxes, 75 | keypoints, self.image_size 76 | ) 77 | else: 78 | image, masks, boxes, keypoints = resize_keeping_aspect_ratio( 79 | image, masks, boxes, keypoints, 80 | self.min_dimension, DIVISOR 81 | ) 82 | 83 | shape = tf.shape(image) 84 | image_height, image_width = shape[0], shape[1] 85 | 86 | heatmaps = tf.py_func( 87 | lambda k, b, w, h: get_heatmaps(k, b, w, h, DOWNSAMPLE), 88 | [keypoints, boxes, image_width, image_height], 89 | tf.float32, stateful=False 90 | ) 91 | 92 | if self.is_training: 93 | from math import ceil 94 | height, width = self.image_size 95 | h = ceil(height/DOWNSAMPLE) 96 | w = ceil(width/DOWNSAMPLE) 97 | heatmaps.set_shape([h, w, 17]) 98 | else: 99 | heatmaps.set_shape([None, None, 17]) 100 | 101 | # this is needed for normalization 102 | num_boxes = tf.shape(boxes)[0] 103 | 104 | features = {'images': image} 105 | labels = { 106 | 'heatmaps': heatmaps, 107 | 'loss_masks': masks[:, :, 0], 108 | 'segmentation_masks': masks[:, :, 1], 109 | 'num_boxes': num_boxes 110 | } 111 | return features, labels 112 | 113 | def parse(self, example_proto): 114 | """ 115 | Returns: 116 | image: a float tensor with shape [height, width, 3], 117 | an RGB image with pixel values in the range [0, 1]. 118 | masks: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE, 2]. 119 | boxes: a float tensor with shape [num_persons, 4], in absolute coordinates. 120 | keypoints: an int tensor with shape [num_persons, 17, 3], in absolute coordinates. 121 | """ 122 | features = { 123 | 'image': tf.FixedLenFeature([], tf.string), 124 | 'num_persons': tf.FixedLenFeature([], tf.int64), 125 | 'boxes': tf.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 126 | 'keypoints': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True), 127 | 'masks': tf.FixedLenFeature([], tf.string) 128 | } 129 | parsed_features = tf.parse_single_example(example_proto, features) 130 | 131 | # get an image 132 | image = tf.image.decode_jpeg(parsed_features['image'], channels=3) 133 | image = tf.image.convert_image_dtype(image, tf.float32) 134 | # now pixel values are scaled to the [0, 1] range 135 | 136 | # get number of people on the image 137 | num_persons = tf.to_int32(parsed_features['num_persons']) 138 | # it is assumed that num_persons > 0 139 | 140 | # get groundtruth boxes, they are in absolute coordinates 141 | boxes = tf.reshape(parsed_features['boxes'], [num_persons, 4]) 142 | # they are used to guide the data augmentation (when doing a random crop) 143 | # and to choose sigmas for gaussian blobs 144 | 145 | # get keypoints, they are in absolute coordinates 146 | keypoints = tf.to_int32(parsed_features['keypoints']) 147 | keypoints = tf.reshape(keypoints, [num_persons, 17, 3]) 148 | 149 | # get size of masks, they are downsampled 150 | shape = tf.shape(image) 151 | image_height, image_width = shape[0], shape[1] 152 | masks_height = tf.to_int32(tf.ceil(image_height/DOWNSAMPLE)) 153 | masks_width = tf.to_int32(tf.ceil(image_width/DOWNSAMPLE)) 154 | # (we use the 'SAME' padding in the networks) 155 | 156 | # get masks (loss and segmentation masks) 157 | masks = tf.decode_raw(parsed_features['masks'], tf.uint8) 158 | # unpack bits (reverse np.packbits) 159 | b = tf.constant([128, 64, 32, 16, 8, 4, 2, 1], dtype=tf.uint8) 160 | masks = tf.reshape(tf.bitwise.bitwise_and(masks[:, None], b), [-1]) 161 | masks = masks[:(masks_height * masks_width * 2)] 162 | masks = tf.cast(masks > 0, tf.uint8) 163 | 164 | # reshape to the initial form 165 | masks = tf.reshape(masks, [masks_height, masks_width, 2]) 166 | masks = tf.to_float(masks) # it has binary values only 167 | 168 | return image, masks, boxes, keypoints 169 | 170 | 171 | def augmentation(image, masks, boxes, keypoints, image_size): 172 | image, masks, boxes, keypoints = random_image_rotation(image, masks, boxes, keypoints, max_angle=45, probability=0.7) 173 | image, masks, boxes, keypoints = randomly_crop_and_resize(image, masks, boxes, keypoints, image_size, probability=0.9) 174 | image = random_color_manipulations(image, probability=0.5, grayscale_probability=0.1) 175 | image = random_pixel_value_scale(image, probability=0.1, minval=0.9, maxval=1.1) 176 | image, masks, boxes, keypoints = random_flip_left_right(image, masks, boxes, keypoints) 177 | return image, masks, boxes, keypoints 178 | 179 | 180 | def resize_keeping_aspect_ratio(image, masks, boxes, keypoints, min_dimension, divisor): 181 | """ 182 | This function resizes and possibly pads with zeros. 183 | When using a usual FPN, divisor must be equal to 128. 184 | 185 | Arguments: 186 | image: a float tensor with shape [height, width, 3]. 187 | masks: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE, 2]. 188 | boxes: a float tensor with shape [num_persons, 4]. 189 | keypoints: an int tensor with shape [num_persons, 17, 3]. 190 | min_dimension, divisor: integers. 191 | Returns: 192 | image: a float tensor with shape [h, w, 3], 193 | where `min_dimension = min(h, w)`, 194 | `h` and `w` are divisible by `DIVISOR`. 195 | masks: a float tensor with shape [h / DOWNSAMPLE, w / DOWNSAMPLE, 2]. 196 | boxes: a float tensor with shape [num_persons, 4]. 197 | keypoints: an int tensor with shape [num_persons, 17, 3]. 198 | """ 199 | 200 | assert min_dimension % divisor == 0 201 | min_dimension = tf.constant(min_dimension, dtype=tf.int32) 202 | divisor = tf.constant(divisor, dtype=tf.int32) 203 | 204 | shape = tf.shape(image) 205 | height, width = shape[0], shape[1] 206 | 207 | original_min_dim = tf.minimum(height, width) 208 | scale_factor = tf.to_float(min_dimension / original_min_dim) 209 | 210 | # RESIZE AND PAD IMAGE 211 | 212 | def scale(x): 213 | unpadded_x = tf.to_int32(tf.round(tf.to_float(x) * scale_factor)) 214 | x = tf.to_int32(tf.ceil(unpadded_x / divisor)) 215 | pad = divisor * x - unpadded_x 216 | return (unpadded_x, pad) 217 | 218 | zero = tf.constant(0, dtype=tf.int32) 219 | new_height, pad_height, new_width, pad_width = tf.cond( 220 | tf.greater_equal(height, width), 221 | lambda: scale(height) + (min_dimension, zero), 222 | lambda: (min_dimension, zero) + scale(width) 223 | ) 224 | 225 | # final image size 226 | h = new_height + pad_height 227 | w = new_width + pad_width 228 | 229 | # resize keeping aspect ratio 230 | image = tf.image.resize_images(image, [new_height, new_width], method=RESIZE_METHOD) 231 | 232 | # pad image at the bottom or at the right 233 | image = tf.image.pad_to_bounding_box(image, offset_height=0, offset_width=0, target_height=h, target_width=w) 234 | 235 | # RESIZE AND PAD MASKS 236 | 237 | # new size of masks with padding 238 | map_height = tf.to_int32(tf.ceil(h / DOWNSAMPLE)) 239 | map_width = tf.to_int32(tf.ceil(w / DOWNSAMPLE)) 240 | 241 | # new size of only masks without padding 242 | map_only_height = tf.to_int32(tf.ceil(new_height / DOWNSAMPLE)) 243 | map_only_width = tf.to_int32(tf.ceil(new_width / DOWNSAMPLE)) 244 | 245 | masks = tf.image.resize_images( 246 | masks, [map_only_height, map_only_width], 247 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR 248 | ) 249 | 250 | masks = tf.image.pad_to_bounding_box( 251 | masks, offset_height=0, offset_width=0, 252 | target_height=map_height, target_width=map_width 253 | ) 254 | 255 | # TRANSFORM KEYPOINTS 256 | 257 | keypoint_scaler = tf.stack([new_height/height, new_width/width]) 258 | keypoint_scaler = tf.to_float(keypoint_scaler) 259 | 260 | points, v = tf.split(keypoints, [2, 1], axis=2) 261 | points = tf.to_int32(tf.round(tf.to_float(points) * keypoint_scaler)) 262 | y, x = tf.split(points, 2, axis=2) 263 | y = tf.clip_by_value(y, 0, h - 1) 264 | x = tf.clip_by_value(x, 0, w - 1) 265 | keypoints = tf.concat([y, x, v], axis=2) 266 | 267 | # TRANSFORM BOXES 268 | 269 | box_scaler = tf.concat(2 * [keypoint_scaler], axis=0) 270 | boxes *= box_scaler 271 | 272 | return image, masks, boxes, keypoints 273 | 274 | 275 | def randomly_crop_and_resize(image, masks, boxes, keypoints, image_size, probability=0.5): 276 | """ 277 | Arguments: 278 | image: a float tensor with shape [height, width, 3]. 279 | masks: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE, 2]. 280 | boxes: a float tensor with shape [num_persons, 4]. 281 | keypoints: an int tensor with shape [num_persons, 17, 3]. 282 | image_size: a tuple of integers (h, w). 283 | probability: a float number. 284 | Returns: 285 | image: a float tensor with shape [h, w, 3]. 286 | masks: a float tensor with shape [h / DOWNSAMPLE, w / DOWNSAMPLE, 2]. 287 | boxes: a float tensor with shape [num_remaining, 4]. 288 | keypoints: an int tensor with shape [num_remaining, 17, 3]. 289 | """ 290 | 291 | shape = tf.to_float(tf.shape(image)) 292 | height, width = shape[0], shape[1] 293 | scaler = tf.stack([height, width, height, width]) 294 | boxes /= scaler # to the [0, 1] range 295 | 296 | def crop(image, boxes, keypoints): 297 | """ 298 | Arguments: 299 | image: a float tensor with shape [height, width, 3]. 300 | boxes: a float tensor with shape [num_persons, 4]. 301 | keypoints: an int tensor with shape [num_persons, 17, 3]. 302 | Returns: 303 | image: a float tensor with shape [None, None, 3]. 304 | boxes: a float tensor with shape [num_remaining, 4]. 305 | keypoints: an int tensor with shape [num_remaining, 17, 3]. 306 | window: a float tensor with shape [4]. 307 | """ 308 | 309 | image, boxes, window, keep_indices = random_image_crop( 310 | image, boxes, min_object_covered=0.9, 311 | aspect_ratio_range=(0.95, 1.05), 312 | area_range=(0.5, 1.0), 313 | overlap_threshold=OVERLAP_THRESHOLD 314 | ) 315 | 316 | keypoints = tf.gather(keypoints, keep_indices) 317 | # it has shape [num_remaining, 17, 3] 318 | 319 | ymin, xmin, ymax, xmax = tf.unstack(window * scaler) 320 | points, v = tf.split(keypoints, [2, 1], axis=2) 321 | points = tf.to_float(points) # shape [num_remaining, 17, 2] 322 | 323 | translation = tf.stack([ymin, xmin]) 324 | points = tf.to_int32(tf.round(points - translation)) 325 | keypoints = tf.concat([points, v], axis=2) 326 | 327 | # note that after this some keypoints will be invisible, 328 | # so we need to modify the `v` vector later 329 | 330 | return image, boxes, keypoints, window 331 | 332 | whole_image_window = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32) 333 | do_it = tf.less(tf.random_uniform([]), probability) 334 | 335 | image, boxes, keypoints, window = tf.cond( 336 | do_it, lambda: crop(image, boxes, keypoints), 337 | lambda: (image, boxes, keypoints, whole_image_window) 338 | ) 339 | 340 | def correct_keypoints(image_shape, keypoints): 341 | """ 342 | Arguments: 343 | image_shape: an int tensor with shape [3]. 344 | keypoints: an int tensor with shape [num_persons, 17, 3]. 345 | Returns: 346 | an int tensor with shape [num_persons, 17, 3]. 347 | """ 348 | y, x, v = tf.split(keypoints, 3, axis=2) 349 | 350 | height = image_shape[0] 351 | width = image_shape[1] 352 | 353 | coordinate_violations = tf.concat([ 354 | tf.less(y, 0), tf.less(x, 0), 355 | tf.greater_equal(y, height), 356 | tf.greater_equal(x, width) 357 | ], axis=2) # shape [num_persons, 17, 4] 358 | 359 | valid_indicator = tf.logical_not(tf.reduce_any(coordinate_violations, axis=2)) 360 | valid_indicator = tf.expand_dims(valid_indicator, 2) 361 | # it has shape [num_persons, 17, 1] 362 | 363 | v *= tf.to_int32(valid_indicator) 364 | keypoints = tf.concat([y, x, v], axis=2) 365 | return keypoints 366 | 367 | def rescale(boxes, keypoints, old_shape, new_shape): 368 | """ 369 | Arguments: 370 | boxes: a float tensor with shape [num_persons, 4]. 371 | keypoints: an int tensor with shape [num_persons, 17, 3]. 372 | old_shape, new_shape: int tensors with shape [3]. 373 | Returns: 374 | a float tensor with shape [num_persons, 4]. 375 | an int tensor with shape [num_persons, 17, 3]. 376 | """ 377 | points, v = tf.split(keypoints, [2, 1], axis=2) 378 | points = tf.to_float(points) 379 | 380 | old_shape = tf.to_float(old_shape) 381 | new_shape = tf.to_float(new_shape) 382 | old_height, old_width = old_shape[0], old_shape[1] 383 | new_height, new_width = new_shape[0], new_shape[1] 384 | 385 | scaler = tf.stack([new_height/old_height, new_width/old_width]) 386 | points *= scaler 387 | 388 | scaler = tf.stack([new_height, new_width]) 389 | scaler = tf.concat(2 * [scaler], axis=0) 390 | boxes *= scaler 391 | 392 | new_height = tf.to_int32(new_height) 393 | new_width = tf.to_int32(new_width) 394 | 395 | points = tf.to_int32(tf.round(points)) 396 | y, x = tf.split(points, 2, axis=2) 397 | y = tf.clip_by_value(y, 0, new_height - 1) 398 | x = tf.clip_by_value(x, 0, new_width - 1) 399 | keypoints = tf.concat([y, x, v], axis=2) 400 | return boxes, keypoints 401 | 402 | old_shape = tf.shape(image) 403 | keypoints = correct_keypoints(old_shape, keypoints) 404 | 405 | h, w = image_size # image size that will be used for training 406 | image = tf.image.resize_images(image, [h, w], method=RESIZE_METHOD) 407 | 408 | masks_height = tf.to_int32(tf.ceil(h / DOWNSAMPLE)) 409 | masks_width = tf.to_int32(tf.ceil(w / DOWNSAMPLE)) 410 | 411 | masks = tf.image.crop_and_resize( 412 | image=tf.expand_dims(masks, 0), 413 | boxes=tf.expand_dims(window, 0), 414 | box_indices=tf.constant([0], dtype=tf.int32), 415 | crop_size=[masks_height, masks_width], 416 | method='nearest' 417 | ) 418 | masks = masks[0] 419 | 420 | boxes, keypoints = rescale(boxes, keypoints, old_shape, tf.shape(image)) 421 | return image, masks, boxes, keypoints 422 | 423 | 424 | def random_flip_left_right(image, masks, boxes, keypoints): 425 | 426 | def flip(image, masks, boxes, keypoints): 427 | 428 | flipped_image = tf.image.flip_left_right(image) 429 | flipped_masks = tf.image.flip_left_right(masks) 430 | 431 | y, x, v = tf.unstack(keypoints, axis=2) 432 | width = tf.shape(image)[1] 433 | flipped_x = width - 1 - x 434 | flipped_keypoints = tf.stack([y, flipped_x, v], axis=2) 435 | 436 | width = tf.to_float(width) 437 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 438 | flipped_boxes = tf.stack([ymin, width - xmax, ymax, width - xmin], axis=1) 439 | 440 | """ 441 | The keypoint order: 442 | 0: 'nose', 443 | 1: 'left eye', 2: 'right eye', 444 | 3: 'left ear', 4: 'right ear', 445 | 5: 'left shoulder', 6: 'right shoulder', 446 | 7: 'left elbow', 8: 'right elbow', 447 | 9: 'left wrist', 10: 'right wrist', 448 | 11: 'left hip', 12: 'right hip', 449 | 13: 'left knee', 14: 'right knee', 450 | 15: 'left ankle', 16: 'right ankle' 451 | """ 452 | 453 | correct_order = tf.constant([0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]) 454 | flipped_keypoints = tf.gather(flipped_keypoints, correct_order, axis=1) 455 | 456 | return flipped_image, flipped_masks, flipped_boxes, flipped_keypoints 457 | 458 | do_it = tf.less(tf.random_uniform([]), 0.5) 459 | image, masks, boxes, keypoints = tf.cond( 460 | do_it, 461 | lambda: flip(image, masks, boxes, keypoints), 462 | lambda: (image, masks, boxes, keypoints) 463 | ) 464 | return image, masks, boxes, keypoints 465 | -------------------------------------------------------------------------------- /detector/input_pipeline/person_detector_pipeline.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.constants import SHUFFLE_BUFFER_SIZE, NUM_PARALLEL_CALLS, RESIZE_METHOD, DIVISOR 3 | from detector.input_pipeline.color_augmentations import random_color_manipulations, random_pixel_value_scale 4 | from detector.input_pipeline.random_crop import random_image_crop 5 | 6 | 7 | class DetectorPipeline: 8 | """ 9 | Input pipeline for training or evaluating object detectors. 10 | It is assumed that all boxes are of the same class. 11 | """ 12 | def __init__(self, filenames, is_training, params): 13 | """ 14 | During the evaluation we resize images keeping aspect ratio. 15 | 16 | Arguments: 17 | filenames: a list of strings, paths to tfrecords files. 18 | is_training: a boolean. 19 | params: a dict. 20 | """ 21 | self.is_training = is_training 22 | 23 | def get_num_samples(filename): 24 | return sum(1 for _ in tf.python_io.tf_record_iterator(filename)) 25 | 26 | num_examples = 0 27 | for filename in filenames: 28 | num_examples_in_file = get_num_samples(filename) 29 | num_examples += num_examples_in_file 30 | self.num_examples = num_examples 31 | 32 | if not is_training: 33 | batch_size = 1 34 | self.image_size = [None, None] 35 | self.min_dimension = params['min_dimension'] 36 | else: 37 | batch_size = params['batch_size'] 38 | width, height = params['image_size'] 39 | assert height % DIVISOR == 0 40 | assert width % DIVISOR == 0 41 | self.image_size = [height, width] 42 | 43 | dataset = tf.data.Dataset.from_tensor_slices(filenames) 44 | 45 | if is_training: 46 | num_shards = len(filenames) 47 | dataset = dataset.shuffle(buffer_size=num_shards) 48 | 49 | dataset = dataset.flat_map(tf.data.TFRecordDataset) 50 | dataset = dataset.prefetch(buffer_size=batch_size) 51 | 52 | if is_training: 53 | dataset = dataset.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE) 54 | 55 | dataset = dataset.repeat(None if is_training else 1) 56 | dataset = dataset.map(self.parse_and_preprocess, num_parallel_calls=NUM_PARALLEL_CALLS) 57 | 58 | padded_shapes = ({'images': self.image_size + [3]}, {'boxes': [None, 4], 'num_boxes': []}) 59 | dataset = dataset.padded_batch(batch_size, padded_shapes, drop_remainder=True) 60 | dataset = dataset.prefetch(buffer_size=1) 61 | 62 | self.dataset = dataset 63 | 64 | def parse_and_preprocess(self, example_proto): 65 | """ 66 | Returns: 67 | image: a float tensor with shape [height, width, 3], 68 | an RGB image with pixel values in the range [0, 1]. 69 | boxes: a float tensor with shape [num_boxes, 4]. 70 | num_boxes: an int tensor with shape []. 71 | """ 72 | features = { 73 | 'image': tf.FixedLenFeature([], tf.string), 74 | 'num_persons': tf.FixedLenFeature([], tf.int64), 75 | 'boxes': tf.FixedLenSequenceFeature([], tf.float32, allow_missing=True) 76 | } 77 | parsed_features = tf.parse_single_example(example_proto, features) 78 | 79 | # get an image 80 | image = tf.image.decode_jpeg(parsed_features['image'], channels=3) 81 | image = tf.image.convert_image_dtype(image, tf.float32) 82 | # now pixel values are scaled to the [0, 1] range 83 | 84 | # get number of people on the image 85 | num_boxes = tf.to_int32(parsed_features['num_persons']) 86 | # it is assumed that num_boxes > 0 87 | 88 | # get groundtruth boxes, they are in absolute coordinates 89 | boxes = tf.reshape(parsed_features['boxes'], [num_boxes, 4]) 90 | 91 | # to the [0, 1] range 92 | height, width = tf.shape(image)[0], tf.shape(image)[1] 93 | scaler = tf.to_float(tf.stack([height, width, height, width])) 94 | boxes /= scaler 95 | 96 | if self.is_training: 97 | image, boxes = augmentation(image, boxes, self.image_size) 98 | else: 99 | image, boxes = resize_keeping_aspect_ratio(image, boxes, self.min_dimension, DIVISOR) 100 | 101 | # it could change after augmentations 102 | num_boxes = tf.shape(boxes)[0] 103 | 104 | features = {'images': image} 105 | labels = {'boxes': boxes, 'num_boxes': num_boxes} 106 | return features, labels 107 | 108 | 109 | def augmentation(image, boxes, image_size): 110 | image, boxes = randomly_crop_and_resize(image, boxes, image_size, probability=0.9) 111 | image, boxes = randomly_pad(image, boxes, probability=0.1) 112 | image = random_color_manipulations(image, probability=0.33, grayscale_probability=0.033) 113 | image = random_pixel_value_scale(image, probability=0.1, minval=0.8, maxval=1.2) 114 | boxes = random_box_jitter(boxes, ratio=0.01) 115 | image, boxes = random_flip_left_right(image, boxes) 116 | return image, boxes 117 | 118 | 119 | def randomly_crop_and_resize(image, boxes, image_size, probability=0.9): 120 | 121 | def crop(image, boxes): 122 | image, boxes, _, _ = random_image_crop( 123 | image, boxes, 124 | min_object_covered=0.9, 125 | aspect_ratio_range=(0.85, 1.15), 126 | area_range=(0.75, 1.0), 127 | overlap_threshold=0.3 128 | ) 129 | return image, boxes 130 | 131 | do_it = tf.less(tf.random_uniform([]), probability) 132 | image, boxes = tf.cond( 133 | do_it, 134 | lambda: crop(image, boxes), 135 | lambda: (image, boxes) 136 | ) 137 | image = tf.image.resize_images(image, image_size, method=RESIZE_METHOD) 138 | return image, boxes 139 | 140 | 141 | def randomly_pad(image, boxes, probability=0.9): 142 | """ 143 | This function makes content of the image 144 | smaller by scaling and padding it with zeros. 145 | """ 146 | 147 | def pad(image, boxes): 148 | 149 | shape = tf.shape(image) 150 | height, width = shape[0], shape[1] 151 | 152 | # randomly reduce image scale 153 | scale = tf.random_uniform([], 0.5, 0.9) 154 | scaled_height = tf.to_int32(scale * tf.to_float(height)) 155 | scaled_width = tf.to_int32(scale * tf.to_float(width)) 156 | 157 | image = tf.image.resize_images( 158 | image, [scaled_height, scaled_width], 159 | method=RESIZE_METHOD 160 | ) 161 | 162 | # randomly pad to the initial size 163 | offset_y = height - scaled_height 164 | offset_x = width - scaled_width 165 | offset_y = tf.random_uniform([], 0, offset_y, dtype=tf.int32) 166 | offset_x = tf.random_uniform([], 0, offset_x, dtype=tf.int32) 167 | image = tf.image.pad_to_bounding_box(image, offset_y, offset_x, height, width) 168 | 169 | # transform boxes 170 | boxes *= scale 171 | offset_y = tf.to_float(offset_y/height) 172 | offset_x = tf.to_float(offset_x/width) 173 | translation = tf.stack([offset_y, offset_x, offset_y, offset_x]) 174 | boxes += translation 175 | 176 | return image, boxes 177 | 178 | do_it = tf.less(tf.random_uniform([]), probability) 179 | image, boxes = tf.cond(do_it, lambda: pad(image, boxes), lambda: (image, boxes)) 180 | return image, boxes 181 | 182 | 183 | def resize_keeping_aspect_ratio(image, boxes, min_dimension, divisor): 184 | """ 185 | This function resizes and possibly pads with zeros. 186 | When using a usual FPN, divisor must be equal to 128. 187 | 188 | Arguments: 189 | image: a float tensor with shape [height, width, 3]. 190 | boxes: a float tensor with shape [n, 4]. 191 | min_dimension: an integer. 192 | divisor: an integer. 193 | Returns: 194 | image: a float tensor with shape [h, w, 3], 195 | where `min_dimension = min(h, w)`, 196 | `h` and `w` are divisible by `divisor`. 197 | boxes: a float tensor with shape [n, 4]. 198 | """ 199 | assert min_dimension % divisor == 0 200 | 201 | min_dimension = tf.constant(min_dimension, dtype=tf.int32) 202 | divisor = tf.constant(divisor, dtype=tf.int32) 203 | 204 | shape = tf.shape(image) 205 | height, width = shape[0], shape[1] 206 | 207 | original_min_dim = tf.minimum(height, width) 208 | scale_factor = tf.to_float(min_dimension / original_min_dim) 209 | 210 | def scale(x): 211 | unpadded_x = tf.to_int32(tf.round(tf.to_float(x) * scale_factor)) 212 | x = tf.to_int32(tf.ceil(unpadded_x / divisor)) 213 | pad = divisor * x - unpadded_x 214 | return (unpadded_x, pad) 215 | 216 | zero = tf.constant(0, dtype=tf.int32) 217 | new_height, pad_height, new_width, pad_width = tf.cond( 218 | tf.greater_equal(height, width), 219 | lambda: scale(height) + (min_dimension, zero), 220 | lambda: (min_dimension, zero) + scale(width) 221 | ) 222 | 223 | # resize keeping aspect ratio 224 | image = tf.image.resize_images(image, [new_height, new_width], method=RESIZE_METHOD) 225 | 226 | h = new_height + pad_height 227 | w = new_width + pad_width 228 | 229 | image = tf.image.pad_to_bounding_box( 230 | image, offset_height=0, offset_width=0, 231 | target_height=h, target_width=w 232 | ) 233 | # it pads image at the bottom or at the right 234 | 235 | # we need to rescale bounding box coordinates 236 | box_scaler = tf.to_float(tf.stack([ 237 | new_height/h, new_width/w, 238 | new_height/h, new_width/w 239 | ])) 240 | 241 | boxes *= box_scaler 242 | return image, boxes 243 | 244 | 245 | def random_flip_left_right(image, boxes): 246 | 247 | def flip(image, boxes): 248 | flipped_image = tf.image.flip_left_right(image) 249 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 250 | flipped_xmin = 1.0 - xmax 251 | flipped_xmax = 1.0 - xmin 252 | flipped_boxes = tf.stack([ymin, flipped_xmin, ymax, flipped_xmax], axis=1) 253 | return flipped_image, flipped_boxes 254 | 255 | do_it = tf.less(tf.random_uniform([]), 0.5) 256 | image, boxes = tf.cond(do_it, lambda: flip(image, boxes), lambda: (image, boxes)) 257 | return image, boxes 258 | 259 | 260 | def random_box_jitter(boxes, ratio=0.05): 261 | """Randomly jitter bounding boxes. 262 | 263 | Arguments: 264 | boxes: a float tensor with shape [N, 4]. 265 | ratio: a float number. 266 | The ratio of the box width and height that the corners can jitter. 267 | For example if the width is 100 pixels and ratio is 0.05, 268 | the corners can jitter up to 5 pixels in the x direction. 269 | Returns: 270 | a float tensor with shape [N, 4]. 271 | """ 272 | def jitter_box(box, ratio): 273 | """ 274 | Arguments: 275 | box: a float tensor with shape [4]. 276 | ratio: a float number. 277 | Returns: 278 | a float tensor with shape [4]. 279 | """ 280 | ymin, xmin, ymax, xmax = tf.unstack(box, axis=0) 281 | box_height, box_width = ymax - ymin, xmax - xmin 282 | hw_coefs = tf.stack([box_height, box_width, box_height, box_width]) 283 | 284 | rand_numbers = tf.random_uniform( 285 | [4], minval=-ratio, maxval=ratio, dtype=tf.float32 286 | ) 287 | hw_rand_coefs = tf.multiply(hw_coefs, rand_numbers) 288 | 289 | jittered_box = tf.add(box, hw_rand_coefs) 290 | return jittered_box 291 | 292 | distorted_boxes = tf.map_fn( 293 | lambda x: jitter_box(x, ratio), 294 | boxes, dtype=tf.float32, back_prop=False 295 | ) 296 | distorted_boxes = tf.clip_by_value(distorted_boxes, 0.0, 1.0) 297 | return distorted_boxes 298 | -------------------------------------------------------------------------------- /detector/input_pipeline/prn_pipeline.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.constants import SHUFFLE_BUFFER_SIZE, NUM_PARALLEL_CALLS, DOWNSAMPLE 3 | from detector.input_pipeline.heatmap_creation import get_heatmaps 4 | 5 | 6 | # height and width 7 | CROP_SIZE = [56, 36] 8 | 9 | 10 | class PoseResidualNetworkPipeline: 11 | """ 12 | """ 13 | def __init__(self, filenames, is_training, batch_size, max_keypoints=None): 14 | """ 15 | Arguments: 16 | filenames: a list of strings, paths to tfrecords files. 17 | is_training: a boolean. 18 | batch_size: an integer. 19 | max_keypoints: an integer or None. 20 | """ 21 | self.is_training = is_training 22 | self.max_keypoints = max_keypoints 23 | 24 | dataset = tf.data.Dataset.from_tensor_slices(filenames) 25 | 26 | if is_training: 27 | num_shards = len(filenames) 28 | dataset = dataset.shuffle(buffer_size=num_shards) 29 | 30 | dataset = dataset.flat_map(tf.data.TFRecordDataset) 31 | dataset = dataset.prefetch(buffer_size=batch_size) 32 | 33 | if is_training: 34 | dataset = dataset.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE) 35 | 36 | dataset = dataset.repeat(None if is_training else 1) 37 | dataset = dataset.map(self.parse_and_preprocess, num_parallel_calls=NUM_PARALLEL_CALLS) 38 | dataset = dataset.apply(tf.data.experimental.unbatch()) 39 | 40 | if is_training: 41 | dataset = dataset.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE) 42 | 43 | dataset = dataset.batch(batch_size) 44 | dataset = dataset.prefetch(buffer_size=1) 45 | 46 | self.dataset = dataset 47 | 48 | def parse_and_preprocess(self, example_proto): 49 | """ 50 | Returns: 51 | crops: a float tensor with shape [num_persons, height, width, 17]. 52 | labels: a float tensor with shape [num_persons, height, width, 17]. 53 | """ 54 | features = { 55 | 'image': tf.FixedLenFeature([], tf.string), 56 | 'num_persons': tf.FixedLenFeature([], tf.int64), 57 | 'boxes': tf.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 58 | 'keypoints': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True) 59 | } 60 | parsed_features = tf.parse_single_example(example_proto, features) 61 | 62 | # get size of the image 63 | shape = tf.image.extract_jpeg_shape(parsed_features['image']) 64 | image_height, image_width = shape[0], shape[1] 65 | scaler = tf.to_float(tf.stack(2 * [image_height, image_width])) 66 | 67 | # get number of people on the image 68 | num_persons = tf.to_int32(parsed_features['num_persons']) 69 | # it is assumed that num_persons > 0 70 | 71 | # get groundtruth boxes, they are in absolute coordinates 72 | boxes = tf.reshape(parsed_features['boxes'], [num_persons, 4]) 73 | 74 | # get keypoints, they are in absolute coordinates 75 | keypoints = tf.to_int32(parsed_features['keypoints']) 76 | keypoints = tf.reshape(keypoints, [num_persons, 17, 3]) 77 | 78 | if self.max_keypoints is not None: 79 | 80 | # curriculum learning by sorting 81 | # annotations based on number of keypoints 82 | 83 | is_visible = tf.to_int32(keypoints[:, :, 2] > 0) # shape [num_persons, 17] 84 | is_good = tf.less_equal(tf.reduce_sum(is_visible, axis=1), self.max_keypoints) 85 | # it has shape [num_persons] 86 | 87 | keypoints = tf.boolean_mask(keypoints, is_good) 88 | boxes = tf.boolean_mask(boxes, is_good) 89 | num_persons = tf.shape(boxes)[0] 90 | 91 | heatmaps = tf.py_func( 92 | lambda k, b, w, h: get_heatmaps(k, b, w, h, DOWNSAMPLE), 93 | [keypoints, boxes, image_width, image_height], 94 | tf.float32, stateful=False 95 | ) 96 | heatmaps.set_shape([None, None, 17]) 97 | 98 | box_indices = tf.zeros([num_persons], dtype=tf.int32) 99 | crops = tf.image.crop_and_resize( 100 | tf.expand_dims(heatmaps, 0), 101 | boxes/scaler, box_indices, 102 | crop_size=CROP_SIZE 103 | ) 104 | 105 | def fn(x): 106 | """ 107 | Arguments: 108 | keypoints: a float tensor with shape [17, 3]. 109 | box: a float tensor with shape [4]. 110 | Returns: 111 | a float tensor with shape [height, width, 17]. 112 | """ 113 | keypoints, box = x 114 | 115 | ymin, xmin, ymax, xmax = tf.unstack(box, axis=0) 116 | y, x, v = tf.unstack(keypoints, axis=1) 117 | keypoints = tf.stack([y, x], axis=1) 118 | 119 | part_id = tf.where(v > 0.0) # shape [num_visible, 1] 120 | part_id = tf.to_int32(part_id) 121 | num_visible = tf.shape(part_id)[0] 122 | keypoints = tf.gather(keypoints, tf.squeeze(part_id, 1)) 123 | # it has shape [num_visible, 2], they have absolute coordinates 124 | 125 | # transform keypoints coordinates 126 | # to be relative to the box 127 | h, w = ymax - ymin, xmax - xmin 128 | height, width = CROP_SIZE 129 | translation = tf.stack([ymin, xmin]) 130 | scaler = tf.to_float(tf.stack([height/h, width/w], axis=0)) 131 | 132 | keypoints -= translation 133 | keypoints *= scaler 134 | keypoints = tf.to_int32(tf.round(keypoints)) 135 | # it has shape [num_visible, 2] 136 | 137 | y, x = tf.unstack(keypoints, axis=1) 138 | y = tf.clip_by_value(y, 0, height - 1) 139 | x = tf.clip_by_value(x, 0, width - 1) 140 | keypoints = tf.stack([y, x], axis=1) 141 | 142 | indices = tf.to_int64(tf.concat([keypoints, part_id], axis=1)) 143 | values = tf.ones([num_visible], dtype=tf.float32) 144 | binary_map = tf.sparse.SparseTensor(indices, values, dense_shape=[height, width, 17]) 145 | binary_map = tf.sparse.to_dense(binary_map, default_value=0, validate_indices=False) 146 | return binary_map 147 | 148 | labels = tf.map_fn( 149 | fn, (tf.to_float(keypoints), boxes), 150 | dtype=tf.float32, back_prop=False, 151 | ) 152 | 153 | if self.is_training: 154 | crops, labels = random_flip_left_right(crops, labels) 155 | 156 | return crops, labels 157 | 158 | 159 | def random_flip_left_right(crops, labels): 160 | 161 | def randomly_flip(x): 162 | """ 163 | Arguments: 164 | crops, labels: float tensors with shape [height, width, 17]. 165 | Returns: 166 | float tensors with shape [height, width, 17]. 167 | """ 168 | crops, labels = x 169 | 170 | def flip(crops, labels): 171 | 172 | crops = tf.image.flip_left_right(crops) 173 | labels = tf.image.flip_left_right(labels) 174 | 175 | """ 176 | The keypoint order: 177 | 0: 'nose', 178 | 1: 'left eye', 2: 'right eye', 179 | 3: 'left ear', 4: 'right ear', 180 | 5: 'left shoulder', 6: 'right shoulder', 181 | 7: 'left elbow', 8: 'right elbow', 182 | 9: 'left wrist', 10: 'right wrist', 183 | 11: 'left hip', 12: 'right hip', 184 | 13: 'left knee', 14: 'right knee', 185 | 15: 'left ankle', 16: 'right ankle' 186 | """ 187 | 188 | correct_order = tf.constant([0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]) 189 | crops = tf.gather(crops, correct_order, axis=2) 190 | labels = tf.gather(labels, correct_order, axis=2) 191 | return crops, labels 192 | 193 | do_it = tf.less(tf.random_uniform([]), 0.5) 194 | crops, labels = tf.cond(do_it, lambda: flip(crops, labels), lambda: (crops, labels)) 195 | 196 | return crops, labels 197 | 198 | crops, labels = tf.map_fn( 199 | randomly_flip, (crops, labels), 200 | dtype=(tf.float32, tf.float32), 201 | back_prop=False, 202 | ) 203 | return crops, labels 204 | -------------------------------------------------------------------------------- /detector/input_pipeline/random_crop.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.utils import intersection, area 3 | from detector.constants import EPSILON 4 | 5 | 6 | def random_image_crop( 7 | image, boxes, min_object_covered=0.9, 8 | aspect_ratio_range=(0.75, 1.33), area_range=(0.5, 1.0), 9 | overlap_threshold=0.3): 10 | """ 11 | Performs random crop. Given the input image and its bounding boxes, 12 | this op randomly crops a subimage. Given a user-provided set of input constraints, 13 | the crop window is resampled until it satisfies these constraints. 14 | If within 100 trials it is unable to find a valid crop, the original 15 | image is returned. Both input boxes and returned boxes are in normalized 16 | form (e.g., lie in the unit square [0, 1]). 17 | 18 | Arguments: 19 | image: a float tensor with shape [height, width, 3]. 20 | boxes: a float tensor containing bounding boxes. It has shape 21 | [num_boxes, 4]. Boxes are in normalized form, meaning 22 | their coordinates vary between [0, 1]. 23 | Each row is in the form of [ymin, xmin, ymax, xmax]. 24 | min_object_covered: the cropped image must cover at least this fraction of 25 | at least one of the input bounding boxes. 26 | aspect_ratio_range: allowed range for aspect ratio of cropped image. 27 | area_range: allowed range for area ratio between cropped image and the 28 | original image. 29 | overlap_threshold: minimum overlap thresh with new cropped 30 | image to keep the box. 31 | Returns: 32 | image: cropped image, a float tensor with shape [None, None, 3]. 33 | boxes: a float tensor with shape [num_remaining, 4], remaining boxes. 34 | Where 0 <= num_remaining <= num_boxes. 35 | window: a float tensor with shape [4], in normalized coordinates. 36 | keep_indices: an int tensor with shape [num_remaining], 37 | indices of remaining boxes in input boxes tensor. 38 | """ 39 | 40 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 41 | tf.shape(image), 42 | bounding_boxes=tf.expand_dims(boxes, 0), 43 | min_object_covered=min_object_covered, 44 | aspect_ratio_range=aspect_ratio_range, 45 | area_range=area_range, 46 | max_attempts=100, 47 | use_image_if_no_bounding_boxes=True 48 | ) 49 | begin, size, window = sample_distorted_bounding_box 50 | image = tf.slice(image, begin, size) 51 | image.set_shape([None, None, 3]) 52 | window = tf.squeeze(window, axis=[0, 1]) 53 | 54 | # remove boxes that are completely outside the cropped image 55 | boxes, inside_window_ids = prune_completely_outside_window(boxes, window) 56 | # why do i need this function? i believe the one below is enough 57 | 58 | # remove boxes that are too much outside the cropped image 59 | boxes, keep_indices = prune_non_overlapping_boxes( 60 | boxes, tf.expand_dims(window, 0), 61 | min_overlap=overlap_threshold 62 | ) 63 | 64 | # change coordinates of the remaining boxes 65 | boxes = change_coordinate_frame(boxes, window) 66 | 67 | keep_indices = tf.gather(inside_window_ids, keep_indices) 68 | return image, boxes, window, keep_indices 69 | 70 | 71 | def prune_completely_outside_window(boxes, window): 72 | """ 73 | Prunes bounding boxes that fall completely outside of the given window. 74 | This function does not clip partially overflowing boxes. 75 | 76 | Arguments: 77 | boxes: a float tensor with shape [M_in, 4]. 78 | window: a float tensor with shape [4] representing [ymin, xmin, ymax, xmax] 79 | of the window. 80 | Returns: 81 | boxes: a float tensor with shape [M_out, 4] where 0 <= M_out <= M_in. 82 | valid_indices: a long tensor with shape [M_out] indexing the valid bounding boxes 83 | in the input 'boxes' tensor. 84 | """ 85 | y_min, x_min, y_max, x_max = tf.split(boxes, num_or_size_splits=4, axis=1) 86 | # they have shape [None, 1] 87 | win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window) 88 | # they have shape [] 89 | 90 | coordinate_violations = tf.concat([ 91 | tf.greater_equal(y_min, win_y_max), tf.greater_equal(x_min, win_x_max), 92 | tf.less_equal(y_max, win_y_min), tf.less_equal(x_max, win_x_min) 93 | ], axis=1) 94 | valid_indices = tf.squeeze( 95 | tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), 96 | axis=1 97 | ) 98 | boxes = tf.gather(boxes, valid_indices) 99 | return boxes, valid_indices 100 | 101 | 102 | def prune_non_overlapping_boxes(boxes1, boxes2, min_overlap): 103 | """ 104 | Prunes the boxes in boxes1 that overlap less than thresh with boxes2. 105 | For each box in boxes1, we want its IOA to be more than min_overlap with 106 | at least one of the boxes in boxes2. If it does not, we remove it. 107 | 108 | Arguments: 109 | boxes1: a float tensor with shape [N, 4]. 110 | boxes2: a float tensor with shape [M, 4]. 111 | min_overlap: minimum required overlap between boxes, 112 | to count them as overlapping. 113 | Returns: 114 | boxes: a float tensor with shape [N', 4]. 115 | keep_indices: a long tensor with shape [N'] indexing kept bounding boxes in the 116 | first input tensor ('boxes1'). 117 | """ 118 | overlap = ioa(boxes2, boxes1) # shape [M, N] 119 | overlap = tf.reduce_max(overlap, axis=0) # shape [N] 120 | 121 | keep_bool = tf.greater_equal(overlap, min_overlap) 122 | keep_indices = tf.squeeze(tf.where(keep_bool), axis=1) 123 | 124 | boxes = tf.gather(boxes1, keep_indices) 125 | return boxes, keep_indices 126 | 127 | 128 | def change_coordinate_frame(boxes, window): 129 | """Change coordinate frame of the boxes to be relative to window's frame. 130 | 131 | Arguments: 132 | boxes: a float tensor with shape [N, 4]. 133 | window: a float tensor with shape [4]. 134 | Returns: 135 | a float tensor with shape [N, 4]. 136 | """ 137 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 138 | ymin -= window[0] 139 | xmin -= window[1] 140 | ymax -= window[0] 141 | xmax -= window[1] 142 | 143 | win_height = window[2] - window[0] 144 | win_width = window[3] - window[1] 145 | boxes = tf.stack([ 146 | ymin/win_height, xmin/win_width, 147 | ymax/win_height, xmax/win_width 148 | ], axis=1) 149 | boxes = tf.clip_by_value(boxes, 0.0, 1.0) 150 | return boxes 151 | 152 | 153 | def ioa(boxes1, boxes2): 154 | """ 155 | Computes pairwise intersection-over-area between box collections. 156 | intersection-over-area (IOA) between two boxes box1 and box2 is defined as 157 | their intersection area over box2's area. Note that ioa is not symmetric, 158 | that is, ioa(box1, box2) != ioa(box2, box1). 159 | 160 | Arguments: 161 | boxes1: a float tensor with shape [N, 4]. 162 | boxes2: a float tensor with shape [M, 4]. 163 | Returns: 164 | a float tensor with shape [N, M] representing pairwise ioa scores. 165 | """ 166 | intersections = intersection(boxes1, boxes2) # shape [N, M] 167 | areas = tf.expand_dims(area(boxes2), 0) # shape [1, M] 168 | return tf.clip_by_value(tf.divide(intersections, areas + EPSILON), 0.0, 1.0) 169 | -------------------------------------------------------------------------------- /detector/input_pipeline/random_rotation.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | import tensorflow.contrib as contrib 3 | from detector.constants import DOWNSAMPLE, OVERLAP_THRESHOLD 4 | from detector.input_pipeline.random_crop import prune_non_overlapping_boxes 5 | 6 | 7 | def random_image_rotation(image, masks, boxes, keypoints, max_angle=45, probability=0.9): 8 | """ 9 | What this function does: 10 | 1. It takes a random box and rotates everything around its center. 11 | 2. Then it rescales the image so that the box not too small or not too big. 12 | 3. Then it translates the image's center to be at the box's center. 13 | 14 | All coordinates are absolute: 15 | 1. Boxes have coordinates in ranges [0, height] and [0, width]. 16 | 2. Keypoints have coordinates in ranges [0, height - 1] and [0, width - 1]. 17 | 18 | Arguments: 19 | image: a float tensor with shape [height, width, 3]. 20 | masks: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE, 2]. 21 | boxes: a float tensor with shape [num_persons, 4]. 22 | keypoints: an int tensor with shape [num_persons, 17, 3]. 23 | max_angle: an integer. 24 | probability: a float number. 25 | Returns: 26 | image: a float tensor with shape [height, width, 3]. 27 | masks: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE, 2]. 28 | boxes: a float tensor with shape [num_remaining_boxes, 4], 29 | where num_remaining_boxes <= num_persons. 30 | keypoints: an int tensor with shape [num_remaining_boxes, 17, 3]. 31 | """ 32 | def rotate(image, masks, boxes, keypoints): 33 | 34 | # get the center of the image 35 | image_shape = tf.to_float(tf.shape(image)) 36 | image_height = image_shape[0] 37 | image_width = image_shape[1] 38 | image_center = 0.5 * tf.stack([image_height, image_width]) 39 | image_center = tf.reshape(image_center, [1, 2]) 40 | 41 | box_center, box_width = get_random_box_center(boxes, image_height, image_width) 42 | rotation = get_random_rotation(max_angle, box_center, image_width) 43 | scaler = get_random_scaling(box_center, box_width, image_width) 44 | 45 | rotation *= scaler 46 | translation = image_center - tf.matmul(box_center, rotation) 47 | 48 | """ 49 | Assume tensor `points` has shape [n, 2]. 50 | 1. points = points - box_center (translate center of the coordinate system to the box center) 51 | 2. points = points * rotation (rotate and scale relative to the new center) 52 | 3. points = points + box_center (translate back) 53 | 4. points = points - center_translation (translate image center to the box center) 54 | 55 | So full transformation is: 56 | (points - box_center) * rotation + box_center - center_translation = 57 | = points * rotation + translation, where translation = image_center - rotation * box_center. 58 | """ 59 | 60 | boxes = transform_boxes(boxes, rotation, translation) 61 | keypoints = transform_keypoints(keypoints, rotation, translation) 62 | # after this some boxes and keypoints could be out of the image 63 | 64 | boxes, keypoints = correct(boxes, keypoints, image_height, image_width) 65 | # now all boxes and keypoints are inside the image 66 | 67 | transform = get_inverse_transform(rotation, translation) 68 | image = contrib.image.transform(image, transform, interpolation='BILINEAR') 69 | 70 | # masks are smaller than the image 71 | scaler = tf.stack([1, 1, DOWNSAMPLE, 1, 1, DOWNSAMPLE, 1, 1]) 72 | masks_transform = transform / tf.to_float(scaler) 73 | 74 | masks = contrib.image.transform(masks, masks_transform, interpolation='NEAREST') 75 | # masks are binary so we use the nearest neighbor interpolation 76 | 77 | return image, masks, boxes, keypoints 78 | 79 | do_it = tf.less(tf.random_uniform([]), probability) 80 | image, masks, boxes, keypoints = tf.cond( 81 | do_it, 82 | lambda: rotate(image, masks, boxes, keypoints), 83 | lambda: (image, masks, boxes, keypoints) 84 | ) 85 | return image, masks, boxes, keypoints 86 | 87 | 88 | def get_random_box_center(boxes, image_height, image_width): 89 | """ 90 | Arguments: 91 | boxes: a float tensor with shape [num_persons, 4]. 92 | image_height, image_width: float tensors with shape []. 93 | Returns: 94 | box_center: a float tensor with shape [1, 2]. 95 | box_width: a float tensor with shape []. 96 | """ 97 | 98 | # get a random bounding box 99 | box = tf.random_shuffle(boxes)[0] 100 | # it has shape [4] 101 | 102 | ymin, xmin, ymax, xmax = tf.unstack(box) 103 | box_height, box_width = ymax - ymin, xmax - xmin 104 | 105 | # get the center of the box 106 | cy = ymin + 0.5 * box_height 107 | cx = xmin + 0.5 * box_width 108 | 109 | # we will rotate around the box's center, 110 | # but the center mustn't be too near to the border of the image 111 | cy = tf.clip_by_value(cy, 0.25 * image_height, 0.75 * image_height) 112 | cx = tf.clip_by_value(cx, 0.2 * image_width, 0.8 * image_width) 113 | box_center = tf.stack([cy, cx]) 114 | box_center = tf.reshape(box_center, [1, 2]) 115 | 116 | return box_center, box_width 117 | 118 | 119 | def get_random_rotation(max_angle, rotation_center, image_width): 120 | """ 121 | Arguments: 122 | max_angle: an integer, angle in degrees. 123 | rotation_center: a float tensor with shape [1, 2]. 124 | image_width: a float tensor with shape []. 125 | Returns: 126 | a float tensor with shape [2, 2]. 127 | """ 128 | 129 | PI = 3.141592653589793 130 | max_angle_radians = max_angle * (PI/180.0) 131 | 132 | # x-coordinate of the rotation center 133 | cx = rotation_center[0, 1] 134 | 135 | # relative distance between centers 136 | distance_to_image_center = tf.abs(cx - 0.5 * image_width) 137 | distance_to_image_center /= image_width 138 | 139 | # if the center is too near to the borders then 140 | # reduce the maximal rotation angle 141 | decay = (0.6 - 2.0 * distance_to_image_center)/0.6 142 | decay = tf.maximum(decay, 0.0) 143 | max_angle_radians *= decay 144 | 145 | # decay is in [0, 1] range, 146 | # decay = 1 if cx = 0.5 * image_width, 147 | # decay = 0 if cx = 0.2 * image_width 148 | 149 | # get a random angle 150 | theta = tf.random_uniform( 151 | [], minval=-max_angle_radians, 152 | maxval=max_angle_radians, 153 | dtype=tf.float32 154 | ) 155 | 156 | rotation = tf.stack([ 157 | tf.cos(theta), tf.sin(theta), 158 | -tf.sin(theta), tf.cos(theta) 159 | ], axis=0) 160 | rotation = tf.reshape(rotation, [2, 2]) 161 | 162 | return rotation 163 | 164 | 165 | def get_random_scaling(rotation_center, box_width, image_width): 166 | """ 167 | Arguments: 168 | rotation_center: a float tensor with shape [1, 2]. 169 | box_width: a float tensor with shape []. 170 | image_width: a float tensor with shape []. 171 | Returns: 172 | a float tensor with shape []. 173 | """ 174 | 175 | # x-coordinate of the rotation center 176 | cx = rotation_center[0, 1] 177 | 178 | # the distance to the nearest border 179 | distance = tf.minimum(cx, image_width - cx) 180 | 181 | # i believe this minimizes the amount 182 | # of zero padding after rescaling 183 | necessary_scale = image_width/(2.0 * distance) 184 | # it is always bigger or equal to 1 185 | 186 | # with this scaling the distance to the 187 | # nearest border will be half of image width 188 | 189 | size_ratio = image_width/box_width 190 | # it is always bigger or equal to 1 191 | 192 | # new box width will be 193 | # maximum one third of the image width 194 | max_scale = size_ratio/3.0 195 | 196 | min_scale = tf.maximum(size_ratio/8.0, necessary_scale) 197 | # this is all very confusing 198 | 199 | min_scale = tf.minimum(min_scale, max_scale - 1e-4) 200 | # now always min_scale < max_scale 201 | 202 | # get a random image scaler 203 | scaler = tf.random_uniform( 204 | [], minval=min_scale, 205 | maxval=max_scale, 206 | dtype=tf.float32 207 | ) 208 | 209 | return scaler 210 | 211 | 212 | def transform_boxes(boxes, rotation, translation): 213 | """ 214 | Arguments: 215 | boxes: a float tensor with shape [num_persons, 4]. 216 | rotation: a float tensor with shape [2, 2]. 217 | translation: a float tensor with shape [1, 2]. 218 | Returns: 219 | a float tensor with shape [num_persons, 4]. 220 | """ 221 | 222 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 223 | p1 = tf.stack([ymin, xmin], axis=1) # top left 224 | p2 = tf.stack([ymin, xmax], axis=1) # top right 225 | p3 = tf.stack([ymax, xmin], axis=1) # buttom left 226 | p4 = tf.stack([ymax, xmax], axis=1) # buttom right 227 | points = tf.concat([p1, p2, p3, p4], axis=0) 228 | # it has shape [4 * num_persons, 2] 229 | 230 | points = tf.matmul(points, rotation) + translation 231 | p1, p2, p3, p4 = tf.split(points, num_or_size_splits=4, axis=0) 232 | 233 | # get boxes that contain the original boxes 234 | ymin = tf.minimum(p1[:, 0], p2[:, 0]) 235 | ymax = tf.maximum(p3[:, 0], p4[:, 0]) 236 | xmin = tf.minimum(p1[:, 1], p3[:, 1]) 237 | xmax = tf.maximum(p2[:, 1], p4[:, 1]) 238 | 239 | boxes = tf.stack([ymin, xmin, ymax, xmax], axis=1) 240 | return boxes 241 | 242 | 243 | def transform_keypoints(keypoints, rotation, translation): 244 | """ 245 | Arguments: 246 | keypoints: an int tensor with shape [num_persons, 17, 3]. 247 | rotation: a float tensor with shape [2, 2]. 248 | translation: a float tensor with shape [1, 2]. 249 | Returns: 250 | an int tensor with shape [num_persons, 17, 3]. 251 | """ 252 | 253 | points, v = tf.split(keypoints, [2, 1], axis=2) 254 | # they have shapes [num_persons, 17, 2] and [num_persons, 17, 1] 255 | 256 | points = tf.to_float(points) 257 | points = tf.reshape(points, [-1, 2]) 258 | points = tf.matmul(points, rotation) + translation 259 | points = tf.to_int32(tf.round(points)) 260 | points = tf.reshape(points, [-1, 17, 2]) 261 | keypoints = tf.concat([points, v], axis=2) 262 | 263 | return keypoints 264 | 265 | 266 | def correct(boxes, keypoints, image_height, image_width): 267 | """ 268 | Remove boxes and keypoints that are outside of the image. 269 | 270 | Arguments: 271 | boxes: a float tensor with shape [num_persons, 4]. 272 | keypoints: an int tensor with shape [num_persons, 17, 3]. 273 | image_height, image_width: float tensors with shape []. 274 | Returns: 275 | boxes: a float tensor with shape [num_remaining_boxes, 4], 276 | where num_remaining_boxes <= num_persons. 277 | keypoints: an int tensor with shape [num_remaining_boxes, 17, 3]. 278 | """ 279 | 280 | window = tf.stack([0.0, 0.0, image_height, image_width]) 281 | boxes, keep_indices = prune_non_overlapping_boxes( 282 | boxes, tf.expand_dims(window, 0), 283 | min_overlap=OVERLAP_THRESHOLD 284 | ) 285 | 286 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 287 | ymin = tf.clip_by_value(ymin, 0.0, image_height) 288 | xmin = tf.clip_by_value(xmin, 0.0, image_width) 289 | ymax = tf.clip_by_value(ymax, 0.0, image_height) 290 | xmax = tf.clip_by_value(xmax, 0.0, image_width) 291 | boxes = tf.stack([ymin, xmin, ymax, xmax], axis=1) 292 | 293 | keypoints = tf.gather(keypoints, keep_indices) 294 | y, x, v = tf.split(keypoints, 3, axis=2) 295 | 296 | image_height = tf.to_int32(image_height) 297 | image_width = tf.to_int32(image_width) 298 | 299 | coordinate_violations = tf.concat([ 300 | tf.less(y, 0), tf.less(x, 0), 301 | tf.greater_equal(y, image_height), 302 | tf.greater_equal(x, image_width) 303 | ], axis=2) # shape [num_remaining_boxes, 17, 4] 304 | 305 | valid_indicator = tf.logical_not(tf.reduce_any(coordinate_violations, axis=2)) 306 | valid_indicator = tf.expand_dims(valid_indicator, 2) 307 | # it has shape [num_remaining_boxes, 17, 1] 308 | 309 | v *= tf.to_int32(valid_indicator) 310 | keypoints = tf.concat([y, x, v], axis=2) 311 | 312 | return boxes, keypoints 313 | 314 | 315 | def get_inverse_transform(rotation, translation): 316 | """ 317 | If y = x * rotation + translation 318 | then x = (y - translation) * inverse_rotation. 319 | 320 | Or x = y * inverse_rotation + inverse_translation, 321 | where inverse_translation = - translation * inverse_rotation. 322 | 323 | This function returns transformation in the 324 | format required by `tf.contrib.image.transform`. 325 | 326 | Arguments: 327 | rotation: a float tensor with shape [2, 2]. 328 | translation: a float tensor with shape [1, 2]. 329 | Returns: 330 | a float tensor with shape [8]. 331 | """ 332 | 333 | a, b = rotation[0, 0], rotation[0, 1] 334 | c, d = rotation[1, 0], rotation[1, 1] 335 | 336 | inverse_rotation = tf.stack([d, -b, -c, a]) / (a * d - b * c) 337 | inverse_rotation = tf.reshape(inverse_rotation, [2, 2]) 338 | 339 | inverse_translation = - tf.matmul(translation, inverse_rotation) 340 | inverse_translation = tf.squeeze(inverse_translation, axis=0) 341 | # it has shape [2] 342 | 343 | translate_y, translate_x = tf.unstack(inverse_translation, axis=0) 344 | transform = tf.stack([ 345 | inverse_rotation[0, 0], inverse_rotation[0, 1], translate_x, 346 | inverse_rotation[1, 0], inverse_rotation[1, 1], translate_y, 347 | 0.0, 0.0 348 | ]) 349 | 350 | return transform 351 | -------------------------------------------------------------------------------- /detector/keypoint_subnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.constants import NUM_KEYPOINTS, DATA_FORMAT 3 | from detector.utils import batch_norm_relu, conv2d_same 4 | from detector.fpn import feature_pyramid_network 5 | 6 | 7 | DEPTH = 128 8 | 9 | 10 | class KeypointSubnet: 11 | def __init__(self, backbone_features, is_training, params): 12 | """ 13 | Arguments: 14 | backbone_features: a dict with float tensors. 15 | It contains keys ['c2', 'c3', 'c4', 'c5']. 16 | is_training: a boolean. 17 | params: a dict. 18 | """ 19 | 20 | self.enriched_features = feature_pyramid_network( 21 | backbone_features, is_training, depth=DEPTH, min_level=2, 22 | add_coarse_features=False, scope='keypoint_fpn' 23 | ) 24 | normalized_enriched_features = { 25 | n: batch_norm_relu(x, is_training, name=f'{n}_batch_norm') 26 | for n, x in self.enriched_features.items() 27 | } 28 | # it is a dict with keys ['p2', 'p3', 'p4', 'p5'] 29 | 30 | upsampled_features = [] 31 | for level in range(2, 6): 32 | with tf.variable_scope(f'phi_subnet_{level}'): 33 | x = normalized_enriched_features[f'p{level}'] 34 | y = phi_subnet(x, is_training, upsample=2**(level - 2)) 35 | upsampled_features.append(y) 36 | 37 | upsampled_features = tf.concat(upsampled_features, axis=1 if DATA_FORMAT == 'channels_first' else 3) 38 | x = conv2d_same(upsampled_features, 64, kernel_size=3, name='final_conv3x3') 39 | x = batch_norm_relu(x, is_training, name='final_bn') 40 | 41 | p = 0.01 # probability of a keypoint 42 | # sigmoid(-log((1 - p) / p)) = p 43 | 44 | import math 45 | value = -math.log((1.0 - p) / p) 46 | keypoints_bias = 17 * [value] 47 | bias_initializer = tf.constant_initializer(keypoints_bias + [0.0]) 48 | 49 | self.heatmaps = tf.layers.conv2d( 50 | x, NUM_KEYPOINTS + 1, kernel_size=1, padding='same', 51 | bias_initializer=bias_initializer, 52 | kernel_initializer=tf.random_normal_initializer(stddev=1e-4), 53 | data_format=DATA_FORMAT, name='heatmaps' 54 | ) 55 | 56 | if DATA_FORMAT == 'channels_first': 57 | 58 | self.heatmaps = tf.transpose(self.heatmaps, [0, 2, 3, 1]) 59 | self.enriched_features = { 60 | n: tf.transpose(x, [0, 2, 3, 1]) 61 | for n, x in self.enriched_features.items() 62 | } 63 | 64 | 65 | def phi_subnet(x, is_training, upsample): 66 | """ 67 | Arguments: 68 | x: a float tensor with shape [b, h, w, c]. 69 | is_training: a boolean. 70 | upsample: an integer. 71 | Returns: 72 | a float tensor with shape [b, upsample * h, upsample * w, depth]. 73 | """ 74 | 75 | x = conv2d_same(x, DEPTH, kernel_size=3, name='conv1') 76 | x = batch_norm_relu(x, is_training, name='bn1') 77 | x = conv2d_same(x, DEPTH, kernel_size=3, name='conv2') 78 | x = batch_norm_relu(x, is_training, name='bn2') 79 | 80 | if DATA_FORMAT == 'channels_first': 81 | x = tf.transpose(x, [0, 2, 3, 1]) 82 | 83 | shape = tf.shape(x) 84 | h, w = shape[1], shape[2] 85 | new_size = [upsample * h, upsample * w] 86 | x = tf.image.resize_bilinear(x, new_size) 87 | 88 | if DATA_FORMAT == 'channels_first': 89 | x = tf.transpose(x, [0, 3, 1, 2]) 90 | 91 | return x 92 | -------------------------------------------------------------------------------- /detector/prn.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | import tensorflow.contrib.slim as slim 3 | 4 | 5 | def prn(x, is_training): 6 | """ 7 | Arguments: 8 | x: a float tensor with shape [b, h, w, c]. 9 | is_training: a boolean. 10 | Returns: 11 | a float tensor with shape [b, h, w, c]. 12 | """ 13 | with tf.variable_scope('PRN'): 14 | 15 | b = tf.shape(x)[0] 16 | _, h, w, c = x.shape.as_list() # must be static 17 | x = tf.reshape(x, [b, h * w * c]) # flatten 18 | 19 | with slim.arg_scope([slim.fully_connected], weights_initializer=tf.variance_scaling_initializer()): 20 | y = slim.fully_connected(x, 1024, activation_fn=tf.nn.relu, scope='fc1') 21 | # y = slim.dropout(y, keep_prob=0.5, is_training=is_training) 22 | y = slim.fully_connected(y, h * w * c, activation_fn=tf.nn.relu, scope='fc2') 23 | 24 | x += y 25 | return tf.reshape(x, [b, h, w, c]) 26 | -------------------------------------------------------------------------------- /detector/retinanet.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.constants import PARALLEL_ITERATIONS, POSITIVES_THRESHOLD, NEGATIVES_THRESHOLD 3 | from detector.utils import batch_non_max_suppression, batch_norm_relu 4 | from detector.training_target_creation import get_training_targets 5 | from detector.box_predictor import retinanet_box_predictor 6 | from detector.anchor_generator import AnchorGenerator 7 | from detector.fpn import feature_pyramid_network 8 | 9 | 10 | DEPTH = 128 11 | 12 | 13 | class RetinaNet: 14 | def __init__(self, backbone_features, image_shape, is_training, params): 15 | """ 16 | Arguments: 17 | backbone_features: a dict with float tensors. 18 | It contains keys ['c2', 'c3', 'c4', 'c5']. 19 | image_shape: an int tensor with shape [4]. 20 | is_training: a boolean. 21 | params: a dict. 22 | """ 23 | 24 | enriched_features = feature_pyramid_network( 25 | backbone_features, is_training, 26 | depth=DEPTH, min_level=3, 27 | add_coarse_features=True, scope='fpn' 28 | ) 29 | enriched_features = { 30 | n: batch_norm_relu(x, is_training, name=f'{n}_batch_norm') 31 | for n, x in enriched_features.items() 32 | } 33 | 34 | # the detector supports images of various sizes 35 | image_height = image_shape[1] 36 | image_width = image_shape[2] 37 | 38 | anchor_generator = AnchorGenerator( 39 | strides=[8, 16, 32, 64, 128], 40 | scales=[32, 64, 128, 256, 512], 41 | scale_multipliers=[1.0, 1.4142], 42 | aspect_ratios=[1.0, 2.0, 0.5] 43 | ) 44 | self.anchors = anchor_generator(image_height, image_width) # shape [num_anchors, 4] 45 | num_anchors_per_location = anchor_generator.num_anchors_per_location 46 | 47 | self.raw_predictions = retinanet_box_predictor( 48 | [enriched_features[f'p{i}'] for i in range(3, 8)], 49 | is_training, num_anchors_per_location=num_anchors_per_location, 50 | depth=64, min_level=3 51 | ) 52 | # it returns a dict with two float tensors: 53 | # `encoded_boxes` has shape [batch_size, num_anchors, 4], 54 | # `class_predictions` has shape [batch_size, num_anchors] 55 | 56 | def get_predictions(self, score_threshold=0.05, iou_threshold=0.5, max_detections=25): 57 | """Postprocess outputs of the network. 58 | 59 | Returns: 60 | boxes: a float tensor with shape [batch_size, N, 4]. 61 | scores: a float tensor with shape [batch_size, N]. 62 | num_boxes: an int tensor with shape [batch_size], it 63 | represents the number of detections on an image. 64 | 65 | Where N = max_detections. 66 | """ 67 | with tf.name_scope('postprocessing'): 68 | 69 | encoded_boxes = self.raw_predictions['encoded_boxes'] 70 | # it has shape [batch_size, num_anchors, 4] 71 | 72 | class_predictions = self.raw_predictions['class_predictions'] 73 | scores = tf.sigmoid(class_predictions) 74 | # it has shape [batch_size, num_anchors] 75 | 76 | with tf.name_scope('nms'): 77 | boxes, scores, num_detections = batch_non_max_suppression( 78 | encoded_boxes, self.anchors, scores, score_threshold=score_threshold, 79 | iou_threshold=iou_threshold, max_detections=max_detections 80 | ) 81 | return {'boxes': boxes, 'scores': scores, 'num_boxes': num_detections} 82 | 83 | def loss(self, groundtruth, params): 84 | """Compute scalar loss tensors with respect to provided groundtruth. 85 | 86 | Arguments: 87 | groundtruth: a dict with the following keys 88 | 'boxes': a float tensor with shape [batch_size, max_num_boxes, 4]. 89 | 'num_boxes': an int tensor with shape [batch_size], 90 | where max_num_boxes = max(num_boxes). 91 | params: a dict with parameters. 92 | Returns: 93 | two float tensors with shape []. 94 | """ 95 | regression_targets, matches = self._create_targets(groundtruth) 96 | 97 | with tf.name_scope('losses'): 98 | 99 | # whether an anchor contains something 100 | is_matched = tf.to_float(tf.greater_equal(matches, 0)) 101 | 102 | not_ignore = tf.to_float(tf.greater_equal(matches, -1)) 103 | # if a value is `-2` then we ignore its anchor 104 | 105 | with tf.name_scope('classification_loss'): 106 | 107 | class_predictions = self.raw_predictions['class_predictions'] 108 | # shape [batch_size, num_anchors] 109 | 110 | cls_losses = focal_loss( 111 | class_predictions, is_matched, weights=not_ignore, 112 | gamma=params['gamma'], alpha=params['alpha'] 113 | ) # shape [batch_size, num_anchors] 114 | 115 | cls_loss = tf.reduce_sum(cls_losses, axis=[0, 1]) 116 | 117 | with tf.name_scope('localization_loss'): 118 | 119 | encoded_boxes = self.raw_predictions['encoded_boxes'] 120 | # it has shape [batch_size, num_anchors, 4] 121 | 122 | loc_losses = localization_loss( 123 | encoded_boxes, regression_targets, 124 | weights=is_matched 125 | ) # shape [batch_size, num_anchors] 126 | 127 | loc_loss = tf.reduce_sum(loc_losses, axis=[0, 1]) 128 | 129 | with tf.name_scope('normalization'): 130 | matches_per_image = tf.reduce_sum(is_matched, axis=1) # shape [batch_size] 131 | num_matches = tf.reduce_sum(matches_per_image) # shape [] 132 | normalizer = tf.maximum(num_matches, 1.0) 133 | 134 | return {'localization_loss': loc_loss/normalizer, 'classification_loss': cls_loss/normalizer} 135 | 136 | def _create_targets(self, groundtruth): 137 | """ 138 | Arguments: 139 | groundtruth: a dict with the following keys 140 | 'boxes': a float tensor with shape [batch_size, N, 4]. 141 | 'num_boxes': an int tensor with shape [batch_size]. 142 | Returns: 143 | regression_targets: a float tensor with shape [batch_size, num_anchors, 4]. 144 | matches: an int tensor with shape [batch_size, num_anchors], 145 | `-1` means that an anchor box is negative (background), 146 | and `-2` means that we must ignore this anchor box. 147 | """ 148 | def fn(x): 149 | boxes, num_boxes = x 150 | boxes = boxes[:num_boxes] 151 | 152 | regression_targets, matches = get_training_targets( 153 | self.anchors, boxes, 154 | positives_threshold=POSITIVES_THRESHOLD, 155 | negatives_threshold=NEGATIVES_THRESHOLD 156 | ) 157 | return regression_targets, matches 158 | 159 | with tf.name_scope('target_creation'): 160 | regression_targets, matches = tf.map_fn( 161 | fn, [groundtruth['boxes'], groundtruth['num_boxes']], 162 | dtype=(tf.float32, tf.int32), 163 | parallel_iterations=PARALLEL_ITERATIONS, 164 | back_prop=False, swap_memory=False, infer_shape=True 165 | ) 166 | return regression_targets, matches 167 | 168 | 169 | def localization_loss(predictions, targets, weights): 170 | """A usual L1 smooth loss. 171 | 172 | Arguments: 173 | predictions: a float tensor with shape [batch_size, num_anchors, 4], 174 | representing the (encoded) predicted locations of objects. 175 | targets: a float tensor with shape [batch_size, num_anchors, 4], 176 | representing the regression targets. 177 | weights: a float tensor with shape [batch_size, num_anchors]. 178 | Returns: 179 | a float tensor with shape [batch_size, num_anchors]. 180 | """ 181 | abs_diff = tf.abs(predictions - targets) 182 | abs_diff_lt_1 = tf.less(abs_diff, 1.0) 183 | loss = tf.where(abs_diff_lt_1, 0.5 * tf.square(abs_diff), abs_diff - 0.5) 184 | return weights * tf.reduce_sum(loss, axis=2) 185 | 186 | 187 | def focal_loss(predictions, targets, weights, gamma=2.0, alpha=0.25): 188 | """ 189 | Here it is assumed that there is only one class. 190 | 191 | Arguments: 192 | predictions: a float tensor with shape [batch_size, num_anchors], 193 | representing the predicted logits. 194 | targets: a float tensor with shape [batch_size, num_anchors], 195 | representing binary classification targets. 196 | weights: a float tensor with shape [batch_size, num_anchors]. 197 | gamma, alpha: float numbers. 198 | Returns: 199 | a float tensor with shape [batch_size, num_anchors]. 200 | """ 201 | positive_label_mask = tf.equal(targets, 1.0) 202 | 203 | negative_log_p_t = tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=predictions) 204 | probabilities = tf.sigmoid(predictions) 205 | p_t = tf.where(positive_label_mask, probabilities, 1.0 - probabilities) 206 | # they all have shape [batch_size, num_anchors] 207 | 208 | modulating_factor = tf.pow(1.0 - p_t, gamma) 209 | weighted_loss = tf.where( 210 | positive_label_mask, 211 | alpha * negative_log_p_t, 212 | (1.0 - alpha) * negative_log_p_t 213 | ) 214 | focal_loss = modulating_factor * weighted_loss 215 | # they all have shape [batch_size, num_anchors] 216 | 217 | return weights * focal_loss 218 | -------------------------------------------------------------------------------- /detector/training_target_creation.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.utils import encode, iou 3 | 4 | 5 | def get_training_targets( 6 | anchors, groundtruth_boxes, 7 | positives_threshold=0.5, 8 | negatives_threshold=0.4): 9 | """ 10 | Arguments: 11 | anchors: a float tensor with shape [num_anchors, 4]. 12 | groundtruth_boxes: a float tensor with shape [N, 4]. 13 | positives_threshold: a float number. 14 | negatives_threshold: a float number. 15 | Returns: 16 | regression_targets: a float tensor with shape [num_anchors, 4]. 17 | matches: an int tensor with shape [num_anchors]. 18 | """ 19 | 20 | with tf.name_scope('matching'): 21 | 22 | N = tf.shape(groundtruth_boxes)[0] 23 | num_anchors = tf.shape(anchors)[0] 24 | only_background = tf.fill([num_anchors], -1) 25 | 26 | matches = tf.to_int32(tf.cond( 27 | tf.greater(N, 0), 28 | lambda: match_boxes( 29 | anchors, groundtruth_boxes, 30 | positives_threshold=positives_threshold, 31 | negatives_threshold=negatives_threshold, 32 | force_match_groundtruth=True 33 | ), 34 | lambda: only_background 35 | )) 36 | 37 | with tf.name_scope('target_creation'): 38 | regression_targets = create_targets( 39 | anchors, groundtruth_boxes, matches 40 | ) 41 | 42 | return regression_targets, matches 43 | 44 | 45 | def match_boxes( 46 | anchors, groundtruth_boxes, positives_threshold=0.5, 47 | negatives_threshold=0.4, force_match_groundtruth=True): 48 | """ 49 | If an anchor has IoU over `positives_threshold` with any groundtruth box, 50 | it will be set a positive label. 51 | Anchors which have highest IoU for a groundtruth box will 52 | also be assigned a positive label. 53 | Meanwhile, if other anchors have IoU less than `negatives_threshold` 54 | with all groundtruth boxes, their labels will be negative. 55 | 56 | Matching algorithm: 57 | 1) for each groundtruth box choose the anchor with largest IoU, 58 | 2) remove this set of anchors from the set of all anchors, 59 | 3) for each remaining anchor choose the groundtruth box with largest IoU, 60 | but only if this IoU is larger than `positives_threshold`, 61 | 4) remove this set of matched anchors from the set of all anchors, 62 | 5) for each remaining anchor if it has IoU less than `negatives_threshold` 63 | with all groundtruth boxes set it to `negative`, otherwise set it to `ignore`. 64 | 65 | Note: after step 1, it could happen that for some two groundtruth boxes 66 | chosen anchors are the same. Let's hope this never happens. 67 | Also see the comments below. 68 | 69 | Arguments: 70 | anchors: a float tensor with shape [num_anchors, 4]. 71 | groundtruth_boxes: a float tensor with shape [N, 4]. 72 | positives_threshold: a float number. 73 | negatives_threshold: a float number. 74 | force_match_groundtruth: a boolean, whether to try to make sure 75 | that all groundtruth boxes are matched. 76 | Returns: 77 | an int tensor with shape [num_anchors], possible values 78 | that it can contain are [-2, -1, 0, 1, 2, ..., (N - 1)], 79 | where numbers in the range [0, N - 1] mean indices of the groundtruth boxes, 80 | `-1` means that an anchor box is negative (background), 81 | and `-2` means that we must ignore this anchor box. 82 | """ 83 | assert positives_threshold >= negatives_threshold 84 | 85 | # for each anchor box choose the groundtruth box with largest iou 86 | similarity_matrix = iou(groundtruth_boxes, anchors) # shape [N, num_anchors] 87 | matches = tf.argmax(similarity_matrix, axis=0, output_type=tf.int32) # shape [num_anchors] 88 | matched_vals = tf.reduce_max(similarity_matrix, axis=0) # shape [num_anchors] 89 | is_positive = tf.to_int32(tf.greater_equal(matched_vals, positives_threshold)) 90 | 91 | if positives_threshold == negatives_threshold: 92 | is_negative = 1 - is_positive 93 | matches = matches * is_positive + (-1 * is_negative) 94 | else: 95 | is_negative = tf.to_int32(tf.greater(negatives_threshold, matched_vals)) 96 | to_ignore = (1 - is_positive) * (1 - is_negative) 97 | matches = matches * is_positive + (-1 * is_negative) + (-2 * to_ignore) 98 | 99 | # after this, it could happen that some groundtruth 100 | # boxes are not matched with any anchor box 101 | 102 | if force_match_groundtruth: 103 | # now we must ensure that each row (groundtruth box) is matched to 104 | # at least one column (which is not guaranteed 105 | # otherwise if `positives_threshold` is high) 106 | 107 | # for each groundtruth box choose the anchor box with largest iou 108 | # (force match for each groundtruth box) 109 | forced_matches_ids = tf.argmax(similarity_matrix, axis=1, output_type=tf.int32) # shape [N] 110 | # if all indices in forced_matches_ids are different then all rows will be matched 111 | 112 | num_anchors = tf.shape(anchors)[0] 113 | forced_matches_indicators = tf.one_hot(forced_matches_ids, depth=num_anchors, dtype=tf.int32) # shape [N, num_anchors] 114 | forced_match_row_ids = tf.argmax(forced_matches_indicators, axis=0, output_type=tf.int32) # shape [num_anchors] 115 | 116 | # some forced matches could be very bad! 117 | forced_matches_values = tf.reduce_max(similarity_matrix, axis=1) # shape [N] 118 | small_iou = 0.05 # this requires that forced match has at least small intersection 119 | is_okay = tf.to_int32(tf.greater_equal(forced_matches_values, small_iou)) # shape [N] 120 | forced_matches_indicators = forced_matches_indicators * tf.expand_dims(is_okay, axis=1) 121 | 122 | forced_match_mask = tf.greater(tf.reduce_max(forced_matches_indicators, axis=0), 0) # shape [num_anchors] 123 | matches = tf.where(forced_match_mask, forced_match_row_ids, matches) 124 | # even after this it could happen that some rows aren't matched, 125 | # but i believe that this event has low probability 126 | 127 | return matches 128 | 129 | 130 | def create_targets(anchors, groundtruth_boxes, matches): 131 | """Returns regression and classification targets for each anchor. 132 | 133 | Arguments: 134 | anchors: a float tensor with shape [num_anchors, 4]. 135 | groundtruth_boxes: a float tensor with shape [N, 4]. 136 | matches: an int tensor with shape [num_anchors]. 137 | Returns: 138 | a float tensor with shape [num_anchors, 4]. 139 | """ 140 | matched_anchor_indices = tf.where(tf.greater_equal(matches, 0)) # shape [num_matches, 1] 141 | matched_anchor_indices = tf.to_int32(tf.squeeze(matched_anchor_indices, axis=1)) 142 | 143 | unmatched_anchor_indices = tf.where(tf.less(matches, 0)) # shape [num_anchors - num_matches, 1] 144 | unmatched_anchor_indices = tf.to_int32(tf.squeeze(unmatched_anchor_indices, axis=1)) 145 | 146 | matched_gt_indices = tf.gather(matches, matched_anchor_indices) # shape [num_matches] 147 | matched_gt_boxes = tf.gather(groundtruth_boxes, matched_gt_indices) # shape [num_matches, 4] 148 | matched_anchors = tf.gather(anchors, matched_anchor_indices) # shape [num_matches, 4] 149 | 150 | matched_reg_targets = encode(matched_gt_boxes, matched_anchors) # shape [num_matches, 4] 151 | num_unmatched = tf.size(unmatched_anchor_indices) # num_anchors - num_matches 152 | unmatched_reg_targets = tf.zeros([num_unmatched, 4], dtype=tf.float32) 153 | 154 | regression_targets = tf.dynamic_stitch( 155 | [matched_anchor_indices, unmatched_anchor_indices], 156 | [matched_reg_targets, unmatched_reg_targets] 157 | ) # shape [num_anchors, 4] 158 | 159 | return regression_targets 160 | -------------------------------------------------------------------------------- /detector/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .box_utils import iou, area, intersection, encode 2 | from .layer_utils import conv2d_same, batch_norm_relu 3 | from .nms import batch_non_max_suppression 4 | -------------------------------------------------------------------------------- /detector/utils/box_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.constants import EPSILON, SCALE_FACTORS 3 | 4 | 5 | """ 6 | Tools for dealing with bounding boxes. 7 | 8 | All boxes are of the format [ymin, xmin, ymax, xmax] if not stated otherwise. 9 | Also the following must be true: ymin < ymax and xmin < xmax. 10 | And box coordinates are normalized to the [0, 1] range. 11 | """ 12 | 13 | 14 | def iou(boxes1, boxes2): 15 | """Computes pairwise intersection-over-union between two box collections. 16 | Arguments: 17 | boxes1: a float tensor with shape [N, 4]. 18 | boxes2: a float tensor with shape [M, 4]. 19 | Returns: 20 | a float tensor with shape [N, M] representing pairwise iou scores. 21 | """ 22 | intersections = intersection(boxes1, boxes2) 23 | areas1 = area(boxes1) 24 | areas2 = area(boxes2) 25 | unions = tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections 26 | return tf.clip_by_value(tf.divide(intersections, unions + EPSILON), 0.0, 1.0) 27 | 28 | 29 | def intersection(boxes1, boxes2): 30 | """Compute pairwise intersection areas between boxes. 31 | Arguments: 32 | boxes1: a float tensor with shape [N, 4]. 33 | boxes2: a float tensor with shape [M, 4]. 34 | Returns: 35 | a float tensor with shape [N, M] representing pairwise intersections. 36 | """ 37 | ymin1, xmin1, ymax1, xmax1 = tf.split(boxes1, num_or_size_splits=4, axis=1) 38 | ymin2, xmin2, ymax2, xmax2 = tf.split(boxes2, num_or_size_splits=4, axis=1) 39 | # they all have shapes like [None, 1] 40 | 41 | all_pairs_min_ymax = tf.minimum(ymax1, tf.transpose(ymax2)) 42 | all_pairs_max_ymin = tf.maximum(ymin1, tf.transpose(ymin2)) 43 | intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin) 44 | all_pairs_min_xmax = tf.minimum(xmax1, tf.transpose(xmax2)) 45 | all_pairs_max_xmin = tf.maximum(xmin1, tf.transpose(xmin2)) 46 | intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin) 47 | # they all have shape [N, M] 48 | 49 | return intersect_heights * intersect_widths 50 | 51 | 52 | def area(boxes): 53 | """Computes area of boxes. 54 | Arguments: 55 | boxes: a float tensor with shape [N, 4]. 56 | Returns: 57 | a float tensor with shape [N] representing box areas. 58 | """ 59 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 60 | return (ymax - ymin) * (xmax - xmin) 61 | 62 | 63 | def to_center_coordinates(boxes): 64 | """Convert bounding boxes of the format 65 | [ymin, xmin, ymax, xmax] to the format [cy, cx, h, w]. 66 | 67 | Arguments: 68 | boxes: a float tensor with shape [N, 4]. 69 | Returns: 70 | a list of float tensors with shape [N] 71 | that represent cy, cx, h, w. 72 | """ 73 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 74 | h, w = ymax - ymin, xmax - xmin 75 | cy, cx = ymin + 0.5 * h, xmin + 0.5 * w 76 | return [cy, cx, h, w] 77 | 78 | 79 | def encode(boxes, anchors): 80 | """Encode boxes with respect to anchors (or proposals). 81 | 82 | Arguments: 83 | boxes: a float tensor with shape [N, 4]. 84 | anchors: a float tensor with shape [N, 4]. 85 | Returns: 86 | a float tensor with shape [N, 4], 87 | anchor-encoded boxes of the format [ty, tx, th, tw]. 88 | """ 89 | 90 | ycenter_a, xcenter_a, ha, wa = to_center_coordinates(anchors) 91 | ycenter, xcenter, h, w = to_center_coordinates(boxes) 92 | 93 | # to avoid NaN in division and log below 94 | ha += EPSILON 95 | wa += EPSILON 96 | h += EPSILON 97 | w += EPSILON 98 | 99 | ty = (ycenter - ycenter_a)/ha 100 | tx = (xcenter - xcenter_a)/wa 101 | th = tf.log(h / ha) 102 | tw = tf.log(w / wa) 103 | 104 | ty *= SCALE_FACTORS[0] 105 | tx *= SCALE_FACTORS[1] 106 | th *= SCALE_FACTORS[2] 107 | tw *= SCALE_FACTORS[3] 108 | 109 | return tf.stack([ty, tx, th, tw], axis=1) 110 | 111 | 112 | def decode(codes, anchors): 113 | """Decode relative codes to normal boxes. 114 | 115 | Arguments: 116 | codes: a float tensor with shape [N, 4], 117 | anchor-encoded boxes of the format [ty, tx, th, tw]. 118 | anchors: a float tensor with shape [N, 4]. 119 | Returns: 120 | a float tensor with shape [N, 4], 121 | bounding boxes of the format [ymin, xmin, ymax, xmax]. 122 | """ 123 | 124 | ycenter_a, xcenter_a, ha, wa = to_center_coordinates(anchors) 125 | ty, tx, th, tw = tf.unstack(codes, axis=1) 126 | 127 | ty /= SCALE_FACTORS[0] 128 | tx /= SCALE_FACTORS[1] 129 | th /= SCALE_FACTORS[2] 130 | tw /= SCALE_FACTORS[3] 131 | 132 | h = tf.exp(th) * ha 133 | w = tf.exp(tw) * wa 134 | ycenter = ty * ha + ycenter_a 135 | xcenter = tx * wa + xcenter_a 136 | 137 | ymin, xmin = ycenter - 0.5 * h, xcenter - 0.5 * w 138 | ymax, xmax = ycenter + 0.5 * h, xcenter + 0.5 * w 139 | return tf.stack([ymin, xmin, ymax, xmax], axis=1) 140 | -------------------------------------------------------------------------------- /detector/utils/layer_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.constants import DATA_FORMAT 3 | 4 | 5 | BATCH_NORM_MOMENTUM = 0.95 6 | BATCH_NORM_EPSILON = 1e-3 7 | 8 | 9 | def batch_norm_relu(x, is_training, use_relu=True, name=None): 10 | x = tf.layers.batch_normalization( 11 | inputs=x, axis=1 if DATA_FORMAT == 'channels_first' else 3, 12 | momentum=BATCH_NORM_MOMENTUM, epsilon=BATCH_NORM_EPSILON, 13 | center=True, scale=True, training=is_training, 14 | fused=True, name=name 15 | ) 16 | return x if not use_relu else tf.nn.relu(x) 17 | 18 | 19 | def conv2d_same(x, num_filters, kernel_size=3, stride=1, name=None): 20 | 21 | assert kernel_size in [1, 3] 22 | assert stride in [1, 2] 23 | 24 | if kernel_size == 3: 25 | 26 | if DATA_FORMAT == 'channels_first': 27 | paddings = [[0, 0], [0, 0], [1, 1], [1, 1]] 28 | else: 29 | paddings = [[0, 0], [1, 1], [1, 1], [0, 0]] 30 | 31 | x = tf.pad(x, paddings) 32 | 33 | return tf.layers.conv2d( 34 | inputs=x, filters=num_filters, 35 | kernel_size=kernel_size, strides=stride, 36 | padding='valid', use_bias=False, 37 | kernel_initializer=tf.variance_scaling_initializer(), 38 | data_format=DATA_FORMAT, name=name 39 | ) 40 | -------------------------------------------------------------------------------- /detector/utils/nms.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector.constants import PARALLEL_ITERATIONS 3 | from detector.utils.box_utils import decode 4 | 5 | 6 | def batch_non_max_suppression( 7 | encoded_boxes, 8 | anchors, scores, 9 | score_threshold, 10 | iou_threshold, 11 | max_detections): 12 | """ 13 | Arguments: 14 | encoded_boxes: a float tensor with shape [batch_size, N, 4]. 15 | anchors: a float tensor with shape [N, 4]. 16 | scores: a float tensor with shape [batch_size, N]. 17 | score_threshold: a float number. 18 | iou_threshold: a float number. 19 | max_detections: an integer. 20 | Returns: 21 | boxes: a float tensor with shape [batch_size, N', 4]. 22 | scores: a float tensor with shape [batch_size, N']. 23 | num_detections: an int tensor with shape [batch_size]. 24 | 25 | Where N' = max_detections. 26 | """ 27 | def fn(x): 28 | encoded_boxes, scores = x 29 | 30 | is_confident = scores >= score_threshold # shape [N] 31 | encoded_boxes = tf.boolean_mask(encoded_boxes, is_confident) # shape [num_confident, 4] 32 | scores = tf.boolean_mask(scores, is_confident) # shape [num_confident] 33 | chosen_anchors = tf.boolean_mask(anchors, is_confident) # shape [num_confident, 4] 34 | 35 | boxes = decode(encoded_boxes, chosen_anchors) # shape [num_confident, 4] 36 | boxes = tf.clip_by_value(boxes, 0.0, 1.0) 37 | 38 | selected_indices = tf.image.non_max_suppression( 39 | boxes, scores, max_output_size=max_detections, 40 | iou_threshold=iou_threshold, score_threshold=score_threshold 41 | ) 42 | 43 | boxes = tf.gather(boxes, selected_indices) 44 | scores = tf.gather(scores, selected_indices) 45 | num_boxes = tf.to_int32(tf.size(selected_indices)) 46 | 47 | zero_padding = max_detections - num_boxes 48 | boxes = tf.pad(boxes, [[0, zero_padding], [0, 0]]) 49 | scores = tf.pad(scores, [[0, zero_padding]]) 50 | 51 | boxes.set_shape([max_detections, 4]) 52 | scores.set_shape([max_detections]) 53 | return boxes, scores, num_boxes 54 | 55 | boxes, scores, num_detections = tf.map_fn( 56 | fn, [encoded_boxes, scores], 57 | dtype=(tf.float32, tf.float32, tf.int32), 58 | parallel_iterations=PARALLEL_ITERATIONS, 59 | back_prop=False, swap_memory=False, infer_shape=True 60 | ) 61 | return boxes, scores, num_detections 62 | -------------------------------------------------------------------------------- /inference/detector.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | import numpy as np 3 | 4 | 5 | class Detector: 6 | def __init__(self, model_path, gpu_memory_fraction=0.25, visible_device_list='0'): 7 | """ 8 | Arguments: 9 | model_path: a string, path to a pb file. 10 | gpu_memory_fraction: a float number. 11 | visible_device_list: a string. 12 | """ 13 | with tf.gfile.GFile(model_path, 'rb') as f: 14 | graph_def = tf.GraphDef() 15 | graph_def.ParseFromString(f.read()) 16 | 17 | graph = tf.Graph() 18 | with graph.as_default(): 19 | tf.import_graph_def(graph_def, name='import') 20 | 21 | self.input_image = graph.get_tensor_by_name('import/images:0') 22 | output_names = [ 23 | 'boxes', 'scores', 'num_boxes', 24 | 'keypoint_heatmaps', 'segmentation_masks', 25 | 'keypoint_scores', 'keypoint_positions' 26 | ] 27 | self.output_ops = {n: graph.get_tensor_by_name(f'import/{n}:0') for n in output_names} 28 | 29 | gpu_options = tf.GPUOptions( 30 | per_process_gpu_memory_fraction=gpu_memory_fraction, 31 | visible_device_list=visible_device_list 32 | ) 33 | config_proto = tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False) 34 | self.sess = tf.Session(graph=graph, config=config_proto) 35 | 36 | def __call__(self, image, score_threshold=0.05): 37 | """ 38 | Arguments: 39 | image: a numpy uint8 array with shape [height, width, 3], 40 | that represents a RGB image. 41 | score_threshold: a float number. 42 | """ 43 | 44 | h, w, _ = image.shape 45 | assert h % 128 == 0 and w % 128 == 0 46 | 47 | feed_dict = {self.input_image: np.expand_dims(image, 0)} 48 | outputs = self.sess.run(self.output_ops, feed_dict) 49 | outputs.update({ 50 | n: v[0] for n, v in outputs.items() 51 | if n not in ['keypoint_scores', 'keypoint_positions'] 52 | }) 53 | 54 | n = outputs['num_boxes'] 55 | to_keep = outputs['scores'][:n] > score_threshold 56 | outputs['boxes'] = outputs['boxes'][:n][to_keep] 57 | outputs['scores'] = outputs['scores'][:n][to_keep] 58 | outputs['keypoint_positions'] = outputs['keypoint_positions'][to_keep] 59 | outputs['keypoint_scores'] = outputs['keypoint_scores'][to_keep] 60 | 61 | return outputs 62 | -------------------------------------------------------------------------------- /inference/predict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "from PIL import Image, ImageDraw\n", 11 | "import os\n", 12 | "import cv2\n", 13 | "import random\n", 14 | "import time\n", 15 | "\n", 16 | "from detector import Detector" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "MODEL_PATH = 'model.pb'\n", 26 | "IMAGES_FOLDER = '/home/dan/datasets/COCO/images/val2017/'" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "# Load the model" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "detector = Detector(MODEL_PATH, visible_device_list='0')" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "# Get an image" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "names = os.listdir(IMAGES_FOLDER)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "n = random.choice(names)\n", 68 | "path = os.path.join(IMAGES_FOLDER, n)\n", 69 | "\n", 70 | "image = Image.open(path)\n", 71 | "image = image.resize((640, 640))\n", 72 | "image" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "# Predict" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "image_array = np.array(image)\n", 89 | "outputs = detector(image_array, score_threshold=0.0001)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "# Show detections" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "EDGES = [\n", 106 | " (0, 1), (0, 2),\n", 107 | " (1, 3), (2, 4),\n", 108 | " (5, 7), (7, 9), (6, 8), (8, 10),\n", 109 | " (11, 13), (13, 15), (12, 14), (14, 16),\n", 110 | " (3, 5), (4, 6),\n", 111 | " (5, 11), (6, 12)\n", 112 | "]" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "outputs['keypoint_scores']" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "def draw_everything(image, outputs):\n", 131 | "\n", 132 | " image_copy = image.copy()\n", 133 | " image_copy.putalpha(255)\n", 134 | " draw = ImageDraw.Draw(image_copy, 'RGBA')\n", 135 | " width, height = image_copy.size\n", 136 | " \n", 137 | " scaler = np.array([height, width, height, width])\n", 138 | " boxes = scaler * outputs['boxes']\n", 139 | "\n", 140 | " for i, box in enumerate(boxes):\n", 141 | " \n", 142 | " ymin, xmin, ymax, xmax = box\n", 143 | " draw.rectangle([(xmin, ymin), (xmax, ymax)], outline='red')\n", 144 | " \n", 145 | " keypoints = outputs['keypoint_positions'][i]\n", 146 | " keypoints = keypoints[:, [1, 0]].copy()\n", 147 | " keypoints *= np.array([xmax - xmin, ymax - ymin])\n", 148 | " keypoints += np.array([xmin, ymin])\n", 149 | " visibility = np.ones([17, 1], dtype=np.float32)\n", 150 | " keypoints = np.concatenate([keypoints, visibility], axis=1)\n", 151 | "\n", 152 | " for (p, q) in EDGES:\n", 153 | "\n", 154 | " x1, y1, v1 = keypoints[p]\n", 155 | " x2, y2, v2 = keypoints[q]\n", 156 | "\n", 157 | " both_visible = v1 > 0 and v2 > 0\n", 158 | " if both_visible:\n", 159 | " draw.line([(x1, y1), (x2, y2)])\n", 160 | "\n", 161 | " for j in range(17):\n", 162 | " x, y, v = keypoints[j]\n", 163 | " if v > 0:\n", 164 | " s = 2\n", 165 | " draw.ellipse([\n", 166 | " (x - s, y - s),\n", 167 | " (x + s, y + s)\n", 168 | " ], fill='red')\n", 169 | "\n", 170 | " return image_copy" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "draw_everything(image, outputs)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "segmentation_masks = outputs['segmentation_masks'].copy()\n", 189 | "m = segmentation_masks.min()\n", 190 | "M = segmentation_masks.max()\n", 191 | "segmentation_masks = 255.0 * (segmentation_masks - m)/(M - m)\n", 192 | "segmentation_masks = Image.fromarray(segmentation_masks.astype('uint8'))\n", 193 | "segmentation_masks" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "# Measure speed" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "scrolled": true 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "times = []\n", 212 | "for _ in range(110):\n", 213 | " start = time.perf_counter()\n", 214 | " result = detector(image_array, score_threshold=0.25)\n", 215 | " times.append(time.perf_counter() - start)\n", 216 | " \n", 217 | "times = np.array(times)\n", 218 | "times = times[10:]\n", 219 | "print(times.mean(), times.std())" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "# Show heatmaps " 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "import matplotlib.pyplot as plt\n", 236 | "from matplotlib.colors import ListedColormap\n", 237 | "\n", 238 | "cmap = plt.cm.get_cmap('autumn')\n", 239 | "new_cmap = cmap(np.arange(cmap.N))\n", 240 | "new_cmap[:, -1] = np.sqrt(np.linspace(0, 1, cmap.N)) # set alpha\n", 241 | "cmap = ListedColormap(new_cmap) # create new colormap" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "ORDER = {\n", 251 | " 0: 'nose',\n", 252 | " 1: 'left eye', 2: 'right eye',\n", 253 | " 3: 'left ear', 4: 'right ear',\n", 254 | " 5: 'left shoulder', 6: 'right shoulder',\n", 255 | " 7: 'left elbow', 8: 'right elbow',\n", 256 | " 9: 'left wrist', 10: 'right wrist',\n", 257 | " 11: 'left hip', 12: 'right hip',\n", 258 | " 13: 'left knee', 14: 'right knee',\n", 259 | " 15: 'left ankle', 16: 'right ankle'\n", 260 | "}\n", 261 | "\n", 262 | "\n", 263 | "def plot_maps(image, heatmaps, segmentation_mask):\n", 264 | " \"\"\"\n", 265 | " Arguments:\n", 266 | " image: a float numpy array with shape [h, w, 3].\n", 267 | " heatmaps: a float numpy array with shape [h / 4, w / 4, 17].\n", 268 | " segmentation_mask: a float numpy array with shape [h / 4, w / 4].\n", 269 | " loss_mask: a float numpy array with shape [h / 4, w / 4].\n", 270 | " \"\"\"\n", 271 | "\n", 272 | " h, w, _ = image.shape\n", 273 | " h, w = (h // 2), (w // 2)\n", 274 | " background = Image.new('RGBA', (w, h * 18), (255, 255, 255, 255))\n", 275 | " draw = ImageDraw.Draw(background, 'RGBA')\n", 276 | " \n", 277 | " image = Image.fromarray(image)\n", 278 | " image = image.resize((w, h), Image.LANCZOS)\n", 279 | " image.putalpha(255)\n", 280 | "\n", 281 | " heatmaps = (255 * cmap(heatmaps)).astype('uint8')\n", 282 | " # it has shape [h, w, 17, 4]\n", 283 | " \n", 284 | " heats = []\n", 285 | " for j, name in ORDER.items():\n", 286 | "\n", 287 | " heat = Image.fromarray(heatmaps[:, :, j])\n", 288 | " heat = heat.resize((w, h), Image.LANCZOS)\n", 289 | " heat = Image.alpha_composite(image, heat)\n", 290 | " background.paste(heat, (0, j * h))\n", 291 | " draw.text((0, j * h), name, fill='red')\n", 292 | " \n", 293 | " def draw_mask(mask):\n", 294 | " mask = np.clip(mask, 0.0, 1.0)\n", 295 | " mask = (255 * mask).astype('uint8')\n", 296 | " mask = Image.fromarray(mask)\n", 297 | " mask = mask.resize((w, h), Image.LANCZOS).convert('RGB')\n", 298 | " mask.putalpha(mask.convert('L'))\n", 299 | " mask = Image.alpha_composite(image, mask)\n", 300 | " return mask\n", 301 | " \n", 302 | " mask = draw_mask(segmentation_mask)\n", 303 | " background.paste(mask, (0, 17 * h))\n", 304 | " draw.text((0, 17 * h), 'segmentation mask', fill='red')\n", 305 | "\n", 306 | " return background" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "h = outputs['keypoint_heatmaps']\n", 316 | "m = h.min(0).min(0)\n", 317 | "M = h.max(0).max(0)\n", 318 | "h = (h - m)/(M - m)\n", 319 | "\n", 320 | "plot_maps(image_array, h, outputs['segmentation_masks'])" 321 | ] 322 | } 323 | ], 324 | "metadata": { 325 | "kernelspec": { 326 | "display_name": "Python 3", 327 | "language": "python", 328 | "name": "python3" 329 | }, 330 | "language_info": { 331 | "codemirror_mode": { 332 | "name": "ipython", 333 | "version": 3 334 | }, 335 | "file_extension": ".py", 336 | "mimetype": "text/x-python", 337 | "name": "python", 338 | "nbconvert_exporter": "python", 339 | "pygments_lexer": "ipython3", 340 | "version": "3.6.7" 341 | } 342 | }, 343 | "nbformat": 4, 344 | "nbformat_minor": 1 345 | } 346 | -------------------------------------------------------------------------------- /inference/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import ImageDraw 3 | 4 | 5 | """ 6 | The keypoint order: 7 | 0: 'nose', 8 | 1: 'left eye', 2: 'right eye', 9 | 3: 'left ear', 4: 'right ear', 10 | 5: 'left shoulder', 6: 'right shoulder', 11 | 7: 'left elbow', 8: 'right elbow', 12 | 9: 'left wrist', 10: 'right wrist', 13 | 11: 'left hip', 12: 'right hip', 14 | 13: 'left knee', 14: 'right knee', 15 | 15: 'left ankle', 16: 'right ankle' 16 | """ 17 | 18 | 19 | EDGES = [ 20 | (0, 1), (0, 2), 21 | (1, 3), (2, 4), 22 | (5, 7), (7, 9), (6, 8), (8, 10), 23 | (11, 13), (13, 15), (12, 14), (14, 16), 24 | (3, 5), (4, 6), 25 | (5, 11), (6, 12) 26 | ] 27 | 28 | 29 | def get_keypoints(heatmaps, box, threshold): 30 | """ 31 | Arguments: 32 | heatmaps: a numpy float array with shape [h, w, 17]. 33 | box: a numpy array with shape [4]. 34 | threshold: a float number. 35 | Returns: 36 | a numpy int array with shape [17, 3]. 37 | """ 38 | keypoints = np.zeros([17, 3], dtype='int32') 39 | 40 | ymin, xmin, ymax, xmax = box 41 | height, width = ymax - ymin, xmax - xmin 42 | h, w, _ = heatmaps.shape 43 | 44 | for j in range(17): 45 | mask = heatmaps[:, :, j] 46 | if mask.max() > threshold: 47 | y, x = np.unravel_index(mask.argmax(), mask.shape) 48 | y = np.clip(int(y * height/h), 0, height) 49 | x = np.clip(int(x * width/w), 0, width) 50 | keypoints[j] = np.array([x, y, 1]) 51 | 52 | return keypoints 53 | 54 | 55 | def draw_pose(draw, keypoints, box): 56 | """ 57 | Arguments: 58 | draw: an instance of ImageDraw.Draw. 59 | keypoints: a numpy int array with shape [17, 3]. 60 | box: a numpy int array with shape [4]. 61 | """ 62 | ymin, xmin, ymax, xmax = box 63 | keypoints += np.array([xmin, ymin, 0]) 64 | 65 | for (p, q) in EDGES: 66 | 67 | x1, y1, v1 = keypoints[p] 68 | x2, y2, v2 = keypoints[q] 69 | 70 | both_visible = v1 > 0 and v2 > 0 71 | if both_visible: 72 | draw.line([(x1, y1), (x2, y2)]) 73 | 74 | for j in range(17): 75 | x, y, v = keypoints[j] 76 | if v > 0: 77 | s = 8 78 | draw.ellipse([ 79 | (x - s, y - s), 80 | (x + s, y + s) 81 | ], fill='red') 82 | 83 | 84 | def draw_everything(image, outputs): 85 | 86 | image_copy = image.copy() 87 | image_copy.putalpha(255) 88 | draw = ImageDraw.Draw(image_copy, 'RGBA') 89 | width, height = image_copy.size 90 | scaler = np.array([height, width, height, width]) 91 | 92 | n = outputs['num_boxes'] 93 | boxes = scaler * outputs['boxes'] 94 | for box in boxes: 95 | ymin, xmin, ymax, xmax = box 96 | draw.rectangle([(xmin, ymin), (xmax, ymax)], outline='red') 97 | 98 | mask = outputs['keypoint_heatmaps'][:, :, 0]#outputs['segmentation_masks'] 99 | m, M = mask.min(), mask.max() 100 | mask = (mask - m)/(M - m) 101 | mask = np.expand_dims(mask, 2) 102 | color = np.array([255, 255, 255]) 103 | mask = Image.fromarray((mask*color).astype('uint8')) 104 | mask.putalpha(mask.convert('L')) 105 | mask = mask.resize((width, height)) 106 | image_copy.alpha_composite(mask) 107 | return image_copy 108 | -------------------------------------------------------------------------------- /keypoints_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector import KeypointSubnet 3 | from detector.backbones import mobilenet_v1 4 | 5 | 6 | def model_fn(features, labels, mode, params): 7 | 8 | assert mode != tf.estimator.ModeKeys.PREDICT 9 | is_training = mode == tf.estimator.ModeKeys.TRAIN 10 | 11 | images = features['images'] 12 | # it has shape [b, height, width, 3] 13 | 14 | backbone_features = mobilenet_v1( 15 | images, is_training, 16 | params['depth_multiplier'] 17 | ) 18 | subnet = KeypointSubnet( 19 | backbone_features, 20 | is_training, params 21 | ) 22 | 23 | # add l2 regularization 24 | if params['weight_decay'] > 0.0: 25 | add_weight_decay(params['weight_decay']) 26 | regularization_loss = tf.losses.get_regularization_loss() 27 | tf.summary.scalar('regularization_loss', regularization_loss) 28 | 29 | losses = {} 30 | 31 | heatmaps = labels['heatmaps'] 32 | # it has shape [b, h, w, 17], 33 | # where (h, w) = (height / 4, width / 4), 34 | # and `b` is batch size 35 | 36 | batch_size = tf.shape(heatmaps)[0] 37 | normalizer = tf.to_float(batch_size) 38 | 39 | segmentation_masks = tf.expand_dims(labels['segmentation_masks'], 3) 40 | loss_masks = tf.expand_dims(labels['loss_masks'], 3) 41 | # they have shape [b, h, w, 1] 42 | 43 | predicted_heatmaps = subnet.heatmaps[:, :, :, :17] 44 | predicted_segmentation_masks = tf.expand_dims(subnet.heatmaps[:, :, :, 17], 3) 45 | 46 | focal_loss_value = focal_loss( 47 | heatmaps, labels['num_boxes'], 48 | predicted_heatmaps, alpha=2.0, beta=4.0 49 | ) # shape [b, h, w] 50 | focal_loss_value = tf.squeeze(loss_masks, 3) * focal_loss_value 51 | focal_loss_value = tf.reduce_sum(focal_loss_value, axis=[0, 1, 2]) 52 | losses['focal_loss'] = focal_loss_value/normalizer 53 | 54 | regression_loss = tf.nn.l2_loss(loss_masks * (predicted_segmentation_masks - segmentation_masks)) 55 | losses['regression_loss'] = 1e-3 * regression_loss/normalizer 56 | 57 | # additional supervision 58 | # with person segmentation 59 | for level in range(2, 6): 60 | 61 | x = subnet.enriched_features[f'p{level}'] 62 | x = tf.expand_dims(x[:, :, :, 0], 3) 63 | # it has shape [b, height / stride, width / stride, 1], 64 | # where stride is equal to level ** 2 65 | 66 | x = tf.nn.l2_loss(loss_masks * (x - segmentation_masks)) 67 | losses[f'segmentation_loss_at_level_{level}'] = 1e-5 * x/normalizer 68 | 69 | shape = tf.shape(segmentation_masks) 70 | height, width = shape[1], shape[2] 71 | new_size = [height // 2, width // 2] 72 | 73 | segmentation_masks = tf.image.resize_bilinear(segmentation_masks, new_size) 74 | loss_masks = tf.image.resize_bilinear(loss_masks, new_size) 75 | 76 | for n, v in losses.items(): 77 | tf.losses.add_loss(v) 78 | tf.summary.scalar(n, v) 79 | total_loss = tf.losses.get_total_loss(add_regularization_losses=True) 80 | 81 | with tf.name_scope('eval_metrics'): 82 | 83 | shape = tf.shape(heatmaps) 84 | height, width = shape[1], shape[2] 85 | area = tf.to_float(height * width) 86 | 87 | loss_masks = tf.expand_dims(labels['loss_masks'], 3) 88 | predicted_heatmaps = tf.sigmoid(predicted_heatmaps) 89 | per_pixel_reg_loss = tf.nn.l2_loss(loss_masks * (predicted_heatmaps - heatmaps))/(normalizer * area) 90 | tf.summary.scalar('per_pixel_reg_loss', per_pixel_reg_loss) 91 | 92 | if mode == tf.estimator.ModeKeys.EVAL: 93 | 94 | eval_metric_ops = { 95 | 'eval_regression_loss': tf.metrics.mean(losses['regression_loss']), 96 | 'eval_focal_loss': tf.metrics.mean(losses['focal_loss']), 97 | 'eval_per_pixel_reg_loss': tf.metrics.mean(per_pixel_reg_loss), 98 | 'eval_segmentation_loss_at_level_2': tf.metrics.mean(losses['segmentation_loss_at_level_2']), 99 | 'eval_segmentation_loss_at_level_5': tf.metrics.mean(losses['segmentation_loss_at_level_5']) 100 | } 101 | 102 | return tf.estimator.EstimatorSpec( 103 | mode, loss=total_loss, 104 | eval_metric_ops=eval_metric_ops 105 | ) 106 | 107 | with tf.variable_scope('learning_rate'): 108 | global_step = tf.train.get_global_step() 109 | learning_rate = tf.train.cosine_decay( 110 | params['initial_learning_rate'], global_step, 111 | decay_steps=params['num_steps'], alpha=1e-4 112 | ) 113 | tf.summary.scalar('learning_rate', learning_rate) 114 | 115 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 116 | with tf.control_dependencies(update_ops): 117 | optimizer = tf.train.AdamOptimizer(learning_rate) 118 | grads_and_vars = optimizer.compute_gradients(total_loss) 119 | grads_and_vars = [(tf.clip_by_value(g, -200, 200), v) for g, v in grads_and_vars] 120 | train_op = optimizer.apply_gradients(grads_and_vars, global_step) 121 | 122 | for g, v in grads_and_vars: 123 | tf.summary.histogram(v.name[:-2] + '_hist', v) 124 | tf.summary.histogram(v.name[:-2] + '_grad_hist', g) 125 | 126 | return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op) 127 | 128 | 129 | def add_weight_decay(weight_decay): 130 | 131 | trainable_vars = tf.trainable_variables() 132 | kernels = [ 133 | v for v in trainable_vars 134 | if ('weights' in v.name or 'kernel' in v.name) and 'depthwise_weights' not in v.name 135 | ] 136 | for k in kernels: 137 | x = tf.multiply(weight_decay, tf.nn.l2_loss(k)) 138 | tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, x) 139 | 140 | 141 | def focal_loss(heatmaps, num_boxes, predictions, alpha=2.0, beta=4.0): 142 | """ 143 | This is loss function from here: 144 | "CornerNet: Detecting Objects as Paired Keypoints" 145 | (https://arxiv.org/abs/1808.01244) 146 | 147 | Arguments: 148 | heatmaps: a float tensor with shape [b, h, w, c]. 149 | num_boxes: a long tensor with shape [b]. 150 | predictions: a float tensor with shape [b, h, w, c], 151 | it represents logits. 152 | alpha, beta: float numbers. 153 | Returns: 154 | a float tensor with shape [b, h, w]. 155 | """ 156 | 157 | # notation like in the paper 158 | y = heatmaps 159 | y_hat = predictions 160 | 161 | is_extreme_point = tf.equal(y, 1.0) 162 | # binary tensor with shape [b, h, w, c] 163 | 164 | losses = tf.nn.sigmoid_cross_entropy_with_logits( 165 | logits=y_hat, labels=tf.to_float(is_extreme_point) 166 | ) # shape [b, h, w, c] 167 | 168 | # to the [0, 1] range 169 | y_hat = tf.sigmoid(y_hat) 170 | 171 | weights = tf.where( 172 | is_extreme_point, tf.pow(1.0 - y_hat, alpha), 173 | tf.pow(1.0 - y, beta) * tf.pow(y_hat, alpha) 174 | ) # shape [b, h, w, c] 175 | 176 | b = tf.shape(y)[0] # batch size 177 | normalizer = tf.to_float(tf.reshape(num_boxes, [b, 1, 1])) + 1.0 178 | return tf.reduce_sum(weights * losses, 3)/normalizer 179 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.compat.v1 as tf 3 | 4 | 5 | """ 6 | For evaluation during the training I use average precision @ iou=0.5 7 | like in PASCAL VOC Challenge (detection task): 8 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/devkit_doc.pdf 9 | 10 | But after the training I test trained models 11 | using the official evaluation scripts. 12 | """ 13 | 14 | 15 | class Evaluator: 16 | """It creates ops like in tf.metrics API.""" 17 | 18 | def __init__(self): 19 | self.initialize() 20 | 21 | def evaluate(self, iou_threshold=0.5): 22 | self.metrics = evaluate_detector( 23 | self.groundtruth, 24 | self.detections, 25 | iou_threshold 26 | ) 27 | 28 | def get_metric_ops(self, groundtruth, predictions): 29 | """ 30 | Arguments: 31 | groundtruth: a dict with the following keys 32 | 'boxes': a float tensor with shape [1, N, 4]. 33 | predictions: a dict with the following keys 34 | 'boxes': a float tensor with shape [1, M, 4]. 35 | 'scores': a float tensor with shape [1, M]. 36 | 'num_boxes': a float tensor with shape [1]. 37 | """ 38 | 39 | def update_op_func(gt_boxes, boxes, scores): 40 | image_name = '{}'.format(self.unique_image_id) 41 | self.unique_image_id += 1 42 | self.add_groundtruth(image_name, gt_boxes) 43 | self.add_detections(image_name, boxes, scores) 44 | 45 | num_boxes = predictions['num_boxes'][0] 46 | tensors = [ 47 | groundtruth['boxes'][0], 48 | predictions['boxes'][0][:num_boxes], 49 | predictions['scores'][0][:num_boxes] 50 | ] 51 | update_op = tf.py_func(update_op_func, tensors, []) 52 | 53 | def evaluate_func(): 54 | self.evaluate() 55 | self.initialize() 56 | evaluate_op = tf.py_func(evaluate_func, [], []) 57 | 58 | def get_value_func(measure): 59 | def value_func(): 60 | return np.float32(self.metrics[measure]) 61 | return value_func 62 | 63 | with tf.control_dependencies([evaluate_op]): 64 | 65 | metric_names = [ 66 | 'AP', 'precision', 'recall', 'mean_iou_for_TP', 67 | 'best_threshold', 'total_FP', 'total_FN' 68 | ] 69 | 70 | eval_metric_ops = {} 71 | for measure in metric_names: 72 | name = 'metrics/' + measure 73 | value_op = tf.py_func(get_value_func(measure), [], tf.float32) 74 | eval_metric_ops[name] = (value_op, update_op) 75 | 76 | return eval_metric_ops 77 | 78 | def initialize(self): 79 | self.detections = [] 80 | 81 | # groundtruth boxes are separated by image 82 | self.groundtruth = {} 83 | 84 | # i will use this counter as an unique image identifier 85 | self.unique_image_id = 0 86 | 87 | def add_detections(self, image_name, boxes, scores): 88 | """ 89 | Arguments: 90 | image_name: a numpy string array with shape []. 91 | boxes: a numpy float array with shape [M, 4]. 92 | scores: a numpy float array with shape [M]. 93 | """ 94 | for box, score in zip(boxes, scores): 95 | self.detections.append(get_box(box, image_name, score)) 96 | 97 | def add_groundtruth(self, image_name, boxes): 98 | for box in boxes: 99 | g = self.groundtruth 100 | if image_name in g: 101 | g[image_name] += [get_box(box)] 102 | else: 103 | g[image_name] = [get_box(box)] 104 | 105 | 106 | def get_box(box, image_name=None, score=None): 107 | ymin, xmin, ymax, xmax = box 108 | dictionary = { 109 | 'ymin': ymin, 'xmin': xmin, 110 | 'ymax': ymax, 'xmax': xmax, 111 | } 112 | 113 | # groundtruth and predicted boxes 114 | # have different format 115 | is_prediction = (score is not None)\ 116 | and (image_name is not None) 117 | is_groundtruth = not is_prediction 118 | 119 | if is_prediction: 120 | dictionary['image_name'] = image_name 121 | dictionary['confidence'] = score 122 | elif is_groundtruth: 123 | dictionary['is_matched'] = False 124 | 125 | return dictionary 126 | 127 | 128 | def evaluate_detector(groundtruth, detections, iou_threshold=0.5): 129 | """ 130 | Arguments: 131 | groundtruth: a dict of lists with boxes, 132 | image -> list of groundtruth boxes on the image. 133 | detections: a list of boxes. 134 | iou_threshold: a float number. 135 | Returns: 136 | a dict with seven values. 137 | """ 138 | 139 | # each ground truth box is either TP or FN 140 | num_groundtruth_boxes = 0 141 | 142 | for boxes in groundtruth.values(): 143 | num_groundtruth_boxes += len(boxes) 144 | num_groundtruth_boxes = max(num_groundtruth_boxes, 1) 145 | 146 | # sort by confidence in decreasing order 147 | detections.sort(key=lambda box: box['confidence'], reverse=True) 148 | 149 | num_correct_detections = 0 150 | num_detections = 0 151 | mean_iou = 0.0 152 | precision = [0.0]*len(detections) 153 | recall = [0.0]*len(detections) 154 | confidences = [box['confidence'] for box in detections] 155 | 156 | for k, detection in enumerate(detections): 157 | 158 | # each detection is either TP or FP 159 | num_detections += 1 160 | 161 | groundtruth_boxes = groundtruth.get(detection['image_name'], []) 162 | best_groundtruth_i, max_iou = match(detection, groundtruth_boxes) 163 | 164 | if best_groundtruth_i >= 0 and max_iou >= iou_threshold: 165 | box = groundtruth_boxes[best_groundtruth_i] 166 | if not box['is_matched']: 167 | box['is_matched'] = True 168 | num_correct_detections += 1 # increase number of TP 169 | mean_iou += max_iou 170 | 171 | precision[k] = num_correct_detections/num_detections # TP/(TP + FP) 172 | recall[k] = num_correct_detections/num_groundtruth_boxes # TP/(TP + FN) 173 | 174 | ap = compute_ap(precision, recall) 175 | best_threshold, best_precision, best_recall = compute_best_threshold( 176 | precision, recall, confidences 177 | ) 178 | mean_iou /= max(num_correct_detections, 1) 179 | 180 | return { 181 | 'AP': ap, 'precision': best_precision, 182 | 'recall': best_recall, 'best_threshold': best_threshold, 183 | 'mean_iou_for_TP': mean_iou, 'total_FP': num_detections - num_correct_detections, 184 | 'total_FN': num_groundtruth_boxes - num_correct_detections 185 | } 186 | 187 | 188 | def compute_best_threshold(precision, recall, confidences): 189 | """ 190 | Arguments: 191 | precision, recall, confidences: lists of floats of the same length. 192 | Returns: 193 | 1. a float number, best confidence threshold. 194 | 2. a float number, precision at the threshold. 195 | 3. a float number, recall at the threshold. 196 | """ 197 | if len(confidences) == 0: 198 | return 0.0, 0.0, 0.0 199 | 200 | precision = np.array(precision) 201 | recall = np.array(recall) 202 | confidences = np.array(confidences) 203 | 204 | diff = np.abs(precision - recall) 205 | prod = precision*recall 206 | best_i = np.argmax(prod*(1.0 - diff)) 207 | best_threshold = confidences[best_i] 208 | 209 | return best_threshold, precision[best_i], recall[best_i] 210 | 211 | 212 | def compute_iou(box1, box2): 213 | w = min(box1['xmax'], box2['xmax']) - max(box1['xmin'], box2['xmin']) 214 | if w > 0: 215 | h = min(box1['ymax'], box2['ymax']) - max(box1['ymin'], box2['ymin']) 216 | if h > 0: 217 | intersection = w*h 218 | w1 = box1['xmax'] - box1['xmin'] 219 | h1 = box1['ymax'] - box1['ymin'] 220 | w2 = box2['xmax'] - box2['xmin'] 221 | h2 = box2['ymax'] - box2['ymin'] 222 | union = (w1*h1 + w2*h2) - intersection 223 | return float(intersection)/float(union) 224 | return 0.0 225 | 226 | 227 | def match(detection, groundtruth_boxes): 228 | """ 229 | Arguments: 230 | detection: a box. 231 | groundtruth_boxes: a list of boxes. 232 | Returns: 233 | best_i: an integer, index of the best groundtruth box. 234 | max_iou: a float number. 235 | """ 236 | best_i = -1 237 | max_iou = 0.0 238 | for i, box in enumerate(groundtruth_boxes): 239 | iou = compute_iou(detection, box) 240 | if iou > max_iou: 241 | best_i = i 242 | max_iou = iou 243 | return best_i, max_iou 244 | 245 | 246 | def compute_ap(precision, recall): 247 | previous_recall_value = 0.0 248 | ap = 0.0 249 | # recall is in increasing order 250 | for p, r in zip(precision, recall): 251 | delta = r - previous_recall_value 252 | ap += p*delta 253 | previous_recall_value = r 254 | return ap 255 | -------------------------------------------------------------------------------- /person_detector_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector import RetinaNet 3 | from detector.backbones import mobilenet_v1 4 | from keypoints_model import add_weight_decay 5 | from metrics import Evaluator 6 | 7 | 8 | def model_fn(features, labels, mode, params): 9 | 10 | assert mode != tf.estimator.ModeKeys.PREDICT 11 | is_training = mode == tf.estimator.ModeKeys.TRAIN 12 | 13 | images = features['images'] 14 | backbone_features = mobilenet_v1( 15 | images, is_training=False, 16 | depth_multiplier=params['depth_multiplier'] 17 | ) 18 | retinanet = RetinaNet( 19 | backbone_features, 20 | tf.shape(images), 21 | is_training, params 22 | ) 23 | 24 | # add nms to the graph 25 | if not is_training: 26 | predictions = retinanet.get_predictions( 27 | score_threshold=params['score_threshold'], 28 | iou_threshold=params['iou_threshold'], 29 | max_detections=params['max_boxes'] 30 | ) 31 | 32 | # add l2 regularization 33 | add_weight_decay(params['weight_decay']) 34 | regularization_loss = tf.losses.get_regularization_loss() 35 | tf.summary.scalar('regularization_loss', regularization_loss) 36 | 37 | # create localization and classification losses 38 | losses = retinanet.loss(labels, params) 39 | tf.losses.add_loss(params['localization_loss_weight'] * losses['localization_loss']) 40 | tf.losses.add_loss(params['classification_loss_weight'] * losses['classification_loss']) 41 | tf.summary.scalar('localization_loss', losses['localization_loss']) 42 | tf.summary.scalar('classification_loss', losses['classification_loss']) 43 | total_loss = tf.losses.get_total_loss(add_regularization_losses=True) 44 | 45 | if mode == tf.estimator.ModeKeys.EVAL: 46 | 47 | shape = features['images'].shape 48 | batch_size = shape[0].value 49 | assert batch_size == 1 50 | 51 | evaluator = Evaluator() 52 | eval_metric_ops = evaluator.get_metric_ops(labels, predictions) 53 | 54 | return tf.estimator.EstimatorSpec( 55 | mode, loss=total_loss, 56 | eval_metric_ops=eval_metric_ops 57 | ) 58 | 59 | with tf.variable_scope('learning_rate'): 60 | global_step = tf.train.get_global_step() 61 | learning_rate = tf.train.cosine_decay( 62 | params['initial_learning_rate'], global_step, 63 | decay_steps=params['num_steps'], alpha=1e-4 64 | ) 65 | tf.summary.scalar('learning_rate', learning_rate) 66 | 67 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 68 | with tf.control_dependencies(update_ops): 69 | optimizer = tf.train.AdamOptimizer(learning_rate) 70 | 71 | # backbone network is frozen 72 | var_list = [v for v in tf.trainable_variables() if 'MobilenetV1' not in v.name] 73 | 74 | grads_and_vars = optimizer.compute_gradients(total_loss, var_list) 75 | train_op = optimizer.apply_gradients(grads_and_vars, global_step) 76 | 77 | for g, v in grads_and_vars: 78 | tf.summary.histogram(v.name[:-2] + '_hist', v) 79 | tf.summary.histogram(v.name[:-2] + '_grad_hist', g) 80 | 81 | return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op) 82 | -------------------------------------------------------------------------------- /prn_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | from detector import prn 3 | 4 | 5 | def model_fn(features, labels, mode, params): 6 | 7 | assert mode != tf.estimator.ModeKeys.PREDICT 8 | is_training = mode == tf.estimator.ModeKeys.TRAIN 9 | 10 | heatmaps = features # shape [b, h, w, c] 11 | logits = prn(heatmaps, is_training) 12 | # it has shape [b, h, w, c] 13 | 14 | b = tf.shape(heatmaps)[0] 15 | _, h, w, c = heatmaps.shape.as_list() 16 | 17 | labels = tf.reshape(labels, [b, h * w, c]) 18 | logits = tf.reshape(logits, [b, h * w, c]) 19 | probabilities = tf.nn.softmax(logits, axis=1) 20 | 21 | losses = tf.losses.log_loss( 22 | labels, probabilities, 23 | loss_collection=None, 24 | reduction=tf.losses.Reduction.NONE 25 | ) 26 | # it has shape [b, h * w, c] 27 | 28 | loss = tf.reduce_mean(losses, axis=[0, 1, 2]) 29 | tf.losses.add_loss(loss) 30 | tf.summary.scalar('logloss', loss) 31 | total_loss = tf.losses.get_total_loss(add_regularization_losses=True) 32 | 33 | if mode == tf.estimator.ModeKeys.EVAL: 34 | eval_metric_ops = {'eval_loss': tf.metrics.mean(losses)} 35 | return tf.estimator.EstimatorSpec( 36 | mode, loss=total_loss, 37 | eval_metric_ops=eval_metric_ops 38 | ) 39 | 40 | assert mode == tf.estimator.ModeKeys.TRAIN 41 | with tf.variable_scope('learning_rate'): 42 | global_step = tf.train.get_global_step() 43 | learning_rate = tf.train.cosine_decay( 44 | params['initial_learning_rate'], global_step, 45 | decay_steps=params['num_steps'], alpha=1e-4 46 | ) 47 | tf.summary.scalar('learning_rate', learning_rate) 48 | 49 | optimizer = tf.train.AdamOptimizer(learning_rate) 50 | grads_and_vars = optimizer.compute_gradients(total_loss) 51 | train_op = optimizer.apply_gradients(grads_and_vars, global_step) 52 | 53 | for g, v in grads_and_vars: 54 | tf.summary.histogram(v.name[:-2] + '_hist', v) 55 | tf.summary.histogram(v.name[:-2] + '_grad_hist', g) 56 | 57 | return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op) 58 | -------------------------------------------------------------------------------- /train_keypoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow.compat.v1 as tf 3 | from keypoints_model import model_fn 4 | from detector.input_pipeline import KeypointPipeline 5 | 6 | 7 | PARAMS = { 8 | 'model_dir': 'models/run00/', 9 | 'train_dataset': '/home/dan/datasets/COCO/multiposenet/train/', 10 | 'val_dataset': '/home/dan/datasets/COCO/multiposenet/val/', 11 | 'pretrained_checkpoint': 'pretrained/mobilenet_v1_1.0_224.ckpt', 12 | 13 | 'backbone': 'mobilenet', 14 | 'depth_multiplier': 1.0, 15 | 'weight_decay': 0.0, 16 | 17 | 'num_steps': 200000, 18 | 'initial_learning_rate': 3e-4, 19 | 20 | 'min_dimension': 512, 21 | 'batch_size': 16, 22 | 'image_size': (512, 512) 23 | } 24 | 25 | 26 | def get_input_fn(is_training=True): 27 | 28 | dataset_path = PARAMS['train_dataset'] if is_training else PARAMS['val_dataset'] 29 | filenames = os.listdir(dataset_path) 30 | filenames = [n for n in filenames if n.endswith('.tfrecords')] 31 | filenames = [os.path.join(dataset_path, n) for n in sorted(filenames)] 32 | 33 | def input_fn(): 34 | pipeline = KeypointPipeline(filenames, is_training, PARAMS) 35 | return pipeline.dataset 36 | 37 | return input_fn 38 | 39 | 40 | tf.logging.set_verbosity('INFO') 41 | session_config = tf.ConfigProto(allow_soft_placement=True) 42 | session_config.gpu_options.visible_device_list = '0' 43 | 44 | 45 | run_config = tf.estimator.RunConfig() 46 | run_config = run_config.replace( 47 | model_dir=PARAMS['model_dir'], session_config=session_config, 48 | save_summary_steps=200, save_checkpoints_secs=7200, 49 | log_step_count_steps=1000 50 | ) 51 | 52 | 53 | train_input_fn = get_input_fn(is_training=True) 54 | val_input_fn = get_input_fn(is_training=False) 55 | warm_start = tf.estimator.WarmStartSettings(PARAMS['pretrained_checkpoint'], 'MobilenetV1/*') 56 | estimator = tf.estimator.Estimator(model_fn, params=PARAMS, config=run_config, warm_start_from=warm_start) 57 | 58 | 59 | train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=PARAMS['num_steps']) 60 | eval_spec = tf.estimator.EvalSpec(val_input_fn, steps=None, start_delay_secs=7200, throttle_secs=7200) 61 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 62 | -------------------------------------------------------------------------------- /train_person_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow.compat.v1 as tf 3 | from detector.input_pipeline import DetectorPipeline 4 | from person_detector_model import model_fn 5 | 6 | 7 | PARAMS = { 8 | 'model_dir': 'models/run01/', 9 | 'train_dataset': '/home/dan/datasets/COCO/multiposenet/train/', 10 | 'val_dataset': '/home/dan/datasets/COCO/multiposenet/val/', 11 | 'pretrained_checkpoint': 'models/run00/model.ckpt-200000', 12 | 13 | 'backbone': 'mobilenet', 14 | 'depth_multiplier': 1.0, 15 | 'weight_decay': 5e-5, 16 | 17 | 'score_threshold': 0.3, 'iou_threshold': 0.6, 'max_boxes': 25, 18 | 'localization_loss_weight': 1.0, 'classification_loss_weight': 2.0, 19 | 20 | 'gamma': 2.0, 21 | 'alpha': 0.25, 22 | 23 | 'num_steps': 150000, 24 | 'initial_learning_rate': 1e-3, 25 | 26 | 'min_dimension': 640, 27 | 'batch_size': 16, 28 | 'image_size': (640, 640) 29 | } 30 | 31 | 32 | def get_input_fn(is_training=True): 33 | 34 | dataset_path = PARAMS['train_dataset'] if is_training else PARAMS['val_dataset'] 35 | filenames = os.listdir(dataset_path) 36 | filenames = [n for n in filenames if n.endswith('.tfrecords')] 37 | filenames = [os.path.join(dataset_path, n) for n in sorted(filenames)] 38 | 39 | def input_fn(): 40 | pipeline = DetectorPipeline(filenames, is_training, PARAMS) 41 | return pipeline.dataset 42 | 43 | return input_fn 44 | 45 | 46 | tf.logging.set_verbosity('INFO') 47 | session_config = tf.ConfigProto(allow_soft_placement=True) 48 | session_config.gpu_options.visible_device_list = '0' 49 | 50 | 51 | run_config = tf.estimator.RunConfig() 52 | run_config = run_config.replace( 53 | model_dir=PARAMS['model_dir'], session_config=session_config, 54 | save_summary_steps=200, save_checkpoints_secs=1800, 55 | log_step_count_steps=1000 56 | ) 57 | 58 | 59 | train_input_fn = get_input_fn(is_training=True) 60 | val_input_fn = get_input_fn(is_training=False) 61 | warm_start = tf.estimator.WarmStartSettings(PARAMS['pretrained_checkpoint'], ['MobilenetV1/*']) 62 | estimator = tf.estimator.Estimator(model_fn, params=PARAMS, config=run_config, warm_start_from=warm_start) 63 | 64 | 65 | train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=PARAMS['num_steps']) 66 | eval_spec = tf.estimator.EvalSpec(val_input_fn, steps=None, start_delay_secs=7200, throttle_secs=7200) 67 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 68 | -------------------------------------------------------------------------------- /train_prn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow.compat.v1 as tf 3 | from detector.input_pipeline import PoseResidualNetworkPipeline 4 | from prn_model import model_fn 5 | 6 | 7 | NUM_STEPS_PER_KEYPOINT = 10000 8 | NUM_STEPS = 200000 9 | 10 | 11 | PARAMS = { 12 | 'model_dir': 'models/run02/', 13 | 'train_dataset': '/home/dan/datasets/COCO/multiposenet/train/', 14 | 'val_dataset': '/home/dan/datasets/COCO/multiposenet/val/', 15 | 16 | 'num_steps': NUM_STEPS, 17 | 'initial_learning_rate': 1e-3, 18 | 19 | 'batch_size': 32 20 | } 21 | 22 | 23 | def get_input_fn(is_training=True, max_keypoints=None): 24 | 25 | dataset_path = PARAMS['train_dataset'] if is_training else PARAMS['val_dataset'] 26 | batch_size = PARAMS['batch_size'] 27 | filenames = os.listdir(dataset_path) 28 | filenames = [n for n in filenames if n.endswith('.tfrecords')] 29 | filenames = [os.path.join(dataset_path, n) for n in sorted(filenames)] 30 | 31 | def input_fn(): 32 | pipeline = PoseResidualNetworkPipeline(filenames, is_training, batch_size, max_keypoints) 33 | return pipeline.dataset 34 | 35 | return input_fn 36 | 37 | 38 | tf.logging.set_verbosity('INFO') 39 | session_config = tf.ConfigProto(allow_soft_placement=True) 40 | session_config.gpu_options.visible_device_list = '0' 41 | 42 | 43 | run_config = tf.estimator.RunConfig() 44 | run_config = run_config.replace( 45 | model_dir=PARAMS['model_dir'], session_config=session_config, 46 | save_summary_steps=200, save_checkpoints_secs=1800, 47 | log_step_count_steps=1000 48 | ) 49 | 50 | 51 | val_input_fn = get_input_fn(is_training=False) 52 | estimator = tf.estimator.Estimator(model_fn, params=PARAMS, config=run_config) 53 | 54 | 55 | for i in range(14): 56 | max_keypoints = i + 4 57 | train_input_fn = get_input_fn(is_training=True, max_keypoints=max_keypoints) 58 | estimator.train(train_input_fn, steps=NUM_STEPS_PER_KEYPOINT) 59 | estimator.evaluate(val_input_fn, steps=None) 60 | 61 | 62 | train_input_fn = get_input_fn(is_training=True) 63 | train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=PARAMS['num_steps']) 64 | eval_spec = tf.estimator.EvalSpec(val_input_fn, steps=None, start_delay_secs=3600, throttle_secs=3600) 65 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 66 | --------------------------------------------------------------------------------