├── .gitignore
├── FAQ.md
├── LICENSE
├── README.md
├── data
└── MPII
│ ├── annot
│ ├── test.h5
│ ├── train.h5
│ └── valid.h5
│ ├── dp.py
│ └── ref.py
├── models
├── layers.py
└── posenet.py
├── task
├── loss.py
└── pose.py
├── test.py
├── train.py
└── utils
├── group.py
├── img.py
└── misc.py
/.gitignore:
--------------------------------------------------------------------------------
1 | exp
2 | *.pyc
3 | *__pycache__*
4 | *core.*
5 | *.jpg
6 | *.png
7 | *.pkl
8 | *.txt
9 | *.json
10 | _ext
11 | tmp
12 | *.o*
13 | *~
--------------------------------------------------------------------------------
/FAQ.md:
--------------------------------------------------------------------------------
1 | Q: What exactly are the differences in parameters between this code and the Stacked Hourglass paper?
2 | A: The paper uses batch size of 8 instead of 16, and RMSprop with learning rate 2.5e-4 instead of Adam with learning rate of 1e-3. Also they decayed learning rate "after validation accuracy plateaued" instead of explicity at 100k iterations, but this is more or less the same idea. You can change Adam to RMSProp in "make_network" function within task/pose.py, while learning rate and batch size are in task/pose.py.
3 |
4 | Q: Were scores on 2HG model achieved using same parameters as 8HG?
5 | A: Yes, just change nstack to 2 in task/pose.py
6 |
7 | Q: How do I interpret the output of the log file?
8 | A: Each iteration during training or validation outputs a line to this file with corresponding loss. Note, we do not calculate train or validation accuracy during training as this operation requires preprocessing and is expensive. Validation loss can be used as a proxy for when to stop training.
9 |
10 | Q: Only one model is saved during training?
11 | A: Yes - the most recent model is saved each epoch. You may want to modify if you desire to save "best" checkpoint, etc.
12 |
13 | Q: How can I display predictions?
14 | A: There isn't explicit visualization code here, but you can change pixels of the image corresponding to keypoints to visualize. To do this, for example, you could modify mpii_eval function in test.py to take in images as well as keypoints and write this to tensorboard.
15 |
16 | Q: The evaluation code evaluates train and validation set? What about the test set?
17 | A: Yes, the evaluation code (test.py) is setup to calculate accuracy on the validation set. Train accuracy (like validation accuracy) is not calculated during training, so is also calculated here. Default settings in task/pose.py are setup to calculate train accuracy on a sample of the train set (300 images) to reduce compute time. To get test accuracy, you must run test.py with images and crops from test.h5, then use the evaluation toolkit provided by MPII, and submit as detailed on the [MPII website](http://human-pose.mpi-inf.mpg.de/#evaluation).
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, princeton-vl
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Stacked Hourglass Networks in Pytorch
2 |
3 | Based on **Stacked Hourglass Networks for Human Pose Estimation.** [Alejandro Newell](https://www.alejandronewell.com/), [Kaiyu Yang](https://www.cs.princeton.edu/~kaiyuy/), and [Jia Deng](https://www.cs.princeton.edu/~jiadeng/). *European Conference on Computer Vision (ECCV)*, 2016. [Github](https://github.com/princeton-vl/pose-hg-train)
4 |
5 | PyTorch code by [Chris Rockwell](https://crockwell.github.io/); adopted from: **Associative Embedding: End-to-end Learning for Joint Detection and Grouping.**
6 | [Alejandro Newell](http://www-personal.umich.edu/~alnewell/), [Zhiao Huang](https://sites.google.com/view/zhiao-huang), and [Jia Deng](https://www.cs.princeton.edu/~jiadeng/). *Neural Information Processing Systems (NeurIPS)*, 2017. [Github](https://github.com/princeton-vl/pose-ae-train)
7 |
8 | ## Getting Started
9 |
10 | This repository provides everything necessary to train and evaluate a single-person pose estimation model on MPII. If you plan on training your own model from scratch, we highly recommend using multiple GPUs.
11 |
12 | Requirements:
13 |
14 | - Python 3 (code has been tested on Python 3.8.2)
15 | - PyTorch (code tested with 1.5)
16 | - CUDA and cuDNN (tested with Cuda 10)
17 | - Python packages (not exhaustive): opencv-python (tested with 4.2), tqdm, cffi, h5py, scipy (tested with 1.4.1), pytz, imageio
18 |
19 | Structure:
20 | - ```data/```: data loading and data augmentation code
21 | - ```models/```: network architecture definitions
22 | - ```task/```: task-specific functions and training configuration
23 | - ```utils/```: image processing code and miscellaneous helper functions
24 | - ```train.py```: code for model training
25 | - ```test.py```: code for model evaluation
26 |
27 | #### Dataset
28 | Download the full [MPII Human Pose dataset](http://human-pose.mpi-inf.mpg.de/), and place the images directory in data/MPII/
29 |
30 | #### Training and Testing
31 |
32 | To train a network, call:
33 |
34 | ```python train.py -e test_run_001``` (```-e,--exp``` allows you to specify an experiment name)
35 |
36 | To continue an experiment where it left off, you can call:
37 |
38 | ```python train.py -c test_run_001```
39 |
40 | All training hyperparameters are defined in ```task/pose.py```, and you can modify ```__config__``` to test different options. It is likely you will have to change the batchsize to accommodate the number of GPUs you have available.
41 |
42 | Once a model has been trained, you can evaluate it with:
43 |
44 | ```python test.py -c test_run_001```
45 |
46 | The option "-m n" will automatically stop training after n total iterations (if continuing, would look at total iterations)
47 |
48 | #### Pretrained Models
49 |
50 | An 8HG pretrained model is available [here](http://www-personal.umich.edu/~cnris/original_8hg/checkpoint.pt). It should yield validation accuracy of 0.902.
51 |
52 | A 2HG pretrained model is available [here](http://www-personal.umich.edu/~cnris/original_2hg/checkpoint.pt). It should yield validation accuracy of 0.883.
53 |
54 | Models should be formatted as exp//checkpoint.pt
55 |
56 | Note models were trained using batch size of 16 along with Adam optimizer with LR of 1e-3 (instead of RMSProp at 2.5e-4), as they outperformed in validation. Code can easily be modified to use original paper settings. The original paper reported validation accuracy of 0.881, which this code approximately replicated. Above results also were trained for approximately 200k iters, while the original paper trained for less.
57 |
58 | #### Training/Validation split
59 |
60 | The train/val split is same as that found in authors' [implementation](https://github.com/princeton-vl/pose-hg-train)
61 |
62 | #### Note
63 |
64 | During training, occasionaly "ConnectionResetError" warning was occasionally displayed between epochs, but did not affect training.
65 |
--------------------------------------------------------------------------------
/data/MPII/annot/test.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/pytorch_stacked_hourglass/ceedc14b9b8814ba641f4a00baa5d0af153588a9/data/MPII/annot/test.h5
--------------------------------------------------------------------------------
/data/MPII/annot/train.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/pytorch_stacked_hourglass/ceedc14b9b8814ba641f4a00baa5d0af153588a9/data/MPII/annot/train.h5
--------------------------------------------------------------------------------
/data/MPII/annot/valid.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-vl/pytorch_stacked_hourglass/ceedc14b9b8814ba641f4a00baa5d0af153588a9/data/MPII/annot/valid.h5
--------------------------------------------------------------------------------
/data/MPII/dp.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import sys
3 | import os
4 | import torch
5 | import numpy as np
6 | import torch.utils.data
7 | import utils.img
8 |
9 | class GenerateHeatmap():
10 | def __init__(self, output_res, num_parts):
11 | self.output_res = output_res
12 | self.num_parts = num_parts
13 | sigma = self.output_res/64
14 | self.sigma = sigma
15 | size = 6*sigma + 3
16 | x = np.arange(0, size, 1, float)
17 | y = x[:, np.newaxis]
18 | x0, y0 = 3*sigma + 1, 3*sigma + 1
19 | self.g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
20 |
21 | def __call__(self, keypoints):
22 | hms = np.zeros(shape = (self.num_parts, self.output_res, self.output_res), dtype = np.float32)
23 | sigma = self.sigma
24 | for p in keypoints:
25 | for idx, pt in enumerate(p):
26 | if pt[0] > 0:
27 | x, y = int(pt[0]), int(pt[1])
28 | if x<0 or y<0 or x>=self.output_res or y>=self.output_res:
29 | continue
30 | ul = int(x - 3*sigma - 1), int(y - 3*sigma - 1)
31 | br = int(x + 3*sigma + 2), int(y + 3*sigma + 2)
32 |
33 | c,d = max(0, -ul[0]), min(br[0], self.output_res) - ul[0]
34 | a,b = max(0, -ul[1]), min(br[1], self.output_res) - ul[1]
35 |
36 | cc,dd = max(0, ul[0]), min(br[0], self.output_res)
37 | aa,bb = max(0, ul[1]), min(br[1], self.output_res)
38 | hms[idx, aa:bb,cc:dd] = np.maximum(hms[idx, aa:bb,cc:dd], self.g[a:b,c:d])
39 | return hms
40 |
41 | class Dataset(torch.utils.data.Dataset):
42 | def __init__(self, config, ds, index):
43 | self.input_res = config['train']['input_res']
44 | self.output_res = config['train']['output_res']
45 | self.generateHeatmap = GenerateHeatmap(self.output_res, config['inference']['num_parts'])
46 | self.ds = ds
47 | self.index = index
48 |
49 | def __len__(self):
50 | return len(self.index)
51 |
52 | def __getitem__(self, idx):
53 | return self.loadImage(self.index[idx % len(self.index)])
54 |
55 | def loadImage(self, idx):
56 | ds = self.ds
57 |
58 | ## load + crop
59 | orig_img = ds.get_img(idx)
60 | path = ds.get_path(idx)
61 | orig_keypoints = ds.get_kps(idx)
62 | kptmp = orig_keypoints.copy()
63 | c = ds.get_center(idx)
64 | s = ds.get_scale(idx)
65 | normalize = ds.get_normalized(idx)
66 |
67 | cropped = utils.img.crop(orig_img, c, s, (self.input_res, self.input_res))
68 | for i in range(np.shape(orig_keypoints)[1]):
69 | if orig_keypoints[0,i,0] > 0:
70 | orig_keypoints[0,i,:2] = utils.img.transform(orig_keypoints[0,i,:2], c, s, (self.input_res, self.input_res))
71 | keypoints = np.copy(orig_keypoints)
72 |
73 | ## augmentation -- to be done to cropped image
74 | height, width = cropped.shape[0:2]
75 | center = np.array((width/2, height/2))
76 | scale = max(height, width)/200
77 |
78 | aug_rot=0
79 |
80 | aug_rot = (np.random.random() * 2 - 1) * 30.
81 | aug_scale = np.random.random() * (1.25 - 0.75) + 0.75
82 | scale *= aug_scale
83 |
84 | mat_mask = utils.img.get_transform(center, scale, (self.output_res, self.output_res), aug_rot)[:2]
85 |
86 | mat = utils.img.get_transform(center, scale, (self.input_res, self.input_res), aug_rot)[:2]
87 | inp = cv2.warpAffine(cropped, mat, (self.input_res, self.input_res)).astype(np.float32)/255
88 | keypoints[:,:,0:2] = utils.img.kpt_affine(keypoints[:,:,0:2], mat_mask)
89 | if np.random.randint(2) == 0:
90 | inp = self.preprocess(inp)
91 | inp = inp[:, ::-1]
92 | keypoints = keypoints[:, ds.flipped_parts['mpii']]
93 | keypoints[:, :, 0] = self.output_res - keypoints[:, :, 0]
94 | orig_keypoints = orig_keypoints[:, ds.flipped_parts['mpii']]
95 | orig_keypoints[:, :, 0] = self.input_res - orig_keypoints[:, :, 0]
96 |
97 | ## set keypoints to 0 when were not visible initially (so heatmap all 0s)
98 | for i in range(np.shape(orig_keypoints)[1]):
99 | if kptmp[0,i,0] == 0 and kptmp[0,i,1] == 0:
100 | keypoints[0,i,0] = 0
101 | keypoints[0,i,1] = 0
102 | orig_keypoints[0,i,0] = 0
103 | orig_keypoints[0,i,1] = 0
104 |
105 | ## generate heatmaps on outres
106 | heatmaps = self.generateHeatmap(keypoints)
107 |
108 | return inp.astype(np.float32), heatmaps.astype(np.float32)
109 |
110 | def preprocess(self, data):
111 | # random hue and saturation
112 | data = cv2.cvtColor(data, cv2.COLOR_RGB2HSV);
113 | delta = (np.random.random() * 2 - 1) * 0.2
114 | data[:, :, 0] = np.mod(data[:,:,0] + (delta * 360 + 360.), 360.)
115 |
116 | delta_sature = np.random.random() + 0.5
117 | data[:, :, 1] *= delta_sature
118 | data[:,:, 1] = np.maximum( np.minimum(data[:,:,1], 1), 0 )
119 | data = cv2.cvtColor(data, cv2.COLOR_HSV2RGB)
120 |
121 | # adjust brightness
122 | delta = (np.random.random() * 2 - 1) * 0.3
123 | data += delta
124 |
125 | # adjust contrast
126 | mean = data.mean(axis=2, keepdims=True)
127 | data = (data - mean) * (np.random.random() + 0.5) + mean
128 | data = np.minimum(np.maximum(data, 0), 1)
129 | return data
130 |
131 |
132 | def init(config):
133 | batchsize = config['train']['batchsize']
134 | current_path = os.path.dirname(os.path.abspath(__file__))
135 | sys.path.append(current_path)
136 | import ref as ds
137 | ds.init()
138 |
139 | train, valid = ds.setup_val_split()
140 | dataset = { key: Dataset(config, ds, data) for key, data in zip( ['train', 'valid'], [train, valid] ) }
141 |
142 | use_data_loader = config['train']['use_data_loader']
143 |
144 | loaders = {}
145 | for key in dataset:
146 | loaders[key] = torch.utils.data.DataLoader(dataset[key], batch_size=batchsize, shuffle=True, num_workers=config['train']['num_workers'], pin_memory=False)
147 |
148 | def gen(phase):
149 | batchsize = config['train']['batchsize']
150 | batchnum = config['train']['{}_iters'.format(phase)]
151 | loader = loaders[phase].__iter__()
152 | for i in range(batchnum):
153 | try:
154 | imgs, heatmaps = next(loader)
155 | except StopIteration:
156 | # to avoid no data provided by dataloader
157 | loader = loaders[phase].__iter__()
158 | imgs, heatmaps = next(loader)
159 | yield {
160 | 'imgs': imgs, #cropped and augmented
161 | 'heatmaps': heatmaps, #based on keypoints. 0 if not in img for joint
162 | }
163 |
164 |
165 | return lambda key: gen(key)
166 |
--------------------------------------------------------------------------------
/data/MPII/ref.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | from imageio import imread
4 | import os
5 | import time
6 |
7 | def _isArrayLike(obj):
8 | return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
9 |
10 | annot_dir = 'data/MPII/annot'
11 | img_dir = 'data/MPII/images'
12 |
13 | assert os.path.exists(img_dir)
14 | mpii, num_examples_train, num_examples_val = None, None, None
15 |
16 | import cv2
17 |
18 | class MPII:
19 | def __init__(self):
20 | print('loading data...')
21 | tic = time.time()
22 |
23 | train_f = h5py.File(os.path.join(annot_dir, 'train.h5'), 'r')
24 | val_f = h5py.File(os.path.join(annot_dir, 'valid.h5'), 'r')
25 |
26 | self.t_center = train_f['center'][()]
27 | t_scale = train_f['scale'][()]
28 | t_part = train_f['part'][()]
29 | t_visible = train_f['visible'][()]
30 | t_normalize = train_f['normalize'][()]
31 | t_imgname = [None] * len(self.t_center)
32 | for i in range(len(self.t_center)):
33 | t_imgname[i] = train_f['imgname'][i].decode('UTF-8')
34 |
35 | self.v_center = val_f['center'][()]
36 | v_scale = val_f['scale'][()]
37 | v_part = val_f['part'][()]
38 | v_visible = val_f['visible'][()]
39 | v_normalize = val_f['normalize'][()]
40 | v_imgname = [None] * len(self.v_center)
41 | for i in range(len(self.v_center)):
42 | v_imgname[i] = val_f['imgname'][i].decode('UTF-8')
43 |
44 | self.center = np.append(self.t_center, self.v_center, axis=0)
45 | self.scale = np.append(t_scale, v_scale)
46 | self.part = np.append(t_part, v_part, axis=0)
47 | self.visible = np.append(t_visible, v_visible, axis=0)
48 | self.normalize = np.append(t_normalize, v_normalize)
49 | self.imgname = t_imgname + v_imgname
50 |
51 | print('Done (t={:0.2f}s)'.format(time.time()- tic))
52 |
53 | def getAnnots(self, idx):
54 | '''
55 | returns h5 file for train or val set
56 | '''
57 | return self.imgname[idx], self.part[idx], self.visible[idx], self.center[idx], self.scale[idx], self.normalize[idx]
58 |
59 | def getLength(self):
60 | return len(self.t_center), len(self.v_center)
61 |
62 | def init():
63 | global mpii, num_examples_train, num_examples_val
64 | mpii = MPII()
65 | num_examples_train, num_examples_val = mpii.getLength()
66 |
67 | # Part reference
68 | parts = {'mpii':['rank', 'rkne', 'rhip',
69 | 'lhip', 'lkne', 'lank',
70 | 'pelv', 'thrx', 'neck', 'head',
71 | 'rwri', 'relb', 'rsho',
72 | 'lsho', 'lelb', 'lwri']}
73 |
74 | flipped_parts = {'mpii':[5, 4, 3, 2, 1, 0, 6, 7, 8, 9, 15, 14, 13, 12, 11, 10]}
75 |
76 | part_pairs = {'mpii':[[0, 5], [1, 4], [2, 3], [6], [7], [8], [9], [10, 15], [11, 14], [12, 13]]}
77 |
78 | pair_names = {'mpii':['ankle', 'knee', 'hip', 'pelvis', 'thorax', 'neck', 'head', 'wrist', 'elbow', 'shoulder']}
79 |
80 | def setup_val_split():
81 | '''
82 | returns index for train and validation imgs
83 | index for validation images starts after that of train images
84 | so that loadImage can tell them apart
85 | '''
86 | valid = [i+num_examples_train for i in range(num_examples_val)]
87 | train = [i for i in range(num_examples_train)]
88 | return np.array(train), np.array(valid)
89 |
90 | def get_img(idx):
91 | imgname, __, __, __, __, __ = mpii.getAnnots(idx)
92 | path = os.path.join(img_dir, imgname)
93 | img = imread(path)
94 | return img
95 |
96 | def get_path(idx):
97 | imgname, __, __, __, __, __ = mpii.getAnnots(idx)
98 | path = os.path.join(img_dir, imgname)
99 | return path
100 |
101 | def get_kps(idx):
102 | __, part, visible, __, __, __ = mpii.getAnnots(idx)
103 | kp2 = np.insert(part, 2, visible, axis=1)
104 | kps = np.zeros((1, 16, 3))
105 | kps[0] = kp2
106 | return kps
107 |
108 | def get_normalized(idx):
109 | __, __, __, __, __, n = mpii.getAnnots(idx)
110 | return n
111 |
112 | def get_center(idx):
113 | __, __, __, c, __, __ = mpii.getAnnots(idx)
114 | return c
115 |
116 | def get_scale(idx):
117 | __, __, __, __, s, __ = mpii.getAnnots(idx)
118 | return s
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | Pool = nn.MaxPool2d
4 |
5 | def batchnorm(x):
6 | return nn.BatchNorm2d(x.size()[1])(x)
7 |
8 | class Conv(nn.Module):
9 | def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
10 | super(Conv, self).__init__()
11 | self.inp_dim = inp_dim
12 | self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True)
13 | self.relu = None
14 | self.bn = None
15 | if relu:
16 | self.relu = nn.ReLU()
17 | if bn:
18 | self.bn = nn.BatchNorm2d(out_dim)
19 |
20 | def forward(self, x):
21 | assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
22 | x = self.conv(x)
23 | if self.bn is not None:
24 | x = self.bn(x)
25 | if self.relu is not None:
26 | x = self.relu(x)
27 | return x
28 |
29 | class Residual(nn.Module):
30 | def __init__(self, inp_dim, out_dim):
31 | super(Residual, self).__init__()
32 | self.relu = nn.ReLU()
33 | self.bn1 = nn.BatchNorm2d(inp_dim)
34 | self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False)
35 | self.bn2 = nn.BatchNorm2d(int(out_dim/2))
36 | self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False)
37 | self.bn3 = nn.BatchNorm2d(int(out_dim/2))
38 | self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False)
39 | self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
40 | if inp_dim == out_dim:
41 | self.need_skip = False
42 | else:
43 | self.need_skip = True
44 |
45 | def forward(self, x):
46 | if self.need_skip:
47 | residual = self.skip_layer(x)
48 | else:
49 | residual = x
50 | out = self.bn1(x)
51 | out = self.relu(out)
52 | out = self.conv1(out)
53 | out = self.bn2(out)
54 | out = self.relu(out)
55 | out = self.conv2(out)
56 | out = self.bn3(out)
57 | out = self.relu(out)
58 | out = self.conv3(out)
59 | out += residual
60 | return out
61 |
62 | class Hourglass(nn.Module):
63 | def __init__(self, n, f, bn=None, increase=0):
64 | super(Hourglass, self).__init__()
65 | nf = f + increase
66 | self.up1 = Residual(f, f)
67 | # Lower branch
68 | self.pool1 = Pool(2, 2)
69 | self.low1 = Residual(f, nf)
70 | self.n = n
71 | # Recursive hourglass
72 | if self.n > 1:
73 | self.low2 = Hourglass(n-1, nf, bn=bn)
74 | else:
75 | self.low2 = Residual(nf, nf)
76 | self.low3 = Residual(nf, f)
77 | self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
78 |
79 | def forward(self, x):
80 | up1 = self.up1(x)
81 | pool1 = self.pool1(x)
82 | low1 = self.low1(pool1)
83 | low2 = self.low2(low1)
84 | low3 = self.low3(low2)
85 | up2 = self.up2(low3)
86 | return up1 + up2
87 |
--------------------------------------------------------------------------------
/models/posenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from models.layers import Conv, Hourglass, Pool, Residual
4 | from task.loss import HeatmapLoss
5 |
6 | class UnFlatten(nn.Module):
7 | def forward(self, input):
8 | return input.view(-1, 256, 4, 4)
9 |
10 | class Merge(nn.Module):
11 | def __init__(self, x_dim, y_dim):
12 | super(Merge, self).__init__()
13 | self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)
14 |
15 | def forward(self, x):
16 | return self.conv(x)
17 |
18 | class PoseNet(nn.Module):
19 | def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=0, **kwargs):
20 | super(PoseNet, self).__init__()
21 |
22 | self.nstack = nstack
23 | self.pre = nn.Sequential(
24 | Conv(3, 64, 7, 2, bn=True, relu=True),
25 | Residual(64, 128),
26 | Pool(2, 2),
27 | Residual(128, 128),
28 | Residual(128, inp_dim)
29 | )
30 |
31 | self.hgs = nn.ModuleList( [
32 | nn.Sequential(
33 | Hourglass(4, inp_dim, bn, increase),
34 | ) for i in range(nstack)] )
35 |
36 | self.features = nn.ModuleList( [
37 | nn.Sequential(
38 | Residual(inp_dim, inp_dim),
39 | Conv(inp_dim, inp_dim, 1, bn=True, relu=True)
40 | ) for i in range(nstack)] )
41 |
42 | self.outs = nn.ModuleList( [Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] )
43 | self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] )
44 | self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] )
45 | self.nstack = nstack
46 | self.heatmapLoss = HeatmapLoss()
47 |
48 | def forward(self, imgs):
49 | ## our posenet
50 | x = imgs.permute(0, 3, 1, 2) #x of size 1,3,inpdim,inpdim
51 | x = self.pre(x)
52 | combined_hm_preds = []
53 | for i in range(self.nstack):
54 | hg = self.hgs[i](x)
55 | feature = self.features[i](hg)
56 | preds = self.outs[i](feature)
57 | combined_hm_preds.append(preds)
58 | if i < self.nstack - 1:
59 | x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)
60 | return torch.stack(combined_hm_preds, 1)
61 |
62 | def calc_loss(self, combined_hm_preds, heatmaps):
63 | combined_loss = []
64 | for i in range(self.nstack):
65 | combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps))
66 | combined_loss = torch.stack(combined_loss, dim=1)
67 | return combined_loss
68 |
--------------------------------------------------------------------------------
/task/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class HeatmapLoss(torch.nn.Module):
4 | """
5 | loss for detection heatmap
6 | """
7 | def __init__(self):
8 | super(HeatmapLoss, self).__init__()
9 |
10 | def forward(self, pred, gt):
11 | l = ((pred - gt)**2)
12 | l = l.mean(dim=3).mean(dim=2).mean(dim=1)
13 | return l ## l of dim bsize
--------------------------------------------------------------------------------
/task/pose.py:
--------------------------------------------------------------------------------
1 | """
2 | __config__ contains the options for training and testing
3 | Basically all of the variables related to training are put in __config__['train']
4 | """
5 | import torch
6 | import numpy as np
7 | from torch import nn
8 | import os
9 | from torch.nn import DataParallel
10 | from utils.misc import make_input, make_output, importNet
11 |
12 | __config__ = {
13 | 'data_provider': 'data.MPII.dp',
14 | 'network': 'models.posenet.PoseNet',
15 | 'inference': {
16 | 'nstack': 8,
17 | 'inp_dim': 256,
18 | 'oup_dim': 16,
19 | 'num_parts': 16,
20 | 'increase': 0,
21 | 'keys': ['imgs'],
22 | 'num_eval': 2958, ## number of val examples used. entire set is 2958
23 | 'train_num_eval': 300, ## number of train examples tested at test time
24 | },
25 |
26 | 'train': {
27 | 'batchsize': 16,
28 | 'input_res': 256,
29 | 'output_res': 64,
30 | 'train_iters': 1000,
31 | 'valid_iters': 10,
32 | 'learning_rate': 1e-3,
33 | 'max_num_people' : 1,
34 | 'loss': [
35 | ['combined_hm_loss', 1],
36 | ],
37 | 'decay_iters': 100000,
38 | 'decay_lr': 2e-4,
39 | 'num_workers': 2,
40 | 'use_data_loader': True,
41 | },
42 | }
43 |
44 | class Trainer(nn.Module):
45 | """
46 | The wrapper module that will behave differetly for training or testing
47 | inference_keys specify the inputs for inference
48 | """
49 | def __init__(self, model, inference_keys, calc_loss=None):
50 | super(Trainer, self).__init__()
51 | self.model = model
52 | self.keys = inference_keys
53 | self.calc_loss = calc_loss
54 |
55 | def forward(self, imgs, **inputs):
56 | inps = {}
57 | labels = {}
58 |
59 | for i in inputs:
60 | if i in self.keys:
61 | inps[i] = inputs[i]
62 | else:
63 | labels[i] = inputs[i]
64 |
65 | if not self.training:
66 | return self.model(imgs, **inps)
67 | else:
68 | combined_hm_preds = self.model(imgs, **inps)
69 | if type(combined_hm_preds)!=list and type(combined_hm_preds)!=tuple:
70 | combined_hm_preds = [combined_hm_preds]
71 | loss = self.calc_loss(**labels, combined_hm_preds=combined_hm_preds)
72 | return list(combined_hm_preds) + list([loss])
73 |
74 | def make_network(configs):
75 | train_cfg = configs['train']
76 | config = configs['inference']
77 |
78 | def calc_loss(*args, **kwargs):
79 | return poseNet.calc_loss(*args, **kwargs)
80 |
81 | ## creating new posenet
82 | PoseNet = importNet(configs['network'])
83 | poseNet = PoseNet(**config)
84 | forward_net = DataParallel(poseNet.cuda())
85 | config['net'] = Trainer(forward_net, configs['inference']['keys'], calc_loss)
86 |
87 | ## optimizer, experiment setup
88 | train_cfg['optimizer'] = torch.optim.Adam(filter(lambda p: p.requires_grad,config['net'].parameters()), train_cfg['learning_rate'])
89 |
90 | exp_path = os.path.join('exp', configs['opt'].exp)
91 | if configs['opt'].exp=='pose' and configs['opt'].continue_exp is not None:
92 | exp_path = os.path.join('exp', configs['opt'].continue_exp)
93 | if not os.path.exists(exp_path):
94 | os.mkdir(exp_path)
95 | logger = open(os.path.join(exp_path, 'log'), 'a+')
96 |
97 | def make_train(batch_id, config, phase, **inputs):
98 | for i in inputs:
99 | try:
100 | inputs[i] = make_input(inputs[i])
101 | except:
102 | pass #for last input, which is a string (id_)
103 |
104 | net = config['inference']['net']
105 | config['batch_id'] = batch_id
106 |
107 | net = net.train()
108 |
109 | # When in validation phase put batchnorm layers in eval mode
110 | # to prevent running stats from getting updated.
111 | if phase == 'valid':
112 | for module in net.modules():
113 | if isinstance(module, nn.BatchNorm2d):
114 | module.eval()
115 |
116 | if phase != 'inference':
117 | result = net(inputs['imgs'], **{i:inputs[i] for i in inputs if i!='imgs'})
118 | num_loss = len(config['train']['loss'])
119 |
120 | losses = {i[0]: result[-num_loss + idx]*i[1] for idx, i in enumerate(config['train']['loss'])}
121 |
122 | loss = 0
123 | toprint = '\n{}: '.format(batch_id)
124 | for i in losses:
125 | loss = loss + torch.mean(losses[i])
126 |
127 | my_loss = make_output( losses[i] )
128 | my_loss = my_loss.mean()
129 |
130 | if my_loss.size == 1:
131 | toprint += ' {}: {}'.format(i, format(my_loss.mean(), '.8f'))
132 | else:
133 | toprint += '\n{}'.format(i)
134 | for j in my_loss:
135 | toprint += ' {}'.format(format(j.mean(), '.8f'))
136 | logger.write(toprint)
137 | logger.flush()
138 |
139 | if phase == 'train':
140 | optimizer = train_cfg['optimizer']
141 | optimizer.zero_grad()
142 | loss.backward()
143 | optimizer.step()
144 |
145 | if batch_id == config['train']['decay_iters']:
146 | ## decrease the learning rate after decay # iterations
147 | for param_group in optimizer.param_groups:
148 | param_group['lr'] = config['train']['decay_lr']
149 |
150 | return None
151 | else:
152 | out = {}
153 | net = net.eval()
154 | result = net(**inputs)
155 | if type(result)!=list and type(result)!=tuple:
156 | result = [result]
157 | out['preds'] = [make_output(i) for i in result]
158 | return out
159 | return make_train
160 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import tqdm
4 | import os
5 | import numpy as np
6 | import h5py
7 | import copy
8 |
9 | from utils.group import HeatmapParser
10 | import utils.img
11 | import data.MPII.ref as ds
12 |
13 | parser = HeatmapParser()
14 |
15 | def post_process(det, mat_, trainval, c=None, s=None, resolution=None):
16 | mat = np.linalg.pinv(np.array(mat_).tolist() + [[0,0,1]])[:2]
17 | res = det.shape[1:3]
18 | cropped_preds = parser.parse(np.float32([det]))[0]
19 |
20 | if len(cropped_preds) > 0:
21 | cropped_preds[:,:,:2] = utils.img.kpt_affine(cropped_preds[:,:,:2] * 4, mat) #size 1x16x3
22 |
23 | preds = np.copy(cropped_preds)
24 | ##for inverting predictions from input res on cropped to original image
25 | if trainval != 'cropped':
26 | for j in range(preds.shape[1]):
27 | preds[0,j,:2] = utils.img.transform(preds[0,j,:2], c, s, resolution, invert=1)
28 | return preds
29 |
30 | def inference(img, func, config, c, s):
31 | """
32 | forward pass at test time
33 | calls post_process to post process results
34 | """
35 |
36 | height, width = img.shape[0:2]
37 | center = (width/2, height/2)
38 | scale = max(height, width)/200
39 | res = (config['train']['input_res'], config['train']['input_res'])
40 |
41 | mat_ = utils.img.get_transform(center, scale, res)[:2]
42 | inp = img/255
43 |
44 | def array2dict(tmp):
45 | return {
46 | 'det': tmp[0][:,:,:16],
47 | }
48 |
49 | tmp1 = array2dict(func([inp]))
50 | tmp2 = array2dict(func([inp[:,::-1]]))
51 |
52 | tmp = {}
53 | for ii in tmp1:
54 | tmp[ii] = np.concatenate((tmp1[ii], tmp2[ii]),axis=0)
55 |
56 | det = tmp['det'][0, -1] + tmp['det'][1, -1, :, :, ::-1][ds.flipped_parts['mpii']]
57 | if det is None:
58 | return [], []
59 | det = det/2
60 |
61 | det = np.minimum(det, 1)
62 |
63 | return post_process(det, mat_, 'valid', c, s, res)
64 |
65 | def mpii_eval(pred, gt, normalizing, num_train, bound=0.5):
66 | """
67 | Use PCK with threshold of .5 of normalized distance (presumably head size)
68 | """
69 |
70 | correct = {'all': {'total': 0, 'ankle': 0, 'knee': 0, 'hip': 0, 'pelvis': 0,
71 | 'thorax': 0, 'neck': 0, 'head': 0, 'wrist': 0, 'elbow': 0,
72 | 'shoulder': 0},
73 | 'visible': {'total': 0, 'ankle': 0, 'knee': 0, 'hip': 0, 'pelvis': 0,
74 | 'thorax': 0, 'neck': 0, 'head': 0, 'wrist': 0, 'elbow': 0,
75 | 'shoulder': 0},
76 | 'not visible': {'total': 0, 'ankle': 0, 'knee': 0, 'hip': 0, 'pelvis': 0,
77 | 'thorax': 0, 'neck': 0, 'head': 0, 'wrist': 0, 'elbow': 0,
78 | 'shoulder': 0}}
79 | count = copy.deepcopy(correct)
80 | correct_train = copy.deepcopy(correct)
81 | count_train = copy.deepcopy(correct)
82 | idx = 0
83 | for p, g, normalize in zip(pred, gt, normalizing):
84 | for j in range(g.shape[1]):
85 | vis = 'visible'
86 | if g[0,j,0] == 0: ## not in picture!
87 | continue
88 | if g[0,j,2] == 0:
89 | vis = 'not visible'
90 | joint = 'ankle'
91 | if j==1 or j==4:
92 | joint = 'knee'
93 | elif j==2 or j==3:
94 | joint = 'hip'
95 | elif j==6:
96 | joint = 'pelvis'
97 | elif j==7:
98 | joint = 'thorax'
99 | elif j==8:
100 | joint = 'neck'
101 | elif j==9:
102 | joint = 'head'
103 | elif j==10 or j==15:
104 | joint = 'wrist'
105 | elif j==11 or j==14:
106 | joint = 'elbow'
107 | elif j==12 or j==13:
108 | joint = 'shoulder'
109 |
110 | if idx >= num_train:
111 | count['all']['total'] += 1
112 | count['all'][joint] += 1
113 | count[vis]['total'] += 1
114 | count[vis][joint] += 1
115 | else:
116 | count_train['all']['total'] += 1
117 | count_train['all'][joint] += 1
118 | count_train[vis]['total'] += 1
119 | count_train[vis][joint] += 1
120 | error = np.linalg.norm(p[0]['keypoints'][j,:2]-g[0,j,:2]) / normalize
121 | if idx >= num_train:
122 | if bound > error:
123 | correct['all']['total'] += 1
124 | correct['all'][joint] += 1
125 | correct[vis]['total'] += 1
126 | correct[vis][joint] += 1
127 | else:
128 | if bound > error:
129 | correct_train['all']['total'] += 1
130 | correct_train['all'][joint] += 1
131 | correct_train[vis]['total'] += 1
132 | correct_train[vis][joint] += 1
133 | idx += 1
134 |
135 | ## breakdown by validation set / training set
136 | for k in correct:
137 | print(k, ':')
138 | for key in correct[k]:
139 | print('Val PCK @,', bound, ',', key, ':', round(correct[k][key] / max(count[k][key],1), 3), ', count:', count[k][key])
140 | print('Tra PCK @,', bound, ',', key, ':', round(correct_train[k][key] / max(count_train[k][key],1), 3), ', count:', count_train[k][key])
141 | print('\n')
142 |
143 | def get_img(config, num_eval=2958, num_train=300):
144 | '''
145 | Load validation and training images
146 | '''
147 | input_res = config['train']['input_res']
148 | output_res = config['train']['output_res']
149 | val_f = h5py.File(os.path.join(ds.annot_dir, 'valid.h5'), 'r')
150 |
151 | tr = tqdm.tqdm( range(0, num_train), total = num_train )
152 | ## training
153 | train_f = h5py.File(os.path.join(ds.annot_dir, 'train.h5') ,'r')
154 | for i in tr:
155 | path_t = '%s/%s' % (ds.img_dir, train_f['imgname'][i].decode('UTF-8'))
156 |
157 | ## img
158 | orig_img = cv2.imread(path_t)[:,:,::-1]
159 | c = train_f['center'][i]
160 | s = train_f['scale'][i]
161 | im = utils.img.crop(orig_img, c, s, (input_res, input_res))
162 |
163 | ## kp
164 | kp = train_f['part'][i]
165 | vis = train_f['visible'][i]
166 | kp2 = np.insert(kp, 2, vis, axis=1)
167 | kps = np.zeros((1, 16, 3))
168 | kps[0] = kp2
169 |
170 | ## normalize (to make errors more fair on high pixel imgs)
171 | n = train_f['normalize'][i]
172 |
173 | yield kps, im, c, s, n
174 |
175 |
176 | tr2 = tqdm.tqdm( range(0, num_eval), total = num_eval )
177 | ## validation
178 | for i in tr2:
179 | path_t = '%s/%s' % (ds.img_dir, val_f['imgname'][i].decode('UTF-8'))
180 |
181 | ## img
182 | orig_img = cv2.imread(path_t)[:,:,::-1]
183 | c = val_f['center'][i]
184 | s = val_f['scale'][i]
185 | im = utils.img.crop(orig_img, c, s, (input_res, input_res))
186 |
187 | ## kp
188 | kp = val_f['part'][i]
189 | vis = val_f['visible'][i]
190 | kp2 = np.insert(kp, 2, vis, axis=1)
191 | kps = np.zeros((1, 16, 3))
192 | kps[0] = kp2
193 |
194 | ## normalize (to make errors more fair on high pixel imgs)
195 | n = val_f['normalize'][i]
196 |
197 | yield kps, im, c, s, n
198 |
199 |
200 | def main():
201 | from train import init
202 | func, config = init()
203 |
204 | def runner(imgs):
205 | return func(0, config, 'inference', imgs=torch.Tensor(np.float32(imgs)))['preds']
206 |
207 | def do(img, c, s):
208 | ans = inference(img, runner, config, c, s)
209 | if len(ans) > 0:
210 | ans = ans[:,:,:3]
211 |
212 | ## ans has shape N,16,3 (num preds, joints, x/y/visible)
213 | pred = []
214 | for i in range(ans.shape[0]):
215 | pred.append({'keypoints': ans[i,:,:]})
216 | return pred
217 |
218 | gts = []
219 | preds = []
220 | normalizing = []
221 |
222 | num_eval = config['inference']['num_eval']
223 | num_train = config['inference']['train_num_eval']
224 | for anns, img, c, s, n in get_img(config, num_eval, num_train):
225 | gts.append(anns)
226 | pred = do(img, c, s)
227 | preds.append(pred)
228 | normalizing.append(n)
229 |
230 | mpii_eval(preds, gts, normalizing, num_train)
231 |
232 | if __name__ == '__main__':
233 | main()
234 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | from os.path import dirname
4 |
5 | import torch.backends.cudnn as cudnn
6 | cudnn.benchmark = True
7 | cudnn.enabled = True
8 |
9 | import torch
10 | import importlib
11 | import argparse
12 | from datetime import datetime
13 | from pytz import timezone
14 |
15 | import shutil
16 |
17 | def parse_command_line():
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('-c', '--continue_exp', type=str, help='continue exp')
20 | parser.add_argument('-e', '--exp', type=str, default='pose', help='experiments name')
21 | parser.add_argument('-m', '--max_iters', type=int, default=250, help='max number of iterations (thousands)')
22 | args = parser.parse_args()
23 | return args
24 |
25 | def reload(config):
26 | """
27 | load or initialize model's parameters by config from config['opt'].continue_exp
28 | config['train']['epoch'] records the epoch num
29 | config['inference']['net'] is the model
30 | """
31 | opt = config['opt']
32 |
33 | if opt.continue_exp:
34 | resume = os.path.join('exp', opt.continue_exp)
35 | resume_file = os.path.join(resume, 'checkpoint.pt')
36 | if os.path.isfile(resume_file):
37 | print("=> loading checkpoint '{}'".format(resume))
38 | checkpoint = torch.load(resume_file)
39 |
40 | config['inference']['net'].load_state_dict(checkpoint['state_dict'])
41 | config['train']['optimizer'].load_state_dict(checkpoint['optimizer'])
42 | config['train']['epoch'] = checkpoint['epoch']
43 | print("=> loaded checkpoint '{}' (epoch {})"
44 | .format(resume, checkpoint['epoch']))
45 | else:
46 | print("=> no checkpoint found at '{}'".format(resume))
47 | exit(0)
48 |
49 | if 'epoch' not in config['train']:
50 | config['train']['epoch'] = 0
51 |
52 | def save_checkpoint(state, is_best, filename='checkpoint.pt'):
53 | """
54 | from pytorch/examples
55 | """
56 | basename = dirname(filename)
57 | if not os.path.exists(basename):
58 | os.makedirs(basename)
59 | torch.save(state, filename)
60 | if is_best:
61 | shutil.copyfile(filename, 'model_best.pt')
62 |
63 | def save(config):
64 | resume = os.path.join('exp', config['opt'].exp)
65 | if config['opt'].exp=='pose' and config['opt'].continue_exp is not None:
66 | resume = os.path.join('exp', config['opt'].continue_exp)
67 | resume_file = os.path.join(resume, 'checkpoint.pt')
68 |
69 | save_checkpoint({
70 | 'state_dict': config['inference']['net'].state_dict(),
71 | 'optimizer' : config['train']['optimizer'].state_dict(),
72 | 'epoch': config['train']['epoch'],
73 | }, False, filename=resume_file)
74 | print('=> save checkpoint')
75 |
76 | def train(train_func, data_func, config, post_epoch=None):
77 | while True:
78 | fails = 0
79 | print('epoch: ', config['train']['epoch'])
80 | if 'epoch_num' in config['train']:
81 | if config['train']['epoch'] > config['train']['epoch_num']:
82 | break
83 |
84 | for phase in ['train', 'valid']:
85 | num_step = config['train']['{}_iters'.format(phase)]
86 | generator = data_func(phase)
87 | print('start', phase, config['opt'].exp)
88 |
89 | show_range = range(num_step)
90 | show_range = tqdm.tqdm(show_range, total = num_step, ascii=True)
91 | batch_id = num_step * config['train']['epoch']
92 | if batch_id > config['opt'].max_iters * 1000:
93 | return
94 | for i in show_range:
95 | datas = next(generator)
96 | outs = train_func(batch_id + i, config, phase, **datas)
97 | config['train']['epoch'] += 1
98 | save(config)
99 |
100 | def init():
101 | """
102 | task.__config__ contains the variables that control the training and testing
103 | make_network builds a function which can do forward and backward propagation
104 | """
105 | opt = parse_command_line()
106 | task = importlib.import_module('task.pose')
107 | exp_path = os.path.join('exp', opt.exp)
108 |
109 | current_time = datetime.now().strftime('%b%d_%H-%M-%S')
110 |
111 | config = task.__config__
112 | try: os.makedirs(exp_path)
113 | except FileExistsError: pass
114 |
115 | config['opt'] = opt
116 | config['data_provider'] = importlib.import_module(config['data_provider'])
117 |
118 | func = task.make_network(config)
119 | reload(config)
120 | return func, config
121 |
122 | def main():
123 | func, config = init()
124 | data_func = config['data_provider'].init(config)
125 | train(func, data_func, config)
126 | print(datetime.now(timezone('EST')))
127 |
128 | if __name__ == '__main__':
129 | main()
130 |
--------------------------------------------------------------------------------
/utils/group.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def match_format(dic):
5 | loc = dic['loc_k'][0,:,0,:]
6 | val = dic['val_k'][0,:,:]
7 | ans = np.hstack((loc, val))
8 | ans = np.expand_dims(ans, axis = 0)
9 | ret = []
10 | ret.append(ans)
11 | return ret
12 |
13 | class HeatmapParser:
14 | def __init__(self):
15 | from torch import nn
16 | self.pool = nn.MaxPool2d(3, 1, 1)
17 |
18 | def nms(self, det):
19 | maxm = self.pool(det)
20 | maxm = torch.eq(maxm, det).float()
21 | det = det * maxm
22 | return det
23 |
24 | def calc(self, det):
25 | with torch.no_grad():
26 | det = torch.autograd.Variable(torch.Tensor(det))
27 | # This is a better format for future version pytorch
28 |
29 | det = self.nms(det)
30 | h = det.size()[2]
31 | w = det.size()[3]
32 | det = det.view(det.size()[0], det.size()[1], -1)
33 | val_k, ind = det.topk(1, dim=2)
34 |
35 | x = ind % w
36 | y = (ind / w).long()
37 | ind_k = torch.stack((x, y), dim=3)
38 | ans = {'loc_k': ind_k, 'val_k': val_k}
39 | return {key:ans[key].cpu().data.numpy() for key in ans}
40 |
41 | def adjust(self, ans, det):
42 | for batch_id, people in enumerate(ans):
43 | for people_id, i in enumerate(people):
44 | for joint_id, joint in enumerate(i):
45 | if joint[2]>0:
46 | y, x = joint[0:2]
47 | xx, yy = int(x), int(y)
48 | tmp = det[0][joint_id]
49 | if tmp[xx, min(yy+1, tmp.shape[1]-1)]>tmp[xx, max(yy-1, 0)]:
50 | y+=0.25
51 | else:
52 | y-=0.25
53 |
54 | if tmp[min(xx+1, tmp.shape[0]-1), yy]>tmp[max(0, xx-1), yy]:
55 | x+=0.25
56 | else:
57 | x-=0.25
58 | ans[0][0, joint_id, 0:2] = (y+0.5, x+0.5)
59 | return ans
60 |
61 | def parse(self, det, adjust=True):
62 | ans = match_format(self.calc(det))
63 | if adjust:
64 | ans = self.adjust(ans, det)
65 | return ans
66 |
--------------------------------------------------------------------------------
/utils/img.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.misc
3 | import cv2
4 |
5 | # =============================================================================
6 | # General image processing functions
7 | # =============================================================================
8 |
9 | def get_transform(center, scale, res, rot=0):
10 | # Generate transformation matrix
11 | h = 200 * scale
12 | t = np.zeros((3, 3))
13 | t[0, 0] = float(res[1]) / h
14 | t[1, 1] = float(res[0]) / h
15 | t[0, 2] = res[1] * (-float(center[0]) / h + .5)
16 | t[1, 2] = res[0] * (-float(center[1]) / h + .5)
17 | t[2, 2] = 1
18 | if not rot == 0:
19 | rot = -rot # To match direction of rotation from cropping
20 | rot_mat = np.zeros((3,3))
21 | rot_rad = rot * np.pi / 180
22 | sn,cs = np.sin(rot_rad), np.cos(rot_rad)
23 | rot_mat[0,:2] = [cs, -sn]
24 | rot_mat[1,:2] = [sn, cs]
25 | rot_mat[2,2] = 1
26 | # Need to rotate around center
27 | t_mat = np.eye(3)
28 | t_mat[0,2] = -res[1]/2
29 | t_mat[1,2] = -res[0]/2
30 | t_inv = t_mat.copy()
31 | t_inv[:2,2] *= -1
32 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
33 | return t
34 |
35 | def transform(pt, center, scale, res, invert=0, rot=0):
36 | # Transform pixel location to different reference
37 | t = get_transform(center, scale, res, rot=rot)
38 | if invert:
39 | t = np.linalg.inv(t)
40 | new_pt = np.array([pt[0], pt[1], 1.]).T
41 | new_pt = np.dot(t, new_pt)
42 | return new_pt[:2].astype(int)
43 |
44 | def crop(img, center, scale, res, rot=0):
45 | # Upper left point
46 | ul = np.array(transform([0, 0], center, scale, res, invert=1))
47 | # Bottom right point
48 | br = np.array(transform(res, center, scale, res, invert=1))
49 |
50 | new_shape = [br[1] - ul[1], br[0] - ul[0]]
51 | if len(img.shape) > 2:
52 | new_shape += [img.shape[2]]
53 | new_img = np.zeros(new_shape)
54 |
55 | # Range to fill new array
56 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
57 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
58 | # Range to sample from original image
59 | old_x = max(0, ul[0]), min(len(img[0]), br[0])
60 | old_y = max(0, ul[1]), min(len(img), br[1])
61 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
62 |
63 | return cv2.resize(new_img, res)
64 |
65 | def inv_mat(mat):
66 | ans = np.linalg.pinv(np.array(mat).tolist() + [[0,0,1]])
67 | return ans[:2]
68 |
69 | def kpt_affine(kpt, mat):
70 | kpt = np.array(kpt)
71 | shape = kpt.shape
72 | kpt = kpt.reshape(-1, 2)
73 | return np.dot( np.concatenate((kpt, kpt[:, 0:1]*0+1), axis = 1), mat.T ).reshape(shape)
74 |
75 | def resize(im, res):
76 | return np.array([cv2.resize(im[i],res) for i in range(im.shape[0])])
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import importlib
4 |
5 | # Helpers when setting up training
6 |
7 | def importNet(net):
8 | t = net.split('.')
9 | path, name = '.'.join(t[:-1]), t[-1]
10 | module = importlib.import_module(path)
11 | return eval('module.{}'.format(name))
12 |
13 | def make_input(t, requires_grad=False, need_cuda = True):
14 | inp = torch.autograd.Variable(t, requires_grad=requires_grad)
15 | if need_cuda:
16 | inp = inp.cuda()
17 | return inp
18 |
19 | def make_output(x):
20 | if not (type(x) is list):
21 | return x.cpu().data.numpy()
22 | else:
23 | return [make_output(i) for i in x]
24 |
--------------------------------------------------------------------------------