├── .gitignore
├── LICENSE
├── README.md
├── assets
    ├── doge.gif
    └── teaser.gif
├── checkpoints
    └── .gitkeep
├── config.yaml
├── data
    ├── Custom_train.json
    ├── SIZER_test.json
    ├── THUMAN_train.json
    ├── lowerbody.json
    ├── smpl_mesh.pkl
    └── thuman_smpl_mesh.pkl
├── demo.py
├── generate_dataset.py
├── lib
    ├── datasets
    │   └── customhumans_dataset.py
    ├── models
    │   ├── evaluator.py
    │   ├── feature_dictionary.py
    │   ├── losses.py
    │   ├── networks
    │   │   ├── discriminator.py
    │   │   ├── layers.py
    │   │   ├── mlps.py
    │   │   └── positional_encoding.py
    │   ├── neural_fields.py
    │   ├── tracer.py
    │   └── trainer.py
    ├── ops
    │   └── mesh
    │   │   ├── __init__.py
    │   │   ├── area_weighted_distribution.py
    │   │   ├── barycentric_coordinates.py
    │   │   ├── closest_point.py
    │   │   ├── closest_tex.py
    │   │   ├── compute_sdf.py
    │   │   ├── load_obj.py
    │   │   ├── normalize.py
    │   │   ├── per_face_normals.py
    │   │   ├── per_vertex_normals.py
    │   │   ├── point_sample.py
    │   │   ├── random_face.py
    │   │   ├── sample_near_surface.py
    │   │   ├── sample_surface.py
    │   │   ├── sample_tex.py
    │   │   └── sample_uniform.py
    └── utils
    │   ├── camera.py
    │   ├── config.py
    │   └── image.py
├── requirements.txt
├── smplx
    └── .gitkeep
├── tools
    ├── align_thuman.py
    ├── evaluate.py
    ├── load_json_to_smplx.py
    └── prepare_dataset.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
  1 | .DS_Store
  2 | *.DS_Store
  3 | **.DS_Store
  4 | # Byte-compiled / optimized / DLL files
  5 | __pycache__/
  6 | *.py[cod]
  7 | *$py.class
  8 | 
  9 | # C extensions
 10 | *.so
 11 | 
 12 | # Distribution / packaging
 13 | .Python
 14 | build/
 15 | develop-eggs/
 16 | dist/
 17 | downloads/
 18 | eggs/
 19 | .eggs/
 20 | lib64/
 21 | parts/
 22 | sdist/
 23 | var/
 24 | wheels/
 25 | share/python-wheels/
 26 | *.egg-info/
 27 | .installed.cfg
 28 | *.egg
 29 | MANIFEST
 30 | 
 31 | # PyInstaller
 32 | #  Usually these files are written by a python script from a template
 33 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 34 | *.manifest
 35 | *.spec
 36 | 
 37 | # Installer logs
 38 | pip-log.txt
 39 | pip-delete-this-directory.txt
 40 | 
 41 | # Unit test / coverage reports
 42 | htmlcov/
 43 | .tox/
 44 | .nox/
 45 | .coverage
 46 | .coverage.*
 47 | .cache
 48 | nosetests.xml
 49 | coverage.xml
 50 | *.cover
 51 | *.py,cover
 52 | .hypothesis/
 53 | .pytest_cache/
 54 | cover/
 55 | 
 56 | # Translations
 57 | *.mo
 58 | *.pot
 59 | 
 60 | # Django stuff:
 61 | *.log
 62 | local_settings.py
 63 | db.sqlite3
 64 | db.sqlite3-journal
 65 | 
 66 | # Flask stuff:
 67 | instance/
 68 | .webassets-cache
 69 | 
 70 | # Scrapy stuff:
 71 | .scrapy
 72 | 
 73 | # Sphinx documentation
 74 | docs/_build/
 75 | 
 76 | # PyBuilder
 77 | .pybuilder/
 78 | target/
 79 | 
 80 | # Jupyter Notebook
 81 | .ipynb_checkpoints
 82 | 
 83 | # IPython
 84 | profile_default/
 85 | ipython_config.py
 86 | 
 87 | # pyenv
 88 | #   For a library or package, you might want to ignore these files since the code is
 89 | #   intended to run in multiple environments; otherwise, check them in:
 90 | # .python-version
 91 | 
 92 | # pipenv
 93 | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 94 | #   However, in case of collaboration, if having platform-specific dependencies or dependencies
 95 | #   having no cross-platform support, pipenv may install dependencies that don't work, or not
 96 | #   install all needed dependencies.
 97 | #Pipfile.lock
 98 | 
 99 | # poetry
