├── lib ├── nms │ ├── __init__.py │ ├── __pycache__ │ │ ├── nms.cpython-37.pyc │ │ └── __init__.cpython-37.pyc │ ├── cpu_nms.cpython-37m-x86_64-linux-gnu.so │ ├── gpu_nms.cpython-37m-x86_64-linux-gnu.so │ ├── gpu_nms.hpp │ ├── gpu_nms.pyx │ ├── cpu_nms.pyx │ ├── nms_kernel.cu │ ├── setup_linux.py │ └── nms.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── utils.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── consistency.cpython-37.pyc │ │ ├── transforms.cpython-37.pyc │ │ └── augmentation_pool.cpython-37.pyc │ ├── zipreader.py │ ├── vis_skeleton.py │ ├── consistency.py │ ├── augmentation_pool.py │ ├── transforms.py │ └── utils.py ├── core │ ├── __pycache__ │ │ ├── loss.cpython-37.pyc │ │ ├── evaluate.cpython-37.pyc │ │ ├── function.cpython-37.pyc │ │ ├── function1.cpython-37.pyc │ │ └── inference.cpython-37.pyc │ ├── evaluate.py │ ├── inference.py │ ├── loss.py │ └── function1.py ├── dataset │ ├── __pycache__ │ │ ├── coco.cpython-37.pyc │ │ ├── mpii.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── coco_align.cpython-37.pyc │ │ └── JointsDataset.cpython-37.pyc │ ├── __init__.py │ ├── mpii.py │ └── JointsDataset.py ├── Makefile ├── config │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── default.cpython-37.pyc │ │ └── models.cpython-37.pyc │ ├── __init__.py │ ├── models.py │ └── default.py ├── models │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── pose_hrnet.cpython-37.pyc │ │ └── pose_hrnet_part.cpython-37.pyc │ └── __init__.py └── dataset_animal │ ├── __pycache__ │ ├── ap10k.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── ap10k_info.cpython-37.pyc │ ├── ap10k_mt_v3.cpython-37.pyc │ ├── dataset_info.cpython-37.pyc │ ├── kpt_2d_base.cpython-37.pyc │ ├── ap10k_category.cpython-37.pyc │ ├── ap10k_fewshot.cpython-37.pyc │ ├── ap10k_test_category.cpython-37.pyc │ └── kpt_2d_base_mt_pseudol_v3.cpython-37.pyc │ ├── __init__.py │ ├── dataset_info.py │ ├── ap10k_info.py │ └── ap10k_test_category.py ├── Network.png ├── teaser.png ├── teaser1.png ├── tools ├── _init_paths.py ├── test.py ├── train.py └── train_mt_part.py ├── LICENSE ├── requirements.txt ├── data └── label_list │ ├── annotation_list_5 │ ├── annotation_list_10 │ ├── annotation_list_Bovidae │ ├── annotation_list_15 │ ├── annotation_list_20 │ └── annotation_list_25 ├── experiments └── ap10k │ ├── resnet │ └── res50_256x192_d256x3_adam_lr1e-3.yaml │ └── hrnet │ └── w32_256x192_adam_lr1e-3.yaml └── README.md /lib/nms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/Network.png -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/teaser.png -------------------------------------------------------------------------------- /teaser1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/teaser1.png -------------------------------------------------------------------------------- /lib/nms/__pycache__/nms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/nms/__pycache__/nms.cpython-37.pyc -------------------------------------------------------------------------------- /lib/core/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/core/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/coco.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset/__pycache__/coco.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/mpii.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset/__pycache__/mpii.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nms/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/nms/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /lib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | cd nms; python setup_linux.py build_ext --inplace; rm -rf build; cd ../../ 3 | clean: 4 | cd nms; rm *.so; cd ../../ 5 | -------------------------------------------------------------------------------- /lib/config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/default.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/config/__pycache__/default.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/config/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /lib/core/__pycache__/evaluate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/core/__pycache__/evaluate.cpython-37.pyc -------------------------------------------------------------------------------- /lib/core/__pycache__/function.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/core/__pycache__/function.cpython-37.pyc -------------------------------------------------------------------------------- /lib/core/__pycache__/function1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/core/__pycache__/function1.cpython-37.pyc -------------------------------------------------------------------------------- /lib/core/__pycache__/inference.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/core/__pycache__/inference.cpython-37.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/pose_hrnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/models/__pycache__/pose_hrnet.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nms/cpu_nms.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/nms/cpu_nms.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /lib/nms/gpu_nms.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/nms/gpu_nms.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /lib/utils/__pycache__/consistency.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/utils/__pycache__/consistency.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/utils/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/coco_align.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset/__pycache__/coco_align.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/ap10k.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/ap10k.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset/__pycache__/JointsDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset/__pycache__/JointsDataset.cpython-37.pyc -------------------------------------------------------------------------------- /lib/models/__pycache__/pose_hrnet_part.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/models/__pycache__/pose_hrnet_part.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/ap10k_info.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/ap10k_info.cpython-37.pyc -------------------------------------------------------------------------------- /lib/utils/__pycache__/augmentation_pool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/utils/__pycache__/augmentation_pool.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/ap10k_mt_v3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/ap10k_mt_v3.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/dataset_info.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/dataset_info.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/kpt_2d_base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/kpt_2d_base.cpython-37.pyc -------------------------------------------------------------------------------- /lib/nms/gpu_nms.hpp: -------------------------------------------------------------------------------- 1 | void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num, 2 | int boxes_dim, float nms_overlap_thresh, int device_id); 3 | -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/ap10k_category.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/ap10k_category.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/ap10k_fewshot.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/ap10k_fewshot.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/ap10k_test_category.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/ap10k_test_category.cpython-37.pyc -------------------------------------------------------------------------------- /lib/dataset_animal/__pycache__/kpt_2d_base_mt_pseudol_v3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaneyddtt/ScarceNet/HEAD/lib/dataset_animal/__pycache__/kpt_2d_base_mt_pseudol_v3.cpython-37.pyc -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from .default import _C as cfg 8 | from .default import update_config 9 | from .models import MODEL_EXTRAS 10 | -------------------------------------------------------------------------------- /lib/dataset_animal/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .ap10k import AnimalAP10KDataset as ap10k 6 | from .ap10k_fewshot import AnimalAP10KDataset as ap10k_fewshot 7 | from .ap10k_mt_v3 import AnimalAP10KDataset as ap10k_mt_v3 8 | from .ap10k_category import AnimalAP10KDataset as ap10k_category 9 | from .ap10k_test_category import AnimalAP10KDataset as ap10k_test_category -------------------------------------------------------------------------------- /lib/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from .mpii import MPIIDataset as mpii 12 | from .coco import COCODataset as coco 13 | from .coco_align import COCODataset as coco_align 14 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from __future__ import absolute_import 12 | from __future__ import division 13 | from __future__ import print_function 14 | 15 | import models.pose_hrnet 16 | import models.pose_hrnet_part 17 | -------------------------------------------------------------------------------- /tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # pose.pytorch 3 | # Copyright (c) 2018-present Microsoft 4 | # Licensed under The Apache-2.0 License [see LICENSE for details] 5 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os.path as osp 13 | import sys 14 | 15 | 16 | def add_path(path): 17 | if path not in sys.path: 18 | sys.path.insert(0, path) 19 | 20 | 21 | this_dir = osp.dirname(__file__) 22 | 23 | lib_path = osp.join(this_dir, '..', 'lib') 24 | add_path(lib_path) 25 | 26 | mm_path = osp.join(this_dir, '..', 'lib/poseeval/py-motmetrics') 27 | add_path(mm_path) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Li Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2021.10.8 2 | charset-normalizer==2.0.9 3 | cycler==0.11.0 4 | Cython==0.29.24 5 | dataclasses==0.6 6 | easydict==1.7 7 | future==0.18.2 8 | idna==3.3 9 | imageio==2.10.1 10 | imgaug==0.4.0 11 | joblib==1.1.0 12 | json-tricks==3.15.5 13 | jsonpatch==1.32 14 | jsonpointer==2.2 15 | kiwisolver==1.3.2 16 | matplotlib==3.4.3 17 | networkx==2.6.3 18 | numpy==1.20.2 19 | opencv-contrib-python==3.4.2.17 20 | opencv-python==3.4.2.17 21 | pandas==1.3.4 22 | Pillow==8.2.0 23 | progress==1.6 24 | protobuf==3.19.1 25 | pycocotools==2.0.2 26 | pyparsing==3.0.4 27 | python-dateutil==2.8.2 28 | pytz==2021.3 29 | PyWavelets==1.1.1 30 | PyYAML==6.0 31 | pyzmq==22.3.0 32 | requests==2.26.0 33 | scikit-image==0.18.3 34 | scikit-learn==1.0.2 35 | scipy==1.2.1 36 | seaborn==0.11.2 37 | Shapely==1.8.0 38 | six==1.16.0 39 | tensorboardX==1.6 40 | threadpoolctl==3.1.0 41 | tifffile==2021.10.12 42 | torch==1.7.1+cu101 43 | torchaudio==0.7.2 44 | torchfile==0.1.0 45 | torchnet==0.0.4 46 | torchvision==0.8.2+cu101 47 | tornado==6.1 48 | tqdm==4.62.3 49 | typing-extensions==3.10.0.0 50 | urllib3==1.26.7 51 | visdom==0.1.8.9 52 | websocket-client==1.2.3 53 | xtcocotools==1.10 54 | yacs==0.1.8 55 | -------------------------------------------------------------------------------- /lib/nms/gpu_nms.pyx: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import numpy as np 12 | cimport numpy as np 13 | 14 | assert sizeof(int) == sizeof(np.int32_t) 15 | 16 | cdef extern from "gpu_nms.hpp": 17 | void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int) 18 | 19 | def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh, 20 | np.int32_t device_id=0): 21 | cdef int boxes_num = dets.shape[0] 22 | cdef int boxes_dim = dets.shape[1] 23 | cdef int num_out 24 | cdef np.ndarray[np.int32_t, ndim=1] \ 25 | keep = np.zeros(boxes_num, dtype=np.int32) 26 | cdef np.ndarray[np.float32_t, ndim=1] \ 27 | scores = dets[:, 4] 28 | cdef np.ndarray[np.int32_t, ndim=1] \ 29 | order = scores.argsort()[::-1].astype(np.int32) 30 | cdef np.ndarray[np.float32_t, ndim=2] \ 31 | sorted_dets = dets[order, :] 32 | _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id) 33 | keep = keep[:num_out] 34 | return list(order[keep]) 35 | -------------------------------------------------------------------------------- /data/label_list/annotation_list_5: -------------------------------------------------------------------------------- 1 | [20026, 20057, 20077, 20067, 20045, 20104, 20155, 20204, 20211, 20198, 20319, 20318, 20244, 20281, 20280, 20493, 20338, 20380, 20450, 20343, 20608, 20622, 20643, 20679, 20624, 23092, 22937, 22848, 23159, 23359, 105, 359, 192, 228, 404, 1107, 1099, 1096, 1039, 1079, 1117, 1193, 1115, 1458, 1138, 2193, 1998, 2061, 2030, 2080, 5000, 4706, 5111, 4916, 2577, 7756, 7076, 7377, 7714, 8123, 48663, 48595, 48656, 48655, 48680, 19939, 19870, 19847, 19875, 19880, 49381, 49383, 49312, 49355, 49386, 50219, 50029, 50230, 50209, 50185, 51848, 51807, 51885, 51802, 51808, 51931, 51918, 52060, 52100, 51981, 44364, 44266, 44341, 44902, 44890, 54873, 54978, 54866, 54817, 54817, 50842, 50864, 50994, 51181, 50839, 48787, 48819, 48772, 48732, 48810, 49029, 48978, 48920, 48949, 48924, 20782, 20756, 20726, 20692, 20803, 22082, 22046, 22050, 22116, 22047, 50313, 50334, 50319, 50343, 50257, 23660, 23449, 23572, 23635, 23484, 55635, 55653, 55643, 55665, 55655, 55853, 55783, 55682, 55822, 55823, 56723, 56655, 56601, 56733, 56644, 58570, 58447, 58522, 58430, 58492, 27092, 27103, 29840, 26603, 26826, 31632, 31721, 31466, 31055, 31199, 45002, 45043, 45083, 45025, 45003, 46216, 46276, 46261, 46205, 46292, 43209, 43150, 43229, 43101, 43152, 32039, 31994, 32058, 31966, 31920, 40412, 40795, 33006, 32551, 40235, 36362, 36511, 36456, 36360, 36472, 37590, 37666, 37578, 37532, 37858, 37894, 37889, 37878, 37888, 37874, 37977, 37945, 38497, 38011, 37983, 38913, 38848, 38730, 38912, 38855, 39855, 39856, 39834, 39808, 39751, 40860, 40853, 40846, 40907, 40916, 41040, 41025, 40952, 41067, 40962, 47652, 47585, 47510, 47463, 47595, 11012, 9944, 10350, 14157, 10323, 18112, 17669, 17856, 17715, 18292, 19463, 19335, 19313, 19350, 19325] -------------------------------------------------------------------------------- /experiments/ap10k/resnet/res50_256x192_d256x3_adam_lr1e-3.yaml: -------------------------------------------------------------------------------- 1 | AUTO_RESUME: true 2 | CUDNN: 3 | BENCHMARK: true 4 | DETERMINISTIC: false 5 | ENABLED: true 6 | DATA_DIR: '' 7 | GPUS: (0,1,2,3) 8 | OUTPUT_DIR: 'output' 9 | LOG_DIR: 'log' 10 | WORKERS: 4 11 | PRINT_FREQ: 100 12 | 13 | DATASET: 14 | COLOR_RGB: true 15 | DATASET: 'ap10k' 16 | DATA_FORMAT: jpg 17 | FLIP: true 18 | NUM_JOINTS_HALF_BODY: 8 19 | PROB_HALF_BODY: 0.3 20 | ROOT: 'data/animalpose/' 21 | ROT_FACTOR: 40 22 | SCALE_FACTOR: 0.5 23 | TEST_SET: 'test' 24 | TRAIN_SET: 'train' 25 | VAL_SET: 'val' 26 | SELECT_DATA: false 27 | SUPERCATEGORY: ['Hominidae'] 28 | MODEL: 29 | INIT_WEIGHTS: true 30 | NAME: 'pose_resnet' 31 | PRETRAINED: 'models/pytorch/imagenet/resnet50-19c8e357.pth' 32 | IMAGE_SIZE: 33 | - 256 34 | - 192 35 | HEATMAP_SIZE: 36 | - 64 37 | - 48 38 | SIGMA: 2 39 | NUM_JOINTS: 17 40 | TARGET_TYPE: 'gaussian' 41 | EXTRA: 42 | FINAL_CONV_KERNEL: 1 43 | DECONV_WITH_BIAS: false 44 | NUM_DECONV_LAYERS: 3 45 | NUM_DECONV_FILTERS: 46 | - 256 47 | - 256 48 | - 256 49 | NUM_DECONV_KERNELS: 50 | - 4 51 | - 4 52 | - 4 53 | NUM_LAYERS: 50 54 | LOSS: 55 | USE_TARGET_WEIGHT: true 56 | TRAIN: 57 | BATCH_SIZE_PER_GPU: 32 58 | SHUFFLE: true 59 | BEGIN_EPOCH: 0 60 | END_EPOCH: 210 61 | OPTIMIZER: 'adam' 62 | LR: 0.001 63 | LR_FACTOR: 0.1 64 | LR_STEP: 65 | - 170 66 | - 200 67 | WD: 0.0001 68 | GAMMA1: 0.99 69 | GAMMA2: 0.0 70 | MOMENTUM: 0.9 71 | NESTEROV: false 72 | TEST: 73 | BATCH_SIZE_PER_GPU: 32 74 | COCO_BBOX_FILE: '' 75 | BBOX_THRE: 1.0 76 | IMAGE_THRE: 0.0 77 | IN_VIS_THRE: 0.2 78 | MODEL_FILE: '' 79 | NMS_THRE: 1.0 80 | OKS_THRE: 0.9 81 | FLIP_TEST: true 82 | POST_PROCESS: true 83 | SHIFT_HEATMAP: true 84 | USE_GT_BBOX: true 85 | DEBUG: 86 | DEBUG: true 87 | SAVE_BATCH_IMAGES_GT: true 88 | SAVE_BATCH_IMAGES_PRED: true 89 | SAVE_HEATMAPS_GT: true 90 | SAVE_HEATMAPS_PRED: true 91 | -------------------------------------------------------------------------------- /lib/config/models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from yacs.config import CfgNode as CN 12 | 13 | 14 | # pose_resnet related params 15 | POSE_RESNET = CN() 16 | POSE_RESNET.NUM_LAYERS = 50 17 | POSE_RESNET.DECONV_WITH_BIAS = False 18 | POSE_RESNET.NUM_DECONV_LAYERS = 3 19 | POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256] 20 | POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4] 21 | POSE_RESNET.FINAL_CONV_KERNEL = 1 22 | POSE_RESNET.PRETRAINED_LAYERS = ['*'] 23 | 24 | # pose_multi_resoluton_net related params 25 | POSE_HIGH_RESOLUTION_NET = CN() 26 | POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] 27 | POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64 28 | POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 29 | 30 | POSE_HIGH_RESOLUTION_NET.STAGE2 = CN() 31 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 32 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 33 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] 34 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64] 35 | POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC' 36 | POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' 37 | 38 | POSE_HIGH_RESOLUTION_NET.STAGE3 = CN() 39 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 40 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 41 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] 42 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128] 43 | POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC' 44 | POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' 45 | 46 | POSE_HIGH_RESOLUTION_NET.STAGE4 = CN() 47 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 48 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 49 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 50 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 51 | POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC' 52 | POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' 53 | 54 | 55 | MODEL_EXTRAS = { 56 | 'pose_resnet': POSE_RESNET, 57 | 'pose_high_resolution_net': POSE_HIGH_RESOLUTION_NET, 58 | } 59 | -------------------------------------------------------------------------------- /lib/utils/zipreader.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import zipfile 13 | import xml.etree.ElementTree as ET 14 | 15 | import cv2 16 | import numpy as np 17 | 18 | _im_zfile = [] 19 | _xml_path_zip = [] 20 | _xml_zfile = [] 21 | 22 | 23 | def imread(filename, flags=cv2.IMREAD_COLOR): 24 | global _im_zfile 25 | path = filename 26 | pos_at = path.index('@') 27 | if pos_at == -1: 28 | print("character '@' is not found from the given path '%s'"%(path)) 29 | assert 0 30 | path_zip = path[0: pos_at] 31 | path_img = path[pos_at + 2:] 32 | if not os.path.isfile(path_zip): 33 | print("zip file '%s' is not found"%(path_zip)) 34 | assert 0 35 | for i in range(len(_im_zfile)): 36 | if _im_zfile[i]['path'] == path_zip: 37 | data = _im_zfile[i]['zipfile'].read(path_img) 38 | return cv2.imdecode(np.frombuffer(data, np.uint8), flags) 39 | 40 | _im_zfile.append({ 41 | 'path': path_zip, 42 | 'zipfile': zipfile.ZipFile(path_zip, 'r') 43 | }) 44 | data = _im_zfile[-1]['zipfile'].read(path_img) 45 | 46 | return cv2.imdecode(np.frombuffer(data, np.uint8), flags) 47 | 48 | 49 | def xmlread(filename): 50 | global _xml_path_zip 51 | global _xml_zfile 52 | path = filename 53 | pos_at = path.index('@') 54 | if pos_at == -1: 55 | print("character '@' is not found from the given path '%s'"%(path)) 56 | assert 0 57 | path_zip = path[0: pos_at] 58 | path_xml = path[pos_at + 2:] 59 | if not os.path.isfile(path_zip): 60 | print("zip file '%s' is not found"%(path_zip)) 61 | assert 0 62 | for i in xrange(len(_xml_path_zip)): 63 | if _xml_path_zip[i] == path_zip: 64 | data = _xml_zfile[i].open(path_xml) 65 | return ET.fromstring(data.read()) 66 | _xml_path_zip.append(path_zip) 67 | print("read new xml file '%s'"%(path_zip)) 68 | _xml_zfile.append(zipfile.ZipFile(path_zip, 'r')) 69 | data = _xml_zfile[-1].open(path_xml) 70 | return ET.fromstring(data.read()) 71 | -------------------------------------------------------------------------------- /lib/core/evaluate.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import numpy as np 12 | 13 | from core.inference import get_max_preds 14 | 15 | 16 | def calc_dists(preds, target, normalize): 17 | preds = preds.astype(np.float32) 18 | target = target.astype(np.float32) 19 | dists = np.zeros((preds.shape[1], preds.shape[0])) 20 | for n in range(preds.shape[0]): 21 | for c in range(preds.shape[1]): 22 | if target[n, c, 0] > 1 and target[n, c, 1] > 1: 23 | normed_preds = preds[n, c, :] / normalize[n] 24 | normed_targets = target[n, c, :] / normalize[n] 25 | dists[c, n] = np.linalg.norm(normed_preds - normed_targets) 26 | else: 27 | dists[c, n] = -1 28 | return dists 29 | 30 | 31 | def dist_acc(dists, thr=0.5): 32 | ''' Return percentage below threshold while ignoring values with a -1 ''' 33 | dist_cal = np.not_equal(dists, -1) 34 | num_dist_cal = dist_cal.sum() 35 | if num_dist_cal > 0: 36 | return np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal 37 | else: 38 | return -1 39 | 40 | 41 | def accuracy(output, target, hm_type='gaussian', thr=0.5): 42 | ''' 43 | Calculate accuracy according to PCK, 44 | but uses ground truth heatmap rather than x,y locations 45 | First value to be returned is average accuracy across 'idxs', 46 | followed by individual accuracies 47 | ''' 48 | idx = list(range(output.shape[1])) 49 | norm = 1.0 50 | if hm_type == 'gaussian': 51 | pred, _ = get_max_preds(output) 52 | target, _ = get_max_preds(target) 53 | h = output.shape[2] 54 | w = output.shape[3] 55 | norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10 56 | dists = calc_dists(pred, target, norm) 57 | 58 | acc = np.zeros((len(idx) + 1)) 59 | avg_acc = 0 60 | cnt = 0 61 | 62 | for i in range(len(idx)): 63 | acc[i + 1] = dist_acc(dists[idx[i]]) 64 | if acc[i + 1] >= 0: 65 | avg_acc = avg_acc + acc[i + 1] 66 | cnt += 1 67 | 68 | avg_acc = avg_acc / cnt if cnt != 0 else 0 69 | if cnt != 0: 70 | acc[0] = avg_acc 71 | return acc, avg_acc, cnt, pred 72 | 73 | 74 | -------------------------------------------------------------------------------- /lib/nms/cpu_nms.pyx: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import numpy as np 12 | cimport numpy as np 13 | 14 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b): 15 | return a if a >= b else b 16 | 17 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b): 18 | return a if a <= b else b 19 | 20 | def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): 21 | cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] 22 | cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] 23 | cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] 24 | cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] 25 | cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] 26 | 27 | cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) 28 | cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1].astype('i') 29 | 30 | cdef int ndets = dets.shape[0] 31 | cdef np.ndarray[np.int_t, ndim=1] suppressed = \ 32 | np.zeros((ndets), dtype=np.int) 33 | 34 | # nominal indices 35 | cdef int _i, _j 36 | # sorted indices 37 | cdef int i, j 38 | # temp variables for box i's (the box currently under consideration) 39 | cdef np.float32_t ix1, iy1, ix2, iy2, iarea 40 | # variables for computing overlap with box j (lower scoring box) 41 | cdef np.float32_t xx1, yy1, xx2, yy2 42 | cdef np.float32_t w, h 43 | cdef np.float32_t inter, ovr 44 | 45 | keep = [] 46 | for _i in range(ndets): 47 | i = order[_i] 48 | if suppressed[i] == 1: 49 | continue 50 | keep.append(i) 51 | ix1 = x1[i] 52 | iy1 = y1[i] 53 | ix2 = x2[i] 54 | iy2 = y2[i] 55 | iarea = areas[i] 56 | for _j in range(_i + 1, ndets): 57 | j = order[_j] 58 | if suppressed[j] == 1: 59 | continue 60 | xx1 = max(ix1, x1[j]) 61 | yy1 = max(iy1, y1[j]) 62 | xx2 = min(ix2, x2[j]) 63 | yy2 = min(iy2, y2[j]) 64 | w = max(0.0, xx2 - xx1 + 1) 65 | h = max(0.0, yy2 - yy1 + 1) 66 | inter = w * h 67 | ovr = inter / (iarea + areas[j] - inter) 68 | if ovr >= thresh: 69 | suppressed[j] = 1 70 | 71 | return keep 72 | -------------------------------------------------------------------------------- /experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml: -------------------------------------------------------------------------------- 1 | AUTO_RESUME: true 2 | CUDNN: 3 | BENCHMARK: true 4 | DETERMINISTIC: false 5 | ENABLED: true 6 | DATA_DIR: '' 7 | GPUS: (0, 1) 8 | OUTPUT_DIR: 'output' 9 | LOG_DIR: 'log' 10 | WORKERS: 4 11 | PRINT_FREQ: 100 12 | LABEL_PER_CLASS: 15 13 | 14 | DATASET: 15 | COLOR_RGB: true 16 | DATASET: 'ap10k' 17 | DATA_FORMAT: jpg 18 | FLIP: true 19 | NUM_JOINTS_HALF_BODY: 8 20 | PROB_HALF_BODY: 0.3 21 | ROOT: 'data/animalpose/' 22 | ROT_FACTOR: 40 23 | SCALE_FACTOR: 0.5 24 | TEST_SET: 'test' 25 | TRAIN_SET: 'train' 26 | VAL_SET: 'val' 27 | SELECT_DATA: false 28 | SUPERCATEGORY: ['Bovidae'] 29 | MODEL: 30 | INIT_WEIGHTS: true 31 | NAME: pose_hrnet 32 | NUM_JOINTS: 17 33 | PRETRAINED: 'models/pytorch/imagenet/hrnet_w32-36af842e.pth' 34 | TARGET_TYPE: gaussian 35 | IMAGE_SIZE: 36 | - 256 37 | - 256 38 | HEATMAP_SIZE: 39 | - 64 40 | - 64 41 | SIGMA: 2 42 | EXTRA: 43 | PRETRAINED_LAYERS: 44 | - 'conv1' 45 | - 'bn1' 46 | - 'conv2' 47 | - 'bn2' 48 | - 'layer1' 49 | - 'transition1' 50 | - 'stage2' 51 | - 'transition2' 52 | - 'stage3' 53 | - 'transition3' 54 | - 'stage4' 55 | FINAL_CONV_KERNEL: 1 56 | STAGE2: 57 | NUM_MODULES: 1 58 | NUM_BRANCHES: 2 59 | BLOCK: BASIC 60 | NUM_BLOCKS: 61 | - 4 62 | - 4 63 | NUM_CHANNELS: 64 | - 32 65 | - 64 66 | FUSE_METHOD: SUM 67 | STAGE3: 68 | NUM_MODULES: 4 69 | NUM_BRANCHES: 3 70 | BLOCK: BASIC 71 | NUM_BLOCKS: 72 | - 4 73 | - 4 74 | - 4 75 | NUM_CHANNELS: 76 | - 32 77 | - 64 78 | - 128 79 | FUSE_METHOD: SUM 80 | STAGE4: 81 | NUM_MODULES: 3 82 | NUM_BRANCHES: 4 83 | BLOCK: BASIC 84 | NUM_BLOCKS: 85 | - 4 86 | - 4 87 | - 4 88 | - 4 89 | NUM_CHANNELS: 90 | - 32 91 | - 64 92 | - 128 93 | - 256 94 | FUSE_METHOD: SUM 95 | LOSS: 96 | USE_TARGET_WEIGHT: true 97 | TRAIN: 98 | BATCH_SIZE_PER_GPU: 32 99 | SHUFFLE: true 100 | BEGIN_EPOCH: 0 101 | END_EPOCH: 210 102 | OPTIMIZER: adam 103 | LR: 0.001 104 | LR_FACTOR: 0.1 105 | LR_STEP: 106 | - 170 107 | - 200 108 | WD: 0.0001 109 | GAMMA1: 0.99 110 | GAMMA2: 0.0 111 | MOMENTUM: 0.9 112 | NESTEROV: false 113 | TEST: 114 | BATCH_SIZE_PER_GPU: 32 115 | COCO_BBOX_FILE: '' 116 | BBOX_THRE: 1.0 117 | IMAGE_THRE: 0.0 118 | IN_VIS_THRE: 0.2 119 | MODEL_FILE: '' 120 | NMS_THRE: 1.0 121 | OKS_THRE: 0.9 122 | USE_GT_BBOX: true 123 | FLIP_TEST: true 124 | POST_PROCESS: true 125 | SHIFT_HEATMAP: true 126 | DEBUG: 127 | DEBUG: true 128 | SAVE_BATCH_IMAGES_GT: true 129 | SAVE_BATCH_IMAGES_PRED: true 130 | SAVE_HEATMAPS_GT: true 131 | SAVE_HEATMAPS_PRED: true 132 | -------------------------------------------------------------------------------- /data/label_list/annotation_list_10: -------------------------------------------------------------------------------- 1 | [31973, 32030, 32075, 32002, 31934, 31969, 31949, 32019, 32013, 32027, 40398, 34250, 34661, 33394, 40783, 40592, 32740, 40403, 40047, 35828, 36429, 36415, 36371, 36393, 36392, 36499, 36402, 36494, 36344, 36390, 37700, 37626, 37668, 37624, 37843, 37601, 37671, 37728, 37658, 37851, 37889, 37878, 37883, 37891, 37882, 37892, 37884, 37874, 37895, 37876, 38286, 38541, 37927, 37949, 38319, 37952, 37918, 38716, 38027, 38297, 38755, 38787, 38848, 38779, 38755, 38749, 38824, 38829, 38871, 38724, 39807, 39906, 39820, 39747, 39780, 39873, 39816, 39908, 39862, 39772, 40859, 40922, 40897, 40929, 40886, 40930, 40905, 40901, 40884, 40854, 40996, 41084, 41146, 41120, 40987, 41067, 41142, 41129, 41008, 41013, 44922, 44331, 44266, 44902, 44339, 44275, 44914, 44299, 44344, 44276, 20054, 20061, 20044, 20068, 20018, 20076, 20039, 20036, 20067, 20047, 20191, 20192, 20106, 20208, 20136, 20193, 20170, 20234, 20195, 20189, 20330, 20303, 20307, 20286, 20270, 20286, 20290, 20297, 20329, 20279, 20564, 20408, 20533, 20578, 20557, 20343, 20478, 20551, 20452, 20542, 20620, 20656, 20634, 20677, 20607, 20645, 20608, 20613, 20629, 20651, 44943, 44962, 45080, 45061, 44962, 45014, 45031, 44995, 45085, 45046, 46214, 46299, 46183, 46122, 46146, 46241, 46132, 46201, 46233, 46121, 49340, 49259, 49381, 49406, 49257, 49407, 49379, 49299, 49317, 49366, 50073, 50046, 50101, 50010, 50144, 50219, 49978, 50194, 50128, 50009, 233, 337, 107, 282, 198, 227, 119, 604, 125, 38, 1050, 1092, 1096, 1031, 1044, 1028, 1103, 1096, 1043, 1067, 1591, 1924, 1213, 1177, 1122, 1208, 1214, 1187, 1768, 1175, 2105, 1980, 2051, 2059, 2112, 2005, 2019, 2144, 2189, 2088, 4237, 5156, 5289, 2455, 4768, 3594, 3109, 5666, 2729, 2455, 7939, 8036, 8012, 7756, 7274, 7446, 7446, 7409, 6108, 7952, 48908, 48873, 48785, 48816, 48787, 48880, 48822, 48758, 48832, 48787, 49052, 49002, 49102, 48913, 48917, 48980, 49000, 48973, 48933, 49124, 51833, 51865, 51862, 51805, 51857, 51840, 51845, 51806, 51854, 51802, 51920, 52046, 52010, 52136, 51921, 52105, 51899, 51963, 52012, 52140, 22718, 22782, 22734, 23417, 22667, 22703, 22676, 22751, 23226, 23215, 55642, 55660, 55634, 55646, 55659, 55653, 55664, 55643, 55626, 55657, 55822, 55698, 55710, 55850, 55793, 55767, 55721, 55856, 55782, 55823, 56636, 56691, 56666, 56567, 56631, 56571, 56667, 56719, 56601, 56704, 58447, 58485, 58626, 58604, 58452, 58454, 58517, 58478, 58526, 58554, 51197, 50967, 50797, 51043, 51216, 51216, 50798, 50853, 50844, 50808, 43206, 43097, 43109, 43273, 43216, 43077, 43185, 43091, 43092, 43126, 20767, 20819, 20762, 20700, 20734, 20858, 20738, 20727, 20791, 20863, 22653, 22081, 22285, 22053, 22507, 22307, 22061, 22075, 22618, 22015, 10077, 12854, 10796, 8821, 12091, 8851, 18400, 9417, 17392, 18848, 17732, 17647, 18268, 17689, 17740, 17742, 18101, 17738, 17714, 18303, 19452, 19619, 19355, 19350, 19574, 19302, 19541, 19346, 19508, 19350, 54828, 54875, 55009, 54829, 54856, 54820, 54978, 54817, 54976, 54969, 19863, 19849, 19976, 19901, 19886, 19991, 19987, 19927, 19915, 19935, 50385, 50274, 50436, 50328, 50377, 50323, 50297, 50277, 50428, 50330, 47644, 47487, 47640, 47468, 47587, 47556, 47607, 47627, 47512, 47520, 23572, 23472, 23513, 23624, 23588, 23436, 23707, 23536, 23483, 23479, 29152, 28605, 29313, 29473, 29546, 30228, 29239, 26792, 29374, 26670, 31055, 31632, 30507, 31555, 31477, 30695, 31621, 31144, 30651, 30981, 48685, 48559, 48671, 48556, 48595, 48641, 48623, 48716, 48611, 48626] -------------------------------------------------------------------------------- /lib/core/inference.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import math 12 | 13 | import numpy as np 14 | 15 | from utils.transforms import transform_preds 16 | 17 | 18 | def get_max_preds(batch_heatmaps): 19 | ''' 20 | get predictions from score maps 21 | heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) 22 | ''' 23 | assert isinstance(batch_heatmaps, np.ndarray), \ 24 | 'batch_heatmaps should be numpy.ndarray' 25 | assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim' 26 | 27 | batch_size = batch_heatmaps.shape[0] 28 | num_joints = batch_heatmaps.shape[1] 29 | width = batch_heatmaps.shape[3] 30 | heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1)) 31 | idx = np.argmax(heatmaps_reshaped, 2) 32 | maxvals = np.amax(heatmaps_reshaped, 2) 33 | 34 | maxvals = maxvals.reshape((batch_size, num_joints, 1)) 35 | idx = idx.reshape((batch_size, num_joints, 1)) 36 | 37 | preds = np.tile(idx, (1, 1, 2)).astype(np.float32) 38 | 39 | preds[:, :, 0] = (preds[:, :, 0]) % width 40 | preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) 41 | 42 | pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) 43 | pred_mask = pred_mask.astype(np.float32) 44 | 45 | preds *= pred_mask 46 | return preds, maxvals 47 | 48 | 49 | def get_final_preds(config, batch_heatmaps, center, scale): 50 | coords, maxvals = get_max_preds(batch_heatmaps) 51 | 52 | heatmap_height = batch_heatmaps.shape[2] 53 | heatmap_width = batch_heatmaps.shape[3] 54 | 55 | # post-processing 56 | if config.TEST.POST_PROCESS: 57 | for n in range(coords.shape[0]): 58 | for p in range(coords.shape[1]): 59 | hm = batch_heatmaps[n][p] 60 | px = int(math.floor(coords[n][p][0] + 0.5)) 61 | py = int(math.floor(coords[n][p][1] + 0.5)) 62 | if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1: 63 | diff = np.array( 64 | [ 65 | hm[py][px+1] - hm[py][px-1], 66 | hm[py+1][px]-hm[py-1][px] 67 | ] 68 | ) 69 | coords[n][p] += np.sign(diff) * .25 70 | 71 | preds = coords.copy() 72 | 73 | # Transform back 74 | for i in range(coords.shape[0]): 75 | preds[i] = transform_preds( 76 | coords[i], center[i], scale[i], [heatmap_width, heatmap_height] 77 | ) 78 | 79 | return preds, maxvals 80 | 81 | 82 | def get_final_preds_const(batch_heatmaps, center, scale): 83 | coords, maxvals = get_max_preds(batch_heatmaps) 84 | 85 | heatmap_height = batch_heatmaps.shape[2] 86 | heatmap_width = batch_heatmaps.shape[3] 87 | 88 | # post-processing 89 | preds = coords.copy() 90 | 91 | # Transform back 92 | for i in range(coords.shape[0]): 93 | preds[i] = transform_preds( 94 | coords[i], center[i], scale[i], [heatmap_width, heatmap_height] 95 | ) 96 | 97 | return preds, maxvals -------------------------------------------------------------------------------- /lib/utils/vis_skeleton.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from torchvision.transforms import transforms 4 | 5 | 6 | def cv2_plot_lines(frame, pts, order): 7 | color_mapping = {1: [255, 0, 255], 2: [255, 0, 0], 3: [255, 0, 127], 4: [255, 255, 255], 5: [0, 0, 255], 8 | 6: [0, 127, 255], 7: [0, 255, 255], 8: [0, 255, 0], 9: [200, 162, 200]} 9 | # point_size = 7 10 | point_size = 2 11 | if order == 0: 12 | # other animals 13 | # plot nose-eyes 14 | cv2.line(frame, (pts[2, 0], pts[2, 1]), (pts[0, 0], pts[0, 1]), color_mapping[5], point_size) 15 | cv2.line(frame, (pts[2, 0], pts[2, 1]), (pts[1, 0], pts[1, 1]), color_mapping[5], point_size) 16 | cv2.line(frame, (pts[0, 0], pts[0, 1]), (pts[1, 0], pts[1, 1]), color_mapping[5], point_size) 17 | 18 | # plot neck and nose 19 | # cv2.line(frame, (pts[2, 0], pts[2, 1]), (pts[3, 0], pts[3, 1]), color_mapping[8], point_size) 20 | 21 | # plot neck and base tail 22 | cv2.line(frame, (pts[4, 0], pts[4, 1]), (pts[3, 0], pts[3, 1]), color_mapping[8], point_size) 23 | 24 | # plot left front leg 25 | cv2.line(frame, (pts[3, 0], pts[3, 1]), (pts[5, 0], pts[5, 1]), color_mapping[1], point_size) 26 | cv2.line(frame, (pts[5, 0], pts[5, 1]), (pts[6, 0], pts[6, 1]), color_mapping[1], point_size) 27 | cv2.line(frame, (pts[7, 0], pts[7, 1]), (pts[6, 0], pts[6, 1]), color_mapping[1], point_size) 28 | 29 | # plot right front leg 30 | cv2.line(frame, (pts[3, 0], pts[3, 1]), (pts[8, 0], pts[8, 1]), color_mapping[2], point_size) 31 | cv2.line(frame, (pts[8, 0], pts[8, 1]), (pts[9, 0], pts[9, 1]), color_mapping[2], point_size) 32 | cv2.line(frame, (pts[10, 0], pts[10, 1]), (pts[9, 0], pts[9, 1]), color_mapping[2], point_size) 33 | 34 | # plot left back leg 35 | cv2.line(frame, (pts[4, 0], pts[4, 1]), (pts[11, 0], pts[11, 1]), color_mapping[6], point_size) 36 | cv2.line(frame, (pts[12, 0], pts[12, 1]), (pts[11, 0], pts[11, 1]), color_mapping[6], point_size) 37 | cv2.line(frame, (pts[12, 0], pts[12, 1]), (pts[13, 0], pts[13, 1]), color_mapping[6], point_size) 38 | 39 | # plot right back leg 40 | cv2.line(frame, (pts[4, 0], pts[4, 1]), (pts[14, 0], pts[14, 1]), color_mapping[7], point_size) 41 | cv2.line(frame, (pts[15, 0], pts[15, 1]), (pts[14, 0], pts[14, 1]), color_mapping[7], point_size) 42 | cv2.line(frame, (pts[15, 0], pts[15, 1]), (pts[16, 0], pts[16, 1]), color_mapping[7], point_size) 43 | return frame 44 | 45 | 46 | def cv2_visualize_keypoints(frames, pts, savepath, idx, num_pts=17, order=0): 47 | inv_normalize = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]), 48 | transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.])]) 49 | inputs = inv_normalize(frames) 50 | inputs = inputs.numpy().transpose(0, 2, 3, 1) 51 | for b in range(inputs.shape[0]): 52 | frame = np.uint8(inputs[b].copy() * 255) 53 | kpt = pts[b].astype(np.int) 54 | x = [] 55 | y = [] 56 | for i in range(num_pts): 57 | x.append(kpt[i, 0]) 58 | y.append(kpt[i, 1]) 59 | # plot keypoints on each image 60 | cv2.circle(frame, (x[-1], y[-1]), 2, (0, 255, 0), -1) 61 | frame = cv2_plot_lines(frame, kpt, order) 62 | # cv2.imshow('frame', frame) 63 | # cv2.waitKey(0) 64 | cv2.imwrite(savepath + str(idx + b) + '.jpg', frame) -------------------------------------------------------------------------------- /data/label_list/annotation_list_Bovidae: -------------------------------------------------------------------------------- 1 | [1, 2, 4, 34, 37, 38, 46, 49, 58, 70, 71, 582, 82, 83, 84, 90, 92, 604, 93, 96, 99, 101, 103, 104, 105, 106, 107, 108, 114, 115, 116, 117, 118, 119, 121, 122, 123, 125, 126, 129, 130, 132, 133, 134, 135, 136, 137, 138, 140, 141, 142, 143, 144, 145, 146, 147, 149, 150, 153, 154, 155, 156, 164, 165, 167, 169, 170, 171, 172, 173, 177, 178, 180, 181, 183, 184, 185, 187, 190, 191, 192, 195, 196, 198, 199, 200, 201, 202, 715, 204, 203, 206, 207, 208, 209, 210, 212, 213, 726, 215, 216, 217, 220, 221, 223, 224, 225, 226, 227, 228, 230, 231, 232, 233, 234, 237, 238, 239, 240, 241, 242, 243, 249, 260, 271, 282, 315, 848, 337, 859, 348, 359, 371, 382, 404, 437, 448, 970, 992, 482, 1025, 1027, 1028, 1030, 1031, 1033, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1047, 1049, 1050, 1052, 1053, 1058, 1060, 1061, 1062, 1064, 1067, 1068, 1069, 1072, 1073, 1075, 1076, 1079, 1082, 1083, 1084, 1089, 1091, 1092, 1094, 1095, 1096, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1106, 1107, 1108, 1109, 1112, 1546, 1557, 1558, 1580, 1591, 1613, 1113, 1114, 1115, 1117, 1118, 1119, 1120, 1122, 1123, 1126, 1128, 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1646, 1138, 1139, 1141, 1143, 1144, 1657, 1145, 1147, 1146, 1149, 1150, 1151, 1154, 1669, 1158, 1160, 1162, 1164, 1165, 1166, 1167, 1680, 1168, 1169, 1171, 1172, 1174, 1175, 1176, 1177, 1178, 1691, 1180, 1181, 1182, 1179, 1183, 1185, 1187, 1702, 1192, 1193, 1194, 1195, 1196, 1197, 1198, 1199, 1200, 1713, 1202, 1203, 1205, 1206, 1208, 1209, 1211, 1213, 1214, 1215, 1217, 1218, 1219, 1220, 1221, 1222, 1735, 1224, 1225, 1227, 1228, 1746, 1236, 1757, 1768, 1258, 1779, 1269, 1791, 1280, 1802, 1813, 1302, 1835, 1324, 1846, 1335, 1336, 1347, 1358, 1380, 1902, 1402, 1924, 1413, 1935, 1424, 1435, 1446, 1967, 1969, 1458, 1971, 1972, 1970, 1974, 1975, 1976, 1977, 1978, 1491, 1502, 1513, 1524, 1535, 2048, 2049, 2051, 2059, 2061, 2063, 2065, 2067, 2069, 2072, 2073, 2074, 2075, 2076, 2078, 2079, 2080, 2081, 2083, 2084, 2085, 2087, 2088, 2089, 2090, 2091, 2092, 2093, 2094, 2095, 2096, 2097, 2098, 2099, 2100, 2104, 2105, 2106, 2108, 2109, 2111, 2112, 2113, 2114, 2117, 2118, 2119, 2120, 2121, 2123, 2124, 2127, 2128, 2132, 2133, 2135, 2137, 2138, 2139, 2140, 2141, 2142, 2144, 2146, 2150, 2151, 2153, 2154, 2156, 2157, 2158, 2159, 2160, 2161, 2164, 2165, 2166, 2167, 2168, 2170, 2172, 2173, 2174, 2176, 2177, 2178, 2179, 2181, 2182, 2183, 2185, 2188, 2189, 2190, 2191, 2193, 2196, 1980, 1981, 1983, 1984, 1986, 1988, 1991, 1992, 1993, 1995, 1997, 1998, 1999, 2002, 2003, 2004, 2005, 2007, 2008, 2009, 2011, 2013, 2015, 2016, 2017, 2019, 2022, 2023, 2024, 2025, 2027, 2029, 2030, 2033, 2035, 2036, 2037, 2038, 2039, 2040, 2045, 2046, 2047, 4608, 5633, 5122, 3594, 5644, 4109, 5133, 2577, 5655, 5144, 4633, 3099, 5666, 5156, 3109, 4645, 5677, 5167, 4657, 2610, 4665, 5178, 4670, 5189, 4682, 5211, 4706, 5219, 2665, 4718, 2677, 4730, 2688, 4225, 3717, 4744, 4237, 4238, 5267, 2710, 3223, 4250, 4768, 2721, 4262, 2729, 5289, 2732, 4781, 4274, 5300, 2754, 4805, 5322, 4299, 4817, 4311, 2776, 4829, 5344, 4323, 2788, 4841, 5355, 4331, 4336, 4853, 5366, 4855, 5377, 4867, 2821, 4360, 4361, 4879, 2832, 3346, 4373, 4892, 2332, 4385, 5411, 4904, 2344, 4397, 5422, 3887, 4916, 5433, 5444, 4422, 2377, 4943, 4434, 2388, 4952, 2399, 4447, 5477, 2410, 4459, 5489, 3442, 3443, 2421, 3963, 4989, 2433, 5511, 5000, 2443, 3471, 5522, 5011, 2455, 5022, 2975, 2466, 5544, 4521, 5033, 5555, 5044, 2488, 5056, 4545, 2499, 4559, 5588, 5589, 5078, 4571, 5599, 5089, 3554, 5096, 4595, 4086, 5111, 4607, 7680, 7692, 8210, 7187, 8222, 7199, 7712, 7714, 7718, 7720, 7213, 7728, 7730, 7734, 8246, 7736, 7740, 7745, 6722, 7237, 7752, 7753, 7756, 7759, 8271, 7249, 7763, 7765, 7769, 8285, 7261, 7773, 7776, 7778, 5734, 8297, 7274, 5739, 7791, 7793, 8309, 7286, 7298, 7816, 7310, 7841, 8357, 7335, 7336, 7853, 7344, 8369, 7348, 7865, 6842, 6843, 6845, 8382, 5822, 7360, 7877, 6855, 8395, 7372, 7377, 6354, 8404, 6868, 8408, 7385, 8409, 7901, 5856, 6880, 8421, 5862, 7397, 7913, 7914, 5867, 7915, 7409, 8445, 7939, 8457, 7434, 7952, 6928, 8469, 7446, 7964, 6953, 7466, 6965, 6966, 6967, 8000, 6979, 7496, 8012, 6478, 7508, 6488, 7003, 8036, 5989, 7015, 7532, 8050, 7027, 7545, 8063, 7039, 7557, 7052, 8087, 7064, 7076, 7588, 8111, 7088, 7089, 7607, 8123, 7100, 7102, 7619, 8135, 6600, 7631, 8147, 7643, 6108, 6111, 8159, 8161, 7656, 8174, 7668, 7163] -------------------------------------------------------------------------------- /lib/core/loss.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import torch 12 | import torch.nn as nn 13 | import numpy as np 14 | 15 | 16 | class JointsMSELoss(nn.Module): 17 | def __init__(self, use_target_weight): 18 | super(JointsMSELoss, self).__init__() 19 | self.criterion = nn.MSELoss(reduction='mean') 20 | self.use_target_weight = use_target_weight 21 | 22 | def forward(self, output, target, target_weight): 23 | batch_size = output.size(0) 24 | num_joints = output.size(1) 25 | heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1) 26 | heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) 27 | loss = 0 28 | 29 | for idx in range(num_joints): 30 | heatmap_pred = heatmaps_pred[idx].squeeze() 31 | heatmap_gt = heatmaps_gt[idx].squeeze() 32 | if self.use_target_weight: 33 | loss += 0.5 * self.criterion( 34 | heatmap_pred.mul(target_weight[:, idx]), 35 | heatmap_gt.mul(target_weight[:, idx]) 36 | ) 37 | else: 38 | loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) 39 | 40 | return loss / num_joints 41 | 42 | 43 | criterion_mse = nn.MSELoss(reduction='none') 44 | 45 | 46 | def select_small_loss_samples_v2(output, target, target_weight, topk_rate): 47 | batch_size = output.size(0) 48 | number_joints = output.size(1) 49 | num_visible_joints = torch.count_nonzero(target_weight) 50 | num_small_loss_samples = int(num_visible_joints * topk_rate) 51 | output_re = output.reshape(batch_size, number_joints, -1) 52 | target_re = target.reshape(batch_size, number_joints, -1) 53 | loss = criterion_mse(output_re.mul(target_weight), target_re.mul(target_weight)).mean(-1) 54 | loss_max = loss.max() * torch.ones_like(loss) 55 | weight = (target_weight > 0) 56 | # set loss for joint with weight 0 to a large number to avoid being selected 57 | loss = torch.where(weight.squeeze(-1), loss, loss_max) 58 | dim_last = loss.size(-1) 59 | _, topk_idx = torch.topk(loss.flatten(), k=num_small_loss_samples, largest=False) 60 | topk_idx = topk_idx.unsqueeze(-1) 61 | idx_re = torch.cat([topk_idx // dim_last, topk_idx % dim_last], dim=-1) 62 | return idx_re 63 | 64 | 65 | class CurriculumLoss(nn.Module): 66 | def __init__(self, use_target_weight=True): 67 | super(CurriculumLoss, self).__init__() 68 | self.criterion = nn.MSELoss(reduce=False) 69 | self.use_target_weight = use_target_weight 70 | 71 | def forward(self, output, target, target_weight, top_k): 72 | batch_size = output.size(0) 73 | num_joints = output.size(1) 74 | heatmaps_pred = output.reshape((batch_size, num_joints, -1)) 75 | heatmaps_gt = target.reshape((batch_size, num_joints, -1)) 76 | 77 | if self.use_target_weight: 78 | loss = 0.5 * (self.criterion( 79 | heatmaps_pred.mul(target_weight), 80 | heatmaps_gt.mul(target_weight) 81 | )).mean(-1) 82 | else: 83 | loss = 0.5 * (self.criterion(heatmaps_pred, heatmaps_gt)).mean(-1) 84 | weights_bool = (target_weight > 0) 85 | loss_clone = loss.clone().detach().requires_grad_(False) 86 | loss_inf = 1e8 * torch.ones_like(loss_clone, requires_grad=False) 87 | # set the loss of invalid joints (weights equal 0) to a large value such that it won't be 88 | # selected as reliable pseudo labels, only joints with smaller loss will be selected 89 | loss_clone = torch.where(weights_bool.squeeze(-1), loss_clone, loss_inf) 90 | _, topk_idx = torch.topk(loss_clone, k=top_k, dim=-1, largest=False) 91 | tmp_loss = torch.gather(loss, dim=-1, index=topk_idx) 92 | tmp_loss = tmp_loss.sum()/(top_k * batch_size) 93 | return tmp_loss -------------------------------------------------------------------------------- /lib/config/default.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Copyright (c) Microsoft 4 | # Licensed under the MIT License. 5 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | 14 | from yacs.config import CfgNode as CN 15 | 16 | 17 | _C = CN() 18 | 19 | _C.OUTPUT_DIR = '' 20 | _C.LOG_DIR = '' 21 | _C.DATA_DIR = '' 22 | _C.GPUS = (0,) 23 | _C.WORKERS = 4 24 | _C.PRINT_FREQ = 20 25 | _C.AUTO_RESUME = False 26 | _C.PIN_MEMORY = True 27 | _C.RANK = 0 28 | _C.LABEL_PER_CLASS = 15 29 | 30 | # Cudnn related params 31 | _C.CUDNN = CN() 32 | _C.CUDNN.BENCHMARK = True 33 | _C.CUDNN.DETERMINISTIC = False 34 | _C.CUDNN.ENABLED = True 35 | 36 | # common params for NETWORK 37 | _C.MODEL = CN() 38 | _C.MODEL.NAME = 'pose_hrnet' 39 | _C.MODEL.INIT_WEIGHTS = True 40 | _C.MODEL.PRETRAINED = '' 41 | _C.MODEL.NUM_JOINTS = 17 42 | _C.MODEL.TAG_PER_JOINT = True 43 | _C.MODEL.TARGET_TYPE = 'gaussian' 44 | _C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256 45 | _C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32 46 | _C.MODEL.SIGMA = 2 47 | _C.MODEL.EXTRA = CN(new_allowed=True) 48 | 49 | _C.LOSS = CN() 50 | _C.LOSS.USE_OHKM = False 51 | _C.LOSS.TOPK = 8 52 | _C.LOSS.USE_TARGET_WEIGHT = True 53 | _C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False 54 | 55 | # DATASET related params 56 | _C.DATASET = CN() 57 | _C.DATASET.ROOT = '' 58 | _C.DATASET.DATASET = 'mpii' 59 | _C.DATASET.TRAIN_SET = 'train' 60 | _C.DATASET.TEST_SET = 'test' 61 | _C.DATASET.VAL_SET = 'val' 62 | _C.DATASET.DATA_FORMAT = 'jpg' 63 | _C.DATASET.HYBRID_JOINTS_TYPE = '' 64 | _C.DATASET.SELECT_DATA = False 65 | _C.DATASET.SUPERCATEGORY = [] 66 | # training data augmentation 67 | _C.DATASET.FLIP = True 68 | _C.DATASET.SCALE_FACTOR = 0.25 69 | _C.DATASET.ROT_FACTOR = 30 70 | _C.DATASET.PROB_HALF_BODY = 0.0 71 | _C.DATASET.NUM_JOINTS_HALF_BODY = 8 72 | _C.DATASET.COLOR_RGB = False 73 | 74 | # train 75 | _C.TRAIN = CN() 76 | 77 | _C.TRAIN.LR_FACTOR = 0.1 78 | _C.TRAIN.LR_STEP = [90, 110] 79 | _C.TRAIN.LR = 0.001 80 | 81 | _C.TRAIN.OPTIMIZER = 'adam' 82 | _C.TRAIN.MOMENTUM = 0.9 83 | _C.TRAIN.WD = 0.0001 84 | _C.TRAIN.NESTEROV = False 85 | _C.TRAIN.GAMMA1 = 0.99 86 | _C.TRAIN.GAMMA2 = 0.0 87 | 88 | _C.TRAIN.BEGIN_EPOCH = 0 89 | _C.TRAIN.END_EPOCH = 140 90 | 91 | _C.TRAIN.RESUME = False 92 | _C.TRAIN.CHECKPOINT = '' 93 | 94 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32 95 | _C.TRAIN.SHUFFLE = True 96 | 97 | # testing 98 | _C.TEST = CN() 99 | 100 | # size of images for each device 101 | _C.TEST.BATCH_SIZE_PER_GPU = 32 102 | # Test Model Epoch 103 | _C.TEST.FLIP_TEST = False 104 | _C.TEST.POST_PROCESS = False 105 | _C.TEST.SHIFT_HEATMAP = False 106 | 107 | _C.TEST.USE_GT_BBOX = False 108 | 109 | # nms 110 | _C.TEST.IMAGE_THRE = 0.1 111 | _C.TEST.NMS_THRE = 0.6 112 | _C.TEST.SOFT_NMS = False 113 | _C.TEST.OKS_THRE = 0.5 114 | _C.TEST.IN_VIS_THRE = 0.0 115 | _C.TEST.COCO_BBOX_FILE = '' 116 | _C.TEST.BBOX_THRE = 1.0 117 | _C.TEST.MODEL_FILE = '' 118 | 119 | # debug 120 | _C.DEBUG = CN() 121 | _C.DEBUG.DEBUG = False 122 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False 123 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False 124 | _C.DEBUG.SAVE_HEATMAPS_GT = False 125 | _C.DEBUG.SAVE_HEATMAPS_PRED = False 126 | 127 | 128 | def update_config(cfg, args): 129 | cfg.defrost() 130 | cfg.merge_from_file(args.cfg) 131 | cfg.merge_from_list(args.opts) 132 | 133 | if args.modelDir: 134 | cfg.OUTPUT_DIR = args.modelDir 135 | 136 | if args.logDir: 137 | cfg.LOG_DIR = args.logDir 138 | 139 | if args.dataDir: 140 | cfg.DATA_DIR = args.dataDir 141 | 142 | cfg.DATASET.ROOT = os.path.join( 143 | cfg.DATA_DIR, cfg.DATASET.ROOT 144 | ) 145 | 146 | cfg.MODEL.PRETRAINED = os.path.join( 147 | cfg.DATA_DIR, cfg.MODEL.PRETRAINED 148 | ) 149 | 150 | if cfg.TEST.MODEL_FILE: 151 | cfg.TEST.MODEL_FILE = os.path.join( 152 | cfg.DATA_DIR, cfg.TEST.MODEL_FILE 153 | ) 154 | 155 | cfg.freeze() 156 | 157 | 158 | if __name__ == '__main__': 159 | import sys 160 | with open(sys.argv[1], 'w') as f: 161 | print(_C, file=f) 162 | 163 | -------------------------------------------------------------------------------- /lib/dataset_animal/dataset_info.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | 4 | 5 | class DatasetInfo: 6 | 7 | def __init__(self, dataset_info): 8 | self._dataset_info = dataset_info 9 | self.dataset_name = self._dataset_info['dataset_name'] 10 | self.paper_info = self._dataset_info['paper_info'] 11 | self.keypoint_info = self._dataset_info['keypoint_info'] 12 | self.skeleton_info = self._dataset_info['skeleton_info'] 13 | self.joint_weights = np.array( 14 | self._dataset_info['joint_weights'], dtype=np.float32)[:, None] 15 | 16 | self.sigmas = np.array(self._dataset_info['sigmas']) 17 | 18 | self._parse_keypoint_info() 19 | self._parse_skeleton_info() 20 | 21 | def _parse_skeleton_info(self): 22 | """Parse skeleton information. 23 | 24 | - link_num (int): number of links. 25 | - skeleton (list((2,))): list of links (id). 26 | - skeleton_name (list((2,))): list of links (name). 27 | - pose_link_color (np.ndarray): the color of the link for 28 | visualization. 29 | """ 30 | self.link_num = len(self.skeleton_info.keys()) 31 | self.pose_link_color = [] 32 | 33 | self.skeleton_name = [] 34 | self.skeleton = [] 35 | for skid in self.skeleton_info.keys(): 36 | link = self.skeleton_info[skid]['link'] 37 | self.skeleton_name.append(link) 38 | self.skeleton.append([ 39 | self.keypoint_name2id[link[0]], self.keypoint_name2id[link[1]] 40 | ]) 41 | self.pose_link_color.append(self.skeleton_info[skid].get( 42 | 'color', [255, 128, 0])) 43 | self.pose_link_color = np.array(self.pose_link_color) 44 | 45 | def _parse_keypoint_info(self): 46 | """Parse keypoint information. 47 | 48 | - keypoint_num (int): number of keypoints. 49 | - keypoint_id2name (dict): mapping keypoint id to keypoint name. 50 | - keypoint_name2id (dict): mapping keypoint name to keypoint id. 51 | - upper_body_ids (list): a list of keypoints that belong to the 52 | upper body. 53 | - lower_body_ids (list): a list of keypoints that belong to the 54 | lower body. 55 | - flip_index (list): list of flip index (id) 56 | - flip_pairs (list((2,))): list of flip pairs (id) 57 | - flip_index_name (list): list of flip index (name) 58 | - flip_pairs_name (list((2,))): list of flip pairs (name) 59 | - pose_kpt_color (np.ndarray): the color of the keypoint for 60 | visualization. 61 | """ 62 | 63 | self.keypoint_num = len(self.keypoint_info.keys()) 64 | self.keypoint_id2name = {} 65 | self.keypoint_name2id = {} 66 | 67 | self.pose_kpt_color = [] 68 | self.upper_body_ids = [] 69 | self.lower_body_ids = [] 70 | 71 | self.flip_index_name = [] 72 | self.flip_pairs_name = [] 73 | 74 | for kid in self.keypoint_info.keys(): 75 | 76 | keypoint_name = self.keypoint_info[kid]['name'] 77 | self.keypoint_id2name[kid] = keypoint_name 78 | self.keypoint_name2id[keypoint_name] = kid 79 | self.pose_kpt_color.append(self.keypoint_info[kid].get( 80 | 'color', [255, 128, 0])) 81 | 82 | type = self.keypoint_info[kid].get('type', '') 83 | if type == 'upper': 84 | self.upper_body_ids.append(kid) 85 | elif type == 'lower': 86 | self.lower_body_ids.append(kid) 87 | else: 88 | pass 89 | 90 | swap_keypoint = self.keypoint_info[kid].get('swap', '') 91 | if swap_keypoint == keypoint_name or swap_keypoint == '': 92 | self.flip_index_name.append(keypoint_name) 93 | else: 94 | self.flip_index_name.append(swap_keypoint) 95 | if [swap_keypoint, keypoint_name] not in self.flip_pairs_name: 96 | self.flip_pairs_name.append([keypoint_name, swap_keypoint]) 97 | 98 | self.flip_pairs = [[ 99 | self.keypoint_name2id[pair[0]], self.keypoint_name2id[pair[1]] 100 | ] for pair in self.flip_pairs_name] 101 | self.flip_index = [ 102 | self.keypoint_name2id[name] for name in self.flip_index_name 103 | ] 104 | self.pose_kpt_color = np.array(self.pose_kpt_color) 105 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import os 8 | import pprint 9 | 10 | import torch 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | import torchvision.transforms as transforms 17 | 18 | import _init_paths 19 | from config import cfg 20 | from config import update_config 21 | from core.loss import JointsMSELoss 22 | from core.function import validate 23 | from utils.utils import create_logger 24 | import models 25 | import dataset_animal 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description='Train keypoints network') 29 | # general 30 | parser.add_argument('--cfg', 31 | help='experiment configure file name', 32 | required=True, 33 | type=str) 34 | 35 | parser.add_argument('opts', 36 | help="Modify config options using the command-line", 37 | default=None, 38 | nargs=argparse.REMAINDER) 39 | 40 | parser.add_argument('--modelDir', 41 | help='model directory', 42 | type=str, 43 | default='') 44 | parser.add_argument('--logDir', 45 | help='log directory', 46 | type=str, 47 | default='') 48 | parser.add_argument('--dataDir', 49 | help='data directory', 50 | type=str, 51 | default='') 52 | parser.add_argument('--prevModelDir', 53 | help='prev Model directory', 54 | type=str, 55 | default='') 56 | parser.add_argument('--animalpose', 57 | help='train on ap10k', 58 | action='store_true') 59 | parser.add_argument('--vis', action='store_true') 60 | 61 | 62 | 63 | args = parser.parse_args() 64 | return args 65 | 66 | 67 | def main(): 68 | args = parse_args() 69 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 70 | update_config(cfg, args) 71 | 72 | logger, final_output_dir, tb_log_dir = create_logger( 73 | cfg, args.cfg, 'test') 74 | 75 | logger.info(pprint.pformat(args)) 76 | logger.info(cfg) 77 | 78 | # cudnn related setting 79 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 80 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 81 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 82 | 83 | model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( 84 | cfg, is_train=False 85 | ) 86 | 87 | if cfg.TEST.MODEL_FILE: 88 | logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) 89 | checkpoint = torch.load(cfg.TEST.MODEL_FILE) 90 | model.load_state_dict(checkpoint['state_dict'], strict=True) 91 | else: 92 | model_state_file = os.path.join( 93 | final_output_dir, 'final_state.pth' 94 | ) 95 | logger.info('=> loading model from {}'.format(model_state_file)) 96 | model.load_state_dict(torch.load(model_state_file)) 97 | 98 | model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() 99 | 100 | # define loss function (criterion) and optimizer 101 | criterion = JointsMSELoss( 102 | use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT 103 | ).cuda() 104 | 105 | # Data loading code 106 | normalize = transforms.Normalize( 107 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 108 | ) 109 | 110 | if args.animalpose: 111 | valid_dataset = eval('dataset_animal.' + cfg.DATASET.DATASET)( 112 | cfg, cfg.DATASET.ROOT, cfg.DATASET.VAL_SET, False, 113 | transforms.Compose([ 114 | transforms.ToTensor(), 115 | normalize, 116 | ]) 117 | ) 118 | else: 119 | valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)( 120 | cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False, 121 | transforms.Compose([ 122 | transforms.ToTensor(), 123 | normalize, 124 | ]) 125 | ) 126 | valid_loader = torch.utils.data.DataLoader( 127 | valid_dataset, 128 | batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS), 129 | shuffle=False, 130 | num_workers=cfg.WORKERS, 131 | pin_memory=True 132 | ) 133 | 134 | # evaluate on validation set 135 | validate(cfg, valid_loader, valid_dataset, model, criterion, 136 | final_output_dir, tb_log_dir, animalpose=args.animalpose, vis=args.vis) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /data/label_list/annotation_list_15: -------------------------------------------------------------------------------- 1 | [30106, 27825, 26981, 29546, 29189, 30261, 26814, 29374, 26976, 29962, 29485, 26792, 27865, 26881, 27025, 30418, 31044, 30884, 31666, 30806, 31355, 30954, 30980, 30978, 31199, 30957, 30975, 30999, 30940, 31832, 49302, 49257, 49308, 49236, 49317, 49399, 49274, 49285, 49405, 49315, 49363, 49337, 49283, 49322, 49323, 50037, 50190, 50108, 50132, 50218, 50049, 50214, 50026, 50135, 50018, 50162, 50034, 50136, 50238, 50003, 19969, 19989, 19886, 20002, 19937, 19838, 19832, 19897, 19917, 19870, 19913, 19848, 19983, 19925, 19939, 32046, 31906, 31924, 32044, 32020, 31913, 32050, 32052, 31950, 31956, 32078, 32035, 32033, 32063, 31985, 40246, 33006, 36206, 40783, 40405, 32973, 33505, 40825, 32739, 33739, 35517, 32895, 33283, 34727, 35272, 36331, 36338, 36351, 36485, 36454, 36422, 36340, 36419, 36393, 36373, 36387, 36500, 36408, 36434, 36409, 37658, 37859, 37626, 37861, 37606, 37642, 37733, 37532, 37559, 37785, 37764, 37843, 37679, 37792, 37648, 37888, 37867, 37876, 37895, 37878, 37881, 37866, 37871, 37877, 37885, 37893, 37889, 37886, 37894, 37891, 38697, 38497, 37993, 38710, 37997, 38047, 38021, 38003, 38018, 38330, 37974, 38564, 38019, 38231, 37957, 38869, 38859, 38785, 38773, 38726, 38821, 38879, 38850, 38762, 38775, 38882, 38862, 38783, 38797, 38736, 39808, 39817, 39919, 39923, 39856, 39811, 39870, 39840, 39746, 39892, 39841, 39778, 39770, 39757, 39804, 40905, 40906, 40849, 40845, 40901, 40869, 40880, 40913, 40930, 40870, 40872, 40860, 40923, 40846, 40868, 41015, 41040, 40971, 41009, 41135, 41034, 41049, 40974, 41067, 40941, 41139, 41024, 40947, 41057, 41012, 44904, 44689, 44910, 44286, 44777, 44411, 44339, 44265, 44357, 44364, 44300, 44285, 44711, 44566, 44257, 604, 238, 46, 135, 71, 243, 232, 185, 104, 105, 187, 207, 150, 165, 154, 1099, 1053, 1075, 1112, 1084, 1025, 1030, 1028, 1098, 1089, 1060, 1091, 1033, 1044, 1038, 1143, 1144, 1164, 1145, 1280, 1135, 1213, 1228, 1146, 1160, 1335, 1209, 1220, 1168, 1149, 2158, 2023, 1999, 2083, 2151, 2093, 2119, 2081, 2111, 2160, 2173, 2048, 2027, 2141, 2159, 4682, 4607, 5111, 3223, 5219, 3443, 5589, 4238, 5655, 2821, 4434, 4336, 2710, 2688, 5544, 8395, 7199, 7914, 7776, 7003, 7557, 8012, 8147, 7714, 6868, 7736, 7286, 8408, 7039, 6111, 48688, 48592, 48582, 48597, 48596, 48606, 48578, 48637, 48708, 48647, 48587, 48589, 48717, 48569, 48584, 20860, 20778, 20885, 20762, 20805, 20814, 20720, 20759, 20848, 20764, 20692, 20782, 20861, 20727, 20741, 22631, 22072, 22407, 22053, 22041, 22638, 22076, 22081, 22658, 22070, 22042, 22396, 22207, 22114, 22440, 50780, 50797, 50821, 50912, 50853, 50870, 51163, 51256, 50815, 51192, 50847, 51151, 51218, 50790, 51103, 23624, 23602, 23512, 23635, 23523, 23623, 23550, 23690, 23695, 23625, 23584, 23448, 23471, 23687, 23706, 45079, 45065, 45036, 45028, 45033, 45025, 45003, 44927, 45070, 45051, 45112, 44997, 45076, 45002, 44966, 46119, 46177, 46181, 46166, 46164, 46261, 46105, 46222, 46121, 46151, 46251, 46156, 46291, 46140, 46265, 47530, 47491, 47628, 47517, 47469, 47643, 47594, 47588, 47555, 47492, 47639, 47485, 47464, 47587, 47470, 54986, 54974, 54875, 54944, 54845, 54988, 54855, 54948, 54939, 54953, 54967, 55006, 54979, 54852, 54917, 10276, 17391, 10345, 14968, 10360, 9279, 8828, 9330, 11036, 8814, 18406, 8884, 10454, 18760, 18330, 17653, 17714, 17707, 17990, 17890, 18288, 17645, 17733, 18156, 18279, 17751, 18201, 17669, 17708, 17740, 19326, 19787, 19808, 19719, 19619, 19342, 19799, 19306, 19288, 19329, 19308, 19788, 19375, 19272, 19541, 22937, 22716, 22696, 22679, 22701, 23413, 22717, 23015, 22661, 22766, 23420, 22729, 22710, 22765, 23004, 43091, 43262, 43244, 43089, 43126, 43161, 43266, 43182, 43154, 43087, 43169, 43224, 43106, 43160, 43206, 50366, 50379, 50389, 50382, 50415, 50300, 50363, 50279, 50328, 50400, 50426, 50394, 50285, 50309, 50374, 20022, 20013, 20051, 20055, 20045, 20076, 20065, 20067, 20018, 20058, 20062, 20023, 20068, 20069, 20072, 20186, 20183, 20082, 20119, 20212, 20128, 20211, 20103, 20208, 20166, 20141, 20152, 20081, 20135, 20198, 20285, 20308, 20303, 20259, 20318, 20323, 20266, 20292, 20276, 20312, 20298, 20249, 20253, 20263, 20277, 20552, 20541, 20478, 20371, 20374, 20587, 20385, 20376, 20416, 20520, 20369, 20384, 20397, 20439, 20480, 20606, 20638, 20684, 20641, 20675, 20598, 20624, 20669, 20672, 20613, 20649, 20656, 20634, 20631, 20604, 48754, 48793, 48808, 48801, 48826, 48780, 48807, 48734, 48747, 48797, 48874, 48790, 48908, 48758, 48761, 49102, 48940, 48911, 49091, 49058, 48983, 48924, 49011, 49012, 49095, 49003, 49025, 48919, 49063, 49028, 51885, 51872, 51801, 51804, 51816, 51837, 51846, 51865, 51841, 51838, 51845, 51811, 51829, 51819, 51859, 51985, 51962, 52030, 52026, 51906, 52145, 51963, 52102, 51919, 51973, 51924, 52132, 52156, 51914, 52100, 55660, 55651, 55624, 55667, 55645, 55631, 55661, 55625, 55647, 55630, 55640, 55662, 55669, 55663, 55664, 55757, 55758, 55873, 55771, 55822, 55901, 55710, 55775, 55699, 55803, 55802, 55704, 55879, 55705, 55795, 56634, 56567, 56606, 56716, 56694, 56736, 56713, 56728, 56638, 56691, 56621, 56622, 56677, 56629, 56684, 58430, 58573, 58564, 58517, 58514, 58528, 58537, 58488, 58439, 58619, 58543, 58511, 58475, 58522, 58463] -------------------------------------------------------------------------------- /lib/utils/consistency.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import cv2 4 | import torchvision.transforms as transforms 5 | from core.inference import get_final_preds_const, get_final_preds 6 | from utils.transforms import get_affine_transform, flip_back 7 | 8 | 9 | normalize = transforms.Normalize( 10 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 11 | ) 12 | trans_inp = transforms.Compose([ 13 | transforms.ToTensor(), 14 | normalize, 15 | ]) 16 | 17 | 18 | def prediction_check(inp, model, dataset, c_ori, s_ori, num_transform=1, num_kpts=17): 19 | s0 = np.array([256/200.0, 256/200.0], dtype=np.float32) 20 | sf = 0.25 21 | rf = 30 22 | c = np.array([128, 128]) 23 | image_size = np.array([256, 256]) 24 | score_map_avg = np.zeros((1, num_kpts, 64, 64)) 25 | 26 | for i in range(num_transform): 27 | img = inp.clone().numpy() 28 | if i == 0: 29 | s = s0 30 | r = 0 31 | else: 32 | s = s0 * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) 33 | r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) 34 | trans = get_affine_transform(c, s, r, image_size) 35 | img = cv2.warpAffine( 36 | img, 37 | trans, 38 | (int(image_size[0]), int(image_size[1])), 39 | flags=cv2.INTER_LINEAR) 40 | 41 | input = (trans_inp(img)).unsqueeze(0) 42 | outputs, _ = model(input.cuda()) 43 | score_map = outputs[-1] if isinstance(outputs, list) else outputs 44 | feat_map = score_map.squeeze(0).detach().cpu().numpy() 45 | 46 | flip_input = input.flip(3) 47 | flip_output, _ = model(flip_input.cuda()) 48 | flip_output_re = flip_back(flip_output.detach().cpu().numpy(), 49 | dataset.flip_pairs) 50 | feat_map += np.squeeze(flip_output_re) 51 | feat_map /= 2 52 | M = cv2.getRotationMatrix2D((32, 32), -r, 1) 53 | feat_map = cv2.warpAffine(feat_map.transpose(1, 2, 0), M, (64, 64)) 54 | feat_map = cv2.resize(feat_map, None, fx=s[0]*200.0/256.0, fy=s[1]*200.0/256.0, interpolation=cv2.INTER_LINEAR) 55 | if feat_map.shape[0] < 64: 56 | start = 32 - feat_map.shape[0]//2 57 | end = start + feat_map.shape[0] 58 | score_map_avg[0][:, start:end, start:end] += feat_map.transpose(2, 0, 1) 59 | else: 60 | start = feat_map.shape[0]//2 - 32 61 | end = feat_map.shape[0]//2 + 32 62 | score_map_avg[0] += feat_map[start:end, start:end].transpose(2, 0, 1) 63 | 64 | score_map_avg = score_map_avg/num_transform 65 | confidence_score = np.max(score_map_avg, axis=(0, 2, 3)) 66 | 67 | confidence = confidence_score.astype(np.float32) 68 | preds, _ = get_final_preds_const(score_map_avg, c_ori, s_ori) 69 | generated_kpts = np.zeros((num_kpts, 3)).astype(np.float32) 70 | generated_kpts[:, :2] = preds[0, :, :2] 71 | generated_kpts[:, 2] = confidence 72 | return generated_kpts, score_map_avg 73 | 74 | 75 | def generate_target(joints, joints_vis): 76 | ''' 77 | :param joints: [num_joints, 3] 78 | :param joints_vis: [num_joints, 3] 79 | :return: target, target_weight(1: visible, 0: invisible) 80 | ''' 81 | num_joints = 17 82 | image_size = np.array([256, 256]) 83 | heatmap_size = np.array([64, 64]) 84 | sigma = 2 85 | target_weight = np.ones((num_joints, 1), dtype=np.float32) 86 | target_weight[:, 0] = joints_vis[:, 0] 87 | 88 | target = np.zeros((num_joints, heatmap_size[0], heatmap_size[1]), dtype=np.float32) 89 | 90 | tmp_size = sigma * 3 91 | 92 | for joint_id in range(17): 93 | feat_stride = image_size / heatmap_size 94 | mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5) 95 | mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5) 96 | # Check that any part of the gaussian is in-bounds 97 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] 98 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] 99 | if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \ 100 | or br[0] < 0 or br[1] < 0: 101 | # If not, just return the image as is 102 | target_weight[joint_id] = 0 103 | continue 104 | 105 | # # Generate gaussian 106 | size = 2 * tmp_size + 1 107 | x = np.arange(0, size, 1, np.float32) 108 | y = x[:, np.newaxis] 109 | x0 = y0 = size // 2 110 | # The gaussian is not normalized, we want the center value to equal 1 111 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 112 | 113 | # Usable gaussian range 114 | g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0] 115 | g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1] 116 | # Image range 117 | img_x = max(0, ul[0]), min(br[0], heatmap_size[0]) 118 | img_y = max(0, ul[1]), min(br[1], heatmap_size[1]) 119 | 120 | v = target_weight[joint_id] 121 | if v > 0.5: 122 | target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \ 123 | g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 124 | 125 | return np.expand_dims(target, axis=0), target_weight -------------------------------------------------------------------------------- /lib/dataset_animal/ap10k_info.py: -------------------------------------------------------------------------------- 1 | dataset_info = dict( 2 | dataset_name='ap10k', 3 | paper_info=dict( 4 | author='Yu, Hang and Xu, Yufei and Zhang, Jing and ' 5 | 'Zhao, Wei and Guan, Ziyu and Tao, Dacheng', 6 | title='AP-10K: A Benchmark for Animal Pose Estimation in the Wild', 7 | container='35th Conference on Neural Information Processing Systems ' 8 | '(NeurIPS 2021) Track on Datasets and Bench-marks.', 9 | year='2021', 10 | homepage='https://github.com/AlexTheBad/AP-10K', 11 | ), 12 | keypoint_info={ 13 | 0: 14 | dict( 15 | name='L_Eye', id=0, color=[0, 255, 0], type='upper', swap='R_Eye'), 16 | 1: 17 | dict( 18 | name='R_Eye', 19 | id=1, 20 | color=[255, 128, 0], 21 | type='upper', 22 | swap='L_Eye'), 23 | 2: 24 | dict(name='Nose', id=2, color=[51, 153, 255], type='upper', swap=''), 25 | 3: 26 | dict(name='Neck', id=3, color=[51, 153, 255], type='upper', swap=''), 27 | 4: 28 | dict( 29 | name='Root of tail', 30 | id=4, 31 | color=[51, 153, 255], 32 | type='lower', 33 | swap=''), 34 | 5: 35 | dict( 36 | name='L_Shoulder', 37 | id=5, 38 | color=[51, 153, 255], 39 | type='upper', 40 | swap='R_Shoulder'), 41 | 6: 42 | dict( 43 | name='L_Elbow', 44 | id=6, 45 | color=[51, 153, 255], 46 | type='upper', 47 | swap='R_Elbow'), 48 | 7: 49 | dict( 50 | name='L_F_Paw', 51 | id=7, 52 | color=[0, 255, 0], 53 | type='upper', 54 | swap='R_F_Paw'), 55 | 8: 56 | dict( 57 | name='R_Shoulder', 58 | id=8, 59 | color=[0, 255, 0], 60 | type='upper', 61 | swap='L_Shoulder'), 62 | 9: 63 | dict( 64 | name='R_Elbow', 65 | id=9, 66 | color=[255, 128, 0], 67 | type='upper', 68 | swap='L_Elbow'), 69 | 10: 70 | dict( 71 | name='R_F_Paw', 72 | id=10, 73 | color=[0, 255, 0], 74 | type='lower', 75 | swap='L_F_Paw'), 76 | 11: 77 | dict( 78 | name='L_Hip', 79 | id=11, 80 | color=[255, 128, 0], 81 | type='lower', 82 | swap='R_Hip'), 83 | 12: 84 | dict( 85 | name='L_Knee', 86 | id=12, 87 | color=[255, 128, 0], 88 | type='lower', 89 | swap='R_Knee'), 90 | 13: 91 | dict( 92 | name='L_B_Paw', 93 | id=13, 94 | color=[0, 255, 0], 95 | type='lower', 96 | swap='R_B_Paw'), 97 | 14: 98 | dict( 99 | name='R_Hip', id=14, color=[0, 255, 0], type='lower', 100 | swap='L_Hip'), 101 | 15: 102 | dict( 103 | name='R_Knee', 104 | id=15, 105 | color=[0, 255, 0], 106 | type='lower', 107 | swap='L_Knee'), 108 | 16: 109 | dict( 110 | name='R_B_Paw', 111 | id=16, 112 | color=[0, 255, 0], 113 | type='lower', 114 | swap='L_B_Paw'), 115 | }, 116 | skeleton_info={ 117 | 0: dict(link=('L_Eye', 'R_Eye'), id=0, color=[0, 0, 255]), 118 | 1: dict(link=('L_Eye', 'Nose'), id=1, color=[0, 0, 255]), 119 | 2: dict(link=('R_Eye', 'Nose'), id=2, color=[0, 0, 255]), 120 | 3: dict(link=('Nose', 'Neck'), id=3, color=[0, 255, 0]), 121 | 4: dict(link=('Neck', 'Root of tail'), id=4, color=[0, 255, 0]), 122 | 5: dict(link=('Neck', 'L_Shoulder'), id=5, color=[0, 255, 255]), 123 | 6: dict(link=('L_Shoulder', 'L_Elbow'), id=6, color=[0, 255, 255]), 124 | 7: dict(link=('L_Elbow', 'L_F_Paw'), id=6, color=[0, 255, 255]), 125 | 8: dict(link=('Neck', 'R_Shoulder'), id=7, color=[6, 156, 250]), 126 | 9: dict(link=('R_Shoulder', 'R_Elbow'), id=8, color=[6, 156, 250]), 127 | 10: dict(link=('R_Elbow', 'R_F_Paw'), id=9, color=[6, 156, 250]), 128 | 11: dict(link=('Root of tail', 'L_Hip'), id=10, color=[0, 255, 255]), 129 | 12: dict(link=('L_Hip', 'L_Knee'), id=11, color=[0, 255, 255]), 130 | 13: dict(link=('L_Knee', 'L_B_Paw'), id=12, color=[0, 255, 255]), 131 | 14: dict(link=('Root of tail', 'R_Hip'), id=13, color=[6, 156, 250]), 132 | 15: dict(link=('R_Hip', 'R_Knee'), id=14, color=[6, 156, 250]), 133 | 16: dict(link=('R_Knee', 'R_B_Paw'), id=15, color=[6, 156, 250]), 134 | }, 135 | joint_weights=[ 136 | 1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2, 1.2, 1.5, 137 | 1.5 138 | ], 139 | 140 | # Note: The original paper did not provide enough information about 141 | # the sigmas. We modified from 'https://github.com/cocodataset/' 142 | # 'cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py#L523' 143 | sigmas=[ 144 | 0.025, 0.025, 0.026, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 145 | 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089 146 | ]) 147 | -------------------------------------------------------------------------------- /lib/nms/nms_kernel.cu: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Copyright (c) Microsoft 3 | // Licensed under The MIT License 4 | // Modified from MATLAB Faster R-CNN (https://github.com/shaoqingren/faster_rcnn) 5 | // ------------------------------------------------------------------ 6 | 7 | #include "gpu_nms.hpp" 8 | #include 9 | #include 10 | 11 | #define CUDA_CHECK(condition) \ 12 | /* Code block avoids redefinition of cudaError_t error */ \ 13 | do { \ 14 | cudaError_t error = condition; \ 15 | if (error != cudaSuccess) { \ 16 | std::cout << cudaGetErrorString(error) << std::endl; \ 17 | } \ 18 | } while (0) 19 | 20 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 21 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 22 | 23 | __device__ inline float devIoU(float const * const a, float const * const b) { 24 | float left = max(a[0], b[0]), right = min(a[2], b[2]); 25 | float top = max(a[1], b[1]), bottom = min(a[3], b[3]); 26 | float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); 27 | float interS = width * height; 28 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 29 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 30 | return interS / (Sa + Sb - interS); 31 | } 32 | 33 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 34 | const float *dev_boxes, unsigned long long *dev_mask) { 35 | const int row_start = blockIdx.y; 36 | const int col_start = blockIdx.x; 37 | 38 | // if (row_start > col_start) return; 39 | 40 | const int row_size = 41 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 42 | const int col_size = 43 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 44 | 45 | __shared__ float block_boxes[threadsPerBlock * 5]; 46 | if (threadIdx.x < col_size) { 47 | block_boxes[threadIdx.x * 5 + 0] = 48 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; 49 | block_boxes[threadIdx.x * 5 + 1] = 50 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; 51 | block_boxes[threadIdx.x * 5 + 2] = 52 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; 53 | block_boxes[threadIdx.x * 5 + 3] = 54 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; 55 | block_boxes[threadIdx.x * 5 + 4] = 56 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; 57 | } 58 | __syncthreads(); 59 | 60 | if (threadIdx.x < row_size) { 61 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 62 | const float *cur_box = dev_boxes + cur_box_idx * 5; 63 | int i = 0; 64 | unsigned long long t = 0; 65 | int start = 0; 66 | if (row_start == col_start) { 67 | start = threadIdx.x + 1; 68 | } 69 | for (i = start; i < col_size; i++) { 70 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { 71 | t |= 1ULL << i; 72 | } 73 | } 74 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock); 75 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 76 | } 77 | } 78 | 79 | void _set_device(int device_id) { 80 | int current_device; 81 | CUDA_CHECK(cudaGetDevice(¤t_device)); 82 | if (current_device == device_id) { 83 | return; 84 | } 85 | // The call to cudaSetDevice must come before any calls to Get, which 86 | // may perform initialization using the GPU. 87 | CUDA_CHECK(cudaSetDevice(device_id)); 88 | } 89 | 90 | void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num, 91 | int boxes_dim, float nms_overlap_thresh, int device_id) { 92 | _set_device(device_id); 93 | 94 | float* boxes_dev = NULL; 95 | unsigned long long* mask_dev = NULL; 96 | 97 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock); 98 | 99 | CUDA_CHECK(cudaMalloc(&boxes_dev, 100 | boxes_num * boxes_dim * sizeof(float))); 101 | CUDA_CHECK(cudaMemcpy(boxes_dev, 102 | boxes_host, 103 | boxes_num * boxes_dim * sizeof(float), 104 | cudaMemcpyHostToDevice)); 105 | 106 | CUDA_CHECK(cudaMalloc(&mask_dev, 107 | boxes_num * col_blocks * sizeof(unsigned long long))); 108 | 109 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock), 110 | DIVUP(boxes_num, threadsPerBlock)); 111 | dim3 threads(threadsPerBlock); 112 | nms_kernel<<>>(boxes_num, 113 | nms_overlap_thresh, 114 | boxes_dev, 115 | mask_dev); 116 | 117 | std::vector mask_host(boxes_num * col_blocks); 118 | CUDA_CHECK(cudaMemcpy(&mask_host[0], 119 | mask_dev, 120 | sizeof(unsigned long long) * boxes_num * col_blocks, 121 | cudaMemcpyDeviceToHost)); 122 | 123 | std::vector remv(col_blocks); 124 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); 125 | 126 | int num_to_keep = 0; 127 | for (int i = 0; i < boxes_num; i++) { 128 | int nblock = i / threadsPerBlock; 129 | int inblock = i % threadsPerBlock; 130 | 131 | if (!(remv[nblock] & (1ULL << inblock))) { 132 | keep_out[num_to_keep++] = i; 133 | unsigned long long *p = &mask_host[0] + i * col_blocks; 134 | for (int j = nblock; j < col_blocks; j++) { 135 | remv[j] |= p[j]; 136 | } 137 | } 138 | } 139 | *num_out = num_to_keep; 140 | 141 | CUDA_CHECK(cudaFree(boxes_dev)); 142 | CUDA_CHECK(cudaFree(mask_dev)); 143 | } 144 | -------------------------------------------------------------------------------- /lib/nms/setup_linux.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Pose.gluon 3 | # Copyright (c) 2018-present Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Modified from py-faster-rcnn (https://github.com/rbgirshick/py-faster-rcnn) 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | from os.path import join as pjoin 10 | from setuptools import setup 11 | from distutils.extension import Extension 12 | from Cython.Distutils import build_ext 13 | import numpy as np 14 | 15 | 16 | def find_in_path(name, path): 17 | "Find a file in a search path" 18 | # Adapted fom 19 | # http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/ 20 | for dir in path.split(os.pathsep): 21 | binpath = pjoin(dir, name) 22 | if os.path.exists(binpath): 23 | return os.path.abspath(binpath) 24 | return None 25 | 26 | 27 | def locate_cuda(): 28 | """Locate the CUDA environment on the system 29 | Returns a dict with keys 'home', 'nvcc', 'include', and 'lib64' 30 | and values giving the absolute path to each directory. 31 | Starts by looking for the CUDAHOME env variable. If not found, everything 32 | is based on finding 'nvcc' in the PATH. 33 | """ 34 | 35 | # first check if the CUDAHOME env variable is in use 36 | if 'CUDAHOME' in os.environ: 37 | home = os.environ['CUDAHOME'] 38 | nvcc = pjoin(home, 'bin', 'nvcc') 39 | else: 40 | # otherwise, search the PATH for NVCC 41 | default_path = pjoin(os.sep, 'usr', 'local', 'cuda', 'bin') 42 | nvcc = find_in_path('nvcc', os.environ['PATH'] + os.pathsep + default_path) 43 | if nvcc is None: 44 | raise EnvironmentError('The nvcc binary could not be ' 45 | 'located in your $PATH. Either add it to your path, or set $CUDAHOME') 46 | home = os.path.dirname(os.path.dirname(nvcc)) 47 | 48 | cudaconfig = {'home':home, 'nvcc':nvcc, 49 | 'include': pjoin(home, 'include'), 50 | 'lib64': pjoin(home, 'lib64')} 51 | for k, v in cudaconfig.items(): 52 | if not os.path.exists(v): 53 | raise EnvironmentError('The CUDA %s path could not be located in %s' % (k, v)) 54 | 55 | return cudaconfig 56 | CUDA = locate_cuda() 57 | 58 | 59 | # Obtain the numpy include directory. This logic works across numpy versions. 60 | try: 61 | numpy_include = np.get_include() 62 | except AttributeError: 63 | numpy_include = np.get_numpy_include() 64 | 65 | 66 | def customize_compiler_for_nvcc(self): 67 | """inject deep into distutils to customize how the dispatch 68 | to gcc/nvcc works. 69 | If you subclass UnixCCompiler, it's not trivial to get your subclass 70 | injected in, and still have the right customizations (i.e. 71 | distutils.sysconfig.customize_compiler) run on it. So instead of going 72 | the OO route, I have this. Note, it's kindof like a wierd functional 73 | subclassing going on.""" 74 | 75 | # tell the compiler it can processes .cu 76 | self.src_extensions.append('.cu') 77 | 78 | # save references to the default compiler_so and _comple methods 79 | default_compiler_so = self.compiler_so 80 | super = self._compile 81 | 82 | # now redefine the _compile method. This gets executed for each 83 | # object but distutils doesn't have the ability to change compilers 84 | # based on source extension: we add it. 85 | def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts): 86 | if os.path.splitext(src)[1] == '.cu': 87 | # use the cuda for .cu files 88 | self.set_executable('compiler_so', CUDA['nvcc']) 89 | # use only a subset of the extra_postargs, which are 1-1 translated 90 | # from the extra_compile_args in the Extension class 91 | postargs = extra_postargs['nvcc'] 92 | else: 93 | postargs = extra_postargs['gcc'] 94 | 95 | super(obj, src, ext, cc_args, postargs, pp_opts) 96 | # reset the default compiler_so, which we might have changed for cuda 97 | self.compiler_so = default_compiler_so 98 | 99 | # inject our redefined _compile method into the class 100 | self._compile = _compile 101 | 102 | 103 | # run the customize_compiler 104 | class custom_build_ext(build_ext): 105 | def build_extensions(self): 106 | customize_compiler_for_nvcc(self.compiler) 107 | build_ext.build_extensions(self) 108 | 109 | 110 | ext_modules = [ 111 | Extension( 112 | "cpu_nms", 113 | ["cpu_nms.pyx"], 114 | extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]}, 115 | include_dirs = [numpy_include] 116 | ), 117 | Extension('gpu_nms', 118 | ['nms_kernel.cu', 'gpu_nms.pyx'], 119 | library_dirs=[CUDA['lib64']], 120 | libraries=['cudart'], 121 | language='c++', 122 | runtime_library_dirs=[CUDA['lib64']], 123 | # this syntax is specific to this build system 124 | # we're only going to use certain compiler args with nvcc and not with 125 | # gcc the implementation of this trick is in customize_compiler() below 126 | extra_compile_args={'gcc': ["-Wno-unused-function"], 127 | 'nvcc': ['-arch=sm_35', 128 | '--ptxas-options=-v', 129 | '-c', 130 | '--compiler-options', 131 | "'-fPIC'"]}, 132 | include_dirs = [numpy_include, CUDA['include']] 133 | ), 134 | ] 135 | 136 | setup( 137 | name='nms', 138 | ext_modules=ext_modules, 139 | # inject our custom trigger 140 | cmdclass={'build_ext': custom_build_ext}, 141 | ) 142 | -------------------------------------------------------------------------------- /lib/utils/augmentation_pool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | import numpy as np 5 | import PIL 6 | import PIL.ImageOps 7 | import PIL.ImageEnhance 8 | import PIL.ImageDraw 9 | from PIL import Image 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | PARAMETER_MAX = 10 14 | 15 | 16 | def AutoContrast(img, **kwarg): 17 | return PIL.ImageOps.autocontrast(img) 18 | 19 | 20 | def Brightness(img, v, max_v, bias=0): 21 | v = _float_parameter(v, max_v) + bias 22 | return PIL.ImageEnhance.Brightness(img).enhance(v) 23 | 24 | 25 | def Color(img, v, max_v, bias=0): 26 | v = _float_parameter(v, max_v) + bias 27 | return PIL.ImageEnhance.Color(img).enhance(v) 28 | 29 | 30 | def Contrast(img, v, max_v, bias=0): 31 | v = _float_parameter(v, max_v) + bias 32 | return PIL.ImageEnhance.Contrast(img).enhance(v) 33 | 34 | 35 | def Cutout(img, v, max_v, bias=0): 36 | if v == 0: 37 | return img 38 | v = _float_parameter(v, max_v) + bias 39 | v = int(v * min(img.size)) 40 | return CutoutAbs(img, v) 41 | 42 | 43 | def CutoutAbs(img, v, **kwarg): 44 | w, h = img.size 45 | x0 = np.random.uniform(0, w) 46 | y0 = np.random.uniform(0, h) 47 | x0 = int(max(0, x0 - v / 2.)) 48 | y0 = int(max(0, y0 - v / 2.)) 49 | x1 = int(min(w, x0 + v)) 50 | y1 = int(min(h, y0 + v)) 51 | xy = (x0, y0, x1, y1) 52 | # gray 53 | color = (127, 127, 127) 54 | img = img.copy() 55 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 56 | return img 57 | 58 | 59 | def Equalize(img, **kwarg): 60 | return PIL.ImageOps.equalize(img) 61 | 62 | 63 | def Identity(img, **kwarg): 64 | return img 65 | 66 | 67 | def Invert(img, **kwarg): 68 | return PIL.ImageOps.invert(img) 69 | 70 | 71 | def Posterize(img, v, max_v, bias=0): 72 | v = _int_parameter(v, max_v) + bias 73 | return PIL.ImageOps.posterize(img, v) 74 | 75 | def Blur(img): 76 | return PIL.Image 77 | 78 | def Rotate(img, v, max_v, bias=0): 79 | v = _int_parameter(v, max_v) + bias 80 | if random.random() < 0.5: 81 | v = -v 82 | return img.rotate(v) 83 | 84 | 85 | def Sharpness(img, v, max_v, bias=0): 86 | v = _float_parameter(v, max_v) + bias 87 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 88 | 89 | 90 | def ShearX(img, v, max_v, bias=0): 91 | v = _float_parameter(v, max_v) + bias 92 | if random.random() < 0.5: 93 | v = -v 94 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 95 | 96 | 97 | def ShearY(img, v, max_v, bias=0): 98 | v = _float_parameter(v, max_v) + bias 99 | if random.random() < 0.5: 100 | v = -v 101 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 102 | 103 | 104 | def Solarize(img, v, max_v, bias=0): 105 | v = _int_parameter(v, max_v) + bias 106 | return PIL.ImageOps.solarize(img, 256 - v) 107 | 108 | 109 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 110 | v = _int_parameter(v, max_v) + bias 111 | if random.random() < 0.5: 112 | v = -v 113 | img_np = np.array(img).astype(np.int) 114 | img_np = img_np + v 115 | img_np = np.clip(img_np, 0, 255) 116 | img_np = img_np.astype(np.uint8) 117 | img = Image.fromarray(img_np) 118 | return PIL.ImageOps.solarize(img, threshold) 119 | 120 | 121 | def TranslateX(img, v, max_v, bias=0): 122 | v = _float_parameter(v, max_v) + bias 123 | if random.random() < 0.5: 124 | v = -v 125 | v = int(v * img.size[0]) 126 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 127 | 128 | 129 | def TranslateY(img, v, max_v, bias=0): 130 | v = _float_parameter(v, max_v) + bias 131 | if random.random() < 0.5: 132 | v = -v 133 | v = int(v * img.size[1]) 134 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 135 | 136 | 137 | def _float_parameter(v, max_v): 138 | return float(v) * max_v / PARAMETER_MAX 139 | 140 | 141 | def _int_parameter(v, max_v): 142 | return int(v * max_v / PARAMETER_MAX) 143 | 144 | 145 | def fixmatch_augment_pool(): 146 | # FixMatch paper 147 | augs = [(AutoContrast, None, None), 148 | (Brightness, 0.9, 0.05), 149 | (Color, 0.9, 0.05), 150 | (Contrast, 0.9, 0.05), 151 | (Equalize, None, None), 152 | (Identity, None, None), 153 | (Posterize, 4, 4), 154 | (Rotate, 30, 0), 155 | (Sharpness, 0.9, 0.05), 156 | (ShearX, 0.3, 0), 157 | (ShearY, 0.3, 0), 158 | (Solarize, 256, 0), 159 | (TranslateX, 0.3, 0), 160 | (TranslateY, 0.3, 0)] 161 | return augs 162 | 163 | 164 | def pose_augment_pool(): 165 | # FixMatch paper 166 | augs = [(AutoContrast, None, None), 167 | (Brightness, 0.9, 0.05), 168 | (Color, 0.9, 0.05), 169 | (Contrast, 0.9, 0.05), 170 | (Equalize, None, None), 171 | (Identity, None, None), 172 | (Posterize, 4, 4), 173 | (Sharpness, 0.9, 0.05), 174 | (Solarize, 256, 0)] 175 | return augs 176 | 177 | 178 | class RandAugmentMC(object): 179 | def __init__(self, n, m, num_cutout): 180 | assert n >= 1 181 | assert 1 <= m <= 10 182 | self.n = n 183 | self.m = m 184 | self.augment_pool = pose_augment_pool() 185 | self.num_cutout = num_cutout 186 | 187 | def __call__(self, img): 188 | ops = random.choices(self.augment_pool, k=self.n) 189 | for op, max_v, bias in ops: 190 | v = np.random.randint(1, self.m) 191 | if random.random() < 0.5: 192 | img = op(img, v=v, max_v=max_v, bias=bias) 193 | if random.random() < 0.5: 194 | for i in range(self.num_cutout): 195 | img = CutoutAbs(img, int(32)) 196 | return img 197 | 198 | -------------------------------------------------------------------------------- /lib/nms/nms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Modified from py-faster-rcnn (https://github.com/rbgirshick/py-faster-rcnn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import numpy as np 12 | 13 | # from .cpu_nms import cpu_nms 14 | # from .gpu_nms import gpu_nms 15 | 16 | 17 | def py_nms_wrapper(thresh): 18 | def _nms(dets): 19 | return nms(dets, thresh) 20 | return _nms 21 | 22 | 23 | # def cpu_nms_wrapper(thresh): 24 | # def _nms(dets): 25 | # return cpu_nms(dets, thresh) 26 | # return _nms 27 | 28 | 29 | # def gpu_nms_wrapper(thresh, device_id): 30 | # def _nms(dets): 31 | # return gpu_nms(dets, thresh, device_id) 32 | # return _nms 33 | 34 | 35 | def nms(dets, thresh): 36 | """ 37 | greedily select boxes with high confidence and overlap with current maximum <= thresh 38 | rule out overlap >= thresh 39 | :param dets: [[x1, y1, x2, y2 score]] 40 | :param thresh: retain overlap < thresh 41 | :return: indexes to keep 42 | """ 43 | if dets.shape[0] == 0: 44 | return [] 45 | 46 | x1 = dets[:, 0] 47 | y1 = dets[:, 1] 48 | x2 = dets[:, 2] 49 | y2 = dets[:, 3] 50 | scores = dets[:, 4] 51 | 52 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 53 | order = scores.argsort()[::-1] 54 | 55 | keep = [] 56 | while order.size > 0: 57 | i = order[0] 58 | keep.append(i) 59 | xx1 = np.maximum(x1[i], x1[order[1:]]) 60 | yy1 = np.maximum(y1[i], y1[order[1:]]) 61 | xx2 = np.minimum(x2[i], x2[order[1:]]) 62 | yy2 = np.minimum(y2[i], y2[order[1:]]) 63 | 64 | w = np.maximum(0.0, xx2 - xx1 + 1) 65 | h = np.maximum(0.0, yy2 - yy1 + 1) 66 | inter = w * h 67 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 68 | 69 | inds = np.where(ovr <= thresh)[0] 70 | order = order[inds + 1] 71 | 72 | return keep 73 | 74 | 75 | def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None): 76 | if not isinstance(sigmas, np.ndarray): 77 | sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0 78 | vars = (sigmas * 2) ** 2 79 | xg = g[0::3] 80 | yg = g[1::3] 81 | vg = g[2::3] 82 | ious = np.zeros((d.shape[0])) 83 | for n_d in range(0, d.shape[0]): 84 | xd = d[n_d, 0::3] 85 | yd = d[n_d, 1::3] 86 | vd = d[n_d, 2::3] 87 | dx = xd - xg 88 | dy = yd - yg 89 | e = (dx ** 2 + dy ** 2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2 90 | if in_vis_thre is not None: 91 | ind = list(vg > in_vis_thre) and list(vd > in_vis_thre) 92 | e = e[ind] 93 | ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0 94 | return ious 95 | 96 | 97 | def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None): 98 | """ 99 | greedily select boxes with high confidence and overlap with current maximum <= thresh 100 | rule out overlap >= thresh, overlap = oks 101 | :param kpts_db 102 | :param thresh: retain overlap < thresh 103 | :return: indexes to keep 104 | """ 105 | if len(kpts_db) == 0: 106 | return [] 107 | 108 | scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))]) 109 | kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))]) 110 | areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))]) 111 | 112 | order = scores.argsort()[::-1] 113 | 114 | keep = [] 115 | while order.size > 0: 116 | i = order[0] 117 | keep.append(i) 118 | 119 | oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre) 120 | 121 | inds = np.where(oks_ovr <= thresh)[0] 122 | order = order[inds + 1] 123 | 124 | return keep 125 | 126 | 127 | def rescore(overlap, scores, thresh, type='gaussian'): 128 | assert overlap.shape[0] == scores.shape[0] 129 | if type == 'linear': 130 | inds = np.where(overlap >= thresh)[0] 131 | scores[inds] = scores[inds] * (1 - overlap[inds]) 132 | else: 133 | scores = scores * np.exp(- overlap**2 / thresh) 134 | 135 | return scores 136 | 137 | 138 | def soft_oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None): 139 | """ 140 | greedily select boxes with high confidence and overlap with current maximum <= thresh 141 | rule out overlap >= thresh, overlap = oks 142 | :param kpts_db 143 | :param thresh: retain overlap < thresh 144 | :return: indexes to keep 145 | """ 146 | if len(kpts_db) == 0: 147 | return [] 148 | 149 | scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))]) 150 | kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))]) 151 | areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))]) 152 | 153 | order = scores.argsort()[::-1] 154 | scores = scores[order] 155 | 156 | # max_dets = order.size 157 | max_dets = 20 158 | keep = np.zeros(max_dets, dtype=np.intp) 159 | keep_cnt = 0 160 | while order.size > 0 and keep_cnt < max_dets: 161 | i = order[0] 162 | 163 | oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre) 164 | 165 | order = order[1:] 166 | scores = rescore(oks_ovr, scores[1:], thresh) 167 | 168 | tmp = scores.argsort()[::-1] 169 | order = order[tmp] 170 | scores = scores[tmp] 171 | 172 | keep[keep_cnt] = i 173 | keep_cnt += 1 174 | 175 | keep = keep[:keep_cnt] 176 | 177 | return keep 178 | # kpts_db = kpts_db[:keep_cnt] 179 | 180 | # return kpts_db 181 | -------------------------------------------------------------------------------- /data/label_list/annotation_list_20: -------------------------------------------------------------------------------- 1 | [22669, 22705, 23420, 22661, 23417, 22738, 22687, 22937, 22837, 22702, 23418, 22667, 22696, 23425, 22746, 22793, 23203, 23292, 22724, 22689, 27936, 29387, 29851, 29017, 27070, 26837, 26803, 30306, 29684, 26970, 28936, 30028, 29079, 29491, 29773, 28234, 26748, 29227, 26848, 26603, 30740, 30996, 30773, 30629, 31843, 31677, 30986, 30651, 30507, 30981, 30595, 31666, 30995, 30817, 31266, 30529, 30984, 30551, 31432, 31288, 44294, 44318, 44811, 44577, 44466, 44600, 44248, 44332, 44319, 44280, 44689, 44357, 44331, 44897, 44314, 44411, 44589, 44268, 44922, 44389, 54997, 54817, 54882, 54875, 54969, 54928, 54818, 54938, 54836, 54831, 54839, 54852, 54899, 55008, 54906, 54845, 54944, 54992, 54995, 54854, 48872, 48753, 48789, 48873, 48752, 48763, 48844, 48838, 48807, 48788, 48834, 48862, 48882, 48908, 48846, 48756, 48888, 48856, 48779, 48854, 49031, 49074, 49094, 49002, 48946, 48973, 49036, 48913, 49037, 49084, 48933, 49041, 49061, 48988, 48999, 48971, 48978, 49004, 49001, 48995, 55627, 55648, 55651, 55631, 55659, 55663, 55664, 55629, 55634, 55657, 55640, 55643, 55636, 55625, 55626, 55653, 55660, 55641, 55624, 55635, 55825, 55743, 55748, 55803, 55674, 55820, 55793, 55783, 55767, 55735, 55828, 55821, 55868, 55764, 55859, 55876, 55694, 55889, 55850, 55686, 56703, 56617, 56680, 56730, 56731, 56574, 56752, 56726, 56567, 56651, 56606, 56628, 56721, 56750, 56577, 56684, 56657, 56693, 56747, 56718, 58541, 58623, 58459, 58567, 58624, 58487, 58511, 58527, 58612, 58432, 58454, 58515, 58435, 58577, 58467, 58585, 58461, 58445, 58458, 58569, 45098, 45095, 45060, 45069, 45083, 45037, 45000, 44952, 44933, 45066, 45070, 44945, 45005, 45092, 44965, 45077, 45076, 45013, 44934, 45067, 46176, 46197, 46196, 46129, 46155, 46111, 46231, 46148, 46229, 46102, 46187, 46290, 46103, 46284, 46240, 46154, 46281, 46132, 46134, 46121, 20863, 20879, 20853, 20866, 20761, 20805, 20877, 20769, 20721, 20704, 20833, 20749, 20695, 20705, 20880, 20733, 20789, 20860, 20796, 20689, 22131, 22418, 22631, 22076, 22075, 22008, 22341, 22625, 22081, 22263, 22028, 21999, 22048, 22496, 22019, 22092, 22363, 22047, 22452, 22073, 32047, 32017, 32039, 31952, 32035, 31977, 32058, 32054, 31966, 31936, 31917, 31929, 31892, 32064, 31963, 31906, 31974, 31938, 32052, 31954, 40803, 32706, 36283, 40386, 33006, 35295, 40831, 40414, 40778, 40408, 33295, 33472, 40147, 40380, 34694, 35095, 39947, 32595, 35850, 40821, 36497, 36491, 36387, 36338, 36482, 36458, 36432, 36344, 36321, 36477, 36505, 36493, 36320, 36401, 36445, 36466, 36471, 36413, 36500, 36352, 37532, 37699, 37696, 37682, 37851, 37653, 37787, 37816, 37598, 37584, 37526, 37645, 37654, 37684, 37548, 37821, 37859, 37708, 37599, 37591, 37866, 37891, 37895, 37878, 37874, 37876, 37885, 37894, 37888, 37881, 37892, 37877, 37886, 37871, 37872, 37883, 37893, 37889, 37882, 37870, 38564, 37910, 37918, 38275, 38037, 38108, 38519, 38041, 38047, 37909, 38080, 38709, 37997, 37993, 38035, 38708, 38186, 37968, 37966, 38586, 38779, 38884, 38907, 38782, 38803, 38854, 38826, 38831, 38776, 38732, 38811, 38840, 38808, 38737, 38823, 38748, 38730, 38836, 38768, 38859, 39789, 39820, 39790, 39871, 39855, 39857, 39897, 39878, 39862, 39751, 39792, 39763, 39736, 39852, 39920, 39750, 39772, 39816, 39899, 39886, 40856, 40868, 40906, 40921, 40884, 40927, 40881, 40911, 40875, 40896, 40923, 40893, 40870, 40843, 40886, 40908, 40910, 40882, 40902, 40897, 40969, 41020, 41116, 41122, 41101, 41039, 41070, 40980, 41143, 41148, 41044, 41066, 41009, 41023, 41003, 41112, 41000, 41120, 41138, 40983, 20075, 20021, 20066, 20036, 20010, 20057, 20047, 20043, 20076, 20059, 20077, 20038, 20018, 20073, 20013, 20039, 20051, 20071, 20069, 20055, 20143, 20164, 20129, 20177, 20214, 20218, 20179, 20185, 20189, 20162, 20130, 20087, 20096, 20158, 20114, 20209, 20196, 20119, 20081, 20091, 20252, 20296, 20288, 20303, 20285, 20262, 20336, 20250, 20329, 20330, 20328, 20249, 20304, 20276, 20332, 20275, 20323, 20281, 20325, 20270, 20399, 20369, 20356, 20566, 20351, 20481, 20521, 20482, 20402, 20379, 20425, 20434, 20533, 20551, 20365, 20426, 20380, 20524, 20386, 20530, 20645, 20676, 20669, 20636, 20607, 20609, 20610, 20643, 20671, 20594, 20632, 20684, 20679, 20673, 20634, 20602, 20589, 20604, 20613, 20620, 51859, 51872, 51811, 51863, 51825, 51848, 51812, 51845, 51837, 51846, 51810, 51877, 51834, 51849, 51835, 51818, 51838, 51867, 51860, 51802, 51974, 51923, 52119, 52058, 51920, 51953, 51890, 52007, 52008, 51915, 52073, 51949, 52142, 52140, 52003, 52095, 52156, 52086, 51958, 51985, 49226, 49299, 49266, 49293, 49317, 49276, 49346, 49287, 49421, 49262, 49259, 49417, 49224, 49408, 49339, 49379, 49340, 49285, 49357, 49312, 50022, 50094, 50017, 50207, 50106, 50108, 50170, 50141, 50183, 50088, 50156, 50125, 50213, 49991, 50223, 49982, 50167, 50027, 50204, 50233, 43194, 43113, 43121, 43270, 43182, 43149, 43256, 43210, 43251, 43141, 43108, 43153, 43183, 43101, 43091, 43250, 43087, 43231, 43191, 43096, 23720, 23619, 23519, 23453, 23660, 23515, 23476, 23694, 23710, 23465, 23707, 23675, 23454, 23581, 23528, 23491, 23505, 23625, 23520, 23628, 315, 116, 181, 482, 859, 123, 604, 107, 715, 191, 225, 37, 203, 173, 359, 992, 92, 150, 2, 126, 1050, 1040, 1101, 1038, 1107, 1082, 1044, 1108, 1035, 1028, 1043, 1027, 1102, 1072, 1037, 1069, 1103, 1099, 1109, 1039, 1150, 1178, 1130, 1175, 1280, 1524, 1558, 1145, 1591, 1224, 1978, 1336, 1146, 1176, 1149, 1132, 1358, 1302, 1977, 1324, 1983, 2124, 2016, 2138, 2069, 2003, 2023, 2128, 1988, 2088, 2079, 2051, 2118, 2146, 2154, 2108, 2112, 2005, 2047, 2037, 4422, 4665, 4238, 5022, 3346, 2776, 5144, 3109, 2677, 4237, 2399, 5267, 5211, 2332, 5444, 5644, 5219, 5344, 5655, 5355, 8445, 6842, 8036, 5989, 7187, 7763, 5856, 8000, 7773, 6979, 8135, 7557, 7335, 7052, 7274, 8159, 7409, 8210, 7434, 7213, 50272, 50297, 50294, 50430, 50443, 50368, 50320, 50300, 50410, 50364, 50398, 50324, 50317, 50283, 50428, 50421, 50344, 50388, 50330, 50446, 48548, 48612, 48673, 48714, 48542, 48560, 48700, 48706, 48625, 48584, 48569, 48707, 48695, 48688, 48607, 48715, 48724, 48565, 48547, 48551, 47490, 47562, 47463, 47553, 47519, 47635, 47545, 47548, 47627, 47639, 47563, 47580, 47482, 47595, 47498, 47646, 47484, 47640, 47524, 47557, 10780, 17057, 13374, 18915, 9289, 10709, 18318, 8861, 10255, 10384, 8877, 18341, 8878, 10476, 9364, 18345, 9374, 9359, 9356, 8828, 17743, 18090, 18285, 17665, 17715, 17747, 18145, 17661, 17658, 17979, 17709, 18288, 18272, 17698, 17734, 18089, 18279, 17720, 17751, 17704, 19530, 19822, 19297, 19377, 19307, 19271, 19831, 19312, 19357, 19552, 19315, 19339, 19374, 19287, 19816, 19802, 19696, 19269, 19299, 19817, 19917, 19999, 19959, 19935, 20006, 19888, 19984, 19866, 19982, 19847, 19929, 19867, 19912, 19947, 19854, 19955, 19833, 19919, 19952, 19884, 50762, 51197, 50839, 51048, 51079, 51103, 50800, 51068, 50841, 51246, 51012, 51099, 50817, 50768, 50767, 51135, 50809, 51007, 51034, 50779] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ScarceNet 2 | 3 | **About** 4 | 5 | This is the source code for our paper 6 | 7 | Chen Li, Gim Hee Lee. ScarceNet: Animal Pose Estimation with Scarce Annotations. In CVPR 2023. 8 | 9 | In this paper, we aim to achieve accurate animal pose estimation with only a small set of labeled images and unlabeled images. 10 | 11 |

