├── README.md ├── checkpoints ├── coco │ └── readme.txt └── mpii │ └── readme.txt ├── config ├── config_hourglass_coco.py └── config_hourglass_mpii.py ├── core ├── dataset │ ├── __init__.py │ ├── data_augment.py │ └── data_generator.py ├── infer │ ├── __init__.py │ ├── freeze_graph.py │ ├── infer_utils.py │ └── visual_utils.py ├── loss │ ├── __init__.py │ └── loss.py ├── network │ ├── __init__.py │ ├── keypoints.py │ └── network_utils.py └── train │ ├── __init__.py │ └── trainer.py ├── data ├── dataset │ └── readme.txt └── name │ └── coco.name ├── demon.py ├── infer_hourglass.py ├── output ├── coco │ └── readme.txt └── mpii │ └── readme.txt ├── script ├── __init__.py ├── ckpt2ckpt.py ├── coco2txt.py ├── mpii2coco.py └── parse_ckpt.py ├── tensorRT ├── c++ │ ├── CMakeLists.txt │ ├── Keypoints_main.cpp │ ├── data │ │ └── images │ │ │ ├── 1_0_origin.jpg │ │ │ ├── 1_0_origin_render.jpg │ │ │ ├── 7_0_origin.jpg │ │ │ └── 7_0_origin_render.jpg │ └── source │ │ ├── ResizeNearestNeighbor.cpp │ │ ├── ResizeNearestNeighbor.cu │ │ ├── ResizeNearestNeighbor.h │ │ ├── keypoints_tensorrt.cpp │ │ ├── keypoints_tensorrt.h │ │ ├── my_plugin.cpp │ │ ├── my_plugin.h │ │ ├── utils.cpp │ │ └── utils.h └── python │ ├── __init__.py │ ├── pb2uff.py │ ├── readpb2graph.py │ └── tfpb2trtpb.py ├── train_hourglass_coco.py └── train_hourglass_mpii.py /README.md: -------------------------------------------------------------------------------- 1 | # Keypoint Detection In Tensorflow and TensorRT C++ 2 | ## 1.Modified hourglass (Hourglass-104) and ResNet-101
3 | 4 | ### Introduction 5 | 此项目为关键点检测训练以及推理加速代码。训练部分用python3 + tensorflow-1.14完成,推理部分用C++ + tensorRT-6完成。
6 | 训练数据集主要为COCO,模型为Hourglass。 7 | 8 | ### Quick Start 9 | * python3 train_hourglass_coco.py
10 | * python3 core/infer/freeze_graph.py -CUDA 0 -c checkpoints/coco/Hourglass_coco.ckpt -o Hourglass.pb
11 | * python3 demon.py
12 | 13 | ### Checkpoints 14 | https://drive.google.com/drive/folders/1pjOH1XUQOuMXlfGddQPvVEjXaXXPU7u1?usp=sharing
15 | 16 | ### Data Format 17 | 如果需要使用自己的数据集进行训练,首先需要将数据转换成如下的格式
18 | (filename1 bxmin,bymin,bxmax,bymax px,py px,py ...)
19 | If multi points have same label
20 | (filename1 bxmin,bymin,bxmax,bymax px,py|px,py px,py ...)
21 | (filename2 bxmin,bymin,bxmax,bymax px,py|px,py px,py|px,py ...)
22 | ...
23 | 24 | 25 | ### Inference 26 | 在core/infer/infer_utils.py中的一些api可以用来构建一个简单的inference模型。通过Flask包装一下就可以实现简单的线上推理了。操作示例在infer_hourglass.py中,其中bbx需要通过其他模型获取。
27 | 28 | ### 注意事项 29 | TensorRT部分已经转移到新的仓库下
30 | [https://github.com/Syencil/tensorRT](https://github.com/Syencil/tensorRT) 31 | 32 | ## 2.TensorRT 33 | ## 介绍 34 | 此处项目采用CUDA 10 + tensorRT-6完成推理阶段,可实现模型推理加速,支持FP32,FP16 35 | ### 开始使用 36 | * 1.pb转uff 37 | * cd tensorRT/python 38 | * python3 pb2uff.py 39 | * 2.编译C++文件 40 | * cd tensorRT/c++ 41 | * cmake . 42 | * make 43 | 44 | 45 | ## 尚未完成的部分 46 | * ~~1.数据增强 主要是图像旋转增强这一块有问题,会尽快将包括其他的增强方式加入项目~~ 47 | * ~~2.TensorRT C++中对upsample plugin的实现,框架现已搭好,会尽快更新~~ 48 | * ~~3.通过Hourglass-101构建今年大火的Anchor-free检测器CenterNet:Object as point~~ 49 | * ~~4.tensorRT C++数据预处理和python有点不同,并不影响太多,懒得改了。~~ 50 | * ~~5.Int 8量化矫正,有空再更新~~ 51 | 52 | -------------------------------------------------------------------------------- /checkpoints/coco/readme.txt: -------------------------------------------------------------------------------- 1 | directory contains checkpoints -------------------------------------------------------------------------------- /checkpoints/mpii/readme.txt: -------------------------------------------------------------------------------- 1 | directory contains checkpoints -------------------------------------------------------------------------------- /config/config_hourglass_coco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-11 8 | """ 9 | 10 | # HARDWARE 11 | CUDA_VISIBLE_DEVICES = '2' 12 | CUDA_VISIBLE_DEVICES_INFER = '1' 13 | MULTI_THREAD_NUM = 4 14 | MULTI_GPU = [0] 15 | 16 | # PATH 17 | dataset_dir = '/data/dataset/coco' 18 | train_image_dir = '/data/dataset/coco/images/train2017' 19 | val_image_dir = '/data/dataset/coco/images/val2017' 20 | train_list_path = 'data/dataset/coco/coco_train.txt' 21 | val_list_path = 'data/dataset/coco/coco_val.txt' 22 | 23 | log_dir = 'output/coco' 24 | ckpt_dir = '/data/checkpoints/coco' 25 | 26 | # AUGMENT 27 | augment = { 28 | "color_jitter": 0.5, 29 | "crop": (0.5, 0.9), 30 | "rotate": (0.5, 15), 31 | "ver_flip": 0, 32 | "hor_flop": 0, 33 | } 34 | 35 | # NETWORK 36 | backbone = "hourglass" 37 | loss_mode = 'focal' # focal, sigmoid, softmax, mse 38 | image_size = (512, 512) 39 | stride = 4 40 | heatmap_size = (128, 128) 41 | num_block = 2 42 | num_depth = 5 43 | residual_dim = [256, 384, 384, 384, 512] 44 | 45 | is_maxpool = False 46 | is_nearest = True 47 | 48 | # SAVER AND LOADER 49 | max_keep = 30 50 | pre_trained_ckpt = None 51 | ckpt_name = backbone + "_coco" + '.ckpt' 52 | 53 | # TRAINING 54 | batch_size = 16 55 | learning_rate_init = 1e-3 56 | learning_rate_warmup = 2.5e-4 57 | exp_decay = 0.97 58 | 59 | warmup_epoch_size = 0 60 | epoch_size = 40 61 | summary_per = 20 62 | save_per = 2500 63 | 64 | regularization_weight = 5e-4 65 | 66 | 67 | 68 | 69 | # VAL 70 | val_per = 2500 71 | val_time = 20 72 | val_rate = 0.1 73 | 74 | # TEST 75 | 76 | -------------------------------------------------------------------------------- /config/config_hourglass_mpii.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-11 8 | """ 9 | 10 | # HARDWARE 11 | CUDA_VISIBLE_DEVICES = '3' 12 | CUDA_VISIBLE_DEVICES_INFER = '0' 13 | MULTI_THREAD_NUM = 4 14 | MULTI_GPU = [0] 15 | 16 | # PATH 17 | dataset_dir = 'data/dataset/mpii' 18 | train_image_dir = 'data/dataset/mpii/images' 19 | val_image_dir = 'data/dataset/mpii/images' 20 | train_list_path = 'data/dataset/mpii/mpii_train.txt' 21 | val_list_path = 'data/dataset/mpii/mpii_train.txt' 22 | 23 | log_dir = 'output/mpii' 24 | ckpt_dir = 'checkpoints/mpii' 25 | 26 | 27 | # AUGMENT 28 | augment = { 29 | "color_jitter": 0.5, 30 | "crop" : (0.5, 0.9), 31 | "rotate": (0.5,30), 32 | "ver_flip": 0, 33 | "hor_flop": 0.5, 34 | } 35 | 36 | # NETWORK 37 | backbone = "hourglass" 38 | loss_mode = 'focal' # focal, sigmoid, softmax, mse 39 | image_size = (512, 512) 40 | stride = 4 41 | heatmap_size = (128, 128) 42 | num_block = 2 43 | num_depth = 5 44 | residual_dim = [256, 384, 384, 384, 512] 45 | 46 | is_maxpool = False 47 | is_nearest = True 48 | 49 | # SAVER AND LOADER 50 | max_keep = 30 51 | pre_trained_ckpt = None 52 | ckpt_name = backbone + "_voc" + ".ckpt" 53 | 54 | # TRAINING 55 | batch_size = 8 56 | learning_rate_init = 2.5e-4 57 | learning_rate_warmup = 1e-4 58 | momentum = 0.9 59 | 60 | warmup_epoch_size = 1 61 | epoch_size = 60 62 | summary_per = 20 63 | save_per = 5000 64 | 65 | regularization_weight = 5e-4 66 | 67 | # VAL 68 | val_per = 200 69 | val_time = 20 70 | val_rate = 0.1 71 | 72 | # TEST 73 | 74 | -------------------------------------------------------------------------------- /core/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/core/dataset/__init__.py -------------------------------------------------------------------------------- /core/dataset/data_augment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-10-08 8 | """ 9 | import numpy as np 10 | from albumentations import ( 11 | KeypointParams, 12 | HorizontalFlip, 13 | VerticalFlip, 14 | RandomCrop, 15 | Compose, 16 | ShiftScaleRotate, 17 | RandomBrightnessContrast, 18 | HueSaturationValue, 19 | Resize 20 | ) 21 | 22 | 23 | def image_augment_with_keypoint(image, keypoints, color_jitter=0.5, crop=( 24 | 0.5, 0.8), rotate=(0.5, 30), ver_flip=0, hor_flop=0.5): 25 | 26 | image_h, image_w = image.shape[0:2] 27 | keypoints = np.clip(keypoints, None, max(image_w - 1, image_h - 1)) 28 | points_ = [] 29 | idx_ = [] 30 | for i, ps in enumerate(keypoints): 31 | for j, p in enumerate(ps): 32 | if p[0] >= 0 and p[1] >= 0: 33 | points_.append(p) 34 | idx_.append([i, j]) 35 | 36 | def get_aug(aug): 37 | return Compose(aug, keypoint_params=KeypointParams(format="xy")) 38 | 39 | aug = get_aug([VerticalFlip(p=ver_flip), 40 | HorizontalFlip(p=hor_flop), 41 | RandomCrop( 42 | p=crop[0], 43 | height=int( 44 | image_h * 45 | crop[1]), 46 | width=int( 47 | image_w * 48 | crop[1])), 49 | ShiftScaleRotate(p=rotate[0], rotate_limit=rotate[1]), 50 | RandomBrightnessContrast(p=color_jitter), 51 | HueSaturationValue(p=color_jitter), 52 | Resize(p=1, height=image_h, width=image_w) 53 | ] 54 | ) 55 | augmented = aug(image=image, keypoints=points_) 56 | 57 | for i in range(len(augmented["keypoints"])): 58 | keypoints[idx_[i][0]][idx_[i][1]] = list(augmented["keypoints"][i]) 59 | 60 | return augmented["image"], keypoints 61 | -------------------------------------------------------------------------------- /core/dataset/data_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-16 8 | """ 9 | import random 10 | import time 11 | import cv2 12 | import numpy as np 13 | import os 14 | import copy 15 | from core.dataset.data_augment import image_augment_with_keypoint 16 | 17 | 18 | class Dataset(): 19 | def __init__(self, image_dir, gt_path, batch_size, 20 | augment=None, image_size=(512, 512), heatmap_size=(128, 128)): 21 | """ 22 | Wrapper for key-points detection dataset 23 | :param image_dir: (str) image dir 24 | :param gt_path: (str) data file eg. train.txt or val.txt, etc 25 | :param batch_size: (int) batch size 26 | :param image_size: (int, int) height, width 27 | :param heatmap_size: (int, int) height, width. can be divided by image_size 28 | """ 29 | # 数据量太大 不能直接读到内存 tf.data.dataset 不好使用 30 | # 读取info支持使用多线程加速 31 | self.gt_path = gt_path 32 | self.image_dir = image_dir 33 | self.image_size = image_size 34 | self.heatmap_size = heatmap_size 35 | self.batch_size = batch_size 36 | self.augment = augment 37 | 38 | self.data_set = self.creat_set_from_txt() 39 | # self.transform_image_set_abs_to_rel() 40 | 41 | self.num_data = len(self.data_set) 42 | self.num_class = len(self.data_set[0][2]) 43 | self.stride = self.image_size[0] // self.heatmap_size[0] 44 | self.ratio = self.image_size[0] / self.image_size[1] 45 | 46 | self._pre = -self.batch_size 47 | 48 | def creat_set_from_txt(self): 49 | """ 50 | support multi point 51 | read image info and gt into memory 52 | :return: [[(str) image_name, [(int) xmin, (int) ymin, (int) xmax, (int) ymax], [[(int) px, (int) py]]]] 53 | """ 54 | image_set = [] 55 | t0 = time.time() 56 | count = 0 57 | 58 | for line in open(self.gt_path, 'r').readlines(): 59 | if line == '': 60 | continue 61 | count += 1 62 | if count % 5000 == 0: 63 | print("--parse %d " % count) 64 | b = line.split()[1].split(',') 65 | points = line.split()[2:] 66 | tmp = [] 67 | for point in points: 68 | tmp.append([[round(float(x)) for x in y.split(",")] 69 | for y in point.split('|')]) 70 | image_set.append( 71 | (line.split()[0], [round(float(x)) for x in b], tmp)) 72 | print('-Set has been created in %.3fs' % (time.time() - t0)) 73 | return image_set 74 | 75 | def sample_batch_image_random(self): 76 | """ 77 | sample data (infinitely) 78 | :return: list 79 | """ 80 | return random.sample(self.data_set, self.batch_size) 81 | # return self.data_set[:self.batch_size] 82 | 83 | def sample_batch_image_order(self): 84 | """ 85 | sample data in order (one shot) 86 | :return: list 87 | """ 88 | self._pre += self.batch_size 89 | if self._pre >= self.num_data: 90 | raise StopIteration 91 | _last = self._pre + self.batch_size 92 | if _last > self.num_data: 93 | _last = self.num_data 94 | return self.data_set[self._pre:_last] 95 | 96 | def make_guassian(self, height, width, sigma=3, center=None): 97 | x = np.arange(0, width, 1, float) 98 | y = np.arange(0, height, 1, float)[:, np.newaxis] 99 | if center is None: 100 | x0 = width // 2 101 | y0 = height // 2 102 | else: 103 | x0 = center[0] 104 | y0 = center[1] 105 | return np.exp(-4. * np.log(2.) * ((x - x0) ** 106 | 2 + (y - y0) ** 2) / sigma ** 2) 107 | 108 | def generate_hm(self, joints, heatmap_h_w): 109 | num_joints = len(joints) 110 | hm = np.zeros([heatmap_h_w[0], heatmap_h_w[1], 111 | num_joints], dtype=np.float32) 112 | for i in range(num_joints): 113 | for joint in joints[i]: 114 | if joint[0] != -1 and joint[1] != -1: 115 | s = int( 116 | np.sqrt( 117 | heatmap_h_w[0]) * heatmap_h_w[1] * 10 / 4096) + 2 118 | gen_hm = self.make_guassian(heatmap_h_w[0], heatmap_h_w[1], sigma=s, 119 | center=[joint[0] // self.stride, joint[1] // self.stride]) 120 | hm[:, :, i] = np.maximum(hm[:, :, i], gen_hm) 121 | return hm 122 | 123 | def _crop_image_with_pad_and_resize(self, image, bbx, points, ratio=0.05): 124 | image_h, image_w = image.shape[0:2] 125 | crop_bbx = copy.deepcopy(bbx) 126 | crop_points = copy.deepcopy(points) 127 | 128 | w = bbx[2] - bbx[0] + 1 129 | h = bbx[3] - bbx[1] + 1 130 | # keep 5% blank for edge 131 | crop_bbx[0] = int(bbx[0] - w * ratio) 132 | crop_bbx[1] = int(bbx[1] - h * ratio) 133 | crop_bbx[2] = int(bbx[2] + w * ratio) 134 | crop_bbx[3] = int(bbx[3] + h * ratio) 135 | # clip value from 0 to len-1 136 | crop_bbx[0] = 0 if crop_bbx[0] < 0 else crop_bbx[0] 137 | crop_bbx[1] = 0 if crop_bbx[1] < 0 else crop_bbx[1] 138 | crop_bbx[2] = image_w - 1 if crop_bbx[2] > image_w - 1 else crop_bbx[2] 139 | crop_bbx[3] = image_h - 1 if crop_bbx[3] > image_h - 1 else crop_bbx[3] 140 | # crop the image 141 | crop_image = image[crop_bbx[1]: crop_bbx[3] + 142 | 1, crop_bbx[0]: crop_bbx[2] + 1, :] 143 | # update width and height 144 | w = crop_bbx[2] - crop_bbx[0] + 1 145 | h = crop_bbx[3] - crop_bbx[1] + 1 146 | # keep aspect ratio 147 | 148 | ih, iw = self.image_size 149 | 150 | scale = min(iw / w, ih / h) 151 | nw, nh = int(scale * w), int(scale * h) 152 | image_resized = cv2.resize(crop_image, (nw, nh)) 153 | 154 | image_paded = np.full(shape=[ih, iw, 3], fill_value=128, dtype=np.uint8) 155 | dw, dh = (iw - nw) // 2, (ih - nh) // 2 156 | image_paded[dh:nh + dh, dw:nw + dw, :] = image_resized 157 | for i in range(len(points)): 158 | for j, point in enumerate(points[i]): 159 | if point[0] != -1 and point[1] != -1: 160 | crop_points[i][j][0] = (point[0] - crop_bbx[0]) * scale + dw 161 | crop_points[i][j][1] = (point[1] - crop_bbx[1]) * scale + dh 162 | 163 | return image_paded, crop_points 164 | 165 | def _one_image_and_heatmap(self, image_set): 166 | """ 167 | process only one image 168 | :param image_set: [image_name, bbx, [points]] 169 | :return: (narray) image_h_w x C, (narray) heatmap_h_w x C' 170 | """ 171 | image_name, bbx, point = image_set 172 | image_path = os.path.join(self.image_dir, image_name) 173 | img = cv2.imread(image_path) 174 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 175 | img, point = self._crop_image_with_pad_and_resize(img, bbx, point) 176 | if self.augment is not None: 177 | img, point = image_augment_with_keypoint(img, point) 178 | hm = self.generate_hm(point, self.heatmap_size) 179 | return img, hm 180 | 181 | def iterator(self, max_worker=None, is_oneshot=False): 182 | """ 183 | Wrapper for batch_data processing 184 | transform data from txt to imgs and hms 185 | (Option) utilize multi thread acceleration 186 | generator images and heatmaps infinitely or make oneshot 187 | :param max_worker: (optional) (int) max worker for multi-thread 188 | :param is_oneshot: (optional) (bool) if False, generator will sample infinitely. 189 | :return: iterator. imgs, hms = next(iterator) 190 | """ 191 | if is_oneshot: 192 | sample_fn = self.sample_batch_image_order 193 | else: 194 | sample_fn = self.sample_batch_image_random 195 | if max_worker is not 0: 196 | from concurrent.futures import ThreadPoolExecutor, as_completed 197 | with ThreadPoolExecutor(max_worker) as executor: 198 | while True: 199 | image_set = sample_fn() 200 | imgs = [] 201 | hms = [] 202 | if executor is None: 203 | for i in range(len(image_set)): 204 | img, hm = self._one_image_and_heatmap(image_set[i]) 205 | imgs.append(img) 206 | hms.append(hm) 207 | else: 208 | all_task = [ 209 | executor.submit( 210 | self._one_image_and_heatmap, 211 | image_set[i]) for i in range( 212 | len(image_set))] 213 | for future in as_completed(all_task): 214 | imgs.append(future.result()[0]) 215 | hms.append(future.result()[1]) 216 | final_imgs = np.stack(imgs, axis=0) 217 | final_hms = np.stack(hms, axis=0) 218 | yield final_imgs, final_hms 219 | else: 220 | while True: 221 | image_set = sample_fn() 222 | imgs = [] 223 | hms = [] 224 | for i in range(len(image_set)): 225 | img, hm = self._one_image_and_heatmap(image_set[i]) 226 | imgs.append(img) 227 | hms.append(hm) 228 | final_imgs = np.stack(imgs, axis=0) 229 | final_hms = np.stack(hms, axis=0) 230 | yield final_imgs, final_hms 231 | 232 | 233 | if __name__ == '__main__': 234 | 235 | from core.infer.visual_utils import visiual_image_with_hm 236 | import config.config_hourglass_coco as cfg 237 | image_dir = cfg.val_image_dir 238 | gt_path = "../../"+cfg.val_list_path 239 | render_path = '../../render_img' 240 | 241 | ite = 3 242 | batch_size = 16 243 | 244 | coco = Dataset(image_dir, gt_path, batch_size, augment=cfg.augment) 245 | it = coco.iterator(0, True) 246 | 247 | t0 = time.time() 248 | for i in range(ite): 249 | b_img, b_hm = next(it) 250 | for j in range(batch_size): 251 | img = b_img[j][:, :, ::-1] 252 | hm = b_hm[j] 253 | img_hm = visiual_image_with_hm(img, hm) 254 | cv2.imwrite( 255 | '../../render_img/' + 256 | str(i) + 257 | '_' + 258 | str(j) + 259 | '_img_hm.jpg', 260 | img_hm) 261 | 262 | print(time.time() - t0) 263 | -------------------------------------------------------------------------------- /core/infer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-21 8 | """ 9 | -------------------------------------------------------------------------------- /core/infer/freeze_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-24 8 | """ 9 | 10 | import argparse 11 | import tensorflow as tf 12 | from tensorflow.python.framework import graph_util 13 | # from tensorflow.contrib.tensorrt import trt_convert as trt 14 | import os 15 | import sys 16 | 17 | sys.path.append('.') 18 | 19 | 20 | def parse_arg(): 21 | parse = argparse.ArgumentParser() 22 | parse.add_argument( 23 | '-C', 24 | '--CUDA', 25 | dest='CUDA', 26 | default=None, 27 | help='CUDA_VISIBLE_DEVICE') 28 | parse.add_argument( 29 | '-c', 30 | '--ckpt', 31 | dest='ckpt', 32 | default=None, 33 | help='Freeze ckpt path') 34 | parse.add_argument( 35 | '-o', 36 | '--output', 37 | dest='output_graph', 38 | default=None, 39 | help='Output graph path') 40 | parse.add_argument( 41 | '-t', 42 | '--is_training', 43 | dest='output_graph', 44 | default=False, 45 | help='Output graph path') 46 | return parse.parse_args() 47 | 48 | 49 | def freeze_graph(input_checkpoint, output_graph, is_training=False): 50 | ''' 51 | :param input_checkpoint: 52 | :param output_graph: PB模型保存路径 53 | :param is_training: Is BN using moving-mean and moving-var 54 | :return: 55 | ''' 56 | # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 57 | # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 58 | 59 | if is_training: 60 | saver = tf.train.import_meta_graph( 61 | input_checkpoint + '.meta', clear_devices=True) 62 | else: 63 | from core.network.keypoints import Keypoints 64 | import config.config_hourglass_coco as config 65 | model = Keypoints(tf.placeholder(name="Placeholder/inputs_x", dtype=tf.float32, shape=[None, 512, 512, 3]), 66 | 17, 67 | num_block=config.num_block, 68 | num_depth=config.num_depth, 69 | residual_dim=config.residual_dim, 70 | is_training=False, 71 | is_maxpool=config.is_maxpool, 72 | is_nearest=config.is_nearest 73 | ) 74 | saver = tf.train.Saver(var_list=tf.global_variables()) 75 | 76 | # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 77 | print('Freeze graph') 78 | output_node_names = ["Keypoints/keypoint_1/conv/Sigmoid"] 79 | print(output_node_names) 80 | 81 | with tf.Session() as sess: 82 | # sess.run(tf.global_variables_initializer()) 83 | saver.restore(sess, input_checkpoint) # 恢复图并得到数据 84 | output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 85 | sess=sess, 86 | input_graph_def=sess.graph_def, # 等于:sess.graph_def 87 | output_node_names=output_node_names) 88 | 89 | with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 90 | f.write(output_graph_def.SerializeToString()) # 序列化输出 91 | print("%d ops in the final graph." % 92 | len(output_graph_def.node)) # 得到当前图有几个操作节点 93 | 94 | 95 | if __name__ == '__main__': 96 | args = parse_arg() 97 | if args.CUDA is not None: 98 | os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA 99 | freeze_graph(args.ckpt, args.output_graph) 100 | -------------------------------------------------------------------------------- /core/infer/infer_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-21 8 | """ 9 | import os 10 | import cv2 11 | import time 12 | import numpy as np 13 | import tensorflow as tf 14 | 15 | 16 | def read_pb_infer(pb_path, input_node_name_and_val, output_node_name): 17 | """ 18 | [xmin, ymin, xmax, ymax, score, cid] 19 | :param pb_path: 20 | :param input_node_name_and_val: {(str) input_node_name: (any) input_node_val} 21 | :param output_node_name: [(str) output_node_name] 22 | :return: [output] B x Num_bbx x 6 23 | """ 24 | with tf.Graph().as_default(): 25 | output_graph_def = tf.GraphDef() 26 | with open(pb_path, 'rb') as f: 27 | output_graph_def.ParseFromString(f.read()) 28 | tf.import_graph_def(output_graph_def, name='') 29 | config = tf.ConfigProto(allow_soft_placement=True) # 是否自动选择GPU 30 | config.gpu_options.allow_growth = True 31 | with tf.Session(config=config) as sess: 32 | # sess.run(tf.global_variables_initializer()) 33 | # 定义输入的张量名称,对应网络结构的输入张量 34 | # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数 35 | feed_dict = {} 36 | for key in input_node_name_and_val: 37 | input_tensor = sess.graph.get_tensor_by_name(key) 38 | feed_dict[input_tensor] = input_node_name_and_val[key] 39 | 40 | # 定义输出的张量名称 41 | output_tensor = [] 42 | for name in output_node_name: 43 | output_tensor.append(sess.graph.get_tensor_by_name(name)) 44 | 45 | # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字 46 | start_time = time.time() 47 | output = sess.run(output_tensor, feed_dict=feed_dict) 48 | print('Infer time is %.4f' % (time.time() - start_time)) 49 | return output 50 | 51 | 52 | def read_pb(pb_path, input_name, output_name): 53 | """ 54 | Instantiation Session 55 | :param pb_path: (str) pb file path 56 | :param input_name: [(str)] input tensor names 57 | :param output_name: [(str)] output tensor names 58 | :return: (tf.Session) sess, (Tensor) input, (Tensor) output 59 | """ 60 | # return sess 61 | with tf.Graph().as_default(): 62 | output_graph_def = tf.GraphDef() 63 | with open(pb_path, 'rb') as f: 64 | output_graph_def.ParseFromString(f.read()) 65 | tf.import_graph_def(output_graph_def, name='') 66 | config = tf.ConfigProto(allow_soft_placement=True) # 是否自动选择GPU 67 | config.gpu_options.allow_growth = True 68 | sess = tf.Session(config=config) 69 | 70 | if not isinstance(input_name, list) or isinstance(input_name, tuple): 71 | input_name = [input_name] 72 | input_tensor = [] 73 | output_tensor = [] 74 | for i in range(len(input_name)): 75 | input_tensor.append(sess.graph.get_tensor_by_name(input_name[i])) 76 | for i in range(len(output_name)): 77 | output_tensor.append(sess.graph.get_tensor_by_name(output_name[i])) 78 | 79 | return sess, input_tensor, output_tensor 80 | 81 | 82 | def pb_infer(sess, output_tensor, input_tensor=None, input_val=None): 83 | """ 84 | get output 85 | :param sess: (tf.Session) sess 86 | :param output_tensor: [(Tensor)] 87 | :param input_tensor: [(Tensor)] 88 | :param input_val: [(np.array)] 89 | :return: 90 | """ 91 | feed_dict = {} 92 | if input_tensor is not None and input_val is not None: 93 | for i in range(len(input_tensor)): 94 | feed_dict[input_tensor[i]] = input_val[i] 95 | 96 | return sess.run(output_tensor, feed_dict) 97 | 98 | 99 | def image_process(image, bbx): 100 | """ 101 | image pre-process 102 | :param image: (str) image_path / (np.array) image in BGR 103 | :param bbx: [(int) xmin, (int) ymin, (int) xmax, (int) ymax] 104 | :return: input_image, bbx 105 | """ 106 | if type(image) == str: 107 | image = cv2.imread(image) 108 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 109 | crop_image, crop_bbx = crop_with_padding_and_resize(image, bbx) 110 | cv2.imwrite("/work/meter_recognition/render_img/crop.jpg", crop_image) 111 | image_norm = crop_image / 127 - 1 112 | return image_norm, crop_bbx 113 | 114 | 115 | def crop_with_padding_and_resize(image, bbx, shape=(512, 512), ratio=0.2): 116 | """ 117 | image pre-process 118 | :param image: image path or BGR image 119 | :param bbx: [xmin, ymin, xmax, ymax] 120 | :param shape: output image shape 121 | :param ratio: keep blank for edge 122 | :return: resized and padded image 123 | """ 124 | image_h, image_w = image.shape[0:2] 125 | crop_bbx = np.copy(bbx) 126 | 127 | w = bbx[2] - bbx[0] + 1 128 | h = bbx[3] - bbx[1] + 1 129 | # keep 0.2 blank for edge 130 | crop_bbx[0] = round(bbx[0] - w * ratio) 131 | crop_bbx[1] = round(bbx[1] - h * ratio) 132 | crop_bbx[2] = round(bbx[2] + w * ratio) 133 | crop_bbx[3] = round(bbx[3] + h * ratio) 134 | # clip value from 0 to len-1 135 | crop_bbx[0] = 0 if crop_bbx[0] < 0 else crop_bbx[0] 136 | crop_bbx[1] = 0 if crop_bbx[1] < 0 else crop_bbx[1] 137 | crop_bbx[2] = image_w - 1 if crop_bbx[2] > image_w - 1 else crop_bbx[2] 138 | crop_bbx[3] = image_h - 1 if crop_bbx[3] > image_h - 1 else crop_bbx[3] 139 | # crop the image 140 | crop_image = image[crop_bbx[1]: crop_bbx[3] + 1, crop_bbx[0]: crop_bbx[2] + 1, :] 141 | # update width and height 142 | w = crop_bbx[2] - crop_bbx[0] + 1 143 | h = crop_bbx[3] - crop_bbx[1] + 1 144 | # keep aspect ratio 145 | # padding 146 | if h < w: 147 | pad = int(w - h) 148 | pad_t = pad // 2 149 | pad_d = pad - pad_t 150 | pad_image = np.pad(crop_image, ((pad_t, pad_d), (0, 0), (0, 0)), constant_values=128) 151 | else: 152 | pad = int(h - w) 153 | pad_l = pad // 2 154 | pad_r = pad - pad_l 155 | pad_image = np.pad(crop_image, ((0, 0), (pad_l, pad_r), (0, 0)), constant_values=128) 156 | crop_image = cv2.resize(pad_image, shape) 157 | return crop_image, crop_bbx 158 | 159 | 160 | def rel2abs(bbx, points): 161 | """ 162 | transform points location into original location 163 | :param bbx: [xmin, ymin, xmax, ymax] cropped bbx 164 | :param points: [[x, y, score]] points location in heatmap 165 | :return: [[x, y, score]] points location in original image 166 | """ 167 | bbx = bbx.copy() 168 | h, w = bbx[3] - bbx[1], bbx[2] - bbx[0] 169 | max_len = max(h, w) 170 | pad_t = (max_len - h) // 2 171 | pad_d = (max_len - h) - (max_len - h) // 2 172 | pad_l = (max_len - w) // 2 173 | pad_r = (max_len - w) - (max_len - w) // 2 174 | bbx[0] -= pad_l 175 | bbx[1] -= pad_t 176 | bbx[2] += pad_r 177 | bbx[3] += pad_d 178 | for point in points: 179 | point[0] = bbx[0] + point[0] * max_len / 128 180 | point[1] = bbx[1] + point[1] * max_len / 128 181 | return points 182 | 183 | 184 | def draw_point(image, points): 185 | for point in points: 186 | if int(point[0]) != -1 and int(point[1]) != -1: 187 | image = cv2.circle( 188 | image, (int(point[0]), int(point[1])), 5, (255, 204, 0), 3) 189 | return image 190 | 191 | 192 | def pred_one_image(image, bbxes, sess, input_tensor, output_tensor): 193 | processed_images = [] 194 | processed_bbxes = [] 195 | for bbx in bbxes: 196 | input_image, croped_bbxes = image_process(image, bbx) 197 | processed_images.append(input_image) 198 | processed_bbxes.append(croped_bbxes) 199 | batch_image = np.stack(processed_images, axis=0) 200 | batch_hm = pb_infer(sess, output_tensor, input_tensor, [batch_image])[0] 201 | final_point = [] 202 | for i in range(len(batch_image)): 203 | hm = batch_hm[i] 204 | img = batch_image[i] 205 | point = get_results(hm, threshold=0.01)[0] 206 | 207 | point = rel2abs(processed_bbxes[i], point) 208 | final_point.append(point) 209 | return final_point 210 | 211 | 212 | def get_results(hms, threshold=0.6): 213 | if len(hms.shape) == 3: 214 | hms = np.expand_dims(hms, axis=0) 215 | num_class = hms.shape[-1] 216 | results = [] 217 | for b in range(len(hms)): 218 | joints = -1 * np.ones([num_class, 3], dtype=np.float32) 219 | hm = hms[b] 220 | for c in range(num_class): 221 | index = np.unravel_index( 222 | np.argmax(hm[:, :, c]), hm[:, :, c].shape) 223 | # tmp = list(index) 224 | tmp = [index[1], index[0]] 225 | score = hm[index[0], index[1], c] 226 | tmp.append(score) 227 | if score > threshold: 228 | joints[c] = np.array(tmp) 229 | results.append(joints.tolist()) 230 | return results 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /core/infer/visual_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-23 8 | """ 9 | import cv2 10 | import numpy as np 11 | 12 | bbx_color = (28, 255, 147) 13 | pointer_color = (255, 204, 0) 14 | txt_color = (0, 240, 78) 15 | 16 | 17 | def get_results(hms, threshold=0.6): 18 | if len(hms.shape) == 3: 19 | hms = np.expand_dims(hms, axis=0) 20 | num_class = hms.shape[-1] 21 | results = [] 22 | for b in range(len(hms)): 23 | joints = -1 * np.ones([num_class, 3], dtype=np.float32) 24 | hm = hms[b] 25 | for c in range(num_class): 26 | index = np.unravel_index( 27 | np.argmax(hm[:, :, c]), hm[:, :, c].shape) 28 | # tmp = list(index) 29 | tmp = [index[1], index[0]] 30 | score = hm[index[0], index[1], c] 31 | tmp.append(score) 32 | if score > threshold: 33 | joints[c] = np.array(tmp) 34 | results.append(joints.tolist()) 35 | return results 36 | 37 | 38 | def draw_bbx(image, bbx): 39 | image = cv2.rectangle( 40 | image, (bbx[0], bbx[1]), (bbx[2], bbx[3]), bbx_color, 3) 41 | return image 42 | 43 | 44 | def draw_point(image, points): 45 | for point in points: 46 | if point[0] != -1 and point[1] != -1: 47 | image = cv2.circle( 48 | image, (point[0], point[1]), 5, pointer_color, 3) 49 | return image 50 | 51 | 52 | def draw_skeleton(image, points, dataset='mpii'): 53 | for point in points: 54 | if point[0] != -1 and point[1] != -1: 55 | image = cv2.circle( 56 | image, (int(point[0]), int(point[1])), 5, pointer_color, 3) 57 | if dataset is 'mpii': 58 | LINKS = [(0, 1), (1, 2), (2, 6), (6, 3), (3, 4), (4, 5), (6, 8), 59 | (8, 13), (13, 14), (14, 15), (8, 12), (12, 11), (11, 10)] 60 | for link in LINKS: 61 | if points[link[0]][:2] != [-1,-1] and points[link[1]][:2] != [-1,-1]: 62 | image = cv2.line(image, (int(points[link[0]][0]),int(points[link[0]][1])), (int(points[link[1]][0]),int(points[link[1]][1])), bbx_color) 63 | return image 64 | elif dataset is 'coco': 65 | LINKS = [[16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13],[6,7],[6,8],[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]] 66 | for link in LINKS: 67 | if points[link[0]-1][:2] != [-1,-1] and points[link[1]-1][:2] != [-1,-1]: 68 | image = cv2.line(image, (int(points[link[0]-1][0]),int(points[link[0]-1][1])), (int(points[link[1]-1][0]),int(points[link[1]-1][1])), bbx_color) 69 | return image 70 | 71 | 72 | def visiual_image_with_hm(img, hm): 73 | hm = np.sum(hm, axis=-1) * 255 74 | hm = np.expand_dims(hm, axis=-1) 75 | hm = np.tile(hm, (1, 1, 3)) 76 | hm = cv2.resize(hm, (img.shape[1], img.shape[0])) 77 | img = img + hm 78 | # img = np.clip(img, 0, 255) 79 | return img -------------------------------------------------------------------------------- /core/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/core/loss/__init__.py -------------------------------------------------------------------------------- /core/loss/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-10 8 | """ 9 | import tensorflow as tf 10 | import tensorflow.contrib.slim as slim 11 | 12 | 13 | def cross_entropy(features, heatmap): 14 | print('-Utilize Sigmoid-Cross-Entropy-Loss') 15 | """ 16 | loss for point locating. 17 | B batch size 18 | H, W Tensor shape 19 | C num of classes 20 | CELoss 21 | :param features: (Tensor) without actived BxHxWxC 22 | :param heatmap: (Tensor) labels BxHxWxC 23 | :return: (List(Tensor)) 24 | """ 25 | if not isinstance(features, list): 26 | features = [features] 27 | losses = [] 28 | for i in range(len(features)): 29 | loss = - heatmap * tf.log(features[i]) 30 | losses.append(tf.reduce_mean(loss)) 31 | return losses 32 | 33 | 34 | def softmax_cross_entropy(features, heatmap): 35 | print('-Utilize Softmax-Cross-Entropy-Loss') 36 | if not isinstance(features, list): 37 | features = [features] 38 | losses = [] 39 | for i in range(len(features)): 40 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 41 | logits=features[i], 42 | labels=heatmap 43 | ) 44 | losses.append(tf.reduce_mean(loss)) 45 | return losses 46 | 47 | 48 | def focal_loss(features, heatmap, alpha=2, beta=4): 49 | """ 50 | Focal Loss in "CornerNet" 51 | Loss = -1/N * (1-p)**alpha*log(p) if y=1 or (1-y)**beta*p**alpha*log(1-p) 52 | add loss into Graph 53 | :param features: (List(Tensor)) [BxHxWxC] 54 | :param heatmap: (Tensor) BxHxWxC 55 | :param alpha: (int) 56 | :param beta: (int) 57 | :return: (List(Tensor)) 58 | """ 59 | eps = 1e-9 60 | print('-Utilize Focal-Loss') 61 | if type(features) is not list: 62 | features = [features] 63 | losses = [] 64 | for i in range(len(features)): 65 | # feature = tf.nn.sigmoid(features[i]) 66 | feature = tf.clip_by_value(features[i], eps, 1 - eps) 67 | zeros = tf.zeros_like(heatmap) 68 | ones = tf.ones_like(heatmap) 69 | 70 | # mask 71 | mask = tf.where(tf.equal(heatmap, 1.0), ones, zeros) 72 | inv_mask = tf.subtract(1.0, mask) 73 | 74 | # num_pos 75 | num_pos = tf.reduce_sum(mask) 76 | num_pos = tf.maximum(num_pos, 1) 77 | 78 | # pre 79 | pos = tf.multiply(feature, mask) 80 | neg = tf.multiply(1.0 - feature, inv_mask) 81 | pre = tf.log(tf.add(pos, neg) + eps) 82 | 83 | # weight alpha 84 | pos_weight_alpha = tf.multiply(1.0 - feature, mask) 85 | neg_weight_alpha = tf.multiply(feature, inv_mask) 86 | weight_alpha = tf.pow(tf.add(pos_weight_alpha, neg_weight_alpha), alpha) 87 | 88 | # weight beta 89 | pos_weight_beta = mask 90 | neg_weight_beta = tf.multiply(1.0 - heatmap, inv_mask) 91 | weight_beta = tf.pow(tf.add(pos_weight_beta, neg_weight_beta), beta) 92 | 93 | # cal loss 94 | loss = tf.reduce_sum(- weight_beta * weight_alpha * pre) / num_pos 95 | 96 | losses.append(loss) 97 | return losses 98 | 99 | 100 | def mean_square_loss(features, heatmap): 101 | print('-Utilize Mse-Loss') 102 | if not isinstance(features, list): 103 | features = [features] 104 | losses = [] 105 | for i in range(len(features)): 106 | loss = tf.losses.mean_squared_error( 107 | heatmap, features[i]) 108 | losses.append(tf.reduce_mean(loss)) 109 | return losses 110 | -------------------------------------------------------------------------------- /core/network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/core/network/__init__.py -------------------------------------------------------------------------------- /core/network/keypoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-10 8 | """ 9 | import time 10 | import tensorflow as tf 11 | import tensorflow.contrib.slim as slim 12 | from tensorflow.contrib.slim.nets import resnet_v2 13 | from core.network.network_utils import residual_block_v2, hourglass_block 14 | 15 | 16 | class Keypoints(): 17 | def __init__(self, inputs, num_class, 18 | backbone="hourglass", 19 | num_block=2, 20 | num_depth=5, 21 | residual_dim=(256, 256, 384, 384, 384, 512), 22 | is_training=True, 23 | is_maxpool=False, 24 | is_nearest=True, 25 | reuse=False 26 | ): 27 | """ 28 | Modified hourglass. See more in network_utils.py 29 | :param inputs: (Tensor) BxHxWxC images 30 | :param num_class: (int) num of classes 31 | :param num_block: (int) num of hourglass block 32 | :param num_depth: (int) num of down-sampling steps 33 | :param residual_dim: (list(int)) output dim for each residual block. Length should be num_depth+1 34 | :param is_training: (bool) is in training parse 35 | :param is_maxpool: (bool) if true, using max-pool down-sampling. Otherwise, residual block stride will be 2 36 | :param is_nearest: (bool) if true, using nearest up-sampling. Otherwise, using deconvolution 37 | :param reuse:(bool) reuse the variable 38 | """ 39 | self.inputs = inputs 40 | self.num_class = num_class 41 | self.backbone = backbone 42 | 43 | self.num_block = num_block 44 | self.num_depth = num_depth 45 | self.residual_dim = residual_dim 46 | self.is_training = is_training 47 | self.is_maxpool = is_maxpool 48 | self.is_nearest = is_nearest 49 | self.reuse = reuse 50 | 51 | self.features = self.graph_hourglass(self.inputs) 52 | 53 | def pre_process(self, inputs, scope='pre_process'): 54 | """ 55 | pre-process conv7x7/s=2 -> residual/s=2 56 | :param inputs: (Tensor) BxHxWxC 57 | :param scope: (str) scope 58 | :return: (Tensor) BxH/4xW/4xC 59 | """ 60 | with tf.variable_scope(scope): 61 | net = slim.conv2d( 62 | inputs=inputs, 63 | num_outputs=128, 64 | kernel_size=[7, 7], 65 | stride=2, 66 | activation_fn=None, 67 | normalizer_fn=None, 68 | reuse=self.reuse, 69 | scope='conv1' 70 | ) 71 | tf.summary.histogram(net.name + '/activations', net) 72 | 73 | net = residual_block_v2( 74 | inputs=net, 75 | output_dim=256, 76 | stride=2, 77 | is_training=self.is_training, 78 | reuse=self.reuse, 79 | scope='residual_1' 80 | ) 81 | return net 82 | 83 | def inter_process(self, inputs_1, inputs_2, scope='inter_process'): 84 | with tf.variable_scope(scope): 85 | branch_1 = slim.batch_norm( 86 | inputs=inputs_1, 87 | activation_fn=tf.nn.relu, 88 | is_training=self.is_training, 89 | scope='branch_1/bn', 90 | reuse=self.reuse, 91 | scale=True 92 | ) 93 | tf.summary.histogram(branch_1.name + '/activations', branch_1) 94 | 95 | branch_1 = slim.conv2d( 96 | inputs=branch_1, 97 | num_outputs=inputs_1.get_shape().as_list()[-1], 98 | kernel_size=[1, 1], 99 | stride=1, 100 | activation_fn=None, 101 | normalizer_fn=None, 102 | reuse=self.reuse, 103 | scope='branch_1/conv' 104 | ) 105 | tf.summary.histogram(branch_1.name + '/activations', branch_1) 106 | 107 | branch_2 = slim.batch_norm( 108 | inputs=inputs_2, 109 | activation_fn=tf.nn.relu, 110 | is_training=self.is_training, 111 | scope='branch_2/bn', 112 | reuse=self.reuse, 113 | scale=True) 114 | tf.summary.histogram(branch_2.name + '/activations', branch_2) 115 | 116 | branch_2 = slim.conv2d( 117 | inputs=branch_2, 118 | num_outputs=inputs_2.get_shape().as_list()[-1], 119 | kernel_size=[1, 1], 120 | stride=1, 121 | activation_fn=None, 122 | normalizer_fn=None, 123 | reuse=self.reuse, 124 | scope='branch_2/conv' 125 | ) 126 | tf.summary.histogram(branch_2.name + '/activations', branch_2) 127 | 128 | output = tf.add(branch_1, branch_2) 129 | return output 130 | 131 | def hinge(self, inputs, output_dim, scope='hinge'): 132 | with tf.variable_scope(scope): 133 | pre = slim.batch_norm( 134 | inputs=inputs, 135 | activation_fn=tf.nn.relu, 136 | is_training=self.is_training, 137 | scope='bn', 138 | reuse=self.reuse, 139 | scale=True 140 | ) 141 | tf.summary.histogram(pre.name + '/activations', pre) 142 | 143 | outputs = slim.conv2d( 144 | inputs=pre, 145 | num_outputs=output_dim, 146 | kernel_size=[1, 1], 147 | stride=1, 148 | activation_fn=None, 149 | normalizer_fn=None, 150 | reuse=self.reuse, 151 | scope='conv' 152 | ) 153 | tf.summary.histogram(outputs.name + '/activations', outputs) 154 | return outputs 155 | 156 | def keypoint(self, features, scope='keypoint'): 157 | """ 158 | key-point branch. return final feature map 159 | :param features: (Tensor) final backbone features without bn and activated 160 | :param scope: (str) scope 161 | :return: [Tensor,...] 162 | """ 163 | keypoint_feature = [] 164 | if type(features) is not list: 165 | features = [features] 166 | for i in range(len(features)): 167 | with tf.variable_scope(scope+'_%d' % i): 168 | feature = slim.batch_norm(inputs=features[i], 169 | activation_fn=tf.nn.relu, 170 | is_training=self.is_training, 171 | scope='pre_bn', 172 | reuse=self.reuse, 173 | scale=True) 174 | tf.summary.histogram(feature.name + '/activations', feature) 175 | feature = slim.conv2d( 176 | inputs=feature, 177 | num_outputs=self.num_class, 178 | kernel_size=[3, 3], 179 | stride=1, 180 | activation_fn=tf.nn.sigmoid, 181 | normalizer_fn=None, 182 | reuse=self.reuse, 183 | scope='conv' 184 | ) 185 | tf.summary.histogram(feature.name + '/activations', feature) 186 | keypoint_feature.append(feature) 187 | 188 | return keypoint_feature 189 | 190 | def graph_backbone_hourglass(self, inputs): 191 | """ 192 | Extract features 193 | :param inputs: (Tensor) BxHxWxC images 194 | :return: [Tensor] BxH/4xW/4xC. Pre is for inter-mediate supervision, last if for prediction. 195 | """ 196 | t0 = time.time() 197 | print('-Begin to creat model') 198 | with tf.variable_scope('backbone'): 199 | start_time = time.time() 200 | pre = self.pre_process(inputs) 201 | print('--%s has been created in %.3fs' % 202 | ('pre_process', time.time() - start_time)) 203 | net = pre 204 | features = [] 205 | for i in range(self.num_block): 206 | start_time = time.time() 207 | hourglass = hourglass_block( 208 | inputs=net, 209 | num_depth=self.num_depth, 210 | residual_dim=self.residual_dim, 211 | is_training=self.is_training, 212 | is_maxpool=self.is_maxpool, 213 | is_nearest=self.is_nearest, 214 | reuse=self.reuse, 215 | scope='hourglass_%d' % i 216 | ) 217 | hinge = self.hinge(hourglass, self.residual_dim[0], 'hinge_%d' % i) 218 | features.append(hinge) 219 | print('--%s has been created in %.3fs' % ('hourglass_%d' % i, time.time() - start_time)) 220 | start_time = time.time() 221 | if i < self.num_block - 1: net = self.inter_process(net, hinge, 'inter_process_%d' % i) 222 | print('--%s has been created in %.3fs' % ('inter_process_%d' % i, time.time() - start_time)) 223 | 224 | print('-Model has been created in %.3fs' % (time.time() - t0)) 225 | return features 226 | 227 | def graph_backbone_resnet_101(self, inputs): 228 | with tf.variable_scope('backbone'): 229 | feature, end_point = resnet_v2.resnet_v2_101(inputs, num_classes=None, global_pool=False, is_training=self.is_training, reuse=self.reuse, scope="resnet_v2_101") 230 | with tf.variable_scope('up_sample'): 231 | feature = slim.conv2d_transpose(feature, 512, 3, 2, activation_fn=None, reuse=self.reuse, scope="transpose_conv1") 232 | feature = slim.batch_norm(inputs=feature, 233 | activation_fn=tf.nn.relu, 234 | is_training=self.is_training, 235 | scope='transpose_conv1/bn', 236 | reuse=self.reuse, 237 | scale=True) 238 | feature = slim.conv2d_transpose(feature, 512, 3, 2, activation_fn=None, reuse=self.reuse, scope="transpose_conv2") 239 | feature = slim.batch_norm(inputs=feature, 240 | activation_fn=tf.nn.relu, 241 | is_training=self.is_training, 242 | scope='transpose_conv2/bn', 243 | reuse=self.reuse, 244 | scale=True) 245 | feature = slim.conv2d_transpose(feature, 512, 3, 2, activation_fn=None, reuse=self.reuse, scope="transpose_conv3") 246 | feature = slim.batch_norm(inputs=feature, 247 | activation_fn=tf.nn.relu, 248 | is_training=self.is_training, 249 | scope='transpose_conv3/bn', 250 | reuse=self.reuse, 251 | scale=True) 252 | feature = slim.conv2d(feature, self.residual_dim[0], 3, 1, activation_fn=None, normalizer_fn=None, reuse=self.reuse, scope="conv1") 253 | return [feature] 254 | 255 | def graph_hourglass(self, inputs, scope='Keypoints'): 256 | """ 257 | graph hourglass net. 258 | :param inputs: (Tensor) images 259 | :param scope: (str) scope 260 | :return: [[Tensor B x H/4 x W/4 x num_class,...]] 261 | """ 262 | with tf.variable_scope(scope): 263 | if self.backbone == "hourglass": 264 | features = self.graph_backbone_hourglass(inputs) 265 | elif self.backbone == "resnet_v2_101": 266 | features = self.graph_backbone_resnet_101(inputs) 267 | else: 268 | raise ValueError("Invalid Backbone type!") 269 | all_features = [self.keypoint(features)] 270 | print('--PB file input node is %s' % inputs.name) 271 | print('--PB file output node is %s' % all_features[0][-1].name) 272 | return all_features 273 | 274 | -------------------------------------------------------------------------------- /core/network/network_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-10 8 | """ 9 | 10 | 11 | import tensorflow as tf 12 | import tensorflow.contrib.slim as slim 13 | 14 | 15 | def residual_block_v2_with_bottom_neck(inputs, output_dim, stride, 16 | is_training=True, reuse=False, scope='residual_block'): 17 | """ 18 | Pre-act mode 19 | modified residual block 20 | bottle neck depth = output_dim / 2 21 | output = conv + short-cut 22 | :param inputs: (Tensor) input tensor BxHxWxC 23 | :param output_dim: (int) multiple of 2 24 | :param stride: (int) if down-sample 25 | :param reuse: (bool) reuse the variable 26 | :param scope: (str) scope name 27 | :param is_training: (bool)bn is in training phase 28 | :return: (Tensor) Bx(H/stride)x(W/stride)xC 29 | """ 30 | dim = output_dim / 4 31 | if output_dim % 2 != 0: 32 | raise ValueError('residual block output dim must be a multiple of 2') 33 | with tf.variable_scope(scope): 34 | depth_in = inputs.get_shape().as_list()[-1] 35 | pre_act = slim.batch_norm( 36 | inputs=inputs, 37 | activation_fn=tf.nn.relu, 38 | is_training=is_training, 39 | scope='pre_act', 40 | scale=True, 41 | reuse=reuse 42 | ) 43 | if output_dim == depth_in: 44 | short_cut = slim.max_pool2d( 45 | inputs=inputs, 46 | kernel_size=[1, 1], 47 | stride=stride, 48 | scope='short_cut' 49 | ) 50 | else: 51 | short_cut = slim.conv2d( 52 | inputs=pre_act, 53 | num_outputs=output_dim, 54 | kernel_size=[1, 1], 55 | stride=stride, 56 | activation_fn=None, 57 | normalizer_fn=None, 58 | scope='short_cut', 59 | reuse=reuse 60 | ) 61 | tf.summary.histogram(short_cut.name + '/activations', short_cut) 62 | 63 | residual = slim.conv2d( 64 | inputs=pre_act, 65 | num_outputs=dim, 66 | kernel_size=[1, 1], 67 | stride=1, 68 | activation_fn=None, 69 | normalizer_fn=None, 70 | scope='conv1', 71 | reuse=reuse 72 | ) 73 | residual = slim.batch_norm( 74 | residual, 75 | activation_fn=tf.nn.relu, 76 | is_training=is_training, 77 | scope='conv1/bn', 78 | scale=True, 79 | reuse=reuse 80 | ) 81 | tf.summary.histogram(residual.name + '/activations', residual) 82 | 83 | residual = slim.conv2d( 84 | inputs=residual, 85 | num_outputs=dim, 86 | kernel_size=[3, 3], 87 | stride=stride, 88 | activation_fn=None, 89 | normalizer_fn=None, 90 | scope='conv2', 91 | reuse=reuse 92 | ) 93 | residual = slim.batch_norm( 94 | residual, 95 | activation_fn=tf.nn.relu, 96 | is_training=is_training, 97 | scope='conv2/bn', 98 | scale=True, 99 | reuse=reuse 100 | ) 101 | tf.summary.histogram(residual.name + '/activations', residual) 102 | 103 | residual = slim.conv2d( 104 | inputs=residual, 105 | num_outputs=output_dim, 106 | kernel_size=[1, 1], 107 | stride=1, 108 | activation_fn=None, 109 | normalizer_fn=None, 110 | scope='conv3', 111 | reuse=reuse 112 | ) 113 | tf.summary.histogram(residual.name + '/activations', residual) 114 | 115 | output = short_cut + residual 116 | return output 117 | 118 | 119 | def residual_block_v2(inputs, output_dim, stride, 120 | is_training=True, reuse=False, scope='residual_block'): 121 | """ 122 | Pre-act mode 123 | modified residual block 124 | bottle neck depth = output_dim / 2 125 | output = conv + short-cut 126 | :param inputs: (Tensor) input tensor BxHxWxC 127 | :param output_dim: (int) multiple of 2 128 | :param stride: (int) if down-sample 129 | :param scope: (str) scope name 130 | :param is_training: (bool)bn is in training phase 131 | :return: (Tensor) Bx(H/stride)x(W/stride)xC 132 | """ 133 | with tf.variable_scope(scope): 134 | depth_in = inputs.get_shape().as_list()[-1] 135 | pre_act = slim.batch_norm( 136 | inputs=inputs, 137 | activation_fn=tf.nn.relu, 138 | is_training=is_training, 139 | scope='pre_act', 140 | scale=True, 141 | reuse=reuse 142 | ) 143 | if output_dim == depth_in: 144 | short_cut = slim.max_pool2d( 145 | inputs=inputs, 146 | kernel_size=[1, 1], 147 | stride=stride, 148 | scope='short_cut' 149 | ) 150 | else: 151 | short_cut = slim.conv2d( 152 | inputs=pre_act, 153 | num_outputs=output_dim, 154 | kernel_size=[1, 1], 155 | stride=stride, 156 | activation_fn=None, 157 | normalizer_fn=None, 158 | scope='short_cut', 159 | reuse=reuse 160 | ) 161 | tf.summary.histogram(short_cut.name + '/activations', short_cut) 162 | 163 | residual = slim.conv2d( 164 | inputs=pre_act, 165 | num_outputs=output_dim, 166 | kernel_size=[3, 3], 167 | stride=1, 168 | activation_fn=None, 169 | normalizer_fn=None, 170 | scope='conv1', 171 | reuse=reuse 172 | ) 173 | residual = slim.batch_norm( 174 | residual, 175 | activation_fn=tf.nn.relu, 176 | is_training=is_training, 177 | scope='conv1/bn', 178 | scale=True, 179 | reuse=reuse 180 | ) 181 | tf.summary.histogram(residual.name + '/activations', residual) 182 | 183 | residual = slim.conv2d( 184 | inputs=residual, 185 | num_outputs=output_dim, 186 | kernel_size=[3, 3], 187 | stride=stride, 188 | activation_fn=None, 189 | normalizer_fn=None, 190 | scope='conv2', 191 | reuse=reuse 192 | ) 193 | 194 | tf.summary.histogram(residual.name + '/activations', residual) 195 | 196 | output = short_cut + residual 197 | return output 198 | 199 | 200 | def hourglass_block(inputs, num_depth, residual_dim, 201 | is_training=True, is_maxpool=False, 202 | is_nearest=True, reuse=False, scope='hourglass_block'): 203 | """ 204 | modified hourglass block fellow by "CornerNet" 205 | There 2 residual blocks in short-cut istead of 1 206 | There 2 residual blocks after upsampling 207 | There 4 residual blocks with depth dim (512 in paper) in the middle of hourglass 208 | Attention! residual blocks are in pre-act mode 209 | inputs must be not processed by actived or normlized 210 | :param inputs: (Tensor) BxHxWxC 211 | :param num_depth: (int) depth of downsample 212 | :param residual_dim: (list) dim of residual block. len(residual_dim)=num_depth+1 213 | :param is_training: (bool) bn is in training phase 214 | :param is_maxpool: (bool) if it's True, downsample mode will be maxpool. Otherwise, downsample mode will be stride=2 215 | :param is_nearest: (bool) if it's True, upsample mode will be neareast upsample. Otherwise, upsample mode will be deconv. 216 | :param scope: (str) scope name 217 | :return: (Tensor) BxHxWxC 218 | """ 219 | cur_res_dim = inputs.get_shape().as_list()[-1] 220 | next_res_dim = residual_dim[0] 221 | 222 | with tf.variable_scope(scope): 223 | up_1 = residual_block_v2( 224 | inputs=inputs, 225 | output_dim=cur_res_dim, 226 | stride=1, 227 | is_training=is_training, 228 | reuse=reuse, 229 | scope='up_1' 230 | ) 231 | if is_maxpool: 232 | low_1 = slim.max_pool2d( 233 | inputs=inputs, 234 | kernel_size=2, 235 | stride=2, 236 | padding='VALID' 237 | ) 238 | low_1 = residual_block_v2( 239 | inputs=low_1, 240 | output_dim=next_res_dim, 241 | stride=1, 242 | is_training=is_training, 243 | reuse=reuse, 244 | scope='low_1' 245 | ) 246 | else: 247 | low_1 = residual_block_v2( 248 | inputs=inputs, 249 | output_dim=next_res_dim, 250 | stride=2, 251 | is_training=is_training, 252 | reuse=reuse, 253 | scope='low_1' 254 | ) 255 | 256 | if num_depth > 1: 257 | low_2 = hourglass_block( 258 | inputs=low_1, 259 | num_depth=num_depth - 1, 260 | residual_dim=residual_dim[1:], 261 | is_training=is_training, 262 | is_maxpool=is_maxpool, 263 | is_nearest=is_nearest, 264 | reuse=reuse, 265 | scope='hourglass_block_%d' % (num_depth - 1) 266 | ) 267 | else: 268 | low_2 = residual_block_v2( 269 | inputs=low_1, 270 | output_dim=next_res_dim, 271 | stride=1, 272 | is_training=is_training, 273 | reuse=reuse, 274 | scope='low_2' 275 | ) 276 | low_3 = residual_block_v2( 277 | inputs=low_2, 278 | output_dim=cur_res_dim, 279 | stride=1, 280 | is_training=is_training, 281 | reuse=reuse, 282 | scope='low_3' 283 | ) 284 | if is_nearest: 285 | up_2 = tf.image.resize_nearest_neighbor( 286 | images=low_3, 287 | size=tf.shape(low_3)[1:3] * 2, 288 | name='up_2' 289 | ) 290 | else: 291 | up_2 = slim.conv2d_transpose( 292 | inputs=low_3, 293 | num_outputs=cur_res_dim, 294 | kernel_size=[3, 3], 295 | stride=2, 296 | reuse=reuse, 297 | scope='up_2' 298 | ) 299 | merge = up_1 + up_2 300 | return merge 301 | -------------------------------------------------------------------------------- /core/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/core/train/__init__.py -------------------------------------------------------------------------------- /core/train/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-20 8 | """ 9 | import os 10 | import time 11 | 12 | import tensorflow as tf 13 | import tensorflow.contrib.slim as slim 14 | from core.loss.loss import focal_loss, cross_entropy, softmax_cross_entropy, mean_square_loss 15 | 16 | 17 | class Trainer(): 18 | def __init__(self, model_class, dataset_class, cfg): 19 | 20 | start_time = time.time() 21 | # HARDWARE 22 | self.CUDA_VISIBLE_DEVICES = cfg.CUDA_VISIBLE_DEVICES 23 | if self.CUDA_VISIBLE_DEVICES is not None: 24 | os.environ['CUDA_VISIBLE_DEVICES'] = self.CUDA_VISIBLE_DEVICES 25 | self.MULTI_THREAD_NUM = cfg.MULTI_THREAD_NUM 26 | # self.MULTI_GPU = cfg.MULTI_GPU 27 | # self.NUM_GPU = len(self.MULTI_GPU) 28 | 29 | # NETWORK 30 | self.backbone = cfg.backbone 31 | self.image_size = cfg.image_size 32 | self.heatmap_size = cfg.heatmap_size 33 | self.stride = cfg.stride 34 | self.num_block = cfg.num_block 35 | self.num_depth = cfg.num_depth 36 | self.residual_dim = cfg.residual_dim 37 | self.is_maxpool = cfg.is_maxpool 38 | self.is_nearest = cfg.is_nearest 39 | 40 | # TRAINING 41 | self.batch_size = cfg.batch_size 42 | self.learning_rate_init = cfg.learning_rate_init 43 | self.learning_rate_warmup = cfg.learning_rate_warmup 44 | self.exp_decay = cfg.exp_decay 45 | 46 | self.warmup_epoch_size = cfg.warmup_epoch_size 47 | self.epoch_size = cfg.epoch_size 48 | self.summary_per = cfg.summary_per 49 | self.save_per = cfg.save_per 50 | 51 | self.regularization_weight = cfg.regularization_weight 52 | 53 | # VALIDATION 54 | self.val_per = cfg.val_per 55 | self.val_time = cfg.val_time 56 | 57 | # PATH 58 | self.dataset_dir = cfg.dataset_dir 59 | self.train_image_dir = cfg.train_image_dir 60 | self.val_image_dir = cfg.val_image_dir 61 | self.train_list_path = cfg.train_list_path 62 | self.val_list_path = cfg.val_list_path 63 | 64 | self.log_dir = cfg.log_dir 65 | self.ckpt_path = cfg.ckpt_dir 66 | 67 | # SAVER AND LOADER 68 | self.pre_trained_ckpt = cfg.pre_trained_ckpt 69 | self.ckpt_name = cfg.ckpt_name 70 | self.max_keep = cfg.max_keep 71 | 72 | print('-Load config in %.3f' % (time.time() - start_time)) 73 | 74 | # DATASET 75 | self.dataset_class = dataset_class 76 | self.augment = cfg.augment 77 | self.train_dataset = None 78 | self.val_dataset = None 79 | self.train_iterator = None 80 | self.val_iterator = None 81 | 82 | # cal option 83 | self.time = time.strftime( 84 | '%Y_%m_%d_%H_%M_%S', 85 | time.localtime( 86 | time.time())) 87 | self.steps_per_period = None 88 | 89 | # PLACE HOLDER 90 | self.inputs_x = None 91 | self.inputs_y = None 92 | self.is_training = None 93 | 94 | # MODEL 95 | self.model_class = model_class 96 | self.model = None 97 | self.features = None 98 | 99 | self.val_model = None 100 | self.val_features = None 101 | 102 | # LOSS 103 | self.loss_mode = cfg.loss_mode 104 | self.model_losses = None 105 | self.model_loss = None 106 | self.val_model_loss = None 107 | self.trainable_variables = None 108 | self.regularization_loss = None 109 | self.loss = None 110 | 111 | # LEARNING RATE 112 | self.global_step = None 113 | self.learning_rate = None 114 | 115 | # TRAIN OP 116 | self.train_op = None 117 | 118 | # SAVER LOADER SUMMARY 119 | self.loader = None 120 | self.saver = None 121 | self.summary_writer = None 122 | self.write_op = None 123 | 124 | # DEBUG 125 | self.is_debug = False 126 | self.gradient = None 127 | self.mean_gradient = None 128 | 129 | # SESSION 130 | self.sess = None 131 | ################################################################# 132 | 133 | def init_inputs(self): 134 | with tf.variable_scope('Placeholder'): 135 | self.inputs_x = tf.placeholder(tf.float32, [None, self.image_size[0], self.image_size[1], 3], 136 | 'inputs_x') 137 | self.inputs_y = tf.placeholder(tf.float32, [None, self.heatmap_size[0], self.heatmap_size[0], 138 | self.train_dataset.num_class], 'inputs_y') 139 | # 如果使用placeholder为BN层的trainable参数,BN层中会处于一种使用tf.cond,tf.switch流控制节点(此处可以在tensorRT以及模型图中得到验证) 140 | # 这样的话每一个BN层都会有两条路径出来,训练太占显存,infer部署的时候还要单独进行剪枝 141 | # 此处直接设置为True的话,训练是没问题的。做val的时候,不调用train_op那么BN的gamma和beta不会更新 142 | # 并且由于mean和var设置为依赖于train_op更新,所以BN在val时所有参数都没有更新,相当于trainable=False 143 | # 然而在tf1.x版本中,trainable=False是让BN处于freeze状态。 144 | # 和infer不同的时,freeze仍然是使用当前batch的mean和var进行处理。 145 | # 在tf2.x版本中,bn已经改成了当trainable为False的时候是infer状态 146 | self.is_training = True 147 | 148 | def init_dataset(self): 149 | start_time = time.time() 150 | 151 | # TRAIN DATASET 152 | self.train_dataset = self.dataset_class(image_dir=self.train_image_dir, 153 | gt_path=self.train_list_path, 154 | batch_size=self.batch_size, 155 | image_size=self.image_size, 156 | heatmap_size=self.heatmap_size, 157 | augment=self.augment) 158 | self.train_iterator = self.train_dataset.iterator( 159 | self.MULTI_THREAD_NUM) 160 | 161 | # VAL DATASET 162 | self.val_dataset = self.dataset_class(image_dir=self.val_image_dir, 163 | gt_path=self.val_list_path, 164 | batch_size=self.batch_size, 165 | image_size=self.image_size, 166 | heatmap_size=self.heatmap_size 167 | ) 168 | self.val_iterator = self.val_dataset.iterator(self.MULTI_THREAD_NUM) 169 | self.steps_per_period = int( 170 | self.train_dataset.num_data / 171 | self.batch_size) 172 | print('-Creat dataset in %.3f' % (time.time() - start_time)) 173 | 174 | def init_model(self): 175 | print("-Creat Train model") 176 | self.model = self.model_class(self.inputs_x, self.train_dataset.num_class, 177 | backbone=self.backbone, 178 | num_block=self.num_block, 179 | num_depth=self.num_depth, 180 | residual_dim=self.residual_dim, 181 | is_training=True, 182 | is_maxpool=self.is_maxpool, 183 | is_nearest=self.is_nearest, 184 | reuse=False 185 | ) 186 | self.features = self.model.features[0] 187 | 188 | print("-Creat Val model") 189 | self.val_model = self.model_class(self.inputs_x, self.train_dataset.num_class, 190 | backbone=self.backbone, 191 | num_block=self.num_block, 192 | num_depth=self.num_depth, 193 | residual_dim=self.residual_dim, 194 | is_training=False, 195 | is_maxpool=self.is_maxpool, 196 | is_nearest=self.is_nearest, 197 | reuse=True 198 | ) 199 | self.val_features = self.val_model.features[0] 200 | 201 | def init_learning_rate(self): 202 | start_time = time.time() 203 | # LEARNING RATE 204 | with tf.variable_scope('Learning_rate'): 205 | self.global_step = tf.train.get_or_create_global_step() 206 | warmup_steps = tf.constant(self.warmup_epoch_size * self.steps_per_period, 207 | dtype=tf.int64, name='warmup_steps') 208 | self.learning_rate = tf.cond( 209 | pred=tf.less(self.global_step, warmup_steps), 210 | true_fn=lambda: self.learning_rate_warmup + (self.learning_rate_init - self.learning_rate_warmup) 211 | * tf.cast(self.global_step, tf.float32) / tf.cast(warmup_steps, tf.float32), 212 | false_fn=lambda: tf.train.exponential_decay( 213 | self.learning_rate_init, self.global_step, self.steps_per_period, self.exp_decay, staircase=True) 214 | ) 215 | print('-Creat learning rate in %.3f' % (time.time() - start_time)) 216 | 217 | def init_loss(self): 218 | start_time = time.time() 219 | 220 | # LOSS 221 | with tf.variable_scope('Loss'): 222 | self.trainable_variables = tf.trainable_variables() 223 | if self.loss_mode == 'focal': 224 | loss_fn = focal_loss 225 | elif self.loss_mode == 'sigmoid': 226 | loss_fn = cross_entropy 227 | elif self.loss_mode == 'softmax': 228 | loss_fn = softmax_cross_entropy 229 | elif self.loss_mode == 'mse': 230 | loss_fn = mean_square_loss 231 | else: 232 | raise ValueError('Unsupported loss mode: %s' % self.loss_mode) 233 | self.model_losses = loss_fn(self.features, self.inputs_y) 234 | self.model_loss = tf.add_n(self.model_losses) 235 | self.val_model_loss = loss_fn(self.val_features, self.inputs_y)[-1] 236 | self.regularization_loss = tf.add_n( 237 | [tf.nn.l2_loss(var) for var in self.trainable_variables]) 238 | self.regularization_loss = self.regularization_weight * self.regularization_loss 239 | self.loss = self.model_loss + self.regularization_loss 240 | 241 | print('-Creat loss in %.3f' % (time.time() - start_time)) 242 | 243 | def init_train_op(self): 244 | start_time = time.time() 245 | # TRAIN_OP 246 | with tf.name_scope("Train_op"): 247 | optimizer = tf.train.AdamOptimizer( 248 | self.learning_rate) 249 | gvs = optimizer.compute_gradients(self.loss) 250 | clip_gvs = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gvs] 251 | if self.is_debug: 252 | self.mean_gradient = tf.reduce_mean([tf.reduce_mean(g) for g, v in gvs]) 253 | tf.summary.scalar("mean_gradient", self.mean_gradient) 254 | print('Debug mode is on !!!') 255 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 256 | # It's important! 257 | # Update moving-average in BN 258 | self.train_op = optimizer.apply_gradients(clip_gvs, global_step=self.global_step) 259 | print('-Creat train op in %.3f' % (time.time() - start_time)) 260 | 261 | def init_loader_saver_summary(self): 262 | start_time = time.time() 263 | with tf.name_scope('loader_and_saver'): 264 | self.loader = tf.train.Saver(var_list=tf.global_variables()) 265 | var_list = tf.trainable_variables() 266 | g_list = tf.global_variables() 267 | bn_moving_var = [g for g in g_list if 'moving_mean' in g.name] 268 | bn_moving_var += [g for g in g_list if 'moving_variance' in g.name] 269 | if len(bn_moving_var) < 1: 270 | print('Warning! BatchNorm layer parameters have not been saved!') 271 | var_list += bn_moving_var 272 | self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=self.max_keep) 273 | 274 | with tf.name_scope('summary'): 275 | 276 | tf.summary.image('input_image', self.inputs_x, max_outputs=3) 277 | tf.summary.image('input_hm', tf.reduce_sum(self.inputs_y,axis=-1,keepdims=True), max_outputs=3) 278 | tf.summary.image('output_hm', tf.reduce_sum(self.features[-1],axis=-1,keepdims=True), max_outputs=3) 279 | 280 | tf.summary.scalar("learning_rate", self.learning_rate) 281 | for i in range(len(self.model_losses)): 282 | tf.summary.scalar("block_%d_loss" % i, self.model_losses[i]) 283 | tf.summary.scalar("model_loss", self.model_loss) 284 | tf.summary.scalar("regularization_loss", self.regularization_loss) 285 | tf.summary.scalar("total_loss", self.loss) 286 | # # Optional 287 | # tf.summary.scalar('keypoint_bn_moving_mean', 288 | # tf.reduce_mean(slim.get_variables_by_name('HourglassNet/keypoint_1/pre_bn/moving_mean'))) 289 | # tf.summary.scalar('keypoint_bn_moving_var', tf.reduce_mean( 290 | # slim.get_variables_by_name('HourglassNet/keypoint_1/pre_bn/moving_variance'))) 291 | 292 | if not os.path.exists(self.log_dir): 293 | os.mkdir(self.log_dir) 294 | self.write_op = tf.summary.merge_all() 295 | 296 | print( 297 | '-Creat loader saver and summary in %.3f' % 298 | (time.time() - start_time)) 299 | 300 | def init_session(self): 301 | start_time = time.time() 302 | # SESSION 303 | config = tf.ConfigProto(allow_soft_placement=True) # 是否自动选择GPU 304 | config.gpu_options.allow_growth = True 305 | self.sess = tf.Session(config=config) 306 | self.summary_writer = tf.summary.FileWriter( 307 | os.path.join(self.log_dir, self.time), graph=self.sess.graph) 308 | print('-Initializing session in %.3f' % (time.time() - start_time)) 309 | 310 | # self.train_launch() 311 | ################################################################ 312 | def _load_ckpt(self): 313 | t0 = time.time() 314 | try: 315 | self.loader.restore(self.sess, self.pre_trained_ckpt) 316 | print('Successful restore from %s in time %.2f' % 317 | (self.pre_trained_ckpt, time.time() - t0)) 318 | except Exception as e: 319 | print(e) 320 | print('Failed restore from %s in time %.2f' % 321 | (self.pre_trained_ckpt, time.time() - t0)) 322 | 323 | def train(self): 324 | t0 = time.time() 325 | self.sess.run(tf.global_variables_initializer()) 326 | print('-Model has beed initialized in %.3f' % (time.time() - t0)) 327 | if self.pre_trained_ckpt is not None: 328 | self._load_ckpt() 329 | 330 | print('Begin to train!') 331 | total_step = self.epoch_size * self.steps_per_period 332 | step = 0 333 | while step < total_step: 334 | # try: 335 | step = self.sess.run(self.global_step) 336 | ite = step % self.steps_per_period + 1 337 | epoch = step // self.steps_per_period + 1 338 | imgs, hms = next(self.train_iterator) 339 | imgs = (imgs / 127.5) - 1 340 | feed_dict = { 341 | self.inputs_x: imgs, 342 | self.inputs_y: hms, 343 | } 344 | 345 | if step % self.summary_per == 0: 346 | if self.is_debug: 347 | mean_gradient = self.sess.run(self.mean_gradient, feed_dict=feed_dict) 348 | print('mean_gradient: %.6f ' % mean_gradient) 349 | summary, _, lr, loss, model_ls, reg_ls = self.sess.run( 350 | [self.write_op, self.train_op, self.learning_rate, self.loss, self.model_loss, self.regularization_loss], feed_dict=feed_dict) 351 | print( 352 | 'Epoch: %d / %d Iter: %d / %d Step: %d Loss: %.4f Model Loss: %.4f Reg Loss: %.4f Lr: %f' % 353 | (epoch, self.epoch_size, ite, self.steps_per_period, step, loss, model_ls, reg_ls, lr)) 354 | self.summary_writer.add_summary(summary, step) 355 | else: 356 | _, lr, loss, model_ls, reg_ls = self.sess.run( 357 | [self.train_op, self.learning_rate, self.loss, self.model_loss, self.regularization_loss], feed_dict=feed_dict) 358 | 359 | if step % self.save_per == 0: 360 | self.saver.save( 361 | self.sess, 362 | os.path.join( 363 | self.ckpt_path, 364 | self.ckpt_name), 365 | global_step=step) 366 | if step % self.val_per == 0 and step != 0: 367 | # Validation 368 | losses = [] 369 | start_time = time.time() 370 | for s in range(self.val_time): 371 | # TODO 计算loss 不更新梯度 保存每一次loss 最后打印平均loss 372 | # TODO 保存几个图片输出的结果 可以用cv2.circle渲染 cv2.imwrite 存在本地 373 | imgs_v, hms_v = next(self.val_iterator) 374 | imgs_v = (imgs_v / 127.5) - 1 375 | feed_dict = { 376 | self.inputs_x: imgs_v, 377 | self.inputs_y: hms_v, 378 | } 379 | loss = self.sess.run(self.val_model_loss, feed_dict=feed_dict) 380 | losses.append(loss) 381 | print('Validation %d times in %.3fs mean loss is %f' 382 | % (self.val_time, time.time() - start_time, sum(losses) / len(losses))) 383 | # except Exception as e: 384 | # print(e) 385 | self.saver.save( 386 | self.sess, 387 | os.path.join( 388 | self.ckpt_path, 389 | self.ckpt_name), 390 | global_step=step) 391 | self.summary_writer.close() 392 | self.sess.close() 393 | 394 | def train_launch(self): 395 | # must in order 396 | self.init_dataset() 397 | self.init_inputs() 398 | self.init_model() 399 | 400 | # optional override 401 | self.init_loss() 402 | self.init_learning_rate() 403 | self.init_train_op() 404 | self.init_loader_saver_summary() 405 | self.init_session() 406 | self.train() 407 | -------------------------------------------------------------------------------- /data/dataset/readme.txt: -------------------------------------------------------------------------------- 1 | This dir contains different dataset -------------------------------------------------------------------------------- /data/name/coco.name: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/data/name/coco.name -------------------------------------------------------------------------------- /demon.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-24 8 | """ 9 | import os 10 | import cv2 11 | import time 12 | import tensorflow as tf 13 | from core.infer.visual_utils import get_results, draw_point, draw_skeleton 14 | 15 | 16 | def read_pb(pb_path, input_node_name_and_val, output_node_name): 17 | """ 18 | :param pb_path: 19 | :param input_node_name_and_val: {(str) input_node_name: (any) input_node_val} 20 | :param output_node_name: [(str) output_node_name] 21 | :return: [output] 22 | """ 23 | with tf.Graph().as_default(): 24 | output_graph_def = tf.GraphDef() 25 | with open(pb_path, 'rb') as f: 26 | output_graph_def.ParseFromString(f.read()) 27 | tf.import_graph_def(output_graph_def, name='') 28 | config = tf.ConfigProto(allow_soft_placement=True) # 是否自动选择GPU 29 | config.gpu_options.allow_growth = True 30 | with tf.Session(config=config) as sess: 31 | # sess.run(tf.global_variables_initializer()) 32 | # 定义输入的张量名称,对应网络结构的输入张量 33 | # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数 34 | feed_dict = {} 35 | for key in input_node_name_and_val: 36 | input_tensor = sess.graph.get_tensor_by_name(key) 37 | feed_dict[input_tensor] = input_node_name_and_val[key] 38 | 39 | # 定义输出的张量名称 40 | output_tensor = [] 41 | for name in output_node_name: 42 | output_tensor.append(sess.graph.get_tensor_by_name(name)) 43 | 44 | # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字 45 | start_time = time.time() 46 | output = sess.run(output_tensor, feed_dict=feed_dict) 47 | print('Infer time is %.4f' % (time.time() - start_time)) 48 | return output 49 | 50 | 51 | if __name__ == '__main__': 52 | import numpy as np 53 | from core.dataset.data_generator import Dataset 54 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 55 | pb_path = 'Hourglass.pb' 56 | # pb_path = 'tensorRT/TensorRT.pb' 57 | img_dir = '/data/dataset/coco/images/val2017' 58 | gt_path = 'data/dataset/coco/coco_val.txt' 59 | batch_size = 8 60 | img_size = (512,512) 61 | hm_size = (128,128) 62 | dataset = Dataset(img_dir, gt_path, batch_size, None, img_size, hm_size) 63 | it = dataset.iterator(4, False) 64 | image, hm = next(it) 65 | image_norm = (image / 127.5) - 1 66 | input_dict = {'Placeholder/inputs_x:0': image_norm} 67 | output_node_name=['Keypoints/keypoint_1/conv/Sigmoid:0'] 68 | outputs = read_pb(pb_path, input_dict, output_node_name) 69 | for k in range(len(outputs)): 70 | # outputs[k] = sigmoid(outputs[k]) 71 | points = get_results(outputs[k], 0.3) 72 | gt_points = get_results(hm, 0.3) 73 | print(points) 74 | print(gt_points) 75 | for i in range(len(points)): 76 | img = image[i][:, :, ::-1] 77 | for j in range(len(points[i])): 78 | if points[i][j][0] != -1: 79 | points[i][j][0] = int(points[i][j][0]/hm_size[1]*img.shape[1]) 80 | if points[i][j][1] != -1: 81 | points[i][j][1] = int(points[i][j][1]/hm_size[0]*img.shape[0]) 82 | for j in range(len(gt_points[i])): 83 | if gt_points[i][j][0] != -1: 84 | gt_points[i][j][0] = int(gt_points[i][j][0]/hm_size[1]*img.shape[1]) 85 | if gt_points[i][j][1] != -1: 86 | gt_points[i][j][1] = int(gt_points[i][j][1]/hm_size[0]*img.shape[0]) 87 | 88 | one_ouput = np.sum(outputs[k][i], axis=-1, keepdims=True) * 255 89 | tile_output = np.tile(one_ouput, (1, 1, 3)) 90 | tile_img =cv2.resize(tile_output, img_size) + img 91 | 92 | cv2.imwrite('render_img/'+str(i)+'_'+str(k)+'_origin.jpg', img) 93 | 94 | 95 | cv2.imwrite('render_img/'+str(i)+'_'+str(k)+'_hm.jpg', tile_img) 96 | 97 | sk_img = draw_skeleton(img, points[i],'coco') 98 | cv2.imwrite('render_img/' + str(i) + '_' + str(k) + '_skeleton.jpg', sk_img) 99 | 100 | img = draw_skeleton(img, gt_points[i],'coco') 101 | cv2.imwrite('render_img/'+str(i)+'_'+str(k)+'_visible.jpg', img) 102 | 103 | # outputs[k] 104 | 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /infer_hourglass.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-21 8 | """ 9 | from core.infer.infer_utils import read_pb, pred_one_image 10 | from core.infer.visual_utils import draw_point, draw_bbx, draw_skeleton 11 | # image = cv2.imread(img_path) 12 | # # 1.实例化模型 13 | # sess, input_tensor, output_tensor = \ 14 | # read_pb(pb_path, ['Placeholder/inputs_x:0'], ['HourglassNet/keypoint_1/conv/Sigmoid:0']) 15 | # # 2.处理图片 每次处理一个图里面的数据作为batch 16 | # # bbxes 是提前知道的信息 bbxes = [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax]] 17 | # points = pred_one_image(image, bbxes, sess, input_tensor, output_tensor) 18 | # print(points) 19 | # for point in points: 20 | # image = draw_point(image, point) -------------------------------------------------------------------------------- /output/coco/readme.txt: -------------------------------------------------------------------------------- 1 | directory contains tensorboard files -------------------------------------------------------------------------------- /output/mpii/readme.txt: -------------------------------------------------------------------------------- 1 | directory contains tensorboard files -------------------------------------------------------------------------------- /script/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-17 8 | """ 9 | -------------------------------------------------------------------------------- /script/ckpt2ckpt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-11-29 8 | """ 9 | 10 | import tensorflow as tf 11 | from core.network.keypoints import Keypoints 12 | from tensorflow.python import pywrap_tensorflow 13 | import config.config_hourglass_coco as cfg 14 | import tensorflow.contrib.slim as slim 15 | import os 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 17 | 18 | ori_ckpt = '/data/checkpoints/pre_train/resnet_v2_101.ckpt' 19 | new_ckpt = os.path.join(cfg.ckpt_dir, "Keypoints_coco_resnet_v2_101.ckpt") 20 | 21 | 22 | def change_name(name): 23 | # 自己的网络名称有前缀 24 | return name[19:] 25 | 26 | def restore_name(name): 27 | return "Keypoints/backbone/" + name 28 | 29 | def map_var(name_list): 30 | # {pretrain文件中的名字 :自己模型中的tensor} 31 | var_list = {} 32 | for name in name_list: 33 | new_name = restore_name(name) 34 | var_list[name] = slim.get_variables_by_name(new_name)[0] 35 | return var_list 36 | 37 | # origin 38 | reader = pywrap_tensorflow.NewCheckpointReader(ori_ckpt) 39 | var_ori = reader.get_variable_to_shape_map() 40 | # network 41 | inputs = tf.placeholder(tf.float32, [1, 512, 512, 3]) 42 | centernet = Keypoints(inputs, 80, 43 | num_block=cfg.num_block, 44 | backbone="resnet_v2_101", 45 | num_depth=cfg.num_depth, 46 | residual_dim=cfg.residual_dim, 47 | is_training=True, 48 | is_maxpool=cfg.is_maxpool, 49 | is_nearest=cfg.is_nearest, 50 | reuse=False 51 | ) 52 | var_new = slim.get_variables_to_restore() 53 | 54 | # search common 55 | count = 0 56 | ommit = 0 57 | all_var = set() 58 | restore_list = [] 59 | for key in var_new: 60 | # 命名改变了 改成了"CenterNet/作为前缀, 需要去掉" 61 | all_var.add(change_name(key.name.strip(':0'))) 62 | for key in var_ori: 63 | if key in all_var: 64 | ori_var = reader.get_tensor(key) 65 | new_var = slim.get_variables_by_name(restore_name(key))[0] 66 | s1 = list(ori_var.shape) 67 | s2 = new_var.get_shape().as_list() 68 | if s1 == s2: 69 | count += 1 70 | restore_list.append(key) 71 | else: 72 | ommit += 1 73 | else: 74 | ommit += 1 75 | print('restore ', count) 76 | print('ommit', ommit) 77 | print('all', count + ommit) 78 | var_list = map_var(restore_list) 79 | # loader = tf.train.Saver( 80 | # var_list=slim.get_variables_to_restore( 81 | # include=restore_list, 82 | # exclude=['logits'])) 83 | loader = tf.train.Saver( 84 | var_list=var_list) 85 | saver = tf.train.Saver() 86 | with tf.Session() as sess: 87 | sess.run(tf.global_variables_initializer()) 88 | loader.restore(sess, ori_ckpt) 89 | saver.save(sess, new_ckpt) 90 | -------------------------------------------------------------------------------- /script/coco2txt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-12 8 | """ 9 | import os 10 | import json 11 | def coco_keypoint2txt(file, txt_path, thre=1): 12 | with open(txt_path, 'w') as writer: 13 | all_count = 0 14 | good_count = 0 15 | print('Transform %s' % (file)) 16 | index = {} 17 | data = json.load(open(file)) 18 | for ann in data['annotations']: 19 | all_count+=1 20 | st = str(int(ann['bbox'][0])) + ',' + str(int(ann['bbox'][1])) + ',' + str(int(ann['bbox'][0]+ann['bbox'][2])) + ',' + str(int(ann['bbox'][1]+ann['bbox'][3]))+' ' 21 | gt = index.get(ann['image_id'], []) 22 | keypoints = ann['keypoints'] 23 | # key = [] 24 | for i in range(len(keypoints) // 3): 25 | # 不存在 26 | if keypoints[i * 3 + 2] == 0: 27 | st += '-1,-1' + ' ' 28 | # key.append([-1, -1]) 29 | # 标注 但不可见 30 | elif keypoints[i * 3 + 2] == 1: 31 | st += str(int(keypoints[i * 3])) + ',' + \ 32 | str(int(keypoints[i * 3 + 1])) + ' ' 33 | # st += '-1,-1' + ' ' 34 | # key.append([keypoints[i * 3], keypoints[i * 3 + 1]]) 35 | # 标注 可见 36 | elif keypoints[i * 3 + 2] == 2: 37 | st += str(int(keypoints[i * 3])) + ',' + \ 38 | str(int(keypoints[i * 3 + 1])) + ' ' 39 | # key.append([keypoints[i * 3], keypoints[i * 3 + 1]]) 40 | else: 41 | st += '-1,-1' + ' ' 42 | print('Unsupported keypoints val') 43 | # key.append([-1, -1]) 44 | if st.count('-1,-1') <= thre: 45 | good_count += 1 46 | # data cleaning 47 | gt.append(st) 48 | index[ann['image_id']] = gt 49 | # writer.write(ann['image_id']+' '+st+'\n') 50 | for image in data['images']: 51 | if image['id'] in index: 52 | for i in range(len(index[image['id']])): 53 | writer.write(image['file_name'] + ' ' + index[image['id']][i] + '\n') 54 | print('total data are %d, write data are %d' % (all_count, good_count)) 55 | 56 | 57 | if __name__ == '__main__': 58 | dataset = 'coco' 59 | 60 | if dataset == 'coco': 61 | coco_dir = '/data/dataset/coco' 62 | annotations_dir = os.path.join(coco_dir, 'annotations') 63 | annotation_train = os.path.join( 64 | annotations_dir, 65 | 'person_keypoints_train2017.json') 66 | annotation_val = os.path.join( 67 | annotations_dir, 68 | 'person_keypoints_val2017.json') 69 | coco_keypoint2txt(annotation_train, '../data/dataset/coco/coco_train.txt', 10) 70 | coco_keypoint2txt(annotation_val, '../data/dataset/coco/coco_val.txt', 10) 71 | 72 | if dataset == 'mpii': 73 | mpii_dir = '/data/dataset/mpii' 74 | annotations_dir = os.path.join(mpii_dir, 'annotations') 75 | annotation_train = os.path.join( 76 | annotations_dir, 77 | 'train.json') 78 | annotation_val = os.path.join( 79 | annotations_dir, 80 | 'test.json') 81 | coco_keypoint2txt(annotation_train, '../data/dataset/mpii/mpii_train.txt', 1) 82 | coco_keypoint2txt(annotation_val, '../data/dataset/mpii/mpii_val.txt', 1) 83 | 84 | -------------------------------------------------------------------------------- /script/mpii2coco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from scipy.io import loadmat, savemat 4 | from PIL import Image 5 | import os 6 | import os.path as osp 7 | import numpy as np 8 | import json 9 | 10 | def check_empty(list, name): 11 | try: 12 | list[name] 13 | except ValueError: 14 | return True 15 | 16 | if len(list[name]) > 0: 17 | return False 18 | else: 19 | return True 20 | 21 | 22 | db_type = 'train' # train, test 23 | annot_file = loadmat('../data/dataset/mpii/annotations/annotations.mat')['RELEASE'] 24 | save_path = '../data/dataset/mpii/annotations/' + db_type + '.json' 25 | 26 | joint_num = 16 27 | img_num = len(annot_file['annolist'][0][0][0]) 28 | 29 | aid = 0 30 | coco = {'images': [], 'categories': [], 'annotations': []} 31 | for img_id in range(img_num): 32 | 33 | if ((db_type == 'train' and annot_file['img_train'][0][0][0][img_id] == 1) or ( 34 | db_type == 'test' and annot_file['img_train'][0][0][0][img_id] == 0)) and \ 35 | check_empty(annot_file['annolist'][0][0][0][img_id], 'annorect') == False: # any person is annotated 36 | 37 | filename =str(annot_file['annolist'][0][0][0][img_id]['image'][0][0][0][0]) # filename 38 | img = Image.open(osp.join('../data/dataset/mpii/images', filename)) 39 | w, h = img.size 40 | img_dict = { 41 | 'id': img_id, 42 | 'file_name': filename, 43 | 'width': w, 44 | 'height': h} 45 | coco['images'].append(img_dict) 46 | 47 | if db_type == 'test': 48 | continue 49 | 50 | person_num = len(annot_file['annolist'][0][0] 51 | [0][img_id]['annorect'][0]) # person_num 52 | joint_annotated = np.zeros((person_num, joint_num)) 53 | for pid in range(person_num): 54 | 55 | if check_empty(annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid], 'annopoints') == False: # kps is annotated 56 | 57 | bbox = np.zeros((4)) # xmin, ymin, w, h 58 | kps = np.zeros((joint_num, 3)) # xcoord, ycoord, vis 59 | 60 | # kps 61 | annot_joint_num = len( 62 | annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid]['annopoints']['point'][0][0][0]) 63 | for jid in range(annot_joint_num): 64 | annot_jid = \ 65 | annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid]['annopoints']['point'][0][0][0][jid][ 66 | 'id'][0][0] 67 | kps[annot_jid][0] = \ 68 | annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid]['annopoints']['point'][0][0][0][jid][ 69 | 'x'][0][0] 70 | kps[annot_jid][1] = \ 71 | annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid]['annopoints']['point'][0][0][0][jid][ 72 | 'y'][0][0] 73 | kps[annot_jid][2] = 1 74 | 75 | # bbox extract from annotated kps 76 | annot_kps = kps[kps[:, 2] == 1, :].reshape(-1, 3) 77 | xmin = np.min(annot_kps[:, 0]) 78 | ymin = np.min(annot_kps[:, 1]) 79 | xmax = np.max(annot_kps[:, 0]) 80 | ymax = np.max(annot_kps[:, 1]) 81 | width = xmax - xmin - 1 82 | height = ymax - ymin - 1 83 | 84 | # corrupted bounding box 85 | if width <= 0 or height <= 0: 86 | continue 87 | # 20% extend 88 | # else: 89 | # bbox[0] = (xmin + xmax) / 2. - width / 2 * 1.2 90 | # bbox[1] = (ymin + ymax) / 2. - height / 2 * 1.2 91 | # bbox[2] = width * 1.2 92 | # bbox[3] = height * 1.2 93 | else: 94 | bbox[0] = max(xmin,0) 95 | bbox[1] = max(ymin,0) 96 | bbox[2] = width 97 | bbox[3] = height 98 | 99 | person_dict = {'id': aid, 'image_id': img_id, 'category_id': 1, 'area': bbox[2] * bbox[3], 100 | 'bbox': bbox.tolist(), 'iscrowd': 0, 'keypoints': kps.reshape(-1).tolist(), 101 | 'num_keypoints': int(np.sum(kps[:, 2] == 1))} 102 | coco['annotations'].append(person_dict) 103 | aid += 1 104 | 105 | category = { 106 | "supercategory": "person", 107 | "id": 1, # to be same as COCO, not using 0 108 | "name": "person", 109 | "skeleton": [[0, 1], 110 | [1, 2], 111 | [2, 6], 112 | [7, 12], 113 | [12, 11], 114 | [11, 10], 115 | [5, 4], 116 | [4, 3], 117 | [3, 6], 118 | [7, 13], 119 | [13, 14], 120 | [14, 15], 121 | [6, 7], 122 | [7, 8], 123 | [8, 9]], 124 | "keypoints": ["r_ankle", "r_knee", "r_hip", 125 | "l_hip", "l_knee", "l_ankle", 126 | "pelvis", "throax", 127 | "upper_neck", "head_top", 128 | "r_wrist", "r_elbow", "r_shoulder", 129 | "l_shoulder", "l_elbow", "l_wrist"]} 130 | 131 | coco['categories'] = [category] 132 | 133 | with open(save_path, 'w') as f: 134 | json.dump(coco, f) 135 | -------------------------------------------------------------------------------- /script/parse_ckpt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-20 8 | """ 9 | 10 | import os 11 | import tensorflow as tf 12 | from tensorflow.python import pywrap_tensorflow 13 | import numpy as np 14 | 15 | # Read data from checkpoint file 16 | # 检查模型变量的var和mean 17 | 18 | 19 | def parse_ckpt(checkpoint_path): 20 | reader =pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 21 | var_to_shape_map = reader.get_variable_to_shape_map() 22 | # Print tensor name and values 23 | key2val = {} 24 | keys = [] 25 | for key in var_to_shape_map: 26 | 27 | if key.split('/')[-1] in ['weights', 'biases']: 28 | print("tensor_name: ", key) 29 | keys.append(key) 30 | val = reader.get_tensor(key) 31 | key2val[key] = np.array(val) 32 | print(np.sum(reader.get_tensor(key))) 33 | print(np.var(reader.get_tensor(key))) 34 | return keys, key2val 35 | 36 | 37 | def read_origin(path): 38 | org_weights_mess = [] 39 | load = tf.train.import_meta_graph(path + '.meta') 40 | with tf.Session() as sess: 41 | load.restore(sess, path) 42 | for var in tf.global_variables(): 43 | var_name = var.op.name 44 | var_name_mess = str(var_name).split('/') 45 | var_shape = var.shape 46 | if (var_name_mess[-1] not in ['weights', 'gamma', 'beta', 'moving_mean', 'moving_variance']): 47 | continue 48 | org_weights_mess.append([var_name, var_shape]) 49 | print("=> " + str(var_name).ljust(50), var_shape) 50 | 51 | def transform(key, key2val): 52 | for k in key: 53 | print(k) 54 | try: 55 | name=k.replace('HourglassNet','model').replace('backbone','stacks').replace('hourglass','stage') 56 | print(name) 57 | 58 | except Exception: 59 | pass 60 | 61 | 62 | if __name__ == '__main__': 63 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 64 | # ckpt_name = 'hg_refined_200.ckpt-0' 65 | # checkpoint_path = os.path.join('../checkpoints', 'pretrained', ckpt_name) 66 | # _, key2val = parse_ckpt(checkpoint_path) 67 | ckpt_name ='mpii/Hourglass_mpii.ckpt-39000' 68 | checkpoint_path = os.path.join('../checkpoints', ckpt_name) 69 | key = parse_ckpt(checkpoint_path) 70 | # transform(key,key2val) 71 | -------------------------------------------------------------------------------- /tensorRT/c++/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5) 2 | cuda_add_executable(keypoints Keypoints_main.cpp ../../../common_source/common_tensorrt/logger.cpp source/keypoints_tensorrt.cpp source/keypoints_tensorrt.h source/utils.h source/utils.cpp source/ResizeNearestNeighbor.cpp source/ResizeNearestNeighbor.h source/my_plugin.h source/my_plugin.cpp source/ResizeNearestNeighbor.cu) 3 | set_property(TARGET keypoints PROPERTY FOLDER project/keypoints) 4 | target_link_libraries(keypoints libnvinfer.so libnvparsers.so cudart.so libopencv_core.so libopencv_imgproc.so libopencv_imgcodecs.so) -------------------------------------------------------------------------------- /tensorRT/c++/Keypoints_main.cpp: -------------------------------------------------------------------------------- 1 | // Created by luozhiwang (luozw1994@outlook.com) 2 | // Date: 2019/12/18 3 | #include "source/keypoints_tensorrt.h" 4 | #include 5 | #include "source/utils.h" 6 | #include "source/my_plugin.h" 7 | 8 | const std::string project_name = "TensorRT_Keypoints"; 9 | void printHelpInfo() 10 | { 11 | std::cout << "Usage: ./keypoints [-h or --help] [-d or " 12 | "--datadir=] [--useDLACore=]\n"; 13 | std::cout << "--help Display help information\n"; 14 | std::cout << "--datadir Specify path to a data directory, overriding " 15 | "the default. This option can be used multiple times to add " 16 | "multiple directories. If no data directories are given, the " 17 | "default is to use (data/samples/mnist/, data/mnist/)" 18 | << std::endl; 19 | std::cout << "--useDLACore=N Specify a DLA engine for layers that support " 20 | "DLA. Value can range from 0 to n-1, where n is the number of " 21 | "DLA engines on the platform." 22 | << std::endl; 23 | std::cout << "--int8 Run in Int8 mode.\n"; 24 | std::cout << "--fp16 Run in FP16 mode." << std::endl; 25 | } 26 | 27 | samplesCommon::UffSampleParams initial_params(const samplesCommon::Args &args){ 28 | samplesCommon::UffSampleParams params; 29 | if (args.dataDirs.empty()){ 30 | params.dataDirs.push_back("/work/tensorRT/project/Template/Keypoints/data/images/"); 31 | } 32 | else //!< Use the data directory provided by the user 33 | { 34 | params.dataDirs = args.dataDirs; 35 | } 36 | params.uffFileName = "/work/tensorRT/project/Template/Keypoints/data/uff/keypoints.uff"; 37 | params.inputTensorNames.push_back("Placeholder/inputs_x"); 38 | params.batchSize = 1; 39 | params.outputTensorNames.push_back("Keypoints/keypoint_1/conv/Sigmoid"); 40 | params.dlaCore = args.useDLACore; 41 | // params.int8 = args.runInInt8; 42 | params.int8 = false; 43 | // params.fp16 = args.runInFp16; 44 | params.fp16 = false; 45 | return params; 46 | } 47 | 48 | int main(int argc, char **argv){ 49 | REGISTER_TENSORRT_PLUGIN(MyPlugin); 50 | 51 | samplesCommon::Args args; 52 | if (!samplesCommon::parseArgs(args, argc, argv)){ 53 | gLogError << "Invalid arguments" << std::endl; 54 | printHelpInfo(); 55 | return EXIT_FAILURE; 56 | } 57 | if (args.help) 58 | { 59 | printHelpInfo(); 60 | return EXIT_SUCCESS; 61 | } 62 | auto sampleTest = Logger::defineTest(project_name, argc, argv); 63 | Logger::reportTestStart(sampleTest); 64 | samplesCommon::UffSampleParams params = initial_params(args); 65 | InputParams input_params(512, 512, 3, 128, 128, 17); 66 | Keypoints keypoints(params, input_params); 67 | gLogInfo << "Building and running a GPU inference engine for " << project_name 68 | << std::endl; 69 | if (!keypoints.build()) 70 | { 71 | return Logger::reportFail(sampleTest); 72 | } 73 | gLogInfo << "Begine to Infer" 74 | << std::endl; 75 | if (!keypoints.infer()) 76 | { 77 | return Logger::reportFail(sampleTest); 78 | } 79 | gLogInfo << "Destroy the engine" 80 | << std::endl; 81 | if (!keypoints.tearDown()) 82 | { 83 | return Logger::reportFail(sampleTest); 84 | } 85 | return Logger::reportPass(sampleTest); 86 | } -------------------------------------------------------------------------------- /tensorRT/c++/data/images/1_0_origin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/tensorRT/c++/data/images/1_0_origin.jpg -------------------------------------------------------------------------------- /tensorRT/c++/data/images/1_0_origin_render.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/tensorRT/c++/data/images/1_0_origin_render.jpg -------------------------------------------------------------------------------- /tensorRT/c++/data/images/7_0_origin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/tensorRT/c++/data/images/7_0_origin.jpg -------------------------------------------------------------------------------- /tensorRT/c++/data/images/7_0_origin_render.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/tensorRT/c++/data/images/7_0_origin_render.jpg -------------------------------------------------------------------------------- /tensorRT/c++/source/ResizeNearestNeighbor.cpp: -------------------------------------------------------------------------------- 1 | // Created by luozhiwang (luozw1994@outlook.com) 2 | // Date: 2020/2/11 3 | 4 | #include "ResizeNearestNeighbor.h" 5 | 6 | // <============ 构造函数 ===========> 7 | UffUpSamplePluginV2::UffUpSamplePluginV2(const nvinfer1::PluginFieldCollection &fc, float scale): mScale(scale){ 8 | 9 | } 10 | 11 | UffUpSamplePluginV2::UffUpSamplePluginV2(const void *data, size_t length){ 12 | if (data== nullptr){ 13 | printf("nullptr\n"); 14 | } 15 | const char *d = static_cast(data); 16 | const char* const start = d; 17 | mCHW = read(d); 18 | mDataType = read(d); 19 | mScale = read(d); 20 | mOutputHeight = read(d); 21 | mOutputWidth = read(d); 22 | if (mDataType == nvinfer1::DataType::kINT8){ 23 | mInHostScale = read(d); 24 | mOutHostScale = read(d); 25 | } 26 | assert(d = start + length); 27 | 28 | } 29 | 30 | 31 | // <============ IPluginV2 ===========> 32 | const char *UffUpSamplePluginV2::getPluginType() const { 33 | // 保证和IPluginCreator::getPluginName()一致 34 | return "ResizeNearestNeighbor"; 35 | } 36 | 37 | const char *UffUpSamplePluginV2::getPluginVersion() const { 38 | // 保证和IPluginCreator::getPluginVersion()一致 39 | return "2"; 40 | } 41 | 42 | int UffUpSamplePluginV2::getNbOutputs() const { 43 | return 1; 44 | } 45 | 46 | nvinfer1::Dims 47 | UffUpSamplePluginV2::getOutputDimensions(int index, const nvinfer1::Dims *inputs_dims, int number_input_dims) { 48 | assert(number_input_dims==1); 49 | assert(index == 0); 50 | assert(inputs_dims[0].nbDims==3); 51 | mCHW = inputs_dims[0]; 52 | mOutputHeight = inputs_dims[0].d[1] * mScale; 53 | mOutputWidth = inputs_dims[0].d[2] * mScale; 54 | return nvinfer1::Dims3(mCHW.d[0], mOutputHeight, mOutputWidth); 55 | } 56 | 57 | int UffUpSamplePluginV2::initialize() { 58 | // 可以用来分配内存 59 | int input_height = mCHW.d[1]; 60 | int input_widht = mCHW.d[2]; 61 | if (mOutputHeight == int(input_height * mScale) && mOutputWidth == int(input_widht * mScale)){ 62 | return 0; 63 | } else{ 64 | return 1; 65 | } 66 | } 67 | 68 | void UffUpSamplePluginV2::terminate() { 69 | // 可以用来释放内存 70 | } 71 | 72 | size_t UffUpSamplePluginV2::getWorkspaceSize(int max_batch_size) const { 73 | // 根据maxBatchSize确定该层所需要的最大内存空间 74 | return 0; 75 | } 76 | 77 | size_t UffUpSamplePluginV2::getSerializationSize() const { 78 | size_t serialization_size = 0; 79 | serialization_size += sizeof(nvinfer1::Dims); 80 | serialization_size += sizeof(nvinfer1::DataType); 81 | serialization_size += sizeof(float); 82 | serialization_size += sizeof(int) * 2; 83 | if (mDataType == nvinfer1::DataType::kINT8){ 84 | serialization_size += sizeof(float) * 2; 85 | } 86 | return serialization_size; 87 | } 88 | 89 | void UffUpSamplePluginV2::serialize(void *buffer) const { 90 | char *d = static_cast(buffer); 91 | const char* const start = d; 92 | printf("serialize mScale %f\n", mScale); 93 | write(d, mCHW); 94 | write(d, mDataType); 95 | write(d, mScale); 96 | write(d, mOutputHeight); 97 | write(d, mOutputWidth); 98 | if (mDataType == nvinfer1::DataType::kINT8){ 99 | write(d, mInHostScale); 100 | write(d, mOutHostScale); 101 | } 102 | assert(d == start + getSerializationSize()); 103 | } 104 | 105 | void UffUpSamplePluginV2::destroy() { 106 | delete this; 107 | } 108 | 109 | void UffUpSamplePluginV2::setPluginNamespace(const char *plugin_namespace) { 110 | mNameSpace = plugin_namespace; 111 | } 112 | 113 | const char *UffUpSamplePluginV2::getPluginNamespace() const { 114 | return mNameSpace.data(); 115 | } 116 | 117 | 118 | // <============ IPluginV2Ext ===========> 119 | nvinfer1::DataType 120 | UffUpSamplePluginV2::getOutputDataType(int index, const nvinfer1::DataType *input_types, int num_inputs) const { 121 | assert(index==0); 122 | assert(input_types!= nullptr); 123 | assert(num_inputs==1); 124 | return input_types[index]; 125 | } 126 | 127 | bool UffUpSamplePluginV2::isOutputBroadcastAcrossBatch(int output_index, const bool *input_is_broadcasted, 128 | int num_inputs) const { 129 | return false; 130 | } 131 | 132 | bool UffUpSamplePluginV2::canBroadcastInputAcrossBatch(int input_idx) const { 133 | return false; 134 | } 135 | 136 | nvinfer1::IPluginV2Ext *UffUpSamplePluginV2::clone() const { 137 | auto *plugin = new UffUpSamplePluginV2(*this); 138 | return plugin; 139 | } 140 | 141 | 142 | // <============ IPluginV2IOExt ===========> 143 | void UffUpSamplePluginV2::configurePlugin(const nvinfer1::PluginTensorDesc *plugin_tensor_desc_input, int num_input, 144 | const nvinfer1::PluginTensorDesc *plugin_tensor_desc_output, int num_output) { 145 | assert(num_input==1 && plugin_tensor_desc_input!= nullptr); 146 | assert(num_output==1 && plugin_tensor_desc_output != nullptr); 147 | assert(plugin_tensor_desc_input[0].type == plugin_tensor_desc_output[0].type); 148 | assert(plugin_tensor_desc_input[0].format == nvinfer1::TensorFormat::kLINEAR); 149 | assert(plugin_tensor_desc_output[0].format == nvinfer1::TensorFormat::kLINEAR); 150 | 151 | mInHostScale = plugin_tensor_desc_input->scale; 152 | mOutHostScale = plugin_tensor_desc_output->scale; 153 | 154 | mDataType = plugin_tensor_desc_input[0].type; 155 | } 156 | 157 | bool UffUpSamplePluginV2::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *plugin_tensor_desc_in_out, int num_inputs, 158 | int num_outputs) const { 159 | assert(plugin_tensor_desc_in_out != nullptr); 160 | assert(num_inputs == num_outputs == 1); 161 | assert(pos < num_inputs + num_outputs); 162 | bool condition = true; 163 | condition &= plugin_tensor_desc_in_out[pos].format == nvinfer1::TensorFormat::kLINEAR; 164 | condition &= plugin_tensor_desc_in_out[pos].type != nvinfer1::DataType::kINT32; 165 | condition &= plugin_tensor_desc_in_out[pos].type == plugin_tensor_desc_in_out[0].type; 166 | return condition; 167 | } 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /tensorRT/c++/source/ResizeNearestNeighbor.cu: -------------------------------------------------------------------------------- 1 | // Created by luozhiwang (luozw1994@outlook.com) 2 | // Date: 2020/2/12 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "ResizeNearestNeighbor.h" 9 | 10 | static void HandleError(cudaError_t err, const char *file, int line ) { 11 | if (err != cudaSuccess) { 12 | printf( "%s in %s at line %d\n", cudaGetErrorString( err ), 13 | file, line ); 14 | exit( EXIT_FAILURE ); 15 | } 16 | } 17 | #define HANDLE_ERROR( err ) (HandleError( err, __FILE__, __LINE__ )) 18 | 19 | __device__ int transform_idx(int idx, int C, int H, int W, float scale_factor){ 20 | // 从后往前解idx 21 | // idx = n*C*H*W + c*H*W + h*W + w 22 | int w = idx % W; 23 | idx /= W; 24 | int h = idx % H; 25 | idx /= H; 26 | int c = idx % C; 27 | idx /= C; 28 | w /= scale_factor; 29 | h /= scale_factor; 30 | int hh = H / scale_factor; 31 | int ww = W / scale_factor; 32 | return idx * C * hh * ww + c * hh * ww + h * ww + w; 33 | } 34 | 35 | 36 | template 37 | __global__ void UpSampleKernel(const Dtype *input, Dtype *output, int num_element, float scale_factor, int C, int H, int W){ 38 | int tid = threadIdx.x + blockIdx.x * blockDim.x; 39 | if (tid < num_element){ 40 | int idx = transform_idx(tid, C, H, W, scale_factor); 41 | output[tid]=input[idx]; 42 | } 43 | } 44 | 45 | template 46 | void UffUpSamplePluginV2::forwardGpu(const Dtype *input, Dtype *output, int N, int C, int H, int W, cudaStream_t stream) { 47 | int num_element = N * C * H * W; 48 | UpSampleKernel<<<(num_element-1)/mThreadNum+1, mThreadNum, 0, stream>>>(input, output, num_element, mScale, C, H, W); 49 | } 50 | 51 | size_t get_size(nvinfer1::DataType dataType){ 52 | switch(dataType){ 53 | case nvinfer1::DataType::kFLOAT : 54 | return sizeof(float); 55 | case nvinfer1::DataType::kHALF : 56 | return sizeof(__half); 57 | case nvinfer1::DataType::kINT8 : 58 | return sizeof(int8_t); 59 | default: 60 | throw "Unsupported Data Type"; 61 | } 62 | } 63 | 64 | int UffUpSamplePluginV2::enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace, 65 | cudaStream_t stream) { 66 | const int channel = mCHW.d[0]; 67 | const int input_h = mCHW.d[1]; 68 | const int input_w = mCHW.d[2]; 69 | const int output_h = mOutputHeight; 70 | const int output_w = mOutputWidth; 71 | int total_element = batch_size * channel * input_h * input_w; 72 | if (input_h == output_h && input_w == output_w){ 73 | HANDLE_ERROR(cudaMemcpyAsync(outputs[0], inputs[0], get_size(mDataType) * total_element, cudaMemcpyDeviceToDevice, stream)); 74 | HANDLE_ERROR(cudaStreamSynchronize(stream)); 75 | return 0; 76 | } 77 | switch (mDataType){ 78 | case nvinfer1::DataType::kFLOAT : 79 | forwardGpu((const float *)inputs[0], (float *)outputs[0], batch_size, channel, output_h, output_w, stream); 80 | break; 81 | case nvinfer1::DataType::kHALF : 82 | forwardGpu<__half>((const __half *)inputs[0], (__half *)outputs[0], batch_size, channel, output_h, output_w, stream); 83 | break; 84 | case nvinfer1::DataType::kINT8 : 85 | forwardGpu((const int8_t *)inputs[0], (int8_t *)outputs[0], batch_size, channel, output_h, output_w, stream); 86 | break; 87 | default: 88 | throw "Unsupported Data Type"; 89 | } 90 | return 0; 91 | } -------------------------------------------------------------------------------- /tensorRT/c++/source/ResizeNearestNeighbor.h: -------------------------------------------------------------------------------- 1 | // Created by luozhiwang (luozw1994@outlook.com) 2 | // Date: 2020/2/11 3 | 4 | #ifndef TENSORRT_RESIZENEARESTNEIGHBOR_H 5 | #define TENSORRT_RESIZENEARESTNEIGHBOR_H 6 | 7 | #include 8 | #include 9 | 10 | #include "utils.h" 11 | 12 | class UffUpSamplePluginV2 : public nvinfer1::IPluginV2IOExt{ 13 | private: 14 | nvinfer1::Dims mCHW; 15 | nvinfer1::DataType mDataType; 16 | float mScale; 17 | int mOutputHeight; 18 | int mOutputWidth; 19 | 20 | float mInHostScale{-1.0}; 21 | float mOutHostScale{-1.0}; 22 | 23 | std::string mNameSpace; 24 | const int mThreadNum = sizeof(unsigned long long) * 8 ; 25 | public: 26 | UffUpSamplePluginV2(const nvinfer1::PluginFieldCollection &fc, float scale=2.0); 27 | UffUpSamplePluginV2(const void *data, size_t length); 28 | // IPluginV2 29 | const char* getPluginType () const override; 30 | const char *getPluginVersion () const override; 31 | int getNbOutputs () const override; 32 | nvinfer1::Dims getOutputDimensions (int index, const nvinfer1::Dims *inputs_dims, int number_input_dims) override; 33 | int initialize() override; 34 | void terminate () override; 35 | size_t getWorkspaceSize (int max_batch_size) const override; 36 | int enqueue (int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override; 37 | size_t getSerializationSize () const override; 38 | void serialize (void *buffer) const override; 39 | void destroy () override; 40 | void setPluginNamespace (const char *plugin_namespace) override; 41 | const char *getPluginNamespace () const override; 42 | 43 | // IPluginV2Ext 44 | nvinfer1::DataType getOutputDataType (int index, const nvinfer1::DataType *input_types, int num_inputs) const override; 45 | bool isOutputBroadcastAcrossBatch (int output_index, const bool *input_is_broadcasted, int num_inputs) const override; 46 | bool canBroadcastInputAcrossBatch (int input_idx) const override; 47 | IPluginV2Ext * clone () const override; 48 | 49 | // IPluginV2IOExt 50 | void configurePlugin (const nvinfer1::PluginTensorDesc *plugin_tensor_desc_input, int num_input, const nvinfer1::PluginTensorDesc *plugin_tensor_desc_output, int num_output) override; 51 | bool supportsFormatCombination (int pos, const nvinfer1::PluginTensorDesc *inOut, int num_inputs, int num_outputs) const override; 52 | 53 | // Extension 54 | template 55 | void forwardGpu(const Dtype* input,Dtype * outputint ,int N,int C,int H ,int W, cudaStream_t stream); 56 | }; 57 | 58 | 59 | #endif //TENSORRT_RESIZENEARESTNEIGHBOR_H 60 | -------------------------------------------------------------------------------- /tensorRT/c++/source/keypoints_tensorrt.cpp: -------------------------------------------------------------------------------- 1 | // Created by luozhiwang (luozw1994@outlook.com) 2 | // Date: 2020/2/7 3 | 4 | #include "keypoints_tensorrt.h" 5 | 6 | bool Keypoints::constructNetwork(Keypoints::sample_unique_ptr &parser, 7 | Keypoints::sample_unique_ptr &network) { 8 | assert(uff_params.inputTensorNames.size() == 1); 9 | assert(uff_params.outputTensorNames.size() == 1); 10 | if (!parser -> registerInput(uff_params.inputTensorNames[0].c_str(), nvinfer1::Dims3(image_c, image_h, image_w), nvuffparser::UffInputOrder::kNCHW)){ 11 | gLogError << "Register Input Failed!" << std::endl; 12 | return false; 13 | } 14 | if (!parser -> registerOutput(uff_params.outputTensorNames[0].c_str())){ 15 | gLogError << "Register Output Failed!" << std::endl; 16 | return false; 17 | } 18 | if (!parser -> parse(uff_params.uffFileName.c_str(), *network, nvinfer1::DataType::kFLOAT)){ 19 | gLogError << "Parse Uff Failed!" << std::endl; 20 | return false; 21 | } 22 | if (uff_params.int8){ 23 | samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f); 24 | } 25 | return true; 26 | } 27 | 28 | bool Keypoints::processInput(const samplesCommon::BufferManager &buffer_manager, const std::string &input_tensor_name, 29 | const std::string &image_path) const { 30 | const int input_h = input_dims.d[1]; 31 | const int input_w = input_dims.d[2]; 32 | std::vector file_data = imagePreprocess(image_path, image_h, image_w); 33 | if (file_data.size() != input_h * input_w * image_c){ 34 | gLogError << "FileData size is "<(buffer_manager.getHostBuffer(input_tensor_name)); 38 | for (int i = 0; i < input_h * input_w * image_c; ++i){ 39 | host_input_buffer[i] = static_cast(file_data[i]) / 128.0 - 1; 40 | } 41 | return true; 42 | } 43 | 44 | std::vector> Keypoints::processOutput(const samplesCommon::BufferManager &buffer_manager, const std::string &output_tensor_name) const { 45 | auto *origin_output = static_cast(buffer_manager.getHostBuffer(output_tensor_name)); 46 | gLogInfo<< "Output: "<< std::endl; 47 | // Keypoint index transformation idx_x, idx_y, prob 48 | std::vector> keypoints; 49 | for (int c = 0; c < heatmap_c; ++c){ 50 | std::vector keypoint; 51 | int max_idx = -1; 52 | float max_prob = -1; 53 | // for (int idx = heatmap_h * heatmap_w * c; idx < heatmap_h * heatmap_w * (c + 1); ++idx){ 54 | // if (origin_output[idx] > max_prob){ 55 | // max_idx = idx; 56 | // max_prob = origin_output[idx]; 57 | // } 58 | // } 59 | // keypoint.push_back(static_cast(max_idx % heatmap_w) / heatmap_w); 60 | // keypoint.push_back(static_cast((max_idx / heatmap_w) % heatmap_h) / heatmap_h); 61 | // 迷之操作 输入都是kNCHW 输出怎么就是kNHWC了 62 | for (int idx = c; idx < heatmap_c * heatmap_h * heatmap_w; idx+=heatmap_c){ 63 | if (origin_output[idx] > max_prob){ 64 | max_idx = idx; 65 | max_prob = origin_output[idx]; 66 | } 67 | } 68 | keypoint.push_back(static_cast(max_idx / heatmap_c % heatmap_w) / heatmap_w); 69 | keypoint.push_back(static_cast((max_idx / heatmap_c) / heatmap_w) / heatmap_h); 70 | 71 | keypoint.push_back(max_prob); 72 | keypoints.push_back(keypoint); 73 | } 74 | for (int c = 0; c < heatmap_c; c++){ 75 | gLogInfo << "channel "<< c << " ==> x : "<< keypoints[c][0] << " y : " << keypoints[c][1] << " prob : " << keypoints[c][2]<< std::endl; 76 | } 77 | return keypoints; 78 | } 79 | 80 | Keypoints::Keypoints(samplesCommon::UffSampleParams params, InputParams input_params) : uff_params(std::move(params)), image_h(input_params.image_h), image_w(input_params.image_w), image_c(input_params.image_c), heatmap_h(input_params.heatmap_h), heatmap_w(input_params.heatmap_w), heatmap_c(input_params.heatmap_c){ 81 | gLogInfo << "Keypoints Construction" << std::endl; 82 | } 83 | 84 | bool Keypoints::build() { 85 | auto builder = sample_unique_ptr(nvinfer1::createInferBuilder(gLogger.getTRTLogger())); 86 | if (!builder){ 87 | gLogError << "Create Builder Failed" << std::endl; 88 | return false; 89 | } 90 | auto network = sample_unique_ptr(builder -> createNetworkV2(0U)); 91 | if (!network){ 92 | gLogError << "Create Network Failed" << std::endl; 93 | return false; 94 | } 95 | auto parser = sample_unique_ptr(nvuffparser::createUffParser()); 96 | if (!parser){ 97 | gLogError << "Create Parser Failed" << std::endl; 98 | return false; 99 | } 100 | if (!constructNetwork(parser, network)){ 101 | gLogError << "Construct Network Failed" << std::endl; 102 | return false; 103 | } 104 | 105 | // 配置config 106 | builder -> setMaxBatchSize(1); 107 | auto config = sample_unique_ptr(builder -> createBuilderConfig()); 108 | if (!config){ 109 | gLogError << "Create Config Failed" << std::endl; 110 | return false; 111 | } 112 | config -> setMaxWorkspaceSize(1_GiB); 113 | config -> setFlag(BuilderFlag::kGPU_FALLBACK); // 可以使用DLA加速 114 | 115 | if (uff_params.fp16){ 116 | config -> setFlag(BuilderFlag::kFP16); 117 | } 118 | if (uff_params.int8){ 119 | config -> setFlag(BuilderFlag::kINT8); 120 | } 121 | samplesCommon::enableDLA(builder.get(), config.get(), uff_params.dlaCore, true); 122 | cuda_engine = std::shared_ptr(builder -> buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter()); 123 | if (!cuda_engine){ 124 | gLogError << "Create Config Failed" << std::endl; 125 | return false; 126 | } 127 | 128 | assert(network -> getNbInputs() == 1); 129 | assert(network -> getNbOutputs() == 1); 130 | input_dims = network -> getInput(0) ->getDimensions(); 131 | assert(input_dims.nbDims == 3); 132 | 133 | gLogInfo << "Build Network Success!" << std::endl; 134 | return true; 135 | } 136 | 137 | bool Keypoints::infer() { 138 | samplesCommon::BufferManager buffer_manager(cuda_engine, uff_params.batchSize); 139 | auto context = sample_unique_ptr(cuda_engine -> createExecutionContext()); 140 | if (!context){ 141 | gLogError << "Create Context Failed" << std::endl; 142 | return false; 143 | } 144 | // 获取问价夹下所有Image图片 145 | std::vector images; 146 | DIR *dir = opendir(uff_params.dataDirs[0].c_str()); 147 | dirent *p = nullptr; 148 | gLogInfo << "Fetch images in " << uff_params.dataDirs[0]<d_name[0] && (strstr(p -> d_name, ".jpg") || strstr(p -> d_name, "png"))){ 153 | std::string imagePath = uff_params.dataDirs[0]+"/"+p->d_name; 154 | gLogInfo<<"--Image : "<d_name< execute(uff_params.batchSize, buffer_manager.getDeviceBindings().data())){ 163 | gLogError<<"Execute Failed!"<(t_end - t_start).count(); 170 | total += elapsed_time; 171 | buffer_manager.copyOutputToHost(); 172 | 173 | // 将输出结果渲染 174 | std::vector> keypoints; 175 | keypoints = processOutput(buffer_manager, uff_params.outputTensorNames[0]); 176 | cv::Mat ori_img = cv::imread(imagePath, cv::IMREAD_COLOR); 177 | cv::Mat render_img = renderKeypoint(ori_img, keypoints, heatmap_c, 0.3); 178 | saveImage(render_img, imagePath.insert(imagePath.length() - 4, "_render")); 179 | 180 | ++count; 181 | } 182 | } 183 | closedir(dir); 184 | gLogInfo<< "Total run time is " << total <<" ms\n"; 185 | gLogInfo<< "Average over " << count << " files run time is "< 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "utils.h" 13 | 14 | 15 | class Keypoints{ 16 | public: 17 | template 18 | using sample_unique_ptr = std::unique_ptr; 19 | private: 20 | std::shared_ptr cuda_engine{nullptr}; 21 | samplesCommon::UffSampleParams uff_params; 22 | nvinfer1::Dims input_dims; 23 | const int image_h; 24 | const int image_w; 25 | const int image_c; 26 | const int heatmap_h; 27 | const int heatmap_w; 28 | const int heatmap_c; 29 | 30 | // 将权重赋给网络 31 | bool constructNetwork(sample_unique_ptr &parser, sample_unique_ptr &network); 32 | // 处理输入,读入图片并存入buffer中 33 | bool processInput(const samplesCommon::BufferManager &buffer_manager, const std::string &input_tensor_name, const std::string &image_path) const; 34 | // 输出后处理,得到最终结果 35 | std::vector> processOutput(const samplesCommon::BufferManager &buffer_manager, const std::string &output_tensor_name) const; 36 | 37 | public: 38 | explicit Keypoints(samplesCommon::UffSampleParams uff_params, InputParams input_params); 39 | bool build(); 40 | bool infer(); 41 | bool tearDown(); 42 | }; 43 | 44 | #endif //TENSORRT_METER_TENSORRT_H 45 | -------------------------------------------------------------------------------- /tensorRT/c++/source/my_plugin.cpp: -------------------------------------------------------------------------------- 1 | // Created by luozhiwang (luozw1994@outlook.com) 2 | // Date: 2020/2/11 3 | 4 | #include "my_plugin.h" 5 | 6 | MyPlugin::~MyPlugin() { 7 | for (auto& item : mPluginUpSample){ 8 | item.reset(); 9 | } 10 | } 11 | 12 | 13 | const char *MyPlugin::getPluginName() const { 14 | return "ResizeNearestNeighbor"; 15 | } 16 | 17 | const char *MyPlugin::getPluginVersion() const { 18 | return "2"; 19 | } 20 | 21 | const nvinfer1::PluginFieldCollection* MyPlugin::getFieldNames() { 22 | // TODO 这里应该是依据参数创建PluginField的 23 | return &mFieldCollection; 24 | } 25 | 26 | nvinfer1::IPluginV2 *MyPlugin::createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) { 27 | if(!strcmp(name, "ResizeNearestNeighbor")){ 28 | printf("Unkown Plugin Name %s", name); 29 | return nullptr; 30 | } 31 | mPluginUpSample.emplace_back(std::unique_ptr(new UffUpSamplePluginV2(*fc))); 32 | return mPluginUpSample.back().get(); 33 | 34 | } 35 | 36 | nvinfer1::IPluginV2 *MyPlugin::deserializePlugin(const char *name, const void *serial_data, size_t serial_length) { 37 | auto plugin = new UffUpSamplePluginV2(serial_data, serial_length); 38 | mPluginName = name; 39 | return plugin; 40 | } 41 | 42 | void MyPlugin::setPluginNamespace(const char *plugin_name_space) { 43 | mNamespace = plugin_name_space; 44 | } 45 | 46 | const char *MyPlugin::getPluginNamespace() const { 47 | return mNamespace.c_str(); 48 | } 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /tensorRT/c++/source/my_plugin.h: -------------------------------------------------------------------------------- 1 | // Created by luozhiwang (luozw1994@outlook.com) 2 | // Date: 2020/2/11 3 | 4 | #ifndef TENSORRT_MY_PLUGIN_H 5 | #define TENSORRT_MY_PLUGIN_H 6 | 7 | #include "ResizeNearestNeighbor.h" 8 | #include 9 | #include 10 | 11 | class MyPlugin : public nvinfer1::IPluginCreator { 12 | private: 13 | std::string mNamespace; 14 | std::string mPluginName; 15 | nvinfer1::PluginFieldCollection mFieldCollection{0, nullptr}; 16 | std::vector> mPluginUpSample{}; 17 | public: 18 | const char* getPluginName() const override; 19 | const char* getPluginVersion() const override; 20 | const nvinfer1::PluginFieldCollection *getFieldNames() override; 21 | nvinfer1::IPluginV2* createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) override; 22 | nvinfer1::IPluginV2* deserializePlugin(const char *name, const void *serial_data, size_t serial_length) override; 23 | void setPluginNamespace (const char *plugin_name_space) override; 24 | const char* getPluginNamespace() const override; 25 | ~MyPlugin(); 26 | }; 27 | 28 | 29 | #endif //TENSORRT_MY_PLUGIN_H 30 | -------------------------------------------------------------------------------- /tensorRT/c++/source/utils.cpp: -------------------------------------------------------------------------------- 1 | // Created by luozhiwang (luozw1994@outlook.com) 2 | // Date: 2020/2/10 3 | #include "utils.h" 4 | 5 | InputParams::InputParams(int ih, int iw, int ic, int hh, int hw, int hc): image_h(ih), image_w(iw), image_c(ic), heatmap_h(hh), heatmap_w(hw), heatmap_c(hc){ 6 | 7 | } 8 | 9 | std::vector imagePreprocess(const std::string &image_path, const int &image_h, const int &image_w){ 10 | // image_path ===> BGR/HWC ===> RGB/CHW 11 | cv::Mat origin_image = cv::imread(image_path, cv::IMREAD_COLOR); 12 | cv::Mat rgb_image = origin_image; 13 | cv::cvtColor(origin_image, rgb_image, cv::COLOR_BGR2RGB); 14 | cv::Mat resized_image(image_h, image_w, CV_8UC3); 15 | cv::resize(rgb_image, resized_image, cv::Size(image_h, image_w)); 16 | std::vector file_data(resized_image.reshape(1, 1)); 17 | std::vector CHW; 18 | int c, h, w, idx; 19 | for (int i=0;i> &keypoints, int nums_keypoints, float thres=0.3){ 31 | int image_h = image.rows; 32 | int image_w = image.cols; 33 | int point_x, point_y; 34 | for (int i=0; i=thres){ 36 | point_x = image_w * keypoints[i][0]; 37 | point_y = image_h * keypoints[i][1]; 38 | cv::circle(image, cv::Point(point_x, point_y), 5, cv::Scalar(255, 204,0), 3); 39 | } 40 | } 41 | return image; 42 | } 43 | 44 | 45 | void saveImage(const cv::Mat &image, const std::string &save_path){ 46 | cv::imwrite(save_path, image); 47 | } 48 | 49 | -------------------------------------------------------------------------------- /tensorRT/c++/source/utils.h: -------------------------------------------------------------------------------- 1 | // Created by luozhiwang (luozw1994@outlook.com) 2 | // Date: 2020/2/7 3 | 4 | #ifndef TENSORRT_UTILS_H 5 | #define TENSORRT_UTILS_H 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | class InputParams{ 14 | public: 15 | const int image_h; 16 | const int image_w; 17 | const int image_c; 18 | const int heatmap_h; 19 | const int heatmap_w; 20 | const int heatmap_c; 21 | InputParams(int ih, int iw, int ic, int hh, int hw, int hc); 22 | }; 23 | 24 | std::vector imagePreprocess(const std::string &image_path, const int &image_h, const int &image_w); 25 | 26 | cv::Mat renderKeypoint(cv::Mat image, const std::vector> &keypoints, int nums_keypoints, float thres); 27 | 28 | void saveImage(const cv::Mat &image, const std::string &save_path); 29 | 30 | template 31 | void write(char*& buffer, const T& val){ 32 | *reinterpret_cast(buffer) = val; 33 | buffer += sizeof(T); 34 | } 35 | 36 | template 37 | T read(const char*& buffer) 38 | { 39 | T val = *reinterpret_cast(buffer); 40 | buffer += sizeof(T); 41 | return val; 42 | } 43 | 44 | #endif //TENSORRT_UTILS_H 45 | -------------------------------------------------------------------------------- /tensorRT/python/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2020/2/16 8 | """ 9 | -------------------------------------------------------------------------------- /tensorRT/python/pb2uff.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-24 8 | """ 9 | import uff 10 | # from tensorflow.contrib.tensorrt import trt_convert as trt 11 | import tensorrt as trt 12 | # from tensorflow.python.compiler.tensorrt import trt_convert as trt 13 | # TODO tf==1.12.0 只支持trt4 14 | pb_path = '../../Hourglass.pb' 15 | output_nodes = ["Keypoints/keypoint_1/conv/Sigmoid"] 16 | output_filename = 'Hourglass.uff' 17 | 18 | serialized=uff.from_tensorflow_frozen_model(pb_path, output_nodes, output_filename=output_filename) 19 | # print(serialized) 20 | 21 | # convert = trt.TrtGraphConverter( 22 | # input_graph_def=pb_path, 23 | # nodes_blacklist=output_nodes 24 | # ) 25 | # frozen_graph = convert.convert() -------------------------------------------------------------------------------- /tensorRT/python/readpb2graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-29 8 | """ 9 | import tensorflow as tf 10 | from tensorflow.python.platform import gfile 11 | 12 | 13 | def readpb2graph(pb_path, log_dir): 14 | """ 15 | transfer one pb file to visible graph in tensorboard 16 | You can build a model by tensorRT C++ API more easily! 17 | :param pb_path: pb_path 18 | :param log_dir: log_dir 19 | :return: None 20 | """ 21 | with tf.Session() as sess: 22 | with gfile.FastGFile(pb_path, 'rb') as f: 23 | graph_def = tf.GraphDef() 24 | graph_def.ParseFromString(f.read()) 25 | g_in = tf.import_graph_def(graph_def) 26 | train_writer = tf.summary.FileWriter(log_dir) 27 | train_writer.add_graph(sess.graph) 28 | 29 | 30 | if __name__ == '__main__': 31 | import os 32 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 33 | pb_path = '../Hourglass.pb' 34 | -------------------------------------------------------------------------------- /tensorRT/python/tfpb2trtpb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-27 8 | """ 9 | 10 | import tensorflow as tf 11 | import tensorflow.contrib.tensorrt as trt 12 | 13 | 14 | def tfpb2trtpb(pb_path, output_pb, output_node_name): 15 | # Inference with TF-TRT frozen graph workflow: 16 | graph = tf.Graph() 17 | with graph.as_default(): 18 | with tf.Session() as sess: 19 | # First deserialize your frozen graph: 20 | with tf.gfile.GFile(pb_path, 'rb') as f: 21 | graph_def = tf.GraphDef() 22 | graph_def.ParseFromString(f.read()) 23 | # Now you can create a TensorRT inference graph from your 24 | # frozen graph: 25 | trt_graph = trt.create_inference_graph( 26 | input_graph_def=graph_def, 27 | outputs=output_node_name, 28 | max_batch_size=1, 29 | max_workspace_size_bytes=2 << 20, 30 | precision_mode='fp32') 31 | 32 | with tf.gfile.GFile(output_pb, "wb") as f: # 保存模型 33 | f.write(trt_graph.SerializeToString()) 34 | # Import the TensorRT graph into a new graph and run: 35 | # output_node = tf.import_graph_def( 36 | # trt_graph, 37 | # return_elements=output_node_name) 38 | # sess.run(output_node) 39 | 40 | 41 | import os 42 | os.environ['CUDA_VISIBLE_DEVICES']='2' 43 | pb_path = '../Hourglass.pb' 44 | output_path = 'TensorRT.pb' 45 | output_node_name=['HourglassNet/keypoint_1/conv/BiasAdd'] 46 | tfpb2trtpb(pb_path, output_path, output_node_name) -------------------------------------------------------------------------------- /train_hourglass_coco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-12 8 | """ 9 | from core.train.trainer import Trainer 10 | from core.network.keypoints import Keypoints 11 | from core.dataset.data_generator import Dataset 12 | import config.config_hourglass_coco as cfg 13 | import time 14 | import tensorflow as tf 15 | import tensorflow.contrib.slim as slim 16 | import sys 17 | 18 | sys.path.append('.') 19 | 20 | 21 | class TrainHourglass(Trainer): 22 | def __init__(self, model, dataset, cfg): 23 | super(TrainHourglass, self).__init__(model, dataset, cfg) 24 | 25 | def init_model(self): 26 | # BN decay 0.9 27 | with slim.arg_scope([slim.batch_norm], decay=0.96): 28 | Trainer.init_model(self) 29 | 30 | def init_train_op(self): 31 | start_time = time.time() 32 | # TRAIN_OP 33 | with tf.name_scope("Train_op"): 34 | optimizer = tf.train.AdamOptimizer( 35 | self.learning_rate) 36 | # optimizer = tf.train.MomentumOptimizer( 37 | # self.learning_rate, 0.9) 38 | gvs = optimizer.compute_gradients(self.loss) 39 | clip_gvs = [(tf.clip_by_value(grad, -5., 5.), var) 40 | for grad, var in gvs] 41 | if self.is_debug: 42 | self.mean_gradient = tf.reduce_mean( 43 | [tf.reduce_mean(g) for g, v in gvs]) 44 | tf.summary.scalar("mean_gradient", self.mean_gradient) 45 | print('Debug mode is on !!!') 46 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 47 | # It's important! 48 | # Update moving-average in BN 49 | self.train_op = optimizer.apply_gradients( 50 | clip_gvs, global_step=self.global_step) 51 | print('-Creat train op in %.3f' % (time.time() - start_time)) 52 | 53 | def init_loader_saver(self): 54 | start_time = time.time() 55 | with tf.name_scope('loader_and_saver'): 56 | if self.pre_trained_ckpt is not None: 57 | from tensorflow.python import pywrap_tensorflow 58 | reader = pywrap_tensorflow.NewCheckpointReader(self.pre_trained_ckpt) 59 | var_to_shape_map = reader.get_variable_to_shape_map() 60 | var_to_restore = [k for k in var_to_shape_map] 61 | # var_ = [var for var in tf.global_variables() if var.name.strip(':0') in var_to_restore and var.name.strip(':0')!="Learning_rate/global_step" and "Momentum" not in var.name.strip(':0')] 62 | var_ = [var for var in tf.global_variables() if var.name.strip(':0') in var_to_restore and var.name.strip(':0')!="Learning_rate/global_step"] 63 | print('restore var total is %d' % len(var_)) 64 | self.loader = tf.train.Saver(var_list=var_) 65 | self.saver = tf.train.Saver( 66 | var_list=tf.global_variables(), 67 | max_to_keep=self.max_keep) 68 | print( 69 | '-Creat loader saver in %.3f' % 70 | (time.time() - start_time)) 71 | 72 | 73 | if __name__ == '__main__': 74 | trainer = TrainHourglass(Keypoints, Dataset, cfg) 75 | trainer.train_launch() 76 | -------------------------------------------------------------------------------- /train_hourglass_mpii.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved 5 | 6 | Authors: luozhiwang(luozw1994@outlook.com) 7 | Date: 2019-09-26 8 | """ 9 | from core.train.trainer import Trainer 10 | from core.network.keypoints import Keypoints 11 | from core.dataset.data_generator import Dataset 12 | import config.config_hourglass_mpii as cfg 13 | 14 | import sys 15 | sys.path.append('.') 16 | 17 | 18 | class TrainHourglass(Trainer): 19 | def __init__(self, model, dataset, cfg): 20 | super(TrainHourglass, self).__init__(model, dataset, cfg) 21 | 22 | def train_launch(self): 23 | self.is_debug = False 24 | # must in order 25 | self.init_dataset() 26 | self.init_inputs() 27 | self.init_model() 28 | 29 | # optional override 30 | self.init_loss() 31 | self.init_learning_rate() 32 | self.init_train_op() 33 | self.init_loader_saver_summary() 34 | self.init_session() 35 | self.train() 36 | 37 | if __name__ == '__main__': 38 | trainer = TrainHourglass(Keypoints, Dataset, cfg) 39 | trainer.train_launch() 40 | --------------------------------------------------------------------------------