100 | #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | #   This is especially recommended for binary packages to ensure reproducibility, and is more
102 | #   commonly ignored for libraries.
103 | #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 | 
106 | # pdm
107 | #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | #   in version control.
111 | #   https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 | 
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 | 
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 | 
121 | # SageMath parsed files
122 | *.sage.py
123 | 
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 | 
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 | 
137 | # Rope project settings
138 | .ropeproject
139 | 
140 | # mkdocs documentation
141 | /site
142 | 
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 | 
148 | # Pyre type checker
149 | .pyre/
150 | 
151 | # pytype static type analyzer
152 | .pytype/
153 | 
154 | # Cython debug symbols
155 | cython_debug/
156 | 
157 | # PyCharm
158 | #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | #  and can be added to the global gitignore or merged into this file.  For a more nuclear
161 | #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | 
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2023 custom-humans
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # Learning Locally Editable Virtual Humans
  2 | 
  3 | ## [Project Page](https://custom-humans.github.io/) | [Paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Ho_Learning_Locally_Editable_Virtual_Humans_CVPR_2023_paper.pdf) | [Youtube(3min)](https://youtu.be/aT8ql5hB3ZM), [Shorts(18sec)](https://youtube.com/shorts/6LTXma_wn4c) | [Dataset](https://custom-humans.ait.ethz.ch/)
  4 | 
  5 | 
 
  6 | 
  7 | Official code release for CVPR 2023 paper [*Learning Locally Editable Virtual Humans*](https://custom-humans.github.io/).
  8 | 
  9 | If you find our code, dataset, and paper useful, please cite as
 10 | ```
 11 | @inproceedings{ho2023custom,
 12 |     title={Learning Locally Editable Virtual Humans},
 13 |     author={Ho, Hsuan-I and Xue, Lixin and Song, Jie and Hilliges, Otmar},
 14 |     booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
 15 |     year={2023}
 16 |   }
 17 | ```
 18 | 
 19 | ## Installation
 20 | Our code has been tested with PyTorch 1.11.0, CUDA 11.3, and an RTX 3090 GPU.
 21 | 
 22 | ```bash
 23 | pip install -r requirements.txt
 24 | ```
 25 | 
 26 | ## Quick Start
 27 | 
 28 | ⚠️ The model checkpoint contains several real human bodies and faces. To download the checkpoint file, you need to agree the CustomHumans Dataset Terms of Use. Click [here](https://custom-humans.ait.ethz.ch/) to apply for the dataset. You will find the checkpoint file in the dataset download link.
 29 | 
 30 | 1. Download and put the checkpoint file into the `checkpoints` folder.
 31 | 
 32 | 2. Download the test meshes and images from [here](https://files.ait.ethz.ch/projects/custom-humans/test.zip) and put them into the `data` folder.
 33 | 
 34 | 3. Run a quick demo on fitting to the unseen 3D scan and 2D images.
 35 | ```bash!
 36 | python demo.py --pretrained-root checkpoints/demo --model-name model-1000.pth
 37 | ```
 38 | You should be able to wear me a Doge T-shirt.
 39 | 
 40 | 
 
 41 | 
 42 | 4. Try out different functions such as reposing and cloth transfer in `demo.py`. 
 43 | 
 44 | ## Data Preparation
 45 | 
 46 | ### CustomHumans
 47 | Apply our dataset by sending a [request](https://custom-humans.ait.ethz.ch/). After downloading, you should get 646 textured meshes and SMPL-X meshes. We use only 100 meshes for training. We provide the indices of training meshes [here](https://github.com/custom-humans/editable-humans/blob/main/data/Custom_train.json).
 48 | 
 49 | 1. Prepare the training data following the folder structure:
 50 | ```
 51 | 	training_dataset
 52 | 	├── 0003
 53 | 	│   ├── mesh-f00101.obj
 54 | 	│   ├── mesh-f00101.mtl
 55 | 	│   ├── mesh-f00101.png
 56 | 	│   ├── mesh-f00101.json
 57 | 	│   └── mesh-f00101_smpl.obj
 58 | 	├── 0007
 59 | 	│   ...
 60 | 
 61 | ```
 62 | You can use the following script to generate the training dataset folder:
 63 | ```bash!
 64 | python tools/prepare_dataset.py
 65 | ```
 66 | 
 67 | 2. Download [SMPL-X](https://smpl-x.is.tue.mpg.de/) models and move them to the `smplx` folder.
 68 | You should have the following data structure:
 69 | ```
 70 | 	smplx
 71 | 	├── SMPLX_NEUTRAL.pkl
 72 | 	├── SMPLX_NEUTRAL.npz
 73 | 	├── SMPLX_MALE.pkl
 74 | 	├── SMPLX_MALE.npz
 75 | 	├── SMPLX_FEMALE.pkl
 76 | 	└── SMPLX_FEMALE.npz
 77 | ```
 78 | 3. Since online sampling points on meshes during training can be slow, we sample 18M points per mesh and cache them in an h5 file for training. Run the following script to generate the h5 file.
 79 | 
 80 | ```bash!
 81 | python generate_dataset.py -i /path/to/dataset/folder
 82 | ```
 83 | 
 84 | ⚠️ The script will generate a large h5 file (>80GB). If you don't want to generate that many points, you can adjust the `NUM_SAMPLES` parameter [here](https://github.com/custom-humans/editable-humans/blob/main/generate_dataset.py#L18).
 85 | 
 86 | ### THuman2.0
 87 | 
 88 | We also train our model using 150 scans in Thuman2.0 and you can find their indices [here](https://github.com/custom-humans/editable-humans/blob/main/data/THUMAN_train.json). Please apply for the dataset and SMPL-X registrations through their [official repo](https://github.com/ytrock/THuman2.0-Dataset).
 89 | 
 90 | ⚠️ Note that the scans in THuman2.0 are in various scales. We rescale them to -1~1 based on the SMPL-X models. You can find the rescaling script [here](https://github.com/custom-humans/editable-humans/blob/main/tools/align_thuman.py)
 91 | 
 92 | ⚠️ THuman2.0 uses different settings for creating SMPL-X body meshes. When generating the h5 file, please change to `flat_hand_mean=False` in the [`generate_dataset.py`](https://github.com/custom-humans/editable-humans/blob/main/generate_dataset.py#L42) script.
 93 | 
 94 | ## Training
 95 | 
 96 | Once your h5 dataset is ready, simply run the command to train the model. 
 97 | ```
 98 | python train.py 
 99 | ```
100 | Here are some configuration flags you can use, they will override the setting in `config.yaml`
101 | * `--config`: path to the config file. Default is `config.yaml`
102 | * `--wandb`: we use wandb for monitoring the training. Activate this flag if you want to use it.
103 | * `--save-root`: path to the folder to save the checkpoints. Default is `checkpoints`
104 | * `--data_root`: path to the training h5 dataset. Default is `CustomHumans.h5`
105 | * `--use_2d_from_epoch`: use 2D adversarial loss after this epoch. -1 means never use 2D loss. Default is 10.
106 | 
107 | ## Evaluation
108 | 
109 | We use SIZER to evaluate the geometry fitting performance. Please follow the instructions to download their [dataset](https://github.com/garvita-tiwari/sizer).
110 | 
111 | We provide subjets' [indices](https://github.com/custom-humans/editable-humans/blob/main/data/SIZER_test.json) and [scripts](https://github.com/custom-humans/editable-humans/blob/main/tools/evaluate.py) for evaluation. 
112 | 
113 | # Acknowledgement
114 | We have used codes from other great research work, including [gdna](https://github.com/xuchen-ethz/gdna), [kaolin-wisp](https://github.com/NVIDIAGameWorks/kaolin-wisp), [SMPL-X](https://github.com/vchoutas/smplx), [ML-GSN](https://github.com/apple/ml-gsn/), [StyleGAN-Ada](https://github.com/NVlabs/stylegan2-ada-pytorch), [Occupancy Networks](https://github.com/autonomousvision/occupancy_networks). 
115 | 
116 | We create all the videos using powerful [aitviewer](https://eth-ait.github.io/aitviewer/).
117 | 
118 | We sincerely thank the authors for their awesome work!
119 | 
--------------------------------------------------------------------------------
/assets/doge.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/assets/doge.gif
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/assets/teaser.gif
--------------------------------------------------------------------------------
/checkpoints/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/checkpoints/.gitkeep
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
 1 | global:
 2 |   save_root: './checkpoints'
 3 |   exp_name: 'test-release'
 4 | 
 5 | dataset:
 6 |   data_root: 'CustomHumans.h5'
 7 |   num_samples: 20480
 8 |   repeat_times: 8
 9 | 
10 | optimizer:
11 |   lr_codebook: 0.0005
12 |   lr_decoder: 0.001
13 |   lr_dis: 0.001
14 |   beta1: 0.5
15 |   beta2: 0.999
16 | 
17 | 
18 | train:
19 |   epochs: 5000
20 |   batch_size: 4
21 |   workers: 8
22 |   save_every: 50
23 |   log_every: 100
24 |   use_2d_from_epoch: 10
25 |   train_2d_every_iter: 1
26 |   use_nrm_dis: False
27 |   use_cached_pts: True
28 | 
29 | dictionary:
30 |   shape_dim: 32
31 |   color_dim: 32
32 |   feature_std: 0.1
33 |   feature_bias: 0.0
34 |   shape_pca_dim: 16
35 |   color_pca_dim: 16
36 | 
37 | 
38 | network:
39 |   pos_dim: 3
40 |   c_dim: 3
41 |   num_layers: 4
42 |   hidden_dim: 128
43 |   skip:
44 |     - 2
45 |   activation: 'relu'
46 |   layer_type: 'none'
47 | 
48 | 
49 | embedder:
50 |   shape_freq: 5
51 |   color_freq: 10
52 | 
53 | 
54 | losses:
55 |   lambda_sdf: 100.
56 |   lambda_rgb: 10.
57 |   lambda_nrm: 10.
58 |   lambda_reg: 1.
59 | 
60 |   gan_loss_type: 'logistic'
61 |   lambda_gan: 1.
62 |   lambda_grad: 10.
63 | 
64 | 
65 | validation:
66 |   valid_every: 50
67 |   subdivide: True
68 |   grid_size: 400
69 |   width: 1024
70 |   fov: 20.0
71 |   n_views: 1
72 | 
73 | wandb:
74 |   wandb: False
75 |   wandb_name: 'custom-test'
76 | 
--------------------------------------------------------------------------------
/data/Custom_train.json:
--------------------------------------------------------------------------------
1 | ["0003", "0007", "0011", "0016", "0019", "0023", "0028", "0035", "0041", "0043", "0052", "0056", "0062", "0067", "0071", "0075", "0084", "0088", "0095", "0099", "0104", "0110", "0113", "0120", "0126", "0132", "0138", "0152", "0157", "0164", "0169", "0170", "0176", "0181", "0186", "0190", "0195", "0205", "0208", "0214", "0221", "0225", "0232", "0234", "0242", "0253", "0257", "0264", "0272", "0275", "0281", "0286", "0291", "0300", "0313", "0320", "0325", "0331", "0333", "0345", "0351", "0363", "0367", "0370", "0375", "0380", "0387", "0407", "0413", "0422", "0428", "0431", "0446", "0450", "0459", "0467", "0474", "0485", "0493", "0498", "0508", "0518", "0525", "0539", "0555", "0565", "0571", "0584", "0586", "0590", "0596", "0602", "0609", "0612", "0618", "0621", "0628", "0635", "0637", "0644"]
--------------------------------------------------------------------------------
/data/SIZER_test.json:
--------------------------------------------------------------------------------
1 | ["10032-3612", "10037-4262", "10040-4311", "10041-4457", "10071-7028", "10090-8110", "10091-8164", "10115-9709"]
--------------------------------------------------------------------------------
/data/THUMAN_train.json:
--------------------------------------------------------------------------------
1 | ["0000", "0001", "0005", "0006", "0007", "0008", "0017", "0021", "0024", "0025", "0028", "0034", "0037", "0038", "0052", "0053", "0054", "0056", "0057", "0060", "0070", "0071", "0078", "0083", "0087", "0088", "0092", "0095", "0099", "0103", "0107", "0108", "0110", "0116", "0119", "0121", "0125", "0126", "0128", "0129", "0132", "0136", "0139", "0144", "0146", "0151", "0155", "0160", "0164", "0167", "0168", "0173", "0176", "0181", "0184", "0185", "0187", "0193", "0197", "0200", "0203", "0204", "0210", "0215", "0216", "0228", "0229", "0241", "0243", "0252", "0266", "0273", "0282", "0283", "0285", "0286", "0293", "0296", "0299", "0303", "0307", "0308", "0311", "0314", "0318", "0322", "0327", "0329", "0330", "0338", "0339", "0342", "0345", "0348", "0351", "0354", "0356", "0362", "0365", "0369", "0376", "0377", "0378", "0381", "0384", "0387", "0391", "0393", "0394", "0398", "0401", "0402", "0405", "0412", "0415", "0421", "0425", "0426", "0428", "0430", "0433", "0434", "0435", "0437", "0440", "0441", "0445", "0448", "0453", "0455", "0459", "0460", "0462", "0463", "0467", "0470", "0471", "0476", "0480", "0482", "0488", "0491", "0494", "0496", "0499", "0501", "0502", "0503", "0507", "0522"]
--------------------------------------------------------------------------------
/data/smpl_mesh.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/data/smpl_mesh.pkl
--------------------------------------------------------------------------------
/data/thuman_smpl_mesh.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/data/thuman_smpl_mesh.pkl
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
 1 | import os, sys
 2 | import logging as log
 3 | import numpy as np
 4 | import torch
 5 | import pickle
 6 | import random
 7 | import json
 8 | from lib.models.evaluator import Evaluator
 9 | from lib.models.trainer import Trainer
10 | 
11 | from lib.utils.config import *
12 | from lib.utils.image import update_edited_images
13 | 
14 | def main(config):
15 |     
16 |     # Set random seed.
17 |     random.seed(config.seed)
18 |     np.random.seed(config.seed)
19 |     torch.manual_seed(config.seed)
20 | 
21 |     log_dir = config.pretrained_root
22 | 
23 |     with open('data/smpl_mesh.pkl', 'rb') as f:
24 |         smpl_mesh = pickle.load(f)
25 | 
26 |     trainer = Trainer(config, smpl_mesh['smpl_V'], smpl_mesh['smpl_F'], log_dir)
27 | 
28 |     trainer.load_checkpoint(os.path.join(config.pretrained_root, config.model_name))
29 | 
30 | 
31 |     evaluator = Evaluator(config, log_dir, mode='test')
32 | 
33 |     evaluator.init_models(trainer)
34 | 
35 |     # Fitting the 32th feature codebook to the unseen 3D mesh
36 |     evaluator.fitting_3D(32, 'data/test/mesh/mesh-f00194.obj', 'data/test/mesh/mesh-f00194_smpl.obj', fit_rgb=True)
37 |     
38 |     # Generate the 3D mesh using marching cube
39 |     evaluator.reconstruction(32, epoch=999)
40 | 
41 |     # Render the 3D mesh to 2D images
42 |     #rendered = evaluator.render_2D(32, epoch=999)
43 | 
44 |     # Get the training points from the edited images
45 |     rendered = update_edited_images('data/test/images', 'data/test/render_dict.pkl')
46 | 
47 |     # Fitting the 32th texture codebook to the edited images
48 |     evaluator.fitting_2D(32, rendered, 'data/test/mesh/mesh-f00194_smpl.obj')
49 | 
50 |     # Generate the edited 3D mesh using marching cube
51 |     evaluator.reconstruction(32, epoch=998)
52 | 
53 |     # Repose the 32th subject to a new SMPL-X pose
54 |     #evaluator.reposing(32, 'data/test/mesh/mesh-f00181_smpl.obj',  epoch=997)
55 | 
56 |     # Clothing transfer
57 |     # Load the indices of the lower body vertices
58 |     #idx = json.load(open('data/lowerbody.json'))
59 | 
60 |     # Fitting the 33th feature codebook to the other 3D scan
61 |     #evaluator.fitting_3D(33, 'data/test/mesh/mesh-f00181.obj', 'data/test/mesh/mesh-f00181_smpl.obj', fit_rgb=True)
62 | 
63 |     # Transfer the clothing (idx) from the 32th subject to the 33th subject
64 |     #evaluator.transfer_features(32, 33, idx)
65 | 
66 |     # Generate the transferred 3D mesh using marching cube
67 |     #evaluator.reconstruction(33, epoch=996)
68 | 
69 | 
70 | if __name__ == "__main__":
71 | 
72 |     parser = parse_options()
73 |     parser.add_argument('--pretrained-root', type=str, required=True, help='pretrained model path')
74 |     parser.add_argument('--model-name', type=str, required=True, help='load model name')
75 | 
76 |     args, args_str = argparse_to_str(parser)
77 |     handlers = [logging.StreamHandler(sys.stdout)]
78 |     logging.basicConfig(level=args.log_level,
79 |                         format='%(asctime)s|%(levelname)8s| %(message)s',
80 |                         handlers=handlers)
81 |     logging.info(f'Info: \n{args_str}')
82 |     main(args)
--------------------------------------------------------------------------------
/generate_dataset.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import h5py
  3 | import numpy as np
  4 | import json
  5 | from tqdm import tqdm
  6 | import argparse
  7 | import torch
  8 | import pickle
  9 | import kaolin as kal
 10 | from kaolin.render.camera import *
 11 | 
 12 | from lib.utils.camera import *
 13 | from lib.ops.mesh import *
 14 | from smplx import SMPLX
 15 | 
 16 | SMPL_PATH = 'smplx/'
 17 | 
 18 | NUM_SAMPLES = 3000000
 19 | 
 20 | N_VIEWS = 4
 21 | FOV = 20
 22 | HEIGHT = 1024
 23 | WIDTH = 1024
 24 | RATIO = 1.0
 25 | 
 26 | N_JOINTS = 25
 27 | HALF_PATCH_SIZE = 64
 28 | 
 29 | def _get_smpl_vertices(smpl_data):
 30 |     device = torch.device('cuda')
 31 |     param_betas = torch.tensor(smpl_data['betas'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
 32 |     param_poses = torch.tensor(smpl_data['body_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
 33 |     param_left_hand_pose = torch.tensor(smpl_data['left_hand_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
 34 |     param_right_hand_pose = torch.tensor(smpl_data['right_hand_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
 35 |             
 36 |     param_expression = torch.tensor(smpl_data['expression'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
 37 |     param_jaw_pose = torch.tensor(smpl_data['jaw_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
 38 |     param_leye_pose = torch.tensor(smpl_data['leye_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
 39 |     param_reye_pose = torch.tensor(smpl_data['reye_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
 40 | 
 41 | 
 42 |     body_model = SMPLX(model_path=SMPL_PATH, gender='male', use_pca=True, num_pca_comps=12, flat_hand_mean=True).to(device)
 43 |                 
 44 |     J_0 = body_model(body_pose = param_poses, betas=param_betas).joints.contiguous().detach()
 45 | 
 46 | 
 47 |     output = body_model(betas=param_betas,
 48 |                                    body_pose=param_poses,
 49 |                                    transl=-J_0[:,0,:],
 50 |                                    left_hand_pose=param_left_hand_pose,
 51 |                                    right_hand_pose=param_right_hand_pose,
 52 |                                    expression=param_expression,
 53 |                                    jaw_pose=param_jaw_pose,
 54 |                                    leye_pose=param_leye_pose,
 55 |                                    reye_pose=param_reye_pose,
 56 |                                    )
 57 |     return output.vertices.contiguous()[0].detach(), \
 58 |            output.joints.contiguous()[0].detach()[:25]
 59 | 
 60 | 
 61 | #########################################################################################################################
 62 | 
 63 | def main(args):
 64 |     device = torch.device('cuda')
 65 | 
 66 |     outfile = h5py.File(os.path.join(args.output_path), 'w')
 67 | 
 68 |     subject_list  = [x for x in sorted(os.listdir(args.input_path)) if os.path.isdir(os.path.join(args.input_path, x))]
 69 |     num_subjects = len(subject_list)
 70 | 
 71 |     outfile.create_dataset( 'num_subjects', data=num_subjects, dtype=np.int32)
 72 | 
 73 | 
 74 | 
 75 |     dataset_pts = outfile.create_dataset( 'pts', shape=(num_subjects, NUM_SAMPLES*6, 3),
 76 |                                  chunks=True, dtype=np.float32)
 77 |     dataset_rgb = outfile.create_dataset( 'rgb',shape=(num_subjects, NUM_SAMPLES*6, 3),
 78 |                                  chunks=True, dtype=np.float32)
 79 |     dataset_nrm = outfile.create_dataset( 'nrm', shape=(num_subjects, NUM_SAMPLES*6, 3),
 80 |                                  chunks=True, dtype=np.float32)
 81 |     dataset_d = outfile.create_dataset( 'd', shape=(num_subjects, NUM_SAMPLES*6, 1),
 82 |                                  chunks=True, dtype=np.float32)
 83 |     
 84 | 
 85 |     dataset_smpl_v = outfile.create_dataset( 'smpl_v', shape=(num_subjects, 10475, 3),
 86 |                                  chunks=True, dtype=np.float32)
 87 | 
 88 |     dataset_ray_ori_image = outfile.create_dataset( 'ray_ori_image', shape=(num_subjects, N_JOINTS*4,
 89 |                                  HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3),
 90 |                                  chunks=True, dtype=np.float32)
 91 |     
 92 |     dataset_ray_dir_image = outfile.create_dataset( 'ray_dir_image', shape=(num_subjects, N_JOINTS*4,
 93 |                                  HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3),
 94 |                                  chunks=True, dtype=np.float32)
 95 | 
 96 | 
 97 |     dataset_xyz_image = outfile.create_dataset( 'xyz_image', shape=(num_subjects, N_JOINTS*4,
 98 |                                  HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3),
 99 |                                  chunks=True, dtype=np.float32)
100 |     dataset_nrm_image = outfile.create_dataset( 'nrm_image', shape=(num_subjects, N_JOINTS*4,
101 |                                  HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3),
102 |                                  chunks=True, dtype=np.float32)
103 |     dataset_rgb_image = outfile.create_dataset( 'rgb_image', shape=(num_subjects, N_JOINTS*4,
104 |                                  HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3),
105 |                                  chunks=True, dtype=np.float32)
106 |     dataset_mask_image = outfile.create_dataset( 'mask_image', shape=(num_subjects, N_JOINTS*4,\
107 |                                  HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 1),
108 |                                     chunks=True, dtype=np.bool)
109 |     
110 |     for s, subject in enumerate(tqdm(subject_list)):
111 |         subject_path = os.path.join(args.input_path, subject)
112 |         json_file  = [x for x in sorted(os.listdir(subject_path)) if x.endswith('.json')][0]
113 |         filename = json_file.split('.')[0]
114 | 
115 |         smpl_data = json.load(open(os.path.join(subject_path, filename+'.json')))
116 |         smpl_V, smpl_J = _get_smpl_vertices(smpl_data)
117 |         with open('data/smpl_mesh.pkl', 'rb') as f:
118 |             smpl_mesh = pickle.load(f)
119 | 
120 |         smpl_F = smpl_mesh['smpl_F'].cuda().detach()
121 | 
122 | 
123 |         mesh_data = os.path.join(subject_path, filename+'.obj')
124 |         out = load_obj(mesh_data, load_materials=True)
125 |         V, F, texv, texf, mats = out
126 |         FN = per_face_normals(V, F).cuda()
127 |         
128 | 
129 |         pts1 = point_sample( V.cuda(), F.cuda(), ['near', 'near', 'trace'], NUM_SAMPLES, 0.01)
130 |         pts2 = point_sample(smpl_V, smpl_F, ['rand', 'near', 'trace'], NUM_SAMPLES, 0.1)
131 | 
132 |         rgb1, nrm1, d1 = closest_tex(V.cuda(), F.cuda(), 
133 |                                            texv.cuda(), texf.cuda(), mats, pts1.cuda())
134 |         rgb2, nrm2, d2 = closest_tex(V.cuda(), F.cuda(), 
135 |                                                texv.cuda(), texf.cuda(), mats, pts2.cuda())
136 | 
137 |             
138 |         look_at = torch.zeros( (N_VIEWS, 3), dtype=torch.float32, device=device)
139 | 
140 | 
141 |         camera_position = torch.tensor( [ [0, 0, 2],
142 |                                              [2, 0, 0],
143 |                                              [0, 0, -2],
144 |                                              [-2, 0, 0]  ]  , dtype=torch.float32, device=device)
145 | 
146 |         camera_up_direction = torch.tensor( [[0, 1, 0]], dtype=torch.float32, device=device).repeat(N_VIEWS, 1,)
147 | 
148 |         cam_transform = generate_transformation_matrix(camera_position, look_at, camera_up_direction)
149 |         cam_proj = generate_perspective_projection(FOV, RATIO)
150 | 
151 |         face_vertices_camera, face_vertices_image, face_normals = \
152 |                 kal.render.mesh.prepare_vertices(
153 |                 V.unsqueeze(0).repeat(N_VIEWS, 1, 1).cuda(),
154 |                 F.cuda(), cam_proj.cuda(), camera_transform=cam_transform
155 |             )
156 |         face_uvs = texv[texf[...,:3]].unsqueeze(0).cuda()
157 | 
158 |          ### Perform Rasterization ###
159 |             # Construct attributes that DIB-R rasterizer will interpolate.
160 |             # the first is the UVS associated to each face
161 |             # the second will make a hard segmentation mask
162 |         face_attributes = [
163 |                 V[F].unsqueeze(0).cuda().repeat(N_VIEWS, 1, 1, 1),
164 |                 face_uvs.repeat(N_VIEWS, 1, 1, 1),
165 |                 FN.unsqueeze(0).unsqueeze(2).repeat(N_VIEWS, 1, 3, 1),
166 |         ]            
167 | 
168 |         padded_joints = torch.nn.functional.pad(
169 |         smpl_J.unsqueeze(0).repeat(N_VIEWS, 1, 1), (0, 1), mode='constant', value=1.)
170 | 
171 |         joints_camera = (padded_joints @ cam_transform)
172 |         # Project the vertices on the camera image plan
173 |         jonts_image = perspective_camera(joints_camera, cam_proj.cuda())
174 |         jonts_image = ((jonts_image) * torch.tensor([1, -1], device=device)  + 1 ) * \
175 |                            torch.tensor([WIDTH//2, HEIGHT//2], device=device)
176 |         # If you have nvdiffrast installed you can change rast_backend to
177 |         # nvdiffrast or nvdiffrast_fwd
178 |         image_features, face_idx = kal.render.mesh.rasterize(
179 |         HEIGHT, WIDTH, face_vertices_camera[:, :, :, -1],
180 |         face_vertices_image, face_attributes, backend='cuda', multiplier=1000)
181 | 
182 |         coords, uv, normal= image_features
183 | 
184 |         TM = torch.zeros((N_VIEWS, HEIGHT, WIDTH, 1), dtype=torch.long, device=device)
185 | 
186 |         rgb = sample_tex(uv.view(-1, 2), TM.view(-1), mats).view(N_VIEWS, HEIGHT, WIDTH, 3)
187 |         mask = (face_idx != -1).unsqueeze(-1)
188 | 
189 | 
190 |         ray_dir_patches = []
191 |         ray_ori_patches = []
192 |         xyz_patches = []
193 |         rgb_patches = []
194 |         nrm_patches = []
195 |         mask_patches = []
196 | 
197 |         for i in range(N_VIEWS):
198 | 
199 |             camera = Camera.from_args(eye=camera_position[i],
200 |                                       at=look_at[i],
201 |                                       up=camera_up_direction[i],
202 |                                       fov=FOV,
203 |                                       width=WIDTH,
204 |                                       height=HEIGHT,
205 |                                       dtype=torch.float32)
206 | 
207 |             ray_grid = generate_centered_pixel_coords(camera.width, camera.height,
208 |                                                   camera.width, camera.height, device=device)
209 | 
210 |             ray_orig, ray_dir = \
211 |                 generate_pinhole_rays(camera.to(ray_grid[0].device), ray_grid)
212 | 
213 |             ray_orig = ray_orig.reshape(camera.height, camera.width, -1)
214 |             ray_dir = ray_dir.reshape(camera.height, camera.width, -1)
215 | 
216 |             for j in range(N_JOINTS):
217 |                 x = min (max( int(jonts_image[i, j, 0]), HALF_PATCH_SIZE), WIDTH - HALF_PATCH_SIZE)
218 |                 y = min (max( int(jonts_image[i, j, 1]), HALF_PATCH_SIZE), HEIGHT - HALF_PATCH_SIZE)
219 | 
220 |                 ray_ori_patches.append( ray_orig[y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] )
221 |                 ray_dir_patches.append( ray_dir[y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] )
222 |                 xyz_patches.append( coords[i, y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] )
223 |                 rgb_patches.append( rgb[i, y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] )
224 |                 nrm_patches.append( normal[i, y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] )
225 |                 mask_patches.append( mask[i, y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] )
226 | 
227 |             
228 |         dataset_pts[s] = torch.cat([pts1, pts2], dim=0).detach().cpu().numpy()
229 |         dataset_rgb[s] = torch.cat([rgb1, rgb2], dim=0).detach().cpu().numpy()
230 |         dataset_nrm[s] = torch.cat([nrm1, nrm2], dim=0).detach().cpu().numpy()
231 |         dataset_d[s] = torch.cat([d1, d2], dim=0).detach().cpu().numpy()
232 |         dataset_smpl_v[s] = smpl_V.detach().cpu().numpy()
233 |         dataset_xyz_image[s] = torch.stack(xyz_patches).detach().cpu().numpy()
234 |         dataset_rgb_image[s] = torch.stack(rgb_patches).detach().cpu().numpy()
235 |         dataset_nrm_image[s] = torch.stack(nrm_patches).detach().cpu().numpy()
236 |         dataset_mask_image[s] = torch.stack(mask_patches).detach().cpu().numpy()
237 |         dataset_ray_ori_image[s] = torch.stack(ray_ori_patches).detach().cpu().numpy()
238 |         dataset_ray_dir_image[s] = torch.stack(ray_dir_patches).detach().cpu().numpy()
239 | 
240 | 
241 |     outfile.close()
242 | 
243 | 
244 | if __name__ == "__main__":
245 |     parser = argparse.ArgumentParser(description='Process dataset to H5 file')
246 | 
247 |     parser.add_argument("-i", "--input_path", default='./CustomHumans/training_dataset', type=str, help="Path of the input mesh folder")
248 |     parser.add_argument("-o", "--output_path", default='./CustomHumans.h5', type=str, help="Path of the output h5 file")
249 | 
250 |     main(parser.parse_args())
251 | 
--------------------------------------------------------------------------------
/lib/datasets/customhumans_dataset.py:
--------------------------------------------------------------------------------
  1 | 
  2 | import h5py
  3 | import numpy as np
  4 | import torch
  5 | from torch.utils.data import Dataset
  6 | import logging as log
  7 | import time
  8 | 
  9 | class CustomHumanDataset(Dataset):
 10 |     """Base class for single mesh datasets with points sampled only at a given octree sampling region.
 11 |     """
 12 | 
 13 |     def __init__(self, 
 14 |         num_samples        : int = 20480,
 15 |         repeat_times      : int = 8,
 16 |     ):
 17 |         """Construct dataset. This dataset also needs to be initialized.
 18 |         """
 19 |         self.repeat_times = repeat_times    # epeate how many times each epoch
 20 |         self.num_samples = num_samples      # number of points per subject
 21 | 
 22 |         self.initialization_mode = None
 23 |         self.label_map = {
 24 |             '0': 1, '1': 2, '2': 2, '3': 1, '4': 2,
 25 |             '5': 2, '6': 1, '7': 2, '8': 2, '9': 1,
 26 |             '10': 2, '11': 2, '12': 1, '13': 1, '14': 1,
 27 |             '15': 0, '16': 1, '17': 1, '18': 1, '19': 1,
 28 |             '20': 1, '21': 1, '22': 0, '23': 0, '24': 0,
 29 |             }
 30 | 
 31 |     def init_from_h5(self, dataset_path):
 32 |         """Initializes the dataset from a h5 file.
 33 |            copy smpl_v from h5 file.
 34 |         """
 35 | 
 36 |         self.h5_path = dataset_path
 37 |         with h5py.File(dataset_path, "r") as f:
 38 |             try:
 39 |                 self.num_subjects = f['num_subjects'][()]
 40 |                 self.num_pts = f['d'].shape[1]
 41 |                 self.smpl_V = torch.tensor(np.array(f['smpl_v']))
 42 |             except:
 43 |                 raise ValueError("[Error] Can't load from h5 dataset")
 44 |         self.resample()
 45 |         self.initialization_mode = "h5"
 46 | 
 47 |     def resample(self):
 48 |         """Resamples a new working set of indices.
 49 |         """
 50 |         
 51 |         start = time.time()
 52 |         log.info(f"Resampling...")
 53 | 
 54 |         self.id = np.random.randint(0, self.num_subjects, self.num_subjects * self.repeat_times)
 55 | 
 56 |         log.info(f"Time: {time.time() - start}")
 57 | 
 58 |     def _get_h5_data(self, subject_id, pts_id, img_id):
 59 |         with h5py.File(self.h5_path, "r") as f:
 60 |             try:
 61 |                 pts = np.array(f['pts'][subject_id,pts_id])
 62 |                 d = np.array(f['d'][subject_id,pts_id])
 63 |                 nrm = np.array(f['nrm'][subject_id,pts_id])
 64 |                 rgb = np.array(f['rgb'][subject_id,pts_id])
 65 |                 image_label = self.label_map[str(img_id[0] % 25)]
 66 | 
 67 |                 xyz_image = np.array(f['xyz_image'][subject_id,img_id])
 68 |                 rgb_image = np.array(f['rgb_image'][subject_id,img_id])
 69 |                 nrm_image = np.array(f['nrm_image'][subject_id,img_id])
 70 |                 mask_image = np.array(f['mask_image'][subject_id,img_id])
 71 |                 ray_ori_image = np.array(f['ray_ori_image'][subject_id,img_id])
 72 |                 ray_dir_image = np.array(f['ray_dir_image'][subject_id,img_id])
 73 | 
 74 |             except:
 75 |                 raise ValueError("[Error] Can't read key (%s, %s, %s) from h5 dataset" % (subject_id, pts_id, img_id))
 76 | 
 77 |         return {
 78 |                 'pts' : pts, 'sdf' : d, 'nrm' : nrm, 'rgb' : rgb, 'idx' : subject_id, 'label' : image_label,
 79 |                 'xyz_image' : xyz_image, 'rgb_image' : rgb_image,  'nrm_image' : nrm_image,
 80 |                 'mask_image' : mask_image, 'ray_ori_image' : ray_ori_image, 'ray_dir_image' : ray_dir_image
 81 |         }
 82 | 
 83 |     def __getitem__(self, idx: int):
 84 |         """Retrieve point sample."""
 85 |         if self.initialization_mode is None:
 86 |             raise Exception("The dataset is not initialized.")
 87 |         
 88 |         subject_id = self.id[idx]
 89 |         # points id need to be in accending order
 90 |         pts_id = np.random.randint(self.num_pts - self.num_samples, size=1)
 91 |         img_id = np.random.randint(100, size=1)
 92 | 
 93 |         return self._get_h5_data(subject_id, np.arange(pts_id, pts_id + self.num_samples), img_id)
 94 |     
 95 |     def __len__(self):
 96 |         """Return length of dataset (number of _samples_)."""
 97 |         if self.initialization_mode is None:
 98 |             raise Exception("The dataset is not initialized.")
 99 | 
100 |         return self.num_subjects * self.repeat_times
101 | 
--------------------------------------------------------------------------------
/lib/models/evaluator.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import time
  3 | import copy
  4 | import pickle
  5 | import torch
  6 | import trimesh
  7 | 
  8 | import numpy as np
  9 | import logging as log
 10 | from tqdm import tqdm
 11 | from PIL import Image
 12 | 
 13 | from .tracer import SDFTracer
 14 | from ..ops.mesh import load_obj, point_sample, closest_tex
 15 | from ..utils.camera import *
 16 | 
 17 | from kaolin.ops.conversions import voxelgrids_to_trianglemeshes
 18 | from kaolin.ops.mesh import subdivide_trianglemesh
 19 | 
 20 | class Evaluator(object):
 21 | 
 22 |     def __init__(self, config, log_dir, mode='valid'):
 23 |         super().__init__()
 24 | 
 25 |         self.cfg = config
 26 |         self.log_dir = log_dir
 27 |         self.mesh_dir = os.path.join(log_dir, mode, 'meshes')
 28 |         os.makedirs(self.mesh_dir, exist_ok=True)
 29 |         self.image_dir = os.path.join(log_dir, mode, 'images')
 30 |         os.makedirs(self.image_dir, exist_ok=True)
 31 | 
 32 |         self.sdf_field = None
 33 |         self.rgb_field = None
 34 | 
 35 |         self.tracer = SDFTracer(self.cfg)
 36 |         self.subdivide = self.cfg.subdivide
 37 |         self.res = self.cfg.grid_size
 38 | 
 39 | 
 40 |     def init_models(self, trainer):
 41 |         '''Initialize the models for evaluation.
 42 |         Args:
 43 |             sdf_field (SDFNet): the sdf field model from trainer.
 44 |             rgb_field (RGBNet): the rgb field model from trainer.
 45 |         '''
 46 | 
 47 |         self.sdf_field = copy.deepcopy(trainer.sdf_field)
 48 |         self.rgb_field = copy.deepcopy(trainer.rgb_field)
 49 |         self.smpl_F = trainer.smpl_F.clone().detach().cpu()
 50 | 
 51 |     def _marching_cubes (self, geo_idx=0, tex_idx=None, subdivide=True, res=300):
 52 |         '''Marching cubes to generate mesh.
 53 |         Args:
 54 |             geo_idx (int): the index of geometry to be generated.
 55 |             tex_idx (int): the index of texture to be generated.
 56 |             subdivide (bool): whether to subdivide the mesh.
 57 |             res (int): the resolution of the marching cubes.
 58 |         Returns:
 59 |             mesh (trimesh): the generated mesh.
 60 |         '''
 61 | 
 62 |         width = res
 63 |         window_x = torch.linspace(-1., 1., steps=width, device='cuda')
 64 |         window_y = torch.linspace(-1., 1., steps=width, device='cuda')
 65 |         window_z = torch.linspace(-1., 1., steps=width, device='cuda')
 66 | 
 67 |         coord = torch.stack(torch.meshgrid(window_x, window_y, window_z, indexing='ij')).permute(1, 2, 3, 0).reshape(1, -1, 3).contiguous()
 68 | 
 69 |         
 70 |         # Debug smpl grid
 71 |         #smpl_vertice = self.sdf_field.get_smpl_vertices_by_idx(geo_idx)
 72 |         #d = trimesh.Trimesh(vertices=smpl_vertice.cpu().detach().numpy(), 
 73 |         #            faces=self.smpl_F.cpu().detach().numpy())
 74 |         #d.export(os.path.join(self.log_dir, 'smpl_sub_%03d.obj' % (geo_idx)) )
 75 |         
 76 |         if tex_idx is None:
 77 |             tex_idx = geo_idx
 78 |         geo_idx = torch.tensor([geo_idx], dtype=torch.long, device = torch.device('cuda')).view(1).detach()
 79 |         tex_idx = torch.tensor([tex_idx], dtype=torch.long, device = torch.device('cuda')).view(1).detach()
 80 | 
 81 |         _points = torch.split(coord, int(2*1e6), dim=1)
 82 |         voxels = []
 83 |         for _p in _points:
 84 |             pred_sdf = self.sdf_field(_p, geo_idx)
 85 |             voxels.append(pred_sdf)
 86 | 
 87 |         voxels = torch.cat(voxels, dim=1)
 88 |         voxels = voxels.reshape(1, width, width, width)
 89 |         
 90 |         vertices, faces = voxelgrids_to_trianglemeshes(voxels, iso_value=0.)
 91 |         vertices = ((vertices[0].reshape(1, -1, 3) - 0.5) / (width/2)) - 1.0
 92 |         faces = faces[0]
 93 | 
 94 |         if subdivide:
 95 |             vertices, faces = subdivide_trianglemesh(vertices, faces, iterations=1)
 96 | 
 97 |         pred_rgb = self.rgb_field(vertices, tex_idx, pose_idx=geo_idx)            
 98 |         
 99 |         h = trimesh.Trimesh(vertices=vertices[0].cpu().detach().numpy(), 
100 |                 faces=faces.cpu().detach().numpy(), 
101 |                 vertex_colors=pred_rgb[0].cpu().detach().numpy())
102 | 
103 |         # remove disconnect par of mesh
104 |         connected_comp = h.split(only_watertight=False)
105 |         max_area = 0
106 |         max_comp = None
107 |         for comp in connected_comp:
108 |             if comp.area > max_area:
109 |                 max_area = comp.area
110 |                 max_comp = comp
111 |         h = max_comp
112 |     
113 |         trimesh.repair.fix_inversion(h)
114 | 
115 |         return h
116 |     
117 |     def _get_camera_rays(self, n_views=4, fov=20, width=1024):
118 |         '''Get camera rays for rendering.
119 |         Args:
120 |             n_views (int): the number of views.
121 |             fov (float): the field of view.
122 |             width (int): the width of the image.
123 |         Returns:
124 |             ray_o_images : the origin of the rays of n_views*height*width*3
125 |             ray_d_images : the direction of the rays of n_views*height*width*3
126 |         '''
127 |             
128 |         look_at = torch.zeros( (n_views, 3), dtype=torch.float32, device=torch.device('cuda'))
129 |         camera_up_direction = torch.tensor( [[0, 1, 0]], dtype=torch.float32, device=torch.device('cuda')).repeat(n_views, 1,)
130 |         angle = torch.linspace(0, 2*np.pi, n_views+1)[:-1]
131 |         camera_position = torch.stack( (2*torch.sin(angle), torch.zeros_like(angle), 2*torch.cos(angle)), dim=1).cuda()
132 | 
133 |         ray_o_images = []
134 |         ray_d_images = []
135 |         for i in range(n_views):
136 |             camera = Camera.from_args(eye=camera_position[i],
137 |                                       at=look_at[i],
138 |                                       up=camera_up_direction[i],
139 |                                       fov=fov,
140 |                                       width=width,
141 |                                       height=width,
142 |                                       dtype=torch.float32)
143 | 
144 |             ray_grid = generate_centered_pixel_coords(camera.width, camera.height,
145 |                                                   camera.width, camera.height, device=torch.device('cuda'))
146 | 
147 |             ray_orig, ray_dir = \
148 |                 generate_pinhole_rays(camera.to(ray_grid[0].device), ray_grid)
149 | 
150 |             ray_o_images.append(ray_orig.reshape(camera.height, camera.width, -1))
151 |             ray_d_images.append(ray_dir.reshape(camera.height, camera.width, -1))
152 | 
153 |         return torch.stack(ray_o_images, dim=0), torch.stack(ray_d_images, dim=0)
154 | 
155 |     def reconstruction(self, idx, epoch=None):
156 |         '''
157 |         Reconstruct the mesh the idx-th subject.
158 |         '''
159 |         if epoch is None:
160 |             epoch = 0
161 |         log.info(f"Reconstructing {idx}th mesh at epoch {epoch}...")
162 |         start = time.time()
163 |         
164 |         with torch.no_grad():
165 |             h = self._marching_cubes (geo_idx=idx, subdivide=self.subdivide, res=self.res)
166 |         
167 |         h.export(os.path.join(self.mesh_dir, '%03d_reco_src-%03d.obj' % (epoch, idx)) )
168 |         end = time.time()
169 |         log.info(f"Reconstruction finished in {end-start} seconds.")
170 | 
171 |     def render_2D(self, idx, epoch=None):
172 |         '''
173 |         Render the 2D images of the idx-th subject.
174 |         '''
175 |         torch.cuda.empty_cache()
176 | 
177 |         log.info(f"Rendering {idx}th subject at epoch {epoch}...")
178 |         start = time.time()
179 | 
180 |         with torch.no_grad():
181 | 
182 |             ray_o_images, ray_d_images = self._get_camera_rays(n_views=self.cfg.n_views, fov=self.cfg.fov, width=self.cfg.width)
183 |             _idx = torch.tensor([idx], dtype=torch.long, device = torch.device('cuda')).repeat(self.cfg.n_views).detach()
184 |             x, hit = self.tracer(self.sdf_field.forward, _idx,
185 |                               ray_o_images.view(self.cfg.n_views, -1, 3),
186 |                               ray_d_images.view(self.cfg.n_views, -1, 3))
187 |             log.info(f"Rat tracing finished in {time.time()-start} seconds.")
188 |             start = time.time()
189 |             rgb_2d = self.rgb_field.forward(x.detach(), _idx) * hit
190 | 
191 |         rgb_img = rgb_2d.reshape(self.cfg.n_views, self.cfg.width, self.cfg.width, 3).cpu().detach().numpy() * 255
192 | 
193 |         for i in range(self.cfg.n_views):
194 |             Image.fromarray(rgb_img[i].astype(np.uint8)).save(
195 |                 os.path.join(self.image_dir, '%03d_render_src-%03d_view-%03d.png' % (epoch, idx, i)) )
196 | 
197 |         log.info(f"Rendering finished in {time.time()-start} seconds.")
198 |         render_dict = {'coord': x.cpu().detach(), 'rgb': rgb_2d.cpu().detach(), 'mask': hit.cpu().detach()}
199 |         with open(os.path.join(self.image_dir, 'render_dict.pkl'), 'wb') as f:
200 |             pickle.dump(render_dict, f)
201 | 
202 |         return render_dict
203 | 
204 | 
205 |     def reposing(self, idx, target_smpl_obj, epoch=None):
206 |         '''
207 |         Reconstruct the mesh the idx-th subject. given the target smpl obj.
208 |         '''
209 |         
210 |         if epoch is None:
211 |             epoch = 0
212 |         smpl_V, _ = load_obj(target_smpl_obj, load_materials=False)
213 |         log.info(f"Reposing {idx}th mesh at epoch {epoch}...")
214 |         start = time.time()
215 | 
216 |         with torch.no_grad():
217 | 
218 |             tmp_smpl_V = self.sdf_field.get_smpl_vertices_by_idx(idx)
219 | 
220 |             self.sdf_field.replace_smpl_vertices_by_idx(idx, smpl_V)
221 |             self.rgb_field.replace_smpl_vertices_by_idx(idx, smpl_V)
222 | 
223 |             h = self._marching_cubes (geo_idx=idx, subdivide=self.subdivide, res=self.res)
224 | 
225 |             self.sdf_field.replace_smpl_vertices_by_idx(idx, tmp_smpl_V)
226 |             self.rgb_field.replace_smpl_vertices_by_idx(idx, tmp_smpl_V)
227 | 
228 |         h.export(os.path.join(self.mesh_dir, '%03d_repose_src-%03d.obj' % (epoch, idx)) )
229 |         end = time.time()
230 |         log.info(f"Reposing finished in {end-start} seconds.")
231 | 
232 |     def transfer_features(self, src_idx, tar_idx, vert_idx=None):
233 |         '''
234 |         Copy the features from src_idx to tar_idx at vert_idx.
235 |         '''
236 |         with torch.no_grad():
237 |             src_geo = self.sdf_field.get_feature_by_idx(src_idx, vert_idx=vert_idx).clone()
238 |             src_tex = self.rgb_field.get_feature_by_idx(src_idx, vert_idx=vert_idx).clone()
239 |             self.sdf_field.replace_feature_by_idx(tar_idx, src_geo, vert_idx=vert_idx)
240 |             self.rgb_field.replace_feature_by_idx(tar_idx, src_tex, vert_idx=vert_idx)
241 | 
242 | 
243 |     def fitting_3D(self, code_idx, target_mesh, target_smpl_obj, num_steps=300, fit_nrm=False, fit_rgb=False):
244 |         """Fitting the latent code to the target mesh.
245 |            Store the optimzed code in the code_idx-th entry of the codebook.
246 |         """
247 | 
248 |         torch.cuda.empty_cache()
249 | 
250 |         geo_code = self.sdf_field.get_mean_feature().clone().unsqueeze(0).detach().data
251 |         tex_code = self.rgb_field.get_mean_feature().clone().unsqueeze(0).detach().data
252 | 
253 |         geo_code.requires_grad = True
254 |         tex_code.requires_grad = True
255 | 
256 |         V, F, texv, texf, mats = load_obj(target_mesh, load_materials=True)
257 |         smpl_V, _ = load_obj(target_smpl_obj, load_materials=False)
258 |         smpl_V = smpl_V.cuda()
259 | 
260 |         params = []
261 |         params.append({'params': geo_code, 'lr': 0.005})
262 |         params.append({'params': tex_code, 'lr': 0.01})
263 | 
264 |         optimizer = torch.optim.Adam(params, betas=(0.9, 0.999))
265 |         loop = tqdm(range(num_steps))
266 |         log.info(f"Start fitting latent code to the target mesh...")
267 |         for i in loop:
268 |             coord_1 = point_sample(V.cuda(), F.cuda(), ['near', 'trace', 'rand'], 20000, 0.01)
269 |             coord_2 = point_sample(smpl_V, self.smpl_F.cuda(), ['near', 'trace'], 50000, 0.2)
270 |             coord = torch.cat((coord_1, coord_2), dim=0)
271 |             rgb, nrm, sdf = closest_tex(V.cuda(), F.cuda(), texv.cuda(), texf.cuda(), mats, coord.cuda())
272 |             coord = coord.unsqueeze(0)
273 |             sdf = sdf.unsqueeze(0)
274 |             rgb = rgb.unsqueeze(0)
275 |             nrm = nrm.unsqueeze(0)
276 | 
277 |             sdf_loss = torch.tensor(0.0).cuda()
278 |             nrm_loss = torch.tensor(0.0).cuda()
279 |             rgb_loss = torch.tensor(0.0).cuda()
280 | 
281 |             optimizer.zero_grad()
282 | 
283 |             pred_sdf = self.sdf_field.forward_fitting(coord, geo_code, smpl_V.unsqueeze(0))
284 |             sdf_loss += torch.abs(pred_sdf - sdf).mean()
285 | 
286 |             if fit_rgb:
287 |                 pred_rgb = self.rgb_field.forward_fitting(coord, tex_code, smpl_V.unsqueeze(0))
288 |                 rgb_loss += torch.abs(pred_rgb - rgb).mean()
289 | 
290 |             if fit_nrm:
291 |                 pred_nrm = self.sdf_field.normal_fitting(coord, tex_code, smpl_V.unsqueeze(0))
292 |                 nrm_loss += torch.abs(pred_nrm - nrm).mean()
293 |             
294 | 
295 |             loss = 10*sdf_loss + rgb_loss + nrm_loss
296 |             loss.backward()
297 |             optimizer.step()
298 |             loop.set_description('Step [{}/{}] Total Loss: {:.4f} - L1:{:.4f} - RGB:{:.4f} - NRM:{:.4f}'
299 |                            .format(i, num_steps, loss.item(), sdf_loss.item(), rgb_loss.item(), nrm_loss.item()))
300 | 
301 |         log.info(f"Fitting finished. Store the optimized code and the new SMPL pose in the codebook.")
302 | 
303 |         with torch.no_grad():
304 |             self.sdf_field.replace_feature_by_idx(code_idx, geo_code)
305 |             self.rgb_field.replace_feature_by_idx(code_idx, tex_code)
306 |             self.sdf_field.replace_smpl_vertices_by_idx(code_idx, smpl_V)
307 |             self.rgb_field.replace_smpl_vertices_by_idx(code_idx, smpl_V)
308 | 
309 | 
310 |     def fitting_2D(self, code_idx, target_dict, target_smpl_obj=None, num_steps=500):
311 |         """Fitting the color latent code to the rendered images
312 |            Store the optimzed code in the code_idx-th entry of the codebook.
313 |         """
314 | 
315 |         torch.cuda.empty_cache()
316 | 
317 |         tex_code = self.rgb_field.get_feature_by_idx(code_idx).clone().unsqueeze(0).detach().data
318 |         tex_code.requires_grad = True
319 | 
320 |         rgb = target_dict['rgb'].cuda()
321 |         coord = target_dict['coord'].cuda()
322 |         mask = target_dict['mask'].cuda()
323 | 
324 |         b_size = rgb.shape[0] # b_size = n_views
325 | 
326 | 
327 |         inputs = []
328 |         targets = []
329 |         for i in range(b_size):
330 |             _xyz = coord[i]
331 |             _rgb = rgb[i]
332 |             _mask = mask[i, :, 0]
333 |             inputs.append(_xyz[_mask].view(1,-1,3))
334 |             targets.append(_rgb[_mask].view(1,-1,3))
335 | 
336 |         inputs = torch.cat(inputs, dim=1)
337 |         targets = torch.cat(targets, dim=1)
338 | 
339 |         if target_smpl_obj is not None:
340 |             smpl_V, _ = load_obj(target_smpl_obj, load_materials=False)
341 |             smpl_V = smpl_V.cuda()
342 |         else:
343 |             smpl_V = self.rgb_field.get_smpl_vertices_by_idx(code_idx)
344 | 
345 |         params = []
346 |         params.append({'params': tex_code, 'lr': 0.005})
347 | 
348 |         optimizer = torch.optim.Adam(params, betas=(0.9, 0.999))
349 |         loop = tqdm(range(num_steps))
350 | 
351 | 
352 |         for i in loop:
353 | 
354 |             rgb_loss = torch.tensor(0.0).cuda()
355 | 
356 |             optimizer.zero_grad()
357 | 
358 |             pred_rgb = self.rgb_field.forward_fitting(inputs, tex_code, smpl_V.unsqueeze(0))
359 |             rgb_loss += torch.abs(pred_rgb - targets).mean()
360 | 
361 |             rgb_loss.backward()
362 |             optimizer.step()
363 |             loop.set_description('Step [{}/{}] Total Loss: {:.4f}'.format(i, num_steps, rgb_loss.item()))
364 | 
365 |         with torch.no_grad():
366 |             self.rgb_field.replace_feature_by_idx(code_idx, tex_code)
367 |             #self.rgb_field.replace_smpl_vertices_by_idx(code_idx, smpl_V)
368 | 
--------------------------------------------------------------------------------
/lib/models/feature_dictionary.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn.functional as F
  3 | import torch.nn as nn
  4 | import logging as log
  5 | from ..ops.mesh import *
  6 | 
  7 | 
  8 | class FeatureDictionary(nn.Module):
  9 | 
 10 |     def __init__(self, 
 11 |         feature_dim        : int,
 12 |         feature_std        : float = 0.1,
 13 |         feature_bias       : float = 0.0,
 14 |     ):
 15 |         super().__init__()
 16 |         self.feature_dim = feature_dim
 17 |         self.feature_std = feature_std
 18 |         self.feature_bias = feature_bias
 19 | 
 20 |     def init_from_smpl_vertices(self, smpl_vertices):
 21 |       
 22 |         self.num_subjets = smpl_vertices.shape[0]
 23 |         self.num_vertices = smpl_vertices.shape[1]
 24 | 
 25 |         # Initialize feature codebooks
 26 |         fts = torch.zeros(self.num_subjets,self.num_vertices, self.feature_dim) + self.feature_bias
 27 |         fts += torch.randn_like(fts) * self.feature_std
 28 |         self.feature_codebooks = nn.Parameter(fts)
 29 | 
 30 |         log.info(f"Initalized feature codebooks with shape {self.feature_codebooks.shape}")
 31 | 
 32 |     def interpolate(self, coords, idx, smpl_V, smpl_F, input_code=None):
 33 | 
 34 |         """Query local features using the feature codebook, or the given input_code.
 35 |         Args:
 36 |             coords (torch.FloatTensor): coords of shape [batch, num_samples, 3]
 37 |             idx (torch.LongTensor): index of shape [batch, 1]
 38 |             smpl_V (torch.FloatTensor): SMPL vertices of shape [batch, num_vertices, 3]
 39 |             smpl_F (torch.LongTensor): SMPL faces of shape [num_faces, 3]
 40 |             input_code (torch.FloatTensor): input code of shape [batch, num_vertices, feature_dim]
 41 |         Returns:
 42 |             (torch.FloatTensor): interpolated features of shape [batch, num_samples, feature_dim]
 43 |         """
 44 | 
 45 |         sdf, hitpt, fid, weights = batched_closest_point_fast(smpl_V, smpl_F,
 46 |                                                               coords) # [B, Ns, 1], [B, Ns, 3], [B, Ns, 1], [B, Ns, 3]
 47 |         
 48 |         normal = torch.nn.functional.normalize( hitpt - coords, eps=1e-6, dim=2) # [B x Ns x 3]
 49 |         hitface = smpl_F[fid] # [B, Ns, 3]
 50 | 
 51 |         if input_code is None:
 52 |             inputs_feat = self.feature_codebooks[idx].unsqueeze(2).expand(-1, -1, hitface.shape[-1], -1) 
 53 |         else:
 54 |             inputs_feat = input_code.unsqueeze(2).expand(-1, -1, hitface.shape[-1], -1)
 55 |             
 56 |         indices = hitface.unsqueeze(-1).expand(-1, -1, -1, inputs_feat.shape[-1])
 57 |         nearest_feats = torch.gather(input=inputs_feat, index=indices, dim=1) # [B, Ns, 3, D]
 58 | 
 59 |         weighted_feats = torch.sum(nearest_feats * weights[...,None], dim=2) # K-weighted sum by: [B x Ns x 32]
 60 |         
 61 |         coords_feats = torch.cat([weights[...,1:], sdf], dim=-1) # [B, Ns, 3]
 62 |         return weighted_feats, coords_feats, normal
 63 |     
 64 |     def interpolate_random(self, coords, smpl_V, smpl_F, low_rank=32):
 65 |         """Query local features using PCA random sampling.
 66 | 
 67 |         Args:
 68 |             coords (torch.FloatTensor): coords of shape [batch, num_samples, 3]
 69 |             smpl_V (torch.FloatTensor): SMPL vertices of shape [batch, num_vertices, 3]
 70 |             smpl_F (torch.LongTensor): SMPL faces of shape [num_faces, 3]
 71 | 
 72 |         Returns:
 73 |             (torch.FloatTensor): interpolated features of shape [batch, num_samples, feature_dim]
 74 |         """
 75 |         b_size = coords.shape[0]
 76 | 
 77 |         sdf, hitpt, fid, weights = batched_closest_point_fast(smpl_V, smpl_F,
 78 |                                                               coords) # [B, Ns, 1], [B, Ns, 3], [B, Ns, 1], [B, Ns, 3]
 79 |         normal = torch.nn.functional.normalize( hitpt - coords, eps=1e-6, dim=2) # [B x Ns x 3]
 80 |         hitface = smpl_F[fid] # [B, Ns, 3]
 81 |         inputs_feat = self._pca_sample(low_rank=low_rank, batch_size=b_size).unsqueeze(2).expand(-1, -1, hitface.shape[-1], -1) 
 82 |         indices = hitface.unsqueeze(-1).expand(-1, -1, -1, inputs_feat.shape[-1])
 83 |         nearest_feats = torch.gather(input=inputs_feat, index=indices, dim=1) # [B, Ns, 3, D]
 84 | 
 85 |         weighted_feats = torch.sum(nearest_feats * weights[...,None], dim=2) # K-weighted sum by: [B x Ns x 32]
 86 |         
 87 |         coords_feats = torch.cat([weights[...,1:], sdf], dim=-1) # [B, Ns, 3]
 88 |         return weighted_feats, coords_feats, normal
 89 | 
 90 | 
 91 |     def _pca_sample(self, low_rank=32, batch_size=1):
 92 | 
 93 |         A = self.feature_codebooks.clone()
 94 |         num_subjects, num_vertices, dim = A.shape
 95 | 
 96 |         A = A.view(num_subjects, -1)
 97 | 
 98 |         (U, S, V) = torch.pca_lowrank(A, q=low_rank, center=True, niter=1)
 99 | 
100 |         params = torch.matmul(A, V) # (N, 128)
101 |         mean = params.mean(dim=0)
102 |         cov = torch.cov(params.T)
103 | 
104 |         m = torch.distributions.multivariate_normal.MultivariateNormal(mean, cov)
105 |         random_codes = m.sample((batch_size,)).to(self.feature_codebooks.device)
106 | 
107 |         return torch.matmul(random_codes.detach(), V.t()).view(-1, num_vertices, dim)
108 | 
109 | 
--------------------------------------------------------------------------------
/lib/models/losses.py:
--------------------------------------------------------------------------------
  1 | """ The code is based on https://github.com/apple/ml-gsn/ with adaption. """
  2 | 
  3 | import torch
  4 | import torch.nn as nn
  5 | from torch import autograd
  6 | import logging as log
  7 | import torch.nn.functional as F
  8 | 
  9 | from .networks.discriminator import StyleDiscriminator
 10 | 
 11 | def hinge_loss(fake_pred, real_pred, mode):
 12 |     if mode == 'd':
 13 |         # Discriminator update
 14 |         d_loss_fake = F.relu(1.0 + fake_pred).mean()
 15 |         d_loss_real = F.relu(1.0 - real_pred).mean()
 16 |         d_loss = d_loss_fake + d_loss_real
 17 |     elif mode == 'g':
 18 |         # Generator update
 19 |         d_loss = -torch.mean(fake_pred)
 20 |     return d_loss
 21 | 
 22 | def logistic_loss(fake_pred, real_pred, mode):
 23 |     if mode == 'd':
 24 |         # Discriminator update
 25 |         d_loss_fake = F.softplus(fake_pred).mean()
 26 |         d_loss_real = F.softplus(-real_pred).mean()
 27 |         d_loss = d_loss_fake + d_loss_real
 28 |     elif mode == 'g':
 29 |         # Generator update
 30 |         d_loss = F.softplus(-fake_pred).mean()
 31 |     return d_loss
 32 | 
 33 | 
 34 | def r1_loss(real_pred, real_img):
 35 |     (grad_real,) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)
 36 |     grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
 37 |     return grad_penalty
 38 | 
 39 | 
 40 | class GANLoss(nn.Module):
 41 |     def __init__(
 42 |         self,
 43 |         cfg,
 44 |         disc_loss='logistic',
 45 |         auxillary=False
 46 |     ):
 47 |         super().__init__()
 48 | 
 49 | 
 50 |         self.cfg = cfg
 51 |         self.discriminator = StyleDiscriminator(3, 128, auxilary=auxillary)
 52 |         log.info("Total number of parameters {}".format(
 53 |             sum(p.numel() for p in self.discriminator.parameters()))\
 54 |         )
 55 | 
 56 |         if disc_loss == 'hinge':
 57 |             self.disc_loss = hinge_loss
 58 |         elif disc_loss == 'logistic':
 59 |             self.disc_loss = logistic_loss
 60 | 
 61 |         self.auxillary = auxillary
 62 | 
 63 |     def forward(self, disc_in_real, disc_in_fake, mode='g', gt_label=None):
 64 | 
 65 |         if mode == 'g':  # optimize generator
 66 |             loss = 0
 67 |             log = {}
 68 |             if self.auxillary:
 69 |                 logits_fake, _ = self.discriminator(disc_in_fake)
 70 |             else:
 71 |                 logits_fake = self.discriminator(disc_in_fake)
 72 | 
 73 |             g_loss = self.disc_loss(logits_fake, None, mode='g')
 74 |             log["loss_train/g_loss"] = g_loss.item()
 75 |             loss += g_loss * self.cfg.lambda_gan
 76 | 
 77 |             return loss, log
 78 | 
 79 |         if mode == 'd' :  # optimize discriminator
 80 |             if self.auxillary:
 81 |                 logits_real, aux_real = self.discriminator(disc_in_real)
 82 |                 logits_fake, aux_fake = self.discriminator(disc_in_fake.detach().clone())
 83 |             else:
 84 |                 logits_real = self.discriminator(disc_in_real)
 85 |                 logits_fake = self.discriminator(disc_in_fake.detach().clone())
 86 | 
 87 |             disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d')
 88 | 
 89 |             # lazy regularization so we don't need to compute grad penalty every iteration
 90 |             if self.cfg.lambda_grad > 0:
 91 |                 grad_penalty = r1_loss(logits_real, disc_in_real)
 92 | 
 93 |                 # the 0 * logits_real is to trigger DDP allgather
 94 |                 # https://github.com/rosinality/stylegan2-pytorch/issues/76
 95 |                 grad_penalty = grad_penalty + (0 * logits_real.sum())
 96 |             else:
 97 |                 grad_penalty = torch.tensor(0.0).type_as(disc_loss)
 98 | 
 99 |             d_loss = disc_loss * self.cfg.lambda_gan + grad_penalty * self.cfg.lambda_grad / 2
100 |             if self.auxillary:
101 |                 d_loss += F.cross_entropy(aux_real, gt_label)
102 |                 d_loss += F.cross_entropy(aux_fake, gt_label)
103 | 
104 |             log = {
105 |                 "loss_train/disc_loss": disc_loss.item(),
106 |                 "loss_train/r1_loss": grad_penalty.item(),
107 |                 "loss_train/logits_real": logits_real.mean().item(),
108 |                 "loss_train/logits_fake": logits_fake.mean().item(),
109 |             }
110 | 
111 |             return d_loss, log
112 | 
--------------------------------------------------------------------------------
/lib/models/networks/discriminator.py:
--------------------------------------------------------------------------------
  1 | """ The code is based on https://github.com/apple/ml-gsn/ with adaption. """
  2 | 
  3 | import math
  4 | import torch
  5 | import torch.nn as nn
  6 | import torch.nn.functional as F
  7 | 
  8 | class StyleDiscriminator(nn.Module):
  9 |     def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, auxilary=False, **kwargs):
 10 |         super().__init__()
 11 | 
 12 |         log_size_in = int(math.log(in_res, 2))
 13 |         log_size_out = int(math.log(4, 2))
 14 |         self.auxilary = auxilary
 15 | 
 16 |         self.conv_in = ConvLayer2d(in_channel=in_channel, out_channel=ch_mul, kernel_size=3)
 17 | 
 18 |         # each resblock will half the resolution and double the number of features (until a maximum of ch_max)
 19 |         self.layers = []
 20 |         in_channels = ch_mul
 21 |         for i in range(log_size_in, log_size_out, -1):
 22 |             out_channels = int(min(in_channels * 2, ch_max))
 23 |             self.layers.append(ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True))
 24 |             in_channels = out_channels
 25 |         self.layers = nn.Sequential(*self.layers)
 26 | 
 27 |         self.disc_out = DiscriminatorHead(in_channel=in_channels, disc_stddev=True, auxilary=auxilary)
 28 | 
 29 |     def forward(self, x):
 30 |         x = self.conv_in(x)
 31 |         x = self.layers(x)
 32 |         if self.auxilary:
 33 |             out, aux = self.disc_out(x)
 34 |             return out, aux
 35 |         else:
 36 |             out = self.disc_out(x)
 37 |             return out
 38 | 
 39 | class DiscriminatorHead(nn.Module):
 40 |     def __init__(self, in_channel, disc_stddev=False, auxilary=False):
 41 |         super().__init__()
 42 | 
 43 |         self.disc_stddev = disc_stddev
 44 |         self.auxilary = auxilary
 45 |         stddev_dim = 1 if disc_stddev else 0
 46 | 
 47 |         self.conv_stddev = ConvLayer2d(
 48 |             in_channel=in_channel + stddev_dim, out_channel=in_channel, kernel_size=3, activate=True
 49 |         )
 50 | 
 51 |         self.final_linear = nn.Sequential(
 52 |             nn.Flatten(),
 53 |             EqualLinear(in_channel=in_channel * 4 * 4, out_channel=in_channel, activate=True),
 54 |             EqualLinear(in_channel=in_channel, out_channel=1),
 55 |         )
 56 |         if self.auxilary:
 57 |             self.aux_layer = nn.Sequential(
 58 |                 nn.Flatten(),
 59 |                 EqualLinear(in_channel=in_channel * 4 * 4, out_channel=in_channel, activate=True),
 60 |                 EqualLinear(in_channel=in_channel, out_channel=3),
 61 |         )
 62 | 
 63 |     def cat_stddev(self, x, stddev_group=4, stddev_feat=1):
 64 |         perm = torch.randperm(len(x))
 65 |         inv_perm = torch.argsort(perm)
 66 | 
 67 |         batch, channel, height, width = x.shape
 68 |         x = x[perm]  # shuffle inputs so that all views in a single trajectory don't get put together
 69 | 
 70 |         group = min(batch, stddev_group)
 71 |         stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width)
 72 |         stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
 73 |         stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
 74 |         stddev = stddev.repeat(group, 1, height, width)
 75 | 
 76 |         stddev = stddev[inv_perm]  # reorder inputs
 77 |         x = x[inv_perm]
 78 | 
 79 |         out = torch.cat([x, stddev], 1)
 80 |         return out
 81 | 
 82 |     def forward(self, x):
 83 |         if self.disc_stddev:
 84 |             x = self.cat_stddev(x)
 85 |         x = self.conv_stddev(x)
 86 |         out = self.final_linear(x)
 87 |         if self.auxilary:
 88 |             aux = self.aux_layer(x)
 89 |             return out, aux
 90 |         else:
 91 |             return out
 92 | 
 93 | 
 94 | class ConvDecoder(nn.Module):
 95 |     def __init__(self, in_channel, out_channel, in_res, out_res):
 96 |         super().__init__()
 97 | 
 98 |         log_size_in = int(math.log(in_res, 2))
 99 |         log_size_out = int(math.log(out_res, 2))
100 | 
101 |         self.layers = []
102 |         in_ch = in_channel
103 |         for i in range(log_size_in, log_size_out):
104 |             out_ch = in_ch // 2
105 |             self.layers.append(
106 |                 ConvLayer2d(
107 |                     in_channel=in_ch, out_channel=out_ch, kernel_size=3, upsample=True, bias=True, activate=True
108 |                 )
109 |             )
110 |             in_ch = out_ch
111 | 
112 |         self.layers.append(
113 |             ConvLayer2d(in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False)
114 |         )
115 |         self.layers = nn.Sequential(*self.layers)
116 | 
117 |     def forward(self, x):
118 |         return self.layers(x)
119 | 
120 | class FusedLeakyReLU(nn.Module):
121 |     def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
122 |         super().__init__()
123 | 
124 |         if bias:
125 |             self.bias = nn.Parameter(torch.zeros(channel))
126 | 
127 |         else:
128 |             self.bias = None
129 | 
130 |         self.negative_slope = negative_slope
131 |         self.scale = scale
132 | 
133 |     def forward(self, input):
134 |         return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
135 | 
136 | 
137 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
138 |     if input.dtype == torch.float16:
139 |         bias = bias.half()
140 | 
141 |     if bias is not None:
142 |         rest_dim = [1] * (input.ndim - bias.ndim - 1)
143 |         return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale
144 | 
145 |     else:
146 |         return F.leaky_relu(input, negative_slope=0.2) * scale
147 | 
148 | 
149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150 |     up_x, up_y = up, up
151 |     down_x, down_y = down, down
152 |     pad_x0, pad_x1, pad_y0, pad_y1 = pad[0], pad[1], pad[0], pad[1]
153 | 
154 |     _, channel, in_h, in_w = input.shape
155 |     input = input.reshape(-1, in_h, in_w, 1)
156 | 
157 |     _, in_h, in_w, minor = input.shape
158 |     kernel_h, kernel_w = kernel.shape
159 | 
160 |     out = input.view(-1, in_h, 1, in_w, 1, minor)
161 |     out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
162 |     out = out.view(-1, in_h * up_y, in_w * up_x, minor)
163 | 
164 |     out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
165 |     out = out[
166 |         :,
167 |         max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
168 |         max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
169 |         :,
170 |     ]
171 | 
172 |     out = out.permute(0, 3, 1, 2)
173 |     out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
174 |     w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
175 |     out = F.conv2d(out, w)
176 |     out = out.reshape(
177 |         -1,
178 |         minor,
179 |         in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
180 |         in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
181 |     )
182 |     out = out.permute(0, 2, 3, 1)
183 |     out = out[:, ::down_y, ::down_x, :]
184 | 
185 |     out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
186 |     out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
187 | 
188 |     return out.view(-1, channel, out_h, out_w)
189 | 
190 | 
191 | 
192 | def make_kernel(k):
193 |     k = torch.tensor(k, dtype=torch.float32)
194 | 
195 |     if k.ndim == 1:
196 |         k = k[None, :] * k[:, None]
197 | 
198 |     k /= k.sum()
199 | 
200 |     return k
201 | 
202 | 
203 | class Blur(nn.Module):
204 |     """Blur layer.
205 |     Applies a blur kernel to input image using finite impulse response filter. Blurring feature maps after
206 |     convolutional upsampling or before convolutional downsampling helps produces models that are more robust to
207 |     shifting inputs (https://richzhang.github.io/antialiased-cnns/). In the context of GANs, this can provide
208 |     cleaner gradients, and therefore more stable training.
209 |     Args:
210 |     ----
211 |     kernel: list, int
212 |         A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1].
213 |     pad: tuple, int
214 |         A tuple of integers representing the number of rows/columns of padding to be added to the top/left and
215 |         the bottom/right respectively.
216 |     upsample_factor: int
217 |         Upsample factor.
218 |     """
219 | 
220 |     def __init__(self, kernel, pad, upsample_factor=1):
221 |         super().__init__()
222 | 
223 |         kernel = make_kernel(kernel)
224 | 
225 |         if upsample_factor > 1:
226 |             kernel = kernel * (upsample_factor ** 2)
227 | 
228 |         self.register_buffer("kernel", kernel)
229 |         self.pad = pad
230 | 
231 |     def forward(self, input):
232 |         out = upfirdn2d(input, self.kernel, pad=self.pad)
233 |         return out
234 | 
235 | 
236 | class Upsample(nn.Module):
237 |     """Upsampling layer.
238 |     Perform upsampling using a blur kernel.
239 |     Args:
240 |     ----
241 |     kernel: list, int
242 |         A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1].
243 |     factor: int
244 |         Upsampling factor.
245 |     """
246 | 
247 |     def __init__(self, kernel=[1, 3, 3, 1], factor=2):
248 |         super().__init__()
249 | 
250 |         self.factor = factor
251 |         kernel = make_kernel(kernel) * (factor ** 2)
252 |         self.register_buffer("kernel", kernel)
253 | 
254 |         p = kernel.shape[0] - factor
255 |         pad0 = (p + 1) // 2 + factor - 1
256 |         pad1 = p // 2
257 |         self.pad = (pad0, pad1)
258 | 
259 |     def forward(self, input):
260 |         out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
261 |         return out
262 | 
263 | 
264 | class Downsample(nn.Module):
265 |     """Downsampling layer.
266 |     Perform downsampling using a blur kernel.
267 |     Args:
268 |     ----
269 |     kernel: list, int
270 |         A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1].
271 |     factor: int
272 |         Downsampling factor.
273 |     """
274 | 
275 |     def __init__(self, kernel=[1, 3, 3, 1], factor=2):
276 |         super().__init__()
277 | 
278 |         self.factor = factor
279 |         kernel = make_kernel(kernel)
280 |         self.register_buffer("kernel", kernel)
281 | 
282 |         p = kernel.shape[0] - factor
283 |         pad0 = (p + 1) // 2
284 |         pad1 = p // 2
285 |         self.pad = (pad0, pad1)
286 | 
287 |     def forward(self, input):
288 |         out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
289 |         return out
290 | 
291 | 
292 | class EqualLinear(nn.Module):
293 |     """Linear layer with equalized learning rate.
294 |     During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to
295 |     prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU
296 |     activation functions.
297 |     Args:
298 |     ----
299 |     in_channel: int
300 |         Input channels.
301 |     out_channel: int
302 |         Output channels.
303 |     bias: bool
304 |         Use bias term.
305 |     bias_init: float
306 |         Initial value for the bias.
307 |     lr_mul: float
308 |         Learning rate multiplier. By scaling weights and the bias we can proportionally scale the magnitude of
309 |         the gradients, effectively increasing/decreasing the learning rate for this layer.
310 |     activate: bool
311 |         Apply leakyReLU activation.
312 |     """
313 | 
314 |     def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False):
315 |         super().__init__()
316 | 
317 |         self.weight = nn.Parameter(torch.randn(out_channel, in_channel).div_(lr_mul))
318 | 
319 |         if bias:
320 |             self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init))
321 |         else:
322 |             self.bias = None
323 | 
324 |         self.activate = activate
325 |         self.scale = (1 / math.sqrt(in_channel)) * lr_mul
326 |         self.lr_mul = lr_mul
327 | 
328 |     def forward(self, input):
329 |         if self.activate:
330 |             out = F.linear(input, self.weight * self.scale)
331 |             out = fused_leaky_relu(out, self.bias * self.lr_mul)
332 |         else:
333 |             out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
334 |         return out
335 | 
336 |     def __repr__(self):
337 |         return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
338 | 
339 | 
340 | class EqualConv2d(nn.Module):
341 |     """2D convolution layer with equalized learning rate.
342 |     During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to
343 |     prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU
344 |     activation functions.
345 |     Args:
346 |     ----
347 |     in_channel: int
348 |         Input channels.
349 |     out_channel: int
350 |         Output channels.
351 |     kernel_size: int
352 |         Kernel size.
353 |     stride: int
354 |         Stride of convolutional kernel across the input.
355 |     padding: int
356 |         Amount of zero padding applied to both sides of the input.
357 |     bias: bool
358 |         Use bias term.
359 |     """
360 | 
361 |     def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
362 |         super().__init__()
363 | 
364 |         self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
365 |         self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
366 | 
367 |         self.stride = stride
368 |         self.padding = padding
369 | 
370 |         if bias:
371 |             self.bias = nn.Parameter(torch.zeros(out_channel))
372 |         else:
373 |             self.bias = None
374 | 
375 |     def forward(self, input):
376 |         out = F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
377 |         return out
378 | 
379 |     def __repr__(self):
380 |         return (
381 |             f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
382 |             f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
383 |         )
384 | 
385 | 
386 | class EqualConvTranspose2d(nn.Module):
387 |     """2D transpose convolution layer with equalized learning rate.
388 |     During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to
389 |     prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU
390 |     activation functions.
391 |     Args:
392 |     ----
393 |     in_channel: int
394 |         Input channels.
395 |     out_channel: int
396 |         Output channels.
397 |     kernel_size: int
398 |         Kernel size.
399 |     stride: int
400 |         Stride of convolutional kernel across the input.
401 |     padding: int
402 |         Amount of zero padding applied to both sides of the input.
403 |     output_padding: int
404 |         Extra padding added to input to achieve the desired output size.
405 |     bias: bool
406 |         Use bias term.
407 |     """
408 | 
409 |     def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, output_padding=0, bias=True):
410 |         super().__init__()
411 | 
412 |         self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size))
413 |         self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
414 | 
415 |         self.stride = stride
416 |         self.padding = padding
417 |         self.output_padding = output_padding
418 | 
419 |         if bias:
420 |             self.bias = nn.Parameter(torch.zeros(out_channel))
421 |         else:
422 |             self.bias = None
423 | 
424 |     def forward(self, input):
425 |         out = F.conv_transpose2d(
426 |             input,
427 |             self.weight * self.scale,
428 |             bias=self.bias,
429 |             stride=self.stride,
430 |             padding=self.padding,
431 |             output_padding=self.output_padding,
432 |         )
433 |         return out
434 | 
435 |     def __repr__(self):
436 |         return (
437 |             f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},'
438 |             f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
439 |         )
440 | 
441 | 
442 | class ConvLayer2d(nn.Sequential):
443 |     def __init__(
444 |         self,
445 |         in_channel,
446 |         out_channel,
447 |         kernel_size=3,
448 |         upsample=False,
449 |         downsample=False,
450 |         blur_kernel=[1, 3, 3, 1],
451 |         bias=True,
452 |         activate=True,
453 |     ):
454 |         assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously'
455 |         layers = []
456 | 
457 |         if upsample:
458 |             factor = 2
459 |             p = (len(blur_kernel) - factor) - (kernel_size - 1)
460 |             pad0 = (p + 1) // 2 + factor - 1
461 |             pad1 = p // 2 + 1
462 | 
463 |             layers.append(
464 |                 EqualConvTranspose2d(
465 |                     in_channel, out_channel, kernel_size, padding=0, stride=2, bias=bias and not activate
466 |                 )
467 |             )
468 |             layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor))
469 | 
470 |         if downsample:
471 |             factor = 2
472 |             p = (len(blur_kernel) - factor) + (kernel_size - 1)
473 |             pad0 = (p + 1) // 2
474 |             pad1 = p // 2
475 | 
476 |             layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
477 |             layers.append(
478 |                 EqualConv2d(in_channel, out_channel, kernel_size, padding=0, stride=2, bias=bias and not activate)
479 |             )
480 | 
481 |         if (not downsample) and (not upsample):
482 |             padding = kernel_size // 2
483 | 
484 |             layers.append(
485 |                 EqualConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=1, bias=bias and not activate)
486 |             )
487 | 
488 |         if activate:
489 |             layers.append(FusedLeakyReLU(out_channel, bias=bias))
490 | 
491 |         super().__init__(*layers)
492 | 
493 | 
494 | class ConvResBlock2d(nn.Module):
495 |     """2D convolutional residual block with equalized learning rate.
496 |     Residual block composed of 3x3 convolutions and leaky ReLUs.
497 |     Args:
498 |     ----
499 |     in_channel: int
500 |         Input channels.
501 |     out_channel: int
502 |         Output channels.
503 |     upsample: bool
504 |         Apply upsampling via strided convolution in the first conv.
505 |     downsample: bool
506 |         Apply downsampling via strided convolution in the second conv.
507 |     """
508 | 
509 |     def __init__(self, in_channel, out_channel, upsample=False, downsample=False):
510 |         super().__init__()
511 | 
512 |         assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously'
513 |         mid_ch = in_channel if downsample else out_channel
514 | 
515 |         self.conv1 = ConvLayer2d(in_channel, mid_ch, upsample=upsample, kernel_size=3)
516 |         self.conv2 = ConvLayer2d(mid_ch, out_channel, downsample=downsample, kernel_size=3)
517 | 
518 |         if (in_channel != out_channel) or upsample or downsample:
519 |             self.skip = ConvLayer2d(
520 |                 in_channel,
521 |                 out_channel,
522 |                 upsample=upsample,
523 |                 downsample=downsample,
524 |                 kernel_size=1,
525 |                 activate=False,
526 |                 bias=False,
527 |             )
528 | 
529 |     def forward(self, input):
530 |         out = self.conv1(input)
531 |         out = self.conv2(out)
532 | 
533 |         if hasattr(self, 'skip'):
534 |             skip = self.skip(input)
535 |             out = (out + skip) / math.sqrt(2)
536 |         else:
537 |             out = (out + input) / math.sqrt(2)
538 |         return out
539 | 
--------------------------------------------------------------------------------
/lib/models/networks/layers.py:
--------------------------------------------------------------------------------
  1 | # The code is adapted from https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/main/wisp/models/layers.py
  2 | 
  3 | import torch
  4 | import torch.nn as nn
  5 | import torch.nn.functional as F
  6 | 
  7 | def normalize_frobenius(x):
  8 |     """Normalizes the matrix according to the Frobenius norm.
  9 | 
 10 |     Args:
 11 |         x (torch.FloatTensor): A matrix.
 12 | 
 13 |     Returns:
 14 |         (torch.FloatTensor): A normalized matrix.
 15 |     """
 16 |     norm = torch.sqrt((torch.abs(x)**2).sum())
 17 |     return x / norm
 18 | 
 19 | def normalize_L_1(x):
 20 |     """Normalizes the matrix according to the L1 norm.
 21 | 
 22 |     Args:
 23 |         x (torch.FloatTensor): A matrix.
 24 | 
 25 |     Returns:
 26 |         (torch.FloatTensor): A normalized matrix.
 27 |     """
 28 |     abscolsum = torch.sum(torch.abs(x), dim=0)
 29 |     abscolsum = torch.min(torch.stack([1.0/abscolsum, torch.ones_like(abscolsum)], dim=0), dim=0)[0]
 30 |     return x * abscolsum[None,:]
 31 | 
 32 | def normalize_L_inf(x):    
 33 |     """Normalizes the matrix according to the Linf norm.
 34 | 
 35 |     Args:
 36 |         x (torch.FloatTensor): A matrix.
 37 | 
 38 |     Returns:
 39 |         (torch.FloatTensor): A normalized matrix.
 40 |     """
 41 |     absrowsum = torch.sum(torch.abs(x), axis=1)
 42 |     absrowsum = torch.min(torch.stack([1.0/absrowsum, torch.ones_like(absrowsum)], dim=0), dim=0)[0]
 43 |     return x * absrowsum[:,None]
 44 | 
 45 | class FrobeniusLinear(nn.Module):
 46 |     """A standard Linear layer which applies a Frobenius normalization in the forward pass.
 47 |     """
 48 |     def __init__(self, *args, **kwargs):
 49 |         super().__init__()
 50 |         self.linear = nn.Linear(*args, **kwargs)
 51 | 
 52 |     def forward(self, x):
 53 |         weight = normalize_frobenius(self.linear.weight)
 54 |         return F.linear(x, weight, self.linear.bias)
 55 | 
 56 | class L_1_Linear(nn.Module):
 57 |     """A standard Linear layer which applies a L1 normalization in the forward pass.
 58 |     """
 59 |     def __init__(self, *args, **kwargs):
 60 |         super().__init__()
 61 |         self.linear = nn.Linear(*args, **kwargs)
 62 | 
 63 |     def forward(self, x):
 64 |         weight = normalize_L_1(self.linear.weight)
 65 |         return F.linear(x, weight, self.linear.bias)
 66 | 
 67 | class L_inf_Linear(nn.Module):
 68 |     """A standard Linear layer which applies a Linf normalization in the forward pass.
 69 |     """
 70 |     def __init__(self, *args, **kwargs):
 71 |         super().__init__()
 72 |         self.linear = nn.Linear(*args, **kwargs)
 73 | 
 74 |     def forward(self, x):
 75 |         weight = normalize_L_inf(self.linear.weight)
 76 |         return F.linear(x, weight, self.linear.bias)
 77 |         
 78 | def spectral_norm_(*args, **kwargs):
 79 |     """Initializes a spectral norm layer.
 80 |     """
 81 |     return nn.utils.spectral_norm(nn.Linear(*args, **kwargs))
 82 | 
 83 | def get_layer_class(layer_type):
 84 |     """Convenience function to return the layer class name from text.
 85 | 
 86 |     Args:
 87 |         layer_type (str): Text name for the layer.
 88 | 
 89 |     Retunrs:
 90 |         (nn.Module): The layer to be used for the decoder.
 91 |     """
 92 |     if layer_type == 'none':
 93 |         return nn.Linear
 94 |     elif layer_type == 'spectral_norm':
 95 |         return spectral_norm_
 96 |     elif layer_type == 'frobenius_norm':
 97 |         return FrobeniusLinear
 98 |     elif layer_type == "l_1_norm":
 99 |         return L_1_Linear
100 |     elif layer_type == "l_inf_norm":
101 |         return L_inf_Linear
102 |     else:
103 |         assert(False and "layer type does not exist")
104 | 
--------------------------------------------------------------------------------
/lib/models/networks/mlps.py:
--------------------------------------------------------------------------------
  1 | # The code is adapted from https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/main/wisp/models/decoders/basic_decoders.py
  2 | import torch
  3 | import torch.nn as nn
  4 | import torch.nn.functional as F
  5 | 
  6 | 
  7 | class MLP(nn.Module):
  8 |     """Super basic but super useful MLP class.
  9 |     """
 10 |     def __init__(self, 
 11 |         input_dim, 
 12 |         output_dim, 
 13 |         activation = torch.relu,
 14 |         bias = True,
 15 |         layer = nn.Linear,
 16 |         num_layers = 4, 
 17 |         hidden_dim = 128, 
 18 |         skip       = [2]
 19 |     ):
 20 |         """Initialize the MLP.
 21 | 
 22 |         Args:
 23 |             input_dim (int): Input dimension of the MLP.
 24 |             output_dim (int): Output dimension of the MLP.
 25 |             activation (function): The activation function to use.
 26 |             bias (bool): If True, use bias.
 27 |             layer (nn.Module): The MLP layer module to use.
 28 |             num_layers (int): The number of hidden layers in the MLP.
 29 |             hidden_dim (int): The hidden dimension of the MLP.
 30 |             skip (List[int]): List of layer indices where the input dimension is concatenated.
 31 | 
 32 |         Returns:
 33 |             (void): Initializes the class.
 34 |         """
 35 |         super().__init__()
 36 |         
 37 |         self.input_dim = input_dim
 38 |         self.output_dim = output_dim        
 39 |         self.activation = activation
 40 |         self.bias = bias
 41 |         self.layer = layer
 42 |         self.num_layers = num_layers
 43 |         self.hidden_dim = hidden_dim
 44 |         self.skip = skip
 45 |         if self.skip is None:
 46 |             self.skip = []
 47 |         
 48 |         self.make()
 49 | 
 50 |     def make(self):
 51 |         """Builds the actual MLP.
 52 |         """
 53 |         layers = []
 54 |         for i in range(self.num_layers):
 55 |             if i == 0: 
 56 |                 layers.append(self.layer(self.input_dim, self.hidden_dim, bias=self.bias))
 57 |             elif i in self.skip:
 58 |                 layers.append(self.layer(self.hidden_dim+self.input_dim, self.hidden_dim, bias=self.bias))
 59 |             else:
 60 |                 layers.append(self.layer(self.hidden_dim, self.hidden_dim, bias=self.bias))
 61 |         self.layers = nn.ModuleList(layers)
 62 |         self.lout = self.layer(self.hidden_dim, self.output_dim, bias=self.bias)
 63 | 
 64 |     def forward(self, x, return_h=False):
 65 |         """Run the MLP!
 66 | 
 67 |         Args:
 68 |             x (torch.FloatTensor): Some tensor of shape [batch, ..., input_dim]
 69 |             return_h (bool): If True, also returns the last hidden layer.
 70 | 
 71 |         Returns:
 72 |             (torch.FloatTensor, (optional) torch.FloatTensor):
 73 |                 - The output tensor of shape [batch, ..., output_dim]
 74 |                 - The last hidden layer of shape [batch, ..., hidden_dim]
 75 |         """
 76 |         N = x.shape[0]
 77 | 
 78 |         for i, l in enumerate(self.layers):
 79 |             if i == 0:
 80 |                 h = self.activation(l(x))
 81 |             elif i in self.skip:
 82 |                 h = torch.cat([x, h], dim=-1)
 83 |                 h = self.activation(l(h))
 84 |             else:
 85 |                 h = self.activation(l(h))
 86 |         
 87 |         out = self.lout(h)
 88 |         
 89 |         if return_h:
 90 |             return out, h
 91 |         else:
 92 |             return out
 93 | 
 94 | 
 95 | 
 96 | class Conditional_MLP(nn.Module):
 97 |     """Super basic but super useful MLP class.
 98 |     """
 99 |     def __init__(self, 
100 |         input_dim, 
101 |         cond_dim,
102 |         output_dim, 
103 |         activation = torch.relu,
104 |         bias = True,
105 |         layer = nn.Linear,
106 |         num_layers = 4, 
107 |         hidden_dim = 128, 
108 |         skip       = [2]
109 |     ):
110 |         """Initialize the MLP.
111 | 
112 |         Args:
113 |             input_dim (int): Input dimension of the MLP.
114 |             output_dim (int): Output dimension of the MLP.
115 |             activation (function): The activation function to use.
116 |             bias (bool): If True, use bias.
117 |             layer (nn.Module): The MLP layer module to use.
118 |             num_layers (int): The number of hidden layers in the MLP.
119 |             hidden_dim (int): The hidden dimension of the MLP.
120 |             skip (List[int]): List of layer indices where the input dimension is concatenated.
121 | 
122 |         Returns:
123 |             (void): Initializes the class.
124 |         """
125 |         super().__init__()
126 |         
127 |         self.input_dim = input_dim
128 |         self.cond_dim = cond_dim
129 |         self.output_dim = output_dim        
130 |         self.activation = activation
131 |         self.bias = bias
132 |         self.layer = layer
133 |         self.num_layers = num_layers
134 |         self.hidden_dim = hidden_dim
135 |         self.skip = skip
136 |         if self.skip is None:
137 |             self.skip = []
138 |         
139 |         self.make()
140 | 
141 |     def make(self):
142 |         """Builds the actual MLP.
143 |         """
144 |         layers = []
145 |         for i in range(self.num_layers):
146 |             if i == 0: 
147 |                 layers.append(self.layer(self.input_dim, self.hidden_dim, bias=self.bias))
148 |             elif i in self.skip:
149 |                 layers.append(self.layer(self.hidden_dim+self.cond_dim, self.hidden_dim, bias=self.bias))
150 |             else:
151 |                 layers.append(self.layer(self.hidden_dim, self.hidden_dim, bias=self.bias))
152 |         self.layers = nn.ModuleList(layers)
153 |         self.lout = self.layer(self.hidden_dim, self.output_dim, bias=self.bias)
154 | 
155 |     def forward(self, x, c, return_h=False, sigmoid=False):
156 |         """Run the MLP!
157 | 
158 |         Args:
159 |             x (torch.FloatTensor): Some tensor of shape [batch, ..., input_dim]
160 |             return_h (bool): If True, also returns the last hidden layer.
161 | 
162 |         Returns:
163 |             (torch.FloatTensor, (optional) torch.FloatTensor):
164 |                 - The output tensor of shape [batch, ..., output_dim]
165 |                 - The last hidden layer of shape [batch, ..., hidden_dim]
166 |         """
167 |         N = x.shape[0]
168 | 
169 |         for i, l in enumerate(self.layers):
170 |             if i == 0:
171 |                 h = self.activation(l(x))
172 |             elif i in self.skip:
173 |                 h = torch.cat([h, c], dim=-1)
174 |                 h = self.activation(l(h))
175 |             else:
176 |                 h = self.activation(l(h))
177 |         
178 |         out = self.lout(h)
179 |         if sigmoid:
180 |             out = torch.sigmoid(out) 
181 | 
182 |         if return_h:
183 |             return out, h
184 |         else:
185 |             return out
186 | 
187 | 
188 | def get_activation_class(activation_type):
189 |     """Utility function to return an activation function class based on the string description.
190 | 
191 |     Args:
192 |         activation_type (str): The name for the activation function.
193 |     
194 |     Returns:
195 |         (Function): The activation function to be used. 
196 |     """
197 |     if activation_type == 'relu':
198 |         return torch.relu
199 |     elif activation_type == 'sin':
200 |         return torch.sin
201 |     elif activation_type == 'softplus':
202 |         return torch.nn.functional.softplus
203 |     elif activation_type == 'lrelu':
204 |         return torch.nn.functional.leaky_relu
205 |     else:
206 |         assert False and "activation type does not exist"
--------------------------------------------------------------------------------
/lib/models/networks/positional_encoding.py:
--------------------------------------------------------------------------------
 1 | # The code is adapted from https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/main/wisp/models/embedders/positional_embedder.py
 2 | import torch
 3 | import torch.nn as nn
 4 | 
 5 | class PositionalEncoding(nn.Module):
 6 |     """PyTorch implementation of positional embedding.
 7 |     """
 8 |     def __init__(self, num_freq, max_freq_log2, log_sampling=True, include_input=True, input_dim=3):
 9 |         """Initialize the module.
10 | 
11 |         Args:
12 |             num_freq (int): The number of frequency bands to sample. 
13 |             max_freq_log2 (int): The maximum frequency. The bands will be sampled between [0, 2^max_freq_log2].
14 |             log_sampling (bool): If true, will sample frequency bands in log space.
15 |             include_input (bool): If true, will concatenate the input.
16 |             input_dim (int): The dimension of the input coordinate space.
17 | 
18 |         Returns:
19 |             (void): Initializes the encoding.
20 |         """
21 |         super().__init__()
22 | 
23 |         self.num_freq = num_freq
24 |         self.max_freq_log2 = max_freq_log2
25 |         self.log_sampling = log_sampling
26 |         self.include_input = include_input
27 |         self.out_dim = 0
28 |         if include_input:
29 |             self.out_dim += input_dim
30 | 
31 |         if self.log_sampling:
32 |             self.bands = 2.0**torch.linspace(0.0, max_freq_log2, steps=num_freq)
33 |         else:
34 |             self.bands = torch.linspace(1, 2.0**max_freq_log2, steps=num_freq)
35 | 
36 |         # The out_dim is really just input_dim + num_freq * input_dim * 2 (for sin and cos)
37 |         self.out_dim += self.bands.shape[0] * input_dim * 2
38 |         self.bands = nn.Parameter(self.bands).requires_grad_(False)
39 |     
40 |     def forward(self, coords):
41 |         """Embded the coordinates.
42 | 
43 |         Args:
44 |             coords (torch.FloatTensor): Coordinates of shape [..., input_dim]
45 | 
46 |         Returns:
47 |             (torch.FloatTensor): Embeddings of shape [..., input_dim + out_dim] or [..., out_dim].
48 |         """
49 |         shape = coords.shape
50 |         # Flatten the coordinates
51 |         assert len(shape) > 1
52 |         if len(shape) > 2:
53 |             coords = coords.reshape(-1, shape[-1])
54 |         N = coords.shape[0]
55 |         winded = (coords[:,None] * self.bands[None,:,None]).reshape(N, -1)
56 |         encoded = torch.cat([torch.sin(winded), torch.cos(winded)], dim=-1)
57 |         if self.include_input:
58 |             encoded = torch.cat([coords, encoded], dim=-1)
59 |         # Reshape back to original
60 |         if len(shape) > 2:
61 |             encoded = encoded.reshape(*shape[:-1], -1)
62 |         return encoded
63 | 
64 | 
--------------------------------------------------------------------------------
/lib/models/neural_fields.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn.functional as F
  3 | import torch.nn as nn
  4 | import logging as log
  5 | 
  6 | from .feature_dictionary import FeatureDictionary
  7 | from .networks.positional_encoding import PositionalEncoding
  8 | from .networks.mlps import MLP, Conditional_MLP
  9 | from .networks.layers import get_layer_class
 10 | 
 11 | 
 12 | def get_activation_class(activation_type):
 13 |     """Utility function to return an activation function class based on the string description.
 14 | 
 15 |     Args:
 16 |         activation_type (str): The name for the activation function.
 17 |     
 18 |     Returns:
 19 |         (Function): The activation function to be used. 
 20 |     """
 21 |     if activation_type == 'relu':
 22 |         return torch.relu
 23 |     elif activation_type == 'sin':
 24 |         return torch.sin
 25 |     elif activation_type == 'softplus':
 26 |         return torch.nn.functional.softplus
 27 |     elif activation_type == 'lrelu':
 28 |         return torch.nn.functional.leaky_relu
 29 |     else:
 30 |         assert False and "activation type does not exist"
 31 | 
 32 | 
 33 | ####################################################
 34 | class NeuralField(nn.Module):
 35 | 
 36 |     def __init__(self,
 37 |         cfg          :dict,
 38 |         smpl_V       :torch.Tensor,
 39 |         smpl_F       :torch.Tensor,
 40 |         feat_dim     : int,
 41 |         out_dim      : int,
 42 |         pos_freq     : int,
 43 |         low_rank     : int,
 44 |         sigmoid      : bool = False,
 45 |     ):
 46 |         
 47 |         super().__init__()
 48 |         self.cfg = cfg
 49 |         self.smpl_V = smpl_V
 50 |         self.smpl_F = smpl_F
 51 |         self.feat_dim = feat_dim
 52 |         self.out_dim = out_dim
 53 |         self.pos_freq = pos_freq
 54 |         self.low_rank = low_rank
 55 |         self.sigmoid = sigmoid
 56 | 
 57 |         self.pos_dim = self.cfg.pos_dim
 58 |         self.c_dim = self.cfg.c_dim
 59 |         self.activation = self.cfg.activation
 60 |         self.layer_type = self.cfg.layer_type
 61 |         self.hidden_dim = self.cfg.hidden_dim
 62 |         self.num_layers = self.cfg.num_layers
 63 |         self.skip = self.cfg.skip
 64 |         self.feature_std = self.cfg.feature_std
 65 |         self.feature_bias = self.cfg.feature_bias
 66 | 
 67 | 
 68 |         self._init_dictionary()
 69 |         self._init_embedder()
 70 |         self._init_decoder()
 71 | 
 72 | 
 73 |     def _init_dictionary(self):
 74 |         """Initialize the feature dictionary object.
 75 |         """
 76 | 
 77 |         self.dictionary = FeatureDictionary(self.feat_dim, self.feature_std, self.feature_bias)
 78 |         self.dictionary.init_from_smpl_vertices(self.smpl_V)
 79 | 
 80 |     def _init_embedder(self):
 81 |         """Initialize positional embedding objects.
 82 |         """
 83 |         self.embedder = PositionalEncoding(self.pos_freq, self.pos_freq -1, input_dim=self.pos_dim)
 84 |         self.embed_dim = self.embedder.out_dim
 85 | 
 86 |     def _init_decoder(self):
 87 |         """Initialize the decoder object.
 88 |         """
 89 |         self.input_dim = self.embed_dim + self.feat_dim
 90 | 
 91 |         if self.c_dim <= 0:
 92 |             self.decoder = MLP(self.input_dim, self.out_dim, activation=get_activation_class(self.activation),
 93 |                                     bias=True, layer=get_layer_class(self.layer_type), num_layers=self.num_layers,
 94 |                                     hidden_dim=self.hidden_dim, skip=self.skip)
 95 |         else:
 96 |             self.decoder = Conditional_MLP(self.input_dim, self.c_dim, self.out_dim,  activation=get_activation_class(self.activation),
 97 |                                         bias=True, layer=get_layer_class(self.layer_type), num_layers=self.num_layers,
 98 |                                         hidden_dim=self.hidden_dim, skip=self.skip)    
 99 | 
100 | 
101 |         log.info("Total number of parameters {}".format(
102 |             sum(p.numel() for p in self.decoder.parameters()))\
103 |         )
104 | 
105 |     def forward_decoder(self, feats, local_coords, normal, return_h=False, f=None):
106 |         """Forward pass through the MLP decoder.
107 |             Args:
108 |                 feats (torch.FloatTensor): Feature tensor of shape [B, N, feat_dim]
109 |                 local_coords (torch.FloatTensor): Local coordinate tensor of shape [B, N, 3]
110 |                 normal (torch.FloatTensor): Normal tensor of shape [B, N, 3]
111 |                 return_h (bool): Whether to return the hidden states of the network.
112 |                 f (torch.FloatTensor): The conditional feature tensor of shape [B, c_dim]
113 |         
114 |         """
115 | 
116 |         if self.c_dim <= 0:
117 |             input = torch.cat([self.embedder(local_coords), feats], dim=-1)
118 |             return self.decoder(input, return_h=return_h, sigmoid=self.sigmoid)
119 |         else:
120 |             input = torch.cat([self.embedder(local_coords), feats], dim=-1)
121 |             if f is not None:
122 |                 c = torch.cat([f, normal], dim=-1)
123 |             else:
124 |                 c = normal
125 |             return self.decoder(input, c, return_h=return_h, sigmoid=self.sigmoid)
126 |         
127 | 
128 |     def forward(self, x, code_idx, pose_idx=None, return_h=False, f=None):
129 |         """Forward pass through the network.
130 |             Args:
131 |                 x (torch.FloatTensor): Coordinate tensor of shape [B, N, 3]
132 |                 code_idx (torch.LongTensor): Code index tensor of shape [B, 1]
133 |                 pose_idx (torch.LongTensor): SMPL_V index tensor of shape [B, 1]
134 |                 return_h (bool): Whether to return the hidden states of the network.
135 |                 f (torch.FloatTensor): The conditional feature tensor of shape [B, c_dim]
136 |         """
137 |         if pose_idx is None:
138 |             pose_idx = code_idx
139 |         feats, local_coords, normal = self.dictionary.interpolate(x, code_idx, self.smpl_V[pose_idx], self.smpl_F)
140 | 
141 |         return self.forward_decoder(feats, local_coords, normal, return_h=return_h, f=f)
142 |         
143 |     def sample(self, x, idx, return_h=False, f=None):
144 |         """Sample from the network.
145 |         """
146 |         feats, local_coords, normal = self.dictionary.interpolate_random(x, self.smpl_V[idx], self.smpl_F, self.low_rank)
147 | 
148 |         return self.forward_decoder(feats, local_coords, normal, return_h=return_h, f=f)
149 | 
150 |             
151 |     def regularization_loss(self, idx=None):
152 |         """Compute the L2 regularization loss.
153 |         """
154 | 
155 |         if idx is None:
156 |             return (self.dictionary.feature_codebooks**2).mean()
157 |         else:
158 |             return (self.dictionary.feature_codebooks[idx]**2).mean()
159 | 
160 | 
161 |     def finitediff_gradient(self, x, idx, eps=0.005, sample=False):
162 |         """Compute 3D gradient using finite difference.
163 | 
164 |         Args:
165 |             x (torch.FloatTensor): Coordinate tensor of shape [B, N, 3]
166 |         """
167 |         shape = x.shape
168 | 
169 |         eps_x = torch.tensor([eps, 0.0, 0.0], device=x.device)
170 |         eps_y = torch.tensor([0.0, eps, 0.0], device=x.device)
171 |         eps_z = torch.tensor([0.0, 0.0, eps], device=x.device)
172 | 
173 |         # shape: [B, 6, N, 3] -> [B, 6*N, 3]
174 |         x_new = torch.stack([x + eps_x, x + eps_y, x + eps_z,
175 |                            x - eps_x, x - eps_y, x - eps_z], dim=1).reshape(shape[0], -1, shape[-1])
176 |         
177 |         # shape: [B, 6*N, 3] -> [B, 6, N, 3]
178 |         if sample:
179 |             pred = self.sample(x_new, idx).reshape(shape[0], 6, -1)
180 |         else:
181 |             pred = self.forward(x_new, idx).reshape(shape[0], 6, -1)
182 |         grad_x = (pred[:, 0, ...] - pred[:, 3, ...]) / (eps * 2.0)
183 |         grad_y = (pred[:, 1, ...] - pred[:, 4, ...]) / (eps * 2.0)
184 |         grad_z = (pred[:, 2, ...] - pred[:, 5, ...]) / (eps * 2.0)
185 | 
186 |         return torch.stack([grad_x, grad_y, grad_z], dim=-1)
187 |     
188 |     
189 |     def forward_fitting(self, x, code, smpl_V, return_h=False, f=None):
190 |         """Forward pass through the network with a latent code input.
191 |             Args:
192 |                 x (torch.FloatTensor): Coordinate tensor of shape [1, N, 3]
193 |                 code (torch.FloatTensor): Latent code tensor of shape [1, n_vertices, c_dim]
194 |                 smpl_V (torch.FloatTensor): SMPL_V tensor of shape [1, n_vertices, 3]
195 |         """
196 | 
197 |         feats, local_coords, normal = self.dictionary.interpolate(x, 0, smpl_V, self.smpl_F, input_code=code)
198 | 
199 |         return self.forward_decoder(feats, local_coords, normal, return_h=return_h, f=f)
200 | 
201 |     def normal_fitting(self, x, code, smpl_V, eps=0.005):
202 |         shape = x.shape
203 | 
204 |         eps_x = torch.tensor([eps, 0.0, 0.0], device=x.device)
205 |         eps_y = torch.tensor([0.0, eps, 0.0], device=x.device)
206 |         eps_z = torch.tensor([0.0, 0.0, eps], device=x.device)
207 | 
208 |         # shape: [B, 6, N, 3] -> [B, 6*N, 3]
209 |         x_new = torch.stack([x + eps_x, x + eps_y, x + eps_z,
210 |                            x - eps_x, x - eps_y, x - eps_z], dim=1).reshape(shape[0], -1, shape[-1])
211 |         
212 |         pred = self.forward_fitting(x_new, code, smpl_V).reshape(shape[0], 6, -1)
213 |         grad_x = (pred[:, 0, ...] - pred[:, 3, ...]) / (eps * 2.0)
214 |         grad_y = (pred[:, 1, ...] - pred[:, 4, ...]) / (eps * 2.0)
215 |         grad_z = (pred[:, 2, ...] - pred[:, 5, ...]) / (eps * 2.0)
216 | 
217 |         return torch.stack([grad_x, grad_y, grad_z], dim=-1)
218 | 
219 |     def get_mean_feature(self, vert_idx=None):
220 |         if vert_idx is None:
221 |             return self.dictionary.feature_codebooks.mean(dim=0)
222 |         else:
223 |             return self.dictionary.feature_codebooks[:, vert_idx].mean(dim=0)
224 | 
225 |     def get_feature_by_idx(self, idx, vert_idx=None):
226 |         if vert_idx is None:
227 |             return self.dictionary.feature_codebooks[idx]
228 |         else:
229 |             return self.dictionary.feature_codebooks[idx][vert_idx]
230 |     
231 |     def replace_feature_by_idx(self, idx, feature, vert_idx=None):
232 |         if vert_idx is None:
233 |             self.dictionary.feature_codebooks[idx] = feature
234 |         else:
235 |             self.dictionary.feature_codebooks[idx][vert_idx] = feature
236 | 
237 |     def get_smpl_vertices_by_idx(self, idx):
238 |         return self.smpl_V[idx]
239 | 
240 |     def replace_smpl_vertices_by_idx(self, idx, smpl_V):
241 |         self.smpl_V[idx] = smpl_V
242 | 
--------------------------------------------------------------------------------
/lib/models/tracer.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn.functional as F
  3 | import torch.nn as nn
  4 | 
  5 | 
  6 | class SDFTracer(object):
  7 | 
  8 |     def __init__(self,
  9 |         cfg                  = None,
 10 |         camera_clamp : list  = [-4, 4],
 11 |         step_size    : float = 1.0,
 12 |         num_steps    : int   = 64, # samples for raymaching, iterations for sphere trace
 13 |         min_dis      : float = 1e-3): 
 14 | 
 15 |         self.camera_clamp = camera_clamp
 16 |         self.step_size = step_size
 17 |         self.num_steps = num_steps
 18 |         self.min_dis = min_dis
 19 | 
 20 |         self.inv_num_steps = 1.0 / self.num_steps
 21 | 
 22 |     def __call__(self, *args, **kwargs):
 23 |         return self.forward(*args, **kwargs)
 24 | 
 25 |     def forward(self, nef, idx, ray_o, ray_d):
 26 |         """PyTorch implementation of sphere tracing.
 27 |             Args:
 28 |                 nef: Neural field object
 29 |                 idx: index of the subject, shape (B, )
 30 |                 ray_o: ray origin, shape (B, N, 3)
 31 |                 ray_d: ray direction, shape (B, N, 3)
 32 |         """
 33 | 
 34 |         # Distanace from ray origin
 35 |         t = torch.zeros(ray_o.shape[0], ray_o.shape[1], 1, device=ray_o.device)
 36 | 
 37 |         # Position in model space
 38 |         x = torch.addcmul(ray_o, ray_d, t)
 39 | 
 40 |         cond = torch.ones_like(t).bool()
 41 |         
 42 |         normal = torch.zeros_like(x)
 43 |         # This function is in fact differentiable, but we treat it as if it's not, because
 44 |         # it evaluates a very long chain of recursive neural networks (essentially a NN with depth of
 45 |         # ~1600 layers or so). This is not sustainable in terms of memory use, so we return the final hit
 46 |         # locations, where additional quantities (normal, depth, segmentation) can be determined. The
 47 |         # gradients will propagate only to these locations. 
 48 |         with torch.no_grad():
 49 | 
 50 |             d = nef(x, idx)
 51 |             
 52 |             dprev = d.clone()
 53 | 
 54 |             # If cond is TRUE, then the corresponding ray has not hit yet.
 55 |             # OR, the corresponding ray has exit the clipping plane.
 56 |             #cond = torch.ones_like(d).bool()[:,0]
 57 | 
 58 |             # If miss is TRUE, then the corresponding ray has missed entirely.
 59 |             hit = torch.zeros_like(d).bool()
 60 |             
 61 |             for i in range(self.num_steps):
 62 |                 # 1. Check if ray hits.
 63 |                 #hit = (torch.abs(d) < self._MIN_DIS)[:,0] 
 64 |                 # 2. Check that the sphere tracing is not oscillating
 65 |                 #hit = hit | (torch.abs((d + dprev) / 2.0) < self._MIN_DIS * 3)[:,0]
 66 |                 
 67 |                 # 3. Check that the ray has not exit the far clipping plane.
 68 |                 #cond = (torch.abs(t) < self.clamp[1])[:,0]
 69 |                 
 70 |                 hit = (torch.abs(t) < self.camera_clamp[1])
 71 |                 
 72 |                 # 1. not hit surface
 73 |                 cond = cond & (torch.abs(d) > self.min_dis)
 74 | 
 75 |                 # 2. not oscillating
 76 |                 cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3)
 77 |                 
 78 |                 # 3. not a hit
 79 |                 cond = cond & hit
 80 |                 
 81 |                 #cond = cond & ~hit
 82 |                 
 83 |                 # If the sum is 0, that means that all rays have hit, or missed.
 84 |                 if not cond.any():
 85 |                     break
 86 | 
 87 |                 # Advance the x, by updating with a new t
 88 |                 x = torch.where(cond, torch.addcmul(ray_o, ray_d, t), x)
 89 |                 
 90 |                 # Store the previous distance
 91 |                 dprev = torch.where(cond, d, dprev)
 92 | 
 93 |                 # Update the distance to surface at x
 94 |                 d[cond] = nef(x, idx)[cond] * self.step_size
 95 | 
 96 |                 # Update the distance from origin 
 97 |                 t = torch.where(cond, t+d, t)
 98 |     
 99 |         # AABB cull 
100 | 
101 |         hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1,keepdim=True)
102 |         #hit = torch.ones_like(d).byte()[...,0]
103 |         
104 |         # The function will return 
105 |         #  x: the final model-space coordinate of the render
106 |         #  t: the final distance from origin
107 |         #  d: the final distance value from
108 |         #  miss: a vector containing bools of whether each ray was a hit or miss
109 |         
110 |         #if hit.any():
111 |         #    grad = nef.finitediff_gradient(x[hit], idx)
112 |         #    _normal = F.normalize(grad, p=2, dim=-1, eps=1e-5)
113 |         #    normal[hit] = _normal
114 |         
115 |         return x, hit
116 | 
--------------------------------------------------------------------------------
/lib/models/trainer.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import logging as log
  3 | import torch
  4 | import torch.nn as nn
  5 | import torch.nn.functional as F
  6 | import numpy as np  
  7 | 
  8 | from .neural_fields import NeuralField
  9 | from .tracer import SDFTracer
 10 | from .losses import GANLoss
 11 | from kaolin.ops.conversions import voxelgrids_to_trianglemeshes
 12 | from kaolin.ops.mesh import subdivide_trianglemesh
 13 | 
 14 | import wandb
 15 | 
 16 | class Trainer(nn.Module):
 17 | 
 18 |     def __init__(self, config, smpl_V, smpl_F, log_dir):
 19 | 
 20 |         super().__init__()
 21 | 
 22 |         # Set device to use
 23 |         self.device = torch.device('cuda')
 24 |         device_name = torch.cuda.get_device_name(device=self.device)
 25 |         log.info(f'Using {device_name} with CUDA v{torch.version.cuda}')
 26 | 
 27 |         self.cfg = config
 28 |         self.use_2d = self.cfg.use_2d_from_epoch >= 0
 29 |         self.use_2d_nrm = self.cfg.use_nrm_dis
 30 | 
 31 |         self.log_dir = log_dir
 32 |         self.log_dict = {}
 33 | 
 34 |         self.smpl_F = smpl_F.to(self.device).detach()
 35 |         self.smpl_V = smpl_V.to(self.device).detach()
 36 | 
 37 |         self.epoch = 0
 38 |         self.global_step = 0
 39 | 
 40 |         self.init_model()
 41 |         self.init_optimizer()
 42 |         self.init_log_dict()
 43 | 
 44 | 
 45 |     def init_model(self):
 46 |         """Initialize model.
 47 |         """
 48 | 
 49 |         log.info("Initializing geometry neural field...")
 50 | 
 51 |         self.sdf_field = NeuralField(self.cfg,
 52 |                                      self.smpl_V, 
 53 |                                      self.smpl_F,
 54 |                                      self.cfg.shape_dim,
 55 |                                      1,
 56 |                                      self.cfg.shape_freq,
 57 |                                      self.cfg.shape_pca_dim).to(self.device)
 58 | 
 59 |         log.info("Initializing texture neural field...")
 60 | 
 61 |         self.rgb_field = NeuralField(self.cfg,
 62 |                                      self.smpl_V, 
 63 |                                      self.smpl_F,
 64 |                                      self.cfg.color_dim,
 65 |                                      3,
 66 |                                      self.cfg.color_freq,
 67 |                                      self.cfg.color_pca_dim,
 68 |                                      sigmoid=True).to(self.device)
 69 |         
 70 |         self.tracer = SDFTracer(self.cfg)
 71 | 
 72 | 
 73 |         if self.use_2d:
 74 |             log.info("Initializing RGB discriminators...")
 75 |             self.gan_loss_rgb = GANLoss(self.cfg, self.cfg.gan_loss_type, auxillary=True).to(self.device)
 76 |             if self.use_2d_nrm:
 77 |                 log.info("Initializing normal discriminators...")
 78 |                 self.gan_loss_nrm = GANLoss(self.cfg, self.cfg.gan_loss_type).to(self.device)
 79 | 
 80 | 
 81 | 
 82 |     def init_optimizer(self):
 83 |         """Initialize optimizer.
 84 |         """
 85 |     
 86 |         decoder_params = []
 87 |         decoder_params.extend(list(self.sdf_field.decoder.parameters()))
 88 |         decoder_params.extend(list(self.rgb_field.decoder.parameters()))
 89 |         dictionary_params = []
 90 |         dictionary_params.extend(list(self.sdf_field.dictionary.parameters()))
 91 |         dictionary_params.extend(list(self.rgb_field.dictionary.parameters()))
 92 | 
 93 |         params = []
 94 |         params.append({'params': decoder_params,
 95 |                           'lr': self.cfg.lr_decoder,
 96 |                           "weight_decay": self.cfg.weight_decay})
 97 |         params.append({'params': dictionary_params,
 98 |                             'lr': self.cfg.lr_codebook})
 99 |         
100 |         
101 |         self.optimizer = torch.optim.Adam(params,
102 |                                     betas=(self.cfg.beta1, self.cfg.beta2))
103 |         
104 |         if self.use_2d:
105 |             dis_params = list(self.gan_loss_rgb.discriminator.parameters())
106 |             if self.use_2d_nrm:
107 |                 dis_params += list(self.gan_loss_nrm.discriminator.parameters())
108 |             
109 |             self.optimizer_d = torch.optim.Adam(dis_params,
110 |                                     lr=self.cfg.lr_dis,
111 |                                     betas=(0.0, self.cfg.beta2))
112 |     def init_log_dict(self):
113 |         """Custom logging dictionary.
114 |         """
115 |         self.log_dict['total_iter_count'] = 0
116 |         # 3D Loss
117 |         self.log_dict['Loss_3D/rgb_loss'] = 0
118 |         self.log_dict['Loss_3D/nrm_loss'] = 0
119 |         self.log_dict['Loss_3D/reco_loss'] = 0
120 |         self.log_dict['Loss_3D/reg_loss'] = 0
121 |         self.log_dict['Loss_3D/total_loss'] = 0
122 |         
123 |         # RGB Discriminator
124 |         self.log_dict['total_2D_count'] = 0
125 | 
126 |         self.log_dict['RGB_dis/D_loss'] = 0
127 |         self.log_dict['RGB_dis/penalty_loss']= 0
128 |         self.log_dict['RGB_dis/logits_real']= 0
129 |         self.log_dict['RGB_dis/logits_fake']= 0
130 |         
131 |         # Nrm Discriminator
132 |         self.log_dict['Nrm_dis/penalty_loss'] = 0
133 |         self.log_dict['Nrm_dis/loss_D'] = 0
134 |         self.log_dict['Nrm_dis/logits_real'] = 0
135 |         self.log_dict['Nrm_dis/logits_fake'] = 0
136 |         
137 |         # 2D Loss
138 |         self.log_dict['Loss_2D/RGB_G_loss'] = 0 
139 |         self.log_dict['Loss_2D/Nrm_G_loss'] = 0 
140 | 
141 | 
142 |     def step(self, epoch, n_iter, data):
143 |         """Training step.
144 |             1. 3D forward
145 |             2. 3D backward
146 |             3. 2D forward
147 |             4. 2D backward
148 |         """
149 |         # record stats
150 |         self.epoch = epoch
151 |         self.global_step = n_iter
152 | 
153 |         # Set inputs to device
154 |         self.set_inputs(data)
155 | 
156 |         # Train
157 |         self.optimizer.zero_grad()
158 |         self.forward_3D()
159 |         self.backward_3D()
160 | 
161 |         if self.use_2d and \
162 |            epoch >= self.cfg.use_2d_from_epoch and \
163 |            n_iter % self.cfg.train_2d_every_iter == 0:
164 |             self.forward_2D_rgb()
165 |             self.backward_2D_rgb()
166 |             if self.use_2d_nrm:
167 |                 self.forward_2D_nrm()
168 |                 self.backward_2D_nrm()
169 |             self.log_dict['total_2D_count'] += 1
170 | 
171 |         self.optimizer.step()
172 |         self.log_dict['total_iter_count'] += 1
173 | 
174 |     def set_inputs(self, data):
175 |         """Set inputs for training.
176 |         """
177 |         self.b_szie, self.n_vertice, _ = data['pts'].shape
178 |         self.idx = data['idx'].to(self.device)
179 | 
180 |         self.pts = data['pts'].to(self.device)
181 |         self.gts = data['sdf'].to(self.device)
182 |         self.rgb = data['rgb'].to(self.device) 
183 | 
184 |         # Downsample normal for faster training
185 |         self.nrm_pts = self.pts[:, :self.n_vertice//10].to(self.device)
186 |         self.nrm = data['nrm'][:, :self.n_vertice//10].to(self.device)
187 | 
188 |         if self.use_2d:
189 |             self.width =  data['rgb_image'].shape[2]
190 |             
191 |             self.label = data['label'].view(self.b_szie).to(self.device)
192 |             self.ray_dir = data['ray_dir_image'].view(self.b_szie,-1,3).to(self.device)
193 |             self.ray_ori = data['ray_ori_image'].view(self.b_szie,-1,3).to(self.device)
194 |             self.gt_xyz = data['xyz_image'].view(self.b_szie,-1,3).to(self.device)
195 |             self.gt_nrm = data['nrm_image'].view(self.b_szie,-1,3).to(self.device)
196 |             self.gt_rgb = data['rgb_image'].view(self.b_szie,-1,3).to(self.device)
197 |             self.gt_mask = data['mask_image'].view(self.b_szie,-1,1).to(self.device)
198 | 
199 |     def forward_3D(self):
200 |         """Forward pass for 3D.
201 |             predict sdf, rgb, nrm
202 |         """
203 |         self.pred_sdf, geo_h = self.sdf_field(self.pts, self.idx, return_h=True)
204 |         self.pred_rgb = self.rgb_field(self.pts, self.idx)
205 |         self.pred_nrm = self.sdf_field.finitediff_gradient(self.nrm_pts, self.idx)
206 |         self.pred_nrm = F.normalize(self.pred_nrm, p=2, dim=-1, eps=1e-5)
207 | 
208 |     def backward_3D(self):
209 |         """Backward pass for 3D.
210 |             Compute 3D loss
211 |         """
212 |         total_loss = 0.0
213 |         reco_loss = 0.0
214 |         rgb_loss = 0.0
215 |         reg_loss = 0.0
216 | 
217 |         reco_loss += torch.abs(self.pred_sdf - self.gts).mean()
218 | 
219 |         rgb_loss += torch.abs(self.pred_rgb - self.rgb).mean()
220 | 
221 |         #nrm_loss = torch.abs(1 - F.cosine_similarity(self.pred_nrm, self.nrm, dim=-1)).mean()
222 |         nrm_loss = torch.abs(self.pred_nrm - self.nrm).mean()
223 | 
224 |         reg_loss += self.sdf_field.regularization_loss()
225 |         reg_loss += self.rgb_field.regularization_loss()
226 | 
227 |         total_loss += reco_loss * self.cfg.lambda_sdf + \
228 |                       rgb_loss * self.cfg.lambda_rgb + \
229 |                       nrm_loss * self.cfg.lambda_nrm + \
230 |                       reg_loss * self.cfg.lambda_reg
231 | 
232 |         total_loss.backward()
233 | 
234 |         # Update logs
235 |         self.log_dict['Loss_3D/reco_loss'] += reco_loss.item()
236 |         self.log_dict['Loss_3D/rgb_loss'] += rgb_loss.item()
237 |         self.log_dict['Loss_3D/nrm_loss'] += nrm_loss.item()
238 |         self.log_dict['Loss_3D/reg_loss'] += reg_loss.item()
239 | 
240 |         self.log_dict['Loss_3D/total_loss'] += total_loss.item()
241 | 
242 |     def forward_2D_rgb(self):
243 |         """Forward pass for 2D rgb images.
244 |            Fix geroemtry (3D coordinates) and random sample texture
245 |         """
246 |         x = self.gt_xyz
247 |         hit = self.gt_mask
248 | 
249 |         self.rgb_2d = self.rgb_field.sample(x.detach(), self.idx) * hit
250 | 
251 |     def forward_2D_nrm(self):
252 |         """Forward pass for 2D nrm images. Random sample geometry and output normal.
253 |             This requires online ray tracing and is slow.
254 |             Cached points can be used as an approximation.
255 |         """
256 |         if self.cfg.use_cached_pts:
257 |             x = self.gt_xyz
258 |             hit = self.gt_mask
259 |         else:
260 |             x, hit = self.tracer(self.sdf_field.sample, self.idx, self.ray_ori, self.ray_dir)
261 | 
262 |         _normal = self.sdf_field.finitediff_gradient(x, self.idx, sample=True)
263 |         _normal = F.normalize(_normal, p=2, dim=-1, eps=1e-5)
264 |         self.nrm_2d = _normal * hit   
265 | 
266 |     def backward_2D_rgb(self):
267 |         """Backward pass for 2D rgb images.
268 |             Compute 2D adversarial loss for the discriminator and generator.
269 |         """
270 |    
271 |         total_2D_loss = 0.0
272 | 
273 |         # RGB GAN loss
274 |         disc_in_fake = self.rgb_2d.view(self.b_szie, self.width, self.width, 3).permute(0,3,1,2)
275 |         disc_in_real = (self.gt_rgb * self.gt_mask).view(self.b_szie, self.width, self.width, 3).permute(0,3,1,2)
276 |         disc_in_real.requires_grad = True  # for R1 gradient penalty
277 | 
278 |         self.optimizer_d.zero_grad()
279 |         d_loss, log = self.gan_loss_rgb(disc_in_real, disc_in_fake, mode='d', gt_label=self.label)
280 |         d_loss.backward()
281 |         self.optimizer_d.step()
282 | 
283 |         self.log_dict['RGB_dis/D_loss'] += log['loss_train/disc_loss']
284 |         self.log_dict['RGB_dis/penalty_loss'] += log['loss_train/r1_loss']
285 |         self.log_dict['RGB_dis/logits_real'] += log['loss_train/logits_real']
286 |         self.log_dict['RGB_dis/logits_fake'] += log['loss_train/logits_fake']
287 | 
288 |         g_loss, log = self.gan_loss_rgb(None, disc_in_fake, mode='g')
289 |         total_2D_loss += g_loss
290 |         total_2D_loss.backward()
291 | 
292 |         self.log_dict['Loss_2D/RGB_G_loss'] += log['loss_train/g_loss']
293 |     
294 |     def backward_2D_nrm(self):
295 |         """Backward pass for 2D normal images.
296 |             Compute 2D adversarial loss for the discriminator and generator.
297 |         """
298 | 
299 |         # Nrm GAN loss
300 |         total_2D_loss = 0.0
301 | 
302 |         disc_in_fake = self.nrm_2d.view(self.b_szie, self.width, self.width, 3).permute(0,3,1,2)
303 |         disc_in_real = (self.gt_nrm * self.gt_mask).view(self.b_szie, self.width, self.width, 3).permute(0,3,1,2)
304 |         disc_in_real.requires_grad = True  # for R1 gradient penalty
305 | 
306 |         self.optimizer_d.zero_grad()
307 |         d_loss, log = self.gan_loss_nrm(disc_in_real, disc_in_fake, mode='d')
308 |         d_loss.backward()
309 |         self.optimizer_d.step()
310 | 
311 |         self.log_dict['Nrm_dis/loss_D'] += log['loss_train/disc_loss']
312 |         self.log_dict['Nrm_dis/penalty_loss'] += log['loss_train/r1_loss']
313 |         self.log_dict['Nrm_dis/logits_real'] += log['loss_train/logits_real']
314 |         self.log_dict['Nrm_dis/logits_fake'] += log['loss_train/logits_fake']
315 |         
316 |         g_loss, log = self.gan_loss_rgb(None, disc_in_fake, mode='g')
317 |         total_2D_loss += g_loss
318 |         total_2D_loss.backward()
319 | 
320 |         self.log_dict['Loss_2D/Nrm_G_loss'] += log['loss_train/g_loss']
321 | 
322 |     def log(self, step, epoch):
323 |         """Log the training information.
324 |         """
325 |         log_text = 'STEP {} - EPOCH {}/{}'.format(step, epoch, self.cfg.epochs)
326 |         self.log_dict['Loss_3D/total_loss'] /= self.log_dict['total_iter_count'] + 1e-6
327 |         log_text += ' | total loss: {:>.3E}'.format(self.log_dict['Loss_3D/total_loss'])
328 |         self.log_dict['Loss_3D/reco_loss'] /= self.log_dict['total_iter_count'] + 1e-6
329 |         log_text += ' | Reco loss: {:>.3E}'.format(self.log_dict['Loss_3D/reco_loss'])
330 |         self.log_dict['Loss_3D/rgb_loss'] /= self.log_dict['total_iter_count'] + 1e-6
331 |         log_text += ' | rgb loss: {:>.3E}'.format(self.log_dict['Loss_3D/rgb_loss'])
332 |         self.log_dict['Loss_3D/nrm_loss'] /= self.log_dict['total_iter_count'] + 1e-6
333 |         log_text += ' | nrm loss: {:>.3E}'.format(self.log_dict['Loss_3D/nrm_loss'])
334 |         self.log_dict['Loss_3D/reg_loss'] /= self.log_dict['total_iter_count'] + 1e-6
335 | 
336 |         log.info(log_text)
337 | 
338 |         for key, value in self.log_dict.items():
339 |             if ['RGB_dis', 'Nrm_dis', 'Loss_2D'].count(key.split('/')[0]) > 0:
340 |                 value /= self.log_dict['total_2D_count'] + 1e-6
341 |             wandb.log({key: value}, step=step)
342 |         self.init_log_dict()
343 | 
344 |     def write_images(self, i):
345 |         """Write images to wandb.
346 |         """    
347 |         gen_img = self.rgb_2d.view(self.b_szie, self.width , self.width , 3).clone().detach().cpu().numpy()
348 |         gt_img = (self.gt_rgb * self.gt_mask).view(self.b_szie, self.width , self.width , 3).clone().detach().cpu().numpy()
349 |         wandb.log({"Generated Images": [wandb.Image(gen_img[i]) for i in range(self.b_szie)]}, step=i)
350 |         wandb.log({"Ground Truth Images": [wandb.Image(gt_img[i]) for i in range(self.b_szie)]}, step=i)
351 | 
352 |         if self.use_2d_nrm:
353 |             gen_nrm = self.nrm_2d.view(self.b_szie, self.width , self.width , 3).clone().detach().cpu().numpy() * 0.5 + 0.5
354 |             gt_nrm = (self.gt_nrm * self.gt_mask).view(self.b_szie, self.width , self.width , 3).clone().detach().cpu().numpy() * 0.5 + 0.5
355 |             gen_nrm = np.clip(gen_nrm, 0, 1)
356 |             gt_nrm = np.clip(gt_nrm, 0, 1)
357 |             wandb.log({"Generated Normals": [wandb.Image(gen_nrm[i]) for i in range(self.b_szie)]}, step=i)
358 |             wandb.log({"Ground Truth Normals": [wandb.Image(gt_nrm[i]) for i in range(self.b_szie)]}, step=i)
359 | 
360 |     def save_checkpoint(self, full=True, replace=False):
361 |         """Save the model checkpoint.
362 |         """
363 | 
364 |         if replace:
365 |             model_fname = os.path.join(self.log_dir, f'model-.pth')
366 |         else:
367 |             model_fname = os.path.join(self.log_dir, f'model-{self.epoch:04d}.pth')
368 | 
369 |         state = {
370 |             'epoch': self.epoch,
371 |             'global_step': self.global_step,
372 |             'log_dir': self.log_dir
373 |         }
374 | 
375 |         if full:
376 |             state['optimizer'] = self.optimizer.state_dict()
377 |             if self.use_2d:
378 |                 state['optimizer_d'] = self.optimizer_d.state_dict()
379 |     
380 | 
381 |         state['sdf'] = self.sdf_field.state_dict()
382 |         state['rgb'] = self.rgb_field.state_dict()
383 |         if self.use_2d:
384 |             state['D_rgb'] = self.gan_loss_rgb.state_dict()
385 |             if self.use_2d_nrm:
386 |                 state['D_nrm'] = self.gan_loss_nrm.state_dict()
387 | 
388 |         log.info(f'Saving model checkpoint to: {model_fname}')
389 |         torch.save(state, model_fname)
390 | 
391 | 
392 |     def load_checkpoint(self, fname):
393 |         """Load checkpoint.
394 |         """
395 |         try:
396 |             checkpoint = torch.load(fname, map_location=self.device)
397 |             log.info(f'Loading model checkpoint from: {fname}')
398 |         except FileNotFoundError:
399 |             log.warning(f'No checkpoint found at: {fname}, model randomly initialized.')
400 |             return
401 | 
402 |         # update meta info
403 |         self.epoch = checkpoint['epoch']
404 |         self.global_step = checkpoint['global_step']
405 |         self.log_dir = checkpoint['log_dir']
406 |         
407 |         self.sdf_field.load_state_dict(checkpoint['sdf'])
408 |         self.rgb_field.load_state_dict(checkpoint['rgb'])
409 |         if self.use_2d:
410 |             if 'D_rgb' in checkpoint:
411 |                 self.gan_loss_rgb.load_state_dict(checkpoint['D_rgb'])
412 |             if self.use_2d_nrm and 'D_nrm' in checkpoint:
413 |                 self.gan_loss_nrm.load_state_dict(checkpoint['D_nrm'])
414 | 
415 |         if 'optimizer' in checkpoint:
416 |             self.optimizer.load_state_dict(checkpoint['optimizer'])
417 |             if self.use_2d:
418 |                 self.optimizer_d.load_state_dict(checkpoint['optimizer_d'])
419 | 
420 |         log.info(f'Loaded checkpoint at epoch {self.epoch} with global step {self.global_step}.')
421 | 
422 | '''
423 | #######################################################################################################################################
424 |     
425 |     def reconstruction(self, epoch, i, subdivide, res=300):
426 |         
427 |         torch.cuda.empty_cache()
428 | 
429 |         with torch.no_grad():
430 |             h = self._marching_cubes (i, subdivide=subdivide, res=res)
431 |         h.export(os.path.join(self.log_dir, '%03d_reco_src-%03d.obj' % (epoch, i)) )
432 |         
433 |         torch.cuda.empty_cache()
434 | 
435 | 
436 |     def _marching_cubes (self, i, subdivide=True, res=300):
437 | 
438 |         width = res
439 |         window_x = torch.linspace(-1., 1., steps=width, device='cuda')
440 |         window_y = torch.linspace(-1., 1., steps=width, device='cuda')
441 |         window_z = torch.linspace(-1., 1., steps=width, device='cuda')
442 | 
443 |         coord = torch.stack(torch.meshgrid(window_x, window_y, window_z)).permute(1, 2, 3, 0).reshape(1, -1, 3).contiguous()
444 | 
445 |         
446 |         # Debug smpl grid
447 |         smpl_vertice = self.smpl_V[i]
448 |         d = trimesh.Trimesh(vertices=smpl_vertice.cpu().detach().numpy(), 
449 |                     faces=self.smpl_F.cpu().detach().numpy())
450 |         d.export(os.path.join(self.log_dir, 'smpl_sub_%03d.obj' % (i)) )
451 |         
452 | 
453 |         idx = torch.tensor([i], dtype=torch.long, device = torch.device('cuda')).view(1).detach()
454 |         _points = torch.split(coord, int(2*1e6), dim=1)
455 |         voxels = []
456 |         for _p in _points:
457 |             pred_sdf = self.sdf_field(_p, idx)
458 |             voxels.append(pred_sdf)
459 | 
460 |         voxels = torch.cat(voxels, dim=1)
461 |         voxels = voxels.reshape(1, width, width, width)
462 |         
463 |         vertices, faces = voxelgrids_to_trianglemeshes(voxels, iso_value=0.)
464 |         vertices = ((vertices[0].reshape(1, -1, 3) - 0.5) / (width/2)) - 1.0
465 |         faces = faces[0]
466 | 
467 |         if subdivide:
468 |             vertices, faces = subdivide_trianglemesh(vertices, faces, iterations=1)
469 | 
470 |         pred_rgb = self.rgb_field(vertices, idx+1, pose_idx=idx)            
471 |         
472 |         h = trimesh.Trimesh(vertices=vertices[0].cpu().detach().numpy(), 
473 |                 faces=faces.cpu().detach().numpy(), 
474 |                 vertex_colors=pred_rgb[0].cpu().detach().numpy())
475 | 
476 |         # remove disconnect par of mesh
477 |         connected_comp = h.split(only_watertight=False)
478 |         max_area = 0
479 |         max_comp = None
480 |         for comp in connected_comp:
481 |             if comp.area > max_area:
482 |                 max_area = comp.area
483 |                 max_comp = comp
484 |         h = max_comp
485 |     
486 |         trimesh.repair.fix_inversion(h)
487 | 
488 |         return h
489 | '''
--------------------------------------------------------------------------------
/lib/ops/mesh/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | from .area_weighted_distribution import area_weighted_distribution
10 | from .random_face import random_face
11 | from .point_sample import point_sample
12 | from .sample_surface import sample_surface
13 | from .sample_near_surface import sample_near_surface
14 | from .sample_uniform import sample_uniform
15 | from .load_obj import load_obj
16 | from .normalize import normalize
17 | from .closest_point import *
18 | from .closest_tex import closest_tex
19 | from .barycentric_coordinates import barycentric_coordinates
20 | from .sample_tex import sample_tex
21 | from .per_face_normals import per_face_normals
22 | from .per_vertex_normals import per_vertex_normals
23 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/area_weighted_distribution.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | from .per_face_normals import per_face_normals
11 | 
12 | def area_weighted_distribution(
13 |     V : torch.Tensor,
14 |     F : torch.Tensor, 
15 |     normals : torch.Tensor = None):
16 |     """Construct discrete area weighted distribution over triangle mesh.
17 | 
18 |     Args:
19 |         V (torch.Tensor): #V, 3 array of vertices
20 |         F (torch.Tensor): #F, 3 array of indices
21 |         normals (torch.Tensor): normals (if precomputed)
22 |         eps (float): epsilon
23 |     
24 |     Returns:
25 |         (torch.distributions): Distribution to be used
26 |     """
27 | 
28 |     if normals is None:
29 |         normals = per_face_normals(V, F)
30 |     areas = torch.norm(normals, p=2, dim=1) * 0.5
31 |     areas /= torch.sum(areas) + 1e-10
32 |     
33 |     # Discrete PDF over triangles
34 |     return torch.distributions.Categorical(areas.view(-1))
35 | 
36 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/barycentric_coordinates.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | import numpy as np
11 | 
12 | # Same API as https://github.com/libigl/libigl/blob/main/include/igl/barycentric_coordinates.cpp
13 | 
14 | def barycentric_coordinates(
15 |     points : torch.Tensor, 
16 |     A : torch.Tensor,
17 |     B : torch.Tensor,
18 |     C : torch.Tensor):
19 |     """
20 |     Return barycentric coordinates for a given set of points and triangle vertices
21 | 
22 |     Args:
23 |         points (torch.FloatTensor): [N, 3]
24 |         A (torch.FloatTensor): [N, 3] vertex0
25 |         B (torch.FloatTensor): [N, 3] vertex1
26 |         C (torch.FloatTensor): [N, 3] vertex2
27 |     
28 |     Returns:
29 |         (torch.FloatTensor): barycentric coordinates of [N, 2] 
30 |     """
31 | 
32 |     v0 = B-A
33 |     v1 = C-A
34 |     v2 = points-A
35 |     d00 = (v0*v0).sum(dim=-1)
36 |     d01 = (v0*v1).sum(dim=-1)
37 |     d11 = (v1*v1).sum(dim=-1)
38 |     d20 = (v2*v0).sum(dim=-1)
39 |     d21 = (v2*v1).sum(dim=-1)
40 |     denom = d00*d11 - d01*d01
41 |     L = torch.zeros(points.shape[0], 3, device=points.device)
42 |     # Warning: This clipping may cause undesired behaviour
43 |     L[...,1] = torch.clip((d11*d20 - d01*d21)/denom, 0.0, 1.0)
44 |     L[...,2] = torch.clip((d00*d21 - d01*d20)/denom, 0.0, 1.0)
45 |     L[...,0] = torch.clip(1.0 - (L[...,1] + L[...,2]), 0.0, 1.0)
46 |     return L
47 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/closest_point.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
  2 | #
  3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
  4 | # and proprietary rights in and to this software, related documentation
  5 | # and any modifications thereto.  Any use, reproduction, disclosure or
  6 | # distribution of this software and related documentation without an express
  7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
  8 | 
  9 | # Closest point function + texture sampling
 10 | # https://en.wikipedia.org/wiki/Closest_point_method
 11 | 
 12 | import torch
 13 | import numpy as np
 14 | from .barycentric_coordinates import barycentric_coordinates
 15 | from tqdm import tqdm
 16 | from kaolin.ops.mesh import index_vertices_by_faces, check_sign
 17 | from kaolin import _C
 18 | 
 19 | 
 20 | class _UnbatchedTriangleDistanceCuda(torch.autograd.Function):
 21 |     @staticmethod
 22 |     def forward(ctx, points, face_vertices):
 23 |         num_points = points.shape[0]
 24 |         num_faces = face_vertices.shape[0]
 25 |         min_dist = torch.zeros((num_points), device=points.device, dtype=points.dtype)
 26 |         min_dist_idx = torch.zeros((num_points), device=points.device, dtype=torch.long)
 27 |         dist_type = torch.zeros((num_points), device=points.device, dtype=torch.int32)
 28 |         _C.metrics.unbatched_triangle_distance_forward_cuda(
 29 |             points, face_vertices, min_dist, min_dist_idx, dist_type)
 30 |         ctx.save_for_backward(points.contiguous(), face_vertices.contiguous(),
 31 |                               min_dist_idx, dist_type)
 32 |         ctx.mark_non_differentiable(min_dist_idx, dist_type)
 33 |         return min_dist, min_dist_idx, dist_type
 34 | 
 35 |     @staticmethod
 36 |     def backward(ctx, grad_dist, grad_face_idx, grad_dist_type):
 37 |         points, face_vertices, face_idx, dist_type = ctx.saved_tensors
 38 |         grad_dist = grad_dist.contiguous()
 39 |         grad_points = torch.zeros_like(points)
 40 |         grad_face_vertices = torch.zeros_like(face_vertices)
 41 |         _C.metrics.unbatched_triangle_distance_backward_cuda(
 42 |             grad_dist, points, face_vertices, face_idx, dist_type,
 43 |             grad_points, grad_face_vertices)
 44 |         return grad_points, grad_face_vertices
 45 | 
 46 | 
 47 | def _compute_dot(p1, p2):
 48 |     return p1[..., 0] * p2[..., 0] + \
 49 |         p1[..., 1] * p2[..., 1] + \
 50 |         p1[..., 2] * p2[..., 2]
 51 | 
 52 | def _project_edge(vertex, edge, point):
 53 |     point_vec = point - vertex
 54 |     length = _compute_dot(edge, edge)
 55 |     return _compute_dot(point_vec, edge) / length
 56 | 
 57 | def _project_plane(vertex, normal, point):
 58 |     point_vec = point - vertex
 59 |     unit_normal = normal / torch.norm(normal, dim=-1, keepdim=True)
 60 |     dist = _compute_dot(point_vec, unit_normal)
 61 |     return point - unit_normal * dist.view(-1, 1)
 62 | 
 63 | def _is_not_above(vertex, edge, norm, point):
 64 |     edge_norm = torch.cross(norm, edge, dim=-1)
 65 |     return _compute_dot(edge_norm.view(1, -1, 3),
 66 |                         point.view(-1, 1, 3) - vertex.view(1, -1, 3)) <= 0
 67 | 
 68 | def _point_at(vertex, edge, proj):
 69 |     return vertex + edge * proj.view(-1, 1)
 70 | 
 71 | 
 72 | def _unbatched_naive_point_to_mesh_distance(points, face_vertices):
 73 |     """
 74 |     description of distance type:
 75 |         - 0: distance to face
 76 |         - 1: distance to vertice 0
 77 |         - 2: distance to vertice 1
 78 |         - 3: distance to vertice 2
 79 |         - 4: distance to edge 0-1
 80 |         - 5: distance to edge 1-2
 81 |         - 6: distance to edge 2-0
 82 |     Args:
 83 |         points (torch.Tensor): of shape (num_points, 3).
 84 |         faces_vertices (torch.LongTensor): of shape (num_faces, 3, 3).
 85 |     Returns:
 86 |         (torch.Tensor, torch.LongTensor, torch.IntTensor):
 87 |             - distance, of shape (num_points).
 88 |             - face_idx, of shape (num_points).
 89 |             - distance_type, of shape (num_points).
 90 |             - conter P
 91 |     """
 92 |     num_points = points.shape[0]
 93 |     num_faces = face_vertices.shape[0]
 94 | 
 95 |     device = points.device
 96 |     dtype = points.dtype
 97 | 
 98 |     v1 = face_vertices[:, 0]
 99 |     v2 = face_vertices[:, 1]
100 |     v3 = face_vertices[:, 2]
101 | 
102 |     e21 = v2 - v1
103 |     e32 = v3 - v2
104 |     e13 = v1 - v3
105 | 
106 |     normals = -torch.cross(e21, e13)
107 | 
108 |     uab = _project_edge(v1.view(1, -1, 3), e21.view(1, -1, 3), points.view(-1, 1, 3))
109 |     ubc = _project_edge(v2.view(1, -1, 3), e32.view(1, -1, 3), points.view(-1, 1, 3))
110 |     uca = _project_edge(v3.view(1, -1, 3), e13.view(1, -1, 3), points.view(-1, 1, 3))
111 | 
112 |     is_type1 = (uca > 1.) & (uab < 0.)
113 |     is_type2 = (uab > 1.) & (ubc < 0.)
114 |     is_type3 = (ubc > 1.) & (uca < 0.)
115 |     is_type4 = (uab >= 0.) & (uab <= 1.) & _is_not_above(v1, e21, normals, points)
116 |     is_type5 = (ubc >= 0.) & (ubc <= 1.) & _is_not_above(v2, e32, normals, points)
117 |     is_type6 = (uca >= 0.) & (uca <= 1.) & _is_not_above(v3, e13, normals, points)
118 |     is_type0 = ~(is_type1 | is_type2 | is_type3 | is_type4 | is_type5 | is_type6)
119 | 
120 |     face_idx = torch.zeros(num_points, device=device, dtype=torch.long)
121 |     all_closest_points = torch.zeros((num_points, num_faces, 3), device=device,
122 |                                      dtype=dtype)
123 | 
124 |     all_type0_idx = torch.where(is_type0)
125 |     all_type1_idx = torch.where(is_type1)
126 |     all_type2_idx = torch.where(is_type2)
127 |     all_type3_idx = torch.where(is_type3)
128 |     all_type4_idx = torch.where(is_type4)
129 |     all_type5_idx = torch.where(is_type5)
130 |     all_type6_idx = torch.where(is_type6)
131 | 
132 |     all_types = is_type1.int() + is_type2.int() * 2 + is_type3.int() * 3 + \
133 |         is_type4.int() * 4 + is_type5.int() * 5 + is_type6.int() * 6
134 | 
135 |     all_closest_points[all_type0_idx] = _project_plane(
136 |         v1[all_type0_idx[1]], normals[all_type0_idx[1]], points[all_type0_idx[0]])
137 |     all_closest_points[all_type1_idx] = v1.view(-1, 3)[all_type1_idx[1]]
138 |     all_closest_points[all_type2_idx] = v2.view(-1, 3)[all_type2_idx[1]]
139 |     all_closest_points[all_type3_idx] = v3.view(-1, 3)[all_type3_idx[1]]
140 |     all_closest_points[all_type4_idx] = _point_at(v1[all_type4_idx[1]], e21[all_type4_idx[1]],
141 |                                                   uab[all_type4_idx])
142 |     all_closest_points[all_type5_idx] = _point_at(v2[all_type5_idx[1]], e32[all_type5_idx[1]],
143 |                                                   ubc[all_type5_idx])
144 |     all_closest_points[all_type6_idx] = _point_at(v3[all_type6_idx[1]], e13[all_type6_idx[1]],
145 |                                                   uca[all_type6_idx])
146 |     all_vec = (all_closest_points - points.view(-1, 1, 3))
147 |     all_dist = _compute_dot(all_vec, all_vec)
148 | 
149 |     _, min_dist_idx = torch.min(all_dist, dim=-1)
150 |     dist_type = all_types[torch.arange(num_points, device=device), min_dist_idx]
151 |     torch.cuda.synchronize()
152 | 
153 |     # Recompute the shortest distances
154 |     # This reduce the backward pass to the closest faces instead of all faces
155 |     # O(num_points) vs O(num_points * num_faces)
156 |     selected_face_vertices = face_vertices[min_dist_idx]
157 |     v1 = selected_face_vertices[:, 0]
158 |     v2 = selected_face_vertices[:, 1]
159 |     v3 = selected_face_vertices[:, 2]
160 | 
161 |     e21 = v2 - v1
162 |     e32 = v3 - v2
163 |     e13 = v1 - v3
164 | 
165 |     normals = -torch.cross(e21, e13)
166 | 
167 |     uab = _project_edge(v1, e21, points)
168 |     ubc = _project_edge(v2, e32, points)
169 |     uca = _project_edge(v3, e13, points)
170 | 
171 |     counter_p = torch.zeros((num_points, 3), device=device, dtype=dtype)
172 | 
173 |     cond = (dist_type == 1)
174 |     counter_p[cond] = v1[cond]
175 | 
176 |     cond = (dist_type == 2)
177 |     counter_p[cond] = v2[cond]
178 | 
179 |     cond = (dist_type == 3)
180 |     counter_p[cond] = v3[cond]
181 | 
182 |     cond = (dist_type == 4)
183 |     counter_p[cond] = _point_at(v1, e21, uab)[cond]
184 | 
185 |     cond = (dist_type == 5)
186 |     counter_p[cond] = _point_at(v2, e32, ubc)[cond]
187 | 
188 |     cond = (dist_type == 6)
189 |     counter_p[cond] = _point_at(v3, e13, uca)[cond]
190 | 
191 |     cond = (dist_type == 0)
192 |     counter_p[cond] = _project_plane(v1, normals, points)[cond]
193 |     min_dist = torch.sum((counter_p - points) ** 2, dim=-1)
194 | 
195 |     return min_dist, min_dist_idx, dist_type, counter_p
196 | 
197 | 
198 | def _find_closest_point(points, face_vertices, cur_face_idx, cur_dist_type):
199 |     """Returns the closest point given a querypoints and meshes.
200 |         points (torch.Tensor): of shape (num_points, 3).
201 |         faces_vertices (torch.LongTensor): of shape (num_faces, 3, 3).
202 |         cur_face_idx (torch.LongTensor): of shape (num_points,).
203 |         cur_dist_type (torch.LongTensor): of shape (num_points,).
204 | 
205 |     Returns:
206 |         (torch.FloatTensor): counter_p of shape (num_points, 3).
207 |     """
208 |     num_points = points.shape[0]
209 |     device = points.device
210 |     dtype = points.dtype
211 |     selected_face_vertices = face_vertices[cur_face_idx]
212 | 
213 |     v1 = selected_face_vertices[:, 0]
214 |     v2 = selected_face_vertices[:, 1]
215 |     v3 = selected_face_vertices[:, 2]
216 | 
217 |     e21 = v2 - v1
218 |     e32 = v3 - v2
219 |     e13 = v1 - v3
220 | 
221 |     normals = -torch.cross(e21, e13)
222 | 
223 |     uab = _project_edge(v1, e21, points)
224 |     ubc = _project_edge(v2, e32, points)
225 |     uca = _project_edge(v3, e13, points)
226 | 
227 |     counter_p = torch.zeros((num_points, 3), device=device, dtype=dtype)
228 | 
229 |     cond = (cur_dist_type == 1)
230 |     counter_p[cond] = v1[cond]
231 | 
232 |     cond = (cur_dist_type == 2)
233 |     counter_p[cond] = v2[cond]
234 | 
235 |     cond = (cur_dist_type == 3)
236 |     counter_p[cond] = v3[cond]
237 | 
238 |     cond = (cur_dist_type == 4)
239 |     counter_p[cond] = _point_at(v1, e21, uab)[cond]
240 | 
241 |     cond = (cur_dist_type == 5)
242 |     counter_p[cond] = _point_at(v2, e32, ubc)[cond]
243 | 
244 |     cond = (cur_dist_type == 6)
245 |     counter_p[cond] = _point_at(v3, e13, uca)[cond]
246 | 
247 |     cond = (cur_dist_type == 0)
248 |     counter_p[cond] = _project_plane(v1, normals, points)[cond]
249 | 
250 | 
251 |     return counter_p
252 | 
253 | def closest_point(
254 |     V : torch.Tensor, 
255 |     F : torch.Tensor,
256 |     points : torch.Tensor,
257 |     split_size : int = 5*10**3):
258 | 
259 |     """Returns the closest texture for a set of points.
260 | 
261 |         V (torch.FloatTensor): mesh vertices of shape [V, 3] 
262 |         F (torch.LongTensor): mesh face indices of shape [F, 3]
263 |         points (torch.FloatTensor): sample locations of shape [N, 3]
264 | 
265 |     Returns:
266 |         (torch.FloatTensor): distances of shape [N, 1]
267 |         (torch.FloatTensor): projected points of shape [N, 3]
268 |         (torch.FloatTensor): face indices of shape [N, 1]
269 |     """
270 | 
271 |     V = V.cuda().contiguous()
272 |     F = F.cuda().contiguous()
273 | 
274 |     mesh = index_vertices_by_faces(V.unsqueeze(0), F).squeeze(0)
275 | 
276 |     _points = torch.split(points, split_size)
277 | 
278 |     dists = []
279 |     pts = []
280 |     indices = []
281 |     for _p in _points:
282 |         p = _p.cuda().contiguous()
283 |         sign = check_sign(V.unsqueeze(0), F, p.unsqueeze(0)).squeeze(0)
284 |         dist, hit_tidx, dist_type, hit_pts = _unbatched_naive_point_to_mesh_distance(p, mesh)
285 |         dist = torch.where (sign, -torch.sqrt(dist), torch.sqrt(dist))
286 |         dists.append(dist)
287 |         pts.append(hit_pts)
288 |         indices.append(hit_tidx)
289 | 
290 |     return torch.cat(dists)[...,None], torch.cat(pts), torch.cat(indices)
291 | 
292 | def batched_closest_point(
293 |     V : torch.Tensor, 
294 |     F : torch.Tensor,
295 |     points : torch.Tensor):
296 | 
297 |     """Returns the closest texture for a set of points.
298 | 
299 |         V (torch.FloatTensor): mesh vertices of shape [B, V, 3] 
300 |         F (torch.LongTensor): mesh face indices of shape [F, 3]
301 |         points (torch.FloatTensor): sample locations of shape [B, N, 3]
302 | 
303 |     Returns:
304 |         (torch.FloatTensor): distances of shape [B, N, 1]
305 |         (torch.FloatTensor): projected points of shape [B, N, 3]
306 |         (torch.FloatTensor): face indices of shape [B, N, 1]
307 |     """
308 | 
309 |     V = V.cuda().contiguous()
310 |     F = F.cuda().contiguous()
311 | 
312 |     batch_size = V.shape[0]
313 |     num_points = V.shape[1]
314 | 
315 |     dists = []
316 |     pts = []
317 |     indices = []
318 |     weights = []
319 | 
320 |     sign = check_sign(V, F, points)
321 | 
322 |     for i in range(batch_size):
323 |         mesh = V[i][F]
324 |         p = points[i]
325 |         dist, hit_tidx, dist_type, hit_pts = _unbatched_naive_point_to_mesh_distance(p, mesh)
326 |         dist = torch.where (sign[i], -torch.sqrt(dist), torch.sqrt(dist))
327 |         hitface = F[hit_tidx.view(-1)] # [ Ns , 3]
328 | 
329 | 
330 |         BC = barycentric_coordinates(hit_pts, V[i][hitface[:,0]],
331 |                                     V[i][hitface[:,1]], V[i][hitface[:,2]])
332 | 
333 |         dists.append(dist)
334 |         pts.append(hit_pts)
335 |         indices.append(hit_tidx)
336 |         weights.append(BC)
337 |     
338 |     return torch.stack(dists)[...,None], torch.stack(pts), torch.stack(indices), torch.stack(weights)
339 | 
340 | 
341 | def closest_point_fast(
342 |     V : torch.Tensor, 
343 |     F : torch.Tensor,
344 |     points : torch.Tensor):
345 | 
346 |     """Returns the closest texture for a set of points.
347 | 
348 |         V (torch.FloatTensor): mesh vertices of shape [V, 3] 
349 |         F (torch.LongTensor): mesh face indices of shape [F, 3]
350 |         points (torch.FloatTensor): sample locations of shape [N, 3]
351 | 
352 |     Returns:
353 |         (torch.FloatTensor): signed distances of shape [N, 1]
354 |         (torch.FloatTensor): projected points of shape [N, 3]
355 |         (torch.FloatTensor): face indices of shape [N, ]
356 |     """
357 | 
358 |     face_vertices =  V[F]
359 |     sign = check_sign(V.unsqueeze(0), F, points.unsqueeze(0)).squeeze(0)
360 | 
361 |     if points.is_cuda:
362 |         cur_dist, cur_face_idx, cur_dist_type = _UnbatchedTriangleDistanceCuda.apply(
363 |                 points, face_vertices)
364 |     else:
365 |         cur_dist, cur_face_idx, cur_dist_type = _unbatched_naive_point_to_mesh_distance(
366 |                 points, face_vertices)
367 | 
368 |     hit_point = _find_closest_point(points, face_vertices, cur_face_idx, cur_dist_type)
369 | 
370 |     dist = torch.where (sign, -torch.sqrt(cur_dist), torch.sqrt(cur_dist))
371 | 
372 | 
373 |     return dist[...,None], hit_point, cur_face_idx
374 | 
375 | 
376 | def batched_closest_point_fast(
377 |     V : torch.Tensor, 
378 |     F : torch.Tensor,
379 |     points : torch.Tensor):
380 | 
381 |     """Returns the closest texture for a set of points.
382 | 
383 |         V (torch.FloatTensor): mesh vertices of shape [B, V, 3] 
384 |         F (torch.LongTensor): mesh face indices of shape [F, 3]
385 |         points (torch.FloatTensor): sample locations of shape [B, N, 3]
386 | 
387 |     Returns:
388 |         (torch.FloatTensor): distances of shape [B, N, 1]
389 |         (torch.FloatTensor): projected points of shape [B, N, 3]
390 |         (torch.FloatTensor): face indices of shape [B, N, 1]
391 |     """
392 | 
393 |     batch_size = V.shape[0]
394 | 
395 |     dists = []
396 |     indices = []
397 |     weights = []
398 |     pts = []
399 | 
400 |     for i in range(batch_size):
401 |         cur_dist, hit_point, cur_face_idx = closest_point_fast (V[i], F, points[i])
402 |         hitface = F[cur_face_idx.view(-1)] # [ N , 3]
403 | 
404 |         dists.append(cur_dist)
405 |         pts.append(hit_point)
406 |         indices.append(cur_face_idx)
407 |         weights.append(barycentric_coordinates(hit_point, V[i][hitface[:,0]],
408 |                                     V[i][hitface[:,1]], V[i][hitface[:,2]]))
409 |     
410 |     return torch.stack(dists, dim=0), torch.stack(pts, dim=0), \
411 |            torch.stack(indices, dim=0), torch.stack(weights, dim=0)
--------------------------------------------------------------------------------
/lib/ops/mesh/closest_tex.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | import numpy as np
11 | from .barycentric_coordinates import barycentric_coordinates
12 | from .closest_point import closest_point, closest_point_fast
13 | from .sample_tex import sample_tex
14 | from .per_face_normals import per_face_normals
15 | 
16 | import time
17 | def closest_tex(
18 |     V : torch.Tensor, 
19 |     F : torch.Tensor,
20 |     TV : torch.Tensor,
21 |     TF : torch.Tensor,
22 |     materials,
23 |     points : torch.Tensor):
24 |     """Returns the closest texture for a set of points.
25 | 
26 |         V (torch.FloatTensor): mesh vertices of shape [V, 3] 
27 |         F (torch.LongTensor): mesh face indices of shape [F, 3]
28 |         TV (torch.FloatTensor): 
29 |         TF (torch.FloatTensor):
30 |         materials:
31 |         points (torch.FloatTensor): sample locations of shape [N, 3]
32 | 
33 |     Returns:
34 |         (torch.FloatTensor): texture samples of shape [N, 3]
35 |     """
36 | 
37 |     TV = TV.cuda()
38 |     TF = TF.cuda()
39 |     points = points.to(V.device)
40 |     
41 |     with torch.no_grad():
42 |         dist, hit_pts, hit_tidx = closest_point_fast(V, F, points)
43 | 
44 |     hit_F = F[hit_tidx]
45 |     hit_V = V[hit_F].cuda()
46 |     nrm = per_face_normals(V, hit_F).cuda()
47 | 
48 |     BC = barycentric_coordinates(hit_pts.cuda(), hit_V[:,0], hit_V[:,1], hit_V[:,2])
49 | 
50 |     hit_TF = TF[hit_tidx]
51 |     hit_TM = hit_TF[...,3]
52 |     hit_TF = hit_TF[...,:3]
53 | 
54 |     if TV.shape[0] > 0:
55 |         hit_TV = TV[hit_TF]
56 |         hit_Tp = (hit_TV * BC.unsqueeze(-1)).sum(1)
57 |     else:
58 |         hit_Tp = BC
59 |     rgb = sample_tex(hit_Tp, hit_TM, materials)
60 |     
61 |     return rgb, nrm, dist
62 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/compute_sdf.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import math
10 | import contextlib
11 | import os
12 | import sys
13 | 
14 | import torch
15 | import numpy as np
16 | import wisp._C as _C
17 | 
18 | def compute_sdf(
19 |     V : torch.Tensor,
20 |     F : torch.Tensor,
21 |     points : torch.Tensor,
22 |     split_size : int = 10**6):
23 |     """Computes SDF given point samples and a mesh.
24 | 
25 |     Args:
26 |         V (torch.FloatTensor): #V, 3 array of vertices
27 |         F (torch.LongTensor): #F, 3 array of indices
28 |         points (torch.FloatTensor): [N, 3] array of points to sample
29 |         split_size (int): The batch at which the SDF will be computed. The kernel will break for too large
30 |                           batches; when in doubt use the default.
31 | 
32 |     Returns:
33 |         (torch.FloatTensor): [N, 1] array of computed SDF values.
34 |     """
35 |     mesh = V[F]
36 | 
37 |     _points = torch.split(points, split_size)
38 |     sdfs = []
39 |     for _p in _points:
40 |         sdfs.append(_C.external.mesh_to_sdf_cuda(_p.cuda().contiguous(), mesh.cuda().contiguous())[0])
41 |     return torch.cat(sdfs)[...,None]
42 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/load_obj.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
  2 | #
  3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
  4 | # and proprietary rights in and to this software, related documentation
  5 | # and any modifications thereto.  Any use, reproduction, disclosure or
  6 | # distribution of this software and related documentation without an express
  7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
  8 | 
  9 | import os
 10 | import sys
 11 | 
 12 | import numpy as np
 13 | import tinyobjloader
 14 | import torch
 15 | 
 16 | from PIL import Image
 17 | 
 18 | import logging as log
 19 | import time
 20 | Image.MAX_IMAGE_PIXELS = None
 21 | 
 22 | # Refer to 
 23 | # https://github.com/tinyobjloader/tinyobjloader/blob/master/tiny_obj_loader.h
 24 | # for conventions for tinyobjloader data structures.
 25 | 
 26 | texopts = [
 27 |     'ambient_texname',
 28 |     'diffuse_texname',
 29 |     'specular_texname',
 30 |     'specular_highlight_texname',
 31 |     'bump_texname',
 32 |     'displacement_texname',
 33 |     'alpha_texname',
 34 |     'reflection_texname',
 35 |     'roughness_texname',
 36 |     'metallic_texname',
 37 |     'sheen_texname',
 38 |     'emissive_texname',
 39 |     'normal_texname'
 40 | ]
 41 | 
 42 | def load_mat(fname : str):
 43 |     """Loads material.
 44 |     """
 45 |     img = torch.ByteTensor(np.array(Image.open(fname)))
 46 |     #img = torch.ByteTensor(np.array(Image.open(fname).resize((2048,2048), Image.ANTIALIAS)))
 47 |     #img = img / 255.0
 48 | 
 49 |     return img
 50 | 
 51 | 
 52 | def load_obj(
 53 |     fname : str, 
 54 |     load_materials : bool = False):
 55 |     """Load .obj file using TinyOBJ and extract info.
 56 |     This is more robust since it can triangulate polygon meshes 
 57 |     with up to 255 sides per face.
 58 |     
 59 |     Args:
 60 |         fname (str): path to Wavefront .obj file
 61 |     """
 62 | 
 63 |     assert os.path.exists(fname), \
 64 |         'Invalid file path and/or format, must be an existing Wavefront .obj'
 65 |     
 66 |     reader = tinyobjloader.ObjReader()
 67 |     config = tinyobjloader.ObjReaderConfig()
 68 |     config.triangulate = True # Ensure we don't have any polygons
 69 | 
 70 |     reader.ParseFromFile(fname, config)
 71 | 
 72 |     # Get vertices
 73 |     attrib = reader.GetAttrib()
 74 |     vertices = torch.FloatTensor(attrib.vertices).reshape(-1, 3)
 75 | 
 76 |     # Get triangle face indices
 77 |     shapes = reader.GetShapes()
 78 |     faces = []
 79 |     for shape in shapes:
 80 |         faces += [idx.vertex_index for idx in shape.mesh.indices]
 81 |     faces = torch.LongTensor(faces).reshape(-1, 3)
 82 |     
 83 |     mats = {}
 84 | 
 85 |     if load_materials:
 86 |         # Load per-faced texture coordinate indices
 87 |         texf = []
 88 |         matf = []
 89 |         for shape in shapes:
 90 |             texf += [idx.texcoord_index for idx in shape.mesh.indices]
 91 |             matf.extend(shape.mesh.material_ids)
 92 |         # texf stores [tex_idx0, tex_idx1, tex_idx2, mat_idx]
 93 |         texf = torch.LongTensor(texf).reshape(-1, 3)
 94 |         matf = torch.LongTensor(matf).reshape(-1, 1)
 95 |         texf = torch.cat([texf, matf], dim=-1)
 96 | 
 97 |         # Load texcoords
 98 |         texv = torch.FloatTensor(attrib.texcoords).reshape(-1, 2)
 99 |         
100 |         # Load texture maps
101 |         parent_path = os.path.dirname(fname) 
102 |         materials = reader.GetMaterials()
103 |         for i, material in enumerate(materials):
104 |             mats[i] = {}
105 |             diffuse = getattr(material, 'diffuse')
106 |             if diffuse != '':
107 |                 mats[i]['diffuse'] = torch.FloatTensor(diffuse)
108 | 
109 |             for texopt in texopts:
110 |                 mat_path = getattr(material, texopt)
111 |                 if mat_path != '':
112 |                     img = load_mat(os.path.join(parent_path, mat_path))
113 |                     mats[i][texopt] = img
114 |                     #mats[i][texopt.split('_')[0]] = img
115 |         return vertices, faces, texv, texf, mats
116 | 
117 |     return vertices, faces
118 | 
119 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/normalize.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | 
11 | def normalize(
12 |     V : torch.Tensor,
13 |     F : torch.Tensor,
14 |     mode : str):
15 |     """Normalizes a mesh.
16 | 
17 |     Args:
18 |         V (torch.FloatTensor): Vertices of shape [V, 3]
19 |         F (torch.LongTensor): Faces of shape [F, 3]
20 |         mode (str): Different methods of normalization.
21 | 
22 |     Returns:
23 |         (torch.FloatTensor, torch.LongTensor):
24 |         - Normalized Vertices
25 |         - Faces
26 |     """
27 | 
28 |     if mode == 'sphere':
29 | 
30 |         V_max, _ = torch.max(V, dim=0)
31 |         V_min, _ = torch.min(V, dim=0)
32 |         V_center = (V_max + V_min) / 2.
33 |         V = V - V_center
34 | 
35 |         # Find the max distance to origin
36 |         max_dist = torch.sqrt(torch.max(torch.sum(V**2, dim=-1)))
37 |         V_scale = 1. / max_dist
38 |         V *= V_scale
39 |         return V, F
40 | 
41 |     elif mode == 'aabb':
42 |         
43 |         V_min, _ = torch.min(V, dim=0)
44 |         V = V - V_min
45 | 
46 |         max_dist = torch.max(V)
47 |         V *= 1.0 / max_dist
48 | 
49 |         V = V * 2.0 - 1.0
50 | 
51 |         return V, F
52 | 
53 |     elif mode == 'planar':
54 |         
55 |         V_min, _ = torch.min(V, dim=0)
56 |         V = V - V_min
57 | 
58 |         x_max = torch.max(V[...,0])
59 |         z_max = torch.max(V[...,2])
60 | 
61 |         V[...,0] *= 1.0 / x_max
62 |         V[...,2] *= 1.0 / z_max
63 | 
64 |         max_dist = torch.max(V)
65 |         V[...,1] *= 1.0 / max_dist
66 |         #V *= 1.0 / max_dist
67 | 
68 |         V = V * 2.0 - 1.0
69 | 
70 |         y_min = torch.min(V[...,1])
71 | 
72 |         V[...,1] -= y_min
73 | 
74 |         return V, F
75 | 
76 |     elif mode == 'none':
77 | 
78 |         return V, F
79 | 
80 | 
81 | 
82 | 
83 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/per_face_normals.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | 
11 | def per_face_normals(
12 |     V : torch.Tensor,
13 |     F : torch.Tensor):
14 |     """Compute normals per face.
15 |     
16 |     Args:
17 |         V (torch.FloatTensor): Vertices of shape [V, 3]
18 |         F (torch.LongTensor): Faces of shape [F, 3]
19 |     
20 |     Returns:
21 |         (torch.FloatTensor): Normals of shape [F, 3]
22 |     """
23 |     mesh = V[F]
24 | 
25 |     vec_a = mesh[:, 0] - mesh[:, 1]
26 |     vec_b = mesh[:, 1] - mesh[:, 2]
27 |     normals = torch.cross(vec_a, vec_b)
28 |     return torch.nn.functional.normalize(
29 |         normals, eps=1e-6, dim=1
30 |     )
31 | 
32 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/per_vertex_normals.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | 
11 | def per_vertex_normals(
12 |     V : torch.Tensor,
13 |     F : torch.Tensor):
14 |     """Compute normals per face.
15 |     
16 |     Args:
17 |         V (torch.FloatTensor): Vertices of shape [V, 3]
18 |         F (torch.LongTensor): Faces of shape [F, 3]
19 |     
20 |     Returns:
21 |         (torch.FloatTensor): Normals of shape [F, 3]
22 |     """
23 |     verts_normals = torch.zeros_like(V)
24 |     mesh = V[F]
25 | 
26 |     faces_normals = torch.cross(
27 |         mesh[:, 2] - mesh[:, 1],
28 |         mesh[:, 0] - mesh[:, 1],
29 |         dim=1,
30 |     )
31 | 
32 |     verts_normals.index_add_(0, F[:, 0], faces_normals)
33 |     verts_normals.index_add_(0, F[:, 1], faces_normals)
34 |     verts_normals.index_add_(0, F[:, 2], faces_normals)
35 |     
36 |     return torch.nn.functional.normalize(
37 |         verts_normals, eps=1e-6, dim=1
38 |     )
--------------------------------------------------------------------------------
/lib/ops/mesh/point_sample.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | from .sample_near_surface import sample_near_surface
11 | from .sample_surface import sample_surface
12 | from .sample_uniform import sample_uniform
13 | from .area_weighted_distribution import area_weighted_distribution
14 | 
15 | def point_sample(
16 |     V : torch.Tensor, 
17 |     F : torch.Tensor, 
18 |     techniques : list, 
19 |     num_samples : int,
20 |     variance: float = 0.005):
21 |     """Sample points from a mesh.
22 | 
23 |     Args:
24 |         V (torch.Tensor): #V, 3 array of vertices
25 |         F (torch.Tensor): #F, 3 array of indices
26 |         techniques (list[str]): list of techniques to sample with
27 |         num_samples (int): points to sample per technique
28 |     
29 |     Returns:
30 |         (torch.FloatTensor): Samples of shape [len(techniques)*num_samples, 3]
31 |     """
32 |     if 'trace' in techniques or 'near' in techniques:
33 |         # Precompute face distribution
34 |         distrib = area_weighted_distribution(V, F)
35 | 
36 |     samples = []
37 |     for technique in techniques:
38 |         if technique =='trace':
39 |             samples.append(sample_surface(V, F, num_samples, distrib=distrib)[0])
40 |         elif technique == 'near':
41 |             samples.append(sample_near_surface(V, F, num_samples, distrib=distrib, variance=variance))
42 |         elif technique == 'rand':
43 |             samples.append(sample_uniform(num_samples).to(V.device))
44 |     samples = torch.cat(samples, dim=0)
45 |     return samples
46 | 
47 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/random_face.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | from .area_weighted_distribution import area_weighted_distribution
11 | from .per_face_normals import per_face_normals
12 | 
13 | def random_face(
14 |     V : torch.Tensor, 
15 |     F : torch.Tensor, 
16 |     num_samples : int, 
17 |     distrib=None):
18 |     """Return an area weighted random sample of faces and their normals from the mesh.
19 | 
20 |     Args:
21 |         V (torch.Tensor): #V, 3 array of vertices
22 |         F (torch.Tensor): #F, 3 array of indices
23 |         num_samples (int): num of samples to return
24 |         distrib: distribution to use. By default, area-weighted distribution is used.
25 |     
26 |     Returns:
27 |         (torch.LongTensor, torch.FloatTensor):
28 |         - Faces
29 |         - Normals
30 |     """
31 |     if distrib is None:
32 |         distrib = area_weighted_distribution(V, F)
33 | 
34 |     normals = per_face_normals(V, F)
35 | 
36 |     idx = distrib.sample([num_samples])
37 | 
38 |     return F[idx], normals[idx]
39 | 
40 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/sample_near_surface.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | from .sample_surface import sample_surface
11 | from .area_weighted_distribution import area_weighted_distribution
12 | 
13 | def sample_near_surface(
14 |     V : torch.Tensor,
15 |     F : torch.Tensor, 
16 |     num_samples: int, 
17 |     variance : float = 0.005,
18 |     distrib=None):
19 |     """Sample points near the mesh surface.
20 | 
21 |     Args:
22 |         V (torch.Tensor): #V, 3 array of vertices
23 |         F (torch.Tensor): #F, 3 array of indices
24 |         num_samples (int): number of surface samples
25 |         distrib: distribution to use. By default, area-weighted distribution is used
26 |     
27 |     Returns:
28 |         (torch.FloatTensor): samples of shape [num_samples, 3]
29 |     """
30 |     if distrib is None:
31 |         distrib = area_weighted_distribution(V, F)
32 |     samples = sample_surface(V, F, num_samples, distrib)[0]
33 |     samples += torch.randn_like(samples) * variance
34 |     return samples
35 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/sample_surface.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | from .random_face import random_face
11 | from .area_weighted_distribution import area_weighted_distribution
12 | 
13 | def sample_surface(
14 |     V : torch.Tensor,
15 |     F : torch.Tensor,
16 |     num_samples : int,
17 |     distrib = None):
18 |     """Sample points and their normals on mesh surface.
19 | 
20 |     Args:
21 |         V (torch.Tensor): #V, 3 array of vertices
22 |         F (torch.Tensor): #F, 3 array of indices
23 |         num_samples (int): number of surface samples
24 |         distrib: distribution to use. By default, area-weighted distribution is used
25 |     
26 |     Returns:
27 |         (torch.FloatTensor): samples of shape [num_samples, 3]
28 |     """
29 |     if distrib is None:
30 |         distrib = area_weighted_distribution(V, F)
31 | 
32 |     # Select faces & sample their surface
33 |     fidx, normals = random_face(V, F, num_samples, distrib)
34 |     f = V[fidx]
35 | 
36 |     u = torch.sqrt(torch.rand(num_samples)).to(V.device).unsqueeze(-1)
37 |     v = torch.rand(num_samples).to(V.device).unsqueeze(-1)
38 | 
39 |     samples = (1 - u) * f[:,0,:] + (u * (1 - v)) * f[:,1,:] + u * v * f[:,2,:]
40 |     
41 |     return samples, normals
42 | 
43 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/sample_tex.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | import torch.nn.functional as F
11 | import time
12 | def sample_tex(
13 |     Tp : torch.Tensor, # points [N ,2] 
14 |     TM : torch.Tensor, # material indices [N]
15 |     materials):
16 |     """Sample from a texture.
17 | 
18 |     Args:
19 |         Tp (torch.FloatTensor): 2D coordinates to sample of shape [N, 2]
20 |         TM (torch.LongTensor): Indices of the material to sample of shape [N]
21 |         materials (list of material): Materials
22 | 
23 |     Returns:
24 |         (torch.FloatTensor): RGB samples of shape [N, 3]
25 |     """
26 | 
27 |     max_idx = TM.max()
28 |     assert(max_idx > -1 and "No materials detected! Check the material definiton on your mesh.")
29 | 
30 |     rgb = torch.zeros(Tp.shape[0], 3, device=Tp.device) # why this line is slow????
31 | 
32 |     Tp = (Tp * 2.0) - 1.0
33 |     # The y axis is flipped from what UV maps generally expects vs in PyTorch
34 |     Tp[...,1] *= -1
35 | 
36 |     for i in range(max_idx+1):
37 |         mask = (TM == i)
38 |         if mask.sum() == 0:
39 |             continue
40 |         if 'diffuse_texname' not in materials[i]:
41 |             if 'diffuse' in materials[i]:
42 |                 rgb[mask] = materials[i]['diffuse'].to(Tp.device)
43 |             continue
44 | 
45 |         map = materials[i]['diffuse_texname'][...,:3].permute(2, 0, 1)[None].float().to(Tp.device) / 255.0
46 |         grid = Tp[mask]
47 |         grid = grid.reshape(1, grid.shape[0], 1, grid.shape[1])
48 |         _rgb = F.grid_sample(map, grid, mode='bilinear', padding_mode='reflection', align_corners=True)
49 |         _rgb = _rgb[0,:,:,0].permute(1,0)
50 |         rgb[mask] = _rgb
51 | 
52 | 
53 |     return rgb
54 | 
55 | 
56 | 
--------------------------------------------------------------------------------
/lib/ops/mesh/sample_uniform.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
 2 | #
 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
 4 | # and proprietary rights in and to this software, related documentation
 5 | # and any modifications thereto.  Any use, reproduction, disclosure or
 6 | # distribution of this software and related documentation without an express
 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
 8 | 
 9 | import torch
10 | 
11 | def sample_uniform(num_samples : int):
12 |     """Sample uniformly in [-1,1] bounding volume.
13 |     
14 |     Args:
15 |         num_samples(int) : number of points to sample
16 |     
17 |     Returns:
18 |         (torch.FloatTensor): samples of shape [num_samples, 3]
19 |     """
20 |     return torch.rand(num_samples, 3) * 2.0 - 1.0
21 | 
22 | 
--------------------------------------------------------------------------------
/lib/utils/camera.py:
--------------------------------------------------------------------------------
 1 | # Adapted from: https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/main/wisp/ops/raygen/raygen.py
 2 | 
 3 | from kaolin.render.camera import Camera
 4 | from kaolin.render.camera.intrinsics import CameraFOV
 5 | import torch
 6 | 
 7 | ################################ Ray Sampling Function ########################################
 8 | 
 9 | 
10 | def _generate_default_grid(width, height, device=None):
11 |     h_coords = torch.arange(height, device=device)
12 |     w_coords = torch.arange(width, device=device)
13 |     return torch.meshgrid(h_coords, w_coords, indexing='ij')  # return pixel_y, pixel_x
14 | 
15 | 
16 | def generate_centered_pixel_coords(img_width, img_height, res_x=None, res_y=None, device=None):
17 |     pixel_y, pixel_x = _generate_default_grid(res_x, res_y, device)
18 |     scale_x = 1.0 if res_x is None else float(img_width) / res_x
19 |     scale_y = 1.0 if res_y is None else float(img_height) / res_y
20 |     pixel_x = pixel_x * scale_x + 0.5   # scale and add bias to pixel center
21 |     pixel_y = pixel_y * scale_y + 0.5   # scale and add bias to pixel center
22 |     return pixel_y, pixel_x
23 | 
24 | 
25 | # -- Ray gen --
26 | 
27 | def _to_ndc_coords(pixel_x, pixel_y, camera):
28 |     pixel_x = 2 * (pixel_x / camera.width) - 1.0
29 |     pixel_y = 2 * (pixel_y / camera.height) - 1.0
30 |     return pixel_x, pixel_y
31 | 
32 | 
33 | def generate_pinhole_rays(camera: Camera, coords_grid: torch.Tensor):
34 |     """Default ray generation function for pinhole cameras.
35 | 
36 |     This function assumes that the principal point (the pinhole location) is specified by a 
37 |     displacement (camera.x0, camera.y0) in pixel coordinates from the center of the image. 
38 | 
39 |     The Kaolin camera class does not enforce a coordinate space for how the principal point is specified,
40 |     so users will need to make sure that the correct principal point conventions are followed for 
41 |     the cameras passed into this function.
42 | 
43 |     Args:
44 |         camera (kaolin.render.camera): The camera class. 
45 |         coords_grid (torch.FloatTensor): Grid of coordinates of shape [H, W, 2].
46 | 
47 |     Returns:
48 |         (wisp.core.Rays): The generated pinhole rays for the camera.
49 |     """
50 |     if camera.device != coords_grid[0].device:
51 |         raise Exception(f"Expected camera and coords_grid[0] to be on the same device, but found {camera.device} and {coords_grid[0].device}.")
52 |     if camera.device != coords_grid[1].device:
53 |         raise Exception(f"Expected camera and coords_grid[1] to be on the same device, but found {camera.device} and {coords_grid[1].device}.")
54 |     # coords_grid should remain immutable (a new tensor is implicitly created here)
55 |     pixel_y, pixel_x = coords_grid
56 |     pixel_x = pixel_x.to(camera.device, camera.dtype)
57 |     pixel_y = pixel_y.to(camera.device, camera.dtype)
58 | 
59 |     # Account for principal point (offsets from the center)
60 |     pixel_x = pixel_x - camera.x0
61 |     pixel_y = pixel_y + camera.y0
62 | 
63 |     # pixel values are now in range [-1, 1], both tensors are of shape res_y x res_x
64 |     pixel_x, pixel_y = _to_ndc_coords(pixel_x, pixel_y, camera)
65 | 
66 |     ray_dir = torch.stack((pixel_x * camera.tan_half_fov(CameraFOV.HORIZONTAL),
67 |                            -pixel_y * camera.tan_half_fov(CameraFOV.VERTICAL),
68 |                            -torch.ones_like(pixel_x)), dim=-1)
69 | 
70 |     ray_dir = ray_dir.reshape(-1, 3)    # Flatten grid rays to 1D array
71 |     ray_orig = torch.zeros_like(ray_dir)
72 | 
73 |     # Transform from camera to world coordinates
74 |     ray_orig, ray_dir = camera.extrinsics.inv_transform_rays(ray_orig, ray_dir)
75 |     ray_dir /= torch.linalg.norm(ray_dir, dim=-1, keepdim=True)
76 |     ray_orig, ray_dir = ray_orig[0], ray_dir[0]  # Assume a single camera
77 | 
78 |     return ray_orig, ray_dir
79 | 
80 | #########################################################################################################################
--------------------------------------------------------------------------------
/lib/utils/config.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import pprint
  3 | import yaml
  4 | import logging
  5 | 
  6 | def parse_options():
  7 | 
  8 |     parser = argparse.ArgumentParser(description='Custom Humans Code')
  9 | 
 10 | 
 11 |     ###################
 12 |     # Global arguments
 13 |     ###################
 14 |     global_group = parser.add_argument_group('global')
 15 |     global_group.add_argument('--config', type=str, default='config.yaml', 
 16 |                                help='Path to config file to replace defaults')
 17 |     global_group.add_argument('--save-root', type=str, default='./checkpoints/', 
 18 |                                help="outputs path")
 19 |     global_group.add_argument('--exp-name', type=str, default='test',
 20 |                                help="Experiment name.")
 21 |     global_group.add_argument('--seed', type=int, default=123)
 22 |     global_group.add_argument('--resume', type=str, default=None,
 23 |                                 help='Resume from the checkpoint.')
 24 |     global_group.add_argument(
 25 |         '--log_level', action='store', type=int, default=logging.INFO,
 26 |         help='Logging level to use globally, DEBUG: 10, INFO: 20, WARN: 30, ERROR: 40.')
 27 |         
 28 |     ###################
 29 |     # Arguments for dataset
 30 |     ###################
 31 |     data_group = parser.add_argument_group('dataset')
 32 |     data_group.add_argument('--data-root', type=str, default='CustomHumans.h5',
 33 |                             help='Path to dataset')
 34 |     data_group.add_argument('--num-samples', type=int, default=20480,
 35 |                             help='Number of samples to use for each subject during training')
 36 |     data_group.add_argument('--repeat-times', type=int, default=8,
 37 |                             help='Number of times to repeat each subject during training')
 38 | 
 39 | 
 40 |     ###################
 41 |     # Arguments for optimizer
 42 |     ###################
 43 |     optim_group = parser.add_argument_group('optimizer')
 44 |     optim_group.add_argument('--lr-codebook', type=float, default=0.001, 
 45 |                              help='Learning rate for the codebook.')
 46 |     optim_group.add_argument('--lr-decoder', type=float, default=0.001, 
 47 |                              help='Learning rate for the decoder.')
 48 |     optim_group.add_argument('--lr-dis', type=float, default=0.004,
 49 |                                 help='Learning rate for the discriminator.')
 50 |     optim_group.add_argument('--beta1', type=float, default=0.5,
 51 |                                 help='Beta1.')
 52 |     optim_group.add_argument('--beta2', type=float, default=0.999,
 53 |                                 help='Beta2.')
 54 |     optim_group.add_argument('--weight-decay', type=float, default=0, 
 55 |                              help='Weight decay.')
 56 | 
 57 | 
 58 |     ###################
 59 |     # Arguments for training
 60 |     ###################
 61 |     train_group = parser.add_argument_group('train')
 62 |     train_group.add_argument('--epochs', type=int, default=800, 
 63 |                              help='Number of epochs to run the training.')
 64 |     train_group.add_argument('--batch-size', type=int, default=2, 
 65 |                              help='Batch size for the training.')
 66 |     train_group.add_argument('--workers', type=int, default=0,
 67 |                              help='Number of workers for the data loader. 0 means single process.')
 68 |     train_group.add_argument('--save-every', type=int, default=50, 
 69 |                              help='Save the model at every N epoch.')
 70 |     train_group.add_argument('--log-every', type=int, default=100,
 71 |                              help='write logs to wandb at every N iters')
 72 |     train_group.add_argument('--use-2d-from-epoch', type=int, default=-1,
 73 |                              help='Adding 2D loss from this epoch. -1 indicates not using 2D loss.')
 74 |     train_group.add_argument('--train-2d-every-iter', type=int, default=1,
 75 |                              help='Train 2D loss every N iterations.')
 76 |     train_group.add_argument('--use-nrm-dis', action='store_true',
 77 |                              help='train with normal loss discriminator.')
 78 |     train_group.add_argument('--use-cached-pts', action='store_true',
 79 |                              help='Use cached point coordinates instead of online raytracing during training.')
 80 | 
 81 |     ###################
 82 |     # Arguments for Feature Dictionary
 83 |     ###################
 84 |     sample_group = parser.add_argument_group('dictionary')
 85 |     sample_group.add_argument('--shape-dim', type=int, default=32,
 86 |                                 help='Dimension of the shape feature code.')
 87 |     sample_group.add_argument('--color-dim', type=int, default=32,
 88 |                                 help='Dimension of the color feature code.')
 89 |     sample_group.add_argument('--feature-std', type=float, default=0.1,
 90 |                                 help='Standard deviation for initializing the feature code.')
 91 |     sample_group.add_argument('--feature-bias', type=float, default=0.1,
 92 |                                 help='Bias for initializing the feature code.')
 93 |     sample_group.add_argument('--shape-pca-dim', type=int, default=8,
 94 |                                 help='Dimension of the shape pca code.')
 95 |     sample_group.add_argument('--color-pca-dim', type=int, default=16,
 96 |                                 help='Dimension of the color pca code.')
 97 |     
 98 |     ###################
 99 |     # Arguments for Network
100 |     ###################
101 |     net_group = parser.add_argument_group('network')
102 |     net_group.add_argument('--pos-dim', type=int, default=3,
103 |                           help='input position dimension')
104 |     net_group.add_argument('--c-dim', type=int, default=0,
105 |                           help='conditional input dimension, if 0, no conditional input')
106 |     net_group.add_argument('--num-layers', type=int, default=4, 
107 |                              help='Number of layers for the MLPs.')
108 |     net_group.add_argument('--hidden-dim', type=int, default=128,
109 |                           help='Network width')
110 |     net_group.add_argument('--activation', type=str, default='relu',
111 |                             choices=['relu', 'sin', 'softplus', 'lrelu'])
112 |     net_group.add_argument('--layer-type', type=str, default='none',
113 |                             choices=['none', 'spectral_norm', 'frobenius_norm', 'l_1_norm', 'l_inf_norm'])
114 |     net_group.add_argument('--skip', type=int, nargs='*', default=[2],
115 |                           help='Layer to have skip connection.')
116 | 
117 |     ###################
118 |     # Embedder arguments
119 |     ###################
120 |     embedder_group = parser.add_argument_group('embedder')
121 |     embedder_group.add_argument('--shape-freq', type=int, default=5,
122 |                                 help='log2 of max freq')
123 |     embedder_group.add_argument('--color-freq', type=int, default=10,
124 |                                 help='log2 of max freq')
125 | 
126 | 
127 |     ###################
128 |     # Losses arguments
129 |     ###################
130 |     embedder_group = parser.add_argument_group('losses')
131 |     embedder_group.add_argument('--lambda-sdf', type=float, default=1000,
132 |                                 help='lambda for sdf loss')
133 |     embedder_group.add_argument('--lambda-rgb', type=float, default=150,
134 |                                 help='lambda for rgb loss')
135 |     embedder_group.add_argument('--lambda-nrm', type=float, default=10,
136 |                                 help='lambda for normal loss')
137 |     embedder_group.add_argument('--lambda-reg', type=float, default=1,
138 |                                 help='lambda for regularization loss')
139 |     embedder_group.add_argument('--gan-loss-type', type=str, default='logistic',
140 |                                 choices=['logistic', 'hinge'],
141 |                                 help='loss type for gan loss')
142 |     embedder_group.add_argument('--lambda-gan', type=float, default=1,  
143 |                                 help='lambda for gan loss')
144 |     embedder_group.add_argument('--lambda-grad', type=float, default=10,
145 |                                 help='lambda for gradient penalty')
146 | 
147 |    ###################
148 |     # Arguments for validation
149 |     ###################
150 |     valid_group = parser.add_argument_group('validation')
151 |     valid_group.add_argument('--valid-every', type=int, default=10,
152 |                              help='Frequency of running validation.')
153 |     valid_group.add_argument('--subdivide', type=bool, default=True, 
154 |                             help='Subdivide the mesh before marching cubes')
155 |     valid_group.add_argument('--grid-size', type=int, default=300, 
156 |                             help='Grid size for marching cubes')
157 |     valid_group.add_argument('--width', type=int, default=1024, 
158 |                             help='Image width (height) for rendering')
159 |     valid_group.add_argument('--fov', type=float, default=20.0, 
160 |                             help='Field of view for rendering')
161 |     valid_group.add_argument('--n_views', type=int, default=4, 
162 |                             help='Number of views for rendering')
163 | 
164 |     ###################
165 |     # Arguments for wandb
166 |     ###################
167 |     wandb_group = parser.add_argument_group('wandb')
168 |     
169 |     wandb_group.add_argument('--wandb-id', type=str, default=None,
170 |                              help='wandb id')
171 |     wandb_group.add_argument('--wandb', action='store_true',
172 |                              help='Use wandb')
173 |     wandb_group.add_argument('--wandb-name', default='default', type=str,
174 |                              help='wandb_name')
175 | 
176 |     return parser
177 | 
178 | 
179 | def parse_yaml_config(config_path, parser):
180 |     """Parses and sets the parser defaults with a yaml config file.
181 | 
182 |     Args:
183 |         config_path : path to the yaml config file.
184 |         parser : The parser for which the defaults will be set.
185 |         parent : True if parsing the parent yaml. Should never be set to True by the user.
186 |     """
187 |     with open(config_path) as f:
188 |         config_dict = yaml.safe_load(f)
189 | 
190 |     list_of_valid_fields = []
191 |     for group in parser._action_groups:
192 |         group_dict = {list_of_valid_fields.append(a.dest) for a in group._group_actions}
193 |     list_of_valid_fields = set(list_of_valid_fields)
194 |     
195 |     defaults_dict = {}
196 | 
197 |     # Loads child parent and overwrite the parent configs
198 |     # The yaml files assumes the argument groups, which aren't actually nested.
199 |     for key in config_dict:
200 |         for field in config_dict[key]:
201 |             if field not in list_of_valid_fields:
202 |                 raise ValueError(
203 |                     f"ERROR: {field} is not a valid option. Check for typos in the config."
204 |                 )
205 |             defaults_dict[field] = config_dict[key][field]
206 | 
207 | 
208 |     parser.set_defaults(**defaults_dict)
209 | 
210 | def argparse_to_str(parser, args=None):
211 |     """Convert parser to string representation for Tensorboard logging.
212 | 
213 |     Args:
214 |         parser (argparse.parser): Parser object. Needed for the argument groups.
215 |         args : The parsed arguments. Will compute from the parser if None.
216 |     
217 |     Returns:
218 |         args    : The parsed arguments.
219 |         arg_str : The string to be printed.
220 |     """
221 |     
222 |     if args is None:
223 |         args = parser.parse_args()
224 | 
225 |     if args.config is not None:
226 |         parse_yaml_config(args.config, parser)
227 | 
228 |     args = parser.parse_args()
229 | 
230 |     args_dict = {}
231 |     for group in parser._action_groups:
232 |         group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
233 |         args_dict[group.title] = vars(argparse.Namespace(**group_dict))
234 | 
235 |     pp = pprint.PrettyPrinter(indent=2)
236 |     args_str = pp.pformat(args_dict)
237 |     args_str = f'```{args_str}```'
238 | 
239 |     return args, args_str
--------------------------------------------------------------------------------
/lib/utils/image.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import pickle
 3 | import cv2
 4 | 
 5 | import torchvision.transforms as transforms
 6 | 
 7 | def update_edited_images(image_path, pickle_path):
 8 |     with open(pickle_path, 'rb') as f:
 9 |         data = pickle.load(f)
10 | 
11 |     img_list = [ os.path.join(image_path, f) for f in sorted(os.listdir(image_path)) if f.endswith('.png') ]
12 |     transform = transforms.Compose([
13 |                 transforms.ToTensor()
14 |                 ])
15 |     for i, img in enumerate(img_list):
16 |         rgb_img = cv2.imread(img)
17 |         rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
18 |         rgb = transform(rgb_img).permute(1,2,0).view(-1, 3)
19 |         data['rgb'][i] = rgb
20 | 
21 |     return data
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | torch==1.11.0
 2 | h5py
 3 | smplx
 4 | wandb
 5 | trimesh
 6 | opencv-python
 7 | git+https://github.com/tinyobjloader/tinyobjloader.git@v2.0.0rc8#subdirectory=python
 8 | 
 9 | --extra-index-url https://download.pytorch.org/whl/cu113
10 | torch==1.11.0+cu113 
11 | torchvision==0.12.0+cu113
12 | 
13 | -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.11.0_cu113.html
14 | kaolin==0.12.0
--------------------------------------------------------------------------------
/smplx/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/smplx/.gitkeep
--------------------------------------------------------------------------------
/tools/align_thuman.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import numpy as np
 3 | import os
 4 | import trimesh
 5 | from PIL import Image
 6 | import pickle
 7 | import cv2
 8 | from tqdm import tqdm
 9 | from smplx import SMPLX
10 | import json
11 | device = torch.device('cuda')
12 | 
13 | SMPLX_PATH = 'smplx'
14 | OUT_PATH = 'new_thuman'
15 | 
16 | body_model = SMPLX(model_path=SMPLX_PATH, num_pca_comps=12,gender='male')
17 | 
18 | for id in tqdm(range(526)):
19 |     name_id = "%04d" % id
20 |     input_file = os.path.join('THuman2.0', name_id, name_id + '.obj')
21 |     tex_file = os.path.join('THuman2.0', name_id, 'material0.jpeg')
22 |     smpl_file = os.path.join('THuman2.0_smplx', name_id, 'smplx_param.pkl')
23 | 
24 |     smpl_data = pickle.load(open(smpl_file,'rb'))
25 |     out_file_name = os.path.splitext(os.path.basename(input_file))[0]
26 |     output_aligned_path = os.path.join(OUT_PATH, out_file_name)
27 |     os.makedirs(output_aligned_path, exist_ok=True)
28 | 
29 |     
30 |     textured_mesh = trimesh.load(input_file)
31 | 
32 | 
33 |     output = body_model(body_pose = torch.tensor(smpl_data['body_pose']),
34 |                                 betas = torch.tensor(smpl_data['betas']),
35 |                                 left_hand_pose = torch.tensor(smpl_data['left_hand_pose']),
36 |                                 right_hand_pose = torch.tensor(smpl_data['right_hand_pose']),
37 |                                )
38 |     J_0 = output.joints.detach().cpu().numpy()[0,0,:]
39 | 
40 |     d = trimesh.Trimesh(vertices=output.vertices.detach().cpu().numpy()[0] -J_0 ,
41 |                                     faces=body_model.faces)
42 | 
43 | 
44 |     R = np.asarray(smpl_data['global_orient'][0])
45 |     rot_mat = np.zeros(shape=(3,3))
46 |     rot_mat, _ = cv2.Rodrigues(R)
47 |     scale = smpl_data['scale']
48 | 
49 |     T = -np.asarray(smpl_data['translation'])
50 |     S = np.eye(4)
51 |     S[:3, 3] = T
52 |     textured_mesh.apply_transform(S)
53 |     
54 |     S = np.eye(4)
55 |     S[:3, :3] *= 1./scale
56 |     textured_mesh.apply_transform(S)
57 | 
58 |     T = -J_0
59 |     S = np.eye(4)
60 |     S[:3, 3] = T
61 |     textured_mesh.apply_transform(S)
62 | 
63 |     S = np.eye(4)
64 |     S[:3, :3] = np.linalg.inv(rot_mat)
65 |     textured_mesh.apply_transform(S)
66 | 
67 | 
68 | 
69 |     visual = trimesh.visual.texture.TextureVisuals(uv=textured_mesh.visual.uv, image=Image.open(tex_file))
70 |     
71 |     t = trimesh.Trimesh(vertices=textured_mesh.vertices,
72 |                                  faces=textured_mesh.faces,
73 |                                  vertex_normals=textured_mesh.vertex_normals,
74 |                                  visual=visual)
75 | 
76 |     #t = t.simplify_quadratic_decimation(50000)
77 |     #t.visual.material.name = out_file_name
78 | 
79 | 
80 |     d.export(os.path.join(output_aligned_path, out_file_name + '_smplx.obj')  )
81 |     t.export(os.path.join(output_aligned_path, out_file_name + '.obj')  )
82 |     with open(os.path.join(output_aligned_path, out_file_name + '.mtl'), 'w') as f:
83 |             f.write('newmtl {}\n'.format(out_file_name))
84 |             f.write('map_Kd {}.jpeg\n'.format(out_file_name))
85 |     
86 |     result = {}
87 |     result ['transl'] = [0.,0.,0.]
88 |     for key, val in smpl_data.items():
89 |         if key not in ['scale', 'translation']:
90 |             result[key] = val[0].tolist()
91 | 
92 |     json_file = os.path.join(output_aligned_path, out_file_name + '_smplx.json')
93 |     json.dump(result, open(json_file, 'w'), indent=4)
94 | 
--------------------------------------------------------------------------------
/tools/evaluate.py:
--------------------------------------------------------------------------------
  1 | import os 
  2 | import torch
  3 | import scipy as sp
  4 | import numpy as np
  5 | import argparse
  6 | import trimesh
  7 | 
  8 | 
  9 | def calculate_iou(gt, prediction):
 10 |     intersection = torch.logical_and(gt, prediction)
 11 |     union = torch.logical_or(gt, prediction)
 12 |     return torch.sum(intersection) / torch.sum(union)
 13 | 
 14 | def compute_surface_metrics(mesh_pred, mesh_gt):
 15 |     """Compute surface metrics (chamfer distance and f-score) for one example.
 16 |     Args:
 17 |     mesh: trimesh.Trimesh, the mesh to evaluate.
 18 |     Returns:
 19 |     chamfer: float, chamfer distance.
 20 |     fscore: float, f-score.
 21 |     """
 22 |     # Chamfer
 23 |     eval_points = 1000000
 24 | 
 25 |     point_gt, idx_gt = mesh_gt.sample(eval_points, return_index=True)
 26 |     normal_gt = mesh_gt.face_normals[idx_gt]
 27 |     point_gt = point_gt.astype(np.float32)
 28 | 
 29 |     point_pred, idx_pred = mesh_pred.sample(eval_points, return_index=True)
 30 |     normal_pred = mesh_pred.face_normals[idx_pred]
 31 |     point_pred = point_pred.astype(np.float32)
 32 | 
 33 |     dist_pred_to_gt, normal_pred_to_gt = distance_field_helper(point_pred, point_gt, normal_pred, normal_gt)
 34 |     dist_gt_to_pred, normal_gt_to_pred = distance_field_helper(point_gt, point_pred, normal_gt, normal_pred)
 35 | 
 36 |     # TODO: subdivide by 2 following OccNet 
 37 |     # https://github.com/autonomousvision/occupancy_networks/blob/406f79468fb8b57b3e76816aaa73b1915c53ad22/im2mesh/eval.py#L136
 38 |     chamfer_l1 = np.mean(dist_pred_to_gt) + np.mean(dist_gt_to_pred)
 39 | 
 40 |     c1 = np.mean(dist_pred_to_gt)
 41 |     c2 = np.mean(dist_gt_to_pred)
 42 | 
 43 |     normal_consistency = np.mean(normal_pred_to_gt) + np.mean(normal_gt_to_pred)
 44 | 
 45 |     # Fscore
 46 |     tau = 1e-4
 47 |     eps = 1e-9
 48 | 
 49 |     dist_pred_to_gt = (dist_pred_to_gt**2)
 50 |     dist_gt_to_pred = (dist_gt_to_pred**2)
 51 | 
 52 |     prec_tau = (dist_pred_to_gt <= tau).astype(np.float32).mean() * 100.
 53 |     recall_tau = (dist_gt_to_pred <= tau).astype(np.float32).mean() * 100.
 54 | 
 55 |     fscore = (2 * prec_tau * recall_tau) / max(prec_tau + recall_tau, eps)
 56 | 
 57 |     # Following the tradition to scale chamfer distance up by 10.
 58 |     return c1 * 1000., c2 * 1000., normal_consistency / 2., fscore
 59 | 
 60 | def distance_field_helper(source, target, normals_src=None, normals_tgt=None):
 61 |     target_kdtree = sp.spatial.cKDTree(target)
 62 |     distances, idx = target_kdtree.query(source, n_jobs=-1)
 63 | 
 64 |     if normals_src is not None and normals_tgt is not None:
 65 |         
 66 |         normals_src = \
 67 |             normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
 68 |         normals_tgt = \
 69 |             normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
 70 | 
 71 |         normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
 72 |         # Handle normals that point into wrong direction gracefully
 73 |         # (mostly due to mehtod not caring about this in generation)
 74 |         normals_dot_product = np.abs(normals_dot_product)
 75 | 
 76 |     else:
 77 |         normals_dot_product = np.array(
 78 |             [np.nan] * source.shape[0], dtype=np.float32)
 79 | 
 80 |     return distances, normals_dot_product
 81 | 
 82 | 
 83 | 
 84 | def main(args):
 85 | 
 86 |     input_subfolder =  [x for x in sorted(os.listdir(args.input_path)) if x.endswith('obj')]
 87 |     gt_subfolder = [x for x in sorted(os.listdir(args.gt_path)) if x.endswith('obj')]
 88 | 
 89 |     mean_c1 = 0.
 90 |     mean_c2 = 0.
 91 |     mean_fscore = 0.
 92 |     mean_normal_consistency = 0.
 93 | 
 94 |     for pred, gt in zip(input_subfolder, gt_subfolder):
 95 |         mesh_pred = trimesh.load(os.path.join(args.input_path, pred))
 96 |         mesh_gt = trimesh.load(os.path.join(args.gt_path, gt))
 97 | 
 98 |         pred_2_scan, scan_2_pred, normal_consistency, fscore = compute_surface_metrics(mesh_pred, mesh_gt)
 99 |         print('Chamfer: {:.3f}, {:.3f}, Normal Consistency: {:.3f}, Fscore: {:.3f}'.format(pred_2_scan, scan_2_pred, normal_consistency, fscore))
100 |         mean_c1 += pred_2_scan
101 |         mean_c2 += scan_2_pred
102 |         mean_fscore += fscore
103 |         mean_normal_consistency += normal_consistency
104 |     
105 |     mean_c1 /= len(input_subfolder)
106 |     mean_c2 /= len(input_subfolder)
107 |     mean_fscore /= len(input_subfolder)
108 |     mean_normal_consistency /= len(input_subfolder)
109 |     print('Mean Chamfer: {:.3f}, {:.3f}, Normal Consistency: {:.3f}, Fscore: {:.3f}'.format(mean_c1, mean_c2, mean_normal_consistency, mean_fscore))
110 |     print('{:.6f}, {:.6f}, {:.6f}, {:.6f}'.format(mean_c1, mean_c2, mean_normal_consistency, mean_fscore))
111 | 
112 | if __name__ == '__main__':
113 |     
114 |     parser = argparse.ArgumentParser()
115 | 
116 |     parser.add_argument('-i', '--input_path', required=True ,type=str)
117 |     parser.add_argument('-g', '--gt_path', required=True ,type=str)
118 | 
119 |     main(parser.parse_args())
120 | 
--------------------------------------------------------------------------------
/tools/load_json_to_smplx.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | from smplx import SMPLX
 3 | import torch
 4 | import json
 5 | import trimesh
 6 | import argparse
 7 | 
 8 | SMPL_PATH = 'body_model/smplx/'
 9 | '''
10 | We use the following minimal code snippet to generate the SMPL-X model across all our scans
11 | '''
12 | def main(args):
13 | 
14 |     smpl_data = json.load(open(os.path.join(args.input_file)))
15 | 
16 | 
17 |     device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
18 | 
19 |     param_betas = torch.tensor(smpl_data['betas'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
20 |     param_poses = torch.tensor(smpl_data['body_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
21 |     param_left_hand_pose = torch.tensor(smpl_data['left_hand_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
22 |     param_right_hand_pose = torch.tensor(smpl_data['right_hand_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
23 |             
24 |     param_expression = torch.tensor(smpl_data['expression'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
25 |     param_jaw_pose = torch.tensor(smpl_data['jaw_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
26 |     param_leye_pose = torch.tensor(smpl_data['leye_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
27 |     param_reye_pose = torch.tensor(smpl_data['reye_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous()
28 | 
29 | 
30 |     body_model = SMPLX(model_path=SMPL_PATH, gender='male', use_pca=True, num_pca_comps=12, flat_hand_mean=True).to(device)
31 |                 
32 |     J_0 = body_model(body_pose = param_poses, betas=param_betas).joints.contiguous().detach()
33 | 
34 | 
35 |     output = body_model(betas=param_betas,
36 |                                    body_pose=param_poses,
37 |                                    transl=-J_0[:,0,:],
38 |                                    left_hand_pose=param_left_hand_pose,
39 |                                    right_hand_pose=param_right_hand_pose,
40 |                                    expression=param_expression,
41 |                                    jaw_pose=param_jaw_pose,
42 |                                    leye_pose=param_leye_pose,
43 |                                    reye_pose=param_reye_pose,
44 |                                    )
45 |     
46 |     d = trimesh.Trimesh(vertices=output.vertices.detach().cpu().numpy()[0], faces=body_model.faces)
47 |     d.export('smplx.obj')
48 | 
49 | 
50 | 
51 | if __name__ == "__main__":
52 |     parser = argparse.ArgumentParser(description='Minimal code snippet to generate SMPL-X mesh from json file')
53 | 
54 |     parser.add_argument("-i", "--input-file", default='./mesh-f00021.json', type=str, help="Input json file")
55 | 
56 |     main(parser.parse_args())
57 | 
--------------------------------------------------------------------------------
/tools/prepare_dataset.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import json
 3 | import shutil
 4 | 
 5 | DATASET_PATH = 'CustomHumans'
 6 | OUTPUT_PATH = 'CustomHumans/training_dataset'
 7 | os.makedirs(OUTPUT_PATH, exist_ok=True)
 8 | 
 9 | mesh_path = { x.split('_')[0]:x for x in sorted(os.listdir(os.path.join(DATASET_PATH, 'mesh'))) }
10 | subject_idx = json.load(open('data/Custom_train.json'))
11 | 
12 | for idx in subject_idx:
13 |     folder_name = mesh_path[idx]
14 |     shutil.copytree(os.path.join(DATASET_PATH, 'mesh', folder_name), os.path.join(OUTPUT_PATH, idx), dirs_exist_ok = True)
15 |     shutil.copytree(os.path.join(DATASET_PATH, 'smplx', folder_name), os.path.join(OUTPUT_PATH, idx), dirs_exist_ok = True)
16 | 
17 | 
18 | 
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
  1 | import os, sys
  2 | from datetime import datetime
  3 | import logging as log
  4 | import numpy as np
  5 | import torch
  6 | import random
  7 | import shutil
  8 | import tempfile
  9 | import wandb
 10 | import pickle
 11 | 
 12 | from torch.utils.data import DataLoader
 13 | from lib.datasets.customhumans_dataset import CustomHumanDataset
 14 | from lib.models.trainer import Trainer
 15 | from lib.models.evaluator import Evaluator
 16 | from lib.utils.config import *
 17 | 
 18 | 
 19 | def create_archive(save_dir, config):
 20 | 
 21 |     with tempfile.TemporaryDirectory() as tmpdir:
 22 | 
 23 |         shutil.copy(config, os.path.join(tmpdir, 'config.yaml'))
 24 |         shutil.copy('train.py', os.path.join(tmpdir, 'train.py'))
 25 |         shutil.copy('test.py', os.path.join(tmpdir, 'test.py'))
 26 | 
 27 |         shutil.copytree(
 28 |             os.path.join('lib'),
 29 |             os.path.join(tmpdir, 'lib'),
 30 |             ignore=shutil.ignore_patterns('__pycache__'))
 31 | 
 32 |         shutil.make_archive(
 33 |             os.path.join(save_dir, 'code_copy'),
 34 |             'zip',
 35 |             tmpdir) 
 36 | 
 37 | 
 38 | def main(config):
 39 | 
 40 |     # Set random seed.
 41 |     random.seed(config.seed)
 42 |     np.random.seed(config.seed)
 43 |     torch.manual_seed(config.seed)
 44 | 
 45 |     log_dir = os.path.join(
 46 |             config.save_root,
 47 |             config.exp_name,
 48 |             f'{datetime.now().strftime("%Y%m%d-%H%M%S")}'
 49 |         )
 50 | 
 51 |     # Backup code.
 52 |     create_archive(log_dir, config.config)
 53 |     
 54 |     # Initialize dataset and dataloader.
 55 | 
 56 |     with open('data/smpl_mesh.pkl', 'rb') as f:
 57 |         smpl_mesh = pickle.load(f)
 58 | 
 59 |     dataset = CustomHumanDataset(config.num_samples, config.repeat_times)
 60 |     dataset.init_from_h5(config.data_root)
 61 | 
 62 |     loader = DataLoader(dataset=dataset, 
 63 |                         batch_size=config.batch_size, 
 64 |                         shuffle=True, 
 65 |                         num_workers=config.workers,
 66 |                         pin_memory=True)
 67 |     
 68 | 
 69 |     trainer = Trainer(config, dataset.smpl_V, smpl_mesh['smpl_F'], log_dir)
 70 | 
 71 |     evaluator = Evaluator(config, log_dir)
 72 | 
 73 | 
 74 |     if config.wandb_id is not None:
 75 |         wandb_id = config.wandb_id
 76 |     else:
 77 |         wandb_id = wandb.util.generate_id()
 78 |         with open(os.path.join(log_dir, 'wandb_id.txt'), 'w+') as f:
 79 |             f.write(wandb_id)
 80 | 
 81 |     wandb_mode = "disabled" if (not config.wandb) else "online"
 82 |     wandb.init(id=wandb_id,
 83 |                project=config.wandb_name,
 84 |                config=config,
 85 |                name=os.path.basename(log_dir),
 86 |                resume="allow",
 87 |                settings=wandb.Settings(start_method="fork"),
 88 |                mode=wandb_mode,
 89 |                dir=log_dir)
 90 |     wandb.watch(trainer)
 91 | 
 92 |     if config.resume:
 93 |         trainer.load_checkpoint(config.resume)
 94 | 
 95 | 
 96 |     global_step = trainer.global_step
 97 |     start_epoch = trainer.epoch
 98 | 
 99 | 
100 |     for epoch in range(start_epoch, config.epochs):
101 |         for data in loader:
102 |             trainer.step(epoch=epoch, n_iter=global_step, data=data)
103 |             
104 |             if global_step % config.log_every == 0:
105 |                 trainer.log(global_step, epoch)
106 | 
107 |             if config.use_2d_from_epoch >= 0 and \
108 |                 epoch >= config.use_2d_from_epoch and \
109 |                 global_step % config.log_every == 0:
110 |                 trainer.write_images(global_step)
111 | 
112 |             global_step += 1
113 | 
114 |         if epoch % config.save_every == 0:
115 |             trainer.save_checkpoint(full=False)
116 | 
117 |         if epoch % config.valid_every == 0 and epoch > 0:
118 |             evaluator.init_models(trainer)
119 |             evaluator.reconstruction(32, epoch=epoch)
120 | 
121 |     wandb.finish()
122 | 
123 | if __name__ == "__main__":
124 | 
125 |     parser = parse_options()
126 |     args, args_str = argparse_to_str(parser)
127 |     handlers = [log.StreamHandler(sys.stdout)]
128 |     log.basicConfig(level=args.log_level,
129 |                         format='%(asctime)s|%(levelname)8s| %(message)s',
130 |                         handlers=handlers)
131 |     log.info(f'Info: \n{args_str}')
132 |     main(args)
--------------------------------------------------------------------------------