12 | 13 |

14 | 15 | We design a pseudo label based framework to learn from scarce animal pose data. We first apply the small-loss trick to select a set of reliable pseudo labels. Despite its effectiveness, pseudo label selection by the small-loss trick tends to discard numerous high-loss samples. This results in high 16 | wastage since those discarded samples can still provide extra information for better discrimination. In view of this, we propose a reusable sample re-labeling step to further identify reusable samples from the high-loss samples via an agreement check and re-generate the corresponding pseudo labels for supervision. Lastly, we design a student-teacher framework to enforce consistency between the outputs of the student and teacher network. 17 | 18 |

19 | 20 |

21 | 22 | For more details, please refer to [our paper](http://arxiv.org/abs/2303.15023). 23 | 24 | **Dependencies** 25 | 1. Python 3.7 26 | 2. Pytorch 1.7 27 | 28 | Please refer to requirements.txt for more details on dependencies. 29 | 30 | **Download datasets** 31 | 32 | * Clone this repository: 33 | 34 | ``` 35 | https://github.com/chaneyddtt/ScarceNet.git 36 | ``` 37 | * Download the [AP10K dataset](https://github.com/AlexTheBad/AP-10K). Put the data and annotations under root/data/animalpose/. 38 | * Download the [HRNet](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch) pretrained on the imagenet, and put under root/models/. 39 | 40 | **Test** 41 | 42 | Download [our models](https://drive.google.com/file/d/1MQGyyf1MQjETRG_CRBMEfrV0IFiis_ZQ/view?usp=share_link) and put them under the root/output folder. Test for the semi-supervised setting by running the command below. (We provide models trained with 5, 10, 15, 20 and 25 labels per category) 43 | ``` 44 | CUDA_VISIBLE_DEVICES=0 python tools/test.py --cfg experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml --animalpose OUTPUT_DIR test TEST.MODEL_FILE output/output_part5_updatev2/model_best.pth MODEL.NAME pose_hrnet_part GPUS [0,] 45 | ``` 46 | You can also test for the transfer learning setting by running (modify DATASET.SUPERCATEGORY to test on different animal categories): 47 | ``` 48 | CUDA_VISIBLE_DEVICES=0 python tools/test.py --cfg experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml --animalpose OUTPUT_DIR test TEST.MODEL_FILE output/output_part_updatev2_transfer/model_best.pth MODEL.NAME pose_hrnet_part GPUS [0,] DATASET.DATASET ap10k_test_category DATASET.SELECT_DATA True DATASET.SUPERCATEGORY "'deer'," 49 | ``` 50 | 51 | **Train** 52 | 53 | If you do not want to train from scratch, you can download [the models](https://drive.google.com/file/d/1Cel43-6dzj4o9xBIXa_cBsZcvvht5Uoz/view?usp=share_link) trained with the few labeled animal data, as well as the corresponding [pseudo labels](https://drive.google.com/file/d/152_mFWBO2Scc7MhcnsSWPnrCn4RsdhQX/view?usp=share_link) generated by these models. Move the folder pseudo labels under root/data/ and the pretrained models under root/output/. Train our model by running the command below. You can change the number of labels by setting 'LABEL_PER_CLASS' to 5, 10, 15, 20, 25 respectively. The pretrained model and output directory need to be changed acoordingly. 54 | ``` 55 | CUDA_VISIBLE_DEVICES=0,1 python tools/train_mt_part.py --cfg experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml --animalpose --augment --pretrained output/output_animal_hrnet5_part/model_best.pth OUTPUT_DIR output_part5_updatev2 DATASET.DATASET ap10k_mt_v3 MODEL.NAME pose_hrnet_part TRAIN.BATCH_SIZE_PER_GPU 16 LABEL_PER_CLASS 5 56 | ``` 57 | You can also train from scratch. Firstly, train the model with the few labels by running: 58 | ``` 59 | CUDA_VISIBLE_DEVICES=0,1 python tools/train.py --cfg experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml --animalpose OUTPUT_DIR output_animal_hrnet5_part DATASET.DATASET ap10k_fewshot MODEL.NAME pose_hrnet_part LABEL_PER_CLASS 5 60 | ``` 61 | Change the number of labels by setting 'LABEL_PER_CLASS' to 5, 10, 15, 20, 25 respectively, and the output directory need to be changed accordingly. Note that the labeled data are randomly selected from each category. 62 | 63 | Create folder root/data/pseudo_labels/5shots/ and generate pseudo labels by running: 64 | ``` 65 | CUDA_VISIBLE_DEVICES=0 python tools/train_mt_part.py --cfg experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml --animalpose --generate_pseudol --pretrained output/output_animal_hrnet5_part/model_best.pth OUTPUT_DIR test MODEL.NAME pose_hrnet_part TRAIN.BATCH_SIZE_PER_GPU 32 GPUS [0,] LABEL_PER_CLASS 5 66 | ``` 67 | Train the whole pipeline by running: 68 | ``` 69 | CUDA_VISIBLE_DEVICES=0,1 python tools/train_mt_part.py --cfg experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml --animalpose --augment --pretrained output/output_animal_hrnet5_part/model_best.pth OUTPUT_DIR output_part5_updatev2 DATASET.DATASET ap10k_mt_v3 MODEL.NAME pose_hrnet_part TRAIN.BATCH_SIZE_PER_GPU 16 LABEL_PER_CLASS 5 70 | ``` 71 | 72 | The training steps for the transfer setting are similar. Firstly train with the labels from the Bovidae by running: 73 | ``` 74 | CUDA_VISIBLE_DEVICES=0,1 python tools/train.py --cfg experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml --animalpose OUTPUT_DIR output_animal_hrnet_part_bovidae DATASET.DATASET ap10k MODEL.NAME pose_hrnet_part DATASET.SELECT_DATA True 75 | ``` 76 | 77 | Then create folder root/data/pseudo_labels/0shots, and generate pseudo labels by running: 78 | ``` 79 | CUDA_VISIBLE_DEVICES=0 python tools/train_mt_part.py --cfg experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml --animalpose --generate_pseudol --pretrained output/output_animal_hrnet_part_bovidae/model_best.pth OUTPUT_DIR test MODEL.NAME pose_hrnet_part TRAIN.BATCH_SIZE_PER_GPU 32 GPUS [0,] LABEL_PER_CLASS 0 80 | ``` 81 | Lastly, train the whole pipeline by running: 82 | ``` 83 | CUDA_VISIBLE_DEVICES=0,1 python tools/train_mt_part.py --cfg experiments/ap10k/hrnet/w32_256x192_adam_lr1e-3.yaml --animalpose --augment --pretrained output/output_animal_hrnet_part_bovidae/model_best.pth --few_shot_setting OUTPUT_DIR output_part_updatev2_transfer DATASET.DATASET ap10k_mt_v3 MODEL.NAME pose_hrnet_part TRAIN.BATCH_SIZE_PER_GPU 16 LABEL_PER_CLASS 0 DATASET.SELECT_DATA True 84 | ``` 85 | 86 | **Acknowledgements** 87 | 88 | The code for network architecture, data preprocessing, and evaluation are adapted from [HRNet](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch) and [AP-10K](https://github.com/AlexTheBad/AP-10K). 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /lib/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import numpy as np 12 | import cv2 13 | 14 | 15 | def flip_back(output_flipped, matched_parts): 16 | ''' 17 | ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width) 18 | ''' 19 | assert output_flipped.ndim == 4,\ 20 | 'output_flipped should be [batch_size, num_joints, height, width]' 21 | 22 | output_flipped = output_flipped[:, :, :, ::-1] 23 | 24 | for pair in matched_parts: 25 | tmp = output_flipped[:, pair[0], :, :].copy() 26 | output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] 27 | output_flipped[:, pair[1], :, :] = tmp 28 | 29 | return output_flipped 30 | 31 | 32 | def fliplr_joints(joints, joints_vis, width, matched_parts): 33 | """ 34 | flip coords 35 | """ 36 | # Flip horizontal 37 | joints[:, 0] = width - joints[:, 0] - 1 38 | 39 | # Change left-right parts 40 | for pair in matched_parts: 41 | joints[pair[0], :], joints[pair[1], :] = \ 42 | joints[pair[1], :], joints[pair[0], :].copy() 43 | joints_vis[pair[0], :], joints_vis[pair[1], :] = \ 44 | joints_vis[pair[1], :], joints_vis[pair[0], :].copy() 45 | 46 | return joints*joints_vis, joints_vis 47 | 48 | 49 | def fliplr_joints_batch(joints, joints_vis, width, matched_parts): 50 | """ 51 | flip coords 52 | """ 53 | # Flip horizontal 54 | joints_flip = joints.copy() 55 | joints_flip[:, :, 0] = width - joints_flip[:, :, 0] - 1 56 | joints_vis_flip = joints_vis.copy() 57 | # Change left-right parts 58 | for pair in matched_parts: 59 | joints_flip[:, pair[0], :], joints_flip[:, pair[1], :] = \ 60 | joints_flip[:, pair[1], :], joints_flip[:, pair[0], :].copy() 61 | joints_vis_flip[:, pair[0], :], joints_vis_flip[:, pair[1], :] = \ 62 | joints_vis_flip[:, pair[1], :], joints_vis_flip[:, pair[0], :].copy() 63 | return joints_flip*joints_vis_flip, joints_vis_flip 64 | 65 | 66 | # do not set the invisible joint coordinate to zero because we do not care about the visibility during relabeling, 67 | # we will use it along as it fulfils the agreement check 68 | def fliplr_joints_batch_v2(joints, joints_vis, width, matched_parts): 69 | """ 70 | flip coords 71 | """ 72 | # Flip horizontal 73 | joints_flip = joints.copy() 74 | joints_flip[:, :, 0] = width - joints_flip[:, :, 0] - 1 75 | joints_vis_flip = joints_vis.copy() 76 | # Change left-right parts 77 | for pair in matched_parts: 78 | joints_flip[:, pair[0], :], joints_flip[:, pair[1], :] = \ 79 | joints_flip[:, pair[1], :], joints_flip[:, pair[0], :].copy() 80 | joints_vis_flip[:, pair[0], :], joints_vis_flip[:, pair[1], :] = \ 81 | joints_vis_flip[:, pair[1], :], joints_vis_flip[:, pair[0], :].copy() 82 | return joints_flip, joints_vis_flip 83 | 84 | 85 | def fliplr_weights_batch(joints_vis, matched_parts): 86 | """ 87 | flip coords 88 | """ 89 | # Flip horizontal 90 | joints_vis_flip = joints_vis.copy() 91 | # Change left-right parts 92 | for pair in matched_parts: 93 | joints_vis_flip[:, pair[0], :], joints_vis_flip[:, pair[1], :] = \ 94 | joints_vis_flip[:, pair[1], :], joints_vis_flip[:, pair[0], :].copy() 95 | return joints_vis_flip 96 | 97 | 98 | def transform_preds(coords, center, scale, output_size): 99 | target_coords = np.zeros(coords.shape) 100 | trans = get_affine_transform(center, scale, 0, output_size, inv=1) 101 | for p in range(coords.shape[0]): 102 | target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) 103 | return target_coords 104 | 105 | 106 | def get_affine_transform( 107 | center, scale, rot, output_size, 108 | shift=np.array([0, 0], dtype=np.float32), inv=0 109 | ): 110 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 111 | # print(scale) 112 | scale = np.array([scale, scale]) 113 | 114 | scale_tmp = scale * 200.0 115 | src_w = scale_tmp[0] 116 | dst_w = output_size[0] 117 | dst_h = output_size[1] 118 | 119 | rot_rad = np.pi * rot / 180 120 | src_dir = get_dir([0, src_w * -0.5], rot_rad) 121 | dst_dir = np.array([0, dst_w * -0.5], np.float32) 122 | 123 | src = np.zeros((3, 2), dtype=np.float32) 124 | dst = np.zeros((3, 2), dtype=np.float32) 125 | src[0, :] = center + scale_tmp * shift 126 | src[1, :] = center + src_dir + scale_tmp * shift 127 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 128 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 129 | 130 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 131 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 132 | 133 | if inv: 134 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 135 | else: 136 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 137 | 138 | return trans 139 | 140 | 141 | def affine_transform(pt, t): 142 | new_pt = np.array([pt[0], pt[1], 1.]).T 143 | new_pt = np.dot(t, new_pt) 144 | return new_pt[:2] 145 | 146 | 147 | def get_3rd_point(a, b): 148 | direct = a - b 149 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 150 | 151 | 152 | def get_dir(src_point, rot_rad): 153 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 154 | 155 | src_result = [0, 0] 156 | src_result[0] = src_point[0] * cs - src_point[1] * sn 157 | src_result[1] = src_point[0] * sn + src_point[1] * cs 158 | 159 | return src_result 160 | 161 | 162 | def crop(img, center, scale, output_size, rot=0): 163 | trans = get_affine_transform(center, scale, rot, output_size) 164 | 165 | dst_img = cv2.warpAffine( 166 | img, trans, (int(output_size[0]), int(output_size[1])), 167 | flags=cv2.INTER_LINEAR 168 | ) 169 | 170 | return dst_img 171 | 172 | 173 | def get_transform(center, scale, res, rot=0): 174 | """ 175 | General image processing functions 176 | """ 177 | # Generate transformation matrix 178 | h = 200 * scale[0] 179 | t = np.zeros((3, 3)) 180 | t[0, 0] = float(res[1]) / h 181 | t[1, 1] = float(res[0]) / h 182 | t[0, 2] = res[1] * (-float(center[0]) / h + .5) 183 | t[1, 2] = res[0] * (-float(center[1]) / h + .5) 184 | t[2, 2] = 1 185 | if not rot == 0: 186 | rot = -rot # To match direction of rotation from cropping 187 | rot_mat = np.zeros((3,3)) 188 | rot_rad = rot * np.pi / 180 189 | sn,cs = np.sin(rot_rad), np.cos(rot_rad) 190 | rot_mat[0,:2] = [cs, -sn] 191 | rot_mat[1,:2] = [sn, cs] 192 | rot_mat[2,2] = 1 193 | # Need to rotate around center 194 | t_mat = np.eye(3) 195 | t_mat[0,2] = -res[1]/2 196 | t_mat[1,2] = -res[0]/2 197 | t_inv = t_mat.copy() 198 | t_inv[:2,2] *= -1 199 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) 200 | return t 201 | -------------------------------------------------------------------------------- /lib/dataset/mpii.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import logging 12 | import os 13 | import json_tricks as json 14 | from collections import OrderedDict 15 | 16 | import numpy as np 17 | from scipy.io import loadmat, savemat 18 | 19 | from dataset.JointsDataset import JointsDataset 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class MPIIDataset(JointsDataset): 26 | def __init__(self, cfg, root, image_set, is_train, transform=None): 27 | super().__init__(cfg, root, image_set, is_train, transform) 28 | 29 | self.num_joints = 16 30 | self.flip_pairs = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]] 31 | self.parent_ids = [1, 2, 6, 6, 3, 4, 6, 6, 7, 8, 11, 12, 7, 7, 13, 14] 32 | 33 | self.upper_body_ids = (7, 8, 9, 10, 11, 12, 13, 14, 15) 34 | self.lower_body_ids = (0, 1, 2, 3, 4, 5, 6) 35 | 36 | self.db = self._get_db() 37 | 38 | if is_train and cfg.DATASET.SELECT_DATA: 39 | self.db = self.select_data(self.db) 40 | 41 | logger.info('=> load {} samples'.format(len(self.db))) 42 | 43 | def _get_db(self): 44 | # create train/val split 45 | file_name = os.path.join( 46 | self.root, 'annot', self.image_set+'.json' 47 | ) 48 | with open(file_name) as anno_file: 49 | anno = json.load(anno_file) 50 | 51 | gt_db = [] 52 | for a in anno: 53 | image_name = a['image'] 54 | 55 | c = np.array(a['center'], dtype=np.float) 56 | s = np.array([a['scale'], a['scale']], dtype=np.float) 57 | 58 | # Adjust center/scale slightly to avoid cropping limbs 59 | if c[0] != -1: 60 | c[1] = c[1] + 15 * s[1] 61 | s = s * 1.25 62 | 63 | # MPII uses matlab format, index is based 1, 64 | # we should first convert to 0-based index 65 | c = c - 1 66 | 67 | joints_3d = np.zeros((self.num_joints, 3), dtype=np.float) 68 | joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float) 69 | if self.image_set != 'test': 70 | joints = np.array(a['joints']) 71 | joints[:, 0:2] = joints[:, 0:2] - 1 72 | joints_vis = np.array(a['joints_vis']) 73 | assert len(joints) == self.num_joints, \ 74 | 'joint num diff: {} vs {}'.format(len(joints), 75 | self.num_joints) 76 | 77 | joints_3d[:, 0:2] = joints[:, 0:2] 78 | joints_3d_vis[:, 0] = joints_vis[:] 79 | joints_3d_vis[:, 1] = joints_vis[:] 80 | 81 | image_dir = 'images.zip@' if self.data_format == 'zip' else 'images' 82 | gt_db.append( 83 | { 84 | 'image': os.path.join(self.root, image_dir, image_name), 85 | 'center': c, 86 | 'scale': s, 87 | 'joints_3d': joints_3d, 88 | 'joints_3d_vis': joints_3d_vis, 89 | 'filename': '', 90 | 'imgnum': 0, 91 | } 92 | ) 93 | 94 | return gt_db 95 | 96 | def evaluate(self, cfg, preds, output_dir, *args, **kwargs): 97 | # convert 0-based index to 1-based index 98 | preds = preds[:, :, 0:2] + 1.0 99 | 100 | if output_dir: 101 | pred_file = os.path.join(output_dir, 'pred.mat') 102 | savemat(pred_file, mdict={'preds': preds}) 103 | 104 | if 'test' in cfg.DATASET.TEST_SET: 105 | return {'Null': 0.0}, 0.0 106 | 107 | SC_BIAS = 0.6 108 | threshold = 0.5 109 | 110 | gt_file = os.path.join(cfg.DATASET.ROOT, 111 | 'annot', 112 | 'gt_{}.mat'.format(cfg.DATASET.TEST_SET)) 113 | gt_dict = loadmat(gt_file) 114 | dataset_joints = gt_dict['dataset_joints'] 115 | jnt_missing = gt_dict['jnt_missing'] 116 | pos_gt_src = gt_dict['pos_gt_src'] 117 | headboxes_src = gt_dict['headboxes_src'] 118 | 119 | pos_pred_src = np.transpose(preds, [1, 2, 0]) 120 | 121 | head = np.where(dataset_joints == 'head')[1][0] 122 | lsho = np.where(dataset_joints == 'lsho')[1][0] 123 | lelb = np.where(dataset_joints == 'lelb')[1][0] 124 | lwri = np.where(dataset_joints == 'lwri')[1][0] 125 | lhip = np.where(dataset_joints == 'lhip')[1][0] 126 | lkne = np.where(dataset_joints == 'lkne')[1][0] 127 | lank = np.where(dataset_joints == 'lank')[1][0] 128 | 129 | rsho = np.where(dataset_joints == 'rsho')[1][0] 130 | relb = np.where(dataset_joints == 'relb')[1][0] 131 | rwri = np.where(dataset_joints == 'rwri')[1][0] 132 | rkne = np.where(dataset_joints == 'rkne')[1][0] 133 | rank = np.where(dataset_joints == 'rank')[1][0] 134 | rhip = np.where(dataset_joints == 'rhip')[1][0] 135 | 136 | jnt_visible = 1 - jnt_missing 137 | uv_error = pos_pred_src - pos_gt_src 138 | uv_err = np.linalg.norm(uv_error, axis=1) 139 | headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :] 140 | headsizes = np.linalg.norm(headsizes, axis=0) 141 | headsizes *= SC_BIAS 142 | scale = np.multiply(headsizes, np.ones((len(uv_err), 1))) 143 | scaled_uv_err = np.divide(uv_err, scale) 144 | scaled_uv_err = np.multiply(scaled_uv_err, jnt_visible) 145 | jnt_count = np.sum(jnt_visible, axis=1) 146 | less_than_threshold = np.multiply((scaled_uv_err <= threshold), 147 | jnt_visible) 148 | PCKh = np.divide(100.*np.sum(less_than_threshold, axis=1), jnt_count) 149 | 150 | # save 151 | rng = np.arange(0, 0.5+0.01, 0.01) 152 | pckAll = np.zeros((len(rng), 16)) 153 | 154 | for r in range(len(rng)): 155 | threshold = rng[r] 156 | less_than_threshold = np.multiply(scaled_uv_err <= threshold, 157 | jnt_visible) 158 | pckAll[r, :] = np.divide(100.*np.sum(less_than_threshold, axis=1), 159 | jnt_count) 160 | 161 | PCKh = np.ma.array(PCKh, mask=False) 162 | PCKh.mask[6:8] = True 163 | 164 | jnt_count = np.ma.array(jnt_count, mask=False) 165 | jnt_count.mask[6:8] = True 166 | jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64) 167 | 168 | name_value = [ 169 | ('Head', PCKh[head]), 170 | ('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])), 171 | ('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])), 172 | ('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])), 173 | ('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])), 174 | ('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])), 175 | ('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])), 176 | ('Mean', np.sum(PCKh * jnt_ratio)), 177 | ('Mean@0.1', np.sum(pckAll[11, :] * jnt_ratio)) 178 | ] 179 | name_value = OrderedDict(name_value) 180 | 181 | return name_value, name_value['Mean'] 182 | -------------------------------------------------------------------------------- /data/label_list/annotation_list_25: -------------------------------------------------------------------------------- 1 | [48752, 48839, 48787, 48876, 48866, 48850, 48859, 48768, 48725, 48848, 48738, 48790, 48864, 48817, 48728, 48852, 48827, 48904, 48753, 48746, 48796, 48808, 48907, 48782, 48751, 49094, 49026, 49000, 48950, 49031, 48951, 49084, 48912, 49071, 48944, 48966, 48971, 49041, 49085, 48949, 49099, 48984, 49025, 49001, 49012, 49051, 49028, 49075, 49055, 49082, 28605, 26976, 29435, 29350, 30228, 29940, 29485, 29380, 28269, 30273, 27713, 30006, 29326, 29873, 29067, 26970, 27099, 29116, 30350, 29091, 29152, 27269, 28943, 29510, 28852, 30999, 30740, 31066, 30595, 30977, 30971, 31643, 31554, 30806, 30970, 30984, 30651, 30507, 31566, 30817, 31443, 31088, 30983, 31133, 30990, 31577, 31233, 30992, 31199, 31555, 49245, 49312, 49309, 49406, 49265, 49287, 49269, 49226, 49323, 49317, 49303, 49276, 49346, 49264, 49291, 49349, 49296, 49410, 49236, 49350, 49263, 49360, 49369, 49241, 49290, 49984, 50041, 50193, 50104, 50220, 50053, 50095, 50208, 50055, 50214, 49978, 50052, 50235, 50105, 50004, 50238, 49993, 50191, 50213, 50036, 50231, 50145, 50022, 49983, 50147, 200, 149, 103, 135, 237, 180, 108, 155, 58, 217, 282, 143, 234, 107, 226, 221, 96, 172, 260, 240, 125, 106, 715, 90, 84, 1079, 1038, 1075, 1092, 1083, 1050, 1072, 1060, 1073, 1091, 1030, 1061, 1098, 1103, 1028, 1108, 1036, 1040, 1025, 1037, 1109, 1064, 1031, 1058, 1049, 1513, 1218, 1935, 1680, 1193, 1977, 1128, 1122, 1197, 1217, 1813, 1182, 1144, 1132, 1236, 1117, 1702, 1126, 1175, 1143, 1114, 1222, 1967, 1224, 1172, 2156, 2109, 2106, 2158, 2151, 2104, 2111, 2015, 2173, 1988, 1983, 2174, 2138, 2069, 2099, 2072, 2160, 2095, 2112, 2123, 2146, 2088, 2075, 2150, 2159, 4521, 4331, 3346, 4238, 5033, 4657, 4086, 4781, 4768, 4422, 2821, 3717, 2688, 5433, 4879, 2455, 5096, 4706, 4109, 2610, 4373, 5444, 4853, 4817, 5011, 7102, 7249, 7753, 7745, 7773, 7286, 6966, 7409, 7274, 7213, 8050, 7446, 8063, 7964, 7336, 8469, 7728, 7557, 7734, 5822, 7763, 6111, 5862, 5989, 8246, 43256, 43101, 43109, 43203, 43179, 43211, 43126, 43193, 43189, 43250, 43087, 43105, 43107, 43113, 43093, 43174, 43213, 43095, 43226, 43188, 43244, 43210, 43092, 43106, 43231, 45128, 45039, 45087, 45098, 45001, 44976, 44943, 44926, 44969, 44960, 44953, 45082, 44962, 45015, 44998, 44935, 44930, 45014, 45086, 45116, 45020, 45013, 45113, 44986, 44981, 46201, 46273, 46157, 46215, 46251, 46286, 46271, 46165, 46183, 46193, 46168, 46256, 46233, 46176, 46238, 46209, 46228, 46290, 46178, 46268, 46284, 46265, 46109, 46149, 46151, 20054, 20033, 20072, 20075, 20040, 20038, 20066, 20020, 20023, 20049, 20055, 20044, 20078, 20043, 20017, 20067, 20039, 20057, 20032, 20051, 20076, 20026, 20010, 20013, 20061, 20165, 20191, 20234, 20103, 20155, 20100, 20180, 20184, 20122, 20172, 20127, 20235, 20158, 20242, 20213, 20230, 20237, 20173, 20104, 20194, 20124, 20116, 20176, 20139, 20121, 20329, 20280, 20250, 20252, 20328, 20278, 20245, 20336, 20277, 20246, 20270, 20279, 20332, 20275, 20330, 20265, 20290, 20253, 20272, 20305, 20298, 20311, 20249, 20304, 20292, 20399, 20455, 20354, 20376, 20568, 20338, 20368, 20427, 20452, 20496, 20535, 20385, 20493, 20403, 20437, 20475, 20566, 20581, 20543, 20339, 20557, 20538, 20411, 20480, 20422, 20682, 20639, 20665, 20610, 20592, 20670, 20655, 20646, 20601, 20632, 20645, 20607, 20634, 20609, 20661, 20642, 20684, 20683, 20630, 20672, 20643, 20599, 20635, 20602, 20620, 48578, 48706, 48607, 48710, 48546, 48649, 48692, 48713, 48696, 48708, 48639, 48567, 48587, 48542, 48584, 48678, 48637, 48707, 48571, 48724, 48547, 48716, 48570, 48588, 48615, 31989, 32064, 31957, 32017, 32013, 32054, 31950, 31930, 32065, 32072, 31904, 31919, 31969, 31952, 31923, 32047, 32073, 32005, 32029, 31954, 32002, 31892, 31932, 32011, 32071, 35695, 34250, 40406, 32828, 40436, 40808, 35616, 40782, 40047, 40658, 39933, 40778, 40269, 34461, 40492, 40428, 40470, 36017, 40413, 34984, 40812, 35550, 33306, 34906, 36239, 36434, 36338, 36481, 36482, 36367, 36319, 36394, 36428, 36397, 36365, 36477, 36420, 36511, 36342, 36361, 36493, 36487, 36470, 36321, 36341, 36333, 36307, 36362, 36404, 36349, 37590, 37591, 37728, 37559, 37826, 37679, 37713, 37833, 37584, 37815, 37522, 37697, 37805, 37626, 37534, 37601, 37618, 37526, 37588, 37698, 37736, 37685, 37785, 37681, 37733, 37888, 37889, 37891, 37892, 37893, 37894, 37895, 37866, 37867, 37870, 37871, 37872, 37874, 37876, 37877, 37878, 37881, 37882, 37883, 37884, 37885, 37886, 38497, 38286, 38041, 37993, 38674, 38353, 37979, 38675, 38031, 38486, 37909, 38035, 37953, 38036, 37901, 37968, 37910, 38408, 37948, 38704, 38032, 38017, 37997, 38047, 38023, 38752, 38832, 38785, 38739, 38855, 38840, 38748, 38914, 38728, 38836, 38737, 38819, 38721, 38863, 38792, 38811, 38860, 38825, 38852, 38738, 38900, 38843, 38731, 38745, 38751, 39908, 39779, 39863, 39746, 39768, 39892, 39857, 39889, 39862, 39923, 39836, 39803, 39770, 39852, 39781, 39813, 39871, 39841, 39802, 39920, 39751, 39918, 39780, 39745, 39915, 40870, 40894, 40934, 40912, 40846, 40868, 40849, 40904, 40909, 40928, 40921, 40930, 40848, 40887, 40897, 40929, 40860, 40916, 40886, 40882, 40893, 40875, 40880, 40850, 40925, 40945, 41080, 40961, 40964, 41120, 41131, 41039, 41051, 41138, 40966, 40952, 41102, 41003, 41054, 41045, 41078, 41011, 41104, 41116, 41041, 40969, 41067, 41097, 41114, 41020, 47652, 47533, 47541, 47673, 47507, 47486, 47638, 47558, 47627, 47456, 47595, 47510, 47498, 47502, 47599, 47484, 47572, 47611, 47602, 47505, 47590, 47589, 47639, 47647, 47493, 50272, 50350, 50326, 50258, 50279, 50309, 50427, 50378, 50292, 50259, 50348, 50355, 50385, 50308, 50341, 50382, 50334, 50373, 50325, 50278, 50409, 50417, 50262, 50426, 50359, 19982, 19976, 20004, 19921, 19858, 19953, 19893, 19926, 19947, 19901, 19984, 19999, 19985, 19952, 20007, 19937, 19881, 19865, 19835, 19912, 19973, 19948, 19890, 19916, 19918, 44248, 44356, 44302, 44294, 44345, 44331, 44280, 44357, 44722, 44511, 44600, 44317, 44589, 44298, 44244, 44308, 44343, 44909, 44285, 44711, 44292, 44500, 44262, 44273, 44655, 23699, 23645, 23483, 23579, 23710, 23624, 23724, 23523, 23610, 23732, 23690, 23628, 23703, 23534, 23435, 23706, 23566, 23568, 23720, 23684, 23604, 23455, 23603, 23485, 23707, 54982, 54901, 54913, 54906, 54851, 54828, 54829, 55000, 54951, 54987, 54947, 54949, 54955, 54918, 54856, 54935, 54817, 54892, 54816, 54838, 54831, 54917, 54985, 55004, 54952, 20762, 20795, 20724, 20756, 20790, 20721, 20699, 20850, 20860, 20742, 20686, 20792, 20863, 20700, 20738, 20703, 20789, 20879, 20866, 20770, 20869, 20874, 20858, 20778, 20716, 22642, 22058, 22330, 22649, 22646, 22082, 22640, 22274, 22016, 22067, 22015, 22068, 22076, 22022, 22627, 22196, 22285, 22218, 22085, 22152, 22207, 22429, 22363, 22653, 22115, 55645, 55653, 55631, 55655, 55664, 55669, 55627, 55657, 55630, 55647, 55652, 55659, 55661, 55634, 55624, 55626, 55667, 55642, 55663, 55625, 55635, 55648, 55665, 55658, 55641, 55884, 55892, 55782, 55899, 55893, 55709, 55764, 55699, 55901, 55859, 55783, 55704, 55873, 55862, 55748, 55705, 55755, 55674, 55767, 55820, 55902, 55839, 55692, 55750, 55896, 56696, 56745, 56707, 56557, 56580, 56656, 56627, 56655, 56714, 56608, 56693, 56670, 56748, 56692, 56625, 56675, 56642, 56597, 56561, 56702, 56687, 56673, 56605, 56728, 56571, 58535, 58492, 58528, 58558, 58450, 58592, 58463, 58569, 58470, 58476, 58523, 58546, 58480, 58447, 58506, 58497, 58475, 58474, 58575, 58510, 58520, 58618, 58504, 58478, 58577, 18324, 19159, 10785, 12869, 16735, 10364, 10359, 18412, 10347, 8883, 18406, 19026, 14079, 13409, 10498, 18959, 8857, 9014, 9248, 12721, 12862, 18393, 9409, 12717, 8830, 17709, 17658, 17682, 17680, 17646, 17755, 18167, 18201, 17720, 17655, 17690, 18156, 18303, 17713, 17707, 17879, 18045, 17747, 18001, 18280, 17714, 17708, 17660, 17669, 17665, 19792, 19350, 19816, 19552, 19652, 19441, 19826, 19789, 19810, 19306, 19797, 19267, 19812, 19619, 19279, 19260, 19258, 19269, 19290, 19363, 19719, 19831, 19300, 19264, 19302, 23292, 23419, 22774, 22730, 23409, 23004, 22691, 22993, 22715, 23303, 22717, 23337, 22751, 22702, 23215, 22704, 22662, 23370, 22837, 22735, 22772, 23429, 22737, 22709, 22763, 51809, 51876, 51853, 51829, 51870, 51807, 51812, 51878, 51863, 51814, 51805, 51825, 51854, 51822, 51813, 51835, 51818, 51808, 51827, 51881, 51856, 51802, 51880, 51819, 51850, 52086, 52068, 51983, 52130, 51982, 52084, 52059, 51903, 51899, 52150, 52058, 52038, 51939, 52096, 52103, 52137, 52032, 51896, 51993, 51925, 51953, 52007, 52034, 51980, 52071, 50813, 50815, 50825, 50857, 50957, 51068, 51128, 50790, 51020, 50858, 50820, 50954, 51112, 50929, 50870, 51078, 51246, 51024, 50913, 51228, 51192, 51179, 50967, 51086, 50760] -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import os 8 | import pprint 9 | 10 | import torch 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim as optim 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | import torchvision.transforms as transforms 17 | from tensorboardX import SummaryWriter 18 | 19 | import _init_paths 20 | from config import cfg 21 | from config import update_config 22 | from core.loss import JointsMSELoss 23 | from core.function import train 24 | from core.function import validate 25 | from utils.utils import get_optimizer 26 | from utils.utils import save_checkpoint 27 | from utils.utils import create_logger 28 | from utils.utils import get_model_summary 29 | import models 30 | import dataset_animal 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser(description='Train keypoints network') 34 | # general 35 | parser.add_argument('--cfg', 36 | help='experiment configure file name', 37 | required=True, 38 | type=str) 39 | 40 | parser.add_argument('opts', 41 | help="Modify config options using the command-line", 42 | default=None, 43 | nargs=argparse.REMAINDER) 44 | 45 | # philly 46 | parser.add_argument('--modelDir', 47 | help='model directory', 48 | type=str, 49 | default='') 50 | parser.add_argument('--logDir', 51 | help='log directory', 52 | type=str, 53 | default='') 54 | parser.add_argument('--dataDir', 55 | help='data directory', 56 | type=str, 57 | default='') 58 | parser.add_argument('--prevModelDir', 59 | help='prev Model directory', 60 | type=str, 61 | default='') 62 | 63 | parser.add_argument('--animalpose', 64 | help='train on ap10k', 65 | action='store_true') 66 | 67 | parser.add_argument('--fewshot', 68 | help='train on ap10k with few shot annotations', 69 | action='store_true') 70 | 71 | parser.add_argument('--pretrained', 72 | help='path for pretrained model', 73 | type=str, 74 | default='') 75 | parser.add_argument('--resume', help='path to resume', type=str, default='') 76 | 77 | parser.add_argument('--evaluate', action='store_true') 78 | args = parser.parse_args() 79 | 80 | return args 81 | 82 | 83 | def main(): 84 | args = parse_args() 85 | update_config(cfg, args) 86 | 87 | logger, final_output_dir, tb_log_dir = create_logger( 88 | cfg, args.cfg, 'train') 89 | 90 | logger.info(pprint.pformat(args)) 91 | logger.info(cfg) 92 | 93 | # cudnn related setting 94 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 95 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 96 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 97 | 98 | model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( 99 | cfg, is_train=True 100 | ) 101 | 102 | writer_dict = { 103 | 'writer': SummaryWriter(log_dir=tb_log_dir), 104 | 'train_global_steps': 0, 105 | 'valid_global_steps': 0, 106 | } 107 | 108 | dump_input = torch.rand( 109 | (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]) 110 | ) 111 | 112 | logger.info(get_model_summary(model, dump_input)) 113 | 114 | # define loss function (criterion) and optimizer 115 | criterion = JointsMSELoss( 116 | use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT 117 | ).cuda() 118 | 119 | # Data loading code 120 | normalize = transforms.Normalize( 121 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 122 | ) 123 | 124 | if args.animalpose: 125 | train_dataset = eval('dataset_animal.' + cfg.DATASET.DATASET)( 126 | cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True, 127 | transforms.Compose([ 128 | transforms.ToTensor(), 129 | normalize, 130 | ]) 131 | ) 132 | valid_dataset = eval('dataset_animal.' + 'ap10k')( 133 | cfg, cfg.DATASET.ROOT, cfg.DATASET.VAL_SET, False, 134 | transforms.Compose([ 135 | transforms.ToTensor(), 136 | normalize, 137 | ]) 138 | ) 139 | else: 140 | train_dataset = eval('dataset.'+cfg.DATASET.DATASET)( 141 | cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True, 142 | transforms.Compose([ 143 | transforms.ToTensor(), 144 | normalize, 145 | ]) 146 | ) 147 | valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)( 148 | cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False, 149 | transforms.Compose([ 150 | transforms.ToTensor(), 151 | normalize, 152 | ]) 153 | ) 154 | 155 | train_loader = torch.utils.data.DataLoader( 156 | train_dataset, 157 | batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS), 158 | shuffle=cfg.TRAIN.SHUFFLE, 159 | num_workers=cfg.WORKERS, 160 | pin_memory=cfg.PIN_MEMORY 161 | ) 162 | valid_loader = torch.utils.data.DataLoader( 163 | valid_dataset, 164 | batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS), 165 | shuffle=False, 166 | num_workers=cfg.WORKERS, 167 | pin_memory=cfg.PIN_MEMORY 168 | ) 169 | 170 | best_perf = 0.0 171 | last_epoch = -1 172 | best_perf_epoch = 0 173 | optimizer = get_optimizer(cfg, model) 174 | 175 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH 176 | checkpoint_file = os.path.join( 177 | final_output_dir, 'checkpoint.pth' 178 | ) 179 | 180 | # load pretrained model 181 | if os.path.exists(args.pretrained): 182 | logger.info("=> loading checkpoint '{}'".format(args.pretrained)) 183 | pretrained_model = torch.load(args.pretrained) 184 | model.load_state_dict(pretrained_model['state_dict']) 185 | 186 | # resume pretrained model 187 | if os.path.exists(args.resume): 188 | logger.info("=> resume from checkpoint '{}'".format(args.resume)) 189 | checkpoint = torch.load(args.resume) 190 | begin_epoch = checkpoint['epoch'] 191 | best_perf = checkpoint['perf'] 192 | last_epoch = checkpoint['epoch'] 193 | model.load_state_dict(checkpoint['state_dict']) 194 | optimizer.load_state_dict(checkpoint['optimizer']) 195 | logger.info("=> loaded checkpoint '{}' (epoch {})".format( 196 | checkpoint_file, checkpoint['epoch'])) 197 | 198 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 199 | optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR, 200 | last_epoch=last_epoch 201 | ) 202 | 203 | model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() 204 | 205 | # evaluate 206 | if args.evaluate: 207 | 208 | acc = validate(cfg, valid_loader, valid_dataset, model, criterion, 209 | final_output_dir, tb_log_dir, writer_dict, args.animalpose) 210 | return 211 | 212 | # train 213 | for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH): 214 | lr_scheduler.step() 215 | 216 | # train for one epoch 217 | train(cfg, train_loader, model, criterion, optimizer, epoch, 218 | final_output_dir, tb_log_dir, writer_dict) 219 | 220 | # evaluate on validation set 221 | perf_indicator = validate( 222 | cfg, valid_loader, valid_dataset, model, criterion, 223 | final_output_dir, tb_log_dir, writer_dict, args.animalpose 224 | ) 225 | 226 | if perf_indicator >= best_perf: 227 | best_perf = perf_indicator 228 | best_model = True 229 | best_perf_epoch = epoch + 1 230 | else: 231 | best_model = False 232 | 233 | logger.info('=> saving checkpoint to {}'.format(final_output_dir)) 234 | save_checkpoint({ 235 | 'epoch': epoch + 1, 236 | 'model': cfg.MODEL.NAME, 237 | 'state_dict': model.module.state_dict(), 238 | 'perf': perf_indicator, 239 | 'optimizer': optimizer.state_dict(), 240 | }, best_model, final_output_dir) 241 | 242 | final_model_state_file = os.path.join( 243 | final_output_dir, 'final_state.pth' 244 | ) 245 | logger.info('=> saving final model state to {}'.format( 246 | final_model_state_file) 247 | ) 248 | logger.info('Best accuracy {} at epoch {}'.format(best_perf, best_perf_epoch)) 249 | torch.save(model.module.state_dict(), final_model_state_file) 250 | writer_dict['writer'].close() 251 | 252 | 253 | if __name__ == '__main__': 254 | main() 255 | -------------------------------------------------------------------------------- /lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import logging 13 | import shutil 14 | import time 15 | from collections import namedtuple 16 | from pathlib import Path 17 | import numpy as np 18 | 19 | import torch 20 | import torch.optim as optim 21 | import torch.nn as nn 22 | 23 | 24 | def create_logger(cfg, cfg_name, phase='train'): 25 | # root_output_dir = Path(cfg.OUTPUT_DIR) 26 | root_output_dir = Path('output') 27 | # set up logger 28 | if not root_output_dir.exists(): 29 | print('=> creating {}'.format(root_output_dir)) 30 | root_output_dir.mkdir() 31 | 32 | dataset = cfg.DATASET.DATASET + '_' + cfg.DATASET.HYBRID_JOINTS_TYPE \ 33 | if cfg.DATASET.HYBRID_JOINTS_TYPE else cfg.DATASET.DATASET 34 | dataset = dataset.replace(':', '_') 35 | model = cfg.MODEL.NAME 36 | cfg_name = os.path.basename(cfg_name).split('.')[0] 37 | 38 | # final_output_dir = root_output_dir / dataset / model / cfg_name 39 | final_output_dir = Path('output/' + cfg.OUTPUT_DIR) 40 | print('=> creating {}'.format(final_output_dir)) 41 | final_output_dir.mkdir(parents=True, exist_ok=True) 42 | 43 | time_str = time.strftime('%Y-%m-%d-%H-%M') 44 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase) 45 | final_log_file = final_output_dir / log_file 46 | head = '%(asctime)-15s %(message)s' 47 | logging.basicConfig(filename=str(final_log_file), 48 | format=head) 49 | logger = logging.getLogger() 50 | logger.setLevel(logging.INFO) 51 | console = logging.StreamHandler() 52 | logging.getLogger('').addHandler(console) 53 | 54 | # tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \ 55 | # (cfg_name + '_' + time_str) 56 | tensorboard_log_dir = final_output_dir / cfg.LOG_DIR 57 | print('=> creating {}'.format(tensorboard_log_dir)) 58 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True) 59 | 60 | return logger, str(final_output_dir), str(tensorboard_log_dir) 61 | 62 | 63 | def get_optimizer(cfg, model): 64 | optimizer = None 65 | if cfg.TRAIN.OPTIMIZER == 'sgd': 66 | optimizer = optim.SGD( 67 | model.parameters(), 68 | lr=cfg.TRAIN.LR, 69 | momentum=cfg.TRAIN.MOMENTUM, 70 | weight_decay=cfg.TRAIN.WD, 71 | nesterov=cfg.TRAIN.NESTEROV 72 | ) 73 | elif cfg.TRAIN.OPTIMIZER == 'adam': 74 | optimizer = optim.Adam( 75 | model.parameters(), 76 | lr=cfg.TRAIN.LR 77 | ) 78 | 79 | return optimizer 80 | 81 | 82 | def save_checkpoint(states, is_best, output_dir, 83 | filename='checkpoint.pth'): 84 | filepath = os.path.join(output_dir, filename) 85 | torch.save(states, filepath) 86 | if is_best: 87 | shutil.copyfile(filepath, os.path.join(output_dir, 'model_best.pth')) 88 | 89 | 90 | def get_model_summary(model, *input_tensors, item_length=26, verbose=False): 91 | """ 92 | :param model: 93 | :param input_tensors: 94 | :param item_length: 95 | :return: 96 | """ 97 | 98 | summary = [] 99 | 100 | ModuleDetails = namedtuple( 101 | "Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds"]) 102 | hooks = [] 103 | layer_instances = {} 104 | 105 | def add_hooks(module): 106 | 107 | def hook(module, input, output): 108 | class_name = str(module.__class__.__name__) 109 | 110 | instance_index = 1 111 | if class_name not in layer_instances: 112 | layer_instances[class_name] = instance_index 113 | else: 114 | instance_index = layer_instances[class_name] + 1 115 | layer_instances[class_name] = instance_index 116 | 117 | layer_name = class_name + "_" + str(instance_index) 118 | 119 | params = 0 120 | 121 | if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \ 122 | class_name.find("Linear") != -1: 123 | for param_ in module.parameters(): 124 | params += param_.view(-1).size(0) 125 | 126 | flops = "Not Available" 127 | if class_name.find("Conv") != -1 and hasattr(module, "weight"): 128 | flops = ( 129 | torch.prod( 130 | torch.LongTensor(list(module.weight.data.size()))) * 131 | torch.prod( 132 | torch.LongTensor(list(output.size())[2:]))).item() 133 | elif isinstance(module, nn.Linear): 134 | flops = (torch.prod(torch.LongTensor(list(output.size()))) \ 135 | * input[0].size(1)).item() 136 | 137 | if isinstance(input[0], list): 138 | input = input[0] 139 | if isinstance(output, list): 140 | output = output[0] 141 | 142 | summary.append( 143 | ModuleDetails( 144 | name=layer_name, 145 | input_size=list(input[0].size()), 146 | output_size=list(output.size()), 147 | num_parameters=params, 148 | multiply_adds=flops) 149 | ) 150 | 151 | if not isinstance(module, nn.ModuleList) \ 152 | and not isinstance(module, nn.Sequential) \ 153 | and module != model: 154 | hooks.append(module.register_forward_hook(hook)) 155 | 156 | model.eval() 157 | model.apply(add_hooks) 158 | 159 | space_len = item_length 160 | 161 | model(*input_tensors) 162 | for hook in hooks: 163 | hook.remove() 164 | 165 | details = '' 166 | if verbose: 167 | details = "Model Summary" + \ 168 | os.linesep + \ 169 | "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format( 170 | ' ' * (space_len - len("Name")), 171 | ' ' * (space_len - len("Input Size")), 172 | ' ' * (space_len - len("Output Size")), 173 | ' ' * (space_len - len("Parameters")), 174 | ' ' * (space_len - len("Multiply Adds (Flops)"))) \ 175 | + os.linesep + '-' * space_len * 5 + os.linesep 176 | 177 | params_sum = 0 178 | flops_sum = 0 179 | for layer in summary: 180 | params_sum += layer.num_parameters 181 | if layer.multiply_adds != "Not Available": 182 | flops_sum += layer.multiply_adds 183 | if verbose: 184 | details += "{}{}{}{}{}{}{}{}{}{}".format( 185 | layer.name, 186 | ' ' * (space_len - len(layer.name)), 187 | layer.input_size, 188 | ' ' * (space_len - len(str(layer.input_size))), 189 | layer.output_size, 190 | ' ' * (space_len - len(str(layer.output_size))), 191 | layer.num_parameters, 192 | ' ' * (space_len - len(str(layer.num_parameters))), 193 | layer.multiply_adds, 194 | ' ' * (space_len - len(str(layer.multiply_adds)))) \ 195 | + os.linesep + '-' * space_len * 5 + os.linesep 196 | 197 | details += os.linesep \ 198 | + "Total Parameters: {:,}".format(params_sum) \ 199 | + os.linesep + '-' * space_len * 5 + os.linesep 200 | details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \ 201 | + os.linesep + '-' * space_len * 5 + os.linesep 202 | details += "Number of Layers" + os.linesep 203 | for layer in layer_instances: 204 | details += "{} : {} layers ".format(layer, layer_instances[layer]) 205 | 206 | return details 207 | 208 | 209 | def update_ema_variables(model, ema_model, alpha, global_step): 210 | # Use the true average until the exponential average is more correct 211 | alpha = min(1 - 1 / (global_step + 1), alpha) 212 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 213 | # ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 214 | ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) 215 | 216 | 217 | def update_ema_statedict(model, ema_model, alpha, global_step): 218 | alpha = min(1 - 1 / (global_step + 1), alpha) 219 | state_dict = model.state_dict() 220 | state_dict_ema = ema_model.state_dict() 221 | for (k, v), (k_ema, v_ema) in zip(state_dict.items(), state_dict_ema.items()): 222 | assert k == k_ema 223 | v_ema.copy_(v_ema * alpha + (1. - alpha) * v) 224 | 225 | 226 | def update_ema_variables_spatial(model, ema_model, alpha, global_step): 227 | # Use the true average until the exponential average is more correct 228 | alpha = min(1 - 1 / (global_step + 1), alpha) 229 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 230 | # ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 231 | tmp_prob = np.random.rand() 232 | if tmp_prob < 0.6: 233 | pass 234 | else: 235 | ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha) 236 | 237 | 238 | def sigmoid_rampup(current, rampup_length): 239 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 240 | if rampup_length == 0: 241 | return 1.0 242 | else: 243 | current = np.clip(current, 0.0, rampup_length) 244 | phase = 1.0 - current / rampup_length 245 | return float(np.exp(-5.0 * phase * phase)) 246 | 247 | 248 | def get_current_consistency_weight(const_weight, epoch, consistency_rampup): 249 | return const_weight * sigmoid_rampup(epoch, consistency_rampup) 250 | 251 | 252 | def cosine_rampdown(current, rampdown_length): 253 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 254 | # assert 0 <= current <= rampdown_length 255 | current = np.clip(current, 0.0, rampdown_length) 256 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 257 | 258 | 259 | def get_current_topkrate(epoch, rampdown_epoch, min_rate): 260 | r = cosine_rampdown(epoch, rampdown_epoch) 261 | return np.clip(r, min_rate, 1) -------------------------------------------------------------------------------- /lib/core/function1.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import logging 7 | 8 | import numpy as np 9 | import torch 10 | from core.evaluate import accuracy 11 | from core.inference import get_final_preds, get_final_preds_const, get_max_preds 12 | from utils.transforms import flip_back, fliplr_joints_batch, fliplr_weights_batch, fliplr_joints_batch_v2 13 | from utils.utils import update_ema_variables, get_current_consistency_weight, get_current_topkrate 14 | from core.function import AverageMeter 15 | from core.loss import select_small_loss_samples_v2 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def train_mt_update(config, args, train_loader, dataset, model, model_ema, criterion, optimizer, epoch, 20 | output_dir, tb_log_dir, writer_dict): 21 | 22 | batch_time = AverageMeter() 23 | data_time = AverageMeter() 24 | losses = AverageMeter() 25 | losses_sup = AverageMeter() 26 | losses_consistency = AverageMeter() 27 | acc = AverageMeter() 28 | acc_ema = AverageMeter() 29 | # switch to train mode 30 | model.train() 31 | model_ema.train() 32 | end = time.time() 33 | 34 | # ratio of small loss samples, reduce gradually to avoid overfitting to initial pseudo labels 35 | topk_rate = get_current_topkrate(epoch, args.topkrampdown, args.minrate) 36 | for i, (input, target, target_weight, input_ema, input1, meta) in enumerate(train_loader): 37 | 38 | # measure data loading time 39 | data_time.update(time.time() - end) 40 | input = input.cuda() 41 | input_ema = input_ema.cuda() 42 | input2 = input_ema.detach().clone() 43 | input1 = input1.cuda() 44 | target = target.cuda(non_blocking=True) 45 | target_weight = target_weight.cuda(non_blocking=True) 46 | use_label = meta['use_label'].cuda(non_blocking=True) 47 | flip_var = meta['flip'].numpy() 48 | img_width = meta['img_width'].numpy() 49 | warpmat1 = meta['warpmat1'].numpy() 50 | warpmat2 = meta['warpmat_v2'].numpy() 51 | flip_var_v2 = meta['flip_v2'].numpy() 52 | # compute output 53 | outputs, _ = model(input) 54 | outputs_copy = outputs.detach().clone() 55 | 56 | # select small loss samples 57 | small_loss_idx = select_small_loss_samples_v2(outputs_copy, target, target_weight, topk_rate) 58 | weights_small_loss = torch.zeros_like(target_weight) 59 | weights_small_loss[small_loss_idx[:, 0], small_loss_idx[:, 1], 0] = 1 60 | 61 | if epoch > args.update_epoch: 62 | with torch.no_grad(): 63 | outputs1, _ = model(input1) 64 | outputs2, _ = model(input2) 65 | # compute joint locations for input_v1, in the scale of heatmap size 66 | joints1, _ = get_max_preds(outputs1.detach().cpu().numpy()) 67 | c2 = meta['center_ema'].numpy() 68 | s2 = meta['scale_ema'].numpy() 69 | 70 | joints2_vis = meta['joints_vis_ema'][:, :, :2].numpy() 71 | # compute the joint locations for input_v2, in the scale of original image size 72 | joints2, _ = get_final_preds_const(outputs2.cpu().numpy(), c2, s2) 73 | # flip output_v2 to keep in consistent with input_v1 74 | joints2_flip, joints2_vis_flip = fliplr_joints_batch_v2(joints2, joints2_vis, img_width[:, None], 75 | dataset.flip_pairs) 76 | joints2_1 = np.where(flip_var_v2[:, None, None], joints2_flip, joints2) 77 | joints2_1 = np.concatenate([joints2_1[:, :, :2], torch.ones(joints2_1.shape[0], joints2_1.shape[1], 1)], axis=-1) 78 | # transform output_v2 to keep in consistent with input_v1 79 | joints21_trans = np.matmul(warpmat2, np.transpose(joints2_1, (0, 2, 1))) 80 | joints21_trans = np.transpose(joints21_trans, (0, 2, 1)) 81 | joints21_trans_hmsize = joints21_trans/4 + 0.5 82 | # compute distance between input_v1 and input_v2 83 | dist = np.linalg.norm(joints1 - joints21_trans_hmsize, axis=-1, keepdims=True) 84 | 85 | # use the current model prediction as supervision, hence needs to transform the prediction to keep 86 | # consistent with the input 87 | joints2_ori = np.where(flip_var[:, None, None], joints2_flip, joints2) 88 | joints2_ori = np.concatenate([joints2_ori[:, :, :2], torch.ones(joints2_ori.shape[0], joints2_ori.shape[1], 1)], axis=-1) 89 | joints2_ori_trans = np.matmul(warpmat1, np.transpose(joints2_ori, (0, 2, 1))) 90 | joints2_ori_trans = np.transpose(joints2_ori_trans, (0, 2, 1)) 91 | 92 | # agreement check to select reusable samples 93 | mask = (dist < args.dist_thre) 94 | weights2 = mask.astype(float) 95 | # exclude the small loss samples 96 | weights2 = weights2 * (1 - weights_small_loss.cpu().numpy()) 97 | 98 | # re-labeling for reusable samples 99 | hms2_re = np.zeros_like(target.cpu().numpy()) 100 | targets_weight2_re = np.zeros_like(target_weight.cpu().numpy()) 101 | for b in range(hms2_re.shape[0]): 102 | hm2_re, target_weight2_re = dataset.generate_target(joints2_ori_trans[b], weights2[b]) 103 | hms2_re[b] = hm2_re 104 | targets_weight2_re[b] = target_weight2_re 105 | hms2_re = torch.from_numpy(hms2_re).cuda() 106 | targets_weight2_re = torch.from_numpy(targets_weight2_re).cuda() 107 | 108 | mask_ = torch.tensor(weights2 > 0).cuda() 109 | # use model prediction for reusable samples, and the initial pseudo label otherwise 110 | target2 = torch.where(mask_[:, :, :, None], hms2_re, target) 111 | targets_weight2 = targets_weight2_re + weights_small_loss 112 | assert targets_weight2.max() < 2 113 | 114 | # increase the weights for supervised loss 115 | target = torch.where(use_label[:, None, None, None], target, target2) 116 | target_weight = args.true_label_w * target_weight 117 | target_weight = torch.where(use_label[:, None, None], target_weight, targets_weight2) 118 | 119 | if isinstance(outputs, list): 120 | loss_sup = criterion(outputs[0], target, target_weight) 121 | for output in outputs[1:]: 122 | loss_sup += criterion(output, target, target_weight) 123 | else: 124 | output = outputs 125 | loss_sup = criterion(output, target, target_weight) 126 | 127 | with torch.no_grad(): 128 | if epoch > args.update_epoch: 129 | # feed the student input into the teacher network although we do not use it. This is to avoid 130 | # the batch norm statistics become too different for student and teacher 131 | _, _ = model_ema(input) 132 | outputs_ema, _ = model_ema(input_ema) 133 | 134 | # student-teacher consistency 135 | c_ema = meta['center_ema'].numpy() 136 | s_ema = meta['scale_ema'].numpy() 137 | joints_vis = meta['joints_vis_ema'][:, :, :2].numpy() 138 | joints_ori, _ = get_final_preds_const(outputs_ema.cpu().numpy(), c_ema, s_ema) 139 | # transform the teacher output to keep consistent with the student network 140 | joints_ori_flip, joints_vis_flip = fliplr_joints_batch(joints_ori, joints_vis, img_width[:, None], dataset.flip_pairs) 141 | joints = np.where(flip_var[:, None, None], joints_ori_flip, joints_ori) 142 | joints_vis = np.where(flip_var[:, None, None], joints_vis_flip, joints_vis) 143 | joints = np.concatenate([joints[:, :, :2], torch.ones(joints.shape[0], joints.shape[1], 1)], axis=-1) 144 | joints_trans = np.matmul(warpmat1, np.transpose(joints, (0, 2, 1))) 145 | joints_trans = np.transpose(joints_trans, (0, 2, 1)) 146 | 147 | hms_ema_re = np.zeros_like(target.cpu().numpy()) 148 | targets_weight_ema_re = np.zeros_like(target_weight.cpu().numpy()) 149 | for b in range(hms_ema_re.shape[0]): 150 | hm_ema_re, target_weight_ema_re = dataset.generate_target(joints_trans[b], joints_vis[b]) 151 | hms_ema_re[b] = hm_ema_re 152 | targets_weight_ema_re[b] = target_weight_ema_re 153 | 154 | hms_ema_re = torch.from_numpy(hms_ema_re).cuda() 155 | target_weight_ema_re = torch.from_numpy(targets_weight_ema_re).cuda() 156 | 157 | loss_consistency = criterion(output, hms_ema_re, target_weight_ema_re) 158 | const_loss_weight = get_current_consistency_weight(args.const_weight, epoch, args.consistency_rampup) 159 | loss = const_loss_weight * loss_consistency + loss_sup 160 | # compute gradient and do update step 161 | 162 | writer = writer_dict['writer'] 163 | global_steps = writer_dict['train_global_steps'] 164 | 165 | optimizer.zero_grad() 166 | loss.backward() 167 | optimizer.step() 168 | update_ema_variables(model, model_ema, 0.999, global_steps) 169 | # measure accuracy and record loss 170 | losses.update(loss.item(), input.size(0)) 171 | losses_sup.update(loss_sup.item(), input.size(0)) 172 | losses_consistency.update(loss_consistency, input.size(0)) 173 | _, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(), 174 | target.detach().cpu().numpy()) 175 | _, avg_acc_ema, cnt_ema, pred_ema = accuracy(outputs_ema.cpu().numpy(), 176 | target.detach().cpu().numpy()) 177 | acc.update(avg_acc, cnt) 178 | acc_ema.update(avg_acc_ema, cnt_ema) 179 | # measure elapsed time 180 | batch_time.update(time.time() - end) 181 | end = time.time() 182 | 183 | if i % config.PRINT_FREQ == 0: 184 | msg = 'Epoch: [{0}][{1}/{2}]\t' \ 185 | 'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \ 186 | 'Speed {speed:.1f} samples/s\t' \ 187 | 'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \ 188 | 'Loss_sup {loss_sup.val:.5f} ({loss_sup.avg:.5f})\t' \ 189 | 'Loss_const {loss_const.val:.5f} ({loss_const.avg:.5f})\t' \ 190 | 'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t'\ 191 | 'Accuracy_ema {acc_ema.val:.3f} ({acc_ema.avg:.3f})'.format( 192 | epoch, i, len(train_loader), batch_time=batch_time, 193 | speed=input.size(0)/batch_time.val, 194 | data_time=data_time, loss_sup=losses_sup, loss_const=losses_consistency, acc=acc, acc_ema=acc_ema) 195 | logger.info(msg) 196 | 197 | writer.add_scalar('train_loss', losses.val, global_steps) 198 | writer.add_scalar('loss_sup', losses_sup.val, global_steps) 199 | writer.add_scalar('loss_const', losses_consistency.val, global_steps) 200 | writer.add_scalar('train_acc', acc.val, global_steps) 201 | writer_dict['train_global_steps'] = global_steps + 1 202 | -------------------------------------------------------------------------------- /lib/dataset/JointsDataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import copy 12 | import logging 13 | import random 14 | 15 | import cv2 16 | import numpy as np 17 | import torch 18 | from torch.utils.data import Dataset 19 | 20 | from utils.transforms import get_affine_transform 21 | from utils.transforms import affine_transform 22 | from utils.transforms import fliplr_joints 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class JointsDataset(Dataset): 29 | def __init__(self, cfg, root, image_set, is_train, transform=None): 30 | self.num_joints = 0 31 | self.pixel_std = 200 32 | self.flip_pairs = [] 33 | self.parent_ids = [] 34 | 35 | self.is_train = is_train 36 | self.root = root 37 | self.image_set = image_set 38 | 39 | self.output_path = cfg.OUTPUT_DIR 40 | self.data_format = cfg.DATASET.DATA_FORMAT 41 | 42 | self.scale_factor = cfg.DATASET.SCALE_FACTOR 43 | self.rotation_factor = cfg.DATASET.ROT_FACTOR 44 | self.flip = cfg.DATASET.FLIP 45 | self.num_joints_half_body = cfg.DATASET.NUM_JOINTS_HALF_BODY 46 | self.prob_half_body = cfg.DATASET.PROB_HALF_BODY 47 | self.color_rgb = cfg.DATASET.COLOR_RGB 48 | 49 | self.target_type = cfg.MODEL.TARGET_TYPE 50 | self.image_size = np.array(cfg.MODEL.IMAGE_SIZE) 51 | self.heatmap_size = np.array(cfg.MODEL.HEATMAP_SIZE) 52 | self.sigma = cfg.MODEL.SIGMA 53 | self.use_different_joints_weight = cfg.LOSS.USE_DIFFERENT_JOINTS_WEIGHT 54 | self.joints_weight = 1 55 | 56 | self.transform = transform 57 | self.db = [] 58 | 59 | def _get_db(self): 60 | raise NotImplementedError 61 | 62 | def evaluate(self, cfg, preds, output_dir, *args, **kwargs): 63 | raise NotImplementedError 64 | 65 | def half_body_transform(self, joints, joints_vis): 66 | upper_joints = [] 67 | lower_joints = [] 68 | for joint_id in range(self.num_joints): 69 | if joints_vis[joint_id][0] > 0: 70 | if joint_id in self.upper_body_ids: 71 | upper_joints.append(joints[joint_id]) 72 | else: 73 | lower_joints.append(joints[joint_id]) 74 | 75 | if np.random.randn() < 0.5 and len(upper_joints) > 2: 76 | selected_joints = upper_joints 77 | else: 78 | selected_joints = lower_joints \ 79 | if len(lower_joints) > 2 else upper_joints 80 | 81 | if len(selected_joints) < 2: 82 | return None, None 83 | 84 | selected_joints = np.array(selected_joints, dtype=np.float32) 85 | center = selected_joints.mean(axis=0)[:2] 86 | 87 | left_top = np.amin(selected_joints, axis=0) 88 | right_bottom = np.amax(selected_joints, axis=0) 89 | 90 | w = right_bottom[0] - left_top[0] 91 | h = right_bottom[1] - left_top[1] 92 | 93 | if w > self.aspect_ratio * h: 94 | h = w * 1.0 / self.aspect_ratio 95 | elif w < self.aspect_ratio * h: 96 | w = h * self.aspect_ratio 97 | 98 | scale = np.array( 99 | [ 100 | w * 1.0 / self.pixel_std, 101 | h * 1.0 / self.pixel_std 102 | ], 103 | dtype=np.float32 104 | ) 105 | 106 | scale = scale * 1.5 107 | 108 | return center, scale 109 | 110 | def __len__(self,): 111 | return len(self.db) 112 | 113 | def __getitem__(self, idx): 114 | db_rec = copy.deepcopy(self.db[idx]) 115 | 116 | image_file = db_rec['image'] 117 | filename = db_rec['filename'] if 'filename' in db_rec else '' 118 | imgnum = db_rec['imgnum'] if 'imgnum' in db_rec else '' 119 | 120 | if self.data_format == 'zip': 121 | from utils import zipreader 122 | data_numpy = zipreader.imread( 123 | image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION 124 | ) 125 | else: 126 | data_numpy = cv2.imread( 127 | image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION 128 | ) 129 | 130 | if self.color_rgb: 131 | data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB) 132 | 133 | if data_numpy is None: 134 | logger.error('=> fail to read {}'.format(image_file)) 135 | raise ValueError('Fail to read {}'.format(image_file)) 136 | 137 | joints = db_rec['joints_3d'] 138 | joints_vis = db_rec['joints_3d_vis'] 139 | 140 | c = db_rec['center'] 141 | s = db_rec['scale'] 142 | score = db_rec['score'] if 'score' in db_rec else 1 143 | r = 0 144 | 145 | if self.is_train: 146 | if (np.sum(joints_vis[:, 0]) > self.num_joints_half_body 147 | and np.random.rand() < self.prob_half_body): 148 | c_half_body, s_half_body = self.half_body_transform( 149 | joints, joints_vis 150 | ) 151 | 152 | if c_half_body is not None and s_half_body is not None: 153 | c, s = c_half_body, s_half_body 154 | 155 | sf = self.scale_factor 156 | rf = self.rotation_factor 157 | s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf) 158 | r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \ 159 | if random.random() <= 0.6 else 0 160 | 161 | if self.flip and random.random() <= 0.5: 162 | data_numpy = data_numpy[:, ::-1, :] 163 | joints, joints_vis = fliplr_joints( 164 | joints, joints_vis, data_numpy.shape[1], self.flip_pairs) 165 | c[0] = data_numpy.shape[1] - c[0] - 1 166 | 167 | trans = get_affine_transform(c, s, r, self.image_size) 168 | input = cv2.warpAffine( 169 | data_numpy, 170 | trans, 171 | (int(self.image_size[0]), int(self.image_size[1])), 172 | flags=cv2.INTER_LINEAR) 173 | 174 | if self.transform: 175 | input = self.transform(input) 176 | 177 | for i in range(self.num_joints): 178 | if joints_vis[i, 0] > 0.0: 179 | joints[i, 0:2] = affine_transform(joints[i, 0:2], trans) 180 | 181 | target, target_weight = self.generate_target(joints, joints_vis) 182 | 183 | target = torch.from_numpy(target) 184 | target_weight = torch.from_numpy(target_weight) 185 | 186 | meta = { 187 | 'image': image_file, 188 | 'filename': filename, 189 | 'imgnum': imgnum, 190 | 'joints': np.float32(joints), 191 | 'joints_vis': joints_vis, 192 | 'center': c, 193 | 'scale': s, 194 | 'rotation': r, 195 | 'score': score 196 | } 197 | 198 | return input, target, target_weight, meta 199 | 200 | def select_data(self, db): 201 | db_selected = [] 202 | for rec in db: 203 | num_vis = 0 204 | joints_x = 0.0 205 | joints_y = 0.0 206 | for joint, joint_vis in zip( 207 | rec['joints_3d'], rec['joints_3d_vis']): 208 | if joint_vis[0] <= 0: 209 | continue 210 | num_vis += 1 211 | 212 | joints_x += joint[0] 213 | joints_y += joint[1] 214 | if num_vis == 0: 215 | continue 216 | 217 | joints_x, joints_y = joints_x / num_vis, joints_y / num_vis 218 | 219 | area = rec['scale'][0] * rec['scale'][1] * (self.pixel_std**2) 220 | joints_center = np.array([joints_x, joints_y]) 221 | bbox_center = np.array(rec['center']) 222 | diff_norm2 = np.linalg.norm((joints_center-bbox_center), 2) 223 | ks = np.exp(-1.0*(diff_norm2**2) / ((0.2)**2*2.0*area)) 224 | 225 | metric = (0.2 / 16) * num_vis + 0.45 - 0.2 / 16 226 | if ks > metric: 227 | db_selected.append(rec) 228 | 229 | logger.info('=> num db: {}'.format(len(db))) 230 | logger.info('=> num selected db: {}'.format(len(db_selected))) 231 | return db_selected 232 | 233 | def generate_target(self, joints, joints_vis): 234 | ''' 235 | :param joints: [num_joints, 3] 236 | :param joints_vis: [num_joints, 3] 237 | :return: target, target_weight(1: visible, 0: invisible) 238 | ''' 239 | target_weight = np.ones((self.num_joints, 1), dtype=np.float32) 240 | target_weight[:, 0] = joints_vis[:, 0] 241 | 242 | assert self.target_type == 'gaussian', \ 243 | 'Only support gaussian map now!' 244 | 245 | if self.target_type == 'gaussian': 246 | target = np.zeros((self.num_joints, 247 | self.heatmap_size[1], 248 | self.heatmap_size[0]), 249 | dtype=np.float32) 250 | 251 | tmp_size = self.sigma * 3 252 | 253 | for joint_id in range(self.num_joints): 254 | feat_stride = self.image_size / self.heatmap_size 255 | mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5) 256 | mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5) 257 | # Check that any part of the gaussian is in-bounds 258 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] 259 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] 260 | if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \ 261 | or br[0] < 0 or br[1] < 0: 262 | # If not, just return the image as is 263 | target_weight[joint_id] = 0 264 | continue 265 | 266 | # # Generate gaussian 267 | size = 2 * tmp_size + 1 268 | x = np.arange(0, size, 1, np.float32) 269 | y = x[:, np.newaxis] 270 | x0 = y0 = size // 2 271 | # The gaussian is not normalized, we want the center value to equal 1 272 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2)) 273 | 274 | # Usable gaussian range 275 | g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0] 276 | g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1] 277 | # Image range 278 | img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0]) 279 | img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1]) 280 | 281 | v = target_weight[joint_id] 282 | if v > 0.5: 283 | target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \ 284 | g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 285 | 286 | if self.use_different_joints_weight: 287 | target_weight = np.multiply(target_weight, self.joints_weight) 288 | 289 | return target, target_weight 290 | -------------------------------------------------------------------------------- /tools/train_mt_part.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import argparse 7 | import os 8 | import pprint 9 | import shutil 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | from tensorboardX import SummaryWriter 19 | 20 | import _init_paths 21 | from config import cfg 22 | from config import update_config 23 | from core.loss import JointsMSELoss 24 | from core.function import validate_mt, AverageMeter, validate 25 | from core.function1 import train_mt_update 26 | from utils.utils import get_optimizer 27 | from utils.utils import save_checkpoint 28 | from utils.utils import create_logger 29 | from utils.utils import get_model_summary 30 | from utils.augmentation_pool import RandAugmentMC 31 | from utils.consistency import prediction_check 32 | from core.evaluate import accuracy 33 | import models 34 | import dataset_animal 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser(description='Train keypoints network') 39 | # general 40 | parser.add_argument('--cfg', 41 | help='experiment configure file name', 42 | required=True, 43 | type=str) 44 | 45 | parser.add_argument('opts', 46 | help="Modify config options using the command-line", 47 | default=None, 48 | nargs=argparse.REMAINDER) 49 | 50 | # philly 51 | parser.add_argument('--modelDir', 52 | help='model directory', 53 | type=str, 54 | default='') 55 | parser.add_argument('--logDir', 56 | help='log directory', 57 | type=str, 58 | default='') 59 | parser.add_argument('--dataDir', 60 | help='data directory', 61 | type=str, 62 | default='') 63 | parser.add_argument('--prevModelDir', 64 | help='prev Model directory', 65 | type=str, 66 | default='') 67 | 68 | parser.add_argument('--animalpose', 69 | help='train on ap10k', 70 | action='store_true') 71 | 72 | parser.add_argument('--fewshot', 73 | help='train on ap10k with few shot annotations', 74 | action='store_true') 75 | 76 | parser.add_argument('--num_transforms', 77 | help='number of transformations used for generating pseudo labels', 78 | type=int, 79 | default=5) 80 | parser.add_argument('--generate_pseudol', 81 | help='set true generate pseudo labels', 82 | action='store_true') 83 | 84 | parser.add_argument('--pretrained', 85 | help='path for pretrained model', 86 | type=str, 87 | default='') 88 | parser.add_argument('--resume', help='path to resume', type=str, default='') 89 | 90 | parser.add_argument('--const_weight', type=float, default=2.0) 91 | parser.add_argument('--consistency_rampup', type=int, default=10) 92 | 93 | parser.add_argument('--dist_thre', type=float, default=0.6) 94 | parser.add_argument('--update_epoch', type=int, default=-1) 95 | 96 | parser.add_argument('--topkrampdown', type=int, default=30) 97 | parser.add_argument('--minrate', type=float, default=0.8) 98 | 99 | parser.add_argument('--score_thre', type=float, default=0.5) 100 | 101 | parser.add_argument('--sf_aggre', type=float, default=0.1) 102 | parser.add_argument('--rf_aggre', type=float, default=20) 103 | parser.add_argument('--length', type=int, default=32) 104 | parser.add_argument('--nholes', type=int, default=6) 105 | parser.add_argument('--augment', action='store_true') 106 | parser.add_argument('--std_gaussian', type=float, default=0.2) 107 | 108 | parser.add_argument('--epochs', type=int, default=150) 109 | parser.add_argument('--true_label_w', type=float, default=2.0) 110 | 111 | parser.add_argument('--lr_factor', type=float, default=0.1) 112 | parser.add_argument('--schedule', type=int, nargs='+', default=[190, 200], 113 | help='Decrease learning rate at these epochs.') 114 | 115 | parser.add_argument('--percentage', type=float, default=0.4) 116 | 117 | parser.add_argument('--m', type=int, default=10) 118 | parser.add_argument('--n', type=int, default=2) 119 | 120 | parser.add_argument('--evaluate', action='store_true') 121 | parser.add_argument('--few_shot_setting', action='store_false') 122 | args = parser.parse_args() 123 | 124 | return args 125 | 126 | 127 | def main(): 128 | 129 | args = parse_args() 130 | update_config(cfg, args) 131 | 132 | logger, final_output_dir, tb_log_dir = create_logger( 133 | cfg, args.cfg, 'train') 134 | 135 | logger.info(pprint.pformat(args)) 136 | logger.info(cfg) 137 | 138 | # cudnn related setting 139 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 140 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 141 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 142 | 143 | model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( 144 | cfg, is_train=True 145 | ) 146 | model_ema = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( 147 | cfg, is_train=True 148 | ) 149 | for param in model_ema.parameters(): 150 | param.detach() 151 | 152 | # copy model file 153 | this_dir = os.path.dirname(__file__) 154 | shutil.copy2( 155 | os.path.join(this_dir, '../lib/core/function1.py'), 156 | final_output_dir) 157 | 158 | writer_dict = { 159 | 'writer': SummaryWriter(log_dir=tb_log_dir), 160 | 'train_global_steps': 0, 161 | 'valid_global_steps': 0, 162 | } 163 | 164 | dump_input = torch.rand( 165 | (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]) 166 | ) 167 | 168 | logger.info(get_model_summary(model, dump_input)) 169 | 170 | # define loss function (criterion) and optimizer 171 | criterion = JointsMSELoss( 172 | use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT 173 | ).cuda() 174 | 175 | # Data loading code 176 | normalize = transforms.Normalize( 177 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 178 | ) 179 | 180 | if args.animalpose: 181 | if args.augment: 182 | logger.info("Add strong augmentations to student") 183 | transfm_stu = transforms.Compose([RandAugmentMC(n=args.n, m=args.m, num_cutout=args.nholes), 184 | transforms.ToTensor(), 185 | normalize]) 186 | transfm_tea = transforms.Compose([ 187 | transforms.ToTensor(), 188 | normalize 189 | ]) 190 | else: 191 | logger.info("Without strong augmentations to student") 192 | transfm_stu = transforms.Compose([ 193 | transforms.ToTensor(), 194 | normalize 195 | ]) 196 | 197 | transfm_tea = transforms.Compose([ 198 | transforms.ToTensor(), 199 | normalize 200 | ]) 201 | 202 | train_dataset = eval('dataset_animal.' + 'ap10k')( 203 | cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True, 204 | transforms.Compose([ 205 | transforms.ToTensor(), 206 | normalize, 207 | ]) 208 | ) 209 | valid_dataset = eval('dataset_animal.' + 'ap10k')( 210 | cfg, cfg.DATASET.ROOT, cfg.DATASET.VAL_SET, False, 211 | transforms.Compose([ 212 | transforms.ToTensor(), 213 | normalize, 214 | ]) 215 | ) 216 | else: 217 | train_dataset = eval('dataset.'+cfg.DATASET.DATASET)( 218 | cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True, 219 | transforms.Compose([ 220 | transforms.ToTensor(), 221 | normalize, 222 | ]) 223 | ) 224 | valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)( 225 | cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False, 226 | transforms.Compose([ 227 | transforms.ToTensor(), 228 | normalize, 229 | ]) 230 | ) 231 | 232 | train_loader = torch.utils.data.DataLoader( 233 | train_dataset, 234 | batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS), 235 | shuffle=cfg.TRAIN.SHUFFLE, 236 | num_workers=cfg.WORKERS, 237 | pin_memory=cfg.PIN_MEMORY, 238 | drop_last=True 239 | ) 240 | valid_loader = torch.utils.data.DataLoader( 241 | valid_dataset, 242 | batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS), 243 | shuffle=False, 244 | num_workers=cfg.WORKERS, 245 | pin_memory=cfg.PIN_MEMORY 246 | ) 247 | 248 | best_perf = 0.0 249 | last_epoch = -1 250 | best_perf_epoch = 0 251 | optimizer = get_optimizer(cfg, model) 252 | 253 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH 254 | checkpoint_file = os.path.join( 255 | final_output_dir, 'checkpoint.pth' 256 | ) 257 | 258 | # load pretrained model 259 | if os.path.exists(args.pretrained): 260 | logger.info("=> loading checkpoint '{}'".format(args.pretrained)) 261 | pretrained_model = torch.load(args.pretrained) 262 | model.load_state_dict(pretrained_model['state_dict']) 263 | 264 | # resume pretrained model 265 | if os.path.exists(args.resume): 266 | logger.info("=> resume from checkpoint '{}'".format(args.resume)) 267 | checkpoint = torch.load(args.resume) 268 | begin_epoch = checkpoint['epoch'] 269 | best_perf = checkpoint['perf'] 270 | last_epoch = checkpoint['epoch'] 271 | model.load_state_dict(checkpoint['state_dict']) 272 | optimizer.load_state_dict(checkpoint['optimizer']) 273 | logger.info("=> loaded checkpoint '{}' (epoch {})".format( 274 | checkpoint_file, checkpoint['epoch'])) 275 | 276 | model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() 277 | model_ema = torch.nn.DataParallel(model_ema, device_ids=cfg.GPUS).cuda() 278 | 279 | 280 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 281 | optimizer, args.schedule, args.lr_factor, 282 | last_epoch=last_epoch 283 | ) 284 | 285 | if args.evaluate: 286 | acc = validate(cfg, valid_loader, valid_dataset, model, criterion, 287 | final_output_dir, tb_log_dir, writer_dict, args.animalpose) 288 | return 289 | 290 | for epoch in range(begin_epoch, args.epochs): 291 | if epoch == begin_epoch: 292 | if args.generate_pseudol: 293 | model.eval() 294 | train_dataset = eval('dataset_animal.ap10k')( 295 | cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, False, 296 | # transforms.Compose([ 297 | # transforms.ToTensor(), 298 | # normalize, 299 | # ]) 300 | ) 301 | train_loader = torch.utils.data.DataLoader( 302 | train_dataset, 303 | batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS), 304 | shuffle=False, 305 | num_workers=cfg.WORKERS, 306 | pin_memory=cfg.PIN_MEMORY 307 | ) 308 | pseudo_kpts = {} 309 | acc_pseudol = AverageMeter() 310 | for _, (input, target, target_weight, meta) in enumerate(train_loader): 311 | for i in range(input.size(0)): 312 | c = meta['center'].numpy() 313 | s = meta['scale'].numpy() 314 | generated_kpts, score_map = prediction_check(input[i], model, train_dataset, c[i:i+1], s[i:i+1], 315 | args.num_transforms) 316 | pseudo_kpts[int(meta['index'][i].numpy())] = generated_kpts 317 | _, avg_acc, cnt, pred = accuracy(score_map, 318 | target[i].unsqueeze(0).numpy()) 319 | acc_pseudol.update(avg_acc, cnt) 320 | print("Acc on the training dataset (pseudo-labels): {}".format(acc_pseudol.avg)) 321 | np.save('data/pseudo_labels/{}shots/pseudo_labels_train.npy'.format(cfg.LABEL_PER_CLASS), 322 | pseudo_kpts) 323 | break 324 | 325 | train_dataset = eval('dataset_animal.' + cfg.DATASET.DATASET)( 326 | cfg, args, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True, 327 | transfm_stu, transfm_tea 328 | ) 329 | 330 | train_loader = torch.utils.data.DataLoader( 331 | train_dataset, 332 | batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS), 333 | shuffle=cfg.TRAIN.SHUFFLE, 334 | num_workers=cfg.WORKERS, 335 | pin_memory=cfg.PIN_MEMORY, 336 | drop_last=True 337 | ) 338 | 339 | lr_scheduler.step() 340 | train_mt_update(cfg, args, train_loader, train_dataset, model, model_ema, criterion, optimizer, epoch, 341 | final_output_dir, tb_log_dir, writer_dict) 342 | 343 | # evaluate on validation set 344 | perf_indicator = validate_mt( 345 | cfg, valid_loader, valid_dataset, model, model_ema, criterion, 346 | final_output_dir, tb_log_dir, writer_dict, args.animalpose) 347 | 348 | if perf_indicator >= best_perf: 349 | best_perf = perf_indicator 350 | best_model = True 351 | best_perf_epoch = epoch + 1 352 | else: 353 | best_model = False 354 | # save model 355 | logger.info('=> saving checkpoint to {}'.format(final_output_dir)) 356 | save_checkpoint({ 357 | 'epoch': epoch + 1, 358 | 'model': cfg.MODEL.NAME, 359 | 'state_dict': model_ema.module.state_dict(), 360 | 'perf': perf_indicator, 361 | 'optimizer': optimizer.state_dict(), 362 | }, best_model, final_output_dir) 363 | 364 | final_model_state_file = os.path.join( 365 | final_output_dir, 'final_state.pth' 366 | ) 367 | logger.info('=> saving final model state to {}'.format( 368 | final_model_state_file) 369 | ) 370 | logger.info('Best accuracy {} at epoch {}'.format(best_perf, best_perf_epoch)) 371 | torch.save(model.module.state_dict(), final_model_state_file) 372 | writer_dict['writer'].close() 373 | 374 | 375 | if __name__ == '__main__': 376 | main() 377 | -------------------------------------------------------------------------------- /lib/dataset_animal/ap10k_test_category.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import warnings 4 | from collections import OrderedDict, defaultdict 5 | import logging 6 | import sys 7 | sys.path.insert(0, './lib') 8 | 9 | import cv2 10 | import json_tricks as json 11 | import numpy as np 12 | from xtcocotools.cocoeval import COCOeval 13 | from xtcocotools.coco import COCO 14 | from nms.nms import oks_nms 15 | from nms.nms import soft_oks_nms 16 | 17 | from dataset_animal.kpt_2d_base import Kpt2dSviewRgbImgTopDownDataset 18 | from dataset_animal.dataset_info import DatasetInfo 19 | from dataset_animal import ap10k_info 20 | 21 | 22 | data_cfg = dict( 23 | image_size=[256, 256], 24 | heatmap_size=[64, 64], 25 | soft_nms=False, 26 | nms_thr=1.0, 27 | oks_thr=0.9, 28 | vis_thr=0.2, 29 | use_gt_bbox=True, 30 | det_bbox_thr=0.0, 31 | bbox_file='', 32 | ) 33 | logger = logging.getLogger(__name__) 34 | 35 | ap10k2coco = np.array([2, 0, 1, 3, 4, 5, 8, 6, 9, 7, 10, 11, 14, 12, 15, 13, 16]) 36 | class AnimalAP10KDataset(Kpt2dSviewRgbImgTopDownDataset): 37 | """AP-10K dataset for animal pose estimation. 38 | 39 | `AP-10K: A Benchmark for Animal Pose Estimation in the Wild’ 40 | Neurips Dataset Track'2021 41 | More details can be found in the `paper 42 | `__ . 43 | 44 | The dataset loads raw features and apply specified transforms 45 | to return a dict containing the image tensors and other information. 46 | 47 | AP-10K keypoint indexes:: 48 | 49 | 0: 'L_Eye', 50 | 1: 'R_Eye', 51 | 2: 'Nose', 52 | 3: 'Neck', 53 | 4: 'root of tail', 54 | 5: 'L_Shoulder', 55 | 6: 'L_Elbow', 56 | 7: 'L_F_Paw', 57 | 8: 'R_Shoulder', 58 | 9: 'R_Elbow', 59 | 10: 'R_F_Paw, 60 | 11: 'L_Hip', 61 | 12: 'L_Knee', 62 | 13: 'L_B_Paw', 63 | 14: 'R_Hip', 64 | 15: 'R_Knee', 65 | 16: 'R_B_Paw' 66 | 67 | Args: 68 | ann_file (str): Path to the annotation file. 69 | img_prefix (str): Path to a directory where images are held. 70 | Default: None. 71 | data_cfg (dict): config 72 | pipeline (list[dict | callable]): A sequence of data transforms. 73 | dataset_info (DatasetInfo): A class containing all dataset info. 74 | test_mode (bool): Store True when building test or 75 | validation dataset. Default: False. 76 | """ 77 | 78 | def __init__(self, 79 | cfg, 80 | root, 81 | image_set, 82 | is_train, 83 | transform=None 84 | ): 85 | 86 | super().__init__( 87 | cfg, 88 | root, 89 | image_set, 90 | is_train, 91 | transform 92 | ) 93 | 94 | self.nms_thr = data_cfg['nms_thr'] 95 | self.image_thre = data_cfg.get('det_bbox_thr', 0.0) 96 | self.soft_nms = data_cfg['soft_nms'] 97 | self.oks_thre = data_cfg['oks_thr'] 98 | self.in_vis_thre = data_cfg['vis_thr'] 99 | self.bbox_file = data_cfg['bbox_file'] 100 | self.use_gt_bbox = data_cfg['use_gt_bbox'] 101 | 102 | self.use_nms = data_cfg.get('use_nms', True) 103 | 104 | self.image_width = data_cfg['image_size'][0] 105 | self.image_height = data_cfg['image_size'][1] 106 | self.pixel_std = 200 107 | 108 | self.root = root 109 | self.select_data = cfg.DATASET.SELECT_DATA 110 | 111 | assert image_set == 'val' 112 | logger.info('Loading validation annotations') 113 | ann_file = self.root + 'annotations/ap10k-val-split1.json' 114 | 115 | self.coco = COCO(ann_file) 116 | if 'categories' in self.coco.dataset: 117 | cats = [ 118 | cat['name'] 119 | for cat in self.coco.loadCats(self.coco.getCatIds()) 120 | ] 121 | self.classes = ['__background__'] + cats 122 | self.num_classes = len(self.classes) 123 | self._class_to_ind = dict( 124 | zip(self.classes, range(self.num_classes))) 125 | self._class_to_coco_ind = dict( 126 | zip(cats, self.coco.getCatIds())) 127 | self._coco_ind_to_class_ind = dict( 128 | (self._class_to_coco_ind[cls], self._class_to_ind[cls]) 129 | for cls in self.classes[1:]) 130 | 131 | if self.select_data: 132 | catids = [] 133 | self.img_ids = [] 134 | catids.extend(self.coco.getCatIds(catNms=cfg.DATASET.SUPERCATEGORY)) 135 | for catid in catids: 136 | self.img_ids.extend(self.coco.catToImgs[catid]) 137 | self.img_ids = list(set(self.img_ids)) 138 | else: 139 | self.img_ids = self.coco.getImgIds() 140 | self.num_images = len(self.img_ids) 141 | print('=> num_images: {}'.format(self.num_images)) 142 | 143 | self.id2name, self.name2id = self._get_mapping_id_name( 144 | self.coco.imgs) 145 | dataset_info = ap10k_info.dataset_info 146 | dataset_info = DatasetInfo(dataset_info) 147 | self.num_joints = dataset_info.keypoint_num 148 | self.flip_pairs = dataset_info.flip_pairs 149 | self.parent_ids = None 150 | self.upper_body_ids = dataset_info.upper_body_ids 151 | self.lower_body_ids = dataset_info.lower_body_ids 152 | self.joints_weight = np.array(dataset_info.joint_weights, 153 | dtype=np.float32).reshape((self.num_joints, 1)) 154 | 155 | self.sigmas = dataset_info.sigmas 156 | self.few_shot_setting = False 157 | self.db, self.id2Cat = self._get_db() 158 | logger.info('=> load {} samples'.format(len(self.db))) 159 | 160 | def _get_db(self): 161 | """Load dataset.""" 162 | assert self.use_gt_bbox 163 | gt_db, id2Cat = self._load_coco_keypoint_annotations() 164 | return gt_db, id2Cat 165 | 166 | def _load_coco_keypoint_annotations(self): 167 | """Ground truth bbox and keypoints.""" 168 | gt_db, id2Cat = [], dict() 169 | for img_id in self.img_ids: 170 | db_tmp, id2Cat_tmp = self._load_coco_keypoint_annotation_kernel( 171 | img_id) 172 | gt_db.extend(db_tmp) 173 | id2Cat.update({img_id: id2Cat_tmp}) 174 | return gt_db, id2Cat 175 | 176 | def _supercat2ids(self): 177 | self.supercat2ids = dict() 178 | for k in self.coco.cats.keys(): 179 | supercategory = self.coco.cats[k]['supercategory'] 180 | id = self.coco.cats[k]['id'] 181 | if supercategory in self.supercat2ids.keys(): 182 | self.supercat2ids[supercategory].append(id) 183 | else: 184 | self.supercat2ids[supercategory] = [id] 185 | 186 | 187 | def _load_coco_keypoint_annotation_kernel(self, img_id): 188 | """load annotation from COCOAPI. 189 | 190 | Note: 191 | bbox:[x1, y1, w, h] 192 | Args: 193 | img_id: coco image id 194 | Returns: 195 | dict: db entry 196 | """ 197 | img_ann = self.coco.loadImgs(img_id)[0] 198 | width = img_ann['width'] 199 | height = img_ann['height'] 200 | num_joints = self.num_joints 201 | 202 | ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False) 203 | objs = self.coco.loadAnns(ann_ids) 204 | 205 | # sanitize bboxes 206 | valid_objs = [] 207 | for obj in objs: 208 | if 'bbox' not in obj: 209 | continue 210 | x, y, w, h = obj['bbox'] 211 | x1 = max(0, x) 212 | y1 = max(0, y) 213 | x2 = min(width - 1, x1 + max(0, w - 1)) 214 | y2 = min(height - 1, y1 + max(0, h - 1)) 215 | if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1: 216 | obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1] 217 | valid_objs.append(obj) 218 | objs = valid_objs 219 | 220 | bbox_id = 0 221 | rec = [] 222 | id2Cat = [] 223 | for obj in objs: 224 | if 'keypoints' not in obj: 225 | continue 226 | if max(obj['keypoints']) == 0: 227 | continue 228 | if 'num_keypoints' in obj and obj['num_keypoints'] == 0: 229 | continue 230 | joints_3d = np.zeros((num_joints, 3), dtype=np.float32) 231 | joints_3d_visible = np.zeros((num_joints, 3), dtype=np.float32) 232 | 233 | keypoints = np.array(obj['keypoints']).reshape(-1, 3) 234 | # keypoints = keypoints[ap10k2coco] 235 | joints_3d[:, :2] = keypoints[:, :2] 236 | joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3]) 237 | 238 | center, scale = self._xywh2cs(*obj['clean_bbox'][:4]) 239 | 240 | image_file = os.path.join(self.root, 'data', self.id2name[img_id]) 241 | rec.append({ 242 | 'image_file': image_file, 243 | 'center': center, 244 | 'scale': scale, 245 | 'bbox': obj['clean_bbox'][:4], 246 | 'joints_3d': joints_3d, 247 | 'joints_3d_visible': joints_3d_visible, 248 | 'bbox_score': 1, 249 | 'bbox_id': bbox_id 250 | }) 251 | category = obj['category_id'] 252 | id2Cat.append({ 253 | 'image_file': image_file, 254 | 'bbox_id': bbox_id, 255 | 'category': category, 256 | }) 257 | bbox_id = bbox_id + 1 258 | 259 | return rec, id2Cat 260 | 261 | def evaluate(self, cfg, preds, output_dir, all_boxes, img_path, 262 | *args, **kwargs): 263 | rank = cfg.RANK 264 | 265 | res_folder = os.path.join(output_dir, 'results') 266 | if not os.path.exists(res_folder): 267 | try: 268 | os.makedirs(res_folder) 269 | except Exception: 270 | logger.error('Fail to make {}'.format(res_folder)) 271 | 272 | res_file = os.path.join( 273 | res_folder, 'keypoints_{}_results_{}.json'.format( 274 | self.image_set, rank) 275 | ) 276 | 277 | # person x (keypoints) 278 | _kpts = [] 279 | for idx, kpt in enumerate(preds): 280 | image_name = img_path[idx][-16:] 281 | image_id = self.name2id[image_name] 282 | bbox_id = int(all_boxes[idx][6]) 283 | cat = self.id2Cat[image_id][bbox_id]['category'] 284 | _kpts.append({ 285 | 'keypoints': kpt, 286 | 'center': all_boxes[idx][0:2], 287 | 'scale': all_boxes[idx][2:4], 288 | 'area': all_boxes[idx][4], 289 | 'score': all_boxes[idx][5], 290 | 'image': int(img_path[idx][-16:-4]), 291 | 'bbox_id': bbox_id, 292 | 'category': cat 293 | }) 294 | # image x person x (keypoints) 295 | kpts = defaultdict(list) 296 | for kpt in _kpts: 297 | kpts[kpt['image']].append(kpt) 298 | kpts = self._sort_and_unique_bboxes(kpts) 299 | # rescoring and oks nms 300 | num_joints = self.num_joints 301 | in_vis_thre = self.in_vis_thre 302 | oks_thre = self.oks_thre 303 | oks_nmsed_kpts = [] 304 | for img in kpts.keys(): 305 | img_kpts = kpts[img] 306 | for n_p in img_kpts: 307 | box_score = n_p['score'] 308 | kpt_score = 0 309 | valid_num = 0 310 | for n_jt in range(0, num_joints): 311 | t_s = n_p['keypoints'][n_jt][2] 312 | if t_s > in_vis_thre: 313 | kpt_score = kpt_score + t_s 314 | valid_num = valid_num + 1 315 | if valid_num != 0: 316 | kpt_score = kpt_score / valid_num 317 | # rescoring 318 | n_p['score'] = kpt_score * box_score 319 | 320 | if self.soft_nms: 321 | keep = soft_oks_nms( 322 | [img_kpts[i] for i in range(len(img_kpts))], 323 | oks_thre 324 | ) 325 | else: 326 | keep = oks_nms( 327 | [img_kpts[i] for i in range(len(img_kpts))], 328 | oks_thre, self.sigmas 329 | ) 330 | 331 | if len(keep) == 0: 332 | oks_nmsed_kpts.append(img_kpts) 333 | else: 334 | oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep]) 335 | 336 | self._write_coco_keypoint_results( 337 | oks_nmsed_kpts, res_file) 338 | 339 | info_str = self._do_python_keypoint_eval( 340 | res_file, res_folder) 341 | name_value = OrderedDict(info_str) 342 | return name_value, name_value['AP'] 343 | 344 | def _write_coco_keypoint_results(self, keypoints, res_file): 345 | data_pack = [ 346 | { 347 | 'cat_id': self._class_to_coco_ind[cls], 348 | 'cls_ind': cls_ind, 349 | 'cls': cls, 350 | 'ann_type': 'keypoints', 351 | 'keypoints': keypoints 352 | } 353 | for cls_ind, cls in enumerate(self.classes) if not cls == '__background__' 354 | ] 355 | 356 | results = self._coco_keypoint_results_one_category_kernel(data_pack[0]) 357 | logger.info('=> writing results json to %s' % res_file) 358 | with open(res_file, 'w') as f: 359 | json.dump(results, f, sort_keys=True, indent=4) 360 | try: 361 | json.load(open(res_file)) 362 | except Exception: 363 | content = [] 364 | with open(res_file, 'r') as f: 365 | for line in f: 366 | content.append(line) 367 | content[-1] = ']' 368 | with open(res_file, 'w') as f: 369 | for c in content: 370 | f.write(c) 371 | 372 | def _coco_keypoint_results_one_category_kernel(self, data_pack): 373 | keypoints = data_pack['keypoints'] 374 | cat_results = [] 375 | 376 | for img_kpts in keypoints: 377 | if len(img_kpts) == 0: 378 | continue 379 | 380 | _key_points = np.array([img_kpts[k]['keypoints'] 381 | for k in range(len(img_kpts))]) 382 | key_points = np.zeros( 383 | (_key_points.shape[0], self.num_joints * 3), dtype=np.float 384 | ) 385 | 386 | for ipt in range(self.num_joints): 387 | key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0] 388 | key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1] 389 | key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2] # keypoints score. 390 | 391 | result = [ 392 | { 393 | 'image_id': img_kpts[k]['image'], 394 | 'category_id': img_kpts[k]['category'], 395 | 'keypoints': list(key_points[k]), 396 | 'score': img_kpts[k]['score'], 397 | 'center': list(img_kpts[k]['center']), 398 | 'scale': list(img_kpts[k]['scale']) 399 | } 400 | for k in range(len(img_kpts)) 401 | ] 402 | cat_results.extend(result) 403 | 404 | return cat_results 405 | 406 | def _do_python_keypoint_eval(self, res_file, res_folder): 407 | coco_dt = self.coco.loadRes(res_file) 408 | coco_eval = COCOeval(self.coco, coco_dt, 'keypoints', self.sigmas) 409 | coco_eval.params.useSegm = None 410 | if self.select_data: 411 | coco_eval.params.imgIds = self.img_ids 412 | coco_eval.evaluate() 413 | coco_eval.accumulate() 414 | coco_eval.summarize() 415 | 416 | stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)'] 417 | 418 | info_str = [] 419 | for ind, name in enumerate(stats_names): 420 | info_str.append((name, coco_eval.stats[ind])) 421 | 422 | return info_str 423 | 424 | def _sort_and_unique_bboxes(self, kpts, key='bbox_id'): 425 | """sort kpts and remove the repeated ones.""" 426 | for img_id, persons in kpts.items(): 427 | num = len(persons) 428 | kpts[img_id] = sorted(kpts[img_id], key=lambda x: x[key]) 429 | for i in range(num - 1, 0, -1): 430 | if kpts[img_id][i][key] == kpts[img_id][i - 1][key]: 431 | del kpts[img_id][i] 432 | 433 | return kpts --------------------------------------------------------------------------------