├── .gitignore ├── FAQ.md ├── LICENSE ├── README.md ├── data └── MPII │ ├── annot │ ├── test.h5 │ ├── train.h5 │ └── valid.h5 │ ├── dp.py │ └── ref.py ├── models ├── layers.py └── posenet.py ├── task ├── loss.py └── pose.py ├── test.py ├── train.py └── utils ├── group.py ├── img.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | exp 2 | *.pyc 3 | *__pycache__* 4 | *core.* 5 | *.jpg 6 | *.png 7 | *.pkl 8 | *.txt 9 | *.json 10 | _ext 11 | tmp 12 | *.o* 13 | *~ -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | Q: What exactly are the differences in parameters between this code and the Stacked Hourglass paper?
2 | A: The paper uses batch size of 8 instead of 16, and RMSprop with learning rate 2.5e-4 instead of Adam with learning rate of 1e-3. Also they decayed learning rate "after validation accuracy plateaued" instead of explicity at 100k iterations, but this is more or less the same idea. You can change Adam to RMSProp in "make_network" function within task/pose.py, while learning rate and batch size are in task/pose.py. 3 | 4 | Q: Were scores on 2HG model achieved using same parameters as 8HG?
5 | A: Yes, just change nstack to 2 in task/pose.py 6 | 7 | Q: How do I interpret the output of the log file?
8 | A: Each iteration during training or validation outputs a line to this file with corresponding loss. Note, we do not calculate train or validation accuracy during training as this operation requires preprocessing and is expensive. Validation loss can be used as a proxy for when to stop training. 9 | 10 | Q: Only one model is saved during training?
11 | A: Yes - the most recent model is saved each epoch. You may want to modify if you desire to save "best" checkpoint, etc. 12 | 13 | Q: How can I display predictions?
14 | A: There isn't explicit visualization code here, but you can change pixels of the image corresponding to keypoints to visualize. To do this, for example, you could modify mpii_eval function in test.py to take in images as well as keypoints and write this to tensorboard. 15 | 16 | Q: The evaluation code evaluates train and validation set? What about the test set?
17 | A: Yes, the evaluation code (test.py) is setup to calculate accuracy on the validation set. Train accuracy (like validation accuracy) is not calculated during training, so is also calculated here. Default settings in task/pose.py are setup to calculate train accuracy on a sample of the train set (300 images) to reduce compute time. To get test accuracy, you must run test.py with images and crops from test.h5, then use the evaluation toolkit provided by MPII, and submit as detailed on the [MPII website](http://human-pose.mpi-inf.mpg.de/#evaluation). 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, princeton-vl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stacked Hourglass Networks in Pytorch 2 | 3 | Based on **Stacked Hourglass Networks for Human Pose Estimation.** [Alejandro Newell](https://www.alejandronewell.com/), [Kaiyu Yang](https://www.cs.princeton.edu/~kaiyuy/), and [Jia Deng](https://www.cs.princeton.edu/~jiadeng/). *European Conference on Computer Vision (ECCV)*, 2016. [Github](https://github.com/princeton-vl/pose-hg-train) 4 | 5 | PyTorch code by [Chris Rockwell](https://crockwell.github.io/); adopted from: **Associative Embedding: End-to-end Learning for Joint Detection and Grouping.** 6 | [Alejandro Newell](http://www-personal.umich.edu/~alnewell/), [Zhiao Huang](https://sites.google.com/view/zhiao-huang), and [Jia Deng](https://www.cs.princeton.edu/~jiadeng/). *Neural Information Processing Systems (NeurIPS)*, 2017. [Github](https://github.com/princeton-vl/pose-ae-train) 7 | 8 | ## Getting Started 9 | 10 | This repository provides everything necessary to train and evaluate a single-person pose estimation model on MPII. If you plan on training your own model from scratch, we highly recommend using multiple GPUs. 11 | 12 | Requirements: 13 | 14 | - Python 3 (code has been tested on Python 3.8.2) 15 | - PyTorch (code tested with 1.5) 16 | - CUDA and cuDNN (tested with Cuda 10) 17 | - Python packages (not exhaustive): opencv-python (tested with 4.2), tqdm, cffi, h5py, scipy (tested with 1.4.1), pytz, imageio 18 | 19 | Structure: 20 | - ```data/```: data loading and data augmentation code 21 | - ```models/```: network architecture definitions 22 | - ```task/```: task-specific functions and training configuration 23 | - ```utils/```: image processing code and miscellaneous helper functions 24 | - ```train.py```: code for model training 25 | - ```test.py```: code for model evaluation 26 | 27 | #### Dataset 28 | Download the full [MPII Human Pose dataset](http://human-pose.mpi-inf.mpg.de/), and place the images directory in data/MPII/ 29 | 30 | #### Training and Testing 31 | 32 | To train a network, call: 33 | 34 | ```python train.py -e test_run_001``` (```-e,--exp``` allows you to specify an experiment name) 35 | 36 | To continue an experiment where it left off, you can call: 37 | 38 | ```python train.py -c test_run_001``` 39 | 40 | All training hyperparameters are defined in ```task/pose.py```, and you can modify ```__config__``` to test different options. It is likely you will have to change the batchsize to accommodate the number of GPUs you have available. 41 | 42 | Once a model has been trained, you can evaluate it with: 43 | 44 | ```python test.py -c test_run_001``` 45 | 46 | The option "-m n" will automatically stop training after n total iterations (if continuing, would look at total iterations) 47 | 48 | #### Pretrained Models 49 | 50 | An 8HG pretrained model is available [here](http://www-personal.umich.edu/~cnris/original_8hg/checkpoint.pt). It should yield validation accuracy of 0.902. 51 | 52 | A 2HG pretrained model is available [here](http://www-personal.umich.edu/~cnris/original_2hg/checkpoint.pt). It should yield validation accuracy of 0.883. 53 | 54 | Models should be formatted as exp//checkpoint.pt 55 | 56 | Note models were trained using batch size of 16 along with Adam optimizer with LR of 1e-3 (instead of RMSProp at 2.5e-4), as they outperformed in validation. Code can easily be modified to use original paper settings. The original paper reported validation accuracy of 0.881, which this code approximately replicated. Above results also were trained for approximately 200k iters, while the original paper trained for less. 57 | 58 | #### Training/Validation split 59 | 60 | The train/val split is same as that found in authors' [implementation](https://github.com/princeton-vl/pose-hg-train) 61 | 62 | #### Note 63 | 64 | During training, occasionaly "ConnectionResetError" warning was occasionally displayed between epochs, but did not affect training. 65 | -------------------------------------------------------------------------------- /data/MPII/annot/test.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/pytorch_stacked_hourglass/ceedc14b9b8814ba641f4a00baa5d0af153588a9/data/MPII/annot/test.h5 -------------------------------------------------------------------------------- /data/MPII/annot/train.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/pytorch_stacked_hourglass/ceedc14b9b8814ba641f4a00baa5d0af153588a9/data/MPII/annot/train.h5 -------------------------------------------------------------------------------- /data/MPII/annot/valid.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-vl/pytorch_stacked_hourglass/ceedc14b9b8814ba641f4a00baa5d0af153588a9/data/MPII/annot/valid.h5 -------------------------------------------------------------------------------- /data/MPII/dp.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import sys 3 | import os 4 | import torch 5 | import numpy as np 6 | import torch.utils.data 7 | import utils.img 8 | 9 | class GenerateHeatmap(): 10 | def __init__(self, output_res, num_parts): 11 | self.output_res = output_res 12 | self.num_parts = num_parts 13 | sigma = self.output_res/64 14 | self.sigma = sigma 15 | size = 6*sigma + 3 16 | x = np.arange(0, size, 1, float) 17 | y = x[:, np.newaxis] 18 | x0, y0 = 3*sigma + 1, 3*sigma + 1 19 | self.g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 20 | 21 | def __call__(self, keypoints): 22 | hms = np.zeros(shape = (self.num_parts, self.output_res, self.output_res), dtype = np.float32) 23 | sigma = self.sigma 24 | for p in keypoints: 25 | for idx, pt in enumerate(p): 26 | if pt[0] > 0: 27 | x, y = int(pt[0]), int(pt[1]) 28 | if x<0 or y<0 or x>=self.output_res or y>=self.output_res: 29 | continue 30 | ul = int(x - 3*sigma - 1), int(y - 3*sigma - 1) 31 | br = int(x + 3*sigma + 2), int(y + 3*sigma + 2) 32 | 33 | c,d = max(0, -ul[0]), min(br[0], self.output_res) - ul[0] 34 | a,b = max(0, -ul[1]), min(br[1], self.output_res) - ul[1] 35 | 36 | cc,dd = max(0, ul[0]), min(br[0], self.output_res) 37 | aa,bb = max(0, ul[1]), min(br[1], self.output_res) 38 | hms[idx, aa:bb,cc:dd] = np.maximum(hms[idx, aa:bb,cc:dd], self.g[a:b,c:d]) 39 | return hms 40 | 41 | class Dataset(torch.utils.data.Dataset): 42 | def __init__(self, config, ds, index): 43 | self.input_res = config['train']['input_res'] 44 | self.output_res = config['train']['output_res'] 45 | self.generateHeatmap = GenerateHeatmap(self.output_res, config['inference']['num_parts']) 46 | self.ds = ds 47 | self.index = index 48 | 49 | def __len__(self): 50 | return len(self.index) 51 | 52 | def __getitem__(self, idx): 53 | return self.loadImage(self.index[idx % len(self.index)]) 54 | 55 | def loadImage(self, idx): 56 | ds = self.ds 57 | 58 | ## load + crop 59 | orig_img = ds.get_img(idx) 60 | path = ds.get_path(idx) 61 | orig_keypoints = ds.get_kps(idx) 62 | kptmp = orig_keypoints.copy() 63 | c = ds.get_center(idx) 64 | s = ds.get_scale(idx) 65 | normalize = ds.get_normalized(idx) 66 | 67 | cropped = utils.img.crop(orig_img, c, s, (self.input_res, self.input_res)) 68 | for i in range(np.shape(orig_keypoints)[1]): 69 | if orig_keypoints[0,i,0] > 0: 70 | orig_keypoints[0,i,:2] = utils.img.transform(orig_keypoints[0,i,:2], c, s, (self.input_res, self.input_res)) 71 | keypoints = np.copy(orig_keypoints) 72 | 73 | ## augmentation -- to be done to cropped image 74 | height, width = cropped.shape[0:2] 75 | center = np.array((width/2, height/2)) 76 | scale = max(height, width)/200 77 | 78 | aug_rot=0 79 | 80 | aug_rot = (np.random.random() * 2 - 1) * 30. 81 | aug_scale = np.random.random() * (1.25 - 0.75) + 0.75 82 | scale *= aug_scale 83 | 84 | mat_mask = utils.img.get_transform(center, scale, (self.output_res, self.output_res), aug_rot)[:2] 85 | 86 | mat = utils.img.get_transform(center, scale, (self.input_res, self.input_res), aug_rot)[:2] 87 | inp = cv2.warpAffine(cropped, mat, (self.input_res, self.input_res)).astype(np.float32)/255 88 | keypoints[:,:,0:2] = utils.img.kpt_affine(keypoints[:,:,0:2], mat_mask) 89 | if np.random.randint(2) == 0: 90 | inp = self.preprocess(inp) 91 | inp = inp[:, ::-1] 92 | keypoints = keypoints[:, ds.flipped_parts['mpii']] 93 | keypoints[:, :, 0] = self.output_res - keypoints[:, :, 0] 94 | orig_keypoints = orig_keypoints[:, ds.flipped_parts['mpii']] 95 | orig_keypoints[:, :, 0] = self.input_res - orig_keypoints[:, :, 0] 96 | 97 | ## set keypoints to 0 when were not visible initially (so heatmap all 0s) 98 | for i in range(np.shape(orig_keypoints)[1]): 99 | if kptmp[0,i,0] == 0 and kptmp[0,i,1] == 0: 100 | keypoints[0,i,0] = 0 101 | keypoints[0,i,1] = 0 102 | orig_keypoints[0,i,0] = 0 103 | orig_keypoints[0,i,1] = 0 104 | 105 | ## generate heatmaps on outres 106 | heatmaps = self.generateHeatmap(keypoints) 107 | 108 | return inp.astype(np.float32), heatmaps.astype(np.float32) 109 | 110 | def preprocess(self, data): 111 | # random hue and saturation 112 | data = cv2.cvtColor(data, cv2.COLOR_RGB2HSV); 113 | delta = (np.random.random() * 2 - 1) * 0.2 114 | data[:, :, 0] = np.mod(data[:,:,0] + (delta * 360 + 360.), 360.) 115 | 116 | delta_sature = np.random.random() + 0.5 117 | data[:, :, 1] *= delta_sature 118 | data[:,:, 1] = np.maximum( np.minimum(data[:,:,1], 1), 0 ) 119 | data = cv2.cvtColor(data, cv2.COLOR_HSV2RGB) 120 | 121 | # adjust brightness 122 | delta = (np.random.random() * 2 - 1) * 0.3 123 | data += delta 124 | 125 | # adjust contrast 126 | mean = data.mean(axis=2, keepdims=True) 127 | data = (data - mean) * (np.random.random() + 0.5) + mean 128 | data = np.minimum(np.maximum(data, 0), 1) 129 | return data 130 | 131 | 132 | def init(config): 133 | batchsize = config['train']['batchsize'] 134 | current_path = os.path.dirname(os.path.abspath(__file__)) 135 | sys.path.append(current_path) 136 | import ref as ds 137 | ds.init() 138 | 139 | train, valid = ds.setup_val_split() 140 | dataset = { key: Dataset(config, ds, data) for key, data in zip( ['train', 'valid'], [train, valid] ) } 141 | 142 | use_data_loader = config['train']['use_data_loader'] 143 | 144 | loaders = {} 145 | for key in dataset: 146 | loaders[key] = torch.utils.data.DataLoader(dataset[key], batch_size=batchsize, shuffle=True, num_workers=config['train']['num_workers'], pin_memory=False) 147 | 148 | def gen(phase): 149 | batchsize = config['train']['batchsize'] 150 | batchnum = config['train']['{}_iters'.format(phase)] 151 | loader = loaders[phase].__iter__() 152 | for i in range(batchnum): 153 | try: 154 | imgs, heatmaps = next(loader) 155 | except StopIteration: 156 | # to avoid no data provided by dataloader 157 | loader = loaders[phase].__iter__() 158 | imgs, heatmaps = next(loader) 159 | yield { 160 | 'imgs': imgs, #cropped and augmented 161 | 'heatmaps': heatmaps, #based on keypoints. 0 if not in img for joint 162 | } 163 | 164 | 165 | return lambda key: gen(key) 166 | -------------------------------------------------------------------------------- /data/MPII/ref.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | from imageio import imread 4 | import os 5 | import time 6 | 7 | def _isArrayLike(obj): 8 | return hasattr(obj, '__iter__') and hasattr(obj, '__len__') 9 | 10 | annot_dir = 'data/MPII/annot' 11 | img_dir = 'data/MPII/images' 12 | 13 | assert os.path.exists(img_dir) 14 | mpii, num_examples_train, num_examples_val = None, None, None 15 | 16 | import cv2 17 | 18 | class MPII: 19 | def __init__(self): 20 | print('loading data...') 21 | tic = time.time() 22 | 23 | train_f = h5py.File(os.path.join(annot_dir, 'train.h5'), 'r') 24 | val_f = h5py.File(os.path.join(annot_dir, 'valid.h5'), 'r') 25 | 26 | self.t_center = train_f['center'][()] 27 | t_scale = train_f['scale'][()] 28 | t_part = train_f['part'][()] 29 | t_visible = train_f['visible'][()] 30 | t_normalize = train_f['normalize'][()] 31 | t_imgname = [None] * len(self.t_center) 32 | for i in range(len(self.t_center)): 33 | t_imgname[i] = train_f['imgname'][i].decode('UTF-8') 34 | 35 | self.v_center = val_f['center'][()] 36 | v_scale = val_f['scale'][()] 37 | v_part = val_f['part'][()] 38 | v_visible = val_f['visible'][()] 39 | v_normalize = val_f['normalize'][()] 40 | v_imgname = [None] * len(self.v_center) 41 | for i in range(len(self.v_center)): 42 | v_imgname[i] = val_f['imgname'][i].decode('UTF-8') 43 | 44 | self.center = np.append(self.t_center, self.v_center, axis=0) 45 | self.scale = np.append(t_scale, v_scale) 46 | self.part = np.append(t_part, v_part, axis=0) 47 | self.visible = np.append(t_visible, v_visible, axis=0) 48 | self.normalize = np.append(t_normalize, v_normalize) 49 | self.imgname = t_imgname + v_imgname 50 | 51 | print('Done (t={:0.2f}s)'.format(time.time()- tic)) 52 | 53 | def getAnnots(self, idx): 54 | ''' 55 | returns h5 file for train or val set 56 | ''' 57 | return self.imgname[idx], self.part[idx], self.visible[idx], self.center[idx], self.scale[idx], self.normalize[idx] 58 | 59 | def getLength(self): 60 | return len(self.t_center), len(self.v_center) 61 | 62 | def init(): 63 | global mpii, num_examples_train, num_examples_val 64 | mpii = MPII() 65 | num_examples_train, num_examples_val = mpii.getLength() 66 | 67 | # Part reference 68 | parts = {'mpii':['rank', 'rkne', 'rhip', 69 | 'lhip', 'lkne', 'lank', 70 | 'pelv', 'thrx', 'neck', 'head', 71 | 'rwri', 'relb', 'rsho', 72 | 'lsho', 'lelb', 'lwri']} 73 | 74 | flipped_parts = {'mpii':[5, 4, 3, 2, 1, 0, 6, 7, 8, 9, 15, 14, 13, 12, 11, 10]} 75 | 76 | part_pairs = {'mpii':[[0, 5], [1, 4], [2, 3], [6], [7], [8], [9], [10, 15], [11, 14], [12, 13]]} 77 | 78 | pair_names = {'mpii':['ankle', 'knee', 'hip', 'pelvis', 'thorax', 'neck', 'head', 'wrist', 'elbow', 'shoulder']} 79 | 80 | def setup_val_split(): 81 | ''' 82 | returns index for train and validation imgs 83 | index for validation images starts after that of train images 84 | so that loadImage can tell them apart 85 | ''' 86 | valid = [i+num_examples_train for i in range(num_examples_val)] 87 | train = [i for i in range(num_examples_train)] 88 | return np.array(train), np.array(valid) 89 | 90 | def get_img(idx): 91 | imgname, __, __, __, __, __ = mpii.getAnnots(idx) 92 | path = os.path.join(img_dir, imgname) 93 | img = imread(path) 94 | return img 95 | 96 | def get_path(idx): 97 | imgname, __, __, __, __, __ = mpii.getAnnots(idx) 98 | path = os.path.join(img_dir, imgname) 99 | return path 100 | 101 | def get_kps(idx): 102 | __, part, visible, __, __, __ = mpii.getAnnots(idx) 103 | kp2 = np.insert(part, 2, visible, axis=1) 104 | kps = np.zeros((1, 16, 3)) 105 | kps[0] = kp2 106 | return kps 107 | 108 | def get_normalized(idx): 109 | __, __, __, __, __, n = mpii.getAnnots(idx) 110 | return n 111 | 112 | def get_center(idx): 113 | __, __, __, c, __, __ = mpii.getAnnots(idx) 114 | return c 115 | 116 | def get_scale(idx): 117 | __, __, __, __, s, __ = mpii.getAnnots(idx) 118 | return s -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | Pool = nn.MaxPool2d 4 | 5 | def batchnorm(x): 6 | return nn.BatchNorm2d(x.size()[1])(x) 7 | 8 | class Conv(nn.Module): 9 | def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True): 10 | super(Conv, self).__init__() 11 | self.inp_dim = inp_dim 12 | self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True) 13 | self.relu = None 14 | self.bn = None 15 | if relu: 16 | self.relu = nn.ReLU() 17 | if bn: 18 | self.bn = nn.BatchNorm2d(out_dim) 19 | 20 | def forward(self, x): 21 | assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim) 22 | x = self.conv(x) 23 | if self.bn is not None: 24 | x = self.bn(x) 25 | if self.relu is not None: 26 | x = self.relu(x) 27 | return x 28 | 29 | class Residual(nn.Module): 30 | def __init__(self, inp_dim, out_dim): 31 | super(Residual, self).__init__() 32 | self.relu = nn.ReLU() 33 | self.bn1 = nn.BatchNorm2d(inp_dim) 34 | self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False) 35 | self.bn2 = nn.BatchNorm2d(int(out_dim/2)) 36 | self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False) 37 | self.bn3 = nn.BatchNorm2d(int(out_dim/2)) 38 | self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False) 39 | self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False) 40 | if inp_dim == out_dim: 41 | self.need_skip = False 42 | else: 43 | self.need_skip = True 44 | 45 | def forward(self, x): 46 | if self.need_skip: 47 | residual = self.skip_layer(x) 48 | else: 49 | residual = x 50 | out = self.bn1(x) 51 | out = self.relu(out) 52 | out = self.conv1(out) 53 | out = self.bn2(out) 54 | out = self.relu(out) 55 | out = self.conv2(out) 56 | out = self.bn3(out) 57 | out = self.relu(out) 58 | out = self.conv3(out) 59 | out += residual 60 | return out 61 | 62 | class Hourglass(nn.Module): 63 | def __init__(self, n, f, bn=None, increase=0): 64 | super(Hourglass, self).__init__() 65 | nf = f + increase 66 | self.up1 = Residual(f, f) 67 | # Lower branch 68 | self.pool1 = Pool(2, 2) 69 | self.low1 = Residual(f, nf) 70 | self.n = n 71 | # Recursive hourglass 72 | if self.n > 1: 73 | self.low2 = Hourglass(n-1, nf, bn=bn) 74 | else: 75 | self.low2 = Residual(nf, nf) 76 | self.low3 = Residual(nf, f) 77 | self.up2 = nn.Upsample(scale_factor=2, mode='nearest') 78 | 79 | def forward(self, x): 80 | up1 = self.up1(x) 81 | pool1 = self.pool1(x) 82 | low1 = self.low1(pool1) 83 | low2 = self.low2(low1) 84 | low3 = self.low3(low2) 85 | up2 = self.up2(low3) 86 | return up1 + up2 87 | -------------------------------------------------------------------------------- /models/posenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.layers import Conv, Hourglass, Pool, Residual 4 | from task.loss import HeatmapLoss 5 | 6 | class UnFlatten(nn.Module): 7 | def forward(self, input): 8 | return input.view(-1, 256, 4, 4) 9 | 10 | class Merge(nn.Module): 11 | def __init__(self, x_dim, y_dim): 12 | super(Merge, self).__init__() 13 | self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False) 14 | 15 | def forward(self, x): 16 | return self.conv(x) 17 | 18 | class PoseNet(nn.Module): 19 | def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=0, **kwargs): 20 | super(PoseNet, self).__init__() 21 | 22 | self.nstack = nstack 23 | self.pre = nn.Sequential( 24 | Conv(3, 64, 7, 2, bn=True, relu=True), 25 | Residual(64, 128), 26 | Pool(2, 2), 27 | Residual(128, 128), 28 | Residual(128, inp_dim) 29 | ) 30 | 31 | self.hgs = nn.ModuleList( [ 32 | nn.Sequential( 33 | Hourglass(4, inp_dim, bn, increase), 34 | ) for i in range(nstack)] ) 35 | 36 | self.features = nn.ModuleList( [ 37 | nn.Sequential( 38 | Residual(inp_dim, inp_dim), 39 | Conv(inp_dim, inp_dim, 1, bn=True, relu=True) 40 | ) for i in range(nstack)] ) 41 | 42 | self.outs = nn.ModuleList( [Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] ) 43 | self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] ) 44 | self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] ) 45 | self.nstack = nstack 46 | self.heatmapLoss = HeatmapLoss() 47 | 48 | def forward(self, imgs): 49 | ## our posenet 50 | x = imgs.permute(0, 3, 1, 2) #x of size 1,3,inpdim,inpdim 51 | x = self.pre(x) 52 | combined_hm_preds = [] 53 | for i in range(self.nstack): 54 | hg = self.hgs[i](x) 55 | feature = self.features[i](hg) 56 | preds = self.outs[i](feature) 57 | combined_hm_preds.append(preds) 58 | if i < self.nstack - 1: 59 | x = x + self.merge_preds[i](preds) + self.merge_features[i](feature) 60 | return torch.stack(combined_hm_preds, 1) 61 | 62 | def calc_loss(self, combined_hm_preds, heatmaps): 63 | combined_loss = [] 64 | for i in range(self.nstack): 65 | combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps)) 66 | combined_loss = torch.stack(combined_loss, dim=1) 67 | return combined_loss 68 | -------------------------------------------------------------------------------- /task/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class HeatmapLoss(torch.nn.Module): 4 | """ 5 | loss for detection heatmap 6 | """ 7 | def __init__(self): 8 | super(HeatmapLoss, self).__init__() 9 | 10 | def forward(self, pred, gt): 11 | l = ((pred - gt)**2) 12 | l = l.mean(dim=3).mean(dim=2).mean(dim=1) 13 | return l ## l of dim bsize -------------------------------------------------------------------------------- /task/pose.py: -------------------------------------------------------------------------------- 1 | """ 2 | __config__ contains the options for training and testing 3 | Basically all of the variables related to training are put in __config__['train'] 4 | """ 5 | import torch 6 | import numpy as np 7 | from torch import nn 8 | import os 9 | from torch.nn import DataParallel 10 | from utils.misc import make_input, make_output, importNet 11 | 12 | __config__ = { 13 | 'data_provider': 'data.MPII.dp', 14 | 'network': 'models.posenet.PoseNet', 15 | 'inference': { 16 | 'nstack': 8, 17 | 'inp_dim': 256, 18 | 'oup_dim': 16, 19 | 'num_parts': 16, 20 | 'increase': 0, 21 | 'keys': ['imgs'], 22 | 'num_eval': 2958, ## number of val examples used. entire set is 2958 23 | 'train_num_eval': 300, ## number of train examples tested at test time 24 | }, 25 | 26 | 'train': { 27 | 'batchsize': 16, 28 | 'input_res': 256, 29 | 'output_res': 64, 30 | 'train_iters': 1000, 31 | 'valid_iters': 10, 32 | 'learning_rate': 1e-3, 33 | 'max_num_people' : 1, 34 | 'loss': [ 35 | ['combined_hm_loss', 1], 36 | ], 37 | 'decay_iters': 100000, 38 | 'decay_lr': 2e-4, 39 | 'num_workers': 2, 40 | 'use_data_loader': True, 41 | }, 42 | } 43 | 44 | class Trainer(nn.Module): 45 | """ 46 | The wrapper module that will behave differetly for training or testing 47 | inference_keys specify the inputs for inference 48 | """ 49 | def __init__(self, model, inference_keys, calc_loss=None): 50 | super(Trainer, self).__init__() 51 | self.model = model 52 | self.keys = inference_keys 53 | self.calc_loss = calc_loss 54 | 55 | def forward(self, imgs, **inputs): 56 | inps = {} 57 | labels = {} 58 | 59 | for i in inputs: 60 | if i in self.keys: 61 | inps[i] = inputs[i] 62 | else: 63 | labels[i] = inputs[i] 64 | 65 | if not self.training: 66 | return self.model(imgs, **inps) 67 | else: 68 | combined_hm_preds = self.model(imgs, **inps) 69 | if type(combined_hm_preds)!=list and type(combined_hm_preds)!=tuple: 70 | combined_hm_preds = [combined_hm_preds] 71 | loss = self.calc_loss(**labels, combined_hm_preds=combined_hm_preds) 72 | return list(combined_hm_preds) + list([loss]) 73 | 74 | def make_network(configs): 75 | train_cfg = configs['train'] 76 | config = configs['inference'] 77 | 78 | def calc_loss(*args, **kwargs): 79 | return poseNet.calc_loss(*args, **kwargs) 80 | 81 | ## creating new posenet 82 | PoseNet = importNet(configs['network']) 83 | poseNet = PoseNet(**config) 84 | forward_net = DataParallel(poseNet.cuda()) 85 | config['net'] = Trainer(forward_net, configs['inference']['keys'], calc_loss) 86 | 87 | ## optimizer, experiment setup 88 | train_cfg['optimizer'] = torch.optim.Adam(filter(lambda p: p.requires_grad,config['net'].parameters()), train_cfg['learning_rate']) 89 | 90 | exp_path = os.path.join('exp', configs['opt'].exp) 91 | if configs['opt'].exp=='pose' and configs['opt'].continue_exp is not None: 92 | exp_path = os.path.join('exp', configs['opt'].continue_exp) 93 | if not os.path.exists(exp_path): 94 | os.mkdir(exp_path) 95 | logger = open(os.path.join(exp_path, 'log'), 'a+') 96 | 97 | def make_train(batch_id, config, phase, **inputs): 98 | for i in inputs: 99 | try: 100 | inputs[i] = make_input(inputs[i]) 101 | except: 102 | pass #for last input, which is a string (id_) 103 | 104 | net = config['inference']['net'] 105 | config['batch_id'] = batch_id 106 | 107 | net = net.train() 108 | 109 | # When in validation phase put batchnorm layers in eval mode 110 | # to prevent running stats from getting updated. 111 | if phase == 'valid': 112 | for module in net.modules(): 113 | if isinstance(module, nn.BatchNorm2d): 114 | module.eval() 115 | 116 | if phase != 'inference': 117 | result = net(inputs['imgs'], **{i:inputs[i] for i in inputs if i!='imgs'}) 118 | num_loss = len(config['train']['loss']) 119 | 120 | losses = {i[0]: result[-num_loss + idx]*i[1] for idx, i in enumerate(config['train']['loss'])} 121 | 122 | loss = 0 123 | toprint = '\n{}: '.format(batch_id) 124 | for i in losses: 125 | loss = loss + torch.mean(losses[i]) 126 | 127 | my_loss = make_output( losses[i] ) 128 | my_loss = my_loss.mean() 129 | 130 | if my_loss.size == 1: 131 | toprint += ' {}: {}'.format(i, format(my_loss.mean(), '.8f')) 132 | else: 133 | toprint += '\n{}'.format(i) 134 | for j in my_loss: 135 | toprint += ' {}'.format(format(j.mean(), '.8f')) 136 | logger.write(toprint) 137 | logger.flush() 138 | 139 | if phase == 'train': 140 | optimizer = train_cfg['optimizer'] 141 | optimizer.zero_grad() 142 | loss.backward() 143 | optimizer.step() 144 | 145 | if batch_id == config['train']['decay_iters']: 146 | ## decrease the learning rate after decay # iterations 147 | for param_group in optimizer.param_groups: 148 | param_group['lr'] = config['train']['decay_lr'] 149 | 150 | return None 151 | else: 152 | out = {} 153 | net = net.eval() 154 | result = net(**inputs) 155 | if type(result)!=list and type(result)!=tuple: 156 | result = [result] 157 | out['preds'] = [make_output(i) for i in result] 158 | return out 159 | return make_train 160 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import tqdm 4 | import os 5 | import numpy as np 6 | import h5py 7 | import copy 8 | 9 | from utils.group import HeatmapParser 10 | import utils.img 11 | import data.MPII.ref as ds 12 | 13 | parser = HeatmapParser() 14 | 15 | def post_process(det, mat_, trainval, c=None, s=None, resolution=None): 16 | mat = np.linalg.pinv(np.array(mat_).tolist() + [[0,0,1]])[:2] 17 | res = det.shape[1:3] 18 | cropped_preds = parser.parse(np.float32([det]))[0] 19 | 20 | if len(cropped_preds) > 0: 21 | cropped_preds[:,:,:2] = utils.img.kpt_affine(cropped_preds[:,:,:2] * 4, mat) #size 1x16x3 22 | 23 | preds = np.copy(cropped_preds) 24 | ##for inverting predictions from input res on cropped to original image 25 | if trainval != 'cropped': 26 | for j in range(preds.shape[1]): 27 | preds[0,j,:2] = utils.img.transform(preds[0,j,:2], c, s, resolution, invert=1) 28 | return preds 29 | 30 | def inference(img, func, config, c, s): 31 | """ 32 | forward pass at test time 33 | calls post_process to post process results 34 | """ 35 | 36 | height, width = img.shape[0:2] 37 | center = (width/2, height/2) 38 | scale = max(height, width)/200 39 | res = (config['train']['input_res'], config['train']['input_res']) 40 | 41 | mat_ = utils.img.get_transform(center, scale, res)[:2] 42 | inp = img/255 43 | 44 | def array2dict(tmp): 45 | return { 46 | 'det': tmp[0][:,:,:16], 47 | } 48 | 49 | tmp1 = array2dict(func([inp])) 50 | tmp2 = array2dict(func([inp[:,::-1]])) 51 | 52 | tmp = {} 53 | for ii in tmp1: 54 | tmp[ii] = np.concatenate((tmp1[ii], tmp2[ii]),axis=0) 55 | 56 | det = tmp['det'][0, -1] + tmp['det'][1, -1, :, :, ::-1][ds.flipped_parts['mpii']] 57 | if det is None: 58 | return [], [] 59 | det = det/2 60 | 61 | det = np.minimum(det, 1) 62 | 63 | return post_process(det, mat_, 'valid', c, s, res) 64 | 65 | def mpii_eval(pred, gt, normalizing, num_train, bound=0.5): 66 | """ 67 | Use PCK with threshold of .5 of normalized distance (presumably head size) 68 | """ 69 | 70 | correct = {'all': {'total': 0, 'ankle': 0, 'knee': 0, 'hip': 0, 'pelvis': 0, 71 | 'thorax': 0, 'neck': 0, 'head': 0, 'wrist': 0, 'elbow': 0, 72 | 'shoulder': 0}, 73 | 'visible': {'total': 0, 'ankle': 0, 'knee': 0, 'hip': 0, 'pelvis': 0, 74 | 'thorax': 0, 'neck': 0, 'head': 0, 'wrist': 0, 'elbow': 0, 75 | 'shoulder': 0}, 76 | 'not visible': {'total': 0, 'ankle': 0, 'knee': 0, 'hip': 0, 'pelvis': 0, 77 | 'thorax': 0, 'neck': 0, 'head': 0, 'wrist': 0, 'elbow': 0, 78 | 'shoulder': 0}} 79 | count = copy.deepcopy(correct) 80 | correct_train = copy.deepcopy(correct) 81 | count_train = copy.deepcopy(correct) 82 | idx = 0 83 | for p, g, normalize in zip(pred, gt, normalizing): 84 | for j in range(g.shape[1]): 85 | vis = 'visible' 86 | if g[0,j,0] == 0: ## not in picture! 87 | continue 88 | if g[0,j,2] == 0: 89 | vis = 'not visible' 90 | joint = 'ankle' 91 | if j==1 or j==4: 92 | joint = 'knee' 93 | elif j==2 or j==3: 94 | joint = 'hip' 95 | elif j==6: 96 | joint = 'pelvis' 97 | elif j==7: 98 | joint = 'thorax' 99 | elif j==8: 100 | joint = 'neck' 101 | elif j==9: 102 | joint = 'head' 103 | elif j==10 or j==15: 104 | joint = 'wrist' 105 | elif j==11 or j==14: 106 | joint = 'elbow' 107 | elif j==12 or j==13: 108 | joint = 'shoulder' 109 | 110 | if idx >= num_train: 111 | count['all']['total'] += 1 112 | count['all'][joint] += 1 113 | count[vis]['total'] += 1 114 | count[vis][joint] += 1 115 | else: 116 | count_train['all']['total'] += 1 117 | count_train['all'][joint] += 1 118 | count_train[vis]['total'] += 1 119 | count_train[vis][joint] += 1 120 | error = np.linalg.norm(p[0]['keypoints'][j,:2]-g[0,j,:2]) / normalize 121 | if idx >= num_train: 122 | if bound > error: 123 | correct['all']['total'] += 1 124 | correct['all'][joint] += 1 125 | correct[vis]['total'] += 1 126 | correct[vis][joint] += 1 127 | else: 128 | if bound > error: 129 | correct_train['all']['total'] += 1 130 | correct_train['all'][joint] += 1 131 | correct_train[vis]['total'] += 1 132 | correct_train[vis][joint] += 1 133 | idx += 1 134 | 135 | ## breakdown by validation set / training set 136 | for k in correct: 137 | print(k, ':') 138 | for key in correct[k]: 139 | print('Val PCK @,', bound, ',', key, ':', round(correct[k][key] / max(count[k][key],1), 3), ', count:', count[k][key]) 140 | print('Tra PCK @,', bound, ',', key, ':', round(correct_train[k][key] / max(count_train[k][key],1), 3), ', count:', count_train[k][key]) 141 | print('\n') 142 | 143 | def get_img(config, num_eval=2958, num_train=300): 144 | ''' 145 | Load validation and training images 146 | ''' 147 | input_res = config['train']['input_res'] 148 | output_res = config['train']['output_res'] 149 | val_f = h5py.File(os.path.join(ds.annot_dir, 'valid.h5'), 'r') 150 | 151 | tr = tqdm.tqdm( range(0, num_train), total = num_train ) 152 | ## training 153 | train_f = h5py.File(os.path.join(ds.annot_dir, 'train.h5') ,'r') 154 | for i in tr: 155 | path_t = '%s/%s' % (ds.img_dir, train_f['imgname'][i].decode('UTF-8')) 156 | 157 | ## img 158 | orig_img = cv2.imread(path_t)[:,:,::-1] 159 | c = train_f['center'][i] 160 | s = train_f['scale'][i] 161 | im = utils.img.crop(orig_img, c, s, (input_res, input_res)) 162 | 163 | ## kp 164 | kp = train_f['part'][i] 165 | vis = train_f['visible'][i] 166 | kp2 = np.insert(kp, 2, vis, axis=1) 167 | kps = np.zeros((1, 16, 3)) 168 | kps[0] = kp2 169 | 170 | ## normalize (to make errors more fair on high pixel imgs) 171 | n = train_f['normalize'][i] 172 | 173 | yield kps, im, c, s, n 174 | 175 | 176 | tr2 = tqdm.tqdm( range(0, num_eval), total = num_eval ) 177 | ## validation 178 | for i in tr2: 179 | path_t = '%s/%s' % (ds.img_dir, val_f['imgname'][i].decode('UTF-8')) 180 | 181 | ## img 182 | orig_img = cv2.imread(path_t)[:,:,::-1] 183 | c = val_f['center'][i] 184 | s = val_f['scale'][i] 185 | im = utils.img.crop(orig_img, c, s, (input_res, input_res)) 186 | 187 | ## kp 188 | kp = val_f['part'][i] 189 | vis = val_f['visible'][i] 190 | kp2 = np.insert(kp, 2, vis, axis=1) 191 | kps = np.zeros((1, 16, 3)) 192 | kps[0] = kp2 193 | 194 | ## normalize (to make errors more fair on high pixel imgs) 195 | n = val_f['normalize'][i] 196 | 197 | yield kps, im, c, s, n 198 | 199 | 200 | def main(): 201 | from train import init 202 | func, config = init() 203 | 204 | def runner(imgs): 205 | return func(0, config, 'inference', imgs=torch.Tensor(np.float32(imgs)))['preds'] 206 | 207 | def do(img, c, s): 208 | ans = inference(img, runner, config, c, s) 209 | if len(ans) > 0: 210 | ans = ans[:,:,:3] 211 | 212 | ## ans has shape N,16,3 (num preds, joints, x/y/visible) 213 | pred = [] 214 | for i in range(ans.shape[0]): 215 | pred.append({'keypoints': ans[i,:,:]}) 216 | return pred 217 | 218 | gts = [] 219 | preds = [] 220 | normalizing = [] 221 | 222 | num_eval = config['inference']['num_eval'] 223 | num_train = config['inference']['train_num_eval'] 224 | for anns, img, c, s, n in get_img(config, num_eval, num_train): 225 | gts.append(anns) 226 | pred = do(img, c, s) 227 | preds.append(pred) 228 | normalizing.append(n) 229 | 230 | mpii_eval(preds, gts, normalizing, num_train) 231 | 232 | if __name__ == '__main__': 233 | main() 234 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | from os.path import dirname 4 | 5 | import torch.backends.cudnn as cudnn 6 | cudnn.benchmark = True 7 | cudnn.enabled = True 8 | 9 | import torch 10 | import importlib 11 | import argparse 12 | from datetime import datetime 13 | from pytz import timezone 14 | 15 | import shutil 16 | 17 | def parse_command_line(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('-c', '--continue_exp', type=str, help='continue exp') 20 | parser.add_argument('-e', '--exp', type=str, default='pose', help='experiments name') 21 | parser.add_argument('-m', '--max_iters', type=int, default=250, help='max number of iterations (thousands)') 22 | args = parser.parse_args() 23 | return args 24 | 25 | def reload(config): 26 | """ 27 | load or initialize model's parameters by config from config['opt'].continue_exp 28 | config['train']['epoch'] records the epoch num 29 | config['inference']['net'] is the model 30 | """ 31 | opt = config['opt'] 32 | 33 | if opt.continue_exp: 34 | resume = os.path.join('exp', opt.continue_exp) 35 | resume_file = os.path.join(resume, 'checkpoint.pt') 36 | if os.path.isfile(resume_file): 37 | print("=> loading checkpoint '{}'".format(resume)) 38 | checkpoint = torch.load(resume_file) 39 | 40 | config['inference']['net'].load_state_dict(checkpoint['state_dict']) 41 | config['train']['optimizer'].load_state_dict(checkpoint['optimizer']) 42 | config['train']['epoch'] = checkpoint['epoch'] 43 | print("=> loaded checkpoint '{}' (epoch {})" 44 | .format(resume, checkpoint['epoch'])) 45 | else: 46 | print("=> no checkpoint found at '{}'".format(resume)) 47 | exit(0) 48 | 49 | if 'epoch' not in config['train']: 50 | config['train']['epoch'] = 0 51 | 52 | def save_checkpoint(state, is_best, filename='checkpoint.pt'): 53 | """ 54 | from pytorch/examples 55 | """ 56 | basename = dirname(filename) 57 | if not os.path.exists(basename): 58 | os.makedirs(basename) 59 | torch.save(state, filename) 60 | if is_best: 61 | shutil.copyfile(filename, 'model_best.pt') 62 | 63 | def save(config): 64 | resume = os.path.join('exp', config['opt'].exp) 65 | if config['opt'].exp=='pose' and config['opt'].continue_exp is not None: 66 | resume = os.path.join('exp', config['opt'].continue_exp) 67 | resume_file = os.path.join(resume, 'checkpoint.pt') 68 | 69 | save_checkpoint({ 70 | 'state_dict': config['inference']['net'].state_dict(), 71 | 'optimizer' : config['train']['optimizer'].state_dict(), 72 | 'epoch': config['train']['epoch'], 73 | }, False, filename=resume_file) 74 | print('=> save checkpoint') 75 | 76 | def train(train_func, data_func, config, post_epoch=None): 77 | while True: 78 | fails = 0 79 | print('epoch: ', config['train']['epoch']) 80 | if 'epoch_num' in config['train']: 81 | if config['train']['epoch'] > config['train']['epoch_num']: 82 | break 83 | 84 | for phase in ['train', 'valid']: 85 | num_step = config['train']['{}_iters'.format(phase)] 86 | generator = data_func(phase) 87 | print('start', phase, config['opt'].exp) 88 | 89 | show_range = range(num_step) 90 | show_range = tqdm.tqdm(show_range, total = num_step, ascii=True) 91 | batch_id = num_step * config['train']['epoch'] 92 | if batch_id > config['opt'].max_iters * 1000: 93 | return 94 | for i in show_range: 95 | datas = next(generator) 96 | outs = train_func(batch_id + i, config, phase, **datas) 97 | config['train']['epoch'] += 1 98 | save(config) 99 | 100 | def init(): 101 | """ 102 | task.__config__ contains the variables that control the training and testing 103 | make_network builds a function which can do forward and backward propagation 104 | """ 105 | opt = parse_command_line() 106 | task = importlib.import_module('task.pose') 107 | exp_path = os.path.join('exp', opt.exp) 108 | 109 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 110 | 111 | config = task.__config__ 112 | try: os.makedirs(exp_path) 113 | except FileExistsError: pass 114 | 115 | config['opt'] = opt 116 | config['data_provider'] = importlib.import_module(config['data_provider']) 117 | 118 | func = task.make_network(config) 119 | reload(config) 120 | return func, config 121 | 122 | def main(): 123 | func, config = init() 124 | data_func = config['data_provider'].init(config) 125 | train(func, data_func, config) 126 | print(datetime.now(timezone('EST'))) 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /utils/group.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def match_format(dic): 5 | loc = dic['loc_k'][0,:,0,:] 6 | val = dic['val_k'][0,:,:] 7 | ans = np.hstack((loc, val)) 8 | ans = np.expand_dims(ans, axis = 0) 9 | ret = [] 10 | ret.append(ans) 11 | return ret 12 | 13 | class HeatmapParser: 14 | def __init__(self): 15 | from torch import nn 16 | self.pool = nn.MaxPool2d(3, 1, 1) 17 | 18 | def nms(self, det): 19 | maxm = self.pool(det) 20 | maxm = torch.eq(maxm, det).float() 21 | det = det * maxm 22 | return det 23 | 24 | def calc(self, det): 25 | with torch.no_grad(): 26 | det = torch.autograd.Variable(torch.Tensor(det)) 27 | # This is a better format for future version pytorch 28 | 29 | det = self.nms(det) 30 | h = det.size()[2] 31 | w = det.size()[3] 32 | det = det.view(det.size()[0], det.size()[1], -1) 33 | val_k, ind = det.topk(1, dim=2) 34 | 35 | x = ind % w 36 | y = (ind / w).long() 37 | ind_k = torch.stack((x, y), dim=3) 38 | ans = {'loc_k': ind_k, 'val_k': val_k} 39 | return {key:ans[key].cpu().data.numpy() for key in ans} 40 | 41 | def adjust(self, ans, det): 42 | for batch_id, people in enumerate(ans): 43 | for people_id, i in enumerate(people): 44 | for joint_id, joint in enumerate(i): 45 | if joint[2]>0: 46 | y, x = joint[0:2] 47 | xx, yy = int(x), int(y) 48 | tmp = det[0][joint_id] 49 | if tmp[xx, min(yy+1, tmp.shape[1]-1)]>tmp[xx, max(yy-1, 0)]: 50 | y+=0.25 51 | else: 52 | y-=0.25 53 | 54 | if tmp[min(xx+1, tmp.shape[0]-1), yy]>tmp[max(0, xx-1), yy]: 55 | x+=0.25 56 | else: 57 | x-=0.25 58 | ans[0][0, joint_id, 0:2] = (y+0.5, x+0.5) 59 | return ans 60 | 61 | def parse(self, det, adjust=True): 62 | ans = match_format(self.calc(det)) 63 | if adjust: 64 | ans = self.adjust(ans, det) 65 | return ans 66 | -------------------------------------------------------------------------------- /utils/img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import cv2 4 | 5 | # ============================================================================= 6 | # General image processing functions 7 | # ============================================================================= 8 | 9 | def get_transform(center, scale, res, rot=0): 10 | # Generate transformation matrix 11 | h = 200 * scale 12 | t = np.zeros((3, 3)) 13 | t[0, 0] = float(res[1]) / h 14 | t[1, 1] = float(res[0]) / h 15 | t[0, 2] = res[1] * (-float(center[0]) / h + .5) 16 | t[1, 2] = res[0] * (-float(center[1]) / h + .5) 17 | t[2, 2] = 1 18 | if not rot == 0: 19 | rot = -rot # To match direction of rotation from cropping 20 | rot_mat = np.zeros((3,3)) 21 | rot_rad = rot * np.pi / 180 22 | sn,cs = np.sin(rot_rad), np.cos(rot_rad) 23 | rot_mat[0,:2] = [cs, -sn] 24 | rot_mat[1,:2] = [sn, cs] 25 | rot_mat[2,2] = 1 26 | # Need to rotate around center 27 | t_mat = np.eye(3) 28 | t_mat[0,2] = -res[1]/2 29 | t_mat[1,2] = -res[0]/2 30 | t_inv = t_mat.copy() 31 | t_inv[:2,2] *= -1 32 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) 33 | return t 34 | 35 | def transform(pt, center, scale, res, invert=0, rot=0): 36 | # Transform pixel location to different reference 37 | t = get_transform(center, scale, res, rot=rot) 38 | if invert: 39 | t = np.linalg.inv(t) 40 | new_pt = np.array([pt[0], pt[1], 1.]).T 41 | new_pt = np.dot(t, new_pt) 42 | return new_pt[:2].astype(int) 43 | 44 | def crop(img, center, scale, res, rot=0): 45 | # Upper left point 46 | ul = np.array(transform([0, 0], center, scale, res, invert=1)) 47 | # Bottom right point 48 | br = np.array(transform(res, center, scale, res, invert=1)) 49 | 50 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 51 | if len(img.shape) > 2: 52 | new_shape += [img.shape[2]] 53 | new_img = np.zeros(new_shape) 54 | 55 | # Range to fill new array 56 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] 57 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] 58 | # Range to sample from original image 59 | old_x = max(0, ul[0]), min(len(img[0]), br[0]) 60 | old_y = max(0, ul[1]), min(len(img), br[1]) 61 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] 62 | 63 | return cv2.resize(new_img, res) 64 | 65 | def inv_mat(mat): 66 | ans = np.linalg.pinv(np.array(mat).tolist() + [[0,0,1]]) 67 | return ans[:2] 68 | 69 | def kpt_affine(kpt, mat): 70 | kpt = np.array(kpt) 71 | shape = kpt.shape 72 | kpt = kpt.reshape(-1, 2) 73 | return np.dot( np.concatenate((kpt, kpt[:, 0:1]*0+1), axis = 1), mat.T ).reshape(shape) 74 | 75 | def resize(im, res): 76 | return np.array([cv2.resize(im[i],res) for i in range(im.shape[0])]) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import importlib 4 | 5 | # Helpers when setting up training 6 | 7 | def importNet(net): 8 | t = net.split('.') 9 | path, name = '.'.join(t[:-1]), t[-1] 10 | module = importlib.import_module(path) 11 | return eval('module.{}'.format(name)) 12 | 13 | def make_input(t, requires_grad=False, need_cuda = True): 14 | inp = torch.autograd.Variable(t, requires_grad=requires_grad) 15 | if need_cuda: 16 | inp = inp.cuda() 17 | return inp 18 | 19 | def make_output(x): 20 | if not (type(x) is list): 21 | return x.cpu().data.numpy() 22 | else: 23 | return [make_output(i) for i in x] 24 | --------------------------------------------------------------------------------