├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── asset
├── flame
│ ├── head_template_mesh.obj
│ ├── landmark_embedding_with_eyes.npy
│ ├── tex_mean_painted.png
│ └── uv_masks.npz
├── flame_editor.png
├── flame_viewer.png
├── monocular.jpg
├── monocular_obama.gif
├── monocular_obama.mp4
├── monocular_person_0004.gif
├── monocular_person_0004.mp4
├── nersemble.jpg
├── nersemble_038_EMO-1.gif
├── nersemble_038_EMO-1.mp4
├── nersemble_074_EMO-1.gif
├── nersemble_074_EMO-1.mp4
└── teaser.gif
├── doc
├── monocular.md
├── nersemble.md
└── nersemble_v2.md
├── jobs
├── combine_nersemble.sh
├── run_monocular.sh
└── run_nersemble.sh
├── pyproject.toml
└── vhap
├── combine_nerf_datasets.py
├── config
├── base.py
├── nersemble.py
└── nersemble_v2.py
├── data
├── image_folder_dataset.py
├── nerf_dataset.py
├── nersemble_dataset.py
├── nersemble_v2_dataset.py
└── video_dataset.py
├── export_as_nerf_dataset.py
├── flame_editor.py
├── flame_viewer.py
├── generate_flame_uvmask.py
├── model
├── flame.py
├── lbs.py
└── tracker.py
├── preprocess_video.py
├── track.py
├── track_nersemble.py
├── track_nersemble_v2.py
└── util
├── camera.py
├── color_correction.py
├── landmark_detector_fa.py
├── landmark_detector_star.py
├── log.py
├── mesh.py
├── render_nvdiffrast.py
├── render_uvmap.py
├── vector_ops.py
└── visualization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .pyenv
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 | bin/
117 | include/
118 | share/
119 | pip-selfcheck.json
120 | lib64
121 | pyvenv.cfg
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
141 | # PyTest
142 | .pytest_cache
143 |
144 | # project specific files
145 | deps/InsightFace-PyTorch
146 | **/.idea/*
147 | tags
148 | .vscode/
149 |
150 | .DS_Store
151 | .pkl
152 | tmp
153 |
154 | output/
155 | export/
156 | /data/
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Qian"
5 | given-names: "Shenhan"
6 | orcid: "https://orcid.org/0000-0003-0416-7548"
7 | title: "VHAP: Versatile Head Alignment with Adaptive Appearance Priors"
8 | version: 0.0.3
9 | doi: 10.5281/zenodo.14988309
10 | date-released: 2024-09-05
11 | url: "https://github.com/ShenhanQian/VHAP"
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VHAP: Versatile Head Alignment with Adaptive Appearance Priors
2 |
3 |
4 |

5 |
6 |
7 | ## TL;DR
8 |
9 | - A photometric optimization pipeline based on differentiable mesh rasterization, applied to human head alignment.
10 | - A perturbation mechanism that implicitly extract and inject regional appearance priors adaptively during rendering, enabling alignment of regions purely based on their appearance consistency, such as the hair, ears, neck, and shoulders, where no pre-defined landmarks are available.
11 | - The exported tracking results can be directly used to create you own [GaussianAvatars](https://github.com/ShenhanQian/GaussianAvatars).
12 |
13 | ## License
14 |
15 | This work is made available under [CC-BY-NC-SA-4.0](./LICENSE). The repository is derived from the [multi-view head tracker of GaussianAvatars](https://github.com/ShenhanQian/GaussianAvatars/tree/main/reference_tracker), which is subjected to the following statements:
16 |
17 | > Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual property and proprietary rights in and to this software and related documentation. Any commercial use, reproduction, disclosure or distribution of this software and related documentation without an express license agreement from Toyota Motor Europe NV/SA is strictly prohibited.
18 |
19 | On top of the original repository, we add support to monocular videos and provide a complete set of scripts from video preprocessing to result export for NeRF/3DGS-style applications.
20 |
21 | ## Setup
22 |
23 | ```shell
24 | git clone git@github.com:ShenhanQian/VHAP.git
25 | cd VHAP
26 |
27 | conda create --name VHAP -y python=3.10
28 | conda activate VHAP
29 |
30 | # Install CUDA and ninja for compilation
31 | conda install -c "nvidia/label/cuda-12.1.1" cuda-toolkit ninja cmake # use the right CUDA version
32 | ln -s "$CONDA_PREFIX/lib" "$CONDA_PREFIX/lib64" # to avoid error "/usr/bin/ld: cannot find -lcudart"
33 | conda env config vars set CUDA_HOME=$CONDA_PREFIX # for compilation
34 |
35 | # Install PyTorch (make sure that the CUDA version matches with "Step 1")
36 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
37 | # or
38 | conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia
39 | # make sure torch.cuda.is_available() returns True
40 |
41 | pip install -e .
42 | ```
43 |
44 | > [!NOTE]
45 | > - We use an adjusted version of [nvdiffrast](https://github.com/ShenhanQian/nvdiffrast/tree/backface-culling) for backface-culling. If you have other versions installed before, you can reinstall as follows:
46 | > ```shell
47 | > pip install nvdiffrast@git+https://github.com/ShenhanQian/nvdiffrast@backface-culling --force-reinstall
48 | > rm -r ~/.cache/torch_extensions/*/nvdiffrast*
49 | > ```
50 | > - We use [STAR](https://github.com/ShenhanQian/STAR/) for landmark detection by default. Alterntively, [face-alignment](https://github.com/1adrianb/face-alignment) is faster but less accurate.
51 |
52 | ## Download
53 |
54 | ### FLAME
55 |
56 | Our code relies on FLAME. Please download assets from the [official website](https://flame.is.tue.mpg.de/download.php) and store them in the paths below:
57 |
58 | - FLAME 2023 (versions w/ jaw rotation) -> `asset/flame/flame2023.pkl`
59 | - FLAME Vertex Masks -> `asset/flame/FLAME_masks.pkl`
60 |
61 | > [!NOTE]
62 | > It is possible to use FLAME 2020 by download to `asset/flame/generic_model.pkl`. The `FLAME_MODEL_PATH` in `flame.py` needs to be updated accordingly.
63 |
64 | ### Video Data
65 |
66 | #### Multiview
67 |
68 | To get access to [NeRSemble](https://tobias-kirschstein.github.io/nersemble/) dataset, please request via the [Google Form](https://forms.gle/rYRoGNh2ed51TDWX9). The directory structure is expected to be like [this](https://github.com/ShenhanQian/VHAP/blob/c9ea660c6c6719110eca5ffdaf9029a2596cc5ca/vhap/data/nersemble_dataset.py#L32-L54).
69 |
70 | > [!NOTE]
71 | > The NeRSemble dataset has been updated to Version 2. Its folder structure and color correction algorithm differ from those in Version 1, so please be careful not to confuse the two.
72 |
73 | #### Monocular
74 |
75 | We use monocular video sequences following [INSTA](https://zielon.github.io/insta/). You can download raw videos from [LRZ](https://syncandshare.lrz.de/getlink/fiJE46wKrG6oTVZ16CUmMr/VHAP).
76 |
77 | ## Usage
78 |
79 | ### Monocular
80 | [For Monocular Videos](doc/monocular.md)
81 |
82 |
83 |

84 |
85 |
86 | ### Multiview
87 | [For NeRSemble Dataset](doc/nersemble.md)
88 |
89 | [For NeRSemble Dataset V2](doc/nersemble_v2.md)
90 |
91 |
92 |

93 |
94 |
95 | ## Discussions
96 |
97 | Photometric alignment is versatile but sometimes sensitive.
98 |
99 | **Texture map regularization**: Our method relies on a total-variation regularization on the texture map. Its loss weight is by default `1e4` for a monocualr video and `1e5` for the NeRSemble dataset (16 views). For you own multi-view dataset with fewer views, you should lower the regularization by passing `--w.reg_tex_tv 1e4` or 3e4. Otherwise, you may encounter corrupted shapes and blurry textures similar to https://github.com/ShenhanQian/VHAP/issues/10#issue-2558743737 and https://github.com/ShenhanQian/VHAP/issues/6#issue-2524833245.
100 |
101 | **Color affinity:** If the color of a point on the foreground contour is too close to the background, the [`static_offset`](https://github.com/ShenhanQian/VHAP/blob/64c18060e7aad104bf05a2c06aab7818f54af6bd/vhap/model/flame.py#L583) can go wild. You may try a different background color by `--data.background_color white` or `--data.background_color black`. You can also disable `static_offset` by `--model.no_use_static_offset`.
102 |
103 | **Occlussion:** When the neck is occluded by collars, the photometric gradients may squeeze and stretch the neck into unnatural shapes. Usually, this problem can be relieved by disabling photometric alignment in certain regions. We hard-coded the occlusion status for some subjects in the NeRSemble dataset with the [`occluded_table`](https://github.com/ShenhanQian/VHAP/blob/51a2792bd3ad3f920d9cd8f1b107a56b92349520/vhap/config/nersemble.py#L71). You can extend the table or temporally change it by, e.g., `--model.occluded neck_lower boundary`.
104 |
105 | **Limited degree of freedom:** Another limitation comes from the FLAME model. FLAME is great since it covers the whole head and neck. However, there is only one joint for the neck, between the neck and the head. This means the lower part of the neck cannot move relative to the torse. This limits the model's ability to capture large movement of the head. For example, it's very hard to achieve good alignment of the lower neck and the head at the same time for the *EXP-1-head* sequence in NeRSemble dataset because of the aforementioned lack of degree of freedom.
106 |
107 | **You are welcomed to report more failure cases and help us improve the tracker.**
108 |
109 | ## Interactive Viewers
110 |
111 | Our method relies on vertex masks defined on FLAME. We add custom masks to enrich the original ones. You can play with `regions` in our FLAME Editor to see how each mask look like .
112 |
113 | ```shell
114 | python vhap/flame_editor.py
115 | ```
116 |
117 | We also provide a FLAME viewer for you to interact with a tracked sequence.
118 |
119 | ```shell
120 | python vhap/flame_viewer.py \
121 | --param_path output/nersemble/074_EMO-1_v16_DS4_wBg_staticOffset/2024-09-09_15-49-02/tracked_flame_params_30.npz \
122 | ```
123 |
124 | Optional, you can enable colored rendering by specifying a texture image with `--tex_path`.
125 |
126 | For both viewers, you can switch to flat shading with `--no-shade-smooth`.
127 |
128 |
129 |

130 |

131 |
132 |
133 | ## Cite
134 |
135 | Please kindly cite our repository and preceding paper if you find our software or algorithm useful for your research.
136 |
137 | ```bibtex
138 | @misc{qian2024vhap,
139 | title={VHAP: Versatile Head Alignment with Adaptive Appearance Priors},
140 | author={Qian, Shenhan},
141 | year={2024},
142 | month={sep},
143 | doi={10.5281/zenodo.14988309}
144 | url={https://github.com/ShenhanQian/VHAP}
145 | }
146 | ```
147 |
148 | ```bibtex
149 | @inproceedings{qian2024gaussianavatars,
150 | title={Gaussianavatars: Photorealistic head avatars with rigged 3d gaussians},
151 | author={Qian, Shenhan and Kirschstein, Tobias and Schoneveld, Liam and Davoli, Davide and Giebenhain, Simon and Nie{\ss}ner, Matthias},
152 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
153 | pages={20299--20309},
154 | year={2024}
155 | }
156 | ```
157 |
--------------------------------------------------------------------------------
/asset/flame/landmark_embedding_with_eyes.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame/landmark_embedding_with_eyes.npy
--------------------------------------------------------------------------------
/asset/flame/tex_mean_painted.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame/tex_mean_painted.png
--------------------------------------------------------------------------------
/asset/flame/uv_masks.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame/uv_masks.npz
--------------------------------------------------------------------------------
/asset/flame_editor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame_editor.png
--------------------------------------------------------------------------------
/asset/flame_viewer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/flame_viewer.png
--------------------------------------------------------------------------------
/asset/monocular.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular.jpg
--------------------------------------------------------------------------------
/asset/monocular_obama.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular_obama.gif
--------------------------------------------------------------------------------
/asset/monocular_obama.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular_obama.mp4
--------------------------------------------------------------------------------
/asset/monocular_person_0004.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular_person_0004.gif
--------------------------------------------------------------------------------
/asset/monocular_person_0004.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/monocular_person_0004.mp4
--------------------------------------------------------------------------------
/asset/nersemble.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble.jpg
--------------------------------------------------------------------------------
/asset/nersemble_038_EMO-1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble_038_EMO-1.gif
--------------------------------------------------------------------------------
/asset/nersemble_038_EMO-1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble_038_EMO-1.mp4
--------------------------------------------------------------------------------
/asset/nersemble_074_EMO-1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble_074_EMO-1.gif
--------------------------------------------------------------------------------
/asset/nersemble_074_EMO-1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/nersemble_074_EMO-1.mp4
--------------------------------------------------------------------------------
/asset/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShenhanQian/VHAP/0c1ad4247e1f1d54a66ee53d3144165dcbc85521/asset/teaser.gif
--------------------------------------------------------------------------------
/doc/monocular.md:
--------------------------------------------------------------------------------
1 | ## For Monocular Videos
2 |
3 |
4 |

5 |
6 |
7 | ### 1. Preprocess
8 |
9 | This step extracts frames from video(s), then run foreground matting for each frame, which requires GPU.
10 |
11 | ```shell
12 | SEQUENCE="obama.mp4"
13 |
14 | python vhap/preprocess_video.py \
15 | --input data/monocular/${SEQUENCE} \
16 | --matting_method robust_video_matting
17 | ```
18 |
19 | - `--matting_method robust_video_matting`: Use RobustVideoMatting due to lack of a background image.
20 | - (Optional) `--downsample_scales 2`: Generate downsampled versions of images in a scale such as 2. (Image size lower than 1024 is preferred for efficiency.)
21 |
22 | ### 2. Align and track faces
23 |
24 | This step automatically detects facial landmarks if absent, then begin FLAME tracking. We initialize shape and appearance parameters on the first frame, then do a sequential tracking of following frames. After the sequence tracking, we conduct 30 epochs of global tracking, which optimize all the parameters on a random frame in each iteration.
25 |
26 | ```shell
27 | SEQUENCE="obama"
28 | TRACK_OUTPUT_FOLDER="output/monocular/${SEQUENCE}_whiteBg_staticOffset"
29 |
30 | python vhap/track.py --data.root_folder "data/monocular" \
31 | --exp.output_folder $TRACK_OUTPUT_FOLDER \
32 | --data.sequence $SEQUENCE \
33 | # --data.n_downsample_rgb 2 # Only specify this if you have generate downsampled images during preprocessing.
34 | ```
35 |
36 | Optional arguments
37 |
38 | - `--model.no_use_static_offset`: disable static offset for FLAME (very stable, but less aligned facial geometry)
39 |
40 | > Disabling static offset will automatically triggers `--model.occluded hair`, which is crucial to prevent the head from growing too larger to align with the top of hair.
41 |
42 | - `--exp.no_photometric`: track only with landmark (very fast, but coarse)
43 |
44 | ### 3. Export tracking results into a NeRF-style dataset
45 |
46 | Given the tracked FLAME parameters from the above step, you can export the results to form a NeRF/3DGS style sequence, consisting of image folders and a `transforms.json`.
47 |
48 | ```shell
49 | SEQUENCE="obama"
50 | TRACK_OUTPUT_FOLDER="output/monocular/${SEQUENCE}_whiteBg_staticOffset"
51 | EXPORT_OUTPUT_FOLDER="export/monocular/${SEQUENCE}_whiteBg_staticOffset_maskBelowLine"
52 |
53 | python vhap/export_as_nerf_dataset.py \
54 | --src_folder ${TRACK_OUTPUT_FOLDER} \
55 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white
56 | ```
57 |
--------------------------------------------------------------------------------
/doc/nersemble.md:
--------------------------------------------------------------------------------
1 | ## For NeRSemble Dataset
2 |
3 |
4 |

5 |
6 |
7 | ### 1. Preprocess
8 |
9 | This step extracts frames from video(s), then run foreground matting for each frame, which requires GPU.
10 |
11 | ```shell
12 | SUBJECT="074"
13 | SEQUENCE="EMO-1"
14 |
15 | python vhap/preprocess_video.py \
16 | --input data/nersemble/${SUBJECT}/${SEQUENCE}* \
17 | --downsample_scales 2 4 \
18 | --matting_method background_matting_v2
19 | ```
20 |
21 | - `--downsample_scales 2 4`: Generate downsampled versions of the images in scale 2 and 4.
22 | - `--matting_method background_matting_v2`: Use BackGroundMatingV2 due to availability of background images.
23 |
24 | After preprocessing, you can inspect images, masks, and landmarks with our [NeRSemble Data Viewer](https://github.com/ShenhanQian/nersemble-data-viewer).
25 |
26 | ### 2. Align and track faces
27 |
28 | This step automatically detects facial landmarks if absent, then begin FLAME tracking. We initialize shape and appearance parameters on the first frame, then do a sequential tracking of following frames. After the sequence tracking, we conduct 30 epochs of global tracking, which optimize all the parameters on a random frame in each iteration.
29 |
30 | ```shell
31 | SUBJECT="074"
32 | SEQUENCE="EMO-1"
33 | TRACK_OUTPUT_FOLDER="output/nersemble/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset"
34 |
35 | python vhap/track_nersemble.py --data.root_folder "data/nersemble" \
36 | --exp.output_folder $TRACK_OUTPUT_FOLDER \
37 | --data.subject $SUBJECT --data.sequence $SEQUENCE \
38 | --data.n_downsample_rgb 4
39 | ```
40 |
41 | Optional arguments
42 |
43 | - `--model.no_use_static_offset`: disable static offset for FLAME (very stable, but less aligned facial geometry)
44 |
45 | > Disabling static offset will automatically triggers `--model.occluded hair`, which is crucial to prevent the head from growing too larger to align with the top of hair.
46 |
47 | - `--exp.no_photometric`: track only with landmark (very fast, but coarse)
48 |
49 | > [!NOTE]
50 | > We use all 16 views for the optimization, but we only visualize 3 views for efficiency.
51 |
52 | ### 3. Export tracking results into a NeRF-style dataset
53 |
54 | Given the tracked FLAME parameters from the above step, you can export the results to form a NeRF/3DGS style sequence, consisting of image folders and a `transforms.json`.
55 |
56 | ```shell
57 | SUBJECT="074"
58 | SEQUENCE="EMO-1"
59 | TRACK_OUTPUT_FOLDER="output/nersemble/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset"
60 | EXPORT_OUTPUT_FOLDER="export/nersemble/${SUBJECT}_${SEQUENCE}_v16_DS4_whiteBg_staticOffset_maskBelowLine"
61 |
62 | python vhap/export_as_nerf_dataset.py \
63 | --src_folder ${TRACK_OUTPUT_FOLDER} \
64 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white
65 | ```
66 |
67 | ### 4. Combine exported sequences of the same person as a union dataset
68 |
69 | ```shell
70 | SUBJECT="074"
71 |
72 | python vhap/combine_nerf_datasets.py \
73 | --src_folders \
74 | export/nersemble/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \
75 | export/nersemble/${SUBJECT}_EMO-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \
76 | export/nersemble/${SUBJECT}_EMO-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \
77 | export/nersemble/${SUBJECT}_EMO-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \
78 | export/nersemble/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \
79 | export/nersemble/${SUBJECT}_EXP-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \
80 | export/nersemble/${SUBJECT}_EXP-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \
81 | export/nersemble/${SUBJECT}_EXP-5_v16_DS4_whiteBg_staticOffset_maskBelowLine \
82 | export/nersemble/${SUBJECT}_EXP-8_v16_DS4_whiteBg_staticOffset_maskBelowLine \
83 | export/nersemble/${SUBJECT}_EXP-9_v16_DS4_whiteBg_staticOffset_maskBelowLine \
84 | --tgt_folder \
85 | export/nersemble/UNION10_${SUBJECT}_EMO1234EXP234589_v16_DS4_whiteBg_staticOffset_maskBelowLine
86 | ```
87 |
88 | > [!NOTE]
89 | > The `tgt_folder` must be in the same parent folder as `src_folders` because the union dataset read from the original image files by relative paths.
90 |
--------------------------------------------------------------------------------
/doc/nersemble_v2.md:
--------------------------------------------------------------------------------
1 | ## For NeRSemble Dataset V2
2 |
3 |
4 |

5 |
6 |
7 | ### 1. Preprocess
8 |
9 | This step extracts frames from video(s), then run foreground matting for each frame, which requires GPU.
10 |
11 | ```shell
12 | SUBJECT="074"
13 | SEQUENCE="EMO-1"
14 |
15 | python vhap/preprocess_video.py \
16 | --input data/nersemble_v2/${SUBJECT}/sequences/${SEQUENCE}* \
17 | --downsample_scales 2 4 \
18 | --matting_method background_matting_v2
19 | ```
20 |
21 | - `--downsample_scales 2 4`: Generate downsampled versions of the images in scale 2 and 4.
22 | - `--matting_method background_matting_v2`: Use BackGroundMatingV2 due to availability of background images.
23 |
24 | After preprocessing, you can inspect images, masks, and landmarks with our [NeRSemble Data Viewer](https://github.com/ShenhanQian/nersemble-data-viewer).
25 |
26 | ### 2. Align and track faces
27 |
28 | This step automatically detects facial landmarks if absent, then begin FLAME tracking. We initialize shape and appearance parameters on the first frame, then do a sequential tracking of following frames. After the sequence tracking, we conduct 30 epochs of global tracking, which optimize all the parameters on a random frame in each iteration.
29 |
30 | ```shell
31 | SUBJECT="074"
32 | SEQUENCE="EMO-1"
33 | TRACK_OUTPUT_FOLDER="output/nersemble_v2/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset"
34 |
35 | python vhap/track_nersemble_v2.py --data.root_folder "data/nersemble_v2" \
36 | --exp.output_folder $TRACK_OUTPUT_FOLDER \
37 | --data.subject $SUBJECT --data.sequence $SEQUENCE \
38 | --data.n_downsample_rgb 4
39 | ```
40 |
41 | Optional arguments
42 |
43 | - `--model.no_use_static_offset`: disable static offset for FLAME (very stable, but less aligned facial geometry)
44 |
45 | > Disabling static offset will automatically triggers `--model.occluded hair`, which is crucial to prevent the head from growing too larger to align with the top of hair.
46 |
47 | - `--exp.no_photometric`: track only with landmark (very fast, but coarse)
48 |
49 | > [!NOTE]
50 | > We use all 16 views for the optimization, but we only visualize 3 views for efficiency.
51 |
52 | > [!WARNING]
53 | > NeRSemble Dataset V2 comes with improved color calibration. However, this may considerably slow down image loading, particularly at full resolution.
54 |
55 |
56 | ### 3. Export tracking results into a NeRF-style dataset
57 |
58 | Given the tracked FLAME parameters from the above step, you can export the results to form a NeRF/3DGS style sequence, consisting of image folders and a `transforms.json`.
59 |
60 | ```shell
61 | SUBJECT="074"
62 | SEQUENCE="EMO-1"
63 | TRACK_OUTPUT_FOLDER="output/nersemble_v2/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset"
64 | EXPORT_OUTPUT_FOLDER="export/nersemble_v2/${SUBJECT}_${SEQUENCE}_v16_DS4_whiteBg_staticOffset_maskBelowLine"
65 |
66 | python vhap/export_as_nerf_dataset.py \
67 | --src_folder ${TRACK_OUTPUT_FOLDER} \
68 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white
69 | ```
70 |
71 | ### 4. Combine exported sequences of the same person as a union dataset
72 |
73 | ```shell
74 | SUBJECT="074"
75 |
76 | python vhap/combine_nerf_datasets.py \
77 | --src_folders \
78 | export/nersemble_v2/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \
79 | export/nersemble_v2/${SUBJECT}_EMO-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \
80 | export/nersemble_v2/${SUBJECT}_EMO-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \
81 | export/nersemble_v2/${SUBJECT}_EMO-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \
82 | export/nersemble_v2/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \
83 | export/nersemble_v2/${SUBJECT}_EXP-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \
84 | export/nersemble_v2/${SUBJECT}_EXP-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \
85 | export/nersemble_v2/${SUBJECT}_EXP-5_v16_DS4_whiteBg_staticOffset_maskBelowLine \
86 | export/nersemble_v2/${SUBJECT}_EXP-8_v16_DS4_whiteBg_staticOffset_maskBelowLine \
87 | export/nersemble_v2/${SUBJECT}_EXP-9_v16_DS4_whiteBg_staticOffset_maskBelowLine \
88 | --tgt_folder \
89 | export/nersemble_v2/UNION10_${SUBJECT}_EMO1234EXP234589_v16_DS4_whiteBg_staticOffset_maskBelowLine
90 | ```
91 |
92 | > [!NOTE]
93 | > The `tgt_folder` must be in the same parent folder as `src_folders` because the union dataset read from the original image files by relative paths.
94 |
--------------------------------------------------------------------------------
/jobs/combine_nersemble.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | # Define a list of subjects and sequences
5 | # SUBJECTS=("074" "104" "140" "165" "210" "218" "238" "253" "264" "302" "304" "306") # 12
6 | # SUBJECTS=("074" "104" "140" "165" "175" "210" "218" "238" "253" "264" "302" "304" "306" "460") # 14
7 | # SUBJECTS=("306") # tmp
8 |
9 | # Loop through subjects and sequences and modify the command accordingly
10 | for SUBJECT in "${SUBJECTS[@]}"; do
11 | export_dir=export/nersemble
12 |
13 | #------- combine 10 -------#
14 | COMMAND="python vhap/combine_nerf_datasets.py \
15 | --src_folders \
16 | $export_dir/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \
17 | $export_dir/${SUBJECT}_EMO-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \
18 | $export_dir/${SUBJECT}_EMO-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \
19 | $export_dir/${SUBJECT}_EMO-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \
20 | $export_dir/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \
21 | $export_dir/${SUBJECT}_EXP-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \
22 | $export_dir/${SUBJECT}_EXP-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \
23 | $export_dir/${SUBJECT}_EXP-5_v16_DS4_whiteBg_staticOffset_maskBelowLine \
24 | $export_dir/${SUBJECT}_EXP-8_v16_DS4_whiteBg_staticOffset_maskBelowLine \
25 | $export_dir/${SUBJECT}_EXP-9_v16_DS4_whiteBg_staticOffset_maskBelowLine \
26 | --tgt_folder \
27 | $export_dir/UNION10_${SUBJECT}_EMO1234EXP234589_v16_DS4_whiteBg_staticOffset_maskBelowLine \
28 |
29 | "
30 |
31 | #------- combine 20 -------#
32 | # COMMAND="python vhap/combine_nerf_datasets.py \
33 | # --src_folders \
34 | # $export_dir/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \
35 | # $export_dir/${SUBJECT}_EMO-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \
36 | # $export_dir/${SUBJECT}_EMO-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \
37 | # $export_dir/${SUBJECT}_EMO-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \
38 | # $export_dir/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \
39 | # $export_dir/${SUBJECT}_EXP-3_v16_DS4_whiteBg_staticOffset_maskBelowLine \
40 | # $export_dir/${SUBJECT}_EXP-4_v16_DS4_whiteBg_staticOffset_maskBelowLine \
41 | # $export_dir/${SUBJECT}_EXP-5_v16_DS4_whiteBg_staticOffset_maskBelowLine \
42 | # $export_dir/${SUBJECT}_EXP-8_v16_DS4_whiteBg_staticOffset_maskBelowLine \
43 | # $export_dir/${SUBJECT}_EXP-9_v16_DS4_whiteBg_staticOffset_maskBelowLine \
44 | # $export_dir/${SUBJECT}_SEN-01_v16_DS4_whiteBg_staticOffset_maskBelowLine \
45 | # $export_dir/${SUBJECT}_SEN-02_v16_DS4_whiteBg_staticOffset_maskBelowLine \
46 | # $export_dir/${SUBJECT}_SEN-03_v16_DS4_whiteBg_staticOffset_maskBelowLine \
47 | # $export_dir/${SUBJECT}_SEN-04_v16_DS4_whiteBg_staticOffset_maskBelowLine \
48 | # $export_dir/${SUBJECT}_SEN-05_v16_DS4_whiteBg_staticOffset_maskBelowLine \
49 | # $export_dir/${SUBJECT}_SEN-06_v16_DS4_whiteBg_staticOffset_maskBelowLine \
50 | # $export_dir/${SUBJECT}_SEN-07_v16_DS4_whiteBg_staticOffset_maskBelowLine \
51 | # $export_dir/${SUBJECT}_SEN-08_v16_DS4_whiteBg_staticOffset_maskBelowLine \
52 | # $export_dir/${SUBJECT}_SEN-09_v16_DS4_whiteBg_staticOffset_maskBelowLine \
53 | # $export_dir/${SUBJECT}_SEN-10_v16_DS4_whiteBg_staticOffset_maskBelowLine \
54 | # --tgt_folder \
55 | # $export_dir/UNION20_${SUBJECT}_EMO1234EXP234589SEN_v16_DS4_whiteBg_staticOffset_maskBelowLine \
56 |
57 | # "
58 |
59 | #------- combine short -------#
60 | # COMMAND="python vhap/combine_nerf_datasets.py \
61 | # --src_folders \
62 | # $export_dir/${SUBJECT}_EMO-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \
63 | # $export_dir/${SUBJECT}_EXP-1_v16_DS4_whiteBg_staticOffset_maskBelowLine \
64 | # $export_dir/${SUBJECT}_EXP-2_v16_DS4_whiteBg_staticOffset_maskBelowLine \
65 | # $export_dir/${SUBJECT}_SEN-10_v16_DS4_whiteBg_staticOffset_maskBelowLine \
66 | # $export_dir/${SUBJECT}_FREE_v16_DS4_whiteBg_staticOffset_maskBelowLine \
67 | # --tgt_folder \
68 | # $export_dir/UNIONshort_${SUBJECT}_EMO1EXP12SEN10FREE_v16_DS4_whiteBg_staticOffset_maskBelowLine \
69 |
70 | # "
71 |
72 | #------- zip -------#
73 | # COMMAND="zip -r $export_dir/$SUBJECT.zip $export_dir/$SUBJECT* $export_dir/UNION_$SUBJECT*"
74 |
75 | #======= Run =======#
76 | # Execute the command (remove the echo if you want to actually run the command)
77 |
78 | $COMMAND
79 | # echo $COMMAND
80 |
81 | done
82 |
--------------------------------------------------------------------------------
/jobs/run_monocular.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define a list of sequences
4 |
5 | # SEQUENCES=("obama" "biden" "justin" "nf_01" "nf_03" "person_0004" "malte_1" "bala" "wojtek_1" "marcel") # 10
6 | # SEQUENCES=("obama")
7 |
8 | DATA_FOLDER="data/monocular"
9 |
10 | # Loop through sequences and modify the command accordingly
11 | for SEQUENCE in "${SEQUENCES[@]}"; do
12 |
13 | JOB_NAME="vhap_${SEQUENCE}"
14 |
15 | #======= Preprocess =======#
16 | RAW_VIDEO_PATH="${DATA_FOLDER}/${SEQUENCE}.mp4"
17 | PREPROCESS_COMMAND="
18 | python vhap/preprocess_video.py --input ${RAW_VIDEO_PATH} \
19 | --matting_method robust_video_matting
20 | "
21 |
22 | #======= Track =======#
23 | TRACK_OUTPUT_FOLDER="output/monocular/${SEQUENCE}_whiteBg_staticOffset"
24 | TRACK_COMMAND="
25 | python vhap/track.py --data.root_folder ${DATA_FOLDER} \
26 | --exp.output_folder $TRACK_OUTPUT_FOLDER \
27 | --data.sequence $SEQUENCE \
28 |
29 | "
30 |
31 | #======= Export =======#
32 | EXPORT_OUTPUT_FOLDER="export/monocular/${SEQUENCE}_whiteBg_staticOffset_maskBelowLine"
33 | EXPORT_COMMAND="python vhap/export_as_nerf_dataset.py \
34 | --src_folder ${TRACK_OUTPUT_FOLDER} \
35 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white \
36 |
37 | "
38 |
39 | #======= Run =======#
40 | # Execute the command (remove the echo if you want to actually run the command)
41 |
42 | #------- check completeness -------#
43 | # last_folder=$(find "$TRACK_OUTPUT_FOLDER" -maxdepth 1 -type d | sort | tail -n 1)
44 | # # if [ ! -d $last_folder/eval_30 ]; then
45 | # if [ ! -e $last_folder/tracked_flame_params_30.npz ]; then
46 | # echo $last_folder
47 | # fi
48 |
49 | #------- create video -------#
50 | # last_folder=$(find "$TRACK_OUTPUT_FOLDER" -maxdepth 1 -type d | sort | tail -n 1)
51 | # video_folder=$last_folder/eval_30
52 | # ffmpeg -y -framerate 25 -f image2 -pattern_type glob -i "${video_folder}/image_grid/*.jpg" -pix_fmt yuv420p $video_folder/image_grid.mp4
53 |
54 | #------- rename -------#
55 | # mv $TRACK_OUTPUT_FOLDER $TRACK_OUTPUT_FOLDER
56 |
57 | #------- only preprocess -------#
58 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $PREPROCESS_COMMAND"
59 | # COMMAND="$PREPROCESS_COMMAND"
60 |
61 | #------- only track -------#
62 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $TRACK_COMMAND"
63 | # COMMAND="$TRACK_COMMAND"
64 |
65 | #------- only export -------#
66 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $EXPORT_COMMAND"
67 | # COMMAND="$EXPORT_COMMAND"
68 |
69 | #------- track and export -------#
70 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $TRACK_COMMAND && $EXPORT_COMMAND"
71 | # COMMAND="$TRACK_COMMAND && $EXPORT_COMMAND"
72 |
73 |
74 | # echo $COMMAND
75 | $COMMAND
76 | sleep 1
77 | done
78 |
--------------------------------------------------------------------------------
/jobs/run_nersemble.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define a list of subjects and sequences
4 |
5 | # SUBJECTS=("074" "104" "140" "165" "175" "210" "218" "238" "253" "264" "302" "304" "306" "460") # 14
6 | # SUBJECTS=("306")
7 |
8 | # SEQUENCES=("EMO-1" "EMO-2" "EMO-3" "EMO-4" "EXP-1" "EXP-2" "EXP-3" "EXP-4" "EXP-5" "EXP-8" "EXP-9" "FREE") # 11
9 | # SEQUENCES=("SEN-01" "SEN-02" "SEN-03" "SEN-04" "SEN-05" "SEN-06" "SEN-07" "SEN-08" "SEN-09" "SEN-10") # 10
10 | # SEQUENCES=("EMO-1")
11 |
12 |
13 | DATA_FOLDER="data/nersemble"
14 |
15 | # Loop through subjects and sequences and modify the command accordingly
16 | for SUBJECT in "${SUBJECTS[@]}"; do
17 | for SEQUENCE in "${SEQUENCES[@]}"; do
18 |
19 | JOB_NAME="vhap_${SUBJECT}_${SEQUENCE}"
20 |
21 | #======= Preprocess =======#
22 | VIDEO_FOLDER_PATH="${DATA_FOLDER}/${SUBJECT}/${SEQUENCE}"
23 | PREPROCESS_COMMAND="
24 | python vhap/preprocess_video.py --input $VIDEO_FOLDER_PATH \
25 | --downsample_scales 2 4 \
26 | --matting_method background_matting_v2
27 | "
28 |
29 | #======= Track =======#
30 | TRACK_OUTPUT_FOLDER="output/nersemble/${SUBJECT}_${SEQUENCE}_v16_DS4_wBg_staticOffset"
31 | TRACK_COMMAND="
32 | python vhap/track_nersemble.py --data.root_folder $DATA_FOLDER \
33 | --exp.output_folder $TRACK_OUTPUT_FOLDER \
34 | --data.subject $SUBJECT --data.sequence $SEQUENCE \
35 | --data.n_downsample_rgb 4
36 |
37 | "
38 | #======= Export =======#
39 | EXPORT_OUTPUT_FOLDER="export/${SUBJECT}_${SEQUENCE}_v16_DS4_whiteBg_staticOffset_maskBelowLine"
40 | EXPORT_COMMAND="
41 | python vhap/export_as_nerf_dataset.py \
42 | --src_folder ${TRACK_OUTPUT_FOLDER} \
43 | --tgt_folder ${EXPORT_OUTPUT_FOLDER} --background-color white
44 | "
45 |
46 | #======= Run =======#
47 | # Execute the command (remove the echo if you want to actually run the command)
48 |
49 | #------- check completeness -------#
50 | # last_folder=$(find "$TRACK_OUTPUT_FOLDER" -maxdepth 1 -type d | sort | tail -n 1)
51 | # # if [ ! -d $last_folder/eval_30 ]; then
52 | # if [ ! -e $last_folder/tracked_flame_params_30.npz ]; then
53 | # echo $last_folder
54 | # fi
55 |
56 | #------- create video -------#
57 | # last_folder=$(find "$TRACK_OUTPUT_FOLDER" -maxdepth 1 -type d | sort | tail -n 1)
58 | # video_folder=$last_folder/eval_30
59 | # ffmpeg -y -framerate 25 -f image2 -pattern_type glob -i "${video_folder}/image_grid/*.jpg" -pix_fmt yuv420p $video_folder/image_grid.mp4
60 |
61 | #------- rename -------#
62 | # mv $TRACK_OUTPUT_FOLDER $TRACK_OUTPUT_FOLDER
63 |
64 | #------- only preprocess -------#
65 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $PREPROCESS_COMMAND"
66 | # COMMAND="$PREPROCESS_COMMAND"
67 |
68 | #------- only track -------#
69 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $TRACK_COMMAND"
70 | # COMMAND="$TRACK_COMMAND"
71 |
72 | #------- only export -------#
73 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $EXPORT_COMMAND"
74 | # COMMAND="$EXPORT_COMMAND"
75 |
76 | #------- track and export -------#
77 | # COMMAND="$HOME/local/usr/bin/isbatch.sh $JOB_NAME $TRACK_COMMAND && $EXPORT_COMMAND"
78 | # COMMAND="$TRACK_COMMAND && $EXPORT_COMMAND"
79 |
80 |
81 | # echo $COMMAND
82 | $COMMAND
83 | sleep 1
84 | done
85 | done
86 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [tool.hatch.metadata]
6 | allow-direct-references = true
7 |
8 | [tool.hatch.build]
9 | include = ["vhap/**/*.py"]
10 |
11 | [project]
12 | name = "VHAP"
13 | version = "0.0.4"
14 | requires-python = ">=3.9"
15 | dependencies = [
16 | "tyro",
17 | "pyyaml",
18 | "numpy==1.22.3",
19 | "matplotlib==3.8.0",
20 | "scipy",
21 | "pillow",
22 | "opencv-python",
23 | "ffmpeg-python",
24 | "colour",
25 | "torch", # manually install to avoid CUDA version mismatch
26 | "torchvision", # manually install to avoid CUDA version mismatch
27 | "tensorboard",
28 | "chumpy",
29 | "trimesh",
30 | "nvdiffrast@git+https://github.com/ShenhanQian/nvdiffrast@backface-culling",
31 | "BackgroundMattingV2@git+https://github.com/ShenhanQian/BackgroundMattingV2",
32 | "STAR@git+https://github.com/ShenhanQian/STAR/",
33 | "dlib", # for STAR
34 | "pandas", # for STAR
35 | "gdown", # for STAR
36 | "face-alignment",
37 | "face-detection-tflite", # for face-alignment
38 | "pytorch3d@git+https://github.com/facebookresearch/pytorch3d.git",
39 | "dearpygui",
40 | ]
41 | authors = [
42 | {name = "Shenhan Qian", email = "shenhan.qian@tum.de"},
43 | ]
44 |
45 | description = "A complete head tracking pipeline from videos to NeRF-ready datasets."
46 | readme = "README.md"
47 |
--------------------------------------------------------------------------------
/vhap/combine_nerf_datasets.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from typing import Optional, Literal, List
11 | from copy import deepcopy
12 | import json
13 | import tyro
14 | from pathlib import Path
15 | import shutil
16 | import random
17 |
18 |
19 | class NeRFDatasetAssembler:
20 | def __init__(self, src_folders: List[Path], tgt_folder: Path, division_mode: Literal['random_single', 'random_group', 'last']='random_group'):
21 | self.src_folders = src_folders
22 | self.tgt_folder = tgt_folder
23 | self.num_timestep = 0
24 |
25 | # use the subject name as the random seed to sample the test sequence
26 | subjects = [sf.name.split('_')[0] for sf in src_folders]
27 | for s in subjects:
28 | assert s == subjects[0], f"Cannot combine datasets from different subjects: {subjects}"
29 | subject = subjects[0]
30 | random.seed(subject)
31 |
32 | if division_mode == 'random_single':
33 | self.src_folders_test = [self.src_folders.pop(int(random.uniform(0, 1) * len(src_folders)))]
34 | elif division_mode == 'random_group':
35 | # sample one sequence as the test sequence every `group_size` sequences
36 | self.src_folders_test = []
37 | num_all = len(self.src_folders)
38 | group_size = 10
39 | num_test = max(1, num_all // group_size)
40 | indices_test = []
41 | for gi in range(num_test):
42 | idx = min(num_all - 1, random.randint(0, group_size - 1) + gi * group_size)
43 | indices_test.append(idx)
44 |
45 | for idx in indices_test:
46 | self.src_folders_test.append(self.src_folders.pop(idx))
47 | elif division_mode == 'last':
48 | self.src_folders_test = [self.src_folders.pop(-1)]
49 | else:
50 | raise ValueError(f"Unknown division mode: {division_mode}")
51 |
52 | self.src_folders_train = self.src_folders
53 |
54 | def write(self):
55 | self.combine_dbs(self.src_folders_train, division='train')
56 | self.combine_dbs(self.src_folders_test, division='test')
57 |
58 | def combine_dbs(self, src_folders, division: Optional[Literal['train', 'test']] = None):
59 | db = None
60 | for i, src_folder in enumerate(src_folders):
61 | dbi_path = src_folder / "transforms.json"
62 | assert dbi_path.exists(), f"Could not find {dbi_path}"
63 | # print(f"Loading database: {dbi_path}")
64 | dbi = json.load(open(dbi_path, "r"))
65 |
66 | dbi['timestep_indices'] = [t + self.num_timestep for t in dbi['timestep_indices']]
67 | self.num_timestep += len(dbi['timestep_indices'])
68 | for frame in dbi['frames']:
69 | # drop keys that are irrelevant for a combined dataset
70 | frame.pop('timestep_index_original')
71 | frame.pop('timestep_id')
72 |
73 | # accumulate timestep indices
74 | frame['timestep_index'] = dbi['timestep_indices'][frame['timestep_index']]
75 |
76 | # complement the parent folder
77 | frame['file_path'] = str(Path('..') / Path(src_folder.name) / frame['file_path'])
78 | frame['flame_param_path'] = str(Path('..') / Path(src_folder.name) / frame['flame_param_path'])
79 | frame['fg_mask_path'] = str(Path('..') / Path(src_folder.name) / frame['fg_mask_path'])
80 |
81 | if db is None:
82 | db = dbi
83 | else:
84 | db['frames'] += dbi['frames']
85 | db['timestep_indices'] += dbi['timestep_indices']
86 |
87 | if not self.tgt_folder.exists():
88 | self.tgt_folder.mkdir(parents=True)
89 |
90 | if division == 'train':
91 | # copy the canonical flame param
92 | cano_flame_param_path = src_folders[0] / "canonical_flame_param.npz"
93 | tgt_flame_param_path = self.tgt_folder / f"canonical_flame_param.npz"
94 | print(f"Copying canonical flame param: {tgt_flame_param_path}")
95 | shutil.copy(cano_flame_param_path, tgt_flame_param_path)
96 |
97 | # leave one camera for validation
98 | db_train = {k: v for k, v in db.items() if k not in ['frames', 'camera_indices']}
99 | db_train['frames'] = []
100 | db_val = deepcopy(db_train)
101 |
102 | if len(db['camera_indices']) > 1:
103 | # when having multiple cameras, leave one camera for validation (novel-view sythesis)
104 | if 8 in db['camera_indices']:
105 | # use camera 8 for validation (front-view of the NeRSemble dataset)
106 | db_train['camera_indices'] = [i for i in db['camera_indices'] if i != 8]
107 | db_val['camera_indices'] = [8]
108 | else:
109 | # use the last camera for validation
110 | db_train['camera_indices'] = db['camera_indices'][:-1]
111 | db_val['camera_indices'] = [db['camera_indices'][-1]]
112 | else:
113 | # when only having one camera, we create an empty validation set
114 | db_train['camera_indices'] = db['camera_indices']
115 | db_val['camera_indices'] = []
116 |
117 | for frame in db['frames']:
118 | if frame['camera_index'] in db_train['camera_indices']:
119 | db_train['frames'].append(frame)
120 | elif frame['camera_index'] in db_val['camera_indices']:
121 | db_val['frames'].append(frame)
122 | else:
123 | raise ValueError(f"Unknown camera index: {frame['camera_index']}")
124 |
125 | write_json(db_train, self.tgt_folder, 'train')
126 | write_json(db_val, self.tgt_folder, 'val')
127 |
128 | with open(self.tgt_folder / 'sequences_trainval.txt', 'w') as f:
129 | for folder in src_folders:
130 | f.write(folder.name + '\n')
131 | else:
132 | db['timestep_indices'] = sorted(db['timestep_indices'])
133 | write_json(db, self.tgt_folder, division)
134 |
135 | with open(self.tgt_folder / f'sequences_{division}.txt', 'w') as f:
136 | for folder in src_folders:
137 | f.write(folder.name + '\n')
138 |
139 |
140 | def write_json(db, tgt_folder, division=None):
141 | fname = "transforms.json" if division is None else f"transforms_{division}.json"
142 | json_path = tgt_folder / fname
143 | print(f"Writing database: {json_path}")
144 | with open(json_path, "w") as f:
145 | json.dump(db, f, indent=4)
146 |
147 | def main(
148 | src_folders: List[Path],
149 | tgt_folder: Path,
150 | division_mode: Literal['random_single', 'random_group', 'last']='random_group',
151 | ):
152 | incomplete = False
153 | print("==== Begin assembling datasets ====")
154 | print(f"Division mode: {division_mode}")
155 | for src_folder in src_folders:
156 | try:
157 | assert src_folder.exists(), f"Error: could not find {src_folder}"
158 | assert src_folder.parent == tgt_folder.parent, "All source folders must be in the same parent folder as the target folder"
159 | # print(src_folder)
160 | except AssertionError as e:
161 | print(e)
162 | incomplete = True
163 |
164 | if incomplete:
165 | return
166 |
167 | nerf_dataset_assembler = NeRFDatasetAssembler(src_folders, tgt_folder, division_mode)
168 | nerf_dataset_assembler.write()
169 |
170 | print("Done!")
171 |
172 |
173 | if __name__ == "__main__":
174 | tyro.cli(main)
175 |
--------------------------------------------------------------------------------
/vhap/config/base.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from dataclasses import dataclass
11 | from pathlib import Path
12 | from typing import Optional, Literal, Tuple
13 | import tyro
14 | import importlib
15 | from vhap.util.log import get_logger
16 | logger = get_logger(__name__)
17 |
18 |
19 | def import_module(module_name: str):
20 | module_name, class_name = module_name.rsplit(".", 1)
21 | module = getattr(importlib.import_module(module_name), class_name)
22 | return module
23 |
24 |
25 | class Config:
26 | def __getitem__(self, __name: str):
27 | if hasattr(self, __name):
28 | return getattr(self, __name)
29 | else:
30 | raise AttributeError(f"{self.__class__.__name__} has no attribute '{__name}'")
31 |
32 |
33 | @dataclass()
34 | class DataConfig(Config):
35 | root_folder: Path
36 | """The root folder for the dataset."""
37 | sequence: str
38 | """The sequence name"""
39 | _target: str = "vhap.data.video_dataset.VideoDataset"
40 | """The target dataset class"""
41 | division: Optional[str] = None
42 | subset: Optional[str] = None
43 | calibrated: bool = False
44 | """Whether the cameras parameters are available"""
45 | align_cameras_to_axes: bool = True
46 | """Adjust how cameras distribute in the space with a global rotation"""
47 | camera_convention_conversion: str = 'opencv->opengl'
48 | target_extrinsic_type: Literal['w2c', 'c2w'] = 'w2c'
49 | n_downsample_rgb: Optional[int] = None
50 | """Load from downsampled RGB images to save data IO time"""
51 | scale_factor: float = 1.0
52 | """Further apply a scaling transformation after the downsampling of RGB"""
53 | background_color: Optional[Literal['white', 'black']] = 'white'
54 | use_alpha_map: bool = False
55 | use_landmark: bool = True
56 | landmark_source: Optional[Literal['face-alignment', 'star']] = "star"
57 |
58 |
59 | @dataclass()
60 | class ModelConfig(Config):
61 | n_shape: int = 300
62 | n_expr: int = 100
63 | n_tex: int = 100
64 |
65 | use_static_offset: bool = True
66 | """Optimize static offsets on top of FLAME vertices in the canonical space"""
67 | use_dynamic_offset: bool = False
68 | """Optimize dynamic offsets on top of the FLAME vertices in the canonical space"""
69 | add_teeth: bool = True
70 | """Add teeth to the FLAME model"""
71 | remove_lip_inside: bool = False
72 | """Remove the inner part of the lips from the FLAME model"""
73 |
74 | tex_resolution: int = 2048
75 | """The resolution of the extra texture map"""
76 | tex_painted: bool = True
77 | """Use a painted texture map instead the pca texture space as the base texture map"""
78 | tex_extra: bool = True
79 | """Optimize an extra texture map as the base texture map or the residual texture map"""
80 | # tex_clusters: tuple[str, ...] = ("skin", "hair", "sclerae", "lips_tight", "boundary")
81 | tex_clusters: tuple[str, ...] = ("skin", "hair", "boundary", "lips_tight", "teeth", "sclerae", "irises")
82 | """Regions that are supposed to share a similar color inside"""
83 | residual_tex: bool = True
84 | """Use the extra texture map as a residual component on top of the base texture"""
85 | occluded: tuple[str, ...] = () # to be used for updating stage configs in __post_init__
86 | """The regions that are occluded by the hair or garments"""
87 |
88 | flame_params_path: Optional[Path] = None
89 |
90 |
91 | @dataclass()
92 | class RenderConfig(Config):
93 | backend: Literal['nvdiffrast', 'pytorch3d'] = 'nvdiffrast'
94 | """The rendering backend"""
95 | use_opengl: bool = False
96 | """Use OpenGL for NVDiffRast"""
97 | background_train: Literal['white', 'black', 'target'] = 'target'
98 | """Background color/image for training"""
99 | disturb_rate_fg: Optional[float] = 0.5
100 | """The rate of disturbance for the foreground"""
101 | disturb_rate_bg: Optional[float] = 0.5
102 | """The rate of disturbance for the background. 0.6 best for multi-view, 0.3 best for single-view"""
103 | background_eval: Literal['white', 'black', 'target'] = 'target'
104 | """Background color/image for evaluation"""
105 | lighting_type: Literal['constant', 'front', 'front-range', 'SH'] = 'SH'
106 | """The type of lighting"""
107 | lighting_space: Literal['world', 'camera'] = 'world'
108 | """The space of lighting"""
109 |
110 |
111 | @dataclass()
112 | class LearningRateConfig(Config):
113 | base: float = 5e-3
114 | """shape, texture, rotation, eyes, neck, jaw"""
115 | translation: float = 1e-3
116 | expr: float = 5e-2
117 | static_offset: float = 5e-4
118 | dynamic_offset: float = 5e-4
119 | camera: float = 5e-3
120 | light: float = 5e-3
121 |
122 |
123 | @dataclass()
124 | class LossWeightConfig(Config):
125 | landmark: Optional[float] = 10.
126 | always_enable_jawline_landmarks: bool = True
127 | """Always enable the landmark loss for the jawline landmarks. Ignore disable_jawline_landmarks in stages."""
128 |
129 | photo: Optional[float] = 30.
130 |
131 | # L2 regularization
132 | reg_shape: float = 3e-1
133 | reg_neck: float = 3e-1
134 | reg_jaw: float = 3e-1
135 | reg_eyes: float = 3e-2
136 | reg_expr: float = 3e-2
137 |
138 | # regularize the texture map
139 | reg_tex_res_clusters: Optional[float] = 1e1
140 | """Regularize the residual texture map inside each texture cluster"""
141 | reg_tex_res_for: tuple[str, ...] = ("sclerae", "teeth")
142 | """Regularize the residual texture map for the clusters specified"""
143 | reg_tex_tv: Optional[float] = 1e4 # important to split regions apart
144 | """Regularize the total variation of the texture map"""
145 | reg_tex_pca: float = 1e-4 # will make it hard to model hair color when too high
146 | """Regularize the pca texture map (not effective when model.tex_painted is True)"""
147 |
148 | # regularize the lighting
149 | reg_light: Optional[float] = None
150 | """Regularize lighting parameters"""
151 | reg_diffuse: Optional[float] = 1e2
152 | """Regularize lighting parameters by the diffuse term"""
153 |
154 | # L2 regularization for static_offset
155 | reg_offset: Optional[float] = 3e2
156 | """Regularize the norm of offsets"""
157 | reg_offset_relax_coef: float = 1.
158 | """The coefficient for relaxing reg_offset for the regions specified"""
159 | reg_offset_relax_for: tuple[str, ...] = ("hair", "ears")
160 | """Relax the offset loss for the regions specified"""
161 |
162 | # laplacian regularization for static_offset
163 | reg_offset_lap: Optional[float] = 1e6
164 | """Regularize the difference of laplacian coordinate caused by offsets"""
165 | reg_offset_lap_relax_coef: float = 0.1
166 | """The coefficient for relaxing reg_offset_lap for the regions specified"""
167 | reg_offset_lap_relax_for: tuple[str, ...] = ("hair", "ears")
168 | """Relax the offset loss for the regions specified"""
169 |
170 | # local rigidity regularization for static_offset
171 | reg_offset_rigid: Optional[float] = 3e2
172 | """Regularize the the offsets to be as-rigid-as-possible"""
173 | reg_offset_rigid_for: tuple[str, ...] = ("left_ear", "right_ear", "neck", "left_eye", "right_eye", "lips_tight")
174 | """Regularize the the offsets to be as-rigid-as-possible for the regions specified"""
175 |
176 | reg_offset_dynamic: Optional[float] = 3e5
177 | """Regularize the dynamic offsets to be temporally smooth"""
178 |
179 | blur_iter: int = 0
180 | """The number of iterations for blurring vertex weights"""
181 |
182 | # temporal smoothness
183 | smooth_trans: float = 3e2
184 | """global translation"""
185 | smooth_rot: float = 3e1
186 | """global rotation"""
187 | smooth_neck: float = 3e1
188 | """neck joint"""
189 | smooth_jaw: float = 1e-1
190 | """jaw joint"""
191 | smooth_eyes: float = 0
192 | """eyes joints"""
193 | smooth_expr: float = 1e0
194 | """expression"""
195 |
196 |
197 | @dataclass()
198 | class LogConfig(Config):
199 | interval_scalar: Optional[int] = 100
200 | """The step interval of scalar logging. Using an interval of stage_tracking.num_steps // 5 unless specified."""
201 | interval_media: Optional[int] = 500
202 | """The step interval of media logging. Using an interval of stage_tracking.num_steps unless specified."""
203 | image_format: Literal['jpg', 'png'] = 'jpg'
204 | """Output image format"""
205 | view_indices: Tuple[int, ...] = ()
206 | """Manually specify the view indices for log"""
207 | max_num_views: int = 3
208 | """The maximum number of views for log"""
209 | stack_views_in_rows: bool = True
210 |
211 |
212 | @dataclass()
213 | class ExperimentConfig(Config):
214 | output_folder: Path = Path('output/track')
215 | reuse_landmarks: bool = True
216 | keyframes: Tuple[int, ...] = tuple()
217 | photometric: bool = True
218 | """enable photometric optimization, otherwise only landmark optimization"""
219 |
220 | @dataclass()
221 | class StageConfig(Config):
222 | disable_jawline_landmarks: bool = False
223 | """Disable the landmark loss for the jawline landmarks since they are not accurate"""
224 |
225 | @dataclass()
226 | class StageLmkInitRigidConfig(StageConfig):
227 | """The stage for initializing the rigid parameters"""
228 | num_steps: int = 500
229 | optimizable_params: tuple[str, ...] = ("cam", "pose")
230 |
231 | @dataclass()
232 | class StageLmkInitAllConfig(StageConfig):
233 | """The stage for initializing all the parameters optimizable with landmark loss"""
234 | num_steps: int = 500
235 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr")
236 |
237 | @dataclass()
238 | class StageLmkSequentialTrackingConfig(StageConfig):
239 | """The stage for sequential tracking with landmark loss"""
240 | num_steps: int = 50
241 | optimizable_params: tuple[str, ...] = ("pose", "joints", "expr")
242 |
243 | @dataclass()
244 | class StageLmkGlobalTrackingConfig(StageConfig):
245 | """The stage for global tracking with landmark loss"""
246 | num_epochs: int = 30
247 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr")
248 |
249 | @dataclass()
250 | class PhotometricStageConfig(StageConfig):
251 | align_texture_except: tuple[str, ...] = ()
252 | """Align the inner region of rendered FLAME to the image, except for the regions specified"""
253 | align_boundary_except: tuple[str, ...] = ("bottomline",) # necessary to avoid the bottomline of FLAME from being stretched to the bottom of the image
254 | """Align the boundary of FLAME to the image, except for the regions specified"""
255 |
256 | @dataclass()
257 | class StageRgbInitTextureConfig(PhotometricStageConfig):
258 | """The stage for initializing the texture map with photometric loss"""
259 | num_steps: int = 500
260 | optimizable_params: tuple[str, ...] = ("cam", "shape", "texture", "lights")
261 | align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
262 | align_boundary_except: tuple[str, ...] = ("hair", "boundary")
263 |
264 | @dataclass()
265 | class StageRgbInitAllConfig(PhotometricStageConfig):
266 | """The stage for initializing all the parameters except the offsets with photometric loss"""
267 | num_steps: int = 500
268 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights")
269 | disable_jawline_landmarks: bool = True
270 | align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
271 | align_boundary_except: tuple[str, ...] = ("hair", "bottomline")
272 |
273 | @dataclass()
274 | class StageRgbInitOffsetConfig(PhotometricStageConfig):
275 | """The stage for initializing the offsets with photometric loss"""
276 | num_steps: int = 500
277 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset")
278 | disable_jawline_landmarks: bool = True
279 | align_texture_except: tuple[str, ...] = ("hair", "boundary", "neck")
280 |
281 | @dataclass()
282 | class StageRgbSequentialTrackingConfig(PhotometricStageConfig):
283 | """The stage for sequential tracking with photometric loss"""
284 | num_steps: int = 50
285 | optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "texture", "dynamic_offset")
286 | disable_jawline_landmarks: bool = True
287 |
288 | @dataclass()
289 | class StageRgbGlobalTrackingConfig(PhotometricStageConfig):
290 | """The stage for global tracking with photometric loss"""
291 | num_epochs: int = 30
292 | optimizable_params: tuple[str, ...] = ("cam", "pose", "shape", "joints", "expr", "texture", "lights", "static_offset", "dynamic_offset")
293 | disable_jawline_landmarks: bool = True
294 |
295 | @dataclass()
296 | class PipelineConfig(Config):
297 | lmk_init_rigid: StageLmkInitRigidConfig
298 | lmk_init_all: StageLmkInitAllConfig
299 | lmk_sequential_tracking: StageLmkSequentialTrackingConfig
300 | lmk_global_tracking: StageLmkGlobalTrackingConfig
301 | rgb_init_texture: StageRgbInitTextureConfig
302 | rgb_init_all: StageRgbInitAllConfig
303 | rgb_init_offset: StageRgbInitOffsetConfig
304 | rgb_sequential_tracking: StageRgbSequentialTrackingConfig
305 | rgb_global_tracking: StageRgbGlobalTrackingConfig
306 |
307 |
308 | @dataclass()
309 | class BaseTrackingConfig(Config):
310 | data: DataConfig
311 | model: ModelConfig
312 | render: RenderConfig
313 | log: LogConfig
314 | exp: ExperimentConfig
315 | lr: LearningRateConfig
316 | w: LossWeightConfig
317 | pipeline: PipelineConfig
318 |
319 | begin_stage: Optional[str] = None
320 | """Begin from the specified stage for debugging"""
321 | begin_frame_idx: int = 0
322 | """Begin from the specified frame index for debugging"""
323 | async_func: bool = True
324 | """Allow asynchronous function calls for speed up"""
325 | device: Literal['cuda', 'cpu'] = 'cuda'
326 |
327 | def get_occluded(self):
328 | occluded_table = {
329 | }
330 | if self.data.sequence in occluded_table:
331 | logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.sequence]}")
332 | self.model.occluded = occluded_table[self.data.sequence]
333 |
334 | def __post_init__(self):
335 | self.get_occluded()
336 |
337 | if not self.model.use_static_offset and not self.model.use_dynamic_offset:
338 | self.model.occluded = tuple(list(self.model.occluded) + ['hair']) # disable boundary alignment for the hair region if no offset is used
339 |
340 | for cfg_stage in self.pipeline.__dict__.values():
341 | if isinstance(cfg_stage, PhotometricStageConfig):
342 | cfg_stage.align_texture_except = tuple(list(cfg_stage.align_texture_except) + list(self.model.occluded))
343 | cfg_stage.align_boundary_except = tuple(list(cfg_stage.align_boundary_except) + list(self.model.occluded))
344 |
345 | if self.begin_stage is not None:
346 | skip = True
347 | for cfg_stage in self.pipeline.__dict__.values():
348 | if cfg_stage.__class__.__name__.lower() == self.begin_stage:
349 | skip = False
350 | if skip:
351 | cfg_stage.num_steps = 0
352 |
353 |
354 | if __name__ == "__main__":
355 | config = tyro.cli(BaseTrackingConfig)
356 | print(tyro.to_yaml(config))
--------------------------------------------------------------------------------
/vhap/config/nersemble.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from typing import Optional, Literal
11 | from dataclasses import dataclass
12 | import tyro
13 |
14 | from vhap.config.base import (
15 | StageRgbSequentialTrackingConfig, StageRgbGlobalTrackingConfig, PipelineConfig,
16 | DataConfig, LossWeightConfig, BaseTrackingConfig,
17 | )
18 | from vhap.util.log import get_logger
19 | logger = get_logger(__name__)
20 |
21 |
22 | @dataclass()
23 | class NersembleDataConfig(DataConfig):
24 | _target: str = "vhap.data.nersemble_dataset.NeRSembleDataset"
25 | calibrated: bool = True
26 | image_size_during_calibration: Optional[tuple[int, int]] = (3208, 2200)
27 | """(height, width). Will be use to convert principle points when the image size is not included in the camera parameters."""
28 | background_color: Optional[Literal['white', 'black']] = None
29 | landmark_source: Optional[Literal["face-alignment", 'star']] = "star"
30 |
31 | subject: str = ""
32 | """Subject ID. Such as 018, 218, 251, 253"""
33 | use_color_correction: bool = True
34 | """Whether to use color correction to harmonize the color of the input images."""
35 |
36 | @dataclass()
37 | class NersembleLossWeightConfig(LossWeightConfig):
38 | landmark: Optional[float] = 3. # should not be lower to avoid collapse
39 | always_enable_jawline_landmarks: bool = False # allow disable_jawline_landmarks in StageConfig to work
40 | reg_expr: float = 1e-2 # for best expressivness
41 | reg_tex_tv: Optional[float] = 1e5 # 10x of the base value
42 | smooth_expr: float = 0 # for best expressivness
43 |
44 | @dataclass()
45 | class NersembleStageRgbSequentialTrackingConfig(StageRgbSequentialTrackingConfig):
46 | optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "dynamic_offset")
47 |
48 | align_texture_except: tuple[str, ...] = ("boundary",)
49 | align_boundary_except: tuple[str, ...] = ("boundary",)
50 | """Due to the limited flexibility in the lower neck region of FLAME, we relax the
51 | alignment constraints for better alignment in the face region.
52 | """
53 |
54 | @dataclass()
55 | class NersembleStageRgbGlobalTrackingConfig(StageRgbGlobalTrackingConfig):
56 | align_texture_except: tuple[str, ...] = ("boundary",)
57 | align_boundary_except: tuple[str, ...] = ("boundary",)
58 | """Due to the limited flexibility in the lower neck region of FLAME, we relax the
59 | alignment constraints for better alignment in the face region.
60 | """
61 |
62 | @dataclass()
63 | class NersemblePipelineConfig(PipelineConfig):
64 | rgb_sequential_tracking: NersembleStageRgbSequentialTrackingConfig
65 | rgb_global_tracking: NersembleStageRgbGlobalTrackingConfig
66 |
67 | @dataclass()
68 | class NersembleTrackingConfig(BaseTrackingConfig):
69 | data: NersembleDataConfig
70 | w: NersembleLossWeightConfig
71 | pipeline: NersemblePipelineConfig
72 |
73 | def get_occluded(self):
74 | occluded_table = {
75 | '018': ('neck_lower',),
76 | '218': ('neck_lower',),
77 | '251': ('neck_lower', 'boundary'),
78 | '253': ('neck_lower',),
79 | }
80 | if self.data.subject in occluded_table:
81 | logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.subject]}")
82 | self.model.occluded = occluded_table[self.data.subject]
83 |
84 |
85 | if __name__ == "__main__":
86 | config = tyro.cli(NersembleTrackingConfig)
87 | print(tyro.to_yaml(config))
--------------------------------------------------------------------------------
/vhap/config/nersemble_v2.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from dataclasses import dataclass
11 | import tyro
12 |
13 | from vhap.config.nersemble import NersembleDataConfig, NersembleTrackingConfig
14 | from vhap.util.log import get_logger
15 | logger = get_logger(__name__)
16 |
17 |
18 | @dataclass()
19 | class NersembleV2DataConfig(NersembleDataConfig):
20 | _target: str = "vhap.data.nersemble_v2_dataset.NeRSembleV2Dataset"
21 |
22 |
23 | @dataclass()
24 | class NersembleV2TrackingConfig(NersembleTrackingConfig):
25 | data: NersembleV2DataConfig
26 |
27 |
28 | if __name__ == "__main__":
29 | config = tyro.cli(NersembleV2TrackingConfig)
30 | print(tyro.to_yaml(config))
--------------------------------------------------------------------------------
/vhap/data/image_folder_dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional
3 | import numpy as np
4 | import PIL.Image as Image
5 | from torch.utils.data import Dataset
6 | from vhap.util.log import get_logger
7 |
8 |
9 | logger = get_logger(__name__)
10 |
11 |
12 | class ImageFolderDataset(Dataset):
13 | def __init__(
14 | self,
15 | image_folder: Path,
16 | background_folder: Optional[Path]=None,
17 | background_fname2camId=lambda x: x,
18 | image_fname2camId=lambda x: x,
19 | ):
20 | """
21 | Args:
22 | root_folder: Path to dataset with the following directory layout
23 | /
24 | |---xx.jpg
25 | |---...
26 | """
27 | super().__init__()
28 | self.image_fname2camId = image_fname2camId
29 | self.background_foler = background_folder
30 |
31 | logger.info(f"Initializing dataset from folder {image_folder}")
32 |
33 | self.image_paths = sorted(list(image_folder.glob('*.jpg')))
34 |
35 | if background_folder is not None:
36 | self.backgrounds = {}
37 | background_paths = sorted(list((image_folder / background_folder).glob('*.jpg')))
38 |
39 | for background_path in background_paths:
40 | bg = np.array(Image.open(background_path))
41 | cam_id = background_fname2camId(background_path.name)
42 | self.backgrounds[cam_id] = bg
43 |
44 | def __len__(self):
45 | return len(self.image_paths)
46 |
47 | def __getitem__(self, i):
48 | image_path = self.image_paths[i]
49 | cam_id = self.image_fname2camId(image_path.name)
50 | rgb = np.array(Image.open(image_path))
51 | item = {
52 | "rgb": rgb,
53 | 'image_path': str(image_path),
54 | }
55 |
56 | if self.background_foler is not None:
57 | item['background'] = self.backgrounds[cam_id]
58 |
59 | return item
60 |
61 |
62 | if __name__ == "__main__":
63 | from tqdm import tqdm
64 | from torch.utils.data import DataLoader
65 |
66 | dataset = ImageFolderDataset(
67 | image_folder='./xx',
68 | img_to_tensor=True,
69 | )
70 |
71 | print(len(dataset))
72 |
73 | sample = dataset[0]
74 | print(sample.keys())
75 | print(sample["rgb"].shape)
76 |
77 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
78 | for item in tqdm(dataloader):
79 | pass
80 |
--------------------------------------------------------------------------------
/vhap/data/nerf_dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import json
3 | import numpy as np
4 | import PIL.Image as Image
5 | import torch
6 | import torchvision.transforms.functional as F
7 | from torch.utils.data import Dataset
8 | from vhap.util.log import get_logger
9 |
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | class NeRFDataset(Dataset):
15 | def __init__(
16 | self,
17 | root_folder,
18 | division=None,
19 | camera_convention_conversion=None,
20 | target_extrinsic_type='w2c',
21 | use_fg_mask=False,
22 | use_flame_param=False,
23 | ):
24 | """
25 | Args:
26 | root_folder: Path to dataset with the following directory layout
27 | /
28 | |
29 | |---/
30 | | |---00000.jpg
31 | | |...
32 | |
33 | |---/
34 | | |---00000.png
35 | | |...
36 | |
37 | |---/
38 | | |---00000.npz
39 | | |...
40 | |
41 | |---transforms_backup.json # backup of the original transforms.json
42 | |---transforms_backup_flame.json # backup of the original transforms.json with flame_param
43 | |---transforms.json # the final transforms.json
44 | |---transforms_train.json # the final transforms.json for training
45 | |---transforms_val.json # the final transforms.json for validation
46 | |---transforms_test.json # the final transforms.json for testing
47 |
48 |
49 | """
50 |
51 | super().__init__()
52 | self.root_folder = Path(root_folder)
53 | self.division = division
54 | self.camera_convention_conversion = camera_convention_conversion
55 | self.target_extrinsic_type = target_extrinsic_type
56 | self.use_fg_mask = use_fg_mask
57 | self.use_flame_param = use_flame_param
58 |
59 | logger.info(f"Loading NeRF scene from: {root_folder}")
60 |
61 | # data division
62 | if division is None:
63 | tranform_path = self.root_folder / "transforms.json"
64 | elif division == "train":
65 | tranform_path = self.root_folder / "transforms_train.json"
66 | elif division == "val":
67 | tranform_path = self.root_folder / "transforms_val.json"
68 | elif division == "test":
69 | tranform_path = self.root_folder / "transforms_test.json"
70 | else:
71 | raise NotImplementedError(f"Unknown division type: {division}")
72 | logger.info(f"division: {division}")
73 |
74 | self.transforms = json.load(open(tranform_path, "r"))
75 | logger.info(f"number of timesteps: {len(self.transforms['timestep_indices'])}, number of cameras: {len(self.transforms['camera_indices'])}")
76 |
77 | assert len(self.transforms['timestep_indices']) == max(self.transforms['timestep_indices']) + 1
78 |
79 | def __len__(self):
80 | return len(self.transforms['frames'])
81 |
82 | def __getitem__(self, i):
83 | frame = self.transforms['frames'][i]
84 |
85 | # 'timestep_index', 'timestep_index_original', 'timestep_id', 'camera_index', 'camera_id', 'cx', 'cy', 'fl_x', 'fl_y', 'h', 'w', 'camera_angle_x', 'camera_angle_y', 'transform_matrix', 'file_path', 'fg_mask_path', 'flame_param_path']
86 |
87 | K = torch.eye(3)
88 | K[[0, 1, 0, 1], [0, 1, 2, 2]] = torch.tensor(
89 | [frame["fl_x"], frame["fl_y"], frame["cx"], frame["cy"]]
90 | )
91 |
92 | c2w = torch.tensor(frame['transform_matrix'])
93 | if self.target_extrinsic_type == "w2c":
94 | extrinsic = c2w.inverse()
95 | elif self.target_extrinsic_type == "c2w":
96 | extrinsic = c2w
97 | else:
98 | raise NotImplementedError(f"Unknown extrinsic type: {self.target_extrinsic_type}")
99 |
100 | img_path = self.root_folder / frame['file_path']
101 |
102 | item = {
103 | 'timestep_index': frame['timestep_index'],
104 | 'camera_index': frame['camera_index'],
105 | 'intrinsics': K,
106 | 'extrinsics': extrinsic,
107 | 'image_height': frame['h'],
108 | 'image_width': frame['w'],
109 | 'image': np.array(Image.open(img_path)),
110 | 'image_path': img_path,
111 | }
112 |
113 | if self.use_fg_mask and 'fg_mask_path' in frame:
114 | fg_mask_path = self.root_folder / frame['fg_mask_path']
115 | item["fg_mask"] = np.array(Image.open(fg_mask_path))
116 | item["fg_mask_path"] = fg_mask_path
117 |
118 | if self.use_flame_param and 'flame_param_path' in frame:
119 | npz = np.load(self.root_folder / frame['flame_param_path'], allow_pickle=True)
120 | item["flame_param"] = dict(npz)
121 |
122 | return item
123 |
124 | def apply_to_tensor(self, item):
125 | if self.img_to_tensor:
126 | if "rgb" in item:
127 | item["rgb"] = F.to_tensor(item["rgb"])
128 | # if self.rgb_range_shift:
129 | # item["rgb"] = (item["rgb"] - 0.5) / 0.5
130 |
131 | if "alpha_map" in item:
132 | item["alpha_map"] = F.to_tensor(item["alpha_map"])
133 | return item
134 |
135 |
136 | if __name__ == "__main__":
137 | from tqdm import tqdm
138 | from dataclasses import dataclass
139 | import tyro
140 | from torch.utils.data import DataLoader
141 |
142 | @dataclass
143 | class Args:
144 | root_folder: str
145 | subject: str
146 | sequence: str
147 | use_landmark: bool = False
148 | batchify_all_views: bool = False
149 |
150 | args = tyro.cli(Args)
151 |
152 | dataset = NeRFDataset(root_folder=args.root_folder)
153 |
154 | print(len(dataset))
155 |
156 | sample = dataset[0]
157 | print(sample.keys())
158 |
159 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
160 | for item in tqdm(dataloader):
161 | pass
162 |
--------------------------------------------------------------------------------
/vhap/data/nersemble_dataset.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import json
11 | import numpy as np
12 | import torch
13 | from vhap.data.video_dataset import VideoDataset
14 | from vhap.config.nersemble import NersembleDataConfig
15 | from vhap.util import camera
16 | from vhap.util.log import get_logger
17 |
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | class NeRSembleDataset(VideoDataset):
23 | def __init__(
24 | self,
25 | cfg: NersembleDataConfig,
26 | img_to_tensor: bool = False,
27 | batchify_all_views: bool = False,
28 | ):
29 | """
30 | Folder layout for NeRSemble dataset:
31 |
32 | /
33 | |---camera_params/
34 | | |---/
35 | | |---camera_params.json
36 | |
37 | |---color_correction/
38 | | |---/
39 | | |---.npy
40 | |
41 | |---/
42 | |---/
43 | |---images/
44 | | |---cam__.jpg
45 | |
46 | |---alpha_maps/
47 | | |---cam__.png
48 | |
49 | |---landmark2d/
50 | |---face-alignment/
51 | | |---.npz
52 | |
53 | |---STAR/
54 | |---.npz
55 | """
56 | self.cfg = cfg
57 | assert cfg.subject != "", "Please specify the subject name"
58 |
59 | super().__init__(
60 | cfg=cfg,
61 | img_to_tensor=img_to_tensor,
62 | batchify_all_views=batchify_all_views,
63 | )
64 | self.load_color_correction()
65 |
66 | def match_sequences(self):
67 | logger.info(f"Subject: {self.cfg.subject}, sequence: {self.cfg.sequence}")
68 | return list(filter(lambda x: x.is_dir(), (self.cfg.root_folder / self.cfg.subject).glob(f"{self.cfg.sequence}*")))
69 |
70 | def define_properties(self):
71 | super().define_properties()
72 | self.properties['rgb']['cam_id_prefix'] = "cam_"
73 | self.properties['alpha_map']['cam_id_prefix'] = "cam_"
74 |
75 | def load_camera_params(self, camera_params_path=None):
76 | if camera_params_path is None:
77 | camera_params_path = self.cfg.root_folder / "camera_params" / self.cfg.subject / "camera_params.json"
78 |
79 | assert camera_params_path.exists()
80 | param = json.load(open(camera_params_path))
81 |
82 | K = torch.Tensor(param["intrinsics"])
83 |
84 | if "height" not in param or "width" not in param:
85 | assert self.cfg.image_size_during_calibration is not None
86 | H, W = self.cfg.image_size_during_calibration
87 | else:
88 | H, W = param["height"], param["width"]
89 |
90 | self.camera_ids = list(param["world_2_cam"].keys())
91 | w2c = torch.tensor([param["world_2_cam"][k] for k in self.camera_ids]) # (N, 4, 4)
92 | R = w2c[..., :3, :3]
93 | T = w2c[..., :3, 3]
94 |
95 | orientation = R.transpose(-1, -2) # (N, 3, 3)
96 | location = R.transpose(-1, -2) @ -T[..., None] # (N, 3, 1)
97 |
98 | # adjust how cameras distribute in the space with a global rotation
99 | if self.cfg.align_cameras_to_axes:
100 | orientation, location = camera.align_cameras_to_axes(
101 | orientation, location, target_convention="opengl"
102 | )
103 |
104 | # modify the local orientation of cameras to fit in different camera conventions
105 | if self.cfg.camera_convention_conversion is not None:
106 | orientation, K = camera.convert_camera_convention(
107 | self.cfg.camera_convention_conversion, orientation, K, H, W
108 | )
109 |
110 | c2w = torch.cat([orientation, location], dim=-1) # camera-to-world transformation
111 |
112 | if self.cfg.target_extrinsic_type == "w2c":
113 | R = orientation.transpose(-1, -2)
114 | T = orientation.transpose(-1, -2) @ -location
115 | w2c = torch.cat([R, T], dim=-1) # world-to-camera transformation
116 | extrinsic = w2c
117 | elif self.cfg.target_extrinsic_type == "c2w":
118 | extrinsic = c2w
119 | else:
120 | raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}")
121 |
122 | self.camera_params = {}
123 | for i, camera_id in enumerate(self.camera_ids):
124 | self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]}
125 |
126 | def load_color_correction(self):
127 | if self.cfg.use_color_correction:
128 | self.color_correction = {}
129 |
130 | for camera_id in self.camera_ids:
131 | color_correction_path = self.cfg.root_folder / 'color_correction' / self.cfg.subject / f'{camera_id}.npy'
132 | assert color_correction_path.exists(), f"Color correction file not found: {color_correction_path}"
133 | self.color_correction[camera_id] = np.load(color_correction_path)
134 |
135 | def filter_division(self, division):
136 | if division is not None:
137 | cam_for_train = [8, 7, 9, 4, 10, 5, 13, 2, 12, 1, 14, 0]
138 | if division == "train":
139 | self.camera_ids = [
140 | self.camera_ids[i]
141 | for i in range(len(self.camera_ids))
142 | if i in cam_for_train
143 | ]
144 | elif division == "val":
145 | self.camera_ids = [
146 | self.camera_ids[i]
147 | for i in range(len(self.camera_ids))
148 | if i not in cam_for_train
149 | ]
150 | elif division == "front-view":
151 | self.camera_ids = self.camera_ids[8:9]
152 | elif division == "side-view":
153 | self.camera_ids = self.camera_ids[0:1]
154 | elif division == "six-view":
155 | self.camera_ids = [self.camera_ids[i] for i in [0, 1, 7, 8, 14, 15]]
156 | else:
157 | raise NotImplementedError(f"Unknown division type: {division}")
158 | logger.info(f"division: {division}")
159 |
160 | def apply_transforms(self, item):
161 | item = self.apply_color_correction(item)
162 | item = super().apply_transforms(item)
163 | return item
164 |
165 | def apply_color_correction(self, item):
166 | if self.cfg.use_color_correction:
167 | affine_color_transform = self.color_correction[item["camera_id"]]
168 | rgb = item["rgb"] / 255
169 | rgb = rgb @ affine_color_transform[:3, :3] + affine_color_transform[np.newaxis, :3, 3]
170 | item["rgb"] = (np.clip(rgb, 0, 1) * 255).astype(np.uint8)
171 | return item
172 |
173 |
174 | if __name__ == "__main__":
175 | import tyro
176 | from tqdm import tqdm
177 | from torch.utils.data import DataLoader
178 | from vhap.config.nersemble import NersembleDataConfig
179 | from vhap.config.base import import_module
180 |
181 | cfg = tyro.cli(NersembleDataConfig)
182 | cfg.use_landmark = False
183 | dataset = import_module(cfg._target)(
184 | cfg=cfg,
185 | img_to_tensor=False,
186 | batchify_all_views=True,
187 | )
188 |
189 | print(len(dataset))
190 |
191 | sample = dataset[0]
192 | print(sample.keys())
193 | print(sample["rgb"].shape)
194 |
195 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
196 | for item in tqdm(dataloader):
197 | pass
198 |
--------------------------------------------------------------------------------
/vhap/data/nersemble_v2_dataset.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import json
11 | import numpy as np
12 | import colour
13 | from vhap.data.nersemble_dataset import NeRSembleDataset
14 | from vhap.util.log import get_logger
15 | from vhap.util.color_correction import color_correction_Cheung2004_precomputed
16 |
17 |
18 | logger = get_logger(__name__)
19 |
20 |
21 | class NeRSembleV2Dataset(NeRSembleDataset):
22 | """
23 | Folder layout for NeRSembleV2 dataset:
24 |
25 | /
26 | |---/
27 | |---calibration/
28 | | |---camera_params.json
29 | | |---color_calibration.json
30 | |
31 | |---sequences/
32 | |---/
33 | |---images/
34 | | |---cam__.jpg
35 | |
36 | |---alpha_maps/
37 | | |---cam__.png
38 | |
39 | |---landmark2d/
40 | |---face-alignment/
41 | | |---.npz
42 | |
43 | |---STAR/
44 | |---.npz
45 |
46 | """
47 |
48 | def match_sequences(self):
49 | logger.info(f"Subject: {self.cfg.subject}, sequence: {self.cfg.sequence}")
50 | return list(filter(lambda x: x.is_dir(), (self.cfg.root_folder / self.cfg.subject / 'sequences').glob(f"{self.cfg.sequence}*")))
51 |
52 | def load_camera_params(self):
53 | super().load_camera_params(self.cfg.root_folder / self.cfg.subject / "calibration" / "camera_params.json")
54 |
55 | def load_color_correction(self):
56 | if self.cfg.use_color_correction:
57 | color_correction_path = self.cfg.root_folder / self.cfg.subject / 'calibration' / f'color_calibration.json'
58 | self.color_correction = {serial: np.array(ccm) for serial, ccm in json.load(open(color_correction_path)).items()}
59 |
60 | def apply_color_correction(self, item):
61 | if self.cfg.use_color_correction:
62 | rgb = item["rgb"] / 255
63 | image_linear = colour.cctf_decoding(rgb)
64 | ccm = self.color_correction[item["camera_id"]]
65 | image_corrected = color_correction_Cheung2004_precomputed(image_linear, ccm)
66 | image_corrected = colour.cctf_encoding(image_corrected)
67 | item["rgb"] = (np.clip(rgb, 0, 1) * 255).astype(np.uint8)
68 | return item
69 |
70 |
71 | if __name__ == "__main__":
72 | import tyro
73 | from tqdm import tqdm
74 | from torch.utils.data import DataLoader
75 | from vhap.data.nersemble_v2_dataset import NeRSembleV2Dataset
76 | from vhap.config.nersemble_v2 import NersembleV2DataConfig
77 | from vhap.config.base import import_module
78 |
79 | cfg = tyro.cli(NersembleV2DataConfig)
80 | cfg.use_landmark = False
81 | dataset = import_module(cfg._target)(
82 | cfg=cfg,
83 | img_to_tensor=False,
84 | batchify_all_views=True,
85 | )
86 |
87 | print(len(dataset))
88 |
89 | sample = dataset[0]
90 | print(sample.keys())
91 | print(sample["rgb"].shape)
92 |
93 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
94 | for item in tqdm(dataloader):
95 | pass
96 |
--------------------------------------------------------------------------------
/vhap/data/video_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from copy import deepcopy
4 | from typing import Optional
5 | import numpy as np
6 | import PIL.Image as Image
7 | import torch
8 | import torchvision.transforms.functional as F
9 | from torch.utils.data import Dataset, default_collate
10 |
11 | from vhap.util.log import get_logger
12 | from vhap.config.base import DataConfig
13 |
14 |
15 | logger = get_logger(__name__)
16 |
17 |
18 | class VideoDataset(Dataset):
19 | def __init__(
20 | self,
21 | cfg: DataConfig,
22 | img_to_tensor: bool = False,
23 | batchify_all_views: bool = False,
24 | ):
25 | """
26 | Args:
27 | root_folder: Path to dataset with the following directory layout
28 | /
29 | |---images/
30 | | |---.jpg
31 | |
32 | |---alpha_maps/
33 | | |---.png
34 | |
35 | |---landmark2d/
36 | |---face-alignment/
37 | | |---.npz
38 | |
39 | |---STAR/
40 | |---.npz
41 | """
42 | super().__init__()
43 | self.cfg = cfg
44 | self.img_to_tensor = img_to_tensor
45 | self.batchify_all_views = batchify_all_views
46 |
47 | sequence_paths = self.match_sequences()
48 | if len(sequence_paths) > 1:
49 | logger.info(f"Found multiple sequences: {sequence_paths}")
50 | raise ValueError(f"Found multiple sequences by '{cfg.sequence}': \n" + "\n\t".join([str(x) for x in sequence_paths]))
51 | elif len(sequence_paths) == 0:
52 | raise ValueError(f"Cannot find sequence: {cfg.sequence}")
53 | self.sequence_path = sequence_paths[0]
54 | logger.info(f"Initializing dataset from {self.sequence_path}")
55 |
56 | self.define_properties()
57 | self.load_camera_params()
58 |
59 | # timesteps
60 | self.timestep_ids = set(
61 | f.split('.')[0].split('_')[-1]
62 | for f in os.listdir(self.sequence_path / self.properties['rgb']['folder']) if f.endswith(self.properties['rgb']['suffix'])
63 | )
64 | self.timestep_ids = sorted(self.timestep_ids)
65 | self.timestep_indices = list(range(len(self.timestep_ids)))
66 |
67 | self.filter_division(cfg.division)
68 | self.filter_subset(cfg.subset)
69 |
70 | logger.info(f"number of timesteps: {self.num_timesteps}, number of cameras: {self.num_cameras}")
71 |
72 | # collect
73 | self.items = []
74 | for fi, timestep_index in enumerate(self.timestep_indices):
75 | for ci, camera_id in enumerate(self.camera_ids):
76 | self.items.append(
77 | {
78 | "timestep_index": fi, # new index after filtering
79 | "timestep_index_original": timestep_index, # original index
80 | "timestep_id": self.timestep_ids[timestep_index],
81 | "camera_index": ci,
82 | "camera_id": camera_id,
83 | }
84 | )
85 |
86 | def match_sequences(self):
87 | logger.info(f"Looking for sequence '{self.cfg.sequence}' at {self.cfg.root_folder}")
88 | return list(filter(lambda x: x.is_dir(), self.cfg.root_folder.glob(f"{self.cfg.sequence}*")))
89 |
90 | def define_properties(self):
91 | self.properties = {
92 | "rgb": {
93 | "folder": f"images_{self.cfg.n_downsample_rgb}"
94 | if self.cfg.n_downsample_rgb
95 | else "images",
96 | "per_timestep": True,
97 | "suffix": "jpg",
98 | },
99 | "alpha_map": {
100 | "folder": "alpha_maps",
101 | "per_timestep": True,
102 | "suffix": "jpg",
103 | },
104 | "landmark2d/face-alignment": {
105 | "folder": "landmark2d/face-alignment",
106 | "per_timestep": False,
107 | "suffix": "npz",
108 | },
109 | "landmark2d/STAR": {
110 | "folder": "landmark2d/STAR",
111 | "per_timestep": False,
112 | "suffix": "npz",
113 | },
114 | }
115 |
116 | @staticmethod
117 | def get_number_after_prefix(string, prefix):
118 | i = string.find(prefix)
119 | if i != -1:
120 | number_begin = i + len(prefix)
121 | assert number_begin < len(string), f"No number found behind prefix '{prefix}'"
122 | assert string[number_begin].isdigit(), f"No number found behind prefix '{prefix}'"
123 |
124 | non_digit_indices = [i for i, c in enumerate(string[number_begin:]) if not c.isdigit()]
125 | if len(non_digit_indices) > 0:
126 | number_end = number_begin + min(non_digit_indices)
127 | return int(string[number_begin:number_end])
128 | else:
129 | return int(string[number_begin:])
130 | else:
131 | return None
132 |
133 | def filter_division(self, division):
134 | pass
135 |
136 | def filter_subset(self, subset):
137 | if subset is not None:
138 | if 'ti' in subset:
139 | ti = self.get_number_after_prefix(subset, 'ti')
140 | if 'tj' in subset:
141 | tj = self.get_number_after_prefix(subset, 'tj')
142 | self.timestep_indices = self.timestep_indices[ti:tj+1]
143 | else:
144 | self.timestep_indices = self.timestep_indices[ti:ti+1]
145 | elif 'tn' in subset:
146 | tn = self.get_number_after_prefix(subset, 'tn')
147 | tn_all = len(self.timestep_indices)
148 | tn = min(tn, tn_all)
149 | self.timestep_indices = self.timestep_indices[::tn_all // tn][:tn]
150 | elif 'ts' in subset:
151 | ts = self.get_number_after_prefix(subset, 'ts')
152 | self.timestep_indices = self.timestep_indices[::ts]
153 | if 'ci' in subset:
154 | ci = self.get_number_after_prefix(subset, 'ci')
155 | self.camera_ids = self.camera_ids[ci:ci+1]
156 | elif 'cn' in subset:
157 | cn = self.get_number_after_prefix(subset, 'cn')
158 | cn_all = len(self.camera_ids)
159 | cn = min(cn, cn_all)
160 | self.camera_ids = self.camera_ids[::cn_all // cn][:cn]
161 | elif 'cs' in subset:
162 | cs = self.get_number_after_prefix(subset, 'cs')
163 | self.camera_ids = self.camera_ids[::cs]
164 |
165 | def load_camera_params(self):
166 | self.camera_ids = ['0']
167 |
168 | # Guessed focal length, height, width. Should be optimized or replaced by real values
169 | f, h, w = 512, 512, 512
170 | K = torch.Tensor([
171 | [f, 0, w],
172 | [0, f, h],
173 | [0, 0, 1]
174 | ])
175 |
176 | orientation = torch.eye(3)[None, ...] # (1, 3, 3)
177 | location = torch.Tensor([0, 0, 1])[None, ..., None] # (1, 3, 1)
178 |
179 | c2w = torch.cat([orientation, location], dim=-1) # camera-to-world transformation
180 |
181 | if self.cfg.target_extrinsic_type == "w2c":
182 | R = orientation.transpose(-1, -2)
183 | T = orientation.transpose(-1, -2) @ -location
184 | w2c = torch.cat([R, T], dim=-1) # world-to-camera transformation
185 | extrinsic = w2c
186 | elif self.cfg.target_extrinsic_type == "c2w":
187 | extrinsic = c2w
188 | else:
189 | raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}")
190 |
191 | self.camera_params = {}
192 | for i, camera_id in enumerate(self.camera_ids):
193 | self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]}
194 |
195 | return self.camera_params
196 |
197 | def __len__(self):
198 | if self.batchify_all_views:
199 | return self.num_timesteps
200 | else:
201 | return len(self.items)
202 |
203 | def __getitem__(self, i):
204 | if self.batchify_all_views:
205 | return self.getitem_by_timestep(i)
206 | else:
207 | return self.getitem_single_image(i)
208 |
209 | def getitem_single_image(self, i):
210 | item = deepcopy(self.items[i])
211 |
212 | rgb_path = self.get_property_path("rgb", i)
213 | item["rgb"] = np.array(Image.open(rgb_path))
214 |
215 | camera_param = self.camera_params[item["camera_id"]]
216 | item["intrinsic"] = camera_param["intrinsic"].clone()
217 | item["extrinsic"] = camera_param["extrinsic"].clone()
218 |
219 | if self.cfg.use_alpha_map or self.cfg.background_color is not None:
220 | alpha_path = self.get_property_path("alpha_map", i)
221 | item["alpha_map"] = np.array(Image.open(alpha_path))
222 |
223 | if self.cfg.use_landmark:
224 | timestep_index = self.items[i]["timestep_index"]
225 |
226 | if self.cfg.landmark_source == "face-alignment":
227 | landmark_path = self.get_property_path("landmark2d/face-alignment", i)
228 | elif self.cfg.landmark_source == "star":
229 | landmark_path = self.get_property_path("landmark2d/STAR", i)
230 | else:
231 | raise NotImplementedError(f"Unknown landmark source: {self.cfg.landmark_source}")
232 | landmark_npz = np.load(landmark_path)
233 |
234 | item["lmk2d"] = landmark_npz["face_landmark_2d"][timestep_index] # (num_points, 3)
235 | if (item["lmk2d"][:, :2] == -1).sum() > 0:
236 | item["lmk2d"][:, 2:] = 0.0
237 | else:
238 | item["lmk2d"][:, 2:] = 1.0
239 |
240 | item = self.apply_transforms(item)
241 | return item
242 |
243 | def getitem_by_timestep(self, timestep_index):
244 | begin = timestep_index * self.num_cameras
245 | indices = range(begin, begin + self.num_cameras)
246 | item = default_collate([self.getitem_single_image(i) for i in indices])
247 |
248 | item["num_cameras"] = self.num_cameras
249 | return item
250 |
251 | def apply_transforms(self, item):
252 | item = self.apply_scale_factor(item)
253 | item = self.apply_background_color(item)
254 | item = self.apply_to_tensor(item)
255 | return item
256 |
257 | def apply_to_tensor(self, item):
258 | if self.img_to_tensor:
259 | if "rgb" in item:
260 | item["rgb"] = F.to_tensor(item["rgb"])
261 |
262 | if "alpha_map" in item:
263 | item["alpha_map"] = F.to_tensor(item["alpha_map"])
264 | return item
265 |
266 | def apply_scale_factor(self, item):
267 | assert self.cfg.scale_factor <= 1.0
268 |
269 | if "rgb" in item:
270 | H, W, _ = item["rgb"].shape
271 | h, w = int(H * self.cfg.scale_factor), int(W * self.cfg.scale_factor)
272 | rgb = Image.fromarray(item["rgb"]).resize(
273 | (w, h), resample=Image.BILINEAR
274 | )
275 | item["rgb"] = np.array(rgb)
276 |
277 | # properties that are defined based on image size
278 | if "lmk2d" in item:
279 | item["lmk2d"][..., 0] *= w
280 | item["lmk2d"][..., 1] *= h
281 |
282 | if "lmk2d_iris" in item:
283 | item["lmk2d_iris"][..., 0] *= w
284 | item["lmk2d_iris"][..., 1] *= h
285 |
286 | if "bbox_2d" in item:
287 | item["bbox_2d"][[0, 2]] *= w
288 | item["bbox_2d"][[1, 3]] *= h
289 |
290 | # properties need to be scaled down when rgb is downsampled
291 | n_downsample_rgb = self.cfg.n_downsample_rgb if self.cfg.n_downsample_rgb else 1
292 | scale_factor = self.cfg.scale_factor / n_downsample_rgb
293 | item["scale_factor"] = scale_factor # NOTE: not self.cfg.scale_factor
294 | if scale_factor < 1.0:
295 | if "intrinsic" in item:
296 | item["intrinsic"][:2] *= scale_factor
297 | if "alpha_map" in item:
298 | h, w = item["rgb"].shape[:2]
299 | alpha_map = Image.fromarray(item["alpha_map"]).resize(
300 | (w, h), Image.Resampling.BILINEAR
301 | )
302 | item["alpha_map"] = np.array(alpha_map)
303 | return item
304 |
305 | def apply_background_color(self, item):
306 | if self.cfg.background_color is not None:
307 | assert (
308 | "alpha_map" in item
309 | ), "'alpha_map' is required to apply background color."
310 | fg = item["rgb"]
311 | if self.cfg.background_color == "white":
312 | bg = np.ones_like(fg) * 255
313 | elif self.cfg.background_color == "black":
314 | bg = np.zeros_like(fg)
315 | else:
316 | raise NotImplementedError(
317 | f"Unknown background color: {self.cfg.background_color}."
318 | )
319 |
320 | w = item["alpha_map"][..., None] / 255
321 | img = (w * fg + (1 - w) * bg).astype(np.uint8)
322 | item["rgb"] = img
323 | return item
324 |
325 | def get_property_path(
326 | self,
327 | name,
328 | index: Optional[int] = None,
329 | timestep_id: Optional[str] = None,
330 | camera_id: Optional[str] = None,
331 | ):
332 | p = self.properties[name]
333 | folder = p["folder"] if "folder" in p else None
334 | per_timestep = p["per_timestep"]
335 | suffix = p["suffix"]
336 |
337 | path = self.sequence_path
338 | if folder is not None:
339 | path = path / folder
340 |
341 | if self.num_cameras > 1:
342 | if camera_id is None:
343 | assert (
344 | index is not None), "index is required when camera_id is not provided."
345 | camera_id = self.items[index]["camera_id"]
346 | if "cam_id_prefix" in p:
347 | camera_id = p["cam_id_prefix"] + camera_id
348 | else:
349 | camera_id = ""
350 |
351 | if per_timestep:
352 | if timestep_id is None:
353 | assert index is not None, "index is required when timestep_id is not provided."
354 | timestep_id = self.items[index]["timestep_id"]
355 | if len(camera_id) > 0:
356 | path /= f"{camera_id}_{timestep_id}.{suffix}"
357 | else:
358 | path /= f"{timestep_id}.{suffix}"
359 | else:
360 | if len(camera_id) > 0:
361 | path /= f"{camera_id}.{suffix}"
362 | else:
363 | path = Path(str(path) + f".{suffix}")
364 |
365 | return path
366 |
367 | def get_property_path_list(self, name):
368 | paths = []
369 | for i in range(len(self.items)):
370 | img_path = self.get_property_path(name, i)
371 | paths.append(img_path)
372 | return paths
373 |
374 | @property
375 | def num_timesteps(self):
376 | return len(self.timestep_indices)
377 |
378 | @property
379 | def num_cameras(self):
380 | return len(self.camera_ids)
381 |
382 |
383 | if __name__ == "__main__":
384 | import tyro
385 | from tqdm import tqdm
386 | from torch.utils.data import DataLoader
387 | from vhap.config.base import DataConfig, import_module
388 |
389 | cfg = tyro.cli(DataConfig)
390 | cfg.use_landmark = False
391 | dataset = import_module(cfg._target)(
392 | cfg=cfg,
393 | img_to_tensor=False,
394 | batchify_all_views=True,
395 | )
396 |
397 | print(len(dataset))
398 |
399 | sample = dataset[0]
400 | print(sample.keys())
401 | print(sample["rgb"].shape)
402 |
403 | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1)
404 | for item in tqdm(dataloader):
405 | pass
406 |
--------------------------------------------------------------------------------
/vhap/flame_editor.py:
--------------------------------------------------------------------------------
1 | import tyro
2 | from dataclasses import dataclass
3 | from typing import Optional
4 | from pathlib import Path
5 | import time
6 | import dearpygui.dearpygui as dpg
7 | import numpy as np
8 | import torch
9 |
10 | from vhap.util.camera import OrbitCamera
11 | from vhap.model.flame import FlameHead
12 | from vhap.config.base import ModelConfig
13 | from vhap.util.render_nvdiffrast import NVDiffRenderer
14 |
15 |
16 | @dataclass
17 | class Config:
18 | model: ModelConfig
19 | """FLAME model configuration"""
20 | param_path: Optional[Path] = None
21 | """Path to the npz file for FLAME parameters"""
22 | W: int = 1024
23 | """GUI width"""
24 | H: int = 1024
25 | """GUI height"""
26 | radius: float = 1
27 | """default GUI camera radius from center"""
28 | fovy: float = 30
29 | """default GUI camera fovy"""
30 | background_color: tuple[float] = (1., 1., 1.)
31 | """default GUI background color"""
32 | use_opengl: bool = False
33 | """use OpenGL or CUDA rasterizer"""
34 | shade_smooth: bool = True
35 | """smooth shading or flat shading"""
36 |
37 |
38 | class FlameViewer:
39 | def __init__(self, cfg: Config):
40 | self.cfg = cfg # shared with the trainer's cfg to support in-place modification of rendering parameters.
41 |
42 | # flame model
43 | self.flame_model = FlameHead(
44 | cfg.model.n_shape,
45 | cfg.model.n_expr,
46 | add_teeth=True,
47 | include_lbs_color=True,
48 | )
49 | self.reset_flame_param()
50 |
51 | # viewer settings
52 | self.W = cfg.W
53 | self.H = cfg.H
54 | self.cam = OrbitCamera(self.W, self.H, r=cfg.radius, fovy=cfg.fovy, convention="opengl")
55 | self.last_time_fresh = None
56 | self.render_mode = '-'
57 | self.selected_regions = '-'
58 | self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32)
59 | self.need_update = True # camera moved, should reset accumulation
60 |
61 | # buffers for mouse interaction
62 | self.cursor_x = None
63 | self.cursor_y = None
64 | self.drag_begin_x = None
65 | self.drag_begin_y = None
66 | self.drag_button = None
67 |
68 | # rendering settings
69 | self.mesh_renderer = NVDiffRenderer(use_opengl=cfg.use_opengl, lighting_space='camera', shade_smooth=cfg.shade_smooth)
70 |
71 | self.define_gui()
72 |
73 | def __del__(self):
74 | dpg.destroy_context()
75 |
76 | def refresh(self):
77 | dpg.set_value("_texture", self.render_buffer)
78 |
79 | if self.last_time_fresh is not None:
80 | elapsed = time.time() - self.last_time_fresh
81 | fps = 1 / elapsed
82 | dpg.set_value("_log_fps", f'{fps:.1f}')
83 | self.last_time_fresh = time.time()
84 |
85 | def define_gui(self):
86 | dpg.create_context()
87 |
88 | # register texture =================================================================================================
89 | with dpg.texture_registry(show=False):
90 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
91 |
92 | # register window ==================================================================================================
93 | # the window to display the rendered image
94 | with dpg.window(label="viewer", tag="_render_window", width=self.W, height=self.H, no_title_bar=True, no_move=True, no_bring_to_front_on_focus=True, no_resize=True):
95 | dpg.add_image("_texture", width=self.W, height=self.H, tag="_image")
96 |
97 | # control window ==================================================================================================
98 | with dpg.window(label="Control", tag="_control_window", autosize=True):
99 |
100 | with dpg.group(horizontal=True):
101 | dpg.add_text("FPS: ")
102 | dpg.add_text("", tag="_log_fps")
103 |
104 | # rendering options
105 | with dpg.collapsing_header(label="Render", default_open=True):
106 |
107 | def callback_set_render_mode(sender, app_data):
108 | self.render_mode = app_data
109 | self.need_update = True
110 | dpg.add_combo(('-', 'lbs weights'), label='render mode', default_value=self.render_mode, tag="_combo_render_mode", callback=callback_set_render_mode)
111 |
112 | def callback_select_regions(sender, app_data):
113 | self.selected_regions = app_data
114 | self.need_update = True
115 | dpg.add_combo(['-']+sorted(self.flame_model.mask.v.keys()), label='regions', default_value='-', tag="_combo_regions", callback=callback_select_regions)
116 |
117 | # fov slider
118 | def callback_set_fovy(sender, app_data):
119 | self.cam.fovy = app_data
120 | self.need_update = True
121 | dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy, tag="_slider_fovy")
122 |
123 | def callback_reset_camera(sender, app_data):
124 | self.cam.reset()
125 | self.need_update = True
126 | dpg.set_value("_slider_fovy", self.cam.fovy)
127 |
128 | with dpg.group(horizontal=True):
129 | dpg.add_button(label="reset camera", tag="_button_reset_pose", callback=callback_reset_camera)
130 |
131 |
132 | # FLAME paraemter options
133 | with dpg.collapsing_header(label="Parameters", default_open=True):
134 |
135 | def callback_set_pose(sender, app_data):
136 | joint, axis = sender.split('-')[1:3]
137 | axis_idx = {'x': 0, 'y': 1, 'z': 2}[axis]
138 | self.flame_param[joint][0, axis_idx] = app_data
139 | self.need_update = True
140 | self.pose_sliders = []
141 | slider_width = 87
142 | for joint in ['neck', 'jaw']:
143 | dpg.add_text(f'{joint:9s}')
144 | if joint in self.flame_param:
145 | with dpg.group(horizontal=True):
146 | dpg.add_slider_float(label="x", min_value=-1, max_value=1, format="%.2f", default_value=self.flame_param[joint][0, 0], callback=callback_set_pose, tag=f"_slider-{joint}-x", width=slider_width)
147 | dpg.add_slider_float(label="y", min_value=-1, max_value=1, format="%.2f", default_value=self.flame_param[joint][0, 1], callback=callback_set_pose, tag=f"_slider-{joint}-y", width=slider_width)
148 | dpg.add_slider_float(label="z", min_value=-1, max_value=1, format="%.2f", default_value=self.flame_param[joint][0, 2], callback=callback_set_pose, tag=f"_slider-{joint}-z", width=slider_width)
149 | self.pose_sliders.append(f"_slider-{joint}-x")
150 | self.pose_sliders.append(f"_slider-{joint}-y")
151 | self.pose_sliders.append(f"_slider-{joint}-z")
152 |
153 | def callback_set_expr(sender, app_data):
154 | expr_i = int(sender.split('-')[2])
155 | self.flame_param['expr'][0, expr_i] = app_data
156 | self.need_update = True
157 | self.expr_sliders = []
158 | dpg.add_text(f'expr')
159 | for i in range(5):
160 | dpg.add_slider_float(label=f"{i}", min_value=-5, max_value=5, format="%.2f", default_value=0, callback=callback_set_expr, tag=f"_slider-expr-{i}", width=300)
161 | self.expr_sliders.append(f"_slider-expr-{i}")
162 |
163 | def callback_reset_flame(sender, app_data):
164 | self.reset_flame_param()
165 | self.need_update = True
166 | for slider in self.pose_sliders + self.expr_sliders:
167 | dpg.set_value(slider, 0)
168 | dpg.add_button(label="reset FLAME", tag="_button_reset_flame", callback=callback_reset_flame)
169 |
170 | ### register mouse handlers ========================================================================================
171 |
172 | def callback_mouse_move(sender, app_data):
173 | self.cursor_x, self.cursor_y = app_data
174 | if not dpg.is_item_focused("_render_window"):
175 | return
176 |
177 | if self.drag_begin_x is None or self.drag_begin_y is None:
178 | self.drag_begin_x = self.cursor_x
179 | self.drag_begin_y = self.cursor_y
180 | else:
181 | dx = self.cursor_x - self.drag_begin_x
182 | dy = self.cursor_y - self.drag_begin_y
183 |
184 | # button=dpg.mvMouseButton_Left
185 | if self.drag_button is dpg.mvMouseButton_Left:
186 | self.cam.orbit(dx, dy)
187 | self.need_update = True
188 | elif self.drag_button is dpg.mvMouseButton_Middle:
189 | self.cam.pan(dx, dy)
190 | self.need_update = True
191 |
192 | def callback_mouse_button_down(sender, app_data):
193 | if not dpg.is_item_focused("_render_window"):
194 | return
195 | self.drag_begin_x = self.cursor_x
196 | self.drag_begin_y = self.cursor_y
197 | self.drag_button = app_data[0]
198 |
199 | def callback_mouse_release(sender, app_data):
200 | self.drag_begin_x = None
201 | self.drag_begin_y = None
202 | self.drag_button = None
203 |
204 | self.dx_prev = None
205 | self.dy_prev = None
206 |
207 | def callback_mouse_drag(sender, app_data):
208 | if not dpg.is_item_focused("_render_window"):
209 | return
210 |
211 | button, dx, dy = app_data
212 | if self.dx_prev is None or self.dy_prev is None:
213 | ddx = dx
214 | ddy = dy
215 | else:
216 | ddx = dx - self.dx_prev
217 | ddy = dy - self.dy_prev
218 |
219 | self.dx_prev = dx
220 | self.dy_prev = dy
221 |
222 | if ddx != 0 and ddy != 0:
223 | if button is dpg.mvMouseButton_Left:
224 | self.cam.orbit(ddx, ddy)
225 | self.need_update = True
226 | elif button is dpg.mvMouseButton_Middle:
227 | self.cam.pan(ddx, ddy)
228 | self.need_update = True
229 |
230 | def callback_camera_wheel_scale(sender, app_data):
231 | if not dpg.is_item_focused("_render_window"):
232 | return
233 | delta = app_data
234 | self.cam.scale(delta)
235 | self.need_update = True
236 |
237 | with dpg.handler_registry():
238 | # this registry order helps avoid false fire
239 | dpg.add_mouse_release_handler(callback=callback_mouse_release)
240 | # dpg.add_mouse_drag_handler(callback=callback_mouse_drag) # not using the drag callback, since it does not return the starting point
241 | dpg.add_mouse_move_handler(callback=callback_mouse_move)
242 | dpg.add_mouse_down_handler(callback=callback_mouse_button_down)
243 | dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
244 |
245 | # key press handlers
246 | # dpg.add_key_press_handler(dpg.mvKey_Left, callback=callback_set_current_frame, tag='_mvKey_Left')
247 | # dpg.add_key_press_handler(dpg.mvKey_Right, callback=callback_set_current_frame, tag='_mvKey_Right')
248 | # dpg.add_key_press_handler(dpg.mvKey_Home, callback=callback_set_current_frame, tag='_mvKey_Home')
249 | # dpg.add_key_press_handler(dpg.mvKey_End, callback=callback_set_current_frame, tag='_mvKey_End')
250 |
251 | def callback_viewport_resize(sender, app_data):
252 | while self.rendering:
253 | time.sleep(0.01)
254 | self.need_update = False
255 | self.W = app_data[0]
256 | self.H = app_data[1]
257 | self.cam.image_width = self.W
258 | self.cam.image_height = self.H
259 | self.render_buffer = np.zeros((self.H, self.W, 3), dtype=np.float32)
260 |
261 | # delete and re-add the texture and image
262 | dpg.delete_item("_texture")
263 | dpg.delete_item("_image")
264 |
265 | with dpg.texture_registry(show=False):
266 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
267 | dpg.add_image("_texture", width=self.W, height=self.H, tag="_image", parent="_render_window")
268 | dpg.configure_item("_render_window", width=self.W, height=self.H)
269 | self.need_update = True
270 | dpg.set_viewport_resize_callback(callback_viewport_resize)
271 |
272 | ### global theme ==================================================================================================
273 | with dpg.theme() as theme_no_padding:
274 | with dpg.theme_component(dpg.mvAll):
275 | # set all padding to 0 to avoid scroll bar
276 | dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
277 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
278 | dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
279 | dpg.bind_item_theme("_render_window", theme_no_padding)
280 |
281 | ### finish setup ==================================================================================================
282 | dpg.create_viewport(title='FLAME Editor', width=self.W, height=self.H, resizable=True)
283 | dpg.setup_dearpygui()
284 | dpg.show_viewport()
285 |
286 | def reset_flame_param(self):
287 | self.flame_param = {
288 | 'shape': torch.zeros(1, self.cfg.model.n_shape),
289 | 'expr': torch.zeros(1, self.cfg.model.n_expr),
290 | 'rotation': torch.zeros(1, 3),
291 | 'neck': torch.zeros(1, 3),
292 | 'jaw': torch.zeros(1, 3),
293 | 'eyes': torch.zeros(1, 6),
294 | 'translation': torch.zeros(1, 3),
295 | 'static_offset': torch.zeros(1, 3),
296 | 'dynamic_offset': torch.zeros(1, 3),
297 | }
298 |
299 | def forward_flame(self, flame_param):
300 | N = flame_param['expr'].shape[0]
301 |
302 | self.verts, self.verts_cano = self.flame_model(
303 | **flame_param,
304 | zero_centered_at_root_node=False,
305 | return_landmarks=False,
306 | return_verts_cano=True,
307 | )
308 |
309 | def prepare_camera(self):
310 | @dataclass
311 | class Cam:
312 | FoVx = float(np.radians(self.cam.fovx))
313 | FoVy = float(np.radians(self.cam.fovy))
314 | image_height = self.cam.image_height
315 | image_width = self.cam.image_width
316 | world_view_transform = torch.tensor(self.cam.world_view_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
317 | full_proj_transform = torch.tensor(self.cam.full_proj_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
318 | camera_center = torch.tensor(self.cam.pose[:3, 3]).cuda()
319 | return Cam
320 |
321 | def run(self):
322 |
323 | while dpg.is_dearpygui_running():
324 |
325 | if self.need_update:
326 | self.rendering = True
327 |
328 | with torch.no_grad():
329 | # mesh
330 | self.forward_flame(self.flame_param)
331 | verts = self.verts.cuda()
332 | faces = self.flame_model.faces.cuda()
333 |
334 | # camera
335 | RT = torch.from_numpy(self.cam.world_view_transform).cuda()[None]
336 | K = torch.from_numpy(self.cam.intrinsics).cuda()[None]
337 | image_size = self.cam.image_height, self.cam.image_width
338 |
339 | if self.render_mode == 'lbs weights':
340 | v_color = self.flame_model.lbs_color.cuda()
341 | else:
342 | v_color = torch.ones_like(verts)
343 |
344 | if self.selected_regions != '-':
345 | vid = self.flame_model.mask.get_vid_except_region(self.selected_regions)
346 | v_color[..., vid, :] *= 0.3
347 |
348 | out_dict = self.mesh_renderer.render_rgba_vis(verts, faces, RT, K, image_size, self.cfg.background_color, v_color=v_color)
349 |
350 | rgba_mesh = out_dict['rgba'].squeeze(0).permute(2, 0, 1) # (C, W, H)
351 | rgb_mesh = rgba_mesh[:3, :, :]
352 |
353 | self.render_buffer = rgb_mesh.permute(1, 2, 0).cpu().numpy()
354 | self.refresh()
355 |
356 | self.rendering = False
357 | self.need_update = False
358 | dpg.render_dearpygui_frame()
359 |
360 |
361 | if __name__ == "__main__":
362 | cfg = tyro.cli(Config)
363 | gui = FlameViewer(cfg)
364 | gui.run()
365 |
--------------------------------------------------------------------------------
/vhap/flame_viewer.py:
--------------------------------------------------------------------------------
1 | import tyro
2 | from dataclasses import dataclass
3 | from typing import Optional
4 | from pathlib import Path
5 | import time
6 | import dearpygui.dearpygui as dpg
7 | import numpy as np
8 | import torch
9 | import PIL.Image
10 |
11 | from vhap.util.camera import OrbitCamera
12 | from vhap.model.flame import FlameHead
13 | from vhap.config.base import ModelConfig
14 | from vhap.util.render_nvdiffrast import NVDiffRenderer
15 |
16 |
17 | @dataclass
18 | class Config:
19 | model: ModelConfig
20 | """FLAME model configuration"""
21 | param_path: Optional[Path] = None
22 | """Path to the npz file for FLAME parameters"""
23 | tex_path: Optional[Path] = None
24 | """Path to the texture image"""
25 | W: int = 1024
26 | """GUI width"""
27 | H: int = 1024
28 | """GUI height"""
29 | radius: float = 1
30 | """default GUI camera radius from center"""
31 | fovy: float = 30
32 | """default GUI camera fovy"""
33 | background_color: tuple[float] = (1., 1., 1.)
34 | """default GUI background color"""
35 | use_opengl: bool = False
36 | """use OpenGL or CUDA rasterizer"""
37 | shade_smooth: bool = True
38 | """smooth shading or flat shading"""
39 |
40 |
41 | class FlameViewer:
42 | def __init__(self, cfg: Config):
43 | self.cfg = cfg # shared with the trainer's cfg to support in-place modification of rendering parameters.
44 |
45 | # flame model
46 | self.flame_model = FlameHead(cfg.model.n_shape, cfg.model.n_expr, add_teeth=True).cuda()
47 |
48 | # viewer settings
49 | self.W = cfg.W
50 | self.H = cfg.H
51 | self.cam = OrbitCamera(self.W, self.H, r=cfg.radius, fovy=cfg.fovy, convention="opengl")
52 | self.last_time_fresh = None
53 | self.render_buffer = np.ones((self.W, self.H, 3), dtype=np.float32)
54 | self.need_update = True # camera moved, should reset accumulation
55 |
56 | # buffers for mouse interaction
57 | self.cursor_x = None
58 | self.cursor_y = None
59 | self.drag_begin_x = None
60 | self.drag_begin_y = None
61 | self.drag_button = None
62 |
63 | # rendering settings
64 | self.mesh_renderer = NVDiffRenderer(use_opengl=cfg.use_opengl, lighting_space='camera', shade_smooth=cfg.shade_smooth)
65 | self.num_timesteps = 1
66 | self.timestep = 0
67 |
68 | self.define_gui()
69 |
70 | def __del__(self):
71 | dpg.destroy_context()
72 |
73 | def refresh(self):
74 | dpg.set_value("_texture", self.render_buffer)
75 |
76 | if self.last_time_fresh is not None:
77 | elapsed = time.time() - self.last_time_fresh
78 | fps = 1 / elapsed
79 | dpg.set_value("_log_fps", f'{fps:.1f}')
80 | self.last_time_fresh = time.time()
81 |
82 | def define_gui(self):
83 | dpg.create_context()
84 |
85 | # register texture =================================================================================================
86 | with dpg.texture_registry(show=False):
87 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
88 |
89 | # register window ==================================================================================================
90 | # the window to display the rendered image
91 | with dpg.window(label="viewer", tag="_render_window", width=self.W, height=self.H, no_title_bar=True, no_move=True, no_bring_to_front_on_focus=True, no_resize=True):
92 | dpg.add_image("_texture", width=self.W, height=self.H, tag="_image")
93 |
94 | # control window ==================================================================================================
95 | with dpg.window(label="Control", tag="_control_window", autosize=True):
96 |
97 | with dpg.group(horizontal=True):
98 | dpg.add_text("FPS: ")
99 | dpg.add_text("", tag="_log_fps")
100 |
101 | # rendering options
102 | with dpg.collapsing_header(label="Render", default_open=True):
103 |
104 | # timestep slider and buttons
105 | if self.num_timesteps != None:
106 | def callback_set_current_frame(sender, app_data):
107 | if sender == "_slider_timestep":
108 | self.timestep = app_data
109 | elif sender in ["_button_timestep_plus", "_mvKey_Right"]:
110 | self.timestep = min(self.timestep + 1, self.num_timesteps - 1)
111 | elif sender in ["_button_timestep_minus", "_mvKey_Left"]:
112 | self.timestep = max(self.timestep - 1, 0)
113 | elif sender == "_mvKey_Home":
114 | self.timestep = 0
115 | elif sender == "_mvKey_End":
116 | self.timestep = self.num_timesteps - 1
117 |
118 | dpg.set_value("_slider_timestep", self.timestep)
119 |
120 | self.need_update = True
121 | with dpg.group(horizontal=True):
122 | dpg.add_button(label='-', tag="_button_timestep_minus", callback=callback_set_current_frame)
123 | dpg.add_button(label='+', tag="_button_timestep_plus", callback=callback_set_current_frame)
124 | dpg.add_slider_int(label="timestep", tag='_slider_timestep', width=162, min_value=0, max_value=self.num_timesteps - 1, format="%d", default_value=0, callback=callback_set_current_frame)
125 |
126 | # fov slider
127 | def callback_set_fovy(sender, app_data):
128 | self.cam.fovy = app_data
129 | self.need_update = True
130 | dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy, tag="_slider_fovy")
131 |
132 | def callback_reset_camera(sender, app_data):
133 | self.cam.reset()
134 | self.need_update = True
135 | dpg.set_value("_slider_fovy", self.cam.fovy)
136 |
137 | with dpg.group(horizontal=True):
138 | dpg.add_button(label="reset camera", tag="_button_reset_pose", callback=callback_reset_camera)
139 |
140 |
141 | ### register mouse handlers ========================================================================================
142 |
143 | def callback_mouse_move(sender, app_data):
144 | self.cursor_x, self.cursor_y = app_data
145 | if not dpg.is_item_focused("_render_window"):
146 | return
147 |
148 | if self.drag_begin_x is None or self.drag_begin_y is None:
149 | self.drag_begin_x = self.cursor_x
150 | self.drag_begin_y = self.cursor_y
151 | else:
152 | dx = self.cursor_x - self.drag_begin_x
153 | dy = self.cursor_y - self.drag_begin_y
154 |
155 | # button=dpg.mvMouseButton_Left
156 | if self.drag_button is dpg.mvMouseButton_Left:
157 | self.cam.orbit(dx, dy)
158 | self.need_update = True
159 | elif self.drag_button is dpg.mvMouseButton_Middle:
160 | self.cam.pan(dx, dy)
161 | self.need_update = True
162 |
163 | def callback_mouse_button_down(sender, app_data):
164 | if not dpg.is_item_focused("_render_window"):
165 | return
166 | self.drag_begin_x = self.cursor_x
167 | self.drag_begin_y = self.cursor_y
168 | self.drag_button = app_data[0]
169 |
170 | def callback_mouse_release(sender, app_data):
171 | self.drag_begin_x = None
172 | self.drag_begin_y = None
173 | self.drag_button = None
174 |
175 | self.dx_prev = None
176 | self.dy_prev = None
177 |
178 | def callback_mouse_drag(sender, app_data):
179 | if not dpg.is_item_focused("_render_window"):
180 | return
181 |
182 | button, dx, dy = app_data
183 | if self.dx_prev is None or self.dy_prev is None:
184 | ddx = dx
185 | ddy = dy
186 | else:
187 | ddx = dx - self.dx_prev
188 | ddy = dy - self.dy_prev
189 |
190 | self.dx_prev = dx
191 | self.dy_prev = dy
192 |
193 | if ddx != 0 and ddy != 0:
194 | if button is dpg.mvMouseButton_Left:
195 | self.cam.orbit(ddx, ddy)
196 | self.need_update = True
197 | elif button is dpg.mvMouseButton_Middle:
198 | self.cam.pan(ddx, ddy)
199 | self.need_update = True
200 |
201 | def callback_camera_wheel_scale(sender, app_data):
202 | if not dpg.is_item_focused("_render_window"):
203 | return
204 | delta = app_data
205 | self.cam.scale(delta)
206 | self.need_update = True
207 |
208 | with dpg.handler_registry():
209 | # this registry order helps avoid false fire
210 | dpg.add_mouse_release_handler(callback=callback_mouse_release)
211 | # dpg.add_mouse_drag_handler(callback=callback_mouse_drag) # not using the drag callback, since it does not return the starting point
212 | dpg.add_mouse_move_handler(callback=callback_mouse_move)
213 | dpg.add_mouse_down_handler(callback=callback_mouse_button_down)
214 | dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
215 |
216 | # key press handlers
217 | dpg.add_key_press_handler(dpg.mvKey_Left, callback=callback_set_current_frame, tag='_mvKey_Left')
218 | dpg.add_key_press_handler(dpg.mvKey_Right, callback=callback_set_current_frame, tag='_mvKey_Right')
219 | dpg.add_key_press_handler(dpg.mvKey_Home, callback=callback_set_current_frame, tag='_mvKey_Home')
220 | dpg.add_key_press_handler(dpg.mvKey_End, callback=callback_set_current_frame, tag='_mvKey_End')
221 |
222 | def callback_viewport_resize(sender, app_data):
223 | while self.rendering:
224 | time.sleep(0.01)
225 | self.need_update = False
226 | self.W = app_data[0]
227 | self.H = app_data[1]
228 | self.cam.image_width = self.W
229 | self.cam.image_height = self.H
230 | self.render_buffer = np.zeros((self.H, self.W, 3), dtype=np.float32)
231 |
232 | # delete and re-add the texture and image
233 | dpg.delete_item("_texture")
234 | dpg.delete_item("_image")
235 |
236 | with dpg.texture_registry(show=False):
237 | dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture")
238 | dpg.add_image("_texture", width=self.W, height=self.H, tag="_image", parent="_render_window")
239 | dpg.configure_item("_render_window", width=self.W, height=self.H)
240 | self.need_update = True
241 | dpg.set_viewport_resize_callback(callback_viewport_resize)
242 |
243 | ### global theme ==================================================================================================
244 | with dpg.theme() as theme_no_padding:
245 | with dpg.theme_component(dpg.mvAll):
246 | # set all padding to 0 to avoid scroll bar
247 | dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core)
248 | dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core)
249 | dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core)
250 | dpg.bind_item_theme("_render_window", theme_no_padding)
251 |
252 | ### finish setup ==================================================================================================
253 | dpg.create_viewport(title='FLAME Sequence Viewer', width=self.W, height=self.H, resizable=True)
254 | dpg.setup_dearpygui()
255 | dpg.show_viewport()
256 |
257 | def forward_flame(self, flame_param):
258 | N = flame_param['expr'].shape[0]
259 |
260 | self.verts, self.verts_cano = self.flame_model(
261 | flame_param['shape'][None, ...].expand(N, -1).cuda(),
262 | flame_param['expr'].cuda(),
263 | flame_param['rotation'].cuda(),
264 | flame_param['neck_pose'].cuda(),
265 | flame_param['jaw_pose'].cuda(),
266 | flame_param['eyes_pose'].cuda(),
267 | flame_param['translation'].cuda(),
268 | zero_centered_at_root_node=False,
269 | return_landmarks=False,
270 | return_verts_cano=True,
271 | static_offset=flame_param['static_offset'].cuda(),
272 | # dynamic_offset=flame_param['dynamic_offset'].cuda(),
273 | )
274 |
275 | self.num_timesteps = N
276 | dpg.configure_item("_slider_timestep", max_value=self.num_timesteps - 1)
277 |
278 | def prepare_camera(self):
279 | @dataclass
280 | class Cam:
281 | FoVx = float(np.radians(self.cam.fovx))
282 | FoVy = float(np.radians(self.cam.fovy))
283 | image_height = self.cam.image_height
284 | image_width = self.cam.image_width
285 | world_view_transform = torch.tensor(self.cam.world_view_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
286 | full_proj_transform = torch.tensor(self.cam.full_proj_transform).float().cuda().T # the transpose is required by gaussian splatting rasterizer
287 | camera_center = torch.tensor(self.cam.pose[:3, 3]).cuda()
288 | return Cam
289 |
290 | def run(self):
291 | assert self.cfg.param_path is not None, 'param_path must be provided.'
292 | assert self.cfg.param_path.exists(), f'{self.cfg.param_path} does not exist.'
293 | self.flame_param = dict(np.load(self.cfg.param_path))
294 | for k, v in self.flame_param.items():
295 | if v.dtype in [np.float64, np.float32]:
296 | self.flame_param[k] = torch.from_numpy(v).float()
297 | self.forward_flame(self.flame_param)
298 |
299 | # tex_path
300 | if self.cfg.tex_path is not None:
301 | assert self.cfg.tex_path.exists(), f'{self.cfg.tex_path} does not exist.'
302 | # load the texture image with PIL and turn into pytorch tensor
303 | tex = np.array(PIL.Image.open(self.cfg.tex_path)) / 255.
304 | self.tex = torch.tensor(tex, dtype=torch.float32).permute(2, 0, 1)[None].cuda()
305 |
306 | self.verts_uv = self.flame_model.verts_uvs.clone()
307 | self.verts_uv[:, 1] = 1 - self.verts_uv[:, 1]
308 |
309 | self.faces_uv = self.flame_model.textures_idx.int()
310 |
311 | self.lights = self.flame_param['lights'].cuda()[None]
312 |
313 | while dpg.is_dearpygui_running():
314 |
315 | if self.need_update:
316 | self.rendering = True
317 |
318 | with torch.no_grad():
319 | RT = torch.from_numpy(self.cam.world_view_transform).cuda()[None]
320 | K = torch.from_numpy(self.cam.intrinsics).cuda()[None]
321 | image_size = self.cam.image_height, self.cam.image_width
322 | verts = self.verts[[self.timestep]]
323 | faces = self.flame_model.faces
324 | tex = self.tex if hasattr(self, 'tex') else None
325 | lights = self.lights if hasattr(self, 'lights') else None
326 |
327 | out_dict = self.mesh_renderer.render_rgba_vis(
328 | verts, faces, RT, K, image_size, self.cfg.background_color,
329 | verts_uv=self.verts_uv, faces_uv=self.faces_uv, tex=tex,
330 | lights=lights,
331 | )
332 |
333 | rgba_mesh = out_dict['rgba'].squeeze(0).permute(2, 0, 1) # (C, W, H)
334 | rgb_mesh = rgba_mesh[:3, :, :]
335 |
336 | self.render_buffer = rgb_mesh.permute(1, 2, 0).cpu().numpy()
337 | self.refresh()
338 |
339 | self.rendering = False
340 | self.need_update = False
341 | dpg.render_dearpygui_frame()
342 |
343 |
344 | if __name__ == "__main__":
345 | cfg = tyro.cli(Config)
346 | gui = FlameViewer(cfg)
347 | gui.run()
348 |
--------------------------------------------------------------------------------
/vhap/generate_flame_uvmask.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from typing import Literal
11 | import tyro
12 | import numpy as np
13 | from PIL import Image
14 | from pathlib import Path
15 | import torch
16 | import nvdiffrast.torch as dr
17 | from vhap.util.render_uvmap import render_uvmap_vtex
18 | from vhap.model.flame import FlameHead
19 |
20 |
21 | FLAME_UV_MASK_FOLDER = "asset/flame/uv_masks"
22 | FLAME_UV_MASK_NPZ = "asset/flame/uv_masks.npz"
23 |
24 |
25 | def main(
26 | use_opengl: bool = False,
27 | device: Literal['cuda', 'cpu'] = 'cuda',
28 | ):
29 | n_shape = 300
30 | n_expr = 100
31 | print("Initializing FLAME model")
32 | flame_model = FlameHead(n_shape, n_expr, add_teeth=True)
33 |
34 | flame_model = FlameHead(
35 | n_shape,
36 | n_expr,
37 | add_teeth=True,
38 | ).cuda()
39 |
40 | faces = flame_model.faces.int().cuda()
41 | verts_uv = flame_model.verts_uvs.cuda()
42 | # verts_uv[:, 1] = 1 - verts_uv[:, 1]
43 | faces_uv = flame_model.textures_idx.int().cuda()
44 | col_idx = faces_uv
45 |
46 | # Rasterizer context
47 | glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()
48 |
49 | h, w = 2048, 2048
50 | resolution = (h, w)
51 |
52 | if not Path(FLAME_UV_MASK_FOLDER).exists():
53 | Path(FLAME_UV_MASK_FOLDER).mkdir(parents=True)
54 |
55 | # alpha_maps = {}
56 | masks = {}
57 | for region, vt_mask in flame_model.mask.vt:
58 | v_color = torch.zeros(verts_uv.shape[0], 1).to(device) # alpha channel
59 | v_color[vt_mask] = 1
60 |
61 | alpha = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)[0]
62 | alpha = alpha.flip(0)
63 | # alpha_maps[region] = alpha.cpu().numpy()
64 | mask = (alpha > 0.5) # to avoid overlap between hair and face
65 | mask = mask.squeeze(-1).cpu().numpy()
66 | masks[region] = mask # (h, w)
67 |
68 | print(f"Saving uv mask for {region}...")
69 | # rgba = mask.expand(-1, -1, 4) # (h, w, 4)
70 | # rgb = torch.ones_like(mask).expand(-1, -1, 3) # (h, w, 3)
71 | # rgba = torch.cat([rgb, mask], dim=-1).cpu().numpy() # (h, w, 4)
72 | img = mask
73 | img = Image.fromarray((img * 255).astype(np.uint8))
74 | img.save(Path(FLAME_UV_MASK_FOLDER) / f"{region}.png")
75 |
76 | print(f"Saving uv mask into: {FLAME_UV_MASK_NPZ}")
77 | np.savez_compressed(FLAME_UV_MASK_NPZ, **masks)
78 |
79 |
80 | if __name__ == "__main__":
81 | tyro.cli(main)
--------------------------------------------------------------------------------
/vhap/model/lbs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems. All rights reserved.
14 | #
15 | # Contact: ps-license@tuebingen.mpg.de
16 |
17 | from __future__ import absolute_import
18 | from __future__ import print_function
19 | from __future__ import division
20 |
21 | import torch
22 | import torch.nn.functional as F
23 |
24 |
25 | def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
26 | """Calculates the rotation matrices for a batch of rotation vectors
27 | Parameters
28 | ----------
29 | rot_vecs: torch.tensor Nx3
30 | array of N axis-angle vectors
31 | Returns
32 | -------
33 | R: torch.tensor Nx3x3
34 | The rotation matrices for the given axis-angle parameters
35 | """
36 |
37 | batch_size = rot_vecs.shape[0]
38 | device = rot_vecs.device
39 |
40 | angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
41 | rot_dir = rot_vecs / angle
42 |
43 | cos = torch.unsqueeze(torch.cos(angle), dim=1)
44 | sin = torch.unsqueeze(torch.sin(angle), dim=1)
45 |
46 | # Bx1 arrays
47 | rx, ry, rz = torch.split(rot_dir, 1, dim=1)
48 | K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
49 |
50 | zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
51 | K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
52 | (batch_size, 3, 3)
53 | )
54 |
55 | ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
56 | rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
57 | return rot_mat
58 |
59 |
60 | def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
61 | """Calculates landmarks by barycentric interpolation
62 |
63 | Parameters
64 | ----------
65 | vertices: torch.tensor BxVx3, dtype = torch.float32
66 | The tensor of input vertices
67 | faces: torch.tensor Fx3, dtype = torch.long
68 | The faces of the mesh
69 | lmk_faces_idx: torch.tensor L, dtype = torch.long
70 | The tensor with the indices of the faces used to calculate the
71 | landmarks.
72 | lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
73 | The tensor of barycentric coordinates that are used to interpolate
74 | the landmarks
75 |
76 | Returns
77 | -------
78 | landmarks: torch.tensor BxLx3, dtype = torch.float32
79 | The coordinates of the landmarks for each mesh in the batch
80 | """
81 | # Extract the indices of the vertices for each face
82 | # BxLx3
83 | batch_size, num_verts = vertices.shape[:2]
84 | device = vertices.device
85 |
86 | lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(
87 | batch_size, -1, 3
88 | )
89 |
90 | lmk_faces += (
91 | torch.arange(batch_size, dtype=torch.long, device=device).view(-1, 1, 1)
92 | * num_verts
93 | )
94 |
95 | lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(batch_size, -1, 3, 3)
96 |
97 | landmarks = torch.einsum("blfi,blf->bli", [lmk_vertices, lmk_bary_coords])
98 | return landmarks
99 |
100 |
101 | def lbs(
102 | pose,
103 | v_shaped,
104 | posedirs,
105 | J_regressor,
106 | parents,
107 | lbs_weights,
108 | pose2rot=True,
109 | dtype=torch.float32,
110 | ):
111 | """Performs Linear Blend Skinning with the given shape and pose parameters
112 |
113 | Parameters
114 | ----------
115 | betas : torch.tensor BxNB
116 | The tensor of shape parameters
117 | pose : torch.tensor Bx(J + 1) * 3
118 | The pose parameters in axis-angle format
119 | v_template: torch.tensor BxVx3
120 | The template mesh that will be deformed
121 | shapedirs : torch.tensor 1xNB
122 | The tensor of PCA shape displacements
123 | posedirs : torch.tensor Px(V * 3)
124 | The pose PCA coefficients
125 | J_regressor : torch.tensor JxV
126 | The regressor array that is used to calculate the joints from
127 | the position of the vertices
128 | parents: torch.tensor J
129 | The array that describes the kinematic tree for the model
130 | lbs_weights: torch.tensor N x V x (J + 1)
131 | The linear blend skinning weights that represent how much the
132 | rotation matrix of each part affects each vertex
133 | pose2rot: bool, optional
134 | Flag on whether to convert the input pose tensor to rotation
135 | matrices. The default value is True. If False, then the pose tensor
136 | should already contain rotation matrices and have a size of
137 | Bx(J + 1)x9
138 | dtype: torch.dtype, optional
139 |
140 | Returns
141 | -------
142 | verts: torch.tensor BxVx3
143 | The vertices of the mesh after applying the shape and pose
144 | displacements.
145 | joints: torch.tensor BxJx3
146 | The joints of the model
147 | """
148 |
149 | batch_size = pose.shape[0]
150 | device = pose.device
151 |
152 | # Get the joints
153 | # NxJx3 array
154 | J = vertices2joints(J_regressor, v_shaped)
155 |
156 | # 3. Add pose blend shapes
157 | # N x J x 3 x 3
158 | ident = torch.eye(3, dtype=dtype, device=device)
159 | if pose2rot:
160 | rot_mats = batch_rodrigues(pose.view(-1, 3), dtype=dtype).view(
161 | [batch_size, -1, 3, 3]
162 | )
163 |
164 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
165 | # (N x P) x (P, V * 3) -> N x V x 3
166 | pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3)
167 | else:
168 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
169 | rot_mats = pose.view(batch_size, -1, 3, 3)
170 |
171 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(
172 | batch_size, -1, 3
173 | )
174 |
175 | v_posed = pose_offsets + v_shaped
176 |
177 | # 4. Get the global joint location
178 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
179 |
180 | # 5. Do skinning:
181 | # W is N x V x (J + 1)
182 | W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
183 | # (N x V x (J + 1)) x (N x (J + 1) x 16)
184 | num_joints = J_regressor.shape[0]
185 | T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
186 |
187 | homogen_coord = torch.ones(
188 | [batch_size, v_posed.shape[1], 1], dtype=dtype, device=device
189 | )
190 | v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
191 | v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
192 |
193 | verts = v_homo[:, :, :3, 0]
194 |
195 | return verts, J_transformed, A[:, 1]
196 |
197 |
198 | def vertices2joints(J_regressor, vertices):
199 | """Calculates the 3D joint locations from the vertices
200 |
201 | Parameters
202 | ----------
203 | J_regressor : torch.tensor JxV
204 | The regressor array that is used to calculate the joints from the
205 | position of the vertices
206 | vertices : torch.tensor BxVx3
207 | The tensor of mesh vertices
208 |
209 | Returns
210 | -------
211 | torch.tensor BxJx3
212 | The location of the joints
213 | """
214 |
215 | return torch.einsum("bik,ji->bjk", [vertices, J_regressor])
216 |
217 |
218 | def blend_shapes(betas, shape_disps):
219 | """Calculates the per vertex displacement due to the blend shapes
220 |
221 |
222 | Parameters
223 | ----------
224 | betas : torch.tensor Bx(num_betas)
225 | Blend shape coefficients
226 | shape_disps: torch.tensor Vx3x(num_betas)
227 | Blend shapes
228 |
229 | Returns
230 | -------
231 | torch.tensor BxVx3
232 | The per-vertex displacement due to shape deformation
233 | """
234 |
235 | # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
236 | # i.e. Multiply each shape displacement by its corresponding beta and
237 | # then sum them.
238 | blend_shape = torch.einsum("bl,mkl->bmk", [betas, shape_disps])
239 | return blend_shape
240 |
241 |
242 | def transform_mat(R, t):
243 | """Creates a batch of transformation matrices
244 | Args:
245 | - R: Bx3x3 array of a batch of rotation matrices
246 | - t: Bx3x1 array of a batch of translation vectors
247 | Returns:
248 | - T: Bx4x4 Transformation matrix
249 | """
250 | # No padding left or right, only add an extra row
251 | return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
252 |
253 |
254 | def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
255 | """
256 | Applies a batch of rigid transformations to the joints
257 |
258 | Parameters
259 | ----------
260 | rot_mats : torch.tensor BxNx3x3
261 | Tensor of rotation matrices
262 | joints : torch.tensor BxNx3
263 | Locations of joints
264 | parents : torch.tensor BxN
265 | The kinematic tree of each object
266 | dtype : torch.dtype, optional:
267 | The data type of the created tensors, the default is torch.float32
268 |
269 | Returns
270 | -------
271 | posed_joints : torch.tensor BxNx3
272 | The locations of the joints after applying the pose rotations
273 | rel_transforms : torch.tensor BxNx4x4
274 | The relative (with respect to the root joint) rigid transformations
275 | for all the joints
276 | """
277 |
278 | joints = torch.unsqueeze(joints, dim=-1)
279 |
280 | rel_joints = joints.clone().contiguous()
281 | rel_joints[:, 1:] = rel_joints[:, 1:] - joints[:, parents[1:]]
282 |
283 | transforms_mat = transform_mat(rot_mats.view(-1, 3, 3), rel_joints.view(-1, 3, 1))
284 | transforms_mat = transforms_mat.view(-1, joints.shape[1], 4, 4)
285 |
286 | transform_chain = [transforms_mat[:, 0]]
287 | for i in range(1, parents.shape[0]):
288 | # Subtract the joint location at the rest pose
289 | # No need for rotation, since it's identity when at rest
290 | curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i])
291 | transform_chain.append(curr_res)
292 |
293 | transforms = torch.stack(transform_chain, dim=1)
294 |
295 | # The last column of the transformations contains the posed joints
296 | posed_joints = transforms[:, :, :3, 3]
297 |
298 | joints_homogen = F.pad(joints, [0, 0, 0, 1])
299 |
300 | rel_transforms = transforms - F.pad(
301 | torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]
302 | )
303 |
304 | return posed_joints, rel_transforms
305 |
--------------------------------------------------------------------------------
/vhap/preprocess_video.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from tqdm import tqdm
3 | from typing import Literal, Optional, List
4 | import tyro
5 | import ffmpeg
6 | from PIL import Image
7 | import torch
8 | from vhap.data.image_folder_dataset import ImageFolderDataset
9 | from torch.utils.data import DataLoader
10 | from BackgroundMattingV2.model import MattingRefine
11 | from BackgroundMattingV2.asset import get_weights_path
12 |
13 |
14 | def video2frames(video_path: Path, image_dir: Path, keep_video_name: bool=False, target_fps: int=30, n_downsample: int=1):
15 | print(f'Converting video {video_path} to frames with downsample scale {n_downsample}')
16 | if not image_dir.exists():
17 | image_dir.mkdir(parents=True)
18 |
19 | file_path_stem = video_path.stem + '_' if keep_video_name else ''
20 |
21 | probe = ffmpeg.probe(str(video_path))
22 |
23 | video_fps = int(probe['streams'][0]['r_frame_rate'].split('/')[0])
24 | if video_fps ==0:
25 | video_fps = int(probe['streams'][0]['avg_frame_rate'].split('/')[0])
26 | if video_fps == 0:
27 | # nb_frames / duration
28 | video_fps = int(probe['streams'][0]['nb_frames']) / float(probe['streams'][0]['duration'])
29 | if video_fps == 0:
30 | raise ValueError('Cannot get valid video fps')
31 |
32 | num_frames = int(probe['streams'][0]['nb_frames'])
33 | video = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
34 | W = int(video['width'])
35 | H = int(video['height'])
36 | w = W // n_downsample
37 | h = H // n_downsample
38 | print(f'[Video] FPS: {video_fps} | number of frames: {num_frames} | resolution: {W}x{H}')
39 | print(f'[Target] FPS: {target_fps} | number of frames: {round(num_frames * target_fps / int(video_fps))} | resolution: {w}x{h}')
40 |
41 | (ffmpeg
42 | .input(str(video_path))
43 | .filter('fps', fps=f'{target_fps}')
44 | .filter('scale', width=w, height=h)
45 | .output(
46 | str(image_dir / f'{file_path_stem}%06d.jpg'),
47 | start_number=0,
48 | qscale=1, # lower values mean higher quality (1 is the best, 31 is the worst).
49 | )
50 | .overwrite_output()
51 | .run(quiet=True)
52 | )
53 |
54 | def robust_video_matting(image_dir: Path, N_warmup: Optional[int]=10):
55 | print(f'Running robust video matting on images in {image_dir}')
56 | # model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3").cuda()
57 | model = torch.hub.load("PeterL1n/RobustVideoMatting", "resnet50").cuda()
58 |
59 | dataset = ImageFolderDataset(image_folder=image_dir)
60 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
61 |
62 | # bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Green background.
63 | rec = [None] * 4 # Initial recurrent states.
64 | downsample_ratio = 0.5 #(for videos in 512x512)
65 | # downsample_ratio = 0.125 #(for videos in 3k)
66 | for item in tqdm(dataloader):
67 | rgb = item['rgb']
68 | rgb = rgb.permute(0, 3, 1, 2).float().cuda() / 255
69 | with torch.no_grad():
70 | while N_warmup:
71 | # use the first frame to warm up the recurrent states as if the video
72 | # has N_warmup identical frames at the beginning. This trick effectively
73 | # removes the artifacts for the first frame.
74 | fgr, pha, *rec = model(rgb, *rec, downsample_ratio) # Cycle the recurrent states.
75 | N_warmup -= 1
76 |
77 | fgr, pha, *rec = model(rgb, *rec, downsample_ratio) # Cycle the recurrent states.
78 | # fgr, pha, *rec = model(rgb, *rec) # Cycle the recurrent states.
79 | # fg = fgr * pha + bgr * (1 - pha)
80 |
81 | alpha = (pha[0, 0] * 255).cpu().numpy()
82 | alpha = Image.fromarray(alpha.astype('uint8'))
83 | alpha_path = item['image_path'][0].replace('images', 'alpha_maps')
84 | if not Path(alpha_path).parent.exists():
85 | Path(alpha_path).parent.mkdir(parents=True)
86 | alpha.save(alpha_path)
87 |
88 | def background_matting_v2(
89 | image_dir: Path,
90 | background_folder: Path=Path('../../BACKGROUND'),
91 | model_backbone: Literal['resnet101', 'resnet50', 'mobilenetv2']='resnet101',
92 | model_backbone_scale: float=0.25,
93 | model_refine_mode: Literal['full', 'sampling', 'thresholding']='thresholding',
94 | model_refine_sample_pixels: int=80_000,
95 | model_refine_threshold: float=0.01,
96 | model_refine_kernel_size: int=3,
97 | ):
98 | model = MattingRefine(
99 | model_backbone,
100 | model_backbone_scale,
101 | model_refine_mode,
102 | model_refine_sample_pixels,
103 | model_refine_threshold,
104 | model_refine_kernel_size
105 | )
106 |
107 | weights_path = get_weights_path(model_backbone)
108 |
109 | model = model.cuda().eval()
110 | model.load_state_dict(torch.load(weights_path, map_location='cuda', weights_only=True))
111 |
112 | dataset = ImageFolderDataset(
113 | image_folder=image_dir,
114 | background_folder=background_folder,
115 | background_fname2camId=lambda x: x.split('.')[0].split('_')[1], # image_00001.jpg -> 00001
116 | image_fname2camId=lambda x: x.split('.')[0].split('_')[1], # cam_00001.jpg -> 00001
117 | )
118 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
119 |
120 | for item in tqdm(dataloader):
121 | src = item['rgb']
122 | bgr = item['background']
123 | src = src.permute(0, 3, 1, 2).float().cuda() / 255
124 | bgr = bgr.permute(0, 3, 1, 2).float().cuda() / 255
125 |
126 | with torch.no_grad():
127 | pha, fgr, _, _, err, ref = model(src, bgr)
128 |
129 | alpha = (pha[0, 0] * 255).cpu().numpy()
130 | alpha = Image.fromarray(alpha.astype('uint8'))
131 | alpha_path = item['image_path'][0].replace('images', 'alpha_maps')
132 | if not Path(alpha_path).parent.exists():
133 | Path(alpha_path).parent.mkdir(parents=True)
134 | alpha.save(alpha_path)
135 |
136 | def downsample_frames(image_dir: Path, n_downsample: int):
137 | print(f'Downsample frames in {image_dir} by {n_downsample}')
138 | assert n_downsample in [2, 4, 8]
139 |
140 | image_paths = sorted(list(image_dir.glob('*.jpg')))
141 | for i, image_path in tqdm(enumerate(image_paths), total=len(image_paths)):
142 | # downasample the resolution of images
143 | img = Image.open(image_path)
144 | W, H = img.size
145 | img = img.resize((W // n_downsample, H // n_downsample))
146 | img.save(image_path)
147 |
148 | def main(
149 | input: Path,
150 | target_fps: int=25,
151 | downsample_scales: List[int]=[],
152 | matting_method: Optional[Literal['robust_video_matting', 'background_matting_v2']]=None,
153 | background_folder: Path=Path('../../BACKGROUND'),
154 | ):
155 | if not input.exists():
156 | matched_paths = list(input.parent.glob(f"{input.name}"))
157 | if len(matched_paths) == 0:
158 | raise FileNotFoundError(f"Cannot find the directory: {input}")
159 | elif len(matched_paths) == 1:
160 | input = matched_paths[0]
161 | else:
162 | raise FileNotFoundError(f"Found multiple matched folders: {matched_paths}")
163 |
164 | # prepare path
165 | if input.suffix in ['.mov', '.mp4']:
166 | print(f'Processing video file: {input}')
167 | videos = [input]
168 | image_dir = input.parent / input.stem / 'images'
169 | elif input.is_dir():
170 | # if input is a directory, assume all contained videos are synchronized multiview of the same scene
171 | print(f'Processing directory: {input}')
172 | videos = list(input.glob('cam_*.mp4')) + list(input.glob('images/cam_*.mp4'))
173 | image_dir = input / 'images'
174 | else:
175 | raise ValueError(f"Input should be a video file or a directory containing video files: {input}")
176 | assert len(videos) > 0, f'No video files found in {input}'
177 |
178 | # extract frames
179 | for i, video_path in enumerate(videos):
180 | print(f'\n[{i}/{len(videos)}] Processing video file: {video_path}')
181 |
182 | for n_downsample in [1] + downsample_scales:
183 | image_dir_ = image_dir if n_downsample == 1 else Path(str(image_dir) + f'_{n_downsample}')
184 | video2frames(video_path, image_dir_, keep_video_name=len(videos) > 1, target_fps=target_fps, n_downsample=n_downsample)
185 |
186 | # foreground matting
187 | if matting_method == 'robust_video_matting':
188 | robust_video_matting(image_dir)
189 | elif matting_method == 'background_matting_v2':
190 | background_matting_v2(image_dir, background_folder=background_folder)
191 | elif matting_method is not None:
192 | raise ValueError(f'Unknown matting method: {matting_method}')
193 |
194 |
195 | if __name__ == '__main__':
196 | tyro.cli(main)
--------------------------------------------------------------------------------
/vhap/track.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import tyro
11 |
12 | from vhap.config.base import BaseTrackingConfig
13 | from vhap.model.tracker import GlobalTracker
14 |
15 |
16 | if __name__ == "__main__":
17 | tyro.extras.set_accent_color("bright_yellow")
18 | cfg = tyro.cli(BaseTrackingConfig)
19 |
20 | tracker = GlobalTracker(cfg)
21 | tracker.optimize()
22 |
--------------------------------------------------------------------------------
/vhap/track_nersemble.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import tyro
11 |
12 | from vhap.config.nersemble import NersembleTrackingConfig
13 | from vhap.model.tracker import GlobalTracker
14 |
15 |
16 | if __name__ == "__main__":
17 | tyro.extras.set_accent_color("bright_yellow")
18 | cfg = tyro.cli(NersembleTrackingConfig)
19 |
20 | tracker = GlobalTracker(cfg)
21 | tracker.optimize()
22 |
--------------------------------------------------------------------------------
/vhap/track_nersemble_v2.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import tyro
11 |
12 | from vhap.config.nersemble_v2 import NersembleV2TrackingConfig
13 | from vhap.model.tracker import GlobalTracker
14 |
15 |
16 | if __name__ == "__main__":
17 | tyro.extras.set_accent_color("bright_yellow")
18 | cfg = tyro.cli(NersembleV2TrackingConfig)
19 |
20 | tracker = GlobalTracker(cfg)
21 | tracker.optimize()
22 |
--------------------------------------------------------------------------------
/vhap/util/camera.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from typing import Tuple, Literal
11 | import torch
12 | import torch.nn.functional as F
13 | import math
14 | import numpy as np
15 | from scipy.spatial.transform import Rotation
16 |
17 |
18 | def align_cameras_to_axes(
19 | R: torch.Tensor,
20 | T: torch.Tensor,
21 | target_convention: Literal["opengl", "opencv"] = None,
22 | ):
23 | """align the averaged axes of cameras with the world axes.
24 |
25 | Args:
26 | R: rotation matrix (N, 3, 3)
27 | T: translation vector (N, 3)
28 | """
29 | # The column vectors of R are the basis vectors of each camera.
30 | # We construct new bases by taking the mean directions of axes, then use Gram-Schmidt
31 | # process to make them orthonormal
32 | bases_c2w = gram_schmidt_orthogonalization(R.mean(0))
33 | if target_convention == "opengl":
34 | bases_c2w[:, [1, 2]] *= -1 # flip y and z axes
35 | elif target_convention == "opencv":
36 | pass
37 | bases_w2c = bases_c2w.t()
38 |
39 | # convert the camera poses into the new coordinate system
40 | R = bases_w2c[None, ...] @ R
41 | T = bases_w2c[None, ...] @ T
42 | return R, T
43 |
44 |
45 | def convert_camera_convention(camera_convention_conversion: str, R: torch.Tensor, K: torch.Tensor, H: int, W: int):
46 | if camera_convention_conversion is not None:
47 | if camera_convention_conversion == "opencv->opengl":
48 | R[:, :3, [1, 2]] *= -1
49 | # flip y of the principal point
50 | K[..., 1, 2] = H - K[..., 1, 2]
51 | elif camera_convention_conversion == "opencv->pytorch3d":
52 | R[:, :3, [0, 1]] *= -1
53 | # flip x and y of the principal point
54 | K[..., 0, 2] = W - K[..., 0, 2]
55 | K[..., 1, 2] = H - K[..., 1, 2]
56 | elif camera_convention_conversion == "opengl->pytorch3d":
57 | R[:, :3, [0, 2]] *= -1
58 | # flip x of the principal point
59 | K[..., 0, 2] = W - K[..., 0, 2]
60 | else:
61 | raise ValueError(
62 | f"Unknown camera coordinate conversion: {camera_convention_conversion}."
63 | )
64 | return R, K
65 |
66 |
67 | def gram_schmidt_orthogonalization(M: torch.tensor):
68 | """conducting Gram-Schmidt process to transform column vectors into orthogonal bases
69 |
70 | Args:
71 | M: An matrix (num_rows, num_cols)
72 | Return:
73 | M: An matrix with orthonormal column vectors (num_rows, num_cols)
74 | """
75 | num_rows, num_cols = M.shape
76 | for c in range(1, num_cols):
77 | M[:, [c - 1, c]] = F.normalize(M[:, [c - 1, c]], p=2, dim=0)
78 | M[:, [c]] -= M[:, :c] @ (M[:, :c].T @ M[:, [c]])
79 |
80 | M[:, -1] = F.normalize(M[:, -1], p=2, dim=0)
81 | return M
82 |
83 |
84 | def projection_from_intrinsics(K: np.ndarray, image_size: Tuple[int], near: float=0.01, far:float=10, flip_y: bool=False, z_sign=-1):
85 | """
86 | Transform points from camera space (x: right, y: up, z: out) to clip space (x: right, y: down, z: in)
87 | Args:
88 | K: Intrinsic matrix, (N, 3, 3)
89 | K = [[
90 | [fx, 0, cx],
91 | [0, fy, cy],
92 | [0, 0, 1],
93 | ]
94 | ]
95 | image_size: (height, width)
96 | Output:
97 | proj = [[
98 | [2*fx/w, 0.0, (w - 2*cx)/w, 0.0 ],
99 | [0.0, 2*fy/h, (h - 2*cy)/h, 0.0 ],
100 | [0.0, 0.0, z_sign*(far+near) / (far-near), -2*far*near / (far-near)],
101 | [0.0, 0.0, z_sign, 0.0 ]
102 | ]
103 | ]
104 | """
105 |
106 | B = K.shape[0]
107 | h, w = image_size
108 |
109 | if K.shape[-2:] == (3, 3):
110 | fx = K[..., 0, 0]
111 | fy = K[..., 1, 1]
112 | cx = K[..., 0, 2]
113 | cy = K[..., 1, 2]
114 | elif K.shape[-1] == 4:
115 | # fx, fy, cx, cy = K[..., [0, 1, 2, 3]].split(1, dim=-1)
116 | fx = K[..., [0]]
117 | fy = K[..., [1]]
118 | cx = K[..., [2]]
119 | cy = K[..., [3]]
120 | else:
121 | raise ValueError(f"Expected K to be (N, 3, 3) or (N, 4) but got: {K.shape}")
122 |
123 | proj = np.zeros([B, 4, 4])
124 | proj[:, 0, 0] = fx * 2 / w
125 | proj[:, 1, 1] = fy * 2 / h
126 | proj[:, 0, 2] = (w - 2 * cx) / w
127 | proj[:, 1, 2] = (h - 2 * cy) / h
128 | proj[:, 2, 2] = z_sign * (far+near) / (far-near)
129 | proj[:, 2, 3] = -2*far*near / (far-near)
130 | proj[:, 3, 2] = z_sign
131 |
132 | if flip_y:
133 | proj[:, 1, 1] *= -1
134 | return proj
135 |
136 |
137 | class OrbitCamera:
138 | def __init__(self, W, H, r=2, fovy=60, znear=1e-8, zfar=10, convention: Literal["opengl", "opencv"]="opengl"):
139 | self.image_width = W
140 | self.image_height = H
141 | self.radius_default = r
142 | self.fovy_default = fovy
143 | self.znear = znear
144 | self.zfar = zfar
145 | self.convention = convention
146 |
147 | self.up = np.array([0, 1, 0], dtype=np.float32)
148 | self.reset()
149 |
150 | def reset(self):
151 | """ The internal state of the camera is based on the OpenGL convention, but
152 | properties are converted to the target convention when queried.
153 | """
154 | self.rot = Rotation.from_matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) # OpenGL convention
155 | self.look_at = np.array([0, 0, 0], dtype=np.float32) # look at this point
156 | self.radius = self.radius_default # camera distance from center
157 | self.fovy = self.fovy_default
158 | if self.convention == "opencv":
159 | self.z_sign = 1
160 | self.y_sign = 1
161 | elif self.convention == "opengl":
162 | self.z_sign = -1
163 | self.y_sign = -1
164 | else:
165 | raise ValueError(f"Unknown convention: {self.convention}")
166 |
167 | @property
168 | def fovx(self):
169 | return self.fovy / self.image_height * self.image_width
170 |
171 | @property
172 | def intrinsics(self):
173 | focal = self.image_height / (2 * np.tan(np.radians(self.fovy) / 2))
174 | return np.array([focal, focal, self.image_width // 2, self.image_height // 2])
175 |
176 | @property
177 | def projection_matrix(self):
178 | return projection_from_intrinsics(self.intrinsics[None], (self.image_height, self.image_width), self.znear, self.zfar, z_sign=self.z_sign)[0]
179 |
180 | @property
181 | def world_view_transform(self):
182 | return np.linalg.inv(self.pose) # world2cam
183 |
184 | @property
185 | def full_proj_transform(self):
186 | return self.projection_matrix @ self.world_view_transform
187 |
188 | @property
189 | def pose(self):
190 | # first move camera to radius
191 | pose = np.eye(4, dtype=np.float32)
192 | pose[2, 3] += self.radius
193 |
194 | # rotate
195 | rot = np.eye(4, dtype=np.float32)
196 | rot[:3, :3] = self.rot.as_matrix()
197 | pose = rot @ pose
198 |
199 | # translate
200 | pose[:3, 3] -= self.look_at
201 |
202 | if self.convention == "opencv":
203 | pose[:, [1, 2]] *= -1
204 | elif self.convention == "opengl":
205 | pass
206 | else:
207 | raise ValueError(f"Unknown convention: {self.convention}")
208 | return pose
209 |
210 | def orbit(self, dx, dy):
211 | # rotate along camera up/side axis!
212 | side = self.rot.as_matrix()[:3, 0]
213 | rotvec_x = self.up * np.radians(-0.3 * dx)
214 | rotvec_y = side * np.radians(-0.3 * dy)
215 | self.rot = Rotation.from_rotvec(rotvec_x) * Rotation.from_rotvec(rotvec_y) * self.rot
216 |
217 | def scale(self, delta):
218 | self.radius *= 1.1 ** (-delta)
219 |
220 | def pan(self, dx, dy, dz=0):
221 | # pan in camera coordinate system (careful on the sensitivity!)
222 | d = np.array([dx, -dy, dz]) # the y axis is flipped
223 | self.look_at += 2 * self.rot.as_matrix()[:3, :3] @ d * self.radius / self.image_height * math.tan(np.radians(self.fovy) / 2)
224 |
--------------------------------------------------------------------------------
/vhap/util/color_correction.py:
--------------------------------------------------------------------------------
1 | # from https://github.com/tobias-kirschstein/nersemble-data/blob/f96aa8d9d482df53c40c51ecc07203646265e4f0/src/nersemble_data/util/color_correction.py
2 |
3 | import colour
4 | import numpy as np
5 | from colour.characterisation import matrix_augmented_Cheung2004
6 | from colour.utilities import as_float_array
7 |
8 |
9 | def color_correction_Cheung2004_precomputed(
10 | image: np.ndarray,
11 | CCM: np.ndarray,
12 | ) -> np.ndarray:
13 | terms = CCM.shape[-1]
14 | RGB = as_float_array(image)
15 | shape = RGB.shape
16 |
17 | RGB = np.reshape(RGB, (-1, 3))
18 |
19 | RGB_e = matrix_augmented_Cheung2004(RGB, terms)
20 |
21 | return np.reshape(np.transpose(np.dot(CCM, np.transpose(RGB_e))), shape)
22 |
23 |
24 | def correct_color(image: np.ndarray, ccm: np.ndarray) -> np.ndarray:
25 | is_uint8 = image.dtype == np.uint8
26 | if is_uint8:
27 | image = image / 255.
28 | image_linear = colour.cctf_decoding(image)
29 | image_corrected = color_correction_Cheung2004_precomputed(image_linear, ccm)
30 | image_corrected = colour.cctf_encoding(image_corrected)
31 | if is_uint8:
32 | image_corrected = np.clip(image_corrected * 255, 0, 255).astype(np.uint8)
33 |
34 | return image_corrected
35 |
--------------------------------------------------------------------------------
/vhap/util/landmark_detector_fa.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from vhap.util.log import get_logger
11 |
12 | from typing import Literal
13 | from tqdm import tqdm
14 |
15 | import face_alignment
16 | import numpy as np
17 | import matplotlib.path as mpltPath
18 |
19 | from fdlite import (
20 | FaceDetection,
21 | FaceLandmark,
22 | face_detection_to_roi,
23 | IrisLandmark,
24 | iris_roi_from_face_landmarks,
25 | )
26 |
27 | logger = get_logger(__name__)
28 |
29 |
30 | class LandmarkDetectorFA:
31 |
32 | IMAGE_FILE_NAME = "image_0000.png"
33 | LMK_FILE_NAME = "keypoints_static_0000.json"
34 |
35 | def __init__(
36 | self,
37 | face_detector:Literal["sfd", "blazeface"]="sfd",
38 | ):
39 | """
40 | Creates dataset_path where all results are stored
41 | :param video_path: path to video file
42 | :param dataset_path: path to results directory
43 | """
44 |
45 | logger.info("Initialize FaceAlignment module...")
46 | # 68 facial landmark detector
47 | self.fa = face_alignment.FaceAlignment(
48 | face_alignment.LandmarksType.TWO_HALF_D,
49 | face_detector=face_detector,
50 | flip_input=True,
51 | device="cuda"
52 | )
53 |
54 | def detect_single_image(self, img):
55 | bbox = self.fa.face_detector.detect_from_image(img)
56 |
57 | if len(bbox) == 0:
58 | lmks = np.zeros([68, 3]) - 1 # set to -1 when landmarks is inavailable
59 |
60 | else:
61 | if len(bbox) > 1:
62 | # if multiple boxes detected, use the one with highest confidence
63 | bbox = [bbox[np.argmax(np.array(bbox)[:, -1])]]
64 |
65 | lmks = self.fa.get_landmarks_from_image(img, detected_faces=bbox)[0]
66 | lmks = np.concatenate([lmks, np.ones_like(lmks[:, :1])], axis=1)
67 |
68 | if (lmks[:, :2] == -1).sum() > 0:
69 | lmks[:, 2:] = 0.0
70 | else:
71 | lmks[:, 2:] = 1.0
72 |
73 | h, w = img.shape[:2]
74 | lmks[:, 0] /= w
75 | lmks[:, 1] /= h
76 | bbox[0][[0, 2]] /= w
77 | bbox[0][[1, 3]] /= h
78 | return bbox, lmks
79 |
80 | def detect_dataset(self, dataloader):
81 | """
82 | Annotates each frame with 68 facial landmarks
83 | :return: dict mapping frame number to landmarks numpy array and the same thing for bboxes
84 | """
85 | landmarks = {}
86 | bboxes = {}
87 |
88 | logger.info("Begin annotating landmarks...")
89 | for item in tqdm(dataloader):
90 | timestep_id = item["timestep_id"][0]
91 | camera_id = item["camera_id"][0]
92 | scale_factor = item["scale_factor"][0]
93 |
94 | logger.info(
95 | f"Annotate facial landmarks for timestep: {timestep_id}, camera: {camera_id}"
96 | )
97 | img = item["rgb"][0].numpy()
98 |
99 | bbox, lmks = self.detect_single_image(img)
100 |
101 | if len(bbox) == 0:
102 | logger.error(
103 | f"No bbox found for frame: {timestep_id}, camera: {camera_id}. Setting landmarks to all -1."
104 | )
105 |
106 | if camera_id not in landmarks:
107 | landmarks[camera_id] = {}
108 | if camera_id not in bboxes:
109 | bboxes[camera_id] = {}
110 | landmarks[camera_id][timestep_id] = lmks
111 | bboxes[camera_id][timestep_id] = bbox[0] if len(bbox) > 0 else np.zeros(5) - 1
112 | return landmarks, bboxes
113 |
114 | def annotate_iris_landmarks(self, dataloader):
115 | """
116 | Annotates each frame with 2 iris landmarks
117 | :return: dict mapping frame number to landmarks numpy array
118 | """
119 |
120 | # iris detector
121 | detect_faces = FaceDetection()
122 | detect_face_landmarks = FaceLandmark()
123 | detect_iris_landmarks = IrisLandmark()
124 |
125 | landmarks = {}
126 |
127 | for item in tqdm(dataloader):
128 | timestep_id = item["timestep_id"][0]
129 | camera_id = item["camera_id"][0]
130 | scale_factor = item["scale_factor"][0]
131 | if timestep_id not in landmarks:
132 | landmarks[timestep_id] = {}
133 | logger.info(
134 | f"Annotate iris landmarks for timestep: {timestep_id}, camera: {camera_id}"
135 | )
136 |
137 | img = item["rgb"][0].numpy()
138 |
139 | height, width = img.shape[:2]
140 | img_size = (width, height)
141 |
142 | face_detections = detect_faces(img)
143 | if len(face_detections) != 1:
144 | logger.error("Empty iris landmarks (type 1)")
145 | landmarks[timestep_id][camera_id] = None
146 | else:
147 | for face_detection in face_detections:
148 | try:
149 | face_roi = face_detection_to_roi(face_detection, img_size)
150 | except ValueError:
151 | logger.error("Empty iris landmarks (type 2)")
152 | landmarks[timestep_id][camera_id] = None
153 | break
154 |
155 | face_landmarks = detect_face_landmarks(img, face_roi)
156 | if len(face_landmarks) == 0:
157 | logger.error("Empty iris landmarks (type 3)")
158 | landmarks[timestep_id][camera_id] = None
159 | break
160 |
161 | iris_rois = iris_roi_from_face_landmarks(face_landmarks, img_size)
162 |
163 | if len(iris_rois) != 2:
164 | logger.error("Empty iris landmarks (type 4)")
165 | landmarks[timestep_id][camera_id] = None
166 | break
167 |
168 | lmks = []
169 | for iris_roi in iris_rois[::-1]:
170 | try:
171 | iris_landmarks = detect_iris_landmarks(img, iris_roi).iris[
172 | 0:1
173 | ]
174 | except np.linalg.LinAlgError:
175 | logger.error("Failed to get iris landmarks")
176 | landmarks[timestep_id][camera_id] = None
177 | break
178 |
179 | for landmark in iris_landmarks:
180 | lmks.append([landmark.x * width, landmark.y * height, 1.0])
181 |
182 | lmks = np.array(lmks, dtype=np.float32)
183 |
184 | h, w = img.shape[:2]
185 | lmks[:, 0] /= w
186 | lmks[:, 1] /= h
187 |
188 | landmarks[timestep_id][camera_id] = lmks
189 |
190 | return landmarks
191 |
192 | def iris_consistency(self, lm_iris, lm_eye):
193 | """
194 | Checks if landmarks for eye and iris are consistent
195 | :param lm_iris:
196 | :param lm_eye:
197 | :return:
198 | """
199 | lm_iris = lm_iris[:, :2]
200 | lm_eye = lm_eye[:, :2]
201 |
202 | polygon_eye = mpltPath.Path(lm_eye)
203 | valid = polygon_eye.contains_points(lm_iris)
204 |
205 | return valid[0]
206 |
207 | def annotate_landmarks(self, dataloader, add_iris=False):
208 | """
209 | Annotates each frame with landmarks for face and iris. Assumes frames have been extracted
210 | :param add_iris:
211 | :return:
212 | """
213 | lmks_face, bboxes_faces = self.detect_dataset(dataloader)
214 |
215 | if add_iris:
216 | lmks_iris = self.annotate_iris_landmarks(dataloader)
217 |
218 | # check conistency of iris landmarks and facial keypoints
219 | for camera_id, lmk_face_camera in lmks_face.items():
220 | for timestep_id in lmk_face_camera.keys():
221 |
222 | discard_iris_lmks = False
223 | bboxes_face_i = bboxes_faces[camera_id][timestep_id]
224 | if bboxes_face_i is not None:
225 | lmks_face_i = lmks_face[camera_id][timestep_id]
226 | lmks_iris_i = lmks_iris[camera_id][timestep_id]
227 | if lmks_iris_i is not None:
228 |
229 | # validate iris landmarks
230 | left_face = lmks_face_i[36:42]
231 | right_face = lmks_face_i[42:48]
232 |
233 | right_iris = lmks_iris_i[:1]
234 | left_iris = lmks_iris_i[1:]
235 |
236 | if not (
237 | self.iris_consistency(left_iris, left_face)
238 | and self.iris_consistency(right_iris, right_face)
239 | ):
240 | logger.error(
241 | f"Inconsistent iris landmarks for timestep: {timestep_id}, camera: {camera_id}"
242 | )
243 | discard_iris_lmks = True
244 | else:
245 | logger.error(
246 | f"No iris landmarks detected for timestep: {timestep_id}, camera: {camera_id}"
247 | )
248 | discard_iris_lmks = True
249 |
250 | else:
251 | logger.error(
252 | f"Discarding iris landmarks because no face landmark is available for timestep: {timestep_id}, camera: {camera_id}"
253 | )
254 | discard_iris_lmks = True
255 |
256 | if discard_iris_lmks:
257 | lmks_iris[timestep_id][camera_id] = (
258 | np.zeros([2, 3]) - 1
259 | ) # set to -1 for inconsistent iris landmarks
260 |
261 | # construct final json
262 | for camera_id, lmk_face_camera in lmks_face.items():
263 | bounding_box = []
264 | face_landmark_2d = []
265 | iris_landmark_2d = []
266 | for timestep_id in lmk_face_camera.keys():
267 | bounding_box.append(bboxes_faces[camera_id][timestep_id][None])
268 | face_landmark_2d.append(lmks_face[camera_id][timestep_id][None])
269 |
270 | if add_iris:
271 | iris_landmark_2d.append(lmks_iris[camera_id][timestep_id][None])
272 |
273 | lmk_dict = {
274 | "bounding_box": bounding_box,
275 | "face_landmark_2d": face_landmark_2d,
276 | }
277 | if len(iris_landmark_2d) > 0:
278 | lmk_dict["iris_landmark_2d"] = iris_landmark_2d
279 |
280 | for k, v in lmk_dict.items():
281 | if len(v) > 0:
282 | lmk_dict[k] = np.concatenate(v, axis=0)
283 | out_path = dataloader.dataset.get_property_path(
284 | "landmark2d/face-alignment", camera_id=camera_id
285 | )
286 | logger.info(f"Saving landmarks to: {out_path}")
287 | if not out_path.parent.exists():
288 | out_path.parent.mkdir(parents=True)
289 | np.savez(out_path, **lmk_dict)
290 |
291 |
292 | if __name__ == "__main__":
293 | import tyro
294 | from tqdm import tqdm
295 | from torch.utils.data import DataLoader
296 | from vhap.config.base import DataConfig, import_module
297 |
298 | cfg = tyro.cli(DataConfig)
299 | dataset = import_module(cfg._target)(
300 | cfg=cfg,
301 | img_to_tensor=False,
302 | batchify_all_views=True,
303 | )
304 | dataset.items = dataset.items[:2]
305 |
306 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
307 |
308 | detector = LandmarkDetectorFA()
309 | detector.annotate_landmarks(dataloader)
310 |
--------------------------------------------------------------------------------
/vhap/util/landmark_detector_star.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | from tqdm import tqdm
11 | import copy
12 | import argparse
13 | import torch
14 | import math
15 | import cv2
16 | import numpy as np
17 | import dlib
18 |
19 | from star.lib import utility
20 | from star.asset import predictor_path, model_path
21 |
22 | from vhap.util.log import get_logger
23 | logger = get_logger(__name__)
24 |
25 |
26 | class GetCropMatrix():
27 | """
28 | from_shape -> transform_matrix
29 | """
30 |
31 | def __init__(self, image_size, target_face_scale, align_corners=False):
32 | self.image_size = image_size
33 | self.target_face_scale = target_face_scale
34 | self.align_corners = align_corners
35 |
36 | def _compose_rotate_and_scale(self, angle, scale, shift_xy, from_center, to_center):
37 | cosv = math.cos(angle)
38 | sinv = math.sin(angle)
39 |
40 | fx, fy = from_center
41 | tx, ty = to_center
42 |
43 | acos = scale * cosv
44 | asin = scale * sinv
45 |
46 | a0 = acos
47 | a1 = -asin
48 | a2 = tx - acos * fx + asin * fy + shift_xy[0]
49 |
50 | b0 = asin
51 | b1 = acos
52 | b2 = ty - asin * fx - acos * fy + shift_xy[1]
53 |
54 | rot_scale_m = np.array([
55 | [a0, a1, a2],
56 | [b0, b1, b2],
57 | [0.0, 0.0, 1.0]
58 | ], np.float32)
59 | return rot_scale_m
60 |
61 | def process(self, scale, center_w, center_h):
62 | if self.align_corners:
63 | to_w, to_h = self.image_size - 1, self.image_size - 1
64 | else:
65 | to_w, to_h = self.image_size, self.image_size
66 |
67 | rot_mu = 0
68 | scale_mu = self.image_size / (scale * self.target_face_scale * 200.0)
69 | shift_xy_mu = (0, 0)
70 | matrix = self._compose_rotate_and_scale(
71 | rot_mu, scale_mu, shift_xy_mu,
72 | from_center=[center_w, center_h],
73 | to_center=[to_w / 2.0, to_h / 2.0])
74 | return matrix
75 |
76 |
77 | class TransformPerspective():
78 | """
79 | image, matrix3x3 -> transformed_image
80 | """
81 |
82 | def __init__(self, image_size):
83 | self.image_size = image_size
84 |
85 | def process(self, image, matrix):
86 | return cv2.warpPerspective(
87 | image, matrix, dsize=(self.image_size, self.image_size),
88 | flags=cv2.INTER_LINEAR, borderValue=0)
89 |
90 |
91 | class TransformPoints2D():
92 | """
93 | points (nx2), matrix (3x3) -> points (nx2)
94 | """
95 |
96 | def process(self, srcPoints, matrix):
97 | # nx3
98 | desPoints = np.concatenate([srcPoints, np.ones_like(srcPoints[:, [0]])], axis=1)
99 | desPoints = desPoints @ np.transpose(matrix) # nx3
100 | desPoints = desPoints[:, :2] / desPoints[:, [2, 2]]
101 | return desPoints.astype(srcPoints.dtype)
102 |
103 |
104 | class Alignment:
105 | def __init__(self, args, model_path, dl_framework, device_ids):
106 | self.input_size = 256
107 | self.target_face_scale = 1.0
108 | self.dl_framework = dl_framework
109 |
110 | # model
111 | if self.dl_framework == "pytorch":
112 | # conf
113 | self.config = utility.get_config(args)
114 | self.config.device_id = device_ids[0]
115 | # set environment
116 | utility.set_environment(self.config)
117 | self.config.init_instance()
118 | if self.config.logger is not None:
119 | self.config.logger.info("Loaded configure file %s: %s" % (args.config_name, self.config.id))
120 | self.config.logger.info("\n" + "\n".join(["%s: %s" % item for item in self.config.__dict__.items()]))
121 |
122 | net = utility.get_net(self.config)
123 | if device_ids == [-1]:
124 | checkpoint = torch.load(model_path, map_location="cpu")
125 | else:
126 | checkpoint = torch.load(model_path)
127 | net.load_state_dict(checkpoint["net"])
128 | net = net.to(self.config.device_id)
129 | net.eval()
130 | self.alignment = net
131 | else:
132 | assert False
133 |
134 | self.getCropMatrix = GetCropMatrix(image_size=self.input_size, target_face_scale=self.target_face_scale,
135 | align_corners=True)
136 | self.transformPerspective = TransformPerspective(image_size=self.input_size)
137 | self.transformPoints2D = TransformPoints2D()
138 |
139 | def norm_points(self, points, align_corners=False):
140 | if align_corners:
141 | # [0, SIZE-1] -> [-1, +1]
142 | return points / torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2) * 2 - 1
143 | else:
144 | # [-0.5, SIZE-0.5] -> [-1, +1]
145 | return (points * 2 + 1) / torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1
146 |
147 | def denorm_points(self, points, align_corners=False):
148 | if align_corners:
149 | # [-1, +1] -> [0, SIZE-1]
150 | return (points + 1) / 2 * torch.tensor([self.input_size - 1, self.input_size - 1]).to(points).view(1, 1, 2)
151 | else:
152 | # [-1, +1] -> [-0.5, SIZE-0.5]
153 | return ((points + 1) * torch.tensor([self.input_size, self.input_size]).to(points).view(1, 1, 2) - 1) / 2
154 |
155 | def preprocess(self, image, scale, center_w, center_h):
156 | matrix = self.getCropMatrix.process(scale, center_w, center_h)
157 | input_tensor = self.transformPerspective.process(image, matrix)
158 | input_tensor = input_tensor[np.newaxis, :]
159 |
160 | input_tensor = torch.from_numpy(input_tensor)
161 | input_tensor = input_tensor.float().permute(0, 3, 1, 2)
162 | input_tensor = input_tensor / 255.0 * 2.0 - 1.0
163 | input_tensor = input_tensor.to(self.config.device_id)
164 | return input_tensor, matrix
165 |
166 | def postprocess(self, srcPoints, coeff):
167 | # dstPoints = self.transformPoints2D.process(srcPoints, coeff)
168 | # matrix^(-1) * src = dst
169 | # src = matrix * dst
170 | dstPoints = np.zeros(srcPoints.shape, dtype=np.float32)
171 | for i in range(srcPoints.shape[0]):
172 | dstPoints[i][0] = coeff[0][0] * srcPoints[i][0] + coeff[0][1] * srcPoints[i][1] + coeff[0][2]
173 | dstPoints[i][1] = coeff[1][0] * srcPoints[i][0] + coeff[1][1] * srcPoints[i][1] + coeff[1][2]
174 | return dstPoints
175 |
176 | def analyze(self, image, scale, center_w, center_h):
177 | input_tensor, matrix = self.preprocess(image, scale, center_w, center_h)
178 |
179 | if self.dl_framework == "pytorch":
180 | with torch.no_grad():
181 | output = self.alignment(input_tensor)
182 | landmarks = output[-1][0]
183 | else:
184 | assert False
185 |
186 | landmarks = self.denorm_points(landmarks)
187 | landmarks = landmarks.data.cpu().numpy()[0]
188 | landmarks = self.postprocess(landmarks, np.linalg.inv(matrix))
189 |
190 | return landmarks
191 |
192 |
193 | def draw_pts(img, pts, mode="pts", shift=4, color=(0, 255, 0), radius=1, thickness=1, save_path=None, dif=0,
194 | scale=0.3, concat=False, ):
195 | img_draw = copy.deepcopy(img)
196 | for cnt, p in enumerate(pts):
197 | if mode == "index":
198 | cv2.putText(img_draw, str(cnt), (int(float(p[0] + dif)), int(float(p[1] + dif))), cv2.FONT_HERSHEY_SIMPLEX,
199 | scale, color, thickness)
200 | elif mode == 'pts':
201 | if len(img_draw.shape) > 2:
202 | # 此处来回切换是因为opencv的bug
203 | img_draw = cv2.cvtColor(img_draw, cv2.COLOR_BGR2RGB)
204 | img_draw = cv2.cvtColor(img_draw, cv2.COLOR_RGB2BGR)
205 | cv2.circle(img_draw, (int(p[0] * (1 << shift)), int(p[1] * (1 << shift))), radius << shift, color, -1,
206 | cv2.LINE_AA, shift=shift)
207 | else:
208 | raise NotImplementedError
209 | if concat:
210 | img_draw = np.concatenate((img, img_draw), axis=1)
211 | if save_path is not None:
212 | cv2.imwrite(save_path, img_draw)
213 | return img_draw
214 |
215 |
216 | class LandmarkDetectorSTAR:
217 | def __init__(
218 | self,
219 | ):
220 | self.detector = dlib.get_frontal_face_detector()
221 | self.shape_predictor = dlib.shape_predictor(predictor_path)
222 |
223 | # facial landmark detector
224 | args = argparse.Namespace()
225 | args.config_name = 'alignment'
226 | # could be downloaded here: https://drive.google.com/file/d/1aOx0wYEZUfBndYy_8IYszLPG_D2fhxrT/view
227 | # model_path = '/path/to/WFLW_STARLoss_NME_4_02_FR_2_32_AUC_0_605.pkl'
228 | device_ids = '0'
229 | device_ids = list(map(int, device_ids.split(",")))
230 | self.alignment = Alignment(args, model_path, dl_framework="pytorch", device_ids=device_ids)
231 |
232 | def detect_single_image(self, img):
233 | bbox = self.detector(img, 1)
234 |
235 | if len(bbox) == 0:
236 | bbox = np.zeros(5) - 1
237 | lmks = np.zeros([68, 3]) - 1 # set to -1 when landmarks is inavailable
238 | else:
239 | face = self.shape_predictor(img, bbox[0])
240 | shape = []
241 | for i in range(68):
242 | x = face.part(i).x
243 | y = face.part(i).y
244 | shape.append((x, y))
245 | shape = np.array(shape)
246 | x1, x2 = shape[:, 0].min(), shape[:, 0].max()
247 | y1, y2 = shape[:, 1].min(), shape[:, 1].max()
248 | scale = min(x2 - x1, y2 - y1) / 200 * 1.05
249 | center_w = (x2 + x1) / 2
250 | center_h = (y2 + y1) / 2
251 |
252 | scale, center_w, center_h = float(scale), float(center_w), float(center_h)
253 | lmks = self.alignment.analyze(img, scale, center_w, center_h)
254 |
255 | h, w = img.shape[:2]
256 |
257 | lmks = np.concatenate([lmks, np.ones([lmks.shape[0], 1])], axis=1).astype(np.float32) # (x, y, 1)
258 | lmks[:, 0] /= w
259 | lmks[:, 1] /= h
260 |
261 | bbox = np.array([bbox[0].left(), bbox[0].top(), bbox[0].right(), bbox[0].bottom(), 1.]).astype(np.float32) # (x1, y1, x2, y2, score)
262 | bbox[[0, 2]] /= w
263 | bbox[[1, 3]] /= h
264 |
265 | return bbox, lmks
266 |
267 | def detect_dataset(self, dataloader):
268 | """
269 | Annotates each frame with 68 facial landmarks
270 | :return: dict mapping frame number to landmarks numpy array and the same thing for bboxes
271 | """
272 | logger.info("Initialize Landmark Detector (STAR)...")
273 | # 68 facial landmark detector
274 |
275 | landmarks = {}
276 | bboxes = {}
277 |
278 | logger.info("Begin annotating landmarks...")
279 | for item in tqdm(dataloader):
280 | timestep_id = item["timestep_id"][0]
281 | camera_id = item["camera_id"][0]
282 |
283 | logger.info(
284 | f"Annotate facial landmarks for timestep: {timestep_id}, camera: {camera_id}"
285 | )
286 | img = item["rgb"][0].numpy()
287 |
288 | bbox, lmks = self.detect_single_image(img)
289 | if len(bbox) == 0:
290 | logger.error(
291 | f"No bbox found for frame: {timestep_id}, camera: {camera_id}. Setting landmarks to all -1."
292 | )
293 |
294 | if camera_id not in landmarks:
295 | landmarks[camera_id] = {}
296 | if camera_id not in bboxes:
297 | bboxes[camera_id] = {}
298 | landmarks[camera_id][timestep_id] = lmks
299 | bboxes[camera_id][timestep_id] = bbox
300 | return landmarks, bboxes
301 |
302 | def annotate_landmarks(self, dataloader):
303 | """
304 | Annotates each frame with landmarks for face and iris. Assumes frames have been extracted
305 | :return:
306 | """
307 | lmks_face, bboxes_faces = self.detect_dataset(dataloader)
308 |
309 | # construct final json
310 | for camera_id, lmk_face_camera in lmks_face.items():
311 | bounding_box = []
312 | face_landmark_2d = []
313 | for timestep_id in lmk_face_camera.keys():
314 | bounding_box.append(bboxes_faces[camera_id][timestep_id][None])
315 | face_landmark_2d.append(lmks_face[camera_id][timestep_id][None])
316 |
317 | lmk_dict = {
318 | "bounding_box": bounding_box,
319 | "face_landmark_2d": face_landmark_2d,
320 | }
321 |
322 | for k, v in lmk_dict.items():
323 | if len(v) > 0:
324 | lmk_dict[k] = np.concatenate(v, axis=0)
325 | out_path = dataloader.dataset.get_property_path(
326 | "landmark2d/STAR", camera_id=camera_id
327 | )
328 | logger.info(f"Saving landmarks to: {out_path}")
329 | if not out_path.parent.exists():
330 | out_path.parent.mkdir(parents=True)
331 | np.savez(out_path, **lmk_dict)
332 |
333 |
334 | if __name__ == "__main__":
335 | import tyro
336 | from tqdm import tqdm
337 | from torch.utils.data import DataLoader
338 | from vhap.config.base import DataConfig, import_module
339 |
340 | cfg = tyro.cli(DataConfig)
341 | dataset = import_module(cfg._target)(
342 | cfg=cfg,
343 | img_to_tensor=False,
344 | batchify_all_views=True,
345 | )
346 | dataset.items = dataset.items[:2]
347 |
348 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
349 |
350 | detector = LandmarkDetectorSTAR()
351 | detector.annotate_landmarks(dataloader)
352 |
--------------------------------------------------------------------------------
/vhap/util/log.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import logging
11 | import sys
12 | from datetime import datetime
13 | import atexit
14 | from pathlib import Path
15 |
16 |
17 | def _colored(msg, color):
18 | colors = {'red': '\033[91m', 'green': '\033[92m', 'yellow': '\033[93m', 'normal': '\033[0m'}
19 | return colors[color] + msg + colors["normal"]
20 |
21 |
22 | class ColorFormatter(logging.Formatter):
23 | """
24 | Class to make command line log entries more appealing
25 | Inspired by https://github.com/facebookresearch/detectron2
26 | """
27 |
28 | def formatMessage(self, record):
29 | """
30 | Print warnings yellow and errors red
31 | :param record:
32 | :return:
33 | """
34 | log = super().formatMessage(record)
35 | if record.levelno == logging.WARNING:
36 | prefix = _colored("WARNING", "yellow")
37 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
38 | prefix = _colored("ERROR", "red")
39 | else:
40 | return log
41 | return prefix + " " + log
42 |
43 |
44 | def get_logger(name, level=logging.DEBUG, root=False, log_dir=None):
45 | """
46 | Replaces the standard library logging.getLogger call in order to make some configuration
47 | for all loggers.
48 | :param name: pass the __name__ variable
49 | :param level: the desired log level
50 | :param root: call only once in the program
51 | :param log_dir: if root is set to True, this defines the directory where a log file is going
52 | to be created that contains all logging output
53 | :return: the logger object
54 | """
55 | logger = logging.getLogger(name)
56 | logger.setLevel(level)
57 |
58 | if root:
59 | # create handler for console
60 | console_handler = logging.StreamHandler(sys.stdout)
61 | console_handler.setLevel(level)
62 | formatter = ColorFormatter(_colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
63 | datefmt="%m/%d %H:%M:%S")
64 | console_handler.setFormatter(formatter)
65 | logger.addHandler(console_handler)
66 | logger.propagate = False # otherwise root logger prints things again
67 |
68 | if log_dir is not None:
69 | # add handler to log to a file
70 | log_dir = Path(log_dir)
71 | if not log_dir.exists():
72 | logger.info(f"Logging directory {log_dir} does not exist and will be created")
73 | log_dir.mkdir(parents=True)
74 | timestamp = datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
75 | log_file = log_dir / f"{timestamp}.log"
76 |
77 | # open stream and make sure it will be closed
78 | stream = log_file.open(mode="w")
79 | atexit.register(stream.close)
80 |
81 | formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
82 | datefmt="%m/%d %H:%M:%S")
83 | file_handler = logging.StreamHandler(stream)
84 | file_handler.setLevel(level)
85 | file_handler.setFormatter(formatter)
86 | logger.addHandler(file_handler)
87 |
88 | return logger
89 |
--------------------------------------------------------------------------------
/vhap/util/mesh.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import torch
11 |
12 |
13 | def get_mtl_content(tex_fname):
14 | return f'newmtl Material\nmap_Kd {tex_fname}\n'
15 |
16 | def get_obj_content(vertices, faces, uv_coordinates=None, uv_indices=None, mtl_fname=None):
17 | obj = ('# Generated with multi-view-head-tracker\n')
18 |
19 | if mtl_fname is not None:
20 | obj += f'mtllib {mtl_fname}\n'
21 | obj += 'usemtl Material\n'
22 |
23 | # Write the vertices
24 | for vertex in vertices:
25 | obj += f"v {vertex[0]} {vertex[1]} {vertex[2]}\n"
26 |
27 | # Write the UV coordinates
28 | if uv_coordinates is not None:
29 | for uv in uv_coordinates:
30 | obj += f"vt {uv[0]} {uv[1]}\n"
31 |
32 | # Write the faces with UV indices
33 | if uv_indices is not None:
34 | for face, uv_indices in zip(faces, uv_indices):
35 | obj += f"f {face[0]+1}/{uv_indices[0]+1} {face[1]+1}/{uv_indices[1]+1} {face[2]+1}/{uv_indices[2]+1}\n"
36 | else:
37 | for face in faces:
38 | obj += f"f {face[0]+1} {face[1]+1} {face[2]+1}\n"
39 | return obj
40 |
41 | def normalize_image_points(u, v, resolution):
42 | """
43 | normalizes u, v coordinates from [0 ,image_size] to [-1, 1]
44 | :param u:
45 | :param v:
46 | :param resolution:
47 | :return:
48 | """
49 | u = 2 * (u - resolution[1] / 2.0) / resolution[1]
50 | v = 2 * (v - resolution[0] / 2.0) / resolution[0]
51 | return u, v
52 |
53 |
54 | def face_vertices(vertices, faces):
55 | """
56 | :param vertices: [batch size, number of vertices, 3]
57 | :param faces: [batch size, number of faces, 3]
58 | :return: [batch size, number of faces, 3, 3]
59 | """
60 | assert vertices.ndimension() == 3
61 | assert faces.ndimension() == 3
62 | assert vertices.shape[0] == faces.shape[0]
63 | assert vertices.shape[2] == 3
64 | assert faces.shape[2] == 3
65 |
66 | bs, nv = vertices.shape[:2]
67 | bs, nf = faces.shape[:2]
68 | device = vertices.device
69 | faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
70 | vertices = vertices.reshape((bs * nv, 3))
71 | # pytorch only supports long and byte tensors for indexing
72 | return vertices[faces.long()]
73 |
74 |
--------------------------------------------------------------------------------
/vhap/util/render_uvmap.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import tyro
11 | import matplotlib.pyplot as plt
12 | import numpy as np
13 | import torch
14 | import nvdiffrast.torch as dr
15 |
16 | from vhap.model.flame import FlameHead
17 |
18 |
19 | FLAME_TEX_PATH = "asset/flame/FLAME_texture.npz"
20 |
21 |
22 | def transform_vt(vt):
23 | """Transform uv vertices to clip space"""
24 | xy = vt * 2 - 1
25 | w = torch.ones([1, vt.shape[-2], 1]).to(vt)
26 | z = -w # In the clip spcae of OpenGL, the camera looks at -z
27 | xyzw = torch.cat([xy[None, :, :], z, w], axis=-1)
28 | return xyzw
29 |
30 | def render_uvmap_vtex(glctx, pos, pos_idx, v_color, col_idx, resolution):
31 | """Render uv map with vertex color"""
32 | pos_clip = transform_vt(pos)
33 | rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution)
34 |
35 | color, _ = dr.interpolate(v_color, rast_out, col_idx)
36 | color = dr.antialias(color, rast_out, pos_clip, pos_idx)
37 | return color
38 |
39 | def render_uvmap_texmap(glctx, pos, pos_idx, verts_uv, faces_uv, tex, resolution, enable_mip=True, max_mip_level=None):
40 | """Render uv map with texture map"""
41 | pos_clip = transform_vt(pos)
42 | rast_out, rast_out_db = dr.rasterize(glctx, pos_clip, pos_idx, resolution)
43 |
44 | if enable_mip:
45 | texc, texd = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv, rast_db=rast_out_db, diff_attrs='all')
46 | color = dr.texture(tex[None, ...], texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=max_mip_level)
47 | else:
48 | texc, _ = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv)
49 | color = dr.texture(tex[None, ...], texc, filter_mode='linear')
50 | color = dr.antialias(color, rast_out, pos_clip, pos_idx)
51 | return color
52 |
53 |
54 | def main(
55 | use_texmap: bool = False,
56 | use_opengl: bool = False,
57 | ):
58 | n_shape = 300
59 | n_expr = 100
60 | print("Initialization FLAME model")
61 | flame_model = FlameHead(n_shape, n_expr)
62 |
63 | verts_uv = flame_model.verts_uvs.cuda()
64 | verts_uv[:, 1] = 1 - verts_uv[:, 1]
65 | faces_uv = flame_model.textures_idx.int().cuda()
66 |
67 | # Rasterizer context
68 | glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()
69 |
70 | h, w = 512, 512
71 | resolution = (h, w)
72 |
73 | if use_texmap:
74 | tex = torch.from_numpy(np.load(FLAME_TEX_PATH)['mean']).cuda().float().flip(dims=[-1]) / 255
75 | rgb = render_uvmap_texmap(glctx, verts_uv, faces_uv, verts_uv, faces_uv, tex, resolution, enable_mip=True)
76 | else:
77 | v_color = torch.ones(verts_uv.shape[0], 3).to(verts_uv)
78 | col_idx = faces_uv
79 | rgb = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)
80 |
81 | plt.imshow(rgb[0, :, :].cpu())
82 | plt.show()
83 |
84 |
85 | if __name__ == "__main__":
86 | tyro.cli(main)
87 |
--------------------------------------------------------------------------------
/vhap/util/vector_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5 | return torch.sum(x*y, -1, keepdim=True)
6 |
7 | def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
8 | return 2*dot(x, n)*n - x
9 |
10 | def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
11 | return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
12 |
13 | def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
14 | return x / length(x, eps)
15 |
16 | def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
17 | return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w)
18 |
--------------------------------------------------------------------------------
/vhap/util/visualization.py:
--------------------------------------------------------------------------------
1 | #
2 | # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
3 | # property and proprietary rights in and to this software and related documentation.
4 | # Any commercial use, reproduction, disclosure or distribution of this software and
5 | # related documentation without an express license agreement from Toyota Motor Europe NV/SA
6 | # is strictly prohibited.
7 | #
8 |
9 |
10 | import matplotlib.pyplot as plt
11 | import torch
12 | from torchvision.utils import draw_bounding_boxes, draw_keypoints
13 |
14 |
15 | connectivity_face = (
16 | [(i, i + 1) for i in list(range(0, 16))]
17 | + [(i, i + 1) for i in list(range(17, 21))]
18 | + [(i, i + 1) for i in list(range(22, 26))]
19 | + [(i, i + 1) for i in list(range(27, 30))]
20 | + [(i, i + 1) for i in list(range(31, 35))]
21 | + [(i, i + 1) for i in list(range(36, 41))]
22 | + [(36, 41)]
23 | + [(i, i + 1) for i in list(range(42, 47))]
24 | + [(42, 47)]
25 | + [(i, i + 1) for i in list(range(48, 59))]
26 | + [(48, 59)]
27 | + [(i, i + 1) for i in list(range(60, 67))]
28 | + [(60, 67)]
29 | )
30 |
31 |
32 | def plot_landmarks_2d(
33 | img: torch.tensor,
34 | lmks: torch.tensor,
35 | connectivity=None,
36 | colors="white",
37 | unit=1,
38 | input_float=False,
39 | ):
40 | if input_float:
41 | img = (img * 255).byte()
42 |
43 | img = draw_keypoints(
44 | img,
45 | lmks,
46 | connectivity=connectivity,
47 | colors=colors,
48 | radius=2 * unit,
49 | width=2 * unit,
50 | )
51 |
52 | if input_float:
53 | img = img.float() / 255
54 | return img
55 |
56 |
57 | def blend(a, b, w):
58 | return (a * w + b * (1 - w)).byte()
59 |
60 |
61 | if __name__ == "__main__":
62 | from argparse import ArgumentParser
63 | from torch.utils.data import DataLoader
64 | from matplotlib import pyplot as plt
65 |
66 | from vhap.data.nersemble_dataset import NeRSembleDataset
67 |
68 | parser = ArgumentParser()
69 | parser.add_argument("--root_folder", type=str, required=True)
70 | parser.add_argument("--subject", type=str, required=True)
71 | parser.add_argument("--sequence", type=str, required=True)
72 | parser.add_argument("--division", default=None)
73 | parser.add_argument("--subset", default=None)
74 | parser.add_argument("--scale_factor", type=float, default=1.0)
75 | parser.add_argument("--blend_weight", type=float, default=0.6)
76 | args = parser.parse_args()
77 |
78 | dataset = NeRSembleDataset(
79 | root_folder=args.root_folder,
80 | subject=args.subject,
81 | sequence=args.sequence,
82 | division=args.division,
83 | subset=args.subset,
84 | n_downsample_rgb=2,
85 | scale_factor=args.scale_factor,
86 | use_landmark=True,
87 | )
88 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
89 |
90 | for item in dataloader:
91 | unit = int(item["scale_factor"][0] * 3) + 1
92 |
93 | rgb = item["rgb"][0].permute(2, 0, 1)
94 | vis = rgb
95 |
96 | if "bbox_2d" in item:
97 | bbox = item["bbox_2d"][0][:4]
98 | tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit)
99 | vis = blend(tmp, vis, args.blend_weight)
100 |
101 | if "lmk2d" in item:
102 | face_landmark = item["lmk2d"][0][:, :2]
103 | tmp = plot_landmarks_2d(
104 | vis,
105 | face_landmark[None, ...],
106 | connectivity=connectivity_face,
107 | colors="white",
108 | unit=unit,
109 | )
110 | vis = blend(tmp, vis, args.blend_weight)
111 |
112 | if "lmk2d_iris" in item:
113 | iris_landmark = item["lmk2d_iris"][0][:, :2]
114 | tmp = plot_landmarks_2d(
115 | vis,
116 | iris_landmark[None, ...],
117 | colors="blue",
118 | unit=unit,
119 | )
120 | vis = blend(tmp, vis, args.blend_weight)
121 |
122 | vis = vis.permute(1, 2, 0).numpy()
123 | plt.imshow(vis)
124 | plt.draw()
125 | while not plt.waitforbuttonpress(timeout=-1):
126 | pass
127 |
--------------------------------------------------------------------------------