├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── asset ├── flame │ ├── head_template_mesh.obj │ ├── landmark_embedding_with_eyes.npy │ ├── tex_mean_painted.png │ └── uv_masks.npz ├── flame_editor.png ├── flame_viewer.png ├── monocular.jpg ├── monocular_obama.gif ├── monocular_obama.mp4 ├── monocular_person_0004.gif ├── monocular_person_0004.mp4 ├── nersemble.jpg ├── nersemble_038_EMO-1.gif ├── nersemble_038_EMO-1.mp4 ├── nersemble_074_EMO-1.gif ├── nersemble_074_EMO-1.mp4 └── teaser.gif ├── doc ├── monocular.md ├── nersemble.md └── nersemble_v2.md ├── jobs ├── combine_nersemble.sh ├── run_monocular.sh └── run_nersemble.sh ├── pyproject.toml └── vhap ├── combine_nerf_datasets.py ├── config ├── base.py ├── nersemble.py └── nersemble_v2.py ├── data ├── image_folder_dataset.py ├── nerf_dataset.py ├── nersemble_dataset.py ├── nersemble_v2_dataset.py └── video_dataset.py ├── export_as_nerf_dataset.py ├── flame_editor.py ├── flame_viewer.py ├── generate_flame_uvmask.py ├── model ├── flame.py ├── lbs.py └── tracker.py ├── preprocess_video.py ├── track.py ├── track_nersemble.py ├── track_nersemble_v2.py └── util ├── camera.py ├── color_correction.py ├── landmark_detector_fa.py ├── landmark_detector_star.py ├── log.py ├── mesh.py ├── render_nvdiffrast.py ├── render_uvmap.py ├── vector_ops.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .pyenv 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | bin/ 117 | include/ 118 | share/ 119 | pip-selfcheck.json 120 | lib64 121 | pyvenv.cfg 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # PyTest 142 | .pytest_cache 143 | 144 | # project specific files 145 | deps/InsightFace-PyTorch 146 | **/.idea/* 147 | tags 148 | .vscode/ 149 | 150 | .DS_Store 151 | .pkl 152 | tmp 153 | 154 | output/ 155 | export/ 156 | /data/ -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Qian" 5 | given-names: "Shenhan" 6 | orcid: "https://orcid.org/0000-0003-0416-7548" 7 | title: "VHAP: Versatile Head Alignment with Adaptive Appearance Priors" 8 | version: 0.0.3 9 | doi: 10.5281/zenodo.14988309 10 | date-released: 2024-09-05 11 | url: "https://github.com/ShenhanQian/VHAP" 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VHAP: Versatile Head Alignment with Adaptive Appearance Priors 2 | 3 |
4 | 5 |
6 | 7 | ## TL;DR 8 | 9 | - A photometric optimization pipeline based on differentiable mesh rasterization, applied to human head alignment. 10 | - A perturbation mechanism that implicitly extract and inject regional appearance priors adaptively during rendering, enabling alignment of regions purely based on their appearance consistency, such as the hair, ears, neck, and shoulders, where no pre-defined landmarks are available. 11 | - The exported tracking results can be directly used to create you own [GaussianAvatars](https://github.com/ShenhanQian/GaussianAvatars). 12 | 13 | ## License 14 | 15 | This work is made available under [CC-BY-NC-SA-4.0](./LICENSE). The repository is derived from the [multi-view head tracker of GaussianAvatars](https://github.com/ShenhanQian/GaussianAvatars/tree/main/reference_tracker), which is subjected to the following statements: 16 | 17 | > Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual property and proprietary rights in and to this software and related documentation. Any commercial use, reproduction, disclosure or distribution of this software and related documentation without an express license agreement from Toyota Motor Europe NV/SA is strictly prohibited. 18 | 19 | On top of the original repository, we add support to monocular videos and provide a complete set of scripts from video preprocessing to result export for NeRF/3DGS-style applications. 20 | 21 | ## Setup 22 | 23 | ```shell 24 | git clone git@github.com:ShenhanQian/VHAP.git 25 | cd VHAP 26 | 27 | conda create --name VHAP -y python=3.10 28 | conda activate VHAP 29 | 30 | # Install CUDA and ninja for compilation 31 | conda install -c "nvidia/label/cuda-12.1.1" cuda-toolkit ninja cmake # use the right CUDA version 32 | ln -s "$CONDA_PREFIX/lib" "$CONDA_PREFIX/lib64" # to avoid error "/usr/bin/ld: cannot find -lcudart" 33 | conda env config vars set CUDA_HOME=$CONDA_PREFIX # for compilation 34 | 35 | # Install PyTorch (make sure that the CUDA version matches with "Step 1") 36 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 37 | # or 38 | conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia 39 | # make sure torch.cuda.is_available() returns True 40 | 41 | pip install -e . 42 | ``` 43 | 44 | > [!NOTE] 45 | > - We use an adjusted version of [nvdiffrast](https://github.com/ShenhanQian/nvdiffrast/tree/backface-culling) for backface-culling. If you have other versions installed before, you can reinstall as follows: 46 | > ```shell 47 | > pip install nvdiffrast@git+https://github.com/ShenhanQian/nvdiffrast@backface-culling --force-reinstall 48 | > rm -r ~/.cache/torch_extensions/*/nvdiffrast* 49 | > ``` 50 | > - We use [STAR](https://github.com/ShenhanQian/STAR/) for landmark detection by default. Alterntively, [face-alignment](https://github.com/1adrianb/face-alignment) is faster but less accurate. 51 | 52 | ## Download 53 | 54 | ### FLAME 55 | 56 | Our code relies on FLAME. Please download assets from the [official website](https://flame.is.tue.mpg.de/download.php) and store them in the paths below: 57 | 58 | - FLAME 2023 (versions w/ jaw rotation) -> `asset/flame/flame2023.pkl` 59 | - FLAME Vertex Masks -> `asset/flame/FLAME_masks.pkl` 60 | 61 | > [!NOTE] 62 | > It is possible to use FLAME 2020 by download to `asset/flame/generic_model.pkl`. The `FLAME_MODEL_PATH` in `flame.py` needs to be updated accordingly. 63 | 64 | ### Video Data 65 | 66 | #### Multiview 67 | 68 | To get access to [NeRSemble](https://tobias-kirschstein.github.io/nersemble/) dataset, please request via the [Google Form](https://forms.gle/rYRoGNh2ed51TDWX9). The directory structure is expected to be like [this](https://github.com/ShenhanQian/VHAP/blob/c9ea660c6c6719110eca5ffdaf9029a2596cc5ca/vhap/data/nersemble_dataset.py#L32-L54). 69 | 70 | > [!NOTE] 71 | > The NeRSemble dataset has been updated to Version 2. Its folder structure and color correction algorithm differ from those in Version 1, so please be careful not to confuse the two. 72 | 73 | #### Monocular 74 | 75 | We use monocular video sequences following [INSTA](https://zielon.github.io/insta/). You can download raw videos from [LRZ](https://syncandshare.lrz.de/getlink/fiJE46wKrG6oTVZ16CUmMr/VHAP). 76 | 77 | ## Usage 78 | 79 | ### Monocular 80 | [For Monocular Videos](doc/monocular.md) 81 | 82 |
83 | 84 |
85 | 86 | ### Multiview 87 | [For NeRSemble Dataset](doc/nersemble.md) 88 | 89 | [For NeRSemble Dataset V2](doc/nersemble_v2.md) 90 | 91 |
92 | 93 |
94 | 95 | ## Discussions 96 | 97 | Photometric alignment is versatile but sometimes sensitive. 98 | 99 | **Texture map regularization**: Our method relies on a total-variation regularization on the texture map. Its loss weight is by default `1e4` for a monocualr video and `1e5` for the NeRSemble dataset (16 views). For you own multi-view dataset with fewer views, you should lower the regularization by passing `--w.reg_tex_tv 1e4` or 3e4. Otherwise, you may encounter corrupted shapes and blurry textures similar to https://github.com/ShenhanQian/VHAP/issues/10#issue-2558743737 and https://github.com/ShenhanQian/VHAP/issues/6#issue-2524833245. 100 | 101 | **Color affinity:** If the color of a point on the foreground contour is too close to the background, the [`static_offset`](https://github.com/ShenhanQian/VHAP/blob/64c18060e7aad104bf05a2c06aab7818f54af6bd/vhap/model/flame.py#L583) can go wild. You may try a different background color by `--data.background_color white` or `--data.background_color black`. You can also disable `static_offset` by `--model.no_use_static_offset`. 102 | 103 | **Occlussion:** When the neck is occluded by collars, the photometric gradients may squeeze and stretch the neck into unnatural shapes. Usually, this problem can be relieved by disabling photometric alignment in certain regions. We hard-coded the occlusion status for some subjects in the NeRSemble dataset with the [`occluded_table`](https://github.com/ShenhanQian/VHAP/blob/51a2792bd3ad3f920d9cd8f1b107a56b92349520/vhap/config/nersemble.py#L71). You can extend the table or temporally change it by, e.g., `--model.occluded neck_lower boundary`. 104 | 105 | **Limited degree of freedom:** Another limitation comes from the FLAME model. FLAME is great since it covers the whole head and neck. However, there is only one joint for the neck, between the neck and the head. This means the lower part of the neck cannot move relative to the torse. This limits the model's ability to capture large movement of the head. For example, it's very hard to achieve good alignment of the lower neck and the head at the same time for the *EXP-1-head* sequence in NeRSemble dataset because of the aforementioned lack of degree of freedom. 106 | 107 | **You are welcomed to report more failure cases and help us improve the tracker.** 108 | 109 | ## Interactive Viewers 110 | 111 | Our method relies on vertex masks defined on FLAME. We add custom masks to enrich the original ones. You can play with `regions` in our FLAME Editor to see how each mask look like . 112 | 113 | ```shell 114 | python vhap/flame_editor.py 115 | ``` 116 | 117 | We also provide a FLAME viewer for you to interact with a tracked sequence. 118 | 119 | ```shell 120 | python vhap/flame_viewer.py \ 121 | --param_path output/nersemble/074_EMO-1_v16_DS4_wBg_staticOffset/2024-09-09_15-49-02/tracked_flame_params_30.npz \ 122 | ``` 123 | 124 | Optional, you can enable colored rendering by specifying a texture image with `--tex_path`. 125 | 126 | For both viewers, you can switch to flat shading with `--no-shade-smooth`. 127 | 128 |
129 | 130 | 131 |
132 | 133 | ## Cite 134 | 135 | Please kindly cite our repository and preceding paper if you find our software or algorithm useful for your research. 136 | 137 | ```bibtex 138 | @misc{qian2024vhap, 139 | title={VHAP: Versatile Head Alignment with Adaptive Appearance Priors}, 140 | author={Qian, Shenhan}, 141 | year={2024}, 142 | month={sep}, 143 | doi={10.5281/zenodo.14988309} 144 | url={https://github.com/ShenhanQian/VHAP} 145 | } 146 | ``` 147 | 148 | ```bibtex 149 | @inproceedings{qian2024gaussianavatars, 150 | title={Gaussianavatars: Photorealistic head avatars with rigged 3d gaussians}, 151 | author={Qian, Shenhan and Kirschstein, Tobias and Schoneveld, Liam and Davoli, Davide and Giebenhain, Simon and Nie{\ss}ner, Matthias}, 152 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 153 | pages={20299--20309}, 154 | year={2024} 155 | } 156 | ``` 157 | -------------------------------------------------------------------------------- /asset/flame/landmark_embedding_with_eyes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame/landmark_embedding_with_eyes.npy -------------------------------------------------------------------------------- /asset/flame/tex_mean_painted.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame/tex_mean_painted.png -------------------------------------------------------------------------------- /asset/flame/uv_masks.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame/uv_masks.npz -------------------------------------------------------------------------------- /asset/flame_editor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame_editor.png -------------------------------------------------------------------------------- /asset/flame_viewer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame_viewer.png -------------------------------------------------------------------------------- /asset/monocular.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular.jpg -------------------------------------------------------------------------------- /asset/monocular_obama.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular_obama.gif -------------------------------------------------------------------------------- /asset/monocular_obama.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular_obama.mp4 -------------------------------------------------------------------------------- /asset/monocular_person_0004.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular_person_0004.gif -------------------------------------------------------------------------------- /asset/monocular_person_0004.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular_person_0004.mp4 -------------------------------------------------------------------------------- /asset/nersemble.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble.jpg -------------------------------------------------------------------------------- /asset/nersemble_038_EMO-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble_038_EMO-1.gif -------------------------------------------------------------------------------- /asset/nersemble_038_EMO-1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble_038_EMO-1.mp4 -------------------------------------------------------------------------------- /asset/nersemble_074_EMO-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble_074_EMO-1.gif -------------------------------------------------------------------------------- /asset/nersemble_074_EMO-1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble_074_EMO-1.mp4 -------------------------------------------------------------------------------- /asset/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/teaser.gif -------------------------------------------------------------------------------- /doc/monocular.md: -------------------------------------------------------------------------------- 1 | ## For Monocular Videos 2 | 3 |
4 | 5 |
6 | 7 | ### 1. Preprocess 8 | 9 | This step extracts frames from video(s), then run foreground matting for each frame, which requires GPU. 10 | 11 | ```shell 12 | SEQUENCE="obama.mp4" 13 | 14 | python vhap/preprocess_video.py \ 15 | --input data/monocular/${SEQUENCE} \ 16 | --matting_method robust_video_matting 17 | ``` 18 | 19 | - `--matting_method robust_video_matting`: Use RobustVideoMatting due to lack of a background image. 20 | - (Optional) `--downsample_scales 2`: Generate downsampled versions of images in a scale such as 2. (Image size lower than 1024 is preferred for efficiency.) 21 | 22 | ### 2. Align and track faces 23 | 24 | This step automatically detects facial landmarks if absent, then begin FLAME tracking. We initialize shape and appearance parameters on the first frame, then do a sequential tracking of following frames. After the sequence tracking, we conduct 30 epochs of global tracking, which optimize all the parameters on a random frame in each iteration. 25 | 26 | ```shell 27 | SEQUENCE="obama" 28 | TRACK_OUTPUT_FOLDER="output/monocular/${SEQUENCE}_whiteBg_staticOffset" 29 | 30 | python vhap/track.py --data.root_folder "data/monocular" \ 31 | --exp.output_folder $TRACK_OUTPUT_FOLDER \ 32 | --data.sequence $SEQUENCE \ 33 | # --data.n_downsample_rgb 2 # Only specify this if you have generate downsampled images during preprocessing. 34 | ``` 35 | 36 | Optional arguments 37 | 38 | - `--model.no_use_static_offset`: disable static offset for FLAME (very stable, but less aligned facial geometry) 39 | 40 | > Disabling static offset will automatically triggers `--model.occluded hair`, which is crucial to prevent the head from growing too larger to align with the top of hair. 41 | 42 | - `--exp.no_photometric`: track only with landmark (very fast, but coarse) 43 | 44 | ### 3. Export tracking results into a NeRF-style dataset 45 | 46 | Given the tracked FLAME parameters from the above step, you can export the results to form a NeRF/3DGS style sequence, consisting of image folders and a `transforms.json`. 47 | 48 | ```shell 49 | SEQUENCE="obama" 50 | TRACK_OUTPUT_FOLDER="output/monocular/${SEQUENCE}_whiteBg_staticOffset" 51 | EXPORT_OUTPUT_FOLDER="export/monocular/${SEQUENCE}_whiteBg_staticOffset_maskBelowLine" 52 | 53 | python vhap/export_as_nerf_dataset.py \ 54 | --src_folder ${TRACK_OUTPUT_FOLDER} \ 55 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white 56 | ``` 57 | -------------------------------------------------------------------------------- /doc/nersemble.md: -------------------------------------------------------------------------------- 1 | ## For NeRSemble Dataset 2 | 3 |
4 | 5 |
6 | 7 | ### 1. Preprocess 8 | 9 | This step extracts frames from video(s), then run foreground matting for each frame, which requires GPU. 10 | 11 | ```shell 12 | SUBJECT="074" 13 | SEQUENCE="EMO-1" 14 | 15 | python vhap/preprocess_video.py \ 16 | --input data/nersemble/${SUBJECT}/${SEQUENCE}* \ 17 | --downsample_scales 2 4 \ 18 | --matting_method background_matting_v2 19 | ``` 20 | 21 | - `--downsample_scales 2 4`: Generate downsampled versions of the images in scale 2 and 4. 22 | - `--matting_method background_matting_v2`: Use BackGroundMatingV2 due to availability of background images. 23 | 24 | After preprocessing, you can inspect images, masks, and landmarks with our [NeRSemble Data Viewer](https://github.com/ShenhanQian/nersemble-data-viewer). 25 | 26 | ### 2. Align and track faces 27 | 28 | This step automatically detects facial landmarks if absent, then begin FLAME tracking. We initialize shape and appearance parameters on the first frame, then do a sequential tracking of following frames. After the sequence tracking, we conduct 30 epochs of global tracking, which optimize all the parameters on a random frame in each iteration. 29 | 30 | ```shell 31 | SUBJECT="074" 32 | SEQUENCE="EMO-1" 33 | TRACK_OUTPUT_FOLDER="output/nersemble/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset" 34 | 35 | python vhap/track_nersemble.py --data.root_folder "data/nersemble" \ 36 | --exp.output_folder $TRACK_OUTPUT_FOLDER \ 37 | --data.subject $SUBJECT --data.sequence $SEQUENCE \ 38 | --data.n_downsample_rgb 4 39 | ``` 40 | 41 | Optional arguments 42 | 43 | - `--model.no_use_static_offset`: disable static offset for FLAME (very stable, but less aligned facial geometry) 44 | 45 | > Disabling static offset will automatically triggers `--model.occluded hair`, which is crucial to prevent the head from growing too larger to align with the top of hair. 46 | 47 | - `--exp.no_photometric`: track only with landmark (very fast, but coarse) 48 | 49 | > [!NOTE] 50 | > We use all 16 views for the optimization, but we only visualize 3 views for efficiency. 51 | 52 | ### 3. Export tracking results into a NeRF-style dataset 53 | 54 | Given the tracked FLAME parameters from the above step, you can export the results to form a NeRF/3DGS style sequence, consisting of image folders and a `transforms.json`. 55 | 56 | ```shell 57 | SUBJECT="074" 58 | SEQUENCE="EMO-1" 59 | TRACK_OUTPUT_FOLDER="output/nersemble/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset" 60 | EXPORT_OUTPUT_FOLDER="export/nersemble/${SUBJECT}_${SEQUENCE}_v16_DS4_whiteBg_staticOffset_maskBelowLine" 61 | 62 | python vhap/export_as_nerf_dataset.py \ 63 | --src_folder ${TRACK_OUTPUT_FOLDER} \ 64 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white 65 | ``` 66 | 67 | ### 4. Combine exported sequences of the same person as a union dataset 68 | 69 | ```shell 70 | SUBJECT="074" 71 | 72 | python vhap/combine_nerf_datasets.py \ 73 | --src_folders \ 74 | export/nersemble/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 75 | export/nersemble/${SUBJECT}_EMO-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 76 | export/nersemble/${SUBJECT}_EMO-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 77 | export/nersemble/${SUBJECT}_EMO-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 78 | export/nersemble/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 79 | export/nersemble/${SUBJECT}_EXP-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 80 | export/nersemble/${SUBJECT}_EXP-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 81 | export/nersemble/${SUBJECT}_EXP-5_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 82 | export/nersemble/${SUBJECT}_EXP-8_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 83 | export/nersemble/${SUBJECT}_EXP-9_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 84 | --tgt_folder \ 85 | export/nersemble/UNION10_${SUBJECT}_EMO1234EXP234589_v16_DS4_whiteBg_staticOffset_maskBelowLine 86 | ``` 87 | 88 | > [!NOTE] 89 | > The `tgt_folder` must be in the same parent folder as `src_folders` because the union dataset read from the original image files by relative paths. 90 | -------------------------------------------------------------------------------- /doc/nersemble_v2.md: -------------------------------------------------------------------------------- 1 | ## For NeRSemble Dataset V2 2 | 3 |
4 | 5 |
6 | 7 | ### 1. Preprocess 8 | 9 | This step extracts frames from video(s), then run foreground matting for each frame, which requires GPU. 10 | 11 | ```shell 12 | SUBJECT="074" 13 | SEQUENCE="EMO-1" 14 | 15 | python vhap/preprocess_video.py \ 16 | --input data/nersemble_v2/${SUBJECT}/sequences/${SEQUENCE}* \ 17 | --downsample_scales 2 4 \ 18 | --matting_method background_matting_v2 19 | ``` 20 | 21 | - `--downsample_scales 2 4`: Generate downsampled versions of the images in scale 2 and 4. 22 | - `--matting_method background_matting_v2`: Use BackGroundMatingV2 due to availability of background images. 23 | 24 | After preprocessing, you can inspect images, masks, and landmarks with our [NeRSemble Data Viewer](https://github.com/ShenhanQian/nersemble-data-viewer). 25 | 26 | ### 2. Align and track faces 27 | 28 | This step automatically detects facial landmarks if absent, then begin FLAME tracking. We initialize shape and appearance parameters on the first frame, then do a sequential tracking of following frames. After the sequence tracking, we conduct 30 epochs of global tracking, which optimize all the parameters on a random frame in each iteration. 29 | 30 | ```shell 31 | SUBJECT="074" 32 | SEQUENCE="EMO-1" 33 | TRACK_OUTPUT_FOLDER="output/nersemble_v2/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset" 34 | 35 | python vhap/track_nersemble_v2.py --data.root_folder "data/nersemble_v2" \ 36 | --exp.output_folder $TRACK_OUTPUT_FOLDER \ 37 | --data.subject $SUBJECT --data.sequence $SEQUENCE \ 38 | --data.n_downsample_rgb 4 39 | ``` 40 | 41 | Optional arguments 42 | 43 | - `--model.no_use_static_offset`: disable static offset for FLAME (very stable, but less aligned facial geometry) 44 | 45 | > Disabling static offset will automatically triggers `--model.occluded hair`, which is crucial to prevent the head from growing too larger to align with the top of hair. 46 | 47 | - `--exp.no_photometric`: track only with landmark (very fast, but coarse) 48 | 49 | > [!NOTE] 50 | > We use all 16 views for the optimization, but we only visualize 3 views for efficiency. 51 | 52 | > [!WARNING] 53 | > NeRSemble Dataset V2 comes with improved color calibration. However, this may considerably slow down image loading, particularly at full resolution. 54 | 55 | 56 | ### 3. Export tracking results into a NeRF-style dataset 57 | 58 | Given the tracked FLAME parameters from the above step, you can export the results to form a NeRF/3DGS style sequence, consisting of image folders and a `transforms.json`. 59 | 60 | ```shell 61 | SUBJECT="074" 62 | SEQUENCE="EMO-1" 63 | TRACK_OUTPUT_FOLDER="output/nersemble_v2/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset" 64 | EXPORT_OUTPUT_FOLDER="export/nersemble_v2/${SUBJECT}_${SEQUENCE}_v16_DS4_whiteBg_staticOffset_maskBelowLine" 65 | 66 | python vhap/export_as_nerf_dataset.py \ 67 | --src_folder ${TRACK_OUTPUT_FOLDER} \ 68 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white 69 | ``` 70 | 71 | ### 4. Combine exported sequences of the same person as a union dataset 72 | 73 | ```shell 74 | SUBJECT="074" 75 | 76 | python vhap/combine_nerf_datasets.py \ 77 | --src_folders \ 78 | export/nersemble_v2/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 79 | export/nersemble_v2/${SUBJECT}_EMO-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 80 | export/nersemble_v2/${SUBJECT}_EMO-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 81 | export/nersemble_v2/${SUBJECT}_EMO-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 82 | export/nersemble_v2/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 83 | export/nersemble_v2/${SUBJECT}_EXP-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 84 | export/nersemble_v2/${SUBJECT}_EXP-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 85 | export/nersemble_v2/${SUBJECT}_EXP-5_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 86 | export/nersemble_v2/${SUBJECT}_EXP-8_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 87 | export/nersemble_v2/${SUBJECT}_EXP-9_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 88 | --tgt_folder \ 89 | export/nersemble_v2/UNION10_${SUBJECT}_EMO1234EXP234589_v16_DS4_whiteBg_staticOffset_maskBelowLine 90 | ``` 91 | 92 | > [!NOTE] 93 | > The `tgt_folder` must be in the same parent folder as `src_folders` because the union dataset read from the original image files by relative paths. 94 | -------------------------------------------------------------------------------- /jobs/combine_nersemble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Define a list of subjects and sequences 5 | # SUBJECTS=("074" "104" "140" "165" "210" "218" "238" "253" "264" "302" "304" "306") # 12 6 | # SUBJECTS=("074" "104" "140" "165" "175" "210" "218" "238" "253" "264" "302" "304" "306" "460") # 14 7 | # SUBJECTS=("306") # tmp 8 | 9 | # Loop through subjects and sequences and modify the command accordingly 10 | for SUBJECT in "${SUBJECTS[@]}"; do 11 | export_dir=export/nersemble 12 | 13 | #------- combine 10 -------# 14 | COMMAND="python vhap/combine_nerf_datasets.py \ 15 | --src_folders \ 16 | $export_dir/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 17 | $export_dir/${SUBJECT}_EMO-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 18 | $export_dir/${SUBJECT}_EMO-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 19 | $export_dir/${SUBJECT}_EMO-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 20 | $export_dir/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 21 | $export_dir/${SUBJECT}_EXP-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 22 | $export_dir/${SUBJECT}_EXP-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 23 | $export_dir/${SUBJECT}_EXP-5_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 24 | $export_dir/${SUBJECT}_EXP-8_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 25 | $export_dir/${SUBJECT}_EXP-9_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 26 | --tgt_folder \ 27 | $export_dir/UNION10_${SUBJECT}_EMO1234EXP234589_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 28 | 29 | " 30 | 31 | #------- combine 20 -------# 32 | # COMMAND="python vhap/combine_nerf_datasets.py \ 33 | # --src_folders \ 34 | # $export_dir/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 35 | # $export_dir/${SUBJECT}_EMO-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 36 | # $export_dir/${SUBJECT}_EMO-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 37 | # $export_dir/${SUBJECT}_EMO-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 38 | # $export_dir/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 39 | # $export_dir/${SUBJECT}_EXP-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 40 | # $export_dir/${SUBJECT}_EXP-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 41 | # $export_dir/${SUBJECT}_EXP-5_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 42 | # $export_dir/${SUBJECT}_EXP-8_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 43 | # $export_dir/${SUBJECT}_EXP-9_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 44 | # $export_dir/${SUBJECT}_SEN-01_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 45 | # $export_dir/${SUBJECT}_SEN-02_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 46 | # $export_dir/${SUBJECT}_SEN-03_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 47 | # $export_dir/${SUBJECT}_SEN-04_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 48 | # $export_dir/${SUBJECT}_SEN-05_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 49 | # $export_dir/${SUBJECT}_SEN-06_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 50 | # $export_dir/${SUBJECT}_SEN-07_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 51 | # $export_dir/${SUBJECT}_SEN-08_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 52 | # $export_dir/${SUBJECT}_SEN-09_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 53 | # $export_dir/${SUBJECT}_SEN-10_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 54 | # --tgt_folder \ 55 | # $export_dir/UNION20_${SUBJECT}_EMO1234EXP234589SEN_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 56 | 57 | # " 58 | 59 | #------- combine short -------# 60 | # COMMAND="python vhap/combine_nerf_datasets.py \ 61 | # --src_folders \ 62 | # $export_dir/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 63 | # $export_dir/${SUBJECT}_EXP-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 64 | # $export_dir/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 65 | # $export_dir/${SUBJECT}_SEN-10_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 66 | # $export_dir/${SUBJECT}_FREE_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 67 | # --tgt_folder \ 68 | # $export_dir/UNIONshort_${SUBJECT}_EMO1EXP12SEN10FREE_v16_DS4_whiteBg_staticOffset_maskBelowLine \ 69 | 70 | # " 71 | 72 | #------- zip -------# 73 | # COMMAND="zip -r $export_dir/$SUBJECT.zip $export_dir/$SUBJECT* $export_dir/UNION_$SUBJECT*" 74 | 75 | #======= Run =======# 76 | # Execute the command (remove the echo if you want to actually run the command) 77 | 78 | $COMMAND 79 | # echo $COMMAND 80 | 81 | done 82 | -------------------------------------------------------------------------------- /jobs/run_monocular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define a list of sequences 4 | 5 | # SEQUENCES=("obama" "biden" "justin" "nf_01" "nf_03" "person_0004" "malte_1" "bala" "wojtek_1" "marcel") # 10 6 | # SEQUENCES=("obama") 7 | 8 | DATA_FOLDER="data/monocular" 9 | 10 | # Loop through sequences and modify the command accordingly 11 | for SEQUENCE in "${SEQUENCES[@]}"; do 12 | 13 | JOB_NAME="vhap_${SEQUENCE}" 14 | 15 | #======= Preprocess =======# 16 | RAW_VIDEO_PATH="${DATA_FOLDER}/${SEQUENCE}.mp4" 17 | PREPROCESS_COMMAND=" 18 | python vhap/preprocess_video.py --input ${RAW_VIDEO_PATH} \ 19 | --matting_method robust_video_matting 20 | " 21 | 22 | #======= Track =======# 23 | TRACK_OUTPUT_FOLDER="output/monocular/${SEQUENCE}_whiteBg_staticOffset" 24 | TRACK_COMMAND=" 25 | python vhap/track.py --data.root_folder ${DATA_FOLDER} \ 26 | --exp.output_folder $TRACK_OUTPUT_FOLDER \ 27 | --data.sequence $SEQUENCE \ 28 | 29 | " 30 | 31 | #======= Export =======# 32 | EXPORT_OUTPUT_FOLDER="export/monocular/${SEQUENCE}_whiteBg_staticOffset_maskBelowLine" 33 | EXPORT_COMMAND="python vhap/export_as_nerf_dataset.py \ 34 | --src_folder ${TRACK_OUTPUT_FOLDER} \ 35 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white \ 36 | 37 | " 38 | 39 | #======= Run =======# 40 | # Execute the command (remove the echo if you want to actually run the command) 41 | 42 | #------- check completeness -------# 43 | # last_folder=$(find "$TRACK_OUTPUT_FOLDER" -maxdepth 1 -type d | sort | tail -n 1) 44 | # # if [ ! -d $last_folder/eval_30 ]; then 45 | # if [ ! -e $last_folder/tracked_flame_params_30.npz ]; then 46 | # echo $last_folder 47 | # fi 48 | 49 | #------- create video -------# 50 | # last_folder=$(find "$TRACK_OUTPUT_FOLDER" -maxdepth 1 -type d | sort | tail -n 1) 51 | # video_folder=$last_folder/eval_30 52 | # ffmpeg -y -framerate 25 -f image2 -pattern_type glob -i "${video_folder}/image_grid/*.jpg" -pix_fmt yuv420p $video_folder/image_grid.mp4 53 | 54 | #------- rename -------# 55 | # mv $TRACK_OUTPUT_FOLDER $TRACK_OUTPUT_FOLDER 56 | 57 | #------- only preprocess -------# 58 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $PREPROCESS_COMMAND" 59 | # COMMAND="$PREPROCESS_COMMAND" 60 | 61 | #------- only track -------# 62 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $TRACK_COMMAND" 63 | # COMMAND="$TRACK_COMMAND" 64 | 65 | #------- only export -------# 66 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $EXPORT_COMMAND" 67 | # COMMAND="$EXPORT_COMMAND" 68 | 69 | #------- track and export -------# 70 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $TRACK_COMMAND && $EXPORT_COMMAND" 71 | # COMMAND="$TRACK_COMMAND && $EXPORT_COMMAND" 72 | 73 | 74 | # echo $COMMAND 75 | $COMMAND 76 | sleep 1 77 | done 78 | -------------------------------------------------------------------------------- /jobs/run_nersemble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define a list of subjects and sequences 4 | 5 | # SUBJECTS=("074" "104" "140" "165" "175" "210" "218" "238" "253" "264" "302" "304" "306" "460") # 14 6 | # SUBJECTS=("306") 7 | 8 | # SEQUENCES=("EMO-1" "EMO-2" "EMO-3" "EMO-4" "EXP-1" "EXP-2" "EXP-3" "EXP-4" "EXP-5" "EXP-8" "EXP-9" "FREE") # 11 9 | # SEQUENCES=("SEN-01" "SEN-02" "SEN-03" "SEN-04" "SEN-05" "SEN-06" "SEN-07" "SEN-08" "SEN-09" "SEN-10") # 10 10 | # SEQUENCES=("EMO-1") 11 | 12 | 13 | DATA_FOLDER="data/nersemble" 14 | 15 | # Loop through subjects and sequences and modify the command accordingly 16 | for SUBJECT in "${SUBJECTS[@]}"; do 17 | for SEQUENCE in "${SEQUENCES[@]}"; do 18 | 19 | JOB_NAME="vhap_${SUBJECT}_${SEQUENCE}" 20 | 21 | #======= Preprocess =======# 22 | VIDEO_FOLDER_PATH="${DATA_FOLDER}/${SUBJECT}/${SEQUENCE}" 23 | PREPROCESS_COMMAND=" 24 | python vhap/preprocess_video.py --input $VIDEO_FOLDER_PATH \ 25 | --downsample_scales 2 4 \ 26 | --matting_method background_matting_v2 27 | " 28 | 29 | #======= Track =======# 30 | TRACK_OUTPUT_FOLDER="output/nersemble/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset" 31 | TRACK_COMMAND=" 32 | python vhap/track_nersemble.py --data.root_folder $DATA_FOLDER \ 33 | --exp.output_folder $TRACK_OUTPUT_FOLDER \ 34 | --data.subject $SUBJECT --data.sequence $SEQUENCE \ 35 | --data.n_downsample_rgb 4 36 | 37 | " 38 | #======= Export =======# 39 | EXPORT_OUTPUT_FOLDER="export/${SUBJECT}_${SEQUENCE}_v16_DS4_whiteBg_staticOffset_maskBelowLine" 40 | EXPORT_COMMAND=" 41 | python vhap/export_as_nerf_dataset.py \ 42 | --src_folder ${TRACK_OUTPUT_FOLDER} \ 43 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white 44 | " 45 | 46 | #======= Run =======# 47 | # Execute the command (remove the echo if you want to actually run the command) 48 | 49 | #------- check completeness -------# 50 | # last_folder=$(find "$TRACK_OUTPUT_FOLDER" -maxdepth 1 -type d | sort | tail -n 1) 51 | # # if [ ! -d $last_folder/eval_30 ]; then 52 | # if [ ! -e $last_folder/tracked_flame_params_30.npz ]; then 53 | # echo $last_folder 54 | # fi 55 | 56 | #------- create video -------# 57 | # last_folder=$(find "$TRACK_OUTPUT_FOLDER" -maxdepth 1 -type d | sort | tail -n 1) 58 | # video_folder=$last_folder/eval_30 59 | # ffmpeg -y -framerate 25 -f image2 -pattern_type glob -i "${video_folder}/image_grid/*.jpg" -pix_fmt yuv420p $video_folder/image_grid.mp4 60 | 61 | #------- rename -------# 62 | # mv $TRACK_OUTPUT_FOLDER $TRACK_OUTPUT_FOLDER 63 | 64 | #------- only preprocess -------# 65 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $PREPROCESS_COMMAND" 66 | # COMMAND="$PREPROCESS_COMMAND" 67 | 68 | #------- only track -------# 69 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $TRACK_COMMAND" 70 | # COMMAND="$TRACK_COMMAND" 71 | 72 | #------- only export -------# 73 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $EXPORT_COMMAND" 74 | # COMMAND="$EXPORT_COMMAND" 75 | 76 | #------- track and export -------# 77 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $TRACK_COMMAND && $EXPORT_COMMAND" 78 | # COMMAND="$TRACK_COMMAND && $EXPORT_COMMAND" 79 | 80 | 81 | # echo $COMMAND 82 | $COMMAND 83 | sleep 1 84 | done 85 | done 86 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.metadata] 6 | allow-direct-references = true 7 | 8 | [tool.hatch.build] 9 | include = ["vhap/**/*.py"] 10 | 11 | [project] 12 | name = "VHAP" 13 | version = "0.0.4" 14 | requires-python = ">=3.9" 15 | dependencies = [ 16 | "tyro", 17 | "pyyaml", 18 | "numpy==1.22.3", 19 | "matplotlib==3.8.0", 20 | "scipy", 21 | "pillow", 22 | "opencv-python", 23 | "ffmpeg-python", 24 | "colour", 25 | "torch", # manually install to avoid CUDA version mismatch 26 | "torchvision", # manually install to avoid CUDA version mismatch 27 | "tensorboard", 28 | "chumpy", 29 | "trimesh", 30 | "nvdiffrast@git+https://github.com/ShenhanQian/nvdiffrast@backface-culling", 31 | "BackgroundMattingV2@git+https://github.com/ShenhanQian/BackgroundMattingV2", 32 | "STAR@git+https://github.com/ShenhanQian/STAR/", 33 | "dlib", # for STAR 34 | "pandas", # for STAR 35 | "gdown", # for STAR 36 | "face-alignment", 37 | "face-detection-tflite", # for face-alignment 38 | "pytorch3d@git+https://github.com/facebookresearch/pytorch3d.git", 39 | "dearpygui", 40 | ] 41 | authors = [ 42 | {name = "Shenhan Qian", email = "shenhan.qian@tum.de"}, 43 | ] 44 | 45 | description = "A complete head tracking pipeline from videos to NeRF-ready datasets." 46 | readme = "README.md" 47 | -------------------------------------------------------------------------------- /vhap/combine_nerf_datasets.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from typing import Optional, Literal, List 11 | from copy import deepcopy 12 | import json 13 | import tyro 14 | from pathlib import Path 15 | import shutil 16 | import random 17 | 18 | 19 | class NeRFDatasetAssembler: 20 | def __init__(self, src_folders: List[Path], tgt_folder: Path, division_mode: Literal['random_single', 'random_group', 'last']='random_group'): 21 | self.src_folders = src_folders 22 | self.tgt_folder = tgt_folder 23 | self.num_timestep = 0 24 | 25 | # use the subject name as the random seed to sample the test sequence 26 | subjects = [sf.name.split('_')[0] for sf in src_folders] 27 | for s in subjects: 28 | assert s == subjects[0], f"Cannot combine datasets from different subjects: {subjects}" 29 | subject = subjects[0] 30 | random.seed(subject) 31 | 32 | if division_mode == 'random_single': 33 | self.src_folders_test = [self.src_folders.pop(int(random.uniform(0, 1) * len(src_folders)))] 34 | elif division_mode == 'random_group': 35 | # sample one sequence as the test sequence every `group_size` sequences 36 | self.src_folders_test = [] 37 | num_all = len(self.src_folders) 38 | group_size = 10 39 | num_test = max(1, num_all // group_size) 40 | indices_test = [] 41 | for gi in range(num_test): 42 | idx = min(num_all - 1, random.randint(0, group_size - 1) + gi * group_size) 43 | indices_test.append(idx) 44 | 45 | for idx in indices_test: 46 | self.src_folders_test.append(self.src_folders.pop(idx)) 47 | elif division_mode == 'last': 48 | self.src_folders_test = [self.src_folders.pop(-1)] 49 | else: 50 | raise ValueError(f"Unknown division mode: {division_mode}") 51 | 52 | self.src_folders_train = self.src_folders 53 | 54 | def write(self): 55 | self.combine_dbs(self.src_folders_train, division='train') 56 | self.combine_dbs(self.src_folders_test, division='test') 57 | 58 | def combine_dbs(self, src_folders, division: Optional[Literal['train', 'test']] = None): 59 | db = None 60 | for i, src_folder in enumerate(src_folders): 61 | dbi_path = src_folder / "transforms.json" 62 | assert dbi_path.exists(), f"Could not find {dbi_path}" 63 | # print(f"Loading database: {dbi_path}") 64 | dbi = json.load(open(dbi_path, "r")) 65 | 66 | dbi['timestep_indices'] = [t + self.num_timestep for t in dbi['timestep_indices']] 67 | self.num_timestep += len(dbi['timestep_indices']) 68 | for frame in dbi['frames']: 69 | # drop keys that are irrelevant for a combined dataset 70 | frame.pop('timestep_index_original') 71 | frame.pop('timestep_id') 72 | 73 | # accumulate timestep indices 74 | frame['timestep_index'] = dbi['timestep_indices'][frame['timestep_index']] 75 | 76 | # complement the parent folder 77 | frame['file_path'] = str(Path('..') / Path(src_folder.name) / frame['file_path']) 78 | frame['flame_param_path'] = str(Path('..') / Path(src_folder.name) / frame['flame_param_path']) 79 | frame['fg_mask_path'] = str(Path('..') / Path(src_folder.name) / frame['fg_mask_path']) 80 | 81 | if db is None: 82 | db = dbi 83 | else: 84 | db['frames'] += dbi['frames'] 85 | db['timestep_indices'] += dbi['timestep_indices'] 86 | 87 | if not self.tgt_folder.exists(): 88 | self.tgt_folder.mkdir(parents=True) 89 | 90 | if division == 'train': 91 | # copy the canonical flame param 92 | cano_flame_param_path = src_folders[0] / "canonical_flame_param.npz" 93 | tgt_flame_param_path = self.tgt_folder / f"canonical_flame_param.npz" 94 | print(f"Copying canonical flame param: {tgt_flame_param_path}") 95 | shutil.copy(cano_flame_param_path, tgt_flame_param_path) 96 | 97 | # leave one camera for validation 98 | db_train = {k: v for k, v in db.items() if k not in ['frames', 'camera_indices']} 99 | db_train['frames'] = [] 100 | db_val = deepcopy(db_train) 101 | 102 | if len(db['camera_indices']) > 1: 103 | # when having multiple cameras, leave one camera for validation (novel-view sythesis) 104 | if 8 in db['camera_indices']: 105 | # use camera 8 for validation (front-view of the NeRSemble dataset) 106 | db_train['camera_indices'] = [i for i in db['camera_indices'] if i != 8] 107 | db_val['camera_indices'] = [8] 108 | else: 109 | # use the last camera for validation 110 | db_train['camera_indices'] = db['camera_indices'][:-1] 111 | db_val['camera_indices'] = [db['camera_indices'][-1]] 112 | else: 113 | # when only having one camera, we create an empty validation set 114 | db_train['camera_indices'] = db['camera_indices'] 115 | db_val['camera_indices'] = [] 116 | 117 | for frame in db['frames']: 118 | if frame['camera_index'] in db_train['camera_indices']: 119 | db_train['frames'].append(frame) 120 | elif frame['camera_index'] in db_val['camera_indices']: 121 | db_val['frames'].append(frame) 122 | else: 123 | raise ValueError(f"Unknown camera index: {frame['camera_index']}") 124 | 125 | write_json(db_train, self.tgt_folder, 'train') 126 | write_json(db_val, self.tgt_folder, 'val') 127 | 128 | with open(self.tgt_folder / 'sequences_trainval.txt', 'w') as f: 129 | for folder in src_folders: 130 | f.write(folder.name + '\n') 131 | else: 132 | db['timestep_indices'] = sorted(db['timestep_indices']) 133 | write_json(db, self.tgt_folder, division) 134 | 135 | with open(self.tgt_folder / f'sequences_{division}.txt', 'w') as f: 136 | for folder in src_folders: 137 | f.write(folder.name + '\n') 138 | 139 | 140 | def write_json(db, tgt_folder, division=None): 141 | fname = "transforms.json" if division is None else f"transforms_{division}.json" 142 | json_path = tgt_folder / fname 143 | print(f"Writing database: {json_path}") 144 | with open(json_path, "w") as f: 145 | json.dump(db, f, indent=4) 146 | 147 | def main( 148 | src_folders: List[Path], 149 | tgt_folder: Path, 150 | division_mode: Literal['random_single', 'random_group', 'last']='random_group', 151 | ): 152 | incomplete = False 153 | print("==== Begin assembling datasets ====") 154 | print(f"Division mode: {division_mode}") 155 | for src_folder in src_folders: 156 | try: 157 | assert src_folder.exists(), f"Error: could not find {src_folder}" 158 | assert src_folder.parent == tgt_folder.parent, "All source folders must be in the same parent folder as the target folder" 159 | # print(src_folder) 160 | except AssertionError as e: 161 | print(e) 162 | incomplete = True 163 | 164 | if incomplete: 165 | return 166 | 167 | nerf_dataset_assembler = NeRFDatasetAssembler(src_folders, tgt_folder, division_mode) 168 | nerf_dataset_assembler.write() 169 | 170 | print("Done!") 171 | 172 | 173 | if __name__ == "__main__": 174 | tyro.cli(main) 175 | -------------------------------------------------------------------------------- /vhap/config/base.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from dataclasses import dataclass 11 | from pathlib import Path 12 | from typing import Optional, Literal, Tuple 13 | import tyro 14 | import importlib 15 | from vhap.util.log import get_logger 16 | logger = get_logger(__name__) 17 | 18 | 19 | def import_module(module_name: str): 20 | module_name, class_name = module_name.rsplit(".", 1) 21 | module = getattr(importlib.import_module(module_name), class_name) 22 | return module 23 | 24 | 25 | class Config: 26 | def __getitem__(self, __name: str): 27 | if hasattr(self, __name): 28 | return getattr(self, __name) 29 | else: 30 | raise AttributeError(f"{self.__class__.__name__} has no attribute '{__name}'") 31 | 32 | 33 | @dataclass() 34 | class DataConfig(Config): 35 | root_folder: Path 36 | """The root folder for the dataset.""" 37 | sequence: str 38 | """The sequence name""" 39 | _target: str = "vhap.data.video_dataset.VideoDataset" 40 | """The target dataset class""" 41 | division: Optional[str] = None 42 | subset: Optional[str] = None 43 | calibrated: bool = False 44 | """Whether the cameras parameters are available""" 45 | align_cameras_to_axes: bool = True 46 | """Adjust how cameras distribute in the space with a global rotation""" 47 | camera_convention_conversion: str = 'opencv->opengl' 48 | target_extrinsic_type: Literal['w2c', 'c2w'] = 'w2c' 49 | n_downsample_rgb: Optional[int] = None 50 | """Load from downsampled RGB images to save data IO time""" 51 | scale_factor: float = 1.0 52 | """Further apply a scaling transformation after the downsampling of RGB""" 53 | background_color: Optional[Literal['white', 'black']] = 'white' 54 | use_alpha_map: bool = False 55 | use_landmark: bool = True 56 | landmark_source: Optional[Literal['face-alignment', 'star']] = "star" 57 | 58 | 59 | @dataclass() 60 | class ModelConfig(Config): 61 | n_shape: int = 300 62 | n_expr: int = 100 63 | n_tex: int = 100 64 | 65 | use_static_offset: bool = True 66 | """Optimize static offsets on top of FLAME vertices in the canonical space""" 67 | use_dynamic_offset: bool = False 68 | """Optimize dynamic offsets on top of the FLAME vertices in the canonical space""" 69 | add_teeth: bool = True 70 | """Add teeth to the FLAME model""" 71 | remove_lip_inside: bool = False 72 | """Remove the inner part of the lips from the FLAME model""" 73 | 74 | tex_resolution: int = 2048 75 | """The resolution of the extra texture map""" 76 | tex_painted: bool = True 77 | """Use a painted texture map instead the pca texture space as the base texture map""" 78 | tex_extra: bool = True 79 | """Optimize an extra texture map as the base texture map or the residual texture map""" 80 | # tex_clusters: tuple[str, ...] = ("skin", "hair", "sclerae", "lips_tight", "boundary") 81 | tex_clusters: tuple[str, ...] = ("skin", "hair", "boundary", "lips_tight", "teeth", "sclerae", "irises") 82 | """Regions that are supposed to share a similar color inside""" 83 | residual_tex: bool = True 84 | """Use the extra texture map as a residual component on top of the base texture""" 85 | occluded: tuple[str, ...] = () # to be used for updating stage configs in __post_init__ 86 | """The regions that are occluded by the hair or garments""" 87 | 88 | flame_params_path: Optional[Path] = None 89 | 90 | 91 | @dataclass() 92 | class RenderConfig(Config): 93 | backend: Literal['nvdiffrast', 'pytorch3d'] = 'nvdiffrast' 94 | """The rendering backend""" 95 | use_opengl: bool = False 96 | """Use OpenGL for NVDiffRast""" 97 | background_train: Literal['white', 'black', 'target'] = 'target' 98 | """Background color/image for training""" 99 | disturb_rate_fg: Optional[float] = 0.5 100 | """The rate of disturbance for the foreground""" 101 | disturb_rate_bg: Optional[float] = 0.5 102 | """The rate of disturbance for the background. 0.6 best for multi-view, 0.3 best for single-view""" 103 | background_eval: Literal['white', 'black', 'target'] = 'target' 104 | """Background color/image for evaluation""" 105 | lighting_type: Literal['constant', 'front', 'front-range', 'SH'] = 'SH' 106 | """The type of lighting""" 107 | lighting_space: Literal['world', 'camera'] = 'world' 108 | """The space of lighting""" 109 | 110 | 111 | @dataclass() 112 | class LearningRateConfig(Config): 113 | base: float = 5e-3 114 | """shape, texture, rotation, eyes, neck, jaw""" 115 | translation: float = 1e-3 116 | expr: float = 5e-2 117 | static_offset: float = 5e-4 118 | dynamic_offset: float = 5e-4 119 | camera: float = 5e-3 120 | light: float = 5e-3 121 | 122 | 123 | @dataclass() 124 | class LossWeightConfig(Config): 125 | landmark: Optional[float] = 10. 126 | always_enable_jawline_landmarks: bool = True 127 | """Always enable the landmark loss for the jawline landmarks. Ignore disable_jawline_landmarks in stages.""" 128 | 129 | photo: Optional[float] = 30. 130 | 131 | # L2 regularization 132 | reg_shape: float = 3e-1 133 | reg_neck: float = 3e-1 134 | reg_jaw: float = 3e-1 135 | reg_eyes: float = 3e-2 136 | reg_expr: float = 3e-2 137 | 138 | # regularize the texture map 139 | reg_tex_res_clusters: Optional[float] = 1e1 140 | """Regularize the residual texture map inside each texture cluster""" 141 | reg_tex_res_for: tuple[str, ...] = ("sclerae", "teeth") 142 | """Regularize the residual texture map for the clusters specified""" 143 | reg_tex_tv: Optional[float] = 1e4 # important to split regions apart 144 | """Regularize the total variation of the texture map""" 145 | reg_tex_pca: float = 1e-4 # will make it hard to model hair color when too high 146 | """Regularize the pca texture map (not effective when model.tex_painted is True)""" 147 | 148 | # regularize the lighting 149 | reg_light: Optional[float] = None 150 | """Regularize lighting parameters""" 151 | reg_diffuse: Optional[float] = 1e2 152 | """Regularize lighting parameters by the diffuse term""" 153 | 154 | # L2 regularization for static_offset 155 | reg_offset: Optional[float] = 3e2 156 | """Regularize the norm of offsets""" 157 | reg_offset_relax_coef: float = 1. 158 | """The coefficient for relaxing reg_offset for the regions specified""" 159 | reg_offset_relax_for: tuple[str, ...] = ("hair", "ears") 160 | """Relax the offset loss for the regions specified""" 161 | 162 | # laplacian regularization for static_offset 163 | reg_offset_lap: Optional[float] = 1e6 164 | """Regularize the difference of laplacian coordinate caused by offsets""" 165 | reg_offset_lap_relax_coef: float = 0.1 166 | """The coefficient for relaxing reg_offset_lap for the regions specified""" 167 | reg_offset_lap_relax_for: tuple[str, ...] = ("hair", "ears") 168 | """Relax the offset loss for the regions specified""" 169 | 170 | # local rigidity regularization for static_offset 171 | reg_offset_rigid: Optional[float] = 3e2 172 | """Regularize the the offsets to be as-rigid-as-possible""" 173 | reg_offset_rigid_for: tuple[str, ...] = ("left_ear", "right_ear", "neck", "left_eye", "right_eye", "lips_tight") 174 | """Regularize the the offsets to be as-rigid-as-possible for the regions specified""" 175 | 176 | reg_offset_dynamic: Optional[float] = 3e5 177 | """Regularize the dynamic offsets to be temporally smooth""" 178 | 179 | blur_iter: int = 0 180 | """The number of iterations for blurring vertex weights""" 181 | 182 | # temporal smoothness 183 | smooth_trans: float = 3e2 184 | """global translation""" 185 | smooth_rot: float = 3e1 186 | """global rotation""" 187 | smooth_neck: float = 3e1 188 | """neck joint""" 189 | smooth_jaw: float = 1e-1 190 | """jaw joint""" 191 | smooth_eyes: float = 0 192 | """eyes joints""" 193 | smooth_expr: float = 1e0 194 | """expression""" 195 | 196 | 197 | @dataclass() 198 | class LogConfig(Config): 199 | interval_scalar: Optional[int] = 100 200 | """The step interval of scalar logging. Using an interval of stage_tracking.num_steps // 5 unless specified.""" 201 | interval_media: Optional[int] = 500 202 | """The step interval of media logging. Using an interval of stage_tracking.num_steps unless specified.""" 203 | image_format: Literal['jpg', 'png'] = 'jpg' 204 | """Output image format""" 205 | view_indices: Tuple[int, ...] = () 206 | """Manually specify the view indices for log""" 207 | max_num_views: int = 3 208 | """The maximum number of views for log""" 209 | stack_views_in_rows: bool = True 210 | 211 | 212 | @dataclass() 213 | class ExperimentConfig(Config): 214 | output_folder: Path = Path('output/track') 215 | reuse_landmarks: bool = True 216 | keyframes: Tuple[int, ...] = tuple() 217 | photometric: bool = True 218 | """enable photometric optimization, otherwise only landmark optimization""" 219 | 220 | @dataclass() 221 | class StageConfig(Config): 222 | disable_jawline_landmarks: bool = False 223 | """Disable the landmark loss for the jawline landmarks since they are not accurate""" 224 | 225 | @dataclass() 226 | class StageLmkInitRigidConfig(StageConfig): 227 | """The stage for initializing the rigid parameters""" 228 | num_steps: int = 500 229 | optimizable_params: tuple[str, ...] = ("cam", "pose") 230 | 231 | @dataclass() 232 | class StageLmkInitAllConfig(StageConfig): 233 | """The stage for initializing all the parameters optimizable with landmark loss""" 234 | num_steps: int = 500 235 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr") 236 | 237 | @dataclass() 238 | class StageLmkSequentialTrackingConfig(StageConfig): 239 | """The stage for sequential tracking with landmark loss""" 240 | num_steps: int = 50 241 | optimizable_params: tuple[str, ...] = ("pose", "joints", "expr") 242 | 243 | @dataclass() 244 | class StageLmkGlobalTrackingConfig(StageConfig): 245 | """The stage for global tracking with landmark loss""" 246 | num_epochs: int = 30 247 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr") 248 | 249 | @dataclass() 250 | class PhotometricStageConfig(StageConfig): 251 | align_texture_except: tuple[str, ...] = () 252 | """Align the inner region of rendered FLAME to the image, except for the regions specified""" 253 | align_boundary_except: tuple[str, ...] = ("bottomline",) # necessary to avoid the bottomline of FLAME from being stretched to the bottom of the image 254 | """Align the boundary of FLAME to the image, except for the regions specified""" 255 | 256 | @dataclass() 257 | class StageRgbInitTextureConfig(PhotometricStageConfig): 258 | """The stage for initializing the texture map with photometric loss""" 259 | num_steps: int = 500 260 | optimizable_params: tuple[str, ...] = ("cam", "shape", "texture", "lights") 261 | align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck") 262 | align_boundary_except: tuple[str, ...] = ("hair", "boundary") 263 | 264 | @dataclass() 265 | class StageRgbInitAllConfig(PhotometricStageConfig): 266 | """The stage for initializing all the parameters except the offsets with photometric loss""" 267 | num_steps: int = 500 268 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights") 269 | disable_jawline_landmarks: bool = True 270 | align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck") 271 | align_boundary_except: tuple[str, ...] = ("hair", "bottomline") 272 | 273 | @dataclass() 274 | class StageRgbInitOffsetConfig(PhotometricStageConfig): 275 | """The stage for initializing the offsets with photometric loss""" 276 | num_steps: int = 500 277 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset") 278 | disable_jawline_landmarks: bool = True 279 | align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck") 280 | 281 | @dataclass() 282 | class StageRgbSequentialTrackingConfig(PhotometricStageConfig): 283 | """The stage for sequential tracking with photometric loss""" 284 | num_steps: int = 50 285 | optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "texture", "dynamic_offset") 286 | disable_jawline_landmarks: bool = True 287 | 288 | @dataclass() 289 | class StageRgbGlobalTrackingConfig(PhotometricStageConfig): 290 | """The stage for global tracking with photometric loss""" 291 | num_epochs: int = 30 292 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset", "dynamic_offset") 293 | disable_jawline_landmarks: bool = True 294 | 295 | @dataclass() 296 | class PipelineConfig(Config): 297 | lmk_init_rigid: StageLmkInitRigidConfig 298 | lmk_init_all: StageLmkInitAllConfig 299 | lmk_sequential_tracking: StageLmkSequentialTrackingConfig 300 | lmk_global_tracking: StageLmkGlobalTrackingConfig 301 | rgb_init_texture: StageRgbInitTextureConfig 302 | rgb_init_all: StageRgbInitAllConfig 303 | rgb_init_offset: StageRgbInitOffsetConfig 304 | rgb_sequential_tracking: StageRgbSequentialTrackingConfig 305 | rgb_global_tracking: StageRgbGlobalTrackingConfig 306 | 307 | 308 | @dataclass() 309 | class BaseTrackingConfig(Config): 310 | data: DataConfig 311 | model: ModelConfig 312 | render: RenderConfig 313 | log: LogConfig 314 | exp: ExperimentConfig 315 | lr: LearningRateConfig 316 | w: LossWeightConfig 317 | pipeline: PipelineConfig 318 | 319 | begin_stage: Optional[str] = None 320 | """Begin from the specified stage for debugging""" 321 | begin_frame_idx: int = 0 322 | """Begin from the specified frame index for debugging""" 323 | async_func: bool = True 324 | """Allow asynchronous function calls for speed up""" 325 | device: Literal['cuda', 'cpu'] = 'cuda' 326 | 327 | def get_occluded(self): 328 | occluded_table = { 329 | } 330 | if self.data.sequence in occluded_table: 331 | logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.sequence]}") 332 | self.model.occluded = occluded_table[self.data.sequence] 333 | 334 | def __post_init__(self): 335 | self.get_occluded() 336 | 337 | if not self.model.use_static_offset and not self.model.use_dynamic_offset: 338 | self.model.occluded = tuple(list(self.model.occluded) + ['hair']) # disable boundary alignment for the hair region if no offset is used 339 | 340 | for cfg_stage in self.pipeline.__dict__.values(): 341 | if isinstance(cfg_stage, PhotometricStageConfig): 342 | cfg_stage.align_texture_except = tuple(list(cfg_stage.align_texture_except) + list(self.model.occluded)) 343 | cfg_stage.align_boundary_except = tuple(list(cfg_stage.align_boundary_except) + list(self.model.occluded)) 344 | 345 | if self.begin_stage is not None: 346 | skip = True 347 | for cfg_stage in self.pipeline.__dict__.values(): 348 | if cfg_stage.__class__.__name__.lower() == self.begin_stage: 349 | skip = False 350 | if skip: 351 | cfg_stage.num_steps = 0 352 | 353 | 354 | if __name__ == "__main__": 355 | config = tyro.cli(BaseTrackingConfig) 356 | print(tyro.to_yaml(config)) -------------------------------------------------------------------------------- /vhap/config/nersemble.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from typing import Optional, Literal 11 | from dataclasses import dataclass 12 | import tyro 13 | 14 | from vhap.config.base import ( 15 | StageRgbSequentialTrackingConfig, StageRgbGlobalTrackingConfig, PipelineConfig, 16 | DataConfig, LossWeightConfig, BaseTrackingConfig, 17 | ) 18 | from vhap.util.log import get_logger 19 | logger = get_logger(__name__) 20 | 21 | 22 | @dataclass() 23 | class NersembleDataConfig(DataConfig): 24 | _target: str = "vhap.data.nersemble_dataset.NeRSembleDataset" 25 | calibrated: bool = True 26 | image_size_during_calibration: Optional[tuple[int, int]] = (3208, 2200) 27 | """(height, width). Will be use to convert principle points when the image size is not included in the camera parameters.""" 28 | background_color: Optional[Literal['white', 'black']] = None 29 | landmark_source: Optional[Literal["face-alignment", 'star']] = "star" 30 | 31 | subject: str = "" 32 | """Subject ID. Such as 018, 218, 251, 253""" 33 | use_color_correction: bool = True 34 | """Whether to use color correction to harmonize the color of the input images.""" 35 | 36 | @dataclass() 37 | class NersembleLossWeightConfig(LossWeightConfig): 38 | landmark: Optional[float] = 3. # should not be lower to avoid collapse 39 | always_enable_jawline_landmarks: bool = False # allow disable_jawline_landmarks in StageConfig to work 40 | reg_expr: float = 1e-2 # for best expressivness 41 | reg_tex_tv: Optional[float] = 1e5 # 10x of the base value 42 | smooth_expr: float = 0 # for best expressivness 43 | 44 | @dataclass() 45 | class NersembleStageRgbSequentialTrackingConfig(StageRgbSequentialTrackingConfig): 46 | optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "dynamic_offset") 47 | 48 | align_texture_except: tuple[str, ...] = ("boundary",) 49 | align_boundary_except: tuple[str, ...] = ("boundary",) 50 | """Due to the limited flexibility in the lower neck region of FLAME, we relax the 51 | alignment constraints for better alignment in the face region. 52 | """ 53 | 54 | @dataclass() 55 | class NersembleStageRgbGlobalTrackingConfig(StageRgbGlobalTrackingConfig): 56 | align_texture_except: tuple[str, ...] = ("boundary",) 57 | align_boundary_except: tuple[str, ...] = ("boundary",) 58 | """Due to the limited flexibility in the lower neck region of FLAME, we relax the 59 | alignment constraints for better alignment in the face region. 60 | """ 61 | 62 | @dataclass() 63 | class NersemblePipelineConfig(PipelineConfig): 64 | rgb_sequential_tracking: NersembleStageRgbSequentialTrackingConfig 65 | rgb_global_tracking: NersembleStageRgbGlobalTrackingConfig 66 | 67 | @dataclass() 68 | class NersembleTrackingConfig(BaseTrackingConfig): 69 | data: NersembleDataConfig 70 | w: NersembleLossWeightConfig 71 | pipeline: NersemblePipelineConfig 72 | 73 | def get_occluded(self): 74 | occluded_table = { 75 | '018': ('neck_lower',), 76 | '218': ('neck_lower',), 77 | '251': ('neck_lower', 'boundary'), 78 | '253': ('neck_lower',), 79 | } 80 | if self.data.subject in occluded_table: 81 | logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.subject]}") 82 | self.model.occluded = occluded_table[self.data.subject] 83 | 84 | 85 | if __name__ == "__main__": 86 | config = tyro.cli(NersembleTrackingConfig) 87 | print(tyro.to_yaml(config)) -------------------------------------------------------------------------------- /vhap/config/nersemble_v2.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from dataclasses import dataclass 11 | import tyro 12 | 13 | from vhap.config.nersemble import NersembleDataConfig, NersembleTrackingConfig 14 | from vhap.util.log import get_logger 15 | logger = get_logger(__name__) 16 | 17 | 18 | @dataclass() 19 | class NersembleV2DataConfig(NersembleDataConfig): 20 | _target: str = "vhap.data.nersemble_v2_dataset.NeRSembleV2Dataset" 21 | 22 | 23 | @dataclass() 24 | class NersembleV2TrackingConfig(NersembleTrackingConfig): 25 | data: NersembleV2DataConfig 26 | 27 | 28 | if __name__ == "__main__": 29 | config = tyro.cli(NersembleV2TrackingConfig) 30 | print(tyro.to_yaml(config)) -------------------------------------------------------------------------------- /vhap/data/image_folder_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | import numpy as np 4 | import PIL.Image as Image 5 | from torch.utils.data import Dataset 6 | from vhap.util.log import get_logger 7 | 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | class ImageFolderDataset(Dataset): 13 | def __init__( 14 | self, 15 | image_folder: Path, 16 | background_folder: Optional[Path]=None, 17 | background_fname2camId=lambda x: x, 18 | image_fname2camId=lambda x: x, 19 | ): 20 | """ 21 | Args: 22 | root_folder: Path to dataset with the following directory layout 23 | / 24 | |---xx.jpg 25 | |---... 26 | """ 27 | super().__init__() 28 | self.image_fname2camId = image_fname2camId 29 | self.background_foler = background_folder 30 | 31 | logger.info(f"Initializing dataset from folder {image_folder}") 32 | 33 | self.image_paths = sorted(list(image_folder.glob('*.jpg'))) 34 | 35 | if background_folder is not None: 36 | self.backgrounds = {} 37 | background_paths = sorted(list((image_folder / background_folder).glob('*.jpg'))) 38 | 39 | for background_path in background_paths: 40 | bg = np.array(Image.open(background_path)) 41 | cam_id = background_fname2camId(background_path.name) 42 | self.backgrounds[cam_id] = bg 43 | 44 | def __len__(self): 45 | return len(self.image_paths) 46 | 47 | def __getitem__(self, i): 48 | image_path = self.image_paths[i] 49 | cam_id = self.image_fname2camId(image_path.name) 50 | rgb = np.array(Image.open(image_path)) 51 | item = { 52 | "rgb": rgb, 53 | 'image_path': str(image_path), 54 | } 55 | 56 | if self.background_foler is not None: 57 | item['background'] = self.backgrounds[cam_id] 58 | 59 | return item 60 | 61 | 62 | if __name__ == "__main__": 63 | from tqdm import tqdm 64 | from torch.utils.data import DataLoader 65 | 66 | dataset = ImageFolderDataset( 67 | image_folder='./xx', 68 | img_to_tensor=True, 69 | ) 70 | 71 | print(len(dataset)) 72 | 73 | sample = dataset[0] 74 | print(sample.keys()) 75 | print(sample["rgb"].shape) 76 | 77 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) 78 | for item in tqdm(dataloader): 79 | pass 80 | -------------------------------------------------------------------------------- /vhap/data/nerf_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import numpy as np 4 | import PIL.Image as Image 5 | import torch 6 | import torchvision.transforms.functional as F 7 | from torch.utils.data import Dataset 8 | from vhap.util.log import get_logger 9 | 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | class NeRFDataset(Dataset): 15 | def __init__( 16 | self, 17 | root_folder, 18 | division=None, 19 | camera_convention_conversion=None, 20 | target_extrinsic_type='w2c', 21 | use_fg_mask=False, 22 | use_flame_param=False, 23 | ): 24 | """ 25 | Args: 26 | root_folder: Path to dataset with the following directory layout 27 | / 28 | | 29 | |---/ 30 | | |---00000.jpg 31 | | |... 32 | | 33 | |---/ 34 | | |---00000.png 35 | | |... 36 | | 37 | |---/ 38 | | |---00000.npz 39 | | |... 40 | | 41 | |---transforms_backup.json # backup of the original transforms.json 42 | |---transforms_backup_flame.json # backup of the original transforms.json with flame_param 43 | |---transforms.json # the final transforms.json 44 | |---transforms_train.json # the final transforms.json for training 45 | |---transforms_val.json # the final transforms.json for validation 46 | |---transforms_test.json # the final transforms.json for testing 47 | 48 | 49 | """ 50 | 51 | super().__init__() 52 | self.root_folder = Path(root_folder) 53 | self.division = division 54 | self.camera_convention_conversion = camera_convention_conversion 55 | self.target_extrinsic_type = target_extrinsic_type 56 | self.use_fg_mask = use_fg_mask 57 | self.use_flame_param = use_flame_param 58 | 59 | logger.info(f"Loading NeRF scene from: {root_folder}") 60 | 61 | # data division 62 | if division is None: 63 | tranform_path = self.root_folder / "transforms.json" 64 | elif division == "train": 65 | tranform_path = self.root_folder / "transforms_train.json" 66 | elif division == "val": 67 | tranform_path = self.root_folder / "transforms_val.json" 68 | elif division == "test": 69 | tranform_path = self.root_folder / "transforms_test.json" 70 | else: 71 | raise NotImplementedError(f"Unknown division type: {division}") 72 | logger.info(f"division: {division}") 73 | 74 | self.transforms = json.load(open(tranform_path, "r")) 75 | logger.info(f"number of timesteps: {len(self.transforms['timestep_indices'])}, number of cameras: {len(self.transforms['camera_indices'])}") 76 | 77 | assert len(self.transforms['timestep_indices']) == max(self.transforms['timestep_indices']) + 1 78 | 79 | def __len__(self): 80 | return len(self.transforms['frames']) 81 | 82 | def __getitem__(self, i): 83 | frame = self.transforms['frames'][i] 84 | 85 | # 'timestep_index', 'timestep_index_original', 'timestep_id', 'camera_index', 'camera_id', 'cx', 'cy', 'fl_x', 'fl_y', 'h', 'w', 'camera_angle_x', 'camera_angle_y', 'transform_matrix', 'file_path', 'fg_mask_path', 'flame_param_path'] 86 | 87 | K = torch.eye(3) 88 | K[[0, 1, 0, 1], [0, 1, 2, 2]] = torch.tensor( 89 | [frame["fl_x"], frame["fl_y"], frame["cx"], frame["cy"]] 90 | ) 91 | 92 | c2w = torch.tensor(frame['transform_matrix']) 93 | if self.target_extrinsic_type == "w2c": 94 | extrinsic = c2w.inverse() 95 | elif self.target_extrinsic_type == "c2w": 96 | extrinsic = c2w 97 | else: 98 | raise NotImplementedError(f"Unknown extrinsic type: {self.target_extrinsic_type}") 99 | 100 | img_path = self.root_folder / frame['file_path'] 101 | 102 | item = { 103 | 'timestep_index': frame['timestep_index'], 104 | 'camera_index': frame['camera_index'], 105 | 'intrinsics': K, 106 | 'extrinsics': extrinsic, 107 | 'image_height': frame['h'], 108 | 'image_width': frame['w'], 109 | 'image': np.array(Image.open(img_path)), 110 | 'image_path': img_path, 111 | } 112 | 113 | if self.use_fg_mask and 'fg_mask_path' in frame: 114 | fg_mask_path = self.root_folder / frame['fg_mask_path'] 115 | item["fg_mask"] = np.array(Image.open(fg_mask_path)) 116 | item["fg_mask_path"] = fg_mask_path 117 | 118 | if self.use_flame_param and 'flame_param_path' in frame: 119 | npz = np.load(self.root_folder / frame['flame_param_path'], allow_pickle=True) 120 | item["flame_param"] = dict(npz) 121 | 122 | return item 123 | 124 | def apply_to_tensor(self, item): 125 | if self.img_to_tensor: 126 | if "rgb" in item: 127 | item["rgb"] = F.to_tensor(item["rgb"]) 128 | # if self.rgb_range_shift: 129 | # item["rgb"] = (item["rgb"] - 0.5) / 0.5 130 | 131 | if "alpha_map" in item: 132 | item["alpha_map"] = F.to_tensor(item["alpha_map"]) 133 | return item 134 | 135 | 136 | if __name__ == "__main__": 137 | from tqdm import tqdm 138 | from dataclasses import dataclass 139 | import tyro 140 | from torch.utils.data import DataLoader 141 | 142 | @dataclass 143 | class Args: 144 | root_folder: str 145 | subject: str 146 | sequence: str 147 | use_landmark: bool = False 148 | batchify_all_views: bool = False 149 | 150 | args = tyro.cli(Args) 151 | 152 | dataset = NeRFDataset(root_folder=args.root_folder) 153 | 154 | print(len(dataset)) 155 | 156 | sample = dataset[0] 157 | print(sample.keys()) 158 | 159 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) 160 | for item in tqdm(dataloader): 161 | pass 162 | -------------------------------------------------------------------------------- /vhap/data/nersemble_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import json 11 | import numpy as np 12 | import torch 13 | from vhap.data.video_dataset import VideoDataset 14 | from vhap.config.nersemble import NersembleDataConfig 15 | from vhap.util import camera 16 | from vhap.util.log import get_logger 17 | 18 | 19 | logger = get_logger(__name__) 20 | 21 | 22 | class NeRSembleDataset(VideoDataset): 23 | def __init__( 24 | self, 25 | cfg: NersembleDataConfig, 26 | img_to_tensor: bool = False, 27 | batchify_all_views: bool = False, 28 | ): 29 | """ 30 | Folder layout for NeRSemble dataset: 31 | 32 | / 33 | |---camera_params/ 34 | | |---/ 35 | | |---camera_params.json 36 | | 37 | |---color_correction/ 38 | | |---/ 39 | | |---.npy 40 | | 41 | |---/ 42 | |---/ 43 | |---images/ 44 | | |---cam__.jpg 45 | | 46 | |---alpha_maps/ 47 | | |---cam__.png 48 | | 49 | |---landmark2d/ 50 | |---face-alignment/ 51 | | |---.npz 52 | | 53 | |---STAR/ 54 | |---.npz 55 | """ 56 | self.cfg = cfg 57 | assert cfg.subject != "", "Please specify the subject name" 58 | 59 | super().__init__( 60 | cfg=cfg, 61 | img_to_tensor=img_to_tensor, 62 | batchify_all_views=batchify_all_views, 63 | ) 64 | self.load_color_correction() 65 | 66 | def match_sequences(self): 67 | logger.info(f"Subject: {self.cfg.subject}, sequence: {self.cfg.sequence}") 68 | return list(filter(lambda x: x.is_dir(), (self.cfg.root_folder / self.cfg.subject).glob(f"{self.cfg.sequence}*"))) 69 | 70 | def define_properties(self): 71 | super().define_properties() 72 | self.properties['rgb']['cam_id_prefix'] = "cam_" 73 | self.properties['alpha_map']['cam_id_prefix'] = "cam_" 74 | 75 | def load_camera_params(self, camera_params_path=None): 76 | if camera_params_path is None: 77 | camera_params_path = self.cfg.root_folder / "camera_params" / self.cfg.subject / "camera_params.json" 78 | 79 | assert camera_params_path.exists() 80 | param = json.load(open(camera_params_path)) 81 | 82 | K = torch.Tensor(param["intrinsics"]) 83 | 84 | if "height" not in param or "width" not in param: 85 | assert self.cfg.image_size_during_calibration is not None 86 | H, W = self.cfg.image_size_during_calibration 87 | else: 88 | H, W = param["height"], param["width"] 89 | 90 | self.camera_ids = list(param["world_2_cam"].keys()) 91 | w2c = torch.tensor([param["world_2_cam"][k] for k in self.camera_ids]) # (N, 4, 4) 92 | R = w2c[..., :3, :3] 93 | T = w2c[..., :3, 3] 94 | 95 | orientation = R.transpose(-1, -2) # (N, 3, 3) 96 | location = R.transpose(-1, -2) @ -T[..., None] # (N, 3, 1) 97 | 98 | # adjust how cameras distribute in the space with a global rotation 99 | if self.cfg.align_cameras_to_axes: 100 | orientation, location = camera.align_cameras_to_axes( 101 | orientation, location, target_convention="opengl" 102 | ) 103 | 104 | # modify the local orientation of cameras to fit in different camera conventions 105 | if self.cfg.camera_convention_conversion is not None: 106 | orientation, K = camera.convert_camera_convention( 107 | self.cfg.camera_convention_conversion, orientation, K, H, W 108 | ) 109 | 110 | c2w = torch.cat([orientation, location], dim=-1) # camera-to-world transformation 111 | 112 | if self.cfg.target_extrinsic_type == "w2c": 113 | R = orientation.transpose(-1, -2) 114 | T = orientation.transpose(-1, -2) @ -location 115 | w2c = torch.cat([R, T], dim=-1) # world-to-camera transformation 116 | extrinsic = w2c 117 | elif self.cfg.target_extrinsic_type == "c2w": 118 | extrinsic = c2w 119 | else: 120 | raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}") 121 | 122 | self.camera_params = {} 123 | for i, camera_id in enumerate(self.camera_ids): 124 | self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]} 125 | 126 | def load_color_correction(self): 127 | if self.cfg.use_color_correction: 128 | self.color_correction = {} 129 | 130 | for camera_id in self.camera_ids: 131 | color_correction_path = self.cfg.root_folder / 'color_correction' / self.cfg.subject / f'{camera_id}.npy' 132 | assert color_correction_path.exists(), f"Color correction file not found: {color_correction_path}" 133 | self.color_correction[camera_id] = np.load(color_correction_path) 134 | 135 | def filter_division(self, division): 136 | if division is not None: 137 | cam_for_train = [8, 7, 9, 4, 10, 5, 13, 2, 12, 1, 14, 0] 138 | if division == "train": 139 | self.camera_ids = [ 140 | self.camera_ids[i] 141 | for i in range(len(self.camera_ids)) 142 | if i in cam_for_train 143 | ] 144 | elif division == "val": 145 | self.camera_ids = [ 146 | self.camera_ids[i] 147 | for i in range(len(self.camera_ids)) 148 | if i not in cam_for_train 149 | ] 150 | elif division == "front-view": 151 | self.camera_ids = self.camera_ids[8:9] 152 | elif division == "side-view": 153 | self.camera_ids = self.camera_ids[0:1] 154 | elif division == "six-view": 155 | self.camera_ids = [self.camera_ids[i] for i in [0, 1, 7, 8, 14, 15]] 156 | else: 157 | raise NotImplementedError(f"Unknown division type: {division}") 158 | logger.info(f"division: {division}") 159 | 160 | def apply_transforms(self, item): 161 | item = self.apply_color_correction(item) 162 | item = super().apply_transforms(item) 163 | return item 164 | 165 | def apply_color_correction(self, item): 166 | if self.cfg.use_color_correction: 167 | affine_color_transform = self.color_correction[item["camera_id"]] 168 | rgb = item["rgb"] / 255 169 | rgb = rgb @ affine_color_transform[:3, :3] + affine_color_transform[np.newaxis, :3, 3] 170 | item["rgb"] = (np.clip(rgb, 0, 1) * 255).astype(np.uint8) 171 | return item 172 | 173 | 174 | if __name__ == "__main__": 175 | import tyro 176 | from tqdm import tqdm 177 | from torch.utils.data import DataLoader 178 | from vhap.config.nersemble import NersembleDataConfig 179 | from vhap.config.base import import_module 180 | 181 | cfg = tyro.cli(NersembleDataConfig) 182 | cfg.use_landmark = False 183 | dataset = import_module(cfg._target)( 184 | cfg=cfg, 185 | img_to_tensor=False, 186 | batchify_all_views=True, 187 | ) 188 | 189 | print(len(dataset)) 190 | 191 | sample = dataset[0] 192 | print(sample.keys()) 193 | print(sample["rgb"].shape) 194 | 195 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) 196 | for item in tqdm(dataloader): 197 | pass 198 | -------------------------------------------------------------------------------- /vhap/data/nersemble_v2_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import json 11 | import numpy as np 12 | import colour 13 | from vhap.data.nersemble_dataset import NeRSembleDataset 14 | from vhap.util.log import get_logger 15 | from vhap.util.color_correction import color_correction_Cheung2004_precomputed 16 | 17 | 18 | logger = get_logger(__name__) 19 | 20 | 21 | class NeRSembleV2Dataset(NeRSembleDataset): 22 | """ 23 | Folder layout for NeRSembleV2 dataset: 24 | 25 | / 26 | |---/ 27 | |---calibration/ 28 | | |---camera_params.json 29 | | |---color_calibration.json 30 | | 31 | |---sequences/ 32 | |---/ 33 | |---images/ 34 | | |---cam__.jpg 35 | | 36 | |---alpha_maps/ 37 | | |---cam__.png 38 | | 39 | |---landmark2d/ 40 | |---face-alignment/ 41 | | |---.npz 42 | | 43 | |---STAR/ 44 | |---.npz 45 | 46 | """ 47 | 48 | def match_sequences(self): 49 | logger.info(f"Subject: {self.cfg.subject}, sequence: {self.cfg.sequence}") 50 | return list(filter(lambda x: x.is_dir(), (self.cfg.root_folder / self.cfg.subject / 'sequences').glob(f"{self.cfg.sequence}*"))) 51 | 52 | def load_camera_params(self): 53 | super().load_camera_params(self.cfg.root_folder / self.cfg.subject / "calibration" / "camera_params.json") 54 | 55 | def load_color_correction(self): 56 | if self.cfg.use_color_correction: 57 | color_correction_path = self.cfg.root_folder / self.cfg.subject / 'calibration' / f'color_calibration.json' 58 | self.color_correction = {serial: np.array(ccm) for serial, ccm in json.load(open(color_correction_path)).items()} 59 | 60 | def apply_color_correction(self, item): 61 | if self.cfg.use_color_correction: 62 | rgb = item["rgb"] / 255 63 | image_linear = colour.cctf_decoding(rgb) 64 | ccm = self.color_correction[item["camera_id"]] 65 | image_corrected = color_correction_Cheung2004_precomputed(image_linear, ccm) 66 | image_corrected = colour.cctf_encoding(image_corrected) 67 | item["rgb"] = (np.clip(rgb, 0, 1) * 255).astype(np.uint8) 68 | return item 69 | 70 | 71 | if __name__ == "__main__": 72 | import tyro 73 | from tqdm import tqdm 74 | from torch.utils.data import DataLoader 75 | from vhap.data.nersemble_v2_dataset import NeRSembleV2Dataset 76 | from vhap.config.nersemble_v2 import NersembleV2DataConfig 77 | from vhap.config.base import import_module 78 | 79 | cfg = tyro.cli(NersembleV2DataConfig) 80 | cfg.use_landmark = False 81 | dataset = import_module(cfg._target)( 82 | cfg=cfg, 83 | img_to_tensor=False, 84 | batchify_all_views=True, 85 | ) 86 | 87 | print(len(dataset)) 88 | 89 | sample = dataset[0] 90 | print(sample.keys()) 91 | print(sample["rgb"].shape) 92 | 93 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) 94 | for item in tqdm(dataloader): 95 | pass 96 | -------------------------------------------------------------------------------- /vhap/data/video_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from copy import deepcopy 4 | from typing import Optional 5 | import numpy as np 6 | import PIL.Image as Image 7 | import torch 8 | import torchvision.transforms.functional as F 9 | from torch.utils.data import Dataset, default_collate 10 | 11 | from vhap.util.log import get_logger 12 | from vhap.config.base import DataConfig 13 | 14 | 15 | logger = get_logger(__name__) 16 | 17 | 18 | class VideoDataset(Dataset): 19 | def __init__( 20 | self, 21 | cfg: DataConfig, 22 | img_to_tensor: bool = False, 23 | batchify_all_views: bool = False, 24 | ): 25 | """ 26 | Args: 27 | root_folder: Path to dataset with the following directory layout 28 | / 29 | |---images/ 30 | | |---.jpg 31 | | 32 | |---alpha_maps/ 33 | | |---.png 34 | | 35 | |---landmark2d/ 36 | |---face-alignment/ 37 | | |---.npz 38 | | 39 | |---STAR/ 40 | |---.npz 41 | """ 42 | super().__init__() 43 | self.cfg = cfg 44 | self.img_to_tensor = img_to_tensor 45 | self.batchify_all_views = batchify_all_views 46 | 47 | sequence_paths = self.match_sequences() 48 | if len(sequence_paths) > 1: 49 | logger.info(f"Found multiple sequences: {sequence_paths}") 50 | raise ValueError(f"Found multiple sequences by '{cfg.sequence}': \n" + "\n\t".join([str(x) for x in sequence_paths])) 51 | elif len(sequence_paths) == 0: 52 | raise ValueError(f"Cannot find sequence: {cfg.sequence}") 53 | self.sequence_path = sequence_paths[0] 54 | logger.info(f"Initializing dataset from {self.sequence_path}") 55 | 56 | self.define_properties() 57 | self.load_camera_params() 58 | 59 | # timesteps 60 | self.timestep_ids = set( 61 | f.split('.')[0].split('_')[-1] 62 | for f in os.listdir(self.sequence_path / self.properties['rgb']['folder']) if f.endswith(self.properties['rgb']['suffix']) 63 | ) 64 | self.timestep_ids = sorted(self.timestep_ids) 65 | self.timestep_indices = list(range(len(self.timestep_ids))) 66 | 67 | self.filter_division(cfg.division) 68 | self.filter_subset(cfg.subset) 69 | 70 | logger.info(f"number of timesteps: {self.num_timesteps}, number of cameras: {self.num_cameras}") 71 | 72 | # collect 73 | self.items = [] 74 | for fi, timestep_index in enumerate(self.timestep_indices): 75 | for ci, camera_id in enumerate(self.camera_ids): 76 | self.items.append( 77 | { 78 | "timestep_index": fi, # new index after filtering 79 | "timestep_index_original": timestep_index, # original index 80 | "timestep_id": self.timestep_ids[timestep_index], 81 | "camera_index": ci, 82 | "camera_id": camera_id, 83 | } 84 | ) 85 | 86 | def match_sequences(self): 87 | logger.info(f"Looking for sequence '{self.cfg.sequence}' at {self.cfg.root_folder}") 88 | return list(filter(lambda x: x.is_dir(), self.cfg.root_folder.glob(f"{self.cfg.sequence}*"))) 89 | 90 | def define_properties(self): 91 | self.properties = { 92 | "rgb": { 93 | "folder": f"images_{self.cfg.n_downsample_rgb}" 94 | if self.cfg.n_downsample_rgb 95 | else "images", 96 | "per_timestep": True, 97 | "suffix": "jpg", 98 | }, 99 | "alpha_map": { 100 | "folder": "alpha_maps", 101 | "per_timestep": True, 102 | "suffix": "jpg", 103 | }, 104 | "landmark2d/face-alignment": { 105 | "folder": "landmark2d/face-alignment", 106 | "per_timestep": False, 107 | "suffix": "npz", 108 | }, 109 | "landmark2d/STAR": { 110 | "folder": "landmark2d/STAR", 111 | "per_timestep": False, 112 | "suffix": "npz", 113 | }, 114 | } 115 | 116 | @staticmethod 117 | def get_number_after_prefix(string, prefix): 118 | i = string.find(prefix) 119 | if i != -1: 120 | number_begin = i + len(prefix) 121 | assert number_begin < len(string), f"No number found behind prefix '{prefix}'" 122 | assert string[number_begin].isdigit(), f"No number found behind prefix '{prefix}'" 123 | 124 | non_digit_indices = [i for i, c in enumerate(string[number_begin:]) if not c.isdigit()] 125 | if len(non_digit_indices) > 0: 126 | number_end = number_begin + min(non_digit_indices) 127 | return int(string[number_begin:number_end]) 128 | else: 129 | return int(string[number_begin:]) 130 | else: 131 | return None 132 | 133 | def filter_division(self, division): 134 | pass 135 | 136 | def filter_subset(self, subset): 137 | if subset is not None: 138 | if 'ti' in subset: 139 | ti = self.get_number_after_prefix(subset, 'ti') 140 | if 'tj' in subset: 141 | tj = self.get_number_after_prefix(subset, 'tj') 142 | self.timestep_indices = self.timestep_indices[ti:tj+1] 143 | else: 144 | self.timestep_indices = self.timestep_indices[ti:ti+1] 145 | elif 'tn' in subset: 146 | tn = self.get_number_after_prefix(subset, 'tn') 147 | tn_all = len(self.timestep_indices) 148 | tn = min(tn, tn_all) 149 | self.timestep_indices = self.timestep_indices[::tn_all // tn][:tn] 150 | elif 'ts' in subset: 151 | ts = self.get_number_after_prefix(subset, 'ts') 152 | self.timestep_indices = self.timestep_indices[::ts] 153 | if 'ci' in subset: 154 | ci = self.get_number_after_prefix(subset, 'ci') 155 | self.camera_ids = self.camera_ids[ci:ci+1] 156 | elif 'cn' in subset: 157 | cn = self.get_number_after_prefix(subset, 'cn') 158 | cn_all = len(self.camera_ids) 159 | cn = min(cn, cn_all) 160 | self.camera_ids = self.camera_ids[::cn_all // cn][:cn] 161 | elif 'cs' in subset: 162 | cs = self.get_number_after_prefix(subset, 'cs') 163 | self.camera_ids = self.camera_ids[::cs] 164 | 165 | def load_camera_params(self): 166 | self.camera_ids = ['0'] 167 | 168 | # Guessed focal length, height, width. Should be optimized or replaced by real values 169 | f, h, w = 512, 512, 512 170 | K = torch.Tensor([ 171 | [f, 0, w], 172 | [0, f, h], 173 | [0, 0, 1] 174 | ]) 175 | 176 | orientation = torch.eye(3)[None, ...] # (1, 3, 3) 177 | location = torch.Tensor([0, 0, 1])[None, ..., None] # (1, 3, 1) 178 | 179 | c2w = torch.cat([orientation, location], dim=-1) # camera-to-world transformation 180 | 181 | if self.cfg.target_extrinsic_type == "w2c": 182 | R = orientation.transpose(-1, -2) 183 | T = orientation.transpose(-1, -2) @ -location 184 | w2c = torch.cat([R, T], dim=-1) # world-to-camera transformation 185 | extrinsic = w2c 186 | elif self.cfg.target_extrinsic_type == "c2w": 187 | extrinsic = c2w 188 | else: 189 | raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}") 190 | 191 | self.camera_params = {} 192 | for i, camera_id in enumerate(self.camera_ids): 193 | self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]} 194 | 195 | return self.camera_params 196 | 197 | def __len__(self): 198 | if self.batchify_all_views: 199 | return self.num_timesteps 200 | else: 201 | return len(self.items) 202 | 203 | def __getitem__(self, i): 204 | if self.batchify_all_views: 205 | return self.getitem_by_timestep(i) 206 | else: 207 | return self.getitem_single_image(i) 208 | 209 | def getitem_single_image(self, i): 210 | item = deepcopy(self.items[i]) 211 | 212 | rgb_path = self.get_property_path("rgb", i) 213 | item["rgb"] = np.array(Image.open(rgb_path)) 214 | 215 | camera_param = self.camera_params[item["camera_id"]] 216 | item["intrinsic"] = camera_param["intrinsic"].clone() 217 | item["extrinsic"] = camera_param["extrinsic"].clone() 218 | 219 | if self.cfg.use_alpha_map or self.cfg.background_color is not None: 220 | alpha_path = self.get_property_path("alpha_map", i) 221 | item["alpha_map"] = np.array(Image.open(alpha_path)) 222 | 223 | if self.cfg.use_landmark: 224 | timestep_index = self.items[i]["timestep_index"] 225 | 226 | if self.cfg.landmark_source == "face-alignment": 227 | landmark_path = self.get_property_path("landmark2d/face-alignment", i) 228 | elif self.cfg.landmark_source == "star": 229 | landmark_path = self.get_property_path("landmark2d/STAR", i) 230 | else: 231 | raise NotImplementedError(f"Unknown landmark source: {self.cfg.landmark_source}") 232 | landmark_npz = np.load(landmark_path) 233 | 234 | item["lmk2d"] = landmark_npz["face_landmark_2d"][timestep_index] # (num_points, 3) 235 | if (item["lmk2d"][:, :2] == -1).sum() > 0: 236 | item["lmk2d"][:, 2:] = 0.0 237 | else: 238 | item["lmk2d"][:, 2:] = 1.0 239 | 240 | item = self.apply_transforms(item) 241 | return item 242 | 243 | def getitem_by_timestep(self, timestep_index): 244 | begin = timestep_index * self.num_cameras 245 | indices = range(begin, begin + self.num_cameras) 246 | item = default_collate([self.getitem_single_image(i) for i in indices]) 247 | 248 | item["num_cameras"] = self.num_cameras 249 | return item 250 | 251 | def apply_transforms(self, item): 252 | item = self.apply_scale_factor(item) 253 | item = self.apply_background_color(item) 254 | item = self.apply_to_tensor(item) 255 | return item 256 | 257 | def apply_to_tensor(self, item): 258 | if self.img_to_tensor: 259 | if "rgb" in item: 260 | item["rgb"] = F.to_tensor(item["rgb"]) 261 | 262 | if "alpha_map" in item: 263 | item["alpha_map"] = F.to_tensor(item["alpha_map"]) 264 | return item 265 | 266 | def apply_scale_factor(self, item): 267 | assert self.cfg.scale_factor <= 1.0 268 | 269 | if "rgb" in item: 270 | H, W, _ = item["rgb"].shape 271 | h, w = int(H * self.cfg.scale_factor), int(W * self.cfg.scale_factor) 272 | rgb = Image.fromarray(item["rgb"]).resize( 273 | (w, h), resample=Image.BILINEAR 274 | ) 275 | item["rgb"] = np.array(rgb) 276 | 277 | # properties that are defined based on image size 278 | if "lmk2d" in item: 279 | item["lmk2d"][..., 0] *= w 280 | item["lmk2d"][..., 1] *= h 281 | 282 | if "lmk2d_iris" in item: 283 | item["lmk2d_iris"][..., 0] *= w 284 | item["lmk2d_iris"][..., 1] *= h 285 | 286 | if "bbox_2d" in item: 287 | item["bbox_2d"][[0, 2]] *= w 288 | item["bbox_2d"][[1, 3]] *= h 289 | 290 | # properties need to be scaled down when rgb is downsampled 291 | n_downsample_rgb = self.cfg.n_downsample_rgb if self.cfg.n_downsample_rgb else 1 292 | scale_factor = self.cfg.scale_factor / n_downsample_rgb 293 | item["scale_factor"] = scale_factor # NOTE: not self.cfg.scale_factor 294 | if scale_factor < 1.0: 295 | if "intrinsic" in item: 296 | item["intrinsic"][:2] *= scale_factor 297 | if "alpha_map" in item: 298 | h, w = item["rgb"].shape[:2] 299 | alpha_map = Image.fromarray(item["alpha_map"]).resize( 300 | (w, h), Image.Resampling.BILINEAR 301 | ) 302 | item["alpha_map"] = np.array(alpha_map) 303 | return item 304 | 305 | def apply_background_color(self, item): 306 | if self.cfg.background_color is not None: 307 | assert ( 308 | "alpha_map" in item 309 | ), "'alpha_map' is required to apply background color." 310 | fg = item["rgb"] 311 | if self.cfg.background_color == "white": 312 | bg = np.ones_like(fg) * 255 313 | elif self.cfg.background_color == "black": 314 | bg = np.zeros_like(fg) 315 | else: 316 | raise NotImplementedError( 317 | f"Unknown background color: {self.cfg.background_color}." 318 | ) 319 | 320 | w = item["alpha_map"][..., None] / 255 321 | img = (w * fg + (1 - w) * bg).astype(np.uint8) 322 | item["rgb"] = img 323 | return item 324 | 325 | def get_property_path( 326 | self, 327 | name, 328 | index: Optional[int] = None, 329 | timestep_id: Optional[str] = None, 330 | camera_id: Optional[str] = None, 331 | ): 332 | p = self.properties[name] 333 | folder = p["folder"] if "folder" in p else None 334 | per_timestep = p["per_timestep"] 335 | suffix = p["suffix"] 336 | 337 | path = self.sequence_path 338 | if folder is not None: 339 | path = path / folder 340 | 341 | if self.num_cameras > 1: 342 | if camera_id is None: 343 | assert ( 344 | index is not None), "index is required when camera_id is not provided." 345 | camera_id = self.items[index]["camera_id"] 346 | if "cam_id_prefix" in p: 347 | camera_id = p["cam_id_prefix"] + camera_id 348 | else: 349 | camera_id = "" 350 | 351 | if per_timestep: 352 | if timestep_id is None: 353 | assert index is not None, "index is required when timestep_id is not provided." 354 | timestep_id = self.items[index]["timestep_id"] 355 | if len(camera_id) > 0: 356 | path /= f"{camera_id}_{timestep_id}.{suffix}" 357 | else: 358 | path /= f"{timestep_id}.{suffix}" 359 | else: 360 | if len(camera_id) > 0: 361 | path /= f"{camera_id}.{suffix}" 362 | else: 363 | path = Path(str(path) + f".{suffix}") 364 | 365 | return path 366 | 367 | def get_property_path_list(self, name): 368 | paths = [] 369 | for i in range(len(self.items)): 370 | img_path = self.get_property_path(name, i) 371 | paths.append(img_path) 372 | return paths 373 | 374 | @property 375 | def num_timesteps(self): 376 | return len(self.timestep_indices) 377 | 378 | @property 379 | def num_cameras(self): 380 | return len(self.camera_ids) 381 | 382 | 383 | if __name__ == "__main__": 384 | import tyro 385 | from tqdm import tqdm 386 | from torch.utils.data import DataLoader 387 | from vhap.config.base import DataConfig, import_module 388 | 389 | cfg = tyro.cli(DataConfig) 390 | cfg.use_landmark = False 391 | dataset = import_module(cfg._target)( 392 | cfg=cfg, 393 | img_to_tensor=False, 394 | batchify_all_views=True, 395 | ) 396 | 397 | print(len(dataset)) 398 | 399 | sample = dataset[0] 400 | print(sample.keys()) 401 | print(sample["rgb"].shape) 402 | 403 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) 404 | for item in tqdm(dataloader): 405 | pass 406 | -------------------------------------------------------------------------------- /vhap/flame_editor.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | from pathlib import Path 5 | import time 6 | import dearpygui.dearpygui as dpg 7 | import numpy as np 8 | import torch 9 | 10 | from vhap.util.camera import OrbitCamera 11 | from vhap.model.flame import FlameHead 12 | from vhap.config.base import ModelConfig 13 | from vhap.util.render_nvdiffrast import NVDiffRenderer 14 | 15 | 16 | @dataclass 17 | class Config: 18 | model: ModelConfig 19 | """FLAME model configuration""" 20 | param_path: Optional[Path] = None 21 | """Path to the npz file for FLAME parameters""" 22 | W: int = 1024 23 | """GUI width""" 24 | H: int = 1024 25 | """GUI height""" 26 | radius: float = 1 27 | """default GUI camera radius from center""" 28 | fovy: float = 30 29 | """default GUI camera fovy""" 30 | background_color: tuple[float] = (1., 1., 1.) 31 | """default GUI background color""" 32 | use_opengl: bool = False 33 | """use OpenGL or CUDA rasterizer""" 34 | shade_smooth: bool = True 35 | """smooth shading or flat shading""" 36 | 37 | 38 | class FlameViewer: 39 | def __init__(self, cfg: Config): 40 | self.cfg = cfg # shared with the trainer's cfg to support in-place modification of rendering parameters. 41 | 42 | # flame model 43 | self.flame_model = FlameHead( 44 | cfg.model.n_shape, 45 | cfg.model.n_expr, 46 | add_teeth=True, 47 | include_lbs_color=True, 48 | ) 49 | self.reset_flame_param() 50 | 51 | # viewer settings 52 | self.W = cfg.W 53 | self.H = cfg.H 54 | self.cam = OrbitCamera(self.W, self.H, r=cfg.radius, fovy=cfg.fovy, convention="opengl") 55 | self.last_time_fresh = None 56 | self.render_mode = '-' 57 | self.selected_regions = '-' 58 | self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32) 59 | self.need_update = True # camera moved, should reset accumulation 60 | 61 | # buffers for mouse interaction 62 | self.cursor_x = None 63 | self.cursor_y = None 64 | self.drag_begin_x = None 65 | self.drag_begin_y = None 66 | self.drag_button = None 67 | 68 | # rendering settings 69 | self.mesh_renderer = NVDiffRenderer(use_opengl=cfg.use_opengl, lighting_space='camera', shade_smooth=cfg.shade_smooth) 70 | 71 | self.define_gui() 72 | 73 | def __del__(self): 74 | dpg.destroy_context() 75 | 76 | def refresh(self): 77 | dpg.set_value("_texture", self.render_buffer) 78 | 79 | if self.last_time_fresh is not None: 80 | elapsed = time.time() - self.last_time_fresh 81 | fps = 1 / elapsed 82 | dpg.set_value("_log_fps", f'{fps:.1f}') 83 | self.last_time_fresh = time.time() 84 | 85 | def define_gui(self): 86 | dpg.create_context() 87 | 88 | # register texture ================================================================================================= 89 | with dpg.texture_registry(show=False): 90 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") 91 | 92 | # register window ================================================================================================== 93 | # the window to display the rendered image 94 | with dpg.window(label="viewer", tag="_render_window", width=self.W, height=self.H, no_title_bar=True, no_move=True, no_bring_to_front_on_focus=True, no_resize=True): 95 | dpg.add_image("_texture", width=self.W, height=self.H, tag="_image") 96 | 97 | # control window ================================================================================================== 98 | with dpg.window(label="Control", tag="_control_window", autosize=True): 99 | 100 | with dpg.group(horizontal=True): 101 | dpg.add_text("FPS: ") 102 | dpg.add_text("", tag="_log_fps") 103 | 104 | # rendering options 105 | with dpg.collapsing_header(label="Render", default_open=True): 106 | 107 | def callback_set_render_mode(sender, app_data): 108 | self.render_mode = app_data 109 | self.need_update = True 110 | dpg.add_combo(('-', 'lbs weights'), label='render mode', default_value=self.render_mode, tag="_combo_render_mode", callback=callback_set_render_mode) 111 | 112 | def callback_select_regions(sender, app_data): 113 | self.selected_regions = app_data 114 | self.need_update = True 115 | dpg.add_combo(['-']+sorted(self.flame_model.mask.v.keys()), label='regions', default_value='-', tag="_combo_regions", callback=callback_select_regions) 116 | 117 | # fov slider 118 | def callback_set_fovy(sender, app_data): 119 | self.cam.fovy = app_data 120 | self.need_update = True 121 | dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy, tag="_slider_fovy") 122 | 123 | def callback_reset_camera(sender, app_data): 124 | self.cam.reset() 125 | self.need_update = True 126 | dpg.set_value("_slider_fovy", self.cam.fovy) 127 | 128 | with dpg.group(horizontal=True): 129 | dpg.add_button(label="reset camera", tag="_button_reset_pose", callback=callback_reset_camera) 130 | 131 | 132 | # FLAME paraemter options 133 | with dpg.collapsing_header(label="Parameters", default_open=True): 134 | 135 | def callback_set_pose(sender, app_data): 136 | joint, axis = sender.split('-')[1:3] 137 | axis_idx = {'x': 0, 'y': 1, 'z': 2}[axis] 138 | self.flame_param[joint][0, axis_idx] = app_data 139 | self.need_update = True 140 | self.pose_sliders = [] 141 | slider_width = 87 142 | for joint in ['neck', 'jaw']: 143 | dpg.add_text(f'{joint:9s}') 144 | if joint in self.flame_param: 145 | with dpg.group(horizontal=True): 146 | dpg.add_slider_float(label="x", min_value=-1, max_value=1, format="%.2f", default_value=self.flame_param[joint][0, 0], callback=callback_set_pose, tag=f"_slider-{joint}-x", width=slider_width) 147 | dpg.add_slider_float(label="y", min_value=-1, max_value=1, format="%.2f", default_value=self.flame_param[joint][0, 1], callback=callback_set_pose, tag=f"_slider-{joint}-y", width=slider_width) 148 | dpg.add_slider_float(label="z", min_value=-1, max_value=1, format="%.2f", default_value=self.flame_param[joint][0, 2], callback=callback_set_pose, tag=f"_slider-{joint}-z", width=slider_width) 149 | self.pose_sliders.append(f"_slider-{joint}-x") 150 | self.pose_sliders.append(f"_slider-{joint}-y") 151 | self.pose_sliders.append(f"_slider-{joint}-z") 152 | 153 | def callback_set_expr(sender, app_data): 154 | expr_i = int(sender.split('-')[2]) 155 | self.flame_param['expr'][0, expr_i] = app_data 156 | self.need_update = True 157 | self.expr_sliders = [] 158 | dpg.add_text(f'expr') 159 | for i in range(5): 160 | dpg.add_slider_float(label=f"{i}", min_value=-5, max_value=5, format="%.2f", default_value=0, callback=callback_set_expr, tag=f"_slider-expr-{i}", width=300) 161 | self.expr_sliders.append(f"_slider-expr-{i}") 162 | 163 | def callback_reset_flame(sender, app_data): 164 | self.reset_flame_param() 165 | self.need_update = True 166 | for slider in self.pose_sliders + self.expr_sliders: 167 | dpg.set_value(slider, 0) 168 | dpg.add_button(label="reset FLAME", tag="_button_reset_flame", callback=callback_reset_flame) 169 | 170 | ### register mouse handlers ======================================================================================== 171 | 172 | def callback_mouse_move(sender, app_data): 173 | self.cursor_x, self.cursor_y = app_data 174 | if not dpg.is_item_focused("_render_window"): 175 | return 176 | 177 | if self.drag_begin_x is None or self.drag_begin_y is None: 178 | self.drag_begin_x = self.cursor_x 179 | self.drag_begin_y = self.cursor_y 180 | else: 181 | dx = self.cursor_x - self.drag_begin_x 182 | dy = self.cursor_y - self.drag_begin_y 183 | 184 | # button=dpg.mvMouseButton_Left 185 | if self.drag_button is dpg.mvMouseButton_Left: 186 | self.cam.orbit(dx, dy) 187 | self.need_update = True 188 | elif self.drag_button is dpg.mvMouseButton_Middle: 189 | self.cam.pan(dx, dy) 190 | self.need_update = True 191 | 192 | def callback_mouse_button_down(sender, app_data): 193 | if not dpg.is_item_focused("_render_window"): 194 | return 195 | self.drag_begin_x = self.cursor_x 196 | self.drag_begin_y = self.cursor_y 197 | self.drag_button = app_data[0] 198 | 199 | def callback_mouse_release(sender, app_data): 200 | self.drag_begin_x = None 201 | self.drag_begin_y = None 202 | self.drag_button = None 203 | 204 | self.dx_prev = None 205 | self.dy_prev = None 206 | 207 | def callback_mouse_drag(sender, app_data): 208 | if not dpg.is_item_focused("_render_window"): 209 | return 210 | 211 | button, dx, dy = app_data 212 | if self.dx_prev is None or self.dy_prev is None: 213 | ddx = dx 214 | ddy = dy 215 | else: 216 | ddx = dx - self.dx_prev 217 | ddy = dy - self.dy_prev 218 | 219 | self.dx_prev = dx 220 | self.dy_prev = dy 221 | 222 | if ddx != 0 and ddy != 0: 223 | if button is dpg.mvMouseButton_Left: 224 | self.cam.orbit(ddx, ddy) 225 | self.need_update = True 226 | elif button is dpg.mvMouseButton_Middle: 227 | self.cam.pan(ddx, ddy) 228 | self.need_update = True 229 | 230 | def callback_camera_wheel_scale(sender, app_data): 231 | if not dpg.is_item_focused("_render_window"): 232 | return 233 | delta = app_data 234 | self.cam.scale(delta) 235 | self.need_update = True 236 | 237 | with dpg.handler_registry(): 238 | # this registry order helps avoid false fire 239 | dpg.add_mouse_release_handler(callback=callback_mouse_release) 240 | # dpg.add_mouse_drag_handler(callback=callback_mouse_drag) # not using the drag callback, since it does not return the starting point 241 | dpg.add_mouse_move_handler(callback=callback_mouse_move) 242 | dpg.add_mouse_down_handler(callback=callback_mouse_button_down) 243 | dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) 244 | 245 | # key press handlers 246 | # dpg.add_key_press_handler(dpg.mvKey_Left, callback=callback_set_current_frame, tag='_mvKey_Left') 247 | # dpg.add_key_press_handler(dpg.mvKey_Right, callback=callback_set_current_frame, tag='_mvKey_Right') 248 | # dpg.add_key_press_handler(dpg.mvKey_Home, callback=callback_set_current_frame, tag='_mvKey_Home') 249 | # dpg.add_key_press_handler(dpg.mvKey_End, callback=callback_set_current_frame, tag='_mvKey_End') 250 | 251 | def callback_viewport_resize(sender, app_data): 252 | while self.rendering: 253 | time.sleep(0.01) 254 | self.need_update = False 255 | self.W = app_data[0] 256 | self.H = app_data[1] 257 | self.cam.image_width = self.W 258 | self.cam.image_height = self.H 259 | self.render_buffer = np.zeros((self.H, self.W, 3), dtype=np.float32) 260 | 261 | # delete and re-add the texture and image 262 | dpg.delete_item("_texture") 263 | dpg.delete_item("_image") 264 | 265 | with dpg.texture_registry(show=False): 266 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") 267 | dpg.add_image("_texture", width=self.W, height=self.H, tag="_image", parent="_render_window") 268 | dpg.configure_item("_render_window", width=self.W, height=self.H) 269 | self.need_update = True 270 | dpg.set_viewport_resize_callback(callback_viewport_resize) 271 | 272 | ### global theme ================================================================================================== 273 | with dpg.theme() as theme_no_padding: 274 | with dpg.theme_component(dpg.mvAll): 275 | # set all padding to 0 to avoid scroll bar 276 | dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) 277 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) 278 | dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) 279 | dpg.bind_item_theme("_render_window", theme_no_padding) 280 | 281 | ### finish setup ================================================================================================== 282 | dpg.create_viewport(title='FLAME Editor', width=self.W, height=self.H, resizable=True) 283 | dpg.setup_dearpygui() 284 | dpg.show_viewport() 285 | 286 | def reset_flame_param(self): 287 | self.flame_param = { 288 | 'shape': torch.zeros(1, self.cfg.model.n_shape), 289 | 'expr': torch.zeros(1, self.cfg.model.n_expr), 290 | 'rotation': torch.zeros(1, 3), 291 | 'neck': torch.zeros(1, 3), 292 | 'jaw': torch.zeros(1, 3), 293 | 'eyes': torch.zeros(1, 6), 294 | 'translation': torch.zeros(1, 3), 295 | 'static_offset': torch.zeros(1, 3), 296 | 'dynamic_offset': torch.zeros(1, 3), 297 | } 298 | 299 | def forward_flame(self, flame_param): 300 | N = flame_param['expr'].shape[0] 301 | 302 | self.verts, self.verts_cano = self.flame_model( 303 | **flame_param, 304 | zero_centered_at_root_node=False, 305 | return_landmarks=False, 306 | return_verts_cano=True, 307 | ) 308 | 309 | def prepare_camera(self): 310 | @dataclass 311 | class Cam: 312 | FoVx = float(np.radians(self.cam.fovx)) 313 | FoVy = float(np.radians(self.cam.fovy)) 314 | image_height = self.cam.image_height 315 | image_width = self.cam.image_width 316 | world_view_transform = torch.tensor(self.cam.world_view_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer 317 | full_proj_transform = torch.tensor(self.cam.full_proj_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer 318 | camera_center = torch.tensor(self.cam.pose[:3, 3]).cuda() 319 | return Cam 320 | 321 | def run(self): 322 | 323 | while dpg.is_dearpygui_running(): 324 | 325 | if self.need_update: 326 | self.rendering = True 327 | 328 | with torch.no_grad(): 329 | # mesh 330 | self.forward_flame(self.flame_param) 331 | verts = self.verts.cuda() 332 | faces = self.flame_model.faces.cuda() 333 | 334 | # camera 335 | RT = torch.from_numpy(self.cam.world_view_transform).cuda()[None] 336 | K = torch.from_numpy(self.cam.intrinsics).cuda()[None] 337 | image_size = self.cam.image_height, self.cam.image_width 338 | 339 | if self.render_mode == 'lbs weights': 340 | v_color = self.flame_model.lbs_color.cuda() 341 | else: 342 | v_color = torch.ones_like(verts) 343 | 344 | if self.selected_regions != '-': 345 | vid = self.flame_model.mask.get_vid_except_region(self.selected_regions) 346 | v_color[..., vid, :] *= 0.3 347 | 348 | out_dict = self.mesh_renderer.render_rgba_vis(verts, faces, RT, K, image_size, self.cfg.background_color, v_color=v_color) 349 | 350 | rgba_mesh = out_dict['rgba'].squeeze(0).permute(2, 0, 1) # (C, W, H) 351 | rgb_mesh = rgba_mesh[:3, :, :] 352 | 353 | self.render_buffer = rgb_mesh.permute(1, 2, 0).cpu().numpy() 354 | self.refresh() 355 | 356 | self.rendering = False 357 | self.need_update = False 358 | dpg.render_dearpygui_frame() 359 | 360 | 361 | if __name__ == "__main__": 362 | cfg = tyro.cli(Config) 363 | gui = FlameViewer(cfg) 364 | gui.run() 365 | -------------------------------------------------------------------------------- /vhap/flame_viewer.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | from pathlib import Path 5 | import time 6 | import dearpygui.dearpygui as dpg 7 | import numpy as np 8 | import torch 9 | import PIL.Image 10 | 11 | from vhap.util.camera import OrbitCamera 12 | from vhap.model.flame import FlameHead 13 | from vhap.config.base import ModelConfig 14 | from vhap.util.render_nvdiffrast import NVDiffRenderer 15 | 16 | 17 | @dataclass 18 | class Config: 19 | model: ModelConfig 20 | """FLAME model configuration""" 21 | param_path: Optional[Path] = None 22 | """Path to the npz file for FLAME parameters""" 23 | tex_path: Optional[Path] = None 24 | """Path to the texture image""" 25 | W: int = 1024 26 | """GUI width""" 27 | H: int = 1024 28 | """GUI height""" 29 | radius: float = 1 30 | """default GUI camera radius from center""" 31 | fovy: float = 30 32 | """default GUI camera fovy""" 33 | background_color: tuple[float] = (1., 1., 1.) 34 | """default GUI background color""" 35 | use_opengl: bool = False 36 | """use OpenGL or CUDA rasterizer""" 37 | shade_smooth: bool = True 38 | """smooth shading or flat shading""" 39 | 40 | 41 | class FlameViewer: 42 | def __init__(self, cfg: Config): 43 | self.cfg = cfg # shared with the trainer's cfg to support in-place modification of rendering parameters. 44 | 45 | # flame model 46 | self.flame_model = FlameHead(cfg.model.n_shape, cfg.model.n_expr, add_teeth=True).cuda() 47 | 48 | # viewer settings 49 | self.W = cfg.W 50 | self.H = cfg.H 51 | self.cam = OrbitCamera(self.W, self.H, r=cfg.radius, fovy=cfg.fovy, convention="opengl") 52 | self.last_time_fresh = None 53 | self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32) 54 | self.need_update = True # camera moved, should reset accumulation 55 | 56 | # buffers for mouse interaction 57 | self.cursor_x = None 58 | self.cursor_y = None 59 | self.drag_begin_x = None 60 | self.drag_begin_y = None 61 | self.drag_button = None 62 | 63 | # rendering settings 64 | self.mesh_renderer = NVDiffRenderer(use_opengl=cfg.use_opengl, lighting_space='camera', shade_smooth=cfg.shade_smooth) 65 | self.num_timesteps = 1 66 | self.timestep = 0 67 | 68 | self.define_gui() 69 | 70 | def __del__(self): 71 | dpg.destroy_context() 72 | 73 | def refresh(self): 74 | dpg.set_value("_texture", self.render_buffer) 75 | 76 | if self.last_time_fresh is not None: 77 | elapsed = time.time() - self.last_time_fresh 78 | fps = 1 / elapsed 79 | dpg.set_value("_log_fps", f'{fps:.1f}') 80 | self.last_time_fresh = time.time() 81 | 82 | def define_gui(self): 83 | dpg.create_context() 84 | 85 | # register texture ================================================================================================= 86 | with dpg.texture_registry(show=False): 87 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") 88 | 89 | # register window ================================================================================================== 90 | # the window to display the rendered image 91 | with dpg.window(label="viewer", tag="_render_window", width=self.W, height=self.H, no_title_bar=True, no_move=True, no_bring_to_front_on_focus=True, no_resize=True): 92 | dpg.add_image("_texture", width=self.W, height=self.H, tag="_image") 93 | 94 | # control window ================================================================================================== 95 | with dpg.window(label="Control", tag="_control_window", autosize=True): 96 | 97 | with dpg.group(horizontal=True): 98 | dpg.add_text("FPS: ") 99 | dpg.add_text("", tag="_log_fps") 100 | 101 | # rendering options 102 | with dpg.collapsing_header(label="Render", default_open=True): 103 | 104 | # timestep slider and buttons 105 | if self.num_timesteps != None: 106 | def callback_set_current_frame(sender, app_data): 107 | if sender == "_slider_timestep": 108 | self.timestep = app_data 109 | elif sender in ["_button_timestep_plus", "_mvKey_Right"]: 110 | self.timestep = min(self.timestep + 1, self.num_timesteps - 1) 111 | elif sender in ["_button_timestep_minus", "_mvKey_Left"]: 112 | self.timestep = max(self.timestep - 1, 0) 113 | elif sender == "_mvKey_Home": 114 | self.timestep = 0 115 | elif sender == "_mvKey_End": 116 | self.timestep = self.num_timesteps - 1 117 | 118 | dpg.set_value("_slider_timestep", self.timestep) 119 | 120 | self.need_update = True 121 | with dpg.group(horizontal=True): 122 | dpg.add_button(label='-', tag="_button_timestep_minus", callback=callback_set_current_frame) 123 | dpg.add_button(label='+', tag="_button_timestep_plus", callback=callback_set_current_frame) 124 | dpg.add_slider_int(label="timestep", tag='_slider_timestep', width=162, min_value=0, max_value=self.num_timesteps - 1, format="%d", default_value=0, callback=callback_set_current_frame) 125 | 126 | # fov slider 127 | def callback_set_fovy(sender, app_data): 128 | self.cam.fovy = app_data 129 | self.need_update = True 130 | dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy, tag="_slider_fovy") 131 | 132 | def callback_reset_camera(sender, app_data): 133 | self.cam.reset() 134 | self.need_update = True 135 | dpg.set_value("_slider_fovy", self.cam.fovy) 136 | 137 | with dpg.group(horizontal=True): 138 | dpg.add_button(label="reset camera", tag="_button_reset_pose", callback=callback_reset_camera) 139 | 140 | 141 | ### register mouse handlers ======================================================================================== 142 | 143 | def callback_mouse_move(sender, app_data): 144 | self.cursor_x, self.cursor_y = app_data 145 | if not dpg.is_item_focused("_render_window"): 146 | return 147 | 148 | if self.drag_begin_x is None or self.drag_begin_y is None: 149 | self.drag_begin_x = self.cursor_x 150 | self.drag_begin_y = self.cursor_y 151 | else: 152 | dx = self.cursor_x - self.drag_begin_x 153 | dy = self.cursor_y - self.drag_begin_y 154 | 155 | # button=dpg.mvMouseButton_Left 156 | if self.drag_button is dpg.mvMouseButton_Left: 157 | self.cam.orbit(dx, dy) 158 | self.need_update = True 159 | elif self.drag_button is dpg.mvMouseButton_Middle: 160 | self.cam.pan(dx, dy) 161 | self.need_update = True 162 | 163 | def callback_mouse_button_down(sender, app_data): 164 | if not dpg.is_item_focused("_render_window"): 165 | return 166 | self.drag_begin_x = self.cursor_x 167 | self.drag_begin_y = self.cursor_y 168 | self.drag_button = app_data[0] 169 | 170 | def callback_mouse_release(sender, app_data): 171 | self.drag_begin_x = None 172 | self.drag_begin_y = None 173 | self.drag_button = None 174 | 175 | self.dx_prev = None 176 | self.dy_prev = None 177 | 178 | def callback_mouse_drag(sender, app_data): 179 | if not dpg.is_item_focused("_render_window"): 180 | return 181 | 182 | button, dx, dy = app_data 183 | if self.dx_prev is None or self.dy_prev is None: 184 | ddx = dx 185 | ddy = dy 186 | else: 187 | ddx = dx - self.dx_prev 188 | ddy = dy - self.dy_prev 189 | 190 | self.dx_prev = dx 191 | self.dy_prev = dy 192 | 193 | if ddx != 0 and ddy != 0: 194 | if button is dpg.mvMouseButton_Left: 195 | self.cam.orbit(ddx, ddy) 196 | self.need_update = True 197 | elif button is dpg.mvMouseButton_Middle: 198 | self.cam.pan(ddx, ddy) 199 | self.need_update = True 200 | 201 | def callback_camera_wheel_scale(sender, app_data): 202 | if not dpg.is_item_focused("_render_window"): 203 | return 204 | delta = app_data 205 | self.cam.scale(delta) 206 | self.need_update = True 207 | 208 | with dpg.handler_registry(): 209 | # this registry order helps avoid false fire 210 | dpg.add_mouse_release_handler(callback=callback_mouse_release) 211 | # dpg.add_mouse_drag_handler(callback=callback_mouse_drag) # not using the drag callback, since it does not return the starting point 212 | dpg.add_mouse_move_handler(callback=callback_mouse_move) 213 | dpg.add_mouse_down_handler(callback=callback_mouse_button_down) 214 | dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) 215 | 216 | # key press handlers 217 | dpg.add_key_press_handler(dpg.mvKey_Left, callback=callback_set_current_frame, tag='_mvKey_Left') 218 | dpg.add_key_press_handler(dpg.mvKey_Right, callback=callback_set_current_frame, tag='_mvKey_Right') 219 | dpg.add_key_press_handler(dpg.mvKey_Home, callback=callback_set_current_frame, tag='_mvKey_Home') 220 | dpg.add_key_press_handler(dpg.mvKey_End, callback=callback_set_current_frame, tag='_mvKey_End') 221 | 222 | def callback_viewport_resize(sender, app_data): 223 | while self.rendering: 224 | time.sleep(0.01) 225 | self.need_update = False 226 | self.W = app_data[0] 227 | self.H = app_data[1] 228 | self.cam.image_width = self.W 229 | self.cam.image_height = self.H 230 | self.render_buffer = np.zeros((self.H, self.W, 3), dtype=np.float32) 231 | 232 | # delete and re-add the texture and image 233 | dpg.delete_item("_texture") 234 | dpg.delete_item("_image") 235 | 236 | with dpg.texture_registry(show=False): 237 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") 238 | dpg.add_image("_texture", width=self.W, height=self.H, tag="_image", parent="_render_window") 239 | dpg.configure_item("_render_window", width=self.W, height=self.H) 240 | self.need_update = True 241 | dpg.set_viewport_resize_callback(callback_viewport_resize) 242 | 243 | ### global theme ================================================================================================== 244 | with dpg.theme() as theme_no_padding: 245 | with dpg.theme_component(dpg.mvAll): 246 | # set all padding to 0 to avoid scroll bar 247 | dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) 248 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) 249 | dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) 250 | dpg.bind_item_theme("_render_window", theme_no_padding) 251 | 252 | ### finish setup ================================================================================================== 253 | dpg.create_viewport(title='FLAME Sequence Viewer', width=self.W, height=self.H, resizable=True) 254 | dpg.setup_dearpygui() 255 | dpg.show_viewport() 256 | 257 | def forward_flame(self, flame_param): 258 | N = flame_param['expr'].shape[0] 259 | 260 | self.verts, self.verts_cano = self.flame_model( 261 | flame_param['shape'][None, ...].expand(N, -1).cuda(), 262 | flame_param['expr'].cuda(), 263 | flame_param['rotation'].cuda(), 264 | flame_param['neck_pose'].cuda(), 265 | flame_param['jaw_pose'].cuda(), 266 | flame_param['eyes_pose'].cuda(), 267 | flame_param['translation'].cuda(), 268 | zero_centered_at_root_node=False, 269 | return_landmarks=False, 270 | return_verts_cano=True, 271 | static_offset=flame_param['static_offset'].cuda(), 272 | # dynamic_offset=flame_param['dynamic_offset'].cuda(), 273 | ) 274 | 275 | self.num_timesteps = N 276 | dpg.configure_item("_slider_timestep", max_value=self.num_timesteps - 1) 277 | 278 | def prepare_camera(self): 279 | @dataclass 280 | class Cam: 281 | FoVx = float(np.radians(self.cam.fovx)) 282 | FoVy = float(np.radians(self.cam.fovy)) 283 | image_height = self.cam.image_height 284 | image_width = self.cam.image_width 285 | world_view_transform = torch.tensor(self.cam.world_view_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer 286 | full_proj_transform = torch.tensor(self.cam.full_proj_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer 287 | camera_center = torch.tensor(self.cam.pose[:3, 3]).cuda() 288 | return Cam 289 | 290 | def run(self): 291 | assert self.cfg.param_path is not None, 'param_path must be provided.' 292 | assert self.cfg.param_path.exists(), f'{self.cfg.param_path} does not exist.' 293 | self.flame_param = dict(np.load(self.cfg.param_path)) 294 | for k, v in self.flame_param.items(): 295 | if v.dtype in [np.float64, np.float32]: 296 | self.flame_param[k] = torch.from_numpy(v).float() 297 | self.forward_flame(self.flame_param) 298 | 299 | # tex_path 300 | if self.cfg.tex_path is not None: 301 | assert self.cfg.tex_path.exists(), f'{self.cfg.tex_path} does not exist.' 302 | # load the texture image with PIL and turn into pytorch tensor 303 | tex = np.array(PIL.Image.open(self.cfg.tex_path)) / 255. 304 | self.tex = torch.tensor(tex, dtype=torch.float32).permute(2, 0, 1)[None].cuda() 305 | 306 | self.verts_uv = self.flame_model.verts_uvs.clone() 307 | self.verts_uv[:, 1] = 1 - self.verts_uv[:, 1] 308 | 309 | self.faces_uv = self.flame_model.textures_idx.int() 310 | 311 | self.lights = self.flame_param['lights'].cuda()[None] 312 | 313 | while dpg.is_dearpygui_running(): 314 | 315 | if self.need_update: 316 | self.rendering = True 317 | 318 | with torch.no_grad(): 319 | RT = torch.from_numpy(self.cam.world_view_transform).cuda()[None] 320 | K = torch.from_numpy(self.cam.intrinsics).cuda()[None] 321 | image_size = self.cam.image_height, self.cam.image_width 322 | verts = self.verts[[self.timestep]] 323 | faces = self.flame_model.faces 324 | tex = self.tex if hasattr(self, 'tex') else None 325 | lights = self.lights if hasattr(self, 'lights') else None 326 | 327 | out_dict = self.mesh_renderer.render_rgba_vis( 328 | verts, faces, RT, K, image_size, self.cfg.background_color, 329 | verts_uv=self.verts_uv, faces_uv=self.faces_uv, tex=tex, 330 | lights=lights, 331 | ) 332 | 333 | rgba_mesh = out_dict['rgba'].squeeze(0).permute(2, 0, 1) # (C, W, H) 334 | rgb_mesh = rgba_mesh[:3, :, :] 335 | 336 | self.render_buffer = rgb_mesh.permute(1, 2, 0).cpu().numpy() 337 | self.refresh() 338 | 339 | self.rendering = False 340 | self.need_update = False 341 | dpg.render_dearpygui_frame() 342 | 343 | 344 | if __name__ == "__main__": 345 | cfg = tyro.cli(Config) 346 | gui = FlameViewer(cfg) 347 | gui.run() 348 | -------------------------------------------------------------------------------- /vhap/generate_flame_uvmask.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from typing import Literal 11 | import tyro 12 | import numpy as np 13 | from PIL import Image 14 | from pathlib import Path 15 | import torch 16 | import nvdiffrast.torch as dr 17 | from vhap.util.render_uvmap import render_uvmap_vtex 18 | from vhap.model.flame import FlameHead 19 | 20 | 21 | FLAME_UV_MASK_FOLDER = "asset/flame/uv_masks" 22 | FLAME_UV_MASK_NPZ = "asset/flame/uv_masks.npz" 23 | 24 | 25 | def main( 26 | use_opengl: bool = False, 27 | device: Literal['cuda', 'cpu'] = 'cuda', 28 | ): 29 | n_shape = 300 30 | n_expr = 100 31 | print("Initializing FLAME model") 32 | flame_model = FlameHead(n_shape, n_expr, add_teeth=True) 33 | 34 | flame_model = FlameHead( 35 | n_shape, 36 | n_expr, 37 | add_teeth=True, 38 | ).cuda() 39 | 40 | faces = flame_model.faces.int().cuda() 41 | verts_uv = flame_model.verts_uvs.cuda() 42 | # verts_uv[:, 1] = 1 - verts_uv[:, 1] 43 | faces_uv = flame_model.textures_idx.int().cuda() 44 | col_idx = faces_uv 45 | 46 | # Rasterizer context 47 | glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext() 48 | 49 | h, w = 2048, 2048 50 | resolution = (h, w) 51 | 52 | if not Path(FLAME_UV_MASK_FOLDER).exists(): 53 | Path(FLAME_UV_MASK_FOLDER).mkdir(parents=True) 54 | 55 | # alpha_maps = {} 56 | masks = {} 57 | for region, vt_mask in flame_model.mask.vt: 58 | v_color = torch.zeros(verts_uv.shape[0], 1).to(device) # alpha channel 59 | v_color[vt_mask] = 1 60 | 61 | alpha = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)[0] 62 | alpha = alpha.flip(0) 63 | # alpha_maps[region] = alpha.cpu().numpy() 64 | mask = (alpha > 0.5) # to avoid overlap between hair and face 65 | mask = mask.squeeze(-1).cpu().numpy() 66 | masks[region] = mask # (h, w) 67 | 68 | print(f"Saving uv mask for {region}...") 69 | # rgba = mask.expand(-1, -1, 4) # (h, w, 4) 70 | # rgb = torch.ones_like(mask).expand(-1, -1, 3) # (h, w, 3) 71 | # rgba = torch.cat([rgb, mask], dim=-1).cpu().numpy() # (h, w, 4) 72 | img = mask 73 | img = Image.fromarray((img * 255).astype(np.uint8)) 74 | img.save(Path(FLAME_UV_MASK_FOLDER) / f"{region}.png") 75 | 76 | print(f"Saving uv mask into: {FLAME_UV_MASK_NPZ}") 77 | np.savez_compressed(FLAME_UV_MASK_NPZ, **masks) 78 | 79 | 80 | if __name__ == "__main__": 81 | tyro.cli(main) -------------------------------------------------------------------------------- /vhap/model/lbs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | 24 | 25 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): 26 | """Calculates the rotation matrices for a batch of rotation vectors 27 | Parameters 28 | ---------- 29 | rot_vecs: torch.tensor Nx3 30 | array of N axis-angle vectors 31 | Returns 32 | ------- 33 | R: torch.tensor Nx3x3 34 | The rotation matrices for the given axis-angle parameters 35 | """ 36 | 37 | batch_size = rot_vecs.shape[0] 38 | device = rot_vecs.device 39 | 40 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) 41 | rot_dir = rot_vecs / angle 42 | 43 | cos = torch.unsqueeze(torch.cos(angle), dim=1) 44 | sin = torch.unsqueeze(torch.sin(angle), dim=1) 45 | 46 | # Bx1 arrays 47 | rx, ry, rz = torch.split(rot_dir, 1, dim=1) 48 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) 49 | 50 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) 51 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view( 52 | (batch_size, 3, 3) 53 | ) 54 | 55 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) 56 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) 57 | return rot_mat 58 | 59 | 60 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): 61 | """Calculates landmarks by barycentric interpolation 62 | 63 | Parameters 64 | ---------- 65 | vertices: torch.tensor BxVx3, dtype = torch.float32 66 | The tensor of input vertices 67 | faces: torch.tensor Fx3, dtype = torch.long 68 | The faces of the mesh 69 | lmk_faces_idx: torch.tensor L, dtype = torch.long 70 | The tensor with the indices of the faces used to calculate the 71 | landmarks. 72 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 73 | The tensor of barycentric coordinates that are used to interpolate 74 | the landmarks 75 | 76 | Returns 77 | ------- 78 | landmarks: torch.tensor BxLx3, dtype = torch.float32 79 | The coordinates of the landmarks for each mesh in the batch 80 | """ 81 | # Extract the indices of the vertices for each face 82 | # BxLx3 83 | batch_size, num_verts = vertices.shape[:2] 84 | device = vertices.device 85 | 86 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( 87 | batch_size, -1, 3 88 | ) 89 | 90 | lmk_faces += ( 91 | torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1) 92 | * num_verts 93 | ) 94 | 95 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3) 96 | 97 | landmarks = torch.einsum("blfi,blf->bli", [lmk_vertices, lmk_bary_coords]) 98 | return landmarks 99 | 100 | 101 | def lbs( 102 | pose, 103 | v_shaped, 104 | posedirs, 105 | J_regressor, 106 | parents, 107 | lbs_weights, 108 | pose2rot=True, 109 | dtype=torch.float32, 110 | ): 111 | """Performs Linear Blend Skinning with the given shape and pose parameters 112 | 113 | Parameters 114 | ---------- 115 | betas : torch.tensor BxNB 116 | The tensor of shape parameters 117 | pose : torch.tensor Bx(J + 1) * 3 118 | The pose parameters in axis-angle format 119 | v_template: torch.tensor BxVx3 120 | The template mesh that will be deformed 121 | shapedirs : torch.tensor 1xNB 122 | The tensor of PCA shape displacements 123 | posedirs : torch.tensor Px(V * 3) 124 | The pose PCA coefficients 125 | J_regressor : torch.tensor JxV 126 | The regressor array that is used to calculate the joints from 127 | the position of the vertices 128 | parents: torch.tensor J 129 | The array that describes the kinematic tree for the model 130 | lbs_weights: torch.tensor N x V x (J + 1) 131 | The linear blend skinning weights that represent how much the 132 | rotation matrix of each part affects each vertex 133 | pose2rot: bool, optional 134 | Flag on whether to convert the input pose tensor to rotation 135 | matrices. The default value is True. If False, then the pose tensor 136 | should already contain rotation matrices and have a size of 137 | Bx(J + 1)x9 138 | dtype: torch.dtype, optional 139 | 140 | Returns 141 | ------- 142 | verts: torch.tensor BxVx3 143 | The vertices of the mesh after applying the shape and pose 144 | displacements. 145 | joints: torch.tensor BxJx3 146 | The joints of the model 147 | """ 148 | 149 | batch_size = pose.shape[0] 150 | device = pose.device 151 | 152 | # Get the joints 153 | # NxJx3 array 154 | J = vertices2joints(J_regressor, v_shaped) 155 | 156 | # 3. Add pose blend shapes 157 | # N x J x 3 x 3 158 | ident = torch.eye(3, dtype=dtype, device=device) 159 | if pose2rot: 160 | rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view( 161 | [batch_size, -1, 3, 3] 162 | ) 163 | 164 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) 165 | # (N x P) x (P, V * 3) -> N x V x 3 166 | pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3) 167 | else: 168 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident 169 | rot_mats = pose.view(batch_size, -1, 3, 3) 170 | 171 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view( 172 | batch_size, -1, 3 173 | ) 174 | 175 | v_posed = pose_offsets + v_shaped 176 | 177 | # 4. Get the global joint location 178 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) 179 | 180 | # 5. Do skinning: 181 | # W is N x V x (J + 1) 182 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) 183 | # (N x V x (J + 1)) x (N x (J + 1) x 16) 184 | num_joints = J_regressor.shape[0] 185 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4) 186 | 187 | homogen_coord = torch.ones( 188 | [batch_size, v_posed.shape[1], 1], dtype=dtype, device=device 189 | ) 190 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) 191 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) 192 | 193 | verts = v_homo[:, :, :3, 0] 194 | 195 | return verts, J_transformed, A[:, 1] 196 | 197 | 198 | def vertices2joints(J_regressor, vertices): 199 | """Calculates the 3D joint locations from the vertices 200 | 201 | Parameters 202 | ---------- 203 | J_regressor : torch.tensor JxV 204 | The regressor array that is used to calculate the joints from the 205 | position of the vertices 206 | vertices : torch.tensor BxVx3 207 | The tensor of mesh vertices 208 | 209 | Returns 210 | ------- 211 | torch.tensor BxJx3 212 | The location of the joints 213 | """ 214 | 215 | return torch.einsum("bik,ji->bjk", [vertices, J_regressor]) 216 | 217 | 218 | def blend_shapes(betas, shape_disps): 219 | """Calculates the per vertex displacement due to the blend shapes 220 | 221 | 222 | Parameters 223 | ---------- 224 | betas : torch.tensor Bx(num_betas) 225 | Blend shape coefficients 226 | shape_disps: torch.tensor Vx3x(num_betas) 227 | Blend shapes 228 | 229 | Returns 230 | ------- 231 | torch.tensor BxVx3 232 | The per-vertex displacement due to shape deformation 233 | """ 234 | 235 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] 236 | # i.e. Multiply each shape displacement by its corresponding beta and 237 | # then sum them. 238 | blend_shape = torch.einsum("bl,mkl->bmk", [betas, shape_disps]) 239 | return blend_shape 240 | 241 | 242 | def transform_mat(R, t): 243 | """Creates a batch of transformation matrices 244 | Args: 245 | - R: Bx3x3 array of a batch of rotation matrices 246 | - t: Bx3x1 array of a batch of translation vectors 247 | Returns: 248 | - T: Bx4x4 Transformation matrix 249 | """ 250 | # No padding left or right, only add an extra row 251 | return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) 252 | 253 | 254 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): 255 | """ 256 | Applies a batch of rigid transformations to the joints 257 | 258 | Parameters 259 | ---------- 260 | rot_mats : torch.tensor BxNx3x3 261 | Tensor of rotation matrices 262 | joints : torch.tensor BxNx3 263 | Locations of joints 264 | parents : torch.tensor BxN 265 | The kinematic tree of each object 266 | dtype : torch.dtype, optional: 267 | The data type of the created tensors, the default is torch.float32 268 | 269 | Returns 270 | ------- 271 | posed_joints : torch.tensor BxNx3 272 | The locations of the joints after applying the pose rotations 273 | rel_transforms : torch.tensor BxNx4x4 274 | The relative (with respect to the root joint) rigid transformations 275 | for all the joints 276 | """ 277 | 278 | joints = torch.unsqueeze(joints, dim=-1) 279 | 280 | rel_joints = joints.clone().contiguous() 281 | rel_joints[:, 1:] = rel_joints[:, 1:] - joints[:, parents[1:]] 282 | 283 | transforms_mat = transform_mat(rot_mats.view(-1, 3, 3), rel_joints.view(-1, 3, 1)) 284 | transforms_mat = transforms_mat.view(-1, joints.shape[1], 4, 4) 285 | 286 | transform_chain = [transforms_mat[:, 0]] 287 | for i in range(1, parents.shape[0]): 288 | # Subtract the joint location at the rest pose 289 | # No need for rotation, since it's identity when at rest 290 | curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i]) 291 | transform_chain.append(curr_res) 292 | 293 | transforms = torch.stack(transform_chain, dim=1) 294 | 295 | # The last column of the transformations contains the posed joints 296 | posed_joints = transforms[:, :, :3, 3] 297 | 298 | joints_homogen = F.pad(joints, [0, 0, 0, 1]) 299 | 300 | rel_transforms = transforms - F.pad( 301 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0] 302 | ) 303 | 304 | return posed_joints, rel_transforms 305 | -------------------------------------------------------------------------------- /vhap/preprocess_video.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | from typing import Literal, Optional, List 4 | import tyro 5 | import ffmpeg 6 | from PIL import Image 7 | import torch 8 | from vhap.data.image_folder_dataset import ImageFolderDataset 9 | from torch.utils.data import DataLoader 10 | from BackgroundMattingV2.model import MattingRefine 11 | from BackgroundMattingV2.asset import get_weights_path 12 | 13 | 14 | def video2frames(video_path: Path, image_dir: Path, keep_video_name: bool=False, target_fps: int=30, n_downsample: int=1): 15 | print(f'Converting video {video_path} to frames with downsample scale {n_downsample}') 16 | if not image_dir.exists(): 17 | image_dir.mkdir(parents=True) 18 | 19 | file_path_stem = video_path.stem + '_' if keep_video_name else '' 20 | 21 | probe = ffmpeg.probe(str(video_path)) 22 | 23 | video_fps = int(probe['streams'][0]['r_frame_rate'].split('/')[0]) 24 | if video_fps ==0: 25 | video_fps = int(probe['streams'][0]['avg_frame_rate'].split('/')[0]) 26 | if video_fps == 0: 27 | # nb_frames / duration 28 | video_fps = int(probe['streams'][0]['nb_frames']) / float(probe['streams'][0]['duration']) 29 | if video_fps == 0: 30 | raise ValueError('Cannot get valid video fps') 31 | 32 | num_frames = int(probe['streams'][0]['nb_frames']) 33 | video = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None) 34 | W = int(video['width']) 35 | H = int(video['height']) 36 | w = W // n_downsample 37 | h = H // n_downsample 38 | print(f'[Video] FPS: {video_fps} | number of frames: {num_frames} | resolution: {W}x{H}') 39 | print(f'[Target] FPS: {target_fps} | number of frames: {round(num_frames * target_fps / int(video_fps))} | resolution: {w}x{h}') 40 | 41 | (ffmpeg 42 | .input(str(video_path)) 43 | .filter('fps', fps=f'{target_fps}') 44 | .filter('scale', width=w, height=h) 45 | .output( 46 | str(image_dir / f'{file_path_stem}%06d.jpg'), 47 | start_number=0, 48 | qscale=1, # lower values mean higher quality (1 is the best, 31 is the worst). 49 | ) 50 | .overwrite_output() 51 | .run(quiet=True) 52 | ) 53 | 54 | def robust_video_matting(image_dir: Path, N_warmup: Optional[int]=10): 55 | print(f'Running robust video matting on images in {image_dir}') 56 | # model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3").cuda() 57 | model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50").cuda() 58 | 59 | dataset = ImageFolderDataset(image_folder=image_dir) 60 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 61 | 62 | # bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background. 63 | rec = [None] * 4 # Initial recurrent states. 64 | downsample_ratio = 0.5 #(for videos in 512x512) 65 | # downsample_ratio = 0.125 #(for videos in 3k) 66 | for item in tqdm(dataloader): 67 | rgb = item['rgb'] 68 | rgb = rgb.permute(0, 3, 1, 2).float().cuda() / 255 69 | with torch.no_grad(): 70 | while N_warmup: 71 | # use the first frame to warm up the recurrent states as if the video 72 | # has N_warmup identical frames at the beginning. This trick effectively 73 | # removes the artifacts for the first frame. 74 | fgr, pha, *rec = model(rgb, *rec, downsample_ratio) # Cycle the recurrent states. 75 | N_warmup -= 1 76 | 77 | fgr, pha, *rec = model(rgb, *rec, downsample_ratio) # Cycle the recurrent states. 78 | # fgr, pha, *rec = model(rgb, *rec) # Cycle the recurrent states. 79 | # fg = fgr * pha + bgr * (1 - pha) 80 | 81 | alpha = (pha[0, 0] * 255).cpu().numpy() 82 | alpha = Image.fromarray(alpha.astype('uint8')) 83 | alpha_path = item['image_path'][0].replace('images', 'alpha_maps') 84 | if not Path(alpha_path).parent.exists(): 85 | Path(alpha_path).parent.mkdir(parents=True) 86 | alpha.save(alpha_path) 87 | 88 | def background_matting_v2( 89 | image_dir: Path, 90 | background_folder: Path=Path('../../BACKGROUND'), 91 | model_backbone: Literal['resnet101', 'resnet50', 'mobilenetv2']='resnet101', 92 | model_backbone_scale: float=0.25, 93 | model_refine_mode: Literal['full', 'sampling', 'thresholding']='thresholding', 94 | model_refine_sample_pixels: int=80_000, 95 | model_refine_threshold: float=0.01, 96 | model_refine_kernel_size: int=3, 97 | ): 98 | model = MattingRefine( 99 | model_backbone, 100 | model_backbone_scale, 101 | model_refine_mode, 102 | model_refine_sample_pixels, 103 | model_refine_threshold, 104 | model_refine_kernel_size 105 | ) 106 | 107 | weights_path = get_weights_path(model_backbone) 108 | 109 | model = model.cuda().eval() 110 | model.load_state_dict(torch.load(weights_path, map_location='cuda', weights_only=True)) 111 | 112 | dataset = ImageFolderDataset( 113 | image_folder=image_dir, 114 | background_folder=background_folder, 115 | background_fname2camId=lambda x: x.split('.')[0].split('_')[1], # image_00001.jpg -> 00001 116 | image_fname2camId=lambda x: x.split('.')[0].split('_')[1], # cam_00001.jpg -> 00001 117 | ) 118 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) 119 | 120 | for item in tqdm(dataloader): 121 | src = item['rgb'] 122 | bgr = item['background'] 123 | src = src.permute(0, 3, 1, 2).float().cuda() / 255 124 | bgr = bgr.permute(0, 3, 1, 2).float().cuda() / 255 125 | 126 | with torch.no_grad(): 127 | pha, fgr, _, _, err, ref = model(src, bgr) 128 | 129 | alpha = (pha[0, 0] * 255).cpu().numpy() 130 | alpha = Image.fromarray(alpha.astype('uint8')) 131 | alpha_path = item['image_path'][0].replace('images', 'alpha_maps') 132 | if not Path(alpha_path).parent.exists(): 133 | Path(alpha_path).parent.mkdir(parents=True) 134 | alpha.save(alpha_path) 135 | 136 | def downsample_frames(image_dir: Path, n_downsample: int): 137 | print(f'Downsample frames in {image_dir} by {n_downsample}') 138 | assert n_downsample in [2, 4, 8] 139 | 140 | image_paths = sorted(list(image_dir.glob('*.jpg'))) 141 | for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)): 142 | # downasample the resolution of images 143 | img = Image.open(image_path) 144 | W, H = img.size 145 | img = img.resize((W // n_downsample, H // n_downsample)) 146 | img.save(image_path) 147 | 148 | def main( 149 | input: Path, 150 | target_fps: int=25, 151 | downsample_scales: List[int]=[], 152 | matting_method: Optional[Literal['robust_video_matting', 'background_matting_v2']]=None, 153 | background_folder: Path=Path('../../BACKGROUND'), 154 | ): 155 | if not input.exists(): 156 | matched_paths = list(input.parent.glob(f"{input.name}")) 157 | if len(matched_paths) == 0: 158 | raise FileNotFoundError(f"Cannot find the directory: {input}") 159 | elif len(matched_paths) == 1: 160 | input = matched_paths[0] 161 | else: 162 | raise FileNotFoundError(f"Found multiple matched folders: {matched_paths}") 163 | 164 | # prepare path 165 | if input.suffix in ['.mov', '.mp4']: 166 | print(f'Processing video file: {input}') 167 | videos = [input] 168 | image_dir = input.parent / input.stem / 'images' 169 | elif input.is_dir(): 170 | # if input is a directory, assume all contained videos are synchronized multiview of the same scene 171 | print(f'Processing directory: {input}') 172 | videos = list(input.glob('cam_*.mp4')) + list(input.glob('images/cam_*.mp4')) 173 | image_dir = input / 'images' 174 | else: 175 | raise ValueError(f"Input should be a video file or a directory containing video files: {input}") 176 | assert len(videos) > 0, f'No video files found in {input}' 177 | 178 | # extract frames 179 | for i, video_path in enumerate(videos): 180 | print(f'\n[{i}/{len(videos)}] Processing video file: {video_path}') 181 | 182 | for n_downsample in [1] + downsample_scales: 183 | image_dir_ = image_dir if n_downsample == 1 else Path(str(image_dir) + f'_{n_downsample}') 184 | video2frames(video_path, image_dir_, keep_video_name=len(videos) > 1, target_fps=target_fps, n_downsample=n_downsample) 185 | 186 | # foreground matting 187 | if matting_method == 'robust_video_matting': 188 | robust_video_matting(image_dir) 189 | elif matting_method == 'background_matting_v2': 190 | background_matting_v2(image_dir, background_folder=background_folder) 191 | elif matting_method is not None: 192 | raise ValueError(f'Unknown matting method: {matting_method}') 193 | 194 | 195 | if __name__ == '__main__': 196 | tyro.cli(main) -------------------------------------------------------------------------------- /vhap/track.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import tyro 11 | 12 | from vhap.config.base import BaseTrackingConfig 13 | from vhap.model.tracker import GlobalTracker 14 | 15 | 16 | if __name__ == "__main__": 17 | tyro.extras.set_accent_color("bright_yellow") 18 | cfg = tyro.cli(BaseTrackingConfig) 19 | 20 | tracker = GlobalTracker(cfg) 21 | tracker.optimize() 22 | -------------------------------------------------------------------------------- /vhap/track_nersemble.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import tyro 11 | 12 | from vhap.config.nersemble import NersembleTrackingConfig 13 | from vhap.model.tracker import GlobalTracker 14 | 15 | 16 | if __name__ == "__main__": 17 | tyro.extras.set_accent_color("bright_yellow") 18 | cfg = tyro.cli(NersembleTrackingConfig) 19 | 20 | tracker = GlobalTracker(cfg) 21 | tracker.optimize() 22 | -------------------------------------------------------------------------------- /vhap/track_nersemble_v2.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import tyro 11 | 12 | from vhap.config.nersemble_v2 import NersembleV2TrackingConfig 13 | from vhap.model.tracker import GlobalTracker 14 | 15 | 16 | if __name__ == "__main__": 17 | tyro.extras.set_accent_color("bright_yellow") 18 | cfg = tyro.cli(NersembleV2TrackingConfig) 19 | 20 | tracker = GlobalTracker(cfg) 21 | tracker.optimize() 22 | -------------------------------------------------------------------------------- /vhap/util/camera.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from typing import Tuple, Literal 11 | import torch 12 | import torch.nn.functional as F 13 | import math 14 | import numpy as np 15 | from scipy.spatial.transform import Rotation 16 | 17 | 18 | def align_cameras_to_axes( 19 | R: torch.Tensor, 20 | T: torch.Tensor, 21 | target_convention: Literal["opengl", "opencv"] = None, 22 | ): 23 | """align the averaged axes of cameras with the world axes. 24 | 25 | Args: 26 | R: rotation matrix (N, 3, 3) 27 | T: translation vector (N, 3) 28 | """ 29 | # The column vectors of R are the basis vectors of each camera. 30 | # We construct new bases by taking the mean directions of axes, then use Gram-Schmidt 31 | # process to make them orthonormal 32 | bases_c2w = gram_schmidt_orthogonalization(R.mean(0)) 33 | if target_convention == "opengl": 34 | bases_c2w[:, [1, 2]] *= -1 # flip y and z axes 35 | elif target_convention == "opencv": 36 | pass 37 | bases_w2c = bases_c2w.t() 38 | 39 | # convert the camera poses into the new coordinate system 40 | R = bases_w2c[None, ...] @ R 41 | T = bases_w2c[None, ...] @ T 42 | return R, T 43 | 44 | 45 | def convert_camera_convention(camera_convention_conversion: str, R: torch.Tensor, K: torch.Tensor, H: int, W: int): 46 | if camera_convention_conversion is not None: 47 | if camera_convention_conversion == "opencv->opengl": 48 | R[:, :3, [1, 2]] *= -1 49 | # flip y of the principal point 50 | K[..., 1, 2] = H - K[..., 1, 2] 51 | elif camera_convention_conversion == "opencv->pytorch3d": 52 | R[:, :3, [0, 1]] *= -1 53 | # flip x and y of the principal point 54 | K[..., 0, 2] = W - K[..., 0, 2] 55 | K[..., 1, 2] = H - K[..., 1, 2] 56 | elif camera_convention_conversion == "opengl->pytorch3d": 57 | R[:, :3, [0, 2]] *= -1 58 | # flip x of the principal point 59 | K[..., 0, 2] = W - K[..., 0, 2] 60 | else: 61 | raise ValueError( 62 | f"Unknown camera coordinate conversion: {camera_convention_conversion}." 63 | ) 64 | return R, K 65 | 66 | 67 | def gram_schmidt_orthogonalization(M: torch.tensor): 68 | """conducting Gram-Schmidt process to transform column vectors into orthogonal bases 69 | 70 | Args: 71 | M: An matrix (num_rows, num_cols) 72 | Return: 73 | M: An matrix with orthonormal column vectors (num_rows, num_cols) 74 | """ 75 | num_rows, num_cols = M.shape 76 | for c in range(1, num_cols): 77 | M[:, [c - 1, c]] = F.normalize(M[:, [c - 1, c]], p=2, dim=0) 78 | M[:, [c]] -= M[:, :c] @ (M[:, :c].T @ M[:, [c]]) 79 | 80 | M[:, -1] = F.normalize(M[:, -1], p=2, dim=0) 81 | return M 82 | 83 | 84 | def projection_from_intrinsics(K: np.ndarray, image_size: Tuple[int], near: float=0.01, far:float=10, flip_y: bool=False, z_sign=-1): 85 | """ 86 | Transform points from camera space (x: right, y: up, z: out) to clip space (x: right, y: down, z: in) 87 | Args: 88 | K: Intrinsic matrix, (N, 3, 3) 89 | K = [[ 90 | [fx, 0, cx], 91 | [0, fy, cy], 92 | [0, 0, 1], 93 | ] 94 | ] 95 | image_size: (height, width) 96 | Output: 97 | proj = [[ 98 | [2*fx/w, 0.0, (w - 2*cx)/w, 0.0 ], 99 | [0.0, 2*fy/h, (h - 2*cy)/h, 0.0 ], 100 | [0.0, 0.0, z_sign*(far+near) / (far-near), -2*far*near / (far-near)], 101 | [0.0, 0.0, z_sign, 0.0 ] 102 | ] 103 | ] 104 | """ 105 | 106 | B = K.shape[0] 107 | h, w = image_size 108 | 109 | if K.shape[-2:] == (3, 3): 110 | fx = K[..., 0, 0] 111 | fy = K[..., 1, 1] 112 | cx = K[..., 0, 2] 113 | cy = K[..., 1, 2] 114 | elif K.shape[-1] == 4: 115 | # fx, fy, cx, cy = K[..., [0, 1, 2, 3]].split(1, dim=-1) 116 | fx = K[..., [0]] 117 | fy = K[..., [1]] 118 | cx = K[..., [2]] 119 | cy = K[..., [3]] 120 | else: 121 | raise ValueError(f"Expected K to be (N, 3, 3) or (N, 4) but got: {K.shape}") 122 | 123 | proj = np.zeros([B, 4, 4]) 124 | proj[:, 0, 0] = fx * 2 / w 125 | proj[:, 1, 1] = fy * 2 / h 126 | proj[:, 0, 2] = (w - 2 * cx) / w 127 | proj[:, 1, 2] = (h - 2 * cy) / h 128 | proj[:, 2, 2] = z_sign * (far+near) / (far-near) 129 | proj[:, 2, 3] = -2*far*near / (far-near) 130 | proj[:, 3, 2] = z_sign 131 | 132 | if flip_y: 133 | proj[:, 1, 1] *= -1 134 | return proj 135 | 136 | 137 | class OrbitCamera: 138 | def __init__(self, W, H, r=2, fovy=60, znear=1e-8, zfar=10, convention: Literal["opengl", "opencv"]="opengl"): 139 | self.image_width = W 140 | self.image_height = H 141 | self.radius_default = r 142 | self.fovy_default = fovy 143 | self.znear = znear 144 | self.zfar = zfar 145 | self.convention = convention 146 | 147 | self.up = np.array([0, 1, 0], dtype=np.float32) 148 | self.reset() 149 | 150 | def reset(self): 151 | """ The internal state of the camera is based on the OpenGL convention, but 152 | properties are converted to the target convention when queried. 153 | """ 154 | self.rot = Rotation.from_matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) # OpenGL convention 155 | self.look_at = np.array([0, 0, 0], dtype=np.float32) # look at this point 156 | self.radius = self.radius_default # camera distance from center 157 | self.fovy = self.fovy_default 158 | if self.convention == "opencv": 159 | self.z_sign = 1 160 | self.y_sign = 1 161 | elif self.convention == "opengl": 162 | self.z_sign = -1 163 | self.y_sign = -1 164 | else: 165 | raise ValueError(f"Unknown convention: {self.convention}") 166 | 167 | @property 168 | def fovx(self): 169 | return self.fovy / self.image_height * self.image_width 170 | 171 | @property 172 | def intrinsics(self): 173 | focal = self.image_height / (2 * np.tan(np.radians(self.fovy) / 2)) 174 | return np.array([focal, focal, self.image_width // 2, self.image_height // 2]) 175 | 176 | @property 177 | def projection_matrix(self): 178 | return projection_from_intrinsics(self.intrinsics[None], (self.image_height, self.image_width), self.znear, self.zfar, z_sign=self.z_sign)[0] 179 | 180 | @property 181 | def world_view_transform(self): 182 | return np.linalg.inv(self.pose) # world2cam 183 | 184 | @property 185 | def full_proj_transform(self): 186 | return self.projection_matrix @ self.world_view_transform 187 | 188 | @property 189 | def pose(self): 190 | # first move camera to radius 191 | pose = np.eye(4, dtype=np.float32) 192 | pose[2, 3] += self.radius 193 | 194 | # rotate 195 | rot = np.eye(4, dtype=np.float32) 196 | rot[:3, :3] = self.rot.as_matrix() 197 | pose = rot @ pose 198 | 199 | # translate 200 | pose[:3, 3] -= self.look_at 201 | 202 | if self.convention == "opencv": 203 | pose[:, [1, 2]] *= -1 204 | elif self.convention == "opengl": 205 | pass 206 | else: 207 | raise ValueError(f"Unknown convention: {self.convention}") 208 | return pose 209 | 210 | def orbit(self, dx, dy): 211 | # rotate along camera up/side axis! 212 | side = self.rot.as_matrix()[:3, 0] 213 | rotvec_x = self.up * np.radians(-0.3 * dx) 214 | rotvec_y = side * np.radians(-0.3 * dy) 215 | self.rot = Rotation.from_rotvec(rotvec_x) * Rotation.from_rotvec(rotvec_y) * self.rot 216 | 217 | def scale(self, delta): 218 | self.radius *= 1.1 ** (-delta) 219 | 220 | def pan(self, dx, dy, dz=0): 221 | # pan in camera coordinate system (careful on the sensitivity!) 222 | d = np.array([dx, -dy, dz]) # the y axis is flipped 223 | self.look_at += 2 * self.rot.as_matrix()[:3, :3] @ d * self.radius / self.image_height * math.tan(np.radians(self.fovy) / 2) 224 | -------------------------------------------------------------------------------- /vhap/util/color_correction.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/tobias-kirschstein/nersemble-data/blob/f96aa8d9d482df53c40c51ecc07203646265e4f0/src/nersemble_data/util/color_correction.py 2 | 3 | import colour 4 | import numpy as np 5 | from colour.characterisation import matrix_augmented_Cheung2004 6 | from colour.utilities import as_float_array 7 | 8 | 9 | def color_correction_Cheung2004_precomputed( 10 | image: np.ndarray, 11 | CCM: np.ndarray, 12 | ) -> np.ndarray: 13 | terms = CCM.shape[-1] 14 | RGB = as_float_array(image) 15 | shape = RGB.shape 16 | 17 | RGB = np.reshape(RGB, (-1, 3)) 18 | 19 | RGB_e = matrix_augmented_Cheung2004(RGB, terms) 20 | 21 | return np.reshape(np.transpose(np.dot(CCM, np.transpose(RGB_e))), shape) 22 | 23 | 24 | def correct_color(image: np.ndarray, ccm: np.ndarray) -> np.ndarray: 25 | is_uint8 = image.dtype == np.uint8 26 | if is_uint8: 27 | image = image / 255. 28 | image_linear = colour.cctf_decoding(image) 29 | image_corrected = color_correction_Cheung2004_precomputed(image_linear, ccm) 30 | image_corrected = colour.cctf_encoding(image_corrected) 31 | if is_uint8: 32 | image_corrected = np.clip(image_corrected * 255, 0, 255).astype(np.uint8) 33 | 34 | return image_corrected 35 | -------------------------------------------------------------------------------- /vhap/util/landmark_detector_fa.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from vhap.util.log import get_logger 11 | 12 | from typing import Literal 13 | from tqdm import tqdm 14 | 15 | import face_alignment 16 | import numpy as np 17 | import matplotlib.path as mpltPath 18 | 19 | from fdlite import ( 20 | FaceDetection, 21 | FaceLandmark, 22 | face_detection_to_roi, 23 | IrisLandmark, 24 | iris_roi_from_face_landmarks, 25 | ) 26 | 27 | logger = get_logger(__name__) 28 | 29 | 30 | class LandmarkDetectorFA: 31 | 32 | IMAGE_FILE_NAME = "image_0000.png" 33 | LMK_FILE_NAME = "keypoints_static_0000.json" 34 | 35 | def __init__( 36 | self, 37 | face_detector:Literal["sfd", "blazeface"]="sfd", 38 | ): 39 | """ 40 | Creates dataset_path where all results are stored 41 | :param video_path: path to video file 42 | :param dataset_path: path to results directory 43 | """ 44 | 45 | logger.info("Initialize FaceAlignment module...") 46 | # 68 facial landmark detector 47 | self.fa = face_alignment.FaceAlignment( 48 | face_alignment.LandmarksType.TWO_HALF_D, 49 | face_detector=face_detector, 50 | flip_input=True, 51 | device="cuda" 52 | ) 53 | 54 | def detect_single_image(self, img): 55 | bbox = self.fa.face_detector.detect_from_image(img) 56 | 57 | if len(bbox) == 0: 58 | lmks = np.zeros([68, 3]) - 1 # set to -1 when landmarks is inavailable 59 | 60 | else: 61 | if len(bbox) > 1: 62 | # if multiple boxes detected, use the one with highest confidence 63 | bbox = [bbox[np.argmax(np.array(bbox)[:, -1])]] 64 | 65 | lmks = self.fa.get_landmarks_from_image(img, detected_faces=bbox)[0] 66 | lmks = np.concatenate([lmks, np.ones_like(lmks[:, :1])], axis=1) 67 | 68 | if (lmks[:, :2] == -1).sum() > 0: 69 | lmks[:, 2:] = 0.0 70 | else: 71 | lmks[:, 2:] = 1.0 72 | 73 | h, w = img.shape[:2] 74 | lmks[:, 0] /= w 75 | lmks[:, 1] /= h 76 | bbox[0][[0, 2]] /= w 77 | bbox[0][[1, 3]] /= h 78 | return bbox, lmks 79 | 80 | def detect_dataset(self, dataloader): 81 | """ 82 | Annotates each frame with 68 facial landmarks 83 | :return: dict mapping frame number to landmarks numpy array and the same thing for bboxes 84 | """ 85 | landmarks = {} 86 | bboxes = {} 87 | 88 | logger.info("Begin annotating landmarks...") 89 | for item in tqdm(dataloader): 90 | timestep_id = item["timestep_id"][0] 91 | camera_id = item["camera_id"][0] 92 | scale_factor = item["scale_factor"][0] 93 | 94 | logger.info( 95 | f"Annotate facial landmarks for timestep: {timestep_id}, camera: {camera_id}" 96 | ) 97 | img = item["rgb"][0].numpy() 98 | 99 | bbox, lmks = self.detect_single_image(img) 100 | 101 | if len(bbox) == 0: 102 | logger.error( 103 | f"No bbox found for frame: {timestep_id}, camera: {camera_id}. Setting landmarks to all -1." 104 | ) 105 | 106 | if camera_id not in landmarks: 107 | landmarks[camera_id] = {} 108 | if camera_id not in bboxes: 109 | bboxes[camera_id] = {} 110 | landmarks[camera_id][timestep_id] = lmks 111 | bboxes[camera_id][timestep_id] = bbox[0] if len(bbox) > 0 else np.zeros(5) - 1 112 | return landmarks, bboxes 113 | 114 | def annotate_iris_landmarks(self, dataloader): 115 | """ 116 | Annotates each frame with 2 iris landmarks 117 | :return: dict mapping frame number to landmarks numpy array 118 | """ 119 | 120 | # iris detector 121 | detect_faces = FaceDetection() 122 | detect_face_landmarks = FaceLandmark() 123 | detect_iris_landmarks = IrisLandmark() 124 | 125 | landmarks = {} 126 | 127 | for item in tqdm(dataloader): 128 | timestep_id = item["timestep_id"][0] 129 | camera_id = item["camera_id"][0] 130 | scale_factor = item["scale_factor"][0] 131 | if timestep_id not in landmarks: 132 | landmarks[timestep_id] = {} 133 | logger.info( 134 | f"Annotate iris landmarks for timestep: {timestep_id}, camera: {camera_id}" 135 | ) 136 | 137 | img = item["rgb"][0].numpy() 138 | 139 | height, width = img.shape[:2] 140 | img_size = (width, height) 141 | 142 | face_detections = detect_faces(img) 143 | if len(face_detections) != 1: 144 | logger.error("Empty iris landmarks (type 1)") 145 | landmarks[timestep_id][camera_id] = None 146 | else: 147 | for face_detection in face_detections: 148 | try: 149 | face_roi = face_detection_to_roi(face_detection, img_size) 150 | except ValueError: 151 | logger.error("Empty iris landmarks (type 2)") 152 | landmarks[timestep_id][camera_id] = None 153 | break 154 | 155 | face_landmarks = detect_face_landmarks(img, face_roi) 156 | if len(face_landmarks) == 0: 157 | logger.error("Empty iris landmarks (type 3)") 158 | landmarks[timestep_id][camera_id] = None 159 | break 160 | 161 | iris_rois = iris_roi_from_face_landmarks(face_landmarks, img_size) 162 | 163 | if len(iris_rois) != 2: 164 | logger.error("Empty iris landmarks (type 4)") 165 | landmarks[timestep_id][camera_id] = None 166 | break 167 | 168 | lmks = [] 169 | for iris_roi in iris_rois[::-1]: 170 | try: 171 | iris_landmarks = detect_iris_landmarks(img, iris_roi).iris[ 172 | 0:1 173 | ] 174 | except np.linalg.LinAlgError: 175 | logger.error("Failed to get iris landmarks") 176 | landmarks[timestep_id][camera_id] = None 177 | break 178 | 179 | for landmark in iris_landmarks: 180 | lmks.append([landmark.x * width, landmark.y * height, 1.0]) 181 | 182 | lmks = np.array(lmks, dtype=np.float32) 183 | 184 | h, w = img.shape[:2] 185 | lmks[:, 0] /= w 186 | lmks[:, 1] /= h 187 | 188 | landmarks[timestep_id][camera_id] = lmks 189 | 190 | return landmarks 191 | 192 | def iris_consistency(self, lm_iris, lm_eye): 193 | """ 194 | Checks if landmarks for eye and iris are consistent 195 | :param lm_iris: 196 | :param lm_eye: 197 | :return: 198 | """ 199 | lm_iris = lm_iris[:, :2] 200 | lm_eye = lm_eye[:, :2] 201 | 202 | polygon_eye = mpltPath.Path(lm_eye) 203 | valid = polygon_eye.contains_points(lm_iris) 204 | 205 | return valid[0] 206 | 207 | def annotate_landmarks(self, dataloader, add_iris=False): 208 | """ 209 | Annotates each frame with landmarks for face and iris. Assumes frames have been extracted 210 | :param add_iris: 211 | :return: 212 | """ 213 | lmks_face, bboxes_faces = self.detect_dataset(dataloader) 214 | 215 | if add_iris: 216 | lmks_iris = self.annotate_iris_landmarks(dataloader) 217 | 218 | # check conistency of iris landmarks and facial keypoints 219 | for camera_id, lmk_face_camera in lmks_face.items(): 220 | for timestep_id in lmk_face_camera.keys(): 221 | 222 | discard_iris_lmks = False 223 | bboxes_face_i = bboxes_faces[camera_id][timestep_id] 224 | if bboxes_face_i is not None: 225 | lmks_face_i = lmks_face[camera_id][timestep_id] 226 | lmks_iris_i = lmks_iris[camera_id][timestep_id] 227 | if lmks_iris_i is not None: 228 | 229 | # validate iris landmarks 230 | left_face = lmks_face_i[36:42] 231 | right_face = lmks_face_i[42:48] 232 | 233 | right_iris = lmks_iris_i[:1] 234 | left_iris = lmks_iris_i[1:] 235 | 236 | if not ( 237 | self.iris_consistency(left_iris, left_face) 238 | and self.iris_consistency(right_iris, right_face) 239 | ): 240 | logger.error( 241 | f"Inconsistent iris landmarks for timestep: {timestep_id}, camera: {camera_id}" 242 | ) 243 | discard_iris_lmks = True 244 | else: 245 | logger.error( 246 | f"No iris landmarks detected for timestep: {timestep_id}, camera: {camera_id}" 247 | ) 248 | discard_iris_lmks = True 249 | 250 | else: 251 | logger.error( 252 | f"Discarding iris landmarks because no face landmark is available for timestep: {timestep_id}, camera: {camera_id}" 253 | ) 254 | discard_iris_lmks = True 255 | 256 | if discard_iris_lmks: 257 | lmks_iris[timestep_id][camera_id] = ( 258 | np.zeros([2, 3]) - 1 259 | ) # set to -1 for inconsistent iris landmarks 260 | 261 | # construct final json 262 | for camera_id, lmk_face_camera in lmks_face.items(): 263 | bounding_box = [] 264 | face_landmark_2d = [] 265 | iris_landmark_2d = [] 266 | for timestep_id in lmk_face_camera.keys(): 267 | bounding_box.append(bboxes_faces[camera_id][timestep_id][None]) 268 | face_landmark_2d.append(lmks_face[camera_id][timestep_id][None]) 269 | 270 | if add_iris: 271 | iris_landmark_2d.append(lmks_iris[camera_id][timestep_id][None]) 272 | 273 | lmk_dict = { 274 | "bounding_box": bounding_box, 275 | "face_landmark_2d": face_landmark_2d, 276 | } 277 | if len(iris_landmark_2d) > 0: 278 | lmk_dict["iris_landmark_2d"] = iris_landmark_2d 279 | 280 | for k, v in lmk_dict.items(): 281 | if len(v) > 0: 282 | lmk_dict[k] = np.concatenate(v, axis=0) 283 | out_path = dataloader.dataset.get_property_path( 284 | "landmark2d/face-alignment", camera_id=camera_id 285 | ) 286 | logger.info(f"Saving landmarks to: {out_path}") 287 | if not out_path.parent.exists(): 288 | out_path.parent.mkdir(parents=True) 289 | np.savez(out_path, **lmk_dict) 290 | 291 | 292 | if __name__ == "__main__": 293 | import tyro 294 | from tqdm import tqdm 295 | from torch.utils.data import DataLoader 296 | from vhap.config.base import DataConfig, import_module 297 | 298 | cfg = tyro.cli(DataConfig) 299 | dataset = import_module(cfg._target)( 300 | cfg=cfg, 301 | img_to_tensor=False, 302 | batchify_all_views=True, 303 | ) 304 | dataset.items = dataset.items[:2] 305 | 306 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) 307 | 308 | detector = LandmarkDetectorFA() 309 | detector.annotate_landmarks(dataloader) 310 | -------------------------------------------------------------------------------- /vhap/util/landmark_detector_star.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | from tqdm import tqdm 11 | import copy 12 | import argparse 13 | import torch 14 | import math 15 | import cv2 16 | import numpy as np 17 | import dlib 18 | 19 | from star.lib import utility 20 | from star.asset import predictor_path, model_path 21 | 22 | from vhap.util.log import get_logger 23 | logger = get_logger(__name__) 24 | 25 | 26 | class GetCropMatrix(): 27 | """ 28 | from_shape -> transform_matrix 29 | """ 30 | 31 | def __init__(self, image_size, target_face_scale, align_corners=False): 32 | self.image_size = image_size 33 | self.target_face_scale = target_face_scale 34 | self.align_corners = align_corners 35 | 36 | def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center): 37 | cosv = math.cos(angle) 38 | sinv = math.sin(angle) 39 | 40 | fx, fy = from_center 41 | tx, ty = to_center 42 | 43 | acos = scale * cosv 44 | asin = scale * sinv 45 | 46 | a0 = acos 47 | a1 = -asin 48 | a2 = tx - acos * fx + asin * fy + shift_xy[0] 49 | 50 | b0 = asin 51 | b1 = acos 52 | b2 = ty - asin * fx - acos * fy + shift_xy[1] 53 | 54 | rot_scale_m = np.array([ 55 | [a0, a1, a2], 56 | [b0, b1, b2], 57 | [0.0, 0.0, 1.0] 58 | ], np.float32) 59 | return rot_scale_m 60 | 61 | def process(self, scale, center_w, center_h): 62 | if self.align_corners: 63 | to_w, to_h = self.image_size - 1, self.image_size - 1 64 | else: 65 | to_w, to_h = self.image_size, self.image_size 66 | 67 | rot_mu = 0 68 | scale_mu = self.image_size / (scale * self.target_face_scale * 200.0) 69 | shift_xy_mu = (0, 0) 70 | matrix = self._compose_rotate_and_scale( 71 | rot_mu, scale_mu, shift_xy_mu, 72 | from_center=[center_w, center_h], 73 | to_center=[to_w / 2.0, to_h / 2.0]) 74 | return matrix 75 | 76 | 77 | class TransformPerspective(): 78 | """ 79 | image, matrix3x3 -> transformed_image 80 | """ 81 | 82 | def __init__(self, image_size): 83 | self.image_size = image_size 84 | 85 | def process(self, image, matrix): 86 | return cv2.warpPerspective( 87 | image, matrix, dsize=(self.image_size, self.image_size), 88 | flags=cv2.INTER_LINEAR, borderValue=0) 89 | 90 | 91 | class TransformPoints2D(): 92 | """ 93 | points (nx2), matrix (3x3) -> points (nx2) 94 | """ 95 | 96 | def process(self, srcPoints, matrix): 97 | # nx3 98 | desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1) 99 | desPoints = desPoints @ np.transpose(matrix) # nx3 100 | desPoints = desPoints[:, :2] / desPoints[:, [2, 2]] 101 | return desPoints.astype(srcPoints.dtype) 102 | 103 | 104 | class Alignment: 105 | def __init__(self, args, model_path, dl_framework, device_ids): 106 | self.input_size = 256 107 | self.target_face_scale = 1.0 108 | self.dl_framework = dl_framework 109 | 110 | # model 111 | if self.dl_framework == "pytorch": 112 | # conf 113 | self.config = utility.get_config(args) 114 | self.config.device_id = device_ids[0] 115 | # set environment 116 | utility.set_environment(self.config) 117 | self.config.init_instance() 118 | if self.config.logger is not None: 119 | self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id)) 120 | self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()])) 121 | 122 | net = utility.get_net(self.config) 123 | if device_ids == [-1]: 124 | checkpoint = torch.load(model_path, map_location="cpu") 125 | else: 126 | checkpoint = torch.load(model_path) 127 | net.load_state_dict(checkpoint["net"]) 128 | net = net.to(self.config.device_id) 129 | net.eval() 130 | self.alignment = net 131 | else: 132 | assert False 133 | 134 | self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale, 135 | align_corners=True) 136 | self.transformPerspective = TransformPerspective(image_size=self.input_size) 137 | self.transformPoints2D = TransformPoints2D() 138 | 139 | def norm_points(self, points, align_corners=False): 140 | if align_corners: 141 | # [0, SIZE-1] -> [-1, +1] 142 | return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1 143 | else: 144 | # [-0.5, SIZE-0.5] -> [-1, +1] 145 | return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1 146 | 147 | def denorm_points(self, points, align_corners=False): 148 | if align_corners: 149 | # [-1, +1] -> [0, SIZE-1] 150 | return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) 151 | else: 152 | # [-1, +1] -> [-0.5, SIZE-0.5] 153 | return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2 154 | 155 | def preprocess(self, image, scale, center_w, center_h): 156 | matrix = self.getCropMatrix.process(scale, center_w, center_h) 157 | input_tensor = self.transformPerspective.process(image, matrix) 158 | input_tensor = input_tensor[np.newaxis, :] 159 | 160 | input_tensor = torch.from_numpy(input_tensor) 161 | input_tensor = input_tensor.float().permute(0, 3, 1, 2) 162 | input_tensor = input_tensor / 255.0 * 2.0 - 1.0 163 | input_tensor = input_tensor.to(self.config.device_id) 164 | return input_tensor, matrix 165 | 166 | def postprocess(self, srcPoints, coeff): 167 | # dstPoints = self.transformPoints2D.process(srcPoints, coeff) 168 | # matrix^(-1) * src = dst 169 | # src = matrix * dst 170 | dstPoints = np.zeros(srcPoints.shape, dtype=np.float32) 171 | for i in range(srcPoints.shape[0]): 172 | dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2] 173 | dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2] 174 | return dstPoints 175 | 176 | def analyze(self, image, scale, center_w, center_h): 177 | input_tensor, matrix = self.preprocess(image, scale, center_w, center_h) 178 | 179 | if self.dl_framework == "pytorch": 180 | with torch.no_grad(): 181 | output = self.alignment(input_tensor) 182 | landmarks = output[-1][0] 183 | else: 184 | assert False 185 | 186 | landmarks = self.denorm_points(landmarks) 187 | landmarks = landmarks.data.cpu().numpy()[0] 188 | landmarks = self.postprocess(landmarks, np.linalg.inv(matrix)) 189 | 190 | return landmarks 191 | 192 | 193 | def draw_pts(img, pts, mode="pts", shift=4, color=(0, 255, 0), radius=1, thickness=1, save_path=None, dif=0, 194 | scale=0.3, concat=False, ): 195 | img_draw = copy.deepcopy(img) 196 | for cnt, p in enumerate(pts): 197 | if mode == "index": 198 | cv2.putText(img_draw, str(cnt), (int(float(p[0] + dif)), int(float(p[1] + dif))), cv2.FONT_HERSHEY_SIMPLEX, 199 | scale, color, thickness) 200 | elif mode == 'pts': 201 | if len(img_draw.shape) > 2: 202 | # 此处来回切换是因为opencv的bug 203 | img_draw = cv2.cvtColor(img_draw, cv2.COLOR_BGR2RGB) 204 | img_draw = cv2.cvtColor(img_draw, cv2.COLOR_RGB2BGR) 205 | cv2.circle(img_draw, (int(p[0] * (1 << shift)), int(p[1] * (1 << shift))), radius << shift, color, -1, 206 | cv2.LINE_AA, shift=shift) 207 | else: 208 | raise NotImplementedError 209 | if concat: 210 | img_draw = np.concatenate((img, img_draw), axis=1) 211 | if save_path is not None: 212 | cv2.imwrite(save_path, img_draw) 213 | return img_draw 214 | 215 | 216 | class LandmarkDetectorSTAR: 217 | def __init__( 218 | self, 219 | ): 220 | self.detector = dlib.get_frontal_face_detector() 221 | self.shape_predictor = dlib.shape_predictor(predictor_path) 222 | 223 | # facial landmark detector 224 | args = argparse.Namespace() 225 | args.config_name = 'alignment' 226 | # could be downloaded here: https://drive.google.com/file/d/1aOx0wYEZUfBndYy_8IYszLPG_D2fhxrT/view 227 | # model_path = '/path/to/WFLW_STARLoss_NME_4_02_FR_2_32_AUC_0_605.pkl' 228 | device_ids = '0' 229 | device_ids = list(map(int, device_ids.split(","))) 230 | self.alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids) 231 | 232 | def detect_single_image(self, img): 233 | bbox = self.detector(img, 1) 234 | 235 | if len(bbox) == 0: 236 | bbox = np.zeros(5) - 1 237 | lmks = np.zeros([68, 3]) - 1 # set to -1 when landmarks is inavailable 238 | else: 239 | face = self.shape_predictor(img, bbox[0]) 240 | shape = [] 241 | for i in range(68): 242 | x = face.part(i).x 243 | y = face.part(i).y 244 | shape.append((x, y)) 245 | shape = np.array(shape) 246 | x1, x2 = shape[:, 0].min(), shape[:, 0].max() 247 | y1, y2 = shape[:, 1].min(), shape[:, 1].max() 248 | scale = min(x2 - x1, y2 - y1) / 200 * 1.05 249 | center_w = (x2 + x1) / 2 250 | center_h = (y2 + y1) / 2 251 | 252 | scale, center_w, center_h = float(scale), float(center_w), float(center_h) 253 | lmks = self.alignment.analyze(img, scale, center_w, center_h) 254 | 255 | h, w = img.shape[:2] 256 | 257 | lmks = np.concatenate([lmks, np.ones([lmks.shape[0], 1])], axis=1).astype(np.float32) # (x, y, 1) 258 | lmks[:, 0] /= w 259 | lmks[:, 1] /= h 260 | 261 | bbox = np.array([bbox[0].left(), bbox[0].top(), bbox[0].right(), bbox[0].bottom(), 1.]).astype(np.float32) # (x1, y1, x2, y2, score) 262 | bbox[[0, 2]] /= w 263 | bbox[[1, 3]] /= h 264 | 265 | return bbox, lmks 266 | 267 | def detect_dataset(self, dataloader): 268 | """ 269 | Annotates each frame with 68 facial landmarks 270 | :return: dict mapping frame number to landmarks numpy array and the same thing for bboxes 271 | """ 272 | logger.info("Initialize Landmark Detector (STAR)...") 273 | # 68 facial landmark detector 274 | 275 | landmarks = {} 276 | bboxes = {} 277 | 278 | logger.info("Begin annotating landmarks...") 279 | for item in tqdm(dataloader): 280 | timestep_id = item["timestep_id"][0] 281 | camera_id = item["camera_id"][0] 282 | 283 | logger.info( 284 | f"Annotate facial landmarks for timestep: {timestep_id}, camera: {camera_id}" 285 | ) 286 | img = item["rgb"][0].numpy() 287 | 288 | bbox, lmks = self.detect_single_image(img) 289 | if len(bbox) == 0: 290 | logger.error( 291 | f"No bbox found for frame: {timestep_id}, camera: {camera_id}. Setting landmarks to all -1." 292 | ) 293 | 294 | if camera_id not in landmarks: 295 | landmarks[camera_id] = {} 296 | if camera_id not in bboxes: 297 | bboxes[camera_id] = {} 298 | landmarks[camera_id][timestep_id] = lmks 299 | bboxes[camera_id][timestep_id] = bbox 300 | return landmarks, bboxes 301 | 302 | def annotate_landmarks(self, dataloader): 303 | """ 304 | Annotates each frame with landmarks for face and iris. Assumes frames have been extracted 305 | :return: 306 | """ 307 | lmks_face, bboxes_faces = self.detect_dataset(dataloader) 308 | 309 | # construct final json 310 | for camera_id, lmk_face_camera in lmks_face.items(): 311 | bounding_box = [] 312 | face_landmark_2d = [] 313 | for timestep_id in lmk_face_camera.keys(): 314 | bounding_box.append(bboxes_faces[camera_id][timestep_id][None]) 315 | face_landmark_2d.append(lmks_face[camera_id][timestep_id][None]) 316 | 317 | lmk_dict = { 318 | "bounding_box": bounding_box, 319 | "face_landmark_2d": face_landmark_2d, 320 | } 321 | 322 | for k, v in lmk_dict.items(): 323 | if len(v) > 0: 324 | lmk_dict[k] = np.concatenate(v, axis=0) 325 | out_path = dataloader.dataset.get_property_path( 326 | "landmark2d/STAR", camera_id=camera_id 327 | ) 328 | logger.info(f"Saving landmarks to: {out_path}") 329 | if not out_path.parent.exists(): 330 | out_path.parent.mkdir(parents=True) 331 | np.savez(out_path, **lmk_dict) 332 | 333 | 334 | if __name__ == "__main__": 335 | import tyro 336 | from tqdm import tqdm 337 | from torch.utils.data import DataLoader 338 | from vhap.config.base import DataConfig, import_module 339 | 340 | cfg = tyro.cli(DataConfig) 341 | dataset = import_module(cfg._target)( 342 | cfg=cfg, 343 | img_to_tensor=False, 344 | batchify_all_views=True, 345 | ) 346 | dataset.items = dataset.items[:2] 347 | 348 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) 349 | 350 | detector = LandmarkDetectorSTAR() 351 | detector.annotate_landmarks(dataloader) 352 | -------------------------------------------------------------------------------- /vhap/util/log.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import logging 11 | import sys 12 | from datetime import datetime 13 | import atexit 14 | from pathlib import Path 15 | 16 | 17 | def _colored(msg, color): 18 | colors = {'red': '\033[91m', 'green': '\033[92m', 'yellow': '\033[93m', 'normal': '\033[0m'} 19 | return colors[color] + msg + colors["normal"] 20 | 21 | 22 | class ColorFormatter(logging.Formatter): 23 | """ 24 | Class to make command line log entries more appealing 25 | Inspired by https://github.com/facebookresearch/detectron2 26 | """ 27 | 28 | def formatMessage(self, record): 29 | """ 30 | Print warnings yellow and errors red 31 | :param record: 32 | :return: 33 | """ 34 | log = super().formatMessage(record) 35 | if record.levelno == logging.WARNING: 36 | prefix = _colored("WARNING", "yellow") 37 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 38 | prefix = _colored("ERROR", "red") 39 | else: 40 | return log 41 | return prefix + " " + log 42 | 43 | 44 | def get_logger(name, level=logging.DEBUG, root=False, log_dir=None): 45 | """ 46 | Replaces the standard library logging.getLogger call in order to make some configuration 47 | for all loggers. 48 | :param name: pass the __name__ variable 49 | :param level: the desired log level 50 | :param root: call only once in the program 51 | :param log_dir: if root is set to True, this defines the directory where a log file is going 52 | to be created that contains all logging output 53 | :return: the logger object 54 | """ 55 | logger = logging.getLogger(name) 56 | logger.setLevel(level) 57 | 58 | if root: 59 | # create handler for console 60 | console_handler = logging.StreamHandler(sys.stdout) 61 | console_handler.setLevel(level) 62 | formatter = ColorFormatter(_colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 63 | datefmt="%m/%d %H:%M:%S") 64 | console_handler.setFormatter(formatter) 65 | logger.addHandler(console_handler) 66 | logger.propagate = False # otherwise root logger prints things again 67 | 68 | if log_dir is not None: 69 | # add handler to log to a file 70 | log_dir = Path(log_dir) 71 | if not log_dir.exists(): 72 | logger.info(f"Logging directory {log_dir} does not exist and will be created") 73 | log_dir.mkdir(parents=True) 74 | timestamp = datetime.now().strftime("%d-%m-%Y_%H-%M-%S") 75 | log_file = log_dir / f"{timestamp}.log" 76 | 77 | # open stream and make sure it will be closed 78 | stream = log_file.open(mode="w") 79 | atexit.register(stream.close) 80 | 81 | formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s", 82 | datefmt="%m/%d %H:%M:%S") 83 | file_handler = logging.StreamHandler(stream) 84 | file_handler.setLevel(level) 85 | file_handler.setFormatter(formatter) 86 | logger.addHandler(file_handler) 87 | 88 | return logger 89 | -------------------------------------------------------------------------------- /vhap/util/mesh.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import torch 11 | 12 | 13 | def get_mtl_content(tex_fname): 14 | return f'newmtl Material\nmap_Kd {tex_fname}\n' 15 | 16 | def get_obj_content(vertices, faces, uv_coordinates=None, uv_indices=None, mtl_fname=None): 17 | obj = ('# Generated with multi-view-head-tracker\n') 18 | 19 | if mtl_fname is not None: 20 | obj += f'mtllib {mtl_fname}\n' 21 | obj += 'usemtl Material\n' 22 | 23 | # Write the vertices 24 | for vertex in vertices: 25 | obj += f"v {vertex[0]} {vertex[1]} {vertex[2]}\n" 26 | 27 | # Write the UV coordinates 28 | if uv_coordinates is not None: 29 | for uv in uv_coordinates: 30 | obj += f"vt {uv[0]} {uv[1]}\n" 31 | 32 | # Write the faces with UV indices 33 | if uv_indices is not None: 34 | for face, uv_indices in zip(faces, uv_indices): 35 | obj += f"f {face[0]+1}/{uv_indices[0]+1} {face[1]+1}/{uv_indices[1]+1} {face[2]+1}/{uv_indices[2]+1}\n" 36 | else: 37 | for face in faces: 38 | obj += f"f {face[0]+1} {face[1]+1} {face[2]+1}\n" 39 | return obj 40 | 41 | def normalize_image_points(u, v, resolution): 42 | """ 43 | normalizes u, v coordinates from [0 ,image_size] to [-1, 1] 44 | :param u: 45 | :param v: 46 | :param resolution: 47 | :return: 48 | """ 49 | u = 2 * (u - resolution[1] / 2.0) / resolution[1] 50 | v = 2 * (v - resolution[0] / 2.0) / resolution[0] 51 | return u, v 52 | 53 | 54 | def face_vertices(vertices, faces): 55 | """ 56 | :param vertices: [batch size, number of vertices, 3] 57 | :param faces: [batch size, number of faces, 3] 58 | :return: [batch size, number of faces, 3, 3] 59 | """ 60 | assert vertices.ndimension() == 3 61 | assert faces.ndimension() == 3 62 | assert vertices.shape[0] == faces.shape[0] 63 | assert vertices.shape[2] == 3 64 | assert faces.shape[2] == 3 65 | 66 | bs, nv = vertices.shape[:2] 67 | bs, nf = faces.shape[:2] 68 | device = vertices.device 69 | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] 70 | vertices = vertices.reshape((bs * nv, 3)) 71 | # pytorch only supports long and byte tensors for indexing 72 | return vertices[faces.long()] 73 | 74 | -------------------------------------------------------------------------------- /vhap/util/render_uvmap.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import tyro 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import torch 14 | import nvdiffrast.torch as dr 15 | 16 | from vhap.model.flame import FlameHead 17 | 18 | 19 | FLAME_TEX_PATH = "asset/flame/FLAME_texture.npz" 20 | 21 | 22 | def transform_vt(vt): 23 | """Transform uv vertices to clip space""" 24 | xy = vt * 2 - 1 25 | w = torch.ones([1, vt.shape[-2], 1]).to(vt) 26 | z = -w # In the clip spcae of OpenGL, the camera looks at -z 27 | xyzw = torch.cat([xy[None, :, :], z, w], axis=-1) 28 | return xyzw 29 | 30 | def render_uvmap_vtex(glctx, pos, pos_idx, v_color, col_idx, resolution): 31 | """Render uv map with vertex color""" 32 | pos_clip = transform_vt(pos) 33 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution) 34 | 35 | color, _ = dr.interpolate(v_color, rast_out, col_idx) 36 | color = dr.antialias(color, rast_out, pos_clip, pos_idx) 37 | return color 38 | 39 | def render_uvmap_texmap(glctx, pos, pos_idx, verts_uv, faces_uv, tex, resolution, enable_mip=True, max_mip_level=None): 40 | """Render uv map with texture map""" 41 | pos_clip = transform_vt(pos) 42 | rast_out, rast_out_db = dr.rasterize(glctx, pos_clip, pos_idx, resolution) 43 | 44 | if enable_mip: 45 | texc, texd = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv, rast_db=rast_out_db, diff_attrs='all') 46 | color = dr.texture(tex[None, ...], texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=max_mip_level) 47 | else: 48 | texc, _ = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv) 49 | color = dr.texture(tex[None, ...], texc, filter_mode='linear') 50 | color = dr.antialias(color, rast_out, pos_clip, pos_idx) 51 | return color 52 | 53 | 54 | def main( 55 | use_texmap: bool = False, 56 | use_opengl: bool = False, 57 | ): 58 | n_shape = 300 59 | n_expr = 100 60 | print("Initialization FLAME model") 61 | flame_model = FlameHead(n_shape, n_expr) 62 | 63 | verts_uv = flame_model.verts_uvs.cuda() 64 | verts_uv[:, 1] = 1 - verts_uv[:, 1] 65 | faces_uv = flame_model.textures_idx.int().cuda() 66 | 67 | # Rasterizer context 68 | glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext() 69 | 70 | h, w = 512, 512 71 | resolution = (h, w) 72 | 73 | if use_texmap: 74 | tex = torch.from_numpy(np.load(FLAME_TEX_PATH)['mean']).cuda().float().flip(dims=[-1]) / 255 75 | rgb = render_uvmap_texmap(glctx, verts_uv, faces_uv, verts_uv, faces_uv, tex, resolution, enable_mip=True) 76 | else: 77 | v_color = torch.ones(verts_uv.shape[0], 3).to(verts_uv) 78 | col_idx = faces_uv 79 | rgb = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution) 80 | 81 | plt.imshow(rgb[0, :, :].cpu()) 82 | plt.show() 83 | 84 | 85 | if __name__ == "__main__": 86 | tyro.cli(main) 87 | -------------------------------------------------------------------------------- /vhap/util/vector_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 5 | return torch.sum(x*y, -1, keepdim=True) 6 | 7 | def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: 8 | return 2*dot(x, n)*n - x 9 | 10 | def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: 11 | return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN 12 | 13 | def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: 14 | return x / length(x, eps) 15 | 16 | def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: 17 | return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) 18 | -------------------------------------------------------------------------------- /vhap/util/visualization.py: -------------------------------------------------------------------------------- 1 | # 2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 3 | # property and proprietary rights in and to this software and related documentation. 4 | # Any commercial use, reproduction, disclosure or distribution of this software and 5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA 6 | # is strictly prohibited. 7 | # 8 | 9 | 10 | import matplotlib.pyplot as plt 11 | import torch 12 | from torchvision.utils import draw_bounding_boxes, draw_keypoints 13 | 14 | 15 | connectivity_face = ( 16 | [(i, i + 1) for i in list(range(0, 16))] 17 | + [(i, i + 1) for i in list(range(17, 21))] 18 | + [(i, i + 1) for i in list(range(22, 26))] 19 | + [(i, i + 1) for i in list(range(27, 30))] 20 | + [(i, i + 1) for i in list(range(31, 35))] 21 | + [(i, i + 1) for i in list(range(36, 41))] 22 | + [(36, 41)] 23 | + [(i, i + 1) for i in list(range(42, 47))] 24 | + [(42, 47)] 25 | + [(i, i + 1) for i in list(range(48, 59))] 26 | + [(48, 59)] 27 | + [(i, i + 1) for i in list(range(60, 67))] 28 | + [(60, 67)] 29 | ) 30 | 31 | 32 | def plot_landmarks_2d( 33 | img: torch.tensor, 34 | lmks: torch.tensor, 35 | connectivity=None, 36 | colors="white", 37 | unit=1, 38 | input_float=False, 39 | ): 40 | if input_float: 41 | img = (img * 255).byte() 42 | 43 | img = draw_keypoints( 44 | img, 45 | lmks, 46 | connectivity=connectivity, 47 | colors=colors, 48 | radius=2 * unit, 49 | width=2 * unit, 50 | ) 51 | 52 | if input_float: 53 | img = img.float() / 255 54 | return img 55 | 56 | 57 | def blend(a, b, w): 58 | return (a * w + b * (1 - w)).byte() 59 | 60 | 61 | if __name__ == "__main__": 62 | from argparse import ArgumentParser 63 | from torch.utils.data import DataLoader 64 | from matplotlib import pyplot as plt 65 | 66 | from vhap.data.nersemble_dataset import NeRSembleDataset 67 | 68 | parser = ArgumentParser() 69 | parser.add_argument("--root_folder", type=str, required=True) 70 | parser.add_argument("--subject", type=str, required=True) 71 | parser.add_argument("--sequence", type=str, required=True) 72 | parser.add_argument("--division", default=None) 73 | parser.add_argument("--subset", default=None) 74 | parser.add_argument("--scale_factor", type=float, default=1.0) 75 | parser.add_argument("--blend_weight", type=float, default=0.6) 76 | args = parser.parse_args() 77 | 78 | dataset = NeRSembleDataset( 79 | root_folder=args.root_folder, 80 | subject=args.subject, 81 | sequence=args.sequence, 82 | division=args.division, 83 | subset=args.subset, 84 | n_downsample_rgb=2, 85 | scale_factor=args.scale_factor, 86 | use_landmark=True, 87 | ) 88 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) 89 | 90 | for item in dataloader: 91 | unit = int(item["scale_factor"][0] * 3) + 1 92 | 93 | rgb = item["rgb"][0].permute(2, 0, 1) 94 | vis = rgb 95 | 96 | if "bbox_2d" in item: 97 | bbox = item["bbox_2d"][0][:4] 98 | tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit) 99 | vis = blend(tmp, vis, args.blend_weight) 100 | 101 | if "lmk2d" in item: 102 | face_landmark = item["lmk2d"][0][:, :2] 103 | tmp = plot_landmarks_2d( 104 | vis, 105 | face_landmark[None, ...], 106 | connectivity=connectivity_face, 107 | colors="white", 108 | unit=unit, 109 | ) 110 | vis = blend(tmp, vis, args.blend_weight) 111 | 112 | if "lmk2d_iris" in item: 113 | iris_landmark = item["lmk2d_iris"][0][:, :2] 114 | tmp = plot_landmarks_2d( 115 | vis, 116 | iris_landmark[None, ...], 117 | colors="blue", 118 | unit=unit, 119 | ) 120 | vis = blend(tmp, vis, args.blend_weight) 121 | 122 | vis = vis.permute(1, 2, 0).numpy() 123 | plt.imshow(vis) 124 | plt.draw() 125 | while not plt.waitforbuttonpress(timeout=-1): 126 | pass 127 | --------------------------------------------------------------------------------