├── README.md ├── experiments └── h36m │ ├── resnet50_pos.yaml │ ├── vit_fcn.yaml │ ├── vit_fcn_fusion.yaml │ ├── vit_pos_trans_encoder.yaml │ └── vit_pos_trans_encoder_token.yaml ├── images ├── framework.png └── quali_eval.jpg ├── infer.py ├── lib ├── __init__.py ├── datasets │ ├── __init__.py │ ├── h36m.py │ ├── joints_dataset.py │ ├── mocap_dataset.py │ ├── mpii3d.py │ ├── multiview_h36m.py │ ├── multiview_mpii3d.py │ └── multiview_totalcapture.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── resnet.py │ │ ├── swin_transformer_v2.py │ │ └── vit.py │ ├── components │ │ ├── __init__.py │ │ ├── pose_transformer.py │ │ └── t_cond_mlp.py │ ├── discriminator.py │ ├── fusion.py │ ├── heads │ │ ├── __init__.py │ │ ├── position_encoding.py │ │ └── smpl_head.py │ ├── losses.py │ └── smpl_wrapper.py └── utils │ ├── __init__.py │ ├── config.py │ ├── geometry.py │ ├── img.py │ ├── log_utils.py │ ├── multiview.py │ ├── pose_utils.py │ ├── renderer.py │ ├── transforms.py │ ├── triangulate.py │ ├── vis.py │ └── zipreader.py ├── requirements.txt └── run.py /README.md: -------------------------------------------------------------------------------- 1 | # Human Mesh Recovery from Arbitrary Multi-view Images 2 | 3 | This repository contains the official implementation of our paper: Human Mesh Recovery from Arbitrary Multi-view Images. Xiaoben Li, Mancheng Meng, Ziyan Wu, Terrence Chen, Fan Yang*, Dinggang Shen. [paper](https://arxiv.org/abs/2403.12434) 4 | ![arch](./images/framework.png) 5 | 6 | **Abstract:** Human mesh recovery from arbitrary multi-view images involves two characteristics: the arbitrary camera poses and arbitrary number of camera views. Because of the variability, designing a unified framework to tackle this task is challenging. The challenges can be summarized as the dilemma of being able to simultaneously estimate arbitrary camera poses and recover human mesh from arbitrary multi-view images while maintaining flexibility. To solve this dilemma, we propose a divide and conquer framework for Unified Human Mesh Recovery (U-HMR) from arbitrary multi-view images. In particular, U-HMR consists of a decoupled structure and two main components: camera and body decoupling (CBD), camera pose estimation (CPE), and arbitrary view fusion (AVF). As camera poses and human body mesh are independent of each other, CBD splits the estimation of them into two sub-tasks for two individual sub-networks (\ie, CPE and AVF) to handle respectively, thus the two sub-tasks are disentangled. In CPE, since each camera pose is unrelated to the others, we adopt a shared MLP to process all views in a parallel way. In AVF, in order to fuse multi-view information and make the fusion operation independent of the number of views, we introduce a transformer decoder with a SMPL parameters query token to extract cross-view features for mesh recovery. To demonstrate the efficacy and flexibility of the proposed framework and effect of each component, we conduct extensive experiments on three public datasets: Human3.6M, MPI-INF-3DHP, and TotalCapture. 7 | 8 | ## Installation 9 | 10 | ### Environment Setup 11 | 12 | - OS: Linux 13 | - python: 3.9.12 14 | - pytorch: 1.11.0 15 | 16 | You can create a conda environment use the following commands. 17 | 18 | ``` 19 | conda create --name U-HMR 20 | source activate U-HMR 21 | pip install -r requirements.txt 22 | ``` 23 | [Find the Docker environment here](https://pan.baidu.com/s/19eWya63THlsNeXJ_eDm6zg?pwd=7k7k) 24 | ### Data Setup 25 | 26 | - **Human3.6M**: We follow the [H36M-Toolbox](https://github.com/CHUNYUWANG/H36M-Toolbox.git) to process Human3.6M dataset. And the SMPL annotations are from [PyMAF](https://github.com/HongwenZhang/PyMAF). 27 | - **MPI-INF-3DHP**: We follow the process code from the dataset and the code from [SPIN](https://github.com/nkolot/SPIN). 28 | - **TotalCapture**: We follow the [TotalCature-Toolbox](https://github.com/zhezh/TotalCapture-Toolbox) and the code from the dataset. 29 | - **Others**: Most of other data are from [SPIN](https://github.com/nkolot/SPIN) and [4D-Humans](https://github.com/shubham-goel/4D-Humans). 30 | 31 | ### Pretrained models 32 | We provide a set of pretrained models using [BaiduDisk](https://pan.baidu.com/s/1rtV533AlhQ6PRq8u6YsGjA) (password:uhmr) and [OneDrive](https://1drv.ms/f/s!AqnMGeLS2QFOhbpt34u-GTmwNIpbRQ?e=aSWriS) as follows: 33 | - ResNet50 34 | - MLP 35 | - Independent tokens 36 | - Decoupling with MLP 37 | - Decoupling with Transformer decoder 38 | 39 | ## Model training and evaluation 40 | 41 | ### Code structure 42 | The general structure of the project is as follows. 43 | 1. experiments: config files that contain arguments for performing experiments 44 | 2. lib: core code, contains datasets, models and utils 45 | 3. logs: tensorboard log directory 46 | 4. output: training model output directory 47 | 5. run.py: training and testing code that call the dataset and model functions 48 | ### Usage 49 | ``` 50 | python run.py --cfg_name cfg_file.yaml --dataset dataset_name 51 | ``` 52 | ### Model Inference 53 | ``` 54 | python infer.py --cfg_name cfg_file.yam --image_dir ./test_data 55 | ``` 56 | --cfg_name: Path to the model configuration file 57 | --image_dir: Directory containing test images 58 | 59 | ### Results 60 | 1. Tested on Human3.6M 61 | 62 | | |MPJPE|PA-MPJPE| 63 | |--|--|--| 64 | |Ours-SV(ResNet50)|52.9|42.5| 65 | |Ours-SV(ViT)|43.3|32.6| 66 | |Ours-MV(ResNet50)|36.3|28.3| 67 | |Ours-MV(ViT)|31.0|22.8| 68 | 69 | 2. Qualitaive Results 70 | ![qual_eval](./images/quali_eval.jpg) 71 | ## Acknowlegement 72 | 73 | This work could not have been done without contributions from other open source projects, and we sincerely thank the contributor of those projects. 74 | 75 | - [SPIN](https://github.com/nkolot/SPIN) 76 | - [PyMAF](https://github.com/HongwenZhang/PyMAF) 77 | - [4D-Humans](https://github.com/shubham-goel/4D-Humans) 78 | 79 | 80 | ## Citation 81 | If you find this repository useful, please kindly cite our paper. 82 | ``` 83 | @misc{li2024uhmr, 84 | title={Human Mesh Recovery from Arbitrary Multi-view Images}, 85 | author={Xiaoben Li and Mancheng Meng and Ziyan Wu and Terrence Chen and Fan Yang and Dinggang Shen}, 86 | year={2024}, 87 | eprint={2403.12434}, 88 | archivePrefix={arXiv} 89 | } 90 | 91 | ``` 92 | -------------------------------------------------------------------------------- /experiments/h36m/resnet50_pos.yaml: -------------------------------------------------------------------------------- 1 | IS_TRAIN: True 2 | GPUS: '0' 3 | OUTPUT_DIR: output 4 | LOG_DIR: logs 5 | DATASET: 6 | COLOR_RGB: true 7 | SOURCE: h36m 8 | TRAIN_DATASET: multiview_h36m 9 | TEST_DATASET: multiview_h36m 10 | ROOT: '/public_bme/data/XiaobenLi/HPE/mocap_data/' 11 | TRAIN_SUBSET: train 12 | TEST_SUBSET: validation 13 | DATA_FORMAT: 'zip' 14 | CROP: True 15 | ROT_FACTOR: 0 16 | SCALE_FACTOR: 0 17 | N_VIEWS: 4 18 | WITH_DAMAGED: False 19 | PREFETCH: False 20 | MOCAP: '/public_bme/data/XiaobenLi/HPE/mocap_data/extra/cmu_mocap.npz' 21 | GENERAL: 22 | NUM_WORKERS: 8 23 | SMPL: 24 | MODEL_PATH: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl 25 | GENDER: neutral 26 | NUM_BODY_JOINTS: 23 27 | JOINT_REGRESSOR_EXTRA: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/SMPL_to_J19.pkl 28 | MEAN_PARAMS: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl_mean_params.npz 29 | EXTRA: 30 | FOCAL_LENGTH: 5000 31 | MODEL: 32 | BACKBONE: 33 | TYPE: resnet 34 | PRETRAINED_WEIGHTS: /public_bme/data/XiaobenLi/HPE/pretrained_models/pose_coco/pose_resnet_50_256x192.pth 35 | NUM_LAYERS: 50 36 | IMAGE_SIZE: 37 | - 256 38 | - 256 39 | SMPL_HEAD: 40 | TYPE: transformer_decoder 41 | IN_CHANNELS: 2048 42 | TRANSFORMER_DECODER: 43 | depth: 6 44 | heads: 8 45 | mlp_dim: 1024 46 | dim_head: 64 47 | dropout: 0.0 48 | emb_dropout: 0.0 49 | norm: layer 50 | context_dim: 2048 51 | POSITIONAL_ENCODING: 'SinePositionalEncoding3D' 52 | TRAIN: 53 | RESUME: True 54 | LR: 1.0e-05 55 | WEIGHT_DECAY: 1.0e-4 56 | BATCH_SIZE: 72 57 | RENDER_MESH: False 58 | TOTAL_EPOCHS: 200 59 | LOG_INTERVAL: 100 60 | LOSS_WEIGHTS: 61 | KEYPOINTS_3D: 0.05 62 | KEYPOINTS_2D: 0.01 63 | GLOBAL_ORIENT: 0.001 64 | BODY_POSE: 0.001 65 | BETAS: 0.0005 66 | ADVERSARIAL: 0. 67 | TEST: 68 | BATCH_SIZE: 72 69 | MODEL_FILE: '/public/bme/home/lixb1/Projects/HPE/multi_view_trans/output/h36m/resnet50_pos/2023-10-30-21-04/model_best.pth.tar' -------------------------------------------------------------------------------- /experiments/h36m/vit_fcn.yaml: -------------------------------------------------------------------------------- 1 | IS_TRAIN: True 2 | GPUS: '0' 3 | OUTPUT_DIR: output 4 | LOG_DIR: logs 5 | DATASET: 6 | COLOR_RGB: true 7 | SOURCE: h36m 8 | TRAIN_DATASET: multiview_h36m 9 | TEST_DATASET: multiview_h36m 10 | ROOT: '/public_bme/data/XiaobenLi/HPE/mocap_data/' 11 | TRAIN_SUBSET: train 12 | TEST_SUBSET: validation 13 | DATA_FORMAT: 'zip' 14 | CROP: True 15 | ROT_FACTOR: 0 16 | SCALE_FACTOR: 0 17 | N_VIEWS: 4 18 | WITH_DAMAGED: False 19 | PREFETCH: False 20 | GENERAL: 21 | NUM_WORKERS: 8 22 | SMPL: 23 | MODEL_PATH: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl 24 | GENDER: neutral 25 | NUM_BODY_JOINTS: 23 26 | JOINT_REGRESSOR_EXTRA: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/SMPL_to_J19.pkl 27 | MEAN_PARAMS: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl_mean_params.npz 28 | EXTRA: 29 | FOCAL_LENGTH: 5000 30 | MODEL: 31 | BACKBONE: 32 | TYPE: vit 33 | PRETRAINED_WEIGHTS: /public_bme/data/XiaobenLi/pretrained_models/vitpose_backbone.pth 34 | IMAGE_SIZE: 35 | - 256 36 | - 256 37 | SMPL_HEAD: 38 | TYPE: fcn 39 | IN_CHANNELS: 1280 40 | TRANSFORMER_DECODER: 41 | depth: 6 42 | heads: 8 43 | mlp_dim: 1024 44 | dim_head: 64 45 | dropout: 0.0 46 | emb_dropout: 0.0 47 | norm: layer 48 | context_dim: 1280 49 | TRAIN: 50 | RESUME: True 51 | LR: 1.0e-05 52 | WEIGHT_DECAY: 1.0e-4 53 | BATCH_SIZE: 24 54 | RENDER_MESH: False 55 | TOTAL_EPOCHS: 200 56 | LOG_INTERVAL: 100 57 | LOSS_WEIGHTS: 58 | KEYPOINTS_3D: 0.05 59 | KEYPOINTS_2D: 0.01 60 | GLOBAL_ORIENT: 0.001 61 | BODY_POSE: 0.001 62 | BETAS: 0.0005 63 | ADVERSARIAL: 0.0005 64 | TEST: 65 | BATCH_SIZE: 72 66 | MODEL_FILE: '/public/bme/home/lixb1/Projects/HPE/multi_view_trans/output/h36m/vit/2023-10-21-22-19/model_best.pth.tar' -------------------------------------------------------------------------------- /experiments/h36m/vit_fcn_fusion.yaml: -------------------------------------------------------------------------------- 1 | IS_TRAIN: True 2 | GPUS: '0' 3 | OUTPUT_DIR: output 4 | LOG_DIR: logs 5 | DATASET: 6 | COLOR_RGB: true 7 | SOURCE: h36m 8 | TRAIN_DATASET: multiview_h36m 9 | TEST_DATASET: multiview_h36m 10 | ROOT: '/public_bme/data/XiaobenLi/HPE/mocap_data/' 11 | TRAIN_SUBSET: train 12 | TEST_SUBSET: validation 13 | DATA_FORMAT: 'zip' 14 | CROP: True 15 | ROT_FACTOR: 0 16 | SCALE_FACTOR: 0 17 | N_VIEWS: 4 18 | WITH_DAMAGED: False 19 | PREFETCH: False 20 | MOCAP: '/public_bme/data/XiaobenLi/HPE/mocap_data/extra/cmu_mocap.npz' 21 | GENERAL: 22 | NUM_WORKERS: 8 23 | SMPL: 24 | MODEL_PATH: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl 25 | GENDER: neutral 26 | NUM_BODY_JOINTS: 23 27 | JOINT_REGRESSOR_EXTRA: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/SMPL_to_J19.pkl 28 | MEAN_PARAMS: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl_mean_params.npz 29 | EXTRA: 30 | FOCAL_LENGTH: 5000 31 | MODEL: 32 | BACKBONE: 33 | TYPE: vit 34 | PRETRAINED_WEIGHTS: /public_bme/data/XiaobenLi/pretrained_models/vitpose_backbone.pth 35 | IMAGE_SIZE: 36 | - 256 37 | - 256 38 | SMPL_HEAD: 39 | TYPE: fcn_fusion 40 | IN_CHANNELS: 1280 41 | TRANSFORMER_DECODER: 42 | depth: 6 43 | heads: 8 44 | mlp_dim: 1024 45 | dim_head: 64 46 | dropout: 0.0 47 | emb_dropout: 0.0 48 | norm: layer 49 | context_dim: 1280 50 | TRAIN: 51 | RESUME: True 52 | LR: 1.0e-05 53 | WEIGHT_DECAY: 1.0e-4 54 | BATCH_SIZE: 24 55 | RENDER_MESH: False 56 | TOTAL_EPOCHS: 200 57 | LOG_INTERVAL: 100 58 | LOSS_WEIGHTS: 59 | KEYPOINTS_3D: 0.05 60 | KEYPOINTS_2D: 0.01 61 | GLOBAL_ORIENT: 0.001 62 | BODY_POSE: 0.001 63 | BETAS: 0.0005 64 | ADVERSARIAL: 0. 65 | TEST: 66 | BATCH_SIZE: 72 67 | MODEL_FILE: '/public/bme/home/lixb1/Projects/HPE/multi_view_trans/output/h36m/vit/2023-10-21-22-19/model_best.pth.tar' -------------------------------------------------------------------------------- /experiments/h36m/vit_pos_trans_encoder.yaml: -------------------------------------------------------------------------------- 1 | IS_TRAIN: True 2 | GPUS: '0' 3 | OUTPUT_DIR: output 4 | LOG_DIR: logs 5 | DATASET: 6 | COLOR_RGB: true 7 | SOURCE: h36m 8 | TRAIN_DATASET: multiview_h36m 9 | TEST_DATASET: multiview_h36m 10 | ROOT: '/public_bme/data/XiaobenLi/HPE/mocap_data/' 11 | TRAIN_SUBSET: train 12 | TEST_SUBSET: validation 13 | DATA_FORMAT: 'zip' 14 | CROP: True 15 | ROT_FACTOR: 0 16 | SCALE_FACTOR: 0 17 | N_VIEWS: 4 18 | WITH_DAMAGED: False 19 | PREFETCH: False 20 | MOCAP: '/public_bme/data/XiaobenLi/HPE/mocap_data/extra/cmu_mocap.npz' 21 | GENERAL: 22 | NUM_WORKERS: 8 23 | SMPL: 24 | MODEL_PATH: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl 25 | GENDER: neutral 26 | NUM_BODY_JOINTS: 23 27 | JOINT_REGRESSOR_EXTRA: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/SMPL_to_J19.pkl 28 | MEAN_PARAMS: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl_mean_params.npz 29 | EXTRA: 30 | FOCAL_LENGTH: 5000 31 | MODEL: 32 | PRETRAINED: /public_bme/data/XiaobenLi/pretrained_models/epoch=35-step=1000000.ckpt 33 | BACKBONE: 34 | TYPE: vit 35 | PRETRAINED_WEIGHTS: /public_bme/data/XiaobenLi/pretrained_models/vitpose_backbone.pth 36 | IMAGE_SIZE: 37 | - 256 38 | - 256 39 | SMPL_HEAD: 40 | TYPE: transformer_decoder 41 | IN_CHANNELS: 2048 42 | TRANSFORMER_DECODER: 43 | depth: 6 44 | heads: 8 45 | mlp_dim: 1024 46 | dim_head: 64 47 | dropout: 0.0 48 | emb_dropout: 0.0 49 | norm: layer 50 | context_dim: 1280 51 | POSITIONAL_ENCODING: 'SinePositionalEncoding3D' 52 | TRAIN: 53 | RESUME: True 54 | LR: 1.0e-05 55 | WEIGHT_DECAY: 1.0e-4 56 | BATCH_SIZE: 24 57 | RENDER_MESH: False 58 | TOTAL_EPOCHS: 200 59 | LOG_INTERVAL: 100 60 | LOSS_WEIGHTS: 61 | KEYPOINTS_3D: 0.05 62 | KEYPOINTS_2D: 0.01 63 | GLOBAL_ORIENT: 0.001 64 | BODY_POSE: 0.001 65 | BETAS: 0.0005 66 | ADVERSARIAL: 0. 67 | TEST: 68 | BATCH_SIZE: 72 69 | MODEL_FILE: '/public/bme/home/lixb1/Projects/HPE/multi_view_trans/output/h36m/vit_pos_trans_encoder/2023-10-31-07-40/model_best.pth.tar' -------------------------------------------------------------------------------- /experiments/h36m/vit_pos_trans_encoder_token.yaml: -------------------------------------------------------------------------------- 1 | IS_TRAIN: True 2 | GPUS: '0' 3 | OUTPUT_DIR: output 4 | LOG_DIR: logs 5 | DATASET: 6 | COLOR_RGB: true 7 | SOURCE: h36m 8 | TRAIN_DATASET: multiview_h36m 9 | TEST_DATASET: multiview_h36m 10 | ROOT: '/public_bme/data/XiaobenLi/HPE/mocap_data/' 11 | TRAIN_SUBSET: train 12 | TEST_SUBSET: validation 13 | DATA_FORMAT: 'zip' 14 | CROP: True 15 | ROT_FACTOR: 0 16 | SCALE_FACTOR: 0 17 | N_VIEWS: 4 18 | WITH_DAMAGED: False 19 | PREFETCH: False 20 | GENERAL: 21 | NUM_WORKERS: 8 22 | SMPL: 23 | MODEL_PATH: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl 24 | GENDER: neutral 25 | NUM_BODY_JOINTS: 23 26 | JOINT_REGRESSOR_EXTRA: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/SMPL_to_J19.pkl 27 | MEAN_PARAMS: /public_bme/data/XiaobenLi/HPE/mocap_data/extra/smpl_mean_params.npz 28 | EXTRA: 29 | FOCAL_LENGTH: 5000 30 | MODEL: 31 | BACKBONE: 32 | TYPE: vit 33 | PRETRAINED_WEIGHTS: /public_bme/data/XiaobenLi/pretrained_models/vitpose_backbone.pth 34 | IMAGE_SIZE: 35 | - 256 36 | - 256 37 | SMPL_HEAD: 38 | TYPE: transformer_decoder_token 39 | IN_CHANNELS: 2048 40 | TRANSFORMER_DECODER: 41 | depth: 6 42 | heads: 8 43 | mlp_dim: 1024 44 | dim_head: 64 45 | dropout: 0.0 46 | emb_dropout: 0.0 47 | norm: layer 48 | context_dim: 1280 49 | POSITIONAL_ENCODING: 'SinePositionalEncoding3D' 50 | TRAIN: 51 | RESUME: True 52 | LR: 1.0e-05 53 | WEIGHT_DECAY: 1.0e-4 54 | BATCH_SIZE: 24 55 | RENDER_MESH: False 56 | TOTAL_EPOCHS: 200 57 | LOG_INTERVAL: 100 58 | LOSS_WEIGHTS: 59 | KEYPOINTS_3D: 0.05 60 | KEYPOINTS_2D: 0.01 61 | GLOBAL_ORIENT: 0.001 62 | BODY_POSE: 0.001 63 | BETAS: 0.0005 64 | ADVERSARIAL: 0.0005 65 | TEST: 66 | BATCH_SIZE: 72 67 | MODEL_FILE: '/public/bme/home/lixb1/Projects/HPE/multi_view_trans/output/h36m/vit/2023-10-21-22-19/model_best.pth.tar' -------------------------------------------------------------------------------- /images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaobenLi00/U-HMR/54f2b6516018a79590ea8e5d160fe1f3bd9b7c8d/images/framework.png -------------------------------------------------------------------------------- /images/quali_eval.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaobenLi00/U-HMR/54f2b6516018a79590ea8e5d160fe1f3bd9b7c8d/images/quali_eval.jpg -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | # export PYOPENGL_PLATFORM=osmesa 2 | # python infer.py --cfg_name experiments/h36m/resnet50_pos_3view.yaml --image_dir ./test_data 3 | import torch 4 | from lib.utils.config import get_config 5 | from lib.models.fusion import Mv_Fusion 6 | 7 | from lib.utils.renderer import Renderer 8 | from lib.models.smpl_wrapper import SMPL 9 | from lib.utils import vis 10 | import os 11 | import time 12 | import argparse 13 | import random 14 | from torchvision import transforms 15 | from PIL import Image 16 | import glob 17 | import numpy as np 18 | from einops import rearrange 19 | import trimesh 20 | 21 | 22 | def process_images(image_folder, cfg): 23 | transform = transforms.Compose([ 24 | transforms.Resize((cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])), 25 | transforms.ToTensor(), 26 | transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 27 | ]) 28 | 29 | image_paths = sorted(glob.glob(os.path.join(image_folder, '*'))) 30 | inputs = [] 31 | 32 | for path in image_paths[:4]: 33 | img = Image.open(path).convert('RGB') 34 | img_tensor = transform(img).unsqueeze(0).cuda() # [1, C, H, W] 35 | inputs.append(img_tensor) 36 | 37 | return inputs 38 | 39 | def convert_to_image(keypoints_2d_vis,images_pred, save_dir): 40 | if keypoints_2d_vis.dtype == np.float32: 41 | array = (array * 255).astype(np.uint8) 42 | elif keypoints_2d_vis.dtype != np.uint8: 43 | raise ValueError("array must be uint8 or floa32") 44 | 45 | if keypoints_2d_vis.shape[0] == 4: # RGBA 46 | keypoints_2d_vis = keypoints_2d_vis.transpose(1, 2, 0) 47 | Image.fromarray(keypoints_2d_vis, 'RGBA').save(os.path.join(save_dir, f'keypoints_2d_vis1.png')) 48 | 49 | 50 | images_pred = images_pred.detach().cpu() 51 | if images_pred.dim() == 3 and images_pred.size(0) == 3: 52 | images_pred = images_pred.permute(1, 2, 0) # H x W x 3 53 | 54 | if images_pred.dtype == torch.float32: 55 | images_pred = (images_pred * 255).clamp(0, 255).to(torch.uint8) 56 | Image.fromarray(images_pred.numpy(), 'RGB').save(os.path.join(save_dir, f'mesh_vis1.png')) 57 | 58 | 59 | 60 | 61 | 62 | def visualize_results(input,output, n_views, mesh_renderer, smpl): 63 | images = torch.cat(input, dim=0) 64 | images = rearrange(images, "(n b) c d e -> (b n) c d e", n=n_views) 65 | 66 | pred_keypoints_2d = rearrange(output['pred_keypoints_2d'], "(n b) c d -> (b n) c d", n=n_views) 67 | 68 | keypoints_2d_vis = vis.visualize_2d_pose(images, pred_keypoints_2d) 69 | 70 | 71 | images = images.detach() * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1) 72 | images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1) 73 | 74 | 75 | pred_vertice = rearrange(output['pred_vertices'], "(n b) c d -> (b n) c d", n=n_views) 76 | pred_cam_t = rearrange(output['pred_cam_t'], "(n b) c -> (b n) c", n=n_views) 77 | 78 | images_pred = mesh_renderer.visualize_tb(pred_vertice.detach(), pred_cam_t.detach(), images) 79 | 80 | convert_to_image(keypoints_2d_vis, images_pred,"./vis") 81 | 82 | mesh_vertices = rearrange(pred_vertice, "(b n) c d -> b n c d", n=n_views) 83 | for b in range(mesh_vertices.shape[0]): 84 | for n in range(mesh_vertices.shape[1]): 85 | mesh_vertice = mesh_vertices.clone().detach().cpu().numpy()[b, n] 86 | vertex_colors = np.ones([mesh_vertice.shape[0], 4]) * [0.82, 0.9, 0.98, 1.0] 87 | face_colors = np.ones([smpl.faces.shape[0], 4]) * [0.82, 0.9, 0.98, 1.0] 88 | mesh = trimesh.Trimesh(mesh_vertice, smpl.faces, face_colors=face_colors, vertex_colors=vertex_colors, process=False) 89 | mesh.export(f'mesh_{b}_{n}.obj') 90 | 91 | 92 | def infer(): 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument( 95 | '--cfg_name', help='experiment configure file name', required=True, type=str) 96 | parser.add_argument( 97 | '--image_dir', help='input your test image dir', required=True, type=str) 98 | args = parser.parse_args() 99 | cfg = get_config(args.cfg_name, merge=False) 100 | 101 | 102 | gpus=[0] 103 | model = Mv_Fusion(cfg, tensorboard_log_dir=None) 104 | model = torch.nn.DataParallel(model, device_ids=gpus).cuda() 105 | 106 | 107 | smpl_cfg = {k.lower(): v for k,v in dict(cfg.SMPL).items()} 108 | smpl = SMPL(**smpl_cfg) 109 | 110 | mesh_renderer = Renderer(focal_length=cfg.EXTRA.FOCAL_LENGTH, 111 | img_res=256, faces=smpl.faces) 112 | 113 | checkpoint = torch.load(cfg.TEST.MODEL_FILE) 114 | model.module.load_state_dict(checkpoint['state_dict']) 115 | model.eval() 116 | input_sub = process_images(args.image_dir, cfg) 117 | n_views = len(input_sub) 118 | 119 | with torch.no_grad(): 120 | output = model.module.forward_step( 121 | input_sub, 122 | n_views 123 | ) 124 | 125 | visualize_results(input_sub,output,n_views,mesh_renderer,smpl) 126 | if __name__ == "__main__": 127 | infer() -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaobenLi00/U-HMR/54f2b6516018a79590ea8e5d160fe1f3bd9b7c8d/lib/__init__.py -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | # from dataset.mpii import MPIIDataset as mpii 6 | # from dataset.h36m import H36MDataset as h36m 7 | from .multiview_h36m import MultiViewH36M as multiview_h36m 8 | from .h36m import H36M as h36m 9 | from .mpii3d import MPII3D as mpii3d 10 | # from dataset.mixed_dataset import MixedDataset as mixed 11 | from .multiview_mpii3d import MultiViewMPII3D as multiview_mpii3d 12 | from .multiview_totalcapture import MultiViewTotalCapture as multiview_totalcapture 13 | from .mocap_dataset import MoCapDataset as mocap_dataset -------------------------------------------------------------------------------- /lib/datasets/h36m.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os.path as osp 12 | import numpy as np 13 | import pickle 14 | import collections 15 | import torchvision.transforms as transforms 16 | 17 | from .joints_dataset import JointsDataset 18 | 19 | 20 | class H36M(JointsDataset): 21 | 22 | def __init__(self, cfg, image_set, is_train, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize( 23 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ])): 25 | super().__init__(cfg, image_set, is_train, transform) 26 | self.actual_joints = { 27 | 0: 'root', 28 | 1: 'rhip', 29 | 2: 'rkne', 30 | 3: 'rank', 31 | 4: 'lhip', 32 | 5: 'lkne', 33 | 6: 'lank', 34 | 7: 'belly', 35 | 8: 'neck', 36 | 9: 'nose', 37 | 10: 'head', 38 | 11: 'lsho', 39 | 12: 'lelb', 40 | 13: 'lwri', 41 | 14: 'rsho', 42 | 15: 'relb', 43 | 16: 'rwri' 44 | } 45 | 46 | if is_train: 47 | anno_file = osp.join(self.root, 'h36m', 'annot', 48 | 'h36m_{}_with_mosh_sample_5.pkl'.format(image_set)) 49 | else: 50 | anno_file = osp.join(self.root, 'h36m', 'annot', 51 | 'h36m_{}.pkl'.format(image_set)) 52 | 53 | self.db = self.load_db(anno_file) 54 | 55 | # print(len(self.db)) 56 | # self.u2a_mapping = super().get_mapping() 57 | # super().do_mapping() 58 | if not cfg.DATASET.WITH_DAMAGED: 59 | self.db = [db_rec for db_rec in self.db if not self.isdamaged(db_rec)] 60 | # print(len(self.db)) 61 | # self.grouping = self.get_group(self.db) 62 | # self.group_size = len(self.grouping) 63 | # print(self.group_size) 64 | self.db_size = len(self.db) 65 | 66 | 67 | def load_db(self, dataset_file): 68 | with open(dataset_file, 'rb') as f: 69 | dataset = pickle.load(f) 70 | return dataset 71 | 72 | def get_group(self, db): 73 | grouping = {} 74 | nitems = len(db) 75 | for i in range(nitems): 76 | keystr = self.get_key_str(db[i]) 77 | camera_id = db[i]['camera_id'] 78 | if keystr not in grouping: 79 | grouping[keystr] = [-1, -1, -1, -1] 80 | grouping[keystr][camera_id] = i 81 | 82 | filtered_grouping = [] 83 | for _, v in grouping.items(): 84 | if np.all(np.array(v) != -1): 85 | filtered_grouping.append(v) 86 | 87 | if self.is_train: 88 | # filtered_grouping = filtered_grouping[::400] 89 | pass 90 | else: 91 | filtered_grouping = filtered_grouping[::64] 92 | # pass 93 | 94 | 95 | 96 | return filtered_grouping 97 | 98 | def __getitem__(self, idx): 99 | input, meta = [], [] 100 | i, m = super().__getitem__(idx) 101 | input.append(i) 102 | meta.append(m) 103 | return input, meta 104 | # input, target, weight, meta = [], [], [], [] 105 | # input, meta = [], [] 106 | # items = self.grouping[idx] 107 | # for item in items: 108 | # # i, t, w, m = super().__getitem__(item) 109 | # i, m = super().__getitem__(item) 110 | # input.append(i) 111 | # # target.append(t) 112 | # # weight.append(w) 113 | # meta.append(m) 114 | # # return input, target, weight, meta 115 | # # return input, meta, idx 116 | # return input, meta 117 | def __len__(self): 118 | # return self.group_size 119 | return self.db_size 120 | 121 | def isdamaged(self, db_rec): 122 | # from https://github.com/yihui-he/epipolar-transformers/blob/4da5cbca762aef6a89d37f889789f772b87d2688/data/datasets/joints_dataset.py#L174 123 | #damaged seq 124 | #'Greeting-2', 'SittingDown-2', 'Waiting-1' 125 | if db_rec['subject'] == 9: 126 | if db_rec['action'] != 5 or db_rec['subaction'] != 2: 127 | if db_rec['action'] != 10 or db_rec['subaction'] != 2: 128 | if db_rec['action'] != 13 or db_rec['subaction'] != 1: 129 | return False 130 | else: 131 | return False 132 | return True 133 | 134 | def get_key_str(self, datum): 135 | return 's_{:02}_act_{:02}_subact_{:02}_imgid_{:06}'.format( 136 | datum['subject'], datum['action'], datum['subaction'], 137 | datum['image_id']) 138 | -------------------------------------------------------------------------------- /lib/datasets/joints_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import cv2 12 | import copy 13 | import random 14 | import numpy as np 15 | import os.path as osp 16 | 17 | import torch 18 | from torch.utils.data import Dataset 19 | import sys 20 | 21 | from ..utils import zipreader 22 | from ..utils.transforms import get_affine_transform 23 | from ..utils.transforms import affine_transform 24 | from ..utils.triangulate import triangulate_poses, camera_to_world_frame 25 | 26 | 27 | class JointsDataset(Dataset): 28 | 29 | def __init__(self, cfg, subset, is_train, transform=None): 30 | self.is_train = is_train 31 | self.subset = subset 32 | 33 | self.root = cfg.DATASET.ROOT 34 | self.data_format = cfg.DATASET.DATA_FORMAT 35 | self.scale_factor = cfg.DATASET.SCALE_FACTOR 36 | self.rotation_factor = cfg.DATASET.ROT_FACTOR 37 | self.image_size = cfg.MODEL.IMAGE_SIZE 38 | self.transform = transform 39 | self.db = [] 40 | self.color_rgb = cfg.DATASET.COLOR_RGB 41 | 42 | self.num_joints = 17 43 | def _get_db(self): 44 | raise NotImplementedError 45 | 46 | def evaluate(self, cfg, preds, output_dir, *args, **kwargs): 47 | raise NotImplementedError 48 | 49 | def __len__(self, ): 50 | return len(self.db) 51 | 52 | def __getitem__(self, idx): 53 | db_rec = copy.deepcopy(self.db[idx]) 54 | 55 | image_dir = 'images.zip@' if self.data_format == 'zip' else '' 56 | image_file = osp.join(self.root, db_rec['source'], image_dir, 'images', 57 | db_rec['image']) 58 | if self.data_format == 'zip': 59 | 60 | data_numpy = zipreader.imread( 61 | image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) 62 | else: 63 | data_numpy = cv2.imread( 64 | image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) 65 | if self.color_rgb: 66 | data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB) 67 | joints = db_rec['joints_2d'].copy() 68 | joints_vis = db_rec['joints_vis'].copy() 69 | 70 | center = np.array(db_rec['center']).copy() 71 | scale = np.array(db_rec['scale']).copy() 72 | rotation = 0 73 | 74 | if self.is_train: 75 | sf = self.scale_factor 76 | rf = self.rotation_factor 77 | scale = scale * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) 78 | rotation = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) \ 79 | if random.random() <= 0.6 else 0 80 | # sf = 0.2 81 | # scale = scale * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) 82 | 83 | # center_shift = np.array([(random.random()-0.5)*40 , (random.random()-0.5)*40]) 84 | # center += center_shift 85 | trans = get_affine_transform(center, scale, rotation, self.image_size) 86 | input = cv2.warpAffine( 87 | data_numpy, 88 | trans, (int(self.image_size[0]), int(self.image_size[1])), 89 | flags=cv2.INTER_LINEAR) 90 | 91 | if self.transform: 92 | input = self.transform(input) 93 | 94 | for i in range(self.num_joints): 95 | if joints_vis[i, 0] > 0.0: 96 | joints[i, 0:2] = affine_transform(joints[i, 0:2], trans) 97 | if (np.min(joints[i, :2]) < 0 or 98 | joints[i, 0] >= self.image_size[0] or 99 | joints[i, 1] >= self.image_size[1]): 100 | joints_vis[i, :] = 0 101 | 102 | meta = { 103 | 'scale': scale, 104 | 'center': center, 105 | 'rotation': rotation, 106 | 'joints_2d': db_rec['joints_2d'].astype(np.float32), 107 | 'joints_2d_transformed': joints.astype(np.float32), 108 | # 'joints_3d_world': camera_to_world_frame(db_rec['joints_3d_camera'], db_rec['camera']['R'], db_rec['camera']['T']).astype(np.float32), 109 | 'joints_3d_camera': db_rec['joints_3d_camera'].astype(np.float32), 110 | # 'camera_params': db_rec['camera'], 111 | 'joints_vis': joints_vis, 112 | 'source': db_rec['source'], 113 | 'index': idx 114 | } 115 | # if self.is_train and db_rec['source'] == 'h36m': 116 | if self.is_train: 117 | meta['smpl_params'] = {'global_orient': db_rec['global_orient'].astype(np.float32), 118 | 'body_pose': db_rec['body_pose'].astype(np.float32), 119 | 'betas': db_rec['betas'].astype(np.float32)} 120 | meta['has_smpl_params'] = {'global_orient': np.ones(1,dtype=np.float32)[0], 121 | 'body_pose': np.ones(1,dtype=np.float32)[0], 122 | 'betas': np.ones(1,dtype=np.float32)[0]} 123 | else: 124 | # meta['has_smpl'] = False 125 | meta['smpl_params'] = {'global_orient': np.zeros((1,3,3),dtype=np.float32), 126 | 'body_pose': np.zeros((23,3,3),dtype=np.float32), 127 | 'betas': np.zeros(10,dtype=np.float32)} 128 | meta['has_smpl_params'] = {'global_orient': np.zeros(1,dtype=np.float32)[0], 129 | 'body_pose': np.zeros(1,dtype=np.float32)[0], 130 | 'betas': np.zeros(1,dtype=np.float32)[0]} 131 | return input, meta -------------------------------------------------------------------------------- /lib/datasets/mocap_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | class MoCapDataset: 5 | 6 | def __init__(self, dataset_file: str): 7 | """ 8 | Dataset class used for loading a dataset of unpaired SMPL parameter annotations 9 | Args: 10 | cfg (CfgNode): Model config file. 11 | dataset_file (str): Path to npz file containing dataset info. 12 | """ 13 | data = np.load(dataset_file) 14 | pose = data['body_pose'].astype(np.float32)[:, 3:] 15 | betas = data['betas'].astype(np.float32) 16 | # self.pose = data['body_pose'].astype(np.float32)[:, 3:] 17 | # self.betas = data['betas'].astype(np.float32) 18 | self.pose = np.concatenate((pose, pose), axis=0) 19 | self.betas = np.concatenate((betas, betas), axis=0) 20 | self.length = len(self.pose) 21 | 22 | def __getitem__(self, idx: int) -> Dict: 23 | pose = self.pose[idx].copy() 24 | betas = self.betas[idx].copy() 25 | item = {'body_pose': pose, 'betas': betas} 26 | return item 27 | 28 | def __len__(self) -> int: 29 | return self.length 30 | -------------------------------------------------------------------------------- /lib/datasets/mpii3d.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os.path as osp 12 | import numpy as np 13 | import pickle 14 | import collections 15 | import torchvision.transforms as transforms 16 | 17 | from .joints_dataset import JointsDataset 18 | 19 | 20 | class MPII3D(JointsDataset): 21 | 22 | def __init__(self, cfg, image_set, is_train, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize( 23 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ])): 25 | super().__init__(cfg, image_set, is_train, transform) 26 | self.actual_joints = { 27 | 0: 'root', 28 | 1: 'rhip', 29 | 2: 'rkne', 30 | 3: 'rank', 31 | 4: 'lhip', 32 | 5: 'lkne', 33 | 6: 'lank', 34 | 7: 'belly', 35 | 8: 'neck', 36 | 9: 'nose', 37 | 10: 'head', 38 | 11: 'lsho', 39 | 12: 'lelb', 40 | 13: 'lwri', 41 | 14: 'rsho', 42 | 15: 'relb', 43 | 16: 'rwri' 44 | } 45 | 46 | if cfg.DATASET.CROP: 47 | anno_file = osp.join(self.root, 'mpi_inf_3dhp', 'annot', 48 | 'mpi_inf_3dhp_{}_new.pkl'.format(image_set)) 49 | else: 50 | anno_file = osp.join(self.root, 'mpi_inf_3dhp', 'annot', 51 | 'mpi_inf_3dhp_{}_uncrop.pkl'.format(image_set)) 52 | 53 | self.db = self.load_db(anno_file) 54 | 55 | # self.u2a_mapping = super().get_mapping() 56 | # super().do_mapping() 57 | 58 | # self.grouping = self.get_group(self.db) 59 | # self.group_size = len(self.grouping) 60 | self.db_size = len(self.db) 61 | 62 | def load_db(self, dataset_file): 63 | with open(dataset_file, 'rb') as f: 64 | dataset = pickle.load(f) 65 | return dataset 66 | 67 | def get_group(self, db): 68 | grouping = {} 69 | nitems = len(db) 70 | for i in range(nitems): 71 | keystr = self.get_key_str(db[i]) 72 | camera_id = db[i]['camera_id'] 73 | if keystr not in grouping: 74 | grouping[keystr] = [-1, -1, -1, -1] 75 | grouping[keystr][camera_id] = i 76 | 77 | filtered_grouping = [] 78 | for _, v in grouping.items(): 79 | if np.all(np.array(v) != -1): 80 | filtered_grouping.append(v) 81 | 82 | if self.is_train: 83 | filtered_grouping = filtered_grouping[::5] 84 | # filtered_grouping = filtered_grouping[::500] 85 | else: 86 | filtered_grouping = filtered_grouping[::5] 87 | # filtered_grouping = filtered_grouping[::640] 88 | # filtered_grouping = filtered_grouping[::64] 89 | 90 | # if self.is_train: 91 | # filtered_grouping = filtered_grouping[:1] 92 | # else: 93 | # # pass 94 | # filtered_grouping = filtered_grouping[46635:46636] 95 | # # filtered_grouping = filtered_grouping[46636:46637] 96 | # # filtered_grouping = filtered_grouping[:1] 97 | 98 | 99 | 100 | return filtered_grouping 101 | 102 | def __getitem__(self, idx): 103 | input, meta = [], [] 104 | i, m = super().__getitem__(idx) 105 | input.append(i) 106 | meta.append(m) 107 | return input, meta 108 | # input, target, weight, meta = [], [], [], [] 109 | # input, meta = [], [] 110 | # items = self.grouping[idx] 111 | # for item in items: 112 | # # i, t, w, m = super().__getitem__(item) 113 | # i, m = super().__getitem__(item) 114 | # input.append(i) 115 | # # target.append(t) 116 | # # weight.append(w) 117 | # meta.append(m) 118 | # # return input, target, weight, meta 119 | # return input, meta 120 | 121 | def __len__(self): 122 | # return self.group_size 123 | return self.db_size 124 | 125 | def get_key_str(self, datum): 126 | return 's_{:02}_seq_{:02}_imgid_{:06}'.format( 127 | datum['subject'], datum['sequence'], 128 | datum['image_id']) 129 | 130 | def evaluate(self, pred, *args, **kwargs): 131 | pass -------------------------------------------------------------------------------- /lib/datasets/multiview_h36m.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os.path as osp 12 | import numpy as np 13 | import pickle 14 | import collections 15 | import torchvision.transforms as transforms 16 | 17 | from .joints_dataset import JointsDataset 18 | 19 | 20 | class MultiViewH36M(JointsDataset): 21 | 22 | def __init__(self, cfg, image_set, is_train, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize( 23 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ])): 25 | super().__init__(cfg, image_set, is_train, transform) 26 | self.actual_joints = { 27 | 0: 'root', 28 | 1: 'rhip', 29 | 2: 'rkne', 30 | 3: 'rank', 31 | 4: 'lhip', 32 | 5: 'lkne', 33 | 6: 'lank', 34 | 7: 'belly', 35 | 8: 'neck', 36 | 9: 'nose', 37 | 10: 'head', 38 | 11: 'lsho', 39 | 12: 'lelb', 40 | 13: 'lwri', 41 | 14: 'rsho', 42 | 15: 'relb', 43 | 16: 'rwri' 44 | } 45 | 46 | if is_train: 47 | anno_file = osp.join(self.root, 'h36m', 'annot', 48 | 'h36m_{}_with_mosh_sample_5_rot.pkl'.format(image_set)) 49 | else: 50 | anno_file = osp.join(self.root, 'h36m', 'annot', 51 | 'h36m_{}.pkl'.format(image_set)) 52 | 53 | self.db = self.load_db(anno_file) 54 | 55 | # print(len(self.db)) 56 | # self.u2a_mapping = super().get_mapping() 57 | # super().do_mapping() 58 | if not cfg.DATASET.WITH_DAMAGED: 59 | self.db = [db_rec for db_rec in self.db if not self.isdamaged(db_rec)] 60 | # print(len(self.db)) 61 | self.grouping = self.get_group(self.db) 62 | self.group_size = len(self.grouping) 63 | # print(self.group_size) 64 | 65 | 66 | def load_db(self, dataset_file): 67 | with open(dataset_file, 'rb') as f: 68 | dataset = pickle.load(f) 69 | return dataset 70 | 71 | def get_group(self, db): 72 | grouping = {} 73 | nitems = len(db) 74 | for i in range(nitems): 75 | keystr = self.get_key_str(db[i]) 76 | camera_id = db[i]['camera_id'] 77 | if keystr not in grouping: 78 | grouping[keystr] = [-1, -1, -1, -1] 79 | grouping[keystr][camera_id] = i 80 | 81 | filtered_grouping = [] 82 | for _, v in grouping.items(): 83 | if np.all(np.array(v) != -1): 84 | filtered_grouping.append(v) 85 | 86 | if self.is_train: 87 | # filtered_grouping = filtered_grouping[::400] 88 | pass 89 | else: 90 | filtered_grouping = filtered_grouping[::64] 91 | # pass 92 | 93 | 94 | 95 | return filtered_grouping 96 | 97 | def __getitem__(self, idx): 98 | # input, target, weight, meta = [], [], [], [] 99 | input, meta = [], [] 100 | items = self.grouping[idx] 101 | for item in items: 102 | # i, t, w, m = super().__getitem__(item) 103 | i, m = super().__getitem__(item) 104 | input.append(i) 105 | # target.append(t) 106 | # weight.append(w) 107 | meta.append(m) 108 | # return input, target, weight, meta 109 | # return input, meta, idx 110 | return input, meta 111 | def __len__(self): 112 | return self.group_size 113 | 114 | def isdamaged(self, db_rec): 115 | # from https://github.com/yihui-he/epipolar-transformers/blob/4da5cbca762aef6a89d37f889789f772b87d2688/data/datasets/joints_dataset.py#L174 116 | #damaged seq 117 | #'Greeting-2', 'SittingDown-2', 'Waiting-1' 118 | if db_rec['subject'] == 9: 119 | if db_rec['action'] != 5 or db_rec['subaction'] != 2: 120 | if db_rec['action'] != 10 or db_rec['subaction'] != 2: 121 | if db_rec['action'] != 13 or db_rec['subaction'] != 1: 122 | return False 123 | else: 124 | return False 125 | return True 126 | 127 | def get_key_str(self, datum): 128 | return 's_{:02}_act_{:02}_subact_{:02}_imgid_{:06}'.format( 129 | datum['subject'], datum['action'], datum['subaction'], 130 | datum['image_id']) 131 | -------------------------------------------------------------------------------- /lib/datasets/multiview_mpii3d.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os.path as osp 12 | import numpy as np 13 | import pickle 14 | import collections 15 | import torchvision.transforms as transforms 16 | 17 | from .joints_dataset import JointsDataset 18 | 19 | 20 | class MultiViewMPII3D(JointsDataset): 21 | 22 | def __init__(self, cfg, image_set, is_train, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize( 23 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ])): 25 | super().__init__(cfg, image_set, is_train, transform) 26 | self.actual_joints = { 27 | 0: 'root', 28 | 1: 'rhip', 29 | 2: 'rkne', 30 | 3: 'rank', 31 | 4: 'lhip', 32 | 5: 'lkne', 33 | 6: 'lank', 34 | 7: 'belly', 35 | 8: 'neck', 36 | 9: 'nose', 37 | 10: 'head', 38 | 11: 'lsho', 39 | 12: 'lelb', 40 | 13: 'lwri', 41 | 14: 'rsho', 42 | 15: 'relb', 43 | 16: 'rwri' 44 | } 45 | 46 | if is_train: 47 | anno_file = osp.join(self.root, 'mpi_inf_3dhp', 'annot', 48 | 'mpi_inf_3dhp_{}_new_all_views_pseudo.pkl'.format(image_set)) 49 | else: 50 | anno_file = osp.join(self.root, 'mpi_inf_3dhp', 'annot', 51 | 'mpi_inf_3dhp_{}_new_all_views.pkl'.format(image_set)) 52 | 53 | self.db = self.load_db(anno_file) 54 | 55 | # self.u2a_mapping = super().get_mapping() 56 | # super().do_mapping() 57 | 58 | self.grouping = self.get_group(self.db) 59 | self.group_size = len(self.grouping) 60 | 61 | def load_db(self, dataset_file): 62 | with open(dataset_file, 'rb') as f: 63 | dataset_all_view = pickle.load(f) 64 | # return dataset 65 | dataset = [] 66 | for item in dataset_all_view: 67 | if item['camera_id'] == 0: 68 | dataset.append(item) 69 | if item['camera_id'] == 2: 70 | item['camera_id'] = 1 71 | dataset.append(item) 72 | if item['camera_id'] == 6: 73 | item['camera_id'] = 2 74 | dataset.append(item) 75 | if item['camera_id'] == 7: 76 | item['camera_id'] = 3 77 | dataset.append(item) 78 | return dataset 79 | 80 | def get_group(self, db): 81 | grouping = {} 82 | nitems = len(db) 83 | for i in range(nitems): 84 | keystr = self.get_key_str(db[i]) 85 | camera_id = db[i]['camera_id'] 86 | if keystr not in grouping: 87 | grouping[keystr] = [-1, -1, -1, -1] 88 | grouping[keystr][camera_id] = i 89 | 90 | filtered_grouping = [] 91 | for _, v in grouping.items(): 92 | if np.all(np.array(v) != -1): 93 | filtered_grouping.append(v) 94 | 95 | if self.is_train: 96 | filtered_grouping = filtered_grouping[::5] 97 | # filtered_grouping = filtered_grouping[::500] 98 | else: 99 | filtered_grouping = filtered_grouping[::5] 100 | # filtered_grouping = filtered_grouping[::640] 101 | # filtered_grouping = filtered_grouping[::64] 102 | 103 | # if self.is_train: 104 | # filtered_grouping = filtered_grouping[:1] 105 | # else: 106 | # # pass 107 | # filtered_grouping = filtered_grouping[46635:46636] 108 | # # filtered_grouping = filtered_grouping[46636:46637] 109 | # # filtered_grouping = filtered_grouping[:1] 110 | 111 | 112 | 113 | return filtered_grouping 114 | 115 | def __getitem__(self, idx): 116 | # input, target, weight, meta = [], [], [], [] 117 | input, meta = [], [] 118 | items = self.grouping[idx] 119 | for item in items: 120 | # i, t, w, m = super().__getitem__(item) 121 | i, m = super().__getitem__(item) 122 | input.append(i) 123 | # target.append(t) 124 | # weight.append(w) 125 | meta.append(m) 126 | # return input, target, weight, meta 127 | return input, meta 128 | 129 | def __len__(self): 130 | return self.group_size 131 | 132 | def get_key_str(self, datum): 133 | return 's_{:02}_seq_{:02}_imgid_{:06}'.format( 134 | datum['subject'], datum['sequence'], 135 | datum['image_id']) 136 | 137 | def evaluate(self, pred, *args, **kwargs): 138 | pass -------------------------------------------------------------------------------- /lib/datasets/multiview_totalcapture.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os.path as osp 12 | import numpy as np 13 | import pickle 14 | import collections 15 | import torchvision.transforms as transforms 16 | 17 | from .joints_dataset import JointsDataset 18 | 19 | 20 | class MultiViewTotalCapture(JointsDataset): 21 | 22 | def __init__(self, cfg, image_set, is_train, transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize( 23 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ])): 25 | super().__init__(cfg, image_set, is_train, transform) 26 | self.actual_joints = { 27 | 0: 'root', 28 | 1: 'rhip', 29 | 2: 'rkne', 30 | 3: 'rank', 31 | 4: 'lhip', 32 | 5: 'lkne', 33 | 6: 'lank', 34 | 7: 'belly', 35 | 8: 'neck', 36 | 9: 'head', 37 | 10: 'lsho', 38 | 11: 'lelb', 39 | 12: 'lwri', 40 | 13: 'rsho', 41 | 14: 'relb', 42 | 15: 'rwri' 43 | } 44 | 45 | if is_train: 46 | anno_file = osp.join(self.root, 'totalcapture', 'annot', 47 | 'totalcapture_{}_new_17_pseudo.pkl'.format(image_set)) 48 | else: 49 | anno_file = osp.join(self.root, 'totalcapture', 'annot', 50 | 'totalcapture_{}_new_17.pkl'.format(image_set)) 51 | 52 | self.db = self.load_db(anno_file) 53 | 54 | # self.u2a_mapping = super().get_mapping() 55 | # super().do_mapping() 56 | 57 | self.grouping = self.get_group(self.db) 58 | self.group_size = len(self.grouping) 59 | 60 | def index_to_action_names(self): 61 | return { 62 | 2: 'Direction', 63 | 3: 'Discuss', 64 | 4: 'Eating', 65 | 5: 'Greet', 66 | 6: 'Phone', 67 | 7: 'Photo', 68 | 8: 'Pose', 69 | 9: 'Purchase', 70 | 10: 'Sitting', 71 | 11: 'SittingDown', 72 | 12: 'Smoke', 73 | 13: 'Wait', 74 | 14: 'WalkDog', 75 | 15: 'Walk', 76 | 16: 'WalkTwo' 77 | } 78 | 79 | def load_db(self, dataset_file): 80 | with open(dataset_file, 'rb') as f: 81 | dataset_all_view = pickle.load(f) 82 | dataset = [] 83 | for item in dataset_all_view: 84 | if item['camera_id'] % 2 == 0: 85 | item['camera_id'] = int(item['camera_id']/2) 86 | dataset.append(item) 87 | return dataset 88 | 89 | def get_group(self, db): 90 | grouping = {} 91 | nitems = len(db) 92 | for i in range(nitems): 93 | keystr = self.get_key_str(db[i]) 94 | camera_id = db[i]['camera_id'] 95 | if keystr not in grouping: 96 | grouping[keystr] = [-1, -1, -1, -1] 97 | grouping[keystr][camera_id] = i 98 | 99 | filtered_grouping = [] 100 | for _, v in grouping.items(): 101 | if np.all(np.array(v) != -1): 102 | filtered_grouping.append(v) 103 | 104 | if self.is_train: 105 | filtered_grouping = filtered_grouping[::5] 106 | # filtered_grouping = filtered_grouping[::500] 107 | else: 108 | filtered_grouping = filtered_grouping[::10] 109 | # filtered_grouping = filtered_grouping[::640] 110 | 111 | # if self.is_train: 112 | # filtered_grouping = filtered_grouping[:1] 113 | # else: 114 | # # pass 115 | # filtered_grouping = filtered_grouping[46635:46636] 116 | # # filtered_grouping = filtered_grouping[46636:46637] 117 | # # filtered_grouping = filtered_grouping[:1] 118 | 119 | 120 | 121 | return filtered_grouping 122 | 123 | def __getitem__(self, idx): 124 | # input, target, weight, meta = [], [], [], [] 125 | input, meta = [], [] 126 | items = self.grouping[idx] 127 | for item in items: 128 | # i, t, w, m = super().__getitem__(item) 129 | i, m = super().__getitem__(item) 130 | input.append(i) 131 | # target.append(t) 132 | # weight.append(w) 133 | meta.append(m) 134 | # return input, target, weight, meta 135 | # return input, meta, idx 136 | return input, meta 137 | 138 | def __len__(self): 139 | return self.group_size 140 | 141 | def get_key_str(self, datum): 142 | return 's_{:02}_act_{:02}_subact_{:02}_imgid_{:06}'.format( 143 | datum['subject'], datum['action'], datum['subaction'], 144 | datum['image_id']) 145 | 146 | def evaluate(self, pred, *args, **kwargs): 147 | pred = pred.copy() 148 | 149 | headsize = self.image_size[0] / 10.0 150 | threshold = 0.5 151 | 152 | u2a = self.u2a_mapping 153 | a2u = {v: k for k, v in u2a.items() if v != '*'} 154 | a = list(a2u.keys()) 155 | u = list(a2u.values()) 156 | indexes = list(range(len(a))) 157 | indexes.sort(key=a.__getitem__) 158 | sa = list(map(a.__getitem__, indexes)) 159 | su = np.array(list(map(u.__getitem__, indexes))) 160 | 161 | gt = [] 162 | for items in self.grouping: 163 | for item in items: 164 | gt.append(self.db[item]['joints_2d'][su, :2]) 165 | gt = np.array(gt) 166 | pred = pred[:, su, :2] 167 | 168 | distance = np.sqrt(np.sum((gt - pred) ** 2, axis=2)) 169 | detected = (distance <= headsize * threshold) 170 | 171 | joint_detection_rate = np.sum(detected, axis=0) / np.float(gt.shape[0]) 172 | 173 | name_values = collections.OrderedDict() 174 | joint_names = self.actual_joints 175 | for i in range(len(a2u)): 176 | name_values[joint_names[sa[i]]] = joint_detection_rate[i] 177 | return name_values, np.mean(joint_detection_rate) 178 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaobenLi00/U-HMR/54f2b6516018a79590ea8e5d160fe1f3bd9b7c8d/lib/models/__init__.py -------------------------------------------------------------------------------- /lib/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .vit import vit 2 | from .resnet import resnet 3 | from .swin_transformer_v2 import swin_v2 4 | 5 | def create_backbone(cfg): 6 | if cfg.MODEL.BACKBONE.TYPE == 'vit': 7 | return vit(cfg) 8 | elif cfg.MODEL.BACKBONE.TYPE == 'resnet': 9 | return resnet(cfg) 10 | elif cfg.MODEL.BACKBONE.TYPE == 'swin_v2': 11 | return swin_v2(cfg) 12 | else: 13 | raise NotImplementedError('Backbone type is not implemented') -------------------------------------------------------------------------------- /lib/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | # from ..utils.geometry import rot6d_to_rotmat 6 | # from lib.utils.geometry import rot6d_to_rotmat 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | """ Redefinition of Bottleneck residual block 11 | Adapted from the official PyTorch implementation 12 | """ 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1, downsample=None): 16 | super(Bottleneck, self).__init__() 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 23 | self.bn3 = nn.BatchNorm2d(planes * 4) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv3(out) 40 | out = self.bn3(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | class ResNet(nn.Module): 51 | """ SMPL Iterative Regressor with ResNet backbone 52 | """ 53 | 54 | def __init__(self, block, layers): 55 | self.inplanes = 64 56 | super(ResNet, self).__init__() 57 | # npose = 24 * 6 58 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 59 | bias=False) 60 | self.bn1 = nn.BatchNorm2d(64) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 63 | self.layer1 = self._make_layer(block, 64, layers[0]) 64 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 65 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 66 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 67 | 68 | 69 | def _make_layer(self, block, planes, blocks, stride=1): 70 | downsample = None 71 | if stride != 1 or self.inplanes != planes * block.expansion: 72 | downsample = nn.Sequential( 73 | nn.Conv2d(self.inplanes, planes * block.expansion, 74 | kernel_size=1, stride=stride, bias=False), 75 | nn.BatchNorm2d(planes * block.expansion), 76 | ) 77 | 78 | layers = [] 79 | layers.append(block(self.inplanes, planes, stride, downsample)) 80 | self.inplanes = planes * block.expansion 81 | for i in range(1, blocks): 82 | layers.append(block(self.inplanes, planes)) 83 | 84 | return nn.Sequential(*layers) 85 | 86 | 87 | def forward(self, x): 88 | 89 | 90 | x = self.conv1(x) 91 | x = self.bn1(x) 92 | x = self.relu(x) 93 | x = self.maxpool(x) 94 | 95 | x1 = self.layer1(x) 96 | x2 = self.layer2(x1) 97 | x3 = self.layer3(x2) 98 | x4 = self.layer4(x3) 99 | 100 | return x4 101 | 102 | resnet_spec = {50: (Bottleneck, [3, 4, 6, 3]), 103 | 101: (Bottleneck, [3, 4, 23, 3]), 104 | 152: (Bottleneck, [3, 8, 36, 3])} 105 | 106 | def resnet(cfg): 107 | num_layers = cfg.MODEL.BACKBONE.NUM_LAYERS 108 | block_class, layers = resnet_spec[num_layers] 109 | 110 | return ResNet(block_class, layers) 111 | -------------------------------------------------------------------------------- /lib/models/backbones/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import torch 5 | from functools import partial 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.utils.checkpoint as checkpoint 9 | 10 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 11 | 12 | def vit(cfg): 13 | return ViT( 14 | img_size=(256, 192), 15 | patch_size=16, 16 | embed_dim=1280, 17 | depth=32, 18 | num_heads=16, 19 | ratio=1, 20 | use_checkpoint=False, 21 | mlp_ratio=4, 22 | qkv_bias=True, 23 | drop_path_rate=0.55 24 | ) 25 | 26 | def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True): 27 | """ 28 | Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token 29 | dimension for the original embeddings. 30 | Args: 31 | abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). 32 | has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. 33 | hw (Tuple): size of input image tokens. 34 | 35 | Returns: 36 | Absolute positional embeddings after processing with shape (1, H, W, C) 37 | """ 38 | cls_token = None 39 | B, L, C = abs_pos.shape 40 | if has_cls_token: 41 | cls_token = abs_pos[:, 0:1] 42 | abs_pos = abs_pos[:, 1:] 43 | 44 | if ori_h != h or ori_w != w: 45 | new_abs_pos = F.interpolate( 46 | abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), 47 | size=(h, w), 48 | mode="bicubic", 49 | align_corners=False, 50 | ).permute(0, 2, 3, 1).reshape(B, -1, C) 51 | 52 | else: 53 | new_abs_pos = abs_pos 54 | 55 | if cls_token is not None: 56 | new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1) 57 | return new_abs_pos 58 | 59 | class DropPath(nn.Module): 60 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 61 | """ 62 | def __init__(self, drop_prob=None): 63 | super(DropPath, self).__init__() 64 | self.drop_prob = drop_prob 65 | 66 | def forward(self, x): 67 | return drop_path(x, self.drop_prob, self.training) 68 | 69 | def extra_repr(self): 70 | return 'p={}'.format(self.drop_prob) 71 | 72 | class Mlp(nn.Module): 73 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 74 | super().__init__() 75 | out_features = out_features or in_features 76 | hidden_features = hidden_features or in_features 77 | self.fc1 = nn.Linear(in_features, hidden_features) 78 | self.act = act_layer() 79 | self.fc2 = nn.Linear(hidden_features, out_features) 80 | self.drop = nn.Dropout(drop) 81 | 82 | def forward(self, x): 83 | x = self.fc1(x) 84 | x = self.act(x) 85 | x = self.fc2(x) 86 | x = self.drop(x) 87 | return x 88 | 89 | class Attention(nn.Module): 90 | def __init__( 91 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 92 | proj_drop=0., attn_head_dim=None,): 93 | super().__init__() 94 | self.num_heads = num_heads 95 | head_dim = dim // num_heads 96 | self.dim = dim 97 | 98 | if attn_head_dim is not None: 99 | head_dim = attn_head_dim 100 | all_head_dim = head_dim * self.num_heads 101 | 102 | self.scale = qk_scale or head_dim ** -0.5 103 | 104 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) 105 | 106 | self.attn_drop = nn.Dropout(attn_drop) 107 | self.proj = nn.Linear(all_head_dim, dim) 108 | self.proj_drop = nn.Dropout(proj_drop) 109 | 110 | def forward(self, x): 111 | B, N, C = x.shape 112 | qkv = self.qkv(x) 113 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 114 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 115 | 116 | q = q * self.scale 117 | attn = (q @ k.transpose(-2, -1)) 118 | 119 | attn = attn.softmax(dim=-1) 120 | attn = self.attn_drop(attn) 121 | 122 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 123 | x = self.proj(x) 124 | x = self.proj_drop(x) 125 | 126 | return x 127 | 128 | class Block(nn.Module): 129 | 130 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 131 | drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, 132 | norm_layer=nn.LayerNorm, attn_head_dim=None 133 | ): 134 | super().__init__() 135 | 136 | self.norm1 = norm_layer(dim) 137 | self.attn = Attention( 138 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 139 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim 140 | ) 141 | 142 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 143 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 144 | self.norm2 = norm_layer(dim) 145 | mlp_hidden_dim = int(dim * mlp_ratio) 146 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 147 | 148 | def forward(self, x): 149 | x = x + self.drop_path(self.attn(self.norm1(x))) 150 | x = x + self.drop_path(self.mlp(self.norm2(x))) 151 | return x 152 | 153 | class PatchEmbed(nn.Module): 154 | """ Image to Patch Embedding 155 | """ 156 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): 157 | super().__init__() 158 | img_size = to_2tuple(img_size) 159 | patch_size = to_2tuple(patch_size) 160 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) 161 | self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) 162 | self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) 163 | self.img_size = img_size 164 | self.patch_size = patch_size 165 | self.num_patches = num_patches 166 | 167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1)) 168 | 169 | def forward(self, x, **kwargs): 170 | B, C, H, W = x.shape 171 | x = self.proj(x) 172 | Hp, Wp = x.shape[2], x.shape[3] 173 | 174 | x = x.flatten(2).transpose(1, 2) 175 | return x, (Hp, Wp) 176 | 177 | class HybridEmbed(nn.Module): 178 | """ CNN Feature Map Embedding 179 | Extract feature map from CNN, flatten, project to embedding dim. 180 | """ 181 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 182 | super().__init__() 183 | assert isinstance(backbone, nn.Module) 184 | img_size = to_2tuple(img_size) 185 | self.img_size = img_size 186 | self.backbone = backbone 187 | if feature_size is None: 188 | with torch.no_grad(): 189 | training = backbone.training 190 | if training: 191 | backbone.eval() 192 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 193 | feature_size = o.shape[-2:] 194 | feature_dim = o.shape[1] 195 | backbone.train(training) 196 | else: 197 | feature_size = to_2tuple(feature_size) 198 | feature_dim = self.backbone.feature_info.channels()[-1] 199 | self.num_patches = feature_size[0] * feature_size[1] 200 | self.proj = nn.Linear(feature_dim, embed_dim) 201 | 202 | def forward(self, x): 203 | x = self.backbone(x)[-1] 204 | x = x.flatten(2).transpose(1, 2) 205 | x = self.proj(x) 206 | return x 207 | 208 | class ViT(nn.Module): 209 | 210 | def __init__(self, 211 | img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, 212 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 213 | drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, 214 | frozen_stages=-1, ratio=1, last_norm=True, 215 | patch_padding='pad', freeze_attn=False, freeze_ffn=False, 216 | ): 217 | # Protect mutable default arguments 218 | super(ViT, self).__init__() 219 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 220 | self.num_classes = num_classes 221 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 222 | self.frozen_stages = frozen_stages 223 | self.use_checkpoint = use_checkpoint 224 | self.patch_padding = patch_padding 225 | self.freeze_attn = freeze_attn 226 | self.freeze_ffn = freeze_ffn 227 | self.depth = depth 228 | 229 | if hybrid_backbone is not None: 230 | self.patch_embed = HybridEmbed( 231 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 232 | else: 233 | self.patch_embed = PatchEmbed( 234 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) 235 | num_patches = self.patch_embed.num_patches 236 | 237 | # since the pretraining model has class token 238 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 239 | 240 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 241 | 242 | self.blocks = nn.ModuleList([ 243 | Block( 244 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 245 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 246 | ) 247 | for i in range(depth)]) 248 | 249 | self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() 250 | 251 | if self.pos_embed is not None: 252 | trunc_normal_(self.pos_embed, std=.02) 253 | 254 | self._freeze_stages() 255 | 256 | def _freeze_stages(self): 257 | """Freeze parameters.""" 258 | if self.frozen_stages >= 0: 259 | self.patch_embed.eval() 260 | for param in self.patch_embed.parameters(): 261 | param.requires_grad = False 262 | 263 | for i in range(1, self.frozen_stages + 1): 264 | m = self.blocks[i] 265 | m.eval() 266 | for param in m.parameters(): 267 | param.requires_grad = False 268 | 269 | if self.freeze_attn: 270 | for i in range(0, self.depth): 271 | m = self.blocks[i] 272 | m.attn.eval() 273 | m.norm1.eval() 274 | for param in m.attn.parameters(): 275 | param.requires_grad = False 276 | for param in m.norm1.parameters(): 277 | param.requires_grad = False 278 | 279 | if self.freeze_ffn: 280 | self.pos_embed.requires_grad = False 281 | self.patch_embed.eval() 282 | for param in self.patch_embed.parameters(): 283 | param.requires_grad = False 284 | for i in range(0, self.depth): 285 | m = self.blocks[i] 286 | m.mlp.eval() 287 | m.norm2.eval() 288 | for param in m.mlp.parameters(): 289 | param.requires_grad = False 290 | for param in m.norm2.parameters(): 291 | param.requires_grad = False 292 | 293 | def init_weights(self): 294 | """Initialize the weights in backbone. 295 | Args: 296 | pretrained (str, optional): Path to pre-trained weights. 297 | Defaults to None. 298 | """ 299 | def _init_weights(m): 300 | if isinstance(m, nn.Linear): 301 | trunc_normal_(m.weight, std=.02) 302 | if isinstance(m, nn.Linear) and m.bias is not None: 303 | nn.init.constant_(m.bias, 0) 304 | elif isinstance(m, nn.LayerNorm): 305 | nn.init.constant_(m.bias, 0) 306 | nn.init.constant_(m.weight, 1.0) 307 | 308 | self.apply(_init_weights) 309 | 310 | def get_num_layers(self): 311 | return len(self.blocks) 312 | 313 | @torch.jit.ignore 314 | def no_weight_decay(self): 315 | return {'pos_embed', 'cls_token'} 316 | 317 | def forward_features(self, x): 318 | B, C, H, W = x.shape 319 | x, (Hp, Wp) = self.patch_embed(x) 320 | 321 | if self.pos_embed is not None: 322 | # fit for multiple GPU training 323 | # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference 324 | x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] 325 | 326 | for blk in self.blocks: 327 | if self.use_checkpoint: 328 | x = checkpoint.checkpoint(blk, x) 329 | else: 330 | x = blk(x) 331 | 332 | x = self.last_norm(x) 333 | 334 | xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() 335 | 336 | return xp 337 | 338 | def forward(self, x): 339 | x = self.forward_features(x) 340 | return x 341 | 342 | def train(self, mode=True): 343 | """Convert the model into training mode.""" 344 | super().train(mode) 345 | self._freeze_stages() -------------------------------------------------------------------------------- /lib/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaobenLi00/U-HMR/54f2b6516018a79590ea8e5d160fe1f3bd9b7c8d/lib/models/components/__init__.py -------------------------------------------------------------------------------- /lib/models/components/pose_transformer.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | from typing import Callable, Optional 3 | 4 | import torch 5 | from einops import rearrange 6 | from einops.layers.torch import Rearrange 7 | from torch import nn 8 | 9 | from .t_cond_mlp import ( 10 | AdaptiveLayerNorm1D, 11 | FrequencyEmbedder, 12 | normalization_layer, 13 | ) 14 | # from .vit import Attention, FeedForward 15 | 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | 21 | def default(val, d): 22 | if exists(val): 23 | return val 24 | return d() if isfunction(d) else d 25 | 26 | 27 | class PreNorm(nn.Module): 28 | def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1): 29 | super().__init__() 30 | self.norm = normalization_layer(norm, dim, norm_cond_dim) 31 | self.fn = fn 32 | 33 | def forward(self, x: torch.Tensor, *args, **kwargs): 34 | if isinstance(self.norm, AdaptiveLayerNorm1D): 35 | return self.fn(self.norm(x, *args), **kwargs) 36 | else: 37 | return self.fn(self.norm(x), **kwargs) 38 | 39 | 40 | class FeedForward(nn.Module): 41 | def __init__(self, dim, hidden_dim, dropout=0.0): 42 | super().__init__() 43 | self.net = nn.Sequential( 44 | nn.Linear(dim, hidden_dim), 45 | nn.GELU(), 46 | nn.Dropout(dropout), 47 | nn.Linear(hidden_dim, dim), 48 | nn.Dropout(dropout), 49 | ) 50 | 51 | def forward(self, x): 52 | return self.net(x) 53 | 54 | 55 | class Attention(nn.Module): 56 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 57 | super().__init__() 58 | inner_dim = dim_head * heads 59 | project_out = not (heads == 1 and dim_head == dim) 60 | 61 | self.heads = heads 62 | self.scale = dim_head**-0.5 63 | 64 | self.attend = nn.Softmax(dim=-1) 65 | self.dropout = nn.Dropout(dropout) 66 | 67 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 68 | 69 | self.to_out = ( 70 | nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 71 | if project_out 72 | else nn.Identity() 73 | ) 74 | 75 | def forward(self, x): 76 | qkv = self.to_qkv(x).chunk(3, dim=-1) 77 | q, k, v = map(lambda t: rearrange( 78 | t, "b n (h d) -> b h n d", h=self.heads), qkv) 79 | 80 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 81 | 82 | attn = self.attend(dots) 83 | attn = self.dropout(attn) 84 | 85 | out = torch.matmul(attn, v) 86 | out = rearrange(out, "b h n d -> b n (h d)") 87 | return self.to_out(out) 88 | 89 | 90 | class CrossAttention(nn.Module): 91 | def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 92 | super().__init__() 93 | inner_dim = dim_head * heads 94 | project_out = not (heads == 1 and dim_head == dim) 95 | 96 | self.heads = heads 97 | self.scale = dim_head**-0.5 98 | 99 | self.attend = nn.Softmax(dim=-1) 100 | self.dropout = nn.Dropout(dropout) 101 | 102 | context_dim = default(context_dim, dim) 103 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) 104 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 105 | 106 | self.to_out = ( 107 | nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 108 | if project_out 109 | else nn.Identity() 110 | ) 111 | 112 | def forward(self, x, context=None): 113 | context = default(context, x) 114 | k, v = self.to_kv(context).chunk(2, dim=-1) 115 | q = self.to_q(x) 116 | q, k, v = map(lambda t: rearrange( 117 | t, "b n (h d) -> b h n d", h=self.heads), [q, k, v]) 118 | 119 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 120 | 121 | attn = self.attend(dots) 122 | attn = self.dropout(attn) 123 | 124 | out = torch.matmul(attn, v) 125 | out = rearrange(out, "b h n d -> b n (h d)") 126 | return self.to_out(out) 127 | 128 | 129 | class Transformer(nn.Module): 130 | def __init__( 131 | self, 132 | dim: int, 133 | depth: int, 134 | heads: int, 135 | dim_head: int, 136 | mlp_dim: int, 137 | dropout: float = 0.0, 138 | norm: str = "layer", 139 | norm_cond_dim: int = -1, 140 | ): 141 | super().__init__() 142 | self.layers = nn.ModuleList([]) 143 | for _ in range(depth): 144 | sa = Attention(dim, heads=heads, 145 | dim_head=dim_head, dropout=dropout) 146 | ff = FeedForward(dim, mlp_dim, dropout=dropout) 147 | self.layers.append( 148 | nn.ModuleList( 149 | [ 150 | PreNorm(dim, sa, norm=norm, 151 | norm_cond_dim=norm_cond_dim), 152 | PreNorm(dim, ff, norm=norm, 153 | norm_cond_dim=norm_cond_dim), 154 | ] 155 | ) 156 | ) 157 | 158 | def forward(self, x: torch.Tensor, *args): 159 | for attn, ff in self.layers: 160 | x = attn(x, *args) + x 161 | x = ff(x, *args) + x 162 | return x 163 | 164 | 165 | class TransformerCrossAttn(nn.Module): 166 | def __init__( 167 | self, 168 | dim: int, 169 | depth: int, 170 | heads: int, 171 | dim_head: int, 172 | mlp_dim: int, 173 | dropout: float = 0.0, 174 | norm: str = "layer", 175 | norm_cond_dim: int = -1, 176 | context_dim: Optional[int] = None, 177 | ): 178 | super().__init__() 179 | self.layers = nn.ModuleList([]) 180 | for _ in range(depth): 181 | sa = Attention(dim, heads=heads, 182 | dim_head=dim_head, dropout=dropout) 183 | ca = CrossAttention( 184 | dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout 185 | ) 186 | ff = FeedForward(dim, mlp_dim, dropout=dropout) 187 | self.layers.append( 188 | nn.ModuleList( 189 | [ 190 | PreNorm(dim, sa, norm=norm, 191 | norm_cond_dim=norm_cond_dim), 192 | PreNorm(dim, ca, norm=norm, 193 | norm_cond_dim=norm_cond_dim), 194 | PreNorm(dim, ff, norm=norm, 195 | norm_cond_dim=norm_cond_dim), 196 | ] 197 | ) 198 | ) 199 | 200 | def forward(self, x: torch.Tensor, *args, context=None, context_list=None): 201 | if context_list is None: 202 | context_list = [context] * len(self.layers) 203 | if len(context_list) != len(self.layers): 204 | raise ValueError( 205 | f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})") 206 | 207 | for i, (self_attn, cross_attn, ff) in enumerate(self.layers): 208 | x = self_attn(x, *args) + x 209 | x = cross_attn(x, *args, context=context_list[i]) + x 210 | x = ff(x, *args) + x 211 | return x 212 | 213 | 214 | class DropTokenDropout(nn.Module): 215 | def __init__(self, p: float = 0.1): 216 | super().__init__() 217 | if p < 0 or p > 1: 218 | raise ValueError( 219 | "dropout probability has to be between 0 and 1, " "but got {}".format( 220 | p) 221 | ) 222 | self.p = p 223 | 224 | def forward(self, x: torch.Tensor): 225 | # x: (batch_size, seq_len, dim) 226 | if self.training and self.p > 0: 227 | zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool() 228 | # TODO: permutation idx for each batch using torch.argsort 229 | if zero_mask.any(): 230 | x = x[:, ~zero_mask, :] 231 | return x 232 | 233 | 234 | class ZeroTokenDropout(nn.Module): 235 | def __init__(self, p: float = 0.1): 236 | super().__init__() 237 | if p < 0 or p > 1: 238 | raise ValueError( 239 | "dropout probability has to be between 0 and 1, " "but got {}".format( 240 | p) 241 | ) 242 | self.p = p 243 | 244 | def forward(self, x: torch.Tensor): 245 | # x: (batch_size, seq_len, dim) 246 | if self.training and self.p > 0: 247 | zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool() 248 | # Zero-out the masked tokens 249 | x[zero_mask, :] = 0 250 | return x 251 | 252 | 253 | class TransformerEncoder(nn.Module): 254 | def __init__( 255 | self, 256 | num_tokens: int, 257 | token_dim: int, 258 | dim: int, 259 | depth: int, 260 | heads: int, 261 | mlp_dim: int, 262 | dim_head: int = 64, 263 | dropout: float = 0.0, 264 | emb_dropout: float = 0.0, 265 | emb_dropout_type: str = "drop", 266 | emb_dropout_loc: str = "token", 267 | norm: str = "layer", 268 | norm_cond_dim: int = -1, 269 | token_pe_numfreq: int = -1, 270 | ): 271 | super().__init__() 272 | if token_pe_numfreq > 0: 273 | token_dim_new = token_dim * (2 * token_pe_numfreq + 1) 274 | self.to_token_embedding = nn.Sequential( 275 | Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim), 276 | FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1), 277 | Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new), 278 | nn.Linear(token_dim_new, dim), 279 | ) 280 | else: 281 | self.to_token_embedding = nn.Linear(token_dim, dim) 282 | self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) 283 | if emb_dropout_type == "drop": 284 | self.dropout = DropTokenDropout(emb_dropout) 285 | elif emb_dropout_type == "zero": 286 | self.dropout = ZeroTokenDropout(emb_dropout) 287 | else: 288 | raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}") 289 | self.emb_dropout_loc = emb_dropout_loc 290 | 291 | self.transformer = Transformer( 292 | dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim 293 | ) 294 | 295 | def forward(self, inp: torch.Tensor, *args, **kwargs): 296 | x = inp 297 | 298 | if self.emb_dropout_loc == "input": 299 | x = self.dropout(x) 300 | x = self.to_token_embedding(x) 301 | 302 | if self.emb_dropout_loc == "token": 303 | x = self.dropout(x) 304 | b, n, _ = x.shape 305 | x += self.pos_embedding[:, :n] 306 | 307 | if self.emb_dropout_loc == "token_afterpos": 308 | x = self.dropout(x) 309 | x = self.transformer(x, *args) 310 | return x 311 | 312 | 313 | class TransformerDecoder(nn.Module): 314 | def __init__( 315 | self, 316 | num_tokens: int, 317 | token_dim: int, 318 | dim: int, 319 | depth: int, 320 | heads: int, 321 | mlp_dim: int, 322 | dim_head: int = 64, 323 | dropout: float = 0.0, 324 | emb_dropout: float = 0.0, 325 | emb_dropout_type: str = 'drop', 326 | norm: str = "layer", 327 | norm_cond_dim: int = -1, 328 | context_dim: Optional[int] = None, 329 | skip_token_embedding: bool = False, 330 | pos_embedding = None, 331 | ): 332 | super().__init__() 333 | if not skip_token_embedding: 334 | self.to_token_embedding = nn.Linear(token_dim, dim) 335 | else: 336 | self.to_token_embedding = nn.Identity() 337 | if token_dim != dim: 338 | raise ValueError( 339 | f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True" 340 | ) 341 | 342 | self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim)) 343 | 344 | if emb_dropout_type == "drop": 345 | self.dropout = DropTokenDropout(emb_dropout) 346 | elif emb_dropout_type == "zero": 347 | self.dropout = ZeroTokenDropout(emb_dropout) 348 | elif emb_dropout_type == "normal": 349 | self.dropout = nn.Dropout(emb_dropout) 350 | 351 | self.transformer = TransformerCrossAttn( 352 | dim, 353 | depth, 354 | heads, 355 | dim_head, 356 | mlp_dim, 357 | dropout, 358 | norm=norm, 359 | norm_cond_dim=norm_cond_dim, 360 | context_dim=context_dim, 361 | ) 362 | 363 | def forward(self, inp: torch.Tensor, *args, context=None, context_list=None): 364 | x = self.to_token_embedding(inp) 365 | b, n, _ = x.shape 366 | 367 | x = self.dropout(x) 368 | x += self.pos_embedding[:, :n] # torch.Size([8, 1, 768]) torch.Size([2, 1, 2048]) 369 | 370 | x = self.transformer(x, *args, context=context, # context torch.Size([8, 64, 768]) 371 | context_list=context_list) 372 | return x 373 | -------------------------------------------------------------------------------- /lib/models/components/t_cond_mlp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List, Optional 3 | 4 | import torch 5 | 6 | 7 | class AdaptiveLayerNorm1D(torch.nn.Module): 8 | def __init__(self, data_dim: int, norm_cond_dim: int): 9 | super().__init__() 10 | if data_dim <= 0: 11 | raise ValueError(f"data_dim must be positive, but got {data_dim}") 12 | if norm_cond_dim <= 0: 13 | raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}") 14 | self.norm = torch.nn.LayerNorm( 15 | data_dim 16 | ) # TODO: Check if elementwise_affine=True is correct 17 | self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim) 18 | torch.nn.init.zeros_(self.linear.weight) 19 | torch.nn.init.zeros_(self.linear.bias) 20 | 21 | def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 22 | # x: (batch, ..., data_dim) 23 | # t: (batch, norm_cond_dim) 24 | # return: (batch, data_dim) 25 | x = self.norm(x) 26 | alpha, beta = self.linear(t).chunk(2, dim=-1) 27 | 28 | # Add singleton dimensions to alpha and beta 29 | if x.dim() > 2: 30 | alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1]) 31 | beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1]) 32 | 33 | return x * (1 + alpha) + beta 34 | 35 | 36 | class SequentialCond(torch.nn.Sequential): 37 | def forward(self, input, *args, **kwargs): 38 | for module in self: 39 | if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)): 40 | # print(f'Passing on args to {module}', [a.shape for a in args]) 41 | input = module(input, *args, **kwargs) 42 | else: 43 | # print(f'Skipping passing args to {module}', [a.shape for a in args]) 44 | input = module(input) 45 | return input 46 | 47 | 48 | def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1): 49 | if norm == "batch": 50 | return torch.nn.BatchNorm1d(dim) 51 | elif norm == "layer": 52 | return torch.nn.LayerNorm(dim) 53 | elif norm == "ada": 54 | assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}" 55 | return AdaptiveLayerNorm1D(dim, norm_cond_dim) 56 | elif norm is None: 57 | return torch.nn.Identity() 58 | else: 59 | raise ValueError(f"Unknown norm: {norm}") 60 | 61 | 62 | def linear_norm_activ_dropout( 63 | input_dim: int, 64 | output_dim: int, 65 | activation: torch.nn.Module = torch.nn.ReLU(), 66 | bias: bool = True, 67 | norm: Optional[str] = "layer", # Options: ada/batch/layer 68 | dropout: float = 0.0, 69 | norm_cond_dim: int = -1, 70 | ) -> SequentialCond: 71 | layers = [] 72 | layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias)) 73 | if norm is not None: 74 | layers.append(normalization_layer(norm, output_dim, norm_cond_dim)) 75 | layers.append(copy.deepcopy(activation)) 76 | if dropout > 0.0: 77 | layers.append(torch.nn.Dropout(dropout)) 78 | return SequentialCond(*layers) 79 | 80 | 81 | def create_simple_mlp( 82 | input_dim: int, 83 | hidden_dims: List[int], 84 | output_dim: int, 85 | activation: torch.nn.Module = torch.nn.ReLU(), 86 | bias: bool = True, 87 | norm: Optional[str] = "layer", # Options: ada/batch/layer 88 | dropout: float = 0.0, 89 | norm_cond_dim: int = -1, 90 | ) -> SequentialCond: 91 | layers = [] 92 | prev_dim = input_dim 93 | for hidden_dim in hidden_dims: 94 | layers.extend( 95 | linear_norm_activ_dropout( 96 | prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim 97 | ) 98 | ) 99 | prev_dim = hidden_dim 100 | layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias)) 101 | return SequentialCond(*layers) 102 | 103 | 104 | class ResidualMLPBlock(torch.nn.Module): 105 | def __init__( 106 | self, 107 | input_dim: int, 108 | hidden_dim: int, 109 | num_hidden_layers: int, 110 | output_dim: int, 111 | activation: torch.nn.Module = torch.nn.ReLU(), 112 | bias: bool = True, 113 | norm: Optional[str] = "layer", # Options: ada/batch/layer 114 | dropout: float = 0.0, 115 | norm_cond_dim: int = -1, 116 | ): 117 | super().__init__() 118 | if not (input_dim == output_dim == hidden_dim): 119 | raise NotImplementedError( 120 | f"input_dim {input_dim} != output_dim {output_dim} is not implemented" 121 | ) 122 | 123 | layers = [] 124 | prev_dim = input_dim 125 | for i in range(num_hidden_layers): 126 | layers.append( 127 | linear_norm_activ_dropout( 128 | prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim 129 | ) 130 | ) 131 | prev_dim = hidden_dim 132 | self.model = SequentialCond(*layers) 133 | self.skip = torch.nn.Identity() 134 | 135 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: 136 | return x + self.model(x, *args, **kwargs) 137 | 138 | 139 | class ResidualMLP(torch.nn.Module): 140 | def __init__( 141 | self, 142 | input_dim: int, 143 | hidden_dim: int, 144 | num_hidden_layers: int, 145 | output_dim: int, 146 | activation: torch.nn.Module = torch.nn.ReLU(), 147 | bias: bool = True, 148 | norm: Optional[str] = "layer", # Options: ada/batch/layer 149 | dropout: float = 0.0, 150 | num_blocks: int = 1, 151 | norm_cond_dim: int = -1, 152 | ): 153 | super().__init__() 154 | self.input_dim = input_dim 155 | self.model = SequentialCond( 156 | linear_norm_activ_dropout( 157 | input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim 158 | ), 159 | *[ 160 | ResidualMLPBlock( 161 | hidden_dim, 162 | hidden_dim, 163 | num_hidden_layers, 164 | hidden_dim, 165 | activation, 166 | bias, 167 | norm, 168 | dropout, 169 | norm_cond_dim, 170 | ) 171 | for _ in range(num_blocks) 172 | ], 173 | torch.nn.Linear(hidden_dim, output_dim, bias=bias), 174 | ) 175 | 176 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: 177 | return self.model(x, *args, **kwargs) 178 | 179 | 180 | class FrequencyEmbedder(torch.nn.Module): 181 | def __init__(self, num_frequencies, max_freq_log2): 182 | super().__init__() 183 | frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies) 184 | self.register_buffer("frequencies", frequencies) 185 | 186 | def forward(self, x): 187 | # x should be of size (N,) or (N, D) 188 | N = x.size(0) 189 | if x.dim() == 1: # (N,) 190 | x = x.unsqueeze(1) # (N, D) where D=1 191 | x_unsqueezed = x.unsqueeze(-1) # (N, D, 1) 192 | scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies) 193 | s = torch.sin(scaled) 194 | c = torch.cos(scaled) 195 | embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view( 196 | N, -1 197 | ) # (N, D * 2 * num_frequencies + D) 198 | return embedded 199 | 200 | -------------------------------------------------------------------------------- /lib/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Discriminator(nn.Module): 5 | 6 | def __init__(self): 7 | """ 8 | Pose + Shape discriminator proposed in HMR 9 | """ 10 | super(Discriminator, self).__init__() 11 | 12 | self.num_joints = 23 13 | # poses_alone 14 | self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1) 15 | nn.init.xavier_uniform_(self.D_conv1.weight) 16 | nn.init.zeros_(self.D_conv1.bias) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1) 19 | nn.init.xavier_uniform_(self.D_conv2.weight) 20 | nn.init.zeros_(self.D_conv2.bias) 21 | pose_out = [] 22 | for i in range(self.num_joints): 23 | pose_out_temp = nn.Linear(32, 1) 24 | nn.init.xavier_uniform_(pose_out_temp.weight) 25 | nn.init.zeros_(pose_out_temp.bias) 26 | pose_out.append(pose_out_temp) 27 | self.pose_out = nn.ModuleList(pose_out) 28 | 29 | # betas 30 | self.betas_fc1 = nn.Linear(10, 10) 31 | nn.init.xavier_uniform_(self.betas_fc1.weight) 32 | nn.init.zeros_(self.betas_fc1.bias) 33 | self.betas_fc2 = nn.Linear(10, 5) 34 | nn.init.xavier_uniform_(self.betas_fc2.weight) 35 | nn.init.zeros_(self.betas_fc2.bias) 36 | self.betas_out = nn.Linear(5, 1) 37 | nn.init.xavier_uniform_(self.betas_out.weight) 38 | nn.init.zeros_(self.betas_out.bias) 39 | 40 | # poses_joint 41 | self.D_alljoints_fc1 = nn.Linear(32*self.num_joints, 1024) 42 | nn.init.xavier_uniform_(self.D_alljoints_fc1.weight) 43 | nn.init.zeros_(self.D_alljoints_fc1.bias) 44 | self.D_alljoints_fc2 = nn.Linear(1024, 1024) 45 | nn.init.xavier_uniform_(self.D_alljoints_fc2.weight) 46 | nn.init.zeros_(self.D_alljoints_fc2.bias) 47 | self.D_alljoints_out = nn.Linear(1024, 1) 48 | nn.init.xavier_uniform_(self.D_alljoints_out.weight) 49 | nn.init.zeros_(self.D_alljoints_out.bias) 50 | 51 | 52 | def forward(self, poses: torch.Tensor, betas: torch.Tensor) -> torch.Tensor: 53 | """ 54 | Forward pass of the discriminator. 55 | Args: 56 | poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of SMPL body poses (excluding the global orientation). 57 | betas (torch.Tensor): Tensor of shape (B, 10) containign a batch of SMPL beta coefficients. 58 | Returns: 59 | torch.Tensor: Discriminator output with shape (B, 25) 60 | """ 61 | #import ipdb; ipdb.set_trace() 62 | #bn = poses.shape[0] 63 | # poses B x 207 64 | #poses = poses.reshape(bn, -1) 65 | # poses B x num_joints x 1 x 9 66 | poses = poses.reshape(-1, self.num_joints, 1, 9) 67 | bn = poses.shape[0] 68 | # poses B x 9 x num_joints x 1 69 | poses = poses.permute(0, 3, 1, 2).contiguous() 70 | 71 | # poses_alone 72 | poses = self.D_conv1(poses) 73 | poses = self.relu(poses) 74 | poses = self.D_conv2(poses) 75 | poses = self.relu(poses) 76 | 77 | poses_out = [] 78 | for i in range(self.num_joints): 79 | poses_out_ = self.pose_out[i](poses[:, :, i, 0]) 80 | poses_out.append(poses_out_) 81 | poses_out = torch.cat(poses_out, dim=1) 82 | 83 | # betas 84 | betas = self.betas_fc1(betas) 85 | betas = self.relu(betas) 86 | betas = self.betas_fc2(betas) 87 | betas = self.relu(betas) 88 | betas_out = self.betas_out(betas) 89 | 90 | # poses_joint 91 | poses = poses.reshape(bn,-1) 92 | poses_all = self.D_alljoints_fc1(poses) 93 | poses_all = self.relu(poses_all) 94 | poses_all = self.D_alljoints_fc2(poses_all) 95 | poses_all = self.relu(poses_all) 96 | poses_all_out = self.D_alljoints_out(poses_all) 97 | 98 | disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1) 99 | return disc_out 100 | -------------------------------------------------------------------------------- /lib/models/fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from .smpl_wrapper import SMPL 5 | from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss 6 | from typing import Any, Dict, Mapping, Tuple 7 | from ..utils.geometry import perspective_projection, rot6d_to_rotmat, aa_to_rotmat 8 | from ..utils import vis 9 | from ..utils.pose_utils import reconstruction_error 10 | from einops import rearrange 11 | from .discriminator import Discriminator 12 | import torch.nn.functional as F 13 | import math 14 | from tensorboardX import SummaryWriter 15 | import logging 16 | from .backbones import create_backbone 17 | from .heads import build_smpl_head 18 | import trimesh 19 | logger = logging.getLogger(__name__) 20 | 21 | class Mv_Fusion(nn.Module): 22 | def __init__(self, cfg, tensorboard_log_dir) -> None: 23 | super(Mv_Fusion, self).__init__() 24 | self.cfg = cfg 25 | # Create backbone feature extractor 26 | self.backbone = create_backbone(cfg) 27 | if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None): 28 | logger.info(f'=> Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}') 29 | # self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict']) 30 | if cfg.MODEL.BACKBONE.TYPE == 'resnet': 31 | self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu'), strict=False) 32 | elif cfg.MODEL.BACKBONE.TYPE == 'vit': 33 | self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict']) 34 | # self.backbone.eval() 35 | # for param in self.backbone.parameters(): 36 | # param.requires_grad = False 37 | elif cfg.MODEL.BACKBONE.TYPE == 'swin_v2': 38 | self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['model']) 39 | 40 | # Create SMPL head 41 | self.smpl_head = build_smpl_head(cfg) 42 | 43 | # Create discriminator 44 | if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: 45 | self.discriminator = Discriminator() 46 | 47 | # Define loss functions 48 | self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1') 49 | self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1') 50 | self.smpl_parameter_loss = ParameterLoss() 51 | if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: 52 | self.optimizer, self.optimizer_disc = self.get_optimizer() 53 | else: 54 | self.optimizer = self.get_optimizer() 55 | 56 | 57 | smpl_cfg = {k.lower(): v for k,v in dict(cfg.SMPL).items()} 58 | self.smpl = SMPL(**smpl_cfg) 59 | joints_list_17 = [14, 2, 1, 0, 3, 4, 5, 16, 12, 17, 18, 9, 10, 11, 8, 7, 6] 60 | self.joints_list = [i + 25 for i in joints_list_17] 61 | # self.n_views = cfg.DATASET.N_VIEWS 62 | 63 | self.writer = SummaryWriter(log_dir=tensorboard_log_dir) 64 | if cfg.TRAIN.RENDER_MESH: 65 | 66 | from ..utils.renderer import Renderer 67 | self.mesh_renderer = Renderer(focal_length=cfg.EXTRA.FOCAL_LENGTH, 68 | img_res=256, faces=self.smpl.faces) 69 | self.step_count = {'train': 0, 70 | 'val': 0, 71 | 'train_vis': 0, 72 | 'val_vis': 0} 73 | self.log_dict_mpii3d = {'mpjpe': [], 74 | 'pa-mpjpe': []} 75 | # self.epoch = 0 76 | # self.len_data = 0 77 | # self.count_dict = {'epoch': 0, 'len_data': 0} 78 | 79 | def forward_step(self, x, n_views) -> Dict: 80 | x = torch.cat(x, dim = 0) 81 | if self.cfg.MODEL.BACKBONE.TYPE == 'vit': 82 | features = self.backbone(x[:,:,:,32:-32]) 83 | else: 84 | features = self.backbone(x) 85 | pred_body_pose, pred_betas, pred_global_orientation, pred_cam = self.smpl_head(features, n_views) 86 | n_sample = x.shape[0] 87 | 88 | # pred_body_pose = pred_body_pose.repeat(self.n_views, 1, 1, 1) 89 | # pred_betas = pred_betas.repeat(self.n_views, 1) 90 | pred_body_pose = pred_body_pose.repeat(n_views, 1, 1, 1) 91 | pred_betas = pred_betas.repeat(n_views, 1) 92 | pred_pose = torch.cat([pred_global_orientation, pred_body_pose], dim = 1) 93 | 94 | output = self.forward_smpl(pred_pose, pred_betas, pred_cam, n_sample) 95 | 96 | return output 97 | 98 | def compute_loss(self, output, meta, dataset = None): 99 | 100 | pred_smpl_params = output['pred_smpl_params'] 101 | pred_keypoints_2d = output['pred_keypoints_2d'] 102 | pred_keypoints_3d = output['pred_keypoints_3d'] 103 | 104 | n_sample = pred_smpl_params['body_pose'].shape[0] 105 | 106 | # Get annotations 107 | # gt_keypoints_2d = batch['keypoints_2d'] 108 | # gt_keypoints_3d = batch['keypoints_3d'] 109 | gt_2d_list = [] 110 | gt_3d_list = [] 111 | vis_list = [] 112 | for m in meta: 113 | gt_2d_list.append(m['joints_2d_transformed']) 114 | gt_3d_list.append(m['joints_3d_camera']) 115 | vis_list.append(m['joints_vis']) 116 | vis_joints = torch.cat(vis_list, dim=0) 117 | gt_keypoints_2d = torch.cat(gt_2d_list, dim=0) 118 | gt_keypoints_3d = torch.cat(gt_3d_list, dim=0) / 1000 119 | gt_keypoints_2d = gt_keypoints_2d / (self.cfg.MODEL.IMAGE_SIZE[0] / 1.) - 0.5 120 | gt_keypoints_3d = gt_keypoints_3d - gt_keypoints_3d[:, [0], :] 121 | pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, [0], :] 122 | # print(pred_keypoints_3d.shape) 123 | # print(gt_keypoints_3d.shape) 124 | # print(vis_joints.shape) 125 | # print(pred_keypoints_2d.shape) 126 | # print(gt_keypoints_2d.shape) 127 | # Compute 2D and 3D keypoint loss 128 | loss_keypoints_2d = self.keypoint_2d_loss(vis_joints[:,:,[0]]*pred_keypoints_2d, vis_joints[:,:,[0]]*gt_keypoints_2d) 129 | loss_keypoints_3d = self.keypoint_3d_loss(vis_joints[:,:,[0]]*pred_keypoints_3d, vis_joints[:,:,[0]]*gt_keypoints_3d) 130 | loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\ 131 | self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d 132 | # if has_smpl: 133 | # gt_smpl_params = batch['smpl_params'] 134 | gt_global_orient_list = [] 135 | 136 | gt_body_pose_list = [] 137 | gt_shape_list = [] 138 | has_global_orient_list = [] 139 | 140 | has_body_pose_list = [] 141 | has_shape_list = [] 142 | 143 | for m in meta: 144 | # gt_global_orient_list.append(aa_to_rotmat(m['smpl_params']['global_orient'])) 145 | # gt_body_pose_list.append(aa_to_rotmat(m['smpl_params']['body_pose'])) 146 | gt_global_orient_list.append(m['smpl_params']['global_orient']) 147 | gt_body_pose_list.append(m['smpl_params']['body_pose']) 148 | gt_shape_list.append(m['smpl_params']['betas']) 149 | 150 | has_global_orient_list.append(m['has_smpl_params']['global_orient']) 151 | has_body_pose_list.append(m['has_smpl_params']['body_pose']) 152 | has_shape_list.append(m['has_smpl_params']['betas']) 153 | gt_global_orient = torch.cat(gt_global_orient_list, dim=0) 154 | gt_body_pose = torch.cat(gt_body_pose_list, dim=0) 155 | gt_shape = torch.cat(gt_shape_list, dim=0) 156 | has_global_orient = torch.cat(has_global_orient_list, dim=0) 157 | has_body_pose = torch.cat(has_body_pose_list, dim=0) 158 | has_shape = torch.cat(has_shape_list, dim=0) 159 | 160 | gt_smpl_params = {'global_orient': gt_global_orient, 161 | 'body_pose': gt_body_pose, 162 | 'betas': gt_shape} 163 | has_smpl_params = {'global_orient':has_global_orient, 164 | 'body_pose': has_body_pose, 165 | 'betas': has_shape} 166 | # Compute loss on SMPL parameters 167 | loss_smpl_params = {} 168 | for k, pred in pred_smpl_params.items(): 169 | gt = gt_smpl_params[k].view(n_sample, -1) 170 | has_gt = has_smpl_params[k] 171 | loss_smpl_params[k] = self.smpl_parameter_loss(pred.reshape(n_sample, -1), gt.reshape(n_sample, -1), has_gt) 172 | 173 | loss += sum([loss_smpl_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_smpl_params]) 174 | 175 | losses = dict(loss=loss.detach(), 176 | loss_keypoints_2d=loss_keypoints_2d.detach(), 177 | loss_keypoints_3d=loss_keypoints_3d.detach()) 178 | # if has_smpl: 179 | for k, v in loss_smpl_params.items(): 180 | losses['loss_' + k] = v.detach() 181 | 182 | output['losses'] = losses 183 | 184 | # mpjpe = np.mean(np.sqrt(np.sum((pred_keypoints_3d.detach().cpu().numpy() - gt_keypoints_3d.detach().cpu().numpy()) ** 2, axis=-1))) * 1000 185 | # rec_error = reconstruction_error(pred_keypoints_3d.detach().cpu().numpy(), gt_keypoints_3d.detach().cpu().numpy(), reduction='mean') * 1000 186 | 187 | mpjpe =((np.sqrt(np.sum((pred_keypoints_3d.detach().cpu().numpy() - gt_keypoints_3d.detach().cpu().numpy()) ** 2, axis=-1)) * vis_joints[:,:,0].detach().cpu().numpy()).sum() / (vis_joints[:,:,0]+1e-9).detach().cpu().numpy().sum() ) * 1000 188 | 189 | if dataset == 'totalcapture': 190 | idx = [0,1,2,3,4,5,6,7,8,10,11,12,13,14,15,16] 191 | rec_error = reconstruction_error(pred_keypoints_3d[:, idx].detach().cpu().numpy(), gt_keypoints_3d[:, idx].detach().cpu().numpy(), reduction='mean') * 1000 192 | else: 193 | rec_error = reconstruction_error(pred_keypoints_3d.detach().cpu().numpy(), gt_keypoints_3d.detach().cpu().numpy(), reduction='mean') * 1000 194 | 195 | self.log_dict_mpii3d['mpjpe'].append(mpjpe) 196 | self.log_dict_mpii3d['pa-mpjpe'].append(rec_error) 197 | 198 | metrics = dict(mpjpe=mpjpe, rec_error=rec_error) 199 | output['metrics'] = metrics 200 | return loss 201 | 202 | def training_step_discriminator(self, batch: Dict, 203 | body_pose: torch.Tensor, 204 | betas: torch.Tensor, 205 | optimizer: torch.optim.Optimizer) -> torch.Tensor: 206 | """ 207 | Run a discriminator training step 208 | Args: 209 | batch (Dict): Dictionary containing mocap batch data 210 | body_pose (torch.Tensor): Regressed body pose from current step 211 | betas (torch.Tensor): Regressed betas from current step 212 | optimizer (torch.optim.Optimizer): Discriminator optimizer 213 | Returns: 214 | torch.Tensor: Discriminator loss 215 | """ 216 | batch_size = body_pose.shape[0] 217 | gt_body_pose = batch['body_pose'] 218 | gt_betas = batch['betas'] 219 | gt_rotmat = aa_to_rotmat(gt_body_pose.view(-1,3)).view(batch_size, -1, 3, 3) 220 | disc_fake_out = self.discriminator(body_pose.detach(), betas.detach()) 221 | loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size 222 | disc_real_out = self.discriminator(gt_rotmat, gt_betas) 223 | loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size 224 | loss_disc = loss_fake + loss_real 225 | loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc 226 | optimizer.zero_grad() 227 | loss.backward() 228 | optimizer.step() 229 | return loss_disc.detach() 230 | 231 | def forward(self, x, meta, batch_idx, mocap, meters, len_data, n_views, epoch = 0, train: bool = True, dataset = None): 232 | output = self.forward_step(x, n_views) 233 | pred_smpl_params = output['pred_smpl_params'] 234 | loss = self.compute_loss(output, meta, dataset) 235 | n_samples = mocap['body_pose'].shape[0] 236 | if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: 237 | disc_out = self.discriminator(pred_smpl_params['body_pose'].reshape(n_samples, -1), pred_smpl_params['betas'].reshape(n_samples, -1)) 238 | loss_adv = ((disc_out - 1.0) ** 2).sum() / n_samples 239 | loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv 240 | if train: 241 | self.optimizer.zero_grad() 242 | loss.backward() 243 | self.optimizer.step() 244 | if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: 245 | loss_disc = self.training_step_discriminator(mocap, pred_smpl_params['body_pose'].reshape(n_samples, -1), pred_smpl_params['betas'].reshape(n_samples, -1), self.optimizer_disc) 246 | output['losses']['loss_gen'] = loss_adv 247 | output['losses']['loss_disc'] = loss_disc 248 | self.tensorboard_logging(x, output, self.step_count, batch_idx, meters, len_data, epoch, train) 249 | if not train: 250 | return output['metrics'] 251 | return None 252 | 253 | 254 | def get_optimizer(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: 255 | 256 | all_params = list(self.backbone.parameters()) + list(self.smpl_head.parameters()) 257 | param_groups = [{'params': filter(lambda p: p.requires_grad, all_params), 'lr': self.cfg.TRAIN.LR}] 258 | 259 | optimizer = torch.optim.AdamW(params=param_groups, 260 | # lr=self.cfg.TRAIN.LR, 261 | weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) 262 | if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0: 263 | optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(), 264 | lr=self.cfg.TRAIN.LR, 265 | # lr=self.cfg.TRAIN.LR * 10, 266 | weight_decay=self.cfg.TRAIN.WEIGHT_DECAY) 267 | return optimizer, optimizer_disc 268 | return optimizer 269 | 270 | def tensorboard_logging(self, input, output, step_count, batch_idx, meters, len_data, epoch, train: bool = True): 271 | 272 | mode = 'train' if train else 'val' 273 | losses = output['losses'] 274 | # n_samples = output['pred_keypoints_2d'].shape[0] 275 | for loss_name, val in losses.items(): 276 | self.writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count['{}'.format(mode)]) 277 | meters['{}_loss'.format(mode)].update(output['losses']['loss']) 278 | meters['{}_mpjpe'.format(mode)].update(output['metrics']['mpjpe']) 279 | meters['{}_rec_error'.format(mode)].update(output['metrics']['rec_error']) 280 | step_count['{}'.format(mode)] += 1 281 | if batch_idx % self.cfg.TRAIN.LOG_INTERVAL == 0: 282 | # images = torch.cat(input, dim=0) 283 | # images = rearrange(images, "(n b) c d e -> (b n) c d e", n=4) 284 | # pred_keypoints_2d = rearrange(output['pred_keypoints_2d'], "(n b) c d -> (b n) c d", n=self.n_views) 285 | # keypoints_2d_vis = vis.visualize_2d_pose(images, pred_keypoints_2d) 286 | # self.writer.add_image(mode + '/pred_2d', keypoints_2d_vis, step_count['{}_vis'.format(mode)]) 287 | if self.cfg.TRAIN.RENDER_MESH: 288 | images = images.detach() * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1) 289 | images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1) 290 | pred_vertice = rearrange(output['pred_vertices'], "(n b) c d -> (b n) c d", n=4) 291 | pred_cam_t = rearrange(output['pred_cam_t'], "(n b) c -> (b n) c", n=4) 292 | images_pred = self.mesh_renderer.visualize_tb(pred_vertice.detach(), pred_cam_t.detach(), images) 293 | self.writer.add_image(mode + '/pred_shape', images_pred, step_count['{}_vis'.format(mode)]) 294 | # mesh_vertices = rearrange(pred_vertice, "(b n) c d -> b n c d", n=4) 295 | # for b in range(mesh_vertices.shape[0]): 296 | # for n in range(mesh_vertices.shape[1]): 297 | # mesh_vertice = mesh_vertices.clone().detach().cpu().numpy()[b, n] 298 | # vertex_colors = np.ones([mesh_vertice.shape[0], 4]) * [0.82, 0.9, 0.98, 1.0] 299 | # face_colors = np.ones([self.smpl.faces.shape[0], 4]) * [0.82, 0.9, 0.98, 1.0] 300 | # mesh = trimesh.Trimesh(mesh_vertice, self.smpl.faces, face_colors=face_colors, vertex_colors=vertex_colors, process=False) 301 | # mesh.export('/home/benlee/projects/HPE/multi_view_trans/output/h36m' + '/meshes/' + 'mesh_{}_{}_{}.obj'.format(step_count['{}_vis'.format(mode)] ,b, n)) 302 | step_count['{}_vis'.format(mode)] += 1 303 | if train: 304 | msg = (f'Epoch: [{epoch}][{batch_idx}/{len_data}]\t' 305 | f'Loss: {meters["{}_loss".format(mode)].val:.5f} ({meters["{}_loss".format(mode)].avg:.5f})\t' 306 | f'MPJPE: {meters["{}_mpjpe".format(mode)].val:.3f} ({meters["{}_mpjpe".format(mode)].avg:.3f})\t' 307 | f'REC_ERROR: {meters["{}_rec_error".format(mode)].val:.3f} ({meters["{}_rec_error".format(mode)].avg:.3f})') 308 | else: 309 | msg = (f'Test: [{batch_idx}/{len_data}]\t' 310 | f'Loss: {meters["{}_loss".format(mode)].val:.5f} ({meters["{}_loss".format(mode)].avg:.5f})\t' 311 | f'MPJPE: {meters["{}_mpjpe".format(mode)].val:.3f} ({meters["{}_mpjpe".format(mode)].avg:.3f})\t' 312 | f'REC_ERROR: {meters["{}_rec_error".format(mode)].val:.3f} ({meters["{}_rec_error".format(mode)].avg:.3f})') 313 | logger.info(msg) 314 | 315 | def forward_smpl(self, pred_rotmat, pred_betas, pred_cam, n_sample): 316 | 317 | pred_smpl_params = {'global_orient': pred_rotmat[:, [0]], 318 | 'body_pose': pred_rotmat[:, 1:], 319 | 'betas': pred_betas} 320 | output = {} 321 | output['pred_cam'] = pred_cam 322 | output['pred_smpl_params'] = pred_smpl_params 323 | 324 | # Compute camera translation 325 | device = pred_smpl_params['body_pose'].device 326 | dtype = pred_smpl_params['body_pose'].dtype 327 | focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(n_sample, 2, device=device, dtype=dtype) 328 | pred_cam_t = torch.stack([pred_cam[:, 1], 329 | pred_cam[:, 2], 330 | 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE[0] * pred_cam[:, 0] +1e-9)],dim=-1) 331 | output['pred_cam_t'] = pred_cam_t 332 | output['focal_length'] = focal_length 333 | 334 | # Compute model vertices, joints and the projected joints 335 | pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(n_sample, -1, 3, 3) 336 | pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(n_sample, -1, 3, 3) 337 | pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(n_sample, -1) 338 | smpl_output = self.smpl(**{k: v.float() for k,v in pred_smpl_params.items()}, pose2rot=False) 339 | pred_keypoints_3d = smpl_output.joints[:, self.joints_list, :] # (8, 44, 3) -> (8, 17, 3) 340 | pred_vertices = smpl_output.vertices 341 | output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(n_sample, -1, 3) 342 | output['pred_vertices'] = pred_vertices.reshape(n_sample, -1, 3) 343 | pred_cam_t = pred_cam_t.reshape(-1, 3) 344 | focal_length = focal_length.reshape(-1, 2) 345 | pred_keypoints_2d = perspective_projection(pred_keypoints_3d, 346 | translation=pred_cam_t, 347 | focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE[0]) # range [-0.5, 0.5] 348 | output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(n_sample, -1, 2) 349 | return output 350 | 351 | if __name__ == "__main__": 352 | from ..utils.config import get_config 353 | from ..utils.log_utils import create_logger 354 | cfg_name = 'swin.yaml' 355 | cfg = get_config('../../experiments/h36m/{}'.format(cfg_name), merge= False) 356 | if cfg.IS_TRAIN: 357 | phase = 'train' 358 | else: 359 | phase = 'test' 360 | logger, final_output_dir, tensorboard_log_dir = create_logger(cfg, cfg_name, phase) 361 | model = Mv_Fusion(cfg, tensorboard_log_dir) 362 | x = [torch.randn(2, 3, 256, 256) for i in range(4)] -------------------------------------------------------------------------------- /lib/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .smpl_head import build_smpl_head 2 | -------------------------------------------------------------------------------- /lib/models/heads/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from mmdetection (https://github.com/open-mmlab/mmdetection) 5 | # Copyright (c) OpenMMLab. All rights reserved. 6 | # ------------------------------------------------------------------------ 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | # from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING 12 | # from mmcv.runner import BaseModule 13 | 14 | # @POSITIONAL_ENCODING.register_module() 15 | class SinePositionalEncoding3D(nn.Module): 16 | """Position encoding with sine and cosine functions. 17 | See `End-to-End Object Detection with Transformers 18 | `_ for details. 19 | Args: 20 | num_feats (int): The feature dimension for each position 21 | along x-axis or y-axis. Note the final returned dimension 22 | for each position is 2 times of this value. 23 | temperature (int, optional): The temperature used for scaling 24 | the position embedding. Defaults to 10000. 25 | normalize (bool, optional): Whether to normalize the position 26 | embedding. Defaults to False. 27 | scale (float, optional): A scale factor that scales the position 28 | embedding. The scale will be used only when `normalize` is True. 29 | Defaults to 2*pi. 30 | eps (float, optional): A value added to the denominator for 31 | numerical stability. Defaults to 1e-6. 32 | offset (float): offset add to embed when do the normalization. 33 | Defaults to 0. 34 | init_cfg (dict or list[dict], optional): Initialization config dict. 35 | Default: None 36 | """ 37 | 38 | def __init__(self, 39 | context_dim, 40 | temperature=10000, 41 | normalize=False, 42 | scale=2 * math.pi, 43 | eps=1e-6, 44 | offset=0.): 45 | super(SinePositionalEncoding3D, self).__init__() 46 | if normalize: 47 | assert isinstance(scale, (float, int)), 'when normalize is set,' \ 48 | 'scale should be provided and in float or int type, ' \ 49 | f'found {type(scale)}' 50 | self.num_feats = context_dim // 2 51 | self.context_dim = context_dim 52 | self.temperature = temperature 53 | self.normalize = normalize 54 | self.scale = scale 55 | self.eps = eps 56 | self.offset = offset 57 | self.adapt_pos3d = nn.Sequential( 58 | nn.Conv2d(self.context_dim*3//2, self.context_dim*4, kernel_size=1, stride=1, padding=0), 59 | nn.ReLU(), 60 | nn.Conv2d(self.context_dim*4, self.context_dim, kernel_size=1, stride=1, padding=0), 61 | ) 62 | 63 | def forward(self, x): 64 | """Forward function for `SinePositionalEncoding`. 65 | Args: 66 | mask (Tensor): ByteTensor mask. Non-zero values representing 67 | ignored positions, while zero values means valid positions 68 | for this image. Shape [bs, N, h, w]. 69 | Returns: 70 | pos (Tensor): Returned position embedding with shape 71 | [bs, num_feats*2, h, w]. 72 | """ 73 | # For convenience of exporting to ONNX, it's required to convert 74 | # `masks` from bool to int. 75 | # mask = mask.to(torch.int) 76 | # not_mask = 1 - mask # logical_not 77 | not_mask = torch.ones(x.shape[0], x.shape[1], x.shape[3], x.shape[4], dtype=torch.bool, device=x.device) 78 | n_embed = not_mask.cumsum(1, dtype=torch.float32) 79 | y_embed = not_mask.cumsum(2, dtype=torch.float32) 80 | x_embed = not_mask.cumsum(3, dtype=torch.float32) 81 | if self.normalize: 82 | n_embed = (n_embed + self.offset) / \ 83 | (n_embed[:, -1:, :, :] + self.eps) * self.scale 84 | y_embed = (y_embed + self.offset) / \ 85 | (y_embed[:, :, -1:, :] + self.eps) * self.scale 86 | x_embed = (x_embed + self.offset) / \ 87 | (x_embed[:, :, :, -1:] + self.eps) * self.scale 88 | dim_t = torch.arange( 89 | self.num_feats, dtype=torch.float32, device=not_mask.device) 90 | dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) 91 | pos_n = n_embed[:, :, :, :, None] / dim_t 92 | pos_x = x_embed[:, :, :, :, None] / dim_t 93 | pos_y = y_embed[:, :, :, :, None] / dim_t 94 | # use `view` instead of `flatten` for dynamically exporting to ONNX 95 | B, N, H, W = not_mask.size() 96 | pos_n = torch.stack( 97 | (pos_n[:, :, :, :, 0::2].sin(), pos_n[:, :, :, :, 1::2].cos()), 98 | dim=4).view(B, N, H, W, -1) 99 | pos_x = torch.stack( 100 | (pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), 101 | dim=4).view(B, N, H, W, -1) 102 | pos_y = torch.stack( 103 | (pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), 104 | dim=4).view(B, N, H, W, -1) 105 | pos = torch.cat((pos_n, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) 106 | 107 | pos = self.adapt_pos3d(pos.flatten(0, 1)).view(x.size()) 108 | 109 | return pos 110 | 111 | def __repr__(self): 112 | """str: a string that describes the module""" 113 | repr_str = self.__class__.__name__ 114 | repr_str += f'(num_feats={self.num_feats}, ' 115 | repr_str += f'temperature={self.temperature}, ' 116 | repr_str += f'normalize={self.normalize}, ' 117 | repr_str += f'scale={self.scale}, ' 118 | repr_str += f'eps={self.eps})' 119 | return repr_str 120 | 121 | 122 | # @POSITIONAL_ENCODING.register_module() 123 | class LearnedPositionalEncoding3D(nn.Module): 124 | """Position embedding with learnable embedding weights. 125 | Args: 126 | num_feats (int): The feature dimension for each position 127 | along x-axis or y-axis. The final returned dimension for 128 | each position is 2 times of this value. 129 | row_num_embed (int, optional): The dictionary size of row embeddings. 130 | Default 50. 131 | col_num_embed (int, optional): The dictionary size of col embeddings. 132 | Default 50. 133 | init_cfg (dict or list[dict], optional): Initialization config dict. 134 | """ 135 | 136 | def __init__(self, 137 | num_feats, 138 | row_num_embed=50, 139 | col_num_embed=50, 140 | init_cfg=dict(type='Uniform', layer='Embedding')): 141 | super(LearnedPositionalEncoding3D, self).__init__(init_cfg) 142 | self.row_embed = nn.Embedding(row_num_embed, num_feats) 143 | self.col_embed = nn.Embedding(col_num_embed, num_feats) 144 | self.num_feats = num_feats 145 | self.row_num_embed = row_num_embed 146 | self.col_num_embed = col_num_embed 147 | 148 | def forward(self, mask): 149 | """Forward function for `LearnedPositionalEncoding`. 150 | Args: 151 | mask (Tensor): ByteTensor mask. Non-zero values representing 152 | ignored positions, while zero values means valid positions 153 | for this image. Shape [bs, h, w]. 154 | Returns: 155 | pos (Tensor): Returned position embedding with shape 156 | [bs, num_feats*2, h, w]. 157 | """ 158 | h, w = mask.shape[-2:] 159 | x = torch.arange(w, device=mask.device) 160 | y = torch.arange(h, device=mask.device) 161 | x_embed = self.col_embed(x) 162 | y_embed = self.row_embed(y) 163 | pos = torch.cat(( 164 | x_embed.unsqueeze(0).repeat(h, 1, 1), 165 | y_embed.unsqueeze(1).repeat(1, w, 1) 166 | ),dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) 167 | return pos 168 | 169 | def __repr__(self): 170 | """str: a string that describes the module""" 171 | repr_str = self.__class__.__name__ 172 | repr_str += f'(num_feats={self.num_feats}, ' 173 | repr_str += f'row_num_embed={self.row_num_embed}, ' 174 | repr_str += f'col_num_embed={self.col_num_embed})' 175 | return repr_str 176 | 177 | 178 | def build_position_encoding(cfg): 179 | # num_feats = cfg.MODEL.SMPL_HEAD.TRANSFORMER_DECODER.context_dim // 2 180 | if cfg.MODEL.SMPL_HEAD.POSITIONAL_ENCODING == 'SinePositionalEncoding3D': 181 | # TODO find a better way of exposing other arguments 182 | position_embedding = SinePositionalEncoding3D(context_dim=cfg.MODEL.SMPL_HEAD.TRANSFORMER_DECODER.context_dim, normalize=True) 183 | # elif cfg.MODEL.SMPL_HEAD.position_embedding in ('v3', 'learned'): 184 | # position_embedding = PositionEmbeddingLearned(N_steps) 185 | # else: 186 | # raise ValueError(f"not supported {args.position_embedding}") 187 | 188 | return position_embedding 189 | 190 | if __name__ == "__main__": 191 | position_embedding = SinePositionalEncoding3D(context_dim=512) 192 | # position_embedding = PositionEmbeddingSine(num_pos_feats=256) 193 | pos_embed = position_embedding(torch.randn(2, 4, 512, 8, 8)) 194 | print(pos_embed.shape) -------------------------------------------------------------------------------- /lib/models/heads/smpl_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import einops 6 | from einops import rearrange 7 | 8 | 9 | from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat 10 | from ..components.pose_transformer import TransformerDecoder 11 | from .position_encoding import build_position_encoding 12 | 13 | 14 | def build_smpl_head(cfg): 15 | smpl_head_type = cfg.MODEL.SMPL_HEAD.get('TYPE', 'hmr') 16 | if smpl_head_type == 'transformer_decoder': 17 | return SMPLTransformerDecoderHead(cfg) 18 | if smpl_head_type == 'fcn': 19 | return SMPLFCNHead(cfg) 20 | if smpl_head_type == 'transformer_decoder_token': 21 | return SMPLTransformerDecoderTokenHead(cfg) 22 | if smpl_head_type == 'fcn_fusion': 23 | return SMPLFCNFusionHead(cfg) 24 | else: 25 | raise ValueError('Unknown SMPL head type: {}'.format(smpl_head_type)) 26 | 27 | 28 | class SMPLTransformerDecoderHead(nn.Module): 29 | """ Cross-attention based SMPL Transformer decoder 30 | """ 31 | 32 | def __init__(self, cfg): 33 | super().__init__() 34 | self.cfg = cfg 35 | # self.n_views = cfg.DATASET.N_VIEWS 36 | self.joint_rep_type = cfg.MODEL.SMPL_HEAD.get('JOINT_REP', '6d') 37 | self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type] 38 | npose = self.joint_rep_dim * (cfg.SMPL.NUM_BODY_JOINTS) 39 | self.npose = npose 40 | self.input_is_mean_shape = cfg.MODEL.SMPL_HEAD.get( 41 | 'TRANSFORMER_INPUT', 'zero') == 'mean_shape' 42 | if cfg.MODEL.BACKBONE.TYPE == 'resnet': 43 | transformer_args = dict( 44 | num_tokens=1, 45 | token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1, 46 | dim=1024, 47 | ) 48 | else: 49 | transformer_args = dict( 50 | num_tokens=1, 51 | token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1, 52 | dim=1024, 53 | ) 54 | transformer_args = {**transformer_args, ** 55 | dict(cfg.MODEL.SMPL_HEAD.TRANSFORMER_DECODER)} 56 | self.transformer = TransformerDecoder( 57 | **transformer_args 58 | ) 59 | if cfg.MODEL.SMPL_HEAD.POSITIONAL_ENCODING == 'SinePositionalEncoding3D': 60 | # TODO find a better way of exposing other arguments 61 | self.position_embedding = build_position_encoding(cfg) 62 | dim = transformer_args['dim'] 63 | context_dim = transformer_args['context_dim'] 64 | # self.decpose = nn.Linear(dim, npose) 65 | # self.decglobalorientation = nn.Linear(context_dim, 6) 66 | # self.decshape = nn.Linear(dim, 10) 67 | # self.deccam = nn.Linear(context_dim, 3) 68 | 69 | self.decpose = nn.Linear(dim, npose) 70 | self.decshape = nn.Linear(dim, 10) 71 | 72 | self.fc1 = nn.Linear(context_dim, 1024) 73 | self.drop1 = nn.Dropout() 74 | self.fc2 = nn.Linear(1024, 1024) 75 | self.drop2 = nn.Dropout() 76 | self.decglobalorientation = nn.Linear(1024, 6) 77 | self.deccam = nn.Linear(1024, 3) 78 | 79 | 80 | 81 | # self.avgpool = nn.AvgPool2d(8, stride=1) 82 | if cfg.MODEL.BACKBONE.TYPE == 'resnet' or cfg.MODEL.BACKBONE.TYPE == 'swin_v2': 83 | self.avgpool = nn.AvgPool2d(8, stride=1) 84 | elif cfg.MODEL.BACKBONE.TYPE == 'vit': 85 | self.avgpool = nn.AvgPool2d((16, 12), stride=1) 86 | 87 | if cfg.MODEL.SMPL_HEAD.get('INIT_DECODER_XAVIER', False): 88 | # True by default in MLP. False by default in Transformer 89 | nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) 90 | nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) 91 | nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) 92 | nn.init.xavier_uniform_(self.decglobalorientation.weight, gain=0.01) 93 | 94 | mean_params = np.load(cfg.SMPL.MEAN_PARAMS) 95 | init_pose = torch.from_numpy( 96 | mean_params['pose'].astype(np.float32)).unsqueeze(0) 97 | init_betas = torch.from_numpy( 98 | mean_params['shape'].astype('float32')).unsqueeze(0) 99 | init_cam = torch.from_numpy( 100 | mean_params['cam'].astype(np.float32)).unsqueeze(0) 101 | self.register_buffer('init_global_orientation', init_pose[:, :6]) 102 | self.register_buffer('init_body_pose', init_pose[:, 6:]) 103 | self.register_buffer('init_betas', init_betas) 104 | self.register_buffer('init_cam', init_cam) 105 | 106 | def forward(self, features, n_views, **kwargs): 107 | features_pooled = self.avgpool(features) 108 | features_pooled = features_pooled.view(features_pooled.size(0), -1) # (8, 2048) 109 | # features = features.view(features.size(0), -1, features.size(1)) # (8, 64, 2048) 110 | # features = rearrange(features, "(n b) c h w-> b n c h w", n=self.n_views) # (2, 256, 2048) 111 | features = rearrange(features, "(n b) c h w-> b n c h w", n=n_views) # (2, 256, 2048) 112 | 113 | if self.cfg.MODEL.SMPL_HEAD.POSITIONAL_ENCODING == 'SinePositionalEncoding3D': 114 | pos_embed = self.position_embedding(features) 115 | features = features + pos_embed 116 | features = rearrange(features, "b n c h w-> b (n h w) c") # (512, 2048, 8, 8) 117 | 118 | 119 | # batch_size = xf_b.shape[0] 120 | n_sample = features_pooled.shape[0] # 8 = 4 * 2 121 | # batch_size = n_sample // self.n_views # 2 122 | batch_size = n_sample // n_views # 2 123 | init_body_pose = self.init_body_pose.expand(batch_size, -1) # (2, 138) 124 | init_betas = self.init_betas.expand(batch_size, -1) # (2, 10) 125 | init_cam = self.init_cam.expand(n_sample, -1) # (8, 3) 126 | init_global_orientation = self.init_global_orientation.expand( 127 | n_sample, -1) # (8, 6) 128 | 129 | # TODO: Convert init_body_pose to aa rep if needed 130 | if self.joint_rep_type == 'aa': 131 | raise NotImplementedError 132 | 133 | pred_body_pose = init_body_pose 134 | pred_betas = init_betas 135 | pred_cam = init_cam 136 | pred_global_orientation = init_global_orientation 137 | 138 | for i in range(self.cfg.MODEL.SMPL_HEAD.get('IEF_ITERS', 1)): 139 | # Input token to transformer is zero token 140 | if self.input_is_mean_shape: 141 | token = torch.cat([pred_body_pose, pred_betas, pred_cam], dim=1)[ 142 | :, None, :] 143 | else: 144 | token = torch.zeros(batch_size, 1, 1).to(features.device) 145 | 146 | # Pass through transformer 147 | token_out = self.transformer(token, context=features) 148 | token_out = token_out.squeeze(1) # (B, C) 149 | 150 | # Readout from token_out 151 | pred_body_pose = self.decpose(token_out) + pred_body_pose 152 | pred_betas = self.decshape(token_out) + pred_betas 153 | 154 | xf = self.fc1(features_pooled) 155 | xf = self.drop1(xf) 156 | xf = self.fc2(xf) 157 | xf = self.drop2(xf) 158 | pred_global_orientation = self.decglobalorientation(xf) + pred_global_orientation 159 | pred_cam = self.deccam(xf) + pred_cam 160 | 161 | # pred_global_orientation = self.decglobalorientation( 162 | # features_pooled) + pred_global_orientation 163 | # pred_cam = self.deccam(features_pooled) + pred_cam 164 | 165 | 166 | # Convert self.joint_rep_type -> rotmat 167 | joint_conversion_fn = { 168 | '6d': rot6d_to_rotmat, 169 | 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous()) 170 | }[self.joint_rep_type] 171 | 172 | 173 | pred_body_pose = joint_conversion_fn(pred_body_pose).view( 174 | batch_size, self.cfg.SMPL.NUM_BODY_JOINTS, 3, 3) 175 | pred_global_orientation = joint_conversion_fn(pred_global_orientation).view( 176 | n_sample, 1, 3, 3) 177 | 178 | return pred_body_pose, pred_betas, pred_global_orientation, pred_cam 179 | 180 | class SMPLFCNHead(nn.Module): 181 | def __init__(self, cfg): 182 | super().__init__() 183 | self.cfg = cfg 184 | self.n_views = cfg.DATASET.N_VIEWS 185 | # self.batch_size = cfg.TRAIN.BATCH_SIZE 186 | self.joint_rep_type = cfg.MODEL.SMPL_HEAD.get('JOINT_REP', '6d') 187 | self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type] 188 | npose = self.joint_rep_dim * (cfg.SMPL.NUM_BODY_JOINTS) 189 | self.npose = npose 190 | 191 | self.fc1 = nn.Linear(self.cfg.MODEL.SMPL_HEAD.IN_CHANNELS * self.n_views + npose + 10 + 6 * self.n_views + 3 * self.n_views, 1024) 192 | self.drop1 = nn.Dropout() 193 | self.fc2 = nn.Linear(1024, 1024) 194 | self.drop2 = nn.Dropout() 195 | 196 | self.decpose = nn.Linear(1024, npose) 197 | self.decshape = nn.Linear(1024, 10) 198 | self.deccam = nn.Linear(1024, 3 * self.n_views) 199 | self.decglobalorientation = nn.Linear(1024, 6 * self.n_views) 200 | if cfg.MODEL.BACKBONE.TYPE == 'resnet' or cfg.MODEL.BACKBONE.TYPE == 'swin_v2': 201 | self.avgpool = nn.AvgPool2d(8, stride=1) 202 | elif cfg.MODEL.BACKBONE.TYPE == 'vit': 203 | self.avgpool = nn.AvgPool2d((16, 12), stride=1) 204 | if cfg.MODEL.SMPL_HEAD.get('INIT_DECODER_XAVIER', False): 205 | # True by default in MLP. False by default in Transformer 206 | nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) 207 | nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) 208 | nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) 209 | nn.init.xavier_uniform_(self.decglobalorientation.weight, gain=0.01) 210 | 211 | mean_params = np.load(cfg.SMPL.MEAN_PARAMS) 212 | init_pose = torch.from_numpy( 213 | mean_params['pose'].astype(np.float32)).unsqueeze(0) 214 | init_betas = torch.from_numpy( 215 | mean_params['shape'].astype('float32')).unsqueeze(0) 216 | init_cam = torch.from_numpy( 217 | mean_params['cam'].astype(np.float32)).unsqueeze(0) 218 | self.register_buffer('init_global_orientation', init_pose[:, :6]) 219 | self.register_buffer('init_body_pose', init_pose[:, 6:]) 220 | self.register_buffer('init_betas', init_betas) 221 | self.register_buffer('init_cam', init_cam) 222 | def forward(self, features, n_iter=3): 223 | features = self.avgpool(features) 224 | features = features.view(features.size(0), -1) 225 | features = rearrange(features, "(n b) c -> b (n c)", n=self.n_views) 226 | # n_sample = features.shape[0] # 8 = 4 * 2 227 | batch_size = features.size(0) # 2 228 | init_body_pose = self.init_body_pose.expand(batch_size, -1) # (2, 138) 229 | init_betas = self.init_betas.expand(batch_size, -1) # (2, 10) 230 | init_cam = self.init_cam.expand(batch_size * self.n_views, -1).contiguous( 231 | ).view(batch_size, -1) # (8, 3) 232 | init_global_orientation = self.init_global_orientation.expand( 233 | batch_size * self.n_views, -1).contiguous( 234 | ).view(batch_size, -1) # (8, 6) 235 | # TODO: Convert init_body_pose to aa rep if needed 236 | if self.joint_rep_type == 'aa': 237 | raise NotImplementedError 238 | 239 | pred_body_pose = init_body_pose 240 | pred_betas = init_betas 241 | pred_cam = init_cam 242 | pred_global_orientation = init_global_orientation 243 | for i in range(n_iter): 244 | xc = torch.cat([features, pred_body_pose, pred_betas, pred_global_orientation, pred_cam], 1) 245 | xc = self.fc1(xc) 246 | xc = self.drop1(xc) 247 | xc = self.fc2(xc) 248 | xc = self.drop2(xc) 249 | pred_body_pose = self.decpose(xc) + pred_body_pose 250 | pred_betas = self.decshape(xc) + pred_betas 251 | pred_global_orientation = self.decglobalorientation(xc) + pred_global_orientation 252 | pred_cam = self.deccam(xc) + pred_cam 253 | joint_conversion_fn = { 254 | '6d': rot6d_to_rotmat, 255 | 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous()) 256 | }[self.joint_rep_type] 257 | 258 | 259 | pred_body_pose = joint_conversion_fn(pred_body_pose).view( 260 | batch_size, self.cfg.SMPL.NUM_BODY_JOINTS, 3, 3) 261 | 262 | pred_global_orientation = rearrange(pred_global_orientation, "b (n c) -> (n b) c", n=self.n_views) 263 | pred_global_orientation = joint_conversion_fn(pred_global_orientation).view( 264 | batch_size * self.n_views, 1, 3, 3) 265 | pred_cam = rearrange(pred_cam, "b (n c) -> (n b) c", n=self.n_views) 266 | return pred_body_pose, pred_betas, pred_global_orientation, pred_cam 267 | 268 | class SMPLTransformerDecoderTokenHead(nn.Module): 269 | def __init__(self, cfg): 270 | super().__init__() 271 | self.cfg = cfg 272 | self.n_views = cfg.DATASET.N_VIEWS 273 | self.joint_rep_type = cfg.MODEL.SMPL_HEAD.get('JOINT_REP', '6d') 274 | self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type] 275 | npose = self.joint_rep_dim * (cfg.SMPL.NUM_BODY_JOINTS) 276 | self.npose = npose 277 | self.input_is_mean_shape = cfg.MODEL.SMPL_HEAD.get( 278 | 'TRANSFORMER_INPUT', 'zero') == 'mean_shape' 279 | if cfg.MODEL.BACKBONE.TYPE == 'resnet': 280 | transformer_args = dict( 281 | num_tokens=1 + self.n_views, 282 | token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1, 283 | dim=1024, 284 | ) 285 | else: 286 | transformer_args = dict( 287 | num_tokens=1 + self.n_views, 288 | token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1, 289 | dim=1024, 290 | ) 291 | transformer_args = {**transformer_args, ** 292 | dict(cfg.MODEL.SMPL_HEAD.TRANSFORMER_DECODER)} 293 | self.transformer = TransformerDecoder( 294 | **transformer_args 295 | ) 296 | if cfg.MODEL.SMPL_HEAD.POSITIONAL_ENCODING == 'SinePositionalEncoding3D': 297 | # TODO find a better way of exposing other arguments 298 | self.position_embedding = build_position_encoding(cfg) 299 | dim = transformer_args['dim'] 300 | context_dim = transformer_args['context_dim'] 301 | self.decpose = nn.Linear(dim, npose) 302 | self.decglobalorientation = nn.Linear(dim, 6) 303 | self.decshape = nn.Linear(dim, 10) 304 | self.deccam = nn.Linear(dim, 3) 305 | # self.avgpool = nn.AvgPool2d(8, stride=1) 306 | if cfg.MODEL.BACKBONE.TYPE == 'resnet' or cfg.MODEL.BACKBONE.TYPE == 'swin_v2': 307 | self.avgpool = nn.AvgPool2d(8, stride=1) 308 | elif cfg.MODEL.BACKBONE.TYPE == 'vit': 309 | self.avgpool = nn.AvgPool2d((16, 12), stride=1) 310 | 311 | if cfg.MODEL.SMPL_HEAD.get('INIT_DECODER_XAVIER', False): 312 | # True by default in MLP. False by default in Transformer 313 | nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) 314 | nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) 315 | nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) 316 | nn.init.xavier_uniform_(self.decglobalorientation.weight, gain=0.01) 317 | 318 | mean_params = np.load(cfg.SMPL.MEAN_PARAMS) 319 | init_pose = torch.from_numpy( 320 | mean_params['pose'].astype(np.float32)).unsqueeze(0) 321 | init_betas = torch.from_numpy( 322 | mean_params['shape'].astype('float32')).unsqueeze(0) 323 | init_cam = torch.from_numpy( 324 | mean_params['cam'].astype(np.float32)).unsqueeze(0) 325 | self.register_buffer('init_global_orientation', init_pose[:, :6]) 326 | self.register_buffer('init_body_pose', init_pose[:, 6:]) 327 | self.register_buffer('init_betas', init_betas) 328 | self.register_buffer('init_cam', init_cam) 329 | 330 | def forward(self, features, **kwargs): 331 | # features_pooled = self.avgpool(features) 332 | # features_pooled = features_pooled.view(features_pooled.size(0), -1) # (8, 2048) 333 | # features = features.view(features.size(0), -1, features.size(1)) # (8, 64, 2048) 334 | n_sample = features.shape[0] 335 | features = rearrange(features, "(n b) c h w-> b n c h w", n=self.n_views) # (2, 256, 2048) 336 | 337 | if self.cfg.MODEL.SMPL_HEAD.POSITIONAL_ENCODING == 'SinePositionalEncoding3D': 338 | pos_embed = self.position_embedding(features) 339 | features = features + pos_embed 340 | features = rearrange(features, "b n c h w-> b (n h w) c") # (512, 2048, 8, 8) 341 | 342 | 343 | # batch_size = xf_b.shape[0] 344 | # 8 = 4 * 2 345 | batch_size = n_sample // self.n_views # 2 346 | init_body_pose = self.init_body_pose.expand(batch_size, -1) # (2, 138) 347 | init_betas = self.init_betas.expand(batch_size, -1) # (2, 10) 348 | init_cam = self.init_cam.expand(n_sample, -1) # (8, 3) 349 | init_global_orientation = self.init_global_orientation.expand( 350 | n_sample, -1) # (8, 6) 351 | 352 | # TODO: Convert init_body_pose to aa rep if needed 353 | if self.joint_rep_type == 'aa': 354 | raise NotImplementedError 355 | 356 | pred_body_pose = init_body_pose 357 | pred_betas = init_betas 358 | pred_cam = init_cam 359 | pred_global_orientation = init_global_orientation 360 | 361 | for i in range(self.cfg.MODEL.SMPL_HEAD.get('IEF_ITERS', 1)): 362 | # Input token to transformer is zero token 363 | if self.input_is_mean_shape: 364 | token = torch.cat([pred_body_pose, pred_betas, pred_cam], dim=1)[ 365 | :, None, :] 366 | else: 367 | token = torch.zeros(batch_size, 1 + self.n_views, 1).to(features.device) 368 | 369 | # Pass through transformer 370 | token_out = self.transformer(token, context=features) 371 | # token_out = token_out.squeeze(1) # (B, C) 372 | 373 | # Readout from token_out 374 | pred_body_pose = self.decpose(token_out[:, 1]) + pred_body_pose 375 | pred_betas = self.decshape(token_out[:, 1]) + pred_betas 376 | token_cam = rearrange(token_out[:, 1:], "b n c -> (n b) c") 377 | pred_global_orientation = self.decglobalorientation( 378 | token_cam) + pred_global_orientation 379 | pred_cam = self.deccam(token_cam) + pred_cam 380 | 381 | 382 | # Convert self.joint_rep_type -> rotmat 383 | joint_conversion_fn = { 384 | '6d': rot6d_to_rotmat, 385 | 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous()) 386 | }[self.joint_rep_type] 387 | 388 | 389 | pred_body_pose = joint_conversion_fn(pred_body_pose).view( 390 | batch_size, self.cfg.SMPL.NUM_BODY_JOINTS, 3, 3) 391 | pred_global_orientation = joint_conversion_fn(pred_global_orientation).view( 392 | n_sample, 1, 3, 3) 393 | 394 | return pred_body_pose, pred_betas, pred_global_orientation, pred_cam 395 | 396 | class SMPLFCNFusionHead(nn.Module): 397 | def __init__(self, cfg): 398 | super().__init__() 399 | self.cfg = cfg 400 | self.n_views = cfg.DATASET.N_VIEWS 401 | self.joint_rep_type = cfg.MODEL.SMPL_HEAD.get('JOINT_REP', '6d') 402 | self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type] 403 | npose = self.joint_rep_dim * (cfg.SMPL.NUM_BODY_JOINTS) 404 | self.npose = npose 405 | self.in_channels = cfg.MODEL.SMPL_HEAD.IN_CHANNELS 406 | self.fc1 = nn.Linear(self.cfg.MODEL.SMPL_HEAD.IN_CHANNELS + npose + 10, 1024) 407 | self.drop1 = nn.Dropout() 408 | self.fc2 = nn.Linear(1024, 1024) 409 | self.drop2 = nn.Dropout() 410 | self.decpose = nn.Linear(1024, npose) 411 | self.decshape = nn.Linear(1024, 10) 412 | 413 | 414 | self.fc3 = nn.Linear(self.in_channels, 1024) 415 | self.drop3 = nn.Dropout() 416 | self.fc4 = nn.Linear(1024, 1024) 417 | self.drop4 = nn.Dropout() 418 | 419 | self.decglobalorientation = nn.Linear(1024, 6) 420 | self.deccam = nn.Linear(1024, 3) 421 | 422 | self.attn = nn.Linear(self.in_channels, self.in_channels) 423 | # self.drop3 = nn.Dropout() 424 | # self.avgpool = nn.AvgPool2d(8, stride=1) 425 | if cfg.MODEL.BACKBONE.TYPE == 'resnet' or cfg.MODEL.BACKBONE.TYPE == 'swin_v2': 426 | self.avgpool = nn.AvgPool2d(8, stride=1) 427 | elif cfg.MODEL.BACKBONE.TYPE == 'vit': 428 | self.avgpool = nn.AvgPool2d((16, 12), stride=1) 429 | 430 | if cfg.MODEL.SMPL_HEAD.get('INIT_DECODER_XAVIER', False): 431 | # True by default in MLP. False by default in Transformer 432 | nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) 433 | nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) 434 | nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) 435 | nn.init.xavier_uniform_(self.decglobalorientation.weight, gain=0.01) 436 | 437 | mean_params = np.load(cfg.SMPL.MEAN_PARAMS) 438 | init_pose = torch.from_numpy( 439 | mean_params['pose'].astype(np.float32)).unsqueeze(0) 440 | init_betas = torch.from_numpy( 441 | mean_params['shape'].astype('float32')).unsqueeze(0) 442 | init_cam = torch.from_numpy( 443 | mean_params['cam'].astype(np.float32)).unsqueeze(0) 444 | self.register_buffer('init_global_orientation', init_pose[:, :6]) 445 | self.register_buffer('init_body_pose', init_pose[:, 6:]) 446 | self.register_buffer('init_betas', init_betas) 447 | self.register_buffer('init_cam', init_cam) 448 | 449 | def forward(self, features, n_iter=3): 450 | n_sample = features.shape[0] # 8 = 4 * 2 451 | features = self.avgpool(features) 452 | features = features.view(features.size(0), -1) 453 | features_attn = self.attn(features) 454 | features_views = rearrange(features, "(n b) c -> b n c", n=self.n_views) 455 | features_attn = rearrange(features_attn, "(n b) c -> b n c", n=self.n_views) 456 | features_attn = F.softmax(features_attn, dim=1) 457 | features_fuse = torch.sum(features_views * features_attn, dim=1) 458 | # n_sample = features.shape[0] # 8 = 4 * 2 459 | batch_size = features_fuse.size(0) # 2 460 | init_body_pose = self.init_body_pose.expand(batch_size, -1) # (2, 138) 461 | init_betas = self.init_betas.expand(batch_size, -1) # (2, 10) 462 | init_cam = self.init_cam.expand(n_sample, -1) # (8, 3) 463 | init_global_orientation = self.init_global_orientation.expand( 464 | n_sample, -1) # (8, 6) 465 | 466 | # TODO: Convert init_body_pose to aa rep if needed 467 | if self.joint_rep_type == 'aa': 468 | raise NotImplementedError 469 | 470 | pred_body_pose = init_body_pose 471 | pred_betas = init_betas 472 | pred_cam = init_cam 473 | pred_global_orientation = init_global_orientation 474 | for i in range(n_iter): 475 | xc = torch.cat([features_fuse, pred_body_pose, pred_betas], 1) 476 | xc = self.fc1(xc) 477 | xc = self.drop1(xc) 478 | xc = self.fc2(xc) 479 | xc = self.drop2(xc) 480 | pred_body_pose = self.decpose(xc) + pred_body_pose 481 | pred_betas = self.decshape(xc) + pred_betas 482 | 483 | xf = self.fc3(features) 484 | xf = self.drop3(xf) 485 | xf = self.fc4(xf) 486 | xf = self.drop4(xf) 487 | pred_global_orientation = self.decglobalorientation(xf) + pred_global_orientation 488 | pred_cam = self.deccam(xf) + pred_cam 489 | 490 | # pred_global_orientation = self.decglobalorientation(features) + pred_global_orientation 491 | # pred_cam = self.deccam(features) + pred_cam 492 | 493 | # Convert self.joint_rep_type -> rotmat 494 | joint_conversion_fn = { 495 | '6d': rot6d_to_rotmat, 496 | 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous()) 497 | }[self.joint_rep_type] 498 | 499 | 500 | pred_body_pose = joint_conversion_fn(pred_body_pose).view( 501 | batch_size, self.cfg.SMPL.NUM_BODY_JOINTS, 3, 3) 502 | pred_global_orientation = joint_conversion_fn(pred_global_orientation).view( 503 | n_sample, 1, 3, 3) 504 | 505 | return pred_body_pose, pred_betas, pred_global_orientation, pred_cam -------------------------------------------------------------------------------- /lib/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Keypoint2DLoss(nn.Module): 5 | 6 | def __init__(self, loss_type: str = 'l1'): 7 | """ 8 | 2D keypoint loss module. 9 | Args: 10 | loss_type (str): Choose between l1 and l2 losses. 11 | """ 12 | super(Keypoint2DLoss, self).__init__() 13 | if loss_type == 'l1': 14 | self.loss_fn = nn.L1Loss(reduction='none') 15 | elif loss_type == 'l2': 16 | self.loss_fn = nn.MSELoss(reduction='none') 17 | else: 18 | raise NotImplementedError('Unsupported loss function') 19 | 20 | def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor: 21 | """ 22 | Compute 2D reprojection loss on the keypoints. 23 | Args: 24 | pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints) 25 | gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence. 26 | Returns: 27 | torch.Tensor: 2D keypoint loss. 28 | """ 29 | # conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() 30 | # batch_size = conf.shape[0] 31 | loss = self.loss_fn(pred_keypoints_2d, gt_keypoints_2d).sum(dim=(1,2)) 32 | # return loss.sum() 33 | return loss.mean() 34 | 35 | 36 | class Keypoint3DLoss(nn.Module): 37 | 38 | def __init__(self, loss_type: str = 'l1'): 39 | """ 40 | 3D keypoint loss module. 41 | Args: 42 | loss_type (str): Choose between l1 and l2 losses. 43 | """ 44 | super(Keypoint3DLoss, self).__init__() 45 | if loss_type == 'l1': 46 | self.loss_fn = nn.L1Loss(reduction='none') 47 | elif loss_type == 'l2': 48 | self.loss_fn = nn.MSELoss(reduction='none') 49 | else: 50 | raise NotImplementedError('Unsupported loss function') 51 | 52 | def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor): 53 | """ 54 | Compute 3D keypoint loss. 55 | Args: 56 | pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints) 57 | gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence. 58 | Returns: 59 | torch.Tensor: 3D keypoint loss. 60 | """ 61 | batch_size = pred_keypoints_3d.shape[0] 62 | # gt_keypoints_3d = gt_keypoints_3d.clone() 63 | # pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1) 64 | # gt_keypoints_3d = gt_keypoints_3d- gt_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1) 65 | # conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() 66 | # gt_keypoints_3d = gt_keypoints_3d[:, :, :-1] 67 | loss = self.loss_fn(pred_keypoints_3d, gt_keypoints_3d).sum(dim=(1,2)) 68 | # return loss.sum() 69 | return loss.mean() 70 | 71 | class ParameterLoss(nn.Module): 72 | 73 | def __init__(self): 74 | """ 75 | SMPL parameter loss module. 76 | """ 77 | super(ParameterLoss, self).__init__() 78 | self.loss_fn = nn.MSELoss(reduction='none') 79 | 80 | def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor): 81 | """ 82 | Compute SMPL parameter loss. 83 | Args: 84 | pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas) 85 | gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth SMPL parameters. 86 | Returns: 87 | torch.Tensor: L2 parameter loss loss. 88 | """ 89 | batch_size = pred_param.shape[0] 90 | num_dims = len(pred_param.shape) 91 | mask_dimension = [batch_size] + [1] * (num_dims-1) 92 | has_param = has_param.type(pred_param.type()).view(*mask_dimension) 93 | # loss_param = self.loss_fn(pred_param, gt_param) 94 | loss_param = (has_param * self.loss_fn(pred_param, gt_param)) 95 | # return loss_param.sum() 96 | return loss_param.mean() -------------------------------------------------------------------------------- /lib/models/smpl_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | from typing import Optional 5 | import smplx 6 | from smplx.lbs import vertices2joints 7 | from smplx.utils import SMPLOutput 8 | 9 | 10 | class SMPL(smplx.SMPLLayer): 11 | def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs): 12 | """ 13 | Extension of the official SMPL implementation to support more joints. 14 | Args: 15 | Same as SMPLLayer. 16 | joint_regressor_extra (str): Path to extra joint regressor. 17 | """ 18 | super(SMPL, self).__init__(*args, **kwargs) 19 | smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 20 | 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] 21 | 22 | if joint_regressor_extra is not None: 23 | self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32)) 24 | self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long)) 25 | self.update_hips = update_hips 26 | 27 | def forward(self, *args, **kwargs) -> SMPLOutput: 28 | """ 29 | Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified. 30 | """ 31 | smpl_output = super(SMPL, self).forward(*args, **kwargs) 32 | joints = smpl_output.joints[:, self.joint_map, :] 33 | if self.update_hips: 34 | joints[:,[9,12]] = joints[:,[9,12]] + \ 35 | 0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \ 36 | 0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]])) 37 | if hasattr(self, 'joint_regressor_extra'): 38 | extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices) 39 | joints = torch.cat([joints, extra_joints], dim=1) 40 | smpl_output.joints = joints 41 | return smpl_output 42 | 43 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaobenLi00/U-HMR/54f2b6516018a79590ea8e5d160fe1f3bd9b7c8d/lib/utils/__init__.py -------------------------------------------------------------------------------- /lib/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | from yacs.config import CfgNode as CN 4 | 5 | def to_lower(x: Dict) -> Dict: 6 | """ 7 | Convert all dictionary keys to lowercase 8 | Args: 9 | x (dict): Input dictionary 10 | Returns: 11 | dict: Output dictionary with all keys converted to lowercase 12 | """ 13 | return {k.lower(): v for k, v in x.items()} 14 | 15 | 16 | _C = CN(new_allowed=True) 17 | 18 | _C.GENERAL = CN(new_allowed=True) 19 | _C.GENERAL.RESUME = True 20 | _C.GENERAL.TIME_TO_RUN = 3300 21 | _C.GENERAL.VAL_STEPS = 100 22 | _C.GENERAL.LOG_STEPS = 100 23 | _C.GENERAL.CHECKPOINT_STEPS = 20000 24 | _C.GENERAL.CHECKPOINT_DIR = "checkpoints" 25 | _C.GENERAL.SUMMARY_DIR = "tensorboard" 26 | _C.GENERAL.NUM_GPUS = 1 27 | _C.GENERAL.NUM_WORKERS = 4 28 | _C.GENERAL.MIXED_PRECISION = True 29 | _C.GENERAL.ALLOW_CUDA = True 30 | _C.GENERAL.PIN_MEMORY = False 31 | _C.GENERAL.DISTRIBUTED = False 32 | _C.GENERAL.LOCAL_RANK = 0 33 | _C.GENERAL.USE_SYNCBN = False 34 | _C.GENERAL.WORLD_SIZE = 1 35 | 36 | _C.TRAIN = CN(new_allowed=True) 37 | _C.TRAIN.NUM_EPOCHS = 100 38 | _C.TRAIN.BATCH_SIZE = 32 39 | _C.TRAIN.SHUFFLE = True 40 | _C.TRAIN.WARMUP = False 41 | _C.TRAIN.NORMALIZE_PER_IMAGE = False 42 | _C.TRAIN.CLIP_GRAD = False 43 | _C.TRAIN.CLIP_GRAD_VALUE = 1.0 44 | _C.LOSS_WEIGHTS = CN(new_allowed=True) 45 | 46 | _C.DATASETS = CN(new_allowed=True) 47 | 48 | _C.MODEL = CN(new_allowed=True) 49 | _C.MODEL.IMAGE_SIZE = [256, 256] 50 | 51 | _C.EXTRA = CN(new_allowed=True) 52 | _C.EXTRA.FOCAL_LENGTH = 5000 53 | 54 | _C.DATASETS.CONFIG = CN(new_allowed=True) 55 | _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3 56 | _C.DATASETS.CONFIG.ROT_FACTOR = 30 57 | _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02 58 | _C.DATASETS.CONFIG.COLOR_SCALE = 0.2 59 | _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6 60 | _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5 61 | _C.DATASETS.CONFIG.DO_FLIP = True 62 | _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5 63 | _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10 64 | 65 | 66 | def default_config() -> CN: 67 | """ 68 | Get a yacs CfgNode object with the default config values. 69 | """ 70 | # Return a clone so that the defaults will not be altered 71 | # This is for the "local variable" use pattern 72 | return _C.clone() 73 | 74 | 75 | 76 | def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN: 77 | """ 78 | Read a config file and optionally merge it with the default config file. 79 | Args: 80 | config_file (str): Path to config file. 81 | merge (bool): Whether to merge with the default config or not. 82 | Returns: 83 | CfgNode: Config as a yacs CfgNode object. 84 | """ 85 | if merge: 86 | cfg = default_config() 87 | else: 88 | cfg = CN(new_allowed=True) 89 | cfg.merge_from_file(config_file) 90 | 91 | cfg.freeze() 92 | return cfg 93 | -------------------------------------------------------------------------------- /lib/utils/geometry.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | """ 7 | Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula 8 | Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR 9 | """ 10 | 11 | 12 | def aa_to_rotmat(theta): 13 | """Convert axis-angle representation to rotation matrix. 14 | Args: 15 | theta: size = [B, 3] 16 | Returns: 17 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 18 | """ 19 | n = theta.size(0) 20 | if theta.size(-1) != 3: 21 | theta = theta.view(-1, 3).contiguous() 22 | l1norm = torch.norm(theta + 1e-8, p=2, dim=1) 23 | angle = torch.unsqueeze(l1norm, -1) 24 | normalized = torch.div(theta, angle) 25 | angle = angle * 0.5 26 | v_cos = torch.cos(angle) 27 | v_sin = torch.sin(angle) 28 | quat = torch.cat([v_cos, v_sin * normalized], dim=1) 29 | return quat_to_rotmat(quat).view(n, -1, 3, 3) 30 | 31 | 32 | def quat_to_rotmat(quat): 33 | """Convert quaternion coefficients to rotation matrix. 34 | Args: 35 | quat: size = [B, 4] 4 <===>(w, x, y, z) 36 | Returns: 37 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 38 | """ 39 | norm_quat = quat 40 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) 41 | w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] 42 | 43 | B = quat.size(0) 44 | 45 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 46 | wx, wy, wz = w * x, w * y, w * z 47 | xy, xz, yz = x * y, x * z, y * z 48 | 49 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 50 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 51 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) 52 | return rotMat 53 | 54 | 55 | def rot6d_to_rotmat(x): 56 | """Convert 6D rotation representation to 3x3 rotation matrix. 57 | Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 58 | Input: 59 | (B,6) Batch of 6-D rotation representations 60 | Output: 61 | (B,3,3) Batch of corresponding rotation matrices 62 | """ 63 | x = x.view(-1, 3, 2) 64 | a1 = x[:, :, 0] 65 | a2 = x[:, :, 1] 66 | b1 = F.normalize(a1) 67 | b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) 68 | b3 = torch.cross(b1, b2) 69 | return torch.stack((b1, b2, b3), dim=-1) 70 | 71 | 72 | def perspective_projection(points: torch.Tensor, 73 | translation: torch.Tensor, 74 | focal_length: torch.Tensor, 75 | camera_center: Optional[torch.Tensor] = None, 76 | rotation: Optional[torch.Tensor] = None) -> torch.Tensor: 77 | """ 78 | Computes the perspective projection of a set of 3D points. 79 | Args: 80 | points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. 81 | translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. 82 | focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. 83 | camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. 84 | rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. 85 | Returns: 86 | torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. 87 | """ 88 | batch_size = points.shape[0] 89 | if rotation is None: 90 | rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) 91 | if camera_center is None: 92 | camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) 93 | # Populate intrinsic camera matrix K. 94 | K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) 95 | K[:,0,0] = focal_length[:,0] 96 | K[:,1,1] = focal_length[:,1] 97 | K[:,2,2] = 1. 98 | K[:,:-1, -1] = camera_center 99 | 100 | # Transform points 101 | points = torch.einsum('bij,bkj->bki', rotation, points) 102 | points = points + translation.unsqueeze(1) 103 | 104 | # Apply perspective distortion 105 | projected_points = points / points[:,:,-1].unsqueeze(-1) 106 | 107 | # Apply camera intrinsics 108 | projected_points = torch.einsum('bij,bkj->bki', K, projected_points) 109 | 110 | return projected_points[:, :, :-1] -------------------------------------------------------------------------------- /lib/utils/img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from PIL import Image 4 | 5 | import torch 6 | 7 | IMAGENET_MEAN, IMAGENET_STD = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225]) 8 | 9 | 10 | def crop_image(image, bbox): 11 | """Crops area from image specified as bbox. Always returns area of size as bbox filling missing parts with zeros 12 | Args: 13 | image numpy array of shape (height, width, 3): input image 14 | bbox tuple of size 4: input bbox (left, upper, right, lower) 15 | 16 | Returns: 17 | cropped_image numpy array of shape (height, width, 3): resulting cropped image 18 | 19 | """ 20 | 21 | image_pil = Image.fromarray(image) 22 | image_pil = image_pil.crop(bbox) 23 | 24 | return np.asarray(image_pil) 25 | 26 | 27 | def resize_image(image, shape): 28 | return cv2.resize(image, (shape[1], shape[0]), interpolation=cv2.INTER_AREA) 29 | 30 | 31 | def get_square_bbox(bbox): 32 | """Makes square bbox from any bbox by stretching of minimal length side 33 | 34 | Args: 35 | bbox tuple of size 4: input bbox (left, upper, right, lower) 36 | 37 | Returns: 38 | bbox: tuple of size 4: resulting square bbox (left, upper, right, lower) 39 | """ 40 | 41 | left, upper, right, lower = bbox 42 | width, height = right - left, lower - upper 43 | 44 | if width > height: 45 | y_center = (upper + lower) // 2 46 | upper = y_center - width // 2 47 | lower = upper + width 48 | else: 49 | x_center = (left + right) // 2 50 | left = x_center - height // 2 51 | right = left + height 52 | 53 | return left, upper, right, lower 54 | 55 | 56 | def scale_bbox(bbox, scale): 57 | left, upper, right, lower = bbox 58 | width, height = right - left, lower - upper 59 | 60 | x_center, y_center = (right + left) // 2, (lower + upper) // 2 61 | new_width, new_height = int(scale * width), int(scale * height) 62 | 63 | new_left = x_center - new_width // 2 64 | new_right = new_left + new_width 65 | 66 | new_upper = y_center - new_height // 2 67 | new_lower = new_upper + new_height 68 | 69 | return new_left, new_upper, new_right, new_lower 70 | 71 | 72 | def to_numpy(tensor): 73 | if torch.is_tensor(tensor): 74 | return tensor.cpu().detach().numpy() 75 | elif type(tensor).__module__ != 'numpy': 76 | raise ValueError("Cannot convert {} to numpy array" 77 | .format(type(tensor))) 78 | return tensor 79 | 80 | 81 | def to_torch(ndarray): 82 | if type(ndarray).__module__ == 'numpy': 83 | return torch.from_numpy(ndarray) 84 | elif not torch.is_tensor(ndarray): 85 | raise ValueError("Cannot convert {} to torch tensor" 86 | .format(type(ndarray))) 87 | return ndarray 88 | 89 | 90 | def image_batch_to_numpy(image_batch): 91 | image_batch = to_numpy(image_batch) 92 | image_batch = np.transpose(image_batch, (0, 2, 3, 1)) # BxCxHxW -> BxHxWxC 93 | return image_batch 94 | 95 | 96 | def image_batch_to_torch(image_batch): 97 | image_batch = np.transpose(image_batch, (0, 3, 1, 2)) # BxHxWxC -> BxCxHxW 98 | image_batch = to_torch(image_batch).float() 99 | return image_batch 100 | 101 | 102 | def normalize_image(image): 103 | """Normalizes image using ImageNet mean and std 104 | 105 | Args: 106 | image numpy array of shape (h, w, 3): image 107 | 108 | Returns normalized_image numpy array of shape (h, w, 3): normalized image 109 | """ 110 | return (image / 255.0 - IMAGENET_MEAN) / IMAGENET_STD 111 | 112 | 113 | def denormalize_image(image): 114 | """Reverse to normalize_image() function""" 115 | return np.clip(255.0 * (image * IMAGENET_STD + IMAGENET_MEAN), 0, 255) 116 | -------------------------------------------------------------------------------- /lib/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import logging 13 | import time 14 | from pathlib import Path 15 | 16 | import torch 17 | import torch.optim as optim 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | 22 | def create_logger(cfg, cfg_name, phase='train'): 23 | root_output_dir = Path(cfg.OUTPUT_DIR) 24 | # set up logger 25 | if not root_output_dir.exists(): 26 | logger.info('=> creating {}'.format(root_output_dir)) 27 | root_output_dir.mkdir() 28 | 29 | data_source = cfg.DATASET.SOURCE 30 | # model, _ = get_model_name(cfg) 31 | cfg_name = os.path.basename(cfg_name).split('.')[0] 32 | 33 | time_str = time.strftime('%Y-%m-%d-%H-%M') 34 | 35 | final_output_dir = root_output_dir / data_source / cfg_name / time_str 36 | 37 | print('=> creating {}'.format(final_output_dir)) 38 | final_output_dir.mkdir(parents=True, exist_ok=True) 39 | 40 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase) 41 | final_log_file = final_output_dir / log_file 42 | head = '%(asctime)-15s %(message)s' 43 | logging.basicConfig(filename=str(final_log_file), 44 | format=head) 45 | logger = logging.getLogger() 46 | logger.setLevel(logging.INFO) 47 | console = logging.StreamHandler() 48 | logging.getLogger('').addHandler(console) 49 | 50 | tensorboard_log_dir = Path(cfg.LOG_DIR) / data_source / cfg_name / time_str 51 | print('=> creating {}'.format(tensorboard_log_dir)) 52 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True) 53 | 54 | return logger, str(final_output_dir), str(tensorboard_log_dir) 55 | 56 | 57 | 58 | def load_checkpoint(model, output_dir, filename='checkpoint.pth.tar'): 59 | file = os.path.join(output_dir, filename) 60 | if os.path.isfile(file): 61 | checkpoint = torch.load(file) 62 | start_epoch = checkpoint['epoch'] 63 | model.module.load_state_dict(checkpoint['state_dict']) 64 | model.module.optimizer.load_state_dict(checkpoint['optimizer']) 65 | logger.info('=> load checkpoint {} (epoch {})' 66 | .format(file, start_epoch)) 67 | 68 | return start_epoch, model 69 | 70 | else: 71 | logger.info('=> no checkpoint found at {}'.format(file)) 72 | return 0, model 73 | 74 | 75 | def save_checkpoint(states, is_best, output_dir, 76 | filename='checkpoint.pth.tar', bestname='model_best.pth.tar'): 77 | torch.save(states, os.path.join(output_dir, filename)) 78 | if is_best: 79 | torch.save(states, os.path.join(output_dir, bestname)) 80 | -------------------------------------------------------------------------------- /lib/utils/multiview.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Camera: 6 | def __init__(self, R, t, K, dist=None, name=""): 7 | self.R = np.array(R).copy() 8 | assert self.R.shape == (3, 3) 9 | 10 | self.t = np.array(t).copy() 11 | assert self.t.size == 3 12 | self.t = self.t.reshape(3, 1) 13 | 14 | self.K = np.array(K).copy() 15 | assert self.K.shape == (3, 3) 16 | 17 | self.dist = dist 18 | if self.dist is not None: 19 | self.dist = np.array(self.dist).copy().flatten() 20 | 21 | self.name = name 22 | 23 | def update_after_crop(self, bbox): 24 | left, upper, right, lower = bbox 25 | 26 | cx, cy = self.K[0, 2], self.K[1, 2] 27 | 28 | new_cx = cx - left 29 | new_cy = cy - upper 30 | 31 | self.K[0, 2], self.K[1, 2] = new_cx, new_cy 32 | 33 | def update_after_resize(self, image_shape, new_image_shape): 34 | height, width = image_shape 35 | new_height, new_width = new_image_shape 36 | 37 | fx, fy, cx, cy = self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] 38 | 39 | new_fx = fx * (new_width / width) 40 | new_fy = fy * (new_height / height) 41 | new_cx = cx * (new_width / width) 42 | new_cy = cy * (new_height / height) 43 | 44 | self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] = new_fx, new_fy, new_cx, new_cy 45 | 46 | @property 47 | def projection(self): 48 | return self.K.dot(self.extrinsics) 49 | 50 | @property 51 | def extrinsics(self): 52 | return np.hstack([self.R, self.t]) 53 | 54 | 55 | def euclidean_to_homogeneous(points): 56 | """Converts euclidean points to homogeneous 57 | 58 | Args: 59 | points numpy array or torch tensor of shape (N, M): N euclidean points of dimension M 60 | 61 | Returns: 62 | numpy array or torch tensor of shape (N, M + 1): homogeneous points 63 | """ 64 | if isinstance(points, np.ndarray): 65 | return np.hstack([points, np.ones((len(points), 1))]) 66 | elif torch.is_tensor(points): 67 | return torch.cat([points, torch.ones((points.shape[0], 1), dtype=points.dtype, device=points.device)], dim=1) 68 | else: 69 | raise TypeError("Works only with numpy arrays and PyTorch tensors.") 70 | 71 | 72 | def homogeneous_to_euclidean(points): 73 | """Converts homogeneous points to euclidean 74 | 75 | Args: 76 | points numpy array or torch tensor of shape (N, M + 1): N homogeneous points of dimension M 77 | 78 | Returns: 79 | numpy array or torch tensor of shape (N, M): euclidean points 80 | """ 81 | if isinstance(points, np.ndarray): 82 | return (points.T[:-1] / points.T[-1]).T 83 | elif torch.is_tensor(points): 84 | return (points.transpose(1, 0)[:-1] / points.transpose(1, 0)[-1]).transpose(1, 0) 85 | else: 86 | raise TypeError("Works only with numpy arrays and PyTorch tensors.") 87 | 88 | 89 | def project_3d_points_to_image_plane_without_distortion(proj_matrix, points_3d, convert_back_to_euclidean=True): 90 | """Project 3D points to image plane not taking into account distortion 91 | Args: 92 | proj_matrix numpy array or torch tensor of shape (3, 4): projection matrix 93 | points_3d numpy array or torch tensor of shape (N, 3): 3D points 94 | convert_back_to_euclidean bool: if True, then resulting points will be converted to euclidean coordinates 95 | NOTE: division by zero can be here if z = 0 96 | Returns: 97 | numpy array or torch tensor of shape (N, 2): 3D points projected to image plane 98 | """ 99 | if isinstance(proj_matrix, np.ndarray) and isinstance(points_3d, np.ndarray): 100 | result = euclidean_to_homogeneous(points_3d) @ proj_matrix.T 101 | if convert_back_to_euclidean: 102 | result = homogeneous_to_euclidean(result) 103 | return result 104 | elif torch.is_tensor(proj_matrix) and torch.is_tensor(points_3d): 105 | result = euclidean_to_homogeneous(points_3d) @ proj_matrix.t() 106 | if convert_back_to_euclidean: 107 | result = homogeneous_to_euclidean(result) 108 | return result 109 | else: 110 | raise TypeError("Works only with numpy arrays and PyTorch tensors.") 111 | 112 | 113 | def triangulate_point_from_multiple_views_linear(proj_matricies, points): 114 | """Triangulates one point from multiple (N) views using direct linear transformation (DLT). 115 | For more information look at "Multiple view geometry in computer vision", 116 | Richard Hartley and Andrew Zisserman, 12.2 (p. 312). 117 | 118 | Args: 119 | proj_matricies numpy array of shape (N, 3, 4): sequence of projection matricies (3x4) 120 | points numpy array of shape (N, 2): sequence of points' coordinates 121 | 122 | Returns: 123 | point_3d numpy array of shape (3,): triangulated point 124 | """ 125 | assert len(proj_matricies) == len(points) 126 | 127 | n_views = len(proj_matricies) 128 | A = np.zeros((2 * n_views, 4)) 129 | for j in range(len(proj_matricies)): 130 | A[j * 2 + 0] = points[j][0] * proj_matricies[j][2, :] - proj_matricies[j][0, :] 131 | A[j * 2 + 1] = points[j][1] * proj_matricies[j][2, :] - proj_matricies[j][1, :] 132 | 133 | u, s, vh = np.linalg.svd(A, full_matrices=False) 134 | point_3d_homo = vh[3, :] 135 | 136 | point_3d = homogeneous_to_euclidean(point_3d_homo) 137 | 138 | return point_3d 139 | 140 | 141 | def triangulate_point_from_multiple_views_linear_torch(proj_matricies, points, confidences=None): 142 | """Similar as triangulate_point_from_multiple_views_linear() but for PyTorch. 143 | For more information see its documentation. 144 | Args: 145 | proj_matricies torch tensor of shape (N, 3, 4): sequence of projection matricies (3x4) 146 | points torch tensor of of shape (N, 2): sequence of points' coordinates 147 | confidences None or torch tensor of shape (N,): confidences of points [0.0, 1.0]. 148 | If None, all confidences are supposed to be 1.0 149 | Returns: 150 | point_3d numpy torch tensor of shape (3,): triangulated point 151 | """ 152 | assert len(proj_matricies) == len(points) 153 | 154 | n_views = len(proj_matricies) 155 | 156 | if confidences is None: 157 | confidences = torch.ones(n_views, dtype=torch.float32, device=points.device) 158 | 159 | A = proj_matricies[:, 2:3].expand(n_views, 2, 4) * points.view(n_views, 2, 1) 160 | A -= proj_matricies[:, :2] 161 | A *= confidences.view(-1, 1, 1) 162 | 163 | u, s, vh = torch.svd(A.view(-1, 4)) 164 | 165 | point_3d_homo = -vh[:, 3] 166 | point_3d = homogeneous_to_euclidean(point_3d_homo.unsqueeze(0))[0] 167 | 168 | return point_3d 169 | 170 | 171 | def triangulate_batch_of_points(proj_matricies_batch, points_batch, confidences_batch=None): 172 | batch_size, n_views, n_joints = points_batch.shape[:3] 173 | point_3d_batch = torch.zeros(batch_size, n_joints, 3, dtype=torch.float32, device=points_batch.device) 174 | 175 | for batch_i in range(batch_size): 176 | for joint_i in range(n_joints): 177 | points = points_batch[batch_i, :, joint_i, :] 178 | 179 | confidences = confidences_batch[batch_i, :, joint_i] if confidences_batch is not None else None 180 | point_3d = triangulate_point_from_multiple_views_linear_torch(proj_matricies_batch[batch_i], points, confidences=confidences) 181 | point_3d_batch[batch_i, joint_i] = point_3d 182 | 183 | return point_3d_batch 184 | 185 | 186 | def calc_reprojection_error_matrix(keypoints_3d, keypoints_2d_list, proj_matricies): 187 | reprojection_error_matrix = [] 188 | for keypoints_2d, proj_matrix in zip(keypoints_2d_list, proj_matricies): 189 | keypoints_2d_projected = project_3d_points_to_image_plane_without_distortion(proj_matrix, keypoints_3d) 190 | reprojection_error = 1 / 2 * np.sqrt(np.sum((keypoints_2d - keypoints_2d_projected) ** 2, axis=1)) 191 | reprojection_error_matrix.append(reprojection_error) 192 | 193 | return np.vstack(reprojection_error_matrix).T 194 | -------------------------------------------------------------------------------- /lib/utils/pose_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parts of the code are adapted from https://github.com/akanazawa/hmr 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | import numpy as np 8 | import torch 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | 14 | def compute_similarity_transform(S1, S2): 15 | """ 16 | Computes a similarity transform (sR, t) that takes 17 | a set of 3D points S1 (3 x N) closest to a set of 3D points S2, 18 | where R is an 3x3 rotation matrix, t 3x1 translation, s scale. 19 | i.e. solves the orthogonal Procrutes problem. 20 | """ 21 | transposed = False 22 | if S1.shape[0] != 3 and S1.shape[0] != 2: 23 | S1 = S1.T 24 | S2 = S2.T 25 | transposed = True 26 | assert (S2.shape[1] == S1.shape[1]) 27 | 28 | # 1. Remove mean. 29 | mu1 = S1.mean(axis=1, keepdims=True) 30 | mu2 = S2.mean(axis=1, keepdims=True) 31 | X1 = S1 - mu1 32 | X2 = S2 - mu2 33 | 34 | # 2. Compute variance of X1 used for scale. 35 | var1 = np.sum(X1**2) 36 | 37 | # 3. The outer product of X1 and X2. 38 | K = X1.dot(X2.T) 39 | 40 | # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are 41 | # singular vectors of K. 42 | U, s, Vh = np.linalg.svd(K) 43 | V = Vh.T 44 | # Construct Z that fixes the orientation of R to get det(R)=1. 45 | Z = np.eye(U.shape[0]) 46 | Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) 47 | # Construct R. 48 | R = V.dot(Z.dot(U.T)) 49 | 50 | # 5. Recover scale. 51 | scale = np.trace(R.dot(K)) / var1 52 | 53 | # 6. Recover translation. 54 | t = mu2 - scale*(R.dot(mu1)) 55 | 56 | # 7. Error: 57 | S1_hat = scale*R.dot(S1) + t 58 | 59 | if transposed: 60 | S1_hat = S1_hat.T 61 | 62 | return S1_hat 63 | 64 | 65 | def compute_similarity_transform_batch(S1, S2): 66 | """Batched version of compute_similarity_transform.""" 67 | S1_hat = np.zeros_like(S1) 68 | for i in range(S1.shape[0]): 69 | S1_hat[i] = compute_similarity_transform(S1[i], S2[i]) 70 | return S1_hat 71 | 72 | 73 | def reconstruction_error(S1, S2, reduction='mean', vis=None): 74 | """Do Procrustes alignment and compute reconstruction error.""" 75 | S1_hat = compute_similarity_transform_batch(S1, S2) 76 | re = np.sqrt(((S1_hat - S2) ** 2).sum(axis=-1)).mean(axis=-1) 77 | # re = np.sqrt(((S1_hat - S2) ** 2).sum(axis=-1)) 78 | # print(re) 79 | # logger.info(f'reconstruction_error: {re}') 80 | # if vis is not None: 81 | # # pass 82 | # re = (re*vis[:,:,0]).sum(axis=-1) / (vis[:,:,0].sum(axis=-1) + 1e-9) 83 | # # logger.info(f'reconstruction_error: {re}') 84 | # # print(re) 85 | # else: 86 | # re = re.mean(axis=-1) 87 | if reduction == 'mean': 88 | re = re.mean() 89 | elif reduction == 'sum': 90 | re = re.sum() 91 | return re 92 | -------------------------------------------------------------------------------- /lib/utils/renderer.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import pyrender 3 | import numpy as np 4 | from torchvision.utils import make_grid 5 | import torch 6 | import os 7 | # os.environ['PYOPENGL_PLATFORM'] = 'osmesa' 8 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 9 | os.environ['EGL_DEVICE_ID'] = os.environ['GPU_DEVICE_ORDINAL'].split(',')[0] \ 10 | if 'GPU_DEVICE_ORDINAL' in os.environ.keys() else '0' 11 | 12 | def create_raymond_lights(): 13 | import pyrender 14 | thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) 15 | phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) 16 | 17 | nodes = [] 18 | 19 | for phi, theta in zip(phis, thetas): 20 | xp = np.sin(theta) * np.cos(phi) 21 | yp = np.sin(theta) * np.sin(phi) 22 | zp = np.cos(theta) 23 | 24 | z = np.array([xp, yp, zp]) 25 | z = z / np.linalg.norm(z) 26 | x = np.array([-z[1], z[0], 0.0]) 27 | if np.linalg.norm(x) == 0: 28 | x = np.array([1.0, 0.0, 0.0]) 29 | x = x / np.linalg.norm(x) 30 | y = np.cross(z, x) 31 | 32 | matrix = np.eye(4) 33 | matrix[:3,:3] = np.c_[x,y,z] 34 | nodes.append(pyrender.Node( 35 | light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0), 36 | matrix=matrix 37 | )) 38 | 39 | return nodes 40 | 41 | class Renderer: 42 | """ 43 | Renderer used for visualizing the SMPL model 44 | Code adapted from https://github.com/vchoutas/smplify-x 45 | """ 46 | 47 | def __init__(self, focal_length=5000, img_res=224, faces=None): 48 | self.img_res = img_res 49 | self.focal_length = focal_length 50 | self.camera_center = [img_res // 2, img_res // 2] 51 | self.faces = faces 52 | 53 | def visualize_tb(self, vertices, camera_translation, images): 54 | vertices = vertices.cpu().numpy() 55 | 56 | camera_translation = camera_translation.cpu().numpy() 57 | images = images.cpu() 58 | images_np = np.transpose(images.numpy(), (0, 2, 3, 1)) 59 | rend_imgs = [] 60 | for i in range(vertices.shape[0]): 61 | rend_img = torch.from_numpy(np.transpose(self.__call__( 62 | vertices[i], camera_translation[i], images_np[i]), (2, 0, 1))).float() 63 | rend_img_side = torch.from_numpy(np.transpose(self.__call__( 64 | vertices[i], camera_translation[i], images_np[i], side_view = True), (2, 0, 1))).float() 65 | rend_img_top = torch.from_numpy(np.transpose(self.__call__( 66 | vertices[i], camera_translation[i], images_np[i], top_view = True), (2, 0, 1))).float() 67 | rend_imgs.append(images[i]) 68 | rend_imgs.append(rend_img) 69 | rend_imgs.append(rend_img_side) 70 | rend_imgs.append(rend_img_top) 71 | rend_imgs = make_grid(rend_imgs, nrow=4) 72 | return rend_imgs 73 | 74 | def __call__(self, vertices, camera_translation, image, side_view=False, top_view = False, rot_angle=90): 75 | renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,viewport_height=self.img_res,point_size=1.0) 76 | material = pyrender.MetallicRoughnessMaterial( 77 | metallicFactor=0.2, 78 | alphaMode='OPAQUE', 79 | baseColorFactor=(0.865, 0.915, 0.98, 1.0)) 80 | 81 | camera_translation[0] *= -1. 82 | 83 | mesh = trimesh.Trimesh(vertices, self.faces) 84 | if side_view: 85 | rot = trimesh.transformations.rotation_matrix( 86 | np.radians(rot_angle), [0, 1, 0]) 87 | mesh.apply_transform(rot) 88 | if top_view: 89 | rot = trimesh.transformations.rotation_matrix( 90 | np.radians(rot_angle), [1, 0, 0]) 91 | mesh.apply_transform(rot) 92 | rot = trimesh.transformations.rotation_matrix( 93 | np.radians(180), [1, 0, 0]) 94 | mesh.apply_transform(rot) 95 | mesh = pyrender.Mesh.from_trimesh(mesh, material=material) 96 | 97 | scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3)) 98 | scene.add(mesh, 'mesh') 99 | 100 | camera_pose = np.eye(4) 101 | camera_pose[:3, 3] = camera_translation 102 | camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length, 103 | cx=self.camera_center[0], cy=self.camera_center[1]) 104 | scene.add(camera, pose=camera_pose) 105 | 106 | # light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1) 107 | # light_pose = np.eye(4) 108 | 109 | # light_pose[:3, 3] = np.array([0, -1, 1]) 110 | # scene.add(light, pose=light_pose) 111 | 112 | # light_pose[:3, 3] = np.array([0, 1, 1]) 113 | # scene.add(light, pose=light_pose) 114 | 115 | # light_pose[:3, 3] = np.array([1, 1, 2]) 116 | # scene.add(light, pose=light_pose) 117 | 118 | light_nodes = create_raymond_lights() 119 | for node in light_nodes: 120 | scene.add_node(node) 121 | 122 | color, rend_depth = renderer.render( 123 | scene, flags=pyrender.RenderFlags.RGBA) 124 | color = color.astype(np.float32) / 255.0 125 | valid_mask = (rend_depth > 0)[:, :, None] 126 | # output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image) 127 | 128 | if side_view or top_view: 129 | output_img = color[:, :, :3] 130 | # if not side_view: 131 | else: 132 | output_img = (color[:, :, :3] * valid_mask + 133 | (1 - valid_mask) * image) 134 | # else: 135 | # output_img = color[:, :, :3] 136 | renderer.delete() 137 | return output_img -------------------------------------------------------------------------------- /lib/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import numpy as np 12 | import cv2 13 | 14 | import torch 15 | 16 | 17 | def flip_back(output_flipped, matched_parts): 18 | ''' 19 | ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width) 20 | ''' 21 | assert output_flipped.ndim == 4,\ 22 | 'output_flipped should be [batch_size, num_joints, height, width]' 23 | 24 | output_flipped = output_flipped[:, :, :, ::-1] 25 | 26 | for pair in matched_parts: 27 | tmp = output_flipped[:, pair[0], :, :].copy() 28 | output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] 29 | output_flipped[:, pair[1], :, :] = tmp 30 | 31 | return output_flipped 32 | 33 | 34 | def fliplr_joints(joints, joints_vis, width, matched_parts): 35 | """ 36 | flip coords 37 | """ 38 | # Flip horizontal 39 | joints[:, 0] = width - joints[:, 0] - 1 40 | 41 | # Change left-right parts 42 | for pair in matched_parts: 43 | joints[pair[0], :], joints[pair[1], :] = \ 44 | joints[pair[1], :], joints[pair[0], :].copy() 45 | joints_vis[pair[0], :], joints_vis[pair[1], :] = \ 46 | joints_vis[pair[1], :], joints_vis[pair[0], :].copy() 47 | 48 | return joints * joints_vis, joints_vis 49 | 50 | 51 | def transform_preds(coords, center, scale, output_size): 52 | target_coords = np.zeros(coords.shape) 53 | trans = get_affine_transform(center, scale, 0, output_size, inv=1) 54 | for p in range(coords.shape[0]): 55 | target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) 56 | return target_coords 57 | 58 | 59 | def get_affine_transform(center, 60 | scale, 61 | rot, 62 | output_size, 63 | shift=np.array([0, 0], dtype=np.float32), 64 | inv=0): 65 | if not isinstance(scale, np.ndarray) and not isinstance(scale, list): 66 | scale = np.array([scale, scale]) 67 | 68 | scale_tmp = scale * 200.0 69 | src_w = scale_tmp[0] 70 | dst_w = output_size[0] 71 | dst_h = output_size[1] 72 | 73 | rot_rad = np.pi * rot / 180 74 | src_dir = get_dir([0, src_w * -0.5], rot_rad) 75 | dst_dir = np.array([0, dst_w * -0.5], np.float32) 76 | 77 | src = np.zeros((3, 2), dtype=np.float32) 78 | dst = np.zeros((3, 2), dtype=np.float32) 79 | src[0, :] = center + scale_tmp * shift 80 | src[1, :] = center + src_dir + scale_tmp * shift 81 | dst[0, :] = [dst_w * 0.5, dst_h * 0.5] 82 | dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir 83 | 84 | src[2:, :] = get_3rd_point(src[0, :], src[1, :]) 85 | dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) 86 | 87 | if inv: 88 | trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) 89 | else: 90 | trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) 91 | 92 | return trans 93 | 94 | 95 | def affine_transform(pt, t): 96 | new_pt = np.array([pt[0], pt[1], 1.]).T 97 | new_pt = np.dot(t, new_pt) 98 | return new_pt[:2] 99 | 100 | 101 | def affine_transform_pts(pts, t): 102 | xyz = np.add( 103 | np.array([[1, 0], [0, 1], [0, 0]]).dot(pts.T), np.array([[0], [0], 104 | [1]])) 105 | return np.dot(t, xyz).T 106 | 107 | 108 | def affine_transform_pts_cuda(pts, t): 109 | npts = pts.shape[0] 110 | pts_homo = torch.cat([pts, torch.ones(npts, 1, device=pts.device)], dim=1) 111 | out = torch.mm(t, torch.t(pts_homo)) 112 | return torch.t(out[:2, :]) 113 | 114 | 115 | def get_3rd_point(a, b): 116 | direct = a - b 117 | return b + np.array([-direct[1], direct[0]], dtype=np.float32) 118 | 119 | 120 | def get_dir(src_point, rot_rad): 121 | sn, cs = np.sin(rot_rad), np.cos(rot_rad) 122 | 123 | src_result = [0, 0] 124 | src_result[0] = src_point[0] * cs - src_point[1] * sn 125 | src_result[1] = src_point[0] * sn + src_point[1] * cs 126 | 127 | return src_result 128 | 129 | 130 | def crop(img, center, scale, output_size, rot=0): 131 | trans = get_affine_transform(center, scale, rot, output_size) 132 | 133 | dst_img = cv2.warpAffine( 134 | img, 135 | trans, (int(output_size[0]), int(output_size[1])), 136 | flags=cv2.INTER_LINEAR) 137 | 138 | return dst_img -------------------------------------------------------------------------------- /lib/utils/triangulate.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | # import parse 8 | import pickle 9 | import numpy as np 10 | 11 | from pymvg.camera_model import CameraModel 12 | from pymvg.multi_camera_system import MultiCameraSystem 13 | # from multiviews.cameras import unfold_camera_param 14 | 15 | def camera_to_world_frame(x, R, T): 16 | """ 17 | Args 18 | x: Nx3 points in camera coordinates 19 | R: 3x3 Camera rotation matrix 20 | T: 3x1 Camera translation parameters 21 | Returns 22 | xcam: Nx3 points in world coordinates 23 | """ 24 | 25 | xcam = R.T.dot(x.T) + T # rotate and translate 26 | # xcam = R.dot(x.T) + T # rotate and translate 27 | return xcam.T 28 | 29 | 30 | def unfold_camera_param(camera): 31 | R = camera['R'] 32 | T = camera['T'] 33 | f = 0.5 * (camera['fx'] + camera['fy']) 34 | c = np.array([camera['cx'], camera['cy']]) 35 | k = camera['k'] 36 | p = camera['p'] 37 | return R, T, f, c, k, p 38 | 39 | def build_multi_camera_system(cameras): 40 | """ 41 | Build a multi-camera system with pymvg package for triangulation 42 | 43 | Args: 44 | cameras: list of camera parameters 45 | Returns: 46 | cams_system: a multi-cameras system 47 | """ 48 | pymvg_cameras = [] 49 | for (name, camera) in cameras: 50 | R, T, f, c, k, p = unfold_camera_param(camera) 51 | camera_matrix = np.array( 52 | [[f, 0, c[0]], [0, f, c[1]], [0, 0, 1]], dtype=float) 53 | proj_matrix = np.zeros((3, 4)) 54 | proj_matrix[:3, :3] = camera_matrix 55 | distortion = np.array([k[0], k[1], p[0], p[1], k[2]]) 56 | distortion.shape = (5,) 57 | T = -np.matmul(R, T) 58 | M = camera_matrix.dot(np.concatenate((R, T), axis=1)) 59 | camera = CameraModel.load_camera_from_M( 60 | M, name=name, distortion_coefficients=distortion) 61 | pymvg_cameras.append(camera) 62 | return MultiCameraSystem(pymvg_cameras) 63 | 64 | 65 | def triangulate_one_point(camera_system, points_2d_set): 66 | """ 67 | Triangulate 3d point in world coordinates with multi-views 2d points 68 | 69 | Args: 70 | camera_system: pymvg camera system 71 | points_2d_set: list of structure (camera_name, point2d) 72 | Returns: 73 | points_3d: 3x1 point in world coordinates 74 | """ 75 | points_3d = camera_system.find3d(points_2d_set) 76 | return points_3d 77 | 78 | 79 | def triangulate_poses(camera_params, poses2d, nviews): 80 | """ 81 | Triangulate 3d points in world coordinates of multi-view 2d poses 82 | by interatively calling $triangulate_one_point$ 83 | 84 | Args: 85 | camera_params: a list of camera parameters, each corresponding to 86 | one prediction in poses2d 87 | poses2d: ndarray of shape nxkx2, len(cameras) == n 88 | Returns: 89 | poses3d: ndarray of shape n/nviews x k x 3 90 | """ 91 | # nviews = 4 92 | njoints = poses2d.shape[1] 93 | ninstances = len(camera_params) // nviews 94 | 95 | poses3d = [] 96 | for i in range(ninstances): 97 | cameras = [] 98 | for j in range(nviews): 99 | camera_name = 'camera_{}'.format(j) 100 | cameras.append((camera_name, camera_params[i * nviews + j])) 101 | camera_system = build_multi_camera_system(cameras) 102 | 103 | pose3d = np.zeros((njoints, 3)) 104 | for k in range(njoints): 105 | points_2d_set = [] 106 | 107 | for j in range(nviews): 108 | camera_name = 'camera_{}'.format(j) 109 | points_2d = poses2d[i * nviews + j, k, :] 110 | points_2d_set.append((camera_name, points_2d)) 111 | pose3d[k, :] = triangulate_one_point(camera_system, points_2d_set).T 112 | poses3d.append(pose3d) 113 | return np.array(poses3d) 114 | -------------------------------------------------------------------------------- /lib/utils/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage 3 | import skimage.transform 4 | import cv2 5 | 6 | import torch 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | from matplotlib import pylab as plt 11 | from mpl_toolkits.mplot3d import axes3d, Axes3D 12 | 13 | 14 | from lib.utils.img import image_batch_to_numpy, to_numpy, denormalize_image, resize_image 15 | from lib.utils.multiview import project_3d_points_to_image_plane_without_distortion 16 | 17 | CONNECTIVITY_DICT = { 18 | 'cmu': [(0, 2), (0, 9), (1, 0), (1, 17), (2, 12), (3, 0), (4, 3), (5, 4), (6, 2), (7, 6), (8, 7), (9, 10), (10, 11), (12, 13), (13, 14), (15, 1), (16, 15), (17, 18)], 19 | 'coco': [(0, 1), (0, 2), (1, 3), (2, 4), (5, 7), (7, 9), (6, 8), (8, 10), (11, 13), (13, 15), (12, 14), (14, 16), (5, 6), (5, 11), (6, 12), (11, 12)], 20 | "mpii": [(0, 1), (1, 2), (2, 6), (5, 4), (4, 3), (3, 6), (6, 7), (7, 8), (8, 9), (8, 12), (8, 13), (10, 11), (11, 12), (13, 14), (14, 15)], 21 | # "human36m": [(0, 1), (1, 2), (2, 6), (5, 4), (4, 3), (3, 6), (6, 7), (7, 8), (8, 16), (9, 16), (8, 12), (11, 12), (10, 11), (8, 13), (13, 14), (14, 15)], 22 | "h36m": [(3, 2), (2, 1), (1, 0), (6, 5), (5, 4), (4, 0), (0, 7), (7, 8), (8, 9), (9, 10), (8, 14), (14, 15), (15, 16), (8, 11), (11, 12), (12, 13)], 23 | "kth": [(0, 1), (1, 2), (5, 4), (4, 3), (6, 7), (7, 8), (11, 10), (10, 9), (2, 3), (3, 9), (2, 8), (9, 12), (8, 12), (12, 13)], 24 | } 25 | 26 | COLOR_DICT = { 27 | 'coco': [ 28 | (102, 0, 153), (153, 0, 102), (51, 0, 153), (153, 0, 153), # head 29 | (51, 153, 0), (0, 153, 0), # left arm 30 | (153, 102, 0), (153, 153, 0), # right arm 31 | (0, 51, 153), (0, 0, 153), # left leg 32 | (0, 153, 102), (0, 153, 153), # right leg 33 | (153, 0, 0), (153, 0, 0), (153, 0, 0), (153, 0, 0) # body 34 | ], 35 | 36 | 'h36m': [ 37 | (0, 153, 102), (0, 153, 153), (0, 153, 153), # right leg 38 | (0, 51, 153), (0, 0, 153), (0, 0, 153), # left leg 39 | (153, 0, 0), (153, 0, 0), # body 40 | (153, 0, 102), (153, 0, 102), # head 41 | (153, 153, 0), (153, 153, 0), (153, 102, 0), # right arm 42 | (0, 153, 0), (0, 153, 0), (51, 153, 0) # left arm 43 | ], 44 | 45 | 'kth': [ 46 | (0, 153, 102), (0, 153, 153), # right leg 47 | (0, 51, 153), (0, 0, 153), # left leg 48 | (153, 102, 0), (153, 153, 0), # right arm 49 | (51, 153, 0), (0, 153, 0), # left arm 50 | (153, 0, 0), (153, 0, 0), (153, 0, 0), (153, 0, 0), (153, 0, 0), # body 51 | (102, 0, 153) # head 52 | ] 53 | } 54 | 55 | JOINT_NAMES_DICT = { 56 | 'coco': { 57 | 0: "nose", 58 | 1: "left_eye", 59 | 2: "right_eye", 60 | 3: "left_ear", 61 | 4: "right_ear", 62 | 5: "left_shoulder", 63 | 6: "right_shoulder", 64 | 7: "left_elbow", 65 | 8: "right_elbow", 66 | 9: "left_wrist", 67 | 10: "right_wrist", 68 | 11: "left_hip", 69 | 12: "right_hip", 70 | 13: "left_knee", 71 | 14: "right_knee", 72 | 15: "left_ankle", 73 | 16: "right_ankle" 74 | } 75 | } 76 | 77 | 78 | def fig_to_array(fig): 79 | fig.canvas.draw() 80 | fig_image = np.array(fig.canvas.renderer._renderer) 81 | 82 | return fig_image 83 | 84 | 85 | def visualize_batch(images_batch, heatmaps_batch, keypoints_2d_batch, proj_matricies_batch, 86 | keypoints_3d_batch_gt, keypoints_3d_batch_pred, 87 | kind="cmu", 88 | cuboids_batch=None, 89 | confidences_batch=None, 90 | batch_index=0, size=5, 91 | max_n_cols=10, 92 | pred_kind=None 93 | ): 94 | if pred_kind is None: 95 | pred_kind = kind 96 | 97 | n_views, n_joints = heatmaps_batch.shape[1], heatmaps_batch.shape[2] 98 | 99 | n_rows = 3 100 | n_rows = n_rows + 1 if keypoints_2d_batch is not None else n_rows 101 | n_rows = n_rows + 1 if cuboids_batch is not None else n_rows 102 | n_rows = n_rows + 1 if confidences_batch is not None else n_rows 103 | 104 | n_cols = min(n_views, max_n_cols) 105 | fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * size, n_rows * size)) 106 | axes = axes.reshape(n_rows, n_cols) 107 | 108 | image_shape = images_batch.shape[3:] 109 | heatmap_shape = heatmaps_batch.shape[3:] 110 | 111 | row_i = 0 112 | 113 | # images 114 | axes[row_i, 0].set_ylabel("image", size='large') 115 | 116 | images = image_batch_to_numpy(images_batch[batch_index]) 117 | images = denormalize_image(images).astype(np.uint8) 118 | images = images[..., ::-1] # bgr -> rgb 119 | 120 | for view_i in range(n_cols): 121 | axes[row_i][view_i].imshow(images[view_i]) 122 | row_i += 1 123 | 124 | # 2D keypoints (pred) 125 | if keypoints_2d_batch is not None: 126 | axes[row_i, 0].set_ylabel("2d keypoints (pred)", size='large') 127 | 128 | keypoints_2d = to_numpy(keypoints_2d_batch)[batch_index] 129 | for view_i in range(n_cols): 130 | axes[row_i][view_i].imshow(images[view_i]) 131 | draw_2d_pose(keypoints_2d[view_i], axes[row_i][view_i], kind=kind) 132 | row_i += 1 133 | 134 | # 2D keypoints (gt projected) 135 | axes[row_i, 0].set_ylabel("2d keypoints (gt projected)", size='large') 136 | 137 | for view_i in range(n_cols): 138 | axes[row_i][view_i].imshow(images[view_i]) 139 | keypoints_2d_gt_proj = project_3d_points_to_image_plane_without_distortion(proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(), keypoints_3d_batch_gt[batch_index].detach().cpu().numpy()) 140 | draw_2d_pose(keypoints_2d_gt_proj, axes[row_i][view_i], kind=kind) 141 | row_i += 1 142 | 143 | # 2D keypoints (pred projected) 144 | axes[row_i, 0].set_ylabel("2d keypoints (pred projected)", size='large') 145 | 146 | for view_i in range(n_cols): 147 | axes[row_i][view_i].imshow(images[view_i]) 148 | keypoints_2d_pred_proj = project_3d_points_to_image_plane_without_distortion(proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(), keypoints_3d_batch_pred[batch_index].detach().cpu().numpy()) 149 | draw_2d_pose(keypoints_2d_pred_proj, axes[row_i][view_i], kind=pred_kind) 150 | row_i += 1 151 | 152 | # cuboids 153 | if cuboids_batch is not None: 154 | axes[row_i, 0].set_ylabel("cuboid", size='large') 155 | 156 | for view_i in range(n_cols): 157 | cuboid = cuboids_batch[batch_index] 158 | axes[row_i][view_i].imshow(cuboid.render(proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(), images[view_i].copy())) 159 | row_i += 1 160 | 161 | # confidences 162 | if confidences_batch is not None: 163 | axes[row_i, 0].set_ylabel("confidences", size='large') 164 | 165 | for view_i in range(n_cols): 166 | confidences = to_numpy(confidences_batch[batch_index, view_i]) 167 | xs = np.arange(len(confidences)) 168 | 169 | axes[row_i, view_i].bar(xs, confidences, color='green') 170 | axes[row_i, view_i].set_xticks(xs) 171 | if torch.max(confidences_batch).item() <= 1.0: 172 | axes[row_i, view_i].set_ylim(0.0, 1.0) 173 | 174 | fig.tight_layout() 175 | 176 | fig_image = fig_to_array(fig) 177 | 178 | plt.close('all') 179 | 180 | return fig_image 181 | 182 | def visualize_2d_pose(images, keypoits_2ds, kind = 'h36m', size = 5): 183 | images = image_batch_to_numpy(images) 184 | images = denormalize_image(images).astype(np.uint8) 185 | n_cols = 2 186 | # n_rows = images.shape[0] 187 | n_rows = min(images.shape[0], 4 * 8) 188 | keypoits_2ds = to_numpy(keypoits_2ds) 189 | fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * size, n_rows * size)) 190 | axes = axes.reshape(n_rows, n_cols) 191 | for row_i in range(n_rows): 192 | axes[row_i][0].imshow(images[row_i]) 193 | axes[row_i][1].imshow(images[row_i]) 194 | draw_2d_pose(keypoits_2ds[row_i], axes[row_i][1], kind = kind) 195 | fig.tight_layout() 196 | 197 | fig_image = fig_to_array(fig) 198 | 199 | plt.close('all') 200 | 201 | return fig_image.transpose(2, 0, 1) 202 | 203 | def visualize_heatmaps(images_batch, heatmaps_batch, 204 | kind="cmu", 205 | batch_index=0, size=5, 206 | max_n_rows=10, max_n_cols=10): 207 | n_views, n_joints = heatmaps_batch.shape[1], heatmaps_batch.shape[2] 208 | heatmap_shape = heatmaps_batch.shape[3:] 209 | 210 | n_cols, n_rows = min(n_joints + 1, max_n_cols), min(n_views, max_n_rows) 211 | fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * size, n_rows * size)) 212 | axes = axes.reshape(n_rows, n_cols) 213 | 214 | # images 215 | images = image_batch_to_numpy(images_batch[batch_index]) 216 | images = denormalize_image(images).astype(np.uint8) 217 | images = images[..., ::-1] # bgr -> 218 | 219 | # heatmaps 220 | heatmaps = to_numpy(heatmaps_batch[batch_index]) 221 | 222 | for row in range(n_rows): 223 | for col in range(n_cols): 224 | if col == 0: 225 | axes[row, col].set_ylabel(str(row), size='large') 226 | axes[row, col].imshow(images[row]) 227 | else: 228 | if row == 0: 229 | joint_name = JOINT_NAMES_DICT[kind][col - 1] if kind in JOINT_NAMES_DICT else str(col - 1) 230 | axes[row, col].set_title(joint_name) 231 | 232 | axes[row, col].imshow(resize_image(images[row], heatmap_shape)) 233 | axes[row, col].imshow(heatmaps[row, col - 1], alpha=0.5) 234 | 235 | fig.tight_layout() 236 | 237 | fig_image = fig_to_array(fig) 238 | 239 | plt.close('all') 240 | 241 | return fig_image 242 | 243 | 244 | def visualize_volumes(images_batch, volumes_batch, proj_matricies_batch, 245 | kind="cmu", 246 | cuboids_batch=None, 247 | batch_index=0, size=5, 248 | max_n_rows=10, max_n_cols=10): 249 | n_views, n_joints = volumes_batch.shape[1], volumes_batch.shape[2] 250 | 251 | n_cols, n_rows = min(n_joints + 1, max_n_cols), min(n_views, max_n_rows) 252 | fig = plt.figure(figsize=(n_cols * size, n_rows * size)) 253 | 254 | # images 255 | images = image_batch_to_numpy(images_batch[batch_index]) 256 | images = denormalize_image(images).astype(np.uint8) 257 | images = images[..., ::-1] # bgr -> 258 | 259 | # heatmaps 260 | volumes = to_numpy(volumes_batch[batch_index]) 261 | 262 | for row in range(n_rows): 263 | for col in range(n_cols): 264 | if col == 0: 265 | ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1) 266 | ax.set_ylabel(str(row), size='large') 267 | 268 | cuboid = cuboids_batch[batch_index] 269 | ax.imshow(cuboid.render(proj_matricies_batch[batch_index, row].detach().cpu().numpy(), images[row].copy())) 270 | else: 271 | ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1, projection='3d') 272 | 273 | if row == 0: 274 | joint_name = JOINT_NAMES_DICT[kind][col - 1] if kind in JOINT_NAMES_DICT else str(col - 1) 275 | ax.set_title(joint_name) 276 | 277 | draw_voxels(volumes[col - 1], ax, norm=True) 278 | 279 | fig.tight_layout() 280 | 281 | fig_image = fig_to_array(fig) 282 | 283 | plt.close('all') 284 | 285 | return fig_image 286 | 287 | 288 | def draw_2d_pose(keypoints, ax, kind='h36m', keypoints_mask=None, point_size=2, line_width=1, radius=None, colors=None): 289 | """ 290 | Visualizes a 2d skeleton 291 | 292 | Args 293 | keypoints numpy array of shape (19, 2): pose to draw in CMU format. 294 | ax: matplotlib axis to draw on 295 | """ 296 | connectivity = CONNECTIVITY_DICT[kind] 297 | 298 | # color = 'blue' if color is None else color 299 | colors = COLOR_DICT[kind] 300 | 301 | if keypoints_mask is None: 302 | keypoints_mask = [True] * len(keypoints) 303 | 304 | # points 305 | ax.scatter(keypoints[keypoints_mask][:, 0], keypoints[keypoints_mask][:, 1], c='red', s=point_size) 306 | 307 | # connections 308 | for i, (index_from, index_to) in enumerate(connectivity): 309 | if keypoints_mask[index_from] and keypoints_mask[index_to]: 310 | xs, ys = [np.array([keypoints[index_from, j], keypoints[index_to, j]]) for j in range(2)] 311 | color = colors[i] 312 | color = np.array(color) / 255 313 | ax.plot(xs, ys, c=color, lw=line_width) 314 | 315 | if radius is not None: 316 | root_keypoint_index = 0 317 | xroot, yroot = keypoints[root_keypoint_index, 0], keypoints[root_keypoint_index, 1] 318 | 319 | ax.set_xlim([-radius + xroot, radius + xroot]) 320 | ax.set_ylim([-radius + yroot, radius + yroot]) 321 | 322 | ax.set_aspect('equal') 323 | ax.xaxis.set_visible(False) 324 | ax.yaxis.set_visible(False) 325 | 326 | 327 | def draw_2d_pose_cv2(keypoints, canvas, kind='cmu', keypoints_mask=None, point_size=2, point_color=(255, 255, 255), line_width=1, radius=None, color=None, anti_aliasing_scale=1): 328 | canvas = canvas.copy() 329 | 330 | shape = np.array(canvas.shape[:2]) 331 | new_shape = shape * anti_aliasing_scale 332 | canvas = resize_image(canvas, tuple(new_shape)) 333 | 334 | keypoints = keypoints * anti_aliasing_scale 335 | point_size = point_size * anti_aliasing_scale 336 | line_width = line_width * anti_aliasing_scale 337 | 338 | connectivity = CONNECTIVITY_DICT[kind] 339 | 340 | color = 'blue' if color is None else color 341 | 342 | if keypoints_mask is None: 343 | keypoints_mask = [True] * len(keypoints) 344 | 345 | # connections 346 | for i, (index_from, index_to) in enumerate(connectivity): 347 | if keypoints_mask[index_from] and keypoints_mask[index_to]: 348 | pt_from = tuple(np.array(keypoints[index_from, :]).astype(int)) 349 | pt_to = tuple(np.array(keypoints[index_to, :]).astype(int)) 350 | 351 | if kind in COLOR_DICT: 352 | color = COLOR_DICT[kind][i] 353 | else: 354 | color = (0, 0, 255) 355 | 356 | cv2.line(canvas, pt_from, pt_to, color=color, thickness=line_width) 357 | 358 | if kind == 'coco': 359 | mid_collarbone = (keypoints[5, :] + keypoints[6, :]) / 2 360 | nose = keypoints[0, :] 361 | 362 | pt_from = tuple(np.array(nose).astype(int)) 363 | pt_to = tuple(np.array(mid_collarbone).astype(int)) 364 | 365 | if kind in COLOR_DICT: 366 | color = (153, 0, 51) 367 | else: 368 | color = (0, 0, 255) 369 | 370 | cv2.line(canvas, pt_from, pt_to, color=color, thickness=line_width) 371 | 372 | # points 373 | for pt in keypoints[keypoints_mask]: 374 | cv2.circle(canvas, tuple(pt.astype(int)), point_size, color=point_color, thickness=-1) 375 | 376 | canvas = resize_image(canvas, tuple(shape)) 377 | 378 | return canvas 379 | 380 | 381 | def draw_3d_pose(keypoints, ax, keypoints_mask=None, kind='cmu', radius=None, root=None, point_size=2, line_width=2, draw_connections=True): 382 | connectivity = CONNECTIVITY_DICT[kind] 383 | 384 | if keypoints_mask is None: 385 | keypoints_mask = [True] * len(keypoints) 386 | 387 | if draw_connections: 388 | # Make connection matrix 389 | for i, joint in enumerate(connectivity): 390 | if keypoints_mask[joint[0]] and keypoints_mask[joint[1]]: 391 | xs, ys, zs = [np.array([keypoints[joint[0], j], keypoints[joint[1], j]]) for j in range(3)] 392 | 393 | if kind in COLOR_DICT: 394 | color = COLOR_DICT[kind][i] 395 | else: 396 | color = (0, 0, 255) 397 | 398 | color = np.array(color) / 255 399 | 400 | ax.plot(xs, ys, zs, lw=line_width, c=color) 401 | 402 | if kind == 'coco': 403 | mid_collarbone = (keypoints[5, :] + keypoints[6, :]) / 2 404 | nose = keypoints[0, :] 405 | 406 | xs, ys, zs = [np.array([nose[j], mid_collarbone[j]]) for j in range(3)] 407 | 408 | if kind in COLOR_DICT: 409 | color = (153, 0, 51) 410 | else: 411 | color = (0, 0, 255) 412 | 413 | color = np.array(color) / 255 414 | 415 | ax.plot(xs, ys, zs, lw=line_width, c=color) 416 | 417 | 418 | ax.scatter(keypoints[keypoints_mask][:, 0], keypoints[keypoints_mask][:, 1], keypoints[keypoints_mask][:, 2], 419 | s=point_size, c=np.array([230, 145, 56])/255, edgecolors='black') # np.array([230, 145, 56])/255 420 | 421 | if radius is not None: 422 | if root is None: 423 | root = np.mean(keypoints, axis=0) 424 | xroot, yroot, zroot = root 425 | ax.set_xlim([-radius + xroot, radius + xroot]) 426 | ax.set_ylim([-radius + yroot, radius + yroot]) 427 | ax.set_zlim([-radius + zroot, radius + zroot]) 428 | 429 | ax.set_aspect('equal') 430 | 431 | 432 | # Get rid of the panes 433 | background_color = np.array([252, 252, 252]) / 255 434 | 435 | ax.w_xaxis.set_pane_color(background_color) 436 | ax.w_yaxis.set_pane_color(background_color) 437 | ax.w_zaxis.set_pane_color(background_color) 438 | 439 | # Get rid of the ticks 440 | ax.set_xticklabels([]) 441 | ax.set_yticklabels([]) 442 | ax.set_zticklabels([]) 443 | 444 | 445 | def draw_voxels(voxels, ax, shape=(8, 8, 8), norm=True, alpha=0.1): 446 | # resize for visualization 447 | zoom = np.array(shape) / np.array(voxels.shape) 448 | voxels = skimage.transform.resize(voxels, shape, mode='constant', anti_aliasing=True) 449 | voxels = voxels.transpose(2, 0, 1) 450 | 451 | if norm and voxels.max() - voxels.min() > 0: 452 | voxels = (voxels - voxels.min()) / (voxels.max() - voxels.min()) 453 | 454 | filled = np.ones(voxels.shape) 455 | 456 | # facecolors 457 | cmap = plt.get_cmap("Blues") 458 | 459 | facecolors_a = cmap(voxels, alpha=alpha) 460 | facecolors_a = facecolors_a.reshape(-1, 4) 461 | 462 | facecolors_hex = np.array(list(map(lambda x: matplotlib.colors.to_hex(x, keep_alpha=True), facecolors_a))) 463 | facecolors_hex = facecolors_hex.reshape(*voxels.shape) 464 | 465 | # explode voxels to perform 3d alpha rendering (https://matplotlib.org/devdocs/gallery/mplot3d/voxels_numpy_logo.html) 466 | def explode(data): 467 | size = np.array(data.shape) * 2 468 | data_e = np.zeros(size - 1, dtype=data.dtype) 469 | data_e[::2, ::2, ::2] = data 470 | return data_e 471 | 472 | filled_2 = explode(filled) 473 | facecolors_2 = explode(facecolors_hex) 474 | 475 | # shrink the gaps 476 | x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2 477 | x[0::2, :, :] += 0.05 478 | y[:, 0::2, :] += 0.05 479 | z[:, :, 0::2] += 0.05 480 | x[1::2, :, :] += 0.95 481 | y[:, 1::2, :] += 0.95 482 | z[:, :, 1::2] += 0.95 483 | 484 | # draw voxels 485 | ax.voxels(x, y, z, filled_2, facecolors=facecolors_2) 486 | 487 | ax.set_xlabel("z"); ax.set_ylabel("x"); ax.set_zlabel("y") 488 | ax.invert_xaxis(); ax.invert_zaxis() 489 | -------------------------------------------------------------------------------- /lib/utils/zipreader.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License. 4 | # Written by Chunyu Wang (chnuwa@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import zipfile 13 | import xml.etree.ElementTree as ET 14 | 15 | import cv2 16 | import numpy as np 17 | 18 | _im_zfile = [] 19 | _xml_path_zip = [] 20 | _xml_zfile = [] 21 | 22 | 23 | def imread(filename, flags=cv2.IMREAD_COLOR): 24 | global _im_zfile 25 | path = filename 26 | pos_at = path.index('@') 27 | if pos_at == -1: 28 | print("character '@' is not found from the given path '%s'" % (path)) 29 | assert 0 30 | path_zip = path[0:pos_at] 31 | path_img = path[pos_at + 2:] 32 | if not os.path.isfile(path_zip): 33 | print("zip file '%s' is not found" % (path_zip)) 34 | assert 0 35 | for i in range(len(_im_zfile)): 36 | if _im_zfile[i]['path'] == path_zip: 37 | data = _im_zfile[i]['zipfile'].read(path_img) 38 | return cv2.imdecode(np.frombuffer(data, np.uint8), flags) 39 | 40 | _im_zfile.append({ 41 | 'path': path_zip, 42 | 'zipfile': zipfile.ZipFile(path_zip, 'r') 43 | }) 44 | data = _im_zfile[-1]['zipfile'].read(path_img) 45 | 46 | return cv2.imdecode(np.frombuffer(data, np.uint8), flags) 47 | 48 | 49 | def xmlread(filename): 50 | global _xml_path_zip 51 | global _xml_zfile 52 | path = filename 53 | pos_at = path.index('@') 54 | if pos_at == -1: 55 | print("character '@' is not found from the given path '%s'" % (path)) 56 | assert 0 57 | path_zip = path[0:pos_at] 58 | path_xml = path[pos_at + 2:] 59 | if not os.path.isfile(path_zip): 60 | print("zip file '%s' is not found" % (path_zip)) 61 | assert 0 62 | for i in range(len(_xml_path_zip)): 63 | if _xml_path_zip[i] == path_zip: 64 | data = _xml_zfile[i].open(path_xml) 65 | return ET.fromstring(data.read()) 66 | _xml_path_zip.append(path_zip) 67 | print("read new xml file '%s'" % (path_zip)) 68 | _xml_zfile.append(zipfile.ZipFile(path_zip, 'r')) 69 | data = _xml_zfile[-1].open(path_xml) 70 | return ET.fromstring(data.read()) 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.11 2 | einops==0.6.1 3 | h5py==3.6.0 4 | matplotlib==3.5.1 5 | opencv-python==4.8.0.76 6 | pymvg==2.1.0 7 | pyopengl==3.1.0 8 | smplx==0.1.28 9 | stack_data==0.2.0 10 | tensorboardx==2.5.1 11 | timm==0.9.7 12 | torch====1.11.0 13 | torchvision==0.12.0 14 | tqdm==4.64.0 15 | trimesh==3.23.5 16 | typing-extensions==4.7.1 17 | yacs==0.1.8 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib import datasets 3 | from lib.utils.config import get_config 4 | from lib.models.fusion import Mv_Fusion 5 | from lib.utils.log_utils import create_logger, load_checkpoint, save_checkpoint 6 | import os 7 | import time 8 | import argparse 9 | import random 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | def main(): 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | '--cfg_name', help='experiment configure file name', required=True, type=str) 33 | parser.add_argument( 34 | '--dataset', help='experiment configure file name', required=True, type=str) 35 | args = parser.parse_args() 36 | cfg = get_config('experiments/{}/{}'.format(args.dataset, args.cfg_name), merge=False) 37 | cfg_name = args.cfg_name 38 | 39 | 40 | # cfg_name = 'vit_pos_trans_encoder.yaml' 41 | # cfg = get_config('experiments/h36m/{}'.format(cfg_name), merge= False) 42 | 43 | if cfg.IS_TRAIN: 44 | phase = 'train' 45 | else: 46 | phase = 'test' 47 | logger, final_output_dir, tensorboard_log_dir = create_logger(cfg, cfg_name, phase) 48 | gpus=[0] 49 | model = Mv_Fusion(cfg, tensorboard_log_dir) 50 | model = torch.nn.DataParallel(model, device_ids=gpus).cuda() 51 | 52 | mocap_dataset = datasets.mocap_dataset(cfg.DATASET.MOCAP) 53 | train_dataset = eval('datasets.' + cfg.DATASET.TRAIN_DATASET)(cfg, cfg.DATASET.TRAIN_SUBSET, True) 54 | 55 | train_loader = torch.utils.data.DataLoader( 56 | train_dataset, 57 | batch_size=cfg.TRAIN.BATCH_SIZE, 58 | shuffle=True, 59 | drop_last=True, 60 | num_workers=cfg.GENERAL.NUM_WORKERS, 61 | pin_memory=True) 62 | mocap_loader = torch.utils.data.DataLoader( 63 | mocap_dataset, 64 | batch_size=cfg.TRAIN.BATCH_SIZE * cfg.DATASET.N_VIEWS, 65 | shuffle=True, 66 | drop_last=True, 67 | num_workers=1, 68 | pin_memory=True) 69 | val_dataset = eval('datasets.' + cfg.DATASET.TEST_DATASET)(cfg, cfg.DATASET.TEST_SUBSET, False) 70 | val_loader = torch.utils.data.DataLoader( 71 | val_dataset, 72 | batch_size=cfg.TRAIN.BATCH_SIZE, 73 | shuffle=False, 74 | drop_last = True, 75 | num_workers=cfg.GENERAL.NUM_WORKERS, 76 | pin_memory=True) 77 | logger.info(f'=> Loaded datasets') 78 | len_train_data = len(train_loader) 79 | len_val_data = len(val_loader) 80 | best_perf = 1000000.0 81 | best_model = False 82 | if not cfg.IS_TRAIN: 83 | meters = {k: AverageMeter() for k in ['train_loss', 'val_loss', 'train_mpjpe', 'val_mpjpe', 'train_rec_error', 'val_rec_error']} 84 | model.eval() 85 | with torch.no_grad(): 86 | model.module.load_state_dict(torch.load(cfg.TEST.MODEL_FILE)['state_dict']) 87 | for i, data in enumerate(zip(val_loader, mocap_loader)): 88 | n_views = 4 89 | subset = random.sample(range(0, 4), n_views) 90 | subset.sort() 91 | (input, meta), mocap = data 92 | input_sub = [] 93 | meta_sub = [] 94 | for j in subset: 95 | input_sub.append(input[j]) 96 | meta_sub.append(meta[j]) 97 | model(input_sub, meta_sub, i, mocap, meters, len_val_data, n_views, train = False) 98 | logger.info(f'val_mpjpe: {meters["val_mpjpe"].avg}\t val_rec_error: {meters["val_rec_error"].avg}') 99 | return 100 | if cfg.TRAIN.RESUME: 101 | start_epoch, model = load_checkpoint(model, final_output_dir) 102 | for epoch in range(start_epoch, cfg.TRAIN.TOTAL_EPOCHS): 103 | meters = {k: AverageMeter() for k in ['train_loss', 'val_loss', 'train_mpjpe', 'val_mpjpe', 'train_rec_error', 'val_rec_error']} 104 | model.train() 105 | 106 | 107 | for i, data in enumerate(zip(train_loader, mocap_loader)): 108 | 109 | n_views = 4 110 | subset = random.sample(range(0, 4), n_views) 111 | subset.sort() 112 | (input, meta), mocap = data 113 | input_sub = [] 114 | meta_sub = [] 115 | for j in subset: 116 | input_sub.append(input[j]) 117 | meta_sub.append(meta[j]) 118 | mocap_sub = {} 119 | for k,v in mocap.items(): 120 | mocap_sub[k] = v[:cfg.TRAIN.BATCH_SIZE * n_views] 121 | model(input_sub, meta_sub, i, mocap_sub, meters, len_train_data, n_views, epoch, True) 122 | 123 | model.eval() 124 | with torch.no_grad(): 125 | for i, data in enumerate(zip(val_loader, mocap_loader)): 126 | n_views = 4 127 | subset = random.sample(range(0, 4), n_views) 128 | subset.sort() 129 | (input, meta), mocap = data 130 | input_sub = [] 131 | meta_sub = [] 132 | for j in subset: 133 | input_sub.append(input[j]) 134 | meta_sub.append(meta[j]) 135 | mocap_sub = {} 136 | for k,v in mocap.items(): 137 | mocap_sub[k] = v[:cfg.TRAIN.BATCH_SIZE * n_views] 138 | model(input_sub, meta_sub, i, mocap_sub, meters, len_val_data, n_views, epoch, False) 139 | logger.info(f'val_mpjpe: {meters["val_mpjpe"].avg}\t val_rec_error: {meters["val_rec_error"].avg}') 140 | perf_indicator = meters['val_mpjpe'].avg 141 | if perf_indicator < best_perf: 142 | best_perf = perf_indicator 143 | best_model = True 144 | else: 145 | best_model = False 146 | 147 | logger.info('=> saving checkpoint to {}'.format(final_output_dir)) 148 | save_checkpoint({ 149 | 'epoch': epoch, 150 | 'state_dict': model.module.state_dict(), 151 | 'perf': perf_indicator, 152 | 'optimizer': model.module.optimizer.state_dict(), 153 | }, best_model, final_output_dir) 154 | final_model_state_file = os.path.join(final_output_dir, 155 | 'final_state.pth.tar') 156 | logger.info('saving final model state to {}'.format(final_model_state_file)) 157 | torch.save(model.module.state_dict(), final_model_state_file) 158 | return 159 | 160 | 161 | 162 | 163 | if __name__ == "__main__": 164 | main() --------------------------------------------------------------------------------