├── sample_img.jpg ├── detected_keypoints.png ├── deployModel.m ├── CHANGELOG ├── model-train ├── transformImage.m ├── generateGHeatmap.m ├── getImdbNoAug.m ├── imRotateGetBatch.m ├── imScaleGetBatch.m ├── +dagnn │ └── RegLoss.m ├── flipKeyPointsCoords.m ├── vl_nnheatloss.m ├── cnn_regressor_dag.m ├── cnn_train_dag_reg.m └── cnn_regressor_get_batch.m ├── LICENSE ├── demo_keypoint.m ├── getMPIIData ├── splitMPIIData_V4.m └── getMPIIData_v3.m ├── demolive_keypoints.m ├── README.md ├── KeyPointDetector.m ├── trainBodyPose_example.m └── dagnetworks └── initialize3ObjeRecFusion.m /sample_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ox-vgg/keypoint_detection/HEAD/sample_img.jpg -------------------------------------------------------------------------------- /detected_keypoints.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ox-vgg/keypoint_detection/HEAD/detected_keypoints.png -------------------------------------------------------------------------------- /deployModel.m: -------------------------------------------------------------------------------- 1 | clearvars; close all; clc; 2 | 3 | %Clean a trained model from the momentum and training stats. 4 | 5 | %MatConvNet library 6 | run(fullfile(fileparts(mfilename('fullpath')),... 7 | '..','matconvnet-b23','matlab', 'vl_setupnn.m')) ; 8 | 9 | model = 'keypoint-v3'; 10 | load(sprintf('%s.mat',model)); 11 | net = dagnn.DagNN.loadobj(net); 12 | net.move('cpu'); 13 | net.rebuild(); 14 | save(sprintf('%s.mat',model),'net'); 15 | -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | keypoint-v1 : Initial model file (epoch 30) [July 14, 2016] 2 | keypoint-v2 : Recurret model as described in https://arxiv.org/abs/1605.02914 [Aug. 18, 2016] 3 | keypoint-v3 : Recurret model with 2 iterations as described in http://www.robots.ox.ac.uk/~vgg/publications/2017/Belagiannis17/ [March. 23, 2017] 4 | keypoint-v4 : Recurret model with 2 iterations (trained on MSCOCO, then fined-tuned in MPII Pose as described in http://www.robots.ox.ac.uk/~vgg/publications/2017/Belagiannis17/ [March. 23, 2017] 5 | 6 | All the old model files can be downloaded from https://gitlab.com/vggdemo/keypoint_data -------------------------------------------------------------------------------- /model-train/transformImage.m: -------------------------------------------------------------------------------- 1 | function [new_im, grn, gcn] = transformImage(I, gr, gc, TM) 2 | %More info at: http://stackoverflow.com/questions/13366771/matlab-image-transformation 3 | 4 | [new_im, xdata, ydata] = imtransform(I, TM,'XYScale',1,'FillValues', 128); 5 | w = xdata(2)-xdata(1) +1; 6 | h = ydata(2)-ydata(1)+1; 7 | scalex = size(new_im,2)/w; 8 | scaley = size(new_im,1)/h; 9 | 10 | coords = [gc(:), gr(:)]; 11 | coords_tf = tformfwd(TM, coords); 12 | 13 | %translation 14 | coords_tf_mg(:,1) = coords_tf(:,1) - xdata(1) + 1; 15 | coords_tf_mg(:,2) = coords_tf(:,2) - ydata(1) + 1; 16 | 17 | %scale 18 | coords_tf_mg(:,1) = coords_tf_mg(:,1)*scalex; 19 | coords_tf_mg(:,2) = coords_tf_mg(:,2)*scaley; 20 | 21 | coords_tf_mg = round(coords_tf_mg); 22 | grn = coords_tf_mg(:,2); 23 | gcn = coords_tf_mg(:,1); 24 | grn = reshape(grn, size(gr,1), size(gr,2)); 25 | gcn = reshape(gcn, size(gc,1), size(gc,2)); 26 | end -------------------------------------------------------------------------------- /model-train/generateGHeatmap.m: -------------------------------------------------------------------------------- 1 | function heatmap = generateGHeatmap(heatmap,cxy,theta,len,opts) 2 | 3 | x1 = 1:1:size(heatmap,1); %rows 4 | x2 = 1:1:size(heatmap,2); %cols 5 | 6 | [X1,X2] = meshgrid(x1, x2); 7 | sigma1 = len*opts.facX; %sigmx (part length) 8 | sigma2 = len*opts.facY; %sigy 9 | scale1 = 1; 10 | scale2 = 1; 11 | sigma1 = scale1*sigma1; 12 | sigma2 = scale2*sigma2; 13 | theta = 180-theta; 14 | 15 | a = ((cosd(theta)^2) / (2*sigma1^2)) + ((sind(theta)^2) / (2*sigma2^2)); 16 | b = -((sind(2*theta)) / (4*sigma1^2)) + ((sind(2*theta)) / (4*sigma2^2)); 17 | c = ((sind(theta)^2) / (2*sigma1^2)) + ((cosd(theta)^2) / (2*sigma2^2)); 18 | 19 | mu(1)= cxy(1); 20 | mu(2)= cxy(2); 21 | 22 | %add up the heatmap for each individual 23 | newMap = exp(-(a*(X1 - mu(1)).^2 + 2*b*(X1 - mu(1)).*(X2 - mu(2)) + c*(X2 - mu(2)).^2)); 24 | 25 | newMap(isnan(newMap))=0; 26 | newMap(newMap<10^-18)=0; %cut the tails 27 | 28 | heatmap = heatmap + (opts.magnif*newMap); -------------------------------------------------------------------------------- /model-train/getImdbNoAug.m: -------------------------------------------------------------------------------- 1 | % -------------------------------------------------------------------- 2 | function imdb = getImdbNoAug(opts) 3 | % -------------------------------------------------------------------- 4 | 5 | % Load the data to form the imdb file 6 | 7 | load(opts.DataMatTrain); %training data 8 | 9 | imdb.images.data=imgPath; 10 | sets=ones(1,numel(imgPath)); 11 | imdb.images.labels=ptsAll; 12 | 13 | clear imgPath ptsAll; 14 | 15 | load(opts.DataMatVal); %validation data 16 | 17 | sets=[sets 2*ones(1,numel(imgPath))]; 18 | imdb.images.data=[imdb.images.data imgPath]; 19 | 20 | if iscell(imdb.images.labels)%different formats of ground-truth 21 | imdb.images.labels=[imdb.images.labels ptsAll]; 22 | else 23 | imdb.images.labels=cat(3,imdb.images.labels,ptsAll); 24 | end 25 | 26 | imdb.images.set=sets; 27 | imdb.meta.sets = {'train', 'val', 'test'} ; 28 | imdb.patchHei=opts.patchHei; 29 | imdb.patchWi=opts.patchWi; 30 | imdb.averageImage = []; -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, University of Oxford 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /model-train/imRotateGetBatch.m: -------------------------------------------------------------------------------- 1 | function [imo,pts] = imRotateGetBatch(im,pts) 2 | 3 | % imshow(uint8(im));hold on; 4 | % for i=1:size(pts,1) 5 | % text(pts(i,1),pts(i,2), 'x','Color','g','FontSize',15); 6 | % end 7 | % hold off; pause(); 8 | 9 | theta1=rand(1)*40; 10 | theta2=-theta1; 11 | 12 | if rand(1)>0.5 13 | theta1=theta2; 14 | end 15 | 16 | tform = maketform('affine',[cosd(theta1) -sind(theta1) 0; sind(theta1) cosd(theta1) 0; 0 0 1]); 17 | [imo, ptsy, ptsx] = transformImage(im, pts(:,2), pts(:,1), tform); 18 | 19 | %zeros should remain zeros 20 | ptsx = ptsx.*double(pts(:,1)>0 & pts(:,2)>0); 21 | ptsy = ptsy.*double(pts(:,1)>0 & pts(:,2)>0); 22 | 23 | %crop the image and change the origin of the points 24 | x = 1 + round(rand(1) * (size(imo,2) - size(im,2)-1)); 25 | y = 1 + round(rand(1) * (size(imo,1) - size(im,1)-1)); 26 | imo = imo(y:y+size(im,1)-1,x:x+size(im,2)-1,:); 27 | 28 | %GT points 29 | ptsx=ptsx-x +1; 30 | ptsy=ptsy-y +1; 31 | 32 | %zeros should remain zeros 33 | ptsx = ptsx.*double(pts(:,1)>0 & pts(:,2)>0); 34 | ptsy = ptsy.*double(pts(:,1)>0 & pts(:,2)>0); 35 | 36 | %exclude out of plane points 37 | ptsy = ptsy.*double(ptsx>=1 & ptsy>=1); 38 | ptsx = ptsx.*double(ptsx>=1 & ptsy>=1); 39 | ptsy = ptsy.*double(ptsy size(im,2) %crop 31 | cropSize = round((size(imo,2) - size(im,1))/2); 32 | imo = imo(cropSize:end-cropSize,cropSize:end-cropSize,:); 33 | %crop if it is required to euqalize the dims. 34 | imo=imo(1:size(im,1),1:size(im,2),:); 35 | ptsx=ptsx-cropSize; 36 | ptsy=ptsy-cropSize; 37 | end 38 | 39 | if size(imo)~=size(im) 40 | disp('problem'); 41 | end 42 | 43 | %zeros should remain zeros 44 | ptsx = ptsx.*double(pts(:,1)>0 & pts(:,2)>0); 45 | ptsy = ptsy.*double(pts(:,1)>0 & pts(:,2)>0); 46 | 47 | %exclude out of plane points 48 | ptsy = ptsy.*double(ptsx>=1 & ptsy>=1); 49 | ptsx = ptsx.*double(ptsx>=1 & ptsy>=1); 50 | ptsy = ptsy.*double(ptsy0; 6 | poseGTresc(idx,1)=leng-poseGTresc(idx,1); %change origin 7 | temp=poseGTresc; 8 | poseGTresc=temp; 9 | poseGTresc(1,:)=temp(6,:); 10 | poseGTresc(6,:)=temp(1,:); 11 | poseGTresc(2,:)=temp(5,:); 12 | poseGTresc(5,:)=temp(2,:); 13 | poseGTresc(3,:)=temp(4,:); 14 | poseGTresc(4,:)=temp(3,:); 15 | poseGTresc(13,:)=temp(14,:); 16 | poseGTresc(14,:)=temp(13,:); 17 | poseGTresc(12,:)=temp(15,:); 18 | poseGTresc(15,:)=temp(12,:); 19 | poseGTresc(11,:)=temp(16,:); 20 | poseGTresc(16,:)=temp(11,:); 21 | y=poseGTresc; 22 | elseif strcmp(jnts,'mscoco')==1 23 | idx = poseGTresc(:,1)>0; 24 | poseGTresc(idx,1)=leng-poseGTresc(idx,1); %change origin 25 | temp=poseGTresc; 26 | poseGTresc=temp; 27 | poseGTresc(2,:)=temp(3,:); 28 | poseGTresc(3,:)=temp(2,:); 29 | poseGTresc(4,:)=temp(5,:); 30 | poseGTresc(5,:)=temp(4,:); 31 | poseGTresc(6,:)=temp(7,:); 32 | poseGTresc(7,:)=temp(6,:); 33 | poseGTresc(8,:)=temp(9,:); 34 | poseGTresc(9,:)=temp(8,:); 35 | poseGTresc(10,:)=temp(11,:); 36 | poseGTresc(11,:)=temp(10,:); 37 | poseGTresc(12,:)=temp(13,:); 38 | poseGTresc(13,:)=temp(12,:); 39 | poseGTresc(14,:)=temp(15,:); 40 | poseGTresc(15,:)=temp(14,:); 41 | poseGTresc(16,:)=temp(17,:); 42 | poseGTresc(17,:)=temp(16,:); 43 | y=poseGTresc; 44 | else %full body (LSP) 45 | idx = poseGTresc(:,1)>0; 46 | poseGTresc(idx,1)=leng-poseGTresc(idx,1); %change origin 47 | temp=poseGTresc; 48 | poseGTresc=temp; 49 | poseGTresc(1,:)=temp(6,:); 50 | poseGTresc(6,:)=temp(1,:); 51 | poseGTresc(2,:)=temp(5,:); 52 | poseGTresc(5,:)=temp(2,:); 53 | poseGTresc(3,:)=temp(4,:); 54 | poseGTresc(4,:)=temp(3,:); 55 | poseGTresc(7,:)=temp(12,:); 56 | poseGTresc(12,:)=temp(7,:); 57 | poseGTresc(8,:)=temp(11,:); 58 | poseGTresc(11,:)=temp(8,:); 59 | poseGTresc(9,:)=temp(10,:); 60 | poseGTresc(10,:)=temp(9,:); 61 | y=poseGTresc; 62 | end 63 | 64 | end -------------------------------------------------------------------------------- /demo_keypoint.m: -------------------------------------------------------------------------------- 1 | %% Keypoints detection in Human Pose Estimation 2 | % This is a MatConvNet demo for human pose human estimation. 3 | % Related Work: Belagiannis V., and Zisserman A., 4 | % Recurrent Human Pose Estimation, FG2017. 5 | % Contact: Vasileios Belagiannis, vb@robots.ox.ac.uk 6 | % Part of this demo has been written by Abhishek Dutta (adutta@robots.ox.ac.uk) 7 | % For further details, visit http://www.robots.ox.ac.uk/~vgg/software/keypoint_detection/ 8 | 9 | close all; clear; clc 10 | 11 | addpath('model-train'); 12 | 13 | % Update these according to your requirements 14 | USE_GPU = 0; % 1 for GPU 15 | img_fn = 'sample_img.jpg'; 16 | 17 | DEMO_MODEL_FN = 'keypoint-v4.mat'; 18 | MATCONVNET_DIR = '../matconvnet-b23/'; 19 | 20 | % 21 | % Compile matconvnet 22 | % http://www.vlfeat.org/matconvnet/install/ 23 | % 24 | if ~exist( fullfile(MATCONVNET_DIR, 'matlab', 'mex'), 'dir' ) 25 | disp('Compiling matconvnet ...') 26 | addpath('./lib/matconvnet-custom/matlab'); 27 | if ( USE_GPU ) 28 | vl_compilenn('enableGpu', true); 29 | else 30 | vl_compilenn('enableGpu', false); 31 | end 32 | fprintf(1, '\n\nMatcovnet compilation finished.'); 33 | end 34 | 35 | % setup matconvnet path variables 36 | matconvnet_setup_fn = fullfile(MATCONVNET_DIR, 'matlab', 'vl_setupnn.m'); 37 | run(matconvnet_setup_fn) ; 38 | 39 | % Initialize keypoint detector 40 | keypoint_detector = KeyPointDetector(DEMO_MODEL_FN, MATCONVNET_DIR, USE_GPU); 41 | 42 | % Detect keypoints 43 | fprintf(1, '\nDetecting keypoints in image : %s', img_fn); 44 | [kpx, kpy, kpname] = get_all_keypoints(keypoint_detector, img_fn); 45 | 46 | % Display the keypoints 47 | img = imread(img_fn); 48 | figure('Name', 'Detected Keypoints'); 49 | imshow(img); hold on; 50 | plot(kpx, kpy, 'r.', 'MarkerSize', round(size(img,2)/10)); hold on; 51 | 52 | voffset = -10; 53 | for i=1:length(kpname) 54 | text(double(kpx(i)), double(kpy(i) + voffset), ... 55 | kpname{i}, 'color', 'yellow', 'FontSize', 8, ... 56 | 'backgroundcolor', 'black'); 57 | hold on; 58 | voffset = voffset * -1; % to prevent cluttering of annotations 59 | end 60 | hold off; 61 | 62 | fprintf(1, '\nShowing detected keypoints:'); 63 | for i=1:length(kpname) 64 | fprintf(1, '\n%s\tat\t(%d,%d)', kpname{i}, kpx(i), kpy(i)); 65 | end 66 | fprintf(1, '\n'); -------------------------------------------------------------------------------- /getMPIIData/splitMPIIData_V4.m: -------------------------------------------------------------------------------- 1 | clearvars; close all; clc; 2 | 3 | %load('extractedData_detector_0_orcacle_0.mat'); 4 | imgSize=[256, 256]; 5 | load(sprintf('extractedData_%d_%d',imgSize(1),imgSize(2))); 6 | 7 | %validation data from Tompson 8 | load('mpii_predictions/data/detections'); 9 | 10 | h = waitbar(0,'Please wait...'); 11 | 12 | ptsAll_train=[]; 13 | ptsAll_train_box=[]; 14 | imgPath_train=[]; 15 | ptsAll_test=[]; 16 | ptsAll_test_box=[]; 17 | imgPath_test=[]; 18 | tompson_val=[]; 19 | 20 | idx=1:16; 21 | 22 | tompson_cnt=0; %counter for the validation data of Tompson 23 | for i=1:size(sets_train_idx,1) 24 | 25 | cnt=1; 26 | indiv=1; 27 | poseGT(:,:,indiv)=ptsAll{i}(idx,:); 28 | while ~isempty(ptsRest{i,cnt}) && cnt0 30 | indiv=indiv+1; 31 | poseGT(:,:,indiv)=ptsRest{i,cnt}(idx,:); 32 | end 33 | cnt=cnt+1; 34 | end 35 | 36 | idxImg=find(RELEASE_img_index==sets_train_idx(i,1)); 37 | idxPe=find(RELEASE_person_index(idxImg)==sets_train_idx(i,2)); 38 | if isempty(idxPe) 39 | 40 | ptsAll_train{1,size(ptsAll_train,2)+1}=poseGT; 41 | 42 | ptsAll_train_box(:,:,numel(ptsAll_train))=bbox{i}; 43 | imgPath_train{numel(imgPath_train)+1}=img_final{i}; 44 | else 45 | 46 | ptsAll_test{1,size(ptsAll_test,2)+1}=poseGT; 47 | 48 | ptsAll_test_box(:,:,numel(ptsAll_test))=bbox{i}; 49 | imgPath_test{numel(imgPath_test)+1}=img_final{i}; 50 | 51 | tompson_val(size(tompson_val,1)+1,1:2)=sets_train_idx(i,:); 52 | tompson_val(size(tompson_val,1),3:6)=bbox{i}; %bounding box for going back to the original coordinate system 53 | tompson_val(size(tompson_val,1),7:10)=pad_train{i}; 54 | end 55 | 56 | % %plot the points 57 | % imshow(img_final{i}); hold on; 58 | % for j=1:size(poseGT,3) 59 | % tempY =poseGT(:,:,j); 60 | % for po=1:size(tempY,1) 61 | % text(tempY(po,1),tempY(po,2), int2str(po),'Color','m','FontSize',15); 62 | % end 63 | % end 64 | % hold off; pause(); 65 | 66 | 67 | clear poseGT; 68 | 69 | waitbar(i / size(sets_train_idx,1)); 70 | end 71 | close all; 72 | 73 | clear ptsAll; 74 | 75 | 76 | ptsAll=ptsAll_train; 77 | imgPath=imgPath_train; 78 | 79 | % for cnt=1:numel(imgPath) 80 | % %plot the points 81 | % imshow(imgPath{cnt}); hold on; 82 | % tempY = ptsAll(:,:,cnt); 83 | % for po=1:size(tempY,1) 84 | % text(tempY(po,1),tempY(po,2), int2str(po),'Color','m','FontSize',15); 85 | % end 86 | % hold off; pause(); 87 | % 88 | % end 89 | 90 | if numel(idx)==16 91 | save('MPI_imdbsT1aug0.mat','imgPath','ptsAll','-v7.3'); %pose data 92 | else 93 | %do nothing 94 | end 95 | 96 | clear ptsAll imgPath; 97 | 98 | ptsAll=ptsAll_test; 99 | imgPath=imgPath_test; 100 | 101 | if numel(idx)==16 102 | save('MPI_imdbsV1aug0.mat','imgPath','ptsAll','tompson_val','-v7.3'); %pose data 103 | else 104 | %do nothing 105 | end 106 | 107 | close(h); -------------------------------------------------------------------------------- /demolive_keypoints.m: -------------------------------------------------------------------------------- 1 | function demolive_keypoints() 2 | %% Keypoints detection: Human Pose Estimation (Live Demo) 3 | % This is a MatConvNet demo for human pose human estimation. 4 | % Related Work: Belagiannis V., and Zisserman A., 5 | % Recurrent Human Pose Estimation, FG (2017). 6 | % Contact: Vasileios Belagiannis, vb@robots.ox.ac.uk 7 | % Part of this demo has been written by Andrea Vedaldi 8 | % For further details, visit http://www.robots.ox.ac.uk/~vgg/software/keypoint_detection/ 9 | 10 | %Set the model name (recpose-iter2 or recposeFT-iter2) 11 | modelName='recposeFT-iter2'; 12 | 13 | %Setup MatConvNet 14 | run('../matconvnet-b23/matlab/vl_setupnn'); 15 | 16 | %Add the Web-Camera package 17 | addpath('/Users/belajohn/Documents/MATLAB/SupportPackages/R2016a/toolbox/matlab/webcam/supportpackages'); 18 | 19 | % Fixed parameters 20 | GPUt=[]; %default test on CPU 21 | opts.imageSize = [248, 248]; 22 | 23 | GPUon=0; 24 | if numel(GPUt)>0 25 | GPUon=1; 26 | end 27 | 28 | % Load model 29 | load(modelName, 'net'); 30 | net = dagnn.DagNN.loadobj(net) ; 31 | 32 | % Prediction Layer 33 | pred={'prediction10'}; 34 | 35 | if GPUon 36 | gpuDevice(GPUt); 37 | net.move('gpu'); 38 | else 39 | net.move('cpu'); 40 | end 41 | 42 | % web-camera 43 | cam = webcam(1); 44 | 45 | start = tic ; 46 | [img,time0] = snapshot(cam); 47 | n = 0 ; 48 | 49 | figure(1) ; clf ; hold all ; 50 | h = opts.imageSize(1) ; 51 | w = opts.imageSize(2) ; 52 | fig.axis = gca ; 53 | fig.keypoints = [1:16] ; 54 | fig.bg = surface([1 w ; 1 w], [h h ; 1 1], zeros(2), ... 55 | 'facecolor', 'texturemap', 'cdata', img, ... 56 | 'facealpha', 1, ... 57 | 'edgecolor', 'none') ; 58 | fig.colors = jet(numel(fig.keypoints)) ; 59 | for k = 1:numel(fig.keypoints) 60 | icol = repmat(reshape(fig.colors(k,:),1,1,3),[62 62]) ; 61 | fig.fg{k} = surface([1 w ; 1 w], [h h ; 1 1], zeros(2) + k, ... 62 | 'facecolor', 'texturemap', 'cdata', icol, ... 63 | 'facealpha', 'texturemap', 'alphadata', 0.5*ones(62,62,1), ... 64 | 'edgecolor', 'none') ; 65 | end 66 | set(fig.axis,'fontsize', 18) ; 67 | axis equal off ; 68 | xlim([1 w]) ; 69 | ylim([1 h]) ; 70 | 71 | while true 72 | 73 | % load an image 74 | elapsed = toc(start) ; 75 | img = snapshot(cam); 76 | 77 | d = size(img,1)-size(img,2) ; 78 | dy = floor(max(d,0)/2) ; 79 | dx = floor(max(-d,0)/2) ; 80 | img = img(dy+1:end-dy, dx+1:end-dx, :) ; % center crop 81 | img = imresize(img,opts.imageSize, 'bilinear') ; 82 | img = single(img)/256 ; 83 | im_ = img - 0.5 ; 84 | 85 | if GPUon 86 | im_ = gpuArray(im_); 87 | end 88 | 89 | %evaluate the image 90 | net.mode='test'; 91 | net.eval({'input', im_}) ; 92 | 93 | %gather the requested predictions 94 | output = cell(numel(pred,1)); 95 | for i=1:numel(pred) 96 | output{i} = net.vars(net.getVarIndex(pred{i})).value ; 97 | end 98 | 99 | % plot 100 | set(fig.bg, 'cdata', img) ; 101 | for k = 1:numel(fig.keypoints) 102 | map = output{1}(:,:,k) ; 103 | map = map / max(map(:)) ; 104 | set(fig.fg{k}, 'alphadata', map) ; 105 | end 106 | 107 | elapsed = toc(start) ; 108 | n = n + 1 ; 109 | title(sprintf('Keypoint Detection (%.1f Hz)', n/elapsed)) ; 110 | drawnow ; 111 | end 112 | 113 | end -------------------------------------------------------------------------------- /model-train/vl_nnheatloss.m: -------------------------------------------------------------------------------- 1 | function Y = vl_nnheatloss(X,c, dzdy, varargin) 2 | 3 | %Created by Vasileios Belagiannis. 4 | %Contact: vb@robots.ox.ac.uk 5 | 6 | %Heatmap Loss 7 | 8 | opts.loss = 'l2loss-heatmap' ; 9 | opts.ignOcc=0; 10 | opts = vl_argparse(opts,varargin) ; 11 | 12 | switch lower(opts.loss) 13 | case {'l2loss-heatmap', 'l2loss-pairwiseheatmap'} 14 | %GT 15 | if strcmp(opts.loss,'l2loss-heatmap') 16 | if iscell(c) 17 | Y = cat(4,c{2,:}); 18 | else 19 | Y = c; 20 | c = opts.labels; 21 | end 22 | weight_mask = cat(4,c{3,:}); 23 | elseif strcmp(opts.loss,'l2loss-pairwiseheatmap') 24 | if iscell(c) 25 | Y = cat(4,c{8,:}); 26 | else 27 | Y = c; 28 | c = opts.labels; 29 | end 30 | weight_mask = cat(4,c{9,:}); 31 | end 32 | 33 | res=(Y-X); 34 | 35 | %missing annotation - zeros contribution 36 | idx=repmat(sum(sum(Y,1),2)==0,size(res,1),size(res,2)); 37 | res(idx)= zerosLike(res(idx)); %check it again!!! 38 | 39 | n=1; 40 | if isempty(dzdy) %forward 41 | Y = sqrt(sum(res(:).^2))/(size(res,1)*size(res,2)*size(res,3)) *1000;%scale factor 42 | else 43 | 44 | %if occluded keypoints - ignore them 45 | if opts.ignOcc 46 | idxOcc=Y<0; 47 | res(idxOcc)= zerosLike(res(idxOcc)); 48 | end 49 | 50 | %gradient weighting 51 | res=weight_mask.*res; 52 | 53 | Y_= -1.*res; 54 | Y = single (Y_ * (dzdy / n) ); 55 | end 56 | 57 | case {'mse-heatmap', 'mse-pairwiseheatmap'} %mean squarred error 58 | %GT stored in sparse matrices stacked next to each other 59 | if strcmp(opts.loss,'mse-heatmap') 60 | Y = cat(4,c{2,:}); 61 | elseif strcmp(opts.loss,'mse-pairwiseheatmap') 62 | Y = cat(4,c{8,:}); 63 | end 64 | 65 | if isempty(dzdy) %forward 66 | 67 | fun = @(A,B) A-B; 68 | err = bsxfun(fun,Y,X); 69 | 70 | %missing annotation - zeros contribution 71 | idx=repmat(sum(sum(Y,1),2)==0,size(err,1),size(err,2)); 72 | err(idx)= zerosLike(err(idx)); %check it again!!! 73 | 74 | %if occluded keypoints - ignore them 75 | if opts.ignOcc 76 | idxOcc=Y<0; 77 | err(idxOcc)= zerosLike(err(idxOcc)); 78 | end 79 | 80 | %dim 1 - 62, dim 2 - 62, dim 3 - 16 (body joints -heatmaps) 81 | Y = sum(err(:).^2)/(size(X,1)*size(X,2)*size(X,3)); 82 | 83 | else %nothing to backprop 84 | Y = zerosLike(X) ; 85 | end 86 | otherwise 87 | error('Unknown loss ''%s''.', opts.loss) ; 88 | end 89 | 90 | % -------------------------------------------------------------------- 91 | function y = zerosLike(x) 92 | % -------------------------------------------------------------------- 93 | if isa(x,'gpuArray') 94 | y = gpuArray.zeros(size(x),'single') ; 95 | else 96 | y = zeros(size(x),'single') ; 97 | end 98 | 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keypoint Detection 2 | Our model uses a convolutional neural network with a recurrent module to detect body keypoints from a single image. Below, we provide code to train and test the model. 3 | 4 | For a web-based demo visit the [project page](http://www.robots.ox.ac.uk/~vgg/software/keypoint_detection/). 5 | 6 | ### Software Dependencies 7 | You need MATLAB and the [MatConvNet](https://github.com/vlfeat/matconvnet) toolbox to run this demo. This demo has been compiled and tested for Matlab 2016a, CUDNN v5.1 and cuda 8.0 (using a Linux machine). 8 | 9 | ### Train Model (MPII Human Pose Dataset) 10 | 1. Download and prepare the dataset by executing **getMPIIData_v3.m** and then **splitMPIIData_V4.m**. 11 | - The data preparation takes several minutes to be completed. The output is the train (**MPI_imdbsT1aug0.mat**, ~3GB) and validation (**MPI_imdbsV1aug0.mat**, ~0.3GB) files. 12 | 2. Execute **trainBodyPose_example.m** (need to set some parameters such as MatConvNet path and GPU index). 13 | 3. To test model, run the demo that follows is explained below. 14 | 15 | Parameters for training a model with one recurrent iteration: 16 | ``` 17 | net = initialize3ObjeRecFusion(opts,2,0,'shareFlag',[0,1]); 18 | 19 | opts.derOutputs = {'objective1', 1,'objective2', 1, 'objective4', 1,'objective5', 1, 'objective7', 1,'objective8', 1}; 20 | 21 | ``` 22 | 23 | Parameters for training a model with two recurrent iterations (**default**): 24 | ``` 25 | net = initialize3ObjeRecFusion(opts,3,0,'shareFlag',[0,1,1]); 26 | 27 | opts.derOutputs = {'objective1', 1,'objective2', 1, 'objective4', 1,'objective5', 1, 'objective7', 1,'objective8', 1, 'objective10', 1,'objective11', 1}; 28 | 29 | ``` 30 | 31 | ### Run Demo 32 | 1. Download a [pre-trained model](https://github.com/ox-vgg/keypoint_models/tree/master/models) if you haven't trained one. 33 | 2. Execute **demo_keypoint.m**. 34 | 35 | ### Run Live Demo 36 | You need to have the **Web-Camera** package installed and set the path to it. 37 | 38 | 1. Download a [pre-trained model](https://github.com/ox-vgg/keypoint_models/tree/master/models) if you haven't trained one. 39 | 2. Execute **demolive_keypoints.m**. 40 | 41 | ## Citation 42 | 43 | @inproceedings{Belagiannis17, 44 | title={Recurrent Human Pose Estimation}, 45 | author={Belagiannis, Vasileios and Zisserman, Andrew}, 46 | booktitle={International Conference on Automatic Face and Gesture Recognition}, 47 | year={2017}, 48 | organization={IEEE} 49 | } 50 | 51 | 52 | For further details on the model and web-based demo, please visit our [project-page](http://www.robots.ox.ac.uk/~vgg/software/keypoint_detection/). 53 | 54 | ## License 55 | 56 | Copyright (c) 2017, University of Oxford 57 | All rights reserved. 58 | 59 | Redistribution and use in source and binary forms, with or without 60 | modification, are permitted provided that the following conditions are met: 61 | 62 | * Redistributions of source code must retain the above copyright notice, this 63 | list of conditions and the following disclaimer. 64 | 65 | * Redistributions in binary form must reproduce the above copyright notice, 66 | this list of conditions and the following disclaimer in the documentation 67 | and/or other materials provided with the distribution. 68 | 69 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 70 | -------------------------------------------------------------------------------- /KeyPointDetector.m: -------------------------------------------------------------------------------- 1 | classdef KeyPointDetector0 %pad width 59 | sq_img = padarray(img,[0,diff]); 60 | else 61 | sq_img = padarray(img,[-diff,0]); 62 | end 63 | [sq_imdim1, sq_imdim2, ~] = size(sq_img); 64 | 65 | % scale the image to a standard size 66 | sq_norm_img = imresize(sq_img, obj.norm_img_size); 67 | 68 | %single format and mean subtraction 69 | sq_norm_img = single(sq_norm_img); 70 | 71 | if ( obj.gpu_id ~= 0 ) 72 | sq_norm_img = gpuArray(sq_norm_img); 73 | end 74 | 75 | sq_norm_img = bsxfun(@minus, sq_norm_img, single(repmat(128,1,1,3))) ; %subtract mean 76 | sq_norm_img = sq_norm_img./256; 77 | 78 | %evaluate the image 79 | obj.net.mode = 'test'; 80 | obj.net.eval({'input', sq_norm_img}) ; 81 | net_output = obj.net.vars(obj.net.getVarIndex('prediction10')).value ; 82 | heatmap_count = size(net_output, 3); 83 | 84 | kpx = zeros(1, heatmap_count); 85 | kpy = zeros(1, heatmap_count); 86 | kpname = cell(1, heatmap_count); 87 | 88 | for hid=1:heatmap_count 89 | hmap = net_output(:, :, hid); 90 | [hy, hx, ~] = find( hmap == max(hmap(:))); 91 | sq_norm_img_x = (hx(1)/obj.net_output_size(1))*obj.norm_img_size(1); 92 | sq_norm_img_y = (hy(1)/obj.net_output_size(2))*obj.norm_img_size(2); 93 | 94 | sq_img_x = (sq_norm_img_x/obj.norm_img_size(1)) * sq_imdim1; 95 | sq_img_y = (sq_norm_img_y/obj.norm_img_size(1)) * sq_imdim2; 96 | 97 | img_x = sq_img_x; 98 | img_y = sq_img_y; 99 | 100 | if diff>0 101 | img_x = img_x - diff; 102 | else 103 | img_y = img_y + diff; 104 | end 105 | 106 | kpx(1,hid) = int32( gather( img_x ) ); 107 | kpy(1,hid) = int32( gather( img_y ) ); 108 | kpname{hid} = obj.keypoint_names{hid}; 109 | end 110 | end 111 | 112 | end 113 | end 114 | -------------------------------------------------------------------------------- /trainBodyPose_example.m: -------------------------------------------------------------------------------- 1 | function trainBodyPose_example(varargin) 2 | %% Keypoints detection: Training on MPII Human Pose Dataset 3 | % This is a MatConvNet demo for human pose human estimation. 4 | % Related Work: Belagiannis V., and Zisserman A., 5 | % Recurrent Human Pose Estimation, FG (2017). 6 | % Contact: Vasileios Belagiannis, vb@robots.ox.ac.uk 7 | 8 | %% Read Me 9 | % Before executing the training script, the MPII dataset hat to be 10 | % downloaded using the following scripts: 11 | % 1. Download and prepare MPII Pose Dataset by executing: 12 | % getMPIIData_v3.m and then splitMPIIData_V4.m (may take several minutes) 13 | % To train a model the following has to be defined: 14 | % - MatConvNet directories 15 | % - Number of GPU (opts.gpus) 16 | 17 | %% MatConvNet library setup (execute vl_setupnn.m) 18 | run(fullfile(fileparts(mfilename('fullpath')),... 19 | '..', 'matconvnet-b23','matlab', 'vl_setupnn.m')) ; 20 | 21 | %% Pose Example 22 | 23 | addpath('dagnetworks'); 24 | addpath('model-train'); 25 | 26 | opts.datas='MPI'; 27 | 28 | %actual input to the network (after augmentation) 29 | opts.patchHei=248; 30 | opts.patchWi=248; 31 | 32 | opts.flipFlg='mpi';%model for flipping the joint (flip augmentation) 33 | opts.cam=1; 34 | opts.aug=0; 35 | opts.NoAug=1; %used for calling the correct imdb creation function 36 | 37 | opts.batchSize = 20; 38 | opts.numEpochs = 50 ; 39 | opts.learningRate = [0.00001*ones(1, 30) 0.000005*ones(1, 5) 0.000001*ones(1, 50)] ; 40 | opts.batchNormalization = 1;%useful for big networks 41 | 42 | %GPU 43 | opts.gpus = [3]; 44 | 45 | opts.outNode=16;%if heatmaps loss, then number of heatmaps 46 | opts.outPairNode=15;% pairwise terms 47 | opts.inNode=3; 48 | opts.lossFunc='l2loss-heatmap'; 49 | opts.lossFunc2='l2loss-pairwiseheatmap'; 50 | opts.lossFunc3=[]; 51 | opts.ConcFeat=384; %number of channels at concat 52 | opts.skip_layer = 'layer20'; %skip layer 53 | 54 | %export path, imdb store path and location of training / validation data. 55 | opts.expDir = sprintf('../data/v1.00-%s_%s_%d_2Obje3Fus',opts.datas,opts.lossFunc,opts.cam) ; 56 | opts.imdbPath = sprintf('../data/%s-baseline_imdb%d.mat',opts.datas, opts.cam); 57 | opts.DataMatTrain=sprintf('../data/%s_imdbsT%daug%d.mat',opts.datas,opts.cam,opts.aug); 58 | opts.DataMatVal=sprintf('../data/%s_imdbsV%daug%d.mat',opts.datas,opts.cam,opts.aug); 59 | 60 | %transformation from input image (248X248) to ouput heatmap(62X62) 61 | trf=[0.25 0 0 ; 0 0.25 0; 0 0 1]; %only scale 62 | 63 | %build network - 1 iteration 64 | %net = initialize3ObjeRecFusion(opts,2,0,'shareFlag', [0,1]); 65 | 66 | %objectives 67 | %opts.derOutputs = {'objective1', 1,'objective2', 1, ... %feed-forward 68 | % 'objective4', 1,'objective5', 1, ... %iter 0 (not-shared w) 69 | % 'objective7', 1,'objective8', 1}; %iter 1 (shared w) 70 | 71 | 72 | 73 | %build network - 2 iterations 74 | net = initialize3ObjeRecFusion(opts,3,0,'shareFlag',[0,1,1]); 75 | 76 | %objectives 77 | opts.derOutputs = {'objective1', 1,'objective2', 1, ... %feed-forward 78 | 'objective4', 1,'objective5', 1, ... %iter 0 (not-shared w) 79 | 'objective7', 1,'objective8', 1, ... %iter 1 (shared w) 80 | 'objective10', 1,'objective11', 1}; %iter 2 (shared w) 81 | 82 | 83 | opts.net=net; 84 | 85 | opts.numThreads = 15; 86 | opts.transformation = 'f25' ; 87 | opts.averageImage = single(repmat(128,1,1,opts.inNode)); 88 | opts.fast = 1; 89 | opts.imageSize = [248, 248] ; 90 | opts.border = [8, 8] ; 91 | opts.bord=[0,0,0,0]; %cropping border 92 | 93 | %heatmap setting; 94 | opts.heatmap=1; 95 | opts.trf=trf; 96 | opts.sigma=1.3; 97 | opts.FiltSize=31; 98 | opts.HeatMapSize=[62, 62]; 99 | opts.padGTim=[0 0]; 100 | opts.rotate=1;%rotation flag 101 | opts.scale=1;%scale augm. 102 | %extra parse settings 103 | 104 | %occluded keypoints 105 | opts.inOcclud=1; 106 | 107 | %multiple instances 108 | opts.multipInst=0; 109 | 110 | %heatmap scheme 111 | opts.HeatMapScheme=1; %how to generate heatmaps 112 | 113 | opts.train.momentum=0.95; 114 | 115 | opts.negHeat=0;%set to 1 to include negative values for the occlusion 116 | opts.ignoreOcc=0;%set to 1 to include negative values for the occlusion 117 | opts.ignoreRest=0; %quasi single human training 118 | 119 | opts.pairHeatmap=1; %generate heatmaps for pairs of body parts 120 | opts.bodyPairs = [1 2 3 4 5 7 8 9 11 12 13 14 14 15 7; 2 3 7 5 6 4 10 10 12 13 8 8 15 16 8]; %full body - MPI 121 | 122 | opts.magnif=12;%amplifier for the body heatmaps 123 | opts.facX=0.15;%pairwise heatmap width (def. 0.15) 124 | opts.facY=0.08;%pairwise heatmap height 125 | 126 | %imdb generation function 127 | opts.imdbfn =@getImdbNoAug; 128 | 129 | opts = vl_argparse(opts, varargin); 130 | 131 | %create imdb and train 132 | cnn_regressor_dag(opts); 133 | -------------------------------------------------------------------------------- /model-train/cnn_regressor_dag.m: -------------------------------------------------------------------------------- 1 | function [net, info] = cnn_regressor_dag(varargin) 2 | 3 | %Create the imdb and train the model 4 | 5 | % Dataset 6 | opts.datas='BBC'; 7 | 8 | % Network input resolution 9 | opts.patchHei=120; 10 | opts.patchWi=80; 11 | 12 | % Camera (always 1 for this setup) 13 | opts.cam=1; 14 | 15 | % Augmentation settings 16 | opts.aug=0; 17 | opts.NoAug=0; 18 | 19 | % Export directory for model and imdb 20 | opts.expDir = sprintf('/data/vb/Temp/%s-baseline%d',opts.datas,opts.cam) ; 21 | opts.imdbPath = fullfile(opts.expDir, sprintf('imdb%d.mat',opts.cam)); 22 | 23 | opts.train.batchSize = 256 ; 24 | opts.train.numSubBatches = 1; 25 | opts.train.numEpochs = 100 ; 26 | opts.train.continue = true ; 27 | opts.train.derOutputs= {'objective', 1} ; 28 | opts.train.learningRate = [0.001*ones(1, 17) 0.0005*ones(1, 50) 0.002*ones(1, 500) 0.03*ones(1, 130) 0.01*ones(1, 100)] ; 29 | opts.train.momentum=0.9; 30 | opts.useBnorm = false ; 31 | opts.batchNormalization = 0; 32 | opts.train.prefetch = false ; 33 | 34 | %GPU 35 | opts.train.gpus = []; 36 | 37 | % Architecture parameters 38 | opts.initNet='/home/bazile/Temp/data/tukey0.mat'; %pre-trained network 39 | opts.outNode=14;%14 bbc, 18,28,42 40 | opts.outPairNode=8;% pairwise terms 41 | opts.outCombiNode=5; 42 | opts.inNode=3; 43 | opts.lossFunc='tukeyloss-heatmap'; 44 | opts.lossFunc2='tukeyloss-pairwiseheatmap'; 45 | opts.lossFunc3=[]; 46 | opts.errMetric = 'mse-combo'; 47 | opts.train.thrs=0; 48 | opts.train.refine=false; 49 | opts.HighRes = 0; %high resolution output 50 | opts.ConcFeat=768; %number of channels at concat 51 | opts.skip_layer = ''; %skip layer 52 | opts.train.hardNeg=0;%hard negative mining 53 | 54 | % Dataset (train and validation files) 55 | opts.DataMatTrain=sprintf('/mnt/ramdisk/vb/%s/%s_imdbsT%daug%d.mat',opts.datas,opts.datas,opts.cam,opts.aug); 56 | opts.DataMatVal=sprintf('/mnt/ramdisk/vb/%s/%s_imdbsV%daug%d.mat',opts.datas,opts.datas,opts.cam,opts.aug); 57 | opts.DataMatTrain2=[]; %for combination of different datasets 58 | 59 | % IMDB generation function 60 | opts.imdbfn= []; 61 | 62 | % Batch parameters 63 | bopts.numThreads = 15; 64 | bopts.transformation = 'f5' ; 65 | bopts.averageImage = single(repmat(128,1,1,opts.inNode)); 66 | bopts.imageSize = [120, 80] ; 67 | bopts.border = [10, 10] ; 68 | bopts.heatmap=0; 69 | bopts.trf=[]; 70 | bopts.sigma=[]; 71 | bopts.HeatMapSize=[]; 72 | bopts.flipFlg='bbc';%full, bbc 73 | bopts.inOcclud=1; %include occluded points 74 | bopts.multipInst=1; %include multiple instances in the heatmaps 75 | bopts.HeatMapScheme=1; %how to generate heatmaps 76 | bopts.rotate=0;%rotation augm. 77 | bopts.scale=0;%scale augm. 78 | bopts.color=0;%color augm. 79 | bopts.pairHeatmap=0; 80 | bopts.bodyPairs = []; 81 | bopts.ignoreOcc=0;%requires 82 | bopts.magnif=8;%amplifier for the body heatmaps 83 | bopts.facX=0.15;%pairwise heatmap width 84 | bopts.facY=0.08;%pairwise heatmap height 85 | 86 | % Parse settings 87 | [opts, trainParams] = vl_argparse(opts, varargin); %main settings 88 | [opts.train, boptsParams]= vl_argparse(opts.train, trainParams); %train settings 89 | [bopts, netParams]= vl_argparse(bopts, boptsParams); %batch settings 90 | net=netParams{1}.net; %network 91 | clear trainParams boptsParams netParams; 92 | 93 | opts.train.bodyPairs = bopts.bodyPairs;%structured prediction training 94 | opts.train.trf = bopts.trf;%transformation from the input to the output space 95 | 96 | useGpu = numel(opts.train.gpus) > 0 ; 97 | bopts.GPU=useGpu; 98 | 99 | %Paths OSX / Ubuntu 100 | opts.train.expDir = opts.expDir ; 101 | 102 | % -------------------------------------------------------------------- 103 | % Prepare data 104 | % -------------------------------------------------------------------- 105 | 106 | if exist(opts.imdbPath) 107 | imdb = load(opts.imdbPath); 108 | else 109 | imdb = opts.imdbfn(opts); 110 | mkdir(opts.expDir) ; 111 | save(opts.imdbPath, '-struct', 'imdb','-v7.3') ; 112 | end 113 | 114 | % -------------------------------------------------------------------- 115 | % Train 116 | % -------------------------------------------------------------------- 117 | 118 | fn = getBatchDagNNWrapper(bopts,useGpu) ; 119 | 120 | info = cnn_train_dag_reg(net, imdb, fn, opts.train) ; 121 | 122 | % ------------------------------------------------------------------------- 123 | function fn = getBatchDagNNWrapper(opts, useGpu) 124 | % ------------------------------------------------------------------------- 125 | fn = @(imdb,batch) getBatchDagNN(imdb,batch,opts,useGpu) ; 126 | 127 | % ------------------------------------------------------------------------- 128 | function inputs = getBatchDagNN(imdb, batch, opts, useGpu) 129 | % ------------------------------------------------------------------------- 130 | 131 | [im, lab] = cnn_regressor_get_batch(imdb, batch, opts, ... 132 | 'prefetch', nargout == 0) ; 133 | if nargout > 0 134 | if useGpu 135 | im = gpuArray(im) ; 136 | end 137 | inputs = {'input', im, 'label', lab} ; 138 | end -------------------------------------------------------------------------------- /model-train/cnn_train_dag_reg.m: -------------------------------------------------------------------------------- 1 | function [net,stats] = cnn_train_dag_reg(net, imdb, getBatch, varargin) 2 | %CNN_TRAIN_DAG Demonstrates training a CNN using the DagNN wrapper 3 | % CNN_TRAIN_DAG() is similar to CNN_TRAIN(), but works with 4 | % the DagNN wrapper instead of the SimpleNN wrapper. 5 | 6 | % Copyright (C) 2014-16 Andrea Vedaldi. 7 | % All rights reserved. 8 | % 9 | % This file is part of the VLFeat library and is made available under 10 | % the terms of the BSD license (see the COPYING file). 11 | 12 | opts.expDir = fullfile('data','exp') ; 13 | opts.continue = true ; 14 | opts.batchSize = 256 ; 15 | opts.numSubBatches = 1 ; 16 | opts.train = [] ; 17 | opts.val = [] ; 18 | opts.gpus = [] ; 19 | opts.prefetch = false ; 20 | opts.numEpochs = 300 ; 21 | opts.learningRate = 0.001 ; 22 | opts.weightDecay = 0.0005 ; 23 | opts.momentum = 0.9 ; 24 | opts.refine=false;%my param 25 | opts.scbox=0;%my param 26 | opts.iter=1;%my param 27 | opts.thrs=0;%my param 28 | opts.bodyPairs=[];%my param 29 | opts.trf=[];%my param 30 | opts.hardNeg=0;%my param 31 | opts.shuffle=1;%my param 32 | opts.saveMomentum = true ; 33 | opts.nesterovUpdate = false ; 34 | opts.randomSeed = 0 ; 35 | opts.profile = false ; 36 | opts.parameterServer.method = 'mmap' ; 37 | opts.parameterServer.prefix = 'mcn' ; 38 | 39 | opts.derOutputs = {'objective', 1} ; 40 | opts.extractStatsFn = @extractStats ; 41 | opts.plotStatistics = true; 42 | opts = vl_argparse(opts, varargin) ; 43 | 44 | if ~exist(opts.expDir, 'dir'), mkdir(opts.expDir) ; end 45 | if isempty(opts.train), opts.train = find(imdb.images.set==1) ; end 46 | if isempty(opts.val), opts.val = find(imdb.images.set==2) ; end 47 | if isnan(opts.train), opts.train = [] ; end 48 | if isnan(opts.val), opts.val = [] ; end 49 | 50 | % ------------------------------------------------------------------------- 51 | % Initialization 52 | % ------------------------------------------------------------------------- 53 | 54 | evaluateMode = isempty(opts.train) ; 55 | if ~evaluateMode 56 | if isempty(opts.derOutputs) 57 | error('DEROUTPUTS must be specified when training.\n') ; 58 | end 59 | end 60 | 61 | % ------------------------------------------------------------------------- 62 | % Train and validate 63 | % ------------------------------------------------------------------------- 64 | 65 | modelPath = @(ep) fullfile(opts.expDir, sprintf('net-epoch-%d.mat', ep)); 66 | modelFigPath = fullfile(opts.expDir, 'net-train.pdf') ; 67 | 68 | start = opts.continue * findLastCheckpoint(opts.expDir) ; 69 | if start >= 1 70 | fprintf('%s: resuming by loading epoch %d\n', mfilename, start) ; 71 | [net, state, stats] = loadState(modelPath(start)) ; 72 | else 73 | state = [] ; 74 | end 75 | 76 | for epoch=start+1:opts.numEpochs 77 | 78 | % Set the random seed based on the epoch and opts.randomSeed. 79 | % This is important for reproducibility, including when training 80 | % is restarted from a checkpoint. 81 | 82 | rng(epoch + opts.randomSeed) ; 83 | prepareGPUs(opts, epoch == start+1) ; 84 | 85 | % Train for one epoch. 86 | params = opts ; 87 | params.epoch = epoch ; 88 | params.learningRate = opts.learningRate(min(epoch, numel(opts.learningRate))) ; 89 | if opts.shuffle 90 | params.train = opts.train(randperm(numel(opts.train))) ; % shuffle 91 | params.val = opts.val(randperm(numel(opts.val))) ; 92 | else 93 | %this is used for processing videos. it creates random subbatches 94 | % of consecutive frames 95 | 96 | combis = []; 97 | Nel = opts.batchSize; 98 | for i=1:numel(opts.train) - Nel +1 99 | combis(i,:)=i:i+Nel-1; 100 | end 101 | part_order = randperm(size(combis,1)); 102 | tmp = combis(part_order,:)'; 103 | tmp = tmp(:)'; 104 | params.train = tmp(1:numel(opts.train)); 105 | params.val = opts.val(randperm(numel(opts.val))) ; %not needed yet 106 | end 107 | params.imdb = imdb ; 108 | params.getBatch = getBatch ; 109 | 110 | if numel(opts.gpus) <= 1 111 | [net, state] = processEpoch(net, state, params, 'train') ; 112 | [net, state] = processEpoch(net, state, params, 'val') ; 113 | if ~evaluateMode 114 | saveState(modelPath(epoch), net, state) ; 115 | end 116 | lastStats = state.stats ; 117 | else 118 | spmd 119 | [net, state] = processEpoch(net, state, params, 'train') ; 120 | [net, state] = processEpoch(net, state, params, 'val') ; 121 | if labindex == 1 && ~evaluateMode 122 | saveState(modelPath(epoch), net, state) ; 123 | end 124 | lastStats = state.stats ; 125 | end 126 | lastStats = accumulateStats(lastStats) ; 127 | end 128 | 129 | stats.train(epoch) = lastStats.train ; 130 | stats.val(epoch) = lastStats.val ; 131 | clear lastStats ; 132 | saveStats(modelPath(epoch), stats) ; 133 | 134 | if opts.plotStatistics 135 | switchFigure(1) ; clf ; 136 | plots = setdiff(... 137 | cat(2,... 138 | fieldnames(stats.train)', ... 139 | fieldnames(stats.val)'), {'num', 'time'}) ; 140 | for p = plots 141 | p = char(p) ; 142 | values = zeros(0, epoch) ; 143 | leg = {} ; 144 | for f = {'train', 'val'} 145 | f = char(f) ; 146 | if isfield(stats.(f), p) 147 | tmp = [stats.(f).(p)] ; 148 | values(end+1,:) = tmp(1,:)' ; 149 | leg{end+1} = f ; 150 | end 151 | end 152 | subplot(1,numel(plots),find(strcmp(p,plots))) ; 153 | plot(1:epoch, values','o-') ; 154 | xlabel('epoch') ; 155 | title(p) ; 156 | legend(leg{:}) ; 157 | grid on ; 158 | end 159 | drawnow ; 160 | print(1, modelFigPath, '-dpdf') ; 161 | end 162 | end 163 | 164 | % With multiple GPUs, return one copy 165 | if isa(net, 'Composite'), net = net{1} ; end 166 | 167 | % ------------------------------------------------------------------------- 168 | function [net, state] = processEpoch(net, state, params, mode) 169 | % ------------------------------------------------------------------------- 170 | % Note that net is not strictly needed as an output argument as net 171 | % is a handle class. However, this fixes some aliasing issue in the 172 | % spmd caller. 173 | 174 | % initialize with momentum 0 175 | if isempty(state) || isempty(state.momentum) 176 | state.momentum = num2cell(zeros(1, numel(net.params))) ; 177 | end 178 | 179 | % move CNN to GPU as needed 180 | numGpus = numel(params.gpus) ; 181 | if numGpus >= 1 182 | net.move('gpu') ; 183 | state.momentum = cellfun(@gpuArray, state.momentum, 'uniformoutput', false) ; 184 | end 185 | if numGpus > 1 186 | parserv = ParameterServer(params.parameterServer) ; 187 | net.setParameterServer(parserv) ; 188 | else 189 | parserv = [] ; 190 | end 191 | 192 | % profile 193 | if params.profile 194 | if numGpus <= 1 195 | profile clear ; 196 | profile on ; 197 | else 198 | mpiprofile reset ; 199 | mpiprofile on ; 200 | end 201 | end 202 | 203 | num = 0 ; 204 | epoch = params.epoch ; 205 | subset = params.(mode) ; 206 | adjustTime = 0 ; 207 | 208 | % %find all loss layers to pass parameters afterwards 209 | % reg_layers = []; 210 | % for i=1:2:numel(opts.derOutputs) 211 | % obj_name = opts.derOutputs{i}; 212 | % reg_layers(numel(reg_layers)+1) = net.getLayerIndex(obj_name); 213 | % idx=sscanf(obj_name, 'objective%d'); 214 | % err_name = sprintf('error%d',idx); 215 | % reg_layers(numel(reg_layers)+1) = net.getLayerIndex(err_name); 216 | % end 217 | stats.num = 0 ; % return something even if subset = [] 218 | stats.time = 0 ; 219 | 220 | start = tic ; 221 | for t=1:params.batchSize:numel(subset) 222 | fprintf('%s: epoch %02d: %3d/%3d:', mode, epoch, ... 223 | fix((t-1)/params.batchSize)+1, ceil(numel(subset)/params.batchSize)) ; 224 | batchSize = min(params.batchSize, numel(subset) - t + 1) ; 225 | 226 | for s=1:params.numSubBatches 227 | % get this image batch and prefetch the next 228 | batchStart = t + (labindex-1) + (s-1) * numlabs ; 229 | batchEnd = min(t+params.batchSize-1, numel(subset)) ; 230 | batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ; 231 | num = num + numel(batch) ; 232 | if numel(batch) == 0, continue ; end 233 | 234 | %%pass parameters to regloss 235 | %for i=1:numel(reg_layers) 236 | % iter=((state.epoch-1)*opts.batchSize) + round(t/opts.batchSize)+1; 237 | % net.layers(reg_layers(i)).block.iter = iter; 238 | % net.layers(reg_layers(i)).block.scbox = opts.scbox; 239 | % net.layers(reg_layers(i)).block.hardNeg = opts.hardNeg; 240 | %end 241 | 242 | clear inputs; 243 | inputs = params.getBatch(params.imdb, batch) ; 244 | 245 | if params.prefetch 246 | if s == params.numSubBatches 247 | batchStart = t + (labindex-1) + params.batchSize ; 248 | batchEnd = min(t+2*params.batchSize-1, numel(subset)) ; 249 | else 250 | batchStart = batchStart + numlabs ; 251 | end 252 | nextBatch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ; 253 | params.getBatch(params.imdb, nextBatch) ; 254 | end 255 | 256 | if strcmp(mode, 'train') 257 | net.mode = 'normal' ; 258 | net.accumulateParamDers = (s ~= 1) ; 259 | net.eval(inputs, params.derOutputs, 'holdOn', s < params.numSubBatches) ; 260 | else 261 | net.mode = 'test' ; 262 | net.eval(inputs) ; 263 | end 264 | end 265 | 266 | % Accumulate gradient. 267 | if strcmp(mode, 'train') 268 | if ~isempty(parserv), parserv.sync() ; end 269 | state = accumulateGradients(net, state, params, batchSize, parserv) ; 270 | end 271 | 272 | % Get statistics. 273 | time = toc(start) + adjustTime ; 274 | batchTime = time - stats.time ; 275 | stats.num = num ; 276 | stats.time = time ; 277 | stats = params.extractStatsFn(stats,net) ; 278 | currentSpeed = batchSize / batchTime ; 279 | averageSpeed = (t + batchSize - 1) / time ; 280 | if t == 3*params.batchSize + 1 281 | % compensate for the first three iterations, which are outliers 282 | adjustTime = 4*batchTime - time ; 283 | stats.time = time + adjustTime ; 284 | end 285 | 286 | fprintf(' %.1f (%.1f) Hz', averageSpeed, currentSpeed) ; 287 | for f = setdiff(fieldnames(stats)', {'num', 'time'}) 288 | f = char(f) ; 289 | fprintf(' %s: %.5f', f, stats.(f)) ; 290 | end 291 | fprintf('\n') ; 292 | end 293 | 294 | % Save back to state. 295 | state.stats.(mode) = stats ; 296 | if params.profile 297 | if numGpus <= 1 298 | state.prof.(mode) = profile('info') ; 299 | profile off ; 300 | else 301 | state.prof.(mode) = mpiprofile('info'); 302 | mpiprofile off ; 303 | end 304 | end 305 | if ~params.saveMomentum 306 | state.momentum = [] ; 307 | else 308 | state.momentum = cellfun(@gather, state.momentum, 'uniformoutput', false) ; 309 | end 310 | 311 | net.reset() ; 312 | net.move('cpu') ; 313 | 314 | % ------------------------------------------------------------------------- 315 | function state = accumulateGradients(net, state, params, batchSize, parserv) 316 | % ------------------------------------------------------------------------- 317 | numGpus = numel(params.gpus) ; 318 | otherGpus = setdiff(1:numGpus, labindex) ; 319 | 320 | for p=1:numel(net.params) 321 | 322 | if ~isempty(parserv) 323 | parDer = parserv.pullWithIndex(p) ; 324 | else 325 | parDer = net.params(p).der ; 326 | end 327 | 328 | switch net.params(p).trainMethod 329 | 330 | case 'average' % mainly for batch normalization 331 | thisLR = net.params(p).learningRate ; 332 | net.params(p).value = vl_taccum(... 333 | 1 - thisLR, net.params(p).value, ... 334 | (thisLR/batchSize/net.params(p).fanout), parDer) ; 335 | 336 | case 'gradient' 337 | thisDecay = params.weightDecay * net.params(p).weightDecay ; 338 | thisLR = params.learningRate * net.params(p).learningRate ; 339 | 340 | % Normalize gradient and incorporate weight decay. 341 | parDer = vl_taccum(1/batchSize, parDer, ... 342 | thisDecay, net.params(p).value) ; 343 | 344 | % Update momentum. 345 | state.momentum{p} = vl_taccum(... 346 | params.momentum, state.momentum{p}, ... 347 | -1, parDer) ; 348 | 349 | % Nesterov update (aka one step ahead). 350 | if params.nesterovUpdate 351 | delta = vl_taccum(... 352 | params.momentum, state.momentum{p}, ... 353 | -1, parDer) ; 354 | else 355 | delta = state.momentum{p} ; 356 | end 357 | 358 | % Update parameters. 359 | net.params(p).value = vl_taccum(... 360 | 1, net.params(p).value, thisLR, delta) ; 361 | 362 | otherwise 363 | error('Unknown training method ''%s'' for parameter ''%s''.', ... 364 | net.params(p).trainMethod, ... 365 | net.params(p).name) ; 366 | end 367 | end 368 | 369 | % ------------------------------------------------------------------------- 370 | function stats = accumulateStats(stats_) 371 | % ------------------------------------------------------------------------- 372 | 373 | for s = {'train', 'val'} 374 | s = char(s) ; 375 | total = 0 ; 376 | 377 | % initialize stats stucture with same fields and same order as 378 | % stats_{1} 379 | stats__ = stats_{1} ; 380 | names = fieldnames(stats__.(s))' ; 381 | values = zeros(1, numel(names)) ; 382 | fields = cat(1, names, num2cell(values)) ; 383 | stats.(s) = struct(fields{:}) ; 384 | 385 | for g = 1:numel(stats_) 386 | stats__ = stats_{g} ; 387 | num__ = stats__.(s).num ; 388 | total = total + num__ ; 389 | 390 | for f = setdiff(fieldnames(stats__.(s))', 'num') 391 | f = char(f) ; 392 | stats.(s).(f) = stats.(s).(f) + stats__.(s).(f) * num__ ; 393 | 394 | if g == numel(stats_) 395 | stats.(s).(f) = stats.(s).(f) / total ; 396 | end 397 | end 398 | end 399 | stats.(s).num = total ; 400 | end 401 | 402 | % ------------------------------------------------------------------------- 403 | function stats = extractStats(stats, net) 404 | % ------------------------------------------------------------------------- 405 | sel = find(cellfun(@(x) isa(x,'dagnn.Loss'), {net.layers.block})) ; 406 | for i = 1:numel(sel) 407 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ; 408 | end 409 | 410 | %regression loss 411 | sel = find(cellfun(@(x) isa(x,'dagnn.RegLoss'), {net.layers.block})) ; 412 | for i = 1:numel(sel) 413 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ; 414 | end 415 | 416 | %structured loss 417 | sel = find(cellfun(@(x) isa(x,'dagnn.StructLoss'), {net.layers.block})) ; 418 | for i = 1:numel(sel) 419 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ; 420 | end 421 | 422 | %temporal structured loss 423 | sel = find(cellfun(@(x) isa(x,'dagnn.TempStructLoss'), {net.layers.block})) ; 424 | for i = 1:numel(sel) 425 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ; 426 | end 427 | 428 | 429 | %optical flow loss 430 | sel = find(cellfun(@(x) isa(x,'dagnn.FlowLoss'), {net.layers.block})) ; 431 | for i = 1:numel(sel) 432 | stats.(net.layers(sel(i)).outputs{1}) = net.layers(sel(i)).block.average ; 433 | end 434 | 435 | % ------------------------------------------------------------------------- 436 | function saveState(fileName, net_, state) 437 | % ------------------------------------------------------------------------- 438 | net = net_.saveobj() ; 439 | save(fileName, 'net', 'state') ; 440 | 441 | % ------------------------------------------------------------------------- 442 | function saveStats(fileName, stats) 443 | % ------------------------------------------------------------------------- 444 | if exist(fileName) 445 | save(fileName, 'stats', '-append') ; 446 | else 447 | save(fileName, 'stats') ; 448 | end 449 | 450 | % ------------------------------------------------------------------------- 451 | function [net, state, stats] = loadState(fileName) 452 | % ------------------------------------------------------------------------- 453 | load(fileName, 'net', 'state', 'stats') ; 454 | net = dagnn.DagNN.loadobj(net) ; 455 | if isempty(whos('stats')) 456 | error('Epoch ''%s'' was only partially saved. Delete this file and try again.', ... 457 | fileName) ; 458 | end 459 | 460 | % ------------------------------------------------------------------------- 461 | function epoch = findLastCheckpoint(modelDir) 462 | % ------------------------------------------------------------------------- 463 | list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ; 464 | tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ; 465 | epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ; 466 | epoch = max([epoch 0]) ; 467 | 468 | % ------------------------------------------------------------------------- 469 | function switchFigure(n) 470 | % ------------------------------------------------------------------------- 471 | if get(0,'CurrentFigure') ~= n 472 | try 473 | set(0,'CurrentFigure',n) ; 474 | catch 475 | figure(n) ; 476 | end 477 | end 478 | 479 | % ------------------------------------------------------------------------- 480 | function clearMex() 481 | % ------------------------------------------------------------------------- 482 | clear vl_tmove vl_imreadjpeg ; 483 | 484 | % ------------------------------------------------------------------------- 485 | function prepareGPUs(opts, cold) 486 | % ------------------------------------------------------------------------- 487 | numGpus = numel(opts.gpus) ; 488 | if numGpus > 1 489 | % check parallel pool integrity as it could have timed out 490 | pool = gcp('nocreate') ; 491 | if ~isempty(pool) && pool.NumWorkers ~= numGpus 492 | delete(pool) ; 493 | end 494 | pool = gcp('nocreate') ; 495 | if isempty(pool) 496 | parpool('local', numGpus) ; 497 | cold = true ; 498 | end 499 | 500 | end 501 | if numGpus >= 1 && cold 502 | fprintf('%s: resetting GPU\n', mfilename) 503 | clearMex() ; 504 | if numGpus == 1 505 | gpuDevice(opts.gpus) 506 | else 507 | spmd 508 | clearMex() ; 509 | gpuDevice(opts.gpus(labindex)) 510 | end 511 | end 512 | end 513 | -------------------------------------------------------------------------------- /dagnetworks/initialize3ObjeRecFusion.m: -------------------------------------------------------------------------------- 1 | function net = initialize3ObjeRecFusion(opts,Niter,resConn, varargin) 2 | % Related Work: Belagiannis V., and Zisserman A., 3 | % Recurrent Human Pose Estimation, FG (2017). 4 | % Contact: Vasileios Belagiannis, vb@robots.ox.ac.uk 5 | % The default network is defined to perform 2 iterations 6 | % (i.e. 2 recurrent layers: 1 without shared weights and another with shared) 7 | % Niter: number of iterations (1 - no shared w, 2 - shared w and etc..) 8 | % resConn: Residual connection(not maintained anymore, it should stay to 0) 9 | % To build a network with 1 non-shared and 1 recurrent iteration: 10 | % net = initialize3ObjeRecFusion(opts,2,0,[0,1]); 11 | % To build a network with 1 non-shared and 2 recurrent iterations: 12 | % net = initialize3ObjeRecFusion(opts,3,0,[0,1,1]); 13 | 14 | scal = 1 ; 15 | init_bias = 0.0; 16 | net.layers = {} ; 17 | opts.cudnnWorkspaceLimit = 1024*1024*1024*5; 18 | lopts.shareFlag=[0,1];%default: 2 iterations (1. not shared w, 2. shared w) 19 | lopts.ref_idx=[];%reference layer to get the share weights 20 | lopts = vl_argparse(lopts, varargin); 21 | convOpts = {'CudnnWorkspaceLimit', opts.cudnnWorkspaceLimit} ; 22 | id=0; 23 | 24 | %define the reference recurrent layer for the shared weights 25 | lopts.ref_idx=0; 26 | i=0; 27 | while i0 133 | %for j=1:numel(data.annolist(i).annorect) 134 | for k=1:numel(data.single_person{i}) 135 | j = data.single_person{i}(k); 136 | if isfield(data.annolist(i).annorect(j),'objpos') %GT exists 137 | if ~isempty(data.annolist(i).annorect(j).objpos) %localize objects 138 | cnt_test = cnt_test + 1; 139 | 140 | hei=round(baseH*data.annolist(i).annorect(j).scale); 141 | wid=round(baseW*data.annolist(i).annorect(j).scale); 142 | 143 | obj_pose(1) = data.annolist(i).annorect(j).objpos.x; 144 | obj_pose(2) = data.annolist(i).annorect(j).objpos.y; 145 | 146 | xUpLe=round(obj_pose(1)-wid/2); 147 | yUpLe=round(obj_pose(2)-hei/2); 148 | 149 | %check if the bounding box exceeds the image plane 150 | %and pad the image and all GT poses 151 | padUpX = ceil(abs(min(0,xUpLe))); 152 | padUpY = ceil(abs(min(0,yUpLe))); 153 | xUpLe = max(1,xUpLe + padUpX); 154 | yUpLe = max(1,yUpLe + padUpY); 155 | padLoX = ceil(max(size(img,2), xUpLe + wid) - size(img,2)); 156 | padLoY = ceil(max(size(img,1), yUpLe + hei) - size(img,1)); 157 | 158 | imgPad = uint8(128*ones(padUpY+size(img,1)+padLoY, ... 159 | padUpX+size(img,2)+padLoX,3)); 160 | imgPad(1+padUpY:padUpY+size(img,1),1+padUpX:padUpX+size(img,2),:) = img; 161 | 162 | %check if the bounding box exceeds the image plane 163 | %xUpLe=round(max(1,xUpLe)); 164 | %wid=round(min(size(img,2),xUpLe+wid-1)-xUpLe); 165 | %yUpLe=round(max(1,yUpLe)); 166 | %hei=round(min(size(img,1),yUpLe+hei-1)-yUpLe); 167 | 168 | bbox_test{cnt_test} = [xUpLe, yUpLe, wid, hei]; 169 | pad_test{cnt_test} = [padUpX, padUpY, padLoX, padLoY]; 170 | 171 | %crop the image 172 | img_final_test{cnt_test}=imgPad(yUpLe:yUpLe+hei,xUpLe:xUpLe+wid,:); 173 | 174 | %resize to standard size 175 | s_s = [size(img_final_test{cnt_test},1) size(img_final_test{cnt_test},2)]; 176 | s_t = [imgSize(1) imgSize(2)]; 177 | s = s_s.\s_t; 178 | 179 | %image resized 180 | img_final_test{cnt_test} = imresize(img_final_test{cnt_test}, 'scale', s, 'method', 'bilinear'); 181 | 182 | %index for mapping back 183 | testMap(cnt_test,1) = i; 184 | testMap(cnt_test,2) = j; 185 | 186 | %imshow(img_final_test{cnt_test});disp(i); pause(); close; 187 | end 188 | end 189 | end 190 | end 191 | 192 | %for each instance 193 | for j=1:numel(data.annolist(i).annorect) 194 | 195 | %if isfield(data.annolist(i).annorect(j),'objpos') 196 | if isfield(data.annolist(i).annorect(j),'annopoints') %GT exists 197 | if ~isempty(data.annolist(i).annorect(j).objpos) %localize objects 198 | cnt=cnt+1; 199 | 200 | sets_train(i)=1; %include in training 201 | 202 | sets_train_idx(cnt,1)=i; 203 | sets_train_idx(cnt,2)=j; 204 | 205 | %store the joints of the active individual 206 | ptsAll{cnt}=zeros(16,3); 207 | poseGT = data.annolist(i).annorect(j).annopoints.point; 208 | for p=1:numel(poseGT) 209 | ptsAll{cnt}(poseGT(p).id+1,1)=poseGT(p).x +1;%1-indexed 210 | ptsAll{cnt}(poseGT(p).id+1,2)=poseGT(p).y +1; 211 | if ~isempty(poseGT(p).is_visible) 212 | ptsAll{cnt}(poseGT(p).id+1,3)=poseGT(p).is_visible; 213 | else 214 | ptsAll{cnt}(poseGT(p).id+1,3)=1;%head 215 | end 216 | end 217 | 218 | % %plot the keypoints 219 | % for p=1:size(ptsAll{cnt},1) 220 | % if (ptsAll{cnt}(p,1)==0 && ptsAll{cnt}(p,3)==1) || (ptsAll{cnt}(p,2)==0 && ptsAll{cnt}(p,3)==1) 221 | % imshow(img); hold on; 222 | % plot(ptsAll{cnt}(p,1),ptsAll{cnt}(p,2),'rx'); 223 | % hold off; pause(); 224 | % end 225 | % end 226 | 227 | %store the joints of the rest individuals 228 | valPerson=0; 229 | if ~isempty(data.annolist(i).valPerson) %check for validation indiv. 230 | valPerson=data.annolist(i).valPerson; 231 | end 232 | 233 | ptsRest{cnt,1}=[];%required initialization 234 | if valPerson==0 %multiple individuals only for training frames 235 | 236 | cnt_rest=0; 237 | for jrest=1:numel(data.annolist(i).annorect) 238 | %exclude the active indiv. & validation indiv. (if any) 239 | if jrest~=j && jrest~=valPerson && ... 240 | ~isempty(data.annolist(i).annorect(jrest).annopoints) %missing annotation 241 | cnt_rest=cnt_rest+1; 242 | ptsRest{cnt,cnt_rest}=zeros(16,3); 243 | poseGT = data.annolist(i).annorect(jrest).annopoints.point; 244 | for p=1:numel(poseGT) 245 | ptsRest{cnt,cnt_rest}(poseGT(p).id+1,1)=poseGT(p).x +1;%1-indexed; 246 | ptsRest{cnt,cnt_rest}(poseGT(p).id+1,2)=poseGT(p).y +1;%1-indexed; 247 | if ~isempty(poseGT(p).is_visible) 248 | ptsRest{cnt,cnt_rest}(poseGT(p).id+1,3)=poseGT(p).is_visible; 249 | else 250 | ptsRest{cnt,cnt_rest}(poseGT(p).id+1,3)=1;%head 251 | end 252 | end 253 | end 254 | end 255 | 256 | %debug 257 | %if jrest==valPerson 258 | % disp(['validation individual should be excluded ' num2str(valPerson)]); 259 | %end 260 | 261 | end %if - valPerson==0 262 | 263 | hei=round(baseH*data.annolist(i).annorect(j).scale); 264 | wid=round(baseW*data.annolist(i).annorect(j).scale); 265 | 266 | obj_pose(1) = data.annolist(i).annorect(j).objpos.x; 267 | obj_pose(2) = data.annolist(i).annorect(j).objpos.y; 268 | 269 | %imshow(img); hold on; 270 | %text(obj_pose(1),obj_pose(2),'C','Color','m','FontSize',22); 271 | %pause(); hold off; 272 | 273 | xUpLe=round(obj_pose(1)-wid/2); 274 | yUpLe=round(obj_pose(2)-hei/2); 275 | 276 | %check if the bounding box exceeds the image plane 277 | %and pad the image and all GT poses 278 | padUpX = ceil(abs(min(0,xUpLe))); 279 | padUpY = ceil(abs(min(0,yUpLe))); 280 | xUpLe = max(1,xUpLe + padUpX); 281 | yUpLe = max(1,yUpLe + padUpY); 282 | padLoX = ceil(max(size(img,2), xUpLe + wid) - size(img,2)); 283 | padLoY = ceil(max(size(img,1), yUpLe + hei) - size(img,1)); 284 | 285 | imgPad = uint8(128*ones(padUpY+size(img,1)+padLoY, ... 286 | padUpX+size(img,2)+padLoX,3)); 287 | imgPad(1+padUpY:padUpY+size(img,1),1+padUpX:padUpX+size(img,2),:) = img; 288 | 289 | %check if the bounding box exceeds the image plane 290 | %xUpLe=round(max(1,xUpLe)); 291 | %wid=round(min(size(img,2),xUpLe+wid-1)-xUpLe); 292 | %yUpLe=round(max(1,yUpLe)); 293 | %hei=round(min(size(img,1),yUpLe+hei-1)-yUpLe); 294 | 295 | bbox{cnt} = [xUpLe, yUpLe, wid, hei]; 296 | pad_train{cnt} = [padUpX, padUpY, padLoX, padLoY]; 297 | 298 | %crop the image 299 | img_final{cnt}=imgPad(yUpLe:yUpLe+hei,xUpLe:xUpLe+wid,:); 300 | 301 | %change the origin for the padded image 302 | idx=(ptsAll{cnt}(:,1)>0 & ptsAll{cnt}(:,2)>0); 303 | ptsAll{cnt}(idx,1)=ptsAll{cnt}(idx,1)+padUpX; 304 | ptsAll{cnt}(idx,2)=ptsAll{cnt}(idx,2)+padUpY; 305 | for jRest=1:size(ptsRest,2) 306 | if ~isempty(ptsRest{cnt,jRest}) 307 | idx=(ptsAll{cnt}(:,1)>0 & ptsAll{cnt}(:,2)>0); 308 | ptsRest{cnt,jRest}(idx,1)=ptsRest{cnt,jRest}(idx,1)+padUpX; 309 | ptsRest{cnt,jRest}(idx,2)=ptsRest{cnt,jRest}(idx,2)+padUpY; 310 | end 311 | end 312 | 313 | %shift the origin for the active individual 314 | ptsAll{cnt}(:,1)=ptsAll{cnt}(:,1)-(xUpLe-1); 315 | ptsAll{cnt}(:,2)=ptsAll{cnt}(:,2)-(yUpLe-1); 316 | checkX=double(ptsAll{cnt}(:,1)>0); 317 | checkY=double(ptsAll{cnt}(:,2)>0); 318 | ptsAll{cnt}(:,1:2)= ptsAll{cnt}(:,1:2).*[checkX checkX]; 319 | ptsAll{cnt}(:,1:2)= ptsAll{cnt}(:,1:2).*[checkY checkY]; 320 | checkX=double(ptsAll{cnt}(:,1)<=size(img_final{cnt},2)); 321 | checkY=double(ptsAll{cnt}(:,2)<=size(img_final{cnt},1)); 322 | ptsAll{cnt}(:,1:2)= ptsAll{cnt}(:,1:2).*[checkX checkX]; 323 | ptsAll{cnt}(:,1:2)= ptsAll{cnt}(:,1:2).*[checkY checkY]; 324 | 325 | %resize to standard size 326 | s_s = [size(img_final{cnt},1) size(img_final{cnt},2)]; 327 | s_t = [imgSize(1) imgSize(2)]; 328 | s = s_s.\s_t; 329 | tf = [ s(2) 0 0; 0 s(1) 0; 0 0 1]; 330 | T = affine2d(tf); 331 | 332 | %points scaled 333 | [ptsAll{cnt}(:,1),ptsAll{cnt}(:,2)] = transformPointsForward(T, ptsAll{cnt}(:,1),ptsAll{cnt}(:,2)); 334 | 335 | %shift the origin for the rest 336 | for jRest=1:size(ptsRest,2) 337 | if ~isempty(ptsRest{cnt,jRest}) 338 | ptsRest{cnt,jRest}(:,1)=ptsRest{cnt,jRest}(:,1)-(xUpLe-1); 339 | ptsRest{cnt,jRest}(:,2)=ptsRest{cnt,jRest}(:,2)-(yUpLe-1); 340 | checkX=double(ptsRest{cnt,jRest}(:,1)>0); 341 | checkY=double(ptsRest{cnt,jRest}(:,2)>0); 342 | ptsRest{cnt,jRest}(:,1:2)=ptsRest{cnt,jRest}(:,1:2).*[checkX checkX]; 343 | ptsRest{cnt,jRest}(:,1:2)=ptsRest{cnt,jRest}(:,1:2).*[checkY checkY]; 344 | checkX=double(ptsRest{cnt,jRest}(:,1)<=size(img_final{cnt},2)); 345 | checkY=double(ptsRest{cnt,jRest}(:,2)<=size(img_final{cnt},1)); 346 | ptsRest{cnt,jRest}(:,1:2)=ptsRest{cnt,jRest}(:,1:2).*[checkX checkX]; 347 | ptsRest{cnt,jRest}(:,1:2)=ptsRest{cnt,jRest}(:,1:2).*[checkY checkY]; 348 | 349 | %points scaled 350 | idx=ptsRest{cnt,jRest}(:,1)>0;%not need for checking ptsRest{cnt,jRest}(:,2)>0; 351 | [ptsRest{cnt,jRest}(idx,1),ptsRest{cnt,jRest}(idx,2)] ... 352 | = transformPointsForward(T,ptsRest{cnt,jRest}(idx,1),ptsRest{cnt,jRest}(idx,2)); 353 | 354 | end 355 | end 356 | clear xUpLe yUpLe wid hei padUpX padUpY padLoX padLoY; 357 | 358 | %image resized 359 | img_final{cnt} = imresize(img_final{cnt}, 'scale', s, 'method', 'bilinear'); 360 | 361 | 362 | % %visualization 363 | % if valPerson~=0 364 | % imshow(img_final{cnt}); hold on; 365 | % x=size(img_final{cnt},2)/2; 366 | % y=size(img_final{cnt},1)/2; 367 | % text(x,y,'C','Color','m','FontSize',22); 368 | % poseGT=ptsAll{cnt}; 369 | % for jj=1:1:size(poseGT,1) %active indiv. 370 | % if poseGT(jj,3)==1 371 | % text(poseGT(jj,1),poseGT(jj,2),int2str(jj),'Color','m','FontSize',16); 372 | % else 373 | % text(poseGT(jj,1),poseGT(jj,2),int2str(jj),'Color','r','FontSize',16); 374 | % end 375 | % end 376 | % 377 | % for jRest=1:size(ptsRest,2) %rest indiv. 378 | % if ~isempty(ptsRest{cnt,jRest}) 379 | % poseGT=(ptsRest{cnt,jRest}); 380 | % for jj=1:1:size(poseGT,1) 381 | % if poseGT(jj,3)==1 382 | % text(poseGT(jj,1),poseGT(jj,2),int2str(jj),'Color','g','FontSize',16); 383 | % else 384 | % text(poseGT(jj,1),poseGT(jj,2),int2str(jj),'Color','r','FontSize',16); 385 | % end 386 | % end 387 | % 388 | % end 389 | % end 390 | % pause(); 391 | % hold off; 392 | % 393 | % %check mapping back 394 | % close all; 395 | % imshow(img); hold on; 396 | % %original GT points 397 | % poseGT = data.annolist(i).annorect(j).annopoints.point; 398 | % for p=1:numel(poseGT) 399 | % text(poseGT(p).x,poseGT(p).y,int2str(poseGT(p).id+1),'Color','c','FontSize',16); 400 | % end 401 | % 402 | % %transformed-points 403 | % poseGT=ptsAll{cnt}; 404 | % [poseGT(:,1),poseGT(:,2)] = transformPointsInverse(T, ptsAll{cnt}(:,1),ptsAll{cnt}(:,2)); 405 | % poseGT(:,1) = poseGT(:,1) + bbox{cnt}(1)-1; 406 | % poseGT(:,2) = poseGT(:,2) + bbox{cnt}(2)-1; 407 | % poseGT(:,1) = poseGT(:,1) - pad_train{cnt}(1); 408 | % poseGT(:,2) = poseGT(:,2) - pad_train{cnt}(2); 409 | % for jj=1:1:size(poseGT,1) %active indiv. 410 | % if poseGT(jj,3)==1 411 | % text(poseGT(jj,1),poseGT(jj,2),int2str(jj),'Color','m','FontSize',16); 412 | % else 413 | % text(poseGT(jj,1),poseGT(jj,2),int2str(jj),'Color','r','FontSize',16); 414 | % end 415 | % end 416 | % 417 | % pause(); 418 | % hold off; 419 | % 420 | % end 421 | % %visualization 422 | 423 | end 424 | end 425 | end 426 | 427 | %disp(i); 428 | waitbar(i / length(data.img_train)); 429 | end 430 | close(h); 431 | 432 | 433 | %storefile=sprintf('extractedData_detector_%d_orcacle_%d',detector, oracleTr); 434 | storefile=sprintf('extractedData_%d_%d',imgSize(1),imgSize(2)); 435 | save(storefile,'img_final','ptsAll','ptsRest','sets_train','sets_train_idx','bbox','pad_train','-v7.3'); 436 | 437 | clear img_final bbox; 438 | imgPath = img_final_test; 439 | bbox = bbox_test; 440 | poseGT = []; 441 | storefile=sprintf('testMPI_%d_%d',imgSize(1),imgSize(2)); 442 | save(storefile,'imgPath','poseGT','bbox','testMap','pad_test','-v7.3'); -------------------------------------------------------------------------------- /model-train/cnn_regressor_get_batch.m: -------------------------------------------------------------------------------- 1 | function [imo, labels] = cnn_regressor_get_batch(imdb, batch, varargin) 2 | 3 | opts.imageSize = [120, 80] ; 4 | opts.border = [10, 10] ; 5 | opts.keepAspect = false ; 6 | opts.numAugments = 1 ; 7 | opts.transformation = 'f5' ; 8 | opts.averageImage = []; 9 | opts.rgbVariance = zeros(0,3,'single') ; 10 | opts.interpolation = 'bilinear' ; 11 | opts.numThreads = 15 ; 12 | opts.prefetch = false ; 13 | opts.heatmap=0; 14 | opts.trf=[]; 15 | opts.sigma=[];%heatmap variance 16 | opts.HeatMapSize=[]; 17 | opts.flipFlg='bbc'; 18 | opts.inOcclud=1; %include occluded points 19 | opts.multipInst=1; %include multiple instances in the heatmaps 20 | opts.GPU=0; 21 | opts.HeatMapScheme=1; %how to generate heatmaps 22 | opts.rotate=0;%rotation augmentation 23 | opts.scale=0;%scale augmentation 24 | opts.color=0;%color augmentation 25 | opts.ignoreOcc=0; 26 | opts.bodyPairs = []; 27 | opts.pairHeatmap=0; 28 | opts.magnif=8; 29 | opts.facX=0.15;%pairwise heatmap width 30 | opts.facY=0.08;%pairwise heatmap height 31 | opts = vl_argparse(opts, varargin); 32 | 33 | im = imdb.images.data(batch) ; 34 | 35 | tfs = [] ; 36 | switch opts.transformation 37 | case 'none' 38 | tfs = [ 39 | .5 ; 40 | .5 ; 41 | 0 ] ; 42 | case 'flipOnly' 43 | tfs = [ 44 | .5 .5; 45 | .5 .5; 46 | 0 1] ; 47 | case 'f5' 48 | tfs = [... 49 | .5 0 0 1 1 .5 0 0 1 1 ; 50 | .5 0 1 0 1 .5 0 1 0 1 ; 51 | 0 0 0 0 0 1 1 1 1 1] ; 52 | case 'f25' 53 | [tx,ty] = meshgrid(linspace(0,1,5)) ; 54 | tfs = [tx(:)' ; ty(:)' ; zeros(1,numel(tx))] ; 55 | tfs_ = tfs ; 56 | tfs_(3,:) = 1 ; 57 | tfs = [tfs,tfs_] ; 58 | case 'stretch' 59 | otherwise 60 | error('Uknown transformations %s', opts.transformation) ; 61 | end 62 | [~,transformations] = sort(rand(size(tfs,2), numel(batch)), 1) ; 63 | 64 | if ~isempty(opts.rgbVariance) && isempty(opts.averageImage) 65 | opts.averageImage = zeros(1,1,3) ; 66 | end 67 | if numel(opts.averageImage) == 3 68 | opts.averageImage = reshape(opts.averageImage, 1,1,3) ; 69 | end 70 | 71 | imo = zeros(opts.imageSize(1), opts.imageSize(2), 3, ... 72 | numel(batch)*opts.numAugments, 'single') ; 73 | 74 | 75 | %store the GT infromation for the error estimation 76 | if opts.heatmap 77 | %1 - keypoints, 2 - heatmap, 3 - weight mask, 4 - number of instances, 78 | %5 - body heatmap, 6 - body weight mask, 7 - segmentation mask, 8 - 79 | %pairwise heatmap, 9 - pairwise weight mask 80 | 81 | labels=cell(19,numel(batch)); 82 | else 83 | labels=cell(1,numel(batch)); 84 | end 85 | 86 | 87 | for i=1:numel(batch) 88 | fr=batch(i); 89 | 90 | poseGT=[]; 91 | poseRest=[]; 92 | 93 | % acquire image 94 | img = single(im{i}) ; 95 | 96 | %ground-truth 97 | poseGT=imdb.images.labels{fr}(:,:,1); 98 | poseRest=imdb.images.labels{fr}(:,:,2:end); 99 | if opts.multipInst==0 %exclude multiple instances 100 | poseRest=[]; 101 | end 102 | 103 | 104 | if opts.inOcclud==0 %exclude occluded keypoints 105 | if opts.ignoreOcc 106 | poseOccl = double(poseGT(:,3) ==1) - double(poseGT(:,3) ==0); %visible 1, invisible -1 (not used anymore) 107 | else 108 | poseOccl = ones(size(poseGT,1)); 109 | end 110 | 111 | poseGT(:,1)=poseGT(:,1).*poseGT(:,3); 112 | poseGT(:,2)=poseGT(:,2).*poseGT(:,3); 113 | poseGT=poseGT(:,1:2); 114 | 115 | if sum(size(poseRest))>0 %rest keyipoints 116 | if opts.ignoreOcc 117 | poseRestOccl=double(poseRest(:,3,:) ==1) - double(poseRest(:,3,:) ==0); %visible 1, invisible -1 118 | else 119 | poseRestOccl = ones(size(poseRest,1),size(poseRest,3)); 120 | end 121 | poseRestOccl=squeeze(poseRestOccl); 122 | 123 | for k=1:size(poseRest,3) 124 | poseRest(:,1,k)=poseRest(:,1,k).*poseRest(:,3,k); 125 | poseRest(:,2,k)=poseRest(:,2,k).*poseRest(:,3,k); 126 | end 127 | poseRest=poseRest(:,1:2,:); 128 | end 129 | 130 | else%include occluded keypoints 131 | if opts.ignoreOcc 132 | poseOccl = double(poseGT(:,3) ==1) - double(poseGT(:,3) ==0); %visible 1, invisible -1 133 | else 134 | poseOccl = ones(size(poseGT,1),1); 135 | end 136 | poseGT=poseGT(:,1:2); 137 | 138 | if sum(size(poseRest))>0 %rest keipoints 139 | if opts.ignoreOcc 140 | poseRestOccl=double(poseRest(:,3,:) ==1) - double(poseRest(:,3,:) ==0); %visible 1, invisible -1 141 | else 142 | poseRestOccl = ones(size(poseRest,1),size(poseRest,3)); 143 | end 144 | poseRestOccl=squeeze(poseRestOccl); 145 | poseRest=poseRest(:,1:2,:); 146 | end 147 | end 148 | 149 | %ensure correct values for the main keypoint 150 | idx=poseGT(:,1)>0; %zeros in x, means zeros in y as well 151 | poseGT(idx,1) = max(1,poseGT(idx,1)); 152 | poseGT(idx,1) = min(size(img,2),poseGT(idx,1)); 153 | poseGT(idx,2) = max(1,poseGT(idx,2)); 154 | poseGT(idx,2) = min(size(img,1),poseGT(idx,2)); 155 | clear idx; 156 | 157 | %ensure correct values for the rest 158 | if sum(size(poseRest))>0 159 | for k=1:size(poseRest,3) 160 | idx=poseRest(:,1,k)>0; %zeros in x, means zeros in y as well 161 | poseRest(idx,1,k) = max(1,poseRest(idx,1,k)); 162 | poseRest(idx,1,k) = min(size(img,2),poseRest(idx,1,k)); 163 | poseRest(idx,2,k) = max(1,poseRest(idx,2,k)); 164 | poseRest(idx,2,k) = min(size(img,1),poseRest(idx,2,k)); 165 | clear idx; 166 | end 167 | end 168 | 169 | %Data 170 | imt=img; clear img; 171 | tempY=poseGT; 172 | tempRest=poseRest; 173 | 174 | % start - real part 175 | 176 | insta=1; %minimum number of individuals 177 | 178 | %color augmentation 179 | if opts.color && rand(1)>0.5 180 | imt(:,:,1) = imt(:,:,1)*(0.9 + rand(1)*(1.1-0.9)); 181 | imt(:,:,2) = imt(:,:,2)*(0.9 + rand(1)*(1.1-0.9)); 182 | imt(:,:,3) = imt(:,:,3)*(0.9 + rand(1)*(1.1-0.9)); 183 | imt = round(imt); 184 | imt(imt>256)=256; 185 | imt(imt<1)=1; 186 | end 187 | 188 | %rotate augmentation 189 | if opts.rotate && rand(1)>0.5 190 | pts = [tempY;reshape(permute(tempRest,[1,3,2]),[],2)]; 191 | [imt,pts] = imRotateGetBatch(imt,pts); 192 | 193 | tempY = pts(1:size(tempY,1),:); 194 | pts=pts(size(tempY,1)+1:end,:); 195 | cc=1; 196 | while size(pts,1)>0 197 | tempRest(:,:,cc) = pts(1:size(tempY,1),:); 198 | pts=pts(size(tempY,1)+1:end,:); 199 | cc=cc+1; 200 | end 201 | clear pts; 202 | end 203 | 204 | %scale augmentation 205 | if opts.scale && rand(1)>0.5 206 | pts = [tempY;reshape(permute(tempRest,[1,3,2]),[],2)]; 207 | [imt,pts] = imScaleGetBatch(imt,pts); 208 | 209 | tempY = pts(1:size(tempY,1),:); 210 | pts=pts(size(tempY,1)+1:end,:); 211 | cc=1; 212 | while size(pts,1)>0 213 | tempRest(:,:,cc) = pts(1:size(tempY,1),:); 214 | pts=pts(size(tempY,1)+1:end,:); 215 | cc=cc+1; 216 | end 217 | clear pts; 218 | end 219 | 220 | % crop & flip 221 | w = size(imt,2) ; 222 | h = size(imt,1) ; 223 | for ai = 1:opts.numAugments 224 | switch opts.transformation 225 | case 'stretch' 226 | sz = round(min(opts.imageSize(1:2)' .* (1-0.1+0.2*rand(2,1)), [w;h])) ; 227 | dx = randi(w - sz(2) + 1, 1) ; 228 | dy = randi(h - sz(1) + 1, 1) ; 229 | flip = rand > 0.5 ; 230 | otherwise 231 | %tf = tfs(:, transformations(mod(ai-1, numel(transformations)) + 1)) ; 232 | tf = tfs(:, transformations(1,i)) ; 233 | %tf=[0,0,0]';%debug 234 | sz = opts.imageSize(1:2) ; 235 | dx = floor((w - sz(2)) * tf(2)) + 1 ; 236 | dy = floor((h - sz(1)) * tf(1)) + 1 ; 237 | flip = tf(3) ; 238 | end 239 | 240 | %exclude missing annotation 241 | idx=tempY(:,1)>0 & tempY(:,2)>0; %zeros in x, means zeros in y as well 242 | 243 | %check if all keypoints are within image frame 244 | checkY(:,1)=tempY(idx,1)-dx+1; 245 | checkY(:,2)=tempY(idx,2)-dy+1; 246 | if sum(checkY(:)<0) ~=0 247 | ofsX=min(0,min(checkY(:,1))); 248 | ofsY=min(0,min(checkY(:,2))); 249 | ofsX=ofsX -(ofsX<0); %origin 1,1 250 | ofsY=ofsY -(ofsY<0); %origin 1,1 251 | dx = floor(dx+ofsX); 252 | dy = floor(dy+ofsY); 253 | end 254 | clear checkY; 255 | 256 | %updated dx,dy 257 | checkY(:,1)=tempY(idx,1)-dx+1; 258 | checkY(:,2)=tempY(idx,2)-dy+1; 259 | if sum(checkY(:,2)> (sz(1) + dy)) ~=0 || sum(checkY(:,1)> (sz(2) + dx)) ~=0 260 | ofsX=max(sz(2)+dx-1,max(checkY(:,1))) - (sz(2)+dx-1); 261 | ofsY=max(sz(1)+dy-1,max(checkY(:,2))) - (sz(1)+dy-1); 262 | dx = floor(dx+ofsX); 263 | dy = floor(dy+ofsY); 264 | end 265 | clear checkY; 266 | 267 | %crop keypoints 268 | tempY(idx,1)=tempY(idx,1)-dx+1; 269 | tempY(idx,2)=tempY(idx,2)-dy+1; 270 | 271 | clear idx; 272 | 273 | %crop images points 274 | sx = round(linspace(dx, sz(2)+dx-1, opts.imageSize(2))) ; 275 | sy = round(linspace(dy, sz(1)+dy-1, opts.imageSize(1))) ; 276 | 277 | if flip 278 | sx = fliplr(sx) ; 279 | tempY = flipKeyPointsCoords(opts.imageSize(2),tempY,opts.flipFlg); %flip keypoints 280 | end 281 | 282 | %crop extra points 283 | if sum(size(tempRest))>0 284 | insta=insta+size(tempRest,3); %number of individuals 285 | 286 | for j=1:size(tempRest,3) 287 | 288 | %exclude missing annotation 289 | idx= tempRest(:,1,j)>0 & tempRest(:,2,j)>0; 290 | 291 | tempRest(idx,1,j) = tempRest(idx,1,j) - dx +1; 292 | tempRest(idx,2,j) = tempRest(idx,2,j) - dy +1; 293 | 294 | clear idx; 295 | 296 | if flip 297 | tempRest(:,:,j) = flipKeyPointsCoords(opts.imageSize(2), tempRest(:,:,j),opts.flipFlg); 298 | end 299 | 300 | %explude the out-of-plane points 301 | idx= (tempRest(:,1,j) > numel(sx)) | (tempRest(:,1,j) < 1); %X coord. 302 | tempRest(idx,:,j) = 0; %both x nad y 303 | idx= (tempRest(:,2,j) > numel(sy)) | (tempRest(:,2,j) < 1); %Y coord. 304 | tempRest(idx,:,j) = 0; %both x nad y 305 | 306 | clear idx; 307 | end 308 | end 309 | 310 | %generate heatmap and segmentation map 311 | if opts.heatmap 312 | 313 | %heatmap size (defined based on the output of the network) 314 | heatmap=zeros(opts.HeatMapSize(1),opts.HeatMapSize(2),size(tempY,1)); 315 | heatmap_mask=zeros(opts.HeatMapSize(1),opts.HeatMapSize(2),size(tempY,1)); 316 | 317 | %rest of keypoints - heatmaps 318 | if sum(size(tempRest))>0 319 | poseRest = zeros(size(tempRest)); 320 | for k=1:size(tempRest,3) 321 | 322 | %transform the keypoints to the output heatmap space 323 | restY= (opts.trf*[tempRest(:,:,k) ones(size(tempRest(:,:,k),1),1)]')'; 324 | poseRest(:,:,k)=restY(:,1:2); %abuse of poseRest, change this name 325 | 326 | %add first the rest keypoints and then the main 327 | for j=1:size(poseRest,1) 328 | %fix rounding problems 329 | if poseRest(j,1,k)>0 && poseRest(j,2,k)>0 %missing keypoints 330 | x=min(max(1,round(poseRest(j,1,k))),size(heatmap,2)); 331 | y=min(max(1,round(poseRest(j,2,k))),size(heatmap,1)); 332 | 333 | topts=opts; 334 | if topts.ignoreOcc %negative values in order to ignore at the loss layer 335 | topts.magnif=topts.magnif.*poseRestOccl(j,k); %occlusion map 336 | end 337 | if topts.ignoreRest && topts.magnif>0 %second constraint because of the above 338 | topts.magnif=topts.magnif.*(-1); %single indiv. ignore rest 339 | end 340 | topts.facX=1; 341 | topts.facY=1; 342 | heatmap(:,:,j) = generateGHeatmap(heatmap(:,:,j),[x,y],180,topts.sigma,topts); 343 | end 344 | end 345 | end 346 | end 347 | 348 | %main keypoints - heatmaps 349 | 350 | %transform the keypoints to the output heatmap space 351 | poseMAP = (opts.trf*[tempY ones(size(tempY,1),1)]')'; 352 | poseMAP = poseMAP(:,1:2); 353 | 354 | for j=1:size(poseMAP,1) 355 | 356 | %fix rounding problems 357 | if poseMAP(j,1)>0 && poseMAP(j,2)>0 358 | x=min(max(1,round(poseMAP(j,1))),size(heatmap,2)); 359 | y=min(max(1,round(poseMAP(j,2))),size(heatmap,1)); 360 | 361 | topts=opts; 362 | if topts.ignoreOcc %negative values in order to ignore 363 | topts.magnif=topts.magnif.*poseOccl(j); %occlusion map 364 | end 365 | topts.facX=1; 366 | topts.facY=1; 367 | heatmap(:,:,j) = generateGHeatmap(heatmap(:,:,j),[x,y],180,topts.sigma,topts); 368 | end 369 | 370 | %generate weights for balancing positive / negative cells 371 | heatmap_mask = getWeightMask(opts,insta,0,heatmap,j,heatmap_mask); 372 | 373 | %visualization 374 | %mapVisualize(opts, imt, sy, sx, heatmap, j, poseMAP, tempY, poseRest); 375 | end 376 | 377 | labels{2,i} = heatmap; %body part heatmap 378 | labels{3,i} = heatmap_mask; %body part weighting mask 379 | labels{4,i} = insta;%number of instances 380 | labels{12,i} = poseOccl;%occlusion binary map 381 | end 382 | 383 | if opts.pairHeatmap 384 | bodyPairs = opts.bodyPairs; 385 | 386 | %pair heatmaps 387 | pair_heatmap=zeros(opts.HeatMapSize(1),opts.HeatMapSize(2),size(opts.bodyPairs,2)); 388 | pair_heatmap_mask=zeros(opts.HeatMapSize(1),opts.HeatMapSize(2),size(opts.bodyPairs,2)); 389 | 390 | for j=1:size(bodyPairs,2) 391 | part_A = bodyPairs(1,j); 392 | part_B = bodyPairs(2,j); 393 | 394 | %go through the rest keypoints first 395 | if sum(size(tempRest))>0 396 | for k=1:size(tempRest,3) 397 | if opts.ignoreOcc %in case of occluded keypoint, remove the part 398 | if poseRestOccl(part_A,k)==-1 399 | poseRest(part_A,:,k)=0; 400 | end 401 | if poseRestOccl(part_B,k)==-1 402 | poseRest(part_B,:,k)=0; 403 | end 404 | end 405 | 406 | if sum(poseRest(part_A,:,k))>0 && sum(poseRest(part_B,:,k))>0 %both keypoints availiable 407 | 408 | %get mu, sigma and theta for generating the heatmaps 409 | [part_center, theta, len] = getHeatMapParams(poseRest(part_A,:,k),poseRest(part_B,:,k)); 410 | 411 | %in case of single individual training, ignore the rest parts 412 | if opts.ignoreRest 413 | opts.magnif=-opts.magnif; 414 | end 415 | 416 | %generate the heatmap 417 | pair_heatmap(:,:,j) = generateGHeatmap(pair_heatmap(:,:,j),part_center,theta,len,opts); 418 | 419 | %restore the magnif value to be positive 420 | if opts.ignoreRest 421 | opts.magnif=-opts.magnif; 422 | end 423 | end 424 | end 425 | end 426 | 427 | if opts.ignoreOcc %in case of occluded keypoint, remove the part 428 | if poseOccl(part_A)==-1 429 | poseMAP(part_A,:)=0; 430 | end 431 | if poseOccl(part_B)==-1 432 | poseMAP(part_B,:)=0; 433 | end 434 | end 435 | 436 | %go through the main keypoints 437 | if sum(poseMAP(part_A,:))>0 && sum(poseMAP(part_B,:))>0 438 | %get mu, sigma and theta for generating the heatmaps 439 | [part_center, theta, len] = getHeatMapParams(poseMAP(part_A,:),poseMAP(part_B,:)); 440 | 441 | %generate the heatmap 442 | pair_heatmap(:,:,j) = generateGHeatmap(pair_heatmap(:,:,j),part_center,theta,len,opts); 443 | end 444 | 445 | %generate weights for balancing positive / negative cells 446 | pair_heatmap_mask = getWeightMask(opts,insta,0,pair_heatmap,j,pair_heatmap_mask); 447 | 448 | %visualization 449 | %mapVisualize(opts, imt, sy, sx, pair_heatmap, j, poseMAP, tempY, poseRest); 450 | 451 | end 452 | 453 | labels{8,i} = pair_heatmap; %pairwise heatmap 454 | labels{9,i} = pair_heatmap_mask; %pairwise weihgting mask 455 | end 456 | 457 | %0-1 output (keep to original coords for the heatmaps) 458 | %tempY = treeCoords(tempY,[],imdb.patchHei,imdb.patchWi,1); 459 | 460 | %store output 461 | %labY = reshape (tempY',size(tempY,1)*2,1);%do not reshape 462 | labels{1,i}=tempY; 463 | 464 | % %plot the points 465 | % imshow(uint8(imt(sy,sx,:))); hold on; 466 | % %tempY = treeCoords(tempY,[],imdb.patchHei,imdb.patchWi,0); 467 | % for po=1:size(tempY,1) 468 | % text(tempY(po,1),tempY(po,2), int2str(po),'Color','m','FontSize',15); 469 | % end 470 | % if sum(size(tempRest))>0 471 | % for k=1:size(tempRest,3) 472 | % tempY = tempRest(:,:,k); 473 | % for po=1:size(tempY,1) 474 | % text(tempY(po,1),tempY(po,2), int2str(po),'Color','g','FontSize',15); 475 | % end 476 | % end 477 | % end 478 | % hold off; pause(); 479 | % %plot the points 480 | 481 | if ~isempty(opts.averageImage) 482 | offset = opts.averageImage ; 483 | if ~isempty(opts.rgbVariance) 484 | offset = bsxfun(@plus, offset, reshape(opts.rgbVariance * randn(3,1), 1,1,3)) ; 485 | end 486 | imo(:,:,:,i) = bsxfun(@minus, imt(sy,sx,:), offset) ; 487 | imo(:,:,:,i) = imo(:,:,:,i)./256; 488 | else 489 | imo(:,:,:,i) = imt(sy,sx,:) ; 490 | end 491 | end 492 | end 493 | 494 | clear im; 495 | 496 | 497 | end 498 | 499 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 500 | 501 | function heatmap_mask = getWeightMask(opts,insta,thrs,heatmap,j,heatmap_mask) 502 | 503 | if sum(sum(heatmap(:,:,j)>0))>0 %if eveything background, do nothing 504 | %integral-idea 505 | heat_pixels=double((heatmap(:,:,j))>thrs); %remove abs for occlusion 506 | integ_a=sum(heat_pixels(:)); 507 | integ_b=numel(heat_pixels(:,:))-integ_a; 508 | heatmap_wa= 1- (integ_a/numel(heatmap(:,:,j))); 509 | heatmap_wb= 1- (integ_b/numel(heatmap(:,:,j))); 510 | 511 | if opts.HeatMapScheme==0 512 | %removed 513 | elseif opts.HeatMapScheme==1 514 | heat_pixels(heat_pixels<1) =heatmap_wb;%first this - order important 515 | heat_pixels(heat_pixels>=1)=heatmap_wa; 516 | 517 | %if more foreground pixels, do not weight 518 | if integ_a>integ_b 519 | heat_pixels=1; 520 | end 521 | 522 | heatmap_mask(:,:,j) = heat_pixels; 523 | end 524 | end 525 | 526 | end 527 | 528 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 529 | 530 | function [part_center, theta, len] = getHeatMapParams(keypA,keypB) 531 | 532 | %center of the Gaussian 533 | part_center(1,1) = keypA(1) + (keypB(1) - keypA(1))./2; 534 | part_center(1,2) = keypA(2) + (keypB(2) - keypA(2))./2; 535 | 536 | %rotation of the Gaussian 537 | theta = atand ((keypB(2) - keypA(2)) /... 538 | (keypB(1) - keypA(1))); 539 | 540 | %sigma of the gaussian based on the part length 541 | 542 | len = norm(keypB(:) - keypA(:)); 543 | 544 | end 545 | 546 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 547 | 548 | function mapVisualize(opts, imt, sy, sx, heatmap, j, poseMAP, tempY, poseRest) 549 | 550 | disp('j is'); 551 | disp(j); 552 | 553 | %visualization 554 | tform = affine2d(opts.trf'); 555 | I=imwarp(uint8(imt(sy,sx,:)),tform); 556 | padFact=[0,0]; 557 | if size(I,1)~=size(heatmap(:,:,j),1) 558 | padFact(1)=size(heatmap(:,:,j),1)-size(I,1); 559 | end 560 | if size(I,2)~=size(heatmap(:,:,j),2) 561 | padFact(2)=size(heatmap(:,:,j),2)-size(I,2); 562 | end 563 | 564 | %original image 565 | imshow(uint8(imt(sy,sx,:))); hold on; 566 | plot(tempY(j,1),tempY(j,2),'rx');hold off; 567 | 568 | figure; 569 | I = padarray(I,padFact,150,'pre'); 570 | I(:,:,2)=I(:,:,2)+uint8( double(rgb2gray(I)).*50.*(heatmap(:,:,j)-0.0) ); 571 | I(:,:,3)=I(:,:,3)+uint8( double(rgb2gray(I)).*10.*(heatmap(:,:,j)-0.0) ); 572 | imshow(I); hold on; plot(poseMAP(j,1),poseMAP(j,2),'rx'); 573 | 574 | if sum(size(poseRest))>0 575 | hold on; 576 | for k=1:size(poseRest,3) 577 | plot(poseRest(j,1,k),poseRest(j,2,k),'gx'); 578 | end 579 | hold off; 580 | end 581 | figure; imagesc(heatmap(:,:,j)); 582 | 583 | pause(); close all; 584 | %visualization 585 | 586 | end 587 | --------------------------------------------------------------------------------