├── .github └── workflows │ └── stale.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── bash ├── assets │ ├── checksum.json │ └── urls │ │ ├── cropped_images.txt │ │ ├── feat.txt │ │ ├── images.txt │ │ ├── mano.txt │ │ ├── misc.txt │ │ ├── mocap.txt │ │ ├── models.txt │ │ ├── smplx.txt │ │ └── splits.txt ├── clean_downloads.sh ├── download_baselines.sh ├── download_body_models.sh ├── download_cropped_images.sh ├── download_dry_run.sh ├── download_feat.sh ├── download_images.sh ├── download_misc.sh ├── download_mocap.sh └── download_splits.sh ├── common ├── .gitignore ├── ___init___.py ├── abstract_pl.py ├── args_utils.py ├── body_models.py ├── camera.py ├── comet_utils.py ├── data_utils.py ├── ld_utils.py ├── list_utils.py ├── mesh.py ├── metrics.py ├── np_utils.py ├── object_tensors.py ├── pl_utils.py ├── rend_utils.py ├── rot.py ├── sys_utils.py ├── thing.py ├── torch_utils.py ├── transforms.py ├── viewer.py ├── vis_utils.py └── xdict.py ├── docs ├── data │ ├── README.md │ ├── data_doc.md │ ├── mano_right.png │ ├── processing.md │ └── visualize.md ├── faq.md ├── leaderboard.md ├── leaderboard_format.md ├── model │ ├── README.md │ ├── extraction.md │ └── train.md ├── purchase.md ├── setup.md ├── static │ ├── aitviewer-logo.svg │ ├── arctic-logo.svg │ ├── dexterous.gif │ ├── hold │ │ ├── mug_ours.gif │ │ └── mug_ref.png │ ├── misalignment.png │ ├── teaser.jpeg │ └── viewer_demo.gif └── stock_photos │ ├── .DS_Store │ ├── box.jpg │ ├── coffee_machine.jpg │ ├── expresso_machine.jpg │ ├── google.png │ ├── ketchup.jpg │ ├── laptop.jpg │ ├── microwave.jpg │ ├── mixer.jpg │ ├── notebook.jpg │ ├── phone.jpg │ ├── scissors.jpg │ └── waffleiron.jpg ├── requirements.txt ├── scripts_data ├── build_splits.py ├── checksum.py ├── crop_images.py ├── download_data.py ├── mocap_viewer.py ├── process_seqs.py ├── unzip_download.py └── visualizer.py ├── scripts_method ├── build_feat_split.py ├── evaluate_metrics.py ├── extract_predicts.py ├── train.py └── visualizer.py └── src ├── arctic ├── preprocess_dataset.py ├── processing.py └── split.py ├── callbacks ├── loss │ ├── loss_arctic_lstm.py │ ├── loss_arctic_sf.py │ └── loss_field.py ├── process │ ├── process_arctic.py │ ├── process_field.py │ └── process_generic.py └── vis │ ├── visualize_arctic.py │ └── visualize_field.py ├── datasets ├── arctic_dataset.py ├── arctic_dataset_eval.py ├── dataset_utils.py ├── tempo_dataset.py ├── tempo_inference_dataset.py └── tempo_inference_dataset_eval.py ├── extraction ├── interface.py └── keys │ ├── eval_field.py │ ├── eval_pose.py │ ├── feat_field.py │ ├── feat_pose.py │ ├── submit_field.py │ ├── submit_pose.py │ ├── vis_field.py │ └── vis_pose.py ├── factory.py ├── mesh_loaders ├── arctic.py ├── field.py └── pose.py ├── models ├── __init__.py ├── arctic_lstm │ ├── model.py │ └── wrapper.py ├── arctic_sf │ ├── model.py │ └── wrapper.py ├── field_lstm │ ├── model.py │ └── wrapper.py ├── field_sf │ ├── model.py │ └── wrapper.py └── generic │ └── wrapper.py ├── nets ├── backbone │ ├── __init__.py │ ├── resnet.py │ └── utils.py ├── hand_heads │ ├── hand_hmr.py │ └── mano_head.py ├── hmr_layer.py ├── obj_heads │ ├── obj_head.py │ └── obj_hmr.py └── pointnet.py ├── parsers ├── configs │ ├── arctic_lstm.py │ ├── arctic_sf.py │ ├── field_lstm.py │ ├── field_sf.py │ └── generic.py ├── generic_parser.py └── parser.py └── utils ├── const.py ├── eval_modules.py ├── interfield.py ├── loss_modules.py └── mdev.py /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. 2 | # 3 | # You can adjust the behavior by modifying this file. 4 | # For more information, see: 5 | # https://github.com/actions/stale 6 | name: Mark stale issues and pull requests 7 | 8 | on: 9 | schedule: 10 | - cron: '39 22 * * *' 11 | 12 | jobs: 13 | stale: 14 | 15 | runs-on: ubuntu-latest 16 | permissions: 17 | issues: write 18 | pull-requests: write 19 | 20 | steps: 21 | - uses: actions/stale@v5 22 | with: 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | stale-issue-message: 'Stale issue message' 25 | stale-pr-message: 'Stale pull request message' 26 | stale-issue-label: 'no-issue-activity' 27 | stale-pr-label: 'no-pr-activity' 28 | days-before-stale: 30 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | /data 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean fix imports sort 2 | 3 | ## Delete all compiled Python files 4 | clean: 5 | find . -type f -name "*.py[co]" -delete 6 | find . -type d -name "__pycache__" -delete 7 | rm -rf condor_logs/* 8 | find run_scripts/* -delete 9 | find logs/* -delete 10 | fix: 11 | black src common scripts_method scripts_data 12 | sort: 13 | isort src common scripts_method scripts_data --wrap-length=1 --combine-as --trailing-comma --use-parentheses 14 | imports: 15 | autoflake -i -r --remove-all-unused-imports src common scripts_method scripts_data 16 | -------------------------------------------------------------------------------- /bash/assets/urls/feat.txt: -------------------------------------------------------------------------------- 1 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/data/feat.zip -------------------------------------------------------------------------------- /bash/assets/urls/mano.txt: -------------------------------------------------------------------------------- 1 | https://download.is.tue.mpg.de/download.php?domain=mano&resume=1&sfile=mano_v1_2.zip -------------------------------------------------------------------------------- /bash/assets/urls/misc.txt: -------------------------------------------------------------------------------- 1 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/data/splits_json.zip 2 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/data/raw_seqs.zip 3 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/data/meta.zip 4 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/data/images_zips/backgrounds.zip 5 | -------------------------------------------------------------------------------- /bash/assets/urls/mocap.txt: -------------------------------------------------------------------------------- 1 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/data/mocap/mocap_c3d.zip 2 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/data/mocap/mocap_npy.zip 3 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/data/mocap/smplx_corres.zip 4 | -------------------------------------------------------------------------------- /bash/assets/urls/models.txt: -------------------------------------------------------------------------------- 1 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/models.zip 2 | -------------------------------------------------------------------------------- /bash/assets/urls/smplx.txt: -------------------------------------------------------------------------------- 1 | https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=models_smplx_v1_1.zip -------------------------------------------------------------------------------- /bash/assets/urls/splits.txt: -------------------------------------------------------------------------------- 1 | https://download.is.tue.mpg.de/download.php?domain=arctic&resume=1&sfile=arctic_release/c7216c3b205186106a1f8326ed7b948f838e4907e69b21c8b3c87bb69d87206e/v1_0/data/splits.zip -------------------------------------------------------------------------------- /bash/clean_downloads.sh: -------------------------------------------------------------------------------- 1 | find downloads unpack render_out outputs -delete # clear dry run data 2 | -------------------------------------------------------------------------------- /bash/download_baselines.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Downloading model weights" 5 | mkdir -p downloads/ 6 | python scripts_data/download_data.py --url_file ./bash/assets/urls/models.txt --out_folder downloads 7 | -------------------------------------------------------------------------------- /bash/download_body_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Downloading SMPLX" 5 | mkdir -p downloads 6 | python scripts_data/download_data.py --url_file ./bash/assets/urls/smplx.txt --out_folder downloads 7 | unzip downloads/models_smplx_v1_1.zip 8 | mv models body_models 9 | 10 | echo "Downloading MANO" 11 | python scripts_data/download_data.py --url_file ./bash/assets/urls/mano.txt --out_folder downloads 12 | 13 | mkdir -p unpack 14 | cd downloads 15 | unzip mano_v1_2.zip 16 | mv mano_v1_2/models ../body_models/mano 17 | cd .. 18 | mv body_models unpack 19 | -------------------------------------------------------------------------------- /bash/download_cropped_images.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Downloading cropped images" 5 | mkdir -p downloads/data/cropped_images_zips 6 | python scripts_data/download_data.py --url_file ./bash/assets/urls/cropped_images.txt --out_folder downloads/data/cropped_images_zips 7 | -------------------------------------------------------------------------------- /bash/download_dry_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Downloading smaller files" 5 | mkdir -p downloads/data 6 | python scripts_data/download_data.py --url_file ./bash/assets/urls/misc.txt --out_folder downloads/data --dry_run 7 | 8 | echo "Downloading model weights" 9 | mkdir -p downloads/ 10 | python scripts_data/download_data.py --url_file ./bash/assets/urls/models.txt --out_folder downloads --dry_run 11 | 12 | echo "Downloading cropped images" 13 | mkdir -p downloads/data/cropped_images_zips 14 | python scripts_data/download_data.py --url_file ./bash/assets/urls/cropped_images.txt --out_folder downloads/data/cropped_images_zips --dry_run 15 | 16 | echo "Downloading full resolution images" 17 | mkdir -p downloads/data/images_zips 18 | python scripts_data/download_data.py --url_file ./bash/assets/urls/images.txt --out_folder downloads/data/images_zips --dry_run 19 | 20 | echo "Downloading SMPLX" 21 | mkdir -p downloads 22 | python scripts_data/download_data.py --url_file ./bash/assets/urls/smplx.txt --out_folder downloads 23 | unzip downloads/models_smplx_v1_1.zip 24 | mv models body_models 25 | 26 | echo "Downloading MANO" 27 | python scripts_data/download_data.py --url_file ./bash/assets/urls/mano.txt --out_folder downloads 28 | 29 | mkdir unpack 30 | cd downloads 31 | unzip mano_v1_2.zip 32 | mv mano_v1_2/models ../body_models/mano 33 | cd .. 34 | mv body_models unpack 35 | -------------------------------------------------------------------------------- /bash/download_feat.sh: -------------------------------------------------------------------------------- 1 | echo "Downloading features files" 2 | mkdir -p downloads/data 3 | python scripts_data/download_data.py --url_file ./bash/assets/urls/feat.txt --out_folder downloads/data -------------------------------------------------------------------------------- /bash/download_images.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Downloading full resolution images" 5 | mkdir -p downloads/data/images_zips 6 | python scripts_data/download_data.py --url_file ./bash/assets/urls/images.txt --out_folder downloads/data/images_zips 7 | -------------------------------------------------------------------------------- /bash/download_misc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Downloading smaller files" 5 | mkdir -p downloads/data 6 | python scripts_data/download_data.py --url_file ./bash/assets/urls/misc.txt --out_folder downloads/data 7 | -------------------------------------------------------------------------------- /bash/download_mocap.sh: -------------------------------------------------------------------------------- 1 | echo "Downloading features files" 2 | mkdir -p downloads/data 3 | python scripts_data/download_data.py --url_file ./bash/assets/urls/mocap.txt --out_folder downloads/data 4 | -------------------------------------------------------------------------------- /bash/download_splits.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Downloading preprocessed splits" 5 | mkdir -p downloads/data 6 | python scripts_data/download_data.py --url_file ./bash/assets/urls/splits.txt --out_folder downloads/data -------------------------------------------------------------------------------- /common/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /common/___init___.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/common/___init___.py -------------------------------------------------------------------------------- /common/args_utils.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | 4 | def set_default_params(args, default_args): 5 | # if a val is not set on argparse, use default val 6 | # else, use the one in the argparse 7 | custom_dict = {} 8 | for key, val in args.items(): 9 | if val is None: 10 | args[key] = default_args[key] 11 | else: 12 | custom_dict[key] = val 13 | 14 | logger.info(f"Using custom values: {custom_dict}") 15 | return args 16 | -------------------------------------------------------------------------------- /common/body_models.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import torch 5 | from smplx import MANO 6 | 7 | from common.mesh import Mesh 8 | 9 | 10 | class MANODecimator: 11 | def __init__(self): 12 | data = np.load( 13 | "./data/arctic_data/data/meta/mano_decimator_195.npy", allow_pickle=True 14 | ).item() 15 | mydata = {} 16 | for key, val in data.items(): 17 | # only consider decimation matrix so far 18 | if "D" in key: 19 | mydata[key] = torch.FloatTensor(val) 20 | self.data = mydata 21 | 22 | def downsample(self, verts, is_right): 23 | dev = verts.device 24 | flag = "right" if is_right else "left" 25 | if self.data[f"D_{flag}"].device != dev: 26 | self.data[f"D_{flag}"] = self.data[f"D_{flag}"].to(dev) 27 | D = self.data[f"D_{flag}"] 28 | batch_size = verts.shape[0] 29 | D_batch = D[None, :, :].repeat(batch_size, 1, 1) 30 | verts_sub = torch.bmm(D_batch, verts) 31 | return verts_sub 32 | 33 | 34 | MODEL_DIR = "./data/body_models/mano" 35 | 36 | SEAL_FACES_R = [ 37 | [120, 108, 778], 38 | [108, 79, 778], 39 | [79, 78, 778], 40 | [78, 121, 778], 41 | [121, 214, 778], 42 | [214, 215, 778], 43 | [215, 279, 778], 44 | [279, 239, 778], 45 | [239, 234, 778], 46 | [234, 92, 778], 47 | [92, 38, 778], 48 | [38, 122, 778], 49 | [122, 118, 778], 50 | [118, 117, 778], 51 | [117, 119, 778], 52 | [119, 120, 778], 53 | ] 54 | 55 | # vertex ids around the ring of the wrist 56 | CIRCLE_V_ID = np.array( 57 | [108, 79, 78, 121, 214, 215, 279, 239, 234, 92, 38, 122, 118, 117, 119, 120], 58 | dtype=np.int64, 59 | ) 60 | 61 | 62 | def seal_mano_mesh(v3d, faces, is_rhand): 63 | # v3d: B, 778, 3 64 | # faces: 1538, 3 65 | # output: v3d(B, 779, 3); faces (1554, 3) 66 | 67 | seal_faces = torch.LongTensor(np.array(SEAL_FACES_R)).to(faces.device) 68 | if not is_rhand: 69 | # left hand 70 | seal_faces = seal_faces[:, np.array([1, 0, 2])] # invert face normal 71 | centers = v3d[:, CIRCLE_V_ID].mean(dim=1)[:, None, :] 72 | sealed_vertices = torch.cat((v3d, centers), dim=1) 73 | faces = torch.cat((faces, seal_faces), dim=0) 74 | return sealed_vertices, faces 75 | 76 | 77 | def build_layers(device=None): 78 | from common.object_tensors import ObjectTensors 79 | 80 | layers = { 81 | "right": build_mano_aa(True), 82 | "left": build_mano_aa(False), 83 | "object_tensors": ObjectTensors(), 84 | } 85 | 86 | if device is not None: 87 | layers["right"] = layers["right"].to(device) 88 | layers["left"] = layers["left"].to(device) 89 | layers["object_tensors"].to(device) 90 | return layers 91 | 92 | 93 | MANO_MODEL_DIR = "./data/body_models/mano" 94 | SMPLX_MODEL_P = { 95 | "male": "./data/body_models/smplx/SMPLX_MALE.npz", 96 | "female": "./data/body_models/smplx/SMPLX_FEMALE.npz", 97 | "neutral": "./data/body_models/smplx/SMPLX_NEUTRAL.npz", 98 | } 99 | 100 | 101 | def build_smplx(batch_size, gender, vtemplate): 102 | import smplx 103 | 104 | subj_m = smplx.create( 105 | model_path=SMPLX_MODEL_P[gender], 106 | model_type="smplx", 107 | gender=gender, 108 | num_pca_comps=45, 109 | v_template=vtemplate, 110 | flat_hand_mean=True, 111 | use_pca=False, 112 | batch_size=batch_size, 113 | # batch_size=320, 114 | ) 115 | return subj_m 116 | 117 | 118 | def build_subject_smplx(batch_size, subject_id): 119 | with open("./data/arctic_data/data/meta/misc.json", "r") as f: 120 | misc = json.load(f) 121 | vtemplate_p = f"./data/arctic_data/data/meta/subject_vtemplates/{subject_id}.obj" 122 | mesh = Mesh(filename=vtemplate_p) 123 | vtemplate = mesh.v 124 | gender = misc[subject_id]["gender"] 125 | return build_smplx(batch_size, gender, vtemplate) 126 | 127 | 128 | def build_mano_aa(is_rhand, create_transl=False, flat_hand=False): 129 | return MANO( 130 | MODEL_DIR, 131 | create_transl=create_transl, 132 | use_pca=False, 133 | flat_hand_mean=flat_hand, 134 | is_rhand=is_rhand, 135 | ) 136 | 137 | 138 | def construct_layers(dev): 139 | mano_layers = { 140 | "right": build_mano_aa(True, create_transl=True, flat_hand=False), 141 | "left": build_mano_aa(False, create_transl=True, flat_hand=False), 142 | "smplx": build_smplx(1, "neutral", None), 143 | } 144 | for layer in mano_layers.values(): 145 | layer.to(dev) 146 | return mano_layers 147 | -------------------------------------------------------------------------------- /common/comet_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as op 4 | import time 5 | 6 | import comet_ml 7 | import numpy as np 8 | import torch 9 | from loguru import logger 10 | from tqdm import tqdm 11 | 12 | from src.datasets.dataset_utils import copy_repo_arctic 13 | 14 | # folder used for debugging 15 | DUMMY_EXP = "xxxxxxxxx" 16 | 17 | 18 | def add_paths(args): 19 | exp_key = args.exp_key 20 | args_p = f"./logs/{exp_key}/args.json" 21 | ckpt_p = f"./logs/{exp_key}/checkpoints/last.ckpt" 22 | if not op.exists(ckpt_p) or DUMMY_EXP in ckpt_p: 23 | ckpt_p = "" 24 | if args.resume_ckpt != "": 25 | ckpt_p = args.resume_ckpt 26 | args.ckpt_p = ckpt_p 27 | args.log_dir = f"./logs/{exp_key}" 28 | 29 | if args.infer_ckpt != "": 30 | basedir = "/".join(args.infer_ckpt.split("/")[:2]) 31 | basename = op.basename(args.infer_ckpt).replace(".ckpt", ".params.pt") 32 | args.interface_p = op.join(basedir, basename) 33 | args.args_p = args_p 34 | if args.cluster: 35 | args.run_p = op.join(args.log_dir, "condor", "run.sh") 36 | args.submit_p = op.join(args.log_dir, "condor", "submit.sub") 37 | args.repo_p = op.join(args.log_dir, "repo") 38 | 39 | return args 40 | 41 | 42 | def save_args(args, save_keys): 43 | args_save = {} 44 | for key, val in args.items(): 45 | if key in save_keys: 46 | args_save[key] = val 47 | with open(args.args_p, "w") as f: 48 | json.dump(args_save, f, indent=4) 49 | logger.info(f"Saved args at {args.args_p}") 50 | 51 | 52 | def create_files(args): 53 | os.makedirs(args.log_dir, exist_ok=True) 54 | if args.cluster: 55 | os.makedirs(op.dirname(args.run_p), exist_ok=True) 56 | copy_repo_arctic(args.exp_key) 57 | 58 | 59 | def log_exp_meta(args): 60 | tags = [args.method] 61 | logger.info(f"Experiment tags: {tags}") 62 | args.experiment.set_name(args.exp_key) 63 | args.experiment.add_tags(tags) 64 | args.experiment.log_parameters(args) 65 | 66 | 67 | def init_experiment(args): 68 | if args.resume_ckpt != "": 69 | args.exp_key = args.resume_ckpt.split("/")[1] 70 | if args.fast_dev_run: 71 | args.exp_key = DUMMY_EXP 72 | if args.exp_key == "": 73 | args.exp_key = generate_exp_key() 74 | args = add_paths(args) 75 | if op.exists(args.args_p) and args.exp_key not in [DUMMY_EXP]: 76 | with open(args.args_p, "r") as f: 77 | args_disk = json.load(f) 78 | if "comet_key" in args_disk.keys(): 79 | args.comet_key = args_disk["comet_key"] 80 | 81 | create_files(args) 82 | 83 | project_name = args.project 84 | disabled = args.mute 85 | comet_url = args["comet_key"] if "comet_key" in args.keys() else None 86 | 87 | api_key = os.environ["COMET_API_KEY"] 88 | workspace = os.environ["COMET_WORKSPACE"] 89 | if not args.cluster: 90 | if comet_url is None: 91 | experiment = comet_ml.Experiment( 92 | api_key=api_key, 93 | workspace=workspace, 94 | project_name=project_name, 95 | disabled=disabled, 96 | display_summary_level=0, 97 | ) 98 | args.comet_key = experiment.get_key() 99 | else: 100 | experiment = comet_ml.ExistingExperiment( 101 | previous_experiment=comet_url, 102 | api_key=api_key, 103 | project_name=project_name, 104 | workspace=workspace, 105 | disabled=disabled, 106 | display_summary_level=0, 107 | ) 108 | 109 | device = "cuda" if torch.cuda.is_available() else "cpu" 110 | logger.add( 111 | os.path.join(args.log_dir, "train.log"), 112 | level="INFO", 113 | colorize=True, 114 | ) 115 | logger.info(torch.cuda.get_device_properties(device)) 116 | args.gpu = torch.cuda.get_device_properties(device).name 117 | else: 118 | experiment = None 119 | args.experiment = experiment 120 | return experiment, args 121 | 122 | 123 | def log_dict(experiment, metric_dict, step, postfix=None): 124 | if experiment is None: 125 | return 126 | for key, value in metric_dict.items(): 127 | if postfix is not None: 128 | key = key + postfix 129 | if isinstance(value, torch.Tensor) and len(value.view(-1)) == 1: 130 | value = value.item() 131 | 132 | if isinstance(value, (int, float, np.float32)): 133 | experiment.log_metric(key, value, step=step) 134 | 135 | 136 | def generate_exp_key(): 137 | import random 138 | 139 | hash = random.getrandbits(128) 140 | key = "%032x" % (hash) 141 | key = key[:9] 142 | return key 143 | 144 | 145 | def push_images(experiment, all_im_list, global_step=None, no_tqdm=False, verbose=True): 146 | if verbose: 147 | print("Pushing PIL images") 148 | tic = time.time() 149 | iterator = all_im_list if no_tqdm else tqdm(all_im_list) 150 | for im in iterator: 151 | im_np = np.array(im["im"]) 152 | if "fig_name" in im.keys(): 153 | experiment.log_image(im_np, im["fig_name"], step=global_step) 154 | else: 155 | experiment.log_image(im_np, "unnamed", step=global_step) 156 | if verbose: 157 | toc = time.time() 158 | print("Done pushing PIL images (%.1fs)" % (toc - tic)) 159 | -------------------------------------------------------------------------------- /common/ld_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def sort_dict(disordered): 8 | sorted_dict = {k: disordered[k] for k in sorted(disordered)} 9 | return sorted_dict 10 | 11 | 12 | def prefix_dict(mydict, prefix): 13 | out = {prefix + k: v for k, v in mydict.items()} 14 | return out 15 | 16 | 17 | def postfix_dict(mydict, postfix): 18 | out = {k + postfix: v for k, v in mydict.items()} 19 | return out 20 | 21 | 22 | def unsort(L, sort_idx): 23 | assert isinstance(sort_idx, list) 24 | assert isinstance(L, list) 25 | LL = zip(sort_idx, L) 26 | LL = sorted(LL, key=lambda x: x[0]) 27 | _, L = zip(*LL) 28 | return list(L) 29 | 30 | 31 | def cat_dl(out_list, dim, verbose=True, squeeze=True): 32 | out = {} 33 | for key, val in out_list.items(): 34 | if isinstance(val[0], torch.Tensor): 35 | out[key] = torch.cat(val, dim=dim) 36 | if squeeze: 37 | out[key] = out[key].squeeze() 38 | elif isinstance(val[0], np.ndarray): 39 | out[key] = np.concatenate(val, axis=dim) 40 | if squeeze: 41 | out[key] = np.squeeze(out[key]) 42 | elif isinstance(val[0], list): 43 | out[key] = sum(val, []) 44 | else: 45 | if verbose: 46 | print(f"Ignoring {key} undefined type {type(val[0])}") 47 | return out 48 | 49 | 50 | def stack_dl(out_list, dim, verbose=True, squeeze=True): 51 | out = {} 52 | for key, val in out_list.items(): 53 | if isinstance(val[0], torch.Tensor): 54 | out[key] = torch.stack(val, dim=dim) 55 | if squeeze: 56 | out[key] = out[key].squeeze() 57 | elif isinstance(val[0], np.ndarray): 58 | out[key] = np.stack(val, axis=dim) 59 | if squeeze: 60 | out[key] = np.squeeze(out[key]) 61 | elif isinstance(val[0], list): 62 | out[key] = sum(val, []) 63 | else: 64 | out[key] = val 65 | if verbose: 66 | print(f"Processing {key} undefined type {type(val[0])}") 67 | return out 68 | 69 | 70 | def add_prefix_postfix(mydict, prefix="", postfix=""): 71 | assert isinstance(mydict, dict) 72 | return dict((prefix + key + postfix, value) for (key, value) in mydict.items()) 73 | 74 | 75 | def ld2dl(LD): 76 | assert isinstance(LD, list) 77 | assert isinstance(LD[0], dict) 78 | """ 79 | A list of dict (same keys) to a dict of lists 80 | """ 81 | dict_list = {k: [dic[k] for dic in LD] for k in LD[0]} 82 | return dict_list 83 | 84 | 85 | class NameSpace(object): 86 | def __init__(self, adict): 87 | self.__dict__.update(adict) 88 | 89 | 90 | def dict2ns(mydict): 91 | """ 92 | Convert dict objec to namespace 93 | """ 94 | return NameSpace(mydict) 95 | 96 | 97 | def ld2dev(ld, dev): 98 | """ 99 | Convert tensors in a list or dict to a device recursively 100 | """ 101 | if isinstance(ld, torch.Tensor): 102 | return ld.to(dev) 103 | if isinstance(ld, dict): 104 | for k, v in ld.items(): 105 | ld[k] = ld2dev(v, dev) 106 | return ld 107 | if isinstance(ld, list): 108 | return [ld2dev(x, dev) for x in ld] 109 | return ld 110 | 111 | 112 | def all_comb_dict(hyper_dict): 113 | assert isinstance(hyper_dict, dict) 114 | keys, values = zip(*hyper_dict.items()) 115 | permute_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)] 116 | return permute_dicts 117 | -------------------------------------------------------------------------------- /common/list_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def chunks_by_len(L, n): 5 | """ 6 | Split a list into n chunks 7 | """ 8 | num_chunks = int(math.ceil(float(len(L)) / n)) 9 | splits = [L[x : x + num_chunks] for x in range(0, len(L), num_chunks)] 10 | return splits 11 | 12 | 13 | def chunks_by_size(L, n): 14 | """Yield successive n-sized chunks from lst.""" 15 | seqs = [] 16 | for i in range(0, len(L), n): 17 | seqs.append(L[i : i + n]) 18 | return seqs 19 | 20 | 21 | def unsort(L, sort_idx): 22 | assert isinstance(sort_idx, list) 23 | assert isinstance(L, list) 24 | LL = zip(sort_idx, L) 25 | LL = sorted(LL, key=lambda x: x[0]) 26 | _, L = zip(*LL) 27 | return list(L) 28 | 29 | 30 | def add_prefix_postfix(mydict, prefix="", postfix=""): 31 | assert isinstance(mydict, dict) 32 | return dict((prefix + key + postfix, value) for (key, value) in mydict.items()) 33 | 34 | 35 | def ld2dl(LD): 36 | assert isinstance(LD, list) 37 | assert isinstance(LD[0], dict) 38 | """ 39 | A list of dict (same keys) to a dict of lists 40 | """ 41 | dict_list = {k: [dic[k] for dic in LD] for k in LD[0]} 42 | return dict_list 43 | 44 | 45 | def chunks(lst, n): 46 | """Yield successive n-sized chunks from lst.""" 47 | seqs = [] 48 | for i in range(0, len(lst), n): 49 | seqs.append(lst[i : i + n]) 50 | seqs_chunked = sum(seqs, []) 51 | assert set(seqs_chunked) == set(lst) 52 | return seqs 53 | -------------------------------------------------------------------------------- /common/mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | 4 | colors = { 5 | "pink": [1.00, 0.75, 0.80], 6 | "purple": [0.63, 0.13, 0.94], 7 | "red": [1.0, 0.0, 0.0], 8 | "green": [0.0, 1.0, 0.0], 9 | "yellow": [1.0, 1.0, 0], 10 | "brown": [1.00, 0.25, 0.25], 11 | "blue": [0.0, 0.0, 1.0], 12 | "white": [1.0, 1.0, 1.0], 13 | "orange": [1.00, 0.65, 0.00], 14 | "grey": [0.75, 0.75, 0.75], 15 | "black": [0.0, 0.0, 0.0], 16 | } 17 | 18 | 19 | class Mesh(trimesh.Trimesh): 20 | def __init__( 21 | self, 22 | filename=None, 23 | v=None, 24 | f=None, 25 | vc=None, 26 | fc=None, 27 | process=False, 28 | visual=None, 29 | **kwargs 30 | ): 31 | if filename is not None: 32 | mesh = trimesh.load(filename, process=process) 33 | v = mesh.vertices 34 | f = mesh.faces 35 | visual = mesh.visual 36 | 37 | super(Mesh, self).__init__( 38 | vertices=v, faces=f, visual=visual, process=process, **kwargs 39 | ) 40 | 41 | self.v = self.vertices 42 | self.f = self.faces 43 | assert self.v is self.vertices 44 | assert self.f is self.faces 45 | 46 | if vc is not None: 47 | self.set_vc(vc) 48 | self.vc = self.visual.vertex_colors 49 | assert self.vc is self.visual.vertex_colors 50 | if fc is not None: 51 | self.set_fc(fc) 52 | self.fc = self.visual.face_colors 53 | assert self.fc is self.visual.face_colors 54 | 55 | def rot_verts(self, vertices, rxyz): 56 | return np.array(vertices * rxyz.T) 57 | 58 | def colors_like(self, color, array, ids): 59 | color = np.array(color) 60 | 61 | if color.max() <= 1.0: 62 | color = color * 255 63 | color = color.astype(np.int8) 64 | 65 | n_color = color.shape[0] 66 | n_ids = ids.shape[0] 67 | 68 | new_color = np.array(array) 69 | if n_color <= 4: 70 | new_color[ids, :n_color] = np.repeat(color[np.newaxis], n_ids, axis=0) 71 | else: 72 | new_color[ids, :] = color 73 | 74 | return new_color 75 | 76 | def set_vc(self, vc, vertex_ids=None): 77 | all_ids = np.arange(self.vertices.shape[0]) 78 | if vertex_ids is None: 79 | vertex_ids = all_ids 80 | 81 | vertex_ids = all_ids[vertex_ids] 82 | new_vc = self.colors_like(vc, self.visual.vertex_colors, vertex_ids) 83 | self.visual.vertex_colors[:] = new_vc 84 | 85 | def set_fc(self, fc, face_ids=None): 86 | if face_ids is None: 87 | face_ids = np.arange(self.faces.shape[0]) 88 | 89 | new_fc = self.colors_like(fc, self.visual.face_colors, face_ids) 90 | self.visual.face_colors[:] = new_fc 91 | 92 | @staticmethod 93 | def cat(meshes): 94 | return trimesh.util.concatenate(meshes) 95 | -------------------------------------------------------------------------------- /common/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def compute_v2v_dist_no_reduce(v3d_cam_gt, v3d_cam_pred, is_valid): 8 | assert isinstance(v3d_cam_gt, list) 9 | assert isinstance(v3d_cam_pred, list) 10 | assert len(v3d_cam_gt) == len(v3d_cam_pred) 11 | assert len(v3d_cam_gt) == len(is_valid) 12 | v2v = [] 13 | for v_gt, v_pred, valid in zip(v3d_cam_gt, v3d_cam_pred, is_valid): 14 | if valid: 15 | dist = ((v_gt - v_pred) ** 2).sum(dim=1).sqrt().cpu().numpy() # meter 16 | else: 17 | dist = None 18 | v2v.append(dist) 19 | return v2v 20 | 21 | 22 | def compute_joint3d_error(joints3d_cam_gt, joints3d_cam_pred, valid_jts): 23 | valid_jts = valid_jts.view(-1) 24 | assert joints3d_cam_gt.shape == joints3d_cam_pred.shape 25 | assert joints3d_cam_gt.shape[0] == valid_jts.shape[0] 26 | dist = ((joints3d_cam_gt - joints3d_cam_pred) ** 2).sum(dim=2).sqrt() 27 | invalid_idx = torch.nonzero((1 - valid_jts).long()).view(-1) 28 | dist[invalid_idx, :] = float("nan") 29 | dist = dist.cpu().numpy() 30 | return dist 31 | 32 | 33 | def compute_mrrpe(root_r_gt, root_l_gt, root_r_pred, root_l_pred, is_valid): 34 | rel_vec_gt = root_l_gt - root_r_gt 35 | rel_vec_pred = root_l_pred - root_r_pred 36 | 37 | invalid_idx = torch.nonzero((1 - is_valid).long()).view(-1) 38 | mrrpe = ((rel_vec_pred - rel_vec_gt) ** 2).sum(dim=1).sqrt() 39 | mrrpe[invalid_idx] = float("nan") 40 | mrrpe = mrrpe.cpu().numpy() 41 | return mrrpe 42 | 43 | 44 | def compute_arti_deg_error(pred_radian, gt_radian): 45 | assert pred_radian.shape == gt_radian.shape 46 | 47 | # articulation error in degree 48 | pred_degree = pred_radian / math.pi * 180 # degree 49 | gt_degree = gt_radian / math.pi * 180 # degree 50 | err_deg = torch.abs(pred_degree - gt_degree).tolist() 51 | return np.array(err_deg, dtype=np.float32) 52 | -------------------------------------------------------------------------------- /common/np_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def permute_np(x, idx): 5 | original_perm = tuple(range(len(x.shape))) 6 | x = np.moveaxis(x, original_perm, idx) 7 | return x 8 | -------------------------------------------------------------------------------- /common/pl_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import torch 5 | 6 | import common.thing as thing 7 | from common.ld_utils import ld2dl 8 | 9 | 10 | def reweight_loss_by_keys(loss_dict, keys, alpha): 11 | for key in keys: 12 | val, weight = loss_dict[key] 13 | weight_new = weight * alpha 14 | loss_dict[key] = (val, weight_new) 15 | return loss_dict 16 | 17 | 18 | def select_loss_group(groups, agent_id, alphas): 19 | random.seed(1) 20 | random.shuffle(groups) 21 | 22 | keys = groups[agent_id % len(groups)] 23 | 24 | random.seed(time.time()) 25 | alpha = random.choice(alphas) 26 | random.seed(1) 27 | return keys, alpha 28 | 29 | 30 | def push_checkpoint_metric(key, val): 31 | val = float(val) 32 | checkpt_metric = torch.FloatTensor([val]) 33 | result = {key: checkpt_metric} 34 | return result 35 | 36 | 37 | def avg_losses_cpu(outputs): 38 | outputs = ld2dl(outputs) 39 | for key, val in outputs.items(): 40 | val = [v.cpu() for v in val] 41 | val = torch.cat(val, dim=0).view(-1) 42 | outputs[key] = val.mean() 43 | return outputs 44 | 45 | 46 | def reform_outputs(out_list): 47 | out_list_dict = ld2dl(out_list) 48 | outputs = ld2dl(out_list_dict["out_dict"]) 49 | losses = ld2dl(out_list_dict["loss"]) 50 | 51 | for k, tensor in outputs.items(): 52 | if isinstance(tensor[0], list): 53 | outputs[k] = sum(tensor, []) 54 | else: 55 | outputs[k] = torch.cat(tensor) 56 | 57 | for k, tensor in losses.items(): 58 | tensor = [ten.view(-1) for ten in tensor] 59 | losses[k] = torch.cat(tensor) 60 | 61 | outputs = {k: thing.thing2np(v) for k, v in outputs.items()} 62 | loss_dict = {k: v.mean().item() for k, v in losses.items()} 63 | return outputs, loss_dict 64 | -------------------------------------------------------------------------------- /common/rend_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import numpy as np 5 | import pyrender 6 | import trimesh 7 | 8 | # offline rendering 9 | os.environ["PYOPENGL_PLATFORM"] = "egl" 10 | 11 | 12 | def flip_meshes(meshes): 13 | rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) 14 | for mesh in meshes: 15 | mesh.apply_transform(rot) 16 | return meshes 17 | 18 | 19 | def color2material(mesh_color: list): 20 | material = pyrender.MetallicRoughnessMaterial( 21 | metallicFactor=0.1, 22 | alphaMode="OPAQUE", 23 | baseColorFactor=( 24 | mesh_color[0] / 255.0, 25 | mesh_color[1] / 255.0, 26 | mesh_color[2] / 255.0, 27 | 0.5, 28 | ), 29 | ) 30 | return material 31 | 32 | 33 | class Renderer: 34 | def __init__(self, img_res: int) -> None: 35 | self.renderer = pyrender.OffscreenRenderer( 36 | viewport_width=img_res, viewport_height=img_res, point_size=1.0 37 | ) 38 | 39 | self.img_res = img_res 40 | 41 | def render_meshes_pose( 42 | self, 43 | meshes, 44 | image=None, 45 | cam_transl=None, 46 | cam_center=None, 47 | K=None, 48 | materials=None, 49 | sideview_angle=None, 50 | ): 51 | # unpack 52 | if cam_transl is not None: 53 | cam_trans = np.copy(cam_transl) 54 | cam_trans[0] *= -1.0 55 | else: 56 | cam_trans = None 57 | meshes = copy.deepcopy(meshes) 58 | meshes = flip_meshes(meshes) 59 | 60 | if sideview_angle is not None: 61 | # center around the final mesh 62 | anchor_mesh = meshes[-1] 63 | center = anchor_mesh.vertices.mean(axis=0) 64 | 65 | rot = trimesh.transformations.rotation_matrix( 66 | np.radians(sideview_angle), [0, 1, 0] 67 | ) 68 | out_meshes = [] 69 | for mesh in copy.deepcopy(meshes): 70 | mesh.vertices -= center 71 | mesh.apply_transform(rot) 72 | mesh.vertices += center 73 | # further away to see more 74 | mesh.vertices += np.array([0, 0, -0.10]) 75 | out_meshes.append(mesh) 76 | meshes = out_meshes 77 | 78 | # setting up 79 | self.create_scene() 80 | self.setup_light() 81 | self.position_camera(cam_trans, K) 82 | if materials is not None: 83 | meshes = [ 84 | pyrender.Mesh.from_trimesh(mesh, material=material) 85 | for mesh, material in zip(meshes, materials) 86 | ] 87 | else: 88 | meshes = [pyrender.Mesh.from_trimesh(mesh) for mesh in meshes] 89 | 90 | for mesh in meshes: 91 | self.scene.add(mesh) 92 | 93 | color, valid_mask = self.render_rgb() 94 | if image is None: 95 | output_img = color[:, :, :3] 96 | else: 97 | output_img = self.overlay_image(color, valid_mask, image) 98 | rend_img = (output_img * 255).astype(np.uint8) 99 | return rend_img 100 | 101 | def render_rgb(self): 102 | color, rend_depth = self.renderer.render( 103 | self.scene, flags=pyrender.RenderFlags.RGBA 104 | ) 105 | color = color.astype(np.float32) / 255.0 106 | valid_mask = (rend_depth > 0)[:, :, None] 107 | return color, valid_mask 108 | 109 | def overlay_image(self, color, valid_mask, image): 110 | output_img = color[:, :, :3] * valid_mask + (1 - valid_mask) * image 111 | return output_img 112 | 113 | def position_camera(self, cam_transl, K): 114 | camera_pose = np.eye(4) 115 | if cam_transl is not None: 116 | camera_pose[:3, 3] = cam_transl 117 | 118 | fx = K[0, 0] 119 | fy = K[1, 1] 120 | cx = K[0, 2] 121 | cy = K[1, 2] 122 | camera = pyrender.IntrinsicsCamera(fx=fx, fy=fy, cx=cx, cy=cy) 123 | self.scene.add(camera, pose=camera_pose) 124 | 125 | def setup_light(self): 126 | light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1) 127 | light_pose = np.eye(4) 128 | 129 | light_pose[:3, 3] = np.array([0, -1, 1]) 130 | self.scene.add(light, pose=light_pose) 131 | 132 | light_pose[:3, 3] = np.array([0, 1, 1]) 133 | self.scene.add(light, pose=light_pose) 134 | 135 | light_pose[:3, 3] = np.array([1, 1, 2]) 136 | self.scene.add(light, pose=light_pose) 137 | 138 | def create_scene(self): 139 | self.scene = pyrender.Scene(ambient_light=(0.5, 0.5, 0.5)) 140 | -------------------------------------------------------------------------------- /common/sys_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | import shutil 4 | from glob import glob 5 | 6 | from loguru import logger 7 | 8 | 9 | def copy(src, dst): 10 | if os.path.islink(src): 11 | linkto = os.readlink(src) 12 | os.symlink(linkto, dst) 13 | else: 14 | if os.path.isdir(src): 15 | shutil.copytree(src, dst) 16 | else: 17 | shutil.copy(src, dst) 18 | 19 | 20 | def copy_repo(src_files, dst_folder, filter_keywords): 21 | src_files = [ 22 | f for f in src_files if not any(keyword in f for keyword in filter_keywords) 23 | ] 24 | dst_files = [op.join(dst_folder, op.basename(f)) for f in src_files] 25 | for src_f, dst_f in zip(src_files, dst_files): 26 | logger.info(f"FROM: {src_f}\nTO:{dst_f}") 27 | copy(src_f, dst_f) 28 | 29 | 30 | def mkdir(directory): 31 | if not os.path.exists(directory): 32 | os.makedirs(directory) 33 | 34 | 35 | def mkdir_p(exp_path): 36 | os.makedirs(exp_path, exist_ok=True) 37 | 38 | 39 | def count_files(path): 40 | """ 41 | Non-recursively count number of files in a folder. 42 | """ 43 | files = glob(path) 44 | return len(files) 45 | -------------------------------------------------------------------------------- /common/thing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | """ 5 | This file stores functions for conversion between numpy and torch, torch, list, etc. 6 | Also deal with general operations such as to(dev), detach, etc. 7 | """ 8 | 9 | 10 | def thing2list(thing): 11 | if isinstance(thing, torch.Tensor): 12 | return thing.tolist() 13 | if isinstance(thing, np.ndarray): 14 | return thing.tolist() 15 | if isinstance(thing, dict): 16 | return {k: thing2list(v) for k, v in md.items()} 17 | if isinstance(thing, list): 18 | return [thing2list(ten) for ten in thing] 19 | return thing 20 | 21 | 22 | def thing2dev(thing, dev): 23 | if hasattr(thing, "to"): 24 | thing = thing.to(dev) 25 | return thing 26 | if isinstance(thing, list): 27 | return [thing2dev(ten, dev) for ten in thing] 28 | if isinstance(thing, tuple): 29 | return tuple(thing2dev(list(thing), dev)) 30 | if isinstance(thing, dict): 31 | return {k: thing2dev(v, dev) for k, v in thing.items()} 32 | if isinstance(thing, torch.Tensor): 33 | return thing.to(dev) 34 | return thing 35 | 36 | 37 | def thing2np(thing): 38 | if isinstance(thing, list): 39 | return np.array(thing) 40 | if isinstance(thing, torch.Tensor): 41 | return thing.cpu().detach().numpy() 42 | if isinstance(thing, dict): 43 | return {k: thing2np(v) for k, v in thing.items()} 44 | return thing 45 | 46 | 47 | def thing2torch(thing): 48 | if isinstance(thing, list): 49 | return torch.tensor(np.array(thing)) 50 | if isinstance(thing, np.ndarray): 51 | return torch.from_numpy(thing) 52 | if isinstance(thing, dict): 53 | return {k: thing2torch(v) for k, v in thing.items()} 54 | return thing 55 | 56 | 57 | def detach_thing(thing): 58 | if isinstance(thing, torch.Tensor): 59 | return thing.cpu().detach() 60 | if isinstance(thing, list): 61 | return [detach_thing(ten) for ten in thing] 62 | if isinstance(thing, tuple): 63 | return tuple(detach_thing(list(thing))) 64 | if isinstance(thing, dict): 65 | return {k: detach_thing(v) for k, v in thing.items()} 66 | return thing 67 | -------------------------------------------------------------------------------- /common/vis_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.cm as cm 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from PIL import Image 5 | 6 | # connection between the 8 points of 3d bbox 7 | BONES_3D_BBOX = [ 8 | (0, 1), 9 | (1, 2), 10 | (2, 3), 11 | (3, 0), 12 | (0, 4), 13 | (1, 5), 14 | (2, 6), 15 | (3, 7), 16 | (4, 5), 17 | (5, 6), 18 | (6, 7), 19 | (7, 4), 20 | ] 21 | 22 | 23 | def plot_2d_bbox(bbox_2d, bones, color, ax): 24 | if ax is None: 25 | axx = plt 26 | else: 27 | axx = ax 28 | colors = cm.rainbow(np.linspace(0, 1, len(bbox_2d))) 29 | for pt, c in zip(bbox_2d, colors): 30 | axx.scatter(pt[0], pt[1], color=c, s=50) 31 | 32 | if bones is None: 33 | bones = BONES_3D_BBOX 34 | for bone in bones: 35 | sidx, eidx = bone 36 | # bottom of bbox is white 37 | if min(sidx, eidx) >= 4: 38 | color = "w" 39 | axx.plot( 40 | [bbox_2d[sidx][0], bbox_2d[eidx][0]], 41 | [bbox_2d[sidx][1], bbox_2d[eidx][1]], 42 | color, 43 | ) 44 | return axx 45 | 46 | 47 | # http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure 48 | def fig2data(fig): 49 | """ 50 | @brief Convert a Matplotlib figure to a 4D 51 | numpy array with RGBA channels and return it 52 | @param fig a matplotlib figure 53 | @return a numpy 3D array of RGBA values 54 | """ 55 | # draw the renderer 56 | fig.canvas.draw() 57 | 58 | # Get the RGBA buffer from the figure 59 | w, h = fig.canvas.get_width_height() 60 | buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8) 61 | buf.shape = (w, h, 4) 62 | 63 | # canvas.tostring_argb give pixmap in ARGB mode. 64 | # Roll the ALPHA channel to have it in RGBA mode 65 | buf = np.roll(buf, 3, axis=2) 66 | return buf 67 | 68 | 69 | # http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure 70 | def fig2img(fig): 71 | """ 72 | @brief Convert a Matplotlib figure to a PIL Image 73 | in RGBA format and return it 74 | @param fig a matplotlib figure 75 | @return a Python Imaging Library ( PIL ) image 76 | """ 77 | # put the figure pixmap into a numpy array 78 | buf = fig2data(fig) 79 | w, h, _ = buf.shape 80 | return Image.frombytes("RGBA", (w, h), buf.tobytes()) 81 | 82 | 83 | def concat_pil_images(images): 84 | """ 85 | Put a list of PIL images next to each other 86 | """ 87 | assert isinstance(images, list) 88 | widths, heights = zip(*(i.size for i in images)) 89 | 90 | total_width = sum(widths) 91 | max_height = max(heights) 92 | 93 | new_im = Image.new("RGB", (total_width, max_height)) 94 | 95 | x_offset = 0 96 | for im in images: 97 | new_im.paste(im, (x_offset, 0)) 98 | x_offset += im.size[0] 99 | return new_im 100 | 101 | 102 | def stack_pil_images(images): 103 | """ 104 | Stack a list of PIL images next to each other 105 | """ 106 | assert isinstance(images, list) 107 | widths, heights = zip(*(i.size for i in images)) 108 | 109 | total_height = sum(heights) 110 | max_width = max(widths) 111 | 112 | new_im = Image.new("RGB", (max_width, total_height)) 113 | 114 | y_offset = 0 115 | for im in images: 116 | new_im.paste(im, (0, y_offset)) 117 | y_offset += im.size[1] 118 | return new_im 119 | 120 | 121 | def im_list_to_plt(image_list, figsize, title_list=None): 122 | fig, axes = plt.subplots(nrows=1, ncols=len(image_list), figsize=figsize) 123 | for idx, (ax, im) in enumerate(zip(axes, image_list)): 124 | ax.imshow(im) 125 | ax.set_title(title_list[idx]) 126 | fig.tight_layout() 127 | im = fig2img(fig) 128 | plt.close() 129 | return im 130 | -------------------------------------------------------------------------------- /docs/data/mano_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/data/mano_right.png -------------------------------------------------------------------------------- /docs/data/processing.md: -------------------------------------------------------------------------------- 1 | # Data processing & splits 2 | 3 | ## Data splits 4 | 5 | **CVPR paper splits** 6 | 7 | - protocol 1: allocentric split (test set GT is hidden) 8 | - protocol 2: egocentric split (test set GT is hidden) 9 | 10 | Note, allocentric training images in protocol 1 can be used for protocol 2 training as per our evaluation protocol in the paper. In our paper, we first pre-train on protocol 1 images and finetune on protocol 2 for the egocentric regressor. If one wants to directly train on allocentric and egocentric training images for protocol 2 evaluation, she can create a custom split. 11 | 12 | See [`docs/data_doc.md`](../data_doc.md) for an explanation of each file in the `arctic_data` folder. 13 | 14 | ## Advanced usage 15 | 16 | ### Process raw sequences 17 | 18 | ```bash 19 | # process a specific seq; do not save vertices for smaller storage 20 | python scripts_data/process_seqs.py --mano_p ./unpack/arctic_data/data/raw_seqs/s01/espressomachine_use_01.mano.npy 21 | 22 | # process all seqs; do not save vertices for smaller storage 23 | python scripts_data/process_seqs.py 24 | 25 | # process all seqs while exporting the vertices for visualization 26 | python scripts_data/process_seqs.py --export_verts 27 | ``` 28 | 29 | ### Create data split from processed sequences 30 | 31 | Our baseline load the pre-processed split from `data/arctic_data/data/splits`. In case you need a custom split, you can build a data split from the example below (here we show validation set split), which generates the split files under `outputs/splits/` 32 | 33 | Build a data split from processed sequence: 34 | 35 | ```bash 36 | # Build validation set based on protocol p1 defined at arctic_data/data/splits_json/protocol_p1.json 37 | python scripts_data/build_splits.py --protocol p1 --split val --process_folder ./outputs/processed/seqs 38 | 39 | # Same as above, but build with vertices too 40 | python scripts_data/build_splits.py --protocol p1 --split val --process_folder ./outputs/processed_verts/seqs 41 | ``` 42 | 43 | ⚠️ The dataloader for our models in our CVPR paper does not require vertices in the split files. If the processed sequences are built with `--export_verts`, this script will try to aggregate the vertices as well, leading to large storage requirement. 44 | 45 | ### Crop images for faster data loading 46 | 47 | Since our images are of high resolution, if reading speed is a limitation for your machine for training models, one can consider cropping the images around a larger region centered at the bounding boxes to reduce data loading requirement in training. We have provided data link for pre-cropped images. In case of a custom crop, one can use the script below: 48 | 49 | ```bash 50 | # crop all images from all sequences using bbox defined in the process folder on a single machine 51 | python scripts_data/crop_images.py --task_id -1 --process_folder ./outputs/processed/seqs 52 | 53 | # crop all images from one sequence using bbox defined in the process folder 54 | # this is used for cluster preprocessing where AGENT_ID is from 0 to num_nodes-1 55 | python scripts_data/crop_images.py --task_id AGENT_ID --process_folder ./outputs/processed/seqs 56 | ``` 57 | -------------------------------------------------------------------------------- /docs/data/visualize.md: -------------------------------------------------------------------------------- 1 | # AIT Viewer with ARCTIC 2 | 3 | Our visualization is powered by: 4 | 5 | Image 6 | 7 | ## Examples 8 | 9 | ```bash 10 | # render object and MANO for a given sequence 11 | python scripts_data/visualizer.py --seq_p ./outputs/processed_verts/seqs/s01/capsulemachine_use_01.npy --object --mano 12 | 13 | # render object and MANO for a given sequence on view 2 14 | python scripts_data/visualizer.py --seq_p ./outputs/processed_verts/seqs/s01/capsulemachine_use_01.npy --object --mano --view_idx 2 15 | 16 | # render object and MANO for a given sequence on egocentric view 17 | python scripts_data/visualizer.py --seq_p ./outputs/processed_verts/seqs/s01/capsulemachine_use_01.npy --object --mano --view_idx 0 18 | 19 | # render object and MANO for a given sequence on egocentric view while taking lens distortion into account 20 | python scripts_data/visualizer.py --seq_p ./outputs/processed_verts/seqs/s01/capsulemachine_use_01.npy --object --mano --view_idx 0 --distort 21 | 22 | # render in headless mode to obtain RGB images (with meshes), depth, segmentation masks, and mp4 video of the visualization 23 | python scripts_data/visualizer.py --seq_p ./outputs/processed_verts/seqs/s01/capsulemachine_use_01.npy --object --mano --headless 24 | 25 | # render object and SMPLX for a given sequence without images 26 | python scripts_data/visualizer.py --seq_p ./outputs/processed_verts/seqs/s01/capsulemachine_use_01.npy --object --smplx --no_image 27 | 28 | # render all sequences into videos, RGB images with meshes, depth maps, and segmentation masks 29 | python scripts_data/visualizer.py --object --smplx --headless 30 | 31 | # visualize raw mocap data 32 | python scripts_data/mocap_viewer.py --mocap_p unpack/arctic_data/data/mocap_npy/s01_ketchup_use_01.npy 33 | ``` 34 | 35 | ## Options 36 | 37 | - `view_idx`: camera view to visualize; `0` is for egocentric view; `{1, .., 8}` are for 3rd-person views. 38 | - `seq_p`: path to processed sequence to visualize. When this option is not specified, the program will run on all sequences (e.g., when you want to render depth masks for all sequences). 39 | - `headless`: when it is off, user will be have an interactive mode; when it is on, we render and save images with GT, depth maps, segmentation masks, and videos to disks. 40 | - `mano`: include MANO in the scene 41 | - `smplx`: include SMPLX in the scene 42 | - `object`: include object in the scene 43 | - `no_image`: do not show images. 44 | - `distort`: in egocentric view, lens distortion is servere as the camera is close to the 3D objects, leading to mismatch in 3D geometry and the images. When turned on, this option makes use of the lens distortion parameters for better GT-image overlaps by simulating the distortion effect using ["vertex displacement for distortion correction"](https://stackoverflow.com/questions/44489686/camera-lens-distortion-in-opengl). It uses the distortion parameters to distort the 3D geometry so that it has better 3D overlaps with the images. However, such a method creates artifacts when the 3D geometry is close to the camera. 45 | 46 | Segmentation mask IDs in the scene are defined [here](https://github.com/zc-alexfan/arctic-private/blob/arctic/common/viewer.py#L24). 47 | 48 | ## Controls to interact with the viewer 49 | 50 | [AITViewer](https://github.com/eth-ait/aitviewer) has lots of useful builtin controls. For an explanation of the frontend and control, visit [here](https://eth-ait.github.io/aitviewer/frontend.html). Here we assume you are in interactive mode (`--headless` is turned off). 51 | 52 | - To play/pause the animation, hit ``. 53 | - To center around an object, click the mesh you want to center, press `X`. 54 | - To go between the previous and the current frame, press `<` and `>`. 55 | 56 | More documentation can be found in [aitviewer github](https://github.com/eth-ait/aitviewer) and in [viewer docs](https://eth-ait.github.io/aitviewer/frontend.html). 57 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | QUESTION: **Why do groundtruth hands not have complete overlap with the image in the visualization (see below) via `scripts_method/train.py`?** 4 | ANSWER: Like most hand-object reconstruction methods, ArcticNet does not assume camera intrinsics and we use a weak perspective camera model by assuming a fixed focal length. The mismatch of 2D alignment between the groundtruth and the image is caused by the weak perspective camera intrinsics. 5 | 6 | 7 |

