├── .gitignore ├── LICENSE ├── README.md ├── assets ├── doge.gif └── teaser.gif ├── checkpoints └── .gitkeep ├── config.yaml ├── data ├── Custom_train.json ├── SIZER_test.json ├── THUMAN_train.json ├── lowerbody.json ├── smpl_mesh.pkl └── thuman_smpl_mesh.pkl ├── demo.py ├── generate_dataset.py ├── lib ├── datasets │ └── customhumans_dataset.py ├── models │ ├── evaluator.py │ ├── feature_dictionary.py │ ├── losses.py │ ├── networks │ │ ├── discriminator.py │ │ ├── layers.py │ │ ├── mlps.py │ │ └── positional_encoding.py │ ├── neural_fields.py │ ├── tracer.py │ └── trainer.py ├── ops │ └── mesh │ │ ├── __init__.py │ │ ├── area_weighted_distribution.py │ │ ├── barycentric_coordinates.py │ │ ├── closest_point.py │ │ ├── closest_tex.py │ │ ├── compute_sdf.py │ │ ├── load_obj.py │ │ ├── normalize.py │ │ ├── per_face_normals.py │ │ ├── per_vertex_normals.py │ │ ├── point_sample.py │ │ ├── random_face.py │ │ ├── sample_near_surface.py │ │ ├── sample_surface.py │ │ ├── sample_tex.py │ │ └── sample_uniform.py └── utils │ ├── camera.py │ ├── config.py │ └── image.py ├── requirements.txt ├── smplx └── .gitkeep ├── tools ├── align_thuman.py ├── evaluate.py ├── load_json_to_smplx.py └── prepare_dataset.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.DS_Store 3 | **.DS_Store 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 custom-humans 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Locally Editable Virtual Humans 2 | 3 | ## [Project Page](https://custom-humans.github.io/) | [Paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Ho_Learning_Locally_Editable_Virtual_Humans_CVPR_2023_paper.pdf) | [Youtube(3min)](https://youtu.be/aT8ql5hB3ZM), [Shorts(18sec)](https://youtube.com/shorts/6LTXma_wn4c) | [Dataset](https://custom-humans.ait.ethz.ch/) 4 | 5 | 6 | 7 | Official code release for CVPR 2023 paper [*Learning Locally Editable Virtual Humans*](https://custom-humans.github.io/). 8 | 9 | If you find our code, dataset, and paper useful, please cite as 10 | ``` 11 | @inproceedings{ho2023custom, 12 | title={Learning Locally Editable Virtual Humans}, 13 | author={Ho, Hsuan-I and Xue, Lixin and Song, Jie and Hilliges, Otmar}, 14 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 15 | year={2023} 16 | } 17 | ``` 18 | 19 | ## Installation 20 | Our code has been tested with PyTorch 1.11.0, CUDA 11.3, and an RTX 3090 GPU. 21 | 22 | ```bash 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ## Quick Start 27 | 28 | ⚠️ The model checkpoint contains several real human bodies and faces. To download the checkpoint file, you need to agree the CustomHumans Dataset Terms of Use. Click [here](https://custom-humans.ait.ethz.ch/) to apply for the dataset. You will find the checkpoint file in the dataset download link. 29 | 30 | 1. Download and put the checkpoint file into the `checkpoints` folder. 31 | 32 | 2. Download the test meshes and images from [here](https://files.ait.ethz.ch/projects/custom-humans/test.zip) and put them into the `data` folder. 33 | 34 | 3. Run a quick demo on fitting to the unseen 3D scan and 2D images. 35 | ```bash! 36 | python demo.py --pretrained-root checkpoints/demo --model-name model-1000.pth 37 | ``` 38 | You should be able to wear me a Doge T-shirt. 39 | 40 | 41 | 42 | 4. Try out different functions such as reposing and cloth transfer in `demo.py`. 43 | 44 | ## Data Preparation 45 | 46 | ### CustomHumans 47 | Apply our dataset by sending a [request](https://custom-humans.ait.ethz.ch/). After downloading, you should get 646 textured meshes and SMPL-X meshes. We use only 100 meshes for training. We provide the indices of training meshes [here](https://github.com/custom-humans/editable-humans/blob/main/data/Custom_train.json). 48 | 49 | 1. Prepare the training data following the folder structure: 50 | ``` 51 | training_dataset 52 | ├── 0003 53 | │ ├── mesh-f00101.obj 54 | │ ├── mesh-f00101.mtl 55 | │ ├── mesh-f00101.png 56 | │ ├── mesh-f00101.json 57 | │ └── mesh-f00101_smpl.obj 58 | ├── 0007 59 | │ ... 60 | 61 | ``` 62 | You can use the following script to generate the training dataset folder: 63 | ```bash! 64 | python tools/prepare_dataset.py 65 | ``` 66 | 67 | 2. Download [SMPL-X](https://smpl-x.is.tue.mpg.de/) models and move them to the `smplx` folder. 68 | You should have the following data structure: 69 | ``` 70 | smplx 71 | ├── SMPLX_NEUTRAL.pkl 72 | ├── SMPLX_NEUTRAL.npz 73 | ├── SMPLX_MALE.pkl 74 | ├── SMPLX_MALE.npz 75 | ├── SMPLX_FEMALE.pkl 76 | └── SMPLX_FEMALE.npz 77 | ``` 78 | 3. Since online sampling points on meshes during training can be slow, we sample 18M points per mesh and cache them in an h5 file for training. Run the following script to generate the h5 file. 79 | 80 | ```bash! 81 | python generate_dataset.py -i /path/to/dataset/folder 82 | ``` 83 | 84 | ⚠️ The script will generate a large h5 file (>80GB). If you don't want to generate that many points, you can adjust the `NUM_SAMPLES` parameter [here](https://github.com/custom-humans/editable-humans/blob/main/generate_dataset.py#L18). 85 | 86 | ### THuman2.0 87 | 88 | We also train our model using 150 scans in Thuman2.0 and you can find their indices [here](https://github.com/custom-humans/editable-humans/blob/main/data/THUMAN_train.json). Please apply for the dataset and SMPL-X registrations through their [official repo](https://github.com/ytrock/THuman2.0-Dataset). 89 | 90 | ⚠️ Note that the scans in THuman2.0 are in various scales. We rescale them to -1~1 based on the SMPL-X models. You can find the rescaling script [here](https://github.com/custom-humans/editable-humans/blob/main/tools/align_thuman.py) 91 | 92 | ⚠️ THuman2.0 uses different settings for creating SMPL-X body meshes. When generating the h5 file, please change to `flat_hand_mean=False` in the [`generate_dataset.py`](https://github.com/custom-humans/editable-humans/blob/main/generate_dataset.py#L42) script. 93 | 94 | ## Training 95 | 96 | Once your h5 dataset is ready, simply run the command to train the model. 97 | ``` 98 | python train.py 99 | ``` 100 | Here are some configuration flags you can use, they will override the setting in `config.yaml` 101 | * `--config`: path to the config file. Default is `config.yaml` 102 | * `--wandb`: we use wandb for monitoring the training. Activate this flag if you want to use it. 103 | * `--save-root`: path to the folder to save the checkpoints. Default is `checkpoints` 104 | * `--data_root`: path to the training h5 dataset. Default is `CustomHumans.h5` 105 | * `--use_2d_from_epoch`: use 2D adversarial loss after this epoch. -1 means never use 2D loss. Default is 10. 106 | 107 | ## Evaluation 108 | 109 | We use SIZER to evaluate the geometry fitting performance. Please follow the instructions to download their [dataset](https://github.com/garvita-tiwari/sizer). 110 | 111 | We provide subjets' [indices](https://github.com/custom-humans/editable-humans/blob/main/data/SIZER_test.json) and [scripts](https://github.com/custom-humans/editable-humans/blob/main/tools/evaluate.py) for evaluation. 112 | 113 | # Acknowledgement 114 | We have used codes from other great research work, including [gdna](https://github.com/xuchen-ethz/gdna), [kaolin-wisp](https://github.com/NVIDIAGameWorks/kaolin-wisp), [SMPL-X](https://github.com/vchoutas/smplx), [ML-GSN](https://github.com/apple/ml-gsn/), [StyleGAN-Ada](https://github.com/NVlabs/stylegan2-ada-pytorch), [Occupancy Networks](https://github.com/autonomousvision/occupancy_networks). 115 | 116 | We create all the videos using powerful [aitviewer](https://eth-ait.github.io/aitviewer/). 117 | 118 | We sincerely thank the authors for their awesome work! 119 | -------------------------------------------------------------------------------- /assets/doge.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/assets/doge.gif -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/assets/teaser.gif -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/checkpoints/.gitkeep -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | save_root: './checkpoints' 3 | exp_name: 'test-release' 4 | 5 | dataset: 6 | data_root: 'CustomHumans.h5' 7 | num_samples: 20480 8 | repeat_times: 8 9 | 10 | optimizer: 11 | lr_codebook: 0.0005 12 | lr_decoder: 0.001 13 | lr_dis: 0.001 14 | beta1: 0.5 15 | beta2: 0.999 16 | 17 | 18 | train: 19 | epochs: 5000 20 | batch_size: 4 21 | workers: 8 22 | save_every: 50 23 | log_every: 100 24 | use_2d_from_epoch: 10 25 | train_2d_every_iter: 1 26 | use_nrm_dis: False 27 | use_cached_pts: True 28 | 29 | dictionary: 30 | shape_dim: 32 31 | color_dim: 32 32 | feature_std: 0.1 33 | feature_bias: 0.0 34 | shape_pca_dim: 16 35 | color_pca_dim: 16 36 | 37 | 38 | network: 39 | pos_dim: 3 40 | c_dim: 3 41 | num_layers: 4 42 | hidden_dim: 128 43 | skip: 44 | - 2 45 | activation: 'relu' 46 | layer_type: 'none' 47 | 48 | 49 | embedder: 50 | shape_freq: 5 51 | color_freq: 10 52 | 53 | 54 | losses: 55 | lambda_sdf: 100. 56 | lambda_rgb: 10. 57 | lambda_nrm: 10. 58 | lambda_reg: 1. 59 | 60 | gan_loss_type: 'logistic' 61 | lambda_gan: 1. 62 | lambda_grad: 10. 63 | 64 | 65 | validation: 66 | valid_every: 50 67 | subdivide: True 68 | grid_size: 400 69 | width: 1024 70 | fov: 20.0 71 | n_views: 1 72 | 73 | wandb: 74 | wandb: False 75 | wandb_name: 'custom-test' 76 | -------------------------------------------------------------------------------- /data/Custom_train.json: -------------------------------------------------------------------------------- 1 | ["0003", "0007", "0011", "0016", "0019", "0023", "0028", "0035", "0041", "0043", "0052", "0056", "0062", "0067", "0071", "0075", "0084", "0088", "0095", "0099", "0104", "0110", "0113", "0120", "0126", "0132", "0138", "0152", "0157", "0164", "0169", "0170", "0176", "0181", "0186", "0190", "0195", "0205", "0208", "0214", "0221", "0225", "0232", "0234", "0242", "0253", "0257", "0264", "0272", "0275", "0281", "0286", "0291", "0300", "0313", "0320", "0325", "0331", "0333", "0345", "0351", "0363", "0367", "0370", "0375", "0380", "0387", "0407", "0413", "0422", "0428", "0431", "0446", "0450", "0459", "0467", "0474", "0485", "0493", "0498", "0508", "0518", "0525", "0539", "0555", "0565", "0571", "0584", "0586", "0590", "0596", "0602", "0609", "0612", "0618", "0621", "0628", "0635", "0637", "0644"] -------------------------------------------------------------------------------- /data/SIZER_test.json: -------------------------------------------------------------------------------- 1 | ["10032-3612", "10037-4262", "10040-4311", "10041-4457", "10071-7028", "10090-8110", "10091-8164", "10115-9709"] -------------------------------------------------------------------------------- /data/THUMAN_train.json: -------------------------------------------------------------------------------- 1 | ["0000", "0001", "0005", "0006", "0007", "0008", "0017", "0021", "0024", "0025", "0028", "0034", "0037", "0038", "0052", "0053", "0054", "0056", "0057", "0060", "0070", "0071", "0078", "0083", "0087", "0088", "0092", "0095", "0099", "0103", "0107", "0108", "0110", "0116", "0119", "0121", "0125", "0126", "0128", "0129", "0132", "0136", "0139", "0144", "0146", "0151", "0155", "0160", "0164", "0167", "0168", "0173", "0176", "0181", "0184", "0185", "0187", "0193", "0197", "0200", "0203", "0204", "0210", "0215", "0216", "0228", "0229", "0241", "0243", "0252", "0266", "0273", "0282", "0283", "0285", "0286", "0293", "0296", "0299", "0303", "0307", "0308", "0311", "0314", "0318", "0322", "0327", "0329", "0330", "0338", "0339", "0342", "0345", "0348", "0351", "0354", "0356", "0362", "0365", "0369", "0376", "0377", "0378", "0381", "0384", "0387", "0391", "0393", "0394", "0398", "0401", "0402", "0405", "0412", "0415", "0421", "0425", "0426", "0428", "0430", "0433", "0434", "0435", "0437", "0440", "0441", "0445", "0448", "0453", "0455", "0459", "0460", "0462", "0463", "0467", "0470", "0471", "0476", "0480", "0482", "0488", "0491", "0494", "0496", "0499", "0501", "0502", "0503", "0507", "0522"] -------------------------------------------------------------------------------- /data/smpl_mesh.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/data/smpl_mesh.pkl -------------------------------------------------------------------------------- /data/thuman_smpl_mesh.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/data/thuman_smpl_mesh.pkl -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import logging as log 3 | import numpy as np 4 | import torch 5 | import pickle 6 | import random 7 | import json 8 | from lib.models.evaluator import Evaluator 9 | from lib.models.trainer import Trainer 10 | 11 | from lib.utils.config import * 12 | from lib.utils.image import update_edited_images 13 | 14 | def main(config): 15 | 16 | # Set random seed. 17 | random.seed(config.seed) 18 | np.random.seed(config.seed) 19 | torch.manual_seed(config.seed) 20 | 21 | log_dir = config.pretrained_root 22 | 23 | with open('data/smpl_mesh.pkl', 'rb') as f: 24 | smpl_mesh = pickle.load(f) 25 | 26 | trainer = Trainer(config, smpl_mesh['smpl_V'], smpl_mesh['smpl_F'], log_dir) 27 | 28 | trainer.load_checkpoint(os.path.join(config.pretrained_root, config.model_name)) 29 | 30 | 31 | evaluator = Evaluator(config, log_dir, mode='test') 32 | 33 | evaluator.init_models(trainer) 34 | 35 | # Fitting the 32th feature codebook to the unseen 3D mesh 36 | evaluator.fitting_3D(32, 'data/test/mesh/mesh-f00194.obj', 'data/test/mesh/mesh-f00194_smpl.obj', fit_rgb=True) 37 | 38 | # Generate the 3D mesh using marching cube 39 | evaluator.reconstruction(32, epoch=999) 40 | 41 | # Render the 3D mesh to 2D images 42 | #rendered = evaluator.render_2D(32, epoch=999) 43 | 44 | # Get the training points from the edited images 45 | rendered = update_edited_images('data/test/images', 'data/test/render_dict.pkl') 46 | 47 | # Fitting the 32th texture codebook to the edited images 48 | evaluator.fitting_2D(32, rendered, 'data/test/mesh/mesh-f00194_smpl.obj') 49 | 50 | # Generate the edited 3D mesh using marching cube 51 | evaluator.reconstruction(32, epoch=998) 52 | 53 | # Repose the 32th subject to a new SMPL-X pose 54 | #evaluator.reposing(32, 'data/test/mesh/mesh-f00181_smpl.obj', epoch=997) 55 | 56 | # Clothing transfer 57 | # Load the indices of the lower body vertices 58 | #idx = json.load(open('data/lowerbody.json')) 59 | 60 | # Fitting the 33th feature codebook to the other 3D scan 61 | #evaluator.fitting_3D(33, 'data/test/mesh/mesh-f00181.obj', 'data/test/mesh/mesh-f00181_smpl.obj', fit_rgb=True) 62 | 63 | # Transfer the clothing (idx) from the 32th subject to the 33th subject 64 | #evaluator.transfer_features(32, 33, idx) 65 | 66 | # Generate the transferred 3D mesh using marching cube 67 | #evaluator.reconstruction(33, epoch=996) 68 | 69 | 70 | if __name__ == "__main__": 71 | 72 | parser = parse_options() 73 | parser.add_argument('--pretrained-root', type=str, required=True, help='pretrained model path') 74 | parser.add_argument('--model-name', type=str, required=True, help='load model name') 75 | 76 | args, args_str = argparse_to_str(parser) 77 | handlers = [logging.StreamHandler(sys.stdout)] 78 | logging.basicConfig(level=args.log_level, 79 | format='%(asctime)s|%(levelname)8s| %(message)s', 80 | handlers=handlers) 81 | logging.info(f'Info: \n{args_str}') 82 | main(args) -------------------------------------------------------------------------------- /generate_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import numpy as np 4 | import json 5 | from tqdm import tqdm 6 | import argparse 7 | import torch 8 | import pickle 9 | import kaolin as kal 10 | from kaolin.render.camera import * 11 | 12 | from lib.utils.camera import * 13 | from lib.ops.mesh import * 14 | from smplx import SMPLX 15 | 16 | SMPL_PATH = 'smplx/' 17 | 18 | NUM_SAMPLES = 3000000 19 | 20 | N_VIEWS = 4 21 | FOV = 20 22 | HEIGHT = 1024 23 | WIDTH = 1024 24 | RATIO = 1.0 25 | 26 | N_JOINTS = 25 27 | HALF_PATCH_SIZE = 64 28 | 29 | def _get_smpl_vertices(smpl_data): 30 | device = torch.device('cuda') 31 | param_betas = torch.tensor(smpl_data['betas'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 32 | param_poses = torch.tensor(smpl_data['body_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 33 | param_left_hand_pose = torch.tensor(smpl_data['left_hand_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 34 | param_right_hand_pose = torch.tensor(smpl_data['right_hand_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 35 | 36 | param_expression = torch.tensor(smpl_data['expression'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 37 | param_jaw_pose = torch.tensor(smpl_data['jaw_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 38 | param_leye_pose = torch.tensor(smpl_data['leye_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 39 | param_reye_pose = torch.tensor(smpl_data['reye_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 40 | 41 | 42 | body_model = SMPLX(model_path=SMPL_PATH, gender='male', use_pca=True, num_pca_comps=12, flat_hand_mean=True).to(device) 43 | 44 | J_0 = body_model(body_pose = param_poses, betas=param_betas).joints.contiguous().detach() 45 | 46 | 47 | output = body_model(betas=param_betas, 48 | body_pose=param_poses, 49 | transl=-J_0[:,0,:], 50 | left_hand_pose=param_left_hand_pose, 51 | right_hand_pose=param_right_hand_pose, 52 | expression=param_expression, 53 | jaw_pose=param_jaw_pose, 54 | leye_pose=param_leye_pose, 55 | reye_pose=param_reye_pose, 56 | ) 57 | return output.vertices.contiguous()[0].detach(), \ 58 | output.joints.contiguous()[0].detach()[:25] 59 | 60 | 61 | ######################################################################################################################### 62 | 63 | def main(args): 64 | device = torch.device('cuda') 65 | 66 | outfile = h5py.File(os.path.join(args.output_path), 'w') 67 | 68 | subject_list = [x for x in sorted(os.listdir(args.input_path)) if os.path.isdir(os.path.join(args.input_path, x))] 69 | num_subjects = len(subject_list) 70 | 71 | outfile.create_dataset( 'num_subjects', data=num_subjects, dtype=np.int32) 72 | 73 | 74 | 75 | dataset_pts = outfile.create_dataset( 'pts', shape=(num_subjects, NUM_SAMPLES*6, 3), 76 | chunks=True, dtype=np.float32) 77 | dataset_rgb = outfile.create_dataset( 'rgb',shape=(num_subjects, NUM_SAMPLES*6, 3), 78 | chunks=True, dtype=np.float32) 79 | dataset_nrm = outfile.create_dataset( 'nrm', shape=(num_subjects, NUM_SAMPLES*6, 3), 80 | chunks=True, dtype=np.float32) 81 | dataset_d = outfile.create_dataset( 'd', shape=(num_subjects, NUM_SAMPLES*6, 1), 82 | chunks=True, dtype=np.float32) 83 | 84 | 85 | dataset_smpl_v = outfile.create_dataset( 'smpl_v', shape=(num_subjects, 10475, 3), 86 | chunks=True, dtype=np.float32) 87 | 88 | dataset_ray_ori_image = outfile.create_dataset( 'ray_ori_image', shape=(num_subjects, N_JOINTS*4, 89 | HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3), 90 | chunks=True, dtype=np.float32) 91 | 92 | dataset_ray_dir_image = outfile.create_dataset( 'ray_dir_image', shape=(num_subjects, N_JOINTS*4, 93 | HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3), 94 | chunks=True, dtype=np.float32) 95 | 96 | 97 | dataset_xyz_image = outfile.create_dataset( 'xyz_image', shape=(num_subjects, N_JOINTS*4, 98 | HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3), 99 | chunks=True, dtype=np.float32) 100 | dataset_nrm_image = outfile.create_dataset( 'nrm_image', shape=(num_subjects, N_JOINTS*4, 101 | HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3), 102 | chunks=True, dtype=np.float32) 103 | dataset_rgb_image = outfile.create_dataset( 'rgb_image', shape=(num_subjects, N_JOINTS*4, 104 | HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 3), 105 | chunks=True, dtype=np.float32) 106 | dataset_mask_image = outfile.create_dataset( 'mask_image', shape=(num_subjects, N_JOINTS*4,\ 107 | HALF_PATCH_SIZE*2, HALF_PATCH_SIZE*2, 1), 108 | chunks=True, dtype=np.bool) 109 | 110 | for s, subject in enumerate(tqdm(subject_list)): 111 | subject_path = os.path.join(args.input_path, subject) 112 | json_file = [x for x in sorted(os.listdir(subject_path)) if x.endswith('.json')][0] 113 | filename = json_file.split('.')[0] 114 | 115 | smpl_data = json.load(open(os.path.join(subject_path, filename+'.json'))) 116 | smpl_V, smpl_J = _get_smpl_vertices(smpl_data) 117 | with open('data/smpl_mesh.pkl', 'rb') as f: 118 | smpl_mesh = pickle.load(f) 119 | 120 | smpl_F = smpl_mesh['smpl_F'].cuda().detach() 121 | 122 | 123 | mesh_data = os.path.join(subject_path, filename+'.obj') 124 | out = load_obj(mesh_data, load_materials=True) 125 | V, F, texv, texf, mats = out 126 | FN = per_face_normals(V, F).cuda() 127 | 128 | 129 | pts1 = point_sample( V.cuda(), F.cuda(), ['near', 'near', 'trace'], NUM_SAMPLES, 0.01) 130 | pts2 = point_sample(smpl_V, smpl_F, ['rand', 'near', 'trace'], NUM_SAMPLES, 0.1) 131 | 132 | rgb1, nrm1, d1 = closest_tex(V.cuda(), F.cuda(), 133 | texv.cuda(), texf.cuda(), mats, pts1.cuda()) 134 | rgb2, nrm2, d2 = closest_tex(V.cuda(), F.cuda(), 135 | texv.cuda(), texf.cuda(), mats, pts2.cuda()) 136 | 137 | 138 | look_at = torch.zeros( (N_VIEWS, 3), dtype=torch.float32, device=device) 139 | 140 | 141 | camera_position = torch.tensor( [ [0, 0, 2], 142 | [2, 0, 0], 143 | [0, 0, -2], 144 | [-2, 0, 0] ] , dtype=torch.float32, device=device) 145 | 146 | camera_up_direction = torch.tensor( [[0, 1, 0]], dtype=torch.float32, device=device).repeat(N_VIEWS, 1,) 147 | 148 | cam_transform = generate_transformation_matrix(camera_position, look_at, camera_up_direction) 149 | cam_proj = generate_perspective_projection(FOV, RATIO) 150 | 151 | face_vertices_camera, face_vertices_image, face_normals = \ 152 | kal.render.mesh.prepare_vertices( 153 | V.unsqueeze(0).repeat(N_VIEWS, 1, 1).cuda(), 154 | F.cuda(), cam_proj.cuda(), camera_transform=cam_transform 155 | ) 156 | face_uvs = texv[texf[...,:3]].unsqueeze(0).cuda() 157 | 158 | ### Perform Rasterization ### 159 | # Construct attributes that DIB-R rasterizer will interpolate. 160 | # the first is the UVS associated to each face 161 | # the second will make a hard segmentation mask 162 | face_attributes = [ 163 | V[F].unsqueeze(0).cuda().repeat(N_VIEWS, 1, 1, 1), 164 | face_uvs.repeat(N_VIEWS, 1, 1, 1), 165 | FN.unsqueeze(0).unsqueeze(2).repeat(N_VIEWS, 1, 3, 1), 166 | ] 167 | 168 | padded_joints = torch.nn.functional.pad( 169 | smpl_J.unsqueeze(0).repeat(N_VIEWS, 1, 1), (0, 1), mode='constant', value=1.) 170 | 171 | joints_camera = (padded_joints @ cam_transform) 172 | # Project the vertices on the camera image plan 173 | jonts_image = perspective_camera(joints_camera, cam_proj.cuda()) 174 | jonts_image = ((jonts_image) * torch.tensor([1, -1], device=device) + 1 ) * \ 175 | torch.tensor([WIDTH//2, HEIGHT//2], device=device) 176 | # If you have nvdiffrast installed you can change rast_backend to 177 | # nvdiffrast or nvdiffrast_fwd 178 | image_features, face_idx = kal.render.mesh.rasterize( 179 | HEIGHT, WIDTH, face_vertices_camera[:, :, :, -1], 180 | face_vertices_image, face_attributes, backend='cuda', multiplier=1000) 181 | 182 | coords, uv, normal= image_features 183 | 184 | TM = torch.zeros((N_VIEWS, HEIGHT, WIDTH, 1), dtype=torch.long, device=device) 185 | 186 | rgb = sample_tex(uv.view(-1, 2), TM.view(-1), mats).view(N_VIEWS, HEIGHT, WIDTH, 3) 187 | mask = (face_idx != -1).unsqueeze(-1) 188 | 189 | 190 | ray_dir_patches = [] 191 | ray_ori_patches = [] 192 | xyz_patches = [] 193 | rgb_patches = [] 194 | nrm_patches = [] 195 | mask_patches = [] 196 | 197 | for i in range(N_VIEWS): 198 | 199 | camera = Camera.from_args(eye=camera_position[i], 200 | at=look_at[i], 201 | up=camera_up_direction[i], 202 | fov=FOV, 203 | width=WIDTH, 204 | height=HEIGHT, 205 | dtype=torch.float32) 206 | 207 | ray_grid = generate_centered_pixel_coords(camera.width, camera.height, 208 | camera.width, camera.height, device=device) 209 | 210 | ray_orig, ray_dir = \ 211 | generate_pinhole_rays(camera.to(ray_grid[0].device), ray_grid) 212 | 213 | ray_orig = ray_orig.reshape(camera.height, camera.width, -1) 214 | ray_dir = ray_dir.reshape(camera.height, camera.width, -1) 215 | 216 | for j in range(N_JOINTS): 217 | x = min (max( int(jonts_image[i, j, 0]), HALF_PATCH_SIZE), WIDTH - HALF_PATCH_SIZE) 218 | y = min (max( int(jonts_image[i, j, 1]), HALF_PATCH_SIZE), HEIGHT - HALF_PATCH_SIZE) 219 | 220 | ray_ori_patches.append( ray_orig[y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] ) 221 | ray_dir_patches.append( ray_dir[y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] ) 222 | xyz_patches.append( coords[i, y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] ) 223 | rgb_patches.append( rgb[i, y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] ) 224 | nrm_patches.append( normal[i, y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] ) 225 | mask_patches.append( mask[i, y-HALF_PATCH_SIZE:y+HALF_PATCH_SIZE, x-HALF_PATCH_SIZE:x+HALF_PATCH_SIZE] ) 226 | 227 | 228 | dataset_pts[s] = torch.cat([pts1, pts2], dim=0).detach().cpu().numpy() 229 | dataset_rgb[s] = torch.cat([rgb1, rgb2], dim=0).detach().cpu().numpy() 230 | dataset_nrm[s] = torch.cat([nrm1, nrm2], dim=0).detach().cpu().numpy() 231 | dataset_d[s] = torch.cat([d1, d2], dim=0).detach().cpu().numpy() 232 | dataset_smpl_v[s] = smpl_V.detach().cpu().numpy() 233 | dataset_xyz_image[s] = torch.stack(xyz_patches).detach().cpu().numpy() 234 | dataset_rgb_image[s] = torch.stack(rgb_patches).detach().cpu().numpy() 235 | dataset_nrm_image[s] = torch.stack(nrm_patches).detach().cpu().numpy() 236 | dataset_mask_image[s] = torch.stack(mask_patches).detach().cpu().numpy() 237 | dataset_ray_ori_image[s] = torch.stack(ray_ori_patches).detach().cpu().numpy() 238 | dataset_ray_dir_image[s] = torch.stack(ray_dir_patches).detach().cpu().numpy() 239 | 240 | 241 | outfile.close() 242 | 243 | 244 | if __name__ == "__main__": 245 | parser = argparse.ArgumentParser(description='Process dataset to H5 file') 246 | 247 | parser.add_argument("-i", "--input_path", default='./CustomHumans/training_dataset', type=str, help="Path of the input mesh folder") 248 | parser.add_argument("-o", "--output_path", default='./CustomHumans.h5', type=str, help="Path of the output h5 file") 249 | 250 | main(parser.parse_args()) 251 | -------------------------------------------------------------------------------- /lib/datasets/customhumans_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import h5py 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | import logging as log 7 | import time 8 | 9 | class CustomHumanDataset(Dataset): 10 | """Base class for single mesh datasets with points sampled only at a given octree sampling region. 11 | """ 12 | 13 | def __init__(self, 14 | num_samples : int = 20480, 15 | repeat_times : int = 8, 16 | ): 17 | """Construct dataset. This dataset also needs to be initialized. 18 | """ 19 | self.repeat_times = repeat_times # epeate how many times each epoch 20 | self.num_samples = num_samples # number of points per subject 21 | 22 | self.initialization_mode = None 23 | self.label_map = { 24 | '0': 1, '1': 2, '2': 2, '3': 1, '4': 2, 25 | '5': 2, '6': 1, '7': 2, '8': 2, '9': 1, 26 | '10': 2, '11': 2, '12': 1, '13': 1, '14': 1, 27 | '15': 0, '16': 1, '17': 1, '18': 1, '19': 1, 28 | '20': 1, '21': 1, '22': 0, '23': 0, '24': 0, 29 | } 30 | 31 | def init_from_h5(self, dataset_path): 32 | """Initializes the dataset from a h5 file. 33 | copy smpl_v from h5 file. 34 | """ 35 | 36 | self.h5_path = dataset_path 37 | with h5py.File(dataset_path, "r") as f: 38 | try: 39 | self.num_subjects = f['num_subjects'][()] 40 | self.num_pts = f['d'].shape[1] 41 | self.smpl_V = torch.tensor(np.array(f['smpl_v'])) 42 | except: 43 | raise ValueError("[Error] Can't load from h5 dataset") 44 | self.resample() 45 | self.initialization_mode = "h5" 46 | 47 | def resample(self): 48 | """Resamples a new working set of indices. 49 | """ 50 | 51 | start = time.time() 52 | log.info(f"Resampling...") 53 | 54 | self.id = np.random.randint(0, self.num_subjects, self.num_subjects * self.repeat_times) 55 | 56 | log.info(f"Time: {time.time() - start}") 57 | 58 | def _get_h5_data(self, subject_id, pts_id, img_id): 59 | with h5py.File(self.h5_path, "r") as f: 60 | try: 61 | pts = np.array(f['pts'][subject_id,pts_id]) 62 | d = np.array(f['d'][subject_id,pts_id]) 63 | nrm = np.array(f['nrm'][subject_id,pts_id]) 64 | rgb = np.array(f['rgb'][subject_id,pts_id]) 65 | image_label = self.label_map[str(img_id[0] % 25)] 66 | 67 | xyz_image = np.array(f['xyz_image'][subject_id,img_id]) 68 | rgb_image = np.array(f['rgb_image'][subject_id,img_id]) 69 | nrm_image = np.array(f['nrm_image'][subject_id,img_id]) 70 | mask_image = np.array(f['mask_image'][subject_id,img_id]) 71 | ray_ori_image = np.array(f['ray_ori_image'][subject_id,img_id]) 72 | ray_dir_image = np.array(f['ray_dir_image'][subject_id,img_id]) 73 | 74 | except: 75 | raise ValueError("[Error] Can't read key (%s, %s, %s) from h5 dataset" % (subject_id, pts_id, img_id)) 76 | 77 | return { 78 | 'pts' : pts, 'sdf' : d, 'nrm' : nrm, 'rgb' : rgb, 'idx' : subject_id, 'label' : image_label, 79 | 'xyz_image' : xyz_image, 'rgb_image' : rgb_image, 'nrm_image' : nrm_image, 80 | 'mask_image' : mask_image, 'ray_ori_image' : ray_ori_image, 'ray_dir_image' : ray_dir_image 81 | } 82 | 83 | def __getitem__(self, idx: int): 84 | """Retrieve point sample.""" 85 | if self.initialization_mode is None: 86 | raise Exception("The dataset is not initialized.") 87 | 88 | subject_id = self.id[idx] 89 | # points id need to be in accending order 90 | pts_id = np.random.randint(self.num_pts - self.num_samples, size=1) 91 | img_id = np.random.randint(100, size=1) 92 | 93 | return self._get_h5_data(subject_id, np.arange(pts_id, pts_id + self.num_samples), img_id) 94 | 95 | def __len__(self): 96 | """Return length of dataset (number of _samples_).""" 97 | if self.initialization_mode is None: 98 | raise Exception("The dataset is not initialized.") 99 | 100 | return self.num_subjects * self.repeat_times 101 | -------------------------------------------------------------------------------- /lib/models/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import copy 4 | import pickle 5 | import torch 6 | import trimesh 7 | 8 | import numpy as np 9 | import logging as log 10 | from tqdm import tqdm 11 | from PIL import Image 12 | 13 | from .tracer import SDFTracer 14 | from ..ops.mesh import load_obj, point_sample, closest_tex 15 | from ..utils.camera import * 16 | 17 | from kaolin.ops.conversions import voxelgrids_to_trianglemeshes 18 | from kaolin.ops.mesh import subdivide_trianglemesh 19 | 20 | class Evaluator(object): 21 | 22 | def __init__(self, config, log_dir, mode='valid'): 23 | super().__init__() 24 | 25 | self.cfg = config 26 | self.log_dir = log_dir 27 | self.mesh_dir = os.path.join(log_dir, mode, 'meshes') 28 | os.makedirs(self.mesh_dir, exist_ok=True) 29 | self.image_dir = os.path.join(log_dir, mode, 'images') 30 | os.makedirs(self.image_dir, exist_ok=True) 31 | 32 | self.sdf_field = None 33 | self.rgb_field = None 34 | 35 | self.tracer = SDFTracer(self.cfg) 36 | self.subdivide = self.cfg.subdivide 37 | self.res = self.cfg.grid_size 38 | 39 | 40 | def init_models(self, trainer): 41 | '''Initialize the models for evaluation. 42 | Args: 43 | sdf_field (SDFNet): the sdf field model from trainer. 44 | rgb_field (RGBNet): the rgb field model from trainer. 45 | ''' 46 | 47 | self.sdf_field = copy.deepcopy(trainer.sdf_field) 48 | self.rgb_field = copy.deepcopy(trainer.rgb_field) 49 | self.smpl_F = trainer.smpl_F.clone().detach().cpu() 50 | 51 | def _marching_cubes (self, geo_idx=0, tex_idx=None, subdivide=True, res=300): 52 | '''Marching cubes to generate mesh. 53 | Args: 54 | geo_idx (int): the index of geometry to be generated. 55 | tex_idx (int): the index of texture to be generated. 56 | subdivide (bool): whether to subdivide the mesh. 57 | res (int): the resolution of the marching cubes. 58 | Returns: 59 | mesh (trimesh): the generated mesh. 60 | ''' 61 | 62 | width = res 63 | window_x = torch.linspace(-1., 1., steps=width, device='cuda') 64 | window_y = torch.linspace(-1., 1., steps=width, device='cuda') 65 | window_z = torch.linspace(-1., 1., steps=width, device='cuda') 66 | 67 | coord = torch.stack(torch.meshgrid(window_x, window_y, window_z, indexing='ij')).permute(1, 2, 3, 0).reshape(1, -1, 3).contiguous() 68 | 69 | 70 | # Debug smpl grid 71 | #smpl_vertice = self.sdf_field.get_smpl_vertices_by_idx(geo_idx) 72 | #d = trimesh.Trimesh(vertices=smpl_vertice.cpu().detach().numpy(), 73 | # faces=self.smpl_F.cpu().detach().numpy()) 74 | #d.export(os.path.join(self.log_dir, 'smpl_sub_%03d.obj' % (geo_idx)) ) 75 | 76 | if tex_idx is None: 77 | tex_idx = geo_idx 78 | geo_idx = torch.tensor([geo_idx], dtype=torch.long, device = torch.device('cuda')).view(1).detach() 79 | tex_idx = torch.tensor([tex_idx], dtype=torch.long, device = torch.device('cuda')).view(1).detach() 80 | 81 | _points = torch.split(coord, int(2*1e6), dim=1) 82 | voxels = [] 83 | for _p in _points: 84 | pred_sdf = self.sdf_field(_p, geo_idx) 85 | voxels.append(pred_sdf) 86 | 87 | voxels = torch.cat(voxels, dim=1) 88 | voxels = voxels.reshape(1, width, width, width) 89 | 90 | vertices, faces = voxelgrids_to_trianglemeshes(voxels, iso_value=0.) 91 | vertices = ((vertices[0].reshape(1, -1, 3) - 0.5) / (width/2)) - 1.0 92 | faces = faces[0] 93 | 94 | if subdivide: 95 | vertices, faces = subdivide_trianglemesh(vertices, faces, iterations=1) 96 | 97 | pred_rgb = self.rgb_field(vertices, tex_idx, pose_idx=geo_idx) 98 | 99 | h = trimesh.Trimesh(vertices=vertices[0].cpu().detach().numpy(), 100 | faces=faces.cpu().detach().numpy(), 101 | vertex_colors=pred_rgb[0].cpu().detach().numpy()) 102 | 103 | # remove disconnect par of mesh 104 | connected_comp = h.split(only_watertight=False) 105 | max_area = 0 106 | max_comp = None 107 | for comp in connected_comp: 108 | if comp.area > max_area: 109 | max_area = comp.area 110 | max_comp = comp 111 | h = max_comp 112 | 113 | trimesh.repair.fix_inversion(h) 114 | 115 | return h 116 | 117 | def _get_camera_rays(self, n_views=4, fov=20, width=1024): 118 | '''Get camera rays for rendering. 119 | Args: 120 | n_views (int): the number of views. 121 | fov (float): the field of view. 122 | width (int): the width of the image. 123 | Returns: 124 | ray_o_images : the origin of the rays of n_views*height*width*3 125 | ray_d_images : the direction of the rays of n_views*height*width*3 126 | ''' 127 | 128 | look_at = torch.zeros( (n_views, 3), dtype=torch.float32, device=torch.device('cuda')) 129 | camera_up_direction = torch.tensor( [[0, 1, 0]], dtype=torch.float32, device=torch.device('cuda')).repeat(n_views, 1,) 130 | angle = torch.linspace(0, 2*np.pi, n_views+1)[:-1] 131 | camera_position = torch.stack( (2*torch.sin(angle), torch.zeros_like(angle), 2*torch.cos(angle)), dim=1).cuda() 132 | 133 | ray_o_images = [] 134 | ray_d_images = [] 135 | for i in range(n_views): 136 | camera = Camera.from_args(eye=camera_position[i], 137 | at=look_at[i], 138 | up=camera_up_direction[i], 139 | fov=fov, 140 | width=width, 141 | height=width, 142 | dtype=torch.float32) 143 | 144 | ray_grid = generate_centered_pixel_coords(camera.width, camera.height, 145 | camera.width, camera.height, device=torch.device('cuda')) 146 | 147 | ray_orig, ray_dir = \ 148 | generate_pinhole_rays(camera.to(ray_grid[0].device), ray_grid) 149 | 150 | ray_o_images.append(ray_orig.reshape(camera.height, camera.width, -1)) 151 | ray_d_images.append(ray_dir.reshape(camera.height, camera.width, -1)) 152 | 153 | return torch.stack(ray_o_images, dim=0), torch.stack(ray_d_images, dim=0) 154 | 155 | def reconstruction(self, idx, epoch=None): 156 | ''' 157 | Reconstruct the mesh the idx-th subject. 158 | ''' 159 | if epoch is None: 160 | epoch = 0 161 | log.info(f"Reconstructing {idx}th mesh at epoch {epoch}...") 162 | start = time.time() 163 | 164 | with torch.no_grad(): 165 | h = self._marching_cubes (geo_idx=idx, subdivide=self.subdivide, res=self.res) 166 | 167 | h.export(os.path.join(self.mesh_dir, '%03d_reco_src-%03d.obj' % (epoch, idx)) ) 168 | end = time.time() 169 | log.info(f"Reconstruction finished in {end-start} seconds.") 170 | 171 | def render_2D(self, idx, epoch=None): 172 | ''' 173 | Render the 2D images of the idx-th subject. 174 | ''' 175 | torch.cuda.empty_cache() 176 | 177 | log.info(f"Rendering {idx}th subject at epoch {epoch}...") 178 | start = time.time() 179 | 180 | with torch.no_grad(): 181 | 182 | ray_o_images, ray_d_images = self._get_camera_rays(n_views=self.cfg.n_views, fov=self.cfg.fov, width=self.cfg.width) 183 | _idx = torch.tensor([idx], dtype=torch.long, device = torch.device('cuda')).repeat(self.cfg.n_views).detach() 184 | x, hit = self.tracer(self.sdf_field.forward, _idx, 185 | ray_o_images.view(self.cfg.n_views, -1, 3), 186 | ray_d_images.view(self.cfg.n_views, -1, 3)) 187 | log.info(f"Rat tracing finished in {time.time()-start} seconds.") 188 | start = time.time() 189 | rgb_2d = self.rgb_field.forward(x.detach(), _idx) * hit 190 | 191 | rgb_img = rgb_2d.reshape(self.cfg.n_views, self.cfg.width, self.cfg.width, 3).cpu().detach().numpy() * 255 192 | 193 | for i in range(self.cfg.n_views): 194 | Image.fromarray(rgb_img[i].astype(np.uint8)).save( 195 | os.path.join(self.image_dir, '%03d_render_src-%03d_view-%03d.png' % (epoch, idx, i)) ) 196 | 197 | log.info(f"Rendering finished in {time.time()-start} seconds.") 198 | render_dict = {'coord': x.cpu().detach(), 'rgb': rgb_2d.cpu().detach(), 'mask': hit.cpu().detach()} 199 | with open(os.path.join(self.image_dir, 'render_dict.pkl'), 'wb') as f: 200 | pickle.dump(render_dict, f) 201 | 202 | return render_dict 203 | 204 | 205 | def reposing(self, idx, target_smpl_obj, epoch=None): 206 | ''' 207 | Reconstruct the mesh the idx-th subject. given the target smpl obj. 208 | ''' 209 | 210 | if epoch is None: 211 | epoch = 0 212 | smpl_V, _ = load_obj(target_smpl_obj, load_materials=False) 213 | log.info(f"Reposing {idx}th mesh at epoch {epoch}...") 214 | start = time.time() 215 | 216 | with torch.no_grad(): 217 | 218 | tmp_smpl_V = self.sdf_field.get_smpl_vertices_by_idx(idx) 219 | 220 | self.sdf_field.replace_smpl_vertices_by_idx(idx, smpl_V) 221 | self.rgb_field.replace_smpl_vertices_by_idx(idx, smpl_V) 222 | 223 | h = self._marching_cubes (geo_idx=idx, subdivide=self.subdivide, res=self.res) 224 | 225 | self.sdf_field.replace_smpl_vertices_by_idx(idx, tmp_smpl_V) 226 | self.rgb_field.replace_smpl_vertices_by_idx(idx, tmp_smpl_V) 227 | 228 | h.export(os.path.join(self.mesh_dir, '%03d_repose_src-%03d.obj' % (epoch, idx)) ) 229 | end = time.time() 230 | log.info(f"Reposing finished in {end-start} seconds.") 231 | 232 | def transfer_features(self, src_idx, tar_idx, vert_idx=None): 233 | ''' 234 | Copy the features from src_idx to tar_idx at vert_idx. 235 | ''' 236 | with torch.no_grad(): 237 | src_geo = self.sdf_field.get_feature_by_idx(src_idx, vert_idx=vert_idx).clone() 238 | src_tex = self.rgb_field.get_feature_by_idx(src_idx, vert_idx=vert_idx).clone() 239 | self.sdf_field.replace_feature_by_idx(tar_idx, src_geo, vert_idx=vert_idx) 240 | self.rgb_field.replace_feature_by_idx(tar_idx, src_tex, vert_idx=vert_idx) 241 | 242 | 243 | def fitting_3D(self, code_idx, target_mesh, target_smpl_obj, num_steps=300, fit_nrm=False, fit_rgb=False): 244 | """Fitting the latent code to the target mesh. 245 | Store the optimzed code in the code_idx-th entry of the codebook. 246 | """ 247 | 248 | torch.cuda.empty_cache() 249 | 250 | geo_code = self.sdf_field.get_mean_feature().clone().unsqueeze(0).detach().data 251 | tex_code = self.rgb_field.get_mean_feature().clone().unsqueeze(0).detach().data 252 | 253 | geo_code.requires_grad = True 254 | tex_code.requires_grad = True 255 | 256 | V, F, texv, texf, mats = load_obj(target_mesh, load_materials=True) 257 | smpl_V, _ = load_obj(target_smpl_obj, load_materials=False) 258 | smpl_V = smpl_V.cuda() 259 | 260 | params = [] 261 | params.append({'params': geo_code, 'lr': 0.005}) 262 | params.append({'params': tex_code, 'lr': 0.01}) 263 | 264 | optimizer = torch.optim.Adam(params, betas=(0.9, 0.999)) 265 | loop = tqdm(range(num_steps)) 266 | log.info(f"Start fitting latent code to the target mesh...") 267 | for i in loop: 268 | coord_1 = point_sample(V.cuda(), F.cuda(), ['near', 'trace', 'rand'], 20000, 0.01) 269 | coord_2 = point_sample(smpl_V, self.smpl_F.cuda(), ['near', 'trace'], 50000, 0.2) 270 | coord = torch.cat((coord_1, coord_2), dim=0) 271 | rgb, nrm, sdf = closest_tex(V.cuda(), F.cuda(), texv.cuda(), texf.cuda(), mats, coord.cuda()) 272 | coord = coord.unsqueeze(0) 273 | sdf = sdf.unsqueeze(0) 274 | rgb = rgb.unsqueeze(0) 275 | nrm = nrm.unsqueeze(0) 276 | 277 | sdf_loss = torch.tensor(0.0).cuda() 278 | nrm_loss = torch.tensor(0.0).cuda() 279 | rgb_loss = torch.tensor(0.0).cuda() 280 | 281 | optimizer.zero_grad() 282 | 283 | pred_sdf = self.sdf_field.forward_fitting(coord, geo_code, smpl_V.unsqueeze(0)) 284 | sdf_loss += torch.abs(pred_sdf - sdf).mean() 285 | 286 | if fit_rgb: 287 | pred_rgb = self.rgb_field.forward_fitting(coord, tex_code, smpl_V.unsqueeze(0)) 288 | rgb_loss += torch.abs(pred_rgb - rgb).mean() 289 | 290 | if fit_nrm: 291 | pred_nrm = self.sdf_field.normal_fitting(coord, tex_code, smpl_V.unsqueeze(0)) 292 | nrm_loss += torch.abs(pred_nrm - nrm).mean() 293 | 294 | 295 | loss = 10*sdf_loss + rgb_loss + nrm_loss 296 | loss.backward() 297 | optimizer.step() 298 | loop.set_description('Step [{}/{}] Total Loss: {:.4f} - L1:{:.4f} - RGB:{:.4f} - NRM:{:.4f}' 299 | .format(i, num_steps, loss.item(), sdf_loss.item(), rgb_loss.item(), nrm_loss.item())) 300 | 301 | log.info(f"Fitting finished. Store the optimized code and the new SMPL pose in the codebook.") 302 | 303 | with torch.no_grad(): 304 | self.sdf_field.replace_feature_by_idx(code_idx, geo_code) 305 | self.rgb_field.replace_feature_by_idx(code_idx, tex_code) 306 | self.sdf_field.replace_smpl_vertices_by_idx(code_idx, smpl_V) 307 | self.rgb_field.replace_smpl_vertices_by_idx(code_idx, smpl_V) 308 | 309 | 310 | def fitting_2D(self, code_idx, target_dict, target_smpl_obj=None, num_steps=500): 311 | """Fitting the color latent code to the rendered images 312 | Store the optimzed code in the code_idx-th entry of the codebook. 313 | """ 314 | 315 | torch.cuda.empty_cache() 316 | 317 | tex_code = self.rgb_field.get_feature_by_idx(code_idx).clone().unsqueeze(0).detach().data 318 | tex_code.requires_grad = True 319 | 320 | rgb = target_dict['rgb'].cuda() 321 | coord = target_dict['coord'].cuda() 322 | mask = target_dict['mask'].cuda() 323 | 324 | b_size = rgb.shape[0] # b_size = n_views 325 | 326 | 327 | inputs = [] 328 | targets = [] 329 | for i in range(b_size): 330 | _xyz = coord[i] 331 | _rgb = rgb[i] 332 | _mask = mask[i, :, 0] 333 | inputs.append(_xyz[_mask].view(1,-1,3)) 334 | targets.append(_rgb[_mask].view(1,-1,3)) 335 | 336 | inputs = torch.cat(inputs, dim=1) 337 | targets = torch.cat(targets, dim=1) 338 | 339 | if target_smpl_obj is not None: 340 | smpl_V, _ = load_obj(target_smpl_obj, load_materials=False) 341 | smpl_V = smpl_V.cuda() 342 | else: 343 | smpl_V = self.rgb_field.get_smpl_vertices_by_idx(code_idx) 344 | 345 | params = [] 346 | params.append({'params': tex_code, 'lr': 0.005}) 347 | 348 | optimizer = torch.optim.Adam(params, betas=(0.9, 0.999)) 349 | loop = tqdm(range(num_steps)) 350 | 351 | 352 | for i in loop: 353 | 354 | rgb_loss = torch.tensor(0.0).cuda() 355 | 356 | optimizer.zero_grad() 357 | 358 | pred_rgb = self.rgb_field.forward_fitting(inputs, tex_code, smpl_V.unsqueeze(0)) 359 | rgb_loss += torch.abs(pred_rgb - targets).mean() 360 | 361 | rgb_loss.backward() 362 | optimizer.step() 363 | loop.set_description('Step [{}/{}] Total Loss: {:.4f}'.format(i, num_steps, rgb_loss.item())) 364 | 365 | with torch.no_grad(): 366 | self.rgb_field.replace_feature_by_idx(code_idx, tex_code) 367 | #self.rgb_field.replace_smpl_vertices_by_idx(code_idx, smpl_V) 368 | -------------------------------------------------------------------------------- /lib/models/feature_dictionary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import logging as log 5 | from ..ops.mesh import * 6 | 7 | 8 | class FeatureDictionary(nn.Module): 9 | 10 | def __init__(self, 11 | feature_dim : int, 12 | feature_std : float = 0.1, 13 | feature_bias : float = 0.0, 14 | ): 15 | super().__init__() 16 | self.feature_dim = feature_dim 17 | self.feature_std = feature_std 18 | self.feature_bias = feature_bias 19 | 20 | def init_from_smpl_vertices(self, smpl_vertices): 21 | 22 | self.num_subjets = smpl_vertices.shape[0] 23 | self.num_vertices = smpl_vertices.shape[1] 24 | 25 | # Initialize feature codebooks 26 | fts = torch.zeros(self.num_subjets,self.num_vertices, self.feature_dim) + self.feature_bias 27 | fts += torch.randn_like(fts) * self.feature_std 28 | self.feature_codebooks = nn.Parameter(fts) 29 | 30 | log.info(f"Initalized feature codebooks with shape {self.feature_codebooks.shape}") 31 | 32 | def interpolate(self, coords, idx, smpl_V, smpl_F, input_code=None): 33 | 34 | """Query local features using the feature codebook, or the given input_code. 35 | Args: 36 | coords (torch.FloatTensor): coords of shape [batch, num_samples, 3] 37 | idx (torch.LongTensor): index of shape [batch, 1] 38 | smpl_V (torch.FloatTensor): SMPL vertices of shape [batch, num_vertices, 3] 39 | smpl_F (torch.LongTensor): SMPL faces of shape [num_faces, 3] 40 | input_code (torch.FloatTensor): input code of shape [batch, num_vertices, feature_dim] 41 | Returns: 42 | (torch.FloatTensor): interpolated features of shape [batch, num_samples, feature_dim] 43 | """ 44 | 45 | sdf, hitpt, fid, weights = batched_closest_point_fast(smpl_V, smpl_F, 46 | coords) # [B, Ns, 1], [B, Ns, 3], [B, Ns, 1], [B, Ns, 3] 47 | 48 | normal = torch.nn.functional.normalize( hitpt - coords, eps=1e-6, dim=2) # [B x Ns x 3] 49 | hitface = smpl_F[fid] # [B, Ns, 3] 50 | 51 | if input_code is None: 52 | inputs_feat = self.feature_codebooks[idx].unsqueeze(2).expand(-1, -1, hitface.shape[-1], -1) 53 | else: 54 | inputs_feat = input_code.unsqueeze(2).expand(-1, -1, hitface.shape[-1], -1) 55 | 56 | indices = hitface.unsqueeze(-1).expand(-1, -1, -1, inputs_feat.shape[-1]) 57 | nearest_feats = torch.gather(input=inputs_feat, index=indices, dim=1) # [B, Ns, 3, D] 58 | 59 | weighted_feats = torch.sum(nearest_feats * weights[...,None], dim=2) # K-weighted sum by: [B x Ns x 32] 60 | 61 | coords_feats = torch.cat([weights[...,1:], sdf], dim=-1) # [B, Ns, 3] 62 | return weighted_feats, coords_feats, normal 63 | 64 | def interpolate_random(self, coords, smpl_V, smpl_F, low_rank=32): 65 | """Query local features using PCA random sampling. 66 | 67 | Args: 68 | coords (torch.FloatTensor): coords of shape [batch, num_samples, 3] 69 | smpl_V (torch.FloatTensor): SMPL vertices of shape [batch, num_vertices, 3] 70 | smpl_F (torch.LongTensor): SMPL faces of shape [num_faces, 3] 71 | 72 | Returns: 73 | (torch.FloatTensor): interpolated features of shape [batch, num_samples, feature_dim] 74 | """ 75 | b_size = coords.shape[0] 76 | 77 | sdf, hitpt, fid, weights = batched_closest_point_fast(smpl_V, smpl_F, 78 | coords) # [B, Ns, 1], [B, Ns, 3], [B, Ns, 1], [B, Ns, 3] 79 | normal = torch.nn.functional.normalize( hitpt - coords, eps=1e-6, dim=2) # [B x Ns x 3] 80 | hitface = smpl_F[fid] # [B, Ns, 3] 81 | inputs_feat = self._pca_sample(low_rank=low_rank, batch_size=b_size).unsqueeze(2).expand(-1, -1, hitface.shape[-1], -1) 82 | indices = hitface.unsqueeze(-1).expand(-1, -1, -1, inputs_feat.shape[-1]) 83 | nearest_feats = torch.gather(input=inputs_feat, index=indices, dim=1) # [B, Ns, 3, D] 84 | 85 | weighted_feats = torch.sum(nearest_feats * weights[...,None], dim=2) # K-weighted sum by: [B x Ns x 32] 86 | 87 | coords_feats = torch.cat([weights[...,1:], sdf], dim=-1) # [B, Ns, 3] 88 | return weighted_feats, coords_feats, normal 89 | 90 | 91 | def _pca_sample(self, low_rank=32, batch_size=1): 92 | 93 | A = self.feature_codebooks.clone() 94 | num_subjects, num_vertices, dim = A.shape 95 | 96 | A = A.view(num_subjects, -1) 97 | 98 | (U, S, V) = torch.pca_lowrank(A, q=low_rank, center=True, niter=1) 99 | 100 | params = torch.matmul(A, V) # (N, 128) 101 | mean = params.mean(dim=0) 102 | cov = torch.cov(params.T) 103 | 104 | m = torch.distributions.multivariate_normal.MultivariateNormal(mean, cov) 105 | random_codes = m.sample((batch_size,)).to(self.feature_codebooks.device) 106 | 107 | return torch.matmul(random_codes.detach(), V.t()).view(-1, num_vertices, dim) 108 | 109 | -------------------------------------------------------------------------------- /lib/models/losses.py: -------------------------------------------------------------------------------- 1 | """ The code is based on https://github.com/apple/ml-gsn/ with adaption. """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import autograd 6 | import logging as log 7 | import torch.nn.functional as F 8 | 9 | from .networks.discriminator import StyleDiscriminator 10 | 11 | def hinge_loss(fake_pred, real_pred, mode): 12 | if mode == 'd': 13 | # Discriminator update 14 | d_loss_fake = F.relu(1.0 + fake_pred).mean() 15 | d_loss_real = F.relu(1.0 - real_pred).mean() 16 | d_loss = d_loss_fake + d_loss_real 17 | elif mode == 'g': 18 | # Generator update 19 | d_loss = -torch.mean(fake_pred) 20 | return d_loss 21 | 22 | def logistic_loss(fake_pred, real_pred, mode): 23 | if mode == 'd': 24 | # Discriminator update 25 | d_loss_fake = F.softplus(fake_pred).mean() 26 | d_loss_real = F.softplus(-real_pred).mean() 27 | d_loss = d_loss_fake + d_loss_real 28 | elif mode == 'g': 29 | # Generator update 30 | d_loss = F.softplus(-fake_pred).mean() 31 | return d_loss 32 | 33 | 34 | def r1_loss(real_pred, real_img): 35 | (grad_real,) = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True) 36 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 37 | return grad_penalty 38 | 39 | 40 | class GANLoss(nn.Module): 41 | def __init__( 42 | self, 43 | cfg, 44 | disc_loss='logistic', 45 | auxillary=False 46 | ): 47 | super().__init__() 48 | 49 | 50 | self.cfg = cfg 51 | self.discriminator = StyleDiscriminator(3, 128, auxilary=auxillary) 52 | log.info("Total number of parameters {}".format( 53 | sum(p.numel() for p in self.discriminator.parameters()))\ 54 | ) 55 | 56 | if disc_loss == 'hinge': 57 | self.disc_loss = hinge_loss 58 | elif disc_loss == 'logistic': 59 | self.disc_loss = logistic_loss 60 | 61 | self.auxillary = auxillary 62 | 63 | def forward(self, disc_in_real, disc_in_fake, mode='g', gt_label=None): 64 | 65 | if mode == 'g': # optimize generator 66 | loss = 0 67 | log = {} 68 | if self.auxillary: 69 | logits_fake, _ = self.discriminator(disc_in_fake) 70 | else: 71 | logits_fake = self.discriminator(disc_in_fake) 72 | 73 | g_loss = self.disc_loss(logits_fake, None, mode='g') 74 | log["loss_train/g_loss"] = g_loss.item() 75 | loss += g_loss * self.cfg.lambda_gan 76 | 77 | return loss, log 78 | 79 | if mode == 'd' : # optimize discriminator 80 | if self.auxillary: 81 | logits_real, aux_real = self.discriminator(disc_in_real) 82 | logits_fake, aux_fake = self.discriminator(disc_in_fake.detach().clone()) 83 | else: 84 | logits_real = self.discriminator(disc_in_real) 85 | logits_fake = self.discriminator(disc_in_fake.detach().clone()) 86 | 87 | disc_loss = self.disc_loss(fake_pred=logits_fake, real_pred=logits_real, mode='d') 88 | 89 | # lazy regularization so we don't need to compute grad penalty every iteration 90 | if self.cfg.lambda_grad > 0: 91 | grad_penalty = r1_loss(logits_real, disc_in_real) 92 | 93 | # the 0 * logits_real is to trigger DDP allgather 94 | # https://github.com/rosinality/stylegan2-pytorch/issues/76 95 | grad_penalty = grad_penalty + (0 * logits_real.sum()) 96 | else: 97 | grad_penalty = torch.tensor(0.0).type_as(disc_loss) 98 | 99 | d_loss = disc_loss * self.cfg.lambda_gan + grad_penalty * self.cfg.lambda_grad / 2 100 | if self.auxillary: 101 | d_loss += F.cross_entropy(aux_real, gt_label) 102 | d_loss += F.cross_entropy(aux_fake, gt_label) 103 | 104 | log = { 105 | "loss_train/disc_loss": disc_loss.item(), 106 | "loss_train/r1_loss": grad_penalty.item(), 107 | "loss_train/logits_real": logits_real.mean().item(), 108 | "loss_train/logits_fake": logits_fake.mean().item(), 109 | } 110 | 111 | return d_loss, log 112 | -------------------------------------------------------------------------------- /lib/models/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | """ The code is based on https://github.com/apple/ml-gsn/ with adaption. """ 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class StyleDiscriminator(nn.Module): 9 | def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, auxilary=False, **kwargs): 10 | super().__init__() 11 | 12 | log_size_in = int(math.log(in_res, 2)) 13 | log_size_out = int(math.log(4, 2)) 14 | self.auxilary = auxilary 15 | 16 | self.conv_in = ConvLayer2d(in_channel=in_channel, out_channel=ch_mul, kernel_size=3) 17 | 18 | # each resblock will half the resolution and double the number of features (until a maximum of ch_max) 19 | self.layers = [] 20 | in_channels = ch_mul 21 | for i in range(log_size_in, log_size_out, -1): 22 | out_channels = int(min(in_channels * 2, ch_max)) 23 | self.layers.append(ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True)) 24 | in_channels = out_channels 25 | self.layers = nn.Sequential(*self.layers) 26 | 27 | self.disc_out = DiscriminatorHead(in_channel=in_channels, disc_stddev=True, auxilary=auxilary) 28 | 29 | def forward(self, x): 30 | x = self.conv_in(x) 31 | x = self.layers(x) 32 | if self.auxilary: 33 | out, aux = self.disc_out(x) 34 | return out, aux 35 | else: 36 | out = self.disc_out(x) 37 | return out 38 | 39 | class DiscriminatorHead(nn.Module): 40 | def __init__(self, in_channel, disc_stddev=False, auxilary=False): 41 | super().__init__() 42 | 43 | self.disc_stddev = disc_stddev 44 | self.auxilary = auxilary 45 | stddev_dim = 1 if disc_stddev else 0 46 | 47 | self.conv_stddev = ConvLayer2d( 48 | in_channel=in_channel + stddev_dim, out_channel=in_channel, kernel_size=3, activate=True 49 | ) 50 | 51 | self.final_linear = nn.Sequential( 52 | nn.Flatten(), 53 | EqualLinear(in_channel=in_channel * 4 * 4, out_channel=in_channel, activate=True), 54 | EqualLinear(in_channel=in_channel, out_channel=1), 55 | ) 56 | if self.auxilary: 57 | self.aux_layer = nn.Sequential( 58 | nn.Flatten(), 59 | EqualLinear(in_channel=in_channel * 4 * 4, out_channel=in_channel, activate=True), 60 | EqualLinear(in_channel=in_channel, out_channel=3), 61 | ) 62 | 63 | def cat_stddev(self, x, stddev_group=4, stddev_feat=1): 64 | perm = torch.randperm(len(x)) 65 | inv_perm = torch.argsort(perm) 66 | 67 | batch, channel, height, width = x.shape 68 | x = x[perm] # shuffle inputs so that all views in a single trajectory don't get put together 69 | 70 | group = min(batch, stddev_group) 71 | stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width) 72 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 73 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 74 | stddev = stddev.repeat(group, 1, height, width) 75 | 76 | stddev = stddev[inv_perm] # reorder inputs 77 | x = x[inv_perm] 78 | 79 | out = torch.cat([x, stddev], 1) 80 | return out 81 | 82 | def forward(self, x): 83 | if self.disc_stddev: 84 | x = self.cat_stddev(x) 85 | x = self.conv_stddev(x) 86 | out = self.final_linear(x) 87 | if self.auxilary: 88 | aux = self.aux_layer(x) 89 | return out, aux 90 | else: 91 | return out 92 | 93 | 94 | class ConvDecoder(nn.Module): 95 | def __init__(self, in_channel, out_channel, in_res, out_res): 96 | super().__init__() 97 | 98 | log_size_in = int(math.log(in_res, 2)) 99 | log_size_out = int(math.log(out_res, 2)) 100 | 101 | self.layers = [] 102 | in_ch = in_channel 103 | for i in range(log_size_in, log_size_out): 104 | out_ch = in_ch // 2 105 | self.layers.append( 106 | ConvLayer2d( 107 | in_channel=in_ch, out_channel=out_ch, kernel_size=3, upsample=True, bias=True, activate=True 108 | ) 109 | ) 110 | in_ch = out_ch 111 | 112 | self.layers.append( 113 | ConvLayer2d(in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False) 114 | ) 115 | self.layers = nn.Sequential(*self.layers) 116 | 117 | def forward(self, x): 118 | return self.layers(x) 119 | 120 | class FusedLeakyReLU(nn.Module): 121 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 122 | super().__init__() 123 | 124 | if bias: 125 | self.bias = nn.Parameter(torch.zeros(channel)) 126 | 127 | else: 128 | self.bias = None 129 | 130 | self.negative_slope = negative_slope 131 | self.scale = scale 132 | 133 | def forward(self, input): 134 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 135 | 136 | 137 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 138 | if input.dtype == torch.float16: 139 | bias = bias.half() 140 | 141 | if bias is not None: 142 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 143 | return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale 144 | 145 | else: 146 | return F.leaky_relu(input, negative_slope=0.2) * scale 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | up_x, up_y = up, up 151 | down_x, down_y = down, down 152 | pad_x0, pad_x1, pad_y0, pad_y1 = pad[0], pad[1], pad[0], pad[1] 153 | 154 | _, channel, in_h, in_w = input.shape 155 | input = input.reshape(-1, in_h, in_w, 1) 156 | 157 | _, in_h, in_w, minor = input.shape 158 | kernel_h, kernel_w = kernel.shape 159 | 160 | out = input.view(-1, in_h, 1, in_w, 1, minor) 161 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 162 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 163 | 164 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 165 | out = out[ 166 | :, 167 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 168 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 169 | :, 170 | ] 171 | 172 | out = out.permute(0, 3, 1, 2) 173 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 175 | out = F.conv2d(out, w) 176 | out = out.reshape( 177 | -1, 178 | minor, 179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 181 | ) 182 | out = out.permute(0, 2, 3, 1) 183 | out = out[:, ::down_y, ::down_x, :] 184 | 185 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 186 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 187 | 188 | return out.view(-1, channel, out_h, out_w) 189 | 190 | 191 | 192 | def make_kernel(k): 193 | k = torch.tensor(k, dtype=torch.float32) 194 | 195 | if k.ndim == 1: 196 | k = k[None, :] * k[:, None] 197 | 198 | k /= k.sum() 199 | 200 | return k 201 | 202 | 203 | class Blur(nn.Module): 204 | """Blur layer. 205 | Applies a blur kernel to input image using finite impulse response filter. Blurring feature maps after 206 | convolutional upsampling or before convolutional downsampling helps produces models that are more robust to 207 | shifting inputs (https://richzhang.github.io/antialiased-cnns/). In the context of GANs, this can provide 208 | cleaner gradients, and therefore more stable training. 209 | Args: 210 | ---- 211 | kernel: list, int 212 | A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. 213 | pad: tuple, int 214 | A tuple of integers representing the number of rows/columns of padding to be added to the top/left and 215 | the bottom/right respectively. 216 | upsample_factor: int 217 | Upsample factor. 218 | """ 219 | 220 | def __init__(self, kernel, pad, upsample_factor=1): 221 | super().__init__() 222 | 223 | kernel = make_kernel(kernel) 224 | 225 | if upsample_factor > 1: 226 | kernel = kernel * (upsample_factor ** 2) 227 | 228 | self.register_buffer("kernel", kernel) 229 | self.pad = pad 230 | 231 | def forward(self, input): 232 | out = upfirdn2d(input, self.kernel, pad=self.pad) 233 | return out 234 | 235 | 236 | class Upsample(nn.Module): 237 | """Upsampling layer. 238 | Perform upsampling using a blur kernel. 239 | Args: 240 | ---- 241 | kernel: list, int 242 | A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. 243 | factor: int 244 | Upsampling factor. 245 | """ 246 | 247 | def __init__(self, kernel=[1, 3, 3, 1], factor=2): 248 | super().__init__() 249 | 250 | self.factor = factor 251 | kernel = make_kernel(kernel) * (factor ** 2) 252 | self.register_buffer("kernel", kernel) 253 | 254 | p = kernel.shape[0] - factor 255 | pad0 = (p + 1) // 2 + factor - 1 256 | pad1 = p // 2 257 | self.pad = (pad0, pad1) 258 | 259 | def forward(self, input): 260 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 261 | return out 262 | 263 | 264 | class Downsample(nn.Module): 265 | """Downsampling layer. 266 | Perform downsampling using a blur kernel. 267 | Args: 268 | ---- 269 | kernel: list, int 270 | A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. 271 | factor: int 272 | Downsampling factor. 273 | """ 274 | 275 | def __init__(self, kernel=[1, 3, 3, 1], factor=2): 276 | super().__init__() 277 | 278 | self.factor = factor 279 | kernel = make_kernel(kernel) 280 | self.register_buffer("kernel", kernel) 281 | 282 | p = kernel.shape[0] - factor 283 | pad0 = (p + 1) // 2 284 | pad1 = p // 2 285 | self.pad = (pad0, pad1) 286 | 287 | def forward(self, input): 288 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 289 | return out 290 | 291 | 292 | class EqualLinear(nn.Module): 293 | """Linear layer with equalized learning rate. 294 | During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to 295 | prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU 296 | activation functions. 297 | Args: 298 | ---- 299 | in_channel: int 300 | Input channels. 301 | out_channel: int 302 | Output channels. 303 | bias: bool 304 | Use bias term. 305 | bias_init: float 306 | Initial value for the bias. 307 | lr_mul: float 308 | Learning rate multiplier. By scaling weights and the bias we can proportionally scale the magnitude of 309 | the gradients, effectively increasing/decreasing the learning rate for this layer. 310 | activate: bool 311 | Apply leakyReLU activation. 312 | """ 313 | 314 | def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False): 315 | super().__init__() 316 | 317 | self.weight = nn.Parameter(torch.randn(out_channel, in_channel).div_(lr_mul)) 318 | 319 | if bias: 320 | self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init)) 321 | else: 322 | self.bias = None 323 | 324 | self.activate = activate 325 | self.scale = (1 / math.sqrt(in_channel)) * lr_mul 326 | self.lr_mul = lr_mul 327 | 328 | def forward(self, input): 329 | if self.activate: 330 | out = F.linear(input, self.weight * self.scale) 331 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 332 | else: 333 | out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) 334 | return out 335 | 336 | def __repr__(self): 337 | return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 338 | 339 | 340 | class EqualConv2d(nn.Module): 341 | """2D convolution layer with equalized learning rate. 342 | During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to 343 | prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU 344 | activation functions. 345 | Args: 346 | ---- 347 | in_channel: int 348 | Input channels. 349 | out_channel: int 350 | Output channels. 351 | kernel_size: int 352 | Kernel size. 353 | stride: int 354 | Stride of convolutional kernel across the input. 355 | padding: int 356 | Amount of zero padding applied to both sides of the input. 357 | bias: bool 358 | Use bias term. 359 | """ 360 | 361 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): 362 | super().__init__() 363 | 364 | self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) 365 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 366 | 367 | self.stride = stride 368 | self.padding = padding 369 | 370 | if bias: 371 | self.bias = nn.Parameter(torch.zeros(out_channel)) 372 | else: 373 | self.bias = None 374 | 375 | def forward(self, input): 376 | out = F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) 377 | return out 378 | 379 | def __repr__(self): 380 | return ( 381 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 382 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 383 | ) 384 | 385 | 386 | class EqualConvTranspose2d(nn.Module): 387 | """2D transpose convolution layer with equalized learning rate. 388 | During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to 389 | prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU 390 | activation functions. 391 | Args: 392 | ---- 393 | in_channel: int 394 | Input channels. 395 | out_channel: int 396 | Output channels. 397 | kernel_size: int 398 | Kernel size. 399 | stride: int 400 | Stride of convolutional kernel across the input. 401 | padding: int 402 | Amount of zero padding applied to both sides of the input. 403 | output_padding: int 404 | Extra padding added to input to achieve the desired output size. 405 | bias: bool 406 | Use bias term. 407 | """ 408 | 409 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, output_padding=0, bias=True): 410 | super().__init__() 411 | 412 | self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size)) 413 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 414 | 415 | self.stride = stride 416 | self.padding = padding 417 | self.output_padding = output_padding 418 | 419 | if bias: 420 | self.bias = nn.Parameter(torch.zeros(out_channel)) 421 | else: 422 | self.bias = None 423 | 424 | def forward(self, input): 425 | out = F.conv_transpose2d( 426 | input, 427 | self.weight * self.scale, 428 | bias=self.bias, 429 | stride=self.stride, 430 | padding=self.padding, 431 | output_padding=self.output_padding, 432 | ) 433 | return out 434 | 435 | def __repr__(self): 436 | return ( 437 | f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},' 438 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' 439 | ) 440 | 441 | 442 | class ConvLayer2d(nn.Sequential): 443 | def __init__( 444 | self, 445 | in_channel, 446 | out_channel, 447 | kernel_size=3, 448 | upsample=False, 449 | downsample=False, 450 | blur_kernel=[1, 3, 3, 1], 451 | bias=True, 452 | activate=True, 453 | ): 454 | assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' 455 | layers = [] 456 | 457 | if upsample: 458 | factor = 2 459 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 460 | pad0 = (p + 1) // 2 + factor - 1 461 | pad1 = p // 2 + 1 462 | 463 | layers.append( 464 | EqualConvTranspose2d( 465 | in_channel, out_channel, kernel_size, padding=0, stride=2, bias=bias and not activate 466 | ) 467 | ) 468 | layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)) 469 | 470 | if downsample: 471 | factor = 2 472 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 473 | pad0 = (p + 1) // 2 474 | pad1 = p // 2 475 | 476 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 477 | layers.append( 478 | EqualConv2d(in_channel, out_channel, kernel_size, padding=0, stride=2, bias=bias and not activate) 479 | ) 480 | 481 | if (not downsample) and (not upsample): 482 | padding = kernel_size // 2 483 | 484 | layers.append( 485 | EqualConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=1, bias=bias and not activate) 486 | ) 487 | 488 | if activate: 489 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 490 | 491 | super().__init__(*layers) 492 | 493 | 494 | class ConvResBlock2d(nn.Module): 495 | """2D convolutional residual block with equalized learning rate. 496 | Residual block composed of 3x3 convolutions and leaky ReLUs. 497 | Args: 498 | ---- 499 | in_channel: int 500 | Input channels. 501 | out_channel: int 502 | Output channels. 503 | upsample: bool 504 | Apply upsampling via strided convolution in the first conv. 505 | downsample: bool 506 | Apply downsampling via strided convolution in the second conv. 507 | """ 508 | 509 | def __init__(self, in_channel, out_channel, upsample=False, downsample=False): 510 | super().__init__() 511 | 512 | assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' 513 | mid_ch = in_channel if downsample else out_channel 514 | 515 | self.conv1 = ConvLayer2d(in_channel, mid_ch, upsample=upsample, kernel_size=3) 516 | self.conv2 = ConvLayer2d(mid_ch, out_channel, downsample=downsample, kernel_size=3) 517 | 518 | if (in_channel != out_channel) or upsample or downsample: 519 | self.skip = ConvLayer2d( 520 | in_channel, 521 | out_channel, 522 | upsample=upsample, 523 | downsample=downsample, 524 | kernel_size=1, 525 | activate=False, 526 | bias=False, 527 | ) 528 | 529 | def forward(self, input): 530 | out = self.conv1(input) 531 | out = self.conv2(out) 532 | 533 | if hasattr(self, 'skip'): 534 | skip = self.skip(input) 535 | out = (out + skip) / math.sqrt(2) 536 | else: 537 | out = (out + input) / math.sqrt(2) 538 | return out 539 | -------------------------------------------------------------------------------- /lib/models/networks/layers.py: -------------------------------------------------------------------------------- 1 | # The code is adapted from https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/main/wisp/models/layers.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def normalize_frobenius(x): 8 | """Normalizes the matrix according to the Frobenius norm. 9 | 10 | Args: 11 | x (torch.FloatTensor): A matrix. 12 | 13 | Returns: 14 | (torch.FloatTensor): A normalized matrix. 15 | """ 16 | norm = torch.sqrt((torch.abs(x)**2).sum()) 17 | return x / norm 18 | 19 | def normalize_L_1(x): 20 | """Normalizes the matrix according to the L1 norm. 21 | 22 | Args: 23 | x (torch.FloatTensor): A matrix. 24 | 25 | Returns: 26 | (torch.FloatTensor): A normalized matrix. 27 | """ 28 | abscolsum = torch.sum(torch.abs(x), dim=0) 29 | abscolsum = torch.min(torch.stack([1.0/abscolsum, torch.ones_like(abscolsum)], dim=0), dim=0)[0] 30 | return x * abscolsum[None,:] 31 | 32 | def normalize_L_inf(x): 33 | """Normalizes the matrix according to the Linf norm. 34 | 35 | Args: 36 | x (torch.FloatTensor): A matrix. 37 | 38 | Returns: 39 | (torch.FloatTensor): A normalized matrix. 40 | """ 41 | absrowsum = torch.sum(torch.abs(x), axis=1) 42 | absrowsum = torch.min(torch.stack([1.0/absrowsum, torch.ones_like(absrowsum)], dim=0), dim=0)[0] 43 | return x * absrowsum[:,None] 44 | 45 | class FrobeniusLinear(nn.Module): 46 | """A standard Linear layer which applies a Frobenius normalization in the forward pass. 47 | """ 48 | def __init__(self, *args, **kwargs): 49 | super().__init__() 50 | self.linear = nn.Linear(*args, **kwargs) 51 | 52 | def forward(self, x): 53 | weight = normalize_frobenius(self.linear.weight) 54 | return F.linear(x, weight, self.linear.bias) 55 | 56 | class L_1_Linear(nn.Module): 57 | """A standard Linear layer which applies a L1 normalization in the forward pass. 58 | """ 59 | def __init__(self, *args, **kwargs): 60 | super().__init__() 61 | self.linear = nn.Linear(*args, **kwargs) 62 | 63 | def forward(self, x): 64 | weight = normalize_L_1(self.linear.weight) 65 | return F.linear(x, weight, self.linear.bias) 66 | 67 | class L_inf_Linear(nn.Module): 68 | """A standard Linear layer which applies a Linf normalization in the forward pass. 69 | """ 70 | def __init__(self, *args, **kwargs): 71 | super().__init__() 72 | self.linear = nn.Linear(*args, **kwargs) 73 | 74 | def forward(self, x): 75 | weight = normalize_L_inf(self.linear.weight) 76 | return F.linear(x, weight, self.linear.bias) 77 | 78 | def spectral_norm_(*args, **kwargs): 79 | """Initializes a spectral norm layer. 80 | """ 81 | return nn.utils.spectral_norm(nn.Linear(*args, **kwargs)) 82 | 83 | def get_layer_class(layer_type): 84 | """Convenience function to return the layer class name from text. 85 | 86 | Args: 87 | layer_type (str): Text name for the layer. 88 | 89 | Retunrs: 90 | (nn.Module): The layer to be used for the decoder. 91 | """ 92 | if layer_type == 'none': 93 | return nn.Linear 94 | elif layer_type == 'spectral_norm': 95 | return spectral_norm_ 96 | elif layer_type == 'frobenius_norm': 97 | return FrobeniusLinear 98 | elif layer_type == "l_1_norm": 99 | return L_1_Linear 100 | elif layer_type == "l_inf_norm": 101 | return L_inf_Linear 102 | else: 103 | assert(False and "layer type does not exist") 104 | -------------------------------------------------------------------------------- /lib/models/networks/mlps.py: -------------------------------------------------------------------------------- 1 | # The code is adapted from https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/main/wisp/models/decoders/basic_decoders.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class MLP(nn.Module): 8 | """Super basic but super useful MLP class. 9 | """ 10 | def __init__(self, 11 | input_dim, 12 | output_dim, 13 | activation = torch.relu, 14 | bias = True, 15 | layer = nn.Linear, 16 | num_layers = 4, 17 | hidden_dim = 128, 18 | skip = [2] 19 | ): 20 | """Initialize the MLP. 21 | 22 | Args: 23 | input_dim (int): Input dimension of the MLP. 24 | output_dim (int): Output dimension of the MLP. 25 | activation (function): The activation function to use. 26 | bias (bool): If True, use bias. 27 | layer (nn.Module): The MLP layer module to use. 28 | num_layers (int): The number of hidden layers in the MLP. 29 | hidden_dim (int): The hidden dimension of the MLP. 30 | skip (List[int]): List of layer indices where the input dimension is concatenated. 31 | 32 | Returns: 33 | (void): Initializes the class. 34 | """ 35 | super().__init__() 36 | 37 | self.input_dim = input_dim 38 | self.output_dim = output_dim 39 | self.activation = activation 40 | self.bias = bias 41 | self.layer = layer 42 | self.num_layers = num_layers 43 | self.hidden_dim = hidden_dim 44 | self.skip = skip 45 | if self.skip is None: 46 | self.skip = [] 47 | 48 | self.make() 49 | 50 | def make(self): 51 | """Builds the actual MLP. 52 | """ 53 | layers = [] 54 | for i in range(self.num_layers): 55 | if i == 0: 56 | layers.append(self.layer(self.input_dim, self.hidden_dim, bias=self.bias)) 57 | elif i in self.skip: 58 | layers.append(self.layer(self.hidden_dim+self.input_dim, self.hidden_dim, bias=self.bias)) 59 | else: 60 | layers.append(self.layer(self.hidden_dim, self.hidden_dim, bias=self.bias)) 61 | self.layers = nn.ModuleList(layers) 62 | self.lout = self.layer(self.hidden_dim, self.output_dim, bias=self.bias) 63 | 64 | def forward(self, x, return_h=False): 65 | """Run the MLP! 66 | 67 | Args: 68 | x (torch.FloatTensor): Some tensor of shape [batch, ..., input_dim] 69 | return_h (bool): If True, also returns the last hidden layer. 70 | 71 | Returns: 72 | (torch.FloatTensor, (optional) torch.FloatTensor): 73 | - The output tensor of shape [batch, ..., output_dim] 74 | - The last hidden layer of shape [batch, ..., hidden_dim] 75 | """ 76 | N = x.shape[0] 77 | 78 | for i, l in enumerate(self.layers): 79 | if i == 0: 80 | h = self.activation(l(x)) 81 | elif i in self.skip: 82 | h = torch.cat([x, h], dim=-1) 83 | h = self.activation(l(h)) 84 | else: 85 | h = self.activation(l(h)) 86 | 87 | out = self.lout(h) 88 | 89 | if return_h: 90 | return out, h 91 | else: 92 | return out 93 | 94 | 95 | 96 | class Conditional_MLP(nn.Module): 97 | """Super basic but super useful MLP class. 98 | """ 99 | def __init__(self, 100 | input_dim, 101 | cond_dim, 102 | output_dim, 103 | activation = torch.relu, 104 | bias = True, 105 | layer = nn.Linear, 106 | num_layers = 4, 107 | hidden_dim = 128, 108 | skip = [2] 109 | ): 110 | """Initialize the MLP. 111 | 112 | Args: 113 | input_dim (int): Input dimension of the MLP. 114 | output_dim (int): Output dimension of the MLP. 115 | activation (function): The activation function to use. 116 | bias (bool): If True, use bias. 117 | layer (nn.Module): The MLP layer module to use. 118 | num_layers (int): The number of hidden layers in the MLP. 119 | hidden_dim (int): The hidden dimension of the MLP. 120 | skip (List[int]): List of layer indices where the input dimension is concatenated. 121 | 122 | Returns: 123 | (void): Initializes the class. 124 | """ 125 | super().__init__() 126 | 127 | self.input_dim = input_dim 128 | self.cond_dim = cond_dim 129 | self.output_dim = output_dim 130 | self.activation = activation 131 | self.bias = bias 132 | self.layer = layer 133 | self.num_layers = num_layers 134 | self.hidden_dim = hidden_dim 135 | self.skip = skip 136 | if self.skip is None: 137 | self.skip = [] 138 | 139 | self.make() 140 | 141 | def make(self): 142 | """Builds the actual MLP. 143 | """ 144 | layers = [] 145 | for i in range(self.num_layers): 146 | if i == 0: 147 | layers.append(self.layer(self.input_dim, self.hidden_dim, bias=self.bias)) 148 | elif i in self.skip: 149 | layers.append(self.layer(self.hidden_dim+self.cond_dim, self.hidden_dim, bias=self.bias)) 150 | else: 151 | layers.append(self.layer(self.hidden_dim, self.hidden_dim, bias=self.bias)) 152 | self.layers = nn.ModuleList(layers) 153 | self.lout = self.layer(self.hidden_dim, self.output_dim, bias=self.bias) 154 | 155 | def forward(self, x, c, return_h=False, sigmoid=False): 156 | """Run the MLP! 157 | 158 | Args: 159 | x (torch.FloatTensor): Some tensor of shape [batch, ..., input_dim] 160 | return_h (bool): If True, also returns the last hidden layer. 161 | 162 | Returns: 163 | (torch.FloatTensor, (optional) torch.FloatTensor): 164 | - The output tensor of shape [batch, ..., output_dim] 165 | - The last hidden layer of shape [batch, ..., hidden_dim] 166 | """ 167 | N = x.shape[0] 168 | 169 | for i, l in enumerate(self.layers): 170 | if i == 0: 171 | h = self.activation(l(x)) 172 | elif i in self.skip: 173 | h = torch.cat([h, c], dim=-1) 174 | h = self.activation(l(h)) 175 | else: 176 | h = self.activation(l(h)) 177 | 178 | out = self.lout(h) 179 | if sigmoid: 180 | out = torch.sigmoid(out) 181 | 182 | if return_h: 183 | return out, h 184 | else: 185 | return out 186 | 187 | 188 | def get_activation_class(activation_type): 189 | """Utility function to return an activation function class based on the string description. 190 | 191 | Args: 192 | activation_type (str): The name for the activation function. 193 | 194 | Returns: 195 | (Function): The activation function to be used. 196 | """ 197 | if activation_type == 'relu': 198 | return torch.relu 199 | elif activation_type == 'sin': 200 | return torch.sin 201 | elif activation_type == 'softplus': 202 | return torch.nn.functional.softplus 203 | elif activation_type == 'lrelu': 204 | return torch.nn.functional.leaky_relu 205 | else: 206 | assert False and "activation type does not exist" -------------------------------------------------------------------------------- /lib/models/networks/positional_encoding.py: -------------------------------------------------------------------------------- 1 | # The code is adapted from https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/main/wisp/models/embedders/positional_embedder.py 2 | import torch 3 | import torch.nn as nn 4 | 5 | class PositionalEncoding(nn.Module): 6 | """PyTorch implementation of positional embedding. 7 | """ 8 | def __init__(self, num_freq, max_freq_log2, log_sampling=True, include_input=True, input_dim=3): 9 | """Initialize the module. 10 | 11 | Args: 12 | num_freq (int): The number of frequency bands to sample. 13 | max_freq_log2 (int): The maximum frequency. The bands will be sampled between [0, 2^max_freq_log2]. 14 | log_sampling (bool): If true, will sample frequency bands in log space. 15 | include_input (bool): If true, will concatenate the input. 16 | input_dim (int): The dimension of the input coordinate space. 17 | 18 | Returns: 19 | (void): Initializes the encoding. 20 | """ 21 | super().__init__() 22 | 23 | self.num_freq = num_freq 24 | self.max_freq_log2 = max_freq_log2 25 | self.log_sampling = log_sampling 26 | self.include_input = include_input 27 | self.out_dim = 0 28 | if include_input: 29 | self.out_dim += input_dim 30 | 31 | if self.log_sampling: 32 | self.bands = 2.0**torch.linspace(0.0, max_freq_log2, steps=num_freq) 33 | else: 34 | self.bands = torch.linspace(1, 2.0**max_freq_log2, steps=num_freq) 35 | 36 | # The out_dim is really just input_dim + num_freq * input_dim * 2 (for sin and cos) 37 | self.out_dim += self.bands.shape[0] * input_dim * 2 38 | self.bands = nn.Parameter(self.bands).requires_grad_(False) 39 | 40 | def forward(self, coords): 41 | """Embded the coordinates. 42 | 43 | Args: 44 | coords (torch.FloatTensor): Coordinates of shape [..., input_dim] 45 | 46 | Returns: 47 | (torch.FloatTensor): Embeddings of shape [..., input_dim + out_dim] or [..., out_dim]. 48 | """ 49 | shape = coords.shape 50 | # Flatten the coordinates 51 | assert len(shape) > 1 52 | if len(shape) > 2: 53 | coords = coords.reshape(-1, shape[-1]) 54 | N = coords.shape[0] 55 | winded = (coords[:,None] * self.bands[None,:,None]).reshape(N, -1) 56 | encoded = torch.cat([torch.sin(winded), torch.cos(winded)], dim=-1) 57 | if self.include_input: 58 | encoded = torch.cat([coords, encoded], dim=-1) 59 | # Reshape back to original 60 | if len(shape) > 2: 61 | encoded = encoded.reshape(*shape[:-1], -1) 62 | return encoded 63 | 64 | -------------------------------------------------------------------------------- /lib/models/neural_fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import logging as log 5 | 6 | from .feature_dictionary import FeatureDictionary 7 | from .networks.positional_encoding import PositionalEncoding 8 | from .networks.mlps import MLP, Conditional_MLP 9 | from .networks.layers import get_layer_class 10 | 11 | 12 | def get_activation_class(activation_type): 13 | """Utility function to return an activation function class based on the string description. 14 | 15 | Args: 16 | activation_type (str): The name for the activation function. 17 | 18 | Returns: 19 | (Function): The activation function to be used. 20 | """ 21 | if activation_type == 'relu': 22 | return torch.relu 23 | elif activation_type == 'sin': 24 | return torch.sin 25 | elif activation_type == 'softplus': 26 | return torch.nn.functional.softplus 27 | elif activation_type == 'lrelu': 28 | return torch.nn.functional.leaky_relu 29 | else: 30 | assert False and "activation type does not exist" 31 | 32 | 33 | #################################################### 34 | class NeuralField(nn.Module): 35 | 36 | def __init__(self, 37 | cfg :dict, 38 | smpl_V :torch.Tensor, 39 | smpl_F :torch.Tensor, 40 | feat_dim : int, 41 | out_dim : int, 42 | pos_freq : int, 43 | low_rank : int, 44 | sigmoid : bool = False, 45 | ): 46 | 47 | super().__init__() 48 | self.cfg = cfg 49 | self.smpl_V = smpl_V 50 | self.smpl_F = smpl_F 51 | self.feat_dim = feat_dim 52 | self.out_dim = out_dim 53 | self.pos_freq = pos_freq 54 | self.low_rank = low_rank 55 | self.sigmoid = sigmoid 56 | 57 | self.pos_dim = self.cfg.pos_dim 58 | self.c_dim = self.cfg.c_dim 59 | self.activation = self.cfg.activation 60 | self.layer_type = self.cfg.layer_type 61 | self.hidden_dim = self.cfg.hidden_dim 62 | self.num_layers = self.cfg.num_layers 63 | self.skip = self.cfg.skip 64 | self.feature_std = self.cfg.feature_std 65 | self.feature_bias = self.cfg.feature_bias 66 | 67 | 68 | self._init_dictionary() 69 | self._init_embedder() 70 | self._init_decoder() 71 | 72 | 73 | def _init_dictionary(self): 74 | """Initialize the feature dictionary object. 75 | """ 76 | 77 | self.dictionary = FeatureDictionary(self.feat_dim, self.feature_std, self.feature_bias) 78 | self.dictionary.init_from_smpl_vertices(self.smpl_V) 79 | 80 | def _init_embedder(self): 81 | """Initialize positional embedding objects. 82 | """ 83 | self.embedder = PositionalEncoding(self.pos_freq, self.pos_freq -1, input_dim=self.pos_dim) 84 | self.embed_dim = self.embedder.out_dim 85 | 86 | def _init_decoder(self): 87 | """Initialize the decoder object. 88 | """ 89 | self.input_dim = self.embed_dim + self.feat_dim 90 | 91 | if self.c_dim <= 0: 92 | self.decoder = MLP(self.input_dim, self.out_dim, activation=get_activation_class(self.activation), 93 | bias=True, layer=get_layer_class(self.layer_type), num_layers=self.num_layers, 94 | hidden_dim=self.hidden_dim, skip=self.skip) 95 | else: 96 | self.decoder = Conditional_MLP(self.input_dim, self.c_dim, self.out_dim, activation=get_activation_class(self.activation), 97 | bias=True, layer=get_layer_class(self.layer_type), num_layers=self.num_layers, 98 | hidden_dim=self.hidden_dim, skip=self.skip) 99 | 100 | 101 | log.info("Total number of parameters {}".format( 102 | sum(p.numel() for p in self.decoder.parameters()))\ 103 | ) 104 | 105 | def forward_decoder(self, feats, local_coords, normal, return_h=False, f=None): 106 | """Forward pass through the MLP decoder. 107 | Args: 108 | feats (torch.FloatTensor): Feature tensor of shape [B, N, feat_dim] 109 | local_coords (torch.FloatTensor): Local coordinate tensor of shape [B, N, 3] 110 | normal (torch.FloatTensor): Normal tensor of shape [B, N, 3] 111 | return_h (bool): Whether to return the hidden states of the network. 112 | f (torch.FloatTensor): The conditional feature tensor of shape [B, c_dim] 113 | 114 | """ 115 | 116 | if self.c_dim <= 0: 117 | input = torch.cat([self.embedder(local_coords), feats], dim=-1) 118 | return self.decoder(input, return_h=return_h, sigmoid=self.sigmoid) 119 | else: 120 | input = torch.cat([self.embedder(local_coords), feats], dim=-1) 121 | if f is not None: 122 | c = torch.cat([f, normal], dim=-1) 123 | else: 124 | c = normal 125 | return self.decoder(input, c, return_h=return_h, sigmoid=self.sigmoid) 126 | 127 | 128 | def forward(self, x, code_idx, pose_idx=None, return_h=False, f=None): 129 | """Forward pass through the network. 130 | Args: 131 | x (torch.FloatTensor): Coordinate tensor of shape [B, N, 3] 132 | code_idx (torch.LongTensor): Code index tensor of shape [B, 1] 133 | pose_idx (torch.LongTensor): SMPL_V index tensor of shape [B, 1] 134 | return_h (bool): Whether to return the hidden states of the network. 135 | f (torch.FloatTensor): The conditional feature tensor of shape [B, c_dim] 136 | """ 137 | if pose_idx is None: 138 | pose_idx = code_idx 139 | feats, local_coords, normal = self.dictionary.interpolate(x, code_idx, self.smpl_V[pose_idx], self.smpl_F) 140 | 141 | return self.forward_decoder(feats, local_coords, normal, return_h=return_h, f=f) 142 | 143 | def sample(self, x, idx, return_h=False, f=None): 144 | """Sample from the network. 145 | """ 146 | feats, local_coords, normal = self.dictionary.interpolate_random(x, self.smpl_V[idx], self.smpl_F, self.low_rank) 147 | 148 | return self.forward_decoder(feats, local_coords, normal, return_h=return_h, f=f) 149 | 150 | 151 | def regularization_loss(self, idx=None): 152 | """Compute the L2 regularization loss. 153 | """ 154 | 155 | if idx is None: 156 | return (self.dictionary.feature_codebooks**2).mean() 157 | else: 158 | return (self.dictionary.feature_codebooks[idx]**2).mean() 159 | 160 | 161 | def finitediff_gradient(self, x, idx, eps=0.005, sample=False): 162 | """Compute 3D gradient using finite difference. 163 | 164 | Args: 165 | x (torch.FloatTensor): Coordinate tensor of shape [B, N, 3] 166 | """ 167 | shape = x.shape 168 | 169 | eps_x = torch.tensor([eps, 0.0, 0.0], device=x.device) 170 | eps_y = torch.tensor([0.0, eps, 0.0], device=x.device) 171 | eps_z = torch.tensor([0.0, 0.0, eps], device=x.device) 172 | 173 | # shape: [B, 6, N, 3] -> [B, 6*N, 3] 174 | x_new = torch.stack([x + eps_x, x + eps_y, x + eps_z, 175 | x - eps_x, x - eps_y, x - eps_z], dim=1).reshape(shape[0], -1, shape[-1]) 176 | 177 | # shape: [B, 6*N, 3] -> [B, 6, N, 3] 178 | if sample: 179 | pred = self.sample(x_new, idx).reshape(shape[0], 6, -1) 180 | else: 181 | pred = self.forward(x_new, idx).reshape(shape[0], 6, -1) 182 | grad_x = (pred[:, 0, ...] - pred[:, 3, ...]) / (eps * 2.0) 183 | grad_y = (pred[:, 1, ...] - pred[:, 4, ...]) / (eps * 2.0) 184 | grad_z = (pred[:, 2, ...] - pred[:, 5, ...]) / (eps * 2.0) 185 | 186 | return torch.stack([grad_x, grad_y, grad_z], dim=-1) 187 | 188 | 189 | def forward_fitting(self, x, code, smpl_V, return_h=False, f=None): 190 | """Forward pass through the network with a latent code input. 191 | Args: 192 | x (torch.FloatTensor): Coordinate tensor of shape [1, N, 3] 193 | code (torch.FloatTensor): Latent code tensor of shape [1, n_vertices, c_dim] 194 | smpl_V (torch.FloatTensor): SMPL_V tensor of shape [1, n_vertices, 3] 195 | """ 196 | 197 | feats, local_coords, normal = self.dictionary.interpolate(x, 0, smpl_V, self.smpl_F, input_code=code) 198 | 199 | return self.forward_decoder(feats, local_coords, normal, return_h=return_h, f=f) 200 | 201 | def normal_fitting(self, x, code, smpl_V, eps=0.005): 202 | shape = x.shape 203 | 204 | eps_x = torch.tensor([eps, 0.0, 0.0], device=x.device) 205 | eps_y = torch.tensor([0.0, eps, 0.0], device=x.device) 206 | eps_z = torch.tensor([0.0, 0.0, eps], device=x.device) 207 | 208 | # shape: [B, 6, N, 3] -> [B, 6*N, 3] 209 | x_new = torch.stack([x + eps_x, x + eps_y, x + eps_z, 210 | x - eps_x, x - eps_y, x - eps_z], dim=1).reshape(shape[0], -1, shape[-1]) 211 | 212 | pred = self.forward_fitting(x_new, code, smpl_V).reshape(shape[0], 6, -1) 213 | grad_x = (pred[:, 0, ...] - pred[:, 3, ...]) / (eps * 2.0) 214 | grad_y = (pred[:, 1, ...] - pred[:, 4, ...]) / (eps * 2.0) 215 | grad_z = (pred[:, 2, ...] - pred[:, 5, ...]) / (eps * 2.0) 216 | 217 | return torch.stack([grad_x, grad_y, grad_z], dim=-1) 218 | 219 | def get_mean_feature(self, vert_idx=None): 220 | if vert_idx is None: 221 | return self.dictionary.feature_codebooks.mean(dim=0) 222 | else: 223 | return self.dictionary.feature_codebooks[:, vert_idx].mean(dim=0) 224 | 225 | def get_feature_by_idx(self, idx, vert_idx=None): 226 | if vert_idx is None: 227 | return self.dictionary.feature_codebooks[idx] 228 | else: 229 | return self.dictionary.feature_codebooks[idx][vert_idx] 230 | 231 | def replace_feature_by_idx(self, idx, feature, vert_idx=None): 232 | if vert_idx is None: 233 | self.dictionary.feature_codebooks[idx] = feature 234 | else: 235 | self.dictionary.feature_codebooks[idx][vert_idx] = feature 236 | 237 | def get_smpl_vertices_by_idx(self, idx): 238 | return self.smpl_V[idx] 239 | 240 | def replace_smpl_vertices_by_idx(self, idx, smpl_V): 241 | self.smpl_V[idx] = smpl_V 242 | -------------------------------------------------------------------------------- /lib/models/tracer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | 6 | class SDFTracer(object): 7 | 8 | def __init__(self, 9 | cfg = None, 10 | camera_clamp : list = [-4, 4], 11 | step_size : float = 1.0, 12 | num_steps : int = 64, # samples for raymaching, iterations for sphere trace 13 | min_dis : float = 1e-3): 14 | 15 | self.camera_clamp = camera_clamp 16 | self.step_size = step_size 17 | self.num_steps = num_steps 18 | self.min_dis = min_dis 19 | 20 | self.inv_num_steps = 1.0 / self.num_steps 21 | 22 | def __call__(self, *args, **kwargs): 23 | return self.forward(*args, **kwargs) 24 | 25 | def forward(self, nef, idx, ray_o, ray_d): 26 | """PyTorch implementation of sphere tracing. 27 | Args: 28 | nef: Neural field object 29 | idx: index of the subject, shape (B, ) 30 | ray_o: ray origin, shape (B, N, 3) 31 | ray_d: ray direction, shape (B, N, 3) 32 | """ 33 | 34 | # Distanace from ray origin 35 | t = torch.zeros(ray_o.shape[0], ray_o.shape[1], 1, device=ray_o.device) 36 | 37 | # Position in model space 38 | x = torch.addcmul(ray_o, ray_d, t) 39 | 40 | cond = torch.ones_like(t).bool() 41 | 42 | normal = torch.zeros_like(x) 43 | # This function is in fact differentiable, but we treat it as if it's not, because 44 | # it evaluates a very long chain of recursive neural networks (essentially a NN with depth of 45 | # ~1600 layers or so). This is not sustainable in terms of memory use, so we return the final hit 46 | # locations, where additional quantities (normal, depth, segmentation) can be determined. The 47 | # gradients will propagate only to these locations. 48 | with torch.no_grad(): 49 | 50 | d = nef(x, idx) 51 | 52 | dprev = d.clone() 53 | 54 | # If cond is TRUE, then the corresponding ray has not hit yet. 55 | # OR, the corresponding ray has exit the clipping plane. 56 | #cond = torch.ones_like(d).bool()[:,0] 57 | 58 | # If miss is TRUE, then the corresponding ray has missed entirely. 59 | hit = torch.zeros_like(d).bool() 60 | 61 | for i in range(self.num_steps): 62 | # 1. Check if ray hits. 63 | #hit = (torch.abs(d) < self._MIN_DIS)[:,0] 64 | # 2. Check that the sphere tracing is not oscillating 65 | #hit = hit | (torch.abs((d + dprev) / 2.0) < self._MIN_DIS * 3)[:,0] 66 | 67 | # 3. Check that the ray has not exit the far clipping plane. 68 | #cond = (torch.abs(t) < self.clamp[1])[:,0] 69 | 70 | hit = (torch.abs(t) < self.camera_clamp[1]) 71 | 72 | # 1. not hit surface 73 | cond = cond & (torch.abs(d) > self.min_dis) 74 | 75 | # 2. not oscillating 76 | cond = cond & (torch.abs((d + dprev) / 2.0) > self.min_dis * 3) 77 | 78 | # 3. not a hit 79 | cond = cond & hit 80 | 81 | #cond = cond & ~hit 82 | 83 | # If the sum is 0, that means that all rays have hit, or missed. 84 | if not cond.any(): 85 | break 86 | 87 | # Advance the x, by updating with a new t 88 | x = torch.where(cond, torch.addcmul(ray_o, ray_d, t), x) 89 | 90 | # Store the previous distance 91 | dprev = torch.where(cond, d, dprev) 92 | 93 | # Update the distance to surface at x 94 | d[cond] = nef(x, idx)[cond] * self.step_size 95 | 96 | # Update the distance from origin 97 | t = torch.where(cond, t+d, t) 98 | 99 | # AABB cull 100 | 101 | hit = hit & ~(torch.abs(x) > 1.0).any(dim=-1,keepdim=True) 102 | #hit = torch.ones_like(d).byte()[...,0] 103 | 104 | # The function will return 105 | # x: the final model-space coordinate of the render 106 | # t: the final distance from origin 107 | # d: the final distance value from 108 | # miss: a vector containing bools of whether each ray was a hit or miss 109 | 110 | #if hit.any(): 111 | # grad = nef.finitediff_gradient(x[hit], idx) 112 | # _normal = F.normalize(grad, p=2, dim=-1, eps=1e-5) 113 | # normal[hit] = _normal 114 | 115 | return x, hit 116 | -------------------------------------------------------------------------------- /lib/models/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging as log 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | from .neural_fields import NeuralField 9 | from .tracer import SDFTracer 10 | from .losses import GANLoss 11 | from kaolin.ops.conversions import voxelgrids_to_trianglemeshes 12 | from kaolin.ops.mesh import subdivide_trianglemesh 13 | 14 | import wandb 15 | 16 | class Trainer(nn.Module): 17 | 18 | def __init__(self, config, smpl_V, smpl_F, log_dir): 19 | 20 | super().__init__() 21 | 22 | # Set device to use 23 | self.device = torch.device('cuda') 24 | device_name = torch.cuda.get_device_name(device=self.device) 25 | log.info(f'Using {device_name} with CUDA v{torch.version.cuda}') 26 | 27 | self.cfg = config 28 | self.use_2d = self.cfg.use_2d_from_epoch >= 0 29 | self.use_2d_nrm = self.cfg.use_nrm_dis 30 | 31 | self.log_dir = log_dir 32 | self.log_dict = {} 33 | 34 | self.smpl_F = smpl_F.to(self.device).detach() 35 | self.smpl_V = smpl_V.to(self.device).detach() 36 | 37 | self.epoch = 0 38 | self.global_step = 0 39 | 40 | self.init_model() 41 | self.init_optimizer() 42 | self.init_log_dict() 43 | 44 | 45 | def init_model(self): 46 | """Initialize model. 47 | """ 48 | 49 | log.info("Initializing geometry neural field...") 50 | 51 | self.sdf_field = NeuralField(self.cfg, 52 | self.smpl_V, 53 | self.smpl_F, 54 | self.cfg.shape_dim, 55 | 1, 56 | self.cfg.shape_freq, 57 | self.cfg.shape_pca_dim).to(self.device) 58 | 59 | log.info("Initializing texture neural field...") 60 | 61 | self.rgb_field = NeuralField(self.cfg, 62 | self.smpl_V, 63 | self.smpl_F, 64 | self.cfg.color_dim, 65 | 3, 66 | self.cfg.color_freq, 67 | self.cfg.color_pca_dim, 68 | sigmoid=True).to(self.device) 69 | 70 | self.tracer = SDFTracer(self.cfg) 71 | 72 | 73 | if self.use_2d: 74 | log.info("Initializing RGB discriminators...") 75 | self.gan_loss_rgb = GANLoss(self.cfg, self.cfg.gan_loss_type, auxillary=True).to(self.device) 76 | if self.use_2d_nrm: 77 | log.info("Initializing normal discriminators...") 78 | self.gan_loss_nrm = GANLoss(self.cfg, self.cfg.gan_loss_type).to(self.device) 79 | 80 | 81 | 82 | def init_optimizer(self): 83 | """Initialize optimizer. 84 | """ 85 | 86 | decoder_params = [] 87 | decoder_params.extend(list(self.sdf_field.decoder.parameters())) 88 | decoder_params.extend(list(self.rgb_field.decoder.parameters())) 89 | dictionary_params = [] 90 | dictionary_params.extend(list(self.sdf_field.dictionary.parameters())) 91 | dictionary_params.extend(list(self.rgb_field.dictionary.parameters())) 92 | 93 | params = [] 94 | params.append({'params': decoder_params, 95 | 'lr': self.cfg.lr_decoder, 96 | "weight_decay": self.cfg.weight_decay}) 97 | params.append({'params': dictionary_params, 98 | 'lr': self.cfg.lr_codebook}) 99 | 100 | 101 | self.optimizer = torch.optim.Adam(params, 102 | betas=(self.cfg.beta1, self.cfg.beta2)) 103 | 104 | if self.use_2d: 105 | dis_params = list(self.gan_loss_rgb.discriminator.parameters()) 106 | if self.use_2d_nrm: 107 | dis_params += list(self.gan_loss_nrm.discriminator.parameters()) 108 | 109 | self.optimizer_d = torch.optim.Adam(dis_params, 110 | lr=self.cfg.lr_dis, 111 | betas=(0.0, self.cfg.beta2)) 112 | def init_log_dict(self): 113 | """Custom logging dictionary. 114 | """ 115 | self.log_dict['total_iter_count'] = 0 116 | # 3D Loss 117 | self.log_dict['Loss_3D/rgb_loss'] = 0 118 | self.log_dict['Loss_3D/nrm_loss'] = 0 119 | self.log_dict['Loss_3D/reco_loss'] = 0 120 | self.log_dict['Loss_3D/reg_loss'] = 0 121 | self.log_dict['Loss_3D/total_loss'] = 0 122 | 123 | # RGB Discriminator 124 | self.log_dict['total_2D_count'] = 0 125 | 126 | self.log_dict['RGB_dis/D_loss'] = 0 127 | self.log_dict['RGB_dis/penalty_loss']= 0 128 | self.log_dict['RGB_dis/logits_real']= 0 129 | self.log_dict['RGB_dis/logits_fake']= 0 130 | 131 | # Nrm Discriminator 132 | self.log_dict['Nrm_dis/penalty_loss'] = 0 133 | self.log_dict['Nrm_dis/loss_D'] = 0 134 | self.log_dict['Nrm_dis/logits_real'] = 0 135 | self.log_dict['Nrm_dis/logits_fake'] = 0 136 | 137 | # 2D Loss 138 | self.log_dict['Loss_2D/RGB_G_loss'] = 0 139 | self.log_dict['Loss_2D/Nrm_G_loss'] = 0 140 | 141 | 142 | def step(self, epoch, n_iter, data): 143 | """Training step. 144 | 1. 3D forward 145 | 2. 3D backward 146 | 3. 2D forward 147 | 4. 2D backward 148 | """ 149 | # record stats 150 | self.epoch = epoch 151 | self.global_step = n_iter 152 | 153 | # Set inputs to device 154 | self.set_inputs(data) 155 | 156 | # Train 157 | self.optimizer.zero_grad() 158 | self.forward_3D() 159 | self.backward_3D() 160 | 161 | if self.use_2d and \ 162 | epoch >= self.cfg.use_2d_from_epoch and \ 163 | n_iter % self.cfg.train_2d_every_iter == 0: 164 | self.forward_2D_rgb() 165 | self.backward_2D_rgb() 166 | if self.use_2d_nrm: 167 | self.forward_2D_nrm() 168 | self.backward_2D_nrm() 169 | self.log_dict['total_2D_count'] += 1 170 | 171 | self.optimizer.step() 172 | self.log_dict['total_iter_count'] += 1 173 | 174 | def set_inputs(self, data): 175 | """Set inputs for training. 176 | """ 177 | self.b_szie, self.n_vertice, _ = data['pts'].shape 178 | self.idx = data['idx'].to(self.device) 179 | 180 | self.pts = data['pts'].to(self.device) 181 | self.gts = data['sdf'].to(self.device) 182 | self.rgb = data['rgb'].to(self.device) 183 | 184 | # Downsample normal for faster training 185 | self.nrm_pts = self.pts[:, :self.n_vertice//10].to(self.device) 186 | self.nrm = data['nrm'][:, :self.n_vertice//10].to(self.device) 187 | 188 | if self.use_2d: 189 | self.width = data['rgb_image'].shape[2] 190 | 191 | self.label = data['label'].view(self.b_szie).to(self.device) 192 | self.ray_dir = data['ray_dir_image'].view(self.b_szie,-1,3).to(self.device) 193 | self.ray_ori = data['ray_ori_image'].view(self.b_szie,-1,3).to(self.device) 194 | self.gt_xyz = data['xyz_image'].view(self.b_szie,-1,3).to(self.device) 195 | self.gt_nrm = data['nrm_image'].view(self.b_szie,-1,3).to(self.device) 196 | self.gt_rgb = data['rgb_image'].view(self.b_szie,-1,3).to(self.device) 197 | self.gt_mask = data['mask_image'].view(self.b_szie,-1,1).to(self.device) 198 | 199 | def forward_3D(self): 200 | """Forward pass for 3D. 201 | predict sdf, rgb, nrm 202 | """ 203 | self.pred_sdf, geo_h = self.sdf_field(self.pts, self.idx, return_h=True) 204 | self.pred_rgb = self.rgb_field(self.pts, self.idx) 205 | self.pred_nrm = self.sdf_field.finitediff_gradient(self.nrm_pts, self.idx) 206 | self.pred_nrm = F.normalize(self.pred_nrm, p=2, dim=-1, eps=1e-5) 207 | 208 | def backward_3D(self): 209 | """Backward pass for 3D. 210 | Compute 3D loss 211 | """ 212 | total_loss = 0.0 213 | reco_loss = 0.0 214 | rgb_loss = 0.0 215 | reg_loss = 0.0 216 | 217 | reco_loss += torch.abs(self.pred_sdf - self.gts).mean() 218 | 219 | rgb_loss += torch.abs(self.pred_rgb - self.rgb).mean() 220 | 221 | #nrm_loss = torch.abs(1 - F.cosine_similarity(self.pred_nrm, self.nrm, dim=-1)).mean() 222 | nrm_loss = torch.abs(self.pred_nrm - self.nrm).mean() 223 | 224 | reg_loss += self.sdf_field.regularization_loss() 225 | reg_loss += self.rgb_field.regularization_loss() 226 | 227 | total_loss += reco_loss * self.cfg.lambda_sdf + \ 228 | rgb_loss * self.cfg.lambda_rgb + \ 229 | nrm_loss * self.cfg.lambda_nrm + \ 230 | reg_loss * self.cfg.lambda_reg 231 | 232 | total_loss.backward() 233 | 234 | # Update logs 235 | self.log_dict['Loss_3D/reco_loss'] += reco_loss.item() 236 | self.log_dict['Loss_3D/rgb_loss'] += rgb_loss.item() 237 | self.log_dict['Loss_3D/nrm_loss'] += nrm_loss.item() 238 | self.log_dict['Loss_3D/reg_loss'] += reg_loss.item() 239 | 240 | self.log_dict['Loss_3D/total_loss'] += total_loss.item() 241 | 242 | def forward_2D_rgb(self): 243 | """Forward pass for 2D rgb images. 244 | Fix geroemtry (3D coordinates) and random sample texture 245 | """ 246 | x = self.gt_xyz 247 | hit = self.gt_mask 248 | 249 | self.rgb_2d = self.rgb_field.sample(x.detach(), self.idx) * hit 250 | 251 | def forward_2D_nrm(self): 252 | """Forward pass for 2D nrm images. Random sample geometry and output normal. 253 | This requires online ray tracing and is slow. 254 | Cached points can be used as an approximation. 255 | """ 256 | if self.cfg.use_cached_pts: 257 | x = self.gt_xyz 258 | hit = self.gt_mask 259 | else: 260 | x, hit = self.tracer(self.sdf_field.sample, self.idx, self.ray_ori, self.ray_dir) 261 | 262 | _normal = self.sdf_field.finitediff_gradient(x, self.idx, sample=True) 263 | _normal = F.normalize(_normal, p=2, dim=-1, eps=1e-5) 264 | self.nrm_2d = _normal * hit 265 | 266 | def backward_2D_rgb(self): 267 | """Backward pass for 2D rgb images. 268 | Compute 2D adversarial loss for the discriminator and generator. 269 | """ 270 | 271 | total_2D_loss = 0.0 272 | 273 | # RGB GAN loss 274 | disc_in_fake = self.rgb_2d.view(self.b_szie, self.width, self.width, 3).permute(0,3,1,2) 275 | disc_in_real = (self.gt_rgb * self.gt_mask).view(self.b_szie, self.width, self.width, 3).permute(0,3,1,2) 276 | disc_in_real.requires_grad = True # for R1 gradient penalty 277 | 278 | self.optimizer_d.zero_grad() 279 | d_loss, log = self.gan_loss_rgb(disc_in_real, disc_in_fake, mode='d', gt_label=self.label) 280 | d_loss.backward() 281 | self.optimizer_d.step() 282 | 283 | self.log_dict['RGB_dis/D_loss'] += log['loss_train/disc_loss'] 284 | self.log_dict['RGB_dis/penalty_loss'] += log['loss_train/r1_loss'] 285 | self.log_dict['RGB_dis/logits_real'] += log['loss_train/logits_real'] 286 | self.log_dict['RGB_dis/logits_fake'] += log['loss_train/logits_fake'] 287 | 288 | g_loss, log = self.gan_loss_rgb(None, disc_in_fake, mode='g') 289 | total_2D_loss += g_loss 290 | total_2D_loss.backward() 291 | 292 | self.log_dict['Loss_2D/RGB_G_loss'] += log['loss_train/g_loss'] 293 | 294 | def backward_2D_nrm(self): 295 | """Backward pass for 2D normal images. 296 | Compute 2D adversarial loss for the discriminator and generator. 297 | """ 298 | 299 | # Nrm GAN loss 300 | total_2D_loss = 0.0 301 | 302 | disc_in_fake = self.nrm_2d.view(self.b_szie, self.width, self.width, 3).permute(0,3,1,2) 303 | disc_in_real = (self.gt_nrm * self.gt_mask).view(self.b_szie, self.width, self.width, 3).permute(0,3,1,2) 304 | disc_in_real.requires_grad = True # for R1 gradient penalty 305 | 306 | self.optimizer_d.zero_grad() 307 | d_loss, log = self.gan_loss_nrm(disc_in_real, disc_in_fake, mode='d') 308 | d_loss.backward() 309 | self.optimizer_d.step() 310 | 311 | self.log_dict['Nrm_dis/loss_D'] += log['loss_train/disc_loss'] 312 | self.log_dict['Nrm_dis/penalty_loss'] += log['loss_train/r1_loss'] 313 | self.log_dict['Nrm_dis/logits_real'] += log['loss_train/logits_real'] 314 | self.log_dict['Nrm_dis/logits_fake'] += log['loss_train/logits_fake'] 315 | 316 | g_loss, log = self.gan_loss_rgb(None, disc_in_fake, mode='g') 317 | total_2D_loss += g_loss 318 | total_2D_loss.backward() 319 | 320 | self.log_dict['Loss_2D/Nrm_G_loss'] += log['loss_train/g_loss'] 321 | 322 | def log(self, step, epoch): 323 | """Log the training information. 324 | """ 325 | log_text = 'STEP {} - EPOCH {}/{}'.format(step, epoch, self.cfg.epochs) 326 | self.log_dict['Loss_3D/total_loss'] /= self.log_dict['total_iter_count'] + 1e-6 327 | log_text += ' | total loss: {:>.3E}'.format(self.log_dict['Loss_3D/total_loss']) 328 | self.log_dict['Loss_3D/reco_loss'] /= self.log_dict['total_iter_count'] + 1e-6 329 | log_text += ' | Reco loss: {:>.3E}'.format(self.log_dict['Loss_3D/reco_loss']) 330 | self.log_dict['Loss_3D/rgb_loss'] /= self.log_dict['total_iter_count'] + 1e-6 331 | log_text += ' | rgb loss: {:>.3E}'.format(self.log_dict['Loss_3D/rgb_loss']) 332 | self.log_dict['Loss_3D/nrm_loss'] /= self.log_dict['total_iter_count'] + 1e-6 333 | log_text += ' | nrm loss: {:>.3E}'.format(self.log_dict['Loss_3D/nrm_loss']) 334 | self.log_dict['Loss_3D/reg_loss'] /= self.log_dict['total_iter_count'] + 1e-6 335 | 336 | log.info(log_text) 337 | 338 | for key, value in self.log_dict.items(): 339 | if ['RGB_dis', 'Nrm_dis', 'Loss_2D'].count(key.split('/')[0]) > 0: 340 | value /= self.log_dict['total_2D_count'] + 1e-6 341 | wandb.log({key: value}, step=step) 342 | self.init_log_dict() 343 | 344 | def write_images(self, i): 345 | """Write images to wandb. 346 | """ 347 | gen_img = self.rgb_2d.view(self.b_szie, self.width , self.width , 3).clone().detach().cpu().numpy() 348 | gt_img = (self.gt_rgb * self.gt_mask).view(self.b_szie, self.width , self.width , 3).clone().detach().cpu().numpy() 349 | wandb.log({"Generated Images": [wandb.Image(gen_img[i]) for i in range(self.b_szie)]}, step=i) 350 | wandb.log({"Ground Truth Images": [wandb.Image(gt_img[i]) for i in range(self.b_szie)]}, step=i) 351 | 352 | if self.use_2d_nrm: 353 | gen_nrm = self.nrm_2d.view(self.b_szie, self.width , self.width , 3).clone().detach().cpu().numpy() * 0.5 + 0.5 354 | gt_nrm = (self.gt_nrm * self.gt_mask).view(self.b_szie, self.width , self.width , 3).clone().detach().cpu().numpy() * 0.5 + 0.5 355 | gen_nrm = np.clip(gen_nrm, 0, 1) 356 | gt_nrm = np.clip(gt_nrm, 0, 1) 357 | wandb.log({"Generated Normals": [wandb.Image(gen_nrm[i]) for i in range(self.b_szie)]}, step=i) 358 | wandb.log({"Ground Truth Normals": [wandb.Image(gt_nrm[i]) for i in range(self.b_szie)]}, step=i) 359 | 360 | def save_checkpoint(self, full=True, replace=False): 361 | """Save the model checkpoint. 362 | """ 363 | 364 | if replace: 365 | model_fname = os.path.join(self.log_dir, f'model-.pth') 366 | else: 367 | model_fname = os.path.join(self.log_dir, f'model-{self.epoch:04d}.pth') 368 | 369 | state = { 370 | 'epoch': self.epoch, 371 | 'global_step': self.global_step, 372 | 'log_dir': self.log_dir 373 | } 374 | 375 | if full: 376 | state['optimizer'] = self.optimizer.state_dict() 377 | if self.use_2d: 378 | state['optimizer_d'] = self.optimizer_d.state_dict() 379 | 380 | 381 | state['sdf'] = self.sdf_field.state_dict() 382 | state['rgb'] = self.rgb_field.state_dict() 383 | if self.use_2d: 384 | state['D_rgb'] = self.gan_loss_rgb.state_dict() 385 | if self.use_2d_nrm: 386 | state['D_nrm'] = self.gan_loss_nrm.state_dict() 387 | 388 | log.info(f'Saving model checkpoint to: {model_fname}') 389 | torch.save(state, model_fname) 390 | 391 | 392 | def load_checkpoint(self, fname): 393 | """Load checkpoint. 394 | """ 395 | try: 396 | checkpoint = torch.load(fname, map_location=self.device) 397 | log.info(f'Loading model checkpoint from: {fname}') 398 | except FileNotFoundError: 399 | log.warning(f'No checkpoint found at: {fname}, model randomly initialized.') 400 | return 401 | 402 | # update meta info 403 | self.epoch = checkpoint['epoch'] 404 | self.global_step = checkpoint['global_step'] 405 | self.log_dir = checkpoint['log_dir'] 406 | 407 | self.sdf_field.load_state_dict(checkpoint['sdf']) 408 | self.rgb_field.load_state_dict(checkpoint['rgb']) 409 | if self.use_2d: 410 | if 'D_rgb' in checkpoint: 411 | self.gan_loss_rgb.load_state_dict(checkpoint['D_rgb']) 412 | if self.use_2d_nrm and 'D_nrm' in checkpoint: 413 | self.gan_loss_nrm.load_state_dict(checkpoint['D_nrm']) 414 | 415 | if 'optimizer' in checkpoint: 416 | self.optimizer.load_state_dict(checkpoint['optimizer']) 417 | if self.use_2d: 418 | self.optimizer_d.load_state_dict(checkpoint['optimizer_d']) 419 | 420 | log.info(f'Loaded checkpoint at epoch {self.epoch} with global step {self.global_step}.') 421 | 422 | ''' 423 | ####################################################################################################################################### 424 | 425 | def reconstruction(self, epoch, i, subdivide, res=300): 426 | 427 | torch.cuda.empty_cache() 428 | 429 | with torch.no_grad(): 430 | h = self._marching_cubes (i, subdivide=subdivide, res=res) 431 | h.export(os.path.join(self.log_dir, '%03d_reco_src-%03d.obj' % (epoch, i)) ) 432 | 433 | torch.cuda.empty_cache() 434 | 435 | 436 | def _marching_cubes (self, i, subdivide=True, res=300): 437 | 438 | width = res 439 | window_x = torch.linspace(-1., 1., steps=width, device='cuda') 440 | window_y = torch.linspace(-1., 1., steps=width, device='cuda') 441 | window_z = torch.linspace(-1., 1., steps=width, device='cuda') 442 | 443 | coord = torch.stack(torch.meshgrid(window_x, window_y, window_z)).permute(1, 2, 3, 0).reshape(1, -1, 3).contiguous() 444 | 445 | 446 | # Debug smpl grid 447 | smpl_vertice = self.smpl_V[i] 448 | d = trimesh.Trimesh(vertices=smpl_vertice.cpu().detach().numpy(), 449 | faces=self.smpl_F.cpu().detach().numpy()) 450 | d.export(os.path.join(self.log_dir, 'smpl_sub_%03d.obj' % (i)) ) 451 | 452 | 453 | idx = torch.tensor([i], dtype=torch.long, device = torch.device('cuda')).view(1).detach() 454 | _points = torch.split(coord, int(2*1e6), dim=1) 455 | voxels = [] 456 | for _p in _points: 457 | pred_sdf = self.sdf_field(_p, idx) 458 | voxels.append(pred_sdf) 459 | 460 | voxels = torch.cat(voxels, dim=1) 461 | voxels = voxels.reshape(1, width, width, width) 462 | 463 | vertices, faces = voxelgrids_to_trianglemeshes(voxels, iso_value=0.) 464 | vertices = ((vertices[0].reshape(1, -1, 3) - 0.5) / (width/2)) - 1.0 465 | faces = faces[0] 466 | 467 | if subdivide: 468 | vertices, faces = subdivide_trianglemesh(vertices, faces, iterations=1) 469 | 470 | pred_rgb = self.rgb_field(vertices, idx+1, pose_idx=idx) 471 | 472 | h = trimesh.Trimesh(vertices=vertices[0].cpu().detach().numpy(), 473 | faces=faces.cpu().detach().numpy(), 474 | vertex_colors=pred_rgb[0].cpu().detach().numpy()) 475 | 476 | # remove disconnect par of mesh 477 | connected_comp = h.split(only_watertight=False) 478 | max_area = 0 479 | max_comp = None 480 | for comp in connected_comp: 481 | if comp.area > max_area: 482 | max_area = comp.area 483 | max_comp = comp 484 | h = max_comp 485 | 486 | trimesh.repair.fix_inversion(h) 487 | 488 | return h 489 | ''' -------------------------------------------------------------------------------- /lib/ops/mesh/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | from .area_weighted_distribution import area_weighted_distribution 10 | from .random_face import random_face 11 | from .point_sample import point_sample 12 | from .sample_surface import sample_surface 13 | from .sample_near_surface import sample_near_surface 14 | from .sample_uniform import sample_uniform 15 | from .load_obj import load_obj 16 | from .normalize import normalize 17 | from .closest_point import * 18 | from .closest_tex import closest_tex 19 | from .barycentric_coordinates import barycentric_coordinates 20 | from .sample_tex import sample_tex 21 | from .per_face_normals import per_face_normals 22 | from .per_vertex_normals import per_vertex_normals 23 | -------------------------------------------------------------------------------- /lib/ops/mesh/area_weighted_distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | from .per_face_normals import per_face_normals 11 | 12 | def area_weighted_distribution( 13 | V : torch.Tensor, 14 | F : torch.Tensor, 15 | normals : torch.Tensor = None): 16 | """Construct discrete area weighted distribution over triangle mesh. 17 | 18 | Args: 19 | V (torch.Tensor): #V, 3 array of vertices 20 | F (torch.Tensor): #F, 3 array of indices 21 | normals (torch.Tensor): normals (if precomputed) 22 | eps (float): epsilon 23 | 24 | Returns: 25 | (torch.distributions): Distribution to be used 26 | """ 27 | 28 | if normals is None: 29 | normals = per_face_normals(V, F) 30 | areas = torch.norm(normals, p=2, dim=1) * 0.5 31 | areas /= torch.sum(areas) + 1e-10 32 | 33 | # Discrete PDF over triangles 34 | return torch.distributions.Categorical(areas.view(-1)) 35 | 36 | -------------------------------------------------------------------------------- /lib/ops/mesh/barycentric_coordinates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import numpy as np 11 | 12 | # Same API as https://github.com/libigl/libigl/blob/main/include/igl/barycentric_coordinates.cpp 13 | 14 | def barycentric_coordinates( 15 | points : torch.Tensor, 16 | A : torch.Tensor, 17 | B : torch.Tensor, 18 | C : torch.Tensor): 19 | """ 20 | Return barycentric coordinates for a given set of points and triangle vertices 21 | 22 | Args: 23 | points (torch.FloatTensor): [N, 3] 24 | A (torch.FloatTensor): [N, 3] vertex0 25 | B (torch.FloatTensor): [N, 3] vertex1 26 | C (torch.FloatTensor): [N, 3] vertex2 27 | 28 | Returns: 29 | (torch.FloatTensor): barycentric coordinates of [N, 2] 30 | """ 31 | 32 | v0 = B-A 33 | v1 = C-A 34 | v2 = points-A 35 | d00 = (v0*v0).sum(dim=-1) 36 | d01 = (v0*v1).sum(dim=-1) 37 | d11 = (v1*v1).sum(dim=-1) 38 | d20 = (v2*v0).sum(dim=-1) 39 | d21 = (v2*v1).sum(dim=-1) 40 | denom = d00*d11 - d01*d01 41 | L = torch.zeros(points.shape[0], 3, device=points.device) 42 | # Warning: This clipping may cause undesired behaviour 43 | L[...,1] = torch.clip((d11*d20 - d01*d21)/denom, 0.0, 1.0) 44 | L[...,2] = torch.clip((d00*d21 - d01*d20)/denom, 0.0, 1.0) 45 | L[...,0] = torch.clip(1.0 - (L[...,1] + L[...,2]), 0.0, 1.0) 46 | return L 47 | -------------------------------------------------------------------------------- /lib/ops/mesh/closest_point.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | # Closest point function + texture sampling 10 | # https://en.wikipedia.org/wiki/Closest_point_method 11 | 12 | import torch 13 | import numpy as np 14 | from .barycentric_coordinates import barycentric_coordinates 15 | from tqdm import tqdm 16 | from kaolin.ops.mesh import index_vertices_by_faces, check_sign 17 | from kaolin import _C 18 | 19 | 20 | class _UnbatchedTriangleDistanceCuda(torch.autograd.Function): 21 | @staticmethod 22 | def forward(ctx, points, face_vertices): 23 | num_points = points.shape[0] 24 | num_faces = face_vertices.shape[0] 25 | min_dist = torch.zeros((num_points), device=points.device, dtype=points.dtype) 26 | min_dist_idx = torch.zeros((num_points), device=points.device, dtype=torch.long) 27 | dist_type = torch.zeros((num_points), device=points.device, dtype=torch.int32) 28 | _C.metrics.unbatched_triangle_distance_forward_cuda( 29 | points, face_vertices, min_dist, min_dist_idx, dist_type) 30 | ctx.save_for_backward(points.contiguous(), face_vertices.contiguous(), 31 | min_dist_idx, dist_type) 32 | ctx.mark_non_differentiable(min_dist_idx, dist_type) 33 | return min_dist, min_dist_idx, dist_type 34 | 35 | @staticmethod 36 | def backward(ctx, grad_dist, grad_face_idx, grad_dist_type): 37 | points, face_vertices, face_idx, dist_type = ctx.saved_tensors 38 | grad_dist = grad_dist.contiguous() 39 | grad_points = torch.zeros_like(points) 40 | grad_face_vertices = torch.zeros_like(face_vertices) 41 | _C.metrics.unbatched_triangle_distance_backward_cuda( 42 | grad_dist, points, face_vertices, face_idx, dist_type, 43 | grad_points, grad_face_vertices) 44 | return grad_points, grad_face_vertices 45 | 46 | 47 | def _compute_dot(p1, p2): 48 | return p1[..., 0] * p2[..., 0] + \ 49 | p1[..., 1] * p2[..., 1] + \ 50 | p1[..., 2] * p2[..., 2] 51 | 52 | def _project_edge(vertex, edge, point): 53 | point_vec = point - vertex 54 | length = _compute_dot(edge, edge) 55 | return _compute_dot(point_vec, edge) / length 56 | 57 | def _project_plane(vertex, normal, point): 58 | point_vec = point - vertex 59 | unit_normal = normal / torch.norm(normal, dim=-1, keepdim=True) 60 | dist = _compute_dot(point_vec, unit_normal) 61 | return point - unit_normal * dist.view(-1, 1) 62 | 63 | def _is_not_above(vertex, edge, norm, point): 64 | edge_norm = torch.cross(norm, edge, dim=-1) 65 | return _compute_dot(edge_norm.view(1, -1, 3), 66 | point.view(-1, 1, 3) - vertex.view(1, -1, 3)) <= 0 67 | 68 | def _point_at(vertex, edge, proj): 69 | return vertex + edge * proj.view(-1, 1) 70 | 71 | 72 | def _unbatched_naive_point_to_mesh_distance(points, face_vertices): 73 | """ 74 | description of distance type: 75 | - 0: distance to face 76 | - 1: distance to vertice 0 77 | - 2: distance to vertice 1 78 | - 3: distance to vertice 2 79 | - 4: distance to edge 0-1 80 | - 5: distance to edge 1-2 81 | - 6: distance to edge 2-0 82 | Args: 83 | points (torch.Tensor): of shape (num_points, 3). 84 | faces_vertices (torch.LongTensor): of shape (num_faces, 3, 3). 85 | Returns: 86 | (torch.Tensor, torch.LongTensor, torch.IntTensor): 87 | - distance, of shape (num_points). 88 | - face_idx, of shape (num_points). 89 | - distance_type, of shape (num_points). 90 | - conter P 91 | """ 92 | num_points = points.shape[0] 93 | num_faces = face_vertices.shape[0] 94 | 95 | device = points.device 96 | dtype = points.dtype 97 | 98 | v1 = face_vertices[:, 0] 99 | v2 = face_vertices[:, 1] 100 | v3 = face_vertices[:, 2] 101 | 102 | e21 = v2 - v1 103 | e32 = v3 - v2 104 | e13 = v1 - v3 105 | 106 | normals = -torch.cross(e21, e13) 107 | 108 | uab = _project_edge(v1.view(1, -1, 3), e21.view(1, -1, 3), points.view(-1, 1, 3)) 109 | ubc = _project_edge(v2.view(1, -1, 3), e32.view(1, -1, 3), points.view(-1, 1, 3)) 110 | uca = _project_edge(v3.view(1, -1, 3), e13.view(1, -1, 3), points.view(-1, 1, 3)) 111 | 112 | is_type1 = (uca > 1.) & (uab < 0.) 113 | is_type2 = (uab > 1.) & (ubc < 0.) 114 | is_type3 = (ubc > 1.) & (uca < 0.) 115 | is_type4 = (uab >= 0.) & (uab <= 1.) & _is_not_above(v1, e21, normals, points) 116 | is_type5 = (ubc >= 0.) & (ubc <= 1.) & _is_not_above(v2, e32, normals, points) 117 | is_type6 = (uca >= 0.) & (uca <= 1.) & _is_not_above(v3, e13, normals, points) 118 | is_type0 = ~(is_type1 | is_type2 | is_type3 | is_type4 | is_type5 | is_type6) 119 | 120 | face_idx = torch.zeros(num_points, device=device, dtype=torch.long) 121 | all_closest_points = torch.zeros((num_points, num_faces, 3), device=device, 122 | dtype=dtype) 123 | 124 | all_type0_idx = torch.where(is_type0) 125 | all_type1_idx = torch.where(is_type1) 126 | all_type2_idx = torch.where(is_type2) 127 | all_type3_idx = torch.where(is_type3) 128 | all_type4_idx = torch.where(is_type4) 129 | all_type5_idx = torch.where(is_type5) 130 | all_type6_idx = torch.where(is_type6) 131 | 132 | all_types = is_type1.int() + is_type2.int() * 2 + is_type3.int() * 3 + \ 133 | is_type4.int() * 4 + is_type5.int() * 5 + is_type6.int() * 6 134 | 135 | all_closest_points[all_type0_idx] = _project_plane( 136 | v1[all_type0_idx[1]], normals[all_type0_idx[1]], points[all_type0_idx[0]]) 137 | all_closest_points[all_type1_idx] = v1.view(-1, 3)[all_type1_idx[1]] 138 | all_closest_points[all_type2_idx] = v2.view(-1, 3)[all_type2_idx[1]] 139 | all_closest_points[all_type3_idx] = v3.view(-1, 3)[all_type3_idx[1]] 140 | all_closest_points[all_type4_idx] = _point_at(v1[all_type4_idx[1]], e21[all_type4_idx[1]], 141 | uab[all_type4_idx]) 142 | all_closest_points[all_type5_idx] = _point_at(v2[all_type5_idx[1]], e32[all_type5_idx[1]], 143 | ubc[all_type5_idx]) 144 | all_closest_points[all_type6_idx] = _point_at(v3[all_type6_idx[1]], e13[all_type6_idx[1]], 145 | uca[all_type6_idx]) 146 | all_vec = (all_closest_points - points.view(-1, 1, 3)) 147 | all_dist = _compute_dot(all_vec, all_vec) 148 | 149 | _, min_dist_idx = torch.min(all_dist, dim=-1) 150 | dist_type = all_types[torch.arange(num_points, device=device), min_dist_idx] 151 | torch.cuda.synchronize() 152 | 153 | # Recompute the shortest distances 154 | # This reduce the backward pass to the closest faces instead of all faces 155 | # O(num_points) vs O(num_points * num_faces) 156 | selected_face_vertices = face_vertices[min_dist_idx] 157 | v1 = selected_face_vertices[:, 0] 158 | v2 = selected_face_vertices[:, 1] 159 | v3 = selected_face_vertices[:, 2] 160 | 161 | e21 = v2 - v1 162 | e32 = v3 - v2 163 | e13 = v1 - v3 164 | 165 | normals = -torch.cross(e21, e13) 166 | 167 | uab = _project_edge(v1, e21, points) 168 | ubc = _project_edge(v2, e32, points) 169 | uca = _project_edge(v3, e13, points) 170 | 171 | counter_p = torch.zeros((num_points, 3), device=device, dtype=dtype) 172 | 173 | cond = (dist_type == 1) 174 | counter_p[cond] = v1[cond] 175 | 176 | cond = (dist_type == 2) 177 | counter_p[cond] = v2[cond] 178 | 179 | cond = (dist_type == 3) 180 | counter_p[cond] = v3[cond] 181 | 182 | cond = (dist_type == 4) 183 | counter_p[cond] = _point_at(v1, e21, uab)[cond] 184 | 185 | cond = (dist_type == 5) 186 | counter_p[cond] = _point_at(v2, e32, ubc)[cond] 187 | 188 | cond = (dist_type == 6) 189 | counter_p[cond] = _point_at(v3, e13, uca)[cond] 190 | 191 | cond = (dist_type == 0) 192 | counter_p[cond] = _project_plane(v1, normals, points)[cond] 193 | min_dist = torch.sum((counter_p - points) ** 2, dim=-1) 194 | 195 | return min_dist, min_dist_idx, dist_type, counter_p 196 | 197 | 198 | def _find_closest_point(points, face_vertices, cur_face_idx, cur_dist_type): 199 | """Returns the closest point given a querypoints and meshes. 200 | points (torch.Tensor): of shape (num_points, 3). 201 | faces_vertices (torch.LongTensor): of shape (num_faces, 3, 3). 202 | cur_face_idx (torch.LongTensor): of shape (num_points,). 203 | cur_dist_type (torch.LongTensor): of shape (num_points,). 204 | 205 | Returns: 206 | (torch.FloatTensor): counter_p of shape (num_points, 3). 207 | """ 208 | num_points = points.shape[0] 209 | device = points.device 210 | dtype = points.dtype 211 | selected_face_vertices = face_vertices[cur_face_idx] 212 | 213 | v1 = selected_face_vertices[:, 0] 214 | v2 = selected_face_vertices[:, 1] 215 | v3 = selected_face_vertices[:, 2] 216 | 217 | e21 = v2 - v1 218 | e32 = v3 - v2 219 | e13 = v1 - v3 220 | 221 | normals = -torch.cross(e21, e13) 222 | 223 | uab = _project_edge(v1, e21, points) 224 | ubc = _project_edge(v2, e32, points) 225 | uca = _project_edge(v3, e13, points) 226 | 227 | counter_p = torch.zeros((num_points, 3), device=device, dtype=dtype) 228 | 229 | cond = (cur_dist_type == 1) 230 | counter_p[cond] = v1[cond] 231 | 232 | cond = (cur_dist_type == 2) 233 | counter_p[cond] = v2[cond] 234 | 235 | cond = (cur_dist_type == 3) 236 | counter_p[cond] = v3[cond] 237 | 238 | cond = (cur_dist_type == 4) 239 | counter_p[cond] = _point_at(v1, e21, uab)[cond] 240 | 241 | cond = (cur_dist_type == 5) 242 | counter_p[cond] = _point_at(v2, e32, ubc)[cond] 243 | 244 | cond = (cur_dist_type == 6) 245 | counter_p[cond] = _point_at(v3, e13, uca)[cond] 246 | 247 | cond = (cur_dist_type == 0) 248 | counter_p[cond] = _project_plane(v1, normals, points)[cond] 249 | 250 | 251 | return counter_p 252 | 253 | def closest_point( 254 | V : torch.Tensor, 255 | F : torch.Tensor, 256 | points : torch.Tensor, 257 | split_size : int = 5*10**3): 258 | 259 | """Returns the closest texture for a set of points. 260 | 261 | V (torch.FloatTensor): mesh vertices of shape [V, 3] 262 | F (torch.LongTensor): mesh face indices of shape [F, 3] 263 | points (torch.FloatTensor): sample locations of shape [N, 3] 264 | 265 | Returns: 266 | (torch.FloatTensor): distances of shape [N, 1] 267 | (torch.FloatTensor): projected points of shape [N, 3] 268 | (torch.FloatTensor): face indices of shape [N, 1] 269 | """ 270 | 271 | V = V.cuda().contiguous() 272 | F = F.cuda().contiguous() 273 | 274 | mesh = index_vertices_by_faces(V.unsqueeze(0), F).squeeze(0) 275 | 276 | _points = torch.split(points, split_size) 277 | 278 | dists = [] 279 | pts = [] 280 | indices = [] 281 | for _p in _points: 282 | p = _p.cuda().contiguous() 283 | sign = check_sign(V.unsqueeze(0), F, p.unsqueeze(0)).squeeze(0) 284 | dist, hit_tidx, dist_type, hit_pts = _unbatched_naive_point_to_mesh_distance(p, mesh) 285 | dist = torch.where (sign, -torch.sqrt(dist), torch.sqrt(dist)) 286 | dists.append(dist) 287 | pts.append(hit_pts) 288 | indices.append(hit_tidx) 289 | 290 | return torch.cat(dists)[...,None], torch.cat(pts), torch.cat(indices) 291 | 292 | def batched_closest_point( 293 | V : torch.Tensor, 294 | F : torch.Tensor, 295 | points : torch.Tensor): 296 | 297 | """Returns the closest texture for a set of points. 298 | 299 | V (torch.FloatTensor): mesh vertices of shape [B, V, 3] 300 | F (torch.LongTensor): mesh face indices of shape [F, 3] 301 | points (torch.FloatTensor): sample locations of shape [B, N, 3] 302 | 303 | Returns: 304 | (torch.FloatTensor): distances of shape [B, N, 1] 305 | (torch.FloatTensor): projected points of shape [B, N, 3] 306 | (torch.FloatTensor): face indices of shape [B, N, 1] 307 | """ 308 | 309 | V = V.cuda().contiguous() 310 | F = F.cuda().contiguous() 311 | 312 | batch_size = V.shape[0] 313 | num_points = V.shape[1] 314 | 315 | dists = [] 316 | pts = [] 317 | indices = [] 318 | weights = [] 319 | 320 | sign = check_sign(V, F, points) 321 | 322 | for i in range(batch_size): 323 | mesh = V[i][F] 324 | p = points[i] 325 | dist, hit_tidx, dist_type, hit_pts = _unbatched_naive_point_to_mesh_distance(p, mesh) 326 | dist = torch.where (sign[i], -torch.sqrt(dist), torch.sqrt(dist)) 327 | hitface = F[hit_tidx.view(-1)] # [ Ns , 3] 328 | 329 | 330 | BC = barycentric_coordinates(hit_pts, V[i][hitface[:,0]], 331 | V[i][hitface[:,1]], V[i][hitface[:,2]]) 332 | 333 | dists.append(dist) 334 | pts.append(hit_pts) 335 | indices.append(hit_tidx) 336 | weights.append(BC) 337 | 338 | return torch.stack(dists)[...,None], torch.stack(pts), torch.stack(indices), torch.stack(weights) 339 | 340 | 341 | def closest_point_fast( 342 | V : torch.Tensor, 343 | F : torch.Tensor, 344 | points : torch.Tensor): 345 | 346 | """Returns the closest texture for a set of points. 347 | 348 | V (torch.FloatTensor): mesh vertices of shape [V, 3] 349 | F (torch.LongTensor): mesh face indices of shape [F, 3] 350 | points (torch.FloatTensor): sample locations of shape [N, 3] 351 | 352 | Returns: 353 | (torch.FloatTensor): signed distances of shape [N, 1] 354 | (torch.FloatTensor): projected points of shape [N, 3] 355 | (torch.FloatTensor): face indices of shape [N, ] 356 | """ 357 | 358 | face_vertices = V[F] 359 | sign = check_sign(V.unsqueeze(0), F, points.unsqueeze(0)).squeeze(0) 360 | 361 | if points.is_cuda: 362 | cur_dist, cur_face_idx, cur_dist_type = _UnbatchedTriangleDistanceCuda.apply( 363 | points, face_vertices) 364 | else: 365 | cur_dist, cur_face_idx, cur_dist_type = _unbatched_naive_point_to_mesh_distance( 366 | points, face_vertices) 367 | 368 | hit_point = _find_closest_point(points, face_vertices, cur_face_idx, cur_dist_type) 369 | 370 | dist = torch.where (sign, -torch.sqrt(cur_dist), torch.sqrt(cur_dist)) 371 | 372 | 373 | return dist[...,None], hit_point, cur_face_idx 374 | 375 | 376 | def batched_closest_point_fast( 377 | V : torch.Tensor, 378 | F : torch.Tensor, 379 | points : torch.Tensor): 380 | 381 | """Returns the closest texture for a set of points. 382 | 383 | V (torch.FloatTensor): mesh vertices of shape [B, V, 3] 384 | F (torch.LongTensor): mesh face indices of shape [F, 3] 385 | points (torch.FloatTensor): sample locations of shape [B, N, 3] 386 | 387 | Returns: 388 | (torch.FloatTensor): distances of shape [B, N, 1] 389 | (torch.FloatTensor): projected points of shape [B, N, 3] 390 | (torch.FloatTensor): face indices of shape [B, N, 1] 391 | """ 392 | 393 | batch_size = V.shape[0] 394 | 395 | dists = [] 396 | indices = [] 397 | weights = [] 398 | pts = [] 399 | 400 | for i in range(batch_size): 401 | cur_dist, hit_point, cur_face_idx = closest_point_fast (V[i], F, points[i]) 402 | hitface = F[cur_face_idx.view(-1)] # [ N , 3] 403 | 404 | dists.append(cur_dist) 405 | pts.append(hit_point) 406 | indices.append(cur_face_idx) 407 | weights.append(barycentric_coordinates(hit_point, V[i][hitface[:,0]], 408 | V[i][hitface[:,1]], V[i][hitface[:,2]])) 409 | 410 | return torch.stack(dists, dim=0), torch.stack(pts, dim=0), \ 411 | torch.stack(indices, dim=0), torch.stack(weights, dim=0) -------------------------------------------------------------------------------- /lib/ops/mesh/closest_tex.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import numpy as np 11 | from .barycentric_coordinates import barycentric_coordinates 12 | from .closest_point import closest_point, closest_point_fast 13 | from .sample_tex import sample_tex 14 | from .per_face_normals import per_face_normals 15 | 16 | import time 17 | def closest_tex( 18 | V : torch.Tensor, 19 | F : torch.Tensor, 20 | TV : torch.Tensor, 21 | TF : torch.Tensor, 22 | materials, 23 | points : torch.Tensor): 24 | """Returns the closest texture for a set of points. 25 | 26 | V (torch.FloatTensor): mesh vertices of shape [V, 3] 27 | F (torch.LongTensor): mesh face indices of shape [F, 3] 28 | TV (torch.FloatTensor): 29 | TF (torch.FloatTensor): 30 | materials: 31 | points (torch.FloatTensor): sample locations of shape [N, 3] 32 | 33 | Returns: 34 | (torch.FloatTensor): texture samples of shape [N, 3] 35 | """ 36 | 37 | TV = TV.cuda() 38 | TF = TF.cuda() 39 | points = points.to(V.device) 40 | 41 | with torch.no_grad(): 42 | dist, hit_pts, hit_tidx = closest_point_fast(V, F, points) 43 | 44 | hit_F = F[hit_tidx] 45 | hit_V = V[hit_F].cuda() 46 | nrm = per_face_normals(V, hit_F).cuda() 47 | 48 | BC = barycentric_coordinates(hit_pts.cuda(), hit_V[:,0], hit_V[:,1], hit_V[:,2]) 49 | 50 | hit_TF = TF[hit_tidx] 51 | hit_TM = hit_TF[...,3] 52 | hit_TF = hit_TF[...,:3] 53 | 54 | if TV.shape[0] > 0: 55 | hit_TV = TV[hit_TF] 56 | hit_Tp = (hit_TV * BC.unsqueeze(-1)).sum(1) 57 | else: 58 | hit_Tp = BC 59 | rgb = sample_tex(hit_Tp, hit_TM, materials) 60 | 61 | return rgb, nrm, dist 62 | -------------------------------------------------------------------------------- /lib/ops/mesh/compute_sdf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import math 10 | import contextlib 11 | import os 12 | import sys 13 | 14 | import torch 15 | import numpy as np 16 | import wisp._C as _C 17 | 18 | def compute_sdf( 19 | V : torch.Tensor, 20 | F : torch.Tensor, 21 | points : torch.Tensor, 22 | split_size : int = 10**6): 23 | """Computes SDF given point samples and a mesh. 24 | 25 | Args: 26 | V (torch.FloatTensor): #V, 3 array of vertices 27 | F (torch.LongTensor): #F, 3 array of indices 28 | points (torch.FloatTensor): [N, 3] array of points to sample 29 | split_size (int): The batch at which the SDF will be computed. The kernel will break for too large 30 | batches; when in doubt use the default. 31 | 32 | Returns: 33 | (torch.FloatTensor): [N, 1] array of computed SDF values. 34 | """ 35 | mesh = V[F] 36 | 37 | _points = torch.split(points, split_size) 38 | sdfs = [] 39 | for _p in _points: 40 | sdfs.append(_C.external.mesh_to_sdf_cuda(_p.cuda().contiguous(), mesh.cuda().contiguous())[0]) 41 | return torch.cat(sdfs)[...,None] 42 | -------------------------------------------------------------------------------- /lib/ops/mesh/load_obj.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import os 10 | import sys 11 | 12 | import numpy as np 13 | import tinyobjloader 14 | import torch 15 | 16 | from PIL import Image 17 | 18 | import logging as log 19 | import time 20 | Image.MAX_IMAGE_PIXELS = None 21 | 22 | # Refer to 23 | # https://github.com/tinyobjloader/tinyobjloader/blob/master/tiny_obj_loader.h 24 | # for conventions for tinyobjloader data structures. 25 | 26 | texopts = [ 27 | 'ambient_texname', 28 | 'diffuse_texname', 29 | 'specular_texname', 30 | 'specular_highlight_texname', 31 | 'bump_texname', 32 | 'displacement_texname', 33 | 'alpha_texname', 34 | 'reflection_texname', 35 | 'roughness_texname', 36 | 'metallic_texname', 37 | 'sheen_texname', 38 | 'emissive_texname', 39 | 'normal_texname' 40 | ] 41 | 42 | def load_mat(fname : str): 43 | """Loads material. 44 | """ 45 | img = torch.ByteTensor(np.array(Image.open(fname))) 46 | #img = torch.ByteTensor(np.array(Image.open(fname).resize((2048,2048), Image.ANTIALIAS))) 47 | #img = img / 255.0 48 | 49 | return img 50 | 51 | 52 | def load_obj( 53 | fname : str, 54 | load_materials : bool = False): 55 | """Load .obj file using TinyOBJ and extract info. 56 | This is more robust since it can triangulate polygon meshes 57 | with up to 255 sides per face. 58 | 59 | Args: 60 | fname (str): path to Wavefront .obj file 61 | """ 62 | 63 | assert os.path.exists(fname), \ 64 | 'Invalid file path and/or format, must be an existing Wavefront .obj' 65 | 66 | reader = tinyobjloader.ObjReader() 67 | config = tinyobjloader.ObjReaderConfig() 68 | config.triangulate = True # Ensure we don't have any polygons 69 | 70 | reader.ParseFromFile(fname, config) 71 | 72 | # Get vertices 73 | attrib = reader.GetAttrib() 74 | vertices = torch.FloatTensor(attrib.vertices).reshape(-1, 3) 75 | 76 | # Get triangle face indices 77 | shapes = reader.GetShapes() 78 | faces = [] 79 | for shape in shapes: 80 | faces += [idx.vertex_index for idx in shape.mesh.indices] 81 | faces = torch.LongTensor(faces).reshape(-1, 3) 82 | 83 | mats = {} 84 | 85 | if load_materials: 86 | # Load per-faced texture coordinate indices 87 | texf = [] 88 | matf = [] 89 | for shape in shapes: 90 | texf += [idx.texcoord_index for idx in shape.mesh.indices] 91 | matf.extend(shape.mesh.material_ids) 92 | # texf stores [tex_idx0, tex_idx1, tex_idx2, mat_idx] 93 | texf = torch.LongTensor(texf).reshape(-1, 3) 94 | matf = torch.LongTensor(matf).reshape(-1, 1) 95 | texf = torch.cat([texf, matf], dim=-1) 96 | 97 | # Load texcoords 98 | texv = torch.FloatTensor(attrib.texcoords).reshape(-1, 2) 99 | 100 | # Load texture maps 101 | parent_path = os.path.dirname(fname) 102 | materials = reader.GetMaterials() 103 | for i, material in enumerate(materials): 104 | mats[i] = {} 105 | diffuse = getattr(material, 'diffuse') 106 | if diffuse != '': 107 | mats[i]['diffuse'] = torch.FloatTensor(diffuse) 108 | 109 | for texopt in texopts: 110 | mat_path = getattr(material, texopt) 111 | if mat_path != '': 112 | img = load_mat(os.path.join(parent_path, mat_path)) 113 | mats[i][texopt] = img 114 | #mats[i][texopt.split('_')[0]] = img 115 | return vertices, faces, texv, texf, mats 116 | 117 | return vertices, faces 118 | 119 | -------------------------------------------------------------------------------- /lib/ops/mesh/normalize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | 11 | def normalize( 12 | V : torch.Tensor, 13 | F : torch.Tensor, 14 | mode : str): 15 | """Normalizes a mesh. 16 | 17 | Args: 18 | V (torch.FloatTensor): Vertices of shape [V, 3] 19 | F (torch.LongTensor): Faces of shape [F, 3] 20 | mode (str): Different methods of normalization. 21 | 22 | Returns: 23 | (torch.FloatTensor, torch.LongTensor): 24 | - Normalized Vertices 25 | - Faces 26 | """ 27 | 28 | if mode == 'sphere': 29 | 30 | V_max, _ = torch.max(V, dim=0) 31 | V_min, _ = torch.min(V, dim=0) 32 | V_center = (V_max + V_min) / 2. 33 | V = V - V_center 34 | 35 | # Find the max distance to origin 36 | max_dist = torch.sqrt(torch.max(torch.sum(V**2, dim=-1))) 37 | V_scale = 1. / max_dist 38 | V *= V_scale 39 | return V, F 40 | 41 | elif mode == 'aabb': 42 | 43 | V_min, _ = torch.min(V, dim=0) 44 | V = V - V_min 45 | 46 | max_dist = torch.max(V) 47 | V *= 1.0 / max_dist 48 | 49 | V = V * 2.0 - 1.0 50 | 51 | return V, F 52 | 53 | elif mode == 'planar': 54 | 55 | V_min, _ = torch.min(V, dim=0) 56 | V = V - V_min 57 | 58 | x_max = torch.max(V[...,0]) 59 | z_max = torch.max(V[...,2]) 60 | 61 | V[...,0] *= 1.0 / x_max 62 | V[...,2] *= 1.0 / z_max 63 | 64 | max_dist = torch.max(V) 65 | V[...,1] *= 1.0 / max_dist 66 | #V *= 1.0 / max_dist 67 | 68 | V = V * 2.0 - 1.0 69 | 70 | y_min = torch.min(V[...,1]) 71 | 72 | V[...,1] -= y_min 73 | 74 | return V, F 75 | 76 | elif mode == 'none': 77 | 78 | return V, F 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /lib/ops/mesh/per_face_normals.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | 11 | def per_face_normals( 12 | V : torch.Tensor, 13 | F : torch.Tensor): 14 | """Compute normals per face. 15 | 16 | Args: 17 | V (torch.FloatTensor): Vertices of shape [V, 3] 18 | F (torch.LongTensor): Faces of shape [F, 3] 19 | 20 | Returns: 21 | (torch.FloatTensor): Normals of shape [F, 3] 22 | """ 23 | mesh = V[F] 24 | 25 | vec_a = mesh[:, 0] - mesh[:, 1] 26 | vec_b = mesh[:, 1] - mesh[:, 2] 27 | normals = torch.cross(vec_a, vec_b) 28 | return torch.nn.functional.normalize( 29 | normals, eps=1e-6, dim=1 30 | ) 31 | 32 | -------------------------------------------------------------------------------- /lib/ops/mesh/per_vertex_normals.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | 11 | def per_vertex_normals( 12 | V : torch.Tensor, 13 | F : torch.Tensor): 14 | """Compute normals per face. 15 | 16 | Args: 17 | V (torch.FloatTensor): Vertices of shape [V, 3] 18 | F (torch.LongTensor): Faces of shape [F, 3] 19 | 20 | Returns: 21 | (torch.FloatTensor): Normals of shape [F, 3] 22 | """ 23 | verts_normals = torch.zeros_like(V) 24 | mesh = V[F] 25 | 26 | faces_normals = torch.cross( 27 | mesh[:, 2] - mesh[:, 1], 28 | mesh[:, 0] - mesh[:, 1], 29 | dim=1, 30 | ) 31 | 32 | verts_normals.index_add_(0, F[:, 0], faces_normals) 33 | verts_normals.index_add_(0, F[:, 1], faces_normals) 34 | verts_normals.index_add_(0, F[:, 2], faces_normals) 35 | 36 | return torch.nn.functional.normalize( 37 | verts_normals, eps=1e-6, dim=1 38 | ) -------------------------------------------------------------------------------- /lib/ops/mesh/point_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | from .sample_near_surface import sample_near_surface 11 | from .sample_surface import sample_surface 12 | from .sample_uniform import sample_uniform 13 | from .area_weighted_distribution import area_weighted_distribution 14 | 15 | def point_sample( 16 | V : torch.Tensor, 17 | F : torch.Tensor, 18 | techniques : list, 19 | num_samples : int, 20 | variance: float = 0.005): 21 | """Sample points from a mesh. 22 | 23 | Args: 24 | V (torch.Tensor): #V, 3 array of vertices 25 | F (torch.Tensor): #F, 3 array of indices 26 | techniques (list[str]): list of techniques to sample with 27 | num_samples (int): points to sample per technique 28 | 29 | Returns: 30 | (torch.FloatTensor): Samples of shape [len(techniques)*num_samples, 3] 31 | """ 32 | if 'trace' in techniques or 'near' in techniques: 33 | # Precompute face distribution 34 | distrib = area_weighted_distribution(V, F) 35 | 36 | samples = [] 37 | for technique in techniques: 38 | if technique =='trace': 39 | samples.append(sample_surface(V, F, num_samples, distrib=distrib)[0]) 40 | elif technique == 'near': 41 | samples.append(sample_near_surface(V, F, num_samples, distrib=distrib, variance=variance)) 42 | elif technique == 'rand': 43 | samples.append(sample_uniform(num_samples).to(V.device)) 44 | samples = torch.cat(samples, dim=0) 45 | return samples 46 | 47 | -------------------------------------------------------------------------------- /lib/ops/mesh/random_face.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | from .area_weighted_distribution import area_weighted_distribution 11 | from .per_face_normals import per_face_normals 12 | 13 | def random_face( 14 | V : torch.Tensor, 15 | F : torch.Tensor, 16 | num_samples : int, 17 | distrib=None): 18 | """Return an area weighted random sample of faces and their normals from the mesh. 19 | 20 | Args: 21 | V (torch.Tensor): #V, 3 array of vertices 22 | F (torch.Tensor): #F, 3 array of indices 23 | num_samples (int): num of samples to return 24 | distrib: distribution to use. By default, area-weighted distribution is used. 25 | 26 | Returns: 27 | (torch.LongTensor, torch.FloatTensor): 28 | - Faces 29 | - Normals 30 | """ 31 | if distrib is None: 32 | distrib = area_weighted_distribution(V, F) 33 | 34 | normals = per_face_normals(V, F) 35 | 36 | idx = distrib.sample([num_samples]) 37 | 38 | return F[idx], normals[idx] 39 | 40 | -------------------------------------------------------------------------------- /lib/ops/mesh/sample_near_surface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | from .sample_surface import sample_surface 11 | from .area_weighted_distribution import area_weighted_distribution 12 | 13 | def sample_near_surface( 14 | V : torch.Tensor, 15 | F : torch.Tensor, 16 | num_samples: int, 17 | variance : float = 0.005, 18 | distrib=None): 19 | """Sample points near the mesh surface. 20 | 21 | Args: 22 | V (torch.Tensor): #V, 3 array of vertices 23 | F (torch.Tensor): #F, 3 array of indices 24 | num_samples (int): number of surface samples 25 | distrib: distribution to use. By default, area-weighted distribution is used 26 | 27 | Returns: 28 | (torch.FloatTensor): samples of shape [num_samples, 3] 29 | """ 30 | if distrib is None: 31 | distrib = area_weighted_distribution(V, F) 32 | samples = sample_surface(V, F, num_samples, distrib)[0] 33 | samples += torch.randn_like(samples) * variance 34 | return samples 35 | -------------------------------------------------------------------------------- /lib/ops/mesh/sample_surface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | from .random_face import random_face 11 | from .area_weighted_distribution import area_weighted_distribution 12 | 13 | def sample_surface( 14 | V : torch.Tensor, 15 | F : torch.Tensor, 16 | num_samples : int, 17 | distrib = None): 18 | """Sample points and their normals on mesh surface. 19 | 20 | Args: 21 | V (torch.Tensor): #V, 3 array of vertices 22 | F (torch.Tensor): #F, 3 array of indices 23 | num_samples (int): number of surface samples 24 | distrib: distribution to use. By default, area-weighted distribution is used 25 | 26 | Returns: 27 | (torch.FloatTensor): samples of shape [num_samples, 3] 28 | """ 29 | if distrib is None: 30 | distrib = area_weighted_distribution(V, F) 31 | 32 | # Select faces & sample their surface 33 | fidx, normals = random_face(V, F, num_samples, distrib) 34 | f = V[fidx] 35 | 36 | u = torch.sqrt(torch.rand(num_samples)).to(V.device).unsqueeze(-1) 37 | v = torch.rand(num_samples).to(V.device).unsqueeze(-1) 38 | 39 | samples = (1 - u) * f[:,0,:] + (u * (1 - v)) * f[:,1,:] + u * v * f[:,2,:] 40 | 41 | return samples, normals 42 | 43 | -------------------------------------------------------------------------------- /lib/ops/mesh/sample_tex.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import time 12 | def sample_tex( 13 | Tp : torch.Tensor, # points [N ,2] 14 | TM : torch.Tensor, # material indices [N] 15 | materials): 16 | """Sample from a texture. 17 | 18 | Args: 19 | Tp (torch.FloatTensor): 2D coordinates to sample of shape [N, 2] 20 | TM (torch.LongTensor): Indices of the material to sample of shape [N] 21 | materials (list of material): Materials 22 | 23 | Returns: 24 | (torch.FloatTensor): RGB samples of shape [N, 3] 25 | """ 26 | 27 | max_idx = TM.max() 28 | assert(max_idx > -1 and "No materials detected! Check the material definiton on your mesh.") 29 | 30 | rgb = torch.zeros(Tp.shape[0], 3, device=Tp.device) # why this line is slow???? 31 | 32 | Tp = (Tp * 2.0) - 1.0 33 | # The y axis is flipped from what UV maps generally expects vs in PyTorch 34 | Tp[...,1] *= -1 35 | 36 | for i in range(max_idx+1): 37 | mask = (TM == i) 38 | if mask.sum() == 0: 39 | continue 40 | if 'diffuse_texname' not in materials[i]: 41 | if 'diffuse' in materials[i]: 42 | rgb[mask] = materials[i]['diffuse'].to(Tp.device) 43 | continue 44 | 45 | map = materials[i]['diffuse_texname'][...,:3].permute(2, 0, 1)[None].float().to(Tp.device) / 255.0 46 | grid = Tp[mask] 47 | grid = grid.reshape(1, grid.shape[0], 1, grid.shape[1]) 48 | _rgb = F.grid_sample(map, grid, mode='bilinear', padding_mode='reflection', align_corners=True) 49 | _rgb = _rgb[0,:,:,0].permute(1,0) 50 | rgb[mask] = _rgb 51 | 52 | 53 | return rgb 54 | 55 | 56 | -------------------------------------------------------------------------------- /lib/ops/mesh/sample_uniform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. 8 | 9 | import torch 10 | 11 | def sample_uniform(num_samples : int): 12 | """Sample uniformly in [-1,1] bounding volume. 13 | 14 | Args: 15 | num_samples(int) : number of points to sample 16 | 17 | Returns: 18 | (torch.FloatTensor): samples of shape [num_samples, 3] 19 | """ 20 | return torch.rand(num_samples, 3) * 2.0 - 1.0 21 | 22 | -------------------------------------------------------------------------------- /lib/utils/camera.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/main/wisp/ops/raygen/raygen.py 2 | 3 | from kaolin.render.camera import Camera 4 | from kaolin.render.camera.intrinsics import CameraFOV 5 | import torch 6 | 7 | ################################ Ray Sampling Function ######################################## 8 | 9 | 10 | def _generate_default_grid(width, height, device=None): 11 | h_coords = torch.arange(height, device=device) 12 | w_coords = torch.arange(width, device=device) 13 | return torch.meshgrid(h_coords, w_coords, indexing='ij') # return pixel_y, pixel_x 14 | 15 | 16 | def generate_centered_pixel_coords(img_width, img_height, res_x=None, res_y=None, device=None): 17 | pixel_y, pixel_x = _generate_default_grid(res_x, res_y, device) 18 | scale_x = 1.0 if res_x is None else float(img_width) / res_x 19 | scale_y = 1.0 if res_y is None else float(img_height) / res_y 20 | pixel_x = pixel_x * scale_x + 0.5 # scale and add bias to pixel center 21 | pixel_y = pixel_y * scale_y + 0.5 # scale and add bias to pixel center 22 | return pixel_y, pixel_x 23 | 24 | 25 | # -- Ray gen -- 26 | 27 | def _to_ndc_coords(pixel_x, pixel_y, camera): 28 | pixel_x = 2 * (pixel_x / camera.width) - 1.0 29 | pixel_y = 2 * (pixel_y / camera.height) - 1.0 30 | return pixel_x, pixel_y 31 | 32 | 33 | def generate_pinhole_rays(camera: Camera, coords_grid: torch.Tensor): 34 | """Default ray generation function for pinhole cameras. 35 | 36 | This function assumes that the principal point (the pinhole location) is specified by a 37 | displacement (camera.x0, camera.y0) in pixel coordinates from the center of the image. 38 | 39 | The Kaolin camera class does not enforce a coordinate space for how the principal point is specified, 40 | so users will need to make sure that the correct principal point conventions are followed for 41 | the cameras passed into this function. 42 | 43 | Args: 44 | camera (kaolin.render.camera): The camera class. 45 | coords_grid (torch.FloatTensor): Grid of coordinates of shape [H, W, 2]. 46 | 47 | Returns: 48 | (wisp.core.Rays): The generated pinhole rays for the camera. 49 | """ 50 | if camera.device != coords_grid[0].device: 51 | raise Exception(f"Expected camera and coords_grid[0] to be on the same device, but found {camera.device} and {coords_grid[0].device}.") 52 | if camera.device != coords_grid[1].device: 53 | raise Exception(f"Expected camera and coords_grid[1] to be on the same device, but found {camera.device} and {coords_grid[1].device}.") 54 | # coords_grid should remain immutable (a new tensor is implicitly created here) 55 | pixel_y, pixel_x = coords_grid 56 | pixel_x = pixel_x.to(camera.device, camera.dtype) 57 | pixel_y = pixel_y.to(camera.device, camera.dtype) 58 | 59 | # Account for principal point (offsets from the center) 60 | pixel_x = pixel_x - camera.x0 61 | pixel_y = pixel_y + camera.y0 62 | 63 | # pixel values are now in range [-1, 1], both tensors are of shape res_y x res_x 64 | pixel_x, pixel_y = _to_ndc_coords(pixel_x, pixel_y, camera) 65 | 66 | ray_dir = torch.stack((pixel_x * camera.tan_half_fov(CameraFOV.HORIZONTAL), 67 | -pixel_y * camera.tan_half_fov(CameraFOV.VERTICAL), 68 | -torch.ones_like(pixel_x)), dim=-1) 69 | 70 | ray_dir = ray_dir.reshape(-1, 3) # Flatten grid rays to 1D array 71 | ray_orig = torch.zeros_like(ray_dir) 72 | 73 | # Transform from camera to world coordinates 74 | ray_orig, ray_dir = camera.extrinsics.inv_transform_rays(ray_orig, ray_dir) 75 | ray_dir /= torch.linalg.norm(ray_dir, dim=-1, keepdim=True) 76 | ray_orig, ray_dir = ray_orig[0], ray_dir[0] # Assume a single camera 77 | 78 | return ray_orig, ray_dir 79 | 80 | ######################################################################################################################### -------------------------------------------------------------------------------- /lib/utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | import yaml 4 | import logging 5 | 6 | def parse_options(): 7 | 8 | parser = argparse.ArgumentParser(description='Custom Humans Code') 9 | 10 | 11 | ################### 12 | # Global arguments 13 | ################### 14 | global_group = parser.add_argument_group('global') 15 | global_group.add_argument('--config', type=str, default='config.yaml', 16 | help='Path to config file to replace defaults') 17 | global_group.add_argument('--save-root', type=str, default='./checkpoints/', 18 | help="outputs path") 19 | global_group.add_argument('--exp-name', type=str, default='test', 20 | help="Experiment name.") 21 | global_group.add_argument('--seed', type=int, default=123) 22 | global_group.add_argument('--resume', type=str, default=None, 23 | help='Resume from the checkpoint.') 24 | global_group.add_argument( 25 | '--log_level', action='store', type=int, default=logging.INFO, 26 | help='Logging level to use globally, DEBUG: 10, INFO: 20, WARN: 30, ERROR: 40.') 27 | 28 | ################### 29 | # Arguments for dataset 30 | ################### 31 | data_group = parser.add_argument_group('dataset') 32 | data_group.add_argument('--data-root', type=str, default='CustomHumans.h5', 33 | help='Path to dataset') 34 | data_group.add_argument('--num-samples', type=int, default=20480, 35 | help='Number of samples to use for each subject during training') 36 | data_group.add_argument('--repeat-times', type=int, default=8, 37 | help='Number of times to repeat each subject during training') 38 | 39 | 40 | ################### 41 | # Arguments for optimizer 42 | ################### 43 | optim_group = parser.add_argument_group('optimizer') 44 | optim_group.add_argument('--lr-codebook', type=float, default=0.001, 45 | help='Learning rate for the codebook.') 46 | optim_group.add_argument('--lr-decoder', type=float, default=0.001, 47 | help='Learning rate for the decoder.') 48 | optim_group.add_argument('--lr-dis', type=float, default=0.004, 49 | help='Learning rate for the discriminator.') 50 | optim_group.add_argument('--beta1', type=float, default=0.5, 51 | help='Beta1.') 52 | optim_group.add_argument('--beta2', type=float, default=0.999, 53 | help='Beta2.') 54 | optim_group.add_argument('--weight-decay', type=float, default=0, 55 | help='Weight decay.') 56 | 57 | 58 | ################### 59 | # Arguments for training 60 | ################### 61 | train_group = parser.add_argument_group('train') 62 | train_group.add_argument('--epochs', type=int, default=800, 63 | help='Number of epochs to run the training.') 64 | train_group.add_argument('--batch-size', type=int, default=2, 65 | help='Batch size for the training.') 66 | train_group.add_argument('--workers', type=int, default=0, 67 | help='Number of workers for the data loader. 0 means single process.') 68 | train_group.add_argument('--save-every', type=int, default=50, 69 | help='Save the model at every N epoch.') 70 | train_group.add_argument('--log-every', type=int, default=100, 71 | help='write logs to wandb at every N iters') 72 | train_group.add_argument('--use-2d-from-epoch', type=int, default=-1, 73 | help='Adding 2D loss from this epoch. -1 indicates not using 2D loss.') 74 | train_group.add_argument('--train-2d-every-iter', type=int, default=1, 75 | help='Train 2D loss every N iterations.') 76 | train_group.add_argument('--use-nrm-dis', action='store_true', 77 | help='train with normal loss discriminator.') 78 | train_group.add_argument('--use-cached-pts', action='store_true', 79 | help='Use cached point coordinates instead of online raytracing during training.') 80 | 81 | ################### 82 | # Arguments for Feature Dictionary 83 | ################### 84 | sample_group = parser.add_argument_group('dictionary') 85 | sample_group.add_argument('--shape-dim', type=int, default=32, 86 | help='Dimension of the shape feature code.') 87 | sample_group.add_argument('--color-dim', type=int, default=32, 88 | help='Dimension of the color feature code.') 89 | sample_group.add_argument('--feature-std', type=float, default=0.1, 90 | help='Standard deviation for initializing the feature code.') 91 | sample_group.add_argument('--feature-bias', type=float, default=0.1, 92 | help='Bias for initializing the feature code.') 93 | sample_group.add_argument('--shape-pca-dim', type=int, default=8, 94 | help='Dimension of the shape pca code.') 95 | sample_group.add_argument('--color-pca-dim', type=int, default=16, 96 | help='Dimension of the color pca code.') 97 | 98 | ################### 99 | # Arguments for Network 100 | ################### 101 | net_group = parser.add_argument_group('network') 102 | net_group.add_argument('--pos-dim', type=int, default=3, 103 | help='input position dimension') 104 | net_group.add_argument('--c-dim', type=int, default=0, 105 | help='conditional input dimension, if 0, no conditional input') 106 | net_group.add_argument('--num-layers', type=int, default=4, 107 | help='Number of layers for the MLPs.') 108 | net_group.add_argument('--hidden-dim', type=int, default=128, 109 | help='Network width') 110 | net_group.add_argument('--activation', type=str, default='relu', 111 | choices=['relu', 'sin', 'softplus', 'lrelu']) 112 | net_group.add_argument('--layer-type', type=str, default='none', 113 | choices=['none', 'spectral_norm', 'frobenius_norm', 'l_1_norm', 'l_inf_norm']) 114 | net_group.add_argument('--skip', type=int, nargs='*', default=[2], 115 | help='Layer to have skip connection.') 116 | 117 | ################### 118 | # Embedder arguments 119 | ################### 120 | embedder_group = parser.add_argument_group('embedder') 121 | embedder_group.add_argument('--shape-freq', type=int, default=5, 122 | help='log2 of max freq') 123 | embedder_group.add_argument('--color-freq', type=int, default=10, 124 | help='log2 of max freq') 125 | 126 | 127 | ################### 128 | # Losses arguments 129 | ################### 130 | embedder_group = parser.add_argument_group('losses') 131 | embedder_group.add_argument('--lambda-sdf', type=float, default=1000, 132 | help='lambda for sdf loss') 133 | embedder_group.add_argument('--lambda-rgb', type=float, default=150, 134 | help='lambda for rgb loss') 135 | embedder_group.add_argument('--lambda-nrm', type=float, default=10, 136 | help='lambda for normal loss') 137 | embedder_group.add_argument('--lambda-reg', type=float, default=1, 138 | help='lambda for regularization loss') 139 | embedder_group.add_argument('--gan-loss-type', type=str, default='logistic', 140 | choices=['logistic', 'hinge'], 141 | help='loss type for gan loss') 142 | embedder_group.add_argument('--lambda-gan', type=float, default=1, 143 | help='lambda for gan loss') 144 | embedder_group.add_argument('--lambda-grad', type=float, default=10, 145 | help='lambda for gradient penalty') 146 | 147 | ################### 148 | # Arguments for validation 149 | ################### 150 | valid_group = parser.add_argument_group('validation') 151 | valid_group.add_argument('--valid-every', type=int, default=10, 152 | help='Frequency of running validation.') 153 | valid_group.add_argument('--subdivide', type=bool, default=True, 154 | help='Subdivide the mesh before marching cubes') 155 | valid_group.add_argument('--grid-size', type=int, default=300, 156 | help='Grid size for marching cubes') 157 | valid_group.add_argument('--width', type=int, default=1024, 158 | help='Image width (height) for rendering') 159 | valid_group.add_argument('--fov', type=float, default=20.0, 160 | help='Field of view for rendering') 161 | valid_group.add_argument('--n_views', type=int, default=4, 162 | help='Number of views for rendering') 163 | 164 | ################### 165 | # Arguments for wandb 166 | ################### 167 | wandb_group = parser.add_argument_group('wandb') 168 | 169 | wandb_group.add_argument('--wandb-id', type=str, default=None, 170 | help='wandb id') 171 | wandb_group.add_argument('--wandb', action='store_true', 172 | help='Use wandb') 173 | wandb_group.add_argument('--wandb-name', default='default', type=str, 174 | help='wandb_name') 175 | 176 | return parser 177 | 178 | 179 | def parse_yaml_config(config_path, parser): 180 | """Parses and sets the parser defaults with a yaml config file. 181 | 182 | Args: 183 | config_path : path to the yaml config file. 184 | parser : The parser for which the defaults will be set. 185 | parent : True if parsing the parent yaml. Should never be set to True by the user. 186 | """ 187 | with open(config_path) as f: 188 | config_dict = yaml.safe_load(f) 189 | 190 | list_of_valid_fields = [] 191 | for group in parser._action_groups: 192 | group_dict = {list_of_valid_fields.append(a.dest) for a in group._group_actions} 193 | list_of_valid_fields = set(list_of_valid_fields) 194 | 195 | defaults_dict = {} 196 | 197 | # Loads child parent and overwrite the parent configs 198 | # The yaml files assumes the argument groups, which aren't actually nested. 199 | for key in config_dict: 200 | for field in config_dict[key]: 201 | if field not in list_of_valid_fields: 202 | raise ValueError( 203 | f"ERROR: {field} is not a valid option. Check for typos in the config." 204 | ) 205 | defaults_dict[field] = config_dict[key][field] 206 | 207 | 208 | parser.set_defaults(**defaults_dict) 209 | 210 | def argparse_to_str(parser, args=None): 211 | """Convert parser to string representation for Tensorboard logging. 212 | 213 | Args: 214 | parser (argparse.parser): Parser object. Needed for the argument groups. 215 | args : The parsed arguments. Will compute from the parser if None. 216 | 217 | Returns: 218 | args : The parsed arguments. 219 | arg_str : The string to be printed. 220 | """ 221 | 222 | if args is None: 223 | args = parser.parse_args() 224 | 225 | if args.config is not None: 226 | parse_yaml_config(args.config, parser) 227 | 228 | args = parser.parse_args() 229 | 230 | args_dict = {} 231 | for group in parser._action_groups: 232 | group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions} 233 | args_dict[group.title] = vars(argparse.Namespace(**group_dict)) 234 | 235 | pp = pprint.PrettyPrinter(indent=2) 236 | args_str = pp.pformat(args_dict) 237 | args_str = f'```{args_str}```' 238 | 239 | return args, args_str -------------------------------------------------------------------------------- /lib/utils/image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import cv2 4 | 5 | import torchvision.transforms as transforms 6 | 7 | def update_edited_images(image_path, pickle_path): 8 | with open(pickle_path, 'rb') as f: 9 | data = pickle.load(f) 10 | 11 | img_list = [ os.path.join(image_path, f) for f in sorted(os.listdir(image_path)) if f.endswith('.png') ] 12 | transform = transforms.Compose([ 13 | transforms.ToTensor() 14 | ]) 15 | for i, img in enumerate(img_list): 16 | rgb_img = cv2.imread(img) 17 | rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB) 18 | rgb = transform(rgb_img).permute(1,2,0).view(-1, 3) 19 | data['rgb'][i] = rgb 20 | 21 | return data -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | h5py 3 | smplx 4 | wandb 5 | trimesh 6 | opencv-python 7 | git+https://github.com/tinyobjloader/tinyobjloader.git@v2.0.0rc8#subdirectory=python 8 | 9 | --extra-index-url https://download.pytorch.org/whl/cu113 10 | torch==1.11.0+cu113 11 | torchvision==0.12.0+cu113 12 | 13 | -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.11.0_cu113.html 14 | kaolin==0.12.0 -------------------------------------------------------------------------------- /smplx/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/custom-humans/editable-humans/97ac85b1e5c995ca0c7a16b2a3887992aba838d0/smplx/.gitkeep -------------------------------------------------------------------------------- /tools/align_thuman.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import trimesh 5 | from PIL import Image 6 | import pickle 7 | import cv2 8 | from tqdm import tqdm 9 | from smplx import SMPLX 10 | import json 11 | device = torch.device('cuda') 12 | 13 | SMPLX_PATH = 'smplx' 14 | OUT_PATH = 'new_thuman' 15 | 16 | body_model = SMPLX(model_path=SMPLX_PATH, num_pca_comps=12,gender='male') 17 | 18 | for id in tqdm(range(526)): 19 | name_id = "%04d" % id 20 | input_file = os.path.join('THuman2.0', name_id, name_id + '.obj') 21 | tex_file = os.path.join('THuman2.0', name_id, 'material0.jpeg') 22 | smpl_file = os.path.join('THuman2.0_smplx', name_id, 'smplx_param.pkl') 23 | 24 | smpl_data = pickle.load(open(smpl_file,'rb')) 25 | out_file_name = os.path.splitext(os.path.basename(input_file))[0] 26 | output_aligned_path = os.path.join(OUT_PATH, out_file_name) 27 | os.makedirs(output_aligned_path, exist_ok=True) 28 | 29 | 30 | textured_mesh = trimesh.load(input_file) 31 | 32 | 33 | output = body_model(body_pose = torch.tensor(smpl_data['body_pose']), 34 | betas = torch.tensor(smpl_data['betas']), 35 | left_hand_pose = torch.tensor(smpl_data['left_hand_pose']), 36 | right_hand_pose = torch.tensor(smpl_data['right_hand_pose']), 37 | ) 38 | J_0 = output.joints.detach().cpu().numpy()[0,0,:] 39 | 40 | d = trimesh.Trimesh(vertices=output.vertices.detach().cpu().numpy()[0] -J_0 , 41 | faces=body_model.faces) 42 | 43 | 44 | R = np.asarray(smpl_data['global_orient'][0]) 45 | rot_mat = np.zeros(shape=(3,3)) 46 | rot_mat, _ = cv2.Rodrigues(R) 47 | scale = smpl_data['scale'] 48 | 49 | T = -np.asarray(smpl_data['translation']) 50 | S = np.eye(4) 51 | S[:3, 3] = T 52 | textured_mesh.apply_transform(S) 53 | 54 | S = np.eye(4) 55 | S[:3, :3] *= 1./scale 56 | textured_mesh.apply_transform(S) 57 | 58 | T = -J_0 59 | S = np.eye(4) 60 | S[:3, 3] = T 61 | textured_mesh.apply_transform(S) 62 | 63 | S = np.eye(4) 64 | S[:3, :3] = np.linalg.inv(rot_mat) 65 | textured_mesh.apply_transform(S) 66 | 67 | 68 | 69 | visual = trimesh.visual.texture.TextureVisuals(uv=textured_mesh.visual.uv, image=Image.open(tex_file)) 70 | 71 | t = trimesh.Trimesh(vertices=textured_mesh.vertices, 72 | faces=textured_mesh.faces, 73 | vertex_normals=textured_mesh.vertex_normals, 74 | visual=visual) 75 | 76 | #t = t.simplify_quadratic_decimation(50000) 77 | #t.visual.material.name = out_file_name 78 | 79 | 80 | d.export(os.path.join(output_aligned_path, out_file_name + '_smplx.obj') ) 81 | t.export(os.path.join(output_aligned_path, out_file_name + '.obj') ) 82 | with open(os.path.join(output_aligned_path, out_file_name + '.mtl'), 'w') as f: 83 | f.write('newmtl {}\n'.format(out_file_name)) 84 | f.write('map_Kd {}.jpeg\n'.format(out_file_name)) 85 | 86 | result = {} 87 | result ['transl'] = [0.,0.,0.] 88 | for key, val in smpl_data.items(): 89 | if key not in ['scale', 'translation']: 90 | result[key] = val[0].tolist() 91 | 92 | json_file = os.path.join(output_aligned_path, out_file_name + '_smplx.json') 93 | json.dump(result, open(json_file, 'w'), indent=4) 94 | -------------------------------------------------------------------------------- /tools/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import scipy as sp 4 | import numpy as np 5 | import argparse 6 | import trimesh 7 | 8 | 9 | def calculate_iou(gt, prediction): 10 | intersection = torch.logical_and(gt, prediction) 11 | union = torch.logical_or(gt, prediction) 12 | return torch.sum(intersection) / torch.sum(union) 13 | 14 | def compute_surface_metrics(mesh_pred, mesh_gt): 15 | """Compute surface metrics (chamfer distance and f-score) for one example. 16 | Args: 17 | mesh: trimesh.Trimesh, the mesh to evaluate. 18 | Returns: 19 | chamfer: float, chamfer distance. 20 | fscore: float, f-score. 21 | """ 22 | # Chamfer 23 | eval_points = 1000000 24 | 25 | point_gt, idx_gt = mesh_gt.sample(eval_points, return_index=True) 26 | normal_gt = mesh_gt.face_normals[idx_gt] 27 | point_gt = point_gt.astype(np.float32) 28 | 29 | point_pred, idx_pred = mesh_pred.sample(eval_points, return_index=True) 30 | normal_pred = mesh_pred.face_normals[idx_pred] 31 | point_pred = point_pred.astype(np.float32) 32 | 33 | dist_pred_to_gt, normal_pred_to_gt = distance_field_helper(point_pred, point_gt, normal_pred, normal_gt) 34 | dist_gt_to_pred, normal_gt_to_pred = distance_field_helper(point_gt, point_pred, normal_gt, normal_pred) 35 | 36 | # TODO: subdivide by 2 following OccNet 37 | # https://github.com/autonomousvision/occupancy_networks/blob/406f79468fb8b57b3e76816aaa73b1915c53ad22/im2mesh/eval.py#L136 38 | chamfer_l1 = np.mean(dist_pred_to_gt) + np.mean(dist_gt_to_pred) 39 | 40 | c1 = np.mean(dist_pred_to_gt) 41 | c2 = np.mean(dist_gt_to_pred) 42 | 43 | normal_consistency = np.mean(normal_pred_to_gt) + np.mean(normal_gt_to_pred) 44 | 45 | # Fscore 46 | tau = 1e-4 47 | eps = 1e-9 48 | 49 | dist_pred_to_gt = (dist_pred_to_gt**2) 50 | dist_gt_to_pred = (dist_gt_to_pred**2) 51 | 52 | prec_tau = (dist_pred_to_gt <= tau).astype(np.float32).mean() * 100. 53 | recall_tau = (dist_gt_to_pred <= tau).astype(np.float32).mean() * 100. 54 | 55 | fscore = (2 * prec_tau * recall_tau) / max(prec_tau + recall_tau, eps) 56 | 57 | # Following the tradition to scale chamfer distance up by 10. 58 | return c1 * 1000., c2 * 1000., normal_consistency / 2., fscore 59 | 60 | def distance_field_helper(source, target, normals_src=None, normals_tgt=None): 61 | target_kdtree = sp.spatial.cKDTree(target) 62 | distances, idx = target_kdtree.query(source, n_jobs=-1) 63 | 64 | if normals_src is not None and normals_tgt is not None: 65 | 66 | normals_src = \ 67 | normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) 68 | normals_tgt = \ 69 | normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) 70 | 71 | normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) 72 | # Handle normals that point into wrong direction gracefully 73 | # (mostly due to mehtod not caring about this in generation) 74 | normals_dot_product = np.abs(normals_dot_product) 75 | 76 | else: 77 | normals_dot_product = np.array( 78 | [np.nan] * source.shape[0], dtype=np.float32) 79 | 80 | return distances, normals_dot_product 81 | 82 | 83 | 84 | def main(args): 85 | 86 | input_subfolder = [x for x in sorted(os.listdir(args.input_path)) if x.endswith('obj')] 87 | gt_subfolder = [x for x in sorted(os.listdir(args.gt_path)) if x.endswith('obj')] 88 | 89 | mean_c1 = 0. 90 | mean_c2 = 0. 91 | mean_fscore = 0. 92 | mean_normal_consistency = 0. 93 | 94 | for pred, gt in zip(input_subfolder, gt_subfolder): 95 | mesh_pred = trimesh.load(os.path.join(args.input_path, pred)) 96 | mesh_gt = trimesh.load(os.path.join(args.gt_path, gt)) 97 | 98 | pred_2_scan, scan_2_pred, normal_consistency, fscore = compute_surface_metrics(mesh_pred, mesh_gt) 99 | print('Chamfer: {:.3f}, {:.3f}, Normal Consistency: {:.3f}, Fscore: {:.3f}'.format(pred_2_scan, scan_2_pred, normal_consistency, fscore)) 100 | mean_c1 += pred_2_scan 101 | mean_c2 += scan_2_pred 102 | mean_fscore += fscore 103 | mean_normal_consistency += normal_consistency 104 | 105 | mean_c1 /= len(input_subfolder) 106 | mean_c2 /= len(input_subfolder) 107 | mean_fscore /= len(input_subfolder) 108 | mean_normal_consistency /= len(input_subfolder) 109 | print('Mean Chamfer: {:.3f}, {:.3f}, Normal Consistency: {:.3f}, Fscore: {:.3f}'.format(mean_c1, mean_c2, mean_normal_consistency, mean_fscore)) 110 | print('{:.6f}, {:.6f}, {:.6f}, {:.6f}'.format(mean_c1, mean_c2, mean_normal_consistency, mean_fscore)) 111 | 112 | if __name__ == '__main__': 113 | 114 | parser = argparse.ArgumentParser() 115 | 116 | parser.add_argument('-i', '--input_path', required=True ,type=str) 117 | parser.add_argument('-g', '--gt_path', required=True ,type=str) 118 | 119 | main(parser.parse_args()) 120 | -------------------------------------------------------------------------------- /tools/load_json_to_smplx.py: -------------------------------------------------------------------------------- 1 | import os 2 | from smplx import SMPLX 3 | import torch 4 | import json 5 | import trimesh 6 | import argparse 7 | 8 | SMPL_PATH = 'body_model/smplx/' 9 | ''' 10 | We use the following minimal code snippet to generate the SMPL-X model across all our scans 11 | ''' 12 | def main(args): 13 | 14 | smpl_data = json.load(open(os.path.join(args.input_file))) 15 | 16 | 17 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 18 | 19 | param_betas = torch.tensor(smpl_data['betas'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 20 | param_poses = torch.tensor(smpl_data['body_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 21 | param_left_hand_pose = torch.tensor(smpl_data['left_hand_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 22 | param_right_hand_pose = torch.tensor(smpl_data['right_hand_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 23 | 24 | param_expression = torch.tensor(smpl_data['expression'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 25 | param_jaw_pose = torch.tensor(smpl_data['jaw_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 26 | param_leye_pose = torch.tensor(smpl_data['leye_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 27 | param_reye_pose = torch.tensor(smpl_data['reye_pose'], dtype=torch.float32, device=device).unsqueeze(0).contiguous() 28 | 29 | 30 | body_model = SMPLX(model_path=SMPL_PATH, gender='male', use_pca=True, num_pca_comps=12, flat_hand_mean=True).to(device) 31 | 32 | J_0 = body_model(body_pose = param_poses, betas=param_betas).joints.contiguous().detach() 33 | 34 | 35 | output = body_model(betas=param_betas, 36 | body_pose=param_poses, 37 | transl=-J_0[:,0,:], 38 | left_hand_pose=param_left_hand_pose, 39 | right_hand_pose=param_right_hand_pose, 40 | expression=param_expression, 41 | jaw_pose=param_jaw_pose, 42 | leye_pose=param_leye_pose, 43 | reye_pose=param_reye_pose, 44 | ) 45 | 46 | d = trimesh.Trimesh(vertices=output.vertices.detach().cpu().numpy()[0], faces=body_model.faces) 47 | d.export('smplx.obj') 48 | 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser(description='Minimal code snippet to generate SMPL-X mesh from json file') 53 | 54 | parser.add_argument("-i", "--input-file", default='./mesh-f00021.json', type=str, help="Input json file") 55 | 56 | main(parser.parse_args()) 57 | -------------------------------------------------------------------------------- /tools/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | 5 | DATASET_PATH = 'CustomHumans' 6 | OUTPUT_PATH = 'CustomHumans/training_dataset' 7 | os.makedirs(OUTPUT_PATH, exist_ok=True) 8 | 9 | mesh_path = { x.split('_')[0]:x for x in sorted(os.listdir(os.path.join(DATASET_PATH, 'mesh'))) } 10 | subject_idx = json.load(open('data/Custom_train.json')) 11 | 12 | for idx in subject_idx: 13 | folder_name = mesh_path[idx] 14 | shutil.copytree(os.path.join(DATASET_PATH, 'mesh', folder_name), os.path.join(OUTPUT_PATH, idx), dirs_exist_ok = True) 15 | shutil.copytree(os.path.join(DATASET_PATH, 'smplx', folder_name), os.path.join(OUTPUT_PATH, idx), dirs_exist_ok = True) 16 | 17 | 18 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from datetime import datetime 3 | import logging as log 4 | import numpy as np 5 | import torch 6 | import random 7 | import shutil 8 | import tempfile 9 | import wandb 10 | import pickle 11 | 12 | from torch.utils.data import DataLoader 13 | from lib.datasets.customhumans_dataset import CustomHumanDataset 14 | from lib.models.trainer import Trainer 15 | from lib.models.evaluator import Evaluator 16 | from lib.utils.config import * 17 | 18 | 19 | def create_archive(save_dir, config): 20 | 21 | with tempfile.TemporaryDirectory() as tmpdir: 22 | 23 | shutil.copy(config, os.path.join(tmpdir, 'config.yaml')) 24 | shutil.copy('train.py', os.path.join(tmpdir, 'train.py')) 25 | shutil.copy('test.py', os.path.join(tmpdir, 'test.py')) 26 | 27 | shutil.copytree( 28 | os.path.join('lib'), 29 | os.path.join(tmpdir, 'lib'), 30 | ignore=shutil.ignore_patterns('__pycache__')) 31 | 32 | shutil.make_archive( 33 | os.path.join(save_dir, 'code_copy'), 34 | 'zip', 35 | tmpdir) 36 | 37 | 38 | def main(config): 39 | 40 | # Set random seed. 41 | random.seed(config.seed) 42 | np.random.seed(config.seed) 43 | torch.manual_seed(config.seed) 44 | 45 | log_dir = os.path.join( 46 | config.save_root, 47 | config.exp_name, 48 | f'{datetime.now().strftime("%Y%m%d-%H%M%S")}' 49 | ) 50 | 51 | # Backup code. 52 | create_archive(log_dir, config.config) 53 | 54 | # Initialize dataset and dataloader. 55 | 56 | with open('data/smpl_mesh.pkl', 'rb') as f: 57 | smpl_mesh = pickle.load(f) 58 | 59 | dataset = CustomHumanDataset(config.num_samples, config.repeat_times) 60 | dataset.init_from_h5(config.data_root) 61 | 62 | loader = DataLoader(dataset=dataset, 63 | batch_size=config.batch_size, 64 | shuffle=True, 65 | num_workers=config.workers, 66 | pin_memory=True) 67 | 68 | 69 | trainer = Trainer(config, dataset.smpl_V, smpl_mesh['smpl_F'], log_dir) 70 | 71 | evaluator = Evaluator(config, log_dir) 72 | 73 | 74 | if config.wandb_id is not None: 75 | wandb_id = config.wandb_id 76 | else: 77 | wandb_id = wandb.util.generate_id() 78 | with open(os.path.join(log_dir, 'wandb_id.txt'), 'w+') as f: 79 | f.write(wandb_id) 80 | 81 | wandb_mode = "disabled" if (not config.wandb) else "online" 82 | wandb.init(id=wandb_id, 83 | project=config.wandb_name, 84 | config=config, 85 | name=os.path.basename(log_dir), 86 | resume="allow", 87 | settings=wandb.Settings(start_method="fork"), 88 | mode=wandb_mode, 89 | dir=log_dir) 90 | wandb.watch(trainer) 91 | 92 | if config.resume: 93 | trainer.load_checkpoint(config.resume) 94 | 95 | 96 | global_step = trainer.global_step 97 | start_epoch = trainer.epoch 98 | 99 | 100 | for epoch in range(start_epoch, config.epochs): 101 | for data in loader: 102 | trainer.step(epoch=epoch, n_iter=global_step, data=data) 103 | 104 | if global_step % config.log_every == 0: 105 | trainer.log(global_step, epoch) 106 | 107 | if config.use_2d_from_epoch >= 0 and \ 108 | epoch >= config.use_2d_from_epoch and \ 109 | global_step % config.log_every == 0: 110 | trainer.write_images(global_step) 111 | 112 | global_step += 1 113 | 114 | if epoch % config.save_every == 0: 115 | trainer.save_checkpoint(full=False) 116 | 117 | if epoch % config.valid_every == 0 and epoch > 0: 118 | evaluator.init_models(trainer) 119 | evaluator.reconstruction(32, epoch=epoch) 120 | 121 | wandb.finish() 122 | 123 | if __name__ == "__main__": 124 | 125 | parser = parse_options() 126 | args, args_str = argparse_to_str(parser) 127 | handlers = [log.StreamHandler(sys.stdout)] 128 | log.basicConfig(level=args.log_level, 129 | format='%(asctime)s|%(levelname)8s| %(message)s', 130 | handlers=handlers) 131 | log.info(f'Info: \n{args_str}') 132 | main(args) --------------------------------------------------------------------------------