├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assests ├── cctv.png ├── infer_results.jpg └── test.jpg ├── infer.py ├── pose ├── __init__.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ └── hrnet.py │ ├── posehrnet.py │ └── simdr.py └── utils │ ├── __init__.py │ ├── boxes.py │ ├── decode.py │ └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Repo-specific GitIgnore ---------------------------------------------------------------------------------------------- 2 | *.jpg 3 | *.jpeg 4 | *.png 5 | *.bmp 6 | *.tif 7 | *.tiff 8 | *.heic 9 | *.JPG 10 | *.JPEG 11 | *.PNG 12 | *.BMP 13 | *.TIF 14 | *.TIFF 15 | *.HEIC 16 | *.mp4 17 | *.mov 18 | *.MOV 19 | *.avi 20 | *.data 21 | *.json 22 | *.cfg 23 | !cfg/yolov3*.cfg 24 | 25 | test.ipynb 26 | test.py 27 | 28 | assets/ 29 | 30 | storage.googleapis.com 31 | test_imgs/ 32 | runs/* 33 | data/* 34 | !data/images/zidane.jpg 35 | !data/images/bus.jpg 36 | !data/coco.names 37 | !data/coco_paper.names 38 | !data/coco.data 39 | !data/coco_*.data 40 | !data/coco_*.txt 41 | !data/trainvalno5k.shapes 42 | !data/*.sh 43 | 44 | pycocotools/* 45 | results*.txt 46 | gcp_test*.sh 47 | 48 | checkpoints/ 49 | output/ 50 | 51 | # Datasets ------------------------------------------------------------------------------------------------------------- 52 | coco/ 53 | coco128/ 54 | VOC/ 55 | 56 | # MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- 57 | *.m~ 58 | *.mat 59 | !targets*.mat 60 | 61 | # Neural Network weights ----------------------------------------------------------------------------------------------- 62 | *.weights 63 | *.pt 64 | *.onnx 65 | *.mlmodel 66 | *.torchscript 67 | darknet53.conv.74 68 | yolov3-tiny.conv.15 69 | 70 | # GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- 71 | # Byte-compiled / optimized / DLL files 72 | __pycache__/ 73 | *.py[cod] 74 | *$py.class 75 | 76 | # C extensions 77 | *.so 78 | 79 | # Distribution / packaging 80 | .Python 81 | env/ 82 | build/ 83 | develop-eggs/ 84 | dist/ 85 | downloads/ 86 | eggs/ 87 | .eggs/ 88 | lib/ 89 | lib64/ 90 | parts/ 91 | sdist/ 92 | var/ 93 | wheels/ 94 | *.egg-info/ 95 | wandb/ 96 | .installed.cfg 97 | *.egg 98 | 99 | 100 | # PyInstaller 101 | # Usually these files are written by a python script from a template 102 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 103 | *.manifest 104 | *.spec 105 | 106 | # Installer logs 107 | pip-log.txt 108 | pip-delete-this-directory.txt 109 | 110 | # Unit test / coverage reports 111 | htmlcov/ 112 | .tox/ 113 | .coverage 114 | .coverage.* 115 | .cache 116 | nosetests.xml 117 | coverage.xml 118 | *.cover 119 | .hypothesis/ 120 | 121 | # Translations 122 | *.mo 123 | *.pot 124 | 125 | # Django stuff: 126 | *.log 127 | local_settings.py 128 | 129 | # Flask stuff: 130 | instance/ 131 | .webassets-cache 132 | 133 | # Scrapy stuff: 134 | .scrapy 135 | 136 | # Sphinx documentation 137 | docs/_build/ 138 | 139 | # PyBuilder 140 | target/ 141 | 142 | # Jupyter Notebook 143 | .ipynb_checkpoints 144 | 145 | # pyenv 146 | .python-version 147 | 148 | # celery beat schedule file 149 | celerybeat-schedule 150 | 151 | # SageMath parsed files 152 | *.sage.py 153 | 154 | # dotenv 155 | .env 156 | 157 | # virtualenv 158 | .venv* 159 | venv*/ 160 | ENV*/ 161 | 162 | # Spyder project settings 163 | .spyderproject 164 | .spyproject 165 | 166 | # Rope project settings 167 | .ropeproject 168 | 169 | # mkdocs documentation 170 | /site 171 | 172 | # mypy 173 | .mypy_cache/ 174 | 175 | 176 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- 177 | 178 | # General 179 | .DS_Store 180 | .AppleDouble 181 | .LSOverride 182 | 183 | # Icon must end with two \r 184 | Icon 185 | Icon? 186 | 187 | # Thumbnails 188 | ._* 189 | 190 | # Files that might appear in the root of a volume 191 | .DocumentRevisions-V100 192 | .fseventsd 193 | .Spotlight-V100 194 | .TemporaryItems 195 | .Trashes 196 | .VolumeIcon.icns 197 | .com.apple.timemachine.donotpresent 198 | 199 | # Directories potentially created on remote AFP share 200 | .AppleDB 201 | .AppleDesktop 202 | Network Trash Folder 203 | Temporary Items 204 | .apdisk 205 | 206 | 207 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 208 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 209 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 210 | 211 | # User-specific stuff: 212 | .idea/* 213 | .idea/**/workspace.xml 214 | .idea/**/tasks.xml 215 | .idea/dictionaries 216 | .html # Bokeh Plots 217 | .pg # TensorFlow Frozen Graphs 218 | .avi # videos 219 | 220 | # Sensitive or high-churn files: 221 | .idea/**/dataSources/ 222 | .idea/**/dataSources.ids 223 | .idea/**/dataSources.local.xml 224 | .idea/**/sqlDataSources.xml 225 | .idea/**/dynamic.xml 226 | .idea/**/uiDesigner.xml 227 | 228 | # Gradle: 229 | .idea/**/gradle.xml 230 | .idea/**/libraries 231 | 232 | # CMake 233 | cmake-build-debug/ 234 | cmake-build-release/ 235 | 236 | # Mongo Explorer plugin: 237 | .idea/**/mongoSettings.xml 238 | 239 | ## File-based project format: 240 | *.iws 241 | 242 | ## Plugin-specific files: 243 | 244 | # IntelliJ 245 | out/ 246 | 247 | # mpeltonen/sbt-idea plugin 248 | .idea_modules/ 249 | 250 | # JIRA plugin 251 | atlassian-ide-plugin.xml 252 | 253 | # Cursive Clojure plugin 254 | .idea/replstate.xml 255 | 256 | # Crashlytics plugin (for Android Studio and IntelliJ) 257 | com_crashlytics_export_strings.xml 258 | crashlytics.properties 259 | crashlytics-build.properties 260 | fabric.properties 261 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "yolov5"] 2 | path = yolov5 3 | url = https://github.com/ultralytics/yolov5 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 sithu3 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Top-Down Multi-person Pose Estimation 2 | 3 | ## Introduction 4 | 5 | Pose estimation find the keypoints belong to the people in the image. There are two methods exist for pose estimation. 6 | 7 | * **Bottom-Up** first finds the keypoints and associates them into different people in the image. (Generally faster and lower accuracy) 8 | * **Top-Down** first detect people in the image and estimate the keypoints. (Generally computationally intensive but better accuracy) 9 | 10 | This repo will only include top-down pose estimation models. 11 | 12 | ## Model Zoo 13 | 14 | [hrnet]: https://arxiv.org/abs/1908.07919 15 | [simdr]: http://arxiv.org/abs/2107.03332 16 | [psa]: https://arxiv.org/abs/2107.00782 17 | [rlepose]: https://arxiv.org/abs/2107.11291 18 | 19 | [hrnetw32]: https://drive.google.com/file/d/1YlPrQMZdNTMWIX3QJ5iKixN3qd0NCKFO/view?usp=sharing 20 | [hrnetw48]: https://drive.google.com/file/d/1hug4ptbf9Y125h9ZH72x4asY2lHt7NA6/view?usp=sharing 21 | 22 | [phrnetw32]: https://drive.google.com/file/d/1os6T42ri4zsVPXwceli3J3KtksIaaGgu/view?usp=sharing 23 | [phrnetw48]: https://drive.google.com/file/d/1MbEjiXkV83Pm3G2o_Rni4j9CT_jRDSAQ/view?usp=sharing 24 | [simdrw32]: https://drive.google.com/file/d/1Bd8h2H30tCN8WuLIhuSRF9ViN6zghj29/view?usp=sharing 25 | [simdrw48]: https://drive.google.com/file/d/1WU_9e0MxgrO8X4W6wKo16L8siCdwgLSZ/view?usp=sharing 26 | [sasimdrw48]: https://drive.google.com/file/d/1Tj9bGL7g7XRyL2F1a-uAcWhgYXnXpqBY/view?usp=sharing 27 | 28 |
29 | COCO-val with 56.4 Detector AP 30 | 31 | Model | Backbone | Image Size | AP | AP50 | AP75 | Params
(M) | FLOPs
(B) | FPS | Weights 32 | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- 33 | [PoseHRNet][hrnet] | HRNet-w32 | 256x192 | 74.4 | 90.5 | 81.9 | 29 | 7 | 25 | [download][phrnetw32] 34 | | | HRNet-w48 | 256x192 | 75.1 | 90.6 | 82.2 | 64 | 15 | 24 | [download][phrnetw48] 35 | [SimDR][simdr] | HRNet-w32 | 256x192 | 75.3 | - | - | 31 | 7 | 25 | [download][simdrw32] 36 | | | HRNet-w48 | 256x192 | 75.9 | 90.4 | 82.7 | 66 | 15 | 24 | [download][simdrw48] 37 | 38 |
39 | 40 | > Note: FPS is tested on a GTX1660ti with one person per frame including pre-processing, model inference and post-processing. Both detection and pose models are in PyTorch FP32. 41 | 42 |
43 | COCO-test with 60.9 Detector AP (click to expand) 44 | 45 | Model | Backbone | Image Size | AP | AP50 | AP75 | Params
(M) | FLOPs
(B) | Weights 46 | --- | --- | --- | --- | --- | --- | --- | --- | --- 47 | [SimDR*][simdr] | HRNet-w48 | 256x192 | 75.4 | 92.4 | 82.7 | 66 | 15 | [download][sasimdrw48] 48 | [RLEPose][rlepose] | HRNet-w48 | 384x288 | 75.7 | 92.3 | 82.9 | - | - | - 49 | [UDP+PSA][psa] | HRNet-w48 | 256x192 | 78.9 | 93.6 | 85.8 | 70 | 16 | - 50 | 51 |
52 | 53 |
54 |
55 | Download Backbone Models' Weights (click to expand) 56 | 57 | Model | Weights 58 | --- | --- 59 | HRNet-w32 | [download][hrnetw32] 60 | HRNet-w48 | [download][hrnetw48] 61 | 62 |
63 | 64 | ## Requirements 65 | 66 | * torch >= 1.8.1 67 | * torchvision >= 0.9.1 68 | 69 | Other requirements can be installed with `pip install -r requirements.txt`. 70 | 71 | Clone the repository recursively: 72 | 73 | ```bash 74 | $ git clone --recursive https://github.com/sithu31296/pose-estimation.git 75 | ``` 76 | 77 | 78 | ## Inference 79 | 80 | * Download a YOLOv5m trained on [CrowdHuman](https://www.crowdhuman.org/) dataset from [here](https://drive.google.com/file/d/1gglIwqxaH2iTvy6lZlXuAcMpd_U0GCUb/view?usp=sharing). (The weights are from [deepakcrk/yolov5-crowdhuman](https://github.com/deepakcrk/yolov5-crowdhuman).) 81 | * Download a pose estimation model's weights from the tables. 82 | * Run the following command. 83 | 84 | ```bash 85 | $ python infer.py --source TEST_SOURCE --det-model DET_MODEL_PATH --pose-model POSE_MODEL_PATH --img-size 640 86 | ``` 87 | 88 | Arguments: 89 | 90 | * `source`: Testing sources 91 | * To test an image, set to image file path. (For example, `assests/test.jpg`) 92 | * To test a folder containing images, set to folder name. (For example, `assests/`) 93 | * To test a video, set to video file path. (For example, `assests/video.mp4`) 94 | * To test with a webcam, set to `0`. 95 | * `det-model`: YOLOv5 model's weights path 96 | * `pose-model`: Pose estimation model's weights path 97 | 98 | Example inference results (image credit: [[1](https://www.flickr.com/photos/fotologic/6038911779/in/photostream/), [2](https://neuralet.com/article/pose-estimation-on-nvidia-jetson-platforms-using-openpifpaf/)]): 99 | 100 | ![infer_result](assests/infer_results.jpg) 101 | 102 | 103 | ## References 104 | 105 | * https://github.com/leoxiaobin/deep-high-resolution-net.pytorch 106 | * https://github.com/ultralytics/yolov5 107 | 108 | ## Citations 109 | 110 | ``` 111 | @article{WangSCJDZLMTWLX19, 112 | title={Deep High-Resolution Representation Learning for Visual Recognition}, 113 | author={Jingdong Wang and Ke Sun and Tianheng Cheng and 114 | Borui Jiang and Chaorui Deng and Yang Zhao and Dong Liu and Yadong Mu and 115 | Mingkui Tan and Xinggang Wang and Wenyu Liu and Bin Xiao}, 116 | journal = {TPAMI} 117 | year={2019} 118 | } 119 | 120 | @misc{li20212d, 121 | title={Is 2D Heatmap Representation Even Necessary for Human Pose Estimation?}, 122 | author={Yanjie Li and Sen Yang and Shoukui Zhang and Zhicheng Wang and Wankou Yang and Shu-Tao Xia and Erjin Zhou}, 123 | year={2021}, 124 | eprint={2107.03332}, 125 | archivePrefix={arXiv}, 126 | primaryClass={cs.CV} 127 | } 128 | 129 | ``` -------------------------------------------------------------------------------- /assests/cctv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/assests/cctv.png -------------------------------------------------------------------------------- /assests/infer_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/assests/infer_results.jpg -------------------------------------------------------------------------------- /assests/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/assests/test.jpg -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | from torchvision import transforms as T 8 | 9 | from pose.models import get_pose_model 10 | from pose.utils.boxes import letterbox, scale_boxes, non_max_suppression, xyxy2xywh 11 | from pose.utils.decode import get_final_preds, get_simdr_final_preds 12 | from pose.utils.utils import setup_cudnn, get_affine_transform, draw_keypoints 13 | from pose.utils.utils import VideoReader, VideoWriter, WebcamStream, FPS 14 | 15 | import sys 16 | sys.path.insert(0, 'yolov5') 17 | from yolov5.models.experimental import attempt_load 18 | 19 | 20 | class Pose: 21 | def __init__(self, 22 | det_model, 23 | pose_model, 24 | img_size=640, 25 | conf_thres=0.25, 26 | iou_thres=0.45, 27 | ) -> None: 28 | self.img_size = img_size 29 | self.conf_thres = conf_thres 30 | self.iou_thres = iou_thres 31 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | self.det_model = attempt_load(det_model, map_location=self.device) 33 | self.det_model = self.det_model.to(self.device) 34 | 35 | self.model_name = pose_model 36 | self.pose_model = get_pose_model(pose_model) 37 | self.pose_model.load_state_dict(torch.load(pose_model, map_location='cpu')) 38 | self.pose_model = self.pose_model.to(self.device) 39 | self.pose_model.eval() 40 | 41 | self.patch_size = (192, 256) 42 | 43 | self.pose_transform = T.Compose([ 44 | T.ToTensor(), 45 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 46 | ]) 47 | 48 | self.coco_skeletons = [ 49 | [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8], 50 | [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7] 51 | ] 52 | 53 | def preprocess(self, image): 54 | img = letterbox(image, new_shape=self.img_size) 55 | img = np.ascontiguousarray(img.transpose((2, 0, 1))) 56 | img = torch.from_numpy(img).to(self.device) 57 | img = img.float() / 255.0 58 | img = img[None] 59 | return img 60 | 61 | def box_to_center_scale(self, boxes, pixel_std=200): 62 | boxes = xyxy2xywh(boxes) 63 | r = self.patch_size[0] / self.patch_size[1] 64 | mask = boxes[:, 2] > boxes[:, 3] * r 65 | boxes[mask, 3] = boxes[mask, 2] / r 66 | boxes[~mask, 2] = boxes[~mask, 3] * r 67 | boxes[:, 2:] /= pixel_std 68 | boxes[:, 2:] *= 1.25 69 | return boxes 70 | 71 | def predict_poses(self, boxes, img): 72 | image_patches = [] 73 | for cx, cy, w, h in boxes: 74 | trans = get_affine_transform(np.array([cx, cy]), np.array([w, h]), self.patch_size) 75 | img_patch = cv2.warpAffine(img, trans, self.patch_size, flags=cv2.INTER_LINEAR) 76 | img_patch = self.pose_transform(img_patch) 77 | image_patches.append(img_patch) 78 | 79 | image_patches = torch.stack(image_patches).to(self.device) 80 | return self.pose_model(image_patches) 81 | 82 | def postprocess(self, pred, img1, img0): 83 | pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=0) 84 | 85 | for det in pred: 86 | if len(det): 87 | boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:]).cpu() 88 | boxes = self.box_to_center_scale(boxes) 89 | outputs = self.predict_poses(boxes, img0) 90 | 91 | if 'simdr' in self.model_name: 92 | coords = get_simdr_final_preds(*outputs, boxes, self.patch_size) 93 | else: 94 | coords = get_final_preds(outputs, boxes) 95 | 96 | draw_keypoints(img0, coords, self.coco_skeletons) 97 | 98 | @torch.no_grad() 99 | def predict(self, image): 100 | img = self.preprocess(image) 101 | pred = self.det_model(img)[0] 102 | self.postprocess(pred, img, image) 103 | return image 104 | 105 | 106 | def argument_parser(): 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument('--source', type=str, default='assests/test.jpg') 109 | parser.add_argument('--det-model', type=str, default='checkpoints/crowdhuman_yolov5m.pt') 110 | parser.add_argument('--pose-model', type=str, default='checkpoints/pretrained/simdr_hrnet_w32_256x192.pth') 111 | parser.add_argument('--img-size', type=int, default=640) 112 | parser.add_argument('--conf-thres', type=float, default=0.4) 113 | parser.add_argument('--iou-thres', type=float, default=0.5) 114 | return parser.parse_args() 115 | 116 | 117 | if __name__ == '__main__': 118 | setup_cudnn() 119 | args = argument_parser() 120 | pose = Pose( 121 | args.det_model, 122 | args.pose_model, 123 | args.img_size, 124 | args.conf_thres, 125 | args.iou_thres 126 | ) 127 | 128 | source = Path(args.source) 129 | 130 | if source.is_file() and source.suffix in ['.jpg', '.png']: 131 | image = cv2.imread(str(source)) 132 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 133 | output = pose.predict(image) 134 | cv2.imwrite(f"{str(source).rsplit('.', maxsplit=1)[0]}_out.jpg", cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) 135 | 136 | elif source.is_dir(): 137 | files = source.glob("*.jpg") 138 | for file in files: 139 | image = cv2.imread(str(file)) 140 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 141 | output = pose.predict(image) 142 | cv2.imwrite(f"{str(file).rsplit('.', maxsplit=1)[0]}_out.jpg", cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) 143 | 144 | elif source.is_file() and source.suffix in ['.mp4', '.avi']: 145 | reader = VideoReader(args.source) 146 | writer = VideoWriter(f"{args.source.rsplit('.', maxsplit=1)[0]}_out.mp4", reader.fps) 147 | fps = FPS(len(reader.frames)) 148 | 149 | for frame in tqdm(reader): 150 | fps.start() 151 | output = pose.predict(frame.numpy()) 152 | fps.stop(False) 153 | writer.update(output) 154 | 155 | print(f"FPS: {fps.fps}") 156 | writer.write() 157 | 158 | else: 159 | webcam = WebcamStream() 160 | fps = FPS() 161 | 162 | for frame in webcam: 163 | fps.start() 164 | output = pose.predict(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 165 | fps.stop() 166 | cv2.imshow('frame', cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) -------------------------------------------------------------------------------- /pose/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/pose/__init__.py -------------------------------------------------------------------------------- /pose/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .posehrnet import PoseHRNet 2 | from .simdr import SimDR 3 | 4 | 5 | __all__ = ['PoseHRNet', 'SimDR'] 6 | 7 | 8 | def get_pose_model(model_path: str): 9 | if 'posehrnet' in model_path: 10 | model = PoseHRNet('w32' if 'w32' in model_path else 'w48') 11 | elif 'simdr' in model_path: 12 | model = SimDR('w32' if 'w32' in model_path else 'w48') 13 | else: 14 | raise NotImplementedError 15 | return model -------------------------------------------------------------------------------- /pose/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .hrnet import HRNet -------------------------------------------------------------------------------- /pose/models/backbones/hrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | 5 | class Conv(nn.Sequential): 6 | def __init__(self, c1, c2, k, s=1, p=0): 7 | super().__init__( 8 | nn.Conv2d(c1, c2, k, s, p, bias=False), 9 | nn.BatchNorm2d(c2), 10 | nn.ReLU(True) 11 | ) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super().__init__() 19 | self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU(True) 22 | 23 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | identity = x 31 | 32 | out = self.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | 35 | if self.downsample is not None: 36 | identity = self.downsample(x) 37 | 38 | out += identity 39 | out = self.relu(out) 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super().__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | 51 | self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | 54 | self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, bias=False) 55 | self.bn3 = nn.BatchNorm2d(planes*self.expansion) 56 | 57 | self.relu = nn.ReLU(True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x: Tensor) -> Tensor: 62 | identity = x 63 | 64 | out = self.relu(self.bn1(self.conv1(x))) 65 | out = self.relu(self.bn2(self.conv2(out))) 66 | out = self.bn3(self.conv3(out)) 67 | 68 | if self.downsample is not None: 69 | identity = self.downsample(x) 70 | 71 | out += identity 72 | out = self.relu(out) 73 | return out 74 | 75 | 76 | class HRModule(nn.Module): 77 | def __init__(self, num_branches, num_channels, ms_output=True): 78 | super().__init__() 79 | self.num_branches = num_branches 80 | self.branches = nn.ModuleList([ 81 | nn.Sequential(*[ 82 | BasicBlock(num_channels[i], num_channels[i]) 83 | for _ in range(4)]) 84 | for i in range(num_branches)]) 85 | 86 | self.fuse_layers = self._make_fuse_layers(num_branches, num_channels, ms_output) 87 | self.relu = nn.ReLU(True) 88 | 89 | def _make_fuse_layers(self, num_branches, num_channels, ms_output=True): 90 | fuse_layers = [] 91 | 92 | for i in range(num_branches if ms_output else 1): 93 | fuse_layer = [] 94 | 95 | for j in range(num_branches): 96 | if j > i: 97 | fuse_layer.append( 98 | nn.Sequential( 99 | nn.Conv2d(num_channels[j], num_channels[i], 1, bias=False), 100 | nn.BatchNorm2d(num_channels[i]), 101 | nn.Upsample(scale_factor=2**(j-i), mode='nearest') 102 | ) 103 | ) 104 | elif j == i: 105 | fuse_layer.append(None) 106 | else: 107 | conv3x3s = [] 108 | for k in range(i-j): 109 | if k == i - j -1: 110 | conv3x3s.append( 111 | nn.Sequential( 112 | nn.Conv2d(num_channels[j], num_channels[i], 3, 2, 1, bias=False), 113 | nn.BatchNorm2d(num_channels[i]) 114 | ) 115 | ) 116 | else: 117 | conv3x3s.append(Conv(num_channels[j], num_channels[j], 3, 2, 1)) 118 | fuse_layer.append(nn.Sequential(*conv3x3s)) 119 | fuse_layers.append(nn.ModuleList(fuse_layer)) 120 | 121 | return nn.ModuleList(fuse_layers) 122 | 123 | def forward(self, x: Tensor) -> Tensor: 124 | for i, m in enumerate(self.branches): 125 | x[i] = m(x[i]) 126 | 127 | x_fuse = [] 128 | 129 | for i, fm in enumerate(self.fuse_layers): 130 | y = x[0] if i == 0 else fm[0](x[0]) 131 | 132 | for j in range(1, self.num_branches): 133 | y = y + x[j] if i == j else y + fm[j](x[j]) 134 | x_fuse.append(self.relu(y)) 135 | return x_fuse 136 | 137 | 138 | hrnet_settings = { 139 | "w18": [18, 36, 72, 144], 140 | "w32": [32, 64, 128, 256], 141 | "w48": [48, 96, 192, 384] 142 | } 143 | 144 | 145 | class HRNet(nn.Module): 146 | def __init__(self, backbone: str = 'w18') -> None: 147 | super().__init__() 148 | assert backbone in hrnet_settings.keys(), f"HRNet model name should be in {list(hrnet_settings.keys())}" 149 | 150 | # stem 151 | self.conv1 = nn.Conv2d(3, 64, 3, 2, 1, bias=False) 152 | self.bn1 = nn.BatchNorm2d(64) 153 | self.conv2 = nn.Conv2d(64, 64, 3, 2, 1, bias=False) 154 | self.bn2 = nn.BatchNorm2d(64) 155 | self.relu = nn.ReLU(True) 156 | 157 | self.all_channels = hrnet_settings[backbone] 158 | 159 | # Stage 1 160 | self.layer1 = self._make_layer(64, 64, 4) 161 | stage1_out_channel = Bottleneck.expansion * 64 162 | 163 | # Stage 2 164 | stage2_channels = self.all_channels[:2] 165 | self.transition1 = self._make_transition_layer([stage1_out_channel], stage2_channels) 166 | self.stage2 = self._make_stage(1, 2, stage2_channels) 167 | 168 | # # Stage 3 169 | stage3_channels = self.all_channels[:3] 170 | self.transition2 = self._make_transition_layer(stage2_channels, stage3_channels) 171 | self.stage3 = self._make_stage(4, 3, stage3_channels) 172 | 173 | # # Stage 4 174 | self.transition3 = self._make_transition_layer(stage3_channels, self.all_channels) 175 | self.stage4 = self._make_stage(3, 4, self.all_channels, ms_output=False) 176 | 177 | def _make_layer(self, inplanes, planes, blocks): 178 | downsample = None 179 | if inplanes != planes * Bottleneck.expansion: 180 | downsample = nn.Sequential( 181 | nn.Conv2d(inplanes, planes*Bottleneck.expansion, 1, bias=False), 182 | nn.BatchNorm2d(planes*Bottleneck.expansion) 183 | ) 184 | 185 | layers = [] 186 | layers.append(Bottleneck(inplanes, planes, downsample=downsample)) 187 | inplanes = planes * Bottleneck.expansion 188 | 189 | for _ in range(1, blocks): 190 | layers.append(Bottleneck(inplanes, planes)) 191 | 192 | return nn.Sequential(*layers) 193 | 194 | def _make_transition_layer(self, c1s, c2s): 195 | num_branches_pre = len(c1s) 196 | num_branches_cur = len(c2s) 197 | 198 | transition_layers = [] 199 | 200 | for i in range(num_branches_cur): 201 | if i < num_branches_pre: 202 | if c1s[i] != c2s[i]: 203 | transition_layers.append(Conv(c1s[i], c2s[i], 3, 1, 1)) 204 | else: 205 | transition_layers.append(None) 206 | else: 207 | conv3x3s = [] 208 | for j in range(i+1-num_branches_pre): 209 | inchannels = c1s[-1] 210 | outchannels = c2s[i] if j == i-num_branches_pre else inchannels 211 | conv3x3s.append(Conv(inchannels, outchannels, 3, 2, 1)) 212 | transition_layers.append(nn.Sequential(*conv3x3s)) 213 | 214 | return nn.ModuleList(transition_layers) 215 | 216 | 217 | def _make_stage(self, num_modules, num_branches, num_channels, ms_output=True): 218 | modules = [] 219 | 220 | for i in range(num_modules): 221 | # multi-scale output is only used in last module 222 | if not ms_output and i == num_modules - 1: 223 | reset_ms_output = False 224 | else: 225 | reset_ms_output = True 226 | modules.append(HRModule(num_branches, num_channels, reset_ms_output)) 227 | 228 | return nn.Sequential(*modules) 229 | 230 | 231 | def forward(self, x: Tensor) -> Tensor: 232 | x = self.relu(self.bn1(self.conv1(x))) 233 | x = self.relu(self.bn2(self.conv2(x))) 234 | 235 | # Stage 1 236 | x = self.layer1(x) 237 | 238 | # Stage 2 239 | x_list = [trans(x) if trans is not None else x for trans in self.transition1] 240 | y_list = self.stage2(x_list) 241 | 242 | # Stage 3 243 | x_list = [trans(y_list[-1]) if trans is not None else y_list[i] for i, trans in enumerate(self.transition2)] 244 | y_list = self.stage3(x_list) 245 | 246 | # # Stage 4 247 | x_list = [trans(y_list[-1]) if trans is not None else y_list[i] for i, trans in enumerate(self.transition3)] 248 | y_list = self.stage4(x_list) 249 | return y_list[0] 250 | 251 | 252 | if __name__ == '__main__': 253 | model = HRNet('w32') 254 | model.load_state_dict(torch.load('./checkpoints/backbone/hrnet_w32.pth', map_location='cpu'), strict=False) 255 | x = torch.randn(1, 3, 224, 224) 256 | y = model(x) 257 | print(y.shape) -------------------------------------------------------------------------------- /pose/models/posehrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from .backbones import HRNet 4 | 5 | 6 | class PoseHRNet(nn.Module): 7 | def __init__(self, backbone: str = 'w32', num_joints: int = 17): 8 | super().__init__() 9 | self.backbone = HRNet(backbone) 10 | self.final_layer = nn.Conv2d(self.backbone.all_channels[0], num_joints, 1) 11 | 12 | self.apply(self._init_weights) 13 | 14 | def _init_weights(self, m: nn.Module) -> None: 15 | if isinstance(m, nn.Conv2d): 16 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 17 | elif isinstance(m, nn.BatchNorm2d): 18 | nn.init.constant_(m.weight, 1) 19 | nn.init.constant_(m.bias, 0) 20 | 21 | def init_pretrained(self, pretrained: str = None) -> None: 22 | if pretrained: 23 | self.backbone.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False) 24 | 25 | def forward(self, x: Tensor) -> Tensor: 26 | out = self.backbone(x) 27 | out = self.final_layer(out) 28 | return out 29 | 30 | 31 | if __name__ == '__main__': 32 | model = PoseHRNet('w48') 33 | model.load_state_dict(torch.load('checkpoints/pretrained/posehrnet_w48_256x192.pth', map_location='cpu')) 34 | x = torch.randn(1, 3, 256, 192) 35 | y = model(x) 36 | print(y.shape) -------------------------------------------------------------------------------- /pose/models/simdr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from .backbones import HRNet 4 | 5 | 6 | class SimDR(nn.Module): 7 | def __init__(self, backbone: str = 'w32', num_joints: int = 17, image_size: tuple = (256, 192)): 8 | super().__init__() 9 | self.backbone = HRNet(backbone) 10 | self.final_layer = nn.Conv2d(self.backbone.all_channels[0], num_joints, 1) 11 | self.mlp_head_x = nn.Linear(3072, int(image_size[1] * 2.0)) 12 | self.mlp_head_y = nn.Linear(3072, int(image_size[0] * 2.0)) 13 | 14 | self.apply(self._init_weights) 15 | 16 | def _init_weights(self, m: nn.Module) -> None: 17 | if isinstance(m, nn.Conv2d): 18 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 19 | elif isinstance(m, nn.BatchNorm2d): 20 | nn.init.constant_(m.weight, 1) 21 | nn.init.constant_(m.bias, 0) 22 | 23 | def init_pretrained(self, pretrained: str = None) -> None: 24 | if pretrained: 25 | self.backbone.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | out = self.backbone(x) 29 | out = self.final_layer(out).flatten(2) 30 | pred_x = self.mlp_head_x(out) 31 | pred_y = self.mlp_head_y(out) 32 | return pred_x, pred_y 33 | 34 | 35 | if __name__ == '__main__': 36 | from torch.nn import functional as F 37 | model = SimDR('w32') 38 | model.load_state_dict(torch.load('checkpoints/pretrained/simdr_hrnet_w32_256x192.pth', map_location='cpu')) 39 | x = torch.randn(4, 3, 256, 192) 40 | px, py = model(x) 41 | print(px.shape, py.shape) 42 | -------------------------------------------------------------------------------- /pose/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/pose/utils/__init__.py -------------------------------------------------------------------------------- /pose/utils/boxes.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from torchvision import ops 5 | 6 | 7 | def letterbox(img, new_shape=(640, 640)): 8 | H, W = img.shape[:2] 9 | if isinstance(new_shape, int): 10 | new_shape = (new_shape, new_shape) 11 | 12 | r = min(new_shape[0] / H, new_shape[1] / W) 13 | nH, nW = round(H * r), round(W * r) 14 | pH, pW = np.mod(new_shape[0] - nH, 32) / 2, np.mod(new_shape[1] - nW, 32) / 2 15 | 16 | if (H, W) != (nH, nW): 17 | img = cv2.resize(img, (nW, nH), interpolation=cv2.INTER_LINEAR) 18 | 19 | top, bottom = round(pH - 0.1), round(pH + 0.1) 20 | left, right = round(pW - 0.1), round(pW + 0.1) 21 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) 22 | return img 23 | 24 | 25 | def scale_boxes(boxes, orig_shape, new_shape): 26 | H, W = orig_shape 27 | nH, nW = new_shape 28 | gain = min(nH / H, nW / W) 29 | pad = (nH - H * gain) / 2, (nW - W * gain) / 2 30 | 31 | boxes[:, ::2] -= pad[1] 32 | boxes[:, 1::2] -= pad[0] 33 | boxes[:, :4] /= gain 34 | 35 | boxes[:, ::2].clamp_(0, orig_shape[1]) 36 | boxes[:, 1::2].clamp_(0, orig_shape[0]) 37 | return boxes.round() 38 | 39 | 40 | def xywh2xyxy(x): 41 | boxes = x.clone() 42 | boxes[:, 0] = x[:, 0] - x[:, 2] / 2 43 | boxes[:, 1] = x[:, 1] - x[:, 3] / 2 44 | boxes[:, 2] = x[:, 0] + x[:, 2] / 2 45 | boxes[:, 3] = x[:, 1] + x[:, 3] / 2 46 | return boxes 47 | 48 | 49 | def xyxy2xywh(x): 50 | y = x.clone() 51 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center 52 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center 53 | y[:, 2] = x[:, 2] - x[:, 0] # width 54 | y[:, 3] = x[:, 3] - x[:, 1] # height 55 | return y 56 | 57 | 58 | def non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, classes=None): 59 | candidates = pred[..., 4] > conf_thres 60 | 61 | max_wh = 4096 62 | max_nms = 30000 63 | max_det = 300 64 | 65 | output = [torch.zeros((0, 6), device=pred.device)] * pred.shape[0] 66 | 67 | for xi, x in enumerate(pred): 68 | x = x[candidates[xi]] 69 | 70 | if not x.shape[0]: continue 71 | 72 | # compute conf 73 | x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf 74 | 75 | # box 76 | box = xywh2xyxy(x[:, :4]) 77 | 78 | # detection matrix nx6 79 | conf, j = x[:, 5:].max(1, keepdim=True) 80 | x = torch.cat([box, conf, j.float()], dim=1)[conf.view(-1) > conf_thres] 81 | 82 | # filter by class 83 | if classes is not None: 84 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] 85 | 86 | # check shape 87 | n = x.shape[0] 88 | if not n: 89 | continue 90 | elif n > max_nms: 91 | x = x[x[:, 4].argsort(descending=True)[:max_nms]] 92 | 93 | # batched nms 94 | c = x[:, 5:6] * max_wh 95 | boxes, scores = x[:, :4] + c, x[:, 4] 96 | keep = ops.nms(boxes, scores, iou_thres) 97 | 98 | if keep.shape[0] > max_det: 99 | keep = keep[:max_det] 100 | 101 | output[xi] = x[keep] 102 | 103 | return output -------------------------------------------------------------------------------- /pose/utils/decode.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | from torch import Tensor 5 | 6 | 7 | def get_simdr_final_preds(pred_x: Tensor, pred_y: Tensor, boxes: Tensor, image_size: tuple): 8 | center, scale = boxes[:, :2].numpy(), boxes[:, 2:].numpy() 9 | 10 | pred_x, pred_y = pred_x.softmax(dim=2), pred_y.softmax(dim=2) 11 | pred_x, pred_y = pred_x.max(dim=2)[-1], pred_y.max(dim=2)[-1] 12 | coords = torch.stack([pred_x / 2, pred_y / 2], dim=-1).cpu().numpy() 13 | 14 | for i in range(coords.shape[0]): 15 | coords[i] = transform_preds(coords[i], center[i], scale[i], image_size) 16 | return coords.astype(int) 17 | 18 | 19 | def get_final_preds(heatmaps: Tensor, boxes: Tensor): 20 | center, scale = boxes[:, :2].numpy(), boxes[:, 2:].numpy() 21 | heatmaps = heatmaps.cpu().numpy() 22 | B, C, H, W = heatmaps.shape 23 | coords = get_max_preds(heatmaps) 24 | 25 | for n in range(B): 26 | for p in range(C): 27 | hm = heatmaps[n][p] 28 | px = int(math.floor(coords[n][p][0] + 0.5)) 29 | py = int(math.floor(coords[n][p][1] + 0.5)) 30 | 31 | if 1 < px < W - 1 and 1 < py < H - 1: 32 | diff = np.array([ 33 | hm[py][px+1] - hm[py][px-1], 34 | hm[py+1][px] - hm[py-1][px] 35 | ]) 36 | coords[n][p] += np.sign(diff) * .25 37 | 38 | for i in range(B): 39 | coords[i] = transform_preds(coords[i], center[i], scale[i], [W, H]) 40 | return coords.astype(int) 41 | 42 | 43 | def get_max_preds(heatmaps: np.ndarray): 44 | B, C, _, W = heatmaps.shape 45 | heatmaps = heatmaps.reshape((B, C, -1)) 46 | idx = np.argmax(heatmaps, axis=2).reshape((B, C, 1)) 47 | maxvals = np.amax(heatmaps, axis=2).reshape((B, C, 1)) 48 | preds = np.tile(idx, (1, 1, 2)).astype(np.float32) 49 | preds[:, :, 0] = preds[:, :, 0] % W 50 | preds[:, :, 1] = preds[:, :, 1] // W 51 | preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1) 52 | return preds 53 | 54 | 55 | def transform_preds(coords, center, scale, output_size): 56 | scale = scale * 200 57 | scale_x = scale[0] / output_size[0] 58 | scale_y = scale[1] / output_size[1] 59 | target_coords = np.ones_like(coords) 60 | target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5 61 | target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5 62 | return target_coords -------------------------------------------------------------------------------- /pose/utils/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import time 5 | from torchvision import io 6 | from threading import Thread 7 | from torch.backends import cudnn 8 | 9 | 10 | def setup_cudnn() -> None: 11 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 12 | cudnn.benchmark = True 13 | cudnn.deterministic = False 14 | 15 | 16 | def draw_coco_keypoints(img, keypoints, skeletons): 17 | if keypoints == []: return img 18 | image = img.copy() 19 | for kpts in keypoints: 20 | for x, y, v in kpts: 21 | if v == 2: 22 | cv2.circle(image, (x, y), 4, (255, 0, 0), 2) 23 | for kid1, kid2 in skeletons: 24 | x1, y1, v1 = kpts[kid1-1] 25 | x2, y2, v2 = kpts[kid2-1] 26 | if v1 == 2 and v2 == 2: 27 | cv2.line(image, (x1, y1), (x2, y2), (0, 255, 0), 2) 28 | return image 29 | 30 | 31 | def draw_keypoints(img, keypoints, skeletons): 32 | if keypoints == []: return img 33 | for kpts in keypoints: 34 | for x, y in kpts: 35 | cv2.circle(img, (x, y), 4, (255, 0, 0), 2, cv2.LINE_AA) 36 | for kid1, kid2 in skeletons: 37 | cv2.line(img, kpts[kid1-1], kpts[kid2-1], (0, 255, 0), 2, cv2.LINE_AA) 38 | 39 | 40 | class WebcamStream: 41 | def __init__(self, src=0) -> None: 42 | self.cap = cv2.VideoCapture(src) 43 | self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) 44 | assert self.cap.isOpened(), f"Failed to open webcam {src}" 45 | _, self.frame = self.cap.read() 46 | Thread(target=self.update, args=([]), daemon=True).start() 47 | 48 | def update(self): 49 | while self.cap.isOpened(): 50 | _, self.frame = self.cap.read() 51 | 52 | def __iter__(self): 53 | self.count = -1 54 | return self 55 | 56 | def __next__(self): 57 | self.count += 1 58 | 59 | if cv2.waitKey(1) == ord('q'): 60 | self.stop() 61 | 62 | return self.frame.copy() 63 | 64 | def stop(self): 65 | cv2.destroyAllWindows() 66 | raise StopIteration 67 | 68 | def __len__(self): 69 | return 0 70 | 71 | 72 | class VideoReader: 73 | def __init__(self, video: str): 74 | self.frames, _, info = io.read_video(video, pts_unit='sec') 75 | self.fps = info['video_fps'] 76 | 77 | print(f"Processing '{video}'...") 78 | print(f"Total Frames: {len(self.frames)}") 79 | print(f"Video Size : {list(self.frames.shape[1:-1])}") 80 | print(f"Video FPS : {self.fps}") 81 | 82 | def __iter__(self): 83 | self.count = 0 84 | return self 85 | 86 | def __len__(self): 87 | return len(self.frames) 88 | 89 | def __next__(self): 90 | if self.count == len(self.frames): 91 | raise StopIteration 92 | frame = self.frames[self.count] 93 | self.count += 1 94 | return frame 95 | 96 | 97 | class VideoWriter: 98 | def __init__(self, file_name, fps): 99 | self.fname = file_name 100 | self.fps = fps 101 | self.frames = [] 102 | 103 | def update(self, frame): 104 | if isinstance(frame, np.ndarray): 105 | frame = torch.from_numpy(frame) 106 | self.frames.append(frame) 107 | 108 | def write(self): 109 | print(f"Saving video to '{self.fname}'...") 110 | io.write_video(self.fname, torch.stack(self.frames), self.fps) 111 | 112 | 113 | class FPS: 114 | def __init__(self, avg=10) -> None: 115 | self.accum_time = 0 116 | self.counts = 0 117 | self.avg = avg 118 | 119 | def synchronize(self): 120 | if torch.cuda.is_available(): 121 | torch.cuda.synchronize() 122 | 123 | def start(self): 124 | self.synchronize() 125 | self.prev_time = time.time() 126 | 127 | def stop(self, debug=True): 128 | self.synchronize() 129 | self.accum_time += time.time() - self.prev_time 130 | self.counts += 1 131 | if self.counts == self.avg: 132 | self.fps = round(self.counts / self.accum_time) 133 | if debug: print(f"FPS: {self.fps}") 134 | self.counts = 0 135 | self.accum_time = 0 136 | 137 | 138 | def get_dir(src_point, rot): 139 | rot_rad = np.pi * rot / 180 140 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 141 | p1 = src_point[0] * cs - src_point[1] * sn 142 | p2 = src_point[0] * sn + src_point[1] * cs 143 | return p1, p2 144 | 145 | 146 | def get_3rd_point(a, b): 147 | direct = a - b 148 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 149 | 150 | 151 | def get_affine_transform(center, scale, patch_size, rot=0, inv=False): 152 | shift = np.array([0, 0], dtype=np.float32) 153 | scale_tmp = scale * 200 154 | src_w = scale_tmp[0] 155 | dst_w = patch_size[0] 156 | dst_h = patch_size[1] 157 | 158 | src_dir = get_dir([0, src_w * -0.5], rot) 159 | dst_dir = np.array([0, dst_w * -0.5], dtype=np.float32) 160 | src = np.zeros((3, 2), dtype=np.float32) 161 | dst = np.zeros((3, 2), dtype=np.float32) 162 | 163 | src[0, :] = center + scale_tmp * shift 164 | src[1, :] = center + src_dir + scale_tmp * shift 165 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 166 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 167 | 168 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 169 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 170 | 171 | return cv2.getAffineTransform(dst, src) if inv else cv2.getAffineTransform(src, dst) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | numpy 3 | tqdm 4 | --------------------------------------------------------------------------------