├── .gitignore
├── README.md
├── media
└── ori_images
│ └── ski.jpg
├── pics
├── leaderboard.png
└── ski.jpg
├── scripts
├── init_dir.sh
├── test.sh
└── visualize.sh
├── src
├── __init__.py
├── dataset
│ ├── __init__.py
│ └── reader.py
├── models
│ ├── __init__.py
│ └── hourglass_ae.py
└── utils
│ ├── __init__.py
│ ├── aux.py
│ ├── nms.so
│ └── visualize.py
└── tools
├── test.py
└── visualize_test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # python
2 | *.py[cod]
3 | *.so
4 | *.egg
5 | *.egg-info
6 | .idea
7 | dist
8 | build
9 |
10 |
11 | results
12 | tmp
13 | checkpoints
14 | !src/utils/nms.so
15 |
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | Codes for reproducing [Associative Embedding: End-to-End Learning for Joint Detection and Grouping](https://arxiv.org/abs/1611.05424)
4 |
5 | ## Results
6 |
7 |
8 |

9 |
10 |

11 |
12 |
13 | ## Contents
14 |
15 | 1. src: source codes including model, data reader and utils
16 | 2. tools: main functions to run testing or visualizing
17 | 3. scripts: scripts to run testing or visualizing
18 |
19 | Other directories are self-explanatory
20 |
21 | ## Require
22 |
23 | * Python2.7
24 | * OpenCV
25 | * Pytorch v0.3.0
26 | * CUDNN
27 | * numpy
28 |
29 | ## Instructions
30 |
31 | Pretrained model is available [here](https://pan.baidu.com/s/1nvKJlFz). Include the model in the `./checkpoints` directory or modify the variable `CHECKPOINT` in `./scripts/test.sh`.
32 |
33 |
34 | Run
35 | ```
36 | ./scripts/init_dir.sh
37 | ```
38 | to make necessary directories.
39 |
40 | Run
41 | ```
42 | ./scripts/test.sh
43 | ```
44 | to test model on images in `media` directory. Or you may change the variable `IMAGES_DIR` in `./scripts/test.sh` to test on your own images.
45 |
46 | In the same way, run
47 | ```
48 | ./scripts/visualize.sh
49 | ```
50 | to visualize results. The rendered images will be save at `./results/imgs`
51 |
52 | ## Output Format
53 | In the JSON format, results are as follow
54 |
55 | ```json
56 | [
57 | {
58 | "image_id": "a0f6bdc065a602b7b84a67fb8d14ce403d902e0d",
59 | "keypoint_annotations": {
60 | "human1": [261, 294, 1, 281, 328, 1, 0, 0, 0, 213, 295, 1, 208, 346, 1, 192, 335, 1, 245, 375, 1, 255, 432, 1, 244, 494, 1, 221, 379, 1, 219, 442, 1, 226, 491, 1, 226, 256, 1, 231, 284, 1],
61 | "human2": [313, 301, 1, 305, 337, 1, 321, 345, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 313, 359, 1, 320, 409, 1, 311, 454, 1, 0, 0, 0, 330, 409, 1, 324, 446, 1, 337, 284, 1, 327, 302, 1],
62 | "human3": [373, 304, 1, 346, 286, 1, 332, 263, 1, 0, 0, 0, 0, 0, 0, 345, 313, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 363, 386, 1, 361, 424, 1, 361, 475, 1, 365, 273, 1, 369, 297, 1],
63 | }
64 | }
65 | ]
66 | ```
67 |
68 | Detailed explanation could be found in the website [AI Challenger](https://challenger.ai/competition/keypoint/subject)
69 |
70 | ## Notice
71 |
72 | * The time to release training codes is not decided yet.
73 | * The model is simplified due to GPU memory limitation. The output feature map is 8 times smaller than the input image instead.
74 |
75 |
76 |

77 |
78 |
--------------------------------------------------------------------------------
/media/ori_images/ski.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JakeRenn/AI-Challenger-Keypoints-pytorch/8bfcba9b49be53d67d711ff860018c998714538f/media/ori_images/ski.jpg
--------------------------------------------------------------------------------
/pics/leaderboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JakeRenn/AI-Challenger-Keypoints-pytorch/8bfcba9b49be53d67d711ff860018c998714538f/pics/leaderboard.png
--------------------------------------------------------------------------------
/pics/ski.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JakeRenn/AI-Challenger-Keypoints-pytorch/8bfcba9b49be53d67d711ff860018c998714538f/pics/ski.jpg
--------------------------------------------------------------------------------
/scripts/init_dir.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mkdir -p results
4 | mkdir -p results/imgs
5 | mkdir -p checkpoints
6 | mkdir -p tmp
7 |
8 |
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | IMAGES_DIR="./media/ori_images"
4 | RESULT_PATH="./results/test_results.json"
5 | CHECKPOINT="./checkpoints/model_best.pth.tar"
6 |
7 | export CUDA_VISIBLE_DEVICES=0
8 |
9 | python tools/test.py \
10 | $IMAGES_DIR \
11 | --test_results=$RESULT_PATH \
12 | --resume=$CHECKPOINT \
13 |
14 |
--------------------------------------------------------------------------------
/scripts/visualize.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | export CUDA_VISIBLE_DEVICES=0
4 |
5 | IMAGES_DIR="./media/ori_images"
6 | OUTPUT_DIR="./results/imgs"
7 | RESULT_PATH="./results/test_results.json"
8 |
9 | python tools/visualize_test.py \
10 | $RESULT_PATH \
11 | $IMAGES_DIR \
12 | $OUTPUT_DIR \
13 |
14 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
--------------------------------------------------------------------------------
/src/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
--------------------------------------------------------------------------------
/src/dataset/reader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 |
4 |
5 | class TestReader(object):
6 | def __init__(self, data_dir, transform=None):
7 |
8 | self.points_num = 14
9 |
10 | self.img_height = 512
11 | self.img_width = 512
12 | self.label_height = self.img_height >> 3
13 | self.label_width = self.img_width >> 3
14 | self.group_height = self.img_height >> 3
15 | self.group_width = self.img_width >> 3
16 |
17 | self.img_ids = list()
18 | self.img_paths = list()
19 | self.transform = transform
20 |
21 | for filename in os.listdir(data_dir):
22 | img_id = filename.split('.')[0]
23 | img_path = os.path.join(data_dir, filename)
24 |
25 | self.img_ids.append(img_id)
26 | self.img_paths.append(img_path)
27 |
28 | assert len(self.img_ids) == len(self.img_paths)
29 |
30 | print "Size of data for test: %d" % self.__len__()
31 |
32 | def __len__(self):
33 | return len(self.img_ids)
34 |
35 | def __getitem__(self, idx):
36 | img = Image.open(self.img_paths[idx])
37 | f_img = img.transpose(Image.FLIP_LEFT_RIGHT)
38 | width, height = img.size
39 |
40 | img_1 = self._resize(img, self.img_height, self.img_width)
41 | img_2 = self._resize(img, self.img_height + 128, self.img_width + 128)
42 |
43 | f_img_1 = self._resize(f_img, self.img_height, self.img_width)
44 | f_img_2 = self._resize(f_img, self.img_height + 128, self.img_width + 128)
45 |
46 | if self.transform:
47 | img_1 = self.transform(img_1)
48 | img_2 = self.transform(img_2)
49 |
50 | f_img_1 = self.transform(f_img_1)
51 | f_img_2 = self.transform(f_img_2)
52 |
53 | imgs = (
54 | img_1,
55 | img_2,
56 | f_img_1,
57 | f_img_2,
58 | )
59 |
60 | sample = (imgs, (self.label_height, self.label_width), (height, width),
61 | self.img_ids[idx], self.img_paths[idx])
62 |
63 | return sample
64 |
65 | def _resize(self, img, img_height, img_width):
66 | out_img = img.resize((img_width, img_height), Image.BILINEAR)
67 |
68 | return out_img
69 |
70 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
--------------------------------------------------------------------------------
/src/models/hourglass_ae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['HourglassNet', 'hg']
6 |
7 |
8 | class ConvAct(nn.Module):
9 | def __init__(self, inplanes, planes, kernel_size, stride=1):
10 | super(ConvAct, self).__init__()
11 | padding = kernel_size / 2
12 |
13 | self.conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
14 | self.sigmoid = nn.Sigmoid()
15 |
16 | def forward(self, x):
17 | x = self.conv(x)
18 | x = x * self.sigmoid(x)
19 |
20 | return x
21 |
22 | class ActConv(nn.Module):
23 | def __init__(self, inplanes, planes, kernel_size, stride=1):
24 | super(ActConv, self).__init__()
25 | padding = kernel_size / 2
26 |
27 | self.sigmoid = nn.Sigmoid()
28 | self.conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
29 |
30 | def forward(self, x):
31 | x = x * self.sigmoid(x)
32 | x = self.conv(x)
33 |
34 | return x
35 |
36 |
37 | class Conv(nn.Module):
38 | def __init__(self, inplanes, planes, kernel_size, stride=1):
39 | super(Conv, self).__init__()
40 | padding = kernel_size / 2
41 |
42 | self.conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 |
47 | return x
48 |
49 | def make_conv(in_f, out_f, kernel_size, stride=1, mode='ConvAct'):
50 | if mode == "ConvAct":
51 | return nn.Sequential(
52 | ConvAct(in_f, out_f, kernel_size, stride)
53 | )
54 | elif mode == "Conv":
55 | return nn.Sequential(
56 | Conv(in_f, out_f, kernel_size, stride)
57 | )
58 | elif mode == "ActConv":
59 | return nn.Sequential(
60 | ActConv(in_f, out_f, kernel_size, stride)
61 | )
62 |
63 |
64 | class Hourglass(nn.Module):
65 | def __init__(self, f):
66 | super(Hourglass, self).__init__()
67 | self.n = 4
68 | self.kernel_size = 3
69 | self.f = f
70 | self.g = f >> 1
71 |
72 | self.upsample = nn.Upsample(scale_factor=2)
73 | self.hg = self._make_hg(self.n)
74 |
75 | def _make_conv(self, in_f, out_f, kernel_size):
76 | return nn.Sequential(
77 | ActConv(in_f, out_f, kernel_size)
78 | )
79 |
80 | def _make_hg(self, n = 4):
81 | hg = []
82 | f3 = self.f
83 | f2 = self.f + self.g
84 | f1 = self.f + self.g * 2
85 | f0 = self.f + self.g * 3
86 | ff = self.f + self.g * 4
87 |
88 | f_difc = {
89 | 3: (f3, f2),
90 | 2: (f2, f1),
91 | 1: (f1, f0)
92 | }
93 | for i in range(n):
94 | tmp = []
95 | if i == 0:
96 | tmp.append(self._make_conv(f0, f0, self.kernel_size))
97 | tmp.append(self._make_conv(f0, ff, self.kernel_size))
98 | tmp.append(self._make_conv(ff, f0, self.kernel_size))
99 | tmp.append(self._make_conv(ff, ff, self.kernel_size))
100 | else:
101 | tmp.append(self._make_conv(f_difc[i][0], f_difc[i][0], self.kernel_size))
102 | tmp.append(self._make_conv(f_difc[i][0], f_difc[i][1], self.kernel_size))
103 | tmp.append(self._make_conv(f_difc[i][1], f_difc[i][0], self.kernel_size))
104 | hg.append(nn.ModuleList(tmp))
105 | return nn.ModuleList(hg)
106 |
107 | def _hg_forward(self, n, x):
108 | up1 = self.hg[n-1][0](x)
109 | low = F.max_pool2d(x, 2, stride=2)
110 | low = self.hg[n-1][1](low)
111 |
112 | if n > 1:
113 | low = self._hg_forward(n-1, low)
114 | else:
115 | low = self.hg[n-1][3](low)
116 |
117 | low = self.hg[n-1][2](low)
118 | up2 = self.upsample(low)
119 |
120 | return up1 + up2
121 |
122 | def forward(self, x):
123 | return self._hg_forward(self.n, x)
124 |
125 | class StageBase(nn.Module):
126 | def __init__(self, f):
127 | super(StageBase, self).__init__()
128 | self.maxpool = nn.MaxPool2d(2, stride=2)
129 |
130 | self.conv1 = make_conv(3, 64, 7, stride=2, mode='ConvAct')
131 | self.conv2 = make_conv(64, 128, 3, mode="ConvAct")
132 |
133 | self.conv3 = make_conv(128, 128, 3, mode='ConvAct')
134 | self.conv4 = make_conv(128, 128, 3, mode='ConvAct')
135 |
136 | self.conv5 = make_conv(128, 256, 3, mode='ConvAct')
137 | self.conv6 = make_conv(256, f, 3, mode='Conv')
138 |
139 | def forward(self, x):
140 |
141 | x = self.conv1(x)
142 | x = self.conv2(x)
143 | x = self.maxpool(x)
144 |
145 | x = self.conv3(x)
146 | x = self.conv4(x)
147 | x = self.maxpool(x)
148 |
149 | x = self.conv5(x)
150 | x = self.conv6(x)
151 |
152 | return x
153 |
154 | class HourglassNet(nn.Module):
155 |
156 | def __init__(self, f=256, num_stacks=2 , num_classes=14, embedding_len=14):
157 | super(HourglassNet, self).__init__()
158 |
159 | self.num_stacks = num_stacks
160 |
161 | self.stage_base = self._make(StageBase(f))
162 |
163 | hg, fc1, fc2, score, embedding, fc_, score_ = [], [], [], [], [], [], []
164 |
165 | for i in range(num_stacks):
166 | hg.append(Hourglass(f))
167 | fc1.append(make_conv(f, f, kernel_size=3, mode="ActConv"))
168 | fc2.append(make_conv(f, f, kernel_size=1, mode="ActConv"))
169 | score.append(make_conv(f, num_classes, kernel_size=1, mode="Conv"))
170 | embedding.append(make_conv(f, embedding_len, kernel_size=1, mode="Conv"))
171 | if i < num_stacks-1:
172 | fc_.append(make_conv(f, f, kernel_size=1, mode="Conv"))
173 | score_.append(make_conv(num_classes+embedding_len, f, kernel_size=1, mode="Conv"))
174 |
175 | self.hg = nn.ModuleList(hg)
176 | self.fc1 = nn.ModuleList(fc1)
177 | self.fc2 = nn.ModuleList(fc2)
178 | self.score = nn.ModuleList(score)
179 | self.embedding = nn.ModuleList(embedding)
180 | self.fc_ = nn.ModuleList(fc_)
181 | self.score_ = nn.ModuleList(score_)
182 |
183 | def _make(self, target):
184 | return nn.Sequential(
185 | target
186 | )
187 |
188 | def forward(self, x):
189 | out = []
190 |
191 | x = self.stage_base(x)
192 |
193 | for i in range(self.num_stacks):
194 | y = self.hg[i](x)
195 | y = self.fc1[i](y)
196 | y = self.fc2[i](y)
197 | score = self.score[i](y)
198 | embedding = self.embedding[i](y)
199 | out.append((score, embedding))
200 | if i < self.num_stacks-1:
201 | fc_ = self.fc_[i](y)
202 | score_ = self.score_[i](torch.cat((score, embedding), dim=1))
203 | x = x + fc_ + score_
204 |
205 | return out
206 |
207 |
208 | def hg(**kwargs):
209 | model = HourglassNet(f=kwargs['f'], num_stacks=kwargs['num_stacks'], embedding_len=kwargs['embedding_len'],
210 | num_classes=kwargs['num_classes'])
211 | return model
212 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
--------------------------------------------------------------------------------
/src/utils/aux.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from nms import nms
4 |
5 |
6 | def nms_fm(detect_fm, group_fm, point_num=14, threshold=0.2, extra_space=9):
7 | """
8 | :param detect_fm: fm of detection, shape=[C1, H, W], numpy array
9 | :param group_fm: fm of grouping, shape=[C2, H, W], numpy array
10 | :param point_num: the number of considered keypoint
11 | :param threshold: threshold to select activation, float
12 | :param extra_space: extra_space of the nms, int
13 | :return: list of tuple, ((y, x), score, tag), (y, x) -> coordinate, score -> confidence score, tag -> vector for grouping
14 | """
15 | out_list = list()
16 | for idx in range(point_num):
17 | out_list.append([])
18 |
19 | cc, yy, xx = nms(detect_fm, threshold, extra_space)
20 | cc = cc.tolist()
21 | yy = yy.tolist()
22 | xx = xx.tolist()
23 | assert len(yy) == len(xx)
24 | assert len(cc) == len(xx)
25 | tmp_len = len(xx)
26 | for i in xrange(tmp_len):
27 | c = cc[i]
28 | y = yy[i]
29 | x = xx[i]
30 | score = detect_fm[c, y, x]
31 | tag = group_fm[c, y, x, :]
32 | out_list[c].append(((y, x), score, tag))
33 |
34 | return out_list
35 |
36 |
37 | def group_with_keypoint(in_list, ori_h, ori_w, cur_h, cur_w, threshold=1, min_part_num=3, point_num=14):
38 | def any_true(check_list):
39 | for item in check_list:
40 | for subitem in item:
41 | if subitem:
42 | return True
43 | return False
44 |
45 | def resize_coord(coord):
46 | y, x = coord
47 | y = int(float(y) / cur_h * ori_h + ori_h / cur_h / 2)
48 | x = int(float(x) / cur_w * ori_w + ori_w / cur_w / 2)
49 | return (y, x)
50 |
51 | def tag_dis(lhs, rhs):
52 | return np.linalg.norm(lhs - rhs)
53 |
54 | check_seq = [12, 13, 0, 3, 6, 9, 1, 2, 4, 5, 7, 8, 10, 11]
55 | check_list = list()
56 | out_dict = dict()
57 | human_count = 0
58 | for idx in range(point_num):
59 | item_len = len(in_list[idx])
60 | check_list.append([True] * item_len)
61 |
62 | while (any_true(check_list)):
63 | human_count += 1
64 | human_name = "human%d" % (human_count)
65 | tmp_coords = np.zeros(point_num * 3, dtype=np.int32).reshape(point_num, 3)
66 | part_count = 0
67 |
68 | finish = False
69 | for i in check_seq:
70 | if finish:
71 | break
72 | for j in range(len(in_list[i])):
73 | if check_list[i][j]:
74 | cur_coord, score, tag = in_list[i][j]
75 | y, x = resize_coord(cur_coord)
76 | tmp_coords[i][0] = x
77 | tmp_coords[i][1] = y
78 | tmp_coords[i][2] = 1
79 | check_list[i][j] = False
80 | others = [k for k in range(point_num) if k != i]
81 | for ii in others:
82 | max_score = 0.
83 | for jj in range(len(in_list[ii])):
84 | if check_list[ii][jj]:
85 | cur_coord, sub_score, sub_tag = in_list[ii][jj]
86 | yy, xx = resize_coord(cur_coord)
87 | if tag_dis(tag, sub_tag) < threshold and check_list[ii][jj] and sub_score > max_score:
88 | max_score = sub_score
89 | tmp_coords[ii][0] = xx
90 | tmp_coords[ii][1] = yy
91 | tmp_coords[ii][2] = 1
92 | check_list[ii][jj] = False
93 | part_count += 1
94 | finish = True
95 | break
96 | if part_count >= min_part_num:
97 | out_dict[human_name] = tmp_coords.reshape(-1).tolist()
98 | if len(out_dict) == 0:
99 | out_dict['human1'] = [0] * point_num * 3
100 | return out_dict
101 |
102 |
103 | def flip_fm(fm):
104 | left_right_pair = (
105 | (0, 3),
106 | (1, 4),
107 | (2, 5),
108 | (6, 9),
109 | (7, 10),
110 | (8, 11),
111 | )
112 | out_fm = np.empty_like(fm)
113 | out_fm[12, :, :] = fm[12, :, :]
114 | out_fm[13, :, :] = fm[13, :, :]
115 |
116 | for idx1, idx2 in left_right_pair:
117 | out_fm[idx1, :, :] = fm[idx2, :, :]
118 | out_fm[idx2, :, :] = fm[idx1, :, :]
119 |
120 | out_fm = out_fm[:, :, ::-1]
121 | return out_fm
122 |
123 |
124 | def integrate_fm_group(detect_fm_list, group_fm_list, height, width):
125 | """
126 | :param detect_fm_list: list of detect_fm
127 | :param group_fm_list: list of group_fm
128 | :return: integrated detect_fm and group_fm
129 | """
130 | resized_detect_list = list()
131 | resized_group_list = list()
132 |
133 | for fm in detect_fm_list:
134 | resized_detect_list.append(resize_fm(fm, height, width))
135 |
136 | for fm in group_fm_list:
137 | tmp = list()
138 | resized_group_list.append(resize_fm(fm, height, width))
139 | out_detect_fm = sum(resized_detect_list) / len(resized_detect_list)
140 | out_group_fm = np.stack(resized_group_list, axis=-1)
141 |
142 | return out_detect_fm, out_group_fm
143 |
144 |
145 | def resize_fm(fm, dst_h, dst_w):
146 | """
147 | :param fm: [C, H, W]
148 | :param dst_h:
149 | :param dst_w:
150 | :return:
151 | """
152 | out_list = list()
153 | for item in fm:
154 | out_list.append(cv2.resize(item, (dst_w, dst_h)))
155 | output = np.stack(out_list, axis=0)
156 | return output
157 |
158 |
--------------------------------------------------------------------------------
/src/utils/nms.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JakeRenn/AI-Challenger-Keypoints-pytorch/8bfcba9b49be53d67d711ff860018c998714538f/src/utils/nms.so
--------------------------------------------------------------------------------
/src/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import cv2
3 | import numpy as np
4 | from PIL import ImageDraw
5 |
6 | def draw_point_on_img_dict(img, in_dict, label=None, point_num = 14):
7 | """
8 | :param img: ori_img, [B, 3, H, W]
9 | """
10 | colors = (
11 | (0, 255, 255),
12 | (255, 0, 255),
13 | (255, 255, 0),
14 | (255, 0, 0),
15 | (0, 255, 0),
16 | (0, 0, 255),
17 | (0, 127, 255),
18 | (255, 0, 127),
19 | (127, 255, 0),
20 | (0, 255, 127),
21 | (127, 0, 255),
22 | (255, 127, 0)
23 | )
24 | line_pair = (
25 | (1, 2),
26 | (2, 3),
27 | (4, 5),
28 | (5, 6),
29 | (7, 8),
30 | (8, 9),
31 | (10, 11),
32 | (11, 12),
33 | (13, 14),
34 | )
35 |
36 | if label is not None:
37 | length = len(label)
38 | for idx in xrange(length):
39 | coords = label[idx].reshape(point_num, 3).astype(np.int32)
40 | xx = coords[:, 0]
41 | yy = coords[:, 1]
42 | vv = coords[:, 2]
43 | for i in xrange(point_num):
44 | x = xx[i]
45 | y = yy[i]
46 | v = vv[i]
47 | if v == 1:
48 | cv2.circle(img, (x, y), 2, (255, 255, 255), 2)
49 |
50 | idx = 0
51 | for key, val in in_dict.items():
52 | coords = np.array(val).reshape(point_num, 3).astype(np.int32)
53 | xx = coords[:, 0]
54 | yy = coords[:, 1]
55 | vv = coords[:, 2]
56 | for i in xrange(point_num):
57 | x = xx[i]
58 | y = yy[i]
59 | v = vv[i]
60 | if v == 1:
61 | cv2.circle(img, (x, y), 2, colors[idx], 2)
62 | for i1, i2 in line_pair:
63 | i1 -= 1
64 | i2 -= 1
65 |
66 | x1 = xx[i1]
67 | y1 = yy[i1]
68 | v1 = vv[i1]
69 |
70 | x2 = xx[i2]
71 | y2 = yy[i2]
72 | v2 = vv[i2]
73 |
74 | if v1 == 1 and v2 == 1:
75 | cv2.line(img, (x1, y1), (x2, y2), colors[idx])
76 | if idx < len(colors):
77 | idx += 1
78 | else:
79 | idx = len(colors) - 1
80 |
81 | plt.imshow(img)
82 | plt.show()
83 |
84 |
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 | import argparse
5 | import json
6 | import math
7 |
8 | import torch
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 | import torch.utils.data
13 | import torchvision.transforms as transforms
14 |
15 | import src.models.hourglass_ae
16 | import src.dataset.reader
17 | import src.utils.aux
18 |
19 | parser = argparse.ArgumentParser(description='AI Challenger Keypoints Detection')
20 | parser.add_argument('test_data_dir', metavar='DIR',
21 | help='path to test_data_dir')
22 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
23 | help='number of data loading workers (default: 4)')
24 | parser.add_argument('-b', '--batch-size', default=1, type=int,
25 | metavar='N', help='mini-batch size (default: 1)')
26 | parser.add_argument('--checkpoint_path', default='./results/hg_ae.pth.tar', type=str, metavar='PATH',
27 | help='path to checkpoint')
28 | parser.add_argument('--test_results', default='./results/multi_test_results.json', type=str, metavar='PATH',
29 | help='path to test results.')
30 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
31 | help='path to checkpoint')
32 |
33 | best_prec1 = 0
34 |
35 | def main():
36 | global args, best_prec1
37 | args = parser.parse_args()
38 |
39 | # create model
40 | model = src.models.hourglass_ae.hg(f=256, num_stacks=3, num_classes=14, embedding_len=14)
41 |
42 | model = torch.nn.DataParallel(model).cuda()
43 |
44 | checkpoint = torch.load(args.resume)
45 | model.load_state_dict(checkpoint['state_dict'])
46 |
47 | all_model = (model, )
48 |
49 | cudnn.benchmark = True
50 |
51 | normalize = transforms.Normalize(mean=[0.4798, 0.4517, 0.4220],
52 | std=[0.2558, 0.2481, 0.2468])
53 | test_dataset = src.dataset.reader.TestReader(
54 | data_dir=args.test_data_dir,
55 | transform=transforms.Compose([
56 | transforms.ToTensor(),
57 | normalize
58 | ])
59 | )
60 | test_loader = torch.utils.data.DataLoader(
61 | test_dataset, batch_size=1, shuffle=False,
62 | num_workers=args.workers, pin_memory=True)
63 |
64 | test(test_loader, all_model)
65 |
66 | def test(test_loader, models):
67 | # switch to evaluate mode
68 | for model in models:
69 | model.eval()
70 |
71 | out_data = list()
72 | print "Testing..."
73 | for i, (img, cur_size, ori_size, img_id, img_path) in enumerate(test_loader):
74 | img_id = img_id[0]
75 |
76 | ori_h, ori_w = ori_size
77 | ori_h = ori_h.numpy()[0]
78 | ori_w = ori_w.numpy()[0]
79 |
80 | cur_h, cur_w = cur_size
81 | cur_h = cur_h.numpy()[0]
82 | cur_w = cur_w.numpy()[0]
83 |
84 | detect_list = list()
85 | group_list = list()
86 | for model in models:
87 | n = len(img)
88 | for img_idx, input_img in enumerate(img):
89 | input_img = input_img.cuda(async=True)
90 | img_var = torch.autograd.Variable(input_img, volatile=True)
91 |
92 | output = model(img_var)
93 | output = output[-1]
94 | out_detect = output[0]
95 | out_group = output[1]
96 |
97 | out_detect = out_detect.data.cpu().numpy()[0]
98 | out_group = out_group.data.cpu().numpy()[0]
99 | if img_idx >= n/2:
100 | # The back half of the images is flipped
101 | out_detect = src.utils.aux.flip_fm(out_detect)
102 | out_group = src.utils.aux.flip_fm(out_group)
103 | detect_list.append(out_detect)
104 | group_list.append(out_group)
105 |
106 | out_detect, out_group = src.utils.aux.integrate_fm_group(detect_list, group_list, cur_h, cur_w)
107 |
108 | out_list = src.utils.aux.nms_fm(out_detect, out_group, threshold=0.3, extra_space=5)
109 | keypoint_annos = src.utils.aux.group_with_keypoint(out_list, ori_h, ori_w,
110 | cur_h, cur_w, threshold=0.5 * math.sqrt(n * len(models)))
111 | if (i + 1) % 1000 == 0:
112 | print ("finish %d images" % (i+1))
113 |
114 | item = dict()
115 | item['image_id'] = img_id
116 | item['keypoint_annotations'] = keypoint_annos
117 | out_data.append(item)
118 |
119 | print ("Saving results at %s" % args.test_results)
120 | with open(args.test_results, 'wb') as fw:
121 | json.dump(out_data, fw)
122 |
123 | return
124 |
125 | if __name__ == '__main__':
126 | main()
127 |
--------------------------------------------------------------------------------
/tools/visualize_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 | import argparse
5 | import os
6 | import json
7 | import random
8 | import cv2
9 |
10 | import src.utils.visualize
11 |
12 | parser = argparse.ArgumentParser(description='AI Challenger')
13 | parser.add_argument('test_results', metavar='FILE',
14 | help='path to test_results')
15 | parser.add_argument('test_data_dir', metavar='DIR',
16 | help='path to test_data_dir')
17 | parser.add_argument('output_dir', metavar='DIR',
18 | help='path to output_dir')
19 |
20 | def main():
21 | global args
22 | args = parser.parse_args()
23 |
24 | with open(args.test_results, 'rb') as fr:
25 | data = json.load(fr)
26 | random.shuffle(data)
27 |
28 | for item in data:
29 |
30 | img_path = os.path.join(args.test_data_dir, item['image_id'] + '.jpg')
31 | img_np = cv2.imread(img_path)
32 | img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
33 |
34 | keypoint_dict = item['keypoint_annotations']
35 |
36 | src.utils.visualize.draw_point_on_img_dict(img_np, keypoint_dict)
37 |
38 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
39 | cv2.imwrite(os.path.join(args.output_dir, item['image_id']+".jpg"), img_np)
40 |
41 |
42 |
43 | if __name__ == '__main__':
44 | main()
45 |
--------------------------------------------------------------------------------