├── .gitignore ├── README.md ├── media └── ori_images │ └── ski.jpg ├── pics ├── leaderboard.png └── ski.jpg ├── scripts ├── init_dir.sh ├── test.sh └── visualize.sh ├── src ├── __init__.py ├── dataset │ ├── __init__.py │ └── reader.py ├── models │ ├── __init__.py │ └── hourglass_ae.py └── utils │ ├── __init__.py │ ├── aux.py │ ├── nms.so │ └── visualize.py └── tools ├── test.py └── visualize_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | *.py[cod] 3 | *.so 4 | *.egg 5 | *.egg-info 6 | .idea 7 | dist 8 | build 9 | 10 | 11 | results 12 | tmp 13 | checkpoints 14 | !src/utils/nms.so 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | Codes for reproducing [Associative Embedding: End-to-End Learning for Joint Detection and Grouping](https://arxiv.org/abs/1611.05424) 4 | 5 | ## Results 6 | 7 |
8 | 9 |   10 | 11 |
12 | 13 | ## Contents 14 | 15 | 1. src: source codes including model, data reader and utils 16 | 2. tools: main functions to run testing or visualizing 17 | 3. scripts: scripts to run testing or visualizing 18 | 19 | Other directories are self-explanatory 20 | 21 | ## Require 22 | 23 | * Python2.7 24 | * OpenCV 25 | * Pytorch v0.3.0 26 | * CUDNN 27 | * numpy 28 | 29 | ## Instructions 30 | 31 | Pretrained model is available [here](https://pan.baidu.com/s/1nvKJlFz). Include the model in the `./checkpoints` directory or modify the variable `CHECKPOINT` in `./scripts/test.sh`. 32 | 33 | 34 | Run 35 | ``` 36 | ./scripts/init_dir.sh 37 | ``` 38 | to make necessary directories. 39 | 40 | Run 41 | ``` 42 | ./scripts/test.sh 43 | ``` 44 | to test model on images in `media` directory. Or you may change the variable `IMAGES_DIR` in `./scripts/test.sh` to test on your own images. 45 | 46 | In the same way, run 47 | ``` 48 | ./scripts/visualize.sh 49 | ``` 50 | to visualize results. The rendered images will be save at `./results/imgs` 51 | 52 | ## Output Format 53 | In the JSON format, results are as follow 54 | 55 | ```json 56 | [ 57 | { 58 | "image_id": "a0f6bdc065a602b7b84a67fb8d14ce403d902e0d", 59 | "keypoint_annotations": { 60 | "human1": [261, 294, 1, 281, 328, 1, 0, 0, 0, 213, 295, 1, 208, 346, 1, 192, 335, 1, 245, 375, 1, 255, 432, 1, 244, 494, 1, 221, 379, 1, 219, 442, 1, 226, 491, 1, 226, 256, 1, 231, 284, 1], 61 | "human2": [313, 301, 1, 305, 337, 1, 321, 345, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 313, 359, 1, 320, 409, 1, 311, 454, 1, 0, 0, 0, 330, 409, 1, 324, 446, 1, 337, 284, 1, 327, 302, 1], 62 | "human3": [373, 304, 1, 346, 286, 1, 332, 263, 1, 0, 0, 0, 0, 0, 0, 345, 313, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 363, 386, 1, 361, 424, 1, 361, 475, 1, 365, 273, 1, 369, 297, 1], 63 | } 64 | } 65 | ] 66 | ``` 67 | 68 | Detailed explanation could be found in the website [AI Challenger](https://challenger.ai/competition/keypoint/subject) 69 | 70 | ## Notice 71 | 72 | * The time to release training codes is not decided yet. 73 | * The model is simplified due to GPU memory limitation. The output feature map is 8 times smaller than the input image instead. 74 | 75 |
76 | 77 |
78 | -------------------------------------------------------------------------------- /media/ori_images/ski.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JakeRenn/AI-Challenger-Keypoints-pytorch/8bfcba9b49be53d67d711ff860018c998714538f/media/ori_images/ski.jpg -------------------------------------------------------------------------------- /pics/leaderboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JakeRenn/AI-Challenger-Keypoints-pytorch/8bfcba9b49be53d67d711ff860018c998714538f/pics/leaderboard.png -------------------------------------------------------------------------------- /pics/ski.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JakeRenn/AI-Challenger-Keypoints-pytorch/8bfcba9b49be53d67d711ff860018c998714538f/pics/ski.jpg -------------------------------------------------------------------------------- /scripts/init_dir.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p results 4 | mkdir -p results/imgs 5 | mkdir -p checkpoints 6 | mkdir -p tmp 7 | 8 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | IMAGES_DIR="./media/ori_images" 4 | RESULT_PATH="./results/test_results.json" 5 | CHECKPOINT="./checkpoints/model_best.pth.tar" 6 | 7 | export CUDA_VISIBLE_DEVICES=0 8 | 9 | python tools/test.py \ 10 | $IMAGES_DIR \ 11 | --test_results=$RESULT_PATH \ 12 | --resume=$CHECKPOINT \ 13 | 14 | -------------------------------------------------------------------------------- /scripts/visualize.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | IMAGES_DIR="./media/ori_images" 6 | OUTPUT_DIR="./results/imgs" 7 | RESULT_PATH="./results/test_results.json" 8 | 9 | python tools/visualize_test.py \ 10 | $RESULT_PATH \ 11 | $IMAGES_DIR \ 12 | $OUTPUT_DIR \ 13 | 14 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | -------------------------------------------------------------------------------- /src/dataset/reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | 5 | class TestReader(object): 6 | def __init__(self, data_dir, transform=None): 7 | 8 | self.points_num = 14 9 | 10 | self.img_height = 512 11 | self.img_width = 512 12 | self.label_height = self.img_height >> 3 13 | self.label_width = self.img_width >> 3 14 | self.group_height = self.img_height >> 3 15 | self.group_width = self.img_width >> 3 16 | 17 | self.img_ids = list() 18 | self.img_paths = list() 19 | self.transform = transform 20 | 21 | for filename in os.listdir(data_dir): 22 | img_id = filename.split('.')[0] 23 | img_path = os.path.join(data_dir, filename) 24 | 25 | self.img_ids.append(img_id) 26 | self.img_paths.append(img_path) 27 | 28 | assert len(self.img_ids) == len(self.img_paths) 29 | 30 | print "Size of data for test: %d" % self.__len__() 31 | 32 | def __len__(self): 33 | return len(self.img_ids) 34 | 35 | def __getitem__(self, idx): 36 | img = Image.open(self.img_paths[idx]) 37 | f_img = img.transpose(Image.FLIP_LEFT_RIGHT) 38 | width, height = img.size 39 | 40 | img_1 = self._resize(img, self.img_height, self.img_width) 41 | img_2 = self._resize(img, self.img_height + 128, self.img_width + 128) 42 | 43 | f_img_1 = self._resize(f_img, self.img_height, self.img_width) 44 | f_img_2 = self._resize(f_img, self.img_height + 128, self.img_width + 128) 45 | 46 | if self.transform: 47 | img_1 = self.transform(img_1) 48 | img_2 = self.transform(img_2) 49 | 50 | f_img_1 = self.transform(f_img_1) 51 | f_img_2 = self.transform(f_img_2) 52 | 53 | imgs = ( 54 | img_1, 55 | img_2, 56 | f_img_1, 57 | f_img_2, 58 | ) 59 | 60 | sample = (imgs, (self.label_height, self.label_width), (height, width), 61 | self.img_ids[idx], self.img_paths[idx]) 62 | 63 | return sample 64 | 65 | def _resize(self, img, img_height, img_width): 66 | out_img = img.resize((img_width, img_height), Image.BILINEAR) 67 | 68 | return out_img 69 | 70 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | -------------------------------------------------------------------------------- /src/models/hourglass_ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['HourglassNet', 'hg'] 6 | 7 | 8 | class ConvAct(nn.Module): 9 | def __init__(self, inplanes, planes, kernel_size, stride=1): 10 | super(ConvAct, self).__init__() 11 | padding = kernel_size / 2 12 | 13 | self.conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True) 14 | self.sigmoid = nn.Sigmoid() 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | x = x * self.sigmoid(x) 19 | 20 | return x 21 | 22 | class ActConv(nn.Module): 23 | def __init__(self, inplanes, planes, kernel_size, stride=1): 24 | super(ActConv, self).__init__() 25 | padding = kernel_size / 2 26 | 27 | self.sigmoid = nn.Sigmoid() 28 | self.conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True) 29 | 30 | def forward(self, x): 31 | x = x * self.sigmoid(x) 32 | x = self.conv(x) 33 | 34 | return x 35 | 36 | 37 | class Conv(nn.Module): 38 | def __init__(self, inplanes, planes, kernel_size, stride=1): 39 | super(Conv, self).__init__() 40 | padding = kernel_size / 2 41 | 42 | self.conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True) 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | 47 | return x 48 | 49 | def make_conv(in_f, out_f, kernel_size, stride=1, mode='ConvAct'): 50 | if mode == "ConvAct": 51 | return nn.Sequential( 52 | ConvAct(in_f, out_f, kernel_size, stride) 53 | ) 54 | elif mode == "Conv": 55 | return nn.Sequential( 56 | Conv(in_f, out_f, kernel_size, stride) 57 | ) 58 | elif mode == "ActConv": 59 | return nn.Sequential( 60 | ActConv(in_f, out_f, kernel_size, stride) 61 | ) 62 | 63 | 64 | class Hourglass(nn.Module): 65 | def __init__(self, f): 66 | super(Hourglass, self).__init__() 67 | self.n = 4 68 | self.kernel_size = 3 69 | self.f = f 70 | self.g = f >> 1 71 | 72 | self.upsample = nn.Upsample(scale_factor=2) 73 | self.hg = self._make_hg(self.n) 74 | 75 | def _make_conv(self, in_f, out_f, kernel_size): 76 | return nn.Sequential( 77 | ActConv(in_f, out_f, kernel_size) 78 | ) 79 | 80 | def _make_hg(self, n = 4): 81 | hg = [] 82 | f3 = self.f 83 | f2 = self.f + self.g 84 | f1 = self.f + self.g * 2 85 | f0 = self.f + self.g * 3 86 | ff = self.f + self.g * 4 87 | 88 | f_difc = { 89 | 3: (f3, f2), 90 | 2: (f2, f1), 91 | 1: (f1, f0) 92 | } 93 | for i in range(n): 94 | tmp = [] 95 | if i == 0: 96 | tmp.append(self._make_conv(f0, f0, self.kernel_size)) 97 | tmp.append(self._make_conv(f0, ff, self.kernel_size)) 98 | tmp.append(self._make_conv(ff, f0, self.kernel_size)) 99 | tmp.append(self._make_conv(ff, ff, self.kernel_size)) 100 | else: 101 | tmp.append(self._make_conv(f_difc[i][0], f_difc[i][0], self.kernel_size)) 102 | tmp.append(self._make_conv(f_difc[i][0], f_difc[i][1], self.kernel_size)) 103 | tmp.append(self._make_conv(f_difc[i][1], f_difc[i][0], self.kernel_size)) 104 | hg.append(nn.ModuleList(tmp)) 105 | return nn.ModuleList(hg) 106 | 107 | def _hg_forward(self, n, x): 108 | up1 = self.hg[n-1][0](x) 109 | low = F.max_pool2d(x, 2, stride=2) 110 | low = self.hg[n-1][1](low) 111 | 112 | if n > 1: 113 | low = self._hg_forward(n-1, low) 114 | else: 115 | low = self.hg[n-1][3](low) 116 | 117 | low = self.hg[n-1][2](low) 118 | up2 = self.upsample(low) 119 | 120 | return up1 + up2 121 | 122 | def forward(self, x): 123 | return self._hg_forward(self.n, x) 124 | 125 | class StageBase(nn.Module): 126 | def __init__(self, f): 127 | super(StageBase, self).__init__() 128 | self.maxpool = nn.MaxPool2d(2, stride=2) 129 | 130 | self.conv1 = make_conv(3, 64, 7, stride=2, mode='ConvAct') 131 | self.conv2 = make_conv(64, 128, 3, mode="ConvAct") 132 | 133 | self.conv3 = make_conv(128, 128, 3, mode='ConvAct') 134 | self.conv4 = make_conv(128, 128, 3, mode='ConvAct') 135 | 136 | self.conv5 = make_conv(128, 256, 3, mode='ConvAct') 137 | self.conv6 = make_conv(256, f, 3, mode='Conv') 138 | 139 | def forward(self, x): 140 | 141 | x = self.conv1(x) 142 | x = self.conv2(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.conv3(x) 146 | x = self.conv4(x) 147 | x = self.maxpool(x) 148 | 149 | x = self.conv5(x) 150 | x = self.conv6(x) 151 | 152 | return x 153 | 154 | class HourglassNet(nn.Module): 155 | 156 | def __init__(self, f=256, num_stacks=2 , num_classes=14, embedding_len=14): 157 | super(HourglassNet, self).__init__() 158 | 159 | self.num_stacks = num_stacks 160 | 161 | self.stage_base = self._make(StageBase(f)) 162 | 163 | hg, fc1, fc2, score, embedding, fc_, score_ = [], [], [], [], [], [], [] 164 | 165 | for i in range(num_stacks): 166 | hg.append(Hourglass(f)) 167 | fc1.append(make_conv(f, f, kernel_size=3, mode="ActConv")) 168 | fc2.append(make_conv(f, f, kernel_size=1, mode="ActConv")) 169 | score.append(make_conv(f, num_classes, kernel_size=1, mode="Conv")) 170 | embedding.append(make_conv(f, embedding_len, kernel_size=1, mode="Conv")) 171 | if i < num_stacks-1: 172 | fc_.append(make_conv(f, f, kernel_size=1, mode="Conv")) 173 | score_.append(make_conv(num_classes+embedding_len, f, kernel_size=1, mode="Conv")) 174 | 175 | self.hg = nn.ModuleList(hg) 176 | self.fc1 = nn.ModuleList(fc1) 177 | self.fc2 = nn.ModuleList(fc2) 178 | self.score = nn.ModuleList(score) 179 | self.embedding = nn.ModuleList(embedding) 180 | self.fc_ = nn.ModuleList(fc_) 181 | self.score_ = nn.ModuleList(score_) 182 | 183 | def _make(self, target): 184 | return nn.Sequential( 185 | target 186 | ) 187 | 188 | def forward(self, x): 189 | out = [] 190 | 191 | x = self.stage_base(x) 192 | 193 | for i in range(self.num_stacks): 194 | y = self.hg[i](x) 195 | y = self.fc1[i](y) 196 | y = self.fc2[i](y) 197 | score = self.score[i](y) 198 | embedding = self.embedding[i](y) 199 | out.append((score, embedding)) 200 | if i < self.num_stacks-1: 201 | fc_ = self.fc_[i](y) 202 | score_ = self.score_[i](torch.cat((score, embedding), dim=1)) 203 | x = x + fc_ + score_ 204 | 205 | return out 206 | 207 | 208 | def hg(**kwargs): 209 | model = HourglassNet(f=kwargs['f'], num_stacks=kwargs['num_stacks'], embedding_len=kwargs['embedding_len'], 210 | num_classes=kwargs['num_classes']) 211 | return model 212 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | -------------------------------------------------------------------------------- /src/utils/aux.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from nms import nms 4 | 5 | 6 | def nms_fm(detect_fm, group_fm, point_num=14, threshold=0.2, extra_space=9): 7 | """ 8 | :param detect_fm: fm of detection, shape=[C1, H, W], numpy array 9 | :param group_fm: fm of grouping, shape=[C2, H, W], numpy array 10 | :param point_num: the number of considered keypoint 11 | :param threshold: threshold to select activation, float 12 | :param extra_space: extra_space of the nms, int 13 | :return: list of tuple, ((y, x), score, tag), (y, x) -> coordinate, score -> confidence score, tag -> vector for grouping 14 | """ 15 | out_list = list() 16 | for idx in range(point_num): 17 | out_list.append([]) 18 | 19 | cc, yy, xx = nms(detect_fm, threshold, extra_space) 20 | cc = cc.tolist() 21 | yy = yy.tolist() 22 | xx = xx.tolist() 23 | assert len(yy) == len(xx) 24 | assert len(cc) == len(xx) 25 | tmp_len = len(xx) 26 | for i in xrange(tmp_len): 27 | c = cc[i] 28 | y = yy[i] 29 | x = xx[i] 30 | score = detect_fm[c, y, x] 31 | tag = group_fm[c, y, x, :] 32 | out_list[c].append(((y, x), score, tag)) 33 | 34 | return out_list 35 | 36 | 37 | def group_with_keypoint(in_list, ori_h, ori_w, cur_h, cur_w, threshold=1, min_part_num=3, point_num=14): 38 | def any_true(check_list): 39 | for item in check_list: 40 | for subitem in item: 41 | if subitem: 42 | return True 43 | return False 44 | 45 | def resize_coord(coord): 46 | y, x = coord 47 | y = int(float(y) / cur_h * ori_h + ori_h / cur_h / 2) 48 | x = int(float(x) / cur_w * ori_w + ori_w / cur_w / 2) 49 | return (y, x) 50 | 51 | def tag_dis(lhs, rhs): 52 | return np.linalg.norm(lhs - rhs) 53 | 54 | check_seq = [12, 13, 0, 3, 6, 9, 1, 2, 4, 5, 7, 8, 10, 11] 55 | check_list = list() 56 | out_dict = dict() 57 | human_count = 0 58 | for idx in range(point_num): 59 | item_len = len(in_list[idx]) 60 | check_list.append([True] * item_len) 61 | 62 | while (any_true(check_list)): 63 | human_count += 1 64 | human_name = "human%d" % (human_count) 65 | tmp_coords = np.zeros(point_num * 3, dtype=np.int32).reshape(point_num, 3) 66 | part_count = 0 67 | 68 | finish = False 69 | for i in check_seq: 70 | if finish: 71 | break 72 | for j in range(len(in_list[i])): 73 | if check_list[i][j]: 74 | cur_coord, score, tag = in_list[i][j] 75 | y, x = resize_coord(cur_coord) 76 | tmp_coords[i][0] = x 77 | tmp_coords[i][1] = y 78 | tmp_coords[i][2] = 1 79 | check_list[i][j] = False 80 | others = [k for k in range(point_num) if k != i] 81 | for ii in others: 82 | max_score = 0. 83 | for jj in range(len(in_list[ii])): 84 | if check_list[ii][jj]: 85 | cur_coord, sub_score, sub_tag = in_list[ii][jj] 86 | yy, xx = resize_coord(cur_coord) 87 | if tag_dis(tag, sub_tag) < threshold and check_list[ii][jj] and sub_score > max_score: 88 | max_score = sub_score 89 | tmp_coords[ii][0] = xx 90 | tmp_coords[ii][1] = yy 91 | tmp_coords[ii][2] = 1 92 | check_list[ii][jj] = False 93 | part_count += 1 94 | finish = True 95 | break 96 | if part_count >= min_part_num: 97 | out_dict[human_name] = tmp_coords.reshape(-1).tolist() 98 | if len(out_dict) == 0: 99 | out_dict['human1'] = [0] * point_num * 3 100 | return out_dict 101 | 102 | 103 | def flip_fm(fm): 104 | left_right_pair = ( 105 | (0, 3), 106 | (1, 4), 107 | (2, 5), 108 | (6, 9), 109 | (7, 10), 110 | (8, 11), 111 | ) 112 | out_fm = np.empty_like(fm) 113 | out_fm[12, :, :] = fm[12, :, :] 114 | out_fm[13, :, :] = fm[13, :, :] 115 | 116 | for idx1, idx2 in left_right_pair: 117 | out_fm[idx1, :, :] = fm[idx2, :, :] 118 | out_fm[idx2, :, :] = fm[idx1, :, :] 119 | 120 | out_fm = out_fm[:, :, ::-1] 121 | return out_fm 122 | 123 | 124 | def integrate_fm_group(detect_fm_list, group_fm_list, height, width): 125 | """ 126 | :param detect_fm_list: list of detect_fm 127 | :param group_fm_list: list of group_fm 128 | :return: integrated detect_fm and group_fm 129 | """ 130 | resized_detect_list = list() 131 | resized_group_list = list() 132 | 133 | for fm in detect_fm_list: 134 | resized_detect_list.append(resize_fm(fm, height, width)) 135 | 136 | for fm in group_fm_list: 137 | tmp = list() 138 | resized_group_list.append(resize_fm(fm, height, width)) 139 | out_detect_fm = sum(resized_detect_list) / len(resized_detect_list) 140 | out_group_fm = np.stack(resized_group_list, axis=-1) 141 | 142 | return out_detect_fm, out_group_fm 143 | 144 | 145 | def resize_fm(fm, dst_h, dst_w): 146 | """ 147 | :param fm: [C, H, W] 148 | :param dst_h: 149 | :param dst_w: 150 | :return: 151 | """ 152 | out_list = list() 153 | for item in fm: 154 | out_list.append(cv2.resize(item, (dst_w, dst_h))) 155 | output = np.stack(out_list, axis=0) 156 | return output 157 | 158 | -------------------------------------------------------------------------------- /src/utils/nms.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JakeRenn/AI-Challenger-Keypoints-pytorch/8bfcba9b49be53d67d711ff860018c998714538f/src/utils/nms.so -------------------------------------------------------------------------------- /src/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import cv2 3 | import numpy as np 4 | from PIL import ImageDraw 5 | 6 | def draw_point_on_img_dict(img, in_dict, label=None, point_num = 14): 7 | """ 8 | :param img: ori_img, [B, 3, H, W] 9 | """ 10 | colors = ( 11 | (0, 255, 255), 12 | (255, 0, 255), 13 | (255, 255, 0), 14 | (255, 0, 0), 15 | (0, 255, 0), 16 | (0, 0, 255), 17 | (0, 127, 255), 18 | (255, 0, 127), 19 | (127, 255, 0), 20 | (0, 255, 127), 21 | (127, 0, 255), 22 | (255, 127, 0) 23 | ) 24 | line_pair = ( 25 | (1, 2), 26 | (2, 3), 27 | (4, 5), 28 | (5, 6), 29 | (7, 8), 30 | (8, 9), 31 | (10, 11), 32 | (11, 12), 33 | (13, 14), 34 | ) 35 | 36 | if label is not None: 37 | length = len(label) 38 | for idx in xrange(length): 39 | coords = label[idx].reshape(point_num, 3).astype(np.int32) 40 | xx = coords[:, 0] 41 | yy = coords[:, 1] 42 | vv = coords[:, 2] 43 | for i in xrange(point_num): 44 | x = xx[i] 45 | y = yy[i] 46 | v = vv[i] 47 | if v == 1: 48 | cv2.circle(img, (x, y), 2, (255, 255, 255), 2) 49 | 50 | idx = 0 51 | for key, val in in_dict.items(): 52 | coords = np.array(val).reshape(point_num, 3).astype(np.int32) 53 | xx = coords[:, 0] 54 | yy = coords[:, 1] 55 | vv = coords[:, 2] 56 | for i in xrange(point_num): 57 | x = xx[i] 58 | y = yy[i] 59 | v = vv[i] 60 | if v == 1: 61 | cv2.circle(img, (x, y), 2, colors[idx], 2) 62 | for i1, i2 in line_pair: 63 | i1 -= 1 64 | i2 -= 1 65 | 66 | x1 = xx[i1] 67 | y1 = yy[i1] 68 | v1 = vv[i1] 69 | 70 | x2 = xx[i2] 71 | y2 = yy[i2] 72 | v2 = vv[i2] 73 | 74 | if v1 == 1 and v2 == 1: 75 | cv2.line(img, (x1, y1), (x2, y2), colors[idx]) 76 | if idx < len(colors): 77 | idx += 1 78 | else: 79 | idx = len(colors) - 1 80 | 81 | plt.imshow(img) 82 | plt.show() 83 | 84 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | import argparse 5 | import json 6 | import math 7 | 8 | import torch 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim 12 | import torch.utils.data 13 | import torchvision.transforms as transforms 14 | 15 | import src.models.hourglass_ae 16 | import src.dataset.reader 17 | import src.utils.aux 18 | 19 | parser = argparse.ArgumentParser(description='AI Challenger Keypoints Detection') 20 | parser.add_argument('test_data_dir', metavar='DIR', 21 | help='path to test_data_dir') 22 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 23 | help='number of data loading workers (default: 4)') 24 | parser.add_argument('-b', '--batch-size', default=1, type=int, 25 | metavar='N', help='mini-batch size (default: 1)') 26 | parser.add_argument('--checkpoint_path', default='./results/hg_ae.pth.tar', type=str, metavar='PATH', 27 | help='path to checkpoint') 28 | parser.add_argument('--test_results', default='./results/multi_test_results.json', type=str, metavar='PATH', 29 | help='path to test results.') 30 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 31 | help='path to checkpoint') 32 | 33 | best_prec1 = 0 34 | 35 | def main(): 36 | global args, best_prec1 37 | args = parser.parse_args() 38 | 39 | # create model 40 | model = src.models.hourglass_ae.hg(f=256, num_stacks=3, num_classes=14, embedding_len=14) 41 | 42 | model = torch.nn.DataParallel(model).cuda() 43 | 44 | checkpoint = torch.load(args.resume) 45 | model.load_state_dict(checkpoint['state_dict']) 46 | 47 | all_model = (model, ) 48 | 49 | cudnn.benchmark = True 50 | 51 | normalize = transforms.Normalize(mean=[0.4798, 0.4517, 0.4220], 52 | std=[0.2558, 0.2481, 0.2468]) 53 | test_dataset = src.dataset.reader.TestReader( 54 | data_dir=args.test_data_dir, 55 | transform=transforms.Compose([ 56 | transforms.ToTensor(), 57 | normalize 58 | ]) 59 | ) 60 | test_loader = torch.utils.data.DataLoader( 61 | test_dataset, batch_size=1, shuffle=False, 62 | num_workers=args.workers, pin_memory=True) 63 | 64 | test(test_loader, all_model) 65 | 66 | def test(test_loader, models): 67 | # switch to evaluate mode 68 | for model in models: 69 | model.eval() 70 | 71 | out_data = list() 72 | print "Testing..." 73 | for i, (img, cur_size, ori_size, img_id, img_path) in enumerate(test_loader): 74 | img_id = img_id[0] 75 | 76 | ori_h, ori_w = ori_size 77 | ori_h = ori_h.numpy()[0] 78 | ori_w = ori_w.numpy()[0] 79 | 80 | cur_h, cur_w = cur_size 81 | cur_h = cur_h.numpy()[0] 82 | cur_w = cur_w.numpy()[0] 83 | 84 | detect_list = list() 85 | group_list = list() 86 | for model in models: 87 | n = len(img) 88 | for img_idx, input_img in enumerate(img): 89 | input_img = input_img.cuda(async=True) 90 | img_var = torch.autograd.Variable(input_img, volatile=True) 91 | 92 | output = model(img_var) 93 | output = output[-1] 94 | out_detect = output[0] 95 | out_group = output[1] 96 | 97 | out_detect = out_detect.data.cpu().numpy()[0] 98 | out_group = out_group.data.cpu().numpy()[0] 99 | if img_idx >= n/2: 100 | # The back half of the images is flipped 101 | out_detect = src.utils.aux.flip_fm(out_detect) 102 | out_group = src.utils.aux.flip_fm(out_group) 103 | detect_list.append(out_detect) 104 | group_list.append(out_group) 105 | 106 | out_detect, out_group = src.utils.aux.integrate_fm_group(detect_list, group_list, cur_h, cur_w) 107 | 108 | out_list = src.utils.aux.nms_fm(out_detect, out_group, threshold=0.3, extra_space=5) 109 | keypoint_annos = src.utils.aux.group_with_keypoint(out_list, ori_h, ori_w, 110 | cur_h, cur_w, threshold=0.5 * math.sqrt(n * len(models))) 111 | if (i + 1) % 1000 == 0: 112 | print ("finish %d images" % (i+1)) 113 | 114 | item = dict() 115 | item['image_id'] = img_id 116 | item['keypoint_annotations'] = keypoint_annos 117 | out_data.append(item) 118 | 119 | print ("Saving results at %s" % args.test_results) 120 | with open(args.test_results, 'wb') as fw: 121 | json.dump(out_data, fw) 122 | 123 | return 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /tools/visualize_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | import argparse 5 | import os 6 | import json 7 | import random 8 | import cv2 9 | 10 | import src.utils.visualize 11 | 12 | parser = argparse.ArgumentParser(description='AI Challenger') 13 | parser.add_argument('test_results', metavar='FILE', 14 | help='path to test_results') 15 | parser.add_argument('test_data_dir', metavar='DIR', 16 | help='path to test_data_dir') 17 | parser.add_argument('output_dir', metavar='DIR', 18 | help='path to output_dir') 19 | 20 | def main(): 21 | global args 22 | args = parser.parse_args() 23 | 24 | with open(args.test_results, 'rb') as fr: 25 | data = json.load(fr) 26 | random.shuffle(data) 27 | 28 | for item in data: 29 | 30 | img_path = os.path.join(args.test_data_dir, item['image_id'] + '.jpg') 31 | img_np = cv2.imread(img_path) 32 | img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) 33 | 34 | keypoint_dict = item['keypoint_annotations'] 35 | 36 | src.utils.visualize.draw_point_on_img_dict(img_np, keypoint_dict) 37 | 38 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 39 | cv2.imwrite(os.path.join(args.output_dir, item['image_id']+".jpg"), img_np) 40 | 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | --------------------------------------------------------------------------------