8 | 9 |

10 | 11 | QUESTION: **Why is there more misalignment in egocentric view?** 12 | ANSWER: Mainly caused by image distortion in the rendering as the 3D geometry is close to the camera. We estimated distortion parameters from camera calibration sequences, which can be used to apply distortion effect on the meshes with ["vertex displacement for distortion correction"](https://stackoverflow.com/questions/44489686/camera-lens-distortion-in-opengl). 13 | 14 | See the `--distort` flag for details: 15 | 16 | ```python 17 | python scripts_data/visualizer.py --seq_p ./outputs/processed_verts/seqs/s01/capsulemachine_use_01.npy --object --mano --view_idx 0 --distort 18 | ``` 19 | 20 | QUESTION: **Assertion error related to 21 joints** 21 | ANSWER: This is because smplx gives 16 joints for the MANO hand by default. See [setup instruction](setup.md) to allow 21 joints. 22 | 23 | 24 | QUESTION: **How's 2D space, 3D camera space, 3D world space are related? How about cropping?** 25 | ANSWER: See [here](https://github.com/zc-alexfan/arctic/issues/29#issuecomment-1751657365) 26 | -------------------------------------------------------------------------------- /docs/leaderboard.md: -------------------------------------------------------------------------------- 1 | # ARCTIC Leaderboard 2 | 3 | This page contains instructions to submit your results to the ARCTIC leaderboard. **The leaderboard is currently under beta release. Should you encounter any issues feel free to contact us.** 4 | 5 | ## Getting an account 6 | 7 | To get started, go to our [leaderboard website](https://arctic-leaderboard.is.tuebingen.mpg.de/). Click on the `Sign up` button at the top to register a new account. You will receive an email to confirm your account. Note that this is not the same account system as in the [ARCTIC website](https://arctic.is.tue.mpg.de/), so you have to register a separate one. 8 | 9 | > ICCV challenge participants: Use the same email you registered for the challenge to register the evaluation account. Further, your "Algorithm Name" should be your team name. 10 | 11 | After activating your account, you may now log into the website. 12 | 13 | ## Creating an algorithm 14 | 15 | 16 | After logging in, click "My algorithms" at the top to manage your algorithms. To add one, click "Add algorithm" and enter your algorithm details, with only the "short name" field being mandatory. This information will appear on the leaderboard if published. Algorithm scores remain private by default unless published. Click save to create the algorithm. 17 | 18 | **IMPORTANT: When an algorithm is created, you can submit on multiple sub-tasks. You do not need to create a separate algorithm for each sub-task** 19 | 20 | ## Submitting to leaderboard 21 | 22 | After the step above, you'll reach a page to upload prediction results. Initially, use our provided CVPR model zip files below for a trial evaluation based on your chosen task. Post trial, you can submit your own zip files. We recommend starting with egocentric tasks due to their smaller file sizes: 23 | 24 | - [Consistent motion reconstruction: allocentric](https://download.is.tue.mpg.de/arctic/submission/pose_p1_test.zip) 25 | - [Consistent motion reconstruction: egocentric](https://download.is.tue.mpg.de/arctic/submission/pose_p2_test.zip) 26 | - [Interaction field estimation: allocentric](https://download.is.tue.mpg.de/arctic/submission/field_p1_test.zip) 27 | - [Interaction field estimation: egocentric](https://download.is.tue.mpg.de/arctic/submission/field_p2_test.zip) 28 | 29 | 30 | 31 | Click "Upload" and select the relevant task to upload your zip file for evaluation. The evaluation time may vary based on the task and number of requests. You'll see the results in a table, which can be downloaded as a JSON file by clicking "evaluation result". 32 | 33 | Your numbers should closely align with our CVPR models, serving as a sanity check for the file format. Results remain private unless you select "publish", allowing evaluation against the test set ground truth. 34 | 35 | To generate zip files for evaluation, create a custom script using the provided zip files as a template. If using our original codebase, utilize the extraction scripts below to create the zip files. Find detailed data format documentation for the leaderboard [here](leaderboard_format.md). 36 | 37 | To avoid excessive hyperparameter tuning on the test set, each account can only submit to the server for **10 successful evaluations in total every month**. 38 | 39 | ## Preparing submission file with original codebase 40 | 41 | We demonstrate preparing submission files using ARCTIC models as an example. First, run inference on each sequence and save the model predictions to disk. These predictions are then compiled into a zip file for submission. 42 | 43 | > If you're using a different codebase, and prefer to write your own script for generating the zip files, you can inspect the example zip files above. 44 | 45 | To submit predictions, we need to use the extraction script `scripts_method/extract_predicts.py`. Detailed documentation on the extraction script is at [here](model/extraction.md). 46 | 47 | To perform a trial submission, you can try to reproduce numbers on our model `28bf3642f`. It is a ArcticNet-SF model for the egocentric setting in our CVPR paper. See details on the [data documentation](data/data_doc.md) page. 48 | 49 | If you have prepared the arctic data following our standard instructions [here](data/README.md), you can copy the pre-trained model `28bf3642f` via: 50 | 51 | ```bash 52 | cp -r data/arctic_data/models/28bf3642f logs/ 53 | ``` 54 | 55 | Then run this command to perform inference on the test set: 56 | 57 | ```bash 58 | python scripts_method/extract_predicts.py --setup p2 --method arctic_sf --load_ckpt logs/28bf3642f/checkpoints/last.ckpt --run_on test --extraction_mode submit_pose 59 | ``` 60 | 61 | A zip file will be produced, which can be used to upload to the evaluation server. 62 | 63 | Explanation on the options above: 64 | 65 | - `--setup`: allocentric setting (`p1`) or egocentric setting (`p2`) to run on. 66 | - `--method`: the model to construct 67 | - `--load_ckpt`: path to model checkpoint 68 | - `--run_on`: test set evaluation 69 | - `--extraction_mode {submit_pose, submit_field}`: this specifies the extraction is for submission 70 | -------------------------------------------------------------------------------- /docs/leaderboard_format.md: -------------------------------------------------------------------------------- 1 | # Submission format 2 | 3 | ## Consistent motion reconstruction 4 | 5 | ### File structure 6 | 7 | To submit for evaluation, you need to prepare prediction files and store them in a folder (here the folder is `pose_p2_test`), and zip the folder for submission. The following shows the tree structure of a folder before zipping to `pose_p2_test.zip`. 8 | 9 | The folder contains a single subfolder named `eval`. We refer `pose_p2_test` as the `TASK_NAME` to indicate different tasks to evaluate your submission on. The `$TASK_NAME/eval` folder then stores prediction from each sequence in a particular view. 10 | 11 | ``` 12 | pose_p2_test 13 | -- eval 14 | |-- s03_box_grab_01_0 15 | | |-- meta_info 16 | | | `-- meta_info.imgname.pt 17 | | `-- preds 18 | | |-- pred.mano.beta.l.pt 19 | | |-- pred.mano.beta.r.pt 20 | | |-- pred.mano.cam_t.l.pt 21 | | |-- pred.mano.cam_t.r.pt 22 | | |-- pred.mano.pose.l.pt 23 | | |-- pred.mano.pose.r.pt 24 | | |-- pred.object.cam_t.pt 25 | | |-- pred.object.radian.pt 26 | | `-- pred.object.rot.pt 27 | |-- s03_box_use_01_0 28 | | |-- meta_info 29 | | | `-- meta_info.imgname.pt 30 | | `-- preds 31 | | |-- pred.mano.beta.l.pt 32 | | |-- pred.mano.beta.r.pt 33 | | |-- pred.mano.cam_t.l.pt 34 | | |-- pred.mano.cam_t.r.pt 35 | | |-- pred.mano.pose.l.pt 36 | | |-- pred.mano.pose.r.pt 37 | | |-- pred.object.cam_t.pt 38 | | |-- pred.object.radian.pt 39 | | `-- pred.object.rot.pt 40 | ... 41 | ``` 42 | 43 | Lets take `pose_p2_test/eval/s03_box_use_01_0` as an example. The `TASK_NAME` is `pose_p2_test` and `s03_box_use_01_0` means that the folder is for predictions of the sequence `s03_box_use_01` in camera view `0`. Since this is an egocentric task, you will expect the view is always 0, but for allocentric tasks it will range from 1 to 8. 44 | 45 | You will use one of the following `TASK_NAME`: 46 | - `pose_p1_test`: motion reconstruction task, allocentric setting evaluation on the test set 47 | - `pose_p2_test`: motion reconstruction task, egocentric setting evaluation on the test set 48 | - `field_p1_test`: interaction field estimation task, allocentric setting evaluation on the test set 49 | - `field_p2_test`: interaction field estimation task, egocentric setting evaluation on the test set 50 | 51 | Say you want to store your prediction on the motion reconstruction task in allocentric camera setting on the test set for camera 2 and the sequence `s03_capsulemachine_use_04`. The folder to store the prediction will be `pose_p1_test/eval/s03_capsulemachine_use_04_2`. 52 | 53 | ### File formats 54 | 55 | Looking at the tree structure above, you can see that there are two folders `meta_info` and `preds`. The former stores information that is not prediction. In this case, it is only the image paths. The latter folder stores the predictions of the MANO model and the object model. Each `.pt` file is from `torch.save`. 56 | 57 | - `pred.mano.beta.l.pt`: (num_frames, 10); MANO betas for left hand for each frame; FloatTensor 58 | - `pred.mano.cam_t.l.pt`: (num_frames, 3); MANO [translation](https://github.com/zc-alexfan/arctic/blob/08c5e9396087c4529b448cdf736b65fae600866e/src/nets/hand_heads/mano_head.py#L51) for left hand; FloatTensor 59 | - `pred.mano.pose.l.pt`: (num_frames, 16, 3, 3); MANO hand rotations for left hand; FloatTensor; assume `flat_hand_mean=False`; this includes the global orientation; rotation matrix format. 60 | - `pred.object.cam_t.pt`: (num_frames, 3); Object [translation](https://github.com/zc-alexfan/arctic/blob/08c5e9396087c4529b448cdf736b65fae600866e/src/nets/obj_heads/obj_head.py#L60C27-L60C32); FloatTensor 61 | - `pred.object.radian.pt`: (num_frames); Object articulation radian. 62 | - `pred.object.rot.pt`: (num_frames, 3); Object orientation in axis-angle; FloatTensor 63 | - `meta_info.imgname.pt`: (num_frames); A list of strings for image paths 64 | 65 | Example of the first image path: 66 | 67 | ``` 68 | './data/arctic_data/data/cropped_images/s03/box_use_01/0/00010.jpg' 69 | ``` 70 | 71 | You can also refer to our hand and object model classes for a reference of these variables. 72 | 73 | -------------------------------------------------------------------------------- /docs/model/extraction.md: -------------------------------------------------------------------------------- 1 | 2 | # Extraction 3 | 4 | To run our training (for LSTM models), evaluation, and visualization pipelines, we need to save certain predictions to disk in advance. Here we detail the extraction script options. 5 | 6 | ## Script options 7 | 8 | Options: 9 | - `--setup`: the split to use; `{p1, p2}` 10 | - `--method`: model name; `{arctic_sf, arctic_lstm, field_sf, field_lstm}` 11 | - `--load_ckpt`: checkpoint path 12 | - `--run_on`: split to extract prediction on; `{train, val, test}` 13 | - `--extraction_mode`: this defines what predicted variables to extract 14 | 15 | Explanation of `setup`: 16 | - `p1`: allocentric split in our CVPR paper 17 | - `p2`: egocentric split in our CVPR paper 18 | 19 | Explanation of `--extraction_mode`: 20 | - `eval_pose`: dump predicted variables that are related for evaluating pose reconstruction. The evaluation will be done locally (assume GT is provided). 21 | - `eval_field`: dump predicted variables that are related for evaluating interaction field estimation. The evaluation will be done locally (assume GT is provided). 22 | - `submit_pose`: dump predicted variables that are related for evaluating pose reconstruction. The evaluation will be done via a submission server for test set evaluation. 23 | - `submit_field`: dump predicted variables that are related for evaluating interaction field estimation. The evaluation will be done via a submission serverfor test set evaluation. 24 | - `feat_pose`: extract image feature vectors for pose estimation (e.g., these features are inputs of the LSTM model to avoid a backbone in the training process for speedup). 25 | - `feat_field`: extract image feature vectors for interaction field estimation 26 | - `vis_pose`: extract prediction for visualizing pose prediction in our viewer. 27 | - `vis_field`: extract prediction for visualizing interaction field prediction in our viewer. 28 | 29 | ## Extraction examples 30 | 31 | Here we show extraction examples using our pre-trained models. To start, copy our pre-trained models to `./logs`: 32 | 33 | ```bash 34 | mkdir -p logs 35 | cp -r data/arctic_data/models/* logs/ 36 | ``` 37 | 38 | **Example**: Suppose that I want to: 39 | - evaluate the *ArcticNet-SF* pose estimation model (`3558f1342`) 40 | - run on the *val* set 41 | - use the split `p1` to evaluate locally (therefore, `eval_pose`) 42 | - use the checkpoint at `logs/3558f1342/checkpoints/last.ckpt` 43 | 44 | ```bash 45 | python scripts_method/extract_predicts.py --setup p1 --method arctic_sf --load_ckpt logs/3558f1342/checkpoints/last.ckpt --run_on val --extraction_mode eval_pose 46 | ``` 47 | 48 | **Example**: Suppose that I want to: 49 | - evaluate the *ArcticNet-SF* pose estimation model (`3558f1342`) 50 | - run on the *test* set 51 | - use the CVPR split `p1` to evaluate so that we submit to the evaluation server later (therefore, `submit_pose`) 52 | - use the checkpoint at `logs/3558f1342/checkpoints/last.ckpt` 53 | 54 | ```bash 55 | python scripts_method/extract_predicts.py --setup p1 --method arctic_sf --load_ckpt logs/3558f1342/checkpoints/last.ckpt --run_on test --extraction_mode submit_pose 56 | ``` 57 | 58 | **Example**: Suppose that I want to: 59 | - visualize the prediction of the *ArcticNet-SF* pose estimation model (`3558f1342`); therefore, `vis_pose` 60 | - run on the *val* set 61 | - use the split `p1` to evaluate 62 | - use the checkpoint at `logs/3558f1342/checkpoints/last.ckpt` 63 | 64 | ```bash 65 | python scripts_method/extract_predicts.py --setup p1 --method arctic_sf --load_ckpt logs/3558f1342/checkpoints/last.ckpt --run_on val --extraction_mode vis_pose 66 | ``` 67 | 68 | **Example**: Suppose that I want to: 69 | - Extract images features of the *ArcticNet-LSTM* pose estimation model (`3558f1342`) on training and val sets. 70 | - use the split `p1` 71 | - we need to first save the visual features of *ArcticNet-SF* model to disks; Therefore, `feat_pose` 72 | 73 | ```bash 74 | # extract for training 75 | python scripts_method/extract_predicts.py --setup p1 --method arctic_sf --load_ckpt logs/3558f1342/checkpoints/last.ckpt --run_on train --extraction_mode feat_pose 76 | 77 | # extract for evaluation on val set 78 | python scripts_method/extract_predicts.py --setup p1 --method arctic_sf --load_ckpt logs/3558f1342/checkpoints/last.ckpt --run_on val --extraction_mode feat_pose 79 | ``` 80 | 81 | -------------------------------------------------------------------------------- /docs/purchase.md: -------------------------------------------------------------------------------- 1 | # Guide to purchase ARCTIC objects 2 | 3 | Here are a list of links that I used to purchase the ARCTIC objects: 4 | 5 | - [Small Foot Toy Kitchen Set](https://www.amazon.de/-/en/dp/B0756C59FR?ref=ppx_yo2ov_dt_b_fed_asin_title&th=1) 6 | - [Titanium Scissors, Non-Stick, 205 mm, SB, Black](https://www.amazon.de/-/en/dp/B00P1F7QVU?ref=ppx_yo2ov_dt_b_fed_asin_title&th=1) 7 | - [Xucker Tomato Ketchup with Xylitol, No Added Sugar: 1 x 500 ml - GMO Free, Vegan](https://www.amazon.de/dp/B07YQ987P4?ref=ppx_yo2ov_dt_b_fed_asin_title) 8 | - [Small Foot Toy Kitchen Set](https://www.amazon.de/dp/B08HK7ZDYL?ref=ppx_yo2ov_dt_b_fed_asin_title&th=1) 9 | - [Autorenplaner | Buch schreiben und veröffentlichen | Handbuch für Autoren & Schriftsteller | Buch schreiben lernen | mit vielen Tipps & Checklisten | für Anfänger geeignet: Der All-in-one Planer](https://www.amazon.de/dp/3966985934?ref=ppx_yo2ov_dt_b_fed_asin_title) 10 | - [Eichhorn 100002575 Wooden Laptop with Puzzle, 14 Pieces, Screen Surface for Writing on with Chalk, Keyboard Consisting of 6 Puzzle Pieces, 32 x 20 cm, Includes 6 Chalks and Sponge](https://www.amazon.de/dp/B00BLEG3SW?ref=ppx_yo2ov_dt_b_fed_asin_title) 11 | - [Creative Deco A4 Wooden Box with Lid | 33.8 x 24.8 x 10 cm (+/- 1 cm) | Unfinished Storage Box | Large Box | Large Wooden Box Ideal for Storing Valuables, Toys and Tools](https://www.amazon.de/dp/B075X4YHZB?ref=ppx_yo2ov_dt_b_fed_asin_title&th=1) 12 | - [Baby Lips Balm Crayon](https://www.amazon.de/dp/B006PG68EU?ref=ppx_yo2ov_dt_b_fed_asin_title&th=1) 13 | 14 | Unfortunatley, links for other objects do not work anymore. However, you can find the stock photos of the objects [here](stock_photos/). Using these photos, you can then use Google Image Search to find other vendors selling the same items. For example, 15 | 16 |

17 | Image 18 |

19 | 20 | 21 | -------------------------------------------------------------------------------- /docs/setup.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Getting started 4 | 5 | General Requirements: 6 | 7 | - Python 3.10 8 | - torch 1.13.0 9 | - CUDA 11.6 (check `nvcc --version`) 10 | - pytorch3d 0.7.3 11 | - pytorch-lightning 2.0.0 12 | - aitviewer 1.8.0 13 | 14 | Install the environment: 15 | 16 | ```bash 17 | ENV_NAME=arctic_env 18 | conda create -n $ENV_NAME python=3.10 19 | conda activate $ENV_NAME 20 | ``` 21 | 22 | Check your CUDA `nvcc` version: 23 | 24 | ``` 25 | nvcc --version # should be 11.6 26 | ``` 27 | 28 | You can install nvcc and cuda via [runfile](https://developer.nvidia.com/cuda-11-6-0-download-archive). If `nvcc --version` is still not `11.6`, check whether you are referring the right nvcc with `which nvcc`. Assuming you have an NVIDIA driver installed, usually, you only need to run the following command to install `nvcc` (as an example): 29 | 30 | ```bash 31 | sudo bash cuda_11.6.0_510.39.01_linux.run --toolkit --silent --override 32 | ``` 33 | 34 | After the installation, make sure the paths pointing to the current cuda toolkit location. For example: 35 | 36 | ```bash 37 | export CUDA_HOME=/usr/local/cuda-11.6 38 | export PATH="/usr/local/cuda-11.6/bin:$PATH" 39 | export CPATH="/usr/local/cuda-11.6/include:$CPATH" 40 | export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda-11.6/lib64/" 41 | ``` 42 | 43 | Install packages: 44 | 45 | ```bash 46 | pip install -r requirements.txt 47 | conda install pytorch=1.13.0 torchvision pytorch-cuda=11.6 -c pytorch -c nvidia 48 | ``` 49 | 50 | Install PyTorch3D: 51 | 52 | ```bash 53 | # pytorch3d 0.7.3 54 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 55 | conda install -c bottler nvidiacub 56 | conda install pytorch3d -c pytorch3d 57 | ``` 58 | 59 | Install this version of numpy to avoid conflicts: 60 | 61 | ```bash 62 | pip install numpy==1.22.4 63 | ``` 64 | 65 | Modify `smplx` package to return 21 joints for instead of 16: 66 | 67 | ```bash 68 | vim /home//anaconda3/envs//lib//site-packages/smplx/body_models.py 69 | 70 | # uncomment L1681 71 | joints = self.vertex_joint_selector(vertices, joints) 72 | ``` 73 | 74 | If you are unsure about where `body_models.py` is, run these on a terminal: 75 | 76 | ```bash 77 | python 78 | >>> import smplx 79 | >>> print(smplx.__file__) 80 | ``` 81 | 82 | -------------------------------------------------------------------------------- /docs/static/aitviewer-logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/static/dexterous.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/static/dexterous.gif -------------------------------------------------------------------------------- /docs/static/hold/mug_ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/static/hold/mug_ours.gif -------------------------------------------------------------------------------- /docs/static/hold/mug_ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/static/hold/mug_ref.png -------------------------------------------------------------------------------- /docs/static/misalignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/static/misalignment.png -------------------------------------------------------------------------------- /docs/static/teaser.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/static/teaser.jpeg -------------------------------------------------------------------------------- /docs/static/viewer_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/static/viewer_demo.gif -------------------------------------------------------------------------------- /docs/stock_photos/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/.DS_Store -------------------------------------------------------------------------------- /docs/stock_photos/box.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/box.jpg -------------------------------------------------------------------------------- /docs/stock_photos/coffee_machine.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/coffee_machine.jpg -------------------------------------------------------------------------------- /docs/stock_photos/expresso_machine.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/expresso_machine.jpg -------------------------------------------------------------------------------- /docs/stock_photos/google.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/google.png -------------------------------------------------------------------------------- /docs/stock_photos/ketchup.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/ketchup.jpg -------------------------------------------------------------------------------- /docs/stock_photos/laptop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/laptop.jpg -------------------------------------------------------------------------------- /docs/stock_photos/microwave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/microwave.jpg -------------------------------------------------------------------------------- /docs/stock_photos/mixer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/mixer.jpg -------------------------------------------------------------------------------- /docs/stock_photos/notebook.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/notebook.jpg -------------------------------------------------------------------------------- /docs/stock_photos/phone.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/phone.jpg -------------------------------------------------------------------------------- /docs/stock_photos/scissors.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/scissors.jpg -------------------------------------------------------------------------------- /docs/stock_photos/waffleiron.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/docs/stock_photos/waffleiron.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | comet_ml==3.32.8 2 | jpeg4py==0.1.4 3 | loguru 4 | matplotlib 5 | numpy>=1.16.5,<1.23.0 6 | opencv_python 7 | chardet 8 | Pillow 9 | pyrender 10 | pytorch_lightning==2.0.0 11 | scipy 12 | smplx==0.1.28 13 | tqdm 14 | trimesh==3.9.21 15 | scikit-image 16 | imgui==1.4.1 17 | aitviewer==1.8.1 18 | chumpy 19 | black 20 | autopep8 21 | flake8 22 | pylint 23 | isort 24 | easydict 25 | pygit2==1.7 26 | ipdb 27 | opencv-python-headless 28 | -------------------------------------------------------------------------------- /scripts_data/build_splits.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | sys.path = ["."] + sys.path 5 | from src.arctic.split import build_split 6 | 7 | 8 | def construct_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | "--protocol", 12 | type=str, 13 | default=None, 14 | ) 15 | parser.add_argument( 16 | "--split", 17 | type=str, 18 | choices=["train", "val", "test", "all"], 19 | default=None, 20 | ) 21 | parser.add_argument( 22 | "--request_keys", 23 | type=str, 24 | default="cam_coord.2d.bbox.params", 25 | help="save data with these keys (separated by .)", 26 | ) 27 | parser.add_argument( 28 | "--process_folder", type=str, default="./outputs/processed/seqs" 29 | ) 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | if __name__ == "__main__": 35 | args = construct_args() 36 | protocol = args.protocol 37 | split = args.split 38 | request_keys = args.request_keys.split(".") 39 | if protocol == "all": 40 | protocols = [ 41 | "p1", # allocentric 42 | "p2", # egocentric 43 | ] 44 | else: 45 | protocols = [protocol] 46 | 47 | if split == "all": 48 | if protocol in ["p1", "p2"]: 49 | splits = ["train", "val", "test"] 50 | else: 51 | raise ValueError("Unknown protocol for option 'all'") 52 | else: 53 | splits = [split] 54 | 55 | for protocol in protocols: 56 | for split in splits: 57 | if protocol in ["p1", "p2"]: 58 | assert split not in ["test"], "val/test are hidden" 59 | build_split(protocol, split, request_keys, args.process_folder) 60 | -------------------------------------------------------------------------------- /scripts_data/checksum.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as op 3 | import traceback 4 | from glob import glob 5 | from hashlib import sha256 6 | 7 | from tqdm import tqdm 8 | 9 | 10 | def main(): 11 | release_folder = "./downloads" 12 | 13 | print("Globing files...") 14 | fnames = glob(op.join(release_folder, "**/*"), recursive=True) 15 | print("Number of files to checksum: ", len(fnames)) 16 | pbar = tqdm(fnames) 17 | 18 | with open("./bash/assets/checksum.json", "r") as f: 19 | gt_checksum = json.load(f) 20 | 21 | hash_dict = {} 22 | for fname in pbar: 23 | if op.isdir(fname): 24 | continue 25 | if ".zip" not in fname: 26 | continue 27 | if "models_smplx_v1_1.zip" in fname: 28 | continue 29 | if "mano_v1_2.zip" in fname: 30 | continue 31 | 32 | try: 33 | with open(fname, "rb") as f: 34 | pbar.set_description(f"Reading {fname}") 35 | data = f.read() 36 | hashcode = sha256(data).hexdigest() 37 | key = fname.replace(release_folder, "") 38 | hash_dict[key] = hashcode 39 | if hashcode != gt_checksum[key]: 40 | print(f"Error: {fname} has different checksum!") 41 | else: 42 | pbar.set_description(f"Hashcode of {fname} is correct!") 43 | # print(f'Hashcode of {fname} is correct!') 44 | except: 45 | print(f"Error processing {fname}") 46 | traceback.print_exc() 47 | continue 48 | 49 | out_p = op.join(release_folder, "checksum.json") 50 | with open(out_p, "w") as f: 51 | json.dump(hash_dict, f, indent=4, sort_keys=True) 52 | print(f"Checksum file saved to {out_p}!") 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /scripts_data/crop_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import os.path as op 5 | import time 6 | import traceback 7 | from glob import glob 8 | 9 | import numpy as np 10 | from loguru import logger 11 | from PIL import Image 12 | from tqdm import tqdm 13 | 14 | logger.add("file_{time}.log") 15 | 16 | 17 | EGO_IMAGE_SCALE = 0.3 18 | 19 | with open( 20 | f"./arctic_data/meta/misc.json", 21 | "r", 22 | ) as f: 23 | misc = json.load(f) 24 | 25 | 26 | def transform_image(im, bbox_loose, cap_dim): 27 | cx, cy, dim = bbox_loose.copy() 28 | dim *= 200 29 | im_cropped = im.crop((cx - dim / 2, cy - dim / 2, cx + dim / 2, cy + dim / 2)) 30 | 31 | im_cropped_cap = im_cropped.resize((cap_dim, cap_dim)) 32 | return im_cropped_cap 33 | 34 | 35 | def process_fname(fname, bbox_loose, sid, view_idx, pbar): 36 | vidx = int(op.basename(fname).split(".")[0]) - misc[sid]["ioi_offset"] 37 | out_p = fname.replace("./data/arctic_data/data/images", "./outputs/croppped_images") 38 | num_frames = bbox_loose.shape[0] 39 | 40 | if vidx < 0: 41 | # expected 42 | return True 43 | 44 | if vidx >= num_frames: 45 | # not expected 46 | return False 47 | 48 | if op.exists(out_p): 49 | return True 50 | 51 | pbar.set_description(f"Croppping {fname}") 52 | im = Image.open(fname) 53 | if view_idx > 0: 54 | im_cap = transform_image(im, bbox_loose[vidx], cap_dim=1000) 55 | else: 56 | width, height = im.size 57 | width_new = int(width * EGO_IMAGE_SCALE) 58 | height_new = int(height * EGO_IMAGE_SCALE) 59 | im_cap = im.resize((width_new, height_new)) 60 | out_folder = op.dirname(out_p) 61 | if not op.exists(out_folder): 62 | os.makedirs(out_folder) 63 | 64 | im_cap.save(out_p) 65 | return True 66 | 67 | 68 | def process_seq(seq_p): 69 | print(f"Start {seq_p}") 70 | 71 | seq_data = np.load(seq_p, allow_pickle=True).item() 72 | sid, seq_name = seq_p.split("/")[-2:] 73 | 74 | seq_name = seq_name.split(".")[0] 75 | stamp = time.time() 76 | 77 | for view_idx in range(9): 78 | print(f"Processing view#{view_idx}") 79 | bbox = seq_data["bbox"][:, view_idx] 80 | bbox_loose = bbox.copy() 81 | bbox_loose[:, 2] *= 1.5 # 1.5X around the bbox 82 | 83 | fnames = glob( 84 | f"./data/arctic_data/data/images/{sid}/{seq_name}/{view_idx}/*.jpg" 85 | ) 86 | fnames = sorted(fnames) 87 | if len(fnames) == 0: 88 | logger.info(f"No images in {sid}/{seq_name}/{view_idx}") 89 | 90 | pbar = tqdm(fnames) 91 | for fname in pbar: 92 | try: 93 | status = process_fname(fname, bbox_loose, sid, view_idx, pbar) 94 | if status is False: 95 | logger.info(f"Skip due to no GT: {fname}") 96 | except: 97 | traceback.print_exc() 98 | logger.info(f"Skip due to Exception: {fname}") 99 | time.sleep(1.0) 100 | 101 | print(f"Done! Elapsed {time.time() - stamp:.2f}s") 102 | 103 | 104 | def construct_args(): 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument("--task_id", type=int, default=None) 107 | parser.add_argument( 108 | "--process_folder", type=str, default="./outputs/processed/seqs" 109 | ) 110 | args = parser.parse_args() 111 | return args 112 | 113 | 114 | if __name__ == "__main__": 115 | args = construct_args() 116 | seq_ps = glob(op.join(args.process_folder, "*/*.npy")) 117 | seq_ps = sorted(seq_ps) 118 | assert len(seq_ps) > 0 119 | 120 | if args.task_id < 0: 121 | for seq_p in seq_ps: 122 | process_seq(seq_p) 123 | else: 124 | seq_p = seq_ps[args.task_id] 125 | process_seq(seq_p) 126 | -------------------------------------------------------------------------------- /scripts_data/download_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as op 4 | import warnings 5 | 6 | import requests 7 | from loguru import logger 8 | from tqdm import tqdm 9 | 10 | warnings.filterwarnings("ignore", message="Unverified HTTPS request") 11 | 12 | 13 | def download_data(url_file, out_folder, dry_run): 14 | # Define the username and password 15 | if "smplx" in url_file: 16 | flag = "SMPLX" 17 | elif "mano" in url_file: 18 | flag = "MANO" 19 | else: 20 | flag = "ARCTIC" 21 | 22 | username = os.environ[f"{flag}_USERNAME"] 23 | password = os.environ[f"{flag}_PASSWORD"] 24 | password_fake = "*" * len(password) 25 | 26 | logger.info(f"Username: {username}") 27 | logger.info(f"Password: {password_fake}") 28 | 29 | post_data = {"username": username, "password": password} 30 | # Read the URLs from the file 31 | with open(url_file, "r") as f: 32 | urls = f.readlines() 33 | 34 | # Strip newline characters from the URLs 35 | urls = [url.strip() for url in urls] 36 | 37 | if dry_run and "images" in url_file: 38 | urls = urls[:5] 39 | 40 | # Loop through the URLs and download the files 41 | logger.info(f"Start downloading from {url_file}") 42 | pbar = tqdm(urls) 43 | for url in pbar: 44 | pbar.set_description(f"Downloading {url[-40:]}") 45 | # Make a POST request with the username and password 46 | response = requests.post( 47 | url, 48 | data=post_data, 49 | stream=True, 50 | verify=False, 51 | allow_redirects=True, 52 | ) 53 | 54 | if response.status_code == 401: 55 | logger.warning( 56 | f"Authentication failed for URLs in {url_file}. Username/password correct?" 57 | ) 58 | break 59 | 60 | # Get the filename from the URL 61 | filename = url.split("/")[-1] 62 | if "models_smplx_v1_1" in url: 63 | filename = "models_smplx_v1_1.zip" 64 | elif "mano_v1_2" in url: 65 | filename = "mano_v1_2.zip" 66 | elif "image" in url: 67 | filename = "/".join(url.split("/")[-2:]) 68 | 69 | # Write the contents of the response to a file 70 | out_p = op.join(out_folder, filename) 71 | os.makedirs(op.dirname(out_p), exist_ok=True) 72 | with open(out_p, "wb") as f: 73 | f.write(response.content) 74 | 75 | logger.info("Done") 76 | 77 | 78 | def main(): 79 | parser = argparse.ArgumentParser(description="Download files from a list of URLs") 80 | parser.add_argument( 81 | "--url_file", 82 | type=str, 83 | help="Path to file containing list of URLs", 84 | required=True, 85 | ) 86 | parser.add_argument( 87 | "--out_folder", 88 | type=str, 89 | help="Path to folder to store downloaded files", 90 | required=True, 91 | ) 92 | parser.add_argument( 93 | "--dry_run", 94 | action="store_true", 95 | help="Select top 5 URLs if enabled and 'images' is in url_file", 96 | ) 97 | args = parser.parse_args() 98 | if args.dry_run: 99 | logger.info("Running in dry-run mode") 100 | 101 | download_data(args.url_file, args.out_folder, args.dry_run) 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /scripts_data/mocap_viewer.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import argparse 4 | from aitviewer.renderables.spheres import Spheres 5 | from aitviewer.utils.so3 import aa2rot_numpy 6 | from aitviewer.scene.material import Material 7 | from easydict import EasyDict 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--mocap_p", type=str, default=None) 12 | config = parser.parse_args() 13 | args = EasyDict(vars(config)) 14 | return args 15 | 16 | 17 | def main(): 18 | args = parse_args() 19 | mocap_p = args.mocap_p 20 | data = np.load(mocap_p, allow_pickle=True).item() 21 | 22 | from aitviewer.viewer import Viewer 23 | v = Viewer(size=(2048, 1024)) 24 | materials = { 25 | "subject": Material(color=(0.44, 0.56, 0.89, 1.0), ambient=0.35), 26 | "object": Material(color=(0.969, 0.969, 0.969, 1.0), ambient=0.35), 27 | "egocam": Material(color=(0.24, 0.2, 0.2, 1.0), ambient=0.35), 28 | "table": Material(color=(0.24, 0.2, 0.2, 1.0), ambient=0.35), 29 | "support": Material(color=(0.969, 0.106, 0.059, 1.0), ambient=0.35), 30 | } 31 | 32 | 33 | for key, subject in data.items(): 34 | marker_names = subject['labels'] 35 | print(marker_names) 36 | marker_pos = subject['points']/1000 # frame, marker, xyz 37 | rotation_flip = aa2rot_numpy(np.array([-1/2, 0, 0]) * np.pi) 38 | spheres = Spheres(marker_pos, rotation=rotation_flip, name=key, material=materials[key]) 39 | v.scene.add(spheres) 40 | 41 | fps = 30 42 | v.playback_fps = fps 43 | v.scene.fps = fps 44 | v.run() 45 | 46 | 47 | if __name__ == "__main__": 48 | main() -------------------------------------------------------------------------------- /scripts_data/process_seqs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import sys 4 | import time 5 | import traceback 6 | from glob import glob 7 | 8 | import numpy as np 9 | import torch 10 | from loguru import logger 11 | from tqdm import tqdm 12 | 13 | sys.path = ["."] + sys.path 14 | from common.body_models import construct_layers 15 | 16 | # from src.arctic.models.object_tensors import ObjectTensors 17 | from common.object_tensors import ObjectTensors 18 | from src.arctic.processing import process_seq 19 | 20 | 21 | def construct_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--export_verts", action="store_true") 24 | parser.add_argument("--mano_p", type=str, default=None) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def main(): 30 | dev = "cuda:0" 31 | args = construct_args() 32 | 33 | with open( 34 | f"./data/arctic_data/data/meta/misc.json", 35 | "r", 36 | ) as f: 37 | misc = json.load(f) 38 | 39 | statcams = {} 40 | for sub in misc.keys(): 41 | statcams[sub] = { 42 | "world2cam": torch.FloatTensor(np.array(misc[sub]["world2cam"])), 43 | "intris_mat": torch.FloatTensor(np.array(misc[sub]["intris_mat"])), 44 | } 45 | 46 | if args.mano_p is not None: 47 | mano_ps = [args.mano_p] 48 | else: 49 | mano_ps = glob(f"./data/arctic_data/data/raw_seqs/*/*.mano.npy") 50 | 51 | layers = construct_layers(dev) 52 | # object_tensor = ObjectTensors('', './arctic_data/data') 53 | object_tensor = ObjectTensors() 54 | object_tensor.to(dev) 55 | layers["object"] = object_tensor 56 | 57 | pbar = tqdm(mano_ps) 58 | for mano_p in pbar: 59 | pbar.set_description("Processing %s" % mano_p) 60 | try: 61 | task = [mano_p, dev, statcams, layers, pbar] 62 | process_seq(task, export_verts=args.export_verts) 63 | except Exception as e: 64 | logger.info(traceback.format_exc()) 65 | time.sleep(2) 66 | logger.info(f"Failed at {mano_p}") 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /scripts_data/unzip_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | import zipfile 4 | from glob import glob 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def unzip(zip_p, out_dir): 10 | os.makedirs(out_dir, exist_ok=True) 11 | with zipfile.ZipFile(zip_p, "r") as zip_ref: 12 | zip_ref.extractall(out_dir) 13 | 14 | 15 | def main(): 16 | fnames = glob(op.join("downloads/data/", "**/*"), recursive=True) 17 | 18 | full_img_zips = [] 19 | cropped_images_zips = [] 20 | misc_zips = [] 21 | models_zips = [] 22 | for fname in fnames: 23 | if not (".zip" in fname or ".npy" in fname): 24 | continue 25 | if "/images_zips/" in fname: 26 | full_img_zips.append(fname) 27 | elif "/cropped_images_zips/" in fname: 28 | cropped_images_zips.append(fname) 29 | elif "raw_seqs.zip" in fname: 30 | misc_zips.append(fname) 31 | elif "splits_json.zip" in fname: 32 | misc_zips.append(fname) 33 | elif "meta.zip" in fname: 34 | misc_zips.append(fname) 35 | elif "splits.zip" in fname: 36 | misc_zips.append(fname) 37 | elif "feat.zip" in fname: 38 | misc_zips.append(fname) 39 | elif "mocap" in fname or 'smplx_corres.zip' in fname: 40 | misc_zips.append(fname) 41 | elif "models.zip" in fname: 42 | models_zips.append(fname) 43 | else: 44 | print(f"Unknown zip: {fname}") 45 | 46 | out_dir = "./unpack/arctic_data/data" 47 | os.makedirs(out_dir, exist_ok=True) 48 | 49 | # unzip misc files 50 | for zip_p in misc_zips: 51 | print(f"Unzipping {zip_p} to {out_dir}") 52 | unzip(zip_p, out_dir) 53 | 54 | # unzip models files 55 | for zip_p in models_zips: 56 | model_out = out_dir.replace("/data", "") 57 | print(f"Unzipping {zip_p} to {model_out}") 58 | unzip(zip_p, model_out) 59 | 60 | # unzip images 61 | pbar = tqdm(cropped_images_zips) 62 | for zip_p in pbar: 63 | out_p = op.join( 64 | out_dir, 65 | zip_p.replace("downloads/data/", "") 66 | .replace(".zip", "") 67 | .replace("cropped_images_zips/", "cropped_images/"), 68 | ) 69 | pbar.set_description(f"Unzipping {zip_p} to {out_dir}") 70 | unzip(zip_p, out_p) 71 | 72 | # unzip images 73 | pbar = tqdm(full_img_zips) 74 | for zip_p in pbar: 75 | pbar.set_description(f"Unzipping {zip_p} to {out_dir}") 76 | out_p = op.join( 77 | out_dir, 78 | zip_p.replace("downloads/data/", "") 79 | .replace(".zip", "") 80 | .replace("images_zips/", "images/"), 81 | ) 82 | unzip(zip_p, out_p) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /scripts_data/visualizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os.path as op 4 | import random 5 | import sys 6 | from glob import glob 7 | 8 | import torch 9 | from easydict import EasyDict 10 | from loguru import logger 11 | 12 | sys.path = ["."] + sys.path 13 | 14 | from common.body_models import construct_layers 15 | from common.viewer import ARCTICViewer 16 | 17 | 18 | class DataViewer(ARCTICViewer): 19 | def __init__( 20 | self, 21 | render_types=["rgb", "depth", "mask"], 22 | interactive=True, 23 | size=(2024, 2024), 24 | ): 25 | dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | self.layers = construct_layers(dev) 27 | super().__init__(render_types, interactive, size) 28 | 29 | def load_data( 30 | self, 31 | seq_p, 32 | use_mano, 33 | use_object, 34 | use_smplx, 35 | no_image, 36 | use_distort, 37 | view_idx, 38 | subject_meta, 39 | ): 40 | logger.info("Creating meshes") 41 | from src.mesh_loaders.arctic import construct_meshes 42 | 43 | batch = construct_meshes( 44 | seq_p, 45 | self.layers, 46 | use_mano, 47 | use_object, 48 | use_smplx, 49 | no_image, 50 | use_distort, 51 | view_idx, 52 | subject_meta, 53 | ) 54 | self.check_format(batch) 55 | logger.info("Done") 56 | return batch 57 | 58 | 59 | def parse_args(): 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--view_idx", type=int, default=1) 62 | parser.add_argument("--seq_p", type=str, default=None) 63 | parser.add_argument("--headless", action="store_true") 64 | parser.add_argument("--mano", action="store_true") 65 | parser.add_argument("--smplx", action="store_true") 66 | parser.add_argument("--object", action="store_true") 67 | parser.add_argument("--no_image", action="store_true") 68 | parser.add_argument("--distort", action="store_true") 69 | config = parser.parse_args() 70 | args = EasyDict(vars(config)) 71 | return args 72 | 73 | 74 | def main(): 75 | with open( 76 | f"./data/arctic_data/data/meta/misc.json", 77 | "r", 78 | ) as f: 79 | subject_meta = json.load(f) 80 | 81 | args = parse_args() 82 | random.seed(1) 83 | 84 | viewer = DataViewer(interactive=not args.headless, size=(2024, 2024)) 85 | if args.seq_p is None: 86 | seq_ps = glob("./outputs/processed_verts/seqs/*/*.npy") 87 | else: 88 | seq_ps = [args.seq_p] 89 | assert len(seq_ps) > 0, f"No seqs found on {args.seq_p}" 90 | 91 | for seq_idx, seq_p in enumerate(seq_ps): 92 | logger.info(f"Rendering seq#{seq_idx+1}, seq: {seq_p}, view: {args.view_idx}") 93 | seq_name = seq_p.split("/")[-1].split(".")[0] 94 | sid = seq_p.split("/")[-2] 95 | out_name = f"{sid}_{seq_name}_{args.view_idx}" 96 | batch = viewer.load_data( 97 | seq_p, 98 | args.mano, 99 | args.object, 100 | args.smplx, 101 | args.no_image, 102 | args.distort, 103 | args.view_idx, 104 | subject_meta, 105 | ) 106 | viewer.render_seq(batch, out_folder=op.join("render_out", out_name)) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /scripts_method/build_feat_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import os.path as op 5 | from glob import glob 6 | 7 | import numpy as np 8 | import torch 9 | from easydict import EasyDict 10 | from tqdm import tqdm 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--eval_p", type=str, default="") 16 | parser.add_argument("--split", type=str, default="") 17 | parser.add_argument("--protocol", type=str, default="") 18 | config = parser.parse_args() 19 | args = EasyDict(vars(config)) 20 | return args 21 | 22 | 23 | def check_imgname_match(imgnames_feat, setup, split): 24 | print("Verifying") 25 | imgnames_feat = ["/".join(imgname.split("/")[-4:]) for imgname in imgnames_feat] 26 | data = np.load( 27 | op.join(f"data/arctic_data/data/splits/{setup}_{split}.npy"), allow_pickle=True 28 | ).item() 29 | imgnames = data["imgnames"] 30 | imgnames_npy = ["/".join(imgname.split("/")[-4:]) for imgname in imgnames] 31 | assert set(imgnames_npy) == set(imgnames_feat) 32 | print("Passed verifcation") 33 | 34 | 35 | def main(split, protocol, eval_p): 36 | if protocol in ["p1"]: 37 | views = [1, 2, 3, 4, 5, 6, 7, 8] 38 | elif protocol in ["p2"]: 39 | views = [0] 40 | else: 41 | assert False, "Undefined protocol" 42 | 43 | short_split = split.replace("mini", "").replace("tiny", "") 44 | exp_key = eval_p.split("/")[-2] 45 | 46 | load_ps = glob(op.join(eval_p, "*")) 47 | with open( 48 | f"./data/arctic_data/data/splits_json/protocol_{protocol}.json", "r" 49 | ) as f: 50 | seq_names = json.load(f)[short_split] 51 | 52 | # needed seq/view pairs 53 | seq_view_specs = [] 54 | for seq_name in seq_names: 55 | for view_idx in views: 56 | seq_view_specs.append(f"{seq_name}/{view_idx}") 57 | seq_view_specs = set(seq_view_specs) 58 | 59 | if "mini" in split: 60 | import random 61 | 62 | random.seed(1) 63 | random.shuffle(seq_names) 64 | seq_names = seq_names[:10] 65 | 66 | if "tiny" in split: 67 | import random 68 | 69 | random.seed(1) 70 | random.shuffle(seq_names) 71 | seq_names = seq_names[:20] 72 | 73 | # filter seqs within split 74 | _load_ps = [] 75 | for load_p in load_ps: 76 | curr_seq = list(op.basename(load_p)) 77 | view_id = int(curr_seq[-1]) 78 | curr_seq[3] = "/" 79 | curr_seq = "".join(curr_seq)[:-2] # rm view id 80 | if curr_seq in seq_names and view_id in views: 81 | _load_ps.append(load_p) 82 | 83 | load_ps = _load_ps 84 | assert len(load_ps) == len(set(load_ps)) 85 | 86 | assert len(load_ps) > 0 87 | print("Loading image feat") 88 | vecs_list = [] 89 | imgnames_list = [] 90 | for load_p in tqdm(load_ps): 91 | feat_vec = torch.load(op.join(load_p, "preds", "pred.feat_vec.pt")) 92 | imgnames = torch.load(op.join(load_p, "meta_info", "meta_info.imgname.pt")) 93 | vecs_list.append(feat_vec) 94 | imgnames_list.append(imgnames) 95 | vecs_list = torch.cat(vecs_list, dim=0) 96 | imgnames_list = sum(imgnames_list, []) 97 | 98 | if short_split == split: 99 | check_imgname_match(imgnames_list, protocol, split) 100 | 101 | out = {"imgnames": imgnames_list, "feat_vec": vecs_list} 102 | out_folder = "./data/arctic_data/data/feat" 103 | out_p = op.join(out_folder, exp_key, f"{protocol}_{split}.pt") 104 | assert not op.exists(out_p), f"{out_p} already exists" 105 | os.makedirs(op.dirname(out_p), exist_ok=True) 106 | print(f"Dumping into {out_p}") 107 | torch.save(out, out_p) 108 | 109 | 110 | if __name__ == "__main__": 111 | args = parse_args() 112 | split = args.split 113 | if split in ["all"]: 114 | splits = ["minitrain", "minival", "tinytest", "tinyval", "train", "val", "test"] 115 | else: 116 | splits = [split] 117 | 118 | for split in splits: 119 | print(f"Processing {split}") 120 | main(split, args.protocol, args.eval_p) 121 | -------------------------------------------------------------------------------- /scripts_method/evaluate_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | import torch 6 | 7 | sys.path = ["."] + sys.path 8 | import argparse 9 | import os.path as op 10 | 11 | import numpy as np 12 | from easydict import EasyDict 13 | from loguru import logger 14 | from tqdm import tqdm 15 | 16 | import common.thing as thing 17 | from common.ld_utils import cat_dl, ld2dl 18 | from common.xdict import xdict 19 | from src.extraction.interface import prepare_data 20 | from src.utils.eval_modules import eval_fn_dict 21 | 22 | 23 | def evalute_results( 24 | layers, split, exp_key, setup, device, metrics, data_keys, task, eval_p 25 | ): 26 | with open(f"./data/arctic_data/data/splits_json/protocol_{setup}.json", "r") as f: 27 | protocols = json.load(f) 28 | 29 | seqs_val = protocols[split] 30 | 31 | if setup in ["p1"]: 32 | views = [1, 2, 3, 4, 5, 6, 7, 8] 33 | elif setup in ["p2"]: 34 | views = [0] 35 | else: 36 | assert False 37 | 38 | with torch.no_grad(): 39 | all_metrics = {} 40 | pbar = tqdm(seqs_val) 41 | for seq_val in pbar: 42 | for view in views: 43 | curr_seq = seq_val.replace("/", "_") + f"_{view}" 44 | pbar.set_description(f"Processing {curr_seq}: load data") 45 | data = prepare_data( 46 | curr_seq, exp_key, data_keys, layers, device, task, eval_p 47 | ) 48 | pred = data.search("pred.", replace_to="") 49 | targets = data.search("targets.", replace_to="") 50 | meta_info = data.search("meta_info.", replace_to="") 51 | metric_dict = xdict() 52 | for metric in metrics: 53 | pbar.set_description(f"Processing {curr_seq}: {metric}") 54 | # each metric returns a tensor with shape (N, ) 55 | out = eval_fn_dict[metric](pred, targets, meta_info) 56 | metric_dict.merge(out) 57 | metric_dict = metric_dict.to_np() 58 | all_metrics[curr_seq] = metric_dict 59 | 60 | agg_metrics = cat_dl(ld2dl(list(all_metrics.values())), dim=0) 61 | for key, val in agg_metrics.items(): 62 | agg_metrics[key] = float(np.nanmean(thing.thing2np(val))) 63 | 64 | out_folder = eval_p.replace("/eval", "/results") 65 | if not op.exists(out_folder): 66 | os.makedirs(out_folder, exist_ok=True) 67 | np.save(op.join(out_folder, f"all_metrics_{split}_{setup}.npy"), all_metrics) 68 | with open(op.join(out_folder, f"agg_metrics_{split}_{setup}.json"), "w") as f: 69 | json.dump(agg_metrics, f, indent=4) 70 | logger.info(f"Exported results to {out_folder}") 71 | 72 | 73 | def parse_args(): 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("--task", type=str, default="") 76 | parser.add_argument("--eval_p", type=str, default="") 77 | parser.add_argument("--split", type=str, default="") 78 | parser.add_argument("--setup", type=str, default="") 79 | config = parser.parse_args() 80 | args = EasyDict(vars(config)) 81 | return args 82 | 83 | 84 | def main(): 85 | args = parse_args() 86 | from common.body_models import build_layers 87 | 88 | device = "cuda" 89 | layers = build_layers(device) 90 | 91 | eval_p = args.eval_p 92 | exp_key = eval_p.split("/")[-2] 93 | split = args.split 94 | setup = args.setup 95 | 96 | if "pose" in args.task: 97 | from src.extraction.keys.eval_pose import KEYS 98 | 99 | metrics = [ 100 | "aae", 101 | "mpjpe.ra", 102 | "mrrpe", 103 | "success_rate", 104 | "cdev", 105 | "mdev", 106 | "acc_err_pose", 107 | ] 108 | elif "field" in args.task: 109 | from src.extraction.keys.eval_field import KEYS 110 | 111 | metrics = ["avg_err_field", "acc_err_field"] 112 | else: 113 | assert False 114 | 115 | logger.info(f"Evaluating {exp_key} {split} on setup {setup}") 116 | evalute_results( 117 | layers, split, exp_key, setup, device, metrics, KEYS, args.task, eval_p 118 | ) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /scripts_method/extract_predicts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import os.path as op 4 | import sys 5 | from pprint import pformat 6 | 7 | import torch 8 | from loguru import logger 9 | from tqdm import tqdm 10 | 11 | sys.path.append(".") 12 | import common.thing as thing 13 | import src.extraction.interface as interface 14 | import src.factory as factory 15 | from common.xdict import xdict 16 | from src.parsers.parser import construct_args 17 | 18 | 19 | # LSTM models are trained using image features from single-frame models 20 | # this specify the single-frame model features that the LSTM model was trained on 21 | # model_dependencies[lstm_model_id] = single_frame_model_id 22 | model_dependencies = { 23 | "423c6057b": "3558f1342", 24 | "40ae50712": "28bf3642f", 25 | "546c1e997": "1f9ac0b15", 26 | "701a72569": "58e200d16", 27 | "fdc34e6c3": "66417ff6e", 28 | "49abdaee9": "7d09884c6", 29 | "5e6f6aeb9": "fb59bac27", 30 | "ec90691f8": "782c39821", 31 | } 32 | 33 | 34 | def main(): 35 | args = construct_args() 36 | 37 | args.experiment = None 38 | args.exp_key = "xxxxxxx" 39 | 40 | device = "cuda:0" 41 | wrapper = factory.fetch_model(args).to(device) 42 | assert args.load_ckpt != "" 43 | wrapper.load_state_dict(torch.load(args.load_ckpt)["state_dict"]) 44 | logger.info(f"Loaded weights from {args.load_ckpt}") 45 | wrapper.eval() 46 | wrapper.to(device) 47 | wrapper.model.arti_head.object_tensors.to(device) 48 | # wrapper.metric_dict = [] 49 | 50 | exp_key = op.abspath(args.load_ckpt).split("/")[-3] 51 | if exp_key in model_dependencies.keys(): 52 | assert ( 53 | args.img_feat_version == model_dependencies[exp_key] 54 | ), f"Image features used for training ({model_dependencies[exp_key]}) do not match the ones used for the current inference ({args.img_feat_version})" 55 | 56 | out_dir = op.join(args.load_ckpt.split("checkpoints")[0], "eval") 57 | 58 | with open( 59 | f"./data/arctic_data/data/splits_json/protocol_{args.setup}.json", 60 | "r", 61 | ) as f: 62 | seqs = json.load(f)[args.run_on] 63 | 64 | logger.info(f"Hyperparameters: \n {pformat(args)}") 65 | logger.info(f"Seqs to process ({len(seqs)}): {seqs}") 66 | 67 | if args.extraction_mode in ["eval_pose"]: 68 | from src.extraction.keys.eval_pose import KEYS 69 | elif args.extraction_mode in ["eval_field"]: 70 | from src.extraction.keys.eval_field import KEYS 71 | elif args.extraction_mode in ["submit_pose"]: 72 | from src.extraction.keys.submit_pose import KEYS 73 | elif args.extraction_mode in ["submit_field"]: 74 | from src.extraction.keys.submit_field import KEYS 75 | elif args.extraction_mode in ["feat_pose"]: 76 | from src.extraction.keys.feat_pose import KEYS 77 | elif args.extraction_mode in ["feat_field"]: 78 | from src.extraction.keys.feat_field import KEYS 79 | elif args.extraction_mode in ["vis_pose"]: 80 | from src.extraction.keys.vis_pose import KEYS 81 | elif args.extraction_mode in ["vis_field"]: 82 | from src.extraction.keys.vis_field import KEYS 83 | else: 84 | assert False, f"Invalid extract ({args.extraction_mode})" 85 | 86 | if "submit_" in args.extraction_mode: 87 | task = args.extraction_mode.replace('submit_', '') 88 | task_name = f'{task}_{args.setup}_test' 89 | out_dir = out_dir.replace('/eval', f'/submit/{task_name}/eval') 90 | os.makedirs(out_dir, exist_ok=True) 91 | 92 | for seq_idx, seq in enumerate(seqs): 93 | logger.info(f"Processing seq {seq} {seq_idx + 1}/{len(seqs)}") 94 | out_list = [] 95 | val_loader = factory.fetch_dataloader(args, "val", seq) 96 | with torch.no_grad(): 97 | for idx, batch in tqdm(enumerate(val_loader), total=len(val_loader)): 98 | batch = thing.thing2dev(batch, device) 99 | inputs, targets, meta_info = batch 100 | if "submit_" in args.extraction_mode: 101 | out_dict = wrapper.inference(inputs, meta_info) 102 | else: 103 | out_dict = wrapper.forward(inputs, targets, meta_info, "extract") 104 | out_dict = xdict(out_dict) 105 | out_dict = out_dict.subset(KEYS) 106 | out_list.append(out_dict) 107 | 108 | out = interface.std_interface(out_list) 109 | interface.save_results(out, out_dir) 110 | logger.info("Done") 111 | 112 | if 'submit_' in args.extraction_mode: 113 | import shutil 114 | zip_name = f'{task_name}' 115 | zip_path = op.join(out_dir, zip_name).replace(f'/submit/{task_name}/eval/', '/submit/') 116 | shutil.make_archive(zip_path, 'zip', root_dir=op.dirname(zip_path), base_dir=op.basename(zip_path)) 117 | logger.info(f"Your submission file as exported at {zip_path}.zip") 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /scripts_method/train.py: -------------------------------------------------------------------------------- 1 | import comet_ml 2 | import os.path as op 3 | import sys 4 | from pprint import pformat 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | from loguru import logger 9 | from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary 10 | 11 | sys.path.append(".") 12 | 13 | import common.comet_utils as comet_utils 14 | import src.factory as factory 15 | from common.torch_utils import reset_all_seeds 16 | from src.utils.const import args 17 | 18 | 19 | def main(args): 20 | if args.experiment is not None: 21 | comet_utils.log_exp_meta(args) 22 | reset_all_seeds(args.seed) 23 | device = "cuda" if torch.cuda.is_available() else "cpu" 24 | wrapper = factory.fetch_model(args).to(device) 25 | if args.load_ckpt != "": 26 | ckpt = torch.load(args.load_ckpt) 27 | wrapper.load_state_dict(ckpt["state_dict"]) 28 | logger.info(f"Loaded weights from {args.load_ckpt}") 29 | 30 | wrapper.model.arti_head.object_tensors.to(device) 31 | 32 | ckpt_callback = ModelCheckpoint( 33 | monitor="loss__val", 34 | verbose=True, 35 | save_top_k=5, 36 | mode="min", 37 | every_n_epochs=args.eval_every_epoch, 38 | save_last=True, 39 | dirpath=op.join(args.log_dir, "checkpoints"), 40 | ) 41 | 42 | pbar_cb = pl.callbacks.progress.TQDMProgressBar(refresh_rate=1) 43 | 44 | model_summary_cb = ModelSummary(max_depth=3) 45 | callbacks = [ckpt_callback, pbar_cb, model_summary_cb] 46 | trainer = pl.Trainer( 47 | gradient_clip_val=args.grad_clip, 48 | gradient_clip_algorithm="norm", 49 | accumulate_grad_batches=args.acc_grad, 50 | devices=1, 51 | accelerator="gpu", 52 | logger=None, 53 | min_epochs=args.num_epoch, 54 | max_epochs=args.num_epoch, 55 | callbacks=callbacks, 56 | log_every_n_steps=args.log_every, 57 | default_root_dir=args.log_dir, 58 | check_val_every_n_epoch=args.eval_every_epoch, 59 | num_sanity_val_steps=0, 60 | enable_model_summary=False, 61 | ) 62 | 63 | reset_all_seeds(args.seed) 64 | train_loader = factory.fetch_dataloader(args, "train") 65 | logger.info(f"Hyperparameters: \n {pformat(args)}") 66 | logger.info("*** Started training ***") 67 | reset_all_seeds(args.seed) 68 | ckpt_path = None if args.ckpt_p == "" else args.ckpt_p 69 | val_loaders = [factory.fetch_dataloader(args, "val")] 70 | wrapper.set_training_flags() # load weights if needed 71 | trainer.fit(wrapper, train_loader, val_loaders, ckpt_path=ckpt_path) 72 | 73 | 74 | if __name__ == "__main__": 75 | main(args) 76 | -------------------------------------------------------------------------------- /scripts_method/visualizer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | from easydict import EasyDict 5 | 6 | sys.path = ["."] + sys.path 7 | import os.path as op 8 | from glob import glob 9 | 10 | import numpy as np 11 | from loguru import logger 12 | 13 | from common.viewer import ARCTICViewer, ViewerData 14 | from common.xdict import xdict 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--exp_folder", type=str, default="") 20 | parser.add_argument("--angle", type=float, default=None) 21 | parser.add_argument("--zoom_out", type=float, default=0.5) 22 | parser.add_argument("--seq_name", type=str, default="") 23 | parser.add_argument( 24 | "--mode", 25 | type=str, 26 | default="", 27 | choices=[ 28 | "gt_mesh", 29 | "pred_mesh", 30 | "gt_field_r", 31 | "gt_field_l", 32 | "pred_field_r", 33 | "pred_field_l", 34 | ], 35 | ) 36 | parser.add_argument("--headless", action="store_true") 37 | config = parser.parse_args() 38 | args = EasyDict(vars(config)) 39 | return args 40 | 41 | 42 | class MethodViewer(ARCTICViewer): 43 | def load_data(self, exp_folder, seq_name, mode): 44 | logger.info("Creating meshes") 45 | 46 | # check if we are loading gt or pred 47 | if "pred_mesh" in mode or "pred_field" in mode: 48 | flag = "pred" 49 | elif "gt_mesh" in mode or "gt_field" in mode: 50 | flag = "targets" 51 | else: 52 | assert False, f"Unknown mode {mode}" 53 | 54 | exp_key = exp_folder.split("/")[1] 55 | images_path = op.join(exp_folder, "eval", seq_name, "images") 56 | 57 | # load mesh 58 | meshes_all = xdict() 59 | print(f"Specs: {exp_key} {seq_name} {flag}") 60 | if "_mesh" in mode: 61 | from src.mesh_loaders.pose import construct_meshes 62 | 63 | meshes, data = construct_meshes( 64 | exp_folder, seq_name, flag, None, zoom_out=None 65 | ) 66 | meshes_all.merge(meshes) 67 | elif "_field" in mode: 68 | from src.mesh_loaders.field import construct_meshes 69 | 70 | meshes, data = construct_meshes( 71 | exp_folder, seq_name, flag, mode, None, zoom_out=None 72 | ) 73 | meshes_all.merge(meshes) 74 | if "_r" in mode: 75 | meshes_all.pop("left", None) 76 | if "_l" in mode: 77 | meshes_all.pop("right", None) 78 | else: 79 | assert False, f"Unknown mode {mode}" 80 | 81 | imgnames = sorted(glob(images_path + "/*")) 82 | num_frames = min(len(imgnames), data[f"{flag}.object.cam_t"].shape[0]) 83 | 84 | # setup camera 85 | focal = 1000.0 86 | rows = 224 87 | cols = 224 88 | K = np.array([[focal, 0, rows / 2.0], [0, focal, cols / 2.0], [0, 0, 1]]) 89 | cam_t = data[f"{flag}.object.cam_t"] 90 | cam_t = cam_t[:num_frames] 91 | Rt = np.zeros((num_frames, 3, 4)) 92 | Rt[:, :, 3] = cam_t 93 | Rt[:, :3, :3] = np.eye(3) 94 | Rt[:, 1:3, :3] *= -1.0 95 | 96 | # pack data 97 | data = ViewerData(Rt=Rt, K=K, cols=cols, rows=rows, imgnames=imgnames) 98 | batch = meshes_all, data 99 | self.check_format(batch) 100 | logger.info("Done") 101 | return batch 102 | 103 | 104 | def main(): 105 | args = parse_args() 106 | exp_folder = args.exp_folder 107 | seq_name = args.seq_name 108 | mode = args.mode 109 | viewer = MethodViewer( 110 | interactive=not args.headless, 111 | size=(2048, 2048), 112 | render_types=["rgb", "video"], 113 | ) 114 | logger.info(f"Rendering {seq_name} {mode}") 115 | batch = viewer.load_data(exp_folder, seq_name, mode) 116 | viewer.render_seq(batch, out_folder=op.join(exp_folder, "render", seq_name, mode)) 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /src/callbacks/loss/loss_field.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from common.xdict import xdict 4 | 5 | l1_loss = nn.L1Loss(reduction="none") 6 | mse_loss = nn.MSELoss(reduction="none") 7 | ce_loss = nn.CrossEntropyLoss(reduction="none") 8 | 9 | 10 | def dist_loss(loss_dict, pred, gt, meta_info): 11 | is_valid = gt["is_valid"] 12 | mask_o = meta_info["mask"] 13 | 14 | # interfield 15 | loss_ro = mse_loss(pred[f"dist.ro"], gt["dist.ro"]) 16 | loss_lo = mse_loss(pred[f"dist.lo"], gt["dist.lo"]) 17 | 18 | pad_olen = min(pred[f"dist.or"].shape[1], gt["dist.or"].shape[1]) 19 | 20 | loss_or = mse_loss(pred[f"dist.or"][:, :pad_olen], gt["dist.or"][:, :pad_olen]) 21 | loss_ol = mse_loss(pred[f"dist.ol"][:, :pad_olen], gt["dist.ol"][:, :pad_olen]) 22 | 23 | # too many 10cm. Skip them in the loss to prevent overfitting 24 | bnd = 0.1 # 10cm 25 | bnd_idx_ro = gt["dist.ro"] == bnd 26 | bnd_idx_lo = gt["dist.lo"] == bnd 27 | bnd_idx_or = gt["dist.or"][:, :pad_olen] == bnd 28 | bnd_idx_ol = gt["dist.ol"][:, :pad_olen] == bnd 29 | 30 | loss_or = loss_or * mask_o * is_valid[:, None] 31 | loss_ol = loss_ol * mask_o * is_valid[:, None] 32 | 33 | loss_ro = loss_ro * is_valid[:, None] 34 | loss_lo = loss_lo * is_valid[:, None] 35 | 36 | loss_or[bnd_idx_or] *= 0.1 37 | loss_ol[bnd_idx_ol] *= 0.1 38 | loss_ro[bnd_idx_ro] *= 0.1 39 | loss_lo[bnd_idx_lo] *= 0.1 40 | 41 | weight = 100.0 42 | loss_dict[f"loss/dist/ro"] = (loss_ro.mean(), weight) 43 | loss_dict[f"loss/dist/lo"] = (loss_lo.mean(), weight) 44 | loss_dict[f"loss/dist/or"] = (loss_or.mean(), weight) 45 | loss_dict[f"loss/dist/ol"] = (loss_ol.mean(), weight) 46 | return loss_dict 47 | 48 | 49 | def compute_loss(pred, gt, meta_info, args): 50 | loss_dict = xdict() 51 | loss_dict = dist_loss(loss_dict, pred, gt, meta_info) 52 | return loss_dict 53 | -------------------------------------------------------------------------------- /src/callbacks/process/process_field.py: -------------------------------------------------------------------------------- 1 | import src.callbacks.process.process_arctic as process_arctic 2 | import src.callbacks.process.process_generic as generic 3 | 4 | 5 | def process_data(models, inputs, targets, meta_info, mode, args): 6 | batch_size = meta_info["intrinsics"].shape[0] 7 | 8 | ( 9 | v0_r, 10 | v0_l, 11 | v0_o, 12 | pidx, 13 | v0_r_full, 14 | v0_l_full, 15 | v0_o_full, 16 | mask, 17 | cams, 18 | ) = generic.prepare_templates( 19 | batch_size, 20 | models["mano_r"], 21 | models["mano_l"], 22 | models["mesh_sampler"], 23 | models["arti_head"], 24 | meta_info["query_names"], 25 | ) 26 | 27 | meta_info["v0.r"] = v0_r 28 | meta_info["v0.l"] = v0_l 29 | meta_info["v0.o"] = v0_o 30 | meta_info["cams0"] = cams 31 | meta_info["parts_idx"] = pidx 32 | meta_info["v0.r.full"] = v0_r_full 33 | meta_info["v0.l.full"] = v0_l_full 34 | meta_info["v0.o.full"] = v0_o_full 35 | meta_info["mask"] = mask 36 | 37 | inputs, targets, meta_info = process_arctic.process_data( 38 | models, inputs, targets, meta_info, mode, args, field_max=args.max_dist 39 | ) 40 | 41 | return inputs, targets, meta_info 42 | -------------------------------------------------------------------------------- /src/callbacks/process/process_generic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import src.utils.interfield as inter 4 | 5 | 6 | def prepare_mano_template(batch_size, mano_layer, mesh_sampler, is_right): 7 | root_idx = 0 8 | 9 | # Generate T-pose template mesh 10 | template_pose = torch.zeros((1, 48)) 11 | template_pose = template_pose.cuda() 12 | template_betas = torch.zeros((1, 10)).cuda() 13 | out = mano_layer( 14 | betas=template_betas, 15 | hand_pose=template_pose[:, 3:], 16 | global_orient=template_pose[:, :3], 17 | transl=None, 18 | ) 19 | template_3d_joints = out.joints 20 | template_vertices = out.vertices 21 | template_vertices_sub = mesh_sampler.downsample(template_vertices, is_right) 22 | 23 | # normalize 24 | template_root = template_3d_joints[:, root_idx, :] 25 | template_3d_joints = template_3d_joints - template_root[:, None, :] 26 | template_vertices = template_vertices - template_root[:, None, :] 27 | template_vertices_sub = template_vertices_sub - template_root[:, None, :] 28 | 29 | # concatinate template joints and template vertices, and then duplicate to batch size 30 | ref_vertices = torch.cat([template_3d_joints, template_vertices_sub], dim=1) 31 | ref_vertices = ref_vertices.expand(batch_size, -1, -1) 32 | 33 | ref_vertices_full = torch.cat([template_3d_joints, template_vertices], dim=1) 34 | ref_vertices_full = ref_vertices_full.expand(batch_size, -1, -1) 35 | return ref_vertices, ref_vertices_full 36 | 37 | 38 | def prepare_templates( 39 | batch_size, 40 | mano_r, 41 | mano_l, 42 | mesh_sampler, 43 | arti_head, 44 | query_names, 45 | ): 46 | v0_r, v0_r_full = prepare_mano_template( 47 | batch_size, mano_r, mesh_sampler, is_right=True 48 | ) 49 | v0_l, v0_l_full = prepare_mano_template( 50 | batch_size, mano_l, mesh_sampler, is_right=False 51 | ) 52 | (v0_o, pidx, v0_full, mask) = prepare_object_template( 53 | batch_size, 54 | arti_head.object_tensors, 55 | query_names, 56 | ) 57 | CAM_R, CAM_L, CAM_O = list(range(100))[-3:] 58 | cams = ( 59 | torch.FloatTensor([CAM_R, CAM_L, CAM_O]).view(1, 3, 1).repeat(batch_size, 1, 3) 60 | / 100 61 | ) 62 | cams = cams.to(v0_r.device) 63 | return ( 64 | v0_r, 65 | v0_l, 66 | v0_o, 67 | pidx, 68 | v0_r_full, 69 | v0_l_full, 70 | v0_full, 71 | mask, 72 | cams, 73 | ) 74 | 75 | 76 | def prepare_object_template(batch_size, object_tensors, query_names): 77 | template_angles = torch.zeros((batch_size, 1)).cuda() 78 | template_rot = torch.zeros((batch_size, 3)).cuda() 79 | out = object_tensors.forward( 80 | angles=template_angles, 81 | global_orient=template_rot, 82 | transl=None, 83 | query_names=query_names, 84 | ) 85 | ref_vertices = out["v_sub"] 86 | parts_idx = out["parts_ids"] 87 | 88 | mask = out["mask"] 89 | 90 | ref_mean = ref_vertices.mean(dim=1)[:, None, :] 91 | ref_vertices -= ref_mean 92 | 93 | v_template = out["v"] 94 | return (ref_vertices, parts_idx, v_template, mask) 95 | 96 | 97 | def prepare_interfield(targets, max_dist): 98 | dist_min = 0.0 99 | dist_max = max_dist 100 | dist_ro, dist_ro_idx = inter.compute_dist_mano_to_obj( 101 | targets["mano.v3d.cam.r"], 102 | targets["object.v.cam"], 103 | targets["object.v_len"], 104 | dist_min, 105 | dist_max, 106 | ) 107 | dist_lo, dist_lo_idx = inter.compute_dist_mano_to_obj( 108 | targets["mano.v3d.cam.l"], 109 | targets["object.v.cam"], 110 | targets["object.v_len"], 111 | dist_min, 112 | dist_max, 113 | ) 114 | dist_or, dist_or_idx = inter.compute_dist_obj_to_mano( 115 | targets["mano.v3d.cam.r"], 116 | targets["object.v.cam"], 117 | targets["object.v_len"], 118 | dist_min, 119 | dist_max, 120 | ) 121 | dist_ol, dist_ol_idx = inter.compute_dist_obj_to_mano( 122 | targets["mano.v3d.cam.l"], 123 | targets["object.v.cam"], 124 | targets["object.v_len"], 125 | dist_min, 126 | dist_max, 127 | ) 128 | 129 | targets["dist.ro"] = dist_ro 130 | targets["dist.lo"] = dist_lo 131 | targets["dist.or"] = dist_or 132 | targets["dist.ol"] = dist_ol 133 | 134 | targets["idx.ro"] = dist_ro_idx 135 | targets["idx.lo"] = dist_lo_idx 136 | targets["idx.or"] = dist_or_idx 137 | targets["idx.ol"] = dist_ol_idx 138 | return targets 139 | -------------------------------------------------------------------------------- /src/datasets/arctic_dataset_eval.py: -------------------------------------------------------------------------------- 1 | from src.datasets.arctic_dataset import ArcticDataset 2 | 3 | 4 | class ArcticDatasetEval(ArcticDataset): 5 | def getitem(self, imgname, load_rgb=True): 6 | return self.getitem_eval(imgname, load_rgb=load_rgb) 7 | -------------------------------------------------------------------------------- /src/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | from glob import glob 4 | 5 | import numpy as np 6 | from loguru import logger 7 | 8 | import common.data_utils as data_utils 9 | from common.sys_utils import copy_repo 10 | 11 | 12 | def transform_bbox_for_speedup( 13 | speedup, 14 | is_egocam, 15 | _bbox_crop, 16 | ego_image_scale, 17 | ): 18 | bbox_crop = np.array(_bbox_crop) 19 | # bbox is normalized in scale 20 | 21 | if speedup: 22 | if is_egocam: 23 | bbox_crop = [num * ego_image_scale for num in bbox_crop] 24 | else: 25 | # change to new coord system 26 | bbox_crop[0] = 500 27 | bbox_crop[1] = 500 28 | bbox_crop[2] = 1000 / (1.5 * 200) 29 | 30 | # bbox is normalized in scale 31 | return bbox_crop 32 | 33 | 34 | def transform_2d_for_speedup( 35 | speedup, 36 | is_egocam, 37 | _joints2d_r, 38 | _joints2d_l, 39 | _kp2d_b, 40 | _kp2d_t, 41 | _bbox2d_b, 42 | _bbox2d_t, 43 | _bbox_crop, 44 | ego_image_scale, 45 | ): 46 | joints2d_r = np.copy(_joints2d_r) 47 | joints2d_l = np.copy(_joints2d_l) 48 | kp2d_b = np.copy(_kp2d_b) 49 | kp2d_t = np.copy(_kp2d_t) 50 | bbox2d_b = np.copy(_bbox2d_b) 51 | bbox2d_t = np.copy(_bbox2d_t) 52 | bbox_crop = np.array(_bbox_crop) 53 | # bbox is normalized in scale 54 | 55 | if speedup: 56 | if is_egocam: 57 | joints2d_r[:, :2] *= ego_image_scale 58 | joints2d_l[:, :2] *= ego_image_scale 59 | kp2d_b[:, :2] *= ego_image_scale 60 | kp2d_t[:, :2] *= ego_image_scale 61 | bbox2d_b[:, :2] *= ego_image_scale 62 | bbox2d_t[:, :2] *= ego_image_scale 63 | 64 | bbox_crop = [num * ego_image_scale for num in bbox_crop] 65 | else: 66 | # change to new coord system 67 | joints2d_r = data_utils.transform_kp2d(joints2d_r, bbox_crop) 68 | joints2d_l = data_utils.transform_kp2d(joints2d_l, bbox_crop) 69 | kp2d_b = data_utils.transform_kp2d(kp2d_b, bbox_crop) 70 | kp2d_t = data_utils.transform_kp2d(kp2d_t, bbox_crop) 71 | bbox2d_b = data_utils.transform_kp2d(bbox2d_b, bbox_crop) 72 | bbox2d_t = data_utils.transform_kp2d(bbox2d_t, bbox_crop) 73 | 74 | bbox_crop[0] = 500 75 | bbox_crop[1] = 500 76 | bbox_crop[2] = 1000 / (1.5 * 200) 77 | 78 | # bbox is normalized in scale 79 | return ( 80 | joints2d_r, 81 | joints2d_l, 82 | kp2d_b, 83 | kp2d_t, 84 | bbox2d_b, 85 | bbox2d_t, 86 | bbox_crop, 87 | ) 88 | 89 | 90 | def copy_repo_arctic(exp_key): 91 | dst_folder = f"/is/cluster/work/fzicong/chiral_data/cache/logs/{exp_key}/repo" 92 | 93 | if not op.exists(dst_folder): 94 | logger.info("Copying repo") 95 | src_files = glob("./*") 96 | os.makedirs(dst_folder) 97 | filter_keywords = [".ipynb", ".obj", ".pt", "run_scripts", ".sub", ".txt"] 98 | copy_repo(src_files, dst_folder, filter_keywords) 99 | logger.info("Done") 100 | 101 | 102 | def get_num_images(split, num_images): 103 | if split in ["train", "val", "test"]: 104 | return num_images 105 | 106 | if split == "smalltrain": 107 | return 100000 108 | 109 | if split == "tinytrain": 110 | return 12000 111 | 112 | if split == "minitrain": 113 | return 300 114 | 115 | if split == "smallval": 116 | return 12000 117 | 118 | if split == "tinyval": 119 | return 500 120 | 121 | if split == "minival": 122 | return 80 123 | 124 | if split == "smalltest": 125 | return 12000 126 | 127 | if split == "tinytest": 128 | return 6000 129 | 130 | if split == "minitest": 131 | return 200 132 | 133 | assert False, f"Invalid split {split}" 134 | 135 | 136 | def pad_jts2d(jts): 137 | num_jts = jts.shape[0] 138 | jts_pad = np.ones((num_jts, 3)) 139 | jts_pad[:, :2] = jts 140 | return jts_pad 141 | 142 | 143 | def get_valid(data_2d, data_cam, vidx, view_idx, imgname): 144 | assert ( 145 | vidx < data_2d["joints.right"].shape[0] 146 | ), "The requested vidx does not exist in annotation" 147 | is_valid = data_cam["is_valid"][vidx, view_idx] 148 | right_valid = data_cam["right_valid"][vidx, view_idx] 149 | left_valid = data_cam["left_valid"][vidx, view_idx] 150 | return vidx, is_valid, right_valid, left_valid 151 | 152 | 153 | def downsample(fnames, split): 154 | if "small" not in split and "mini" not in split and "tiny" not in split: 155 | return fnames 156 | import random 157 | 158 | random.seed(1) 159 | assert ( 160 | random.randint(0, 100) == 17 161 | ), "Same seed but different results; Subsampling might be different." 162 | 163 | num_samples = get_num_images(split, len(fnames)) 164 | curr_keys = random.sample(fnames, num_samples) 165 | return curr_keys 166 | -------------------------------------------------------------------------------- /src/datasets/tempo_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | 3 | import numpy as np 4 | import torch 5 | from loguru import logger 6 | from torch.utils.data import Dataset 7 | 8 | import common.ld_utils as ld_utils 9 | import src.datasets.dataset_utils as dataset_utils 10 | from src.datasets.arctic_dataset import ArcticDataset 11 | 12 | 13 | class TempoDataset(ArcticDataset): 14 | def _load_data(self, args, split): 15 | data_p = f"./data/arctic_data/data/feat/{args.img_feat_version}/{args.setup}_{split}.pt" 16 | logger.info(f"Loading: {data_p}") 17 | data = torch.load(data_p) 18 | imgnames = data["imgnames"] 19 | vecs_list = data["feat_vec"] 20 | assert len(imgnames) == len(vecs_list) 21 | vec_dict = {} 22 | for imgname, vec in zip(imgnames, vecs_list): 23 | key = "/".join(imgname.split("/")[-4:]) 24 | vec_dict[key] = vec 25 | self.vec_dict = vec_dict 26 | 27 | assert len(imgnames) == len(vec_dict.keys()) 28 | self.aug_data = False 29 | self.window_size = args.window_size 30 | 31 | def __init__(self, args, split, seq=None): 32 | Dataset.__init__(self) 33 | super()._load_data(args, split, seq) 34 | self._load_data(args, split) 35 | 36 | imgnames = list(self.vec_dict.keys()) 37 | imgnames = dataset_utils.downsample(imgnames, split) 38 | 39 | self.imgnames = imgnames 40 | logger.info( 41 | f"TempoDataset Loaded {self.split} split, num samples {len(imgnames)}" 42 | ) 43 | 44 | def __getitem__(self, index): 45 | imgname = self.imgnames[index] 46 | img_idx = int(op.basename(imgname).split(".")[0]) 47 | ind = ( 48 | np.arange(self.window_size) - (self.window_size - 1) / 2 + img_idx 49 | ).astype(np.int64) 50 | num_frames = self.data["/".join(imgname.split("/")[:2])]["params"][ 51 | "rot_r" 52 | ].shape[0] 53 | ind = np.clip( 54 | ind, 10, num_frames - 10 - 1 55 | ) # skip first and last 10 frames as they are not useful 56 | imgnames = [op.join(op.dirname(imgname), "%05d.jpg" % (idx)) for idx in ind] 57 | 58 | targets_list = [] 59 | meta_list = [] 60 | img_feats = [] 61 | inputs_list = [] 62 | load_rgb = True if self.args.method in ["tempo_ft"] else False 63 | for imgname in imgnames: 64 | img_folder = f"./data/arctic_data/data/images/" 65 | inputs, targets, meta_info = self.getitem( 66 | op.join(img_folder, imgname), load_rgb=load_rgb 67 | ) 68 | if load_rgb: 69 | inputs_list.append(inputs) 70 | else: 71 | img_feats.append(self.vec_dict[imgname].type(torch.FloatTensor)) 72 | targets_list.append(targets) 73 | meta_list.append(meta_info) 74 | 75 | if load_rgb: 76 | inputs_list = ld_utils.stack_dl( 77 | ld_utils.ld2dl(inputs_list), dim=0, verbose=False 78 | ) 79 | inputs = {"img": inputs_list["img"]} 80 | else: 81 | img_feats = torch.stack(img_feats, dim=0) 82 | inputs = {"img_feat": img_feats} 83 | 84 | targets_list = ld_utils.stack_dl( 85 | ld_utils.ld2dl(targets_list), dim=0, verbose=False 86 | ) 87 | meta_list = ld_utils.stack_dl(ld_utils.ld2dl(meta_list), dim=0, verbose=False) 88 | 89 | targets_list["is_valid"] = torch.FloatTensor(np.array(targets_list["is_valid"])) 90 | targets_list["left_valid"] = torch.FloatTensor( 91 | np.array(targets_list["left_valid"]) 92 | ) 93 | targets_list["right_valid"] = torch.FloatTensor( 94 | np.array(targets_list["right_valid"]) 95 | ) 96 | targets_list["joints_valid_r"] = torch.FloatTensor( 97 | np.array(targets_list["joints_valid_r"]) 98 | ) 99 | targets_list["joints_valid_l"] = torch.FloatTensor( 100 | np.array(targets_list["joints_valid_l"]) 101 | ) 102 | meta_list["center"] = torch.FloatTensor(np.array(meta_list["center"])) 103 | meta_list["is_flipped"] = torch.FloatTensor(np.array(meta_list["is_flipped"])) 104 | meta_list["rot_angle"] = torch.FloatTensor(np.array(meta_list["rot_angle"])) 105 | return inputs, targets_list, meta_list 106 | -------------------------------------------------------------------------------- /src/datasets/tempo_inference_dataset_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import common.ld_utils as ld_utils 5 | from src.datasets.tempo_inference_dataset import TempoInferenceDataset 6 | 7 | 8 | class TempoInferenceDatasetEval(TempoInferenceDataset): 9 | def __getitem__(self, index): 10 | imgnames = self.windows[index] 11 | inputs_list = [] 12 | targets_list = [] 13 | meta_list = [] 14 | img_feats = [] 15 | # load_rgb = not self.args.eval # test.py do not load rgb 16 | load_rgb = False 17 | for imgname in imgnames: 18 | short_imgname = "/".join(imgname.split("/")[-4:]) 19 | # always load rgb because in training, we need to visualize 20 | # too complicated if not load rgb in eval or other situations 21 | # thus: load both rgb and features 22 | inputs, targets, meta_info = self.getitem_eval(imgname, load_rgb=load_rgb) 23 | img_feats.append(self.vec_dict[short_imgname]) 24 | inputs_list.append(inputs) 25 | targets_list.append(targets) 26 | meta_list.append(meta_info) 27 | 28 | if load_rgb: 29 | inputs_list = ld_utils.stack_dl( 30 | ld_utils.ld2dl(inputs_list), dim=0, verbose=False 31 | ) 32 | else: 33 | inputs_list = {} 34 | targets_list = ld_utils.stack_dl( 35 | ld_utils.ld2dl(targets_list), dim=0, verbose=False 36 | ) 37 | meta_list = ld_utils.stack_dl(ld_utils.ld2dl(meta_list), dim=0, verbose=False) 38 | img_feats = torch.stack(img_feats, dim=0).float() 39 | 40 | inputs_list["img_feat"] = img_feats 41 | meta_list["center"] = torch.FloatTensor(np.array(meta_list["center"])) 42 | meta_list["is_flipped"] = torch.FloatTensor(np.array(meta_list["is_flipped"])) 43 | meta_list["rot_angle"] = torch.FloatTensor(np.array(meta_list["rot_angle"])) 44 | return inputs_list, targets_list, meta_list 45 | -------------------------------------------------------------------------------- /src/extraction/keys/eval_field.py: -------------------------------------------------------------------------------- 1 | KEYS = [ 2 | "pred.dist.ro", 3 | "pred.dist.lo", 4 | "pred.dist.or", 5 | "pred.dist.ol", 6 | "targets.mano.pose.r", 7 | "targets.mano.pose.l", 8 | "targets.mano.beta.r", 9 | "targets.mano.beta.l", 10 | "targets.object.radian", 11 | "targets.object.rot", 12 | "targets.is_valid", 13 | "targets.left_valid", 14 | "targets.right_valid", 15 | "targets.joints_valid_r", 16 | "targets.joints_valid_l", 17 | "targets.mano.cam_t.r", 18 | "targets.mano.cam_t.l", 19 | "targets.object.cam_t", 20 | "meta_info.imgname", 21 | "meta_info.query_names", 22 | "meta_info.window_size", 23 | "meta_info.center", 24 | "meta_info.is_flipped", 25 | "meta_info.rot_angle", 26 | ] 27 | -------------------------------------------------------------------------------- /src/extraction/keys/eval_pose.py: -------------------------------------------------------------------------------- 1 | KEYS = [ 2 | "pred.mano.cam_t.r", 3 | "pred.mano.beta.r", 4 | "pred.mano.pose.r", 5 | "pred.mano.cam_t.l", 6 | "pred.mano.beta.l", 7 | "pred.mano.pose.l", 8 | "pred.object.rot", 9 | "pred.object.cam_t", 10 | "pred.object.radian", 11 | "targets.mano.pose.r", 12 | "targets.mano.pose.l", 13 | "targets.mano.beta.r", 14 | "targets.mano.beta.l", 15 | "targets.object.radian", 16 | "targets.object.rot", 17 | "targets.is_valid", 18 | "targets.left_valid", 19 | "targets.right_valid", 20 | "targets.joints_valid_r", 21 | "targets.joints_valid_l", 22 | "targets.mano.cam_t.r", 23 | "targets.mano.cam_t.l", 24 | "targets.object.cam_t", 25 | "targets.object.bbox3d.cam", 26 | "meta_info.imgname", 27 | "meta_info.query_names", 28 | "meta_info.window_size", 29 | "meta_info.center", 30 | "meta_info.is_flipped", 31 | "meta_info.rot_angle", 32 | "meta_info.diameter", 33 | ] 34 | -------------------------------------------------------------------------------- /src/extraction/keys/feat_field.py: -------------------------------------------------------------------------------- 1 | KEYS = [ 2 | "pred.feat_vec", 3 | "meta_info.imgname", 4 | ] 5 | -------------------------------------------------------------------------------- /src/extraction/keys/feat_pose.py: -------------------------------------------------------------------------------- 1 | KEYS = [ 2 | "pred.feat_vec", 3 | "meta_info.imgname", 4 | ] 5 | -------------------------------------------------------------------------------- /src/extraction/keys/submit_field.py: -------------------------------------------------------------------------------- 1 | KEYS = [ 2 | "pred.dist.ro", 3 | "pred.dist.lo", 4 | "pred.dist.or", 5 | "pred.dist.ol", 6 | "meta_info.imgname", 7 | ] 8 | -------------------------------------------------------------------------------- /src/extraction/keys/submit_pose.py: -------------------------------------------------------------------------------- 1 | KEYS = [ 2 | "pred.mano.cam_t.r", 3 | "pred.mano.beta.r", 4 | "pred.mano.pose.r", 5 | "pred.mano.cam_t.l", 6 | "pred.mano.beta.l", 7 | "pred.mano.pose.l", 8 | "pred.object.rot", 9 | "pred.object.cam_t", 10 | "pred.object.radian", 11 | "meta_info.imgname", 12 | ] 13 | -------------------------------------------------------------------------------- /src/extraction/keys/vis_field.py: -------------------------------------------------------------------------------- 1 | KEYS = [ 2 | "inputs.img", 3 | "pred.dist.lo", 4 | "pred.dist.ol", 5 | "pred.dist.or", 6 | "pred.dist.ro", 7 | "targets.is_valid", 8 | "targets.left_valid", 9 | "targets.right_valid", 10 | "targets.mano.beta.r", 11 | "targets.mano.beta.l", 12 | "targets.mano.pose.r", 13 | "targets.mano.pose.l", 14 | "targets.mano.cam_t.r", 15 | "targets.mano.cam_t.l", 16 | "targets.object.radian", 17 | "targets.object.rot", 18 | "targets.object.cam_t", 19 | "meta_info.imgname", 20 | "meta_info.query_names", 21 | ] 22 | -------------------------------------------------------------------------------- /src/extraction/keys/vis_pose.py: -------------------------------------------------------------------------------- 1 | KEYS = [ 2 | "inputs.img", 3 | "pred.mano.cam_t.r", 4 | "pred.mano.beta.r", 5 | "pred.mano.pose.r", 6 | "pred.mano.cam_t.l", 7 | "pred.mano.beta.l", 8 | "pred.mano.pose.l", 9 | "pred.object.rot", 10 | "pred.object.cam_t", 11 | "pred.object.radian", 12 | "targets.mano.pose.r", 13 | "targets.mano.pose.l", 14 | "targets.mano.beta.r", 15 | "targets.mano.beta.l", 16 | "targets.object.radian", 17 | "targets.object.rot", 18 | "targets.is_valid", 19 | "targets.left_valid", 20 | "targets.right_valid", 21 | "targets.joints_valid_r", 22 | "targets.joints_valid_l", 23 | "targets.mano.cam_t.r", 24 | "targets.mano.cam_t.l", 25 | "targets.object.cam_t", 26 | "meta_info.imgname", 27 | "meta_info.query_names", 28 | "meta_info.window_size", 29 | "meta_info.intrinsics", 30 | "meta_info.dist", 31 | "meta_info.center", 32 | "meta_info.is_flipped", 33 | "meta_info.rot_angle", 34 | "meta_info.diameter", 35 | ] 36 | -------------------------------------------------------------------------------- /src/factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from common.torch_utils import reset_all_seeds 5 | from src.datasets.arctic_dataset import ArcticDataset 6 | from src.datasets.arctic_dataset_eval import ArcticDatasetEval 7 | from src.datasets.tempo_dataset import TempoDataset 8 | from src.datasets.tempo_inference_dataset import TempoInferenceDataset 9 | from src.datasets.tempo_inference_dataset_eval import TempoInferenceDatasetEval 10 | 11 | 12 | def fetch_dataset_eval(args, seq=None): 13 | if args.method in ["arctic_sf"]: 14 | DATASET = ArcticDatasetEval 15 | elif args.method in ["field_sf"]: 16 | DATASET = ArcticDatasetEval 17 | elif args.method in ["arctic_lstm", "field_lstm"]: 18 | DATASET = TempoInferenceDatasetEval 19 | else: 20 | assert False 21 | if seq is not None: 22 | split = args.run_on 23 | ds = DATASET(args=args, split=split, seq=seq) 24 | return ds 25 | 26 | 27 | def fetch_dataset_devel(args, is_train, seq=None): 28 | split = args.trainsplit if is_train else args.valsplit 29 | if args.method in ["arctic_sf"]: 30 | if is_train: 31 | DATASET = ArcticDataset 32 | else: 33 | DATASET = ArcticDataset 34 | elif args.method in ["field_sf"]: 35 | if is_train: 36 | DATASET = ArcticDataset 37 | else: 38 | DATASET = ArcticDataset 39 | elif args.method in ["field_lstm", "arctic_lstm"]: 40 | if is_train: 41 | DATASET = TempoDataset 42 | else: 43 | DATASET = TempoInferenceDataset 44 | else: 45 | assert False 46 | if seq is not None: 47 | split = args.run_on 48 | ds = DATASET(args=args, split=split, seq=seq) 49 | return ds 50 | 51 | 52 | def collate_custom_fn(data_list): 53 | data = data_list[0] 54 | _inputs, _targets, _meta_info = data 55 | out_inputs = {} 56 | out_targets = {} 57 | out_meta_info = {} 58 | 59 | for key in _inputs.keys(): 60 | out_inputs[key] = [] 61 | 62 | for key in _targets.keys(): 63 | out_targets[key] = [] 64 | 65 | for key in _meta_info.keys(): 66 | out_meta_info[key] = [] 67 | 68 | for data in data_list: 69 | inputs, targets, meta_info = data 70 | for key, val in inputs.items(): 71 | out_inputs[key].append(val) 72 | 73 | for key, val in targets.items(): 74 | out_targets[key].append(val) 75 | 76 | for key, val in meta_info.items(): 77 | out_meta_info[key].append(val) 78 | 79 | for key in _inputs.keys(): 80 | out_inputs[key] = torch.cat(out_inputs[key], dim=0) 81 | 82 | for key in _targets.keys(): 83 | out_targets[key] = torch.cat(out_targets[key], dim=0) 84 | 85 | for key in _meta_info.keys(): 86 | if key not in ["imgname", "query_names"]: 87 | out_meta_info[key] = torch.cat(out_meta_info[key], dim=0) 88 | else: 89 | out_meta_info[key] = sum(out_meta_info[key], []) 90 | 91 | return out_inputs, out_targets, out_meta_info 92 | 93 | 94 | def fetch_dataloader(args, mode, seq=None): 95 | if mode == "train": 96 | reset_all_seeds(args.seed) 97 | dataset = fetch_dataset_devel(args, is_train=True) 98 | if type(dataset) == ArcticDataset: 99 | collate_fn = None 100 | else: 101 | collate_fn = collate_custom_fn 102 | return DataLoader( 103 | dataset=dataset, 104 | batch_size=args.batch_size, 105 | num_workers=args.num_workers, 106 | pin_memory=args.pin_memory, 107 | shuffle=args.shuffle_train, 108 | collate_fn=collate_fn, 109 | ) 110 | 111 | elif mode == "val" or mode == "eval": 112 | if "submit_" in args.extraction_mode: 113 | dataset = fetch_dataset_eval(args, seq=seq) 114 | else: 115 | dataset = fetch_dataset_devel(args, is_train=False, seq=seq) 116 | if type(dataset) in [ArcticDataset, ArcticDatasetEval]: 117 | collate_fn = None 118 | else: 119 | collate_fn = collate_custom_fn 120 | return DataLoader( 121 | dataset=dataset, 122 | batch_size=args.test_batch_size, 123 | shuffle=False, 124 | num_workers=args.num_workers, 125 | collate_fn=collate_fn, 126 | ) 127 | else: 128 | assert False 129 | 130 | 131 | def fetch_model(args): 132 | if args.method in ["arctic_sf"]: 133 | from src.models.arctic_sf.wrapper import ArcticSFWrapper as Wrapper 134 | elif args.method in ["arctic_lstm"]: 135 | from src.models.arctic_lstm.wrapper import ArcticLSTMWrapper as Wrapper 136 | elif args.method in ["field_sf"]: 137 | from src.models.field_sf.wrapper import FieldSFWrapper as Wrapper 138 | elif args.method in ["field_lstm"]: 139 | from src.models.field_lstm.wrapper import FieldLSTMWrapper as Wrapper 140 | else: 141 | assert False, f"Invalid method ({args.method})" 142 | model = Wrapper(args) 143 | return model 144 | -------------------------------------------------------------------------------- /src/mesh_loaders/arctic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | import common.viewer as viewer_utils 6 | from common.mesh import Mesh 7 | from common.viewer import ViewerData 8 | 9 | 10 | def construct_hand_meshes(cam_data, layers, view_idx, distort): 11 | if view_idx == 0 and distort: 12 | view_idx = 9 13 | v3d_r = cam_data["verts.right"][:, view_idx] 14 | v3d_l = cam_data["verts.left"][:, view_idx] 15 | 16 | right = { 17 | "v3d": v3d_r, 18 | "f3d": layers["right"].faces, 19 | "vc": None, 20 | "name": "right", 21 | "color": "white", 22 | } 23 | left = { 24 | "v3d": v3d_l, 25 | "f3d": layers["left"].faces, 26 | "vc": None, 27 | "name": "left", 28 | "color": "white", 29 | } 30 | return right, left 31 | 32 | 33 | def construct_object_meshes(cam_data, obj_name, layers, view_idx, distort): 34 | if view_idx == 0 and distort: 35 | view_idx = 9 36 | v3d_o = cam_data["verts.object"][:, view_idx] 37 | f3d_o = Mesh( 38 | filename=f"./data/arctic_data/data/meta/object_vtemplates/{obj_name}/mesh.obj" 39 | ).faces 40 | 41 | obj = { 42 | "v3d": v3d_o, 43 | "f3d": f3d_o, 44 | "vc": None, 45 | "name": "object", 46 | "color": "light-blue", 47 | } 48 | return obj 49 | 50 | 51 | def construct_smplx_meshes(cam_data, layers, view_idx, distort): 52 | assert not distort, "Distortion rendering not supported for SMPL-X" 53 | # We use the following algorithm to render meshes with distortion effects: 54 | # VR Distortion Correction Using Vertex Displacement 55 | # https://stackoverflow.com/questions/44489686/camera-lens-distortion-in-opengl 56 | # However, this method creates artifacts when vertices are too close to the camera. 57 | 58 | if view_idx == 0 and distort: 59 | view_idx = 9 60 | 61 | v3d_s = cam_data["verts.smplx"][:, view_idx] 62 | 63 | smplx_mesh = { 64 | "v3d": v3d_s, 65 | "f3d": layers["smplx"].faces, 66 | "vc": None, 67 | "name": "smplx", 68 | "color": "rice", 69 | } 70 | 71 | return smplx_mesh 72 | 73 | 74 | def construct_meshes( 75 | seq_p, 76 | layers, 77 | use_mano, 78 | use_object, 79 | use_smplx, 80 | no_image, 81 | use_distort, 82 | view_idx, 83 | subject_meta, 84 | ): 85 | # load 86 | data = np.load(seq_p, allow_pickle=True).item() 87 | cam_data = data["cam_coord"] 88 | data_params = data["params"] 89 | # unpack 90 | subject = seq_p.split("/")[-2] 91 | seq_name = seq_p.split("/")[-1].split(".")[0] 92 | obj_name = seq_name.split("_")[0] 93 | 94 | num_frames = cam_data["verts.right"].shape[0] 95 | 96 | # camera intrinsics 97 | if view_idx == 0: 98 | K = torch.FloatTensor(data_params["K_ego"][0].copy()) 99 | else: 100 | K = torch.FloatTensor( 101 | np.array(subject_meta[subject]["intris_mat"][view_idx - 1]) 102 | ) 103 | 104 | # image names 105 | vidx = np.arange(num_frames) 106 | image_idx = vidx + subject_meta[subject]["ioi_offset"] 107 | imgnames = [ 108 | f"./data/arctic_data/data/images/{subject}/{seq_name}/{view_idx}/{idx:05d}.jpg" 109 | for idx in image_idx 110 | ] 111 | 112 | # construct meshes 113 | vis_dict = {} 114 | if use_mano: 115 | right, left = construct_hand_meshes(cam_data, layers, view_idx, use_distort) 116 | vis_dict["right"] = right 117 | vis_dict["left"] = left 118 | if use_smplx: 119 | smplx_mesh = construct_smplx_meshes(cam_data, layers, view_idx, use_distort) 120 | vis_dict["smplx"] = smplx_mesh 121 | if use_object: 122 | obj = construct_object_meshes(cam_data, obj_name, layers, view_idx, use_distort) 123 | vis_dict["object"] = obj 124 | 125 | meshes = viewer_utils.construct_viewer_meshes( 126 | vis_dict, draw_edges=False, flat_shading=False 127 | ) 128 | 129 | num_frames = len(imgnames) 130 | Rt = np.zeros((num_frames, 3, 4)) 131 | Rt[:, :3, :3] = np.eye(3) 132 | Rt[:, 1:3, :3] *= -1.0 133 | 134 | im = Image.open(imgnames[0]) 135 | cols, rows = im.size 136 | if no_image: 137 | imgnames = None 138 | 139 | data = ViewerData(Rt, K, cols, rows, imgnames) 140 | return meshes, data 141 | -------------------------------------------------------------------------------- /src/mesh_loaders/field.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | 3 | import matplotlib 4 | import numpy as np 5 | import torch 6 | 7 | import common.viewer as viewer_utils 8 | from common.body_models import build_layers, seal_mano_mesh 9 | from common.mesh import Mesh 10 | from common.xdict import xdict 11 | from src.extraction.interface import prepare_data 12 | from src.extraction.keys.vis_field import KEYS as keys 13 | 14 | 15 | def dist2vc(dist_r, dist_l, dist_o, ccmap): 16 | vc_r, vc_l, vc_o = viewer_utils.dist2vc(dist_r, dist_l, dist_o, ccmap) 17 | 18 | vc_r_pad = np.zeros((vc_r.shape[0], vc_r.shape[1] + 1, 4)) 19 | vc_l_pad = np.zeros((vc_l.shape[0], vc_l.shape[1] + 1, 4)) 20 | 21 | # sealed vertex to pre-defined color 22 | vc_r_pad[:, -1, 0] = 0.4 23 | vc_l_pad[:, -1, 0] = 0.4 24 | vc_r_pad[:, -1, 1] = 0.2 25 | vc_l_pad[:, -1, 1] = 0.2 26 | vc_r_pad[:, -1, 2] = 0.3 27 | vc_l_pad[:, -1, 2] = 0.3 28 | vc_r_pad[:, -1, 3] = 1.0 29 | vc_l_pad[:, -1, 3] = 1.0 30 | vc_r_pad[:, :-1, :] = vc_r 31 | vc_l_pad[:, :-1, :] = vc_l 32 | 33 | vc_r = vc_r_pad 34 | vc_l = vc_l_pad 35 | return vc_r, vc_l, vc_o 36 | 37 | 38 | def construct_meshes(exp_folder, seq_name, flag, mode, side_angle=None, zoom_out=0.5): 39 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 40 | layers = build_layers(device) 41 | 42 | exp_key = exp_folder.split("/")[1] 43 | # load data 44 | data = prepare_data( 45 | seq_name, 46 | exp_key, 47 | keys, 48 | layers, 49 | device, 50 | task="field", 51 | eval_p=op.join(exp_folder, "eval"), 52 | ) 53 | 54 | # load object faces 55 | obj_name = seq_name.split("_")[1] 56 | f3d_o = Mesh( 57 | filename=f"./data/arctic_data/data/meta/object_vtemplates/{obj_name}/mesh.obj" 58 | ).f 59 | 60 | # only show predicted dist < 0.1 61 | num_frames = data["targets.dist.or"].shape[0] 62 | num_verts = data["targets.dist.or"].shape[1] 63 | data["pred.dist.or"][:num_frames, :num_verts][data["targets.dist.or"] == 0.1] = 0.1 64 | data["pred.dist.ol"][:num_frames, :num_verts][data["targets.dist.ol"] == 0.1] = 0.1 65 | data["pred.dist.ro"][:num_frames, :num_verts][data["targets.dist.ro"] == 0.1] = 0.1 66 | data["pred.dist.lo"][:num_frames, :num_verts][data["targets.dist.lo"] == 0.1] = 0.1 67 | 68 | # center verts 69 | v3d_r = data[f"targets.mano.v3d.cam.r"] 70 | v3d_l = data[f"targets.mano.v3d.cam.l"] 71 | v3d_o = data[f"targets.object.v.cam"] 72 | cam_t = data[f"targets.object.cam_t"] 73 | v3d_r -= cam_t[:, None, :] 74 | v3d_l -= cam_t[:, None, :] 75 | v3d_o -= cam_t[:, None, :] 76 | 77 | # seal MANO meshes 78 | f3d_r = torch.LongTensor(layers["right"].faces.astype(np.int64)) 79 | f3d_l = torch.LongTensor(layers["left"].faces.astype(np.int64)) 80 | v3d_r, f3d_r = seal_mano_mesh(v3d_r, f3d_r, True) 81 | v3d_l, f3d_l = seal_mano_mesh(v3d_l, f3d_l, False) 82 | 83 | if "_l" in mode: 84 | mydist_o = data[f"{flag}.dist.ol"] 85 | else: 86 | mydist_o = data[f"{flag}.dist.or"] 87 | 88 | ccmap = matplotlib.cm.get_cmap("plasma") 89 | vc_r, vc_l, vc_o = dist2vc( 90 | data[f"{flag}.dist.ro"], data[f"{flag}.dist.lo"], mydist_o, ccmap 91 | ) 92 | 93 | right = { 94 | "v3d": v3d_r.numpy(), 95 | "f3d": f3d_r.numpy(), 96 | "vc": vc_r, 97 | "name": "right", 98 | "color": "none", 99 | } 100 | left = { 101 | "v3d": v3d_l.numpy(), 102 | "f3d": f3d_l.numpy(), 103 | "vc": vc_l, 104 | "name": "left", 105 | "color": "none", 106 | } 107 | obj = { 108 | "v3d": v3d_o.numpy(), 109 | "f3d": f3d_o, 110 | "vc": vc_o, 111 | "name": "object", 112 | "color": "none", 113 | } 114 | meshes = viewer_utils.construct_viewer_meshes( 115 | {"right": right, "left": left, "object": obj}, 116 | draw_edges=False, 117 | flat_shading=True, 118 | ) 119 | data = xdict(data).to_np() 120 | 121 | # pred_field uses GT cam_t for vis 122 | data["pred.object.cam_t"] = data["targets.object.cam_t"] 123 | return meshes, data 124 | -------------------------------------------------------------------------------- /src/mesh_loaders/pose.py: -------------------------------------------------------------------------------- 1 | import os.path as op 2 | 3 | import numpy as np 4 | import torch 5 | import trimesh 6 | 7 | import common.viewer as viewer_utils 8 | from common.body_models import build_layers, seal_mano_mesh 9 | from common.xdict import xdict 10 | from src.extraction.interface import prepare_data 11 | from src.extraction.keys.vis_pose import KEYS as keys 12 | 13 | 14 | def construct_meshes(exp_folder, seq_name, flag, side_angle=None, zoom_out=0.5): 15 | exp_key = exp_folder.split("/")[1] 16 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 17 | layers = build_layers(device) 18 | 19 | data = prepare_data( 20 | seq_name, 21 | exp_key, 22 | keys, 23 | layers, 24 | device, 25 | task="pose", 26 | eval_p=op.join(exp_folder, "eval"), 27 | ) 28 | 29 | # load object faces 30 | obj_name = seq_name.split("_")[1] 31 | f3d_o = trimesh.load( 32 | f"./data/arctic_data/data/meta/object_vtemplates/{obj_name}/mesh.obj", 33 | process=False, 34 | ).faces 35 | 36 | # center verts 37 | v3d_r = data[f"{flag}.mano.v3d.cam.r"] 38 | v3d_l = data[f"{flag}.mano.v3d.cam.l"] 39 | v3d_o = data[f"{flag}.object.v.cam"] 40 | cam_t = data[f"{flag}.object.cam_t"] 41 | v3d_r -= cam_t[:, None, :] 42 | v3d_l -= cam_t[:, None, :] 43 | v3d_o -= cam_t[:, None, :] 44 | 45 | # seal MANO mesh 46 | f3d_r = torch.LongTensor(layers["right"].faces.astype(np.int64)) 47 | f3d_l = torch.LongTensor(layers["left"].faces.astype(np.int64)) 48 | v3d_r, f3d_r = seal_mano_mesh(v3d_r, f3d_r, True) 49 | v3d_l, f3d_l = seal_mano_mesh(v3d_l, f3d_l, False) 50 | 51 | # AIT meshes 52 | hand_color = "white" 53 | object_color = "light-blue" 54 | right = { 55 | "v3d": v3d_r.numpy(), 56 | "f3d": f3d_r.numpy(), 57 | "vc": None, 58 | "name": "right", 59 | "color": hand_color, 60 | } 61 | left = { 62 | "v3d": v3d_l.numpy(), 63 | "f3d": f3d_l.numpy(), 64 | "vc": None, 65 | "name": "left", 66 | "color": hand_color, 67 | } 68 | obj = { 69 | "v3d": v3d_o.numpy(), 70 | "f3d": f3d_o, 71 | "vc": None, 72 | "name": "object", 73 | "color": object_color, 74 | } 75 | 76 | meshes = viewer_utils.construct_viewer_meshes( 77 | { 78 | "right": right, 79 | "left": left, 80 | "object": obj, 81 | }, 82 | draw_edges=False, 83 | flat_shading=True, 84 | ) 85 | data = xdict(data).to_np() 86 | return meshes, data 87 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/arctic_lstm/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import common.ld_utils as ld_utils 5 | import src.callbacks.process.process_generic as generic 6 | from common.xdict import xdict 7 | from src.nets.hand_heads.hand_hmr import HandHMR 8 | from src.nets.hand_heads.mano_head import MANOHead 9 | from src.nets.obj_heads.obj_head import ArtiHead 10 | from src.nets.obj_heads.obj_hmr import ObjectHMR 11 | 12 | 13 | class ArcticLSTM(nn.Module): 14 | def __init__(self, focal_length, img_res, args): 15 | super().__init__() 16 | self.args = args 17 | feat_dim = 2048 18 | self.head_r = HandHMR(feat_dim, is_rhand=True, n_iter=3) 19 | self.head_l = HandHMR(feat_dim, is_rhand=False, n_iter=3) 20 | 21 | self.head_o = ObjectHMR(feat_dim, n_iter=3) 22 | 23 | self.mano_r = MANOHead( 24 | is_rhand=True, focal_length=focal_length, img_res=img_res 25 | ) 26 | 27 | self.mano_l = MANOHead( 28 | is_rhand=False, focal_length=focal_length, img_res=img_res 29 | ) 30 | 31 | self.arti_head = ArtiHead(focal_length=focal_length, img_res=img_res) 32 | self.mode = "train" 33 | self.img_res = img_res 34 | self.focal_length = focal_length 35 | self.feat_dim = feat_dim 36 | self.lstm = nn.LSTM( 37 | input_size=2048, 38 | hidden_size=1024, 39 | num_layers=2, 40 | bidirectional=True, 41 | batch_first=True, 42 | ) 43 | 44 | def _fetch_img_feat(self, inputs): 45 | feat_vec = inputs["img_feat"] 46 | return feat_vec 47 | 48 | def forward(self, inputs, meta_info): 49 | window_size = self.args.window_size 50 | query_names = meta_info["query_names"] 51 | K = meta_info["intrinsics"] 52 | device = K.device 53 | feat_vec = self._fetch_img_feat(inputs) 54 | feat_vec = feat_vec.view(-1, window_size, self.feat_dim) 55 | batch_size = feat_vec.shape[0] 56 | 57 | # bidirectional 58 | h0 = torch.randn(2 * 2, batch_size, self.feat_dim // 2, device=device) 59 | c0 = torch.randn(2 * 2, batch_size, self.feat_dim // 2, device=device) 60 | feat_vec, (hn, cn) = self.lstm(feat_vec, (h0, c0)) # batch, seq, 2*dim 61 | feat_vec = feat_vec.reshape(batch_size * window_size, self.feat_dim) 62 | 63 | hmr_output_r = self.head_r(feat_vec, use_pool=False) 64 | hmr_output_l = self.head_l(feat_vec, use_pool=False) 65 | hmr_output_o = self.head_o(feat_vec, use_pool=False) 66 | 67 | # weak perspective 68 | root_r = hmr_output_r["cam_t.wp"] 69 | root_l = hmr_output_l["cam_t.wp"] 70 | root_o = hmr_output_o["cam_t.wp"] 71 | 72 | mano_output_r = self.mano_r( 73 | rotmat=hmr_output_r["pose"], 74 | shape=hmr_output_r["shape"], 75 | K=K, 76 | cam=root_r, 77 | ) 78 | 79 | mano_output_l = self.mano_l( 80 | rotmat=hmr_output_l["pose"], 81 | shape=hmr_output_l["shape"], 82 | K=K, 83 | cam=root_l, 84 | ) 85 | 86 | # fwd mesh when in val or vis 87 | arti_output = self.arti_head( 88 | rot=hmr_output_o["rot"], 89 | angle=hmr_output_o["radian"], 90 | query_names=query_names, 91 | cam=root_o, 92 | K=K, 93 | ) 94 | 95 | root_r_init = hmr_output_r["cam_t.wp.init"] 96 | root_l_init = hmr_output_l["cam_t.wp.init"] 97 | root_o_init = hmr_output_o["cam_t.wp.init"] 98 | mano_output_r["cam_t.wp.init.r"] = root_r_init 99 | mano_output_l["cam_t.wp.init.l"] = root_l_init 100 | arti_output["cam_t.wp.init"] = root_o_init 101 | 102 | mano_output_r = ld_utils.prefix_dict(mano_output_r, "mano.") 103 | mano_output_l = ld_utils.prefix_dict(mano_output_l, "mano.") 104 | arti_output = ld_utils.prefix_dict(arti_output, "object.") 105 | output = xdict() 106 | output.merge(mano_output_r) 107 | output.merge(mano_output_l) 108 | output.merge(arti_output) 109 | output = generic.prepare_interfield(output, self.args.max_dist) 110 | return output 111 | -------------------------------------------------------------------------------- /src/models/arctic_lstm/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from loguru import logger 3 | 4 | import common.torch_utils as torch_utils 5 | from common.xdict import xdict 6 | from src.callbacks.loss.loss_arctic_lstm import compute_loss 7 | from src.callbacks.process.process_arctic import process_data 8 | from src.callbacks.vis.visualize_arctic import visualize_all 9 | from src.models.arctic_lstm.model import ArcticLSTM 10 | from src.models.generic.wrapper import GenericWrapper 11 | 12 | 13 | class ArcticLSTMWrapper(GenericWrapper): 14 | def __init__(self, args): 15 | super().__init__(args) 16 | self.model = ArcticLSTM( 17 | focal_length=args.focal_length, 18 | img_res=args.img_res, 19 | args=args, 20 | ) 21 | self.process_fn = process_data 22 | self.loss_fn = compute_loss 23 | self.metric_dict = [ 24 | "cdev", 25 | "mrrpe", 26 | "mpjpe.ra", 27 | "aae", 28 | "success_rate", 29 | ] 30 | 31 | self.vis_fns = [visualize_all] 32 | self.num_vis_train = 0 33 | self.num_vis_val = 1 34 | 35 | def set_training_flags(self): 36 | if not self.started_training: 37 | sd_p = f"./logs/{self.args.img_feat_version}/checkpoints/last.ckpt" 38 | sd = torch.load(sd_p)["state_dict"] 39 | msd = xdict(sd).search("model.head") 40 | 41 | wd = msd.search("weight") 42 | bd = msd.search("bias") 43 | wd.merge(bd) 44 | self.load_state_dict(wd, strict=False) 45 | torch_utils.toggle_parameters(self, True) 46 | logger.info(f"Loaded: {sd_p}") 47 | self.started_training = True 48 | 49 | def inference(self, inputs, meta_info): 50 | return super().inference_pose(inputs, meta_info) 51 | -------------------------------------------------------------------------------- /src/models/arctic_sf/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import common.ld_utils as ld_utils 4 | from common.xdict import xdict 5 | from src.nets.backbone.utils import get_backbone_info 6 | from src.nets.hand_heads.hand_hmr import HandHMR 7 | from src.nets.hand_heads.mano_head import MANOHead 8 | from src.nets.obj_heads.obj_head import ArtiHead 9 | from src.nets.obj_heads.obj_hmr import ObjectHMR 10 | 11 | 12 | class ArcticSF(nn.Module): 13 | def __init__(self, backbone, focal_length, img_res, args): 14 | super(ArcticSF, self).__init__() 15 | self.args = args 16 | if backbone == "resnet50": 17 | from src.nets.backbone.resnet import resnet50 as resnet 18 | elif backbone == "resnet18": 19 | from src.nets.backbone.resnet import resnet18 as resnet 20 | else: 21 | assert False 22 | self.backbone = resnet(pretrained=True) 23 | feat_dim = get_backbone_info(backbone)["n_output_channels"] 24 | self.head_r = HandHMR(feat_dim, is_rhand=True, n_iter=3) 25 | self.head_l = HandHMR(feat_dim, is_rhand=False, n_iter=3) 26 | 27 | self.head_o = ObjectHMR(feat_dim, n_iter=3) 28 | 29 | self.mano_r = MANOHead( 30 | is_rhand=True, focal_length=focal_length, img_res=img_res 31 | ) 32 | 33 | self.mano_l = MANOHead( 34 | is_rhand=False, focal_length=focal_length, img_res=img_res 35 | ) 36 | 37 | self.arti_head = ArtiHead(focal_length=focal_length, img_res=img_res) 38 | self.mode = "train" 39 | self.img_res = img_res 40 | self.focal_length = focal_length 41 | 42 | def forward(self, inputs, meta_info): 43 | images = inputs["img"] 44 | query_names = meta_info["query_names"] 45 | K = meta_info["intrinsics"] 46 | features = self.backbone(images) 47 | feat_vec = features.view(features.shape[0], features.shape[1], -1).sum(dim=2) 48 | 49 | hmr_output_r = self.head_r(features) 50 | hmr_output_l = self.head_l(features) 51 | hmr_output_o = self.head_o(features) 52 | 53 | # weak perspective 54 | root_r = hmr_output_r["cam_t.wp"] 55 | root_l = hmr_output_l["cam_t.wp"] 56 | root_o = hmr_output_o["cam_t.wp"] 57 | 58 | mano_output_r = self.mano_r( 59 | rotmat=hmr_output_r["pose"], 60 | shape=hmr_output_r["shape"], 61 | K=K, 62 | cam=root_r, 63 | ) 64 | 65 | mano_output_l = self.mano_l( 66 | rotmat=hmr_output_l["pose"], 67 | shape=hmr_output_l["shape"], 68 | K=K, 69 | cam=root_l, 70 | ) 71 | 72 | # fwd mesh when in val or vis 73 | arti_output = self.arti_head( 74 | rot=hmr_output_o["rot"], 75 | angle=hmr_output_o["radian"], 76 | query_names=query_names, 77 | cam=root_o, 78 | K=K, 79 | ) 80 | 81 | root_r_init = hmr_output_r["cam_t.wp.init"] 82 | root_l_init = hmr_output_l["cam_t.wp.init"] 83 | root_o_init = hmr_output_o["cam_t.wp.init"] 84 | mano_output_r["cam_t.wp.init.r"] = root_r_init 85 | mano_output_l["cam_t.wp.init.l"] = root_l_init 86 | arti_output["cam_t.wp.init"] = root_o_init 87 | 88 | mano_output_r = ld_utils.prefix_dict(mano_output_r, "mano.") 89 | mano_output_l = ld_utils.prefix_dict(mano_output_l, "mano.") 90 | arti_output = ld_utils.prefix_dict(arti_output, "object.") 91 | output = xdict() 92 | output.merge(mano_output_r) 93 | output.merge(mano_output_l) 94 | output.merge(arti_output) 95 | output["feat_vec"] = feat_vec.cpu().detach() 96 | return output 97 | -------------------------------------------------------------------------------- /src/models/arctic_sf/wrapper.py: -------------------------------------------------------------------------------- 1 | from src.callbacks.loss.loss_arctic_sf import compute_loss 2 | from src.callbacks.process.process_arctic import process_data 3 | from src.callbacks.vis.visualize_arctic import visualize_all 4 | from src.models.arctic_sf.model import ArcticSF 5 | from src.models.generic.wrapper import GenericWrapper 6 | 7 | 8 | class ArcticSFWrapper(GenericWrapper): 9 | def __init__(self, args): 10 | super().__init__(args) 11 | self.model = ArcticSF( 12 | backbone="resnet50", 13 | focal_length=args.focal_length, 14 | img_res=args.img_res, 15 | args=args, 16 | ) 17 | self.process_fn = process_data 18 | self.loss_fn = compute_loss 19 | self.metric_dict = [ 20 | "cdev", 21 | "mrrpe", 22 | "mpjpe.ra", 23 | "aae", 24 | "success_rate", 25 | ] 26 | 27 | self.vis_fns = [visualize_all] 28 | 29 | self.num_vis_train = 1 30 | self.num_vis_val = 1 31 | 32 | def inference(self, inputs, meta_info): 33 | return super().inference_pose(inputs, meta_info) 34 | -------------------------------------------------------------------------------- /src/models/field_lstm/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from common.xdict import xdict 5 | from src.models.field_sf.model import RegressHead, Upsampler 6 | from src.nets.backbone.utils import get_backbone_info 7 | from src.nets.obj_heads.obj_head import ArtiHead 8 | from src.nets.pointnet import PointNetfeat 9 | 10 | 11 | class FieldLSTM(nn.Module): 12 | def __init__(self, backbone, focal_length, img_res, window_size): 13 | super().__init__() 14 | assert backbone in ["resnet18", "resnet50"] 15 | feat_dim = get_backbone_info(backbone)["n_output_channels"] 16 | self.arti_head = ArtiHead(focal_length=focal_length, img_res=img_res) 17 | 18 | img_down_dim = 512 19 | img_mid_dim = 512 20 | pt_out_dim = 512 21 | self.down = nn.Sequential( 22 | nn.Linear(feat_dim, img_mid_dim), 23 | nn.ReLU(), 24 | nn.Linear(img_mid_dim, img_down_dim), 25 | nn.ReLU(), 26 | ) # downsize image features 27 | 28 | pt_shallow_dim = 512 29 | pt_mid_dim = 512 30 | self.point_backbone = PointNetfeat( 31 | input_dim=3 + img_down_dim, 32 | shallow_dim=pt_shallow_dim, 33 | mid_dim=pt_mid_dim, 34 | out_dim=pt_out_dim, 35 | ) 36 | pts_dim = pt_shallow_dim + pt_out_dim 37 | self.dist_head_or = RegressHead(pts_dim) 38 | self.dist_head_ol = RegressHead(pts_dim) 39 | self.dist_head_ro = RegressHead(pts_dim) 40 | self.dist_head_lo = RegressHead(pts_dim) 41 | self.avgpool = nn.AdaptiveAvgPool2d(1) 42 | 43 | self.num_v_sub = 195 # mano subsampled 44 | self.num_v_o_sub = 300 * 2 # object subsampled 45 | self.num_v_o = 4000 # object 46 | self.upsampling_r = Upsampler(self.num_v_sub, 778) 47 | self.upsampling_l = Upsampler(self.num_v_sub, 778) 48 | self.upsampling_o = Upsampler(self.num_v_o_sub, self.num_v_o) 49 | self.lstm = nn.LSTM( 50 | input_size=2048, 51 | hidden_size=1024, 52 | num_layers=2, 53 | bidirectional=True, 54 | batch_first=True, 55 | ) 56 | 57 | self.feat_dim = feat_dim 58 | self.window_size = window_size 59 | 60 | def forward(self, inputs, meta_info): 61 | window_size = self.window_size 62 | device = meta_info["v0.r"].device 63 | 64 | feat_vec = inputs["img_feat"].view(-1, window_size, self.feat_dim) 65 | batch_size = feat_vec.shape[0] 66 | 67 | points_r = meta_info["v0.r"].permute(0, 2, 1)[:, :, 21:] 68 | points_l = meta_info["v0.l"].permute(0, 2, 1)[:, :, 21:] 69 | points_o = meta_info["v0.o"].permute(0, 2, 1) 70 | points_all = torch.cat((points_r, points_l, points_o), dim=2) 71 | 72 | # bidirectional 73 | h0 = torch.randn(2 * 2, batch_size, self.feat_dim // 2, device=device) 74 | c0 = torch.randn(2 * 2, batch_size, self.feat_dim // 2, device=device) 75 | feat_vec, (hn, cn) = self.lstm(feat_vec, (h0, c0)) # batch, seq, 2*dim 76 | feat_vec = feat_vec.reshape(batch_size * window_size, self.feat_dim) 77 | 78 | img_feat = self.down(feat_vec) 79 | num_mano_pts = points_r.shape[2] 80 | num_object_pts = points_o.shape[2] 81 | 82 | img_feat_all = img_feat[:, :, None].repeat( 83 | 1, 1, num_mano_pts * 2 + num_object_pts 84 | ) 85 | pts_all_feat = self.point_backbone( 86 | torch.cat((points_all, img_feat_all), dim=1) 87 | )[0] 88 | pts_r_feat, pts_l_feat, pts_o_feat = torch.split( 89 | pts_all_feat, [num_mano_pts, num_mano_pts, num_object_pts], dim=2 90 | ) 91 | 92 | dist_ro = self.dist_head_ro(pts_r_feat) 93 | dist_lo = self.dist_head_lo(pts_l_feat) 94 | dist_or = self.dist_head_or(pts_o_feat) 95 | dist_ol = self.dist_head_ol(pts_o_feat) 96 | 97 | dist_ro = self.upsampling_r(dist_ro[:, :, None])[:, :, 0] 98 | dist_lo = self.upsampling_l(dist_lo[:, :, None])[:, :, 0] 99 | dist_or = self.upsampling_o(dist_or[:, :, None])[:, :, 0] 100 | dist_ol = self.upsampling_o(dist_ol[:, :, None])[:, :, 0] 101 | 102 | out = xdict() 103 | out["dist.ro"] = dist_ro 104 | out["dist.lo"] = dist_lo 105 | out["dist.or"] = dist_or 106 | out["dist.ol"] = dist_ol 107 | return out 108 | -------------------------------------------------------------------------------- /src/models/field_lstm/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from loguru import logger 3 | 4 | import common.torch_utils as torch_utils 5 | from common.xdict import xdict 6 | from src.callbacks.loss.loss_field import compute_loss 7 | from src.callbacks.process.process_field import process_data 8 | from src.callbacks.vis.visualize_field import visualize_all 9 | from src.models.field_lstm.model import FieldLSTM 10 | from src.models.generic.wrapper import GenericWrapper 11 | 12 | 13 | class FieldLSTMWrapper(GenericWrapper): 14 | def __init__(self, args): 15 | super().__init__(args) 16 | self.model = FieldLSTM( 17 | "resnet50", 18 | args.focal_length, 19 | args.img_res, 20 | args.window_size, 21 | ) 22 | self.process_fn = process_data 23 | self.loss_fn = compute_loss 24 | self.metric_dict = ["avg_err_field"] 25 | 26 | self.vis_fns = [visualize_all] 27 | self.num_vis_train = 0 28 | self.num_vis_val = 1 29 | 30 | def set_training_flags(self): 31 | if not self.started_training: 32 | sd_p = f"./logs/{self.args.img_feat_version}/checkpoints/last.ckpt" 33 | sd = torch.load(sd_p)["state_dict"] 34 | msd = xdict(sd).search("model.").rm("model.backbone") 35 | 36 | wd = msd.search("weight") 37 | bd = msd.search("bias") 38 | wd.merge(bd) 39 | self.load_state_dict(wd, strict=False) 40 | torch_utils.toggle_parameters(self, True) 41 | logger.info(f"Loaded: {sd_p}") 42 | self.started_training = True 43 | 44 | def inference(self, inputs, meta_info): 45 | return super().inference_field(inputs, meta_info) 46 | -------------------------------------------------------------------------------- /src/models/field_sf/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from common.xdict import xdict 5 | from src.nets.backbone.utils import get_backbone_info 6 | from src.nets.obj_heads.obj_head import ArtiHead 7 | from src.nets.pointnet import PointNetfeat 8 | 9 | 10 | class Upsampler(nn.Module): 11 | def __init__(self, in_dim, out_dim): 12 | super().__init__() 13 | self.upsampling = torch.nn.Linear(in_dim, out_dim) 14 | 15 | def forward(self, pred_vertices_sub): 16 | temp_transpose = pred_vertices_sub.transpose(1, 2) 17 | pred_vertices = self.upsampling(temp_transpose) 18 | pred_vertices = pred_vertices.transpose(1, 2) 19 | return pred_vertices 20 | 21 | 22 | class RegressHead(nn.Module): 23 | def __init__(self, input_dim): 24 | super().__init__() 25 | 26 | self.network = nn.Sequential( 27 | nn.Conv1d(input_dim, 512, 1), 28 | nn.BatchNorm1d(512), 29 | nn.ReLU(), 30 | nn.Conv1d(512, 128, 1), 31 | nn.BatchNorm1d(128), 32 | nn.ReLU(), 33 | nn.Conv1d(128, 1, 1), 34 | ) 35 | 36 | def forward(self, x): 37 | dist = self.network(x).permute(0, 2, 1)[:, :, 0] 38 | return dist 39 | 40 | 41 | class FieldSF(nn.Module): 42 | def __init__(self, backbone, focal_length, img_res): 43 | super().__init__() 44 | if backbone == "resnet18": 45 | from src.nets.backbone.resnet import resnet18 as resnet 46 | elif backbone == "resnet50": 47 | from src.nets.backbone.resnet import resnet50 as resnet 48 | else: 49 | assert False 50 | self.backbone = resnet(pretrained=True) 51 | feat_dim = get_backbone_info(backbone)["n_output_channels"] 52 | self.arti_head = ArtiHead(focal_length=focal_length, img_res=img_res) 53 | 54 | img_down_dim = 512 55 | img_mid_dim = 512 56 | pt_out_dim = 512 57 | self.down = nn.Sequential( 58 | nn.Linear(feat_dim, img_mid_dim), 59 | nn.ReLU(), 60 | nn.Linear(img_mid_dim, img_down_dim), 61 | nn.ReLU(), 62 | ) # downsize image features 63 | 64 | pt_shallow_dim = 512 65 | pt_mid_dim = 512 66 | self.point_backbone = PointNetfeat( 67 | input_dim=3 + img_down_dim, 68 | shallow_dim=pt_shallow_dim, 69 | mid_dim=pt_mid_dim, 70 | out_dim=pt_out_dim, 71 | ) 72 | pts_dim = pt_shallow_dim + pt_out_dim 73 | self.dist_head_or = RegressHead(pts_dim) 74 | self.dist_head_ol = RegressHead(pts_dim) 75 | self.dist_head_ro = RegressHead(pts_dim) 76 | self.dist_head_lo = RegressHead(pts_dim) 77 | self.avgpool = nn.AdaptiveAvgPool2d(1) 78 | 79 | self.num_v_sub = 195 # mano subsampled 80 | self.num_v_o_sub = 300 * 2 # object subsampled 81 | self.num_v_o = 4000 # object 82 | self.upsampling_r = Upsampler(self.num_v_sub, 778) 83 | self.upsampling_l = Upsampler(self.num_v_sub, 778) 84 | self.upsampling_o = Upsampler(self.num_v_o_sub, self.num_v_o) 85 | 86 | def _decode(self, pts_all_feat): 87 | pts_all_feat = self.point_backbone(pts_all_feat)[0] 88 | pts_r_feat, pts_l_feat, pts_o_feat = torch.split( 89 | pts_all_feat, 90 | [self.num_mano_pts, self.num_mano_pts, self.num_object_pts], 91 | dim=2, 92 | ) 93 | 94 | dist_ro = self.dist_head_ro(pts_r_feat) 95 | dist_lo = self.dist_head_lo(pts_l_feat) 96 | dist_or = self.dist_head_or(pts_o_feat) 97 | dist_ol = self.dist_head_ol(pts_o_feat) 98 | return dist_ro, dist_lo, dist_or, dist_ol 99 | 100 | def forward(self, inputs, meta_info): 101 | images = inputs["img"] 102 | points_r = meta_info["v0.r"].permute(0, 2, 1)[:, :, 21:] 103 | points_l = meta_info["v0.l"].permute(0, 2, 1)[:, :, 21:] 104 | points_o = meta_info["v0.o"].permute(0, 2, 1) 105 | points_all = torch.cat((points_r, points_l, points_o), dim=2) 106 | 107 | img_feat = self.backbone(images) 108 | img_feat = self.avgpool(img_feat).view(img_feat.shape[0], -1) 109 | pred_vec = img_feat.clone() 110 | img_feat = self.down(img_feat) 111 | 112 | self.num_mano_pts = points_r.shape[2] 113 | self.num_object_pts = points_o.shape[2] 114 | 115 | img_feat_all = img_feat[:, :, None].repeat( 116 | 1, 1, self.num_mano_pts * 2 + self.num_object_pts 117 | ) 118 | 119 | pts_all_feat = torch.cat((points_all, img_feat_all), dim=1) 120 | dist_ro, dist_lo, dist_or, dist_ol = self._decode(pts_all_feat) 121 | dist_ro = self.upsampling_r(dist_ro[:, :, None])[:, :, 0] 122 | dist_lo = self.upsampling_l(dist_lo[:, :, None])[:, :, 0] 123 | dist_or = self.upsampling_o(dist_or[:, :, None])[:, :, 0] 124 | dist_ol = self.upsampling_o(dist_ol[:, :, None])[:, :, 0] 125 | 126 | out = xdict() 127 | out["dist.ro"] = dist_ro 128 | out["dist.lo"] = dist_lo 129 | out["dist.or"] = dist_or 130 | out["dist.ol"] = dist_ol 131 | out["feat_vec"] = pred_vec 132 | return out 133 | -------------------------------------------------------------------------------- /src/models/field_sf/wrapper.py: -------------------------------------------------------------------------------- 1 | from src.callbacks.loss.loss_field import compute_loss 2 | from src.callbacks.process.process_field import process_data 3 | from src.callbacks.vis.visualize_field import visualize_all 4 | from src.models.field_sf.model import FieldSF 5 | from src.models.generic.wrapper import GenericWrapper 6 | 7 | 8 | class FieldSFWrapper(GenericWrapper): 9 | def __init__(self, args): 10 | super().__init__(args) 11 | self.model = FieldSF("resnet50", args.focal_length, args.img_res) 12 | self.process_fn = process_data 13 | self.loss_fn = compute_loss 14 | self.metric_dict = ["avg_err_field"] 15 | 16 | self.vis_fns = [visualize_all] 17 | self.num_vis_train = 1 18 | self.num_vis_val = 1 19 | 20 | def inference(self, inputs, meta_info): 21 | return super().inference_field(inputs, meta_info) 22 | -------------------------------------------------------------------------------- /src/nets/backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zc-alexfan/arctic/e904c9695911d0e4a2025ebaeff1d91e5caf4d69/src/nets/backbone/__init__.py -------------------------------------------------------------------------------- /src/nets/backbone/utils.py: -------------------------------------------------------------------------------- 1 | def get_backbone_info(backbone): 2 | info = { 3 | "resnet18": {"n_output_channels": 512, "downsample_rate": 4}, 4 | "resnet34": {"n_output_channels": 512, "downsample_rate": 4}, 5 | "resnet50": {"n_output_channels": 2048, "downsample_rate": 4}, 6 | "resnet50_adf_dropout": {"n_output_channels": 2048, "downsample_rate": 4}, 7 | "resnet50_dropout": {"n_output_channels": 2048, "downsample_rate": 4}, 8 | "resnet101": {"n_output_channels": 2048, "downsample_rate": 4}, 9 | "resnet152": {"n_output_channels": 2048, "downsample_rate": 4}, 10 | "resnext50_32x4d": {"n_output_channels": 2048, "downsample_rate": 4}, 11 | "resnext101_32x8d": {"n_output_channels": 2048, "downsample_rate": 4}, 12 | "wide_resnet50_2": {"n_output_channels": 2048, "downsample_rate": 4}, 13 | "wide_resnet101_2": {"n_output_channels": 2048, "downsample_rate": 4}, 14 | "mobilenet_v2": {"n_output_channels": 1280, "downsample_rate": 4}, 15 | "hrnet_w32": {"n_output_channels": 480, "downsample_rate": 4}, 16 | "hrnet_w48": {"n_output_channels": 720, "downsample_rate": 4}, 17 | # 'hrnet_w64': {'n_output_channels': 2048, 'downsample_rate': 4}, 18 | "dla34": {"n_output_channels": 512, "downsample_rate": 4}, 19 | } 20 | return info[backbone] 21 | -------------------------------------------------------------------------------- /src/nets/hand_heads/hand_hmr.py: -------------------------------------------------------------------------------- 1 | import pytorch3d.transforms.rotation_conversions as rot_conv 2 | import torch 3 | import torch.nn as nn 4 | 5 | from common.xdict import xdict 6 | from src.nets.hmr_layer import HMRLayer 7 | 8 | 9 | class HandHMR(nn.Module): 10 | def __init__(self, feat_dim, is_rhand, n_iter): 11 | super().__init__() 12 | self.is_rhand = is_rhand 13 | 14 | hand_specs = {"pose_6d": 6 * 16, "cam_t/wp": 3, "shape": 10} 15 | self.hmr_layer = HMRLayer(feat_dim, 1024, hand_specs) 16 | 17 | self.cam_init = nn.Sequential( 18 | nn.Linear(feat_dim, 512), 19 | nn.ReLU(), 20 | nn.Linear(512, 512), 21 | nn.ReLU(), 22 | nn.Linear(512, 3), 23 | ) 24 | 25 | self.hand_specs = hand_specs 26 | self.n_iter = n_iter 27 | self.avgpool = nn.AdaptiveAvgPool2d(1) 28 | 29 | def init_vector_dict(self, features): 30 | batch_size = features.shape[0] 31 | dev = features.device 32 | init_pose = ( 33 | rot_conv.matrix_to_rotation_6d( 34 | rot_conv.axis_angle_to_matrix(torch.zeros(16, 3)) 35 | ) 36 | .reshape(1, -1) 37 | .repeat(batch_size, 1) 38 | ) 39 | init_shape = torch.zeros(1, 10).repeat(batch_size, 1) 40 | init_transl = self.cam_init(features) 41 | 42 | out = {} 43 | out["pose_6d"] = init_pose 44 | out["shape"] = init_shape 45 | out["cam_t/wp"] = init_transl 46 | out = xdict(out).to(dev) 47 | return out 48 | 49 | def forward(self, features, use_pool=True): 50 | batch_size = features.shape[0] 51 | if use_pool: 52 | feat = self.avgpool(features) 53 | feat = feat.view(feat.size(0), -1) 54 | else: 55 | feat = features 56 | 57 | init_vdict = self.init_vector_dict(feat) 58 | init_cam_t = init_vdict["cam_t/wp"].clone() 59 | pred_vdict = self.hmr_layer(feat, init_vdict, self.n_iter) 60 | 61 | pred_rotmat = rot_conv.rotation_6d_to_matrix( 62 | pred_vdict["pose_6d"].reshape(-1, 6) 63 | ).view(batch_size, 16, 3, 3) 64 | 65 | pred_vdict["pose"] = pred_rotmat 66 | pred_vdict["cam_t.wp.init"] = init_cam_t 67 | pred_vdict = pred_vdict.replace_keys("/", ".") 68 | return pred_vdict 69 | -------------------------------------------------------------------------------- /src/nets/hand_heads/mano_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import common.camera as camera 4 | import common.data_utils as data_utils 5 | import common.rot as rot 6 | import common.transforms as tf 7 | from common.body_models import build_mano_aa 8 | from common.xdict import xdict 9 | 10 | 11 | class MANOHead(nn.Module): 12 | def __init__(self, is_rhand, focal_length, img_res): 13 | super(MANOHead, self).__init__() 14 | self.mano = build_mano_aa(is_rhand) 15 | self.add_module("mano", self.mano) 16 | self.focal_length = focal_length 17 | self.img_res = img_res 18 | self.is_rhand = is_rhand 19 | 20 | def forward(self, rotmat, shape, cam, K): 21 | """ 22 | :param rotmat: rotation in euler angles format (N,J,3,3) 23 | :param shape: smpl betas 24 | :param cam: weak perspective camera 25 | :param normalize_joints2d: bool, normalize joints between -1, 1 if true 26 | :return: dict with keys 'vertices', 'joints3d', 'joints2d' if cam is True 27 | """ 28 | 29 | rotmat_original = rotmat.clone() 30 | rotmat = rot.matrix_to_axis_angle(rotmat.reshape(-1, 3, 3)).reshape(-1, 48) 31 | 32 | mano_output = self.mano( 33 | betas=shape, 34 | hand_pose=rotmat[:, 3:], 35 | global_orient=rotmat[:, :3], 36 | ) 37 | output = xdict() 38 | 39 | avg_focal_length = (K[:, 0, 0] + K[:, 1, 1]) / 2.0 40 | cam_t = camera.weak_perspective_to_perspective_torch( 41 | cam, focal_length=avg_focal_length, img_res=self.img_res, min_s=0.1 42 | ) 43 | 44 | joints3d_cam = mano_output.joints + cam_t[:, None, :] 45 | v3d_cam = mano_output.vertices + cam_t[:, None, :] 46 | 47 | joints2d = tf.project2d_batch(K, joints3d_cam) 48 | joints2d = data_utils.normalize_kp2d(joints2d, self.img_res) 49 | 50 | output["cam_t.wp"] = cam 51 | output["cam_t"] = cam_t 52 | output["joints3d"] = mano_output.joints 53 | output["vertices"] = mano_output.vertices 54 | output["j3d.cam"] = joints3d_cam 55 | output["v3d.cam"] = v3d_cam 56 | output["j2d.norm"] = joints2d 57 | output["beta"] = shape 58 | output["pose"] = rotmat_original 59 | 60 | postfix = ".r" if self.is_rhand else ".l" 61 | output_pad = output.postfix(postfix) 62 | return output_pad 63 | -------------------------------------------------------------------------------- /src/nets/hmr_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class HMRLayer(nn.Module): 6 | def __init__(self, feat_dim, mid_dim, specs_dict): 7 | super().__init__() 8 | 9 | self.feat_dim = feat_dim 10 | self.avgpool = nn.AdaptiveAvgPool2d(1) 11 | self.specs_dict = specs_dict 12 | 13 | vector_dim = sum(list(zip(*specs_dict.items()))[1]) 14 | hmr_dim = feat_dim + vector_dim 15 | 16 | # construct refine 17 | self.refine = nn.Sequential( 18 | nn.Linear(hmr_dim, mid_dim), 19 | nn.ReLU(), 20 | nn.Dropout(), 21 | nn.Linear(mid_dim, mid_dim), 22 | nn.ReLU(), 23 | nn.Dropout(), 24 | ) 25 | 26 | # construct decoders 27 | decoders = {} 28 | for key, vec_size in specs_dict.items(): 29 | decoders[key] = nn.Linear(mid_dim, vec_size) 30 | self.decoders = nn.ModuleDict(decoders) 31 | 32 | self.init_weights() 33 | 34 | def init_weights(self): 35 | for key, decoder in self.decoders.items(): 36 | nn.init.xavier_uniform_(decoder.weight, gain=0.01) 37 | self.decoders[key] = decoder 38 | 39 | def forward(self, feat, init_vector_dict, n_iter): 40 | pred_vector_dict = init_vector_dict 41 | for i in range(n_iter): 42 | vectors = list(zip(*pred_vector_dict.items()))[1] 43 | xc = torch.cat([feat] + list(vectors), dim=1) 44 | xc = self.refine(xc) 45 | for key, decoder in self.decoders.items(): 46 | pred_vector_dict.overwrite(key, decoder(xc) + pred_vector_dict[key]) 47 | 48 | pred_vector_dict.has_invalid() 49 | return pred_vector_dict 50 | -------------------------------------------------------------------------------- /src/nets/obj_heads/obj_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import common.camera as camera 4 | import common.data_utils as data_utils 5 | import common.transforms as tf 6 | from common.object_tensors import ObjectTensors 7 | from common.xdict import xdict 8 | 9 | 10 | class ArtiHead(nn.Module): 11 | def __init__(self, focal_length, img_res): 12 | super().__init__() 13 | self.object_tensors = ObjectTensors() 14 | self.focal_length = focal_length 15 | self.img_res = img_res 16 | 17 | def forward( 18 | self, 19 | rot, 20 | angle, 21 | query_names, 22 | cam, 23 | K, 24 | transl=None, 25 | ): 26 | if self.object_tensors.dev != rot.device: 27 | self.object_tensors.to(rot.device) 28 | 29 | out = self.object_tensors.forward(angle.view(-1, 1), rot, transl, query_names) 30 | 31 | # after adding relative transl 32 | bbox3d = out["bbox3d"] 33 | kp3d = out["kp3d"] 34 | 35 | # right hand translation 36 | avg_focal_length = (K[:, 0, 0] + K[:, 1, 1]) / 2.0 37 | cam_t = camera.weak_perspective_to_perspective_torch( 38 | cam, focal_length=avg_focal_length, img_res=self.img_res, min_s=0.1 39 | ) 40 | 41 | # camera coord 42 | bbox3d_cam = bbox3d + cam_t[:, None, :] 43 | kp3d_cam = kp3d + cam_t[:, None, :] 44 | 45 | # 2d keypoints 46 | kp2d = tf.project2d_batch(K, kp3d_cam) 47 | bbox2d = tf.project2d_batch(K, bbox3d_cam) 48 | 49 | kp2d = data_utils.normalize_kp2d(kp2d, self.img_res) 50 | bbox2d = data_utils.normalize_kp2d(bbox2d, self.img_res) 51 | num_kps = kp2d.shape[1] // 2 52 | 53 | output = xdict() 54 | output["rot"] = rot 55 | if transl is not None: 56 | # relative transl 57 | output["transl"] = transl # mete 58 | 59 | output["cam_t.wp"] = cam 60 | output["cam_t"] = cam_t 61 | output["kp3d"] = kp3d 62 | output["bbox3d"] = bbox3d 63 | output["bbox3d.cam"] = bbox3d_cam 64 | output["kp3d.cam"] = kp3d_cam 65 | output["kp2d.norm"] = kp2d 66 | output["kp2d.norm.t"] = kp2d[:, :num_kps] 67 | output["kp2d.norm.b"] = kp2d[:, num_kps:] 68 | output["bbox2d.norm.t"] = bbox2d[:, :8] 69 | output["bbox2d.norm.b"] = bbox2d[:, 8:] 70 | output["radian"] = angle 71 | 72 | output["v.cam"] = out["v"] + cam_t[:, None, :] 73 | output["v_len"] = out["v_len"] 74 | output["f"] = out["f"] 75 | output["f_len"] = out["f_len"] 76 | 77 | return output 78 | -------------------------------------------------------------------------------- /src/nets/obj_heads/obj_hmr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from common.xdict import xdict 5 | from src.nets.hmr_layer import HMRLayer 6 | 7 | 8 | class ObjectHMR(nn.Module): 9 | def __init__(self, feat_dim, n_iter): 10 | super().__init__() 11 | 12 | obj_specs = {"rot": 3, "cam_t/wp": 3, "radian": 1} 13 | self.hmr_layer = HMRLayer(feat_dim, 1024, obj_specs) 14 | 15 | self.cam_init = nn.Sequential( 16 | nn.Linear(feat_dim, 512), 17 | nn.ReLU(), 18 | nn.Linear(512, 512), 19 | nn.ReLU(), 20 | nn.Linear(512, 3), 21 | ) 22 | 23 | self.obj_specs = obj_specs 24 | self.n_iter = n_iter 25 | self.avgpool = nn.AdaptiveAvgPool2d(1) 26 | 27 | def init_vector_dict(self, features): 28 | batch_size = features.shape[0] 29 | dev = features.device 30 | init_rot = torch.zeros(batch_size, 3) 31 | init_angle = torch.zeros(batch_size, 1) 32 | init_transl = self.cam_init(features) 33 | 34 | out = {} 35 | out["rot"] = init_rot 36 | out["radian"] = init_angle 37 | out["cam_t/wp"] = init_transl 38 | out = xdict(out).to(dev) 39 | return out 40 | 41 | def forward(self, features, use_pool=True): 42 | if use_pool: 43 | feat = self.avgpool(features) 44 | feat = feat.view(feat.size(0), -1) 45 | else: 46 | feat = features 47 | 48 | init_vdict = self.init_vector_dict(feat) 49 | init_cam_t = init_vdict["cam_t/wp"].clone() 50 | pred_vdict = self.hmr_layer(feat, init_vdict, self.n_iter) 51 | pred_vdict["cam_t.wp.init"] = init_cam_t 52 | pred_vdict = pred_vdict.replace_keys("/", ".") 53 | return pred_vdict 54 | -------------------------------------------------------------------------------- /src/nets/pointnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.parallel 8 | import torch.utils.data 9 | from torch.autograd import Variable 10 | 11 | """ 12 | Source: https://github.com/fxia22/pointnet.pytorch/blob/f0c2430b0b1529e3f76fb5d6cd6ca14be763d975/pointnet/model.py 13 | """ 14 | 15 | 16 | class STN3d(nn.Module): 17 | def __init__(self): 18 | super(STN3d, self).__init__() 19 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 20 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 21 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 22 | self.fc1 = nn.Linear(1024, 512) 23 | self.fc2 = nn.Linear(512, 256) 24 | self.fc3 = nn.Linear(256, 9) 25 | self.relu = nn.ReLU() 26 | 27 | self.bn1 = nn.BatchNorm1d(64) 28 | self.bn2 = nn.BatchNorm1d(128) 29 | self.bn3 = nn.BatchNorm1d(1024) 30 | self.bn4 = nn.BatchNorm1d(512) 31 | self.bn5 = nn.BatchNorm1d(256) 32 | 33 | def forward(self, x): 34 | batchsize = x.size()[0] 35 | x = F.relu(self.bn1(self.conv1(x))) 36 | x = F.relu(self.bn2(self.conv2(x))) 37 | x = F.relu(self.bn3(self.conv3(x))) 38 | x = torch.max(x, 2, keepdim=True)[0] 39 | x = x.view(-1, 1024) 40 | 41 | x = F.relu(self.bn4(self.fc1(x))) 42 | x = F.relu(self.bn5(self.fc2(x))) 43 | x = self.fc3(x) 44 | 45 | iden = ( 46 | Variable( 47 | torch.from_numpy( 48 | np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32) 49 | ) 50 | ) 51 | .view(1, 9) 52 | .repeat(batchsize, 1) 53 | ) 54 | if x.is_cuda: 55 | iden = iden.cuda() 56 | x = x + iden 57 | x = x.view(-1, 3, 3) 58 | return x 59 | 60 | 61 | class STNkd(nn.Module): 62 | def __init__(self, k=64): 63 | super(STNkd, self).__init__() 64 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 65 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 66 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 67 | self.fc1 = nn.Linear(1024, 512) 68 | self.fc2 = nn.Linear(512, 256) 69 | self.fc3 = nn.Linear(256, k * k) 70 | self.relu = nn.ReLU() 71 | 72 | self.bn1 = nn.BatchNorm1d(64) 73 | self.bn2 = nn.BatchNorm1d(128) 74 | self.bn3 = nn.BatchNorm1d(1024) 75 | self.bn4 = nn.BatchNorm1d(512) 76 | self.bn5 = nn.BatchNorm1d(256) 77 | 78 | self.k = k 79 | 80 | def forward(self, x): 81 | batchsize = x.size()[0] 82 | x = F.relu(self.bn1(self.conv1(x))) 83 | x = F.relu(self.bn2(self.conv2(x))) 84 | x = F.relu(self.bn3(self.conv3(x))) 85 | x = torch.max(x, 2, keepdim=True)[0] 86 | x = x.view(-1, 1024) 87 | 88 | x = F.relu(self.bn4(self.fc1(x))) 89 | x = F.relu(self.bn5(self.fc2(x))) 90 | x = self.fc3(x) 91 | 92 | iden = ( 93 | Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))) 94 | .view(1, self.k * self.k) 95 | .repeat(batchsize, 1) 96 | ) 97 | if x.is_cuda: 98 | iden = iden.cuda() 99 | x = x + iden 100 | x = x.view(-1, self.k, self.k) 101 | return x 102 | 103 | 104 | class PointNetfeat(nn.Module): 105 | def __init__(self, input_dim, shallow_dim, mid_dim, out_dim, global_feat=False): 106 | super(PointNetfeat, self).__init__() 107 | self.shallow_layer = nn.Sequential( 108 | nn.Conv1d(input_dim, shallow_dim, 1), nn.BatchNorm1d(shallow_dim) 109 | ) 110 | 111 | self.base_layer = nn.Sequential( 112 | nn.Conv1d(shallow_dim, mid_dim, 1), 113 | nn.BatchNorm1d(mid_dim), 114 | nn.ReLU(), 115 | nn.Conv1d(mid_dim, out_dim, 1), 116 | nn.BatchNorm1d(out_dim), 117 | ) 118 | 119 | self.global_feat = global_feat 120 | self.out_dim = out_dim 121 | 122 | def forward(self, x): 123 | n_pts = x.size()[2] 124 | x = self.shallow_layer(x) 125 | pointfeat = x 126 | 127 | x = self.base_layer(x) 128 | x = torch.max(x, 2, keepdim=True)[0] 129 | x = x.view(-1, self.out_dim) 130 | 131 | trans_feat = None 132 | trans = None 133 | if self.global_feat: 134 | return x, trans, trans_feat 135 | else: 136 | x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts) 137 | return torch.cat([x, pointfeat], 1), trans, trans_feat 138 | -------------------------------------------------------------------------------- /src/parsers/configs/arctic_lstm.py: -------------------------------------------------------------------------------- 1 | from src.parsers.configs.generic import DEFAULT_ARGS_ALLO, DEFAULT_ARGS_EGO 2 | 3 | DEFAULT_ARGS_EGO["batch_size"] = 64 4 | DEFAULT_ARGS_EGO["test_batch_size"] = 64 5 | DEFAULT_ARGS_EGO["img_feat_version"] = "28bf3642f" 6 | DEFAULT_ARGS_EGO["num_epoch"] = 80 7 | 8 | DEFAULT_ARGS_ALLO["batch_size"] = 64 9 | DEFAULT_ARGS_ALLO["test_batch_size"] = 64 10 | DEFAULT_ARGS_ALLO["img_feat_version"] = "3558f1342" 11 | DEFAULT_ARGS_ALLO["num_epoch"] = 10 12 | -------------------------------------------------------------------------------- /src/parsers/configs/arctic_sf.py: -------------------------------------------------------------------------------- 1 | from src.parsers.configs.generic import DEFAULT_ARGS_ALLO, DEFAULT_ARGS_EGO 2 | 3 | DEFAULT_ARGS_EGO["img_feat_version"] = "" # should use ArcticDataset 4 | DEFAULT_ARGS_ALLO["img_feat_version"] = "" # should use ArcticDataset 5 | -------------------------------------------------------------------------------- /src/parsers/configs/field_lstm.py: -------------------------------------------------------------------------------- 1 | from src.parsers.configs.generic import DEFAULT_ARGS_ALLO, DEFAULT_ARGS_EGO 2 | 3 | DEFAULT_ARGS_EGO["batch_size"] = 32 4 | DEFAULT_ARGS_EGO["test_batch_size"] = 32 5 | DEFAULT_ARGS_EGO["img_feat_version"] = "58e200d16" 6 | DEFAULT_ARGS_EGO["num_epoch"] = 50 7 | 8 | DEFAULT_ARGS_ALLO["batch_size"] = 32 9 | DEFAULT_ARGS_ALLO["test_batch_size"] = 32 10 | DEFAULT_ARGS_ALLO["img_feat_version"] = "1f9ac0b15" 11 | DEFAULT_ARGS_ALLO["num_epoch"] = 6 12 | -------------------------------------------------------------------------------- /src/parsers/configs/field_sf.py: -------------------------------------------------------------------------------- 1 | from src.parsers.configs.generic import DEFAULT_ARGS_ALLO, DEFAULT_ARGS_EGO 2 | 3 | DEFAULT_ARGS_EGO["img_feat_version"] = "" # should use ArcticDataset 4 | DEFAULT_ARGS_ALLO["img_feat_version"] = "" # should use ArcticDataset 5 | -------------------------------------------------------------------------------- /src/parsers/configs/generic.py: -------------------------------------------------------------------------------- 1 | DEFAULT_ARGS_EGO = { 2 | "run_on": "", 3 | "trainsplit": "train", 4 | "valsplit": "tinyval", 5 | "setup": "p2a", 6 | "method": "arctic", 7 | "log_every": 50, 8 | "eval_every_epoch": 5, 9 | "lr_dec_epoch": [], 10 | "num_epoch": 100, 11 | "lr": 1e-5, 12 | "lr_dec_factor": 10, 13 | "lr_decay": 0.1, 14 | "num_exp": 1, 15 | "exp_key": "", 16 | "batch_size": 64, 17 | "test_batch_size": 128, 18 | "temp_loader": False, 19 | "window_size": 11, 20 | "num_workers": 16, 21 | "img_feat_version": "", 22 | "eval_on": "", 23 | "acc_grad": 1, 24 | "load_from": "", 25 | "load_ckpt": "", 26 | "infer_ckpt": "", 27 | "resume_ckpt": "", 28 | "gpu_ids": [0], 29 | "agent_id": 0, 30 | "cluster_node": "", 31 | "bid": 21, 32 | "gpu_arch": "ampere", 33 | "gpu_min_mem": 20000, 34 | "extraction_mode": "", 35 | } 36 | DEFAULT_ARGS_ALLO = { 37 | "run_on": "", 38 | "trainsplit": "train", 39 | "valsplit": "tinyval", 40 | "setup": "p1a", 41 | "method": "arctic", 42 | "log_every": 50, 43 | "eval_every_epoch": 1, 44 | "lr_dec_epoch": [], 45 | "num_epoch": 20, 46 | "lr": 1e-5, 47 | "lr_dec_factor": 10, 48 | "lr_decay": 0.1, 49 | "num_exp": 1, 50 | "exp_key": "", 51 | "batch_size": 64, 52 | "test_batch_size": 128, 53 | "window_size": 11, 54 | "num_workers": 16, 55 | "img_feat_version": "", 56 | "eval_on": "", 57 | "acc_grad": 1, 58 | "load_from": "", 59 | "load_ckpt": "", 60 | "infer_ckpt": "", 61 | "resume_ckpt": "", 62 | "gpu_ids": [0], 63 | "agent_id": 0, 64 | "cluster_node": "", 65 | "bid": 21, 66 | "gpu_arch": "ampere", 67 | "gpu_min_mem": 20000, 68 | "extraction_mode": "", 69 | } 70 | -------------------------------------------------------------------------------- /src/parsers/generic_parser.py: -------------------------------------------------------------------------------- 1 | def add_generic_args(parser): 2 | """ 3 | Generic options that are non-specific to a project. 4 | """ 5 | parser.add_argument("--agent_id", type=int, default=None) 6 | parser.add_argument( 7 | "--load_from", type=str, default=None, help="Load weights from InterHand format" 8 | ) 9 | parser.add_argument( 10 | "--load_ckpt", type=str, default=None, help="Load checkpoints from PL format" 11 | ) 12 | parser.add_argument( 13 | "--infer_ckpt", type=str, default=None, help="This is for the interface" 14 | ) 15 | parser.add_argument( 16 | "--resume_ckpt", 17 | type=str, 18 | default=None, 19 | help="Resume training from checkpoint and keep logging in the same comet exp", 20 | ) 21 | parser.add_argument( 22 | "-f", 23 | "--fast", 24 | dest="fast_dev_run", 25 | help="single batch for development", 26 | action="store_true", 27 | ) 28 | parser.add_argument( 29 | "--trainsplit", 30 | type=str, 31 | default=None, 32 | choices=[None, "train", "smalltrain", "minitrain", "tinytrain"], 33 | help="Amount to subsample training set.", 34 | ) 35 | parser.add_argument( 36 | "--valsplit", 37 | type=str, 38 | default=None, 39 | choices=[None, "val", "smallval", "tinyval", "minival"], 40 | help="Amount to subsample validation set.", 41 | ) 42 | parser.add_argument( 43 | "--run_on", 44 | type=str, 45 | default=None, 46 | help="split for extraction", 47 | ) 48 | parser.add_argument("--setup", type=str, default=None) 49 | 50 | parser.add_argument("--log_every", type=int, default=None, help="log every k steps") 51 | parser.add_argument( 52 | "--eval_every_epoch", type=int, default=None, help="Eval every k epochs" 53 | ) 54 | parser.add_argument( 55 | "--lr_dec_epoch", 56 | type=int, 57 | nargs="+", 58 | default=None, 59 | help="Learning rate decay epoch.", 60 | ) 61 | parser.add_argument("--num_epoch", type=int, default=None) 62 | parser.add_argument("--lr", type=float, default=None) 63 | parser.add_argument( 64 | "--lr_dec_factor", type=int, default=None, help="Learning rate decay factor" 65 | ) 66 | parser.add_argument( 67 | "--lr_decay", type=float, default=None, help="Learning rate decay factor" 68 | ) 69 | parser.add_argument("--num_exp", type=int, default=None) 70 | parser.add_argument("--acc_grad", type=int, default=None) 71 | parser.add_argument("--batch_size", type=int, default=None) 72 | parser.add_argument("--test_batch_size", type=int, default=None) 73 | parser.add_argument("--num_workers", type=int, default=None) 74 | parser.add_argument( 75 | "--eval_on", 76 | type=str, 77 | default=None, 78 | choices=[None, "val", "test", "minival", "minitest"], 79 | help="Test mode set to eval on", 80 | ) 81 | 82 | parser.add_argument("--mute", help="No logging", action="store_true") 83 | parser.add_argument("--no_vis", help="Stop visualization", action="store_true") 84 | parser.add_argument("--cluster", action="store_true") 85 | parser.add_argument("--cluster_node", type=str, default=None) 86 | parser.add_argument("--bid", type=int, default=None, help="log every k steps") 87 | return parser 88 | -------------------------------------------------------------------------------- /src/parsers/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from easydict import EasyDict 4 | 5 | from common.args_utils import set_default_params 6 | from src.parsers.generic_parser import add_generic_args 7 | 8 | 9 | def construct_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument( 12 | "--method", 13 | type=str, 14 | default=None, 15 | choices=[None, "arctic_sf", "arctic_lstm", "field_sf", "field_lstm"], 16 | ) 17 | parser.add_argument("--exp_key", type=str, default=None) 18 | parser.add_argument("--extraction_mode", type=str, default=None) 19 | parser.add_argument("--img_feat_version", type=str, default=None) 20 | parser.add_argument("--window_size", type=int, default=None) 21 | parser.add_argument("--eval", action="store_true") 22 | parser = add_generic_args(parser) 23 | args = EasyDict(vars(parser.parse_args())) 24 | 25 | if args.method in ["arctic_sf"]: 26 | import src.parsers.configs.arctic_sf as config 27 | elif args.method in ["arctic_lstm"]: 28 | import src.parsers.configs.arctic_lstm as config 29 | elif args.method in ["field_sf"]: 30 | import src.parsers.configs.field_sf as config 31 | elif args.method in ["field_lstm"]: 32 | import src.parsers.configs.field_lstm as config 33 | else: 34 | assert False 35 | 36 | default_args = ( 37 | config.DEFAULT_ARGS_EGO if args.setup in ["p2"] else config.DEFAULT_ARGS_ALLO 38 | ) 39 | args = set_default_params(args, default_args) 40 | 41 | args.focal_length = 1000.0 42 | args.img_res = 224 43 | args.rot_factor = 30.0 44 | args.noise_factor = 0.4 45 | args.scale_factor = 0.25 46 | args.flip_prob = 0.0 47 | args.img_norm_mean = [0.485, 0.456, 0.406] 48 | args.img_norm_std = [0.229, 0.224, 0.225] 49 | args.pin_memory = True 50 | args.shuffle_train = True 51 | args.seed = 1 52 | args.grad_clip = 150.0 53 | args.use_gt_k = False # use weak perspective camera or the actual intrinsics 54 | args.speedup = True # load cropped images for faster training 55 | # args.speedup = False # uncomment this to load full images instead 56 | args.max_dist = 0.10 # distance range the model predicts on 57 | args.ego_image_scale = 0.3 58 | 59 | if args.method in ["field_sf", "field_lstm"]: 60 | args.project = "interfield" 61 | else: 62 | args.project = "arctic" 63 | args.interface_p = None 64 | 65 | if args.fast_dev_run: 66 | args.num_workers = 0 67 | args.batch_size = 8 68 | args.trainsplit = "minitrain" 69 | args.valsplit = "minival" 70 | args.log_every = 5 71 | args.window_size = 3 72 | 73 | return args 74 | -------------------------------------------------------------------------------- /src/utils/const.py: -------------------------------------------------------------------------------- 1 | import common.comet_utils as comet_utils 2 | from src.parsers.parser import construct_args 3 | 4 | args = construct_args() 5 | experiment, args = comet_utils.init_experiment(args) 6 | comet_utils.save_args(args, save_keys=["comet_key"]) 7 | -------------------------------------------------------------------------------- /src/utils/interfield.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch3d.ops import knn_points 3 | 4 | 5 | def compute_dist_mano_to_obj(batch_mano_v, batch_v, batch_v_len, dist_min, dist_max): 6 | knn_dists, knn_idx, _ = knn_points( 7 | batch_mano_v, batch_v, None, batch_v_len, K=1, return_nn=True 8 | ) 9 | knn_dists = knn_dists.sqrt()[:, :, 0] 10 | 11 | knn_dists = torch.clamp(knn_dists, dist_min, dist_max) 12 | return knn_dists, knn_idx[:, :, 0] 13 | 14 | 15 | def compute_dist_obj_to_mano(batch_mano_v, batch_v, batch_v_len, dist_min, dist_max): 16 | knn_dists, knn_idx, _ = knn_points( 17 | batch_v, batch_mano_v, batch_v_len, None, K=1, return_nn=True 18 | ) 19 | 20 | knn_dists = knn_dists.sqrt() 21 | knn_dists = torch.clamp(knn_dists, dist_min, dist_max) 22 | return knn_dists[:, :, 0], knn_idx[:, :, 0] 23 | 24 | 25 | def dist2contact(dist, contact_bnd): 26 | contact = (dist < contact_bnd).long() 27 | return contact 28 | -------------------------------------------------------------------------------- /src/utils/loss_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import common.torch_utils as torch_utils 5 | from common.torch_utils import nanmean 6 | 7 | l1_loss = nn.L1Loss(reduction="none") 8 | mse_loss = nn.MSELoss(reduction="none") 9 | 10 | 11 | def subtract_root_batch(joints: torch.Tensor, root_idx: int): 12 | assert len(joints.shape) == 3 13 | assert joints.shape[2] == 3 14 | joints_ra = joints.clone() 15 | root = joints_ra[:, root_idx : root_idx + 1].clone() 16 | joints_ra = joints_ra - root 17 | return joints_ra 18 | 19 | 20 | def compute_contact_devi_loss(pred, targets): 21 | cd_ro = contact_deviation( 22 | pred["object.v.cam"], 23 | pred["mano.v3d.cam.r"], 24 | targets["dist.ro"], 25 | targets["idx.ro"], 26 | targets["is_valid"], 27 | targets["right_valid"], 28 | ) 29 | 30 | cd_lo = contact_deviation( 31 | pred["object.v.cam"], 32 | pred["mano.v3d.cam.l"], 33 | targets["dist.lo"], 34 | targets["idx.lo"], 35 | targets["is_valid"], 36 | targets["left_valid"], 37 | ) 38 | cd_ro = nanmean(cd_ro) 39 | cd_lo = nanmean(cd_lo) 40 | cd_ro = torch.nan_to_num(cd_ro) 41 | cd_lo = torch.nan_to_num(cd_lo) 42 | return cd_ro, cd_lo 43 | 44 | 45 | def contact_deviation(pred_v3d_o, pred_v3d_r, dist_ro, idx_ro, is_valid, _right_valid): 46 | right_valid = _right_valid.clone() * is_valid 47 | contact_dist = 3 * 1e-3 # 3mm considered in contact 48 | vo_r_corres = torch.gather(pred_v3d_o, 1, idx_ro[:, :, None].repeat(1, 1, 3)) 49 | 50 | # displacement vector H->O 51 | disp_ro = vo_r_corres - pred_v3d_r # batch, num_v, 3 52 | invalid_ridx = (1 - right_valid).nonzero()[:, 0] 53 | disp_ro[invalid_ridx] = float("nan") 54 | disp_ro[dist_ro > contact_dist] = float("nan") 55 | cd = (disp_ro**2).sum(dim=2).sqrt() 56 | err_ro = torch_utils.nanmean(cd, axis=1) # .cpu().numpy() # m 57 | return err_ro 58 | 59 | 60 | def keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, criterion, jts_valid): 61 | """ 62 | Compute 3D keypoint loss for the examples that 3D keypoint annotations are available. 63 | The loss is weighted by the confidence. 64 | """ 65 | 66 | gt_root = gt_keypoints_3d[:, :1, :] 67 | gt_keypoints_3d = gt_keypoints_3d - gt_root 68 | pred_root = pred_keypoints_3d[:, :1, :] 69 | pred_keypoints_3d = pred_keypoints_3d - pred_root 70 | 71 | return joints_loss(pred_keypoints_3d, gt_keypoints_3d, criterion, jts_valid) 72 | 73 | 74 | def object_kp3d_loss(pred_3d, gt_3d, criterion, is_valid): 75 | num_kps = pred_3d.shape[1] // 2 76 | pred_3d_ra = subtract_root_batch(pred_3d, root_idx=num_kps) 77 | gt_3d_ra = subtract_root_batch(gt_3d, root_idx=num_kps) 78 | loss_kp = vector_loss( 79 | pred_3d_ra, 80 | gt_3d_ra, 81 | criterion=criterion, 82 | is_valid=is_valid, 83 | ) 84 | return loss_kp 85 | 86 | 87 | def hand_kp3d_loss(pred_3d, gt_3d, criterion, jts_valid): 88 | pred_3d_ra = subtract_root_batch(pred_3d, root_idx=0) 89 | gt_3d_ra = subtract_root_batch(gt_3d, root_idx=0) 90 | loss_kp = keypoint_3d_loss( 91 | pred_3d_ra, gt_3d_ra, criterion=criterion, jts_valid=jts_valid 92 | ) 93 | return loss_kp 94 | 95 | 96 | def vector_loss(pred_vector, gt_vector, criterion, is_valid=None): 97 | dist = criterion(pred_vector, gt_vector) 98 | if is_valid.sum() == 0: 99 | return torch.zeros((1)).to(gt_vector.device) 100 | if is_valid is not None: 101 | valid_idx = is_valid.long().bool() 102 | dist = dist[valid_idx] 103 | loss = dist.mean().view(-1) 104 | return loss 105 | 106 | 107 | def joints_loss(pred_vector, gt_vector, criterion, jts_valid): 108 | dist = criterion(pred_vector, gt_vector) 109 | if jts_valid is not None: 110 | dist = dist * jts_valid[:, :, None] 111 | loss = dist.mean().view(-1) 112 | return loss 113 | 114 | 115 | def mano_loss(pred_rotmat, pred_betas, gt_rotmat, gt_betas, criterion, is_valid=None): 116 | loss_regr_pose = vector_loss(pred_rotmat, gt_rotmat, criterion, is_valid) 117 | loss_regr_betas = vector_loss(pred_betas, gt_betas, criterion, is_valid) 118 | return loss_regr_pose, loss_regr_betas 119 | --------------------------------------------------------------------------------