├── README.md
├── checkpoints
├── coco
│ └── readme.txt
└── mpii
│ └── readme.txt
├── config
├── config_hourglass_coco.py
└── config_hourglass_mpii.py
├── core
├── dataset
│ ├── __init__.py
│ ├── data_augment.py
│ └── data_generator.py
├── infer
│ ├── __init__.py
│ ├── freeze_graph.py
│ ├── infer_utils.py
│ └── visual_utils.py
├── loss
│ ├── __init__.py
│ └── loss.py
├── network
│ ├── __init__.py
│ ├── keypoints.py
│ └── network_utils.py
└── train
│ ├── __init__.py
│ └── trainer.py
├── data
├── dataset
│ └── readme.txt
└── name
│ └── coco.name
├── demon.py
├── infer_hourglass.py
├── output
├── coco
│ └── readme.txt
└── mpii
│ └── readme.txt
├── script
├── __init__.py
├── ckpt2ckpt.py
├── coco2txt.py
├── mpii2coco.py
└── parse_ckpt.py
├── tensorRT
├── c++
│ ├── CMakeLists.txt
│ ├── Keypoints_main.cpp
│ ├── data
│ │ └── images
│ │ │ ├── 1_0_origin.jpg
│ │ │ ├── 1_0_origin_render.jpg
│ │ │ ├── 7_0_origin.jpg
│ │ │ └── 7_0_origin_render.jpg
│ └── source
│ │ ├── ResizeNearestNeighbor.cpp
│ │ ├── ResizeNearestNeighbor.cu
│ │ ├── ResizeNearestNeighbor.h
│ │ ├── keypoints_tensorrt.cpp
│ │ ├── keypoints_tensorrt.h
│ │ ├── my_plugin.cpp
│ │ ├── my_plugin.h
│ │ ├── utils.cpp
│ │ └── utils.h
└── python
│ ├── __init__.py
│ ├── pb2uff.py
│ ├── readpb2graph.py
│ └── tfpb2trtpb.py
├── train_hourglass_coco.py
└── train_hourglass_mpii.py
/README.md:
--------------------------------------------------------------------------------
1 | # Keypoint Detection In Tensorflow and TensorRT C++
2 | ## 1.Modified hourglass (Hourglass-104) and ResNet-101
3 |
4 | ### Introduction
5 | 此项目为关键点检测训练以及推理加速代码。训练部分用python3 + tensorflow-1.14完成,推理部分用C++ + tensorRT-6完成。
6 | 训练数据集主要为COCO,模型为Hourglass。
7 |
8 | ### Quick Start
9 | * python3 train_hourglass_coco.py
10 | * python3 core/infer/freeze_graph.py -CUDA 0 -c checkpoints/coco/Hourglass_coco.ckpt -o Hourglass.pb
11 | * python3 demon.py
12 |
13 | ### Checkpoints
14 | https://drive.google.com/drive/folders/1pjOH1XUQOuMXlfGddQPvVEjXaXXPU7u1?usp=sharing
15 |
16 | ### Data Format
17 | 如果需要使用自己的数据集进行训练,首先需要将数据转换成如下的格式
18 | (filename1 bxmin,bymin,bxmax,bymax px,py px,py ...)
19 | If multi points have same label
20 | (filename1 bxmin,bymin,bxmax,bymax px,py|px,py px,py ...)
21 | (filename2 bxmin,bymin,bxmax,bymax px,py|px,py px,py|px,py ...)
22 | ...
23 |
24 |
25 | ### Inference
26 | 在core/infer/infer_utils.py中的一些api可以用来构建一个简单的inference模型。通过Flask包装一下就可以实现简单的线上推理了。操作示例在infer_hourglass.py中,其中bbx需要通过其他模型获取。
27 |
28 | ### 注意事项
29 | TensorRT部分已经转移到新的仓库下
30 | [https://github.com/Syencil/tensorRT](https://github.com/Syencil/tensorRT)
31 |
32 | ## 2.TensorRT
33 | ## 介绍
34 | 此处项目采用CUDA 10 + tensorRT-6完成推理阶段,可实现模型推理加速,支持FP32,FP16
35 | ### 开始使用
36 | * 1.pb转uff
37 | * cd tensorRT/python
38 | * python3 pb2uff.py
39 | * 2.编译C++文件
40 | * cd tensorRT/c++
41 | * cmake .
42 | * make
43 |
44 |
45 | ## 尚未完成的部分
46 | * ~~1.数据增强 主要是图像旋转增强这一块有问题,会尽快将包括其他的增强方式加入项目~~
47 | * ~~2.TensorRT C++中对upsample plugin的实现,框架现已搭好,会尽快更新~~
48 | * ~~3.通过Hourglass-101构建今年大火的Anchor-free检测器CenterNet:Object as point~~
49 | * ~~4.tensorRT C++数据预处理和python有点不同,并不影响太多,懒得改了。~~
50 | * ~~5.Int 8量化矫正,有空再更新~~
51 |
52 |
--------------------------------------------------------------------------------
/checkpoints/coco/readme.txt:
--------------------------------------------------------------------------------
1 | directory contains checkpoints
--------------------------------------------------------------------------------
/checkpoints/mpii/readme.txt:
--------------------------------------------------------------------------------
1 | directory contains checkpoints
--------------------------------------------------------------------------------
/config/config_hourglass_coco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-11
8 | """
9 |
10 | # HARDWARE
11 | CUDA_VISIBLE_DEVICES = '2'
12 | CUDA_VISIBLE_DEVICES_INFER = '1'
13 | MULTI_THREAD_NUM = 4
14 | MULTI_GPU = [0]
15 |
16 | # PATH
17 | dataset_dir = '/data/dataset/coco'
18 | train_image_dir = '/data/dataset/coco/images/train2017'
19 | val_image_dir = '/data/dataset/coco/images/val2017'
20 | train_list_path = 'data/dataset/coco/coco_train.txt'
21 | val_list_path = 'data/dataset/coco/coco_val.txt'
22 |
23 | log_dir = 'output/coco'
24 | ckpt_dir = '/data/checkpoints/coco'
25 |
26 | # AUGMENT
27 | augment = {
28 | "color_jitter": 0.5,
29 | "crop": (0.5, 0.9),
30 | "rotate": (0.5, 15),
31 | "ver_flip": 0,
32 | "hor_flop": 0,
33 | }
34 |
35 | # NETWORK
36 | backbone = "hourglass"
37 | loss_mode = 'focal' # focal, sigmoid, softmax, mse
38 | image_size = (512, 512)
39 | stride = 4
40 | heatmap_size = (128, 128)
41 | num_block = 2
42 | num_depth = 5
43 | residual_dim = [256, 384, 384, 384, 512]
44 |
45 | is_maxpool = False
46 | is_nearest = True
47 |
48 | # SAVER AND LOADER
49 | max_keep = 30
50 | pre_trained_ckpt = None
51 | ckpt_name = backbone + "_coco" + '.ckpt'
52 |
53 | # TRAINING
54 | batch_size = 16
55 | learning_rate_init = 1e-3
56 | learning_rate_warmup = 2.5e-4
57 | exp_decay = 0.97
58 |
59 | warmup_epoch_size = 0
60 | epoch_size = 40
61 | summary_per = 20
62 | save_per = 2500
63 |
64 | regularization_weight = 5e-4
65 |
66 |
67 |
68 |
69 | # VAL
70 | val_per = 2500
71 | val_time = 20
72 | val_rate = 0.1
73 |
74 | # TEST
75 |
76 |
--------------------------------------------------------------------------------
/config/config_hourglass_mpii.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-11
8 | """
9 |
10 | # HARDWARE
11 | CUDA_VISIBLE_DEVICES = '3'
12 | CUDA_VISIBLE_DEVICES_INFER = '0'
13 | MULTI_THREAD_NUM = 4
14 | MULTI_GPU = [0]
15 |
16 | # PATH
17 | dataset_dir = 'data/dataset/mpii'
18 | train_image_dir = 'data/dataset/mpii/images'
19 | val_image_dir = 'data/dataset/mpii/images'
20 | train_list_path = 'data/dataset/mpii/mpii_train.txt'
21 | val_list_path = 'data/dataset/mpii/mpii_train.txt'
22 |
23 | log_dir = 'output/mpii'
24 | ckpt_dir = 'checkpoints/mpii'
25 |
26 |
27 | # AUGMENT
28 | augment = {
29 | "color_jitter": 0.5,
30 | "crop" : (0.5, 0.9),
31 | "rotate": (0.5,30),
32 | "ver_flip": 0,
33 | "hor_flop": 0.5,
34 | }
35 |
36 | # NETWORK
37 | backbone = "hourglass"
38 | loss_mode = 'focal' # focal, sigmoid, softmax, mse
39 | image_size = (512, 512)
40 | stride = 4
41 | heatmap_size = (128, 128)
42 | num_block = 2
43 | num_depth = 5
44 | residual_dim = [256, 384, 384, 384, 512]
45 |
46 | is_maxpool = False
47 | is_nearest = True
48 |
49 | # SAVER AND LOADER
50 | max_keep = 30
51 | pre_trained_ckpt = None
52 | ckpt_name = backbone + "_voc" + ".ckpt"
53 |
54 | # TRAINING
55 | batch_size = 8
56 | learning_rate_init = 2.5e-4
57 | learning_rate_warmup = 1e-4
58 | momentum = 0.9
59 |
60 | warmup_epoch_size = 1
61 | epoch_size = 60
62 | summary_per = 20
63 | save_per = 5000
64 |
65 | regularization_weight = 5e-4
66 |
67 | # VAL
68 | val_per = 200
69 | val_time = 20
70 | val_rate = 0.1
71 |
72 | # TEST
73 |
74 |
--------------------------------------------------------------------------------
/core/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/core/dataset/__init__.py
--------------------------------------------------------------------------------
/core/dataset/data_augment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-10-08
8 | """
9 | import numpy as np
10 | from albumentations import (
11 | KeypointParams,
12 | HorizontalFlip,
13 | VerticalFlip,
14 | RandomCrop,
15 | Compose,
16 | ShiftScaleRotate,
17 | RandomBrightnessContrast,
18 | HueSaturationValue,
19 | Resize
20 | )
21 |
22 |
23 | def image_augment_with_keypoint(image, keypoints, color_jitter=0.5, crop=(
24 | 0.5, 0.8), rotate=(0.5, 30), ver_flip=0, hor_flop=0.5):
25 |
26 | image_h, image_w = image.shape[0:2]
27 | keypoints = np.clip(keypoints, None, max(image_w - 1, image_h - 1))
28 | points_ = []
29 | idx_ = []
30 | for i, ps in enumerate(keypoints):
31 | for j, p in enumerate(ps):
32 | if p[0] >= 0 and p[1] >= 0:
33 | points_.append(p)
34 | idx_.append([i, j])
35 |
36 | def get_aug(aug):
37 | return Compose(aug, keypoint_params=KeypointParams(format="xy"))
38 |
39 | aug = get_aug([VerticalFlip(p=ver_flip),
40 | HorizontalFlip(p=hor_flop),
41 | RandomCrop(
42 | p=crop[0],
43 | height=int(
44 | image_h *
45 | crop[1]),
46 | width=int(
47 | image_w *
48 | crop[1])),
49 | ShiftScaleRotate(p=rotate[0], rotate_limit=rotate[1]),
50 | RandomBrightnessContrast(p=color_jitter),
51 | HueSaturationValue(p=color_jitter),
52 | Resize(p=1, height=image_h, width=image_w)
53 | ]
54 | )
55 | augmented = aug(image=image, keypoints=points_)
56 |
57 | for i in range(len(augmented["keypoints"])):
58 | keypoints[idx_[i][0]][idx_[i][1]] = list(augmented["keypoints"][i])
59 |
60 | return augmented["image"], keypoints
61 |
--------------------------------------------------------------------------------
/core/dataset/data_generator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-16
8 | """
9 | import random
10 | import time
11 | import cv2
12 | import numpy as np
13 | import os
14 | import copy
15 | from core.dataset.data_augment import image_augment_with_keypoint
16 |
17 |
18 | class Dataset():
19 | def __init__(self, image_dir, gt_path, batch_size,
20 | augment=None, image_size=(512, 512), heatmap_size=(128, 128)):
21 | """
22 | Wrapper for key-points detection dataset
23 | :param image_dir: (str) image dir
24 | :param gt_path: (str) data file eg. train.txt or val.txt, etc
25 | :param batch_size: (int) batch size
26 | :param image_size: (int, int) height, width
27 | :param heatmap_size: (int, int) height, width. can be divided by image_size
28 | """
29 | # 数据量太大 不能直接读到内存 tf.data.dataset 不好使用
30 | # 读取info支持使用多线程加速
31 | self.gt_path = gt_path
32 | self.image_dir = image_dir
33 | self.image_size = image_size
34 | self.heatmap_size = heatmap_size
35 | self.batch_size = batch_size
36 | self.augment = augment
37 |
38 | self.data_set = self.creat_set_from_txt()
39 | # self.transform_image_set_abs_to_rel()
40 |
41 | self.num_data = len(self.data_set)
42 | self.num_class = len(self.data_set[0][2])
43 | self.stride = self.image_size[0] // self.heatmap_size[0]
44 | self.ratio = self.image_size[0] / self.image_size[1]
45 |
46 | self._pre = -self.batch_size
47 |
48 | def creat_set_from_txt(self):
49 | """
50 | support multi point
51 | read image info and gt into memory
52 | :return: [[(str) image_name, [(int) xmin, (int) ymin, (int) xmax, (int) ymax], [[(int) px, (int) py]]]]
53 | """
54 | image_set = []
55 | t0 = time.time()
56 | count = 0
57 |
58 | for line in open(self.gt_path, 'r').readlines():
59 | if line == '':
60 | continue
61 | count += 1
62 | if count % 5000 == 0:
63 | print("--parse %d " % count)
64 | b = line.split()[1].split(',')
65 | points = line.split()[2:]
66 | tmp = []
67 | for point in points:
68 | tmp.append([[round(float(x)) for x in y.split(",")]
69 | for y in point.split('|')])
70 | image_set.append(
71 | (line.split()[0], [round(float(x)) for x in b], tmp))
72 | print('-Set has been created in %.3fs' % (time.time() - t0))
73 | return image_set
74 |
75 | def sample_batch_image_random(self):
76 | """
77 | sample data (infinitely)
78 | :return: list
79 | """
80 | return random.sample(self.data_set, self.batch_size)
81 | # return self.data_set[:self.batch_size]
82 |
83 | def sample_batch_image_order(self):
84 | """
85 | sample data in order (one shot)
86 | :return: list
87 | """
88 | self._pre += self.batch_size
89 | if self._pre >= self.num_data:
90 | raise StopIteration
91 | _last = self._pre + self.batch_size
92 | if _last > self.num_data:
93 | _last = self.num_data
94 | return self.data_set[self._pre:_last]
95 |
96 | def make_guassian(self, height, width, sigma=3, center=None):
97 | x = np.arange(0, width, 1, float)
98 | y = np.arange(0, height, 1, float)[:, np.newaxis]
99 | if center is None:
100 | x0 = width // 2
101 | y0 = height // 2
102 | else:
103 | x0 = center[0]
104 | y0 = center[1]
105 | return np.exp(-4. * np.log(2.) * ((x - x0) **
106 | 2 + (y - y0) ** 2) / sigma ** 2)
107 |
108 | def generate_hm(self, joints, heatmap_h_w):
109 | num_joints = len(joints)
110 | hm = np.zeros([heatmap_h_w[0], heatmap_h_w[1],
111 | num_joints], dtype=np.float32)
112 | for i in range(num_joints):
113 | for joint in joints[i]:
114 | if joint[0] != -1 and joint[1] != -1:
115 | s = int(
116 | np.sqrt(
117 | heatmap_h_w[0]) * heatmap_h_w[1] * 10 / 4096) + 2
118 | gen_hm = self.make_guassian(heatmap_h_w[0], heatmap_h_w[1], sigma=s,
119 | center=[joint[0] // self.stride, joint[1] // self.stride])
120 | hm[:, :, i] = np.maximum(hm[:, :, i], gen_hm)
121 | return hm
122 |
123 | def _crop_image_with_pad_and_resize(self, image, bbx, points, ratio=0.05):
124 | image_h, image_w = image.shape[0:2]
125 | crop_bbx = copy.deepcopy(bbx)
126 | crop_points = copy.deepcopy(points)
127 |
128 | w = bbx[2] - bbx[0] + 1
129 | h = bbx[3] - bbx[1] + 1
130 | # keep 5% blank for edge
131 | crop_bbx[0] = int(bbx[0] - w * ratio)
132 | crop_bbx[1] = int(bbx[1] - h * ratio)
133 | crop_bbx[2] = int(bbx[2] + w * ratio)
134 | crop_bbx[3] = int(bbx[3] + h * ratio)
135 | # clip value from 0 to len-1
136 | crop_bbx[0] = 0 if crop_bbx[0] < 0 else crop_bbx[0]
137 | crop_bbx[1] = 0 if crop_bbx[1] < 0 else crop_bbx[1]
138 | crop_bbx[2] = image_w - 1 if crop_bbx[2] > image_w - 1 else crop_bbx[2]
139 | crop_bbx[3] = image_h - 1 if crop_bbx[3] > image_h - 1 else crop_bbx[3]
140 | # crop the image
141 | crop_image = image[crop_bbx[1]: crop_bbx[3] +
142 | 1, crop_bbx[0]: crop_bbx[2] + 1, :]
143 | # update width and height
144 | w = crop_bbx[2] - crop_bbx[0] + 1
145 | h = crop_bbx[3] - crop_bbx[1] + 1
146 | # keep aspect ratio
147 |
148 | ih, iw = self.image_size
149 |
150 | scale = min(iw / w, ih / h)
151 | nw, nh = int(scale * w), int(scale * h)
152 | image_resized = cv2.resize(crop_image, (nw, nh))
153 |
154 | image_paded = np.full(shape=[ih, iw, 3], fill_value=128, dtype=np.uint8)
155 | dw, dh = (iw - nw) // 2, (ih - nh) // 2
156 | image_paded[dh:nh + dh, dw:nw + dw, :] = image_resized
157 | for i in range(len(points)):
158 | for j, point in enumerate(points[i]):
159 | if point[0] != -1 and point[1] != -1:
160 | crop_points[i][j][0] = (point[0] - crop_bbx[0]) * scale + dw
161 | crop_points[i][j][1] = (point[1] - crop_bbx[1]) * scale + dh
162 |
163 | return image_paded, crop_points
164 |
165 | def _one_image_and_heatmap(self, image_set):
166 | """
167 | process only one image
168 | :param image_set: [image_name, bbx, [points]]
169 | :return: (narray) image_h_w x C, (narray) heatmap_h_w x C'
170 | """
171 | image_name, bbx, point = image_set
172 | image_path = os.path.join(self.image_dir, image_name)
173 | img = cv2.imread(image_path)
174 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
175 | img, point = self._crop_image_with_pad_and_resize(img, bbx, point)
176 | if self.augment is not None:
177 | img, point = image_augment_with_keypoint(img, point)
178 | hm = self.generate_hm(point, self.heatmap_size)
179 | return img, hm
180 |
181 | def iterator(self, max_worker=None, is_oneshot=False):
182 | """
183 | Wrapper for batch_data processing
184 | transform data from txt to imgs and hms
185 | (Option) utilize multi thread acceleration
186 | generator images and heatmaps infinitely or make oneshot
187 | :param max_worker: (optional) (int) max worker for multi-thread
188 | :param is_oneshot: (optional) (bool) if False, generator will sample infinitely.
189 | :return: iterator. imgs, hms = next(iterator)
190 | """
191 | if is_oneshot:
192 | sample_fn = self.sample_batch_image_order
193 | else:
194 | sample_fn = self.sample_batch_image_random
195 | if max_worker is not 0:
196 | from concurrent.futures import ThreadPoolExecutor, as_completed
197 | with ThreadPoolExecutor(max_worker) as executor:
198 | while True:
199 | image_set = sample_fn()
200 | imgs = []
201 | hms = []
202 | if executor is None:
203 | for i in range(len(image_set)):
204 | img, hm = self._one_image_and_heatmap(image_set[i])
205 | imgs.append(img)
206 | hms.append(hm)
207 | else:
208 | all_task = [
209 | executor.submit(
210 | self._one_image_and_heatmap,
211 | image_set[i]) for i in range(
212 | len(image_set))]
213 | for future in as_completed(all_task):
214 | imgs.append(future.result()[0])
215 | hms.append(future.result()[1])
216 | final_imgs = np.stack(imgs, axis=0)
217 | final_hms = np.stack(hms, axis=0)
218 | yield final_imgs, final_hms
219 | else:
220 | while True:
221 | image_set = sample_fn()
222 | imgs = []
223 | hms = []
224 | for i in range(len(image_set)):
225 | img, hm = self._one_image_and_heatmap(image_set[i])
226 | imgs.append(img)
227 | hms.append(hm)
228 | final_imgs = np.stack(imgs, axis=0)
229 | final_hms = np.stack(hms, axis=0)
230 | yield final_imgs, final_hms
231 |
232 |
233 | if __name__ == '__main__':
234 |
235 | from core.infer.visual_utils import visiual_image_with_hm
236 | import config.config_hourglass_coco as cfg
237 | image_dir = cfg.val_image_dir
238 | gt_path = "../../"+cfg.val_list_path
239 | render_path = '../../render_img'
240 |
241 | ite = 3
242 | batch_size = 16
243 |
244 | coco = Dataset(image_dir, gt_path, batch_size, augment=cfg.augment)
245 | it = coco.iterator(0, True)
246 |
247 | t0 = time.time()
248 | for i in range(ite):
249 | b_img, b_hm = next(it)
250 | for j in range(batch_size):
251 | img = b_img[j][:, :, ::-1]
252 | hm = b_hm[j]
253 | img_hm = visiual_image_with_hm(img, hm)
254 | cv2.imwrite(
255 | '../../render_img/' +
256 | str(i) +
257 | '_' +
258 | str(j) +
259 | '_img_hm.jpg',
260 | img_hm)
261 |
262 | print(time.time() - t0)
263 |
--------------------------------------------------------------------------------
/core/infer/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-21
8 | """
9 |
--------------------------------------------------------------------------------
/core/infer/freeze_graph.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-24
8 | """
9 |
10 | import argparse
11 | import tensorflow as tf
12 | from tensorflow.python.framework import graph_util
13 | # from tensorflow.contrib.tensorrt import trt_convert as trt
14 | import os
15 | import sys
16 |
17 | sys.path.append('.')
18 |
19 |
20 | def parse_arg():
21 | parse = argparse.ArgumentParser()
22 | parse.add_argument(
23 | '-C',
24 | '--CUDA',
25 | dest='CUDA',
26 | default=None,
27 | help='CUDA_VISIBLE_DEVICE')
28 | parse.add_argument(
29 | '-c',
30 | '--ckpt',
31 | dest='ckpt',
32 | default=None,
33 | help='Freeze ckpt path')
34 | parse.add_argument(
35 | '-o',
36 | '--output',
37 | dest='output_graph',
38 | default=None,
39 | help='Output graph path')
40 | parse.add_argument(
41 | '-t',
42 | '--is_training',
43 | dest='output_graph',
44 | default=False,
45 | help='Output graph path')
46 | return parse.parse_args()
47 |
48 |
49 | def freeze_graph(input_checkpoint, output_graph, is_training=False):
50 | '''
51 | :param input_checkpoint:
52 | :param output_graph: PB模型保存路径
53 | :param is_training: Is BN using moving-mean and moving-var
54 | :return:
55 | '''
56 | # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
57 | # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
58 |
59 | if is_training:
60 | saver = tf.train.import_meta_graph(
61 | input_checkpoint + '.meta', clear_devices=True)
62 | else:
63 | from core.network.keypoints import Keypoints
64 | import config.config_hourglass_coco as config
65 | model = Keypoints(tf.placeholder(name="Placeholder/inputs_x", dtype=tf.float32, shape=[None, 512, 512, 3]),
66 | 17,
67 | num_block=config.num_block,
68 | num_depth=config.num_depth,
69 | residual_dim=config.residual_dim,
70 | is_training=False,
71 | is_maxpool=config.is_maxpool,
72 | is_nearest=config.is_nearest
73 | )
74 | saver = tf.train.Saver(var_list=tf.global_variables())
75 |
76 | # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
77 | print('Freeze graph')
78 | output_node_names = ["Keypoints/keypoint_1/conv/Sigmoid"]
79 | print(output_node_names)
80 |
81 | with tf.Session() as sess:
82 | # sess.run(tf.global_variables_initializer())
83 | saver.restore(sess, input_checkpoint) # 恢复图并得到数据
84 | output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
85 | sess=sess,
86 | input_graph_def=sess.graph_def, # 等于:sess.graph_def
87 | output_node_names=output_node_names)
88 |
89 | with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
90 | f.write(output_graph_def.SerializeToString()) # 序列化输出
91 | print("%d ops in the final graph." %
92 | len(output_graph_def.node)) # 得到当前图有几个操作节点
93 |
94 |
95 | if __name__ == '__main__':
96 | args = parse_arg()
97 | if args.CUDA is not None:
98 | os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA
99 | freeze_graph(args.ckpt, args.output_graph)
100 |
--------------------------------------------------------------------------------
/core/infer/infer_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-21
8 | """
9 | import os
10 | import cv2
11 | import time
12 | import numpy as np
13 | import tensorflow as tf
14 |
15 |
16 | def read_pb_infer(pb_path, input_node_name_and_val, output_node_name):
17 | """
18 | [xmin, ymin, xmax, ymax, score, cid]
19 | :param pb_path:
20 | :param input_node_name_and_val: {(str) input_node_name: (any) input_node_val}
21 | :param output_node_name: [(str) output_node_name]
22 | :return: [output] B x Num_bbx x 6
23 | """
24 | with tf.Graph().as_default():
25 | output_graph_def = tf.GraphDef()
26 | with open(pb_path, 'rb') as f:
27 | output_graph_def.ParseFromString(f.read())
28 | tf.import_graph_def(output_graph_def, name='')
29 | config = tf.ConfigProto(allow_soft_placement=True) # 是否自动选择GPU
30 | config.gpu_options.allow_growth = True
31 | with tf.Session(config=config) as sess:
32 | # sess.run(tf.global_variables_initializer())
33 | # 定义输入的张量名称,对应网络结构的输入张量
34 | # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
35 | feed_dict = {}
36 | for key in input_node_name_and_val:
37 | input_tensor = sess.graph.get_tensor_by_name(key)
38 | feed_dict[input_tensor] = input_node_name_and_val[key]
39 |
40 | # 定义输出的张量名称
41 | output_tensor = []
42 | for name in output_node_name:
43 | output_tensor.append(sess.graph.get_tensor_by_name(name))
44 |
45 | # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
46 | start_time = time.time()
47 | output = sess.run(output_tensor, feed_dict=feed_dict)
48 | print('Infer time is %.4f' % (time.time() - start_time))
49 | return output
50 |
51 |
52 | def read_pb(pb_path, input_name, output_name):
53 | """
54 | Instantiation Session
55 | :param pb_path: (str) pb file path
56 | :param input_name: [(str)] input tensor names
57 | :param output_name: [(str)] output tensor names
58 | :return: (tf.Session) sess, (Tensor) input, (Tensor) output
59 | """
60 | # return sess
61 | with tf.Graph().as_default():
62 | output_graph_def = tf.GraphDef()
63 | with open(pb_path, 'rb') as f:
64 | output_graph_def.ParseFromString(f.read())
65 | tf.import_graph_def(output_graph_def, name='')
66 | config = tf.ConfigProto(allow_soft_placement=True) # 是否自动选择GPU
67 | config.gpu_options.allow_growth = True
68 | sess = tf.Session(config=config)
69 |
70 | if not isinstance(input_name, list) or isinstance(input_name, tuple):
71 | input_name = [input_name]
72 | input_tensor = []
73 | output_tensor = []
74 | for i in range(len(input_name)):
75 | input_tensor.append(sess.graph.get_tensor_by_name(input_name[i]))
76 | for i in range(len(output_name)):
77 | output_tensor.append(sess.graph.get_tensor_by_name(output_name[i]))
78 |
79 | return sess, input_tensor, output_tensor
80 |
81 |
82 | def pb_infer(sess, output_tensor, input_tensor=None, input_val=None):
83 | """
84 | get output
85 | :param sess: (tf.Session) sess
86 | :param output_tensor: [(Tensor)]
87 | :param input_tensor: [(Tensor)]
88 | :param input_val: [(np.array)]
89 | :return:
90 | """
91 | feed_dict = {}
92 | if input_tensor is not None and input_val is not None:
93 | for i in range(len(input_tensor)):
94 | feed_dict[input_tensor[i]] = input_val[i]
95 |
96 | return sess.run(output_tensor, feed_dict)
97 |
98 |
99 | def image_process(image, bbx):
100 | """
101 | image pre-process
102 | :param image: (str) image_path / (np.array) image in BGR
103 | :param bbx: [(int) xmin, (int) ymin, (int) xmax, (int) ymax]
104 | :return: input_image, bbx
105 | """
106 | if type(image) == str:
107 | image = cv2.imread(image)
108 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
109 | crop_image, crop_bbx = crop_with_padding_and_resize(image, bbx)
110 | cv2.imwrite("/work/meter_recognition/render_img/crop.jpg", crop_image)
111 | image_norm = crop_image / 127 - 1
112 | return image_norm, crop_bbx
113 |
114 |
115 | def crop_with_padding_and_resize(image, bbx, shape=(512, 512), ratio=0.2):
116 | """
117 | image pre-process
118 | :param image: image path or BGR image
119 | :param bbx: [xmin, ymin, xmax, ymax]
120 | :param shape: output image shape
121 | :param ratio: keep blank for edge
122 | :return: resized and padded image
123 | """
124 | image_h, image_w = image.shape[0:2]
125 | crop_bbx = np.copy(bbx)
126 |
127 | w = bbx[2] - bbx[0] + 1
128 | h = bbx[3] - bbx[1] + 1
129 | # keep 0.2 blank for edge
130 | crop_bbx[0] = round(bbx[0] - w * ratio)
131 | crop_bbx[1] = round(bbx[1] - h * ratio)
132 | crop_bbx[2] = round(bbx[2] + w * ratio)
133 | crop_bbx[3] = round(bbx[3] + h * ratio)
134 | # clip value from 0 to len-1
135 | crop_bbx[0] = 0 if crop_bbx[0] < 0 else crop_bbx[0]
136 | crop_bbx[1] = 0 if crop_bbx[1] < 0 else crop_bbx[1]
137 | crop_bbx[2] = image_w - 1 if crop_bbx[2] > image_w - 1 else crop_bbx[2]
138 | crop_bbx[3] = image_h - 1 if crop_bbx[3] > image_h - 1 else crop_bbx[3]
139 | # crop the image
140 | crop_image = image[crop_bbx[1]: crop_bbx[3] + 1, crop_bbx[0]: crop_bbx[2] + 1, :]
141 | # update width and height
142 | w = crop_bbx[2] - crop_bbx[0] + 1
143 | h = crop_bbx[3] - crop_bbx[1] + 1
144 | # keep aspect ratio
145 | # padding
146 | if h < w:
147 | pad = int(w - h)
148 | pad_t = pad // 2
149 | pad_d = pad - pad_t
150 | pad_image = np.pad(crop_image, ((pad_t, pad_d), (0, 0), (0, 0)), constant_values=128)
151 | else:
152 | pad = int(h - w)
153 | pad_l = pad // 2
154 | pad_r = pad - pad_l
155 | pad_image = np.pad(crop_image, ((0, 0), (pad_l, pad_r), (0, 0)), constant_values=128)
156 | crop_image = cv2.resize(pad_image, shape)
157 | return crop_image, crop_bbx
158 |
159 |
160 | def rel2abs(bbx, points):
161 | """
162 | transform points location into original location
163 | :param bbx: [xmin, ymin, xmax, ymax] cropped bbx
164 | :param points: [[x, y, score]] points location in heatmap
165 | :return: [[x, y, score]] points location in original image
166 | """
167 | bbx = bbx.copy()
168 | h, w = bbx[3] - bbx[1], bbx[2] - bbx[0]
169 | max_len = max(h, w)
170 | pad_t = (max_len - h) // 2
171 | pad_d = (max_len - h) - (max_len - h) // 2
172 | pad_l = (max_len - w) // 2
173 | pad_r = (max_len - w) - (max_len - w) // 2
174 | bbx[0] -= pad_l
175 | bbx[1] -= pad_t
176 | bbx[2] += pad_r
177 | bbx[3] += pad_d
178 | for point in points:
179 | point[0] = bbx[0] + point[0] * max_len / 128
180 | point[1] = bbx[1] + point[1] * max_len / 128
181 | return points
182 |
183 |
184 | def draw_point(image, points):
185 | for point in points:
186 | if int(point[0]) != -1 and int(point[1]) != -1:
187 | image = cv2.circle(
188 | image, (int(point[0]), int(point[1])), 5, (255, 204, 0), 3)
189 | return image
190 |
191 |
192 | def pred_one_image(image, bbxes, sess, input_tensor, output_tensor):
193 | processed_images = []
194 | processed_bbxes = []
195 | for bbx in bbxes:
196 | input_image, croped_bbxes = image_process(image, bbx)
197 | processed_images.append(input_image)
198 | processed_bbxes.append(croped_bbxes)
199 | batch_image = np.stack(processed_images, axis=0)
200 | batch_hm = pb_infer(sess, output_tensor, input_tensor, [batch_image])[0]
201 | final_point = []
202 | for i in range(len(batch_image)):
203 | hm = batch_hm[i]
204 | img = batch_image[i]
205 | point = get_results(hm, threshold=0.01)[0]
206 |
207 | point = rel2abs(processed_bbxes[i], point)
208 | final_point.append(point)
209 | return final_point
210 |
211 |
212 | def get_results(hms, threshold=0.6):
213 | if len(hms.shape) == 3:
214 | hms = np.expand_dims(hms, axis=0)
215 | num_class = hms.shape[-1]
216 | results = []
217 | for b in range(len(hms)):
218 | joints = -1 * np.ones([num_class, 3], dtype=np.float32)
219 | hm = hms[b]
220 | for c in range(num_class):
221 | index = np.unravel_index(
222 | np.argmax(hm[:, :, c]), hm[:, :, c].shape)
223 | # tmp = list(index)
224 | tmp = [index[1], index[0]]
225 | score = hm[index[0], index[1], c]
226 | tmp.append(score)
227 | if score > threshold:
228 | joints[c] = np.array(tmp)
229 | results.append(joints.tolist())
230 | return results
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
--------------------------------------------------------------------------------
/core/infer/visual_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-23
8 | """
9 | import cv2
10 | import numpy as np
11 |
12 | bbx_color = (28, 255, 147)
13 | pointer_color = (255, 204, 0)
14 | txt_color = (0, 240, 78)
15 |
16 |
17 | def get_results(hms, threshold=0.6):
18 | if len(hms.shape) == 3:
19 | hms = np.expand_dims(hms, axis=0)
20 | num_class = hms.shape[-1]
21 | results = []
22 | for b in range(len(hms)):
23 | joints = -1 * np.ones([num_class, 3], dtype=np.float32)
24 | hm = hms[b]
25 | for c in range(num_class):
26 | index = np.unravel_index(
27 | np.argmax(hm[:, :, c]), hm[:, :, c].shape)
28 | # tmp = list(index)
29 | tmp = [index[1], index[0]]
30 | score = hm[index[0], index[1], c]
31 | tmp.append(score)
32 | if score > threshold:
33 | joints[c] = np.array(tmp)
34 | results.append(joints.tolist())
35 | return results
36 |
37 |
38 | def draw_bbx(image, bbx):
39 | image = cv2.rectangle(
40 | image, (bbx[0], bbx[1]), (bbx[2], bbx[3]), bbx_color, 3)
41 | return image
42 |
43 |
44 | def draw_point(image, points):
45 | for point in points:
46 | if point[0] != -1 and point[1] != -1:
47 | image = cv2.circle(
48 | image, (point[0], point[1]), 5, pointer_color, 3)
49 | return image
50 |
51 |
52 | def draw_skeleton(image, points, dataset='mpii'):
53 | for point in points:
54 | if point[0] != -1 and point[1] != -1:
55 | image = cv2.circle(
56 | image, (int(point[0]), int(point[1])), 5, pointer_color, 3)
57 | if dataset is 'mpii':
58 | LINKS = [(0, 1), (1, 2), (2, 6), (6, 3), (3, 4), (4, 5), (6, 8),
59 | (8, 13), (13, 14), (14, 15), (8, 12), (12, 11), (11, 10)]
60 | for link in LINKS:
61 | if points[link[0]][:2] != [-1,-1] and points[link[1]][:2] != [-1,-1]:
62 | image = cv2.line(image, (int(points[link[0]][0]),int(points[link[0]][1])), (int(points[link[1]][0]),int(points[link[1]][1])), bbx_color)
63 | return image
64 | elif dataset is 'coco':
65 | LINKS = [[16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13],[6,7],[6,8],[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]
66 | for link in LINKS:
67 | if points[link[0]-1][:2] != [-1,-1] and points[link[1]-1][:2] != [-1,-1]:
68 | image = cv2.line(image, (int(points[link[0]-1][0]),int(points[link[0]-1][1])), (int(points[link[1]-1][0]),int(points[link[1]-1][1])), bbx_color)
69 | return image
70 |
71 |
72 | def visiual_image_with_hm(img, hm):
73 | hm = np.sum(hm, axis=-1) * 255
74 | hm = np.expand_dims(hm, axis=-1)
75 | hm = np.tile(hm, (1, 1, 3))
76 | hm = cv2.resize(hm, (img.shape[1], img.shape[0]))
77 | img = img + hm
78 | # img = np.clip(img, 0, 255)
79 | return img
--------------------------------------------------------------------------------
/core/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/core/loss/__init__.py
--------------------------------------------------------------------------------
/core/loss/loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-10
8 | """
9 | import tensorflow as tf
10 | import tensorflow.contrib.slim as slim
11 |
12 |
13 | def cross_entropy(features, heatmap):
14 | print('-Utilize Sigmoid-Cross-Entropy-Loss')
15 | """
16 | loss for point locating.
17 | B batch size
18 | H, W Tensor shape
19 | C num of classes
20 | CELoss
21 | :param features: (Tensor) without actived BxHxWxC
22 | :param heatmap: (Tensor) labels BxHxWxC
23 | :return: (List(Tensor))
24 | """
25 | if not isinstance(features, list):
26 | features = [features]
27 | losses = []
28 | for i in range(len(features)):
29 | loss = - heatmap * tf.log(features[i])
30 | losses.append(tf.reduce_mean(loss))
31 | return losses
32 |
33 |
34 | def softmax_cross_entropy(features, heatmap):
35 | print('-Utilize Softmax-Cross-Entropy-Loss')
36 | if not isinstance(features, list):
37 | features = [features]
38 | losses = []
39 | for i in range(len(features)):
40 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
41 | logits=features[i],
42 | labels=heatmap
43 | )
44 | losses.append(tf.reduce_mean(loss))
45 | return losses
46 |
47 |
48 | def focal_loss(features, heatmap, alpha=2, beta=4):
49 | """
50 | Focal Loss in "CornerNet"
51 | Loss = -1/N * (1-p)**alpha*log(p) if y=1 or (1-y)**beta*p**alpha*log(1-p)
52 | add loss into Graph
53 | :param features: (List(Tensor)) [BxHxWxC]
54 | :param heatmap: (Tensor) BxHxWxC
55 | :param alpha: (int)
56 | :param beta: (int)
57 | :return: (List(Tensor))
58 | """
59 | eps = 1e-9
60 | print('-Utilize Focal-Loss')
61 | if type(features) is not list:
62 | features = [features]
63 | losses = []
64 | for i in range(len(features)):
65 | # feature = tf.nn.sigmoid(features[i])
66 | feature = tf.clip_by_value(features[i], eps, 1 - eps)
67 | zeros = tf.zeros_like(heatmap)
68 | ones = tf.ones_like(heatmap)
69 |
70 | # mask
71 | mask = tf.where(tf.equal(heatmap, 1.0), ones, zeros)
72 | inv_mask = tf.subtract(1.0, mask)
73 |
74 | # num_pos
75 | num_pos = tf.reduce_sum(mask)
76 | num_pos = tf.maximum(num_pos, 1)
77 |
78 | # pre
79 | pos = tf.multiply(feature, mask)
80 | neg = tf.multiply(1.0 - feature, inv_mask)
81 | pre = tf.log(tf.add(pos, neg) + eps)
82 |
83 | # weight alpha
84 | pos_weight_alpha = tf.multiply(1.0 - feature, mask)
85 | neg_weight_alpha = tf.multiply(feature, inv_mask)
86 | weight_alpha = tf.pow(tf.add(pos_weight_alpha, neg_weight_alpha), alpha)
87 |
88 | # weight beta
89 | pos_weight_beta = mask
90 | neg_weight_beta = tf.multiply(1.0 - heatmap, inv_mask)
91 | weight_beta = tf.pow(tf.add(pos_weight_beta, neg_weight_beta), beta)
92 |
93 | # cal loss
94 | loss = tf.reduce_sum(- weight_beta * weight_alpha * pre) / num_pos
95 |
96 | losses.append(loss)
97 | return losses
98 |
99 |
100 | def mean_square_loss(features, heatmap):
101 | print('-Utilize Mse-Loss')
102 | if not isinstance(features, list):
103 | features = [features]
104 | losses = []
105 | for i in range(len(features)):
106 | loss = tf.losses.mean_squared_error(
107 | heatmap, features[i])
108 | losses.append(tf.reduce_mean(loss))
109 | return losses
110 |
--------------------------------------------------------------------------------
/core/network/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/core/network/__init__.py
--------------------------------------------------------------------------------
/core/network/keypoints.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-10
8 | """
9 | import time
10 | import tensorflow as tf
11 | import tensorflow.contrib.slim as slim
12 | from tensorflow.contrib.slim.nets import resnet_v2
13 | from core.network.network_utils import residual_block_v2, hourglass_block
14 |
15 |
16 | class Keypoints():
17 | def __init__(self, inputs, num_class,
18 | backbone="hourglass",
19 | num_block=2,
20 | num_depth=5,
21 | residual_dim=(256, 256, 384, 384, 384, 512),
22 | is_training=True,
23 | is_maxpool=False,
24 | is_nearest=True,
25 | reuse=False
26 | ):
27 | """
28 | Modified hourglass. See more in network_utils.py
29 | :param inputs: (Tensor) BxHxWxC images
30 | :param num_class: (int) num of classes
31 | :param num_block: (int) num of hourglass block
32 | :param num_depth: (int) num of down-sampling steps
33 | :param residual_dim: (list(int)) output dim for each residual block. Length should be num_depth+1
34 | :param is_training: (bool) is in training parse
35 | :param is_maxpool: (bool) if true, using max-pool down-sampling. Otherwise, residual block stride will be 2
36 | :param is_nearest: (bool) if true, using nearest up-sampling. Otherwise, using deconvolution
37 | :param reuse:(bool) reuse the variable
38 | """
39 | self.inputs = inputs
40 | self.num_class = num_class
41 | self.backbone = backbone
42 |
43 | self.num_block = num_block
44 | self.num_depth = num_depth
45 | self.residual_dim = residual_dim
46 | self.is_training = is_training
47 | self.is_maxpool = is_maxpool
48 | self.is_nearest = is_nearest
49 | self.reuse = reuse
50 |
51 | self.features = self.graph_hourglass(self.inputs)
52 |
53 | def pre_process(self, inputs, scope='pre_process'):
54 | """
55 | pre-process conv7x7/s=2 -> residual/s=2
56 | :param inputs: (Tensor) BxHxWxC
57 | :param scope: (str) scope
58 | :return: (Tensor) BxH/4xW/4xC
59 | """
60 | with tf.variable_scope(scope):
61 | net = slim.conv2d(
62 | inputs=inputs,
63 | num_outputs=128,
64 | kernel_size=[7, 7],
65 | stride=2,
66 | activation_fn=None,
67 | normalizer_fn=None,
68 | reuse=self.reuse,
69 | scope='conv1'
70 | )
71 | tf.summary.histogram(net.name + '/activations', net)
72 |
73 | net = residual_block_v2(
74 | inputs=net,
75 | output_dim=256,
76 | stride=2,
77 | is_training=self.is_training,
78 | reuse=self.reuse,
79 | scope='residual_1'
80 | )
81 | return net
82 |
83 | def inter_process(self, inputs_1, inputs_2, scope='inter_process'):
84 | with tf.variable_scope(scope):
85 | branch_1 = slim.batch_norm(
86 | inputs=inputs_1,
87 | activation_fn=tf.nn.relu,
88 | is_training=self.is_training,
89 | scope='branch_1/bn',
90 | reuse=self.reuse,
91 | scale=True
92 | )
93 | tf.summary.histogram(branch_1.name + '/activations', branch_1)
94 |
95 | branch_1 = slim.conv2d(
96 | inputs=branch_1,
97 | num_outputs=inputs_1.get_shape().as_list()[-1],
98 | kernel_size=[1, 1],
99 | stride=1,
100 | activation_fn=None,
101 | normalizer_fn=None,
102 | reuse=self.reuse,
103 | scope='branch_1/conv'
104 | )
105 | tf.summary.histogram(branch_1.name + '/activations', branch_1)
106 |
107 | branch_2 = slim.batch_norm(
108 | inputs=inputs_2,
109 | activation_fn=tf.nn.relu,
110 | is_training=self.is_training,
111 | scope='branch_2/bn',
112 | reuse=self.reuse,
113 | scale=True)
114 | tf.summary.histogram(branch_2.name + '/activations', branch_2)
115 |
116 | branch_2 = slim.conv2d(
117 | inputs=branch_2,
118 | num_outputs=inputs_2.get_shape().as_list()[-1],
119 | kernel_size=[1, 1],
120 | stride=1,
121 | activation_fn=None,
122 | normalizer_fn=None,
123 | reuse=self.reuse,
124 | scope='branch_2/conv'
125 | )
126 | tf.summary.histogram(branch_2.name + '/activations', branch_2)
127 |
128 | output = tf.add(branch_1, branch_2)
129 | return output
130 |
131 | def hinge(self, inputs, output_dim, scope='hinge'):
132 | with tf.variable_scope(scope):
133 | pre = slim.batch_norm(
134 | inputs=inputs,
135 | activation_fn=tf.nn.relu,
136 | is_training=self.is_training,
137 | scope='bn',
138 | reuse=self.reuse,
139 | scale=True
140 | )
141 | tf.summary.histogram(pre.name + '/activations', pre)
142 |
143 | outputs = slim.conv2d(
144 | inputs=pre,
145 | num_outputs=output_dim,
146 | kernel_size=[1, 1],
147 | stride=1,
148 | activation_fn=None,
149 | normalizer_fn=None,
150 | reuse=self.reuse,
151 | scope='conv'
152 | )
153 | tf.summary.histogram(outputs.name + '/activations', outputs)
154 | return outputs
155 |
156 | def keypoint(self, features, scope='keypoint'):
157 | """
158 | key-point branch. return final feature map
159 | :param features: (Tensor) final backbone features without bn and activated
160 | :param scope: (str) scope
161 | :return: [Tensor,...]
162 | """
163 | keypoint_feature = []
164 | if type(features) is not list:
165 | features = [features]
166 | for i in range(len(features)):
167 | with tf.variable_scope(scope+'_%d' % i):
168 | feature = slim.batch_norm(inputs=features[i],
169 | activation_fn=tf.nn.relu,
170 | is_training=self.is_training,
171 | scope='pre_bn',
172 | reuse=self.reuse,
173 | scale=True)
174 | tf.summary.histogram(feature.name + '/activations', feature)
175 | feature = slim.conv2d(
176 | inputs=feature,
177 | num_outputs=self.num_class,
178 | kernel_size=[3, 3],
179 | stride=1,
180 | activation_fn=tf.nn.sigmoid,
181 | normalizer_fn=None,
182 | reuse=self.reuse,
183 | scope='conv'
184 | )
185 | tf.summary.histogram(feature.name + '/activations', feature)
186 | keypoint_feature.append(feature)
187 |
188 | return keypoint_feature
189 |
190 | def graph_backbone_hourglass(self, inputs):
191 | """
192 | Extract features
193 | :param inputs: (Tensor) BxHxWxC images
194 | :return: [Tensor] BxH/4xW/4xC. Pre is for inter-mediate supervision, last if for prediction.
195 | """
196 | t0 = time.time()
197 | print('-Begin to creat model')
198 | with tf.variable_scope('backbone'):
199 | start_time = time.time()
200 | pre = self.pre_process(inputs)
201 | print('--%s has been created in %.3fs' %
202 | ('pre_process', time.time() - start_time))
203 | net = pre
204 | features = []
205 | for i in range(self.num_block):
206 | start_time = time.time()
207 | hourglass = hourglass_block(
208 | inputs=net,
209 | num_depth=self.num_depth,
210 | residual_dim=self.residual_dim,
211 | is_training=self.is_training,
212 | is_maxpool=self.is_maxpool,
213 | is_nearest=self.is_nearest,
214 | reuse=self.reuse,
215 | scope='hourglass_%d' % i
216 | )
217 | hinge = self.hinge(hourglass, self.residual_dim[0], 'hinge_%d' % i)
218 | features.append(hinge)
219 | print('--%s has been created in %.3fs' % ('hourglass_%d' % i, time.time() - start_time))
220 | start_time = time.time()
221 | if i < self.num_block - 1: net = self.inter_process(net, hinge, 'inter_process_%d' % i)
222 | print('--%s has been created in %.3fs' % ('inter_process_%d' % i, time.time() - start_time))
223 |
224 | print('-Model has been created in %.3fs' % (time.time() - t0))
225 | return features
226 |
227 | def graph_backbone_resnet_101(self, inputs):
228 | with tf.variable_scope('backbone'):
229 | feature, end_point = resnet_v2.resnet_v2_101(inputs, num_classes=None, global_pool=False, is_training=self.is_training, reuse=self.reuse, scope="resnet_v2_101")
230 | with tf.variable_scope('up_sample'):
231 | feature = slim.conv2d_transpose(feature, 512, 3, 2, activation_fn=None, reuse=self.reuse, scope="transpose_conv1")
232 | feature = slim.batch_norm(inputs=feature,
233 | activation_fn=tf.nn.relu,
234 | is_training=self.is_training,
235 | scope='transpose_conv1/bn',
236 | reuse=self.reuse,
237 | scale=True)
238 | feature = slim.conv2d_transpose(feature, 512, 3, 2, activation_fn=None, reuse=self.reuse, scope="transpose_conv2")
239 | feature = slim.batch_norm(inputs=feature,
240 | activation_fn=tf.nn.relu,
241 | is_training=self.is_training,
242 | scope='transpose_conv2/bn',
243 | reuse=self.reuse,
244 | scale=True)
245 | feature = slim.conv2d_transpose(feature, 512, 3, 2, activation_fn=None, reuse=self.reuse, scope="transpose_conv3")
246 | feature = slim.batch_norm(inputs=feature,
247 | activation_fn=tf.nn.relu,
248 | is_training=self.is_training,
249 | scope='transpose_conv3/bn',
250 | reuse=self.reuse,
251 | scale=True)
252 | feature = slim.conv2d(feature, self.residual_dim[0], 3, 1, activation_fn=None, normalizer_fn=None, reuse=self.reuse, scope="conv1")
253 | return [feature]
254 |
255 | def graph_hourglass(self, inputs, scope='Keypoints'):
256 | """
257 | graph hourglass net.
258 | :param inputs: (Tensor) images
259 | :param scope: (str) scope
260 | :return: [[Tensor B x H/4 x W/4 x num_class,...]]
261 | """
262 | with tf.variable_scope(scope):
263 | if self.backbone == "hourglass":
264 | features = self.graph_backbone_hourglass(inputs)
265 | elif self.backbone == "resnet_v2_101":
266 | features = self.graph_backbone_resnet_101(inputs)
267 | else:
268 | raise ValueError("Invalid Backbone type!")
269 | all_features = [self.keypoint(features)]
270 | print('--PB file input node is %s' % inputs.name)
271 | print('--PB file output node is %s' % all_features[0][-1].name)
272 | return all_features
273 |
274 |
--------------------------------------------------------------------------------
/core/network/network_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-10
8 | """
9 |
10 |
11 | import tensorflow as tf
12 | import tensorflow.contrib.slim as slim
13 |
14 |
15 | def residual_block_v2_with_bottom_neck(inputs, output_dim, stride,
16 | is_training=True, reuse=False, scope='residual_block'):
17 | """
18 | Pre-act mode
19 | modified residual block
20 | bottle neck depth = output_dim / 2
21 | output = conv + short-cut
22 | :param inputs: (Tensor) input tensor BxHxWxC
23 | :param output_dim: (int) multiple of 2
24 | :param stride: (int) if down-sample
25 | :param reuse: (bool) reuse the variable
26 | :param scope: (str) scope name
27 | :param is_training: (bool)bn is in training phase
28 | :return: (Tensor) Bx(H/stride)x(W/stride)xC
29 | """
30 | dim = output_dim / 4
31 | if output_dim % 2 != 0:
32 | raise ValueError('residual block output dim must be a multiple of 2')
33 | with tf.variable_scope(scope):
34 | depth_in = inputs.get_shape().as_list()[-1]
35 | pre_act = slim.batch_norm(
36 | inputs=inputs,
37 | activation_fn=tf.nn.relu,
38 | is_training=is_training,
39 | scope='pre_act',
40 | scale=True,
41 | reuse=reuse
42 | )
43 | if output_dim == depth_in:
44 | short_cut = slim.max_pool2d(
45 | inputs=inputs,
46 | kernel_size=[1, 1],
47 | stride=stride,
48 | scope='short_cut'
49 | )
50 | else:
51 | short_cut = slim.conv2d(
52 | inputs=pre_act,
53 | num_outputs=output_dim,
54 | kernel_size=[1, 1],
55 | stride=stride,
56 | activation_fn=None,
57 | normalizer_fn=None,
58 | scope='short_cut',
59 | reuse=reuse
60 | )
61 | tf.summary.histogram(short_cut.name + '/activations', short_cut)
62 |
63 | residual = slim.conv2d(
64 | inputs=pre_act,
65 | num_outputs=dim,
66 | kernel_size=[1, 1],
67 | stride=1,
68 | activation_fn=None,
69 | normalizer_fn=None,
70 | scope='conv1',
71 | reuse=reuse
72 | )
73 | residual = slim.batch_norm(
74 | residual,
75 | activation_fn=tf.nn.relu,
76 | is_training=is_training,
77 | scope='conv1/bn',
78 | scale=True,
79 | reuse=reuse
80 | )
81 | tf.summary.histogram(residual.name + '/activations', residual)
82 |
83 | residual = slim.conv2d(
84 | inputs=residual,
85 | num_outputs=dim,
86 | kernel_size=[3, 3],
87 | stride=stride,
88 | activation_fn=None,
89 | normalizer_fn=None,
90 | scope='conv2',
91 | reuse=reuse
92 | )
93 | residual = slim.batch_norm(
94 | residual,
95 | activation_fn=tf.nn.relu,
96 | is_training=is_training,
97 | scope='conv2/bn',
98 | scale=True,
99 | reuse=reuse
100 | )
101 | tf.summary.histogram(residual.name + '/activations', residual)
102 |
103 | residual = slim.conv2d(
104 | inputs=residual,
105 | num_outputs=output_dim,
106 | kernel_size=[1, 1],
107 | stride=1,
108 | activation_fn=None,
109 | normalizer_fn=None,
110 | scope='conv3',
111 | reuse=reuse
112 | )
113 | tf.summary.histogram(residual.name + '/activations', residual)
114 |
115 | output = short_cut + residual
116 | return output
117 |
118 |
119 | def residual_block_v2(inputs, output_dim, stride,
120 | is_training=True, reuse=False, scope='residual_block'):
121 | """
122 | Pre-act mode
123 | modified residual block
124 | bottle neck depth = output_dim / 2
125 | output = conv + short-cut
126 | :param inputs: (Tensor) input tensor BxHxWxC
127 | :param output_dim: (int) multiple of 2
128 | :param stride: (int) if down-sample
129 | :param scope: (str) scope name
130 | :param is_training: (bool)bn is in training phase
131 | :return: (Tensor) Bx(H/stride)x(W/stride)xC
132 | """
133 | with tf.variable_scope(scope):
134 | depth_in = inputs.get_shape().as_list()[-1]
135 | pre_act = slim.batch_norm(
136 | inputs=inputs,
137 | activation_fn=tf.nn.relu,
138 | is_training=is_training,
139 | scope='pre_act',
140 | scale=True,
141 | reuse=reuse
142 | )
143 | if output_dim == depth_in:
144 | short_cut = slim.max_pool2d(
145 | inputs=inputs,
146 | kernel_size=[1, 1],
147 | stride=stride,
148 | scope='short_cut'
149 | )
150 | else:
151 | short_cut = slim.conv2d(
152 | inputs=pre_act,
153 | num_outputs=output_dim,
154 | kernel_size=[1, 1],
155 | stride=stride,
156 | activation_fn=None,
157 | normalizer_fn=None,
158 | scope='short_cut',
159 | reuse=reuse
160 | )
161 | tf.summary.histogram(short_cut.name + '/activations', short_cut)
162 |
163 | residual = slim.conv2d(
164 | inputs=pre_act,
165 | num_outputs=output_dim,
166 | kernel_size=[3, 3],
167 | stride=1,
168 | activation_fn=None,
169 | normalizer_fn=None,
170 | scope='conv1',
171 | reuse=reuse
172 | )
173 | residual = slim.batch_norm(
174 | residual,
175 | activation_fn=tf.nn.relu,
176 | is_training=is_training,
177 | scope='conv1/bn',
178 | scale=True,
179 | reuse=reuse
180 | )
181 | tf.summary.histogram(residual.name + '/activations', residual)
182 |
183 | residual = slim.conv2d(
184 | inputs=residual,
185 | num_outputs=output_dim,
186 | kernel_size=[3, 3],
187 | stride=stride,
188 | activation_fn=None,
189 | normalizer_fn=None,
190 | scope='conv2',
191 | reuse=reuse
192 | )
193 |
194 | tf.summary.histogram(residual.name + '/activations', residual)
195 |
196 | output = short_cut + residual
197 | return output
198 |
199 |
200 | def hourglass_block(inputs, num_depth, residual_dim,
201 | is_training=True, is_maxpool=False,
202 | is_nearest=True, reuse=False, scope='hourglass_block'):
203 | """
204 | modified hourglass block fellow by "CornerNet"
205 | There 2 residual blocks in short-cut istead of 1
206 | There 2 residual blocks after upsampling
207 | There 4 residual blocks with depth dim (512 in paper) in the middle of hourglass
208 | Attention! residual blocks are in pre-act mode
209 | inputs must be not processed by actived or normlized
210 | :param inputs: (Tensor) BxHxWxC
211 | :param num_depth: (int) depth of downsample
212 | :param residual_dim: (list) dim of residual block. len(residual_dim)=num_depth+1
213 | :param is_training: (bool) bn is in training phase
214 | :param is_maxpool: (bool) if it's True, downsample mode will be maxpool. Otherwise, downsample mode will be stride=2
215 | :param is_nearest: (bool) if it's True, upsample mode will be neareast upsample. Otherwise, upsample mode will be deconv.
216 | :param scope: (str) scope name
217 | :return: (Tensor) BxHxWxC
218 | """
219 | cur_res_dim = inputs.get_shape().as_list()[-1]
220 | next_res_dim = residual_dim[0]
221 |
222 | with tf.variable_scope(scope):
223 | up_1 = residual_block_v2(
224 | inputs=inputs,
225 | output_dim=cur_res_dim,
226 | stride=1,
227 | is_training=is_training,
228 | reuse=reuse,
229 | scope='up_1'
230 | )
231 | if is_maxpool:
232 | low_1 = slim.max_pool2d(
233 | inputs=inputs,
234 | kernel_size=2,
235 | stride=2,
236 | padding='VALID'
237 | )
238 | low_1 = residual_block_v2(
239 | inputs=low_1,
240 | output_dim=next_res_dim,
241 | stride=1,
242 | is_training=is_training,
243 | reuse=reuse,
244 | scope='low_1'
245 | )
246 | else:
247 | low_1 = residual_block_v2(
248 | inputs=inputs,
249 | output_dim=next_res_dim,
250 | stride=2,
251 | is_training=is_training,
252 | reuse=reuse,
253 | scope='low_1'
254 | )
255 |
256 | if num_depth > 1:
257 | low_2 = hourglass_block(
258 | inputs=low_1,
259 | num_depth=num_depth - 1,
260 | residual_dim=residual_dim[1:],
261 | is_training=is_training,
262 | is_maxpool=is_maxpool,
263 | is_nearest=is_nearest,
264 | reuse=reuse,
265 | scope='hourglass_block_%d' % (num_depth - 1)
266 | )
267 | else:
268 | low_2 = residual_block_v2(
269 | inputs=low_1,
270 | output_dim=next_res_dim,
271 | stride=1,
272 | is_training=is_training,
273 | reuse=reuse,
274 | scope='low_2'
275 | )
276 | low_3 = residual_block_v2(
277 | inputs=low_2,
278 | output_dim=cur_res_dim,
279 | stride=1,
280 | is_training=is_training,
281 | reuse=reuse,
282 | scope='low_3'
283 | )
284 | if is_nearest:
285 | up_2 = tf.image.resize_nearest_neighbor(
286 | images=low_3,
287 | size=tf.shape(low_3)[1:3] * 2,
288 | name='up_2'
289 | )
290 | else:
291 | up_2 = slim.conv2d_transpose(
292 | inputs=low_3,
293 | num_outputs=cur_res_dim,
294 | kernel_size=[3, 3],
295 | stride=2,
296 | reuse=reuse,
297 | scope='up_2'
298 | )
299 | merge = up_1 + up_2
300 | return merge
301 |
--------------------------------------------------------------------------------
/core/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/core/train/__init__.py
--------------------------------------------------------------------------------
/core/train/trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-20
8 | """
9 | import os
10 | import time
11 |
12 | import tensorflow as tf
13 | import tensorflow.contrib.slim as slim
14 | from core.loss.loss import focal_loss, cross_entropy, softmax_cross_entropy, mean_square_loss
15 |
16 |
17 | class Trainer():
18 | def __init__(self, model_class, dataset_class, cfg):
19 |
20 | start_time = time.time()
21 | # HARDWARE
22 | self.CUDA_VISIBLE_DEVICES = cfg.CUDA_VISIBLE_DEVICES
23 | if self.CUDA_VISIBLE_DEVICES is not None:
24 | os.environ['CUDA_VISIBLE_DEVICES'] = self.CUDA_VISIBLE_DEVICES
25 | self.MULTI_THREAD_NUM = cfg.MULTI_THREAD_NUM
26 | # self.MULTI_GPU = cfg.MULTI_GPU
27 | # self.NUM_GPU = len(self.MULTI_GPU)
28 |
29 | # NETWORK
30 | self.backbone = cfg.backbone
31 | self.image_size = cfg.image_size
32 | self.heatmap_size = cfg.heatmap_size
33 | self.stride = cfg.stride
34 | self.num_block = cfg.num_block
35 | self.num_depth = cfg.num_depth
36 | self.residual_dim = cfg.residual_dim
37 | self.is_maxpool = cfg.is_maxpool
38 | self.is_nearest = cfg.is_nearest
39 |
40 | # TRAINING
41 | self.batch_size = cfg.batch_size
42 | self.learning_rate_init = cfg.learning_rate_init
43 | self.learning_rate_warmup = cfg.learning_rate_warmup
44 | self.exp_decay = cfg.exp_decay
45 |
46 | self.warmup_epoch_size = cfg.warmup_epoch_size
47 | self.epoch_size = cfg.epoch_size
48 | self.summary_per = cfg.summary_per
49 | self.save_per = cfg.save_per
50 |
51 | self.regularization_weight = cfg.regularization_weight
52 |
53 | # VALIDATION
54 | self.val_per = cfg.val_per
55 | self.val_time = cfg.val_time
56 |
57 | # PATH
58 | self.dataset_dir = cfg.dataset_dir
59 | self.train_image_dir = cfg.train_image_dir
60 | self.val_image_dir = cfg.val_image_dir
61 | self.train_list_path = cfg.train_list_path
62 | self.val_list_path = cfg.val_list_path
63 |
64 | self.log_dir = cfg.log_dir
65 | self.ckpt_path = cfg.ckpt_dir
66 |
67 | # SAVER AND LOADER
68 | self.pre_trained_ckpt = cfg.pre_trained_ckpt
69 | self.ckpt_name = cfg.ckpt_name
70 | self.max_keep = cfg.max_keep
71 |
72 | print('-Load config in %.3f' % (time.time() - start_time))
73 |
74 | # DATASET
75 | self.dataset_class = dataset_class
76 | self.augment = cfg.augment
77 | self.train_dataset = None
78 | self.val_dataset = None
79 | self.train_iterator = None
80 | self.val_iterator = None
81 |
82 | # cal option
83 | self.time = time.strftime(
84 | '%Y_%m_%d_%H_%M_%S',
85 | time.localtime(
86 | time.time()))
87 | self.steps_per_period = None
88 |
89 | # PLACE HOLDER
90 | self.inputs_x = None
91 | self.inputs_y = None
92 | self.is_training = None
93 |
94 | # MODEL
95 | self.model_class = model_class
96 | self.model = None
97 | self.features = None
98 |
99 | self.val_model = None
100 | self.val_features = None
101 |
102 | # LOSS
103 | self.loss_mode = cfg.loss_mode
104 | self.model_losses = None
105 | self.model_loss = None
106 | self.val_model_loss = None
107 | self.trainable_variables = None
108 | self.regularization_loss = None
109 | self.loss = None
110 |
111 | # LEARNING RATE
112 | self.global_step = None
113 | self.learning_rate = None
114 |
115 | # TRAIN OP
116 | self.train_op = None
117 |
118 | # SAVER LOADER SUMMARY
119 | self.loader = None
120 | self.saver = None
121 | self.summary_writer = None
122 | self.write_op = None
123 |
124 | # DEBUG
125 | self.is_debug = False
126 | self.gradient = None
127 | self.mean_gradient = None
128 |
129 | # SESSION
130 | self.sess = None
131 | #################################################################
132 |
133 | def init_inputs(self):
134 | with tf.variable_scope('Placeholder'):
135 | self.inputs_x = tf.placeholder(tf.float32, [None, self.image_size[0], self.image_size[1], 3],
136 | 'inputs_x')
137 | self.inputs_y = tf.placeholder(tf.float32, [None, self.heatmap_size[0], self.heatmap_size[0],
138 | self.train_dataset.num_class], 'inputs_y')
139 | # 如果使用placeholder为BN层的trainable参数,BN层中会处于一种使用tf.cond,tf.switch流控制节点(此处可以在tensorRT以及模型图中得到验证)
140 | # 这样的话每一个BN层都会有两条路径出来,训练太占显存,infer部署的时候还要单独进行剪枝
141 | # 此处直接设置为True的话,训练是没问题的。做val的时候,不调用train_op那么BN的gamma和beta不会更新
142 | # 并且由于mean和var设置为依赖于train_op更新,所以BN在val时所有参数都没有更新,相当于trainable=False
143 | # 然而在tf1.x版本中,trainable=False是让BN处于freeze状态。
144 | # 和infer不同的时,freeze仍然是使用当前batch的mean和var进行处理。
145 | # 在tf2.x版本中,bn已经改成了当trainable为False的时候是infer状态
146 | self.is_training = True
147 |
148 | def init_dataset(self):
149 | start_time = time.time()
150 |
151 | # TRAIN DATASET
152 | self.train_dataset = self.dataset_class(image_dir=self.train_image_dir,
153 | gt_path=self.train_list_path,
154 | batch_size=self.batch_size,
155 | image_size=self.image_size,
156 | heatmap_size=self.heatmap_size,
157 | augment=self.augment)
158 | self.train_iterator = self.train_dataset.iterator(
159 | self.MULTI_THREAD_NUM)
160 |
161 | # VAL DATASET
162 | self.val_dataset = self.dataset_class(image_dir=self.val_image_dir,
163 | gt_path=self.val_list_path,
164 | batch_size=self.batch_size,
165 | image_size=self.image_size,
166 | heatmap_size=self.heatmap_size
167 | )
168 | self.val_iterator = self.val_dataset.iterator(self.MULTI_THREAD_NUM)
169 | self.steps_per_period = int(
170 | self.train_dataset.num_data /
171 | self.batch_size)
172 | print('-Creat dataset in %.3f' % (time.time() - start_time))
173 |
174 | def init_model(self):
175 | print("-Creat Train model")
176 | self.model = self.model_class(self.inputs_x, self.train_dataset.num_class,
177 | backbone=self.backbone,
178 | num_block=self.num_block,
179 | num_depth=self.num_depth,
180 | residual_dim=self.residual_dim,
181 | is_training=True,
182 | is_maxpool=self.is_maxpool,
183 | is_nearest=self.is_nearest,
184 | reuse=False
185 | )
186 | self.features = self.model.features[0]
187 |
188 | print("-Creat Val model")
189 | self.val_model = self.model_class(self.inputs_x, self.train_dataset.num_class,
190 | backbone=self.backbone,
191 | num_block=self.num_block,
192 | num_depth=self.num_depth,
193 | residual_dim=self.residual_dim,
194 | is_training=False,
195 | is_maxpool=self.is_maxpool,
196 | is_nearest=self.is_nearest,
197 | reuse=True
198 | )
199 | self.val_features = self.val_model.features[0]
200 |
201 | def init_learning_rate(self):
202 | start_time = time.time()
203 | # LEARNING RATE
204 | with tf.variable_scope('Learning_rate'):
205 | self.global_step = tf.train.get_or_create_global_step()
206 | warmup_steps = tf.constant(self.warmup_epoch_size * self.steps_per_period,
207 | dtype=tf.int64, name='warmup_steps')
208 | self.learning_rate = tf.cond(
209 | pred=tf.less(self.global_step, warmup_steps),
210 | true_fn=lambda: self.learning_rate_warmup + (self.learning_rate_init - self.learning_rate_warmup)
211 | * tf.cast(self.global_step, tf.float32) / tf.cast(warmup_steps, tf.float32),
212 | false_fn=lambda: tf.train.exponential_decay(
213 | self.learning_rate_init, self.global_step, self.steps_per_period, self.exp_decay, staircase=True)
214 | )
215 | print('-Creat learning rate in %.3f' % (time.time() - start_time))
216 |
217 | def init_loss(self):
218 | start_time = time.time()
219 |
220 | # LOSS
221 | with tf.variable_scope('Loss'):
222 | self.trainable_variables = tf.trainable_variables()
223 | if self.loss_mode == 'focal':
224 | loss_fn = focal_loss
225 | elif self.loss_mode == 'sigmoid':
226 | loss_fn = cross_entropy
227 | elif self.loss_mode == 'softmax':
228 | loss_fn = softmax_cross_entropy
229 | elif self.loss_mode == 'mse':
230 | loss_fn = mean_square_loss
231 | else:
232 | raise ValueError('Unsupported loss mode: %s' % self.loss_mode)
233 | self.model_losses = loss_fn(self.features, self.inputs_y)
234 | self.model_loss = tf.add_n(self.model_losses)
235 | self.val_model_loss = loss_fn(self.val_features, self.inputs_y)[-1]
236 | self.regularization_loss = tf.add_n(
237 | [tf.nn.l2_loss(var) for var in self.trainable_variables])
238 | self.regularization_loss = self.regularization_weight * self.regularization_loss
239 | self.loss = self.model_loss + self.regularization_loss
240 |
241 | print('-Creat loss in %.3f' % (time.time() - start_time))
242 |
243 | def init_train_op(self):
244 | start_time = time.time()
245 | # TRAIN_OP
246 | with tf.name_scope("Train_op"):
247 | optimizer = tf.train.AdamOptimizer(
248 | self.learning_rate)
249 | gvs = optimizer.compute_gradients(self.loss)
250 | clip_gvs = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gvs]
251 | if self.is_debug:
252 | self.mean_gradient = tf.reduce_mean([tf.reduce_mean(g) for g, v in gvs])
253 | tf.summary.scalar("mean_gradient", self.mean_gradient)
254 | print('Debug mode is on !!!')
255 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
256 | # It's important!
257 | # Update moving-average in BN
258 | self.train_op = optimizer.apply_gradients(clip_gvs, global_step=self.global_step)
259 | print('-Creat train op in %.3f' % (time.time() - start_time))
260 |
261 | def init_loader_saver_summary(self):
262 | start_time = time.time()
263 | with tf.name_scope('loader_and_saver'):
264 | self.loader = tf.train.Saver(var_list=tf.global_variables())
265 | var_list = tf.trainable_variables()
266 | g_list = tf.global_variables()
267 | bn_moving_var = [g for g in g_list if 'moving_mean' in g.name]
268 | bn_moving_var += [g for g in g_list if 'moving_variance' in g.name]
269 | if len(bn_moving_var) < 1:
270 | print('Warning! BatchNorm layer parameters have not been saved!')
271 | var_list += bn_moving_var
272 | self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=self.max_keep)
273 |
274 | with tf.name_scope('summary'):
275 |
276 | tf.summary.image('input_image', self.inputs_x, max_outputs=3)
277 | tf.summary.image('input_hm', tf.reduce_sum(self.inputs_y,axis=-1,keepdims=True), max_outputs=3)
278 | tf.summary.image('output_hm', tf.reduce_sum(self.features[-1],axis=-1,keepdims=True), max_outputs=3)
279 |
280 | tf.summary.scalar("learning_rate", self.learning_rate)
281 | for i in range(len(self.model_losses)):
282 | tf.summary.scalar("block_%d_loss" % i, self.model_losses[i])
283 | tf.summary.scalar("model_loss", self.model_loss)
284 | tf.summary.scalar("regularization_loss", self.regularization_loss)
285 | tf.summary.scalar("total_loss", self.loss)
286 | # # Optional
287 | # tf.summary.scalar('keypoint_bn_moving_mean',
288 | # tf.reduce_mean(slim.get_variables_by_name('HourglassNet/keypoint_1/pre_bn/moving_mean')))
289 | # tf.summary.scalar('keypoint_bn_moving_var', tf.reduce_mean(
290 | # slim.get_variables_by_name('HourglassNet/keypoint_1/pre_bn/moving_variance')))
291 |
292 | if not os.path.exists(self.log_dir):
293 | os.mkdir(self.log_dir)
294 | self.write_op = tf.summary.merge_all()
295 |
296 | print(
297 | '-Creat loader saver and summary in %.3f' %
298 | (time.time() - start_time))
299 |
300 | def init_session(self):
301 | start_time = time.time()
302 | # SESSION
303 | config = tf.ConfigProto(allow_soft_placement=True) # 是否自动选择GPU
304 | config.gpu_options.allow_growth = True
305 | self.sess = tf.Session(config=config)
306 | self.summary_writer = tf.summary.FileWriter(
307 | os.path.join(self.log_dir, self.time), graph=self.sess.graph)
308 | print('-Initializing session in %.3f' % (time.time() - start_time))
309 |
310 | # self.train_launch()
311 | ################################################################
312 | def _load_ckpt(self):
313 | t0 = time.time()
314 | try:
315 | self.loader.restore(self.sess, self.pre_trained_ckpt)
316 | print('Successful restore from %s in time %.2f' %
317 | (self.pre_trained_ckpt, time.time() - t0))
318 | except Exception as e:
319 | print(e)
320 | print('Failed restore from %s in time %.2f' %
321 | (self.pre_trained_ckpt, time.time() - t0))
322 |
323 | def train(self):
324 | t0 = time.time()
325 | self.sess.run(tf.global_variables_initializer())
326 | print('-Model has beed initialized in %.3f' % (time.time() - t0))
327 | if self.pre_trained_ckpt is not None:
328 | self._load_ckpt()
329 |
330 | print('Begin to train!')
331 | total_step = self.epoch_size * self.steps_per_period
332 | step = 0
333 | while step < total_step:
334 | # try:
335 | step = self.sess.run(self.global_step)
336 | ite = step % self.steps_per_period + 1
337 | epoch = step // self.steps_per_period + 1
338 | imgs, hms = next(self.train_iterator)
339 | imgs = (imgs / 127.5) - 1
340 | feed_dict = {
341 | self.inputs_x: imgs,
342 | self.inputs_y: hms,
343 | }
344 |
345 | if step % self.summary_per == 0:
346 | if self.is_debug:
347 | mean_gradient = self.sess.run(self.mean_gradient, feed_dict=feed_dict)
348 | print('mean_gradient: %.6f ' % mean_gradient)
349 | summary, _, lr, loss, model_ls, reg_ls = self.sess.run(
350 | [self.write_op, self.train_op, self.learning_rate, self.loss, self.model_loss, self.regularization_loss], feed_dict=feed_dict)
351 | print(
352 | 'Epoch: %d / %d Iter: %d / %d Step: %d Loss: %.4f Model Loss: %.4f Reg Loss: %.4f Lr: %f' %
353 | (epoch, self.epoch_size, ite, self.steps_per_period, step, loss, model_ls, reg_ls, lr))
354 | self.summary_writer.add_summary(summary, step)
355 | else:
356 | _, lr, loss, model_ls, reg_ls = self.sess.run(
357 | [self.train_op, self.learning_rate, self.loss, self.model_loss, self.regularization_loss], feed_dict=feed_dict)
358 |
359 | if step % self.save_per == 0:
360 | self.saver.save(
361 | self.sess,
362 | os.path.join(
363 | self.ckpt_path,
364 | self.ckpt_name),
365 | global_step=step)
366 | if step % self.val_per == 0 and step != 0:
367 | # Validation
368 | losses = []
369 | start_time = time.time()
370 | for s in range(self.val_time):
371 | # TODO 计算loss 不更新梯度 保存每一次loss 最后打印平均loss
372 | # TODO 保存几个图片输出的结果 可以用cv2.circle渲染 cv2.imwrite 存在本地
373 | imgs_v, hms_v = next(self.val_iterator)
374 | imgs_v = (imgs_v / 127.5) - 1
375 | feed_dict = {
376 | self.inputs_x: imgs_v,
377 | self.inputs_y: hms_v,
378 | }
379 | loss = self.sess.run(self.val_model_loss, feed_dict=feed_dict)
380 | losses.append(loss)
381 | print('Validation %d times in %.3fs mean loss is %f'
382 | % (self.val_time, time.time() - start_time, sum(losses) / len(losses)))
383 | # except Exception as e:
384 | # print(e)
385 | self.saver.save(
386 | self.sess,
387 | os.path.join(
388 | self.ckpt_path,
389 | self.ckpt_name),
390 | global_step=step)
391 | self.summary_writer.close()
392 | self.sess.close()
393 |
394 | def train_launch(self):
395 | # must in order
396 | self.init_dataset()
397 | self.init_inputs()
398 | self.init_model()
399 |
400 | # optional override
401 | self.init_loss()
402 | self.init_learning_rate()
403 | self.init_train_op()
404 | self.init_loader_saver_summary()
405 | self.init_session()
406 | self.train()
407 |
--------------------------------------------------------------------------------
/data/dataset/readme.txt:
--------------------------------------------------------------------------------
1 | This dir contains different dataset
--------------------------------------------------------------------------------
/data/name/coco.name:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/data/name/coco.name
--------------------------------------------------------------------------------
/demon.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-24
8 | """
9 | import os
10 | import cv2
11 | import time
12 | import tensorflow as tf
13 | from core.infer.visual_utils import get_results, draw_point, draw_skeleton
14 |
15 |
16 | def read_pb(pb_path, input_node_name_and_val, output_node_name):
17 | """
18 | :param pb_path:
19 | :param input_node_name_and_val: {(str) input_node_name: (any) input_node_val}
20 | :param output_node_name: [(str) output_node_name]
21 | :return: [output]
22 | """
23 | with tf.Graph().as_default():
24 | output_graph_def = tf.GraphDef()
25 | with open(pb_path, 'rb') as f:
26 | output_graph_def.ParseFromString(f.read())
27 | tf.import_graph_def(output_graph_def, name='')
28 | config = tf.ConfigProto(allow_soft_placement=True) # 是否自动选择GPU
29 | config.gpu_options.allow_growth = True
30 | with tf.Session(config=config) as sess:
31 | # sess.run(tf.global_variables_initializer())
32 | # 定义输入的张量名称,对应网络结构的输入张量
33 | # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
34 | feed_dict = {}
35 | for key in input_node_name_and_val:
36 | input_tensor = sess.graph.get_tensor_by_name(key)
37 | feed_dict[input_tensor] = input_node_name_and_val[key]
38 |
39 | # 定义输出的张量名称
40 | output_tensor = []
41 | for name in output_node_name:
42 | output_tensor.append(sess.graph.get_tensor_by_name(name))
43 |
44 | # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
45 | start_time = time.time()
46 | output = sess.run(output_tensor, feed_dict=feed_dict)
47 | print('Infer time is %.4f' % (time.time() - start_time))
48 | return output
49 |
50 |
51 | if __name__ == '__main__':
52 | import numpy as np
53 | from core.dataset.data_generator import Dataset
54 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
55 | pb_path = 'Hourglass.pb'
56 | # pb_path = 'tensorRT/TensorRT.pb'
57 | img_dir = '/data/dataset/coco/images/val2017'
58 | gt_path = 'data/dataset/coco/coco_val.txt'
59 | batch_size = 8
60 | img_size = (512,512)
61 | hm_size = (128,128)
62 | dataset = Dataset(img_dir, gt_path, batch_size, None, img_size, hm_size)
63 | it = dataset.iterator(4, False)
64 | image, hm = next(it)
65 | image_norm = (image / 127.5) - 1
66 | input_dict = {'Placeholder/inputs_x:0': image_norm}
67 | output_node_name=['Keypoints/keypoint_1/conv/Sigmoid:0']
68 | outputs = read_pb(pb_path, input_dict, output_node_name)
69 | for k in range(len(outputs)):
70 | # outputs[k] = sigmoid(outputs[k])
71 | points = get_results(outputs[k], 0.3)
72 | gt_points = get_results(hm, 0.3)
73 | print(points)
74 | print(gt_points)
75 | for i in range(len(points)):
76 | img = image[i][:, :, ::-1]
77 | for j in range(len(points[i])):
78 | if points[i][j][0] != -1:
79 | points[i][j][0] = int(points[i][j][0]/hm_size[1]*img.shape[1])
80 | if points[i][j][1] != -1:
81 | points[i][j][1] = int(points[i][j][1]/hm_size[0]*img.shape[0])
82 | for j in range(len(gt_points[i])):
83 | if gt_points[i][j][0] != -1:
84 | gt_points[i][j][0] = int(gt_points[i][j][0]/hm_size[1]*img.shape[1])
85 | if gt_points[i][j][1] != -1:
86 | gt_points[i][j][1] = int(gt_points[i][j][1]/hm_size[0]*img.shape[0])
87 |
88 | one_ouput = np.sum(outputs[k][i], axis=-1, keepdims=True) * 255
89 | tile_output = np.tile(one_ouput, (1, 1, 3))
90 | tile_img =cv2.resize(tile_output, img_size) + img
91 |
92 | cv2.imwrite('render_img/'+str(i)+'_'+str(k)+'_origin.jpg', img)
93 |
94 |
95 | cv2.imwrite('render_img/'+str(i)+'_'+str(k)+'_hm.jpg', tile_img)
96 |
97 | sk_img = draw_skeleton(img, points[i],'coco')
98 | cv2.imwrite('render_img/' + str(i) + '_' + str(k) + '_skeleton.jpg', sk_img)
99 |
100 | img = draw_skeleton(img, gt_points[i],'coco')
101 | cv2.imwrite('render_img/'+str(i)+'_'+str(k)+'_visible.jpg', img)
102 |
103 | # outputs[k]
104 |
105 |
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/infer_hourglass.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-21
8 | """
9 | from core.infer.infer_utils import read_pb, pred_one_image
10 | from core.infer.visual_utils import draw_point, draw_bbx, draw_skeleton
11 | # image = cv2.imread(img_path)
12 | # # 1.实例化模型
13 | # sess, input_tensor, output_tensor = \
14 | # read_pb(pb_path, ['Placeholder/inputs_x:0'], ['HourglassNet/keypoint_1/conv/Sigmoid:0'])
15 | # # 2.处理图片 每次处理一个图里面的数据作为batch
16 | # # bbxes 是提前知道的信息 bbxes = [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax]]
17 | # points = pred_one_image(image, bbxes, sess, input_tensor, output_tensor)
18 | # print(points)
19 | # for point in points:
20 | # image = draw_point(image, point)
--------------------------------------------------------------------------------
/output/coco/readme.txt:
--------------------------------------------------------------------------------
1 | directory contains tensorboard files
--------------------------------------------------------------------------------
/output/mpii/readme.txt:
--------------------------------------------------------------------------------
1 | directory contains tensorboard files
--------------------------------------------------------------------------------
/script/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-17
8 | """
9 |
--------------------------------------------------------------------------------
/script/ckpt2ckpt.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-11-29
8 | """
9 |
10 | import tensorflow as tf
11 | from core.network.keypoints import Keypoints
12 | from tensorflow.python import pywrap_tensorflow
13 | import config.config_hourglass_coco as cfg
14 | import tensorflow.contrib.slim as slim
15 | import os
16 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
17 |
18 | ori_ckpt = '/data/checkpoints/pre_train/resnet_v2_101.ckpt'
19 | new_ckpt = os.path.join(cfg.ckpt_dir, "Keypoints_coco_resnet_v2_101.ckpt")
20 |
21 |
22 | def change_name(name):
23 | # 自己的网络名称有前缀
24 | return name[19:]
25 |
26 | def restore_name(name):
27 | return "Keypoints/backbone/" + name
28 |
29 | def map_var(name_list):
30 | # {pretrain文件中的名字 :自己模型中的tensor}
31 | var_list = {}
32 | for name in name_list:
33 | new_name = restore_name(name)
34 | var_list[name] = slim.get_variables_by_name(new_name)[0]
35 | return var_list
36 |
37 | # origin
38 | reader = pywrap_tensorflow.NewCheckpointReader(ori_ckpt)
39 | var_ori = reader.get_variable_to_shape_map()
40 | # network
41 | inputs = tf.placeholder(tf.float32, [1, 512, 512, 3])
42 | centernet = Keypoints(inputs, 80,
43 | num_block=cfg.num_block,
44 | backbone="resnet_v2_101",
45 | num_depth=cfg.num_depth,
46 | residual_dim=cfg.residual_dim,
47 | is_training=True,
48 | is_maxpool=cfg.is_maxpool,
49 | is_nearest=cfg.is_nearest,
50 | reuse=False
51 | )
52 | var_new = slim.get_variables_to_restore()
53 |
54 | # search common
55 | count = 0
56 | ommit = 0
57 | all_var = set()
58 | restore_list = []
59 | for key in var_new:
60 | # 命名改变了 改成了"CenterNet/作为前缀, 需要去掉"
61 | all_var.add(change_name(key.name.strip(':0')))
62 | for key in var_ori:
63 | if key in all_var:
64 | ori_var = reader.get_tensor(key)
65 | new_var = slim.get_variables_by_name(restore_name(key))[0]
66 | s1 = list(ori_var.shape)
67 | s2 = new_var.get_shape().as_list()
68 | if s1 == s2:
69 | count += 1
70 | restore_list.append(key)
71 | else:
72 | ommit += 1
73 | else:
74 | ommit += 1
75 | print('restore ', count)
76 | print('ommit', ommit)
77 | print('all', count + ommit)
78 | var_list = map_var(restore_list)
79 | # loader = tf.train.Saver(
80 | # var_list=slim.get_variables_to_restore(
81 | # include=restore_list,
82 | # exclude=['logits']))
83 | loader = tf.train.Saver(
84 | var_list=var_list)
85 | saver = tf.train.Saver()
86 | with tf.Session() as sess:
87 | sess.run(tf.global_variables_initializer())
88 | loader.restore(sess, ori_ckpt)
89 | saver.save(sess, new_ckpt)
90 |
--------------------------------------------------------------------------------
/script/coco2txt.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-12
8 | """
9 | import os
10 | import json
11 | def coco_keypoint2txt(file, txt_path, thre=1):
12 | with open(txt_path, 'w') as writer:
13 | all_count = 0
14 | good_count = 0
15 | print('Transform %s' % (file))
16 | index = {}
17 | data = json.load(open(file))
18 | for ann in data['annotations']:
19 | all_count+=1
20 | st = str(int(ann['bbox'][0])) + ',' + str(int(ann['bbox'][1])) + ',' + str(int(ann['bbox'][0]+ann['bbox'][2])) + ',' + str(int(ann['bbox'][1]+ann['bbox'][3]))+' '
21 | gt = index.get(ann['image_id'], [])
22 | keypoints = ann['keypoints']
23 | # key = []
24 | for i in range(len(keypoints) // 3):
25 | # 不存在
26 | if keypoints[i * 3 + 2] == 0:
27 | st += '-1,-1' + ' '
28 | # key.append([-1, -1])
29 | # 标注 但不可见
30 | elif keypoints[i * 3 + 2] == 1:
31 | st += str(int(keypoints[i * 3])) + ',' + \
32 | str(int(keypoints[i * 3 + 1])) + ' '
33 | # st += '-1,-1' + ' '
34 | # key.append([keypoints[i * 3], keypoints[i * 3 + 1]])
35 | # 标注 可见
36 | elif keypoints[i * 3 + 2] == 2:
37 | st += str(int(keypoints[i * 3])) + ',' + \
38 | str(int(keypoints[i * 3 + 1])) + ' '
39 | # key.append([keypoints[i * 3], keypoints[i * 3 + 1]])
40 | else:
41 | st += '-1,-1' + ' '
42 | print('Unsupported keypoints val')
43 | # key.append([-1, -1])
44 | if st.count('-1,-1') <= thre:
45 | good_count += 1
46 | # data cleaning
47 | gt.append(st)
48 | index[ann['image_id']] = gt
49 | # writer.write(ann['image_id']+' '+st+'\n')
50 | for image in data['images']:
51 | if image['id'] in index:
52 | for i in range(len(index[image['id']])):
53 | writer.write(image['file_name'] + ' ' + index[image['id']][i] + '\n')
54 | print('total data are %d, write data are %d' % (all_count, good_count))
55 |
56 |
57 | if __name__ == '__main__':
58 | dataset = 'coco'
59 |
60 | if dataset == 'coco':
61 | coco_dir = '/data/dataset/coco'
62 | annotations_dir = os.path.join(coco_dir, 'annotations')
63 | annotation_train = os.path.join(
64 | annotations_dir,
65 | 'person_keypoints_train2017.json')
66 | annotation_val = os.path.join(
67 | annotations_dir,
68 | 'person_keypoints_val2017.json')
69 | coco_keypoint2txt(annotation_train, '../data/dataset/coco/coco_train.txt', 10)
70 | coco_keypoint2txt(annotation_val, '../data/dataset/coco/coco_val.txt', 10)
71 |
72 | if dataset == 'mpii':
73 | mpii_dir = '/data/dataset/mpii'
74 | annotations_dir = os.path.join(mpii_dir, 'annotations')
75 | annotation_train = os.path.join(
76 | annotations_dir,
77 | 'train.json')
78 | annotation_val = os.path.join(
79 | annotations_dir,
80 | 'test.json')
81 | coco_keypoint2txt(annotation_train, '../data/dataset/mpii/mpii_train.txt', 1)
82 | coco_keypoint2txt(annotation_val, '../data/dataset/mpii/mpii_val.txt', 1)
83 |
84 |
--------------------------------------------------------------------------------
/script/mpii2coco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from scipy.io import loadmat, savemat
4 | from PIL import Image
5 | import os
6 | import os.path as osp
7 | import numpy as np
8 | import json
9 |
10 | def check_empty(list, name):
11 | try:
12 | list[name]
13 | except ValueError:
14 | return True
15 |
16 | if len(list[name]) > 0:
17 | return False
18 | else:
19 | return True
20 |
21 |
22 | db_type = 'train' # train, test
23 | annot_file = loadmat('../data/dataset/mpii/annotations/annotations.mat')['RELEASE']
24 | save_path = '../data/dataset/mpii/annotations/' + db_type + '.json'
25 |
26 | joint_num = 16
27 | img_num = len(annot_file['annolist'][0][0][0])
28 |
29 | aid = 0
30 | coco = {'images': [], 'categories': [], 'annotations': []}
31 | for img_id in range(img_num):
32 |
33 | if ((db_type == 'train' and annot_file['img_train'][0][0][0][img_id] == 1) or (
34 | db_type == 'test' and annot_file['img_train'][0][0][0][img_id] == 0)) and \
35 | check_empty(annot_file['annolist'][0][0][0][img_id], 'annorect') == False: # any person is annotated
36 |
37 | filename =str(annot_file['annolist'][0][0][0][img_id]['image'][0][0][0][0]) # filename
38 | img = Image.open(osp.join('../data/dataset/mpii/images', filename))
39 | w, h = img.size
40 | img_dict = {
41 | 'id': img_id,
42 | 'file_name': filename,
43 | 'width': w,
44 | 'height': h}
45 | coco['images'].append(img_dict)
46 |
47 | if db_type == 'test':
48 | continue
49 |
50 | person_num = len(annot_file['annolist'][0][0]
51 | [0][img_id]['annorect'][0]) # person_num
52 | joint_annotated = np.zeros((person_num, joint_num))
53 | for pid in range(person_num):
54 |
55 | if check_empty(annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid], 'annopoints') == False: # kps is annotated
56 |
57 | bbox = np.zeros((4)) # xmin, ymin, w, h
58 | kps = np.zeros((joint_num, 3)) # xcoord, ycoord, vis
59 |
60 | # kps
61 | annot_joint_num = len(
62 | annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid]['annopoints']['point'][0][0][0])
63 | for jid in range(annot_joint_num):
64 | annot_jid = \
65 | annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid]['annopoints']['point'][0][0][0][jid][
66 | 'id'][0][0]
67 | kps[annot_jid][0] = \
68 | annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid]['annopoints']['point'][0][0][0][jid][
69 | 'x'][0][0]
70 | kps[annot_jid][1] = \
71 | annot_file['annolist'][0][0][0][img_id]['annorect'][0][pid]['annopoints']['point'][0][0][0][jid][
72 | 'y'][0][0]
73 | kps[annot_jid][2] = 1
74 |
75 | # bbox extract from annotated kps
76 | annot_kps = kps[kps[:, 2] == 1, :].reshape(-1, 3)
77 | xmin = np.min(annot_kps[:, 0])
78 | ymin = np.min(annot_kps[:, 1])
79 | xmax = np.max(annot_kps[:, 0])
80 | ymax = np.max(annot_kps[:, 1])
81 | width = xmax - xmin - 1
82 | height = ymax - ymin - 1
83 |
84 | # corrupted bounding box
85 | if width <= 0 or height <= 0:
86 | continue
87 | # 20% extend
88 | # else:
89 | # bbox[0] = (xmin + xmax) / 2. - width / 2 * 1.2
90 | # bbox[1] = (ymin + ymax) / 2. - height / 2 * 1.2
91 | # bbox[2] = width * 1.2
92 | # bbox[3] = height * 1.2
93 | else:
94 | bbox[0] = max(xmin,0)
95 | bbox[1] = max(ymin,0)
96 | bbox[2] = width
97 | bbox[3] = height
98 |
99 | person_dict = {'id': aid, 'image_id': img_id, 'category_id': 1, 'area': bbox[2] * bbox[3],
100 | 'bbox': bbox.tolist(), 'iscrowd': 0, 'keypoints': kps.reshape(-1).tolist(),
101 | 'num_keypoints': int(np.sum(kps[:, 2] == 1))}
102 | coco['annotations'].append(person_dict)
103 | aid += 1
104 |
105 | category = {
106 | "supercategory": "person",
107 | "id": 1, # to be same as COCO, not using 0
108 | "name": "person",
109 | "skeleton": [[0, 1],
110 | [1, 2],
111 | [2, 6],
112 | [7, 12],
113 | [12, 11],
114 | [11, 10],
115 | [5, 4],
116 | [4, 3],
117 | [3, 6],
118 | [7, 13],
119 | [13, 14],
120 | [14, 15],
121 | [6, 7],
122 | [7, 8],
123 | [8, 9]],
124 | "keypoints": ["r_ankle", "r_knee", "r_hip",
125 | "l_hip", "l_knee", "l_ankle",
126 | "pelvis", "throax",
127 | "upper_neck", "head_top",
128 | "r_wrist", "r_elbow", "r_shoulder",
129 | "l_shoulder", "l_elbow", "l_wrist"]}
130 |
131 | coco['categories'] = [category]
132 |
133 | with open(save_path, 'w') as f:
134 | json.dump(coco, f)
135 |
--------------------------------------------------------------------------------
/script/parse_ckpt.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-20
8 | """
9 |
10 | import os
11 | import tensorflow as tf
12 | from tensorflow.python import pywrap_tensorflow
13 | import numpy as np
14 |
15 | # Read data from checkpoint file
16 | # 检查模型变量的var和mean
17 |
18 |
19 | def parse_ckpt(checkpoint_path):
20 | reader =pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
21 | var_to_shape_map = reader.get_variable_to_shape_map()
22 | # Print tensor name and values
23 | key2val = {}
24 | keys = []
25 | for key in var_to_shape_map:
26 |
27 | if key.split('/')[-1] in ['weights', 'biases']:
28 | print("tensor_name: ", key)
29 | keys.append(key)
30 | val = reader.get_tensor(key)
31 | key2val[key] = np.array(val)
32 | print(np.sum(reader.get_tensor(key)))
33 | print(np.var(reader.get_tensor(key)))
34 | return keys, key2val
35 |
36 |
37 | def read_origin(path):
38 | org_weights_mess = []
39 | load = tf.train.import_meta_graph(path + '.meta')
40 | with tf.Session() as sess:
41 | load.restore(sess, path)
42 | for var in tf.global_variables():
43 | var_name = var.op.name
44 | var_name_mess = str(var_name).split('/')
45 | var_shape = var.shape
46 | if (var_name_mess[-1] not in ['weights', 'gamma', 'beta', 'moving_mean', 'moving_variance']):
47 | continue
48 | org_weights_mess.append([var_name, var_shape])
49 | print("=> " + str(var_name).ljust(50), var_shape)
50 |
51 | def transform(key, key2val):
52 | for k in key:
53 | print(k)
54 | try:
55 | name=k.replace('HourglassNet','model').replace('backbone','stacks').replace('hourglass','stage')
56 | print(name)
57 |
58 | except Exception:
59 | pass
60 |
61 |
62 | if __name__ == '__main__':
63 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
64 | # ckpt_name = 'hg_refined_200.ckpt-0'
65 | # checkpoint_path = os.path.join('../checkpoints', 'pretrained', ckpt_name)
66 | # _, key2val = parse_ckpt(checkpoint_path)
67 | ckpt_name ='mpii/Hourglass_mpii.ckpt-39000'
68 | checkpoint_path = os.path.join('../checkpoints', ckpt_name)
69 | key = parse_ckpt(checkpoint_path)
70 | # transform(key,key2val)
71 |
--------------------------------------------------------------------------------
/tensorRT/c++/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.5)
2 | cuda_add_executable(keypoints Keypoints_main.cpp ../../../common_source/common_tensorrt/logger.cpp source/keypoints_tensorrt.cpp source/keypoints_tensorrt.h source/utils.h source/utils.cpp source/ResizeNearestNeighbor.cpp source/ResizeNearestNeighbor.h source/my_plugin.h source/my_plugin.cpp source/ResizeNearestNeighbor.cu)
3 | set_property(TARGET keypoints PROPERTY FOLDER project/keypoints)
4 | target_link_libraries(keypoints libnvinfer.so libnvparsers.so cudart.so libopencv_core.so libopencv_imgproc.so libopencv_imgcodecs.so)
--------------------------------------------------------------------------------
/tensorRT/c++/Keypoints_main.cpp:
--------------------------------------------------------------------------------
1 | // Created by luozhiwang (luozw1994@outlook.com)
2 | // Date: 2019/12/18
3 | #include "source/keypoints_tensorrt.h"
4 | #include
5 | #include "source/utils.h"
6 | #include "source/my_plugin.h"
7 |
8 | const std::string project_name = "TensorRT_Keypoints";
9 | void printHelpInfo()
10 | {
11 | std::cout << "Usage: ./keypoints [-h or --help] [-d or "
12 | "--datadir=] [--useDLACore=]\n";
13 | std::cout << "--help Display help information\n";
14 | std::cout << "--datadir Specify path to a data directory, overriding "
15 | "the default. This option can be used multiple times to add "
16 | "multiple directories. If no data directories are given, the "
17 | "default is to use (data/samples/mnist/, data/mnist/)"
18 | << std::endl;
19 | std::cout << "--useDLACore=N Specify a DLA engine for layers that support "
20 | "DLA. Value can range from 0 to n-1, where n is the number of "
21 | "DLA engines on the platform."
22 | << std::endl;
23 | std::cout << "--int8 Run in Int8 mode.\n";
24 | std::cout << "--fp16 Run in FP16 mode." << std::endl;
25 | }
26 |
27 | samplesCommon::UffSampleParams initial_params(const samplesCommon::Args &args){
28 | samplesCommon::UffSampleParams params;
29 | if (args.dataDirs.empty()){
30 | params.dataDirs.push_back("/work/tensorRT/project/Template/Keypoints/data/images/");
31 | }
32 | else //!< Use the data directory provided by the user
33 | {
34 | params.dataDirs = args.dataDirs;
35 | }
36 | params.uffFileName = "/work/tensorRT/project/Template/Keypoints/data/uff/keypoints.uff";
37 | params.inputTensorNames.push_back("Placeholder/inputs_x");
38 | params.batchSize = 1;
39 | params.outputTensorNames.push_back("Keypoints/keypoint_1/conv/Sigmoid");
40 | params.dlaCore = args.useDLACore;
41 | // params.int8 = args.runInInt8;
42 | params.int8 = false;
43 | // params.fp16 = args.runInFp16;
44 | params.fp16 = false;
45 | return params;
46 | }
47 |
48 | int main(int argc, char **argv){
49 | REGISTER_TENSORRT_PLUGIN(MyPlugin);
50 |
51 | samplesCommon::Args args;
52 | if (!samplesCommon::parseArgs(args, argc, argv)){
53 | gLogError << "Invalid arguments" << std::endl;
54 | printHelpInfo();
55 | return EXIT_FAILURE;
56 | }
57 | if (args.help)
58 | {
59 | printHelpInfo();
60 | return EXIT_SUCCESS;
61 | }
62 | auto sampleTest = Logger::defineTest(project_name, argc, argv);
63 | Logger::reportTestStart(sampleTest);
64 | samplesCommon::UffSampleParams params = initial_params(args);
65 | InputParams input_params(512, 512, 3, 128, 128, 17);
66 | Keypoints keypoints(params, input_params);
67 | gLogInfo << "Building and running a GPU inference engine for " << project_name
68 | << std::endl;
69 | if (!keypoints.build())
70 | {
71 | return Logger::reportFail(sampleTest);
72 | }
73 | gLogInfo << "Begine to Infer"
74 | << std::endl;
75 | if (!keypoints.infer())
76 | {
77 | return Logger::reportFail(sampleTest);
78 | }
79 | gLogInfo << "Destroy the engine"
80 | << std::endl;
81 | if (!keypoints.tearDown())
82 | {
83 | return Logger::reportFail(sampleTest);
84 | }
85 | return Logger::reportPass(sampleTest);
86 | }
--------------------------------------------------------------------------------
/tensorRT/c++/data/images/1_0_origin.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/tensorRT/c++/data/images/1_0_origin.jpg
--------------------------------------------------------------------------------
/tensorRT/c++/data/images/1_0_origin_render.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/tensorRT/c++/data/images/1_0_origin_render.jpg
--------------------------------------------------------------------------------
/tensorRT/c++/data/images/7_0_origin.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/tensorRT/c++/data/images/7_0_origin.jpg
--------------------------------------------------------------------------------
/tensorRT/c++/data/images/7_0_origin_render.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Syencil/Keypoints/b0125a7a2fc20cb72f29f1ac771b7f9a54eeccf2/tensorRT/c++/data/images/7_0_origin_render.jpg
--------------------------------------------------------------------------------
/tensorRT/c++/source/ResizeNearestNeighbor.cpp:
--------------------------------------------------------------------------------
1 | // Created by luozhiwang (luozw1994@outlook.com)
2 | // Date: 2020/2/11
3 |
4 | #include "ResizeNearestNeighbor.h"
5 |
6 | // <============ 构造函数 ===========>
7 | UffUpSamplePluginV2::UffUpSamplePluginV2(const nvinfer1::PluginFieldCollection &fc, float scale): mScale(scale){
8 |
9 | }
10 |
11 | UffUpSamplePluginV2::UffUpSamplePluginV2(const void *data, size_t length){
12 | if (data== nullptr){
13 | printf("nullptr\n");
14 | }
15 | const char *d = static_cast(data);
16 | const char* const start = d;
17 | mCHW = read(d);
18 | mDataType = read(d);
19 | mScale = read(d);
20 | mOutputHeight = read(d);
21 | mOutputWidth = read(d);
22 | if (mDataType == nvinfer1::DataType::kINT8){
23 | mInHostScale = read(d);
24 | mOutHostScale = read(d);
25 | }
26 | assert(d = start + length);
27 |
28 | }
29 |
30 |
31 | // <============ IPluginV2 ===========>
32 | const char *UffUpSamplePluginV2::getPluginType() const {
33 | // 保证和IPluginCreator::getPluginName()一致
34 | return "ResizeNearestNeighbor";
35 | }
36 |
37 | const char *UffUpSamplePluginV2::getPluginVersion() const {
38 | // 保证和IPluginCreator::getPluginVersion()一致
39 | return "2";
40 | }
41 |
42 | int UffUpSamplePluginV2::getNbOutputs() const {
43 | return 1;
44 | }
45 |
46 | nvinfer1::Dims
47 | UffUpSamplePluginV2::getOutputDimensions(int index, const nvinfer1::Dims *inputs_dims, int number_input_dims) {
48 | assert(number_input_dims==1);
49 | assert(index == 0);
50 | assert(inputs_dims[0].nbDims==3);
51 | mCHW = inputs_dims[0];
52 | mOutputHeight = inputs_dims[0].d[1] * mScale;
53 | mOutputWidth = inputs_dims[0].d[2] * mScale;
54 | return nvinfer1::Dims3(mCHW.d[0], mOutputHeight, mOutputWidth);
55 | }
56 |
57 | int UffUpSamplePluginV2::initialize() {
58 | // 可以用来分配内存
59 | int input_height = mCHW.d[1];
60 | int input_widht = mCHW.d[2];
61 | if (mOutputHeight == int(input_height * mScale) && mOutputWidth == int(input_widht * mScale)){
62 | return 0;
63 | } else{
64 | return 1;
65 | }
66 | }
67 |
68 | void UffUpSamplePluginV2::terminate() {
69 | // 可以用来释放内存
70 | }
71 |
72 | size_t UffUpSamplePluginV2::getWorkspaceSize(int max_batch_size) const {
73 | // 根据maxBatchSize确定该层所需要的最大内存空间
74 | return 0;
75 | }
76 |
77 | size_t UffUpSamplePluginV2::getSerializationSize() const {
78 | size_t serialization_size = 0;
79 | serialization_size += sizeof(nvinfer1::Dims);
80 | serialization_size += sizeof(nvinfer1::DataType);
81 | serialization_size += sizeof(float);
82 | serialization_size += sizeof(int) * 2;
83 | if (mDataType == nvinfer1::DataType::kINT8){
84 | serialization_size += sizeof(float) * 2;
85 | }
86 | return serialization_size;
87 | }
88 |
89 | void UffUpSamplePluginV2::serialize(void *buffer) const {
90 | char *d = static_cast(buffer);
91 | const char* const start = d;
92 | printf("serialize mScale %f\n", mScale);
93 | write(d, mCHW);
94 | write(d, mDataType);
95 | write(d, mScale);
96 | write(d, mOutputHeight);
97 | write(d, mOutputWidth);
98 | if (mDataType == nvinfer1::DataType::kINT8){
99 | write(d, mInHostScale);
100 | write(d, mOutHostScale);
101 | }
102 | assert(d == start + getSerializationSize());
103 | }
104 |
105 | void UffUpSamplePluginV2::destroy() {
106 | delete this;
107 | }
108 |
109 | void UffUpSamplePluginV2::setPluginNamespace(const char *plugin_namespace) {
110 | mNameSpace = plugin_namespace;
111 | }
112 |
113 | const char *UffUpSamplePluginV2::getPluginNamespace() const {
114 | return mNameSpace.data();
115 | }
116 |
117 |
118 | // <============ IPluginV2Ext ===========>
119 | nvinfer1::DataType
120 | UffUpSamplePluginV2::getOutputDataType(int index, const nvinfer1::DataType *input_types, int num_inputs) const {
121 | assert(index==0);
122 | assert(input_types!= nullptr);
123 | assert(num_inputs==1);
124 | return input_types[index];
125 | }
126 |
127 | bool UffUpSamplePluginV2::isOutputBroadcastAcrossBatch(int output_index, const bool *input_is_broadcasted,
128 | int num_inputs) const {
129 | return false;
130 | }
131 |
132 | bool UffUpSamplePluginV2::canBroadcastInputAcrossBatch(int input_idx) const {
133 | return false;
134 | }
135 |
136 | nvinfer1::IPluginV2Ext *UffUpSamplePluginV2::clone() const {
137 | auto *plugin = new UffUpSamplePluginV2(*this);
138 | return plugin;
139 | }
140 |
141 |
142 | // <============ IPluginV2IOExt ===========>
143 | void UffUpSamplePluginV2::configurePlugin(const nvinfer1::PluginTensorDesc *plugin_tensor_desc_input, int num_input,
144 | const nvinfer1::PluginTensorDesc *plugin_tensor_desc_output, int num_output) {
145 | assert(num_input==1 && plugin_tensor_desc_input!= nullptr);
146 | assert(num_output==1 && plugin_tensor_desc_output != nullptr);
147 | assert(plugin_tensor_desc_input[0].type == plugin_tensor_desc_output[0].type);
148 | assert(plugin_tensor_desc_input[0].format == nvinfer1::TensorFormat::kLINEAR);
149 | assert(plugin_tensor_desc_output[0].format == nvinfer1::TensorFormat::kLINEAR);
150 |
151 | mInHostScale = plugin_tensor_desc_input->scale;
152 | mOutHostScale = plugin_tensor_desc_output->scale;
153 |
154 | mDataType = plugin_tensor_desc_input[0].type;
155 | }
156 |
157 | bool UffUpSamplePluginV2::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *plugin_tensor_desc_in_out, int num_inputs,
158 | int num_outputs) const {
159 | assert(plugin_tensor_desc_in_out != nullptr);
160 | assert(num_inputs == num_outputs == 1);
161 | assert(pos < num_inputs + num_outputs);
162 | bool condition = true;
163 | condition &= plugin_tensor_desc_in_out[pos].format == nvinfer1::TensorFormat::kLINEAR;
164 | condition &= plugin_tensor_desc_in_out[pos].type != nvinfer1::DataType::kINT32;
165 | condition &= plugin_tensor_desc_in_out[pos].type == plugin_tensor_desc_in_out[0].type;
166 | return condition;
167 | }
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
--------------------------------------------------------------------------------
/tensorRT/c++/source/ResizeNearestNeighbor.cu:
--------------------------------------------------------------------------------
1 | // Created by luozhiwang (luozw1994@outlook.com)
2 | // Date: 2020/2/12
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | #include "ResizeNearestNeighbor.h"
9 |
10 | static void HandleError(cudaError_t err, const char *file, int line ) {
11 | if (err != cudaSuccess) {
12 | printf( "%s in %s at line %d\n", cudaGetErrorString( err ),
13 | file, line );
14 | exit( EXIT_FAILURE );
15 | }
16 | }
17 | #define HANDLE_ERROR( err ) (HandleError( err, __FILE__, __LINE__ ))
18 |
19 | __device__ int transform_idx(int idx, int C, int H, int W, float scale_factor){
20 | // 从后往前解idx
21 | // idx = n*C*H*W + c*H*W + h*W + w
22 | int w = idx % W;
23 | idx /= W;
24 | int h = idx % H;
25 | idx /= H;
26 | int c = idx % C;
27 | idx /= C;
28 | w /= scale_factor;
29 | h /= scale_factor;
30 | int hh = H / scale_factor;
31 | int ww = W / scale_factor;
32 | return idx * C * hh * ww + c * hh * ww + h * ww + w;
33 | }
34 |
35 |
36 | template
37 | __global__ void UpSampleKernel(const Dtype *input, Dtype *output, int num_element, float scale_factor, int C, int H, int W){
38 | int tid = threadIdx.x + blockIdx.x * blockDim.x;
39 | if (tid < num_element){
40 | int idx = transform_idx(tid, C, H, W, scale_factor);
41 | output[tid]=input[idx];
42 | }
43 | }
44 |
45 | template
46 | void UffUpSamplePluginV2::forwardGpu(const Dtype *input, Dtype *output, int N, int C, int H, int W, cudaStream_t stream) {
47 | int num_element = N * C * H * W;
48 | UpSampleKernel<<<(num_element-1)/mThreadNum+1, mThreadNum, 0, stream>>>(input, output, num_element, mScale, C, H, W);
49 | }
50 |
51 | size_t get_size(nvinfer1::DataType dataType){
52 | switch(dataType){
53 | case nvinfer1::DataType::kFLOAT :
54 | return sizeof(float);
55 | case nvinfer1::DataType::kHALF :
56 | return sizeof(__half);
57 | case nvinfer1::DataType::kINT8 :
58 | return sizeof(int8_t);
59 | default:
60 | throw "Unsupported Data Type";
61 | }
62 | }
63 |
64 | int UffUpSamplePluginV2::enqueue(int batch_size, const void *const *inputs, void **outputs, void *workspace,
65 | cudaStream_t stream) {
66 | const int channel = mCHW.d[0];
67 | const int input_h = mCHW.d[1];
68 | const int input_w = mCHW.d[2];
69 | const int output_h = mOutputHeight;
70 | const int output_w = mOutputWidth;
71 | int total_element = batch_size * channel * input_h * input_w;
72 | if (input_h == output_h && input_w == output_w){
73 | HANDLE_ERROR(cudaMemcpyAsync(outputs[0], inputs[0], get_size(mDataType) * total_element, cudaMemcpyDeviceToDevice, stream));
74 | HANDLE_ERROR(cudaStreamSynchronize(stream));
75 | return 0;
76 | }
77 | switch (mDataType){
78 | case nvinfer1::DataType::kFLOAT :
79 | forwardGpu((const float *)inputs[0], (float *)outputs[0], batch_size, channel, output_h, output_w, stream);
80 | break;
81 | case nvinfer1::DataType::kHALF :
82 | forwardGpu<__half>((const __half *)inputs[0], (__half *)outputs[0], batch_size, channel, output_h, output_w, stream);
83 | break;
84 | case nvinfer1::DataType::kINT8 :
85 | forwardGpu((const int8_t *)inputs[0], (int8_t *)outputs[0], batch_size, channel, output_h, output_w, stream);
86 | break;
87 | default:
88 | throw "Unsupported Data Type";
89 | }
90 | return 0;
91 | }
--------------------------------------------------------------------------------
/tensorRT/c++/source/ResizeNearestNeighbor.h:
--------------------------------------------------------------------------------
1 | // Created by luozhiwang (luozw1994@outlook.com)
2 | // Date: 2020/2/11
3 |
4 | #ifndef TENSORRT_RESIZENEARESTNEIGHBOR_H
5 | #define TENSORRT_RESIZENEARESTNEIGHBOR_H
6 |
7 | #include
8 | #include
9 |
10 | #include "utils.h"
11 |
12 | class UffUpSamplePluginV2 : public nvinfer1::IPluginV2IOExt{
13 | private:
14 | nvinfer1::Dims mCHW;
15 | nvinfer1::DataType mDataType;
16 | float mScale;
17 | int mOutputHeight;
18 | int mOutputWidth;
19 |
20 | float mInHostScale{-1.0};
21 | float mOutHostScale{-1.0};
22 |
23 | std::string mNameSpace;
24 | const int mThreadNum = sizeof(unsigned long long) * 8 ;
25 | public:
26 | UffUpSamplePluginV2(const nvinfer1::PluginFieldCollection &fc, float scale=2.0);
27 | UffUpSamplePluginV2(const void *data, size_t length);
28 | // IPluginV2
29 | const char* getPluginType () const override;
30 | const char *getPluginVersion () const override;
31 | int getNbOutputs () const override;
32 | nvinfer1::Dims getOutputDimensions (int index, const nvinfer1::Dims *inputs_dims, int number_input_dims) override;
33 | int initialize() override;
34 | void terminate () override;
35 | size_t getWorkspaceSize (int max_batch_size) const override;
36 | int enqueue (int batch_size, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream) override;
37 | size_t getSerializationSize () const override;
38 | void serialize (void *buffer) const override;
39 | void destroy () override;
40 | void setPluginNamespace (const char *plugin_namespace) override;
41 | const char *getPluginNamespace () const override;
42 |
43 | // IPluginV2Ext
44 | nvinfer1::DataType getOutputDataType (int index, const nvinfer1::DataType *input_types, int num_inputs) const override;
45 | bool isOutputBroadcastAcrossBatch (int output_index, const bool *input_is_broadcasted, int num_inputs) const override;
46 | bool canBroadcastInputAcrossBatch (int input_idx) const override;
47 | IPluginV2Ext * clone () const override;
48 |
49 | // IPluginV2IOExt
50 | void configurePlugin (const nvinfer1::PluginTensorDesc *plugin_tensor_desc_input, int num_input, const nvinfer1::PluginTensorDesc *plugin_tensor_desc_output, int num_output) override;
51 | bool supportsFormatCombination (int pos, const nvinfer1::PluginTensorDesc *inOut, int num_inputs, int num_outputs) const override;
52 |
53 | // Extension
54 | template
55 | void forwardGpu(const Dtype* input,Dtype * outputint ,int N,int C,int H ,int W, cudaStream_t stream);
56 | };
57 |
58 |
59 | #endif //TENSORRT_RESIZENEARESTNEIGHBOR_H
60 |
--------------------------------------------------------------------------------
/tensorRT/c++/source/keypoints_tensorrt.cpp:
--------------------------------------------------------------------------------
1 | // Created by luozhiwang (luozw1994@outlook.com)
2 | // Date: 2020/2/7
3 |
4 | #include "keypoints_tensorrt.h"
5 |
6 | bool Keypoints::constructNetwork(Keypoints::sample_unique_ptr &parser,
7 | Keypoints::sample_unique_ptr &network) {
8 | assert(uff_params.inputTensorNames.size() == 1);
9 | assert(uff_params.outputTensorNames.size() == 1);
10 | if (!parser -> registerInput(uff_params.inputTensorNames[0].c_str(), nvinfer1::Dims3(image_c, image_h, image_w), nvuffparser::UffInputOrder::kNCHW)){
11 | gLogError << "Register Input Failed!" << std::endl;
12 | return false;
13 | }
14 | if (!parser -> registerOutput(uff_params.outputTensorNames[0].c_str())){
15 | gLogError << "Register Output Failed!" << std::endl;
16 | return false;
17 | }
18 | if (!parser -> parse(uff_params.uffFileName.c_str(), *network, nvinfer1::DataType::kFLOAT)){
19 | gLogError << "Parse Uff Failed!" << std::endl;
20 | return false;
21 | }
22 | if (uff_params.int8){
23 | samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
24 | }
25 | return true;
26 | }
27 |
28 | bool Keypoints::processInput(const samplesCommon::BufferManager &buffer_manager, const std::string &input_tensor_name,
29 | const std::string &image_path) const {
30 | const int input_h = input_dims.d[1];
31 | const int input_w = input_dims.d[2];
32 | std::vector file_data = imagePreprocess(image_path, image_h, image_w);
33 | if (file_data.size() != input_h * input_w * image_c){
34 | gLogError << "FileData size is "<(buffer_manager.getHostBuffer(input_tensor_name));
38 | for (int i = 0; i < input_h * input_w * image_c; ++i){
39 | host_input_buffer[i] = static_cast(file_data[i]) / 128.0 - 1;
40 | }
41 | return true;
42 | }
43 |
44 | std::vector> Keypoints::processOutput(const samplesCommon::BufferManager &buffer_manager, const std::string &output_tensor_name) const {
45 | auto *origin_output = static_cast(buffer_manager.getHostBuffer(output_tensor_name));
46 | gLogInfo<< "Output: "<< std::endl;
47 | // Keypoint index transformation idx_x, idx_y, prob
48 | std::vector> keypoints;
49 | for (int c = 0; c < heatmap_c; ++c){
50 | std::vector keypoint;
51 | int max_idx = -1;
52 | float max_prob = -1;
53 | // for (int idx = heatmap_h * heatmap_w * c; idx < heatmap_h * heatmap_w * (c + 1); ++idx){
54 | // if (origin_output[idx] > max_prob){
55 | // max_idx = idx;
56 | // max_prob = origin_output[idx];
57 | // }
58 | // }
59 | // keypoint.push_back(static_cast(max_idx % heatmap_w) / heatmap_w);
60 | // keypoint.push_back(static_cast((max_idx / heatmap_w) % heatmap_h) / heatmap_h);
61 | // 迷之操作 输入都是kNCHW 输出怎么就是kNHWC了
62 | for (int idx = c; idx < heatmap_c * heatmap_h * heatmap_w; idx+=heatmap_c){
63 | if (origin_output[idx] > max_prob){
64 | max_idx = idx;
65 | max_prob = origin_output[idx];
66 | }
67 | }
68 | keypoint.push_back(static_cast(max_idx / heatmap_c % heatmap_w) / heatmap_w);
69 | keypoint.push_back(static_cast((max_idx / heatmap_c) / heatmap_w) / heatmap_h);
70 |
71 | keypoint.push_back(max_prob);
72 | keypoints.push_back(keypoint);
73 | }
74 | for (int c = 0; c < heatmap_c; c++){
75 | gLogInfo << "channel "<< c << " ==> x : "<< keypoints[c][0] << " y : " << keypoints[c][1] << " prob : " << keypoints[c][2]<< std::endl;
76 | }
77 | return keypoints;
78 | }
79 |
80 | Keypoints::Keypoints(samplesCommon::UffSampleParams params, InputParams input_params) : uff_params(std::move(params)), image_h(input_params.image_h), image_w(input_params.image_w), image_c(input_params.image_c), heatmap_h(input_params.heatmap_h), heatmap_w(input_params.heatmap_w), heatmap_c(input_params.heatmap_c){
81 | gLogInfo << "Keypoints Construction" << std::endl;
82 | }
83 |
84 | bool Keypoints::build() {
85 | auto builder = sample_unique_ptr(nvinfer1::createInferBuilder(gLogger.getTRTLogger()));
86 | if (!builder){
87 | gLogError << "Create Builder Failed" << std::endl;
88 | return false;
89 | }
90 | auto network = sample_unique_ptr(builder -> createNetworkV2(0U));
91 | if (!network){
92 | gLogError << "Create Network Failed" << std::endl;
93 | return false;
94 | }
95 | auto parser = sample_unique_ptr(nvuffparser::createUffParser());
96 | if (!parser){
97 | gLogError << "Create Parser Failed" << std::endl;
98 | return false;
99 | }
100 | if (!constructNetwork(parser, network)){
101 | gLogError << "Construct Network Failed" << std::endl;
102 | return false;
103 | }
104 |
105 | // 配置config
106 | builder -> setMaxBatchSize(1);
107 | auto config = sample_unique_ptr(builder -> createBuilderConfig());
108 | if (!config){
109 | gLogError << "Create Config Failed" << std::endl;
110 | return false;
111 | }
112 | config -> setMaxWorkspaceSize(1_GiB);
113 | config -> setFlag(BuilderFlag::kGPU_FALLBACK); // 可以使用DLA加速
114 |
115 | if (uff_params.fp16){
116 | config -> setFlag(BuilderFlag::kFP16);
117 | }
118 | if (uff_params.int8){
119 | config -> setFlag(BuilderFlag::kINT8);
120 | }
121 | samplesCommon::enableDLA(builder.get(), config.get(), uff_params.dlaCore, true);
122 | cuda_engine = std::shared_ptr(builder -> buildEngineWithConfig(*network, *config), samplesCommon::InferDeleter());
123 | if (!cuda_engine){
124 | gLogError << "Create Config Failed" << std::endl;
125 | return false;
126 | }
127 |
128 | assert(network -> getNbInputs() == 1);
129 | assert(network -> getNbOutputs() == 1);
130 | input_dims = network -> getInput(0) ->getDimensions();
131 | assert(input_dims.nbDims == 3);
132 |
133 | gLogInfo << "Build Network Success!" << std::endl;
134 | return true;
135 | }
136 |
137 | bool Keypoints::infer() {
138 | samplesCommon::BufferManager buffer_manager(cuda_engine, uff_params.batchSize);
139 | auto context = sample_unique_ptr(cuda_engine -> createExecutionContext());
140 | if (!context){
141 | gLogError << "Create Context Failed" << std::endl;
142 | return false;
143 | }
144 | // 获取问价夹下所有Image图片
145 | std::vector images;
146 | DIR *dir = opendir(uff_params.dataDirs[0].c_str());
147 | dirent *p = nullptr;
148 | gLogInfo << "Fetch images in " << uff_params.dataDirs[0]<d_name[0] && (strstr(p -> d_name, ".jpg") || strstr(p -> d_name, "png"))){
153 | std::string imagePath = uff_params.dataDirs[0]+"/"+p->d_name;
154 | gLogInfo<<"--Image : "<d_name< execute(uff_params.batchSize, buffer_manager.getDeviceBindings().data())){
163 | gLogError<<"Execute Failed!"<(t_end - t_start).count();
170 | total += elapsed_time;
171 | buffer_manager.copyOutputToHost();
172 |
173 | // 将输出结果渲染
174 | std::vector> keypoints;
175 | keypoints = processOutput(buffer_manager, uff_params.outputTensorNames[0]);
176 | cv::Mat ori_img = cv::imread(imagePath, cv::IMREAD_COLOR);
177 | cv::Mat render_img = renderKeypoint(ori_img, keypoints, heatmap_c, 0.3);
178 | saveImage(render_img, imagePath.insert(imagePath.length() - 4, "_render"));
179 |
180 | ++count;
181 | }
182 | }
183 | closedir(dir);
184 | gLogInfo<< "Total run time is " << total <<" ms\n";
185 | gLogInfo<< "Average over " << count << " files run time is "<
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include "utils.h"
13 |
14 |
15 | class Keypoints{
16 | public:
17 | template
18 | using sample_unique_ptr = std::unique_ptr;
19 | private:
20 | std::shared_ptr cuda_engine{nullptr};
21 | samplesCommon::UffSampleParams uff_params;
22 | nvinfer1::Dims input_dims;
23 | const int image_h;
24 | const int image_w;
25 | const int image_c;
26 | const int heatmap_h;
27 | const int heatmap_w;
28 | const int heatmap_c;
29 |
30 | // 将权重赋给网络
31 | bool constructNetwork(sample_unique_ptr &parser, sample_unique_ptr &network);
32 | // 处理输入,读入图片并存入buffer中
33 | bool processInput(const samplesCommon::BufferManager &buffer_manager, const std::string &input_tensor_name, const std::string &image_path) const;
34 | // 输出后处理,得到最终结果
35 | std::vector> processOutput(const samplesCommon::BufferManager &buffer_manager, const std::string &output_tensor_name) const;
36 |
37 | public:
38 | explicit Keypoints(samplesCommon::UffSampleParams uff_params, InputParams input_params);
39 | bool build();
40 | bool infer();
41 | bool tearDown();
42 | };
43 |
44 | #endif //TENSORRT_METER_TENSORRT_H
45 |
--------------------------------------------------------------------------------
/tensorRT/c++/source/my_plugin.cpp:
--------------------------------------------------------------------------------
1 | // Created by luozhiwang (luozw1994@outlook.com)
2 | // Date: 2020/2/11
3 |
4 | #include "my_plugin.h"
5 |
6 | MyPlugin::~MyPlugin() {
7 | for (auto& item : mPluginUpSample){
8 | item.reset();
9 | }
10 | }
11 |
12 |
13 | const char *MyPlugin::getPluginName() const {
14 | return "ResizeNearestNeighbor";
15 | }
16 |
17 | const char *MyPlugin::getPluginVersion() const {
18 | return "2";
19 | }
20 |
21 | const nvinfer1::PluginFieldCollection* MyPlugin::getFieldNames() {
22 | // TODO 这里应该是依据参数创建PluginField的
23 | return &mFieldCollection;
24 | }
25 |
26 | nvinfer1::IPluginV2 *MyPlugin::createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) {
27 | if(!strcmp(name, "ResizeNearestNeighbor")){
28 | printf("Unkown Plugin Name %s", name);
29 | return nullptr;
30 | }
31 | mPluginUpSample.emplace_back(std::unique_ptr(new UffUpSamplePluginV2(*fc)));
32 | return mPluginUpSample.back().get();
33 |
34 | }
35 |
36 | nvinfer1::IPluginV2 *MyPlugin::deserializePlugin(const char *name, const void *serial_data, size_t serial_length) {
37 | auto plugin = new UffUpSamplePluginV2(serial_data, serial_length);
38 | mPluginName = name;
39 | return plugin;
40 | }
41 |
42 | void MyPlugin::setPluginNamespace(const char *plugin_name_space) {
43 | mNamespace = plugin_name_space;
44 | }
45 |
46 | const char *MyPlugin::getPluginNamespace() const {
47 | return mNamespace.c_str();
48 | }
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/tensorRT/c++/source/my_plugin.h:
--------------------------------------------------------------------------------
1 | // Created by luozhiwang (luozw1994@outlook.com)
2 | // Date: 2020/2/11
3 |
4 | #ifndef TENSORRT_MY_PLUGIN_H
5 | #define TENSORRT_MY_PLUGIN_H
6 |
7 | #include "ResizeNearestNeighbor.h"
8 | #include
9 | #include
10 |
11 | class MyPlugin : public nvinfer1::IPluginCreator {
12 | private:
13 | std::string mNamespace;
14 | std::string mPluginName;
15 | nvinfer1::PluginFieldCollection mFieldCollection{0, nullptr};
16 | std::vector> mPluginUpSample{};
17 | public:
18 | const char* getPluginName() const override;
19 | const char* getPluginVersion() const override;
20 | const nvinfer1::PluginFieldCollection *getFieldNames() override;
21 | nvinfer1::IPluginV2* createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) override;
22 | nvinfer1::IPluginV2* deserializePlugin(const char *name, const void *serial_data, size_t serial_length) override;
23 | void setPluginNamespace (const char *plugin_name_space) override;
24 | const char* getPluginNamespace() const override;
25 | ~MyPlugin();
26 | };
27 |
28 |
29 | #endif //TENSORRT_MY_PLUGIN_H
30 |
--------------------------------------------------------------------------------
/tensorRT/c++/source/utils.cpp:
--------------------------------------------------------------------------------
1 | // Created by luozhiwang (luozw1994@outlook.com)
2 | // Date: 2020/2/10
3 | #include "utils.h"
4 |
5 | InputParams::InputParams(int ih, int iw, int ic, int hh, int hw, int hc): image_h(ih), image_w(iw), image_c(ic), heatmap_h(hh), heatmap_w(hw), heatmap_c(hc){
6 |
7 | }
8 |
9 | std::vector imagePreprocess(const std::string &image_path, const int &image_h, const int &image_w){
10 | // image_path ===> BGR/HWC ===> RGB/CHW
11 | cv::Mat origin_image = cv::imread(image_path, cv::IMREAD_COLOR);
12 | cv::Mat rgb_image = origin_image;
13 | cv::cvtColor(origin_image, rgb_image, cv::COLOR_BGR2RGB);
14 | cv::Mat resized_image(image_h, image_w, CV_8UC3);
15 | cv::resize(rgb_image, resized_image, cv::Size(image_h, image_w));
16 | std::vector file_data(resized_image.reshape(1, 1));
17 | std::vector CHW;
18 | int c, h, w, idx;
19 | for (int i=0;i> &keypoints, int nums_keypoints, float thres=0.3){
31 | int image_h = image.rows;
32 | int image_w = image.cols;
33 | int point_x, point_y;
34 | for (int i=0; i=thres){
36 | point_x = image_w * keypoints[i][0];
37 | point_y = image_h * keypoints[i][1];
38 | cv::circle(image, cv::Point(point_x, point_y), 5, cv::Scalar(255, 204,0), 3);
39 | }
40 | }
41 | return image;
42 | }
43 |
44 |
45 | void saveImage(const cv::Mat &image, const std::string &save_path){
46 | cv::imwrite(save_path, image);
47 | }
48 |
49 |
--------------------------------------------------------------------------------
/tensorRT/c++/source/utils.h:
--------------------------------------------------------------------------------
1 | // Created by luozhiwang (luozw1994@outlook.com)
2 | // Date: 2020/2/7
3 |
4 | #ifndef TENSORRT_UTILS_H
5 | #define TENSORRT_UTILS_H
6 |
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 |
13 | class InputParams{
14 | public:
15 | const int image_h;
16 | const int image_w;
17 | const int image_c;
18 | const int heatmap_h;
19 | const int heatmap_w;
20 | const int heatmap_c;
21 | InputParams(int ih, int iw, int ic, int hh, int hw, int hc);
22 | };
23 |
24 | std::vector imagePreprocess(const std::string &image_path, const int &image_h, const int &image_w);
25 |
26 | cv::Mat renderKeypoint(cv::Mat image, const std::vector> &keypoints, int nums_keypoints, float thres);
27 |
28 | void saveImage(const cv::Mat &image, const std::string &save_path);
29 |
30 | template
31 | void write(char*& buffer, const T& val){
32 | *reinterpret_cast(buffer) = val;
33 | buffer += sizeof(T);
34 | }
35 |
36 | template
37 | T read(const char*& buffer)
38 | {
39 | T val = *reinterpret_cast(buffer);
40 | buffer += sizeof(T);
41 | return val;
42 | }
43 |
44 | #endif //TENSORRT_UTILS_H
45 |
--------------------------------------------------------------------------------
/tensorRT/python/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2020/2/16
8 | """
9 |
--------------------------------------------------------------------------------
/tensorRT/python/pb2uff.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-24
8 | """
9 | import uff
10 | # from tensorflow.contrib.tensorrt import trt_convert as trt
11 | import tensorrt as trt
12 | # from tensorflow.python.compiler.tensorrt import trt_convert as trt
13 | # TODO tf==1.12.0 只支持trt4
14 | pb_path = '../../Hourglass.pb'
15 | output_nodes = ["Keypoints/keypoint_1/conv/Sigmoid"]
16 | output_filename = 'Hourglass.uff'
17 |
18 | serialized=uff.from_tensorflow_frozen_model(pb_path, output_nodes, output_filename=output_filename)
19 | # print(serialized)
20 |
21 | # convert = trt.TrtGraphConverter(
22 | # input_graph_def=pb_path,
23 | # nodes_blacklist=output_nodes
24 | # )
25 | # frozen_graph = convert.convert()
--------------------------------------------------------------------------------
/tensorRT/python/readpb2graph.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-29
8 | """
9 | import tensorflow as tf
10 | from tensorflow.python.platform import gfile
11 |
12 |
13 | def readpb2graph(pb_path, log_dir):
14 | """
15 | transfer one pb file to visible graph in tensorboard
16 | You can build a model by tensorRT C++ API more easily!
17 | :param pb_path: pb_path
18 | :param log_dir: log_dir
19 | :return: None
20 | """
21 | with tf.Session() as sess:
22 | with gfile.FastGFile(pb_path, 'rb') as f:
23 | graph_def = tf.GraphDef()
24 | graph_def.ParseFromString(f.read())
25 | g_in = tf.import_graph_def(graph_def)
26 | train_writer = tf.summary.FileWriter(log_dir)
27 | train_writer.add_graph(sess.graph)
28 |
29 |
30 | if __name__ == '__main__':
31 | import os
32 | os.environ['CUDA_VISIBLE_DEVICES'] = '2'
33 | pb_path = '../Hourglass.pb'
34 |
--------------------------------------------------------------------------------
/tensorRT/python/tfpb2trtpb.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-27
8 | """
9 |
10 | import tensorflow as tf
11 | import tensorflow.contrib.tensorrt as trt
12 |
13 |
14 | def tfpb2trtpb(pb_path, output_pb, output_node_name):
15 | # Inference with TF-TRT frozen graph workflow:
16 | graph = tf.Graph()
17 | with graph.as_default():
18 | with tf.Session() as sess:
19 | # First deserialize your frozen graph:
20 | with tf.gfile.GFile(pb_path, 'rb') as f:
21 | graph_def = tf.GraphDef()
22 | graph_def.ParseFromString(f.read())
23 | # Now you can create a TensorRT inference graph from your
24 | # frozen graph:
25 | trt_graph = trt.create_inference_graph(
26 | input_graph_def=graph_def,
27 | outputs=output_node_name,
28 | max_batch_size=1,
29 | max_workspace_size_bytes=2 << 20,
30 | precision_mode='fp32')
31 |
32 | with tf.gfile.GFile(output_pb, "wb") as f: # 保存模型
33 | f.write(trt_graph.SerializeToString())
34 | # Import the TensorRT graph into a new graph and run:
35 | # output_node = tf.import_graph_def(
36 | # trt_graph,
37 | # return_elements=output_node_name)
38 | # sess.run(output_node)
39 |
40 |
41 | import os
42 | os.environ['CUDA_VISIBLE_DEVICES']='2'
43 | pb_path = '../Hourglass.pb'
44 | output_path = 'TensorRT.pb'
45 | output_node_name=['HourglassNet/keypoint_1/conv/BiasAdd']
46 | tfpb2trtpb(pb_path, output_path, output_node_name)
--------------------------------------------------------------------------------
/train_hourglass_coco.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-12
8 | """
9 | from core.train.trainer import Trainer
10 | from core.network.keypoints import Keypoints
11 | from core.dataset.data_generator import Dataset
12 | import config.config_hourglass_coco as cfg
13 | import time
14 | import tensorflow as tf
15 | import tensorflow.contrib.slim as slim
16 | import sys
17 |
18 | sys.path.append('.')
19 |
20 |
21 | class TrainHourglass(Trainer):
22 | def __init__(self, model, dataset, cfg):
23 | super(TrainHourglass, self).__init__(model, dataset, cfg)
24 |
25 | def init_model(self):
26 | # BN decay 0.9
27 | with slim.arg_scope([slim.batch_norm], decay=0.96):
28 | Trainer.init_model(self)
29 |
30 | def init_train_op(self):
31 | start_time = time.time()
32 | # TRAIN_OP
33 | with tf.name_scope("Train_op"):
34 | optimizer = tf.train.AdamOptimizer(
35 | self.learning_rate)
36 | # optimizer = tf.train.MomentumOptimizer(
37 | # self.learning_rate, 0.9)
38 | gvs = optimizer.compute_gradients(self.loss)
39 | clip_gvs = [(tf.clip_by_value(grad, -5., 5.), var)
40 | for grad, var in gvs]
41 | if self.is_debug:
42 | self.mean_gradient = tf.reduce_mean(
43 | [tf.reduce_mean(g) for g, v in gvs])
44 | tf.summary.scalar("mean_gradient", self.mean_gradient)
45 | print('Debug mode is on !!!')
46 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
47 | # It's important!
48 | # Update moving-average in BN
49 | self.train_op = optimizer.apply_gradients(
50 | clip_gvs, global_step=self.global_step)
51 | print('-Creat train op in %.3f' % (time.time() - start_time))
52 |
53 | def init_loader_saver(self):
54 | start_time = time.time()
55 | with tf.name_scope('loader_and_saver'):
56 | if self.pre_trained_ckpt is not None:
57 | from tensorflow.python import pywrap_tensorflow
58 | reader = pywrap_tensorflow.NewCheckpointReader(self.pre_trained_ckpt)
59 | var_to_shape_map = reader.get_variable_to_shape_map()
60 | var_to_restore = [k for k in var_to_shape_map]
61 | # var_ = [var for var in tf.global_variables() if var.name.strip(':0') in var_to_restore and var.name.strip(':0')!="Learning_rate/global_step" and "Momentum" not in var.name.strip(':0')]
62 | var_ = [var for var in tf.global_variables() if var.name.strip(':0') in var_to_restore and var.name.strip(':0')!="Learning_rate/global_step"]
63 | print('restore var total is %d' % len(var_))
64 | self.loader = tf.train.Saver(var_list=var_)
65 | self.saver = tf.train.Saver(
66 | var_list=tf.global_variables(),
67 | max_to_keep=self.max_keep)
68 | print(
69 | '-Creat loader saver in %.3f' %
70 | (time.time() - start_time))
71 |
72 |
73 | if __name__ == '__main__':
74 | trainer = TrainHourglass(Keypoints, Dataset, cfg)
75 | trainer.train_launch()
76 |
--------------------------------------------------------------------------------
/train_hourglass_mpii.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2019 luozw, Inc. All Rights Reserved
5 |
6 | Authors: luozhiwang(luozw1994@outlook.com)
7 | Date: 2019-09-26
8 | """
9 | from core.train.trainer import Trainer
10 | from core.network.keypoints import Keypoints
11 | from core.dataset.data_generator import Dataset
12 | import config.config_hourglass_mpii as cfg
13 |
14 | import sys
15 | sys.path.append('.')
16 |
17 |
18 | class TrainHourglass(Trainer):
19 | def __init__(self, model, dataset, cfg):
20 | super(TrainHourglass, self).__init__(model, dataset, cfg)
21 |
22 | def train_launch(self):
23 | self.is_debug = False
24 | # must in order
25 | self.init_dataset()
26 | self.init_inputs()
27 | self.init_model()
28 |
29 | # optional override
30 | self.init_loss()
31 | self.init_learning_rate()
32 | self.init_train_op()
33 | self.init_loader_saver_summary()
34 | self.init_session()
35 | self.train()
36 |
37 | if __name__ == '__main__':
38 | trainer = TrainHourglass(Keypoints, Dataset, cfg)
39 | trainer.train_launch()
40 |
--------------------------------------------------------------------------------