├── LICENSE ├── README.md ├── demo_img ├── cxk.jpg ├── cxk.mp4 ├── epoch0_step1500_i_1.jpg ├── epoch0_step200_i_1.jpg ├── epoch0_step500_i_0.jpg ├── epoch0_step900_i_1.jpg ├── epoch1_step1500_i_1.jpg ├── epoch1_step200_i_1.jpg ├── epoch1_step500_i_0.jpg ├── epoch1_step900_i_1.jpg ├── epoch2_step1500_i_1.jpg ├── epoch2_step200_i_1.jpg ├── epoch2_step500_i_0.jpg ├── epoch2_step900_i_1.jpg ├── epoch3_step1500_i_1.jpg ├── epoch3_step200_i_1.jpg ├── epoch3_step500_i_0.jpg ├── epoch3_step900_i_1.jpg └── result.jpg ├── src ├── __pycache__ │ ├── dataset.cpython-36.pyc │ ├── heatmap.cpython-36.pyc │ ├── hrnet.cpython-36.pyc │ └── utils.cpython-36.pyc ├── dataset.py ├── evaluate.py ├── heatmap.py ├── hrnet.py ├── temp.py ├── test.py ├── train.py └── utils.py └── test_img ├── step11_i_0.jpg └── step136_i_0.jpg /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 VXallset 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deep-high-resolution-net.TensorFlow 2 | A TensorFlow implementation of HRNet-32.The dataset used to train the model is the AI Challenger dataset. 3 | 4 | Just for fun! A **'famous' actor** CXK in China and the keypoints estimated using the HRNet-32. 5 |
6 | 7 | For more details, please refer to the [paper](https://arxiv.org/abs/1902.09212) and the [dataset](https://challenger.ai/competition/keypoint). 8 | 9 | # Environment 10 | - python 3.6 or higher 11 | - TensorFlow 1.11 or higher 12 | - PyCharm 13 | 14 | # How to Use 15 | ### For Training 16 | - Download the AI Challenger dataset. 17 | - Convert the images in the AI Challenger dataset (train_images folder) to TFRecords by running the dataset.py. Please make sure that the **dataset_root_path** you used in the **extract_people_from_dataset()** function is the path of the AI Challenger dataset you saved in the previous step. 18 | - Run the train.py! 19 | 20 | Please note that the structure of the HRNet is complicated. I trained the HRNet-32 network using 2 Nvidia Titan V graphics cards. As the limited of the graphics memory(16 GB), the max batch size I used was 2, and it took around 30 hours to finish 1 epoch (189176 steps). The model files were uploaded to [Google Drive](https://drive.google.com/drive/folders/13ll_UyKLW31ozasChqzB_91sWEE4I2PZ?usp=sharing) and [Baidu Cloud](https://pan.baidu.com/s/1bTmiP3MxxC17pF1S4pDpWQ) (Extraction code: 7hym). 21 | 22 | ### For Testing 23 | - Finish the 4 steps in the training. 24 | - Make sure the dataset name, mode file name are corrected. 25 | - Run the test.py! 26 | 27 | The result images will be saved in the _test_img_ folder. It will also generate the distances.npy and the classes.npy file, which will be used to calculate the AP50 and AP75 later. 28 | 29 | ### For Evaluating 30 | - Run the evaluate.py. 31 | 32 | It will print the AP50 and AP75 information in the command line. 33 | 34 | ### For Debugging 35 | If you encounter any problems, please try to run the _temp.py_ file to see if it can work properly. It is a simple demo file that can predict the human pose in the cxk.mp4 file. Compare to other scripts, this one is easier to debug. 36 | 37 | # What You Will See 38 | ### For Training 39 | - The loss information. 40 | - The examples of images predicted by the network will be saved into the _./demo_img/_ folder. 41 | 42 | Epoch Number | example image 1 | example image 2 | example image 3 | example image 4 43 | :-: | :-: | :-: | :-: | :-: 44 | epoch 0|
|
|
|
| 45 | epoch 1|
|
|
|
| 46 | epoch 2|
|
|
|
| 47 | epoch 3|
|
|
|
| 48 | 49 | ### For Testing 50 | - The result of testing images will be saved into the _./test_img/_ floder. 51 |
52 | 53 | # For More 54 | Contact me: vxallset@outlook.com 55 | -------------------------------------------------------------------------------- /demo_img/cxk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/cxk.jpg -------------------------------------------------------------------------------- /demo_img/cxk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/cxk.mp4 -------------------------------------------------------------------------------- /demo_img/epoch0_step1500_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch0_step1500_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch0_step200_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch0_step200_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch0_step500_i_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch0_step500_i_0.jpg -------------------------------------------------------------------------------- /demo_img/epoch0_step900_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch0_step900_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch1_step1500_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch1_step1500_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch1_step200_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch1_step200_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch1_step500_i_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch1_step500_i_0.jpg -------------------------------------------------------------------------------- /demo_img/epoch1_step900_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch1_step900_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch2_step1500_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch2_step1500_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch2_step200_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch2_step200_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch2_step500_i_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch2_step500_i_0.jpg -------------------------------------------------------------------------------- /demo_img/epoch2_step900_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch2_step900_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch3_step1500_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch3_step1500_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch3_step200_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch3_step200_i_1.jpg -------------------------------------------------------------------------------- /demo_img/epoch3_step500_i_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch3_step500_i_0.jpg -------------------------------------------------------------------------------- /demo_img/epoch3_step900_i_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/epoch3_step900_i_1.jpg -------------------------------------------------------------------------------- /demo_img/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/demo_img/result.jpg -------------------------------------------------------------------------------- /src/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/src/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/heatmap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/src/__pycache__/heatmap.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/hrnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/src/__pycache__/hrnet.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/src/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used to generate TFRecords using the AI Challenger dataset. 3 | 4 | @ Author: Yu Sun. vxallset@outlook.com 5 | 6 | @ Date created: Jun 04, 2019 7 | 8 | @ Last modified: Jun 27, 2019 9 | 10 | """ 11 | import numpy as np 12 | import os 13 | import time 14 | from skimage import io, draw 15 | from skimage.transform import resize 16 | from random import shuffle 17 | import tensorflow as tf 18 | import json 19 | 20 | """ 21 | # The image which contains a person is collected from the AI Challenger dataset in the following steps: 22 | 1. Get the coordinate of the bounding box in the original image. 23 | 2. Adjust the ratio of the bounding box to be 4:3 (height : width) 24 | 25 | Note that the coordinates of keypoints are also re-calculated when the foreground parts are clipped from the 26 | original images. 27 | 28 | """ 29 | 30 | 31 | def draw_points_on_img(img, point_ver, point_hor, point_class): 32 | for i in range(len(point_class)): 33 | if point_class[i] != 3: 34 | rr, cc = draw.circle(point_ver[i], point_hor[i], 10, (256, 192)) 35 | #draw.set_color(img, [rr, cc], [0., 0., 0.], alpha=5) 36 | img[rr, cc, :] = 0 37 | #io.imshow(img) 38 | #io.show() 39 | 40 | return img 41 | 42 | 43 | def draw_lines_on_img(img, point_ver, point_hor, point_class): 44 | line_list = [[0, 1], [1, 2], [3, 4], [4, 5], [6, 7], [7, 8], [9, 10], 45 | [10, 11], [12, 13], [13, 6], [13, 9], [13, 0], [13, 3]] 46 | 47 | # key point class: 1:visible, 2: not visible, 3: not marked 48 | for start_point_id in range(len(point_class)): 49 | if point_class[start_point_id] == 3: 50 | continue 51 | for end_point_id in range(len(point_class)): 52 | if point_class[end_point_id] == 3: 53 | continue 54 | 55 | if [start_point_id, end_point_id] in line_list: 56 | rr, cc = draw.line(int(point_ver[start_point_id]), int(point_hor[start_point_id]), 57 | int(point_ver[end_point_id]), int(point_hor[end_point_id])) 58 | draw.set_color(img, [rr, cc], [255, 0, 0]) 59 | 60 | return img 61 | 62 | 63 | def _int64_feature(value): 64 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 65 | 66 | 67 | def _bytes_feature(value): 68 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 69 | 70 | 71 | def _float_feature(value): 72 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 73 | 74 | 75 | def extract_people_from_dataset(dataset_root_path='../../../dataset/ai_challenger/', image_save_path='../dataset/imgs/', 76 | tfrecords_path='../dataset/', is_shuffle=True): 77 | """ 78 | This function is used to extract people from the AI Challenger dataset. The extract image will contain only one 79 | person each and will be saved as a single .jpg file. At last, the image and the the corresponding annotation will 80 | be saved into a .tfrecord file. 81 | 82 | :param dataset_root_path: the root path of the AI Challenger dataset. 83 | :param image_save_path: the path used to save the clipped images. 84 | :param tfrecord_path: the path used to save the .tfrecords file. 85 | :param is_shuffle: is shuffle. 86 | :return: None. 87 | """ 88 | annotation_file = os.path.join(dataset_root_path, 'keypoint_train_annotations_20170909.json') 89 | image_read_path = os.path.join(dataset_root_path, 'train_images') 90 | tfrecords_file = os.path.join(tfrecords_path, 'train.tfrecords') 91 | 92 | if not os.path.exists(tfrecords_path): 93 | os.mkdir(tfrecords_path) 94 | if os.path.exists(tfrecords_file): 95 | os.remove(tfrecords_file) 96 | if os.path.exists(image_save_path): 97 | useless = os.listdir(image_save_path) 98 | for onefile in useless: 99 | os.remove(os.path.join(image_save_path, onefile)) 100 | else: 101 | os.mkdir(image_save_path) 102 | 103 | saved_number = 0 104 | image_number = 0 105 | start_time = time.time() 106 | with tf.python_io.TFRecordWriter(tfrecords_file) as tfwriter: 107 | 108 | with open(annotation_file, 'r') as jsfile: 109 | data = json.load(jsfile) 110 | 111 | for one_item in data: 112 | img_id = one_item['image_id'] 113 | image_number += 1 114 | if image_number % 100 == 0: 115 | print('Processed {} images, extracted {} people from the dataset. ' 116 | 'time = {}'.format(image_number, saved_number, time.time() - start_time)) 117 | 118 | kps = one_item['keypoint_annotations'] 119 | boxes = one_item['human_annotations'] 120 | 121 | # read image 122 | img_filename = os.path.join(image_read_path, img_id + '.jpg') 123 | img = io.imread(img_filename) 124 | 125 | for i in range(len(boxes)): 126 | # construct the name of a human in the dictionary, 127 | # for example, the first one (when i = 0) is 'human1' 128 | human_name = 'human' + str(i+1) 129 | 130 | kp = kps[human_name] 131 | box = boxes[human_name] 132 | p1_hor, p1_ver, p2_hor, p2_ver = box 133 | foreground = img[p1_ver:p2_ver, p1_hor:p2_hor, :] 134 | 135 | try: 136 | foreground = resize(foreground, (256, 192, 3)) 137 | except ValueError: 138 | print('ValueError at image {} and {}'.format(image_number, human_name)) 139 | continue 140 | 141 | foreground = foreground * 255.0 142 | foreground_uint8 = np.uint8(foreground) 143 | 144 | kp_hor = (np.array(kp[0::3]) - p1_hor) / (p2_hor - p1_hor) * 192 145 | kp_ver = (np.array(kp[1::3]) - p1_ver) / (p2_ver - p1_ver) * 256 146 | kp_class = np.array(kp[2::3]) 147 | 148 | img_name = img_id + '_' + human_name + '.jpg' 149 | 150 | io.imsave(os.path.join(image_save_path, img_id + '_' + human_name + '.jpg'), foreground_uint8) 151 | 152 | example = tf.train.Example( 153 | features=tf.train.Features( 154 | feature={ 155 | 'image_name': _bytes_feature(img_name.encode()), 156 | 'image_raw': _bytes_feature(foreground_uint8.tobytes()), 157 | 'keypoints_ver': _bytes_feature(np.uint8(kp_ver).tobytes()), 158 | 'keypoints_hor': _bytes_feature(np.uint8(kp_hor).tobytes()), 159 | 'keypoints_class': _bytes_feature(np.uint8(kp_class).tobytes()) 160 | })) 161 | tfwriter.write(example.SerializeToString()) 162 | 163 | saved_number += 1 164 | print('Extracted {} people from the dataset in total.'.format(saved_number)) 165 | 166 | 167 | def decode_proto(proto): 168 | features = tf.parse_single_example(proto, 169 | features={ 170 | 'image_name': tf.FixedLenFeature([], tf.string), 171 | 'image_raw': tf.FixedLenFeature([], tf.string), 172 | 'keypoints_ver': tf.FixedLenFeature([], tf.string), 173 | 'keypoints_hor': tf.FixedLenFeature([], tf.string), 174 | 'keypoints_class': tf.FixedLenFeature([], tf.string), 175 | }) 176 | image_name = features['image_name'] 177 | 178 | image_raw = tf.decode_raw(features['image_raw'], out_type=np.uint8) 179 | image = tf.reshape(image_raw, [256, 192, 3]) 180 | 181 | keypoints_ver = tf.decode_raw(features['keypoints_ver'], out_type=np.uint8) 182 | keypoints_hor = tf.decode_raw(features['keypoints_hor'], out_type=np.uint8) 183 | keypoints_class = tf.decode_raw(features['keypoints_class'], out_type=np.uint8) 184 | return image_name, image, keypoints_ver, keypoints_hor, keypoints_class 185 | 186 | 187 | def decode_tfrecord(filename_queue): 188 | tfreader = tf.TFRecordReader() 189 | _, proto = tfreader.read(filename_queue) 190 | image_name, image, keypoints_ver, keypoints_hor, keypoints_class = decode_proto(proto) 191 | 192 | return image_name, image, keypoints_ver, keypoints_hor, keypoints_class 193 | 194 | 195 | def input_batch(datasetname, batch_size, num_epochs): 196 | """ 197 | This function is used to decode the TFrecord and return a batch of images as well as their information 198 | :param datasetname: the name of the TFrecord file. 199 | :param batch_size: the number of images in a batch 200 | :param num_epochs: the number of epochs 201 | :return: a batch of images as well as their information 202 | """ 203 | with tf.name_scope('input_batch'): 204 | # The shuffle transformation uses a finite-sized buffer to shuffle elements 205 | # in memory. The parameter is the number of elements in the buffer. For 206 | # completely uniform shuffling, set the parameter to be the same as the 207 | # number of elements in the dataset. 208 | mydataset = tf.data.TFRecordDataset(datasetname) 209 | mydataset = mydataset.map(decode_proto) 210 | 211 | # have no idea why I can't set the parameter of mydataset.shuffle to be the number of the dataset...... 212 | # mydataset = mydataset.shuffle(200) 213 | mydataset = mydataset.repeat(num_epochs * 2) 214 | # drop all the data that can't be used to make up a batch 215 | mydataset = mydataset.batch(batch_size, drop_remainder=True) 216 | iterator = mydataset.make_one_shot_iterator() 217 | 218 | nextelement = iterator.get_next() 219 | return nextelement 220 | 221 | 222 | def mytest(): 223 | tfrecord_file = '../dataset/train.tfrecords' 224 | 225 | filename_queue = tf.train.string_input_producer([tfrecord_file], num_epochs=None) 226 | image_name, image, keypoints_ver, keypoints_hor, keypoints_class = decode_tfrecord(filename_queue) 227 | 228 | with tf.Session() as sess: 229 | init_op = tf.global_variables_initializer() 230 | sess.run(init_op) 231 | coord = tf.train.Coordinator() 232 | threads = tf.train.start_queue_runners(coord=coord) 233 | try: 234 | # while not coord.should_stop(): 235 | for i in range(10): 236 | img_name, img, point_ver, point_hor, point_class = sess.run([image_name, image, keypoints_ver, 237 | keypoints_hor, keypoints_class]) 238 | 239 | print(img_name, point_hor, point_ver, point_class) 240 | 241 | for i in range(len(point_class)): 242 | if point_class[i] > 0: 243 | rr, cc = draw.circle(point_ver[i], point_hor[i], 10, (256, 192)) 244 | img[rr, cc, :] = 0 245 | 246 | io.imshow(img) 247 | io.show() 248 | 249 | except tf.errors.OutOfRangeError: 250 | print('Done reading') 251 | finally: 252 | coord.request_stop() 253 | 254 | 255 | if __name__ == '__main__': 256 | extract_people_from_dataset() 257 | #mytest() 258 | 259 | 260 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used to evaluate the performance of the model. Please run the test.py before running this file. 3 | 4 | @ Author: Yu Sun. vxallset@outlook.com 5 | 6 | @ Date created: Jun 04, 2019 7 | 8 | @ Last modified: Jun 27, 2019 9 | 10 | """ 11 | import numpy as np 12 | 13 | 14 | def calculate_sigma2s(distances): 15 | sigma2s = np.zeros(14, dtype=np.float) 16 | for keypoint_id in range(14): 17 | distance = distances[:, keypoint_id] 18 | distance2 = distance ** 2 19 | sigma2s[keypoint_id] = np.mean(distance2) 20 | return sigma2s 21 | 22 | 23 | def calculate_OKS(distances, classes): 24 | sigma2s = calculate_sigma2s(distances) 25 | sigmas = np.sqrt(sigma2s) 26 | oks = np.zeros(len(distances)) 27 | for id in range(len(distances)): 28 | one_distance = distances[id] 29 | one_class = classes[id] 30 | one_oks = np.sum(np.exp(-one_distance ** 2 / (2.0 * (1 * sigmas) ** 2)) * 31 | np.array(one_class != 3, dtype=np.int)) / np.sum(np.array(one_class != 3, dtype=np.int)) 32 | oks[id] = one_oks 33 | 34 | return oks 35 | 36 | 37 | if __name__ == '__main__': 38 | distance_file = 'distances.npy' 39 | classes_file = 'classes.npy' 40 | distances = np.load(distance_file) 41 | classes = np.load(classes_file) 42 | oks = calculate_OKS(distances, classes) 43 | oks50_mask = np.array(oks > 0.5, dtype=np.int) 44 | oks75_mask = np.array(oks > 0.75, dtype=np.int) 45 | ap50 = np.sum(oks50_mask) / len(oks50_mask) 46 | ap75 = np.sum(oks75_mask) / len(oks75_mask) 47 | print("AP50 = {}, AP75 = {}".format(ap50, ap75)) 48 | -------------------------------------------------------------------------------- /src/heatmap.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used to generate the heat map and other stuffs. 3 | 4 | @ Author: Yu Sun. vxallset@outlook.com 5 | 6 | @ Date created: Jun 04, 2019 7 | 8 | @ Last modified: Jun 27, 2019 9 | 10 | """ 11 | import numpy as np 12 | from skimage import io, draw 13 | from dataset import draw_lines_on_img 14 | 15 | 16 | def gaussian_kernel(kernel_length=3, sigma=1.): 17 | """ 18 | creates gaussian kernel with side length l and a sigma of sig 19 | """ 20 | 21 | ax = np.arange(-kernel_length // 2 + 1., kernel_length // 2 + 1.) 22 | xx, yy = np.meshgrid(ax, ax) 23 | 24 | kernel = np.exp(-0.5 * (np.square(xx) + np.square(yy)) / np.square(sigma)) 25 | 26 | return kernel / np.sum(kernel) 27 | 28 | 29 | def calculate_groundtruth_heatmap(keypoint_ver, keypoint_hor, kepoint_class, kernel_length=3, sigma=1.0): 30 | batch_size, keypoints_number = kepoint_class.shape 31 | assert kernel_length % 2 == 1, 'kernel_length must be odd!' 32 | kernel = gaussian_kernel(kernel_length=kernel_length, sigma=sigma) 33 | half_length = kernel_length // 2 34 | heatmap = np.zeros((batch_size, 256, 192, keypoints_number), dtype=np.float32) 35 | 36 | for b in range(batch_size): 37 | for n in range(keypoints_number): 38 | # if the keypoint class is 3, continue 39 | if kepoint_class[b, n] == 3: 40 | continue 41 | 42 | for i in range(-half_length, half_length + 1): 43 | for j in range(-half_length, half_length + 1): 44 | if keypoint_ver[b, n] + i >= 256 or keypoint_ver[b, n] + i < 0 \ 45 | or keypoint_hor[b, n] + j >= 192 or keypoint_hor[b, n] + j < 0: 46 | continue 47 | heatmap[b, keypoint_ver[b, n] + i, keypoint_hor[b, n] + j, n] += kernel[i + half_length, j + half_length] 48 | return heatmap 49 | 50 | 51 | def decode_output(net_output, threshold=0.0): 52 | batch_size, size_ver, size_hor, keypoints_number = net_output.shape 53 | kp_ver = np.zeros((batch_size, keypoints_number)) 54 | kp_hor = np.zeros_like(kp_ver) 55 | kp_class = np.ones_like(kp_hor) * 3 56 | 57 | for b in range(batch_size): 58 | for n in range(keypoints_number): 59 | max_index = np.argmax(net_output[b, :, :, n]) 60 | max_row = max_index // 192 61 | max_col = max_index % 192 62 | if net_output[b, max_row, max_col, n] > threshold: 63 | # print(net_output[b, max_row, max_col, n]) 64 | kp_ver[b, n] = max_row 65 | kp_hor[b, n] = max_col 66 | kp_class[b, n] = 1 67 | prediction = np.zeros((batch_size, keypoints_number * 3)) 68 | prediction[:, ::3] = kp_ver 69 | prediction[:, 1::3] = kp_hor 70 | prediction[:, 2::3] = kp_class 71 | return prediction 72 | 73 | 74 | def decode_pose(images, net_output, threshold=0.001): 75 | # key point class: 1:visible, 2: invisible, 3: not marked 76 | prediction = decode_output(net_output, threshold=threshold) 77 | 78 | batch_size, size_ver, size_hor, keypoints_number = net_output.shape 79 | kp_ver = prediction[:, ::3] 80 | kp_hor = prediction[:, 1::3] 81 | kp_class = prediction[:, 2::3] 82 | 83 | for b in range(batch_size): 84 | point_hor = kp_hor[b] 85 | point_ver = kp_ver[b] 86 | point_class = kp_class[b] 87 | images[b, :, :, :] = draw_lines_on_img(images[b], point_ver, point_hor, point_class) 88 | for i in range(len(point_class)): 89 | if point_class[i] != 3: 90 | rr, cc = draw.circle(point_ver[i], point_hor[i], 10, (256, 192)) 91 | images[b, rr, cc, :] = 0 92 | 93 | return images 94 | 95 | 96 | def calculate_distance(prediction, groundtruth): 97 | kp_ver_pred = prediction[:, ::3] 98 | kp_hor_pred = prediction[:, 1::3] 99 | kp_class_pred = prediction[:, 2::3] 100 | 101 | kp_ver_gt = groundtruth[:, ::3] 102 | kp_hor_gt = groundtruth[:, 1::3] 103 | kp_class_gt = groundtruth[:, 2::3] 104 | 105 | distance2 = (kp_ver_gt - kp_ver_pred) ** 2 + (kp_hor_gt - kp_hor_pred) ** 2 106 | mask = np.array(kp_class_gt != 3, dtype=np.int) 107 | result = np.sqrt(distance2) * mask 108 | return result 109 | -------------------------------------------------------------------------------- /src/hrnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the structure of the HRNet-32, an implementation of the CVPR 2019 paper "Deep High-Resolution Representation 3 | Learning for Human Pose Estimation" using TensorFlow. 4 | 5 | @ Author: Yu Sun. vxallset@outlook.com 6 | 7 | @ Date created: Jun 04, 2019 8 | 9 | @ Last modified: Jun 06, 2019 10 | 11 | """ 12 | import tensorflow as tf 13 | from utils import * 14 | 15 | 16 | def stage1(input, name='stage1', is_training=True): 17 | output = [] 18 | with tf.variable_scope(name): 19 | s1_res1 = residual_unit_bottleneck(input, name='rs1', is_training=is_training) 20 | s1_res2 = residual_unit_bottleneck(s1_res1, name='rs2', is_training=is_training) 21 | s1_res3 = residual_unit_bottleneck(s1_res2, name='rs3', is_training=is_training) 22 | s1_res4 = residual_unit_bottleneck(s1_res3, name='rs4', is_training=is_training) 23 | output.append(conv_2d(s1_res4, channels=32, activation=leaky_Relu, name=name + '_output', 24 | is_training=is_training)) 25 | return output 26 | 27 | 28 | def stage2(input, name='stage2', is_training=True): 29 | with tf.variable_scope(name): 30 | sub_networks = exchange_between_stage(input, name='between_stage', is_training=is_training) 31 | sub_networks = exchange_block(sub_networks, name='exchange_block', is_training=is_training) 32 | return sub_networks 33 | 34 | 35 | def stage3(input, name='stage3', is_training=True): 36 | with tf.variable_scope(name): 37 | sub_networks = exchange_between_stage(input, name=name, is_training=is_training) 38 | sub_networks = exchange_block(sub_networks, name='exchange_block1', is_training=is_training) 39 | sub_networks = exchange_block(sub_networks, name='exchange_block2', is_training=is_training) 40 | sub_networks = exchange_block(sub_networks, name='exchange_block3', is_training=is_training) 41 | sub_networks = exchange_block(sub_networks, name='exchange_block4', is_training=is_training) 42 | return sub_networks 43 | 44 | 45 | def stage4(input, name='stage4', is_training=True): 46 | with tf.variable_scope(name): 47 | sub_networks = exchange_between_stage(input, name=name, is_training=is_training) 48 | sub_networks = exchange_block(sub_networks, name='exchange_block1', is_training=is_training) 49 | sub_networks = exchange_block(sub_networks, name='exchange_block2', is_training=is_training) 50 | sub_networks = exchange_block(sub_networks, name='exchange_block3', is_training=is_training) 51 | return sub_networks 52 | 53 | 54 | def HRNet(input, is_training=True, eps=1e-10): 55 | output = stage1(input=input, is_training=is_training) 56 | output = stage2(input=output, is_training=is_training) 57 | output = stage3(input=output, is_training=is_training) 58 | output = stage4(input=output, is_training=is_training) 59 | 60 | # The output contains 4 sub-networks, we only need the first one, which contains information of all 61 | # resolution levels 62 | output = output[0] 63 | 64 | # using a 3x3 convolution to reduce the channels of feature maps to 14 (the number of keypoints) 65 | output = conv_2d(output, channels=14, kernel_size=3, batch_normalization=False, name='change_channel', 66 | is_training=is_training, activation=tf.nn.relu) 67 | # sigmoid can convert the output to the interval of (0, 1) 68 | # output = tf.nn.sigmoid(output, name='net_output') 69 | 70 | # If we don't normalize the value of the output to 1, the net may predict the values on all pixels to be 0, which 71 | # will make the loss of one image to be around 1.75 (batch_size = 1, 256, 192, 14). This is because that the value 72 | # of an 3 x 3 gaussian kernel is g = 73 | # [[0.07511361 0.1238414 0.07511361] 74 | # [0.1238414 0.20417996 0.1238414 ] 75 | # [0.07511361 0.1238414 0.07511361]] 76 | 77 | # so g^2 = 78 | # [[0.00564205 0.01533669 0.00564205] 79 | # [0.01533669 0.04168945 0.01533669] 80 | # [0.00564205 0.01533669 0.00564205]] 81 | # therefore, np.sum(g^2) * 14 = 1.75846 82 | 83 | # In order to avoid this from happening, we need to normalize the value of the net output by dividing the value on 84 | # all pixels by the sum of the value on that image (1, 256, 192, 1). Or we may calculate the classification loss 85 | # to indicate the class of the key points. 86 | 87 | 88 | # sum up the value on each pixels, the result should be a [batch_size, 14] tensor, then expend dim to be 89 | # [batch_size, 1, 1, 14] tensor so as to normalize the output 90 | output_sum = tf.expand_dims(tf.expand_dims(tf.reduce_sum(tf.reduce_sum(output, axis=-2), 91 | axis=-2), axis=-2), axis=-2, name='net_output_sum') 92 | 93 | output = tf.truediv(output, output_sum + eps, name='net_output_final') 94 | 95 | return output 96 | 97 | 98 | def mytest(): 99 | input = tf.ones((16, 256, 192, 3)) 100 | output = HRNet(input) 101 | 102 | print(output) 103 | 104 | 105 | def compute_loss(net_output, ground_truth): 106 | diff = tf.square(tf.subtract(net_output, ground_truth), name='square_difference') 107 | loss = tf.reduce_sum(diff, name='loss') 108 | #loss = tf.losses.mean_squared_error(ground_truth, net_output) 109 | 110 | return loss 111 | 112 | 113 | if __name__ == '__main__': 114 | mytest() 115 | -------------------------------------------------------------------------------- /src/temp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used to test the model using the AI Challenger dataset. 3 | 4 | @ Author: Yu Sun. vxallset@outlook.com 5 | 6 | @ Date created: Jun 04, 2019 7 | 8 | @ Last modified: Apr 13, 2019 9 | 10 | """ 11 | import numpy as np 12 | import tensorflow as tf 13 | from hrnet import * 14 | import dataset 15 | from heatmap import * 16 | import time 17 | import os 18 | from skimage import io 19 | from skimage.transform import resize 20 | import cv2 21 | 22 | def main(use_GPU = True): 23 | batch_size = 1 24 | num_epochs = 10 25 | image_numbers = 378352 26 | #image_numbers = 4500 27 | 28 | root_path = os.getcwd()[:-3] 29 | 30 | datasetname = os.path.join(root_path, 'dataset/test.tfrecords') 31 | model_folder = os.path.join(root_path, 'models/') 32 | modelfile = os.path.join(root_path, 'models/epoch2.ckpt-567528') 33 | 34 | global_step = tf.Variable(0, trainable=False) 35 | 36 | image_name, image, keypoints_ver, keypoints_hor, keypoints_class = dataset.input_batch( 37 | datasetname=datasetname, batch_size=batch_size, num_epochs=num_epochs) 38 | 39 | input_images = tf.placeholder(tf.float32, [None, 256, 192, 3]) 40 | ground_truth = tf.placeholder(tf.float32, [None, 256, 192, 14]) 41 | 42 | input_images = tf.cast(input_images / 255.0, tf.float32, name='change_type') 43 | net_output = HRNet(input=input_images) 44 | loss = compute_loss(net_output=net_output, ground_truth=ground_truth) 45 | 46 | saver = tf.train.Saver() 47 | device = '/gpu:0' 48 | if not use_GPU: 49 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 50 | device = '/cpu:0' 51 | 52 | video_capture = cv2.VideoCapture('./cxk.mp4') 53 | fps = video_capture.get(cv2.CAP_PROP_FPS) 54 | start_second = 0 55 | start_frame = fps * start_second 56 | video_capture.set(cv2.CAP_PROP_POS_FRAMES, start_frame) 57 | 58 | with tf.Session() as sess: 59 | with tf.device(device): 60 | sess.run(tf.global_variables_initializer()) 61 | saver.restore(sess=sess, save_path=modelfile) 62 | 63 | 64 | try: 65 | framid = 0 66 | while True: 67 | start_time = time.time() 68 | retval, img_data = video_capture.read() 69 | if not retval: 70 | break 71 | img_data = cv2.cvtColor(img_data, code=cv2.COLOR_BGR2RGB) 72 | _img = cv2.resize(img_data, (192, 256)) 73 | _img = np.array([_img]) 74 | 75 | tnet_output = sess.run(net_output, feed_dict={input_images: _img}) 76 | 77 | #prediction = decode_output(tnet_output, threshold=0.001) 78 | 79 | timgs = decode_pose(_img, tnet_output, threshold=0.001) 80 | resultimg = timgs[0]/ 255.0 81 | 82 | io.imsave('../demo_img/frame_{}.jpg'.format(framid), resultimg) 83 | framid += 1 84 | print('time = {}'.format(time.time() - start_time)) 85 | print('---------------------------------------------------------------------------------') 86 | 87 | except tf.errors.OutOfRangeError: 88 | print('End testing...') 89 | finally: 90 | total_time = time.time() - start_time 91 | print('Running time: {} s'.format(total_time)) 92 | print('Done!') 93 | 94 | 95 | if __name__ == '__main__': 96 | main(use_GPU=True) 97 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used to test the model using the AI Challenger dataset. 3 | 4 | @ Author: Yu Sun. vxallset@outlook.com 5 | 6 | @ Date created: Jun 04, 2019 7 | 8 | @ Last modified: Jun 27, 2019 9 | 10 | """ 11 | import numpy as np 12 | import tensorflow as tf 13 | from hrnet import * 14 | import dataset 15 | from heatmap import * 16 | import time 17 | import os 18 | 19 | from functools import reduce 20 | from operator import mul 21 | 22 | def get_num_params(): 23 | num_params = 0 24 | for variable in tf.trainable_variables(): 25 | shape = variable.get_shape() 26 | num_params += reduce(mul, [dim.value for dim in shape], 1) 27 | return num_params 28 | 29 | 30 | def main(device_option='/gpu:0'): 31 | batch_size = 1 32 | num_epochs = 10 33 | image_numbers = 378352 34 | #image_numbers = 4500 35 | 36 | root_path = os.getcwd()[:-3] 37 | 38 | datasetname = os.path.join(root_path, 'dataset/train.tfrecords') 39 | model_folder = os.path.join(root_path, 'models/') 40 | modelfile = os.path.join(root_path, 'models/epoch2.ckpt-567528') 41 | 42 | global_step = tf.Variable(0, trainable=False) 43 | 44 | image_name, image, keypoints_ver, keypoints_hor, keypoints_class = dataset.input_batch( 45 | datasetname=datasetname, batch_size=batch_size, num_epochs=num_epochs) 46 | 47 | input_images = tf.placeholder(tf.float32, [None, 256, 192, 3]) 48 | ground_truth = tf.placeholder(tf.float32, [None, 256, 192, 14]) 49 | 50 | input_images = tf.cast(input_images / 255.0, tf.float32, name='change_type') 51 | net_output = HRNet(input=input_images) 52 | loss = compute_loss(net_output=net_output, ground_truth=ground_truth) 53 | 54 | saver = tf.train.Saver() 55 | # os.environ['CUDA_VISIBLE_DEVICES'] = '' 56 | 57 | with tf.Session() as sess: 58 | with tf.device(device_option): 59 | sess.run(tf.global_variables_initializer()) 60 | saver.restore(sess=sess, save_path=modelfile) 61 | #print(get_num_params()) 62 | 63 | writer = tf.summary.FileWriter('../log/', sess.graph) 64 | start_time = time.time() 65 | try: 66 | distances = 0 67 | classes = 0 68 | for step in range(int(image_numbers / batch_size)): 69 | _img, _kp_ver, _kp_hor, _kp_class = sess.run( 70 | [image, keypoints_ver, keypoints_hor, keypoints_class]) 71 | _gt = calculate_groundtruth_heatmap(_kp_ver, _kp_hor, _kp_class) 72 | 73 | tloss, tnet_output = sess.run([loss, net_output], 74 | feed_dict={input_images: _img, ground_truth: _gt}) 75 | 76 | prediction = decode_output(tnet_output, threshold=0.001) 77 | gt_all = np.zeros((batch_size, 14*3)) 78 | gt_all[:, ::3] = _kp_ver 79 | gt_all[:, 1::3] = _kp_hor 80 | gt_all[:, 2::3] = _kp_class 81 | distance = calculate_distance(prediction, gt_all) 82 | 83 | if step == 0: 84 | distances = distance 85 | classes = _kp_class 86 | elif step == 1000: 87 | np.save('distances.npy', distances) 88 | np.save('classes.npy', classes) 89 | break 90 | else: 91 | distances = np.append(distances, distance, axis=0) 92 | classes = np.append(classes, _kp_class, axis=0) 93 | 94 | timgs = decode_pose(_img, tnet_output, threshold=0.001) 95 | for i in range(batch_size): 96 | io.imsave('../test_img/step{}_i_{}.jpg'.format(step, i), timgs[i]) 97 | print('Step = {:>6}/{:>6}, loss = {:.6f}, time = {}' 98 | .format(step, int(image_numbers / batch_size), tloss, 99 | time.time() - start_time)) 100 | print('---------------------------------------------------------------------------------') 101 | 102 | except tf.errors.OutOfRangeError: 103 | print('End testing...') 104 | finally: 105 | total_time = time.time() - start_time 106 | print('Running time: {} s'.format(total_time)) 107 | print('Done!') 108 | 109 | 110 | if __name__ == '__main__': 111 | main(device_option='/gpu:0') 112 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used to train the HRNet-32 model. 3 | 4 | @ Author: Yu Sun. vxallset@outlook.com 5 | 6 | @ Date created: Jun 04, 2019 7 | 8 | @ Last modified: Jun 27, 2019 9 | 10 | """ 11 | import numpy as np 12 | import tensorflow as tf 13 | from hrnet import * 14 | import dataset 15 | from heatmap import * 16 | import time 17 | import os 18 | 19 | 20 | def main(gpu_divice='/gpu:0'): 21 | is_training = True 22 | 23 | batch_size = 1 24 | num_epochs = 10 25 | image_numbers = 378352 26 | learning_rate = 0.001 27 | save_epoch_number = 1 28 | root_path = os.getcwd()[:-3] 29 | 30 | datasetname = os.path.join(root_path, 'dataset/train.tfrecords') 31 | model_folder = os.path.join(root_path, 'models/') 32 | modelfile = os.path.join(root_path, 'models/model.ckpt') 33 | 34 | global_step = tf.Variable(0, trainable=False) 35 | 36 | image_name, image, keypoints_ver, keypoints_hor, keypoints_class = dataset.input_batch( 37 | datasetname=datasetname, batch_size=batch_size, num_epochs=num_epochs) 38 | 39 | input_images = tf.placeholder(tf.float32, [None, 256, 192, 3]) 40 | ground_truth = tf.placeholder(tf.float32, [None, 256, 192, 14]) 41 | 42 | input_images = tf.cast(input_images / 255.0, tf.float32, name='change_type') 43 | net_output = HRNet(input=input_images, is_training=is_training) 44 | loss = compute_loss(net_output=net_output, ground_truth=ground_truth) 45 | 46 | saver = tf.train.Saver() 47 | train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step) 48 | 49 | with tf.Session() as sess: 50 | with tf.device(gpu_divice): 51 | sess.run(tf.global_variables_initializer()) 52 | 53 | writer = tf.summary.FileWriter('../log/', sess.graph) 54 | start_time = time.time() 55 | try: 56 | for epoch in range(num_epochs): 57 | epoch_time = time.time() 58 | for step in range(int(image_numbers / batch_size)): 59 | _img, _kp_ver, _kp_hor, _kp_class = sess.run( 60 | [image, keypoints_ver, keypoints_hor, keypoints_class]) 61 | _gt = calculate_groundtruth_heatmap(_kp_ver, _kp_hor, _kp_class) 62 | 63 | train_step.run(feed_dict={input_images: _img, ground_truth: _gt}) 64 | 65 | if step % 100 == 0: 66 | tloss, tnet_output = sess.run([loss, net_output], 67 | feed_dict={input_images: _img, ground_truth: _gt}) 68 | 69 | timgs = decode_pose(_img, tnet_output, threshold=0.0) 70 | for i in range(batch_size): 71 | io.imsave('../demo_img/epoch{}_step{}_i_{}.jpg'.format(epoch, step, i), timgs[i]) 72 | print('Epoch {:>2}/{}, step = {:>6}/{:>6}, loss = {:.6f}, time = {}' 73 | .format(epoch, num_epochs, step, int(image_numbers / batch_size), tloss, 74 | time.time() - epoch_time)) 75 | print('---------------------------------------------------------------------------------') 76 | if epoch % save_epoch_number == 0: 77 | saver.save(sess, model_folder + 'epoch{}.ckpt'.format(epoch), global_step=global_step) 78 | print('Model saved in: {}'.format(model_folder + 'epoch{}.ckpt'.format(epoch))) 79 | except tf.errors.OutOfRangeError: 80 | print('End training...') 81 | finally: 82 | total_time = time.time() - start_time 83 | saver.save(sess, modelfile, global_step=global_step) 84 | print('Model saved as: {}, runing time: {} s'.format(modelfile, total_time)) 85 | print('Done!') 86 | 87 | """ 88 | imgs, kp_vers, kp_hors, kp_classses = sess.run([output, keypoints_ver, keypoints_hor, keypoints_class]) 89 | img = imgs[0] 90 | kp_ver = kp_vers[0] 91 | kp_hor = kp_hors[0] 92 | kp_classs = kp_classses[0] 93 | 94 | dataset.draw_points_on_img(img, point_ver=kp_ver, point_hor=kp_hor, point_class=kp_classs) 95 | """ 96 | 97 | 98 | if __name__ == '__main__': 99 | main(gpu_divice='/gpu:0') 100 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the utils for deep learning, implemented with TensorFlow. 3 | 4 | @ Author: Yu Sun. vxallset@outlook.com 5 | 6 | @ Date created: Jun 04, 2019 7 | 8 | @ Last modified: Jun 06, 2019 9 | 10 | """ 11 | import tensorflow as tf 12 | 13 | 14 | def leaky_Relu(input, name=''): 15 | return tf.nn.leaky_relu(input, alpha=0.1, name=name + '_relu') 16 | 17 | 18 | def conv_2d(inputs, channels, kernel_size=3, strides=1, batch_normalization=True, activation=None, 19 | name='', padding='same', kernel_initializer=tf.random_normal_initializer(stddev=0.01), is_training=True): 20 | 21 | output = tf.layers.conv2d(inputs=inputs, filters=channels, kernel_size=kernel_size, strides=strides, 22 | padding=padding, name=name + '_conv', kernel_initializer=kernel_initializer) 23 | name = name + '_conv' 24 | 25 | if batch_normalization: 26 | output = tf.layers.batch_normalization(output, axis=-1, momentum=0.9, name=name+'_bn', training=is_training) 27 | name = name + '_bn' 28 | 29 | if activation: 30 | output = activation(output, name=name) 31 | 32 | return output 33 | 34 | 35 | def down_sampling(input, method='strided_convolution', rate=2, name='', activation=leaky_Relu, is_training=True): 36 | assert method == 'max_pooling' or method == 'strided_convolution', \ 37 | 'Unknown type of down_sample method! "strided_convolution" and "' \ 38 | 'max_pooling" are expected, but "' + method + '" is provided!' 39 | output = input 40 | 41 | if method == 'strided_convolution': 42 | _, _, _, channels = input.get_shape() 43 | channels = channels.value 44 | output = input 45 | loop_index = 1 46 | new_rate = rate 47 | while new_rate > 1: 48 | assert new_rate % 2 == 0, 'The rate of down_sampling (using "strided_convolution") must be the power of ' \ 49 | '2, but "{}" is provided!'.format(rate) 50 | output = conv_2d(output, channels=channels * (2 ** loop_index), strides=2, activation=activation, 51 | name=name + 'down_sampling' + '_x' + str(loop_index * 2), is_training=is_training) 52 | loop_index += 1 53 | new_rate = int(new_rate / 2) 54 | 55 | elif method == 'max_pooling': 56 | output = tf.layers.max_pooling2d(input, pool_size=rate, strides=rate, name=name+'_max_pooling') 57 | 58 | return output 59 | 60 | 61 | def up_sampling(input, channels, method='nearest_neighbor', rate=2, name='', activation=leaky_Relu, is_training=True): 62 | assert method == 'nearest_neighbor', 'Only "nearest_neighbor" method is supported now! ' \ 63 | 'However, "' + method + '" is provided.' 64 | output = input 65 | if method == 'nearest_neighbor': 66 | _, x, y, _= input.get_shape() 67 | x = x.value 68 | y = y.value 69 | 70 | output = tf.image.resize_nearest_neighbor(input, size=(x*rate, y*rate), name=name + '_upsampling') 71 | name += '_upsampling' 72 | output = conv_2d(output, channels=channels, kernel_size=1, activation=activation, 73 | name=name + '_align_channels', is_training=is_training) 74 | 75 | return output 76 | 77 | 78 | # Repeated multi-scale fusion (namely the exchange block) within a stage (the input and the output has the same number 79 | # of sub-networks) 80 | def exchange_within_stage(inputs, name='exchange_within_stage', is_training=True): 81 | with tf.variable_scope(name): 82 | subnetworks_number = len(inputs) 83 | outputs = [] 84 | 85 | # suppose i is the index of the input sub-network, o is the index of the output sub-network 86 | for o in range(subnetworks_number): 87 | one_subnetwork = 0 88 | for i in range(subnetworks_number): 89 | if i == o: 90 | # if in the same resolution 91 | temp_subnetwork = inputs[i] 92 | elif i - o < 0: 93 | # if the input resolution is greater the output resolution, down-sampling with rate 94 | # of 2 ** (o - i) 95 | temp_subnetwork = down_sampling(inputs[i], rate=2 ** (o - i), name='i_{}_o_{}'.format(i, o), 96 | is_training=is_training) 97 | else: 98 | # if the input resolution is smaller the output resolution, up-sampling with rate of 99 | # 2 ** (o - i) 100 | _, _, _, c = inputs[o].get_shape() 101 | temp_subnetwork = up_sampling(inputs[i], channels=c, rate=2 ** (i - o), 102 | name='i_{}_o_{}'.format(i, o), is_training=is_training) 103 | one_subnetwork = tf.add(temp_subnetwork, one_subnetwork, name='add_i_{}_o_{}'.format(i, o)) 104 | outputs.append(one_subnetwork) 105 | return outputs 106 | 107 | 108 | # Repeated multi-scale fusion (namely the exchange block) between two stages (the input and the output has the same 109 | # number of sub-networks) 110 | def exchange_between_stage(inputs, name='exchange_between_stage', is_training=True): 111 | subnetworks_number = len(inputs) 112 | outputs = [] 113 | 114 | # suppose i is the index of the input sub-network, o is the index of the output sub-network 115 | for o in range(subnetworks_number): 116 | one_subnetwork = 0 117 | for i in range(subnetworks_number): 118 | if i == o: 119 | # if in the same resolution 120 | temp_subnetwork = inputs[i] 121 | elif i - o < 0: 122 | # if the input resolution is greater the output resolution, down-sampling with rate 123 | # of 2 ** (o - i) 124 | temp_subnetwork = down_sampling(inputs[i], rate=2 ** (o - i), name='i_{}_o_{}'.format(i, o), 125 | is_training=is_training) 126 | else: 127 | # if the input resolution is smaller the output resolution, up-sampling with rate of 128 | # 2 ** (o - i) 129 | _, _, _, c = inputs[o].get_shape() 130 | temp_subnetwork = up_sampling(inputs[i], channels=c, rate=2 ** (i - o), 131 | name='i_{}_o_{}'.format(i, o), is_training=is_training) 132 | one_subnetwork = tf.add(temp_subnetwork, one_subnetwork, name='add_i_{}_o_{}'.format(i, o)) 133 | outputs.append(one_subnetwork) 134 | one_subnetwork = down_sampling(inputs[-1], rate=2, name='new_resolution', is_training=is_training) 135 | outputs.append(one_subnetwork) 136 | return outputs 137 | 138 | 139 | def residual_unit_bottleneck(input, name='RU_bottleneck', channels=64, is_training=True): 140 | """ 141 | Residual unit with bottleneck design, default width is 64. 142 | :param input: 143 | :param name: 144 | :return: 145 | """ 146 | _, _, _, c = input.get_shape() 147 | conv_1x1_1 = conv_2d(input, channels=channels, kernel_size=1, activation=leaky_Relu, name=name + '_conv1x1_1', 148 | is_training = is_training) 149 | conv_3x3 = conv_2d(conv_1x1_1, channels=channels, activation=leaky_Relu, name=name + '_conv3x3', 150 | is_training=is_training) 151 | conv_1x1_2 = conv_2d(conv_3x3, channels=c, kernel_size=1, name=name + '_conv1x1_2', is_training=is_training) 152 | _output = tf.add(input, conv_1x1_2, name=name + '_add') 153 | output = leaky_Relu(_output, name=name + '_out') 154 | return output 155 | 156 | 157 | def residual_unit(input, name='RU', is_training=True): 158 | """ 159 | Residual unit with two 3 x 3 convolution layers. 160 | :param input: 161 | :param name: 162 | :return: 163 | """ 164 | _, _, _, channels = input.get_shape() 165 | conv3x3_1 = conv_2d(inputs=input, channels=channels, activation=leaky_Relu, name=name + '_conv3x3_1', 166 | is_training=is_training) 167 | conv3x3_2 = conv_2d(inputs=conv3x3_1, channels=channels, name=name + '_conv3x3_2', is_training=is_training) 168 | _output = tf.add(input, conv3x3_2, name=name + '_add') 169 | output = leaky_Relu(_output, name=name + '_out') 170 | return output 171 | 172 | 173 | def exchange_block(inputs, name='exchange_block', is_training=True): 174 | with tf.variable_scope(name): 175 | output = [] 176 | level = 0 177 | for input in inputs: 178 | sub_network = residual_unit(input, name='level{}RU1'.format(level), is_training=is_training) 179 | sub_network = residual_unit(sub_network, name='level{}RU2'.format(level), is_training=is_training) 180 | sub_network = residual_unit(sub_network, name='level{}RU3'.format(level), is_training=is_training) 181 | sub_network = residual_unit(sub_network, name='level{}RU4'.format(level), is_training=is_training) 182 | output.append(sub_network) 183 | level += 1 184 | outputs = exchange_within_stage(output, is_training=is_training) 185 | return outputs -------------------------------------------------------------------------------- /test_img/step11_i_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/test_img/step11_i_0.jpg -------------------------------------------------------------------------------- /test_img/step136_i_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VXallset/deep-high-resolution-net.TensorFlow/d885abc6f8699f5dfd09b270170f3c68fbf32ac2/test_img/step136_i_0.jpg --------------------------------------------------------------------------------