├── overview.png ├── README.md └── predictor.py /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shajiayu1/DiffusionPose/HEAD/overview.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffusionPose 2 | 3 | # Stay Tuned! 4 | ## We propose DiffusionPose, a new framework that formulates multi-person 2D pose estimation as a denoising diffusion process from noisy keypoints to human body keypoints. 5 | ## The code and paper will be released soon. 6 | ![image](https://github.com/shajiayu1/DiffusionPose/blob/main/overview.png) 7 | -------------------------------------------------------------------------------- /predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import atexit 3 | import bisect 4 | import multiprocessing as mp 5 | from collections import deque 6 | import cv2 7 | import torch 8 | 9 | from detectron2.data import MetadataCatalog 10 | from detectron2.engine.defaults import DefaultPredictor 11 | from detectron2.utils.video_visualizer import VideoVisualizer 12 | from detectron2.utils.visualizer import ColorMode, Visualizer 13 | 14 | 15 | class VisualizationDemo(object): 16 | def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): 17 | """ 18 | Args: 19 | cfg (CfgNode): 20 | instance_mode (ColorMode): 21 | parallel (bool): whether to run the model in different processes from visualization. 22 | Useful since the visualization logic can be slow. 23 | """ 24 | self.metadata = MetadataCatalog.get( 25 | cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" 26 | ) 27 | self.cpu_device = torch.device("cpu") 28 | self.instance_mode = instance_mode 29 | 30 | self.parallel = parallel 31 | if parallel: 32 | num_gpu = torch.cuda.device_count() 33 | self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) 34 | else: 35 | self.predictor = DefaultPredictor(cfg) 36 | 37 | self.threshold = cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST # workaround 38 | 39 | def run_on_image(self, image): 40 | """ 41 | Args: 42 | image (np.ndarray): an image of shape (H, W, C) (in BGR order). 43 | This is the format used by OpenCV. 44 | 45 | Returns: 46 | predictions (dict): the output of the model. 47 | vis_output (VisImage): the visualized image output. 48 | """ 49 | vis_output = None 50 | predictions = self.predictor(image) 51 | # Filter 52 | instances = predictions['instances'] 53 | new_instances = instances[instances.scores > self.threshold] 54 | predictions = {'instances': new_instances} 55 | # Convert image from OpenCV BGR format to Matplotlib RGB format. 56 | image = image[:, :, ::-1] 57 | visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) 58 | if "panoptic_seg" in predictions: 59 | panoptic_seg, segments_info = predictions["panoptic_seg"] 60 | vis_output = visualizer.draw_panoptic_seg_predictions( 61 | panoptic_seg.to(self.cpu_device), segments_info 62 | ) 63 | else: 64 | if "sem_seg" in predictions: 65 | vis_output = visualizer.draw_sem_seg( 66 | predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 67 | ) 68 | if "instances" in predictions: 69 | instances = predictions["instances"].to(self.cpu_device) 70 | vis_output = visualizer.draw_instance_predictions(predictions=instances) 71 | 72 | return predictions, vis_output 73 | 74 | def _frame_from_video(self, video): 75 | while video.isOpened(): 76 | success, frame = video.read() 77 | if success: 78 | yield frame 79 | else: 80 | break 81 | 82 | def run_on_video(self, video): 83 | """ 84 | Visualizes predictions on frames of the input video. 85 | 86 | Args: 87 | video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be 88 | either a webcam or a video file. 89 | 90 | Yields: 91 | ndarray: BGR visualizations of each video frame. 92 | """ 93 | video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) 94 | 95 | def process_predictions(frame, predictions): 96 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 97 | if "panoptic_seg" in predictions: 98 | panoptic_seg, segments_info = predictions["panoptic_seg"] 99 | vis_frame = video_visualizer.draw_panoptic_seg_predictions( 100 | frame, panoptic_seg.to(self.cpu_device), segments_info 101 | ) 102 | elif "instances" in predictions: 103 | predictions = predictions["instances"].to(self.cpu_device) 104 | vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) 105 | elif "sem_seg" in predictions: 106 | vis_frame = video_visualizer.draw_sem_seg( 107 | frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 108 | ) 109 | 110 | # Converts Matplotlib RGB format to OpenCV BGR format 111 | vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) 112 | return vis_frame 113 | 114 | frame_gen = self._frame_from_video(video) 115 | if self.parallel: 116 | buffer_size = self.predictor.default_buffer_size 117 | 118 | frame_data = deque() 119 | 120 | for cnt, frame in enumerate(frame_gen): 121 | frame_data.append(frame) 122 | self.predictor.put(frame) 123 | 124 | if cnt >= buffer_size: 125 | frame = frame_data.popleft() 126 | predictions = self.predictor.get() 127 | yield process_predictions(frame, predictions) 128 | 129 | while len(frame_data): 130 | frame = frame_data.popleft() 131 | predictions = self.predictor.get() 132 | yield process_predictions(frame, predictions) 133 | else: 134 | for frame in frame_gen: 135 | yield process_predictions(frame, self.predictor(frame)) 136 | 137 | 138 | class AsyncPredictor: 139 | """ 140 | A predictor that runs the model asynchronously, possibly on >1 GPUs. 141 | Because rendering the visualization takes considerably amount of time, 142 | this helps improve throughput a little bit when rendering videos. 143 | """ 144 | 145 | class _StopToken: 146 | pass 147 | 148 | class _PredictWorker(mp.Process): 149 | def __init__(self, cfg, task_queue, result_queue): 150 | self.cfg = cfg 151 | self.task_queue = task_queue 152 | self.result_queue = result_queue 153 | super().__init__() 154 | 155 | def run(self): 156 | predictor = DefaultPredictor(self.cfg) 157 | 158 | while True: 159 | task = self.task_queue.get() 160 | if isinstance(task, AsyncPredictor._StopToken): 161 | break 162 | idx, data = task 163 | result = predictor(data) 164 | self.result_queue.put((idx, result)) 165 | 166 | def __init__(self, cfg, num_gpus: int = 1): 167 | """ 168 | Args: 169 | cfg (CfgNode): 170 | num_gpus (int): if 0, will run on CPU 171 | """ 172 | num_workers = max(num_gpus, 1) 173 | self.task_queue = mp.Queue(maxsize=num_workers * 3) 174 | self.result_queue = mp.Queue(maxsize=num_workers * 3) 175 | self.procs = [] 176 | for gpuid in range(max(num_gpus, 1)): 177 | cfg = cfg.clone() 178 | cfg.defrost() 179 | cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" 180 | self.procs.append( 181 | AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) 182 | ) 183 | 184 | self.put_idx = 0 185 | self.get_idx = 0 186 | self.result_rank = [] 187 | self.result_data = [] 188 | 189 | for p in self.procs: 190 | p.start() 191 | atexit.register(self.shutdown) 192 | 193 | def put(self, image): 194 | self.put_idx += 1 195 | self.task_queue.put((self.put_idx, image)) 196 | 197 | def get(self): 198 | self.get_idx += 1 # the index needed for this request 199 | if len(self.result_rank) and self.result_rank[0] == self.get_idx: 200 | res = self.result_data[0] 201 | del self.result_data[0], self.result_rank[0] 202 | return res 203 | 204 | while True: 205 | # make sure the results are returned in the correct order 206 | idx, res = self.result_queue.get() 207 | if idx == self.get_idx: 208 | return res 209 | insert = bisect.bisect(self.result_rank, idx) 210 | self.result_rank.insert(insert, idx) 211 | self.result_data.insert(insert, res) 212 | 213 | def __len__(self): 214 | return self.put_idx - self.get_idx 215 | 216 | def __call__(self, image): 217 | self.put(image) 218 | return self.get() 219 | 220 | def shutdown(self): 221 | for _ in self.procs: 222 | self.task_queue.put(AsyncPredictor._StopToken()) 223 | 224 | @property 225 | def default_buffer_size(self): 226 | return len(self.procs) * 5 227 | --------------------------------------------------------------------------------