├── README.md ├── augment.py ├── datasets.py ├── environment.yaml ├── infer3d.py ├── loss.py ├── models.py ├── train2d.py ├── train3d.py ├── trainPose.py ├── trainer.py ├── utils.py ├── v1 ├── A.ipynb ├── calibration.py ├── distribution.py ├── triangulation.py └── utils.py └── visualization.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Probabilistic Triangulation V2 2 | 3 | Code of ICCV 2023 paper: [Probabilistic Triangulation for Uncalibrated Multi-View 3D Human Pose Estimation](https://arxiv.org/abs/2309.04756) 4 | 5 | Abstract: 3D human pose estimation has been a long-standing challenge in computer vision and graphics, where multi-view methods have significantly progressed but are limited by the tedious calibration processes. Existing multi-view methods are restricted to fixed camera pose and therefore lack generalization ability. This paper presents a novel Probabilistic Triangulation module that can be embedded in a calibrated 3D human pose estimation method, generalizing it to uncalibration scenes. The key idea is to use a probability distribution to model the camera pose and iteratively update the distribution from 2D features instead of using camera pose. Specifically, We maintain a camera pose distribution and then iteratively update this distribution by computing the posterior probability of the camera pose through Monte Carlo sampling. This way, the gradients can be directly back-propagated from the 3D pose estimation to the 2D heatmap, enabling end-to-end training. Extensive experiments on Human3.6M and CMU Panoptic demonstrate that our method outperforms other uncalibration methods and achieves comparable results with state-of-the-art calibration methods. Thus, our method achieves a trade-off between estimation accuracy and generalizability. 6 | 7 | ## version update 8 | 1. Accelerated the model by replacing backbone with mobileone; 9 | 2. changed the sampling logic to speed up multi-view fusion; 10 | 3. Now the model can be reasoned in real time on iphone. 11 | 12 | ## Getting started 13 | 14 | ### 1. Dataset 15 | 16 | Download and preprocess the dataset by following the instructions in [h36m-fetch](https://github.com/anibali/h36m-fetch) and [learnable triangulation](https://github.com/karfly/learnable-triangulation-pytorch). 17 | 18 | The directory structure after completing all processing: 19 | 20 | ``` 21 | human3.6m 22 | ├── extra 23 | │ ├── bboxes-Human36M-GT.npy 24 | │ ├── human36m-multiview-labels-GTbboxes.npy 25 | │ └── una-dinosauria-data 26 | └── processed 27 | ├── S1 28 | ├── S11 29 | ├── S5 30 | ├── S6 31 | ├── S7 32 | ├── S8 33 | └── S9 34 | ``` 35 | 36 | ### 2. Quick Start 37 | 38 | Use conda to create an environment, or a newer version of pytorch: 39 | 40 | ``` 41 | conda env create -f environment.yaml 42 | ``` 43 | 44 | Perform inference on [pretrained models](https://drive.google.com/file/d/11baGjN9-iC6AzORrPSLJ_Oyk6kCLiEH6/view?usp=drive_link): 45 | 46 | ```python 47 | python infer3d.py 48 | ``` 49 | 50 | The following results will be obtained, where x3d/l2 is mpjpe: 51 | ``` 52 | loss 4.177016958594322 53 | loss/hm 2.6981436171952415 54 | loss/x3d 5.655890337684575 55 | x2d/l1 16.779179317109726 56 | x2d/l2 13.22389983604936 57 | x3d/l1 38.431529241449695 58 | x3d/l2 26.103624186095068 59 | ``` 60 | 61 | Train the 3d estimator, which by default will use the pretrained model of the 2d backbone: 62 | 63 | ```python 64 | python train3d.py 65 | ``` 66 | 67 | Train the 2d backbone: 68 | 69 | ```python 70 | python train2d.py 71 | ``` 72 | 73 | ### 3. Some training suggestions 74 | 75 | 1. While training, it was found that the estimation accuracy of the 2D pose very much affects the results, the mpjpe can be up to 26mm for 384x384 inputs, but only 34mm for 256x256 inputs. 76 | 77 | 2. When the model parameter count is small, Human3.6m has a single background that tends to overfit the model (probably because the model uses color as a key point feature). We added some color and brightness data augmentation during training to combat it. But this can't completely solve the field scene. Pre-training the model with a broader dataset would solve this problem. 78 | 79 | 3. Human3.6m has a lot of data duplicates, and spaced use can quickly validate the training results. 80 | 81 | 4. The voxel fusion multi-view approach leads to a rich physics prior, but there is a bottleneck in acceleration. In the new version of the code, we use orientation + sampled features as inputs, which can greatly speed up the speedup. 82 | 83 | 5. In training, the fusion part uses the generated data for pre-training and is fine-tuned in subsequent training, which can achieve better generalization. 84 | 85 | 86 | 87 | ## Citation 88 | 89 | If you find this project useful for your research, please consider citing: 90 | 91 | ``` 92 | @article{hu2023pose, 93 | title={Probabilistic Triangulation for Uncalibrated Multi-View 3D Human Pose Estimation}, 94 | author={Boyuan Jiang, Lei Hu, Shihong Xia} 95 | journal={IEEE International Conference on Computer Vision}, 96 | year={2023}, 97 | publisher={IEEE} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | from utils import eulid_to_homo 2 | import numpy as np 3 | import cv2 4 | 5 | class Compose(): 6 | def __init__(self, transforms): 7 | self.transforms = transforms 8 | def __call__(self, sample): 9 | for t in self.transforms: 10 | sample = t(sample) 11 | return sample 12 | 13 | class Crop(): 14 | def __init__(self, scaleRange,moveRange): 15 | self.scaleRange = scaleRange 16 | self.moveRange = moveRange 17 | 18 | def __call__(self,sample): 19 | H, W, C = sample['image'].shape 20 | scale = np.random.uniform(self.scaleRange[0],self.scaleRange[1]) 21 | moveX = np.random.uniform(self.moveRange[0], self.moveRange[1]) 22 | moveY = np.random.uniform(self.moveRange[0], self.moveRange[1]) 23 | B = sample['box'] 24 | width = B[2] - B[0] 25 | height = B[3] - B[1] 26 | cx = int((B[0] + B[2])/2 + moveX * width) 27 | cy = int((B[1] + B[3])/2 + moveY * height) 28 | 29 | side = int(max(width, height) * scale) 30 | A = np.asarray([ 31 | cx - side//2, 32 | cy - side//2, 33 | cx + side - side//2, 34 | cy + side - side//2 35 | ]) 36 | Aclip = np.clip(A, [0, 0, 0, 0], [W, H, W, H]) 37 | img = np.zeros((side, side, C)) 38 | img[(Aclip[1]-A[1]):(Aclip[3]-A[1]), (Aclip[0]-A[0]):(Aclip[2]-A[0]) 39 | ] = sample['image'][Aclip[1]: Aclip[3], Aclip[0]: Aclip[2]] 40 | 41 | sample['image'] = img 42 | sample['camera'].update_after_crop(A) 43 | del sample['box'] 44 | return sample 45 | 46 | class Resize(): 47 | def __init__(self, image_size, interpolation=cv2.INTER_NEAREST): 48 | self.image_size = (image_size, image_size) if isinstance( 49 | image_size, (int, float)) else image_size 50 | self.interpolation = interpolation 51 | 52 | def __call__(self, sample): 53 | sample['camera'].update_after_resize(sample['image'].shape[:2], self.image_size) 54 | sample['image'] = cv2.resize( 55 | sample['image'], self.image_size, interpolation=cv2.INTER_CUBIC) 56 | return sample 57 | 58 | class NormSkeleton(): 59 | def __init__(self, root_id=6): 60 | self.root_id = root_id 61 | 62 | def __call__(self, sample): 63 | x3d = eulid_to_homo(sample['x3d']) @ sample['camera'].extrinsics().T 64 | 65 | x2d = x3d @ sample['camera'].K.T 66 | x2d[:,:2] /= x2d[:,2:] 67 | 68 | sample['x3d'] = (x3d-x3d[self.root_id]).astype(np.float32) 69 | sample['x2d'] = x2d[:,:2].astype(np.float32) 70 | sample['K'] = sample['camera'].K 71 | sample['R'] = sample['camera'].R 72 | sample['t'] = sample['camera'].t 73 | del sample['camera'] 74 | return sample 75 | 76 | class NormImage(): 77 | def __init__(self): 78 | pass 79 | 80 | def __call__(self,sample): 81 | sample['image'] = np.clip(sample['image']/255., 0.,1.).transpose(2,0,1).astype(np.float32) 82 | return sample 83 | 84 | 85 | class RandomBrightness(object): 86 | def __init__(self, delta=32): 87 | assert delta >= 0.0 88 | assert delta <= 255.0 89 | self.delta = delta 90 | def __call__(self, sample): 91 | if np.random.randint(2): 92 | delta = np.random.uniform(-self.delta, self.delta) 93 | sample['image'] += delta 94 | return sample 95 | 96 | class RandomContrast(object): 97 | def __init__(self, lower=0.5, upper=1.5): 98 | self.lower = lower 99 | self.upper = upper 100 | assert self.upper >= self.lower, "contrast upper must be >= lower." 101 | assert self.lower >= 0, "contrast lower must be non-negative." 102 | def __call__(self, sample): 103 | if np.random.randint(2): 104 | alpha = np.random.uniform(self.lower, self.upper) 105 | sample['image'] *= alpha 106 | return sample 107 | 108 | 109 | class RandomSaturation(object): 110 | def __init__(self, lower=0.5, upper=1.5): 111 | self.lower = lower 112 | self.upper = upper 113 | assert self.upper >= self.lower, "contrast upper must be >= lower." 114 | assert self.lower >= 0, "contrast lower must be non-negative." 115 | def __call__(self, sample): 116 | if np.random.randint(2): 117 | sample['image'][:, :, 1] *= np.random.uniform(self.lower, self.upper) 118 | return sample 119 | 120 | class RandomHue(object): 121 | def __init__(self, delta=18.0): 122 | assert delta >= 0.0 and delta <= 360.0 123 | self.delta = delta 124 | def __call__(self, sample): 125 | if np.random.randint(2): 126 | sample['image'][:, :, 0] += np.random.uniform(-self.delta, self.delta) 127 | sample['image'][:, :, 0][sample['image'][:, :, 0] > 360.0] -= 360.0 128 | sample['image'][:, :, 0][sample['image'][:, :, 0] < 0.0] += 360.0 129 | return sample 130 | 131 | class SwapChannels(object): 132 | def __init__(self, swaps): 133 | self.swaps = swaps 134 | def __call__(self, image): 135 | image = image[:, :, self.swaps] 136 | return image 137 | 138 | class RandomLightingNoise(object): 139 | def __init__(self): 140 | self.perms = ((0, 1, 2), (0, 2, 1), 141 | (1, 0, 2), (1, 2, 0), 142 | (2, 0, 1), (2, 1, 0)) 143 | def __call__(self, sample): 144 | if np.random.randint(2): 145 | swap = self.perms[np.random.randint(len(self.perms))] 146 | sample['image'] = sample['image'][:,:,swap] 147 | return sample 148 | 149 | class ConvertColor(object): 150 | def __init__(self, current='BGR', transform='HSV'): 151 | self.transform = transform 152 | self.current = current 153 | def __call__(self, sample): 154 | if self.current == 'BGR' and self.transform == 'HSV': 155 | sample['image'] = cv2.cvtColor(sample['image'] , cv2.COLOR_BGR2HSV) 156 | elif self.current == 'HSV' and self.transform == 'BGR': 157 | sample['image'] = cv2.cvtColor(sample['image'] , cv2.COLOR_HSV2BGR) 158 | else: 159 | raise NotImplementedError 160 | return sample 161 | 162 | class PhotometricDistort(): 163 | def __init__(self): 164 | self.distort = Compose([ 165 | RandomBrightness(), 166 | RandomContrast(), 167 | ConvertColor(transform='HSV'), 168 | RandomSaturation(), 169 | RandomHue(), 170 | ConvertColor(current='HSV', transform='BGR'), 171 | RandomLightingNoise(), 172 | ]) 173 | def __call__(self, sample): 174 | sample['image'] = np.clip(sample['image'],0,255).astype(np.float32) 175 | return self.distort(sample) 176 | 177 | 178 | class GenHeatmap(): 179 | def __init__(self, num_keypoints, image_size = 256, heatmap_size = 64): 180 | self.num_keypoints = num_keypoints 181 | self.image_size = image_size 182 | self.heatmap_size = heatmap_size 183 | 184 | 185 | def __call__(self,sample): 186 | hm = np.zeros((self.num_keypoints, self.heatmap_size, self.heatmap_size), dtype=np.float32) 187 | reg = np.zeros((self.num_keypoints, 2), dtype=np.float32) 188 | ind = np.zeros((self.num_keypoints), dtype=np.int64) 189 | mask = np.zeros((self.num_keypoints), dtype=np.uint8) 190 | 191 | for i,x2d in enumerate(sample['x2d']): 192 | ct = x2d * self.heatmap_size/self.image_size 193 | ct_int = (ct + 0.5).astype(np.int32) 194 | if ct_int[0] < self.heatmap_size and ct_int[1] < self.heatmap_size and ct_int[0] >= 0 and ct_int[1] >= 0: 195 | radius = 2 196 | self.draw_gaussian(hm[i], ct, radius) 197 | ind[i] = ct_int[1] * self.heatmap_size + ct_int[0] 198 | reg[i] = ct - ct_int 199 | mask[i] = 1 200 | 201 | sample['hm'] = hm 202 | sample['reg'] = reg 203 | sample['mask'] = mask 204 | sample['ind'] = ind 205 | return sample 206 | 207 | def gaussian2D(self, shape, sigma=1): 208 | m, n = [(ss - 1.) / 2. for ss in shape] 209 | y, x = np.ogrid[-m:m+1, -n:n+1] 210 | 211 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 212 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 213 | return h 214 | 215 | def draw_gaussian(self, heatmap, center, radius): 216 | diameter = 2 * radius + 1 217 | gaussian = self.gaussian2D((diameter, diameter), sigma=diameter / 6) 218 | 219 | x, y = int(center[0]), int(center[1]) 220 | 221 | height, width = heatmap.shape[0:2] 222 | 223 | left, right = min(x, radius), min(width - x, radius + 1) 224 | top, bottom = min(y, radius), min(height - y, radius + 1) 225 | 226 | masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] 227 | masked_gaussian = gaussian[radius - top:radius + 228 | bottom, radius - left:radius + right] 229 | if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug 230 | np.maximum(masked_heatmap, masked_gaussian, out=masked_heatmap) 231 | return heatmap 232 | 233 | 234 | 235 | 236 | 237 | 238 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import cv2 4 | from utils import Camera, eulid_to_homo 5 | import numpy as np 6 | from augment import * 7 | import os 8 | from collections import defaultdict 9 | import random 10 | 11 | class Human36M(Dataset): 12 | def __init__(self, cfg, is_train): 13 | super().__init__() 14 | self.cfg = cfg 15 | self.is_train = is_train 16 | self.labels = np.load(cfg['labels_path'], allow_pickle=True).item() 17 | # n_cameras = len(self.labels['camera_names']) 18 | 19 | train_subjects = ['S1', 'S5', 'S6', 'S7', 'S8'] 20 | test_subjects = ['S9','S11'] 21 | train_subjects = list(self.labels['subject_names'].index(x) for x in train_subjects) 22 | test_subjects = list(self.labels['subject_names'].index(x) for x in test_subjects) 23 | 24 | if is_train: 25 | mask = np.isin(self.labels['table']['subject_idx'], train_subjects, assume_unique=True) 26 | else: 27 | mask = np.isin(self.labels['table']['subject_idx'], test_subjects, assume_unique=True) 28 | 29 | 30 | self.labels['table'] = self.labels['table'][mask] 31 | 32 | self.augment = Compose([ 33 | Crop(cfg['scaleRange'],cfg['moveRange']), 34 | Resize(cfg['image_size']), 35 | PhotometricDistort(), 36 | NormSkeleton(), 37 | NormImage(), 38 | GenHeatmap(cfg['num_keypoints'],cfg['image_size'],cfg['heatmap_size']), 39 | ]) 40 | 41 | def __len__(self): 42 | if self.is_train: 43 | return (len(self.labels['table'])*len(self.labels['camera_names']))//self.cfg['data_skip_train'] 44 | else: 45 | return len(self.labels['table'])*len(self.labels['camera_names'])//self.cfg['data_skip_test'] 46 | 47 | def __getitem__(self, index): 48 | if self.is_train: 49 | index = index * self.cfg['data_skip_train'] + np.random.randint(self.cfg['data_skip_train']) 50 | else: 51 | index = index * self.cfg['data_skip_test'] + np.random.randint(self.cfg['data_skip_test']) 52 | 53 | camera_idx = index % len(self.labels['camera_names']) 54 | idx = index // len(self.labels['camera_names']) 55 | shot = self.labels['table'][idx] 56 | subject = self.labels['subject_names'][shot['subject_idx']] 57 | action = self.labels['action_names'][shot['action_idx']] 58 | frame_idx = shot['frame_idx'] 59 | camera_name = self.labels['camera_names'][camera_idx] 60 | 61 | 62 | 63 | image = cv2.imread(os.path.join( 64 | self.cfg['root_path'], subject, action, 'imageSequence' + '-undistorted', 65 | camera_name, 'img_%06d.jpg' % (frame_idx+1))) 66 | 67 | box = shot['bbox_by_camera_tlbr'][camera_idx][[1,0,3,2]] # TLBR to LTRB 68 | 69 | # x3d = np.pad( 70 | # shot['keypoints'][:self.cfg['num_keypts']], 71 | # ((0,0), (0,1)), 'constant', constant_values=1.0) 72 | x3d = np.asarray(shot['keypoints'][:self.cfg['num_keypoints']]) 73 | 74 | shot_camera = self.labels['cameras'][shot['subject_idx'], camera_idx] 75 | camera = Camera(shot_camera['R'],shot_camera['t'],shot_camera['K']) 76 | 77 | sample = self.augment({'image':image, 'box': box, 'x3d': x3d, 'camera': camera}) 78 | 79 | if self.cfg['use_tag']: 80 | sample['tag'] = { 81 | 'subject': shot['subject_idx'], 82 | 'action': shot['action_idx'], 83 | 'camera': camera_idx, 84 | 'frame': frame_idx, 85 | } 86 | 87 | return sample 88 | 89 | class MultiPose(Dataset): 90 | def __init__(self, cfg, is_train): 91 | super().__init__() 92 | self.cfg = cfg 93 | self.is_train = is_train 94 | self.labels = np.load(cfg['labels_path'], allow_pickle=True).item() 95 | # n_cameras = len(self.labels['camera_names']) 96 | 97 | train_subjects = ['S1', 'S5', 'S6', 'S7', 'S8'] 98 | test_subjects = ['S9', 'S11'] 99 | train_subjects = list(self.labels['subject_names'].index(x) for x in train_subjects) 100 | test_subjects = list(self.labels['subject_names'].index(x) for x in test_subjects) 101 | 102 | if is_train: 103 | mask = np.isin(self.labels['table']['subject_idx'], train_subjects, assume_unique=True) 104 | else: 105 | mask = np.isin(self.labels['table']['subject_idx'], test_subjects, assume_unique=True) 106 | 107 | self.labels['table'] = self.labels['table'][mask] 108 | 109 | self.augment = Compose([ 110 | NormSkeleton(), 111 | ]) 112 | 113 | def __len__(self): 114 | return len(self.labels['table']) 115 | 116 | def __getitem__(self, index): 117 | shot = self.labels['table'][index] 118 | subject = self.labels['subject_names'][shot['subject_idx']] 119 | action = self.labels['action_names'][shot['action_idx']] 120 | frame_idx = shot['frame_idx'] 121 | 122 | view_list = [] 123 | for camera_idx, camera_name in enumerate(self.labels['camera_names']): 124 | x3d = np.asarray(shot['keypoints'][:self.cfg['num_keypoints']]) 125 | shot_camera = self.labels['cameras'][shot['subject_idx'], camera_idx] 126 | camera = Camera(shot_camera['R'],shot_camera['t'],shot_camera['K']) 127 | view_list.append(self.augment({'x3d': x3d, 'camera': camera})) 128 | 129 | random.shuffle(view_list) 130 | sample = defaultdict(list) 131 | for view in view_list: 132 | for key in view: 133 | sample[key].append(view[key]) 134 | 135 | for key in sample: 136 | sample[key] = np.asarray(sample[key]) 137 | 138 | V,J,_ = sample['x2d'].shape 139 | # (V,1,3,3) @ (V,J,3,1) -> (V,J,3) -> (V,J,4) 140 | if self.is_train: 141 | rands = np.random.randn(V,J,2)*2 142 | x2d = sample['x2d'] + rands 143 | conf = np.exp(-np.sqrt((rands**2).sum(-1)))[...,None] 144 | else: 145 | x2d = sample['x2d'] 146 | conf = np.ones((V,J,1)) 147 | # (V,1,3,3) @ (V,J,3,1) -> (V,J,3,1) 148 | xdir = (np.linalg.inv(sample['K'])[:,None] @ eulid_to_homo(x2d)[...,None]).squeeze(-1) 149 | xdir = np.concatenate([xdir, conf], axis=-1).astype(np.float32) 150 | sample['xdir'] = xdir 151 | 152 | return sample 153 | 154 | 155 | 156 | 157 | class MultiHuman36M(Dataset): 158 | def __init__(self, cfg, is_train): 159 | super().__init__() 160 | self.cfg = cfg 161 | self.is_train = is_train 162 | self.labels = np.load(cfg['labels_path'], allow_pickle=True).item() 163 | # n_cameras = len(self.labels['camera_names']) 164 | 165 | train_subjects = ['S1', 'S5', 'S6', 'S7', 'S8'] 166 | test_subjects = ['S9', 'S11'] 167 | train_subjects = list(self.labels['subject_names'].index(x) for x in train_subjects) 168 | test_subjects = list(self.labels['subject_names'].index(x) for x in test_subjects) 169 | 170 | if is_train: 171 | mask = np.isin(self.labels['table']['subject_idx'], train_subjects, assume_unique=True) 172 | else: 173 | mask = np.isin(self.labels['table']['subject_idx'], test_subjects, assume_unique=True) 174 | 175 | 176 | self.labels['table'] = self.labels['table'][mask] 177 | 178 | self.augment = Compose([ 179 | Crop(cfg['scaleRange'],cfg['moveRange']), 180 | Resize(cfg['image_size']), 181 | PhotometricDistort(), 182 | NormSkeleton(), 183 | NormImage(), 184 | GenHeatmap(cfg['num_keypoints'],cfg['image_size'],cfg['heatmap_size']), 185 | ]) 186 | 187 | def __len__(self): 188 | if self.is_train: 189 | return len(self.labels['table']) // self.cfg['data_skip_train'] 190 | else: 191 | return len(self.labels['table']) // self.cfg['data_skip_test'] 192 | 193 | def __getitem__(self, index): 194 | if self.is_train: 195 | index = index * self.cfg['data_skip_train'] + np.random.randint(self.cfg['data_skip_train']) 196 | else: 197 | index = index * self.cfg['data_skip_test'] + np.random.randint(self.cfg['data_skip_test']) 198 | 199 | shot = self.labels['table'][index] 200 | subject = self.labels['subject_names'][shot['subject_idx']] 201 | action = self.labels['action_names'][shot['action_idx']] 202 | frame_idx = shot['frame_idx'] 203 | 204 | view_list = [] 205 | for camera_idx, camera_name in enumerate(self.labels['camera_names']): 206 | 207 | image = cv2.imread(os.path.join( 208 | self.cfg['root_path'], subject, action, 'imageSequence' + '-undistorted', 209 | camera_name, 'img_%06d.jpg' % (frame_idx+1))) 210 | 211 | box = shot['bbox_by_camera_tlbr'][camera_idx][[1,0,3,2]] # TLBR to LTRB 212 | 213 | # x3d = np.pad( 214 | # shot['keypoints'][:self.cfg['num_keypts']], 215 | # ((0,0), (0,1)), 'constant', constant_values=1.0) 216 | x3d = np.asarray(shot['keypoints'][:self.cfg['num_keypoints']]) 217 | 218 | shot_camera = self.labels['cameras'][shot['subject_idx'], camera_idx] 219 | camera = Camera(shot_camera['R'],shot_camera['t'],shot_camera['K']) 220 | 221 | view_list.append(self.augment({'image':image, 'box': box, 'x3d': x3d, 'camera': camera})) 222 | 223 | # random.shuffle(view_list) 224 | sample = defaultdict(list) 225 | for view in view_list: 226 | for key in view: 227 | sample[key].append(view[key]) 228 | 229 | for key in sample: 230 | sample[key] = np.asarray(sample[key]) 231 | 232 | return sample -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: algo 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - asttokens=2.4.1=pyhd8ed1ab_0 11 | - backcall=0.2.0=pyh9f0ad1d_0 12 | - blas=1.0=mkl 13 | - brotli-python=1.0.9=py38h6a678d5_7 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2024.2.2=hbcca054_0 16 | - certifi=2024.2.2=pyhd8ed1ab_0 17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 18 | - comm=0.2.1=pyhd8ed1ab_0 19 | - cuda-cudart=11.8.89=0 20 | - cuda-cupti=11.8.87=0 21 | - cuda-libraries=11.8.0=0 22 | - cuda-nvrtc=11.8.89=0 23 | - cuda-nvtx=11.8.86=0 24 | - cuda-runtime=11.8.0=0 25 | - debugpy=1.6.7=py38h6a678d5_0 26 | - decorator=5.1.1=pyhd8ed1ab_0 27 | - executing=2.0.1=pyhd8ed1ab_0 28 | - ffmpeg=4.3=hf484d3e_0 29 | - filelock=3.13.1=py38h06a4308_0 30 | - freetype=2.12.1=h4a9f257_0 31 | - gmp=6.2.1=h295c915_3 32 | - gmpy2=2.1.2=py38heeb90bb_0 33 | - gnutls=3.6.15=he1e5248_0 34 | - idna=3.4=py38h06a4308_0 35 | - importlib-metadata=7.0.1=pyha770c72_0 36 | - importlib_metadata=7.0.1=hd8ed1ab_0 37 | - intel-openmp=2023.1.0=hdb19cb5_46306 38 | - ipykernel=6.29.2=pyhd33586a_0 39 | - ipython=8.12.2=pyh41d4057_0 40 | - jedi=0.19.1=pyhd8ed1ab_0 41 | - jinja2=3.1.3=py38h06a4308_0 42 | - jpeg=9e=h5eee18b_1 43 | - jupyter_client=8.6.0=pyhd8ed1ab_0 44 | - jupyter_core=5.7.1=py38h578d9bd_0 45 | - lame=3.100=h7b6447c_0 46 | - lcms2=2.12=h3be6417_0 47 | - ld_impl_linux-64=2.38=h1181459_1 48 | - lerc=3.0=h295c915_0 49 | - libcublas=11.11.3.6=0 50 | - libcufft=10.9.0.58=0 51 | - libcufile=1.8.1.2=0 52 | - libcurand=10.3.4.107=0 53 | - libcusolver=11.4.1.48=0 54 | - libcusparse=11.7.5.86=0 55 | - libdeflate=1.17=h5eee18b_1 56 | - libffi=3.4.4=h6a678d5_0 57 | - libgcc-ng=13.2.0=h807b86a_5 58 | - libgomp=13.2.0=h807b86a_5 59 | - libiconv=1.16=h7f8727e_2 60 | - libidn2=2.3.4=h5eee18b_0 61 | - libjpeg-turbo=2.0.0=h9bf148f_0 62 | - libnpp=11.8.0.86=0 63 | - libnvjpeg=11.9.0.86=0 64 | - libpng=1.6.39=h5eee18b_0 65 | - libsodium=1.0.18=h36c2ea0_1 66 | - libstdcxx-ng=11.2.0=h1234567_1 67 | - libtasn1=4.19.0=h5eee18b_0 68 | - libtiff=4.5.1=h6a678d5_0 69 | - libunistring=0.9.10=h27cfd23_0 70 | - libwebp-base=1.3.2=h5eee18b_0 71 | - llvm-openmp=14.0.6=h9e868ea_0 72 | - lz4-c=1.9.4=h6a678d5_0 73 | - markupsafe=2.1.3=py38h5eee18b_0 74 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 75 | - mkl=2023.1.0=h213fc3f_46344 76 | - mkl-service=2.4.0=py38h5eee18b_1 77 | - mkl_fft=1.3.8=py38h5eee18b_0 78 | - mkl_random=1.2.4=py38hdb19cb5_0 79 | - mpc=1.1.0=h10f8cd9_1 80 | - mpfr=4.0.2=hb69a4c5_1 81 | - mpmath=1.3.0=py38h06a4308_0 82 | - ncurses=6.4=h6a678d5_0 83 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 84 | - nettle=3.7.3=hbbd107a_1 85 | - networkx=3.1=py38h06a4308_0 86 | - numpy=1.24.3=py38hf6e8229_1 87 | - numpy-base=1.24.3=py38h060ed82_1 88 | - openh264=2.1.1=h4ff587b_0 89 | - openjpeg=2.4.0=h3ad879b_0 90 | - openssl=3.2.1=hd590300_0 91 | - packaging=23.2=pyhd8ed1ab_0 92 | - parso=0.8.3=pyhd8ed1ab_0 93 | - pexpect=4.9.0=pyhd8ed1ab_0 94 | - pickleshare=0.7.5=py_1003 95 | - pillow=10.2.0=py38h5eee18b_0 96 | - pip=23.3.1=py38h06a4308_0 97 | - platformdirs=4.2.0=pyhd8ed1ab_0 98 | - prompt-toolkit=3.0.42=pyha770c72_0 99 | - prompt_toolkit=3.0.42=hd8ed1ab_0 100 | - psutil=5.9.8=py38h01eb140_0 101 | - ptyprocess=0.7.0=pyhd3deb0d_0 102 | - pure_eval=0.2.2=pyhd8ed1ab_0 103 | - pygments=2.17.2=pyhd8ed1ab_0 104 | - pysocks=1.7.1=py38h06a4308_0 105 | - python=3.8.18=h955ad1f_0 106 | - python-dateutil=2.8.2=pyhd8ed1ab_0 107 | - python_abi=3.8=2_cp38 108 | - pytorch=2.2.0=py3.8_cuda11.8_cudnn8.7.0_0 109 | - pytorch-cuda=11.8=h7e8668a_5 110 | - pytorch-mutex=1.0=cuda 111 | - pyyaml=6.0.1=py38h5eee18b_0 112 | - pyzmq=25.1.2=py38h6a678d5_0 113 | - readline=8.2=h5eee18b_0 114 | - requests=2.31.0=py38h06a4308_1 115 | - setuptools=68.2.2=py38h06a4308_0 116 | - six=1.16.0=pyh6c4a22f_0 117 | - sqlite=3.41.2=h5eee18b_0 118 | - stack_data=0.6.2=pyhd8ed1ab_0 119 | - sympy=1.12=py38h06a4308_0 120 | - tbb=2021.8.0=hdb19cb5_0 121 | - tk=8.6.12=h1ccaba5_0 122 | - torchaudio=2.2.0=py38_cu118 123 | - torchtriton=2.2.0=py38 124 | - torchvision=0.17.0=py38_cu118 125 | - tornado=6.3.3=py38h01eb140_1 126 | - traitlets=5.14.1=pyhd8ed1ab_0 127 | - typing_extensions=4.9.0=py38h06a4308_1 128 | - urllib3=2.1.0=py38h06a4308_1 129 | - wcwidth=0.2.13=pyhd8ed1ab_0 130 | - wheel=0.41.2=py38h06a4308_0 131 | - xz=5.4.5=h5eee18b_0 132 | - yaml=0.2.5=h7b6447c_0 133 | - zeromq=4.3.5=h6a678d5_0 134 | - zipp=3.17.0=pyhd8ed1ab_0 135 | - zlib=1.2.13=h5eee18b_0 136 | - zstd=1.5.5=hc292b87_0 137 | - pip: 138 | - absl-py==2.1.0 139 | - cachetools==5.3.2 140 | - contourpy==1.1.1 141 | - cycler==0.12.1 142 | - fonttools==4.49.0 143 | - google-auth==2.28.0 144 | - google-auth-oauthlib==1.0.0 145 | - grpcio==1.60.1 146 | - importlib-resources==6.1.1 147 | - kiwisolver==1.4.5 148 | - markdown==3.5.2 149 | - matplotlib==3.7.5 150 | - oauthlib==3.2.2 151 | - opencv-python==4.9.0.80 152 | - pandas==2.0.3 153 | - protobuf==4.25.3 154 | - pyasn1==0.5.1 155 | - pyasn1-modules==0.3.0 156 | - pyparsing==3.1.1 157 | - pytz==2024.1 158 | - requests-oauthlib==1.3.1 159 | - rsa==4.9 160 | - seaborn==0.13.2 161 | - tensorboard==2.14.0 162 | - tensorboard-data-server==0.7.2 163 | - tensorboardx==2.6.2.2 164 | - torchsummary==1.5.1 165 | - tqdm==4.66.2 166 | - tzdata==2024.1 167 | - werkzeug==3.0.1 168 | prefix: ~/.conda/envs/algo 169 | -------------------------------------------------------------------------------- /infer3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | from utils import * 8 | from trainer import AverageMeter 9 | from datasets import MultiHuman36M 10 | from models import ProbTri 11 | from loss import Net3d 12 | from tqdm import tqdm 13 | 14 | cfg = { 15 | 'root_path': '/data/human36m/processed', 16 | 'labels_path': '/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy', 17 | 'lr':1e-5, 18 | 'num_epoch':300, 19 | 'batch_size_train': 32, 20 | 'batch_size_test': 8, 21 | 'num_workers':8, 22 | 'num_keypoints': 17, 23 | 'num_views': 4, 24 | 'scaleRange': [1.1,1.2], 25 | 'moveRange': [-0.1,0.1], 26 | 'image_size':384, 27 | 'heatmap_size': 96, 28 | # 'backbone_path': 'checkpoints/backbone.pth', 29 | # 'fusion_path': 'checkpoints/fusion.pth', 30 | 'model_path': 'checkpoints/pretrain.pth', 31 | 'device':'cuda', 32 | 'save_dir': '/logs/pose3d', 33 | 'use_tag':False, 34 | 'data_skip_train':8, 35 | 'data_skip_test':1, 36 | } 37 | 38 | test_db = MultiHuman36M(cfg, is_train=False) 39 | test_loader = DataLoader( 40 | test_db, 41 | batch_size=cfg['batch_size_test'], 42 | shuffle=False, 43 | num_workers = cfg['num_workers'], 44 | pin_memory = True, 45 | drop_last=True, 46 | ) 47 | 48 | # trainer 49 | model = ProbTri(cfg).cuda() 50 | if 'backbone_path' in cfg: 51 | pretrain_dict = torch.load(cfg['backbone_path']) 52 | missing, unexpected = model.backbone.load_state_dict(pretrain_dict,strict=False) 53 | print('load backbone model, missing length', len(missing), 'unexpected', len(unexpected) , '\n') 54 | 55 | if 'fusion_path' in cfg: 56 | pretrain_dict = torch.load(cfg['fusion_path']) 57 | missing, unexpected = model.fusion.load_state_dict(pretrain_dict,strict=False) 58 | print('load fusion model, missing length', len(missing), 'unexpected', len(unexpected) , '\n') 59 | 60 | if 'model_path' in cfg: 61 | pretrain_dict = torch.load(cfg['model_path']) 62 | missing, unexpected = model.load_state_dict(pretrain_dict,strict=False) 63 | print('missing length', len(missing), 'unexpected', len(unexpected) , '\n') 64 | 65 | net = Net3d(cfg, model) 66 | net.eval() 67 | torch.set_grad_enabled(False) 68 | 69 | 70 | avg_loss_stats = {} 71 | for iter_id, batch in tqdm(enumerate(test_loader)): 72 | for key in batch: 73 | if isinstance(batch[key], torch.Tensor): 74 | batch[key] = batch[key].to(device = cfg['device'], non_blocking=True) 75 | 76 | output, loss, loss_stats = net(batch) 77 | if 'loss' not in avg_loss_stats: 78 | for key in loss_stats: 79 | avg_loss_stats[key] = AverageMeter() 80 | 81 | for key in loss_stats: 82 | avg_loss_stats[key].update(loss_stats[key].item(), test_loader.batch_size) 83 | 84 | for k, v in avg_loss_stats.items(): 85 | print(k, v.avg) -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class WingLoss(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.omega = 10 10 | self.sigma = 2 11 | self.c = self.omega - self.omega * np.log(1+self.omega/self.sigma) 12 | 13 | def forward(self, pred, target): 14 | """ 15 | Args: 16 | pred : (B,N,3) 17 | target: (B,N,3) 18 | """ 19 | l1, loss = self.calloss(pred, target) 20 | l2 = torch.sqrt(((pred - target) ** 2).sum(-1)).mean(-1) 21 | 22 | loss_stats = { 23 | 'loss': loss.mean().cpu().detach(), 24 | 'diff/l1' : l1.mean().cpu().detach(), 25 | 'diff/l2': l2.mean().cpu().detach(), 26 | } 27 | 28 | return loss.mean(), loss_stats 29 | 30 | def calloss(self, x,t): 31 | diff = torch.abs(x - t).sum(-1) 32 | is_small = (diff < self.omega).float() 33 | small_loss = self.omega * torch.log(1+diff/self.sigma) 34 | big_loss = diff - self.c 35 | loss = (small_loss * is_small + big_loss * (1-is_small)) * 0.1 36 | return diff, loss 37 | 38 | class FocalLoss(nn.Module): 39 | """ 40 | Args: 41 | pred (B, c, h, w) 42 | target (B, c, h, w) 43 | """ 44 | 45 | def __init__(self): 46 | super(FocalLoss, self).__init__() 47 | self.epsilon = 1e-8 48 | 49 | def forward(self, pred, target): 50 | pred = pred.clamp(min=self.epsilon, max=1-self.epsilon) 51 | yeq1_index = target.ge(0.9).float() 52 | other_index = target.lt(0.9).float() 53 | 54 | yeq1_loss = (yeq1_index * torch.log(pred) * torch.pow(1-pred,2)).sum() 55 | other_loss = (other_index * torch.log(1 - pred) * torch.pow(pred, 2) * torch.pow(1 - target, 4)).sum() 56 | num_yeq1 = yeq1_index.float().sum() 57 | 58 | if num_yeq1 == 0: 59 | loss = - other_loss 60 | else: 61 | loss = - (yeq1_loss + other_loss) / num_yeq1 62 | 63 | return loss 64 | 65 | def _tranpose_and_gather_feat(feat, ind): 66 | """ 67 | Args: 68 | feat (B,C*2,H,W) 69 | ind (B,C) 70 | Returns: 71 | feat (B,C,2) 72 | """ 73 | # (B,C*2,H,W) -> (B,C,2,H*W) 74 | feat = feat.view(feat.size(0), feat.size(1)//2, 2, -1) 75 | return torch.gather(feat, 3, ind.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, -1)).squeeze(3) 76 | 77 | class RegL1Loss(nn.Module): 78 | """ 79 | Args: 80 | output (B, dim, h, w) 81 | mask (B, max_obj) 82 | ind (B, max_obj) 83 | target (B, max_obj, dim) 84 | Temp: 85 | pred (B, max_obj, dim) 86 | """ 87 | 88 | def __init__(self): 89 | super(RegL1Loss, self).__init__() 90 | 91 | def forward(self, output, mask, ind, target): 92 | pred = _tranpose_and_gather_feat(output, ind) 93 | 94 | mask = mask.unsqueeze(2).expand_as(pred).float() 95 | loss = F.l1_loss(pred * mask, target * mask, reduction='sum') 96 | loss = loss / (mask.sum() + 1e-4) 97 | return loss 98 | 99 | class DetLoss(nn.Module): 100 | def __init__(self,cfg): 101 | super(DetLoss, self).__init__() 102 | self.cfg=cfg 103 | self.hm_crit = FocalLoss() 104 | self.reg_crit = RegL1Loss() 105 | 106 | def forward(self, output, batch): 107 | 108 | hm_loss = self.hm_crit(output['hm'], batch['hm']) 109 | reg_loss = self.reg_crit(output['reg'],batch['mask'],batch['ind'], batch['reg']) 110 | loss = hm_loss + reg_loss 111 | 112 | B,C,H,W = output['hm'].shape 113 | max_val, max_idx = torch.max(output['hm'].view(B, C, -1), dim=2) 114 | # reg = _tranpose_and_gather_feat(output['reg'], max_idx) 115 | reg = torch.gather(output['reg'].view(B,C,2,H*W), 3, max_idx.unsqueeze(-1).unsqueeze(-1).expand(-1,-1,2,-1)).squeeze(3) 116 | 117 | x = (torch.stack([max_idx%W, max_idx//W],dim=-1) + reg) * self.cfg['image_size'] / self.cfg['heatmap_size'] 118 | l1 = torch.abs(x - batch['x2d']).sum(-1).mean(-1) 119 | l2 = torch.sqrt(((x - batch['x2d']) ** 2).sum(-1)).mean(-1) 120 | 121 | 122 | loss_stats = { 123 | 'loss': loss.mean().cpu().detach(), 124 | 'loss/hm' : hm_loss.mean().cpu().detach(), 125 | 'loss/reg': reg_loss.mean().cpu().detach(), 126 | 'diff/l1': l1.mean().cpu().detach(), 127 | 'diff/l2': l2.mean().cpu().detach(), 128 | } 129 | output['pred_x2d'] = torch.cat([x,max_val[...,None]],dim=-1) 130 | return loss.mean(), loss_stats 131 | 132 | class CocktailLoss(nn.Module): 133 | def __init__(self): 134 | super().__init__() 135 | self.hm_crit = FocalLoss() 136 | # self.reg_crit = RegL1Loss() 137 | 138 | self.omega = 25 139 | self.sigma = 5 140 | self.c = self.omega - self.omega * np.log(1+self.omega/self.sigma) 141 | 142 | 143 | def forward(self, output, batch): 144 | 145 | # 2d 146 | hm_loss, reg_loss = 0., 0. 147 | B,V = batch['hm'].shape[:2] 148 | # print(output['hm'].shape, batch['hm'].shape, output['reg'].shape, batch['mask'].shape, batch['ind'].shape, batch['reg'].shape) 149 | for i in range(V): 150 | hm_loss += self.hm_crit(output['hm'][:,i], batch['hm'][:,i]) / V 151 | # reg_loss += self.reg_crit(output['reg'][:,i],batch['mask'][:,i],batch['ind'][:,i], batch['reg'][:,i]) / V 152 | 153 | # 3d 154 | pred = output['x3d'] 155 | target = batch['x3d'][:,0] 156 | x3d_l1, x3d_loss = self.calloss(pred, target) 157 | x3d_l2 = torch.sqrt(((pred - target) ** 2).sum(-1)).mean(-1) 158 | 159 | x2d_l1 = torch.abs(output['x2d'] - batch['x2d']).sum(-1).mean(-1) 160 | x2d_l2 = torch.sqrt(((output['x2d'] - batch['x2d']) ** 2).sum(-1)).mean(-1) 161 | 162 | loss = (x3d_loss + hm_loss)*0.5 163 | 164 | loss_stats = { 165 | 'loss': loss.mean().cpu().detach(), 166 | 'loss/hm' : hm_loss.mean().cpu().detach(), 167 | # 'loss/reg': reg_loss.mean().cpu().detach(), 168 | 'loss/x3d': x3d_loss.mean().cpu().detach(), 169 | 'x2d/l1': x2d_l1.mean().cpu().detach(), 170 | 'x2d/l2': x2d_l2.mean().cpu().detach(), 171 | 'x3d/l1' : x3d_l1.mean().cpu().detach(), 172 | 'x3d/l2': x3d_l2.mean().cpu().detach(), 173 | } 174 | 175 | return loss.mean(), loss_stats 176 | 177 | def calloss(self, x,t): 178 | diff = torch.abs(x - t).sum(-1) 179 | is_small = (diff < self.omega).float() 180 | small_loss = self.omega * torch.log(1+diff/self.sigma) 181 | big_loss = diff - self.c 182 | loss = (small_loss * is_small + big_loss * (1-is_small)) * 0.1 183 | return diff, loss 184 | 185 | 186 | 187 | 188 | class Net2d(nn.Module): 189 | def __init__(self, cfg, model): 190 | super().__init__() 191 | self.model = model.to(cfg['device']) 192 | self.loss = DetLoss(cfg).to(cfg['device']) 193 | 194 | def forward(self, batch): 195 | out_hm, out_reg, _ = self.model(batch['image']) 196 | output = {'hm':out_hm,'reg':out_reg} 197 | loss, loss_stats = self.loss(output, batch) 198 | return output, loss, loss_stats 199 | 200 | 201 | class Net3d(nn.Module): 202 | def __init__(self, cfg, model): 203 | super().__init__() 204 | self.model = model.to(cfg['device']) 205 | self.loss = CocktailLoss().to(cfg['device']) 206 | 207 | def forward(self, batch): 208 | out_x3d, out_hm, out_reg, out_x2d = self.model(batch['image'],batch['K']) 209 | output = {'x3d':out_x3d,'hm':out_hm,'reg':out_reg,'x2d':out_x2d} 210 | loss, loss_stats = self.loss(output, batch) 211 | return output, loss, loss_stats 212 | 213 | class NetPose(nn.Module): 214 | def __init__(self, cfg, model): 215 | super().__init__() 216 | self.model = model.to(cfg['device']) 217 | self.loss = WingLoss().to(cfg['device']) 218 | 219 | def forward(self, batch): 220 | out = self.model(batch['xdir']) 221 | loss, loss_stats = self.loss(out, batch['x3d'][:,0]) 222 | return out, loss, loss_stats 223 | 224 | 225 | class NetTri(nn.Module): 226 | def __init__(self, cfg, model): 227 | super().__init__() 228 | self.model = model.to(cfg['device']) 229 | self.loss = WingLoss().to(cfg['device']) 230 | 231 | def forward(self, batch): 232 | out = self.model(batch['xdir'], batch['xfeat']) 233 | loss, loss_stats = self.loss(out, batch['x3d'][:,0]) 234 | return out, loss, loss_stats 235 | 236 | # class Net(nn.Module): 237 | # def __init__(self, cfg, model): 238 | # super().__init__() 239 | # self.model = model.to(cfg['device']) 240 | # self.loss = WingLoss().to(cfg['device']) 241 | 242 | # def forward(self, batch): 243 | # lds_pred,_ = self.model(batch['image']) 244 | # lds_pred = lds_pred.squeeze() 245 | # loss, loss_stats = self.loss(lds_pred, batch['x2d']) 246 | # return lds_pred, loss, loss_stats 247 | 248 | 249 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | from typing import Optional, List, Tuple 6 | 7 | import copy 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from utils import eulid_to_homo 12 | 13 | 14 | __all__ = ['MobileOne', 'mobileone', 'reparameterize_model'] 15 | 16 | 17 | class SEBlock(nn.Module): 18 | """ Squeeze and Excite module. 19 | 20 | Pytorch implementation of `Squeeze-and-Excitation Networks` - 21 | https://arxiv.org/pdf/1709.01507.pdf 22 | """ 23 | 24 | def __init__(self, 25 | in_channels: int, 26 | rd_ratio: float = 0.0625) -> None: 27 | """ Construct a Squeeze and Excite Module. 28 | 29 | :param in_channels: Number of input channels. 30 | :param rd_ratio: Input channel reduction ratio. 31 | """ 32 | super(SEBlock, self).__init__() 33 | self.reduce = nn.Conv2d(in_channels=in_channels, 34 | out_channels=int(in_channels * rd_ratio), 35 | kernel_size=1, 36 | stride=1, 37 | bias=True) 38 | self.expand = nn.Conv2d(in_channels=int(in_channels * rd_ratio), 39 | out_channels=in_channels, 40 | kernel_size=1, 41 | stride=1, 42 | bias=True) 43 | 44 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 45 | """ Apply forward pass. """ 46 | b, c, h, w = inputs.size() 47 | x = F.avg_pool2d(inputs, kernel_size=[h, w]) 48 | x = self.reduce(x) 49 | x = F.relu(x) 50 | x = self.expand(x) 51 | x = torch.sigmoid(x) 52 | x = x.view(-1, c, 1, 1) 53 | return inputs * x 54 | 55 | 56 | class MobileOneBlock(nn.Module): 57 | """ MobileOne building block. 58 | 59 | This block has a multi-branched architecture at train-time 60 | and plain-CNN style architecture at inference time 61 | For more details, please refer to our paper: 62 | `An Improved One millisecond Mobile Backbone` - 63 | https://arxiv.org/pdf/2206.04040.pdf 64 | """ 65 | def __init__(self, 66 | in_channels: int, 67 | out_channels: int, 68 | kernel_size: int, 69 | stride: int = 1, 70 | padding: int = 0, 71 | dilation: int = 1, 72 | groups: int = 1, 73 | inference_mode: bool = False, 74 | use_se: bool = False, 75 | num_conv_branches: int = 1) -> None: 76 | """ Construct a MobileOneBlock module. 77 | 78 | :param in_channels: Number of channels in the input. 79 | :param out_channels: Number of channels produced by the block. 80 | :param kernel_size: Size of the convolution kernel. 81 | :param stride: Stride size. 82 | :param padding: Zero-padding size. 83 | :param dilation: Kernel dilation factor. 84 | :param groups: Group number. 85 | :param inference_mode: If True, instantiates model in inference mode. 86 | :param use_se: Whether to use SE-ReLU activations. 87 | :param num_conv_branches: Number of linear conv branches. 88 | """ 89 | super(MobileOneBlock, self).__init__() 90 | self.inference_mode = inference_mode 91 | self.groups = groups 92 | self.stride = stride 93 | self.kernel_size = kernel_size 94 | self.in_channels = in_channels 95 | self.out_channels = out_channels 96 | self.num_conv_branches = num_conv_branches 97 | 98 | # Check if SE-ReLU is requested 99 | if use_se: 100 | self.se = SEBlock(out_channels) 101 | else: 102 | self.se = nn.Identity() 103 | self.activation = nn.ReLU() 104 | 105 | if inference_mode: 106 | self.reparam_conv = nn.Conv2d(in_channels=in_channels, 107 | out_channels=out_channels, 108 | kernel_size=kernel_size, 109 | stride=stride, 110 | padding=padding, 111 | dilation=dilation, 112 | groups=groups, 113 | bias=True) 114 | else: 115 | # Re-parameterizable skip connection 116 | self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) \ 117 | if out_channels == in_channels and stride == 1 else None 118 | 119 | # Re-parameterizable conv branches 120 | rbr_conv = list() 121 | for _ in range(self.num_conv_branches): 122 | rbr_conv.append(self._conv_bn(kernel_size=kernel_size, 123 | padding=padding)) 124 | self.rbr_conv = nn.ModuleList(rbr_conv) 125 | 126 | # Re-parameterizable scale branch 127 | self.rbr_scale = None 128 | if kernel_size > 1: 129 | self.rbr_scale = self._conv_bn(kernel_size=1, 130 | padding=0) 131 | 132 | def forward(self, x: torch.Tensor) -> torch.Tensor: 133 | """ Apply forward pass. """ 134 | # Inference mode forward pass. 135 | if self.inference_mode: 136 | return self.activation(self.se(self.reparam_conv(x))) 137 | 138 | # Multi-branched train-time forward pass. 139 | # Skip branch output 140 | identity_out = 0 141 | if self.rbr_skip is not None: 142 | identity_out = self.rbr_skip(x) 143 | 144 | # Scale branch output 145 | scale_out = 0 146 | if self.rbr_scale is not None: 147 | scale_out = self.rbr_scale(x) 148 | 149 | # Other branches 150 | out = scale_out + identity_out 151 | for ix in range(self.num_conv_branches): 152 | out += self.rbr_conv[ix](x) 153 | 154 | return self.activation(self.se(out)) 155 | 156 | def reparameterize(self): 157 | """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` - 158 | https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched 159 | architecture used at training time to obtain a plain CNN-like structure 160 | for inference. 161 | """ 162 | if self.inference_mode: 163 | return 164 | kernel, bias = self._get_kernel_bias() 165 | self.reparam_conv = nn.Conv2d(in_channels=self.rbr_conv[0].conv.in_channels, 166 | out_channels=self.rbr_conv[0].conv.out_channels, 167 | kernel_size=self.rbr_conv[0].conv.kernel_size, 168 | stride=self.rbr_conv[0].conv.stride, 169 | padding=self.rbr_conv[0].conv.padding, 170 | dilation=self.rbr_conv[0].conv.dilation, 171 | groups=self.rbr_conv[0].conv.groups, 172 | bias=True) 173 | self.reparam_conv.weight.data = kernel 174 | self.reparam_conv.bias.data = bias 175 | 176 | # Delete un-used branches 177 | for para in self.parameters(): 178 | para.detach_() 179 | self.__delattr__('rbr_conv') 180 | self.__delattr__('rbr_scale') 181 | if hasattr(self, 'rbr_skip'): 182 | self.__delattr__('rbr_skip') 183 | 184 | self.inference_mode = True 185 | 186 | def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: 187 | """ Method to obtain re-parameterized kernel and bias. 188 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 189 | 190 | :return: Tuple of (kernel, bias) after fusing branches. 191 | """ 192 | # get weights and bias of scale branch 193 | kernel_scale = 0 194 | bias_scale = 0 195 | if self.rbr_scale is not None: 196 | kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) 197 | # Pad scale branch kernel to match conv branch kernel size. 198 | pad = self.kernel_size // 2 199 | kernel_scale = torch.nn.functional.pad(kernel_scale, 200 | [pad, pad, pad, pad]) 201 | 202 | # get weights and bias of skip branch 203 | kernel_identity = 0 204 | bias_identity = 0 205 | if self.rbr_skip is not None: 206 | kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) 207 | 208 | # get weights and bias of conv branches 209 | kernel_conv = 0 210 | bias_conv = 0 211 | for ix in range(self.num_conv_branches): 212 | _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) 213 | kernel_conv += _kernel 214 | bias_conv += _bias 215 | 216 | kernel_final = kernel_conv + kernel_scale + kernel_identity 217 | bias_final = bias_conv + bias_scale + bias_identity 218 | return kernel_final, bias_final 219 | 220 | def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: 221 | """ Method to fuse batchnorm layer with preceeding conv layer. 222 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 223 | 224 | :param branch: 225 | :return: Tuple of (kernel, bias) after fusing batchnorm. 226 | """ 227 | if isinstance(branch, nn.Sequential): 228 | kernel = branch.conv.weight 229 | running_mean = branch.bn.running_mean 230 | running_var = branch.bn.running_var 231 | gamma = branch.bn.weight 232 | beta = branch.bn.bias 233 | eps = branch.bn.eps 234 | else: 235 | assert isinstance(branch, nn.BatchNorm2d) 236 | if not hasattr(self, 'id_tensor'): 237 | input_dim = self.in_channels // self.groups 238 | kernel_value = torch.zeros((self.in_channels, 239 | input_dim, 240 | self.kernel_size, 241 | self.kernel_size), 242 | dtype=branch.weight.dtype, 243 | device=branch.weight.device) 244 | for i in range(self.in_channels): 245 | kernel_value[i, i % input_dim, 246 | self.kernel_size // 2, 247 | self.kernel_size // 2] = 1 248 | self.id_tensor = kernel_value 249 | kernel = self.id_tensor 250 | running_mean = branch.running_mean 251 | running_var = branch.running_var 252 | gamma = branch.weight 253 | beta = branch.bias 254 | eps = branch.eps 255 | std = (running_var + eps).sqrt() 256 | t = (gamma / std).reshape(-1, 1, 1, 1) 257 | return kernel * t, beta - running_mean * gamma / std 258 | 259 | def _conv_bn(self, 260 | kernel_size: int, 261 | padding: int) -> nn.Sequential: 262 | """ Helper method to construct conv-batchnorm layers. 263 | 264 | :param kernel_size: Size of the convolution kernel. 265 | :param padding: Zero-padding size. 266 | :return: Conv-BN module. 267 | """ 268 | mod_list = nn.Sequential() 269 | mod_list.add_module('conv', nn.Conv2d(in_channels=self.in_channels, 270 | out_channels=self.out_channels, 271 | kernel_size=kernel_size, 272 | stride=self.stride, 273 | padding=padding, 274 | groups=self.groups, 275 | bias=False)) 276 | mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels)) 277 | return mod_list 278 | 279 | 280 | class MobileConv(nn.Module): 281 | def __init__(self, in_channels, out_channels, stride=1, inference_mode = False): 282 | super().__init__() 283 | self.convs = nn.Sequential( 284 | MobileOneBlock(in_channels=in_channels, 285 | out_channels=in_channels, 286 | kernel_size=3, 287 | stride=stride, 288 | padding=1, 289 | groups=1, 290 | inference_mode=inference_mode, 291 | use_se=False, 292 | num_conv_branches=1), 293 | MobileOneBlock(in_channels=in_channels, 294 | out_channels=out_channels, 295 | kernel_size=1, 296 | stride=1, 297 | padding=0, 298 | groups=1, 299 | inference_mode=inference_mode, 300 | use_se=False, 301 | num_conv_branches=1), 302 | ) 303 | def __call__(self, x): 304 | return self.convs(x) 305 | 306 | 307 | class UpHead(nn.Module): 308 | def __init__(self,width_multipliers,inference_mode): 309 | super().__init__() 310 | self.proj4 = nn.Sequential( 311 | MobileConv(int(512 * width_multipliers[3]), int(256 * width_multipliers[2]), stride=1, inference_mode=inference_mode), 312 | nn.UpsamplingBilinear2d(scale_factor=2) 313 | ) 314 | self.proj3 = nn.Sequential( 315 | MobileConv(int(256 * width_multipliers[2]), int(128 * width_multipliers[1]), stride=1, inference_mode=inference_mode), 316 | nn.UpsamplingBilinear2d(scale_factor=2) 317 | ) 318 | self.proj2 = nn.Sequential( 319 | MobileConv(int(128 * width_multipliers[1]), int(64 * width_multipliers[0]), stride=1, inference_mode=inference_mode), 320 | nn.UpsamplingBilinear2d(scale_factor=2) 321 | # MobileConv(256, 128, stride=1, inference_mode=inference_mode), 322 | ) 323 | self.proj1 = nn.Sequential( 324 | MobileConv(int(64 * width_multipliers[0]), 256, stride=1, inference_mode=inference_mode), 325 | # nn.UpsamplingBilinear2d(scale_factor=2) 326 | MobileConv(256, 128, stride=1, inference_mode=inference_mode), 327 | ) 328 | # self.proj0 = nn.Sequential( 329 | # MobileConv( min(64,int(64 * width_multipliers[0])), 64, stride=1, inference_mode=inference_mode), 330 | # ) 331 | 332 | def forward(self, x1, x2, x3, x4): 333 | elt3 = x3 + self.proj4(x4) 334 | elt2 = x2 + self.proj3(elt3) 335 | elt1 = x1 + self.proj2(elt2) 336 | out = self.proj1(elt1) 337 | # elt0 = x0 + self.proj1(elt1) 338 | # out = self.proj0(elt0) 339 | return out 340 | 341 | 342 | class Pose2d(nn.Module): 343 | def __init__(self, 344 | num_blocks_per_stage: List[int] = [2, 8, 10, 1], 345 | num_classes: int = 1000, 346 | width_multipliers: Optional[List[float]] = None, 347 | inference_mode: bool = False, 348 | use_se: bool = False, 349 | num_conv_branches: int = 1) -> None: 350 | """ 351 | :param num_blocks_per_stage: List of number of blocks per stage. 352 | :param num_classes: Number of classes in the dataset. 353 | :param width_multipliers: List of width multiplier for blocks in a stage. 354 | :param inference_mode: If True, instantiates model in inference mode. 355 | :param use_se: Whether to use SE-ReLU activations. 356 | :param num_conv_branches: Number of linear conv branches. 357 | """ 358 | super().__init__() 359 | 360 | assert len(width_multipliers) == 4 361 | self.inference_mode = inference_mode 362 | self.in_planes = min(64, int(64 * width_multipliers[0])) 363 | self.use_se = use_se 364 | self.num_conv_branches = num_conv_branches 365 | 366 | # Build stages 367 | self.stage0 = MobileOneBlock(in_channels=3, out_channels=self.in_planes, 368 | kernel_size=3, stride=2, padding=1, 369 | inference_mode=self.inference_mode) 370 | self.cur_layer_idx = 1 371 | self.stage1 = self._make_stage(int(64 * width_multipliers[0]), num_blocks_per_stage[0], 372 | num_se_blocks=0) 373 | self.stage2 = self._make_stage(int(128 * width_multipliers[1]), num_blocks_per_stage[1], 374 | num_se_blocks=0) 375 | self.stage3 = self._make_stage(int(256 * width_multipliers[2]), num_blocks_per_stage[2], 376 | num_se_blocks=int(num_blocks_per_stage[2] // 2) if use_se else 0) 377 | 378 | self.stage4 = self._make_stage(int(512 * width_multipliers[3]), num_blocks_per_stage[3], 379 | num_se_blocks=num_blocks_per_stage[3] if use_se else 0) 380 | # self.neck = nn.Sequential( 381 | # MobileConv(int(512 * width_multipliers[3]), 512, stride=1, inference_mode=inference_mode), 382 | # MobileConv(512, 256, stride=1, inference_mode=inference_mode), 383 | # # nn.UpsamplingBilinear2d(scale_factor=2), 384 | # MobileConv(256, 256, stride=2, inference_mode=inference_mode), 385 | # MobileConv(256, 128, stride=1, inference_mode=inference_mode), 386 | # MobileConv(128, 128, stride=2, inference_mode=inference_mode), 387 | # MobileConv(128, 64, stride=1, inference_mode=inference_mode), 388 | # ) 389 | # self.gap = nn.AdaptiveAvgPool2d(output_size=1) 390 | # self.linear = nn.Linear(64, num_classes*2) 391 | 392 | self.upstage = UpHead(width_multipliers, inference_mode) 393 | self.hm_head = nn.Sequential( 394 | MobileConv( 128, num_classes, stride=1, inference_mode=inference_mode), 395 | nn.Softmax2d(), 396 | ) 397 | self.reg_head = nn.Sequential( 398 | MobileConv( 128, num_classes*2, stride=1, inference_mode=inference_mode), 399 | ) 400 | self.freeze() 401 | 402 | def _make_stage(self, 403 | planes: int, 404 | num_blocks: int, 405 | num_se_blocks: int) -> nn.Sequential: 406 | """ Build a stage of MobileOne model. 407 | 408 | :param planes: Number of output channels. 409 | :param num_blocks: Number of blocks in this stage. 410 | :param num_se_blocks: Number of SE blocks in this stage. 411 | :return: A stage of MobileOne model. 412 | """ 413 | # Get strides for all layers 414 | strides = [2] + [1]*(num_blocks-1) 415 | blocks = [] 416 | for ix, stride in enumerate(strides): 417 | use_se = False 418 | if num_se_blocks > num_blocks: 419 | raise ValueError("Number of SE blocks cannot " 420 | "exceed number of layers.") 421 | if ix >= (num_blocks - num_se_blocks): 422 | use_se = True 423 | 424 | # Depthwise conv 425 | blocks.append(MobileOneBlock(in_channels=self.in_planes, 426 | out_channels=self.in_planes, 427 | kernel_size=3, 428 | stride=stride, 429 | padding=1, 430 | groups=self.in_planes, 431 | inference_mode=self.inference_mode, 432 | use_se=use_se, 433 | num_conv_branches=self.num_conv_branches)) 434 | # Pointwise conv 435 | blocks.append(MobileOneBlock(in_channels=self.in_planes, 436 | out_channels=planes, 437 | kernel_size=1, 438 | stride=1, 439 | padding=0, 440 | groups=1, 441 | inference_mode=self.inference_mode, 442 | use_se=use_se, 443 | num_conv_branches=self.num_conv_branches)) 444 | self.in_planes = planes 445 | self.cur_layer_idx += 1 446 | return nn.Sequential(*blocks) 447 | 448 | def freeze(self): 449 | for layer in [self.stage0, self.stage1, self.stage2, self.stage3, self.stage4]: 450 | for param in layer.parameters(): 451 | param.requires_grad = False 452 | 453 | for layer in [self.upstage, self.hm_head,self.reg_head]: 454 | for param in layer.parameters(): 455 | param.requires_grad = True 456 | layer.train() 457 | 458 | 459 | def forward(self, x: torch.Tensor) -> torch.Tensor: 460 | """ Apply forward pass. """ 461 | x0 = self.stage0(x) 462 | x1 = self.stage1(x0) 463 | x2 = self.stage2(x1) 464 | x3 = self.stage3(x2) 465 | x4 = self.stage4(x3) 466 | feature = self.upstage(x1,x2,x3,x4) 467 | # feature = self.neck(x) 468 | # out = self.gap(feature) 469 | # out = out.view(out.size(0), -1) 470 | # out = self.linear(out) 471 | return self.hm_head(feature), self.reg_head(feature), feature 472 | 473 | 474 | 475 | def pose2d_model(num_classes: int = 1000, inference_mode: bool = False) -> nn.Module: 476 | return Pose2d(num_classes=num_classes, inference_mode=inference_mode, 477 | width_multipliers = (3.0, 3.5, 3.5, 4.0),use_se=True) 478 | 479 | 480 | def reparameterize_model(model: torch.nn.Module) -> nn.Module: 481 | """ Method returns a model where a multi-branched structure 482 | used in training is re-parameterized into a single branch 483 | for inference. 484 | 485 | :param model: MobileOne model in train mode. 486 | :return: MobileOne model in inference mode. 487 | """ 488 | # Avoid editing original graph 489 | model = copy.deepcopy(model) 490 | for module in model.modules(): 491 | if hasattr(module, 'reparameterize'): 492 | module.reparameterize() 493 | return model 494 | 495 | 496 | class Fusion(nn.Module): 497 | def __init__(self,cfg): 498 | super().__init__() 499 | self.cfg = cfg 500 | self.embed = nn.Sequential( 501 | nn.Linear(4,1024), 502 | nn.ReLU(), 503 | nn.Linear(1024,512), 504 | nn.ReLU(), 505 | nn.Linear(512,256), 506 | nn.ReLU(), 507 | nn.Linear(256,128), 508 | ) 509 | 510 | self.pose0 = nn.Sequential( 511 | nn.Linear(cfg['num_keypoints']*256, 1024), 512 | nn.ReLU(), 513 | nn.Linear(1024,512), 514 | nn.ReLU(), 515 | nn.Linear(512,256), 516 | nn.ReLU(), 517 | # nn.Dropout(p=0.5), 518 | nn.Linear(256,256), 519 | ) 520 | self.pose1 = nn.Sequential( 521 | nn.Linear(cfg['num_views']*256, 1024), 522 | nn.ReLU(), 523 | nn.Linear(1024,512), 524 | nn.ReLU(), 525 | nn.Linear(512,cfg['num_keypoints']*3), 526 | ) 527 | 528 | def forward(self, xdir, xfeat=None): 529 | B,V,J,_ = xdir.shape 530 | xdir = self.embed(xdir) 531 | 532 | if xfeat == None: 533 | xfeat = torch.zeros_like(xdir) 534 | # (B,V,J,256) -> (B,V,J*256) 535 | x = torch.cat([xdir,xfeat], dim=-1).view(B,V,J*256) 536 | x = self.pose0(x) # (B,V,256) 537 | x = x.view(B,-1) 538 | out = self.pose1(x).view(B,-1,3) 539 | return out 540 | 541 | 542 | 543 | class ProbTri(nn.Module): 544 | def __init__(self, cfg): 545 | super().__init__() 546 | self.cfg = cfg 547 | self.backbone = pose2d_model(num_classes=cfg['num_keypoints']) 548 | self.fusion = Fusion(cfg) 549 | self.freeze() 550 | 551 | def freeze(self): 552 | self.backbone.freeze() 553 | # for layer in [self.backbone]: 554 | # for param in layer.parameters(): 555 | # param.requires_grad = False 556 | 557 | for layer in [self.fusion]: 558 | for param in layer.parameters(): 559 | param.requires_grad = True 560 | layer.train() 561 | 562 | def forward(self, image, K): 563 | """ 564 | Args: 565 | image (B,V,3,H,W) 566 | K (B,V,3,3) 567 | """ 568 | B,V,_,image_h,image_w = image.shape 569 | image = image.view(-1,3,image_h,image_w) 570 | K = K.view(-1,3,3) 571 | out_hm,out_reg,feature = self.backbone(image) 572 | 573 | # BV,J,H,W 574 | heatmap = out_hm 575 | BV, J, H,W = heatmap.shape 576 | reg = out_reg.view(BV,J,2,H*W) 577 | 578 | # ind (BV,J) 579 | max_val, ind = torch.max(heatmap.view(BV,J, -1), dim=2) 580 | # ind_xy (BV,J,2) 581 | ind_y = ind // W 582 | ind_x = ind % W 583 | ind_xy = torch.stack([ind_x, ind_y], dim=-1) 584 | 585 | # ind (BV,J,1,1) -> ind (BV,J,2,1) / reg (BV,J,2,HW) / reg_val (BV,J,2) 586 | reg_val = torch.gather(reg, 3, ind.unsqueeze(-1).unsqueeze(-1).expand(-1,-1,2,-1)).squeeze(3) 587 | 588 | # coord (BV,J,2) 589 | coord = ind_xy.float() + reg_val 590 | out_x2d = coord.clone()* image_w / W 591 | 592 | # (BV,1,3,3)@(BV,J,3,1) -> (BV,J,3,1) 593 | xdir = (torch.inverse(K)[:,None] @ eulid_to_homo(coord * image_w / W)[...,None]).squeeze(-1) 594 | xdir = torch.cat([xdir, max_val[...,None]], dim = -1).view(B,V,J,4) 595 | 596 | coord[:, :, 0] = (coord[:, :, 0] / (W - 1)) * 2 - 1 597 | coord[:, :, 1] = (coord[:, :, 1] / (H - 1)) * 2 - 1 598 | coord = coord.unsqueeze(1) 599 | xfeat = F.grid_sample(feature, coord, mode='bilinear', align_corners=True) 600 | xfeat = xfeat.squeeze(2).transpose(1, 2).view(B,V,J,128) 601 | 602 | out_x3d = self.fusion(xdir, xfeat) 603 | # out_x3d = self.fusion(xdir) 604 | 605 | return out_x3d, out_hm.view(B, V, J, H,W), out_reg.view(B,V,J,2,H,W), out_x2d.view(B,V,J,2) 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | -------------------------------------------------------------------------------- /train2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | from utils import * 8 | from trainer import Trainer 9 | from datasets import Human36M 10 | from models import pose2d_model 11 | from loss import Net2d 12 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 13 | 14 | cfg = { 15 | 'root_path': '/data/human36m/processed', 16 | 'labels_path': '/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy', 17 | 'lr':1e-4, 18 | 'num_epoch':300, 19 | 'batch_size_train': 128, 20 | 'batch_size_test': 32, 21 | 'num_workers':32, 22 | 'num_keypoints': 17, 23 | 'scaleRange': [1.1,1.2], 24 | 'moveRange': [-0.1,0.1], 25 | 'image_size':384, 26 | 'heatmap_size': 96, 27 | # 'model_path': 'checkpoints/mobileone_s4_unfused.pth.tar', 28 | 'model_path': '/logs/pose2d_384/model_12.pth', 29 | # 'model_path': 'checkpoints/backbone_pretrain.pth', 30 | 'device':'cuda', 31 | 'save_dir': '/logs/pose2d_384', 32 | 'use_tag':False, 33 | 'data_skip_train':16, 34 | 'data_skip_test':16, 35 | } 36 | 37 | train_db = Human36M(cfg, is_train=True) 38 | test_db = Human36M(cfg, is_train=False) 39 | train_loader = DataLoader( 40 | train_db, 41 | batch_size=cfg['batch_size_train'], 42 | shuffle=True, 43 | num_workers = cfg['num_workers'], 44 | pin_memory = True, 45 | drop_last=True, 46 | ) 47 | test_loader = DataLoader( 48 | test_db, 49 | batch_size=cfg['batch_size_test'], 50 | shuffle=False, 51 | num_workers = cfg['num_workers'], 52 | pin_memory = True, 53 | drop_last=True, 54 | ) 55 | 56 | 57 | # trainer 58 | model = pose2d_model(num_classes=17) 59 | 60 | if 'model_path' in cfg: 61 | pretrain_dict = torch.load(cfg['model_path']) 62 | missing, unexpected = model.load_state_dict(pretrain_dict,strict=False) 63 | print('missing length', len(missing), 'unexpected', len(unexpected) , '\n') 64 | model = torch.nn.DataParallel(model, device_ids=[0,1,2,3]).cuda() 65 | 66 | net = Net2d(cfg, model) 67 | trainer = Trainer(cfg, net) 68 | trainer.run(train_loader, test_loader) 69 | -------------------------------------------------------------------------------- /train3d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | from utils import * 8 | from trainer import Trainer 9 | from datasets import MultiHuman36M 10 | from models import ProbTri 11 | from loss import Net3d 12 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 13 | 14 | 15 | cfg = { 16 | 'root_path': '/data/human36m/processed', 17 | 'labels_path': '/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy', 18 | 'lr':1e-5, 19 | 'num_epoch':300, 20 | 'batch_size_train': 32, 21 | 'batch_size_test': 8, 22 | 'num_workers':8, 23 | 'num_keypoints': 17, 24 | 'num_views': 4, 25 | 'scaleRange': [1.1,1.2], 26 | 'moveRange': [-0.1,0.1], 27 | 'image_size':384, 28 | 'heatmap_size': 96, 29 | # 'backbone_path': 'checkpoints/backbone.pth', 30 | # 'fusion_path': 'checkpoints/fusion.pth', 31 | 'model_path': 'checkpoints/pretrain.pth', 32 | 'device':'cuda', 33 | 'save_dir': '/logs/pose3d', 34 | 'use_tag':False, 35 | 'data_skip_train':8, 36 | 'data_skip_test':4, 37 | } 38 | 39 | train_db = MultiHuman36M(cfg, is_train=True) 40 | test_db = MultiHuman36M(cfg, is_train=False) 41 | train_loader = DataLoader( 42 | train_db, 43 | batch_size=cfg['batch_size_train'], 44 | shuffle=True, 45 | num_workers = cfg['num_workers'], 46 | pin_memory = True, 47 | drop_last=True, 48 | ) 49 | test_loader = DataLoader( 50 | test_db, 51 | batch_size=cfg['batch_size_test'], 52 | shuffle=False, 53 | num_workers = cfg['num_workers'], 54 | pin_memory = True, 55 | drop_last=True, 56 | ) 57 | 58 | 59 | # trainer 60 | model = ProbTri(cfg) 61 | if 'backbone_path' in cfg: 62 | pretrain_dict = torch.load(cfg['backbone_path']) 63 | missing, unexpected = model.backbone.load_state_dict(pretrain_dict,strict=False) 64 | print('load backbone model, missing length', len(missing), 'unexpected', len(unexpected) , '\n') 65 | 66 | if 'fusion_path' in cfg: 67 | pretrain_dict = torch.load(cfg['fusion_path']) 68 | missing, unexpected = model.fusion.load_state_dict(pretrain_dict,strict=False) 69 | print('load fusion model, missing length', len(missing), 'unexpected', len(unexpected) , '\n') 70 | 71 | if 'model_path' in cfg: 72 | pretrain_dict = torch.load(cfg['model_path']) 73 | missing, unexpected = model.load_state_dict(pretrain_dict,strict=False) 74 | print('missing length', len(missing), 'unexpected', len(unexpected) , '\n') 75 | model = torch.nn.DataParallel(model, device_ids=[0,1,2,3]).cuda() 76 | 77 | net = Net3d(cfg, model) 78 | trainer = Trainer(cfg, net) 79 | trainer.run(train_loader, test_loader) 80 | -------------------------------------------------------------------------------- /trainPose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import Dataset, DataLoader 4 | import os 5 | import numpy as np 6 | from loss import NetPose 7 | from trainer import Trainer 8 | from models import Fusion 9 | from datasets import MultiPose 10 | 11 | # 1e-3, 扩大范围的wingloss,batch size 128 12 | # 2e-4, 进一步降低loss,batch size 256 13 | 14 | cfg = { 15 | 'root_path': '/data/human36m/processed', 16 | 'labels_path': '/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy', 17 | 'lr':2e-4, 18 | 'num_epoch':300, 19 | 'batch_size_train': 256, 20 | 'batch_size_test': 64, 21 | 'num_workers':8, 22 | 'num_keypoints': 17, 23 | 'num_views':4, 24 | 'image_size':384, 25 | 'heatmap_size':96, 26 | 'device':'cuda', 27 | 'model_path': '/logs/tri/model_125.pth', 28 | 'save_dir': '/logs/tri', 29 | } 30 | 31 | train_db = MultiPose(cfg, is_train=True) 32 | test_db = MultiPose(cfg, is_train=False) 33 | train_loader = DataLoader( 34 | train_db, 35 | batch_size=cfg['batch_size_train'], 36 | shuffle=True, 37 | num_workers = cfg['num_workers'], 38 | pin_memory = True, 39 | drop_last=True, 40 | ) 41 | test_loader = DataLoader( 42 | test_db, 43 | batch_size=cfg['batch_size_test'], 44 | shuffle=False, 45 | num_workers = cfg['num_workers'], 46 | pin_memory = True, 47 | drop_last=True, 48 | ) 49 | 50 | 51 | 52 | model = Fusion(cfg) 53 | 54 | if 'model_path' in cfg: 55 | pretrain_dict = torch.load(cfg['model_path']) 56 | missing, unexpected = model.load_state_dict(pretrain_dict,strict=False) 57 | print('missing length', len(missing), 'unexpected', len(unexpected) , '\n') 58 | # model = torch.nn.DataParallel(model, device_ids=[0,1,2]).cuda() 59 | model = model.cuda() 60 | 61 | net = NetPose(cfg, model) 62 | trainer = Trainer(cfg, net) 63 | trainer.run(train_loader, test_loader) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import tensorboardX 4 | import time 5 | import os 6 | 7 | # EMA 8 | class EMA(): 9 | def __init__(self, model, decay): 10 | self.model = model 11 | self.decay = decay 12 | self.shadow = {} 13 | self.backup = {} 14 | 15 | def register(self): 16 | for name, param in self.model.named_parameters(): 17 | if param.requires_grad: 18 | self.shadow[name] = param.data.clone() 19 | 20 | def update(self): 21 | for name, param in self.model.named_parameters(): 22 | if param.requires_grad: 23 | assert name in self.shadow 24 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] 25 | self.shadow[name] = new_average.clone() 26 | 27 | def apply_shadow(self): 28 | for name, param in self.model.named_parameters(): 29 | if param.requires_grad: 30 | assert name in self.shadow 31 | self.backup[name] = param.data 32 | param.data = self.shadow[name] 33 | 34 | def restore(self): 35 | for name, param in self.model.named_parameters(): 36 | if param.requires_grad: 37 | assert name in self.backup 38 | param.data = self.backup[name] 39 | self.backup = {} 40 | 41 | 42 | # AverageMeter 43 | class AverageMeter(): 44 | """Computes and stores the average and current value""" 45 | def __init__(self): 46 | self.reset() 47 | 48 | def reset(self): 49 | self.val = 0 50 | self.avg = 0 51 | self.sum = 0 52 | self.count = 0 53 | 54 | def update(self, val, n=1): 55 | self.val = val 56 | self.sum += val * n 57 | self.count += n 58 | if self.count > 0: 59 | self.avg = self.sum / self.count 60 | 61 | # Logger 62 | class Logger(): 63 | def __init__(self, cfg, is_train): 64 | if not os.path.exists(cfg['save_dir']): 65 | os.makedirs(cfg['save_dir']) 66 | 67 | time_str = time.strftime('%Y-%m-%d-%H-%M') 68 | if is_train: 69 | save_path = os.path.join(cfg['save_dir'],'trainlogs') 70 | else: 71 | save_path = os.path.join(cfg['save_dir'],'testlogs') 72 | 73 | if not os.path.exists(save_path): 74 | os.makedirs(save_path) 75 | 76 | log_dir = os.path.join(save_path, 'logs_{}'.format(time_str)) 77 | print(log_dir) 78 | self.writer = tensorboardX.SummaryWriter(log_dir = log_dir) 79 | 80 | def scalar_summay(self, tag, value, step): 81 | self.writer.add_scalar(tag, value, step) 82 | 83 | 84 | # Trainer 85 | class Trainer(): 86 | def __init__(self, cfg, net): 87 | self.cfg = cfg 88 | self.net = net 89 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=cfg['lr']) 90 | 91 | # for state in self.optimizer.state.values(): 92 | # for k, v in state.items(): 93 | # if isinstance(v, torch.Tensor): 94 | # state[k] = v.to(device=cfg['device'], non_blocking=True) 95 | 96 | self.ema = EMA(self.net, 0.999) 97 | self.ema.register() 98 | 99 | self.train_logger = Logger(cfg, is_train=True) 100 | self.test_logger = Logger(cfg, is_train=False) 101 | 102 | def run(self, train_loader, test_loader): 103 | start_epoch = 0 104 | for epoch in range(start_epoch + 1, self.cfg['num_epoch']+1): 105 | # train 106 | log_dict = self.run_epoch(train_loader, is_train=True) 107 | for k, v in log_dict.items(): 108 | print('train epoch=', epoch, k, v.avg) 109 | self.train_logger.scalar_summay(k, v.avg, epoch) 110 | 111 | # test 112 | log_dict = self.run_epoch(test_loader, is_train=False) 113 | for k, v in log_dict.items(): 114 | print('test epoch=', epoch, k, v.avg) 115 | self.test_logger.scalar_summay(k, v.avg, epoch) 116 | 117 | if hasattr(self.net.model, 'module'): 118 | model_state_dict = self.net.model.module.state_dict() 119 | else: 120 | model_state_dict = self.net.model.state_dict() 121 | 122 | save_path = os.path.join(self.cfg['save_dir'], f'model_{epoch}.pth') 123 | 124 | torch.save(model_state_dict, save_path) 125 | 126 | 127 | def run_epoch(self, data_loader, is_train=True): 128 | if is_train: 129 | self.net.train() 130 | torch.set_grad_enabled(True) 131 | 132 | else: 133 | self.net.eval() 134 | torch.set_grad_enabled(False) 135 | 136 | avg_loss_stats = {key: AverageMeter() for key in ['time/data', 'time/infer']} 137 | 138 | t0 = time.time() 139 | for iter_id, batch in tqdm(enumerate(data_loader)): 140 | for key in batch: 141 | if isinstance(batch[key], torch.Tensor): 142 | batch[key] = batch[key].to(device = self.cfg['device'], non_blocking=True) 143 | 144 | t1 = time.time() 145 | avg_loss_stats['time/data'].update(t1-t0) 146 | t0 = time.time() 147 | 148 | output, loss, loss_stats = self.net(batch) 149 | 150 | t1 = time.time() 151 | avg_loss_stats['time/infer'].update(t1-t0) 152 | t0 = time.time() 153 | 154 | if is_train: 155 | self.optimizer.zero_grad() 156 | loss.backward() 157 | self.optimizer.step() 158 | self.ema.update() 159 | 160 | if 'loss' not in avg_loss_stats: 161 | for key in loss_stats: 162 | avg_loss_stats[key] = AverageMeter() 163 | 164 | for key in loss_stats: 165 | avg_loss_stats[key].update(loss_stats[key].item(), data_loader.batch_size) 166 | 167 | return avg_loss_stats 168 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | 5 | class Camera(): 6 | def __init__(self,R,t,K): 7 | self.R = np.asarray(R).copy() 8 | self.t = np.asarray(t).copy() 9 | self.K = np.asarray(K).copy() 10 | 11 | def update_after_crop(self, bbox): 12 | left, upper, right, lower = bbox 13 | 14 | cx, cy = self.K[0, 2], self.K[1, 2] 15 | 16 | new_cx = cx - left 17 | new_cy = cy - upper 18 | 19 | self.K[0, 2], self.K[1, 2] = new_cx, new_cy 20 | 21 | def update_after_resize(self, image_shape, new_image_shape): 22 | height, width = image_shape 23 | new_height, new_width = new_image_shape 24 | 25 | fx, fy, cx, cy = self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] 26 | 27 | new_fx = fx * (new_width / width) 28 | new_fy = fy * (new_height / height) 29 | new_cx = cx * (new_width / width) 30 | new_cy = cy * (new_height / height) 31 | 32 | self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] = new_fx, new_fy, new_cx, new_cy 33 | 34 | def projection(self): 35 | return self.K @ self.extrinsics() 36 | 37 | def extrinsics(self): 38 | return np.hstack([self.R, self.t]) 39 | 40 | def eulid_to_homo(points): 41 | """ 42 | points: (...,N,M) 43 | return: (...,N,M+1) 44 | """ 45 | if isinstance(points, np.ndarray): 46 | return np.concatenate([points, np.ones((*points.shape[:-1],1))], axis=-1) 47 | elif torch.is_tensor(points): 48 | return torch.cat([points, torch.ones((*points.shape[:-1],1),dtype=points.dtype,device=points.device)],dim=-1) 49 | else: 50 | raise TypeError("Works Only with numpy arrays and Pytorch tensors") 51 | 52 | def homo_to_eulid(points): 53 | """ 54 | points: (...,N,M+1) 55 | return: (...,N,M) 56 | """ 57 | if isinstance(points, np.ndarray): 58 | return points[...,:-1] / points[...,-1,None] 59 | elif torch.is_tensor(points): 60 | return points[...,:-1] / points[...,-1,None] 61 | else: 62 | raise TypeError("Works Only with numpy arrays and Pytorch tensors") 63 | 64 | 65 | -------------------------------------------------------------------------------- /v1/A.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import numpy as np\n", 11 | "from collections import defaultdict\n", 12 | "from tqdm import tqdm\n", 13 | "from utils import *\n", 14 | "\n", 15 | "import math\n", 16 | "\n", 17 | "from triangulation import ProbabilisticTriangulation,cal_mpjpe_batch\n", 18 | "from calibration import CalibrationBatch\n", 19 | "from torch.utils.data import Dataset, DataLoader\n", 20 | "\n", 21 | "from dataset import H36M, H36MPred\n", 22 | "import torch\n", 23 | "import torch.nn as nn\n", 24 | "import numpy as np\n", 25 | "import cv2\n", 26 | "from utils import *\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "from tqdm import tqdm" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 10, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "cfg = {\n", 38 | " \"nB\":32,\n", 39 | " \"nV\":4,\n", 40 | " \"M\":32,\n", 41 | " \"isDistr\": False,\n", 42 | " \"cube_min\": [-0.16, -0.22, 0],\n", 43 | " \"cube_max\" : [0.22, 0.28, 0.21],\n", 44 | " \"cube_shape\": [64,64,32],\n", 45 | "}" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stderr", 55 | "output_type": "stream", 56 | "text": [ 57 | "69it [43:29, 37.76s/it]" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "h36m = H36MPred()\n", 63 | "h36mloader = DataLoader(h36m, batch_size = cfg[\"nB\"], shuffle = True)\n", 64 | "db = []\n", 65 | "for iter_i, batch in tqdm(enumerate(h36mloader)):\n", 66 | " calibr = CalibrationBatch(cfg,batch['pose_2d_pred'],batch['confidence'])\n", 67 | " weights_log = calibr.monte_carlo()\n", 68 | " R,t = calibr.prob_tri.getbuffer_Rt()\n", 69 | " for j in range(cfg[\"nB\"]):\n", 70 | " db.append({\n", 71 | " 'pose_3d' : batch['pose_3d'][j],\n", 72 | " 'pose_2d_pred': batch['pose_2d_pred'][j],\n", 73 | " 'Rpred': R.quan[:,j],\n", 74 | " 'tpred': t.vector[:,j],\n", 75 | " 'Rgt': batch['Rgt'][j],\n", 76 | " 'tgt': batch['tgt'][j],\n", 77 | " })" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "np.save('after_monte.npy', np.asarray(db))" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 7, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "data": { 96 | "text/plain": [ 97 | "[,\n", 98 | " ,\n", 99 | " ,\n", 100 | " ,\n", 101 | " ,\n", 102 | " ,\n", 103 | " ,\n", 104 | " ,\n", 105 | " ,\n", 106 | " ,\n", 107 | " ,\n", 108 | " ,\n", 109 | " ,\n", 110 | " ,\n", 111 | " ,\n", 112 | " ,\n", 113 | " ,\n", 114 | " ,\n", 115 | " ,\n", 116 | " ,\n", 117 | " ,\n", 118 | " ,\n", 119 | " ,\n", 120 | " ,\n", 121 | " ,\n", 122 | " ,\n", 123 | " ,\n", 124 | " ,\n", 125 | " ,\n", 126 | " ,\n", 127 | " ,\n", 128 | " ]" 129 | ] 130 | }, 131 | "execution_count": 7, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | }, 135 | { 136 | "data": { 137 | "image/png": "\n", 138 | "text/plain": [ 139 | "
" 140 | ] 141 | }, 142 | "metadata": {}, 143 | "output_type": "display_data" 144 | } 145 | ], 146 | "source": [ 147 | "draw_db = torch.stack(weights_log, dim=0)[...,0]\n", 148 | "plt.plot(draw_db)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "class Cube(nn.Module):\n", 158 | " def __init__(self, cfg):\n", 159 | " super().__init__()\n", 160 | " self.cube_min = torch.tensor(cfg[\"cube_min\"])\n", 161 | " self.cube_max = torch.tensor(cfg[\"cube_max\"])\n", 162 | " self.nB = cfg['nB']\n", 163 | " self.cube_shape = cfg['cube_shape']\n", 164 | " cube = []\n", 165 | " for i in range(self.cube_shape[0]):\n", 166 | " for j in range(self.cube_shape[1]):\n", 167 | " for k in range(self.cube_shape[2]):\n", 168 | " cube.append([i,j,k])\n", 169 | " # (B,Np,3)\n", 170 | " self.cube = torch.tensor(cube.repeat(self.nB,1,1),requires_grad=False)\n", 171 | " self.cube = self.cube * (self.cube_max - self.cube_min)[None,None] + self.cube_min[None]\n", 172 | "\n", 173 | " def __call__(self, x, R, t, weights = None):\n", 174 | " \"\"\"\n", 175 | " Args:\n", 176 | " x : (B,V,J,2)\n", 177 | " R : ((M),B,V,*)\n", 178 | " t : ((M),B,V,*)\n", 179 | " weights: (M,B)\n", 180 | " \n", 181 | " \"\"\"\n", 182 | " if weights is not None:\n", 183 | " # (M,B,V,1,3,3)@(1,B,1,Np,3,1) + (M,B,V,1,3,1) -> (M,B,V,Np,3,1) -> (M,B,V,Np,2)\n", 184 | " reproj = homo_to_eulid( (R.matrix[:,:,:,None] @ self.cube[None,:,None,:,:,None] + t.trans[:,:,:,None]).squeeze(-1))\n", 185 | " # ((M,B,V,1,Np,2) - (1,B,V,J,1,2)) * (M,N,1,1,1,1) -> (M,B,V,J,Np,2) -> (B,V,J,Np,2)\n", 186 | " # (B,V,J,Np,2)**2 -> (B,V,J,Np,2) ** 2 -> (B,J,Np)\n", 187 | " mpjpe = (\n", 188 | " (reproj[:,:,:,None] - x[None,...,None,:])**2 * weights[...,None,None,None,None]\n", 189 | " ).sum(0) / weights.sum(0)[...,None,None,None,None]\n", 190 | " \n", 191 | " \n", 192 | " else:\n", 193 | " # (B,V,1,3,3)@(B,1,Np,3,1) + (B,V,1,3,1) -> (B,V,Np,3,1) -> (B,V,Np,2)\n", 194 | " reproj = homo_to_eulid( (R.matrix[:,:,None] @ self.cube[:,None,:,:,None] + t.trans[:,:,None]).squeeze(-1))\n", 195 | " # ((B,V,1,Np,2) - (B,V,J,1,2))**2 -> (B,V,J,Np,2) ** 2 -> (B,J,Np)\n", 196 | " mpjpe = (reproj[:,:,None] - x[...,None,:])**2\n", 197 | " \n", 198 | " mpjpe = mpjpe.sum(-4).sum(-1)\n", 199 | " mpjpe = mpjpe.view(self.nB,-1,*self.cube_shape)\n", 200 | " return torch.exp(- mpjpe)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [] 209 | } 210 | ], 211 | "metadata": { 212 | "kernelspec": { 213 | "display_name": "Python 3 (ipykernel)", 214 | "language": "python", 215 | "name": "python3" 216 | }, 217 | "language_info": { 218 | "codemirror_mode": { 219 | "name": "ipython", 220 | "version": 3 221 | }, 222 | "file_extension": ".py", 223 | "mimetype": "text/x-python", 224 | "name": "python", 225 | "nbconvert_exporter": "python", 226 | "pygments_lexer": "ipython3", 227 | "version": "3.8.17" 228 | } 229 | }, 230 | "nbformat": 4, 231 | "nbformat_minor": 2 232 | } 233 | -------------------------------------------------------------------------------- /v1/calibration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from triangulation import ProbabilisticTriangulation 3 | import cv2 4 | import numpy as np 5 | from utils import * 6 | 7 | class CalibrationBatch(): 8 | def __init__(self,cfg, points2d, confi2d): 9 | """ 10 | points2d : (B,V,J,2) 11 | confi2d : (B,V,J) 12 | points3d : (B,J,3) 13 | confi3d : (B,J) 14 | R : (B,V,3,3) 15 | t : (B,V,3,1) 16 | isdistribution : bool 17 | """ 18 | self.n_batch,self.n_view,self.n_joint = points2d.shape[:3] 19 | self.M = cfg["M"] 20 | self.points2d = points2d 21 | self.confi2d = confi2d 22 | self.points3d = torch.zeros((self.n_batch,self.n_joint,3)) 23 | self.confi3d = torch.zeros((self.n_batch,self.n_joint)) 24 | self.R = torch.zeros((self.n_batch,self.n_view,3,3)) 25 | self.t = torch.zeros((self.n_batch,self.n_view,3,1)) 26 | self.prob_tri = ProbabilisticTriangulation(cfg) 27 | self.buffer_weights = None 28 | 29 | 30 | def weighted_triangulation(self, points2d, confi2d, R ,t): 31 | """ 32 | Args: 33 | points2d : (V',J,2) 34 | confi2d : (V',J) 35 | R : (V',3,3) 36 | t : (V',3,1) 37 | Returns: 38 | points3d : (J,3) 39 | confi3d : (J) 40 | """ 41 | n_view_filter= points2d.shape[0] 42 | points3d = torch.zeros((self.n_joint, 3)) 43 | confi3d = torch.zeros((self.n_joint)) 44 | # print(points2d.shape,confi2d.shape,R.shape,t.shape) 45 | for j in range(self.n_joint): 46 | A = [] 47 | for i in range(n_view_filter): 48 | if confi2d[i,j] > 0.5: 49 | P = torch.cat([R[i],t[i]],dim=1) 50 | P3T = P[2] 51 | A.append(confi2d[i,j] * (points2d[i,j,0]*P3T - P[0])) 52 | A.append(confi2d[i,j] * (points2d[i,j,1]*P3T - P[1])) 53 | A = torch.stack(A) 54 | # print(A.shape) 55 | if A.shape[0] >= 4: 56 | u, s, vh = torch.linalg.svd(A) 57 | error = s[-1] 58 | X = vh[len(s) - 1] 59 | points3d[j,:] = X[:3] / X[3] 60 | confi3d[j] = np.exp(-torch.abs(error)) 61 | else: 62 | points3d[:,j] = torch.tensor([0.0,0.0,0.0]) 63 | confi3d[j] = 0 64 | 65 | return points3d, confi3d 66 | 67 | def weighted_triangulation_sample(self, points2d, confi2d, R ,t): 68 | """ 69 | Args: 70 | points2d : (B,V',J,2) 71 | confi2d : (B,V',J) 72 | R : ((M),B, V',3,3) 73 | t : ((M),B, V',3,1) 74 | Returns: 75 | sample_points3d : ((M),B,J,3) 76 | sample_confi3d : ((M),B,J) 77 | """ 78 | if len(R.matrix.shape[:-2]) > len(points2d.shape[:-2]): 79 | nM = R.matrix.shape[0] 80 | sample_points3d = torch.zeros((nM,self.n_batch,self.n_joint,3)) 81 | sample_confi3d = torch.zeros((nM,self.n_batch,self.n_joint)) 82 | for i in range(nM): 83 | for j in range(self.n_batch): 84 | sample_points3d[i,j], sample_confi3d[i,j] = self.weighted_triangulation( 85 | points2d[j], confi2d[j], R.matrix[i,j], t.trans[i,j] 86 | ) 87 | return sample_points3d, sample_confi3d 88 | 89 | else: 90 | sample_points3d = torch.zeros((self.n_batch,self.n_joint,3)) 91 | sample_confi3d = torch.zeros((self.n_batch,self.n_joint)) 92 | 93 | for j in range(self.n_batch): 94 | sample_points3d[j], sample_confi3d[j] = self.weighted_triangulation( 95 | points2d[j], confi2d[j], R.matrix[j], t.trans[j] 96 | ) 97 | 98 | return sample_points3d, sample_confi3d 99 | 100 | def pnp(self,batch_id): 101 | self.R[batch_id,0] = torch.eye(3) 102 | self.t[batch_id,0] = torch.zeros((3,1)) 103 | for i in range(1, self.n_view): 104 | mask = torch.logical_and(self.confi2d[batch_id,i]>0.8,self.confi3d[batch_id]>0.8) 105 | p2d = self.points2d[batch_id,i,mask].numpy() 106 | p3d = self.points3d[batch_id,mask].numpy() 107 | ret, rvec, tvec = cv2.solvePnP(p3d, p2d, np.eye(3), np.zeros(5)) 108 | R, _ = cv2.Rodrigues(rvec) 109 | self.R[batch_id,i] = torch.tensor(R) 110 | self.t[batch_id,i] = torch.tensor(tvec) 111 | 112 | 113 | def eight_point(self): 114 | for batch_id in range(self.n_batch): 115 | mask = torch.logical_and(self.confi2d[batch_id,0]>0.5, self.confi2d[batch_id,1]>0.5) 116 | 117 | p0 = self.points2d[batch_id,0,mask].numpy() 118 | p1 = self.points2d[batch_id,1,mask].numpy() 119 | # p0,p1 (N,2) 120 | # print(p0.shape,p1.shape) 121 | E, mask = cv2.findEssentialMat(p0, p1, focal=1.0, pp=(0., 0.), 122 | method=cv2.RANSAC, prob=0.999, threshold=0.0003) 123 | p0_inliers = p0[mask.ravel() == 1] 124 | p1_inliers = p0[mask.ravel() == 1] 125 | point, R, t,mask = cv2.recoverPose(E, p0_inliers, p1_inliers) 126 | self.R[batch_id,0],self.t[batch_id,0] = torch.eye(3), torch.zeros((3,1)) 127 | self.R[batch_id,1],self.t[batch_id,1] = torch.tensor(R),torch.tensor(t) 128 | 129 | # print(self.R[batch_id,1],self.t[batch_id,1]) 130 | 131 | self.points3d[batch_id], self.confi3d[batch_id] = self.weighted_triangulation( 132 | self.points2d[batch_id,:2],self.confi2d[batch_id,:2],self.R[batch_id,:2],self.t[batch_id,:2] 133 | ) 134 | 135 | self.pnp(batch_id) 136 | 137 | # print(self.R[batch_id,0],self.t[batch_id,0]) 138 | # print(self.mpjpe(2)) 139 | # print(self.confi3d[batch_id]) 140 | 141 | self.points3d[batch_id], self.confi3d[batch_id] = self.weighted_triangulation( 142 | self.points2d[batch_id],self.confi2d[batch_id],self.R[batch_id],self.t[batch_id] 143 | ) 144 | # print(self.confi3d[batch_id]) 145 | # print(self.mpjpe(self.n_view)) 146 | 147 | def monte_carlo(self): 148 | self.eight_point() 149 | self.prob_tri.update_paramater_init(self.points3d, self.points2d ,self.R,self.t) 150 | 151 | weights_log = [] 152 | for i in range(4): 153 | rot,t = self.prob_tri.sample(self.M) 154 | # print(self.points2d.shape, self.confi2d.shape, rot.quan.shape, t.vector.shape) 155 | sample_points3d, sample_confi3d = self.weighted_triangulation_sample(self.points2d, self.confi2d, rot, t) 156 | weights = cal_mpjpe_batch(sample_points3d, self.points2d, rot, t) 157 | weights_log.append(weights) 158 | self.prob_tri.update_paramater_with_weights(rot,t,weights) 159 | 160 | rot,t = self.prob_tri.getbuffer_Rt() 161 | # print(self.points2d.shape, self.confi2d.shape, rot.quan.shape, t.vector.shape) 162 | points3d, confi3d = self.weighted_triangulation_sample(self.points2d, self.confi2d, rot, t) 163 | self.buffer_weights = cal_mpjpe_batch(points3d,self.points2d,rot,t) 164 | 165 | rot,t = self.prob_tri.getbest_Rt() 166 | self.R = rot.matrix 167 | self.t = t.trans 168 | return weights_log 169 | 170 | def mpjpe(self, n_view_filter): 171 | return (homo_to_eulid((self.R[...,:n_view_filter,None,:,:] @ self.points3d[...,None,:,:,None] + self.t[...,:n_view_filter,None,:,:]).squeeze(-1)) - self.points2d[:,:n_view_filter] ).mean() 172 | 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /v1/distribution.py: -------------------------------------------------------------------------------- 1 | from pyro.distributions import TorchDistribution,constraints 2 | from pyro.distributions import MultivariateStudentT 3 | from torch.distributions.multivariate_normal import _batch_mahalanobis, _standard_normal, _batch_mv 4 | 5 | from pyro.distributions.util import broadcast_shape 6 | import torch 7 | import math 8 | 9 | 10 | 11 | def cholesky_wrapper(mat, default_diag=None, force_cpu=True): 12 | device = mat.device 13 | if force_cpu: 14 | mat = mat.cpu() 15 | try: 16 | tril = torch.linalg.cholesky(mat, upper=False) 17 | except RuntimeError: 18 | n_dims = mat.size(-1) 19 | tril = [] 20 | default_tril_single = torch.diag(mat.new_tensor(default_diag)) if default_diag is not None \ 21 | else torch.eye(n_dims, dtype=mat.dtype, device=mat.device) 22 | for cov in mat.reshape(-1, n_dims, n_dims): 23 | try: 24 | tril.append(torch.linalg.cholesky(cov, upper=False)) 25 | except RuntimeError: 26 | tril.append(default_tril_single) 27 | tril = torch.stack(tril, dim=0).reshape(mat.shape) 28 | return tril.to(device) 29 | 30 | 31 | class AngularCentralGaussian(TorchDistribution): 32 | arg_constraints = {'scale_tril': constraints.lower_cholesky} 33 | has_rsample = True 34 | 35 | def __init__(self, scale_tril, validate_args=None, eps=1e-6): 36 | q = scale_tril.size(-1) 37 | assert q > 1 38 | assert scale_tril.shape[-2:] == (q, q) 39 | batch_shape = scale_tril.shape[:-2] 40 | event_shape = (q,) 41 | self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) 42 | self._unbroadcasted_scale_tril = scale_tril 43 | self.q = q 44 | self.area = 2 * math.pi ** (0.5 * q) / math.gamma(0.5 * q) 45 | self.eps = eps 46 | super().__init__(batch_shape, event_shape, validate_args=validate_args) 47 | 48 | def log_prob(self, value): 49 | if self._validate_args: 50 | self._validate_sample(value) 51 | value = value.expand( 52 | broadcast_shape(value.shape[:-1], self._unbroadcasted_scale_tril.shape[:-2]) 53 | + self.event_shape) 54 | M = _batch_mahalanobis(self._unbroadcasted_scale_tril, value) 55 | half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) 56 | return M.log() * (-self.q / 2) - half_log_det - math.log(self.area) 57 | 58 | def rsample(self, sample_shape=torch.Size()): 59 | shape = self._extended_shape(sample_shape) 60 | normal = _standard_normal(shape, 61 | dtype=self._unbroadcasted_scale_tril.dtype, 62 | device=self._unbroadcasted_scale_tril.device) 63 | gaussian_samples = _batch_mv(self._unbroadcasted_scale_tril, normal) 64 | gaussian_samples_norm = gaussian_samples.norm(dim=-1) 65 | samples = gaussian_samples / gaussian_samples_norm.unsqueeze(-1) 66 | samples[gaussian_samples_norm < self.eps] = samples.new_tensor( 67 | [1.] + [0. for _ in range(self.q - 1)]) 68 | return samples 69 | 70 | 71 | class AngularCentralGaussianMultiView(): 72 | def __init__(self, scale_tril, df): 73 | self.scale_tril = scale_tril 74 | self.nV = scale_tril.shape[-3] 75 | self.distri = [AngularCentralGaussian(scale_tril[...,i,:,:]) for i in range(self.nV)] 76 | 77 | def __call__(self,sample_shape=torch.Size()): 78 | return torch.stack([self.distri[i](sample_shape) for i in range(self.nV)],dim=-2) 79 | 80 | class MultivariateStudentTMultiView(): 81 | def __init__(self, loc, scale_tril, df): 82 | self.loc = loc 83 | self.scale_tril = scale_tril 84 | self.df = df 85 | self.nV = scale_tril.shape[-3] 86 | self.distri = [ MultivariateStudentT(loc=self.loc[...,i,:], scale_tril = scale_tril[...,i,:,:],df=self.df) for i in range(self.nV)] 87 | 88 | def __call__(self,sample_shape=torch.Size()): 89 | return torch.stack([self.distri[i](sample_shape) for i in range(self.nV)],dim=-2) -------------------------------------------------------------------------------- /v1/triangulation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from distribution import AngularCentralGaussian, cholesky_wrapper 3 | from pyro.distributions import MultivariateStudentT 4 | from utils import * 5 | import torch.nn.functional as F 6 | 7 | class ProbabilisticTriangulation(): 8 | def __init__(self, cfg): 9 | """ 10 | Members: 11 | expect_quan: Rotation (B,V,*) 12 | tril_R: (B,V-1,4,4) 13 | mu_t: Translation (B,V,*) 14 | tril_t: (B,V-1,3,3) 15 | """ 16 | self.nB = cfg["nB"] 17 | self.nV = cfg["nV"] 18 | self.M = cfg["M"] 19 | self.isDistr = cfg["isDistr"] 20 | 21 | if self.isDistr: 22 | self.expect_quan = Rotation(torch.tensor([1.,0.,0.,0.]).repeat(self.nB,self.nV,1)) 23 | self.mu_t = Translation(torch.zeros(self.nB,self.nV,3)) 24 | 25 | self.tril_R = torch.eye(4,4).repeat(self.nB, self.nV-1, 1, 1) 26 | self.tril_t = torch.eye(3,3).repeat(self.nB, self.nV-1, 1, 1) 27 | # conv_quan (B,V,4,4) 28 | self.distrR = AngularCentralGaussian(self.tril_R) 29 | # mu_t (B,V,3) conv_t (B,V,3,3) 30 | self.distrT = MultivariateStudentT(loc=self.mu_t.distr_norm(),scale_tril=self.tril_t,df=3) 31 | 32 | self.bufferR, self.bufferT = None,None 33 | 34 | else: 35 | self.bufferR = Rotation(torch.randn(self.M//8,self.nB, self.nV-1, 4)) 36 | self.bufferT = Translation(torch.randn(self.M//8,self.nB, self.nV-1, 3)) 37 | 38 | self.lr = 1e-2 39 | 40 | def sample(self, nM): 41 | 42 | if self.isDistr: 43 | if self.bufferR is not None: 44 | nM -= self.bufferR.quan.shape[-4] 45 | rot = Rotation(self.distrR((nM,))) 46 | t = Translation(self.distrT((nM,))) 47 | if self.bufferR is not None: 48 | # print(rot.quan.shape, self.bufferR.quan.shape) 49 | rot.cat(self.bufferR) 50 | t.cat(self.bufferT) 51 | return rot,t 52 | 53 | else: 54 | buffer_nM = self.bufferR.quan.shape[-4] 55 | 56 | temp_buffer_quan = self.bufferR.quan[...,1:,:].repeat(nM//buffer_nM, 1, 1, 1) 57 | rot = Rotation( 58 | temp_buffer_quan + torch.randn_like(temp_buffer_quan) * self.lr 59 | ) 60 | 61 | temp_buffer_vector = self.bufferT.vector[...,1:,:].repeat(nM//buffer_nM, 1, 1, 1) 62 | t = Translation( 63 | temp_buffer_vector + torch.randn_like(temp_buffer_vector) * self.lr 64 | ) 65 | self.lr *= 0.1 66 | rot.random(lr = self.lr) 67 | t.random(lr = self.lr*10) 68 | return rot,t 69 | 70 | 71 | def update_paramater_init(self,points3d,points2d, rot,t): 72 | """ 73 | Args: 74 | rot Tensor -> Rotation: (B,V,3,3) 75 | t Tensor -> Translation: (B,V,3,1) 76 | Returns: 77 | sample_quan : (M,B,V,4) 78 | sample_t : (M,B,V,3) 79 | weights: (M,B) 80 | """ 81 | 82 | self.lr = 1e-3 83 | self.bufferR = Rotation(rot.repeat(self.M//8,1,1,1,1)) 84 | self.bufferT = Translation(t.repeat(self.M//8,1,1,1,1)) 85 | self.bufferR.random(self.lr) 86 | self.bufferT.random(self.lr*10) 87 | rot, t = self.sample(self.M) 88 | 89 | # weights = torch.cat([ 90 | # torch.ones(self.M-1,self.nB) * (0.5/(self.M-1)), 91 | # torch.ones(1,self.nB)*0.5, 92 | # ], dim = 0) 93 | weights = cal_mpjpe_batch(points3d,points2d, rot,t) 94 | self.update_paramater_with_weights(rot, t, weights) 95 | 96 | def update_paramater_with_weights(self,rot,t, weights): 97 | """ 98 | Args: 99 | rot : (M,B,V,*) 100 | t : (M,B,V,*) 101 | weights : (M,B) 102 | Returns: 103 | conv_quan : (B,V,4,4) 104 | mu_t : (B,V,3) 105 | conv_t : (B,V,3,3) 106 | """ 107 | 108 | 109 | topk_weight, indices = torch.topk(weights, self.M//8, dim=0) 110 | # indices (M/2, B) -> (M/2,B,V,*) 111 | indices = indices[...,None,None] 112 | half_quan = rot.quan.gather(0, indices.expand(-1,-1,self.nV,4)) 113 | half_vector = t.vector.gather(0, indices.expand(-1,-1,self.nV,3)) 114 | 115 | rot = Rotation(half_quan) 116 | t = Translation(half_vector) 117 | weights = topk_weight 118 | 119 | # (M,B,V,4) * (M,B,1,1) -> (B,V,4) 120 | self.expect_quan = Rotation( 121 | (rot.quan * weights[...,None,None]).sum(0) / weights.sum(0)[...,None,None] 122 | ) 123 | # (M,B,V,3) * (M,B) -> (B,V,3) 124 | self.mu_t = Translation( 125 | (t.vector * weights[...,None,None]).sum(0) / weights.sum(0)[...,None,None] 126 | ) 127 | 128 | if self.isDistr: 129 | # (B,V-1,4,M) @ (B,V-1,M,4) -> (B,V-1,4,4) 130 | conv_quan = ( 131 | rot.distr_norm().permute(1,2,3,0) @ (rot.distr_norm() * weights[...,None,None]).permute(1,2,0,3) 132 | ) / weights.sum(0)[...,None,None,None] 133 | 134 | # u,s,vt = torch.linalg.svd(conv_quan) 135 | # s *= torch.tensor([1,0.1,0.01,0.001])[None,None] 136 | # conv_quan = u @ torch.diag_embed(s) @ vt 137 | 138 | self.tril_quan = cholesky_wrapper(conv_quan) 139 | 140 | # (M,B,V,3) - (1,B,V,3) -> (M,B,V,3) -> (M,B ,V,3) 141 | centered_t = Translation(t.vector - self.mu_t.vector[None]).distr_norm() 142 | # (B,V-1,3,M) @ (B,V-1,M,3) -> (B,V-1,3,3) 143 | conv_t = ( 144 | centered_t.permute(1,2,3,0) @ (centered_t * weights[...,None,None]).permute(1,2,0,3) 145 | ) / weights.sum(0)[...,None,None,None] 146 | 147 | # u,s,vt = torch.linalg.svd(conv_t) 148 | # s *= torch.tensor([1,0.1,0.01])[None,None] 149 | # conv_t = u @ torch.diag_embed(s) @ vt 150 | 151 | self.tril_t = cholesky_wrapper(conv_t) 152 | 153 | self.distrR = AngularCentralGaussian(self.tril_quan) 154 | self.distrT = MultivariateStudentT(loc=self.mu_t.distr_norm(),scale_tril=self.tril_t,df = 3) 155 | 156 | 157 | self.bufferR = Rotation(half_quan) 158 | self.bufferT = Translation(half_vector) 159 | 160 | 161 | def getbest_Rt(self): 162 | return Rotation(self.bufferR.quan[0]), Translation(self.bufferT.vector[0]) 163 | 164 | def getbuffer_Rt(self): 165 | return self.bufferR, self.bufferT 166 | 167 | 168 | -------------------------------------------------------------------------------- /v1/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | import torch.nn.functional as F 5 | 6 | JOINT_LINKS = [(0,1),(0,2),(1,3),(2,4),(0,5),(0,6),(5,7),(7,9),(6,8),(8,10),(5,11),(6,12),(11,12),(11,13),(13,15),(12,14),(14,16)] 7 | 8 | 9 | def cal_mpjpe(points3d, points2d, R, t): 10 | """ 11 | Args: 12 | points3d : (...,J,3) 13 | points2d : (...,V,J,2) 14 | R : (...,V,3,3) 15 | t: (...,V,3,1) 16 | """ 17 | return (homo_to_eulid((R[...,None,:,:] @ points3d[...,None,:,:,None] + t[...,None,:,:]).squeeze(-1)) - points2d ).mean() 18 | 19 | 20 | def cal_mpjpe_batch(points3d, points2d, rot, t): 21 | """ 22 | Args: 23 | points3d : ((M),...,J,3) 24 | points2d : (...,V,J,2) 25 | rot : Rotation ((M),...,V,*) 26 | t: Translation ((M),...,V,*) 27 | Returns: 28 | weights : (...) or ((M),...) 29 | """ 30 | if(len(rot.quan.shape[:-1]) > len(points2d.shape[:-2])): 31 | return torch.pow(torch.exp( 32 | -( 33 | homo_to_eulid( 34 | (rot.matrix[...,None,:,:] @ points3d[...,None,:,:,None] + t.trans[...,None,:,:] 35 | ).squeeze(-1) 36 | ) - points2d[None] 37 | ).norm(dim=-1).mean((-1,-2)) 38 | ), 4) 39 | else: 40 | return torch.exp( 41 | -( 42 | homo_to_eulid( 43 | (rot.matrix[...,None,:,:] @ points3d[...,None,:,:,None] + t.trans[...,None,:,:] 44 | ).squeeze(-1) 45 | ) - points2d 46 | ).norm(dim=-1).mean((-1,-2)) 47 | ) 48 | 49 | def eulid_to_homo(points): 50 | """ 51 | points: (...,N,M) 52 | return: (...,N,M+1) 53 | """ 54 | if isinstance(points, np.ndarray): 55 | return np.concatenate([points, np.ones((*points.shape[:-1],1))], axis=-1) 56 | elif torch.is_tensor(points): 57 | return torch.cat([points, torch.ones((*points.shape[:-1],1),dtype=points.dtype,device=points.device)],dim=-1) 58 | else: 59 | raise TypeError("Works Only with numpy arrays and Pytorch tensors") 60 | 61 | def homo_to_eulid(points): 62 | """ 63 | points: (...,N,M+1) 64 | return: (...,N,M) 65 | """ 66 | if isinstance(points, np.ndarray): 67 | return points[...,:-1] / points[...,-1,None] 68 | elif torch.is_tensor(points): 69 | return points[...,:-1] / points[...,-1,None] 70 | else: 71 | raise TypeError("Works Only with numpy arrays and Pytorch tensors") 72 | 73 | def calIOU(b1,b2): 74 | """ 75 | Input: 76 | b1,b2: [x1,y1,x2,y2] 77 | """ 78 | s1 = (b1[2] - b1[0]) * (b1[3]-b1[1]) 79 | s2 = (b2[2] - b2[0]) * (b2[3]-b2[1]) 80 | a = max(0,min(b1[2],b2[2]) - max(b1[0],b2[0])) * max(0,min(b1[3],b2[3]) - max(b1[1],b2[1])) 81 | return a/(s1+s2-a) 82 | 83 | 84 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: 85 | """ 86 | Convert rotations given as quaternions to rotation matrices. 87 | 88 | Args: 89 | quaternions: quaternions with real part first, 90 | as tensor of shape (..., 4). 91 | 92 | Returns: 93 | Rotation matrices as tensor of shape (..., 3, 3). 94 | """ 95 | r, i, j, k = torch.unbind(quaternions, -1) 96 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 97 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 98 | 99 | o = torch.stack( 100 | ( 101 | 1 - two_s * (j * j + k * k), 102 | two_s * (i * j - k * r), 103 | two_s * (i * k + j * r), 104 | two_s * (i * j + k * r), 105 | 1 - two_s * (i * i + k * k), 106 | two_s * (j * k - i * r), 107 | two_s * (i * k - j * r), 108 | two_s * (j * k + i * r), 109 | 1 - two_s * (i * i + j * j), 110 | ), 111 | -1, 112 | ) 113 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 114 | 115 | def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: 116 | """ 117 | Returns torch.sqrt(torch.max(0, x)) 118 | but with a zero subgradient where x is 0. 119 | """ 120 | ret = torch.zeros_like(x) 121 | positive_mask = x > 0 122 | ret[positive_mask] = torch.sqrt(x[positive_mask]) 123 | return ret 124 | 125 | def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: 126 | """ 127 | Convert rotations given as rotation matrices to quaternions. 128 | 129 | Args: 130 | matrix: Rotation matrices as tensor of shape (..., 3, 3). 131 | 132 | Returns: 133 | quaternions with real part first, as tensor of shape (..., 4). 134 | """ 135 | if matrix.size(-1) != 3 or matrix.size(-2) != 3: 136 | raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") 137 | 138 | batch_dim = matrix.shape[:-2] 139 | m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( 140 | matrix.reshape(batch_dim + (9,)), dim=-1 141 | ) 142 | 143 | q_abs = _sqrt_positive_part( 144 | torch.stack( 145 | [ 146 | 1.0 + m00 + m11 + m22, 147 | 1.0 + m00 - m11 - m22, 148 | 1.0 - m00 + m11 - m22, 149 | 1.0 - m00 - m11 + m22, 150 | ], 151 | dim=-1, 152 | ) 153 | ) 154 | 155 | # we produce the desired quaternion multiplied by each of r, i, j, k 156 | quat_by_rijk = torch.stack( 157 | [ 158 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 159 | # `int`. 160 | torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), 161 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 162 | # `int`. 163 | torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), 164 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 165 | # `int`. 166 | torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), 167 | # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and 168 | # `int`. 169 | torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), 170 | ], 171 | dim=-2, 172 | ) 173 | 174 | # We floor here at 0.1 but the exact level is not important; if q_abs is small, 175 | # the candidate won't be picked. 176 | flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) 177 | quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) 178 | 179 | # if not for numerical problems, quat_candidates[i] should be same (up to a sign), 180 | # forall i; we pick the best-conditioned one (with the largest denominator) 181 | 182 | return quat_candidates[ 183 | F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : 184 | ].reshape(batch_dim + (4,)) 185 | 186 | class Rotation(): 187 | def __init__(self, rot): 188 | """ 189 | quan ((M),B,V,4) 190 | matrix ((M),B,V,3,3) 191 | """ 192 | assert((rot.shape[-1] == 4) or (rot.shape[-1] == 3 and rot.shape[-2] == 3)) 193 | 194 | if rot.shape[-1] == 4: 195 | rot = self.standard_quan(rot) 196 | self.quan = rot 197 | self.matrix = quaternion_to_matrix(self.quan) 198 | 199 | else: 200 | self.matrix = rot 201 | self.quan = matrix_to_quaternion(self.matrix) 202 | 203 | def cat(self,rot): 204 | assert(isinstance(rot, (Rotation, torch.Tensor))) 205 | if isinstance(rot, torch.Tensor): 206 | rot = Rotation(rot) 207 | assert(rot.quan.shape[-3:] == self.quan.shape[-3:]) 208 | if len(self.quan.shape) > len(rot.quan.shape): 209 | self.quan = torch.cat([self.quan, rot.quan[None]], dim=-4) 210 | self.matrix = torch.cat([self.matrix, rot.matrix[None]], dim=-5) 211 | else: 212 | self.quan = torch.cat([self.quan, rot.quan], dim=-4) 213 | self.matrix = torch.cat([self.matrix, rot.matrix], dim=-5) 214 | 215 | def distr_norm(self): 216 | return self.quan[...,1:,:] 217 | 218 | def standard_quan(self,rot): 219 | rot = rot / torch.clamp_(rot.norm(dim = -1)[...,None],min=1e-4) 220 | size = rot.shape 221 | rot0 = torch.tensor([1,0,0,0]).repeat(*(size[:-2]),1,1) 222 | if F.l1_loss(rot0,rot[...,0:1,:]) < 1e-6: 223 | return rot 224 | else: 225 | return torch.cat([rot0, rot],dim=-2) 226 | 227 | def random(self, lr): 228 | assert(len(self.quan.shape) == 4) 229 | self.quan[1:,:,1:] += torch.randn_like(self.quan[1:,:,1:]) * lr 230 | self.quan = self.standard_quan(self.quan) 231 | self.matrix = quaternion_to_matrix(self.quan) 232 | 233 | 234 | 235 | 236 | class Translation(): 237 | def __init__(self, t): 238 | """ 239 | vector ((M),B,V,3) 240 | trans ((M),B,V,3,1) 241 | """ 242 | assert(t.shape[-1] == 3 or (t.shape[-1]==1 and t.shape[-2]==3)) 243 | if t.shape[-1] == 3: 244 | t = self.standard_vector(t) 245 | self.vector = t 246 | self.trans = t.unsqueeze(-1) 247 | else: 248 | self.vector = t.squeeze(-1) 249 | self.trans = t 250 | # (M,B) 251 | t_norm = self.vector[...,1,:].norm(dim=-1) 252 | if not (t_norm == 0.0).any(): 253 | self.vector[...,1:,:] /= t_norm[...,None,None] 254 | self.trans = self.vector.unsqueeze(-1) 255 | 256 | def cat(self,t): 257 | assert(isinstance(t, (Translation, torch.Tensor))) 258 | if isinstance(t ,torch.Tensor): 259 | t = Translation(t) 260 | assert(t.trans.shape[-3:] == self.trans.shape[-3:]) 261 | if len(self.vector.shape) > len(t.vector.shape): 262 | self.vector = torch.cat([self.vector, t.vector[None]],dim=-4) 263 | self.trans = torch.cat([self.trans, t.trans[None]], dim=-5) 264 | else: 265 | self.vector = torch.cat([self.vector, t.vector],dim=-4) 266 | self.trans = torch.cat([self.trans, t.trans], dim=-5) 267 | 268 | def distr_norm(self): 269 | return self.vector[...,1:,:] 270 | 271 | def standard_vector(self,t): 272 | size = t.shape 273 | t0 = torch.tensor([0,0,0]).repeat(*(size[:-2]),1,1) 274 | if F.l1_loss(t0,t[...,0:1,:]) < 1e-6: 275 | return t / torch.clamp_(t[...,1,:].norm(dim = -1)[...,None,None],min=1e-4) 276 | else: 277 | return torch.cat([t0, t/torch.clamp_(t[...,0,:].norm(dim = -1)[...,None,None],min=1e-4)],dim=-2) 278 | 279 | def random(self, lr): 280 | assert(len(self.vector.shape) == 4) 281 | self.vector[1:,:,1:] += torch.randn_like(self.vector[1:,:,1:]) * lr 282 | self.vector = self.standard_vector(self.vector) 283 | self.trans = self.vector.unsqueeze(-1) --------------------------------------------------------------------------------