├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 1.gif ├── 2.gif └── 3.gif ├── config ├── config.yaml ├── datamodule │ ├── cape.yaml │ ├── dfaust.yaml │ └── jointlim.yaml └── experiments │ └── cape.yaml ├── demo.py ├── download_data.sh ├── environment.yml ├── lib ├── dataset │ ├── cape.py │ ├── dfaust.py │ ├── dfaust_split.yml │ └── jointlim.py ├── libmise │ ├── mise.cpp │ └── mise.pyx ├── model │ ├── broyden.py │ ├── deformer.py │ ├── helpers.py │ ├── metrics.py │ ├── network.py │ ├── sample.py │ └── smpl.py ├── smpl │ ├── body_models.py │ ├── lbs.py │ ├── utils.py │ ├── vertex_ids.py │ └── vertex_joint_selector.py ├── snarf_model.py └── utils │ ├── meshing.py │ └── render.py ├── preprocess ├── body_model.py ├── lbs.py ├── sample_points.py └── utils.py ├── setup.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | share/python-wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | *.py,cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | cover/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | .pybuilder/ 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | # For a library or package, you might want to ignore these files since the code is 86 | # intended to run in multiple environments; otherwise, check them in: 87 | # .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # pytype static type analyzer 134 | .pytype/ 135 | 136 | # Cython debug symbols 137 | cython_debug/ 138 | 139 | **.pyc 140 | **.in 141 | **.rst 142 | **.ini 143 | 144 | # Custom 145 | **.out 146 | data 147 | outputs 148 | wandb 149 | multirun 150 | lib/smpl/smpl_model 151 | **.pkl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Xu Chen, Yufeng Zheng, Michael J. Black, Otmar Hilliges, Andreas Geiger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes 2 | ## [Paper](https://arxiv.org/pdf/2104.03953.pdf) | [Supp](https://bit.ly/3t1Tk6F) | [Video](https://youtu.be/rCEpFTKjFHE) | [Project Page](https://xuchen-ethz.github.io/snarf) | Blog ([AIT](https://eth-ait.medium.com/animate-implicit-shapes-with-forward-skinning-c7ebbf355694),[AVG](https://autonomousvision.github.io/snarf/)) 3 | 4 | 5 | 6 | 7 | Official code release for ICCV 2021 paper [*SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes*](https://arxiv.org/pdf/2104.03953.pdf). We propose a novel forward skinning module to animate neural implicit shapes with good generalization to unseen poses. 8 | 9 | **Update:** we have released an improved version, FastSNARF, which is 150x faster than SNARF. Check it out [here](https://github.com/xuchen-ethz/fast-snarf). 10 | 11 | If you find our code or paper useful, please cite as 12 | ``` 13 | @inproceedings{chen2021snarf, 14 | title={SNARF: Differentiable Forward Skinning for Animating Non-Rigid Neural Implicit Shapes}, 15 | author={Chen, Xu and Zheng, Yufeng and Black, Michael J and Hilliges, Otmar and Geiger, Andreas}, 16 | booktitle={International Conference on Computer Vision (ICCV)}, 17 | year={2021} 18 | } 19 | ``` 20 | 21 | # Quick Start 22 | Clone this repo: 23 | ``` 24 | git clone https://github.com/xuchen-ethz/snarf.git 25 | cd snarf 26 | ``` 27 | 28 | Install environment: 29 | ``` 30 | conda env create -f environment.yml 31 | conda activate snarf 32 | python setup.py install 33 | ``` 34 | 35 | 36 | Download [SMPL models](https://smpl.is.tue.mpg.de) (1.0.0 for Python 2.7 (10 shape PCs)) and move them to the corresponding places: 37 | ``` 38 | mkdir lib/smpl/smpl_model/ 39 | mv /path/to/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl lib/smpl/smpl_model/SMPL_FEMALE.pkl 40 | mv /path/to/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl lib/smpl/smpl_model/SMPL_MALE.pkl 41 | ``` 42 | 43 | Download our pretrained models and test motion sequences: 44 | ``` 45 | sh ./download_data.sh 46 | ``` 47 | 48 | Run a quick demo for clothed human: 49 | ``` 50 | python demo.py expname=cape subject=3375 demo.motion_path=data/aist_demo/seqs +experiments=cape 51 | ``` 52 | You can the find the video in `outputs/cape/3375/demo.mp4` and images in `outputs/cape/3375/images/`. To save the meshes, add `demo.save_mesh=true` to the command. 53 | 54 | You can also try other subjects (see `outputs/data/cape` for available options) by setting `subject=xx`, and other motion sequences from [AMASS](https://amass.is.tue.mpg.de/download.php) by setting `demo.motion_path=/path/to/amass_modetion.npz`. 55 | 56 | Some motion sequences have high fps and one might want to skip some frames. To do this, add `demo.every_n_frames=x` to consider every x frame in the motion sequence. (e.g. `demo.every_n_frames=10` for PosePrior sequences) 57 | 58 | By default, we use `demo.fast_mode=true` for fast mesh extraction. In this mode, we first extract mesh in canonical space, and then forward skin the mesh to posed space. This bypasses the root finding during inference, thus is faster. However, it's not really deforming a continuous field. To first deform the continuous field and then extract mesh in deformed space, use `demo.fast_mode=false` instead. 59 | 60 | # Training and Evaluation 61 | 62 | ## Install Additional Dependencies 63 | Install [kaolin](https://kaolin.readthedocs.io/en/latest/notes/installation.html) for fast occupancy query from meshes. 64 | ``` 65 | git clone https://github.com/NVIDIAGameWorks/kaolin 66 | cd kaolin 67 | git checkout v0.9.0 68 | python setup.py develop 69 | ``` 70 | ## Minimally Clothed Human 71 | ### Prepare Datasets 72 | Download the [AMASS](https://amass.is.tue.mpg.de/download.php) dataset. We use ''DFaust Snythetic'' and ''PosePrior'' subsets and SMPL-H format. Unzip the dataset into `data` folder. 73 | ``` 74 | tar -xf DFaust67.tar.bz2 -C data 75 | tar -xf MPILimits.tar.bz2 -C data 76 | ``` 77 | 78 | Preprocess dataset: 79 | ``` 80 | python preprocess/sample_points.py --output_folder data/DFaust_processed 81 | python preprocess/sample_points.py --output_folder data/MPI_processed --skip 10 --poseprior 82 | ``` 83 | 84 | 85 | ### Training 86 | Run the following command to train for a specified subject: 87 | ``` 88 | python train.py subject=50002 89 | ``` 90 | Training logs are available on [wandb](https://wandb.ai/home) (registration needed, free of charge). It should take ~12h on a single 2080Ti. 91 | 92 | ### Evaluation 93 | Run the following command to evaluate the method for a specified subject on within distribution data (DFaust test split): 94 | ``` 95 | python test.py subject=50002 96 | ``` 97 | and outside destribution (PosePrior): 98 | ``` 99 | python test.py subject=50002 datamodule=jointlim 100 | ``` 101 | 102 | ### Generate Animation 103 | You can use the trained model to generate animation (same as in Quick Start): 104 | ``` 105 | python demo.py expname='dfaust' subject=50002 demo.motion_path='data/aist_demo/seqs' 106 | ``` 107 | 108 | 109 | ## Clothed Human 110 | 111 | ### Training 112 | Download the [CAPE](https://cape.is.tue.mpg.de/) dataset and unzip into `data` folder. 113 | 114 | Run the following command to train for a specified subject and clothing type: 115 | ``` 116 | python train.py datamodule=cape subject=3375 datamodule.clothing='blazerlong' +experiments=cape 117 | ``` 118 | Training logs are available on [wandb](https://wandb.ai/home). It should take ~24h on a single 2080Ti. 119 | 120 | ### Generate Animation 121 | You can use the trained model to generate animation (same as in Quick Start): 122 | ``` 123 | python demo.py expname=cape subject=3375 demo.motion_path=data/aist_demo/seqs +experiments=cape 124 | ``` 125 | 126 | # Acknowledgement 127 | We use the pre-processing code in [PTF](https://github.com/taconite/PTF) and [LEAP](https://github.com/neuralbodies/leap) with some adaptions (`./preprocess`). The network and sampling part of the code (`lib/model/network.py` and `lib/model/sample.py`) is implemented based on [IGR](https://github.com/amosgropp/IGR) and [IDR](https://github.com/lioryariv/idr). The code for extracting mesh (`lib/utils/meshing.py`) is adapted from [NASA](https://github.com/tensorflow/graphics/tree/master/tensorflow_graphics/projects/nasa). Our implementation of Broyden's method (`lib/model/broyden.py`) is based on [DEQ](https://github.com/locuslab/deq). We sincerely thank these authors for their awesome work. 128 | -------------------------------------------------------------------------------- /assets/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuchen-ethz/snarf/ae0c893cc049f0f8270eaa401e138dff5d4637b9/assets/1.gif -------------------------------------------------------------------------------- /assets/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuchen-ethz/snarf/ae0c893cc049f0f8270eaa401e138dff5d4637b9/assets/2.gif -------------------------------------------------------------------------------- /assets/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuchen-ethz/snarf/ae0c893cc049f0f8270eaa401e138dff5d4637b9/assets/3.gif -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - datamodule: dfaust 3 | 4 | hydra: 5 | run: 6 | dir: outputs/${expname}/${subject} 7 | 8 | expname: dfaust 9 | subject: 50002 10 | epoch: last 11 | resume: false 12 | 13 | trainer: 14 | gradient_clip_val: 0.1 15 | check_val_every_n_epoch: 5 16 | deterministic: true 17 | max_steps: 45000 18 | gpus: 1 19 | model: 20 | # shape MLP 21 | network: 22 | d_in: 3 23 | d_out: 1 24 | depth: 8 25 | width: 256 26 | multires: 0 27 | skip_layer: [4] 28 | cond_layer: [4] 29 | dim_cond_embed: 8 30 | weight_norm: true 31 | geometric_init: true 32 | bias: 1 33 | deformer: 34 | softmax_mode: hierarchical 35 | # LBS MLP 36 | network: 37 | d_in: 3 38 | d_out: 25 39 | depth: 4 40 | width: 128 41 | multires: 0 42 | skip_layer: [] 43 | cond_layer: [] 44 | dim_cond_embed: 0 45 | weight_norm: true 46 | geometric_init: false 47 | bias: 1 48 | optim: 49 | lr: 1e-3 50 | soft_blend: 5 51 | nepochs_pretrain: 1 52 | lambda_bone_occ: 1 53 | lambda_bone_w: 10 54 | 55 | demo: 56 | motion_path: data/aist_demo/seqs 57 | resolution: 256 58 | save_mesh: false 59 | every_n_frames: 2 60 | output_video_name: aist 61 | verbose: false 62 | fast_mode: true -------------------------------------------------------------------------------- /config/datamodule/cape.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lib.dataset.cape.CAPEDataModule 3 | dataset_path: ./data/CAPE 4 | num_workers: 10 5 | subject: ${subject} 6 | clothing: 'longshort' 7 | batch_size: 8 8 | processor: 9 | _target_: lib.dataset.cape.CAPEDataProcessor 10 | points_per_frame: 2000 11 | sampler: 12 | _target_: lib.model.sample.PointInSpace 13 | global_sigma: 1.8 14 | local_sigma: 0.01 15 | -------------------------------------------------------------------------------- /config/datamodule/dfaust.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lib.dataset.dfaust.DFaustDataModule 3 | dataset_path: ./data/DFaust_processed 4 | subject: ${subject} 5 | num_workers: 10 6 | batch_size: 8 7 | points_per_frame: 2000 8 | 9 | -------------------------------------------------------------------------------- /config/datamodule/jointlim.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lib.dataset.jointlim.JointLimDataModule 3 | dataset_path: ./data/MPILimits_processed 4 | subject: ${subject} 5 | num_workers: 10 -------------------------------------------------------------------------------- /config/experiments/cape.yaml: -------------------------------------------------------------------------------- 1 | expname: cape 2 | trainer: 3 | max_steps: 100000 4 | model: 5 | network: 6 | multires: 4 7 | cond_layer: [0] 8 | dim_cond_embed: -1 9 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | import hydra 5 | import torch 6 | import imageio 7 | import numpy as np 8 | import pytorch_lightning as pl 9 | 10 | from tqdm import trange 11 | from lib.snarf_model import SNARFModel 12 | from lib.model.helpers import rectify_pose 13 | 14 | @hydra.main(config_path="config", config_name="config") 15 | def main(opt): 16 | 17 | print(opt.pretty()) 18 | pl.seed_everything(42, workers=True) 19 | 20 | # set up model 21 | meta_info = np.load('meta_info.npz') 22 | 23 | if opt.epoch == 'last': 24 | checkpoint_path = './checkpoints/last.ckpt' 25 | else: 26 | checkpoint_path = glob.glob('./checkpoints/epoch=%d*.ckpt'%opt.epoch)[0] 27 | 28 | model = SNARFModel.load_from_checkpoint( 29 | checkpoint_path=checkpoint_path, 30 | opt=opt.model, 31 | meta_info=meta_info 32 | ).cuda() 33 | # use all bones for initialization during testing 34 | model.deformer.init_bones = np.arange(24) 35 | 36 | # pose format conversion 37 | smplx_to_smpl = list(range(66)) + [72, 73, 74, 117, 118, 119] # SMPLH to SMPL 38 | 39 | # load motion sequence 40 | motion_path = hydra.utils.to_absolute_path(opt.demo.motion_path) 41 | if os.path.isdir(motion_path): 42 | motion_files = sorted(glob.glob(os.path.join(motion_path, '*.npz'))) 43 | smpl_params_all = [] 44 | for f in motion_files: 45 | f = np.load(f) 46 | smpl_params = np.zeros(86) 47 | smpl_params[0], smpl_params[4:76] = 1, f['pose'] 48 | smpl_params = torch.tensor(smpl_params).float().cuda() 49 | smpl_params_all.append(smpl_params) 50 | smpl_params_all = torch.stack(smpl_params_all) 51 | 52 | elif '.npz' in motion_path: 53 | f = np.load(motion_path) 54 | smpl_params_all = np.zeros( (f['poses'].shape[0], 86) ) 55 | smpl_params_all[:,0] = 1 56 | if f['poses'].shape[-1] == 72: 57 | smpl_params_all[:, 4:76] = f['poses'] 58 | elif f['poses'].shape[-1] == 156: 59 | smpl_params_all[:, 4:76] = f['poses'][:,smplx_to_smpl] 60 | 61 | root_abs = smpl_params_all[0, 4:7].copy() 62 | for i in range(smpl_params_all.shape[0]): 63 | smpl_params_all[i, 4:7] = rectify_pose(smpl_params_all[i, 4:7], root_abs) 64 | 65 | smpl_params_all = torch.tensor(smpl_params_all).float().cuda() 66 | 67 | smpl_params_all = smpl_params_all[::opt.demo.every_n_frames] 68 | # generate data batch 69 | images = [] 70 | for i in trange(smpl_params_all.shape[0]): 71 | 72 | smpl_params = smpl_params_all[[i]] 73 | data = model.smpl_server.forward(smpl_params, absolute=True) 74 | data['smpl_thetas'] = smpl_params[:, 4:76] 75 | results = model.plot(data, res=opt.demo.resolution, verbose=opt.demo.verbose, fast_mode=opt.demo.fast_mode) 76 | 77 | images.append(results['img_all']) 78 | 79 | if not os.path.exists('images'): 80 | os.makedirs('images') 81 | imageio.imwrite('images/%04d.png'%i, results['img_all']) 82 | 83 | if opt.demo.save_mesh: 84 | if not os.path.exists('meshes'): 85 | os.makedirs('meshes') 86 | results['mesh_def'].export('meshes/%04d_def.ply'%i) 87 | if opt.demo.verbose: 88 | results['mesh_cano'].export('meshes/%04d_cano.ply'%i) 89 | 90 | imageio.mimsave('%s.mp4'%opt.demo.output_video_name, images) 91 | 92 | if __name__ == '__main__': 93 | main() -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | 3 | echo Downloading motion sequences... 4 | 5 | wget https://scanimate.is.tue.mpg.de/media/upload/demo_data/aist_demo_seq.zip 6 | unzip aist_demo_seq.zip -d ./data/ 7 | rm aist_demo_seq.zip 8 | mv ./data/gLO_sBM_cAll_d14_mLO1_ch05 ./data/aist_demo 9 | 10 | echo Done! 11 | 12 | 13 | mkdir outputs 14 | 15 | echo Downloading pretrained models ... 16 | wget https://dataset.ait.ethz.ch/downloads/fOUiBuCXJy/pretrained_models.zip 17 | unzip pretrained_models.zip -d ./outputs/ 18 | mv outputs/pretrained_models/* outputs/ 19 | rm outputs/pretrained_models -rf 20 | rm pretrained_models.zip 21 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: snarf 2 | channels: 3 | - anaconda 4 | - conda-forge 5 | - pytorch 6 | - pytorch3d 7 | - defaults 8 | dependencies: 9 | - python=3.7 10 | - pytorch=1.7.0 11 | - torchvision=0.8.0 12 | - cudatoolkit=11.0 13 | - iopath 14 | - fvcore 15 | - pytorch3d 16 | - pip 17 | - pip: 18 | - chumpy 19 | - scikit-image 20 | - imageio-ffmpeg 21 | - trimesh 22 | - wandb 23 | - tqdm 24 | - pytorch-lightning==1.3.3 25 | - opencv-python==4.5.2.54 26 | - hydra-core==1.0.6 27 | - usd-core==20.11 # for kaolin 28 | - Cython==0.29.20 # for kaolin -------------------------------------------------------------------------------- /lib/dataset/cape.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | 5 | from torch.utils.data import DataLoader, Dataset 6 | import numpy as np 7 | import os 8 | import glob 9 | import hydra 10 | 11 | import pickle 12 | import kaolin 13 | 14 | 15 | class CAPEDataSet(Dataset): 16 | 17 | def __init__(self, dataset_path, subject=32, clothing='longshort'): 18 | 19 | dataset_path = hydra.utils.to_absolute_path(dataset_path) 20 | 21 | self.regstr_list = glob.glob(os.path.join(dataset_path, 'cape_release', 'sequences', '%05d'%subject, clothing+'_**/*.npz'), recursive=True) 22 | 23 | genders_list = os.path.join(dataset_path, 'cape_release', 'misc', 'subj_genders.pkl') 24 | with open(genders_list,'rb') as f: 25 | self.gender = pickle.load(f, encoding='latin1')['%05d'%subject] 26 | 27 | minimal_body_path = os.path.join(dataset_path, 'cape_release', 'minimal_body_shape', '%05d'%subject, '%05d_minimal.npy'%subject) 28 | self.v_template = np.load(minimal_body_path) 29 | 30 | self.meta_info = {'v_template': self.v_template, 'gender': self.gender} 31 | 32 | 33 | def __getitem__(self, index): 34 | 35 | data = {} 36 | 37 | while True: 38 | try: 39 | regstr = np.load(self.regstr_list[index]) 40 | poses = regstr['pose'] 41 | break 42 | except: 43 | index = np.random.randint(self.__len__()) 44 | print('corrupted npz') 45 | 46 | verts = regstr['v_posed'] - regstr['transl'][None,:] 47 | verts = torch.tensor(verts).float() 48 | 49 | smpl_params = torch.zeros([86]).float() 50 | smpl_params[0] = 1 51 | smpl_params[4:76] = torch.tensor(poses).float() 52 | 53 | data['scan_verts'] = verts 54 | data['smpl_params'] = smpl_params 55 | data['smpl_thetas'] = smpl_params[4:76] 56 | data['smpl_betas'] = smpl_params[76:] 57 | 58 | return data 59 | 60 | def __len__(self): 61 | return len(self.regstr_list) 62 | 63 | ''' Used to generate groud-truth occupancy and bone transformations in batchs during training ''' 64 | class CAPEDataProcessor(): 65 | 66 | def __init__(self, opt, meta_info, **kwargs): 67 | from lib.model.smpl import SMPLServer 68 | 69 | self.opt = opt 70 | self.gender = meta_info['gender'] 71 | self.v_template =meta_info['v_template'] 72 | 73 | self.smpl_server = SMPLServer(gender=self.gender, v_template=self.v_template) 74 | self.smpl_faces = torch.tensor(self.smpl_server.smpl.faces.astype('int')).unsqueeze(0).cuda() 75 | self.sampler = hydra.utils.instantiate(opt.sampler) 76 | 77 | def process(self, data): 78 | 79 | smpl_output = self.smpl_server(data['smpl_params'], absolute=True) 80 | data.update(smpl_output) 81 | 82 | num_batch, num_verts, num_dim = smpl_output['smpl_verts'].shape 83 | 84 | random_idx = torch.randint(0, num_verts, [num_batch, self.opt.points_per_frame,1], device=smpl_output['smpl_verts'].device) 85 | 86 | random_pts = torch.gather(data['scan_verts'], 1, random_idx.expand(-1, -1, num_dim)) 87 | data['pts_d'] = self.sampler.get_points(random_pts) 88 | 89 | data['occ_gt'] = kaolin.ops.mesh.check_sign(data['scan_verts'], self.smpl_faces[0], data['pts_d']).float().unsqueeze(-1) 90 | 91 | return data 92 | 93 | class CAPEDataModule(pl.LightningDataModule): 94 | 95 | def __init__(self, opt, **kwargs): 96 | super().__init__() 97 | self.opt = opt 98 | 99 | def setup(self, stage=None): 100 | 101 | if stage == 'fit': 102 | self.dataset_train = CAPEDataSet(dataset_path=self.opt.dataset_path, 103 | subject=self.opt.subject, 104 | clothing=self.opt.clothing) 105 | 106 | self.dataset_val = CAPEDataSet(dataset_path=self.opt.dataset_path, 107 | subject=self.opt.subject, 108 | clothing=self.opt.clothing) 109 | 110 | self.meta_info = self.dataset_val.meta_info 111 | 112 | 113 | def train_dataloader(self): 114 | dataloader = DataLoader(self.dataset_train, 115 | batch_size=self.opt.batch_size, 116 | num_workers=self.opt.num_workers, 117 | shuffle=True, 118 | drop_last=True, 119 | pin_memory=True) 120 | return dataloader 121 | 122 | def val_dataloader(self): 123 | dataloader = DataLoader(self.dataset_val, 124 | batch_size=self.opt.batch_size, 125 | num_workers=self.opt.num_workers, 126 | shuffle=False, 127 | drop_last=False, 128 | pin_memory=True) 129 | return dataloader 130 | 131 | def test_dataloader(self): 132 | dataloader = DataLoader(self.dataset_val, 133 | batch_size=1, 134 | num_workers=self.opt.num_workers, 135 | shuffle=False, 136 | drop_last=False, 137 | pin_memory=True) 138 | return dataloader 139 | -------------------------------------------------------------------------------- /lib/dataset/dfaust.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import yaml 4 | import hydra 5 | import torch 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | class DFaustDataset(Dataset): 12 | def __init__(self, dataset_path, stage, subject, points_per_frame=2000): 13 | 14 | dataset_path = hydra.utils.to_absolute_path(dataset_path) 15 | split_path = hydra.utils.to_absolute_path('lib/dataset/dfaust_split.yml') 16 | with open(split_path, 'r') as stream: 17 | split = yaml.safe_load(stream) 18 | 19 | self.stage = stage 20 | 21 | shape = np.load(os.path.join(dataset_path, 'shapes', '%d_shape.npz'%subject)) 22 | self.betas = torch.tensor(shape['betas'][:10]) 23 | self.gender = str(shape['gender']) 24 | self.meta_info = {'betas': shape['betas'][:10], 'gender': self.gender} 25 | 26 | self.frame_names = [] 27 | for name in split[str(subject)][stage]: 28 | frame_name_act = glob.glob(os.path.join(dataset_path, 'points', str(subject), str(subject)+ '_' + name+'_poses'+'*.npz')) 29 | self.frame_names.extend(frame_name_act) 30 | 31 | self.points_per_frame = points_per_frame 32 | self.total_points = 100000 33 | 34 | 35 | def __getitem__(self, i): 36 | name = self.frame_names[i] 37 | data = {} 38 | dataset = np.load(name) 39 | if self.stage == 'train': 40 | random_idx = torch.cat([torch.randint(0,self.total_points,[self.points_per_frame//8]), # 1//8 for bbox samples 41 | torch.randint(0,self.total_points,[self.points_per_frame])+self.total_points], # 1 for surface samples 42 | 0) 43 | data['pts_d'] = torch.tensor(dataset['points'][0, random_idx]).float() 44 | data['occ_gt'] = torch.tensor(dataset['occupancies'][0, random_idx]).float().unsqueeze(-1) 45 | elif self.stage == 'val': 46 | data['pts_d'] = torch.tensor(dataset['points'][0, ::20]).float() 47 | data['occ_gt'] = torch.tensor(dataset['occupancies'][0, ::20]).float().unsqueeze(-1) 48 | elif self.stage == 'test': 49 | data['pts_d'] = torch.tensor(dataset['points'][0, :]).float() 50 | data['occ_gt'] = torch.tensor(dataset['occupancies'][0, :]).float().unsqueeze(-1) 51 | 52 | data['smpl_verts'] = torch.tensor(dataset['vertices']) 53 | data['smpl_tfs'] = torch.tensor(dataset['bone_transforms']).inverse() 54 | data['smpl_jnts'] = torch.tensor(dataset['joints']) 55 | data['smpl_thetas'] = torch.tensor(dataset['pose']) 56 | data['smpl_betas'] = self.betas 57 | 58 | return data 59 | 60 | def __len__(self): 61 | return len(self.frame_names) 62 | 63 | class DFaustDataModule(pl.LightningDataModule): 64 | 65 | def __init__(self, opt, **kwargs): 66 | super().__init__() 67 | self.opt = opt 68 | 69 | def setup(self, stage=None): 70 | 71 | if stage == 'fit': 72 | self.dataset_train = DFaustDataset(dataset_path=self.opt.dataset_path, 73 | stage = 'train', 74 | subject=self.opt.subject, 75 | points_per_frame=self.opt.points_per_frame, 76 | ) 77 | 78 | self.dataset_val = DFaustDataset(dataset_path=self.opt.dataset_path, 79 | stage = 'val', 80 | subject=self.opt.subject, 81 | ) 82 | 83 | self.meta_info = self.dataset_train.meta_info 84 | 85 | elif stage == 'test': 86 | self.dataset_test = DFaustDataset(dataset_path=self.opt.dataset_path, 87 | stage = 'test', 88 | subject=self.opt.subject, 89 | ) 90 | 91 | self.meta_info = self.dataset_test.meta_info 92 | 93 | def train_dataloader(self): 94 | dataloader = DataLoader(self.dataset_train, 95 | batch_size=self.opt.batch_size, 96 | num_workers=self.opt.num_workers, 97 | shuffle=True, 98 | drop_last=True, 99 | pin_memory=True) 100 | return dataloader 101 | 102 | def val_dataloader(self): 103 | dataloader = DataLoader(self.dataset_val, 104 | batch_size=1, 105 | num_workers=self.opt.num_workers, 106 | shuffle=False, 107 | drop_last=False, 108 | pin_memory=True) 109 | return dataloader 110 | 111 | def test_dataloader(self): 112 | dataloader = DataLoader(self.dataset_test, 113 | batch_size=1, 114 | num_workers=self.opt.num_workers, 115 | shuffle=False, 116 | drop_last=False, 117 | pin_memory=True) 118 | return dataloader 119 | -------------------------------------------------------------------------------- /lib/dataset/dfaust_split.yml: -------------------------------------------------------------------------------- 1 | '50002': 2 | test: 3 | - chicken_wings 4 | train: 5 | - hips 6 | - jiggle_on_toes 7 | - jumping_jacks 8 | - knees 9 | - light_hopping_loose 10 | - light_hopping_stiff 11 | - one_leg_jump 12 | - one_leg_loose 13 | - punching 14 | val: 15 | - chicken_wings 16 | '50004': 17 | test: 18 | - hips 19 | train: 20 | - chicken_wings 21 | - jiggle_on_toes 22 | - jumping_jacks 23 | - knees 24 | - light_hopping_loose 25 | - light_hopping_stiff 26 | - one_leg_jump 27 | - one_leg_loose 28 | - punching 29 | val: 30 | - hips 31 | '50007': 32 | test: 33 | - jiggle_on_toes 34 | train: 35 | - chicken_wings 36 | - jumping_jacks 37 | - knees 38 | - light_hopping_loose 39 | - light_hopping_stiff 40 | - one_leg_jump 41 | - one_leg_loose 42 | - punching 43 | - running_on_spot 44 | val: 45 | - jiggle_on_toes 46 | '50009': 47 | test: 48 | - jumping_jacks 49 | train: 50 | - chicken_wings 51 | - hips 52 | - jiggle_on_toes 53 | - light_hopping_loose 54 | - light_hopping_stiff 55 | - one_leg_jump 56 | - one_leg_loose 57 | - punching 58 | - running_on_spot 59 | val: 60 | - jumping_jacks 61 | '50020': 62 | test: 63 | - knees 64 | train: 65 | - chicken_wings 66 | - hips 67 | - jiggle_on_toes 68 | - light_hopping_loose 69 | - light_hopping_stiff 70 | - one_leg_jump 71 | - one_leg_loose 72 | - punching 73 | - running_on_spot 74 | val: 75 | - knees 76 | '50021': 77 | test: 78 | - light_hopping_stiff 79 | train: 80 | - chicken_wings 81 | - hips 82 | - knees 83 | - one_leg_jump 84 | - one_leg_loose 85 | - punching 86 | - running_on_spot 87 | - shake_arms 88 | - shake_hips 89 | val: 90 | - light_hopping_stiff 91 | '50022': 92 | test: 93 | - light_hopping_loose 94 | train: 95 | - hips 96 | - jiggle_on_toes 97 | - jumping_jacks 98 | - knees 99 | - light_hopping_stiff 100 | - one_leg_jump 101 | - one_leg_loose 102 | - punching 103 | - running_on_spot 104 | val: 105 | - light_hopping_loose 106 | '50025': 107 | test: 108 | - one_leg_jump 109 | train: 110 | - chicken_wings 111 | - hips 112 | - jiggle_on_toes 113 | - knees 114 | - light_hopping_loose 115 | - light_hopping_stiff 116 | - one_leg_loose 117 | - punching 118 | - running_on_spot 119 | val: 120 | - one_leg_jump 121 | '50026': 122 | test: 123 | - one_leg_loose 124 | train: 125 | - chicken_wings 126 | - hips 127 | - jiggle_on_toes 128 | - jumping_jacks 129 | - knees 130 | - light_hopping_loose 131 | - light_hopping_stiff 132 | - one_leg_jump 133 | - punching 134 | val: 135 | - one_leg_loose 136 | '50027': 137 | test: 138 | - punching 139 | train: 140 | - hips 141 | - jiggle_on_toes 142 | - jumping_jacks 143 | - knees 144 | - light_hopping_loose 145 | - light_hopping_stiff 146 | - one_leg_jump 147 | - one_leg_loose 148 | - running_on_spot 149 | val: 150 | - punching -------------------------------------------------------------------------------- /lib/dataset/jointlim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | 5 | from torch.utils.data import DataLoader, Dataset 6 | import numpy as np 7 | import os 8 | import glob 9 | import hydra 10 | 11 | class JointLimDataSet(Dataset): 12 | def __init__(self, dataset_path, subject): 13 | 14 | dataset_path = hydra.utils.to_absolute_path(dataset_path) 15 | 16 | seqs = ['op2', 'op3', 'op4', 'op5', 'op7', 'op8', 'op9'] 17 | 18 | shape = np.load(os.path.join(dataset_path, 'shapes', '%d_shape.npz'%subject)) 19 | self.betas = torch.tensor(shape['betas'][:10]) 20 | self.gender = str(shape['gender']) 21 | self.meta_info = {'betas': shape['betas'][:10], 'gender': self.gender} 22 | 23 | self.frame_names = [] 24 | 25 | for seq in seqs: 26 | self.frame_names += sorted(glob.glob(os.path.join(dataset_path, 'points', str(subject), seq + '_poses'+'*.npz'))) 27 | 28 | 29 | def __getitem__(self, index): 30 | 31 | name = self.frame_names[index] 32 | data = {} 33 | 34 | dataset = np.load(name) 35 | 36 | data['pts_d'] = torch.tensor(dataset['points'][0]).float() 37 | data['occ_gt'] = torch.tensor(dataset['occupancies'][0, :]).float().unsqueeze(-1) 38 | 39 | data['smpl_verts'] = torch.tensor(dataset['vertices']) 40 | data['smpl_tfs'] = torch.tensor(dataset['bone_transforms']).inverse() 41 | data['smpl_jnts'] = torch.tensor(dataset['joints']) 42 | data['smpl_thetas'] = torch.tensor(dataset['pose']) 43 | data['smpl_betas'] = self.betas 44 | 45 | return data 46 | 47 | def __len__(self): 48 | return len(self.frame_names) 49 | 50 | 51 | class JointLimDataModule(pl.LightningDataModule): 52 | 53 | def __init__(self, opt, **kwargs): 54 | super().__init__() 55 | self.opt = opt 56 | 57 | def setup(self, stage=None): 58 | 59 | if stage == 'test': 60 | self.dataset_test = JointLimDataSet(dataset_path=self.opt.dataset_path, 61 | subject=self.opt.subject, 62 | ) 63 | 64 | self.meta_info = self.dataset_test.meta_info 65 | 66 | def test_dataloader(self): 67 | dataloader = DataLoader(self.dataset_test, 68 | batch_size=1, 69 | num_workers=self.opt.num_workers, 70 | shuffle=False, 71 | drop_last=False, 72 | pin_memory=True) 73 | return dataloader 74 | -------------------------------------------------------------------------------- /lib/libmise/mise.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | cimport cython 3 | from cython.operator cimport dereference as dref 4 | from libcpp.vector cimport vector 5 | from libcpp.map cimport map 6 | from libc.math cimport isnan, NAN 7 | import numpy as np 8 | 9 | 10 | cdef struct Vector3D: 11 | int x, y, z 12 | 13 | 14 | cdef struct Voxel: 15 | Vector3D loc 16 | unsigned int level 17 | bint is_leaf 18 | unsigned long children[2][2][2] 19 | 20 | 21 | cdef struct GridPoint: 22 | Vector3D loc 23 | double value 24 | bint known 25 | 26 | 27 | cdef inline unsigned long vec_to_idx(Vector3D coord, long resolution): 28 | cdef unsigned long idx 29 | idx = resolution * resolution * coord.x + resolution * coord.y + coord.z 30 | return idx 31 | 32 | 33 | cdef class MISE: 34 | cdef vector[Voxel] voxels 35 | cdef vector[GridPoint] grid_points 36 | cdef map[long, long] grid_point_hash 37 | cdef readonly int resolution_0 38 | cdef readonly int depth 39 | cdef readonly double threshold 40 | cdef readonly int voxel_size_0 41 | cdef readonly int resolution 42 | 43 | def __cinit__(self, int resolution_0, int depth, double threshold): 44 | self.resolution_0 = resolution_0 45 | self.depth = depth 46 | self.threshold = threshold 47 | self.voxel_size_0 = (1 << depth) 48 | self.resolution = resolution_0 * self.voxel_size_0 49 | 50 | # Create initial voxels 51 | self.voxels.reserve(resolution_0 * resolution_0 * resolution_0) 52 | 53 | cdef Voxel voxel 54 | cdef GridPoint point 55 | cdef Vector3D loc 56 | cdef int i, j, k 57 | for i in range(resolution_0): 58 | for j in range(resolution_0): 59 | for k in range (resolution_0): 60 | loc = Vector3D( 61 | i * self.voxel_size_0, 62 | j * self.voxel_size_0, 63 | k * self.voxel_size_0, 64 | ) 65 | voxel = Voxel( 66 | loc=loc, 67 | level=0, 68 | is_leaf=True, 69 | ) 70 | 71 | assert(self.voxels.size() == vec_to_idx(Vector3D(i, j, k), resolution_0)) 72 | self.voxels.push_back(voxel) 73 | 74 | # Create initial grid points 75 | self.grid_points.reserve((resolution_0 + 1) * (resolution_0 + 1) * (resolution_0 + 1)) 76 | for i in range(resolution_0 + 1): 77 | for j in range(resolution_0 + 1): 78 | for k in range(resolution_0 + 1): 79 | loc = Vector3D( 80 | i * self.voxel_size_0, 81 | j * self.voxel_size_0, 82 | k * self.voxel_size_0, 83 | ) 84 | assert(self.grid_points.size() == vec_to_idx(Vector3D(i, j, k), resolution_0 + 1)) 85 | self.add_grid_point(loc) 86 | 87 | def update(self, long[:, :] points, double[:] values): 88 | """Update points and set their values. Also determine all active voxels and subdivide them.""" 89 | assert(points.shape[0] == values.shape[0]) 90 | assert(points.shape[1] == 3) 91 | cdef Vector3D loc 92 | cdef long idx 93 | cdef int i 94 | 95 | # Find all indices of point and set value 96 | for i in range(points.shape[0]): 97 | loc = Vector3D(points[i, 0], points[i, 1], points[i, 2]) 98 | idx = self.get_grid_point_idx(loc) 99 | if idx == -1: 100 | raise ValueError('Point not in grid!') 101 | self.grid_points[idx].value = values[i] 102 | self.grid_points[idx].known = True 103 | # Subdivide activate voxels and add new points 104 | self.subdivide_voxels() 105 | 106 | def query(self): 107 | """Query points to evaluate.""" 108 | # Find all points with unknown value 109 | cdef vector[Vector3D] points 110 | cdef int n_unknown = 0 111 | for p in self.grid_points: 112 | if not p.known: 113 | n_unknown += 1 114 | 115 | points.reserve(n_unknown) 116 | for p in self.grid_points: 117 | if not p.known: 118 | points.push_back(p.loc) 119 | 120 | # Convert to numpy 121 | points_np = np.zeros((points.size(), 3), dtype=np.int64) 122 | cdef long[:, :] points_view = points_np 123 | for i in range(points.size()): 124 | points_view[i, 0] = points[i].x 125 | points_view[i, 1] = points[i].y 126 | points_view[i, 2] = points[i].z 127 | 128 | return points_np 129 | 130 | def to_dense(self): 131 | """Output dense matrix at highest resolution.""" 132 | out_array = np.full((self.resolution + 1,) * 3, np.nan) 133 | cdef double[:, :, :] out_view = out_array 134 | cdef GridPoint point 135 | cdef int i, j, k 136 | 137 | for point in self.grid_points: 138 | # Take voxel for which points is upper left corner 139 | # assert(point.known) 140 | out_view[point.loc.x, point.loc.y, point.loc.z] = point.value 141 | 142 | # Complete along x axis 143 | for i in range(1, self.resolution + 1): 144 | for j in range(self.resolution + 1): 145 | for k in range(self.resolution + 1): 146 | if isnan(out_view[i, j, k]): 147 | out_view[i, j, k] = out_view[i-1, j, k] 148 | 149 | # Complete along y axis 150 | for i in range(self.resolution + 1): 151 | for j in range(1, self.resolution + 1): 152 | for k in range(self.resolution + 1): 153 | if isnan(out_view[i, j, k]): 154 | out_view[i, j, k] = out_view[i, j-1, k] 155 | 156 | 157 | # Complete along z axis 158 | for i in range(self.resolution + 1): 159 | for j in range(self.resolution + 1): 160 | for k in range(1, self.resolution + 1): 161 | if isnan(out_view[i, j, k]): 162 | out_view[i, j, k] = out_view[i, j, k-1] 163 | assert(not isnan(out_view[i, j, k])) 164 | return out_array 165 | 166 | def get_points(self): 167 | points_np = np.zeros((self.grid_points.size(), 3), dtype=np.int64) 168 | values_np = np.zeros((self.grid_points.size()), dtype=np.float64) 169 | 170 | cdef long[:, :] points_view = points_np 171 | cdef double[:] values_view = values_np 172 | cdef Vector3D loc 173 | cdef int i 174 | 175 | for i in range(self.grid_points.size()): 176 | loc = self.grid_points[i].loc 177 | points_view[i, 0] = loc.x 178 | points_view[i, 1] = loc.y 179 | points_view[i, 2] = loc.z 180 | values_view[i] = self.grid_points[i].value 181 | 182 | return points_np, values_np 183 | 184 | cdef void subdivide_voxels(self) except +: 185 | cdef vector[bint] next_to_positive 186 | cdef vector[bint] next_to_negative 187 | cdef int i, j, k 188 | cdef long idx 189 | cdef Vector3D loc, adj_loc 190 | 191 | # Initialize vectors 192 | next_to_positive.resize(self.voxels.size(), False) 193 | next_to_negative.resize(self.voxels.size(), False) 194 | 195 | # Iterate over grid points and mark voxels active 196 | # TODO: can move this to update operation and add attibute to voxel 197 | for grid_point in self.grid_points: 198 | loc = grid_point.loc 199 | if not grid_point.known: 200 | continue 201 | 202 | # Iterate over the 8 adjacent voxels 203 | for i in range(-1, 1): 204 | for j in range(-1, 1): 205 | for k in range(-1, 1): 206 | adj_loc = Vector3D( 207 | x=loc.x + i, 208 | y=loc.y + j, 209 | z=loc.z + k, 210 | ) 211 | idx = self.get_voxel_idx(adj_loc) 212 | if idx == -1: 213 | continue 214 | 215 | if grid_point.value >= self.threshold: 216 | next_to_positive[idx] = True 217 | if grid_point.value <= self.threshold: 218 | next_to_negative[idx] = True 219 | 220 | cdef int n_subdivide = 0 221 | 222 | for idx in range(self.voxels.size()): 223 | if not self.voxels[idx].is_leaf or self.voxels[idx].level == self.depth: 224 | continue 225 | if next_to_positive[idx] and next_to_negative[idx]: 226 | n_subdivide += 1 227 | 228 | self.voxels.reserve(self.voxels.size() + 8 * n_subdivide) 229 | self.grid_points.reserve(self.voxels.size() + 19 * n_subdivide) 230 | 231 | for idx in range(self.voxels.size()): 232 | if not self.voxels[idx].is_leaf or self.voxels[idx].level == self.depth: 233 | continue 234 | if next_to_positive[idx] and next_to_negative[idx]: 235 | self.subdivide_voxel(idx) 236 | 237 | cdef void subdivide_voxel(self, long idx): 238 | cdef Voxel voxel 239 | cdef GridPoint point 240 | cdef Vector3D loc0 = self.voxels[idx].loc 241 | cdef Vector3D loc 242 | cdef int new_level = self.voxels[idx].level + 1 243 | cdef int new_size = 1 << (self.depth - new_level) 244 | assert(new_level <= self.depth) 245 | assert(1 <= new_size <= self.voxel_size_0) 246 | 247 | # Current voxel is not leaf anymore 248 | self.voxels[idx].is_leaf = False 249 | # Add new voxels 250 | cdef int i, j, k 251 | for i in range(2): 252 | for j in range(2): 253 | for k in range(2): 254 | loc = Vector3D( 255 | x=loc0.x + i * new_size, 256 | y=loc0.y + j * new_size, 257 | z=loc0.z + k * new_size, 258 | ) 259 | voxel = Voxel( 260 | loc=loc, 261 | level=new_level, 262 | is_leaf=True 263 | ) 264 | 265 | self.voxels[idx].children[i][j][k] = self.voxels.size() 266 | self.voxels.push_back(voxel) 267 | 268 | # Add new grid points 269 | for i in range(3): 270 | for j in range(3): 271 | for k in range(3): 272 | loc = Vector3D( 273 | loc0.x + i * new_size, 274 | loc0.y + j * new_size, 275 | loc0.z + k * new_size, 276 | ) 277 | 278 | # Only add new grid points 279 | if self.get_grid_point_idx(loc) == -1: 280 | self.add_grid_point(loc) 281 | 282 | 283 | @cython.cdivision(True) 284 | cdef long get_voxel_idx(self, Vector3D loc) except +: 285 | """Utility function for getting voxel index corresponding to 3D coordinates.""" 286 | # Shorthands 287 | cdef long resolution = self.resolution 288 | cdef long resolution_0 = self.resolution_0 289 | cdef long depth = self.depth 290 | cdef long voxel_size_0 = self.voxel_size_0 291 | 292 | # Return -1 if point lies outside bounds 293 | if not (0 <= loc.x < resolution and 0<= loc.y < resolution and 0 <= loc.z < resolution): 294 | return -1 295 | 296 | # Coordinates in coarse voxel grid 297 | cdef Vector3D loc0 = Vector3D( 298 | x=loc.x >> depth, 299 | y=loc.y >> depth, 300 | z=loc.z >> depth, 301 | ) 302 | 303 | # Initial voxels 304 | cdef int idx = vec_to_idx(loc0, resolution_0) 305 | cdef Voxel voxel = self.voxels[idx] 306 | assert(voxel.loc.x == loc0.x * voxel_size_0) 307 | assert(voxel.loc.y == loc0.y * voxel_size_0) 308 | assert(voxel.loc.z == loc0.z * voxel_size_0) 309 | 310 | # Relative coordinates 311 | cdef Vector3D loc_rel = Vector3D( 312 | x=loc.x - (loc0.x << depth), 313 | y=loc.y - (loc0.y << depth), 314 | z=loc.z - (loc0.z << depth), 315 | ) 316 | 317 | cdef Vector3D loc_offset 318 | cdef long voxel_size = voxel_size_0 319 | 320 | while not voxel.is_leaf: 321 | voxel_size = voxel_size >> 1 322 | assert(voxel_size >= 1) 323 | 324 | # Determine child 325 | loc_offset = Vector3D( 326 | x=1 if (loc_rel.x >= voxel_size) else 0, 327 | y=1 if (loc_rel.y >= voxel_size) else 0, 328 | z=1 if (loc_rel.z >= voxel_size) else 0, 329 | ) 330 | # New voxel 331 | idx = voxel.children[loc_offset.x][loc_offset.y][loc_offset.z] 332 | voxel = self.voxels[idx] 333 | 334 | # New relative coordinates 335 | loc_rel = Vector3D( 336 | x=loc_rel.x - loc_offset.x * voxel_size, 337 | y=loc_rel.y - loc_offset.y * voxel_size, 338 | z=loc_rel.z - loc_offset.z * voxel_size, 339 | ) 340 | 341 | assert(0<= loc_rel.x < voxel_size) 342 | assert(0<= loc_rel.y < voxel_size) 343 | assert(0<= loc_rel.z < voxel_size) 344 | 345 | 346 | # Return idx 347 | return idx 348 | 349 | 350 | cdef inline void add_grid_point(self, Vector3D loc): 351 | cdef GridPoint point = GridPoint( 352 | loc=loc, 353 | value=0., 354 | known=False, 355 | ) 356 | self.grid_point_hash[vec_to_idx(loc, self.resolution + 1)] = self.grid_points.size() 357 | self.grid_points.push_back(point) 358 | 359 | cdef inline int get_grid_point_idx(self, Vector3D loc): 360 | p_idx = self.grid_point_hash.find(vec_to_idx(loc, self.resolution + 1)) 361 | if p_idx == self.grid_point_hash.end(): 362 | return -1 363 | 364 | cdef int idx = dref(p_idx).second 365 | assert(self.grid_points[idx].loc.x == loc.x) 366 | assert(self.grid_points[idx].loc.y == loc.y) 367 | assert(self.grid_points[idx].loc.z == loc.z) 368 | 369 | return idx -------------------------------------------------------------------------------- /lib/model/broyden.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def broyden(g, x_init, J_inv_init, max_steps=50, cvg_thresh=1e-5, dvg_thresh=1, eps=1e-6): 5 | """Find roots of the given function g(x) = 0. 6 | This function is impleneted based on https://github.com/locuslab/deq. 7 | 8 | Tensor shape abbreviation: 9 | N: number of points 10 | D: space dimension 11 | Args: 12 | g (function): the function of which the roots are to be determined. shape: [N, D, 1]->[N, D, 1] 13 | x_init (tensor): initial value of the parameters. shape: [N, D, 1] 14 | J_inv_init (tensor): initial value of the inverse Jacobians. shape: [N, D, D] 15 | 16 | max_steps (int, optional): max number of iterations. Defaults to 50. 17 | cvg_thresh (float, optional): covergence threshold. Defaults to 1e-5. 18 | dvg_thresh (float, optional): divergence threshold. Defaults to 1. 19 | eps (float, optional): a small number added to the denominator to prevent numerical error. Defaults to 1e-6. 20 | 21 | Returns: 22 | result (tensor): root of the given function. shape: [N, D, 1] 23 | diff (tensor): corresponding loss. [N] 24 | valid_ids (tensor): identifiers of converged points. [N] 25 | """ 26 | 27 | # initialization 28 | x = x_init.clone().detach() 29 | J_inv = J_inv_init.clone().detach() 30 | 31 | ids_val = torch.ones(x.shape[0]).bool() 32 | 33 | gx = g(x, mask=ids_val) 34 | update = -J_inv.bmm(gx) 35 | 36 | x_opt = x 37 | gx_norm_opt = torch.linalg.norm(gx.squeeze(-1), dim=-1) 38 | 39 | delta_gx = torch.zeros_like(gx) 40 | delta_x = torch.zeros_like(x) 41 | 42 | ids_val = torch.ones_like(gx_norm_opt).bool() 43 | 44 | for _ in range(max_steps): 45 | 46 | # update paramter values 47 | delta_x[ids_val] = update 48 | x[ids_val] += delta_x[ids_val] 49 | delta_gx[ids_val] = g(x, mask=ids_val) - gx[ids_val] 50 | gx[ids_val] += delta_gx[ids_val] 51 | 52 | # store values with minial loss 53 | gx_norm = torch.linalg.norm(gx.squeeze(-1), dim=-1) 54 | ids_opt = gx_norm < gx_norm_opt 55 | gx_norm_opt[ids_opt] = gx_norm.clone().detach()[ids_opt] 56 | x_opt[ids_opt] = x.clone().detach()[ids_opt] 57 | 58 | # exclude converged and diverged points from furture iterations 59 | ids_val = (gx_norm_opt > cvg_thresh) & (gx_norm < dvg_thresh) 60 | if ids_val.sum() <= 0: 61 | break 62 | 63 | # compute paramter update for next iter 64 | vT = (delta_x[ids_val]).transpose(-1, -2).bmm(J_inv[ids_val]) 65 | a = delta_x[ids_val] - J_inv[ids_val].bmm(delta_gx[ids_val]) 66 | b = vT.bmm(delta_gx[ids_val]) 67 | b[b >= 0] += eps 68 | b[b < 0] -= eps 69 | u = a / b 70 | J_inv[ids_val] += u.bmm(vT) 71 | update = -J_inv[ids_val].bmm(gx[ids_val]) 72 | 73 | return {'result': x_opt, 'diff': gx_norm_opt, 'valid_ids': gx_norm_opt < cvg_thresh} 74 | -------------------------------------------------------------------------------- /lib/model/deformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import einsum 3 | import torch.nn.functional as F 4 | 5 | from lib.model.broyden import broyden 6 | from lib.model.network import ImplicitNetwork 7 | from lib.model.helpers import hierarchical_softmax 8 | 9 | 10 | class ForwardDeformer(torch.nn.Module): 11 | """ 12 | Tensor shape abbreviation: 13 | B: batch size 14 | N: number of points 15 | J: number of bones 16 | I: number of init 17 | D: space dimension 18 | """ 19 | 20 | def __init__(self, opt, **kwargs): 21 | super().__init__() 22 | 23 | self.opt = opt 24 | 25 | self.lbs_network = ImplicitNetwork(**self.opt.network) 26 | 27 | self.soft_blend = 20 28 | 29 | self.init_bones = [0, 1, 2, 4, 5, 16, 17, 18, 19] 30 | 31 | def forward(self, xd, cond, tfs, eval_mode=False): 32 | """Given deformed point return its caonical correspondence 33 | 34 | Args: 35 | xd (tensor): deformed points in batch. shape: [B, N, D] 36 | cond (dict): conditional input. 37 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] 38 | 39 | Returns: 40 | xc (tensor): canonical correspondences. shape: [B, N, I, D] 41 | others (dict): other useful outputs. 42 | """ 43 | xc_init = self.init(xd, tfs) 44 | 45 | xc_opt, others = self.search(xd, xc_init, cond, tfs, eval_mode=eval_mode) 46 | 47 | if eval_mode: 48 | return xc_opt, others 49 | 50 | # compute correction term for implicit differentiation during training 51 | 52 | # do not back-prop through broyden 53 | xc_opt = xc_opt.detach() 54 | 55 | # reshape to [B,?,D] for network query 56 | n_batch, n_point, n_init, n_dim = xc_init.shape 57 | xc_opt = xc_opt.reshape((n_batch, n_point * n_init, n_dim)) 58 | 59 | xd_opt = self.forward_skinning(xc_opt, cond, tfs) 60 | 61 | grad_inv = self.gradient(xc_opt, cond, tfs).inverse() 62 | 63 | correction = xd_opt - xd_opt.detach() 64 | correction = einsum("bnij,bnj->bni", -grad_inv.detach(), correction) 65 | 66 | # trick for implicit diff with autodiff: 67 | # xc = xc_opt + 0 and xc' = correction' 68 | xc = xc_opt + correction 69 | 70 | # reshape back to [B,N,I,D] 71 | xc = xc.reshape(xc_init.shape) 72 | 73 | return xc, others 74 | 75 | def init(self, xd, tfs): 76 | """Transform xd to canonical space for initialization 77 | 78 | Args: 79 | xd (tensor): deformed points in batch. shape: [B, N, D] 80 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] 81 | 82 | Returns: 83 | xc_init (tensor): gradients. shape: [B, N, I, D] 84 | """ 85 | n_batch, n_point, _ = xd.shape 86 | _, n_joint, _, _ = tfs.shape 87 | 88 | xc_init = [] 89 | for i in self.init_bones: 90 | w = torch.zeros((n_batch, n_point, n_joint), device=xd.device) 91 | w[:, :, i] = 1 92 | xc_init.append(skinning(xd, w, tfs, inverse=True)) 93 | 94 | xc_init = torch.stack(xc_init, dim=2) 95 | 96 | return xc_init 97 | 98 | def search(self, xd, xc_init, cond, tfs, eval_mode=False): 99 | """Search correspondences. 100 | 101 | Args: 102 | xd (tensor): deformed points in batch. shape: [B, N, D] 103 | xc_init (tensor): deformed points in batch. shape: [B, N, I, D] 104 | cond (dict): conditional input. 105 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] 106 | 107 | Returns: 108 | xc_opt (tensor): canonoical correspondences of xd. shape: [B, N, I, D] 109 | valid_ids (tensor): identifiers of converged points. [B, N, I] 110 | """ 111 | # reshape to [B,?,D] for other functions 112 | n_batch, n_point, n_init, n_dim = xc_init.shape 113 | xc_init = xc_init.reshape(n_batch, n_point * n_init, n_dim) 114 | xd_tgt = xd.repeat_interleave(n_init, dim=1) 115 | 116 | # compute init jacobians 117 | if not eval_mode: 118 | J_inv_init = self.gradient(xc_init, cond, tfs).inverse() 119 | else: 120 | w = self.query_weights(xc_init, cond, mask=None) 121 | J_inv_init = einsum("bpn,bnij->bpij", w, tfs)[:, :, :3, :3].inverse() 122 | 123 | # reshape init to [?,D,...] for boryden 124 | xc_init = xc_init.reshape(-1, n_dim, 1) 125 | J_inv_init = J_inv_init.flatten(0, 1) 126 | 127 | # construct function for root finding 128 | def _func(xc_opt, mask=None): 129 | # reshape to [B,?,D] for other functions 130 | xc_opt = xc_opt.reshape(n_batch, n_point * n_init, n_dim) 131 | xd_opt = self.forward_skinning(xc_opt, cond, tfs, mask=mask) 132 | error = xd_opt - xd_tgt 133 | # reshape to [?,D,1] for boryden 134 | error = error.flatten(0, 1)[mask].unsqueeze(-1) 135 | return error 136 | 137 | # run broyden without grad 138 | with torch.no_grad(): 139 | result = broyden(_func, xc_init, J_inv_init) 140 | 141 | # reshape back to [B,N,I,D] 142 | xc_opt = result["result"].reshape(n_batch, n_point, n_init, n_dim) 143 | result["valid_ids"] = result["valid_ids"].reshape(n_batch, n_point, n_init) 144 | 145 | return xc_opt, result 146 | 147 | def forward_skinning(self, xc, cond, tfs, mask=None): 148 | """Canonical point -> deformed point 149 | 150 | Args: 151 | xc (tensor): canonoical points in batch. shape: [B, N, D] 152 | cond (dict): conditional input. 153 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] 154 | 155 | Returns: 156 | xd (tensor): deformed point. shape: [B, N, D] 157 | """ 158 | w = self.query_weights(xc, cond, mask=mask) 159 | xd = skinning(xc, w, tfs, inverse=False) 160 | return xd 161 | 162 | def query_weights(self, xc, cond, mask=None): 163 | """Get skinning weights in canonical space 164 | 165 | Args: 166 | xc (tensor): canonical points. shape: [B, N, D] 167 | cond (dict): conditional input. 168 | mask (tensor, optional): valid indices. shape: [B, N] 169 | 170 | Returns: 171 | w (tensor): skinning weights. shape: [B, N, J] 172 | """ 173 | 174 | w = self.lbs_network(xc, cond, mask) 175 | w = self.soft_blend * w 176 | 177 | if self.opt.softmax_mode == "hierarchical": 178 | w = hierarchical_softmax(w) 179 | else: 180 | w = F.softmax(w, dim=-1) 181 | 182 | return w 183 | 184 | def gradient(self, xc, cond, tfs): 185 | """Get gradients df/dx 186 | 187 | Args: 188 | xc (tensor): canonical points. shape: [B, N, D] 189 | cond (dict): conditional input. 190 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] 191 | 192 | Returns: 193 | grad (tensor): gradients. shape: [B, N, D, D] 194 | """ 195 | xc.requires_grad_(True) 196 | 197 | xd = self.forward_skinning(xc, cond, tfs) 198 | 199 | grads = [] 200 | for i in range(xd.shape[-1]): 201 | d_out = torch.zeros_like(xd, requires_grad=False, device=xd.device) 202 | d_out[:, :, i] = 1 203 | grad = torch.autograd.grad( 204 | outputs=xd, 205 | inputs=xc, 206 | grad_outputs=d_out, 207 | create_graph=False, 208 | retain_graph=True, 209 | only_inputs=True, 210 | )[0] 211 | grads.append(grad) 212 | 213 | return torch.stack(grads, dim=-2) 214 | 215 | 216 | def skinning(x, w, tfs, inverse=False): 217 | """Linear blend skinning 218 | 219 | Args: 220 | x (tensor): canonical points. shape: [B, N, D] 221 | w (tensor): conditional input. [B, N, J] 222 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] 223 | Returns: 224 | x (tensor): skinned points. shape: [B, N, D] 225 | """ 226 | x_h = F.pad(x, (0, 1), value=1.0) 227 | 228 | if inverse: 229 | # p:n_point, n:n_bone, i,k: n_dim+1 230 | w_tf = einsum("bpn,bnij->bpij", w, tfs) 231 | x_h = einsum("bpij,bpj->bpi", w_tf.inverse(), x_h) 232 | else: 233 | x_h = einsum("bpn,bnij,bpj->bpi", w, tfs, x_h) 234 | 235 | return x_h[:, :, :3] 236 | -------------------------------------------------------------------------------- /lib/model/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import cv2 4 | import numpy as np 5 | 6 | def masked_softmax(vec, mask, dim=-1, mode='softmax', soft_blend=1): 7 | if mode == 'softmax': 8 | 9 | vec = torch.distributions.Bernoulli(logits=vec).probs 10 | 11 | masked_exps = torch.exp(soft_blend*vec) * mask.float() 12 | masked_exps_sum = masked_exps.sum(dim) 13 | 14 | output = torch.zeros_like(vec) 15 | output[masked_exps_sum>0,:] = masked_exps[masked_exps_sum>0,:]/ masked_exps_sum[masked_exps_sum>0].unsqueeze(-1) 16 | 17 | output = (output * vec).sum(dim, keepdim=True) 18 | 19 | output = torch.distributions.Bernoulli(probs=output).logits 20 | 21 | elif mode == 'max': 22 | vec[~mask] = -math.inf 23 | output = torch.max(vec, dim, keepdim=True)[0] 24 | 25 | return output 26 | 27 | 28 | ''' Hierarchical softmax following the kinematic tree of the human body. Imporves convergence speed''' 29 | def hierarchical_softmax(x): 30 | def softmax(x): 31 | return torch.nn.functional.softmax(x, dim=-1) 32 | 33 | def sigmoid(x): 34 | return torch.sigmoid(x) 35 | 36 | n_batch, n_point, n_dim = x.shape 37 | x = x.flatten(0,1) 38 | 39 | prob_all = torch.ones(n_batch * n_point, 24, device=x.device) 40 | 41 | prob_all[:, [1, 2, 3]] = prob_all[:, [0]] * sigmoid(x[:, [0]]) * softmax(x[:, [1, 2, 3]]) 42 | prob_all[:, [0]] = prob_all[:, [0]] * (1 - sigmoid(x[:, [0]])) 43 | 44 | prob_all[:, [4, 5, 6]] = prob_all[:, [1, 2, 3]] * (sigmoid(x[:, [4, 5, 6]])) 45 | prob_all[:, [1, 2, 3]] = prob_all[:, [1, 2, 3]] * (1 - sigmoid(x[:, [4, 5, 6]])) 46 | 47 | prob_all[:, [7, 8, 9]] = prob_all[:, [4, 5, 6]] * (sigmoid(x[:, [7, 8, 9]])) 48 | prob_all[:, [4, 5, 6]] = prob_all[:, [4, 5, 6]] * (1 - sigmoid(x[:, [7, 8, 9]])) 49 | 50 | prob_all[:, [10, 11]] = prob_all[:, [7, 8]] * (sigmoid(x[:, [10, 11]])) 51 | prob_all[:, [7, 8]] = prob_all[:, [7, 8]] * (1 - sigmoid(x[:, [10, 11]])) 52 | 53 | prob_all[:, [12, 13, 14]] = prob_all[:, [9]] * sigmoid(x[:, [24]]) * softmax(x[:, [12, 13, 14]]) 54 | prob_all[:, [9]] = prob_all[:, [9]] * (1 - sigmoid(x[:, [24]])) 55 | 56 | prob_all[:, [15]] = prob_all[:, [12]] * (sigmoid(x[:, [15]])) 57 | prob_all[:, [12]] = prob_all[:, [12]] * (1 - sigmoid(x[:, [15]])) 58 | 59 | prob_all[:, [16, 17]] = prob_all[:, [13, 14]] * (sigmoid(x[:, [16, 17]])) 60 | prob_all[:, [13, 14]] = prob_all[:, [13, 14]] * (1 - sigmoid(x[:, [16, 17]])) 61 | 62 | prob_all[:, [18, 19]] = prob_all[:, [16, 17]] * (sigmoid(x[:, [18, 19]])) 63 | prob_all[:, [16, 17]] = prob_all[:, [16, 17]] * (1 - sigmoid(x[:, [18, 19]])) 64 | 65 | prob_all[:, [20, 21]] = prob_all[:, [18, 19]] * (sigmoid(x[:, [20, 21]])) 66 | prob_all[:, [18, 19]] = prob_all[:, [18, 19]] * (1 - sigmoid(x[:, [20, 21]])) 67 | 68 | prob_all[:, [22, 23]] = prob_all[:, [20, 21]] * (sigmoid(x[:, [22, 23]])) 69 | prob_all[:, [20, 21]] = prob_all[:, [20, 21]] * (1 - sigmoid(x[:, [22, 23]])) 70 | 71 | prob_all = prob_all.reshape(n_batch, n_point, prob_all.shape[-1]) 72 | return prob_all 73 | 74 | def rectify_pose(pose, root_abs): 75 | """ 76 | Rectify AMASS pose in global coord adapted from https://github.com/akanazawa/hmr/issues/50. 77 | 78 | Args: 79 | pose (72,): Pose. 80 | 81 | Returns: 82 | Rotated pose. 83 | """ 84 | pose = pose.copy() 85 | R_abs = cv2.Rodrigues(root_abs)[0] 86 | R_root = cv2.Rodrigues(pose[:3])[0] 87 | new_root = np.linalg.inv(R_abs).dot(R_root) 88 | pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3) 89 | return pose -------------------------------------------------------------------------------- /lib/model/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def calculate_iou(gt, prediction): 4 | intersection = torch.logical_and(gt, prediction) 5 | union = torch.logical_or(gt, prediction) 6 | return torch.sum(intersection) / torch.sum(union) -------------------------------------------------------------------------------- /lib/model/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | """ MLP for neural implicit shapes. The code is based on https://github.com/lioryariv/idr with adaption. """ 5 | class ImplicitNetwork(torch.nn.Module): 6 | def __init__( 7 | self, 8 | d_in, 9 | d_out, 10 | width, 11 | depth, 12 | geometric_init=True, 13 | bias=1.0, 14 | weight_norm=True, 15 | multires=0, 16 | skip_layer=[], 17 | cond_layer=[], 18 | cond_dim=69, 19 | dim_cond_embed=-1, 20 | ): 21 | super().__init__() 22 | 23 | dims = [d_in] + [width] * depth + [d_out] 24 | self.num_layers = len(dims) 25 | 26 | self.embed_fn = None 27 | if multires > 0: 28 | embed_fn, input_ch = get_embedder(multires) 29 | self.embed_fn = embed_fn 30 | dims[0] = input_ch 31 | 32 | self.cond_layer = cond_layer 33 | self.cond_dim = cond_dim 34 | 35 | self.dim_cond_embed = dim_cond_embed 36 | if dim_cond_embed > 0: 37 | self.lin_p0 = torch.nn.Linear(self.cond_dim, dim_cond_embed) 38 | self.cond_dim = dim_cond_embed 39 | 40 | self.skip_layer = skip_layer 41 | 42 | for l in range(0, self.num_layers - 1): 43 | if l + 1 in self.skip_layer: 44 | out_dim = dims[l + 1] - dims[0] 45 | else: 46 | out_dim = dims[l + 1] 47 | 48 | if l in self.cond_layer: 49 | lin = torch.nn.Linear(dims[l] + self.cond_dim, out_dim) 50 | else: 51 | lin = torch.nn.Linear(dims[l], out_dim) 52 | 53 | if geometric_init: 54 | if l == self.num_layers - 2: 55 | torch.nn.init.normal_( 56 | lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001 57 | ) 58 | torch.nn.init.constant_(lin.bias, bias) 59 | elif multires > 0 and l == 0: 60 | torch.nn.init.constant_(lin.bias, 0.0) 61 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 62 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 63 | elif multires > 0 and l in self.skip_layer: 64 | torch.nn.init.constant_(lin.bias, 0.0) 65 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 66 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0) 67 | else: 68 | torch.nn.init.constant_(lin.bias, 0.0) 69 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 70 | 71 | if weight_norm: 72 | lin = torch.nn.utils.weight_norm(lin) 73 | 74 | setattr(self, "lin" + str(l), lin) 75 | 76 | self.softplus = torch.nn.Softplus(beta=100) 77 | 78 | def forward(self, input, cond, mask=None): 79 | """MPL query. 80 | 81 | Tensor shape abbreviation: 82 | B: batch size 83 | N: number of points 84 | D: input dimension 85 | 86 | Args: 87 | input (tensor): network input. shape: [B, N, D] 88 | cond (dict): conditional input. 89 | mask (tensor, optional): only masked inputs are fed into the network. shape: [B, N] 90 | 91 | Returns: 92 | output (tensor): network output. Might contains placehold if mask!=None shape: [N, D, ?] 93 | """ 94 | 95 | 96 | n_batch, n_point, n_dim = input.shape 97 | 98 | if n_batch * n_point == 0: 99 | return input 100 | 101 | # reshape to [N,?] 102 | input = input.reshape(n_batch * n_point, n_dim) 103 | if mask is not None: 104 | input = input[mask] 105 | 106 | input_embed = input if self.embed_fn is None else self.embed_fn(input) 107 | 108 | if len(self.cond_layer): 109 | cond = cond["smpl"] 110 | n_batch, n_cond = cond.shape 111 | input_cond = cond.unsqueeze(1).expand(n_batch, n_point, n_cond) 112 | input_cond = input_cond.reshape(n_batch * n_point, n_cond) 113 | 114 | if mask is not None: 115 | input_cond = input_cond[mask] 116 | 117 | if self.dim_cond_embed > 0: 118 | input_cond = self.lin_p0(input_cond) 119 | 120 | x = input_embed 121 | 122 | for l in range(0, self.num_layers - 1): 123 | lin = getattr(self, "lin" + str(l)) 124 | if l in self.cond_layer: 125 | x = torch.cat([x, input_cond], dim=-1) 126 | 127 | if l in self.skip_layer: 128 | x = torch.cat([x, input_embed], 1) / np.sqrt(2) 129 | 130 | x = lin(x) 131 | 132 | if l < self.num_layers - 2: 133 | x = self.softplus(x) 134 | 135 | # add placeholder for masked prediction 136 | if mask is not None: 137 | x_full = torch.zeros(n_batch * n_point, x.shape[-1], device=x.device) 138 | x_full[mask] = x 139 | else: 140 | x_full = x 141 | 142 | return x_full.reshape(n_batch, n_point, -1) 143 | 144 | 145 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ 146 | class Embedder: 147 | def __init__(self, **kwargs): 148 | self.kwargs = kwargs 149 | self.create_embedding_fn() 150 | 151 | def create_embedding_fn(self): 152 | embed_fns = [] 153 | d = self.kwargs["input_dims"] 154 | out_dim = 0 155 | if self.kwargs["include_input"]: 156 | embed_fns.append(lambda x: x) 157 | out_dim += d 158 | 159 | max_freq = self.kwargs["max_freq_log2"] 160 | N_freqs = self.kwargs["num_freqs"] 161 | 162 | if self.kwargs["log_sampling"]: 163 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs) 164 | else: 165 | freq_bands = torch.linspace(2.0 ** 0.0, 2.0 ** max_freq, N_freqs) 166 | 167 | for freq in freq_bands: 168 | for p_fn in self.kwargs["periodic_fns"]: 169 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 170 | out_dim += d 171 | 172 | self.embed_fns = embed_fns 173 | self.out_dim = out_dim 174 | 175 | def embed(self, inputs): 176 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 177 | 178 | 179 | def get_embedder(multires): 180 | embed_kwargs = { 181 | "include_input": True, 182 | "input_dims": 3, 183 | "max_freq_log2": multires - 1, 184 | "num_freqs": multires, 185 | "log_sampling": True, 186 | "periodic_fns": [torch.sin, torch.cos], 187 | } 188 | 189 | embedder_obj = Embedder(**embed_kwargs) 190 | 191 | def embed(x, eo=embedder_obj): 192 | return eo.embed(x) 193 | 194 | return embed, embedder_obj.out_dim 195 | -------------------------------------------------------------------------------- /lib/model/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class PointOnBones: 5 | def __init__(self, bone_ids): 6 | self.bone_ids = bone_ids 7 | 8 | def get_points(self, joints, num_per_bone=5): 9 | """Sample points on bones in canonical space. 10 | 11 | Args: 12 | joints (tensor): joint positions to define the bone positions. shape: [B, J, D] 13 | num_per_bone (int, optional): number of sample points on each bone. Defaults to 5. 14 | 15 | Returns: 16 | samples (tensor): sampled points in canoncial space. shape: [B, ?, 3] 17 | probs (tensor): ground truth occupancy for samples (all 1). shape: [B, ?] 18 | """ 19 | 20 | num_batch, _, _ = joints.shape 21 | 22 | samples = [] 23 | 24 | for bone_id in self.bone_ids: 25 | 26 | if bone_id[0] < 0 or bone_id[1] < 0: 27 | continue 28 | 29 | bone_dir = joints[:, bone_id[1]] - joints[:, bone_id[0]] 30 | 31 | scalars = ( 32 | torch.linspace(0, 1, steps=num_per_bone, device=joints.device) 33 | .unsqueeze(0) 34 | .expand(num_batch, -1) 35 | ) 36 | scalars = ( 37 | scalars + torch.randn((num_batch, num_per_bone), device=joints.device) * 0.1 38 | ).clamp_(0, 1) 39 | 40 | samples.append( 41 | joints[:, bone_id[0]].unsqueeze(1).expand(-1, scalars.shape[-1], -1) 42 | + torch.einsum("bn,bi->bni", scalars, bone_dir) # b: num_batch, n: num_per_bone, i: 3-dim 43 | ) 44 | 45 | samples = torch.cat(samples, dim=1) 46 | 47 | probs = torch.ones((num_batch, samples.shape[1]), device=joints.device) 48 | 49 | return samples, probs 50 | 51 | def get_joints(self, joints): 52 | """Sample joints in canonical space. 53 | 54 | Args: 55 | joints (tensor): joint positions to define the bone positions. shape: [B, J, D] 56 | 57 | Returns: 58 | samples (tensor): sampled points in canoncial space. shape: [B, ?, 3] 59 | weights (tensor): ground truth skinning weights for samples (all 1). shape: [B, ?, J] 60 | """ 61 | num_batch, num_joints, _ = joints.shape 62 | 63 | samples = [] 64 | weights = [] 65 | 66 | for k in range(num_joints): 67 | samples.append(joints[:, k]) 68 | weight = torch.zeros((num_batch, num_joints), device=joints.device) 69 | weight[:, k] = 1 70 | weights.append(weight) 71 | 72 | for bone_id in self.bone_ids: 73 | 74 | if bone_id[0] < 0 or bone_id[1] < 0: 75 | continue 76 | 77 | samples.append(joints[:, bone_id[1]]) 78 | 79 | weight = torch.zeros((num_batch, num_joints), device=joints.device) 80 | weight[:, bone_id[0]] = 1 81 | weights.append(weight) 82 | 83 | samples = torch.stack(samples, dim=1) 84 | weights = torch.stack(weights, dim=1) 85 | 86 | return samples, weights 87 | 88 | 89 | class PointInSpace: 90 | def __init__(self, global_sigma=1.8, local_sigma=0.01): 91 | self.global_sigma = global_sigma 92 | self.local_sigma = local_sigma 93 | 94 | def get_points(self, pc_input): 95 | """Sample one point near each of the given point + 1/8 uniformly. 96 | 97 | Args: 98 | pc_input (tensor): sampling centers. shape: [B, N, D] 99 | 100 | Returns: 101 | samples (tensor): sampled points. shape: [B, N + N / 8, D] 102 | """ 103 | 104 | batch_size, sample_size, dim = pc_input.shape 105 | 106 | sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma) 107 | 108 | sample_global = ( 109 | torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device) 110 | * (self.global_sigma * 2) 111 | ) - self.global_sigma 112 | 113 | sample = torch.cat([sample_local, sample_global], dim=1) 114 | 115 | return sample 116 | -------------------------------------------------------------------------------- /lib/model/smpl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import numpy as np 4 | from lib.smpl.body_models import SMPL 5 | 6 | class SMPLServer(torch.nn.Module): 7 | 8 | def __init__(self, gender='neutral', betas=None, v_template=None): 9 | super().__init__() 10 | 11 | 12 | self.smpl = SMPL(model_path=hydra.utils.to_absolute_path('lib/smpl/smpl_model'), 13 | gender=gender, 14 | batch_size=1, 15 | use_hands=False, 16 | use_feet_keypoints=False, 17 | dtype=torch.float32).cuda() 18 | 19 | self.bone_parents = self.smpl.bone_parents.astype(int) 20 | self.bone_parents[0] = -1 21 | self.bone_ids = [] 22 | for i in range(24): self.bone_ids.append([self.bone_parents[i], i]) 23 | 24 | if v_template is not None: 25 | self.v_template = torch.tensor(v_template).float().cuda() 26 | else: 27 | self.v_template = None 28 | 29 | if betas is not None: 30 | self.betas = torch.tensor(betas).float().cuda() 31 | else: 32 | self.betas = None 33 | 34 | # define the canonical pose 35 | param_canonical = torch.zeros((1, 86),dtype=torch.float32).cuda() 36 | param_canonical[0, 0] = 1 37 | param_canonical[0, 9] = np.pi / 6 38 | param_canonical[0, 12] = -np.pi / 6 39 | if self.betas is not None and self.v_template is None: 40 | param_canonical[0,-10:] = self.betas 41 | self.param_canonical = param_canonical 42 | 43 | output = self.forward(param_canonical, absolute=True) 44 | self.verts_c = output['smpl_verts'] 45 | self.joints_c = output['smpl_jnts'] 46 | self.tfs_c_inv = output['smpl_tfs'].squeeze(0).inverse() 47 | 48 | 49 | def forward(self, smpl_params, absolute=False): 50 | """return SMPL output from params 51 | 52 | Args: 53 | smpl_params : smpl parameters. shape: [B, 86]. [0-scale,1:4-trans, 4:76-thetas,76:86-betas] 54 | absolute (bool): if true return smpl_tfs wrt thetas=0. else wrt thetas=thetas_canonical. 55 | 56 | Returns: 57 | smpl_verts: vertices. shape: [B, 6893. 3] 58 | smpl_tfs: bone transformations. shape: [B, 24, 4, 4] 59 | smpl_jnts: joint positions. shape: [B, 25, 3] 60 | """ 61 | 62 | output = {} 63 | 64 | scale, transl, thetas, betas = torch.split(smpl_params, [1, 3, 72, 10], dim=1) 65 | 66 | # ignore betas if v_template is provided 67 | if self.v_template is not None: 68 | betas = torch.zeros_like(betas) 69 | 70 | smpl_output = self.smpl.forward(betas=betas, 71 | transl=torch.zeros_like(transl), 72 | body_pose=thetas[:, 3:], 73 | global_orient=thetas[:, :3], 74 | return_verts=True, 75 | return_full_pose=True, 76 | v_template=self.v_template) 77 | 78 | verts = smpl_output.vertices.clone() 79 | output['smpl_verts'] = verts * scale.unsqueeze(1) + transl.unsqueeze(1) 80 | 81 | joints = smpl_output.joints.clone() 82 | output['smpl_jnts'] = joints * scale.unsqueeze(1) + transl.unsqueeze(1) 83 | 84 | tf_mats = smpl_output.T.clone() 85 | tf_mats[:, :, :3, :] *= scale.unsqueeze(1).unsqueeze(1) 86 | tf_mats[:, :, :3, 3] += transl.unsqueeze(1) 87 | 88 | if not absolute: 89 | tf_mats = torch.einsum('bnij,njk->bnik', tf_mats, self.tfs_c_inv) 90 | 91 | output['smpl_tfs'] = tf_mats 92 | 93 | return output -------------------------------------------------------------------------------- /lib/smpl/body_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems and the Max Planck Institute for Biological 14 | # Cybernetics. All rights reserved. 15 | # 16 | # Contact: ps-license@tuebingen.mpg.de 17 | 18 | from __future__ import absolute_import 19 | from __future__ import print_function 20 | from __future__ import division 21 | 22 | import os 23 | import os.path as osp 24 | 25 | 26 | import pickle 27 | 28 | import numpy as np 29 | 30 | from collections import namedtuple 31 | 32 | import torch 33 | import torch.nn as nn 34 | 35 | from .lbs import ( 36 | lbs, vertices2joints, blend_shapes) 37 | 38 | from .vertex_ids import vertex_ids as VERTEX_IDS 39 | from .utils import Struct, to_np, to_tensor 40 | from .vertex_joint_selector import VertexJointSelector 41 | 42 | 43 | ModelOutput = namedtuple('ModelOutput', 44 | ['vertices','faces', 'joints', 'full_pose', 'betas', 45 | 'global_orient', 46 | 'body_pose', 'expression', 47 | 'left_hand_pose', 'right_hand_pose', 48 | 'jaw_pose', 'T', 'T_weighted', 'weights']) 49 | ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields) 50 | 51 | class SMPL(nn.Module): 52 | 53 | NUM_JOINTS = 23 54 | NUM_BODY_JOINTS = 23 55 | NUM_BETAS = 10 56 | 57 | def __init__(self, model_path, data_struct=None, 58 | create_betas=True, 59 | betas=None, 60 | create_global_orient=True, 61 | global_orient=None, 62 | create_body_pose=True, 63 | body_pose=None, 64 | create_transl=True, 65 | transl=None, 66 | dtype=torch.float32, 67 | batch_size=1, 68 | joint_mapper=None, gender='neutral', 69 | vertex_ids=None, 70 | pose_blend=True, 71 | **kwargs): 72 | ''' SMPL model constructor 73 | 74 | Parameters 75 | ---------- 76 | model_path: str 77 | The path to the folder or to the file where the model 78 | parameters are stored 79 | data_struct: Strct 80 | A struct object. If given, then the parameters of the model are 81 | read from the object. Otherwise, the model tries to read the 82 | parameters from the given `model_path`. (default = None) 83 | create_global_orient: bool, optional 84 | Flag for creating a member variable for the global orientation 85 | of the body. (default = True) 86 | global_orient: torch.tensor, optional, Bx3 87 | The default value for the global orientation variable. 88 | (default = None) 89 | create_body_pose: bool, optional 90 | Flag for creating a member variable for the pose of the body. 91 | (default = True) 92 | body_pose: torch.tensor, optional, Bx(Body Joints * 3) 93 | The default value for the body pose variable. 94 | (default = None) 95 | create_betas: bool, optional 96 | Flag for creating a member variable for the shape space 97 | (default = True). 98 | betas: torch.tensor, optional, Bx10 99 | The default value for the shape member variable. 100 | (default = None) 101 | create_transl: bool, optional 102 | Flag for creating a member variable for the translation 103 | of the body. (default = True) 104 | transl: torch.tensor, optional, Bx3 105 | The default value for the transl variable. 106 | (default = None) 107 | dtype: torch.dtype, optional 108 | The data type for the created variables 109 | batch_size: int, optional 110 | The batch size used for creating the member variables 111 | joint_mapper: object, optional 112 | An object that re-maps the joints. Useful if one wants to 113 | re-order the SMPL joints to some other convention (e.g. MSCOCO) 114 | (default = None) 115 | gender: str, optional 116 | Which gender to load 117 | vertex_ids: dict, optional 118 | A dictionary containing the indices of the extra vertices that 119 | will be selected 120 | ''' 121 | 122 | self.gender = gender 123 | self.pose_blend = pose_blend 124 | 125 | if data_struct is None: 126 | if osp.isdir(model_path): 127 | model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl') 128 | smpl_path = os.path.join(model_path, model_fn) 129 | else: 130 | smpl_path = model_path 131 | assert osp.exists(smpl_path), 'Path {} does not exist!'.format( 132 | smpl_path) 133 | 134 | with open(smpl_path, 'rb') as smpl_file: 135 | data_struct = Struct(**pickle.load(smpl_file,encoding='latin1')) 136 | super(SMPL, self).__init__() 137 | self.batch_size = batch_size 138 | 139 | if vertex_ids is None: 140 | # SMPL and SMPL-H share the same topology, so any extra joints can 141 | # be drawn from the same place 142 | vertex_ids = VERTEX_IDS['smplh'] 143 | 144 | self.dtype = dtype 145 | 146 | self.joint_mapper = joint_mapper 147 | 148 | self.vertex_joint_selector = VertexJointSelector( 149 | vertex_ids=vertex_ids, **kwargs) 150 | 151 | self.faces = data_struct.f 152 | self.register_buffer('faces_tensor', 153 | to_tensor(to_np(self.faces, dtype=np.int64), 154 | dtype=torch.long)) 155 | 156 | if create_betas: 157 | if betas is None: 158 | default_betas = torch.zeros([batch_size, self.NUM_BETAS], 159 | dtype=dtype) 160 | else: 161 | if 'torch.Tensor' in str(type(betas)): 162 | default_betas = betas.clone().detach() 163 | else: 164 | default_betas = torch.tensor(betas, 165 | dtype=dtype) 166 | 167 | self.register_parameter('betas', nn.Parameter(default_betas, 168 | requires_grad=True)) 169 | 170 | # The tensor that contains the global rotation of the model 171 | # It is separated from the pose of the joints in case we wish to 172 | # optimize only over one of them 173 | if create_global_orient: 174 | if global_orient is None: 175 | default_global_orient = torch.zeros([batch_size, 3], 176 | dtype=dtype) 177 | else: 178 | if 'torch.Tensor' in str(type(global_orient)): 179 | default_global_orient = global_orient.clone().detach() 180 | else: 181 | default_global_orient = torch.tensor(global_orient, 182 | dtype=dtype) 183 | 184 | global_orient = nn.Parameter(default_global_orient, 185 | requires_grad=True) 186 | self.register_parameter('global_orient', global_orient) 187 | 188 | if create_body_pose: 189 | if body_pose is None: 190 | default_body_pose = torch.zeros( 191 | [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) 192 | else: 193 | if 'torch.Tensor' in str(type(body_pose)): 194 | default_body_pose = body_pose.clone().detach() 195 | else: 196 | default_body_pose = torch.tensor(body_pose, 197 | dtype=dtype) 198 | self.register_parameter( 199 | 'body_pose', 200 | nn.Parameter(default_body_pose, requires_grad=True)) 201 | 202 | if create_transl: 203 | if transl is None: 204 | default_transl = torch.zeros([batch_size, 3], 205 | dtype=dtype, 206 | requires_grad=True) 207 | else: 208 | default_transl = torch.tensor(transl, dtype=dtype) 209 | self.register_parameter( 210 | 'transl', 211 | nn.Parameter(default_transl, requires_grad=True)) 212 | 213 | # The vertices of the template model 214 | self.register_buffer('v_template', 215 | to_tensor(to_np(data_struct.v_template), 216 | dtype=dtype)) 217 | 218 | # The shape components 219 | shapedirs = data_struct.shapedirs[:, :, :self.NUM_BETAS] 220 | # The shape components 221 | self.register_buffer( 222 | 'shapedirs', 223 | to_tensor(to_np(shapedirs), dtype=dtype)) 224 | 225 | 226 | j_regressor = to_tensor(to_np( 227 | data_struct.J_regressor), dtype=dtype) 228 | self.register_buffer('J_regressor', j_regressor) 229 | 230 | # if self.gender == 'neutral': 231 | # joint_regressor = to_tensor(to_np( 232 | # data_struct.cocoplus_regressor), dtype=dtype).permute(1,0) 233 | # self.register_buffer('joint_regressor', joint_regressor) 234 | 235 | # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207 236 | num_pose_basis = data_struct.posedirs.shape[-1] 237 | # 207 x 20670 238 | posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T 239 | self.register_buffer('posedirs', 240 | to_tensor(to_np(posedirs), dtype=dtype)) 241 | 242 | # indices of parents for each joints 243 | parents = to_tensor(to_np(data_struct.kintree_table[0])).long() 244 | parents[0] = -1 245 | self.register_buffer('parents', parents) 246 | 247 | self.bone_parents = to_np(data_struct.kintree_table[0]) 248 | 249 | self.register_buffer('lbs_weights', 250 | to_tensor(to_np(data_struct.weights), dtype=dtype)) 251 | 252 | def create_mean_pose(self, data_struct): 253 | pass 254 | 255 | @torch.no_grad() 256 | def reset_params(self, **params_dict): 257 | for param_name, param in self.named_parameters(): 258 | if param_name in params_dict: 259 | param[:] = torch.tensor(params_dict[param_name]) 260 | else: 261 | param.fill_(0) 262 | 263 | def get_T_hip(self, betas=None): 264 | v_shaped = self.v_template + blend_shapes(betas, self.shapedirs) 265 | J = vertices2joints(self.J_regressor, v_shaped) 266 | T_hip = J[0,0] 267 | return T_hip 268 | 269 | def get_num_verts(self): 270 | return self.v_template.shape[0] 271 | 272 | def get_num_faces(self): 273 | return self.faces.shape[0] 274 | 275 | def extra_repr(self): 276 | return 'Number of betas: {}'.format(self.NUM_BETAS) 277 | 278 | def forward(self, betas=None, body_pose=None, global_orient=None, 279 | transl=None, return_verts=True, return_full_pose=False,displacement=None,v_template=None, 280 | **kwargs): 281 | ''' Forward pass for the SMPL model 282 | 283 | Parameters 284 | ---------- 285 | global_orient: torch.tensor, optional, shape Bx3 286 | If given, ignore the member variable and use it as the global 287 | rotation of the body. Useful if someone wishes to predicts this 288 | with an external model. (default=None) 289 | betas: torch.tensor, optional, shape Bx10 290 | If given, ignore the member variable `betas` and use it 291 | instead. For example, it can used if shape parameters 292 | `betas` are predicted from some external model. 293 | (default=None) 294 | body_pose: torch.tensor, optional, shape Bx(J*3) 295 | If given, ignore the member variable `body_pose` and use it 296 | instead. For example, it can used if someone predicts the 297 | pose of the body joints are predicted from some external model. 298 | It should be a tensor that contains joint rotations in 299 | axis-angle format. (default=None) 300 | transl: torch.tensor, optional, shape Bx3 301 | If given, ignore the member variable `transl` and use it 302 | instead. For example, it can used if the translation 303 | `transl` is predicted from some external model. 304 | (default=None) 305 | return_verts: bool, optional 306 | Return the vertices. (default=True) 307 | return_full_pose: bool, optional 308 | Returns the full axis-angle pose vector (default=False) 309 | 310 | Returns 311 | ------- 312 | ''' 313 | # If no shape and pose parameters are passed along, then use the 314 | # ones from the module 315 | global_orient = (global_orient if global_orient is not None else 316 | self.global_orient) 317 | body_pose = body_pose if body_pose is not None else self.body_pose 318 | betas = betas if betas is not None else self.betas 319 | 320 | apply_trans = transl is not None or hasattr(self, 'transl') 321 | if transl is None and hasattr(self, 'transl'): 322 | transl = self.transl 323 | 324 | full_pose = torch.cat([global_orient, body_pose], dim=1) 325 | 326 | # if betas.shape[0] != self.batch_size: 327 | # num_repeats = int(self.batch_size / betas.shape[0]) 328 | # betas = betas.expand(num_repeats, -1) 329 | 330 | if v_template is None: 331 | v_template = self.v_template 332 | 333 | if displacement is not None: 334 | vertices, joints_smpl, T_weighted, W, T = lbs(betas, full_pose, v_template+displacement, 335 | self.shapedirs, self.posedirs, 336 | self.J_regressor, self.parents, 337 | self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend) 338 | else: 339 | 340 | 341 | vertices, joints_smpl,T_weighted, W, T = lbs(betas, full_pose, v_template, 342 | self.shapedirs, self.posedirs, 343 | self.J_regressor, self.parents, 344 | self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend) 345 | 346 | # if self.gender is not 'neutral': 347 | joints = self.vertex_joint_selector(vertices, joints_smpl) 348 | # else: 349 | # joints = torch.matmul(vertices.permute(0,2,1),self.joint_regressor).permute(0,2,1) 350 | # Map the joints to the current dataset 351 | if self.joint_mapper is not None: 352 | joints = self.joint_mapper(joints) 353 | 354 | if apply_trans: 355 | joints_smpl += transl.unsqueeze(dim=1) 356 | joints += transl.unsqueeze(dim=1) 357 | vertices += transl.unsqueeze(dim=1) 358 | 359 | output = ModelOutput(vertices=vertices if return_verts else None, 360 | faces=self.faces, 361 | global_orient=global_orient, 362 | body_pose=body_pose, 363 | joints=joints_smpl, 364 | betas=self.betas, 365 | full_pose=full_pose if return_full_pose else None, 366 | T=T, T_weighted=T_weighted, weights=W) 367 | 368 | return output -------------------------------------------------------------------------------- /lib/smpl/lbs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems and the Max Planck Institute for Biological 14 | # Cybernetics. All rights reserved. 15 | # 16 | # Contact: ps-license@tuebingen.mpg.de 17 | 18 | from __future__ import absolute_import 19 | from __future__ import print_function 20 | from __future__ import division 21 | 22 | import numpy as np 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | 27 | from .utils import rot_mat_to_euler 28 | 29 | 30 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx, 31 | dynamic_lmk_b_coords, 32 | neck_kin_chain, dtype=torch.float32): 33 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks 34 | 35 | 36 | To do so, we first compute the rotation of the neck around the y-axis 37 | and then use a pre-computed look-up table to find the faces and the 38 | barycentric coordinates that will be used. 39 | 40 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) 41 | for providing the original TensorFlow implementation and for the LUT. 42 | 43 | Parameters 44 | ---------- 45 | vertices: torch.tensor BxVx3, dtype = torch.float32 46 | The tensor of input vertices 47 | pose: torch.tensor Bx(Jx3), dtype = torch.float32 48 | The current pose of the body model 49 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long 50 | The look-up table from neck rotation to faces 51 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 52 | The look-up table from neck rotation to barycentric coordinates 53 | neck_kin_chain: list 54 | A python list that contains the indices of the joints that form the 55 | kinematic chain of the neck. 56 | dtype: torch.dtype, optional 57 | 58 | Returns 59 | ------- 60 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long 61 | A tensor of size BxL that contains the indices of the faces that 62 | will be used to compute the current dynamic landmarks. 63 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 64 | A tensor of size BxL that contains the indices of the faces that 65 | will be used to compute the current dynamic landmarks. 66 | ''' 67 | 68 | batch_size = vertices.shape[0] 69 | 70 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, 71 | neck_kin_chain) 72 | rot_mats = batch_rodrigues( 73 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) 74 | 75 | rel_rot_mat = torch.eye(3, device=vertices.device, 76 | dtype=dtype).unsqueeze_(dim=0) 77 | for idx in range(len(neck_kin_chain)): 78 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) 79 | 80 | y_rot_angle = torch.round( 81 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, 82 | max=39)).to(dtype=torch.long) 83 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) 84 | mask = y_rot_angle.lt(-39).to(dtype=torch.long) 85 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) 86 | y_rot_angle = (neg_mask * neg_vals + 87 | (1 - neg_mask) * y_rot_angle) 88 | 89 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 90 | 0, y_rot_angle) 91 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 92 | 0, y_rot_angle) 93 | 94 | return dyn_lmk_faces_idx, dyn_lmk_b_coords 95 | 96 | 97 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): 98 | ''' Calculates landmarks by barycentric interpolation 99 | 100 | Parameters 101 | ---------- 102 | vertices: torch.tensor BxVx3, dtype = torch.float32 103 | The tensor of input vertices 104 | faces: torch.tensor Fx3, dtype = torch.long 105 | The faces of the mesh 106 | lmk_faces_idx: torch.tensor L, dtype = torch.long 107 | The tensor with the indices of the faces used to calculate the 108 | landmarks. 109 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 110 | The tensor of barycentric coordinates that are used to interpolate 111 | the landmarks 112 | 113 | Returns 114 | ------- 115 | landmarks: torch.tensor BxLx3, dtype = torch.float32 116 | The coordinates of the landmarks for each mesh in the batch 117 | ''' 118 | # Extract the indices of the vertices for each face 119 | # BxLx3 120 | batch_size, num_verts = vertices.shape[:2] 121 | device = vertices.device 122 | 123 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).expand( 124 | batch_size, -1, -1).long() 125 | 126 | lmk_faces = lmk_faces + torch.arange( 127 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts 128 | 129 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( 130 | batch_size, -1, 3, 3) 131 | 132 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) 133 | return landmarks 134 | 135 | 136 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, 137 | lbs_weights, pose2rot=True, dtype=torch.float32, pose_blend=True): 138 | ''' Performs Linear Blend Skinning with the given shape and pose parameters 139 | 140 | Parameters 141 | ---------- 142 | betas : torch.tensor BxNB 143 | The tensor of shape parameters 144 | pose : torch.tensor Bx(J + 1) * 3 145 | The pose parameters in axis-angle format 146 | v_template torch.tensor BxVx3 147 | The template mesh that will be deformed 148 | shapedirs : torch.tensor 1xNB 149 | The tensor of PCA shape displacements 150 | posedirs : torch.tensor Px(V * 3) 151 | The pose PCA coefficients 152 | J_regressor : torch.tensor JxV 153 | The regressor array that is used to calculate the joints from 154 | the position of the vertices 155 | parents: torch.tensor J 156 | The array that describes the kinematic tree for the model 157 | lbs_weights: torch.tensor N x V x (J + 1) 158 | The linear blend skinning weights that represent how much the 159 | rotation matrix of each part affects each vertex 160 | pose2rot: bool, optional 161 | Flag on whether to convert the input pose tensor to rotation 162 | matrices. The default value is True. If False, then the pose tensor 163 | should already contain rotation matrices and have a size of 164 | Bx(J + 1)x9 165 | dtype: torch.dtype, optional 166 | 167 | Returns 168 | ------- 169 | verts: torch.tensor BxVx3 170 | The vertices of the mesh after applying the shape and pose 171 | displacements. 172 | joints: torch.tensor BxJx3 173 | The joints of the model 174 | ''' 175 | 176 | batch_size = max(betas.shape[0], pose.shape[0]) 177 | device = betas.device 178 | 179 | # Add shape contribution 180 | v_shaped = v_template + blend_shapes(betas, shapedirs) 181 | 182 | # Get the joints 183 | # NxJx3 array 184 | J = vertices2joints(J_regressor, v_shaped) 185 | 186 | # 3. Add pose blend shapes 187 | # N x J x 3 x 3 188 | ident = torch.eye(3, dtype=dtype, device=device) 189 | 190 | 191 | if pose2rot: 192 | rot_mats = batch_rodrigues( 193 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) 194 | 195 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) 196 | # (N x P) x (P, V * 3) -> N x V x 3 197 | pose_offsets = torch.matmul(pose_feature, posedirs) \ 198 | .view(batch_size, -1, 3) 199 | else: 200 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident 201 | rot_mats = pose.view(batch_size, -1, 3, 3) 202 | 203 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), 204 | posedirs).view(batch_size, -1, 3) 205 | 206 | if pose_blend: 207 | v_posed = pose_offsets + v_shaped 208 | else: 209 | v_posed = v_shaped 210 | 211 | # 4. Get the global joint location 212 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) 213 | 214 | # 5. Do skinning: 215 | # W is N x V x (J + 1) 216 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) 217 | # (N x V x (J + 1)) x (N x (J + 1) x 16) 218 | num_joints = J_regressor.shape[0] 219 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ 220 | .view(batch_size, -1, 4, 4) 221 | 222 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], 223 | dtype=dtype, device=device) 224 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) 225 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) 226 | 227 | verts = v_homo[:, :, :3, 0] 228 | 229 | return verts, J_transformed, T, W, A.view(batch_size, num_joints, 4,4) 230 | 231 | 232 | def vertices2joints(J_regressor, vertices): 233 | ''' Calculates the 3D joint locations from the vertices 234 | 235 | Parameters 236 | ---------- 237 | J_regressor : torch.tensor JxV 238 | The regressor array that is used to calculate the joints from the 239 | position of the vertices 240 | vertices : torch.tensor BxVx3 241 | The tensor of mesh vertices 242 | 243 | Returns 244 | ------- 245 | torch.tensor BxJx3 246 | The location of the joints 247 | ''' 248 | 249 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) 250 | 251 | 252 | def blend_shapes(betas, shape_disps): 253 | ''' Calculates the per vertex displacement due to the blend shapes 254 | 255 | 256 | Parameters 257 | ---------- 258 | betas : torch.tensor Bx(num_betas) 259 | Blend shape coefficients 260 | shape_disps: torch.tensor Vx3x(num_betas) 261 | Blend shapes 262 | 263 | Returns 264 | ------- 265 | torch.tensor BxVx3 266 | The per-vertex displacement due to shape deformation 267 | ''' 268 | 269 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] 270 | # i.e. Multiply each shape displacement by its corresponding beta and 271 | # then sum them. 272 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) 273 | return blend_shape 274 | 275 | 276 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 277 | ''' Calculates the rotation matrices for a batch of rotation vectors 278 | Parameters 279 | ---------- 280 | rot_vecs: torch.tensor Nx3 281 | array of N axis-angle vectors 282 | Returns 283 | ------- 284 | R: torch.tensor Nx3x3 285 | The rotation matrices for the given axis-angle parameters 286 | ''' 287 | 288 | batch_size = rot_vecs.shape[0] 289 | device = rot_vecs.device 290 | 291 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 292 | rot_dir = rot_vecs / angle 293 | 294 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 295 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 296 | 297 | # Bx1 arrays 298 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 299 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 300 | 301 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 302 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 303 | .view((batch_size, 3, 3)) 304 | 305 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 306 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 307 | return rot_mat 308 | 309 | 310 | def transform_mat(R, t): 311 | ''' Creates a batch of transformation matrices 312 | Args: 313 | - R: Bx3x3 array of a batch of rotation matrices 314 | - t: Bx3x1 array of a batch of translation vectors 315 | Returns: 316 | - T: Bx4x4 Transformation matrix 317 | ''' 318 | # No padding left or right, only add an extra row 319 | return torch.cat([F.pad(R, [0, 0, 0, 1]), 320 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2) 321 | 322 | 323 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): 324 | """ 325 | Applies a batch of rigid transformations to the joints 326 | 327 | Parameters 328 | ---------- 329 | rot_mats : torch.tensor BxNx3x3 330 | Tensor of rotation matrices 331 | joints : torch.tensor BxNx3 332 | Locations of joints 333 | parents : torch.tensor BxN 334 | The kinematic tree of each object 335 | dtype : torch.dtype, optional: 336 | The data type of the created tensors, the default is torch.float32 337 | 338 | Returns 339 | ------- 340 | posed_joints : torch.tensor BxNx3 341 | The locations of the joints after applying the pose rotations 342 | rel_transforms : torch.tensor BxNx4x4 343 | The relative (with respect to the root joint) rigid transformations 344 | for all the joints 345 | """ 346 | 347 | joints = torch.unsqueeze(joints, dim=-1) 348 | 349 | rel_joints = joints.clone() 350 | rel_joints[:, 1:] -= joints[:, parents[1:]] 351 | 352 | transforms_mat = transform_mat( 353 | rot_mats.reshape(-1, 3, 3), 354 | rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) 355 | 356 | transform_chain = [transforms_mat[:, 0]] 357 | for i in range(1, parents.shape[0]): 358 | # Subtract the joint location at the rest pose 359 | # No need for rotation, since it's identity when at rest 360 | curr_res = torch.matmul(transform_chain[parents[i]], 361 | transforms_mat[:, i]) 362 | transform_chain.append(curr_res) 363 | 364 | transforms = torch.stack(transform_chain, dim=1) 365 | 366 | # The last column of the transformations contains the posed joints 367 | posed_joints = transforms[:, :, :3, 3] 368 | 369 | # The last column of the transformations contains the posed joints 370 | posed_joints = transforms[:, :, :3, 3] 371 | 372 | joints_homogen = F.pad(joints, [0, 0, 0, 1]) 373 | 374 | rel_transforms = transforms - F.pad( 375 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) 376 | 377 | return posed_joints, rel_transforms 378 | -------------------------------------------------------------------------------- /lib/smpl/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems and the Max Planck Institute for Biological 14 | # Cybernetics. All rights reserved. 15 | # 16 | # Contact: ps-license@tuebingen.mpg.de 17 | 18 | from __future__ import print_function 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | import numpy as np 23 | import torch 24 | 25 | 26 | def to_tensor(array, dtype=torch.float32): 27 | if 'torch.tensor' not in str(type(array)): 28 | return torch.tensor(array, dtype=dtype) 29 | 30 | 31 | class Struct(object): 32 | def __init__(self, **kwargs): 33 | for key, val in kwargs.items(): 34 | setattr(self, key, val) 35 | 36 | 37 | def to_np(array, dtype=np.float32): 38 | if 'scipy.sparse' in str(type(array)): 39 | array = array.todense() 40 | return np.array(array, dtype=dtype) 41 | 42 | 43 | def rot_mat_to_euler(rot_mats): 44 | # Calculates rotation matrix to euler angles 45 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 46 | 47 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 48 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 49 | return torch.atan2(-rot_mats[:, 2, 0], sy) 50 | -------------------------------------------------------------------------------- /lib/smpl/vertex_ids.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems and the Max Planck Institute for Biological 14 | # Cybernetics. All rights reserved. 15 | # 16 | # Contact: ps-license@tuebingen.mpg.de 17 | 18 | from __future__ import print_function 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to 23 | # MSCOCO and OpenPose joints 24 | vertex_ids = { 25 | 'smplh': { 26 | 'nose': 332, 27 | 'reye': 6260, 28 | 'leye': 2800, 29 | 'rear': 4071, 30 | 'lear': 583, 31 | 'rthumb': 6191, 32 | 'rindex': 5782, 33 | 'rmiddle': 5905, 34 | 'rring': 6016, 35 | 'rpinky': 6133, 36 | 'lthumb': 2746, 37 | 'lindex': 2319, 38 | 'lmiddle': 2445, 39 | 'lring': 2556, 40 | 'lpinky': 2673, 41 | 'LBigToe': 3216, 42 | 'LSmallToe': 3226, 43 | 'LHeel': 3387, 44 | 'RBigToe': 6617, 45 | 'RSmallToe': 6624, 46 | 'RHeel': 6787 47 | }, 48 | 'smplx': { 49 | 'nose': 9120, 50 | 'reye': 9929, 51 | 'leye': 9448, 52 | 'rear': 616, 53 | 'lear': 6, 54 | 'rthumb': 8079, 55 | 'rindex': 7669, 56 | 'rmiddle': 7794, 57 | 'rring': 7905, 58 | 'rpinky': 8022, 59 | 'lthumb': 5361, 60 | 'lindex': 4933, 61 | 'lmiddle': 5058, 62 | 'lring': 5169, 63 | 'lpinky': 5286, 64 | 'LBigToe': 5770, 65 | 'LSmallToe': 5780, 66 | 'LHeel': 8846, 67 | 'RBigToe': 8463, 68 | 'RSmallToe': 8474, 69 | 'RHeel': 8635 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /lib/smpl/vertex_joint_selector.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems and the Max Planck Institute for Biological 14 | # Cybernetics. All rights reserved. 15 | # 16 | # Contact: ps-license@tuebingen.mpg.de 17 | 18 | from __future__ import absolute_import 19 | from __future__ import print_function 20 | from __future__ import division 21 | 22 | import numpy as np 23 | 24 | import torch 25 | import torch.nn as nn 26 | 27 | from .utils import to_tensor 28 | 29 | 30 | class VertexJointSelector(nn.Module): 31 | 32 | def __init__(self, vertex_ids=None, 33 | use_hands=True, 34 | use_feet_keypoints=True, **kwargs): 35 | super(VertexJointSelector, self).__init__() 36 | 37 | extra_joints_idxs = [] 38 | 39 | face_keyp_idxs = np.array([ 40 | vertex_ids['nose'], 41 | vertex_ids['reye'], 42 | vertex_ids['leye'], 43 | vertex_ids['rear'], 44 | vertex_ids['lear']], dtype=np.int64) 45 | 46 | extra_joints_idxs = np.concatenate([extra_joints_idxs, 47 | face_keyp_idxs]) 48 | 49 | if use_feet_keypoints: 50 | feet_keyp_idxs = np.array([vertex_ids['LBigToe'], 51 | vertex_ids['LSmallToe'], 52 | vertex_ids['LHeel'], 53 | vertex_ids['RBigToe'], 54 | vertex_ids['RSmallToe'], 55 | vertex_ids['RHeel']], dtype=np.int32) 56 | 57 | extra_joints_idxs = np.concatenate( 58 | [extra_joints_idxs, feet_keyp_idxs]) 59 | 60 | if use_hands: 61 | self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky'] 62 | 63 | tips_idxs = [] 64 | for hand_id in ['l', 'r']: 65 | for tip_name in self.tip_names: 66 | tips_idxs.append(vertex_ids[hand_id + tip_name]) 67 | 68 | extra_joints_idxs = np.concatenate( 69 | [extra_joints_idxs, tips_idxs]) 70 | 71 | self.register_buffer('extra_joints_idxs', 72 | to_tensor(extra_joints_idxs, dtype=torch.long)) 73 | 74 | def forward(self, vertices, joints): 75 | extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs) 76 | joints = torch.cat([joints, extra_joints], dim=1) 77 | return joints 78 | -------------------------------------------------------------------------------- /lib/snarf_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | import torch 4 | import wandb 5 | import imageio 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | 9 | from lib.model.smpl import SMPLServer 10 | from lib.model.sample import PointOnBones 11 | from lib.model.network import ImplicitNetwork 12 | from lib.model.metrics import calculate_iou 13 | from lib.utils.meshing import generate_mesh 14 | from lib.model.helpers import masked_softmax 15 | from lib.model.deformer import ForwardDeformer, skinning 16 | from lib.utils.render import render_trimesh, render_joint, weights2colors 17 | 18 | class SNARFModel(pl.LightningModule): 19 | 20 | def __init__(self, opt, meta_info, data_processor=None): 21 | super().__init__() 22 | 23 | self.opt = opt 24 | 25 | self.network = ImplicitNetwork(**opt.network) 26 | self.deformer = ForwardDeformer(opt.deformer) 27 | 28 | print(self.network) 29 | print(self.deformer) 30 | 31 | gender = str(meta_info['gender']) 32 | betas = meta_info['betas'] if 'betas' in meta_info else None 33 | v_template = meta_info['v_template'] if 'v_template' in meta_info else None 34 | 35 | self.smpl_server = SMPLServer(gender=gender, betas=betas, v_template=v_template) 36 | self.smpl_faces = torch.tensor(self.smpl_server.smpl.faces.astype('int')).unsqueeze(0).cuda() 37 | self.sampler_bone = PointOnBones(self.smpl_server.bone_ids) 38 | 39 | self.data_processor = data_processor 40 | 41 | def configure_optimizers(self): 42 | optimizer = torch.optim.Adam(self.parameters(), lr=self.opt.optim.lr) 43 | return optimizer 44 | 45 | def forward(self, pts_d, smpl_tfs, smpl_thetas, eval_mode=True): 46 | 47 | # rectify rest pose 48 | smpl_tfs = torch.einsum('bnij,njk->bnik', smpl_tfs, self.smpl_server.tfs_c_inv) 49 | 50 | cond = {'smpl': smpl_thetas[:,3:]/np.pi} 51 | 52 | batch_points = 60000 53 | 54 | accum_pred = [] 55 | # split to prevent out of memory 56 | for pts_d_split in torch.split(pts_d, batch_points, dim=1): 57 | 58 | # compute canonical correspondences 59 | pts_c, intermediates = self.deformer(pts_d_split, cond, smpl_tfs, eval_mode=eval_mode) 60 | 61 | # query occuancy in canonical space 62 | num_batch, num_point, num_init, num_dim = pts_c.shape 63 | pts_c = pts_c.reshape(num_batch, num_point * num_init, num_dim) 64 | occ_pd = self.network(pts_c, cond).reshape(num_batch, num_point, num_init) 65 | 66 | # aggregate occupancy probablities 67 | mask = intermediates['valid_ids'] 68 | if eval_mode: 69 | occ_pd = masked_softmax(occ_pd, mask, dim=-1, mode='max') 70 | else: 71 | occ_pd = masked_softmax(occ_pd, mask, dim=-1, mode='softmax', soft_blend=self.opt.soft_blend) 72 | 73 | accum_pred.append(occ_pd) 74 | 75 | accum_pred = torch.cat(accum_pred, 1) 76 | 77 | return accum_pred 78 | 79 | def training_step(self, data, data_idx): 80 | 81 | # Data prep 82 | if self.data_processor is not None: 83 | data = self.data_processor.process(data) 84 | 85 | # BCE loss 86 | occ_pd = self.forward(data['pts_d'], data['smpl_tfs'], data['smpl_thetas'], eval_mode=False) 87 | loss_bce = torch.nn.functional.binary_cross_entropy_with_logits(occ_pd, data['occ_gt']) 88 | self.log('train_bce', loss_bce) 89 | loss = loss_bce 90 | 91 | # Bootstrapping 92 | num_batch = data['pts_d'].shape[0] 93 | cond = {'smpl': data['smpl_thetas'][:,3:]/np.pi} 94 | 95 | # Bone occupancy loss 96 | if self.current_epoch < self.opt.nepochs_pretrain: 97 | if self.opt.lambda_bone_occ > 0: 98 | 99 | pts_c, occ_gt = self.sampler_bone.get_points(self.smpl_server.joints_c.expand(num_batch, -1, -1)) 100 | occ_pd = self.network(pts_c, cond) 101 | loss_bone_occ = torch.nn.functional.binary_cross_entropy_with_logits(occ_pd, occ_gt.unsqueeze(-1)) 102 | 103 | loss = loss + self.opt.lambda_bone_occ * loss_bone_occ 104 | self.log('train_bone_occ', loss_bone_occ) 105 | 106 | # Joint weight loss 107 | if self.opt.lambda_bone_w > 0: 108 | 109 | pts_c, w_gt = self.sampler_bone.get_joints(self.smpl_server.joints_c.expand(num_batch, -1, -1)) 110 | w_pd = self.deformer.query_weights(pts_c, cond) 111 | loss_bone_w = torch.nn.functional.mse_loss(w_pd, w_gt) 112 | 113 | loss = loss + self.opt.lambda_bone_w * loss_bone_w 114 | self.log('train_bone_w', loss_bone_w) 115 | 116 | return loss 117 | 118 | def validation_step(self, data, data_idx): 119 | 120 | if self.data_processor is not None: 121 | data = self.data_processor.process(data) 122 | 123 | with torch.no_grad(): 124 | if data_idx == 0: 125 | img_all = self.plot(data)['img_all'] 126 | self.logger.experiment.log({"vis":[wandb.Image(img_all)]}) 127 | 128 | occ_pd = self.forward(data['pts_d'], data['smpl_tfs'], data['smpl_thetas'], eval_mode=True) 129 | 130 | _, num_point, _ = data['occ_gt'].shape 131 | bbox_iou = calculate_iou(data['occ_gt'][:,:num_point//2]>0.5, occ_pd[:,:num_point//2]>0) 132 | surf_iou = calculate_iou(data['occ_gt'][:,num_point//2:]>0.5, occ_pd[:,num_point//2:]>0) 133 | 134 | return {'bbox_iou':bbox_iou, 'surf_iou':surf_iou} 135 | 136 | def validation_epoch_end(self, validation_step_outputs): 137 | 138 | bbox_ious, surf_ious = [], [] 139 | for output in validation_step_outputs: 140 | bbox_ious.append(output['bbox_iou']) 141 | surf_ious.append(output['surf_iou']) 142 | 143 | self.log('valid_bbox_iou', torch.stack(bbox_ious).mean()) 144 | self.log('valid_surf_iou', torch.stack(surf_ious).mean()) 145 | 146 | def test_step(self, data, data_idx): 147 | 148 | with torch.no_grad(): 149 | 150 | occ_pd = self.forward(data['pts_d'], data['smpl_tfs'], data['smpl_thetas'], eval_mode=True) 151 | 152 | _, num_point, _ = data['occ_gt'].shape 153 | bbox_iou = calculate_iou(data['occ_gt'][:,:num_point//2]>0.5, occ_pd[:,:num_point//2]>0) 154 | surf_iou = calculate_iou(data['occ_gt'][:,num_point//2:]>0.5, occ_pd[:,num_point//2:]>0) 155 | 156 | return {'bbox_iou':bbox_iou, 'surf_iou':surf_iou} 157 | 158 | def test_epoch_end(self, test_step_outputs): 159 | return self.validation_epoch_end(test_step_outputs) 160 | 161 | def plot(self, data, res=128, verbose=True, fast_mode=False): 162 | 163 | res_up = np.log2(res//32) 164 | 165 | if verbose: 166 | surf_pred_cano = self.extract_mesh(self.smpl_server.verts_c, data['smpl_tfs'][[0]], data['smpl_thetas'][[0]], res_up=res_up, canonical=True, with_weights=True) 167 | surf_pred_def = self.extract_mesh(data['smpl_verts'][[0]], data['smpl_tfs'][[0]], data['smpl_thetas'][[0]], res_up=res_up, canonical=False, with_weights=False) 168 | 169 | img_pred_cano = render_trimesh(surf_pred_cano) 170 | img_pred_def = render_trimesh(surf_pred_def) 171 | 172 | img_joint = render_joint(data['smpl_jnts'].data.cpu().numpy()[0],self.smpl_server.bone_ids) 173 | img_pred_def[1024:,:,:3] = 255 174 | img_pred_def[1024:-512,:, :3] = img_joint 175 | img_pred_def[1024:-512,:, -1] = 255 176 | 177 | results = { 178 | 'img_all': np.concatenate([img_pred_cano, img_pred_def], axis=1), 179 | 'mesh_cano': surf_pred_cano, 180 | 'mesh_def' : surf_pred_def 181 | } 182 | else: 183 | smpl_verts = self.smpl_server.verts_c if fast_mode else data['smpl_verts'][[0]] 184 | 185 | surf_pred_def = self.extract_mesh(smpl_verts, data['smpl_tfs'][[0]], data['smpl_thetas'][[0]], res_up=res_up, canonical=False, with_weights=False, fast_mode=fast_mode) 186 | 187 | img_pred_def = render_trimesh(surf_pred_def, mode='p') 188 | results = { 189 | 'img_all': img_pred_def, 190 | 'mesh_def' : surf_pred_def 191 | } 192 | 193 | 194 | return results 195 | 196 | def extract_mesh(self, smpl_verts, smpl_tfs, smpl_thetas, canonical=False, with_weights=False, res_up=2, fast_mode=False): 197 | ''' 198 | In fast mode, we extract canonical mesh and then forward skin it to posed space. 199 | This is faster as it bypasses root finding. 200 | However, it's not deforming the continuous field, but the discrete mesh. 201 | ''' 202 | if canonical or fast_mode: 203 | occ_func = lambda x: self.network(x, {'smpl': smpl_thetas[:,3:]/np.pi}).reshape(-1, 1) 204 | else: 205 | occ_func = lambda x: self.forward(x, smpl_tfs, smpl_thetas, eval_mode=True).reshape(-1, 1) 206 | 207 | mesh = generate_mesh(occ_func, smpl_verts.squeeze(0),res_up=res_up) 208 | 209 | 210 | if fast_mode: 211 | verts = torch.tensor(mesh.vertices).type_as(smpl_verts) 212 | weights = self.deformer.query_weights(verts[None], None).clamp(0,1)[0] 213 | 214 | smpl_tfs = torch.einsum('bnij,njk->bnik', smpl_tfs, self.smpl_server.tfs_c_inv) 215 | 216 | verts_mesh_deformed = skinning(verts.unsqueeze(0), weights.unsqueeze(0), smpl_tfs).data.cpu().numpy()[0] 217 | mesh.vertices = verts_mesh_deformed 218 | 219 | if with_weights: 220 | verts = torch.tensor(mesh.vertices).cuda().float() 221 | weights = self.deformer.query_weights(verts[None], None).clamp(0,1)[0] 222 | mesh.visual.vertex_colors = weights2colors(weights.data.cpu().numpy()) 223 | 224 | return mesh 225 | -------------------------------------------------------------------------------- /lib/utils/meshing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage import measure 4 | from lib.libmise import mise 5 | import trimesh 6 | 7 | ''' Code adapted from NASA https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/projects/nasa/lib/utils.py''' 8 | def generate_mesh(func, verts, level_set=0, res_init=32, res_up=3): 9 | 10 | scale = 1.1 # Scale of the padded bbox regarding the tight one. 11 | 12 | verts = verts.data.cpu().numpy() 13 | gt_bbox = np.stack([verts.min(axis=0), verts.max(axis=0)], axis=0) 14 | gt_center = (gt_bbox[0] + gt_bbox[1]) * 0.5 15 | gt_scale = (gt_bbox[1] - gt_bbox[0]).max() 16 | 17 | mesh_extractor = mise.MISE(res_init, res_up, level_set) 18 | points = mesh_extractor.query() 19 | 20 | # query occupancy grid 21 | with torch.no_grad(): 22 | while points.shape[0] != 0: 23 | 24 | orig_points = points 25 | points = points.astype(np.float32) 26 | points = (points / mesh_extractor.resolution - 0.5) * scale 27 | points = points * gt_scale + gt_center 28 | points = torch.tensor(points).float().cuda() 29 | 30 | values = func(points.unsqueeze(0))[:,0] 31 | values = values.data.cpu().numpy().astype(np.float64) 32 | 33 | mesh_extractor.update(orig_points, values) 34 | 35 | points = mesh_extractor.query() 36 | 37 | value_grid = mesh_extractor.to_dense() 38 | # value_grid = np.pad(value_grid, 1, "constant", constant_values=-1e6) 39 | 40 | # marching cube 41 | verts, faces, normals, values = measure.marching_cubes_lewiner( 42 | volume=value_grid, 43 | gradient_direction='ascent', 44 | level=min(level_set, value_grid.max())) 45 | 46 | verts = (verts / mesh_extractor.resolution - 0.5) * scale 47 | verts = verts * gt_scale + gt_center 48 | 49 | meshexport = trimesh.Trimesh(verts, faces, normals, vertex_colors=values) 50 | 51 | # remove disconnect part 52 | connected_comp = meshexport.split(only_watertight=False) 53 | max_area = 0 54 | max_comp = None 55 | for comp in connected_comp: 56 | if comp.area > max_area: 57 | max_area = comp.area 58 | max_comp = comp 59 | meshexport = max_comp 60 | 61 | return meshexport -------------------------------------------------------------------------------- /lib/utils/render.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import cv2 5 | 6 | from pytorch3d.renderer import ( 7 | FoVOrthographicCameras, 8 | RasterizationSettings, 9 | MeshRenderer, 10 | MeshRasterizer, 11 | HardPhongShader, 12 | PointLights 13 | ) 14 | from pytorch3d.structures import Meshes 15 | from pytorch3d.renderer.mesh import Textures 16 | 17 | 18 | class Renderer(): 19 | def __init__(self, image_size=512): 20 | super().__init__() 21 | 22 | self.image_size = image_size 23 | 24 | self.device = torch.device("cuda:0") 25 | torch.cuda.set_device(self.device) 26 | 27 | R = torch.from_numpy(np.array([[-1., 0., 0.], 28 | [0., 1., 0.], 29 | [0., 0., -1.]])).cuda().float().unsqueeze(0) 30 | 31 | 32 | t = torch.from_numpy(np.array([[0., 0.3, 5.]])).cuda().float() 33 | 34 | self.cameras = FoVOrthographicCameras(R=R, T=t,device=self.device) 35 | 36 | self.lights = PointLights(device=self.device,location=[[0.0, 0.0, 3.0]], 37 | ambient_color=((1,1,1),),diffuse_color=((0,0,0),),specular_color=((0,0,0),)) 38 | 39 | self.raster_settings = RasterizationSettings(image_size=image_size,faces_per_pixel=100,blur_radius=0) 40 | self.rasterizer = MeshRasterizer(cameras=self.cameras, raster_settings=self.raster_settings) 41 | 42 | self.shader = HardPhongShader(device=self.device, cameras=self.cameras, lights=self.lights) 43 | 44 | self.renderer = MeshRenderer(rasterizer=self.rasterizer, shader=self.shader) 45 | 46 | def render_mesh(self, verts, faces, colors=None, mode='npat'): 47 | ''' 48 | mode: normal, phong, texture 49 | ''' 50 | with torch.no_grad(): 51 | 52 | mesh = Meshes(verts, faces) 53 | 54 | normals = torch.stack(mesh.verts_normals_list()) 55 | front_light = torch.tensor([0,0,1]).float().to(verts.device) 56 | shades = (normals * front_light.view(1,1,3)).sum(-1).clamp(min=0).unsqueeze(-1).expand(-1,-1,3) 57 | results = [] 58 | 59 | # normal 60 | if 'n' in mode: 61 | normals_vis = normals* 0.5 + 0.5 62 | mesh_normal = Meshes(verts, faces, textures=Textures(verts_rgb=normals_vis)) 63 | image_normal = self.renderer(mesh_normal) 64 | results.append(image_normal) 65 | 66 | # shading 67 | if 'p' in mode: 68 | mesh_shading = Meshes(verts, faces, textures=Textures(verts_rgb=shades)) 69 | image_phong = self.renderer(mesh_shading) 70 | results.append(image_phong) 71 | 72 | # albedo 73 | if 'a' in mode: 74 | assert(colors is not None) 75 | mesh_albido = Meshes(verts, faces, textures=Textures(verts_rgb=colors)) 76 | image_color = self.renderer(mesh_albido) 77 | results.append(image_color) 78 | 79 | # albedo*shading 80 | if 't' in mode: 81 | assert(colors is not None) 82 | mesh_teture = Meshes(verts, faces, textures=Textures(verts_rgb=colors*shades)) 83 | image_color = self.renderer(mesh_teture) 84 | results.append(image_color) 85 | 86 | return torch.cat(results, axis=1) 87 | 88 | image_size = 512 89 | torch.cuda.set_device(torch.device("cuda:0")) 90 | renderer = Renderer(image_size) 91 | 92 | def render(verts, faces, colors=None): 93 | return renderer.render_mesh(verts, faces, colors) 94 | 95 | def render_trimesh(mesh, mode='npta'): 96 | verts = torch.tensor(mesh.vertices).cuda().float()[None] 97 | faces = torch.tensor(mesh.faces).cuda()[None] 98 | colors = torch.tensor(mesh.visual.vertex_colors).float().cuda()[None,...,:3]/255 99 | image = renderer.render_mesh(verts, faces, colors=colors, mode=mode)[0] 100 | image = (255*image).data.cpu().numpy().astype(np.uint8) 101 | return image 102 | 103 | 104 | def render_joint(smpl_jnts, bone_ids): 105 | marker_sz = 6 106 | line_wd = 2 107 | 108 | image = np.ones((image_size, image_size,3), dtype=np.uint8)*255 109 | smpl_jnts[:,1] += 0.3 110 | smpl_jnts[:,1] = -smpl_jnts[:,1] 111 | smpl_jnts = smpl_jnts[:,:2]*image_size/2 + image_size/2 112 | 113 | for b in bone_ids: 114 | if b[0]<0 : continue 115 | joint = smpl_jnts[b[0]] 116 | cv2.circle(image, joint.astype('int32'), color=(0,0,0), radius=marker_sz, thickness=-1) 117 | 118 | joint2 = smpl_jnts[b[1]] 119 | cv2.circle(image, joint2.astype('int32'), color=(0,0,0), radius=marker_sz, thickness=-1) 120 | 121 | cv2.line(image, joint2.astype('int32'), joint.astype('int32'), color=(0,0,0), thickness=int(line_wd)) 122 | 123 | return image 124 | 125 | 126 | 127 | def weights2colors(weights): 128 | import matplotlib.pyplot as plt 129 | 130 | cmap = plt.get_cmap('Paired') 131 | 132 | colors = [ 'pink', #0 133 | 'blue', #1 134 | 'green', #2 135 | 'red', #3 136 | 'pink', #4 137 | 'pink', #5 138 | 'pink', #6 139 | 'green', #7 140 | 'blue', #8 141 | 'red', #9 142 | 'pink', #10 143 | 'pink', #11 144 | 'pink', #12 145 | 'blue', #13 146 | 'green', #14 147 | 'red', #15 148 | 'cyan', #16 149 | 'darkgreen', #17 150 | 'pink', #18 151 | 'pink', #19 152 | 'blue', #20 153 | 'green', #21 154 | 'pink', #22 155 | 'pink' #23 156 | ] 157 | 158 | 159 | color_mapping = {'cyan': cmap.colors[3], 160 | 'blue': cmap.colors[1], 161 | 'darkgreen': cmap.colors[1], 162 | 'green':cmap.colors[3], 163 | 'pink': [1,1,1], 164 | 'red':cmap.colors[5], 165 | } 166 | 167 | for i in range(len(colors)): 168 | colors[i] = np.array(color_mapping[colors[i]]) 169 | 170 | colors = np.stack(colors)[None]# [1x24x3] 171 | verts_colors = weights[:,:,None] * colors 172 | verts_colors = verts_colors.sum(1) 173 | return verts_colors -------------------------------------------------------------------------------- /preprocess/body_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # 19 | # Code Developed by: 20 | # Nima Ghorbani 21 | # 22 | # 2018.12.13 23 | 24 | import numpy as np 25 | 26 | import torch 27 | import torch.nn as nn 28 | 29 | # from smplx.lbs import lbs 30 | from lbs import lbs 31 | import pickle 32 | 33 | class Struct(object): 34 | def __init__(self, **kwargs): 35 | for key, val in kwargs.items(): 36 | setattr(self, key, val) 37 | 38 | def to_tensor(array, dtype=torch.float32): 39 | if 'torch.tensor' not in str(type(array)): 40 | return torch.tensor(array, dtype=dtype) 41 | 42 | def to_np(array, dtype=np.float32): 43 | if 'scipy.sparse' in str(type(array)): 44 | array = array.todense() 45 | return np.array(array, dtype=dtype) 46 | 47 | class BodyModel(nn.Module): 48 | 49 | def __init__(self, 50 | bm_path, 51 | params=None, 52 | num_betas=10, 53 | batch_size=1, v_template = None, 54 | num_dmpls=None, path_dmpl=None, 55 | num_expressions=10, 56 | use_posedirs=True, 57 | dtype=torch.float32): 58 | 59 | super(BodyModel, self).__init__() 60 | 61 | ''' 62 | :param bm_path: path to a SMPL model as pkl file 63 | :param num_betas: number of shape parameters to include. 64 | if betas are provided in params, num_betas would be overloaded with number of thoes betas 65 | :param batch_size: number of smpl vertices to get 66 | :param device: default on gpu 67 | :param dtype: float precision of the compuations 68 | :return: verts, trans, pose, betas 69 | ''' 70 | # Todo: if params the batchsize should be read from one of the params 71 | 72 | self.dtype = dtype 73 | 74 | if params is None: params = {} 75 | # -- Load SMPL params -- 76 | if '.npz' in bm_path: 77 | smpl_dict = np.load(bm_path, encoding='latin1') 78 | elif '.pkl' in bm_path: 79 | with open(bm_path, 'rb') as smpl_file: 80 | smpl_dict = Struct(**pickle.load(smpl_file,encoding='latin1')) 81 | else: 82 | raise ValueError('bm_path should be either a .pkl nor .npz file') 83 | 84 | njoints = smpl_dict.posedirs.shape[2] // 3 85 | self.model_type = {69: 'smpl', 153: 'smplh', 162: 'smplx', 45: 'mano'}[njoints] 86 | 87 | assert self.model_type in ['smpl', 'smplh', 'smplx', 'mano', 'mano'], ValueError( 88 | 'model_type should be in smpl/smplh/smplx/mano.') 89 | 90 | self.use_dmpl = False 91 | if num_dmpls is not None: 92 | if path_dmpl is not None: 93 | self.use_dmpl = True 94 | else: 95 | raise (ValueError('path_dmpl should be provided when using dmpls!')) 96 | 97 | if self.use_dmpl and self.model_type in ['smplx', 'mano']: raise ( 98 | NotImplementedError('DMPLs only work with SMPL/SMPLH models for now.')) 99 | 100 | # Mean template vertices 101 | if v_template is None: 102 | v_template = np.repeat(smpl_dict.v_template[np.newaxis], batch_size, axis=0) 103 | else: 104 | v_template = np.repeat(v_template[np.newaxis], batch_size, axis=0) 105 | 106 | self.register_buffer('v_template', torch.tensor(v_template, dtype=dtype)) 107 | 108 | self.register_buffer('f', torch.tensor(smpl_dict.f.astype(np.int32), dtype=torch.int32)) 109 | 110 | if len(params): 111 | if 'betas' in params.keys(): 112 | num_betas = params['betas'].shape[1] 113 | if 'dmpls' in params.keys(): 114 | num_dmpls = params['dmpls'].shape[1] 115 | 116 | num_total_betas = smpl_dict.shapedirs.shape[-1] 117 | if num_betas < 1: 118 | num_betas = num_total_betas 119 | 120 | shapedirs = smpl_dict.shapedirs[:, :, :num_betas] 121 | 122 | self.register_buffer('shapedirs', torch.tensor(to_np(shapedirs), dtype=dtype)) 123 | 124 | if self.model_type == 'smplx': 125 | begin_shape_id = 300 if smpl_dict.shapedirs.shape[-1] > 300 else 10 126 | exprdirs = smpl_dict.shapedirs[:, :, begin_shape_id:(begin_shape_id + num_expressions)] 127 | self.register_buffer('exprdirs', torch.tensor(exprdirs, dtype=dtype)) 128 | 129 | expression = torch.tensor(np.zeros((batch_size, num_expressions)), dtype=dtype, requires_grad=True) 130 | self.register_parameter('expression', nn.Parameter(expression, requires_grad=True)) 131 | 132 | if self.use_dmpl: 133 | dmpldirs = np.load(path_dmpl)['eigvec'] 134 | 135 | dmpldirs = dmpldirs[:, :, :num_dmpls] 136 | self.register_buffer('dmpldirs', torch.tensor(dmpldirs, dtype=dtype)) 137 | 138 | # Regressor for joint locations given shape - 6890 x 24 139 | self.register_buffer('J_regressor', to_tensor(to_np( 140 | smpl_dict.J_regressor), dtype=dtype)) 141 | 142 | # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*30 x 207 143 | if use_posedirs: 144 | posedirs = smpl_dict.posedirs 145 | posedirs = posedirs.reshape([posedirs.shape[0] * 3, -1]).T 146 | self.register_buffer('posedirs', torch.tensor(posedirs, dtype=dtype)) 147 | else: 148 | self.posedirs = None 149 | 150 | # indices of parents for each joints 151 | kintree_table = smpl_dict.kintree_table.astype(np.int32) 152 | self.register_buffer('kintree_table', torch.tensor(kintree_table, dtype=torch.int32)) 153 | 154 | # LBS weights 155 | # weights = np.repeat(smpl_dict.weights[np.newaxis], batch_size, axis=0) 156 | weights = smpl_dict.weights 157 | self.register_buffer('weights', torch.tensor(weights, dtype=dtype)) 158 | 159 | if 'trans' in params.keys(): 160 | trans = params['trans'] 161 | else: 162 | trans = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True) 163 | self.register_parameter('trans', nn.Parameter(trans, requires_grad=True)) 164 | 165 | # root_orient 166 | # if self.model_type in ['smpl', 'smplh']: 167 | root_orient = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True) 168 | self.register_parameter('root_orient', nn.Parameter(root_orient, requires_grad=True)) 169 | 170 | # pose_body 171 | if self.model_type in ['smpl', 'smplh', 'smplx']: 172 | pose_body = torch.tensor(np.zeros((batch_size, 63)), dtype=dtype, requires_grad=True) 173 | self.register_parameter('pose_body', nn.Parameter(pose_body, requires_grad=True)) 174 | 175 | # pose_hand 176 | if 'pose_hand' in params.keys(): 177 | pose_hand = params['pose_hand'] 178 | else: 179 | if self.model_type in ['smpl']: 180 | pose_hand = torch.tensor(np.zeros((batch_size, 1 * 3 * 2)), dtype=dtype, requires_grad=True) 181 | elif self.model_type in ['smplh', 'smplx']: 182 | pose_hand = torch.tensor(np.zeros((batch_size, 15 * 3 * 2)), dtype=dtype, requires_grad=True) 183 | elif self.model_type in ['mano']: 184 | pose_hand = torch.tensor(np.zeros((batch_size, 15 * 3)), dtype=dtype, requires_grad=True) 185 | self.register_parameter('pose_hand', nn.Parameter(pose_hand, requires_grad=True)) 186 | 187 | # face poses 188 | if self.model_type == 'smplx': 189 | pose_jaw = torch.tensor(np.zeros((batch_size, 1 * 3)), dtype=dtype, requires_grad=True) 190 | self.register_parameter('pose_jaw', nn.Parameter(pose_jaw, requires_grad=True)) 191 | pose_eye = torch.tensor(np.zeros((batch_size, 2 * 3)), dtype=dtype, requires_grad=True) 192 | self.register_parameter('pose_eye', nn.Parameter(pose_eye, requires_grad=True)) 193 | 194 | if 'betas' in params.keys(): 195 | betas = params['betas'] 196 | else: 197 | betas = torch.tensor(np.zeros((batch_size, num_betas)), dtype=dtype, requires_grad=True) 198 | self.register_parameter('betas', nn.Parameter(betas, requires_grad=True)) 199 | 200 | if self.use_dmpl: 201 | if 'dmpls' in params.keys(): 202 | dmpls = params['dmpls'] 203 | else: 204 | dmpls = torch.tensor(np.zeros((batch_size, num_dmpls)), dtype=dtype, requires_grad=True) 205 | self.register_parameter('dmpls', nn.Parameter(dmpls, requires_grad=True)) 206 | self.batch_size = batch_size 207 | 208 | def r(self): 209 | from human_body_prior.tools.omni_tools import copy2cpu as c2c 210 | return c2c(self.forward().v) 211 | 212 | def forward(self, root_orient=None, pose_body=None, pose_hand=None, pose_jaw=None, pose_eye=None, betas=None, 213 | trans=None, dmpls=None, expression=None, return_dict=False, v_template =None, **kwargs): 214 | ''' 215 | 216 | :param root_orient: Nx3 217 | :param pose_body: 218 | :param pose_hand: 219 | :param pose_jaw: 220 | :param pose_eye: 221 | :param kwargs: 222 | :return: 223 | ''' 224 | assert not (v_template is not None and betas is not None), ValueError('vtemplate and betas could not be used jointly.') 225 | assert self.model_type in ['smpl', 'smplh', 'smplx', 'mano', 'mano'], ValueError( 226 | 'model_type should be in smpl/smplh/smplx/mano') 227 | if root_orient is None: root_orient = self.root_orient 228 | if self.model_type in ['smplh', 'smpl']: 229 | if pose_body is None: pose_body = self.pose_body 230 | if pose_hand is None: pose_hand = self.pose_hand 231 | elif self.model_type == 'smplx': 232 | if pose_body is None: pose_body = self.pose_body 233 | if pose_hand is None: pose_hand = self.pose_hand 234 | if pose_jaw is None: pose_jaw = self.pose_jaw 235 | if pose_eye is None: pose_eye = self.pose_eye 236 | elif self.model_type in ['mano', 'mano']: 237 | if pose_hand is None: pose_hand = self.pose_hand 238 | 239 | if pose_hand is None: pose_hand = self.pose_hand 240 | 241 | if trans is None: trans = self.trans 242 | if v_template is None: v_template = self.v_template 243 | if betas is None: betas = self.betas 244 | 245 | if v_template.size(0) != pose_body.size(0): 246 | v_template = v_template[:pose_body.size(0)] # this is fine since actual batch size will 247 | # only be equal to or less than specified batch 248 | # size 249 | 250 | if self.model_type in ['smplh', 'smpl']: 251 | full_pose = torch.cat([root_orient, pose_body, pose_hand], dim=1) 252 | elif self.model_type == 'smplx': 253 | full_pose = torch.cat([root_orient, pose_body, pose_jaw, pose_eye, pose_hand], 254 | dim=1) # orient:3, body:63, jaw:3, eyel:3, eyer:3, handl, handr 255 | elif self.model_type in ['mano', 'mano']: 256 | full_pose = torch.cat([root_orient, pose_hand], dim=1) 257 | 258 | if self.use_dmpl: 259 | if dmpls is None: dmpls = self.dmpls 260 | shape_components = torch.cat([betas, dmpls], dim=-1) 261 | shapedirs = torch.cat([self.shapedirs, self.dmpldirs], dim=-1) 262 | elif self.model_type == 'smplx': 263 | if expression is None: expression = self.expression 264 | shape_components = torch.cat([betas, expression], dim=-1) 265 | shapedirs = torch.cat([self.shapedirs, self.exprdirs], dim=-1) 266 | else: 267 | shape_components = betas 268 | shapedirs = self.shapedirs 269 | 270 | 271 | verts, joints, bone_transforms = lbs(betas=shape_components, pose=full_pose, v_template=v_template, 272 | shapedirs=shapedirs, posedirs=self.posedirs, 273 | J_regressor=self.J_regressor, parents=self.kintree_table[0].long(), 274 | lbs_weights=self.weights, 275 | dtype=self.dtype) 276 | 277 | Jtr = joints + trans.unsqueeze(dim=1) 278 | verts = verts + trans.unsqueeze(dim=1) 279 | 280 | res = {} 281 | res['v'] = verts 282 | res['f'] = self.f 283 | res['bone_transforms'] = bone_transforms 284 | res['betas'] = self.betas 285 | res['Jtr'] = Jtr # Todo: ik can be made with vposer 286 | 287 | if self.model_type == 'smpl': 288 | res['pose_body'] = pose_body 289 | elif self.model_type == 'smplh': 290 | res['pose_body'] = pose_body 291 | res['pose_hand'] = pose_hand 292 | elif self.model_type == 'smplx': 293 | res['pose_body'] = pose_body 294 | res['pose_hand'] = pose_hand 295 | res['pose_jaw'] = pose_jaw 296 | res['pose_eye'] = pose_eye 297 | elif self.model_type in ['mano', 'mano']: 298 | res['pose_hand'] = pose_hand 299 | res['full_pose'] = full_pose 300 | 301 | if not return_dict: 302 | class result_meta(object): 303 | pass 304 | 305 | res_class = result_meta() 306 | for k, v in res.items(): 307 | res_class.__setattr__(k, v) 308 | res = res_class 309 | 310 | return res 311 | 312 | 313 | -------------------------------------------------------------------------------- /preprocess/lbs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), 4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the 5 | # Max Planck Institute for Biological Cybernetics. All rights reserved. 6 | # 7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights 8 | # on this computer program. You can only use this computer program if you have closed a license agreement 9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. 10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution. 11 | # Contact: ps-license@tuebingen.mpg.de 12 | # 13 | # 14 | # If you use this code in a research publication please consider citing the following: 15 | # 16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image 17 | # 18 | # Code Developed by: 19 | # Vassilis Choutas 20 | # For the original and better implementation please refer to : https://github.com/vchoutas/smplx 21 | # 22 | # 2018.01.02 23 | 24 | from __future__ import absolute_import 25 | from __future__ import print_function 26 | from __future__ import division 27 | 28 | import time 29 | import torch 30 | import torch.nn.functional as F 31 | 32 | # TODO: Create merged c++ and CUDA kernel 33 | # TODO: Add argument to choose between python impl and c++/CUDA merged op 34 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, 35 | lbs_weights, num_joints=23, dtype=torch.float32): 36 | ''' Performs Linear Blend Skinning with the given shape and pose parameters 37 | 38 | Parameters 39 | ---------- 40 | betas : torch.tensor BxNB 41 | The tensor of shape parameters 42 | pose : torch.tensor Bx(J + 1) * 3 43 | The pose parameters in axis-angle format 44 | v_template torch.tensor BxVx3 45 | The template mesh that will be deformed 46 | shapedirs : torch.tensor 1xNB 47 | The tensor of PCA shape displacements 48 | posedirs : torch.tensor Px(V * 3) 49 | The pose PCA coefficients 50 | J_regressor : torch.tensor JxV 51 | The regressor array that is used to calculate the joints from 52 | the position of the vertices 53 | parents: torch.tensor J 54 | The array that describes the kinematic tree for the model 55 | lbs_weights: torch.tensor N x V x (J + 1) 56 | The linear blend skinning weights that represent how much the 57 | rotation matrix of each part affects each vertex 58 | pose2rot: bool, optional 59 | Flag on whether to convert the input pose tensor to rotation 60 | matrices. The default value is True. If False, then the pose tensor 61 | should already contain rotation matrices and have a size of 62 | Bx(J + 1)x9 63 | num_joints : int, optional 64 | The number of joints of the model. The default value is equal 65 | to the number of joints of the SMPL body model 66 | dtype: torch.dtype, optional 67 | 68 | Returns 69 | ------- 70 | verts: torch.tensor BxVx3 71 | The vertices of the mesh after applying the shape and pose 72 | displacements. 73 | joints: torch.tensor BxJx3 74 | The joints of the model 75 | ''' 76 | 77 | batch_size = betas.shape[0] 78 | device = betas.device 79 | 80 | # Add shape contribution 81 | v_shaped = v_template + blend_shapes(betas, shapedirs) 82 | 83 | # Get the joints 84 | # NxJx3 array 85 | J = vertices2joints(J_regressor, v_shaped) 86 | 87 | # import numpy as np 88 | 89 | rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3]) 90 | 91 | if posedirs is not None: 92 | # 3. Add pose blend shapes 93 | # N x J x 3 x 3 94 | ident = torch.eye(3, dtype=dtype, device=device) 95 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) 96 | 97 | # (N x P) x (P, V * 3) -> N x V x 3 98 | pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3) 99 | v_posed = pose_offsets + v_shaped 100 | else: 101 | v_posed = v_shaped 102 | 103 | # 4. Get the global joint location 104 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) 105 | 106 | # 5. Do skinning: 107 | # W is N x V x (J + 1) 108 | W = lbs_weights.unsqueeze(dim=0).repeat([batch_size, 1, 1]) 109 | num_joints = J_regressor.shape[0] 110 | # (N x V x (J + 1)) x (N x (J + 1) x 16) 111 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4) 112 | 113 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], dtype=dtype, device=device) 114 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) 115 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) 116 | 117 | verts = v_homo[:, :, :3, 0] 118 | 119 | return verts, J_transformed, A 120 | 121 | 122 | def vertices2joints(J_regressor, vertices): 123 | ''' Calculates the 3D joint locations from the vertices 124 | 125 | Parameters 126 | ---------- 127 | J_regressor : torch.tensor JxV 128 | The regressor array that is used to calculate the joints from the 129 | position of the vertices 130 | vertices : torch.tensor BxVx3 131 | The tensor of mesh vertices 132 | 133 | Returns 134 | ------- 135 | torch.tensor BxJx3 136 | The location of the joints 137 | ''' 138 | 139 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) 140 | 141 | 142 | def blend_shapes(betas, shape_disps): 143 | ''' Calculates the per vertex displacement due to the blend shapes 144 | 145 | 146 | Parameters 147 | ---------- 148 | betas : torch.tensor Bx(num_betas) 149 | Blend shape coefficients 150 | shape_disps: torch.tensor Vx3x(num_betas) 151 | Blend shapes 152 | 153 | Returns 154 | ------- 155 | torch.tensor BxVx3 156 | The per-vertex displacement due to shape deformation 157 | ''' 158 | 159 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] 160 | # i.e. Multiply each shape displacement by its corresponding beta and 161 | # then sum them. 162 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) 163 | return blend_shape 164 | 165 | 166 | def batch_rodrigues(aa_rots): 167 | ''' 168 | convert batch of rotations in axis-angle representation to matrix representation 169 | :param aa_rots: Nx3 170 | :return: mat_rots: Nx3x3 171 | ''' 172 | 173 | dtype = aa_rots.dtype 174 | device = aa_rots.device 175 | 176 | batch_size = aa_rots.shape[0] 177 | 178 | angle = torch.norm(aa_rots + 1e-8, dim=1, keepdim=True) 179 | rot_dir = aa_rots / angle 180 | 181 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 182 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 183 | 184 | # Bx1 arrays 185 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 186 | 187 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 188 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ 189 | .view((batch_size, 3, 3)) 190 | 191 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 192 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 193 | return rot_mat 194 | 195 | 196 | def transform_mat(R, t): 197 | ''' Creates a batch of transformation matrices 198 | Args: 199 | - R: Bx3x3 array of a batch of rotation matrices 200 | - t: Bx3x1 array of a batch of translation vectors 201 | Returns: 202 | - T: Bx4x4 Transformation matrix 203 | ''' 204 | # No padding left or right, only add an extra row 205 | return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) 206 | 207 | 208 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): 209 | """ 210 | Applies a batch of rigid transformations to the joints 211 | 212 | Parameters 213 | ---------- 214 | rot_mats : torch.tensor BxNx3x3 215 | Tensor of rotation matrices 216 | joints : torch.tensor BxNx3 217 | Locations of joints 218 | parents : torch.tensor BxN 219 | The kinematic tree of each object 220 | dtype : torch.dtype, optional: 221 | The data type of the created tensors, the default is torch.float32 222 | 223 | Returns 224 | ------- 225 | posed_joints : torch.tensor BxNx3 226 | The locations of the joints after applying the pose rotations 227 | rel_transforms : torch.tensor BxNx4x4 228 | The relative (with respect to the root joint) rigid transformations 229 | for all the joints 230 | """ 231 | 232 | batch_size = rot_mats.shape[0] 233 | num_joints = joints.shape[1] 234 | device = rot_mats.device 235 | 236 | joints = torch.unsqueeze(joints, dim=-1) 237 | 238 | rel_joints = joints.clone() 239 | rel_joints[:, 1:] -= joints[:, parents[1:]] 240 | 241 | transforms_mat = transform_mat( 242 | rot_mats.reshape(-1, 3, 3), 243 | rel_joints.reshape(-1, 3, 1)).view(-1, joints.shape[1], 4, 4) 244 | 245 | transform_chain = [transforms_mat[:, 0]] 246 | for i in range(1, parents.shape[0]): 247 | # Subtract the joint location at the rest pose 248 | # No need for rotation, since it's identity when at rest 249 | curr_res = torch.matmul(transform_chain[parents[i]], 250 | transforms_mat[:, i]) 251 | transform_chain.append(curr_res) 252 | 253 | transforms = torch.stack(transform_chain, dim=1) 254 | 255 | # The last column of the transformations contains the posed joints 256 | posed_joints = transforms[:, :, :3, 3] 257 | 258 | joints_homogen = torch.cat([joints, torch.zeros([batch_size, num_joints, 1, 1], dtype=dtype, device=device)],dim=2) 259 | init_bone = torch.matmul(transforms, joints_homogen) 260 | init_bone = F.pad(init_bone, [3, 0, 0, 0, 0, 0, 0, 0]) 261 | rel_transforms = transforms - init_bone 262 | 263 | return posed_joints, rel_transforms 264 | -------------------------------------------------------------------------------- /preprocess/sample_points.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | import torch 5 | import trimesh 6 | 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | from body_model import BodyModel 11 | from utils import export_points 12 | 13 | from tqdm import tqdm,trange 14 | from shutil import copyfile 15 | 16 | parser = argparse.ArgumentParser('Read and sample AMASS dataset.') 17 | parser.add_argument('--dataset_path', type=str, default='data/', 18 | help='Path to AMASS dataset.') 19 | 20 | parser.add_argument('--poseprior', action='store_true', 21 | help='Generate data for posprior dataset') 22 | 23 | parser.add_argument('--bm_path', type=str, default='lib/smpl/smpl_model', 24 | help='Path to body model') 25 | 26 | parser.add_argument('--bbox_padding', type=float, default=0., 27 | help='Padding for bounding box') 28 | 29 | parser.add_argument('--output_folder', type=str, default='data/DFaust_processed', 30 | help='Output path for points.') 31 | 32 | parser.add_argument('--points_size', type=int, default=200000, 33 | help='Size of points.') 34 | parser.add_argument('--points_uniform_ratio', type=float, default=.5, 35 | help='Ratio of points to sample uniformly' 36 | 'in bounding box.') 37 | parser.add_argument('--points_sigma', type=float, default=0.01, 38 | help='Standard deviation of gaussian noise added to points' 39 | 'samples on the surfaces.') 40 | parser.add_argument('--points_padding', type=float, default=0.1, 41 | help='Additional padding applied to the uniformly' 42 | 'sampled points on both sides (in total).') 43 | 44 | parser.add_argument('--overwrite', action='store_true', 45 | help='Whether to overwrite output.') 46 | parser.add_argument('--float16', action='store_true', 47 | help='Whether to use half precision.') 48 | parser.add_argument('--packbits', action='store_true', 49 | help='Whether to save truth values as bit array.') 50 | 51 | parser.add_argument('--skip', type=int, default=1, 52 | help='Take every x frames.') 53 | 54 | def process_single_file(vertices, root_orient, pose, joints, bone_transforms, root_loc, frame_name, betas, gender, skinning_weights_vertices, faces, subset, args): 55 | body_mesh = trimesh.Trimesh(vertices=vertices, faces=faces) 56 | # Get extents of model. 57 | bb_min = np.min(vertices, axis=0) 58 | bb_max = np.max(vertices, axis=0) 59 | # total_size = np.sqrt(np.square(bb_max - bb_min).sum()) 60 | total_size = (bb_max - bb_min).max() 61 | 62 | # Set the center (although this should usually be the origin already). 63 | loc = np.array( 64 | [(bb_min[0] + bb_max[0]) / 2, 65 | (bb_min[1] + bb_max[1]) / 2, 66 | (bb_min[2] + bb_max[2]) / 2] 67 | ) 68 | # Scales all dimensions equally. 69 | scale = total_size / (1 - args.bbox_padding) 70 | export_points(body_mesh, subset, frame_name, loc, scale, args, joints=joints, bone_transforms=bone_transforms, root_loc=root_loc, root_orient=root_orient, pose=pose, betas=betas, gender=gender, skinning_weights_vertices=skinning_weights_vertices, vertices=vertices, faces=faces) 71 | 72 | 73 | def amass_extract(args): 74 | dfaust_dir = os.path.join(args.dataset_path, 'DFaust_67') 75 | subjects = [os.path.basename(s_dir) for s_dir in sorted(glob.glob(os.path.join(dfaust_dir, '*')))] 76 | 77 | for subject in tqdm(subjects): 78 | subject_dir = os.path.join(dfaust_dir, subject) 79 | 80 | shape_data = np.load(os.path.join(subject_dir, 'shape.npz')) 81 | 82 | # Save shape data 83 | output_shape_folder = os.path.join(args.output_folder, 'shapes') 84 | if not os.path.exists(output_shape_folder): 85 | os.makedirs(output_shape_folder) 86 | copyfile(os.path.join(subject_dir, 'shape.npz'), os.path.join(output_shape_folder, '%s_shape.npz'%subject)) 87 | 88 | # Generate and save rest-pose for current subject 89 | gender = shape_data['gender'].item() 90 | betas = torch.Tensor(shape_data['betas'][:10]).unsqueeze(0).cuda() 91 | bm_path = os.path.join(args.bm_path, 'SMPL_%s.pkl'%(gender.upper())) 92 | bm = BodyModel(bm_path=bm_path, num_betas=10, batch_size=1).cuda() 93 | 94 | # Get skinning weights 95 | with torch.no_grad(): 96 | body = bm(betas=betas) 97 | vertices = body.v.detach().cpu().numpy()[0] 98 | faces = bm.f.detach().cpu().numpy() 99 | 100 | skinning_weights_vertices = bm.weights 101 | skinning_weights_vertices = skinning_weights_vertices.detach().cpu().numpy() 102 | 103 | # Read pose sequences 104 | sequences = [] 105 | if args.poseprior: 106 | pose_dir = os.path.join(args.dataset_path, 'MPI_Limits', '03099') 107 | for s_dir in glob.glob(os.path.join(pose_dir, '*.npz')): 108 | sequence = os.path.basename(s_dir) 109 | if sequence in ['op2_poses.npz', 'op3_poses.npz', 'op4_poses.npz', 'op5_poses.npz', 'op7_poses.npz', 'op8_poses.npz', 'op9_poses.npz']: 110 | sequences.append(sequence) 111 | else: 112 | pose_dir = subject_dir 113 | for s_dir in glob.glob(os.path.join(pose_dir, '*.npz')): 114 | sequence = os.path.basename(s_dir) 115 | if sequence not in ['shape.npz']: 116 | sequences.append(sequence) 117 | 118 | for sequence in tqdm(sequences): 119 | sequence_path = os.path.join(pose_dir, sequence) 120 | sequence_name = sequence[:] 121 | data = np.load(sequence_path, allow_pickle=True) 122 | 123 | poses = data['poses'][::args.skip] 124 | trans = data['trans'][::args.skip] 125 | 126 | batch_size = poses.shape[0] 127 | bm = BodyModel(bm_path=bm_path, num_betas=10, batch_size=batch_size).cuda() 128 | faces = bm.f.detach().cpu().numpy() 129 | 130 | pose_body = torch.Tensor(poses[:, 3:66]).cuda() 131 | pose_hand = torch.Tensor(poses[:, 66:72]).cuda() 132 | pose = torch.Tensor(poses[:, :72]).cuda() 133 | root_orient = torch.Tensor(poses[:, :3]).cuda() 134 | trans = torch.zeros(batch_size, 3, dtype=torch.float32).cuda() 135 | 136 | with torch.no_grad(): 137 | 138 | body = bm(root_orient=root_orient, pose_body=pose_body, pose_hand=pose_hand, betas=betas.expand(batch_size,-1), trans=trans) 139 | 140 | trans_ = F.pad(trans, [0, 1]).view(batch_size, 1, -1, 1) 141 | trans_ = torch.cat([torch.zeros(batch_size, 1, 4, 3, device=trans_.device), trans_], dim=-1) 142 | bone_transforms = body.bone_transforms + trans_ 143 | bone_transforms = torch.inverse(bone_transforms).detach().cpu().numpy() 144 | bone_transforms = bone_transforms[:, :, :] 145 | 146 | pose_body = pose_body.detach().cpu().numpy() 147 | pose = pose.detach().cpu().numpy() 148 | joints = body.Jtr.detach().cpu().numpy() 149 | vertices = body.v.detach().cpu().numpy() 150 | trans = trans.detach().cpu().numpy() 151 | root_orient = root_orient.detach().cpu().numpy() 152 | 153 | for f_idx in trange(batch_size): 154 | frame_name = sequence_name + '_{:06d}'.format(f_idx) 155 | process_single_file(vertices[f_idx], 156 | root_orient[f_idx], 157 | pose[f_idx], 158 | joints[f_idx], 159 | bone_transforms[f_idx], 160 | trans[f_idx], 161 | frame_name, 162 | betas[0].detach().cpu().numpy(), 163 | gender, 164 | skinning_weights_vertices, 165 | faces, 166 | subject, 167 | args) 168 | 169 | def main(args): 170 | amass_extract(args) 171 | 172 | 173 | if __name__ == '__main__': 174 | args = parser.parse_args() 175 | main(args) 176 | -------------------------------------------------------------------------------- /preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYOPENGL_PLATFORM'] = 'osmesa' 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def export_points(mesh, subject, modelname, loc, scale, args, **kwargs): 8 | if not mesh.is_watertight: 9 | print('Warning: mesh %s is not watertight!' 10 | 'Cannot sample points.' % modelname) 11 | return 12 | 13 | kwargs_new = {} 14 | for k, v in kwargs.items(): 15 | if v is not None: 16 | kwargs_new[k] = v 17 | 18 | filename = os.path.join(args.output_folder, 'points', subject, modelname + '.npz') 19 | 20 | if not args.overwrite and os.path.exists(filename): 21 | print('Points already exist: %s' % filename) 22 | return 23 | 24 | n_points_uniform = int(args.points_size * args.points_uniform_ratio) 25 | n_points_surface = args.points_size - n_points_uniform 26 | 27 | boxsize = 1 + args.points_padding 28 | points_uniform = np.random.rand(n_points_uniform, 3) 29 | points_uniform = boxsize * (points_uniform - 0.5) 30 | # Scale points in (padded) unit box back to the original space 31 | points_uniform *= scale 32 | points_uniform += np.expand_dims(loc, axis=0) 33 | # Sample points around mesh surface 34 | points_surface = mesh.sample(n_points_surface) 35 | points_surface += args.points_sigma * np.random.randn(n_points_surface, 3) 36 | points = np.concatenate([points_uniform, points_surface], axis=0) 37 | points = torch.tensor(points).cuda().float().unsqueeze(0) 38 | vertices = torch.tensor(kwargs['vertices']).cuda().float().unsqueeze(0) 39 | faces = torch.tensor(kwargs['faces'], dtype=torch.int64).cuda() 40 | 41 | import kaolin 42 | occupancies = kaolin.ops.mesh.check_sign(vertices, faces, points) 43 | 44 | 45 | points = points.cpu().numpy() 46 | occupancies = occupancies.cpu().numpy() 47 | 48 | # Compress 49 | if args.float16: 50 | dtype = np.float16 51 | else: 52 | dtype = np.float32 53 | 54 | points = points.astype(dtype) 55 | 56 | if args.packbits: 57 | occupancies = np.packbits(occupancies) 58 | 59 | if not os.path.exists(os.path.dirname(filename)): 60 | os.makedirs(os.path.dirname(filename)) 61 | 62 | # print('Writing points: %s' % filename) 63 | np.savez(filename, points=points, occupancies=occupancies, 64 | loc=loc, scale=scale, 65 | **kwargs_new) 66 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Set-up script for installing extension modules.""" 15 | from Cython.Build import cythonize 16 | import numpy 17 | from setuptools import Extension 18 | from setuptools import setup 19 | 20 | # Get the numpy include directory. 21 | numpy_include_dir = numpy.get_include() 22 | 23 | # mise (efficient mesh extraction) 24 | mise_module = Extension( 25 | "lib.libmise.mise", 26 | sources=["lib/libmise/mise.pyx"], 27 | ) 28 | 29 | # Gather all extension modules 30 | ext_modules = [ 31 | mise_module, 32 | ] 33 | 34 | setup(ext_modules=cythonize(ext_modules),) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import glob 4 | import hydra 5 | import torch 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | 9 | from lib.snarf_model import SNARFModel 10 | 11 | @hydra.main(config_path="config", config_name="config") 12 | def main(opt): 13 | 14 | print(opt.pretty()) 15 | pl.seed_everything(42, workers=True) 16 | torch.set_num_threads(10) 17 | 18 | datamodule = hydra.utils.instantiate(opt.datamodule, opt.datamodule) 19 | datamodule.setup(stage='test') 20 | 21 | trainer = pl.Trainer(**opt.trainer) 22 | 23 | if opt.epoch == 'last': 24 | checkpoint_path = './checkpoints/last.ckpt' 25 | else: 26 | checkpoint_path = glob.glob('./checkpoints/epoch=%d*.ckpt'%opt.epoch)[0] 27 | 28 | model = SNARFModel.load_from_checkpoint( 29 | checkpoint_path=checkpoint_path, 30 | opt=opt.model, 31 | meta_info=datamodule.meta_info 32 | ) 33 | # use all bones for initialization during testing 34 | model.deformer.init_bones = np.arange(24) 35 | 36 | results = trainer.test(model, datamodule=datamodule, verbose=True) 37 | 38 | np.savetxt('./results_%s_%s_%s.txt'%(os.path.basename(opt.datamodule.dataset_path),opt.datamodule.subject, str(opt.epoch)), np.array([results[0]['valid_bbox_iou'], results[0]['valid_surf_iou']])) 39 | 40 | if __name__ == '__main__': 41 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | import hydra 4 | import torch 5 | import yaml 6 | import os 7 | import numpy as np 8 | 9 | from lib.snarf_model import SNARFModel 10 | 11 | @hydra.main(config_path="config", config_name="config") 12 | def main(opt): 13 | 14 | print(opt.pretty()) 15 | 16 | pl.seed_everything(42, workers=True) 17 | 18 | torch.set_num_threads(10) 19 | 20 | # dataset 21 | datamodule = hydra.utils.instantiate(opt.datamodule, opt.datamodule) 22 | datamodule.setup(stage='fit') 23 | np.savez('meta_info.npz', **datamodule.meta_info) 24 | 25 | data_processor = None 26 | if 'processor' in opt.datamodule: 27 | data_processor = hydra.utils.instantiate(opt.datamodule.processor, 28 | opt.datamodule.processor, 29 | meta_info=datamodule.meta_info) 30 | 31 | # logger 32 | with open('.hydra/config.yaml', 'r') as f: 33 | config = yaml.load(f, Loader=yaml.FullLoader) 34 | logger = pl.loggers.WandbLogger(project='snarf', config=config) 35 | 36 | # checkpoint 37 | checkpoint_path = './checkpoints/last.ckpt' 38 | if not os.path.exists(checkpoint_path) or not opt.resume: 39 | checkpoint_path = None 40 | 41 | checkpoint_callback = pl.callbacks.ModelCheckpoint(save_top_k=-1, 42 | monitor=None, 43 | dirpath='./checkpoints', 44 | save_last=True, 45 | every_n_val_epochs=1) 46 | 47 | 48 | trainer = pl.Trainer(logger=logger, 49 | callbacks=[checkpoint_callback], 50 | accelerator=None, 51 | resume_from_checkpoint=checkpoint_path, 52 | **opt.trainer) 53 | 54 | model = SNARFModel(opt=opt.model, 55 | meta_info=datamodule.meta_info, 56 | data_processor=data_processor) 57 | 58 | trainer.fit(model, datamodule=datamodule) 59 | 60 | if __name__ == '__main__': 61 | main() --------------------------------------------------------------------------------