├── .gitignore ├── README.md ├── coco_utils.py ├── cocoapi ├── LuaAPI │ ├── CocoApi.lua │ ├── MaskApi.lua │ ├── cocoDemo.lua │ ├── env.lua │ ├── init.lua │ └── rocks │ │ └── coco-scm-1.rockspec ├── MatlabAPI │ ├── CocoApi.m │ ├── CocoEval.m │ ├── CocoUtils.m │ ├── MaskApi.m │ ├── cocoDemo.m │ ├── evalDemo.m │ ├── gason.m │ └── private │ │ ├── gasonMex.cpp │ │ ├── gasonMex.mexa64 │ │ ├── gasonMex.mexmaci64 │ │ ├── getPrmDflt.m │ │ ├── maskApiMex.c │ │ ├── maskApiMex.mexa64 │ │ └── maskApiMex.mexmaci64 ├── PythonAPI │ ├── Makefile │ ├── main.py │ ├── mul-main.py │ ├── pycocoDemo.ipynb │ ├── pycocoEvalDemo.ipynb │ ├── pycocotools │ │ ├── __init__.py │ │ ├── _mask.c │ │ ├── _mask.pyx │ │ ├── _mask.so │ │ ├── coco.py │ │ ├── cocoeval.py │ │ └── mask.py │ └── setup.py └── common │ ├── gason.cpp │ ├── gason.h │ ├── maskApi.c │ └── maskApi.h ├── dataloader.py ├── estimator.py ├── eval.py ├── mobilenetv2.py ├── networks.py ├── pose_dataset ├── lsp │ ├── README.txt │ └── joints.mat ├── lsp_dataset.py ├── lsp_ext │ ├── README.txt │ ├── joints.mat │ ├── test_joints.csv │ ├── train_joints.csv │ └── train_lsp_small_joints.csv ├── mpii │ ├── data.json │ ├── mpii_human_pose_v1_u12_1.mat │ ├── mpii_human_pose_v1_u12_2 │ │ ├── README.md │ │ └── bsd.txt │ ├── test_joints.csv │ └── train_joints.csv └── mpii_dataset.py ├── pycocotools ├── run_webcam.py └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | */.ipynb_checkpoints/* 3 | *.t7 4 | *.jpg 5 | *.txt 6 | *.pyc 7 | models 8 | __pycache__ 9 | others 10 | .vscode 11 | .gitignore 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MobilePose 2 | 3 | MobilePose is a **Tiny** PyTorch implementation of single person 2D pose estimation framework. The aim is to provide the interface of the training/inference/evaluation, and the dataloader with various data augmentation options. And final trained model can satisfy basic requirements(speed+size+accuracy) for mobile device. 4 | 5 | Some codes for mobilenetV2 and display are brought from [pytorch-mobilenet-v2](https://github.com/tonylins/pytorch-mobilenet-v2) and [tf-pose-estimation](https://github.com/ildoonet/tf-pose-estimation). Thanks to the original authors. 6 | 7 | ## Functionality 8 | 9 | 1. **Tiny** trained model (Resnet18[43MB], MobilenetV2[8.9MB]) 10 | 2. **Fast** inference speed (GPU>100FPS, CPU~30FPS) 11 | 3. **Accurate** keypoint estimation (75~85mAP(0.5IoU)) 12 | 13 | ## Requirements: 14 | 15 | - Python 3.6.2 16 | - Pytorch 0.2.0\_3 17 | - imgaug 0.2.5 18 | 19 | ## Todo List: 20 | 21 | - [x] multi-thread dataloader 22 | - [x] training and inference 23 | - [x] performance evaluation 24 | - [x] multi-scale training 25 | - [x] support resnet18/mobilenetV2 26 | - [x] data augmentation(rotate/shift/flip/multi-scale/noise) 27 | - [x] Macbook camera realtime display script 28 | 29 | ## Usage 30 | 31 | 1. Training: 32 | ```shell 33 | export CUDA_VISIBLE_DEVICES=0; python training.py --model=mobilenet/resnet --gpu=0 34 | ``` 35 | 2. Evaluation 36 | ```shell 37 | export CUDA_VISIBLE_DEVICES=0; python eval.py --model=mobilenet/resnet 38 | ``` 39 | 4. Realtime visualization: 40 | ```shell 41 | python run_webcam.py --model=mobilenet/resnet 42 | ``` 43 | 44 | ## Contributors 45 | 46 | MobilePose is developed and maintained by [Yuliang Xiu](http://xiuyuliang.cn/about/), [Zexin Chen](https://github.com/ZexinChen) and [Yinghong Fang](https://github.com/Fangyh09). 47 | 48 | ## License 49 | 50 | MobilePose is freely available for free non-commercial use. For commercial queries, please contact [Cewu Lu](http://www.mvig.org/) or [SightPlus Co. Ltd](https://www.sightp.com/). 51 | 52 | -------------------------------------------------------------------------------- /coco_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File: coco_utils.py 3 | Project: MobilePose 4 | File Created: Saturday, 3rd March 2018 7:04:57 pm 5 | Author: Yuliang Xiu (yuliangxiu@sjtu.edu.cn) 6 | ----- 7 | Last Modified: Thursday, 8th March 2018 3:02:15 pm 8 | Modified By: Yuliang Xiu (yuliangxiu@sjtu.edu.cn>) 9 | ----- 10 | Copyright 2018 - 2018 Shanghai Jiao Tong University, Machine Vision and Intelligence Group 11 | ''' 12 | 13 | 14 | # define coco class 15 | import json 16 | import numpy as np 17 | from collections import namedtuple, Mapping 18 | 19 | # Create namedtuple without defaults 20 | def namedtuple_with_defaults(typename, field_names, default_values=()): 21 | T = namedtuple(typename, field_names) 22 | T.__new__.__defaults__ = (None,) * len(T._fields) 23 | if isinstance(default_values, Mapping): 24 | prototype = T(**default_values) 25 | else: 26 | prototype = T(*default_values) 27 | T.__new__.__defaults__ = tuple(prototype) 28 | return T 29 | 30 | # Used for solving TypeError: Object of type 'float32' is not JSON serializable 31 | class MyEncoder(json.JSONEncoder): 32 | def default(self, obj): 33 | if isinstance(obj, np.integer): 34 | return int(obj) 35 | elif isinstance(obj, np.floating): 36 | return float(obj) 37 | elif isinstance(obj, np.ndarray): 38 | return obj.tolist() 39 | else: 40 | return super(MyEncoder, self).default(obj) 41 | 42 | # Classes for coco groud truth, CocoImage and CocoAnnotation 43 | CocoImage = namedtuple_with_defaults('image', ['file_name', 'height', 'width', 'id']) 44 | CocoAnnotation = namedtuple_with_defaults('annotation', ['num_keypoints', 'area', 45 | 'iscrowd', 'keypoints', 46 | 'image_id', 'bbox', 'category_id', 47 | 'id']) 48 | class CocoData: 49 | def __init__(self, coco_images_arr, coco_annotations_arr): 50 | self.Coco = {} 51 | coco_images_arr = [item._asdict() for item in coco_images_arr] 52 | coco_annotations_arr = [item._asdict() for item in coco_annotations_arr] 53 | self.Coco['images'] = coco_images_arr 54 | self.Coco['annotations'] = coco_annotations_arr 55 | self.Coco['categories'] = [{"id": 1, "name": "test"}] 56 | 57 | def dumps(self): 58 | return json.dumps(self.Coco, cls=MyEncoder) 59 | 60 | # Change keypoints [x, y, prob] prob = int(prob) 61 | def float2int(str_data): 62 | json_data = json.loads(str_data) 63 | annotations = [] 64 | if 'annotations' in json_data: 65 | annotations = json_data['annotations'] 66 | else: 67 | annotations = json_data 68 | json_size = len(annotations) 69 | for i in range(json_size): 70 | annotation = annotations[i] 71 | keypoints = annotation['keypoints'] 72 | keypoints_num = int(len(keypoints) / 3) 73 | for j in range(keypoints_num): 74 | keypoints[j * 3 + 2] = int(round(keypoints[j * 3 + 2])) 75 | return json.dumps(json_data) 76 | 77 | # Append coco ground truth to coco_images_arr and coco_annotations_arr 78 | def transform_to_coco_gt(datas, coco_images_arr, coco_annotations_arr): 79 | """ 80 | data: num_samples * 32, type Tensor 81 | 16 keypoints 82 | 83 | output: 84 | inside coco_images_arr, coco_annotations_arr 85 | """ 86 | for idx, sample in enumerate(datas): 87 | coco_image = CocoImage() 88 | coco_annotation = CocoAnnotation() 89 | sample = np.array(sample.numpy()).reshape(-1, 2) 90 | num_keypoints = len(sample) 91 | keypoints = np.append(sample, np.array(np.ones(num_keypoints).reshape(-1, 1) * 2), 92 | axis=1) 93 | xmin = np.min(sample[:,0]) 94 | ymin = np.min(sample[:,1]) 95 | xmax = np.max(sample[:,0]) 96 | ymax = np.max(sample[:,1]) 97 | width = ymax - ymin 98 | height = xmax - xmin 99 | coco_image = coco_image._replace(id = idx, width=width, height=height, file_name="") 100 | coco_annotation = coco_annotation._replace(num_keypoints=num_keypoints) 101 | coco_annotation = coco_annotation._replace(area=width*height) 102 | coco_annotation = coco_annotation._replace(keypoints=keypoints.reshape(-1)) 103 | coco_annotation = coco_annotation._replace(image_id=idx) 104 | coco_annotation = coco_annotation._replace(bbox=[xmin, ymin, width, height]) 105 | coco_annotation = coco_annotation._replace(category_id=1) # default "1" for keypoint 106 | coco_annotation = coco_annotation._replace(id=idx) 107 | coco_annotation = coco_annotation._replace(iscrowd=0) 108 | coco_images_arr.append(coco_image) 109 | coco_annotations_arr.append(coco_annotation) 110 | return () 111 | 112 | # Coco predict result class 113 | CocoPredictAnnotation = namedtuple_with_defaults('predict_anno', ['image_id', 'category_id', 'keypoints', 'score']) 114 | 115 | # Append coco predict result to coco_images_arr and coco_pred_annotations_arr 116 | def transform_to_coco_pred(datas, coco_pred_annotations_arr, beg_idx): 117 | """ 118 | data: num_samples * 32, type Variable 119 | 16 keypoints 120 | 121 | output: 122 | inside coco_pred_annotations_arr 123 | """ 124 | for idx, sample in enumerate(datas): 125 | coco_pred_annotation = CocoPredictAnnotation() 126 | 127 | sample = np.array(sample.data.cpu().numpy()).reshape(-1, 2) 128 | num_keypoints = len(sample) 129 | keypoints = np.append(sample, np.array(np.ones(num_keypoints).reshape(-1, 1) * 2), 130 | axis=1) 131 | xmin = np.min(sample[:,0]) 132 | ymin = np.min(sample[:,1]) 133 | xmax = np.max(sample[:,0]) 134 | ymax = np.max(sample[:,1]) 135 | width = ymax - ymin 136 | height = xmax - xmin 137 | # set value 138 | cur_idx = beg_idx + idx 139 | coco_pred_annotation = coco_pred_annotation._replace(image_id=cur_idx) 140 | coco_pred_annotation = coco_pred_annotation._replace(category_id=1) 141 | coco_pred_annotation = coco_pred_annotation._replace(keypoints=keypoints.reshape(-1)) 142 | coco_pred_annotation = coco_pred_annotation._replace(score=2) 143 | # add to arr 144 | coco_pred_annotations_arr.append(coco_pred_annotation) 145 | return () -------------------------------------------------------------------------------- /cocoapi/LuaAPI/CocoApi.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | 3 | Interface for accessing the Common Objects in COntext (COCO) dataset. 4 | 5 | For an overview of the API please see http://mscoco.org/dataset/#download. 6 | CocoApi.lua (this file) is modeled after the Matlab CocoApi.m: 7 | https://github.com/pdollar/coco/blob/master/MatlabAPI/CocoApi.m 8 | 9 | The following API functions are defined in the Lua API: 10 | CocoApi - Load COCO annotation file and prepare data structures. 11 | getAnnIds - Get ann ids that satisfy given filter conditions. 12 | getCatIds - Get cat ids that satisfy given filter conditions. 13 | getImgIds - Get img ids that satisfy given filter conditions. 14 | loadAnns - Load anns with the specified ids. 15 | loadCats - Load cats with the specified ids. 16 | loadImgs - Load imgs with the specified ids. 17 | showAnns - Display the specified annotations. 18 | Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 19 | For detailed usage information please see cocoDemo.lua. 20 | 21 | LIMITATIONS: the following API functions are NOT defined in the Lua API: 22 | loadRes - Load algorithm results and create API for accessing them. 23 | download - Download COCO images from mscoco.org server. 24 | In addition, currently the getCatIds() and getImgIds() do not accept filters. 25 | getAnnIds() can be called using getAnnIds({imgId=id}) and getAnnIds({catId=id}). 26 | 27 | Note: loading COCO JSON annotations to Lua tables is quite slow. Hence, a call 28 | to CocApi(annFile) converts the annotations to a custom 'flattened' format that 29 | is more efficient. The first time a COCO JSON is loaded, the conversion is 30 | invoked (this may take up to a minute). The converted data is then stored in a 31 | t7 file (the code must have write permission to the dir of the JSON file). 32 | Future calls of cocoApi=CocApi(annFile) take a fraction of a second. To view the 33 | created data just inspect cocoApi.data of a created instance of the CocoApi. 34 | 35 | Common Objects in COntext (COCO) Toolbox. version 3.0 36 | Data, paper, and tutorials available at: http://mscoco.org/ 37 | Code written by Pedro O. Pinheiro and Piotr Dollar, 2016. 38 | Licensed under the Simplified BSD License [see coco/license.txt] 39 | 40 | ------------------------------------------------------------------------------]] 41 | 42 | local json = require 'cjson' 43 | local coco = require 'coco.env' 44 | 45 | local TensorTable = torch.class('TensorTable',coco) 46 | local CocoSeg = torch.class('CocoSeg',coco) 47 | local CocoApi = torch.class('CocoApi',coco) 48 | 49 | -------------------------------------------------------------------------------- 50 | 51 | --[[ TensorTable is a lightweight data structure for storing variable size 1D 52 | tensors. Tables of tensors are slow to save/load to disk. Instead, TensorTable 53 | stores all the data in a single long tensor (along with indices into the tensor) 54 | making serialization fast. A TensorTable may only contain 1D same-type torch 55 | tensors or strings. It supports only creation from a table and indexing. ]] 56 | 57 | function TensorTable:__init( T ) 58 | local n = #T; assert(n>0) 59 | local isStr = torch.type(T[1])=='string' 60 | assert(isStr or torch.isTensor(T[1])) 61 | local c=function(s) return torch.CharTensor(torch.CharStorage():string(s)) end 62 | if isStr then local S=T; T={}; for i=1,n do T[i]=c(S[i]) end end 63 | local ms, idx = torch.LongTensor(n), torch.LongTensor(n+1) 64 | for i=1,n do ms[i]=T[i]:numel() end 65 | idx[1]=1; idx:narrow(1,2,n):copy(ms); idx=idx:cumsum() 66 | local type = string.sub(torch.type(T[1]),7,-1) 67 | local data = torch[type](idx[n+1]-1) 68 | if isStr then type='string' end 69 | for i=1,n do if ms[i]>0 then data:sub(idx[i],idx[i+1]-1):copy(T[i]) end end 70 | if ms:eq(ms[1]):all() and ms[1]>0 then data=data:view(n,ms[1]); idx=nil end 71 | self.data, self.idx, self.type = data, idx, type 72 | end 73 | 74 | function TensorTable:__index__( i ) 75 | if torch.type(i)~='number' then return false end 76 | local d, idx, type = self.data, self.idx, self.type 77 | if idx and idx[i]==idx[i+1] then 78 | if type=='string' then d='' else d=torch[type]() end 79 | else 80 | if idx then d=d:sub(idx[i],idx[i+1]-1) else d=d[i] end 81 | if type=='string' then d=d:clone():storage():string() end 82 | end 83 | return d, true 84 | end 85 | 86 | -------------------------------------------------------------------------------- 87 | 88 | --[[ CocoSeg is an efficient data structure for storing COCO segmentations. ]] 89 | 90 | function CocoSeg:__init( segs ) 91 | local polys, pIdx, sizes, rles, p, isStr = {}, {}, {}, {}, 0, 0 92 | for i,seg in pairs(segs) do if seg.size then isStr=seg.counts break end end 93 | isStr = torch.type(isStr)=='string' 94 | for i,seg in pairs(segs) do 95 | pIdx[i], sizes[i] = {}, {} 96 | if seg.size then 97 | sizes[i],rles[i] = seg.size,seg.counts 98 | else 99 | if isStr then rles[i]='' else rles[i]={} end 100 | for j=1,#seg do p=p+1; pIdx[i][j],polys[p] = p,seg[j] end 101 | end 102 | pIdx[i],sizes[i] = torch.LongTensor(pIdx[i]),torch.IntTensor(sizes[i]) 103 | if not isStr then rles[i]=torch.IntTensor(rles[i]) end 104 | end 105 | for i=1,p do polys[i]=torch.DoubleTensor(polys[i]) end 106 | self.polys, self.pIdx = coco.TensorTable(polys), coco.TensorTable(pIdx) 107 | self.sizes, self.rles = coco.TensorTable(sizes), coco.TensorTable(rles) 108 | end 109 | 110 | function CocoSeg:__index__( i ) 111 | if torch.type(i)~='number' then return false end 112 | if self.sizes[i]:numel()>0 then 113 | return {size=self.sizes[i],counts=self.rles[i]}, true 114 | else 115 | local ids, polys = self.pIdx[i], {} 116 | for i=1,ids:numel() do polys[i]=self.polys[ids[i]] end 117 | return polys, true 118 | end 119 | end 120 | 121 | -------------------------------------------------------------------------------- 122 | 123 | --[[ CocoApi is the API to the COCO dataset, see main comment for details. ]] 124 | 125 | function CocoApi:__init( annFile ) 126 | assert( string.sub(annFile,-4,-1)=='json' and paths.filep(annFile) ) 127 | local torchFile = string.sub(annFile,1,-6) .. '.t7' 128 | if not paths.filep(torchFile) then self:__convert(annFile,torchFile) end 129 | local data = torch.load(torchFile) 130 | self.data, self.inds = data, {} 131 | for k,v in pairs({images='img',categories='cat',annotations='ann'}) do 132 | local M = {}; self.inds[v..'IdsMap']=M 133 | if data[k] then for i=1,data[k].id:size(1) do M[data[k].id[i]]=i end end 134 | end 135 | end 136 | 137 | function CocoApi:__convert( annFile, torchFile ) 138 | print('convert: '..annFile..' --> .t7 [please be patient]') 139 | local tic = torch.tic() 140 | -- load data and decode json 141 | local data = torch.CharStorage(annFile):string() 142 | data = json.decode(data); collectgarbage() 143 | -- transpose and flatten each field in the coco data struct 144 | local convert = {images=true, categories=true, annotations=true} 145 | for field, d in pairs(data) do if convert[field] then 146 | print('converting: '..field) 147 | local n, out = #d, {} 148 | if n==0 then d,n={d},1 end 149 | for k,v in pairs(d[1]) do 150 | local t, isReg = torch.type(v), true 151 | for i=1,n do isReg=isReg and torch.type(d[i][k])==t end 152 | if t=='number' and isReg then 153 | out[k] = torch.DoubleTensor(n) 154 | for i=1,n do out[k][i]=d[i][k] end 155 | elseif t=='string' and isReg then 156 | out[k]={}; for i=1,n do out[k][i]=d[i][k] end 157 | out[k] = coco.TensorTable(out[k]) 158 | elseif t=='table' and isReg and torch.type(v[1])=='number' then 159 | out[k]={}; for i=1,n do out[k][i]=torch.DoubleTensor(d[i][k]) end 160 | out[k] = coco.TensorTable(out[k]) 161 | if not out[k].idx then out[k]=out[k].data end 162 | else 163 | out[k]={}; for i=1,n do out[k][i]=d[i][k] end 164 | if k=='segmentation' then out[k] = coco.CocoSeg(out[k]) end 165 | end 166 | collectgarbage() 167 | end 168 | if out.id then out.idx=torch.range(1,out.id:size(1)) end 169 | data[field] = out 170 | collectgarbage() 171 | end end 172 | -- create mapping from cat/img index to anns indices for that cat/img 173 | print('convert: building indices') 174 | local makeMap = function( type, type_id ) 175 | if not data[type] or not data.annotations then return nil end 176 | local invmap, n = {}, data[type].id:size(1) 177 | for i=1,n do invmap[data[type].id[i]]=i end 178 | local map = {}; for i=1,n do map[i]={} end 179 | data.annotations[type_id..'x'] = data.annotations[type_id]:clone() 180 | for i=1,data.annotations.id:size(1) do 181 | local id = invmap[data.annotations[type_id][i]] 182 | data.annotations[type_id..'x'][i] = id 183 | table.insert(map[id],data.annotations.id[i]) 184 | end 185 | for i=1,n do map[i]=torch.LongTensor(map[i]) end 186 | return coco.TensorTable(map) 187 | end 188 | data.annIdsPerImg = makeMap('images','image_id') 189 | data.annIdsPerCat = makeMap('categories','category_id') 190 | -- save to disk 191 | torch.save( torchFile, data ) 192 | print(('convert: complete [%.2f s]'):format(torch.toc(tic))) 193 | end 194 | 195 | function CocoApi:getAnnIds( filters ) 196 | if not filters then filters = {} end 197 | if filters.imgId then 198 | return self.data.annIdsPerImg[self.inds.imgIdsMap[filters.imgId]] or {} 199 | elseif filters.catId then 200 | return self.data.annIdsPerCat[self.inds.catIdsMap[filters.catId]] or {} 201 | else 202 | return self.data.annotations.id 203 | end 204 | end 205 | 206 | function CocoApi:getCatIds() 207 | return self.data.categories.id 208 | end 209 | 210 | function CocoApi:getImgIds() 211 | return self.data.images.id 212 | end 213 | 214 | function CocoApi:loadAnns( ids ) 215 | return self:__load(self.data.annotations,self.inds.annIdsMap,ids) 216 | end 217 | 218 | function CocoApi:loadCats( ids ) 219 | return self:__load(self.data.categories,self.inds.catIdsMap,ids) 220 | end 221 | 222 | function CocoApi:loadImgs( ids ) 223 | return self:__load(self.data.images,self.inds.imgIdsMap,ids) 224 | end 225 | 226 | function CocoApi:showAnns( img, anns ) 227 | local n, h, w = #anns, img:size(2), img:size(3) 228 | local MaskApi, clrs = coco.MaskApi, torch.rand(n,3)*.6+.4 229 | local O = img:clone():contiguous():float() 230 | if n==0 then anns,n={anns},1 end 231 | if anns[1].keypoints then for i=1,n do if anns[i].iscrowd==0 then 232 | local sk, kp, j, k = self:loadCats(anns[i].category_id)[1].skeleton 233 | kp=anns[i].keypoints; k=kp:size(1); j=torch.range(1,k,3):long(); k=k/3; 234 | local x,y,v = kp:index(1,j), kp:index(1,j+1), kp:index(1,j+2) 235 | for _,s in pairs(sk) do if v[s[1]]>0 and v[s[2]]>0 then 236 | MaskApi.drawLine(O,x[s[1]],y[s[1]],x[s[2]],y[s[2]],.75,clrs[i]) 237 | end end 238 | for j=1,k do if v[j]==1 then MaskApi.drawCirc(O,x[j],y[j],4,{0,0,0}) end end 239 | for j=1,k do if v[j]>0 then MaskApi.drawCirc(O,x[j],y[j],3,clrs[i]) end end 240 | end end end 241 | if anns[1].segmentation or anns[1].bbox then 242 | local Rs, alpha = {}, anns[1].keypoints and .25 or .4 243 | for i=1,n do 244 | Rs[i]=anns[i].segmentation 245 | if Rs[i] and #Rs[i]>0 then Rs[i]=MaskApi.frPoly(Rs[i],h,w) end 246 | if not Rs[i] then Rs[i]=MaskApi.frBbox(anns[i].bbox,h,w)[1] end 247 | end 248 | MaskApi.drawMasks(O,MaskApi.decode(Rs),nil,alpha,clrs) 249 | end 250 | return O 251 | end 252 | 253 | function CocoApi:__load( data, map, ids ) 254 | if not torch.isTensor(ids) then ids=torch.LongTensor({ids}) end 255 | local out, idx = {}, nil 256 | for i=1,ids:numel() do 257 | out[i], idx = {}, map[ids[i]] 258 | for k,v in pairs(data) do out[i][k]=v[idx] end 259 | end 260 | return out 261 | end 262 | -------------------------------------------------------------------------------- /cocoapi/LuaAPI/MaskApi.lua: -------------------------------------------------------------------------------- 1 | --[[---------------------------------------------------------------------------- 2 | 3 | Interface for manipulating masks stored in RLE format. 4 | 5 | For an overview of RLE please see http://mscoco.org/dataset/#download. 6 | Additionally, more detailed information can be found in the Matlab MaskApi.m: 7 | https://github.com/pdollar/coco/blob/master/MatlabAPI/MaskApi.m 8 | 9 | The following API functions are defined: 10 | encode - Encode binary masks using RLE. 11 | decode - Decode binary masks encoded via RLE. 12 | merge - Compute union or intersection of encoded masks. 13 | iou - Compute intersection over union between masks. 14 | nms - Compute non-maximum suppression between ordered masks. 15 | area - Compute area of encoded masks. 16 | toBbox - Get bounding boxes surrounding encoded masks. 17 | frBbox - Convert bounding boxes to encoded masks. 18 | frPoly - Convert polygon to encoded mask. 19 | drawCirc - Draw circle into image (alters input). 20 | drawLine - Draw line into image (alters input). 21 | drawMasks - Draw masks into image (alters input). 22 | 23 | Usage: 24 | Rs = MaskApi.encode( masks ) 25 | masks = MaskApi.decode( Rs ) 26 | R = MaskApi.merge( Rs, [intersect=false] ) 27 | o = MaskApi.iou( dt, gt, [iscrowd=false] ) 28 | keep = MaskApi.nms( dt, thr ) 29 | a = MaskApi.area( Rs ) 30 | bbs = MaskApi.toBbox( Rs ) 31 | Rs = MaskApi.frBbox( bbs, h, w ) 32 | R = MaskApi.frPoly( poly, h, w ) 33 | MaskApi.drawCirc( img, x, y, rad, clr ) 34 | MaskApi.drawLine( img, x0, y0, x1, y1, rad, clr ) 35 | MaskApi.drawMasks( img, masks, [maxn=n], [alpha=.4], [clrs] ) 36 | For detailed usage information please see cocoDemo.lua. 37 | 38 | In the API the following formats are used: 39 | R,Rs - [table] Run-length encoding of binary mask(s) 40 | masks - [nxhxw] Binary mask(s) 41 | bbs - [nx4] Bounding box(es) stored as [x y w h] 42 | poly - Polygon stored as {[x1 y1 x2 y2...],[x1 y1 ...],...} 43 | dt,gt - May be either bounding boxes or encoded masks 44 | Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 45 | 46 | Common Objects in COntext (COCO) Toolbox. version 3.0 47 | Data, paper, and tutorials available at: http://mscoco.org/ 48 | Code written by Pedro O. Pinheiro and Piotr Dollar, 2016. 49 | Licensed under the Simplified BSD License [see coco/license.txt] 50 | 51 | ------------------------------------------------------------------------------]] 52 | 53 | local ffi = require 'ffi' 54 | local coco = require 'coco.env' 55 | 56 | coco.MaskApi = {} 57 | local MaskApi = coco.MaskApi 58 | 59 | coco.libmaskapi = ffi.load(package.searchpath('libmaskapi',package.cpath)) 60 | local libmaskapi = coco.libmaskapi 61 | 62 | -------------------------------------------------------------------------------- 63 | 64 | MaskApi.encode = function( masks ) 65 | local n, h, w = masks:size(1), masks:size(2), masks:size(3) 66 | masks = masks:type('torch.ByteTensor'):transpose(2,3) 67 | local data = masks:contiguous():data() 68 | local Qs = MaskApi._rlesInit(n) 69 | libmaskapi.rleEncode(Qs[0],data,h,w,n) 70 | return MaskApi._rlesToLua(Qs,n) 71 | end 72 | 73 | MaskApi.decode = function( Rs ) 74 | local Qs, n, h, w = MaskApi._rlesFrLua(Rs) 75 | local masks = torch.ByteTensor(n,w,h):zero():contiguous() 76 | libmaskapi.rleDecode(Qs,masks:data(),n) 77 | MaskApi._rlesFree(Qs,n) 78 | return masks:transpose(2,3) 79 | end 80 | 81 | MaskApi.merge = function( Rs, intersect ) 82 | intersect = intersect or 0 83 | local Qs, n, h, w = MaskApi._rlesFrLua(Rs) 84 | local Q = MaskApi._rlesInit(1) 85 | libmaskapi.rleMerge(Qs,Q,n,intersect) 86 | MaskApi._rlesFree(Qs,n) 87 | return MaskApi._rlesToLua(Q,1)[1] 88 | end 89 | 90 | MaskApi.iou = function( dt, gt, iscrowd ) 91 | if not iscrowd then iscrowd = NULL else 92 | iscrowd = iscrowd:type('torch.ByteTensor'):contiguous():data() 93 | end 94 | if torch.isTensor(gt) and torch.isTensor(dt) then 95 | local nDt, k = dt:size(1), dt:size(2); assert(k==4) 96 | local nGt, k = gt:size(1), gt:size(2); assert(k==4) 97 | local dDt = dt:type('torch.DoubleTensor'):contiguous():data() 98 | local dGt = gt:type('torch.DoubleTensor'):contiguous():data() 99 | local o = torch.DoubleTensor(nGt,nDt):contiguous() 100 | libmaskapi.bbIou(dDt,dGt,nDt,nGt,iscrowd,o:data()) 101 | return o:transpose(1,2) 102 | else 103 | local qDt, nDt = MaskApi._rlesFrLua(dt) 104 | local qGt, nGt = MaskApi._rlesFrLua(gt) 105 | local o = torch.DoubleTensor(nGt,nDt):contiguous() 106 | libmaskapi.rleIou(qDt,qGt,nDt,nGt,iscrowd,o:data()) 107 | MaskApi._rlesFree(qDt,nDt); MaskApi._rlesFree(qGt,nGt) 108 | return o:transpose(1,2) 109 | end 110 | end 111 | 112 | MaskApi.nms = function( dt, thr ) 113 | if torch.isTensor(dt) then 114 | local n, k = dt:size(1), dt:size(2); assert(k==4) 115 | local Q = dt:type('torch.DoubleTensor'):contiguous():data() 116 | local kp = torch.IntTensor(n):contiguous() 117 | libmaskapi.bbNms(Q,n,kp:data(),thr) 118 | return kp 119 | else 120 | local Q, n = MaskApi._rlesFrLua(dt) 121 | local kp = torch.IntTensor(n):contiguous() 122 | libmaskapi.rleNms(Q,n,kp:data(),thr) 123 | MaskApi._rlesFree(Q,n) 124 | return kp 125 | end 126 | end 127 | 128 | MaskApi.area = function( Rs ) 129 | local Qs, n, h, w = MaskApi._rlesFrLua(Rs) 130 | local a = torch.IntTensor(n):contiguous() 131 | libmaskapi.rleArea(Qs,n,a:data()) 132 | MaskApi._rlesFree(Qs,n) 133 | return a 134 | end 135 | 136 | MaskApi.toBbox = function( Rs ) 137 | local Qs, n, h, w = MaskApi._rlesFrLua(Rs) 138 | local bb = torch.DoubleTensor(n,4):contiguous() 139 | libmaskapi.rleToBbox(Qs,bb:data(),n) 140 | MaskApi._rlesFree(Qs,n) 141 | return bb 142 | end 143 | 144 | MaskApi.frBbox = function( bbs, h, w ) 145 | if bbs:dim()==1 then bbs=bbs:view(1,bbs:size(1)) end 146 | local n, k = bbs:size(1), bbs:size(2); assert(k==4) 147 | local data = bbs:type('torch.DoubleTensor'):contiguous():data() 148 | local Qs = MaskApi._rlesInit(n) 149 | libmaskapi.rleFrBbox(Qs[0],data,h,w,n) 150 | return MaskApi._rlesToLua(Qs,n) 151 | end 152 | 153 | MaskApi.frPoly = function( poly, h, w ) 154 | local n = #poly 155 | local Qs, Q = MaskApi._rlesInit(n), MaskApi._rlesInit(1) 156 | for i,p in pairs(poly) do 157 | local xy = p:type('torch.DoubleTensor'):contiguous():data() 158 | libmaskapi.rleFrPoly(Qs[i-1],xy,p:size(1)/2,h,w) 159 | end 160 | libmaskapi.rleMerge(Qs,Q[0],n,0) 161 | MaskApi._rlesFree(Qs,n) 162 | return MaskApi._rlesToLua(Q,1)[1] 163 | end 164 | 165 | -------------------------------------------------------------------------------- 166 | 167 | MaskApi.drawCirc = function( img, x, y, rad, clr ) 168 | assert(img:isContiguous() and img:dim()==3) 169 | local k, h, w, data = img:size(1), img:size(2), img:size(3), img:data() 170 | for dx=-rad,rad do for dy=-rad,rad do 171 | local xi, yi = torch.round(x+dx), torch.round(y+dy) 172 | if dx*dx+dy*dy<=rad*rad and xi>=0 and yi>=0 and xi=0 and yi>=0 and xi= 5.1", 17 | "torch >= 7.0", 18 | "lua-cjson" 19 | } 20 | 21 | build = { 22 | type = "builtin", 23 | modules = { 24 | ["coco.env"] = "LuaAPI/env.lua", 25 | ["coco.init"] = "LuaAPI/init.lua", 26 | ["coco.MaskApi"] = "LuaAPI/MaskApi.lua", 27 | ["coco.CocoApi"] = "LuaAPI/CocoApi.lua", 28 | libmaskapi = { 29 | sources = { "common/maskApi.c" }, 30 | incdirs = { "common/" } 31 | } 32 | } 33 | } 34 | 35 | -- luarocks make LuaAPI/rocks/coco-scm-1.rockspec 36 | -- https://github.com/pdollar/coco/raw/master/LuaAPI/rocks/coco-scm-1.rockspec 37 | -------------------------------------------------------------------------------- /cocoapi/MatlabAPI/CocoApi.m: -------------------------------------------------------------------------------- 1 | classdef CocoApi 2 | % Interface for accessing the Microsoft COCO dataset. 3 | % 4 | % Microsoft COCO is a large image dataset designed for object detection, 5 | % segmentation, and caption generation. CocoApi.m is a Matlab API that 6 | % assists in loading, parsing and visualizing the annotations in COCO. 7 | % Please visit http://mscoco.org/ for more information on COCO, including 8 | % for the data, paper, and tutorials. The exact format of the annotations 9 | % is also described on the COCO website. For example usage of the CocoApi 10 | % please see cocoDemo.m. In addition to this API, please download both 11 | % the COCO images and annotations in order to run the demo. 12 | % 13 | % An alternative to using the API is to load the annotations directly 14 | % into a Matlab struct. This can be achieved via: 15 | % data = gason(fileread(annFile)); 16 | % Using the API provides additional utility functions. Note that this API 17 | % supports both *instance* and *caption* annotations. In the case of 18 | % captions not all functions are defined (e.g. categories are undefined). 19 | % 20 | % The following API functions are defined: 21 | % CocoApi - Load COCO annotation file and prepare data structures. 22 | % getAnnIds - Get ann ids that satisfy given filter conditions. 23 | % getCatIds - Get cat ids that satisfy given filter conditions. 24 | % getImgIds - Get img ids that satisfy given filter conditions. 25 | % loadAnns - Load anns with the specified ids. 26 | % loadCats - Load cats with the specified ids. 27 | % loadImgs - Load imgs with the specified ids. 28 | % showAnns - Display the specified annotations. 29 | % loadRes - Load algorithm results and create API for accessing them. 30 | % Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 31 | % Help on each functions can be accessed by: "help CocoApi>function". 32 | % 33 | % See also CocoApi>CocoApi, CocoApi>getAnnIds, CocoApi>getCatIds, 34 | % CocoApi>getImgIds, CocoApi>loadAnns, CocoApi>loadCats, 35 | % CocoApi>loadImgs, CocoApi>showAnns, CocoApi>loadRes 36 | % 37 | % Microsoft COCO Toolbox. version 2.0 38 | % Data, paper, and tutorials available at: http://mscoco.org/ 39 | % Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 40 | % Licensed under the Simplified BSD License [see coco/license.txt] 41 | 42 | properties 43 | data % COCO annotation data structure 44 | inds % data structures for fast indexing 45 | end 46 | 47 | methods 48 | function coco = CocoApi( annFile ) 49 | % Load COCO annotation file and prepare data structures. 50 | % 51 | % USAGE 52 | % coco = CocoApi( annFile ) 53 | % 54 | % INPUTS 55 | % annFile - COCO annotation filename 56 | % 57 | % OUTPUTS 58 | % coco - initialized coco object 59 | fprintf('Loading and preparing annotations... '); clk=clock; 60 | if(isstruct(annFile)), coco.data=annFile; else 61 | coco.data=gason(fileread(annFile)); end 62 | is.imgIds = [coco.data.images.id]'; 63 | is.imgIdsMap = makeMap(is.imgIds); 64 | if( isfield(coco.data,'annotations') ) 65 | ann=coco.data.annotations; o=[ann.image_id]; 66 | if(isfield(ann,'category_id')), o=o*1e10+[ann.category_id]; end 67 | [~,o]=sort(o); ann=ann(o); coco.data.annotations=ann; 68 | s={'category_id','area','iscrowd','id','image_id'}; 69 | t={'annCatIds','annAreas','annIscrowd','annIds','annImgIds'}; 70 | for f=1:5, if(isfield(ann,s{f})), is.(t{f})=[ann.(s{f})]'; end; end 71 | is.annIdsMap = makeMap(is.annIds); 72 | is.imgAnnIdsMap = makeMultiMap(is.imgIds,... 73 | is.imgIdsMap,is.annImgIds,is.annIds,0); 74 | end 75 | if( isfield(coco.data,'categories') ) 76 | is.catIds = [coco.data.categories.id]'; 77 | is.catIdsMap = makeMap(is.catIds); 78 | if(isfield(is,'annCatIds')), is.catImgIdsMap = makeMultiMap(... 79 | is.catIds,is.catIdsMap,is.annCatIds,is.annImgIds,1); end 80 | end 81 | coco.inds=is; fprintf('DONE (t=%0.2fs).\n',etime(clock,clk)); 82 | 83 | function map = makeMap( keys ) 84 | % Make map from key to integer id associated with key. 85 | if(isempty(keys)), map=containers.Map(); return; end 86 | map=containers.Map(keys,1:length(keys)); 87 | end 88 | 89 | function map = makeMultiMap( keys, keysMap, keysAll, valsAll, sqz ) 90 | % Make map from keys to set of vals associated with each key. 91 | js=values(keysMap,num2cell(keysAll)); js=[js{:}]; 92 | m=length(js); n=length(keys); k=zeros(1,n); 93 | for i=1:m, j=js(i); k(j)=k(j)+1; end; vs=zeros(n,max(k)); k(:)=0; 94 | for i=1:m, j=js(i); k(j)=k(j)+1; vs(j,k(j))=valsAll(i); end 95 | map = containers.Map('KeyType','double','ValueType','any'); 96 | if(sqz), for j=1:n, map(keys(j))=unique(vs(j,1:k(j))); end 97 | else for j=1:n, map(keys(j))=vs(j,1:k(j)); end; end 98 | end 99 | end 100 | 101 | function ids = getAnnIds( coco, varargin ) 102 | % Get ann ids that satisfy given filter conditions. 103 | % 104 | % USAGE 105 | % ids = coco.getAnnIds( params ) 106 | % 107 | % INPUTS 108 | % params - filtering parameters (struct or name/value pairs) 109 | % setting any filter to [] skips that filter 110 | % .imgIds - [] get anns for given imgs 111 | % .catIds - [] get anns for given cats 112 | % .areaRng - [] get anns for given area range (e.g. [0 inf]) 113 | % .iscrowd - [] get anns for given crowd label (0 or 1) 114 | % 115 | % OUTPUTS 116 | % ids - integer array of ann ids 117 | def = {'imgIds',[],'catIds',[],'areaRng',[],'iscrowd',[]}; 118 | [imgIds,catIds,ar,iscrowd] = getPrmDflt(varargin,def,1); 119 | if( length(imgIds)==1 ) 120 | t = coco.loadAnns(coco.inds.imgAnnIdsMap(imgIds)); 121 | if(~isempty(catIds)), t = t(ismember([t.category_id],catIds)); end 122 | if(~isempty(ar)), a=[t.area]; t = t(a>=ar(1) & a<=ar(2)); end 123 | if(~isempty(iscrowd)), t = t([t.iscrowd]==iscrowd); end 124 | ids = [t.id]; 125 | else 126 | ids=coco.inds.annIds; K = true(length(ids),1); t = coco.inds; 127 | if(~isempty(imgIds)), K = K & ismember(t.annImgIds,imgIds); end 128 | if(~isempty(catIds)), K = K & ismember(t.annCatIds,catIds); end 129 | if(~isempty(ar)), a=t.annAreas; K = K & a>=ar(1) & a<=ar(2); end 130 | if(~isempty(iscrowd)), K = K & t.annIscrowd==iscrowd; end 131 | ids=ids(K); 132 | end 133 | end 134 | 135 | function ids = getCatIds( coco, varargin ) 136 | % Get cat ids that satisfy given filter conditions. 137 | % 138 | % USAGE 139 | % ids = coco.getCatIds( params ) 140 | % 141 | % INPUTS 142 | % params - filtering parameters (struct or name/value pairs) 143 | % setting any filter to [] skips that filter 144 | % .catNms - [] get cats for given cat names 145 | % .supNms - [] get cats for given supercategory names 146 | % .catIds - [] get cats for given cat ids 147 | % 148 | % OUTPUTS 149 | % ids - integer array of cat ids 150 | if(~isfield(coco.data,'categories')), ids=[]; return; end 151 | def={'catNms',[],'supNms',[],'catIds',[]}; t=coco.data.categories; 152 | [catNms,supNms,catIds] = getPrmDflt(varargin,def,1); 153 | if(~isempty(catNms)), t = t(ismember({t.name},catNms)); end 154 | if(~isempty(supNms)), t = t(ismember({t.supercategory},supNms)); end 155 | if(~isempty(catIds)), t = t(ismember([t.id],catIds)); end 156 | ids = [t.id]; 157 | end 158 | 159 | function ids = getImgIds( coco, varargin ) 160 | % Get img ids that satisfy given filter conditions. 161 | % 162 | % USAGE 163 | % ids = coco.getImgIds( params ) 164 | % 165 | % INPUTS 166 | % params - filtering parameters (struct or name/value pairs) 167 | % setting any filter to [] skips that filter 168 | % .imgIds - [] get imgs for given ids 169 | % .catIds - [] get imgs with all given cats 170 | % 171 | % OUTPUTS 172 | % ids - integer array of img ids 173 | def={'imgIds',[],'catIds',[]}; ids=coco.inds.imgIds; 174 | [imgIds,catIds] = getPrmDflt(varargin,def,1); 175 | if(~isempty(imgIds)), ids=intersect(ids,imgIds); end 176 | if(isempty(catIds)), return; end 177 | t=values(coco.inds.catImgIdsMap,num2cell(catIds)); 178 | for i=1:length(t), ids=intersect(ids,t{i}); end 179 | end 180 | 181 | function anns = loadAnns( coco, ids ) 182 | % Load anns with the specified ids. 183 | % 184 | % USAGE 185 | % anns = coco.loadAnns( ids ) 186 | % 187 | % INPUTS 188 | % ids - integer ids specifying anns 189 | % 190 | % OUTPUTS 191 | % anns - loaded ann objects 192 | ids = values(coco.inds.annIdsMap,num2cell(ids)); 193 | anns = coco.data.annotations([ids{:}]); 194 | end 195 | 196 | function cats = loadCats( coco, ids ) 197 | % Load cats with the specified ids. 198 | % 199 | % USAGE 200 | % cats = coco.loadCats( ids ) 201 | % 202 | % INPUTS 203 | % ids - integer ids specifying cats 204 | % 205 | % OUTPUTS 206 | % cats - loaded cat objects 207 | if(~isfield(coco.data,'categories')), cats=[]; return; end 208 | ids = values(coco.inds.catIdsMap,num2cell(ids)); 209 | cats = coco.data.categories([ids{:}]); 210 | end 211 | 212 | function imgs = loadImgs( coco, ids ) 213 | % Load imgs with the specified ids. 214 | % 215 | % USAGE 216 | % imgs = coco.loadImgs( ids ) 217 | % 218 | % INPUTS 219 | % ids - integer ids specifying imgs 220 | % 221 | % OUTPUTS 222 | % imgs - loaded img objects 223 | ids = values(coco.inds.imgIdsMap,num2cell(ids)); 224 | imgs = coco.data.images([ids{:}]); 225 | end 226 | 227 | function hs = showAnns( coco, anns ) 228 | % Display the specified annotations. 229 | % 230 | % USAGE 231 | % hs = coco.showAnns( anns ) 232 | % 233 | % INPUTS 234 | % anns - annotations to display 235 | % 236 | % OUTPUTS 237 | % hs - handles to segment graphic objects 238 | n=length(anns); if(n==0), return; end 239 | r=.4:.2:1; [r,g,b]=ndgrid(r,r,r); cs=[r(:) g(:) b(:)]; 240 | cs=cs(randperm(size(cs,1)),:); cs=repmat(cs,100,1); 241 | if( isfield( anns,'keypoints') ) 242 | for i=1:n 243 | a=anns(i); if(isfield(a,'iscrowd') && a.iscrowd), continue; end 244 | seg={}; if(isfield(a,'segmentation')), seg=a.segmentation; end 245 | k=a.keypoints; x=k(1:3:end)+1; y=k(2:3:end)+1; v=k(3:3:end); 246 | k=coco.loadCats(a.category_id); k=k.skeleton; c=cs(i,:); hold on 247 | p={'FaceAlpha',.25,'LineWidth',2,'EdgeColor',c}; % polygon 248 | for j=seg, xy=j{1}+.5; fill(xy(1:2:end),xy(2:2:end),c,p{:}); end 249 | p={'Color',c,'LineWidth',3}; % skeleton 250 | for j=k, s=j{1}; if(all(v(s)>0)), line(x(s),y(s),p{:}); end; end 251 | p={'MarkerSize',8,'MarkerFaceColor',c,'MarkerEdgeColor'}; % pnts 252 | plot(x(v>0),y(v>0),'o',p{:},'k'); 253 | plot(x(v>1),y(v>1),'o',p{:},c); hold off; 254 | end 255 | elseif( any(isfield(anns,{'segmentation','bbox'})) ) 256 | if(~isfield(anns,'iscrowd')), [anns(:).iscrowd]=deal(0); end 257 | if(~isfield(anns,'segmentation')), S={anns.bbox}; %#ok 258 | for i=1:n, x=S{i}(1); w=S{i}(3); y=S{i}(2); h=S{i}(4); 259 | anns(i).segmentation={[x,y,x,y+h,x+w,y+h,x+w,y]}; end; end 260 | S={anns.segmentation}; hs=zeros(10000,1); k=0; hold on; 261 | pFill={'FaceAlpha',.4,'LineWidth',3}; 262 | for i=1:n 263 | if(anns(i).iscrowd), C=[.01 .65 .40]; else C=rand(1,3); end 264 | if(isstruct(S{i})), M=double(MaskApi.decode(S{i})); k=k+1; 265 | hs(k)=imagesc(cat(3,M*C(1),M*C(2),M*C(3)),'Alphadata',M*.5); 266 | else for j=1:length(S{i}), P=S{i}{j}+.5; k=k+1; 267 | hs(k)=fill(P(1:2:end),P(2:2:end),C,pFill{:}); end 268 | end 269 | end 270 | hs=hs(1:k); hold off; 271 | elseif( isfield(anns,'caption') ) 272 | S={anns.caption}; 273 | for i=1:n, S{i}=[int2str(i) ') ' S{i} '\newline']; end 274 | S=[S{:}]; title(S,'FontSize',12); 275 | end 276 | end 277 | 278 | function cocoRes = loadRes( coco, resFile ) 279 | % Load algorithm results and create API for accessing them. 280 | % 281 | % The API for accessing and viewing algorithm results is identical to 282 | % the CocoApi for the ground truth. The single difference is that the 283 | % ground truth results are replaced by the algorithm results. 284 | % 285 | % USAGE 286 | % cocoRes = coco.loadRes( resFile ) 287 | % 288 | % INPUTS 289 | % resFile - COCO results filename 290 | % 291 | % OUTPUTS 292 | % cocoRes - initialized results API 293 | fprintf('Loading and preparing results... '); clk=clock; 294 | cdata=coco.data; R=gason(fileread(resFile)); m=length(R); 295 | valid=ismember([R.image_id],[cdata.images.id]); 296 | if(~all(valid)), error('Results provided for invalid images.'); end 297 | t={'segmentation','bbox','keypoints','caption'}; t=t{isfield(R,t)}; 298 | if(strcmp(t,'caption')) 299 | for i=1:m, R(i).id=i; end; imgs=cdata.images; 300 | cdata.images=imgs(ismember([imgs.id],[R.image_id])); 301 | else 302 | assert(all(isfield(R,{'category_id','score',t}))); 303 | s=cat(1,R.(t)); if(strcmp(t,'bbox')), a=s(:,3).*s(:,4); end 304 | if(strcmp(t,'segmentation')), a=MaskApi.area(s); end 305 | if(strcmp(t,'keypoints')), x=s(:,1:3:end)'; y=s(:,2:3:end)'; 306 | a=(max(x)-min(x)).*(max(y)-min(y)); end 307 | for i=1:m, R(i).area=a(i); R(i).id=i; end 308 | end 309 | fprintf('DONE (t=%0.2fs).\n',etime(clock,clk)); 310 | cdata.annotations=R; cocoRes=CocoApi(cdata); 311 | end 312 | end 313 | 314 | end 315 | -------------------------------------------------------------------------------- /cocoapi/MatlabAPI/CocoUtils.m: -------------------------------------------------------------------------------- 1 | classdef CocoUtils 2 | % Utility functions for testing and validation of COCO code. 3 | % 4 | % The following utility functions are defined: 5 | % convertPascalGt - Convert ground truth for PASCAL to COCO format. 6 | % convertImageNetGt - Convert ground truth for ImageNet to COCO format. 7 | % convertPascalDt - Convert detections on PASCAL to COCO format. 8 | % convertImageNetDt - Convert detections on ImageNet to COCO format. 9 | % validateOnPascal - Validate COCO eval code against PASCAL code. 10 | % validateOnImageNet - Validate COCO eval code against ImageNet code. 11 | % generateFakeDt - Generate fake detections from ground truth. 12 | % validateMaskApi - Validate MaskApi against Matlab functions. 13 | % gasonSplit - Split JSON file into multiple JSON files. 14 | % gasonMerge - Merge JSON files into single JSON file. 15 | % Help on each functions can be accessed by: "help CocoUtils>function". 16 | % 17 | % See also CocoApi MaskApi CocoEval CocoUtils>convertPascalGt 18 | % CocoUtils>convertImageNetGt CocoUtils>convertPascalDt 19 | % CocoUtils>convertImageNetDt CocoUtils>validateOnPascal 20 | % CocoUtils>validateOnImageNet CocoUtils>generateFakeDt 21 | % CocoUtils>validateMaskApi CocoUtils>gasonSplit CocoUtils>gasonMerge 22 | % 23 | % Microsoft COCO Toolbox. version 2.0 24 | % Data, paper, and tutorials available at: http://mscoco.org/ 25 | % Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 26 | % Licensed under the Simplified BSD License [see coco/license.txt] 27 | 28 | methods( Static ) 29 | function convertPascalGt( dataDir, year, split, annFile ) 30 | % Convert ground truth for PASCAL to COCO format. 31 | % 32 | % USAGE 33 | % CocoUtils.convertPascalGt( dataDir, year, split, annFile ) 34 | % 35 | % INPUTS 36 | % dataDir - dir containing VOCdevkit/ 37 | % year - dataset year (e.g. '2007') 38 | % split - dataset split (e.g. 'val') 39 | % annFile - annotation file for writing results 40 | if(exist(annFile,'file')), return; end 41 | fprintf('Converting PASCAL VOC dataset... '); clk=tic; 42 | dev=[dataDir '/VOCdevkit/']; addpath(genpath([dev '/VOCcode'])); 43 | VOCinit; C=VOCopts.classes'; catsMap=containers.Map(C,1:length(C)); 44 | f=fopen([dev '/VOC' year '/ImageSets/Main/' split '.txt']); 45 | is=textscan(f,'%s %*s'); is=is{1}; fclose(f); n=length(is); 46 | data=CocoUtils.initData(C,n); 47 | for i=1:n, nm=[is{i} '.jpg']; 48 | f=[dev '/VOC' year '/Annotations/' is{i} '.xml']; 49 | R=PASreadrecord(f); hw=R.imgsize([2 1]); O=R.objects; 50 | id=is{i}; id(id=='_')=[]; id=str2double(id); 51 | ignore=[O.difficult]; bbs=cat(1,O.bbox); 52 | t=catsMap.values({O.class}); catIds=[t{:}]; iscrowd=ignore*0; 53 | data=CocoUtils.addData(data,nm,id,hw,catIds,ignore,iscrowd,bbs); 54 | end 55 | f=fopen(annFile,'w'); fwrite(f,gason(data)); fclose(f); 56 | fprintf('DONE (t=%0.2fs).\n',toc(clk)); 57 | end 58 | 59 | function convertImageNetGt( dataDir, year, split, annFile ) 60 | % Convert ground truth for ImageNet to COCO format. 61 | % 62 | % USAGE 63 | % CocoUtils.convertImageNetGt( dataDir, year, split, annFile ) 64 | % 65 | % INPUTS 66 | % dataDir - dir containing ILSVRC*/ folders 67 | % year - dataset year (e.g. '2013') 68 | % split - dataset split (e.g. 'val') 69 | % annFile - annotation file for writing results 70 | if(exist(annFile,'file')), return; end 71 | fprintf('Converting ImageNet dataset... '); clk=tic; 72 | dev=[dataDir '/ILSVRC' year '_devkit/']; 73 | addpath(genpath([dev '/evaluation/'])); 74 | t=[dev '/data/meta_det.mat']; 75 | t=load(t); synsets=t.synsets(1:200); catNms={synsets.name}; 76 | catsMap=containers.Map({synsets.WNID},1:length(catNms)); 77 | if(~strcmp(split,'val')), blacklist=cell(1,2); else 78 | f=[dev '/data/' 'ILSVRC' year '_det_validation_blacklist.txt']; 79 | f=fopen(f); blacklist=textscan(f,'%d %s'); fclose(f); 80 | t=catsMap.values(blacklist{2}); blacklist{2}=[t{:}]; 81 | end 82 | if(strcmp(split,'train')) 83 | dl=@(i) [dev '/data/det_lists/' split '_pos_' int2str(i) '.txt']; 84 | is=cell(1,200); for i=1:200, f=fopen(dl(i)); 85 | is{i}=textscan(f,'%s %*s'); is{i}=is{i}{1}; fclose(f); end 86 | is=unique(cat(1,is{:})); n=length(is); 87 | else 88 | f=fopen([dev '/data/det_lists/' split '.txt']); 89 | is=textscan(f,'%s %*s'); is=is{1}; fclose(f); n=length(is); 90 | end 91 | data=CocoUtils.initData(catNms,n); 92 | for i=1:n 93 | f=[dataDir '/ILSVRC' year '_DET_bbox_' split '/' is{i} '.xml']; 94 | R=VOCreadxml(f); R=R.annotation; nm=[is{i} '.JPEG']; 95 | hw=str2double({R.size.height R.size.width}); 96 | if(~isfield(R,'object')), catIds=[]; bbs=[]; else 97 | O=R.object; t=catsMap.values({O.name}); catIds=[t{:}]; 98 | b=[O.bndbox]; bbs=str2double({b.xmin; b.ymin; b.xmax; b.ymax})'; 99 | end 100 | j=blacklist{2}(blacklist{1}==i); m=numel(j); b=[0 0 hw(2) hw(1)]; 101 | catIds=[j catIds]; bbs=[repmat(b,m,1); bbs]; %#ok 102 | ignore=ismember(catIds,j); iscrowd=ignore*0; iscrowd(1:m)=1; 103 | data=CocoUtils.addData(data,nm,i,hw,catIds,ignore,iscrowd,bbs); 104 | end 105 | f=fopen(annFile,'w'); fwrite(f,gason(data)); fclose(f); 106 | fprintf('DONE (t=%0.2fs).\n',toc(clk)); 107 | end 108 | 109 | function convertPascalDt( srcFiles, tarFile ) 110 | % Convert detections on PASCAL to COCO format. 111 | % 112 | % USAGE 113 | % CocoUtils.convertPascalDt( srcFiles, tarFile ) 114 | % 115 | % INPUTS 116 | % srcFiles - source detection file(s) in PASCAL format 117 | % tarFile - target detection file in COCO format 118 | if(exist(tarFile,'file')), return; end; R=[]; 119 | for i=1:length(srcFiles), f=fopen(srcFiles{i},'r'); 120 | R1=textscan(f,'%d %f %f %f %f %f'); fclose(f); 121 | [~,~,x0,y0,x1,y1]=deal(R1{:}); b=[x0-1 y0-1 x1-x0+1 y1-y0+1]; 122 | b(:,3:4)=max(b(:,3:4),1); b=mat2cell(b,ones(1,size(b,1)),4); 123 | R=[R; struct('image_id',num2cell(R1{1}),'bbox',b,... 124 | 'category_id',i,'score',num2cell(R1{2}))]; %#ok 125 | end 126 | f=fopen(tarFile,'w'); fwrite(f,gason(R)); fclose(f); 127 | end 128 | 129 | function convertImageNetDt( srcFile, tarFile ) 130 | % Convert detections on ImageNet to COCO format. 131 | % 132 | % USAGE 133 | % CocoUtils.convertImageNetDt( srcFile, tarFile ) 134 | % 135 | % INPUTS 136 | % srcFile - source detection file in ImageNet format 137 | % tarFile - target detection file in COCO format 138 | if(exist(tarFile,'file')), return; end; f=fopen(srcFile,'r'); 139 | R=textscan(f,'%d %d %f %f %f %f %f'); fclose(f); 140 | [~,~,~,x0,y0,x1,y1]=deal(R{:}); b=[x0-1 y0-1 x1-x0+1 y1-y0+1]; 141 | b(:,3:4)=max(b(:,3:4),1); bbox=mat2cell(b,ones(1,size(b,1)),4); 142 | R=struct('image_id',num2cell(R{1}),'bbox',bbox,... 143 | 'category_id',num2cell(R{2}),'score',num2cell(R{3})); 144 | f=fopen(tarFile,'w'); fwrite(f,gason(R)); fclose(f); 145 | end 146 | 147 | function validateOnPascal( dataDir ) 148 | % Validate COCO eval code against PASCAL code. 149 | % 150 | % USAGE 151 | % CocoUtils.validateOnPascal( dataDir ) 152 | % 153 | % INPUTS 154 | % dataDir - dir containing VOCdevkit/ 155 | split='val'; year='2007'; thrs=0:.001:1; T=length(thrs); 156 | dev=[dataDir '/VOCdevkit/']; addpath(genpath([dev '/VOCcode/'])); 157 | d=pwd; cd(dev); VOCinit; cd(d); O=VOCopts; O.testset=split; 158 | O.detrespath=[O.detrespath(1:end-10) split '_%s.txt']; 159 | catNms=O.classes; K=length(catNms); ap=zeros(K,1); 160 | for i=1:K, [R,P]=VOCevaldet(O,'comp3',catNms{i},0); R1=[R; inf]; 161 | P1=[P; 0]; for t=1:T, ap(i)=ap(i)+max(P1(R1>=thrs(t)))/T; end; end 162 | srcFile=[dev '/results/VOC' year '/Main/comp3_det_' split]; 163 | resFile=[srcFile '.json']; annFile=[dev '/VOC2007/' split '.json']; 164 | sfs=cell(1,K); for i=1:K, sfs{i}=[srcFile '_' catNms{i} '.txt']; end 165 | CocoUtils.convertPascalGt(dataDir,year,split,annFile); 166 | CocoUtils.convertPascalDt(sfs,resFile); 167 | D=CocoApi(annFile); R=D.loadRes(resFile); E=CocoEval(D,R); 168 | p=E.params; p.recThrs=thrs; p.iouThrs=.5; p.areaRng=[0 inf]; 169 | p.useSegm=0; p.maxDets=inf; E.params=p; E.evaluate(); E.accumulate(); 170 | apCoco=squeeze(mean(E.eval.precision,2)); deltas=abs(apCoco-ap); 171 | fprintf('AP delta: mean=%.2e median=%.2e max=%.2e\n',... 172 | mean(deltas),median(deltas),max(deltas)) 173 | if(max(deltas)>1e-2), msg='FAILED'; else msg='PASSED'; end 174 | warning(['Eval code *' msg '* validation!']); 175 | end 176 | 177 | function validateOnImageNet( dataDir ) 178 | % Validate COCO eval code against ImageNet code. 179 | % 180 | % USAGE 181 | % CocoUtils.validateOnImageNet( dataDir ) 182 | % 183 | % INPUTS 184 | % dataDir - dir containing ILSVRC*/ folders 185 | warning(['Set pixelTolerance=0 in line 30 of eval_detection.m '... 186 | '(and delete cache) otherwise AP will differ by >1e-4!']); 187 | year='2013'; dev=[dataDir '/ILSVRC' year '_devkit/']; 188 | fs = { [dev 'evaluation/demo.val.pred.det.txt'] 189 | [dataDir '/ILSVRC' year '_DET_bbox_val/'] 190 | [dev 'data/meta_det.mat'] 191 | [dev 'data/det_lists/val.txt'] 192 | [dev 'data/ILSVRC' year '_det_validation_blacklist.txt'] 193 | [dev 'data/ILSVRC' year '_det_validation_cache.mat'] }; 194 | addpath(genpath([dev 'evaluation/'])); 195 | ap=eval_detection(fs{:})'; 196 | resFile=[fs{1}(1:end-3) 'json']; 197 | annFile=[dev 'data/ILSVRC' year '_val.json']; 198 | CocoUtils.convertImageNetDt(fs{1},resFile); 199 | CocoUtils.convertImageNetGt(dataDir,year,'val',annFile) 200 | D=CocoApi(annFile); R=D.loadRes(resFile); E=CocoEval(D,R); 201 | p=E.params; p.recThrs=0:.0001:1; p.iouThrs=.5; p.areaRng=[0 inf]; 202 | p.useSegm=0; p.maxDets=inf; E.params=p; E.evaluate(); E.accumulate(); 203 | apCoco=squeeze(mean(E.eval.precision,2)); deltas=abs(apCoco-ap); 204 | fprintf('AP delta: mean=%.2e median=%.2e max=%.2e\n',... 205 | mean(deltas),median(deltas),max(deltas)) 206 | if(max(deltas)>1e-4), msg='FAILED'; else msg='PASSED'; end 207 | warning(['Eval code *' msg '* validation!']); 208 | end 209 | 210 | function generateFakeDt( coco, dtFile, varargin ) 211 | % Generate fake detections from ground truth. 212 | % 213 | % USAGE 214 | % CocoUtils.generateFakeDt( coco, dtFile, varargin ) 215 | % 216 | % INPUTS 217 | % coco - instance of CocoApi containing ground truth 218 | % dtFile - target file for writing detection results 219 | % params - parameters (struct or name/value pairs) 220 | % .n - [100] number images for which to generate dets 221 | % .fn - [.20] false negative rate (00; if(~any(v)), continue; end 251 | x=o(1:3:end); y=o(2:3:end); x(~v)=mean(x(v)); y(~v)=mean(y(v)); 252 | x=max(0,min(w-1,x+dx)); o(1:3:end)=x; o(2:3:end)=y; 253 | end 254 | k=k+1; R(k).image_id=imgIds(i); R(k).category_id=catId; 255 | R(k).(opts.type)=o; R(k).score=round(rand(rstream)*1000)/1000; 256 | end 257 | end 258 | R=R(1:k); f=fopen(dtFile,'w'); fwrite(f,gason(R)); fclose(f); 259 | fprintf('DONE (t=%0.2fs).\n',toc(clk)); 260 | end 261 | 262 | function validateMaskApi( coco ) 263 | % Validate MaskApi against Matlab functions. 264 | % 265 | % USAGE 266 | % CocoUtils.validateMaskApi( coco ) 267 | % 268 | % INPUTS 269 | % coco - instance of CocoApi containing ground truth 270 | S=coco.data.annotations; S=S(~[S.iscrowd]); S={S.segmentation}; 271 | h=1000; n=1000; Z=cell(1,n); A=Z; B=Z; M=Z; IB=zeros(1,n); 272 | fprintf('Running MaskApi implementations... '); clk=tic; 273 | for i=1:n, A{i}=MaskApi.frPoly(S{i},h,h); end 274 | Ia=MaskApi.iou(A{1},[A{:}]); 275 | fprintf('DONE (t=%0.2fs).\n',toc(clk)); 276 | fprintf('Running Matlab implementations... '); clk=tic; 277 | for i=1:n, M1=0; for j=1:length(S{i}), x=S{i}{j}+.5; 278 | M1=M1+poly2mask(x(1:2:end),x(2:2:end),h,h); end 279 | M{i}=uint8(M1>0); B{i}=MaskApi.encode(M{i}); 280 | IB(i)=sum(sum(M{1}&M{i}))/sum(sum(M{1}|M{i})); 281 | end 282 | fprintf('DONE (t=%0.2fs).\n',toc(clk)); 283 | if(isequal(A,B)&&isequal(Ia,IB)), 284 | msg='PASSED'; else msg='FAILED'; end 285 | warning(['MaskApi *' msg '* validation!']); 286 | end 287 | 288 | function gasonSplit( name, k ) 289 | % Split JSON file into multiple JSON files. 290 | % 291 | % Splits file 'name.json' into multiple files 'name-*.json'. Only 292 | % works for JSON arrays. Memory efficient. Inverted by gasonMerge(). 293 | % 294 | % USAGE 295 | % CocoUtils.gasonSplit( name, k ) 296 | % 297 | % INPUTS 298 | % name - file containing JSON array (w/o '.json' ext) 299 | % k - number of files to split JSON into 300 | s=gasonMex('split',fileread([name '.json']),k); k=length(s); 301 | for i=1:k, f=fopen(sprintf('%s-%06i.json',name,i),'w'); 302 | fwrite(f,s{i}); fclose(f); end 303 | end 304 | 305 | function gasonMerge( name ) 306 | % Merge JSON files into single JSON file. 307 | % 308 | % Merge files 'name-*.json' into single file 'name.json'. Only works 309 | % for JSON arrays. Memory efficient. Inverted by gasonSplit(). 310 | % 311 | % USAGE 312 | % CocoUtils.gasonMerge( name ) 313 | % 314 | % INPUTS 315 | % name - files containing JSON arrays (w/o '.json' ext) 316 | s=dir([name '-*.json']); s=sort({s.name}); k=length(s); 317 | p=fileparts(name); for i=1:k, s{i}=fullfile(p,s{i}); end 318 | for i=1:k, s{i}=fileread(s{i}); end; s=gasonMex('merge',s); 319 | f=fopen([name '.json'],'w'); fwrite(f,s); fclose(f); 320 | end 321 | end 322 | 323 | methods( Static, Access=private ) 324 | function data = initData( catNms, n ) 325 | % Helper for convert() functions: init annotations. 326 | m=length(catNms); ms=num2cell(1:m); 327 | I = struct('file_name',0,'height',0,'width',0,'id',0); 328 | C = struct('supercategory','none','id',ms,'name',catNms); 329 | A = struct('segmentation',0,'area',0,'iscrowd',0,... 330 | 'image_id',0,'bbox',0,'category_id',0,'id',0,'ignore',0); 331 | I=repmat(I,1,n); A=repmat(A,1,n*20); 332 | data = struct('images',I,'type','instances',... 333 | 'annotations',A,'categories',C,'nImgs',0,'nAnns',0); 334 | end 335 | 336 | function data = addData( data,nm,id,hw,catIds,ignore,iscrowd,bbs ) 337 | % Helper for convert() functions: add annotations. 338 | data.nImgs=data.nImgs+1; 339 | data.images(data.nImgs)=struct('file_name',nm,... 340 | 'height',hw(1),'width',hw(2),'id',id); 341 | for j=1:length(catIds), data.nAnns=data.nAnns+1; k=data.nAnns; 342 | b=bbs(j,:); b=b-1; b(3:4)=b(3:4)-b(1:2)+1; 343 | x1=b(1); x2=b(1)+b(3); y1=b(2); y2=b(2)+b(4); 344 | S={{[x1 y1 x1 y2 x2 y2 x2 y1]}}; a=b(3)*b(4); 345 | data.annotations(k)=struct('segmentation',S,'area',a,... 346 | 'iscrowd',iscrowd(j),'image_id',id,'bbox',b,... 347 | 'category_id',catIds(j),'id',k,'ignore',ignore(j)); 348 | end 349 | if( data.nImgs == length(data.images) ) 350 | data.annotations=data.annotations(1:data.nAnns); 351 | data=rmfield(data,{'nImgs','nAnns'}); 352 | end 353 | end 354 | end 355 | 356 | end 357 | -------------------------------------------------------------------------------- /cocoapi/MatlabAPI/MaskApi.m: -------------------------------------------------------------------------------- 1 | classdef MaskApi 2 | % Interface for manipulating masks stored in RLE format. 3 | % 4 | % RLE is a simple yet efficient format for storing binary masks. RLE 5 | % first divides a vector (or vectorized image) into a series of piecewise 6 | % constant regions and then for each piece simply stores the length of 7 | % that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 8 | % be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 9 | % (note that the odd counts are always the numbers of zeros). Instead of 10 | % storing the counts directly, additional compression is achieved with a 11 | % variable bitrate representation based on a common scheme called LEB128. 12 | % 13 | % Compression is greatest given large piecewise constant regions. 14 | % Specifically, the size of the RLE is proportional to the number of 15 | % *boundaries* in M (or for an image the number of boundaries in the y 16 | % direction). Assuming fairly simple shapes, the RLE representation is 17 | % O(sqrt(n)) where n is number of pixels in the object. Hence space usage 18 | % is substantially lower, especially for large simple objects (large n). 19 | % 20 | % Many common operations on masks can be computed directly using the RLE 21 | % (without need for decoding). This includes computations such as area, 22 | % union, intersection, etc. All of these operations are linear in the 23 | % size of the RLE, in other words they are O(sqrt(n)) where n is the area 24 | % of the object. Computing these operations on the original mask is O(n). 25 | % Thus, using the RLE can result in substantial computational savings. 26 | % 27 | % The following API functions are defined: 28 | % encode - Encode binary masks using RLE. 29 | % decode - Decode binary masks encoded via RLE. 30 | % merge - Compute union or intersection of encoded masks. 31 | % iou - Compute intersection over union between masks. 32 | % nms - Compute non-maximum suppression between ordered masks. 33 | % area - Compute area of encoded masks. 34 | % toBbox - Get bounding boxes surrounding encoded masks. 35 | % frBbox - Convert bounding boxes to encoded masks. 36 | % frPoly - Convert polygon to encoded mask. 37 | % 38 | % Usage: 39 | % Rs = MaskApi.encode( masks ) 40 | % masks = MaskApi.decode( Rs ) 41 | % R = MaskApi.merge( Rs, [intersect=false] ) 42 | % o = MaskApi.iou( dt, gt, [iscrowd=false] ) 43 | % keep = MaskApi.nms( dt, thr ) 44 | % a = MaskApi.area( Rs ) 45 | % bbs = MaskApi.toBbox( Rs ) 46 | % Rs = MaskApi.frBbox( bbs, h, w ) 47 | % R = MaskApi.frPoly( poly, h, w ) 48 | % 49 | % In the API the following formats are used: 50 | % R,Rs - [struct] Run-length encoding of binary mask(s) 51 | % masks - [hxwxn] Binary mask(s) (must have type uint8) 52 | % bbs - [nx4] Bounding box(es) stored as [x y w h] 53 | % poly - Polygon stored as {[x1 y1 x2 y2...],[x1 y1 ...],...} 54 | % dt,gt - May be either bounding boxes or encoded masks 55 | % Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 56 | % 57 | % Finally, a note about the intersection over union (iou) computation. 58 | % The standard iou of a ground truth (gt) and detected (dt) object is 59 | % iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 60 | % For "crowd" regions, we use a modified criteria. If a gt object is 61 | % marked as "iscrowd", we allow a dt to match any subregion of the gt. 62 | % Choosing gt' in the crowd gt that best matches the dt can be done using 63 | % gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 64 | % iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 65 | % For crowd gt regions we use this modified criteria above for the iou. 66 | % 67 | % To compile use the following (some precompiled binaries are included): 68 | % mex('CFLAGS=\$CFLAGS -Wall -std=c99','-largeArrayDims',... 69 | % 'private/maskApiMex.c','../common/maskApi.c',... 70 | % '-I../common/','-outdir','private'); 71 | % Please do not contact us for help with compiling. 72 | % 73 | % Microsoft COCO Toolbox. version 2.0 74 | % Data, paper, and tutorials available at: http://mscoco.org/ 75 | % Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 76 | % Licensed under the Simplified BSD License [see coco/license.txt] 77 | 78 | methods( Static ) 79 | function Rs = encode( masks ) 80 | Rs = maskApiMex( 'encode', masks ); 81 | end 82 | 83 | function masks = decode( Rs ) 84 | masks = maskApiMex( 'decode', Rs ); 85 | end 86 | 87 | function R = merge( Rs, varargin ) 88 | R = maskApiMex( 'merge', Rs, varargin{:} ); 89 | end 90 | 91 | function o = iou( dt, gt, varargin ) 92 | o = maskApiMex( 'iou', dt', gt', varargin{:} ); 93 | end 94 | 95 | function keep = nms( dt, thr ) 96 | keep = maskApiMex('nms',dt',thr); 97 | end 98 | 99 | function a = area( Rs ) 100 | a = maskApiMex( 'area', Rs ); 101 | end 102 | 103 | function bbs = toBbox( Rs ) 104 | bbs = maskApiMex( 'toBbox', Rs )'; 105 | end 106 | 107 | function Rs = frBbox( bbs, h, w ) 108 | Rs = maskApiMex( 'frBbox', bbs', h, w ); 109 | end 110 | 111 | function R = frPoly( poly, h, w ) 112 | R = maskApiMex( 'frPoly', poly, h , w ); 113 | end 114 | end 115 | 116 | end 117 | -------------------------------------------------------------------------------- /cocoapi/MatlabAPI/cocoDemo.m: -------------------------------------------------------------------------------- 1 | %% Demo for the CocoApi (see CocoApi.m) 2 | 3 | %% initialize COCO api (please specify dataType/annType below) 4 | annTypes = { 'instances', 'captions', 'person_keypoints' }; 5 | dataType='val2014'; annType=annTypes{1}; % specify dataType/annType 6 | annFile=sprintf('../annotations/%s_%s.json',annType,dataType); 7 | coco=CocoApi(annFile); 8 | 9 | %% display COCO categories and supercategories 10 | if( ~strcmp(annType,'captions') ) 11 | cats = coco.loadCats(coco.getCatIds()); 12 | nms={cats.name}; fprintf('COCO categories: '); 13 | fprintf('%s, ',nms{:}); fprintf('\n'); 14 | nms=unique({cats.supercategory}); fprintf('COCO supercategories: '); 15 | fprintf('%s, ',nms{:}); fprintf('\n'); 16 | end 17 | 18 | %% get all images containing given categories, select one at random 19 | catIds = coco.getCatIds('catNms',{'person','dog','skateboard'}); 20 | imgIds = coco.getImgIds('catIds',catIds); 21 | imgId = imgIds(randi(length(imgIds))); 22 | 23 | %% load and display image 24 | img = coco.loadImgs(imgId); 25 | I = imread(sprintf('../images/%s/%s',dataType,img.file_name)); 26 | figure(1); imagesc(I); axis('image'); set(gca,'XTick',[],'YTick',[]) 27 | 28 | %% load and display annotations 29 | annIds = coco.getAnnIds('imgIds',imgId,'catIds',catIds,'iscrowd',[]); 30 | anns = coco.loadAnns(annIds); coco.showAnns(anns); 31 | -------------------------------------------------------------------------------- /cocoapi/MatlabAPI/evalDemo.m: -------------------------------------------------------------------------------- 1 | %% Demo demonstrating the algorithm result formats for COCO 2 | 3 | %% select results type for demo (either bbox or segm) 4 | type = {'segm','bbox','keypoints'}; type = type{1}; % specify type here 5 | fprintf('Running demo for *%s* results.\n\n',type); 6 | 7 | %% initialize COCO ground truth api 8 | dataDir='../'; prefix='instances'; dataType='val2014'; 9 | if(strcmp(type,'keypoints')), prefix='person_keypoints'; end 10 | annFile=sprintf('%s/annotations/%s_%s.json',dataDir,prefix,dataType); 11 | cocoGt=CocoApi(annFile); 12 | 13 | %% initialize COCO detections api 14 | resFile='%s/results/%s_%s_fake%s100_results.json'; 15 | resFile=sprintf(resFile,dataDir,prefix,dataType,type); 16 | cocoDt=cocoGt.loadRes(resFile); 17 | 18 | %% visialuze gt and dt side by side 19 | imgIds=sort(cocoGt.getImgIds()); imgIds=imgIds(1:100); 20 | imgId = imgIds(randi(100)); img = cocoGt.loadImgs(imgId); 21 | I = imread(sprintf('%s/images/val2014/%s',dataDir,img.file_name)); 22 | figure(1); subplot(1,2,1); imagesc(I); axis('image'); axis off; 23 | annIds = cocoGt.getAnnIds('imgIds',imgId); title('ground truth') 24 | anns = cocoGt.loadAnns(annIds); cocoGt.showAnns(anns); 25 | figure(1); subplot(1,2,2); imagesc(I); axis('image'); axis off; 26 | annIds = cocoDt.getAnnIds('imgIds',imgId); title('results') 27 | anns = cocoDt.loadAnns(annIds); cocoDt.showAnns(anns); 28 | 29 | %% load raw JSON and show exact format for results 30 | fprintf('results structure have the following format:\n'); 31 | res = gason(fileread(resFile)); disp(res) 32 | 33 | %% the following command can be used to save the results back to disk 34 | if(0), f=fopen(resFile,'w'); fwrite(f,gason(res)); fclose(f); end 35 | 36 | %% run COCO evaluation code (see CocoEval.m) 37 | cocoEval=CocoEval(cocoGt,cocoDt,type); 38 | cocoEval.params.imgIds=imgIds; 39 | cocoEval.evaluate(); 40 | cocoEval.accumulate(); 41 | cocoEval.summarize(); 42 | 43 | %% generate Derek Hoiem style analyis of false positives (slow) 44 | if(0), cocoEval.analyze(); end 45 | -------------------------------------------------------------------------------- /cocoapi/MatlabAPI/gason.m: -------------------------------------------------------------------------------- 1 | function out = gason( in ) 2 | % Convert between JSON strings and corresponding JSON objects. 3 | % 4 | % This parser is based on Gason written and maintained by Ivan Vashchaev: 5 | % https://github.com/vivkin/gason 6 | % Gason is a "lightweight and fast JSON parser for C++". Please see the 7 | % above link for license information and additional details about Gason. 8 | % 9 | % Given a JSON string, gason calls the C++ parser and converts the output 10 | % into an appropriate Matlab structure. As the parsing is performed in mex 11 | % the resulting parser is blazingly fast. Large JSON structs (100MB+) take 12 | % only a few seconds to parse (compared to hours for pure Matlab parsers). 13 | % 14 | % Given a JSON object, gason calls the C++ encoder to convert the object 15 | % back into a JSON string representation. Nearly any Matlab struct, cell 16 | % array, or numeric array represent a valid JSON object. Note that gason() 17 | % can be used to go both from JSON string to JSON object and back. 18 | % 19 | % Gason requires C++11 to compile (for GCC this requires version 4.7 or 20 | % later). The following command compiles the parser (may require tweaking): 21 | % mex('CXXFLAGS=\$CXXFLAGS -std=c++11 -Wall','-largeArrayDims',... 22 | % 'private/gasonMex.cpp','../common/gason.cpp',... 23 | % '-I../common/','-outdir','private'); 24 | % Note the use of the "-std=c++11" flag. A number of precompiled binaries 25 | % are included, please do not contact us for help with compiling. If needed 26 | % you can specify a compiler by adding the option 'CXX="/usr/bin/g++"'. 27 | % 28 | % Note that by default JSON arrays that contain only numbers are stored as 29 | % regular Matlab arrays. Likewise, JSON arrays that contain only objects of 30 | % the same type are stored as Matlab struct arrays. This is much faster and 31 | % can use considerably less memory than always using Matlab cell arrays. 32 | % 33 | % USAGE 34 | % object = gason( string ) 35 | % string = gason( object ) 36 | % 37 | % INPUTS/OUTPUTS 38 | % string - JSON string 39 | % object - JSON object 40 | % 41 | % EXAMPLE 42 | % o = struct('first',{'piotr','ty'},'last',{'dollar','lin'}) 43 | % s = gason( o ) % convert JSON object -> JSON string 44 | % p = gason( s ) % convert JSON string -> JSON object 45 | % 46 | % See also 47 | % 48 | % Microsoft COCO Toolbox. version 2.0 49 | % Data, paper, and tutorials available at: http://mscoco.org/ 50 | % Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 51 | % Licensed under the Simplified BSD License [see coco/license.txt] 52 | 53 | out = gasonMex( 'convert', in ); 54 | -------------------------------------------------------------------------------- /cocoapi/MatlabAPI/private/gasonMex.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************** 2 | * Microsoft COCO Toolbox. version 2.0 3 | * Data, paper, and tutorials available at: http://mscoco.org/ 4 | * Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 5 | * Licensed under the Simplified BSD License [see coco/license.txt] 6 | **************************************************************************/ 7 | #include "gason.h" 8 | #include "mex.h" 9 | #include "string.h" 10 | #include "math.h" 11 | #include 12 | #include 13 | #include 14 | typedef std::ostringstream ostrm; 15 | typedef unsigned long siz; 16 | typedef unsigned short ushort; 17 | 18 | siz length( const JsonValue &a ) { 19 | // get number of elements in JSON_ARRAY or JSON_OBJECT 20 | siz k=0; auto n=a.toNode(); while(n) { k++; n=n->next; } return k; 21 | } 22 | 23 | bool isRegularObjArray( const JsonValue &a ) { 24 | // check if all JSON_OBJECTs in JSON_ARRAY have the same fields 25 | JsonValue o=a.toNode()->value; siz k, n; const char **keys; 26 | n=length(o); keys=new const char*[n]; 27 | k=0; for(auto j:o) keys[k++]=j->key; 28 | for( auto i:a ) { 29 | if(length(i->value)!=n) return false; k=0; 30 | for(auto j:i->value) if(strcmp(j->key,keys[k++])) return false; 31 | } 32 | delete [] keys; return true; 33 | } 34 | 35 | mxArray* json( const JsonValue &o ) { 36 | // convert JsonValue to Matlab mxArray 37 | siz k, m, n; mxArray *M; const char **keys; 38 | switch( o.getTag() ) { 39 | case JSON_NUMBER: 40 | return mxCreateDoubleScalar(o.toNumber()); 41 | case JSON_STRING: 42 | return mxCreateString(o.toString()); 43 | case JSON_ARRAY: { 44 | if(!o.toNode()) return mxCreateDoubleMatrix(1,0,mxREAL); 45 | JsonValue o0=o.toNode()->value; JsonTag tag=o0.getTag(); 46 | n=length(o); bool isRegular=true; 47 | for(auto i:o) isRegular=isRegular && i->value.getTag()==tag; 48 | if( isRegular && tag==JSON_OBJECT && isRegularObjArray(o) ) { 49 | m=length(o0); keys=new const char*[m]; 50 | k=0; for(auto j:o0) keys[k++]=j->key; 51 | M = mxCreateStructMatrix(1,n,m,keys); 52 | k=0; for(auto i:o) { m=0; for(auto j:i->value) 53 | mxSetFieldByNumber(M,k,m++,json(j->value)); k++; } 54 | delete [] keys; return M; 55 | } else if( isRegular && tag==JSON_NUMBER ) { 56 | M = mxCreateDoubleMatrix(1,n,mxREAL); double *p=mxGetPr(M); 57 | k=0; for(auto i:o) p[k++]=i->value.toNumber(); return M; 58 | } else { 59 | M = mxCreateCellMatrix(1,n); 60 | k=0; for(auto i:o) mxSetCell(M,k++,json(i->value)); 61 | return M; 62 | } 63 | } 64 | case JSON_OBJECT: 65 | if(!o.toNode()) return mxCreateStructMatrix(1,0,0,NULL); 66 | n=length(o); keys=new const char*[n]; 67 | k=0; for(auto i:o) keys[k++]=i->key; 68 | M = mxCreateStructMatrix(1,1,n,keys); k=0; 69 | for(auto i:o) mxSetFieldByNumber(M,0,k++,json(i->value)); 70 | delete [] keys; return M; 71 | case JSON_TRUE: 72 | return mxCreateDoubleScalar(1); 73 | case JSON_FALSE: 74 | return mxCreateDoubleScalar(0); 75 | case JSON_NULL: 76 | return mxCreateDoubleMatrix(0,0,mxREAL); 77 | default: return NULL; 78 | } 79 | } 80 | 81 | template ostrm& json( ostrm &S, T *A, siz n ) { 82 | // convert numeric array to JSON string with casting 83 | if(n==0) { S<<"[]"; return S; } if(n==1) { S< ostrm& json( ostrm &S, T *A, siz n ) { 89 | // convert numeric array to JSON string without casting 90 | return json(S,A,n); 91 | } 92 | 93 | ostrm& json( ostrm &S, const char *A ) { 94 | // convert char array to JSON string (handle escape characters) 95 | #define RPL(a,b) case a: { S << b; A++; break; } 96 | S << "\""; while( *A>0 ) switch( *A ) { 97 | RPL('"',"\\\""); RPL('\\',"\\\\"); RPL('/',"\\/"); RPL('\b',"\\b"); 98 | RPL('\f',"\\f"); RPL('\n',"\\n"); RPL('\r',"\\r"); RPL('\t',"\\t"); 99 | default: S << *A; A++; 100 | } 101 | S << "\""; return S; 102 | } 103 | 104 | ostrm& json( ostrm& S, const JsonValue *o ) { 105 | // convert JsonValue to JSON string 106 | switch( o->getTag() ) { 107 | case JSON_NUMBER: S << o->toNumber(); return S; 108 | case JSON_TRUE: S << "true"; return S; 109 | case JSON_FALSE: S << "false"; return S; 110 | case JSON_NULL: S << "null"; return S; 111 | case JSON_STRING: return json(S,o->toString()); 112 | case JSON_ARRAY: 113 | S << "["; for(auto i:*o) { 114 | json(S,&i->value) << (i->next ? "," : ""); } 115 | S << "]"; return S; 116 | case JSON_OBJECT: 117 | S << "{"; for(auto i:*o) { 118 | json(S,i->key) << ":"; 119 | json(S,&i->value) << (i->next ? "," : ""); } 120 | S << "}"; return S; 121 | default: return S; 122 | } 123 | } 124 | 125 | ostrm& json( ostrm& S, const mxArray *M ) { 126 | // convert Matlab mxArray to JSON string 127 | siz i, j, m, n=mxGetNumberOfElements(M); 128 | void *A=mxGetData(M); ostrm *nms; 129 | switch( mxGetClassID(M) ) { 130 | case mxDOUBLE_CLASS: return json(S,(double*) A,n); 131 | case mxSINGLE_CLASS: return json(S,(float*) A,n); 132 | case mxINT64_CLASS: return json(S,(int64_t*) A,n); 133 | case mxUINT64_CLASS: return json(S,(uint64_t*) A,n); 134 | case mxINT32_CLASS: return json(S,(int32_t*) A,n); 135 | case mxUINT32_CLASS: return json(S,(uint32_t*) A,n); 136 | case mxINT16_CLASS: return json(S,(int16_t*) A,n); 137 | case mxUINT16_CLASS: return json(S,(uint16_t*) A,n); 138 | case mxINT8_CLASS: return json(S,(int8_t*) A,n); 139 | case mxUINT8_CLASS: return json(S,(uint8_t*) A,n); 140 | case mxLOGICAL_CLASS: return json(S,(uint8_t*) A,n); 141 | case mxCHAR_CLASS: return json(S,mxArrayToString(M)); 142 | case mxCELL_CLASS: 143 | S << "["; for(i=0; i0) json(S,mxGetCell(M,n-1)); S << "]"; return S; 145 | case mxSTRUCT_CLASS: 146 | if(n==0) { S<<"{}"; return S; } m=mxGetNumberOfFields(M); 147 | if(m==0) { S<<"["; for(i=0; i1) S<<"["; nms=new ostrm[m]; 149 | for(j=0; j1) S<<"]"; delete [] nms; return S; 156 | default: 157 | mexErrMsgTxt( "Unknown type." ); return S; 158 | } 159 | } 160 | 161 | mxArray* mxCreateStringRobust( const char* str ) { 162 | // convert char* to Matlab string (robust version of mxCreateString) 163 | mxArray *M; ushort *c; mwSize n[2]={1,strlen(str)}; 164 | M=mxCreateCharArray(2,n); c=(ushort*) mxGetData(M); 165 | for( siz i=0; i1 ) mexErrMsgTxt("One output expected."); 182 | 183 | if(!strcmp(action,"convert")) { 184 | if( nr!=1 ) mexErrMsgTxt("One input expected."); 185 | if( mxGetClassID(pr[0])==mxCHAR_CLASS ) { 186 | // object = mexFunction( string ) 187 | char *str = mxArrayToStringRobust(pr[0]); 188 | int status = jsonParse(str, &endptr, &val, allocator); 189 | if( status != JSON_OK) mexErrMsgTxt(jsonStrError(status)); 190 | pl[0] = json(val); mxFree(str); 191 | } else { 192 | // string = mexFunction( object ) 193 | ostrm S; S << std::setprecision(12); json(S,pr[0]); 194 | pl[0]=mxCreateStringRobust(S.str().c_str()); 195 | } 196 | 197 | } else if(!strcmp(action,"split")) { 198 | // strings = mexFunction( string, k ) 199 | if( nr!=2 ) mexErrMsgTxt("Two input expected."); 200 | char *str = mxArrayToStringRobust(pr[0]); 201 | int status = jsonParse(str, &endptr, &val, allocator); 202 | if( status != JSON_OK) mexErrMsgTxt(jsonStrError(status)); 203 | if( val.getTag()!=JSON_ARRAY ) mexErrMsgTxt("Array expected"); 204 | siz i=0, t=0, n=length(val), k=(siz) mxGetScalar(pr[1]); 205 | k=(k>n)?n:(k<1)?1:k; k=ceil(n/ceil(double(n)/k)); 206 | pl[0]=mxCreateCellMatrix(1,k); ostrm S; S<value); t--; if(!o->next) t=0; S << (t ? "," : "]"); 210 | if(!t) mxSetCell(pl[0],i++,mxCreateStringRobust(S.str().c_str())); 211 | } 212 | 213 | } else if(!strcmp(action,"merge")) { 214 | // string = mexFunction( strings ) 215 | if( nr!=1 ) mexErrMsgTxt("One input expected."); 216 | if(!mxIsCell(pr[0])) mexErrMsgTxt("Cell array expected."); 217 | siz n = mxGetNumberOfElements(pr[0]); 218 | ostrm S; S << std::setprecision(12); S << "["; 219 | for( siz i=0; ivalue) << (j->next ? "," : ""); 225 | mxFree(str); if(i1) 14 | % [ param1 ... paramN ] = getPrmDflt( prm, dfs, [checkExtra] ) 15 | % 16 | % INPUTS 17 | % prm - param struct or cell of form {'name1' v1 'name2' v2 ...} 18 | % dfs - cell of form {'name1' def1 'name2' def2 ...} 19 | % checkExtra - [0] if 1 throw error if prm contains params not in dfs 20 | % if -1 if prm contains params not in dfs adds them 21 | % 22 | % OUTPUTS (nargout==1) 23 | % prm - parameter struct with fields 'name1' through 'nameN' assigned 24 | % 25 | % OUTPUTS (nargout>1) 26 | % param1 - value assigned to parameter with 'name1' 27 | % ... 28 | % paramN - value assigned to parameter with 'nameN' 29 | % 30 | % EXAMPLE 31 | % dfs = { 'x','REQ', 'y',0, 'z',[], 'eps',1e-3 }; 32 | % prm = getPrmDflt( struct('x',1,'y',1), dfs ) 33 | % [ x y z eps ] = getPrmDflt( {'x',2,'y',1}, dfs ) 34 | % 35 | % See also INPUTPARSER 36 | % 37 | % Piotr's Computer Vision Matlab Toolbox Version 2.60 38 | % Copyright 2014 Piotr Dollar. [pdollar-at-gmail.com] 39 | % Licensed under the Simplified BSD License [see external/bsd.txt] 40 | 41 | if( mod(length(dfs),2) ), error('odd number of default parameters'); end 42 | if nargin<=2, checkExtra = 0; end 43 | 44 | % get the input parameters as two cell arrays: prmVal and prmField 45 | if iscell(prm) && length(prm)==1, prm=prm{1}; end 46 | if iscell(prm) 47 | if(mod(length(prm),2)), error('odd number of parameters in prm'); end 48 | prmField = prm(1:2:end); prmVal = prm(2:2:end); 49 | else 50 | if(~isstruct(prm)), error('prm must be a struct or a cell'); end 51 | prmVal = struct2cell(prm); prmField = fieldnames(prm); 52 | end 53 | 54 | % get and update default values using quick for loop 55 | dfsField = dfs(1:2:end); dfsVal = dfs(2:2:end); 56 | if checkExtra>0 57 | for i=1:length(prmField) 58 | j = find(strcmp(prmField{i},dfsField)); 59 | if isempty(j), error('parameter %s is not valid', prmField{i}); end 60 | dfsVal(j) = prmVal(i); 61 | end 62 | elseif checkExtra<0 63 | for i=1:length(prmField) 64 | j = find(strcmp(prmField{i},dfsField)); 65 | if isempty(j), j=length(dfsVal)+1; dfsField{j}=prmField{i}; end 66 | dfsVal(j) = prmVal(i); 67 | end 68 | else 69 | for i=1:length(prmField) 70 | dfsVal(strcmp(prmField{i},dfsField)) = prmVal(i); 71 | end 72 | end 73 | 74 | % check for missing values 75 | if any(strcmp('REQ',dfsVal)) 76 | cmpArray = find(strcmp('REQ',dfsVal)); 77 | error(['Required field ''' dfsField{cmpArray(1)} ''' not specified.'] ); 78 | end 79 | 80 | % set output 81 | if nargout==1 82 | varargout{1} = cell2struct( dfsVal, dfsField, 2 ); 83 | else 84 | varargout = dfsVal; 85 | end 86 | -------------------------------------------------------------------------------- /cocoapi/MatlabAPI/private/maskApiMex.c: -------------------------------------------------------------------------------- 1 | /************************************************************************** 2 | * Microsoft COCO Toolbox. version 2.0 3 | * Data, paper, and tutorials available at: http://mscoco.org/ 4 | * Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 5 | * Licensed under the Simplified BSD License [see coco/license.txt] 6 | **************************************************************************/ 7 | #include "mex.h" 8 | #include "maskApi.h" 9 | #include 10 | 11 | void checkType( const mxArray *M, mxClassID id ) { 12 | if(mxGetClassID(M)!=id) mexErrMsgTxt("Invalid type."); 13 | } 14 | 15 | mxArray* toMxArray( const RLE *R, siz n ) { 16 | const char *fs[] = {"size", "counts"}; 17 | mxArray *M=mxCreateStructMatrix(1,n,2,fs); 18 | for( siz i=0; i1) mexErrMsgTxt(err); 35 | for( i=0; i<*n; i++ ) { 36 | mxArray *S, *C; double *s; void *c; 37 | S=mxGetFieldByNumber(M,i,O[0]); checkType(S,mxDOUBLE_CLASS); 38 | C=mxGetFieldByNumber(M,i,O[1]); s=mxGetPr(S); c=mxGetData(C); 39 | h=(siz)s[0]; w=(siz)s[1]; m=mxGetNumberOfElements(C); 40 | if(same && i>0 && (h!=R[0].h || w!=R[0].w)) mexErrMsgTxt(err); 41 | if( mxGetClassID(C)==mxDOUBLE_CLASS ) { 42 | rleInit(R+i,h,w,m,0); 43 | for(j=0; j=2) ? (mxGetScalar(pr[1])>0) : false; 74 | rleMerge(R,&M,n,intersect); pl[0]=toMxArray(&M,1); rleFree(&M); 75 | 76 | } else if(!strcmp(action,"area")) { 77 | R=frMxArray(pr[0],&n,0); 78 | pl[0]=mxCreateNumericMatrix(1,n,mxUINT32_CLASS,mxREAL); 79 | uint *a=(uint*) mxGetPr(pl[0]); rleArea(R,n,a); 80 | 81 | } else if(!strcmp(action,"iou")) { 82 | if(nr>2) checkType(pr[2],mxUINT8_CLASS); siz nDt, nGt; 83 | byte *iscrowd = nr>2 ? (byte*) mxGetPr(pr[2]) : NULL; 84 | if(mxIsStruct(pr[0]) || mxIsStruct(pr[1])) { 85 | RLE *dt=frMxArray(pr[0],&nDt,1), *gt=frMxArray(pr[1],&nGt,1); 86 | pl[0]=mxCreateNumericMatrix(nDt,nGt,mxDOUBLE_CLASS,mxREAL); 87 | double *o=mxGetPr(pl[0]); rleIou(dt,gt,nDt,nGt,iscrowd,o); 88 | rlesFree(&dt,nDt); rlesFree(>,nGt); 89 | } else { 90 | checkType(pr[0],mxDOUBLE_CLASS); checkType(pr[1],mxDOUBLE_CLASS); 91 | double *dt=mxGetPr(pr[0]); nDt=mxGetN(pr[0]); 92 | double *gt=mxGetPr(pr[1]); nGt=mxGetN(pr[1]); 93 | pl[0]=mxCreateNumericMatrix(nDt,nGt,mxDOUBLE_CLASS,mxREAL); 94 | double *o=mxGetPr(pl[0]); bbIou(dt,gt,nDt,nGt,iscrowd,o); 95 | } 96 | 97 | } else if(!strcmp(action,"nms")) { 98 | siz n; uint *keep; double thr=(double) mxGetScalar(pr[1]); 99 | if(mxIsStruct(pr[0])) { 100 | RLE *dt=frMxArray(pr[0],&n,1); 101 | pl[0]=mxCreateNumericMatrix(1,n,mxUINT32_CLASS,mxREAL); 102 | keep=(uint*) mxGetPr(pl[0]); rleNms(dt,n,keep,thr); 103 | rlesFree(&dt,n); 104 | } else { 105 | checkType(pr[0],mxDOUBLE_CLASS); 106 | double *dt=mxGetPr(pr[0]); n=mxGetN(pr[0]); 107 | pl[0]=mxCreateNumericMatrix(1,n,mxUINT32_CLASS,mxREAL); 108 | keep=(uint*) mxGetPr(pl[0]); bbNms(dt,n,keep,thr); 109 | } 110 | 111 | } else if(!strcmp(action,"toBbox")) { 112 | R=frMxArray(pr[0],&n,0); 113 | pl[0]=mxCreateNumericMatrix(4,n,mxDOUBLE_CLASS,mxREAL); 114 | BB bb=mxGetPr(pl[0]); rleToBbox(R,bb,n); 115 | 116 | } else if(!strcmp(action,"frBbox")) { 117 | checkType(pr[0],mxDOUBLE_CLASS); 118 | double *bb=mxGetPr(pr[0]); n=mxGetN(pr[0]); 119 | h=(siz)mxGetScalar(pr[1]); w=(siz)mxGetScalar(pr[2]); 120 | rlesInit(&R,n); rleFrBbox(R,bb,h,w,n); pl[0]=toMxArray(R,n); 121 | 122 | } else if(!strcmp(action,"frPoly")) { 123 | checkType(pr[0],mxCELL_CLASS); n=mxGetNumberOfElements(pr[0]); 124 | h=(siz)mxGetScalar(pr[1]); w=(siz)mxGetScalar(pr[2]); rlesInit(&R,n); 125 | for(siz i=0; i malloc(h*w*n* sizeof(byte)) 85 | self._h = h 86 | self._w = w 87 | self._n = n 88 | # def __dealloc__(self): 89 | # the memory management of _mask has been passed to np.ndarray 90 | # it doesn't need to be freed here 91 | 92 | # called when passing into np.array() and return an np.ndarray in column-major order 93 | def __array__(self): 94 | cdef np.npy_intp shape[1] 95 | shape[0] = self._h*self._w*self._n 96 | # Create a 1D array, and reshape it to fortran/Matlab column-major array 97 | ndarray = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT8, self._mask).reshape((self._h, self._w, self._n), order='F') 98 | # The _mask allocated by Masks is now handled by ndarray 99 | PyArray_ENABLEFLAGS(ndarray, np.NPY_OWNDATA) 100 | return ndarray 101 | 102 | # internal conversion from Python RLEs object to compressed RLE format 103 | def _toString(RLEs Rs): 104 | cdef siz n = Rs.n 105 | cdef bytes py_string 106 | cdef char* c_string 107 | objs = [] 108 | for i in range(n): 109 | c_string = rleToString( &Rs._R[i] ) 110 | py_string = c_string 111 | objs.append({ 112 | 'size': [Rs._R[i].h, Rs._R[i].w], 113 | 'counts': py_string 114 | }) 115 | free(c_string) 116 | return objs 117 | 118 | # internal conversion from compressed RLE format to Python RLEs object 119 | def _frString(rleObjs): 120 | cdef siz n = len(rleObjs) 121 | Rs = RLEs(n) 122 | cdef bytes py_string 123 | cdef char* c_string 124 | for i, obj in enumerate(rleObjs): 125 | if PYTHON_VERSION == 2: 126 | py_string = str(obj['counts']).encode('utf8') 127 | elif PYTHON_VERSION == 3: 128 | py_string = str.encode(obj['counts']) if type(obj['counts']) == str else obj['counts'] 129 | else: 130 | raise Exception('Python version must be 2 or 3') 131 | c_string = py_string 132 | rleFrString( &Rs._R[i], c_string, obj['size'][0], obj['size'][1] ) 133 | return Rs 134 | 135 | # encode mask to RLEs objects 136 | # list of RLE string can be generated by RLEs member function 137 | def encode(np.ndarray[np.uint8_t, ndim=3, mode='fortran'] mask): 138 | h, w, n = mask.shape[0], mask.shape[1], mask.shape[2] 139 | cdef RLEs Rs = RLEs(n) 140 | rleEncode(Rs._R,mask.data,h,w,n) 141 | objs = _toString(Rs) 142 | return objs 143 | 144 | # decode mask from compressed list of RLE string or RLEs object 145 | def decode(rleObjs): 146 | cdef RLEs Rs = _frString(rleObjs) 147 | h, w, n = Rs._R[0].h, Rs._R[0].w, Rs._n 148 | masks = Masks(h, w, n) 149 | rleDecode(Rs._R, masks._mask, n); 150 | return np.array(masks) 151 | 152 | def merge(rleObjs, intersect=0): 153 | cdef RLEs Rs = _frString(rleObjs) 154 | cdef RLEs R = RLEs(1) 155 | rleMerge(Rs._R, R._R, Rs._n, intersect) 156 | obj = _toString(R)[0] 157 | return obj 158 | 159 | def area(rleObjs): 160 | cdef RLEs Rs = _frString(rleObjs) 161 | cdef uint* _a = malloc(Rs._n* sizeof(uint)) 162 | rleArea(Rs._R, Rs._n, _a) 163 | cdef np.npy_intp shape[1] 164 | shape[0] = Rs._n 165 | a = np.array((Rs._n, ), dtype=np.uint8) 166 | a = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT32, _a) 167 | PyArray_ENABLEFLAGS(a, np.NPY_OWNDATA) 168 | return a 169 | 170 | # iou computation. support function overload (RLEs-RLEs and bbox-bbox). 171 | def iou( dt, gt, pyiscrowd ): 172 | def _preproc(objs): 173 | if len(objs) == 0: 174 | return objs 175 | if type(objs) == np.ndarray: 176 | if len(objs.shape) == 1: 177 | objs = objs.reshape((objs[0], 1)) 178 | # check if it's Nx4 bbox 179 | if not len(objs.shape) == 2 or not objs.shape[1] == 4: 180 | raise Exception('numpy ndarray input is only for *bounding boxes* and should have Nx4 dimension') 181 | objs = objs.astype(np.double) 182 | elif type(objs) == list: 183 | # check if list is in box format and convert it to np.ndarray 184 | isbox = np.all(np.array([(len(obj)==4) and ((type(obj)==list) or (type(obj)==np.ndarray)) for obj in objs])) 185 | isrle = np.all(np.array([type(obj) == dict for obj in objs])) 186 | if isbox: 187 | objs = np.array(objs, dtype=np.double) 188 | if len(objs.shape) == 1: 189 | objs = objs.reshape((1,objs.shape[0])) 190 | elif isrle: 191 | objs = _frString(objs) 192 | else: 193 | raise Exception('list input can be bounding box (Nx4) or RLEs ([RLE])') 194 | else: 195 | raise Exception('unrecognized type. The following type: RLEs (rle), np.ndarray (box), and list (box) are supported.') 196 | return objs 197 | def _rleIou(RLEs dt, RLEs gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 198 | rleIou( dt._R, gt._R, m, n, iscrowd.data, _iou.data ) 199 | def _bbIou(np.ndarray[np.double_t, ndim=2] dt, np.ndarray[np.double_t, ndim=2] gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 200 | bbIou( dt.data, gt.data, m, n, iscrowd.data, _iou.data ) 201 | def _len(obj): 202 | cdef siz N = 0 203 | if type(obj) == RLEs: 204 | N = obj.n 205 | elif len(obj)==0: 206 | pass 207 | elif type(obj) == np.ndarray: 208 | N = obj.shape[0] 209 | return N 210 | # convert iscrowd to numpy array 211 | cdef np.ndarray[np.uint8_t, ndim=1] iscrowd = np.array(pyiscrowd, dtype=np.uint8) 212 | # simple type checking 213 | cdef siz m, n 214 | dt = _preproc(dt) 215 | gt = _preproc(gt) 216 | m = _len(dt) 217 | n = _len(gt) 218 | if m == 0 or n == 0: 219 | return [] 220 | if not type(dt) == type(gt): 221 | raise Exception('The dt and gt should have the same data type, either RLEs, list or np.ndarray') 222 | 223 | # define local variables 224 | cdef double* _iou = 0 225 | cdef np.npy_intp shape[1] 226 | # check type and assign iou function 227 | if type(dt) == RLEs: 228 | _iouFun = _rleIou 229 | elif type(dt) == np.ndarray: 230 | _iouFun = _bbIou 231 | else: 232 | raise Exception('input data type not allowed.') 233 | _iou = malloc(m*n* sizeof(double)) 234 | iou = np.zeros((m*n, ), dtype=np.double) 235 | shape[0] = m*n 236 | iou = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _iou) 237 | PyArray_ENABLEFLAGS(iou, np.NPY_OWNDATA) 238 | _iouFun(dt, gt, iscrowd, m, n, iou) 239 | return iou.reshape((m,n), order='F') 240 | 241 | def toBbox( rleObjs ): 242 | cdef RLEs Rs = _frString(rleObjs) 243 | cdef siz n = Rs.n 244 | cdef BB _bb = malloc(4*n* sizeof(double)) 245 | rleToBbox( Rs._R, _bb, n ) 246 | cdef np.npy_intp shape[1] 247 | shape[0] = 4*n 248 | bb = np.array((1,4*n), dtype=np.double) 249 | bb = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _bb).reshape((n, 4)) 250 | PyArray_ENABLEFLAGS(bb, np.NPY_OWNDATA) 251 | return bb 252 | 253 | def frBbox(np.ndarray[np.double_t, ndim=2] bb, siz h, siz w ): 254 | cdef siz n = bb.shape[0] 255 | Rs = RLEs(n) 256 | rleFrBbox( Rs._R, bb.data, h, w, n ) 257 | objs = _toString(Rs) 258 | return objs 259 | 260 | def frPoly( poly, siz h, siz w ): 261 | cdef np.ndarray[np.double_t, ndim=1] np_poly 262 | n = len(poly) 263 | Rs = RLEs(n) 264 | for i, p in enumerate(poly): 265 | np_poly = np.array(p, dtype=np.double, order='F') 266 | rleFrPoly( &Rs._R[i], np_poly.data, int(len(p)/2), h, w ) 267 | objs = _toString(Rs) 268 | return objs 269 | 270 | def frUncompressedRLE(ucRles, siz h, siz w): 271 | cdef np.ndarray[np.uint32_t, ndim=1] cnts 272 | cdef RLE R 273 | cdef uint *data 274 | n = len(ucRles) 275 | objs = [] 276 | for i in range(n): 277 | Rs = RLEs(1) 278 | cnts = np.array(ucRles[i]['counts'], dtype=np.uint32) 279 | # time for malloc can be saved here but it's fine 280 | data = malloc(len(cnts)* sizeof(uint)) 281 | for j in range(len(cnts)): 282 | data[j] = cnts[j] 283 | R = RLE(ucRles[i]['size'][0], ucRles[i]['size'][1], len(cnts), data) 284 | Rs._R[0] = R 285 | objs.append(_toString(Rs)[0]) 286 | return objs 287 | 288 | def frPyObjects(pyobj, h, w): 289 | # encode rle from a list of python objects 290 | if type(pyobj) == np.ndarray: 291 | objs = frBbox(pyobj, h, w) 292 | elif type(pyobj) == list and len(pyobj[0]) == 4: 293 | objs = frBbox(pyobj, h, w) 294 | elif type(pyobj) == list and len(pyobj[0]) > 4: 295 | objs = frPoly(pyobj, h, w) 296 | elif type(pyobj) == list and type(pyobj[0]) == dict \ 297 | and 'counts' in pyobj[0] and 'size' in pyobj[0]: 298 | objs = frUncompressedRLE(pyobj, h, w) 299 | # encode rle from single python object 300 | elif type(pyobj) == list and len(pyobj) == 4: 301 | objs = frBbox([pyobj], h, w)[0] 302 | elif type(pyobj) == list and len(pyobj) > 4: 303 | objs = frPoly([pyobj], h, w)[0] 304 | elif type(pyobj) == dict and 'counts' in pyobj and 'size' in pyobj: 305 | objs = frUncompressedRLE([pyobj], h, w)[0] 306 | else: 307 | raise Exception('input type is not supported.') 308 | return objs 309 | -------------------------------------------------------------------------------- /cocoapi/PythonAPI/pycocotools/_mask.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-lab/MobilePose-pytorch/351fc4c244416c24cd2e35e4edb98b5ad856a722/cocoapi/PythonAPI/pycocotools/_mask.so -------------------------------------------------------------------------------- /cocoapi/PythonAPI/pycocotools/coco.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | __version__ = '2.0' 3 | # Interface for accessing the Microsoft COCO dataset. 4 | 5 | # Microsoft COCO is a large image dataset designed for object detection, 6 | # segmentation, and caption generation. pycocotools is a Python API that 7 | # assists in loading, parsing and visualizing the annotations in COCO. 8 | # Please visit http://mscoco.org/ for more information on COCO, including 9 | # for the data, paper, and tutorials. The exact format of the annotations 10 | # is also described on the COCO website. For example usage of the pycocotools 11 | # please see pycocotools_demo.ipynb. In addition to this API, please download both 12 | # the COCO images and annotations in order to run the demo. 13 | 14 | # An alternative to using the API is to load the annotations directly 15 | # into Python dictionary 16 | # Using the API provides additional utility functions. Note that this API 17 | # supports both *instance* and *caption* annotations. In the case of 18 | # captions not all functions are defined (e.g. categories are undefined). 19 | 20 | # The following API functions are defined: 21 | # COCO - COCO api class that loads COCO annotation file and prepare data structures. 22 | # decodeMask - Decode binary mask M encoded via run-length encoding. 23 | # encodeMask - Encode binary mask M using run-length encoding. 24 | # getAnnIds - Get ann ids that satisfy given filter conditions. 25 | # getCatIds - Get cat ids that satisfy given filter conditions. 26 | # getImgIds - Get img ids that satisfy given filter conditions. 27 | # loadAnns - Load anns with the specified ids. 28 | # loadCats - Load cats with the specified ids. 29 | # loadImgs - Load imgs with the specified ids. 30 | # annToMask - Convert segmentation in an annotation to binary mask. 31 | # showAnns - Display the specified annotations. 32 | # loadRes - Load algorithm results and create API for accessing them. 33 | # download - Download COCO images from mscoco.org server. 34 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 35 | # Help on each functions can be accessed by: "help COCO>function". 36 | 37 | # See also COCO>decodeMask, 38 | # COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds, 39 | # COCO>getImgIds, COCO>loadAnns, COCO>loadCats, 40 | # COCO>loadImgs, COCO>annToMask, COCO>showAnns 41 | 42 | # Microsoft COCO Toolbox. version 2.0 43 | # Data, paper, and tutorials available at: http://mscoco.org/ 44 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2014. 45 | # Licensed under the Simplified BSD License [see bsd.txt] 46 | 47 | import json 48 | import time 49 | import matplotlib.pyplot as plt 50 | from matplotlib.collections import PatchCollection 51 | from matplotlib.patches import Polygon 52 | import numpy as np 53 | import copy 54 | import itertools 55 | from . import mask as maskUtils 56 | import os 57 | from collections import defaultdict 58 | import sys 59 | PYTHON_VERSION = sys.version_info[0] 60 | if PYTHON_VERSION == 2: 61 | from urllib import urlretrieve 62 | elif PYTHON_VERSION == 3: 63 | from urllib.request import urlretrieve 64 | 65 | 66 | def _isArrayLike(obj): 67 | return hasattr(obj, '__iter__') and hasattr(obj, '__len__') 68 | 69 | 70 | class COCO: 71 | def __init__(self, annotation_file=None): 72 | """ 73 | Constructor of Microsoft COCO helper class for reading and visualizing annotations. 74 | :param annotation_file (str): location of annotation file 75 | :param image_folder (str): location to the folder that hosts images. 76 | :return: 77 | """ 78 | # load dataset 79 | self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict() 80 | self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) 81 | if not annotation_file == None: 82 | print('loading annotations into memory...') 83 | tic = time.time() 84 | dataset = json.load(open(annotation_file, 'r')) 85 | assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset)) 86 | print('Done (t={:0.2f}s)'.format(time.time()- tic)) 87 | self.dataset = dataset 88 | self.createIndex() 89 | 90 | def createIndex(self): 91 | # create index 92 | print('creating index...') 93 | anns, cats, imgs = {}, {}, {} 94 | imgToAnns,catToImgs = defaultdict(list),defaultdict(list) 95 | if 'annotations' in self.dataset: 96 | for ann in self.dataset['annotations']: 97 | imgToAnns[ann['image_id']].append(ann) 98 | anns[ann['id']] = ann 99 | 100 | if 'images' in self.dataset: 101 | for img in self.dataset['images']: 102 | imgs[img['id']] = img 103 | 104 | if 'categories' in self.dataset: 105 | for cat in self.dataset['categories']: 106 | cats[cat['id']] = cat 107 | 108 | if 'annotations' in self.dataset and 'categories' in self.dataset: 109 | for ann in self.dataset['annotations']: 110 | catToImgs[ann['category_id']].append(ann['image_id']) 111 | 112 | print('index created!') 113 | 114 | # create class members 115 | self.anns = anns 116 | self.imgToAnns = imgToAnns 117 | self.catToImgs = catToImgs 118 | self.imgs = imgs 119 | self.cats = cats 120 | 121 | def info(self): 122 | """ 123 | Print information about the annotation file. 124 | :return: 125 | """ 126 | for key, value in self.dataset['info'].items(): 127 | print('{}: {}'.format(key, value)) 128 | 129 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None): 130 | """ 131 | Get ann ids that satisfy given filter conditions. default skips that filter 132 | :param imgIds (int array) : get anns for given imgs 133 | catIds (int array) : get anns for given cats 134 | areaRng (float array) : get anns for given area range (e.g. [0 inf]) 135 | iscrowd (boolean) : get anns for given crowd label (False or True) 136 | :return: ids (int array) : integer array of ann ids 137 | """ 138 | imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] 139 | catIds = catIds if _isArrayLike(catIds) else [catIds] 140 | 141 | if len(imgIds) == len(catIds) == len(areaRng) == 0: 142 | anns = self.dataset['annotations'] 143 | else: 144 | if not len(imgIds) == 0: 145 | lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns] 146 | anns = list(itertools.chain.from_iterable(lists)) 147 | else: 148 | anns = self.dataset['annotations'] 149 | anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds] 150 | anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]] 151 | if not iscrowd == None: 152 | ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd] 153 | else: 154 | ids = [ann['id'] for ann in anns] 155 | return ids 156 | 157 | def getCatIds(self, catNms=[], supNms=[], catIds=[]): 158 | """ 159 | filtering parameters. default skips that filter. 160 | :param catNms (str array) : get cats for given cat names 161 | :param supNms (str array) : get cats for given supercategory names 162 | :param catIds (int array) : get cats for given cat ids 163 | :return: ids (int array) : integer array of cat ids 164 | """ 165 | catNms = catNms if _isArrayLike(catNms) else [catNms] 166 | supNms = supNms if _isArrayLike(supNms) else [supNms] 167 | catIds = catIds if _isArrayLike(catIds) else [catIds] 168 | 169 | if len(catNms) == len(supNms) == len(catIds) == 0: 170 | cats = self.dataset['categories'] 171 | else: 172 | cats = self.dataset['categories'] 173 | cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms] 174 | cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms] 175 | cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds] 176 | ids = [cat['id'] for cat in cats] 177 | return ids 178 | 179 | def getImgIds(self, imgIds=[], catIds=[]): 180 | ''' 181 | Get img ids that satisfy given filter conditions. 182 | :param imgIds (int array) : get imgs for given ids 183 | :param catIds (int array) : get imgs with all given cats 184 | :return: ids (int array) : integer array of img ids 185 | ''' 186 | imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] 187 | catIds = catIds if _isArrayLike(catIds) else [catIds] 188 | 189 | if len(imgIds) == len(catIds) == 0: 190 | ids = self.imgs.keys() 191 | else: 192 | ids = set(imgIds) 193 | for i, catId in enumerate(catIds): 194 | if i == 0 and len(ids) == 0: 195 | ids = set(self.catToImgs[catId]) 196 | else: 197 | ids &= set(self.catToImgs[catId]) 198 | return list(ids) 199 | 200 | def loadAnns(self, ids=[]): 201 | """ 202 | Load anns with the specified ids. 203 | :param ids (int array) : integer ids specifying anns 204 | :return: anns (object array) : loaded ann objects 205 | """ 206 | if _isArrayLike(ids): 207 | return [self.anns[id] for id in ids] 208 | elif type(ids) == int: 209 | return [self.anns[ids]] 210 | 211 | def loadCats(self, ids=[]): 212 | """ 213 | Load cats with the specified ids. 214 | :param ids (int array) : integer ids specifying cats 215 | :return: cats (object array) : loaded cat objects 216 | """ 217 | if _isArrayLike(ids): 218 | return [self.cats[id] for id in ids] 219 | elif type(ids) == int: 220 | return [self.cats[ids]] 221 | 222 | def loadImgs(self, ids=[]): 223 | """ 224 | Load anns with the specified ids. 225 | :param ids (int array) : integer ids specifying img 226 | :return: imgs (object array) : loaded img objects 227 | """ 228 | if _isArrayLike(ids): 229 | return [self.imgs[id] for id in ids] 230 | elif type(ids) == int: 231 | return [self.imgs[ids]] 232 | 233 | def showAnns(self, anns): 234 | """ 235 | Display the specified annotations. 236 | :param anns (array of object): annotations to display 237 | :return: None 238 | """ 239 | if len(anns) == 0: 240 | return 0 241 | if 'segmentation' in anns[0] or 'keypoints' in anns[0]: 242 | datasetType = 'instances' 243 | elif 'caption' in anns[0]: 244 | datasetType = 'captions' 245 | else: 246 | raise Exception('datasetType not supported') 247 | if datasetType == 'instances': 248 | ax = plt.gca() 249 | ax.set_autoscale_on(False) 250 | polygons = [] 251 | color = [] 252 | for ann in anns: 253 | c = (np.random.random((1, 3))*0.6+0.4).tolist()[0] 254 | if 'segmentation' in ann: 255 | if type(ann['segmentation']) == list: 256 | # polygon 257 | for seg in ann['segmentation']: 258 | poly = np.array(seg).reshape((int(len(seg)/2), 2)) 259 | polygons.append(Polygon(poly)) 260 | color.append(c) 261 | else: 262 | # mask 263 | t = self.imgs[ann['image_id']] 264 | if type(ann['segmentation']['counts']) == list: 265 | rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width']) 266 | else: 267 | rle = [ann['segmentation']] 268 | m = maskUtils.decode(rle) 269 | img = np.ones( (m.shape[0], m.shape[1], 3) ) 270 | if ann['iscrowd'] == 1: 271 | color_mask = np.array([2.0,166.0,101.0])/255 272 | if ann['iscrowd'] == 0: 273 | color_mask = np.random.random((1, 3)).tolist()[0] 274 | for i in range(3): 275 | img[:,:,i] = color_mask[i] 276 | ax.imshow(np.dstack( (img, m*0.5) )) 277 | if 'keypoints' in ann and type(ann['keypoints']) == list: 278 | # turn skeleton into zero-based index 279 | sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1 280 | kp = np.array(ann['keypoints']) 281 | x = kp[0::3] 282 | y = kp[1::3] 283 | v = kp[2::3] 284 | for sk in sks: 285 | if np.all(v[sk]>0): 286 | plt.plot(x[sk],y[sk], linewidth=3, color=c) 287 | plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2) 288 | plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2) 289 | p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4) 290 | ax.add_collection(p) 291 | p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2) 292 | ax.add_collection(p) 293 | elif datasetType == 'captions': 294 | for ann in anns: 295 | print(ann['caption']) 296 | 297 | def loadRes(self, resFile): 298 | """ 299 | Load result file and return a result api object. 300 | :param resFile (str) : file name of result file 301 | :return: res (obj) : result api object 302 | """ 303 | res = COCO() 304 | res.dataset['images'] = [img for img in self.dataset['images']] 305 | 306 | print('Loading and preparing results...') 307 | tic = time.time() 308 | if type(resFile) == str or type(resFile) == unicode: 309 | anns = json.load(open(resFile)) 310 | elif type(resFile) == np.ndarray: 311 | anns = self.loadNumpyAnnotations(resFile) 312 | else: 313 | anns = resFile 314 | assert type(anns) == list, 'results in not an array of objects' 315 | annsImgIds = [ann['image_id'] for ann in anns] 316 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 317 | 'Results do not correspond to current coco set' 318 | if 'caption' in anns[0]: 319 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) 320 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] 321 | for id, ann in enumerate(anns): 322 | ann['id'] = id+1 323 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: 324 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 325 | for id, ann in enumerate(anns): 326 | bb = ann['bbox'] 327 | x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]] 328 | if not 'segmentation' in ann: 329 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 330 | ann['area'] = bb[2]*bb[3] 331 | ann['id'] = id+1 332 | ann['iscrowd'] = 0 333 | elif 'segmentation' in anns[0]: 334 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 335 | for id, ann in enumerate(anns): 336 | # now only support compressed RLE format as segmentation results 337 | ann['area'] = maskUtils.area(ann['segmentation']) 338 | if not 'bbox' in ann: 339 | ann['bbox'] = maskUtils.toBbox(ann['segmentation']) 340 | ann['id'] = id+1 341 | ann['iscrowd'] = 0 342 | elif 'keypoints' in anns[0]: 343 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 344 | for id, ann in enumerate(anns): 345 | s = ann['keypoints'] 346 | x = s[0::3] 347 | y = s[1::3] 348 | x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y) 349 | ann['area'] = (x1-x0)*(y1-y0) 350 | ann['id'] = id + 1 351 | ann['bbox'] = [x0,y0,x1-x0,y1-y0] 352 | print('DONE (t={:0.2f}s)'.format(time.time()- tic)) 353 | 354 | res.dataset['annotations'] = anns 355 | res.createIndex() 356 | return res 357 | 358 | def download(self, tarDir = None, imgIds = [] ): 359 | ''' 360 | Download COCO images from mscoco.org server. 361 | :param tarDir (str): COCO results directory name 362 | imgIds (list): images to be downloaded 363 | :return: 364 | ''' 365 | if tarDir is None: 366 | print('Please specify target directory') 367 | return -1 368 | if len(imgIds) == 0: 369 | imgs = self.imgs.values() 370 | else: 371 | imgs = self.loadImgs(imgIds) 372 | N = len(imgs) 373 | if not os.path.exists(tarDir): 374 | os.makedirs(tarDir) 375 | for i, img in enumerate(imgs): 376 | tic = time.time() 377 | fname = os.path.join(tarDir, img['file_name']) 378 | if not os.path.exists(fname): 379 | urlretrieve(img['coco_url'], fname) 380 | print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic)) 381 | 382 | def loadNumpyAnnotations(self, data): 383 | """ 384 | Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class} 385 | :param data (numpy.ndarray) 386 | :return: annotations (python nested list) 387 | """ 388 | print('Converting ndarray to lists...') 389 | assert(type(data) == np.ndarray) 390 | print(data.shape) 391 | assert(data.shape[1] == 7) 392 | N = data.shape[0] 393 | ann = [] 394 | for i in range(N): 395 | if i % 1000000 == 0: 396 | print('{}/{}'.format(i,N)) 397 | ann += [{ 398 | 'image_id' : int(data[i, 0]), 399 | 'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ], 400 | 'score' : data[i, 5], 401 | 'category_id': int(data[i, 6]), 402 | }] 403 | return ann 404 | 405 | def annToRLE(self, ann): 406 | """ 407 | Convert annotation which can be polygons, uncompressed RLE to RLE. 408 | :return: binary mask (numpy 2D array) 409 | """ 410 | t = self.imgs[ann['image_id']] 411 | h, w = t['height'], t['width'] 412 | segm = ann['segmentation'] 413 | if type(segm) == list: 414 | # polygon -- a single object might consist of multiple parts 415 | # we merge all parts into one mask rle code 416 | rles = maskUtils.frPyObjects(segm, h, w) 417 | rle = maskUtils.merge(rles) 418 | elif type(segm['counts']) == list: 419 | # uncompressed RLE 420 | rle = maskUtils.frPyObjects(segm, h, w) 421 | else: 422 | # rle 423 | rle = ann['segmentation'] 424 | return rle 425 | 426 | def annToMask(self, ann): 427 | """ 428 | Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask. 429 | :return: binary mask (numpy 2D array) 430 | """ 431 | rle = self.annToRLE(ann) 432 | m = maskUtils.decode(rle) 433 | return m -------------------------------------------------------------------------------- /cocoapi/PythonAPI/pycocotools/mask.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | import pycocotools._mask as _mask 4 | 5 | # Interface for manipulating masks stored in RLE format. 6 | # 7 | # RLE is a simple yet efficient format for storing binary masks. RLE 8 | # first divides a vector (or vectorized image) into a series of piecewise 9 | # constant regions and then for each piece simply stores the length of 10 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 11 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 12 | # (note that the odd counts are always the numbers of zeros). Instead of 13 | # storing the counts directly, additional compression is achieved with a 14 | # variable bitrate representation based on a common scheme called LEB128. 15 | # 16 | # Compression is greatest given large piecewise constant regions. 17 | # Specifically, the size of the RLE is proportional to the number of 18 | # *boundaries* in M (or for an image the number of boundaries in the y 19 | # direction). Assuming fairly simple shapes, the RLE representation is 20 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage 21 | # is substantially lower, especially for large simple objects (large n). 22 | # 23 | # Many common operations on masks can be computed directly using the RLE 24 | # (without need for decoding). This includes computations such as area, 25 | # union, intersection, etc. All of these operations are linear in the 26 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area 27 | # of the object. Computing these operations on the original mask is O(n). 28 | # Thus, using the RLE can result in substantial computational savings. 29 | # 30 | # The following API functions are defined: 31 | # encode - Encode binary masks using RLE. 32 | # decode - Decode binary masks encoded via RLE. 33 | # merge - Compute union or intersection of encoded masks. 34 | # iou - Compute intersection over union between masks. 35 | # area - Compute area of encoded masks. 36 | # toBbox - Get bounding boxes surrounding encoded masks. 37 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask. 38 | # 39 | # Usage: 40 | # Rs = encode( masks ) 41 | # masks = decode( Rs ) 42 | # R = merge( Rs, intersect=false ) 43 | # o = iou( dt, gt, iscrowd ) 44 | # a = area( Rs ) 45 | # bbs = toBbox( Rs ) 46 | # Rs = frPyObjects( [pyObjects], h, w ) 47 | # 48 | # In the API the following formats are used: 49 | # Rs - [dict] Run-length encoding of binary masks 50 | # R - dict Run-length encoding of binary mask 51 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order) 52 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore 53 | # bbs - [nx4] Bounding box(es) stored as [x y w h] 54 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list) 55 | # dt,gt - May be either bounding boxes or encoded masks 56 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 57 | # 58 | # Finally, a note about the intersection over union (iou) computation. 59 | # The standard iou of a ground truth (gt) and detected (dt) object is 60 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 61 | # For "crowd" regions, we use a modified criteria. If a gt object is 62 | # marked as "iscrowd", we allow a dt to match any subregion of the gt. 63 | # Choosing gt' in the crowd gt that best matches the dt can be done using 64 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 65 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 66 | # For crowd gt regions we use this modified criteria above for the iou. 67 | # 68 | # To compile run "python setup.py build_ext --inplace" 69 | # Please do not contact us for help with compiling. 70 | # 71 | # Microsoft COCO Toolbox. version 2.0 72 | # Data, paper, and tutorials available at: http://mscoco.org/ 73 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 74 | # Licensed under the Simplified BSD License [see coco/license.txt] 75 | 76 | iou = _mask.iou 77 | merge = _mask.merge 78 | frPyObjects = _mask.frPyObjects 79 | 80 | def encode(bimask): 81 | if len(bimask.shape) == 3: 82 | return _mask.encode(bimask) 83 | elif len(bimask.shape) == 2: 84 | h, w = bimask.shape 85 | return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0] 86 | 87 | def decode(rleObjs): 88 | if type(rleObjs) == list: 89 | return _mask.decode(rleObjs) 90 | else: 91 | return _mask.decode([rleObjs])[:,:,0] 92 | 93 | def area(rleObjs): 94 | if type(rleObjs) == list: 95 | return _mask.area(rleObjs) 96 | else: 97 | return _mask.area([rleObjs])[0] 98 | 99 | def toBbox(rleObjs): 100 | if type(rleObjs) == list: 101 | return _mask.toBbox(rleObjs) 102 | else: 103 | return _mask.toBbox([rleObjs])[0] -------------------------------------------------------------------------------- /cocoapi/PythonAPI/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | from distutils.extension import Extension 4 | import numpy as np 5 | 6 | # To compile and install locally run "python setup.py build_ext --inplace" 7 | # To install library to Python site-packages run "python setup.py build_ext install" 8 | 9 | ext_modules = [ 10 | Extension( 11 | 'pycocotools._mask', 12 | sources=['../common/maskApi.c', 'pycocotools/_mask.pyx'], 13 | include_dirs = [np.get_include(), '../common'], 14 | extra_compile_args=['-Wno-cpp', '-Wno-unused-function', '-std=c99'], 15 | ) 16 | ] 17 | 18 | setup(name='pycocotools', 19 | packages=['pycocotools'], 20 | package_dir = {'pycocotools': 'pycocotools'}, 21 | version='2.0', 22 | ext_modules= 23 | cythonize(ext_modules) 24 | ) -------------------------------------------------------------------------------- /cocoapi/common/gason.cpp: -------------------------------------------------------------------------------- 1 | // https://github.com/vivkin/gason - pulled January 10, 2016 2 | #include "gason.h" 3 | #include 4 | 5 | #define JSON_ZONE_SIZE 4096 6 | #define JSON_STACK_SIZE 32 7 | 8 | const char *jsonStrError(int err) { 9 | switch (err) { 10 | #define XX(no, str) \ 11 | case JSON_##no: \ 12 | return str; 13 | JSON_ERRNO_MAP(XX) 14 | #undef XX 15 | default: 16 | return "unknown"; 17 | } 18 | } 19 | 20 | void *JsonAllocator::allocate(size_t size) { 21 | size = (size + 7) & ~7; 22 | 23 | if (head && head->used + size <= JSON_ZONE_SIZE) { 24 | char *p = (char *)head + head->used; 25 | head->used += size; 26 | return p; 27 | } 28 | 29 | size_t allocSize = sizeof(Zone) + size; 30 | Zone *zone = (Zone *)malloc(allocSize <= JSON_ZONE_SIZE ? JSON_ZONE_SIZE : allocSize); 31 | if (zone == nullptr) 32 | return nullptr; 33 | zone->used = allocSize; 34 | if (allocSize <= JSON_ZONE_SIZE || head == nullptr) { 35 | zone->next = head; 36 | head = zone; 37 | } else { 38 | zone->next = head->next; 39 | head->next = zone; 40 | } 41 | return (char *)zone + sizeof(Zone); 42 | } 43 | 44 | void JsonAllocator::deallocate() { 45 | while (head) { 46 | Zone *next = head->next; 47 | free(head); 48 | head = next; 49 | } 50 | } 51 | 52 | static inline bool isspace(char c) { 53 | return c == ' ' || (c >= '\t' && c <= '\r'); 54 | } 55 | 56 | static inline bool isdelim(char c) { 57 | return c == ',' || c == ':' || c == ']' || c == '}' || isspace(c) || !c; 58 | } 59 | 60 | static inline bool isdigit(char c) { 61 | return c >= '0' && c <= '9'; 62 | } 63 | 64 | static inline bool isxdigit(char c) { 65 | return (c >= '0' && c <= '9') || ((c & ~' ') >= 'A' && (c & ~' ') <= 'F'); 66 | } 67 | 68 | static inline int char2int(char c) { 69 | if (c <= '9') 70 | return c - '0'; 71 | return (c & ~' ') - 'A' + 10; 72 | } 73 | 74 | static double string2double(char *s, char **endptr) { 75 | char ch = *s; 76 | if (ch == '-') 77 | ++s; 78 | 79 | double result = 0; 80 | while (isdigit(*s)) 81 | result = (result * 10) + (*s++ - '0'); 82 | 83 | if (*s == '.') { 84 | ++s; 85 | 86 | double fraction = 1; 87 | while (isdigit(*s)) { 88 | fraction *= 0.1; 89 | result += (*s++ - '0') * fraction; 90 | } 91 | } 92 | 93 | if (*s == 'e' || *s == 'E') { 94 | ++s; 95 | 96 | double base = 10; 97 | if (*s == '+') 98 | ++s; 99 | else if (*s == '-') { 100 | ++s; 101 | base = 0.1; 102 | } 103 | 104 | unsigned int exponent = 0; 105 | while (isdigit(*s)) 106 | exponent = (exponent * 10) + (*s++ - '0'); 107 | 108 | double power = 1; 109 | for (; exponent; exponent >>= 1, base *= base) 110 | if (exponent & 1) 111 | power *= base; 112 | 113 | result *= power; 114 | } 115 | 116 | *endptr = s; 117 | return ch == '-' ? -result : result; 118 | } 119 | 120 | static inline JsonNode *insertAfter(JsonNode *tail, JsonNode *node) { 121 | if (!tail) 122 | return node->next = node; 123 | node->next = tail->next; 124 | tail->next = node; 125 | return node; 126 | } 127 | 128 | static inline JsonValue listToValue(JsonTag tag, JsonNode *tail) { 129 | if (tail) { 130 | auto head = tail->next; 131 | tail->next = nullptr; 132 | return JsonValue(tag, head); 133 | } 134 | return JsonValue(tag, nullptr); 135 | } 136 | 137 | int jsonParse(char *s, char **endptr, JsonValue *value, JsonAllocator &allocator) { 138 | JsonNode *tails[JSON_STACK_SIZE]; 139 | JsonTag tags[JSON_STACK_SIZE]; 140 | char *keys[JSON_STACK_SIZE]; 141 | JsonValue o; 142 | int pos = -1; 143 | bool separator = true; 144 | JsonNode *node; 145 | *endptr = s; 146 | 147 | while (*s) { 148 | while (isspace(*s)) { 149 | ++s; 150 | if (!*s) break; 151 | } 152 | *endptr = s++; 153 | switch (**endptr) { 154 | case '-': 155 | if (!isdigit(*s) && *s != '.') { 156 | *endptr = s; 157 | return JSON_BAD_NUMBER; 158 | } 159 | case '0': 160 | case '1': 161 | case '2': 162 | case '3': 163 | case '4': 164 | case '5': 165 | case '6': 166 | case '7': 167 | case '8': 168 | case '9': 169 | o = JsonValue(string2double(*endptr, &s)); 170 | if (!isdelim(*s)) { 171 | *endptr = s; 172 | return JSON_BAD_NUMBER; 173 | } 174 | break; 175 | case '"': 176 | o = JsonValue(JSON_STRING, s); 177 | for (char *it = s; *s; ++it, ++s) { 178 | int c = *it = *s; 179 | if (c == '\\') { 180 | c = *++s; 181 | switch (c) { 182 | case '\\': 183 | case '"': 184 | case '/': 185 | *it = c; 186 | break; 187 | case 'b': 188 | *it = '\b'; 189 | break; 190 | case 'f': 191 | *it = '\f'; 192 | break; 193 | case 'n': 194 | *it = '\n'; 195 | break; 196 | case 'r': 197 | *it = '\r'; 198 | break; 199 | case 't': 200 | *it = '\t'; 201 | break; 202 | case 'u': 203 | c = 0; 204 | for (int i = 0; i < 4; ++i) { 205 | if (isxdigit(*++s)) { 206 | c = c * 16 + char2int(*s); 207 | } else { 208 | *endptr = s; 209 | return JSON_BAD_STRING; 210 | } 211 | } 212 | if (c < 0x80) { 213 | *it = c; 214 | } else if (c < 0x800) { 215 | *it++ = 0xC0 | (c >> 6); 216 | *it = 0x80 | (c & 0x3F); 217 | } else { 218 | *it++ = 0xE0 | (c >> 12); 219 | *it++ = 0x80 | ((c >> 6) & 0x3F); 220 | *it = 0x80 | (c & 0x3F); 221 | } 222 | break; 223 | default: 224 | *endptr = s; 225 | return JSON_BAD_STRING; 226 | } 227 | } else if ((unsigned int)c < ' ' || c == '\x7F') { 228 | *endptr = s; 229 | return JSON_BAD_STRING; 230 | } else if (c == '"') { 231 | *it = 0; 232 | ++s; 233 | break; 234 | } 235 | } 236 | if (!isdelim(*s)) { 237 | *endptr = s; 238 | return JSON_BAD_STRING; 239 | } 240 | break; 241 | case 't': 242 | if (!(s[0] == 'r' && s[1] == 'u' && s[2] == 'e' && isdelim(s[3]))) 243 | return JSON_BAD_IDENTIFIER; 244 | o = JsonValue(JSON_TRUE); 245 | s += 3; 246 | break; 247 | case 'f': 248 | if (!(s[0] == 'a' && s[1] == 'l' && s[2] == 's' && s[3] == 'e' && isdelim(s[4]))) 249 | return JSON_BAD_IDENTIFIER; 250 | o = JsonValue(JSON_FALSE); 251 | s += 4; 252 | break; 253 | case 'n': 254 | if (!(s[0] == 'u' && s[1] == 'l' && s[2] == 'l' && isdelim(s[3]))) 255 | return JSON_BAD_IDENTIFIER; 256 | o = JsonValue(JSON_NULL); 257 | s += 3; 258 | break; 259 | case ']': 260 | if (pos == -1) 261 | return JSON_STACK_UNDERFLOW; 262 | if (tags[pos] != JSON_ARRAY) 263 | return JSON_MISMATCH_BRACKET; 264 | o = listToValue(JSON_ARRAY, tails[pos--]); 265 | break; 266 | case '}': 267 | if (pos == -1) 268 | return JSON_STACK_UNDERFLOW; 269 | if (tags[pos] != JSON_OBJECT) 270 | return JSON_MISMATCH_BRACKET; 271 | if (keys[pos] != nullptr) 272 | return JSON_UNEXPECTED_CHARACTER; 273 | o = listToValue(JSON_OBJECT, tails[pos--]); 274 | break; 275 | case '[': 276 | if (++pos == JSON_STACK_SIZE) 277 | return JSON_STACK_OVERFLOW; 278 | tails[pos] = nullptr; 279 | tags[pos] = JSON_ARRAY; 280 | keys[pos] = nullptr; 281 | separator = true; 282 | continue; 283 | case '{': 284 | if (++pos == JSON_STACK_SIZE) 285 | return JSON_STACK_OVERFLOW; 286 | tails[pos] = nullptr; 287 | tags[pos] = JSON_OBJECT; 288 | keys[pos] = nullptr; 289 | separator = true; 290 | continue; 291 | case ':': 292 | if (separator || keys[pos] == nullptr) 293 | return JSON_UNEXPECTED_CHARACTER; 294 | separator = true; 295 | continue; 296 | case ',': 297 | if (separator || keys[pos] != nullptr) 298 | return JSON_UNEXPECTED_CHARACTER; 299 | separator = true; 300 | continue; 301 | case '\0': 302 | continue; 303 | default: 304 | return JSON_UNEXPECTED_CHARACTER; 305 | } 306 | 307 | separator = false; 308 | 309 | if (pos == -1) { 310 | *endptr = s; 311 | *value = o; 312 | return JSON_OK; 313 | } 314 | 315 | if (tags[pos] == JSON_OBJECT) { 316 | if (!keys[pos]) { 317 | if (o.getTag() != JSON_STRING) 318 | return JSON_UNQUOTED_KEY; 319 | keys[pos] = o.toString(); 320 | continue; 321 | } 322 | if ((node = (JsonNode *) allocator.allocate(sizeof(JsonNode))) == nullptr) 323 | return JSON_ALLOCATION_FAILURE; 324 | tails[pos] = insertAfter(tails[pos], node); 325 | tails[pos]->key = keys[pos]; 326 | keys[pos] = nullptr; 327 | } else { 328 | if ((node = (JsonNode *) allocator.allocate(sizeof(JsonNode) - sizeof(char *))) == nullptr) 329 | return JSON_ALLOCATION_FAILURE; 330 | tails[pos] = insertAfter(tails[pos], node); 331 | } 332 | tails[pos]->value = o; 333 | } 334 | return JSON_BREAKING_BAD; 335 | } 336 | -------------------------------------------------------------------------------- /cocoapi/common/gason.h: -------------------------------------------------------------------------------- 1 | // https://github.com/vivkin/gason - pulled January 10, 2016 2 | #pragma once 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | enum JsonTag { 9 | JSON_NUMBER = 0, 10 | JSON_STRING, 11 | JSON_ARRAY, 12 | JSON_OBJECT, 13 | JSON_TRUE, 14 | JSON_FALSE, 15 | JSON_NULL = 0xF 16 | }; 17 | 18 | struct JsonNode; 19 | 20 | #define JSON_VALUE_PAYLOAD_MASK 0x00007FFFFFFFFFFFULL 21 | #define JSON_VALUE_NAN_MASK 0x7FF8000000000000ULL 22 | #define JSON_VALUE_TAG_MASK 0xF 23 | #define JSON_VALUE_TAG_SHIFT 47 24 | 25 | union JsonValue { 26 | uint64_t ival; 27 | double fval; 28 | 29 | JsonValue(double x) 30 | : fval(x) { 31 | } 32 | JsonValue(JsonTag tag = JSON_NULL, void *payload = nullptr) { 33 | assert((uintptr_t)payload <= JSON_VALUE_PAYLOAD_MASK); 34 | ival = JSON_VALUE_NAN_MASK | ((uint64_t)tag << JSON_VALUE_TAG_SHIFT) | (uintptr_t)payload; 35 | } 36 | bool isDouble() const { 37 | return (int64_t)ival <= (int64_t)JSON_VALUE_NAN_MASK; 38 | } 39 | JsonTag getTag() const { 40 | return isDouble() ? JSON_NUMBER : JsonTag((ival >> JSON_VALUE_TAG_SHIFT) & JSON_VALUE_TAG_MASK); 41 | } 42 | uint64_t getPayload() const { 43 | assert(!isDouble()); 44 | return ival & JSON_VALUE_PAYLOAD_MASK; 45 | } 46 | double toNumber() const { 47 | assert(getTag() == JSON_NUMBER); 48 | return fval; 49 | } 50 | char *toString() const { 51 | assert(getTag() == JSON_STRING); 52 | return (char *)getPayload(); 53 | } 54 | JsonNode *toNode() const { 55 | assert(getTag() == JSON_ARRAY || getTag() == JSON_OBJECT); 56 | return (JsonNode *)getPayload(); 57 | } 58 | }; 59 | 60 | struct JsonNode { 61 | JsonValue value; 62 | JsonNode *next; 63 | char *key; 64 | }; 65 | 66 | struct JsonIterator { 67 | JsonNode *p; 68 | 69 | void operator++() { 70 | p = p->next; 71 | } 72 | bool operator!=(const JsonIterator &x) const { 73 | return p != x.p; 74 | } 75 | JsonNode *operator*() const { 76 | return p; 77 | } 78 | JsonNode *operator->() const { 79 | return p; 80 | } 81 | }; 82 | 83 | inline JsonIterator begin(JsonValue o) { 84 | return JsonIterator{o.toNode()}; 85 | } 86 | inline JsonIterator end(JsonValue) { 87 | return JsonIterator{nullptr}; 88 | } 89 | 90 | #define JSON_ERRNO_MAP(XX) \ 91 | XX(OK, "ok") \ 92 | XX(BAD_NUMBER, "bad number") \ 93 | XX(BAD_STRING, "bad string") \ 94 | XX(BAD_IDENTIFIER, "bad identifier") \ 95 | XX(STACK_OVERFLOW, "stack overflow") \ 96 | XX(STACK_UNDERFLOW, "stack underflow") \ 97 | XX(MISMATCH_BRACKET, "mismatch bracket") \ 98 | XX(UNEXPECTED_CHARACTER, "unexpected character") \ 99 | XX(UNQUOTED_KEY, "unquoted key") \ 100 | XX(BREAKING_BAD, "breaking bad") \ 101 | XX(ALLOCATION_FAILURE, "allocation failure") 102 | 103 | enum JsonErrno { 104 | #define XX(no, str) JSON_##no, 105 | JSON_ERRNO_MAP(XX) 106 | #undef XX 107 | }; 108 | 109 | const char *jsonStrError(int err); 110 | 111 | class JsonAllocator { 112 | struct Zone { 113 | Zone *next; 114 | size_t used; 115 | } *head = nullptr; 116 | 117 | public: 118 | JsonAllocator() = default; 119 | JsonAllocator(const JsonAllocator &) = delete; 120 | JsonAllocator &operator=(const JsonAllocator &) = delete; 121 | JsonAllocator(JsonAllocator &&x) : head(x.head) { 122 | x.head = nullptr; 123 | } 124 | JsonAllocator &operator=(JsonAllocator &&x) { 125 | head = x.head; 126 | x.head = nullptr; 127 | return *this; 128 | } 129 | ~JsonAllocator() { 130 | deallocate(); 131 | } 132 | void *allocate(size_t size); 133 | void deallocate(); 134 | }; 135 | 136 | int jsonParse(char *str, char **endptr, JsonValue *value, JsonAllocator &allocator); 137 | -------------------------------------------------------------------------------- /cocoapi/common/maskApi.c: -------------------------------------------------------------------------------- 1 | /************************************************************************** 2 | * Microsoft COCO Toolbox. version 2.0 3 | * Data, paper, and tutorials available at: http://mscoco.org/ 4 | * Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 5 | * Licensed under the Simplified BSD License [see coco/license.txt] 6 | **************************************************************************/ 7 | #include "maskApi.h" 8 | #include 9 | #include 10 | 11 | uint umin( uint a, uint b ) { return (ab) ? a : b; } 13 | 14 | void rleInit( RLE *R, siz h, siz w, siz m, uint *cnts ) { 15 | R->h=h; R->w=w; R->m=m; R->cnts=(m==0)?0:malloc(sizeof(uint)*m); 16 | siz j; if(cnts) for(j=0; jcnts[j]=cnts[j]; 17 | } 18 | 19 | void rleFree( RLE *R ) { 20 | free(R->cnts); R->cnts=0; 21 | } 22 | 23 | void rlesInit( RLE **R, siz n ) { 24 | siz i; *R = (RLE*) malloc(sizeof(RLE)*n); 25 | for(i=0; i0 ) { 61 | c=umin(ca,cb); cc+=c; ct=0; 62 | ca-=c; if(!ca && a0) { 83 | crowd=iscrowd!=NULL && iscrowd[g]; 84 | if(dt[d].h!=gt[g].h || dt[d].w!=gt[g].w) { o[g*m+d]=-1; continue; } 85 | siz ka, kb, a, b; uint c, ca, cb, ct, i, u; int va, vb; 86 | ca=dt[d].cnts[0]; ka=dt[d].m; va=vb=0; 87 | cb=gt[g].cnts[0]; kb=gt[g].m; a=b=1; i=u=0; ct=1; 88 | while( ct>0 ) { 89 | c=umin(ca,cb); if(va||vb) { u+=c; if(va&&vb) i+=c; } ct=0; 90 | ca-=c; if(!ca && athr) keep[j]=0; 105 | } 106 | } 107 | } 108 | 109 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) { 110 | double h, w, i, u, ga, da; siz g, d; int crowd; 111 | for( g=0; gthr) keep[j]=0; 129 | } 130 | } 131 | } 132 | 133 | void rleToBbox( const RLE *R, BB bb, siz n ) { 134 | siz i; for( i=0; id?1:c=dy && xs>xe) || (dxye); 173 | if(flip) { t=xs; xs=xe; xe=t; t=ys; ys=ye; ye=t; } 174 | s = dx>=dy ? (double)(ye-ys)/dx : (double)(xe-xs)/dy; 175 | if(dx>=dy) for( d=0; d<=dx; d++ ) { 176 | t=flip?dx-d:d; u[m]=t+xs; v[m]=(int)(ys+s*t+.5); m++; 177 | } else for( d=0; d<=dy; d++ ) { 178 | t=flip?dy-d:d; v[m]=t+ys; u[m]=(int)(xs+s*t+.5); m++; 179 | } 180 | } 181 | /* get points along y-boundary and downsample */ 182 | free(x); free(y); k=m; m=0; double xd, yd; 183 | x=malloc(sizeof(int)*k); y=malloc(sizeof(int)*k); 184 | for( j=1; jw-1 ) continue; 187 | yd=(double)(v[j]h) yd=h; yd=ceil(yd); 189 | x[m]=(int) xd; y[m]=(int) yd; m++; 190 | } 191 | /* compute rle encoding given y-boundary points */ 192 | k=m; a=malloc(sizeof(uint)*(k+1)); 193 | for( j=0; j0) b[m++]=a[j++]; else { 199 | j++; if(jm, p=0; long x; int more; 206 | char *s=malloc(sizeof(char)*m*6); 207 | for( i=0; icnts[i]; if(i>2) x-=(long) R->cnts[i-2]; more=1; 209 | while( more ) { 210 | char c=x & 0x1f; x >>= 5; more=(c & 0x10) ? x!=-1 : x!=0; 211 | if(more) c |= 0x20; c+=48; s[p++]=c; 212 | } 213 | } 214 | s[p]=0; return s; 215 | } 216 | 217 | void rleFrString( RLE *R, char *s, siz h, siz w ) { 218 | siz m=0, p=0, k; long x; int more; uint *cnts; 219 | while( s[m] ) m++; cnts=malloc(sizeof(uint)*m); m=0; 220 | while( s[p] ) { 221 | x=0; k=0; more=1; 222 | while( more ) { 223 | char c=s[p]-48; x |= (c & 0x1f) << 5*k; 224 | more = c & 0x20; p++; k++; 225 | if(!more && (c & 0x10)) x |= -1 << 5*k; 226 | } 227 | if(m>2) x+=(long) cnts[m-2]; cnts[m++]=(uint) x; 228 | } 229 | rleInit(R,h,w,m,cnts); free(cnts); 230 | } 231 | -------------------------------------------------------------------------------- /cocoapi/common/maskApi.h: -------------------------------------------------------------------------------- 1 | /************************************************************************** 2 | * Microsoft COCO Toolbox. version 2.0 3 | * Data, paper, and tutorials available at: http://mscoco.org/ 4 | * Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 5 | * Licensed under the Simplified BSD License [see coco/license.txt] 6 | **************************************************************************/ 7 | #pragma once 8 | 9 | typedef unsigned int uint; 10 | typedef unsigned long siz; 11 | typedef unsigned char byte; 12 | typedef double* BB; 13 | typedef struct { siz h, w, m; uint *cnts; } RLE; 14 | 15 | /* Initialize/destroy RLE. */ 16 | void rleInit( RLE *R, siz h, siz w, siz m, uint *cnts ); 17 | void rleFree( RLE *R ); 18 | 19 | /* Initialize/destroy RLE array. */ 20 | void rlesInit( RLE **R, siz n ); 21 | void rlesFree( RLE **R, siz n ); 22 | 23 | /* Encode binary masks using RLE. */ 24 | void rleEncode( RLE *R, const byte *mask, siz h, siz w, siz n ); 25 | 26 | /* Decode binary masks encoded via RLE. */ 27 | void rleDecode( const RLE *R, byte *mask, siz n ); 28 | 29 | /* Compute union or intersection of encoded masks. */ 30 | void rleMerge( const RLE *R, RLE *M, siz n, int intersect ); 31 | 32 | /* Compute area of encoded masks. */ 33 | void rleArea( const RLE *R, siz n, uint *a ); 34 | 35 | /* Compute intersection over union between masks. */ 36 | void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ); 37 | 38 | /* Compute non-maximum suppression between bounding masks */ 39 | void rleNms( RLE *dt, siz n, uint *keep, double thr ); 40 | 41 | /* Compute intersection over union between bounding boxes. */ 42 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ); 43 | 44 | /* Compute non-maximum suppression between bounding boxes */ 45 | void bbNms( BB dt, siz n, uint *keep, double thr ); 46 | 47 | /* Get bounding boxes surrounding encoded masks. */ 48 | void rleToBbox( const RLE *R, BB bb, siz n ); 49 | 50 | /* Convert bounding boxes to encoded masks. */ 51 | void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n ); 52 | 53 | /* Convert polygon to encoded mask. */ 54 | void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w ); 55 | 56 | /* Get compressed string representation of encoded mask. */ 57 | char* rleToString( const RLE *R ); 58 | 59 | /* Convert from compressed string representation of encoded mask. */ 60 | void rleFrString( RLE *R, char *s, siz h, siz w ); 61 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File: dataloader.py 3 | Project: MobilePose 4 | File Created: Thursday, 8th March 2018 3:00:27 pm 5 | Author: Yuliang Xiu (yuliangxiu@sjtu.edu.cn) 6 | ----- 7 | Last Modified: Thursday, 8th March 2018 3:00:39 pm 8 | Modified By: Yuliang Xiu (yuliangxiu@sjtu.edu.cn>) 9 | ----- 10 | Copyright 2018 - 2018 Shanghai Jiao Tong University, Machine Vision and Intelligence Group 11 | ''' 12 | 13 | import csv 14 | import numpy as np 15 | import os 16 | from skimage import io, transform 17 | import cv2 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.optim as optim 23 | import torch.backends.cudnn as cudnn 24 | from torch.utils.data import Dataset, DataLoader 25 | from torchvision import datasets, transforms, utils, models 26 | from torch.autograd import Variable 27 | 28 | # import matplotlib.pyplot as plt 29 | 30 | def crop_camera(image, ratio=0.15): 31 | height = image.shape[0] 32 | width = image.shape[1] 33 | mid_width = width / 2.0 34 | width_20 = width * ratio 35 | crop_img = image[0:int(height), int(mid_width - width_20):int(mid_width + width_20)] 36 | return crop_img 37 | 38 | def expand_bbox(left, right, top, bottom, img_width, img_height): 39 | width = right-left 40 | height = bottom-top 41 | # ratio = np.random.random_sample()*0.2 42 | ratio = 0.15 43 | new_left = np.clip(left-ratio*width,0,img_width) 44 | new_right = np.clip(right+ratio*width,0,img_width) 45 | new_top = np.clip(top-ratio*height,0,img_height) 46 | new_bottom = np.clip(bottom+ratio*height,0,img_height) 47 | 48 | return [int(new_left), int(new_top), int(new_right), int(new_bottom)] 49 | 50 | # Rescale implementation of mobilenetV2 51 | 52 | class Wrap(object): 53 | 54 | def __init__(self, output_size): 55 | assert isinstance(output_size, (int, tuple)) 56 | self.output_size = output_size 57 | 58 | def __call__(self, sample): 59 | image_, pose_ = sample['image']/256.0, sample['pose'] 60 | 61 | h, w = image_.shape[:2] 62 | if isinstance(self.output_size, int): 63 | if h > w: 64 | new_h, new_w = self.output_size * h / w, self.output_size 65 | else: 66 | new_h, new_w = self.output_size, self.output_size * w / h 67 | else: 68 | new_h, new_w = self.output_size 69 | 70 | new_h, new_w = int(new_h), int(new_w) 71 | 72 | image = transform.resize(image_, (new_w, new_h)) 73 | pose = (pose_.reshape([-1,2])/np.array([w,h])*np.array([new_w,new_h])).flatten() 74 | return {'image': image, 'pose': pose} 75 | 76 | 77 | # Rescale implementation of Resnet18 78 | 79 | class Rescale(object): 80 | 81 | 82 | def __init__(self, output_size): 83 | assert isinstance(output_size, (int, tuple)) 84 | self.output_size = output_size 85 | 86 | def __call__(self, sample): 87 | image_, pose_ = sample['image']/256.0, sample['pose'] 88 | h, w = image_.shape[:2] 89 | im_scale = min(float(self.output_size[0]) / float(h), float(self.output_size[1]) / float(w)) 90 | new_h = int(image_.shape[0] * im_scale) 91 | new_w = int(image_.shape[1] * im_scale) 92 | image = cv2.resize(image_, (new_w, new_h), 93 | interpolation=cv2.INTER_LINEAR) 94 | left_pad = (self.output_size[1] - new_w) // 2 95 | right_pad = (self.output_size[1] - new_w) - left_pad 96 | top_pad = (self.output_size[0] - new_h) // 2 97 | bottom_pad = (self.output_size[0] - new_h) - top_pad 98 | mean=np.array([0.485, 0.456, 0.406]) 99 | pad = ((top_pad, bottom_pad), (left_pad, right_pad)) 100 | image = np.stack([np.pad(image[:,:,c], pad, mode='constant', constant_values=mean[c]) 101 | for c in range(3)], axis=2) 102 | pose = (pose_.reshape([-1,2])/np.array([w,h])*np.array([new_w,new_h])) 103 | pose += [left_pad, top_pad] 104 | pose = pose.flatten() 105 | 106 | return {'image': image, 'pose': pose} 107 | 108 | 109 | class Expansion(object): 110 | 111 | def __call__(self, sample): 112 | image, pose = sample['image'], sample['pose'] 113 | h, w = image.shape[:2] 114 | x = np.arange(0, h) 115 | y = np.arange(0, w) 116 | x, y = np.meshgrid(x, y) 117 | x = x[:,:, np.newaxis] 118 | y = y[:,:, np.newaxis] 119 | image = np.concatenate((image, x, y), axis=2) 120 | 121 | return {'image': image, 122 | 'pose': pose} 123 | 124 | class ToTensor(object): 125 | 126 | def __call__(self, sample): 127 | image, pose = sample['image'], sample['pose'] 128 | h, w = image.shape[:2] 129 | 130 | x_mean = np.mean(image[:,:,3]) 131 | x_std = np.std(image[:,:,3]) 132 | y_mean = np.mean(image[:,:,4]) 133 | y_std = np.std(image[:,:,4]) 134 | 135 | mean=np.array([0.485, 0.456, 0.406, x_mean, y_mean]) 136 | std=np.array([0.229, 0.224, 0.225, x_std, y_std]) 137 | 138 | # mean=np.array([0.485, 0.456, 0.406]) 139 | # std=np.array([0.229, 0.224, 0.225]) 140 | 141 | image = (image-mean)/(std) 142 | image = torch.from_numpy(image.transpose((2, 0, 1))).float() 143 | pose = torch.from_numpy(pose).float() 144 | 145 | return {'image': image, 146 | 'pose': pose} 147 | 148 | class PoseDataset(Dataset): 149 | 150 | def __init__(self, csv_file, transform): 151 | 152 | with open(csv_file) as f: 153 | self.f_csv = list(csv.reader(f, delimiter='\t')) 154 | self.transform = transform 155 | 156 | def __len__(self): 157 | return len(self.f_csv) 158 | 159 | def __getitem__(self, idx): 160 | ROOT_DIR = "/home/yuliang/code/deeppose_tf/datasets/mpii" 161 | line = self.f_csv[idx][0].split(",") 162 | img_path = os.path.join(ROOT_DIR,'images',line[0]) 163 | image = io.imread(img_path) 164 | height, width = image.shape[0], image.shape[1] 165 | pose = np.array([float(item) for item in line[1:]]).reshape([-1,2]) 166 | 167 | xmin = np.min(pose[:,0]) 168 | ymin = np.min(pose[:,1]) 169 | xmax = np.max(pose[:,0]) 170 | ymax = np.max(pose[:,1]) 171 | 172 | box = expand_bbox(xmin, xmax, ymin, ymax, width, height) 173 | image = image[box[1]:box[3],box[0]:box[2],:] 174 | pose = (pose-np.array([box[0],box[1]])).flatten() 175 | sample = {'image': image, 'pose':pose} 176 | if self.transform: 177 | sample = self.transform(sample) 178 | return sample 179 | 180 | 181 | import imgaug as ia 182 | from imgaug import augmenters as iaa 183 | from scipy import misc 184 | import copy 185 | import random 186 | from imgaug import parameters as iap 187 | 188 | class Augmentation(object): 189 | 190 | def pose2keypoints(self, image, pose): 191 | keypoints = [] 192 | for row in range(int(pose.shape[0])): 193 | x = pose[row,0] 194 | y = pose[row,1] 195 | keypoints.append(ia.Keypoint(x=x, y=y)) 196 | return ia.KeypointsOnImage(keypoints, shape=image.shape) 197 | 198 | def keypoints2pose(self, keypoints_aug): 199 | one_person = [] 200 | for kp_idx, keypoint in enumerate(keypoints_aug.keypoints): 201 | x_new, y_new = keypoint.x, keypoint.y 202 | one_person.append(np.array(x_new).astype(np.float32)) 203 | one_person.append(np.array(y_new).astype(np.float32)) 204 | return np.array(one_person).reshape([-1,2]) 205 | 206 | def __call__(self, sample): 207 | image, pose= sample['image'], sample['pose'].reshape([-1,2]) 208 | 209 | # augmentation choices 210 | seq = iaa.SomeOf(2, [ 211 | iaa.Sometimes(0.4, iaa.Scale((0.5, 1.0))), 212 | iaa.Sometimes(0.6, iaa.CropAndPad(percent=(-0.25, 0.25), pad_mode=["edge"], keep_size=False)), 213 | iaa.Fliplr(0.1), 214 | iaa.Sometimes(0.4, iaa.AdditiveGaussianNoise(scale=(0, 0.05*50))), 215 | iaa.Sometimes(0.1, iaa.GaussianBlur(sigma=(0, 3.0))) 216 | ]) 217 | seq_det = seq.to_deterministic() 218 | 219 | image_aug = seq_det.augment_images([image])[0] 220 | keypoints_aug = seq_det.augment_keypoints([self.pose2keypoints(image,pose)])[0] 221 | 222 | return {'image': image_aug, 'pose': self.keypoints2pose(keypoints_aug)} -------------------------------------------------------------------------------- /estimator.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File: estimator.py 3 | Project: MobilePose 4 | File Created: Thursday, 8th March 2018 3:02:01 pm 5 | Author: Yuliang Xiu (yuliangxiu@sjtu.edu.cn) 6 | ----- 7 | Last Modified: Thursday, 8th March 2018 3:02:06 pm 8 | Modified By: Yuliang Xiu (yuliangxiu@sjtu.edu.cn>) 9 | ----- 10 | Copyright 2018 - 2018 Shanghai Jiao Tong University, Machine Vision and Intelligence Group 11 | ''' 12 | 13 | import itertools 14 | import logging 15 | import math 16 | from collections import namedtuple 17 | 18 | import cv2 19 | import numpy as np 20 | import torch 21 | 22 | from scipy.ndimage import maximum_filter, gaussian_filter 23 | from skimage import io, transform 24 | 25 | from torch.autograd import Variable 26 | 27 | class ResEstimator: 28 | def __init__(self, graph_path, target_size=(224, 224)): 29 | self.target_size = target_size 30 | self.graph_path = graph_path 31 | self.net = torch.load(graph_path,map_location=lambda storage, loc: storage) 32 | self.net.eval() 33 | 34 | def addlayer(self, image): 35 | h, w = image.shape[:2] 36 | x = np.arange(0, h) 37 | y = np.arange(0, w) 38 | x, y = np.meshgrid(x, y) 39 | x = x[:,:, np.newaxis] 40 | y = y[:,:, np.newaxis] 41 | image = np.concatenate((image, x, y), axis=2) 42 | 43 | return image 44 | 45 | def wrap(self, image, output_size): 46 | image_ = image 47 | h, w = image_.shape[:2] 48 | if isinstance(output_size, int): 49 | if h > w: 50 | new_h, new_w = output_size * h / w, output_size 51 | else: 52 | new_h, new_w = output_size, output_size * w / h 53 | else: 54 | new_h, new_w = output_size 55 | 56 | new_h, new_w = int(new_h), int(new_w) 57 | 58 | image = transform.resize(image_, (new_w, new_h)) 59 | pose_fun = lambda x: (x.reshape([-1,2]) * 1.0 /np.array([new_w, new_h])*np.array([w,h])) 60 | return {'image': image, 'pose_fun': pose_fun} 61 | 62 | def rescale(self, image, output_size): 63 | image_ = image 64 | h, w = image_.shape[:2] 65 | im_scale = min(float(output_size[0]) / float(h), float(output_size[1]) / float(w)) 66 | new_h = int(image_.shape[0] * im_scale) 67 | new_w = int(image_.shape[1] * im_scale) 68 | image = cv2.resize(image_, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 69 | left_pad =int( (output_size[1] - new_w) / 2.0) 70 | top_pad = int((output_size[0] - new_h) / 2.0) 71 | mean=np.array([0.485, 0.456, 0.406]) 72 | pad = ((top_pad, top_pad), (left_pad, left_pad)) 73 | image = np.stack([np.pad(image[:,:,c], pad, mode='constant', constant_values=mean[c])for c in range(3)], axis=2) 74 | pose_fun = lambda x: (((x.reshape([-1,2])-[left_pad, top_pad]) * 1.0 /np.array([new_w, new_h])*np.array([w,h]))) 75 | return {'image': image, 'pose_fun': pose_fun} 76 | 77 | def to_tensor(self, image): 78 | x_mean = np.mean(image[:,:,3]) 79 | x_std = np.std(image[:,:,3]) 80 | y_mean = np.mean(image[:,:,4]) 81 | y_std = np.std(image[:,:,4]) 82 | mean=np.array([0.485, 0.456, 0.406, x_mean, y_mean]) 83 | std=np.array([0.229, 0.224, 0.225, x_std, y_std]) 84 | image = torch.from_numpy(((image-mean)/std).transpose((2, 0, 1))).float() 85 | return image 86 | 87 | def inference(self, in_npimg, model): 88 | canvas = np.zeros_like(in_npimg) 89 | height = canvas.shape[0] 90 | width = canvas.shape[1] 91 | 92 | if model == 'resnet': 93 | rescale_out = self.rescale(in_npimg, (227,227)) 94 | elif model =='mobilenet': 95 | rescale_out = self.rescale(in_npimg, (224,224)) 96 | 97 | image = rescale_out['image']/256.0 98 | image = self.addlayer(image) 99 | image = self.to_tensor(image) 100 | image = image.unsqueeze(0) 101 | pose_fun = rescale_out['pose_fun'] 102 | 103 | keypoints = self.net(Variable(image)) 104 | keypoints = keypoints.data.cpu().numpy() 105 | keypoints = pose_fun(keypoints).astype(int) 106 | 107 | return keypoints 108 | 109 | @staticmethod 110 | def draw_humans(npimg, pose, imgcopy=False): 111 | if imgcopy: 112 | npimg = np.copy(npimg) 113 | image_h, image_w = npimg.shape[:2] 114 | centers = {} 115 | 116 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 117 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 118 | [170, 0, 255], [255, 0, 255]] 119 | 120 | pairs = [[8,9],[11,12],[11,10],[2,1],[1,0],[13,14],[14,15],[3,4],[4,5],[8,7],[7,6],[6,2],[6,3],[8,12],[8,13]] 121 | colors_skeleton = ['r', 'y', 'y', 'g', 'g', 'y', 'y', 'g', 'g', 'm', 'm', 'g', 'g', 'y','y'] 122 | colors_skeleton = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 123 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 124 | [170, 0, 255]] 125 | 126 | for idx in range(len(colors)): 127 | cv2.circle(npimg, (pose[idx,0], pose[idx,1]), 3, colors[idx], thickness=3, lineType=8, shift=0) 128 | for idx in range(len(colors_skeleton)): 129 | npimg = cv2.line(npimg, (pose[pairs[idx][0],0], pose[pairs[idx][0],1]), (pose[pairs[idx][1],0], pose[pairs[idx][1],1]), colors_skeleton[idx], 3) 130 | 131 | return npimg 132 | 133 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | ''' 3 | File: eval.py 4 | Project: MobilePose 5 | File Created: Thursday, 8th March 2018 1:54:07 pm 6 | Author: Yuliang Xiu (yuliangxiu@sjtu.edu.cn) 7 | ----- 8 | Last Modified: Thursday, 8th March 2018 3:01:51 pm 9 | Modified By: Yuliang Xiu (yuliangxiu@sjtu.edu.cn>) 10 | ----- 11 | Copyright 2018 - 2018 Shanghai Jiao Tong University, Machine Vision and Intelligence Group 12 | ''' 13 | 14 | import warnings 15 | warnings.filterwarnings('ignore') 16 | 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | from torch.autograd import Variable 20 | from torch.utils.data import Dataset, DataLoader 21 | from torchvision import datasets, transforms, utils, models 22 | from tqdm import tqdm 23 | from skimage import io, transform 24 | from math import ceil 25 | import numpy as np 26 | import torch 27 | import csv 28 | import os 29 | import argparse 30 | import time 31 | 32 | from dataloader import * 33 | from coco_utils import * 34 | from networks import * 35 | from pycocotools.coco import COCO 36 | from pycocotools.cocoeval import COCOeval 37 | 38 | gpus = [0,1] 39 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 40 | torch.backends.cudnn.enabled = True 41 | print(torch.cuda.device_count()) 42 | 43 | if __name__ == '__main__': 44 | 45 | parser = argparse.ArgumentParser(description='MobilePose Demo') 46 | parser.add_argument('--model', type=str, default="resnet") 47 | args = parser.parse_args() 48 | modeltype = args.model 49 | 50 | # user defined parameters 51 | filename = "final-aug.t7" 52 | num_threads = 10 53 | 54 | PATH_PREFIX = "./results/{}".format(modeltype) 55 | full_name="./models/{}/{}".format(modeltype, filename) 56 | # full_name = "/home/yuliang/code/MobilePose-pytorch/models/demo/mobilenetv2_224x224-robust.t7" 57 | # full_name = "/home/yuliang/code/MobilePose-pytorch/models/demo/resnet18_227x227.t7" 58 | # full_name = "/home/yuliang/code/MobilePose-pytorch/models/demo/mobilenet-best.t7" 59 | # ROOT_DIR = "/home/yuliang/code/deeppose_tf/datasets/mpii" 60 | ROOT_DIR = "../deeppose_tf/datasets/mpii" 61 | 62 | if modeltype == 'resnet': 63 | input_size = 227 64 | elif modeltype == 'mobilenet': 65 | input_size = 224 66 | 67 | print("Loading testing dataset, wait...") 68 | 69 | # load dataset 70 | test_dataset = PoseDataset(csv_file=os.path.join(ROOT_DIR,'test_joints.csv'), 71 | transform=transforms.Compose([ 72 | # Rescale((input_size, input_size)), # for resnet18 and mobilenet 73 | Wrap((input_size,input_size)), # only for mobilenet-best 74 | Expansion(), 75 | ToTensor() 76 | ])) 77 | test_dataset_size = len(test_dataset) 78 | 79 | test_dataloader = DataLoader(test_dataset, batch_size=test_dataset_size, 80 | shuffle=False, num_workers = num_threads) 81 | 82 | # get all test data 83 | all_test_data = {} 84 | for i_batch, sample_batched in enumerate(tqdm(test_dataloader)): 85 | all_test_data = sample_batched 86 | 87 | def eval_coco(net_path, result_gt_json_path, result_pred_json_path): 88 | """ 89 | Example: 90 | eval_coco('/home/yuliang/code/PoseFlow/checkpoint140.t7', 91 | 'result-gt-json.txt', 'result-pred-json.txt') 92 | """ 93 | # gpu mode 94 | net = Net().cuda(device_id=gpus[0]) 95 | net = torch.load(net_path).cuda(device_id=gpus[0]) 96 | 97 | # cpu mode 98 | # net = Net() 99 | # net = torch.load(net_path, map_location=lambda storage, loc: storage) 100 | 101 | ## generate groundtruth json 102 | total_size = len(all_test_data['image']) 103 | all_coco_images_arr = [] 104 | all_coco_annotations_arr = [] 105 | transform_to_coco_gt(all_test_data['pose'], all_coco_images_arr, all_coco_annotations_arr) 106 | coco = CocoData(all_coco_images_arr, all_coco_annotations_arr) 107 | coco_str = coco.dumps() 108 | result_gt_json = float2int(coco_str) 109 | 110 | # save ground truth json to file 111 | f = open(result_gt_json_path, "w") 112 | f.write(result_gt_json) 113 | f.close() 114 | 115 | # generate predictioin json 116 | total_size = len(all_test_data['image']) 117 | all_coco_pred_annotations_arr = [] 118 | for i in tqdm(range(1, int(ceil(total_size / 100.0 + 1)))): 119 | sample_data = {} 120 | 121 | # gpu mode 122 | sample_data['image'] = all_test_data['image'][100 * (i - 1) : min(100 * i, total_size)].cuda(device=gpus[0]) 123 | # cpu mode 124 | # sample_data['image'] = all_test_data['image'][100 * (i - 1) : min(100 * i, total_size)] 125 | 126 | # print('test dataset contains: %d'%(len(sample_data['image']))) 127 | # t0 = time.time() 128 | output = net(Variable(sample_data['image'],volatile=True)) 129 | # print('FPS is %f'%(1.0/((time.time()-t0)/len(sample_data['image'])))) 130 | 131 | transform_to_coco_pred(output, all_coco_pred_annotations_arr, 100 * (i - 1)) 132 | 133 | all_coco_pred_annotations_arr = [item._asdict() for item in all_coco_pred_annotations_arr] 134 | result_pred_json = json.dumps(all_coco_pred_annotations_arr, cls=MyEncoder) 135 | result_pred_json = float2int(result_pred_json) 136 | 137 | # save result predict json to file 138 | f = open(result_pred_json_path, "w") 139 | f.write(result_pred_json) 140 | f.close() 141 | 142 | 143 | 144 | eval_coco(full_name, os.path.join(PATH_PREFIX, 'result-gt-json.txt'), os.path.join(PATH_PREFIX, 'result-pred-json.txt')) 145 | 146 | # evaluation 147 | annType = ['segm','bbox','keypoints'] 148 | annType = annType[2] 149 | prefix = 'person_keypoints' if annType=='keypoints' else 'instances' 150 | 151 | print('Running demo for *%s* results.'%(annType)) 152 | 153 | annFile = os.path.join(PATH_PREFIX, "result-gt-json.txt") 154 | cocoGt=COCO(annFile) 155 | resFile = os.path.join(PATH_PREFIX,"result-pred-json.txt") 156 | cocoDt=cocoGt.loadRes(resFile) 157 | imgIds=sorted(cocoGt.getImgIds()) 158 | 159 | cocoEval = COCOeval(cocoGt,cocoDt,annType) 160 | cocoEval.params.imgIds = imgIds 161 | cocoEval.evaluate() 162 | cocoEval.accumulate() 163 | cocoEval.summarize() 164 | -------------------------------------------------------------------------------- /mobilenetv2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File: mobilenetv2.py 3 | Project: MobilePose 4 | File Created: Thursday, 8th March 2018 2:51:18 pm 5 | Author: Yuliang Xiu (yuliangxiu@sjtu.edu.cn) 6 | ----- 7 | Last Modified: Thursday, 8th March 2018 3:01:19 pm 8 | Modified By: Yuliang Xiu (yuliangxiu@sjtu.edu.cn>) 9 | ----- 10 | Copyright 2018 - 2018 Shanghai Jiao Tong University, Machine Vision and Intelligence Group 11 | ''' 12 | 13 | import torch.nn as nn 14 | import math 15 | 16 | 17 | def conv_bn(inp, oup, stride): 18 | return nn.Sequential( 19 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 20 | nn.BatchNorm2d(oup), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | 25 | def conv_1x1_bn(inp, oup): 26 | return nn.Sequential( 27 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 28 | nn.BatchNorm2d(oup), 29 | nn.ReLU(inplace=True) 30 | ) 31 | 32 | 33 | class InvertedResidual(nn.Module): 34 | def __init__(self, inp, oup, stride, expand_ratio): 35 | super(InvertedResidual, self).__init__() 36 | self.stride = stride 37 | assert stride in [1, 2] 38 | 39 | self.use_res_connect = self.stride == 1 and inp == oup 40 | 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(inp * expand_ratio), 45 | nn.ReLU6(inplace=True), 46 | # dw 47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 48 | nn.BatchNorm2d(inp * expand_ratio), 49 | nn.ReLU6(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | 55 | def forward(self, x): 56 | if self.use_res_connect: 57 | return x + self.conv(x) 58 | else: 59 | return self.conv(x) 60 | 61 | 62 | class MobileNetV2(nn.Module): 63 | def __init__(self, image_channel=5, n_class=32, input_size=224, width_mult=1.): 64 | super(MobileNetV2, self).__init__() 65 | # setting of inverted residual blocks 66 | self.interverted_residual_setting = [ 67 | # t, c, n, s 68 | [1, 16, 1, 1], 69 | [6, 24, 2, 2], 70 | [6, 32, 3, 2], 71 | [6, 64, 4, 2], 72 | [6, 96, 3, 1], 73 | [6, 160, 3, 2], 74 | [6, 320, 1, 1], 75 | ] 76 | 77 | # building first layer 78 | assert input_size % 32 == 0 79 | input_channel = int(32 * width_mult) 80 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 81 | # self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 32 82 | self.features = [conv_bn(image_channel, input_channel, 2)] 83 | # building inverted residual blocks 84 | for t, c, n, s in self.interverted_residual_setting: 85 | output_channel = int(c * width_mult) 86 | for i in range(n): 87 | if i == 0: 88 | self.features.append(InvertedResidual(input_channel, output_channel, s, t)) 89 | else: 90 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t)) 91 | input_channel = output_channel 92 | # building last several layers 93 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 94 | self.features.append(nn.AvgPool2d(int(input_size/32))) 95 | # make it nn.Sequential 96 | self.features = nn.Sequential(*self.features) 97 | 98 | # building classifier 99 | self.classifier = nn.Sequential( 100 | nn.Dropout(p=0.5), 101 | nn.Linear(self.last_channel, n_class), 102 | ) 103 | 104 | self._initialize_weights() 105 | # print(width_mult) 106 | 107 | def forward(self, x): 108 | x = self.features(x) 109 | x = x.view(-1, self.last_channel) 110 | x = self.classifier(x) 111 | return x 112 | 113 | def _initialize_weights(self): 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | if m.bias is not None: 119 | m.bias.data.zero_() 120 | elif isinstance(m, nn.BatchNorm2d): 121 | m.weight.data.fill_(1) 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.Linear): 124 | n = m.weight.size(1) 125 | m.weight.data.normal_(0, 0.01) 126 | m.bias.data.zero_() 127 | 128 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File: networks.py 3 | Project: MobilePose 4 | File Created: Thursday, 8th March 2018 2:59:28 pm 5 | Author: Yuliang Xiu (yuliangxiu@sjtu.edu.cn) 6 | ----- 7 | Last Modified: Thursday, 8th March 2018 3:01:29 pm 8 | Modified By: Yuliang Xiu (yuliangxiu@sjtu.edu.cn>) 9 | ----- 10 | Copyright 2018 - 2018 Shanghai Jiao Tong University, Machine Vision and Intelligence Group 11 | ''' 12 | 13 | 14 | from torchvision import models 15 | import torch.nn as nn 16 | from mobilenetv2 import * 17 | 18 | def get_graph_path(model_name): 19 | return { 20 | 'resnet': './models/demo/resnet18_227x227.t7', 21 | 'mobilenet': './models/demo/mobilenetv2_224x224.t7', 22 | }[model_name] 23 | 24 | def model_wh(model_name): 25 | width, height = model_name.split('_')[-1].split('x') 26 | return int(width), int(height.split(".")[0]) 27 | 28 | class Net(nn.Module): 29 | 30 | def __init__(self): 31 | super(Net, self).__init__() 32 | model = models.resnet18(pretrained=True) 33 | model.conv1 = nn.Conv2d(5, 64, kernel_size=7, stride=2, padding=3,bias=True) 34 | model.fc=nn.Linear(512,32) 35 | for param in model.parameters(): 36 | param.requires_grad = True 37 | self.resnet = model.cuda() 38 | 39 | def forward(self, x): 40 | 41 | pose_out = self.resnet(x) 42 | return pose_out -------------------------------------------------------------------------------- /pose_dataset/lsp/README.txt: -------------------------------------------------------------------------------- 1 | Leeds Sports Pose Dataset (Original size images) 2 | Sam Johnson and Mark Everingham 3 | http://sam.johnson.io/research/lsp.html 4 | 5 | This dataset contains 2000 images of mostly sports people 6 | gathered from Flickr. The file joints.mat contains 14 joint 7 | locations for each image along with a binary value specifying 8 | joint visibility. 9 | 10 | The ordering of the joints is as follows: 11 | 12 | Right ankle 13 | Right knee 14 | Right hip 15 | Left hip 16 | Left knee 17 | Left ankle 18 | Right wrist 19 | Right elbow 20 | Right shoulder 21 | Left shoulder 22 | Left elbow 23 | Left wrist 24 | Neck 25 | Head top 26 | 27 | This archive contains two folders: 28 | images - containing the original images 29 | visualized - containing the images with poses visualized 30 | 31 | 32 | If you use this dataset please cite 33 | 34 | Sam Johnson and Mark Everingham 35 | "Clustered Pose and Nonlinear Appearance Models for Human Pose Estimation" 36 | In proceedings of the 21st British Machine Vision Conference (BMVC2010) 37 | 38 | @inproceedings{Johnson10, 39 | title = {Clustered Pose and Nonlinear Appearance Models for Human Pose Estimation}, 40 | author = {Johnson, Sam and Everingham, Mark}, 41 | year = {2010}, 42 | booktitle = {Proceedings of the British Machine Vision Conference}, 43 | note = {doi:10.5244/C.24.12} 44 | } 45 | 46 | For further information on this dataset including our results and protocols please visit: 47 | http://sam.johnson.io/research/lsp.html 48 | E-mail: s.a.johnson04@leeds.ac.uk 49 | -------------------------------------------------------------------------------- /pose_dataset/lsp/joints.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-lab/MobilePose-pytorch/351fc4c244416c24cd2e35e4edb98b5ad856a722/pose_dataset/lsp/joints.mat -------------------------------------------------------------------------------- /pose_dataset/lsp_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2016 Artsiom Sanakoyeu 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from os.path import basename 7 | from scipy.io import loadmat 8 | import argparse 9 | import glob 10 | import re 11 | import os.path 12 | 13 | from scripts.config import * 14 | 15 | 16 | def create_data(images_dir, joints_mat_path, transpose_order=(2, 0, 1)): 17 | """ 18 | Create a list of lines in format: 19 | image_path, x1, y1, x2,y2, ... 20 | where xi, yi - coordinates of the i-th joint 21 | """ 22 | joints = loadmat(joints_mat_path) 23 | print(joints['joints'].shape) 24 | joints = joints['joints'].transpose(*transpose_order) 25 | print(joints.shape) 26 | joints = joints[:, :, :2] 27 | print(joints.shape) 28 | if joints.shape[1:] != (14, 2): 29 | raise ValueError('Incorrect shape of the joints matrix of joints.mat. ' 30 | 'Expected: (?, 14, 2); received: (?, {}, {})'.format( 31 | joints.shape[1], joints.shape[2])) 32 | 33 | lines = list() 34 | for img_path in sorted(glob.glob(os.path.join(images_dir, '*.jpg'))): 35 | index = int(re.search(r'im([0-9]+)', basename(img_path)).groups()[0]) - 1 36 | joints_str_list = [str(j) if j > 0 else '-1' for j in joints[index].flatten().tolist()] 37 | 38 | out_list = [img_path] 39 | out_list.extend(joints_str_list) 40 | out_str = ','.join(out_list) 41 | 42 | lines.append(out_str) 43 | return lines 44 | 45 | 46 | if __name__ == '__main__': 47 | """ 48 | Write train.csv and test.csv. 49 | Each line in csv file will be in the following format: 50 | image_name, x1, y1, x2,y2, ... 51 | where xi, yi - coordinates of the i-th joint 52 | Train file consists of 11000 lines (all images from extended LSP + first 1000 images from small LSP). 53 | Test file consists of 1000 lines (last 1000 images from small LSP). 54 | """ 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--extended_lsp_images_dir', type=str, default=os.path.join(LSP_EXT_DATASET_ROOT, 'images')) 58 | parser.add_argument('--extended_lsp_joints_path', type=str, default=os.path.join(LSP_EXT_DATASET_ROOT, 'joints.mat')) 59 | parser.add_argument('--small_lsp_images_dir', type=str, default=os.path.join(LSP_DATASET_ROOT, 'images')) 60 | parser.add_argument('--small_lsp_joints_path', type=str, default=os.path.join(LSP_DATASET_ROOT, 'joints.mat')) 61 | parser.add_argument('--output_dir', type=str, default=LSP_EXT_DATASET_ROOT) 62 | args = parser.parse_args() 63 | print(args) 64 | if not os.path.exists(args.output_dir): 65 | os.makedirs(args.output_dir) 66 | 67 | file_train = open('%s/train_joints.csv' % args.output_dir, 'w') 68 | file_test = open('%s/test_joints.csv' % args.output_dir, 'w') 69 | file_train_lsp_small = open('%s/train_lsp_small_joints.csv' % args.output_dir, 'w') 70 | 71 | print('Read LSP_EXT') 72 | lsp_ext_lines = create_data(args.extended_lsp_images_dir, args.extended_lsp_joints_path, 73 | transpose_order=(2, 0, 1)) 74 | print('Read LSP') 75 | lsp_small_lines = create_data(args.small_lsp_images_dir, args.small_lsp_joints_path, 76 | transpose_order=(2, 1, 0)) # different dim order 77 | print('Extended LSP images:', len(lsp_ext_lines)) 78 | print('Small LSP images:', len(lsp_small_lines)) 79 | if len(lsp_ext_lines) != 10000: 80 | raise Exception('Extended LSP dataset must contain 10000 images!') 81 | if len(lsp_small_lines) != 2000: 82 | raise Exception('Small LSP dataset must contain 2000 images!') 83 | num_small_lsp_train = 1000 84 | 85 | for line in lsp_ext_lines: 86 | print(line, file=file_train) 87 | for line in lsp_small_lines[:num_small_lsp_train]: 88 | print(line, file=file_train) 89 | print(line, file=file_train_lsp_small) 90 | for line in lsp_small_lines[num_small_lsp_train:]: 91 | print(line, file=file_test) 92 | 93 | file_train.close() 94 | file_test.close() 95 | file_train_lsp_small.close() 96 | -------------------------------------------------------------------------------- /pose_dataset/lsp_ext/README.txt: -------------------------------------------------------------------------------- 1 | Leeds Sports Pose Extended Training Dataset 2 | Sam Johnson and Mark Everingham 3 | http://sam.johnson.io/research/lspet.html 4 | 5 | This is a set of 10,000 images gathered from Flickr searches for 6 | the tags 'parkour', 'gymnastics', and 'athletics'. Each image has 7 | a corresponding annotation gathered from Amazon Mechanical Turk. 8 | The images have been scaled such that the annotated person is 9 | roughly 150 pixels in length. 10 | 11 | The archive contains two top-level files and one folder: 12 | README.txt - this document 13 | joints.mat - a MATLAB format matrix 'joints' consisting of 14 14 | joint locations and visibility flags. Joints are 15 | labelled in the following order: 16 | Right ankle 17 | Right knee 18 | Right hip 19 | Left hip 20 | Left knee 21 | Left ankle 22 | Right wrist 23 | Right elbow 24 | Right shoulder 25 | Left shoulder 26 | Left elbow 27 | Left wrist 28 | Neck 29 | Head top 30 | images/ - 10,000 images 31 | 32 | There is a second archive: 33 | http://sam.johnson.io/research/lspet_dataset_visualized.zip 34 | containing the 10,000 images above with rendered poses. 35 | 36 | If you use this dataset please cite 37 | 38 | Sam Johnson and Mark Everingham 39 | "Learning Effective Human Pose Estimation from Inaccurate Annotation" 40 | In proceedings of Computer Vision and Pattern Recognition (CVPR) 2011 41 | 42 | @inproceedings{Johnson11, 43 | title = {Learning Effective Human Pose Estimation from Inaccurate Annotation}, 44 | author = {Johnson, Sam and Everingham, Mark}, 45 | year = {2011}, 46 | booktitle = {Proceedings of Computer Vision and Pattern Recognition (CVPR) 2011} 47 | } 48 | 49 | E-mail: s.a.johnson04@leeds.ac.uk 50 | -------------------------------------------------------------------------------- /pose_dataset/lsp_ext/joints.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-lab/MobilePose-pytorch/351fc4c244416c24cd2e35e4edb98b5ad856a722/pose_dataset/lsp_ext/joints.mat -------------------------------------------------------------------------------- /pose_dataset/mpii/mpii_human_pose_v1_u12_1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-lab/MobilePose-pytorch/351fc4c244416c24cd2e35e4edb98b5ad856a722/pose_dataset/mpii/mpii_human_pose_v1_u12_1.mat -------------------------------------------------------------------------------- /pose_dataset/mpii/mpii_human_pose_v1_u12_2/README.md: -------------------------------------------------------------------------------- 1 | --------------------------------------------------------------------------- 2 | MPII Human Pose Dataset, Version 1.0 3 | Copyright 2015 Max Planck Institute for Informatics 4 | Licensed under the Simplified BSD License [see bsd.txt] 5 | --------------------------------------------------------------------------- 6 | 7 | We are making the annotations and the corresponding code freely available for research 8 | purposes. If you would like to use the dataset for any other purposes please contact 9 | the authors. 10 | 11 | ### Introduction 12 | MPII Human Pose dataset is a state of the art benchmark for evaluation 13 | of articulated human pose estimation. The dataset includes around 14 | **25K images** containing over **40K people** with annotated body 15 | joints. The images were systematically collected using an established 16 | taxonomy of every day human activities. Overall the dataset covers 17 | **410 human activities** and each image assigned an activity 18 | label. Each image was extracted from a YouTube video and provided with 19 | preceding and following un-annotated frames. In addition, for the test 20 | set we obtained richer annotations including body part occlusions and 21 | 3D torso and head orientations. 22 | 23 | Following the best practices for the performance evaluation benchmarks 24 | in the literature we withhold the test annotations to prevent 25 | overfitting and tuning on the test set. We are working on an automatic 26 | evaluation server and performance analysis tools based on rich test 27 | set annotations. 28 | 29 | ### Citing the dataset 30 | ``` 31 | @inproceedings{andriluka14cvpr, 32 | author = {Mykhaylo Andriluka and Leonid Pishchulin and Peter Gehler and Schiele, Bernt} 33 | title = {2D Human Pose Estimation: New Benchmark and State of the Art Analysis}, 34 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 35 | year = {2014}, 36 | month = {June} 37 | } 38 | ``` 39 | 40 | ### Download 41 | 42 | -. **Images (12.9 GB)** 43 | 44 | http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/mpii_human_pose_v1.tar.gz 45 | -. **Annotations (12.5 MB)** 46 | 47 | http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/mpii_human_pose_v1_u12.tar.gz 48 | -. **Videos for each image (25 batches x 17 GB)** 49 | 50 | http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/mpii_human_pose_v1_sequences_batch1.tar.gz 51 | ... 52 | http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/mpii_human_pose_v1_sequences_batch25.tar.gz 53 | -. **Image - video mapping (239 KB)** 54 | 55 | http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/mpii_human_pose_v1_sequences_keyframes.mat 56 | 57 | ### Annotation description 58 | Annotations are stored in a matlab structure `RELEASE` having following fields 59 | 60 | - `.annolist(imgidx)` - annotations for image `imgidx` 61 | - `.image.name` - image filename 62 | - `.annorect(ridx)` - body annotations for a person `ridx` 63 | - `.x1, .y1, .x2, .y2` - coordinates of the head rectangle 64 | - `.scale` - person scale w.r.t. 200 px height 65 | - `.objpos` - rough human position in the image 66 | - `.annopoints.point` - person-centric body joint annotations 67 | - `.x, .y` - coordinates of a joint 68 | - `id` - joint id 69 | [//]: # "(0 - r ankle, 1 - r knee, 2 - r hip, 3 - l hip, 4 - l knee, 5 - l ankle, 6 - pelvis, 7 - thorax, 8 - upper neck, 9 - head top, 10 - r wrist, 10 - r wrist, 12 - r shoulder, 13 - l shoulder, 14 - l elbow, 15 - l wrist)" 70 | - `is_visible` - joint visibility 71 | - `.vidx` - video index in `video_list` 72 | - `.frame_sec` - image position in video, in seconds 73 | 74 | - `img_train(imgidx)` - training/testing image assignment 75 | - `single_person(imgidx)` - contains rectangle id `ridx` of *sufficiently separated* individuals 76 | - `act(imgidx)` - activity/category label for image `imgidx` 77 | - `act_name` - activity name 78 | - `cat_name` - category name 79 | - `act_id` - activity id 80 | - `video_list(videoidx)` - specifies video id as is provided by YouTube. To watch video on youtube go to https://www.youtube.com/watch?v=video_list(videoidx) 81 | 82 | ### Browsing the dataset 83 | - Please use our online tool for browsing the data 84 | http://human-pose.mpi-inf.mpg.de/#dataset 85 | - Red rectangles mark testing images 86 | 87 | ### References 88 | - **2D Human Pose Estimation: New Benchmark and State of the Art Analysis.** 89 | 90 | Mykhaylo Andriluka, Leonid Pishchulin, Peter Gehler and Bernt Schiele. 91 | 92 | IEEE CVPR'14 93 | - **Fine-grained Activity Recognition with Holistic and Pose based Features.** 94 | 95 | Leonid Pishchulin, Mykhaylo Andriluka and Bernt Schiele. 96 | 97 | GCPR'14 98 | 99 | ### Contact 100 | You can reach us via `@mpi-inf.mpg.de` 101 | We are looking forward to your feedback. If you have any questions related to the dataset please let us know. 102 | -------------------------------------------------------------------------------- /pose_dataset/mpii/mpii_human_pose_v1_u12_2/bsd.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Max Planck Institute for Informatics 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 11 | -------------------------------------------------------------------------------- /pose_dataset/mpii_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2016 Shunta Saito (original code) 3 | # Copyright (c) 2016 Artsiom Sanakoyeu 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | from __future__ import unicode_literals 9 | from scipy.io import loadmat 10 | from itertools import izip 11 | import json 12 | import numpy as np 13 | 14 | from scripts.config import * 15 | 16 | MPII_DATA_DIR = MPII_DATASET_ROOT 17 | MPII_OUT_DIR = MPII_DATASET_ROOT 18 | 19 | 20 | def fix_wrong_joints(joint): 21 | if '12' in joint and '13' in joint and '2' in joint and '3' in joint: 22 | if ((joint['12'][0] < joint['13'][0]) and 23 | (joint['3'][0] < joint['2'][0])): 24 | joint['2'], joint['3'] = joint['3'], joint['2'] 25 | if ((joint['12'][0] > joint['13'][0]) and 26 | (joint['3'][0] > joint['2'][0])): 27 | joint['2'], joint['3'] = joint['3'], joint['2'] 28 | 29 | return joint 30 | 31 | 32 | def save_joints(): 33 | """ 34 | Convert annotations mat file to json and save on disk. 35 | Only persons with annotations of all 16 joints will be written in the json. 36 | """ 37 | joint_data_fn = os.path.join(MPII_OUT_DIR, 'data.json') 38 | mat = loadmat(os.path.join(MPII_DATA_DIR, 'mpii_human_pose_v1_u12_1.mat')) 39 | 40 | fp = open(joint_data_fn, 'w') 41 | 42 | for i, (anno, train_flag) in enumerate( 43 | izip(mat['RELEASE']['annolist'][0, 0][0], 44 | mat['RELEASE']['img_train'][0, 0][0])): 45 | 46 | img_fn = anno['image']['name'][0, 0][0] 47 | train_flag = int(train_flag) 48 | 49 | if 'annopoints' in str(anno['annorect'].dtype): 50 | annopoints = anno['annorect']['annopoints'][0] 51 | head_x1s = anno['annorect']['x1'][0] 52 | head_y1s = anno['annorect']['y1'][0] 53 | head_x2s = anno['annorect']['x2'][0] 54 | head_y2s = anno['annorect']['y2'][0] 55 | for annopoint, head_x1, head_y1, head_x2, head_y2 in \ 56 | izip(annopoints, head_x1s, head_y1s, head_x2s, head_y2s): 57 | if len(annopoint) > 0: 58 | head_rect = [float(head_x1[0, 0]), 59 | float(head_y1[0, 0]), 60 | float(head_x2[0, 0]), 61 | float(head_y2[0, 0])] 62 | 63 | # joint coordinates 64 | annopoint = annopoint['point'][0, 0] 65 | j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]] 66 | x = [x[0, 0] for x in annopoint['x'][0]] 67 | y = [y[0, 0] for y in annopoint['y'][0]] 68 | joint_pos = {} 69 | for _j_id, (_x, _y) in zip(j_id, zip(x, y)): 70 | joint_pos[str(_j_id)] = [float(_x), float(_y)] 71 | # joint_pos = fix_wrong_joints(joint_pos) 72 | 73 | # visiblity list 74 | if 'is_visible' in str(annopoint.dtype): 75 | vis = [v[0] if v else [0] 76 | for v in annopoint['is_visible'][0]] 77 | vis = dict([(k, int(v[0])) if len(v) > 0 else v 78 | for k, v in zip(j_id, vis)]) 79 | else: 80 | vis = None 81 | 82 | if len(joint_pos) == 16: 83 | data = { 84 | 'filename': img_fn, 85 | 'train': train_flag, 86 | 'head_rect': head_rect, 87 | 'is_visible': vis, 88 | 'joint_pos': joint_pos 89 | } 90 | 91 | print(json.dumps(data), file=fp) 92 | 93 | 94 | def write_line(datum, fp): 95 | """ 96 | Write a line in format: 97 | image_name, x1, y1, x2,y2, ... 98 | where xi, yi - coordinates of the i-th joint 99 | """ 100 | joints = sorted([[int(k), v] for k, v in datum['joint_pos'].items()]) 101 | joints = np.array([j for i, j in joints]).flatten() 102 | 103 | out = [datum['filename']] 104 | out.extend(joints) 105 | out = [str(o) for o in out] 106 | out = ','.join(out) 107 | 108 | print(out, file=fp) 109 | 110 | 111 | def split_train_test(): 112 | fp_test = open(os.path.join(MPII_OUT_DIR, 'test_joints.csv'), 'w') 113 | fp_train = open(os.path.join(MPII_OUT_DIR, 'train_joints.csv'), 'w') 114 | all_data = open(os.path.join(MPII_OUT_DIR, 'data.json')).readlines() 115 | N = len(all_data) 116 | N_test = int(N * 0.1) 117 | N_train = N - N_test 118 | 119 | print('N:{}'.format(N)) 120 | print('N_train:{}'.format(N_train)) 121 | print('N_test:{}'.format(N_test)) 122 | 123 | np.random.seed(1701) 124 | perm = np.random.permutation(N) 125 | test_indices = perm[:N_test] 126 | train_indices = perm[N_test:] 127 | 128 | print('train_indices:{}'.format(len(train_indices))) 129 | print('test_indices:{}'.format(len(test_indices))) 130 | 131 | for i in train_indices: 132 | datum = json.loads(all_data[i].strip()) 133 | write_line(datum, fp_train) 134 | 135 | for i in test_indices: 136 | datum = json.loads(all_data[i].strip()) 137 | write_line(datum, fp_test) 138 | 139 | 140 | if __name__ == '__main__': 141 | save_joints() 142 | split_train_test() 143 | -------------------------------------------------------------------------------- /pycocotools: -------------------------------------------------------------------------------- 1 | cocoapi/PythonAPI/pycocotools -------------------------------------------------------------------------------- /run_webcam.py: -------------------------------------------------------------------------------- 1 | ''' 2 | File: run_webcam.py 3 | Project: MobilePose 4 | File Created: Thursday, 8th March 2018 2:19:39 pm 5 | Author: Yuliang Xiu (yuliangxiu@sjtu.edu.cn) 6 | ----- 7 | Last Modified: Thursday, 8th March 2018 3:01:35 pm 8 | Modified By: Yuliang Xiu (yuliangxiu@sjtu.edu.cn>) 9 | ----- 10 | Copyright 2018 - 2018 Shanghai Jiao Tong University, Machine Vision and Intelligence Group 11 | ''' 12 | 13 | import argparse 14 | import logging 15 | import time 16 | 17 | import cv2 18 | import numpy as np 19 | 20 | import torch 21 | import torch.nn as nn 22 | from torchvision import models 23 | 24 | from estimator import ResEstimator 25 | import matplotlib.pyplot as plt 26 | from networks import * 27 | from dataloader import crop_camera 28 | 29 | if __name__ == '__main__': 30 | 31 | parser = argparse.ArgumentParser(description='MobilePose Realtime Webcam.') 32 | parser.add_argument('--model', type=str, default='resnet', help='mobilenet|resnet') 33 | parser.add_argument('--camera', type=int, default=0) 34 | 35 | args = parser.parse_args() 36 | 37 | w, h = model_wh(get_graph_path(args.model)) 38 | e = ResEstimator(get_graph_path(args.model), target_size=(w,h)) 39 | cam = cv2.VideoCapture(args.camera) 40 | 41 | ret_val, image = cam.read() 42 | image = crop_camera(image) 43 | 44 | while True: 45 | ret_val , image = cam.read() 46 | image = crop_camera(image) 47 | humans = e.inference(image, args.model) 48 | image = ResEstimator.draw_humans(image, humans, imgcopy=False) 49 | cv2.imshow('tf-pose-estimation result', image) 50 | if cv2.waitKey(1) == 27: # ESC 51 | break 52 | 53 | cv2.destroyAllWindows() 54 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | ''' 4 | File: training.py 5 | Project: MobilePose 6 | File Created: Thursday, 8th March 2018 2:50:11 pm 7 | Author: Yuliang Xiu (yuliangxiu@sjtu.edu.cn) 8 | ----- 9 | Last Modified: Thursday, 8th March 2018 2:50:51 pm 10 | Modified By: Yuliang Xiu (yuliangxiu@sjtu.edu.cn>) 11 | ----- 12 | Copyright 2018 - 2018 Shanghai Jiao Tong University, Machine Vision and Intelligence Group 13 | ''' 14 | 15 | # remove warning 16 | import warnings 17 | warnings.filterwarnings('ignore') 18 | 19 | import os 20 | import numpy as np 21 | from networks import * 22 | from dataloader import * 23 | import argparse 24 | 25 | import torch 26 | import torch.nn as nn 27 | import torch.optim as optim 28 | from torch.utils.data import Dataset, DataLoader 29 | 30 | if __name__ == '__main__': 31 | 32 | parser = argparse.ArgumentParser(description='MobilePose Demo') 33 | parser.add_argument('--model', type=str, default="resnet") 34 | parser.add_argument('--gpu', type=str, default="0") 35 | args = parser.parse_args() 36 | modeltype = args.model 37 | 38 | # user defined parameters 39 | num_threads = 10 40 | 41 | if modeltype =='resnet': 42 | modelname = "final-aug.t7" 43 | pretrain = True 44 | batchsize = 256 45 | minloss = 316.52189376 #changed expand ratio 46 | # minloss = 272.49565467 #fixed expand ratio 47 | learning_rate = 1e-05 48 | net = Net().cuda() 49 | inputsize = 227 50 | elif modeltype == "mobilenet": 51 | modelname = "final-aug.t7" 52 | pretrain = True 53 | batchsize = 128 54 | minloss = 396.84708708 # change expand ratio 55 | # minloss = 332.48316225 # fixed expand ratio 56 | learning_rate = 1e-06 57 | net = MobileNetV2(image_channel=5).cuda() 58 | inputsize = 224 59 | 60 | # gpu setting 61 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 62 | torch.backends.cudnn.enabled = True 63 | gpus = [0,1] 64 | print("GPU NUM: %d"%(torch.cuda.device_count())) 65 | 66 | logname = modeltype+'-log.txt' 67 | 68 | if pretrain: 69 | net = torch.load('./models/%s/%s'%(modeltype,modelname)).cuda(device_id=gpus[0]) 70 | 71 | ROOT_DIR = "../deeppose_tf/datasets/mpii" 72 | PATH_PREFIX = './models/{}/'.format(modeltype) 73 | 74 | train_dataset = PoseDataset(csv_file=os.path.join(ROOT_DIR,'train_joints.csv'), 75 | transform=transforms.Compose([ 76 | # Augmentation(), 77 | Rescale((inputsize,inputsize)), 78 | # Wrap((inputsize,inputsize)), 79 | Expansion(), 80 | ToTensor() 81 | ])) 82 | train_dataloader = DataLoader(train_dataset, batch_size=batchsize, 83 | shuffle=False, num_workers = num_threads) 84 | 85 | test_dataset = PoseDataset(csv_file=os.path.join(ROOT_DIR,'test_joints.csv'), 86 | transform=transforms.Compose([ 87 | Rescale((inputsize,inputsize)), 88 | # Wrap((inputsize, inputsize)), 89 | Expansion(), 90 | ToTensor() 91 | ])) 92 | test_dataloader = DataLoader(test_dataset, batch_size=batchsize, 93 | shuffle=False, num_workers = num_threads) 94 | 95 | 96 | criterion = nn.MSELoss().cuda() 97 | # optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08) 98 | optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9) 99 | 100 | 101 | def mse_loss(input, target): 102 | return torch.sum(torch.pow(input - target,2)) / input.nelement() 103 | 104 | train_loss_all = [] 105 | valid_loss_all = [] 106 | 107 | for epoch in range(1000): # loop over the dataset multiple times 108 | 109 | train_loss_epoch = [] 110 | for i, data in enumerate(train_dataloader): 111 | images, poses = data['image'], data['pose'] 112 | images, poses = Variable(images.cuda()), Variable(poses.cuda()) 113 | optimizer.zero_grad() 114 | outputs = net(images) 115 | loss = criterion(outputs, poses) 116 | loss.backward() 117 | optimizer.step() 118 | 119 | train_loss_epoch.append(loss.data[0]) 120 | 121 | if epoch%2==0: 122 | valid_loss_epoch = [] 123 | for i_batch, sample_batched in enumerate(test_dataloader): 124 | 125 | net_forward = net 126 | images = sample_batched['image'].cuda() 127 | poses = sample_batched['pose'].cuda() 128 | outputs = net_forward(Variable(images, volatile=True)) 129 | valid_loss_epoch.append(mse_loss(outputs.data,poses)) 130 | 131 | if np.mean(np.array(valid_loss_epoch)) < minloss: 132 | minloss = np.mean(np.array(valid_loss_epoch)) 133 | checkpoint_file = PATH_PREFIX + modelname 134 | torch.save(net, checkpoint_file) 135 | print('==> checkpoint model saving to %s'%checkpoint_file) 136 | 137 | print('[epoch %d] train loss: %.8f, valid loss: %.8f' % 138 | (epoch + 1, np.mean(np.array(train_loss_epoch)), np.mean(np.array(valid_loss_epoch)))) 139 | with open(PATH_PREFIX+logname, 'a+') as file_output: 140 | file_output.write('[epoch %d] train loss: %.8f, valid loss: %.8f\n' % 141 | (epoch + 1, np.mean(np.array(train_loss_epoch)), np.mean(np.array(valid_loss_epoch)))) 142 | file_output.flush() 143 | 144 | print('Finished Training') --------------------------------------------------------------------------------