├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assests
├── cctv.png
├── infer_results.jpg
└── test.jpg
├── infer.py
├── pose
├── __init__.py
├── models
│ ├── __init__.py
│ ├── backbones
│ │ ├── __init__.py
│ │ └── hrnet.py
│ ├── posehrnet.py
│ └── simdr.py
└── utils
│ ├── __init__.py
│ ├── boxes.py
│ ├── decode.py
│ └── utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Repo-specific GitIgnore ----------------------------------------------------------------------------------------------
2 | *.jpg
3 | *.jpeg
4 | *.png
5 | *.bmp
6 | *.tif
7 | *.tiff
8 | *.heic
9 | *.JPG
10 | *.JPEG
11 | *.PNG
12 | *.BMP
13 | *.TIF
14 | *.TIFF
15 | *.HEIC
16 | *.mp4
17 | *.mov
18 | *.MOV
19 | *.avi
20 | *.data
21 | *.json
22 | *.cfg
23 | !cfg/yolov3*.cfg
24 |
25 | test.ipynb
26 | test.py
27 |
28 | assets/
29 |
30 | storage.googleapis.com
31 | test_imgs/
32 | runs/*
33 | data/*
34 | !data/images/zidane.jpg
35 | !data/images/bus.jpg
36 | !data/coco.names
37 | !data/coco_paper.names
38 | !data/coco.data
39 | !data/coco_*.data
40 | !data/coco_*.txt
41 | !data/trainvalno5k.shapes
42 | !data/*.sh
43 |
44 | pycocotools/*
45 | results*.txt
46 | gcp_test*.sh
47 |
48 | checkpoints/
49 | output/
50 |
51 | # Datasets -------------------------------------------------------------------------------------------------------------
52 | coco/
53 | coco128/
54 | VOC/
55 |
56 | # MATLAB GitIgnore -----------------------------------------------------------------------------------------------------
57 | *.m~
58 | *.mat
59 | !targets*.mat
60 |
61 | # Neural Network weights -----------------------------------------------------------------------------------------------
62 | *.weights
63 | *.pt
64 | *.onnx
65 | *.mlmodel
66 | *.torchscript
67 | darknet53.conv.74
68 | yolov3-tiny.conv.15
69 |
70 | # GitHub Python GitIgnore ----------------------------------------------------------------------------------------------
71 | # Byte-compiled / optimized / DLL files
72 | __pycache__/
73 | *.py[cod]
74 | *$py.class
75 |
76 | # C extensions
77 | *.so
78 |
79 | # Distribution / packaging
80 | .Python
81 | env/
82 | build/
83 | develop-eggs/
84 | dist/
85 | downloads/
86 | eggs/
87 | .eggs/
88 | lib/
89 | lib64/
90 | parts/
91 | sdist/
92 | var/
93 | wheels/
94 | *.egg-info/
95 | wandb/
96 | .installed.cfg
97 | *.egg
98 |
99 |
100 | # PyInstaller
101 | # Usually these files are written by a python script from a template
102 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
103 | *.manifest
104 | *.spec
105 |
106 | # Installer logs
107 | pip-log.txt
108 | pip-delete-this-directory.txt
109 |
110 | # Unit test / coverage reports
111 | htmlcov/
112 | .tox/
113 | .coverage
114 | .coverage.*
115 | .cache
116 | nosetests.xml
117 | coverage.xml
118 | *.cover
119 | .hypothesis/
120 |
121 | # Translations
122 | *.mo
123 | *.pot
124 |
125 | # Django stuff:
126 | *.log
127 | local_settings.py
128 |
129 | # Flask stuff:
130 | instance/
131 | .webassets-cache
132 |
133 | # Scrapy stuff:
134 | .scrapy
135 |
136 | # Sphinx documentation
137 | docs/_build/
138 |
139 | # PyBuilder
140 | target/
141 |
142 | # Jupyter Notebook
143 | .ipynb_checkpoints
144 |
145 | # pyenv
146 | .python-version
147 |
148 | # celery beat schedule file
149 | celerybeat-schedule
150 |
151 | # SageMath parsed files
152 | *.sage.py
153 |
154 | # dotenv
155 | .env
156 |
157 | # virtualenv
158 | .venv*
159 | venv*/
160 | ENV*/
161 |
162 | # Spyder project settings
163 | .spyderproject
164 | .spyproject
165 |
166 | # Rope project settings
167 | .ropeproject
168 |
169 | # mkdocs documentation
170 | /site
171 |
172 | # mypy
173 | .mypy_cache/
174 |
175 |
176 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore -----------------------------------------------
177 |
178 | # General
179 | .DS_Store
180 | .AppleDouble
181 | .LSOverride
182 |
183 | # Icon must end with two \r
184 | Icon
185 | Icon?
186 |
187 | # Thumbnails
188 | ._*
189 |
190 | # Files that might appear in the root of a volume
191 | .DocumentRevisions-V100
192 | .fseventsd
193 | .Spotlight-V100
194 | .TemporaryItems
195 | .Trashes
196 | .VolumeIcon.icns
197 | .com.apple.timemachine.donotpresent
198 |
199 | # Directories potentially created on remote AFP share
200 | .AppleDB
201 | .AppleDesktop
202 | Network Trash Folder
203 | Temporary Items
204 | .apdisk
205 |
206 |
207 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore
208 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
209 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
210 |
211 | # User-specific stuff:
212 | .idea/*
213 | .idea/**/workspace.xml
214 | .idea/**/tasks.xml
215 | .idea/dictionaries
216 | .html # Bokeh Plots
217 | .pg # TensorFlow Frozen Graphs
218 | .avi # videos
219 |
220 | # Sensitive or high-churn files:
221 | .idea/**/dataSources/
222 | .idea/**/dataSources.ids
223 | .idea/**/dataSources.local.xml
224 | .idea/**/sqlDataSources.xml
225 | .idea/**/dynamic.xml
226 | .idea/**/uiDesigner.xml
227 |
228 | # Gradle:
229 | .idea/**/gradle.xml
230 | .idea/**/libraries
231 |
232 | # CMake
233 | cmake-build-debug/
234 | cmake-build-release/
235 |
236 | # Mongo Explorer plugin:
237 | .idea/**/mongoSettings.xml
238 |
239 | ## File-based project format:
240 | *.iws
241 |
242 | ## Plugin-specific files:
243 |
244 | # IntelliJ
245 | out/
246 |
247 | # mpeltonen/sbt-idea plugin
248 | .idea_modules/
249 |
250 | # JIRA plugin
251 | atlassian-ide-plugin.xml
252 |
253 | # Cursive Clojure plugin
254 | .idea/replstate.xml
255 |
256 | # Crashlytics plugin (for Android Studio and IntelliJ)
257 | com_crashlytics_export_strings.xml
258 | crashlytics.properties
259 | crashlytics-build.properties
260 | fabric.properties
261 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "yolov5"]
2 | path = yolov5
3 | url = https://github.com/ultralytics/yolov5
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 sithu3
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Top-Down Multi-person Pose Estimation
2 |
3 | ## Introduction
4 |
5 | Pose estimation find the keypoints belong to the people in the image. There are two methods exist for pose estimation.
6 |
7 | * **Bottom-Up** first finds the keypoints and associates them into different people in the image. (Generally faster and lower accuracy)
8 | * **Top-Down** first detect people in the image and estimate the keypoints. (Generally computationally intensive but better accuracy)
9 |
10 | This repo will only include top-down pose estimation models.
11 |
12 | ## Model Zoo
13 |
14 | [hrnet]: https://arxiv.org/abs/1908.07919
15 | [simdr]: http://arxiv.org/abs/2107.03332
16 | [psa]: https://arxiv.org/abs/2107.00782
17 | [rlepose]: https://arxiv.org/abs/2107.11291
18 |
19 | [hrnetw32]: https://drive.google.com/file/d/1YlPrQMZdNTMWIX3QJ5iKixN3qd0NCKFO/view?usp=sharing
20 | [hrnetw48]: https://drive.google.com/file/d/1hug4ptbf9Y125h9ZH72x4asY2lHt7NA6/view?usp=sharing
21 |
22 | [phrnetw32]: https://drive.google.com/file/d/1os6T42ri4zsVPXwceli3J3KtksIaaGgu/view?usp=sharing
23 | [phrnetw48]: https://drive.google.com/file/d/1MbEjiXkV83Pm3G2o_Rni4j9CT_jRDSAQ/view?usp=sharing
24 | [simdrw32]: https://drive.google.com/file/d/1Bd8h2H30tCN8WuLIhuSRF9ViN6zghj29/view?usp=sharing
25 | [simdrw48]: https://drive.google.com/file/d/1WU_9e0MxgrO8X4W6wKo16L8siCdwgLSZ/view?usp=sharing
26 | [sasimdrw48]: https://drive.google.com/file/d/1Tj9bGL7g7XRyL2F1a-uAcWhgYXnXpqBY/view?usp=sharing
27 |
28 |
29 | COCO-val with 56.4 Detector AP
30 |
31 | Model | Backbone | Image Size | AP | AP50 | AP75 | Params
(M) | FLOPs
(B) | FPS | Weights
32 | --- | --- | --- | --- | --- | --- | --- | --- | --- | ---
33 | [PoseHRNet][hrnet] | HRNet-w32 | 256x192 | 74.4 | 90.5 | 81.9 | 29 | 7 | 25 | [download][phrnetw32]
34 | | | HRNet-w48 | 256x192 | 75.1 | 90.6 | 82.2 | 64 | 15 | 24 | [download][phrnetw48]
35 | [SimDR][simdr] | HRNet-w32 | 256x192 | 75.3 | - | - | 31 | 7 | 25 | [download][simdrw32]
36 | | | HRNet-w48 | 256x192 | 75.9 | 90.4 | 82.7 | 66 | 15 | 24 | [download][simdrw48]
37 |
38 |
39 |
40 | > Note: FPS is tested on a GTX1660ti with one person per frame including pre-processing, model inference and post-processing. Both detection and pose models are in PyTorch FP32.
41 |
42 |
43 | COCO-test with 60.9 Detector AP (click to expand)
44 |
45 | Model | Backbone | Image Size | AP | AP50 | AP75 | Params
(M) | FLOPs
(B) | Weights
46 | --- | --- | --- | --- | --- | --- | --- | --- | ---
47 | [SimDR*][simdr] | HRNet-w48 | 256x192 | 75.4 | 92.4 | 82.7 | 66 | 15 | [download][sasimdrw48]
48 | [RLEPose][rlepose] | HRNet-w48 | 384x288 | 75.7 | 92.3 | 82.9 | - | - | -
49 | [UDP+PSA][psa] | HRNet-w48 | 256x192 | 78.9 | 93.6 | 85.8 | 70 | 16 | -
50 |
51 |
52 |
53 |
54 |
55 | Download Backbone Models' Weights (click to expand)
56 |
57 | Model | Weights
58 | --- | ---
59 | HRNet-w32 | [download][hrnetw32]
60 | HRNet-w48 | [download][hrnetw48]
61 |
62 |
63 |
64 | ## Requirements
65 |
66 | * torch >= 1.8.1
67 | * torchvision >= 0.9.1
68 |
69 | Other requirements can be installed with `pip install -r requirements.txt`.
70 |
71 | Clone the repository recursively:
72 |
73 | ```bash
74 | $ git clone --recursive https://github.com/sithu31296/pose-estimation.git
75 | ```
76 |
77 |
78 | ## Inference
79 |
80 | * Download a YOLOv5m trained on [CrowdHuman](https://www.crowdhuman.org/) dataset from [here](https://drive.google.com/file/d/1gglIwqxaH2iTvy6lZlXuAcMpd_U0GCUb/view?usp=sharing). (The weights are from [deepakcrk/yolov5-crowdhuman](https://github.com/deepakcrk/yolov5-crowdhuman).)
81 | * Download a pose estimation model's weights from the tables.
82 | * Run the following command.
83 |
84 | ```bash
85 | $ python infer.py --source TEST_SOURCE --det-model DET_MODEL_PATH --pose-model POSE_MODEL_PATH --img-size 640
86 | ```
87 |
88 | Arguments:
89 |
90 | * `source`: Testing sources
91 | * To test an image, set to image file path. (For example, `assests/test.jpg`)
92 | * To test a folder containing images, set to folder name. (For example, `assests/`)
93 | * To test a video, set to video file path. (For example, `assests/video.mp4`)
94 | * To test with a webcam, set to `0`.
95 | * `det-model`: YOLOv5 model's weights path
96 | * `pose-model`: Pose estimation model's weights path
97 |
98 | Example inference results (image credit: [[1](https://www.flickr.com/photos/fotologic/6038911779/in/photostream/), [2](https://neuralet.com/article/pose-estimation-on-nvidia-jetson-platforms-using-openpifpaf/)]):
99 |
100 | 
101 |
102 |
103 | ## References
104 |
105 | * https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
106 | * https://github.com/ultralytics/yolov5
107 |
108 | ## Citations
109 |
110 | ```
111 | @article{WangSCJDZLMTWLX19,
112 | title={Deep High-Resolution Representation Learning for Visual Recognition},
113 | author={Jingdong Wang and Ke Sun and Tianheng Cheng and
114 | Borui Jiang and Chaorui Deng and Yang Zhao and Dong Liu and Yadong Mu and
115 | Mingkui Tan and Xinggang Wang and Wenyu Liu and Bin Xiao},
116 | journal = {TPAMI}
117 | year={2019}
118 | }
119 |
120 | @misc{li20212d,
121 | title={Is 2D Heatmap Representation Even Necessary for Human Pose Estimation?},
122 | author={Yanjie Li and Sen Yang and Shoukui Zhang and Zhicheng Wang and Wankou Yang and Shu-Tao Xia and Erjin Zhou},
123 | year={2021},
124 | eprint={2107.03332},
125 | archivePrefix={arXiv},
126 | primaryClass={cs.CV}
127 | }
128 |
129 | ```
--------------------------------------------------------------------------------
/assests/cctv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/assests/cctv.png
--------------------------------------------------------------------------------
/assests/infer_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/assests/infer_results.jpg
--------------------------------------------------------------------------------
/assests/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/assests/test.jpg
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import argparse
4 | import numpy as np
5 | from tqdm import tqdm
6 | from pathlib import Path
7 | from torchvision import transforms as T
8 |
9 | from pose.models import get_pose_model
10 | from pose.utils.boxes import letterbox, scale_boxes, non_max_suppression, xyxy2xywh
11 | from pose.utils.decode import get_final_preds, get_simdr_final_preds
12 | from pose.utils.utils import setup_cudnn, get_affine_transform, draw_keypoints
13 | from pose.utils.utils import VideoReader, VideoWriter, WebcamStream, FPS
14 |
15 | import sys
16 | sys.path.insert(0, 'yolov5')
17 | from yolov5.models.experimental import attempt_load
18 |
19 |
20 | class Pose:
21 | def __init__(self,
22 | det_model,
23 | pose_model,
24 | img_size=640,
25 | conf_thres=0.25,
26 | iou_thres=0.45,
27 | ) -> None:
28 | self.img_size = img_size
29 | self.conf_thres = conf_thres
30 | self.iou_thres = iou_thres
31 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32 | self.det_model = attempt_load(det_model, map_location=self.device)
33 | self.det_model = self.det_model.to(self.device)
34 |
35 | self.model_name = pose_model
36 | self.pose_model = get_pose_model(pose_model)
37 | self.pose_model.load_state_dict(torch.load(pose_model, map_location='cpu'))
38 | self.pose_model = self.pose_model.to(self.device)
39 | self.pose_model.eval()
40 |
41 | self.patch_size = (192, 256)
42 |
43 | self.pose_transform = T.Compose([
44 | T.ToTensor(),
45 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
46 | ])
47 |
48 | self.coco_skeletons = [
49 | [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8],
50 | [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]
51 | ]
52 |
53 | def preprocess(self, image):
54 | img = letterbox(image, new_shape=self.img_size)
55 | img = np.ascontiguousarray(img.transpose((2, 0, 1)))
56 | img = torch.from_numpy(img).to(self.device)
57 | img = img.float() / 255.0
58 | img = img[None]
59 | return img
60 |
61 | def box_to_center_scale(self, boxes, pixel_std=200):
62 | boxes = xyxy2xywh(boxes)
63 | r = self.patch_size[0] / self.patch_size[1]
64 | mask = boxes[:, 2] > boxes[:, 3] * r
65 | boxes[mask, 3] = boxes[mask, 2] / r
66 | boxes[~mask, 2] = boxes[~mask, 3] * r
67 | boxes[:, 2:] /= pixel_std
68 | boxes[:, 2:] *= 1.25
69 | return boxes
70 |
71 | def predict_poses(self, boxes, img):
72 | image_patches = []
73 | for cx, cy, w, h in boxes:
74 | trans = get_affine_transform(np.array([cx, cy]), np.array([w, h]), self.patch_size)
75 | img_patch = cv2.warpAffine(img, trans, self.patch_size, flags=cv2.INTER_LINEAR)
76 | img_patch = self.pose_transform(img_patch)
77 | image_patches.append(img_patch)
78 |
79 | image_patches = torch.stack(image_patches).to(self.device)
80 | return self.pose_model(image_patches)
81 |
82 | def postprocess(self, pred, img1, img0):
83 | pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=0)
84 |
85 | for det in pred:
86 | if len(det):
87 | boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:]).cpu()
88 | boxes = self.box_to_center_scale(boxes)
89 | outputs = self.predict_poses(boxes, img0)
90 |
91 | if 'simdr' in self.model_name:
92 | coords = get_simdr_final_preds(*outputs, boxes, self.patch_size)
93 | else:
94 | coords = get_final_preds(outputs, boxes)
95 |
96 | draw_keypoints(img0, coords, self.coco_skeletons)
97 |
98 | @torch.no_grad()
99 | def predict(self, image):
100 | img = self.preprocess(image)
101 | pred = self.det_model(img)[0]
102 | self.postprocess(pred, img, image)
103 | return image
104 |
105 |
106 | def argument_parser():
107 | parser = argparse.ArgumentParser()
108 | parser.add_argument('--source', type=str, default='assests/test.jpg')
109 | parser.add_argument('--det-model', type=str, default='checkpoints/crowdhuman_yolov5m.pt')
110 | parser.add_argument('--pose-model', type=str, default='checkpoints/pretrained/simdr_hrnet_w32_256x192.pth')
111 | parser.add_argument('--img-size', type=int, default=640)
112 | parser.add_argument('--conf-thres', type=float, default=0.4)
113 | parser.add_argument('--iou-thres', type=float, default=0.5)
114 | return parser.parse_args()
115 |
116 |
117 | if __name__ == '__main__':
118 | setup_cudnn()
119 | args = argument_parser()
120 | pose = Pose(
121 | args.det_model,
122 | args.pose_model,
123 | args.img_size,
124 | args.conf_thres,
125 | args.iou_thres
126 | )
127 |
128 | source = Path(args.source)
129 |
130 | if source.is_file() and source.suffix in ['.jpg', '.png']:
131 | image = cv2.imread(str(source))
132 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
133 | output = pose.predict(image)
134 | cv2.imwrite(f"{str(source).rsplit('.', maxsplit=1)[0]}_out.jpg", cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
135 |
136 | elif source.is_dir():
137 | files = source.glob("*.jpg")
138 | for file in files:
139 | image = cv2.imread(str(file))
140 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
141 | output = pose.predict(image)
142 | cv2.imwrite(f"{str(file).rsplit('.', maxsplit=1)[0]}_out.jpg", cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
143 |
144 | elif source.is_file() and source.suffix in ['.mp4', '.avi']:
145 | reader = VideoReader(args.source)
146 | writer = VideoWriter(f"{args.source.rsplit('.', maxsplit=1)[0]}_out.mp4", reader.fps)
147 | fps = FPS(len(reader.frames))
148 |
149 | for frame in tqdm(reader):
150 | fps.start()
151 | output = pose.predict(frame.numpy())
152 | fps.stop(False)
153 | writer.update(output)
154 |
155 | print(f"FPS: {fps.fps}")
156 | writer.write()
157 |
158 | else:
159 | webcam = WebcamStream()
160 | fps = FPS()
161 |
162 | for frame in webcam:
163 | fps.start()
164 | output = pose.predict(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
165 | fps.stop()
166 | cv2.imshow('frame', cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
--------------------------------------------------------------------------------
/pose/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/pose/__init__.py
--------------------------------------------------------------------------------
/pose/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .posehrnet import PoseHRNet
2 | from .simdr import SimDR
3 |
4 |
5 | __all__ = ['PoseHRNet', 'SimDR']
6 |
7 |
8 | def get_pose_model(model_path: str):
9 | if 'posehrnet' in model_path:
10 | model = PoseHRNet('w32' if 'w32' in model_path else 'w48')
11 | elif 'simdr' in model_path:
12 | model = SimDR('w32' if 'w32' in model_path else 'w48')
13 | else:
14 | raise NotImplementedError
15 | return model
--------------------------------------------------------------------------------
/pose/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .hrnet import HRNet
--------------------------------------------------------------------------------
/pose/models/backbones/hrnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 |
4 |
5 | class Conv(nn.Sequential):
6 | def __init__(self, c1, c2, k, s=1, p=0):
7 | super().__init__(
8 | nn.Conv2d(c1, c2, k, s, p, bias=False),
9 | nn.BatchNorm2d(c2),
10 | nn.ReLU(True)
11 | )
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None):
18 | super().__init__()
19 | self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.relu = nn.ReLU(True)
22 |
23 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 |
26 | self.downsample = downsample
27 | self.stride = stride
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | identity = x
31 |
32 | out = self.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 |
35 | if self.downsample is not None:
36 | identity = self.downsample(x)
37 |
38 | out += identity
39 | out = self.relu(out)
40 | return out
41 |
42 |
43 | class Bottleneck(nn.Module):
44 | expansion = 4
45 |
46 | def __init__(self, inplanes, planes, stride=1, downsample=None):
47 | super().__init__()
48 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 |
51 | self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 |
54 | self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, bias=False)
55 | self.bn3 = nn.BatchNorm2d(planes*self.expansion)
56 |
57 | self.relu = nn.ReLU(True)
58 | self.downsample = downsample
59 | self.stride = stride
60 |
61 | def forward(self, x: Tensor) -> Tensor:
62 | identity = x
63 |
64 | out = self.relu(self.bn1(self.conv1(x)))
65 | out = self.relu(self.bn2(self.conv2(out)))
66 | out = self.bn3(self.conv3(out))
67 |
68 | if self.downsample is not None:
69 | identity = self.downsample(x)
70 |
71 | out += identity
72 | out = self.relu(out)
73 | return out
74 |
75 |
76 | class HRModule(nn.Module):
77 | def __init__(self, num_branches, num_channels, ms_output=True):
78 | super().__init__()
79 | self.num_branches = num_branches
80 | self.branches = nn.ModuleList([
81 | nn.Sequential(*[
82 | BasicBlock(num_channels[i], num_channels[i])
83 | for _ in range(4)])
84 | for i in range(num_branches)])
85 |
86 | self.fuse_layers = self._make_fuse_layers(num_branches, num_channels, ms_output)
87 | self.relu = nn.ReLU(True)
88 |
89 | def _make_fuse_layers(self, num_branches, num_channels, ms_output=True):
90 | fuse_layers = []
91 |
92 | for i in range(num_branches if ms_output else 1):
93 | fuse_layer = []
94 |
95 | for j in range(num_branches):
96 | if j > i:
97 | fuse_layer.append(
98 | nn.Sequential(
99 | nn.Conv2d(num_channels[j], num_channels[i], 1, bias=False),
100 | nn.BatchNorm2d(num_channels[i]),
101 | nn.Upsample(scale_factor=2**(j-i), mode='nearest')
102 | )
103 | )
104 | elif j == i:
105 | fuse_layer.append(None)
106 | else:
107 | conv3x3s = []
108 | for k in range(i-j):
109 | if k == i - j -1:
110 | conv3x3s.append(
111 | nn.Sequential(
112 | nn.Conv2d(num_channels[j], num_channels[i], 3, 2, 1, bias=False),
113 | nn.BatchNorm2d(num_channels[i])
114 | )
115 | )
116 | else:
117 | conv3x3s.append(Conv(num_channels[j], num_channels[j], 3, 2, 1))
118 | fuse_layer.append(nn.Sequential(*conv3x3s))
119 | fuse_layers.append(nn.ModuleList(fuse_layer))
120 |
121 | return nn.ModuleList(fuse_layers)
122 |
123 | def forward(self, x: Tensor) -> Tensor:
124 | for i, m in enumerate(self.branches):
125 | x[i] = m(x[i])
126 |
127 | x_fuse = []
128 |
129 | for i, fm in enumerate(self.fuse_layers):
130 | y = x[0] if i == 0 else fm[0](x[0])
131 |
132 | for j in range(1, self.num_branches):
133 | y = y + x[j] if i == j else y + fm[j](x[j])
134 | x_fuse.append(self.relu(y))
135 | return x_fuse
136 |
137 |
138 | hrnet_settings = {
139 | "w18": [18, 36, 72, 144],
140 | "w32": [32, 64, 128, 256],
141 | "w48": [48, 96, 192, 384]
142 | }
143 |
144 |
145 | class HRNet(nn.Module):
146 | def __init__(self, backbone: str = 'w18') -> None:
147 | super().__init__()
148 | assert backbone in hrnet_settings.keys(), f"HRNet model name should be in {list(hrnet_settings.keys())}"
149 |
150 | # stem
151 | self.conv1 = nn.Conv2d(3, 64, 3, 2, 1, bias=False)
152 | self.bn1 = nn.BatchNorm2d(64)
153 | self.conv2 = nn.Conv2d(64, 64, 3, 2, 1, bias=False)
154 | self.bn2 = nn.BatchNorm2d(64)
155 | self.relu = nn.ReLU(True)
156 |
157 | self.all_channels = hrnet_settings[backbone]
158 |
159 | # Stage 1
160 | self.layer1 = self._make_layer(64, 64, 4)
161 | stage1_out_channel = Bottleneck.expansion * 64
162 |
163 | # Stage 2
164 | stage2_channels = self.all_channels[:2]
165 | self.transition1 = self._make_transition_layer([stage1_out_channel], stage2_channels)
166 | self.stage2 = self._make_stage(1, 2, stage2_channels)
167 |
168 | # # Stage 3
169 | stage3_channels = self.all_channels[:3]
170 | self.transition2 = self._make_transition_layer(stage2_channels, stage3_channels)
171 | self.stage3 = self._make_stage(4, 3, stage3_channels)
172 |
173 | # # Stage 4
174 | self.transition3 = self._make_transition_layer(stage3_channels, self.all_channels)
175 | self.stage4 = self._make_stage(3, 4, self.all_channels, ms_output=False)
176 |
177 | def _make_layer(self, inplanes, planes, blocks):
178 | downsample = None
179 | if inplanes != planes * Bottleneck.expansion:
180 | downsample = nn.Sequential(
181 | nn.Conv2d(inplanes, planes*Bottleneck.expansion, 1, bias=False),
182 | nn.BatchNorm2d(planes*Bottleneck.expansion)
183 | )
184 |
185 | layers = []
186 | layers.append(Bottleneck(inplanes, planes, downsample=downsample))
187 | inplanes = planes * Bottleneck.expansion
188 |
189 | for _ in range(1, blocks):
190 | layers.append(Bottleneck(inplanes, planes))
191 |
192 | return nn.Sequential(*layers)
193 |
194 | def _make_transition_layer(self, c1s, c2s):
195 | num_branches_pre = len(c1s)
196 | num_branches_cur = len(c2s)
197 |
198 | transition_layers = []
199 |
200 | for i in range(num_branches_cur):
201 | if i < num_branches_pre:
202 | if c1s[i] != c2s[i]:
203 | transition_layers.append(Conv(c1s[i], c2s[i], 3, 1, 1))
204 | else:
205 | transition_layers.append(None)
206 | else:
207 | conv3x3s = []
208 | for j in range(i+1-num_branches_pre):
209 | inchannels = c1s[-1]
210 | outchannels = c2s[i] if j == i-num_branches_pre else inchannels
211 | conv3x3s.append(Conv(inchannels, outchannels, 3, 2, 1))
212 | transition_layers.append(nn.Sequential(*conv3x3s))
213 |
214 | return nn.ModuleList(transition_layers)
215 |
216 |
217 | def _make_stage(self, num_modules, num_branches, num_channels, ms_output=True):
218 | modules = []
219 |
220 | for i in range(num_modules):
221 | # multi-scale output is only used in last module
222 | if not ms_output and i == num_modules - 1:
223 | reset_ms_output = False
224 | else:
225 | reset_ms_output = True
226 | modules.append(HRModule(num_branches, num_channels, reset_ms_output))
227 |
228 | return nn.Sequential(*modules)
229 |
230 |
231 | def forward(self, x: Tensor) -> Tensor:
232 | x = self.relu(self.bn1(self.conv1(x)))
233 | x = self.relu(self.bn2(self.conv2(x)))
234 |
235 | # Stage 1
236 | x = self.layer1(x)
237 |
238 | # Stage 2
239 | x_list = [trans(x) if trans is not None else x for trans in self.transition1]
240 | y_list = self.stage2(x_list)
241 |
242 | # Stage 3
243 | x_list = [trans(y_list[-1]) if trans is not None else y_list[i] for i, trans in enumerate(self.transition2)]
244 | y_list = self.stage3(x_list)
245 |
246 | # # Stage 4
247 | x_list = [trans(y_list[-1]) if trans is not None else y_list[i] for i, trans in enumerate(self.transition3)]
248 | y_list = self.stage4(x_list)
249 | return y_list[0]
250 |
251 |
252 | if __name__ == '__main__':
253 | model = HRNet('w32')
254 | model.load_state_dict(torch.load('./checkpoints/backbone/hrnet_w32.pth', map_location='cpu'), strict=False)
255 | x = torch.randn(1, 3, 224, 224)
256 | y = model(x)
257 | print(y.shape)
--------------------------------------------------------------------------------
/pose/models/posehrnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from .backbones import HRNet
4 |
5 |
6 | class PoseHRNet(nn.Module):
7 | def __init__(self, backbone: str = 'w32', num_joints: int = 17):
8 | super().__init__()
9 | self.backbone = HRNet(backbone)
10 | self.final_layer = nn.Conv2d(self.backbone.all_channels[0], num_joints, 1)
11 |
12 | self.apply(self._init_weights)
13 |
14 | def _init_weights(self, m: nn.Module) -> None:
15 | if isinstance(m, nn.Conv2d):
16 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
17 | elif isinstance(m, nn.BatchNorm2d):
18 | nn.init.constant_(m.weight, 1)
19 | nn.init.constant_(m.bias, 0)
20 |
21 | def init_pretrained(self, pretrained: str = None) -> None:
22 | if pretrained:
23 | self.backbone.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False)
24 |
25 | def forward(self, x: Tensor) -> Tensor:
26 | out = self.backbone(x)
27 | out = self.final_layer(out)
28 | return out
29 |
30 |
31 | if __name__ == '__main__':
32 | model = PoseHRNet('w48')
33 | model.load_state_dict(torch.load('checkpoints/pretrained/posehrnet_w48_256x192.pth', map_location='cpu'))
34 | x = torch.randn(1, 3, 256, 192)
35 | y = model(x)
36 | print(y.shape)
--------------------------------------------------------------------------------
/pose/models/simdr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from .backbones import HRNet
4 |
5 |
6 | class SimDR(nn.Module):
7 | def __init__(self, backbone: str = 'w32', num_joints: int = 17, image_size: tuple = (256, 192)):
8 | super().__init__()
9 | self.backbone = HRNet(backbone)
10 | self.final_layer = nn.Conv2d(self.backbone.all_channels[0], num_joints, 1)
11 | self.mlp_head_x = nn.Linear(3072, int(image_size[1] * 2.0))
12 | self.mlp_head_y = nn.Linear(3072, int(image_size[0] * 2.0))
13 |
14 | self.apply(self._init_weights)
15 |
16 | def _init_weights(self, m: nn.Module) -> None:
17 | if isinstance(m, nn.Conv2d):
18 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
19 | elif isinstance(m, nn.BatchNorm2d):
20 | nn.init.constant_(m.weight, 1)
21 | nn.init.constant_(m.bias, 0)
22 |
23 | def init_pretrained(self, pretrained: str = None) -> None:
24 | if pretrained:
25 | self.backbone.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False)
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | out = self.backbone(x)
29 | out = self.final_layer(out).flatten(2)
30 | pred_x = self.mlp_head_x(out)
31 | pred_y = self.mlp_head_y(out)
32 | return pred_x, pred_y
33 |
34 |
35 | if __name__ == '__main__':
36 | from torch.nn import functional as F
37 | model = SimDR('w32')
38 | model.load_state_dict(torch.load('checkpoints/pretrained/simdr_hrnet_w32_256x192.pth', map_location='cpu'))
39 | x = torch.randn(4, 3, 256, 192)
40 | px, py = model(x)
41 | print(px.shape, py.shape)
42 |
--------------------------------------------------------------------------------
/pose/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sithu31296/pose-estimation/b00da09cfaf0ee25cdc900a46ac0a2e2a878f16a/pose/utils/__init__.py
--------------------------------------------------------------------------------
/pose/utils/boxes.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | from torchvision import ops
5 |
6 |
7 | def letterbox(img, new_shape=(640, 640)):
8 | H, W = img.shape[:2]
9 | if isinstance(new_shape, int):
10 | new_shape = (new_shape, new_shape)
11 |
12 | r = min(new_shape[0] / H, new_shape[1] / W)
13 | nH, nW = round(H * r), round(W * r)
14 | pH, pW = np.mod(new_shape[0] - nH, 32) / 2, np.mod(new_shape[1] - nW, 32) / 2
15 |
16 | if (H, W) != (nH, nW):
17 | img = cv2.resize(img, (nW, nH), interpolation=cv2.INTER_LINEAR)
18 |
19 | top, bottom = round(pH - 0.1), round(pH + 0.1)
20 | left, right = round(pW - 0.1), round(pW + 0.1)
21 | img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
22 | return img
23 |
24 |
25 | def scale_boxes(boxes, orig_shape, new_shape):
26 | H, W = orig_shape
27 | nH, nW = new_shape
28 | gain = min(nH / H, nW / W)
29 | pad = (nH - H * gain) / 2, (nW - W * gain) / 2
30 |
31 | boxes[:, ::2] -= pad[1]
32 | boxes[:, 1::2] -= pad[0]
33 | boxes[:, :4] /= gain
34 |
35 | boxes[:, ::2].clamp_(0, orig_shape[1])
36 | boxes[:, 1::2].clamp_(0, orig_shape[0])
37 | return boxes.round()
38 |
39 |
40 | def xywh2xyxy(x):
41 | boxes = x.clone()
42 | boxes[:, 0] = x[:, 0] - x[:, 2] / 2
43 | boxes[:, 1] = x[:, 1] - x[:, 3] / 2
44 | boxes[:, 2] = x[:, 0] + x[:, 2] / 2
45 | boxes[:, 3] = x[:, 1] + x[:, 3] / 2
46 | return boxes
47 |
48 |
49 | def xyxy2xywh(x):
50 | y = x.clone()
51 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
52 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
53 | y[:, 2] = x[:, 2] - x[:, 0] # width
54 | y[:, 3] = x[:, 3] - x[:, 1] # height
55 | return y
56 |
57 |
58 | def non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, classes=None):
59 | candidates = pred[..., 4] > conf_thres
60 |
61 | max_wh = 4096
62 | max_nms = 30000
63 | max_det = 300
64 |
65 | output = [torch.zeros((0, 6), device=pred.device)] * pred.shape[0]
66 |
67 | for xi, x in enumerate(pred):
68 | x = x[candidates[xi]]
69 |
70 | if not x.shape[0]: continue
71 |
72 | # compute conf
73 | x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
74 |
75 | # box
76 | box = xywh2xyxy(x[:, :4])
77 |
78 | # detection matrix nx6
79 | conf, j = x[:, 5:].max(1, keepdim=True)
80 | x = torch.cat([box, conf, j.float()], dim=1)[conf.view(-1) > conf_thres]
81 |
82 | # filter by class
83 | if classes is not None:
84 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
85 |
86 | # check shape
87 | n = x.shape[0]
88 | if not n:
89 | continue
90 | elif n > max_nms:
91 | x = x[x[:, 4].argsort(descending=True)[:max_nms]]
92 |
93 | # batched nms
94 | c = x[:, 5:6] * max_wh
95 | boxes, scores = x[:, :4] + c, x[:, 4]
96 | keep = ops.nms(boxes, scores, iou_thres)
97 |
98 | if keep.shape[0] > max_det:
99 | keep = keep[:max_det]
100 |
101 | output[xi] = x[keep]
102 |
103 | return output
--------------------------------------------------------------------------------
/pose/utils/decode.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | from torch import Tensor
5 |
6 |
7 | def get_simdr_final_preds(pred_x: Tensor, pred_y: Tensor, boxes: Tensor, image_size: tuple):
8 | center, scale = boxes[:, :2].numpy(), boxes[:, 2:].numpy()
9 |
10 | pred_x, pred_y = pred_x.softmax(dim=2), pred_y.softmax(dim=2)
11 | pred_x, pred_y = pred_x.max(dim=2)[-1], pred_y.max(dim=2)[-1]
12 | coords = torch.stack([pred_x / 2, pred_y / 2], dim=-1).cpu().numpy()
13 |
14 | for i in range(coords.shape[0]):
15 | coords[i] = transform_preds(coords[i], center[i], scale[i], image_size)
16 | return coords.astype(int)
17 |
18 |
19 | def get_final_preds(heatmaps: Tensor, boxes: Tensor):
20 | center, scale = boxes[:, :2].numpy(), boxes[:, 2:].numpy()
21 | heatmaps = heatmaps.cpu().numpy()
22 | B, C, H, W = heatmaps.shape
23 | coords = get_max_preds(heatmaps)
24 |
25 | for n in range(B):
26 | for p in range(C):
27 | hm = heatmaps[n][p]
28 | px = int(math.floor(coords[n][p][0] + 0.5))
29 | py = int(math.floor(coords[n][p][1] + 0.5))
30 |
31 | if 1 < px < W - 1 and 1 < py < H - 1:
32 | diff = np.array([
33 | hm[py][px+1] - hm[py][px-1],
34 | hm[py+1][px] - hm[py-1][px]
35 | ])
36 | coords[n][p] += np.sign(diff) * .25
37 |
38 | for i in range(B):
39 | coords[i] = transform_preds(coords[i], center[i], scale[i], [W, H])
40 | return coords.astype(int)
41 |
42 |
43 | def get_max_preds(heatmaps: np.ndarray):
44 | B, C, _, W = heatmaps.shape
45 | heatmaps = heatmaps.reshape((B, C, -1))
46 | idx = np.argmax(heatmaps, axis=2).reshape((B, C, 1))
47 | maxvals = np.amax(heatmaps, axis=2).reshape((B, C, 1))
48 | preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
49 | preds[:, :, 0] = preds[:, :, 0] % W
50 | preds[:, :, 1] = preds[:, :, 1] // W
51 | preds = np.where(np.tile(maxvals, (1, 1, 2)) > 0.0, preds, -1)
52 | return preds
53 |
54 |
55 | def transform_preds(coords, center, scale, output_size):
56 | scale = scale * 200
57 | scale_x = scale[0] / output_size[0]
58 | scale_y = scale[1] / output_size[1]
59 | target_coords = np.ones_like(coords)
60 | target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5
61 | target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5
62 | return target_coords
--------------------------------------------------------------------------------
/pose/utils/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import time
5 | from torchvision import io
6 | from threading import Thread
7 | from torch.backends import cudnn
8 |
9 |
10 | def setup_cudnn() -> None:
11 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
12 | cudnn.benchmark = True
13 | cudnn.deterministic = False
14 |
15 |
16 | def draw_coco_keypoints(img, keypoints, skeletons):
17 | if keypoints == []: return img
18 | image = img.copy()
19 | for kpts in keypoints:
20 | for x, y, v in kpts:
21 | if v == 2:
22 | cv2.circle(image, (x, y), 4, (255, 0, 0), 2)
23 | for kid1, kid2 in skeletons:
24 | x1, y1, v1 = kpts[kid1-1]
25 | x2, y2, v2 = kpts[kid2-1]
26 | if v1 == 2 and v2 == 2:
27 | cv2.line(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
28 | return image
29 |
30 |
31 | def draw_keypoints(img, keypoints, skeletons):
32 | if keypoints == []: return img
33 | for kpts in keypoints:
34 | for x, y in kpts:
35 | cv2.circle(img, (x, y), 4, (255, 0, 0), 2, cv2.LINE_AA)
36 | for kid1, kid2 in skeletons:
37 | cv2.line(img, kpts[kid1-1], kpts[kid2-1], (0, 255, 0), 2, cv2.LINE_AA)
38 |
39 |
40 | class WebcamStream:
41 | def __init__(self, src=0) -> None:
42 | self.cap = cv2.VideoCapture(src)
43 | self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3)
44 | assert self.cap.isOpened(), f"Failed to open webcam {src}"
45 | _, self.frame = self.cap.read()
46 | Thread(target=self.update, args=([]), daemon=True).start()
47 |
48 | def update(self):
49 | while self.cap.isOpened():
50 | _, self.frame = self.cap.read()
51 |
52 | def __iter__(self):
53 | self.count = -1
54 | return self
55 |
56 | def __next__(self):
57 | self.count += 1
58 |
59 | if cv2.waitKey(1) == ord('q'):
60 | self.stop()
61 |
62 | return self.frame.copy()
63 |
64 | def stop(self):
65 | cv2.destroyAllWindows()
66 | raise StopIteration
67 |
68 | def __len__(self):
69 | return 0
70 |
71 |
72 | class VideoReader:
73 | def __init__(self, video: str):
74 | self.frames, _, info = io.read_video(video, pts_unit='sec')
75 | self.fps = info['video_fps']
76 |
77 | print(f"Processing '{video}'...")
78 | print(f"Total Frames: {len(self.frames)}")
79 | print(f"Video Size : {list(self.frames.shape[1:-1])}")
80 | print(f"Video FPS : {self.fps}")
81 |
82 | def __iter__(self):
83 | self.count = 0
84 | return self
85 |
86 | def __len__(self):
87 | return len(self.frames)
88 |
89 | def __next__(self):
90 | if self.count == len(self.frames):
91 | raise StopIteration
92 | frame = self.frames[self.count]
93 | self.count += 1
94 | return frame
95 |
96 |
97 | class VideoWriter:
98 | def __init__(self, file_name, fps):
99 | self.fname = file_name
100 | self.fps = fps
101 | self.frames = []
102 |
103 | def update(self, frame):
104 | if isinstance(frame, np.ndarray):
105 | frame = torch.from_numpy(frame)
106 | self.frames.append(frame)
107 |
108 | def write(self):
109 | print(f"Saving video to '{self.fname}'...")
110 | io.write_video(self.fname, torch.stack(self.frames), self.fps)
111 |
112 |
113 | class FPS:
114 | def __init__(self, avg=10) -> None:
115 | self.accum_time = 0
116 | self.counts = 0
117 | self.avg = avg
118 |
119 | def synchronize(self):
120 | if torch.cuda.is_available():
121 | torch.cuda.synchronize()
122 |
123 | def start(self):
124 | self.synchronize()
125 | self.prev_time = time.time()
126 |
127 | def stop(self, debug=True):
128 | self.synchronize()
129 | self.accum_time += time.time() - self.prev_time
130 | self.counts += 1
131 | if self.counts == self.avg:
132 | self.fps = round(self.counts / self.accum_time)
133 | if debug: print(f"FPS: {self.fps}")
134 | self.counts = 0
135 | self.accum_time = 0
136 |
137 |
138 | def get_dir(src_point, rot):
139 | rot_rad = np.pi * rot / 180
140 | sn, cs = np.sin(rot_rad), np.cos(rot_rad)
141 | p1 = src_point[0] * cs - src_point[1] * sn
142 | p2 = src_point[0] * sn + src_point[1] * cs
143 | return p1, p2
144 |
145 |
146 | def get_3rd_point(a, b):
147 | direct = a - b
148 | return b + np.array([-direct[1], direct[0]], dtype=np.float32)
149 |
150 |
151 | def get_affine_transform(center, scale, patch_size, rot=0, inv=False):
152 | shift = np.array([0, 0], dtype=np.float32)
153 | scale_tmp = scale * 200
154 | src_w = scale_tmp[0]
155 | dst_w = patch_size[0]
156 | dst_h = patch_size[1]
157 |
158 | src_dir = get_dir([0, src_w * -0.5], rot)
159 | dst_dir = np.array([0, dst_w * -0.5], dtype=np.float32)
160 | src = np.zeros((3, 2), dtype=np.float32)
161 | dst = np.zeros((3, 2), dtype=np.float32)
162 |
163 | src[0, :] = center + scale_tmp * shift
164 | src[1, :] = center + src_dir + scale_tmp * shift
165 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
166 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
167 |
168 | src[2:, :] = get_3rd_point(src[0, :], src[1, :])
169 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
170 |
171 | return cv2.getAffineTransform(dst, src) if inv else cv2.getAffineTransform(src, dst)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | numpy
3 | tqdm
4 |
--------------------------------------------------------------------------------