├── LICENSE ├── README.md ├── apps ├── test_color.py ├── test_scanimate.py ├── train_color.py └── train_scanimate.py ├── configs ├── default.yaml └── example.yaml ├── download_aist_demo_motion.sh ├── install.sh ├── installation.txt ├── lib ├── __init__.py ├── config.py ├── data │ ├── CapeDataset.py │ └── __init__.py ├── ext_trimesh.py ├── geo_util.py ├── geometry.py ├── mesh_util.py ├── model │ ├── BaseIMNet3d.py │ ├── IGRSDFNet.py │ ├── LBSNet.py │ ├── MLP.py │ ├── TNet.py │ └── __init__.py ├── net_util.py └── sdf.py ├── render └── render_aist.py ├── requirements.txt ├── smpl ├── LICENSE ├── README.md ├── setup.py └── smpl │ ├── __init__.py │ ├── body_models.py │ ├── joint_names.py │ ├── lbs.py │ ├── utils.py │ ├── vertex_ids.py │ └── vertex_joint_selector.py └── teaser ├── aist_0.gif └── teaser.png /LICENSE: -------------------------------------------------------------------------------- 1 | License 2 | 3 | Software Copyright License for non-commercial scientific research purposes 4 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the SCANimate model, data and software, (the "Data & Software"), pre-trained neural network model parameters, pre-trained animatable avatars ("Scanimats"), software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License 5 | 6 | Ownership / Licensees 7 | The Software and the associated materials has been developed at the 8 | 9 | Max Planck Institute for Intelligent Systems (hereinafter "MPI"). 10 | 11 | Any copyright or patent right is owned by and proprietary material of the 12 | 13 | Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) 14 | 15 | hereinafter the “Licensor”. 16 | 17 | License Grant 18 | Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right: 19 | 20 | To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization; 21 | To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects; 22 | Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission. 23 | 24 | The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial use of any kind. By downloading the Data & Software, you agree not to reverse engineer it. 25 | 26 | No Distribution 27 | The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only. 28 | 29 | Disclaimer of Representations and Warranties 30 | You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE Data & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party. 31 | 32 | Limitation of Liability 33 | Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage. 34 | Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded. 35 | Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders. 36 | The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause. 37 | 38 | No Maintenance Services 39 | You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time. 40 | 41 | Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication. 42 | 43 | Publications using the Data & Software 44 | You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software. 45 | 46 | Citation: 47 | 48 | @inproceedings{Saito:CVPR:2021, 49 | title = {{SCANimate}: Weakly Supervised Learning of Skinned Clothed Avatar Networks}, 50 | author = {Saito, Shunsuke and Yang, Jinlong and Ma, Qianli and Black, Michael J.}, 51 | booktitle = {Proceedings IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR)}, 52 | month = jun, 53 | year = {2021}, 54 | doi = {}, 55 | month_numeric = {6} 56 | } 57 | Commercial licensing opportunities 58 | For commercial uses of the Software, please send email to ps-license@tue.mpg.de 59 | 60 | This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention. 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCANimate: Weakly Supervised Learning of Skinned Clothed Avatar Networks (CVPR 2021 Oral) 2 | 3 | [![Paper](https://img.shields.io/badge/arXiv-Paper-b31b1b.svg)](https://arxiv.org/pdf/2104.03313) 4 | 5 | This repository contains the official PyTorch implementation of: 6 | 7 | **SCANimate: Weakly Supervised Learning of Skinned Clothed Avatar Networks**
*Shunsuke Saito, Jinlong Yang, Qianli Ma, and Michael J. Black*
[Full paper](https://arxiv.org/pdf/2104.03313.pdf) | [5min Presentation](https://youtu.be/EeNFvmNuuog) | [Video](https://youtu.be/ohavL55Oznw) | [Project website](https://scanimate.is.tue.mpg.de/) | [Poster](https://scanimate.is.tue.mpg.de/media/upload/poster/CVPR_poster_SCANimate.pdf) 8 | 9 | ![](teaser/aist_0.gif) 10 | 11 | 12 | ## Installation 13 | Please follow the instructions in `./installation.txt` to install the environment and the SMPL model. 14 | 15 | ## Run SCANimate 16 | **0. Activate the environment if it is not already activated:** 17 | ```sh 18 | $ source ./venv/scanimate/bin/activate 19 | ``` 20 | 21 | **1. First download the pretrained model, some motion sequences and other files for the demo** 22 | 23 | - Download an AIST++ dance motion sequence for test (CC BY 4.0 license): 24 | 25 | ```sh 26 | $ . ./download_aist_demo_motion.sh 27 | ``` 28 | ​ This script will create a `data` folder under current directory, please make sure to put it under the `SCANimate` directory. 29 | 30 | - Download pre-trained scanimats for animation test: 31 | Please visit https://scanimate.is.tue.mpg.de/download.php, register, login, read and agree to the license and then download some demo scanimats. 32 | Unzip the zip file into `./data` directory 33 | 34 | - Download subset of CAPE data for training demo: 35 | Please visit https://scanimate.is.tue.mpg.de/download.php, register, login, read and agree to the license and then download the data for training demo. 36 | Unzip the zip file into `./data` directory. 37 | 38 | - Now you should have a `./data` directory under `SCANimate`. Within `./data` you will have 5 directories: `minimal_body`, `pretrained`, `pretrained_configs`, `test`, and `train`. 39 | 40 | 41 | ### Run animation demos: 42 | **2. Now you can run the test demo with the following command:** 43 | 44 | ```sh 45 | $ python -m apps.test_scanimate -c ./data/pretrained_configs/release_03223_shortlong.yaml -t ./data/test/gLO_sBM_cAll_d14_mLO1_ch05 46 | ``` 47 | - You can replace the configuration file with other files under `./data/pretrained_configs/` to try other subjects. 48 | - You can also replace the test motions with others under `./data/test`. 49 | - The result will be generated under `./demo_result/results_test`. 50 | 51 | **3. The generated mesh sequences can be rendered with the code under `./demo_result`**: 52 | 53 | First, install Open3D (for rendering the results) by: 54 | 55 | ```sh 56 | $ pip install open3d==0.12.0 57 | ``` 58 | 59 | Then run: 60 | 61 | ```sh 62 | $ python render/render_aist.py -i demo_result/results_test/release_03223_shortlong_test_gLO_sBM_cAll_d14_mLO1_ch05/ -o demo_result 63 | ``` 64 | ### Run training demo 65 | **2. Now you can run the demo training with** 66 | ```sh 67 | $ python -m apps.train_scanimate -c ./configs/example.yaml 68 | ``` 69 | The results can be found under `./demo_result/results/example`. 70 | 71 | **3. Train on your own data** 72 | Make your data the same structure as in the `./data/train/example_03375_shortlong`, where a `.ply` file contains a T-pose SMPL body mesh and a folder containing training frames. 73 | Each frame corresponds to two files: one `.npz` files containing SMPL parameters that describes the pose of the body (i.e. only 'transl' and 'pose' matters) and one `.ply` file containing the clothed scan. The body should align with the scan. 74 | Then, change the `./configs/example.yaml` to point to your data directory and you are good to go! 75 | 76 | ## License 77 | Software Copyright License for non-commercial scientific research purposes. Please read carefully the terms and conditions and any accompanying documentation before you download and/or use the SCANimate code, including the scripts, animation demos, pre-trained neural network model parameters and the pre-trained animatable avatars ("Scanimats"). By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this GitHub repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License. 78 | 79 | The provided demo data (including the body pose, raw scans and mesh registrations of clothed human bodies) are subject to the license of the [CAPE Dataset](https://cape.is.tue.mpg.de/). 80 | 81 | ## Related Research 82 | **[SCALE: Modeling Clothed Humans with a Surface Codec of Articulated Local Elements (CVPR 2021)](https://qianlim.github.io/SCALE)**
83 | *Qianli Ma, Shunsuke Saito, Jinlong Yang, Siyu Tang, Michael J. Black* 84 | 85 | Modeling pose-dependent shapes of clothed humans *explicitly* with hundreds of articulated surface elements: the clothing deforms naturally even in the presence of topological change! 86 | 87 | **[Learning to Dress 3D People in Generative Clothing (CVPR 2020)](https://cape.is.tue.mpg.de/)**
88 | *Qianli Ma, Jinlong Yang, Anurag Ranjan, Sergi Pujades, Gerard Pons-Moll, Siyu Tang, Michael J. Black* 89 | 90 | CAPE --- a generative model and a large-scale dataset for 3D clothed human meshes in varied poses and garment types. 91 | We trained SCANimate using the [CAPE dataset](https://cape.is.tue.mpg.de/dataset), check it out! 92 | 93 | 94 | ## Citations 95 | If you find our code or paper useful to your research, please consider citing: 96 | 97 | ```bibtex 98 | @inproceedings{Saito:CVPR:2021, 99 | title = {{SCANimate}: Weakly Supervised Learning of Skinned Clothed Avatar Networks}, 100 | author = {Saito, Shunsuke and Yang, Jinlong and Ma, Qianli and Black, Michael J.}, 101 | booktitle = {Proceedings IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR)}, 102 | month = jun, 103 | year = {2021}, 104 | month_numeric = {6}} 105 | ``` 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /apps/test_color.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from lib.data.CapeDataset import CapeDataset_scan_color 18 | import smpl 19 | from torch.utils.data import DataLoader 20 | from lib.config import load_config 21 | from lib.model.IGRSDFNet import IGRSDFNet 22 | from lib.model.LBSNet import LBSNet 23 | from lib.model.TNet import TNet 24 | 25 | import argparse 26 | import torch 27 | import os 28 | import json 29 | import numpy as np 30 | 31 | from lib.net_util import batch_rod2quat,homogenize, load_network, get_posemap 32 | import torch.nn as nn 33 | import math 34 | from lib.mesh_util import replace_hands_feet_wcolor 35 | from lib.mesh_util import reconstruction, save_obj_mesh, save_obj_mesh_with_color, scalar_to_color 36 | import time 37 | import trimesh 38 | from tqdm import tqdm 39 | from apps.train_color import gen_color_mesh 40 | 41 | def test(opt, test_input_dir): 42 | cuda = torch.device('cuda:0') 43 | 44 | tmp_dirs = test_input_dir.split('/') 45 | 46 | test_input_basedir = tmp_dirs.pop() 47 | while test_input_basedir == '': 48 | test_input_basedir = tmp_dirs.pop() 49 | opt['data']['test_dir'] = test_input_dir 50 | 51 | exp_name = opt['experiment']['name'] 52 | ckpt_dir = '%s/%s' % (opt['experiment']['ckpt_dir'], exp_name) 53 | result_dir = '%s/%s_test_%s_color' % (opt['experiment']['result_dir']+'_color', exp_name, test_input_basedir) 54 | os.makedirs(result_dir, exist_ok=True) 55 | 56 | # load checkpoints 57 | ckpt_dict = None 58 | print("Loading checkpoint from " + ckpt_dir) 59 | if os.path.isfile(os.path.join(ckpt_dir, 'ckpt_color_latest.pt')): 60 | print('loading ckpt...', os.path.join(ckpt_dir, 'ckpt_color_latest.pt')) 61 | ckpt_dict = torch.load(os.path.join(ckpt_dir, 'ckpt_color_latest.pt')) 62 | else: 63 | print('error: ckpt does not exist [%s]' % opt['experiment']['ckpt_file']) 64 | exit() 65 | 66 | model = smpl.create(opt['data']['smpl_dir'], model_type='smpl_vitruvian', 67 | gender=opt['data']['smpl_gender'], use_face_contour=False, 68 | ext='npz').to(cuda) 69 | 70 | 71 | tmp_dir = opt['data']['data_dir'] 72 | tmp_dir_files = sorted([f for f in os.listdir(tmp_dir) if '.ply' in f]) 73 | customized_minimal_ply = os.path.join(tmp_dir, tmp_dir_files[0]) 74 | test_dataset = CapeDataset_scan_color(opt['data'], phase='test', smpl=model, 75 | customized_minimal_ply=customized_minimal_ply, 76 | full_test = True) 77 | 78 | # reference_body_vs_train = train_dataset.subjects_minimal_v 79 | reference_body_vs_test = test_dataset.Tpose_minimal_v 80 | smpl_vitruvian = model.initiate_vitruvian(device = cuda, body_neutral_v = test_dataset.Tpose_minimal_v) 81 | 82 | 83 | test_data_loader = DataLoader(test_dataset, 84 | batch_size=1, shuffle=False, 85 | num_workers=0, pin_memory=False) 86 | 87 | 88 | # for now, all the hand, face joints are combined with body joints for smplx 89 | gt_lbs_smpl = model.lbs_weights[:,:24].clone() 90 | root_idx = model.parents.cpu().numpy() 91 | idx_list = list(range(root_idx.shape[0])) 92 | for i in range(root_idx.shape[0]): 93 | if i > 23: 94 | root = idx_list[root_idx[i]] 95 | gt_lbs_smpl[:,root] += model.lbs_weights[:,i] 96 | idx_list[i] = root 97 | gt_lbs_smpl = gt_lbs_smpl[None].permute(0,2,1) 98 | 99 | smpl_vitruvian = model.initiate_vitruvian(device = cuda, body_neutral_v = test_dataset.Tpose_minimal_v) 100 | 101 | # define bounding box 102 | bbox_smpl = (smpl_vitruvian[0].cpu().numpy().min(0).astype(np.float32), smpl_vitruvian[0].cpu().numpy().max(0).astype(np.float32)) 103 | bbox_center, bbox_size = 0.5 * (bbox_smpl[0] + bbox_smpl[1]), (bbox_smpl[1] - bbox_smpl[0]) 104 | bbox_min = np.stack([bbox_center[0]-0.55*bbox_size[0],bbox_center[1]-0.6*bbox_size[1],bbox_center[2]-1.5*bbox_size[2]], 0).astype(np.float32) 105 | bbox_max = np.stack([bbox_center[0]+0.55*bbox_size[0],bbox_center[1]+0.6*bbox_size[1],bbox_center[2]+1.5*bbox_size[2]], 0).astype(np.float32) 106 | 107 | pose_map = get_posemap(opt['model']['posemap_type'], 24, model.parents, opt['model']['n_traverse'], opt['model']['normalize_posemap']) 108 | 109 | igr_net = IGRSDFNet(opt['model']['igr_net'], bbox_min, bbox_max, pose_map).to(cuda) 110 | fwd_skin_net = LBSNet(opt['model']['fwd_skin_net'], bbox_min, bbox_max, posed=False).to(cuda) 111 | texture_net = TNet(opt['model']['igr_net']).to(cuda) 112 | 113 | lat_vecs_igr = nn.Embedding(1, opt['model']['igr_net']['g_dim']).to(cuda) 114 | 115 | if opt['model']['igr_net']['g_dim'] > 0: 116 | torch.nn.init.constant_(lat_vecs_igr.weight.data, 0.0) 117 | 118 | print(igr_net) 119 | print(fwd_skin_net) 120 | print(texture_net) 121 | 122 | if ckpt_dict is not None: 123 | if 'igr_net' in ckpt_dict: 124 | load_network(igr_net, ckpt_dict['igr_net']) 125 | else: 126 | print("Couldn't find igr_net in checkpoints!") 127 | 128 | if 'fwd_skin_net' in ckpt_dict: 129 | load_network(fwd_skin_net, ckpt_dict['fwd_skin_net']) 130 | else: 131 | print("Couldn't find fwd_skin_net in checkpoints!") 132 | 133 | if 'lat_vecs_igr'in ckpt_dict: 134 | load_network(lat_vecs_igr, ckpt_dict['lat_vecs_igr']) 135 | else: 136 | print("Couldn't find lat_vecs_igr in checkpoints!") 137 | 138 | if 'texture_net'in ckpt_dict: 139 | load_network(texture_net, ckpt_dict['texture_net']) 140 | else: 141 | print("Couldn't find texture_net in checkpoints!") 142 | 143 | print('test data size: ', len(test_data_loader)) 144 | 145 | # Test color module 146 | print('start test inference') 147 | igr_net.set_lbsnet(fwd_skin_net) 148 | gen_color_mesh(opt, result_dir, igr_net, fwd_skin_net, lat_vecs_igr, texture_net, model, smpl_vitruvian, test_data_loader, cuda, 149 | reference_body_v=test_data_loader.dataset.Tpose_minimal_v) 150 | 151 | with open(os.path.join(result_dir, '../', exp_name+'_'+test_input_basedir+'.txt'), 'w') as finish_file: 152 | finish_file.write('Done!') 153 | 154 | def testWrapper(args=None): 155 | parser = argparse.ArgumentParser( 156 | description='Test SCANimate color.' 157 | ) 158 | parser.add_argument('--config', '-c', type=str, help='Path to config file.') 159 | parser.add_argument('--test_dir', '-t', type=str, required=True, help='Path to test directory') 160 | args = parser.parse_args() 161 | 162 | opt = load_config(args.config, 'configs/default.yaml') 163 | 164 | test(opt, args.test_dir) 165 | 166 | if __name__ == '__main__': 167 | testWrapper() -------------------------------------------------------------------------------- /apps/test_scanimate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import os 18 | import argparse 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | from torch.utils.data import DataLoader 23 | 24 | import smpl 25 | from lib.config import load_config 26 | from lib.net_util import load_network, get_posemap 27 | from lib.model.IGRSDFNet import IGRSDFNet 28 | from lib.model.LBSNet import LBSNet 29 | from lib.data.CapeDataset import CapeDataset_scan 30 | 31 | import math 32 | 33 | from apps.train_scanimate import gen_mesh2 34 | 35 | import logging 36 | logging.basicConfig(level=logging.DEBUG) 37 | 38 | def test(opt, test_input_dir): 39 | cuda = torch.device('cuda:0') 40 | 41 | tmp_dirs = test_input_dir.split('/') 42 | 43 | test_input_basedir = tmp_dirs.pop() 44 | while test_input_basedir == '': 45 | test_input_basedir = tmp_dirs.pop() 46 | opt['data']['test_dir'] = test_input_dir 47 | 48 | exp_name = opt['experiment']['name'] 49 | ckpt_dir = '%s/%s' % (opt['experiment']['ckpt_dir'], exp_name) 50 | result_dir = '%s_test/%s_test_%s' % (opt['experiment']['result_dir'], exp_name, test_input_basedir) 51 | os.makedirs(result_dir, exist_ok=True) 52 | 53 | model = smpl.create(opt['data']['smpl_dir'], model_type='smpl_vitruvian', 54 | gender=opt['data']['smpl_gender'], use_face_contour=False, 55 | ext='npz').to(cuda) 56 | 57 | tmp_dir = opt['data']['data_dir'] 58 | tmp_dir_files = sorted([f for f in os.listdir(tmp_dir) if '.ply' in f]) 59 | customized_minimal_ply = os.path.join(tmp_dir, tmp_dir_files[0]) 60 | test_dataset = CapeDataset_scan(opt['data'], phase='test', smpl=model, 61 | customized_minimal_ply=customized_minimal_ply, full_test = True, device=cuda) 62 | 63 | reference_body_vs_test = test_dataset.Tpose_minimal_v 64 | smpl_vitruvian = model.initiate_vitruvian(device = cuda, body_neutral_v = test_dataset.Tpose_minimal_v) 65 | 66 | test_data_loader = DataLoader(test_dataset, 67 | batch_size=1, shuffle=False, 68 | num_workers=0, pin_memory=False) 69 | 70 | 71 | # for now, all the hand, face joints are combined with body joints for smplx 72 | gt_lbs_smpl = model.lbs_weights[:,:24].clone() 73 | root_idx = model.parents.cpu().numpy() 74 | idx_list = list(range(root_idx.shape[0])) 75 | for i in range(root_idx.shape[0]): 76 | if i > 23: 77 | root = idx_list[root_idx[i]] 78 | gt_lbs_smpl[:,root] += model.lbs_weights[:,i] 79 | idx_list[i] = root 80 | gt_lbs_smpl = gt_lbs_smpl[None].permute(0,2,1) 81 | 82 | betas = torch.zeros([1, 10], dtype=torch.float32, device=cuda) 83 | body_pose = torch.zeros([1, 69], dtype=torch.float32, device=cuda) 84 | body_pose[:,2] = math.radians(30) # for vitruvian pose 85 | body_pose[:,5] = math.radians(-30) # for vitruvian pose 86 | global_orient = torch.zeros((1, 3), dtype=torch.float32, device=cuda) 87 | transl = torch.zeros((1, 3), dtype=torch.float32, device=cuda) 88 | 89 | # define bounding box 90 | bbox_smpl = (smpl_vitruvian[0].cpu().numpy().min(0).astype(np.float32), smpl_vitruvian[0].cpu().numpy().max(0).astype(np.float32)) 91 | bbox_center, bbox_size = 0.5 * (bbox_smpl[0] + bbox_smpl[1]), (bbox_smpl[1] - bbox_smpl[0]) 92 | bbox_min = np.stack([bbox_center[0]-0.55*bbox_size[0],bbox_center[1]-0.6*bbox_size[1],bbox_center[2]-1.5*bbox_size[2]], 0).astype(np.float32) 93 | bbox_max = np.stack([bbox_center[0]+0.55*bbox_size[0],bbox_center[1]+0.6*bbox_size[1],bbox_center[2]+1.5*bbox_size[2]], 0).astype(np.float32) 94 | 95 | pose_map = get_posemap(opt['model']['posemap_type'], 24, model.parents, opt['model']['n_traverse'], opt['model']['normalize_posemap']) 96 | 97 | igr_net = IGRSDFNet(opt['model']['igr_net'], bbox_min, bbox_max, pose_map).to(cuda) 98 | fwd_skin_net = LBSNet(opt['model']['fwd_skin_net'], bbox_min, bbox_max, posed=False).to(cuda) 99 | inv_skin_net = LBSNet(opt['model']['inv_skin_net'], bbox_min, bbox_max, posed=True).to(cuda) 100 | 101 | lat_vecs_igr = nn.Embedding(1, opt['model']['igr_net']['g_dim']).to(cuda) 102 | 103 | if opt['model']['igr_net']['g_dim'] > 0: 104 | torch.nn.init.constant_(lat_vecs_igr.weight.data, 0.0) 105 | 106 | print(igr_net) 107 | print(fwd_skin_net) 108 | print(inv_skin_net) 109 | 110 | # load checkpoints 111 | ckpt_dict = None 112 | logging.info("Loading checkpoint from %s" % ckpt_dir) 113 | if os.path.isfile(os.path.join(ckpt_dir, 'ckpt_latest.pt')): 114 | logging.info('loading ckpt [%s]'%os.path.join(ckpt_dir, 'ckpt_latest.pt')) 115 | ckpt_dict = torch.load(os.path.join(ckpt_dir, 'ckpt_latest.pt')) 116 | else: 117 | logging.error('error: ckpt does not exist [%s]' % opt['experiment']['ckpt_file']) 118 | exit() 119 | 120 | if ckpt_dict is not None: 121 | if 'igr_net' in ckpt_dict: 122 | load_network(igr_net, ckpt_dict['igr_net']) 123 | else: 124 | print("Couldn't find igr_net in checkpoints!") 125 | 126 | if 'fwd_skin_net' in ckpt_dict: 127 | load_network(fwd_skin_net, ckpt_dict['fwd_skin_net']) 128 | else: 129 | print("Couldn't find fwd_skin_net in checkpoints!") 130 | 131 | if 'lat_vecs_igr'in ckpt_dict: 132 | load_network(lat_vecs_igr, ckpt_dict['lat_vecs_igr']) 133 | else: 134 | print("Couldn't find lat_vecs_igr in checkpoints!") 135 | 136 | else: 137 | logging.error("No checkpoint!") 138 | exit() 139 | 140 | logging.info('test data size: %d'%len(test_data_loader)) 141 | 142 | logging.info('Start test inference') 143 | igr_net.set_lbsnet(fwd_skin_net) 144 | 145 | gen_mesh2(opt, result_dir, igr_net, fwd_skin_net, lat_vecs_igr, model, smpl_vitruvian, test_data_loader, cuda, 146 | reference_body_v=test_data_loader.dataset.Tpose_minimal_v) 147 | 148 | with open(os.path.join(result_dir, '../', exp_name+'_'+test_input_basedir+'.txt'), 'w') as finish_file: 149 | finish_file.write('Done!') 150 | 151 | 152 | def testWrapper(args=None): 153 | parser = argparse.ArgumentParser( 154 | description='Test SCANimate.' 155 | ) 156 | parser.add_argument('--config', '-c', type=str, help='Path to config file.') 157 | parser.add_argument('--test_dir', '-t', type=str, \ 158 | required=True,\ 159 | help='Path to test directory') 160 | args = parser.parse_args() 161 | 162 | opt = load_config(args.config, 'configs/default.yaml') 163 | 164 | test(opt, args.test_dir) 165 | 166 | if __name__ == '__main__': 167 | testWrapper() -------------------------------------------------------------------------------- /apps/train_color.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from lib.data.CapeDataset import CapeDataset_scan_color 18 | import smpl 19 | from torch.utils.data import DataLoader 20 | from lib.config import load_config 21 | from lib.model.IGRSDFNet import IGRSDFNet 22 | from lib.model.LBSNet import LBSNet 23 | from lib.model.TNet import TNet 24 | 25 | from lib.geo_util import compute_normal_v 26 | 27 | import argparse 28 | import torch 29 | import os 30 | import json 31 | import numpy as np 32 | 33 | from lib.net_util import batch_rod2quat,homogenize, load_network, get_posemap 34 | import torch.nn as nn 35 | import math 36 | from lib.mesh_util import replace_hands_feet_wcolor 37 | from lib.mesh_util import reconstruction, save_obj_mesh, save_obj_mesh_with_color, scalar_to_color 38 | import time 39 | import trimesh 40 | from tqdm import tqdm 41 | 42 | def gen_train_color_mesh(opt, result_dir, fwd_skin_net, inv_skin_net, lat_vecs_inv_skin, model, smpl_vitruvian, train_data_loader, cuda, name='', reference_body_v=None): 43 | dataset = train_data_loader.dataset 44 | smpl_face = torch.LongTensor(model.faces[:,[0,2,1]].astype(np.int32))[None].to(cuda) 45 | 46 | def process(data, idx=0): 47 | frame_names = data['frame_name'] 48 | scan_color = data['colors'].to(device=cuda) 49 | betas = data['betas'][None].to(device=cuda) 50 | body_pose = data['body_pose'][None].to(device=cuda) 51 | scan_posed = data['scan_posed'][None].to(device=cuda) 52 | original_colors = data['original_colors'] 53 | transl = data['transl'][None].to(device=cuda) 54 | f_ids = torch.LongTensor([data['f_id']]).to(device=cuda) 55 | smpl_data = data['smpl_data'] 56 | faces = data['faces'].numpy() 57 | global_orient = body_pose[:,:3] 58 | body_pose = body_pose[:,3:] 59 | 60 | if not reference_body_v == None: 61 | output = model(betas=betas, body_pose=body_pose, global_orient=0*global_orient, transl=0*transl, return_verts=True, custom_out=True, 62 | body_neutral_v = reference_body_v.expand(body_pose.shape[0], -1, -1)) 63 | else: 64 | output = model(betas=betas, body_pose=body_pose, global_orient=0*global_orient, transl=0*transl, return_verts=True, custom_out=True) 65 | smpl_posed_joints = output.joints 66 | rootT = model.get_root_T(global_orient, transl, smpl_posed_joints[:,0:1,:]) 67 | 68 | smpl_neutral = output.v_shaped 69 | smpl_cano = output.v_posed 70 | smpl_posed = output.vertices.contiguous() 71 | bmax = smpl_posed.max(1)[0] 72 | bmin = smpl_posed.min(1)[0] 73 | offset = 0.2*(bmax - bmin) 74 | bmax += offset 75 | bmin -= offset 76 | jT = output.joint_transform[:,:24] 77 | smpl_n_posed = compute_normal_v(smpl_posed, smpl_face.expand(smpl_posed.shape[0],-1,-1)) 78 | scan_posed = torch.einsum('bst,bvt->bsv', torch.inverse(rootT), homogenize(scan_posed))[:,:3,:] # remove root transform 79 | 80 | if inv_skin_net.opt['g_dim'] > 0: 81 | lat = lat_vecs_inv_skin(f_ids) # (B, Z) 82 | inv_skin_net.set_global_feat(lat) 83 | feat3d_posed = None 84 | res_scan_p = inv_skin_net(feat3d_posed, scan_posed, jT=jT, bmin=bmin[:,:,None], bmax=bmax[:,:,None]) 85 | pred_scan_cano = res_scan_p['pred_smpl_cano'].permute(0,2,1) 86 | 87 | # res_smpl_p = inv_skin_net(feat3d_posed, smpl_posed.permute(0,2,1), jT=jT, bmin=bmin[:,:,None], bmax=bmax[:,:,None]) 88 | # pred_smpl_cano = res_smpl_p['pred_smpl_cano'].permute(0,2,1) 89 | # save_obj_mesh('%s/pred_smpl_cano%s%s.obj' % (result_dir, str(idx).zfill(4), name), pred_smpl_cano[0].cpu().numpy(), model.faces[:,[0,2,1]]) 90 | if name=='_pt3': 91 | scan_faces, scan_mask = dataset.get_raw_scan_face_and_mask(frame_id = f_ids[0].cpu().numpy()) 92 | valid_scan_faces = scan_faces[scan_mask,:] 93 | pred_scan_cano_mesh = trimesh.Trimesh(vertices = pred_scan_cano[0].cpu().numpy(), faces = valid_scan_faces[:,[0,2,1]], vertex_colors = original_colors, process=False) 94 | save_obj_mesh_with_color('%s/%s_scan_cano_%s.obj' % (result_dir, frame_names, str(idx).zfill(4)), pred_scan_cano_mesh.vertices, pred_scan_cano_mesh.faces, original_colors) 95 | 96 | feat3d_cano = None 97 | pred_scan_reposed = fwd_skin_net(feat3d_cano, pred_scan_cano.permute(0,2,1), jT=jT)['pred_smpl_posed'].permute(0,2,1) 98 | save_obj_mesh('%s/%s_pred_scan_reposed_%s%s.obj' % (result_dir, frame_names, str(idx).zfill(4), name), pred_scan_reposed[0].cpu().numpy(), faces) 99 | 100 | if True: 101 | with torch.no_grad(): 102 | print("Output canonicalized train meshes...") 103 | for i in tqdm(range(len(dataset))): 104 | if not i % 5 == 0: 105 | continue 106 | data = dataset[i] 107 | process(data, i) 108 | 109 | 110 | def gen_color_mesh(opt, result_dir, igr_net, fwd_skin_net, lat_vecs_igr, texture_net, model, smpl_vitruvian, test_data_loader, cuda, reference_body_v=None, largest_component=False): 111 | bbox_min = igr_net.bbox_min.squeeze().cpu().numpy() 112 | bbox_max = igr_net.bbox_max.squeeze().cpu().numpy() 113 | 114 | with torch.no_grad(): 115 | torch.cuda.empty_cache() 116 | for test_idx, test_data in enumerate(tqdm(test_data_loader)): 117 | frame_names = test_data['frame_name'] 118 | betas = test_data['betas'].to(device=cuda) 119 | body_pose = test_data['body_pose'].to(device=cuda) 120 | sub_ids = test_data['sub_id'].to(device=cuda) 121 | transl = test_data['transl'].to(device=cuda) 122 | global_orient = body_pose[:,:3] 123 | body_pose = body_pose[:,3:] 124 | if not reference_body_v == None: 125 | output = model(betas=betas, body_pose=body_pose, global_orient=0*global_orient, transl=0*transl, return_verts=True, custom_out=True, 126 | body_neutral_v = reference_body_v.expand(body_pose.shape[0], -1, -1)) 127 | else: 128 | output = model(betas=betas, body_pose=body_pose, global_orient=0*global_orient, transl=0*transl, return_verts=True, custom_out=True) 129 | # smpl_posed_joints = output.joints 130 | # rootT = model.get_root_T(global_orient, transl, smpl_posed_joints[:,0:1,:]) 131 | 132 | smpl_neutral = output.v_shaped 133 | jT = output.joint_transform[:,:24] 134 | 135 | if igr_net.opt['g_dim'] > 0: 136 | lat = lat_vecs_igr(sub_ids) # (B, Z) 137 | igr_net.set_global_feat(lat) 138 | 139 | set_pose_feat = batch_rod2quat(body_pose.reshape(-1, 3)).view(betas.shape[0], -1, 4) 140 | igr_net.set_pose_feat(set_pose_feat) 141 | 142 | verts, faces, _, _, vcolors = reconstruction(igr_net, cuda, torch.eye(4)[None].to(cuda), opt['experiment']['vol_res'],\ 143 | bbox_min, bbox_max, use_octree=True, thresh=0.0, 144 | texture_net = texture_net) 145 | # save_obj_mesh_with_color('%s/%s_cano%s.obj' % (result_dir, frame_names[0], str(test_idx).zfill(4)), verts, faces, vcolors) 146 | 147 | verts_torch = torch.Tensor(verts)[None].to(cuda) 148 | feat3d = None 149 | res = fwd_skin_net(feat3d, verts_torch.permute(0,2,1), jT=jT) 150 | pred_lbs = res['pred_lbs_smpl_cano'].permute(0,2,1) 151 | 152 | pred_scan_posed = res['pred_smpl_posed'].permute(0,2,1) 153 | rootT = test_data['rootT'].cuda() 154 | pred_scan_posed = torch.einsum('bst,bvt->bvs', rootT, homogenize(pred_scan_posed))[0,:,:3] 155 | pred_scan_posed = pred_scan_posed.cpu().numpy() 156 | save_obj_mesh_with_color('%s/%s_posed%s.obj' % (result_dir, frame_names[0], str(test_idx).zfill(4)), pred_scan_posed, faces, vcolors) 157 | 158 | 159 | def train_color(opt, ckpt_dir, result_dir, texture_net, igr_net, fwd_skin_net, inv_skin_net, lat_vecs_igr, lat_vecs_inv_skin, 160 | model, smpl_vitruvian, gt_lbs_smpl, train_data_loader, test_data_loader, cuda, reference_body_v=None): 161 | 162 | fwd_skin_net.eval() 163 | igr_net.eval() 164 | inv_skin_net.eval() 165 | igr_net.set_lbsnet(fwd_skin_net) 166 | 167 | smpl_face = torch.LongTensor(model.faces[:,[0,2,1]].astype(np.int32))[None].to(cuda) 168 | 169 | optimizer = torch.optim.Adam([{ 170 | "params": texture_net.parameters(), 171 | "lr": opt['training']['lr_sdf']}]) 172 | 173 | n_iter = 0 174 | max_train_idx = 0 175 | start_time = time.time() 176 | current_number_processed_samples = 0 177 | 178 | train_data_loader.dataset.resample_flag = True 179 | 180 | opt['training']['num_epoch_sdf'] = opt['training']['num_epoch_sdf']//4 181 | 182 | for epoch in range(opt['training']['num_epoch_sdf']): 183 | texture_net.train() 184 | if epoch == opt['training']['num_epoch_sdf']//2 or epoch == 3*(opt['training']['num_epoch_sdf']//4): 185 | for j, _ in enumerate(optimizer.param_groups): 186 | optimizer.param_groups[j]['lr'] *= 0.1 187 | for train_idx, train_data in enumerate(train_data_loader): 188 | betas = train_data['betas'].to(device=cuda) 189 | body_pose = train_data['body_pose'].to(device=cuda) 190 | sub_ids = train_data['sub_id'].to(device=cuda) 191 | transl = train_data['transl'].to(device=cuda) 192 | f_ids = train_data['f_id'].to(device=cuda) 193 | smpl_data = train_data['smpl_data'] 194 | 195 | scan_v_posed = train_data['scan_cano_uni'].to(device=cuda) 196 | scan_n_posed = train_data['normals_uni'].to(device=cuda) 197 | scan_color = train_data['colors'].to(device=cuda) 198 | scan_color = scan_color.permute(0,2,1) 199 | 200 | global_orient = body_pose[:,:3] 201 | body_pose = body_pose[:,3:] 202 | 203 | smpl_neutral = smpl_data['smpl_neutral'].cuda() 204 | smpl_cano = smpl_data['smpl_cano'].cuda() 205 | smpl_posed = smpl_data['smpl_posed'].cuda() 206 | smpl_n_posed = smpl_data['smpl_n_posed'].cuda() 207 | bmax = smpl_data['bmax'].cuda() 208 | bmin = smpl_data['bmin'].cuda() 209 | jT = smpl_data['jT'].cuda() 210 | inv_rootT = smpl_data['inv_rootT'].cuda() 211 | 212 | with torch.no_grad(): 213 | scan_v_posed = torch.einsum('bst,bvt->bsv', inv_rootT, homogenize(scan_v_posed))[:,:3,:] # remove root transform 214 | scan_n_posed = torch.einsum('bst,bvt->bsv', inv_rootT[:,:3,:3], scan_n_posed) 215 | 216 | if opt['model']['inv_skin_net']['g_dim'] > 0: 217 | lat = lat_vecs_inv_skin(f_ids) # (B, Z) 218 | inv_skin_net.set_global_feat(lat) 219 | 220 | feat3d_posed = None 221 | res_lbs_p, _, _ = inv_skin_net(feat3d_posed, smpl_posed.permute(0,2,1), gt_lbs_smpl, scan_v_posed, jT=jT, nml_scan=scan_n_posed, bmin=bmin[:,:,None], bmax=bmax[:,:,None]) 222 | scan_cano = res_lbs_p['pred_scan_cano'] 223 | normal_cano = res_lbs_p['normal_scan_cano'] 224 | 225 | if opt['model']['igr_net']['g_dim'] > 0: 226 | lat = lat_vecs_igr(sub_ids) # (B, Z) 227 | # print("subid", sub_ids) 228 | igr_net.set_global_feat(lat) 229 | 230 | smpl_neutral = smpl_neutral.permute(0,2,1) 231 | 232 | set_pose_feat = batch_rod2quat(body_pose.reshape(-1, 3)).view(betas.shape[0], -1, 4) 233 | igr_net.set_pose_feat(set_pose_feat) 234 | 235 | pts0 = scan_cano[0].detach().cpu().permute(1,0).numpy() 236 | clr0 = scan_color[0].detach().cpu().permute(1,0).numpy() 237 | clr0 = train_data['colors'][0,:,:].detach().cpu().numpy() 238 | 239 | scan_cano, scan_color = replace_hands_feet_wcolor(scan_cano, 240 | scan_color, smpl_neutral, 241 | opt['data']['num_sample_surf'], 242 | vitruvian_angle = model.vitruvian_angle) 243 | 244 | sdf, last_layer_feature, point_local_feat = igr_net.query(scan_cano, return_last_layer_feature=True) 245 | 246 | 247 | err, err_dict = texture_net(point_local_feat, last_layer_feature, scan_color) 248 | 249 | err_dict['All'] = err.item() 250 | 251 | optimizer.zero_grad() 252 | err.backward() 253 | optimizer.step() 254 | 255 | if n_iter % opt['training']['freq_plot'] == 0: 256 | err_txt = ''.join(['{}: {:.3f} '.format(k, v) for k,v in err_dict.items()]) 257 | time_now = time.time() 258 | duration = time_now-start_time 259 | current_number_processed_samples += f_ids.shape[0] 260 | persample_process_time = duration/current_number_processed_samples 261 | current_number_processed_samples = -f_ids.shape[0] 262 | print('[%03d/%03d]:[%04d/%04d] %02f FPS, %s' % (epoch, opt['training']['num_epoch_sdf'], 263 | train_idx, len(train_data_loader), 1.0/persample_process_time, err_txt)) 264 | start_time = time.time() 265 | 266 | if (n_iter+1) % 200 == 0 or (epoch == opt['training']['num_epoch_sdf']-1 and train_idx == max_train_idx): 267 | ckpt_dict = { 268 | 'opt': opt, 269 | 'epoch': epoch, 270 | 'iter': n_iter, 271 | 'igr_net': igr_net.state_dict(), 272 | 'fwd_skin_net': fwd_skin_net.state_dict(), 273 | 'lat_vecs_igr': lat_vecs_igr.state_dict(), 274 | 'lat_vecs_inv_skin': lat_vecs_inv_skin.state_dict(), 275 | 'texture_net': texture_net.state_dict(), 276 | 'optimizer': optimizer.state_dict() 277 | } 278 | torch.save(ckpt_dict, '%s/ckpt_color_latest.pt' % ckpt_dir) 279 | if (n_iter+1) % 1000 == 0: 280 | torch.save(ckpt_dict, '%s/ckpt_color_epoch%d.pt' % (ckpt_dir, epoch)) 281 | 282 | if n_iter == 0: 283 | train_data_loader.dataset.is_train = False 284 | texture_net.eval() 285 | gen_train_color_mesh(opt, result_dir, fwd_skin_net, inv_skin_net, lat_vecs_inv_skin, model, smpl_vitruvian, train_data_loader, cuda, '_pt3', reference_body_v=train_data_loader.dataset.Tpose_minimal_v) 286 | train_data_loader.dataset.is_train = True 287 | 288 | if (n_iter+1) % opt['training']['freq_mesh'] == 0 or (epoch == opt['training']['num_epoch_sdf']-1 and train_idx == max_train_idx): 289 | texture_net.eval() 290 | gen_color_mesh(opt, result_dir, igr_net, fwd_skin_net, lat_vecs_igr, texture_net, model, smpl_vitruvian, test_data_loader, cuda, 291 | reference_body_v=test_data_loader.dataset.Tpose_minimal_v) 292 | 293 | if max_train_idx < train_idx: 294 | max_train_idx = train_idx 295 | n_iter += 1 296 | current_number_processed_samples += f_ids.shape[0] 297 | 298 | def train(opt): 299 | cuda = torch.device('cuda:0') 300 | 301 | exp_name = opt['experiment']['name'] 302 | ckpt_dir = '%s/%s' % (opt['experiment']['ckpt_dir'], exp_name) 303 | result_dir = '%s/%s' % (opt['experiment']['result_dir']+'_color', exp_name+'_color') 304 | log_dir = '%s/%s' % (opt['experiment']['log_dir'], exp_name) 305 | 306 | os.makedirs(ckpt_dir, exist_ok=True) 307 | os.makedirs(result_dir, exist_ok=True) 308 | os.makedirs(log_dir, exist_ok=True) 309 | 310 | # Backup config into log_dir 311 | with open(os.path.join(log_dir, 'config.json'), 'w') as config_file: 312 | config_file.write(json.dumps(opt)) 313 | 314 | # load checkpoints 315 | ckpt_dict = None 316 | if opt['experiment']['ckpt_file'] is not None: 317 | if os.path.isfile(opt['experiment']['ckpt_file']): 318 | print('loading for ckpt...', opt['experiment']['ckpt_file']) 319 | ckpt_dict = torch.load(opt['experiment']['ckpt_file']) 320 | else: 321 | print('error: ckpt does not exist [%s]' % opt['experiment']['ckpt_file']) 322 | elif opt['training']['continue_train']: 323 | # if opt['training']['resume_epoch'] < 0: 324 | model_path = '%s/ckpt_latest.pt' % ckpt_dir 325 | # else: 326 | # model_path = '%s/ckpt_epoch_%d.pt' % (ckpt_dir, opt['training']['resume_epoch']) 327 | if os.path.isfile(model_path): 328 | print('Resuming from ', model_path) 329 | ckpt_dict = torch.load(model_path) 330 | else: 331 | print('error: ckpt does not exist [%s]' % model_path) 332 | elif opt['training']['use_pretrain']: 333 | model_path = '%s/ckpt_pretrain.pt' % ckpt_dir 334 | if os.path.isfile(model_path): 335 | print('Resuming from ', model_path) 336 | ckpt_dict = torch.load(model_path) 337 | print('Pretrained model loaded.') 338 | else: 339 | print('error: ckpt does not exist [%s]' % model_path) 340 | 341 | 342 | model = smpl.create(opt['data']['smpl_dir'], model_type='smpl_vitruvian', 343 | gender=opt['data']['smpl_gender'], use_face_contour=False, 344 | ext='npz').to(cuda) 345 | 346 | 347 | train_dataset = CapeDataset_scan_color(opt['data'], phase='train', smpl=model) 348 | test_dataset = CapeDataset_scan_color(opt['data'], phase='test', smpl=model, full_test = True) 349 | 350 | reference_body_vs_train = train_dataset.Tpose_minimal_v 351 | reference_body_vs_test = test_dataset.Tpose_minimal_v 352 | 353 | smpl_vitruvian = model.initiate_vitruvian(device = cuda, body_neutral_v = train_dataset.Tpose_minimal_v) 354 | 355 | 356 | train_data_loader = DataLoader(train_dataset, 357 | batch_size=8, shuffle=True,#not opt['training']['serial_batch'], 358 | num_workers=16, pin_memory=opt['training']['pin_memory']) 359 | test_data_loader = DataLoader(test_dataset, 360 | batch_size=1, shuffle=False, 361 | num_workers=0, pin_memory=False) 362 | 363 | 364 | # for now, all the hand, face joints are combined with body joints for smpl 365 | gt_lbs_smpl = model.lbs_weights[:,:24].clone() 366 | root_idx = model.parents.cpu().numpy() 367 | idx_list = list(range(root_idx.shape[0])) 368 | for i in range(root_idx.shape[0]): 369 | if i > 23: 370 | root = idx_list[root_idx[i]] 371 | gt_lbs_smpl[:,root] += model.lbs_weights[:,i] 372 | idx_list[i] = root 373 | gt_lbs_smpl = gt_lbs_smpl[None].permute(0,2,1) 374 | 375 | smpl_vitruvian = model.initiate_vitruvian(device = cuda, body_neutral_v = train_dataset.Tpose_minimal_v) 376 | 377 | # define bounding box 378 | bbox_smpl = (smpl_vitruvian[0].cpu().numpy().min(0).astype(np.float32), smpl_vitruvian[0].cpu().numpy().max(0).astype(np.float32)) 379 | bbox_center, bbox_size = 0.5 * (bbox_smpl[0] + bbox_smpl[1]), (bbox_smpl[1] - bbox_smpl[0]) 380 | bbox_min = np.stack([bbox_center[0]-0.55*bbox_size[0],bbox_center[1]-0.6*bbox_size[1],bbox_center[2]-1.5*bbox_size[2]], 0).astype(np.float32) 381 | bbox_max = np.stack([bbox_center[0]+0.55*bbox_size[0],bbox_center[1]+0.6*bbox_size[1],bbox_center[2]+1.5*bbox_size[2]], 0).astype(np.float32) 382 | 383 | pose_map = get_posemap(opt['model']['posemap_type'], 24, model.parents, opt['model']['n_traverse'], opt['model']['normalize_posemap']) 384 | 385 | igr_net = IGRSDFNet(opt['model']['igr_net'], bbox_min, bbox_max, pose_map).to(cuda) 386 | fwd_skin_net = LBSNet(opt['model']['fwd_skin_net'], bbox_min, bbox_max, posed=False).to(cuda) 387 | inv_skin_net = LBSNet(opt['model']['inv_skin_net'], bbox_min, bbox_max, posed=True).to(cuda) 388 | texture_net = TNet(opt['model']['igr_net']).to(cuda) 389 | 390 | lat_vecs_igr = nn.Embedding(1, opt['model']['igr_net']['g_dim']).to(cuda) 391 | lat_vecs_inv_skin = nn.Embedding(len(train_dataset), opt['model']['inv_skin_net']['g_dim']).to(cuda) 392 | 393 | if opt['model']['igr_net']['g_dim'] > 0: 394 | torch.nn.init.constant_(lat_vecs_igr.weight.data, 0.0) 395 | #torch.nn.init.normal_(lat_vecs_igr.weight.data, 0.0, 1.0 / math.sqrt(opt['model']['igr_net']['g_dim'])) 396 | 397 | if opt['model']['inv_skin_net']['g_dim'] > 0: 398 | torch.nn.init.normal_(lat_vecs_inv_skin.weight.data, 0.0, 1.0 / math.sqrt(opt['model']['inv_skin_net']['g_dim'])) 399 | 400 | print(igr_net) 401 | print(fwd_skin_net) 402 | print(inv_skin_net) 403 | print(texture_net) 404 | 405 | if ckpt_dict is not None: 406 | if 'igr_net' in ckpt_dict: 407 | load_network(igr_net, ckpt_dict['igr_net']) 408 | else: 409 | print("Couldn't find igr_net in checkpoints!") 410 | 411 | if 'fwd_skin_net' in ckpt_dict: 412 | load_network(fwd_skin_net, ckpt_dict['fwd_skin_net']) 413 | else: 414 | print("Couldn't find fwd_skin_net in checkpoints!") 415 | 416 | if 'inv_skin_net' in ckpt_dict: 417 | load_network(inv_skin_net, ckpt_dict['inv_skin_net']) 418 | else: 419 | print("Couldn't find inv_skin_net in checkpoints!") 420 | print("Try to find pretrained model...") 421 | model_path = '%s/ckpt_trained_skin_nets.pt' % ckpt_dir 422 | if os.path.isfile(model_path): 423 | print('Sucessfully found pretrained model of inv_skin_net: ', model_path) 424 | pretrained_ckpt_dict = torch.load(model_path) 425 | fwd_skin_net.load_state_dict(pretrained_ckpt_dict['fwd_skin_net']) 426 | inv_skin_net.load_state_dict(pretrained_ckpt_dict['inv_skin_net']) 427 | lat_vecs_inv_skin.load_state_dict(pretrained_ckpt_dict['lat_vecs_inv_skin']) 428 | # load_network(inv_skin_net, pretrained_ckpt_dict['inv_skin_net']) 429 | else: 430 | print("No pretrained model has been found") 431 | exit() 432 | 433 | if 'lat_vecs_igr'in ckpt_dict: 434 | load_network(lat_vecs_igr, ckpt_dict['lat_vecs_igr']) 435 | else: 436 | print("Couldn't find lat_vecs_igr in checkpoints!") 437 | 438 | if 'lat_vecs_inv_skin'in ckpt_dict: 439 | load_network(lat_vecs_inv_skin, ckpt_dict['lat_vecs_inv_skin']) 440 | else: 441 | print("Couldn't find lat_vecs_inv_skin in checkpoints!") 442 | 443 | 444 | print('train data size: ', len(train_data_loader)) 445 | print('test data size: ', len(test_data_loader)) 446 | 447 | 448 | # get only valid triangles 449 | print('Computing valid triangles...') 450 | train_data_loader.dataset.compute_valid_tri(inv_skin_net, model, lat_vecs_inv_skin, smpl_vitruvian) 451 | 452 | # Train color module 453 | print('Start training color module!') 454 | 455 | train_color(opt, ckpt_dir, result_dir, texture_net, igr_net, fwd_skin_net, inv_skin_net, lat_vecs_igr, lat_vecs_inv_skin, model, 456 | smpl_vitruvian, gt_lbs_smpl, train_data_loader, test_data_loader, cuda, reference_body_v = reference_body_vs_train) 457 | 458 | with open(os.path.join(result_dir, '../', exp_name+'.txt'), 'w') as finish_file: 459 | finish_file.write('Done!') 460 | 461 | def trainWrapper(args=None): 462 | parser = argparse.ArgumentParser( 463 | description='Train SCANimate color.' 464 | ) 465 | parser.add_argument('--config', '-c', type=str, help='Path to config file.') 466 | args = parser.parse_args() 467 | 468 | opt = load_config(args.config, 'configs/default.yaml') 469 | 470 | train(opt) 471 | 472 | if __name__ == '__main__': 473 | trainWrapper() -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: /is/cluster/work/jyang/data/scanimation/miniCAPE 3 | smpl_dir: ./smplx/models 4 | smpl_gender: male 5 | num_sample_surf: 5000 6 | num_sample_scan_igr: 5000 7 | num_sample_smpl_igr: 5000 8 | num_sample_bbox_igr: 2000 9 | num_sample_edge: 5000 10 | sigma_body: 0.1 11 | experiment: 12 | name: example 13 | debug: false 14 | log_dir: /is/cluster/work/jyang/experiments/scanimation/runs 15 | result_dir: /is/cluster/work/jyang/experiments/scanimation/results 16 | ckpt_dir: /is/cluster/work/jyang/experiments/scanimation/checkpoints 17 | ckpt_file: 18 | vol_res: 256 19 | netG_file: 20 | netC_file: 21 | training: 22 | num_threads: 1 23 | serial_batch: false 24 | pin_memory: false 25 | batch_size: 2 26 | skip_pt1: false 27 | skip_pt2: false 28 | end2end: false 29 | num_epoch_pt1: 100 30 | num_epoch_pt2: 500 31 | num_epoch_sdf: 4000 32 | lr_pt1: 0.004 33 | lr_pt2: 0.004 34 | lr_sdf: 0.004 35 | freq_plot: 10 36 | freq_save: 500 37 | freq_mesh: 1000 38 | num_interp: 10 39 | resume_epoch: -1 40 | use_pretrain: false 41 | continue_train: false 42 | finetune: false 43 | no_gen_mesh: false 44 | no_num_eval: false 45 | num_eval_sample: 400 46 | num_eval_mesh: 5 47 | test: 48 | test_folder_path: /is/cluster/work/jyang/experiments/scanimation/test 49 | model: 50 | bps_res_c: 16 51 | bps_res_p: 32 52 | posemap_type: both 53 | n_traverse: 4 54 | normalize_posemap: true 55 | id_type: subject 56 | sdf_net: 57 | lambda_sdf: 1.0 58 | lambda_nml: 1.0 59 | lambda_reg: 0.1 60 | lambda_bbox: 10.0 61 | lambda_pmap: 1.0 62 | lambda_lat: 0.01 63 | pose_dim: 4 64 | g_dim: 64 65 | learn_posemap: false 66 | use_embed: true 67 | d_size: 5 68 | n_bound: 500 69 | nml_scale: 0.1 70 | mlp: 71 | ch_dim: 72 | - 7 73 | - 512 74 | - 512 75 | - 343 76 | - 512 77 | - 1 78 | res_layers: 79 | - 3 80 | nlactiv: softplus 81 | norm: weight 82 | last_op: null 83 | lbs_net_c: 84 | lambda_smpl: 10.0 85 | lambda_scan: 1.0 86 | lambda_cyc_scan: 0.1 87 | lambda_cyc_smpl: 0.0 88 | lambda_lat: 0.01 89 | lambda_edge: 0.1 90 | use_embed: true 91 | d_size: 6 92 | g_dim: 0 93 | mlp: 94 | ch_dim: 95 | - 7 96 | - 256 97 | - 256 98 | - 256 99 | - 24 100 | res_layers: 101 | - 2 102 | nlactiv: leakyrelu 103 | norm: none 104 | last_op: softmax 105 | lbs_net_p: 106 | lambda_smpl: 10.0 107 | lambda_scan: 1.0 108 | lambda_lat: 0.01 109 | lambda_l_edge: 0.0 110 | lambda_w_edge: 0.1 111 | lambda_sparse: 0.001 112 | use_embed: true 113 | d_size: 8 114 | g_dim: 64 115 | p_val: 0.8 116 | mlp: 117 | ch_dim: 118 | - 8 119 | - 256 120 | - 256 121 | - 256 122 | - 24 123 | res_layers: 124 | - 2 125 | nlactiv: leakyrelu 126 | norm: none 127 | last_op: softmax 128 | -------------------------------------------------------------------------------- /configs/example.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | data_dir: ./data/train/example_03375_shortlong # Training data directory 3 | num_sample_bbox_igr: 2000 # Number of sample points in the bounding box 4 | num_sample_edge: 5000 5 | num_sample_scan_igr: 5000 6 | num_sample_smpl_igr: 5000 # num_sample_scan_igr+num_sample_smpl_igr: number of sample points around the surface 7 | num_sample_surf: 8000 # Number of sample points on the clothing surface 8 | sigma_body: 0.05 # Standard deviation controlling the sample points around the surface 9 | smpl_dir: ./smpl/models # The directory containing 10 PCA .pkl SMPL model 10 | smpl_gender: male 11 | test_dir: ./data/train/example_03375_shortlong # Test data directory, by default it is the same as training dir. Test directory should be additionally specified when run apps.test_scanimate 12 | train_ratio: 1.0 # Using different portion of training data. 13 | experiment: 14 | ckpt_dir: ./demo_result/checkpoints # Where to save checkpoints and final trained models 15 | log_dir: ./demo_result/runs # Where to save intermediate results 16 | name: example # Experiment name 17 | result_dir: ./demo_result/results # Where to save the results 18 | vol_res: 256 # Voxel resolution for marching cube reconstruction 19 | model: 20 | fwd_skin_net: 21 | d_size: 4 # Degree of positional encoding 22 | g_dim: 0 # Dimension of global latent code 23 | lambda_cyc_scan: 0.1 24 | lambda_cyc_smpl: 0.0 25 | lambda_edge: 0.1 26 | lambda_lat: 0.01 27 | lambda_scan: 1.0 28 | lambda_smpl: 10.0 29 | mlp: 30 | ch_dim: 31 | - 3 32 | - 256 33 | - 256 34 | - 256 35 | - 24 36 | last_op: softmax 37 | nlactiv: leakyrelu 38 | norm: none 39 | res_layers: 40 | - 2 41 | use_embed: true # Use positional encoding 42 | igr_net: 43 | d_size: 4 # Degree of positional encoding 44 | g_dim: 64 # Dimension of global latent code 45 | lambda_bbox: 1.0 46 | lambda_lat: 1.0 47 | lambda_nml: 1.0 48 | lambda_non_zero: 0.1 49 | lambda_pmap: 1.0 50 | lambda_reg: 1.0 51 | lambda_sdf: 1.0 52 | learn_posemap: false 53 | mlp: 54 | ch_dim: 55 | - 3 56 | - 512 57 | - 512 58 | - 512 59 | - 343 60 | - 512 61 | - 512 62 | - 1 63 | last_op: null 64 | nlactiv: softplus 65 | norm: weight 66 | res_layers: 67 | - 4 68 | n_bound: 500 69 | pose_dim: 4 # N ring neighbouring joints will be considered as pose condition 70 | use_embed: true # Use positional encoding 71 | inv_skin_net: 72 | d_size: 4 # Degree of positional encoding 73 | g_dim: 64 # Dimension of global latent code 74 | lambda_l_edge: 0.0 75 | lambda_lat: 0.01 76 | lambda_scan: 1.0 77 | lambda_smpl: 10.0 78 | lambda_sparse: 0.001 79 | lambda_w_edge: 0.1 80 | mlp: 81 | ch_dim: 82 | - 3 83 | - 256 84 | - 256 85 | - 256 86 | - 24 87 | last_op: softmax 88 | nlactiv: leakyrelu 89 | norm: none 90 | res_layers: 91 | - 2 92 | p_val: 0.8 93 | use_embed: true # Use positional encoding 94 | n_traverse: 4 95 | normalize_posemap: true 96 | posemap_type: both 97 | training: 98 | batch_size: 4 99 | continue_train: true 100 | freq_mesh: 10000 # Output intermediate mesh results for every N iterations 101 | freq_plot: 100 # Output training information for every N iterations 102 | freq_save: 5000 # Save checkpoints for every N iterations 103 | lr_pt1: 0.004 # Learning rates 104 | lr_pt2: 0.001 105 | lr_sdf: 0.004 106 | num_epoch_pt1: 200 # Epoches 107 | num_epoch_pt2: 200 108 | num_epoch_sdf: 20000 109 | num_threads: 8 110 | pin_memory: false 111 | resample_every_n_epoch: 1 # Actively resample points at every Nth epoch 112 | skip_pt1: false 113 | skip_pt2: false 114 | use_trained_skin_nets: false # If continue_train is true and ckpt_skin_net.pth exists, the code will skip training skin_nets automatically. 115 | -------------------------------------------------------------------------------- /download_aist_demo_motion.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | base_path=$(pwd) 3 | mkdir data && cd data 4 | mkdir test && cd test 5 | 6 | echo Downloading test motion sequences... 7 | wget https://scanimate.is.tue.mpg.de/media/upload/demo_data/aist_demo_seq.zip 8 | unzip aist_demo_seq.zip -d ./ 9 | rm aist_demo_seq.zip 10 | 11 | cd $base_path 12 | echo Done! -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | echo "Set up virtualenv for scanimate..." 3 | virtualenv -p python3.6 ./venv/scanimate 4 | . ./venv/scanimate/bin/activate 5 | 6 | echo "Installing torch..." 7 | echo "If you are using other versions (default python3.6 cuda 10.1), change the cuda version and python version in ./install.sh" 8 | pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 9 | 10 | echo "Installing pytorch3d..." 11 | echo "If you are using other versions (default python3.6 cuda 10.1), change the cuda version and python version in ./install.sh" 12 | pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py36_cu101_pyt171/download.html 13 | 14 | echo "Installing other dependencies..." 15 | pip install -r requirements.txt 16 | 17 | echo "Installing customized smpl code" 18 | cd smpl 19 | python3 setup.py install 20 | cd ../ 21 | 22 | echo "Done!" -------------------------------------------------------------------------------- /installation.txt: -------------------------------------------------------------------------------- 1 | 1. Setup virtual environment: 2 | Go to the scanimate directory in the command line, then 3 | $ source ./install.sh 4 | If you use other python and cuda versions (default python3.6 cuda 10.1), please change the cuda version and python version in ./install.sh 5 | 6 | 2. Download the smpl model: 7 | Download smpl models from https://smpl.is.tue.mpg.de/, put them into models folder under ./scanimate/smpl/models/smpl 8 | By default we use 10 PCA models and .pkl format. 9 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunsukesaito/SCANimate/f2eeb5799fd20fd9d5933472f6aedf1560296cbe/lib/__init__.py -------------------------------------------------------------------------------- /lib/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import yaml 18 | 19 | # General config 20 | def load_config(path, default_path): 21 | ''' Loads config file. 22 | Args: 23 | path (str): path to config file 24 | default_path (bool): whether to use default path 25 | ''' 26 | # Load configuration from file itself 27 | with open(path, 'r') as f: 28 | cfg_special = yaml.load(f, Loader=yaml.FullLoader) 29 | 30 | # Check if we should inherit from a config 31 | inherit_from = cfg_special.get('inherit_from') 32 | 33 | # If yes, load this config first as default 34 | # If no, use the default_path 35 | if inherit_from is not None: 36 | cfg = load_config(inherit_from, default_path) 37 | elif default_path is not None: 38 | with open(default_path, 'r') as f: 39 | cfg = yaml.load(f, Loader=yaml.FullLoader) 40 | else: 41 | cfg = dict() 42 | 43 | # Include main configuration 44 | update_recursive(cfg, cfg_special) 45 | 46 | return cfg 47 | 48 | def update_recursive(dict1, dict2): 49 | ''' Update two config dictionaries recursively. 50 | Args: 51 | dict1 (dict): first dictionary to be updated 52 | dict2 (dict): second dictionary which entries should be used 53 | ''' 54 | for k, v in dict2.items(): 55 | # Add item if not yet in dict1 56 | if k not in dict1: 57 | dict1[k] = None 58 | # Update 59 | if isinstance(dict1[k], dict): 60 | update_recursive(dict1[k], v) 61 | else: 62 | dict1[k] = v 63 | -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunsukesaito/SCANimate/f2eeb5799fd20fd9d5933472f6aedf1560296cbe/lib/data/__init__.py -------------------------------------------------------------------------------- /lib/ext_trimesh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import numpy as np 18 | 19 | from trimesh.sample import util 20 | from trimesh.sample import transformations 21 | import trimesh 22 | import torch 23 | 24 | def slice_mesh_plane_with_mesh_color(mesh, plane_normal, plane_origin): 25 | """ 26 | Slice a mesh with vertex color into two by a plane, 27 | return the half indicated by the normal direction. 28 | 29 | Parameters 30 | mesh: trimesh.Trimesh 31 | Mesh with mesh.visual.vertex_colors 32 | plane_normal: (3,) float 33 | Normal direction of the slice plane 34 | plane_origion: (3,) float 35 | One point on the slice plane 36 | """ 37 | 38 | dots = np.einsum('i,ij->j', plane_normal, 39 | (mesh.vertices - plane_origin).T)[mesh.faces] 40 | 41 | # Find vertex orientations w.r.t. faces for all triangles: 42 | # -1 -> vertex "inside" plane (positive normal direction) 43 | # 0 -> vertex on plane 44 | # 1 -> vertex "outside" plane (negative normal direction) 45 | signs = np.zeros(mesh.faces.shape, dtype=np.int8) 46 | signs[dots < -1e-8] = 1 47 | signs[dots > 1e-8] = -1 48 | signs[np.logical_and(dots >= -1e-8, dots <= 1e-8)] = 0 49 | 50 | # Find all triangles that intersect this plane 51 | # onedge <- indices of all triangles intersecting the plane 52 | # inside <- indices of all triangles "inside" the plane (positive normal) 53 | signs_sum = signs.sum(axis=1, dtype=np.int8) 54 | signs_asum = np.abs(signs).sum(axis=1, dtype=np.int8) 55 | 56 | # Cases: 57 | # (0,0,0), (-1,0,0), (-1,-1,0), (-1,-1,-1) <- inside 58 | # (1,0,0), (1,1,0), (1,1,1) <- outside 59 | # (1,0,-1), (1,-1,-1), (1,1,-1) <- onedge 60 | # onedge = np.logical_and(signs_asum >= 2, 61 | # np.abs(signs_sum) <= 1) 62 | inside = (signs_sum == -signs_asum) 63 | 64 | # Automatically include all faces that are "inside" 65 | new_faces = mesh.faces[inside] 66 | 67 | selected_vertex_ids = np.unique(new_faces) 68 | 69 | new_vertices = mesh.vertices[selected_vertex_ids] 70 | new_vertex_colors = mesh.visual.vertex_colors[selected_vertex_ids] 71 | 72 | old_vid2new_vid = np.zeros((mesh.vertices.shape[0]), dtype = np.int64) - 1 73 | for new_vid, old_vid in enumerate(selected_vertex_ids): 74 | old_vid2new_vid[old_vid] = new_vid 75 | new_faces = old_vid2new_vid[new_faces] 76 | 77 | half_mesh = trimesh.Trimesh(vertices = new_vertices, faces = new_faces, process=False) 78 | half_mesh.visual.vertex_colors[:,:] = new_vertex_colors[:,:] 79 | 80 | return half_mesh 81 | 82 | def slice_mesh_plane_with_texture_coordinates(mesh, plane_normal, plane_origin): 83 | """ 84 | Slice a mesh with vertex color into two by a plane, 85 | return the half indicated by the normal direction. 86 | 87 | Parameters 88 | mesh: trimesh.Trimesh 89 | Mesh with mesh.visual.vertex_colors 90 | plane_normal: (3,) float 91 | Normal direction of the slice plane 92 | plane_origion: (3,) float 93 | One point on the slice plane 94 | """ 95 | 96 | dots = np.einsum('i,ij->j', plane_normal, 97 | (mesh.vertices - plane_origin).T)[mesh.faces] 98 | 99 | # Find vertex orientations w.r.t. faces for all triangles: 100 | # -1 -> vertex "inside" plane (positive normal direction) 101 | # 0 -> vertex on plane 102 | # 1 -> vertex "outside" plane (negative normal direction) 103 | signs = np.zeros(mesh.faces.shape, dtype=np.int8) 104 | signs[dots < -1e-8] = 1 105 | signs[dots > 1e-8] = -1 106 | signs[np.logical_and(dots >= -1e-8, dots <= 1e-8)] = 0 107 | 108 | # Find all triangles that intersect this plane 109 | # onedge <- indices of all triangles intersecting the plane 110 | # inside <- indices of all triangles "inside" the plane (positive normal) 111 | signs_sum = signs.sum(axis=1, dtype=np.int8) 112 | signs_asum = np.abs(signs).sum(axis=1, dtype=np.int8) 113 | 114 | # Cases: 115 | # (0,0,0), (-1,0,0), (-1,-1,0), (-1,-1,-1) <- inside 116 | # (1,0,0), (1,1,0), (1,1,1) <- outside 117 | # (1,0,-1), (1,-1,-1), (1,1,-1) <- onedge 118 | # onedge = np.logical_and(signs_asum >= 2, 119 | # np.abs(signs_sum) <= 1) 120 | inside = (signs_sum == -signs_asum) 121 | 122 | # Automatically include all faces that are "inside" 123 | new_faces = mesh.faces[inside] 124 | 125 | selected_vertex_ids = np.unique(new_faces) 126 | 127 | new_vertices = mesh.vertices[selected_vertex_ids] 128 | new_uv = mesh.visual.uv[selected_vertex_ids] 129 | 130 | old_vid2new_vid = np.zeros((mesh.vertices.shape[0]), dtype = np.int64) - 1 131 | for new_vid, old_vid in enumerate(selected_vertex_ids): 132 | old_vid2new_vid[old_vid] = new_vid 133 | new_faces = old_vid2new_vid[new_faces] 134 | 135 | half_mesh = trimesh.Trimesh(vertices = new_vertices, faces = new_faces, process=False) 136 | 137 | return half_mesh, new_uv 138 | 139 | def sample_surface_wnormal(mesh, count, mask=None): 140 | """ 141 | Sample the surface of a mesh, returning the specified 142 | number of points 143 | For individual triangle sampling uses this method: 144 | http://mathworld.wolfram.com/TrianglePointPicking.html 145 | Parameters 146 | --------- 147 | mesh : trimesh.Trimesh 148 | Geometry to sample the surface of 149 | count : int 150 | Number of points to return 151 | Returns 152 | --------- 153 | samples : (count, 3) float 154 | Points in space on the surface of mesh 155 | face_index : (count,) int 156 | Indices of faces for each sampled point 157 | """ 158 | 159 | # len(mesh.faces) float, array of the areas 160 | # of each face of the mesh 161 | area = mesh.area_faces 162 | if mask is not None: 163 | area = mask * area 164 | # total area (float) 165 | area_sum = np.sum(area) 166 | # cumulative area (len(mesh.faces)) 167 | area_cum = np.cumsum(area) 168 | face_pick = np.random.random(count) * area_sum 169 | face_index = np.searchsorted(area_cum, face_pick) 170 | 171 | # pull triangles into the form of an origin + 2 vectors 172 | tri_origins = mesh.triangles[:, 0] 173 | tri_vectors = mesh.triangles[:, 1:].copy() 174 | tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3)) 175 | 176 | # do the same for normal 177 | normals = mesh.vertex_normals.view(np.ndarray)[mesh.faces] 178 | nml_origins = normals[:, 0] 179 | nml_vectors = normals[:, 1:]#.copy() 180 | nml_vectors -= np.tile(nml_origins, (1, 2)).reshape((-1, 2, 3)) 181 | 182 | # pull the vectors for the faces we are going to sample from 183 | tri_origins = tri_origins[face_index] 184 | tri_vectors = tri_vectors[face_index] 185 | 186 | # pull the vectors for the faces we are going to sample from 187 | nml_origins = nml_origins[face_index] 188 | nml_vectors = nml_vectors[face_index] 189 | 190 | # randomly generate two 0-1 scalar components to multiply edge vectors by 191 | random_lengths = np.random.random((len(tri_vectors), 2, 1)) 192 | 193 | # points will be distributed on a quadrilateral if we use 2 0-1 samples 194 | # if the two scalar components sum less than 1.0 the point will be 195 | # inside the triangle, so we find vectors longer than 1.0 and 196 | # transform them to be inside the triangle 197 | random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0 198 | random_lengths[random_test] -= 1.0 199 | random_lengths = np.abs(random_lengths) 200 | 201 | # multiply triangle edge vectors by the random lengths and sum 202 | sample_vector = (tri_vectors * random_lengths).sum(axis=1) 203 | sample_normal = (nml_vectors * random_lengths).sum(axis=1) 204 | 205 | # finally, offset by the origin to generate 206 | # (n,3) points in space on the triangle 207 | samples = sample_vector + tri_origins 208 | 209 | normals = sample_normal + nml_origins 210 | 211 | return samples, normals, face_index 212 | 213 | def sample_surface_wnormalcolor(mesh, count, mask=None): 214 | """ 215 | Sample the surface of a mesh, returning the specified 216 | number of points 217 | For individual triangle sampling uses this method: 218 | http://mathworld.wolfram.com/TrianglePointPicking.html 219 | Parameters 220 | --------- 221 | mesh : trimesh.Trimesh 222 | Geometry to sample the surface of 223 | count : int 224 | Number of points to return 225 | Returns 226 | --------- 227 | samples : (count, 3) float 228 | Points in space on the surface of mesh 229 | face_index : (count,) int 230 | Indices of faces for each sampled point 231 | """ 232 | 233 | # len(mesh.faces) float, array of the areas 234 | # of each face of the mesh 235 | area = mesh.area_faces 236 | if mask is not None: 237 | area = mask * area 238 | # total area (float) 239 | area_sum = np.sum(area) 240 | # cumulative area (len(mesh.faces)) 241 | area_cum = np.cumsum(area) 242 | face_pick = np.random.random(count) * area_sum 243 | face_index = np.searchsorted(area_cum, face_pick) 244 | 245 | # pull triangles into the form of an origin + 2 vectors 246 | tri_origins = mesh.triangles[:, 0] 247 | tri_vectors = mesh.triangles[:, 1:].copy() 248 | tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3)) 249 | 250 | # do the same for normal 251 | normals = mesh.vertex_normals.view(np.ndarray)[mesh.faces] 252 | nml_origins = normals[:, 0] 253 | nml_vectors = normals[:, 1:]#.copy() 254 | nml_vectors -= np.tile(nml_origins, (1, 2)).reshape((-1, 2, 3)) 255 | 256 | colors = mesh.visual.vertex_colors[:,:3].astype(np.float32) 257 | colors = colors / 255.0 258 | colors = colors.view(np.ndarray)[mesh.faces] 259 | clr_origins = colors[:, 0] 260 | clr_vectors = colors[:, 1:]#.copy() 261 | clr_vectors -= np.tile(clr_origins, (1, 2)).reshape((-1, 2, 3)) 262 | 263 | # pull the vectors for the faces we are going to sample from 264 | tri_origins = tri_origins[face_index] 265 | tri_vectors = tri_vectors[face_index] 266 | 267 | # pull the vectors for the faces we are going to sample from 268 | nml_origins = nml_origins[face_index] 269 | nml_vectors = nml_vectors[face_index] 270 | 271 | clr_origins = clr_origins[face_index] 272 | clr_vectors = clr_vectors[face_index] 273 | 274 | # randomly generate two 0-1 scalar components to multiply edge vectors by 275 | random_lengths = np.random.random((len(tri_vectors), 2, 1)) 276 | 277 | # points will be distributed on a quadrilateral if we use 2 0-1 samples 278 | # if the two scalar components sum less than 1.0 the point will be 279 | # inside the triangle, so we find vectors longer than 1.0 and 280 | # transform them to be inside the triangle 281 | random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0 282 | random_lengths[random_test] -= 1.0 283 | random_lengths = np.abs(random_lengths) 284 | 285 | # multiply triangle edge vectors by the random lengths and sum 286 | sample_vector = (tri_vectors * random_lengths).sum(axis=1) 287 | sample_normal = (nml_vectors * random_lengths).sum(axis=1) 288 | sample_color = (clr_vectors * random_lengths).sum(axis=1) 289 | 290 | # finally, offset by the origin to generate 291 | # (n,3) points in space on the triangle 292 | samples = sample_vector + tri_origins 293 | 294 | normals = sample_normal + nml_origins 295 | 296 | colors = sample_color + clr_origins 297 | 298 | # mesh.export('train_sample_before_sample.ply') 299 | # print(mesh.visual.vertex_colors.dtype) 300 | # tmp_mesh = trimesh.Trimesh(vertices = samples, faces = np.zeros((0,3), dtype = np.int64), process=False) 301 | # tmp_mesh.visual.vertex_colors[:,:3] = (colors[:,:]*255.0).astype(np.int8) 302 | # tmp_mesh.export('train_sample.ply') 303 | # exit() 304 | 305 | return samples, normals, colors, face_index 306 | 307 | -------------------------------------------------------------------------------- /lib/geo_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import math 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as Fn 22 | 23 | def detectBoundary(F): 24 | ''' 25 | input: 26 | F: (F, 3) numpy triangle list 27 | return: 28 | (F) boundary flag 29 | ''' 30 | tri_dic = {} 31 | nV = F.max() 32 | for i in range(F.shape[0]): 33 | idx = [F[i,0],F[i,1],F[i,2]] 34 | 35 | if (idx[1],idx[0]) in tri_dic: 36 | tri_dic[(idx[1],idx[0])].append(i) 37 | else: 38 | tri_dic[(idx[0],idx[1])] = [i] 39 | 40 | if (idx[2],idx[1]) in tri_dic: 41 | tri_dic[(idx[2],idx[1])].append(i) 42 | else: 43 | tri_dic[(idx[1],idx[2])] = [i] 44 | 45 | if (idx[0],idx[2]) in tri_dic: 46 | tri_dic[(idx[0],idx[2])].append(i) 47 | else: 48 | tri_dic[(idx[2],idx[0])] = [i] 49 | 50 | v_boundary = np.array((nV+1)*[False]) 51 | for key in tri_dic: 52 | if len(tri_dic[key]) != 2: 53 | v_boundary[key[0]] = True 54 | v_boundary[key[1]] = True 55 | 56 | boundary = v_boundary[F[:,0]] | v_boundary[F[:,1]] | v_boundary[F[:,2]] 57 | 58 | return boundary 59 | 60 | def computeMeanCurvature(V, N, F, norm_factor=10.0): 61 | ''' 62 | input: 63 | V: (B, N, 3) 64 | N: (B, N, 3) 65 | F: (B, F, 3) 66 | output: 67 | (B, F, 3) cotangent weight, corresponding edge is ordered in 23, 31, 12 68 | ''' 69 | B, nF = F.size()[:2] 70 | 71 | indices_repeat = F[:,:,None].expand(*F.size()[:2],3,*F.size()[2:]) 72 | 73 | v1 = torch.gather(V, 1, indices_repeat[:, :, :, 0].long()) 74 | v2 = torch.gather(V, 1, indices_repeat[:, :, :, 1].long()) 75 | v3 = torch.gather(V, 1, indices_repeat[:, :, :, 2].long()) 76 | 77 | n1 = torch.gather(N, 1, indices_repeat[:, :, :, 0].long()) 78 | n2 = torch.gather(N, 1, indices_repeat[:, :, :, 1].long()) 79 | n3 = torch.gather(N, 1, indices_repeat[:, :, :, 2].long()) 80 | 81 | dv1 = v2 - v3 82 | dv2 = v3 - v1 83 | dv3 = v1 - v2 84 | 85 | lsq1 = dv1.pow(2).sum(2) 86 | lsq2 = dv2.pow(2).sum(2) 87 | lsq3 = dv3.pow(2).sum(2) 88 | 89 | dn1 = n2 - n3 90 | dn2 = n3 - n1 91 | dn3 = n1 - n2 92 | 93 | c1 = (dv1 * dn1).sum(2) / (lsq1 + 1e-8) # (B, F) 94 | c2 = (dv2 * dn2).sum(2) / (lsq2 + 1e-8) 95 | c3 = (dv3 * dn3).sum(2) / (lsq3 + 1e-8) 96 | 97 | C = torch.stack([c1, c2, c3], 2)[:,:,:,None].expand(B, nF, 3, 2).contiguous().view(B, -1, 1) 98 | 99 | idx1 = F[:,:,0:1] 100 | idx2 = F[:,:,1:2] 101 | idx3 = F[:,:,2:] 102 | 103 | idx23 = torch.stack([idx2, idx3], 3) 104 | idx31 = torch.stack([idx3, idx1], 3) 105 | idx12 = torch.stack([idx1, idx2], 3) 106 | 107 | Fst = torch.cat([idx23, idx31, idx12], 2).contiguous().view(B, -1, 1) 108 | 109 | Hv = torch.zeros_like(V[:,:,0:1]) # (B, N) 110 | Cv = torch.zeros_like(V[:,:,0:1]) # (B, N) 111 | Cnt = torch.ones_like(C) # (B, N) 112 | Hv = Hv.scatter_add_(1, Fst.long(), C) 113 | Cv = Cv.scatter_add_(1, Fst.long(), Cnt) 114 | 115 | Hv = Hv / Cv / (2.0*norm_factor) + 0.5 # to roughly range [-1, 1] 116 | 117 | return Hv 118 | 119 | def vertices_to_faces(vertices, faces): 120 | assert (faces.ndimension() == 3) 121 | assert (vertices.shape[0] == faces.shape[0]) 122 | assert (faces.shape[2] == 3) 123 | 124 | bs, nv, c = vertices.shape 125 | bs, nf = faces.shape[:2] 126 | device = vertices.device 127 | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] 128 | vertices = vertices.reshape((bs * nv, c)) 129 | # pytorch only supports long and byte tensors for indexing 130 | return vertices[faces.long()] 131 | 132 | def compute_normal_v(Vs, Fs, norm=True): 133 | B, nF = Fs.size()[:2] 134 | Vf = vertices_to_faces(Vs, Fs) 135 | Vf = Vf.reshape((B * nF, 3, 3)) 136 | v10 = Vf[:, 1] - Vf[:, 0] 137 | v20 = Vf[:, 2] - Vf[:, 0] 138 | nf = torch.cross(v10, v20).view(B, nF, 3) # (B * nF, 3) 139 | 140 | Ns = torch.zeros(Vs.size()) # (B, N, 3) 141 | Fs = Fs.view(Fs.size(0),Fs.size(1),3,1).expand(Fs.size(0),Fs.size(1),3,3) 142 | nf = nf.view(nf.size(0),nf.size(1),1,nf.size(2)).expand_as(Fs).contiguous() 143 | Ns = Ns.scatter_add_(1, Fs.long().reshape(Fs.size(0),-1,3).cpu(), nf.reshape(Fs.size(0),-1,3).cpu()).type_as(Vs) 144 | 145 | # Ns = torch.zeros_like(Vs) # (B, N, 3) 146 | # Fs = Fs.view(B,nF,3,1).expand(B,nF,3,3) 147 | # nf = nf.view(B,nF,1,3).expand_as(Fs).contiguous() 148 | # Ns = Ns.scatter_add_(1, Fs.long().view(B,-1,3), nf.view(B,-1,3)) 149 | 150 | if norm: 151 | Ns = Fn.normalize(Ns, dim=2) 152 | 153 | return Ns 154 | -------------------------------------------------------------------------------- /lib/geometry.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import torch 18 | 19 | def index_custom(feat, uv): 20 | ''' 21 | args: 22 | feat: (B, C, H, W) 23 | uv: (B, 2, N) 24 | return: 25 | (B, C, N) 26 | ''' 27 | device = feat.device 28 | B, C, H, W = feat.size() 29 | _, _, N = uv.size() 30 | 31 | x, y = uv[:,0], uv[:,1] 32 | x = (W-1.0) * (0.5 * x.contiguous().view(-1) + 0.5) 33 | y = (H-1.0) * (0.5 * y.contiguous().view(-1) + 0.5) 34 | 35 | x0 = torch.floor(x).int() 36 | x1 = x0 + 1 37 | y0 = torch.floor(y).int() 38 | y1 = y0 + 1 39 | 40 | max_x = W - 1 41 | max_y = H - 1 42 | 43 | x0_clamp = torch.clamp(x0, 0, max_x) 44 | x1_clamp = torch.clamp(x1, 0, max_x) 45 | y0_clamp = torch.clamp(y0, 0, max_y) 46 | y1_clamp = torch.clamp(y1, 0, max_y) 47 | 48 | dim2 = W 49 | dim1 = W * H 50 | 51 | base = (dim1 * torch.arange(B).int()).view(B, 1).expand(B, N).contiguous().view(-1).to(device) 52 | 53 | base_y0 = base + y0_clamp * dim2 54 | base_y1 = base + y1_clamp * dim2 55 | 56 | idx_y0_x0 = base_y0 + x0_clamp 57 | idx_y0_x1 = base_y0 + x1_clamp 58 | idx_y1_x0 = base_y1 + x0_clamp 59 | idx_y1_x1 = base_y1 + x1_clamp 60 | 61 | # (B,C,H,W) -> (B,H,W,C) 62 | im_flat = feat.permute(0,2,3,1).contiguous().view(-1, C) 63 | i_y0_x0 = torch.gather(im_flat, 0, idx_y0_x0.unsqueeze(1).expand(-1,C).long()) 64 | i_y0_x1 = torch.gather(im_flat, 0, idx_y0_x1.unsqueeze(1).expand(-1,C).long()) 65 | i_y1_x0 = torch.gather(im_flat, 0, idx_y1_x0.unsqueeze(1).expand(-1,C).long()) 66 | i_y1_x1 = torch.gather(im_flat, 0, idx_y1_x1.unsqueeze(1).expand(-1,C).long()) 67 | 68 | # Check the out-of-boundary case. 69 | x0_valid = (x0 <= max_x) & (x0 >= 0) 70 | x1_valid = (x1 <= max_x) & (x1 >= 0) 71 | y0_valid = (y0 <= max_y) & (y0 >= 0) 72 | y1_valid = (y1 <= max_y) & (y1 >= 0) 73 | 74 | x0 = x0.float() 75 | x1 = x1.float() 76 | y0 = y0.float() 77 | y1 = y1.float() 78 | 79 | w_y0_x0 = ((x1 - x) * (y1 - y) * (x1_valid * y1_valid).float()).unsqueeze(1) 80 | w_y0_x1 = ((x - x0) * (y1 - y) * (x0_valid * y1_valid).float()).unsqueeze(1) 81 | w_y1_x0 = ((x1 - x) * (y - y0) * (x1_valid * y0_valid).float()).unsqueeze(1) 82 | w_y1_x1 = ((x - x0) * (y - y0) * (x0_valid * y0_valid).float()).unsqueeze(1) 83 | 84 | output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1 # (B, N, C) 85 | 86 | return output.view(B, N, C).permute(0,2,1).contiguous() 87 | 88 | def index3d_custom(feat, pts): 89 | ''' 90 | args: 91 | feat: (B, C, D, H, W) 92 | pts: (B, 3, N) 93 | return: 94 | (B, C, N) 95 | ''' 96 | device = feat.device 97 | B, C, D, H, W = feat.size() 98 | _, _, N = pts.size() 99 | 100 | x, y, z = pts[:,0], pts[:,1], pts[:,2] 101 | x = (W-1.0) * (0.5 * x.contiguous().view(-1) + 0.5) 102 | y = (H-1.0) * (0.5 * y.contiguous().view(-1) + 0.5) 103 | z = (D-1.0) * (0.5 * z.contiguous().view(-1) + 0.5) 104 | 105 | x0 = torch.floor(x).int() 106 | x1 = x0 + 1 107 | y0 = torch.floor(y).int() 108 | y1 = y0 + 1 109 | z0 = torch.floor(z).int() 110 | z1 = z0 + 1 111 | 112 | max_x = W - 1 113 | max_y = H - 1 114 | max_z = D - 1 115 | 116 | x0_clamp = torch.clamp(x0, 0, max_x) 117 | x1_clamp = torch.clamp(x1, 0, max_x) 118 | y0_clamp = torch.clamp(y0, 0, max_y) 119 | y1_clamp = torch.clamp(y1, 0, max_y) 120 | z0_clamp = torch.clamp(z0, 0, max_z) 121 | z1_clamp = torch.clamp(z1, 0, max_z) 122 | 123 | dim3 = W 124 | dim2 = W * H 125 | dim1 = W * H * D 126 | 127 | base = (dim1 * torch.arange(B).int()).view(B, 1).expand(B, N).contiguous().view(-1).to(device) 128 | 129 | base_z0_y0 = base + z0_clamp * dim2 + y0_clamp * dim3 130 | base_z0_y1 = base + z0_clamp * dim2 + y1_clamp * dim3 131 | base_z1_y0 = base + z1_clamp * dim2 + y0_clamp * dim3 132 | base_z1_y1 = base + z1_clamp * dim2 + y1_clamp * dim3 133 | 134 | idx_z0_y0_x0 = base_z0_y0 + x0_clamp 135 | idx_z0_y0_x1 = base_z0_y0 + x1_clamp 136 | idx_z0_y1_x0 = base_z0_y1 + x0_clamp 137 | idx_z0_y1_x1 = base_z0_y1 + x1_clamp 138 | idx_z1_y0_x0 = base_z1_y0 + x0_clamp 139 | idx_z1_y0_x1 = base_z1_y0 + x1_clamp 140 | idx_z1_y1_x0 = base_z1_y1 + x0_clamp 141 | idx_z1_y1_x1 = base_z1_y1 + x1_clamp 142 | 143 | # (B,C,D,H,W) -> (B,D,H,W,C) 144 | im_flat = feat.permute(0,2,3,4,1).contiguous().view(-1, C) 145 | i_z0_y0_x0 = torch.gather(im_flat, 0, idx_z0_y0_x0.unsqueeze(1).expand(-1,C).long()) 146 | i_z0_y0_x1 = torch.gather(im_flat, 0, idx_z0_y0_x1.unsqueeze(1).expand(-1,C).long()) 147 | i_z0_y1_x0 = torch.gather(im_flat, 0, idx_z0_y1_x0.unsqueeze(1).expand(-1,C).long()) 148 | i_z0_y1_x1 = torch.gather(im_flat, 0, idx_z0_y1_x1.unsqueeze(1).expand(-1,C).long()) 149 | i_z1_y0_x0 = torch.gather(im_flat, 0, idx_z1_y0_x0.unsqueeze(1).expand(-1,C).long()) 150 | i_z1_y0_x1 = torch.gather(im_flat, 0, idx_z1_y0_x1.unsqueeze(1).expand(-1,C).long()) 151 | i_z1_y1_x0 = torch.gather(im_flat, 0, idx_z1_y1_x0.unsqueeze(1).expand(-1,C).long()) 152 | i_z1_y1_x1 = torch.gather(im_flat, 0, idx_z1_y1_x1.unsqueeze(1).expand(-1,C).long()) 153 | 154 | # Check the out-of-boundary case. 155 | x0_valid = (x0 <= max_x) & (x0 >= 0) 156 | x1_valid = (x1 <= max_x) & (x1 >= 0) 157 | y0_valid = (y0 <= max_y) & (y0 >= 0) 158 | y1_valid = (y1 <= max_y) & (y1 >= 0) 159 | z0_valid = (z0 <= max_z) & (z0 >= 0) 160 | z1_valid = (z1 <= max_z) & (z1 >= 0) 161 | 162 | x0 = x0.float() 163 | x1 = x1.float() 164 | y0 = y0.float() 165 | y1 = y1.float() 166 | z0 = z0.float() 167 | z1 = z1.float() 168 | 169 | w_z0_y0_x0 = ((x1 - x) * (y1 - y) * (z1 - z) * (x1_valid * y1_valid * z1_valid).float()).unsqueeze(1) 170 | w_z0_y0_x1 = ((x - x0) * (y1 - y) * (z1 - z) * (x0_valid * y1_valid * z1_valid).float()).unsqueeze(1) 171 | w_z0_y1_x0 = ((x1 - x) * (y - y0) * (z1 - z) * (x1_valid * y0_valid * z1_valid).float()).unsqueeze(1) 172 | w_z0_y1_x1 = ((x - x0) * (y - y0) * (z1 - z) * (x0_valid * y0_valid * z1_valid).float()).unsqueeze(1) 173 | w_z1_y0_x0 = ((x1 - x) * (y1 - y) * (z - z0) * (x1_valid * y1_valid * z0_valid).float()).unsqueeze(1) 174 | w_z1_y0_x1 = ((x - x0) * (y1 - y) * (z - z0) * (x0_valid * y1_valid * z0_valid).float()).unsqueeze(1) 175 | w_z1_y1_x0 = ((x1 - x) * (y - y0) * (z - z0) * (x1_valid * y0_valid * z0_valid).float()).unsqueeze(1) 176 | w_z1_y1_x1 = ((x - x0) * (y - y0) * (z - z0) * (x0_valid * y0_valid * z0_valid).float()).unsqueeze(1) 177 | 178 | output = w_z0_y0_x0 * i_z0_y0_x0 + w_z0_y0_x1 * i_z0_y0_x1 + w_z0_y1_x0 * i_z0_y1_x0 + w_z0_y1_x1 * i_z0_y1_x1 \ 179 | + w_z1_y0_x0 * i_z1_y0_x0 + w_z1_y0_x1 * i_z1_y0_x1 + w_z1_y1_x0 * i_z1_y1_x0 + w_z1_y1_x1 * i_z1_y1_x1 180 | 181 | return output.view(B, N, C).permute(0,2,1).contiguous() 182 | 183 | def index3d_nearest(feat, pts): 184 | ''' 185 | args: 186 | feat: (B, C, D, H, W) 187 | pts: (B, 3, N) 188 | return: 189 | (B, C, N) 190 | ''' 191 | device = feat.device 192 | B, C, D, H, W = feat.size() 193 | _, _, N = pts.size() 194 | 195 | x, y, z = pts[:,0], pts[:,1], pts[:,2] 196 | x = (W-1.0) * (0.5 * x.contiguous().view(-1) + 0.5) 197 | y = (H-1.0) * (0.5 * y.contiguous().view(-1) + 0.5) 198 | z = (D-1.0) * (0.5 * z.contiguous().view(-1) + 0.5) 199 | 200 | x0 = torch.floor(x).int() 201 | y0 = torch.floor(y).int() 202 | z0 = torch.floor(z).int() 203 | 204 | max_x = W - 1 205 | max_y = H - 1 206 | max_z = D - 1 207 | 208 | x0_clamp = torch.clamp(x0, 0, max_x) 209 | y0_clamp = torch.clamp(y0, 0, max_y) 210 | z0_clamp = torch.clamp(z0, 0, max_z) 211 | 212 | s = x - x0.float() 213 | t = y - y0.float() 214 | v = z - z0.float() 215 | 216 | dim3 = W 217 | dim2 = W * H 218 | dim1 = W * H * D 219 | 220 | x0_valid = (x0 <= max_x) & (x0 >= 0) 221 | y0_valid = (y0 <= max_y) & (y0 >= 0) 222 | z0_valid = (z0 <= max_z) & (z0 >= 0) 223 | 224 | base = (dim1 * torch.arange(B).int()).view(B, 1).expand(B, N).contiguous().view(-1).to(device) 225 | 226 | base_z0_y0 = base + z0_clamp * dim2 + y0_clamp * dim3 227 | idx_z0_y0_x0 = base_z0_y0 + x0_clamp 228 | 229 | # (B,C,D,H,W) -> (B,D,H,W,C) 230 | im_flat = feat.permute(0,2,3,4,1).contiguous().view(-1, C) 231 | i_z0_y0_x0 = torch.gather(im_flat, 0, idx_z0_y0_x0.unsqueeze(1).expand(-1,C).long()) 232 | 233 | w_z0_y0_x0 = ((x0_valid * y0_valid * z0_valid).float()).unsqueeze(1) 234 | 235 | stv = torch.stack([s.view(B,-1), t.view(B,-1), v.view(B,-1)], 1) 236 | 237 | output = (w_z0_y0_x0 * i_z0_y0_x0).view(B, N, C).permute(0,2,1).contiguous() 238 | 239 | return output, stv-0.5 # (-0.5, 0.5) 240 | 241 | def index3d_nearest_overlap(feat, pts): 242 | ''' 243 | args: 244 | feat: (B, C, D, H, W) 245 | pts: (B, 3, N) 246 | return: 247 | (B, C, N*8) 248 | ''' 249 | device = feat.device 250 | B, C, D, H, W = feat.size() 251 | _, _, N = pts.size() 252 | 253 | x, y, z = pts[:,0], pts[:,1], pts[:,2] 254 | x = (W-1.0) * (0.5 * x.contiguous().view(-1) + 0.5) 255 | y = (H-1.0) * (0.5 * y.contiguous().view(-1) + 0.5) 256 | z = (D-1.0) * (0.5 * z.contiguous().view(-1) + 0.5) 257 | 258 | x0 = torch.floor(x).int() 259 | y0 = torch.floor(y).int() 260 | z0 = torch.floor(z).int() 261 | 262 | s = x - x0.float() 263 | t = y - y0.float() 264 | v = z - z0.float() 265 | 266 | s_side = (s >= 0.5) 267 | t_side = (t >= 0.5) 268 | v_side = (v >= 0.5) 269 | 270 | xn = x0[:,None].expand(-1,2).contiguous() 271 | yn = y0[:,None].expand(-1,2).contiguous() 272 | zn = z0[:,None].expand(-1,2).contiguous() 273 | 274 | s = s[:,None].expand(-1,2).contiguous() 275 | t = t[:,None].expand(-1,2).contiguous() 276 | v = v[:,None].expand(-1,2).contiguous() 277 | 278 | xn[s_side,1] += 1 279 | xn[~s_side,1] -= 1 280 | yn[t_side,1] += 1 281 | yn[~t_side,1] -= 1 282 | zn[v_side,1] += 1 283 | zn[~v_side,1] -= 1 284 | 285 | s[s_side,1] -= 1.0 286 | s[~s_side,1] += 1.0 287 | t[t_side,1] -= 1.0 288 | t[~t_side,1] += 1.0 289 | v[v_side,1] -= 1.0 290 | v[~v_side,1] += 1.0 291 | 292 | s = s[:,None,None,:].expand(-1,2,2,-1).contiguous().view(-1) 293 | t = t[:,None,:,None].expand(-1,2,-1,2).contiguous().view(-1) 294 | v = v[:,:,None,None].expand(-1,-1,2,2).contiguous().view(-1) 295 | 296 | max_x = W - 1 297 | max_y = H - 1 298 | max_z = D - 1 299 | 300 | xn_clamp = torch.clamp(xn, 0, max_x) 301 | yn_clamp = torch.clamp(yn, 0, max_y) 302 | zn_clamp = torch.clamp(zn, 0, max_z) 303 | 304 | dim3 = W 305 | dim2 = W * H 306 | dim1 = W * H * D 307 | 308 | xn_valid = (xn <= max_x) & (xn >= 0) 309 | yn_valid = (yn <= max_y) & (yn >= 0) 310 | zn_valid = (zn <= max_z) & (zn >= 0) 311 | 312 | base = (dim1 * torch.arange(B).int()).view(B, 1).expand(B, N).contiguous().view(-1).to(device) 313 | 314 | base_zn_yn = base[:,None,None,None] + zn_clamp[:,:,None,None] * dim2 + yn_clamp[:,None,:,None] * dim3 315 | idx_zn_yn_xn = base_zn_yn + xn_clamp[:,None,None,:] # (BN, 2, 2, 2) 316 | 317 | # (B,C,D,H,W) -> (B,D,H,W,C) 318 | im_flat = feat.permute(0,2,3,4,1).contiguous().view(-1, C) 319 | i_z0_y0_x0 = torch.gather(im_flat, 0, idx_zn_yn_xn.view(-1,1).expand(-1,C).long()) 320 | 321 | w_z0_y0_x0 = ((xn_valid[:,None,None,:] * yn_valid[:,None,:,None] * zn_valid[:,:,None,None]).float()).view(-1,1) 322 | 323 | stv = torch.stack([s.view(B,-1), t.view(B,-1), v.view(B,-1)], 1) 324 | 325 | output = (w_z0_y0_x0 * i_z0_y0_x0).view(B, N*8, C).permute(0,2,1).contiguous() 326 | 327 | return output, stv-0.5 # (-1.0, 1.0) 328 | 329 | def index3d_nearest_boundary(feat, pts): 330 | ''' 331 | args: 332 | feat: (B, C, D, H, W) 333 | pts: (B, 3, N) 334 | return: 335 | sampled feature (B, C, N*6) 336 | stv (B, 3, N*6) 337 | ''' 338 | device = feat.device 339 | B, C, D, H, W = feat.size() 340 | _, _, N = pts.size() 341 | 342 | x, y, z = pts[:,0], pts[:,1], pts[:,2] 343 | x = (W-1.0) * (0.5 * x.contiguous().view(-1) + 0.5) 344 | y = (H-1.0) * (0.5 * y.contiguous().view(-1) + 0.5) 345 | z = (D-1.0) * (0.5 * z.contiguous().view(-1) + 0.5) 346 | 347 | x0 = torch.floor(x).int() 348 | y0 = torch.floor(y).int() 349 | z0 = torch.floor(z).int() 350 | 351 | s = x - x0.float() 352 | t = y - y0.float() 353 | v = z - z0.float() 354 | 355 | s_side = (s >= 0.5) 356 | t_side = (t >= 0.5) 357 | v_side = (v >= 0.5) 358 | 359 | s_mid = torch.stack([s, s_side.float()], -1) 360 | t_mid = torch.stack([t, t_side.float()], -1) 361 | v_mid = torch.stack([v, v_side.float()], -1) 362 | 363 | s_mid = s_mid[:,[1,0,0]] 364 | t_mid = t_mid[:,[0,1,0]] 365 | v_mid = v_mid[:,[0,0,1]] 366 | 367 | xn = x0[:,None].expand(-1,2).contiguous() 368 | yn = y0[:,None].expand(-1,2).contiguous() 369 | zn = z0[:,None].expand(-1,2).contiguous() 370 | 371 | s = s[:,None].expand(-1,2).contiguous() 372 | t = t[:,None].expand(-1,2).contiguous() 373 | v = v[:,None].expand(-1,2).contiguous() 374 | 375 | xn[s_side,1] += 1 376 | xn[~s_side,1] -= 1 377 | yn[t_side,1] += 1 378 | yn[~t_side,1] -= 1 379 | zn[v_side,1] += 1 380 | zn[~v_side,1] -= 1 381 | 382 | s[s_side,1] = 0.0 383 | s[~s_side,1] = 1.0 384 | t[t_side,1] = 0.0 385 | t[~t_side,1] = 1.0 386 | v[v_side,1] = 0.0 387 | v[~v_side,1] = 1.0 388 | 389 | s = s[:,[1,0,0]] 390 | t = t[:,[0,1,0]] 391 | v = v[:,[0,0,1]] 392 | 393 | max_x = W - 1 394 | max_y = H - 1 395 | max_z = D - 1 396 | 397 | xn_clamp = torch.clamp(xn, 0, max_x) 398 | yn_clamp = torch.clamp(yn, 0, max_y) 399 | zn_clamp = torch.clamp(zn, 0, max_z) 400 | 401 | dim3 = W 402 | dim2 = W * H 403 | dim1 = W * H * D 404 | 405 | xn_valid = (xn <= max_x) & (xn >= 0) 406 | yn_valid = (yn <= max_y) & (yn >= 0) 407 | zn_valid = (zn <= max_z) & (zn >= 0) 408 | 409 | base = (dim1 * torch.arange(B).int()).view(B, 1).expand(B, N).contiguous().view(-1).to(device) 410 | 411 | base_zn_yn = base[:,None] + zn_clamp[:,[0,0,0,1]] * dim2 + yn_clamp[:,[0,0,1,0]] * dim3 412 | idx_zn_yn_xn = base_zn_yn + xn_clamp[:,[0,1,0,0]] # (BN, 4) 413 | 414 | # (B,C,D,H,W) -> (B,D,H,W,C) 415 | im_flat = feat.permute(0,2,3,4,1).contiguous().view(-1, C) 416 | i_z0_y0_x0 = torch.gather(im_flat, 0, idx_zn_yn_xn.view(-1,1).expand(-1,C).long()) 417 | w_z0_y0_x0 = ((xn_valid[:,[0,1,0,0]] * yn_valid[:,[0,0,1,0]] * zn_valid[:,[0,0,0,1]]).float()).view(-1,1) 418 | 419 | output = (w_z0_y0_x0 * i_z0_y0_x0).view(B,N,4,C).permute(0,3,1,2) 420 | 421 | out_mid = output[:,:,:,:1].expand(-1,-1,-1,3) 422 | out_bound = output[:,:,:,1:] 423 | 424 | s = torch.cat([s_mid, s], -1) 425 | t = torch.cat([t_mid, t], -1) 426 | v = torch.cat([v_mid, v], -1) 427 | 428 | stv = torch.stack([s.view(B,-1), t.view(B,-1), v.view(B,-1)], 1) 429 | output = torch.cat([out_mid, out_bound], -1) 430 | output = output.view(B,C,-1) 431 | 432 | return output, stv-0.5 # (-1.0, 1.0) 433 | 434 | def index(feat, uv, mode='bilinear'): 435 | ''' 436 | :param feat: [B, C, H, W] image features 437 | :param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1] 438 | :return: [B, C, N] image features at the uv coordinates 439 | ''' 440 | uv = uv.transpose(1, 2) # [B, N, 2] 441 | uv = uv.unsqueeze(2) # [B, N, 1, 2] 442 | samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True, mode=mode) # [B, C, N, 1] 443 | return samples[:, :, :, 0] # [B, C, N] 444 | 445 | def index3d(feat, pts, mode='bilinear'): 446 | ''' 447 | :param feat: [B, C, D, H, W] image features 448 | :param pts: [B, 3, N] normalized 3d coordinates, range [-1, 1] 449 | :return: [B, C, N] image features at the pts coordinates 450 | ''' 451 | pts = pts.transpose(1, 2) # [B, N, 3] 452 | pts = pts[:,:,None,None] # [B, N, 1, 1, 3] 453 | samples = torch.nn.functional.grid_sample(feat, pts, align_corners=True, mode=mode) # [B, C, N, 1, 1] 454 | return samples[:, :, :, 0, 0] # [B, C, N] 455 | 456 | def orthogonal(points, calibrations): 457 | ''' 458 | Compute the orthogonal projections of 3D points into the image plane by given projection matrix 459 | :param points: [B, 3, N] Tensor of 3D points 460 | :param calibrations: [B, 3, 4] Tensor of projection matrix 461 | :param transforms: [B, 2, 3] Tensor of image transform matrix 462 | :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane 463 | ''' 464 | rot = calibrations[:, :3, :3] 465 | trans = calibrations[:, :3, 3:4] 466 | pts = torch.baddbmm(trans, rot, points) # [B, 3, N] 467 | return pts 468 | 469 | 470 | def perspective(points, calibrations): 471 | ''' 472 | Compute the perspective projections of 3D points into the image plane by given projection matrix 473 | :param points: [Bx3xN] Tensor of 3D points 474 | :param calibrations: [Bx3x4] Tensor of projection matrix 475 | :param transforms: [Bx2x3] Tensor of image transform matrix 476 | :return: xy: [Bx2xN] Tensor of xy coordinates in the image plane 477 | ''' 478 | rot = calibrations[:, :3, :3] 479 | trans = calibrations[:, :3, 3:4] 480 | homo = torch.baddbmm(trans, rot, points) # [B, 3, N] 481 | xy = homo[:, :2, :] / homo[:, 2:3, :] 482 | xyz = torch.cat([xy, homo[:, 2:3, :]], 1) 483 | return xyz 484 | -------------------------------------------------------------------------------- /lib/model/BaseIMNet3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import os 18 | import copy 19 | import math 20 | import torch 21 | import torch.autograd as autograd 22 | import torch.nn as nn 23 | from torch.nn import init 24 | import torch.nn.functional as F 25 | from .MLP import MLP 26 | 27 | from ..net_util import init_net, load_network, get_embedder, init_mlp_siren 28 | from torch.nn import init 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | import torch.autograd as autograd 32 | # import functools 33 | 34 | import numpy as np 35 | from ..mesh_util import save_obj_mesh_with_color, save_obj_mesh 36 | from ..geometry import index3d, index3d_custom 37 | 38 | class BaseIMNet3d(nn.Module): 39 | def __init__(self, 40 | opt, 41 | bbox_min=[-1.0,-1.0,-1.0], 42 | bbox_max=[1.0,1.0,1.0] 43 | ): 44 | super(BaseIMNet3d, self).__init__() 45 | 46 | self.body_centric_encoding = False if opt['mlp']['ch_dim'][0] == 3 else True 47 | 48 | self.name = 'base_imnet3d' 49 | self.opt = copy.deepcopy(opt) 50 | 51 | if opt['use_embed']: 52 | self.embedder, self.opt['mlp']['ch_dim'][0] = get_embedder(opt['d_size'], input_dims=opt['mlp']['ch_dim'][0]) 53 | else: 54 | self.embedder = None 55 | 56 | if 'g_dim' in self.opt: 57 | self.opt['mlp']['ch_dim'][0] += self.opt['g_dim'] 58 | if 'pose_dim' in self.opt: 59 | self.opt['mlp']['ch_dim'][0] += self.opt['pose_dim'] * 23 60 | 61 | self.mlp = MLP( 62 | filter_channels=self.opt['mlp']['ch_dim'], 63 | res_layers=self.opt['mlp']['res_layers'], 64 | last_op=self.opt['mlp']['last_op'], 65 | nlactiv=self.opt['mlp']['nlactiv'], 66 | norm=self.opt['mlp']['norm']) 67 | 68 | init_net(self) 69 | 70 | if self.opt['mlp']['nlactiv'] == 'sin': # SIREN 71 | self.mlp.apply(init_mlp_siren) 72 | 73 | self.register_buffer('bbox_min', torch.Tensor(bbox_min)[None,:,None]) 74 | self.register_buffer('bbox_max', torch.Tensor(bbox_max)[None,:,None]) 75 | 76 | self.feat3d = None 77 | self.global_feat = None 78 | 79 | def filter(self, feat): 80 | ''' 81 | Store 3d feature 82 | args: 83 | feat: (B, C, D, H, W) 84 | ''' 85 | self.feat3d = feat 86 | 87 | def set_global_feat(self, feat): 88 | self.global_feat = feat 89 | 90 | def query(self, points, calib_tensor=None, bmin=None, bmax=None): 91 | ''' 92 | Given 3D points, query the network predictions for each point. 93 | args: 94 | points: (B, 3, N) 95 | return: 96 | (B, C, N) 97 | ''' 98 | N = points.size(2) 99 | 100 | if bmin is None: 101 | bmin = self.bbox_min 102 | if bmax is None: 103 | bmax = self.bbox_max 104 | points_nc3d = 2.0 * (points - bmin) / (bmax - bmin) - 1.0 # normalized coordiante 105 | # points_nc3d = 1.0*points 106 | if self.feat3d is not None and self.body_centric_encoding: 107 | point_local_feat = index3d_custom(self.feat3d, points_nc3d) 108 | else: # not body_centric_encoding 109 | point_local_feat = points_nc3d 110 | 111 | if self.embedder is not None: 112 | point_local_feat = self.embedder(point_local_feat.permute(0,2,1)).permute(0,2,1) 113 | 114 | if self.global_feat is not None: 115 | point_local_feat = torch.cat([point_local_feat, self.global_feat[:,:,None].expand(-1,-1,N)], 1) 116 | 117 | w0 = 30.0 if self.opt['mlp']['nlactiv'] == 'sin' else 1.0 118 | 119 | return self.mlp(w0*point_local_feat) 120 | 121 | # for debug 122 | def get_point_feat(self, feat, points, custom_index=True, bmin=None, bmax=None): 123 | if bmin is None: 124 | bmin = self.bbox_min 125 | if bmax is None: 126 | bmax = self.bbox_max 127 | 128 | points_nc3d = 2.0 * (points - bmin) / (bmax - bmin) - 1.0 # normalized coordiante/ 129 | # points_nc3d = 1.0*points/ 130 | if custom_index: 131 | return index3d_custom(feat, points_nc3d) 132 | else: 133 | return index3d(feat, points_nc3d) 134 | 135 | def forward(self, feat, points): 136 | # Set 3d feature 137 | self.filter(feat) 138 | 139 | return self.query(points) 140 | -------------------------------------------------------------------------------- /lib/model/IGRSDFNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import os 18 | import math 19 | import torch 20 | import torch.autograd as autograd 21 | import torch.nn as nn 22 | from torch.nn import init 23 | import torch.nn.functional as F 24 | from .MLP import MLP 25 | 26 | from ..net_util import init_net, load_network, get_embedder, init_mlp_geometric, init_mlp_siren 27 | from torch.nn import init 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | import torch.autograd as autograd 31 | # import functools 32 | 33 | import numpy as np 34 | from ..mesh_util import save_obj_mesh_with_color, save_obj_mesh, scalar_to_color 35 | from ..geometry import index3d, index3d_custom 36 | from .BaseIMNet3d import BaseIMNet3d 37 | 38 | from ..net_util import get_embedder 39 | 40 | class IGRSDFNet(BaseIMNet3d): 41 | def __init__(self, 42 | opt, 43 | bbox_min, 44 | bbox_max, 45 | pose_map 46 | ): 47 | super(IGRSDFNet, self).__init__(opt, bbox_min, bbox_max) 48 | 49 | self.name = 'neural_sdf_bpsigr' 50 | 51 | if opt['mlp']['nlactiv'] != 'sin': # SIREN 52 | self.mlp.apply(init_mlp_geometric) 53 | 54 | self.body_centric_encoding = True if opt['mlp']['ch_dim'][0] == 7 else False 55 | 56 | self.lbs_net = None 57 | self.pose_feat = None 58 | 59 | if opt['learn_posemap']: 60 | self.register_buffer('pose_map_init', pose_map) 61 | self.register_parameter('pose_map', nn.Parameter(pose_map)) 62 | else: 63 | self.register_buffer('pose_map', pose_map) 64 | 65 | self.bbox_regularization = True if not opt['lambda_bbox'] == 0 else False 66 | self.space_non_zero_regu = True if not opt['lambda_non_zero'] == 0 else False 67 | 68 | def set_lbsnet(self, net): 69 | self.lbs_net = net 70 | 71 | def set_pose_feat(self, pose_feat): 72 | self.pose_feat = pose_feat 73 | 74 | def filter(self, feat): 75 | self.feat3d = feat 76 | if self.lbs_net is not None: 77 | self.lbs_net.filter(feat) 78 | 79 | def query(self, points, calib_tensor=None, return_negative=True, update_lbs=False, bmin=None, bmax=None, return_last_layer_feature = False): 80 | ''' 81 | Given 3D points, query the network predictions for each point. 82 | args: 83 | points: (B, 3, N) 84 | return: 85 | (B, C, N) 86 | ''' 87 | N = points.size(2) 88 | 89 | if self.lbs_net is not None: 90 | self.lbs_net.filter(self.feat3d) 91 | # NOTE: the first value belongs to root 92 | lbs = self.lbs_net.query(points, bmin=bmin, bmax=bmax) 93 | if self.opt['learn_posemap']: 94 | lbs = torch.einsum('bjv,jl->blv', lbs, F.softmax(self.pose_map,dim=0)) 95 | else: 96 | lbs = torch.einsum('bjv,jl->blv', lbs, self.pose_map) 97 | if not update_lbs: 98 | lbs = lbs.detach() 99 | 100 | if bmin is None: 101 | bmin = self.bbox_min 102 | if bmax is None: 103 | bmax = self.bbox_max 104 | 105 | points_nc3d = 2.0 * (points - bmin) / (bmax - bmin) - 1.0 # normalized coordiante 106 | points_nc3d = points_nc3d.clamp(min=-1.0, max=1.0) 107 | # points_nc3d = 1.0 * points 108 | 109 | in_bbox = (points_nc3d[:, 0] >= -1.0) & (points_nc3d[:, 0] <= 1.0) &\ 110 | (points_nc3d[:, 1] >= -1.0) & (points_nc3d[:, 1] <= 1.0) &\ 111 | (points_nc3d[:, 2] >= -1.0) & (points_nc3d[:, 2] <= 1.0) 112 | in_bbox = in_bbox[:,None].float() 113 | 114 | if self.feat3d is not None and self.body_centric_encoding: 115 | point_local_feat = index3d(self.feat3d, points_nc3d) 116 | else: 117 | point_local_feat = points_nc3d 118 | 119 | if self.embedder is not None: 120 | point_local_feat = self.embedder(point_local_feat.permute(0,2,1)).permute(0,2,1) 121 | 122 | if self.global_feat is not None: 123 | global_feat = self.global_feat[:,:,None].expand(-1,-1,N) 124 | point_local_feat = torch.cat([point_local_feat, global_feat], 1) 125 | 126 | if self.pose_feat is not None: 127 | if self.lbs_net is not None: 128 | pose_feat = self.pose_feat.view(self.pose_feat.size(0),-1,self.opt['pose_dim'],1) * lbs[:,:,None] 129 | # Use entire feature 130 | if 'full_pose' in self.opt.keys(): 131 | if self.opt['full_pose']: 132 | pose_feat = self.pose_feat.view(self.pose_feat.size(0),-1,self.opt['pose_dim'],1) * torch.ones_like(lbs[:,:,None]) 133 | pose_feat = pose_feat.reshape(pose_feat.size(0),-1,N) 134 | else: 135 | pose_feat = self.pose_feat[:,:,None].expand(-1,-1,N) 136 | point_local_feat = torch.cat([point_local_feat, pose_feat], 1) 137 | 138 | w0 = 30.0 if self.opt['mlp']['nlactiv'] == 'sin' else 1.0 139 | 140 | if not return_last_layer_feature: 141 | if return_negative: 142 | return -in_bbox*self.mlp(w0*point_local_feat)-(1.0-in_bbox) 143 | else: 144 | return in_bbox*self.mlp(w0*point_local_feat)+(1.0-in_bbox) 145 | else: 146 | if return_negative: 147 | sdf, last_layer_feature = self.mlp(w0*point_local_feat, return_last_layer_feature = True) 148 | sdf = -in_bbox*sdf-(1.0-in_bbox) 149 | else: 150 | sdf, last_layer_feature = self.mlp(w0*point_local_feat, return_last_layer_feature = True) 151 | sdf = in_bbox*sdf+(1.0-in_bbox) 152 | return sdf, last_layer_feature, point_local_feat 153 | 154 | def compute_normal(self, points, normalize=False, return_pred=False, custom_index=False, update_lbs=False, bmin=None, bmax=None): 155 | ''' 156 | since image sampling operation does not have second order derivative, 157 | normal can be computed only via finite difference (forward differentiation) 158 | ''' 159 | N = points.size(2) 160 | 161 | with torch.enable_grad(): 162 | points.requires_grad_() 163 | 164 | if self.lbs_net is not None: 165 | self.lbs_net.filter(self.feat3d) 166 | # NOTE: the first value belongs to root 167 | lbs = self.lbs_net.query(points, bmin=bmin, bmax=bmax) 168 | if self.opt['learn_posemap']: 169 | lbs = torch.einsum('bjv,jl->blv', lbs, F.softmax(self.pose_map,dim=0)) 170 | else: 171 | lbs = torch.einsum('bjv,jl->blv', lbs, self.pose_map) 172 | if not update_lbs: 173 | lbs = lbs.detach() 174 | 175 | if bmin is None: 176 | bmin = self.bbox_min 177 | if bmax is None: 178 | bmax = self.bbox_max 179 | points_nc3d = 2.0 * (points - bmin) / (bmax - bmin) - 1.0 # normalized coordiante 180 | points_nc3d = points_nc3d.clamp(min=-1.0, max=1.0) 181 | # points_nc3d = 1.0 * points 182 | 183 | if self.feat3d is None: 184 | point_local_feat = points_nc3d 185 | else: 186 | if custom_index: 187 | point_local_feat = index3d_custom(self.feat3d, points_nc3d) 188 | else: 189 | point_local_feat = index3d(self.feat3d, points_nc3d) 190 | 191 | if not self.body_centric_encoding: 192 | point_local_feat = points_nc3d 193 | 194 | if self.embedder is not None: 195 | point_local_feat = self.embedder(point_local_feat.permute(0,2,1)).permute(0,2,1) 196 | 197 | if self.global_feat is not None: 198 | global_feat = self.global_feat[:,:,None].expand(-1,-1,N) 199 | point_local_feat = torch.cat([point_local_feat, global_feat], 1) 200 | 201 | if self.pose_feat is not None: 202 | if self.lbs_net is not None: 203 | pose_feat = self.pose_feat.view(self.pose_feat.size(0),-1,self.opt['pose_dim'],1) * lbs[:,:,None] 204 | # Use entire feature 205 | if 'full_pose' in self.opt.keys(): 206 | if self.opt['full_pose']: 207 | pose_feat = self.pose_feat.view(self.pose_feat.size(0),-1,self.opt['pose_dim'],1) * torch.ones_like(lbs[:,:,None]) 208 | 209 | pose_feat = pose_feat.reshape(pose_feat.size(0),-1,N) 210 | else: 211 | pose_feat = self.pose_feat[:,:,None].expand(-1,-1,N) 212 | point_local_feat = torch.cat([point_local_feat, pose_feat], 1) 213 | 214 | w0 = 30.0 if self.opt['mlp']['nlactiv'] == 'sin' else 1.0 215 | 216 | pred = self.mlp(w0*point_local_feat) 217 | normal = autograd.grad( 218 | [pred.sum()], [points], 219 | create_graph=True, retain_graph=True, only_inputs=True)[0] 220 | 221 | if normalize: 222 | normal = F.normalize(normal, dim=1, eps=1e-6) 223 | 224 | if return_pred: 225 | return normal, pred 226 | else: 227 | return normal 228 | 229 | def get_error(self, res): 230 | ''' 231 | based on https://arxiv.org/pdf/2002.10099.pdf 232 | ''' 233 | err_dict = {} 234 | 235 | error_ls = self.opt['lambda_sdf'] * nn.L1Loss()(res['sdf_surface'], torch.zeros_like(res['sdf_surface'])) 236 | error_nml = self.opt['lambda_nml'] * torch.norm(res['nml_surface'] - res['nml_gt'], p=2, dim=1).mean() 237 | 238 | nml_reg = torch.cat((res['nml_surface'], res['nml_igr']), dim=2) 239 | # error_reg = self.opt['lambda_reg'] * (torch.norm(res['nml_igr'], p=2, dim=1) - 1).mean().pow(2) 240 | error_reg = self.opt['lambda_reg'] * (torch.norm(nml_reg, p=2, dim=1) - 1).pow(2).mean() 241 | 242 | err_dict['LS'] = error_ls.item() 243 | err_dict['N'] = error_nml.item() 244 | err_dict['R'] = error_reg.item() 245 | error = error_ls + error_nml + error_reg 246 | 247 | if self.bbox_regularization: 248 | error_bbox = self.opt['lambda_bbox'] * F.leaky_relu(res['sdf_bound'], 1e-6, inplace=True).mean() 249 | err_dict['BB'] = error_bbox.item() 250 | error += error_bbox 251 | 252 | if self.space_non_zero_regu: 253 | error_non_zero = self.opt['lambda_non_zero'] * torch.exp(-100.0*torch.abs(res['sdf_igr'])).mean() 254 | err_dict['NZ'] = error_non_zero.item() 255 | error += error_non_zero 256 | 257 | if self.pose_map.requires_grad: 258 | error_pose_map = self.opt['lambda_pmap'] * (F.softmax(self.pose_map,dim=0)-self.pose_map_init).abs().mean() 259 | err_dict['PMap'] = error_pose_map.item() 260 | error += error_pose_map 261 | 262 | if self.global_feat is not None: 263 | error_lat = self.opt['lambda_lat'] * torch.norm(self.global_feat, dim=1).mean() 264 | err_dict['z-sdf'] = error_lat.item() 265 | error += error_lat 266 | 267 | return error, err_dict 268 | 269 | 270 | def forward(self, feat, pts_surface, pts_body, pts_bbox, normals, bmin=None, bmax=None): 271 | ''' 272 | args: 273 | feat: (B, C, D, H, W) 274 | pts_surface: (B, 3, N) 275 | pts_body: (B, 3, N*) 276 | pts_bbox: (B, 3, N**) 277 | normals: (B, 3, N) 278 | ''' 279 | # set volumetric feature 280 | self.filter(feat) 281 | nml_surface, sdf_surface = self.compute_normal(points=pts_surface, 282 | normalize=False, 283 | return_pred=True, 284 | custom_index=True, 285 | update_lbs=True, 286 | bmin=bmin, bmax=bmax) 287 | if self.bbox_regularization: 288 | with torch.no_grad(): 289 | bbox_xmin = pts_bbox[:,:,:self.opt['n_bound']].clone() 290 | bbox_xmin[:, 0] = self.bbox_min[0,0,0] 291 | bbox_ymin = pts_bbox[:,:,:self.opt['n_bound']].clone() 292 | bbox_ymin[:, 1] = self.bbox_min[0,1,0] 293 | bbox_zmin = pts_bbox[:,:,:self.opt['n_bound']].clone() 294 | bbox_zmin[:, 2] = self.bbox_min[0,2,0] 295 | bbox_xmax = pts_bbox[:,:,:self.opt['n_bound']].clone() 296 | bbox_xmax[:, 0] = self.bbox_max[0,0,0] 297 | bbox_ymax = pts_bbox[:,:,:self.opt['n_bound']].clone() 298 | bbox_ymax[:, 1] = self.bbox_max[0,1,0] 299 | bbox_zmax = pts_bbox[:,:,:self.opt['n_bound']].clone() 300 | bbox_zmax[:, 2] = self.bbox_max[0,2,0] 301 | 302 | pts_bound = torch.cat([bbox_xmin, bbox_ymin, bbox_zmin, bbox_xmax, bbox_ymax, bbox_zmax],-1) 303 | 304 | sdf_bound = self.query(pts_bound, bmin=bmin, bmax=bmax) 305 | 306 | pts_igr = torch.cat([pts_body, pts_bbox], 2) 307 | nml_igr, sdf_igr = self.compute_normal(points=pts_igr, 308 | normalize=False, 309 | return_pred=True, 310 | custom_index=True, 311 | bmin=bmin, bmax=bmax) 312 | 313 | res = {'sdf_surface': sdf_surface, 'sdf_igr': sdf_igr[:,:,pts_body.shape[2]:], 'nml_surface': nml_surface, 314 | 'nml_igr': nml_igr, 'nml_gt': normals} 315 | 316 | if self.bbox_regularization: 317 | res['sdf_bound'] = sdf_bound 318 | 319 | # get the error 320 | error, err_dict = self.get_error(res) 321 | 322 | return res, error, err_dict 323 | -------------------------------------------------------------------------------- /lib/model/LBSNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | 20 | import numpy as np 21 | from .BaseIMNet3d import BaseIMNet3d 22 | from ..net_util import homogenize 23 | 24 | class LBSNet(BaseIMNet3d): 25 | def __init__(self, 26 | opt, 27 | bbox_min, 28 | bbox_max, 29 | posed=False 30 | ): 31 | super(LBSNet, self).__init__(opt, bbox_min, bbox_max) 32 | self.name = 'lbs_net' 33 | self.source_space = 'posed' if posed else 'cano' 34 | 35 | def get_error(self, res): 36 | err_dict = {} 37 | 38 | errLBS_SMPL = self.opt['lambda_smpl'] * (res['pred_lbs_smpl_%s' % self.source_space]-res['gt_lbs_smpl']).pow(2).mean() 39 | err_dict['SW-SMPL/%s' % self.source_space[0]] = errLBS_SMPL.item() 40 | error = errLBS_SMPL 41 | 42 | if ('reference_lbs_scan_%s' % self.source_space) in res: 43 | errLBS_SCAN = self.opt['lambda_scan'] * (res['pred_lbs_scan_%s' % self.source_space]-res['reference_lbs_scan_%s' % self.source_space]).pow(2).mean() 44 | err_dict['SW-SCAN/%s' % self.source_space[0]] = errLBS_SCAN.item() 45 | error += errLBS_SCAN 46 | if 'pred_smpl_posed' in res and 'gt_smpl_posed' in res: 47 | errCyc_SMPL = self.opt['lambda_cyc_smpl'] * (res['pred_smpl_posed'] - res['gt_smpl_posed']).abs().mean() 48 | err_dict['Cy-SMPL'] = errCyc_SMPL.item() 49 | error += errCyc_SMPL 50 | if ('tar_edge_%s' % self.source_space) in res and ('src_edge_%s' % self.source_space) in res: 51 | errEdge = self.opt['lambda_l_edge'] * (res['w_tri'][:,None]*(1.0 - res['src_edge_%s' % self.source_space] / (res['tar_edge_%s' % self.source_space]+1e-8)).abs()).mean() 52 | err_dict['L-Edge'] = errEdge.item() 53 | error += errEdge 54 | if ('pred_lbs_tri_%s' % self.source_space) in res: 55 | pred_lbs_tri = res['pred_lbs_tri_%s' % self.source_space] 56 | le1 = (pred_lbs_tri[:,:,:,0] - pred_lbs_tri[:,:,:,1]).abs().sum(1) 57 | le2 = (pred_lbs_tri[:,:,:,1] - pred_lbs_tri[:,:,:,2]).abs().sum(1) 58 | le3 = (pred_lbs_tri[:,:,:,2] - pred_lbs_tri[:,:,:,0]).abs().sum(1) 59 | errEdge = self.opt['lambda_w_edge'] * (res['w_tri'] * (le1 + le2 + le3)).mean() 60 | err_dict['SW-Edge'] = errEdge.item() 61 | error += errEdge 62 | if 'pred_smpl_cano' in res and 'gt_smpl_cano' in res: 63 | errCyc_SMPL = self.opt['lambda_cyc_smpl'] * (res['pred_smpl_cano'] - res['gt_smpl_cano']).abs().mean() 64 | if 'Cy(SMPL)' in err_dict: 65 | err_dict['Cy-SMPL'] += errCyc_SMPL.item() 66 | else: 67 | err_dict['Cy-SMPL'] = errCyc_SMPL.item() 68 | error += errCyc_SMPL 69 | if 'pred_scan_posed' in res and 'gt_scan_posed' in res: 70 | errCyc_SCAN = self.opt['lambda_cyc_scan'] * (res['pred_scan_posed'] - res['gt_scan_posed']).abs().mean() 71 | err_dict['Cy-SCAN'] = errCyc_SCAN.item() 72 | error += errCyc_SCAN 73 | if 'pred_lbs_scan_cano' in res and 'pred_lbs_scan_posed' in res: 74 | errLBS_SCAN = self.opt['lambda_scan'] * (res['pred_lbs_scan_cano']-res['pred_lbs_scan_posed']).pow(2).sum(1).mean() 75 | err_dict['SW-SCAN-Cy'] = errLBS_SCAN.item() 76 | error += errLBS_SCAN 77 | 78 | if self.source_space == 'posed' and 'pred_lbs_scan_posed' in res: 79 | errSparse = self.opt['lambda_sparse'] * (res['pred_lbs_scan_posed'].abs()+1e-12).pow(self.opt['p_val']).sum(1).mean() 80 | err_dict['Sprs'] = errSparse.item() 81 | error += errSparse 82 | 83 | if self.global_feat is not None: 84 | error_lat = self.opt['lambda_lat'] * torch.norm(self.global_feat, dim=1).mean() 85 | err_dict['z-lbs/%s' % self.source_space[0]] = error_lat.item() 86 | error += error_lat 87 | 88 | return error, err_dict 89 | 90 | def forward(self, feat, smpl, gt_lbs_smpl=None, scan=None, reference_lbs_scan=None, jT=None, res_posed=None, nml_scan=None, v_tri=None, w_tri=None, bmin=None, bmax=None): 91 | B = smpl.shape[0] 92 | 93 | if self.body_centric_encoding: 94 | # set volumetric feature 95 | self.filter(feat) # In case it is body centric encoding 96 | 97 | pred_lbs_smpl = self.query(smpl, bmin=bmin, bmax=bmax) 98 | 99 | res = {} 100 | if res_posed is not None: 101 | res = res_posed 102 | res['pred_lbs_smpl_%s' % self.source_space] = pred_lbs_smpl 103 | 104 | space_transformed_to = 'cano' if self.source_space == 'posed' else 'posed' 105 | if jT is not None: 106 | pred_vT = torch.einsum('bjst,bjv->bvst', jT, pred_lbs_smpl) 107 | pred_vT[:,:,3,3] = 1.0 108 | if self.source_space == 'posed': 109 | pred_vT = torch.inverse(pred_vT.reshape(-1,4,4)).view(B,-1,4,4) 110 | smpl_transformed = torch.einsum('bvst,btv->bsv', pred_vT, homogenize(smpl,1))[:,:3,:] 111 | res['pred_smpl_%s' % space_transformed_to] = smpl_transformed 112 | 113 | if scan is None and gt_lbs_smpl is None: 114 | return res 115 | 116 | res['gt_smpl_%s' % self.source_space] = smpl 117 | if gt_lbs_smpl is not None: 118 | res['gt_lbs_smpl'] = gt_lbs_smpl 119 | if reference_lbs_scan is not None: 120 | res['reference_lbs_scan_%s' % self.source_space] = reference_lbs_scan 121 | if scan is not None: 122 | pred_lbs_scan = self.query(scan, bmin=bmin, bmax=bmax) 123 | res['pred_lbs_scan_%s' % self.source_space] = pred_lbs_scan 124 | if jT is not None: 125 | pred_vT = torch.einsum('bjst,bjv->bvst', jT, pred_lbs_scan) 126 | pred_vT[:,:,3,3] = 1.0 127 | if space_transformed_to == 'cano': 128 | pred_vT = torch.inverse(pred_vT.reshape(-1,4,4)).view(B,-1,4,4) 129 | res['gt_scan_posed'] = scan 130 | if nml_scan is not None: 131 | nml_T = torch.einsum('bvst,btv->bsv', pred_vT[:,:,:3,:3], nml_scan) 132 | nml_T = F.normalize(nml_T, dim=1) 133 | res['normal_scan_cano'] = nml_T 134 | res['pred_scan_%s' % space_transformed_to] = torch.einsum('bvst,btv->bsv', pred_vT, homogenize(scan,1))[:,:3,:] 135 | if v_tri is not None and jT is not None: 136 | v_tri_reshape = v_tri.view(B,3,-1,3) 137 | e1 = torch.norm(v_tri_reshape[:,:,:,0] - v_tri_reshape[:,:,:,1], p=2, dim=1, keepdim=True) 138 | e2 = torch.norm(v_tri_reshape[:,:,:,1] - v_tri_reshape[:,:,:,2], p=2, dim=1, keepdim=True) 139 | e3 = torch.norm(v_tri_reshape[:,:,:,2] - v_tri_reshape[:,:,:,0], p=2, dim=1, keepdim=True) 140 | e = torch.cat([e1,e2,e3], 1) 141 | res['tar_edge_%s' % self.source_space] = e 142 | pred_lbs_tri = self.query(v_tri, bmin=bmin, bmax=bmax) 143 | pred_vT = torch.einsum('bjst,bjv->bvst', jT, pred_lbs_tri) 144 | pred_vT[:,:,3,3] = 1.0 145 | if space_transformed_to == 'cano': 146 | pred_vT = torch.inverse(pred_vT.reshape(-1,4,4)).view(B,-1,4,4) 147 | pred_tri = torch.einsum('bvst,btv->bsv', pred_vT, homogenize(v_tri,1))[:,:3,:].view(B,3,-1,3) 148 | E1 = torch.norm(pred_tri[:,:,:,0] - pred_tri[:,:,:,1], p=2, dim=1, keepdim=True) 149 | E2 = torch.norm(pred_tri[:,:,:,1] - pred_tri[:,:,:,2], p=2, dim=1, keepdim=True) 150 | E3 = torch.norm(pred_tri[:,:,:,2] - pred_tri[:,:,:,0], p=2, dim=1, keepdim=True) 151 | E = torch.cat([E1,E2,E3], 1) 152 | res['src_edge_%s' % self.source_space] = E 153 | res['pred_lbs_tri_%s' % self.source_space] = pred_lbs_tri.view(B,pred_lbs_tri.shape[1],-1,3) 154 | if w_tri is not None: 155 | res['w_tri'] = w_tri 156 | 157 | # get the error 158 | error, err_dict = self.get_error(res) 159 | 160 | return res, error, err_dict 161 | -------------------------------------------------------------------------------- /lib/model/MLP.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from ..net_util import Mish, Sin 22 | 23 | class MLP(nn.Module): 24 | def __init__(self, filter_channels, res_layers=[], last_op=None, nlactiv='leakyrelu', norm='none'): 25 | super(MLP, self).__init__() 26 | 27 | self.filters = nn.ModuleList() 28 | 29 | if last_op == 'sigmoid': 30 | self.last_op = nn.Sigmoid() 31 | elif last_op == 'tanh': 32 | self.last_op = nn.Tanh() 33 | elif last_op == 'softmax': 34 | self.last_op = nn.Softmax(dim=1) 35 | else: 36 | self.last_op = None 37 | 38 | self.res_layers = res_layers 39 | for l in range(0, len(filter_channels) - 1): 40 | if l in res_layers: 41 | if norm == 'weight' and l != len(filter_channels) - 2: 42 | self.filters.append( 43 | nn.utils.weight_norm(nn.Conv1d( 44 | filter_channels[l] + filter_channels[0], 45 | filter_channels[l + 1], 46 | 1))) 47 | else: 48 | self.filters.append( 49 | nn.Conv1d( 50 | filter_channels[l] + filter_channels[0], 51 | filter_channels[l + 1], 52 | 1)) 53 | else: 54 | if norm == 'weight' and l != len(filter_channels) - 2: 55 | self.filters.append(nn.utils.weight_norm(nn.Conv1d( 56 | filter_channels[l], 57 | filter_channels[l + 1], 58 | 1))) 59 | else: 60 | self.filters.append(nn.Conv1d( 61 | filter_channels[l], 62 | filter_channels[l + 1], 63 | 1)) 64 | 65 | self.nlactiv = None 66 | if nlactiv == 'leakyrelu': 67 | self.nlactiv = nn.LeakyReLU() 68 | elif nlactiv == 'softplus': 69 | self.nlactiv = nn.Softplus(beta=100, threshold=20) 70 | elif nlactiv == 'relu': 71 | self.nlactiv = nn.ReLU() 72 | elif nlactiv == 'mish': 73 | self.nlactiv = Mish() 74 | elif nlactiv == 'elu': 75 | self.nlactiv = nn.ELU(0.1) 76 | elif nlactiv == 'sin': 77 | self.nlactiv = Sin() 78 | 79 | def forward(self, feature, return_last_layer_feature = False): 80 | ''' 81 | :param feature: list of [BxC_inxN] tensors of image features 82 | :param xy: [Bx3xN] tensor of (x,y) coodinates in the image plane 83 | :return: [BxC_outxN] tensor of features extracted at the coordinates 84 | ''' 85 | 86 | y = feature 87 | y0 = feature 88 | last_layer_feature = None 89 | for i, f in enumerate(self.filters): 90 | if i in self.res_layers: 91 | y = f(torch.cat([y, y0], 1)) 92 | else: 93 | y = f(y) 94 | 95 | if i != len(self.filters) - 1 and self.nlactiv is not None: 96 | y = self.nlactiv(y) 97 | 98 | if i == len(self.filters) - 2 and return_last_layer_feature: 99 | last_layer_feature = y.clone() 100 | last_layer_feature = last_layer_feature.detach() 101 | 102 | if self.last_op: 103 | y = self.last_op(y) 104 | 105 | if not return_last_layer_feature: 106 | return y 107 | else: 108 | return y, last_layer_feature 109 | -------------------------------------------------------------------------------- /lib/model/TNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import os 18 | import math 19 | import torch 20 | import torch.autograd as autograd 21 | import torch.nn as nn 22 | from torch.nn import init 23 | import torch.nn.functional as F 24 | from .MLP import MLP 25 | 26 | from ..net_util import init_net, load_network, get_embedder, init_mlp_geometric, init_mlp_siren 27 | from torch.nn import init 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | import torch.autograd as autograd 31 | 32 | import numpy as np 33 | from ..mesh_util import save_obj_mesh_with_color, save_obj_mesh, scalar_to_color 34 | from ..geometry import index3d, index3d_custom 35 | 36 | from ..net_util import get_embedder 37 | 38 | class TNet(nn.Module): 39 | def __init__(self, 40 | opt = None 41 | ): 42 | super(TNet, self).__init__() 43 | 44 | self.name = 'color_net' 45 | 46 | if opt is None: 47 | opt = { 48 | 'use_embed': True, 49 | 'd_size': 5, 50 | 'mlp':{ 51 | 'ch_dim': [3 , 256, 256, 256, 256, 3], 52 | 'res_layers': [2], 53 | 'last_op': 'softplus', 54 | 'nlactiv': 'softplus', 55 | 'norm': 'weight', 56 | 'last_op': 'none' 57 | }, 58 | 'feature_dim': 512, 59 | 'pose_dim': 4, 60 | 'g_dim': 64 61 | } 62 | else: 63 | opt['feature_dim'] = 512 64 | opt['mlp']['ch_dim'][-1] = 3 65 | 66 | 67 | self.opt = opt 68 | if self.opt['use_embed']: 69 | _, self.opt['mlp']['ch_dim'][0] = get_embedder(opt['d_size'], input_dims=self.opt['mlp']['ch_dim'][0]) 70 | 71 | if 'g_dim' in self.opt: 72 | self.opt['mlp']['ch_dim'][0] += self.opt['g_dim'] 73 | 74 | if 'pose_dim' in self.opt: 75 | self.opt['mlp']['ch_dim'][0] += self.opt['pose_dim'] * 23 76 | 77 | self.opt['mlp']['ch_dim'][0] += self.opt['feature_dim'] 78 | 79 | self.mlp = MLP( 80 | filter_channels=self.opt['mlp']['ch_dim'], 81 | res_layers=self.opt['mlp']['res_layers'], 82 | last_op=self.opt['mlp']['last_op'], 83 | nlactiv=self.opt['mlp']['nlactiv'], 84 | norm=self.opt['mlp']['norm']) 85 | 86 | init_net(self) 87 | 88 | def query(self, points_discription, last_layer_feature): 89 | input_data = torch.cat([points_discription, last_layer_feature], 1) 90 | 91 | return self.mlp(input_data) 92 | 93 | def forward(self, points_discription, last_layer_feature, target_color): 94 | input_data = torch.cat([points_discription, last_layer_feature], 1) 95 | 96 | pred_color = self.mlp(input_data) 97 | 98 | err_dict = {} 99 | 100 | error_color = nn.L1Loss()(pred_color, target_color) 101 | 102 | err_dict['CLR'] = error_color.item() 103 | 104 | return error_color, err_dict 105 | -------------------------------------------------------------------------------- /lib/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .BaseIMNet3d import BaseIMNet3d 2 | from .IGRSDFNet import IGRSDFNet 3 | from .LBSNet import LBSNet 4 | from .MLP import MLP 5 | -------------------------------------------------------------------------------- /lib/net_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import math 18 | import torch 19 | from torch.nn import init 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import functools 23 | 24 | import numpy as np 25 | from tqdm import tqdm 26 | 27 | from pytorch3d.ops import knn_gather, knn_points 28 | 29 | def compute_knn_feat(vsrc, vtar, vfeat, K=1): 30 | dist, idx, Vnn = knn_points(vsrc, vtar, K=K, return_nn=True) 31 | return knn_gather(vfeat, idx) 32 | 33 | def homogenize(v,dim=2): 34 | ''' 35 | args: 36 | v: (B, N, C) 37 | return: 38 | (B, N, C+1) 39 | ''' 40 | if dim == 2: 41 | return torch.cat([v, torch.ones_like(v[:,:,:1])], -1) 42 | elif dim == 1: 43 | return torch.cat([v, torch.ones_like(v[:,:1,:])], 1) 44 | else: 45 | raise NotImplementedError('unsupported homogenize dimension [%d]' % dim) 46 | 47 | def transform_normal(net, x, n): 48 | ''' 49 | args: 50 | flow network that returns (B, 3, N) 51 | x: (B, N, 3) 52 | n: (B, N, 3) 53 | ''' 54 | x = x.permute(0,2,1) 55 | with torch.enable_grad(): 56 | x.requires_grad_() 57 | 58 | pred = net.query(x) 59 | 60 | dfdx = autograd.grad( 61 | [pred.sum()], [x], 62 | create_graph=True, retain_graph=True, only_inputs=True)[0] 63 | print(dfdx.shape) 64 | # torch.einsum('bc') 65 | # if normalize: 66 | # normal = F.normalize(normal, dim=1, eps=1e-6) 67 | 68 | def get_posemap(map_type, n_joints, parents, n_traverse=1, normalize=True): 69 | pose_map = torch.zeros(n_joints,n_joints-1) 70 | if map_type == 'parent': 71 | for i in range(n_joints-1): 72 | pose_map[i+1,i] = 1.0 73 | elif map_type == 'children': 74 | for i in range(n_joints-1): 75 | parent = parents[i+1] 76 | for j in range(n_traverse): 77 | pose_map[parent, i] += 1.0 78 | if parent == 0: 79 | break 80 | parent = parents[parent] 81 | if normalize: 82 | pose_map /= pose_map.sum(0,keepdim=True)+1e-16 83 | elif map_type == 'both': 84 | for i in range(n_joints-1): 85 | pose_map[i+1,i] += 1.0 86 | parent = parents[i+1] 87 | for j in range(n_traverse): 88 | pose_map[parent, i] += 1.0 89 | if parent == 0: 90 | break 91 | parent = parents[parent] 92 | if normalize: 93 | pose_map /= pose_map.sum(0,keepdim=True)+1e-16 94 | else: 95 | raise NotImplementedError('unsupported pose map type [%s]' % map_type) 96 | return pose_map 97 | 98 | def batch_rot2euler(R): 99 | ''' 100 | args: 101 | Rs: (B, 3, 3) 102 | return: 103 | (B, 3) euler angle (x, y, z) 104 | ''' 105 | sy = torch.sqrt(R[:,0,0] * R[:,0,0] + R[:,1,0] * R[:,1,0]) 106 | singular = (sy < 1e-6).float()[:,None] 107 | 108 | x = torch.atan2(R[:,2,1] , R[:,2,2]) 109 | y = torch.atan2(-R[:,2,0], sy) 110 | z = torch.atan2(R[:,1,0], R[:,0,0]) 111 | euler = torch.stack([x,y,z],1) 112 | 113 | euler_s = euler.clone() 114 | euler_s[:,0] = torch.atan2(-R[:,1,2], R[:,1,1]) 115 | euler_s[:,1] = torch.atan2(-R[:,2,0], sy) 116 | euler_s[:,2] = 0 117 | 118 | return (1.0-singular)*euler + singular * euler_s 119 | 120 | 121 | def batch_rod2euler(rot_vecs): 122 | R = batch_rodrigues(rot_vecs) 123 | return batch_rot2euler(R) 124 | 125 | def batch_rod2quat(rot_vecs): 126 | batch_size = rot_vecs.shape[0] 127 | 128 | angle = torch.norm(rot_vecs + 1e-16, dim=1, keepdim=True) 129 | rot_dir = rot_vecs / angle 130 | 131 | cos = torch.cos(angle / 2) 132 | sin = torch.sin(angle / 2) 133 | 134 | # Bx1 arrays 135 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 136 | 137 | qx = rx * sin 138 | qy = ry * sin 139 | qz = rz * sin 140 | qw = cos-1.0 141 | 142 | return torch.cat([qx,qy,qz,qw], dim=1) 143 | 144 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 145 | ''' Calculates the rotation matrices for a batch of rotation vectors 146 | Parameters 147 | ---------- 148 | rot_vecs: torch.tensor Nx3 149 | array of N axis-angle vectors 150 | Returns 151 | ------- 152 | R: torch.tensor Nx3x3 153 | The rotation matrices for the given axis-angle parameters 154 | ''' 155 | 156 | batch_size = rot_vecs.shape[0] 157 | device = rot_vecs.device 158 | 159 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 160 | rot_dir = rot_vecs / angle 161 | 162 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 163 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 164 | 165 | # Bx1 arrays 166 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 167 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 168 | 169 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 170 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 171 | .view((batch_size, 3, 3)) 172 | 173 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 174 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 175 | return rot_mat 176 | 177 | def quat_to_matrix(rvec): 178 | ''' 179 | args: 180 | rvec: (B, N, 4) 181 | ''' 182 | B, N, _ = rvec.size() 183 | 184 | theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=2)) 185 | rvec = rvec / theta[:, :, None] 186 | return torch.stack(( 187 | 1. - 2. * rvec[:, :, 1] ** 2 - 2. * rvec[:, :, 2] ** 2, 188 | 2. * (rvec[:, :, 0] * rvec[:, :, 1] - rvec[:, :, 2] * rvec[:, :, 3]), 189 | 2. * (rvec[:, :, 0] * rvec[:, :, 2] + rvec[:, :, 1] * rvec[:, :, 3]), 190 | 191 | 2. * (rvec[:, :, 0] * rvec[:, :, 1] + rvec[:, :, 2] * rvec[:, :, 3]), 192 | 1. - 2. * rvec[:, :, 0] ** 2 - 2. * rvec[:, :, 2] ** 2, 193 | 2. * (rvec[:, :, 1] * rvec[:, :, 2] - rvec[:, :, 0] * rvec[:, :, 3]), 194 | 195 | 2. * (rvec[:, :, 0] * rvec[:, :, 2] - rvec[:, :, 1] * rvec[:, :, 3]), 196 | 2. * (rvec[:, :, 0] * rvec[:, :, 3] + rvec[:, :, 1] * rvec[:, :, 2]), 197 | 1. - 2. * rvec[:, :, 0] ** 2 - 2. * rvec[:, :, 1] ** 2 198 | ), dim=2).view(B, N, 3, 3) 199 | 200 | def rot6d_to_matrix(rot6d): 201 | ''' 202 | args: 203 | rot6d: (B, N, 6) 204 | return: 205 | rotation matrix: (B, N, 3, 3) 206 | ''' 207 | x_raw = rot6d[:,:,0:3] 208 | y_raw = rot6d[:,:,3:6] 209 | 210 | x = F.normalize(x_raw, dim=2) 211 | z = torch.cross(x, y_raw, dim=2) 212 | z = F.normalize(z, dim=2) 213 | y = torch.cross(z, x, dim=2) 214 | 215 | rotmat = torch.cat((x[:,:,:,None],y[:,:,:,None],z[:,:,:,None]), -1) # (B, 3, 3) 216 | 217 | return rotmat 218 | 219 | def compute_affinemat(param, rot_dim): 220 | ''' 221 | args: 222 | param: (B, N, 9/12) 223 | return: 224 | (B, N, 4, 4) 225 | ''' 226 | B, N, C = param.size() 227 | rot = param[:,:,:rot_dim] 228 | 229 | if C - rot_dim == 3: 230 | trans = param[:,:,rot_dim:] 231 | scale = torch.ones_like(trans) 232 | elif C - rot_dim == 6: 233 | trans = param[:,:,rot_dim:(rot_dim+3)] 234 | scale = param[:,:,(rot_dim+3):] 235 | else: 236 | raise ValueError('unsupported dimension [%d]' % C) 237 | 238 | if rot_dim == 3: 239 | rotmat = batch_rodrigues(rot) 240 | elif rot_dim == 4: 241 | rotmat = quat_to_matrix(rot) 242 | elif rot_dim == 6: 243 | rotmat = rot6d_to_matrix(rot) 244 | else: 245 | raise NotImplementedError('unsupported rot dimension [%d]' % rot_dim) 246 | 247 | A = torch.eye(4)[None,None].to(param.device).expand(B, N, -1, -1).contiguous() 248 | A[:,:,:3, 3] = trans # (B, N, 3, 1) 249 | A[:,:,:3,:3] = rotmat * scale[:,:,None,:] # (B, N, 3, 3) 250 | 251 | return A 252 | 253 | def compositional_affine(param, num_comp, rot_dim): 254 | ''' 255 | args: 256 | param: (B, N, M*(9/12)+M) 257 | return: 258 | (B, N, 4, 4) 259 | ''' 260 | B, N, _ = param.size() 261 | 262 | weight = torch.exp(param[:,:,:num_comp])[:,:,:,None,None] 263 | 264 | affine_param = param[:,:,num_comp:].reshape(B, N*num_comp, -1) 265 | A = compute_affinemat(affine_param, rot_dim).view(B, N, num_comp, 4, 4) 266 | 267 | return (weight * A).sum(2) / weight.sum(dim=2).clamp(min=0.001) 268 | 269 | 270 | class Embedder: 271 | def __init__(self, **kwargs): 272 | 273 | self.kwargs = kwargs 274 | self.create_embedding_fn() 275 | 276 | 277 | def create_embedding_fn(self): 278 | 279 | embed_fns = [] 280 | d = self.kwargs['input_dims'] 281 | out_dim = 0 282 | if self.kwargs['include_input']: 283 | embed_fns.append(lambda x : x) 284 | out_dim += d 285 | 286 | max_freq = self.kwargs['max_freq_log2'] 287 | N_freqs = self.kwargs['num_freqs'] 288 | 289 | if self.kwargs['log_sampling']: 290 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 291 | else: 292 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 293 | 294 | for freq in freq_bands: 295 | for p_fn in self.kwargs['periodic_fns']: 296 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 297 | out_dim += d 298 | 299 | self.embed_fns = embed_fns 300 | self.out_dim = out_dim 301 | 302 | def embed(self, inputs): 303 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 304 | 305 | def get_embedder(multires, i=0, input_dims=3): 306 | 307 | if i == -1: 308 | return nn.Identity(), input_dims 309 | 310 | embed_kwargs = { 311 | 'include_input' : True, 312 | 'input_dims' : input_dims, 313 | 'max_freq_log2' : multires-1, 314 | 'num_freqs' : multires, 315 | 'log_sampling' : True, 316 | 'periodic_fns' : [torch.sin, torch.cos], 317 | } 318 | 319 | embedder_obj = Embedder(**embed_kwargs) 320 | embed = lambda x, eo=embedder_obj : eo.embed(x) 321 | return embed, embedder_obj.out_dim 322 | 323 | class Mish(nn.Module): 324 | def __init__(self): 325 | super().__init__() 326 | 327 | def forward(self, x): 328 | #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) 329 | return x *( torch.tanh(F.softplus(x))) 330 | 331 | class Sin(nn.Module): 332 | def __init__(self): 333 | super().__init__() 334 | 335 | def forward(self, x): 336 | return torch.sin(x) 337 | 338 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 339 | "3x3 convolution with padding" 340 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 341 | stride=strd, padding=padding, bias=bias) 342 | 343 | def conv3x3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 344 | "3x3 convolution with padding" 345 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, 346 | stride=strd, padding=padding, bias=bias) 347 | 348 | def init_mlp_siren(m): 349 | classname = m.__class__.__name__ 350 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 351 | d_in = m.weight.data.size()[1] 352 | init.uniform_(m.weight.data, -math.sqrt(6/d_in), math.sqrt(6/d_in)) 353 | if hasattr(m, 'bias') and m.bias is not None: 354 | init.constant_(m.bias.data, 0.0) 355 | 356 | # From IGR paper 357 | def init_mlp_geometric(m): 358 | classname = m.__class__.__name__ 359 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 360 | d_out = m.weight.data.size()[0] 361 | if d_out == 1: 362 | d_in = m.weight.data.size()[1] 363 | init.constant_(m.weight.data, math.sqrt(math.pi/d_in)) 364 | if hasattr(m, 'bias') and m.bias is not None: 365 | init.constant_(m.bias.data, -1.0) 366 | else: 367 | init.normal_(m.weight.data, 0.0, math.sqrt(2/d_out)) 368 | if hasattr(m, 'bias') and m.bias is not None: 369 | init.constant_(m.bias.data, 0.0) 370 | 371 | def init_weights(net, init_type='normal', init_gain=0.02): 372 | """Initialize network weights. 373 | 374 | Parameters: 375 | net (network) -- network to be initialized 376 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 377 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 378 | 379 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 380 | work better for some applications. Feel free to try yourself. 381 | """ 382 | 383 | def init_func(m): # define the initialization function 384 | classname = m.__class__.__name__ 385 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 386 | if init_type == 'normal': 387 | init.normal_(m.weight.data, 0.0, init_gain) 388 | elif init_type == 'xavier': 389 | init.xavier_normal_(m.weight.data, gain=init_gain) 390 | elif init_type == 'kaiming': 391 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 392 | elif init_type == 'orthogonal': 393 | init.orthogonal_(m.weight.data, gain=init_gain) 394 | else: 395 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 396 | if hasattr(m, 'bias') and m.bias is not None: 397 | init.constant_(m.bias.data, 0.0) 398 | elif classname.find( 399 | 'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 400 | init.normal_(m.weight.data, 1.0, init_gain) 401 | init.constant_(m.bias.data, 0.0) 402 | 403 | print('initialize network with %s' % init_type) 404 | net.apply(init_func) # apply the initialization function 405 | 406 | 407 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 408 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 409 | Parameters: 410 | net (network) -- the network to be initialized 411 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 412 | gain (float) -- scaling factor for normal, xavier and orthogonal. 413 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 414 | 415 | Return an initialized network. 416 | """ 417 | if len(gpu_ids) > 0: 418 | assert (torch.cuda.is_available()) 419 | net.to(gpu_ids[0]) 420 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 421 | init_weights(net, init_type, init_gain=init_gain) 422 | return net 423 | 424 | 425 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 426 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 427 | 428 | Arguments: 429 | netD (network) -- discriminator network 430 | real_data (tensor array) -- real images 431 | fake_data (tensor array) -- generated images from the generator 432 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 433 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 434 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 435 | lambda_gp (float) -- weight for this loss 436 | 437 | Returns the gradient penalty loss 438 | """ 439 | if lambda_gp > 0.0: 440 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 441 | interpolatesv = real_data 442 | elif type == 'fake': 443 | interpolatesv = fake_data 444 | elif type == 'mixed': 445 | alpha = torch.rand(real_data.shape[0], 1) 446 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view( 447 | *real_data.shape) 448 | alpha = alpha.to(device) 449 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 450 | else: 451 | raise NotImplementedError('{} not implemented'.format(type)) 452 | interpolatesv.requires_grad_(True) 453 | disc_interpolates = netD(interpolatesv) 454 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 455 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 456 | create_graph=True, retain_graph=True, only_inputs=True) 457 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 458 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 459 | return gradient_penalty, gradients 460 | else: 461 | return 0.0, None 462 | 463 | def get_norm_layer(norm_type='instance'): 464 | """Return a normalization layer 465 | Parameters: 466 | norm_type (str) -- the name of the normalization layer: batch | instance | none 467 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 468 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 469 | """ 470 | if norm_type == 'batch': 471 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 472 | elif norm_type == 'instance': 473 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 474 | elif norm_type == 'group': 475 | norm_layer = functools.partial(nn.GroupNorm, 32) 476 | elif norm_type == 'none': 477 | norm_layer = None 478 | else: 479 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 480 | return norm_layer 481 | 482 | def load_network(network, state_dict): 483 | try: 484 | network.load_state_dict(state_dict) 485 | except: 486 | model_dict = network.state_dict() 487 | try: 488 | state_dict = {k: v for k, v in state_dict.items() if k in model_dict} 489 | network.load_state_dict(state_dict) 490 | except: 491 | print('Pretrained network has fewer layers; The following are not initialized:') 492 | for k, v in state_dict.items(): 493 | if v.size() == model_dict[k].size(): 494 | model_dict[k] = v 495 | 496 | not_initialized = set() 497 | 498 | for k, v in model_dict.items(): 499 | if k not in state_dict or v.size() != state_dict[k].size(): 500 | not_initialized.add(k.split('.')[0]) 501 | 502 | print(sorted(not_initialized)) 503 | network.load_state_dict(model_dict) 504 | 505 | 506 | class Flatten(nn.Module): 507 | def forward(self, input): 508 | return input.view(input.size(0), -1) 509 | 510 | class ConvBlock(nn.Module): 511 | def __init__(self, in_planes, out_planes, norm='batch'): 512 | super(ConvBlock, self).__init__() 513 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 514 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 515 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 516 | 517 | if norm == 'batch': 518 | self.bn1 = nn.BatchNorm2d(in_planes) 519 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 520 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 521 | self.bn4 = nn.BatchNorm2d(in_planes) 522 | elif norm == 'group': 523 | self.bn1 = nn.GroupNorm(32, in_planes) 524 | self.bn2 = nn.GroupNorm(32, int(out_planes / 2)) 525 | self.bn3 = nn.GroupNorm(32, int(out_planes / 4)) 526 | self.bn4 = nn.GroupNorm(32, in_planes) 527 | 528 | if in_planes != out_planes: 529 | self.downsample = nn.Sequential( 530 | self.bn4, 531 | nn.ReLU(True), 532 | nn.Conv2d(in_planes, out_planes, 533 | kernel_size=1, stride=1, bias=False), 534 | ) 535 | else: 536 | self.downsample = None 537 | 538 | def forward(self, x): 539 | residual = x 540 | 541 | out1 = self.bn1(x) 542 | out1 = F.relu(out1, True) 543 | out1 = self.conv1(out1) 544 | 545 | out2 = self.bn2(out1) 546 | out2 = F.relu(out2, True) 547 | out2 = self.conv2(out2) 548 | 549 | out3 = self.bn3(out2) 550 | out3 = F.relu(out3, True) 551 | out3 = self.conv3(out3) 552 | 553 | out3 = torch.cat((out1, out2, out3), 1) 554 | 555 | if self.downsample is not None: 556 | residual = self.downsample(residual) 557 | 558 | out3 += residual 559 | 560 | return out3 561 | 562 | 563 | class ConvBlock3d(nn.Module): 564 | def __init__(self, in_planes, out_planes, norm='batch'): 565 | super(ConvBlock3d, self).__init__() 566 | self.conv1 = conv3x3x3(in_planes, int(out_planes / 2)) 567 | self.conv2 = conv3x3x3(int(out_planes / 2), int(out_planes / 4)) 568 | self.conv3 = conv3x3x3(int(out_planes / 4), int(out_planes / 4)) 569 | 570 | if norm == 'batch': 571 | self.bn1 = nn.BatchNorm2d(in_planes) 572 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 573 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 574 | self.bn4 = nn.BatchNorm2d(in_planes) 575 | elif norm == 'group': 576 | self.bn1 = nn.GroupNorm(min(32,in_planes), in_planes) 577 | self.bn2 = nn.GroupNorm(min(32,int(out_planes / 2)), int(out_planes / 2)) 578 | self.bn3 = nn.GroupNorm(min(32,int(out_planes / 4)), int(out_planes / 4)) 579 | self.bn4 = nn.GroupNorm(min(32,in_planes), in_planes) 580 | 581 | if in_planes != out_planes: 582 | self.downsample = nn.Sequential( 583 | self.bn4, 584 | nn.ReLU(True), 585 | nn.Conv3d(in_planes, out_planes, 586 | kernel_size=1, stride=1, bias=False), 587 | ) 588 | else: 589 | self.downsample = None 590 | 591 | def forward(self, x): 592 | residual = x 593 | 594 | out1 = self.bn1(x) 595 | out1 = F.relu(out1, True) 596 | out1 = self.conv1(out1) 597 | 598 | out2 = self.bn2(out1) 599 | out2 = F.relu(out2, True) 600 | out2 = self.conv2(out2) 601 | 602 | out3 = self.bn3(out2) 603 | out3 = F.relu(out3, True) 604 | out3 = self.conv3(out3) 605 | 606 | out3 = torch.cat((out1, out2, out3), 1) 607 | 608 | if self.downsample is not None: 609 | residual = self.downsample(residual) 610 | 611 | out3 += residual 612 | 613 | return out3 614 | 615 | 616 | class Unet3d(nn.Module): 617 | def __init__(self, input_nc, output_nc, num_downs, ngf=16, norm_layer=nn.GroupNorm): 618 | super(Unet3d, self).__init__() 619 | # construct unet structure 620 | unet_block = UnetSkipConnectionBlock3d(ngf * 4, ngf * 4, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 621 | for i in range(num_downs - 4): # add intermediate layers with ngf * 8 filters 622 | unet_block = UnetSkipConnectionBlock3d(ngf * 4, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 623 | # gradually reduce the number of filters from ngf * 8 to ngf 624 | unet_block = UnetSkipConnectionBlock3d(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 625 | unet_block = UnetSkipConnectionBlock3d(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 626 | self.model = UnetSkipConnectionBlock3d(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 627 | 628 | def forward(self, input): 629 | """Standard forward""" 630 | return self.model(input) 631 | 632 | 633 | class UnetSkipConnectionBlock3d(nn.Module): 634 | def __init__(self, outer_nc, inner_nc, input_nc=None, 635 | submodule=None, outermost=False, innermost=False, norm_layer=nn.GroupNorm): 636 | super(UnetSkipConnectionBlock3d, self).__init__() 637 | self.outermost = outermost 638 | use_bias = False 639 | if input_nc is None: 640 | input_nc = outer_nc 641 | downconv = nn.Conv3d(input_nc, inner_nc, kernel_size=4, 642 | stride=2, padding=1, bias=use_bias) 643 | downrelu = nn.LeakyReLU(0.2, True) 644 | downnorm = norm_layer(16, inner_nc) 645 | uprelu = nn.ReLU(True) 646 | upnorm = norm_layer(16, outer_nc) 647 | 648 | if outermost: 649 | upconv = nn.ConvTranspose3d(inner_nc * 2, outer_nc, 650 | kernel_size=4, stride=2, 651 | padding=1) 652 | down = [downconv] 653 | up = [uprelu, upconv] 654 | model = down + [submodule] + up 655 | elif innermost: 656 | upconv = nn.ConvTranspose3d(inner_nc, outer_nc, 657 | kernel_size=4, stride=2, 658 | padding=1, bias=use_bias) 659 | down = [downrelu, downconv] 660 | up = [uprelu, upconv, upnorm] 661 | model = down + up 662 | else: 663 | upconv = nn.ConvTranspose3d(inner_nc * 2, outer_nc, 664 | kernel_size=4, stride=2, 665 | padding=1, bias=use_bias) 666 | down = [downrelu, downconv, downnorm] 667 | up = [uprelu, upconv, upnorm] 668 | 669 | model = down + [submodule] + up 670 | 671 | self.model = nn.Sequential(*model) 672 | 673 | def forward(self, x): 674 | if self.outermost: 675 | return self.model(x) 676 | else: # add skip connections 677 | return torch.cat([x, self.model(x)], 1) -------------------------------------------------------------------------------- /lib/sdf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | import numpy as np 18 | 19 | 20 | def create_grid(resX, resY, resZ, b_min=np.array([0, 0, 0]), b_max=np.array([1, 1, 1]), transform=None): 21 | ''' 22 | Create a dense grid of given resolution and bounding box 23 | :param resX: resolution along X axis 24 | :param resY: resolution along Y axis 25 | :param resZ: resolution along Z axis 26 | :param b_min: vec3 (x_min, y_min, z_min) bounding box corner 27 | :param b_max: vec3 (x_max, y_max, z_max) bounding box corner 28 | :return: [3, resX, resY, resZ] coordinates of the grid, and transform matrix from mesh index 29 | ''' 30 | coords = np.mgrid[:resX, :resY, :resZ] 31 | coords = coords.reshape(3, -1) 32 | coords_matrix = np.eye(4) 33 | length = b_max - b_min 34 | coords_matrix[0, 0] = length[0] / resX 35 | coords_matrix[1, 1] = length[1] / resY 36 | coords_matrix[2, 2] = length[2] / resZ 37 | coords_matrix[0:3, 3] = b_min 38 | coords = np.matmul(coords_matrix[:3, :3], coords) + coords_matrix[:3, 3:4] 39 | if transform is not None: 40 | coords = np.matmul(transform[:3, :3], coords) + transform[:3, 3:4] 41 | coords_matrix = np.matmul(transform, coords_matrix) 42 | coords = coords.reshape(3, resX, resY, resZ) 43 | return coords, coords_matrix 44 | 45 | 46 | def batch_eval(points, eval_func, num_samples=512 * 512 * 512): 47 | num_pts = points.shape[1] 48 | sdf = np.zeros(num_pts) 49 | 50 | num_batches = num_pts // num_samples 51 | for i in range(num_batches): 52 | sdf[i * num_samples:i * num_samples + num_samples] = eval_func( 53 | points[:, i * num_samples:i * num_samples + num_samples]) 54 | if num_pts % num_samples: 55 | sdf[num_batches * num_samples:] = eval_func(points[:, num_batches * num_samples:]) 56 | 57 | return sdf 58 | 59 | 60 | def eval_grid(coords, eval_func, num_samples=512 * 512 * 512): 61 | resolution = coords.shape[1:4] 62 | coords = coords.reshape([3, -1]) 63 | sdf = batch_eval(coords, eval_func, num_samples=num_samples) 64 | return sdf.reshape(resolution) 65 | 66 | 67 | def eval_grid_octree(coords, eval_func, 68 | init_resolution=64, threshold=0.01, 69 | num_samples=512 * 512 * 512): 70 | resolution = coords.shape[1:4] 71 | 72 | sdf = np.zeros(resolution) 73 | 74 | dirty = np.ones(resolution, dtype=np.bool) 75 | grid_mask = np.zeros(resolution, dtype=np.bool) 76 | 77 | reso = resolution[0] // init_resolution 78 | 79 | while reso > 0: 80 | # subdivide the grid 81 | grid_mask[0:resolution[0]:reso, 0:resolution[1]:reso, 0:resolution[2]:reso] = True 82 | # test samples in this iteration 83 | test_mask = np.logical_and(grid_mask, dirty) 84 | #print('step size:', reso, 'test sample size:', test_mask.sum()) 85 | points = coords[:, test_mask] 86 | 87 | sdf[test_mask] = batch_eval(points, eval_func, num_samples=num_samples) 88 | dirty[test_mask] = False 89 | 90 | # do interpolation 91 | if reso <= 1: 92 | break 93 | for x in range(0, resolution[0] - reso, reso): 94 | for y in range(0, resolution[1] - reso, reso): 95 | for z in range(0, resolution[2] - reso, reso): 96 | # if center marked, return 97 | if not dirty[x + reso // 2, y + reso // 2, z + reso // 2]: 98 | continue 99 | v0 = sdf[x, y, z] 100 | v1 = sdf[x, y, z + reso] 101 | v2 = sdf[x, y + reso, z] 102 | v3 = sdf[x, y + reso, z + reso] 103 | v4 = sdf[x + reso, y, z] 104 | v5 = sdf[x + reso, y, z + reso] 105 | v6 = sdf[x + reso, y + reso, z] 106 | v7 = sdf[x + reso, y + reso, z + reso] 107 | v = np.array([v0, v1, v2, v3, v4, v5, v6, v7]) 108 | v_min = v.min() 109 | v_max = v.max() 110 | # this cell is all the same 111 | if (v_max - v_min) < threshold: 112 | sdf[x:x + reso, y:y + reso, z:z + reso] = (v_max + v_min) / 2 113 | dirty[x:x + reso, y:y + reso, z:z + reso] = False 114 | reso //= 2 115 | 116 | return sdf.reshape(resolution) 117 | -------------------------------------------------------------------------------- /render/render_aist.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | 4 | import os 5 | from os.path import isdir, isfile, join 6 | 7 | import json 8 | 9 | import argparse 10 | import subprocess 11 | 12 | from tqdm import tqdm 13 | import time 14 | import random 15 | import subprocess 16 | 17 | import matplotlib.pyplot as plt 18 | 19 | def render_single_image(result_mesh_file, output_image_file, vis, yprs, raw_color = True): 20 | # Render result image 21 | mesh = o3d.io.read_triangle_mesh(result_mesh_file) 22 | mesh.compute_vertex_normals() 23 | if not raw_color: 24 | mesh.paint_uniform_color([0.7, 0.7, 0.7]) 25 | mesh.vertices = o3d.utility.Vector3dVector(np.asarray(mesh.vertices) - y_axis_offset) 26 | 27 | vis.add_geometry(mesh) 28 | ctr = vis.get_view_control() 29 | ctr.convert_from_pinhole_camera_parameters(cam_params) 30 | 31 | for ypr in yprs: 32 | ctr.rotate(0, RENDER_RESOLUTION/180*ypr[1]) 33 | ctr.rotate(RENDER_RESOLUTION/180*ypr[0], 0) 34 | 35 | vis.poll_events() 36 | vis.update_renderer() 37 | # time_stamp_result_image = str(time.time()) 38 | # output_result_image_file = join(output_dir, result_name+'_'+time_stamp_result_image+'.png') 39 | # vis.capture_screen_image(output_result_image_file, True) 40 | result_image = vis.capture_screen_float_buffer(False) 41 | vis.clear_geometries() 42 | 43 | result_img = np.asarray(result_image) 44 | 45 | plt.imsave(output_image_file, np.asarray(result_img), dpi = 1) 46 | 47 | return result_img 48 | 49 | 50 | y_axis_offset = np.array([0.0, 1.25, 0.0]) 51 | # o3d.utility.set_verbosity_level(o3d.cpu.pybind.utility.VerbosityLevel(1)) 52 | random.seed(0) 53 | 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('-i', '--input_dir', type=str, required=True, help='result directory') 56 | parser.add_argument('-o', '--out_dir', type=str, default='demo_result', help='Output directory or filename') 57 | parser.add_argument('-v', '--video_name', type=str, default='video', help='Output directory or filename') 58 | parser.add_argument('-n', '--num', type=int, default=0, help='Output directory or filename') 59 | 60 | args = parser.parse_args() 61 | 62 | input_dirs = [args.input_dir] 63 | 64 | vis = o3d.visualization.Visualizer() 65 | RENDER_RESOLUTION = 512 66 | FOCAL_LENGTH = 1.5 67 | vis.create_window(width=RENDER_RESOLUTION, height=RENDER_RESOLUTION) 68 | 69 | opt = vis.get_render_option() 70 | # opt.background_color = np.asarray([0, 0, 0]) 71 | # opt.mesh_color_option = o3d.visualization.MeshColorOption.Normal 72 | # opt.mesh_shade_option = o3d.visualization.MeshShadeOption.Color 73 | # opt.mesh_show_wireframe = True 74 | opt.light_on = True 75 | 76 | for dir_index in tqdm(range(len(input_dirs))): 77 | args.input_dir = input_dirs[dir_index] 78 | 79 | input_dir = args.input_dir 80 | output_dir = args.out_dir 81 | if output_dir == '': 82 | output_dir = input_dir 83 | output_dir = output_dir[:-1] if output_dir[-1] == '/' else output_dir 84 | video_name = args.video_name 85 | 86 | cam_intrinsics = o3d.camera.PinholeCameraIntrinsic() 87 | INTRINSIC = np.eye(3, dtype=np.float32) 88 | INTRINSIC[0,0] = FOCAL_LENGTH*RENDER_RESOLUTION 89 | INTRINSIC[1,1] = FOCAL_LENGTH*RENDER_RESOLUTION 90 | INTRINSIC[0,2] = RENDER_RESOLUTION/2-0.5 91 | INTRINSIC[1,2] = RENDER_RESOLUTION/2-0.5 92 | cam_intrinsics.intrinsic_matrix = INTRINSIC 93 | # print(cam_intrinsics.intrinsic_matrix) 94 | 95 | cam_intrinsics.width = RENDER_RESOLUTION 96 | cam_intrinsics.height = RENDER_RESOLUTION 97 | 98 | 99 | EXTRINSIC = np.array([[ 1.0, 0.0, 0.0, 0.0], 100 | [ 0.0, -1.0, 0.0, 0.5], 101 | [ 0.0, 0.0, -1.0, 2.7], 102 | [ 0.0, 0.0, 0.0, 1.0]]) 103 | cam_params = o3d.camera.PinholeCameraParameters() 104 | cam_params.intrinsic = cam_intrinsics 105 | cam_params.extrinsic = EXTRINSIC 106 | 107 | if isdir(join(output_dir, 'tmp')): 108 | tmp_dir = join(output_dir, 'tmp') 109 | command = 'rm -rf ' + tmp_dir 110 | subprocess.run(command, shell=True, stdout=subprocess.DEVNULL) 111 | tmp_dir = join(output_dir, 'tmp') 112 | os.makedirs(tmp_dir, exist_ok = True) 113 | 114 | has_color = True if 'color' in input_dir else False 115 | 116 | mesh_files = sorted([f for f in os.listdir(input_dir) if f[-3:]=='obj']) 117 | 118 | yprs = [[0, 0]] 119 | # if 'knocking1_poses' in input_dir: 120 | # yprs = [[0, -90], [-90, 0]] 121 | # if 'misc_poses' in input_dir or 'misc2_poses' in input_dir: 122 | # yprs = [[0, -90], [180, 0]] 123 | # if 'irish_dance' in input_dir: 124 | # # yprs = [[0, -90]] 125 | # yprs = [[0, -90], [180, 0]] 126 | 127 | for i, mesh_file in enumerate(tqdm(mesh_files[::1])): 128 | mesh_file_path = join(input_dir, mesh_file) 129 | output_image_file = join(tmp_dir, str(i).zfill(5)+'.png') 130 | 131 | _ = render_single_image(mesh_file_path, output_image_file, vis, yprs, has_color) 132 | 133 | if video_name == '': 134 | video_file = join(output_dir+'/../', output_dir.split('/')[-1]+'.mp4') 135 | else: 136 | video_file = join(output_dir, video_name+'.mp4') 137 | command = 'ffmpeg -r 30 -i '+join(tmp_dir,'%05d.png') + ' -c:v libx264 -vf fps=30 -pix_fmt yuv420p -y '+video_file 138 | subprocess.run(command, shell=True) 139 | 140 | command = 'rm -rf ' + tmp_dir 141 | subprocess.run(command, shell=True, stdout=subprocess.DEVNULL) 142 | 143 | 144 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.5 2 | trimesh==3.9.20 3 | matplotlib==3.3.4 4 | scikit-image==0.17.2 5 | pyyaml==5.4.1 6 | chumpy==0.70 7 | tqdm==4.61.1 -------------------------------------------------------------------------------- /smpl/LICENSE: -------------------------------------------------------------------------------- 1 | License 2 | 3 | Software Copyright License for non-commercial scientific research purposes 4 | Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the SMPL-X/SMPLify-X model, data and software, (the "Model & Software"), including 3D meshes, blend weights, blend shapes, textures, software, scripts, and animations. By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License 5 | 6 | Ownership / Licensees 7 | The Software and the associated materials has been developed at the 8 | 9 | Max Planck Institute for Intelligent Systems (hereinafter "MPI"). 10 | 11 | Any copyright or patent right is owned by and proprietary material of the 12 | 13 | Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”) 14 | 15 | hereinafter the “Licensor”. 16 | 17 | License Grant 18 | Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right: 19 | 20 | To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization; 21 | To use the Model & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects; 22 | Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Model & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission. 23 | 24 | The Model & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Model & Software to train methods/algorithms/neural networks/etc. for commercial use of any kind. By downloading the Model & Software, you agree not to reverse engineer it. 25 | 26 | No Distribution 27 | The Model & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only. 28 | 29 | Disclaimer of Representations and Warranties 30 | You expressly acknowledge and agree that the Model & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Model & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE MODEL & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Model & Software, (ii) that the use of the Model & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Model & Software will not cause any damage of any kind to you or a third party. 31 | 32 | Limitation of Liability 33 | Because this Model & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage. 34 | Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded. 35 | Patent claims generated through the usage of the Model & Software cannot be directed towards the copyright holders. 36 | The Model & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Model & Software and is not responsible for any problems such modifications cause. 37 | 38 | No Maintenance Services 39 | You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Model & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Model & Software at any time. 40 | 41 | Defects of the Model & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication. 42 | 43 | Publications using the Model & Software 44 | You acknowledge that the Model & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Model & Software. 45 | 46 | Citation: 47 | 48 | 49 | @inproceedings{SMPL-X:2019, 50 | title = {Expressive Body Capture: 3D Hands, Face, and Body from a Single Image}, 51 | author = {Pavlakos, Georgios and Choutas, Vasileios and Ghorbani, Nima and Bolkart, Timo and Osman, Ahmed A. A. and Tzionas, Dimitrios and Black, Michael J.}, 52 | booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, 53 | year = {2019} 54 | } 55 | Commercial licensing opportunities 56 | For commercial uses of the Software, please send email to ps-license@tue.mpg.de 57 | 58 | This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention. 59 | -------------------------------------------------------------------------------- /smpl/README.md: -------------------------------------------------------------------------------- 1 | ## SMPL-X: A new joint 3D model of the human body, face and hands together 2 | 3 | [[Paper Page](https://smpl-x.is.tue.mpg.de)] [[Paper](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/497/SMPL-X.pdf)] 4 | [[Supp. Mat.](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/498/SMPL-X-supp.pdf)] 5 | 6 | ![SMPL-X Examples](./images/teaser_fig.png) 7 | 8 | ## Table of Contents 9 | * [License](#license) 10 | * [Description](#description) 11 | * [Installation](#installation) 12 | * [Downloading the model](#downloading-the-model) 13 | * [Loading SMPL-X, SMPL+H and SMPL](#loading-smpl-x-smplh-and-smpl) 14 | * [SMPL and SMPL+H setup](#smpl-and-smplh-setup) 15 | * [Model loading](https://github.com/vchoutas/smplx#model-loading) 16 | * [Example](#example) 17 | * [Citation](#citation) 18 | * [Acknowledgments](#acknowledgments) 19 | * [Contact](#contact) 20 | 21 | ## License 22 | 23 | Software Copyright License for **non-commercial scientific research purposes**. 24 | Please read carefully the [terms and conditions](https://github.com/vchoutas/smplx/blob/master/LICENSE) and any accompanying documentation before you download and/or use the SMPL-X/SMPLify-X model, data and software, (the "Model & Software"), including 3D meshes, blend weights, blend shapes, textures, software, scripts, and animations. By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this [License](./LICENSE). 25 | 26 | ## Disclaimer 27 | 28 | The original images used for the figures 1 and 2 of the paper can be found in this link. 29 | The images in the paper are used under license from gettyimages.com. 30 | We have acquired the right to use them in the publication, but redistribution is not allowed. 31 | Please follow the instructions on the given link to acquire right of usage. 32 | Our results are obtained on the 483 × 724 pixels resolution of the original images. 33 | 34 | ## Description 35 | 36 | *SMPL-X* (SMPL eXpressive) is a unified body model with shape parameters trained jointly for the 37 | face, hands and body. *SMPL-X* uses standard vertex based linear blend skinning with learned corrective blend 38 | shapes, has N = 10, 475 vertices and K = 54 joints, 39 | which include joints for the neck, jaw, eyeballs and fingers. 40 | SMPL-X is defined by a function M(θ, β, ψ), where θ is the pose parameters, β the shape parameters and 41 | ψ the facial expression parameters. 42 | 43 | 44 | ## Installation 45 | 46 | To install the model please follow the next steps in the specified order: 47 | 1. To install from PyPi simply run: 48 | ```Shell 49 | pip install smplx[all] 50 | ``` 51 | 2. Clone this repository and install it using the *setup.py* script: 52 | ```Shell 53 | git clone https://github.com/vchoutas/smplx 54 | python setup.py install 55 | ``` 56 | 57 | ## Downloading the model 58 | 59 | To download the *SMPL-X* model go to [this project website](https://smpl-x.is.tue.mpg.de) and register to get access to the downloads section. 60 | 61 | To download the *SMPL+H* model go to [this project website](http://mano.is.tue.mpg.de) and register to get access to the downloads section. 62 | 63 | To download the *SMPL* model go to [this](http://smpl.is.tue.mpg.de) (male and female models) and [this](http://smplify.is.tue.mpg.de) (gender neutral model) project website and register to get access to the downloads section. 64 | 65 | ## Loading SMPL-X, SMPL+H and SMPL 66 | 67 | ### SMPL and SMPL+H setup 68 | 69 | The loader gives the option to use any of the SMPL-X, SMPL+H and SMPL models. Depending on the model you want to use, please follow the respective download instructions. To switch between SMPL, SMPL+H and SMPL-X just change the *model_path* or *model_type* parameters. For more details please check the docs of the model classes. 70 | Before using SMPL and SMPL+H you should follow the instructions in [tools/README.md](./tools/README.md) to remove the 71 | Chumpy objects from both model pkls, as well as merge the MANO parameters with SMPL+H. 72 | 73 | ### Model loading 74 | 75 | You can either use the [create](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L54) 76 | function from [body_models](./smplx/body_models.py) or directly call the constructor for the 77 | [SMPL](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L106), 78 | [SMPL+H](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L395) and 79 | [SMPL-X](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L628) model. The path to the model can either be the path to the file with the parameters or a directory with the following structure: 80 | ```bash 81 | models 82 | ├── smpl 83 | │   ├── SMPL_FEMALE.pkl 84 | │   └── SMPL_MALE.pkl 85 | │   └── SMPL_NEUTRAL.pkl 86 | ├── smplh 87 | │   ├── SMPLH_FEMALE.pkl 88 | │   └── SMPLH_MALE.pkl 89 | └── smplx 90 | ├── SMPLX_FEMALE.npz 91 | ├── SMPLX_FEMALE.pkl 92 | ├── SMPLX_MALE.npz 93 | ├── SMPLX_MALE.pkl 94 | ├── SMPLX_NEUTRAL.npz 95 | └── SMPLX_NEUTRAL.pkl 96 | ``` 97 | 98 | ## Example 99 | 100 | After installing the *smplx* package and downloading the model parameters you should be able to run the *demo.py* 101 | script to visualize the results. For this step you have to install the [pyrender](https://pyrender.readthedocs.io/en/latest/index.html) and [trimesh](https://trimsh.org/) packages. 102 | 103 | `python examples/demo.py --model-folder $SMPLX_FOLDER --plot-joints=True --gender="neutral"` 104 | 105 | ![SMPL-X Examples](./images/example.png) 106 | 107 | ## Citation 108 | 109 | Depending on which model is loaded for your project, i.e. SMPL-X or SMPL+H or SMPL, please cite the most relevant work below, listed in the same order: 110 | 111 | ``` 112 | @inproceedings{SMPL-X:2019, 113 | title = {Expressive Body Capture: 3D Hands, Face, and Body from a Single Image}, 114 | author = {Pavlakos, Georgios and Choutas, Vasileios and Ghorbani, Nima and Bolkart, Timo and Osman, Ahmed A. A. and Tzionas, Dimitrios and Black, Michael J.}, 115 | booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, 116 | year = {2019} 117 | } 118 | ``` 119 | 120 | ``` 121 | @article{MANO:SIGGRAPHASIA:2017, 122 | title = {Embodied Hands: Modeling and Capturing Hands and Bodies Together}, 123 | author = {Romero, Javier and Tzionas, Dimitrios and Black, Michael J.}, 124 | journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)}, 125 | volume = {36}, 126 | number = {6}, 127 | series = {245:1--245:17}, 128 | month = nov, 129 | year = {2017}, 130 | month_numeric = {11} 131 | } 132 | ``` 133 | 134 | ``` 135 | @article{SMPL:2015, 136 | author = {Loper, Matthew and Mahmood, Naureen and Romero, Javier and Pons-Moll, Gerard and Black, Michael J.}, 137 | title = {{SMPL}: A Skinned Multi-Person Linear Model}, 138 | journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)}, 139 | month = oct, 140 | number = {6}, 141 | pages = {248:1--248:16}, 142 | publisher = {ACM}, 143 | volume = {34}, 144 | year = {2015} 145 | } 146 | ``` 147 | 148 | This repository was originally developed for SMPL-X / SMPLify-X (CVPR 2019), you might be interested in having a look: [https://smpl-x.is.tue.mpg.de](https://smpl-x.is.tue.mpg.de). 149 | 150 | ## Acknowledgments 151 | 152 | ### Facial Contour 153 | 154 | Special thanks to [Soubhik Sanyal](https://github.com/soubhiksanyal) for sharing the Tensorflow code used for the facial 155 | landmarks. 156 | 157 | ## Contact 158 | The code of this repository was implemented by [Vassilis Choutas](vassilis.choutas@tuebingen.mpg.de). 159 | 160 | For questions, please contact [smplx@tue.mpg.de](smplx@tue.mpg.de). 161 | 162 | For commercial licensing (and all related questions for business applications), please contact [ps-licensing@tue.mpg.de](ps-licensing@tue.mpg.de). 163 | -------------------------------------------------------------------------------- /smpl/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems and the Max Planck Institute for Biological 14 | # Cybernetics. All rights reserved. 15 | # 16 | # Contact: ps-license@tuebingen.mpg.de 17 | 18 | import io 19 | import os 20 | 21 | from setuptools import setup 22 | 23 | # Package meta-data. 24 | NAME = 'smpl' 25 | DESCRIPTION = 'Customized PyTorch module for SMPL body model, adapted from Vassilis Choutas' 26 | URL = 'http://smpl-x.is.tuebingen.mpg.de' 27 | EMAIL = 'jinlong.yang@tuebingen.mpg.de' 28 | AUTHOR = 'Jinlong Yang' 29 | REQUIRES_PYTHON = '>=3.6.0' 30 | VERSION = '0.1.13' 31 | 32 | here = os.path.abspath(os.path.dirname(__file__)) 33 | 34 | try: 35 | FileNotFoundError 36 | except NameError: 37 | FileNotFoundError = IOError 38 | 39 | # Import the README and use it as the long-description. 40 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 41 | try: 42 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 43 | long_description = '\n' + f.read() 44 | except FileNotFoundError: 45 | long_description = DESCRIPTION 46 | 47 | # Load the package's __version__.py module as a dictionary. 48 | about = {} 49 | if not VERSION: 50 | with open(os.path.join(here, NAME, '__version__.py')) as f: 51 | exec(f.read(), about) 52 | else: 53 | about['__version__'] = VERSION 54 | 55 | pyrender_reqs = ['pyrender>=0.1.23', 'trimesh>=2.37.6', 'shapely'] 56 | matplotlib_reqs = ['matplotlib'] 57 | open3d_reqs = ['open3d-python'] 58 | 59 | setup(name=NAME, 60 | version=about['__version__'], 61 | description=DESCRIPTION, 62 | long_description=long_description, 63 | long_description_content_type='text/markdown', 64 | author=AUTHOR, 65 | author_email=EMAIL, 66 | python_requires=REQUIRES_PYTHON, 67 | url=URL, 68 | install_requires=[ 69 | 'numpy>=1.16.2', 70 | 'torch>=1.0.1.post2', 71 | 'torchgeometry>=0.1.2' 72 | ], 73 | extras_require={ 74 | 'pyrender': pyrender_reqs, 75 | 'open3d': open3d_reqs, 76 | 'matplotlib': matplotlib_reqs, 77 | 'all': pyrender_reqs + matplotlib_reqs + open3d_reqs 78 | }, 79 | packages=['smpl']) 80 | -------------------------------------------------------------------------------- /smpl/smpl/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from .body_models import create, SMPL 18 | -------------------------------------------------------------------------------- /smpl/smpl/joint_names.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | JOINT_NAMES = [ 18 | 'pelvis', 19 | 'left_hip', 20 | 'right_hip', 21 | 'spine1', 22 | 'left_knee', 23 | 'right_knee', 24 | 'spine2', 25 | 'left_ankle', 26 | 'right_ankle', 27 | 'spine3', 28 | 'left_foot', 29 | 'right_foot', 30 | 'neck', 31 | 'left_collar', 32 | 'right_collar', 33 | 'head', 34 | 'left_shoulder', 35 | 'right_shoulder', 36 | 'left_elbow', 37 | 'right_elbow', 38 | 'left_wrist', 39 | 'right_wrist', 40 | 'jaw', 41 | 'left_eye_smplhf', 42 | 'right_eye_smplhf', 43 | 'left_index1', 44 | 'left_index2', 45 | 'left_index3', 46 | 'left_middle1', 47 | 'left_middle2', 48 | 'left_middle3', 49 | 'left_pinky1', 50 | 'left_pinky2', 51 | 'left_pinky3', 52 | 'left_ring1', 53 | 'left_ring2', 54 | 'left_ring3', 55 | 'left_thumb1', 56 | 'left_thumb2', 57 | 'left_thumb3', 58 | 'right_index1', 59 | 'right_index2', 60 | 'right_index3', 61 | 'right_middle1', 62 | 'right_middle2', 63 | 'right_middle3', 64 | 'right_pinky1', 65 | 'right_pinky2', 66 | 'right_pinky3', 67 | 'right_ring1', 68 | 'right_ring2', 69 | 'right_ring3', 70 | 'right_thumb1', 71 | 'right_thumb2', 72 | 'right_thumb3', 73 | 'nose', 74 | 'right_eye', 75 | 'left_eye', 76 | 'right_ear', 77 | 'left_ear', 78 | 'left_big_toe', 79 | 'left_small_toe', 80 | 'left_heel', 81 | 'right_big_toe', 82 | 'right_small_toe', 83 | 'right_heel', 84 | 'left_thumb', 85 | 'left_index', 86 | 'left_middle', 87 | 'left_ring', 88 | 'left_pinky', 89 | 'right_thumb', 90 | 'right_index', 91 | 'right_middle', 92 | 'right_ring', 93 | 'right_pinky', 94 | 'right_eye_brow1', 95 | 'right_eye_brow2', 96 | 'right_eye_brow3', 97 | 'right_eye_brow4', 98 | 'right_eye_brow5', 99 | 'left_eye_brow5', 100 | 'left_eye_brow4', 101 | 'left_eye_brow3', 102 | 'left_eye_brow2', 103 | 'left_eye_brow1', 104 | 'nose1', 105 | 'nose2', 106 | 'nose3', 107 | 'nose4', 108 | 'right_nose_2', 109 | 'right_nose_1', 110 | 'nose_middle', 111 | 'left_nose_1', 112 | 'left_nose_2', 113 | 'right_eye1', 114 | 'right_eye2', 115 | 'right_eye3', 116 | 'right_eye4', 117 | 'right_eye5', 118 | 'right_eye6', 119 | 'left_eye4', 120 | 'left_eye3', 121 | 'left_eye2', 122 | 'left_eye1', 123 | 'left_eye6', 124 | 'left_eye5', 125 | 'right_mouth_1', 126 | 'right_mouth_2', 127 | 'right_mouth_3', 128 | 'mouth_top', 129 | 'left_mouth_3', 130 | 'left_mouth_2', 131 | 'left_mouth_1', 132 | 'left_mouth_5', # 59 in OpenPose output 133 | 'left_mouth_4', # 58 in OpenPose output 134 | 'mouth_bottom', 135 | 'right_mouth_4', 136 | 'right_mouth_5', 137 | 'right_lip_1', 138 | 'right_lip_2', 139 | 'lip_top', 140 | 'left_lip_2', 141 | 'left_lip_1', 142 | 'left_lip_3', 143 | 'lip_bottom', 144 | 'right_lip_3', 145 | # Face contour 146 | 'right_contour_1', 147 | 'right_contour_2', 148 | 'right_contour_3', 149 | 'right_contour_4', 150 | 'right_contour_5', 151 | 'right_contour_6', 152 | 'right_contour_7', 153 | 'right_contour_8', 154 | 'contour_middle', 155 | 'left_contour_8', 156 | 'left_contour_7', 157 | 'left_contour_6', 158 | 'left_contour_5', 159 | 'left_contour_4', 160 | 'left_contour_3', 161 | 'left_contour_2', 162 | 'left_contour_1', 163 | ] 164 | -------------------------------------------------------------------------------- /smpl/smpl/lbs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | 26 | from .utils import rot_mat_to_euler 27 | 28 | 29 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx, 30 | dynamic_lmk_b_coords, 31 | neck_kin_chain, dtype=torch.float32): 32 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks 33 | 34 | 35 | To do so, we first compute the rotation of the neck around the y-axis 36 | and then use a pre-computed look-up table to find the faces and the 37 | barycentric coordinates that will be used. 38 | 39 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) 40 | for providing the original TensorFlow implementation and for the LUT. 41 | 42 | Parameters 43 | ---------- 44 | vertices: torch.tensor BxVx3, dtype = torch.float32 45 | The tensor of input vertices 46 | pose: torch.tensor Bx(Jx3), dtype = torch.float32 47 | The current pose of the body model 48 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long 49 | The look-up table from neck rotation to faces 50 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 51 | The look-up table from neck rotation to barycentric coordinates 52 | neck_kin_chain: list 53 | A python list that contains the indices of the joints that form the 54 | kinematic chain of the neck. 55 | dtype: torch.dtype, optional 56 | 57 | Returns 58 | ------- 59 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long 60 | A tensor of size BxL that contains the indices of the faces that 61 | will be used to compute the current dynamic landmarks. 62 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 63 | A tensor of size BxL that contains the indices of the faces that 64 | will be used to compute the current dynamic landmarks. 65 | ''' 66 | 67 | batch_size = vertices.shape[0] 68 | 69 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 70 | neck_kin_chain) 71 | rot_mats = batch_rodrigues( 72 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 73 | 74 | rel_rot_mat = torch.eye(3, device=vertices.device, 75 | dtype=dtype).unsqueeze_(dim=0) 76 | for idx in range(len(neck_kin_chain)): 77 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 78 | 79 | y_rot_angle = torch.round( 80 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 81 | max=39)).to(dtype=torch.long) 82 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 83 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 84 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 85 | y_rot_angle = (neg_mask * neg_vals + 86 | (1 - neg_mask) * y_rot_angle) 87 | 88 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 89 | 0, y_rot_angle) 90 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 91 | 0, y_rot_angle) 92 | 93 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 94 | 95 | 96 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords, output_transform=False, vertex_transform=None): 97 | ''' Calculates landmarks by barycentric interpolation 98 | 99 | Parameters 100 | ---------- 101 | vertices: torch.tensor BxVx3, dtype = torch.float32 102 | The tensor of input vertices 103 | faces: torch.tensor Fx3, dtype = torch.long 104 | The faces of the mesh 105 | lmk_faces_idx: torch.tensor L, dtype = torch.long 106 | The tensor with the indices of the faces used to calculate the 107 | landmarks. 108 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 109 | The tensor of barycentric coordinates that are used to interpolate 110 | the landmarks 111 | 112 | Returns 113 | ------- 114 | landmarks: torch.tensor BxLx3, dtype = torch.float32 115 | The coordinates of the landmarks for each mesh in the batch 116 | ''' 117 | # Extract the indices of the vertices for each face 118 | # BxLx3 119 | batch_size, num_verts = vertices.shape[:2] 120 | device = vertices.device 121 | 122 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 123 | batch_size, -1, 3) 124 | 125 | lmk_faces += torch.arange( 126 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts 127 | 128 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( 129 | batch_size, -1, 3, 3) 130 | 131 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 132 | if not output_transform: 133 | return landmarks 134 | 135 | lmk_transform = vertex_transform.view(-1, 4, 4)[lmk_faces].view( 136 | batch_size, -1, 3, 4, 4) 137 | landmarks_transform = torch.einsum('blfij, blf->blij', [lmk_transform, lmk_bary_coords]) 138 | return landmarks, landmarks_transform 139 | 140 | 141 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, 142 | lbs_weights, pose2rot=True, body_neutral_v = None, dtype=torch.float32, custom_out=False): 143 | ''' Performs Linear Blend Skinning with the given shape and pose parameters 144 | 145 | Parameters 146 | ---------- 147 | betas : torch.tensor BxNB 148 | The tensor of shape parameters 149 | pose : torch.tensor Bx(J + 1) * 3 150 | The pose parameters in axis-angle format 151 | v_template torch.tensor BxVx3 152 | The template mesh that will be deformed 153 | shapedirs : torch.tensor 1xNB 154 | The tensor of PCA shape displacements 155 | posedirs : torch.tensor Px(V * 3) 156 | The pose PCA coefficients 157 | J_regressor : torch.tensor JxV 158 | The regressor array that is used to calculate the joints from 159 | the position of the vertices 160 | parents: torch.tensor J 161 | The array that describes the kinematic tree for the model 162 | lbs_weights: torch.tensor N x V x (J + 1) 163 | The linear blend skinning weights that represent how much the 164 | rotation matrix of each part affects each vertex 165 | pose2rot: bool, optional 166 | Flag on whether to convert the input pose tensor to rotation 167 | matrices. The default value is True. If False, then the pose tensor 168 | should already contain rotation matrices and have a size of 169 | Bx(J + 1)x9 170 | dtype: torch.dtype, optional 171 | 172 | Returns 173 | ------- 174 | verts: torch.tensor BxVx3 175 | The vertices of the mesh after applying the shape and pose 176 | displacements. 177 | joints: torch.tensor BxJx3 178 | The joints of the model 179 | ''' 180 | 181 | batch_size = max(betas.shape[0], pose.shape[0]) 182 | device = betas.device 183 | 184 | # Add shape contribution 185 | v_shaped = v_template + blend_shapes(betas, shapedirs) 186 | 187 | if not body_neutral_v == None: 188 | v_shaped = body_neutral_v 189 | 190 | # Get the joints 191 | # NxJx3 array 192 | J = vertices2joints(J_regressor, v_shaped) 193 | 194 | # 3. Add pose blend shapes 195 | # N x J x 3 x 3 196 | ident = torch.eye(3, dtype=dtype, device=device) 197 | if pose2rot: 198 | rot_mats = batch_rodrigues( 199 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) 200 | 201 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) 202 | # (N x P) x (P, V * 3) -> N x V x 3 203 | pose_offsets = torch.matmul(pose_feature, posedirs) \ 204 | .view(batch_size, -1, 3) 205 | else: 206 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident 207 | rot_mats = pose.view(batch_size, -1, 3, 3) 208 | 209 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), 210 | posedirs).view(batch_size, -1, 3) 211 | 212 | v_posed = pose_offsets + v_shaped 213 | # 4. Get the global joint location 214 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) 215 | 216 | # 5. Do skinning: 217 | # W is N x V x (J + 1) 218 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) 219 | # (N x V x (J + 1)) x (N x (J + 1) x 16) 220 | num_joints = J_regressor.shape[0] 221 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ 222 | .view(batch_size, -1, 4, 4) 223 | 224 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], 225 | dtype=dtype, device=device) 226 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) 227 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) 228 | 229 | verts = v_homo[:, :, :3, 0] 230 | 231 | ret = {'verts': verts, 'joints': J_transformed} 232 | if custom_out: 233 | ret['vT'] = T 234 | ret['jT'] = A 235 | ret['v_shaped'] = v_shaped 236 | ret['v_posed'] = v_posed 237 | return ret 238 | 239 | 240 | def vertices2joints(J_regressor, vertices): 241 | ''' Calculates the 3D joint locations from the vertices 242 | 243 | Parameters 244 | ---------- 245 | J_regressor : torch.tensor JxV 246 | The regressor array that is used to calculate the joints from the 247 | position of the vertices 248 | vertices : torch.tensor BxVx3 249 | The tensor of mesh vertices 250 | 251 | Returns 252 | ------- 253 | torch.tensor BxJx3 254 | The location of the joints 255 | ''' 256 | 257 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) 258 | 259 | 260 | def blend_shapes(betas, shape_disps): 261 | ''' Calculates the per vertex displacement due to the blend shapes 262 | 263 | 264 | Parameters 265 | ---------- 266 | betas : torch.tensor Bx(num_betas) 267 | Blend shape coefficients 268 | shape_disps: torch.tensor Vx3x(num_betas) 269 | Blend shapes 270 | 271 | Returns 272 | ------- 273 | torch.tensor BxVx3 274 | The per-vertex displacement due to shape deformation 275 | ''' 276 | 277 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] 278 | # i.e. Multiply each shape displacement by its corresponding beta and 279 | # then sum them. 280 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) 281 | return blend_shape 282 | 283 | 284 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 285 | ''' Calculates the rotation matrices for a batch of rotation vectors 286 | Parameters 287 | ---------- 288 | rot_vecs: torch.tensor Nx3 289 | array of N axis-angle vectors 290 | Returns 291 | ------- 292 | R: torch.tensor Nx3x3 293 | The rotation matrices for the given axis-angle parameters 294 | ''' 295 | 296 | batch_size = rot_vecs.shape[0] 297 | device = rot_vecs.device 298 | 299 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 300 | rot_dir = rot_vecs / angle 301 | 302 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 303 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 304 | 305 | # Bx1 arrays 306 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 307 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 308 | 309 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 310 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 311 | .view((batch_size, 3, 3)) 312 | 313 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 314 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 315 | return rot_mat 316 | 317 | 318 | def transform_mat(R, t): 319 | ''' Creates a batch of transformation matrices 320 | Args: 321 | - R: Bx3x3 array of a batch of rotation matrices 322 | - t: Bx3x1 array of a batch of translation vectors 323 | Returns: 324 | - T: Bx4x4 Transformation matrix 325 | ''' 326 | # No padding left or right, only add an extra row 327 | return torch.cat([F.pad(R, [0, 0, 0, 1]), 328 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2) 329 | 330 | 331 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): 332 | """ 333 | Applies a batch of rigid transformations to the joints 334 | 335 | Parameters 336 | ---------- 337 | rot_mats : torch.tensor BxNx3x3 338 | Tensor of rotation matrices 339 | joints : torch.tensor BxNx3 340 | Locations of joints 341 | parents : torch.tensor BxN 342 | The kinematic tree of each object 343 | dtype : torch.dtype, optional: 344 | The data type of the created tensors, the default is torch.float32 345 | 346 | Returns 347 | ------- 348 | posed_joints : torch.tensor BxNx3 349 | The locations of the joints after applying the pose rotations 350 | rel_transforms : torch.tensor BxNx4x4 351 | The relative (with respect to the root joint) rigid transformations 352 | for all the joints 353 | """ 354 | 355 | joints = torch.unsqueeze(joints, dim=-1) 356 | 357 | rel_joints = joints.clone() 358 | rel_joints[:, 1:] -= joints[:, parents[1:]] 359 | 360 | transforms_mat = transform_mat( 361 | rot_mats.view(-1, 3, 3), 362 | rel_joints.reshape(-1, 3, 1)).view(-1, joints.shape[1], 4, 4) 363 | 364 | transform_chain = [transforms_mat[:, 0]] 365 | for i in range(1, parents.shape[0]): 366 | # Subtract the joint location at the rest pose 367 | # No need for rotation, since it's identity when at rest 368 | curr_res = torch.matmul(transform_chain[parents[i]], 369 | transforms_mat[:, i]) 370 | transform_chain.append(curr_res) 371 | 372 | transforms = torch.stack(transform_chain, dim=1) 373 | 374 | # The last column of the transformations contains the posed joints 375 | posed_joints = transforms[:, :, :3, 3] 376 | 377 | # The last column of the transformations contains the posed joints 378 | posed_joints = transforms[:, :, :3, 3] 379 | 380 | joints_homogen = F.pad(joints, [0, 0, 0, 1]) 381 | 382 | rel_transforms = transforms - F.pad( 383 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) 384 | 385 | return posed_joints, rel_transforms 386 | -------------------------------------------------------------------------------- /smpl/smpl/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import print_function 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | 21 | import numpy as np 22 | import torch 23 | 24 | 25 | def to_tensor(array, dtype=torch.float32): 26 | if 'torch.tensor' not in str(type(array)): 27 | return torch.tensor(array, dtype=dtype) 28 | 29 | 30 | class Struct(object): 31 | def __init__(self, **kwargs): 32 | for key, val in kwargs.items(): 33 | setattr(self, key, val) 34 | 35 | 36 | def to_np(array, dtype=np.float32): 37 | if 'scipy.sparse' in str(type(array)): 38 | array = array.todense() 39 | return np.array(array, dtype=dtype) 40 | 41 | 42 | def rot_mat_to_euler(rot_mats): 43 | # Calculates rotation matrix to euler angles 44 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 45 | 46 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 47 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 48 | return torch.atan2(-rot_mats[:, 2, 0], sy) 49 | -------------------------------------------------------------------------------- /smpl/smpl/vertex_ids.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import print_function 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | 21 | # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to 22 | # MSCOCO and OpenPose joints 23 | vertex_ids = { 24 | 'smplh': { 25 | 'nose': 332, 26 | 'reye': 6260, 27 | 'leye': 2800, 28 | 'rear': 4071, 29 | 'lear': 583, 30 | 'rthumb': 6191, 31 | 'rindex': 5782, 32 | 'rmiddle': 5905, 33 | 'rring': 6016, 34 | 'rpinky': 6133, 35 | 'lthumb': 2746, 36 | 'lindex': 2319, 37 | 'lmiddle': 2445, 38 | 'lring': 2556, 39 | 'lpinky': 2673, 40 | 'LBigToe': 3216, 41 | 'LSmallToe': 3226, 42 | 'LHeel': 3387, 43 | 'RBigToe': 6617, 44 | 'RSmallToe': 6624, 45 | 'RHeel': 6787 46 | }, 47 | 'smplx': { 48 | 'nose': 9120, 49 | 'reye': 9929, 50 | 'leye': 9448, 51 | 'rear': 616, 52 | 'lear': 6, 53 | 'rthumb': 8079, 54 | 'rindex': 7669, 55 | 'rmiddle': 7794, 56 | 'rring': 7905, 57 | 'rpinky': 8022, 58 | 'lthumb': 5361, 59 | 'lindex': 4933, 60 | 'lmiddle': 5058, 61 | 'lring': 5169, 62 | 'lpinky': 5286, 63 | 'LBigToe': 5770, 64 | 'LSmallToe': 5780, 65 | 'LHeel': 8846, 66 | 'RBigToe': 8463, 67 | 'RSmallToe': 8474, 68 | 'RHeel': 8635 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /smpl/smpl/vertex_joint_selector.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.nn as nn 25 | 26 | from .utils import to_tensor 27 | 28 | 29 | class VertexJointSelector(nn.Module): 30 | 31 | def __init__(self, vertex_ids=None, 32 | use_hands=True, 33 | use_feet_keypoints=True, **kwargs): 34 | super(VertexJointSelector, self).__init__() 35 | 36 | extra_joints_idxs = [] 37 | 38 | face_keyp_idxs = np.array([ 39 | vertex_ids['nose'], 40 | vertex_ids['reye'], 41 | vertex_ids['leye'], 42 | vertex_ids['rear'], 43 | vertex_ids['lear']], dtype=np.int64) 44 | 45 | extra_joints_idxs = np.concatenate([extra_joints_idxs, 46 | face_keyp_idxs]) 47 | 48 | if use_feet_keypoints: 49 | feet_keyp_idxs = np.array([vertex_ids['LBigToe'], 50 | vertex_ids['LSmallToe'], 51 | vertex_ids['LHeel'], 52 | vertex_ids['RBigToe'], 53 | vertex_ids['RSmallToe'], 54 | vertex_ids['RHeel']], dtype=np.int32) 55 | 56 | extra_joints_idxs = np.concatenate( 57 | [extra_joints_idxs, feet_keyp_idxs]) 58 | 59 | if use_hands: 60 | self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky'] 61 | 62 | tips_idxs = [] 63 | for hand_id in ['l', 'r']: 64 | for tip_name in self.tip_names: 65 | tips_idxs.append(vertex_ids[hand_id + tip_name]) 66 | 67 | extra_joints_idxs = np.concatenate( 68 | [extra_joints_idxs, tips_idxs]) 69 | 70 | self.register_buffer('extra_joints_idxs', 71 | to_tensor(extra_joints_idxs, dtype=torch.long)) 72 | 73 | def forward(self, vertices, joints, output_joint_transform=False, vertex_transform=None, joints_transform=None): 74 | extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs) 75 | joints = torch.cat([joints, extra_joints], dim=1) 76 | if not output_joint_transform: 77 | return joints 78 | 79 | extra_joints_transform = torch.index_select(vertex_transform, 1, self.extra_joints_idxs) 80 | joints_transform = torch.cat([joints_transform, extra_joints_transform], dim=1) 81 | return joints, joints_transform 82 | -------------------------------------------------------------------------------- /teaser/aist_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunsukesaito/SCANimate/f2eeb5799fd20fd9d5933472f6aedf1560296cbe/teaser/aist_0.gif -------------------------------------------------------------------------------- /teaser/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shunsukesaito/SCANimate/f2eeb5799fd20fd9d5933472f6aedf1560296cbe/teaser/teaser.png --------------------------------------------------------------------------------