├── CMakeLists.txt ├── README.md ├── img ├── final.jpg ├── result.jpg ├── runningman.jpg ├── runningman2.jpg └── runningman_cropped.jpg ├── model_trace.py └── prediction.cpp /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | project(predict_demo) 3 | SET(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} "-std=c++11 -O3") 4 | 5 | find_package(OpenCV REQUIRED) 6 | find_package(Torch REQUIRED) 7 | 8 | include_directories( ${OpenCV_INCLUDE_DIRS} ) 9 | add_executable(SPPE prediction.cpp) 10 | target_link_libraries(SPPE ${OpenCV_LIBS} ${TORCH_LIBRARIES}) 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Some Update on 19/2/2020: 2 | If you are still using this code, there are something wrong: 3 | 1. The model should have input (WxH) 256x320 instead of 224x224 which i have changed it in my new repo (MPPE), which causes the loss of accuracy . 4 | 2. The way I use loop to calculate the keypoints requires me to convert the tensor using tensor.item() which waste a hell ton of time , consider using: (ps also written in MPPE) 5 | 6 | auto ft = output_tensor.flatten(2,3); //flattening 7 | auto maxresult = at::max(ft,2); //find the coordinate with the highest confidence 8 | auto maxid = std::get<1>(maxresult); //get the tensor 9 | 10 | for(int kpts=0;kpts<17;kpts++){ 11 | int i = 0; 12 | i = (int)(maxid[personid][kpts].item().toFloat()) ; 13 | max_x = (i % 64)+1; 14 | max_y = (i / 64)+1; 15 | coor[kpts][0] = max_x ; 16 | coor[kpts][1] = max_y ; 17 | } 18 | 19 | This demo shows you how to build a single pose estimation algorithm using libtorch 20 | 21 | The model is trained using pytorch (Alphapose's SPPE model) , Check their github for training the model 22 | 23 | ## Contents 24 | 25 | 1. [Requirements](#requirements) 26 | 2. [Build](#build) 27 | 3. [Usage](#usage) 28 | 29 | 30 | ## Requirements 31 | 32 | - Pytorch 33 | - Libtorch 34 | - OpenCV 35 | 36 | ## Build 37 | 38 | ### Step 1 39 | 40 | Download the model via this link: 41 | 42 | https://drive.google.com/file/d/1xEQnogxHAkurNebHGatHzZvkN7N8khtt/view?usp=sharing 43 | 44 | Put the model "duc_se.pth" into the file directory: 45 | 46 | "models/sppe" 47 | 48 | 49 | ### Step 2 50 | 51 | Take a look at ``prediction.cpp`` to see how estimation is done 52 | 53 | - run ``model_trace.py``, then you will get a file ``posemodel.pt`` 54 | - compile your cpp program by ``-DCMAKE_PREFIX_PATH=/absolute/path/to/libtorch``, for example: 55 | 56 | ``` 57 | mkdir build 58 | cd build 59 | cmake -DCMAKE_PREFIX_PATH=/home/luisrodman/libtorch .. 60 | make 61 | ``` 62 | 63 | - test your program 64 | 65 | ``SPPE `` 66 | 67 | ``` 68 | == Switch to GPU mode 69 | == PoseModel loaded! 70 | == Input image path: [enter Q to exit] 71 | ../img/runningman_cropped.jpg 72 | == image size: [263 x 374] == 73 | == simply resize: [224 x 224] == 74 | Keypoint : 0 75 | x :12 y : 7 76 | Probability :12 77 | Keypoint : 1 78 | x :13 y : 6 79 | Probability :9 80 | Keypoint : 2 81 | x :30 y : 11 82 | Probability :1 83 | Keypoint : 3 84 | x :17 y : 5 85 | Probability :8 86 | Keypoint : 4 87 | x :30 y : 13 88 | Probability :4 89 | Keypoint : 5 90 | x :24 y : 9 91 | Probability :9 92 | Keypoint : 6 93 | x :37 y : 39 94 | Probability :18 95 | Keypoint : 7 96 | x :33 y : 13 97 | Probability :21 98 | Keypoint : 8 99 | x :40 y : 39 100 | Probability :13 101 | Keypoint : 9 102 | x :7 y : 20 103 | Probability :11 104 | Keypoint : 10 105 | x :7 y : 20 106 | Probability :17 107 | Keypoint : 11 108 | x :22 y : 29 109 | Probability :9 110 | Keypoint : 12 111 | x :29 y : 26 112 | Probability :9 113 | Keypoint : 13 114 | x :8 y : 35 115 | Probability :28 116 | Keypoint : 14 117 | x :36 y : 38 118 | Probability :36 119 | Keypoint : 15 120 | x :12 y : 48 121 | Probability :6 122 | Keypoint : 16 123 | x :50 y : 46 124 | Probability :20 125 | Visualizing Result ... 126 | 127 | 128 | ``` 129 | ![](./img/runningman_cropped.jpg) 130 | 131 | ![](./img/result.jpg) 132 | 133 | ![](./img/final.jpg) 134 | 135 | 136 | -------------------------------------------------------------------------------- /img/final.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winggo12/Libtorch-SPPE/34c797987a7d5321fe9dba1235d5fae21d092197/img/final.jpg -------------------------------------------------------------------------------- /img/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winggo12/Libtorch-SPPE/34c797987a7d5321fe9dba1235d5fae21d092197/img/result.jpg -------------------------------------------------------------------------------- /img/runningman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winggo12/Libtorch-SPPE/34c797987a7d5321fe9dba1235d5fae21d092197/img/runningman.jpg -------------------------------------------------------------------------------- /img/runningman2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winggo12/Libtorch-SPPE/34c797987a7d5321fe9dba1235d5fae21d092197/img/runningman2.jpg -------------------------------------------------------------------------------- /img/runningman_cropped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/winggo12/Libtorch-SPPE/34c797987a7d5321fe9dba1235d5fae21d092197/img/runningman_cropped.jpg -------------------------------------------------------------------------------- /model_trace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import sys 7 | class InferenNet(nn.Module): 8 | def __init__(self): 9 | super(InferenNet, self).__init__() 10 | model = createModel() 11 | model.load_state_dict(torch.load('./models/sppe/duc_se.pth')) 12 | model.eval() 13 | self.pyranet = model 14 | 15 | 16 | def forward(self, x): 17 | out = self.pyranet(x) 18 | out = out.narrow(1, 0, 17) 19 | 20 | return out 21 | 22 | #_____________________Resnet Layer_____________________# 23 | 24 | class SELayer(nn.Module): 25 | def __init__(self, channel, reduction=1): 26 | super(SELayer, self).__init__() 27 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 28 | self.fc = nn.Sequential( 29 | nn.Linear(channel, channel // reduction), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(channel // reduction, channel), 32 | nn.Sigmoid() 33 | ) 34 | 35 | def forward(self, x): 36 | b, c, _, _ = x.size() 37 | y = self.avg_pool(x).view(b, c) 38 | y = self.fc(y).view(b, c, 1, 1) 39 | return x * y 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=False): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 50 | padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(planes * 4) 54 | if reduction: 55 | self.se = SELayer(planes * 4) 56 | 57 | self.reduc = reduction 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = F.relu(self.bn1(self.conv1(x)), inplace=True) 65 | out = F.relu(self.bn2(self.conv2(out)), inplace=True) 66 | 67 | out = self.conv3(out) 68 | out = self.bn3(out) 69 | if self.reduc: 70 | out = self.se(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = F.relu(out) 77 | 78 | return out 79 | 80 | 81 | class SEResnet(nn.Module): 82 | """ SEResnet """ 83 | 84 | def __init__(self, architecture): 85 | super(SEResnet, self).__init__() 86 | assert architecture in ["resnet50", "resnet101"] 87 | self.inplanes = 64 88 | self.layers = [3, 4, {"resnet50": 6, "resnet101": 23}[architecture], 3] 89 | self.block = Bottleneck 90 | 91 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 92 | stride=2, padding=3, bias=False) 93 | self.bn1 = nn.BatchNorm2d(64, eps=1e-5, momentum=0.01, affine=True) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | 97 | self.layer1 = self.make_layer(self.block, 64, self.layers[0]) 98 | self.layer2 = self.make_layer( 99 | self.block, 128, self.layers[1], stride=2) 100 | self.layer3 = self.make_layer( 101 | self.block, 256, self.layers[2], stride=2) 102 | 103 | self.layer4 = self.make_layer( 104 | self.block, 512, self.layers[3], stride=2) 105 | 106 | def forward(self, x): 107 | x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) # 64 * h/4 * w/4 108 | x = self.layer1(x) # 256 * h/4 * w/4 109 | x = self.layer2(x) # 512 * h/8 * w/8 110 | x = self.layer3(x) # 1024 * h/16 * w/16 111 | x = self.layer4(x) # 2048 * h/32 * w/32 112 | return x 113 | 114 | def stages(self): 115 | return [self.layer1, self.layer2, self.layer3, self.layer4] 116 | 117 | def make_layer(self, block, planes, blocks, stride=1): 118 | downsample = None 119 | if stride != 1 or self.inplanes != planes * block.expansion: 120 | downsample = nn.Sequential( 121 | nn.Conv2d(self.inplanes, planes * block.expansion, 122 | kernel_size=1, stride=stride, bias=False), 123 | nn.BatchNorm2d(planes * block.expansion), 124 | ) 125 | 126 | layers = [] 127 | if downsample is not None: 128 | layers.append(block(self.inplanes, planes, stride, downsample, reduction=True)) 129 | else: 130 | layers.append(block(self.inplanes, planes, stride, downsample)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | 138 | #_____________________DUC Layer_____________________# 139 | 140 | class DUC(nn.Module): 141 | ''' 142 | INPUT: inplanes, planes, upscale_factor 143 | OUTPUT: (planes // 4)* ht * wd 144 | ''' 145 | def __init__(self, inplanes, planes, upscale_factor=2): 146 | super(DUC, self).__init__() 147 | self.conv = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, bias=False) 148 | self.bn = nn.BatchNorm2d(planes) 149 | self.relu = nn.ReLU() 150 | 151 | self.pixel_shuffle = nn.PixelShuffle(upscale_factor) 152 | 153 | def forward(self, x): 154 | x = self.conv(x) 155 | x = self.bn(x) 156 | x = self.relu(x) 157 | x = self.pixel_shuffle(x) 158 | return x 159 | 160 | #_____________________Create the whole network_____________________# 161 | 162 | def createModel(): 163 | return FastPose() 164 | 165 | 166 | class FastPose(nn.Module): 167 | DIM = 128 168 | 169 | def __init__(self): 170 | super(FastPose, self).__init__() 171 | 172 | self.preact = SEResnet('resnet101') 173 | 174 | self.suffle1 = nn.PixelShuffle(2) 175 | self.duc1 = DUC(512, 1024, upscale_factor=2) 176 | self.duc2 = DUC(256, 512, upscale_factor=2) 177 | 178 | self.conv_out = nn.Conv2d( 179 | self.DIM, 33, kernel_size=3, stride=1, padding=1) 180 | 181 | 182 | def forward(self, x: Variable): 183 | out = self.preact(x) 184 | out = self.suffle1(out) 185 | out = self.duc1(out) 186 | out = self.duc2(out) 187 | out = self.conv_out(out) 188 | return out 189 | 190 | 191 | 192 | 193 | model = InferenNet() 194 | example = torch.rand(1, 3, 224, 224) 195 | traced_script_module = torch.jit.trace(model, example) 196 | traced_script_module.save("posemodel.pt") -------------------------------------------------------------------------------- /prediction.cpp: -------------------------------------------------------------------------------- 1 | // One-stop header. 2 | #include 3 | 4 | // headers for opencv 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #define kIMAGE_SIZE_W 256 17 | #define kIMAGE_SIZE_H 320 18 | #define kCHANNELS 3 19 | 20 | bool LoadImage(std::string file_name, cv::Mat &image , cv::Mat &cloneimage) { 21 | image = cv::imread(file_name); // CV_8UC3 22 | cloneimage = image.clone(); 23 | std::cout << "Image (W,H) : " << image.size().width << " , "<< image.size().height < " 52 | << std::endl; 53 | return -1; 54 | } 55 | 56 | torch::jit::script::Module module = torch::jit::load(argv[1]); 57 | std::cout << "== Switch to GPU mode" << std::endl; 58 | // to GPU 59 | module.to(at::kCUDA); 60 | 61 | //assert(module != nullptr); 62 | std::cout << "== PoseModel loaded!\n"; 63 | 64 | std::string file_name = ""; 65 | cv::Mat image; 66 | cv::Mat copied_image; 67 | while (true) { 68 | std::cout << "== Input image path: [enter Q to exit]" << std::endl; 69 | std::cin >> file_name; 70 | if (file_name == "Q") { 71 | break; 72 | } 73 | if (LoadImage(file_name, image ,copied_image)) { 74 | myfile << "Image Data Size " << std::endl; 75 | myfile << image << std::endl; 76 | auto input_tensor = torch::from_blob( 77 | image.data, {1, kIMAGE_SIZE_H, kIMAGE_SIZE_W, kCHANNELS}); 78 | myfile << "Input Tensor Before Normalization " << std::endl; 79 | myfile << input_tensor << std::endl; 80 | 81 | input_tensor[0][0] = input_tensor[0][0].sub_(0.485).div_(0.229); 82 | input_tensor[0][1] = input_tensor[0][1].sub_(0.456).div_(0.224); 83 | input_tensor[0][2] = input_tensor[0][2].sub_(0.406).div_(0.225); 84 | 85 | input_tensor = input_tensor.permute({0, 3, 1, 2}); 86 | 87 | 88 | // to GPU 89 | input_tensor = input_tensor.to(at::kCUDA); 90 | 91 | 92 | auto output = module.forward({input_tensor}); //type : [ Variable[CUDAFloatType]{1,17,56,56} ] 93 | 94 | torch::Tensor out_tensor = output.toTensor(); 95 | 96 | out_tensor = out_tensor.to(at::kCPU); 97 | 98 | 99 | //-------------------Finding the keypoint with the highest probability -------------------------------------------// 100 | 101 | int coor[17][3]; 102 | int max_x = 0; 103 | int max_y = 0; 104 | int prob = 0; 105 | 106 | float max = out_tensor[0][1][0][0].item().toFloat(); 107 | 108 | 109 | myfile << "Input_tensor: " << std::endl; 110 | myfile << input_tensor << std::endl; 111 | myfile << "Output_tensor: " << std::endl; 112 | myfile << out_tensor << std::endl; 113 | myfile.close(); 114 | 115 | for(int kpts=0;kpts<17;kpts++){ 116 | max = out_tensor[0][kpts][0][0].item().toFloat(); 117 | for(int i=0;i<80;i++){ 118 | for(int j=0;j<64;j++){ 119 | //std::cout << out_tensor[0][1][i][j].item().toFloat() << std::endl; 120 | if(out_tensor[0][kpts][i][j].item().toFloat() > max ){ 121 | // std::cout << i << " " << j << " " ; 122 | max = out_tensor[0][kpts][i][j].item().toFloat(); 123 | max_x = j ; 124 | max_y = i ; 125 | prob = (int)(max*100); 126 | } 127 | 128 | if(out_tensor[0][kpts][i][j].item().toFloat() > 0.005 ){ 129 | 130 | //std::cout << kpts << " : " << i << " " << j << " " << 100*out_tensor[0][kpts][i][j].item().toFloat() << std::endl ; 131 | 132 | 133 | } 134 | 135 | } 136 | } 137 | coor[kpts][0] = max_x ; 138 | coor[kpts][1] = max_y ; 139 | coor[kpts][2] = prob ; 140 | 141 | 142 | } 143 | 144 | 145 | 146 | //std::cout << out_tensor << std::endl; 147 | 148 | //-------------------Display Result -------------------------------------------// 149 | 150 | cv::Point p(0,0); 151 | 152 | for(int kpts=0;kpts<17;kpts++){ 153 | std::cout << "Keypoint : " << kpts << std::endl; 154 | std::cout << "w : " << copied_image.size().width << " h : " << copied_image.size().height <