├── lib ├── utils │ ├── __init__.py │ ├── prof_utils.py │ ├── parallel_utils.py │ ├── base_utils.py │ ├── ray_utils.py │ ├── sem_utils.py │ ├── easy_utils.py │ └── log_utils.py ├── networks │ ├── deform │ │ └── __init__.py │ ├── relight │ │ ├── __init__.py │ │ └── relight_network.py │ ├── renderer │ │ ├── __init__.py │ │ ├── make_renderer.py │ │ └── base_renderer.py │ ├── __init__.py │ ├── make_network.py │ └── transformer.py ├── config │ └── __init__.py ├── evaluators │ ├── __init__.py │ ├── make_evaluator.py │ ├── mesh_evaluator.py │ └── base_evaluator.py ├── train │ ├── trainers │ │ ├── __init__.py │ │ ├── make_trainer.py │ │ ├── base_trainer.py │ │ ├── relight_trainer.py │ │ └── trainer.py │ ├── __init__.py │ ├── scheduler.py │ ├── optimizer.py │ ├── optimizers │ │ └── lr_scheduler.py │ └── recorder.py ├── visualizers │ ├── __init__.py │ ├── make_visualizer.py │ ├── pose_visualizer.py │ ├── demo_visualizer.py │ ├── mesh_visualizer.py │ └── light_visualizer.py └── datasets │ ├── __init__.py │ ├── mesh_dataset.py │ ├── demo_dataset.py │ ├── make_dataset.py │ ├── pose_dataset.py │ └── samplers.py ├── scripts └── tools │ ├── __init__.py │ ├── evaluate_lighting.py │ ├── json_to_xlsx.py │ ├── prepare_envmap.py │ ├── visualize_npfiles.py │ ├── prepare_config.py │ └── prepare_annots.py ├── setup.cfg ├── .vscode ├── phdeform.code-workspace ├── settings.json └── launch.json ├── pyrightconfig.json ├── .tabnineignore ├── .gitignore ├── environment.yml ├── configs ├── synthetic_human │ ├── base_synthetic_jody_1v.yaml │ ├── base_synthetic_josh_1v.yaml │ ├── base_synthetic_jody.yaml │ ├── base_synthetic_josh.yaml │ ├── base_synthetic_leonard_1v.yaml │ ├── base_synthetic_leonard.yaml │ ├── base_synthetic_megan_1v.yaml │ ├── base_synthetic_megan.yaml │ ├── base_synthetic_malcolm_1v.yaml │ ├── base_synthetic_nathan_1v.yaml │ ├── base_synthetic_malcolm.yaml │ ├── base_synthetic_manuel.yaml │ ├── base_synthetic_nathan.yaml │ ├── base_synthetic_manuel_1v.yaml │ ├── nathan_1v_geo.yaml │ ├── jody_1v_geo.yaml │ ├── josh_1v_geo.yaml │ ├── manuel_1v_geo.yaml │ ├── jody_2v_geo.yaml │ ├── malcolm_1v_geo.yaml │ ├── jody_5v_geo.yaml │ ├── megan_1v_geo.yaml │ ├── leonard_1v_geo.yaml │ ├── jody_10v_geo.yaml │ ├── josh_10v_geo.yaml │ ├── jody_10v_1f_geo.yaml │ ├── megan_10v_geo.yaml │ ├── jody_10v_20f_geo.yaml │ ├── jody_10v_50f_geo.yaml │ ├── leonard_10v_geo.yaml │ └── jody_fv_1f_geo.yaml ├── my_zju_mocap │ ├── base_zjumocap_377_1v.yaml │ ├── base_zjumocap_386_1v.yaml │ ├── my_313_4v_geo.yaml │ ├── my_377_4v_geo.yaml │ └── my_387_4v_geo.yaml ├── mobile_stage │ ├── base_mobile_white.yaml │ ├── base_mobile_purple.yaml │ ├── base_mobile_black.yaml │ ├── base_mobile_dress.yaml │ ├── base_mobile_xuzhen.yaml │ ├── base_mobile_move.yaml │ ├── white_12v_geo.yaml │ ├── purple_12v_geo.yaml │ ├── black_12v_geo.yaml │ └── xuzhen_12v_geo.yaml └── base.yaml ├── requirements.txt ├── run.py ├── train.py └── readme.md /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/networks/deform/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/networks/relight/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import cfg, args 2 | -------------------------------------------------------------------------------- /lib/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_evaluator import make_evaluator 2 | -------------------------------------------------------------------------------- /lib/networks/renderer/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_renderer import make_renderer -------------------------------------------------------------------------------- /lib/train/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_trainer import make_trainer 2 | -------------------------------------------------------------------------------- /lib/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_visualizer import make_visualizer 2 | -------------------------------------------------------------------------------- /lib/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_network import make_network 2 | # import embedder 3 | # import transformer -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pdbr] 2 | ; use_traceback = True 3 | theme = ansi_dark 4 | ; style = dim 5 | store_history = .pdbr_history -------------------------------------------------------------------------------- /.vscode/phdeform.code-workspace: -------------------------------------------------------------------------------- 1 | { 2 | "folders": [ 3 | { 4 | "path": ".." 5 | } 6 | ], 7 | "settings": {} 8 | } -------------------------------------------------------------------------------- /scripts/tools/evaluate_lighting.py: -------------------------------------------------------------------------------- 1 | # This scripts read exr files, evaluates them with rendered ground truth and then reports results -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | "**/node_modules", 4 | "**/__pycache__", 5 | "data", 6 | ], 7 | } -------------------------------------------------------------------------------- /lib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainers import make_trainer 2 | from .optimizer import make_optimizer 3 | from .scheduler import make_lr_scheduler, set_lr_scheduler 4 | from .recorder import make_recorder 5 | 6 | -------------------------------------------------------------------------------- /lib/networks/make_network.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from torch import nn 3 | 4 | def make_network(cfg) -> nn.Module: 5 | module = cfg.network_module 6 | network = importlib.import_module(module).Network() 7 | return network 8 | -------------------------------------------------------------------------------- /.tabnineignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | .ipynb_checkpoints/ 4 | *.py[cod] 5 | *.so 6 | *.orig 7 | *.o 8 | *.json 9 | *.pth 10 | *.npy 11 | *.ipynb 12 | *.th 13 | 14 | # imgui stuff 15 | imgui.ini 16 | 17 | *.ply 18 | *.png 19 | *.jpy 20 | 21 | data 22 | -------------------------------------------------------------------------------- /lib/visualizers/make_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from .base_visualizer import Visualizer 4 | 5 | def make_visualizer(cfg) -> Visualizer: 6 | module = cfg.visualizer_module 7 | visualizer = importlib.import_module(module).Visualizer() 8 | return visualizer 9 | -------------------------------------------------------------------------------- /lib/networks/renderer/make_renderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from .base_renderer import Renderer 4 | 5 | def make_renderer(cfg, network) -> Renderer: 6 | module = cfg.renderer_module 7 | renderer = importlib.import_module(module).Renderer(network) 8 | return renderer 9 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataset import make_data_loader 2 | 3 | from os.path import dirname, basename, isfile, join 4 | import glob 5 | modules = glob.glob(join(dirname(__file__), "*.py")) 6 | __all__ = [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] 7 | -------------------------------------------------------------------------------- /lib/train/trainers/make_trainer.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | import importlib 3 | 4 | 5 | def _wrapper_factory(cfg, network): 6 | module = cfg.trainer_module 7 | network_wrapper = importlib.import_module(module).NetworkWrapper(network) 8 | return network_wrapper 9 | 10 | 11 | def make_trainer(cfg, network): 12 | network = _wrapper_factory(cfg, network) 13 | return Trainer(network) 14 | -------------------------------------------------------------------------------- /lib/evaluators/make_evaluator.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from .base_evaluator import Evaluator 3 | 4 | def _evaluator_factory(cfg): 5 | module = cfg.evaluator_module 6 | evaluator = importlib.import_module(module).Evaluator() 7 | return evaluator 8 | 9 | 10 | def make_evaluator(cfg) -> Evaluator: 11 | if cfg.skip_eval: 12 | return None 13 | else: 14 | return _evaluator_factory(cfg) 15 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "search.exclude": { 3 | "**/.git": true, 4 | "**/node_modules": true, 5 | "**/data": true, 6 | }, 7 | "files.exclude": { 8 | "**/.git": true, 9 | "**/node_modules": true, 10 | "**/data": true, 11 | }, 12 | "files.watcherExclude": { 13 | "**/data": true, 14 | "**/.git/objects/**": true, 15 | "**/.git/subtree-cache/**": true, 16 | "**/node_modules/*/**": true 17 | } 18 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | .ipynb_checkpoints/ 4 | *.py[cod] 5 | *.so 6 | *.orig 7 | *.o 8 | # why ignore json 9 | # *.json 10 | *.pth 11 | *.npy 12 | *.ipynb 13 | 14 | # graphviz 15 | *.dot 16 | *.svg 17 | *.eps 18 | 19 | # macOS 20 | .DS_Store 21 | 22 | # mesh 23 | *.ply 24 | 25 | # images 26 | *.png 27 | *.jpg 28 | *.jpeg 29 | *.gif 30 | .fuse* 31 | *.tar.gz 32 | *.pkl 33 | *.swp 34 | 35 | *.obj 36 | *.mtl 37 | *.blend 38 | *.blend1 39 | *.npz 40 | 41 | *.xls 42 | *.xlsx 43 | condaenv* 44 | *.hdr 45 | 46 | data 47 | *.pdf 48 | model_graph 49 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: phdeform 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | 8 | dependencies: 9 | - python>=3.10 10 | # pytorch3d does not support pytorch 1.13 yet 11 | # so we install 1.12.1 for now 12 | - pytorch==1.12.1 # conda forge version of pytorch does not seem to have kineto 13 | - cudatoolkit==11.6 14 | - torchvision 15 | - torchaudio 16 | 17 | - git 18 | - conda-forge::ncurses # vim: /home/xuzhen/miniconda3/envs/phdeform/bin/../lib/libtinfo.so.6: no version information available (required by vim) 19 | - vim 20 | - cmake 21 | 22 | - pip 23 | # - pip: 24 | # - -r requirements.txt 25 | -------------------------------------------------------------------------------- /scripts/tools/json_to_xlsx.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pandas as pd 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--json', type=str, default='data/metrics_ablation.json') 8 | args = parser.parse_args() 9 | 10 | metrics_json = json.load(open(args.json)) 11 | table = {} 12 | for i, exp in enumerate(metrics_json): 13 | for j, key in enumerate(metrics_json[exp]): 14 | for k, met in enumerate(metrics_json[exp][key]): 15 | if f'{key}_{met}' not in table: 16 | table[f'{key}_{met}'] = {} 17 | table[f'{key}_{met}'][exp] = metrics_json[exp][key][met] 18 | 19 | # __import__('pdbr').set_trace() 20 | 21 | w = pd.ExcelWriter(args.json.replace('.json', '.xlsx')) 22 | df = pd.DataFrame(table) 23 | df.to_excel(w) 24 | w.close() -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_jody_1v.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_synthetic_jody_1v 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | # Relighting Configuration 33 | relighting_cfg: 34 | exp_name: relight_synthetic_jody_1v 35 | geometry_mesh: data/animation/deform/base_synthetic_jody_1v/can_mesh.npz 36 | geometry_pretrain: data/trained_model/deform/base_synthetic_jody_1v 37 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_josh_1v.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_synthetic_josh_1v 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/josh 7 | human: josh 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/josh 13 | human: josh 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | # Relighting Configuration 33 | relighting_cfg: 34 | exp_name: relight_synthetic_josh_1v 35 | geometry_mesh: data/animation/deform/base_synthetic_josh_1v/can_mesh.npz 36 | geometry_pretrain: data/trained_model/deform/base_synthetic_josh_1v 37 | -------------------------------------------------------------------------------- /configs/my_zju_mocap/base_zjumocap_377_1v.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_zjumocap_377_1v 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/my_zju_mocap/my_377 7 | human: my_377 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/my_zju_mocap/my_377 13 | human: my_377 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | frame_interval: 5 26 | num_train_frame: 100 27 | num_eval_frame: 100 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 1 31 | frame_sampler_interval: 6 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_synthetic_377_1v 36 | geometry_mesh: data/animation/deform/base_zjumocap_377_1v/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_zjumocap_377_1v 38 | -------------------------------------------------------------------------------- /configs/my_zju_mocap/base_zjumocap_386_1v.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_zjumocap_386_1v 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/my_zju_mocap/my_386 7 | human: my_386 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/my_zju_mocap/my_386 13 | human: my_386 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | frame_interval: 5 26 | num_train_frame: 100 27 | num_eval_frame: 100 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 1 31 | frame_sampler_interval: 6 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_synthetic_386_1v 36 | geometry_mesh: data/animation/deform/base_zjumocap_386_1v/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_zjumocap_386_1v 38 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_jody.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_synthetic_jody 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | # Relighting Configuration 33 | relighting_cfg: 34 | exp_name: relight_synthetic_jody 35 | geometry_mesh: data/animation/deform/base_synthetic_jody/can_mesh.npz 36 | geometry_pretrain: data/trained_model/deform/base_synthetic_jody 37 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_josh.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_synthetic_josh 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/josh 7 | human: josh 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/josh 13 | human: josh 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | # Relighting Configuration 33 | relighting_cfg: 34 | exp_name: relight_synthetic_josh 35 | geometry_mesh: data/animation/deform/base_synthetic_josh/can_mesh.npz 36 | geometry_pretrain: data/trained_model/deform/base_synthetic_josh 37 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_leonard_1v.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_synthetic_leonard_1v 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/leonard 7 | human: leonard 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/leonard 13 | human: leonard 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | # Relighting Configuration 33 | relighting_cfg: 34 | exp_name: relight_synthetic_leonard_1v 35 | geometry_mesh: data/animation/deform/base_synthetic_leonard_1v/can_mesh.npz 36 | geometry_pretrain: data/trained_model/deform/base_synthetic_leonard_1v 37 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_leonard.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_synthetic_leonard 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/leonard 7 | human: leonard 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/leonard 13 | human: leonard 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | # Relighting Configuration 33 | relighting_cfg: 34 | exp_name: relight_synthetic_leonard 35 | geometry_mesh: data/animation/deform/base_synthetic_leonard/can_mesh.npz 36 | geometry_pretrain: data/trained_model/deform/base_synthetic_leonard 37 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_megan_1v.yaml: -------------------------------------------------------------------------------- 1 | # Dataset Configuration 2 | exp_name: base_synthetic_megan_1v 3 | parent_cfg: configs/base.yaml 4 | 5 | # Data Configuration 6 | train_dataset: 7 | data_root: data/synthetic_human/megan 8 | human: megan 9 | ann_file: annots.npy 10 | split: train 11 | 12 | test_dataset: 13 | data_root: data/synthetic_human/megan 14 | human: megan 15 | ann_file: annots.npy 16 | split: test 17 | 18 | # Selection Configuration 19 | ratio: 1.0 20 | # prettier-ignore 21 | training_view: [ 4, ] 22 | # prettier-ignore 23 | test_view: [0, 4, 8, 12, 15, 19] 24 | fix_material: 0 25 | begin_ith_frame: 0 26 | num_train_frame: 100 27 | num_eval_frame: 100 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 1 31 | frame_sampler_interval: 21 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_synthetic_megan_1v 36 | geometry_mesh: data/animation/deform/base_synthetic_megan_1v/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_synthetic_megan_1v 38 | -------------------------------------------------------------------------------- /configs/mobile_stage/base_mobile_white.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_mobile_white 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data options 5 | train_dataset: 6 | data_root: data/mobile_stage/white 7 | human: model4 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/white 13 | human: model4 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Options 20 | # prettier-ignore 21 | training_view: [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33] # 12 views 22 | # prettier-ignore 23 | test_view: [] # will use all view except training views 24 | fix_material: 300 25 | begin_ith_frame: 300 26 | num_train_frame: 300 27 | num_eval_frame: 600 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 8 31 | frame_sampler_interval: 80 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_mobile_white 36 | geometry_mesh: data/animation/deform/base_mobile_white/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_mobile_white 38 | -------------------------------------------------------------------------------- /configs/mobile_stage/base_mobile_purple.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_mobile_purple 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data options 5 | train_dataset: 6 | data_root: data/mobile_stage/purple 7 | human: model2 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/purple 13 | human: model2 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Options 20 | # prettier-ignore 21 | training_view: [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33] # 12 views 22 | # prettier-ignore 23 | test_view: [] # will use all view except training views 24 | fix_material: 0 25 | begin_ith_frame: 0 26 | num_train_frame: 600 27 | num_eval_frame: 700 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 8 31 | frame_sampler_interval: 110 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_mobile_purple 36 | geometry_mesh: data/animation/deform/base_mobile_purple/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_mobile_purple 38 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_megan.yaml: -------------------------------------------------------------------------------- 1 | # Dataset Configuration 2 | exp_name: base_synthetic_megan 3 | parent_cfg: configs/base.yaml 4 | 5 | # Data Configuration 6 | train_dataset: 7 | data_root: data/synthetic_human/megan 8 | human: megan 9 | ann_file: annots.npy 10 | split: train 11 | 12 | test_dataset: 13 | data_root: data/synthetic_human/megan 14 | human: megan 15 | ann_file: annots.npy 16 | split: test 17 | 18 | # Selection Configuration 19 | ratio: 1.0 20 | # prettier-ignore 21 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 22 | # prettier-ignore 23 | test_view: [0, 4, 8, 12, 15, 19] 24 | fix_material: 0 25 | begin_ith_frame: 0 26 | num_train_frame: 100 27 | num_eval_frame: 100 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 1 31 | frame_sampler_interval: 21 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_synthetic_megan 36 | geometry_mesh: data/animation/deform/base_synthetic_megan/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_synthetic_megan 38 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_malcolm_1v.yaml: -------------------------------------------------------------------------------- 1 | # Dataset Configuration 2 | exp_name: base_synthetic_malcolm_1v 3 | parent_cfg: configs/base.yaml 4 | 5 | # Data Configuration 6 | train_dataset: 7 | data_root: data/synthetic_human/malcolm 8 | human: malcolm 9 | ann_file: annots.npy 10 | split: train 11 | 12 | test_dataset: 13 | data_root: data/synthetic_human/malcolm 14 | human: malcolm 15 | ann_file: annots.npy 16 | split: test 17 | 18 | # Selection Configuration 19 | ratio: 1.0 20 | # prettier-ignore 21 | training_view: [ 4, ] 22 | # prettier-ignore 23 | test_view: [0, 4, 8, 12, 15, 19] 24 | fix_material: 1 25 | begin_ith_frame: 1 26 | num_train_frame: 69 27 | num_eval_frame: 69 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 1 31 | frame_sampler_interval: 13 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_synthetic_malcolm_1v 36 | geometry_mesh: data/animation/deform/base_synthetic_malcolm_1v/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_synthetic_malcolm_1v 38 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_nathan_1v.yaml: -------------------------------------------------------------------------------- 1 | # Dataset Configuration 2 | exp_name: base_synthetic_nathan_1v 3 | parent_cfg: configs/base.yaml 4 | 5 | # Data Configuration 6 | train_dataset: 7 | data_root: data/synthetic_human/nathan 8 | human: nathan 9 | ann_file: annots.npy 10 | split: train 11 | 12 | test_dataset: 13 | data_root: data/synthetic_human/nathan 14 | human: nathan 15 | ann_file: annots.npy 16 | split: test 17 | 18 | # Selection Configuration 19 | ratio: 1.0 20 | # prettier-ignore 21 | training_view: [ 4, ] 22 | # prettier-ignore 23 | test_view: [0, 4, 8, 12, 15, 19] 24 | fix_material: 1 25 | begin_ith_frame: 1 26 | num_train_frame: 68 # strange... 27 | num_eval_frame: 68 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 1 31 | frame_sampler_interval: 13 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_synthetic_nathan_1v 36 | geometry_mesh: data/animation/deform/base_synthetic_nathan_1v/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_synthetic_nathan_1v 38 | -------------------------------------------------------------------------------- /configs/mobile_stage/base_mobile_black.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_mobile_black 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data options 5 | train_dataset: 6 | data_root: data/mobile_stage/black 7 | human: model2 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/black 13 | human: model2 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Options 20 | # prettier-ignore 21 | training_view: [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33] # 12 views 22 | # prettier-ignore 23 | test_view: [] # will use all view except training views 24 | fix_material: 0 25 | begin_ith_frame: 0 26 | num_train_frame: 300 # only 400 valid frames 27 | num_eval_frame: 400 28 | num_render_view: 100 29 | test: 30 | view_sampler_interval: 8 31 | frame_sampler_interval: 60 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_mobile_black 36 | geometry_mesh: data/animation/deform/base_mobile_black/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_mobile_black 38 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_malcolm.yaml: -------------------------------------------------------------------------------- 1 | # Dataset Configuration 2 | exp_name: base_synthetic_malcolm 3 | parent_cfg: configs/base.yaml 4 | 5 | # Data Configuration 6 | train_dataset: 7 | data_root: data/synthetic_human/malcolm 8 | human: malcolm 9 | ann_file: annots.npy 10 | split: train 11 | 12 | test_dataset: 13 | data_root: data/synthetic_human/malcolm 14 | human: malcolm 15 | ann_file: annots.npy 16 | split: test 17 | 18 | # Selection Configuration 19 | ratio: 1.0 20 | # prettier-ignore 21 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 22 | # prettier-ignore 23 | test_view: [0, 4, 8, 12, 15, 19] 24 | fix_material: 1 25 | begin_ith_frame: 1 26 | num_train_frame: 69 27 | num_eval_frame: 69 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 1 31 | frame_sampler_interval: 13 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_synthetic_malcolm 36 | geometry_mesh: data/animation/deform/base_synthetic_malcolm/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_synthetic_malcolm 38 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_manuel.yaml: -------------------------------------------------------------------------------- 1 | # Dataset Configuration 2 | exp_name: base_synthetic_manuel 3 | parent_cfg: configs/base.yaml 4 | 5 | # Data Configuration 6 | train_dataset: 7 | data_root: data/synthetic_human/manuel 8 | human: manuel 9 | ann_file: annots.npy 10 | split: train 11 | 12 | test_dataset: 13 | data_root: data/synthetic_human/manuel 14 | human: manuel 15 | ann_file: annots.npy 16 | split: test 17 | 18 | # Selection Configuration 19 | ratio: 1.0 20 | # prettier-ignore 21 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 22 | # prettier-ignore 23 | test_view: [0, 4, 8, 12, 15, 19] 24 | fix_material: 100 25 | begin_ith_frame: 100 26 | num_train_frame: 800 27 | num_eval_frame: 1100 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 1 31 | frame_sampler_interval: 100 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_synthetic_manuel 36 | geometry_mesh: data/animation/deform/base_synthetic_manuel/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_synthetic_manuel 38 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_nathan.yaml: -------------------------------------------------------------------------------- 1 | # Dataset Configuration 2 | exp_name: base_synthetic_nathan 3 | parent_cfg: configs/base.yaml 4 | 5 | # Data Configuration 6 | train_dataset: 7 | data_root: data/synthetic_human/nathan 8 | human: nathan 9 | ann_file: annots.npy 10 | split: train 11 | 12 | test_dataset: 13 | data_root: data/synthetic_human/nathan 14 | human: nathan 15 | ann_file: annots.npy 16 | split: test 17 | 18 | # Selection Configuration 19 | ratio: 1.0 20 | # prettier-ignore 21 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 22 | # prettier-ignore 23 | test_view: [0, 4, 8, 12, 15, 19] 24 | fix_material: 1 25 | begin_ith_frame: 1 26 | num_train_frame: 68 # strange... 27 | num_eval_frame: 68 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 1 31 | frame_sampler_interval: 13 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_synthetic_nathan 36 | geometry_mesh: data/animation/deform/base_synthetic_nathan/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_synthetic_nathan 38 | -------------------------------------------------------------------------------- /configs/mobile_stage/base_mobile_dress.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_mobile_dress 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data options 5 | train_dataset: 6 | data_root: data/mobile_stage/220609_162358_weiyu_round+001700+002600 7 | human: weiyu 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/220609_162358_weiyu_round+001700+002600 13 | human: weiyu 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Options 20 | # prettier-ignore 21 | training_view: [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30] # 12 views 22 | # prettier-ignore 23 | test_view: [] # will use all view except training views 24 | fix_material: 0 25 | begin_ith_frame: 0 26 | num_train_frame: 600 27 | num_eval_frame: 900 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 8 31 | frame_sampler_interval: 120 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_mobile_dress 36 | geometry_mesh: data/animation/deform/base_mobile_dress/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_mobile_dress 38 | -------------------------------------------------------------------------------- /configs/synthetic_human/base_synthetic_manuel_1v.yaml: -------------------------------------------------------------------------------- 1 | # Dataset Configuration 2 | exp_name: base_synthetic_manuel_1v 3 | parent_cfg: configs/base.yaml 4 | 5 | # Data Configuration 6 | train_dataset: 7 | data_root: data/synthetic_human/manuel 8 | human: manuel 9 | ann_file: annots.npy 10 | split: train 11 | 12 | test_dataset: 13 | data_root: data/synthetic_human/manuel 14 | human: manuel 15 | ann_file: annots.npy 16 | split: test 17 | 18 | # Selection Configuration 19 | ratio: 1.0 20 | # prettier-ignore 21 | # training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 22 | training_view: [4,] 23 | # prettier-ignore 24 | test_view: [0, 4, 8, 12, 15, 19] 25 | fix_material: 250 26 | begin_ith_frame: 250 27 | num_train_frame: 400 28 | num_eval_frame: 1050 29 | num_render_view: 300 30 | test: 31 | view_sampler_interval: 1 32 | frame_sampler_interval: 100 33 | 34 | # Relighting Configuration 35 | relighting_cfg: 36 | exp_name: relight_synthetic_manuel_1v 37 | geometry_mesh: data/animation/deform/base_synthetic_manuel_1v/can_mesh.npz 38 | geometry_pretrain: data/trained_model/deform/base_synthetic_manuel_1v 39 | -------------------------------------------------------------------------------- /configs/mobile_stage/base_mobile_xuzhen.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_mobile_xuzhen 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/mobile_stage/xuzhen 7 | human: xuzhen 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/xuzhen 13 | human: xuzhen 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Configuration 20 | ratio: 0.5 21 | # prettier-ignore 22 | training_view: [0, 3, 6, 7, 12, 17, 18, 21, 24, 26, 30, 34] 23 | # prettier-ignore 24 | test_view: [] # use training views 25 | fix_material: 0 26 | begin_ith_frame: 0 27 | num_train_frame: 1600 28 | num_eval_frame: 2000 29 | num_render_view: 400 30 | test: 31 | view_sampler_interval: 8 32 | frame_sampler_interval: 300 33 | 34 | # Loss Configuration 35 | resd_loss_weight: 0.1 36 | lambertian: False # important for final result when view is dense enough 37 | 38 | # Relighting Configuration 39 | relighting_cfg: 40 | exp_name: relight_mobile_xuzhen 41 | geometry_mesh: data/animation/deform/base_mobile_xuzhen/can_mesh.npz 42 | geometry_pretrain: data/trained_model/deform/base_mobile_xuzhen 43 | -------------------------------------------------------------------------------- /lib/visualizers/pose_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from lib.config import cfg 3 | from lib.config.config import Output 4 | from lib.utils.data_utils import save_image 5 | from lib.utils.base_utils import dotdict 6 | from . import base_visualizer 7 | 8 | 9 | class Visualizer(base_visualizer.Visualizer): 10 | def prepare_result_paths(self): 11 | img_path = f'data/pose_sequence/{cfg.exp_name}/view_{{view:04d}}/{{type}}/frame_{{frame:04d}}{cfg.vis_ext}' 12 | self.img_path = img_path 13 | self.result_dir = os.path.dirname(self.img_path) 14 | 15 | def visualize_single_type(self, output: dotdict, batch: dotdict, type: Output = Output.Rendering): 16 | img_pred = Visualizer.generate_image(output, batch, type) 17 | frame_index = batch.meta.frame_index.item() 18 | view_index = batch.meta.view_index.item() 19 | self.view_index = view_index # for generating video 20 | self.frame_index = frame_index # for generating video 21 | 22 | img_path = self.img_path.format(type=type.name.lower(), frame=frame_index, view=view_index) 23 | 24 | os.makedirs(os.path.dirname(img_path), exist_ok=True) 25 | save_image(img_path, img_pred) 26 | -------------------------------------------------------------------------------- /configs/synthetic_human/nathan_1v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: nathan_1v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/nathan 7 | human: nathan 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/nathan 13 | human: nathan 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | begin_ith_frame: 1 24 | num_train_frame: 68 25 | num_eval_frame: 68 26 | num_render_view: 300 27 | test: 28 | view_sampler_interval: 1 29 | frame_sampler_interval: 21 30 | 31 | mesh_simp: True # match smpl? 32 | mesh_simp_face: 16384 33 | 34 | # Relighting Configuration 35 | relighting_cfg: 36 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 37 | obj_lvis: 38 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 39 | 40 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 41 | exp_name: nathan_1v_geo_fix_mat 42 | geometry_mesh: data/animation/deform/nathan_1v_geo/can_mesh.npz 43 | geometry_pretrain: data/trained_model/deform/nathan_1v_geo 44 | -------------------------------------------------------------------------------- /configs/synthetic_human/jody_1v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: jody_1v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: jody_1v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/jody_1v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/jody_1v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/josh_1v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: josh_1v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/josh 7 | human: josh 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/josh 13 | human: josh 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: josh_1v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/josh_1v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/josh_1v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/manuel_1v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: manuel_1v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/manuel 7 | human: manuel 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/manuel 13 | human: manuel 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | begin_ith_frame: 250 24 | num_train_frame: 400 25 | num_eval_frame: 1050 26 | num_render_view: 300 27 | test: 28 | view_sampler_interval: 1 29 | frame_sampler_interval: 21 30 | 31 | mesh_simp: True # match smpl? 32 | mesh_simp_face: 16384 33 | 34 | # Relighting Configuration 35 | relighting_cfg: 36 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 37 | obj_lvis: 38 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 39 | 40 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 41 | exp_name: manuel_1v_geo_fix_mat 42 | geometry_mesh: data/animation/deform/manuel_1v_geo/can_mesh.npz 43 | geometry_pretrain: data/trained_model/deform/manuel_1v_geo 44 | -------------------------------------------------------------------------------- /configs/synthetic_human/jody_2v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: jody_2v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 2, 6, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: jody_2v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/jody_2v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/jody_2v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/malcolm_1v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: malcolm_1v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/malcolm 7 | human: malcolm 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/malcolm 13 | human: malcolm 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | begin_ith_frame: 1 24 | num_train_frame: 69 25 | num_eval_frame: 69 26 | num_render_view: 300 27 | test: 28 | view_sampler_interval: 1 29 | frame_sampler_interval: 21 30 | 31 | mesh_simp: True # match smpl? 32 | mesh_simp_face: 16384 33 | 34 | # Relighting Configuration 35 | relighting_cfg: 36 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 37 | obj_lvis: 38 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 39 | 40 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 41 | exp_name: malcolm_1v_geo_fix_mat 42 | geometry_mesh: data/animation/deform/malcolm_1v_geo/can_mesh.npz 43 | geometry_pretrain: data/trained_model/deform/malcolm_1v_geo 44 | -------------------------------------------------------------------------------- /configs/synthetic_human/jody_5v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: jody_5v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 2, 4, 6, 8 ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: jody_5v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/jody_5v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/jody_5v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/megan_1v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: megan_1v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/megan 7 | human: megan 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/megan 13 | human: megan 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: megan_1v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/megan_1v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/megan_1v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/leonard_1v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: leonard_1v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/leonard 7 | human: leonard 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/leonard 13 | human: leonard 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 4, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: leonard_1v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/leonard_1v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/leonard_1v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/jody_10v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: jody_10v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: jody_10v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/jody_10v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/jody_10v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/josh_10v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: josh_10v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/josh 7 | human: josh 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/josh 13 | human: josh 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: josh_10v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/josh_10v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/josh_10v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/jody_10v_1f_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: jody_10v_1f_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 1 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: jody_10v_1f_geo_fix_mat 43 | geometry_mesh: data/animation/deform/jody_10v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/jody_10v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/megan_10v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: megan_10v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/megan 7 | human: megan 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/megan 13 | human: megan 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: megan_10v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/megan_10v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/megan_10v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/jody_10v_20f_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: jody_10v_20f_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 20 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: jody_10v_20f_geo_fix_mat 43 | geometry_mesh: data/animation/deform/jody_10v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/jody_10v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/jody_10v_50f_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: jody_10v_50f_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 50 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: jody_10v_50f_geo_fix_mat 43 | geometry_mesh: data/animation/deform/jody_10v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/jody_10v_geo 45 | -------------------------------------------------------------------------------- /configs/synthetic_human/leonard_10v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: leonard_10v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/leonard 7 | human: leonard 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/leonard 13 | human: leonard 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ] 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 100 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: leonard_10v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/leonard_10v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/leonard_10v_geo 45 | -------------------------------------------------------------------------------- /configs/my_zju_mocap/my_313_4v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: my_313_4v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/my_zju_mocap/my_313 7 | human: my_313 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/my_zju_mocap/my_313 13 | human: my_313 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 6, 12, 18 ] 21 | # prettier-ignore 22 | test_view: [1,2,3,4,5,7,8,9,10,11,13,14,15,16,17,19,20] # use other views as training view 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 60 26 | num_eval_frame: 60 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 10 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: my_313_4v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/my_313_4v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/my_313_4v_geo 45 | -------------------------------------------------------------------------------- /configs/my_zju_mocap/my_377_4v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: my_377_4v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/my_zju_mocap/my_377 7 | human: my_377 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/my_zju_mocap/my_377 13 | human: my_377 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 6, 12, 18 ] 21 | # prettier-ignore 22 | test_view: [1,2,3,4,5,7,8,9,10,11,13,14,15,16,17,19,20] # use other views as training view 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 60 26 | num_eval_frame: 60 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 10 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: my_377_4v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/my_377_4v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/my_377_4v_geo 45 | -------------------------------------------------------------------------------- /configs/my_zju_mocap/my_387_4v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: my_387_4v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/my_zju_mocap/my_387 7 | human: my_387 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/my_zju_mocap/my_387 13 | human: my_387 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 0, 6, 12, 18 ] 21 | # prettier-ignore 22 | test_view: [1,2,3,4,5,7,8,9,10,11,13,14,15,16,17,19,20] # use other views as training view 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 60 26 | num_eval_frame: 60 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 10 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | # Relighting Configuration 36 | relighting_cfg: 37 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 38 | obj_lvis: 39 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 40 | 41 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 42 | exp_name: my_387_4v_geo_fix_mat 43 | geometry_mesh: data/animation/deform/my_387_4v_geo/can_mesh.npz 44 | geometry_pretrain: data/trained_model/deform/my_387_4v_geo 45 | -------------------------------------------------------------------------------- /scripts/tools/prepare_envmap.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | import os 3 | import cv2 4 | # fmt: on 5 | 6 | import argparse 7 | from tqdm import tqdm 8 | from glob import glob 9 | from os.path import join 10 | 11 | 12 | # fmt: off 13 | import sys 14 | sys.path.append('.') 15 | from lib.utils.log_utils import log, run 16 | from lib.utils.data_utils import load_unchanged, save_unchanged 17 | from lib.utils.parallel_utils import parallel_execution 18 | # fmt: on 19 | 20 | 21 | def resize_hdri(img_path, out_path, height=16, width=32): 22 | log(f'loading {img_path}') 23 | hdri = load_unchanged(img_path) 24 | log(f'resizing {img_path}') 25 | hdri = cv2.resize(hdri, (width, height), interpolation=cv2.INTER_AREA) 26 | log(f'saving {img_path} to {out_path}') 27 | save_unchanged(out_path, hdri) 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--data_root', default='data/lighting') 33 | parser.add_argument('--dir_8k', default='8k') 34 | parser.add_argument('--dir_16x32', default='16x32') 35 | args = parser.parse_args() 36 | img_paths = glob(join(args.data_root, args.dir_8k, "*.hdr")) + glob(join(args.data_root, args.dir_8k, "*.exr")) 37 | out_paths = [join(args.data_root, args.dir_16x32, os.path.basename(img_path)) for img_path in img_paths] 38 | parallel_execution(img_paths, out_paths, action=resize_hdri, print_progress=True) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /configs/synthetic_human/jody_fv_1f_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: jody_fv_1f_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/synthetic_human/jody 7 | human: jody 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/synthetic_human/jody 13 | human: jody 14 | ann_file: annots.npy 15 | split: test 16 | 17 | # Selection Configuration 18 | ratio: 1.0 19 | # prettier-ignore 20 | training_view: [ 3, 4, 5, 6, 7 ] # frontal views 21 | # prettier-ignore 22 | test_view: [0, 4, 8, 12, 15, 19] 23 | fix_material: 0 24 | begin_ith_frame: 0 25 | num_train_frame: 1 26 | num_eval_frame: 100 27 | num_render_view: 300 28 | test: 29 | view_sampler_interval: 1 30 | frame_sampler_interval: 21 31 | 32 | mesh_simp: True # match smpl? 33 | mesh_simp_face: 16384 34 | 35 | train: 36 | epoch: 100 37 | 38 | # Relighting Configuration 39 | relighting_cfg: 40 | train: 41 | epoch: 100 42 | 43 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 44 | obj_lvis: 45 | dist_th: 0.025 # this affects performance greatly, slow, good looking and smooth, fast, janky 46 | 47 | use_geometry: True # when using learned geometry, could have much higher tolerance for distance 48 | exp_name: jody_fv_1f_geo_fix_mat 49 | geometry_mesh: data/animation/deform/jody_fv_1f_geo/can_mesh.npz 50 | geometry_pretrain: data/trained_model/deform/jody_fv_1f_geo 51 | -------------------------------------------------------------------------------- /configs/mobile_stage/base_mobile_move.yaml: -------------------------------------------------------------------------------- 1 | exp_name: base_mobile_move 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data options 5 | train_dataset: 6 | data_root: data/mobile_stage/220608_163646_model3_move+001400+002600 7 | human: model3 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/220608_163646_model3_move+001400+002600 13 | human: model3 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Options 20 | # prettier-ignore 21 | training_view: [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33] # 12 views 22 | # prettier-ignore 23 | test_view: [] # will use all view except training views 24 | fix_material: 0 25 | begin_ith_frame: 0 26 | num_train_frame: 1000 27 | num_eval_frame: 1200 28 | num_render_view: 200 29 | test: 30 | view_sampler_interval: 8 31 | frame_sampler_interval: 180 32 | 33 | # Relighting Configuration 34 | relighting_cfg: 35 | exp_name: relight_mobile_move 36 | geometry_mesh: data/animation/deform/base_mobile_move/can_mesh.npz 37 | geometry_pretrain: data/trained_model/deform/base_mobile_move 38 | train: 39 | lr_table: 40 | residual_deformation_network: 1.0e-6 # base geometry should not change much 41 | signed_distance_network: 1.0e-6 42 | eikonal_loss_weight: 0.25 # smoother canonical mesh 43 | observed_eikonal_loss_weight: 0.25 # smoother residual deformation -> also smoother canonical mesh 44 | roughness_smooth_weight: 5.0e-2 45 | -------------------------------------------------------------------------------- /configs/mobile_stage/white_12v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: white_12v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data options 5 | train_dataset: 6 | data_root: data/mobile_stage/white 7 | human: model4 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/white 13 | human: model4 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Options 20 | # prettier-ignore 21 | training_view: [0, 21, 10, 33, 22, 27, 17, 13, 3, 16, 20, 26, 15] # 13 views 22 | # prettier-ignore 23 | test_view: [] # will use all view except training views 24 | fix_material: 300 25 | begin_ith_frame: 300 26 | num_train_frame: 300 27 | num_eval_frame: 600 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 8 31 | frame_sampler_interval: 80 32 | 33 | eikonal_loss_weight: 0.01 34 | observed_eikonal_loss_weight: 0.005 35 | 36 | mesh_simp: True # match smpl? 37 | mesh_simp_face: 16384 38 | novel_view_z_off: 1.5 39 | 40 | # Relighting Configuration 41 | relighting_cfg: 42 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 43 | obj_lvis: 44 | dist_th: 0.1 # this affects performance greatly, slow, good looking and smooth, fast, janky 45 | # relight_network_width: 256 46 | # relight_network_depth: 8 47 | # use_geometry: True # when using learned geometry, could have much higher tolerance for distance 48 | achro_light: True 49 | exp_name: white_12v_geo_fix_mat 50 | geometry_mesh: data/animation/deform/white_12v_geo/can_mesh.npz 51 | geometry_pretrain: data/trained_model/deform/white_12v_geo 52 | -------------------------------------------------------------------------------- /configs/mobile_stage/purple_12v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: purple_12v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data options 5 | train_dataset: 6 | data_root: data/mobile_stage/purple 7 | human: model2 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/purple 13 | human: model2 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Options 20 | # prettier-ignore 21 | training_view: [0, 21, 10, 33, 22, 27, 17, 13, 3, 16, 20, 26, 15] # 13 views 22 | # prettier-ignore 23 | test_view: [] # will use all view except training views 24 | fix_material: 0 25 | begin_ith_frame: 0 26 | num_train_frame: 600 27 | num_eval_frame: 700 28 | num_render_view: 300 29 | test: 30 | view_sampler_interval: 8 31 | frame_sampler_interval: 110 32 | 33 | eikonal_loss_weight: 0.01 34 | observed_eikonal_loss_weight: 0.005 35 | 36 | mesh_simp: True # match smpl? 37 | mesh_simp_face: 16384 38 | novel_view_z_off: 1.5 39 | 40 | # Relighting Configuration 41 | relighting_cfg: 42 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 43 | obj_lvis: 44 | dist_th: 0.1 # this affects performance greatly, slow, good looking and smooth, fast, janky 45 | # relight_network_width: 256 46 | # relight_network_depth: 8 47 | # use_geometry: True # when using learned geometry, could have much higher tolerance for distance 48 | achro_light: True 49 | exp_name: purple_12v_geo_fix_mat 50 | geometry_mesh: data/animation/deform/purple_12v_geo/can_mesh.npz 51 | geometry_pretrain: data/trained_model/deform/purple_12v_geo 52 | -------------------------------------------------------------------------------- /configs/mobile_stage/black_12v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: black_12v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data options 5 | train_dataset: 6 | data_root: data/mobile_stage/black 7 | human: model2 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/black 13 | human: model2 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Options 20 | # prettier-ignore 21 | training_view: [0, 21, 10, 33, 22, 27, 17, 13, 3, 16, 20, 26, 15] # 13 views 22 | # prettier-ignore 23 | test_view: [] # will use all view except training views 24 | fix_material: 0 25 | begin_ith_frame: 0 26 | num_train_frame: 800 # only 400 valid frames 27 | num_eval_frame: 100 28 | num_render_view: 100 29 | test: 30 | view_sampler_interval: 8 31 | frame_sampler_interval: 60 32 | 33 | eikonal_loss_weight: 0.01 34 | observed_eikonal_loss_weight: 0.005 35 | 36 | mesh_simp: True # match smpl? 37 | mesh_simp_face: 16384 38 | novel_view_z_off: 1.5 39 | 40 | # Relighting Configuration 41 | relighting_cfg: 42 | dist_th: 0.1 # closer to surface, 2cm, and this is not helping very much 43 | obj_lvis: 44 | dist_th: 0.1 # this affects performance greatly, slow, good looking and smooth, fast, janky 45 | # relight_network_width: 256 46 | # relight_network_depth: 8 47 | # use_geometry: True # when using learned geometry, could have much higher tolerance for distance 48 | achro_light: True 49 | exp_name: black_12v_geo_fix_mat 50 | geometry_mesh: data/animation/deform/black_12v_geo/can_mesh.npz 51 | geometry_pretrain: data/trained_model/deform/black_12v_geo 52 | -------------------------------------------------------------------------------- /configs/mobile_stage/xuzhen_12v_geo.yaml: -------------------------------------------------------------------------------- 1 | exp_name: xuzhen_12v_geo 2 | parent_cfg: configs/base.yaml 3 | 4 | # Data Configuration 5 | train_dataset: 6 | data_root: data/mobile_stage/xuzhen 7 | human: xuzhen 8 | ann_file: annots.npy 9 | split: train 10 | 11 | test_dataset: 12 | data_root: data/mobile_stage/xuzhen 13 | human: xuzhen 14 | ann_file: annots.npy 15 | split: test 16 | 17 | mask: rvm 18 | 19 | # Selection Configuration 20 | ratio: 1.0 21 | # prettier-ignore 22 | training_view: [0, 2, 4, 6, 7, 8, 11, 14, 15, 19, 21, 24, 27, 30] 23 | # prettier-ignore 24 | test_view: [] # use training views 25 | fix_material: 0 26 | begin_ith_frame: 0 27 | num_train_frame: 1600 28 | num_eval_frame: 2000 29 | num_render_view: 400 30 | test: 31 | view_sampler_interval: 8 32 | frame_sampler_interval: 300 33 | 34 | eikonal_loss_weight: 0.01 35 | observed_eikonal_loss_weight: 0.005 36 | 37 | mesh_simp: True # match smpl? 38 | mesh_simp_face: 16384 39 | novel_view_z_off: 1.5 40 | 41 | # Loss Configuration 42 | resd_loss_weight: 0.1 43 | lambertian: False # important for final result when view is dense enough 44 | 45 | # Relighting Configuration 46 | relighting_cfg: 47 | dist_th: 0.125 # closer to surface, 2cm, and this is not helping very much 48 | obj_lvis: 49 | dist_th: 0.125 # this affects performance greatly, slow, good looking and smooth, fast, janky 50 | # relight_network_width: 256 51 | # relight_network_depth: 8 52 | # use_geometry: True # when using learned geometry, could have much higher tolerance for distance 53 | achro_light: True 54 | albedo_smooth_weight: 5.0e-4 55 | albedo_sparsity: 5.0e-5 56 | exp_name: xuzhen_12v_geo_fix_mat 57 | geometry_mesh: data/animation/deform/xuzhen_12v_geo/can_mesh.npz 58 | geometry_pretrain: data/trained_model/deform/xuzhen_12v_geo 59 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Training mi11_1k_feat", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "train_net.py", 12 | "console": "integratedTerminal", 13 | "justMyCode": false, 14 | "env": { 15 | // "CUDA_VISIBLE_DEVICES": "0" 16 | }, 17 | "args": [ 18 | "--cfg_file", 19 | "configs/rendering/mi11_1k_feat.yaml", 20 | "exp_name", 21 | "mi11_1k_feat_test", 22 | "suppress_timer", 23 | "False", 24 | "gpus", 25 | "5,", 26 | ] 27 | }, 28 | { 29 | "name": "Python: Current File", 30 | "type": "python", 31 | "request": "launch", 32 | "program": "${file}", 33 | "console": "integratedTerminal", 34 | "justMyCode": false, 35 | }, 36 | { 37 | "name": "Training Tracking", 38 | "type": "python", 39 | "request": "launch", 40 | "program": "train_net.py", 41 | "console": "integratedTerminal", 42 | "justMyCode": false, 43 | "env": { 44 | // "CUDA_VISIBLE_DEVICES": "0" 45 | }, 46 | "args": [ 47 | "--cfg_file", 48 | "configs/rendering/base_zju_a1s0_smplh_tracking.yaml", 49 | "gpus", 50 | "1,", 51 | ] 52 | }, 53 | ] 54 | } -------------------------------------------------------------------------------- /lib/visualizers/demo_visualizer.py: -------------------------------------------------------------------------------- 1 | # This file is reused when we're performing textured rendering or just plain old rendering 2 | import os 3 | 4 | from lib.config import cfg 5 | from lib.config.config import Output 6 | from lib.utils.base_utils import dotdict 7 | from lib.utils.log_utils import log 8 | from lib.utils.data_utils import save_image 9 | from . import base_visualizer 10 | 11 | 12 | class Visualizer(base_visualizer.Visualizer): 13 | def prepare_result_paths(self): 14 | data_dir = f'data/novel_view/{cfg.exp_name}' 15 | motion_name = os.path.splitext(os.path.basename(cfg.test_motion))[0] 16 | 17 | if cfg.perform: 18 | img_path = f'{data_dir}/{motion_name}/{{type}}/frame{{view:04d}}_view{{frame:04d}}{cfg.vis_ext}' 19 | elif 'sfm' in cfg.test_dataset_module or 'mipnerf360' in cfg.test_dataset_module: # special treatment for sfm datasets 20 | img_path = f'{data_dir}/{{type}}/frame{{frame:04d}}_view{{view:04d}}{cfg.vis_ext}' # TODO: this is evil 21 | else: 22 | img_path = f'{data_dir}/frame_{{frame:04d}}/{{type}}/{{view:04d}}{cfg.vis_ext}' 23 | 24 | self.result_dir = os.path.dirname(img_path) 25 | self.img_path = img_path 26 | 27 | def visualize_single_type(self, output: dotdict, batch: dotdict, type: Output=Output.Rendering): 28 | img_pred = Visualizer.generate_image(output, batch, type) 29 | frame_index = batch.meta.frame_index.item() 30 | view_index = batch.meta.view_index.item() 31 | self.view_index = view_index # for generating video 32 | self.frame_index = frame_index # for generating video 33 | 34 | img_path = self.img_path.format(type=type.name.lower(), frame=frame_index, view=view_index) 35 | os.makedirs(os.path.dirname(img_path), exist_ok=True) 36 | save_image(img_path, img_pred) 37 | -------------------------------------------------------------------------------- /lib/train/scheduler.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from .optimizers.lr_scheduler import WarmupMultiStepLR, MultiStepLR, ExponentialLR, WarmupExponentialLR 3 | 4 | 5 | def make_lr_scheduler(cfg, optimizer): 6 | cfg_scheduler = cfg.train.scheduler 7 | if cfg_scheduler.type == 'multi_step': 8 | scheduler = MultiStepLR(optimizer, 9 | milestones=cfg_scheduler.milestones, 10 | gamma=cfg_scheduler.gamma) 11 | elif cfg_scheduler.type == 'exponential': 12 | scheduler = ExponentialLR(optimizer, 13 | decay_epochs=cfg_scheduler.decay_epochs, 14 | gamma=cfg_scheduler.gamma) 15 | elif cfg_scheduler.type == 'warmup_exponential': 16 | scheduler = WarmupExponentialLR(optimizer, 17 | warmup_factor=cfg_scheduler.warmup_factor, 18 | warmup_epochs=cfg_scheduler.warmup_epochs, 19 | warmup_method=cfg_scheduler.warmup_method, 20 | decay_epochs=cfg_scheduler.decay_epochs, 21 | gamma=cfg_scheduler.gamma) 22 | return scheduler 23 | 24 | 25 | def set_lr_scheduler(cfg, scheduler): 26 | cfg_scheduler = cfg.train.scheduler 27 | if cfg_scheduler.type == 'multi_step': 28 | scheduler.milestones = Counter(cfg_scheduler.milestones) 29 | elif cfg_scheduler.type == 'exponential': 30 | scheduler.decay_epochs = cfg_scheduler.decay_epochs 31 | elif cfg_scheduler.type == 'warmup_exponential': 32 | scheduler.warmup_factor = cfg_scheduler.warmup_factor 33 | scheduler.warmup_epochs = cfg_scheduler.warmup_epochs 34 | scheduler.warmup_method = cfg_scheduler.warmup_method 35 | scheduler.decay_epochs = cfg_scheduler.decay_epochs 36 | scheduler.gamma = cfg_scheduler.gamma 37 | -------------------------------------------------------------------------------- /scripts/tools/visualize_npfiles.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from functools import reduce 5 | # This script serves the purpose of reading the content and meaningfully visualize the data a numpy dict 6 | # that is, a dict whose keys are numpy arrays 7 | 8 | import argparse 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('input', default='data/xuzhen36/talk/lbs/smpl_params.npy') 11 | parser.add_argument('-o', '--output', default='') 12 | args = parser.parse_args() 13 | 14 | if not args.output: 15 | args.output = os.path.splitext(args.input)[0] + '.xlsx' 16 | 17 | d = np.load(args.input, allow_pickle=True) 18 | if isinstance(d, np.ndarray) and np.squeeze(d).size == 1: 19 | # the outer most scope is already a dict 20 | d = d.item() 21 | assert isinstance(d, dict) 22 | elif isinstance(d, np.lib.npyio.NpzFile): 23 | d = {**d} 24 | else: 25 | # should apply an augmentation 26 | n = os.path.splitext(os.path.basename(args.input))[0] 27 | d = {n: d} 28 | 29 | 30 | w = pd.ExcelWriter(args.output) 31 | 32 | 33 | def get_indices(*shapes): 34 | inds = np.stack(np.meshgrid(*[np.arange(s) for s in shapes], indexing='ij'), axis=-1) 35 | inds = reduce(np.char.add, np.split(inds.astype(str), inds.shape[-1], axis=-1)) 36 | return inds.ravel() 37 | 38 | 39 | def traverse(d, w, key_prefix=''): 40 | for key, item in d.items(): 41 | key_full = key_prefix + '.' + key if key_prefix else key 42 | if isinstance(item, dict): 43 | traverse(item, w, key_full) 44 | else: 45 | item = np.array(item) 46 | if item.ndim == 1: 47 | item = item[:, None] 48 | # item might be a high-d np array 49 | df = pd.DataFrame(item.reshape(item.shape[0], -1), columns=get_indices(*item.shape[1:])) 50 | df.to_excel(w, key_full) 51 | 52 | # apply tree traversal 53 | traverse(d, w) 54 | print(f'writing to: {args.output}') 55 | w.close() 56 | -------------------------------------------------------------------------------- /lib/utils/prof_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from lib.utils.log_utils import log 4 | from lib.utils.base_utils import dotdict, context 5 | from torch.profiler import profile, record_function, ProfilerActivity, schedule 6 | 7 | context.prof_cfg = dotdict() 8 | context.prof_cfg.enabled = False 9 | 10 | 11 | def profiler_step(): 12 | if context.prof_cfg.enabled: 13 | context.profiler.step() 14 | 15 | 16 | def profiler_start(): 17 | if context.prof_cfg.enabled: 18 | context.profiler.start() 19 | 20 | 21 | def profiler_stop(): 22 | if context.prof_cfg.enabled: 23 | context.profiler.stop() 24 | 25 | 26 | def setup_profiling(prof_cfg): 27 | if prof_cfg.enabled: 28 | log(f"profiling results will be saved to: {prof_cfg.record_dir}", 'yellow') 29 | if prof_cfg.clear_previous: 30 | log(f'removing profiling result in: {prof_cfg.record_dir}', 'red') 31 | os.system(f'rm -rf {prof_cfg.record_dir}') 32 | profiler = profile(schedule=schedule(skip_first=prof_cfg.skip_first, 33 | wait=prof_cfg.wait, 34 | warmup=prof_cfg.warmup, 35 | active=prof_cfg.active, 36 | repeat=prof_cfg.repeat, 37 | ), 38 | activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 39 | on_trace_ready=torch.profiler.tensorboard_trace_handler(prof_cfg.record_dir), 40 | record_shapes=prof_cfg.record_shapes, 41 | profile_memory=prof_cfg.profile_memory, 42 | with_stack=prof_cfg.with_stack, # sometimes with_stack causes segmentation fault 43 | with_flops=prof_cfg.with_flops, 44 | with_modules=prof_cfg.with_modules 45 | ) 46 | context.profiler = profiler 47 | context.prof_cfg = prof_cfg 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # please install pip requirements using: 2 | # cat requirements.txt | sed -e '/^\s*#.*$/d' -e '/^\s*$/d' | xargs -n 1 pip install 3 | # otherwise one single compilation error could eat away like 20 minutes for nothing. 4 | 5 | # other requirements 6 | yacs 7 | tqdm 8 | rich 9 | sympy 10 | pillow 11 | trimesh 12 | imageio 13 | tensorboard 14 | scikit-image 15 | scikit-learn 16 | torch-tb-profiler 17 | 18 | # other requirements not available in conda 19 | smplx 20 | pymcubes 21 | opencv-python 22 | 23 | # dev requirements 24 | h5py 25 | ninja 26 | lpips 27 | ujson 28 | pandas 29 | # for unwrapping to get StVK properly 30 | kornia 31 | jupyter 32 | autopep8 33 | pyntcloud 34 | matplotlib 35 | ruamel.yaml 36 | commentjson 37 | 38 | # external dependency: easymocap-public (this repo is not publicly available yet) 39 | # for easymocap's vposer: human_pose_prior, this looks like my DotDict implementation... just way more complex 40 | dotmap 41 | # for easymocap loading of SMPL (maybe all pickle loading of SMPL?) 42 | chumpy 43 | mediapipe 44 | func_timeout 45 | pycocotools 46 | spconv-cu116 47 | tensorboardX 48 | git+https://github.com/mmatl/pyopengl 49 | git+https://github.com/nghorbani/human_body_prior 50 | git+https://github.com/zju3dv/EasyMocap 51 | 52 | # !: prone to change pytorch version, please install these on demand and manually 53 | # functorch 54 | torch-scatter 55 | 56 | # https://storage.googleapis.com/open3d-releases-master/python-wheels/open3d-0.16.0-cp310-cp310-manylinux_2_27_x86_64.whl # TODO: fix the quirky install 57 | # http://www.open3d.org/docs/release/getting_started.html (install the development version from here if the previsou link is expired and python is too new) 58 | # python3.10 support for open3d finally here 59 | # if failed to install open3d (even when installing from latest release?), try to skip it using 60 | # pip install $(grep -v '^ *#\|^open3d' requirements.txt | grep .) 61 | # open3d 62 | open3d 63 | 64 | # pip install $(grep -v '^ *#\|^.*open3d\|^torch-sparse\|^torch-geometric\|^.*cholespy\|^.*pytorch3d\|^.*pyopengl' requirements.txt | grep .) 65 | -------------------------------------------------------------------------------- /lib/utils/parallel_utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from typing import Callable 3 | from multiprocessing.pool import ThreadPool 4 | 5 | 6 | def parallel_execution(*args, action: Callable, num_processes=16, print_progress=False, sequential=False, **kwargs): 7 | # NOTE: we expect first arg / or kwargs to be distributed 8 | # NOTE: print_progress arg is reserved 9 | 10 | def get_valid_arg(args, kwargs): return args[0] if isinstance(args[0], list) else next(iter(kwargs.values())) # TODO: search through them all 11 | 12 | def get_action_args(valid_arg, args, kwargs, i): 13 | action_args = [(arg[i] if isinstance(arg, list) and len(arg) == len(valid_arg) else arg) for arg in args] 14 | action_kwargs = {key: (kwargs[key][i] if isinstance(kwargs[key], list) and len(kwargs[key]) == len(valid_arg) else kwargs[key]) for key in kwargs} 15 | return action_args, action_kwargs 16 | 17 | def maybe_tqdm(x): return tqdm(x) if print_progress else x 18 | 19 | if not sequential: 20 | # Create ThreadPool 21 | pool = ThreadPool(processes=num_processes) 22 | 23 | # Spawn threads 24 | results = [] 25 | asyncs = [] 26 | valid_arg = get_valid_arg(args, kwargs) 27 | for i in range(len(valid_arg)): 28 | action_args, action_kwargs = get_action_args(valid_arg, args, kwargs, i) 29 | async_result = pool.apply_async(action, action_args, action_kwargs) 30 | asyncs.append(async_result) 31 | 32 | # Join threads and get return values 33 | for async_result in maybe_tqdm(asyncs): 34 | results.append(async_result.get()) # will sync the corresponding thread 35 | pool.close() 36 | pool.join() 37 | return results 38 | else: 39 | results = [] 40 | valid_arg = get_valid_arg(args, kwargs) 41 | for i in maybe_tqdm(range(len(valid_arg))): 42 | action_args, action_kwargs = get_action_args(valid_arg, args, kwargs, i) 43 | async_result = action(*action_args, **action_kwargs) 44 | results.append(async_result) 45 | return results 46 | -------------------------------------------------------------------------------- /lib/datasets/mesh_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from lib.config import cfg 6 | from lib.utils.data_utils import project, read_mask_by_img_path 7 | from lib.utils.base_utils import dotdict 8 | 9 | from . import pose_dataset 10 | 11 | 12 | class Dataset(pose_dataset.Dataset): 13 | def __init__(self, data_root, human, ann_file, split, **kwargs): 14 | super(Dataset, self).__init__(data_root, human, ann_file, split) 15 | 16 | def load_image_size(self): 17 | pass 18 | 19 | def load_view(self): 20 | self.view = [0, ] 21 | self.num_cams = 1 # only one camera for mesh extraction 22 | 23 | def load_camera(self): 24 | pass 25 | 26 | def get_indices(self, index): 27 | latent_index = index 28 | i = latent_index 29 | if latent_index == -1: 30 | i = 0 # load data of first frame if index is -1 31 | frame_index = self.i + i * self.i_intv # recompute frame index for i 32 | return latent_index, frame_index, -1, -1 33 | 34 | def __getitem__(self, index): # TODO: might get -1 in optimization, but doesn't matter now 35 | latent_index, frame_index, _, _ = self.get_indices(index) 36 | i = frame_index 37 | 38 | # load SMPL & pose & human related parameters 39 | ret = self.get_blend(i) 40 | 41 | voxel_size = cfg.voxel_size 42 | if cfg.vis_can_mesh or cfg.vis_tpose_mesh: 43 | bounds = ret.tbounds 44 | else: 45 | bounds = ret.wbounds 46 | x = torch.arange(bounds[0, 0], bounds[1, 0] + voxel_size[0], voxel_size[0]) 47 | y = torch.arange(bounds[0, 1], bounds[1, 1] + voxel_size[1], voxel_size[1]) 48 | z = torch.arange(bounds[0, 2], bounds[1, 2] + voxel_size[2], voxel_size[2]) 49 | pts = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).float().numpy() 50 | 51 | meta = { 52 | 'pts': pts, 53 | 'latent_index': latent_index, 54 | 'frame_index': frame_index, 55 | 'view_index': 0, # no view 56 | } 57 | ret.update(meta) 58 | ret.meta.update(meta) 59 | 60 | return dotdict(ret) 61 | -------------------------------------------------------------------------------- /lib/visualizers/mesh_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from lib.config import cfg 4 | from lib.utils.data_utils import export_dotdict, export_mesh 5 | from lib.utils.log_utils import log, print_colorful_stacktrace, run 6 | from . import base_visualizer 7 | 8 | 9 | class Visualizer(base_visualizer.Visualizer): 10 | def __init__(self): 11 | self.result_dir = 'data/animation/{}/{}'.format(cfg.task, cfg.exp_name) 12 | log('the results are saved at {}'.format(self.result_dir), 'yellow') 13 | 14 | def visualize(self, output, batch): 15 | result_dir = self.result_dir 16 | 17 | if (batch.meta.latent_index.item() == -1 and cfg.vis_tpose_mesh and cfg.track_tpose_mesh) or cfg.vis_can_mesh: 18 | os.makedirs(result_dir, exist_ok=True) 19 | meta_path = os.path.join(result_dir, 'can_mesh.npz') 20 | mesh_path = os.path.join(result_dir, 'can_mesh.ply') 21 | elif cfg.vis_posed_mesh or cfg.vis_tpose_mesh: 22 | if cfg.track_tpose_mesh: 23 | result_dir = os.path.join(result_dir, 'track_mesh') 24 | elif cfg.vis_posed_mesh: 25 | result_dir = os.path.join(result_dir, 'posed_mesh') 26 | elif cfg.vis_tpose_mesh: 27 | result_dir = os.path.join(result_dir, 'tpose_mesh') 28 | 29 | os.makedirs(result_dir, exist_ok=True) 30 | frame_index = batch.meta.frame_index.item() 31 | view_index = batch.meta.view_index.item() 32 | meta_path = os.path.join(result_dir, '{:04d}.npz'.format(frame_index)) 33 | mesh_path = os.path.join(result_dir, '{:04d}.ply'.format(frame_index)) 34 | 35 | export_dotdict(output, meta_path) 36 | export_mesh(output.verts, output.faces, filename=mesh_path) 37 | 38 | try: 39 | run(f'blender --background --python-expr \"import sys; sys.path.append(\'.\'); from lib.utils.blender_utils import replace_weights; replace_weights(\'{meta_path}\', \'{meta_path}\')\"') 40 | except Exception as e: 41 | log('blender not found or returned error, will use SMPL blend weights and they might be janky. Maybe try to install blender from https://www.blender.org/download/. Use the log above for more info.', 'red') 42 | 43 | def prepare_result_paths(self): pass 44 | 45 | def update_result_dir(self): pass 46 | 47 | def summarize(self): pass 48 | -------------------------------------------------------------------------------- /lib/train/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from termcolor import colored 4 | from lib.utils.log_utils import log 5 | from .optimizers.radam import RAdam 6 | from torch.distributed.optim import ZeroRedundancyOptimizer 7 | 8 | 9 | _optimizer_factory = { 10 | 'adam': torch.optim.Adam, 11 | 'radam': RAdam, 12 | 'sgd': torch.optim.SGD 13 | } 14 | 15 | 16 | def make_optimizer(cfg, net: nn.Module, lr=None, eps=None, weight_decay=None): 17 | params = [] 18 | eps = cfg.train.eps if eps is None else eps 19 | lr = cfg.train.lr if lr is None else lr 20 | weight_decay = cfg.train.weight_decay if weight_decay is None else weight_decay 21 | 22 | special = [] 23 | for key, value in net.named_parameters(): 24 | if not value.requires_grad: 25 | continue 26 | 27 | v_lr = lr 28 | v_eps = eps 29 | v_weight_decay = weight_decay 30 | 31 | keys = key.split('.') 32 | for item in keys: 33 | if item in cfg.train.lr_table: 34 | v_lr = cfg.train.lr_table[item] 35 | special.append(f'{key}: {colored(f"{v_lr:g}", "magenta")}') 36 | break 37 | for item in keys: 38 | if item in cfg.train.eps_table: 39 | v_eps = cfg.train.eps_table[item] 40 | break 41 | for item in keys: 42 | if item in cfg.train.weight_decap_table: 43 | v_weight_decay = cfg.train.weight_decap_table[item] 44 | break 45 | params += [{"params": [value], "lr": v_lr, "weight_decay": v_weight_decay, 'eps': v_eps}] 46 | 47 | log(f'default learning rate: {colored(f"{lr:g}", "magenta")}') 48 | if len(special): 49 | log(f'special learning rate loaded from lr table: \n' + '\n'.join(special)) 50 | 51 | if 'adam' in cfg.train.optim: 52 | # if cfg.distributed: 53 | # optimizer = ZeroRedundancyOptimizer(params, optimizer_class=_optimizer_factory[cfg.train.optim], lr=lr, weight_decay=weight_decay) 54 | # else: 55 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, weight_decay=weight_decay) 56 | else: 57 | # if cfg.distributed: 58 | # optimizer = ZeroRedundancyOptimizer(params, optimizer_class=_optimizer_factory[cfg.train.optim], lr=lr) 59 | # else: 60 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, momentum=0.9) 61 | 62 | return optimizer 63 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import time 2 | import tqdm 3 | import torch 4 | import numpy as np 5 | 6 | from lib.utils.net_utils import load_network 7 | from lib.utils.data_utils import to_cuda 8 | from lib.evaluators import make_evaluator 9 | from lib.visualizers import make_visualizer 10 | from lib.networks.renderer import make_renderer 11 | from lib.datasets import make_data_loader 12 | from lib.networks import make_network 13 | from lib.config import cfg, args 14 | from lib.utils.log_utils import log 15 | 16 | import cv2 17 | cv2.setNumThreads(1) 18 | cfg.train.num_workers = 0 # no multi-process dataloading needed when visualizing 19 | 20 | 21 | @torch.no_grad() 22 | def run_dataset(): 23 | from lib.datasets import make_data_loader 24 | import tqdm 25 | 26 | data_loader = make_data_loader(cfg, is_train=False) 27 | for batch in tqdm.tqdm(data_loader): 28 | pass 29 | 30 | 31 | @torch.no_grad() 32 | def run_network(): 33 | network = make_network(cfg).cuda() 34 | load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch) 35 | network.eval() 36 | renderer = make_renderer(cfg, network) 37 | 38 | data_loader = make_data_loader(cfg, is_train=False) 39 | total_time = 0 40 | for batch in tqdm.tqdm(data_loader): 41 | batch = to_cuda(batch) 42 | 43 | torch.cuda.synchronize() 44 | start = time.time() 45 | output = renderer.render(batch) 46 | torch.cuda.synchronize() 47 | total_time += time.time() - start 48 | 49 | log(total_time / len(data_loader)) 50 | 51 | 52 | @torch.no_grad() 53 | def run_evaluate(): 54 | network = make_network(cfg).cuda() 55 | load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch) 56 | network.eval() 57 | 58 | data_loader = make_data_loader(cfg, is_train=False) 59 | renderer = make_renderer(cfg, network) 60 | evaluator = make_evaluator(cfg) 61 | for batch in tqdm.tqdm(data_loader): 62 | batch = to_cuda(batch) 63 | output = renderer.render(batch) 64 | evaluator.evaluate(output, batch) 65 | evaluator.summarize() 66 | 67 | 68 | @torch.no_grad() 69 | def run_visualize(): 70 | network = make_network(cfg).cuda() 71 | load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch) 72 | network.eval() 73 | diffs = [] 74 | 75 | data_loader = make_data_loader(cfg, is_train=False) 76 | renderer = make_renderer(cfg, network) 77 | visualizer = make_visualizer(cfg) 78 | for batch in tqdm.tqdm(data_loader): 79 | batch = to_cuda(batch, 'cuda') 80 | output = renderer.render(batch) 81 | 82 | if 'diff' in output: 83 | diffs.append(output.diff) 84 | del output.diff 85 | visualizer.visualize(output, batch) 86 | visualizer.summarize() 87 | 88 | if len(diffs): 89 | log(f'###################{cfg.exp_name}###################', 'green') 90 | log(f'Net work rendering time: {np.mean(diffs)}', 'green') 91 | 92 | 93 | if __name__ == '__main__': 94 | try: 95 | globals()['run_' + args.type]() 96 | except: 97 | import pdbr 98 | pdbr.post_mortem() 99 | -------------------------------------------------------------------------------- /lib/visualizers/light_visualizer.py: -------------------------------------------------------------------------------- 1 | # This file is reused when we're performing textured rendering or just plain old rendering 2 | import os 3 | import torch 4 | from os.path import join 5 | 6 | from lib.config import cfg 7 | from lib.config.config import Output 8 | from lib.utils.base_utils import dotdict 9 | from lib.utils.data_utils import save_image, to_cuda 10 | from lib.utils.parallel_utils import parallel_execution 11 | from . import base_visualizer 12 | 13 | 14 | class Visualizer(base_visualizer.Visualizer): 15 | def prepare_result_paths(self): 16 | data_dir = f'data/novel_light/{cfg.exp_name}' 17 | data_dir = join(data_dir, cfg.extra_prefix) if cfg.extra_prefix else data_dir # differentiate between video and evals 18 | motion_name = os.path.splitext(os.path.basename(cfg.test_motion))[0] 19 | 20 | if len(cfg.test_view) == 1: 21 | img_path = f'{data_dir}/view_{{view:04d}}/{{type}}/{{frame:04d}}{cfg.vis_ext}' 22 | elif cfg.num_eval_frame == 1: 23 | img_path = f'{data_dir}/frame{{frame:04d}}/{{type}}/{{view:04d}}{cfg.vis_ext}' 24 | else: 25 | img_path = f'{data_dir}/{motion_name}/{{type}}/frame{{frame:04d}}_view{{view:04d}}{cfg.vis_ext}' 26 | 27 | self.result_dir = os.path.dirname(img_path) 28 | self.img_path = img_path 29 | 30 | def visualize_single_type(self, output: dotdict, batch: dotdict, type: Output = Output.Rendering): 31 | frame_index = batch.meta.frame_index.item() 32 | view_index = batch.meta.view_index.item() 33 | self.view_index = view_index # for generating video 34 | self.frame_index = frame_index # for generating video 35 | 36 | img_path = self.img_path.format(type=type.name.lower(), frame=frame_index, view=view_index) 37 | os.makedirs(os.path.dirname(img_path), exist_ok=True) 38 | 39 | img_preds = [] 40 | img_paths = [] 41 | for name, out in output.items(): 42 | out = to_cuda(out, batch.latent_index.device) # only add light probe involves interaction between out and batch 43 | img_pred = Visualizer.generate_image(out, batch, type) 44 | if cfg.vis_rotate_light: 45 | img_path_light = os.path.join(os.path.dirname(img_path), name + "_" + os.path.basename(img_path)) 46 | else: 47 | img_path_light = os.path.join(os.path.dirname(img_path), name, os.path.basename(img_path)) 48 | img_paths.append(img_path_light) 49 | img_preds.append(img_pred) 50 | 51 | parallel_execution(img_paths, img_preds, action=save_image) 52 | 53 | def summarize(self): 54 | for type in self.types: 55 | result_dir = os.path.dirname(self.img_path).format(type=type.name.lower(), view=self.view_index, frame=self.frame_index) 56 | for light in cfg.test_light: 57 | result_str = join(result_dir, light) 58 | if cfg.vis_rotate_light: 59 | result_str = f'"{result_str}-*{cfg.vis_ext}"' 60 | else: 61 | result_str = f'"{result_str}/*{cfg.vis_ext}"' 62 | Visualizer.generate_video(result_str) 63 | -------------------------------------------------------------------------------- /lib/evaluators/mesh_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import trimesh 4 | import numpy as np 5 | from PIL import Image 6 | import trimesh.sample 7 | import trimesh.proximity 8 | 9 | from lib.utils.log_utils import log 10 | 11 | 12 | class Evaluator: 13 | def __init__(self) -> None: 14 | self.p2ss = [] 15 | self.chamfers = [] 16 | self.mesh_eval = MeshEvaluator() 17 | 18 | def evaluate(self, output, batch): 19 | src_mesh = output["mesh"] 20 | self.mesh_eval.set_src_mesh(src_mesh) 21 | self.mesh_eval.set_tgt_mesh(batch["obj"][0]) 22 | chamfer = self.mesh_eval.get_chamfer_dist() 23 | p2s = self.mesh_eval.get_surface_dist() 24 | self.chamfers.append(chamfer) 25 | self.p2ss.append(p2s) 26 | log("Chamfer: {}, P2S: {}".format(chamfer, p2s)) 27 | 28 | def summarize(self): 29 | self.p2ss = np.array(self.p2ss) 30 | self.chamfers = np.array(self.chamfers) 31 | np.save("p2s.npy", self.p2ss) 32 | np.save("chamfer.npy", self.chamfers) 33 | pass 34 | 35 | 36 | class MeshEvaluator: 37 | """ 38 | From https://github.com/facebookresearch/pifuhd/blob/master/lib/evaluator.py 39 | """ 40 | _normal_render = None 41 | 42 | def __init__(self, scale_factor=1.0, offset=0): 43 | self.scale_factor = scale_factor 44 | self.offset = offset 45 | pass 46 | 47 | def set_mesh(self, src_path, tgt_path): 48 | self.src_mesh = trimesh.load(src_path) 49 | self.tgt_mesh = trimesh.load(tgt_path) 50 | 51 | def apply_registration(self): 52 | transform, _ = trimesh.registration.mesh_other(self.src_mesh, 53 | self.tgt_mesh) 54 | self.src_mesh.apply_transform(transform) 55 | 56 | def set_src_mesh(self, mesh): 57 | self.src_mesh = mesh 58 | 59 | def set_tgt_mesh(self, tgt_path: str): 60 | self.tgt_mesh = trimesh.load(tgt_path) 61 | 62 | def get_chamfer_dist(self, num_samples=1000): 63 | # breakpoint() 64 | # Chamfer 65 | src_surf_pts, _ = trimesh.sample.sample_surface( 66 | self.src_mesh, num_samples) 67 | # self.src_mesh.show() 68 | tgt_surf_pts, _ = trimesh.sample.sample_surface( 69 | self.tgt_mesh, num_samples) 70 | 71 | _, src_tgt_dist, _ = trimesh.proximity.closest_point( 72 | self.tgt_mesh, src_surf_pts) 73 | _, tgt_src_dist, _ = trimesh.proximity.closest_point( 74 | self.src_mesh, tgt_surf_pts) 75 | 76 | src_tgt_dist[np.isnan(src_tgt_dist)] = 0 77 | tgt_src_dist[np.isnan(tgt_src_dist)] = 0 78 | 79 | src_tgt_dist = src_tgt_dist.mean() 80 | tgt_src_dist = tgt_src_dist.mean() 81 | 82 | chamfer_dist = (src_tgt_dist + tgt_src_dist) / 2 83 | 84 | return chamfer_dist 85 | 86 | def get_surface_dist(self, num_samples=10000): 87 | # P2S 88 | src_surf_pts, _ = trimesh.sample.sample_surface( 89 | self.src_mesh, num_samples) 90 | 91 | _, src_tgt_dist, _ = trimesh.proximity.closest_point( 92 | self.tgt_mesh, src_surf_pts) 93 | 94 | src_tgt_dist[np.isnan(src_tgt_dist)] = 0 95 | 96 | src_tgt_dist = src_tgt_dist.mean() 97 | 98 | return src_tgt_dist 99 | -------------------------------------------------------------------------------- /lib/datasets/demo_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | 5 | from lib.config import cfg 6 | from lib.utils.log_utils import log 7 | from lib.utils.render_utils import gen_path 8 | from lib.utils.data_utils import get_rays_within_bounds 9 | from lib.datasets import pose_dataset 10 | 11 | 12 | class Dataset(pose_dataset.Dataset): 13 | def __init__(self, data_root, human, ann_file, split): 14 | super(Dataset, self).__init__(data_root, human, ann_file, split) 15 | self.load_render() 16 | 17 | def load_render(self): 18 | self.render_w2c = gen_path(self.RT, cfg.novel_view_center, cfg.novel_view_z_off) 19 | self.num_cams = len(self.render_w2c) 20 | self.K = self.Ks[0] 21 | self.K[0, 0] *= cfg.novel_view_ixt_ratio 22 | self.K[1, 1] *= cfg.novel_view_ixt_ratio 23 | if len(self.render_w2c) > len(self.motion.poses) and cfg.perform: 24 | log(f'will render {len(self.render_w2c)} views on only {len(self.motion.poses)} poses', 'yellow') 25 | 26 | def get_indices(self, index): 27 | if cfg.perform: # ? BUG 28 | latent_index = index 29 | else: 30 | latent_index = 0 31 | view_index = index 32 | frame_index = self.i + latent_index * self.i_intv # recompute frame index for i 33 | cam_index = view_index 34 | return latent_index, frame_index, view_index, cam_index 35 | 36 | def __getitem__(self, index): 37 | latent_index, frame_index, view_index, cam_index = self.get_indices(index) 38 | ret = self.get_blend(frame_index) 39 | 40 | # reduce the image resolution by ratio 41 | if cfg.H <= 0 or cfg.W <= 0: 42 | img_path = os.path.join(self.data_root, self.annots['ims'][0]['ims'][0]) # need to get H, W 43 | img = imageio.imread(img_path) 44 | H, W = img.shape[:2] 45 | H, W = int(H * cfg.ratio), int(W * cfg.ratio) 46 | K = self.K 47 | else: 48 | H, W = cfg.H, cfg.W 49 | K = np.zeros((3, 3), dtype=np.float32) 50 | K[2, 2] = 1 51 | K[0, 0] = H * cfg.novel_view_ixt_ratio 52 | K[1, 1] = H * cfg.novel_view_ixt_ratio 53 | K[0, 2] = H / 2 54 | K[1, 2] = H / 2 55 | 56 | RT = self.render_w2c[view_index] 57 | R, T = RT[:3, :3], RT[:3, 3:] 58 | ray_o, ray_d, near, far, mask_at_box = get_rays_within_bounds(H, W, K, R, T, ret.wbounds) 59 | 60 | # store camera parameters 61 | meta = { 62 | "cam_K": K, 63 | "cam_R": R, 64 | "cam_T": T, 65 | "cam_RT": np.concatenate([R, T], axis=1), 66 | 'H': H, 67 | 'W': W, 68 | 'RT': self.RT, 69 | 'Ks': self.Ks, 70 | } 71 | ret.update(meta) 72 | ret.meta.update(meta) 73 | 74 | # store ray data 75 | meta = { 76 | 'ray_o': ray_o, 77 | 'ray_d': ray_d, 78 | 'near': near, 79 | 'far': far, 80 | 'mask_at_box': mask_at_box, 81 | } 82 | ret.update(meta) 83 | 84 | # store index data 85 | meta = { 86 | 'latent_index': latent_index, 87 | 'frame_index': frame_index, 88 | 'view_index': view_index, 89 | } 90 | ret.update(meta) 91 | ret.meta.update(meta) 92 | return ret 93 | 94 | def __len__(self): 95 | return len(self.render_w2c) 96 | -------------------------------------------------------------------------------- /scripts/tools/prepare_config.py: -------------------------------------------------------------------------------- 1 | # The configuration system used in this project: yacs (yaml based) 2 | # does not provide a good enough support for loading multiple parent configuration files 3 | # however, the baseline experiments might result in bloated configs 4 | # we need two different kinds of config source: data and experiment 5 | 6 | import os 7 | import ruamel.yaml as yaml 8 | 9 | from glob import glob 10 | from os.path import join 11 | import argparse 12 | 13 | 14 | def walk_config(exp, data, exp_name, data_name, exp_keys): 15 | for key in exp_keys: 16 | if key in exp and key in data: 17 | if isinstance(exp[key], dict) and isinstance(data[key], dict): 18 | walk_config(exp[key], data[key], exp_name, data_name, exp_keys) 19 | elif isinstance(exp[key], str) and isinstance(data[key], str): 20 | data[key] = exp[key].replace(exp_name, data_name) 21 | else: 22 | raise NotImplementedError('Unsupported config type to replace') 23 | 24 | 25 | def main(): 26 | 27 | exp_keys = [ 28 | 'relighting_cfg', 29 | 'exp_name', 30 | 'parent_cfg', 31 | 'geometry_mesh', 32 | 'geometry_pretrain', 33 | ] # other keys are data related keys (shared across experiments) 34 | datasets = ['mobile_stage', 'synthetic_human'] 35 | experiments = ['nerf', 'neuralbody', 'brute'] 36 | data_file_prefix = 'base' # we define data entries here 37 | exp_file_template = 'configs/synthetic_human/base_synthetic_jody.yaml' # this means we've defined exp entries on jody 38 | configs_root = 'configs' 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--exp_keys', nargs='+', default=exp_keys) 42 | parser.add_argument('--experiments', nargs='+', default=experiments) 43 | parser.add_argument('--datasets', nargs='+', default=datasets) 44 | parser.add_argument('--data_file_prefix', type=str, default=data_file_prefix) 45 | parser.add_argument('--exp_file_template', type=str, default=exp_file_template) 46 | parser.add_argument('--configs_root', type=str, default=configs_root) 47 | args = parser.parse_args() 48 | exp_keys = args.exp_keys 49 | datasets = args.datasets 50 | experiments = args.experiments 51 | data_file_prefix = args.data_file_prefix 52 | exp_file_template = args.exp_file_template 53 | configs_root = args.configs_root 54 | 55 | for dataset in datasets: 56 | data_files = glob(join(configs_root, dataset, f'{data_file_prefix}*')) 57 | for experiment in experiments: 58 | exp_files = [exp_file_template.replace(data_file_prefix, experiment) for f in data_files] 59 | for data_file, exp_file in zip(data_files, exp_files): 60 | exp_name = os.path.splitext(exp_file)[0].split('_') # something like synthetic_jody 61 | exp_name = '_'.join(exp_name[-2:]) 62 | data_name = os.path.splitext(data_file)[0].split('_') # something like jody / josh 63 | data_name = '_'.join(data_name[-2:]) 64 | out_file = data_file.replace(data_file_prefix, experiment) 65 | 66 | exp = yaml.round_trip_load(open(exp_file)) 67 | data = yaml.round_trip_load(open(data_file)) 68 | 69 | # inplace modification 70 | walk_config(exp, data, exp_name, data_name, exp_keys) 71 | yaml.round_trip_dump(data, open(out_file, 'w')) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /lib/utils/base_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, TypeVar, Union, Iterable, Callable 2 | # these are generic type vars to tell mapping to accept any type vars when creating a type 3 | KT = TypeVar("KT") # key type 4 | VT = TypeVar("VT") # value type 5 | 6 | 7 | class dotdict(dict, Mapping[KT, VT]): 8 | """ 9 | a dictionary that supports dot notation 10 | as well as dictionary access notation 11 | usage: d = make_dotdict() or d = make_dotdict{'val1':'first'}) 12 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 13 | get attributes: d.val2 or d['val2'] 14 | """ 15 | 16 | def update(self, dct=None, **kwargs): 17 | if dct is None: 18 | dct = kwargs 19 | elif isinstance(dct, Mapping): 20 | dct.update(kwargs) 21 | else: 22 | raise TypeError("dct must be a mapping") 23 | for k, v in dct.items(): 24 | if k in self: 25 | target_type = type(self[k]) 26 | if not isinstance(v, target_type): 27 | # NOTE: bool('False') will be True 28 | if target_type == bool and isinstance(v, str): 29 | dct[k] = v == 'True' 30 | else: 31 | dct[k] = target_type(v) 32 | dict.update(self, dct) 33 | 34 | # def __hash__(self): 35 | # # return hash(''.join([str(self.values().__hash__())])) 36 | # return super(dotdict, self).__hash__() 37 | 38 | # def __init__(self, *args, **kwargs): 39 | # super(dotdict, self).__init__(*args, **kwargs) 40 | 41 | """ 42 | Uncomment following lines and 43 | comment out __getattr__ = dict.__getitem__ to get feature: 44 | 45 | returns empty numpy array for undefined keys, so that you can easily copy things around 46 | TODO: potential caveat, harder to trace where this is set to np.array([], dtype=np.float32) 47 | """ 48 | 49 | def __getitem__(self, key): 50 | try: 51 | return dict.__getitem__(self, key) 52 | except KeyError as e: 53 | raise AttributeError(e) 54 | # MARK: Might encounter exception in newer version of pytorch 55 | # Traceback (most recent call last): 56 | # File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 245, in _feed 57 | # obj = _ForkingPickler.dumps(obj) 58 | # File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps 59 | # cls(buf, protocol).dump(obj) 60 | # KeyError: '__getstate__' 61 | # MARK: Because you allow your __getattr__() implementation to raise the wrong kind of exception. 62 | # FIXME: not working typing hinting code 63 | __getattr__: Callable[..., 'torch.Tensor'] = __getitem__ # type: ignore # overidden dict.__getitem__ 64 | __getattribute__: Callable[..., 'torch.Tensor'] # type: ignore 65 | # __getattr__ = dict.__getitem__ 66 | __setattr__ = dict.__setitem__ 67 | __delattr__ = dict.__delitem__ 68 | 69 | 70 | class default_dotdict(dotdict): 71 | def __init__(self, type=object, *arg, **kwargs): 72 | super().__init__(*arg, **kwargs) 73 | dict.__setattr__(self, 'type', type) 74 | 75 | def __getitem__(self, key): 76 | try: 77 | return super().__getitem__(key) 78 | except (AttributeError, KeyError) as e: 79 | super().__setitem__(key, dict.__getattribute__(self, 'type')()) 80 | return super().__getitem__(key) 81 | 82 | 83 | context = dotdict() 84 | -------------------------------------------------------------------------------- /lib/utils/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from lib.utils.net_utils import MLP, GradModule 4 | from lib.networks.embedder import get_embedder 5 | 6 | from lib.config import cfg 7 | 8 | 9 | def parameterize(x: torch.Tensor, v: torch.Tensor): 10 | s = (-x * v).sum(dim=-1, keepdim=True) # closest distance along ray to origin 11 | o = s * v + x # closest point along ray to origin 12 | return o, s 13 | 14 | 15 | class DirectionalDistance(GradModule): 16 | def __init__(self, 17 | xyz_res=8, # use middle resolution? 18 | view_res=8, # are there still view encoding needs? 19 | cond_dim=cfg.cond_dim, 20 | ): 21 | super(DirectionalDistance, self).__init__() 22 | # our directional distance module for fast sphere tracing (pre-computation) 23 | # before learning them in a brute-force way, we need to analyze the properties of these modules 24 | # for directional distance, occ along the same ray should stay the same, we might need a specific network structure to ensure that 25 | # for directional intersection, distance to intersection point along the same ray should be constrained by a ray equation, too needs special structure 26 | # is it possible to define a neural field parameterized by ray instead of points and directions? 27 | # * a good enough parameterization should be: find closest points along ray, compute through network and then convert 28 | 29 | # this module stores the closest distance along a ray direction from a point in space along a ray direction 30 | self.xyz_embedder, xyz_dim = get_embedder(xyz_res, 3) # no parameters 31 | self.view_embedder, view_dim = get_embedder(view_res, 3) # no parameters 32 | # self.mlp = MLP(input_ch=xyz_dim + view_dim + cond_dim, W=256, D=8, out_ch=1 + 1, actvn=nn.Softplus(), out_actvn=nn.Identity()) 33 | self.directional_distance = MLP(input_ch=xyz_dim + view_dim + cond_dim, W=256, D=8, out_ch=1, actvn=nn.Softplus(), out_actvn=nn.Identity()) 34 | # this module stores the surface intersection point distance along a ray direction from a point in space along a ray direction 35 | self.directional_intersection = MLP(input_ch=xyz_dim + view_dim + cond_dim, W=256, D=8, out_ch=1, actvn=nn.Softplus(), out_actvn=nn.Identity()) 36 | # self.intersection_probability = MLP(input_ch=xyz_dim + view_dim + cond_dim, W=256, D=8, out_ch=1, actvn=nn.Softplus(), out_actvn=nn.Sigmoid()) 37 | 38 | def forward(self, x: torch.Tensor, v: torch.Tensor, c: torch.Tensor): 39 | 40 | # maybe expand condition vector 41 | if c.ndim == 2: 42 | c = c[: None].expand(*v.shape[:2], -1) 43 | 44 | # find parameterization for a particular ray 45 | o, s = parameterize(x, v) # origin distance and origin intersection 46 | 47 | # forward through the network 48 | ebd_o = self.xyz_embedder(o) 49 | ebd_v = self.view_embedder(v) 50 | input = torch.cat([ebd_o, ebd_v, c], dim=-1) 51 | # out = self.mlp(input) 52 | # dd, di = out.split([1, 1], dim=-1) 53 | dd = self.directional_distance(input) 54 | di = self.directional_intersection(input) 55 | # dd, di, pi = out.split([1, 1, 1], dim=-1) 56 | dd = dd.tanh() # one meter for closest distance should be enough? 57 | di = di.tanh() * cfg.clip_far * 2 # larger range for intersection distance (to cover far plane) 58 | # pi = pi.sigmoid() 59 | 60 | # intersection_mask = pi > 0.5 # MARK: GRAD 61 | # dd = ~intersection_mask * dd # not intersection -> use original value, intersection -> zero 62 | # di = pi * di + ~intersection_mask * cfg.clip_far # intersection -> use original value, not intersection -> use far 63 | di = di + s # plus the distance along ray to origin 64 | 65 | return dd, di 66 | -------------------------------------------------------------------------------- /lib/utils/sem_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from functools import lru_cache 5 | from lib.utils.net_utils import linear_gather 6 | 7 | # SCHP definitions 8 | semantic_list = [ 9 | 'background', 10 | 'hat', 11 | 'hair', 12 | 'glove', 13 | 'sunglasses', 14 | 'upper_cloth', 15 | 'dress', 16 | 'coat', 17 | 'sock', 18 | 'pant', 19 | 'jumpsuit', 20 | 'scarf', 21 | 'skirt', 22 | 'face', 23 | 'left_leg', 24 | 'right_leg', 25 | 'left_arm', 26 | 'right_arm', 27 | 'left_shoe', 28 | 'right_shoe', 29 | ] 30 | 31 | semantic_dim = len(semantic_list) 32 | 33 | # Conversion between the semantic map or the color used to represent those semantic maps 34 | 35 | 36 | def color_to_semantic(schp: torch.Tensor, # B, H, W, 3 37 | palette: torch.Tensor, # 256, 3 38 | ): 39 | sem_msk = schp.new_zeros(schp.shape[:3]) 40 | for i, rgb in enumerate(palette): 41 | belong = (schp - rgb).sum(axis=-1) == 0 42 | sem_msk[belong] = i 43 | return sem_msk # B, H, W 44 | 45 | 46 | def semantics_to_color(semantic: torch.Tensor, # V, 47 | palette: torch.Tensor, # 256, 3 48 | ): 49 | return linear_gather(palette, semantic) # V, 3 50 | 51 | 52 | def palette_to_index(sem: np.ndarray, semantic_dim=semantic_dim): 53 | # convert color coded semantic map to semantic index 54 | palette = get_schp_palette(semantic_dim) 55 | sem_msk = np.zeros(sem.shape[:2], dtype=np.uint8) 56 | for i, rgb in enumerate(palette): 57 | belong = (sem - rgb).sum(axis=-1) == 0 58 | sem_msk[belong] = i 59 | 60 | return sem_msk 61 | 62 | 63 | def palette_to_onehot(sem: np.ndarray, semantic_dim=semantic_dim): 64 | sem_msk = palette_to_index(sem, semantic_dim) 65 | # convert semantic index to one-hot vectors 66 | sem = torch.from_numpy(sem_msk) 67 | sem: torch.Tensor = F.one_hot(sem.long(), semantic_dim) 68 | sem = sem.float().numpy() 69 | return sem 70 | 71 | 72 | @lru_cache 73 | def get_schp_palette(num_cls=256): 74 | # Copied from SCHP 75 | """ Returns the color map for visualizing the segmentation mask. 76 | Inputs: 77 | =num_cls= 78 | Number of classes. 79 | Returns: 80 | The color map. 81 | """ 82 | n = num_cls 83 | palette = [0] * (n * 3) 84 | for j in range(0, n): 85 | lab = j 86 | palette[j * 3 + 0] = 0 87 | palette[j * 3 + 1] = 0 88 | palette[j * 3 + 2] = 0 89 | i = 0 90 | while lab: 91 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 92 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 93 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 94 | i += 1 95 | lab >>= 3 96 | 97 | palette = np.array(palette, dtype=np.uint8) 98 | palette = palette.reshape(-1, 3) # n_cls, 3 99 | return palette 100 | 101 | @lru_cache 102 | def get_schp_palette_tensor_float(num_cls=semantic_dim, device='cuda'): 103 | # Copied from SCHP 104 | """ Returns the color map for visualizing the segmentation mask. 105 | Inputs: 106 | =num_cls= 107 | Number of classes. 108 | Returns: 109 | The color map. 110 | """ 111 | n = num_cls 112 | palette = [0] * (n * 3) 113 | for j in range(0, n): 114 | lab = j 115 | palette[j * 3 + 0] = 0 116 | palette[j * 3 + 1] = 0 117 | palette[j * 3 + 2] = 0 118 | i = 0 119 | while lab: 120 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 121 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 122 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 123 | i += 1 124 | lab >>= 3 125 | 126 | palette = torch.tensor(palette, dtype=torch.float, device=device) / 255.0 127 | palette = palette.reshape(-1, 3) # n_cls, 3 128 | return palette 129 | -------------------------------------------------------------------------------- /lib/train/optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | from collections import Counter 3 | 4 | import torch 5 | 6 | 7 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | warmup_factor=1.0 / 3, 14 | warmup_iters=5, 15 | warmup_method="linear", 16 | last_epoch=-1, 17 | ): 18 | if not list(milestones) == sorted(milestones): 19 | raise ValueError( 20 | "Milestones should be a list of" " increasing integers. Got {}", 21 | milestones, 22 | ) 23 | 24 | if warmup_method not in ("constant", "linear"): 25 | raise ValueError( 26 | "Only 'constant' or 'linear' warmup_method accepted" 27 | "got {}".format(warmup_method) 28 | ) 29 | self.milestones = milestones 30 | self.gamma = gamma 31 | self.warmup_factor = warmup_factor 32 | self.warmup_iters = warmup_iters 33 | self.warmup_method = warmup_method 34 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 35 | 36 | def get_lr(self): 37 | warmup_factor = 1 38 | if self.last_epoch < self.warmup_iters: 39 | if self.warmup_method == "constant": 40 | warmup_factor = self.warmup_factor 41 | elif self.warmup_method == "linear": 42 | alpha = float(self.last_epoch) / self.warmup_iters 43 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 44 | return [ 45 | base_lr 46 | * warmup_factor 47 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 48 | for base_lr in self.base_lrs 49 | ] 50 | 51 | 52 | class MultiStepLR(torch.optim.lr_scheduler._LRScheduler): 53 | 54 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 55 | self.milestones = Counter(milestones) 56 | self.gamma = gamma 57 | super(MultiStepLR, self).__init__(optimizer, last_epoch) 58 | 59 | def get_lr(self): 60 | if self.last_epoch not in self.milestones: 61 | return [group['lr'] for group in self.optimizer.param_groups] 62 | return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 63 | 64 | 65 | class ExponentialLR(torch.optim.lr_scheduler._LRScheduler): 66 | 67 | def __init__(self, optimizer, decay_epochs, gamma=0.1, last_epoch=-1): 68 | self.decay_epochs = decay_epochs 69 | self.gamma = gamma 70 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 71 | 72 | def get_lr(self): 73 | return [base_lr * self.gamma ** (self.last_epoch / self.decay_epochs) for base_lr in self.base_lrs] 74 | 75 | 76 | class WarmupExponentialLR(torch.optim.lr_scheduler._LRScheduler): 77 | 78 | def __init__(self, 79 | optimizer, 80 | decay_epochs, 81 | warmup_factor=1.0 / 3, 82 | warmup_epochs=1, 83 | warmup_method="linear", 84 | gamma=0.1, 85 | last_epoch=-1): 86 | self.warmup_factor = warmup_factor 87 | self.warmup_epochs = warmup_epochs 88 | self.warmup_method = warmup_method 89 | self.decay_epochs = decay_epochs 90 | self.gamma = gamma 91 | super(WarmupExponentialLR, self).__init__(optimizer, last_epoch) 92 | 93 | def get_lr(self): 94 | warmup_factor = 1 95 | if self.last_epoch < self.warmup_epochs: 96 | if self.warmup_method == "constant": 97 | warmup_factor = self.warmup_factor 98 | elif self.warmup_method == "linear": 99 | alpha = float(self.last_epoch) / self.warmup_epochs 100 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 101 | return [base_lr * warmup_factor * self.gamma ** (self.last_epoch / self.decay_epochs) for base_lr in self.base_lrs] 102 | -------------------------------------------------------------------------------- /scripts/tools/prepare_annots.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import os 4 | import glob 5 | import numpy as np 6 | import cv2 7 | import argparse 8 | from ruamel.yaml import YAML 9 | yaml = YAML() 10 | 11 | 12 | def get_cams(): 13 | intri = cv2.FileStorage('intri.yml', cv2.FILE_STORAGE_READ) 14 | extri = cv2.FileStorage('extri.yml', cv2.FILE_STORAGE_READ) 15 | cams = {'K': [], 'D': [], 'R': [], 'T': []} 16 | for i in range(len(camera_names)): 17 | camera_name = camera_names[i] 18 | cams['K'].append(intri.getNode(f'K_{camera_name}').mat()) 19 | try: 20 | cams['D'].append( 21 | intri.getNode(f'D_{camera_name}').mat().T) 22 | except: 23 | cams['D'].append( 24 | intri.getNode(f'dist_{camera_name}').mat().T) 25 | cams['R'].append(extri.getNode(f'Rot_{camera_name}').mat()) 26 | cams['T'].append(extri.getNode(f'T_{camera_name}').mat() * 1000) 27 | return cams 28 | 29 | 30 | def get_img_paths(): 31 | all_ims = [] 32 | for i in range(len(camera_names)): 33 | camera_name = camera_names[i] 34 | cam_dir = f'{image_dir}/{camera_name}' 35 | ims = glob.glob(os.path.join(cam_dir, f'*{args.ext}')) 36 | ims = np.array(sorted(ims)) 37 | all_ims.append(ims) 38 | num_img = min([len(ims) for ims in all_ims]) 39 | all_ims = [ims[:num_img] for ims in all_ims] 40 | all_ims = np.stack(all_ims, axis=1) 41 | return all_ims 42 | 43 | 44 | def get_kpts2d(): 45 | def _get_kpts2d(paths): 46 | kpts2d_list = [] 47 | for path in paths: 48 | with open(path, 'r') as f: 49 | d = json.load(f) 50 | kpts2d = np.array(d['people'][0]['pose_keypoints_2d']).reshape( 51 | -1, 3) 52 | kpts2d_list.append(kpts2d) 53 | kpts2d = np.array(kpts2d_list) 54 | return kpts2d 55 | 56 | all_kpts = [] 57 | for i in range(len(camera_names)): 58 | camera_name = camera_names[i] 59 | cur_dump = f'keypoints2d/{camera_name}' 60 | paths = sorted(glob.glob(os.path.join(cur_dump, '*.json'))) 61 | kpts2d = _get_kpts2d(paths[:1400]) 62 | all_kpts.append(kpts2d) 63 | 64 | num_img = min([len(kpt) for kpt in all_kpts]) 65 | all_kpts = [kpt[:num_img] for kpt in all_kpts] 66 | all_kpts = np.stack(all_kpts, axis=1) 67 | 68 | return all_kpts 69 | 70 | 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--data_dir', type=str, default='data/subset_mi11') 73 | parser.add_argument('--image_dir', type=str, default='images') 74 | parser.add_argument('--humans', type=str, nargs='+', default=['talking_cont', 'talking_val', 'talking_step']) 75 | parser.add_argument('--ext', type=str, default='.jpg', choices=['.jpg', '.png']) 76 | parser.add_argument('--use_existing_cam', action='store_true', default=False) 77 | parser.add_argument('--use_existing_ims', action='store_true', default=False) 78 | args = parser.parse_args() 79 | camera_names = [] 80 | image_dir = args.image_dir 81 | for human_ind in range(len(args.humans)): 82 | human = args.humans[human_ind] 83 | 84 | root = os.path.join(args.data_dir, human) 85 | old = os.getcwd() 86 | os.chdir(root) 87 | 88 | if args.use_existing_cam: 89 | existing = np.load('annots.npy', allow_pickle=True).item() 90 | camera_names = [ 91 | f'Camera ({i+1})' for i in range(23) 92 | ] 93 | print(camera_names) 94 | cams = existing['cams'] 95 | else: 96 | camera_names = yaml.load(open('extri.yml'))['names'] 97 | print(camera_names) 98 | cams = get_cams() 99 | if args.use_existing_ims: 100 | img_paths = existing['ims'] 101 | else: 102 | camera_names = sorted(os.listdir(args.image_dir)) 103 | img_paths = get_img_paths() 104 | annot = {} 105 | annot['cams'] = cams 106 | 107 | ims = [] 108 | for img_path in img_paths: 109 | data = {} 110 | data['ims'] = img_path.tolist() 111 | # data['kpts2d'] = kpt.tolist() # TODO: inefficient but minimal code change 112 | ims.append(data) 113 | annot['ims'] = ims 114 | 115 | np.save('annots.npy', annot) 116 | # np.save('annots_python2.npy', annot, fix_imports=True) 117 | os.chdir(old) 118 | 119 | 120 | """ 121 | python tools/prepare_annots.py --data_dir data/my_zju_mocap --humans my_313 my_315 my_377 my_386 my_387 my_390 my_392 my_393 my_394 122 | """ 123 | -------------------------------------------------------------------------------- /lib/datasets/make_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import torch 4 | import importlib 5 | import numpy as np 6 | 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data._utils.collate import default_collate, default_convert 9 | from torch.utils.data.sampler import RandomSampler, SequentialSampler, BatchSampler 10 | 11 | from lib.config.config import cfg 12 | from lib.datasets import samplers 13 | 14 | torch.multiprocessing.set_sharing_strategy('file_system') 15 | cv2.setNumThreads(1) 16 | 17 | 18 | def _dataset_factory(cfg, is_train): 19 | if is_train: 20 | module = cfg.train_dataset_module 21 | args = cfg.train_dataset 22 | else: 23 | module = cfg.test_dataset_module 24 | args = cfg.test_dataset 25 | # __import__ is like import ... 26 | # __import_module__ is like from ... import ... 27 | # dataset = __import__(module, fromlist=[None]).Dataset(**args) 28 | # The __import__ function will return the top level module of a package, unless you pass a nonempty fromlist argument: 29 | # MARK: imp.load_source breaks the typing in Dataset, occ_dataset.Dataset would not == occ_dataset.Dataset.Dataset 30 | # switching to import lib can solve this problem 31 | # dataset = __import__(module, fromlist=[None]).Dataset(**args) 32 | dataset = importlib.import_module(module).Dataset(**args) 33 | return dataset 34 | 35 | 36 | def make_dataset(cfg, is_train=True): 37 | dataset = _dataset_factory(cfg, is_train) 38 | return dataset 39 | 40 | 41 | def make_data_sampler(dataset, shuffle, is_distributed, is_train): 42 | if not is_train and cfg.test.sampler == 'FrameSampler': 43 | sampler = samplers.FrameSampler(dataset) 44 | return sampler 45 | if not is_train and cfg.test.sampler == 'MeshFrameSampler': 46 | sampler = samplers.MeshFrameSampler(dataset) 47 | return sampler 48 | if is_distributed: 49 | return samplers.DistributedSampler(dataset, shuffle=shuffle) 50 | if shuffle: 51 | sampler = RandomSampler(dataset) 52 | else: 53 | sampler = SequentialSampler(dataset) 54 | return sampler 55 | 56 | 57 | def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter, is_train): 58 | if is_train: 59 | batch_sampler = cfg.train.batch_sampler 60 | sampler_meta = cfg.train.sampler_meta 61 | else: 62 | batch_sampler = cfg.test.batch_sampler 63 | sampler_meta = cfg.test.sampler_meta 64 | 65 | if batch_sampler == 'default': 66 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 67 | elif batch_sampler == 'image_size': 68 | batch_sampler = samplers.ImageSizeBatchSampler(sampler, batch_size, drop_last, sampler_meta) 69 | if max_iter != -1: 70 | batch_sampler = samplers.IterationBasedBatchSampler( 71 | batch_sampler, max_iter) 72 | return batch_sampler 73 | 74 | 75 | def worker_init_fn(worker_id): 76 | cv2.setNumThreads(1) # MARK: OpenCV undistort is why all cores are taken 77 | # previous randomness issue might just come from here 78 | if cfg.fix_random: 79 | np.random.seed(worker_id) 80 | else: 81 | np.random.seed(worker_id + (int(round(time.time() * 1000) % (2**16)))) 82 | 83 | 84 | def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1) -> DataLoader: 85 | if is_train: 86 | batch_size = cfg.train.batch_size 87 | shuffle = cfg.train.shuffle 88 | drop_last = False 89 | else: 90 | batch_size = cfg.test.batch_size 91 | shuffle = True if is_distributed else False 92 | drop_last = False 93 | 94 | dataset = make_dataset(cfg, is_train) 95 | sampler = make_data_sampler(dataset, shuffle, is_distributed, is_train) 96 | batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter, is_train) 97 | num_workers = min(cfg.train.num_workers, max_iter) if max_iter > 0 else cfg.train.num_workers 98 | prefetch_factor = cfg.prefetch_factor if num_workers > 1 else (None if torch.__version__ >= '2' else 2) 99 | pin_memory = cfg.pin_memory 100 | collate = default_collate if cfg.collate else default_convert 101 | data_loader = DataLoader(dataset, 102 | batch_sampler=batch_sampler, 103 | num_workers=num_workers, 104 | worker_init_fn=worker_init_fn, 105 | collate_fn=collate, 106 | pin_memory=pin_memory, 107 | prefetch_factor=prefetch_factor 108 | ) 109 | 110 | return data_loader 111 | -------------------------------------------------------------------------------- /lib/utils/easy_utils.py: -------------------------------------------------------------------------------- 1 | # easymocap utility functions 2 | import os 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | class FileStorage(object): 8 | def __init__(self, filename, isWrite=False): 9 | version = cv2.__version__ 10 | self.major_version = int(version.split('.')[0]) 11 | self.second_version = int(version.split('.')[1]) 12 | 13 | if isWrite: 14 | os.makedirs(os.path.dirname(filename), exist_ok=True) 15 | self.fs = open(filename, 'w') 16 | self.fs.write('%YAML:1.0\r\n') 17 | self.fs.write('---\r\n') 18 | else: 19 | assert os.path.exists(filename), filename 20 | self.fs = cv2.FileStorage(filename, cv2.FILE_STORAGE_READ) 21 | self.isWrite = isWrite 22 | 23 | def __del__(self): 24 | if self.isWrite: 25 | self.fs.close() 26 | else: 27 | cv2.FileStorage.release(self.fs) 28 | 29 | def _write(self, out): 30 | self.fs.write(out+'\r\n') 31 | 32 | def write(self, key, value, dt='mat'): 33 | if dt == 'mat': 34 | self._write('{}: !!opencv-matrix'.format(key)) 35 | self._write(' rows: {}'.format(value.shape[0])) 36 | self._write(' cols: {}'.format(value.shape[1])) 37 | self._write(' dt: d') 38 | self._write(' data: [{}]'.format(', '.join(['{:.8f}'.format(i) for i in value.reshape(-1)]))) 39 | elif dt == 'list': 40 | self._write('{}:'.format(key)) 41 | for elem in value: 42 | self._write(' - "{}"'.format(elem)) 43 | 44 | def read(self, key, dt='mat'): 45 | if dt == 'mat': 46 | output = self.fs.getNode(key).mat() 47 | elif dt == 'list': 48 | results = [] 49 | n = self.fs.getNode(key) 50 | for i in range(n.size()): 51 | val = n.at(i).string() 52 | if val == '': 53 | val = str(int(n.at(i).real())) 54 | if val != 'none': 55 | results.append(val) 56 | output = results 57 | else: 58 | raise NotImplementedError 59 | return output 60 | 61 | def close(self): 62 | self.__del__(self) 63 | 64 | 65 | def read_camera(intri_path: str, extri_path: str, cam_names=[]): 66 | assert os.path.exists(intri_path), intri_path 67 | assert os.path.exists(extri_path), extri_path 68 | 69 | intri = FileStorage(intri_path) 70 | extri = FileStorage(extri_path) 71 | cams, P = {}, {} 72 | cam_names = intri.read('names', dt='list') 73 | for cam in cam_names: 74 | # 内参只读子码流的 75 | cams[cam] = {} 76 | cams[cam]['K'] = intri.read('K_{}'.format(cam)) 77 | cams[cam]['invK'] = np.linalg.inv(cams[cam]['K']) 78 | Tvec = extri.read('T_{}'.format(cam)) 79 | 80 | Rvec = extri.read('R_{}'.format(cam)) 81 | if Rvec is not None: 82 | R = cv2.Rodrigues(Rvec)[0] 83 | else: 84 | R = extri.read('Rot_{}'.format(cam)) 85 | Rvec = cv2.Rodrigues(R)[0] 86 | RT = np.hstack((R, Tvec)) 87 | 88 | cams[cam]['RT'] = RT 89 | cams[cam]['R'] = R 90 | cams[cam]['Rvec'] = Rvec 91 | cams[cam]['T'] = Tvec 92 | cams[cam]['center'] = - Rvec.T @ Tvec 93 | P[cam] = cams[cam]['K'] @ cams[cam]['RT'] 94 | cams[cam]['P'] = P[cam] 95 | 96 | cams[cam]['dist'] = intri.read('dist_{}'.format(cam)) 97 | cams['basenames'] = cam_names 98 | return cams 99 | 100 | 101 | def write_camera(cameras: dict, path: str): 102 | from os.path import join 103 | os.makedirs(path, exist_ok=True) 104 | intri_name = join(path, 'intri.yml') # TODO: make them arguments 105 | extri_name = join(path, 'extri.yml') 106 | intri = FileStorage(intri_name, True) 107 | extri = FileStorage(extri_name, True) 108 | cam_names = [key_.split('.')[0] for key_ in cameras.keys()] 109 | intri.write('names', cam_names, 'list') 110 | extri.write('names', cam_names, 'list') 111 | for key_, val in cameras.items(): 112 | if key_ == 'basenames': 113 | continue 114 | key = key_.split('.')[0] 115 | intri.write('K_{}'.format(key), val['K']) 116 | if 'dist' not in val: 117 | val['dist'] = np.zeros((5, 1)) 118 | intri.write('dist_{}'.format(key), val['dist']) 119 | if 'R' not in val.keys(): 120 | val['R'] = cv2.Rodrigues(val['Rvec'])[0] 121 | if 'Rvec' not in val.keys(): 122 | val['Rvec'] = cv2.Rodrigues(val['R'])[0] 123 | extri.write('R_{}'.format(key), val['Rvec']) 124 | extri.write('Rot_{}'.format(key), val['R']) 125 | extri.write('T_{}'.format(key), val['T']) 126 | -------------------------------------------------------------------------------- /lib/datasets/pose_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import imageio 4 | import numpy as np 5 | from lib.config import cfg 6 | from lib.utils.render_utils import load_cam 7 | from lib.utils.data_utils import read_mask_by_img_path, get_bounds, get_rays_within_bounds, load_image 8 | 9 | from . import base_dataset 10 | 11 | 12 | class Dataset(base_dataset.Dataset): 13 | def __init__(self, data_root, human, ann_file, split): 14 | super(Dataset, self).__init__(data_root, human, ann_file, split) 15 | self.load_camera() 16 | 17 | def load_ims_data(self): 18 | pass 19 | 20 | def load_camera(self): 21 | self.Ks = np.array(self.cams['K'])[self.view].astype(np.float32) 22 | self.Rs = np.array(self.cams['R'])[self.view].astype(np.float32) 23 | self.Ts = np.array(self.cams['T'])[self.view].astype(np.float32) / 1000.0 24 | self.Ds = np.array(self.cams['D'])[self.view].astype(np.float32) 25 | 26 | self.Ks[:, :2] = self.Ks[:, :2] * cfg.ratio # prepare for rendering at different scale 27 | lower_row = np.array([[[0., 0., 0., 1.]]], dtype=np.float32).repeat(len(self.Ks), axis=0) # 1, 1, 4 -> N, 1, 4 28 | self.RT = np.concatenate([self.Rs, self.Ts], axis=-1) # N, 3, 3 + N, 1, 3 29 | self.RT = np.concatenate([self.RT, lower_row], axis=-2) # N, 3, 4 + N, 1, 4 30 | 31 | if hasattr(self, 'BGs'): 32 | self.BGs = self.BGs[self.view].astype(np.float32) 33 | BGs = [] 34 | for v in range(len(self.view)): 35 | D = self.Ds[v] 36 | K = self.Ks[v] 37 | BG = self.BGs[v] 38 | H, W = BG.shape[:2] 39 | H, W = int(H * cfg.ratio), int(W * cfg.ratio) 40 | BG = cv2.resize(BG, (W, H), interpolation=cv2.INTER_AREA) 41 | BG = cv2.undistort(BG, K, D) 42 | BGs.append(BG) 43 | self.BGs = np.stack(BGs) 44 | 45 | def __getitem__(self, index): 46 | # ? BUG 47 | latent_index, frame_index, view_index, cam_index = self.get_indices(index) 48 | 49 | # These are SMPL bw, bounds, vertices 50 | wpts, ppts, A, joints, Rh, Th, poses, shapes = self.get_lbs_params(frame_index) 51 | wbounds = get_bounds(wpts) 52 | 53 | if cfg.H <= 0 or cfg.W <= 0: 54 | H, W = self.H, self.W 55 | H, W = int(H * cfg.ratio), int(W * cfg.ratio) 56 | K = self.Ks[view_index] 57 | else: 58 | H, W = cfg.H, cfg.W 59 | K = np.zeros((3, 3), dtype=np.float32) 60 | K[2, 2] = 1 61 | K[0, 0] = H * cfg.novel_view_ixt_ratio 62 | K[1, 1] = H * cfg.novel_view_ixt_ratio 63 | K[0, 2] = H / 2 64 | K[1, 2] = H / 2 65 | 66 | RT = self.RT[view_index] 67 | R, T = RT[:3, :3], RT[:3, 3:] 68 | ray_o, ray_d, near, far, mask_at_box = get_rays_within_bounds(H, W, K, R, T, wbounds) 69 | 70 | # load SMPL & pose & human related parameters 71 | ret = self.get_blend(frame_index) 72 | 73 | # store camera parameters 74 | meta = { 75 | "cam_K": K, 76 | "cam_R": R, 77 | "cam_T": T, 78 | "cam_RT": np.concatenate([R, T], axis=1), 79 | 'H': H, 80 | 'W': W, 81 | 'RT': self.RT, 82 | 'Ks': self.Ks, 83 | } 84 | ret.update(meta) 85 | ret.meta.update(meta) 86 | 87 | # store camera background images 88 | if hasattr(self, 'BGs'): 89 | BG = self.BGs[view_index] 90 | meta = { 91 | "cam_BG": BG, 92 | } 93 | ret.update(meta) 94 | ret.meta.update(meta) 95 | 96 | # store ray data 97 | meta = { 98 | 'ray_o': ray_o, 99 | 'ray_d': ray_d, 100 | 'near': near, 101 | 'far': far, 102 | 'mask_at_box': mask_at_box, 103 | } 104 | ret.update(meta) 105 | latent_index = index // len(self.view) 106 | meta = { 107 | 'latent_index': latent_index, 108 | 'frame_index': frame_index, 109 | 'view_index': self.view[view_index], 110 | } 111 | ret.update(meta) 112 | ret.meta.update(meta) 113 | return ret 114 | 115 | def get_indices(self, index): 116 | view_index = index % len(self.view) 117 | latent_index = index // len(self.view) 118 | frame_index = self.i + latent_index * self.i_intv # recompute frame index for i 119 | cam_index = view_index 120 | return latent_index, frame_index, view_index, cam_index 121 | 122 | def __len__(self): 123 | # return self.ims.size # number of elements, regardless of dimensions 124 | # pose dataset should consider arbitrary length novel pose 125 | return self.ni * self.num_cams 126 | -------------------------------------------------------------------------------- /lib/train/recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from collections import deque, defaultdict 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | from lib.config.config import cfg 8 | from lib.utils.base_utils import dotdict, default_dotdict 9 | from lib.utils.log_utils import log 10 | 11 | 12 | class SmoothedValue(object): 13 | """Track a series of values and provide access to smoothed values over a 14 | window or the global series average. 15 | """ 16 | 17 | def __init__(self, window_size=20): 18 | self.deque = deque(maxlen=window_size) 19 | self.total = 0.0 20 | self.count = 0 21 | self.value = 0 22 | 23 | def update(self, value): 24 | self.deque.append(value) 25 | self.count += 1 26 | self.total += value 27 | self.value = value 28 | 29 | @property 30 | def latest(self): 31 | d = torch.tensor(list(self.deque)) 32 | return d[-1].item() 33 | 34 | @property 35 | def median(self): 36 | d = torch.tensor(list(self.deque)) 37 | return d.float().median().item() 38 | 39 | @property 40 | def avg(self): 41 | d = torch.tensor(list(self.deque)) 42 | return d.float().mean().item() 43 | 44 | @property 45 | def global_avg(self): 46 | return self.total / self.count 47 | 48 | @property 49 | def val(self): 50 | return self.value 51 | 52 | 53 | class Recorder(object): 54 | def __init__(self, cfg): 55 | if cfg.local_rank > 0: 56 | return 57 | 58 | log_dir = cfg.record_dir 59 | if not cfg.resume: 60 | log(f'removing training record: {log_dir}', 'red') 61 | os.system(f'rm -rf {log_dir}') 62 | self.writer = SummaryWriter(log_dir=log_dir) 63 | 64 | # scalars 65 | self.epoch = 0 66 | self.step = 0 67 | self.record_stats = default_dotdict(SmoothedValue) 68 | 69 | # images 70 | self.image_stats = default_dotdict(object) 71 | if 'process_' + cfg.task in globals(): 72 | self.processor = globals()['process_' + cfg.task] 73 | else: 74 | self.processor = None 75 | 76 | def update_record_stats(self, record_stats: dotdict): 77 | if cfg.local_rank > 0: 78 | return 79 | for k, v in record_stats.items(): 80 | self.record_stats[k].update(v) 81 | 82 | def update_image_stats(self, image_stats: dotdict): 83 | if cfg.local_rank > 0: 84 | return 85 | if self.processor is None: 86 | return 87 | image_stats = self.processor(image_stats) 88 | for k, v in image_stats.items(): 89 | self.image_stats[k] = v 90 | 91 | def record(self, prefix, step=-1, record_stats: dotdict = None, image_stats: dotdict = None): 92 | if cfg.local_rank > 0: 93 | return 94 | 95 | pattern = prefix + '/{}' 96 | step = step if step >= 0 else self.step 97 | record_stats = record_stats if record_stats else self.record_stats 98 | 99 | for k, v in record_stats.items(): 100 | if isinstance(v, SmoothedValue): 101 | self.writer.add_scalar(pattern.format(k), v.median, step) 102 | else: 103 | self.writer.add_scalar(pattern.format(k), v, step) 104 | 105 | if self.processor is None: 106 | return 107 | image_stats = self.processor(image_stats) if image_stats else self.image_stats 108 | for k, v in image_stats.items(): 109 | self.writer.add_image(pattern.format(k), v, step) 110 | 111 | def state_dict(self): 112 | if cfg.local_rank > 0: 113 | return 114 | scalar_dict = {} 115 | scalar_dict['step'] = self.step 116 | return scalar_dict 117 | 118 | def load_state_dict(self, scalar_dict): 119 | if cfg.local_rank > 0: 120 | return 121 | self.step = scalar_dict['step'] 122 | 123 | @property 124 | def log_stats(self): 125 | if cfg.local_rank > 0: 126 | return 127 | log_stats = dotdict() 128 | log_stats.epoch = str(self.epoch) 129 | log_stats.step = str(self.step) 130 | for k, v in self.record_stats.items(): 131 | if isinstance(v, SmoothedValue): 132 | log_stats[k] = f'{v.avg:.6f}' 133 | else: 134 | log_stats[k] = v 135 | log_stats.lr = f'{self.record_stats.lr.val:.6f}' 136 | log_stats.data = f'{self.record_stats.data.val:.4f}' 137 | log_stats.batch = f'{self.record_stats.batch.val:.4f}' 138 | log_stats.max_mem = f'{self.record_stats.max_mem.val:.0f}' 139 | return log_stats 140 | 141 | def __str__(self): 142 | return ' '.join([k + ': ' + v for k, v in self.log_stats.items()]) 143 | 144 | 145 | def make_recorder(cfg): 146 | recorder = Recorder(cfg) 147 | 148 | return recorder 149 | -------------------------------------------------------------------------------- /lib/evaluators/base_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from skimage.metrics import structural_similarity as compare_ssim 6 | 7 | from lib.config import cfg 8 | from lib.utils.log_utils import log 9 | from ..visualizers import base_visualizer 10 | 11 | 12 | class Evaluator(base_visualizer.Visualizer): 13 | def __init__(self): 14 | if cfg.local_rank > 0: 15 | return 16 | 17 | super(Evaluator, self).__init__() 18 | 19 | import lpips 20 | self.mse = [] 21 | self.psnr = [] 22 | self.ssim = [] 23 | self.lpips = [] 24 | self.compute_lpips = lpips.LPIPS(verbose=False) 25 | 26 | def psnr_metric(self, img_pred, img_gt): 27 | mse = np.mean((img_pred - img_gt)**2) 28 | psnr = -10 * np.log(mse) / np.log(10) 29 | return psnr 30 | 31 | def ssim_metric(self, img_pred, img_gt, batch): 32 | if not cfg.eval_whole_img: 33 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 34 | H, W = batch['H'].item(), batch['W'].item() 35 | mask_at_box = mask_at_box.reshape(H, W) 36 | # crop the object region 37 | x, y, w, h = cv2.boundingRect(mask_at_box.astype(np.uint8)) 38 | img_pred = img_pred[y:y + h, x:x + w] 39 | img_gt = img_gt[y:y + h, x:x + w] 40 | 41 | if 'crop_bbox' in batch: 42 | img_pred = Evaluator.fill_image(img_pred, batch) 43 | img_gt = Evaluator.fill_image(img_gt, batch) 44 | 45 | # compute the ssim 46 | ssim = compare_ssim(img_pred, img_gt, channel_axis=-1, data_range=1) 47 | 48 | return ssim 49 | 50 | def lpips_metric(self, img_pred, img_gt, batch): 51 | if not cfg.eval_whole_img: 52 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 53 | H, W = batch['H'].item(), batch['W'].item() 54 | mask_at_box = mask_at_box.reshape(H, W) 55 | # crop the object region 56 | x, y, w, h = cv2.boundingRect(mask_at_box.astype(np.uint8)) 57 | img_pred = img_pred[y:y + h, x:x + w] 58 | img_gt = img_gt[y:y + h, x:x + w] 59 | 60 | if 'crop_bbox' in batch: 61 | img_pred = Evaluator.fill_image(img_pred, batch) 62 | img_gt = Evaluator.fill_image(img_gt, batch) 63 | 64 | # compute the lpips 65 | with torch.no_grad(): 66 | lpips = self.compute_lpips(torch.Tensor(img_pred.transpose((2, 0, 1))[None]), torch.Tensor(img_gt.transpose((2, 0, 1)))[None])[0] 67 | lpips = lpips.item() 68 | 69 | return lpips 70 | 71 | def evaluate(self, output, batch): 72 | rgb_pred = output['rgb_map'][0].detach().cpu().numpy() 73 | rgb_gt = batch['rgb'][0].detach().cpu().numpy() 74 | 75 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 76 | H, W = batch['H'].item(), batch['W'].item() 77 | mask_at_box = mask_at_box.reshape(H, W) 78 | # convert the pixels into an image 79 | white_bkgd = cfg.bg_brightness 80 | 81 | if rgb_pred.ndim == 2: 82 | img_pred = np.zeros((H, W, 3)) + white_bkgd 83 | img_pred[mask_at_box] = rgb_pred 84 | img_gt = np.zeros((H, W, 3)) + white_bkgd 85 | img_gt[mask_at_box] = rgb_gt 86 | else: 87 | img_pred = rgb_pred 88 | img_gt = rgb_gt 89 | 90 | if cfg.eval_whole_img: 91 | rgb_pred = img_pred 92 | rgb_gt = img_gt 93 | 94 | mse = np.mean((rgb_pred - rgb_gt)**2) 95 | self.mse.append(mse) 96 | 97 | psnr = self.psnr_metric(rgb_pred, rgb_gt) 98 | self.psnr.append(psnr) 99 | 100 | ssim = self.ssim_metric(rgb_pred, rgb_gt, batch) 101 | self.ssim.append(ssim) 102 | 103 | lpips = self.lpips_metric(rgb_pred, rgb_gt, batch) 104 | self.lpips.append(lpips) 105 | 106 | self.visualize(output, batch) 107 | 108 | def summarize(self): 109 | super(Evaluator, self).summarize() # will save images 110 | 111 | result_dir = cfg.result_dir 112 | log('the results are saved at {}'.format(result_dir), 'yellow') 113 | 114 | result_path = os.path.join(cfg.result_dir, 'metrics.npy') 115 | os.makedirs(os.path.dirname(result_path), exist_ok=True) 116 | metrics = {'mse': self.mse, 'psnr': self.psnr, 'ssim': self.ssim, 'lpips': self.lpips} 117 | np.save(result_path, metrics) 118 | mse, psnr, ssim, lpips = np.mean(self.mse), np.mean(self.psnr), np.mean(self.ssim), np.mean(self.lpips) 119 | mean_metrics = {'mse': mse, 'psnr': psnr, 'ssim': ssim, 'lpips': lpips} 120 | log('mse: {}'.format(np.mean(self.mse))) 121 | log('psnr: {}'.format(np.mean(self.psnr))) 122 | log('ssim: {}'.format(np.mean(self.ssim))) 123 | 124 | self.mse = [] 125 | self.psnr = [] 126 | self.ssim = [] 127 | self.lpips = [] 128 | 129 | return mean_metrics 130 | -------------------------------------------------------------------------------- /lib/networks/renderer/base_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from lib.config import cfg 5 | from lib.utils.base_utils import dotdict 6 | from lib.utils.net_utils import volume_rendering, chunkify 7 | from lib.networks.deform import base_network 8 | 9 | 10 | class Renderer(nn.Module): 11 | def __init__(self, net: base_network.Network): 12 | super(Renderer, self).__init__() 13 | self.net = net 14 | 15 | def get_wsampling_points(self, ray_o, ray_d, near, far): 16 | # calculate the steps for each ray 17 | t_vals = torch.linspace(0., 1., steps=cfg.n_samples, device=near.device, dtype=near.dtype) 18 | z_vals = near[..., None] * (1. - t_vals) + far[..., None] * t_vals 19 | 20 | if cfg.perturb > 0. and self.net.training: 21 | # get intervals between samples 22 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 23 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 24 | lower = torch.cat([z_vals[..., :1], mids], -1) 25 | # stratified samples in those intervals 26 | t_rand = torch.rand(z_vals.shape, device=upper.device, dtype=upper.dtype) 27 | z_vals = lower + (upper - lower) * t_rand 28 | 29 | pts = ray_o[:, :, None] + ray_d[:, :, None] * z_vals[..., None] 30 | 31 | return pts, z_vals 32 | 33 | def get_density_color(self, wpts, viewdir, z_vals, raw_decoder) -> dotdict: 34 | """ 35 | wpts: n_batch, n_pixel, n_sample, 3 36 | viewdir: n_batch, n_pixel, 3 37 | z_vals: n_batch, n_pixel, n_sample 38 | """ 39 | B, P, S = wpts.shape[:3] 40 | wpts = wpts.view(B, P * S, -1) 41 | viewdir = viewdir[:, :, None].expand(-1, -1, S, -1) 42 | viewdir = viewdir.reshape(B, P * S, -1) 43 | 44 | # calculate dists for the opacity computation 45 | dists = z_vals[..., 1:] - z_vals[..., :-1] 46 | dists = torch.cat([dists, dists[..., -1:]], dim=2) 47 | dists = dists.view(B, P * S) 48 | 49 | ret = raw_decoder(wpts, viewdir, dists) 50 | 51 | return ret 52 | 53 | def get_pixel_value(self, ray_o, ray_d, near, far, batch) -> dotdict: 54 | # sampling points for nerf training 55 | wpts, z_vals = self.get_wsampling_points(ray_o, ray_d, near, far) 56 | B, P, S = wpts.shape[:3] 57 | 58 | # viewing direction, ray_d has been normalized in the dataset 59 | viewdir = ray_d 60 | 61 | def raw_decoder(wpts_val, viewdir_val, dists_val): return self.net(wpts_val, viewdir_val, dists_val, batch) 62 | 63 | # compute the color and density 64 | ret = self.get_density_color(wpts, viewdir, z_vals, raw_decoder) 65 | 66 | # reshape to [num_rays, num_samples along ray, 4] 67 | raw: torch.Tensor = ret.raw 68 | rgb = raw[..., :-1].view(B, P, S, raw.shape[-1] - 1) # B, P, S, 3 69 | occ = raw[..., -1:].view(B, P, S) # B, P, S, 1 70 | 71 | # volume rendering of rgb values 72 | weights, raw_map, acc_map = volume_rendering(rgb, occ, bg_brightness=cfg.bg_brightness) 73 | 74 | depth_map = torch.sum(weights * z_vals, dim=-1) 75 | 76 | raw_map = raw_map.view(B, P, -1) 77 | acc_map = acc_map.view(B, P) 78 | depth_map = depth_map.view(B, P) 79 | 80 | # prepare for regulariaztion on distortion loss 81 | ret.weights = weights 82 | ret.z_vals = z_vals 83 | 84 | # when not training, construct new return values (discard previously cached data) 85 | if not self.net.training: 86 | ret = dotdict() # save some memory 87 | 88 | # add more visualization 89 | if not self.net.training: 90 | ret.depth_map = depth_map 91 | 92 | # for visualization 93 | raw = raw_map # return to rgb_map before volume rendering? 94 | 95 | # not training visualization 96 | if raw.shape[-1] >= 9: 97 | cpts, bpts, resd, raw = raw[..., :3], raw[..., 3:6], raw[..., 6:9], raw[..., 9:] 98 | if not self.net.training: 99 | ret.cpts_map = cpts 100 | ret.bpts_map = bpts 101 | ret.resd_map = resd 102 | 103 | # another type of network output, no need to explicitly render 104 | if raw.shape[-1] >= 6: 105 | norm, raw = raw.split([3, 3], dim=-1) 106 | if not self.net.training: 107 | ret.norm_map = norm 108 | 109 | # training or not, always add in these 110 | ret.rgb_map = raw 111 | ret.acc_map = acc_map # for mask loss 112 | 113 | return ret 114 | 115 | def render(self, batch): 116 | ray_o = batch.ray_o 117 | ray_d = batch.ray_d 118 | near = batch.near 119 | far = batch.far 120 | near = near.clip(min=cfg.clip_near) # do not go back the camera 121 | far = far.clip(max=cfg.clip_far) # do not go back the camera 122 | 123 | # volume rendering for each pixel 124 | chunk = cfg.train_chunk_size if self.net.training else cfg.render_chunk_size 125 | @chunkify(chunk, dim=-2, merge_dims=True) 126 | def chunked_get_pixel_value(ray_o, ray_d, near, far, batch): return self.get_pixel_value(ray_o, ray_d, near, far, batch) 127 | ret = chunked_get_pixel_value(ray_o, ray_d, near, far, batch) 128 | 129 | return dotdict(ret) 130 | -------------------------------------------------------------------------------- /lib/networks/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from lib.utils.net_utils import make_buffer 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | r"""Inject some information about the relative or absolute position of the tokens in the sequence. 10 | The positional encodings have the same dimension as the embeddings, so that the two can be summed. 11 | Here, we use sine and cosine functions of different frequencies. 12 | .. math: 13 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 14 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 15 | \text{where pos is the word position and i is the embed idx) 16 | Args: 17 | d_model: the embed dim (required). 18 | dropout: the dropout value (default=0.1). 19 | max_len: the max. length of the incoming sequence (default=5000). 20 | Examples: 21 | >>> pos_encoder = PositionalEncoding(d_model) 22 | """ 23 | 24 | def __init__(self, d_model, max_len=30, dropout=0.1): 25 | super(PositionalEncoding, self).__init__() 26 | self.dropout = nn.Dropout(p=dropout) 27 | 28 | pe = torch.zeros(max_len, d_model) 29 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 30 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 31 | pe[:, 0::2] = torch.sin(position * div_term) 32 | pe[:, 1::2] = torch.cos(position * div_term) 33 | pe = pe.unsqueeze(0) 34 | self.pe = make_buffer(pe) 35 | 36 | def forward(self, x): 37 | r"""Inputs of forward function 38 | Args: 39 | x: the sequence fed to the positional encoder model (required). 40 | Shape: 41 | x: [batch size, sequence length, embed dim] 42 | output: [batch size, sequence length, embed dim] 43 | Examples: 44 | >>> output = pos_encoder(x) 45 | """ 46 | 47 | x = x + self.pe[:, :x.shape[-2], :] # B, L, D + B, L, D 48 | return self.dropout(x) 49 | 50 | 51 | class MergingTransformer(nn.Module): 52 | """Container module with an encoder, a recurrent or transformer module, and a decoder.""" 53 | 54 | def __init__(self, max_len, in_dim, d_model=256, nhead=4, dim_feedforward=256, num_encoder_layers=4, dropout=0.0): 55 | super(MergingTransformer, self).__init__() 56 | self.d_model = d_model 57 | self.linear_mapping = nn.Linear(in_dim, d_model) 58 | self.positional_encoding = PositionalEncoding(d_model=d_model, dropout=dropout) 59 | encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) 60 | self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers) 61 | self.linear_decoder = nn.Conv1d(max_len, 1, 1) 62 | 63 | def forward(self, src: torch.Tensor) -> torch.Tensor: 64 | # src: [batch size, sequence length, embed dim] 65 | # out: [batch size, embed dim] 66 | src = src * np.sqrt(self.d_model) # B, L, D 67 | src = self.linear_mapping(src) 68 | src = self.positional_encoding(src) # B, L, D 69 | out = self.transformer_encoder(src) # B, L, D 70 | out = self.linear_decoder(out) # B, 1, D 71 | return out # B, 1, D 72 | 73 | 74 | def generate_square_subsequent_mask(sz1, sz2, device='cuda'): 75 | return torch.triu(torch.full((sz1, sz2), float('-inf'), device=device), diagonal=1) 76 | 77 | 78 | class SequenceTransformer(nn.Module): 79 | 80 | def __init__(self, in_dim, max_len=30, d_model=256, dropout=0.0): 81 | super(SequenceTransformer, self).__init__() 82 | self.d_model = d_model 83 | self.linear_mapping = nn.Linear(in_dim, d_model) 84 | self.positional_encoding = PositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_len) 85 | self.transformer = nn.Transformer(d_model=d_model, 86 | dropout=dropout, 87 | batch_first=True, 88 | # nhead=4, 89 | # num_encoder_layers=4, 90 | # num_decoder_layers=4, 91 | # dim_feedforward=256, 92 | ) 93 | 94 | def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: 95 | # src: B, L1, D 96 | # tgt: B, L2, D 97 | # out: B, L2, D 98 | 99 | B, L1, D = src.shape 100 | B, L2, D = tgt.shape 101 | mem_msk = generate_square_subsequent_mask(L2, L1, src.device) # prevent the network from looking forward in time? 102 | tgt_msk = generate_square_subsequent_mask(L2, L2, src.device) 103 | 104 | src = self.linear_mapping(src) # convert input to the required embedding dimension 105 | src = self.positional_encoding(src) # add positional information to the input variables 106 | tgt = self.positional_encoding(tgt) # add positional information to the target variables 107 | tgt = self.transformer(src, tgt, 108 | memory_mask=mem_msk, 109 | tgt_mask=tgt_msk, 110 | ) # B, L2, D 111 | return tgt 112 | -------------------------------------------------------------------------------- /lib/train/trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from lib.config import cfg 8 | from lib.utils.base_utils import dotdict 9 | from lib.utils.loss_utils import l1, l2, elastic_crit, eikonal, mse, l1_reg, dot, mIoU_loss, huber, cross_entropy, reg_raw_crit, anneal_loss_weight 10 | from lib.utils.net_utils import normalize 11 | from lib.networks.renderer import base_renderer, make_renderer 12 | 13 | 14 | class NetworkWrapper(nn.Module): 15 | def __init__(self, net): 16 | super(NetworkWrapper, self).__init__() 17 | self.renderer: base_renderer.Renderer = make_renderer(cfg, net) 18 | 19 | def forward(self, batch): 20 | ret = self.renderer.render(batch) 21 | scalar_stats = dotdict() 22 | loss = 0 23 | 24 | # mipnerf360 regularzation 25 | if 'distortion' in ret: 26 | dist_loss = ret['distortion'].mean() 27 | scalar_stats.update({'dist_loss': dist_loss}) 28 | loss += cfg.dist_loss_weight * dist_loss 29 | 30 | # human-nerf optimizing monocular pose 31 | if 'presd' in ret: 32 | presd_loss = ret['presd'].view(-1, 3).norm(dim=-1).mean() 33 | scalar_stats.update({'presd_loss': presd_loss}) 34 | loss += cfg.presd_loss_weight * presd_loss 35 | 36 | # deforming aninerf loss 37 | if 'oresd' in ret and ret.oresd.numel(): 38 | oresd, weight = reg_raw_crit(ret.oresd, batch.iter_step) # svd of jacobian elastic loss 39 | scalar_stats.update({'oresd': oresd}) 40 | loss += oresd * weight 41 | 42 | if 'jac' in ret and ret.jac.numel(): 43 | jac, weight = reg_raw_crit(ret.jac, batch.iter_step) # length of difference in value of neighbor points 44 | scalar_stats.update({'jac': jac}) 45 | loss += jac * weight # TODO: remove weights 46 | 47 | if 'ograd' in ret and ret.ograd.numel(): 48 | ograd, weight = reg_raw_crit(ret.ograd, batch.iter_step) # length of difference in value of neighbor points 49 | scalar_stats.update({'ograd': ograd}) 50 | loss += ograd * weight 51 | 52 | if 'cgrad' in ret and ret.cgrad.numel(): 53 | cgrad, weight = reg_raw_crit(ret.cgrad, batch.iter_step) # length of difference in value of neighbor points 54 | scalar_stats.update({'cgrad': cgrad}) 55 | loss += cgrad * weight 56 | 57 | # anisdf loss 58 | if 'residuals' in ret: # residual offset 59 | resd_loss = ret['residuals'].norm(dim=-1).mean() 60 | resd_loss_weight = anneal_loss_weight(cfg.resd_loss_weight, cfg.resd_loss_weight_gamma, batch.meta.iter_step, cfg.resd_loss_weight_milestone) 61 | scalar_stats.update({'resd_loss': resd_loss}) 62 | if cfg.resd_loss_weight_gamma != 1.0: 63 | scalar_stats.update({'resd_loss_weight': resd_loss_weight}) 64 | loss += resd_loss_weight * resd_loss 65 | 66 | if 'gradients' in ret: # gradients 67 | gradients = ret['gradients'] 68 | grad_loss = eikonal(gradients) 69 | scalar_stats.update({'grad_loss': grad_loss}) 70 | loss += cfg.eikonal_loss_weight * grad_loss 71 | 72 | if 'observed_gradients' in ret and ret['observed_gradients'].numel(): 73 | ogradients = ret['observed_gradients'] 74 | ograd_loss = eikonal(ogradients) 75 | scalar_stats.update({'ograd_loss': ograd_loss}) 76 | loss += cfg.observed_eikonal_loss_weight * ograd_loss 77 | 78 | if 'norm_map' in ret and 'norm' in batch: 79 | # image derivative only available when we're doing patch sampling, how to achieve this? 80 | norm_map = normalize(ret['norm_map']) # world space normal B, N, 3 @ B, 3, 3 81 | norm = normalize(batch['norm']) # in world space right? 82 | view_map = batch['ray_d'] # B, N, 3 83 | view_dot = dot(norm_map, -view_map).clip(0, 1) # B, N, this serves as weight to the normal loss 84 | norm_loss = ((norm_map - norm).abs().sum(dim=-1) + (1 - dot(norm_map, norm))) * view_dot 85 | norm_loss = norm_loss.mean() 86 | # norm_loss = l1(norm_map, norm) 87 | scalar_stats.update({'norm_loss': norm_loss}) 88 | loss += cfg.norm_loss_weight * norm_loss 89 | 90 | if 'sem_map' in ret and 'sem' in batch: 91 | sem_loss = cross_entropy(ret['sem_map'], batch['sem']) 92 | scalar_stats.update({'sem_loss': sem_loss}) 93 | loss += cfg.sem_loss_weight * sem_loss 94 | 95 | if 'acc_map' in ret and 'msk' in batch: 96 | msk_loss = mIoU_loss(ret['acc_map'], batch['msk']) 97 | scalar_stats.update({'msk_loss': msk_loss}) 98 | loss += cfg.msk_loss_weight * msk_loss 99 | 100 | if 'rgb_map' in ret: 101 | img_loss = mse(ret['rgb_map'], batch['rgb']) 102 | psnr = (1 / img_loss).log() * 10 / np.log(10) 103 | scalar_stats.update({'img_loss': img_loss}) 104 | scalar_stats.update({'psnr': psnr}) 105 | loss += cfg.img_loss_weight * img_loss 106 | 107 | scalar_stats.update({'loss': loss}) 108 | image_stats = {} 109 | 110 | return ret, loss, scalar_stats, image_stats 111 | -------------------------------------------------------------------------------- /lib/train/trainers/relight_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from lib.config import cfg 8 | from lib.utils.base_utils import dotdict 9 | from lib.utils.loss_utils import l1, l2, mse, gaussian_entropy, gaussian_entropy_relighting4d, eikonal, mIoU_loss 10 | from lib.utils.net_utils import normalize 11 | from lib.networks.renderer import base_renderer, make_renderer 12 | 13 | 14 | class NetworkWrapper(nn.Module): 15 | def __init__(self, net): 16 | super(NetworkWrapper, self).__init__() 17 | self.renderer: base_renderer.Renderer = make_renderer(cfg, net) 18 | 19 | def forward(self, batch): 20 | ret = self.renderer.render(batch) 21 | scalar_stats = dotdict() 22 | loss = 0 23 | 24 | # directional distance trainer 25 | if 'dd_gt' in ret and 'dd_pred' in ret: 26 | pred = ret.dd_pred 27 | gt = ret.dd_gt 28 | dd_loss = mse(pred, gt) 29 | scalar_stats.update({'dd_loss': dd_loss}) 30 | loss += cfg.dd_loss_weight * dd_loss 31 | 32 | if 'di_gt' in ret and 'di_pred' in ret: 33 | pred = ret.di_pred 34 | gt = ret.di_gt 35 | di_loss = mse(pred, gt) 36 | scalar_stats.update({'di_loss': di_loss}) 37 | loss += cfg.di_loss_weight * di_loss 38 | 39 | if 'pi_gt' in ret and 'pi_pred' in ret: 40 | pred = ret.pi_pred 41 | gt = ret.pi_gt 42 | pi_loss = mIoU_loss(pred, gt) 43 | scalar_stats.update({'pi_loss': pi_loss}) 44 | loss += cfg.pi_loss_weight * pi_loss 45 | 46 | # anisdf loss 47 | if 'residuals' in ret: # residual offset 48 | resd_loss = ret['residuals'].norm(dim=-1).mean() 49 | scalar_stats.update({'resd_loss': resd_loss}) 50 | loss += cfg.resd_loss_weight * resd_loss 51 | 52 | if 'gradients' in ret: # gradients 53 | gradients = ret['gradients'] 54 | grad_loss = eikonal(gradients) 55 | scalar_stats.update({'grad_loss': grad_loss}) 56 | loss += cfg.eikonal_loss_weight * grad_loss 57 | 58 | if 'observed_gradients' in ret and ret['observed_gradients'].numel(): 59 | ogradients = ret['observed_gradients'] 60 | ograd_loss = eikonal(ogradients) 61 | scalar_stats.update({'ograd_loss': ograd_loss}) 62 | loss += cfg.observed_eikonal_loss_weight * ograd_loss 63 | 64 | if 'acc_map' in ret and 'msk' in batch: 65 | msk_loss = mIoU_loss(ret['acc_map'], batch['msk']) 66 | scalar_stats.update({'msk_loss': msk_loss}) 67 | loss += cfg.msk_loss_weight * msk_loss 68 | 69 | # relighting loss 70 | if 'albedo' in ret: 71 | albedo_entropy = gaussian_entropy(ret.albedo) 72 | scalar_stats.update({'albedo_entropy': albedo_entropy}) 73 | # print(ret.albedo.max()) 74 | # if albedo_entropy == 0.0: 75 | # breakpoint() 76 | loss += cfg.albedo_sparsity * albedo_entropy 77 | 78 | if 'volume_albedo' in ret: 79 | albedo_entropy = gaussian_entropy(ret.volume_albedo) 80 | scalar_stats.update({'volume_entropy': albedo_entropy}) 81 | loss += cfg.albedo_sparsity * albedo_entropy 82 | 83 | if 'albedo' in ret and 'albedo_jitter' in ret: 84 | albedo_smooth = l1(ret['albedo'], ret['albedo_jitter']) 85 | scalar_stats.update({'albedo_smooth': albedo_smooth}) 86 | loss += cfg.albedo_smooth_weight * albedo_smooth 87 | 88 | if 'roughness' in ret and 'roughness_jitter' in ret: 89 | roughness_smooth = l1(ret['roughness'], ret['roughness_jitter']) 90 | scalar_stats.update({'roughness_smooth': roughness_smooth}) 91 | loss += cfg.roughness_smooth_weight * roughness_smooth 92 | 93 | if 'normal' in ret and 'normal_jitter' in ret: 94 | normal_smooth = l1(ret['normal'], ret['normal_jitter']) 95 | scalar_stats.update({'normal_smooth': normal_smooth}) 96 | loss += cfg.normal_smooth_weight * normal_smooth 97 | 98 | if 'visibility' in ret and 'visibility_jitter' in ret: 99 | visibility_smooth = l1(ret['visibility'], ret['visibility_jitter']) 100 | scalar_stats.update({'visibility_smooth': visibility_smooth}) 101 | loss += cfg.visibility_smooth_weight * visibility_smooth 102 | 103 | if 'normal' in ret and 'normal_geometry' in ret: 104 | normal_smooth = l2(ret['normal'], ret['normal_geometry']) 105 | scalar_stats.update({'normal_geometry': normal_smooth}) 106 | loss += cfg.normal_geometry_weight * normal_smooth 107 | 108 | if 'visibility' in ret and 'visibility_geometry' in ret: 109 | visibility_smooth = l2(ret['visibility'], ret['visibility_geometry']) 110 | scalar_stats.update({'visibility_geometry': visibility_smooth}) 111 | loss += cfg.visibility_geometry_weight * visibility_smooth 112 | 113 | if 'rgb_map' in ret: 114 | img_loss = mse(ret['rgb_map'], batch['rgb']) 115 | psnr = (1 / img_loss).log() * 10 / np.log(10) 116 | scalar_stats.update({'img_loss': img_loss}) 117 | scalar_stats.update({'psnr': psnr}) 118 | loss += cfg.img_loss_weight * img_loss 119 | 120 | scalar_stats.update({'loss': loss}) 121 | image_stats = {} 122 | 123 | return ret, loss, scalar_stats, image_stats 124 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | # cfg should be imported first before importing torch, to avoid cfg mismatch 3 | # and also to avoid loading cuda_driver before proper CUDA_VISIBLE_DEVICES is set 4 | from lib.config import cfg, args 5 | # fmt: on 6 | 7 | from lib.networks import make_network 8 | from lib.evaluators import make_evaluator 9 | from lib.datasets import make_data_loader 10 | from lib.utils.log_utils import log, print_colorful_stacktrace 11 | from lib.utils.net_utils import load_model, save_model, load_network, fix_random, number_of_params 12 | from lib.utils.prof_utils import setup_profiling, profiler_start, profiler_step, profiler_stop 13 | from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler 14 | 15 | import os 16 | import torch 17 | import torch.multiprocessing 18 | import torch.distributed as dist 19 | 20 | from torch import nn 21 | from rich.pretty import pprint 22 | from easyvolcap.utils.console_utils import * 23 | 24 | @catch_throw 25 | def train(cfg, network: nn.Module): 26 | setup_profiling(cfg.profiling) 27 | fix_random(cfg.fix_random) 28 | trainer = make_trainer(cfg, network) 29 | optimizer = make_optimizer(cfg, network) 30 | scheduler = make_lr_scheduler(cfg, optimizer) 31 | recorder = make_recorder(cfg) 32 | evaluator = make_evaluator(cfg) 33 | 34 | optims = [optimizer] 35 | begin_epoch = load_model(network, 36 | optims, 37 | scheduler, 38 | recorder, 39 | cfg.trained_model_dir, 40 | epoch=cfg.train.load_epoch, 41 | resume=cfg.resume, 42 | load_others=cfg.load_others, 43 | ) 44 | set_lr_scheduler(cfg, scheduler) 45 | 46 | train_loader = make_data_loader(cfg, 47 | is_train=True, 48 | is_distributed=cfg.distributed, 49 | max_iter=cfg.ep_iter * (cfg.train.epoch - begin_epoch)) 50 | val_loader = make_data_loader(cfg, is_train=False) 51 | 52 | fix_random(cfg.fix_random) 53 | profiler_start() 54 | 55 | do_train = trainer.train(begin_epoch, train_loader, optims, recorder) 56 | for epoch in range(begin_epoch, cfg.train.epoch): 57 | recorder.epoch = epoch 58 | if cfg.distributed: 59 | train_loader.batch_sampler.sampler.set_epoch(epoch) 60 | 61 | network.train() # selet training mode 62 | try: 63 | next(do_train) # actual training 64 | except RuntimeError as e: 65 | print_colorful_stacktrace() 66 | import pdbr; pdbr.post_mortem() # break on the last exception's stack for inpection 67 | scheduler.step() 68 | 69 | if (epoch + 1) % cfg.save_ep == 0: 70 | # for ZeroRedundancyOptimizer, we need to consolidate the gradients before saving the state dicts 71 | save_model(network, optims, scheduler, recorder, cfg.trained_model_dir, epoch, rank=cfg.local_rank, distributed=cfg.distributed) 72 | 73 | if (epoch + 1) % cfg.save_latest_ep == 0: 74 | save_model(network, optims, scheduler, recorder, cfg.trained_model_dir, epoch, latest=True, rank=cfg.local_rank, distributed=cfg.distributed) 75 | 76 | if (epoch + 1) % cfg.eval_ep == 0: 77 | try: 78 | trainer.val(epoch, val_loader, evaluator, recorder) # sometimes during early testing stage, evaluation is not implemented 79 | except Exception as e: 80 | print_colorful_stacktrace() 81 | log(f"exception when evaluating: {type(e)}: {e}, check your eval impl", 'red') 82 | # do not disrupt training even if validation raised an exception 83 | 84 | profiler_stop() 85 | return network 86 | 87 | 88 | def test(cfg, network): 89 | trainer = make_trainer(cfg, network) 90 | val_loader = make_data_loader(cfg, is_train=False) 91 | evaluator = make_evaluator(cfg) 92 | epoch = load_network(network, 93 | cfg.trained_model_dir, 94 | resume=cfg.resume, 95 | epoch=cfg.test.epoch) 96 | trainer.val(epoch, val_loader, evaluator) 97 | 98 | 99 | def synchronize(): 100 | """ 101 | Helper function to synchronize (barrier) among all processes when 102 | using distributed training 103 | """ 104 | if not dist.is_available(): 105 | return 106 | if not dist.is_initialized(): 107 | return 108 | world_size = dist.get_world_size() 109 | if world_size == 1: 110 | return 111 | dist.barrier() 112 | 113 | 114 | def main(): 115 | fix_random(cfg.fix_random) 116 | if cfg.distributed: 117 | cfg.local_rank = int(os.environ['RANK']) % torch.cuda.device_count() 118 | cfg.world_size = int(os.environ['WORLD_SIZE']) 119 | log(f'local rank: {cfg.local_rank}, world_size: {cfg.world_size}') 120 | torch.cuda.set_device(cfg.local_rank) 121 | dist.init_process_group(backend="nccl", init_method="env://") 122 | synchronize() 123 | 124 | network = make_network(cfg) 125 | 126 | if cfg.print_network: 127 | nop = number_of_params(network) 128 | log('') 129 | pprint(network) 130 | log(f'number of parameters: {nop}({nop / 1e6:.2f}M)') 131 | 132 | if cfg.dry_run: 133 | return 134 | 135 | if args.test: 136 | test(cfg, network) 137 | else: 138 | train(cfg, network) 139 | 140 | 141 | if __name__ == "__main__": 142 | if cfg.detect_anomaly: 143 | with torch.autograd.detect_anomaly(): 144 | main() 145 | else: 146 | main() 147 | -------------------------------------------------------------------------------- /lib/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | # everybody loves colorization... right? ... right? 2 | import os 3 | import io 4 | import re 5 | import sys 6 | from typing import List 7 | from collections import deque 8 | from termcolor import colored 9 | 10 | from rich.live import Live 11 | from rich.table import Table 12 | from rich.text import Text 13 | from tqdm import tqdm 14 | 15 | from lib.utils.base_utils import default_dotdict, dotdict 16 | 17 | NoneType = type(None) 18 | 19 | # fmt: off 20 | def red(x): return colored(x, 'red') 21 | def green(x): return colored(x, 'green') 22 | def blue(x): return colored(x, 'blue') 23 | def yellow(x): return colored(x, 'yellow') 24 | def cyan(x): return colored(x, 'cyan') 25 | def magenta(x): return colored(x, 'magenta') 26 | # fmt: on 27 | 28 | 29 | def trim_ansi(a): 30 | ESC = r'\x1b' 31 | CSI = ESC + r'\[' 32 | OSC = ESC + r'\]' 33 | CMD = '[@-~]' 34 | ST = ESC + r'\\' 35 | BEL = r'\x07' 36 | pattern = '(' + CSI + '.*?' + CMD + '|' + OSC + '.*?' + '(' + ST + '|' + BEL + ')' + ')' 37 | return re.sub(pattern, '', a) 38 | 39 | 40 | def stop_live_table(): 41 | if hasattr(update_log_stats, 'live'): 42 | update_log_stats.live.stop() # avoid strange screen duplicates 43 | 44 | 45 | def print_colorful_stacktrace(): 46 | stop_live_table() 47 | from rich.console import Console 48 | console = Console() 49 | console.print_exception() 50 | 51 | 52 | def colored_rgb(fg_color, bg_color, text): 53 | r, g, b = fg_color 54 | result = f'\033[38;2;{r};{g};{b}m{text}' 55 | r, g, b = bg_color 56 | result = f'\033[48;2;{r};{g};{b}m{result}\033[0m' 57 | return result 58 | 59 | 60 | def run_if_not_exists(cmd, outname, *args, **kwargs): 61 | # whether a file exists, whether a directory has more than 3 elements 62 | # if (os.path.exists(outname) and os.path.isfile(outname)) or (os.path.isdir(outname) and len(os.listdir(outname)) >= 3): 63 | if os.path.exists(outname): 64 | log(f'Skip: {cmd}', 'yellow') 65 | else: 66 | run(cmd, *args, **kwargs) 67 | 68 | 69 | def run(cmd, quite=False, dry_run=False): 70 | if isinstance(cmd, list): 71 | cmd = ' '.join(list(map(str, cmd))) 72 | func = sys._getframe(1).f_code.co_name 73 | if not quite: 74 | cmd_color = 'blue' if not cmd.startswith('rm') else 'red' 75 | cmd_color = 'green' if dry_run else cmd_color 76 | dry_msg = colored('[dry_run]: ', 'magenta') if dry_run else '' 77 | log(colored(func, 'yellow') + ": " + dry_msg + colored(cmd, cmd_color)) 78 | if not dry_run: 79 | code = os.system(cmd) 80 | else: 81 | code = 0 82 | if code != 0: 83 | log(colored(str(code), 'red') + " <- " + colored(func, 'yellow') + ": " + colored(cmd, 'red')) 84 | raise RuntimeError(f'{code} <- {func}: {cmd}') 85 | 86 | 87 | def log(msg, color=None, attrs=None, log_file=None): 88 | func = sys._getframe(1).f_code.co_name 89 | frame = sys._getframe(1) 90 | module = frame.f_globals['__name__'] if frame is not None else '' 91 | content = colored(module, 'blue') + " -> " + colored(func, 'green') + ": " + colored(str(msg), color, attrs) 92 | if isinstance(log_file, str): 93 | with open(log_file, 'a+') as f: 94 | f.write(trim_ansi(content) + '\n') 95 | elif isinstance(log_file, io.TextIOWrapper): 96 | log_file.write(trim_ansi(content) + '\n') 97 | tqdm.write(content) # be compatible with existing tqdm loops 98 | 99 | 100 | def create_table(name: str, 101 | columns: List[str], 102 | rows: List[List[str]] = [], 103 | styles: default_dotdict[str, NoneType] = default_dotdict(NoneType), 104 | ): 105 | table = Table(title=name, show_footer=True, show_header=False) 106 | for col in columns: 107 | table.add_column(footer=Text(col, styles[col]), style=styles[col], justify="center") 108 | 109 | for row in rows: 110 | table.add_row(*row) 111 | return table 112 | 113 | 114 | def create_live(*args, **kwargs): 115 | table = create_table(*args, **kwargs) 116 | live = Live(table, auto_refresh=False) 117 | return live 118 | 119 | 120 | def update_log_stats(states: dotdict, 121 | styles: default_dotdict[str, NoneType] = default_dotdict( 122 | NoneType, 123 | { 124 | 'eta': 'cyan', 125 | 'epoch': 'cyan', 126 | 127 | 'img_loss': 'magenta', 128 | 'psnr': 'magenta', 129 | 'loss': 'magenta', 130 | 131 | 'data': 'blue', 132 | 'batch': 'blue', 133 | } 134 | ), 135 | table_row_limit=200, 136 | ): 137 | 138 | name = states.name 139 | del states.name 140 | keys = list(states.keys()) 141 | values = list(map(str, states.values())) 142 | 143 | if not hasattr(update_log_stats, 'live'): 144 | update_log_stats.live = create_live(name, keys, [values], styles) 145 | update_log_stats.live.start() 146 | 147 | width, height = os.get_terminal_size() 148 | maxlen = max(min(height - 8, table_row_limit), 1) # 5 would fill the terminal 149 | if not hasattr(update_log_stats, 'rows'): 150 | update_log_stats.rows = deque(maxlen=maxlen) # save space for header and footer 151 | elif update_log_stats.rows.maxlen != maxlen: 152 | update_log_stats.rows = deque(list(update_log_stats.rows)[-maxlen + 1:], maxlen=maxlen) # save space for header and footer 153 | update_log_stats.live.start() 154 | update_log_stats.rows.append(values) 155 | update_log_stats.live.update(create_table(name, keys, update_log_stats.rows, styles), refresh=True) # disabled autorefresh 156 | -------------------------------------------------------------------------------- /lib/train/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import time 3 | import torch 4 | import datetime 5 | import numpy as np 6 | from typing import List 7 | from torch.optim import Optimizer 8 | 9 | from lib.config import cfg 10 | from lib.train.recorder import Recorder 11 | from lib.utils.base_utils import dotdict 12 | from lib.utils.prof_utils import profiler_step 13 | from lib.utils.data_utils import add_iter_step, to_cuda 14 | from lib.utils.log_utils import log, update_log_stats, print_colorful_stacktrace 15 | 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | 19 | class Trainer(object): 20 | def __init__(self, network): 21 | device = torch.device('cuda:{}'.format(cfg.local_rank)) 22 | network = network.to(device) 23 | if cfg.distributed: 24 | network = DDP( 25 | network, 26 | device_ids=[cfg.local_rank], 27 | output_device=cfg.local_rank, 28 | find_unused_parameters=cfg.find_unused_parameters, 29 | ) 30 | self.network = network 31 | self.device = device 32 | self.local_rank = cfg.local_rank 33 | 34 | def reduce_record_stats(self, record_stats: dotdict): 35 | reduced_stats = dotdict() 36 | for k, v in record_stats.items(): 37 | if isinstance(v, torch.Tensor): 38 | reduced_stats[k] = v.item() # MARK: will cause sync 39 | else: 40 | reduced_stats[k] = v 41 | return reduced_stats 42 | 43 | # NOTE: this is now a generator instead of simple function, to preserve the local variables 44 | def train(self, epoch, data_loader, optimizers: List[Optimizer], recorder: Recorder): 45 | ep_iter = cfg.ep_iter 46 | self.network.train() 47 | end = time.perf_counter() 48 | 49 | for index, batch in enumerate(data_loader): 50 | if hasattr(recorder, 'step'): 51 | iteration = recorder.step + 1 52 | else: 53 | iteration = 1 # the start of the training process 54 | 55 | batch = add_iter_step(batch, iteration) # DotDict 56 | batch = to_cuda(batch) # will also return as DotDict 57 | 58 | data_time = time.perf_counter() - end # data time 59 | output, loss, record_stats, image_stats = self.network(batch) 60 | 61 | for optimizer in optimizers: 62 | optimizer.zero_grad(set_to_none=True) 63 | loss = loss.mean() 64 | loss.backward() # where the actual work is done 65 | 66 | torch.nn.utils.clip_grad_norm_(self.network.parameters(), cfg.clip_grad_norm) 67 | torch.nn.utils.clip_grad_value_(self.network.parameters(), cfg.clip_grad_value) 68 | 69 | for optimizer in optimizers: 70 | optimizer.step() # all optimizers on all GPUs should perform backward step? 71 | 72 | if cfg.local_rank > 0: 73 | if iteration % ep_iter == 0: 74 | yield 75 | continue 76 | 77 | recorder.step += 1 78 | if iteration % cfg.log_interval == 0 or iteration % ep_iter == 0: 79 | record_stats = self.reduce_record_stats(record_stats) 80 | 81 | # data recording stage: loss_stats, time, image_stats 82 | batch_time = time.perf_counter() - end 83 | lr = optimizer.param_groups[0]['lr'] 84 | max_mem = torch.cuda.max_memory_allocated() / 2**20 85 | 86 | record_stats.data = data_time 87 | record_stats.batch = batch_time 88 | record_stats.lr = lr 89 | record_stats.max_mem = max_mem 90 | recorder.update_record_stats(record_stats) 91 | 92 | eta_seconds = recorder.record_stats.batch.global_avg * (cfg.train.epoch * ep_iter - recorder.step) 93 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 94 | log_stats = dotdict() 95 | log_stats.name = cfg.exp_name 96 | log_stats.eta = eta_string 97 | log_stats.update(recorder.log_stats) 98 | 99 | update_log_stats(log_stats, table_row_limit=cfg.table_row_limit) 100 | 101 | if iteration % cfg.record_interval == 0 or iteration % ep_iter == 0: 102 | # record loss_stats and image_dict 103 | recorder.update_image_stats(image_stats) 104 | recorder.record('train') 105 | 106 | profiler_step() 107 | end = time.perf_counter() 108 | 109 | if iteration % ep_iter == 0: 110 | yield 111 | 112 | def val(self, epoch, data_loader, evaluator=None, recorder=None): 113 | self.network.eval() 114 | val_loss_stats = {} 115 | data_size = len(data_loader) 116 | for batch in tqdm.tqdm(data_loader): 117 | batch = add_iter_step(batch, epoch * cfg.ep_iter) # DotDict 118 | batch = to_cuda(batch) 119 | with torch.no_grad(): 120 | 121 | output, loss, loss_stats, image_stats = self.network(batch) 122 | 123 | if evaluator is not None: 124 | evaluator.evaluate(output, batch) 125 | 126 | loss_stats = self.reduce_record_stats(loss_stats) 127 | for k, v in loss_stats.items(): 128 | val_loss_stats.setdefault(k, 0) 129 | val_loss_stats[k] += v 130 | 131 | loss_state = [] 132 | for k in val_loss_stats.keys(): 133 | val_loss_stats[k] /= data_size 134 | loss_state.append('{}: {:.4f}'.format(k, val_loss_stats[k])) 135 | log(loss_state) 136 | 137 | if evaluator is not None: 138 | result = evaluator.summarize() 139 | val_loss_stats.update(result) 140 | 141 | if recorder: 142 | recorder.record('val', epoch, val_loss_stats, image_stats) 143 | 144 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2024 (Highlight)] Relightable and Animatable Neural Avatar from Sparse-View Video 2 | 3 | [Paper](https://arxiv.org/abs/2308.07903) | [Project Page](https://zju3dv.github.io/relightable_avatar) | [Video](https://youtu.be/BQ3pL7Uwbdk) 4 | 5 | 6 | 7 | ![teaser_video](https://github.com/dendenxu/relightable_avatar/assets/43734697/874231a3-3366-4d2f-a081-05ccf05e4096) 8 | 9 | This paper tackles the challenge of creating relightable and animatable neural avatars from sparse-view (or even monocular) videos of dynamic humans under unknown illumination. 10 | Compared to studio environments, this setting is more practical and accessible but poses an extremely challenging ill-posed problem. 11 | 12 | 13 | ## Quick Start 14 | 15 | ### Prepare Trained Model 16 | 17 | We provide an example trained model for the *xuzhen* sequence of the MobileStage dataset: 18 | - The base `AniSDF` model can be downloaded from here: [anisdf.zip](https://github.com/zju3dv/RelightableAvatar/files/15020629/anisdf.zip). 19 | - The `RelightableAvatar` model can be downloaded from here: [relightable.zip](https://github.com/zju3dv/RelightableAvatar/files/15020631/relightable.zip). 20 | - Furthermore, you'll need to download a skeleton dataset (very small, only with some basic information needed to run `relightable_avatar`) here: [minimal.tar.gz](https://github.com/zju3dv/RelightableAvatar/files/15020633/minimal.tar.gz). 21 | - The skeleton dataset is only required if the full dataset hasn't been downloaded and placed at its corresponding location. 22 | - For relighting, we also provide the downscaled environment map: [16x32.zip](https://github.com/zju3dv/RelightableAvatar/files/15020635/16x32.zip). If you see errors about `data/lighting`, download this. 23 | 24 | Trained model and skeleton data placement: 25 | - The base `AniSDF` model should be put in `data/trained_model/deform/xuzhen_12v_geo`, after which we expect `latest.pth` to reside at `data/trained_model/deform/xuzhen_12v_geo/latest.pth`. 26 | - The `RelightableAvatar` model should be put in `data/trained_model/relight/xuzhen_12v_geo_fix_mat`, after which we expect `latest.pth` to reside at `data/trained_model/deform/xuzhen_12v_geo_fix_mat/latest.pth`. 27 | - The skeleton dataset should be extracted at `data/mobile_stage/xuzhen`, leading to `data/mobile_stage/xuzhen/motion.npz...`. 28 | - The environment map should be placed at `data/lighting`, after which a `data/lighting/16x32` folder is expected. 29 | 30 | ### Prepare Custom Pose 31 | 32 | For the human pose, we use a compact `motion.npz` to store the pose, shape and global translation parameters. 33 | You can find an example file at `data/mobile_stage/xuzhen/motion.npz`. 34 | If you've downloaded the skeleton data provided above, you should also see other motion files ending with `.npz`. 35 | 36 | We also provide a script for preparing other common motion formats into our `motion.npz` structure at `scripts/toosl/prepare_motion.py`. 37 | You can learn more about the structure of `motion.npz` in this script. 38 | 39 | ### Run the AniSDF Model With Custom Pose 40 | 41 | ```shell 42 | # Fixed view + novel pose 43 | python run.py -t visualize -c configs/mobile_stage/xuzhen_12v_geo.yaml ground_attach_envmap False vis_pose_sequence True num_eval_frame 100 H 512 W 512 novel_view_ixt_ratio 0.80 vis_ext .png test_view 0, test_motion gPO_sFM_cAll_d12_mPO1_ch16.npz 44 | 45 | # Novel rotating view + novel pose 46 | python run.py -t visualize -c configs/mobile_stage/xuzhen_12v_geo.yaml ground_attach_envmap False vis_novel_view True perform True num_render_view 100 H 512 W 512 novel_view_ixt_ratio 0.80 vis_ext .png test_motion gPO_sFM_cAll_d12_mPO1_ch16.npz 47 | 48 | # For faster rendering, use sphere tracing instead of volume rendering by adding the `vis_sphere_tracing True` entry 49 | # Will speed up the rendering, but might produce artifacts 50 | ``` 51 | 52 | Try to tune these entries `H 512 W 512 novel_view_ixt_ratio 0.80` to customize your output image. 53 | Moreover, select the source view using `test_view 0,` and the motion using `test_motion gPO_sFM_cAll_d12_mPO1_ch16.npz`. 54 | `num_eval_frame` and `num_render_view` control the number of rendered images for the novel pose and novel view setting, respectively. 55 | 56 | Example motions files are provided at `data/mobile_stage/xuzhen/*.npz`. 57 | To use skeleton data, customize your dataset root using `test_dataset.data_root `. 58 | The recommended way of switching to another set of motions is to put the prepared motion file into `` (wherever the `test_dataset.data_root` points to) and set `test_motion`. 59 | You can also use `test_motion` to specify a motion file outside the dataset root by providing an absolute path to the motion file. 60 | 61 | ### Run the Relightable Model With Custom Pose 62 | 63 | ```shell 64 | python run.py -t visualize -c configs/mobile_stage/xuzhen_12v_geo.yaml relighting True vis_novel_light True vis_pose_sequence True vis_rendering_map True vis_shading_map True vis_albedo_map True vis_normal_map True vis_envmap_map True vis_roughness_map True vis_specular_map True vis_surface_map True vis_residual_map True vis_depth_map True num_eval_frame 100 H 512 W 512 novel_view_ixt_ratio 0.80 vis_ext .png vis_ground_shading True test_light '["main", "venetian_crossroads", "pink_sunrise", "shanghai_bund", "venice_sunrise", "quattro_canti", "olat0002-0027", "olat0004-0019"]' test_view 0, extra_prefix "gPO_sFM_cAll_d12_mPO1_ch16" test_motion gPO_sFM_cAll_d12_mPO1_ch16.npz 65 | ``` 66 | 67 | ## Todo 68 | 69 | - [ ] Add documentation on training on the SyntheticHuman++ dataset 70 | - [ ] Add documentation on training on the MobileStage dataset 71 | 72 | ## Citation 73 | 74 | If you find this code useful for your research, please cite us using the following BibTeX entry. 75 | 76 | ```bibtex 77 | @inproceedings{xu2024relightable, 78 | title={Relightable and Animatable Neural Avatar from Sparse-View Video}, 79 | author={Xu, Zhen and Peng, Sida and Geng, Chen and Mou, Linzhan and Yan, Zihan and Sun, Jiaming and Bao, Hujun and Zhou, Xiaowei}, 80 | booktitle={CVPR}, 81 | year={2024} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /lib/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from torch.utils.data.sampler import BatchSampler 3 | import numpy as np 4 | import torch 5 | import math 6 | import torch.distributed as dist 7 | from lib.config import cfg 8 | 9 | 10 | 11 | class ImageSizeBatchSampler(Sampler): 12 | def __init__(self, sampler, batch_size, drop_last, sampler_meta): 13 | self.sampler = sampler 14 | self.batch_size = batch_size 15 | self.drop_last = drop_last 16 | self.strategy = sampler_meta.strategy 17 | self.hmin, self.wmin = sampler_meta.min_hw 18 | self.hmax, self.wmax = sampler_meta.max_hw 19 | self.divisor = 32 20 | 21 | def generate_height_width(self): 22 | if self.strategy == 'origin': 23 | return -1, -1 24 | h = np.random.randint(self.hmin, self.hmax + 1) 25 | w = np.random.randint(self.wmin, self.wmax + 1) 26 | h = (h | (self.divisor - 1)) + 1 27 | w = (w | (self.divisor - 1)) + 1 28 | return h, w 29 | 30 | def __iter__(self): 31 | batch = [] 32 | h, w = self.generate_height_width() 33 | for idx in self.sampler: 34 | batch.append((idx, h, w)) 35 | if len(batch) == self.batch_size: 36 | h, w = self.generate_height_width() 37 | yield batch 38 | batch = [] 39 | if len(batch) > 0 and not self.drop_last: 40 | yield batch 41 | 42 | def __len__(self): 43 | if self.drop_last: 44 | return len(self.sampler) // self.batch_size 45 | else: 46 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 47 | 48 | 49 | class IterationBasedBatchSampler(BatchSampler): 50 | """ 51 | Wraps a BatchSampler, resampling from it until 52 | a specified number of iterations have been sampled 53 | """ 54 | 55 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 56 | self.batch_sampler = batch_sampler 57 | self.sampler = self.batch_sampler.sampler 58 | self.num_iterations = num_iterations 59 | self.start_iter = start_iter 60 | 61 | def __iter__(self): 62 | iteration = self.start_iter 63 | while iteration <= self.num_iterations: 64 | for batch in self.batch_sampler: 65 | iteration += 1 66 | if iteration > self.num_iterations: 67 | break 68 | yield batch 69 | 70 | def __len__(self): 71 | return self.num_iterations 72 | 73 | 74 | class DistributedSampler(Sampler): 75 | """Sampler that restricts data loading to a subset of the dataset. 76 | It is especially useful in conjunction with 77 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 78 | process can pass a DistributedSampler instance as a DataLoader sampler, 79 | and load a subset of the original dataset that is exclusive to it. 80 | .. note:: 81 | Dataset is assumed to be of constant size. 82 | Arguments: 83 | dataset: Dataset used for sampling. 84 | num_replicas (optional): Number of processes participating in 85 | distributed training. 86 | rank (optional): Rank of the current process within num_replicas. 87 | """ 88 | 89 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 90 | if num_replicas is None: 91 | if not dist.is_available(): 92 | raise RuntimeError("Requires distributed package to be available") 93 | num_replicas = dist.get_world_size() 94 | if rank is None: 95 | if not dist.is_available(): 96 | raise RuntimeError("Requires distributed package to be available") 97 | rank = dist.get_rank() 98 | self.dataset = dataset 99 | self.num_replicas = num_replicas 100 | self.rank = rank 101 | self.epoch = 0 102 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 103 | self.total_size = self.num_samples * self.num_replicas 104 | self.shuffle = shuffle 105 | 106 | def __iter__(self): 107 | if self.shuffle: 108 | # deterministically shuffle based on epoch 109 | g = torch.Generator() 110 | g.manual_seed(self.epoch) 111 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 112 | else: 113 | indices = torch.arange(len(self.dataset)).tolist() 114 | 115 | # add extra samples to make it evenly divisible 116 | indices += indices[: (self.total_size - len(indices))] 117 | assert len(indices) == self.total_size 118 | 119 | # subsample 120 | offset = self.num_samples * self.rank 121 | indices = indices[offset:offset+self.num_samples] 122 | assert len(indices) == self.num_samples 123 | 124 | return iter(indices) 125 | 126 | def __len__(self): 127 | return self.num_samples 128 | 129 | def set_epoch(self, epoch): 130 | self.epoch = epoch 131 | 132 | 133 | class FrameSampler(Sampler): 134 | """Sampler certain frames for test 135 | """ 136 | 137 | def __init__(self, dataset): 138 | inds = np.arange(0, len(dataset)) 139 | ni = len(dataset) // dataset.num_cams 140 | inds = inds.reshape(ni, -1)[::cfg.test.frame_sampler_interval, ::cfg.test.view_sampler_interval][cfg.begin_ith_latent:] 141 | self.inds = inds.ravel() 142 | 143 | def __iter__(self): 144 | return iter(self.inds) 145 | 146 | def __len__(self): 147 | return len(self.inds) 148 | 149 | 150 | class MeshFrameSampler(FrameSampler): 151 | """Sampler certain frames for test, including a -1 frame for original mesh 152 | """ 153 | 154 | def __init__(self, dataset): 155 | super(MeshFrameSampler, self).__init__(dataset) 156 | if cfg.vis_can_mesh: 157 | self.inds = [-1] 158 | elif cfg.vis_tpose_mesh and cfg.track_tpose_mesh: 159 | self.inds = np.concatenate([[-1], self.inds]) 160 | -------------------------------------------------------------------------------- /lib/networks/relight/relight_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from functools import partial 5 | 6 | from lib.config import cfg 7 | from lib.utils.base_utils import dotdict 8 | from lib.utils.data_utils import load_mesh 9 | from lib.utils.blend_utils import pose_dirs_to_tpose_dirs, pose_dirs_to_world_dirs, tpose_dirs_to_pose_dirs 10 | from lib.utils.net_utils import normalize, MLP, load_network, freeze_module, make_params, make_buffer, take_gradient, multi_scatter, batch_aware_indexing, multi_gather, unfreeze_module, multi_scatter_, GradModule, torch_dot 11 | from lib.utils.relight_utils import Microfacet, gen_light_xyz 12 | from lib.networks.deform import base_network 13 | from lib.networks.embedder import get_embedder 14 | from lib.utils.log_utils import log 15 | 16 | 17 | class Network(base_network.Network): 18 | def __init__( 19 | self, 20 | fresnel_f0=cfg.fresnel_f0, 21 | geometry_mesh=cfg.geometry_mesh, 22 | geometry_pretrain=cfg.geometry_pretrain, 23 | xyz_res=cfg.relight_xyz_res, 24 | view_res=cfg.relight_view_res, # are there still view encoding needs? 25 | env_h=cfg.env_h, 26 | env_w=cfg.env_w, 27 | env_r=cfg.env_r, 28 | achro_light=cfg.achro_light, 29 | xyz_noise_std=cfg.xyz_noise_std, 30 | *args, 31 | **kwargs, 32 | ): 33 | super(Network, self).__init__(*args, **kwargs) # do not inherite network structure 34 | 35 | # load the pretrained model of the network 36 | load_network(self, geometry_pretrain, strict=False) 37 | freeze_module(self.render_network) 38 | 39 | self.xyz_embedder, self.xyz_dim = get_embedder(xyz_res, 3) # no parameters 40 | self.view_embedder, self.view_dim = get_embedder(view_res, 3) # no parameters 41 | 42 | self.prepare_relight_network() 43 | self.prepare_relight_metadata() 44 | 45 | def prepare_relight_network(self): 46 | self.albedo_network = MLP(input_ch=self.feature_dim, W=cfg.relight_network_width, D=cfg.relight_network_depth, out_ch=3, actvn=nn.Softplus(beta=100), out_actvn=lambda x: cfg.albedo_slope * torch.sigmoid(x) + cfg.albedo_bias, init=nn.init.kaiming_normal_) # albedo and roughness can only be in [0, 1] 47 | self.roughness_network = MLP(input_ch=self.feature_dim, W=cfg.relight_network_width, D=cfg.relight_network_depth, out_ch=1, actvn=nn.Softplus(beta=100), out_actvn=lambda x: cfg.roughness_slope * torch.sigmoid(x) + cfg.roughness_bias, init=nn.init.kaiming_normal_) # albedo and roughness can only be in [0, 1] 48 | 49 | def prepare_relight_metadata(self, 50 | fresnel_f0=cfg.fresnel_f0, 51 | env_h=cfg.env_h, 52 | env_w=cfg.env_w, 53 | env_r=cfg.env_r, 54 | achro_light=cfg.achro_light, 55 | xyz_noise_std=cfg.xyz_noise_std, 56 | lambert_only=cfg.lambert_only, 57 | glossy_only=cfg.glossy_only, 58 | ): 59 | 60 | self.microfacet = Microfacet(f0=fresnel_f0, lambert_only=lambert_only, glossy_only=glossy_only) # no parameters 61 | 62 | # use an optimizable fixed environment lighting system for now 63 | if achro_light: 64 | self.global_env_map_ = make_params(torch.rand(env_h * cfg.envmap_upscale, env_w * cfg.envmap_upscale, 1) * cfg.envmap_init_intensity) 65 | else: 66 | self.global_env_map_ = make_params(torch.rand(env_h * cfg.envmap_upscale, env_w * cfg.envmap_upscale, 3) * cfg.envmap_init_intensity) 67 | 68 | xyz, area = gen_light_xyz(env_h, env_w, env_r, device='cpu') # eH, eW, 3; eH, eW 69 | sharp = 1 / (area / torch.pi).sqrt() # as in tangent, H, W, how much was obsecured 70 | self.light_xyz_ = make_buffer(xyz) 71 | self.light_area = make_buffer(area) 72 | self.light_sharp = make_buffer(sharp) 73 | 74 | # other regularization related configuration entry 75 | self.xyz_noise_std = xyz_noise_std 76 | self.achro_light = achro_light 77 | self.env_h, self.env_w, self.env_r = env_h, env_w, env_r 78 | 79 | @property 80 | def light_xyz(self): 81 | if self.training: 82 | return self.light_xyz_ + torch.randn_like(self.light_xyz_) * cfg.light_xyz_noise_std 83 | else: 84 | return self.light_xyz_ 85 | 86 | @property 87 | def global_env_map(self): 88 | # TODO: fix this ugly impl 89 | return F.softplus(self.global_env_map_.expand(*self.global_env_map_.shape[:2], 3)) 90 | 91 | def forward(self, x: torch.Tensor, v: torch.Tensor, d: torch.Tensor, batch: dotdict): # NOTE: viewdirection is not used here 92 | ret, out = self.forward_geometry(x, None, d, batch) 93 | 94 | # first we need to get the albedo and specular parameters from the MLP 95 | # ebd = self.xyz_embedder(out.cpts) # albedo lives in canonical 96 | # input = torch.cat([ebd, out.feat], dim=-1) 97 | albedo = self.albedo_network(out.feat) 98 | roughness = self.roughness_network(out.feat) 99 | 100 | # apply back the previously done filtering 101 | raw = torch.cat([albedo, roughness, out.norm, out.occ], dim=-1) 102 | if not self.training: 103 | raw = torch.cat([out.cpts, out.bpts, out.resd, raw], dim=-1) 104 | ret.raw = multi_scatter(raw.new_zeros(*x.shape[:-1], raw.shape[-1]), out.inds, raw) 105 | 106 | # apply regularization 107 | if self.training: 108 | # define smoothness loss on jitter points 109 | cpts: torch.Tensor = out.cpts # B, P, 3 110 | xyz_noise = torch.normal(mean=0, std=self.xyz_noise_std, size=cpts.shape, device=cpts.device) 111 | xyz_jitter = cpts + xyz_noise 112 | # ebd_jitter = self.xyz_embedder(xyz_jitter) # albedo lives in canonical 113 | feat_jitter = self.signed_distance_network.feat(xyz_jitter) 114 | # input = torch.cat([ebd_jitter, feat_jitter], dim=-1) 115 | ret.albedo = albedo 116 | ret.roughness = roughness 117 | ret.albedo_jitter = self.albedo_network(feat_jitter) 118 | ret.roughness_jitter = self.roughness_network(feat_jitter) 119 | 120 | return ret 121 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | # Task Configuration 2 | task: deform 3 | exp_name: base_my_313 4 | 5 | # Module Configuration 6 | train_dataset_module: lib.datasets.base_dataset 7 | test_dataset_module: lib.datasets.base_dataset 8 | network_module: lib.networks.deform.base_network 9 | renderer_module: lib.networks.renderer.base_renderer 10 | trainer_module: lib.train.trainers.base_trainer 11 | evaluator_module: lib.evaluators.base_evaluator 12 | visualizer_module: lib.visualizers.base_visualizer 13 | 14 | # Data Configuration 15 | training_view: [0, 3, 6, 9, 12, 15, 18] # try 7 views 16 | # prettier-ignore 17 | test_view: [] # to get a consistent visualization 18 | 19 | ratio: 1.0 20 | frame_interval: 1 21 | begin_ith_frame: 0 22 | num_train_frame: 300 23 | num_eval_frame: 600 24 | num_render_view: 300 25 | 26 | train_dataset: 27 | data_root: data/my_zju_mocap/my_313 28 | human: my_313 29 | ann_file: annots.npy 30 | split: train 31 | test_dataset: 32 | data_root: data/my_zju_mocap/my_313 33 | human: my_313 34 | ann_file: annots.npy 35 | split: test 36 | 37 | # Mask Configuration 38 | mask: mask 39 | erode_dilate_mask: False 40 | 41 | # SMPL & Pose related configuration 42 | train_motion: motion.npz # relative to data root 43 | test_motion: motion.npz # relative to data root 44 | body_model: easymocap/output-output-smpl-3d/cfg_model.yml # controls all previous parameters 45 | 46 | # Network Configuration 47 | xyz_res: 10 48 | sdf_res: 8 49 | view_res: 4 50 | 51 | # Loss Configuration 52 | resd_loss_weight: 0.1 53 | img_loss_weight: 1.0 54 | eikonal_loss_weight: 0.01 # smoother canonical mesh 55 | observed_eikonal_loss_weight: 0.005 # smoother residual deformation -> also smoother canonical mesh 56 | msk_loss_weight: 0.01 57 | 58 | # Training Configuration 59 | train: 60 | batch_size: 4 # ddp 8 61 | collator: '' 62 | lr: 5e-4 63 | weight_decay: 0.0 64 | epoch: 400 65 | scheduler: 66 | type: exponential 67 | gamma: 0.1 68 | decay_epochs: 400 69 | num_workers: 4 # avoid excessive memory usage 70 | sampler: RandomSampler 71 | test: 72 | sampler: FrameSampler 73 | frame_sampler_interval: 30 74 | batch_size: 1 75 | collator: '' 76 | 77 | n_rays: 1024 # 32 x 32? 78 | n_samples: 128 79 | save_ep: 50 80 | eval_ep: 400 81 | ep_iter: 500 82 | log_interval: 1 83 | save_latest_ep: 1 84 | record_interval: 1 85 | 86 | # Threshold Configuration 87 | norm_th: 0.1 88 | dist_th: 0.1 89 | surf_reg_th: 0.02 90 | clip_near: 0.02 91 | perturb: 1.0 92 | bg_brightness: 0.0 93 | 94 | # Chunkify Configuration 95 | train_chunk_size: 4096 96 | render_chunk_size: 8192 97 | network_chunk_size: 262144 98 | voxel_size: [0.005, 0.005, 0.005] 99 | 100 | # Visualization Configuration 101 | pose_seq_cfg: 102 | train_dataset_module: lib.datasets.pose_dataset 103 | test_dataset_module: lib.datasets.pose_dataset 104 | visualizer_module: lib.visualizers.pose_visualizer 105 | test: 106 | frame_sampler_interval: 1 107 | view_sampler_interval: 1 108 | test_view: [0] 109 | 110 | novel_view_cfg: 111 | train_dataset_module: lib.datasets.demo_dataset 112 | test_dataset_module: lib.datasets.demo_dataset 113 | visualizer_module: lib.visualizers.demo_visualizer 114 | test: 115 | frame_sampler_interval: 1 116 | view_sampler_interval: 1 117 | 118 | mesh_cfg: 119 | renderer_module: lib.networks.renderer.mesh_renderer 120 | train_dataset_module: lib.datasets.mesh_dataset 121 | test_dataset_module: lib.datasets.mesh_dataset 122 | evaluator_module: lib.evaluators.mesh_evaluator 123 | visualizer_module: lib.visualizers.mesh_visualizer 124 | test: 125 | sampler: MeshFrameSampler 126 | view_sampler_interval: 1 127 | frame_sampler_interval: 100 128 | dist_th: 0.1 129 | mesh_th: 0.5 130 | mesh_th_to_sdf: True 131 | 132 | sphere_tracing_cfg: # will force a sphere tracing renderer 133 | n_samples: 3 134 | render_chunk_size: 65536 135 | network_chunk_size: 1048576 136 | renderer_module: lib.networks.renderer.sphere_tracing_renderer 137 | 138 | relighting_cfg: 139 | # Experiment Configuration 140 | task: relight 141 | exp_name: relight_my_313 142 | geometry_mesh: data/animation/deform/base_my_313/can_mesh.npz 143 | geometry_pretrain: data/trained_model/deform/base_my_313 144 | 145 | # Module Configuration 146 | trainer_module: lib.train.trainers.relight_trainer # loss computation 147 | network_module: lib.networks.relight.relight_network # chunk of the network 148 | renderer_module: lib.networks.renderer.sphere_tracing_renderer # general purpose rendering (no params) 149 | 150 | # Training Configuration 151 | train: 152 | # torchrun --nproc_per_node=2 train.py -c configs/relight_my_313.yaml distributed True train.batch_size 4 153 | batch_size: 2 # typical batch size is 8, use 2x ddp for that 154 | epoch: 100 155 | lr: 5.0e-3 # turns out this is importance for reconstructing good environment map 156 | scheduler: 157 | type: exponential 158 | # type: warmup_exponential 159 | # warmup_epochs: 2 160 | # warmup_factor: 0.1 161 | # warmup_method: linear 162 | gamma: 0.1 163 | decay_epochs: 100 164 | lr_table: 165 | residual_deformation_network: 5.0e-6 # base geometry should not change much 166 | signed_distance_network: 5.0e-6 167 | roughness_network: 5.0e-5 # roughness should be learned slowly, more learning rate tuning 168 | # albedo_netowrk: 5.0e-4 # albedo should be learned slowly, more learning rate tuning 169 | n_samples: 3 # 3 surface volume samples 170 | render_chunk_size: 65536 171 | network_chunk_size: 1048576 # large render_chunk size? 172 | eval_ep: 100 173 | save_ep: 10 174 | 175 | # Loss Configuration 176 | albedo_sparsity: 5.0e-5 177 | albedo_smooth_weight: 5.0e-3 178 | roughness_smooth_weight: 5.0e-5 179 | img_loss_weight: 10.0 180 | eikonal_loss_weight: 0.05 # larger values for better regularization 181 | observed_eikonal_loss_weight: 0.025 182 | msk_loss_weight: 0.1 # avoid worse geometry 183 | 184 | mesh_simp_face: -1 185 | mesh_th_to_sdf: False 186 | mesh_th: 0.0 187 | 188 | # Visualization Configuration 189 | novel_view_cfg: 190 | renderer_module: lib.networks.renderer.sphere_tracing_renderer # general purpose rendering (no params) 191 | pose_seq_cfg: 192 | renderer_module: lib.networks.renderer.sphere_tracing_renderer # general purpose rendering (no params) 193 | novel_light_cfg: # will use a relighting rendering to store relighting results 194 | renderer_module: lib.networks.renderer.novel_light_sphere_tracing 195 | visualizer_module: lib.visualizers.light_visualizer # do not try to remove visualizer 196 | # test_dataset_module: lib.datasets.pose_dataset 197 | # test: 198 | # frame_sampler_interval: 1 199 | # perform: True 200 | # num_eval_frame: 1 201 | # num_render_view: 1 202 | # test_view: [0] 203 | --------------------------------------------------------------------------------