├── src ├── pypose │ ├── __init__.py │ ├── ref.py │ ├── draw.py │ ├── eval.py │ ├── img.py │ ├── data.py │ └── report.py ├── misc │ ├── predictMPII.m │ ├── verify_gradient.lua │ ├── visualize_results.lua │ ├── analyze_occlusion_pr.py │ ├── generate_exs.lua │ ├── monitor_experiments.py │ ├── examples.py │ └── pck_figs.py ├── models │ ├── layers │ │ ├── Residual.lua │ │ └── MRF.lua │ ├── hg-generic.lua │ ├── hg-stacked-no-int.lua │ ├── hg-stacked.lua │ ├── hg.lua │ ├── hg-stacked-2.lua │ ├── hg-stacked-3.lua │ └── hg-stacked-4.lua ├── main.lua ├── model.lua ├── util │ ├── pose-int.lua │ ├── pose-vol.lua │ ├── pose.lua │ ├── pose-c2f.lua │ ├── eval.lua │ ├── Logger.lua │ └── img.lua ├── dataloader.lua ├── opts.lua ├── train.lua ├── data.lua └── ref.lua ├── data.sh ├── LICENSE └── README.md /src/pypose/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['img', 'draw', 'data', 'ref', 'report', 'eval'] 2 | 3 | import img, draw, data, ref, report, eval 4 | -------------------------------------------------------------------------------- /data.sh: -------------------------------------------------------------------------------- 1 | # Download H36M annotations 2 | mkdir data 3 | cd data 4 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/h36m_annot.tar 5 | tar -xf h36m_annot.tar 6 | rm h36m_annot.tar 7 | 8 | # Download H36M images 9 | mkdir -p h36m/images 10 | cd h36m/images 11 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S1.tar 12 | tar -xf S1.tar 13 | rm S1.tar 14 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S5.tar 15 | tar -xf S5.tar 16 | rm S5.tar 17 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S6.tar 18 | tar -xf S6.tar 19 | rm S6.tar 20 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S7.tar 21 | tar -xf S7.tar 22 | rm S7.tar 23 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S8.tar 24 | tar -xf S8.tar 25 | rm S8.tar 26 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S9.tar 27 | tar -xf S9.tar 28 | rm S9.tar 29 | wget http://visiondata.cis.upenn.edu/volumetric/h36m/S11.tar 30 | tar -xf S11.tar 31 | rm S11.tar 32 | cd ../../.. -------------------------------------------------------------------------------- /src/misc/predictMPII.m: -------------------------------------------------------------------------------- 1 | 2 | load('../../data/mpii/annot.mat'); 3 | imgname = h5read('../../data/mpii/annot/test.h5', '/imgname'); 4 | index = h5read('../../data/mpii/annot/test.h5', '/index') + 1; 5 | person = h5read('../../data/mpii/annot/test.h5', '/person') + 1; 6 | scale = h5read('../../data/mpii/annot/test.h5', '/scale'); 7 | preds = h5read('../../exp/mpii/best/preds_full.h5', '/preds_tf'); 8 | nJoints = size(preds, 2); 9 | 10 | for i = 1:size(preds, 3) 11 | idx = index(i); 12 | assert(strcmp(RELEASE.annolist(idx).image.name, imgname{i}) == 1); 13 | assert(RELEASE.img_train(idx) == 0); 14 | 15 | x = cell(1, nJoints); 16 | y = cell(1, nJoints); 17 | id = cell(1, nJoints); 18 | for j = 1:nJoints 19 | x{j} = preds(1, j, i); 20 | y{j} = preds(2, j, i); 21 | id{j} = j - 1; 22 | end 23 | RELEASE.annolist(idx).annorect(person(i)).annopoints = struct('point', struct('x', x, 'y', y,'id', id)); 24 | end 25 | 26 | save('mpii-prediction-01.mat', 'RELEASE'); 27 | -------------------------------------------------------------------------------- /src/models/layers/Residual.lua: -------------------------------------------------------------------------------- 1 | local conv = nnlib.SpatialConvolution 2 | local batchnorm = nn.SpatialBatchNormalization 3 | local relu = nnlib.ReLU 4 | 5 | -- Main convolutional block 6 | local function convBlock(numIn,numOut) 7 | return nn.Sequential() 8 | :add(batchnorm(numIn)) 9 | :add(relu(true)) 10 | :add(conv(numIn,numOut/2,1,1)) 11 | :add(batchnorm(numOut/2)) 12 | :add(relu(true)) 13 | :add(conv(numOut/2,numOut/2,3,3,1,1,1,1)) 14 | :add(batchnorm(numOut/2)) 15 | :add(relu(true)) 16 | :add(conv(numOut/2,numOut,1,1)) 17 | end 18 | 19 | -- Skip layer 20 | local function skipLayer(numIn,numOut) 21 | if numIn == numOut then 22 | return nn.Identity() 23 | else 24 | return nn.Sequential() 25 | :add(conv(numIn,numOut,1,1)) 26 | end 27 | end 28 | 29 | -- Residual block 30 | function Residual(numIn,numOut) 31 | return nn.Sequential() 32 | :add(nn.ConcatTable() 33 | :add(convBlock(numIn,numOut)) 34 | :add(skipLayer(numIn,numOut))) 35 | :add(nn.CAddTable(true)) 36 | end 37 | 38 | -------------------------------------------------------------------------------- /src/main.lua: -------------------------------------------------------------------------------- 1 | require 'paths' 2 | paths.dofile('ref.lua') -- Parse command line input and do global variable initialization 3 | paths.dofile('data.lua') -- Set up data processing 4 | paths.dofile('model.lua') -- Read in network model 5 | paths.dofile('train.lua') -- Load up training/testing functions 6 | 7 | torch.setnumthreads(1) 8 | local Dataloader = require 'dataloader' 9 | loader = Dataloader.create(opt) 10 | 11 | isFinished = false -- Finish early if validation accuracy plateaus, can be adjusted with opt.threshold 12 | 13 | -- Main training loop 14 | for i=1,opt.nEpochs do 15 | train() 16 | valid() 17 | collectgarbage() 18 | epoch = epoch + 1 19 | if isFinished then break end 20 | end 21 | 22 | -- Update options/reference for last epoch 23 | opt.lastEpoch = epoch - 1 24 | torch.save(opt.save .. '/options.t7', opt) 25 | 26 | -- Generate final predictions on validation set 27 | if opt.finalPredictions == 1 then predict() end 28 | 29 | -- Save model 30 | model:clearState() 31 | torch.save(paths.concat(opt.save,'final_model.t7'), model) 32 | torch.save(paths.concat(opt.save,'optimState.t7'), optimState) 33 | -------------------------------------------------------------------------------- /src/model.lua: -------------------------------------------------------------------------------- 1 | --- Load up network model or initialize from scratch 2 | paths.dofile('models/' .. opt.netType .. '.lua') 3 | 4 | -- Continuing an experiment where it left off 5 | if opt.continue or opt.branch ~= 'none' then 6 | local prevModel = opt.load .. '/model_' .. opt.lastEpoch .. '.t7' 7 | print('==> Loading model from: ' .. prevModel) 8 | model = torch.load(prevModel) 9 | 10 | -- Or a path to previously trained model is provided 11 | elseif opt.loadModel ~= 'none' then 12 | assert(paths.filep(opt.loadModel), 'File not found: ' .. opt.loadModel) 13 | print('==> Loading model from: ' .. opt.loadModel) 14 | model = torch.load(opt.loadModel) 15 | 16 | -- Or we're starting fresh 17 | else 18 | print('==> Creating model from file: models/' .. opt.netType .. '.lua') 19 | model = createModel(modelArgs) 20 | end 21 | 22 | -- Criterion (can be set in the opt.task file as well) 23 | if not criterion then 24 | criterion = nn[opt.crit .. 'Criterion']() 25 | end 26 | 27 | if opt.GPU ~= -1 then 28 | -- Convert model to CUDA 29 | print('==> Converting model to CUDA') 30 | model:cuda() 31 | criterion:cuda() 32 | 33 | cudnn.fastest = true 34 | cudnn.benchmark = true 35 | end 36 | -------------------------------------------------------------------------------- /src/util/pose-int.lua: -------------------------------------------------------------------------------- 1 | -- Get prediction coordinates 2 | predDim = {nParts,2} 3 | 4 | criterion = nn.ParallelCriterion() 5 | for i = 1,opt.nStack do criterion:add(nn.MSECriterion()) end 6 | 7 | -- Code to generate training samples from raw images. 8 | function generateSample(set, idx) 9 | local pts = annot[set]['part'][idx] 10 | local c = annot[set]['center'][idx] 11 | local s = annot[set]['scale'][idx] 12 | local img = image.load(opt.dataDir .. '/images/' .. annot[set]['images'][idx]) 13 | 14 | -- For single-person pose estimation with a centered/scaled figure 15 | local inp = crop(img, c, s, 0, opt.inputRes) 16 | local out = torch.zeros(nParts, opt.outputRes, opt.outputRes) 17 | for i = 1,nParts do 18 | if pts[i][1] > 0 then -- Checks that there is a ground truth annotation 19 | drawGaussian(out[i], transform(torch.add(pts[i],1), c, s, 0, opt.outputRes), 1) 20 | end 21 | end 22 | 23 | return inp,out 24 | end 25 | 26 | function preprocess(input, label) 27 | newLabel = {} 28 | for i = 1,opt.nStack do newLabel[i] = label end 29 | return input, newLabel 30 | end 31 | 32 | function postprocess(set, idx, output) 33 | local preds = getPreds(output[#output]) 34 | return preds 35 | end 36 | 37 | function accuracy(output,label) 38 | local jntIdxs = {mpii={1,2,3,4,5,6,11,12,15,16},flic={2,3,5,6,7,8}} 39 | return heatmapAccuracy(output[#output],label[#output],nil,jntIdxs[opt.dataset]) 40 | end 41 | -------------------------------------------------------------------------------- /src/misc/verify_gradient.lua: -------------------------------------------------------------------------------- 1 | -- This hasn't been used in a while, but does a gradient check for custom layer modules 2 | 3 | paths.dofile('../net/setup.lua') -- Initialize options and load data 4 | paths.dofile('../net/model.lua') -- Read in network model 5 | 6 | local idx = 1 7 | local batch_size = 2 8 | local inp_dims = {{idx,idx+batch_size-1}} 9 | for j = 1,dataDim:size()[1] do inp_dims[j+1] = {1,dataDim[j]} end 10 | local label_dims = {{idx,idx+batch_size-1}} 11 | for j = 1,labelDim:size()[1] do label_dims[j+1] = {1,labelDim[j]} end 12 | local inputs = trainFile:read('data'):partial(unpack(inp_dims)) 13 | local labels = trainFile:read('label'):partial(unpack(label_dims)) 14 | in_size = inputs:size() 15 | 16 | local function check_net(input) 17 | input = input:view(in_size) 18 | if opt.GPU ~= -1 then 19 | input = torch.CudaTensor(input:size()):copy(input) 20 | labels = torch.CudaTensor(labels:size()):copy(labels) 21 | end 22 | if preprocess then input,labels = preprocess(input,labels) end 23 | local output = model:forward(input) 24 | local loss = criterion:forward(output,labels) 25 | local d_loss = criterion:backward(output,labels) 26 | local grad = model:backward(input, d_loss) 27 | if unprocess then grad = unprocess(grad, batch_size) end 28 | grad = grad:view(grad:numel()) 29 | return loss, grad 30 | end 31 | 32 | timer = torch.Timer() 33 | diff,a,b = optim.checkgrad(check_net, inputs:view(inputs:numel()), 1e-7) 34 | print(a:cat(b,2)) 35 | print(diff) 36 | print("Time: " .. timer:time().real) 37 | 38 | trainFile:close() 39 | -------------------------------------------------------------------------------- /src/util/pose-vol.lua: -------------------------------------------------------------------------------- 1 | -- Get prediction coordinates 2 | predDim = {nParts,3} 3 | 4 | -- Code to generate training samples from raw images. 5 | function generateSample(set, idx) 6 | local pts = annot[set]['part'][idx] 7 | local c = annot[set]['center'][idx] 8 | local s = annot[set]['scale'][idx] 9 | local z = annot[set]['zind'][idx] 10 | local img = image.load(opt.dataDir .. '/images/' .. annot[set]['images'][idx]) 11 | 12 | -- For single-person pose estimation with a centered/scaled figure 13 | local inp = crop(img, c, s, 0, opt.inputRes) 14 | local out = {} 15 | 16 | local sigma_2d = 2 17 | 18 | local size_z = 2*torch.floor((6*sigma_2d*opt.resZ[1]/opt.outputRes+1)/2)+1 19 | local outTemp = torch.zeros(nParts*opt.resZ[1], opt.outputRes, opt.outputRes) 20 | for j = 1,nParts do 21 | if pts[j][1] > 0 then -- Checks that there is a ground truth annotation 22 | drawGaussian3D(outTemp:sub((j-1)*opt.resZ[1]+1,j*opt.resZ[1]), transform(torch.add(pts[j],1), c, s, 0, opt.outputRes), torch.ceil(z[j]*opt.resZ[1]/opt.outputRes), sigma_2d, size_z) 23 | end 24 | end 25 | table.insert(out,outTemp) 26 | 27 | return inp,out 28 | end 29 | 30 | function preprocess(input, label) 31 | return input, label[1] 32 | end 33 | 34 | function postprocess(set, idx, output) 35 | local preds = getPreds3D(output) 36 | return preds 37 | end 38 | 39 | function accuracy(output,label) 40 | local jntIdxs = {mpii={1,2,3,4,5,6,11,12,15,16},flic={2,3,5,6,7,8},h36m={1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}} 41 | return heatmapAccuracy(output,label,nil,jntIdxs[opt.dataset]) 42 | end 43 | -------------------------------------------------------------------------------- /src/util/pose.lua: -------------------------------------------------------------------------------- 1 | -- Get prediction coordinates 2 | predDim = {nParts,2} 3 | 4 | -- Code to generate training samples from raw images. 5 | function generateSample(set, idx) 6 | local pts = annot[set]['part'][idx] 7 | local c = annot[set]['center'][idx] 8 | local s = annot[set]['scale'][idx] 9 | local img = image.load(opt.dataDir .. '/images/' .. annot[set]['images'][idx]) 10 | 11 | -- For single-person pose estimation with a centered/scaled figure 12 | local inp = crop(img, c, s, 0, opt.inputRes) 13 | local out = torch.zeros(nParts, opt.outputRes, opt.outputRes) 14 | for i = 1,nParts do 15 | if pts[i][1] > 0 then -- Checks that there is a ground truth annotation 16 | drawGaussian(out[i], transform(torch.add(pts[i],1), c, s, 0, opt.outputRes), 2) 17 | end 18 | end 19 | 20 | return inp,out 21 | end 22 | 23 | -- function preprocess(input, label) 24 | -- require 'image' 25 | -- w = image.display{image=input,win=w} 26 | -- w2 = image.display{image=label:view(label:size(1)*label:size(2),label:size(3),label:size(4)),win=w2} 27 | -- return input, label 28 | -- end 29 | 30 | function postprocess(set, idx, output) 31 | -- Return predictions in the heatmap coordinate space 32 | -- The evaluation code will apply the transformation back to the original image space 33 | -- (Though we could also do it here) 34 | local preds = getPreds(output) 35 | return preds 36 | end 37 | 38 | function accuracy(output,label) 39 | -- Only care about accuracy across the most difficult joints 40 | local jntIdxs = {mpii={1,2,3,4,5,6,11,12,15,16},flic={2,3,5,6,7,8}} 41 | return heatmapAccuracy(output,label,nil,jntIdxs[opt.dataset]) 42 | end 43 | -------------------------------------------------------------------------------- /src/util/pose-c2f.lua: -------------------------------------------------------------------------------- 1 | -- Get prediction coordinates 2 | predDim = {nParts,3} 3 | 4 | criterion = nn.ParallelCriterion() 5 | for i = 1,opt.nStack do criterion:add(nn.MSECriterion()) end 6 | 7 | -- Code to generate training samples from raw images. 8 | function generateSample(set, idx) 9 | local pts = annot[set]['part'][idx] 10 | local c = annot[set]['center'][idx] 11 | local s = annot[set]['scale'][idx] 12 | local z = annot[set]['zind'][idx] 13 | local img = image.load(opt.dataDir .. '/images/' .. annot[set]['images'][idx]) 14 | 15 | -- For single-person pose estimation with a centered/scaled figure 16 | local inp = crop(img, c, s, 0, opt.inputRes) 17 | local out = {} 18 | 19 | local sigma_2d = 2 20 | 21 | for i = 1,opt.nStack do 22 | local size_z = 2*torch.floor((6*sigma_2d*opt.resZ[i]/opt.outputRes+1)/2)+1 23 | local outTemp = torch.zeros(nParts*opt.resZ[i], opt.outputRes, opt.outputRes) 24 | for j = 1,nParts do 25 | if pts[j][1] > 0 then -- Checks that there is a ground truth annotation 26 | drawGaussian3D(outTemp:sub((j-1)*opt.resZ[i]+1,j*opt.resZ[i]), transform(torch.add(pts[j],1), c, s, 0, opt.outputRes), torch.ceil(z[j]*opt.resZ[i]/opt.outputRes), sigma_2d, size_z) 27 | end 28 | end 29 | out[i] = outTemp 30 | end 31 | 32 | return inp,out 33 | end 34 | 35 | function preprocess(input, label) 36 | return input, label 37 | end 38 | 39 | function postprocess(set, idx, output) 40 | local preds = getPreds3D(output[#output]) 41 | return preds 42 | end 43 | 44 | function accuracy(output,label) 45 | local jntIdxs = {mpii={1,2,3,4,5,6,11,12,15,16},flic={2,3,5,6,7,8},h36m={1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}} 46 | return heatmapAccuracy(output[#output],label[#label],nil,jntIdxs[opt.dataset]) 47 | end 48 | -------------------------------------------------------------------------------- /src/pypose/ref.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.misc import imread 3 | import h5py 4 | 5 | # Home of the posenet directory, (change if not in your home directory) 6 | posedir = os.environ["HOME"] + '/posenet' 7 | # Global options 8 | in_res = [256, 256] 9 | out_res = [64, 64] 10 | 11 | # Load annotations 12 | # Example call: ref.load('mpii','train') 13 | def load(dataset, settype): 14 | return h5py.File('%s/data/%s/annot/%s.h5' % (posedir, dataset, settype), 'r+') 15 | 16 | # Part reference 17 | parts = {'flic':['lsho', 'lelb', 'lwri', 18 | 'rsho', 'relb', 'rwri', 19 | 'lhip', 'rhip', 20 | 'leye', 'reye', 'nose'], 21 | 'mpii':['rank', 'rkne', 'rhip', 22 | 'lhip', 'lkne', 'lank', 23 | 'pelv', 'thrx', 'neck', 'head', 24 | 'rwri', 'relb', 'rsho', 25 | 'lsho', 'lelb', 'lwri']} 26 | 27 | flipped_parts = {'flic':[3, 4, 5, 0, 1, 2, 7, 6, 9, 8, 10], 28 | 'mpii':[5, 4, 3, 2, 1, 0, 6, 7, 8, 9, 15, 14, 13, 12, 11, 10]} 29 | 30 | part_pairs = {'flic':[[0, 3], [1, 4], [2, 5], [6, 7], [8, 9, 10]], 31 | 'mpii':[[0, 5], [1, 4], [2, 3], [6], [7], [8], [9], [10, 15], [11, 14], [12, 13]]} 32 | 33 | pair_names = {'flic':['shoulder', 'elbow', 'wrist', 'hip', 'face'], 34 | 'mpii':['ankle', 'knee', 'hip', 'pelvis', 'thorax', 'neck', 'head', 'wrist', 'elbow', 'shoulder']} 35 | 36 | def partinfo(annot, idx, part): 37 | # This function can take either the part name or the index of the part 38 | if type(part) is str: 39 | part = parts[annot.attrs['name']].index(part) 40 | return annot['part'][idx, part] 41 | 42 | # Load in an image 43 | def loadimg(annot, idx): 44 | imgpath = '%s/data/%s/images/%s' % (posedir, annot.attrs['name'], annot['imgname'][idx]) 45 | return imread(imgpath) 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Large portions of this code were built off: 2 | https://github.com/soumith/imagenet-multiGPU.torch 3 | Copyright (c) 2016, Soumith Chintala 4 | and 5 | https://github.com/anewell/pose-hg-train 6 | Copyright (c) 2016, University of Michigan 7 | 8 | For the rest of the code: 9 | Copyright (c) 2017, University of Pennsylvania 10 | All rights reserved. 11 | 12 | Redistribution and use in source and binary forms, with or without 13 | modification, are permitted provided that the following conditions are met: 14 | 15 | * Redistributions of source code must retain the above copyright notice, this 16 | list of conditions and the following disclaimer. 17 | 18 | * Redistributions in binary form must reproduce the above copyright notice, 19 | this list of conditions and the following disclaimer in the documentation 20 | and/or other materials provided with the distribution. 21 | 22 | * Neither the name of pose-hg-train nor the names of its 23 | contributors may be used to endorse or promote products derived from 24 | this software without specific prior written permission. 25 | 26 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 27 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 28 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 29 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 30 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 31 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 32 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 33 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 34 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 35 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 36 | -------------------------------------------------------------------------------- /src/models/hg-generic.lua: -------------------------------------------------------------------------------- 1 | paths.dofile('layers/Residual.lua') 2 | 3 | local function hourglass(n, f, inp) 4 | -- Upper branch 5 | local up1 = Residual(f,f)(inp) 6 | 7 | -- Lower branch 8 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(inp) 9 | local low1 = Residual(f,f)(pool) 10 | local low2 11 | 12 | if n > 1 then low2 = hourglass(n-1,f,low1) 13 | else low2 = Residual(f,f)(low1) end 14 | 15 | local low3 = Residual(f,f)(low2) 16 | local up2 = nn.SpatialUpSamplingNearest(2)(low3) 17 | 18 | -- Bring two branches together 19 | return nn.CAddTable()({up1,up2}) 20 | end 21 | 22 | local function lin(numIn,numOut,inp) 23 | -- Apply 1x1 convolution, stride 1, no padding 24 | local l = nnlib.SpatialConvolution(numIn,numOut,1,1,1,1,0,0)(inp) 25 | return nnlib.ReLU(true)(nn.SpatialBatchNormalization(numOut)(l)) 26 | end 27 | 28 | function createModel() 29 | 30 | local inp = nn.Identity()() 31 | 32 | -- Initial processing of the image 33 | local cnv1_ = nnlib.SpatialConvolution(3,64,7,7,2,2,3,3)(inp) -- 128 34 | local cnv1 = nnlib.ReLU(true)(nn.SpatialBatchNormalization(64)(cnv1_)) 35 | local r1 = Residual(64,128)(cnv1) 36 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(r1) -- 64 37 | local r4 = Residual(128,128)(pool) 38 | local r5 = Residual(128,opt.nFeats)(r4) 39 | 40 | local out = {} 41 | local inter = r5 42 | 43 | for i = 1,opt.nStack do 44 | local hg = hourglass(4,opt.nFeats,inter) 45 | 46 | -- Linear layer to produce first set of predictions 47 | local ll = lin(opt.nFeats,opt.nFeats,hg) 48 | 49 | -- Predicted heatmaps 50 | local tmpOut = nnlib.SpatialConvolution(opt.nFeats,outputDim[1],1,1,1,1,0,0)(ll) 51 | table.insert(out,tmpOut) 52 | 53 | if i < opt.nStack then inter = nn.CAddTable()({inter, hg}) end 54 | end 55 | 56 | -- Final model 57 | local model = nn.gModule({inp}, out) 58 | 59 | return model 60 | 61 | end 62 | -------------------------------------------------------------------------------- /src/models/hg-stacked-no-int.lua: -------------------------------------------------------------------------------- 1 | -- Same model without intermediate supervision 2 | 3 | paths.dofile('layers/Residual.lua') 4 | 5 | local function hourglass(n, numIn, numOut, inp) 6 | -- Upper branch 7 | local up1 = Residual(numIn,256)(inp) 8 | local up2 = Residual(256,256)(up1) 9 | local up4 = Residual(256,numOut)(up2) 10 | 11 | -- Lower branch 12 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(inp) 13 | local low1 = Residual(numIn,256)(pool) 14 | local low2 = Residual(256,256)(low1) 15 | local low5 = Residual(256,256)(low2) 16 | local low6 17 | if n > 1 then 18 | low6 = hourglass(n-1,256,numOut,low5) 19 | else 20 | low6 = Residual(256,numOut)(low5) 21 | end 22 | local low7 = Residual(numOut,numOut)(low6) 23 | local up5 = nn.SpatialUpSamplingNearest(2)(low7) 24 | 25 | -- Bring two branches together 26 | return nn.CAddTable()({up4,up5}) 27 | end 28 | 29 | local function lin(numIn,numOut,inp) 30 | -- Apply 1x1 convolution, no stride, no padding 31 | local l_ = nnlib.SpatialConvolution(numIn,numOut,1,1,1,1,0,0)(inp) 32 | return nnlib.ReLU(true)(nn.SpatialBatchNormalization(numOut)(l_)) 33 | end 34 | 35 | function createModel() 36 | 37 | local inp = nn.Identity()() 38 | 39 | -- Initial processing of the image 40 | local cnv1_ = nnlib.SpatialConvolution(3,64,7,7,2,2,3,3)(inp) -- 128 41 | local cnv1 = nnlib.ReLU(true)(nn.SpatialBatchNormalization(64)(cnv1_)) 42 | local r1 = Residual(64,128)(cnv1) 43 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(r1) -- 64 44 | local r4 = Residual(128,128)(pool) 45 | local r5 = Residual(128,128)(r4) 46 | local r6 = Residual(128,256)(r5) 47 | 48 | -- First hourglass 49 | local hg1 = hourglass(4,256,512,r6) 50 | 51 | -- Linear layers to produce first set of predictions 52 | local l1 = lin(512,512,hg1) 53 | local l2 = lin(512,256,l1) 54 | 55 | -- Concatenate with previous linear features 56 | local cat1 = nn.JoinTable(2)({l2,pool}) 57 | 58 | -- Second hourglass 59 | local hg2 = hourglass(4,256+128,512,cat1) 60 | 61 | -- Linear layers to produce predictions again 62 | local l3 = lin(512,512,hg2) 63 | local l4 = lin(512,512,l3) 64 | 65 | -- Output heatmaps 66 | local out = nnlib.SpatialConvolution(512,outputDim[1],1,1,1,1,0,0)(l4) 67 | 68 | -- Final model 69 | local model = nn.gModule({inp}, {out}) 70 | 71 | return model 72 | 73 | end 74 | -------------------------------------------------------------------------------- /src/misc/visualize_results.lua: -------------------------------------------------------------------------------- 1 | require 'paths' 2 | arg = {'-GPU','-1'} 3 | paths.dofile('../ref.lua') 4 | require 'sys' 5 | 6 | pairRef = { 7 | {1,2}, {2,3}, {3,7}, 8 | {4,5}, {4,7}, {5,6}, 9 | {7,9}, {9,10}, 10 | {14,9}, {11,12}, {12,13}, 11 | {13,9}, {14,15}, {15,16} 12 | } 13 | 14 | partNames = {'RAnk','RKne','RHip','LHip','LKne','LAnk', 15 | 'Pelv','Thrx','Neck','Head', 16 | 'RWri','RElb','RSho','LSho','LElb','LWri'} 17 | 18 | function transformCoords(i,cds) 19 | local c = annot['test']['center'][i] 20 | local s = annot['test']['scale'][i] 21 | local new_cds = torch.zeros(cds:size()) 22 | for j = 1,cds:size(1) do 23 | new_cds[j] = transform(cds[j]:add(-.5),c,s,0,64,true) 24 | end 25 | return new_cds 26 | end 27 | 28 | predsfile = opt.expDir .. '/best/preds_full.h5' 29 | preds = hdf5.open(predsfile) 30 | 31 | for i = 1,10000 do 32 | hms = preds:read('pred_heatmaps'):partial({i,i},{1,16},{1,64},{1,64}) 33 | print(i) 34 | old_coords = preds:read('preds'):partial({i,i},{1,16},{1,2}) 35 | new_coord = transformCoords(i,old_coords[1]) 36 | new_coords_2 = preds:read('preds_tf'):partial({i,i},{1,16},{1,2}) 37 | im = image.load(opt.dataDir .. '/images/' .. annot['test']['images'][i]) 38 | for j = 1,#pairRef do 39 | if hms[1][pairRef[j][1]]:max() > .05 and hms[1][pairRef[j][2]]:max() > .05 then 40 | local s = annot['test']['scale'][i] 41 | if string.sub(partNames[pairRef[j][1]],1,1) == 'L' then 42 | drawLine(im[1],new_coord[pairRef[j][1]],new_coord[pairRef[j][2]],4*s,1,0,2) 43 | drawLine(im[2],new_coord[pairRef[j][1]],new_coord[pairRef[j][2]],4*s,1,0,2) 44 | drawLine(im[3],new_coord[pairRef[j][1]],new_coord[pairRef[j][2]],4*s,1,1,2) 45 | elseif string.sub(partNames[pairRef[j][1]],1,1) == 'R' then 46 | drawLine(im[1],new_coord[pairRef[j][1]],new_coord[pairRef[j][2]],4*s,1,1,2) 47 | drawLine(im[2],new_coord[pairRef[j][1]],new_coord[pairRef[j][2]],4*s,1,0,2) 48 | drawLine(im[3],new_coord[pairRef[j][1]],new_coord[pairRef[j][2]],4*s,1,0,2) 49 | else 50 | drawLine(im[1],new_coord[pairRef[j][1]],new_coord[pairRef[j][2]],4*s,1,.7,2) 51 | drawLine(im[2],new_coord[pairRef[j][1]],new_coord[pairRef[j][2]],4*s,1,0,2) 52 | drawLine(im[3],new_coord[pairRef[j][1]],new_coord[pairRef[j][2]],4*s,1,.7,2) 53 | end 54 | else 55 | print("Not drawing:",partNames[pairRef[j][1]],partNames[pairRef[j][2]]) 56 | end 57 | end 58 | w = image.display{image=im,win=w} 59 | w2 = image.display{image=hms[1],win=w2} 60 | sys.sleep(.2) 61 | end 62 | -------------------------------------------------------------------------------- /src/misc/analyze_occlusion_pr.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | from pypose import ref 4 | import h5py 5 | import numpy as np 6 | 7 | a = ref.load('mpii','valid') 8 | 9 | ''' 10 | predFile = '/home/-/posenet/exp/mpii/hg-I-2/preds.h5' 11 | f = h5py.File(predFile,'r') 12 | p = f['preds_tf'] 13 | hms = f['pred_heatmaps'] 14 | 15 | max_act = np.zeros((2958,16)) 16 | mean_act = np.zeros((2958,16)) 17 | 18 | for i in xrange(2958): 19 | print i 20 | for j in xrange(16): 21 | max_act[i][j] = hms[i][j].max() 22 | mean_act[i][j] = hms[i][j].mean() 23 | 24 | np.save('maxAct.npy',max_act) 25 | np.save('meanAct.npy',mean_act) 26 | ''' 27 | 28 | maxAct = np.load('maxAct.npy') 29 | meanAct = np.load('meanAct.npy') 30 | 31 | thr = np.arange(0.01,.9,.01) 32 | thr2 = np.arange(-0.00001,.015,.0002) 33 | 34 | # precision = tp / tp + fp 35 | # recall = tp / tp + fn 36 | partChoice = [1,4] 37 | ptIdx = a['part'][:,partChoice,0] <= 0 38 | mxp = [] 39 | mxr = [] 40 | mnp = [] 41 | mnr = [] 42 | track_tp = [] 43 | track_fp = [] 44 | track_fn = [] 45 | track_tn = [] 46 | max_acc = 0 47 | for i in xrange(thr2.size): 48 | maxIdx = maxAct[:,partChoice] <= thr[i] 49 | tp = (maxIdx * ptIdx).sum() 50 | fp = (maxIdx * -ptIdx).sum() 51 | fn = (-maxIdx * ptIdx).sum() 52 | if tp+fp == 0: 53 | mxp += [1.] 54 | else: 55 | mxp += [(float(tp)/(tp+fp))] 56 | mxr += [(float(tp)/(tp+fn))] 57 | meanIdx = meanAct[:,partChoice] < thr2[i] 58 | tp = (meanIdx * ptIdx).sum() 59 | fp = (meanIdx * -ptIdx).sum() 60 | fn = (-meanIdx * ptIdx).sum() 61 | tn = (-meanIdx * -ptIdx).sum() 62 | acc = float(tp + tn) / (tp + fp +fn + tn) 63 | # if acc > max_acc: 64 | # print thr2[i],acc 65 | # max_acc = acc 66 | if tp+fp == 0: 67 | mnp += [1.] 68 | else: 69 | mnp += [(float(tp)/(tp+fp))] 70 | mnr += [(float(tp)/(tp+fn))] 71 | track_tp += [tp] 72 | track_fp += [fp] 73 | track_fn += [fn] 74 | track_tn += [tn] 75 | 76 | f = plt.figure() 77 | ax1 = f.add_subplot(111) 78 | ax1.plot(mxr,mxp,label='Max') 79 | ax1.plot(mnr,mnp,label='Mean') 80 | ax1.set_ylim(0,1) 81 | ax1.set_xlim(0,1) 82 | ax1.set_title('Ankle') 83 | ax1.set_xlabel('Recall') 84 | ax1.set_ylabel('Precision') 85 | ax1.legend(loc='lower right') 86 | """ 87 | ax2 = f.add_subplot(212) 88 | ax2.plot(track_tp,label='tp') 89 | ax2.plot(track_fp,label='fp') 90 | ax2.plot(track_fn,label='fn') 91 | ax2.plot(track_tn,label='tn') 92 | ax2.plot(mnp,label='p') 93 | ax2.plot(mnr,label='r') 94 | ax2.legend() 95 | """ 96 | plt.show() 97 | -------------------------------------------------------------------------------- /src/pypose/draw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage.draw 3 | 4 | def line(img, pt1, pt2, color, width): 5 | # Draw a line on an image 6 | # Make sure dimension of color matches number of channels in img 7 | 8 | # First get coordinates for corners of the line 9 | diff = np.array([pt1[1] - pt2[1], pt1[0] - pt2[0]], np.float) 10 | mag = np.linalg.norm(diff) 11 | if mag >= 1: 12 | diff *= width / (2 * mag) 13 | x = np.array([pt1[0] - diff[0], pt2[0] - diff[0], pt2[0] + diff[0], pt1[0] + diff[0]], int) 14 | y = np.array([pt1[1] + diff[1], pt2[1] + diff[1], pt2[1] - diff[1], pt1[1] - diff[1]], int) 15 | else: 16 | d = float(width) / 2 17 | x = np.array([pt1[0] - d, pt1[0] + d, pt1[0] + d, pt1[0] - d], int) 18 | y = np.array([pt1[1] - d, pt1[1] - d, pt1[1] + d, pt1[1] + d], int) 19 | 20 | # noinspection PyArgumentList 21 | rr, cc = skimage.draw.polygon(y, x, img.shape) 22 | img[rr, cc] = color 23 | 24 | return img 25 | 26 | def limb(img, pt1, pt2, color, width): 27 | # Specific handling of a limb, in case the annotation isn't there for one of the joints 28 | if pt1[0] > 0 and pt2[0] > 0: 29 | line(img, pt1, pt2, color, width) 30 | elif pt1[0] > 0: 31 | circle(img, pt1, color, width) 32 | elif pt2[0] > 0: 33 | circle(img, pt2, color, width) 34 | 35 | def gaussian(img, pt, sigma): 36 | # Draw a 2D gaussian 37 | 38 | # Check that any part of the gaussian is in-bounds 39 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] 40 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] 41 | if (ul[0] > img.shape[1] or ul[1] >= img.shape[0] or 42 | br[0] < 0 or br[1] < 0): 43 | # If not, just return the image as is 44 | return img 45 | 46 | # Generate gaussian 47 | size = 6 * sigma + 1 48 | x = np.arange(0, size, 1, float) 49 | y = x[:, np.newaxis] 50 | x0 = y0 = size // 2 51 | # The gaussian is not normalized, we want the center value to equal 1 52 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 53 | 54 | # Usable gaussian range 55 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] 56 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] 57 | # Image range 58 | img_x = max(0, ul[0]), min(br[0], img.shape[1]) 59 | img_y = max(0, ul[1]), min(br[1], img.shape[0]) 60 | 61 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 62 | return img 63 | 64 | def circle(img, pt, color, radius): 65 | # Draw a circle 66 | # Mostly a convenient wrapper for skimage.draw.circle 67 | 68 | rr, cc = skimage.draw.circle(pt[1], pt[0], radius, img.shape) 69 | img[rr, cc] = color 70 | return img 71 | -------------------------------------------------------------------------------- /src/models/hg-stacked.lua: -------------------------------------------------------------------------------- 1 | paths.dofile('layers/Residual.lua') 2 | 3 | local function hourglass(n, numIn, numOut, inp) 4 | -- Upper branch 5 | local up1 = Residual(numIn,256)(inp) 6 | local up2 = Residual(256,256)(up1) 7 | local up4 = Residual(256,numOut)(up2) 8 | 9 | -- Lower branch 10 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(inp) 11 | local low1 = Residual(numIn,256)(pool) 12 | local low2 = Residual(256,256)(low1) 13 | local low5 = Residual(256,256)(low2) 14 | local low6 15 | if n > 1 then 16 | low6 = hourglass(n-1,256,numOut,low5) 17 | else 18 | low6 = Residual(256,numOut)(low5) 19 | end 20 | local low7 = Residual(numOut,numOut)(low6) 21 | local up5 = nn.SpatialUpSamplingNearest(2)(low7) 22 | 23 | -- Bring two branches together 24 | return nn.CAddTable()({up4,up5}) 25 | end 26 | 27 | local function lin(numIn,numOut,inp) 28 | -- Apply 1x1 convolution, no stride, no padding 29 | local l_ = nnlib.SpatialConvolution(numIn,numOut,1,1,1,1,0,0)(inp) 30 | return nnlib.ReLU(true)(nn.SpatialBatchNormalization(numOut)(l_)) 31 | end 32 | 33 | function createModel() 34 | 35 | local inp = nn.Identity()() 36 | 37 | -- Initial processing of the image 38 | local cnv1_ = nnlib.SpatialConvolution(3,64,7,7,2,2,3,3)(inp) -- 128 39 | local cnv1 = nnlib.ReLU(true)(nn.SpatialBatchNormalization(64)(cnv1_)) 40 | local r1 = Residual(64,128)(cnv1) 41 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(r1) -- 64 42 | local r4 = Residual(128,128)(pool) 43 | local r5 = Residual(128,128)(r4) 44 | local r6 = Residual(128,256)(r5) 45 | 46 | -- First hourglass 47 | local hg1 = hourglass(4,256,512,r6) 48 | 49 | -- Linear layers to produce first set of predictions 50 | local l1 = lin(512,512,hg1) 51 | local l2 = lin(512,256,l1) 52 | 53 | -- First predicted heatmaps 54 | local out1 = nnlib.SpatialConvolution(256,outputDim[1][1],1,1,1,1,0,0)(l2) 55 | local out1_ = nnlib.SpatialConvolution(outputDim[1][1],256+128,1,1,1,1,0,0)(out1) 56 | 57 | -- Concatenate with previous linear features 58 | local cat1 = nn.JoinTable(2)({l2,pool}) 59 | local cat1_ = nnlib.SpatialConvolution(256+128,256+128,1,1,1,1,0,0)(cat1) 60 | local int1 = nn.CAddTable()({cat1_,out1_}) 61 | 62 | -- Second hourglass 63 | local hg2 = hourglass(4,256+128,512,int1) 64 | 65 | -- Linear layers to produce predictions again 66 | local l3 = lin(512,512,hg2) 67 | local l4 = lin(512,512,l3) 68 | 69 | -- Output heatmaps 70 | local out2 = nnlib.SpatialConvolution(512,outputDim[2][1],1,1,1,1,0,0)(l4) 71 | 72 | -- Final model 73 | local model = nn.gModule({inp}, {out1,out2}) 74 | 75 | return model 76 | 77 | end 78 | -------------------------------------------------------------------------------- /src/models/hg.lua: -------------------------------------------------------------------------------- 1 | paths.dofile('layers/Residual.lua') 2 | 3 | local function hourglass(n, f, inp) 4 | -- Upper branch 5 | local up1 = inp 6 | for i = 1,opt.nModules do up1 = Residual(f,f)(up1) end 7 | 8 | -- Lower branch 9 | local low1 = nnlib.SpatialMaxPooling(2,2,2,2)(inp) 10 | for i = 1,opt.nModules do low1 = Residual(f,f)(low1) end 11 | local low2 12 | 13 | if n > 1 then low2 = hourglass(n-1,f,low1) 14 | else 15 | low2 = low1 16 | for i = 1,opt.nModules do low2 = Residual(f,f)(low2) end 17 | end 18 | 19 | local low3 = low2 20 | for i = 1,opt.nModules do low3 = Residual(f,f)(low3) end 21 | local up2 = nn.SpatialUpSamplingNearest(2)(low3) 22 | 23 | -- Bring two branches together 24 | return nn.CAddTable()({up1,up2}) 25 | end 26 | 27 | local function lin(numIn,numOut,inp) 28 | -- Apply 1x1 convolution, stride 1, no padding 29 | local l = nnlib.SpatialConvolution(numIn,numOut,1,1,1,1,0,0)(inp) 30 | return nnlib.ReLU(true)(nn.SpatialBatchNormalization(numOut)(l)) 31 | end 32 | 33 | function createModel() 34 | 35 | local inp = nn.Identity()() 36 | 37 | -- Initial processing of the image 38 | local cnv1_ = nnlib.SpatialConvolution(3,64,7,7,2,2,3,3)(inp) -- 128 39 | local cnv1 = nnlib.ReLU(true)(nn.SpatialBatchNormalization(64)(cnv1_)) 40 | local r1 = Residual(64,128)(cnv1) 41 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(r1) -- 64 42 | local r4 = Residual(128,128)(pool) 43 | local r5 = Residual(128,opt.nFeats)(r4) 44 | 45 | local out = {} 46 | local inter = r5 47 | 48 | for i = 1,opt.nStack do 49 | local hg = hourglass(4,opt.nFeats,inter) 50 | 51 | -- Residual layers at output resolution 52 | local ll = hg 53 | for j = 1,opt.nModules do ll = Residual(opt.nFeats,opt.nFeats)(ll) end 54 | -- Linear layer to produce first set of predictions 55 | ll = lin(opt.nFeats,opt.nFeats,ll) 56 | 57 | -- Predicted heatmaps 58 | local tmpOut 59 | if opt.nStack > 1 then 60 | tmpOut = nnlib.SpatialConvolution(opt.nFeats,outputDim[i][1],1,1,1,1,0,0)(ll) 61 | else 62 | tmpOut = nnlib.SpatialConvolution(opt.nFeats,outputDim[1],1,1,1,1,0,0)(ll) 63 | end 64 | table.insert(out,tmpOut) 65 | 66 | -- Add predictions back 67 | if i < opt.nStack then 68 | local ll_ = nnlib.SpatialConvolution(opt.nFeats,opt.nFeats,1,1,1,1,0,0)(ll) 69 | local tmpOut_ = nnlib.SpatialConvolution(outputDim[i][1],opt.nFeats,1,1,1,1,0,0)(tmpOut) 70 | inter = nn.CAddTable()({inter, ll_, tmpOut_}) 71 | end 72 | end 73 | 74 | -- Final model 75 | local model = nn.gModule({inp}, out) 76 | 77 | return model 78 | 79 | end 80 | -------------------------------------------------------------------------------- /src/models/hg-stacked-2.lua: -------------------------------------------------------------------------------- 1 | paths.dofile('layers/Residual.lua') 2 | 3 | local function hourglass(n, numIn, numOut, inp) 4 | -- Upper branch 5 | local up1 = Residual(numIn,256)(inp) 6 | local up2 = Residual(256,256)(up1) 7 | local up4 = Residual(256,numOut)(up2) 8 | 9 | -- Lower branch 10 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(inp) 11 | local low1 = Residual(numIn,256)(pool) 12 | local low2 = Residual(256,256)(low1) 13 | local low5 = Residual(256,256)(low2) 14 | local low6 15 | if n > 1 then 16 | low6 = hourglass(n-1,256,numOut,low5) 17 | else 18 | low6 = Residual(256,numOut)(low5) 19 | end 20 | local low7 = Residual(numOut,numOut)(low6) 21 | local up5 = nn.SpatialUpSamplingNearest(2)(low7) 22 | 23 | -- Bring two branches together 24 | return nn.CAddTable()({up4,up5}) 25 | end 26 | 27 | local function lin(numIn,numOut,inp) 28 | -- Apply 1x1 convolution, no stride, no padding 29 | local l_ = nnlib.SpatialConvolution(numIn,numOut,1,1,1,1,0,0)(inp) 30 | return nnlib.ReLU(true)(nn.SpatialBatchNormalization(numOut)(l_)) 31 | end 32 | 33 | function createModel() 34 | 35 | local inp = nn.Identity()() 36 | 37 | -- Initial processing of the image 38 | local cnv1_ = nnlib.SpatialConvolution(3,64,7,7,2,2,3,3)(inp) -- 128 39 | local cnv1 = nnlib.ReLU(true)(nn.SpatialBatchNormalization(64)(cnv1_)) 40 | local r1 = Residual(64,128)(cnv1) 41 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(r1) -- 64 42 | local r4 = Residual(128,128)(pool) 43 | local r5 = Residual(128,128)(r4) 44 | local r6 = Residual(128,256)(r5) 45 | 46 | -- First hourglass 47 | local hg1 = hourglass(4,256,512,r6) 48 | 49 | -- Linear layers to produce first set of predictions 50 | local l1 = lin(512,512,hg1) 51 | local l2 = lin(512,256,l1) 52 | 53 | -- First predicted heatmaps 54 | local out1 = nnlib.SpatialConvolution(256,outputDim[1][1],1,1,1,1,0,0)(l2) 55 | local out1_ = nnlib.SpatialConvolution(outputDim[1][1],256+128,1,1,1,1,0,0)(out1) 56 | 57 | -- Concatenate with previous linear features 58 | local cat1 = nn.JoinTable(2)({l2,pool}) 59 | local cat1_ = nnlib.SpatialConvolution(256+128,256+128,1,1,1,1,0,0)(cat1) 60 | local int1 = nn.CAddTable()({cat1_,out1_}) 61 | 62 | -- Second hourglass 63 | local hg2 = hourglass(4,256+128,512,int1) 64 | 65 | -- Linear layers to produce predictions again 66 | local l3 = lin(512,512,hg2) 67 | local l4 = lin(512,512,l3) 68 | 69 | -- Second predicted heatmaps 70 | local out2 = nnlib.SpatialConvolution(512,outputDim[2][1],1,1,1,1,0,0)(l4) 71 | 72 | -- Final model 73 | local model = nn.gModule({inp}, {out1,out2}) 74 | 75 | return model 76 | 77 | end 78 | -------------------------------------------------------------------------------- /src/misc/generate_exs.lua: -------------------------------------------------------------------------------- 1 | require 'paths' 2 | arg = {'-GPU','-1'} 3 | paths.dofile('../ref.lua') 4 | require 'sys' 5 | 6 | pairRef = { 7 | {1,2}, {2,3}, {3,7}, 8 | {4,5}, {4,7}, {5,6}, 9 | {7,9}, {9,10}, 10 | {14,9}, {11,12}, {12,13}, 11 | {13,9}, {14,15}, {15,16} 12 | } 13 | 14 | partNames = {'RLAnk','RLKne','RLHip','LLHip','LLKne','LLAnk', 15 | 'Pelv','Thrx','Neck','Head', 16 | 'RUWri','RUElb','RUSho','LUSho','LUElb','LUWri'} 17 | 18 | function transformCoords(i,cds) 19 | local c = annot['test']['center'][i] 20 | local s = annot['test']['scale'][i] 21 | local new_cds = torch.zeros(cds:size()) 22 | for j = 1,cds:size(1) do 23 | new_cds[j] = transform(cds[j]:add(-.5),c,s,0,64,true) 24 | end 25 | return new_cds 26 | end 27 | 28 | predsfile = opt.expDir .. '/best/preds_full.h5' 29 | preds = hdf5.open(predsfile) 30 | 31 | function getImg(idx,res) 32 | hms = preds:read('pred_heatmaps'):partial({idx,idx},{1,16},{1,64},{1,64}) 33 | coord = preds:read('preds_tf'):partial({idx,idx},{1,16},{1,2})[1] 34 | im = image.load(opt.dataDir .. '/images/' .. annot['test']['images'][idx]) 35 | local c = annot['test']['center'][idx] 36 | local s = annot['test']['scale'][idx] 37 | for j = 1,#pairRef do 38 | if hms[1][pairRef[j][1]]:max() > .1 and hms[1][pairRef[j][2]]:max() > .1 then 39 | if string.sub(partNames[pairRef[j][1]],1,1) == 'L' then 40 | drawLine(im[1],coord[pairRef[j][1]],coord[pairRef[j][2]],4*s,1,0,2) 41 | drawLine(im[2],coord[pairRef[j][1]],coord[pairRef[j][2]],4*s,1,0,2) 42 | drawLine(im[3],coord[pairRef[j][1]],coord[pairRef[j][2]],4*s,1,1,2) 43 | elseif string.sub(partNames[pairRef[j][1]],1,1) == 'R' then 44 | drawLine(im[1],coord[pairRef[j][1]],coord[pairRef[j][2]],4*s,1,1,2) 45 | drawLine(im[2],coord[pairRef[j][1]],coord[pairRef[j][2]],4*s,1,0,2) 46 | drawLine(im[3],coord[pairRef[j][1]],coord[pairRef[j][2]],4*s,1,0,2) 47 | else 48 | drawLine(im[1],coord[pairRef[j][1]],coord[pairRef[j][2]],4*s,1,.7,2) 49 | drawLine(im[2],coord[pairRef[j][1]],coord[pairRef[j][2]],4*s,1,0,2) 50 | drawLine(im[3],coord[pairRef[j][1]],coord[pairRef[j][2]],4*s,1,.7,2) 51 | end 52 | end 53 | end 54 | im = crop(im, c, s, 0, res) 55 | return im 56 | end 57 | 58 | function compileImages(imgs, nrows, ncols, res) 59 | print(imgs[1]:size()) 60 | local totalImg = torch.zeros(3,nrows*res,ncols*res) 61 | for i = 1,#imgs do 62 | local r = torch.floor((i-1)/ncols) + 1 63 | local c = ((i - 1) % ncols) + 1 64 | print(r) 65 | print(c) 66 | totalImg:sub(1,3,(r-1)*res+1,r*res,(c-1)*res+1,c*res):copy(imgs[i]) 67 | end 68 | w = image.display{image=totalImg,win=w} 69 | return totalImg 70 | end 71 | 72 | num_imgs = 10 73 | test_idxs = torch.randperm(11000):sub(1,12) 74 | num_imgs = test_idxs:numel() 75 | ims = {} 76 | for i = 1,num_imgs do 77 | print(test_idxs[i]) 78 | ims[i] = getImg(test_idxs[i],728) 79 | end 80 | final = compileImages(ims, 6, 2, 728) 81 | image.savePNG('examples.png',final) 82 | -------------------------------------------------------------------------------- /src/dataloader.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- Copyright (c) 2016, Facebook, Inc. 3 | -- All rights reserved. 4 | -- 5 | -- This source code is licensed under the BSD-style license found in the 6 | -- LICENSE file in the root directory of this source tree. An additional grant 7 | -- of patent rights can be found in the PATENTS file in the same directory. 8 | -- 9 | -- Multi-threaded data loader 10 | -- 11 | 12 | local Threads = require 'threads' 13 | Threads.serialization('threads.sharedserialize') 14 | 15 | local M = {} 16 | local DataLoader = torch.class('resnet.DataLoader', M) 17 | 18 | function DataLoader.create(opt) 19 | -- The train and valid loader 20 | local loaders = {} 21 | 22 | for i, split in ipairs{'train', 'valid'} do 23 | loaders[split] = M.DataLoader(opt, split) 24 | end 25 | 26 | return loaders 27 | end 28 | 29 | function DataLoader:__init(opt, split) 30 | local function init() 31 | _G.opt = opt 32 | _G.split = split 33 | _G.alreadyChecked = true 34 | paths.dofile('ref.lua') 35 | paths.dofile('data.lua') 36 | end 37 | 38 | local function main(idx) 39 | torch.setnumthreads(1) 40 | if split == 'valid' then _G.isTesting = true end 41 | return opt[split .. 'Iters']*opt[split .. 'Batch'] 42 | end 43 | 44 | local threads, sizes = Threads(opt.nThreads, init, main) 45 | self.threads = threads 46 | self.__size = sizes[1][1] 47 | self.batchsize = opt[split .. 'Batch'] 48 | self.split = split 49 | end 50 | 51 | function DataLoader:size() 52 | return math.ceil(self.__size / self.batchSize) 53 | end 54 | 55 | function DataLoader:run() 56 | local threads = self.threads 57 | local size, batchsize = self.__size, self.batchsize 58 | local perm = torch.randperm(size) 59 | 60 | local idx, sample = 1, nil 61 | local function enqueue() 62 | while idx <= size and threads:acceptsjob() do 63 | local indices = perm:narrow(1, idx, math.min(batchsize, size - idx + 1)) 64 | threads:addjob( 65 | function(indices) 66 | local idx_ = nil 67 | if _G.isTesting then idx_ = idx end 68 | local inp,out = _G.loadData(_G.split, idx_, batchsize) 69 | collectgarbage() 70 | return {inp,out} 71 | end, 72 | 73 | function(_sample_) 74 | sample = _sample_ 75 | end, 76 | 77 | indices 78 | ) 79 | idx = idx + batchsize 80 | end 81 | end 82 | 83 | local n = 0 84 | local function loop() 85 | enqueue() 86 | if not threads:hasjob() then 87 | return nil 88 | end 89 | threads:dojob() 90 | if threads:haserror() then 91 | threads:synchronize() 92 | end 93 | enqueue() 94 | n = n + 1 95 | return n, sample 96 | end 97 | 98 | return loop 99 | end 100 | 101 | return M.DataLoader 102 | -------------------------------------------------------------------------------- /src/models/hg-stacked-3.lua: -------------------------------------------------------------------------------- 1 | paths.dofile('layers/Residual.lua') 2 | 3 | local function hourglass(n, numIn, numOut, inp) 4 | -- Upper branch 5 | local up1 = Residual(numIn,256)(inp) 6 | local up2 = Residual(256,256)(up1) 7 | local up4 = Residual(256,numOut)(up2) 8 | 9 | -- Lower branch 10 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(inp) 11 | local low1 = Residual(numIn,256)(pool) 12 | local low2 = Residual(256,256)(low1) 13 | local low5 = Residual(256,256)(low2) 14 | local low6 15 | if n > 1 then 16 | low6 = hourglass(n-1,256,numOut,low5) 17 | else 18 | low6 = Residual(256,numOut)(low5) 19 | end 20 | local low7 = Residual(numOut,numOut)(low6) 21 | local up5 = nn.SpatialUpSamplingNearest(2)(low7) 22 | 23 | -- Bring two branches together 24 | return nn.CAddTable()({up4,up5}) 25 | end 26 | 27 | local function lin(numIn,numOut,inp) 28 | -- Apply 1x1 convolution, no stride, no padding 29 | local l_ = nnlib.SpatialConvolution(numIn,numOut,1,1,1,1,0,0)(inp) 30 | return nnlib.ReLU(true)(nn.SpatialBatchNormalization(numOut)(l_)) 31 | end 32 | 33 | function createModel() 34 | 35 | local inp = nn.Identity()() 36 | 37 | -- Initial processing of the image 38 | local cnv1_ = nnlib.SpatialConvolution(3,64,7,7,2,2,3,3)(inp) -- 128 39 | local cnv1 = nnlib.ReLU(true)(nn.SpatialBatchNormalization(64)(cnv1_)) 40 | local r1 = Residual(64,128)(cnv1) 41 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(r1) -- 64 42 | local r4 = Residual(128,128)(pool) 43 | local r5 = Residual(128,128)(r4) 44 | local r6 = Residual(128,256)(r5) 45 | 46 | -- First hourglass 47 | local hg1 = hourglass(4,256,512,r6) 48 | 49 | -- Linear layers to produce first set of predictions 50 | local l1 = lin(512,512,hg1) 51 | local l2 = lin(512,256,l1) 52 | 53 | -- First predicted heatmaps 54 | local out1 = nnlib.SpatialConvolution(256,outputDim[1][1],1,1,1,1,0,0)(l2) 55 | local out1_ = nnlib.SpatialConvolution(outputDim[1][1],256+128,1,1,1,1,0,0)(out1) 56 | 57 | -- Concatenate with previous linear features 58 | local cat1 = nn.JoinTable(2)({l2,pool}) 59 | local cat1_ = nnlib.SpatialConvolution(256+128,256+128,1,1,1,1,0,0)(cat1) 60 | local int1 = nn.CAddTable()({cat1_,out1_}) 61 | 62 | -- Second hourglass 63 | local hg2 = hourglass(4,256+128,512,int1) 64 | 65 | -- Linear layers to produce predictions again 66 | local l3 = lin(512,512,hg2) 67 | local l4 = lin(512,256,l3) 68 | 69 | -- Second predicted heatmaps 70 | local out2 = nnlib.SpatialConvolution(256,outputDim[2][1],1,1,1,1,0,0)(l4) 71 | local out2_ = nnlib.SpatialConvolution(outputDim[2][1],256+256,1,1,1,1,0,0)(out2) 72 | 73 | -- Concatenate with previous linear features 74 | local cat2 = nn.JoinTable(2)({l4,l2}) 75 | local cat2_ = nnlib.SpatialConvolution(256+256,256+256,1,1,1,1,0,0)(cat2) 76 | local int2 = nn.CAddTable()({cat2_,out2_}) 77 | 78 | -- Third hourglass 79 | local hg3 = hourglass(4,256+256,512,int2) 80 | 81 | -- Linear layers to produce predictions again 82 | local l5 = lin(512,512,hg3) 83 | local l6 = lin(512,512,l5) 84 | 85 | -- Third predicted heatmaps 86 | local out3 = nnlib.SpatialConvolution(512,outputDim[3][1],1,1,1,1,0,0)(l6) 87 | 88 | -- Final model 89 | local model = nn.gModule({inp}, {out1,out2,out3}) 90 | 91 | return model 92 | 93 | end 94 | -------------------------------------------------------------------------------- /src/pypose/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ref 3 | import img 4 | 5 | # Reference for other predictions 6 | other_preds = {'nyu':{'flic':'nyu_pred', 'mpii':'nyu_pred'}} 7 | def get_path(dataset_name, file_name): 8 | return ref.posedir + '/data/' + dataset_name + '/ref/' + file_name + '.npy' 9 | 10 | # Load ground truth annotations 11 | annot = {'flic':ref.load('flic','test'), 12 | 'mpii':ref.load('mpii','valid'), 13 | 'mpii_train':ref.load('mpii','train'), 14 | 'mpii_test':ref.load('mpii','test')} 15 | 16 | def getdists(pred, dotrain=False): 17 | # Get normalized distances between predictions and ground truth 18 | 19 | # Automatically figures out dataset based on number of parts 20 | if pred.shape[1] == 11: 21 | dataset = 'flic' 22 | elif pred.shape[1] == 16: 23 | dataset = 'mpii' 24 | else: 25 | print "Error: Bad prediction file." 26 | return 0 27 | 28 | idx_ref = [] 29 | if dotrain: 30 | idx_ref = np.load(get_path(dataset,'train_idxs')) 31 | dataset += '_train' 32 | dists = np.zeros((len(idx_ref),annot[dataset]['part'].shape[1])) 33 | else: 34 | dists = np.zeros(annot[dataset]['part'].shape[:2]) 35 | 36 | # Loop through samples and parts 37 | for i in xrange(dists.shape[0]): 38 | if dotrain: 39 | idx = idx_ref[i] 40 | else: 41 | idx = i 42 | scale = annot[dataset]['normalize'][idx] 43 | for j in xrange(dists.shape[1]): 44 | if annot[dataset]['part'][i,j,0] <= 0: 45 | dists[i,j] = -1 46 | else: 47 | dists[i,j] = np.linalg.norm(annot[dataset]['part'][idx,j] - pred[i,j]) / scale 48 | return dists 49 | 50 | def getaccuracy(arr, thresh, no_zero=True, filt=None): 51 | # Returns number of elements in arr that fall below the given threshold 52 | # filt should be a binary array the same size as arr 53 | if filt is None: 54 | # If no filter has been provided create entirely true array 55 | filt = np.array([True for _ in xrange(len(arr))]) 56 | else: 57 | filt = filt.copy() 58 | 59 | if no_zero: 60 | filt *= (arr > 0) 61 | 62 | return float(sum(arr[filt] <= thresh)) / filt.sum() 63 | 64 | def pdjdata(dataset, dists, partnames=None, rng=None, filt=None): 65 | # Return data for creating a PDJ plot 66 | # Returns the average curve for the parts provided 67 | 68 | if partnames is None: 69 | partnames = ref.parts[dataset] 70 | 71 | if rng is None: 72 | # If no range is provided use the default ranges for flic and mpii 73 | if dataset == 'flic': 74 | rng = [0, .21, .01] 75 | elif dataset == 'mpii': 76 | rng = [0, .51, .01] 77 | 78 | t = np.arange(rng[0],rng[1],rng[2]) 79 | pdj = np.zeros(len(t)) 80 | 81 | if filt is None or filt.sum() > 0: 82 | for choice in partnames: 83 | part_idx = ref.parts[dataset].index(choice) 84 | for i in xrange(len(t)): 85 | pdj[i] += getaccuracy(dists[:, part_idx], t[i], filt=filt) 86 | 87 | pdj /= len(partnames) # Average across all chosen parts 88 | 89 | return pdj, t 90 | 91 | def transformpreds(dataset, preds, res, rot=False, dotrain=False): 92 | # Predictions from torch will need to go through a coordinate transformation 93 | new_preds = np.zeros(preds.shape) 94 | idx_ref = np.arange(len(new_preds)) 95 | if dotrain: 96 | idx_ref = np.load(get_path(dataset,'train_idxs')) 97 | dataset += '_train' 98 | for i in xrange(preds.shape[0]): 99 | idx = idx_ref[i] 100 | c = annot[dataset]['center'][idx] 101 | s = annot[dataset]['scale'][idx] 102 | if rot: 103 | r = annot[dataset]['torsoangle'][idx] 104 | else: 105 | r = 0 106 | for j in xrange(preds.shape[1]): 107 | new_preds[i,j] = img.transform(preds[i,j]-.5, c, s, res, invert=1, rot=r) 108 | return new_preds 109 | -------------------------------------------------------------------------------- /src/models/hg-stacked-4.lua: -------------------------------------------------------------------------------- 1 | paths.dofile('layers/Residual.lua') 2 | 3 | local function hourglass(n, numIn, numOut, inp) 4 | -- Upper branch 5 | local up1 = Residual(numIn,256)(inp) 6 | local up2 = Residual(256,256)(up1) 7 | local up4 = Residual(256,numOut)(up2) 8 | 9 | -- Lower branch 10 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(inp) 11 | local low1 = Residual(numIn,256)(pool) 12 | local low2 = Residual(256,256)(low1) 13 | local low5 = Residual(256,256)(low2) 14 | local low6 15 | if n > 1 then 16 | low6 = hourglass(n-1,256,numOut,low5) 17 | else 18 | low6 = Residual(256,numOut)(low5) 19 | end 20 | local low7 = Residual(numOut,numOut)(low6) 21 | local up5 = nn.SpatialUpSamplingNearest(2)(low7) 22 | 23 | -- Bring two branches together 24 | return nn.CAddTable()({up4,up5}) 25 | end 26 | 27 | local function lin(numIn,numOut,inp) 28 | -- Apply 1x1 convolution, no stride, no padding 29 | local l_ = nnlib.SpatialConvolution(numIn,numOut,1,1,1,1,0,0)(inp) 30 | return nnlib.ReLU(true)(nn.SpatialBatchNormalization(numOut)(l_)) 31 | end 32 | 33 | function createModel() 34 | 35 | local inp = nn.Identity()() 36 | 37 | -- Initial processing of the image 38 | local cnv1_ = nnlib.SpatialConvolution(3,64,7,7,2,2,3,3)(inp) -- 128 39 | local cnv1 = nnlib.ReLU(true)(nn.SpatialBatchNormalization(64)(cnv1_)) 40 | local r1 = Residual(64,128)(cnv1) 41 | local pool = nnlib.SpatialMaxPooling(2,2,2,2)(r1) -- 64 42 | local r4 = Residual(128,128)(pool) 43 | local r5 = Residual(128,128)(r4) 44 | local r6 = Residual(128,256)(r5) 45 | 46 | -- First hourglass 47 | local hg1 = hourglass(4,256,512,r6) 48 | 49 | -- Linear layers to produce first set of predictions 50 | local l1 = lin(512,512,hg1) 51 | local l2 = lin(512,256,l1) 52 | 53 | -- First predicted heatmaps 54 | local out1 = nnlib.SpatialConvolution(256,outputDim[1][1],1,1,1,1,0,0)(l2) 55 | local out1_ = nnlib.SpatialConvolution(outputDim[1][1],256+128,1,1,1,1,0,0)(out1) 56 | 57 | -- Concatenate with previous linear features 58 | local cat1 = nn.JoinTable(2)({l2,pool}) 59 | local cat1_ = nnlib.SpatialConvolution(256+128,256+128,1,1,1,1,0,0)(cat1) 60 | local int1 = nn.CAddTable()({cat1_,out1_}) 61 | 62 | -- Second hourglass 63 | local hg2 = hourglass(4,256+128,512,int1) 64 | 65 | -- Linear layers to produce predictions again 66 | local l3 = lin(512,512,hg2) 67 | local l4 = lin(512,256,l3) 68 | 69 | -- Second predicted heatmaps 70 | local out2 = nnlib.SpatialConvolution(256,outputDim[2][1],1,1,1,1,0,0)(l4) 71 | local out2_ = nnlib.SpatialConvolution(outputDim[2][1],256+256,1,1,1,1,0,0)(out2) 72 | 73 | -- Concatenate with previous linear features 74 | local cat2 = nn.JoinTable(2)({l4,l2}) 75 | local cat2_ = nnlib.SpatialConvolution(256+256,256+256,1,1,1,1,0,0)(cat2) 76 | local int2 = nn.CAddTable()({cat2_,out2_}) 77 | 78 | -- Third hourglass 79 | local hg3 = hourglass(4,256+256,512,int2) 80 | 81 | -- Linear layers to produce predictions again 82 | local l5 = lin(512,512,hg3) 83 | local l6 = lin(512,256,l5) 84 | 85 | -- Third predicted heatmaps 86 | local out3 = nnlib.SpatialConvolution(256,outputDim[3][1],1,1,1,1,0,0)(l6) 87 | local out3_ = nnlib.SpatialConvolution(outputDim[3][1],256+256,1,1,1,1,0,0)(out3) 88 | 89 | -- Concatenate with previous linear features 90 | local cat3 = nn.JoinTable(2)({l6,l4}) 91 | local cat3_ = nnlib.SpatialConvolution(256+256,256+256,1,1,1,1,0,0)(cat3) 92 | local int3 = nn.CAddTable()({cat3_,out3_}) 93 | 94 | -- Fourth hourglass 95 | local hg4 = hourglass(4,256+256,512,int3) 96 | 97 | -- Linear layers to produce predictions again 98 | local l7 = lin(512,512,hg4) 99 | local l8 = lin(512,512,l7) 100 | 101 | -- Output heatmaps 102 | local out4 = nnlib.SpatialConvolution(512,outputDim[4][1],1,1,1,1,0,0)(l8) 103 | 104 | -- Final model 105 | local model = nn.gModule({inp}, {out1,out2,out3,out4}) 106 | 107 | return model 108 | 109 | end 110 | -------------------------------------------------------------------------------- /src/opts.lua: -------------------------------------------------------------------------------- 1 | projectDir = projectDir or paths.concat(os.getenv('HOME'),'pose-hg-train') 2 | 3 | local M = { } 4 | 5 | function M.parse(arg) 6 | local cmd = torch.CmdLine() 7 | cmd:text() 8 | cmd:text(' ---------- General options ------------------------------------') 9 | cmd:text() 10 | cmd:option('-expID', 'default', 'Experiment ID') 11 | cmd:option('-dataset', 'h36m', 'Dataset choice') 12 | cmd:option('-dataDir', projectDir .. '/data', 'Data directory') 13 | cmd:option('-expDir', projectDir .. '/exp', 'Experiments directory') 14 | cmd:option('-manualSeed', -1, 'Manually set RNG seed') 15 | cmd:option('-GPU', 1, 'Default preferred GPU, if set to -1: no GPU') 16 | cmd:option('-finalPredictions', 0, 'Generate a final set of predictions at the end of training (default no, set to 1 for yes)') 17 | cmd:option('-nThreads', 4, 'Number of data loading threads') 18 | cmd:text() 19 | cmd:text(' ---------- Model options --------------------------------------') 20 | cmd:text() 21 | cmd:option('-netType', 'hg', 'Network model') 22 | cmd:option('-loadModel', 'none', 'Provide full path to a previously trained model') 23 | cmd:option('-continue', false, 'Pick up where an experiment left off') 24 | cmd:option('-branch', 'none', 'Provide a parent expID to branch off') 25 | cmd:option('-snapshot', 5, 'How often to take a snapshot of the model (0 = never)') 26 | cmd:option('-task', 'pose-c2f', 'Network task: pose-vol | pose-c2f') 27 | cmd:option('-nFeats', 256, 'Number of features in the hourglass') 28 | cmd:option('-nModules', 3, 'Number of modules in the provided hourglass model') 29 | cmd:option('-nStack', 4, 'Number of stacks in the provided hourglass model') 30 | cmd:option('-resZ', '1,2,4,64', 'Resolution of z-dimension for the output of the corresponding hourglass') 31 | cmd:text() 32 | cmd:text(' ---------- Hyperparameter options -----------------------------') 33 | cmd:text() 34 | cmd:option('-LR', 2.5e-4, 'Learning rate') 35 | cmd:option('-LRdecay', 0.0, 'Learning rate decay') 36 | cmd:option('-momentum', 0.0, 'Momentum') 37 | cmd:option('-weightDecay', 0.0, 'Weight decay') 38 | cmd:option('-crit', 'MSE', 'Criterion type') 39 | cmd:option('-optMethod', 'rmsprop', 'Optimization method: rmsprop | sgd | nag | adadelta') 40 | cmd:option('-threshold', .001, 'Threshold (on validation accuracy growth) to cut off training early') 41 | cmd:text() 42 | cmd:text(' ---------- Training options -----------------------------------') 43 | cmd:text() 44 | cmd:option('-nEpochs', 100, 'Total number of epochs to run') 45 | cmd:option('-lastEpoch', 1, 'Training from a previous epoch model (when -continue is activated)') 46 | cmd:option('-trainIters', 4000, 'Number of train iterations per epoch') 47 | cmd:option('-trainBatch', 4, 'Mini-batch size') 48 | cmd:option('-validIters', 2958, 'Number of validation iterations per epoch') 49 | cmd:option('-validBatch', 1, 'Mini-batch size for validation') 50 | cmd:text() 51 | cmd:text(' ---------- Data options ---------------------------------------') 52 | cmd:text() 53 | cmd:option('-inputRes', 256, 'Input image resolution') 54 | cmd:option('-outputRes', 64, 'Output heatmap resolution') 55 | cmd:option('-trainFile', '', 'Name of training data file') 56 | cmd:option('-validFile', '', 'Name of validation file') 57 | cmd:option('-scaleFactor', .25, 'Degree of scale augmentation') 58 | cmd:option('-rotFactor', 30, 'Degree of rotation augmentation') 59 | 60 | local opt = cmd:parse(arg or {}) 61 | opt.expDir = paths.concat(opt.expDir, opt.dataset) 62 | opt.dataDir = paths.concat(opt.dataDir, opt.dataset) 63 | opt.save = paths.concat(opt.expDir, opt.expID) 64 | -- convert string to table 65 | tmpResZ = {} 66 | for match in opt.resZ:gmatch("([^,%s]+)") do 67 | tmpResZ[#tmpResZ + 1] = tonumber(match) 68 | end 69 | opt.resZ = tmpResZ 70 | return opt 71 | end 72 | 73 | return M 74 | -------------------------------------------------------------------------------- /src/misc/monitor_experiments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | import os 5 | from subprocess import call 6 | import time 7 | from datetime import datetime 8 | import sys 9 | 10 | plt.ion() 11 | try: 12 | experiments_to_show = sys.argv[1].split(',') 13 | except: 14 | print "Error: No experiments provided" 15 | exit() 16 | 17 | print "Monitoring the following experiments:", 18 | for exp in experiments_to_show: print exp, 19 | print "" 20 | 21 | track_multiple = sys.argv[2] == '1' 22 | if track_multiple: 23 | exp_to_track = experiments_to_show[0] 24 | print "Tracking all variations of:",exp_to_track 25 | experiments_to_show.remove(exp_to_track) 26 | 27 | def readlog(filepath): 28 | done_first = False 29 | arr = None 30 | with open(filepath,'r') as f: 31 | for line in f: 32 | if not done_first: 33 | done_first = True 34 | else: 35 | vals = np.array(map(float,line.split())) 36 | if arr is None: 37 | arr = vals.reshape(1,np.size(vals)) 38 | else: 39 | arr = np.concatenate((arr, vals[:arr.shape[1]].reshape(1,arr.shape[1]))) 40 | return arr 41 | 42 | while True: 43 | logs = {} 44 | 45 | for dirname, dirnames, filenames in os.walk('../../exp/mpii'): 46 | for subdirname in dirnames: 47 | logs[subdirname] = {} 48 | train_path = '../../exp/mpii/' + subdirname + '/train.log' 49 | test_path = '../../exp/mpii/' + subdirname + '/test.log' 50 | if (os.path.exists(train_path) and os.path.exists(test_path) and 51 | os.stat(train_path).st_size != 0 and os.stat(test_path).st_size != 0): 52 | logs[subdirname]['train'] = readlog(train_path) 53 | logs[subdirname]['test'] = readlog(test_path) 54 | if track_multiple and exp_to_track in subdirname: 55 | if not subdirname in experiments_to_show: 56 | experiments_to_show += [subdirname] 57 | 58 | print "Updated experiments to show:", 59 | for exp in experiments_to_show: print exp, 60 | print "" 61 | 62 | idx = [1, 2, 0] # Epoch, Loss, Accuracy indices 63 | 64 | plt.clf() 65 | 66 | fig = plt.figure(1, facecolor='w') 67 | last = 25 68 | last_str = '(last %d)' % last 69 | 70 | axs = {"Train loss":fig.add_subplot(421), 71 | "Test loss":fig.add_subplot(423), 72 | "Train accuracy":fig.add_subplot(425), 73 | "Test accuracy":fig.add_subplot(427), 74 | ("Train loss %s"%last_str):fig.add_subplot(422), 75 | ("Test loss %s"%last_str):fig.add_subplot(424), 76 | ("Train accuracy %s"%last_str):fig.add_subplot(426), 77 | ("Test accuracy %s"%last_str):fig.add_subplot(428)} 78 | 79 | for k in axs.keys(): 80 | for tick in axs[k].xaxis.get_major_ticks(): 81 | tick.label.set_fontsize(8) 82 | for tick in axs[k].yaxis.get_major_ticks(): 83 | tick.label.set_fontsize(8) 84 | 85 | plt_idx = idx[0] 86 | if 'loss' in k: plt_idx = idx[1] 87 | if 'accuracy' in k: plt_idx = idx[2] 88 | 89 | start_idx = 0 90 | if last_str in k: start_idx = -last 91 | 92 | log_choice = 'train' 93 | if 'Test' in k: log_choice = 'test' 94 | 95 | max_x = -1 96 | for exp in experiments_to_show: 97 | log = logs[exp][log_choice] 98 | temp_start_idx = start_idx 99 | if abs(start_idx) > log.shape[0]: temp_start_idx = 0 100 | axs[k].plot(log[temp_start_idx:,idx[0]], log[temp_start_idx:,plt_idx], label=exp) 101 | if log[-1,idx[0]] > max_x: 102 | max_x = log[-1,idx[0]] 103 | 104 | if last_str in k: 105 | axs[k].set_xlim(max(0,max_x-last),max_x) 106 | else: 107 | axs[k].set_xlim(0,max_x) 108 | 109 | axs[k].set_title(k) 110 | if 'accuracy' in k and not last_str in k: 111 | axs[k].set_ylim(0,1) 112 | 113 | axs['Test accuracy'].legend(loc='lower right', fontsize=10) 114 | print time.strftime('%X %x %Z') 115 | plt.show() 116 | plt.pause(900) 117 | -------------------------------------------------------------------------------- /src/misc/examples.py: -------------------------------------------------------------------------------- 1 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 2 | # All of these examples are really, really outdated but offer some insights 3 | # into using the python code, if you want to check it out 4 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 5 | 6 | 7 | import numpy as np 8 | import pypose as pose 9 | import pypose.mpii as ds # Use this to swap which dataset you want to use 10 | 11 | # Sample dataset generation 12 | if False: 13 | # Everything is pretty self explanatory here (outdated and not functional anymore 14 | # command line interface is better) 15 | filename = 'sample' 16 | numsamples = 100 17 | is_train = 1 18 | augmentation = 1 19 | pose.data.generateset(ds, filename, numsamples, is_train, chg=augmentation) 20 | 21 | # Sample report 22 | # (compares performance based on whether the person is facing forward or backward) 23 | if False: 24 | # Get predictions 25 | preds = np.load(pose.eval.get_path(ds.name, 'nyu_pred')) 26 | # Get prediction error 27 | dists = pose.eval.getdists(preds) 28 | 29 | # To create our filters for the report: 307-DR,304-hg-D 30 | # Load up ground truth annotations 31 | gt_idx = pose.eval.gt_idx[ds.name] 32 | # Compare shoulder annotations 33 | sho_diff = np.array([ds.partinfo(gt_idx[i,0],gt_idx[i,1],'lsho')[0][0] - 34 | ds.partinfo(gt_idx[i,0],gt_idx[i,1],'rsho')[0][0] 35 | for i in xrange(len(gt_idx))], np.float) 36 | # Normalize difference by sample scale size 37 | sho_diff /= gt_idx[:,2] 38 | # Define the filters, numpy generates boolean arrays out of these comparisons 39 | filtnames = ['Forward', 'Back', 'Profile', 'Total'] 40 | thresh = .3 41 | filts = [sho_diff > thresh, 42 | sho_diff < -thresh, 43 | (sho_diff < thresh) * (sho_diff > -thresh), 44 | None] 45 | 46 | # Prepare the document 47 | title='Performance Comparison - Facing Forward or Backward' 48 | pdf = pose.report.PdfPages(pose.ref.posedir+'/img/reports/fwd_back_sample.pdf') 49 | 50 | # Add whatever pages you want 51 | print "Doing overall comparison..." 52 | pose.report.filtercomparison(ds.name, dists, filts, filtnames=filtnames, title=title, pdf=pdf) 53 | for i,filt in enumerate(filts[:-1]): 54 | print "Generating images for - %s..." % filtnames[i] 55 | pose.report.sampleimages(ds, preds, dists=dists, pdf=pdf, title=filtnames[i], filt=filt) 56 | pose.report.sampleimages(ds, preds, dists=dists, pdf=pdf, title=filtnames[i], filt=filt, get_worst=True) 57 | 58 | # Save the pdf 59 | pdf.close() 60 | 61 | if True: 62 | # Get predictions 63 | preds = np.load(pose.eval.get_path(ds.name, 'nyu_pred')) 64 | # Get prediction error 65 | dists = pose.eval.getdists(preds) 66 | 67 | # To create our filters for the report: 68 | # Load up ground truth annotations 69 | gt_idx = pose.eval.gt_idx[ds.name] 70 | # Calculate torso angles (note this only works for mpii) 71 | torso_angles = np.array([abs(ds.torsoangle(gt_idx[i,0], gt_idx[i,1])) for i in xrange(len(gt_idx))]) 72 | # Define filters 73 | filtnames = ['< 20 degrees','20 < 40','40 < 120', '> 120', 'Total'] 74 | filts = [torso_angles <= 20, 75 | (20 < torso_angles) * (torso_angles < 40), 76 | (40 < torso_angles) * (torso_angles < 120), 77 | (120 < torso_angles), 78 | None] 79 | 80 | # Prepare the document 81 | title='Performance Comparison - Torso Deviation from Vertical' 82 | pdf = pose.report.PdfPages(pose.ref.posedir+'/img/reports/torso_angle_sample.pdf') 83 | 84 | print "Doing overall comparison..." 85 | pose.report.filtercomparison(ds.name, dists, filts, filtnames=filtnames, title=title, pdf=pdf) 86 | for i in xrange(7): 87 | # This loop will only generate poor performing images for the first filter (people who are upright) 88 | print "Generating images for page - %d..." % i 89 | pose.report.sampleimages(ds, preds, dists=dists, pdf=pdf, title=filtnames[0], filt=filts[0], 90 | get_worst=True, page_num=i+1) 91 | 92 | # Save the pdf 93 | pdf.close() 94 | 95 | 96 | """ 97 | overall performance - taken out of report.py not adjusted to work here 98 | 99 | def make(dataset, preds, partnames=None): 100 | pdf = PdfPages(ref.posedir+'/img/test.pdf') 101 | 102 | num_pages = 10 103 | dists = eval.getdists(preds) 104 | 105 | for i in xrange(num_pages): 106 | print "Page %d..." % i 107 | page_choice = i + 1 108 | if i < num_pages / 2: 109 | get_worst = False 110 | else: 111 | page_choice -= num_pages / 2 112 | get_worst = True 113 | sampleimages(dataset, preds, dists=dists, pdf=pdf, get_worst=get_worst, 114 | partnames=partnames, title='Overall Performance', page_num=page_choice) 115 | 116 | pdf.close() 117 | 118 | """ 119 | -------------------------------------------------------------------------------- /src/util/eval.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------------- 2 | -- Helpful functions for evaluation 3 | ------------------------------------------------------------------------------- 4 | 5 | function calcDists(preds, label, normalize) 6 | local dists = torch.Tensor(preds:size(2), preds:size(1)) 7 | local diff = torch.Tensor(2) 8 | for i = 1,preds:size(1) do 9 | for j = 1,preds:size(2) do 10 | if label[i][j][1] > 1 and label[i][j][2] > 1 then 11 | dists[j][i] = torch.dist(label[i][j],preds[i][j])/normalize[i] 12 | else 13 | dists[j][i] = -1 14 | end 15 | end 16 | end 17 | return dists 18 | end 19 | 20 | function getPreds(hm) 21 | assert(hm:size():size() == 4, 'Input must be 4-D tensor') 22 | local max, idx = torch.max(hm:view(hm:size(1), hm:size(2), hm:size(3) * hm:size(4)), 3) 23 | local preds = torch.repeatTensor(idx, 1, 1, 2):float() 24 | preds[{{}, {}, 1}]:apply(function(x) return (x - 1) % hm:size(4) + 1 end) 25 | preds[{{}, {}, 2}]:add(-1):div(hm:size(3)):floor():add(1) 26 | return preds 27 | end 28 | 29 | function getPreds3D(hm) 30 | assert(hm:size():size() == 4, 'Input must be 4-D tensor') 31 | local max, idx = torch.max(hm:view(hm:size(1), nParts, opt.resZ[opt.nStack] * hm:size(3) * hm:size(4)), 3) 32 | local preds = torch.repeatTensor(idx, 1, 1, 3):float() 33 | preds[{{}, {}, 1}]:apply(function(x) return (x - 1) % hm:size(4) + 1 end) 34 | preds[{{}, {}, 2}]:add(-1):div(hm:size(4)):floor():mod(hm:size(3)):add(1) 35 | preds[{{}, {}, 3}]:add(-1):div(hm:size(3)*hm:size(4)):floor():add(1) 36 | return preds 37 | end 38 | 39 | function distAccuracy(dists, thr) 40 | -- Return percentage below threshold while ignoring values with a -1 41 | if not thr then thr = .5 end 42 | if torch.ne(dists,-1):sum() > 0 then 43 | return dists:le(thr):eq(dists:ne(-1)):sum() / dists:ne(-1):sum() 44 | else 45 | return -1 46 | end 47 | end 48 | 49 | function heatmapAccuracy(output, label, thr, idxs) 50 | -- Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations 51 | -- First value to be returned is average accuracy across 'idxs', followed by individual accuracies 52 | 53 | local preds 54 | local gt 55 | if opt.resZ[opt.nStack] == 1 then 56 | preds = getPreds(output) 57 | gt = getPreds(label) 58 | else 59 | preds = getPreds3D(output) 60 | gt = getPreds3D(label) 61 | end 62 | 63 | local dists = calcDists(preds, gt, torch.ones(preds:size(1))*opt.outputRes/10) 64 | local acc = {} 65 | local avgAcc = 0.0 66 | local badIdxCount = 0 67 | 68 | if not idxs then 69 | for i = 1,dists:size(1) do 70 | acc[i+1] = distAccuracy(dists[i]) 71 | if acc[i+1] >= 0 then avgAcc = avgAcc + acc[i+1] 72 | else badIdxCount = badIdxCount + 1 end 73 | end 74 | acc[1] = avgAcc / (dists:size(1) - badIdxCount) 75 | else 76 | for i = 1,#idxs do 77 | acc[i+1] = distAccuracy(dists[idxs[i]]) 78 | if acc[i+1] >= 0 then avgAcc = avgAcc + acc[i+1] 79 | else badIdxCount = badIdxCount + 1 end 80 | end 81 | acc[1] = avgAcc / (#idxs - badIdxCount) 82 | end 83 | return unpack(acc) 84 | end 85 | 86 | function basicAccuracy(output, label, thr) 87 | -- Calculate basic accuracy 88 | if not thr then thr = .5 end -- Default threshold of .5 89 | output = output:view(output:numel()) 90 | label = label:view(label:numel()) 91 | 92 | local rounded_output = torch.ceil(output - thr):typeAs(label) 93 | local eql = torch.eq(label,rounded_output):typeAs(label) 94 | 95 | return eql:sum()/output:numel() 96 | end 97 | 98 | function displayPCK(dists, part_idx, label, title, show_key) 99 | -- Generate standard PCK plot 100 | if not (type(part_idx) == 'table') then 101 | part_idx = {part_idx} 102 | end 103 | 104 | curve_res = 11 105 | num_curves = #dists 106 | local t = torch.linspace(0,.5,curve_res) 107 | local pdj_scores = torch.zeros(num_curves, curve_res) 108 | local plot_args = {} 109 | print(title) 110 | for curve = 1,num_curves do 111 | for i = 1,curve_res do 112 | t[i] = (i-1)*.05 113 | local acc = 0.0 114 | for j = 1,#part_idx do 115 | acc = acc + distAccuracy(dists[curve][part_idx[j]], t[i]) 116 | end 117 | pdj_scores[curve][i] = acc / #part_idx 118 | end 119 | plot_args[curve] = {label[curve],t,pdj_scores[curve],'-'} 120 | print(label[curve],pdj_scores[curve][curve_res]) 121 | end 122 | 123 | require 'gnuplot' 124 | gnuplot.raw('set title "' .. title .. '"') 125 | if not show_key then gnuplot.raw('unset key') 126 | else gnuplot.raw('set key font ",6" right bottom') end 127 | gnuplot.raw('set xrange [0:.5]') 128 | gnuplot.raw('set yrange [0:1]') 129 | gnuplot.plot(unpack(plot_args)) 130 | end 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Coarse-to-Fine Volumetric Prediction for Single-Image 3D Human Pose (Training code) 2 | ## Georgios Pavlakos, Xiaowei Zhou, Konstantinos G. Derpanis, Kostas Daniilidis 3 | 4 | This is the training code for the paper **Coarse-to-Fine Volumetric Prediction for Single-Image 3D Human Pose**. Please follow the links to read the [paper](https://arxiv.org/abs/1611.07828) and visit the corresponding [project page](https://www.seas.upenn.edu/~pavlakos/projects/volumetric). This code follows closely the [original training code for the Stacked Hourglass networks](https://github.com/anewell/pose-hg-train) by Alejandro Newell, so you can follow the corresponding release for an elaborate description on the command line arguments and options. Here, we provide details so that you can train a network with a volumetric output for 3D human pose estimation (or generally 3D keypoint localization). 5 | 6 | For the testng code please visit this [repository](https://github.com/geopavlakos/c2f-vol-demo). 7 | 8 | We provide code and data to train our models on [Human3.6M](http://vision.imar.ro/human3.6m/description.php). Please follow the instructions below to setup and use our code. To run this code, make sure the following are installed: 9 | 10 | - [Torch7](https://github.com/torch/torch7) 11 | - hdf5 12 | - cudnn 13 | 14 | ### 1) Data format 15 | 16 | We provide the data for the training and testing set of Human3.6M. Please run the following script to get all the relevant data (**be careful, since the size is over 32GB**) 17 | 18 | ``` 19 | data.sh 20 | ``` 21 | 22 | These images are extracted from the videos of the [original dataset](http://vision.imar.ro/human3.6m/description.php), and correspond to the images used for testing by the most typical protocol. The filename protocol we follow is: 23 | 24 | ``` 25 | S[subject number]_[Action Name].[Camera Name]_[Frame Number].jpg 26 | ``` 27 | 28 | An example for Subject 5, performing action Eating (iteration 1), when we consider camera name '55011271' and frame 321, is: 29 | 30 | ``` 31 | S5_Eating_1.55011271_000321.jpg 32 | ``` 33 | 34 | Check also the files: 35 | 36 | ``` 37 | data/hm36m/annot/train_images.txt 38 | data/hm36m/annot/valid_images.txt 39 | ``` 40 | 41 | to figure out the complete list of training and testing images. 42 | 43 | ### 2) Training 44 | 45 | You can train a typical Coarse-to-Fine Volumetric prediction model by using the command line: 46 | 47 | ``` 48 | th main.lua -dataset h36m -expID test-run-c2f -netType hg-stacked-4 -task pose-c2f -nStack 4 -resZ 1,2,4,64 -LR 2.5e-4 -nEpochs 1000 -trainIters 1000 -validIters 1000 49 | ``` 50 | 51 | Please check the file ''opts.lua'' for all the relevant command line options. Our code follows closely the [original training code for the Stacked Hourglass networks](https://github.com/anewell/pose-hg-train) by Alejandro Newell, so you can follow the corresponding repository for an elaborate description on the command line arguments and options. An additional argument required to model our class of networks is 'resZ'. This is a list with the resolution of the z-dimension for each hourglass' output. The length of the list must match the number of the hourglass components ''nStack''. 52 | 53 | Also, to replicate the models used in our paper, please use the architectures defined in the files: 54 | 55 | ``` 56 | src/models/hg-stacked-2.lua 57 | src/models/hg-stacked-3.lua 58 | src/models/hg-stacked-4.lua 59 | ``` 60 | 61 | with 2,3 and 4 hourglasses respectively. Again, these follow the original Stacked Hourglass network design. Alternatively, you can use the typical hourglass architecture: 62 | 63 | ``` 64 | src/models/hg.lua 65 | ``` 66 | 67 | which has a more uniform network design. If you are using a single hourglass (no iterative coarse-to-fine prediction), you can train a simple model by using the command line: 68 | 69 | ``` 70 | th main.lua -dataset h36m -expID test-run-vol -netType hg -task pose-vol -nStack 1 -resZ 64 -LR 2.5e-4 -nEpochs 1000 -trainIters 1000 -validIters 1000 71 | ``` 72 | 73 | ### 3) Evaluation 74 | 75 | You can evaluate your trained model on users S9 and S11 of Human3.6M, by running: 76 | 77 | ``` 78 | th main.lua -dataset h36m -expID test-run-c2f -task pose-c2f -nStack 4 -finalPredictions 1 -nEpochs 0 -validIters 109867 -loadModel \path\to\model 79 | ``` 80 | 81 | or you can use our [demo code](https://github.com/geopavlakos/c2f-vol-demo) for that. 82 | 83 | ### 4) Training on your own data 84 | 85 | Compared to training a hourglass with a 2D output, the only overhead our code requires is to provide the index for the z-Dimension (zind), for each keypoint. We provide this in a 1-64 scale, and the code adapts it when the resolution is smaller. As long as you provide this information during training, along with the pixel locations of each keypoint, you should be able to use our training code on your custom data. 86 | 87 | ### Citing 88 | 89 | If you find this code useful for your research, please consider citing the following paper: 90 | 91 | @Inproceedings{pavlakos17volumetric, 92 | Title = {Coarse-to-Fine Volumetric Prediction for Single-Image 3{D} Human Pose}, 93 | Author = {Pavlakos, Georgios and Zhou, Xiaowei and Derpanis, Konstantinos G and Daniilidis, Kostas}, 94 | Booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 95 | Year = {2017} 96 | } 97 | 98 | ### Acknowledgements 99 | 100 | This code follows closely the [released code](https://github.com/anewell/pose-hg-train) for the Stacked Hourglass networks by Alejandro Newell. We gratefully appreciate the impact it had on our work. If you use our code, please consider citing the [original paper](http://arxiv.org/abs/1603.06937) as well. 101 | -------------------------------------------------------------------------------- /src/util/Logger.lua: -------------------------------------------------------------------------------- 1 | --[[ Logger: a simple class to log symbols during training, 2 | and automate plot generation 3 | 4 | #### Slightly modified from optim.Logger to allow appending to logs #### 5 | 6 | Example: 7 | logger = Logger('somefile.log') -- file to save stuff 8 | 9 | for i = 1,N do -- log some symbols during 10 | train_error = ... -- training/testing 11 | test_error = ... 12 | logger:add{['training error'] = train_error, 13 | ['test error'] = test_error} 14 | end 15 | 16 | logger:style{['training error'] = '-', -- define styles for plots 17 | ['test error'] = '-'} 18 | logger:plot() -- and plot 19 | 20 | ---- OR --- 21 | 22 | logger = optim.Logger('somefile.log') -- file to save stuff 23 | logger:setNames{'training error', 'test error'} 24 | 25 | for i = 1,N do -- log some symbols during 26 | train_error = ... -- training/testing 27 | test_error = ... 28 | logger:add{train_error, test_error} 29 | end 30 | 31 | logger:style{'-', '-'} -- define styles for plots 32 | logger:plot() -- and plot 33 | ]] 34 | require 'xlua' 35 | local Logger = torch.class('Logger') 36 | 37 | function Logger:__init(filename, continue, timestamp) 38 | if filename then 39 | self.name = filename 40 | os.execute('mkdir -p "' .. paths.dirname(filename) .. '"') 41 | if timestamp then 42 | -- append timestamp to create unique log file 43 | filename = filename .. '-'..os.date("%Y_%m_%d_%X") 44 | end 45 | if continue then 46 | self.file = io.open(filename,'a') 47 | else 48 | self.file = io.open(filename,'w') 49 | end 50 | self.epsfile = self.name .. '.eps' 51 | else 52 | self.file = io.stdout 53 | self.name = 'stdout' 54 | print(' warning: no path provided, logging to std out') 55 | end 56 | self.continue = continue 57 | self.empty = true 58 | self.symbols = {} 59 | self.styles = {} 60 | self.names = {} 61 | self.idx = {} 62 | self.figure = nil 63 | end 64 | 65 | function Logger:setNames(names) 66 | self.names = names 67 | self.empty = false 68 | self.nsymbols = #names 69 | for k,key in pairs(names) do 70 | self.file:write(key .. '\t') 71 | self.symbols[k] = {} 72 | self.styles[k] = {'+'} 73 | self.idx[key] = k 74 | end 75 | self.file:write('\n') 76 | self.file:flush() 77 | end 78 | 79 | function Logger:add(symbols) 80 | -- (1) first time ? print symbols' names on first row 81 | if self.empty then 82 | self.empty = false 83 | self.nsymbols = #symbols 84 | for k,val in pairs(symbols) do 85 | if not self.continue then self.file:write(k .. '\t') end 86 | self.symbols[k] = {} 87 | self.styles[k] = {'+'} 88 | self.names[k] = k 89 | end 90 | self.idx = self.names 91 | if not self.continue then self.file:write('\n') end 92 | end 93 | -- (2) print all symbols on one row 94 | for k,val in pairs(symbols) do 95 | if type(val) == 'number' then 96 | self.file:write(string.format('%11.4e',val) .. '\t') 97 | elseif type(val) == 'string' then 98 | self.file:write(val .. '\t') 99 | else 100 | xlua.error('can only log numbers and strings', 'Logger') 101 | end 102 | end 103 | self.file:write('\n') 104 | self.file:flush() 105 | -- (3) save symbols in internal table 106 | for k,val in pairs(symbols) do 107 | table.insert(self.symbols[k], val) 108 | end 109 | end 110 | 111 | function Logger:style(symbols) 112 | for name,style in pairs(symbols) do 113 | if type(style) == 'string' then 114 | self.styles[name] = {style} 115 | elseif type(style) == 'table' then 116 | self.styles[name] = style 117 | else 118 | xlua.error('style should be a string or a table of strings','Logger') 119 | end 120 | end 121 | end 122 | 123 | function Logger:plot(...) 124 | if not xlua.require('gnuplot') then 125 | if not self.warned then 126 | print(' warning: cannot plot with this version of Torch') 127 | self.warned = true 128 | end 129 | return 130 | end 131 | local plotit = false 132 | local plots = {} 133 | local plotsymbol = 134 | function(name,list) 135 | if #list > 1 then 136 | local nelts = #list 137 | local plot_y = torch.Tensor(nelts) 138 | for i = 1,nelts do 139 | plot_y[i] = list[i] 140 | end 141 | for _,style in ipairs(self.styles[name]) do 142 | table.insert(plots, {self.names[name], plot_y, style}) 143 | end 144 | plotit = true 145 | end 146 | end 147 | local args = {...} 148 | if not args[1] then -- plot all symbols 149 | for name,list in pairs(self.symbols) do 150 | plotsymbol(name,list) 151 | end 152 | else -- plot given symbols 153 | for _,name in ipairs(args) do 154 | plotsymbol(self.idx[name], self.symbols[self.idx[name]]) 155 | end 156 | end 157 | if plotit then 158 | self.figure = gnuplot.figure(self.figure) 159 | gnuplot.plot(plots) 160 | gnuplot.grid('on') 161 | gnuplot.title('') 162 | if self.epsfile then 163 | os.execute('rm -f "' .. self.epsfile .. '"') 164 | local epsfig = gnuplot.epsfigure(self.epsfile) 165 | gnuplot.plot(plots) 166 | gnuplot.grid('on') 167 | gnuplot.title('') 168 | gnuplot.plotflush() 169 | gnuplot.close(epsfig) 170 | end 171 | end 172 | end 173 | -------------------------------------------------------------------------------- /src/train.lua: -------------------------------------------------------------------------------- 1 | -- Track accuracy 2 | opt.lastAcc = opt.lastAcc or 0 3 | opt.bestAcc = opt.bestAcc or 0 4 | -- We save snapshots of the best model only when evaluating on the full validation set 5 | trackBest = (opt.validIters * opt.validBatch == ref.valid.nsamples) 6 | 7 | -- The dimensions of 'final predictions' are defined by the opt.task file 8 | -- This allows some flexibility for post-processing of the network output 9 | preds = torch.Tensor(ref.valid.nsamples, unpack(predDim)) 10 | 11 | -- We also save the raw output of the network (in this case heatmaps) 12 | --if type(outputDim[1]) == "table" then predHMs = torch.Tensor(ref.valid.nsamples, unpack(outputDim[#outputDim])) 13 | --else predHMs = torch.Tensor(ref.valid.nsamples, unpack(outputDim)) end 14 | 15 | -- Model parameters 16 | param, gradparam = model:getParameters() 17 | 18 | -- Main processing step 19 | function step(tag) 20 | local avgLoss, avgAcc = 0.0, 0.0 21 | local output, err, idx 22 | local r = ref[tag] 23 | local function evalFn(x) return criterion.output, gradparam end 24 | 25 | if tag == 'train' then 26 | print("==> Starting epoch: " .. epoch .. "/" .. (opt.nEpochs + opt.epochNumber - 1)) 27 | model:training() 28 | set = 'train' 29 | isTesting = false -- Global flag 30 | else 31 | if tag == 'predict' then print("==> Generating predictions...") end 32 | model:evaluate() 33 | set = 'valid' 34 | isTesting = true 35 | end 36 | 37 | for i,sample in loader[set]:run() do 38 | 39 | xlua.progress(i, r.iters) 40 | local input, label = unpack(sample) 41 | 42 | if opt.GPU ~= -1 then 43 | -- Convert to CUDA 44 | input = applyFn(function (x) return x:cuda() end, input) 45 | label = applyFn(function (x) return x:cuda() end, label) 46 | end 47 | 48 | -- Do a forward pass and calculate loss 49 | local output = model:forward(input) 50 | local err = criterion:forward(output, label) 51 | 52 | -- Training: Do backpropagation and optimization 53 | if tag == 'train' then 54 | model:zeroGradParameters() 55 | model:backward(input, criterion:backward(output, label)) 56 | optfn(evalFn, param, optimState) 57 | 58 | -- Validation: Get flipped output 59 | else 60 | output = applyFn(function (x) return x:clone() end, output) 61 | local flip_ = customFlip or flip 62 | local shuffleLR_ = customShuffleLR or shuffleLR 63 | local flippedOut = model:forward(flip_(input)) 64 | if opt.nStack == 1 then 65 | flippedOut = applyFn(function (x) return flip_(shuffleLR_(x,matchedParts3D[1])) end, flippedOut) 66 | else 67 | for j = 1,opt.nStack do 68 | flippedOut[j] = applyFn(function (x) return flip_(shuffleLR_(x,matchedParts3D[j])) end, flippedOut[j]) 69 | end 70 | end 71 | output = applyFn(function (x,y) return x:add(y):div(2) end, output, flippedOut) 72 | 73 | end 74 | 75 | -- Synchronize with GPU 76 | if opt.GPU ~= -1 then cutorch.synchronize() end 77 | 78 | -- If we're generating predictions, save output 79 | if tag == 'predict' or (tag == 'valid' and trackBest) then 80 | --if type(outputDim[1]) == "table" then 81 | -- -- If we're getting a table of heatmaps, save the last one 82 | -- predHMs:sub(i,i+r.batchsize-1):copy(output[#output]) 83 | --else 84 | -- predHMs:sub(i,i+r.batchsize-1):copy(output) 85 | --end 86 | if postprocess then preds:sub(i,i+r.batchsize-1):copy(postprocess(set,i,output)) end 87 | end 88 | 89 | -- Calculate accuracy 90 | local acc = accuracy(output, label) 91 | avgLoss = avgLoss + err 92 | avgAcc = avgAcc + acc 93 | end 94 | 95 | avgLoss = avgLoss / r.iters 96 | avgAcc = avgAcc / r.iters 97 | 98 | local epochStep = torch.floor(ref.train.nsamples / (r.iters * r.batchsize * 2)) 99 | if tag == 'train' and epoch % epochStep == 0 then 100 | if avgAcc - opt.lastAcc < opt.threshold then 101 | isFinished = true --Training has plateaued 102 | end 103 | opt.lastAcc = avgAcc 104 | end 105 | 106 | -- Print and log some useful performance metrics 107 | print(string.format(" %s : Loss: %.7f Acc: %.4f" % {set, avgLoss, avgAcc})) 108 | if r.log then 109 | r.log:add{ 110 | ['epoch '] = string.format("%d" % epoch), 111 | ['loss '] = string.format("%.6f" % avgLoss), 112 | ['acc '] = string.format("%.4f" % avgAcc), 113 | ['LR '] = string.format("%g" % optimState.learningRate) 114 | } 115 | end 116 | 117 | if tag == 'train' and opt.snapshot ~= 0 and epoch % opt.snapshot == 0 then 118 | -- Take an intermediate training snapshot 119 | model:clearState() 120 | torch.save(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), model) 121 | torch.save(paths.concat(opt.save, 'optimState.t7'), optimState) 122 | elseif tag == 'valid' and trackBest and avgAcc > opt.bestAcc then 123 | -- A new record validation accuracy has been hit, save the model and predictions 124 | predFile = hdf5.open(opt.save .. '/best_preds.h5', 'w') 125 | --predFile:write('heatmaps', predHMs) 126 | if postprocess then predFile:write('preds', preds) end 127 | predFile:close() 128 | model:clearState() 129 | torch.save(paths.concat(opt.save, 'best_model.t7'), model) 130 | torch.save(paths.concat(opt.save, 'optimState.t7'), optimState) 131 | opt.bestAcc = avgAcc 132 | elseif tag == 'predict' then 133 | -- Save final predictions 134 | predFile = hdf5.open(opt.save .. '/preds.h5', 'w') 135 | --predFile:write('heatmaps', predHMs) 136 | if postprocess then predFile:write('preds', preds) end 137 | predFile:close() 138 | end 139 | end 140 | 141 | function train() step('train') end 142 | function valid() step('valid') end 143 | function predict() step('predict') end 144 | -------------------------------------------------------------------------------- /src/data.lua: -------------------------------------------------------------------------------- 1 | -- Manage HDF5 files 2 | useHDF5 = {} 3 | files = {} 4 | 5 | for _,l in ipairs({'train','valid'}) do 6 | useHDF5[l] = false 7 | 8 | if #opt[l .. 'File'] > 0 then 9 | useHDF5[l] = true 10 | files[l] = hdf5.open(opt.dataDir .. '/' .. opt[l .. 'File'] .. '.h5', 'r') 11 | local dataSize = torch.Tensor(files[l]:read('data'):dataspaceSize()) 12 | ref[l]['nsamples'] = dataSize[1] 13 | 14 | -- Use the hdf5 file to determine data/label sizes 15 | -- These could be anything, not just images/heatmaps 16 | dataDim = torch.totable(dataSize:sub(2,-1)) 17 | local labelSize = torch.Tensor(files[l]:read('label'):dataspaceSize()) 18 | if labelSize:numel() > 1 then labelDim = torch.totable(labelSize:sub(2,-1)) 19 | else labelDim = {} end 20 | 21 | print("HDF5 file provided (" .. l .. "): " .. opt[l .. "File"]) 22 | print("Number of samples: " .. ref[l]['nsamples']) 23 | end 24 | end 25 | 26 | -- More legible way to print out tensor dimensions 27 | local function print_dims(prefix,d) 28 | local s = "" 29 | if #d == 0 then s = "single value" 30 | elseif #d == 1 then s = string.format("vector of length: %d", d[1]) 31 | else 32 | s = string.format("tensor with dimensions: %d", d[1]) 33 | for i = 2,table.getn(d) do s = s .. string.format(" x %d", d[i]) end 34 | end 35 | print(prefix .. s) 36 | end 37 | 38 | function loadData(set, idx, batchsize) 39 | -- Load in a mini-batch of data 40 | local input,label 41 | 42 | -- Read data from a provided hdf5 file 43 | if useHDF5[set] then 44 | idx = idx or torch.random(annot[set]['nsamples'] - batchsize) 45 | local inp_dims = {{idx,idx+batchsize-1}} 46 | for i = 1,#dataDim do inp_dims[i+1] = {1,dataDim[i]} end 47 | local label_dims = {{idx,idx+batchsize-1}} 48 | for i = 1,#labelDim do label_dims[i+1] = {1,labelDim[i]} end 49 | 50 | input = files[set]:read('data'):partial(unpack(inp_dims)) 51 | label = files[set]:read('label'):partial(unpack(label_dims)) 52 | 53 | if opt.inputRes ~= dataDim[2] or opt.outputRes ~= labelDim[2] then 54 | -- Data is a fixed size coming from the hdf5 file, so this allows us to resize it 55 | input = image.scale(input:view(batchsize*dataDim[1],dataDim[2],dataDim[3]),opt.inputRes) 56 | input = input:view(batchsize,dataDim[1],opt.inputRes,opt.inputRes) 57 | label = image.scale(label:view(batchsize*labelDim[1],labelDim[2],labelDim[3]),opt.outputRes) 58 | label = label:view(batchsize,labelDim[1],opt.outputRes,opt.outputRes) 59 | end 60 | 61 | -- Or generate a new sample 62 | else 63 | input = torch.Tensor(batchsize, unpack(dataDim)) 64 | label = {} 65 | local labelTemp 66 | for i = 1, opt.nStack do 67 | table.insert(label,torch.Tensor(batchsize, opt.resZ[i]*labelDim[1], labelDim[2], labelDim[3])) 68 | end 69 | for i = 1, batchsize do 70 | idx_ = idx or torch.random(annot[set]['nsamples']) 71 | idx_ = (idx_ + i - 2) % annot[set]['nsamples'] + 1 72 | input[i],labelTemp = generateSample(set, idx_) 73 | for j = 1, opt.nStack do 74 | label[j][i] = labelTemp[j] 75 | end 76 | end 77 | end 78 | 79 | if input:max() > 2 then 80 | input:div(255) 81 | end 82 | 83 | -- Augment data (during training only) 84 | if not isTesting then 85 | local s = torch.randn(batchsize):mul(opt.scaleFactor):add(1):clamp(1-opt.scaleFactor,1+opt.scaleFactor) 86 | local r = torch.randn(batchsize):mul(opt.rotFactor):clamp(-2*opt.rotFactor,2*opt.rotFactor) 87 | 88 | for i = 1, batchsize do 89 | -- Color 90 | input[{i, 1, {}, {}}]:mul(torch.uniform(0.8, 1.2)):clamp(0, 1) 91 | input[{i, 2, {}, {}}]:mul(torch.uniform(0.8, 1.2)):clamp(0, 1) 92 | input[{i, 3, {}, {}}]:mul(torch.uniform(0.8, 1.2)):clamp(0, 1) 93 | 94 | -- Scale/rotation 95 | if torch.uniform() <= .6 then r[i] = 0 end 96 | local inp,out = opt.inputRes, opt.outputRes 97 | input[i] = crop(input[i], {(inp+1)/2,(inp+1)/2}, inp*s[i]/200, r[i], inp) 98 | for j = 1, opt.nStack do 99 | label[j][i] = crop(label[j][i], {(out+1)/2,(out+1)/2}, out*s[i]/200, r[i], out) 100 | end 101 | end 102 | 103 | -- Flip 104 | local flip_ = customFlip or flip 105 | local shuffleLR_ = customShuffleLR or shuffleLR 106 | if torch.uniform() <= .5 then 107 | input = flip_(input) 108 | for i = 1, opt.nStack do 109 | label[i] = flip_(shuffleLR_(label[i],matchedParts3D[i])) 110 | end 111 | end 112 | end 113 | 114 | -- Do task-specific preprocessing 115 | if preprocess then input,label = preprocess(input,label,batchsize,set,idx) end 116 | 117 | return input, label 118 | end 119 | 120 | -- Check data preprocessing if there is any 121 | if not alreadyChecked then 122 | 123 | if preprocess then 124 | print_dims("Original input is a ", dataDim) 125 | print_dims("Original output is a ", labelDim) 126 | print("After preprocessing ---") 127 | local temp_input,temp_label = loadData('train',1,1) 128 | -- Input 129 | if type(temp_input) == "table" then 130 | inputDim = {} 131 | print("Input is a table of %d values" % table.getn(temp_input)) 132 | for i = 1,#temp_input do 133 | inputDim[i] = torch.totable(temp_input[i][1]:size()) 134 | print_dims("Input %d is a "%i, inputDim[i]) 135 | end 136 | else 137 | inputDim = torch.totable(temp_input[1]:size()) 138 | print_dims("Input is a ", inputDim) 139 | end 140 | 141 | -- Output 142 | if type(temp_label) == "table" then 143 | outputDim = {} 144 | print("Output is a table of %d values" % #temp_label) 145 | for i = 1,#temp_label do 146 | outputDim[i] = torch.totable(temp_label[i][1]:size()) 147 | print_dims("Output %d is a "%i, outputDim[i]) 148 | end 149 | else 150 | outputDim = torch.totable(temp_label[1]:size()) 151 | print_dims("Output is a ", outputDim) 152 | end 153 | else 154 | inputDim = dataDim 155 | outputDim = labelDim 156 | print_dims("Input is a ", inputDim) 157 | print_dims("Output is a ", outputDim) 158 | end 159 | 160 | end 161 | -------------------------------------------------------------------------------- /src/misc/pck_figs.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from scipy.io import loadmat 3 | import numpy as np 4 | import h5py 5 | 6 | def setuppdjplot(ax, i): 7 | # Configuration of ticks in plots 8 | major_ticks_y = np.arange(0,1.01,.2) 9 | minor_ticks_y = np.arange(0,1.01,.1) 10 | major_ticks_x = np.arange(0,.21,.1) 11 | minor_ticks_x = np.arange(0,.21,.05) 12 | ax.set_yticks(major_ticks_y) 13 | ax.set_yticks(minor_ticks_y, minor=True) 14 | ax.set_xticks(major_ticks_x) 15 | ax.set_xticks(minor_ticks_x, minor=True) 16 | for tick in ax.yaxis.get_major_ticks(): 17 | # tick.label.set_fontsize(8) 18 | if i == 0: 19 | tick.label.set_visible(False) 20 | # for tick in ax.yaxis.get_major_ticks(): 21 | # tick.label.set_fontsize(8) 22 | ax.grid() 23 | ax.grid(which='minor', alpha=0.5) 24 | 25 | def plottraintest(ax, train_log, test_log, title='Loss'): 26 | idx = [0,2] 27 | for tick in ax.xaxis.get_major_ticks(): 28 | tick.label.set_fontsize(8) 29 | for tick in ax.yaxis.get_major_ticks(): 30 | tick.label.set_fontsize(8) 31 | 32 | t = (np.arange(len(train_log[1:,idx[0]])) + 1) 33 | ax.plot(t, train_log[1:,idx[0]], label='Train', color='k') 34 | t = (np.arange(len(test_log[1:,idx[1]])) + 1) 35 | ax.plot(t, test_log[1:,idx[1]], label='Test', color='r') 36 | # ax.set_ylim(0,1) 37 | ax.legend(loc='upper right', fontsize=10) 38 | ax.set_title('Training/Test %s'%title, fontsize=10) 39 | 40 | def loadpreds(predfile, obs): 41 | with h5py.File(predfile, 'r') as f: 42 | preds = np.array(f['preds_tf']) 43 | dist_key = 'dist_' 44 | if obs: dist_key += 'o' 45 | dists = np.array(f[dist_key]) 46 | return preds, dists 47 | 48 | def getaccuracy(arr, thresh, no_zero=True, filt=None): 49 | # Returns number of elements in arr that fall below the given threshold 50 | # filt should be a binary array the same size as arr 51 | if filt is None: 52 | # If no filter has been provided create entirely true array 53 | filt = np.array([True for _ in xrange(len(arr))]) 54 | else: 55 | filt = filt.copy() 56 | 57 | if no_zero: 58 | filt *= (arr > 0) 59 | 60 | return float(sum(arr[filt] <= thresh)) / filt.sum() 61 | 62 | def pdjdata(dataset, dists, partidx, rng=None, filt=None): 63 | # Return data for creating a PDJ plot 64 | # Returns the average curve for the parts provided 65 | 66 | if rng is None: 67 | # If no range is provided use the default ranges for flic and mpii 68 | if dataset == 'flic': 69 | rng = [0, .21, .01] 70 | elif dataset == 'mpii': 71 | rng = [0, .51, .01] 72 | 73 | t = np.arange(rng[0],rng[1],rng[2]) 74 | pdj = np.zeros(len(t)) 75 | 76 | for i in xrange(len(t)): 77 | for j in partidx: 78 | pdj[i] += getaccuracy(dists[:, j], t[i], filt=filt) 79 | 80 | pdj /= len(partidx) # Average across all chosen parts 81 | return pdj, t 82 | 83 | # ============================================================================= 84 | # MPII Figures 85 | # ============================================================================= 86 | 87 | dirnames = ['carreira15arxiv','pishchulin15arxiv', 88 | 'tompson14nips','tompson15cvpr', 89 | 'wei16arxiv','our_model'] 90 | 91 | results = [] 92 | 93 | for d in dirnames: 94 | results += [loadmat('mpii_results/'+d+'/pckAll.mat')] 95 | ''' 96 | plt.plot(results[0]['range'].T,results[0]['pck'][:,0]) 97 | plt.plot(results[0]['range'].T,results[1]['pck'][:,0]) 98 | plt.plot(results[0]['range'].T,results[2]['pck'][:,0]) 99 | plt.plot(results[0]['range'].T,results[3]['pck'][:,0]) 100 | plt.plot(results[0]['range'].T,results[4]['pck'][:,0]) 101 | plt.plot(results[0]['range'].T,results[5]['pck'][:,0]) 102 | plt.show() 103 | ''' 104 | 105 | # ============================================================================= 106 | # FLIC Figures 107 | # ============================================================================= 108 | 109 | # flic_wri = loadmat('flic_results_wrist') 110 | # flic_wri = flic_wri['new_data'] 111 | # flic_elb = loadmat('flic_results_elbow') 112 | # flic_elb = flic_elb['new_data'] 113 | # nyu_dists = np.load('/home/-/posenet/data/flic/ref/nyu_dists_flic_obs.npy') 114 | # _,our_dists = loadpreds('/home/-/posenet/exp/flic/base/preds.h5',True) 115 | 116 | # # Plot elbow results 117 | # nyu_elb,t = pdjdata('flic',nyu_dists,[1,4]) 118 | # our_elb,_ = pdjdata('flic',our_dists,[1,4]) 119 | # f = plt.figure(facecolor='w') 120 | # ax = f.add_subplot(1,2,2) 121 | # lines = ax.plot(t,our_elb) 122 | # plt.setp(lines,linewidth=2) 123 | # lines = ax.plot(t,nyu_elb) 124 | # plt.setp(lines,linewidth=2) 125 | # lines = ax.plot(t,flic_elb[[0,2],:21].T) 126 | # setuppdjplot(ax, 0) 127 | # ax.set_title('Elbow',fontsize=24) 128 | # ax.set_ylabel('Detection Rate (%)',fontsize=22) 129 | # ax.set_xlabel('Normalized Distance',fontsize=22) 130 | # slabels = ['Ours','Tompson et al.','Chen et al.','Toshev et al.'] 131 | # plt.setp(lines,linewidth=2) 132 | # ax.legend(loc=4,labels=labels, fontsize=16) 133 | 134 | # # Plot elbow results 135 | # nyu_wri,t = pdjdata('flic',nyu_dists,[2,5]) 136 | # our_wri,_ = pdjdata('flic',our_dists,[2,5]) 137 | # ax = f.add_subplot(1,2,1) 138 | # lines = ax.plot(t,our_wri) 139 | # plt.setp(lines,linewidth=2) 140 | # lines = ax.plot(t,nyu_wri) 141 | # plt.setp(lines,linewidth=2) 142 | # lines = ax.plot(t,flic_wri[[0,2],:21].T) 143 | # setuppdjplot(ax, 1) 144 | # ax.set_ylabel('Detection Rate (%)',fontsize=22) 145 | # ax.set_xlabel('Normalized Distance',fontsize=22) 146 | # ax.set_title('Wrist',fontsize=24) 147 | # plt.setp(lines,linewidth=2) 148 | 149 | # ax2 = f.add_subplot(1,1,1) 150 | # ax2.axis('off') 151 | # ax2.set_xlabel('Normalized Distance') 152 | # t = f.suptitle('FLIC Results',fontsize=32) 153 | # plt.show() 154 | 155 | 156 | # ============================================================================= 157 | # Loss comparisons 158 | # ============================================================================= 159 | 160 | exps = ['307-DR', '304-hg-D', 'hg-1_s2_b1', 'F28-hg-I-2', '304-hg-IA'] 161 | 162 | logs = [] 163 | rounds = [] 164 | for exp in exps: 165 | log = np.loadtxt('log_data/mpii/'+exp+'/test.log',skiprows=1) 166 | rounds += [log[:,1]] 167 | logs += [log[:,0]] 168 | 169 | f = plt.figure(facecolor='w') 170 | ax = f.add_subplot(111) 171 | f.suptitle('Validation Accuracy Across Training',fontsize=32) 172 | ax.set_ylabel('Average Accuracy (%)',fontsize=24) 173 | ax.set_xlabel('Training Iterations (x2000)',fontsize=24) 174 | for i in xrange(len(exps)): 175 | ln = ax.plot(rounds[i],logs[i]) 176 | plt.setp(ln,linewidth=1.5) 177 | 178 | labels = ['HG','HG-Int','HG-Stacked','HG-Stacked-Int','HG-Stacked-Add'] 179 | ax.legend(loc=4,labels=labels,fontsize=16) 180 | ax.set_xlim(0,100) 181 | ax.set_ylim(0,1) 182 | 183 | plt.show() -------------------------------------------------------------------------------- /src/ref.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------------- 2 | -- Load necessary libraries and files 3 | ------------------------------------------------------------------------------- 4 | 5 | require 'torch' 6 | require 'xlua' 7 | require 'optim' 8 | require 'nn' 9 | require 'nnx' 10 | require 'nngraph' 11 | require 'hdf5' 12 | require 'string' 13 | require 'image' 14 | 15 | paths.dofile('util/img.lua') 16 | paths.dofile('util/eval.lua') 17 | paths.dofile('util/Logger.lua') 18 | 19 | torch.setdefaulttensortype('torch.FloatTensor') 20 | 21 | -- Project directory 22 | projectDir = paths.concat(os.getenv('HOME'),'c2f-vol-train') 23 | 24 | ------------------------------------------------------------------------------- 25 | -- Process command line options 26 | ------------------------------------------------------------------------------- 27 | 28 | if not opt then 29 | 30 | local opts = paths.dofile('opts.lua') 31 | opt = opts.parse(arg) 32 | 33 | print('Saving everything to: ' .. opt.save) 34 | os.execute('mkdir -p ' .. opt.save) 35 | 36 | if opt.GPU == -1 then 37 | nnlib = nn 38 | else 39 | require 'cutorch' 40 | require 'cunn' 41 | require 'cudnn' 42 | nnlib = cudnn 43 | cutorch.setDevice(opt.GPU) 44 | end 45 | 46 | if opt.branch ~= 'none' or opt.continue then 47 | -- Continuing training from a prior experiment 48 | -- Figure out which new options have been set 49 | local setOpts = {} 50 | for i = 1,#arg do 51 | if arg[i]:sub(1,1) == '-' then table.insert(setOpts,arg[i]:sub(2,-1)) end 52 | end 53 | 54 | -- Where to load the previous options/model from 55 | if opt.branch ~= 'none' then opt.load = opt.expDir .. '/' .. opt.branch 56 | else opt.load = opt.expDir .. '/' .. opt.expID end 57 | 58 | -- Keep previous options, except those that were manually set 59 | local opt_ = opt 60 | opt = torch.load(opt_.load .. '/options.t7') 61 | opt.save = opt_.save 62 | opt.load = opt_.load 63 | opt.continue = opt_.continue 64 | for i = 1,#setOpts do opt[setOpts[i]] = opt_[setOpts[i]] end 65 | 66 | epoch = opt.lastEpoch + 1 67 | 68 | -- If there's a previous optimState, load that too 69 | if paths.filep(opt.load .. '/optimState.t7') then 70 | optimState = torch.load(opt.load .. '/optimState.t7') 71 | optimState.learningRate = opt.LR 72 | end 73 | 74 | else epoch = 1 end 75 | opt.epochNumber = epoch 76 | 77 | -- Training hyperparameters 78 | -- (Some of these aren't relevant for rmsprop which is the optimization we use) 79 | if not optimState then 80 | optimState = { 81 | learningRate = opt.LR, 82 | learningRateDecay = opt.LRdecay, 83 | momentum = opt.momentum, 84 | dampening = 0.0, 85 | weightDecay = opt.weightDecay 86 | } 87 | end 88 | 89 | -- Optimization function 90 | optfn = optim[opt.optMethod] 91 | 92 | -- Random number seed 93 | if opt.manualSeed ~= -1 then torch.manualSeed(opt.manualSeed) 94 | else torch.seed() end 95 | 96 | -- Save options to experiment directory 97 | torch.save(opt.save .. '/options.t7', opt) 98 | 99 | end 100 | 101 | ------------------------------------------------------------------------------- 102 | -- Load in annotations 103 | ------------------------------------------------------------------------------- 104 | 105 | annotLabels = {'train', 'valid'} 106 | annot,ref = {},{} 107 | for _,l in ipairs(annotLabels) do 108 | local a, namesFile 109 | if opt.dataset == 'mpii' and l == 'valid' and opt.finalPredictions == 1 then 110 | a = hdf5.open(opt.dataDir .. '/annot/test.h5') 111 | namesFile = io.open(opt.dataDir .. '/annot/test_images.txt') 112 | else 113 | a = hdf5.open(opt.dataDir .. '/annot/' .. l .. '.h5') 114 | namesFile = io.open(opt.dataDir .. '/annot/' .. l .. '_images.txt') 115 | end 116 | annot[l] = {} 117 | 118 | -- Read in annotation information 119 | local tags = {'part', 'center', 'scale', 'zind'} 120 | for _,tag in ipairs(tags) do annot[l][tag] = a:read(tag):all() end 121 | annot[l]['nsamples'] = annot[l]['part']:size()[1] 122 | 123 | -- Load in image file names (reading strings wasn't working from hdf5) 124 | annot[l]['images'] = {} 125 | local toIdxs = {} 126 | local idx = 1 127 | for line in namesFile:lines() do 128 | annot[l]['images'][idx] = line 129 | if not toIdxs[line] then toIdxs[line] = {} end 130 | table.insert(toIdxs[line], idx) 131 | idx = idx + 1 132 | end 133 | namesFile:close() 134 | 135 | -- This allows us to reference multiple people who are in the same image 136 | annot[l]['imageToIdxs'] = toIdxs 137 | 138 | -- Set up reference for training parameters 139 | ref[l] = {} 140 | ref[l].nsamples = annot[l]['nsamples'] 141 | ref[l].iters = opt[l .. 'Iters'] 142 | ref[l].batchsize = opt[l .. 'Batch'] 143 | ref[l].log = Logger(paths.concat(opt.save, l .. '.log'), opt.continue) 144 | end 145 | 146 | ref.predict = {} 147 | ref.predict.nsamples = annot.valid.nsamples 148 | ref.predict.iters = annot.valid.nsamples 149 | ref.predict.batchsize = 1 150 | 151 | -- Default input is assumed to be an image and output is assumed to be a heatmap 152 | -- This can change if an hdf5 file is used, or if opt.task specifies something different 153 | nParts = annot['train']['part']:size(2) 154 | dataDim = {3, opt.inputRes, opt.inputRes} 155 | labelDim = {nParts, opt.outputRes, opt.outputRes} 156 | 157 | -- Load up task specific variables/functions 158 | -- (this allows a decent amount of flexibility in network input/output and training) 159 | paths.dofile('util/' .. opt.task .. '.lua') 160 | 161 | local matchedParts 162 | if opt.dataset == 'mpii' then 163 | matchedParts = { 164 | {1,6}, {2,5}, {3,4}, 165 | {11,16}, {12,15}, {13,14} 166 | } 167 | elseif opt.dataset == 'flic' then 168 | matchedParts = { 169 | {1,4}, {2,5}, {3,6}, {7,8}, {9,10} 170 | } 171 | elseif opt.dataset == 'lsp' then 172 | matchedParts = { 173 | {1,6}, {2,5}, {3,4}, {7,12}, {8,11}, {9,10} 174 | } 175 | elseif opt.dataset == 'h36m' then 176 | matchedParts = { 177 | {2,5}, {3,6}, {4,7}, {12,15}, {13,16}, {14,17} 178 | } 179 | end 180 | 181 | matchedParts3D = {} 182 | for i = 1, opt.nStack do 183 | local matchTemp = {} 184 | for j = 1,#matchedParts do 185 | for k = 1,opt.resZ[i] do 186 | table.insert(matchTemp,{(matchedParts[j][1]-1)*opt.resZ[i]+k,(matchedParts[j][2]-1)*opt.resZ[i]+k}) 187 | end 188 | end 189 | matchedParts3D[i] = matchTemp 190 | end 191 | 192 | function applyFn(fn, t, t2) 193 | -- Helper function for applying an operation whether passed a table or tensor 194 | local t_ = {} 195 | if type(t) == "table" then 196 | if t2 then 197 | for i = 1,#t do t_[i] = applyFn(fn, t[i], t2[i]) end 198 | else 199 | for i = 1,#t do t_[i] = applyFn(fn, t[i]) end 200 | end 201 | else t_ = fn(t, t2) end 202 | return t_ 203 | end 204 | -------------------------------------------------------------------------------- /src/pypose/img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import scipy.signal 4 | import math 5 | 6 | import draw 7 | import ref 8 | 9 | # ============================================================================= 10 | # General image processing functions 11 | # ============================================================================= 12 | 13 | def get_transform(center, scale, res, rot=0): 14 | # Generate transformation matrix 15 | h = 200 * scale 16 | t = np.zeros((3, 3)) 17 | t[0, 0] = float(res[1]) / h 18 | t[1, 1] = float(res[0]) / h 19 | t[0, 2] = res[1] * (-float(center[0]) / h + .5) 20 | t[1, 2] = res[0] * (-float(center[1]) / h + .5) 21 | t[2, 2] = 1 22 | if not rot == 0: 23 | rot = -rot # To match direction of rotation from cropping 24 | rot_mat = np.zeros((3,3)) 25 | rot_rad = rot * np.pi / 180 26 | sn,cs = np.sin(rot_rad), np.cos(rot_rad) 27 | rot_mat[0,:2] = [cs, -sn] 28 | rot_mat[1,:2] = [sn, cs] 29 | rot_mat[2,2] = 1 30 | # Need to rotate around center 31 | t_mat = np.eye(3) 32 | t_mat[0,2] = -res[1]/2 33 | t_mat[1,2] = -res[0]/2 34 | t_inv = t_mat.copy() 35 | t_inv[:2,2] *= -1 36 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) 37 | return t 38 | 39 | def transform(pt, center, scale, res, invert=0, rot=0): 40 | # Transform pixel location to different reference 41 | t = get_transform(center, scale, res, rot=rot) 42 | if invert: 43 | t = np.linalg.inv(t) 44 | new_pt = np.array([pt[0], pt[1], 1.]).T 45 | new_pt = np.dot(t, new_pt) 46 | return new_pt[:2].astype(int) 47 | 48 | def crop(img, center, scale, res, rot=0): 49 | # Upper left point 50 | ul = np.array(transform([0, 0], center, scale, res, invert=1)) 51 | # Bottom right point 52 | br = np.array(transform(res, center, scale, res, invert=1)) 53 | 54 | # Padding so that when rotated proper amount of context is included 55 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) 56 | if not rot == 0: 57 | ul -= pad 58 | br += pad 59 | 60 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 61 | if len(img.shape) > 2: 62 | new_shape += [img.shape[2]] 63 | new_img = np.zeros(new_shape) 64 | 65 | # Range to fill new array 66 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] 67 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] 68 | # Range to sample from original image 69 | old_x = max(0, ul[0]), min(len(img[0]), br[0]) 70 | old_y = max(0, ul[1]), min(len(img), br[1]) 71 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] 72 | 73 | if not rot == 0: 74 | # Remove padding 75 | new_img = scipy.misc.imrotate(new_img, rot) 76 | new_img = new_img[pad:-pad, pad:-pad] 77 | 78 | return scipy.misc.imresize(new_img, res) 79 | 80 | def two_pt_crop(img, scale, pt1, pt2, pad, res, chg=None): 81 | center = (pt1+pt2) / 2 82 | scale = max(20*scale, np.linalg.norm(pt1-pt2)) * .007 83 | scale *= pad 84 | angle = math.atan2(pt2[1]-pt1[1],pt2[0]-pt1[0]) * 180 / math.pi - 90 85 | flip = False 86 | 87 | # Handle data augmentation 88 | if chg is not None: 89 | # Flipping 90 | if 'flip' in chg: 91 | if np.random.rand() < .5: 92 | flip = True 93 | # Scaling 94 | if 'scale' in chg: 95 | scale *= min(1+chg['scale'], max(1-chg['scale'], (np.random.randn() * chg['scale']) + 1)) 96 | # Rotation 97 | if 'rotate' in chg: 98 | angle += np.random.randint(-chg['rotate'], chg['rotate'] + 1) 99 | # Translation 100 | if 'translate' in chg: 101 | for i in xrange(2): 102 | offset = np.random.randint(-chg['translate'], chg['translate'] + 1) * scale 103 | center[i] += offset 104 | 105 | # Create input image 106 | cropped = crop(img, center, scale, res, rot=angle) 107 | inp = np.zeros((3, res[0], res[1])) 108 | for i in xrange(3): 109 | inp[i, :, :] = cropped[:, :, i] 110 | 111 | # Create heatmap 112 | hm = np.zeros((2,res[0],res[1])) 113 | draw.gaussian(hm[0],transform(pt1, center, scale, res, rot=angle), 2) 114 | draw.gaussian(hm[1],transform(pt2, center, scale, res, rot=angle), 2) 115 | 116 | if flip: 117 | inp = np.array([np.fliplr(inp[i]) for i in xrange(len(inp))]) 118 | hm = np.array([np.fliplr(hm[i]) for i in xrange(len(hm))]) 119 | 120 | return inp, hm 121 | 122 | def nms(img): 123 | # Do non-maximum suppression on a 2D array 124 | win_size = 3 125 | domain = np.ones((win_size, win_size)) 126 | maxes = scipy.signal.order_filter(img, domain, win_size ** 2 - 1) 127 | diff = maxes - img 128 | result = img.copy() 129 | result[diff > 0] = 0 130 | return result 131 | 132 | # ============================================================================= 133 | # Helpful display functions 134 | # ============================================================================= 135 | 136 | def gauss(x, a, b, c, d=0): 137 | return a * np.exp(-(x - b)**2 / (2 * c**2)) + d 138 | 139 | def color_heatmap(x): 140 | color = np.zeros((x.shape[0],x.shape[1],3)) 141 | color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) 142 | color[:,:,1] = gauss(x, 1, .5, .3) 143 | color[:,:,2] = gauss(x, 1, .2, .3) 144 | color[color > 1] = 1 145 | color = (color * 255).astype(np.uint8) 146 | return color 147 | 148 | def sample_with_heatmap(dataset, inp, out, num_rows=2, parts_to_show=None): 149 | img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0])) 150 | for i in xrange(3): 151 | img[:, :, i] = inp[i, :, :] 152 | 153 | if parts_to_show is None: 154 | parts_to_show = np.arange(out.shape[0]) 155 | 156 | # Generate a single image to display input/output pair 157 | num_cols = np.ceil(float(len(parts_to_show)) / num_rows) 158 | size = img.shape[0] / num_rows 159 | 160 | full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8) 161 | full_img[:img.shape[0], :img.shape[1]] = img 162 | 163 | inp_small = scipy.misc.imresize(img, [size, size]) 164 | 165 | # Set up heatmap display for each part 166 | for i, part in enumerate(parts_to_show): 167 | if type(part) is str: 168 | part_idx = ref.parts[dataset].index(part) 169 | else: 170 | part_idx = part 171 | out_resized = scipy.misc.imresize(out[part_idx], [size, size]) 172 | out_resized = out_resized.astype(float)/255 173 | out_img = inp_small.copy() * .3 174 | color_hm = color_heatmap(out_resized) 175 | out_img += color_hm * .7 176 | 177 | col_offset = (i % num_cols + num_rows) * size 178 | row_offset = (i // num_cols) * size 179 | full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img 180 | 181 | return full_img 182 | 183 | def sample_with_skeleton(annot, idx, preds, res=None): 184 | 185 | # Load image and basic info 186 | ds = annot.attrs['name'] 187 | img = ref.loadimg(annot, idx) 188 | c = annot['center'][idx] 189 | s = annot['scale'][idx] 190 | if res is None: 191 | res = [256, 256] 192 | 193 | # Skeleton colors 194 | colors = [(255, 0, 0), # Upper arm (left) 195 | (255, 100, 100), # Lower arm (left) 196 | (0, 0, 255), # Upper arm (right) 197 | (100, 100, 255), # Lower arm (right) 198 | (100, 255, 100), # Head/neck/face 199 | (255, 75, 0), # Upper leg (left) 200 | (255, 175, 100), # Lower leg (left) 201 | (0, 75, 255), # Upper leg (right) 202 | (100, 175, 255) # Lower leg (right) 203 | ] 204 | 205 | # Draw arms 206 | draw.limb(img, preds[ref.parts[ds].index('lsho')], preds[ref.parts[ds].index('lelb')], colors[0], 5 * s) 207 | draw.limb(img, preds[ref.parts[ds].index('lwri')], preds[ref.parts[ds].index('lelb')], colors[1], 5 * s) 208 | draw.limb(img, preds[ref.parts[ds].index('rsho')], preds[ref.parts[ds].index('relb')], colors[2], 5 * s) 209 | draw.limb(img, preds[ref.parts[ds].index('rwri')], preds[ref.parts[ds].index('relb')], colors[3], 5 * s) 210 | 211 | if ds == 'mpii': 212 | # MPII 213 | # Draw head 214 | draw.circle(img, preds[ref.parts[ds].index('head')], colors[4], 5 * s) 215 | draw.circle(img, preds[ref.parts[ds].index('neck')], colors[4], 5 * s) 216 | 217 | # Draw legs 218 | draw.limb(img, preds[ref.parts[ds].index('lhip')], preds[ref.parts[ds].index('lkne')], colors[5], 5 * s) 219 | draw.limb(img, preds[ref.parts[ds].index('lank')], preds[ref.parts[ds].index('lkne')], colors[6], 5 * s) 220 | draw.limb(img, preds[ref.parts[ds].index('rhip')], preds[ref.parts[ds].index('rkne')], colors[7], 5 * s) 221 | draw.limb(img, preds[ref.parts[ds].index('rank')], preds[ref.parts[ds].index('rkne')], colors[8], 5 * s) 222 | 223 | elif ds == 'flic': 224 | # FLIC 225 | # Draw face 226 | draw.circle(img, preds[ref.parts[ds].index('leye')], colors[4], 3 * s) 227 | draw.circle(img, preds[ref.parts[ds].index('reye')], colors[4], 3 * s) 228 | draw.circle(img, preds[ref.parts[ds].index('nose')], colors[4], 3 * s) 229 | 230 | # Draw hips 231 | draw.circle(img, preds[ref.parts[ds].index('lhip')], colors[5], 5 * s) 232 | draw.circle(img, preds[ref.parts[ds].index('rhip')], colors[7], 5 * s) 233 | 234 | return crop(img, c, s, res) 235 | -------------------------------------------------------------------------------- /src/util/img.lua: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------------ 2 | -- Coordinate transformation 3 | ------------------------------------------------------------------------------- 4 | 5 | function getTransform(center, scale, rot, res) 6 | local h = 200 * scale 7 | local t = torch.eye(3) 8 | 9 | -- Scaling 10 | t[1][1] = res / h 11 | t[2][2] = res / h 12 | 13 | -- Translation 14 | t[1][3] = res * (-center[1] / h + .5) 15 | t[2][3] = res * (-center[2] / h + .5) 16 | 17 | -- Rotation 18 | if rot ~= 0 then 19 | rot = -rot 20 | local r = torch.eye(3) 21 | local ang = rot * math.pi / 180 22 | local s = math.sin(ang) 23 | local c = math.cos(ang) 24 | r[1][1] = c 25 | r[1][2] = -s 26 | r[2][1] = s 27 | r[2][2] = c 28 | -- Need to make sure rotation is around center 29 | local t_ = torch.eye(3) 30 | t_[1][3] = -res/2 31 | t_[2][3] = -res/2 32 | local t_inv = torch.eye(3) 33 | t_inv[1][3] = res/2 34 | t_inv[2][3] = res/2 35 | t = t_inv * r * t_ * t 36 | end 37 | 38 | return t 39 | end 40 | 41 | function transform(pt, center, scale, rot, res, invert) 42 | local pt_ = torch.ones(3) 43 | pt_[1],pt_[2] = pt[1]-1,pt[2]-1 44 | 45 | local t = getTransform(center, scale, rot, res) 46 | if invert then 47 | t = torch.inverse(t) 48 | end 49 | local new_point = (t*pt_):sub(1,2):add(1e-4) 50 | 51 | return new_point:int():add(1) 52 | end 53 | 54 | ------------------------------------------------------------------------------- 55 | -- Cropping 56 | ------------------------------------------------------------------------------- 57 | 58 | function crop(img, center, scale, rot, res) 59 | local ul = transform({1,1}, center, scale, 0, res, true) 60 | local br = transform({res+1,res+1}, center, scale, 0, res, true) 61 | 62 | 63 | local pad = math.floor(torch.norm((ul - br):float())/2 - (br[1]-ul[1])/2) 64 | if rot ~= 0 then 65 | ul = ul - pad 66 | br = br + pad 67 | end 68 | 69 | local newDim,newImg,ht,wd 70 | 71 | if img:size():size() > 2 then 72 | newDim = torch.IntTensor({img:size(1), br[2] - ul[2], br[1] - ul[1]}) 73 | newImg = torch.zeros(newDim[1],newDim[2],newDim[3]) 74 | ht = img:size(2) 75 | wd = img:size(3) 76 | else 77 | newDim = torch.IntTensor({br[2] - ul[2], br[1] - ul[1]}) 78 | newImg = torch.zeros(newDim[1],newDim[2]) 79 | ht = img:size(1) 80 | wd = img:size(2) 81 | end 82 | 83 | local newX = torch.Tensor({math.max(1, -ul[1] + 2), math.min(br[1], wd+1) - ul[1]}) 84 | local newY = torch.Tensor({math.max(1, -ul[2] + 2), math.min(br[2], ht+1) - ul[2]}) 85 | local oldX = torch.Tensor({math.max(1, ul[1]), math.min(br[1], wd+1) - 1}) 86 | local oldY = torch.Tensor({math.max(1, ul[2]), math.min(br[2], ht+1) - 1}) 87 | 88 | if newDim:size(1) > 2 then 89 | newImg:sub(1,newDim[1],newY[1],newY[2],newX[1],newX[2]):copy(img:sub(1,newDim[1],oldY[1],oldY[2],oldX[1],oldX[2])) 90 | else 91 | newImg:sub(newY[1],newY[2],newX[1],newX[2]):copy(img:sub(oldY[1],oldY[2],oldX[1],oldX[2])) 92 | end 93 | 94 | if rot ~= 0 then 95 | newImg = image.rotate(newImg, rot * math.pi / 180, 'bilinear') 96 | if newDim:size(1) > 2 then 97 | newImg = newImg:sub(1,newDim[1],pad,newDim[2]-pad,pad,newDim[3]-pad) 98 | else 99 | newImg = newImg:sub(pad,newDim[1]-pad,pad,newDim[2]-pad) 100 | end 101 | end 102 | 103 | newImg = image.scale(newImg,res,res) 104 | return newImg 105 | end 106 | 107 | function two_point_crop(img, s, pt1, pt2, pad, res) 108 | local center = (pt1 + pt2) / 2 109 | local scale = math.max(20*s,torch.norm(pt1 - pt2)) * .007 110 | scale = scale * pad 111 | local angle = math.atan2(pt2[2]-pt1[2],pt2[1]-pt1[1]) * 180 / math.pi - 90 112 | return crop(img, center, scale, angle, res) 113 | end 114 | 115 | ------------------------------------------------------------------------------- 116 | -- Non-maximum Suppression 117 | ------------------------------------------------------------------------------- 118 | 119 | -- Set up max network for NMS 120 | nms_window_size = 3 121 | nms_pad = (nms_window_size - 1)/2 122 | maxlayer = nn.Sequential() 123 | if cudnn then 124 | maxlayer:add(cudnn.SpatialMaxPooling(nms_window_size, nms_window_size,1,1, nms_pad, nms_pad)) 125 | maxlayer:cuda() 126 | else 127 | maxlayer:add(nn.SpatialMaxPooling(nms_window_size, nms_window_size,1,1, nms_pad,nms_pad)) 128 | end 129 | maxlayer:evaluate() 130 | 131 | function local_maxes(hm, n, c, s, hm_idx) 132 | hm = torch.Tensor(1,16,64,64):copy(hm):float() 133 | if hm_idx then hm = hm:sub(1,-1,hm_idx,hm_idx) end 134 | local hm_dim = hm:size() 135 | local max_out 136 | -- First do nms 137 | if cudnn then 138 | local hmCuda = torch.CudaTensor(1, hm_dim[2], hm_dim[3], hm_dim[4]) 139 | hmCuda:copy(hm) 140 | max_out = maxlayer:forward(hmCuda) 141 | cutorch.synchronize() 142 | else 143 | max_out = maxlayer:forward(hm) 144 | end 145 | 146 | local nms = torch.cmul(hm, torch.eq(hm, max_out:float()):float())[1] 147 | -- Loop through each heatmap retrieving top n locations, and their scores 148 | local pred_coords = torch.Tensor(hm_dim[2], n, 2) 149 | local pred_scores = torch.Tensor(hm_dim[2], n) 150 | for i = 1, hm_dim[2] do 151 | local nms_flat = nms[i]:view(nms[i]:nElement()) 152 | local vals,idxs = torch.sort(nms_flat,1,true) 153 | for j = 1,n do 154 | local pt = {idxs[j] % 64, torch.ceil(idxs[j] / 64) } 155 | pred_coords[i][j] = transform(pt, c, s, 0, 64, true) 156 | pred_scores[i][j] = vals[j] 157 | end 158 | end 159 | return pred_coords, pred_scores 160 | end 161 | 162 | ------------------------------------------------------------------------------- 163 | -- Draw gaussian 164 | ------------------------------------------------------------------------------- 165 | 166 | function drawGaussian(img, pt, sigma) 167 | -- Draw a 2D gaussian 168 | -- Check that any part of the gaussian is in-bounds 169 | local ul = {math.floor(pt[1] - 3 * sigma), math.floor(pt[2] - 3 * sigma)} 170 | local br = {math.floor(pt[1] + 3 * sigma), math.floor(pt[2] + 3 * sigma)} 171 | -- If not, return the image as is 172 | if (ul[1] > img:size(2) or ul[2] > img:size(1) or br[1] < 1 or br[2] < 1) then return img end 173 | -- Generate gaussian 174 | local size = 6 * sigma + 1 175 | local g = image.gaussian(size) -- , 1 / size, 1) 176 | -- Usable gaussian range 177 | local g_x = {math.max(1, -ul[1]), math.min(br[1], img:size(2)) - math.max(1, ul[1]) + math.max(1, -ul[1])} 178 | local g_y = {math.max(1, -ul[2]), math.min(br[2], img:size(1)) - math.max(1, ul[2]) + math.max(1, -ul[2])} 179 | -- Image range 180 | local img_x = {math.max(1, ul[1]), math.min(br[1], img:size(2))} 181 | local img_y = {math.max(1, ul[2]), math.min(br[2], img:size(1))} 182 | assert(g_x[1] > 0 and g_y[1] > 0) 183 | img:sub(img_y[1], img_y[2], img_x[1], img_x[2]):add(g:sub(g_y[1], g_y[2], g_x[1], g_x[2])) 184 | img[img:gt(1)] = 1 185 | return img 186 | end 187 | 188 | function drawGaussian3D(vol, pt, z, sigma_2d, size_z) 189 | 190 | local resZ = vol:size(1) 191 | local res2D = vol:size(2) 192 | local temp = torch.zeros(res2D,res2D) 193 | drawGaussian(temp, pt, sigma_2d) 194 | local zun = image.gaussian(size_z)[torch.ceil(size_z/2)] 195 | local count = 0 196 | for i = z-torch.floor(size_z/2),z+torch.floor(size_z/2) do 197 | count = count+1 198 | if i>0 and i 0 and x_idx -1 > 0 and y_idx < img:size(1) and x_idx < img:size(2) then 225 | img:sub(y_idx-1,y_idx,x_idx-1,x_idx):fill(val) 226 | end 227 | end 228 | end 229 | img[img:gt(1)] = 1 230 | end 231 | 232 | ------------------------------------------------------------------------------- 233 | -- Flipping functions 234 | ------------------------------------------------------------------------------- 235 | 236 | function shuffleLR(x,matchedParts) 237 | local dim 238 | if x:nDimension() == 4 then 239 | dim = 2 240 | else 241 | assert(x:nDimension() == 3) 242 | dim = 1 243 | end 244 | 245 | for i = 1,#matchedParts do 246 | local idx1, idx2 = unpack(matchedParts[i]) 247 | local tmp = x:narrow(dim, idx1, 1):clone() 248 | x:narrow(dim, idx1, 1):copy(x:narrow(dim, idx2, 1)) 249 | x:narrow(dim, idx2, 1):copy(tmp) 250 | end 251 | 252 | return x 253 | end 254 | 255 | function flip(x) 256 | require 'image' 257 | local y = torch.FloatTensor(x:size()) 258 | for i = 1, x:size(1) do 259 | image.hflip(y[i], x[i]:float()) 260 | end 261 | return y:typeAs(x) 262 | end 263 | -------------------------------------------------------------------------------- /src/models/layers/MRF.lua: -------------------------------------------------------------------------------- 1 | -- Designed only for loop-free models 2 | local MRF, parent = torch.class('MRF', 'nn.Module') 3 | 4 | function gen_pairwise_ref(conn) 5 | local pw_ref = {} 6 | local inv_ref = {} 7 | local num_nodes = conn:size()[1] 8 | local count = 1 9 | for i = 1,num_nodes do 10 | for j = i+1,num_nodes do 11 | if conn[i][j] == 1 then 12 | if pw_ref[i] == nil then pw_ref[i] = {} end 13 | if pw_ref[j] == nil then pw_ref[j] = {} end 14 | pw_ref[i][j] = count 15 | pw_ref[j][i] = -count 16 | inv_ref[count] = {i,j} 17 | count = count + 1 18 | end 19 | end 20 | end 21 | return pw_ref,inv_ref 22 | end 23 | 24 | function MRF:__init(connections, num_states) 25 | -- No error checking here, but connections should be 0 along the diagonal 26 | -- and connections should equal connections:transpose 27 | parent.__init(self) 28 | self.num_nodes = connections:size()[1] 29 | self.num_states = num_states 30 | self.num_edges = connections:sum() / 2 31 | self.conn = torch.Tensor(connections:size()):copy(connections) 32 | self.all_idxs = torch.linspace(1,self.num_nodes,self.num_nodes) 33 | self.pw_ref,_ = gen_pairwise_ref(connections) 34 | self.do_full_reset = true 35 | end 36 | 37 | function MRF:reset(input, batch_size) 38 | if self.do_full_reset then 39 | -- Allocate memory for everything once 40 | self.unary = torch.Tensor(batch_size, self.num_nodes, self.num_states) 41 | self.pairwise = torch.Tensor(batch_size, self.num_edges, self.num_states, self.num_states) 42 | self.msg = torch.Tensor(batch_size, self.num_nodes, self.num_nodes, 2, self.num_states) 43 | self.msg_flag = torch.zeros(batch_size, self.num_nodes, self.num_nodes) 44 | self.z = torch.Tensor(batch_size, self.num_nodes) 45 | self.temp_output = torch.Tensor(self.batch_size, self.num_nodes * self.num_states) 46 | self.temp_gradInput = torch.zeros(self.batch_size, input[1]:numel()) 47 | -- In case everything is else is on GPU 48 | self.output = torch.Tensor(self.batch_size, self.num_nodes * self.num_states):typeAs(input) 49 | self.gradInput = torch.zeros(self.batch_size, input[1]:numel()):typeAs(input) 50 | if self.out_vector then 51 | -- Extra stuff to handle receiving a single vector as input properly 52 | self.output = self.output:view(out:size(2)) 53 | self.gradInput = self.gradInput:view(self.gradInput:numel()) 54 | end 55 | else 56 | self.msg_flag:fill(0) 57 | end 58 | end 59 | 60 | function MRF:set_potentials(unary, pairwise) 61 | self.unary:copy(unary) 62 | self.pairwise:copy(pairwise) 63 | end 64 | 65 | function MRF:get_neighbors(node_idx, exclude) 66 | local idxs = self.all_idxs[self.conn[node_idx]:eq(1)] 67 | if exclude ~= nil then idxs = idxs[idxs:ne(exclude)] end 68 | return idxs 69 | end 70 | 71 | function MRF:get_message(sample_idx, node_idx, exclude) 72 | local nb = self:get_neighbors(node_idx, exclude) 73 | local msg = self.msg[sample_idx][node_idx] 74 | local msg_flag = self.msg_flag[sample_idx][node_idx] 75 | if nb:nDimension() == 0 then 76 | -- Leaf node, return unary scores 77 | return self.unary[sample_idx][node_idx] 78 | else 79 | -- Get message from all neighbors 80 | local out_msg = torch.ones(self.num_states):typeAs(self.unary) 81 | for i = 1,nb:numel() do 82 | if msg_flag[nb[i]] == 0 then 83 | -- Get pairwise scores 84 | local pw_idx = self.pw_ref[node_idx][nb[i]] 85 | local pw_vals = self.pairwise[sample_idx][math.abs(pw_idx)] 86 | if pw_idx < 0 then pw_vals = pw_vals:transpose(1,2) end 87 | -- Get message from neighbor 88 | msg[nb[i]][1] = self:get_message(sample_idx, nb[i], node_idx) 89 | msg[nb[i]][2] = pw_vals * msg[nb[i]][1] 90 | msg_flag[nb[i]] = 1 91 | end 92 | out_msg:cmul(msg[nb[i]][2]) 93 | end 94 | out_msg:cmul(self.unary[sample_idx][node_idx]) 95 | if exclude == nil then 96 | -- The only reason to keep z per node is in case a graph is provided that isn't connected 97 | -- It doesn't add much overhead, so it is easy enough to keep, instead of restricting to 98 | -- only connected graphs 99 | self.z[sample_idx][node_idx] = out_msg:sum() 100 | out_msg:div(self.z[sample_idx][node_idx]) 101 | end 102 | return out_msg 103 | end 104 | end 105 | 106 | function MRF:updateOutput(input) 107 | local num_in = self.num_states*(self.num_nodes + self.num_edges * self.num_states) 108 | -- If only presented with a 1D vector, change the view so it is 2D 109 | if input:nDimension() == 1 then 110 | input = input:view(1,input:size()[1]) 111 | self.out_vector = true 112 | else 113 | self.out_vector = false 114 | end 115 | if input:nDimension() > 2 then 116 | -- No more than two dimensions expected 117 | print("Bad number of input dimensions to graphical model (must be 1 or 2)") 118 | elseif input[1]:numel() ~= num_in then 119 | -- Input should be a vector of size unary:numel() + pairwise:numel() 120 | print("Input has wrong number of elements") 121 | print(num_in .. " expected.") 122 | print(input[1]:numel() .. " received.") 123 | else 124 | if self.batch_size ~= input:size(1) then 125 | -- If we haven't done it before, or batch size changes 126 | -- Initialize batch size and output tensor dimensions 127 | self.batch_size = input:size(1) 128 | self.do_full_reset = true 129 | end 130 | 131 | -- Reset, important for message flags 132 | self:reset(input, self.batch_size) 133 | -- Set potentials 134 | local un = input:sub(1,-1,1,self.unary[1]:numel()) 135 | local pw = input:sub(1,-1,self.unary[1]:numel()+1,-1) 136 | self:set_potentials(un,pw) 137 | -- For each sample in the batch calculate output 138 | for i = 1,self.batch_size do 139 | for j = 1,self.num_nodes do 140 | local start_idx = (j-1)*self.num_states + 1 141 | self.temp_output[i][{{start_idx,start_idx+self.num_states-1}}] = self:get_message(i,j) 142 | end 143 | end 144 | 145 | if (torch.eq(self.temp_output,self.temp_output):sum() ~= self.temp_output:numel()) then 146 | print(input) 147 | print(self.temp_output) 148 | print("nan present in output") 149 | assert(false) 150 | end 151 | 152 | -- Copy output to GPU 153 | self.output:copy(self.temp_output) 154 | return self.output 155 | end 156 | end 157 | 158 | function MRF:gradient_message(sample_idx, node_idx, curr_grad, exclude) 159 | -- Change how we see the gradient input vector to ease data access 160 | local out = self.temp_output[sample_idx]:view(self.unary[1]:size()) 161 | local grad_in_un = self.temp_gradInput[sample_idx]:sub(1,self.unary[1]:numel()):view(self.unary[1]:size()) 162 | local grad_in_pw = self.temp_gradInput[sample_idx]:sub(self.unary[1]:numel()+1,-1):view(self.pairwise[1]:size()) 163 | -- Get node neighbors and messages from forward pass 164 | local nb = self:get_neighbors(node_idx, exclude) 165 | local in_msg = self.msg[sample_idx][node_idx] 166 | local out_msg 167 | if exclude == nil then out_msg = out[node_idx]*self.z[sample_idx][node_idx] 168 | else out_msg = self.msg[sample_idx][exclude][node_idx][1] end 169 | 170 | if nb:nDimension() == 0 then 171 | -- Leaf node, no neighbors, update all the unary score gradients 172 | if (torch.eq(curr_grad,curr_grad):sum() ~= self.num_states) then 173 | print("There was a nan in the passed gradient....") 174 | print(curr_grad) 175 | end 176 | grad_in_un[node_idx]:add(curr_grad) 177 | else 178 | -- Else, pass gradient message back to each neighbor before updating unary scores 179 | local temp_grad = torch.cmul(curr_grad, out_msg) 180 | for i = 1,nb:size()[1] do 181 | local temp_grad_2 = torch.cdiv(temp_grad, in_msg[nb[i]][2]) 182 | if (torch.eq(temp_grad_2,temp_grad_2):sum() ~= self.num_states) then 183 | print("Temp_grad_2 nan...") 184 | print(curr_grad) 185 | end 186 | local pw_idx = self.pw_ref[node_idx][nb[i]] 187 | local pw = self.pairwise[sample_idx][math.abs(pw_idx)] 188 | if pw_idx > 0 then pw = pw:transpose(1,2) end 189 | -- Pass message back 190 | self:gradient_message(sample_idx, nb[i], pw * temp_grad_2, node_idx) 191 | -- Calculate pairwise gradient 192 | local pw_grad = torch.ger(temp_grad_2, in_msg[nb[i]][1]) -- Vector outer product 193 | if node_idx > nb[i] then pw_grad = pw_grad:transpose(1,2) end 194 | grad_in_pw[math.abs(pw_idx)]:add(pw_grad) 195 | end 196 | temp_grad:cdiv(self.unary[sample_idx][node_idx]) 197 | temp_grad[torch.eq(self.unary[sample_idx][node_idx],0)] = 0 198 | if (torch.eq(temp_grad,temp_grad):sum() ~= self.num_states) then 199 | print("There was a nan in temp gradient...still?") 200 | print(curr_grad) 201 | end 202 | grad_in_un[node_idx]:add(temp_grad) 203 | end 204 | end 205 | 206 | function MRF:updateGradInput(input, gradOutput) 207 | -- If only presented with a 1D vector, change the view so it is 2D 208 | if input:nDimension() == 1 then 209 | gradOutput = gradOutput:view(1, gradOutput:numel()) 210 | end 211 | 212 | for i = 1,self.batch_size do 213 | -- For each sample, loop through each node 214 | -- The outer loop is embarrassingly parallel, the inner loop can easily be as well (at the moment it is not) 215 | local out = self.temp_output[i]:view(self.num_nodes,self.num_states) 216 | local grad_out = gradOutput[i]:view(self.num_nodes,self.num_states):double() 217 | for j = 1,self.num_nodes do 218 | -- To start handle gradient calculation taking into account normalization by z 219 | local temp_mat = -torch.repeatTensor(out[j],self.num_states,1) + torch.diag(torch.ones(self.num_states)) 220 | local temp_grad = (temp_mat * grad_out[j]) / self.z[i][j] 221 | -- Then send gradient message back through nodes 222 | self:gradient_message(i, j, temp_grad, nil) 223 | end 224 | end 225 | 226 | if (torch.eq(self.temp_gradInput,self.temp_gradInput):sum() ~= self.temp_gradInput:numel()) then 227 | print(self.temp_output) 228 | print(gradOutput) 229 | print(self.temp_gradInput) 230 | print("nan present") 231 | assert(false) 232 | end 233 | 234 | self.gradInput:copy(self.temp_gradInput) 235 | return self.gradInput 236 | end -------------------------------------------------------------------------------- /src/pypose/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import getopt, sys, time 4 | 5 | import img 6 | import draw 7 | import ref 8 | import segment 9 | 10 | def gendefault(annot, idx, img_in, chg=None): 11 | # Initialize sample parameters 12 | c = annot['center'][idx] 13 | s = annot['scale'][idx] 14 | flip, r = False, 0 15 | flip_idxs = ref.flipped_parts[annot.attrs['name']] 16 | 17 | # Handle data augmentation 18 | if chg is not None: 19 | # Flipping 20 | if 'flip' in chg: 21 | if np.random.rand() < .5: 22 | flip = True 23 | # Scaling 24 | if 'scale' in chg: 25 | s *= min(1+chg['scale'], max(1-chg['scale'], (np.random.randn() * chg['scale']) + 1)) 26 | # Rotation 27 | if 'rotate' in chg: 28 | if chg['rotate'] == -1: 29 | # Force vertical orientation 30 | r = annot['torsoangle'][idx] 31 | else: 32 | r = np.random.randint(-chg['rotate'], chg['rotate'] + 1) 33 | # Translation 34 | if 'translate' in chg: 35 | for i in xrange(2): 36 | offset = np.random.randint(-chg['translate'], chg['translate'] + 1) 37 | c[i] += offset 38 | 39 | # Generate input image 40 | cropped = img.crop(img_in, c, s, ref.in_res, rot=r) 41 | inp = np.zeros((3, ref.in_res[0], ref.in_res[1])) 42 | for i in xrange(3): 43 | inp[i, :, :] = cropped[:, :, i] 44 | 45 | # Generate part heatmap output 46 | num_parts = annot['part'].shape[1] 47 | out = np.zeros((num_parts, ref.out_res[0], ref.out_res[1])) 48 | for i in xrange(num_parts): 49 | pt = annot['part'][idx,i] 50 | if pt[0] > 0: 51 | draw.gaussian(out[i], img.transform(pt, c, s, ref.out_res, rot=r), 2) 52 | 53 | # Flip sample 54 | if flip: 55 | inp = np.array([np.fliplr(inp[i]) for i in xrange(len(inp))]) 56 | out = np.array([np.fliplr(out[flip_idxs[i]]) for i in xrange(len(out))]) 57 | 58 | return inp,out 59 | 60 | def gendetect(annot, idx, img_in, chg=None): 61 | img_c = [img_in.shape[1]/2, img_in.shape[0]/2] 62 | img_s = max(img_in.shape) / 200 63 | flip, r = False, 0 64 | idxs = np.where(annot['index'][:] == annot['index'][idx])[0] 65 | 66 | # Handle data augmentation 67 | if chg is not None: 68 | # Flipping 69 | if 'flip' in chg: 70 | if np.random.rand() < .5: 71 | flip = True 72 | # Scaling 73 | if 'scale' in chg: 74 | img_s *= min(1+chg['scale'], max(1-chg['scale'], (np.random.randn() * chg['scale']) + 1)) 75 | # Rotation 76 | # if 'rotate' in chg: 77 | # r = np.random.randint(-chg['rotate'], chg['rotate'] + 1) 78 | # Translation 79 | if 'translate' in chg: 80 | for i in xrange(2): 81 | offset = np.random.randint(-chg['translate'], chg['translate'] + 1) 82 | c[i] += offset 83 | 84 | img_c[0] += img_s * np.random.randint(-10,10) 85 | img_c[1] += img_s * np.random.randint(-10,10) 86 | cropped = img.crop(img_in, img_c, img_s, ref.in_res) 87 | inp = np.zeros((3, ref.in_res[0], ref.in_res[1])) 88 | for i in xrange(3): inp[i, :, :] = cropped[:, :, i] 89 | 90 | out = np.zeros((2, ref.out_res[0], ref.out_res[1])) 91 | for i in idxs: 92 | pt = img.transform(annot['center'][i], img_c, img_s, ref.out_res) 93 | draw.gaussian(out[0], pt, 1) 94 | out[1,pt[1]-1:pt[1]+1,pt[0]-1:pt[0]+1] = annot['scale'][i] / img_s 95 | 96 | if flip: 97 | inp = np.array([np.fliplr(inp[i]) for i in xrange(len(inp))]) 98 | out = np.array([np.fliplr(out[i]) for i in xrange(len(out))]) 99 | 100 | return inp,out 101 | 102 | def gencascade(annot, idx, img_in, chg=None, extra_args=None): 103 | jnt1 = extra_args[0] 104 | jnt2 = extra_args[1] 105 | pt1 = annot['part'][idx,jnt1] 106 | pt2 = annot['part'][idx,jnt2] 107 | if pt1.min() <= 0 or pt2.min() <= 0: 108 | return np.zeros((3,ref.out_res[0],ref.out_res[1])), np.zeros((2,ref.out_res[0],ref.out_res[1])) 109 | else: 110 | return img.two_pt_crop(img_in, annot['scale'][idx], pt1, pt2, 1.8, ref.out_res, chg) 111 | 112 | def gensample(annot, idx, chg=None, sampletype='default', extra_args=None): 113 | img_in = ref.loadimg(annot, idx) 114 | if sampletype == 'default': 115 | return gendefault(annot, idx, img_in, chg) 116 | elif sampletype == 'detect': 117 | return gendetect(annot, idx, img_in, chg) 118 | elif sampletype == 'cascade': 119 | return gencascade(annot, idx, img_in, chg, extra_args) 120 | 121 | def generateset(dataset, settype, filename, numsamples, datadir=None, chg=None, sampletype='default', idxs=None, extra_args=None): 122 | # Generate full hdf5 dataset 123 | 124 | # Path to output file 125 | if datadir is None: 126 | filepath = ref.posedir + '/data/' + dataset + '/' + filename + '.h5' 127 | else: 128 | filepath = datadir + '/' + dataset + '/' + filename + '.h5' 129 | # Load in annotations 130 | annot = ref.load(dataset, settype) 131 | 132 | # Option to strictly follow the order of the provided annotations 133 | # Useful for generating test sets. 134 | if idxs is None: 135 | numavailable = len(annot['index']) # Number of available samples 136 | else: 137 | numavailable = len(idxs) 138 | inorder = False 139 | if numsamples == -1: 140 | numsamples = numavailable 141 | inorder = True 142 | 143 | print "" 144 | print "Generating %s %s set: %s" % (dataset, sampletype, settype) 145 | print "Path to dataset: %s" % filepath 146 | print "Number of samples: %d" % numsamples 147 | print "Data augmentation: %s" % (str(chg)) 148 | 149 | # Data/label sizes can be all over the place, this is the easiest way to check 150 | ex_in, ex_out = gensample(annot, 0, chg=chg, sampletype=sampletype, extra_args=extra_args) 151 | 152 | # Initialize numpy arrays to hold data 153 | data = np.zeros((numsamples, ex_in.shape[0], ex_in.shape[1], ex_in.shape[2]), np.float32) 154 | label = np.zeros((numsamples, ex_out.shape[0], ex_out.shape[1], ex_out.shape[2]), np.float32) 155 | ref_idxs = np.zeros((numsamples, 1), np.float32) 156 | 157 | # Loop to generate new samples 158 | print '' 159 | print '| Progress |' 160 | print '|', 161 | sys.stdout.flush() 162 | 163 | starttime = time.time() 164 | for i in xrange(numsamples): 165 | if idxs is not None: idx = idxs[i] 166 | elif inorder: idx = i 167 | else: idx = np.random.randint(numavailable) 168 | 169 | data[i], label[i] = gensample(annot, idx, chg=chg, sampletype=sampletype, extra_args=extra_args) 170 | ref_idxs[i] = idx 171 | 172 | if i % (numsamples/10) == 0: 173 | print '=', 174 | sys.stdout.flush() 175 | 176 | print '|' 177 | print '' 178 | print 'Done!', 179 | print '(%.2f seconds to complete.)' % (time.time() - starttime) 180 | print '' 181 | 182 | # Write out to hdf5 files 183 | with h5py.File(filepath, 'w') as f: 184 | f['data'] = data 185 | f['label'] = label 186 | f['index'] = ref_idxs 187 | 188 | def helpmessage(): 189 | print "Extra flags:" 190 | print " -d, --dataset :: Datset choice (mpii or flic), REQUIRED" 191 | print " -o, --outfile :: Output file for data (do not include '.h5'), REQUIRED" 192 | print " -p, --prefix :: Directory to save to (no need to include dataset name)" 193 | print " -t, --type :: Dataset type (train or test), default is train" 194 | print " -n, --numsamples :: Number of samples to generate, default is all available (-1) for test and 100 for train" 195 | print "" 196 | print "Augmentation options: (default Tompson's options for train, none for test)" 197 | print " -m, --move :: Translate (0 - 50)" 198 | print " -z, --zoom :: Scale (0.0 - 1.0)" 199 | print " -r, --rotate :: Rotate (-1 for fixed vertical, 0-180 for max distortion)" 200 | print " (Tompson's options are: -m 0 -z .5 -r 20" 201 | print "" 202 | print "Other dataset types:" 203 | print " -q, --detect" 204 | print " -c, --cascade :: Provide first joint as argument, must use additional argument below" 205 | print " -j, --pairedjoint :: Provide second joint to be used with 'cascade'" 206 | print "" 207 | print "Additional limb heatmap output:" 208 | print " -s, --segment :: - 0 No limb segment output (default)" 209 | print " - 1 Does not distinguish parts, angle == angle + 180" 210 | print " - 2 Distinguishes different part types, angle == angle + 180" 211 | print " - 3 Distinguishes different part types, angle != angle + 180" 212 | sys.exit(2) 213 | 214 | def main(argv): 215 | # Default values 216 | dataset = None 217 | datadir = None 218 | outfile = None 219 | numsamples = 100 220 | settype = 'train' 221 | chg = None 222 | sampletype = 'default' 223 | jnt1 = -1 224 | jnt2 = -1 225 | extra = None 226 | 227 | # Process command line arguments 228 | try: 229 | opts, args = getopt.getopt(argv, "hd:o:p:t:n:m:z:r:s:qc:j:", ["help", "dataset=", "outfile=", "prefix=", "type=", 230 | "numsamples=", "move=", "zoom=", "rotate=", 231 | "segment=", "detect", "cascade=", "pairedjoint="]) 232 | except getopt.GetoptError: 233 | print "Incorrect arguments" 234 | helpmessage() 235 | sys.exit() 236 | for opt,arg in opts: 237 | # Help 238 | if opt in ('-h','--help'): 239 | helpmessage() 240 | # Dataset choice 241 | elif opt in ('-d','--dataset'): 242 | dataset = arg 243 | if not (dataset in ['mpii', 'flic']): 244 | print "Bad argument for --dataset" 245 | helpmessage() 246 | # Output file 247 | elif opt in ('-o','--outfile'): 248 | outfile = arg 249 | # Prefix 250 | elif opt in ('-p','--prefix'): 251 | datadir = arg 252 | # Set type 253 | elif opt in ('-t','--type'): 254 | settype = arg 255 | if not (settype in ['train','test','valid','train_obs','test_obs']): 256 | print "Bad argument for --type" 257 | helpmessage() 258 | # Number of samples 259 | elif opt in ('-n','--numsamples'): 260 | numsamples = int(arg) 261 | if numsamples < -1: 262 | print "Bad argument for --numsamples" 263 | helpmessage() 264 | # Move 265 | elif opt in ('-m','--move'): 266 | move = int(arg) 267 | if not 0 <= move <= 50: 268 | print "Bad argument for --move" 269 | helpmessage() 270 | else: 271 | if chg is None: 272 | chg = {} 273 | chg['translate'] = move 274 | # Zoom 275 | elif opt in ('-z','--zoom'): 276 | zoom = float(arg) 277 | if not 0 <= zoom <= 1: 278 | print "Bad argument for --zoom" 279 | helpmessage() 280 | else: 281 | if chg is None: 282 | chg = {} 283 | chg['scale'] = zoom 284 | # Rotate 285 | elif opt in ('-r','--rotate'): 286 | rot = int(arg) 287 | if not -1 <= rot <= 180: 288 | print "Bad argument for --rotate" 289 | helpmessage() 290 | else: 291 | if chg is None: 292 | chg = {} 293 | chg['rotate'] = rot 294 | # Segment 295 | elif opt in ('-s','--segment'): 296 | seg = int(arg) 297 | if not (0 <= seg <= 3): 298 | print "Bad argument for --segment" 299 | helpmessage() 300 | # Detect 301 | elif opt in ('-q','--detect'): 302 | sampletype = 'detect' 303 | # Cascade 304 | elif opt in ('-c','--cascade'): 305 | sampletype = 'cascade' 306 | jnt1 = int(arg) 307 | elif opt in ('-j','--pairedjoint'): 308 | jnt2 = int(arg) 309 | 310 | if dataset is None: 311 | print "No dataset chosen." 312 | helpmessage() 313 | if outfile is None: 314 | print "No output filename chosen." 315 | helpmessage() 316 | 317 | if settype in ['test','test_obs']: 318 | # Test set has a standard number of images, and no augmentation 319 | numsamples = -1 320 | elif settype == 'train' and chg is None: 321 | if sampletype == 'default': chg = {'rotate':20, 'scale':.5} 322 | elif sampletype == 'cascade': chg = {'rotate':20,'scale':.2, 'translate':20} 323 | else: chg = {} 324 | chg['flip'] = True 325 | 326 | # If we're generating cascade data make sure two joints have been provided 327 | if sampletype == 'cascade': 328 | if jnt1 == -1 or jnt2 == -1: 329 | print "Need two joints to generate cascade data" 330 | helpmessage() 331 | extra = [jnt1, jnt2] 332 | 333 | generateset(dataset, settype, outfile, numsamples, datadir=datadir, chg=chg, sampletype=sampletype, extra_args=extra) 334 | 335 | if __name__ == "__main__": 336 | main(sys.argv[1:]) 337 | -------------------------------------------------------------------------------- /src/pypose/report.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | from matplotlib.backends.backend_pdf import PdfPages 6 | import sys, getopt 7 | import h5py 8 | 9 | import img 10 | import eval 11 | import ref 12 | import data 13 | 14 | # ============================================================================= 15 | # Helper functions 16 | # ============================================================================= 17 | 18 | # loss_idxs = [5,4] 19 | loss_idxs = [2,2] 20 | acc_idxs = [3,5] 21 | 22 | def doit(train_log,test_log,idx=None): 23 | f = plt.figure() 24 | ax = f.add_subplot(1,1,1) 25 | plottraintest(ax, train_log, test_log, idx=idx, title='Accuracy') 26 | plt.show() 27 | plt.savefig('whatever.png') 28 | plt.clf() 29 | return 30 | 31 | def plottraintest(ax, train_log, test_log, idx=loss_idxs, title='Loss'): 32 | for tick in ax.xaxis.get_major_ticks(): 33 | tick.label.set_fontsize(8) 34 | for tick in ax.yaxis.get_major_ticks(): 35 | tick.label.set_fontsize(8) 36 | 37 | t = (np.arange(len(train_log[1:,idx[0]])) + 1) 38 | ax.plot(t, train_log[1:,idx[0]], label='Train', color='k') 39 | t = (np.arange(len(test_log[1:,idx[1]])) + 1) 40 | ax.plot(t, test_log[1:,idx[1]], label='Test', color='r') 41 | # ax.set_ylim(0,1) 42 | ax.legend(loc='upper right', fontsize=10) 43 | ax.set_title('Training/Test %s'%title, fontsize=10) 44 | 45 | def tabletext(ax, txt, r, c, align='center', size=10): 46 | # Shift left/right for alignment 47 | if align == 'left': c -= .4 48 | elif align == 'right': c += .4 49 | # Handle weirdness because the first column is double width 50 | if c > 0: c += 1 51 | elif align == 'center': c += .5 52 | elif align == 'right': c += 1 53 | # Write text 54 | ax.text(c+.5, .25-r, txt, fontsize=size, horizontalalignment=align) 55 | 56 | def setuptable(ax, numrows, numcols, row_labels=None, col_labels=None): 57 | # Initial table set up 58 | # (first column is double width to allow more room for label names) 59 | for r in xrange(numrows): 60 | if row_labels is not None and r < len(row_labels): 61 | tabletext(ax, row_labels[r], r+1, 0, align='left') 62 | if r: thk = 1 63 | else: thk = 2 64 | plt.plot([0,numcols+1],[-r,-r],color='k',linewidth=thk) 65 | for c in xrange(numcols): 66 | if col_labels is not None and c < len(col_labels): 67 | tabletext(ax, col_labels[c], 0, c) 68 | if c: thk = 1 69 | else: thk = 2 70 | plt.plot([c+2, c+2],[1,1-numrows],color='k',linewidth=thk) 71 | ax.set_xlim(0,numcols+1) 72 | ax.set_ylim(-numrows+1,1) 73 | ax.get_xaxis().set_visible(False) 74 | ax.get_yaxis().set_visible(False) 75 | 76 | def setuppdjplot(ax, i, num_parts, num_cols): 77 | # Configuration of ticks in plots 78 | major_ticks_y = np.arange(0,1.01,.2) 79 | minor_ticks_y = np.arange(0,1.01,.1) 80 | ax.set_yticks(major_ticks_y) 81 | ax.set_yticks(minor_ticks_y, minor=True) 82 | for tick in ax.xaxis.get_major_ticks(): 83 | tick.label.set_fontsize(8) 84 | if i < num_parts - num_cols: 85 | tick.label.set_visible(False) 86 | for tick in ax.yaxis.get_major_ticks(): 87 | tick.label.set_fontsize(8) 88 | if not (i % num_cols == 0): 89 | tick.label.set_visible(False) 90 | ax.grid() 91 | ax.grid(which='minor', alpha=0.5) 92 | 93 | def loadpreds(dataset, predfile, pred_opts, get_hms=False, dotrain=False): 94 | num_parts, vert, obs = pred_opts 95 | hms = None 96 | with h5py.File(predfile, 'r+') as f: 97 | # Choose appropriate key 98 | if vert: k = 'preds_v' 99 | else: k = 'preds_tf' 100 | # Load predictions 101 | if k in f.keys(): 102 | preds = np.array(f[k]) 103 | else: 104 | preds = eval.transformpreds(dataset, f['preds'], 105 | [64, 64], rot=vert) 106 | f[k] = preds 107 | 108 | # Ignore additional predictions from segmentation (soon to be unnecessary) 109 | if preds.shape[1] > num_parts: 110 | preds = preds[:,:num_parts,:] 111 | # Also load heatmaps if necessary 112 | if get_hms: 113 | hms = np.array(f['preds_raw']) 114 | 115 | # Load distances 116 | dist_key = 'dist_' 117 | if vert: dist_key += 'v' 118 | if obs: dist_key += 'o' 119 | if dist_key in f.keys(): 120 | dists = np.array(f[dist_key]) 121 | else: 122 | # Evaluate distances to ground truth 123 | dists = eval.getdists(preds, dotrain) 124 | f[dist_key] = dists 125 | return preds, dists, hms 126 | 127 | # ============================================================================= 128 | # Page creation functions 129 | # ============================================================================= 130 | 131 | # Sample images page 132 | # - m x n rows and columns of skeleton images 133 | # - allow filters and sort by overall score or score for specific parts 134 | # - input res, number of images 135 | default_res = [256, 256] 136 | def sampleimages(annot, preds, dists=None, partnames=None, filt=None, num_rows=7, num_cols=5, res=default_res, get_worst=False, page_num=1, pdf=None, title='Prediction Examples'): 137 | # Dataset name 138 | ds = annot.attrs['name'] 139 | 140 | # Initialize blank page 141 | plt.clf() 142 | fig = plt.figure(figsize=(8.5,11), dpi=100, facecolor='w') 143 | ax = fig.add_subplot(111) 144 | page = np.zeros((res[0]*num_rows, res[1]*num_cols, 3), np.uint8) 145 | 146 | # If no specific parts have been chosen, use them all for scoring 147 | if partnames is None: 148 | partnames = ref.parts[ds] 149 | part_idxs = [ref.parts[ds].index(part) for part in partnames] 150 | part_filt = [i in part_idxs for i in xrange(len(ref.parts[ds]))] 151 | 152 | # If no filter is provided create entirely true array 153 | if filt is None: 154 | filt = np.array([True for _ in xrange(len(preds))]) 155 | else: 156 | filt = filt.copy() 157 | 158 | # If no precalculated distances are provided, calculate them 159 | if dists is None: 160 | dists = eval.getdists(preds) 161 | 162 | # Determine scores from which we'll sort the images 163 | scores = np.zeros(len(preds)) 164 | for i in xrange(len(preds)): 165 | # A bit of an interesting line below, gets the mean distance for a particular image 166 | # while only considering the parts we want and ignoring any parts where there's no annotation 167 | vals = dists[i, part_filt * (annot['part'][i,:,0] > 0)] 168 | if len(vals) > 0: 169 | scores[i] = vals.mean() 170 | else: 171 | # If no valid annotation to make a score, filter out this example 172 | filt[i] = False 173 | if get_worst: 174 | # Flip the scores if we're getting the worst images 175 | scores = -scores 176 | best_idxs = scores.argsort() 177 | curr_idx = 0 178 | 179 | # Start filling in the overall image 180 | for i in xrange(page_num * num_rows * num_cols): 181 | while curr_idx < len(best_idxs) and not filt[best_idxs[curr_idx]]: 182 | curr_idx += 1 183 | if curr_idx >= len(best_idxs): break 184 | 185 | # If we're doing multiple pages, pass over the images that have already been used 186 | if i >= (page_num - 1) * num_rows * num_cols: 187 | idx = best_idxs[curr_idx] 188 | curr_pred = preds[idx].copy() 189 | curr_pred[-(part_filt * (annot['part'][idx,:,0] > 0)), :] = -1000 190 | new_img = img.sample_with_skeleton(annot, idx, curr_pred, res=res) 191 | row = ((i % (num_rows * num_cols)) / num_cols) * res[0] 192 | col = ((i % (num_rows * num_cols)) % num_cols) * res[1] 193 | 194 | page[row:row+res[0], col:col+res[1]] = new_img 195 | curr_idx += 1 196 | 197 | # Plot management 198 | if not get_worst: 199 | title += ' - Best - ' 200 | else: 201 | title += ' - Worst - ' 202 | title += 'Page %d' % page_num 203 | ax.set_title(title) 204 | ax.imshow(page) 205 | ax.axis('off') 206 | fig.subplots_adjust(left=0.05,right=.95,bottom=0.05,top=.95) 207 | if pdf: 208 | pdf.savefig() 209 | else: 210 | plt.show() 211 | 212 | # Part heatmaps page 213 | def partheatmaps(annot, preds, preds_raw, dists=None, partnames=None, filt=None, num_rows=7, vert=False, num_cols=2, res=default_res, get_worst=False, page_num=1, pdf=None, title='Prediction Examples'): 214 | # Dataset name 215 | ds = annot.attrs['name'] 216 | 217 | # Initialize blank page 218 | plt.clf() 219 | fig = plt.figure(figsize=(8.5,11), dpi=100, facecolor='w') 220 | ax = fig.add_subplot(111) 221 | 222 | # If no specific parts have been chosen, use them all for scoring 223 | if partnames is None: 224 | partnames = ref.parts[ds] 225 | part_idxs = [ref.parts[ds].index(part) if type(part) is str else part for part in partnames] 226 | part_filt = [i in part_idxs for i in xrange(len(ref.parts[ds]))] 227 | page = np.zeros((res[0]*num_rows, res[1]*num_cols*(1+len(part_idxs)), 3), np.uint8) 228 | 229 | # If no filter is provided create entirely true array 230 | if filt is None: 231 | filt = np.array([True for _ in xrange(len(preds))]) 232 | else: 233 | filt = filt.copy() 234 | 235 | # If no precalculated distances are provided, calculate them 236 | if dists is None: 237 | dists = eval.getdists(preds) 238 | 239 | # Determine scores from which we'll sort the images 240 | scores = np.zeros(len(preds)) 241 | for i in xrange(len(preds)): 242 | # A bit of an interesting line below, gets the mean distance for a particular image 243 | # while only considering the parts we want and ignoring any parts where there's no annotation 244 | vals = dists[i, part_filt * (annot['part'][i,:,0] > 0)] 245 | if len(vals) > 0: 246 | scores[i] = vals.mean() 247 | else: 248 | # If no valid annotation to make a score, filter out this example 249 | filt[i] = False 250 | if get_worst: 251 | # Flip the scores if we're getting the worst images 252 | scores = -scores 253 | best_idxs = scores.argsort() 254 | if title[:4] == 'head' and get_worst: 255 | np.save('worst_head_idxs',best_idxs[:200]) 256 | curr_idx = 0 257 | 258 | # Start filling in the overall image 259 | for i in xrange(page_num * num_rows * num_cols): 260 | while curr_idx < len(best_idxs) and not filt[best_idxs[curr_idx]]: 261 | curr_idx += 1 262 | if curr_idx >= len(best_idxs): break 263 | 264 | # If we're doing multiple pages, pass over the images that have already been used 265 | if i >= (page_num - 1) * num_rows * num_cols: 266 | idx = best_idxs[curr_idx] 267 | if vert: 268 | inp, _ = data.gensample(annot, idx, chg={'rotate':-1}) 269 | else: 270 | inp, _ = data.gensample(annot, idx) 271 | new_img = img.sample_with_heatmap(ds, inp, preds_raw[idx], num_rows=1, parts_to_show=part_idxs) 272 | row = ((i % (num_rows * num_cols)) / num_cols) * res[0] 273 | col = ((i % (num_rows * num_cols)) % num_cols) * res[1] * (1+len(part_idxs)) 274 | 275 | page[row:row+res[0], col:col+(res[1]*(1+len(part_idxs)))] = new_img 276 | curr_idx += 1 277 | 278 | # Plot management 279 | if not get_worst: 280 | title += ' - Best - ' 281 | else: 282 | title += ' - Worst - ' 283 | title += 'Page %d' % page_num 284 | ax.set_title(title) 285 | ax.imshow(page) 286 | ax.axis('off') 287 | fig.subplots_adjust(left=0.05,right=.95,bottom=0.05,top=.95) 288 | if pdf: 289 | pdf.savefig() 290 | else: 291 | plt.show() 292 | fig.clf() 293 | 294 | # Filter comparison page 295 | # - inputs a list of filters 296 | # - top row is overall performance comparison, and table with numbers including # samples 297 | # - performance broken down by part categories 298 | def filtercomparison(dataset, dists, filts, filtnames, pdf=None, title='Performance comparison', other_dists=None, parts_to_show=None): 299 | # Initialize blank page 300 | # plt.clf() 301 | fig = plt.figure(figsize=(8.5,11), dpi=100, facecolor='w') 302 | 303 | if parts_to_show is None: 304 | part_labels = ref.pair_names[dataset] 305 | parts_to_show = ref.part_pairs[dataset] 306 | else: 307 | part_labels = [parts_to_show[i][0] for i in xrange(len(parts_to_show))] 308 | 309 | # Configuration of ticks in plots 310 | major_ticks_y = np.arange(0,1.01,.2) 311 | minor_ticks_y = np.arange(0,1.01,.1) 312 | 313 | #------------------------------------------------------------------- 314 | # Table with performance numbers 315 | #------------------------------------------------------------------- 316 | 317 | ax_table = fig.add_subplot(5,1,1) 318 | 319 | cols = ['', '#', 'Full'] + part_labels 320 | rows = ['Label'] + filtnames 321 | num_samples = [len(dists) if filt is None else filt.sum() for filt in filts] 322 | num_samples = [''] + num_samples 323 | if other_dists is not None: 324 | rows += other_dists.keys() 325 | num_samples += [len(dists) if other_dists[k][1] is None else other_dists[k][1].sum() 326 | for k in other_dists.keys()] 327 | 328 | # Initial table set up 329 | for r in xrange(len(rows)): 330 | # Filter labels 331 | ax_table.text(.1,len(rows)-.75-r,rows[r],fontsize=10,horizontalalignment='left') 332 | # Number of samples available from each filter 333 | ax_table.text(2.5,len(rows)-.75-r,num_samples[r],fontsize=10,horizontalalignment='center') 334 | if r < len(rows) - 1: 335 | thk = 1 336 | else: 337 | thk = 2 338 | plt.plot([0,len(cols)+1],[r,r],color='k',linewidth=thk) 339 | 340 | for c in xrange(1, len(cols) + 1): 341 | ax_table.text(c+.5,len(rows)-.75,cols[c-1],fontsize=10,horizontalalignment='center') 342 | if c > 2: 343 | thk = 1 344 | else: 345 | thk = 2 346 | if c < len(cols): 347 | plt.plot([c+1, c+1],[0,len(rows)],color='k',linewidth=thk) 348 | 349 | # Performance numbers get filled in as we create the PDJ charts 350 | 351 | ax_table.set_xlim(0,len(cols)+1) 352 | ax_table.get_xaxis().set_visible(False) 353 | ax_table.get_yaxis().set_visible(False) 354 | 355 | ax_table.set_title(title, y=1.05) 356 | 357 | #------------------------------------------------------------------- 358 | # Overall performance chart 359 | #------------------------------------------------------------------- 360 | ax = fig.add_subplot(5,3,4) 361 | 362 | ax.set_yticks(major_ticks_y) 363 | ax.set_yticks(minor_ticks_y, minor=True) 364 | for tick in ax.xaxis.get_major_ticks(): 365 | tick.label.set_fontsize(8) 366 | tick.label.set_visible(False) 367 | for tick in ax.yaxis.get_major_ticks(): 368 | tick.label.set_fontsize(8) 369 | 370 | ax.grid() 371 | ax.grid(which='minor', alpha=0.5) 372 | 373 | for i,filt in enumerate(filts): 374 | d, t = eval.pdjdata(dataset, dists, filt=filt) 375 | # Plot PDJ curve 376 | ax.plot(t,d) 377 | # Display performance number in table 378 | ax_table.text(3.5, len(rows)-i-1.75, '%04.1f' % (d[-1]*100), 379 | fontsize=10, horizontalalignment='center') 380 | if other_dists is not None: 381 | for i,k in enumerate(other_dists): 382 | d, t = eval.pdjdata(dataset, other_dists[k][0], filt=other_dists[k][1]) 383 | # Plot PDJ curve 384 | ax.plot(t,d) 385 | # Display performance number in table 386 | ax_table.text(3.5, len(rows)-i-len(filts)-1.75, '%04.1f' % (d[-1]*100), 387 | fontsize=10, horizontalalignment='center') 388 | 389 | ax.set_ylim([0,1]) 390 | 391 | # box = ax.get_position() 392 | # ax.set_position([box.x0, box.y0, box.width * 0.9, box.height]) 393 | ax_labels = filtnames 394 | if other_dists is not None: 395 | ax_labels += other_dists.keys() 396 | ax.legend(loc='center left', bbox_to_anchor=(1.1, 0.5), labels=ax_labels, fontsize=11) 397 | ax.set_title('Overall', fontsize=10) 398 | 399 | #------------------------------------------------------------------- 400 | # Separate charts for each part 401 | #------------------------------------------------------------------- 402 | num_cols = 3 403 | for i,pts in enumerate(parts_to_show): 404 | ax = fig.add_subplot(5,num_cols,i+7) 405 | setuppdjplot(ax, i, len(parts_to_show), num_cols) 406 | 407 | for j,filt in enumerate(filts): 408 | d, t = eval.pdjdata(dataset, dists, filt=filt, partnames=pts) 409 | # Plot PDJ curve 410 | ax.plot(t,d) 411 | # Display performance number in table 412 | ax_table.text(i+4.5, len(rows)-j-1.75, '%04.1f' % (d[-1]*100), fontsize=10, horizontalalignment='center') 413 | if other_dists is not None: 414 | for j,k in enumerate(other_dists): 415 | d, t = eval.pdjdata(dataset, other_dists[k][0], filt=other_dists[k][1], partnames=pts) 416 | # Plot PDJ curve 417 | ax.plot(t,d) 418 | # Display performance number in table 419 | ax_table.text(i+4.5, len(rows)-j-1.75-len(filts), '%04.1f' % (d[-1]*100), 420 | fontsize=10, horizontalalignment='center') 421 | 422 | ax.set_title(part_labels[i], fontsize=10) 423 | ax.set_ylim([0,1]) 424 | 425 | if pdf: 426 | pdf.savefig() 427 | else: 428 | plt.show() 429 | 430 | # Training report page 431 | # - plot showing loss train/test across iterations 432 | # - accuracy vs state of the art for shoulders, elbows, wrists 433 | # - in plots also include (previous best benchmark) 434 | # - quite a lot of overlap with filter comparison code, a bit sloppy 435 | # there's probably a better code interface to design to pull this altogether 436 | def trainingoverview(dataset, dists, filts, filtnames, pdf=None, 437 | other_dists=None, parts_to_show=None, exp_id='default'): 438 | # Initialize blank page 439 | fig = plt.figure(figsize=(8.5,11), dpi=100, facecolor='w') 440 | 441 | # Default parts to show performance results 442 | if parts_to_show is None: 443 | parts_to_show = [] 444 | if dataset == 'flic': parts_to_show += [('Face',['leye','reye','nose'])] 445 | else: parts_to_show += [('Head',['head','neck']), 446 | ('Ank',['lank','rank']), 447 | ('Knee',['lkne','rkne'])] 448 | parts_to_show += [('Sho',['lsho','rsho']), 449 | ('Elb',['lelb','relb']), 450 | ('Wri',['lwri','rwri']), 451 | ('Hip',['lhip','rhip'])] 452 | part_labels = [parts_to_show[i][0] for i in xrange(len(parts_to_show))] 453 | 454 | # Load training logs 455 | log_dir = ref.posedir + '/exp/' + dataset + '/' + exp_id 456 | train_log = np.loadtxt(log_dir + '/train.log', skiprows=1) 457 | test_log = np.loadtxt(log_dir + '/test.log', skiprows=1) 458 | # Plot loss 459 | ax = fig.add_subplot(4,1,2) 460 | plottraintest(ax, train_log, test_log) 461 | 462 | # Setup table to hold performance numbers 463 | ax_table = fig.add_subplot(4,1,1) 464 | cols = ['Label'] + part_labels 465 | rows = filtnames 466 | if other_dists is not None: 467 | rows += other_dists.keys() 468 | setuptable(ax_table, len(rows)+1, len(cols), row_labels=rows, col_labels=cols) 469 | ax_table.set_title('%s - Experiment: %s' % (dataset.upper(), exp_id), y=1.05) 470 | 471 | # Generate PDJ charts for each part 472 | num_cols = 4 473 | for i,pts in enumerate(parts_to_show): 474 | ax = fig.add_subplot(4,num_cols,i+2*num_cols+1) 475 | setuppdjplot(ax, i, len(parts_to_show), num_cols) 476 | 477 | for j,filt in enumerate(filts): 478 | d, t = eval.pdjdata(dataset, dists, filt=filt, partnames=pts[1]) 479 | # Plot PDJ curve 480 | ax.plot(t,d,label=filtnames[j]) 481 | # Display performance number in table 482 | tabletext(ax_table, '%04.1f'%(d[-1]*100), j+1, i+1) 483 | if other_dists is not None: 484 | for j,k in enumerate(other_dists): 485 | d, t = eval.pdjdata(dataset, other_dists[k][0], filt=other_dists[k][1], partnames=pts[1]) 486 | # Plot PDJ curve 487 | ax.plot(t,d,label=k) 488 | # Display performance number in table 489 | tabletext(ax_table, '%04.1f'%(d[-1]*100), j+1+len(filts), i+1) 490 | 491 | ax.set_title(pts[0], fontsize=10) 492 | ax.set_ylim([0,1]) 493 | 494 | if i == len(parts_to_show) - 1: 495 | ax_labels = filtnames 496 | if other_dists is not None: 497 | ax_labels += other_dists.keys() 498 | ax.legend(loc='upper left', bbox_to_anchor=(1.1, 1), fontsize=11) 499 | 500 | if pdf: pdf.savefig() 501 | else: plt.show() 502 | 503 | # Limb report page 504 | # A super basic report that just includes training/test loss and accuracy over time 505 | def limbreport(dataset, exp_id, pdf=None): 506 | fig = plt.figure(figsize=(8.5,11), dpi=100, facecolor='w') 507 | 508 | # Plot train/test loss 509 | ax = fig.add_subplot(3,1,1) 510 | exp_dir = ref.posedir + '/exp/' + dataset + '/' + exp_id 511 | train_log = np.loadtxt(exp_dir + '/train.log', skiprows=1) 512 | test_log = np.loadtxt(exp_dir + '/test.log', skiprows=1) 513 | plottraintest(ax, train_log, test_log) 514 | 515 | # Plot test accuracy over time 516 | ax = fig.add_subplot(3,1,2) 517 | plottraintest(ax, train_log, test_log, idx=acc_idxs, title='Accuracy') 518 | 519 | # Plot precision/recall curve 520 | with h5py.File(exp_dir + '/preds.h5', 'r') as f: 521 | preds = np.array(f['preds_raw']) 522 | with h5py.File(ref.posedir + '/data/mpii/sho_neck_test.h5','r') as f: 523 | label = np.array(f['label']) 524 | 525 | p,r = [],[] 526 | for thrsh in np.arange(0,1.01,.05): 527 | preds_bool = np.floor(preds+thrsh) 528 | tp = (preds_bool * label).sum() 529 | fp = ((preds_bool - label) == 1).sum() 530 | fn = ((label - preds_bool) == 1).sum() 531 | p += [float(tp)/(tp+fp) if tp > 0 else 1] 532 | r += [float(tp)/(tp+fn)] 533 | ax = fig.add_subplot(3,1,3) 534 | ax.plot(r,p) 535 | ax.set_xlabel('Recall') 536 | ax.set_ylabel('Precision') 537 | ax.set_xlim([0,1]) 538 | ax.set_ylim([0,1]) 539 | if pdf: 540 | pdf.savefig() 541 | else: 542 | plt.show() 543 | 544 | # ============================================================================= 545 | # Main command line interface 546 | # ============================================================================= 547 | 548 | def helpmessage(): 549 | print "This isn't too helpful, updating message soon..." 550 | sys.exit(2) 551 | 552 | # Main 553 | def main(argv): 554 | dataset = None 555 | exp_id = None 556 | extra = [] 557 | prev = [] 558 | other_dists = {} 559 | vert = False 560 | images = False 561 | obs = False 562 | limb = False 563 | 564 | # Process command line arguments 565 | try: 566 | opts, args = getopt.getopt(argv, "hd:e:c:p:viol", ["help", "dataset=", "expID=", "compare=", "prev=", 567 | "vert", "images", "obs", "limb"]) 568 | except getopt.GetoptError: 569 | print "Incorrect arguments" 570 | helpmessage() 571 | sys.exit(2) 572 | for opt, arg in opts: 573 | if opt in ('-h', '--help'): 574 | helpmessage() 575 | elif opt in ('-d', '--dataset'): dataset = arg 576 | elif opt in ('-e', '--expID'): exp_id = arg 577 | elif opt in ('-c', '--compare'): extra += arg.split(',') 578 | elif opt in ('-p', '--prev'): prev += arg.split(',') 579 | elif opt in ('-v', '--vert'): vert = True 580 | elif opt in ('-i', '--images'): images = True 581 | elif opt in ('-o', '--obs'): obs = True 582 | elif opt in ('-l', '--limb'): limb = True 583 | 584 | if dataset is None: 585 | print "No dataset chosen." 586 | helpmessage() 587 | if not (dataset in ['mpii','flic']): 588 | print "Bad argument for --dataset" 589 | helpmessage() 590 | if exp_id is None: 591 | print "No experiment number provided." 592 | helpmessage() 593 | expdir = ref.posedir + '/exp/' + dataset + '/' + exp_id 594 | 595 | # Generate the simple report for mini limb networks 596 | if limb: 597 | pdf = PdfPages(expdir + '/report.pdf') 598 | limbreport(dataset, exp_id, pdf) 599 | pdf.close() 600 | return 601 | 602 | # Load in dataset information 603 | num_parts = len(ref.parts[dataset]) 604 | if obs: 605 | annot = ref.load(dataset, 'test_obs') 606 | eval.annot[dataset] = annot 607 | else: 608 | annot = ref.load(dataset, 'valid') 609 | 610 | # Load predictions 611 | print "Loading predictions" 612 | pred_opts = [num_parts, vert, obs] 613 | preds, dists, _ = loadpreds(dataset, expdir + '/preds.h5', pred_opts, images) 614 | 615 | # Load previous predictions 616 | for prv in prev: 617 | _,d,_ = loadpreds(dataset, expdir + '/preds_%s.h5' % prv, pred_opts) 618 | other_dists[prv] = [d, None] 619 | 620 | # Load comparison predictions 621 | for ext in extra: 622 | predfile = ref.posedir + '/exp/' + dataset + '/' + ext + '/preds.h5' 623 | _,d,_ = loadpreds(dataset, predfile, pred_opts) 624 | other_dists[ext] = [d, None] 625 | 626 | # Load previous best 627 | if vert: predfile = expdir + '/../best/preds_vert.h5' 628 | else: predfile = expdir + '/../best/preds.h5' 629 | _,best_dists,_ = loadpreds(dataset, predfile, pred_opts) 630 | #other_dists["Kaiyu's best model"] = [best_dists, None] 631 | 632 | # Load NYU predictions 633 | if dataset == 'mpii': 634 | nyu_dists = np.load(eval.get_path(dataset, 'nyu_dists')) 635 | else: 636 | if not obs: nyu_preds = np.load(eval.get_path(dataset, 'nyu_pred')) 637 | else: nyu_preds = np.load(eval.get_path(dataset, 'nyu_pred_obs')) 638 | nyu_dists = eval.getdists(nyu_preds) 639 | np.save('nyu_dists_%s%s'%(dataset,'_obs' if obs else ''),nyu_dists) 640 | other_dists['Tompson'] = [nyu_dists, None] 641 | 642 | # Load training set predictions 643 | if False: 644 | _,d,_ = loadpreds(dataset, expdir + '/preds_train.h5', pred_opts, dotrain=True) 645 | other_dists['Train'] = [d, None] 646 | 647 | filt = None 648 | 649 | print "Creating overview page" 650 | # Main report creation 651 | pdf = PdfPages(expdir + '/report.pdf') 652 | 653 | # Training overview page 654 | trainingoverview(dataset, dists, [filt], [exp_id], exp_id=exp_id, 655 | other_dists=other_dists, pdf=pdf) 656 | 657 | if images: 658 | print "Creating prediction examples page" 659 | # Overall performance examples 660 | num_good_exs = 2 661 | num_bad_exs = 6 662 | for i in xrange(num_good_exs): 663 | sampleimages(annot,preds,dists,pdf=pdf,page_num=i+1) 664 | for i in xrange(num_bad_exs): 665 | sampleimages(annot,preds,dists,get_worst=True,pdf=pdf,page_num=i+1) 666 | 667 | # print "Creating part heatmap examples" 668 | # # Heatmap examples 669 | # for i in xrange(len(ref.part_pairs[dataset])): 670 | # title = ref.pair_names[dataset][i] 671 | # pt_names = ref.part_pairs[dataset][i] 672 | # if not title == 'face': 673 | # partheatmaps(annot,preds,preds_raw,dists=dists,partnames=pt_names,title='%s Heatmap Examples'%title, 674 | # pdf=pdf, page_num=1, vert=vert) 675 | # for j in xrange(1,3): 676 | # partheatmaps(annot,preds,preds_raw,dists=dists,partnames=pt_names,title='%s Heatmap Examples'%title, 677 | # pdf=pdf, page_num=j, vert=vert, get_worst=True) 678 | 679 | pdf.close() 680 | 681 | if __name__ == "__main__": 682 | main(sys.argv[1:]) 683 | 684 | # Reference for creating a left/right filter 685 | # filt = np.array([ref.partinfo(annot,i,'lsho')[0] > 686 | # ref.partinfo(annot,i,'rsho')[0] 687 | # for i in xrange(len(annot['index']))]) 688 | --------------------------------------------------------------------------------