├── .gitignore
├── LICENSE
├── README.md
├── assets
├── 1.gif
├── 2.gif
└── 3.gif
├── config
├── config.yaml
├── datamodule
│ ├── cape.yaml
│ ├── dfaust.yaml
│ └── jointlim.yaml
└── experiments
│ └── cape.yaml
├── demo.py
├── download_data.sh
├── environment.yml
├── lib
├── dataset
│ ├── cape.py
│ ├── dfaust.py
│ ├── dfaust_split.yml
│ └── jointlim.py
├── libmise
│ ├── mise.cpp
│ └── mise.pyx
├── model
│ ├── broyden.py
│ ├── deformer.py
│ ├── helpers.py
│ ├── metrics.py
│ ├── network.py
│ ├── sample.py
│ └── smpl.py
├── smpl
│ ├── body_models.py
│ ├── lbs.py
│ ├── utils.py
│ ├── vertex_ids.py
│ └── vertex_joint_selector.py
├── snarf_model.py
└── utils
│ ├── meshing.py
│ └── render.py
├── preprocess
├── body_model.py
├── lbs.py
├── sample_points.py
└── utils.py
├── setup.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib64/
18 | parts/
19 | sdist/
20 | var/
21 | wheels/
22 | share/python-wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | *.py,cover
49 | .hypothesis/
50 | .pytest_cache/
51 | cover/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | .pybuilder/
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | # For a library or package, you might want to ignore these files since the code is
86 | # intended to run in multiple environments; otherwise, check them in:
87 | # .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # pytype static type analyzer
134 | .pytype/
135 |
136 | # Cython debug symbols
137 | cython_debug/
138 |
139 | **.pyc
140 | **.in
141 | **.rst
142 | **.ini
143 |
144 | # Custom
145 | **.out
146 | data
147 | outputs
148 | wandb
149 | multirun
150 | lib/smpl/smpl_model
151 | **.pkl
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Xu Chen, Yufeng Zheng, Michael J. Black, Otmar Hilliges, Andreas Geiger
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes
2 | ## [Paper](https://arxiv.org/pdf/2104.03953.pdf) | [Supp](https://bit.ly/3t1Tk6F) | [Video](https://youtu.be/rCEpFTKjFHE) | [Project Page](https://xuchen-ethz.github.io/snarf) | Blog ([AIT](https://eth-ait.medium.com/animate-implicit-shapes-with-forward-skinning-c7ebbf355694),[AVG](https://autonomousvision.github.io/snarf/))
3 |
4 |
5 |
6 |
7 | Official code release for ICCV 2021 paper [*SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes*](https://arxiv.org/pdf/2104.03953.pdf). We propose a novel forward skinning module to animate neural implicit shapes with good generalization to unseen poses.
8 |
9 | **Update:** we have released an improved version, FastSNARF, which is 150x faster than SNARF. Check it out [here](https://github.com/xuchen-ethz/fast-snarf).
10 |
11 | If you find our code or paper useful, please cite as
12 | ```
13 | @inproceedings{chen2021snarf,
14 | title={SNARF: Differentiable Forward Skinning for Animating Non-Rigid Neural Implicit Shapes},
15 | author={Chen, Xu and Zheng, Yufeng and Black, Michael J and Hilliges, Otmar and Geiger, Andreas},
16 | booktitle={International Conference on Computer Vision (ICCV)},
17 | year={2021}
18 | }
19 | ```
20 |
21 | # Quick Start
22 | Clone this repo:
23 | ```
24 | git clone https://github.com/xuchen-ethz/snarf.git
25 | cd snarf
26 | ```
27 |
28 | Install environment:
29 | ```
30 | conda env create -f environment.yml
31 | conda activate snarf
32 | python setup.py install
33 | ```
34 |
35 |
36 | Download [SMPL models](https://smpl.is.tue.mpg.de) (1.0.0 for Python 2.7 (10 shape PCs)) and move them to the corresponding places:
37 | ```
38 | mkdir lib/smpl/smpl_model/
39 | mv /path/to/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl lib/smpl/smpl_model/SMPL_FEMALE.pkl
40 | mv /path/to/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl lib/smpl/smpl_model/SMPL_MALE.pkl
41 | ```
42 |
43 | Download our pretrained models and test motion sequences:
44 | ```
45 | sh ./download_data.sh
46 | ```
47 |
48 | Run a quick demo for clothed human:
49 | ```
50 | python demo.py expname=cape subject=3375 demo.motion_path=data/aist_demo/seqs +experiments=cape
51 | ```
52 | You can the find the video in `outputs/cape/3375/demo.mp4` and images in `outputs/cape/3375/images/`. To save the meshes, add `demo.save_mesh=true` to the command.
53 |
54 | You can also try other subjects (see `outputs/data/cape` for available options) by setting `subject=xx`, and other motion sequences from [AMASS](https://amass.is.tue.mpg.de/download.php) by setting `demo.motion_path=/path/to/amass_modetion.npz`.
55 |
56 | Some motion sequences have high fps and one might want to skip some frames. To do this, add `demo.every_n_frames=x` to consider every x frame in the motion sequence. (e.g. `demo.every_n_frames=10` for PosePrior sequences)
57 |
58 | By default, we use `demo.fast_mode=true` for fast mesh extraction. In this mode, we first extract mesh in canonical space, and then forward skin the mesh to posed space. This bypasses the root finding during inference, thus is faster. However, it's not really deforming a continuous field. To first deform the continuous field and then extract mesh in deformed space, use `demo.fast_mode=false` instead.
59 |
60 | # Training and Evaluation
61 |
62 | ## Install Additional Dependencies
63 | Install [kaolin](https://kaolin.readthedocs.io/en/latest/notes/installation.html) for fast occupancy query from meshes.
64 | ```
65 | git clone https://github.com/NVIDIAGameWorks/kaolin
66 | cd kaolin
67 | git checkout v0.9.0
68 | python setup.py develop
69 | ```
70 | ## Minimally Clothed Human
71 | ### Prepare Datasets
72 | Download the [AMASS](https://amass.is.tue.mpg.de/download.php) dataset. We use ''DFaust Snythetic'' and ''PosePrior'' subsets and SMPL-H format. Unzip the dataset into `data` folder.
73 | ```
74 | tar -xf DFaust67.tar.bz2 -C data
75 | tar -xf MPILimits.tar.bz2 -C data
76 | ```
77 |
78 | Preprocess dataset:
79 | ```
80 | python preprocess/sample_points.py --output_folder data/DFaust_processed
81 | python preprocess/sample_points.py --output_folder data/MPI_processed --skip 10 --poseprior
82 | ```
83 |
84 |
85 | ### Training
86 | Run the following command to train for a specified subject:
87 | ```
88 | python train.py subject=50002
89 | ```
90 | Training logs are available on [wandb](https://wandb.ai/home) (registration needed, free of charge). It should take ~12h on a single 2080Ti.
91 |
92 | ### Evaluation
93 | Run the following command to evaluate the method for a specified subject on within distribution data (DFaust test split):
94 | ```
95 | python test.py subject=50002
96 | ```
97 | and outside destribution (PosePrior):
98 | ```
99 | python test.py subject=50002 datamodule=jointlim
100 | ```
101 |
102 | ### Generate Animation
103 | You can use the trained model to generate animation (same as in Quick Start):
104 | ```
105 | python demo.py expname='dfaust' subject=50002 demo.motion_path='data/aist_demo/seqs'
106 | ```
107 |
108 |
109 | ## Clothed Human
110 |
111 | ### Training
112 | Download the [CAPE](https://cape.is.tue.mpg.de/) dataset and unzip into `data` folder.
113 |
114 | Run the following command to train for a specified subject and clothing type:
115 | ```
116 | python train.py datamodule=cape subject=3375 datamodule.clothing='blazerlong' +experiments=cape
117 | ```
118 | Training logs are available on [wandb](https://wandb.ai/home). It should take ~24h on a single 2080Ti.
119 |
120 | ### Generate Animation
121 | You can use the trained model to generate animation (same as in Quick Start):
122 | ```
123 | python demo.py expname=cape subject=3375 demo.motion_path=data/aist_demo/seqs +experiments=cape
124 | ```
125 |
126 | # Acknowledgement
127 | We use the pre-processing code in [PTF](https://github.com/taconite/PTF) and [LEAP](https://github.com/neuralbodies/leap) with some adaptions (`./preprocess`). The network and sampling part of the code (`lib/model/network.py` and `lib/model/sample.py`) is implemented based on [IGR](https://github.com/amosgropp/IGR) and [IDR](https://github.com/lioryariv/idr). The code for extracting mesh (`lib/utils/meshing.py`) is adapted from [NASA](https://github.com/tensorflow/graphics/tree/master/tensorflow_graphics/projects/nasa). Our implementation of Broyden's method (`lib/model/broyden.py`) is based on [DEQ](https://github.com/locuslab/deq). We sincerely thank these authors for their awesome work.
128 |
--------------------------------------------------------------------------------
/assets/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuchen-ethz/snarf/ae0c893cc049f0f8270eaa401e138dff5d4637b9/assets/1.gif
--------------------------------------------------------------------------------
/assets/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuchen-ethz/snarf/ae0c893cc049f0f8270eaa401e138dff5d4637b9/assets/2.gif
--------------------------------------------------------------------------------
/assets/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuchen-ethz/snarf/ae0c893cc049f0f8270eaa401e138dff5d4637b9/assets/3.gif
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - datamodule: dfaust
3 |
4 | hydra:
5 | run:
6 | dir: outputs/${expname}/${subject}
7 |
8 | expname: dfaust
9 | subject: 50002
10 | epoch: last
11 | resume: false
12 |
13 | trainer:
14 | gradient_clip_val: 0.1
15 | check_val_every_n_epoch: 5
16 | deterministic: true
17 | max_steps: 45000
18 | gpus: 1
19 | model:
20 | # shape MLP
21 | network:
22 | d_in: 3
23 | d_out: 1
24 | depth: 8
25 | width: 256
26 | multires: 0
27 | skip_layer: [4]
28 | cond_layer: [4]
29 | dim_cond_embed: 8
30 | weight_norm: true
31 | geometric_init: true
32 | bias: 1
33 | deformer:
34 | softmax_mode: hierarchical
35 | # LBS MLP
36 | network:
37 | d_in: 3
38 | d_out: 25
39 | depth: 4
40 | width: 128
41 | multires: 0
42 | skip_layer: []
43 | cond_layer: []
44 | dim_cond_embed: 0
45 | weight_norm: true
46 | geometric_init: false
47 | bias: 1
48 | optim:
49 | lr: 1e-3
50 | soft_blend: 5
51 | nepochs_pretrain: 1
52 | lambda_bone_occ: 1
53 | lambda_bone_w: 10
54 |
55 | demo:
56 | motion_path: data/aist_demo/seqs
57 | resolution: 256
58 | save_mesh: false
59 | every_n_frames: 2
60 | output_video_name: aist
61 | verbose: false
62 | fast_mode: true
--------------------------------------------------------------------------------
/config/datamodule/cape.yaml:
--------------------------------------------------------------------------------
1 | datamodule:
2 | _target_: lib.dataset.cape.CAPEDataModule
3 | dataset_path: ./data/CAPE
4 | num_workers: 10
5 | subject: ${subject}
6 | clothing: 'longshort'
7 | batch_size: 8
8 | processor:
9 | _target_: lib.dataset.cape.CAPEDataProcessor
10 | points_per_frame: 2000
11 | sampler:
12 | _target_: lib.model.sample.PointInSpace
13 | global_sigma: 1.8
14 | local_sigma: 0.01
15 |
--------------------------------------------------------------------------------
/config/datamodule/dfaust.yaml:
--------------------------------------------------------------------------------
1 | datamodule:
2 | _target_: lib.dataset.dfaust.DFaustDataModule
3 | dataset_path: ./data/DFaust_processed
4 | subject: ${subject}
5 | num_workers: 10
6 | batch_size: 8
7 | points_per_frame: 2000
8 |
9 |
--------------------------------------------------------------------------------
/config/datamodule/jointlim.yaml:
--------------------------------------------------------------------------------
1 | datamodule:
2 | _target_: lib.dataset.jointlim.JointLimDataModule
3 | dataset_path: ./data/MPILimits_processed
4 | subject: ${subject}
5 | num_workers: 10
--------------------------------------------------------------------------------
/config/experiments/cape.yaml:
--------------------------------------------------------------------------------
1 | expname: cape
2 | trainer:
3 | max_steps: 100000
4 | model:
5 | network:
6 | multires: 4
7 | cond_layer: [0]
8 | dim_cond_embed: -1
9 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import glob
4 | import hydra
5 | import torch
6 | import imageio
7 | import numpy as np
8 | import pytorch_lightning as pl
9 |
10 | from tqdm import trange
11 | from lib.snarf_model import SNARFModel
12 | from lib.model.helpers import rectify_pose
13 |
14 | @hydra.main(config_path="config", config_name="config")
15 | def main(opt):
16 |
17 | print(opt.pretty())
18 | pl.seed_everything(42, workers=True)
19 |
20 | # set up model
21 | meta_info = np.load('meta_info.npz')
22 |
23 | if opt.epoch == 'last':
24 | checkpoint_path = './checkpoints/last.ckpt'
25 | else:
26 | checkpoint_path = glob.glob('./checkpoints/epoch=%d*.ckpt'%opt.epoch)[0]
27 |
28 | model = SNARFModel.load_from_checkpoint(
29 | checkpoint_path=checkpoint_path,
30 | opt=opt.model,
31 | meta_info=meta_info
32 | ).cuda()
33 | # use all bones for initialization during testing
34 | model.deformer.init_bones = np.arange(24)
35 |
36 | # pose format conversion
37 | smplx_to_smpl = list(range(66)) + [72, 73, 74, 117, 118, 119] # SMPLH to SMPL
38 |
39 | # load motion sequence
40 | motion_path = hydra.utils.to_absolute_path(opt.demo.motion_path)
41 | if os.path.isdir(motion_path):
42 | motion_files = sorted(glob.glob(os.path.join(motion_path, '*.npz')))
43 | smpl_params_all = []
44 | for f in motion_files:
45 | f = np.load(f)
46 | smpl_params = np.zeros(86)
47 | smpl_params[0], smpl_params[4:76] = 1, f['pose']
48 | smpl_params = torch.tensor(smpl_params).float().cuda()
49 | smpl_params_all.append(smpl_params)
50 | smpl_params_all = torch.stack(smpl_params_all)
51 |
52 | elif '.npz' in motion_path:
53 | f = np.load(motion_path)
54 | smpl_params_all = np.zeros( (f['poses'].shape[0], 86) )
55 | smpl_params_all[:,0] = 1
56 | if f['poses'].shape[-1] == 72:
57 | smpl_params_all[:, 4:76] = f['poses']
58 | elif f['poses'].shape[-1] == 156:
59 | smpl_params_all[:, 4:76] = f['poses'][:,smplx_to_smpl]
60 |
61 | root_abs = smpl_params_all[0, 4:7].copy()
62 | for i in range(smpl_params_all.shape[0]):
63 | smpl_params_all[i, 4:7] = rectify_pose(smpl_params_all[i, 4:7], root_abs)
64 |
65 | smpl_params_all = torch.tensor(smpl_params_all).float().cuda()
66 |
67 | smpl_params_all = smpl_params_all[::opt.demo.every_n_frames]
68 | # generate data batch
69 | images = []
70 | for i in trange(smpl_params_all.shape[0]):
71 |
72 | smpl_params = smpl_params_all[[i]]
73 | data = model.smpl_server.forward(smpl_params, absolute=True)
74 | data['smpl_thetas'] = smpl_params[:, 4:76]
75 | results = model.plot(data, res=opt.demo.resolution, verbose=opt.demo.verbose, fast_mode=opt.demo.fast_mode)
76 |
77 | images.append(results['img_all'])
78 |
79 | if not os.path.exists('images'):
80 | os.makedirs('images')
81 | imageio.imwrite('images/%04d.png'%i, results['img_all'])
82 |
83 | if opt.demo.save_mesh:
84 | if not os.path.exists('meshes'):
85 | os.makedirs('meshes')
86 | results['mesh_def'].export('meshes/%04d_def.ply'%i)
87 | if opt.demo.verbose:
88 | results['mesh_cano'].export('meshes/%04d_cano.ply'%i)
89 |
90 | imageio.mimsave('%s.mp4'%opt.demo.output_video_name, images)
91 |
92 | if __name__ == '__main__':
93 | main()
--------------------------------------------------------------------------------
/download_data.sh:
--------------------------------------------------------------------------------
1 | mkdir data
2 |
3 | echo Downloading motion sequences...
4 |
5 | wget https://scanimate.is.tue.mpg.de/media/upload/demo_data/aist_demo_seq.zip
6 | unzip aist_demo_seq.zip -d ./data/
7 | rm aist_demo_seq.zip
8 | mv ./data/gLO_sBM_cAll_d14_mLO1_ch05 ./data/aist_demo
9 |
10 | echo Done!
11 |
12 |
13 | mkdir outputs
14 |
15 | echo Downloading pretrained models ...
16 | wget https://dataset.ait.ethz.ch/downloads/fOUiBuCXJy/pretrained_models.zip
17 | unzip pretrained_models.zip -d ./outputs/
18 | mv outputs/pretrained_models/* outputs/
19 | rm outputs/pretrained_models -rf
20 | rm pretrained_models.zip
21 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: snarf
2 | channels:
3 | - anaconda
4 | - conda-forge
5 | - pytorch
6 | - pytorch3d
7 | - defaults
8 | dependencies:
9 | - python=3.7
10 | - pytorch=1.7.0
11 | - torchvision=0.8.0
12 | - cudatoolkit=11.0
13 | - iopath
14 | - fvcore
15 | - pytorch3d
16 | - pip
17 | - pip:
18 | - chumpy
19 | - scikit-image
20 | - imageio-ffmpeg
21 | - trimesh
22 | - wandb
23 | - tqdm
24 | - pytorch-lightning==1.3.3
25 | - opencv-python==4.5.2.54
26 | - hydra-core==1.0.6
27 | - usd-core==20.11 # for kaolin
28 | - Cython==0.29.20 # for kaolin
--------------------------------------------------------------------------------
/lib/dataset/cape.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 |
5 | from torch.utils.data import DataLoader, Dataset
6 | import numpy as np
7 | import os
8 | import glob
9 | import hydra
10 |
11 | import pickle
12 | import kaolin
13 |
14 |
15 | class CAPEDataSet(Dataset):
16 |
17 | def __init__(self, dataset_path, subject=32, clothing='longshort'):
18 |
19 | dataset_path = hydra.utils.to_absolute_path(dataset_path)
20 |
21 | self.regstr_list = glob.glob(os.path.join(dataset_path, 'cape_release', 'sequences', '%05d'%subject, clothing+'_**/*.npz'), recursive=True)
22 |
23 | genders_list = os.path.join(dataset_path, 'cape_release', 'misc', 'subj_genders.pkl')
24 | with open(genders_list,'rb') as f:
25 | self.gender = pickle.load(f, encoding='latin1')['%05d'%subject]
26 |
27 | minimal_body_path = os.path.join(dataset_path, 'cape_release', 'minimal_body_shape', '%05d'%subject, '%05d_minimal.npy'%subject)
28 | self.v_template = np.load(minimal_body_path)
29 |
30 | self.meta_info = {'v_template': self.v_template, 'gender': self.gender}
31 |
32 |
33 | def __getitem__(self, index):
34 |
35 | data = {}
36 |
37 | while True:
38 | try:
39 | regstr = np.load(self.regstr_list[index])
40 | poses = regstr['pose']
41 | break
42 | except:
43 | index = np.random.randint(self.__len__())
44 | print('corrupted npz')
45 |
46 | verts = regstr['v_posed'] - regstr['transl'][None,:]
47 | verts = torch.tensor(verts).float()
48 |
49 | smpl_params = torch.zeros([86]).float()
50 | smpl_params[0] = 1
51 | smpl_params[4:76] = torch.tensor(poses).float()
52 |
53 | data['scan_verts'] = verts
54 | data['smpl_params'] = smpl_params
55 | data['smpl_thetas'] = smpl_params[4:76]
56 | data['smpl_betas'] = smpl_params[76:]
57 |
58 | return data
59 |
60 | def __len__(self):
61 | return len(self.regstr_list)
62 |
63 | ''' Used to generate groud-truth occupancy and bone transformations in batchs during training '''
64 | class CAPEDataProcessor():
65 |
66 | def __init__(self, opt, meta_info, **kwargs):
67 | from lib.model.smpl import SMPLServer
68 |
69 | self.opt = opt
70 | self.gender = meta_info['gender']
71 | self.v_template =meta_info['v_template']
72 |
73 | self.smpl_server = SMPLServer(gender=self.gender, v_template=self.v_template)
74 | self.smpl_faces = torch.tensor(self.smpl_server.smpl.faces.astype('int')).unsqueeze(0).cuda()
75 | self.sampler = hydra.utils.instantiate(opt.sampler)
76 |
77 | def process(self, data):
78 |
79 | smpl_output = self.smpl_server(data['smpl_params'], absolute=True)
80 | data.update(smpl_output)
81 |
82 | num_batch, num_verts, num_dim = smpl_output['smpl_verts'].shape
83 |
84 | random_idx = torch.randint(0, num_verts, [num_batch, self.opt.points_per_frame,1], device=smpl_output['smpl_verts'].device)
85 |
86 | random_pts = torch.gather(data['scan_verts'], 1, random_idx.expand(-1, -1, num_dim))
87 | data['pts_d'] = self.sampler.get_points(random_pts)
88 |
89 | data['occ_gt'] = kaolin.ops.mesh.check_sign(data['scan_verts'], self.smpl_faces[0], data['pts_d']).float().unsqueeze(-1)
90 |
91 | return data
92 |
93 | class CAPEDataModule(pl.LightningDataModule):
94 |
95 | def __init__(self, opt, **kwargs):
96 | super().__init__()
97 | self.opt = opt
98 |
99 | def setup(self, stage=None):
100 |
101 | if stage == 'fit':
102 | self.dataset_train = CAPEDataSet(dataset_path=self.opt.dataset_path,
103 | subject=self.opt.subject,
104 | clothing=self.opt.clothing)
105 |
106 | self.dataset_val = CAPEDataSet(dataset_path=self.opt.dataset_path,
107 | subject=self.opt.subject,
108 | clothing=self.opt.clothing)
109 |
110 | self.meta_info = self.dataset_val.meta_info
111 |
112 |
113 | def train_dataloader(self):
114 | dataloader = DataLoader(self.dataset_train,
115 | batch_size=self.opt.batch_size,
116 | num_workers=self.opt.num_workers,
117 | shuffle=True,
118 | drop_last=True,
119 | pin_memory=True)
120 | return dataloader
121 |
122 | def val_dataloader(self):
123 | dataloader = DataLoader(self.dataset_val,
124 | batch_size=self.opt.batch_size,
125 | num_workers=self.opt.num_workers,
126 | shuffle=False,
127 | drop_last=False,
128 | pin_memory=True)
129 | return dataloader
130 |
131 | def test_dataloader(self):
132 | dataloader = DataLoader(self.dataset_val,
133 | batch_size=1,
134 | num_workers=self.opt.num_workers,
135 | shuffle=False,
136 | drop_last=False,
137 | pin_memory=True)
138 | return dataloader
139 |
--------------------------------------------------------------------------------
/lib/dataset/dfaust.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import yaml
4 | import hydra
5 | import torch
6 | import numpy as np
7 | import pytorch_lightning as pl
8 | from torch.utils.data import DataLoader, Dataset
9 |
10 |
11 | class DFaustDataset(Dataset):
12 | def __init__(self, dataset_path, stage, subject, points_per_frame=2000):
13 |
14 | dataset_path = hydra.utils.to_absolute_path(dataset_path)
15 | split_path = hydra.utils.to_absolute_path('lib/dataset/dfaust_split.yml')
16 | with open(split_path, 'r') as stream:
17 | split = yaml.safe_load(stream)
18 |
19 | self.stage = stage
20 |
21 | shape = np.load(os.path.join(dataset_path, 'shapes', '%d_shape.npz'%subject))
22 | self.betas = torch.tensor(shape['betas'][:10])
23 | self.gender = str(shape['gender'])
24 | self.meta_info = {'betas': shape['betas'][:10], 'gender': self.gender}
25 |
26 | self.frame_names = []
27 | for name in split[str(subject)][stage]:
28 | frame_name_act = glob.glob(os.path.join(dataset_path, 'points', str(subject), str(subject)+ '_' + name+'_poses'+'*.npz'))
29 | self.frame_names.extend(frame_name_act)
30 |
31 | self.points_per_frame = points_per_frame
32 | self.total_points = 100000
33 |
34 |
35 | def __getitem__(self, i):
36 | name = self.frame_names[i]
37 | data = {}
38 | dataset = np.load(name)
39 | if self.stage == 'train':
40 | random_idx = torch.cat([torch.randint(0,self.total_points,[self.points_per_frame//8]), # 1//8 for bbox samples
41 | torch.randint(0,self.total_points,[self.points_per_frame])+self.total_points], # 1 for surface samples
42 | 0)
43 | data['pts_d'] = torch.tensor(dataset['points'][0, random_idx]).float()
44 | data['occ_gt'] = torch.tensor(dataset['occupancies'][0, random_idx]).float().unsqueeze(-1)
45 | elif self.stage == 'val':
46 | data['pts_d'] = torch.tensor(dataset['points'][0, ::20]).float()
47 | data['occ_gt'] = torch.tensor(dataset['occupancies'][0, ::20]).float().unsqueeze(-1)
48 | elif self.stage == 'test':
49 | data['pts_d'] = torch.tensor(dataset['points'][0, :]).float()
50 | data['occ_gt'] = torch.tensor(dataset['occupancies'][0, :]).float().unsqueeze(-1)
51 |
52 | data['smpl_verts'] = torch.tensor(dataset['vertices'])
53 | data['smpl_tfs'] = torch.tensor(dataset['bone_transforms']).inverse()
54 | data['smpl_jnts'] = torch.tensor(dataset['joints'])
55 | data['smpl_thetas'] = torch.tensor(dataset['pose'])
56 | data['smpl_betas'] = self.betas
57 |
58 | return data
59 |
60 | def __len__(self):
61 | return len(self.frame_names)
62 |
63 | class DFaustDataModule(pl.LightningDataModule):
64 |
65 | def __init__(self, opt, **kwargs):
66 | super().__init__()
67 | self.opt = opt
68 |
69 | def setup(self, stage=None):
70 |
71 | if stage == 'fit':
72 | self.dataset_train = DFaustDataset(dataset_path=self.opt.dataset_path,
73 | stage = 'train',
74 | subject=self.opt.subject,
75 | points_per_frame=self.opt.points_per_frame,
76 | )
77 |
78 | self.dataset_val = DFaustDataset(dataset_path=self.opt.dataset_path,
79 | stage = 'val',
80 | subject=self.opt.subject,
81 | )
82 |
83 | self.meta_info = self.dataset_train.meta_info
84 |
85 | elif stage == 'test':
86 | self.dataset_test = DFaustDataset(dataset_path=self.opt.dataset_path,
87 | stage = 'test',
88 | subject=self.opt.subject,
89 | )
90 |
91 | self.meta_info = self.dataset_test.meta_info
92 |
93 | def train_dataloader(self):
94 | dataloader = DataLoader(self.dataset_train,
95 | batch_size=self.opt.batch_size,
96 | num_workers=self.opt.num_workers,
97 | shuffle=True,
98 | drop_last=True,
99 | pin_memory=True)
100 | return dataloader
101 |
102 | def val_dataloader(self):
103 | dataloader = DataLoader(self.dataset_val,
104 | batch_size=1,
105 | num_workers=self.opt.num_workers,
106 | shuffle=False,
107 | drop_last=False,
108 | pin_memory=True)
109 | return dataloader
110 |
111 | def test_dataloader(self):
112 | dataloader = DataLoader(self.dataset_test,
113 | batch_size=1,
114 | num_workers=self.opt.num_workers,
115 | shuffle=False,
116 | drop_last=False,
117 | pin_memory=True)
118 | return dataloader
119 |
--------------------------------------------------------------------------------
/lib/dataset/dfaust_split.yml:
--------------------------------------------------------------------------------
1 | '50002':
2 | test:
3 | - chicken_wings
4 | train:
5 | - hips
6 | - jiggle_on_toes
7 | - jumping_jacks
8 | - knees
9 | - light_hopping_loose
10 | - light_hopping_stiff
11 | - one_leg_jump
12 | - one_leg_loose
13 | - punching
14 | val:
15 | - chicken_wings
16 | '50004':
17 | test:
18 | - hips
19 | train:
20 | - chicken_wings
21 | - jiggle_on_toes
22 | - jumping_jacks
23 | - knees
24 | - light_hopping_loose
25 | - light_hopping_stiff
26 | - one_leg_jump
27 | - one_leg_loose
28 | - punching
29 | val:
30 | - hips
31 | '50007':
32 | test:
33 | - jiggle_on_toes
34 | train:
35 | - chicken_wings
36 | - jumping_jacks
37 | - knees
38 | - light_hopping_loose
39 | - light_hopping_stiff
40 | - one_leg_jump
41 | - one_leg_loose
42 | - punching
43 | - running_on_spot
44 | val:
45 | - jiggle_on_toes
46 | '50009':
47 | test:
48 | - jumping_jacks
49 | train:
50 | - chicken_wings
51 | - hips
52 | - jiggle_on_toes
53 | - light_hopping_loose
54 | - light_hopping_stiff
55 | - one_leg_jump
56 | - one_leg_loose
57 | - punching
58 | - running_on_spot
59 | val:
60 | - jumping_jacks
61 | '50020':
62 | test:
63 | - knees
64 | train:
65 | - chicken_wings
66 | - hips
67 | - jiggle_on_toes
68 | - light_hopping_loose
69 | - light_hopping_stiff
70 | - one_leg_jump
71 | - one_leg_loose
72 | - punching
73 | - running_on_spot
74 | val:
75 | - knees
76 | '50021':
77 | test:
78 | - light_hopping_stiff
79 | train:
80 | - chicken_wings
81 | - hips
82 | - knees
83 | - one_leg_jump
84 | - one_leg_loose
85 | - punching
86 | - running_on_spot
87 | - shake_arms
88 | - shake_hips
89 | val:
90 | - light_hopping_stiff
91 | '50022':
92 | test:
93 | - light_hopping_loose
94 | train:
95 | - hips
96 | - jiggle_on_toes
97 | - jumping_jacks
98 | - knees
99 | - light_hopping_stiff
100 | - one_leg_jump
101 | - one_leg_loose
102 | - punching
103 | - running_on_spot
104 | val:
105 | - light_hopping_loose
106 | '50025':
107 | test:
108 | - one_leg_jump
109 | train:
110 | - chicken_wings
111 | - hips
112 | - jiggle_on_toes
113 | - knees
114 | - light_hopping_loose
115 | - light_hopping_stiff
116 | - one_leg_loose
117 | - punching
118 | - running_on_spot
119 | val:
120 | - one_leg_jump
121 | '50026':
122 | test:
123 | - one_leg_loose
124 | train:
125 | - chicken_wings
126 | - hips
127 | - jiggle_on_toes
128 | - jumping_jacks
129 | - knees
130 | - light_hopping_loose
131 | - light_hopping_stiff
132 | - one_leg_jump
133 | - punching
134 | val:
135 | - one_leg_loose
136 | '50027':
137 | test:
138 | - punching
139 | train:
140 | - hips
141 | - jiggle_on_toes
142 | - jumping_jacks
143 | - knees
144 | - light_hopping_loose
145 | - light_hopping_stiff
146 | - one_leg_jump
147 | - one_leg_loose
148 | - running_on_spot
149 | val:
150 | - punching
--------------------------------------------------------------------------------
/lib/dataset/jointlim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 |
5 | from torch.utils.data import DataLoader, Dataset
6 | import numpy as np
7 | import os
8 | import glob
9 | import hydra
10 |
11 | class JointLimDataSet(Dataset):
12 | def __init__(self, dataset_path, subject):
13 |
14 | dataset_path = hydra.utils.to_absolute_path(dataset_path)
15 |
16 | seqs = ['op2', 'op3', 'op4', 'op5', 'op7', 'op8', 'op9']
17 |
18 | shape = np.load(os.path.join(dataset_path, 'shapes', '%d_shape.npz'%subject))
19 | self.betas = torch.tensor(shape['betas'][:10])
20 | self.gender = str(shape['gender'])
21 | self.meta_info = {'betas': shape['betas'][:10], 'gender': self.gender}
22 |
23 | self.frame_names = []
24 |
25 | for seq in seqs:
26 | self.frame_names += sorted(glob.glob(os.path.join(dataset_path, 'points', str(subject), seq + '_poses'+'*.npz')))
27 |
28 |
29 | def __getitem__(self, index):
30 |
31 | name = self.frame_names[index]
32 | data = {}
33 |
34 | dataset = np.load(name)
35 |
36 | data['pts_d'] = torch.tensor(dataset['points'][0]).float()
37 | data['occ_gt'] = torch.tensor(dataset['occupancies'][0, :]).float().unsqueeze(-1)
38 |
39 | data['smpl_verts'] = torch.tensor(dataset['vertices'])
40 | data['smpl_tfs'] = torch.tensor(dataset['bone_transforms']).inverse()
41 | data['smpl_jnts'] = torch.tensor(dataset['joints'])
42 | data['smpl_thetas'] = torch.tensor(dataset['pose'])
43 | data['smpl_betas'] = self.betas
44 |
45 | return data
46 |
47 | def __len__(self):
48 | return len(self.frame_names)
49 |
50 |
51 | class JointLimDataModule(pl.LightningDataModule):
52 |
53 | def __init__(self, opt, **kwargs):
54 | super().__init__()
55 | self.opt = opt
56 |
57 | def setup(self, stage=None):
58 |
59 | if stage == 'test':
60 | self.dataset_test = JointLimDataSet(dataset_path=self.opt.dataset_path,
61 | subject=self.opt.subject,
62 | )
63 |
64 | self.meta_info = self.dataset_test.meta_info
65 |
66 | def test_dataloader(self):
67 | dataloader = DataLoader(self.dataset_test,
68 | batch_size=1,
69 | num_workers=self.opt.num_workers,
70 | shuffle=False,
71 | drop_last=False,
72 | pin_memory=True)
73 | return dataloader
74 |
--------------------------------------------------------------------------------
/lib/libmise/mise.pyx:
--------------------------------------------------------------------------------
1 | # distutils: language = c++
2 | cimport cython
3 | from cython.operator cimport dereference as dref
4 | from libcpp.vector cimport vector
5 | from libcpp.map cimport map
6 | from libc.math cimport isnan, NAN
7 | import numpy as np
8 |
9 |
10 | cdef struct Vector3D:
11 | int x, y, z
12 |
13 |
14 | cdef struct Voxel:
15 | Vector3D loc
16 | unsigned int level
17 | bint is_leaf
18 | unsigned long children[2][2][2]
19 |
20 |
21 | cdef struct GridPoint:
22 | Vector3D loc
23 | double value
24 | bint known
25 |
26 |
27 | cdef inline unsigned long vec_to_idx(Vector3D coord, long resolution):
28 | cdef unsigned long idx
29 | idx = resolution * resolution * coord.x + resolution * coord.y + coord.z
30 | return idx
31 |
32 |
33 | cdef class MISE:
34 | cdef vector[Voxel] voxels
35 | cdef vector[GridPoint] grid_points
36 | cdef map[long, long] grid_point_hash
37 | cdef readonly int resolution_0
38 | cdef readonly int depth
39 | cdef readonly double threshold
40 | cdef readonly int voxel_size_0
41 | cdef readonly int resolution
42 |
43 | def __cinit__(self, int resolution_0, int depth, double threshold):
44 | self.resolution_0 = resolution_0
45 | self.depth = depth
46 | self.threshold = threshold
47 | self.voxel_size_0 = (1 << depth)
48 | self.resolution = resolution_0 * self.voxel_size_0
49 |
50 | # Create initial voxels
51 | self.voxels.reserve(resolution_0 * resolution_0 * resolution_0)
52 |
53 | cdef Voxel voxel
54 | cdef GridPoint point
55 | cdef Vector3D loc
56 | cdef int i, j, k
57 | for i in range(resolution_0):
58 | for j in range(resolution_0):
59 | for k in range (resolution_0):
60 | loc = Vector3D(
61 | i * self.voxel_size_0,
62 | j * self.voxel_size_0,
63 | k * self.voxel_size_0,
64 | )
65 | voxel = Voxel(
66 | loc=loc,
67 | level=0,
68 | is_leaf=True,
69 | )
70 |
71 | assert(self.voxels.size() == vec_to_idx(Vector3D(i, j, k), resolution_0))
72 | self.voxels.push_back(voxel)
73 |
74 | # Create initial grid points
75 | self.grid_points.reserve((resolution_0 + 1) * (resolution_0 + 1) * (resolution_0 + 1))
76 | for i in range(resolution_0 + 1):
77 | for j in range(resolution_0 + 1):
78 | for k in range(resolution_0 + 1):
79 | loc = Vector3D(
80 | i * self.voxel_size_0,
81 | j * self.voxel_size_0,
82 | k * self.voxel_size_0,
83 | )
84 | assert(self.grid_points.size() == vec_to_idx(Vector3D(i, j, k), resolution_0 + 1))
85 | self.add_grid_point(loc)
86 |
87 | def update(self, long[:, :] points, double[:] values):
88 | """Update points and set their values. Also determine all active voxels and subdivide them."""
89 | assert(points.shape[0] == values.shape[0])
90 | assert(points.shape[1] == 3)
91 | cdef Vector3D loc
92 | cdef long idx
93 | cdef int i
94 |
95 | # Find all indices of point and set value
96 | for i in range(points.shape[0]):
97 | loc = Vector3D(points[i, 0], points[i, 1], points[i, 2])
98 | idx = self.get_grid_point_idx(loc)
99 | if idx == -1:
100 | raise ValueError('Point not in grid!')
101 | self.grid_points[idx].value = values[i]
102 | self.grid_points[idx].known = True
103 | # Subdivide activate voxels and add new points
104 | self.subdivide_voxels()
105 |
106 | def query(self):
107 | """Query points to evaluate."""
108 | # Find all points with unknown value
109 | cdef vector[Vector3D] points
110 | cdef int n_unknown = 0
111 | for p in self.grid_points:
112 | if not p.known:
113 | n_unknown += 1
114 |
115 | points.reserve(n_unknown)
116 | for p in self.grid_points:
117 | if not p.known:
118 | points.push_back(p.loc)
119 |
120 | # Convert to numpy
121 | points_np = np.zeros((points.size(), 3), dtype=np.int64)
122 | cdef long[:, :] points_view = points_np
123 | for i in range(points.size()):
124 | points_view[i, 0] = points[i].x
125 | points_view[i, 1] = points[i].y
126 | points_view[i, 2] = points[i].z
127 |
128 | return points_np
129 |
130 | def to_dense(self):
131 | """Output dense matrix at highest resolution."""
132 | out_array = np.full((self.resolution + 1,) * 3, np.nan)
133 | cdef double[:, :, :] out_view = out_array
134 | cdef GridPoint point
135 | cdef int i, j, k
136 |
137 | for point in self.grid_points:
138 | # Take voxel for which points is upper left corner
139 | # assert(point.known)
140 | out_view[point.loc.x, point.loc.y, point.loc.z] = point.value
141 |
142 | # Complete along x axis
143 | for i in range(1, self.resolution + 1):
144 | for j in range(self.resolution + 1):
145 | for k in range(self.resolution + 1):
146 | if isnan(out_view[i, j, k]):
147 | out_view[i, j, k] = out_view[i-1, j, k]
148 |
149 | # Complete along y axis
150 | for i in range(self.resolution + 1):
151 | for j in range(1, self.resolution + 1):
152 | for k in range(self.resolution + 1):
153 | if isnan(out_view[i, j, k]):
154 | out_view[i, j, k] = out_view[i, j-1, k]
155 |
156 |
157 | # Complete along z axis
158 | for i in range(self.resolution + 1):
159 | for j in range(self.resolution + 1):
160 | for k in range(1, self.resolution + 1):
161 | if isnan(out_view[i, j, k]):
162 | out_view[i, j, k] = out_view[i, j, k-1]
163 | assert(not isnan(out_view[i, j, k]))
164 | return out_array
165 |
166 | def get_points(self):
167 | points_np = np.zeros((self.grid_points.size(), 3), dtype=np.int64)
168 | values_np = np.zeros((self.grid_points.size()), dtype=np.float64)
169 |
170 | cdef long[:, :] points_view = points_np
171 | cdef double[:] values_view = values_np
172 | cdef Vector3D loc
173 | cdef int i
174 |
175 | for i in range(self.grid_points.size()):
176 | loc = self.grid_points[i].loc
177 | points_view[i, 0] = loc.x
178 | points_view[i, 1] = loc.y
179 | points_view[i, 2] = loc.z
180 | values_view[i] = self.grid_points[i].value
181 |
182 | return points_np, values_np
183 |
184 | cdef void subdivide_voxels(self) except +:
185 | cdef vector[bint] next_to_positive
186 | cdef vector[bint] next_to_negative
187 | cdef int i, j, k
188 | cdef long idx
189 | cdef Vector3D loc, adj_loc
190 |
191 | # Initialize vectors
192 | next_to_positive.resize(self.voxels.size(), False)
193 | next_to_negative.resize(self.voxels.size(), False)
194 |
195 | # Iterate over grid points and mark voxels active
196 | # TODO: can move this to update operation and add attibute to voxel
197 | for grid_point in self.grid_points:
198 | loc = grid_point.loc
199 | if not grid_point.known:
200 | continue
201 |
202 | # Iterate over the 8 adjacent voxels
203 | for i in range(-1, 1):
204 | for j in range(-1, 1):
205 | for k in range(-1, 1):
206 | adj_loc = Vector3D(
207 | x=loc.x + i,
208 | y=loc.y + j,
209 | z=loc.z + k,
210 | )
211 | idx = self.get_voxel_idx(adj_loc)
212 | if idx == -1:
213 | continue
214 |
215 | if grid_point.value >= self.threshold:
216 | next_to_positive[idx] = True
217 | if grid_point.value <= self.threshold:
218 | next_to_negative[idx] = True
219 |
220 | cdef int n_subdivide = 0
221 |
222 | for idx in range(self.voxels.size()):
223 | if not self.voxels[idx].is_leaf or self.voxels[idx].level == self.depth:
224 | continue
225 | if next_to_positive[idx] and next_to_negative[idx]:
226 | n_subdivide += 1
227 |
228 | self.voxels.reserve(self.voxels.size() + 8 * n_subdivide)
229 | self.grid_points.reserve(self.voxels.size() + 19 * n_subdivide)
230 |
231 | for idx in range(self.voxels.size()):
232 | if not self.voxels[idx].is_leaf or self.voxels[idx].level == self.depth:
233 | continue
234 | if next_to_positive[idx] and next_to_negative[idx]:
235 | self.subdivide_voxel(idx)
236 |
237 | cdef void subdivide_voxel(self, long idx):
238 | cdef Voxel voxel
239 | cdef GridPoint point
240 | cdef Vector3D loc0 = self.voxels[idx].loc
241 | cdef Vector3D loc
242 | cdef int new_level = self.voxels[idx].level + 1
243 | cdef int new_size = 1 << (self.depth - new_level)
244 | assert(new_level <= self.depth)
245 | assert(1 <= new_size <= self.voxel_size_0)
246 |
247 | # Current voxel is not leaf anymore
248 | self.voxels[idx].is_leaf = False
249 | # Add new voxels
250 | cdef int i, j, k
251 | for i in range(2):
252 | for j in range(2):
253 | for k in range(2):
254 | loc = Vector3D(
255 | x=loc0.x + i * new_size,
256 | y=loc0.y + j * new_size,
257 | z=loc0.z + k * new_size,
258 | )
259 | voxel = Voxel(
260 | loc=loc,
261 | level=new_level,
262 | is_leaf=True
263 | )
264 |
265 | self.voxels[idx].children[i][j][k] = self.voxels.size()
266 | self.voxels.push_back(voxel)
267 |
268 | # Add new grid points
269 | for i in range(3):
270 | for j in range(3):
271 | for k in range(3):
272 | loc = Vector3D(
273 | loc0.x + i * new_size,
274 | loc0.y + j * new_size,
275 | loc0.z + k * new_size,
276 | )
277 |
278 | # Only add new grid points
279 | if self.get_grid_point_idx(loc) == -1:
280 | self.add_grid_point(loc)
281 |
282 |
283 | @cython.cdivision(True)
284 | cdef long get_voxel_idx(self, Vector3D loc) except +:
285 | """Utility function for getting voxel index corresponding to 3D coordinates."""
286 | # Shorthands
287 | cdef long resolution = self.resolution
288 | cdef long resolution_0 = self.resolution_0
289 | cdef long depth = self.depth
290 | cdef long voxel_size_0 = self.voxel_size_0
291 |
292 | # Return -1 if point lies outside bounds
293 | if not (0 <= loc.x < resolution and 0<= loc.y < resolution and 0 <= loc.z < resolution):
294 | return -1
295 |
296 | # Coordinates in coarse voxel grid
297 | cdef Vector3D loc0 = Vector3D(
298 | x=loc.x >> depth,
299 | y=loc.y >> depth,
300 | z=loc.z >> depth,
301 | )
302 |
303 | # Initial voxels
304 | cdef int idx = vec_to_idx(loc0, resolution_0)
305 | cdef Voxel voxel = self.voxels[idx]
306 | assert(voxel.loc.x == loc0.x * voxel_size_0)
307 | assert(voxel.loc.y == loc0.y * voxel_size_0)
308 | assert(voxel.loc.z == loc0.z * voxel_size_0)
309 |
310 | # Relative coordinates
311 | cdef Vector3D loc_rel = Vector3D(
312 | x=loc.x - (loc0.x << depth),
313 | y=loc.y - (loc0.y << depth),
314 | z=loc.z - (loc0.z << depth),
315 | )
316 |
317 | cdef Vector3D loc_offset
318 | cdef long voxel_size = voxel_size_0
319 |
320 | while not voxel.is_leaf:
321 | voxel_size = voxel_size >> 1
322 | assert(voxel_size >= 1)
323 |
324 | # Determine child
325 | loc_offset = Vector3D(
326 | x=1 if (loc_rel.x >= voxel_size) else 0,
327 | y=1 if (loc_rel.y >= voxel_size) else 0,
328 | z=1 if (loc_rel.z >= voxel_size) else 0,
329 | )
330 | # New voxel
331 | idx = voxel.children[loc_offset.x][loc_offset.y][loc_offset.z]
332 | voxel = self.voxels[idx]
333 |
334 | # New relative coordinates
335 | loc_rel = Vector3D(
336 | x=loc_rel.x - loc_offset.x * voxel_size,
337 | y=loc_rel.y - loc_offset.y * voxel_size,
338 | z=loc_rel.z - loc_offset.z * voxel_size,
339 | )
340 |
341 | assert(0<= loc_rel.x < voxel_size)
342 | assert(0<= loc_rel.y < voxel_size)
343 | assert(0<= loc_rel.z < voxel_size)
344 |
345 |
346 | # Return idx
347 | return idx
348 |
349 |
350 | cdef inline void add_grid_point(self, Vector3D loc):
351 | cdef GridPoint point = GridPoint(
352 | loc=loc,
353 | value=0.,
354 | known=False,
355 | )
356 | self.grid_point_hash[vec_to_idx(loc, self.resolution + 1)] = self.grid_points.size()
357 | self.grid_points.push_back(point)
358 |
359 | cdef inline int get_grid_point_idx(self, Vector3D loc):
360 | p_idx = self.grid_point_hash.find(vec_to_idx(loc, self.resolution + 1))
361 | if p_idx == self.grid_point_hash.end():
362 | return -1
363 |
364 | cdef int idx = dref(p_idx).second
365 | assert(self.grid_points[idx].loc.x == loc.x)
366 | assert(self.grid_points[idx].loc.y == loc.y)
367 | assert(self.grid_points[idx].loc.z == loc.z)
368 |
369 | return idx
--------------------------------------------------------------------------------
/lib/model/broyden.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def broyden(g, x_init, J_inv_init, max_steps=50, cvg_thresh=1e-5, dvg_thresh=1, eps=1e-6):
5 | """Find roots of the given function g(x) = 0.
6 | This function is impleneted based on https://github.com/locuslab/deq.
7 |
8 | Tensor shape abbreviation:
9 | N: number of points
10 | D: space dimension
11 | Args:
12 | g (function): the function of which the roots are to be determined. shape: [N, D, 1]->[N, D, 1]
13 | x_init (tensor): initial value of the parameters. shape: [N, D, 1]
14 | J_inv_init (tensor): initial value of the inverse Jacobians. shape: [N, D, D]
15 |
16 | max_steps (int, optional): max number of iterations. Defaults to 50.
17 | cvg_thresh (float, optional): covergence threshold. Defaults to 1e-5.
18 | dvg_thresh (float, optional): divergence threshold. Defaults to 1.
19 | eps (float, optional): a small number added to the denominator to prevent numerical error. Defaults to 1e-6.
20 |
21 | Returns:
22 | result (tensor): root of the given function. shape: [N, D, 1]
23 | diff (tensor): corresponding loss. [N]
24 | valid_ids (tensor): identifiers of converged points. [N]
25 | """
26 |
27 | # initialization
28 | x = x_init.clone().detach()
29 | J_inv = J_inv_init.clone().detach()
30 |
31 | ids_val = torch.ones(x.shape[0]).bool()
32 |
33 | gx = g(x, mask=ids_val)
34 | update = -J_inv.bmm(gx)
35 |
36 | x_opt = x
37 | gx_norm_opt = torch.linalg.norm(gx.squeeze(-1), dim=-1)
38 |
39 | delta_gx = torch.zeros_like(gx)
40 | delta_x = torch.zeros_like(x)
41 |
42 | ids_val = torch.ones_like(gx_norm_opt).bool()
43 |
44 | for _ in range(max_steps):
45 |
46 | # update paramter values
47 | delta_x[ids_val] = update
48 | x[ids_val] += delta_x[ids_val]
49 | delta_gx[ids_val] = g(x, mask=ids_val) - gx[ids_val]
50 | gx[ids_val] += delta_gx[ids_val]
51 |
52 | # store values with minial loss
53 | gx_norm = torch.linalg.norm(gx.squeeze(-1), dim=-1)
54 | ids_opt = gx_norm < gx_norm_opt
55 | gx_norm_opt[ids_opt] = gx_norm.clone().detach()[ids_opt]
56 | x_opt[ids_opt] = x.clone().detach()[ids_opt]
57 |
58 | # exclude converged and diverged points from furture iterations
59 | ids_val = (gx_norm_opt > cvg_thresh) & (gx_norm < dvg_thresh)
60 | if ids_val.sum() <= 0:
61 | break
62 |
63 | # compute paramter update for next iter
64 | vT = (delta_x[ids_val]).transpose(-1, -2).bmm(J_inv[ids_val])
65 | a = delta_x[ids_val] - J_inv[ids_val].bmm(delta_gx[ids_val])
66 | b = vT.bmm(delta_gx[ids_val])
67 | b[b >= 0] += eps
68 | b[b < 0] -= eps
69 | u = a / b
70 | J_inv[ids_val] += u.bmm(vT)
71 | update = -J_inv[ids_val].bmm(gx[ids_val])
72 |
73 | return {'result': x_opt, 'diff': gx_norm_opt, 'valid_ids': gx_norm_opt < cvg_thresh}
74 |
--------------------------------------------------------------------------------
/lib/model/deformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import einsum
3 | import torch.nn.functional as F
4 |
5 | from lib.model.broyden import broyden
6 | from lib.model.network import ImplicitNetwork
7 | from lib.model.helpers import hierarchical_softmax
8 |
9 |
10 | class ForwardDeformer(torch.nn.Module):
11 | """
12 | Tensor shape abbreviation:
13 | B: batch size
14 | N: number of points
15 | J: number of bones
16 | I: number of init
17 | D: space dimension
18 | """
19 |
20 | def __init__(self, opt, **kwargs):
21 | super().__init__()
22 |
23 | self.opt = opt
24 |
25 | self.lbs_network = ImplicitNetwork(**self.opt.network)
26 |
27 | self.soft_blend = 20
28 |
29 | self.init_bones = [0, 1, 2, 4, 5, 16, 17, 18, 19]
30 |
31 | def forward(self, xd, cond, tfs, eval_mode=False):
32 | """Given deformed point return its caonical correspondence
33 |
34 | Args:
35 | xd (tensor): deformed points in batch. shape: [B, N, D]
36 | cond (dict): conditional input.
37 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
38 |
39 | Returns:
40 | xc (tensor): canonical correspondences. shape: [B, N, I, D]
41 | others (dict): other useful outputs.
42 | """
43 | xc_init = self.init(xd, tfs)
44 |
45 | xc_opt, others = self.search(xd, xc_init, cond, tfs, eval_mode=eval_mode)
46 |
47 | if eval_mode:
48 | return xc_opt, others
49 |
50 | # compute correction term for implicit differentiation during training
51 |
52 | # do not back-prop through broyden
53 | xc_opt = xc_opt.detach()
54 |
55 | # reshape to [B,?,D] for network query
56 | n_batch, n_point, n_init, n_dim = xc_init.shape
57 | xc_opt = xc_opt.reshape((n_batch, n_point * n_init, n_dim))
58 |
59 | xd_opt = self.forward_skinning(xc_opt, cond, tfs)
60 |
61 | grad_inv = self.gradient(xc_opt, cond, tfs).inverse()
62 |
63 | correction = xd_opt - xd_opt.detach()
64 | correction = einsum("bnij,bnj->bni", -grad_inv.detach(), correction)
65 |
66 | # trick for implicit diff with autodiff:
67 | # xc = xc_opt + 0 and xc' = correction'
68 | xc = xc_opt + correction
69 |
70 | # reshape back to [B,N,I,D]
71 | xc = xc.reshape(xc_init.shape)
72 |
73 | return xc, others
74 |
75 | def init(self, xd, tfs):
76 | """Transform xd to canonical space for initialization
77 |
78 | Args:
79 | xd (tensor): deformed points in batch. shape: [B, N, D]
80 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
81 |
82 | Returns:
83 | xc_init (tensor): gradients. shape: [B, N, I, D]
84 | """
85 | n_batch, n_point, _ = xd.shape
86 | _, n_joint, _, _ = tfs.shape
87 |
88 | xc_init = []
89 | for i in self.init_bones:
90 | w = torch.zeros((n_batch, n_point, n_joint), device=xd.device)
91 | w[:, :, i] = 1
92 | xc_init.append(skinning(xd, w, tfs, inverse=True))
93 |
94 | xc_init = torch.stack(xc_init, dim=2)
95 |
96 | return xc_init
97 |
98 | def search(self, xd, xc_init, cond, tfs, eval_mode=False):
99 | """Search correspondences.
100 |
101 | Args:
102 | xd (tensor): deformed points in batch. shape: [B, N, D]
103 | xc_init (tensor): deformed points in batch. shape: [B, N, I, D]
104 | cond (dict): conditional input.
105 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
106 |
107 | Returns:
108 | xc_opt (tensor): canonoical correspondences of xd. shape: [B, N, I, D]
109 | valid_ids (tensor): identifiers of converged points. [B, N, I]
110 | """
111 | # reshape to [B,?,D] for other functions
112 | n_batch, n_point, n_init, n_dim = xc_init.shape
113 | xc_init = xc_init.reshape(n_batch, n_point * n_init, n_dim)
114 | xd_tgt = xd.repeat_interleave(n_init, dim=1)
115 |
116 | # compute init jacobians
117 | if not eval_mode:
118 | J_inv_init = self.gradient(xc_init, cond, tfs).inverse()
119 | else:
120 | w = self.query_weights(xc_init, cond, mask=None)
121 | J_inv_init = einsum("bpn,bnij->bpij", w, tfs)[:, :, :3, :3].inverse()
122 |
123 | # reshape init to [?,D,...] for boryden
124 | xc_init = xc_init.reshape(-1, n_dim, 1)
125 | J_inv_init = J_inv_init.flatten(0, 1)
126 |
127 | # construct function for root finding
128 | def _func(xc_opt, mask=None):
129 | # reshape to [B,?,D] for other functions
130 | xc_opt = xc_opt.reshape(n_batch, n_point * n_init, n_dim)
131 | xd_opt = self.forward_skinning(xc_opt, cond, tfs, mask=mask)
132 | error = xd_opt - xd_tgt
133 | # reshape to [?,D,1] for boryden
134 | error = error.flatten(0, 1)[mask].unsqueeze(-1)
135 | return error
136 |
137 | # run broyden without grad
138 | with torch.no_grad():
139 | result = broyden(_func, xc_init, J_inv_init)
140 |
141 | # reshape back to [B,N,I,D]
142 | xc_opt = result["result"].reshape(n_batch, n_point, n_init, n_dim)
143 | result["valid_ids"] = result["valid_ids"].reshape(n_batch, n_point, n_init)
144 |
145 | return xc_opt, result
146 |
147 | def forward_skinning(self, xc, cond, tfs, mask=None):
148 | """Canonical point -> deformed point
149 |
150 | Args:
151 | xc (tensor): canonoical points in batch. shape: [B, N, D]
152 | cond (dict): conditional input.
153 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
154 |
155 | Returns:
156 | xd (tensor): deformed point. shape: [B, N, D]
157 | """
158 | w = self.query_weights(xc, cond, mask=mask)
159 | xd = skinning(xc, w, tfs, inverse=False)
160 | return xd
161 |
162 | def query_weights(self, xc, cond, mask=None):
163 | """Get skinning weights in canonical space
164 |
165 | Args:
166 | xc (tensor): canonical points. shape: [B, N, D]
167 | cond (dict): conditional input.
168 | mask (tensor, optional): valid indices. shape: [B, N]
169 |
170 | Returns:
171 | w (tensor): skinning weights. shape: [B, N, J]
172 | """
173 |
174 | w = self.lbs_network(xc, cond, mask)
175 | w = self.soft_blend * w
176 |
177 | if self.opt.softmax_mode == "hierarchical":
178 | w = hierarchical_softmax(w)
179 | else:
180 | w = F.softmax(w, dim=-1)
181 |
182 | return w
183 |
184 | def gradient(self, xc, cond, tfs):
185 | """Get gradients df/dx
186 |
187 | Args:
188 | xc (tensor): canonical points. shape: [B, N, D]
189 | cond (dict): conditional input.
190 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
191 |
192 | Returns:
193 | grad (tensor): gradients. shape: [B, N, D, D]
194 | """
195 | xc.requires_grad_(True)
196 |
197 | xd = self.forward_skinning(xc, cond, tfs)
198 |
199 | grads = []
200 | for i in range(xd.shape[-1]):
201 | d_out = torch.zeros_like(xd, requires_grad=False, device=xd.device)
202 | d_out[:, :, i] = 1
203 | grad = torch.autograd.grad(
204 | outputs=xd,
205 | inputs=xc,
206 | grad_outputs=d_out,
207 | create_graph=False,
208 | retain_graph=True,
209 | only_inputs=True,
210 | )[0]
211 | grads.append(grad)
212 |
213 | return torch.stack(grads, dim=-2)
214 |
215 |
216 | def skinning(x, w, tfs, inverse=False):
217 | """Linear blend skinning
218 |
219 | Args:
220 | x (tensor): canonical points. shape: [B, N, D]
221 | w (tensor): conditional input. [B, N, J]
222 | tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
223 | Returns:
224 | x (tensor): skinned points. shape: [B, N, D]
225 | """
226 | x_h = F.pad(x, (0, 1), value=1.0)
227 |
228 | if inverse:
229 | # p:n_point, n:n_bone, i,k: n_dim+1
230 | w_tf = einsum("bpn,bnij->bpij", w, tfs)
231 | x_h = einsum("bpij,bpj->bpi", w_tf.inverse(), x_h)
232 | else:
233 | x_h = einsum("bpn,bnij,bpj->bpi", w, tfs, x_h)
234 |
235 | return x_h[:, :, :3]
236 |
--------------------------------------------------------------------------------
/lib/model/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import cv2
4 | import numpy as np
5 |
6 | def masked_softmax(vec, mask, dim=-1, mode='softmax', soft_blend=1):
7 | if mode == 'softmax':
8 |
9 | vec = torch.distributions.Bernoulli(logits=vec).probs
10 |
11 | masked_exps = torch.exp(soft_blend*vec) * mask.float()
12 | masked_exps_sum = masked_exps.sum(dim)
13 |
14 | output = torch.zeros_like(vec)
15 | output[masked_exps_sum>0,:] = masked_exps[masked_exps_sum>0,:]/ masked_exps_sum[masked_exps_sum>0].unsqueeze(-1)
16 |
17 | output = (output * vec).sum(dim, keepdim=True)
18 |
19 | output = torch.distributions.Bernoulli(probs=output).logits
20 |
21 | elif mode == 'max':
22 | vec[~mask] = -math.inf
23 | output = torch.max(vec, dim, keepdim=True)[0]
24 |
25 | return output
26 |
27 |
28 | ''' Hierarchical softmax following the kinematic tree of the human body. Imporves convergence speed'''
29 | def hierarchical_softmax(x):
30 | def softmax(x):
31 | return torch.nn.functional.softmax(x, dim=-1)
32 |
33 | def sigmoid(x):
34 | return torch.sigmoid(x)
35 |
36 | n_batch, n_point, n_dim = x.shape
37 | x = x.flatten(0,1)
38 |
39 | prob_all = torch.ones(n_batch * n_point, 24, device=x.device)
40 |
41 | prob_all[:, [1, 2, 3]] = prob_all[:, [0]] * sigmoid(x[:, [0]]) * softmax(x[:, [1, 2, 3]])
42 | prob_all[:, [0]] = prob_all[:, [0]] * (1 - sigmoid(x[:, [0]]))
43 |
44 | prob_all[:, [4, 5, 6]] = prob_all[:, [1, 2, 3]] * (sigmoid(x[:, [4, 5, 6]]))
45 | prob_all[:, [1, 2, 3]] = prob_all[:, [1, 2, 3]] * (1 - sigmoid(x[:, [4, 5, 6]]))
46 |
47 | prob_all[:, [7, 8, 9]] = prob_all[:, [4, 5, 6]] * (sigmoid(x[:, [7, 8, 9]]))
48 | prob_all[:, [4, 5, 6]] = prob_all[:, [4, 5, 6]] * (1 - sigmoid(x[:, [7, 8, 9]]))
49 |
50 | prob_all[:, [10, 11]] = prob_all[:, [7, 8]] * (sigmoid(x[:, [10, 11]]))
51 | prob_all[:, [7, 8]] = prob_all[:, [7, 8]] * (1 - sigmoid(x[:, [10, 11]]))
52 |
53 | prob_all[:, [12, 13, 14]] = prob_all[:, [9]] * sigmoid(x[:, [24]]) * softmax(x[:, [12, 13, 14]])
54 | prob_all[:, [9]] = prob_all[:, [9]] * (1 - sigmoid(x[:, [24]]))
55 |
56 | prob_all[:, [15]] = prob_all[:, [12]] * (sigmoid(x[:, [15]]))
57 | prob_all[:, [12]] = prob_all[:, [12]] * (1 - sigmoid(x[:, [15]]))
58 |
59 | prob_all[:, [16, 17]] = prob_all[:, [13, 14]] * (sigmoid(x[:, [16, 17]]))
60 | prob_all[:, [13, 14]] = prob_all[:, [13, 14]] * (1 - sigmoid(x[:, [16, 17]]))
61 |
62 | prob_all[:, [18, 19]] = prob_all[:, [16, 17]] * (sigmoid(x[:, [18, 19]]))
63 | prob_all[:, [16, 17]] = prob_all[:, [16, 17]] * (1 - sigmoid(x[:, [18, 19]]))
64 |
65 | prob_all[:, [20, 21]] = prob_all[:, [18, 19]] * (sigmoid(x[:, [20, 21]]))
66 | prob_all[:, [18, 19]] = prob_all[:, [18, 19]] * (1 - sigmoid(x[:, [20, 21]]))
67 |
68 | prob_all[:, [22, 23]] = prob_all[:, [20, 21]] * (sigmoid(x[:, [22, 23]]))
69 | prob_all[:, [20, 21]] = prob_all[:, [20, 21]] * (1 - sigmoid(x[:, [22, 23]]))
70 |
71 | prob_all = prob_all.reshape(n_batch, n_point, prob_all.shape[-1])
72 | return prob_all
73 |
74 | def rectify_pose(pose, root_abs):
75 | """
76 | Rectify AMASS pose in global coord adapted from https://github.com/akanazawa/hmr/issues/50.
77 |
78 | Args:
79 | pose (72,): Pose.
80 |
81 | Returns:
82 | Rotated pose.
83 | """
84 | pose = pose.copy()
85 | R_abs = cv2.Rodrigues(root_abs)[0]
86 | R_root = cv2.Rodrigues(pose[:3])[0]
87 | new_root = np.linalg.inv(R_abs).dot(R_root)
88 | pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3)
89 | return pose
--------------------------------------------------------------------------------
/lib/model/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def calculate_iou(gt, prediction):
4 | intersection = torch.logical_and(gt, prediction)
5 | union = torch.logical_or(gt, prediction)
6 | return torch.sum(intersection) / torch.sum(union)
--------------------------------------------------------------------------------
/lib/model/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | """ MLP for neural implicit shapes. The code is based on https://github.com/lioryariv/idr with adaption. """
5 | class ImplicitNetwork(torch.nn.Module):
6 | def __init__(
7 | self,
8 | d_in,
9 | d_out,
10 | width,
11 | depth,
12 | geometric_init=True,
13 | bias=1.0,
14 | weight_norm=True,
15 | multires=0,
16 | skip_layer=[],
17 | cond_layer=[],
18 | cond_dim=69,
19 | dim_cond_embed=-1,
20 | ):
21 | super().__init__()
22 |
23 | dims = [d_in] + [width] * depth + [d_out]
24 | self.num_layers = len(dims)
25 |
26 | self.embed_fn = None
27 | if multires > 0:
28 | embed_fn, input_ch = get_embedder(multires)
29 | self.embed_fn = embed_fn
30 | dims[0] = input_ch
31 |
32 | self.cond_layer = cond_layer
33 | self.cond_dim = cond_dim
34 |
35 | self.dim_cond_embed = dim_cond_embed
36 | if dim_cond_embed > 0:
37 | self.lin_p0 = torch.nn.Linear(self.cond_dim, dim_cond_embed)
38 | self.cond_dim = dim_cond_embed
39 |
40 | self.skip_layer = skip_layer
41 |
42 | for l in range(0, self.num_layers - 1):
43 | if l + 1 in self.skip_layer:
44 | out_dim = dims[l + 1] - dims[0]
45 | else:
46 | out_dim = dims[l + 1]
47 |
48 | if l in self.cond_layer:
49 | lin = torch.nn.Linear(dims[l] + self.cond_dim, out_dim)
50 | else:
51 | lin = torch.nn.Linear(dims[l], out_dim)
52 |
53 | if geometric_init:
54 | if l == self.num_layers - 2:
55 | torch.nn.init.normal_(
56 | lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001
57 | )
58 | torch.nn.init.constant_(lin.bias, bias)
59 | elif multires > 0 and l == 0:
60 | torch.nn.init.constant_(lin.bias, 0.0)
61 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
62 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
63 | elif multires > 0 and l in self.skip_layer:
64 | torch.nn.init.constant_(lin.bias, 0.0)
65 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
66 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
67 | else:
68 | torch.nn.init.constant_(lin.bias, 0.0)
69 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
70 |
71 | if weight_norm:
72 | lin = torch.nn.utils.weight_norm(lin)
73 |
74 | setattr(self, "lin" + str(l), lin)
75 |
76 | self.softplus = torch.nn.Softplus(beta=100)
77 |
78 | def forward(self, input, cond, mask=None):
79 | """MPL query.
80 |
81 | Tensor shape abbreviation:
82 | B: batch size
83 | N: number of points
84 | D: input dimension
85 |
86 | Args:
87 | input (tensor): network input. shape: [B, N, D]
88 | cond (dict): conditional input.
89 | mask (tensor, optional): only masked inputs are fed into the network. shape: [B, N]
90 |
91 | Returns:
92 | output (tensor): network output. Might contains placehold if mask!=None shape: [N, D, ?]
93 | """
94 |
95 |
96 | n_batch, n_point, n_dim = input.shape
97 |
98 | if n_batch * n_point == 0:
99 | return input
100 |
101 | # reshape to [N,?]
102 | input = input.reshape(n_batch * n_point, n_dim)
103 | if mask is not None:
104 | input = input[mask]
105 |
106 | input_embed = input if self.embed_fn is None else self.embed_fn(input)
107 |
108 | if len(self.cond_layer):
109 | cond = cond["smpl"]
110 | n_batch, n_cond = cond.shape
111 | input_cond = cond.unsqueeze(1).expand(n_batch, n_point, n_cond)
112 | input_cond = input_cond.reshape(n_batch * n_point, n_cond)
113 |
114 | if mask is not None:
115 | input_cond = input_cond[mask]
116 |
117 | if self.dim_cond_embed > 0:
118 | input_cond = self.lin_p0(input_cond)
119 |
120 | x = input_embed
121 |
122 | for l in range(0, self.num_layers - 1):
123 | lin = getattr(self, "lin" + str(l))
124 | if l in self.cond_layer:
125 | x = torch.cat([x, input_cond], dim=-1)
126 |
127 | if l in self.skip_layer:
128 | x = torch.cat([x, input_embed], 1) / np.sqrt(2)
129 |
130 | x = lin(x)
131 |
132 | if l < self.num_layers - 2:
133 | x = self.softplus(x)
134 |
135 | # add placeholder for masked prediction
136 | if mask is not None:
137 | x_full = torch.zeros(n_batch * n_point, x.shape[-1], device=x.device)
138 | x_full[mask] = x
139 | else:
140 | x_full = x
141 |
142 | return x_full.reshape(n_batch, n_point, -1)
143 |
144 |
145 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
146 | class Embedder:
147 | def __init__(self, **kwargs):
148 | self.kwargs = kwargs
149 | self.create_embedding_fn()
150 |
151 | def create_embedding_fn(self):
152 | embed_fns = []
153 | d = self.kwargs["input_dims"]
154 | out_dim = 0
155 | if self.kwargs["include_input"]:
156 | embed_fns.append(lambda x: x)
157 | out_dim += d
158 |
159 | max_freq = self.kwargs["max_freq_log2"]
160 | N_freqs = self.kwargs["num_freqs"]
161 |
162 | if self.kwargs["log_sampling"]:
163 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs)
164 | else:
165 | freq_bands = torch.linspace(2.0 ** 0.0, 2.0 ** max_freq, N_freqs)
166 |
167 | for freq in freq_bands:
168 | for p_fn in self.kwargs["periodic_fns"]:
169 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
170 | out_dim += d
171 |
172 | self.embed_fns = embed_fns
173 | self.out_dim = out_dim
174 |
175 | def embed(self, inputs):
176 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
177 |
178 |
179 | def get_embedder(multires):
180 | embed_kwargs = {
181 | "include_input": True,
182 | "input_dims": 3,
183 | "max_freq_log2": multires - 1,
184 | "num_freqs": multires,
185 | "log_sampling": True,
186 | "periodic_fns": [torch.sin, torch.cos],
187 | }
188 |
189 | embedder_obj = Embedder(**embed_kwargs)
190 |
191 | def embed(x, eo=embedder_obj):
192 | return eo.embed(x)
193 |
194 | return embed, embedder_obj.out_dim
195 |
--------------------------------------------------------------------------------
/lib/model/sample.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class PointOnBones:
5 | def __init__(self, bone_ids):
6 | self.bone_ids = bone_ids
7 |
8 | def get_points(self, joints, num_per_bone=5):
9 | """Sample points on bones in canonical space.
10 |
11 | Args:
12 | joints (tensor): joint positions to define the bone positions. shape: [B, J, D]
13 | num_per_bone (int, optional): number of sample points on each bone. Defaults to 5.
14 |
15 | Returns:
16 | samples (tensor): sampled points in canoncial space. shape: [B, ?, 3]
17 | probs (tensor): ground truth occupancy for samples (all 1). shape: [B, ?]
18 | """
19 |
20 | num_batch, _, _ = joints.shape
21 |
22 | samples = []
23 |
24 | for bone_id in self.bone_ids:
25 |
26 | if bone_id[0] < 0 or bone_id[1] < 0:
27 | continue
28 |
29 | bone_dir = joints[:, bone_id[1]] - joints[:, bone_id[0]]
30 |
31 | scalars = (
32 | torch.linspace(0, 1, steps=num_per_bone, device=joints.device)
33 | .unsqueeze(0)
34 | .expand(num_batch, -1)
35 | )
36 | scalars = (
37 | scalars + torch.randn((num_batch, num_per_bone), device=joints.device) * 0.1
38 | ).clamp_(0, 1)
39 |
40 | samples.append(
41 | joints[:, bone_id[0]].unsqueeze(1).expand(-1, scalars.shape[-1], -1)
42 | + torch.einsum("bn,bi->bni", scalars, bone_dir) # b: num_batch, n: num_per_bone, i: 3-dim
43 | )
44 |
45 | samples = torch.cat(samples, dim=1)
46 |
47 | probs = torch.ones((num_batch, samples.shape[1]), device=joints.device)
48 |
49 | return samples, probs
50 |
51 | def get_joints(self, joints):
52 | """Sample joints in canonical space.
53 |
54 | Args:
55 | joints (tensor): joint positions to define the bone positions. shape: [B, J, D]
56 |
57 | Returns:
58 | samples (tensor): sampled points in canoncial space. shape: [B, ?, 3]
59 | weights (tensor): ground truth skinning weights for samples (all 1). shape: [B, ?, J]
60 | """
61 | num_batch, num_joints, _ = joints.shape
62 |
63 | samples = []
64 | weights = []
65 |
66 | for k in range(num_joints):
67 | samples.append(joints[:, k])
68 | weight = torch.zeros((num_batch, num_joints), device=joints.device)
69 | weight[:, k] = 1
70 | weights.append(weight)
71 |
72 | for bone_id in self.bone_ids:
73 |
74 | if bone_id[0] < 0 or bone_id[1] < 0:
75 | continue
76 |
77 | samples.append(joints[:, bone_id[1]])
78 |
79 | weight = torch.zeros((num_batch, num_joints), device=joints.device)
80 | weight[:, bone_id[0]] = 1
81 | weights.append(weight)
82 |
83 | samples = torch.stack(samples, dim=1)
84 | weights = torch.stack(weights, dim=1)
85 |
86 | return samples, weights
87 |
88 |
89 | class PointInSpace:
90 | def __init__(self, global_sigma=1.8, local_sigma=0.01):
91 | self.global_sigma = global_sigma
92 | self.local_sigma = local_sigma
93 |
94 | def get_points(self, pc_input):
95 | """Sample one point near each of the given point + 1/8 uniformly.
96 |
97 | Args:
98 | pc_input (tensor): sampling centers. shape: [B, N, D]
99 |
100 | Returns:
101 | samples (tensor): sampled points. shape: [B, N + N / 8, D]
102 | """
103 |
104 | batch_size, sample_size, dim = pc_input.shape
105 |
106 | sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)
107 |
108 | sample_global = (
109 | torch.rand(batch_size, sample_size // 8, dim, device=pc_input.device)
110 | * (self.global_sigma * 2)
111 | ) - self.global_sigma
112 |
113 | sample = torch.cat([sample_local, sample_global], dim=1)
114 |
115 | return sample
116 |
--------------------------------------------------------------------------------
/lib/model/smpl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import hydra
3 | import numpy as np
4 | from lib.smpl.body_models import SMPL
5 |
6 | class SMPLServer(torch.nn.Module):
7 |
8 | def __init__(self, gender='neutral', betas=None, v_template=None):
9 | super().__init__()
10 |
11 |
12 | self.smpl = SMPL(model_path=hydra.utils.to_absolute_path('lib/smpl/smpl_model'),
13 | gender=gender,
14 | batch_size=1,
15 | use_hands=False,
16 | use_feet_keypoints=False,
17 | dtype=torch.float32).cuda()
18 |
19 | self.bone_parents = self.smpl.bone_parents.astype(int)
20 | self.bone_parents[0] = -1
21 | self.bone_ids = []
22 | for i in range(24): self.bone_ids.append([self.bone_parents[i], i])
23 |
24 | if v_template is not None:
25 | self.v_template = torch.tensor(v_template).float().cuda()
26 | else:
27 | self.v_template = None
28 |
29 | if betas is not None:
30 | self.betas = torch.tensor(betas).float().cuda()
31 | else:
32 | self.betas = None
33 |
34 | # define the canonical pose
35 | param_canonical = torch.zeros((1, 86),dtype=torch.float32).cuda()
36 | param_canonical[0, 0] = 1
37 | param_canonical[0, 9] = np.pi / 6
38 | param_canonical[0, 12] = -np.pi / 6
39 | if self.betas is not None and self.v_template is None:
40 | param_canonical[0,-10:] = self.betas
41 | self.param_canonical = param_canonical
42 |
43 | output = self.forward(param_canonical, absolute=True)
44 | self.verts_c = output['smpl_verts']
45 | self.joints_c = output['smpl_jnts']
46 | self.tfs_c_inv = output['smpl_tfs'].squeeze(0).inverse()
47 |
48 |
49 | def forward(self, smpl_params, absolute=False):
50 | """return SMPL output from params
51 |
52 | Args:
53 | smpl_params : smpl parameters. shape: [B, 86]. [0-scale,1:4-trans, 4:76-thetas,76:86-betas]
54 | absolute (bool): if true return smpl_tfs wrt thetas=0. else wrt thetas=thetas_canonical.
55 |
56 | Returns:
57 | smpl_verts: vertices. shape: [B, 6893. 3]
58 | smpl_tfs: bone transformations. shape: [B, 24, 4, 4]
59 | smpl_jnts: joint positions. shape: [B, 25, 3]
60 | """
61 |
62 | output = {}
63 |
64 | scale, transl, thetas, betas = torch.split(smpl_params, [1, 3, 72, 10], dim=1)
65 |
66 | # ignore betas if v_template is provided
67 | if self.v_template is not None:
68 | betas = torch.zeros_like(betas)
69 |
70 | smpl_output = self.smpl.forward(betas=betas,
71 | transl=torch.zeros_like(transl),
72 | body_pose=thetas[:, 3:],
73 | global_orient=thetas[:, :3],
74 | return_verts=True,
75 | return_full_pose=True,
76 | v_template=self.v_template)
77 |
78 | verts = smpl_output.vertices.clone()
79 | output['smpl_verts'] = verts * scale.unsqueeze(1) + transl.unsqueeze(1)
80 |
81 | joints = smpl_output.joints.clone()
82 | output['smpl_jnts'] = joints * scale.unsqueeze(1) + transl.unsqueeze(1)
83 |
84 | tf_mats = smpl_output.T.clone()
85 | tf_mats[:, :, :3, :] *= scale.unsqueeze(1).unsqueeze(1)
86 | tf_mats[:, :, :3, 3] += transl.unsqueeze(1)
87 |
88 | if not absolute:
89 | tf_mats = torch.einsum('bnij,njk->bnik', tf_mats, self.tfs_c_inv)
90 |
91 | output['smpl_tfs'] = tf_mats
92 |
93 | return output
--------------------------------------------------------------------------------
/lib/smpl/body_models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems and the Max Planck Institute for Biological
14 | # Cybernetics. All rights reserved.
15 | #
16 | # Contact: ps-license@tuebingen.mpg.de
17 |
18 | from __future__ import absolute_import
19 | from __future__ import print_function
20 | from __future__ import division
21 |
22 | import os
23 | import os.path as osp
24 |
25 |
26 | import pickle
27 |
28 | import numpy as np
29 |
30 | from collections import namedtuple
31 |
32 | import torch
33 | import torch.nn as nn
34 |
35 | from .lbs import (
36 | lbs, vertices2joints, blend_shapes)
37 |
38 | from .vertex_ids import vertex_ids as VERTEX_IDS
39 | from .utils import Struct, to_np, to_tensor
40 | from .vertex_joint_selector import VertexJointSelector
41 |
42 |
43 | ModelOutput = namedtuple('ModelOutput',
44 | ['vertices','faces', 'joints', 'full_pose', 'betas',
45 | 'global_orient',
46 | 'body_pose', 'expression',
47 | 'left_hand_pose', 'right_hand_pose',
48 | 'jaw_pose', 'T', 'T_weighted', 'weights'])
49 | ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields)
50 |
51 | class SMPL(nn.Module):
52 |
53 | NUM_JOINTS = 23
54 | NUM_BODY_JOINTS = 23
55 | NUM_BETAS = 10
56 |
57 | def __init__(self, model_path, data_struct=None,
58 | create_betas=True,
59 | betas=None,
60 | create_global_orient=True,
61 | global_orient=None,
62 | create_body_pose=True,
63 | body_pose=None,
64 | create_transl=True,
65 | transl=None,
66 | dtype=torch.float32,
67 | batch_size=1,
68 | joint_mapper=None, gender='neutral',
69 | vertex_ids=None,
70 | pose_blend=True,
71 | **kwargs):
72 | ''' SMPL model constructor
73 |
74 | Parameters
75 | ----------
76 | model_path: str
77 | The path to the folder or to the file where the model
78 | parameters are stored
79 | data_struct: Strct
80 | A struct object. If given, then the parameters of the model are
81 | read from the object. Otherwise, the model tries to read the
82 | parameters from the given `model_path`. (default = None)
83 | create_global_orient: bool, optional
84 | Flag for creating a member variable for the global orientation
85 | of the body. (default = True)
86 | global_orient: torch.tensor, optional, Bx3
87 | The default value for the global orientation variable.
88 | (default = None)
89 | create_body_pose: bool, optional
90 | Flag for creating a member variable for the pose of the body.
91 | (default = True)
92 | body_pose: torch.tensor, optional, Bx(Body Joints * 3)
93 | The default value for the body pose variable.
94 | (default = None)
95 | create_betas: bool, optional
96 | Flag for creating a member variable for the shape space
97 | (default = True).
98 | betas: torch.tensor, optional, Bx10
99 | The default value for the shape member variable.
100 | (default = None)
101 | create_transl: bool, optional
102 | Flag for creating a member variable for the translation
103 | of the body. (default = True)
104 | transl: torch.tensor, optional, Bx3
105 | The default value for the transl variable.
106 | (default = None)
107 | dtype: torch.dtype, optional
108 | The data type for the created variables
109 | batch_size: int, optional
110 | The batch size used for creating the member variables
111 | joint_mapper: object, optional
112 | An object that re-maps the joints. Useful if one wants to
113 | re-order the SMPL joints to some other convention (e.g. MSCOCO)
114 | (default = None)
115 | gender: str, optional
116 | Which gender to load
117 | vertex_ids: dict, optional
118 | A dictionary containing the indices of the extra vertices that
119 | will be selected
120 | '''
121 |
122 | self.gender = gender
123 | self.pose_blend = pose_blend
124 |
125 | if data_struct is None:
126 | if osp.isdir(model_path):
127 | model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl')
128 | smpl_path = os.path.join(model_path, model_fn)
129 | else:
130 | smpl_path = model_path
131 | assert osp.exists(smpl_path), 'Path {} does not exist!'.format(
132 | smpl_path)
133 |
134 | with open(smpl_path, 'rb') as smpl_file:
135 | data_struct = Struct(**pickle.load(smpl_file,encoding='latin1'))
136 | super(SMPL, self).__init__()
137 | self.batch_size = batch_size
138 |
139 | if vertex_ids is None:
140 | # SMPL and SMPL-H share the same topology, so any extra joints can
141 | # be drawn from the same place
142 | vertex_ids = VERTEX_IDS['smplh']
143 |
144 | self.dtype = dtype
145 |
146 | self.joint_mapper = joint_mapper
147 |
148 | self.vertex_joint_selector = VertexJointSelector(
149 | vertex_ids=vertex_ids, **kwargs)
150 |
151 | self.faces = data_struct.f
152 | self.register_buffer('faces_tensor',
153 | to_tensor(to_np(self.faces, dtype=np.int64),
154 | dtype=torch.long))
155 |
156 | if create_betas:
157 | if betas is None:
158 | default_betas = torch.zeros([batch_size, self.NUM_BETAS],
159 | dtype=dtype)
160 | else:
161 | if 'torch.Tensor' in str(type(betas)):
162 | default_betas = betas.clone().detach()
163 | else:
164 | default_betas = torch.tensor(betas,
165 | dtype=dtype)
166 |
167 | self.register_parameter('betas', nn.Parameter(default_betas,
168 | requires_grad=True))
169 |
170 | # The tensor that contains the global rotation of the model
171 | # It is separated from the pose of the joints in case we wish to
172 | # optimize only over one of them
173 | if create_global_orient:
174 | if global_orient is None:
175 | default_global_orient = torch.zeros([batch_size, 3],
176 | dtype=dtype)
177 | else:
178 | if 'torch.Tensor' in str(type(global_orient)):
179 | default_global_orient = global_orient.clone().detach()
180 | else:
181 | default_global_orient = torch.tensor(global_orient,
182 | dtype=dtype)
183 |
184 | global_orient = nn.Parameter(default_global_orient,
185 | requires_grad=True)
186 | self.register_parameter('global_orient', global_orient)
187 |
188 | if create_body_pose:
189 | if body_pose is None:
190 | default_body_pose = torch.zeros(
191 | [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype)
192 | else:
193 | if 'torch.Tensor' in str(type(body_pose)):
194 | default_body_pose = body_pose.clone().detach()
195 | else:
196 | default_body_pose = torch.tensor(body_pose,
197 | dtype=dtype)
198 | self.register_parameter(
199 | 'body_pose',
200 | nn.Parameter(default_body_pose, requires_grad=True))
201 |
202 | if create_transl:
203 | if transl is None:
204 | default_transl = torch.zeros([batch_size, 3],
205 | dtype=dtype,
206 | requires_grad=True)
207 | else:
208 | default_transl = torch.tensor(transl, dtype=dtype)
209 | self.register_parameter(
210 | 'transl',
211 | nn.Parameter(default_transl, requires_grad=True))
212 |
213 | # The vertices of the template model
214 | self.register_buffer('v_template',
215 | to_tensor(to_np(data_struct.v_template),
216 | dtype=dtype))
217 |
218 | # The shape components
219 | shapedirs = data_struct.shapedirs[:, :, :self.NUM_BETAS]
220 | # The shape components
221 | self.register_buffer(
222 | 'shapedirs',
223 | to_tensor(to_np(shapedirs), dtype=dtype))
224 |
225 |
226 | j_regressor = to_tensor(to_np(
227 | data_struct.J_regressor), dtype=dtype)
228 | self.register_buffer('J_regressor', j_regressor)
229 |
230 | # if self.gender == 'neutral':
231 | # joint_regressor = to_tensor(to_np(
232 | # data_struct.cocoplus_regressor), dtype=dtype).permute(1,0)
233 | # self.register_buffer('joint_regressor', joint_regressor)
234 |
235 | # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
236 | num_pose_basis = data_struct.posedirs.shape[-1]
237 | # 207 x 20670
238 | posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T
239 | self.register_buffer('posedirs',
240 | to_tensor(to_np(posedirs), dtype=dtype))
241 |
242 | # indices of parents for each joints
243 | parents = to_tensor(to_np(data_struct.kintree_table[0])).long()
244 | parents[0] = -1
245 | self.register_buffer('parents', parents)
246 |
247 | self.bone_parents = to_np(data_struct.kintree_table[0])
248 |
249 | self.register_buffer('lbs_weights',
250 | to_tensor(to_np(data_struct.weights), dtype=dtype))
251 |
252 | def create_mean_pose(self, data_struct):
253 | pass
254 |
255 | @torch.no_grad()
256 | def reset_params(self, **params_dict):
257 | for param_name, param in self.named_parameters():
258 | if param_name in params_dict:
259 | param[:] = torch.tensor(params_dict[param_name])
260 | else:
261 | param.fill_(0)
262 |
263 | def get_T_hip(self, betas=None):
264 | v_shaped = self.v_template + blend_shapes(betas, self.shapedirs)
265 | J = vertices2joints(self.J_regressor, v_shaped)
266 | T_hip = J[0,0]
267 | return T_hip
268 |
269 | def get_num_verts(self):
270 | return self.v_template.shape[0]
271 |
272 | def get_num_faces(self):
273 | return self.faces.shape[0]
274 |
275 | def extra_repr(self):
276 | return 'Number of betas: {}'.format(self.NUM_BETAS)
277 |
278 | def forward(self, betas=None, body_pose=None, global_orient=None,
279 | transl=None, return_verts=True, return_full_pose=False,displacement=None,v_template=None,
280 | **kwargs):
281 | ''' Forward pass for the SMPL model
282 |
283 | Parameters
284 | ----------
285 | global_orient: torch.tensor, optional, shape Bx3
286 | If given, ignore the member variable and use it as the global
287 | rotation of the body. Useful if someone wishes to predicts this
288 | with an external model. (default=None)
289 | betas: torch.tensor, optional, shape Bx10
290 | If given, ignore the member variable `betas` and use it
291 | instead. For example, it can used if shape parameters
292 | `betas` are predicted from some external model.
293 | (default=None)
294 | body_pose: torch.tensor, optional, shape Bx(J*3)
295 | If given, ignore the member variable `body_pose` and use it
296 | instead. For example, it can used if someone predicts the
297 | pose of the body joints are predicted from some external model.
298 | It should be a tensor that contains joint rotations in
299 | axis-angle format. (default=None)
300 | transl: torch.tensor, optional, shape Bx3
301 | If given, ignore the member variable `transl` and use it
302 | instead. For example, it can used if the translation
303 | `transl` is predicted from some external model.
304 | (default=None)
305 | return_verts: bool, optional
306 | Return the vertices. (default=True)
307 | return_full_pose: bool, optional
308 | Returns the full axis-angle pose vector (default=False)
309 |
310 | Returns
311 | -------
312 | '''
313 | # If no shape and pose parameters are passed along, then use the
314 | # ones from the module
315 | global_orient = (global_orient if global_orient is not None else
316 | self.global_orient)
317 | body_pose = body_pose if body_pose is not None else self.body_pose
318 | betas = betas if betas is not None else self.betas
319 |
320 | apply_trans = transl is not None or hasattr(self, 'transl')
321 | if transl is None and hasattr(self, 'transl'):
322 | transl = self.transl
323 |
324 | full_pose = torch.cat([global_orient, body_pose], dim=1)
325 |
326 | # if betas.shape[0] != self.batch_size:
327 | # num_repeats = int(self.batch_size / betas.shape[0])
328 | # betas = betas.expand(num_repeats, -1)
329 |
330 | if v_template is None:
331 | v_template = self.v_template
332 |
333 | if displacement is not None:
334 | vertices, joints_smpl, T_weighted, W, T = lbs(betas, full_pose, v_template+displacement,
335 | self.shapedirs, self.posedirs,
336 | self.J_regressor, self.parents,
337 | self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend)
338 | else:
339 |
340 |
341 | vertices, joints_smpl,T_weighted, W, T = lbs(betas, full_pose, v_template,
342 | self.shapedirs, self.posedirs,
343 | self.J_regressor, self.parents,
344 | self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend)
345 |
346 | # if self.gender is not 'neutral':
347 | joints = self.vertex_joint_selector(vertices, joints_smpl)
348 | # else:
349 | # joints = torch.matmul(vertices.permute(0,2,1),self.joint_regressor).permute(0,2,1)
350 | # Map the joints to the current dataset
351 | if self.joint_mapper is not None:
352 | joints = self.joint_mapper(joints)
353 |
354 | if apply_trans:
355 | joints_smpl += transl.unsqueeze(dim=1)
356 | joints += transl.unsqueeze(dim=1)
357 | vertices += transl.unsqueeze(dim=1)
358 |
359 | output = ModelOutput(vertices=vertices if return_verts else None,
360 | faces=self.faces,
361 | global_orient=global_orient,
362 | body_pose=body_pose,
363 | joints=joints_smpl,
364 | betas=self.betas,
365 | full_pose=full_pose if return_full_pose else None,
366 | T=T, T_weighted=T_weighted, weights=W)
367 |
368 | return output
--------------------------------------------------------------------------------
/lib/smpl/lbs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems and the Max Planck Institute for Biological
14 | # Cybernetics. All rights reserved.
15 | #
16 | # Contact: ps-license@tuebingen.mpg.de
17 |
18 | from __future__ import absolute_import
19 | from __future__ import print_function
20 | from __future__ import division
21 |
22 | import numpy as np
23 |
24 | import torch
25 | import torch.nn.functional as F
26 |
27 | from .utils import rot_mat_to_euler
28 |
29 |
30 | def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
31 | dynamic_lmk_b_coords,
32 | neck_kin_chain, dtype=torch.float32):
33 | ''' Compute the faces, barycentric coordinates for the dynamic landmarks
34 |
35 |
36 | To do so, we first compute the rotation of the neck around the y-axis
37 | and then use a pre-computed look-up table to find the faces and the
38 | barycentric coordinates that will be used.
39 |
40 | Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
41 | for providing the original TensorFlow implementation and for the LUT.
42 |
43 | Parameters
44 | ----------
45 | vertices: torch.tensor BxVx3, dtype = torch.float32
46 | The tensor of input vertices
47 | pose: torch.tensor Bx(Jx3), dtype = torch.float32
48 | The current pose of the body model
49 | dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
50 | The look-up table from neck rotation to faces
51 | dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
52 | The look-up table from neck rotation to barycentric coordinates
53 | neck_kin_chain: list
54 | A python list that contains the indices of the joints that form the
55 | kinematic chain of the neck.
56 | dtype: torch.dtype, optional
57 |
58 | Returns
59 | -------
60 | dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
61 | A tensor of size BxL that contains the indices of the faces that
62 | will be used to compute the current dynamic landmarks.
63 | dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
64 | A tensor of size BxL that contains the indices of the faces that
65 | will be used to compute the current dynamic landmarks.
66 | '''
67 |
68 | batch_size = vertices.shape[0]
69 |
70 | aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
71 | neck_kin_chain)
72 | rot_mats = batch_rodrigues(
73 | aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
74 |
75 | rel_rot_mat = torch.eye(3, device=vertices.device,
76 | dtype=dtype).unsqueeze_(dim=0)
77 | for idx in range(len(neck_kin_chain)):
78 | rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
79 |
80 | y_rot_angle = torch.round(
81 | torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
82 | max=39)).to(dtype=torch.long)
83 | neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
84 | mask = y_rot_angle.lt(-39).to(dtype=torch.long)
85 | neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
86 | y_rot_angle = (neg_mask * neg_vals +
87 | (1 - neg_mask) * y_rot_angle)
88 |
89 | dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
90 | 0, y_rot_angle)
91 | dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
92 | 0, y_rot_angle)
93 |
94 | return dyn_lmk_faces_idx, dyn_lmk_b_coords
95 |
96 |
97 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
98 | ''' Calculates landmarks by barycentric interpolation
99 |
100 | Parameters
101 | ----------
102 | vertices: torch.tensor BxVx3, dtype = torch.float32
103 | The tensor of input vertices
104 | faces: torch.tensor Fx3, dtype = torch.long
105 | The faces of the mesh
106 | lmk_faces_idx: torch.tensor L, dtype = torch.long
107 | The tensor with the indices of the faces used to calculate the
108 | landmarks.
109 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
110 | The tensor of barycentric coordinates that are used to interpolate
111 | the landmarks
112 |
113 | Returns
114 | -------
115 | landmarks: torch.tensor BxLx3, dtype = torch.float32
116 | The coordinates of the landmarks for each mesh in the batch
117 | '''
118 | # Extract the indices of the vertices for each face
119 | # BxLx3
120 | batch_size, num_verts = vertices.shape[:2]
121 | device = vertices.device
122 |
123 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).expand(
124 | batch_size, -1, -1).long()
125 |
126 | lmk_faces = lmk_faces + torch.arange(
127 | batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
128 |
129 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
130 | batch_size, -1, 3, 3)
131 |
132 | landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
133 | return landmarks
134 |
135 |
136 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
137 | lbs_weights, pose2rot=True, dtype=torch.float32, pose_blend=True):
138 | ''' Performs Linear Blend Skinning with the given shape and pose parameters
139 |
140 | Parameters
141 | ----------
142 | betas : torch.tensor BxNB
143 | The tensor of shape parameters
144 | pose : torch.tensor Bx(J + 1) * 3
145 | The pose parameters in axis-angle format
146 | v_template torch.tensor BxVx3
147 | The template mesh that will be deformed
148 | shapedirs : torch.tensor 1xNB
149 | The tensor of PCA shape displacements
150 | posedirs : torch.tensor Px(V * 3)
151 | The pose PCA coefficients
152 | J_regressor : torch.tensor JxV
153 | The regressor array that is used to calculate the joints from
154 | the position of the vertices
155 | parents: torch.tensor J
156 | The array that describes the kinematic tree for the model
157 | lbs_weights: torch.tensor N x V x (J + 1)
158 | The linear blend skinning weights that represent how much the
159 | rotation matrix of each part affects each vertex
160 | pose2rot: bool, optional
161 | Flag on whether to convert the input pose tensor to rotation
162 | matrices. The default value is True. If False, then the pose tensor
163 | should already contain rotation matrices and have a size of
164 | Bx(J + 1)x9
165 | dtype: torch.dtype, optional
166 |
167 | Returns
168 | -------
169 | verts: torch.tensor BxVx3
170 | The vertices of the mesh after applying the shape and pose
171 | displacements.
172 | joints: torch.tensor BxJx3
173 | The joints of the model
174 | '''
175 |
176 | batch_size = max(betas.shape[0], pose.shape[0])
177 | device = betas.device
178 |
179 | # Add shape contribution
180 | v_shaped = v_template + blend_shapes(betas, shapedirs)
181 |
182 | # Get the joints
183 | # NxJx3 array
184 | J = vertices2joints(J_regressor, v_shaped)
185 |
186 | # 3. Add pose blend shapes
187 | # N x J x 3 x 3
188 | ident = torch.eye(3, dtype=dtype, device=device)
189 |
190 |
191 | if pose2rot:
192 | rot_mats = batch_rodrigues(
193 | pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
194 |
195 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
196 | # (N x P) x (P, V * 3) -> N x V x 3
197 | pose_offsets = torch.matmul(pose_feature, posedirs) \
198 | .view(batch_size, -1, 3)
199 | else:
200 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
201 | rot_mats = pose.view(batch_size, -1, 3, 3)
202 |
203 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
204 | posedirs).view(batch_size, -1, 3)
205 |
206 | if pose_blend:
207 | v_posed = pose_offsets + v_shaped
208 | else:
209 | v_posed = v_shaped
210 |
211 | # 4. Get the global joint location
212 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
213 |
214 | # 5. Do skinning:
215 | # W is N x V x (J + 1)
216 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
217 | # (N x V x (J + 1)) x (N x (J + 1) x 16)
218 | num_joints = J_regressor.shape[0]
219 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
220 | .view(batch_size, -1, 4, 4)
221 |
222 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
223 | dtype=dtype, device=device)
224 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
225 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
226 |
227 | verts = v_homo[:, :, :3, 0]
228 |
229 | return verts, J_transformed, T, W, A.view(batch_size, num_joints, 4,4)
230 |
231 |
232 | def vertices2joints(J_regressor, vertices):
233 | ''' Calculates the 3D joint locations from the vertices
234 |
235 | Parameters
236 | ----------
237 | J_regressor : torch.tensor JxV
238 | The regressor array that is used to calculate the joints from the
239 | position of the vertices
240 | vertices : torch.tensor BxVx3
241 | The tensor of mesh vertices
242 |
243 | Returns
244 | -------
245 | torch.tensor BxJx3
246 | The location of the joints
247 | '''
248 |
249 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
250 |
251 |
252 | def blend_shapes(betas, shape_disps):
253 | ''' Calculates the per vertex displacement due to the blend shapes
254 |
255 |
256 | Parameters
257 | ----------
258 | betas : torch.tensor Bx(num_betas)
259 | Blend shape coefficients
260 | shape_disps: torch.tensor Vx3x(num_betas)
261 | Blend shapes
262 |
263 | Returns
264 | -------
265 | torch.tensor BxVx3
266 | The per-vertex displacement due to shape deformation
267 | '''
268 |
269 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
270 | # i.e. Multiply each shape displacement by its corresponding beta and
271 | # then sum them.
272 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
273 | return blend_shape
274 |
275 |
276 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
277 | ''' Calculates the rotation matrices for a batch of rotation vectors
278 | Parameters
279 | ----------
280 | rot_vecs: torch.tensor Nx3
281 | array of N axis-angle vectors
282 | Returns
283 | -------
284 | R: torch.tensor Nx3x3
285 | The rotation matrices for the given axis-angle parameters
286 | '''
287 |
288 | batch_size = rot_vecs.shape[0]
289 | device = rot_vecs.device
290 |
291 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
292 | rot_dir = rot_vecs / angle
293 |
294 | cos = torch.unsqueeze(torch.cos(angle), dim=1)
295 | sin = torch.unsqueeze(torch.sin(angle), dim=1)
296 |
297 | # Bx1 arrays
298 | rx, ry, rz = torch.split(rot_dir, 1, dim=1)
299 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
300 |
301 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
302 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
303 | .view((batch_size, 3, 3))
304 |
305 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
306 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
307 | return rot_mat
308 |
309 |
310 | def transform_mat(R, t):
311 | ''' Creates a batch of transformation matrices
312 | Args:
313 | - R: Bx3x3 array of a batch of rotation matrices
314 | - t: Bx3x1 array of a batch of translation vectors
315 | Returns:
316 | - T: Bx4x4 Transformation matrix
317 | '''
318 | # No padding left or right, only add an extra row
319 | return torch.cat([F.pad(R, [0, 0, 0, 1]),
320 | F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
321 |
322 |
323 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
324 | """
325 | Applies a batch of rigid transformations to the joints
326 |
327 | Parameters
328 | ----------
329 | rot_mats : torch.tensor BxNx3x3
330 | Tensor of rotation matrices
331 | joints : torch.tensor BxNx3
332 | Locations of joints
333 | parents : torch.tensor BxN
334 | The kinematic tree of each object
335 | dtype : torch.dtype, optional:
336 | The data type of the created tensors, the default is torch.float32
337 |
338 | Returns
339 | -------
340 | posed_joints : torch.tensor BxNx3
341 | The locations of the joints after applying the pose rotations
342 | rel_transforms : torch.tensor BxNx4x4
343 | The relative (with respect to the root joint) rigid transformations
344 | for all the joints
345 | """
346 |
347 | joints = torch.unsqueeze(joints, dim=-1)
348 |
349 | rel_joints = joints.clone()
350 | rel_joints[:, 1:] -= joints[:, parents[1:]]
351 |
352 | transforms_mat = transform_mat(
353 | rot_mats.reshape(-1, 3, 3),
354 | rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
355 |
356 | transform_chain = [transforms_mat[:, 0]]
357 | for i in range(1, parents.shape[0]):
358 | # Subtract the joint location at the rest pose
359 | # No need for rotation, since it's identity when at rest
360 | curr_res = torch.matmul(transform_chain[parents[i]],
361 | transforms_mat[:, i])
362 | transform_chain.append(curr_res)
363 |
364 | transforms = torch.stack(transform_chain, dim=1)
365 |
366 | # The last column of the transformations contains the posed joints
367 | posed_joints = transforms[:, :, :3, 3]
368 |
369 | # The last column of the transformations contains the posed joints
370 | posed_joints = transforms[:, :, :3, 3]
371 |
372 | joints_homogen = F.pad(joints, [0, 0, 0, 1])
373 |
374 | rel_transforms = transforms - F.pad(
375 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
376 |
377 | return posed_joints, rel_transforms
378 |
--------------------------------------------------------------------------------
/lib/smpl/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems and the Max Planck Institute for Biological
14 | # Cybernetics. All rights reserved.
15 | #
16 | # Contact: ps-license@tuebingen.mpg.de
17 |
18 | from __future__ import print_function
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | import numpy as np
23 | import torch
24 |
25 |
26 | def to_tensor(array, dtype=torch.float32):
27 | if 'torch.tensor' not in str(type(array)):
28 | return torch.tensor(array, dtype=dtype)
29 |
30 |
31 | class Struct(object):
32 | def __init__(self, **kwargs):
33 | for key, val in kwargs.items():
34 | setattr(self, key, val)
35 |
36 |
37 | def to_np(array, dtype=np.float32):
38 | if 'scipy.sparse' in str(type(array)):
39 | array = array.todense()
40 | return np.array(array, dtype=dtype)
41 |
42 |
43 | def rot_mat_to_euler(rot_mats):
44 | # Calculates rotation matrix to euler angles
45 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
46 |
47 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
48 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
49 | return torch.atan2(-rot_mats[:, 2, 0], sy)
50 |
--------------------------------------------------------------------------------
/lib/smpl/vertex_ids.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems and the Max Planck Institute for Biological
14 | # Cybernetics. All rights reserved.
15 | #
16 | # Contact: ps-license@tuebingen.mpg.de
17 |
18 | from __future__ import print_function
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to
23 | # MSCOCO and OpenPose joints
24 | vertex_ids = {
25 | 'smplh': {
26 | 'nose': 332,
27 | 'reye': 6260,
28 | 'leye': 2800,
29 | 'rear': 4071,
30 | 'lear': 583,
31 | 'rthumb': 6191,
32 | 'rindex': 5782,
33 | 'rmiddle': 5905,
34 | 'rring': 6016,
35 | 'rpinky': 6133,
36 | 'lthumb': 2746,
37 | 'lindex': 2319,
38 | 'lmiddle': 2445,
39 | 'lring': 2556,
40 | 'lpinky': 2673,
41 | 'LBigToe': 3216,
42 | 'LSmallToe': 3226,
43 | 'LHeel': 3387,
44 | 'RBigToe': 6617,
45 | 'RSmallToe': 6624,
46 | 'RHeel': 6787
47 | },
48 | 'smplx': {
49 | 'nose': 9120,
50 | 'reye': 9929,
51 | 'leye': 9448,
52 | 'rear': 616,
53 | 'lear': 6,
54 | 'rthumb': 8079,
55 | 'rindex': 7669,
56 | 'rmiddle': 7794,
57 | 'rring': 7905,
58 | 'rpinky': 8022,
59 | 'lthumb': 5361,
60 | 'lindex': 4933,
61 | 'lmiddle': 5058,
62 | 'lring': 5169,
63 | 'lpinky': 5286,
64 | 'LBigToe': 5770,
65 | 'LSmallToe': 5780,
66 | 'LHeel': 8846,
67 | 'RBigToe': 8463,
68 | 'RSmallToe': 8474,
69 | 'RHeel': 8635
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/lib/smpl/vertex_joint_selector.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems and the Max Planck Institute for Biological
14 | # Cybernetics. All rights reserved.
15 | #
16 | # Contact: ps-license@tuebingen.mpg.de
17 |
18 | from __future__ import absolute_import
19 | from __future__ import print_function
20 | from __future__ import division
21 |
22 | import numpy as np
23 |
24 | import torch
25 | import torch.nn as nn
26 |
27 | from .utils import to_tensor
28 |
29 |
30 | class VertexJointSelector(nn.Module):
31 |
32 | def __init__(self, vertex_ids=None,
33 | use_hands=True,
34 | use_feet_keypoints=True, **kwargs):
35 | super(VertexJointSelector, self).__init__()
36 |
37 | extra_joints_idxs = []
38 |
39 | face_keyp_idxs = np.array([
40 | vertex_ids['nose'],
41 | vertex_ids['reye'],
42 | vertex_ids['leye'],
43 | vertex_ids['rear'],
44 | vertex_ids['lear']], dtype=np.int64)
45 |
46 | extra_joints_idxs = np.concatenate([extra_joints_idxs,
47 | face_keyp_idxs])
48 |
49 | if use_feet_keypoints:
50 | feet_keyp_idxs = np.array([vertex_ids['LBigToe'],
51 | vertex_ids['LSmallToe'],
52 | vertex_ids['LHeel'],
53 | vertex_ids['RBigToe'],
54 | vertex_ids['RSmallToe'],
55 | vertex_ids['RHeel']], dtype=np.int32)
56 |
57 | extra_joints_idxs = np.concatenate(
58 | [extra_joints_idxs, feet_keyp_idxs])
59 |
60 | if use_hands:
61 | self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky']
62 |
63 | tips_idxs = []
64 | for hand_id in ['l', 'r']:
65 | for tip_name in self.tip_names:
66 | tips_idxs.append(vertex_ids[hand_id + tip_name])
67 |
68 | extra_joints_idxs = np.concatenate(
69 | [extra_joints_idxs, tips_idxs])
70 |
71 | self.register_buffer('extra_joints_idxs',
72 | to_tensor(extra_joints_idxs, dtype=torch.long))
73 |
74 | def forward(self, vertices, joints):
75 | extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs)
76 | joints = torch.cat([joints, extra_joints], dim=1)
77 | return joints
78 |
--------------------------------------------------------------------------------
/lib/snarf_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hydra
3 | import torch
4 | import wandb
5 | import imageio
6 | import numpy as np
7 | import pytorch_lightning as pl
8 |
9 | from lib.model.smpl import SMPLServer
10 | from lib.model.sample import PointOnBones
11 | from lib.model.network import ImplicitNetwork
12 | from lib.model.metrics import calculate_iou
13 | from lib.utils.meshing import generate_mesh
14 | from lib.model.helpers import masked_softmax
15 | from lib.model.deformer import ForwardDeformer, skinning
16 | from lib.utils.render import render_trimesh, render_joint, weights2colors
17 |
18 | class SNARFModel(pl.LightningModule):
19 |
20 | def __init__(self, opt, meta_info, data_processor=None):
21 | super().__init__()
22 |
23 | self.opt = opt
24 |
25 | self.network = ImplicitNetwork(**opt.network)
26 | self.deformer = ForwardDeformer(opt.deformer)
27 |
28 | print(self.network)
29 | print(self.deformer)
30 |
31 | gender = str(meta_info['gender'])
32 | betas = meta_info['betas'] if 'betas' in meta_info else None
33 | v_template = meta_info['v_template'] if 'v_template' in meta_info else None
34 |
35 | self.smpl_server = SMPLServer(gender=gender, betas=betas, v_template=v_template)
36 | self.smpl_faces = torch.tensor(self.smpl_server.smpl.faces.astype('int')).unsqueeze(0).cuda()
37 | self.sampler_bone = PointOnBones(self.smpl_server.bone_ids)
38 |
39 | self.data_processor = data_processor
40 |
41 | def configure_optimizers(self):
42 | optimizer = torch.optim.Adam(self.parameters(), lr=self.opt.optim.lr)
43 | return optimizer
44 |
45 | def forward(self, pts_d, smpl_tfs, smpl_thetas, eval_mode=True):
46 |
47 | # rectify rest pose
48 | smpl_tfs = torch.einsum('bnij,njk->bnik', smpl_tfs, self.smpl_server.tfs_c_inv)
49 |
50 | cond = {'smpl': smpl_thetas[:,3:]/np.pi}
51 |
52 | batch_points = 60000
53 |
54 | accum_pred = []
55 | # split to prevent out of memory
56 | for pts_d_split in torch.split(pts_d, batch_points, dim=1):
57 |
58 | # compute canonical correspondences
59 | pts_c, intermediates = self.deformer(pts_d_split, cond, smpl_tfs, eval_mode=eval_mode)
60 |
61 | # query occuancy in canonical space
62 | num_batch, num_point, num_init, num_dim = pts_c.shape
63 | pts_c = pts_c.reshape(num_batch, num_point * num_init, num_dim)
64 | occ_pd = self.network(pts_c, cond).reshape(num_batch, num_point, num_init)
65 |
66 | # aggregate occupancy probablities
67 | mask = intermediates['valid_ids']
68 | if eval_mode:
69 | occ_pd = masked_softmax(occ_pd, mask, dim=-1, mode='max')
70 | else:
71 | occ_pd = masked_softmax(occ_pd, mask, dim=-1, mode='softmax', soft_blend=self.opt.soft_blend)
72 |
73 | accum_pred.append(occ_pd)
74 |
75 | accum_pred = torch.cat(accum_pred, 1)
76 |
77 | return accum_pred
78 |
79 | def training_step(self, data, data_idx):
80 |
81 | # Data prep
82 | if self.data_processor is not None:
83 | data = self.data_processor.process(data)
84 |
85 | # BCE loss
86 | occ_pd = self.forward(data['pts_d'], data['smpl_tfs'], data['smpl_thetas'], eval_mode=False)
87 | loss_bce = torch.nn.functional.binary_cross_entropy_with_logits(occ_pd, data['occ_gt'])
88 | self.log('train_bce', loss_bce)
89 | loss = loss_bce
90 |
91 | # Bootstrapping
92 | num_batch = data['pts_d'].shape[0]
93 | cond = {'smpl': data['smpl_thetas'][:,3:]/np.pi}
94 |
95 | # Bone occupancy loss
96 | if self.current_epoch < self.opt.nepochs_pretrain:
97 | if self.opt.lambda_bone_occ > 0:
98 |
99 | pts_c, occ_gt = self.sampler_bone.get_points(self.smpl_server.joints_c.expand(num_batch, -1, -1))
100 | occ_pd = self.network(pts_c, cond)
101 | loss_bone_occ = torch.nn.functional.binary_cross_entropy_with_logits(occ_pd, occ_gt.unsqueeze(-1))
102 |
103 | loss = loss + self.opt.lambda_bone_occ * loss_bone_occ
104 | self.log('train_bone_occ', loss_bone_occ)
105 |
106 | # Joint weight loss
107 | if self.opt.lambda_bone_w > 0:
108 |
109 | pts_c, w_gt = self.sampler_bone.get_joints(self.smpl_server.joints_c.expand(num_batch, -1, -1))
110 | w_pd = self.deformer.query_weights(pts_c, cond)
111 | loss_bone_w = torch.nn.functional.mse_loss(w_pd, w_gt)
112 |
113 | loss = loss + self.opt.lambda_bone_w * loss_bone_w
114 | self.log('train_bone_w', loss_bone_w)
115 |
116 | return loss
117 |
118 | def validation_step(self, data, data_idx):
119 |
120 | if self.data_processor is not None:
121 | data = self.data_processor.process(data)
122 |
123 | with torch.no_grad():
124 | if data_idx == 0:
125 | img_all = self.plot(data)['img_all']
126 | self.logger.experiment.log({"vis":[wandb.Image(img_all)]})
127 |
128 | occ_pd = self.forward(data['pts_d'], data['smpl_tfs'], data['smpl_thetas'], eval_mode=True)
129 |
130 | _, num_point, _ = data['occ_gt'].shape
131 | bbox_iou = calculate_iou(data['occ_gt'][:,:num_point//2]>0.5, occ_pd[:,:num_point//2]>0)
132 | surf_iou = calculate_iou(data['occ_gt'][:,num_point//2:]>0.5, occ_pd[:,num_point//2:]>0)
133 |
134 | return {'bbox_iou':bbox_iou, 'surf_iou':surf_iou}
135 |
136 | def validation_epoch_end(self, validation_step_outputs):
137 |
138 | bbox_ious, surf_ious = [], []
139 | for output in validation_step_outputs:
140 | bbox_ious.append(output['bbox_iou'])
141 | surf_ious.append(output['surf_iou'])
142 |
143 | self.log('valid_bbox_iou', torch.stack(bbox_ious).mean())
144 | self.log('valid_surf_iou', torch.stack(surf_ious).mean())
145 |
146 | def test_step(self, data, data_idx):
147 |
148 | with torch.no_grad():
149 |
150 | occ_pd = self.forward(data['pts_d'], data['smpl_tfs'], data['smpl_thetas'], eval_mode=True)
151 |
152 | _, num_point, _ = data['occ_gt'].shape
153 | bbox_iou = calculate_iou(data['occ_gt'][:,:num_point//2]>0.5, occ_pd[:,:num_point//2]>0)
154 | surf_iou = calculate_iou(data['occ_gt'][:,num_point//2:]>0.5, occ_pd[:,num_point//2:]>0)
155 |
156 | return {'bbox_iou':bbox_iou, 'surf_iou':surf_iou}
157 |
158 | def test_epoch_end(self, test_step_outputs):
159 | return self.validation_epoch_end(test_step_outputs)
160 |
161 | def plot(self, data, res=128, verbose=True, fast_mode=False):
162 |
163 | res_up = np.log2(res//32)
164 |
165 | if verbose:
166 | surf_pred_cano = self.extract_mesh(self.smpl_server.verts_c, data['smpl_tfs'][[0]], data['smpl_thetas'][[0]], res_up=res_up, canonical=True, with_weights=True)
167 | surf_pred_def = self.extract_mesh(data['smpl_verts'][[0]], data['smpl_tfs'][[0]], data['smpl_thetas'][[0]], res_up=res_up, canonical=False, with_weights=False)
168 |
169 | img_pred_cano = render_trimesh(surf_pred_cano)
170 | img_pred_def = render_trimesh(surf_pred_def)
171 |
172 | img_joint = render_joint(data['smpl_jnts'].data.cpu().numpy()[0],self.smpl_server.bone_ids)
173 | img_pred_def[1024:,:,:3] = 255
174 | img_pred_def[1024:-512,:, :3] = img_joint
175 | img_pred_def[1024:-512,:, -1] = 255
176 |
177 | results = {
178 | 'img_all': np.concatenate([img_pred_cano, img_pred_def], axis=1),
179 | 'mesh_cano': surf_pred_cano,
180 | 'mesh_def' : surf_pred_def
181 | }
182 | else:
183 | smpl_verts = self.smpl_server.verts_c if fast_mode else data['smpl_verts'][[0]]
184 |
185 | surf_pred_def = self.extract_mesh(smpl_verts, data['smpl_tfs'][[0]], data['smpl_thetas'][[0]], res_up=res_up, canonical=False, with_weights=False, fast_mode=fast_mode)
186 |
187 | img_pred_def = render_trimesh(surf_pred_def, mode='p')
188 | results = {
189 | 'img_all': img_pred_def,
190 | 'mesh_def' : surf_pred_def
191 | }
192 |
193 |
194 | return results
195 |
196 | def extract_mesh(self, smpl_verts, smpl_tfs, smpl_thetas, canonical=False, with_weights=False, res_up=2, fast_mode=False):
197 | '''
198 | In fast mode, we extract canonical mesh and then forward skin it to posed space.
199 | This is faster as it bypasses root finding.
200 | However, it's not deforming the continuous field, but the discrete mesh.
201 | '''
202 | if canonical or fast_mode:
203 | occ_func = lambda x: self.network(x, {'smpl': smpl_thetas[:,3:]/np.pi}).reshape(-1, 1)
204 | else:
205 | occ_func = lambda x: self.forward(x, smpl_tfs, smpl_thetas, eval_mode=True).reshape(-1, 1)
206 |
207 | mesh = generate_mesh(occ_func, smpl_verts.squeeze(0),res_up=res_up)
208 |
209 |
210 | if fast_mode:
211 | verts = torch.tensor(mesh.vertices).type_as(smpl_verts)
212 | weights = self.deformer.query_weights(verts[None], None).clamp(0,1)[0]
213 |
214 | smpl_tfs = torch.einsum('bnij,njk->bnik', smpl_tfs, self.smpl_server.tfs_c_inv)
215 |
216 | verts_mesh_deformed = skinning(verts.unsqueeze(0), weights.unsqueeze(0), smpl_tfs).data.cpu().numpy()[0]
217 | mesh.vertices = verts_mesh_deformed
218 |
219 | if with_weights:
220 | verts = torch.tensor(mesh.vertices).cuda().float()
221 | weights = self.deformer.query_weights(verts[None], None).clamp(0,1)[0]
222 | mesh.visual.vertex_colors = weights2colors(weights.data.cpu().numpy())
223 |
224 | return mesh
225 |
--------------------------------------------------------------------------------
/lib/utils/meshing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from skimage import measure
4 | from lib.libmise import mise
5 | import trimesh
6 |
7 | ''' Code adapted from NASA https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/projects/nasa/lib/utils.py'''
8 | def generate_mesh(func, verts, level_set=0, res_init=32, res_up=3):
9 |
10 | scale = 1.1 # Scale of the padded bbox regarding the tight one.
11 |
12 | verts = verts.data.cpu().numpy()
13 | gt_bbox = np.stack([verts.min(axis=0), verts.max(axis=0)], axis=0)
14 | gt_center = (gt_bbox[0] + gt_bbox[1]) * 0.5
15 | gt_scale = (gt_bbox[1] - gt_bbox[0]).max()
16 |
17 | mesh_extractor = mise.MISE(res_init, res_up, level_set)
18 | points = mesh_extractor.query()
19 |
20 | # query occupancy grid
21 | with torch.no_grad():
22 | while points.shape[0] != 0:
23 |
24 | orig_points = points
25 | points = points.astype(np.float32)
26 | points = (points / mesh_extractor.resolution - 0.5) * scale
27 | points = points * gt_scale + gt_center
28 | points = torch.tensor(points).float().cuda()
29 |
30 | values = func(points.unsqueeze(0))[:,0]
31 | values = values.data.cpu().numpy().astype(np.float64)
32 |
33 | mesh_extractor.update(orig_points, values)
34 |
35 | points = mesh_extractor.query()
36 |
37 | value_grid = mesh_extractor.to_dense()
38 | # value_grid = np.pad(value_grid, 1, "constant", constant_values=-1e6)
39 |
40 | # marching cube
41 | verts, faces, normals, values = measure.marching_cubes_lewiner(
42 | volume=value_grid,
43 | gradient_direction='ascent',
44 | level=min(level_set, value_grid.max()))
45 |
46 | verts = (verts / mesh_extractor.resolution - 0.5) * scale
47 | verts = verts * gt_scale + gt_center
48 |
49 | meshexport = trimesh.Trimesh(verts, faces, normals, vertex_colors=values)
50 |
51 | # remove disconnect part
52 | connected_comp = meshexport.split(only_watertight=False)
53 | max_area = 0
54 | max_comp = None
55 | for comp in connected_comp:
56 | if comp.area > max_area:
57 | max_area = comp.area
58 | max_comp = comp
59 | meshexport = max_comp
60 |
61 | return meshexport
--------------------------------------------------------------------------------
/lib/utils/render.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import cv2
5 |
6 | from pytorch3d.renderer import (
7 | FoVOrthographicCameras,
8 | RasterizationSettings,
9 | MeshRenderer,
10 | MeshRasterizer,
11 | HardPhongShader,
12 | PointLights
13 | )
14 | from pytorch3d.structures import Meshes
15 | from pytorch3d.renderer.mesh import Textures
16 |
17 |
18 | class Renderer():
19 | def __init__(self, image_size=512):
20 | super().__init__()
21 |
22 | self.image_size = image_size
23 |
24 | self.device = torch.device("cuda:0")
25 | torch.cuda.set_device(self.device)
26 |
27 | R = torch.from_numpy(np.array([[-1., 0., 0.],
28 | [0., 1., 0.],
29 | [0., 0., -1.]])).cuda().float().unsqueeze(0)
30 |
31 |
32 | t = torch.from_numpy(np.array([[0., 0.3, 5.]])).cuda().float()
33 |
34 | self.cameras = FoVOrthographicCameras(R=R, T=t,device=self.device)
35 |
36 | self.lights = PointLights(device=self.device,location=[[0.0, 0.0, 3.0]],
37 | ambient_color=((1,1,1),),diffuse_color=((0,0,0),),specular_color=((0,0,0),))
38 |
39 | self.raster_settings = RasterizationSettings(image_size=image_size,faces_per_pixel=100,blur_radius=0)
40 | self.rasterizer = MeshRasterizer(cameras=self.cameras, raster_settings=self.raster_settings)
41 |
42 | self.shader = HardPhongShader(device=self.device, cameras=self.cameras, lights=self.lights)
43 |
44 | self.renderer = MeshRenderer(rasterizer=self.rasterizer, shader=self.shader)
45 |
46 | def render_mesh(self, verts, faces, colors=None, mode='npat'):
47 | '''
48 | mode: normal, phong, texture
49 | '''
50 | with torch.no_grad():
51 |
52 | mesh = Meshes(verts, faces)
53 |
54 | normals = torch.stack(mesh.verts_normals_list())
55 | front_light = torch.tensor([0,0,1]).float().to(verts.device)
56 | shades = (normals * front_light.view(1,1,3)).sum(-1).clamp(min=0).unsqueeze(-1).expand(-1,-1,3)
57 | results = []
58 |
59 | # normal
60 | if 'n' in mode:
61 | normals_vis = normals* 0.5 + 0.5
62 | mesh_normal = Meshes(verts, faces, textures=Textures(verts_rgb=normals_vis))
63 | image_normal = self.renderer(mesh_normal)
64 | results.append(image_normal)
65 |
66 | # shading
67 | if 'p' in mode:
68 | mesh_shading = Meshes(verts, faces, textures=Textures(verts_rgb=shades))
69 | image_phong = self.renderer(mesh_shading)
70 | results.append(image_phong)
71 |
72 | # albedo
73 | if 'a' in mode:
74 | assert(colors is not None)
75 | mesh_albido = Meshes(verts, faces, textures=Textures(verts_rgb=colors))
76 | image_color = self.renderer(mesh_albido)
77 | results.append(image_color)
78 |
79 | # albedo*shading
80 | if 't' in mode:
81 | assert(colors is not None)
82 | mesh_teture = Meshes(verts, faces, textures=Textures(verts_rgb=colors*shades))
83 | image_color = self.renderer(mesh_teture)
84 | results.append(image_color)
85 |
86 | return torch.cat(results, axis=1)
87 |
88 | image_size = 512
89 | torch.cuda.set_device(torch.device("cuda:0"))
90 | renderer = Renderer(image_size)
91 |
92 | def render(verts, faces, colors=None):
93 | return renderer.render_mesh(verts, faces, colors)
94 |
95 | def render_trimesh(mesh, mode='npta'):
96 | verts = torch.tensor(mesh.vertices).cuda().float()[None]
97 | faces = torch.tensor(mesh.faces).cuda()[None]
98 | colors = torch.tensor(mesh.visual.vertex_colors).float().cuda()[None,...,:3]/255
99 | image = renderer.render_mesh(verts, faces, colors=colors, mode=mode)[0]
100 | image = (255*image).data.cpu().numpy().astype(np.uint8)
101 | return image
102 |
103 |
104 | def render_joint(smpl_jnts, bone_ids):
105 | marker_sz = 6
106 | line_wd = 2
107 |
108 | image = np.ones((image_size, image_size,3), dtype=np.uint8)*255
109 | smpl_jnts[:,1] += 0.3
110 | smpl_jnts[:,1] = -smpl_jnts[:,1]
111 | smpl_jnts = smpl_jnts[:,:2]*image_size/2 + image_size/2
112 |
113 | for b in bone_ids:
114 | if b[0]<0 : continue
115 | joint = smpl_jnts[b[0]]
116 | cv2.circle(image, joint.astype('int32'), color=(0,0,0), radius=marker_sz, thickness=-1)
117 |
118 | joint2 = smpl_jnts[b[1]]
119 | cv2.circle(image, joint2.astype('int32'), color=(0,0,0), radius=marker_sz, thickness=-1)
120 |
121 | cv2.line(image, joint2.astype('int32'), joint.astype('int32'), color=(0,0,0), thickness=int(line_wd))
122 |
123 | return image
124 |
125 |
126 |
127 | def weights2colors(weights):
128 | import matplotlib.pyplot as plt
129 |
130 | cmap = plt.get_cmap('Paired')
131 |
132 | colors = [ 'pink', #0
133 | 'blue', #1
134 | 'green', #2
135 | 'red', #3
136 | 'pink', #4
137 | 'pink', #5
138 | 'pink', #6
139 | 'green', #7
140 | 'blue', #8
141 | 'red', #9
142 | 'pink', #10
143 | 'pink', #11
144 | 'pink', #12
145 | 'blue', #13
146 | 'green', #14
147 | 'red', #15
148 | 'cyan', #16
149 | 'darkgreen', #17
150 | 'pink', #18
151 | 'pink', #19
152 | 'blue', #20
153 | 'green', #21
154 | 'pink', #22
155 | 'pink' #23
156 | ]
157 |
158 |
159 | color_mapping = {'cyan': cmap.colors[3],
160 | 'blue': cmap.colors[1],
161 | 'darkgreen': cmap.colors[1],
162 | 'green':cmap.colors[3],
163 | 'pink': [1,1,1],
164 | 'red':cmap.colors[5],
165 | }
166 |
167 | for i in range(len(colors)):
168 | colors[i] = np.array(color_mapping[colors[i]])
169 |
170 | colors = np.stack(colors)[None]# [1x24x3]
171 | verts_colors = weights[:,:,None] * colors
172 | verts_colors = verts_colors.sum(1)
173 | return verts_colors
--------------------------------------------------------------------------------
/preprocess/body_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the
5 | # Max Planck Institute for Biological Cybernetics. All rights reserved.
6 | #
7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
8 | # on this computer program. You can only use this computer program if you have closed a license agreement
9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution.
11 | # Contact: ps-license@tuebingen.mpg.de
12 | #
13 | #
14 | # If you use this code in a research publication please consider citing the following:
15 | #
16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image
17 | #
18 | #
19 | # Code Developed by:
20 | # Nima Ghorbani
21 | #
22 | # 2018.12.13
23 |
24 | import numpy as np
25 |
26 | import torch
27 | import torch.nn as nn
28 |
29 | # from smplx.lbs import lbs
30 | from lbs import lbs
31 | import pickle
32 |
33 | class Struct(object):
34 | def __init__(self, **kwargs):
35 | for key, val in kwargs.items():
36 | setattr(self, key, val)
37 |
38 | def to_tensor(array, dtype=torch.float32):
39 | if 'torch.tensor' not in str(type(array)):
40 | return torch.tensor(array, dtype=dtype)
41 |
42 | def to_np(array, dtype=np.float32):
43 | if 'scipy.sparse' in str(type(array)):
44 | array = array.todense()
45 | return np.array(array, dtype=dtype)
46 |
47 | class BodyModel(nn.Module):
48 |
49 | def __init__(self,
50 | bm_path,
51 | params=None,
52 | num_betas=10,
53 | batch_size=1, v_template = None,
54 | num_dmpls=None, path_dmpl=None,
55 | num_expressions=10,
56 | use_posedirs=True,
57 | dtype=torch.float32):
58 |
59 | super(BodyModel, self).__init__()
60 |
61 | '''
62 | :param bm_path: path to a SMPL model as pkl file
63 | :param num_betas: number of shape parameters to include.
64 | if betas are provided in params, num_betas would be overloaded with number of thoes betas
65 | :param batch_size: number of smpl vertices to get
66 | :param device: default on gpu
67 | :param dtype: float precision of the compuations
68 | :return: verts, trans, pose, betas
69 | '''
70 | # Todo: if params the batchsize should be read from one of the params
71 |
72 | self.dtype = dtype
73 |
74 | if params is None: params = {}
75 | # -- Load SMPL params --
76 | if '.npz' in bm_path:
77 | smpl_dict = np.load(bm_path, encoding='latin1')
78 | elif '.pkl' in bm_path:
79 | with open(bm_path, 'rb') as smpl_file:
80 | smpl_dict = Struct(**pickle.load(smpl_file,encoding='latin1'))
81 | else:
82 | raise ValueError('bm_path should be either a .pkl nor .npz file')
83 |
84 | njoints = smpl_dict.posedirs.shape[2] // 3
85 | self.model_type = {69: 'smpl', 153: 'smplh', 162: 'smplx', 45: 'mano'}[njoints]
86 |
87 | assert self.model_type in ['smpl', 'smplh', 'smplx', 'mano', 'mano'], ValueError(
88 | 'model_type should be in smpl/smplh/smplx/mano.')
89 |
90 | self.use_dmpl = False
91 | if num_dmpls is not None:
92 | if path_dmpl is not None:
93 | self.use_dmpl = True
94 | else:
95 | raise (ValueError('path_dmpl should be provided when using dmpls!'))
96 |
97 | if self.use_dmpl and self.model_type in ['smplx', 'mano']: raise (
98 | NotImplementedError('DMPLs only work with SMPL/SMPLH models for now.'))
99 |
100 | # Mean template vertices
101 | if v_template is None:
102 | v_template = np.repeat(smpl_dict.v_template[np.newaxis], batch_size, axis=0)
103 | else:
104 | v_template = np.repeat(v_template[np.newaxis], batch_size, axis=0)
105 |
106 | self.register_buffer('v_template', torch.tensor(v_template, dtype=dtype))
107 |
108 | self.register_buffer('f', torch.tensor(smpl_dict.f.astype(np.int32), dtype=torch.int32))
109 |
110 | if len(params):
111 | if 'betas' in params.keys():
112 | num_betas = params['betas'].shape[1]
113 | if 'dmpls' in params.keys():
114 | num_dmpls = params['dmpls'].shape[1]
115 |
116 | num_total_betas = smpl_dict.shapedirs.shape[-1]
117 | if num_betas < 1:
118 | num_betas = num_total_betas
119 |
120 | shapedirs = smpl_dict.shapedirs[:, :, :num_betas]
121 |
122 | self.register_buffer('shapedirs', torch.tensor(to_np(shapedirs), dtype=dtype))
123 |
124 | if self.model_type == 'smplx':
125 | begin_shape_id = 300 if smpl_dict.shapedirs.shape[-1] > 300 else 10
126 | exprdirs = smpl_dict.shapedirs[:, :, begin_shape_id:(begin_shape_id + num_expressions)]
127 | self.register_buffer('exprdirs', torch.tensor(exprdirs, dtype=dtype))
128 |
129 | expression = torch.tensor(np.zeros((batch_size, num_expressions)), dtype=dtype, requires_grad=True)
130 | self.register_parameter('expression', nn.Parameter(expression, requires_grad=True))
131 |
132 | if self.use_dmpl:
133 | dmpldirs = np.load(path_dmpl)['eigvec']
134 |
135 | dmpldirs = dmpldirs[:, :, :num_dmpls]
136 | self.register_buffer('dmpldirs', torch.tensor(dmpldirs, dtype=dtype))
137 |
138 | # Regressor for joint locations given shape - 6890 x 24
139 | self.register_buffer('J_regressor', to_tensor(to_np(
140 | smpl_dict.J_regressor), dtype=dtype))
141 |
142 | # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*30 x 207
143 | if use_posedirs:
144 | posedirs = smpl_dict.posedirs
145 | posedirs = posedirs.reshape([posedirs.shape[0] * 3, -1]).T
146 | self.register_buffer('posedirs', torch.tensor(posedirs, dtype=dtype))
147 | else:
148 | self.posedirs = None
149 |
150 | # indices of parents for each joints
151 | kintree_table = smpl_dict.kintree_table.astype(np.int32)
152 | self.register_buffer('kintree_table', torch.tensor(kintree_table, dtype=torch.int32))
153 |
154 | # LBS weights
155 | # weights = np.repeat(smpl_dict.weights[np.newaxis], batch_size, axis=0)
156 | weights = smpl_dict.weights
157 | self.register_buffer('weights', torch.tensor(weights, dtype=dtype))
158 |
159 | if 'trans' in params.keys():
160 | trans = params['trans']
161 | else:
162 | trans = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True)
163 | self.register_parameter('trans', nn.Parameter(trans, requires_grad=True))
164 |
165 | # root_orient
166 | # if self.model_type in ['smpl', 'smplh']:
167 | root_orient = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True)
168 | self.register_parameter('root_orient', nn.Parameter(root_orient, requires_grad=True))
169 |
170 | # pose_body
171 | if self.model_type in ['smpl', 'smplh', 'smplx']:
172 | pose_body = torch.tensor(np.zeros((batch_size, 63)), dtype=dtype, requires_grad=True)
173 | self.register_parameter('pose_body', nn.Parameter(pose_body, requires_grad=True))
174 |
175 | # pose_hand
176 | if 'pose_hand' in params.keys():
177 | pose_hand = params['pose_hand']
178 | else:
179 | if self.model_type in ['smpl']:
180 | pose_hand = torch.tensor(np.zeros((batch_size, 1 * 3 * 2)), dtype=dtype, requires_grad=True)
181 | elif self.model_type in ['smplh', 'smplx']:
182 | pose_hand = torch.tensor(np.zeros((batch_size, 15 * 3 * 2)), dtype=dtype, requires_grad=True)
183 | elif self.model_type in ['mano']:
184 | pose_hand = torch.tensor(np.zeros((batch_size, 15 * 3)), dtype=dtype, requires_grad=True)
185 | self.register_parameter('pose_hand', nn.Parameter(pose_hand, requires_grad=True))
186 |
187 | # face poses
188 | if self.model_type == 'smplx':
189 | pose_jaw = torch.tensor(np.zeros((batch_size, 1 * 3)), dtype=dtype, requires_grad=True)
190 | self.register_parameter('pose_jaw', nn.Parameter(pose_jaw, requires_grad=True))
191 | pose_eye = torch.tensor(np.zeros((batch_size, 2 * 3)), dtype=dtype, requires_grad=True)
192 | self.register_parameter('pose_eye', nn.Parameter(pose_eye, requires_grad=True))
193 |
194 | if 'betas' in params.keys():
195 | betas = params['betas']
196 | else:
197 | betas = torch.tensor(np.zeros((batch_size, num_betas)), dtype=dtype, requires_grad=True)
198 | self.register_parameter('betas', nn.Parameter(betas, requires_grad=True))
199 |
200 | if self.use_dmpl:
201 | if 'dmpls' in params.keys():
202 | dmpls = params['dmpls']
203 | else:
204 | dmpls = torch.tensor(np.zeros((batch_size, num_dmpls)), dtype=dtype, requires_grad=True)
205 | self.register_parameter('dmpls', nn.Parameter(dmpls, requires_grad=True))
206 | self.batch_size = batch_size
207 |
208 | def r(self):
209 | from human_body_prior.tools.omni_tools import copy2cpu as c2c
210 | return c2c(self.forward().v)
211 |
212 | def forward(self, root_orient=None, pose_body=None, pose_hand=None, pose_jaw=None, pose_eye=None, betas=None,
213 | trans=None, dmpls=None, expression=None, return_dict=False, v_template =None, **kwargs):
214 | '''
215 |
216 | :param root_orient: Nx3
217 | :param pose_body:
218 | :param pose_hand:
219 | :param pose_jaw:
220 | :param pose_eye:
221 | :param kwargs:
222 | :return:
223 | '''
224 | assert not (v_template is not None and betas is not None), ValueError('vtemplate and betas could not be used jointly.')
225 | assert self.model_type in ['smpl', 'smplh', 'smplx', 'mano', 'mano'], ValueError(
226 | 'model_type should be in smpl/smplh/smplx/mano')
227 | if root_orient is None: root_orient = self.root_orient
228 | if self.model_type in ['smplh', 'smpl']:
229 | if pose_body is None: pose_body = self.pose_body
230 | if pose_hand is None: pose_hand = self.pose_hand
231 | elif self.model_type == 'smplx':
232 | if pose_body is None: pose_body = self.pose_body
233 | if pose_hand is None: pose_hand = self.pose_hand
234 | if pose_jaw is None: pose_jaw = self.pose_jaw
235 | if pose_eye is None: pose_eye = self.pose_eye
236 | elif self.model_type in ['mano', 'mano']:
237 | if pose_hand is None: pose_hand = self.pose_hand
238 |
239 | if pose_hand is None: pose_hand = self.pose_hand
240 |
241 | if trans is None: trans = self.trans
242 | if v_template is None: v_template = self.v_template
243 | if betas is None: betas = self.betas
244 |
245 | if v_template.size(0) != pose_body.size(0):
246 | v_template = v_template[:pose_body.size(0)] # this is fine since actual batch size will
247 | # only be equal to or less than specified batch
248 | # size
249 |
250 | if self.model_type in ['smplh', 'smpl']:
251 | full_pose = torch.cat([root_orient, pose_body, pose_hand], dim=1)
252 | elif self.model_type == 'smplx':
253 | full_pose = torch.cat([root_orient, pose_body, pose_jaw, pose_eye, pose_hand],
254 | dim=1) # orient:3, body:63, jaw:3, eyel:3, eyer:3, handl, handr
255 | elif self.model_type in ['mano', 'mano']:
256 | full_pose = torch.cat([root_orient, pose_hand], dim=1)
257 |
258 | if self.use_dmpl:
259 | if dmpls is None: dmpls = self.dmpls
260 | shape_components = torch.cat([betas, dmpls], dim=-1)
261 | shapedirs = torch.cat([self.shapedirs, self.dmpldirs], dim=-1)
262 | elif self.model_type == 'smplx':
263 | if expression is None: expression = self.expression
264 | shape_components = torch.cat([betas, expression], dim=-1)
265 | shapedirs = torch.cat([self.shapedirs, self.exprdirs], dim=-1)
266 | else:
267 | shape_components = betas
268 | shapedirs = self.shapedirs
269 |
270 |
271 | verts, joints, bone_transforms = lbs(betas=shape_components, pose=full_pose, v_template=v_template,
272 | shapedirs=shapedirs, posedirs=self.posedirs,
273 | J_regressor=self.J_regressor, parents=self.kintree_table[0].long(),
274 | lbs_weights=self.weights,
275 | dtype=self.dtype)
276 |
277 | Jtr = joints + trans.unsqueeze(dim=1)
278 | verts = verts + trans.unsqueeze(dim=1)
279 |
280 | res = {}
281 | res['v'] = verts
282 | res['f'] = self.f
283 | res['bone_transforms'] = bone_transforms
284 | res['betas'] = self.betas
285 | res['Jtr'] = Jtr # Todo: ik can be made with vposer
286 |
287 | if self.model_type == 'smpl':
288 | res['pose_body'] = pose_body
289 | elif self.model_type == 'smplh':
290 | res['pose_body'] = pose_body
291 | res['pose_hand'] = pose_hand
292 | elif self.model_type == 'smplx':
293 | res['pose_body'] = pose_body
294 | res['pose_hand'] = pose_hand
295 | res['pose_jaw'] = pose_jaw
296 | res['pose_eye'] = pose_eye
297 | elif self.model_type in ['mano', 'mano']:
298 | res['pose_hand'] = pose_hand
299 | res['full_pose'] = full_pose
300 |
301 | if not return_dict:
302 | class result_meta(object):
303 | pass
304 |
305 | res_class = result_meta()
306 | for k, v in res.items():
307 | res_class.__setattr__(k, v)
308 | res = res_class
309 |
310 | return res
311 |
312 |
313 |
--------------------------------------------------------------------------------
/preprocess/lbs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG),
4 | # acting on behalf of its Max Planck Institute for Intelligent Systems and the
5 | # Max Planck Institute for Biological Cybernetics. All rights reserved.
6 | #
7 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights
8 | # on this computer program. You can only use this computer program if you have closed a license agreement
9 | # with MPG or you get the right to use the computer program from someone who is authorized to grant you that right.
10 | # Any use of the computer program without a valid license is prohibited and liable to prosecution.
11 | # Contact: ps-license@tuebingen.mpg.de
12 | #
13 | #
14 | # If you use this code in a research publication please consider citing the following:
15 | #
16 | # Expressive Body Capture: 3D Hands, Face, and Body from a Single Image
17 | #
18 | # Code Developed by:
19 | # Vassilis Choutas
20 | # For the original and better implementation please refer to : https://github.com/vchoutas/smplx
21 | #
22 | # 2018.01.02
23 |
24 | from __future__ import absolute_import
25 | from __future__ import print_function
26 | from __future__ import division
27 |
28 | import time
29 | import torch
30 | import torch.nn.functional as F
31 |
32 | # TODO: Create merged c++ and CUDA kernel
33 | # TODO: Add argument to choose between python impl and c++/CUDA merged op
34 | def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
35 | lbs_weights, num_joints=23, dtype=torch.float32):
36 | ''' Performs Linear Blend Skinning with the given shape and pose parameters
37 |
38 | Parameters
39 | ----------
40 | betas : torch.tensor BxNB
41 | The tensor of shape parameters
42 | pose : torch.tensor Bx(J + 1) * 3
43 | The pose parameters in axis-angle format
44 | v_template torch.tensor BxVx3
45 | The template mesh that will be deformed
46 | shapedirs : torch.tensor 1xNB
47 | The tensor of PCA shape displacements
48 | posedirs : torch.tensor Px(V * 3)
49 | The pose PCA coefficients
50 | J_regressor : torch.tensor JxV
51 | The regressor array that is used to calculate the joints from
52 | the position of the vertices
53 | parents: torch.tensor J
54 | The array that describes the kinematic tree for the model
55 | lbs_weights: torch.tensor N x V x (J + 1)
56 | The linear blend skinning weights that represent how much the
57 | rotation matrix of each part affects each vertex
58 | pose2rot: bool, optional
59 | Flag on whether to convert the input pose tensor to rotation
60 | matrices. The default value is True. If False, then the pose tensor
61 | should already contain rotation matrices and have a size of
62 | Bx(J + 1)x9
63 | num_joints : int, optional
64 | The number of joints of the model. The default value is equal
65 | to the number of joints of the SMPL body model
66 | dtype: torch.dtype, optional
67 |
68 | Returns
69 | -------
70 | verts: torch.tensor BxVx3
71 | The vertices of the mesh after applying the shape and pose
72 | displacements.
73 | joints: torch.tensor BxJx3
74 | The joints of the model
75 | '''
76 |
77 | batch_size = betas.shape[0]
78 | device = betas.device
79 |
80 | # Add shape contribution
81 | v_shaped = v_template + blend_shapes(betas, shapedirs)
82 |
83 | # Get the joints
84 | # NxJx3 array
85 | J = vertices2joints(J_regressor, v_shaped)
86 |
87 | # import numpy as np
88 |
89 | rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3])
90 |
91 | if posedirs is not None:
92 | # 3. Add pose blend shapes
93 | # N x J x 3 x 3
94 | ident = torch.eye(3, dtype=dtype, device=device)
95 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
96 |
97 | # (N x P) x (P, V * 3) -> N x V x 3
98 | pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
99 | v_posed = pose_offsets + v_shaped
100 | else:
101 | v_posed = v_shaped
102 |
103 | # 4. Get the global joint location
104 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
105 |
106 | # 5. Do skinning:
107 | # W is N x V x (J + 1)
108 | W = lbs_weights.unsqueeze(dim=0).repeat([batch_size, 1, 1])
109 | num_joints = J_regressor.shape[0]
110 | # (N x V x (J + 1)) x (N x (J + 1) x 16)
111 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
112 |
113 | homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], dtype=dtype, device=device)
114 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
115 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
116 |
117 | verts = v_homo[:, :, :3, 0]
118 |
119 | return verts, J_transformed, A
120 |
121 |
122 | def vertices2joints(J_regressor, vertices):
123 | ''' Calculates the 3D joint locations from the vertices
124 |
125 | Parameters
126 | ----------
127 | J_regressor : torch.tensor JxV
128 | The regressor array that is used to calculate the joints from the
129 | position of the vertices
130 | vertices : torch.tensor BxVx3
131 | The tensor of mesh vertices
132 |
133 | Returns
134 | -------
135 | torch.tensor BxJx3
136 | The location of the joints
137 | '''
138 |
139 | return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
140 |
141 |
142 | def blend_shapes(betas, shape_disps):
143 | ''' Calculates the per vertex displacement due to the blend shapes
144 |
145 |
146 | Parameters
147 | ----------
148 | betas : torch.tensor Bx(num_betas)
149 | Blend shape coefficients
150 | shape_disps: torch.tensor Vx3x(num_betas)
151 | Blend shapes
152 |
153 | Returns
154 | -------
155 | torch.tensor BxVx3
156 | The per-vertex displacement due to shape deformation
157 | '''
158 |
159 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
160 | # i.e. Multiply each shape displacement by its corresponding beta and
161 | # then sum them.
162 | blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
163 | return blend_shape
164 |
165 |
166 | def batch_rodrigues(aa_rots):
167 | '''
168 | convert batch of rotations in axis-angle representation to matrix representation
169 | :param aa_rots: Nx3
170 | :return: mat_rots: Nx3x3
171 | '''
172 |
173 | dtype = aa_rots.dtype
174 | device = aa_rots.device
175 |
176 | batch_size = aa_rots.shape[0]
177 |
178 | angle = torch.norm(aa_rots + 1e-8, dim=1, keepdim=True)
179 | rot_dir = aa_rots / angle
180 |
181 | cos = torch.unsqueeze(torch.cos(angle), dim=1)
182 | sin = torch.unsqueeze(torch.sin(angle), dim=1)
183 |
184 | # Bx1 arrays
185 | rx, ry, rz = torch.split(rot_dir, 1, dim=1)
186 |
187 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
188 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
189 | .view((batch_size, 3, 3))
190 |
191 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
192 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
193 | return rot_mat
194 |
195 |
196 | def transform_mat(R, t):
197 | ''' Creates a batch of transformation matrices
198 | Args:
199 | - R: Bx3x3 array of a batch of rotation matrices
200 | - t: Bx3x1 array of a batch of translation vectors
201 | Returns:
202 | - T: Bx4x4 Transformation matrix
203 | '''
204 | # No padding left or right, only add an extra row
205 | return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
206 |
207 |
208 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
209 | """
210 | Applies a batch of rigid transformations to the joints
211 |
212 | Parameters
213 | ----------
214 | rot_mats : torch.tensor BxNx3x3
215 | Tensor of rotation matrices
216 | joints : torch.tensor BxNx3
217 | Locations of joints
218 | parents : torch.tensor BxN
219 | The kinematic tree of each object
220 | dtype : torch.dtype, optional:
221 | The data type of the created tensors, the default is torch.float32
222 |
223 | Returns
224 | -------
225 | posed_joints : torch.tensor BxNx3
226 | The locations of the joints after applying the pose rotations
227 | rel_transforms : torch.tensor BxNx4x4
228 | The relative (with respect to the root joint) rigid transformations
229 | for all the joints
230 | """
231 |
232 | batch_size = rot_mats.shape[0]
233 | num_joints = joints.shape[1]
234 | device = rot_mats.device
235 |
236 | joints = torch.unsqueeze(joints, dim=-1)
237 |
238 | rel_joints = joints.clone()
239 | rel_joints[:, 1:] -= joints[:, parents[1:]]
240 |
241 | transforms_mat = transform_mat(
242 | rot_mats.reshape(-1, 3, 3),
243 | rel_joints.reshape(-1, 3, 1)).view(-1, joints.shape[1], 4, 4)
244 |
245 | transform_chain = [transforms_mat[:, 0]]
246 | for i in range(1, parents.shape[0]):
247 | # Subtract the joint location at the rest pose
248 | # No need for rotation, since it's identity when at rest
249 | curr_res = torch.matmul(transform_chain[parents[i]],
250 | transforms_mat[:, i])
251 | transform_chain.append(curr_res)
252 |
253 | transforms = torch.stack(transform_chain, dim=1)
254 |
255 | # The last column of the transformations contains the posed joints
256 | posed_joints = transforms[:, :, :3, 3]
257 |
258 | joints_homogen = torch.cat([joints, torch.zeros([batch_size, num_joints, 1, 1], dtype=dtype, device=device)],dim=2)
259 | init_bone = torch.matmul(transforms, joints_homogen)
260 | init_bone = F.pad(init_bone, [3, 0, 0, 0, 0, 0, 0, 0])
261 | rel_transforms = transforms - init_bone
262 |
263 | return posed_joints, rel_transforms
264 |
--------------------------------------------------------------------------------
/preprocess/sample_points.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import argparse
4 | import torch
5 | import trimesh
6 |
7 | import numpy as np
8 | import torch.nn.functional as F
9 |
10 | from body_model import BodyModel
11 | from utils import export_points
12 |
13 | from tqdm import tqdm,trange
14 | from shutil import copyfile
15 |
16 | parser = argparse.ArgumentParser('Read and sample AMASS dataset.')
17 | parser.add_argument('--dataset_path', type=str, default='data/',
18 | help='Path to AMASS dataset.')
19 |
20 | parser.add_argument('--poseprior', action='store_true',
21 | help='Generate data for posprior dataset')
22 |
23 | parser.add_argument('--bm_path', type=str, default='lib/smpl/smpl_model',
24 | help='Path to body model')
25 |
26 | parser.add_argument('--bbox_padding', type=float, default=0.,
27 | help='Padding for bounding box')
28 |
29 | parser.add_argument('--output_folder', type=str, default='data/DFaust_processed',
30 | help='Output path for points.')
31 |
32 | parser.add_argument('--points_size', type=int, default=200000,
33 | help='Size of points.')
34 | parser.add_argument('--points_uniform_ratio', type=float, default=.5,
35 | help='Ratio of points to sample uniformly'
36 | 'in bounding box.')
37 | parser.add_argument('--points_sigma', type=float, default=0.01,
38 | help='Standard deviation of gaussian noise added to points'
39 | 'samples on the surfaces.')
40 | parser.add_argument('--points_padding', type=float, default=0.1,
41 | help='Additional padding applied to the uniformly'
42 | 'sampled points on both sides (in total).')
43 |
44 | parser.add_argument('--overwrite', action='store_true',
45 | help='Whether to overwrite output.')
46 | parser.add_argument('--float16', action='store_true',
47 | help='Whether to use half precision.')
48 | parser.add_argument('--packbits', action='store_true',
49 | help='Whether to save truth values as bit array.')
50 |
51 | parser.add_argument('--skip', type=int, default=1,
52 | help='Take every x frames.')
53 |
54 | def process_single_file(vertices, root_orient, pose, joints, bone_transforms, root_loc, frame_name, betas, gender, skinning_weights_vertices, faces, subset, args):
55 | body_mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
56 | # Get extents of model.
57 | bb_min = np.min(vertices, axis=0)
58 | bb_max = np.max(vertices, axis=0)
59 | # total_size = np.sqrt(np.square(bb_max - bb_min).sum())
60 | total_size = (bb_max - bb_min).max()
61 |
62 | # Set the center (although this should usually be the origin already).
63 | loc = np.array(
64 | [(bb_min[0] + bb_max[0]) / 2,
65 | (bb_min[1] + bb_max[1]) / 2,
66 | (bb_min[2] + bb_max[2]) / 2]
67 | )
68 | # Scales all dimensions equally.
69 | scale = total_size / (1 - args.bbox_padding)
70 | export_points(body_mesh, subset, frame_name, loc, scale, args, joints=joints, bone_transforms=bone_transforms, root_loc=root_loc, root_orient=root_orient, pose=pose, betas=betas, gender=gender, skinning_weights_vertices=skinning_weights_vertices, vertices=vertices, faces=faces)
71 |
72 |
73 | def amass_extract(args):
74 | dfaust_dir = os.path.join(args.dataset_path, 'DFaust_67')
75 | subjects = [os.path.basename(s_dir) for s_dir in sorted(glob.glob(os.path.join(dfaust_dir, '*')))]
76 |
77 | for subject in tqdm(subjects):
78 | subject_dir = os.path.join(dfaust_dir, subject)
79 |
80 | shape_data = np.load(os.path.join(subject_dir, 'shape.npz'))
81 |
82 | # Save shape data
83 | output_shape_folder = os.path.join(args.output_folder, 'shapes')
84 | if not os.path.exists(output_shape_folder):
85 | os.makedirs(output_shape_folder)
86 | copyfile(os.path.join(subject_dir, 'shape.npz'), os.path.join(output_shape_folder, '%s_shape.npz'%subject))
87 |
88 | # Generate and save rest-pose for current subject
89 | gender = shape_data['gender'].item()
90 | betas = torch.Tensor(shape_data['betas'][:10]).unsqueeze(0).cuda()
91 | bm_path = os.path.join(args.bm_path, 'SMPL_%s.pkl'%(gender.upper()))
92 | bm = BodyModel(bm_path=bm_path, num_betas=10, batch_size=1).cuda()
93 |
94 | # Get skinning weights
95 | with torch.no_grad():
96 | body = bm(betas=betas)
97 | vertices = body.v.detach().cpu().numpy()[0]
98 | faces = bm.f.detach().cpu().numpy()
99 |
100 | skinning_weights_vertices = bm.weights
101 | skinning_weights_vertices = skinning_weights_vertices.detach().cpu().numpy()
102 |
103 | # Read pose sequences
104 | sequences = []
105 | if args.poseprior:
106 | pose_dir = os.path.join(args.dataset_path, 'MPI_Limits', '03099')
107 | for s_dir in glob.glob(os.path.join(pose_dir, '*.npz')):
108 | sequence = os.path.basename(s_dir)
109 | if sequence in ['op2_poses.npz', 'op3_poses.npz', 'op4_poses.npz', 'op5_poses.npz', 'op7_poses.npz', 'op8_poses.npz', 'op9_poses.npz']:
110 | sequences.append(sequence)
111 | else:
112 | pose_dir = subject_dir
113 | for s_dir in glob.glob(os.path.join(pose_dir, '*.npz')):
114 | sequence = os.path.basename(s_dir)
115 | if sequence not in ['shape.npz']:
116 | sequences.append(sequence)
117 |
118 | for sequence in tqdm(sequences):
119 | sequence_path = os.path.join(pose_dir, sequence)
120 | sequence_name = sequence[:]
121 | data = np.load(sequence_path, allow_pickle=True)
122 |
123 | poses = data['poses'][::args.skip]
124 | trans = data['trans'][::args.skip]
125 |
126 | batch_size = poses.shape[0]
127 | bm = BodyModel(bm_path=bm_path, num_betas=10, batch_size=batch_size).cuda()
128 | faces = bm.f.detach().cpu().numpy()
129 |
130 | pose_body = torch.Tensor(poses[:, 3:66]).cuda()
131 | pose_hand = torch.Tensor(poses[:, 66:72]).cuda()
132 | pose = torch.Tensor(poses[:, :72]).cuda()
133 | root_orient = torch.Tensor(poses[:, :3]).cuda()
134 | trans = torch.zeros(batch_size, 3, dtype=torch.float32).cuda()
135 |
136 | with torch.no_grad():
137 |
138 | body = bm(root_orient=root_orient, pose_body=pose_body, pose_hand=pose_hand, betas=betas.expand(batch_size,-1), trans=trans)
139 |
140 | trans_ = F.pad(trans, [0, 1]).view(batch_size, 1, -1, 1)
141 | trans_ = torch.cat([torch.zeros(batch_size, 1, 4, 3, device=trans_.device), trans_], dim=-1)
142 | bone_transforms = body.bone_transforms + trans_
143 | bone_transforms = torch.inverse(bone_transforms).detach().cpu().numpy()
144 | bone_transforms = bone_transforms[:, :, :]
145 |
146 | pose_body = pose_body.detach().cpu().numpy()
147 | pose = pose.detach().cpu().numpy()
148 | joints = body.Jtr.detach().cpu().numpy()
149 | vertices = body.v.detach().cpu().numpy()
150 | trans = trans.detach().cpu().numpy()
151 | root_orient = root_orient.detach().cpu().numpy()
152 |
153 | for f_idx in trange(batch_size):
154 | frame_name = sequence_name + '_{:06d}'.format(f_idx)
155 | process_single_file(vertices[f_idx],
156 | root_orient[f_idx],
157 | pose[f_idx],
158 | joints[f_idx],
159 | bone_transforms[f_idx],
160 | trans[f_idx],
161 | frame_name,
162 | betas[0].detach().cpu().numpy(),
163 | gender,
164 | skinning_weights_vertices,
165 | faces,
166 | subject,
167 | args)
168 |
169 | def main(args):
170 | amass_extract(args)
171 |
172 |
173 | if __name__ == '__main__':
174 | args = parser.parse_args()
175 | main(args)
176 |
--------------------------------------------------------------------------------
/preprocess/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def export_points(mesh, subject, modelname, loc, scale, args, **kwargs):
8 | if not mesh.is_watertight:
9 | print('Warning: mesh %s is not watertight!'
10 | 'Cannot sample points.' % modelname)
11 | return
12 |
13 | kwargs_new = {}
14 | for k, v in kwargs.items():
15 | if v is not None:
16 | kwargs_new[k] = v
17 |
18 | filename = os.path.join(args.output_folder, 'points', subject, modelname + '.npz')
19 |
20 | if not args.overwrite and os.path.exists(filename):
21 | print('Points already exist: %s' % filename)
22 | return
23 |
24 | n_points_uniform = int(args.points_size * args.points_uniform_ratio)
25 | n_points_surface = args.points_size - n_points_uniform
26 |
27 | boxsize = 1 + args.points_padding
28 | points_uniform = np.random.rand(n_points_uniform, 3)
29 | points_uniform = boxsize * (points_uniform - 0.5)
30 | # Scale points in (padded) unit box back to the original space
31 | points_uniform *= scale
32 | points_uniform += np.expand_dims(loc, axis=0)
33 | # Sample points around mesh surface
34 | points_surface = mesh.sample(n_points_surface)
35 | points_surface += args.points_sigma * np.random.randn(n_points_surface, 3)
36 | points = np.concatenate([points_uniform, points_surface], axis=0)
37 | points = torch.tensor(points).cuda().float().unsqueeze(0)
38 | vertices = torch.tensor(kwargs['vertices']).cuda().float().unsqueeze(0)
39 | faces = torch.tensor(kwargs['faces'], dtype=torch.int64).cuda()
40 |
41 | import kaolin
42 | occupancies = kaolin.ops.mesh.check_sign(vertices, faces, points)
43 |
44 |
45 | points = points.cpu().numpy()
46 | occupancies = occupancies.cpu().numpy()
47 |
48 | # Compress
49 | if args.float16:
50 | dtype = np.float16
51 | else:
52 | dtype = np.float32
53 |
54 | points = points.astype(dtype)
55 |
56 | if args.packbits:
57 | occupancies = np.packbits(occupancies)
58 |
59 | if not os.path.exists(os.path.dirname(filename)):
60 | os.makedirs(os.path.dirname(filename))
61 |
62 | # print('Writing points: %s' % filename)
63 | np.savez(filename, points=points, occupancies=occupancies,
64 | loc=loc, scale=scale,
65 | **kwargs_new)
66 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The TensorFlow Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Set-up script for installing extension modules."""
15 | from Cython.Build import cythonize
16 | import numpy
17 | from setuptools import Extension
18 | from setuptools import setup
19 |
20 | # Get the numpy include directory.
21 | numpy_include_dir = numpy.get_include()
22 |
23 | # mise (efficient mesh extraction)
24 | mise_module = Extension(
25 | "lib.libmise.mise",
26 | sources=["lib/libmise/mise.pyx"],
27 | )
28 |
29 | # Gather all extension modules
30 | ext_modules = [
31 | mise_module,
32 | ]
33 |
34 | setup(ext_modules=cythonize(ext_modules),)
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import glob
4 | import hydra
5 | import torch
6 | import numpy as np
7 | import pytorch_lightning as pl
8 |
9 | from lib.snarf_model import SNARFModel
10 |
11 | @hydra.main(config_path="config", config_name="config")
12 | def main(opt):
13 |
14 | print(opt.pretty())
15 | pl.seed_everything(42, workers=True)
16 | torch.set_num_threads(10)
17 |
18 | datamodule = hydra.utils.instantiate(opt.datamodule, opt.datamodule)
19 | datamodule.setup(stage='test')
20 |
21 | trainer = pl.Trainer(**opt.trainer)
22 |
23 | if opt.epoch == 'last':
24 | checkpoint_path = './checkpoints/last.ckpt'
25 | else:
26 | checkpoint_path = glob.glob('./checkpoints/epoch=%d*.ckpt'%opt.epoch)[0]
27 |
28 | model = SNARFModel.load_from_checkpoint(
29 | checkpoint_path=checkpoint_path,
30 | opt=opt.model,
31 | meta_info=datamodule.meta_info
32 | )
33 | # use all bones for initialization during testing
34 | model.deformer.init_bones = np.arange(24)
35 |
36 | results = trainer.test(model, datamodule=datamodule, verbose=True)
37 |
38 | np.savetxt('./results_%s_%s_%s.txt'%(os.path.basename(opt.datamodule.dataset_path),opt.datamodule.subject, str(opt.epoch)), np.array([results[0]['valid_bbox_iou'], results[0]['valid_surf_iou']]))
39 |
40 | if __name__ == '__main__':
41 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | import pytorch_lightning as pl
3 | import hydra
4 | import torch
5 | import yaml
6 | import os
7 | import numpy as np
8 |
9 | from lib.snarf_model import SNARFModel
10 |
11 | @hydra.main(config_path="config", config_name="config")
12 | def main(opt):
13 |
14 | print(opt.pretty())
15 |
16 | pl.seed_everything(42, workers=True)
17 |
18 | torch.set_num_threads(10)
19 |
20 | # dataset
21 | datamodule = hydra.utils.instantiate(opt.datamodule, opt.datamodule)
22 | datamodule.setup(stage='fit')
23 | np.savez('meta_info.npz', **datamodule.meta_info)
24 |
25 | data_processor = None
26 | if 'processor' in opt.datamodule:
27 | data_processor = hydra.utils.instantiate(opt.datamodule.processor,
28 | opt.datamodule.processor,
29 | meta_info=datamodule.meta_info)
30 |
31 | # logger
32 | with open('.hydra/config.yaml', 'r') as f:
33 | config = yaml.load(f, Loader=yaml.FullLoader)
34 | logger = pl.loggers.WandbLogger(project='snarf', config=config)
35 |
36 | # checkpoint
37 | checkpoint_path = './checkpoints/last.ckpt'
38 | if not os.path.exists(checkpoint_path) or not opt.resume:
39 | checkpoint_path = None
40 |
41 | checkpoint_callback = pl.callbacks.ModelCheckpoint(save_top_k=-1,
42 | monitor=None,
43 | dirpath='./checkpoints',
44 | save_last=True,
45 | every_n_val_epochs=1)
46 |
47 |
48 | trainer = pl.Trainer(logger=logger,
49 | callbacks=[checkpoint_callback],
50 | accelerator=None,
51 | resume_from_checkpoint=checkpoint_path,
52 | **opt.trainer)
53 |
54 | model = SNARFModel(opt=opt.model,
55 | meta_info=datamodule.meta_info,
56 | data_processor=data_processor)
57 |
58 | trainer.fit(model, datamodule=datamodule)
59 |
60 | if __name__ == '__main__':
61 | main()
--------------------------------------------------------------------------------