├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── configs └── config_vimo.yaml ├── data ├── colors.txt ├── pretrain │ └── cascade_mask_rcnn_vitdet_h_75ep.py ├── smpl │ ├── J_regressor_extra.npy │ ├── J_regressor_h36m.npy │ ├── kintree_table.pkl │ └── smpl_mean_params.npz └── teaser.jpg ├── data_config.py ├── install.sh ├── lib ├── __init__.py ├── camera │ ├── __init__.py │ ├── est_gravity.py │ ├── est_scale.py │ ├── masked_droid_slam.py │ └── slam_utils.py ├── core │ ├── __init__.py │ ├── base_trainer.py │ ├── config.py │ ├── constants.py │ ├── data_loader.py │ └── losses.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── coco_occlusion.py │ ├── detect_dataset.py │ ├── image_dataset.py │ ├── mixed_dataset.py │ ├── track_dataset.py │ └── video_dataset.py ├── get_videoloader.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ ├── pose_transformer.py │ │ └── t_cond_mlp.py │ ├── configs │ │ └── config_vimo.yaml │ ├── hmr_vimo.py │ ├── modules.py │ ├── smpl.py │ └── vit.py ├── pipeline │ ├── __init__.py │ ├── deva_track.py │ ├── tools.py │ └── visualization.py ├── trainer.py ├── utils │ ├── __init__.py │ ├── eval_utils.py │ ├── geometry.py │ ├── imutils.py │ ├── misc.py │ ├── pose_utils.py │ ├── rotation_conversions.py │ ├── utils.py │ ├── utils_detectron2.py │ └── visualizer.py └── vis │ ├── __init__.py │ ├── renderer.py │ ├── renderer_img.py │ ├── tools.py │ └── traj.py ├── scripts ├── crop_datasets.py ├── download_models.sh ├── download_pretrain.sh ├── emdb │ ├── run.sh │ ├── run_cam.py │ ├── run_eval.py │ └── run_smpl.py ├── estimate_camera.py ├── estimate_humans.py ├── extract_bedlam_jpg.py └── visualize_tram.py ├── thirdparty ├── DROID-SLAM │ ├── .gitignore │ ├── .gitmodules │ ├── LICENSE │ ├── README.md │ ├── demo.py │ ├── droid_slam │ │ ├── data_readers │ │ │ ├── __init__.py │ │ │ ├── augmentation.py │ │ │ ├── base.py │ │ │ ├── factory.py │ │ │ ├── rgbd_utils.py │ │ │ ├── stream.py │ │ │ ├── tartan.py │ │ │ └── tartan_test.txt │ │ ├── depth_video.py │ │ ├── droid.py │ │ ├── droid_backend.py │ │ ├── droid_frontend.py │ │ ├── droid_net.py │ │ ├── factor_graph.py │ │ ├── geom │ │ │ ├── __init__.py │ │ │ ├── ba.py │ │ │ ├── chol.py │ │ │ ├── graph_utils.py │ │ │ ├── losses.py │ │ │ └── projective_ops.py │ │ ├── logger.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── clipping.py │ │ │ ├── corr.py │ │ │ ├── extractor.py │ │ │ └── gru.py │ │ ├── motion_filter.py │ │ ├── trajectory_filler.py │ │ ├── vis_headless.py │ │ └── visualization.py │ ├── environment.yaml │ ├── environment_novis.yaml │ ├── evaluation_scripts │ │ ├── test_eth3d.py │ │ ├── test_euroc.py │ │ ├── test_tum.py │ │ └── validate_tartanair.py │ ├── misc │ │ ├── DROID.png │ │ ├── renderoption.json │ │ └── screenshot.png │ ├── setup.py │ ├── src │ │ ├── altcorr_kernel.cu │ │ ├── correlation_kernels.cu │ │ ├── droid.cpp │ │ └── droid_kernels.cu │ ├── thirdparty │ │ └── tartanair_tools │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── data_type.md │ │ │ ├── download_cvpr_slam_test.txt │ │ │ ├── download_training.py │ │ │ ├── download_training_zipfiles.txt │ │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ ├── evaluate_ate_scale.py │ │ │ ├── evaluate_kitti.py │ │ │ ├── evaluate_rpe.py │ │ │ ├── evaluator_base.py │ │ │ ├── pose_est.txt │ │ │ ├── pose_gt.txt │ │ │ ├── tartanair_evaluator.py │ │ │ ├── trajectory_transform.py │ │ │ └── transformation.py │ │ │ └── seg_rgbs.txt │ ├── tools │ │ ├── download_sample_data.sh │ │ ├── evaluate_eth3d.sh │ │ ├── evaluate_euroc.sh │ │ ├── evaluate_tum.sh │ │ └── validate_tartanair.sh │ └── train.py └── camcalib │ ├── __init__.py │ ├── cam_utils.py │ ├── camcalib_demo.py │ ├── model.py │ └── resnet.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific data and submodule 2 | .vscode/ 3 | **/.DS_Store 4 | data/pretrain/*.pth 5 | data/pretrain/*.pth.tar 6 | data/pretrain/*.ckpt 7 | data/pretrain/*/*.ckpt 8 | data/smpl/SMPL_*.pkl 9 | data/*.pkl 10 | *.mov 11 | *.mp4 12 | example_video/ 13 | results/ 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # Jupyter Notebook 39 | .ipynb_checkpoints 40 | *.ipynb 41 | 42 | # IPython 43 | profile_default/ 44 | ipython_config.py 45 | 46 | # pyenv 47 | .python-version 48 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/DROID-SLAM/thirdparty/lietorch"] 2 | path = thirdparty/DROID-SLAM/thirdparty/lietorch 3 | url = https://github.com/princeton-vl/lietorch 4 | [submodule "thirdparty/Tracking-Anything-with-DEVA"] 5 | path = thirdparty/Tracking-Anything-with-DEVA 6 | url = https://github.com/hkchengrex/Tracking-Anything-with-DEVA 7 | [submodule "thirdparty/DROID-SLAM/thirdparty/eigen"] 8 | path = thirdparty/DROID-SLAM/thirdparty/eigen 9 | url = https://gitlab.com/libeigen/eigen.git 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yufu Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/config_vimo.yaml: -------------------------------------------------------------------------------- 1 | LOGDIR: '' 2 | DEVICE: 'cuda' 3 | OUTPUT_DIR: 'results' 4 | NUM_WORKERS: 15 5 | SEED_VALUE: 0 6 | IMG_RES: 256 7 | 8 | DATASET: 9 | LIST: ['3dpw_vid', 'h36m_vid', 'bedlam_vid'] 10 | PARTITION: [0.165, 0.165, 0.67] 11 | SEQ_LEN: 16 12 | TEST: 'emdb_1' 13 | 14 | LOSS: 15 | KPT2D: 5.0 16 | KPT3D: 5.0 17 | SMPL_PLUS: 1.0 18 | V3D: 1.0 19 | 20 | 21 | TRAIN: 22 | RESUME: None 23 | LOAD_LATEST: True 24 | BATCH_SIZE: 24 25 | MULTI_LR: True 26 | LR: 1e-5 27 | LR2: 3e-5 28 | WARMUP_STEPS: 3000 29 | SUMMARY_STEP: 50 30 | VALID_STEP: 250 31 | SAVE_STEP: 1000 32 | MAX_STEP: 250000 33 | GAMMA: 1 34 | UPDATE_ITER: 1 35 | CLIP_GRADIENT: True 36 | CLIP_NORM: 1.0 37 | WD: 0.01 38 | OPT: 'AdamW' 39 | LOSS_SCALE: 1 40 | 41 | 42 | MODEL: 43 | CHECKPOINT: 'data/pretrain/hmr2b/epoch=35-step=1000000.ckpt' 44 | ST_MODULE: True 45 | MOTION_MODULE: True 46 | ST_HDIM: 512 47 | MOTION_HDIM: 384 48 | ST_NLAYER: 6 49 | MOTION_NLAYER: 6 50 | 51 | 52 | EXP_NAME: 'tram_0' 53 | COMMENT: 54 | 'Default configs.' 55 | 56 | 57 | -------------------------------------------------------------------------------- /data/colors.txt: -------------------------------------------------------------------------------- 1 | 229 229 229 255 2 | 0 187 249 255 3 | 254 228 64 255 4 | 0 245 212 255 5 | 222 158 54 255 6 | 255 133 141 255 7 | 93 169 233 255 8 | 241 91 181 255 9 | 155 93 229 255 10 | 24 242 178 255 11 | 252 116 186 255 12 | 138 225 252 255 13 | 178 171 242 255 14 | 24 231 98 255 15 | 232 93 117 255 16 | 207 250 231 255 17 | 253 208 242 255 18 | 109 157 197 255 19 | 166 189 219 255 20 | 253 146 207 255 21 | 167 201 87 255 22 | 117 68 177 255 23 | 255 229 50 255 24 | 251 202 239 255 25 | 58 134 255 255 26 | 255 0 110 255 27 | 251 86 7 255 28 | 188 51 209 255 29 | 122 229 130 255 30 | 0 48 73 255 31 | 214 40 40 255 32 | 229 179 179 255 33 | 0 187 249 255 34 | 255 190 11 255 35 | 204 213 174 255 36 | 0 245 212 255 37 | 255 153 200 255 38 | 144 251 146 255 39 | 189 211 147 255 40 | 230 0 86 255 41 | 0 95 57 255 42 | 0 174 126 255 43 | 255 116 163 255 44 | 189 198 255 255 45 | 90 219 255 255 46 | 158 0 142 255 47 | 255 147 126 255 48 | 164 36 0 255 49 | 0 21 68 255 50 | 145 208 203 255 51 | 95 173 78 255 52 | 107 104 130 255 53 | 0 125 181 255 54 | 106 130 108 255 55 | 252 246 189 255 56 | 208 244 222 255 57 | 169 222 249 255 58 | 228 193 249 255 59 | 122 204 174 255 60 | 194 140 159 255 61 | 0 143 156 255 62 | 235 0 0 255 63 | 255 2 157 255 64 | 104 61 59 255 65 | 150 138 232 255 66 | 152 255 82 255 67 | 167 87 64 255 68 | 1 255 254 255 69 | 255 238 232 255 70 | 254 137 0 255 71 | 1 208 255 255 72 | 187 136 0 255 73 | 165 255 210 255 74 | 255 166 254 255 75 | 119 77 0 255 76 | 122 71 130 255 77 | 38 52 0 255 78 | 0 71 84 255 79 | 67 0 44 255 80 | 181 0 255 255 81 | 255 177 103 255 82 | 255 219 102 255 83 | 126 45 210 255 84 | 229 111 254 255 85 | 222 255 116 255 86 | 0 255 120 255 87 | 0 155 255 255 88 | 0 100 1 255 89 | 0 118 255 255 90 | 133 169 0 255 91 | 0 185 23 255 92 | 120 130 49 255 93 | 0 255 198 255 94 | 255 110 65 255 95 | 232 94 190 255 96 | 1 0 103 255 97 | 149 0 58 255 98 | 98 14 0 255 99 | 0 0 0 255 100 | -------------------------------------------------------------------------------- /data/pretrain/cascade_mask_rcnn_vitdet_h_75ep.py: -------------------------------------------------------------------------------- 1 | ## coco_loader_lsj.py 2 | 3 | import detectron2.data.transforms as T 4 | from detectron2 import model_zoo 5 | from detectron2.config import LazyCall as L 6 | 7 | # Data using LSJ 8 | image_size = 1024 9 | dataloader = model_zoo.get_config("common/data/coco.py").dataloader 10 | dataloader.train.mapper.augmentations = [ 11 | L(T.RandomFlip)(horizontal=True), # flip first 12 | L(T.ResizeScale)( 13 | min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size 14 | ), 15 | L(T.FixedSizeCrop)(crop_size=(image_size, image_size), pad=False), 16 | ] 17 | dataloader.train.mapper.image_format = "RGB" 18 | dataloader.train.total_batch_size = 64 19 | # recompute boxes due to cropping 20 | dataloader.train.mapper.recompute_boxes = True 21 | 22 | dataloader.test.mapper.augmentations = [ 23 | L(T.ResizeShortestEdge)(short_edge_length=image_size, max_size=image_size), 24 | ] 25 | 26 | from functools import partial 27 | from fvcore.common.param_scheduler import MultiStepParamScheduler 28 | 29 | from detectron2 import model_zoo 30 | from detectron2.config import LazyCall as L 31 | from detectron2.solver import WarmupParamScheduler 32 | from detectron2.modeling.backbone.vit import get_vit_lr_decay_rate 33 | 34 | # mask_rcnn_vitdet_b_100ep.py 35 | 36 | model = model_zoo.get_config("common/models/mask_rcnn_vitdet.py").model 37 | 38 | # Initialization and trainer settings 39 | train = model_zoo.get_config("common/train.py").train 40 | train.amp.enabled = True 41 | train.ddp.fp16_compression = True 42 | train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth" 43 | 44 | 45 | # Schedule 46 | # 100 ep = 184375 iters * 64 images/iter / 118000 images/ep 47 | train.max_iter = 184375 48 | 49 | lr_multiplier = L(WarmupParamScheduler)( 50 | scheduler=L(MultiStepParamScheduler)( 51 | values=[1.0, 0.1, 0.01], 52 | milestones=[163889, 177546], 53 | num_updates=train.max_iter, 54 | ), 55 | warmup_length=250 / train.max_iter, 56 | warmup_factor=0.001, 57 | ) 58 | 59 | # Optimizer 60 | optimizer = model_zoo.get_config("common/optim.py").AdamW 61 | optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7) 62 | optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} 63 | 64 | # cascade_mask_rcnn_vitdet_b_100ep.py 65 | 66 | from detectron2.config import LazyCall as L 67 | from detectron2.layers import ShapeSpec 68 | from detectron2.modeling.box_regression import Box2BoxTransform 69 | from detectron2.modeling.matcher import Matcher 70 | from detectron2.modeling.roi_heads import ( 71 | FastRCNNOutputLayers, 72 | FastRCNNConvFCHead, 73 | CascadeROIHeads, 74 | ) 75 | 76 | # arguments that don't exist for Cascade R-CNN 77 | [model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] 78 | 79 | model.roi_heads.update( 80 | _target_=CascadeROIHeads, 81 | box_heads=[ 82 | L(FastRCNNConvFCHead)( 83 | input_shape=ShapeSpec(channels=256, height=7, width=7), 84 | conv_dims=[256, 256, 256, 256], 85 | fc_dims=[1024], 86 | conv_norm="LN", 87 | ) 88 | for _ in range(3) 89 | ], 90 | box_predictors=[ 91 | L(FastRCNNOutputLayers)( 92 | input_shape=ShapeSpec(channels=1024), 93 | test_score_thresh=0.05, 94 | box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), 95 | cls_agnostic_bbox_reg=True, 96 | num_classes="${...num_classes}", 97 | ) 98 | for (w1, w2) in [(10, 5), (20, 10), (30, 15)] 99 | ], 100 | proposal_matchers=[ 101 | L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) 102 | for th in [0.5, 0.6, 0.7] 103 | ], 104 | ) 105 | 106 | # cascade_mask_rcnn_vitdet_h_75ep.py 107 | 108 | from functools import partial 109 | 110 | train.init_checkpoint = "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth" 111 | 112 | model.backbone.net.embed_dim = 1280 113 | model.backbone.net.depth = 32 114 | model.backbone.net.num_heads = 16 115 | model.backbone.net.drop_path_rate = 0.5 116 | # 7, 15, 23, 31 for global attention 117 | model.backbone.net.window_block_indexes = ( 118 | list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)) 119 | ) 120 | 121 | optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, lr_decay_rate=0.9, num_layers=32) 122 | optimizer.params.overrides = {} 123 | optimizer.params.weight_decay_norm = None 124 | 125 | train.max_iter = train.max_iter * 3 // 4 # 100ep -> 75ep 126 | lr_multiplier.scheduler.milestones = [ 127 | milestone * 3 // 4 for milestone in lr_multiplier.scheduler.milestones 128 | ] 129 | lr_multiplier.scheduler.num_updates = train.max_iter 130 | -------------------------------------------------------------------------------- /data/smpl/J_regressor_extra.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/data/smpl/J_regressor_extra.npy -------------------------------------------------------------------------------- /data/smpl/J_regressor_h36m.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/data/smpl/J_regressor_h36m.npy -------------------------------------------------------------------------------- /data/smpl/kintree_table.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/data/smpl/kintree_table.pkl -------------------------------------------------------------------------------- /data/smpl/smpl_mean_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/data/smpl/smpl_mean_params.npz -------------------------------------------------------------------------------- /data/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/data/teaser.jpg -------------------------------------------------------------------------------- /data_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains definitions of useful data stuctures and the paths 3 | for the datasets and data files necessary to run the code. 4 | Things you need to change: *_ROOT that indicate the path to each dataset 5 | """ 6 | from os.path import join 7 | 8 | # You should have directores like this 9 | # - datasets 10 | # ---- dataset_ann 11 | # ---- 3dpw 12 | # ---- h36m 13 | # ---- bedlam_30fps 14 | # ---- emdb 15 | 16 | # Please change these two lines for your directories 17 | ROOT = './datasets' 18 | DATASET_NPZ_PATH = './datasets/dataset_ann' 19 | 20 | H36M_ROOT = join(ROOT, 'h36m') 21 | PW3D_ROOT = join(ROOT, '3dpw') 22 | BEDLAM_ROOT = join(ROOT, 'bedlam_30fps') 23 | EMDB_ROOT = join(ROOT, 'emdb') 24 | 25 | # Path to test/train npz files 26 | DATASET_FILES = [ { 27 | 'emdb_1': join(DATASET_NPZ_PATH , 'emdb_1.npz'), 28 | '3dpw_vid_test': join(DATASET_NPZ_PATH , '3dpw_vid_test.npz'), 29 | }, 30 | 31 | { 32 | '3dpw_vid': join(DATASET_NPZ_PATH , '3dpw_vid_train.npz'), 33 | 'h36m_vid': join(DATASET_NPZ_PATH , 'h36m_train.npz'), 34 | 'bedlam_vid': join(DATASET_NPZ_PATH , 'bedlam_vid.npz'), 35 | } 36 | ] 37 | 38 | DATASET_FOLDERS = {'h36m_vid': H36M_ROOT, 39 | '3dpw_vid': PW3D_ROOT, 40 | 'bedlam_vid': BEDLAM_ROOT, 41 | 'emdb_1': EMDB_ROOT, 42 | '3dpw_vid_test': PW3D_ROOT, 43 | } 44 | 45 | 46 | PASCAL_OCCLUDERS = 'data/pascal_occluders.pkl' 47 | JOINT_REGRESSOR_TRAIN_EXTRA = 'data/smpl/J_regressor_extra.npy' 48 | JOINT_REGRESSOR_H36M = 'data/smpl/J_regressor_h36m.npy' 49 | SMPL_MEAN_PARAMS = 'data/smpl/smpl_mean_params.npz' 50 | SMPL_MODEL_DIR = 'data/smpl' 51 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # conda remove -n tram --all -y 2 | # conda create -n tram python=3.10 -y 3 | # conda activate tram 4 | 5 | conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit 6 | pip install torch==2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 7 | pip install pytorch-lightning 8 | pip install 'git+https://github.com/facebookresearch/detectron2.git@a59f05630a8f205756064244bf5beb8661f96180' 9 | pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" 10 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu118.html 11 | 12 | conda install -c conda-forge suitesparse 13 | 14 | pip install pulp 15 | pip install supervision 16 | 17 | pip install open3d 18 | pip install opencv-python 19 | pip install loguru 20 | pip install git+https://github.com/mattloper/chumpy 21 | pip install einops 22 | pip install plyfile 23 | pip install pyrender 24 | pip install segment_anything 25 | pip install scikit-image 26 | pip install smplx 27 | pip install timm==0.6.7 28 | pip install evo 29 | pip install pytorch-minimize 30 | pip install imageio[ffmpeg] 31 | pip install numpy==1.23 32 | pip install gdown 33 | pip install openpyxl 34 | # pip install git+https://github.com/princeton-vl/lietorch.git 35 | 36 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/lib/__init__.py -------------------------------------------------------------------------------- /lib/camera/__init__.py: -------------------------------------------------------------------------------- 1 | from .masked_droid_slam import run_metric_slam, calibrate_intrinsics 2 | from .est_gravity import align_cam_to_world -------------------------------------------------------------------------------- /lib/camera/est_gravity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | import torch 5 | from thirdparty.camcalib.model import CameraRegressorNetwork 6 | 7 | 8 | def run_spec(img): 9 | # Predict gravity direction + fov using SPEC 10 | spec = CameraRegressorNetwork() 11 | spec = spec.load_ckpt('data/pretrain/camcalib_sa_biased_l2.ckpt').to('cuda') 12 | 13 | with torch.no_grad(): 14 | if isinstance(img, str): 15 | img = cv2.imread(img)[:,:,::-1] 16 | 17 | preds = spec(img, transform_data=True) 18 | vfov, pitch, roll = preds 19 | f_pix = img.shape[0] / (2 * np.tan(vfov / 2.)) 20 | 21 | return [f_pix, pitch, roll] 22 | 23 | 24 | def cam_wrt_gravity(pitch, roll): 25 | # Convert pitch-roll from SPEC to cam pose wrt to gravity direction 26 | Rpitch = rotation_about_x(-pitch)[:3, :3] 27 | Rroll = rotation_about_y(roll)[:3, :3] 28 | R_gc = Rpitch @ Rroll 29 | return R_gc 30 | 31 | 32 | def cam_wrt_world(pitch, roll): 33 | # Cam from gravity frame to world frame 34 | R_gc = cam_wrt_gravity(pitch, roll) 35 | R_wg = torch.Tensor([[1,0,0], 36 | [0,-1,0], 37 | [0,0,-1]]) 38 | R_wc = R_wg @ R_gc 39 | return R_wc 40 | 41 | 42 | def align_cam_to_world(img, cam_R, cam_T): 43 | f_pix, pitch, roll = run_spec(img) 44 | R_wc = cam_wrt_world(pitch, roll) 45 | 46 | world_cam_R = torch.einsum('ij,bjk->bik', R_wc, cam_R) 47 | world_cam_T = torch.einsum('ij,bj->bi', R_wc, cam_T) 48 | 49 | return world_cam_R, world_cam_T, f_pix 50 | 51 | 52 | def rotation_about_x(angle: float) -> torch.Tensor: 53 | cos = math.cos(angle) 54 | sin = math.sin(angle) 55 | return torch.tensor([[1, 0, 0, 0], [0, cos, -sin, 0], [0, sin, cos, 0], [0, 0, 0, 1]]) 56 | 57 | 58 | def rotation_about_y(angle: float) -> torch.Tensor: 59 | cos = math.cos(angle) 60 | sin = math.sin(angle) 61 | return torch.tensor([[cos, 0, sin, 0], [0, 1, 0, 0], [-sin, 0, cos, 0], [0, 0, 0, 1]]) 62 | 63 | 64 | def rotation_about_z(angle: float) -> torch.Tensor: 65 | cos = math.cos(angle) 66 | sin = math.sin(angle) 67 | return torch.tensor([[cos, -sin, 0, 0], [sin, cos, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 68 | -------------------------------------------------------------------------------- /lib/camera/est_scale.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | from torchmin import minimize 5 | 6 | 7 | def est_scale_iterative(slam_depth, pred_depth, iters=10, msk=None): 8 | """ Simple depth-align by iterative median and thresholding """ 9 | s = pred_depth / slam_depth 10 | 11 | if msk is None: 12 | msk = np.zeros_like(pred_depth) 13 | else: 14 | msk = cv2.resize(msk, (pred_depth.shape[1], pred_depth.shape[0])) 15 | 16 | robust = (msk<0.5) * (0shape[:,0]) + (bbox[:,1]>shape[:,1]) 118 | invalid = invalid + (self.data['valid']!=1) 119 | return invalid 120 | 121 | -------------------------------------------------------------------------------- /lib/datasets/detect_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torchvision.transforms import Normalize, ToTensor, Compose 4 | import numpy as np 5 | import cv2 6 | 7 | from lib.core import constants 8 | from lib.utils.imutils import crop, boxes_2_cs 9 | 10 | 11 | class DetectDataset(Dataset): 12 | """ 13 | Detection Dataset Class - Handles data loading from detections. 14 | """ 15 | def __init__(self, img, boxes, crop_size=256, dilate=1.2, 16 | img_focal=None, img_center=None, normalize=True): 17 | super(DetectDataset, self).__init__() 18 | 19 | self.img = img 20 | self.crop_size = crop_size 21 | self.orig_shape = img.shape[:2] 22 | self.normalize = normalize 23 | self.normalize_img = Compose([ 24 | ToTensor(), 25 | Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD) 26 | ]) 27 | 28 | self.boxes = boxes 29 | self.box_dilate = dilate 30 | self.centers, self.scales = boxes_2_cs(boxes) 31 | 32 | if img_focal is None: 33 | self.img_focal = self.est_focal(self.orig_shape) 34 | else: 35 | self.img_focal = img_focal 36 | 37 | if img_center is None: 38 | self.img_center = self.est_center(self.orig_shape) 39 | else: 40 | self.img_center = img_center 41 | 42 | 43 | def __getitem__(self, index): 44 | item = {} 45 | scale = self.scales[index] * self.box_dilate 46 | center = self.centers[index] 47 | img_focal = self.img_focal 48 | img_center = self.img_center 49 | 50 | img = crop(self.img, center, scale, 51 | [self.crop_size, self.crop_size], rot=0).astype('uint8') 52 | origin_crop = img.copy() 53 | if self.normalize: 54 | img = self.normalize_img(img) 55 | 56 | 57 | item['img'] = img 58 | item['origin_crop'] = origin_crop 59 | item['scale'] = torch.tensor(scale).float() 60 | item['center'] = torch.tensor(center).float() 61 | item['img_focal'] = torch.tensor(img_focal).float() 62 | item['img_center'] = torch.tensor(img_center).float() 63 | item['orig_shape'] = torch.tensor(self.orig_shape).float() 64 | 65 | return item 66 | 67 | 68 | def __len__(self): 69 | return len(self.boxes) 70 | 71 | 72 | def est_focal(self, orig_shape): 73 | h, w = orig_shape 74 | focal = np.sqrt(h**2 + w**2) 75 | return focal 76 | 77 | def est_center(self, orig_shape): 78 | h, w = orig_shape 79 | center = np.array([w/2., h/2.]) 80 | return center 81 | 82 | 83 | -------------------------------------------------------------------------------- /lib/datasets/image_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torchvision.transforms import Normalize, ToTensor, Compose 4 | import numpy as np 5 | import cv2 6 | from os.path import join 7 | from skimage.util.shape import view_as_windows 8 | 9 | from lib.core import constants, config 10 | from lib.utils.imutils import crop, boxes_2_cs 11 | 12 | 13 | class ImageDataset(Dataset): 14 | """ 15 | Image Dataset Class - Handles data loading from image files. 16 | """ 17 | def __init__(self, imgfiles, boxes, crop_size=256, dilate=1.0, 18 | img_focal=None, img_center=None, normalization=False, step=8): 19 | super(ImageDataset, self).__init__() 20 | 21 | self.imgfiles = imgfiles 22 | self.crop_size = crop_size 23 | self.normalization = normalization 24 | self.normalize_img = Compose([ 25 | ToTensor(), 26 | Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD) 27 | ]) 28 | 29 | self.boxes = boxes 30 | self.box_dilate = dilate 31 | self.centers, self.scales = boxes_2_cs(boxes) 32 | 33 | self.img_focal = img_focal 34 | self.img_center = img_center 35 | 36 | idx = np.arange(0, len(imgfiles)) 37 | self.seq_idx = view_as_windows(idx, (16,), step=step) 38 | 39 | # leftover 40 | self.leftover = len(imgfiles) % step 41 | if self.leftover != 0: 42 | self.seq_idx = np.append(self.seq_idx, idx[-16:][None], axis=0) 43 | 44 | 45 | def __len__(self): 46 | return len(self.imgfiles) 47 | 48 | 49 | def __getitem__(self, index): 50 | return self.get_item(index) 51 | 52 | 53 | def get_item(self, index): 54 | item = {} 55 | scale = self.scales[index] * self.box_dilate 56 | center = self.centers[index] 57 | img_focal = self.img_focal 58 | img_center = self.img_center 59 | 60 | imgfile = self.imgfiles[index] 61 | img = cv2.imread(imgfile)[:,:,::-1] 62 | img_crop = crop(img, center, scale, 63 | [self.crop_size, self.crop_size], 64 | rot=0).astype('uint8') 65 | 66 | if self.normalization: 67 | img_crop = self.normalize_img(img_crop) 68 | else: 69 | img_crop = torch.from_numpy(img_crop) 70 | 71 | if self.img_focal is None: 72 | orig_shape = img.shape[:2] 73 | img_focal = self.est_focal(orig_shape) 74 | 75 | if self.img_center is None: 76 | orig_shape = img.shape[:2] 77 | img_center = self.est_center(orig_shape) 78 | 79 | item['img'] = img_crop 80 | item['img_idx'] = torch.tensor(index).long() 81 | item['scale'] = torch.tensor(scale).float() 82 | item['center'] = torch.tensor(center).float() 83 | item['img_focal'] = torch.tensor(img_focal).float() 84 | item['img_center'] = torch.tensor(img_center).float() 85 | 86 | return item 87 | 88 | 89 | def est_focal(self, orig_shape): 90 | h, w = orig_shape 91 | focal = np.sqrt(h**2 + w**2) 92 | return focal 93 | 94 | def est_center(self, orig_shape): 95 | h, w = orig_shape 96 | center = np.array([w/2., h/2.]) 97 | return center 98 | 99 | 100 | -------------------------------------------------------------------------------- /lib/datasets/mixed_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .video_dataset import VideoDataset 4 | from .image_dataset import ImageDataset 5 | 6 | 7 | class MixedVidDataset(torch.utils.data.Dataset): 8 | 9 | def __init__(self, dataset_list, partition, **kwargs): 10 | 11 | self.dataset_list = dataset_list 12 | self.nds = len(self.dataset_list) 13 | 14 | self.datasets = [VideoDataset(ds, **kwargs) for ds in self.dataset_list] 15 | self.length = max([len(ds) for ds in self.datasets]) 16 | 17 | self.partition = partition 18 | self.partition = np.array(self.partition).cumsum() 19 | 20 | 21 | def __getitem__(self, index): 22 | p = np.random.rand() 23 | for i in range(self.nds): 24 | if p <= self.partition[i]: 25 | item = self.datasets[i][index % len(self.datasets[i])] 26 | 27 | return item 28 | 29 | def __len__(self): 30 | return self.length 31 | 32 | class MixedImgDataset(torch.utils.data.Dataset): 33 | 34 | def __init__(self, dataset_list, partition, **kwargs): 35 | 36 | self.dataset_list = dataset_list 37 | self.nds = len(self.dataset_list) 38 | 39 | self.datasets = [ImageDataset(ds, **kwargs) for ds in self.dataset_list] 40 | self.length = max([len(ds) for ds in self.datasets]) 41 | 42 | self.partition = partition 43 | self.partition = np.array(self.partition).cumsum() 44 | 45 | 46 | def __getitem__(self, index): 47 | p = np.random.rand() 48 | for i in range(self.nds): 49 | if p <= self.partition[i]: 50 | item = self.datasets[i][index % len(self.datasets[i])] 51 | 52 | return item 53 | 54 | def __len__(self): 55 | return self.length 56 | -------------------------------------------------------------------------------- /lib/datasets/track_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torchvision.transforms import Normalize, ToTensor, Compose 4 | import numpy as np 5 | import cv2 6 | 7 | from lib.core import constants 8 | from lib.utils.imutils import crop, boxes_2_cs 9 | 10 | 11 | class TrackDataset(Dataset): 12 | """ 13 | Track Dataset Class - Load images/crops of the tracked boxes. 14 | """ 15 | def __init__(self, imgfiles, boxes, crop_size=256, dilate=1.0, 16 | img_focal=None, img_center=None, normalization=True): 17 | super(TrackDataset, self).__init__() 18 | 19 | self.imgfiles = imgfiles 20 | self.crop_size = crop_size 21 | self.normalization = normalization 22 | self.normalize_img = Compose([ 23 | ToTensor(), 24 | Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD) 25 | ]) 26 | 27 | self.boxes = boxes 28 | self.box_dilate = dilate 29 | self.centers, self.scales = boxes_2_cs(boxes) 30 | 31 | self.img_focal = img_focal 32 | self.img_center = img_center 33 | 34 | 35 | def __len__(self): 36 | return len(self.imgfiles) 37 | 38 | 39 | def __getitem__(self, index): 40 | item = {} 41 | imgfile = self.imgfiles[index] 42 | scale = self.scales[index] * self.box_dilate 43 | center = self.centers[index] 44 | img_focal = self.img_focal 45 | img_center = self.img_center 46 | 47 | img = cv2.imread(imgfile)[:,:,::-1] 48 | img_crop = crop(img, center, scale, 49 | [self.crop_size, self.crop_size], 50 | rot=0).astype('uint8') 51 | 52 | if self.normalization: 53 | img_crop = self.normalize_img(img_crop) 54 | else: 55 | img_crop = torch.from_numpy(img_crop) 56 | 57 | if self.img_focal is None: 58 | orig_shape = img.shape[:2] 59 | img_focal = self.est_focal(orig_shape) 60 | 61 | if self.img_center is None: 62 | orig_shape = img.shape[:2] 63 | img_center = self.est_center(orig_shape) 64 | 65 | item['img'] = img_crop 66 | item['img_idx'] = torch.tensor(index).long() 67 | item['scale'] = torch.tensor(scale).float() 68 | item['center'] = torch.tensor(center).float() 69 | item['img_focal'] = torch.tensor(img_focal).float() 70 | item['img_center'] = torch.tensor(img_center).float() 71 | 72 | return item 73 | 74 | 75 | def est_focal(self, orig_shape): 76 | h, w = orig_shape 77 | focal = np.sqrt(h**2 + w**2) 78 | return focal 79 | 80 | def est_center(self, orig_shape): 81 | h, w = orig_shape 82 | center = np.array([w/2., h/2.]) 83 | return center 84 | 85 | 86 | -------------------------------------------------------------------------------- /lib/get_videoloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from lib.datasets.mixed_dataset import MixedVidDataset 3 | from lib.datasets.video_dataset import VideoDataset 4 | from lib.core.data_loader import CheckpointDataLoader 5 | 6 | def get_dataloaders(cfg=None): 7 | 8 | train_bs = cfg.TRAIN.BATCH_SIZE 9 | num_workers = cfg.NUM_WORKERS 10 | crop_size = cfg.IMG_RES 11 | dataset_list = cfg.DATASET.LIST 12 | seqlen = cfg.DATASET.SEQ_LEN 13 | stride = cfg.DATASET.STRIDE 14 | valid_set = cfg.DATASET.TEST 15 | partition = cfg.DATASET.PARTITION 16 | 17 | print('Num of data loading workers:', num_workers) 18 | print('Sequence length:', seqlen) 19 | print('Sequence stride:', stride) 20 | 21 | print('Datasets:', dataset_list) 22 | print('Partition:', partition) 23 | 24 | train = MixedVidDataset(dataset_list, partition, is_train=True, use_augmentation=True, 25 | normalization=True, cropped=True, crop_size=crop_size, 26 | seqlen=seqlen, stride=stride) 27 | train_loader = CheckpointDataLoader(train, shuffle=True, batch_size=train_bs, num_workers=num_workers) 28 | 29 | test = VideoDataset(valid_set, is_train=False, use_augmentation=False, 30 | normalization=True, cropped=True, crop_size=crop_size, seqlen=16, stride=16) 31 | test_loader = DataLoader(test, batch_size=8, shuffle=False, num_workers=num_workers) 32 | 33 | return [train_loader, test_loader] 34 | 35 | 36 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from yacs.config import CfgNode as CN 4 | from .hmr_vimo import HMR_VIMO 5 | 6 | 7 | def get_default_config(): 8 | cfg_file = os.path.join( 9 | os.path.dirname(__file__), 10 | 'configs/config_vimo.yaml' 11 | ) 12 | 13 | cfg = CN() 14 | cfg.set_new_allowed(True) 15 | cfg.merge_from_file(cfg_file) 16 | return cfg 17 | 18 | 19 | def get_hmr_vimo(checkpoint=None, device='cuda'): 20 | cfg = get_default_config() 21 | cfg.device = device 22 | model = HMR_VIMO(cfg) 23 | 24 | if checkpoint is not None: 25 | ckpt = torch.load(checkpoint, map_location='cpu') 26 | _ = model.load_state_dict(ckpt['model'], strict=False) 27 | 28 | model = model.to(device) 29 | _ = model.eval() 30 | 31 | return model 32 | 33 | -------------------------------------------------------------------------------- /lib/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/lib/models/components/__init__.py -------------------------------------------------------------------------------- /lib/models/configs/config_vimo.yaml: -------------------------------------------------------------------------------- 1 | LOGDIR: '' 2 | DEVICE: 'cuda' 3 | OUTPUT_DIR: 'results' 4 | NUM_WORKERS: 15 5 | SEED_VALUE: 0 6 | IMG_RES: 256 7 | 8 | DATASET: 9 | LIST: ['3dpw_vid', 'h36m', 'bedlam_vid'] 10 | PARTITION: [0.165, 0.165, 0.67] 11 | SEQ_LEN: 16 12 | TEST: 'emdb_1' 13 | 14 | LOSS: 15 | KPT2D: 5.0 16 | KPT3D: 5.0 17 | SMPL: 1.0 18 | V3D: 1.0 19 | 20 | 21 | TRAIN: 22 | RESUME: None 23 | LOAD_LATEST: True 24 | BATCH_SIZE: 24 25 | MULTI_LR: True 26 | LR: 1e-5 27 | LR2: 3e-5 28 | WARMUP_STEPS: 3000 29 | SUMMARY_STEP: 50 30 | VALID_STEP: 250 31 | SAVE_STEP: 1000 32 | MAX_STEP: 250000 33 | GAMMA: 1 34 | UPDATE_ITER: 1 35 | CLIP_GRADIENT: True 36 | CLIP_NORM: 1.0 37 | WD: 0.01 38 | OPT: 'AdamW' 39 | LOSS_SCALE: 1 40 | 41 | 42 | MODEL: 43 | CHECKPOINT: 'data/pretrain/hmr2b/epoch=35-step=1000000.ckpt' 44 | ST_MODULE: True 45 | MOTION_MODULE: True 46 | ST_HDIM: 512 47 | MOTION_HDIM: 384 48 | ST_NLAYER: 6 49 | MOTION_NLAYER: 6 50 | 51 | 52 | EXP_NAME: 'hmr_vimo' 53 | COMMENT: 54 | 'Default hmr_vimo configs.' 55 | 56 | 57 | -------------------------------------------------------------------------------- /lib/models/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import einops 3 | import torch 4 | import torch.nn as nn 5 | from .components.pose_transformer import TransformerDecoder 6 | 7 | 8 | class SMPLTransformerDecoderHead(nn.Module): 9 | """ HMR2 Cross-attention based SMPL Transformer decoder 10 | """ 11 | def __init__(self, ): 12 | super().__init__() 13 | transformer_args = dict( 14 | depth = 6, # originally 6 15 | heads = 8, 16 | mlp_dim = 1024, 17 | dim_head = 64, 18 | dropout = 0.0, 19 | emb_dropout = 0.0, 20 | norm = "layer", 21 | context_dim = 1280, 22 | num_tokens = 1, 23 | token_dim = 1, 24 | dim = 1024 25 | ) 26 | self.transformer = TransformerDecoder(**transformer_args) 27 | 28 | dim = 1024 29 | npose = 24*6 30 | self.decpose = nn.Linear(dim, npose) 31 | self.decshape = nn.Linear(dim, 10) 32 | self.deccam = nn.Linear(dim, 3) 33 | nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) 34 | nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) 35 | nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) 36 | 37 | mean_params = np.load('data/smpl/smpl_mean_params.npz') 38 | init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0) 39 | init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0) 40 | init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0) 41 | self.register_buffer('init_body_pose', init_body_pose) 42 | self.register_buffer('init_betas', init_betas) 43 | self.register_buffer('init_cam', init_cam) 44 | 45 | 46 | def forward(self, x, **kwargs): 47 | 48 | batch_size = x.shape[0] 49 | # vit pretrained backbone is channel-first. Change to token-first 50 | x = einops.rearrange(x, 'b c h w -> b (h w) c') 51 | 52 | init_body_pose = self.init_body_pose.expand(batch_size, -1) 53 | init_betas = self.init_betas.expand(batch_size, -1) 54 | init_cam = self.init_cam.expand(batch_size, -1) 55 | 56 | # Pass through transformer 57 | token = torch.zeros(batch_size, 1, 1).to(x.device) 58 | token_out = self.transformer(token, context=x) 59 | token_out = token_out.squeeze(1) # (B, C) 60 | 61 | # Readout from token_out 62 | pred_pose = self.decpose(token_out) + init_body_pose 63 | pred_shape = self.decshape(token_out) + init_betas 64 | pred_cam = self.deccam(token_out) + init_cam 65 | 66 | return pred_pose, pred_shape, pred_cam 67 | 68 | 69 | class temporal_attention(nn.Module): 70 | def __init__(self, in_dim=1280, out_dim=1280, hdim=512, nlayer=6, nhead=4, residual=False): 71 | super(temporal_attention, self).__init__() 72 | self.hdim = hdim 73 | self.out_dim = out_dim 74 | self.residual = residual 75 | self.l1 = nn.Linear(in_dim, hdim) 76 | self.l2 = nn.Linear(hdim, out_dim) 77 | 78 | self.pos_embedding = PositionalEncoding(hdim, dropout=0.1) 79 | TranLayer = nn.TransformerEncoderLayer(d_model=hdim, nhead=nhead, dim_feedforward=1024, 80 | dropout=0.1, activation='gelu') 81 | self.trans = nn.TransformerEncoder(TranLayer, num_layers=nlayer) 82 | 83 | nn.init.xavier_uniform_(self.l1.weight, gain=0.01) 84 | nn.init.xavier_uniform_(self.l2.weight, gain=0.01) 85 | 86 | def forward(self, x): 87 | x = x.permute(1,0,2) # (b,t,c) -> (t,b,c) 88 | 89 | h = self.l1(x) 90 | h = self.pos_embedding(h) 91 | h = self.trans(h) 92 | h = self.l2(h) 93 | 94 | if self.residual: 95 | x = x[..., :self.out_dim] + h 96 | else: 97 | x = h 98 | x = x.permute(1,0,2) 99 | 100 | return x 101 | 102 | 103 | class PositionalEncoding(nn.Module): 104 | def __init__(self, d_model, dropout=0.1, max_len=100): 105 | super(PositionalEncoding, self).__init__() 106 | self.dropout = nn.Dropout(p=dropout) 107 | 108 | pe = torch.zeros(max_len, d_model) 109 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 110 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 111 | pe[:, 0::2] = torch.sin(position * div_term) 112 | pe[:, 1::2] = torch.cos(position * div_term) 113 | pe = pe.unsqueeze(0).transpose(0, 1) 114 | 115 | self.register_buffer('pe', pe) 116 | 117 | def forward(self, x): 118 | # not used in the final model 119 | x = x + self.pe[:x.shape[0], :] 120 | return self.dropout(x) 121 | -------------------------------------------------------------------------------- /lib/models/smpl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | import contextlib 6 | 7 | from smplx import SMPL as _SMPL 8 | from smplx import SMPLLayer as _SMPLLayer 9 | from smplx.body_models import SMPLOutput 10 | from smplx.lbs import vertices2joints 11 | 12 | from lib.core.constants import JOINT_MAP, JOINT_NAMES 13 | 14 | 15 | # SMPL data path 16 | SMPL_DATA_PATH = "data/smpl/" 17 | 18 | SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl") 19 | SMPL_MEAN_PARAMS = os.path.join(SMPL_DATA_PATH, "smpl_mean_params.npz") 20 | SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl") 21 | JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy') 22 | JOINT_REGRESSOR_H36M = os.path.join(SMPL_DATA_PATH, 'J_regressor_h36m.npy') 23 | 24 | 25 | class SMPL(_SMPL): 26 | 27 | def __init__(self, create_default=False, *args, **kwargs): 28 | kwargs["model_path"] = "data/smpl" 29 | 30 | # remove the verbosity for the 10-shapes beta parameters 31 | with contextlib.redirect_stdout(None): 32 | super(SMPL, self).__init__( 33 | create_body_pose=create_default, 34 | create_betas=create_default, 35 | create_global_orient=create_default, 36 | create_transl=create_default, 37 | *args, 38 | **kwargs 39 | ) 40 | 41 | # SPIN 49(25 OP + 24) joints 42 | joints = [JOINT_MAP[i] for i in JOINT_NAMES] 43 | J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA) 44 | self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32)) 45 | self.joint_map = torch.tensor(joints, dtype=torch.long) 46 | 47 | 48 | def forward(self, default_smpl=False, *args, **kwargs): 49 | smpl_output = super(SMPL, self).forward(*args, **kwargs) 50 | if default_smpl: 51 | return smpl_output 52 | 53 | extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) 54 | joints = torch.cat([smpl_output.joints, extra_joints], dim=1) 55 | joints = joints[:, self.joint_map, :] 56 | 57 | output = SMPLOutput(vertices=smpl_output.vertices, 58 | global_orient=smpl_output.global_orient, 59 | body_pose=smpl_output.body_pose, 60 | betas=smpl_output.betas, 61 | full_pose=smpl_output.full_pose, 62 | joints=joints) 63 | 64 | return output 65 | 66 | 67 | def query(self, hmr_output, default_smpl=False): 68 | pred_rotmat = hmr_output['pred_rotmat'] 69 | pred_shape = hmr_output['pred_shape'] 70 | 71 | smpl_out = self(global_orient=pred_rotmat[:, [0]], 72 | body_pose = pred_rotmat[:, 1:], 73 | betas = pred_shape, 74 | default_smpl = default_smpl, 75 | pose2rot=False) 76 | return smpl_out 77 | 78 | 79 | -------------------------------------------------------------------------------- /lib/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import video2frames, detect_segment_track 2 | from .visualization import visualize_tram -------------------------------------------------------------------------------- /lib/pipeline/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import cv2 5 | from tqdm import tqdm 6 | from glob import glob 7 | import imageio 8 | 9 | from lib.vis.traj import * 10 | from lib.models.smpl import SMPL 11 | from lib.vis.renderer import Renderer 12 | 13 | 14 | def visualize_tram(seq_folder, floor_scale=2, bin_size=-1, max_faces_per_bin=30000): 15 | img_folder = f'{seq_folder}/images' 16 | hps_folder = f'{seq_folder}/hps' 17 | imgfiles = sorted(glob(f'{img_folder}/*.jpg')) 18 | hps_files = sorted(glob(f'{hps_folder}/*.npy')) 19 | 20 | device = 'cuda' 21 | smpl = SMPL().to(device) 22 | colors = np.loadtxt('data/colors.txt')/255 23 | colors = torch.from_numpy(colors).float() 24 | 25 | max_track = len(hps_files) 26 | tstamp = [t for t in range(len(imgfiles))] 27 | track_verts = {i:[] for i in tstamp} 28 | track_joints = {i:[] for i in tstamp} 29 | track_tid = {i:[] for i in tstamp} 30 | locations = [] 31 | lowest = [] 32 | 33 | ##### TRAM + VIMO ##### 34 | pred_cam = np.load(f'{seq_folder}/camera.npy', allow_pickle=True).item() 35 | img_focal = pred_cam['img_focal'].item() 36 | world_cam_R = torch.tensor(pred_cam['world_cam_R']).to(device) 37 | world_cam_T = torch.tensor(pred_cam['world_cam_T']).to(device) 38 | 39 | for i in range(max_track): 40 | hps_file = hps_files[i] 41 | 42 | pred_smpl = np.load(hps_file, allow_pickle=True).item() 43 | pred_rotmat = pred_smpl['pred_rotmat'].to(device) 44 | pred_shape = pred_smpl['pred_shape'].to(device) 45 | pred_trans = pred_smpl['pred_trans'].to(device) 46 | frame = pred_smpl['frame'] 47 | 48 | mean_shape = pred_shape.mean(dim=0, keepdim=True) 49 | pred_shape = mean_shape.repeat(len(pred_shape), 1) 50 | 51 | pred = smpl(body_pose=pred_rotmat[:,1:], 52 | global_orient=pred_rotmat[:,[0]], 53 | betas=pred_shape, 54 | transl=pred_trans.squeeze(), 55 | pose2rot=False, 56 | default_smpl=True) 57 | pred_vert = pred.vertices 58 | pred_j3d = pred.joints[:, :24] 59 | 60 | cam_r = world_cam_R[frame] 61 | cam_t = world_cam_T[frame] 62 | 63 | pred_vert_w = torch.einsum('bij,bnj->bni', cam_r, pred_vert) + cam_t[:,None] 64 | pred_j3d_w = torch.einsum('bij,bnj->bni', cam_r, pred_j3d) + cam_t[:,None] 65 | pred_vert_w, pred_j3d_w = traj_filter(pred_vert_w.cpu(), 66 | pred_j3d_w.cpu()) 67 | locations.append(pred_j3d_w.mean(1)) 68 | lowest.append(pred_vert_w[:, :, 1].min()) 69 | 70 | for j, f in enumerate(frame.tolist()): 71 | track_tid[f].append(i) 72 | track_verts[f].append(pred_vert_w[j]) 73 | track_joints[f].append(pred_j3d_w[j]) 74 | 75 | 76 | offset = torch.min(torch.stack(lowest)) 77 | offset = torch.tensor([0, offset, 0]).to(device) 78 | 79 | locations = torch.cat(locations).to(device) 80 | cx, cz = (locations.max(0)[0] + locations.min(0)[0])[[0, 2]] / 2.0 81 | sx, sz = (locations.max(0)[0] - locations.min(0)[0])[[0, 2]] 82 | scale = max(sx.item(), sz.item()) * floor_scale 83 | 84 | ##### Viewing Camera ##### 85 | world_cam_T = world_cam_T - offset 86 | view_cam_R = world_cam_R.mT.to('cuda') 87 | view_cam_T = - torch.einsum('bij,bj->bi', world_cam_R, world_cam_T).to('cuda') 88 | 89 | ##### Render video for visualization ##### 90 | writer = imageio.get_writer(f'{seq_folder}/tram_output.mp4', fps=30, mode='I', 91 | format='FFMPEG', macro_block_size=1) 92 | img = cv2.imread(imgfiles[0]) 93 | renderer = Renderer(img.shape[1], img.shape[0], img_focal-100, 'cuda', 94 | smpl.faces, bin_size=bin_size, max_faces_per_bin=max_faces_per_bin) 95 | renderer.set_ground(scale, cx.item(), cz.item()) 96 | 97 | for i in tqdm(range(len(imgfiles))): 98 | img = cv2.imread(imgfiles[i])[:,:,::-1] 99 | 100 | verts_list = track_verts[i] 101 | if len(verts_list)>0: 102 | verts_list = torch.stack(track_verts[i])[:,None].to('cuda') 103 | verts_list -= offset 104 | 105 | tid = track_tid[i] 106 | verts_colors = torch.stack([colors[t] for t in tid]).to('cuda') 107 | 108 | faces = renderer.faces.clone().squeeze(0) 109 | cameras, lights = renderer.create_camera_from_cv(view_cam_R[[i]], 110 | view_cam_T[[i]]) 111 | rend = renderer.render_with_ground_multiple(verts_list, faces, verts_colors, 112 | cameras, lights) 113 | 114 | out = np.concatenate([img, rend], axis=1) 115 | writer.append_data(out) 116 | 117 | writer.close() -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/lib/utils/__init__.py -------------------------------------------------------------------------------- /lib/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def fixseed(seed): 6 | random.seed(seed) 7 | np.random.seed(seed) 8 | torch.manual_seed(seed) 9 | 10 | def to_numpy(tensor): 11 | if torch.is_tensor(tensor): 12 | return tensor.cpu().numpy() 13 | elif type(tensor).__module__ != 'numpy': 14 | raise ValueError("Cannot convert {} to numpy array".format( 15 | type(tensor))) 16 | return tensor 17 | 18 | 19 | def to_torch(ndarray): 20 | if type(ndarray).__module__ == 'numpy': 21 | return torch.from_numpy(ndarray) 22 | elif not torch.is_tensor(ndarray): 23 | raise ValueError("Cannot convert {} to torch tensor".format( 24 | type(ndarray))) 25 | return ndarray 26 | 27 | 28 | def cleanexit(): 29 | import sys 30 | import os 31 | try: 32 | sys.exit(0) 33 | except SystemExit: 34 | os._exit(0) 35 | 36 | -------------------------------------------------------------------------------- /lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import logging 5 | 6 | 7 | def prepare_output_dir(cfg): 8 | logfolder = cfg.EXP_NAME 9 | logdir = os.path.join(cfg.OUTPUT_DIR, logfolder) 10 | cfg.LOGDIR = logdir 11 | os.makedirs(logdir, exist_ok=True) 12 | 13 | shutil.copy(src=cfg.cfg_file, dst=f'{cfg.LOGDIR}/config.yaml') 14 | return cfg 15 | 16 | 17 | def create_logger(logdir, phase='train'): 18 | os.makedirs(logdir, exist_ok=True) 19 | 20 | log_file = os.path.join(logdir, f'{phase}_log.txt') 21 | head = '%(asctime)-15s %(message)s' 22 | 23 | logging.basicConfig(filename=log_file, 24 | format=head) 25 | logger = logging.getLogger() 26 | logger.setLevel(logging.INFO) 27 | console = logging.StreamHandler() 28 | logging.getLogger('').addHandler(console) 29 | 30 | return logger 31 | 32 | 33 | def move_dict_to_device(dict, device, tensor2float=False): 34 | for k,v in dict.items(): 35 | if isinstance(v, torch.Tensor): 36 | if tensor2float: 37 | dict[k] = v.float().to(device) 38 | else: 39 | dict[k] = v.to(device) 40 | 41 | 42 | def concatenate_dicts(dict_list, dim=0): 43 | rdict = dict.fromkeys(dict_list[0].keys()) 44 | for k in rdict.keys(): 45 | rdict[k] = torch.cat([d[k] for d in dict_list], dim=dim) 46 | return rdict 47 | 48 | 49 | class AverageMeter(object): 50 | def __init__(self): 51 | self.val = 0 52 | self.avg = 0 53 | self.sum = 0 54 | self.count = 0 55 | 56 | def update(self, val, n=1): 57 | self.val = val 58 | self.sum += val * n 59 | self.count += n 60 | self.avg = self.sum / self.count 61 | -------------------------------------------------------------------------------- /lib/utils/utils_detectron2.py: -------------------------------------------------------------------------------- 1 | import detectron2.data.transforms as T 2 | import torch 3 | from detectron2.checkpoint import DetectionCheckpointer 4 | from detectron2.config import CfgNode, instantiate 5 | from detectron2.data import MetadataCatalog 6 | from omegaconf import OmegaConf 7 | 8 | 9 | class DefaultPredictor_Lazy: 10 | """Create a simple end-to-end predictor with the given config that runs on single device for a 11 | single input image. 12 | 13 | Compared to using the model directly, this class does the following additions: 14 | 15 | 1. Load checkpoint from the weights specified in config (cfg.MODEL.WEIGHTS). 16 | 2. Always take BGR image as the input and apply format conversion internally. 17 | 3. Apply resizing defined by the config (`cfg.INPUT.{MIN,MAX}_SIZE_TEST`). 18 | 4. Take one input image and produce a single output, instead of a batch. 19 | 20 | This is meant for simple demo purposes, so it does the above steps automatically. 21 | This is not meant for benchmarks or running complicated inference logic. 22 | If you'd like to do anything more complicated, please refer to its source code as 23 | examples to build and use the model manually. 24 | 25 | Attributes: 26 | metadata (Metadata): the metadata of the underlying dataset, obtained from 27 | test dataset name in the config. 28 | 29 | 30 | Examples: 31 | :: 32 | pred = DefaultPredictor(cfg) 33 | inputs = cv2.imread("input.jpg") 34 | outputs = pred(inputs) 35 | """ 36 | 37 | def __init__(self, cfg): 38 | """ 39 | Args: 40 | cfg: omegaconf dict object. 41 | """ 42 | # new LazyConfig 43 | self.cfg = cfg 44 | self.model = instantiate(cfg.model) 45 | test_dataset = OmegaConf.select(cfg, "dataloader.test.dataset.names", default=None) 46 | if isinstance(test_dataset, (list, tuple)): 47 | test_dataset = test_dataset[0] 48 | 49 | checkpointer = DetectionCheckpointer(self.model) 50 | checkpointer.load(OmegaConf.select(cfg, "train.init_checkpoint", default="")) 51 | 52 | mapper = instantiate(cfg.dataloader.test.mapper) 53 | self.aug = mapper.augmentations 54 | self.input_format = mapper.image_format 55 | 56 | self.model.eval().cuda() 57 | if test_dataset: 58 | self.metadata = MetadataCatalog.get(test_dataset) 59 | assert self.input_format in ["RGB", "BGR"], self.input_format 60 | 61 | def __call__(self, original_image): 62 | """ 63 | Args: 64 | original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). 65 | 66 | Returns: 67 | predictions (dict): 68 | the output of the model for one image only. 69 | See :doc:`/tutorials/models` for details about the format. 70 | """ 71 | with torch.no_grad(): 72 | if self.input_format == "RGB": 73 | original_image = original_image[:, :, ::-1] 74 | height, width = original_image.shape[:2] 75 | image = self.aug(T.AugInput(original_image)).apply_image(original_image) 76 | image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) 77 | inputs = {"image": image, "height": height, "width": width} 78 | predictions = self.model([inputs])[0] 79 | return predictions 80 | -------------------------------------------------------------------------------- /lib/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from plyfile import PlyData, PlyElement 5 | from matplotlib.colors import Normalize 6 | import matplotlib.cm as cm 7 | 8 | 9 | def draw_kpts(img, kpts, r=5, thickness=5, color=(255,0,0), confidence=1e-6): 10 | if isinstance(img, np.ndarray): 11 | img = img.copy().astype(np.uint8) 12 | if isinstance(img, torch.Tensor): 13 | img = img.numpy() 14 | img = img.copy().astype(np.uint8) 15 | 16 | for kpt in kpts: 17 | if len(kpt)>2: 18 | x, y, c = kpt 19 | else: 20 | x, y = kpt 21 | c = 1 22 | 23 | if c >= confidence: 24 | cv2.circle(img, (int(x), int(y)), r, color, thickness) 25 | 26 | return img 27 | 28 | def draw_boxes(img, boxes, thickness=5, color=(0,255,0)): 29 | img_box = img.copy() 30 | for box in boxes: 31 | x1, y1, x2, y2 = box[:4] 32 | img_box = cv2.rectangle(img_box, (int(x1),int(y1)), (int(x2),int(y2)), 33 | color, thickness) 34 | return img_box 35 | 36 | 37 | def to_rgb(grey, cmap='YlGnBu', resize=[224, 224], normalize=True): 38 | # cmap_list = ['YlGnBu', 'coolwarm', 'RdBu'] 39 | g = np.array(grey) 40 | cmap = cm.get_cmap(cmap) 41 | 42 | if normalize: 43 | norm = Normalize(vmin=g.min(), vmax=g.max()) 44 | g = norm(g) 45 | rgb = cmap(g)[:,:,:3] 46 | 47 | if resize is not None: 48 | rgb = cv2.resize(rgb, resize) 49 | 50 | rgb = (rgb * 255).astype(int) 51 | return rgb 52 | 53 | 54 | def to_rgb_norm(grey, cmap='YlGnBu', resize=[224, 224], min_v=0.0, max_v=1.0, normalize=True): 55 | # cmap_list = ['YlGnBu', 'coolwarm', 'RdBu'] 56 | g = np.array(grey) 57 | cmap = cm.get_cmap(cmap) 58 | 59 | if normalize: 60 | norm = Normalize(vmin=min_v, vmax=max_v) 61 | g = norm(g) 62 | rgb = cmap(g)[:,:,:3] 63 | 64 | if resize is not None: 65 | rgb = cv2.resize(rgb, resize) 66 | 67 | rgb = (rgb * 255).astype(int) 68 | return rgb 69 | 70 | 71 | ### Save for visualization 72 | def save_ply(vert, face=None, color=None, filename='file.ply'): 73 | 74 | # Colors 75 | if color is None: 76 | vtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 77 | else: 78 | vtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 79 | vert = np.concatenate([vert, color], axis=-1) 80 | 81 | # Vertices 82 | if isinstance(vert, np.ndarray): 83 | vert = vert.tolist() 84 | vert = [tuple(v) for v in vert] 85 | vert = np.array(vert, dtype=vtype) 86 | vert = PlyElement.describe(vert, 'vertex') 87 | 88 | # Faces 89 | if face is not None: 90 | if isinstance(face, np.ndarray): 91 | face = face.tolist() 92 | face = [(face[i], 255, 255, 255) for i in range(len(face))] 93 | face = np.array(face, dtype=[('vertex_indices', 'i4', (3,)), 94 | ('red', 'u1'), 95 | ('green', 'u1'), 96 | ('blue', 'u1')]) 97 | face = PlyElement.describe(face, 'face') 98 | 99 | # Save 100 | if face is not None: 101 | with open(filename, 'wb') as f: 102 | PlyData([vert, face]).write(f) 103 | else: 104 | with open(filename, 'wb') as f: 105 | PlyData([vert]).write(f) 106 | 107 | 108 | def read_ply(plyfile): 109 | plydata = PlyData.read(plyfile) 110 | v = plydata['vertex'].data 111 | v = [list(i) for i in v] 112 | v = np.array(v) 113 | f = plydata['face'].data 114 | f = [list(i) for i in f] 115 | f = np.array(f).squeeze() 116 | return v, f 117 | 118 | 119 | # from transforms3d.euler import euler2mat 120 | 121 | # def novel_view(s_out, angle=-0.25*np.pi, axis='y', trans=8): 122 | # vertices = s_out.vertices.clone().cpu() 123 | # joints = s_out.joints.clone().cpu() 124 | 125 | # j3d = joints[:, 25:] 126 | # pelvis3d = j3d[:, [14]] 127 | # verts = vertices - pelvis3d 128 | 129 | # if axis=='x': 130 | # rot = euler2mat(angle, 0, 0, "sxyz") 131 | # elif axis=='y': 132 | # rot = euler2mat(0, angle, 0, "sxyz") 133 | # else: 134 | # rot = euler2mat(0, 0, angle, "sxyz") 135 | 136 | # rot = torch.from_numpy(rot).float() 137 | # verts = torch.einsum('ij, bvj->bvi', rot, verts) 138 | 139 | # # verts[:,:,1] -= 0.1 140 | # verts[:,:,2] += trans 141 | 142 | # return verts 143 | 144 | 145 | -------------------------------------------------------------------------------- /lib/vis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/lib/vis/__init__.py -------------------------------------------------------------------------------- /lib/vis/renderer_img.py: -------------------------------------------------------------------------------- 1 | from optparse import Option 2 | import os 3 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 4 | #os.environ['PYOPENGL_PLATFORM'] = 'osmesa' 5 | import torch 6 | import numpy as np 7 | import pyrender 8 | import trimesh 9 | import cv2 10 | from typing import List, Optional 11 | 12 | # RGB-A 13 | blue = (0.3, 0.5, 0.9, 1.0) 14 | green = (0.45, 0.75, 0.533, 1.0) 15 | yellow = (0.88, 0.85, 0.528, 1.0) 16 | 17 | 18 | class Renderer: 19 | 20 | def __init__(self, faces, color=(0.3, 0.5, 0.9, 1.0), size=None): 21 | """ 22 | Wrapper around the pyrender renderer to render SMPL meshes. 23 | Args: 24 | faces (np.array): Array of shape (F, 3) containing the mesh faces. 25 | """ 26 | 27 | self.size = size 28 | self.faces = faces 29 | 30 | self.light_nodes = create_raymond_lights() 31 | 32 | self.material = pyrender.MetallicRoughnessMaterial( 33 | metallicFactor=0.1, 34 | alphaMode='OPAQUE', 35 | baseColorFactor=color) 36 | self.renderer = None 37 | 38 | 39 | def init_renderer(self, height=None, width=None, image=None): 40 | if height is None or width is None: 41 | height, width = image.shape[:2] 42 | self.renderer = pyrender.OffscreenRenderer(viewport_width=width, 43 | viewport_height=height, 44 | point_size=1.0) 45 | 46 | 47 | def __call__(self, vertices, camera_translation, image=None, focal=None, center=None, return_depth=False) : 48 | """ 49 | Render meshes on input image 50 | Args: 51 | vertices (np.array): Array of shape (V, 3) containing the mesh vertices. 52 | camera_translation (np.array): Array of shape (3,) with the camera translation. 53 | image (np.array): Array of shape (H, W, 3) containing the image crop with normalized pixel values. 54 | """ 55 | 56 | height, width = image.shape[:2] 57 | 58 | if self.renderer is None: 59 | renderer = pyrender.OffscreenRenderer(viewport_width=width, 60 | viewport_height=height, 61 | point_size=1.0) 62 | else: 63 | renderer = self.renderer 64 | 65 | 66 | scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], 67 | ambient_light=(0.5, 0.5, 0.5)) 68 | 69 | if focal is None: 70 | focal = np.sqrt(width**2 + height**2) 71 | 72 | if center is None: 73 | center = [width / 2., height / 2.] 74 | 75 | camera_translation = np.array(camera_translation) # also make a copy 76 | camera_translation[0] *= -1. 77 | 78 | # Create mesh 79 | if len(vertices.shape) == 2: 80 | vertices = vertices[None] 81 | 82 | for vert in vertices: 83 | mesh = trimesh.Trimesh(vert, self.faces, process=False) 84 | rot = trimesh.transformations.rotation_matrix( 85 | np.radians(180), [1, 0, 0]) 86 | 87 | mesh.apply_transform(rot) 88 | mesh = pyrender.Mesh.from_trimesh(mesh, material=self.material, smooth=True) 89 | scene.add(mesh, 'mesh') 90 | 91 | # Create camera 92 | camera_pose = np.eye(4) 93 | camera_pose[:3, 3] = camera_translation 94 | camera = pyrender.IntrinsicsCamera(fx=focal, fy=focal, 95 | cx=center[0], cy=center[1], zfar=1000) 96 | scene.add(camera, pose=camera_pose) 97 | 98 | # Create light 99 | for node in self.light_nodes: scene.add_node(node) 100 | 101 | # Render 102 | color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.NONE) 103 | 104 | 105 | # Composite 106 | if image is None: 107 | output_img = color[:, :, :3] 108 | else: 109 | valid_mask = (rend_depth > 0)[:, :, np.newaxis].astype(np.uint8) 110 | output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image) 111 | 112 | del scene 113 | 114 | if return_depth: 115 | return output_img, rend_depth 116 | else: 117 | return output_img 118 | 119 | 120 | def create_raymond_lights() -> List[pyrender.Node]: 121 | """ 122 | Return raymond light nodes for the scene. 123 | """ 124 | thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0]) 125 | phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0]) 126 | 127 | nodes = [] 128 | 129 | for phi, theta in zip(phis, thetas): 130 | xp = np.sin(theta) * np.cos(phi) 131 | yp = np.sin(theta) * np.sin(phi) 132 | zp = np.cos(theta) 133 | 134 | z = np.array([xp, yp, zp]) 135 | z = z / np.linalg.norm(z) 136 | x = np.array([-z[1], z[0], 0.0]) 137 | if np.linalg.norm(x) == 0: 138 | x = np.array([1.0, 0.0, 0.0]) 139 | x = x / np.linalg.norm(x) 140 | y = np.cross(z, x) 141 | 142 | matrix = np.eye(4) 143 | matrix[:3,:3] = np.c_[x,y,z] 144 | nodes.append(pyrender.Node( 145 | light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0), 146 | matrix=matrix 147 | )) 148 | 149 | return nodes 150 | -------------------------------------------------------------------------------- /scripts/crop_datasets.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.dirname(__file__) + '/..') 4 | 5 | import cv2 6 | import random 7 | import numpy as np 8 | from tqdm import tqdm 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | from data_config import ROOT 13 | from lib.datasets.base_dataset import BaseDataset 14 | 15 | 16 | SEED_VALUE = 0 17 | os.environ['PYTHONHASHSEED'] = str(SEED_VALUE) 18 | random.seed(SEED_VALUE) 19 | torch.manual_seed(SEED_VALUE) 20 | np.random.seed(SEED_VALUE) 21 | 22 | # Datasets 23 | # ds_list = ['3dpw_vid', 'h36m_vid', 'bedlam_vid'] 24 | # ds_list = ['3dpw_vid_test', 'emdb_1'] 25 | ds_list = ['bedlam_vid'] 26 | 27 | save_dir = {'h36m_vid': ROOT + '/h36m/crops', 28 | '3dpw_vid': ROOT + '/3dpw/crops', 29 | 'bedlam_vid': ROOT + '/bedlam_30fps/crops', 30 | '3dpw_vid_test': ROOT + '/3dpw/crops_test', 31 | 'emdb_1': ROOT + '/emdb/crops_1'} 32 | 33 | for ds in ds_list: 34 | print(f'Processing (crop) {ds} ...') 35 | 36 | # DATASET 37 | db = BaseDataset(ds, is_train=True, crop_size=256) 38 | loader = DataLoader(db, batch_size=64, num_workers=15, shuffle=False) 39 | 40 | imgdir = save_dir[ds] 41 | os.makedirs(imgdir, exist_ok=True) 42 | 43 | c = 0 44 | for i, batch in enumerate(tqdm(loader)): 45 | images = batch['img'].numpy() 46 | for img in images: 47 | cv2.imwrite(f'{imgdir}/{c:08d}.jpg', img) 48 | c += 1 49 | -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } 3 | 4 | # SMPL Neutral model 5 | echo -e "\nYou need to register at https://smplify.is.tue.mpg.de" 6 | read -p "Username (SMPLify):" username 7 | read -p "Password (SMPLify):" password 8 | username=$(urle $username) 9 | password=$(urle $password) 10 | 11 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplify&resume=1&sfile=mpips_smplify_public_v2.zip' -O './data/smpl/smplify.zip' --no-check-certificate --continue 12 | unzip data/smpl/smplify.zip -d data/smpl/smplify 13 | mv data/smpl/smplify/smplify_public/code/models/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl data/smpl/SMPL_NEUTRAL.pkl 14 | rm -rf data/smpl/smplify 15 | rm -rf data/smpl/smplify.zip 16 | 17 | # SMPL Male and Female model 18 | echo -e "\nYou need to register at https://smpl.is.tue.mpg.de" 19 | read -p "Username (SMPL):" username 20 | read -p "Password (SMPL):" password 21 | username=$(urle $username) 22 | password=$(urle $password) 23 | 24 | wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smpl&sfile=SMPL_python_v.1.0.0.zip' -O './data/smpl/smpl.zip' --no-check-certificate --continue 25 | unzip data/smpl/smpl.zip -d data/smpl/smpl 26 | mv data/smpl/smpl/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl data/smpl/SMPL_FEMALE.pkl 27 | mv data/smpl/smpl/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl data/smpl/SMPL_MALE.pkl 28 | rm -rf data/smpl/smpl 29 | rm -rf data/smpl/smpl.zip 30 | 31 | # Thirdparty checkpoints 32 | wget -P ./data/pretrain/ https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/DEVA-propagation.pth 33 | wget -P ./data/pretrain/ https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 34 | gdown --fuzzy -O ./data/pretrain/droid.pth https://drive.google.com/file/d/1PpqVt1H4maBa_GbPJp4NwxRsd9jk-elh/view?usp=sharing 35 | gdown --fuzzy -O ./data/pretrain/camcalib_sa_biased_l2.ckpt https://drive.google.com/file/d/1t4tO0OM5s8XDvAzPW-5HaOkQuV3dHBdO/view?usp=sharing 36 | 37 | # Our checkpoint and an example video 38 | gdown --fuzzy -O ./data/pretrain/vimo_checkpoint.pth.tar https://drive.google.com/file/d/1fdeUxn_hK4ERGFwuksFpV_-_PHZJuoiW/view?usp=share_link 39 | gdown --fuzzy -O ./example_video.mov https://drive.google.com/file/d/1H6gyykajrk2JsBBxBIdt9Z49oKgYAuYJ/view?usp=share_link -------------------------------------------------------------------------------- /scripts/download_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } 3 | 4 | # FOR TRAINING ONLY 5 | # VIBE occlusion augmentation 6 | gdown --fuzzy -O ./data/pascal_occluders.pkl https://drive.google.com/file/d/1_Qv9eAKVkfvZjdl9qaRyxVrIeAnAwevE/view?usp=sharing 7 | 8 | # HMR2b checkpoint 9 | mkdir -p data/pretrain/hmr2b 10 | gdown --fuzzy -O ./data/pretrain/hmr2b/epoch=35-step=1000000.ckpt https://drive.google.com/file/d/1W4fcp8mwS19Rg_A7MoTS1lc7JafqTGu-/view?usp=sharing -------------------------------------------------------------------------------- /scripts/emdb/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | split=2 4 | cam_output_dir="results/emdb/camera" 5 | smpl_output_dir="results/emdb/smpl" 6 | eval_input_dir="results/emdb" 7 | 8 | python scripts/emdb/run_cam.py --split $split --output_dir "$cam_output_dir" 9 | python scripts/emdb/run_smpl.py --split $split --output_dir "$smpl_output_dir" 10 | python scripts/emdb/run_eval.py --split $split --input_dir "$eval_input_dir" 11 | -------------------------------------------------------------------------------- /scripts/emdb/run_cam.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.dirname(__file__) + '/../..') 4 | 5 | import cv2 6 | import torch 7 | import argparse 8 | import numpy as np 9 | import pickle as pkl 10 | from glob import glob 11 | from tqdm import tqdm 12 | 13 | from lib.camera import run_metric_slam, align_cam_to_world 14 | from lib.pipeline.tools import arrange_boxes 15 | from lib.utils.utils_detectron2 import DefaultPredictor_Lazy 16 | 17 | from torch.amp import autocast 18 | from segment_anything import SamPredictor, sam_model_registry 19 | from detectron2.config import LazyConfig 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--split', type=int, default=2) 23 | parser.add_argument('--output_dir', type=str, default='results/emdb/camera') 24 | args = parser.parse_args() 25 | 26 | 27 | # EMDB dataset and splits 28 | roots = [] 29 | for p in range(10): 30 | folder = f'/mnt/kostas-graid/datasets/yufu/emdb/P{p}' 31 | root = sorted(glob(f'{folder}/*')) 32 | roots.extend(root) 33 | 34 | emdb = [] 35 | spl = args.split 36 | for root in roots: 37 | annfile = f'{root}/{root.split("/")[-2]}_{root.split("/")[-1]}_data.pkl' 38 | ann = pkl.load(open(annfile, 'rb')) 39 | if ann[f'emdb{spl}']: 40 | emdb.append(root) 41 | 42 | 43 | # Save folder 44 | savefolder = args.ourtput_dir 45 | os.makedirs(savefolder, exist_ok=True) 46 | 47 | # ViTDet 48 | device = 'cuda' 49 | cfg_path = 'data/pretrain/cascade_mask_rcnn_vitdet_h_75ep.py' 50 | detectron2_cfg = LazyConfig.load(str(cfg_path)) 51 | detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl" 52 | for i in range(3): 53 | detectron2_cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25 54 | detector = DefaultPredictor_Lazy(detectron2_cfg) 55 | 56 | # SAM 57 | sam = sam_model_registry["vit_h"](checkpoint="data/pretrain/sam_vit_h_4b8939.pth") 58 | _ = sam.to(device) 59 | predictor = SamPredictor(sam) 60 | 61 | 62 | # Estimate camera motion on EMDB (subset: spl) 63 | for root in emdb: 64 | print(f'Running on {root}...') 65 | 66 | seq = root.split('/')[-1] 67 | img_folder = f'{root}/images' 68 | imgfiles = sorted(glob(f'{img_folder}/*.jpg')) 69 | 70 | masks_ = [] 71 | for t, imgpath in enumerate(tqdm(imgfiles)): 72 | img_cv2 = cv2.imread(imgpath) 73 | 74 | ### --- Detection --- 75 | with torch.no_grad(): 76 | with autocast('cuda'): 77 | det_out = detector(img_cv2) 78 | det_instances = det_out['instances'] 79 | valid_idx = (det_instances.pred_classes==0) & (det_instances.scores > 0.5) 80 | boxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy() 81 | confs = det_instances.scores[valid_idx].cpu().numpy() 82 | 83 | boxes = np.hstack([boxes, confs[:, None]]) 84 | boxes = arrange_boxes(boxes, mode='size', min_size=100) 85 | 86 | 87 | ### --- SAM --- 88 | if len(boxes)>0: 89 | with autocast('cuda'): 90 | predictor.set_image(img_cv2, image_format='BGR') 91 | 92 | # multiple boxes 93 | bb = torch.tensor(boxes[:, :4]).cuda() 94 | bb = predictor.transform.apply_boxes_torch(bb, img_cv2.shape[:2]) 95 | masks, scores, _ = predictor.predict_torch( 96 | point_coords=None, 97 | point_labels=None, 98 | boxes=bb, 99 | multimask_output=False 100 | ) 101 | scores = scores.cpu() 102 | masks = masks.cpu().squeeze(1) 103 | mask = masks.sum(dim=0) 104 | else: 105 | mask = torch.zeros_like(mask) 106 | 107 | masks_.append(mask.byte()) 108 | 109 | masks = torch.stack(masks_) 110 | 111 | 112 | ### --- Camera Motion --- 113 | annfile = f'{root}/{root.split("/")[-2]}_{root.split("/")[-1]}_data.pkl' 114 | ann = pkl.load(open(annfile, 'rb')) 115 | intr = ann['camera']['intrinsics'] 116 | 117 | cam_int = [intr[0,0], intr[1,1], intr[0,2], intr[1,2]] 118 | cam_R, cam_T = run_metric_slam(img_folder, masks=masks, calib=cam_int) 119 | wd_cam_R, wd_cam_T, spec_f = align_cam_to_world(imgfiles[0], cam_R, cam_T) 120 | 121 | camera = {'pred_cam_R': cam_R.numpy(), 'pred_cam_T': cam_T.numpy(), 122 | 'world_cam_R': wd_cam_R.numpy(), 'world_cam_T': wd_cam_T.numpy(), 123 | 'img_focal': cam_int[0], 'img_center': cam_int[2:], 'spec_focal': spec_f} 124 | 125 | 126 | ### --- Save results --- 127 | savefile = f'{savefolder}/{seq}.npz' 128 | np.savez(savefile, **camera) 129 | 130 | -------------------------------------------------------------------------------- /scripts/emdb/run_smpl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.dirname(__file__) + '/../..') 4 | 5 | import torch 6 | import argparse 7 | import numpy as np 8 | import pickle as pkl 9 | from glob import glob 10 | from tqdm import tqdm 11 | 12 | from torch.utils.data import default_collate 13 | from lib.models import get_hmr_vimo 14 | from lib.datasets.image_dataset import ImageDataset 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--split', type=int, default=2) 18 | parser.add_argument('--output_dir', type=str, default='results/emdb/smpl') 19 | parser.add_argument('--efficient', action='store_true', help='efficient option, but increase ACC error.') 20 | args = parser.parse_args() 21 | 22 | 23 | # EMDB dataset and splits 24 | roots = [] 25 | for p in range(10): 26 | folder = f'/mnt/kostas-graid/datasets/yufu/emdb/P{p}' 27 | root = sorted(glob(f'{folder}/*')) 28 | roots.extend(root) 29 | 30 | emdb = [] 31 | spl = args.split 32 | for root in roots: 33 | annfile = f'{root}/{root.split("/")[-2]}_{root.split("/")[-1]}_data.pkl' 34 | ann = pkl.load(open(annfile, 'rb')) 35 | if ann[f'emdb{spl}']: 36 | emdb.append(root) 37 | 38 | 39 | # Save folder 40 | savefolder = args.output_dir 41 | os.makedirs(savefolder, exist_ok=True) 42 | 43 | # HPS model 44 | device = 'cuda' 45 | model = get_hmr_vimo(checkpoint='data/pretrain/vimo_checkpoint.pth.tar').to(device) 46 | 47 | 48 | # Predict SMPL on EMDB (subset: spl) 49 | for i, root in enumerate(emdb): 50 | print('Running HPS on', root) 51 | 52 | seq = root.split('/')[-1] 53 | imgfiles = sorted(glob(f'{root}/images/*.jpg')) 54 | annfile = f'{root}/{root.split("/")[-2]}_{root.split("/")[-1]}_data.pkl' 55 | ann = pkl.load(open(annfile, 'rb')) 56 | 57 | ext = ann['camera']['extrinsics'] 58 | intr = ann['camera']['intrinsics'] 59 | ann_boxes = ann['bboxes']['bboxes'] 60 | img_focal = (intr[0,0] + intr[1,1]) / 2. 61 | img_center = intr[:2, 2] 62 | 63 | db = ImageDataset(imgfiles, ann_boxes, img_focal=img_focal, 64 | img_center=img_center, normalization=True) 65 | dataloader = torch.utils.data.DataLoader(db, batch_size=64, shuffle=False, num_workers=12) 66 | 67 | # Results 68 | pred_cam = [] 69 | pred_pose = [] 70 | pred_shape = [] 71 | pred_rotmat = [] 72 | pred_trans = [] 73 | 74 | ### Efficient option: none-overlap sliding window (higher acc error from 4.5 to 7) 75 | if args.efficient: 76 | for batch in tqdm(dataloader): 77 | batch = {k: v.to(device) for k, v in batch.items() if type(v)==torch.Tensor} 78 | 79 | # Last batch 80 | n = len(batch['img']) 81 | if n < 64: 82 | for k in batch: 83 | batch[k] = torch.cat([previous_batch[k][n-64:], 84 | batch[k]], dim=0) 85 | 86 | with torch.no_grad(): 87 | out, _ = model(batch) 88 | 89 | # Last batch 90 | if n < 64: 91 | for k in out: 92 | out[k] = out[k][64-n:] 93 | 94 | pred_cam.append(out['pred_cam'].cpu()) 95 | pred_pose.append(out['pred_pose'].cpu()) 96 | pred_shape.append(out['pred_shape'].cpu()) 97 | pred_rotmat.append(out['pred_rotmat'].cpu()) 98 | pred_trans.append(out['trans_full'].cpu()) 99 | previous_batch = batch 100 | 101 | ### Maximum overlapping sliding window 102 | else: 103 | items = [] 104 | for i in tqdm(range(len(db))): 105 | item = db[i] 106 | items.append(item) 107 | 108 | if len(items) < 16: 109 | continue 110 | elif len(items) == 16: 111 | batch = default_collate(items) 112 | else: 113 | items.pop(0) 114 | batch = default_collate(items) 115 | 116 | with torch.no_grad(): 117 | batch = {k: v.to(device) for k, v in batch.items() if type(v)==torch.Tensor} 118 | out, _ = model.forward(batch) 119 | 120 | if i == 15: 121 | out = {k:v[:9] for k,v in out.items()} 122 | elif i == len(db) - 1: 123 | out = {k:v[8:] for k,v in out.items()} 124 | else: 125 | out = {k:v[[8]] for k,v in out.items()} 126 | 127 | pred_cam.append(out['pred_cam'].cpu()) 128 | pred_pose.append(out['pred_pose'].cpu()) 129 | pred_shape.append(out['pred_shape'].cpu()) 130 | pred_rotmat.append(out['pred_rotmat'].cpu()) 131 | pred_trans.append(out['trans_full'].cpu()) 132 | 133 | 134 | results = {'pred_cam': torch.cat(pred_cam), 135 | 'pred_pose': torch.cat(pred_pose), 136 | 'pred_shape': torch.cat(pred_shape), 137 | 'pred_rotmat': torch.cat(pred_rotmat), 138 | 'pred_trans': torch.cat(pred_trans), 139 | 'img_focal': img_focal, 140 | 'img_center': img_center} 141 | 142 | np.savez(f'{savefolder}/{seq}.npz', **results) 143 | 144 | 145 | -------------------------------------------------------------------------------- /scripts/estimate_camera.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.dirname(__file__) + '/..') 4 | 5 | import torch 6 | import argparse 7 | import numpy as np 8 | from glob import glob 9 | from pycocotools import mask as masktool 10 | 11 | from lib.pipeline import video2frames, detect_segment_track, visualize_tram 12 | from lib.camera import run_metric_slam, calibrate_intrinsics, align_cam_to_world 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--video", type=str, default='./example_video.mov', help='input video') 17 | parser.add_argument("--static_camera", action='store_true', help='whether the camera is static') 18 | parser.add_argument("--visualize_mask", action='store_true', help='save deva vos for visualization') 19 | args = parser.parse_args() 20 | 21 | # File and folders 22 | file = args.video 23 | root = os.path.dirname(file) 24 | seq = os.path.basename(file).split('.')[0] 25 | 26 | seq_folder = f'results/{seq}' 27 | img_folder = f'{seq_folder}/images' 28 | os.makedirs(seq_folder, exist_ok=True) 29 | os.makedirs(img_folder, exist_ok=True) 30 | 31 | ##### Extract Frames ##### 32 | print('Extracting frames ...') 33 | nframes = video2frames(file, img_folder) 34 | 35 | ##### Detection + SAM + DEVA-Track-Anything ##### 36 | print('Detect, Segment, and Track ...') 37 | imgfiles = sorted(glob(f'{img_folder}/*.jpg')) 38 | boxes_, masks_, tracks_ = detect_segment_track(imgfiles, seq_folder, thresh=0.25, 39 | min_size=100, save_vos=args.visualize_mask) 40 | 41 | ##### Run Masked DROID-SLAM ##### 42 | print('Masked Metric SLAM ...') 43 | masks = np.array([masktool.decode(m) for m in masks_]) 44 | masks = torch.from_numpy(masks) 45 | 46 | cam_int, is_static = calibrate_intrinsics(img_folder, masks, is_static=args.static_camera) 47 | cam_R, cam_T = run_metric_slam(img_folder, masks=masks, calib=cam_int, is_static=is_static) 48 | wd_cam_R, wd_cam_T, spec_f = align_cam_to_world(imgfiles[0], cam_R, cam_T) 49 | 50 | camera = {'pred_cam_R': cam_R.numpy(), 'pred_cam_T': cam_T.numpy(), 51 | 'world_cam_R': wd_cam_R.numpy(), 'world_cam_T': wd_cam_T.numpy(), 52 | 'img_focal': cam_int[0], 'img_center': cam_int[2:], 'spec_focal': spec_f} 53 | 54 | np.save(f'{seq_folder}/camera.npy', camera) 55 | np.save(f'{seq_folder}/boxes.npy', boxes_) 56 | np.save(f'{seq_folder}/masks.npy', masks_) 57 | np.save(f'{seq_folder}/tracks.npy', tracks_) 58 | 59 | -------------------------------------------------------------------------------- /scripts/estimate_humans.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.dirname(__file__) + '/..') 4 | 5 | import torch 6 | import argparse 7 | import numpy as np 8 | from glob import glob 9 | 10 | from lib.models import get_hmr_vimo 11 | from lib.pipeline import visualize_tram 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--video', type=str, default='./example_video.mov', help='input video') 16 | parser.add_argument('--max_humans', type=int, default=20, help='maximum number of humans to reconstruct') 17 | args = parser.parse_args() 18 | 19 | # File and folders 20 | file = args.video 21 | root = os.path.dirname(file) 22 | seq = os.path.basename(file).split('.')[0] 23 | 24 | seq_folder = f'results/{seq}' 25 | img_folder = f'{seq_folder}/images' 26 | hps_folder = f'{seq_folder}/hps' 27 | os.makedirs(hps_folder, exist_ok=True) 28 | 29 | ##### Preprocess results from estimate_camera.py ##### 30 | imgfiles = sorted(glob(f'{img_folder}/*.jpg')) 31 | camera = np.load(f'{seq_folder}/camera.npy', allow_pickle=True).item() 32 | tracks = np.load(f'{seq_folder}/tracks.npy', allow_pickle=True).item() 33 | 34 | img_focal = camera['img_focal'] 35 | img_center = camera['img_center'] 36 | 37 | # Sort the tracks by length 38 | tid = [k for k in tracks.keys()] 39 | lens = [len(trk) for trk in tracks.values()] 40 | rank = np.argsort(lens)[::-1] 41 | tracks = [tracks[tid[r]] for r in rank] 42 | 43 | ##### Run HPS (here we use tram) ##### 44 | print('Estimate HPS ...') 45 | model = get_hmr_vimo(checkpoint='data/pretrain/vimo_checkpoint.pth.tar') 46 | 47 | for k, trk in enumerate(tracks): 48 | valid = np.array([t['det'] for t in trk]) 49 | boxes = np.concatenate([t['det_box'] for t in trk]) 50 | frame = np.array([t['frame'] for t in trk]) 51 | results = model.inference(imgfiles, boxes, valid=valid, frame=frame, 52 | img_focal=img_focal, img_center=img_center) 53 | 54 | if results is not None: 55 | np.save(f'{hps_folder}/hps_track_{k}.npy', results) 56 | 57 | if k+1 >= args.max_humans: 58 | break 59 | -------------------------------------------------------------------------------- /scripts/extract_bedlam_jpg.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.dirname(__file__) + '/..') 4 | 5 | import os 6 | import cv2 7 | import numpy as np 8 | from glob import glob 9 | from tqdm import tqdm 10 | import ffmpeg 11 | 12 | from data_config import ROOT 13 | 14 | def mp4_to_jpg(v, frame_folder): 15 | cap = cv2.VideoCapture(v) 16 | count = 0 17 | while(cap.isOpened()): 18 | ret, frame = cap.read() 19 | if ret == True: 20 | cv2.imwrite(f'{frame_folder}/{count:04d}.jpg', frame) 21 | count += 1 22 | else: 23 | break 24 | cap.release() 25 | 26 | def mp4_to_jpg_ffmpeg(v, frame_folder): 27 | ( 28 | ffmpeg.input(v) 29 | .output(f'{frame_folder}/%04d.jpg', 30 | vf='fps=30', 31 | start_number=0, 32 | qscale=1) 33 | .run(quiet=True) 34 | ) 35 | 36 | root = f'{ROOT}/bedlam_30fps' 37 | mp4_scene = sorted(glob(f'{root}/mp4/*')) 38 | 39 | for scene in mp4_scene: 40 | mp4_files = sorted(glob(f'{scene}/mp4/*.mp4')) 41 | for file in tqdm(mp4_files): 42 | s = file.split('/')[-3] 43 | seq = file.split('/')[-1][:-4] 44 | 45 | img_outdir = f'{root}/bedlam_data/images/{s}/jpg/{seq}' 46 | os.makedirs(img_outdir, exist_ok=True) 47 | 48 | # mp4_to_jpg(file, img_outdir) 49 | mp4_to_jpg_ffmpeg(file, img_outdir) 50 | 51 | -------------------------------------------------------------------------------- /scripts/visualize_tram.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.dirname(__file__) + '/..') 4 | 5 | import argparse 6 | import numpy as np 7 | from glob import glob 8 | from lib.pipeline import visualize_tram 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--video', type=str, default='./example_video.mov', help='input video') 12 | parser.add_argument('--bin_size', type=int, default=-1, help='rasterization bin_size; set to [64,128,...] to increase speed') 13 | parser.add_argument('--floor_scale', type=int, default=3, help='size of the floor') 14 | args = parser.parse_args() 15 | 16 | # File and folders 17 | file = args.video 18 | root = os.path.dirname(file) 19 | seq = os.path.basename(file).split('.')[0] 20 | 21 | seq_folder = f'results/{seq}' 22 | img_folder = f'{seq_folder}/images' 23 | imgfiles = sorted(glob(f'{img_folder}/*.jpg')) 24 | 25 | ##### Combine camera & human motion ##### 26 | # Render video 27 | print('Visualize results ...') 28 | visualize_tram(seq_folder, floor_scale=args.floor_scale, bin_size=args.bin_size) 29 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/.gitignore: -------------------------------------------------------------------------------- 1 | a# Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | 141 | 142 | __pycache__ 143 | build 144 | dist 145 | *.egg-info 146 | *.vscode/ 147 | *.pth 148 | tests 149 | checkpoints 150 | datasets 151 | runs 152 | cache 153 | *.out 154 | *.o 155 | data 156 | figures/*.pdf 157 | 158 | 159 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/lietorch"] 2 | path = thirdparty/lietorch 3 | url = https://github.com/princeton-vl/lietorch 4 | [submodule "thirdparty/eigen"] 5 | path = thirdparty/eigen 6 | url = https://gitlab.com/libeigen/eigen.git 7 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Princeton Vision & Learning Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/README.md: -------------------------------------------------------------------------------- 1 | # DROID-SLAM 2 | 3 | 4 | 5 | 6 | 7 | [![IMAGE ALT TEXT HERE](misc/screenshot.png)](https://www.youtube.com/watch?v=GG78CSlSHSA) 8 | 9 | 10 | 11 | [DROID-SLAM: Deep Visual SLAM for Monocular, Stereo, and RGB-D Cameras](https://arxiv.org/abs/2108.10869) 12 | Zachary Teed and Jia Deng 13 | 14 | ``` 15 | @article{teed2021droid, 16 | title={{DROID-SLAM: Deep Visual SLAM for Monocular, Stereo, and RGB-D Cameras}}, 17 | author={Teed, Zachary and Deng, Jia}, 18 | journal={Advances in neural information processing systems}, 19 | year={2021} 20 | } 21 | ``` 22 | 23 | **Initial Code Release:** This repo currently provides a single GPU implementation of our monocular, stereo, and RGB-D SLAM systems. It currently contains demos, training, and evaluation scripts. 24 | 25 | 26 | ## Requirements 27 | 28 | To run the code you will need ... 29 | * **Inference:** Running the demos will require a GPU with at least 11G of memory. 30 | 31 | * **Training:** Training requires a GPU with at least 24G of memory. We train on 4 x RTX-3090 GPUs. 32 | 33 | ## Getting Started 34 | 1. Clone the repo using the `--recursive` flag 35 | ```Bash 36 | git clone --recursive https://github.com/princeton-vl/DROID-SLAM.git 37 | ``` 38 | 39 | 2. Creating a new anaconda environment using the provided .yaml file. Use `environment_novis.yaml` to if you do not want to use the visualization 40 | ```Bash 41 | conda env create -f environment.yaml 42 | pip install evo --upgrade --no-binary evo 43 | pip install gdown 44 | ``` 45 | 46 | 3. Compile the extensions (takes about 10 minutes) 47 | ```Bash 48 | python setup.py install 49 | ``` 50 | 51 | 52 | ## Demos 53 | 54 | 1. Download the model from google drive: [droid.pth](https://drive.google.com/file/d/1PpqVt1H4maBa_GbPJp4NwxRsd9jk-elh/view?usp=sharing) 55 | 56 | 2. Download some sample videos using the provided script. 57 | ```Bash 58 | ./tools/download_sample_data.sh 59 | ``` 60 | 61 | Run the demo on any of the samples (all demos can be run on a GPU with 11G of memory). While running, press the "s" key to increase the filtering threshold (= more points) and "a" to decrease the filtering threshold (= fewer points). To save the reconstruction with full resolution depth maps use the `--reconstruction_path` flag. 62 | 63 | 64 | ```Python 65 | python demo.py --imagedir=data/abandonedfactory --calib=calib/tartan.txt --stride=2 66 | ``` 67 | 68 | ```Python 69 | python demo.py --imagedir=data/sfm_bench/rgb --calib=calib/eth.txt 70 | ``` 71 | 72 | ```Python 73 | python demo.py --imagedir=data/Barn --calib=calib/barn.txt --stride=1 --backend_nms=4 74 | ``` 75 | 76 | ```Python 77 | python demo.py --imagedir=data/mav0/cam0/data --calib=calib/euroc.txt --t0=150 78 | ``` 79 | 80 | ```Python 81 | python demo.py --imagedir=data/rgbd_dataset_freiburg3_cabinet/rgb --calib=calib/tum3.txt 82 | ``` 83 | 84 | 85 | **Running on your own data:** All you need is a calibration file. Calibration files are in the form 86 | ``` 87 | fx fy cx cy [k1 k2 p1 p2 [ k3 [ k4 k5 k6 ]]] 88 | ``` 89 | with parameters in brackets optional. 90 | 91 | ## Evaluation 92 | We provide evaluation scripts for TartanAir, EuRoC, and TUM. EuRoC and TUM can be run on a 1080Ti. The TartanAir and ETH will require 24G of memory. 93 | 94 | ### TartanAir (Mono + Stereo) 95 | Download the [TartanAir](https://theairlab.org/tartanair-dataset/) dataset using the script `thirdparty/tartanair_tools/download_training.py` and put them in `datasets/TartanAir` 96 | ```Bash 97 | ./tools/validate_tartanair.sh --plot_curve # monocular eval 98 | ./tools/validate_tartanair.sh --plot_curve --stereo # stereo eval 99 | ``` 100 | 101 | ### EuRoC (Mono + Stereo) 102 | Download the [EuRoC](https://projects.asl.ethz.ch/datasets/doku.php?id=kmavvisualinertialdatasets) sequences (ASL format) and put them in `datasets/EuRoC` 103 | ```Bash 104 | ./tools/evaluate_euroc.sh # monocular eval 105 | ./tools/evaluate_euroc.sh --stereo # stereo eval 106 | ``` 107 | 108 | ### TUM-RGBD (Mono) 109 | Download the fr1 sequences from [TUM-RGBD](https://vision.in.tum.de/data/datasets/rgbd-dataset/download) and put them in `datasets/TUM-RGBD` 110 | ```Bash 111 | ./tools/evaluate_tum.sh # monocular eval 112 | ``` 113 | 114 | ### ETH3D (RGB-D) 115 | Download the [ETH3D](https://www.eth3d.net/slam_datasets) dataset 116 | ```Bash 117 | ./tools/evaluate_eth3d.sh # RGB-D eval 118 | ``` 119 | 120 | ## Training 121 | 122 | First download the TartanAir dataset. The download script can be found in `thirdparty/tartanair_tools/download_training.py`. You will only need the `rgb` and `depth` data. 123 | 124 | ``` 125 | python download_training.py --rgb --depth 126 | ``` 127 | 128 | You can then run the training script. We use 4x3090 RTX GPUs for training which takes approximatly 1 week. If you use a different number of GPUs, adjust the learning rate accordingly. 129 | 130 | **Note:** On the first training run, covisibility is computed between all pairs of frames. This can take several hours, but the results are cached so that future training runs will start immediately. 131 | 132 | 133 | ``` 134 | python train.py --datapath= --gpus=4 --lr=0.00025 135 | ``` 136 | 137 | 138 | ## Acknowledgements 139 | Data from [TartanAir](https://theairlab.org/tartanair-dataset/) was used to train our model. We additionally use evaluation tools from [evo](https://github.com/MichaelGrupp/evo) and [tartanair_tools](https://github.com/castacks/tartanair_tools). 140 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/data_readers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/data_readers/augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class RGBDAugmentor: 8 | """ perform augmentation on RGB-D video """ 9 | 10 | def __init__(self, crop_size): 11 | self.crop_size = crop_size 12 | self.augcolor = transforms.Compose([ 13 | transforms.ToPILImage(), 14 | transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.4/3.14), 15 | transforms.RandomGrayscale(p=0.1), 16 | transforms.ToTensor()]) 17 | 18 | self.max_scale = 0.25 19 | 20 | def spatial_transform(self, images, depths, poses, intrinsics): 21 | """ cropping and resizing """ 22 | ht, wd = images.shape[2:] 23 | 24 | max_scale = self.max_scale 25 | min_scale = np.log2(np.maximum( 26 | (self.crop_size[0] + 1) / float(ht), 27 | (self.crop_size[1] + 1) / float(wd))) 28 | 29 | scale = 2 ** np.random.uniform(min_scale, max_scale) 30 | intrinsics = scale * intrinsics 31 | depths = depths.unsqueeze(dim=1) 32 | 33 | images = F.interpolate(images, scale_factor=scale, mode='bilinear', 34 | align_corners=False, recompute_scale_factor=True) 35 | 36 | depths = F.interpolate(depths, scale_factor=scale, recompute_scale_factor=True) 37 | 38 | # always perform center crop (TODO: try non-center crops) 39 | y0 = (images.shape[2] - self.crop_size[0]) // 2 40 | x0 = (images.shape[3] - self.crop_size[1]) // 2 41 | 42 | intrinsics = intrinsics - torch.tensor([0.0, 0.0, x0, y0]) 43 | images = images[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 44 | depths = depths[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 45 | 46 | depths = depths.squeeze(dim=1) 47 | return images, poses, depths, intrinsics 48 | 49 | def color_transform(self, images): 50 | """ color jittering """ 51 | num, ch, ht, wd = images.shape 52 | images = images.permute(1, 2, 3, 0).reshape(ch, ht, wd*num) 53 | images = 255 * self.augcolor(images[[2,1,0]] / 255.0) 54 | return images[[2,1,0]].reshape(ch, ht, wd, num).permute(3,0,1,2).contiguous() 55 | 56 | def __call__(self, images, poses, depths, intrinsics): 57 | images = self.color_transform(images) 58 | return self.spatial_transform(images, depths, poses, intrinsics) 59 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/data_readers/factory.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import os 4 | import os.path as osp 5 | 6 | # RGBD-Dataset 7 | from .tartan import TartanAir 8 | 9 | from .stream import ImageStream 10 | from .stream import StereoStream 11 | from .stream import RGBDStream 12 | 13 | # streaming datasets for inference 14 | from .tartan import TartanAirStream 15 | from .tartan import TartanAirTestStream 16 | 17 | def dataset_factory(dataset_list, **kwargs): 18 | """ create a combined dataset """ 19 | 20 | from torch.utils.data import ConcatDataset 21 | 22 | dataset_map = { 'tartan': (TartanAir, ) } 23 | db_list = [] 24 | for key in dataset_list: 25 | # cache datasets for faster future loading 26 | db = dataset_map[key][0](**kwargs) 27 | 28 | print("Dataset {} has {} images".format(key, len(db))) 29 | db_list.append(db) 30 | 31 | return ConcatDataset(db_list) 32 | 33 | 34 | def create_datastream(dataset_path, **kwargs): 35 | """ create data_loader to stream images 1 by 1 """ 36 | 37 | from torch.utils.data import DataLoader 38 | 39 | if osp.isfile(osp.join(dataset_path, 'calibration.txt')): 40 | db = ETH3DStream(dataset_path, **kwargs) 41 | 42 | elif osp.isdir(osp.join(dataset_path, 'image_left')): 43 | db = TartanAirStream(dataset_path, **kwargs) 44 | 45 | elif osp.isfile(osp.join(dataset_path, 'rgb.txt')): 46 | db = TUMStream(dataset_path, **kwargs) 47 | 48 | elif osp.isdir(osp.join(dataset_path, 'mav0')): 49 | db = EurocStream(dataset_path, **kwargs) 50 | 51 | elif osp.isfile(osp.join(dataset_path, 'calib.txt')): 52 | db = KITTIStream(dataset_path, **kwargs) 53 | 54 | else: 55 | # db = TartanAirStream(dataset_path, **kwargs) 56 | db = TartanAirTestStream(dataset_path, **kwargs) 57 | 58 | stream = DataLoader(db, shuffle=False, batch_size=1, num_workers=4) 59 | return stream 60 | 61 | 62 | def create_imagestream(dataset_path, **kwargs): 63 | """ create data_loader to stream images 1 by 1 """ 64 | from torch.utils.data import DataLoader 65 | 66 | db = ImageStream(dataset_path, **kwargs) 67 | return DataLoader(db, shuffle=False, batch_size=1, num_workers=4) 68 | 69 | def create_stereostream(dataset_path, **kwargs): 70 | """ create data_loader to stream images 1 by 1 """ 71 | from torch.utils.data import DataLoader 72 | 73 | db = StereoStream(dataset_path, **kwargs) 74 | return DataLoader(db, shuffle=False, batch_size=1, num_workers=4) 75 | 76 | def create_rgbdstream(dataset_path, **kwargs): 77 | """ create data_loader to stream images 1 by 1 """ 78 | from torch.utils.data import DataLoader 79 | 80 | db = RGBDStream(dataset_path, **kwargs) 81 | return DataLoader(db, shuffle=False, batch_size=1, num_workers=4) 82 | 83 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/data_readers/tartan.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import glob 5 | import cv2 6 | import os 7 | import os.path as osp 8 | 9 | from lietorch import SE3 10 | from .base import RGBDDataset 11 | from .stream import RGBDStream 12 | 13 | cur_path = osp.dirname(osp.abspath(__file__)) 14 | test_split = osp.join(cur_path, 'tartan_test.txt') 15 | test_split = open(test_split).read().split() 16 | 17 | 18 | class TartanAir(RGBDDataset): 19 | 20 | # scale depths to balance rot & trans 21 | DEPTH_SCALE = 5.0 22 | 23 | def __init__(self, mode='training', **kwargs): 24 | self.mode = mode 25 | self.n_frames = 2 26 | super(TartanAir, self).__init__(name='TartanAir', **kwargs) 27 | 28 | @staticmethod 29 | def is_test_scene(scene): 30 | # print(scene, any(x in scene for x in test_split)) 31 | return any(x in scene for x in test_split) 32 | 33 | def _build_dataset(self): 34 | from tqdm import tqdm 35 | print("Building TartanAir dataset") 36 | 37 | scene_info = {} 38 | scenes = glob.glob(osp.join(self.root, '*/*/*/*')) 39 | for scene in tqdm(sorted(scenes)): 40 | images = sorted(glob.glob(osp.join(scene, 'image_left/*.png'))) 41 | depths = sorted(glob.glob(osp.join(scene, 'depth_left/*.npy'))) 42 | 43 | poses = np.loadtxt(osp.join(scene, 'pose_left.txt'), delimiter=' ') 44 | poses = poses[:, [1, 2, 0, 4, 5, 3, 6]] 45 | poses[:,:3] /= TartanAir.DEPTH_SCALE 46 | intrinsics = [TartanAir.calib_read()] * len(images) 47 | 48 | # graph of co-visible frames based on flow 49 | graph = self.build_frame_graph(poses, depths, intrinsics) 50 | 51 | scene = '/'.join(scene.split('/')) 52 | scene_info[scene] = {'images': images, 'depths': depths, 53 | 'poses': poses, 'intrinsics': intrinsics, 'graph': graph} 54 | 55 | return scene_info 56 | 57 | @staticmethod 58 | def calib_read(): 59 | return np.array([320.0, 320.0, 320.0, 240.0]) 60 | 61 | @staticmethod 62 | def image_read(image_file): 63 | return cv2.imread(image_file) 64 | 65 | @staticmethod 66 | def depth_read(depth_file): 67 | depth = np.load(depth_file) / TartanAir.DEPTH_SCALE 68 | depth[depth==np.nan] = 1.0 69 | depth[depth==np.inf] = 1.0 70 | return depth 71 | 72 | 73 | class TartanAirStream(RGBDStream): 74 | def __init__(self, datapath, **kwargs): 75 | super(TartanAirStream, self).__init__(datapath=datapath, **kwargs) 76 | 77 | def _build_dataset_index(self): 78 | """ build list of images, poses, depths, and intrinsics """ 79 | self.root = 'datasets/TartanAir' 80 | 81 | scene = osp.join(self.root, self.datapath) 82 | image_glob = osp.join(scene, 'image_left/*.png') 83 | images = sorted(glob.glob(image_glob)) 84 | 85 | poses = np.loadtxt(osp.join(scene, 'pose_left.txt'), delimiter=' ') 86 | poses = poses[:, [1, 2, 0, 4, 5, 3, 6]] 87 | 88 | poses = SE3(torch.as_tensor(poses)) 89 | poses = poses[[0]].inv() * poses 90 | poses = poses.data.cpu().numpy() 91 | 92 | intrinsic = self.calib_read(self.datapath) 93 | intrinsics = np.tile(intrinsic[None], (len(images), 1)) 94 | 95 | self.images = images[::int(self.frame_rate)] 96 | self.poses = poses[::int(self.frame_rate)] 97 | self.intrinsics = intrinsics[::int(self.frame_rate)] 98 | 99 | @staticmethod 100 | def calib_read(datapath): 101 | return np.array([320.0, 320.0, 320.0, 240.0]) 102 | 103 | @staticmethod 104 | def image_read(image_file): 105 | return cv2.imread(image_file) 106 | 107 | 108 | class TartanAirTestStream(RGBDStream): 109 | def __init__(self, datapath, **kwargs): 110 | super(TartanAirTestStream, self).__init__(datapath=datapath, **kwargs) 111 | 112 | def _build_dataset_index(self): 113 | """ build list of images, poses, depths, and intrinsics """ 114 | self.root = 'datasets/mono' 115 | image_glob = osp.join(self.root, self.datapath, '*.png') 116 | images = sorted(glob.glob(image_glob)) 117 | 118 | poses = np.loadtxt(osp.join(self.root, 'mono_gt', self.datapath + '.txt'), delimiter=' ') 119 | poses = poses[:, [1, 2, 0, 4, 5, 3, 6]] 120 | 121 | poses = SE3(torch.as_tensor(poses)) 122 | poses = poses[[0]].inv() * poses 123 | poses = poses.data.cpu().numpy() 124 | 125 | intrinsic = self.calib_read(self.datapath) 126 | intrinsics = np.tile(intrinsic[None], (len(images), 1)) 127 | 128 | self.images = images[::int(self.frame_rate)] 129 | self.poses = poses[::int(self.frame_rate)] 130 | self.intrinsics = intrinsics[::int(self.frame_rate)] 131 | 132 | @staticmethod 133 | def calib_read(datapath): 134 | return np.array([320.0, 320.0, 320.0, 240.0]) 135 | 136 | @staticmethod 137 | def image_read(image_file): 138 | return cv2.imread(image_file) -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/data_readers/tartan_test.txt: -------------------------------------------------------------------------------- 1 | abandonedfactory/abandonedfactory/Easy/P011 2 | abandonedfactory/abandonedfactory/Hard/P011 3 | abandonedfactory_night/abandonedfactory_night/Easy/P013 4 | abandonedfactory_night/abandonedfactory_night/Hard/P014 5 | amusement/amusement/Easy/P008 6 | amusement/amusement/Hard/P007 7 | carwelding/carwelding/Easy/P007 8 | endofworld/endofworld/Easy/P009 9 | gascola/gascola/Easy/P008 10 | gascola/gascola/Hard/P009 11 | hospital/hospital/Easy/P036 12 | hospital/hospital/Hard/P049 13 | japanesealley/japanesealley/Easy/P007 14 | japanesealley/japanesealley/Hard/P005 15 | neighborhood/neighborhood/Easy/P021 16 | neighborhood/neighborhood/Hard/P017 17 | ocean/ocean/Easy/P013 18 | ocean/ocean/Hard/P009 19 | office2/office2/Easy/P011 20 | office2/office2/Hard/P010 21 | office/office/Hard/P007 22 | oldtown/oldtown/Easy/P007 23 | oldtown/oldtown/Hard/P008 24 | seasidetown/seasidetown/Easy/P009 25 | seasonsforest/seasonsforest/Easy/P011 26 | seasonsforest/seasonsforest/Hard/P006 27 | seasonsforest_winter/seasonsforest_winter/Easy/P009 28 | seasonsforest_winter/seasonsforest_winter/Hard/P018 29 | soulcity/soulcity/Easy/P012 30 | soulcity/soulcity/Hard/P009 31 | westerndesert/westerndesert/Easy/P013 32 | westerndesert/westerndesert/Hard/P007 33 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/droid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lietorch 3 | import numpy as np 4 | 5 | from droid_net import DroidNet 6 | from depth_video import DepthVideo 7 | from motion_filter import MotionFilter 8 | from droid_frontend import DroidFrontend 9 | from droid_backend import DroidBackend 10 | from trajectory_filler import PoseTrajectoryFiller 11 | 12 | from collections import OrderedDict 13 | from torch.multiprocessing import Process 14 | 15 | 16 | class Droid: 17 | def __init__(self, args): 18 | super(Droid, self).__init__() 19 | self.load_weights(args.weights) 20 | self.args = args 21 | self.disable_vis = args.disable_vis 22 | 23 | # store images, depth, poses, intrinsics (shared between processes) 24 | self.video = DepthVideo(args.image_size, args.buffer, stereo=args.stereo) 25 | 26 | # filter incoming frames so that there is enough motion 27 | self.filterx = MotionFilter(self.net, self.video, thresh=args.filter_thresh) 28 | 29 | # frontend process 30 | self.frontend = DroidFrontend(self.net, self.video, self.args) 31 | 32 | # backend process 33 | self.backend = DroidBackend(self.net, self.video, self.args) 34 | 35 | # visualizer 36 | if not self.disable_vis: 37 | # from visualization import droid_visualization 38 | from vis_headless import droid_visualization 39 | print('Using headless ...') 40 | self.visualizer = Process(target=droid_visualization, args=(self.video, '.')) 41 | self.visualizer.start() 42 | 43 | # post processor - fill in poses for non-keyframes 44 | self.traj_filler = PoseTrajectoryFiller(self.net, self.video) 45 | 46 | 47 | def load_weights(self, weights): 48 | """ load trained model weights """ 49 | 50 | self.net = DroidNet() 51 | state_dict = OrderedDict([ 52 | (k.replace("module.", ""), v) for (k, v) in torch.load(weights).items()]) 53 | 54 | state_dict["update.weight.2.weight"] = state_dict["update.weight.2.weight"][:2] 55 | state_dict["update.weight.2.bias"] = state_dict["update.weight.2.bias"][:2] 56 | state_dict["update.delta.2.weight"] = state_dict["update.delta.2.weight"][:2] 57 | state_dict["update.delta.2.bias"] = state_dict["update.delta.2.bias"][:2] 58 | 59 | self.net.load_state_dict(state_dict) 60 | self.net.to("cuda:0").eval() 61 | 62 | def track(self, tstamp, image, depth=None, intrinsics=None, mask=None): 63 | """ main thread - update map """ 64 | 65 | with torch.no_grad(): 66 | # check there is enough motion 67 | self.filterx.track(tstamp, image, depth, intrinsics, mask) 68 | 69 | # local bundle adjustment 70 | self.frontend() 71 | 72 | # global bundle adjustment 73 | # self.backend() 74 | 75 | def terminate(self, stream=None, backend=True): 76 | """ terminate the visualization process, return poses [t, q] """ 77 | 78 | del self.frontend 79 | 80 | if backend: 81 | torch.cuda.empty_cache() 82 | # print("#" * 32) 83 | self.backend(7) 84 | 85 | torch.cuda.empty_cache() 86 | # print("#" * 32) 87 | self.backend(12) 88 | 89 | camera_trajectory = self.traj_filler(stream) 90 | return camera_trajectory.inv().data.cpu().numpy() 91 | 92 | def compute_error(self): 93 | """ compute slam reprojection error """ 94 | 95 | del self.frontend 96 | 97 | torch.cuda.empty_cache() 98 | self.backend(12) 99 | 100 | return self.backend.errors[-1] 101 | 102 | 103 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/droid_backend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lietorch 3 | import numpy as np 4 | 5 | from lietorch import SE3 6 | from factor_graph import FactorGraph 7 | 8 | 9 | class DroidBackend: 10 | def __init__(self, net, video, args): 11 | self.video = video 12 | self.update_op = net.update 13 | 14 | # global optimization window 15 | self.t0 = 0 16 | self.t1 = 0 17 | 18 | self.upsample = args.upsample 19 | self.beta = args.beta 20 | self.backend_thresh = args.backend_thresh 21 | self.backend_radius = args.backend_radius 22 | self.backend_nms = args.backend_nms 23 | self.errors = [] 24 | 25 | @torch.no_grad() 26 | def __call__(self, steps=12): 27 | """ main update """ 28 | 29 | t = self.video.counter.value 30 | if not self.video.stereo and not torch.any(self.video.disps_sens): 31 | self.video.normalize() 32 | 33 | graph = FactorGraph(self.video, self.update_op, corr_impl="alt", max_factors=16*t, upsample=self.upsample) 34 | 35 | graph.add_proximity_factors(rad=self.backend_radius, 36 | nms=self.backend_nms, 37 | thresh=self.backend_thresh, 38 | beta=self.beta) 39 | 40 | graph.update_lowmem(steps=steps) 41 | self.errors.append(self.cal_err(graph)) 42 | graph.clear_edges() 43 | self.video.dirty[:t] = True 44 | 45 | return 46 | 47 | def cal_err(self, graph): 48 | coord, _ = graph.video.reproject(graph.ii, graph.jj) 49 | diff = graph.target - coord 50 | err = diff.norm(dim=-1).mean().item() 51 | return err 52 | 53 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/droid_frontend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lietorch 3 | import numpy as np 4 | 5 | from lietorch import SE3 6 | from factor_graph import FactorGraph 7 | 8 | 9 | class DroidFrontend: 10 | def __init__(self, net, video, args): 11 | self.video = video 12 | self.update_op = net.update 13 | self.graph = FactorGraph(video, net.update, max_factors=48, upsample=args.upsample) 14 | 15 | # local optimization window 16 | self.t0 = 0 17 | self.t1 = 0 18 | 19 | # frontent variables 20 | self.is_initialized = False 21 | self.count = 0 22 | 23 | self.max_age = 25 24 | self.iters1 = 4 25 | self.iters2 = 2 26 | 27 | self.warmup = args.warmup 28 | self.beta = args.beta 29 | self.frontend_nms = args.frontend_nms 30 | self.keyframe_thresh = args.keyframe_thresh 31 | self.frontend_window = args.frontend_window 32 | self.frontend_thresh = args.frontend_thresh 33 | self.frontend_radius = args.frontend_radius 34 | 35 | def __update(self): 36 | """ add edges, perform update """ 37 | 38 | self.count += 1 39 | self.t1 += 1 40 | 41 | if self.graph.corr is not None: 42 | self.graph.rm_factors(self.graph.age > self.max_age, store=True) 43 | 44 | self.graph.add_proximity_factors(self.t1-5, max(self.t1-self.frontend_window, 0), 45 | rad=self.frontend_radius, nms=self.frontend_nms, thresh=self.frontend_thresh, beta=self.beta, remove=True) 46 | 47 | self.video.disps[self.t1-1] = torch.where(self.video.disps_sens[self.t1-1] > 0, 48 | self.video.disps_sens[self.t1-1], self.video.disps[self.t1-1]) 49 | 50 | for itr in range(self.iters1): 51 | self.graph.update(None, None, use_inactive=True) 52 | 53 | # set initial pose for next frame 54 | poses = SE3(self.video.poses) 55 | d = self.video.distance([self.t1-3], [self.t1-2], beta=self.beta, bidirectional=True) 56 | 57 | if d.item() < self.keyframe_thresh: 58 | self.graph.rm_keyframe(self.t1 - 2) 59 | 60 | with self.video.get_lock(): 61 | self.video.counter.value -= 1 62 | self.t1 -= 1 63 | 64 | else: 65 | for itr in range(self.iters2): 66 | self.graph.update(None, None, use_inactive=True) 67 | 68 | # set pose for next itration 69 | self.video.poses[self.t1] = self.video.poses[self.t1-1] 70 | self.video.disps[self.t1] = self.video.disps[self.t1-1].mean() 71 | 72 | # update visualization 73 | self.video.dirty[self.graph.ii.min():self.t1] = True 74 | 75 | def __initialize(self): 76 | """ initialize the SLAM system """ 77 | 78 | self.t0 = 0 79 | self.t1 = self.video.counter.value 80 | 81 | self.graph.add_neighborhood_factors(self.t0, self.t1, r=3) 82 | 83 | for itr in range(8): 84 | self.graph.update(1, use_inactive=True) 85 | 86 | self.graph.add_proximity_factors(0, 0, rad=2, nms=2, thresh=self.frontend_thresh, remove=False) 87 | 88 | for itr in range(8): 89 | self.graph.update(1, use_inactive=True) 90 | 91 | 92 | # self.video.normalize() 93 | self.video.poses[self.t1] = self.video.poses[self.t1-1].clone() 94 | self.video.disps[self.t1] = self.video.disps[self.t1-4:self.t1].mean() 95 | 96 | # initialization complete 97 | self.is_initialized = True 98 | self.last_pose = self.video.poses[self.t1-1].clone() 99 | self.last_disp = self.video.disps[self.t1-1].clone() 100 | self.last_time = self.video.tstamp[self.t1-1].clone() 101 | 102 | with self.video.get_lock(): 103 | self.video.ready.value = 1 104 | self.video.dirty[:self.t1] = True 105 | 106 | self.graph.rm_factors(self.graph.ii < self.warmup-4, store=True) 107 | 108 | def __call__(self): 109 | """ main update """ 110 | 111 | # do initialization 112 | if not self.is_initialized and self.video.counter.value == self.warmup: 113 | self.__initialize() 114 | 115 | # do update 116 | elif self.is_initialized and self.t1 < self.video.counter.value: 117 | self.__update() 118 | 119 | 120 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/geom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/thirdparty/DROID-SLAM/droid_slam/geom/__init__.py -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/geom/ba.py: -------------------------------------------------------------------------------- 1 | import lietorch 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from .chol import block_solve, schur_solve 6 | import geom.projective_ops as pops 7 | 8 | from torch_scatter import scatter_sum 9 | 10 | 11 | # utility functions for scattering ops 12 | def safe_scatter_add_mat(A, ii, jj, n, m): 13 | v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m) 14 | return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m) 15 | 16 | def safe_scatter_add_vec(b, ii, n): 17 | v = (ii >= 0) & (ii < n) 18 | return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n) 19 | 20 | # apply retraction operator to inv-depth maps 21 | def disp_retr(disps, dz, ii): 22 | ii = ii.to(device=dz.device) 23 | return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1]) 24 | 25 | # apply retraction operator to poses 26 | def pose_retr(poses, dx, ii): 27 | ii = ii.to(device=dx.device) 28 | return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1])) 29 | 30 | 31 | def BA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): 32 | """ Full Bundle Adjustment """ 33 | 34 | B, P, ht, wd = disps.shape 35 | N = ii.shape[0] 36 | D = poses.manifold_dim 37 | 38 | ### 1: commpute jacobians and residuals ### 39 | coords, valid, (Ji, Jj, Jz) = pops.projective_transform( 40 | poses, disps, intrinsics, ii, jj, jacobian=True) 41 | 42 | r = (target - coords).view(B, N, -1, 1) 43 | w = .001 * (valid * weight).view(B, N, -1, 1) 44 | 45 | ### 2: construct linear system ### 46 | Ji = Ji.reshape(B, N, -1, D) 47 | Jj = Jj.reshape(B, N, -1, D) 48 | wJiT = (w * Ji).transpose(2,3) 49 | wJjT = (w * Jj).transpose(2,3) 50 | 51 | Jz = Jz.reshape(B, N, ht*wd, -1) 52 | 53 | Hii = torch.matmul(wJiT, Ji) 54 | Hij = torch.matmul(wJiT, Jj) 55 | Hji = torch.matmul(wJjT, Ji) 56 | Hjj = torch.matmul(wJjT, Jj) 57 | 58 | vi = torch.matmul(wJiT, r).squeeze(-1) 59 | vj = torch.matmul(wJjT, r).squeeze(-1) 60 | 61 | Ei = (wJiT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) 62 | Ej = (wJjT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1) 63 | 64 | w = w.view(B, N, ht*wd, -1) 65 | r = r.view(B, N, ht*wd, -1) 66 | wk = torch.sum(w*r*Jz, dim=-1) 67 | Ck = torch.sum(w*Jz*Jz, dim=-1) 68 | 69 | kx, kk = torch.unique(ii, return_inverse=True) 70 | M = kx.shape[0] 71 | 72 | # only optimize keyframe poses 73 | P = P // rig - fixedp 74 | ii = ii // rig - fixedp 75 | jj = jj // rig - fixedp 76 | 77 | H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ 78 | safe_scatter_add_mat(Hij, ii, jj, P, P) + \ 79 | safe_scatter_add_mat(Hji, jj, ii, P, P) + \ 80 | safe_scatter_add_mat(Hjj, jj, jj, P, P) 81 | 82 | E = safe_scatter_add_mat(Ei, ii, kk, P, M) + \ 83 | safe_scatter_add_mat(Ej, jj, kk, P, M) 84 | 85 | v = safe_scatter_add_vec(vi, ii, P) + \ 86 | safe_scatter_add_vec(vj, jj, P) 87 | 88 | C = safe_scatter_add_vec(Ck, kk, M) 89 | w = safe_scatter_add_vec(wk, kk, M) 90 | 91 | C = C + eta.view(*C.shape) + 1e-7 92 | 93 | H = H.view(B, P, P, D, D) 94 | E = E.view(B, P, M, D, ht*wd) 95 | 96 | ### 3: solve the system ### 97 | dx, dz = schur_solve(H, E, C, v, w) 98 | 99 | ### 4: apply retraction ### 100 | poses = pose_retr(poses, dx, torch.arange(P) + fixedp) 101 | disps = disp_retr(disps, dz.view(B,-1,ht,wd), kx) 102 | 103 | disps = torch.where(disps > 10, torch.zeros_like(disps), disps) 104 | disps = disps.clamp(min=0.0) 105 | 106 | return poses, disps 107 | 108 | 109 | def MoBA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1): 110 | """ Motion only bundle adjustment """ 111 | 112 | B, P, ht, wd = disps.shape 113 | N = ii.shape[0] 114 | D = poses.manifold_dim 115 | 116 | ### 1: commpute jacobians and residuals ### 117 | coords, valid, (Ji, Jj, Jz) = pops.projective_transform( 118 | poses, disps, intrinsics, ii, jj, jacobian=True) 119 | 120 | r = (target - coords).view(B, N, -1, 1) 121 | w = .001 * (valid * weight).view(B, N, -1, 1) 122 | 123 | ### 2: construct linear system ### 124 | Ji = Ji.reshape(B, N, -1, D) 125 | Jj = Jj.reshape(B, N, -1, D) 126 | wJiT = (w * Ji).transpose(2,3) 127 | wJjT = (w * Jj).transpose(2,3) 128 | 129 | Hii = torch.matmul(wJiT, Ji) 130 | Hij = torch.matmul(wJiT, Jj) 131 | Hji = torch.matmul(wJjT, Ji) 132 | Hjj = torch.matmul(wJjT, Jj) 133 | 134 | vi = torch.matmul(wJiT, r).squeeze(-1) 135 | vj = torch.matmul(wJjT, r).squeeze(-1) 136 | 137 | # only optimize keyframe poses 138 | P = P // rig - fixedp 139 | ii = ii // rig - fixedp 140 | jj = jj // rig - fixedp 141 | 142 | H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \ 143 | safe_scatter_add_mat(Hij, ii, jj, P, P) + \ 144 | safe_scatter_add_mat(Hji, jj, ii, P, P) + \ 145 | safe_scatter_add_mat(Hjj, jj, jj, P, P) 146 | 147 | v = safe_scatter_add_vec(vi, ii, P) + \ 148 | safe_scatter_add_vec(vj, jj, P) 149 | 150 | H = H.view(B, P, P, D, D) 151 | 152 | ### 3: solve the system ### 153 | dx = block_solve(H, v) 154 | 155 | ### 4: apply retraction ### 156 | poses = pose_retr(poses, dx, torch.arange(P) + fixedp) 157 | return poses 158 | 159 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/geom/chol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import geom.projective_ops as pops 4 | 5 | class CholeskySolver(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, H, b): 8 | # don't crash training if cholesky decomp fails 9 | try: 10 | U = torch.linalg.cholesky(H) 11 | xs = torch.cholesky_solve(b, U) 12 | ctx.save_for_backward(U, xs) 13 | ctx.failed = False 14 | except Exception as e: 15 | print(e) 16 | ctx.failed = True 17 | xs = torch.zeros_like(b) 18 | 19 | return xs 20 | 21 | @staticmethod 22 | def backward(ctx, grad_x): 23 | if ctx.failed: 24 | return None, None 25 | 26 | U, xs = ctx.saved_tensors 27 | dz = torch.cholesky_solve(grad_x, U) 28 | dH = -torch.matmul(xs, dz.transpose(-1,-2)) 29 | 30 | return dH, dz 31 | 32 | def block_solve(H, b, ep=0.1, lm=0.0001): 33 | """ solve normal equations """ 34 | B, N, _, D, _ = H.shape 35 | I = torch.eye(D).to(H.device) 36 | H = H + (ep + lm*H) * I 37 | 38 | H = H.permute(0,1,3,2,4) 39 | H = H.reshape(B, N*D, N*D) 40 | b = b.reshape(B, N*D, 1) 41 | 42 | x = CholeskySolver.apply(H,b) 43 | return x.reshape(B, N, D) 44 | 45 | 46 | def schur_solve(H, E, C, v, w, ep=0.1, lm=0.0001, sless=False): 47 | """ solve using shur complement """ 48 | 49 | B, P, M, D, HW = E.shape 50 | H = H.permute(0,1,3,2,4).reshape(B, P*D, P*D) 51 | E = E.permute(0,1,3,2,4).reshape(B, P*D, M*HW) 52 | Q = (1.0 / C).view(B, M*HW, 1) 53 | 54 | # damping 55 | I = torch.eye(P*D).to(H.device) 56 | H = H + (ep + lm*H) * I 57 | 58 | v = v.reshape(B, P*D, 1) 59 | w = w.reshape(B, M*HW, 1) 60 | 61 | Et = E.transpose(1,2) 62 | S = H - torch.matmul(E, Q*Et) 63 | v = v - torch.matmul(E, Q*w) 64 | 65 | dx = CholeskySolver.apply(S, v) 66 | if sless: 67 | return dx.reshape(B, P, D) 68 | 69 | dz = Q * (w - Et @ dx) 70 | dx = dx.reshape(B, P, D) 71 | dz = dz.reshape(B, M, HW) 72 | 73 | return dx, dz -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/geom/graph_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | import lietorch 7 | from data_readers.rgbd_utils import compute_distance_matrix_flow, compute_distance_matrix_flow2 8 | 9 | 10 | def graph_to_edge_list(graph): 11 | ii, jj, kk = [], [], [] 12 | for s, u in enumerate(graph): 13 | for v in graph[u]: 14 | ii.append(u) 15 | jj.append(v) 16 | kk.append(s) 17 | 18 | ii = torch.as_tensor(ii) 19 | jj = torch.as_tensor(jj) 20 | kk = torch.as_tensor(kk) 21 | return ii, jj, kk 22 | 23 | def keyframe_indicies(graph): 24 | return torch.as_tensor([u for u in graph]) 25 | 26 | def meshgrid(m, n, device='cuda'): 27 | ii, jj = torch.meshgrid(torch.arange(m), torch.arange(n), indexing='ij') 28 | return ii.reshape(-1).to(device), jj.reshape(-1).to(device) 29 | 30 | def neighbourhood_graph(n, r): 31 | ii, jj = meshgrid(n, n) 32 | d = (ii - jj).abs() 33 | keep = (d >= 1) & (d <= r) 34 | return ii[keep], jj[keep] 35 | 36 | 37 | def build_frame_graph(poses, disps, intrinsics, num=16, thresh=24.0, r=2): 38 | """ construct a frame graph between co-visible frames """ 39 | N = poses.shape[1] 40 | poses = poses[0].cpu().numpy() 41 | disps = disps[0][:,3::8,3::8].cpu().numpy() 42 | intrinsics = intrinsics[0].cpu().numpy() / 8.0 43 | d = compute_distance_matrix_flow(poses, disps, intrinsics) 44 | 45 | count = 0 46 | graph = OrderedDict() 47 | 48 | for i in range(N): 49 | graph[i] = [] 50 | d[i,i] = np.inf 51 | for j in range(i-r, i+r+1): 52 | if 0 <= j < N and i != j: 53 | graph[i].append(j) 54 | d[i,j] = np.inf 55 | count += 1 56 | 57 | while count < num: 58 | ix = np.argmin(d) 59 | i, j = ix // N, ix % N 60 | 61 | if d[i,j] < thresh: 62 | graph[i].append(j) 63 | d[i,j] = np.inf 64 | count += 1 65 | else: 66 | break 67 | 68 | return graph 69 | 70 | 71 | 72 | def build_frame_graph_v2(poses, disps, intrinsics, num=16, thresh=24.0, r=2): 73 | """ construct a frame graph between co-visible frames """ 74 | N = poses.shape[1] 75 | # poses = poses[0].cpu().numpy() 76 | # disps = disps[0].cpu().numpy() 77 | # intrinsics = intrinsics[0].cpu().numpy() 78 | d = compute_distance_matrix_flow2(poses, disps, intrinsics) 79 | 80 | # import matplotlib.pyplot as plt 81 | # plt.imshow(d) 82 | # plt.show() 83 | 84 | count = 0 85 | graph = OrderedDict() 86 | 87 | for i in range(N): 88 | graph[i] = [] 89 | d[i,i] = np.inf 90 | for j in range(i-r, i+r+1): 91 | if 0 <= j < N and i != j: 92 | graph[i].append(j) 93 | d[i,j] = np.inf 94 | count += 1 95 | 96 | while 1: 97 | ix = np.argmin(d) 98 | i, j = ix // N, ix % N 99 | 100 | if d[i,j] < thresh: 101 | graph[i].append(j) 102 | 103 | for i1 in range(i-1, i+2): 104 | for j1 in range(j-1, j+2): 105 | if 0 <= i1 < N and 0 <= j1 < N: 106 | d[i1, j1] = np.inf 107 | 108 | count += 1 109 | else: 110 | break 111 | 112 | return graph 113 | 114 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/geom/losses.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | import torch 4 | from lietorch import SO3, SE3, Sim3 5 | from .graph_utils import graph_to_edge_list 6 | from .projective_ops import projective_transform 7 | 8 | 9 | def pose_metrics(dE): 10 | """ Translation/Rotation/Scaling metrics from Sim3 """ 11 | t, q, s = dE.data.split([3, 4, 1], -1) 12 | ang = SO3(q).log().norm(dim=-1) 13 | 14 | # convert radians to degrees 15 | r_err = (180 / np.pi) * ang 16 | t_err = t.norm(dim=-1) 17 | s_err = (s - 1.0).abs() 18 | return r_err, t_err, s_err 19 | 20 | 21 | def fit_scale(Ps, Gs): 22 | b = Ps.shape[0] 23 | t1 = Ps.data[...,:3].detach().reshape(b, -1) 24 | t2 = Gs.data[...,:3].detach().reshape(b, -1) 25 | 26 | s = (t1*t2).sum(-1) / ((t2*t2).sum(-1) + 1e-8) 27 | return s 28 | 29 | 30 | def geodesic_loss(Ps, Gs, graph, gamma=0.9, do_scale=True): 31 | """ Loss function for training network """ 32 | 33 | # relative pose 34 | ii, jj, kk = graph_to_edge_list(graph) 35 | dP = Ps[:,jj] * Ps[:,ii].inv() 36 | 37 | n = len(Gs) 38 | geodesic_loss = 0.0 39 | 40 | for i in range(n): 41 | w = gamma ** (n - i - 1) 42 | dG = Gs[i][:,jj] * Gs[i][:,ii].inv() 43 | 44 | if do_scale: 45 | s = fit_scale(dP, dG) 46 | dG = dG.scale(s[:,None]) 47 | 48 | # pose error 49 | d = (dG * dP.inv()).log() 50 | 51 | if isinstance(dG, SE3): 52 | tau, phi = d.split([3,3], dim=-1) 53 | geodesic_loss += w * ( 54 | tau.norm(dim=-1).mean() + 55 | phi.norm(dim=-1).mean()) 56 | 57 | elif isinstance(dG, Sim3): 58 | tau, phi, sig = d.split([3,3,1], dim=-1) 59 | geodesic_loss += w * ( 60 | tau.norm(dim=-1).mean() + 61 | phi.norm(dim=-1).mean() + 62 | 0.05 * sig.norm(dim=-1).mean()) 63 | 64 | dE = Sim3(dG * dP.inv()).detach() 65 | r_err, t_err, s_err = pose_metrics(dE) 66 | 67 | metrics = { 68 | 'rot_error': r_err.mean().item(), 69 | 'tr_error': t_err.mean().item(), 70 | 'bad_rot': (r_err < .1).float().mean().item(), 71 | 'bad_tr': (t_err < .01).float().mean().item(), 72 | } 73 | 74 | return geodesic_loss, metrics 75 | 76 | 77 | def residual_loss(residuals, gamma=0.9): 78 | """ loss on system residuals """ 79 | residual_loss = 0.0 80 | n = len(residuals) 81 | 82 | for i in range(n): 83 | w = gamma ** (n - i - 1) 84 | residual_loss += w * residuals[i].abs().mean() 85 | 86 | return residual_loss, {'residual': residual_loss.item()} 87 | 88 | 89 | def flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph, gamma=0.9): 90 | """ optical flow loss """ 91 | 92 | N = Ps.shape[1] 93 | graph = OrderedDict() 94 | for i in range(N): 95 | graph[i] = [j for j in range(N) if abs(i-j)==1] 96 | 97 | ii, jj, kk = graph_to_edge_list(graph) 98 | coords0, val0 = projective_transform(Ps, disps, intrinsics, ii, jj) 99 | val0 = val0 * (disps[:,ii] > 0).float().unsqueeze(dim=-1) 100 | 101 | n = len(poses_est) 102 | flow_loss = 0.0 103 | 104 | for i in range(n): 105 | w = gamma ** (n - i - 1) 106 | coords1, val1 = projective_transform(poses_est[i], disps_est[i], intrinsics, ii, jj) 107 | 108 | v = (val0 * val1).squeeze(dim=-1) 109 | epe = v * (coords1 - coords0).norm(dim=-1) 110 | flow_loss += w * epe.mean() 111 | 112 | epe = epe.reshape(-1)[v.reshape(-1) > 0.5] 113 | metrics = { 114 | 'f_error': epe.mean().item(), 115 | '1px': (epe<1.0).float().mean().item(), 116 | } 117 | 118 | return flow_loss, metrics 119 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/geom/projective_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from lietorch import SE3, Sim3 5 | 6 | MIN_DEPTH = 0.2 7 | 8 | def extract_intrinsics(intrinsics): 9 | return intrinsics[...,None,None,:].unbind(dim=-1) 10 | 11 | def coords_grid(ht, wd, **kwargs): 12 | y, x = torch.meshgrid( 13 | torch.arange(ht).to(**kwargs).float(), 14 | torch.arange(wd).to(**kwargs).float(), indexing='ij') 15 | 16 | return torch.stack([x, y], dim=-1) 17 | 18 | def iproj(disps, intrinsics, jacobian=False): 19 | """ pinhole camera inverse projection """ 20 | ht, wd = disps.shape[2:] 21 | fx, fy, cx, cy = extract_intrinsics(intrinsics) 22 | 23 | y, x = torch.meshgrid( 24 | torch.arange(ht).to(disps.device).float(), 25 | torch.arange(wd).to(disps.device).float(), indexing='ij') 26 | 27 | i = torch.ones_like(disps) 28 | X = (x - cx) / fx 29 | Y = (y - cy) / fy 30 | pts = torch.stack([X, Y, i, disps], dim=-1) 31 | 32 | if jacobian: 33 | J = torch.zeros_like(pts) 34 | J[...,-1] = 1.0 35 | return pts, J 36 | 37 | return pts, None 38 | 39 | def proj(Xs, intrinsics, jacobian=False, return_depth=False): 40 | """ pinhole camera projection """ 41 | fx, fy, cx, cy = extract_intrinsics(intrinsics) 42 | X, Y, Z, D = Xs.unbind(dim=-1) 43 | 44 | Z = torch.where(Z < 0.5*MIN_DEPTH, torch.ones_like(Z), Z) 45 | d = 1.0 / Z 46 | 47 | x = fx * (X * d) + cx 48 | y = fy * (Y * d) + cy 49 | if return_depth: 50 | coords = torch.stack([x, y, D*d], dim=-1) 51 | else: 52 | coords = torch.stack([x, y], dim=-1) 53 | 54 | if jacobian: 55 | B, N, H, W = d.shape 56 | o = torch.zeros_like(d) 57 | proj_jac = torch.stack([ 58 | fx*d, o, -fx*X*d*d, o, 59 | o, fy*d, -fy*Y*d*d, o, 60 | # o, o, -D*d*d, d, 61 | ], dim=-1).view(B, N, H, W, 2, 4) 62 | 63 | return coords, proj_jac 64 | 65 | return coords, None 66 | 67 | def actp(Gij, X0, jacobian=False): 68 | """ action on point cloud """ 69 | X1 = Gij[:,:,None,None] * X0 70 | 71 | if jacobian: 72 | X, Y, Z, d = X1.unbind(dim=-1) 73 | o = torch.zeros_like(d) 74 | B, N, H, W = d.shape 75 | 76 | if isinstance(Gij, SE3): 77 | Ja = torch.stack([ 78 | d, o, o, o, Z, -Y, 79 | o, d, o, -Z, o, X, 80 | o, o, d, Y, -X, o, 81 | o, o, o, o, o, o, 82 | ], dim=-1).view(B, N, H, W, 4, 6) 83 | 84 | elif isinstance(Gij, Sim3): 85 | Ja = torch.stack([ 86 | d, o, o, o, Z, -Y, X, 87 | o, d, o, -Z, o, X, Y, 88 | o, o, d, Y, -X, o, Z, 89 | o, o, o, o, o, o, o 90 | ], dim=-1).view(B, N, H, W, 4, 7) 91 | 92 | return X1, Ja 93 | 94 | return X1, None 95 | 96 | def projective_transform(poses, depths, intrinsics, ii, jj, jacobian=False, return_depth=False): 97 | """ map points from ii->jj """ 98 | 99 | # inverse project (pinhole) 100 | X0, Jz = iproj(depths[:,ii], intrinsics[:,ii], jacobian=jacobian) 101 | 102 | # transform 103 | Gij = poses[:,jj] * poses[:,ii].inv() 104 | 105 | Gij.data[:,ii==jj] = torch.as_tensor([-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device="cuda") 106 | X1, Ja = actp(Gij, X0, jacobian=jacobian) 107 | 108 | # project (pinhole) 109 | x1, Jp = proj(X1, intrinsics[:,jj], jacobian=jacobian, return_depth=return_depth) 110 | 111 | # exclude points too close to camera 112 | valid = ((X1[...,2] > MIN_DEPTH) & (X0[...,2] > MIN_DEPTH)).float() 113 | valid = valid.unsqueeze(-1) 114 | 115 | if jacobian: 116 | # Ji transforms according to dual adjoint 117 | Jj = torch.matmul(Jp, Ja) 118 | Ji = -Gij[:,:,None,None,None].adjT(Jj) 119 | 120 | Jz = Gij[:,:,None,None] * Jz 121 | Jz = torch.matmul(Jp, Jz.unsqueeze(-1)) 122 | 123 | return x1, valid, (Ji, Jj, Jz) 124 | 125 | return x1, valid 126 | 127 | def induced_flow(poses, disps, intrinsics, ii, jj): 128 | """ optical flow induced by camera motion """ 129 | 130 | ht, wd = disps.shape[2:] 131 | y, x = torch.meshgrid( 132 | torch.arange(ht).to(disps.device).float(), 133 | torch.arange(wd).to(disps.device).float(), indexing='ij') 134 | 135 | coords0 = torch.stack([x, y], dim=-1) 136 | coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False) 137 | 138 | return coords1[...,:2] - coords0, valid 139 | 140 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | 6 | SUM_FREQ = 100 7 | 8 | class Logger: 9 | def __init__(self, name, scheduler): 10 | self.total_steps = 0 11 | self.running_loss = {} 12 | self.writer = None 13 | self.name = name 14 | self.scheduler = scheduler 15 | 16 | def _print_training_status(self): 17 | if self.writer is None: 18 | self.writer = SummaryWriter('runs/%s' % self.name) 19 | print([k for k in self.running_loss]) 20 | 21 | lr = self.scheduler.get_lr().pop() 22 | metrics_data = [self.running_loss[k]/SUM_FREQ for k in self.running_loss.keys()] 23 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, lr) 24 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 25 | 26 | # print the training status 27 | print(training_str + metrics_str) 28 | 29 | for key in self.running_loss: 30 | val = self.running_loss[key] / SUM_FREQ 31 | self.writer.add_scalar(key, val, self.total_steps) 32 | self.running_loss[key] = 0.0 33 | 34 | def push(self, metrics): 35 | 36 | for key in metrics: 37 | if key not in self.running_loss: 38 | self.running_loss[key] = 0.0 39 | 40 | self.running_loss[key] += metrics[key] 41 | 42 | if self.total_steps % SUM_FREQ == SUM_FREQ-1: 43 | self._print_training_status() 44 | self.running_loss = {} 45 | 46 | self.total_steps += 1 47 | 48 | def write_dict(self, results): 49 | for key in results: 50 | self.writer.add_scalar(key, results[key], self.total_steps) 51 | 52 | def close(self): 53 | self.writer.close() 54 | 55 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/thirdparty/DROID-SLAM/droid_slam/modules/__init__.py -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/modules/clipping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | GRAD_CLIP = .01 6 | 7 | class GradClip(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x): 10 | return x 11 | 12 | @staticmethod 13 | def backward(ctx, grad_x): 14 | o = torch.zeros_like(grad_x) 15 | grad_x = torch.where(grad_x.abs()>GRAD_CLIP, o, grad_x) 16 | grad_x = torch.where(torch.isnan(grad_x), o, grad_x) 17 | return grad_x 18 | 19 | class GradientClip(nn.Module): 20 | def __init__(self): 21 | super(GradientClip, self).__init__() 22 | 23 | def forward(self, x): 24 | return GradClip.apply(x) -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/modules/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import droid_backends 5 | 6 | class CorrSampler(torch.autograd.Function): 7 | 8 | @staticmethod 9 | def forward(ctx, volume, coords, radius): 10 | ctx.save_for_backward(volume,coords) 11 | ctx.radius = radius 12 | corr, = droid_backends.corr_index_forward(volume, coords, radius) 13 | return corr 14 | 15 | @staticmethod 16 | def backward(ctx, grad_output): 17 | volume, coords = ctx.saved_tensors 18 | grad_output = grad_output.contiguous() 19 | grad_volume, = droid_backends.corr_index_backward(volume, coords, grad_output, ctx.radius) 20 | return grad_volume, None, None 21 | 22 | 23 | class CorrBlock: 24 | def __init__(self, fmap1, fmap2, num_levels=4, radius=3): 25 | self.num_levels = num_levels 26 | self.radius = radius 27 | self.corr_pyramid = [] 28 | 29 | # all pairs correlation 30 | corr = CorrBlock.corr(fmap1, fmap2) 31 | 32 | batch, num, h1, w1, h2, w2 = corr.shape 33 | corr = corr.reshape(batch*num*h1*w1, 1, h2, w2) 34 | 35 | for i in range(self.num_levels): 36 | self.corr_pyramid.append( 37 | corr.view(batch*num, h1, w1, h2//2**i, w2//2**i)) 38 | corr = F.avg_pool2d(corr, 2, stride=2) 39 | 40 | def __call__(self, coords): 41 | out_pyramid = [] 42 | batch, num, ht, wd, _ = coords.shape 43 | coords = coords.permute(0,1,4,2,3) 44 | coords = coords.contiguous().view(batch*num, 2, ht, wd) 45 | 46 | for i in range(self.num_levels): 47 | corr = CorrSampler.apply(self.corr_pyramid[i], coords/2**i, self.radius) 48 | out_pyramid.append(corr.view(batch, num, -1, ht, wd)) 49 | 50 | return torch.cat(out_pyramid, dim=2) 51 | 52 | def cat(self, other): 53 | for i in range(self.num_levels): 54 | self.corr_pyramid[i] = torch.cat([self.corr_pyramid[i], other.corr_pyramid[i]], 0) 55 | return self 56 | 57 | def __getitem__(self, index): 58 | for i in range(self.num_levels): 59 | self.corr_pyramid[i] = self.corr_pyramid[i][index] 60 | return self 61 | 62 | 63 | @staticmethod 64 | def corr(fmap1, fmap2): 65 | """ all-pairs correlation """ 66 | batch, num, dim, ht, wd = fmap1.shape 67 | fmap1 = fmap1.reshape(batch*num, dim, ht*wd) / 4.0 68 | fmap2 = fmap2.reshape(batch*num, dim, ht*wd) / 4.0 69 | 70 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 71 | return corr.view(batch, num, ht, wd, ht, wd) 72 | 73 | 74 | class CorrLayer(torch.autograd.Function): 75 | @staticmethod 76 | def forward(ctx, fmap1, fmap2, coords, r): 77 | ctx.r = r 78 | ctx.save_for_backward(fmap1, fmap2, coords) 79 | corr, = droid_backends.altcorr_forward(fmap1, fmap2, coords, ctx.r) 80 | return corr 81 | 82 | @staticmethod 83 | def backward(ctx, grad_corr): 84 | fmap1, fmap2, coords = ctx.saved_tensors 85 | grad_corr = grad_corr.contiguous() 86 | fmap1_grad, fmap2_grad, coords_grad = \ 87 | droid_backends.altcorr_backward(fmap1, fmap2, coords, grad_corr, ctx.r) 88 | return fmap1_grad, fmap2_grad, coords_grad, None 89 | 90 | 91 | class AltCorrBlock: 92 | def __init__(self, fmaps, num_levels=4, radius=3): 93 | self.num_levels = num_levels 94 | self.radius = radius 95 | 96 | B, N, C, H, W = fmaps.shape 97 | fmaps = fmaps.view(B*N, C, H, W) / 4.0 98 | 99 | self.pyramid = [] 100 | for i in range(self.num_levels): 101 | sz = (B, N, H//2**i, W//2**i, C) 102 | fmap_lvl = fmaps.permute(0, 2, 3, 1).contiguous() 103 | self.pyramid.append(fmap_lvl.view(*sz)) 104 | fmaps = F.avg_pool2d(fmaps, 2, stride=2) 105 | 106 | def corr_fn(self, coords, ii, jj): 107 | B, N, H, W, S, _ = coords.shape 108 | coords = coords.permute(0, 1, 4, 2, 3, 5) 109 | 110 | corr_list = [] 111 | for i in range(self.num_levels): 112 | r = self.radius 113 | fmap1_i = self.pyramid[0][:, ii] 114 | fmap2_i = self.pyramid[i][:, jj] 115 | 116 | coords_i = (coords / 2**i).reshape(B*N, S, H, W, 2).contiguous() 117 | fmap1_i = fmap1_i.reshape((B*N,) + fmap1_i.shape[2:]) 118 | fmap2_i = fmap2_i.reshape((B*N,) + fmap2_i.shape[2:]) 119 | 120 | corr = CorrLayer.apply(fmap1_i.float(), fmap2_i.float(), coords_i, self.radius) 121 | corr = corr.view(B, N, S, -1, H, W).permute(0, 1, 3, 4, 5, 2) 122 | corr_list.append(corr) 123 | 124 | corr = torch.cat(corr_list, dim=2) 125 | return corr 126 | 127 | 128 | def __call__(self, coords, ii, jj): 129 | squeeze_output = False 130 | if len(coords.shape) == 5: 131 | coords = coords.unsqueeze(dim=-2) 132 | squeeze_output = True 133 | 134 | corr = self.corr_fn(coords, ii, jj) 135 | 136 | if squeeze_output: 137 | corr = corr.squeeze(dim=-1) 138 | 139 | return corr.contiguous() 140 | 141 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/modules/gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvGRU(nn.Module): 6 | def __init__(self, h_planes=128, i_planes=128): 7 | super(ConvGRU, self).__init__() 8 | self.do_checkpoint = False 9 | self.convz = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1) 10 | self.convr = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1) 11 | self.convq = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1) 12 | 13 | self.w = nn.Conv2d(h_planes, h_planes, 1, padding=0) 14 | 15 | self.convz_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0) 16 | self.convr_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0) 17 | self.convq_glo = nn.Conv2d(h_planes, h_planes, 1, padding=0) 18 | 19 | def forward(self, net, *inputs): 20 | inp = torch.cat(inputs, dim=1) 21 | net_inp = torch.cat([net, inp], dim=1) 22 | 23 | b, c, h, w = net.shape 24 | glo = torch.sigmoid(self.w(net)) * net 25 | glo = glo.view(b, c, h*w).mean(-1).view(b, c, 1, 1) 26 | 27 | z = torch.sigmoid(self.convz(net_inp) + self.convz_glo(glo)) 28 | r = torch.sigmoid(self.convr(net_inp) + self.convr_glo(glo)) 29 | q = torch.tanh(self.convq(torch.cat([r*net, inp], dim=1)) + self.convq_glo(glo)) 30 | 31 | net = (1-z) * net + z * q 32 | return net 33 | 34 | 35 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/motion_filter.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import lietorch 4 | 5 | from collections import OrderedDict 6 | from droid_net import DroidNet 7 | 8 | import geom.projective_ops as pops 9 | from modules.corr import CorrBlock 10 | 11 | 12 | class MotionFilter: 13 | """ This class is used to filter incoming frames and extract features """ 14 | 15 | def __init__(self, net, video, thresh=2.5, device="cuda:0"): 16 | 17 | # split net modules 18 | self.cnet = net.cnet 19 | self.fnet = net.fnet 20 | self.update = net.update 21 | 22 | self.video = video 23 | self.thresh = thresh 24 | self.device = device 25 | 26 | self.count = 0 27 | 28 | # mean, std for image normalization 29 | self.MEAN = torch.as_tensor([0.485, 0.456, 0.406], device=self.device)[:, None, None] 30 | self.STDV = torch.as_tensor([0.229, 0.224, 0.225], device=self.device)[:, None, None] 31 | 32 | @torch.amp.autocast('cuda', enabled=True) 33 | def __context_encoder(self, image): 34 | """ context features """ 35 | net, inp = self.cnet(image).split([128,128], dim=2) 36 | return net.tanh().squeeze(0), inp.relu().squeeze(0) 37 | 38 | @torch.amp.autocast('cuda', enabled=True) 39 | def __feature_encoder(self, image): 40 | """ features for correlation volume """ 41 | return self.fnet(image).squeeze(0) 42 | 43 | @torch.amp.autocast('cuda', enabled=True) 44 | @torch.no_grad() 45 | def track(self, tstamp, image, depth=None, intrinsics=None, mask=None): 46 | """ main update operation - run on every frame in video """ 47 | 48 | Id = lietorch.SE3.Identity(1,).data.squeeze() 49 | ht = image.shape[-2] // 8 50 | wd = image.shape[-1] // 8 51 | 52 | # normalize images 53 | inputs = image[None, :, [2,1,0]].to(self.device) / 255.0 54 | inputs = inputs.sub_(self.MEAN).div_(self.STDV) 55 | 56 | # extract features 57 | gmap = self.__feature_encoder(inputs) # [1, 128, gh, gw] 58 | if mask is None: 59 | mask = torch.zeros([gmap.shape[-2], gmap.shape[-1]]).to(gmap) 60 | # if mask is not None: 61 | # # bias = self.fnet.conv2.bias.detach().clone().half() 62 | # # gmap[:,:,mask>0.0] = bias[:, None].repeat(1, (mask>0.0).sum()) 63 | # gmap[:,:,mask>0.0] = 0 64 | 65 | ### always add first frame to the depth video ### 66 | if self.video.counter.value == 0: 67 | net, inp = self.__context_encoder(inputs[:,[0]]) 68 | self.net, self.inp, self.fmap = net, inp, gmap 69 | self.video.append(tstamp, image[0], Id, 1.0, depth, intrinsics / 8.0, gmap, net[0,0], inp[0,0], mask) 70 | # msk: torch.Size([64, 48]) 71 | # gmap: torch.Size([1, 128, 64, 48]) 72 | # net: torch.Size([1, 128, 64, 48]) 73 | # inp: torch.Size([1, 128, 64, 48]) 74 | 75 | ### only add new frame if there is enough motion ### 76 | else: 77 | # index correlation volume 78 | coords0 = pops.coords_grid(ht, wd, device=self.device)[None,None] 79 | corr = CorrBlock(self.fmap[None,[0]], gmap[None,[0]])(coords0) 80 | 81 | # approximate flow magnitude using 1 update iteration 82 | _, delta, weight = self.update(self.net[None], self.inp[None], corr) 83 | 84 | # check motion magnitue / add new frame to video 85 | if delta.norm(dim=-1).mean().item() > self.thresh: 86 | self.count = 0 87 | net, inp = self.__context_encoder(inputs[:,[0]]) 88 | self.net, self.inp, self.fmap = net, inp, gmap 89 | self.video.append(tstamp, image[0], None, None, depth, intrinsics / 8.0, gmap, net[0], inp[0], mask) 90 | 91 | else: 92 | self.count += 1 93 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/droid_slam/trajectory_filler.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import lietorch 4 | 5 | from lietorch import SE3 6 | from collections import OrderedDict 7 | from factor_graph import FactorGraph 8 | from droid_net import DroidNet 9 | import geom.projective_ops as pops 10 | 11 | 12 | class PoseTrajectoryFiller: 13 | """ This class is used to fill in non-keyframe poses """ 14 | 15 | def __init__(self, net, video, device="cuda:0"): 16 | 17 | # split net modules 18 | self.cnet = net.cnet 19 | self.fnet = net.fnet 20 | self.update = net.update 21 | 22 | self.count = 0 23 | self.video = video 24 | self.device = device 25 | 26 | # mean, std for image normalization 27 | self.MEAN = torch.as_tensor([0.485, 0.456, 0.406], device=self.device)[:, None, None] 28 | self.STDV = torch.as_tensor([0.229, 0.224, 0.225], device=self.device)[:, None, None] 29 | 30 | @torch.amp.autocast('cuda', enabled=True) 31 | def __feature_encoder(self, image): 32 | """ features for correlation volume """ 33 | return self.fnet(image) 34 | 35 | def __fill(self, tstamps, images, intrinsics): 36 | """ fill operator """ 37 | 38 | tt = torch.as_tensor(tstamps, device="cuda") 39 | images = torch.stack(images, 0) 40 | intrinsics = torch.stack(intrinsics, 0) 41 | inputs = images[:,:,[2,1,0]].to(self.device) / 255.0 42 | 43 | ### linear pose interpolation ### 44 | N = self.video.counter.value # number of keyframes 45 | M = len(tstamps) # 16 frames to fill 46 | 47 | ts = self.video.tstamp[:N] # tstamp of keyframes 48 | Ps = SE3(self.video.poses[:N]) # pose of keyframes 49 | 50 | t0 = torch.as_tensor([ts[ts<=t].shape[0] - 1 for t in tstamps]) 51 | t1 = torch.where(t0 0: 108 | pose_list += self.__fill(tstamps, images, intrinsics) 109 | 110 | # stitch pose segments together 111 | return lietorch.cat(pose_list, 0) 112 | 113 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/environment.yaml: -------------------------------------------------------------------------------- 1 | name: droidenv 2 | channels: 3 | - rusty1s 4 | - pytorch 5 | - open3d-admin 6 | - nvidia 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - pytorch-scatter 11 | - torchaudio 12 | - torchvision 13 | - open3d 14 | - pytorch=1.10 15 | - cudatoolkit=11.3 16 | - tensorboard 17 | - scipy 18 | - opencv 19 | - tqdm 20 | - suitesparse 21 | - matplotlib 22 | - pyyaml 23 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/environment_novis.yaml: -------------------------------------------------------------------------------- 1 | name: droidenv 2 | channels: 3 | - rusty1s 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - pytorch-scatter 10 | - torchaudio 11 | - torchvision 12 | - pytorch=1.10 13 | - cudatoolkit=11.3 14 | - tensorboard 15 | - scipy 16 | - opencv 17 | - tqdm 18 | - suitesparse 19 | - matplotlib 20 | - pyyaml 21 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/evaluation_scripts/test_eth3d.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | import lietorch 8 | import cv2 9 | import os 10 | import glob 11 | import time 12 | import argparse 13 | 14 | import torch.nn.functional as F 15 | from droid import Droid 16 | 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | def show_image(image): 21 | image = image.permute(1, 2, 0).cpu().numpy() 22 | cv2.imshow('image', image / 255.0) 23 | cv2.waitKey(1) 24 | 25 | def image_stream(datapath, use_depth=False, stride=1): 26 | """ image generator """ 27 | 28 | fx, fy, cx, cy = np.loadtxt(os.path.join(datapath, 'calibration.txt')).tolist() 29 | image_list = sorted(glob.glob(os.path.join(datapath, 'rgb', '*.png')))[::stride] 30 | depth_list = sorted(glob.glob(os.path.join(datapath, 'depth', '*.png')))[::stride] 31 | 32 | for t, (image_file, depth_file) in enumerate(zip(image_list, depth_list)): 33 | image = cv2.imread(image_file) 34 | depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH) / 5000.0 35 | 36 | h0, w0, _ = image.shape 37 | h1 = int(h0 * np.sqrt((384 * 512) / (h0 * w0))) 38 | w1 = int(w0 * np.sqrt((384 * 512) / (h0 * w0))) 39 | 40 | image = cv2.resize(image, (w1, h1)) 41 | image = image[:h1-h1%8, :w1-w1%8] 42 | image = torch.as_tensor(image).permute(2, 0, 1) 43 | 44 | depth = torch.as_tensor(depth) 45 | depth = F.interpolate(depth[None,None], (h1, w1)).squeeze() 46 | depth = depth[:h1-h1%8, :w1-w1%8] 47 | 48 | intrinsics = torch.as_tensor([fx, fy, cx, cy]) 49 | intrinsics[0::2] *= (w1 / w0) 50 | intrinsics[1::2] *= (h1 / h0) 51 | 52 | if use_depth: 53 | yield t, image[None], depth, intrinsics 54 | 55 | else: 56 | yield t, image[None], intrinsics 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--datapath") 61 | parser.add_argument("--weights", default="droid.pth") 62 | parser.add_argument("--buffer", type=int, default=1024) 63 | parser.add_argument("--image_size", default=[240, 320]) 64 | parser.add_argument("--disable_vis", action="store_true") 65 | 66 | parser.add_argument("--beta", type=float, default=0.5) 67 | parser.add_argument("--filter_thresh", type=float, default=2.0) 68 | parser.add_argument("--warmup", type=int, default=8) 69 | parser.add_argument("--keyframe_thresh", type=float, default=3.5) 70 | parser.add_argument("--frontend_thresh", type=float, default=16.0) 71 | parser.add_argument("--frontend_window", type=int, default=16) 72 | parser.add_argument("--frontend_radius", type=int, default=1) 73 | parser.add_argument("--frontend_nms", type=int, default=0) 74 | 75 | parser.add_argument("--stereo", action="store_true") 76 | parser.add_argument("--depth", action="store_true") 77 | 78 | parser.add_argument("--backend_thresh", type=float, default=22.0) 79 | parser.add_argument("--backend_radius", type=int, default=2) 80 | parser.add_argument("--backend_nms", type=int, default=3) 81 | args = parser.parse_args() 82 | 83 | torch.multiprocessing.set_start_method('spawn') 84 | 85 | print("Running evaluation on {}".format(args.datapath)) 86 | print(args) 87 | 88 | # this can usually be set to 2-3 except for "camera_shake" scenes 89 | # set to 2 for test scenes 90 | stride = 1 91 | 92 | tstamps = [] 93 | for (t, image, depth, intrinsics) in tqdm(image_stream(args.datapath, use_depth=True, stride=stride)): 94 | if not args.disable_vis: 95 | show_image(image[0]) 96 | 97 | if t == 0: 98 | args.image_size = [image.shape[2], image.shape[3]] 99 | droid = Droid(args) 100 | 101 | droid.track(t, image, depth, intrinsics=intrinsics) 102 | 103 | traj_est = droid.terminate(image_stream(args.datapath, use_depth=False, stride=stride)) 104 | 105 | ### run evaluation ### 106 | 107 | print("#"*20 + " Results...") 108 | 109 | import evo 110 | from evo.core.trajectory import PoseTrajectory3D 111 | from evo.tools import file_interface 112 | from evo.core import sync 113 | import evo.main_ape as main_ape 114 | from evo.core.metrics import PoseRelation 115 | 116 | image_path = os.path.join(args.datapath, 'rgb') 117 | images_list = sorted(glob.glob(os.path.join(image_path, '*.png')))[::stride] 118 | tstamps = [float(x.split('/')[-1][:-4]) for x in images_list] 119 | 120 | traj_est = PoseTrajectory3D( 121 | positions_xyz=traj_est[:,:3], 122 | orientations_quat_wxyz=traj_est[:,3:], 123 | timestamps=np.array(tstamps)) 124 | 125 | gt_file = os.path.join(args.datapath, 'groundtruth.txt') 126 | traj_ref = file_interface.read_tum_trajectory_file(gt_file) 127 | 128 | traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est) 129 | 130 | result = main_ape.ape(traj_ref, traj_est, est_name='traj', 131 | pose_relation=PoseRelation.translation_part, align=True, correct_scale=False) 132 | 133 | print(result.stats) 134 | 135 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/evaluation_scripts/test_tum.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | import lietorch 8 | import cv2 9 | import os 10 | import glob 11 | import time 12 | import argparse 13 | 14 | import torch.nn.functional as F 15 | from droid import Droid 16 | 17 | 18 | def show_image(image): 19 | image = image.permute(1, 2, 0).cpu().numpy() 20 | cv2.imshow('image', image / 255.0) 21 | cv2.waitKey(1) 22 | 23 | def image_stream(datapath, image_size=[320, 512]): 24 | """ image generator """ 25 | 26 | fx, fy, cx, cy = 517.3, 516.5, 318.6, 255.3 27 | 28 | K_l = np.array([fx, 0.0, cx, 0.0, fy, cy, 0.0, 0.0, 1.0]).reshape(3,3) 29 | d_l = np.array([0.2624, -0.9531, -0.0054, 0.0026, 1.1633]) 30 | 31 | # read all png images in folder 32 | images_list = sorted(glob.glob(os.path.join(datapath, 'rgb', '*.png')))[::2] 33 | 34 | for t, imfile in enumerate(images_list): 35 | image = cv2.imread(imfile) 36 | ht0, wd0, _ = image.shape 37 | image = cv2.undistort(image, K_l, d_l) 38 | image = cv2.resize(image, (320+32, 240+16)) 39 | image = torch.from_numpy(image).permute(2,0,1) 40 | 41 | intrinsics = torch.as_tensor([fx, fy, cx, cy]).cuda() 42 | intrinsics[0] *= image.shape[2] / 640.0 43 | intrinsics[1] *= image.shape[1] / 480.0 44 | intrinsics[2] *= image.shape[2] / 640.0 45 | intrinsics[3] *= image.shape[1] / 480.0 46 | 47 | # crop image to remove distortion boundary 48 | intrinsics[2] -= 16 49 | intrinsics[3] -= 8 50 | image = image[:, 8:-8, 16:-16] 51 | 52 | yield t, image[None], intrinsics 53 | 54 | if __name__ == '__main__': 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--datapath") 57 | parser.add_argument("--weights", default="droid.pth") 58 | parser.add_argument("--buffer", type=int, default=512) 59 | parser.add_argument("--image_size", default=[240, 320]) 60 | parser.add_argument("--disable_vis", action="store_true") 61 | 62 | parser.add_argument("--beta", type=float, default=0.6) 63 | parser.add_argument("--filter_thresh", type=float, default=1.75) 64 | parser.add_argument("--warmup", type=int, default=12) 65 | parser.add_argument("--keyframe_thresh", type=float, default=2.25) 66 | parser.add_argument("--frontend_thresh", type=float, default=12.0) 67 | parser.add_argument("--frontend_window", type=int, default=25) 68 | parser.add_argument("--frontend_radius", type=int, default=2) 69 | parser.add_argument("--frontend_nms", type=int, default=1) 70 | 71 | parser.add_argument("--backend_thresh", type=float, default=15.0) 72 | parser.add_argument("--backend_radius", type=int, default=2) 73 | parser.add_argument("--backend_nms", type=int, default=3) 74 | args = parser.parse_args() 75 | 76 | args.stereo = False 77 | torch.multiprocessing.set_start_method('spawn') 78 | 79 | print("Running evaluation on {}".format(args.datapath)) 80 | print(args) 81 | 82 | droid = Droid(args) 83 | time.sleep(5) 84 | 85 | tstamps = [] 86 | for (t, image, intrinsics) in tqdm(image_stream(args.datapath)): 87 | if not args.disable_vis: 88 | show_image(image) 89 | droid.track(t, image, intrinsics=intrinsics) 90 | 91 | 92 | traj_est = droid.terminate(image_stream(args.datapath)) 93 | 94 | ### run evaluation ### 95 | 96 | print("#"*20 + " Results...") 97 | 98 | import evo 99 | from evo.core.trajectory import PoseTrajectory3D 100 | from evo.tools import file_interface 101 | from evo.core import sync 102 | import evo.main_ape as main_ape 103 | from evo.core.metrics import PoseRelation 104 | 105 | image_path = os.path.join(args.datapath, 'rgb') 106 | images_list = sorted(glob.glob(os.path.join(image_path, '*.png')))[::2] 107 | tstamps = [float(x.split('/')[-1][:-4]) for x in images_list] 108 | 109 | traj_est = PoseTrajectory3D( 110 | positions_xyz=traj_est[:,:3], 111 | orientations_quat_wxyz=traj_est[:,3:], 112 | timestamps=np.array(tstamps)) 113 | 114 | gt_file = os.path.join(args.datapath, 'groundtruth.txt') 115 | traj_ref = file_interface.read_tum_trajectory_file(gt_file) 116 | 117 | traj_ref, traj_est = sync.associate_trajectories(traj_ref, traj_est) 118 | result = main_ape.ape(traj_ref, traj_est, est_name='traj', 119 | pose_relation=PoseRelation.translation_part, align=True, correct_scale=True) 120 | 121 | 122 | print(result) 123 | 124 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/evaluation_scripts/validate_tartanair.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('droid_slam') 3 | sys.path.append('thirdparty/tartanair_tools') 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | import lietorch 9 | import cv2 10 | import os 11 | import glob 12 | import time 13 | import yaml 14 | import argparse 15 | 16 | from droid import Droid 17 | 18 | def image_stream(datapath, image_size=[384, 512], intrinsics_vec=[320.0, 320.0, 320.0, 240.0], stereo=False): 19 | """ image generator """ 20 | 21 | # read all png images in folder 22 | ht0, wd0 = [480, 640] 23 | images_left = sorted(glob.glob(os.path.join(datapath, 'image_left/*.png'))) 24 | images_right = sorted(glob.glob(os.path.join(datapath, 'image_right/*.png'))) 25 | 26 | data = [] 27 | for t in range(len(images_left)): 28 | images = [ cv2.resize(cv2.imread(images_left[t]), (image_size[1], image_size[0])) ] 29 | if stereo: 30 | images += [ cv2.resize(cv2.imread(images_right[t]), (image_size[1], image_size[0])) ] 31 | 32 | images = torch.from_numpy(np.stack(images, 0)).permute(0,3,1,2) 33 | intrinsics = .8 * torch.as_tensor(intrinsics_vec) 34 | 35 | data.append((t, images, intrinsics)) 36 | 37 | return data 38 | 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--datapath", default="datasets/TartanAir") 43 | parser.add_argument("--weights", default="droid.pth") 44 | parser.add_argument("--buffer", type=int, default=1000) 45 | parser.add_argument("--image_size", default=[384,512]) 46 | parser.add_argument("--stereo", action="store_true") 47 | parser.add_argument("--disable_vis", action="store_true") 48 | parser.add_argument("--plot_curve", action="store_true") 49 | parser.add_argument("--id", type=int, default=-1) 50 | 51 | parser.add_argument("--beta", type=float, default=0.3) 52 | parser.add_argument("--filter_thresh", type=float, default=2.4) 53 | parser.add_argument("--warmup", type=int, default=12) 54 | parser.add_argument("--keyframe_thresh", type=float, default=3.5) 55 | parser.add_argument("--frontend_thresh", type=float, default=15) 56 | parser.add_argument("--frontend_window", type=int, default=20) 57 | parser.add_argument("--frontend_radius", type=int, default=1) 58 | parser.add_argument("--frontend_nms", type=int, default=1) 59 | 60 | parser.add_argument("--backend_thresh", type=float, default=20.0) 61 | parser.add_argument("--backend_radius", type=int, default=2) 62 | parser.add_argument("--backend_nms", type=int, default=3) 63 | 64 | args = parser.parse_args() 65 | torch.multiprocessing.set_start_method('spawn') 66 | 67 | from data_readers.tartan import test_split 68 | from evaluation.tartanair_evaluator import TartanAirEvaluator 69 | 70 | if not os.path.isdir("figures"): 71 | os.mkdir("figures") 72 | 73 | if args.id >= 0: 74 | test_split = [ test_split[args.id] ] 75 | 76 | ate_list = [] 77 | for scene in test_split: 78 | print("Performing evaluation on {}".format(scene)) 79 | torch.cuda.empty_cache() 80 | droid = Droid(args) 81 | 82 | scenedir = os.path.join(args.datapath, scene) 83 | 84 | for (tstamp, image, intrinsics) in tqdm(image_stream(scenedir, stereo=args.stereo)): 85 | droid.track(tstamp, image, intrinsics=intrinsics) 86 | 87 | # fill in non-keyframe poses + global BA 88 | traj_est = droid.terminate(image_stream(scenedir)) 89 | 90 | ### do evaluation ### 91 | evaluator = TartanAirEvaluator() 92 | gt_file = os.path.join(scenedir, "pose_left.txt") 93 | traj_ref = np.loadtxt(gt_file, delimiter=' ')[:, [1, 2, 0, 4, 5, 3, 6]] # ned -> xyz 94 | 95 | # usually stereo should not be scale corrected, but we are comparing monocular and stereo here 96 | results = evaluator.evaluate_one_trajectory( 97 | traj_ref, traj_est, scale=True, title=scenedir[-20:].replace('/', '_')) 98 | 99 | print(results) 100 | ate_list.append(results["ate_score"]) 101 | 102 | print("Results") 103 | print(ate_list) 104 | 105 | if args.plot_curve: 106 | import matplotlib.pyplot as plt 107 | ate = np.array(ate_list) 108 | xs = np.linspace(0.0, 1.0, 512) 109 | ys = [np.count_nonzero(ate < t) / ate.shape[0] for t in xs] 110 | 111 | plt.plot(xs, ys) 112 | plt.xlabel("ATE [m]") 113 | plt.ylabel("% runs") 114 | plt.show() 115 | 116 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/misc/DROID.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/thirdparty/DROID-SLAM/misc/DROID.png -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/misc/renderoption.json: -------------------------------------------------------------------------------- 1 | { 2 | "background_color" : [ 1, 1, 1 ], 3 | "class_name" : "RenderOption", 4 | "default_mesh_color" : [ 0.69999999999999996, 0.69999999999999996, 0.69999999999999996 ], 5 | "image_max_depth" : 3000, 6 | "image_stretch_option" : 0, 7 | "interpolation_option" : 0, 8 | "light0_color" : [ 1, 1, 1 ], 9 | "light0_diffuse_power" : 20, 10 | "light0_position" : [ 0, 0, 20 ], 11 | "light0_specular_power" : 2.20000000000000001, 12 | "light0_specular_shininess" : 100, 13 | "light1_color" : [ 1, 1, 1 ], 14 | "light1_diffuse_power" : 0.66000000000000003, 15 | "light1_position" : [ 0, 0, 2 ], 16 | "light1_specular_power" : 2.20000000000000001, 17 | "light1_specular_shininess" : 100, 18 | "light2_color" : [ 1, 1, 1 ], 19 | "light2_diffuse_power" : 20, 20 | "light2_position" : [ 0, 0, -20 ], 21 | "light2_specular_power" : 2.20000000000000001, 22 | "light2_specular_shininess" : 100, 23 | "light3_color" : [ 1, 1, 1 ], 24 | "light3_diffuse_power" : 20, 25 | "light3_position" : [ 0, 0, -20 ], 26 | "light3_specular_power" : 2.20000000000000001, 27 | "light3_specular_shininess" : 100, 28 | "light_ambient_color" : [ 0, 0, 0 ], 29 | "light_on" : true, 30 | "mesh_color_option" : 1, 31 | "mesh_shade_option" : 0, 32 | "mesh_show_back_face" : false, 33 | "mesh_show_wireframe" : false, 34 | "point_color_option" : 7, 35 | "point_show_normal" : false, 36 | "point_size" : 2, 37 | "show_coordinate_frame" : false, 38 | "version_major" : 1, 39 | "version_minor" : 0 40 | } 41 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/misc/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/thirdparty/DROID-SLAM/misc/screenshot.png -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | import os.path as osp 5 | ROOT = osp.dirname(osp.abspath(__file__)) 6 | 7 | setup( 8 | name='droid_backends', 9 | ext_modules=[ 10 | CUDAExtension('droid_backends', 11 | include_dirs=[osp.join(ROOT, 'thirdparty/eigen')], 12 | sources=[ 13 | 'src/droid.cpp', 14 | 'src/droid_kernels.cu', 15 | 'src/correlation_kernels.cu', 16 | 'src/altcorr_kernel.cu', 17 | ], 18 | extra_compile_args={ 19 | 'cxx': ['-O3'], 20 | 'nvcc': ['-O3', 21 | '-gencode=arch=compute_60,code=sm_60', 22 | '-gencode=arch=compute_61,code=sm_61', 23 | '-gencode=arch=compute_70,code=sm_70', 24 | '-gencode=arch=compute_75,code=sm_75', 25 | '-gencode=arch=compute_80,code=sm_80', 26 | '-gencode=arch=compute_86,code=sm_86', 27 | ] 28 | }), 29 | ], 30 | cmdclass={ 'build_ext' : BuildExtension } 31 | ) 32 | 33 | setup( 34 | name='lietorch', 35 | version='0.2', 36 | description='Lie Groups for PyTorch', 37 | packages=['lietorch'], 38 | package_dir={'': 'thirdparty/lietorch'}, 39 | ext_modules=[ 40 | CUDAExtension('lietorch_backends', 41 | include_dirs=[ 42 | osp.join(ROOT, 'thirdparty/lietorch/lietorch/include'), 43 | osp.join(ROOT, 'thirdparty/eigen')], 44 | sources=[ 45 | 'thirdparty/lietorch/lietorch/src/lietorch.cpp', 46 | 'thirdparty/lietorch/lietorch/src/lietorch_gpu.cu', 47 | 'thirdparty/lietorch/lietorch/src/lietorch_cpu.cpp'], 48 | extra_compile_args={ 49 | 'cxx': ['-O2'], 50 | 'nvcc': ['-O2', 51 | '-gencode=arch=compute_60,code=sm_60', 52 | '-gencode=arch=compute_61,code=sm_61', 53 | '-gencode=arch=compute_70,code=sm_70', 54 | '-gencode=arch=compute_75,code=sm_75', 55 | '-gencode=arch=compute_80,code=sm_80', 56 | '-gencode=arch=compute_86,code=sm_86', 57 | ] 58 | }), 59 | ], 60 | cmdclass={ 'build_ext' : BuildExtension } 61 | ) 62 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/thirdparty/tartanair_tools/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Carnegie Mellon University 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/thirdparty/tartanair_tools/data_type.md: -------------------------------------------------------------------------------- 1 | ### GRB Image 2 | 3 | The color images are stored as 640x480 8-bit RGB images in PNG format. 4 | 5 | * Load the image using OpenCV: 6 | ``` 7 | import cv2 8 | img = cv2.imread(FILENAME) 9 | cv2.imshow('img', img) 10 | cv2.waitKey(0) 11 | ``` 12 | 13 | * Load the image using Pillow: 14 | ``` 15 | from PIL import Image 16 | img = Image.open(FILENAME) 17 | img.show() 18 | ``` 19 | 20 | ### Camera intrinsics 21 | ``` 22 | fx = 320.0 # focal length x 23 | fy = 320.0 # focal length y 24 | cx = 320.0 # optical center x 25 | cy = 240.0 # optical center y 26 | 27 | fov = 90 deg # field of view 28 | 29 | width = 640 30 | height = 480 31 | ``` 32 | 33 | ### Depth image 34 | 35 | The depth maps are stored as 640x480 16-bit numpy array in NPY format. In the Unreal Engine, the environment usually has a sky sphere at a large distance. So the infinite distant object such as the sky has a large depth value (e.g. 10000) instead of an infinite number. 36 | 37 | The unit of the depth value is meter. The baseline between the left and right cameras is 0.25m. 38 | 39 | * Load the depth image: 40 | ``` 41 | import numpy as np 42 | depth = np.load(FILENAME) 43 | 44 | # change to disparity image 45 | disparity = 80.0 / depth 46 | ``` 47 | 48 | ### Segmentation image 49 | 50 | The segmentation images are saved as a uint8 numpy array. AirSim assigns value 0 to 255 to each mesh available in the environment. 51 | 52 | [More details](https://github.com/microsoft/AirSim/blob/master/docs/image_apis.md#segmentation) 53 | 54 | * Load the segmentation image 55 | ``` 56 | import numpy as np 57 | depth = np.load(FILENAME) 58 | ``` 59 | 60 | ### Optical flow 61 | 62 | The optical flow maps are saved as a float32 numpy array, which is calculated based on the ground truth depth and ground truth camera motion, using [this](https://github.com/huyaoyu/ImageFlow) code. Dynamic objects and occlusions are masked by the mask file, which is a uint8 numpy array. We currently provide the optical flow for the left camera. 63 | 64 | * Load the optical flow 65 | ``` 66 | import numpy as np 67 | flow = np.load(FILENAME) 68 | 69 | # load the mask 70 | mask = np.load(MASKFILENAME) 71 | ``` 72 | 73 | ### Pose file 74 | 75 | The camera pose file is a text file containing the translation and orientation of the camera in a fixed coordinate frame. Note that our automatic evaluation tool expects both the ground truth trajectory and the estimated trajectory to be in this format.  76 | 77 | * Each line in the text file contains a single pose. 78 | 79 | * The number of lines/poses is the same as the number of image frames in that trajectory.  80 | 81 | * The format of each line is '**tx ty tz qx qy qz qw**'.  82 | 83 | * **tx ty tz** (3 floats) give the position of the optical center of the color camera with respect to the world origin in the world frame. 84 | 85 | * **qx qy qz qw** (4 floats) give the orientation of the optical center of the color camera in the form of a unit quaternion with respect to the world frame.  86 | 87 | * The camera motion is defined in the NED frame. That is to say, the x-axis is pointing to the camera's forward, the y-axis is pointing to the camera's right, the z-axis is pointing to the camera's downward. 88 | 89 | * Load the pose file: 90 | ``` 91 | import numpy as np 92 | flow = np.loadtxt(FILENAME) 93 | ``` 94 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/thirdparty/tartanair_tools/download_cvpr_slam_test.txt: -------------------------------------------------------------------------------- 1 | https://tartanair.blob.core.windows.net/tartanair-testing1/tartanair-test-mono-release.tar.gz 2 | https://tartanair.blob.core.windows.net/tartanair-testing1/tartanair-test-stereo-release.tar.gz 3 | https://tartanair.blob.core.windows.net/tartanair-testing1/tartanair-test-release.tar.gz -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/thirdparty/tartanair_tools/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/thirdparty/DROID-SLAM/thirdparty/tartanair_tools/evaluation/__init__.py -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/thirdparty/tartanair_tools/evaluation/evaluate_ate_scale.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Modified by Wenshan Wang 4 | # Modified by Raul Mur-Artal 5 | # Automatically compute the optimal scale factor for monocular VO/SLAM. 6 | 7 | # Software License Agreement (BSD License) 8 | # 9 | # Copyright (c) 2013, Juergen Sturm, TUM 10 | # All rights reserved. 11 | # 12 | # Redistribution and use in source and binary forms, with or without 13 | # modification, are permitted provided that the following conditions 14 | # are met: 15 | # 16 | # * Redistributions of source code must retain the above copyright 17 | # notice, this list of conditions and the following disclaimer. 18 | # * Redistributions in binary form must reproduce the above 19 | # copyright notice, this list of conditions and the following 20 | # disclaimer in the documentation and/or other materials provided 21 | # with the distribution. 22 | # * Neither the name of TUM nor the names of its 23 | # contributors may be used to endorse or promote products derived 24 | # from this software without specific prior written permission. 25 | # 26 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 27 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 28 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 29 | # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 30 | # COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 31 | # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 32 | # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 33 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 34 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 35 | # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 36 | # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 37 | # POSSIBILITY OF SUCH DAMAGE. 38 | # 39 | # Requirements: 40 | # sudo apt-get install python-argparse 41 | 42 | """ 43 | This script computes the absolute trajectory error from the ground truth 44 | trajectory and the estimated trajectory. 45 | """ 46 | 47 | import numpy 48 | 49 | def align(model,data,calc_scale=False): 50 | """Align two trajectories using the method of Horn (closed-form). 51 | 52 | Input: 53 | model -- first trajectory (3xn) 54 | data -- second trajectory (3xn) 55 | 56 | Output: 57 | rot -- rotation matrix (3x3) 58 | trans -- translation vector (3x1) 59 | trans_error -- translational error per point (1xn) 60 | 61 | """ 62 | numpy.set_printoptions(precision=3,suppress=True) 63 | model_zerocentered = model - model.mean(1) 64 | data_zerocentered = data - data.mean(1) 65 | 66 | W = numpy.zeros( (3,3) ) 67 | for column in range(model.shape[1]): 68 | W += numpy.outer(model_zerocentered[:,column],data_zerocentered[:,column]) 69 | U,d,Vh = numpy.linalg.linalg.svd(W.transpose()) 70 | S = numpy.matrix(numpy.identity( 3 )) 71 | if(numpy.linalg.det(U) * numpy.linalg.det(Vh)<0): 72 | S[2,2] = -1 73 | rot = U*S*Vh 74 | 75 | if calc_scale: 76 | rotmodel = rot*model_zerocentered 77 | dots = 0.0 78 | norms = 0.0 79 | for column in range(data_zerocentered.shape[1]): 80 | dots += numpy.dot(data_zerocentered[:,column].transpose(),rotmodel[:,column]) 81 | normi = numpy.linalg.norm(model_zerocentered[:,column]) 82 | norms += normi*normi 83 | # s = float(dots/norms) 84 | s = float(norms/dots) 85 | else: 86 | s = 1.0 87 | 88 | # trans = data.mean(1) - s*rot * model.mean(1) 89 | # model_aligned = s*rot * model + trans 90 | # alignment_error = model_aligned - data 91 | 92 | # scale the est to the gt, otherwise the ATE could be very small if the est scale is small 93 | trans = s*data.mean(1) - rot * model.mean(1) 94 | model_aligned = rot * model + trans 95 | data_alingned = s * data 96 | alignment_error = model_aligned - data_alingned 97 | 98 | trans_error = numpy.sqrt(numpy.sum(numpy.multiply(alignment_error,alignment_error),0)).A[0] 99 | 100 | return rot,trans,trans_error, s 101 | 102 | def plot_traj(ax,stamps,traj,style,color,label): 103 | """ 104 | Plot a trajectory using matplotlib. 105 | 106 | Input: 107 | ax -- the plot 108 | stamps -- time stamps (1xn) 109 | traj -- trajectory (3xn) 110 | style -- line style 111 | color -- line color 112 | label -- plot legend 113 | 114 | """ 115 | stamps.sort() 116 | interval = numpy.median([s-t for s,t in zip(stamps[1:],stamps[:-1])]) 117 | x = [] 118 | y = [] 119 | last = stamps[0] 120 | for i in range(len(stamps)): 121 | if stamps[i]-last < 2*interval: 122 | x.append(traj[i][0]) 123 | y.append(traj[i][1]) 124 | elif len(x)>0: 125 | ax.plot(x,y,style,color=color,label=label) 126 | label="" 127 | x=[] 128 | y=[] 129 | last= stamps[i] 130 | if len(x)>0: 131 | ax.plot(x,y,style,color=color,label=label) 132 | 133 | 134 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/thirdparty/tartanair_tools/evaluation/evaluate_kitti.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Carnegie Mellon University, Wenshan Wang 2 | # For License information please see the LICENSE file in the root directory. 3 | # This is a python reinplementation of the KITTI metric: http://www.cvlibs.net/datasets/kitti/eval_odometry.php 4 | # Cridit: Xiangwei Wang https://github.com/TimingSpace 5 | 6 | import numpy as np 7 | import sys 8 | 9 | def trajectory_distances(poses): 10 | distances = [] 11 | distances.append(0) 12 | for i in range(1,len(poses)): 13 | p1 = poses[i-1] 14 | p2 = poses[i] 15 | delta = p1[0:3,3] - p2[0:3,3] 16 | distances.append(distances[i-1]+np.linalg.norm(delta)) 17 | return distances 18 | 19 | def last_frame_from_segment_length(dist,first_frame,length): 20 | for i in range(first_frame,len(dist)): 21 | if dist[i]>dist[first_frame]+length: 22 | return i 23 | return -1 24 | 25 | def rotation_error(pose_error): 26 | a = pose_error[0,0] 27 | b = pose_error[1,1] 28 | c = pose_error[2,2] 29 | d = 0.5*(a+b+c-1) 30 | rot_error = np.arccos(max(min(d,1.0),-1.0)) 31 | return rot_error 32 | 33 | def translation_error(pose_error): 34 | dx = pose_error[0,3] 35 | dy = pose_error[1,3] 36 | dz = pose_error[2,3] 37 | return np.sqrt(dx*dx+dy*dy+dz*dz) 38 | 39 | # def line2matrix(pose_line): 40 | # pose_line = np.array(pose_line) 41 | # pose_m = np.matrix(np.eye(4)) 42 | # pose_m[0:3,:] = pose_line.reshape(3,4) 43 | # return pose_m 44 | 45 | def calculate_sequence_error(poses_gt,poses_result,lengths=[10,20,30,40,50,60,70,80]): 46 | # error_vetor 47 | errors = [] 48 | 49 | # paramet 50 | step_size = 1 #10; # every second 51 | num_lengths = len(lengths) 52 | 53 | # import ipdb;ipdb.set_trace() 54 | # pre-compute distances (from ground truth as reference) 55 | dist = trajectory_distances(poses_gt) 56 | # for all start positions do 57 | for first_frame in range(0, len(poses_gt), step_size): 58 | # for all segment lengths do 59 | for i in range(0,num_lengths): 60 | # current length 61 | length = lengths[i]; 62 | 63 | # compute last frame 64 | last_frame = last_frame_from_segment_length(dist,first_frame,length); 65 | # continue, if sequence not long enough 66 | if (last_frame==-1): 67 | continue; 68 | 69 | # compute rotational and translational errors 70 | pose_delta_gt = np.linalg.inv(poses_gt[first_frame]).dot(poses_gt[last_frame]) 71 | pose_delta_result = np.linalg.inv(poses_result[first_frame]).dot(poses_result[last_frame]) 72 | pose_error = np.linalg.inv(pose_delta_result).dot(pose_delta_gt) 73 | r_err = rotation_error(pose_error); 74 | t_err = translation_error(pose_error); 75 | 76 | # compute speed 77 | num_frames = (float)(last_frame-first_frame+1); 78 | speed = length/(0.1*num_frames); 79 | 80 | # write to file 81 | error = [first_frame,r_err/length,t_err/length,length,speed] 82 | errors.append(error) 83 | # return error vector 84 | return errors 85 | 86 | def calculate_ave_errors(errors,lengths=[10,20,30,40,50,60,70,80]): 87 | rot_errors=[] 88 | tra_errors=[] 89 | for length in lengths: 90 | rot_error_each_length =[] 91 | tra_error_each_length =[] 92 | for error in errors: 93 | if abs(error[3]-length)<0.1: 94 | rot_error_each_length.append(error[1]) 95 | tra_error_each_length.append(error[2]) 96 | 97 | if len(rot_error_each_length)==0: 98 | # import ipdb;ipdb.set_trace() 99 | continue 100 | else: 101 | rot_errors.append(sum(rot_error_each_length)/len(rot_error_each_length)) 102 | tra_errors.append(sum(tra_error_each_length)/len(tra_error_each_length)) 103 | return np.array(rot_errors)*180/np.pi, tra_errors 104 | 105 | def evaluate(gt, data,rescale_=False): 106 | lens = [5,10,15,20,25,30,35,40] #[1,2,3,4,5,6] # 107 | errors = calculate_sequence_error(gt, data, lengths=lens) 108 | rot,tra = calculate_ave_errors(errors, lengths=lens) 109 | return np.mean(rot), np.mean(tra) 110 | 111 | def main(): 112 | # usage: python main.py path_to_ground_truth path_to_predict_pose 113 | # load and preprocess data 114 | ground_truth_data = np.loadtxt(sys.argv[1]) 115 | predict_pose__data = np.loadtxt(sys.argv[2]) 116 | errors = calculate_sequence_error(ground_truth_data,predict_pose__data) 117 | rot,tra = calculate_ave_errors(errors) 118 | print(rot,'\n',tra) 119 | #print(error) 120 | # evaluate the vo result 121 | # save and visualization the evaluatation result 122 | 123 | if __name__ == "__main__": 124 | main() 125 | 126 | 127 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/thirdparty/tartanair_tools/evaluation/evaluator_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Carnegie Mellon University, Wenshan Wang 2 | # For License information please see the LICENSE file in the root directory. 3 | 4 | import numpy as np 5 | from .trajectory_transform import trajectory_transform, rescale 6 | from .transformation import pos_quats2SE_matrices, SE2pos_quat 7 | 8 | 9 | np.set_printoptions(suppress=True, precision=2, threshold=100000) 10 | 11 | def transform_trajs(gt_traj, est_traj, cal_scale): 12 | gt_traj, est_traj = trajectory_transform(gt_traj, est_traj) 13 | if cal_scale : 14 | est_traj, s = rescale(gt_traj, est_traj) 15 | print(' Scale, {}'.format(s)) 16 | else: 17 | s = 1.0 18 | return gt_traj, est_traj, s 19 | 20 | def quats2SEs(gt_traj, est_traj): 21 | gt_SEs = pos_quats2SE_matrices(gt_traj) 22 | est_SEs = pos_quats2SE_matrices(est_traj) 23 | return gt_SEs, est_SEs 24 | 25 | from .evaluate_ate_scale import align, plot_traj 26 | 27 | 28 | class ATEEvaluator(object): 29 | def __init__(self): 30 | super(ATEEvaluator, self).__init__() 31 | 32 | 33 | def evaluate(self, gt_traj, est_traj, scale): 34 | gt_xyz = np.matrix(gt_traj[:,0:3].transpose()) 35 | est_xyz = np.matrix(est_traj[:, 0:3].transpose()) 36 | 37 | rot, trans, trans_error, s = align(gt_xyz, est_xyz, scale) 38 | print(' ATE scale: {}'.format(s)) 39 | error = np.sqrt(np.dot(trans_error,trans_error) / len(trans_error)) 40 | 41 | # align two trajs 42 | est_SEs = pos_quats2SE_matrices(est_traj) 43 | T = np.eye(4) 44 | T[:3,:3] = rot 45 | T[:3,3:] = trans 46 | T = np.linalg.inv(T) 47 | est_traj_aligned = [] 48 | for se in est_SEs: 49 | se[:3,3] = se[:3,3] * s 50 | se_new = T.dot(se) 51 | se_new = SE2pos_quat(se_new) 52 | est_traj_aligned.append(se_new) 53 | 54 | 55 | return error, gt_traj, est_traj_aligned 56 | 57 | # ======================= 58 | 59 | from .evaluate_rpe import evaluate_trajectory 60 | 61 | class RPEEvaluator(object): 62 | def __init__(self): 63 | super(RPEEvaluator, self).__init__() 64 | 65 | 66 | def evaluate(self, gt_SEs, est_SEs): 67 | result = evaluate_trajectory(gt_SEs, est_SEs) 68 | 69 | trans_error = np.array(result)[:,2] 70 | rot_error = np.array(result)[:,3] 71 | 72 | trans_error_mean = np.mean(trans_error) 73 | rot_error_mean = np.mean(rot_error) 74 | 75 | # import ipdb;ipdb.set_trace() 76 | 77 | return (rot_error_mean, trans_error_mean) 78 | 79 | # ======================= 80 | 81 | from .evaluate_kitti import evaluate as kittievaluate 82 | 83 | class KittiEvaluator(object): 84 | def __init__(self): 85 | super(KittiEvaluator, self).__init__() 86 | 87 | # return rot_error, tra_error 88 | def evaluate(self, gt_SEs, est_SEs): 89 | # trajectory_scale(est_SEs, 0.831984631412) 90 | error = kittievaluate(gt_SEs, est_SEs) 91 | return error 92 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/thirdparty/tartanair_tools/evaluation/tartanair_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Carnegie Mellon University, Wenshan Wang 2 | # For License information please see the LICENSE file in the root directory. 3 | 4 | import numpy as np 5 | from os.path import isdir, isfile 6 | 7 | from .evaluator_base import ATEEvaluator, RPEEvaluator, KittiEvaluator, transform_trajs, quats2SEs 8 | 9 | # from trajectory_transform import timestamp_associate 10 | 11 | 12 | def plot_traj(gtposes, estposes, vis=False, savefigname=None, title=''): 13 | import matplotlib.pyplot as plt 14 | fig = plt.figure(figsize=(4,4)) 15 | 16 | 17 | cm = plt.cm.get_cmap('Spectral') 18 | 19 | plt.subplot(111) 20 | plt.plot(gtposes[:,2],gtposes[:,0], linestyle='dashed',c='k') 21 | plt.plot(estposes[:, 2], estposes[:, 0],c='#ff7f0e') 22 | plt.xlabel('x (m)') 23 | plt.ylabel('y (m)') 24 | plt.legend(['Ground Truth','Ours']) 25 | plt.title(title) 26 | 27 | plt.axis('equal') 28 | 29 | if savefigname is not None: 30 | plt.savefig(savefigname) 31 | 32 | if vis: 33 | plt.show() 34 | 35 | plt.close(fig) 36 | 37 | 38 | # 39 | 40 | class TartanAirEvaluator: 41 | def __init__(self, scale = False, round=1): 42 | self.ate_eval = ATEEvaluator() 43 | self.rpe_eval = RPEEvaluator() 44 | self.kitti_eval = KittiEvaluator() 45 | 46 | def evaluate_one_trajectory(self, gt_traj, est_traj, scale=False, title=''): 47 | """ 48 | scale = True: calculate a global scale 49 | """ 50 | 51 | if gt_traj.shape[0] != est_traj.shape[0]: 52 | raise Exception("POSEFILE_LENGTH_ILLEGAL") 53 | 54 | if gt_traj.shape[1] != 7 or est_traj.shape[1] != 7: 55 | raise Exception("POSEFILE_FORMAT_ILLEGAL") 56 | 57 | gt_traj = gt_traj.astype(np.float64) 58 | est_traj = est_traj.astype(np.float64) 59 | 60 | ate_score, gt_ate_aligned, est_ate_aligned = self.ate_eval.evaluate(gt_traj, est_traj, scale) 61 | 62 | plot_traj(np.matrix(gt_ate_aligned), np.matrix(est_ate_aligned), vis=False, savefigname="figures/%s.pdf"%title, title=title) 63 | 64 | est_ate_aligned = np.array(est_ate_aligned) 65 | gt_SEs, est_SEs = quats2SEs(gt_ate_aligned, est_ate_aligned) 66 | 67 | 68 | 69 | rpe_score = self.rpe_eval.evaluate(gt_SEs, est_SEs) 70 | kitti_score = self.kitti_eval.evaluate(gt_SEs, est_SEs) 71 | 72 | return {'ate_score': ate_score, 'rpe_score': rpe_score, 'kitti_score': kitti_score} 73 | 74 | 75 | if __name__ == "__main__": 76 | 77 | # scale = True for monocular track, scale = False for stereo track 78 | aicrowd_evaluator = TartanAirEvaluator() 79 | result = aicrowd_evaluator.evaluate_one_trajectory('pose_gt.txt', 'pose_est.txt', scale=True) 80 | print(result) 81 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/thirdparty/tartanair_tools/seg_rgbs.txt: -------------------------------------------------------------------------------- 1 | 0 0 0 2 | 153 108 6 3 | 112 105 191 4 | 89 121 72 5 | 190 225 64 6 | 206 190 59 7 | 81 13 36 8 | 115 176 195 9 | 161 171 27 10 | 135 169 180 11 | 29 26 199 12 | 102 16 239 13 | 242 107 146 14 | 156 198 23 15 | 49 89 160 16 | 68 218 116 17 | 11 236 9 18 | 196 30 8 19 | 121 67 28 20 | 0 53 65 21 | 146 52 70 22 | 226 149 143 23 | 151 126 171 24 | 194 39 7 25 | 205 120 161 26 | 212 51 60 27 | 211 80 208 28 | 189 135 188 29 | 54 72 205 30 | 103 252 157 31 | 124 21 123 32 | 19 132 69 33 | 195 237 132 34 | 94 253 175 35 | 182 251 87 36 | 90 162 242 37 | 199 29 1 38 | 254 12 229 39 | 35 196 244 40 | 220 163 49 41 | 86 254 214 42 | 152 3 129 43 | 92 31 106 44 | 207 229 90 45 | 125 75 48 46 | 98 55 74 47 | 126 129 238 48 | 222 153 109 49 | 85 152 34 50 | 173 69 31 51 | 37 128 125 52 | 58 19 33 53 | 134 57 119 54 | 218 124 115 55 | 120 0 200 56 | 225 131 92 57 | 246 90 16 58 | 51 155 241 59 | 202 97 155 60 | 184 145 182 61 | 96 232 44 62 | 133 244 133 63 | 180 191 29 64 | 1 222 192 65 | 99 242 104 66 | 91 168 219 67 | 65 54 217 68 | 148 66 130 69 | 203 102 204 70 | 216 78 75 71 | 234 20 250 72 | 109 206 24 73 | 164 194 17 74 | 157 23 236 75 | 158 114 88 76 | 245 22 110 77 | 67 17 35 78 | 181 213 93 79 | 170 179 42 80 | 52 187 148 81 | 247 200 111 82 | 25 62 174 83 | 100 25 240 84 | 191 195 144 85 | 252 36 67 86 | 241 77 149 87 | 237 33 141 88 | 119 230 85 89 | 28 34 108 90 | 78 98 254 91 | 114 161 30 92 | 75 50 243 93 | 66 226 253 94 | 46 104 76 95 | 8 234 216 96 | 15 241 102 97 | 93 14 71 98 | 192 255 193 99 | 253 41 164 100 | 24 175 120 101 | 185 243 231 102 | 169 233 97 103 | 243 215 145 104 | 72 137 21 105 | 160 113 101 106 | 214 92 13 107 | 167 140 147 108 | 101 109 181 109 | 53 118 126 110 | 3 177 32 111 | 40 63 99 112 | 186 139 153 113 | 88 207 100 114 | 71 146 227 115 | 236 38 187 116 | 215 4 215 117 | 18 211 66 118 | 113 49 134 119 | 47 42 63 120 | 219 103 127 121 | 57 240 137 122 | 227 133 211 123 | 145 71 201 124 | 217 173 183 125 | 250 40 113 126 | 208 125 68 127 | 224 186 249 128 | 69 148 46 129 | 239 85 20 130 | 108 116 224 131 | 56 214 26 132 | 179 147 43 133 | 48 188 172 134 | 221 83 47 135 | 155 166 218 136 | 62 217 189 137 | 198 180 122 138 | 201 144 169 139 | 132 2 14 140 | 128 189 114 141 | 163 227 112 142 | 45 157 177 143 | 64 86 142 144 | 118 193 163 145 | 14 32 79 146 | 200 45 170 147 | 74 81 2 148 | 59 37 212 149 | 73 35 225 150 | 95 224 39 151 | 84 170 220 152 | 159 58 173 153 | 17 91 237 154 | 31 95 84 155 | 34 201 248 156 | 63 73 209 157 | 129 235 107 158 | 231 115 40 159 | 36 74 95 160 | 238 228 154 161 | 61 212 54 162 | 13 94 165 163 | 141 174 0 164 | 140 167 255 165 | 117 93 91 166 | 183 10 186 167 | 165 28 61 168 | 144 238 194 169 | 12 158 41 170 | 76 110 234 171 | 150 9 121 172 | 142 1 246 173 | 230 136 198 174 | 5 60 233 175 | 232 250 80 176 | 143 112 56 177 | 187 70 156 178 | 2 185 62 179 | 138 223 226 180 | 122 183 222 181 | 166 245 3 182 | 175 6 140 183 | 240 59 210 184 | 248 44 10 185 | 83 82 52 186 | 223 248 167 187 | 87 15 150 188 | 111 178 117 189 | 197 84 22 190 | 235 208 124 191 | 9 76 45 192 | 176 24 50 193 | 154 159 251 194 | 149 111 207 195 | 168 231 15 196 | 209 247 202 197 | 80 205 152 198 | 178 221 213 199 | 27 8 38 200 | 244 117 51 201 | 107 68 190 202 | 23 199 139 203 | 171 88 168 204 | 136 202 58 205 | 6 46 86 206 | 105 127 176 207 | 174 249 197 208 | 172 172 138 209 | 228 142 81 210 | 7 204 185 211 | 22 61 247 212 | 233 100 78 213 | 127 65 105 214 | 33 87 158 215 | 139 156 252 216 | 42 7 136 217 | 20 99 179 218 | 79 150 223 219 | 131 182 184 220 | 110 123 37 221 | 60 138 96 222 | 210 96 94 223 | 123 48 18 224 | 137 197 162 225 | 188 18 5 226 | 39 219 151 227 | 204 143 135 228 | 249 79 73 229 | 77 64 178 230 | 41 246 77 231 | 16 154 4 232 | 116 134 19 233 | 4 122 235 234 | 177 106 230 235 | 21 119 12 236 | 104 5 98 237 | 50 130 53 238 | 30 192 25 239 | 26 165 166 240 | 10 160 82 241 | 106 43 131 242 | 44 216 103 243 | 255 101 221 244 | 32 151 196 245 | 213 220 89 246 | 70 209 228 247 | 97 184 83 248 | 82 239 232 249 | 251 164 128 250 | 193 11 245 251 | 38 27 159 252 | 229 141 203 253 | 130 56 55 254 | 147 210 11 255 | 162 203 118 256 | 255 255 255 -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/tools/download_sample_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p data && cd data 4 | 5 | 6 | gdown https://drive.google.com/uc?id=1AlfhZnGmlsKWGcNHFB1i8i8Jzn4VHB15 7 | unzip abandonedfactory.zip 8 | rm abandonedfactory.zip 9 | 10 | gdown https://drive.google.com/uc?id=0B-ePgl6HF260NzQySklGdXZyQzA 11 | unzip Barn.zip 12 | rm Barn.zip 13 | 14 | wget https://www.eth3d.net/data/slam/datasets/sfm_bench_mono.zip 15 | unzip sfm_bench_mono.zip 16 | rm sfm_bench_mono.zip 17 | 18 | wget https://vision.in.tum.de/rgbd/dataset/freiburg3/rgbd_dataset_freiburg3_cabinet.tgz 19 | tar -zxvf rgbd_dataset_freiburg3_cabinet.tgz 20 | rm rgbd_dataset_freiburg3_cabinet.tgz 21 | 22 | wget http://robotics.ethz.ch/~asl-datasets/ijrr_euroc_mav_dataset/machine_hall/MH_03_medium/MH_03_medium.zip 23 | unzip MH_03_medium.zip 24 | rm MH_03_medium.zip 25 | 26 | cd .. 27 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/tools/evaluate_eth3d.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ETH_PATH=datasets/ETH3D-SLAM/training 5 | 6 | # all "non-dark" training scenes 7 | evalset=( 8 | cables_1 9 | cables_2 10 | cables_3 11 | camera_shake_1 12 | camera_shake_2 13 | camera_shake_3 14 | ceiling_1 15 | ceiling_2 16 | desk_3 17 | desk_changing_1 18 | einstein_1 19 | einstein_2 20 | # einstein_dark 21 | einstein_flashlight 22 | einstein_global_light_changes_1 23 | einstein_global_light_changes_2 24 | einstein_global_light_changes_3 25 | kidnap_1 26 | # kidnap_dark 27 | large_loop_1 28 | mannequin_1 29 | mannequin_3 30 | mannequin_4 31 | mannequin_5 32 | mannequin_7 33 | mannequin_face_1 34 | mannequin_face_2 35 | mannequin_face_3 36 | mannequin_head 37 | motion_1 38 | planar_2 39 | planar_3 40 | plant_1 41 | plant_2 42 | plant_3 43 | plant_4 44 | plant_5 45 | # plant_dark 46 | plant_scene_1 47 | plant_scene_2 48 | plant_scene_3 49 | reflective_1 50 | repetitive 51 | sfm_bench 52 | sfm_garden 53 | sfm_house_loop 54 | sfm_lab_room_1 55 | sfm_lab_room_2 56 | sofa_1 57 | sofa_2 58 | sofa_3 59 | sofa_4 60 | # sofa_dark_1 61 | # sofa_dark_2 62 | # sofa_dark_3 63 | sofa_shake 64 | table_3 65 | table_4 66 | table_7 67 | vicon_light_1 68 | vicon_light_2 69 | ) 70 | 71 | for seq in ${evalset[@]}; do 72 | python evaluation_scripts/test_eth3d.py --datapath=$ETH_PATH/$seq --weights=droid.pth --disable_vis $@ 73 | done 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/tools/evaluate_euroc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | EUROC_PATH=datasets/EuRoC 5 | 6 | evalset=( 7 | MH_01_easy 8 | MH_02_easy 9 | MH_03_medium 10 | MH_04_difficult 11 | MH_05_difficult 12 | V1_01_easy 13 | V1_02_medium 14 | V1_03_difficult 15 | V2_01_easy 16 | V2_02_medium 17 | V2_03_difficult 18 | ) 19 | 20 | for seq in ${evalset[@]}; do 21 | python evaluation_scripts/test_euroc.py --datapath=$EUROC_PATH/$seq --gt=data/euroc_groundtruth/$seq.txt --weights=droid.pth $@ 22 | done 23 | 24 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/tools/evaluate_tum.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | TUM_PATH=datasets/TUM-RGBD/$seq 5 | 6 | evalset=( 7 | rgbd_dataset_freiburg1_360 8 | rgbd_dataset_freiburg1_desk 9 | rgbd_dataset_freiburg1_desk2 10 | rgbd_dataset_freiburg1_floor 11 | rgbd_dataset_freiburg1_plant 12 | rgbd_dataset_freiburg1_room 13 | rgbd_dataset_freiburg1_rpy 14 | rgbd_dataset_freiburg1_teddy 15 | rgbd_dataset_freiburg1_xyz 16 | ) 17 | 18 | for seq in ${evalset[@]}; do 19 | python evaluation_scripts/test_tum.py --datapath=$TUM_PATH/$seq --weights=droid.pth --disable_vis $@ 20 | done 21 | 22 | -------------------------------------------------------------------------------- /thirdparty/DROID-SLAM/tools/validate_tartanair.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | TARTANAIR_PATH=datasets/TartanAir 5 | 6 | python evaluation_scripts/validate_tartanair.py --datapath=$TARTANAIR_PATH --weights=droid.pth --disable_vis $@ 7 | 8 | -------------------------------------------------------------------------------- /thirdparty/camcalib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yufu-wang/tram/1714d96fa1da8011d506c26fa6c74a3dc27d1af8/thirdparty/camcalib/__init__.py -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pprint 4 | import random 5 | import numpy as np 6 | import torch.backends.cudnn as cudnn 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | from lib.core.config import parse_args 11 | from lib.core.losses import compile_criterion 12 | from lib.utils.utils import prepare_output_dir, create_logger 13 | from lib.trainer import Trainer 14 | 15 | from lib.get_videoloader import get_dataloaders 16 | from lib.models.hmr_vimo import HMR_VIMO 17 | 18 | 19 | def main(cfg): 20 | if cfg.SEED_VALUE >= 0: 21 | os.environ['PYTHONHASHSEED'] = str(cfg.SEED_VALUE) 22 | random.seed(cfg.SEED_VALUE) 23 | torch.manual_seed(cfg.SEED_VALUE) 24 | np.random.seed(cfg.SEED_VALUE) 25 | 26 | # create logger 27 | logger = create_logger(cfg.LOGDIR, phase='train') 28 | logger.info(f'GPU name -> {torch.cuda.get_device_name()}') 29 | logger.info(f'GPU feat -> {torch.cuda.get_device_properties("cuda")}') 30 | logger.info(pprint.pformat(cfg)) 31 | 32 | # cudnn related setting 33 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 34 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 35 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 36 | 37 | writer = SummaryWriter(log_dir=cfg.LOGDIR) 38 | writer.add_text('config', pprint.pformat(cfg), 0) 39 | 40 | # Dataloaders 41 | data_loaders = get_dataloaders(cfg) 42 | 43 | # Compile Loss 44 | criterion = compile_criterion(cfg) 45 | 46 | # Networks and optimizers 47 | model = HMR_VIMO(cfg=cfg) 48 | checkpoint = cfg.MODEL.CHECKPOINT 49 | state_dict = torch.load(checkpoint, map_location=cfg.DEVICE, weights_only=True) 50 | _ = model.load_state_dict(state_dict['state_dict'], strict=False) 51 | 52 | model = model.to(cfg.DEVICE) 53 | model.frozen_modules = [model.backbone] 54 | model.freeze_modules() 55 | 56 | logger.info(f'Loaded pretrained checkpoint {checkpoint}') 57 | logger.info(f'Freeze pretrained backbone') 58 | 59 | if cfg.TRAIN.MULTI_LR: 60 | params = [{'params': [p for p in model.smpl_head.parameters() if p.requires_grad]}] 61 | 62 | if cfg.MODEL.MOTION_MODULE: 63 | params.append({'params': [p for p in model.motion_module.parameters() if p.requires_grad], 64 | 'lr':cfg.TRAIN.LR2}) 65 | 66 | if cfg.MODEL.ST_MODULE: 67 | params.append({'params': [p for p in model.st_module.parameters() if p.requires_grad], 68 | 'lr':cfg.TRAIN.LR2}) 69 | 70 | optimizer = torch.optim.AdamW(params, lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WD) 71 | 72 | logger.info(f'Using multiple learning rates:[{cfg.TRAIN.LR}, {cfg.TRAIN.LR2}] and WD: {cfg.TRAIN.WD}') 73 | else: 74 | optimizer = torch.optim.AdamW(params=[p for p in model.parameters() if p.requires_grad], 75 | lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WD) 76 | 77 | 78 | # ========= Start Training ========= # 79 | Trainer( 80 | cfg=cfg, 81 | data_loaders=data_loaders, 82 | model=model, 83 | criterion=criterion, 84 | optimizer=optimizer, 85 | writer=writer, 86 | lr_scheduler=None, 87 | ).train() 88 | 89 | 90 | 91 | if __name__ == '__main__': 92 | cfg = parse_args() 93 | cfg = prepare_output_dir(cfg) 94 | 95 | main(cfg) 96 | --------------------------------------------------------------------------------