├── README 2.md ├── README.md ├── assets └── framework.jpg ├── compare.py ├── concat.py ├── configs ├── finetune_ggrt_stable.yaml ├── pixelsplat │ ├── decoder │ │ └── splatting_cuda.yaml │ └── encoder │ │ ├── backbone │ │ ├── dino.yaml │ │ └── resnet.yaml │ │ └── epipolar.yaml └── pretrain_ggrt_stable.yaml ├── eval ├── __init__.py ├── dbarf_compute_poses.py ├── eval.py ├── eval_abs_pose_accuracy.py ├── eval_dbarf.py ├── eval_deepvoxels.sh ├── eval_ggrt.py ├── eval_llff_all.sh ├── eval_nerf_synthetic_all.sh ├── finetune_dbarf_llff.sh ├── finetune_dbarf_scannet.sh ├── finetune_llff.sh ├── render_dbarf_llff_video.py ├── render_llff.sh └── render_llff_video.py ├── eval_crop.py ├── eval_crop_final.py ├── finetune_ggrt_stable.py ├── ggrt ├── __init__.py ├── base │ ├── checkpoint_manager.py │ ├── functools.py │ ├── model_base.py │ └── trainer.py ├── config.py ├── data_loaders │ ├── __init__.py │ ├── base_utils.py │ ├── colmap_read_model.py │ ├── create_training_dataset.py │ ├── data_utils.py │ ├── data_verifier.py │ ├── deepvoxels.py │ ├── google_scanned_objects.py │ ├── ibrnet_collected.py │ ├── kitti.py │ ├── llff.py │ ├── llff_data_utils.py │ ├── llff_test.py │ ├── nerf_synthetic.py │ ├── realestate.py │ ├── scannet.py │ ├── spaces_dataset.py │ └── waymo.py ├── dataset │ ├── __init__.py │ ├── data_module.py │ ├── dataset.py │ ├── shims │ │ ├── augmentation_shim.py │ │ ├── bounds_shim.py │ │ ├── crop_shim.py │ │ └── patch_shim.py │ ├── types.py │ ├── validation_wrapper.py │ └── view_sampler │ │ ├── __init__.py │ │ ├── view_sampler.py │ │ ├── view_sampler_all.py │ │ ├── view_sampler_arbitrary.py │ │ ├── view_sampler_bounded.py │ │ └── view_sampler_evaluation.py ├── depth_pose_network.py ├── geometry │ ├── align_poses.py │ ├── camera.py │ ├── depth.py │ ├── epipolar_lines.py │ ├── lie_group │ │ ├── __init__.py │ │ ├── liegroupbase.py │ │ ├── se3.py │ │ ├── se3_common.py │ │ ├── se3q.py │ │ ├── so3.py │ │ ├── so3_common.py │ │ ├── so3q.py │ │ └── utils.py │ ├── projection.py │ ├── rotation.py │ ├── track.py │ └── utils.py ├── global_cfg.py ├── hack_torch │ └── custom_grid.py ├── loss │ ├── criterion.py │ ├── photometric_loss.py │ └── ssim_torch.py ├── misc │ ├── LocalLogger.py │ ├── benchmarker.py │ ├── collation.py │ ├── discrete_probability_distribution.py │ ├── heterogeneous_pairings.py │ ├── image_io.py │ ├── nn_module_tools.py │ ├── sh_rotation.py │ ├── step_tracker.py │ └── wandb_tools.py ├── model │ ├── __init__.py │ ├── barf.py │ ├── dbarf.py │ ├── dgaussian.py │ ├── feature_network.py │ ├── gaussian.py │ ├── ibrnet.py │ ├── mlp_network.py │ ├── nerf.py │ └── pixelsplat │ │ ├── decoder │ │ ├── __init__.py │ │ ├── cuda_splatting.py │ │ ├── decoder.py │ │ └── decoder_splatting_cuda.py │ │ ├── encoder │ │ ├── __init__.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── backbone.py │ │ │ ├── backbone_dino.py │ │ │ └── backbone_resnet.py │ │ ├── common │ │ │ ├── depth_predictor.py │ │ │ ├── gaussian_adapter.py │ │ │ ├── gaussians.py │ │ │ └── sampler.py │ │ ├── encoder.py │ │ ├── encoder_epipolar.py │ │ ├── epipolar │ │ │ ├── conversions.py │ │ │ ├── depth_predictor_monocular.py │ │ │ ├── distribution.py │ │ │ ├── distribution_sampler.py │ │ │ ├── epipolar_sampler.py │ │ │ ├── epipolar_transformer.py │ │ │ └── image_self_attention.py │ │ └── visualization │ │ │ ├── encoder_visualizer.py │ │ │ ├── encoder_visualizer_epipolar.py │ │ │ └── encoder_visualizer_epipolar_cfg.py │ │ ├── encodings │ │ └── positional_encoding.py │ │ ├── interpolatation.py │ │ ├── pixelsplat.py │ │ ├── pixelsplat_crop.py │ │ ├── ply_export.py │ │ ├── transformer │ │ ├── attention.py │ │ ├── feed_forward.py │ │ ├── pre_norm.py │ │ └── transformer.py │ │ ├── types.py │ │ └── wobble.py ├── optimizer.py ├── pose_util.py ├── projection.py ├── render_image.py ├── render_ray.py ├── sample_ray.py ├── utils │ ├── read_colmap_model.py │ └── union_find.py └── visualization │ ├── annotation.py │ ├── camera_trajectory │ ├── interpolation.py │ ├── spin.py │ └── wobble.py │ ├── color_map.py │ ├── colors.py │ ├── drawing │ ├── cameras.py │ ├── coordinate_conversion.py │ ├── lines.py │ ├── points.py │ ├── rendering.py │ └── types.py │ ├── feature_visualizer.py │ ├── layout.py │ ├── pose_visualizer.py │ └── validation_in_3d.py ├── scripts ├── __init__.py ├── colmap_model_to_poses_bounds.py ├── env │ └── dependencies.sh ├── extract_features.py ├── extract_relative_poses.py ├── filter_matches.py ├── match_features.py ├── pairs_from_retrieval.py ├── preprocess_dbarf_dataset.py ├── reconstruction.py ├── shell │ ├── eval_coarse_llff_all.sh │ ├── eval_coarse_nerf_synthetic_all.sh │ ├── eval_coarse_scannet.sh │ ├── eval_dbarf_ibr_collected_all.sh │ ├── eval_dbarf_llff_all.sh │ ├── eval_dbarf_scannet.sh │ ├── eval_llff_all.sh │ ├── render_coarse_llff_all.sh │ ├── render_dbarf_llff_all.sh │ ├── train_coarse_ibrnet.sh │ ├── train_dbarf.sh │ └── train_ibrnet.sh └── utils.py ├── train_ggrt_stable.py └── utils_loc.py /assets/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifuguan/GGRt_official/01886261b6b6b6175b6ea88f44a85c640564ae9f/assets/framework.jpg -------------------------------------------------------------------------------- /compare.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import math 4 | TINY_NUMBER = 1e-6 5 | mse2psnr = lambda x: -10. * np.log(x+TINY_NUMBER) / np.log(10.) 6 | 7 | def img2mse(x, y, mask=None): 8 | ''' 9 | :param x: img 1, [(...), 3] 10 | :param y: img 2, [(...), 3] 11 | :param mask: optional, [(...)] 12 | :return: mse score 13 | ''' 14 | 15 | if mask is None: 16 | return np.mean((x - y) * (x - y)) 17 | else: 18 | return np.sum((x - y) * (x - y) * mask.unsqueeze(-1)) / (np.sum(mask) * x.shape[-1] + TINY_NUMBER) 19 | 20 | def img2psnr(x, y, mask=None): 21 | return mse2psnr(img2mse(x, y, mask).item()) 22 | 23 | def psnr(img1, img2): 24 | psnr=img2psnr(img1,img2) 25 | # mse = numpy.mean( (img1 - img2) ** 2 ) 26 | # if mse == 0: 27 | # return 100 28 | # PIXEL_MAX = 255.0 29 | return psnr 30 | def compare(root): 31 | pics=['pic_006','pic_419','pic_008','pic_241','pic_168','pic_113'] 32 | t=0 33 | i=0 34 | for pic in pics: 35 | img1=Image.open(f'{root}/{pic}/test.png') 36 | img2=Image.open(f'{root}/{pic}/gt.png') 37 | i=i+1 38 | i1_array = np.array(img1) 39 | i2_array = np.array(img2) 40 | img1 = i1_array.astype(np.float32) 41 | img2 = i2_array.astype(np.float32) 42 | img1=img1/255 43 | img2=img2/255 44 | img1=(np.clip(img1[None, ...], a_min=0., a_max=1.)) 45 | 46 | r12=psnr(img1,img2) 47 | t=t+r12 48 | print(r12) 49 | print(f'mean:{t/i}') -------------------------------------------------------------------------------- /concat.py: -------------------------------------------------------------------------------- 1 | from logging import root 2 | from math import ceil 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | def concat(root_1,file_id,crop_h,crop_w): 7 | # images = ['00.png','01.png','02.png','10.png','11.png','12.png','20.png','21.png','22.png'] 8 | img='' 9 | img_array='' 10 | row=ceil(378/crop_h) 11 | col=ceil(504/crop_w) 12 | images =[] 13 | for i in range(row): 14 | for j in range(col): 15 | images.append(f'{i}{j}.png') 16 | # pic=['pic_168','pic_008','pic_006','pic_113','pic_419','pic_241'] 17 | # for im in pic: 18 | root=f'{root_1}/pic_'+file_id[-3:] 19 | for index,value in enumerate(images): 20 | image = os.path.join(root,value) 21 | if index==0: 22 | img_array = np.array(Image.open(image)) 23 | elif index==1: 24 | img_array01 = np.array(Image.open(image)) 25 | img_array = np.concatenate((img_array,img_array01),axis=1)#横向拼接 26 | #img_array = np.concatenate((img_array,img_array2),axis=0)#纵向拼接 27 | elif index==2: 28 | img_array02 = np.array(Image.open(image)) 29 | img_array = np.concatenate((img_array,img_array02[:,crop_w*2-496:,:]),axis=1)#横向拼接 30 | elif index==3: 31 | img_array1=np.array(Image.open(image)) 32 | elif index==4: 33 | img_array11 = np.array(Image.open(image)) 34 | img_array1 = np.concatenate((img_array1,img_array11),axis=1) 35 | elif index==5: 36 | img_array12 = np.array(Image.open(image)) 37 | img_array1 = np.concatenate((img_array1,img_array12[:,crop_w*2-496:,:]),axis=1) 38 | elif index==6: 39 | img_array2=np.array(Image.open(image))[crop_h*2-368:,:,:] 40 | elif index==7: 41 | img_array21 = np.array(Image.open(image))[crop_h*2-368:,:,:] 42 | img_array2 = np.concatenate((img_array2,img_array21),axis=1) 43 | elif index==8: 44 | img_array22 = np.array(Image.open(image))[crop_h*2-368:,:,:] 45 | img_array2 = np.concatenate((img_array2,img_array22[:,crop_w*2-496:,:]),axis=1) 46 | img_array= np.concatenate((img_array,img_array1,img_array2),axis=0)#横向拼接 47 | 48 | img = Image.fromarray(img_array) 49 | 50 | img.save(f'{root_1}/pic_{file_id[-3:]}/test.png')#图保存为png格式 51 | 52 | 53 | if __name__ == '__main__': 54 | root_1='' 55 | file_id='' 56 | crop_h=160 57 | crop_w=224 58 | concat(root_1,file_id,crop_h,crop_w) -------------------------------------------------------------------------------- /configs/finetune_ggrt_stable.yaml: -------------------------------------------------------------------------------- 1 | ### Model Config 2 | defaults: 3 | - pixelsplat/encoder: epipolar 4 | - pixelsplat/decoder: splatting_cuda 5 | 6 | ### INPUT 7 | expname : finetune_dgaussian_stable_room 8 | rootdir : data/ibrnet/train 9 | render_stride : 2 10 | distributed : False 11 | enable_tensorboard : True 12 | enable_visdom : False 13 | seed : 3407 14 | pretrained : False 15 | local_rank : 0 16 | 17 | ## dataset 18 | train_dataset : llff_test 19 | train_scenes : [room] 20 | dataset_weights : [1] 21 | eval_dataset : llff_test 22 | eval_scenes : [room] 23 | # eval_scenes : [trex, fern, flower, leaves, room, fortress, horns, orchids] 24 | 25 | num_source_views : 7 26 | workers : 16 27 | 28 | selection_rule : pose 29 | random_crop : False 30 | outlier_ratio : 0.2 31 | noise_level : 0.15 32 | testskip : 8 33 | 34 | no_load_opt: True 35 | no_load_scheduler: True 36 | 37 | ### TRAINING 38 | n_iters: 5000 39 | N_rand : 500 40 | lrate_feature : 0.001 41 | lrate_mlp : 0.0005 42 | lrate_pose : 0.00002 43 | lrate_decay_factor : 0.5 44 | lrate_decay_steps : 2000 45 | lrate_decay_pose_steps : 2000 46 | coarse_only : True 47 | rectify_inplane_rotation: False 48 | coarse_feat_dim : 64 # original:32 49 | fine_feat_dim : 32 # original:128 50 | anti_alias_pooling : 1 51 | 52 | use_pred_pose: False 53 | use_depth_loss: False 54 | 55 | optimizer: 56 | lr: 5e-5 57 | warm_up_steps: 500 58 | ### TESTING 59 | chunk_size : 2000 60 | 61 | ### RENDERING 62 | N_importance : 0 #64 63 | N_samples : 64 64 | inv_uniform : True 65 | white_bkgd : False 66 | sample_mode : uniform 67 | center_ratio : 0.8 68 | feat_loss_scale : 1e1 69 | crop_size : 2 70 | ### CONSOLE AND TENSORBOARD 71 | n_validation : 1000 72 | n_tensorboard : 2 73 | n_checkpoint : 500 74 | visdom_port : 9000 75 | 76 | ### evaluation options 77 | llffhold : 8 -------------------------------------------------------------------------------- /configs/pixelsplat/decoder/splatting_cuda.yaml: -------------------------------------------------------------------------------- 1 | name: splatting_cuda 2 | -------------------------------------------------------------------------------- /configs/pixelsplat/encoder/backbone/dino.yaml: -------------------------------------------------------------------------------- 1 | name: dino 2 | 3 | model: dino_vitb8 4 | d_out: 512 5 | -------------------------------------------------------------------------------- /configs/pixelsplat/encoder/backbone/resnet.yaml: -------------------------------------------------------------------------------- 1 | name: resnet 2 | 3 | model: resnet50 4 | num_layers: 5 5 | use_first_pool: false 6 | d_out: 512 7 | -------------------------------------------------------------------------------- /configs/pixelsplat/encoder/epipolar.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: dino 3 | 4 | name: epipolar 5 | 6 | opacity_mapping: 7 | initial: 0.0 8 | final: 0.0 9 | warm_up: 1 10 | 11 | num_monocular_samples: 32 12 | num_surfaces: 1 13 | predict_opacity: false 14 | near_disparity: 3.0 15 | 16 | gaussians_per_pixel: 3 17 | 18 | gaussian_adapter: 19 | gaussian_scale_min: 0.5 20 | gaussian_scale_max: 15.0 21 | sh_degree: 4 22 | 23 | d_feature: 128 24 | 25 | epipolar_transformer: 26 | self_attention: 27 | patch_size: 4 28 | num_octaves: 10 29 | num_layers: 2 30 | num_heads: 4 31 | d_token: 128 32 | d_dot: 128 33 | d_mlp: 256 34 | num_octaves: 10 35 | num_layers: 2 36 | num_heads: 4 37 | num_samples: 32 38 | num_context_views: 2 39 | d_dot: 128 40 | d_mlp: 256 41 | downscale: 4 42 | 43 | visualizer: 44 | num_samples: 8 45 | min_resolution: 256 46 | export_ply: false 47 | 48 | apply_bounds_shim: true 49 | 50 | # Use this to ablate the epipolar transformer. 51 | use_epipolar_transformer: true 52 | 53 | use_transmittance: false 54 | -------------------------------------------------------------------------------- /configs/pretrain_ggrt_stable.yaml: -------------------------------------------------------------------------------- 1 | ### Model Config 2 | defaults: 3 | - pixelsplat/encoder: epipolar 4 | - pixelsplat/decoder: splatting_cuda 5 | 6 | ### INPUT 7 | expname : pretrain_llff 8 | rootdir : data/ibrnet/train 9 | render_stride : 2 10 | distributed : False 11 | enable_tensorboard : True 12 | enable_visdom : False 13 | seed : 3407 14 | pretrained : False 15 | local_rank : 0 16 | ckpt_path : model_zoo/generalized_llff_best.pth 17 | 18 | ## dataset 19 | train_dataset : llff+ibrnet_collected 20 | train_scenes : [] 21 | dataset_weights : [0.5, 0.5] 22 | eval_dataset : llff_test 23 | # eval_scenes : [room] 24 | eval_scenes : [trex, fern, flower, leaves, room, fortress, horns, orchids] 25 | 26 | num_source_views : 5 27 | workers : 8 28 | 29 | selection_rule : pose 30 | random_crop : False 31 | outlier_ratio : 0.2 32 | noise_level : 0.15 33 | testskip : 8 34 | 35 | no_load_opt: True 36 | no_load_scheduler: True 37 | 38 | ### TRAINING 39 | n_iters: 6000 40 | N_rand : 500 41 | lrate_feature : 0.001 42 | lrate_mlp : 0.0005 43 | lrate_pose : 0.00002 44 | lrate_decay_factor : 0.5 45 | lrate_decay_steps : 50000 46 | lrate_decay_pose_steps : 50000 47 | coarse_only : True 48 | rectify_inplane_rotation: False 49 | coarse_feat_dim : 64 # original:32 50 | fine_feat_dim : 32 # original:128 51 | anti_alias_pooling : 1 52 | 53 | use_pred_pose: True 54 | use_depth_loss: True 55 | 56 | optimizer: 57 | lr: 1.5e-4 58 | warm_up_steps: 2000 59 | ### TESTING 60 | chunk_size : 2000 61 | 62 | ### RENDERING 63 | N_importance : 0 #64 64 | N_samples : 64 65 | inv_uniform : True 66 | white_bkgd : False 67 | sample_mode : uniform 68 | center_ratio : 0.8 69 | feat_loss_scale : 1e1 70 | 71 | ### CONSOLE AND TENSORBOARD 72 | n_validation : 1000 73 | n_tensorboard : 2 74 | n_checkpoint : 500 75 | visdom_port : 9000 76 | 77 | ### evaluation options 78 | llffhold : 8 -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifuguan/GGRt_official/01886261b6b6b6175b6ea88f44a85c640564ae9f/eval/__init__.py -------------------------------------------------------------------------------- /eval/eval_deepvoxels.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd eval/ 3 | CUDA_VISIBLE_DEVICES=0 python eval.py --config ../configs/eval_deepvoxels.txt --eval_scenes cube & 4 | CUDA_VISIBLE_DEVICES=1 python eval.py --config ../configs/eval_deepvoxels.txt --eval_scenes vase & 5 | CUDA_VISIBLE_DEVICES=2 python eval.py --config ../configs/eval_deepvoxels.txt --eval_scenes greek & 6 | CUDA_VISIBLE_DEVICES=3 python eval.py --config ../configs/eval_deepvoxels.txt --eval_scenes armchair & -------------------------------------------------------------------------------- /eval/eval_llff_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd eval/ 3 | CUDA_VISIBLE_DEVICES=0 python eval.py --config ../configs/eval_llff.txt --eval_scenes horns & 4 | CUDA_VISIBLE_DEVICES=1 python eval.py --config ../configs/eval_llff.txt --eval_scenes trex & 5 | CUDA_VISIBLE_DEVICES=2 python eval.py --config ../configs/eval_llff.txt --eval_scenes room & 6 | CUDA_VISIBLE_DEVICES=3 python eval.py --config ../configs/eval_llff.txt --eval_scenes flower & 7 | CUDA_VISIBLE_DEVICES=4 python eval.py --config ../configs/eval_llff.txt --eval_scenes orchids & 8 | CUDA_VISIBLE_DEVICES=5 python eval.py --config ../configs/eval_llff.txt --eval_scenes leaves & 9 | CUDA_VISIBLE_DEVICES=6 python eval.py --config ../configs/eval_llff.txt --eval_scenes fern & 10 | CUDA_VISIBLE_DEVICES=7 python eval.py --config ../configs/eval_llff.txt --eval_scenes fortress & 11 | 12 | -------------------------------------------------------------------------------- /eval/eval_nerf_synthetic_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd eval/ 3 | 4 | CUDA_VISIBLE_DEVICES=0 python eval.py --config ../configs/eval_nerf_synthetic.txt --eval_scenes mic & 5 | CUDA_VISIBLE_DEVICES=1 python eval.py --config ../configs/eval_nerf_synthetic.txt --eval_scenes chair & 6 | CUDA_VISIBLE_DEVICES=2 python eval.py --config ../configs/eval_nerf_synthetic.txt --eval_scenes lego & 7 | CUDA_VISIBLE_DEVICES=3 python eval.py --config ../configs/eval_nerf_synthetic.txt --eval_scenes ficus & 8 | CUDA_VISIBLE_DEVICES=4 python eval.py --config ../configs/eval_nerf_synthetic.txt --eval_scenes materials & 9 | CUDA_VISIBLE_DEVICES=5 python eval.py --config ../configs/eval_nerf_synthetic.txt --eval_scenes hotdog & 10 | CUDA_VISIBLE_DEVICES=6 python eval.py --config ../configs/eval_nerf_synthetic.txt --eval_scenes ship & 11 | CUDA_VISIBLE_DEVICES=7 python eval.py --config ../configs/eval_nerf_synthetic.txt --eval_scenes drums & 12 | 13 | -------------------------------------------------------------------------------- /eval/finetune_dbarf_llff.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | GPU_ID=$1 4 | 5 | HOME_DIR=$HOME #'/home/chenyu' 6 | echo $HOME_DIR 7 | CODE_DIR=${HOME_DIR}/'Projects/dbarf/ibrnet' 8 | CONFIG_DIR=$CODE_DIR/'configs' 9 | ROOT_DIR=${HOME_DIR}/'Datasets/IBRNet/eval' 10 | EXPNAME='finetune_dbarf_llff' 11 | PRETRAINED_MODEL_PATH=${HOME_DIR}/'Datasets/IBRNet/pretrained_model/dbarf_model_200000.pth' 12 | 13 | DATASET_NAME='llff_test' 14 | scenes=("fern" "flower" "fortress" "horns" "leaves" "orchids" "room" "trex") 15 | 16 | cd $CODE_DIR 17 | 18 | for((i=0;i<${#scenes[@]};i++)); 19 | do 20 | echo "Finetuning ${scenes[i]} on single machine" 21 | CUDA_VISIBLE_DEVICES=${GPU_ID} python train_dbarf.py \ 22 | --config ${CONFIG_DIR}/finetune_dbarf.txt \ 23 | --expname ${EXPNAME}_${scenes[i]}_test_depth \ 24 | --rootdir $ROOT_DIR \ 25 | --ckpt_path $PRETRAINED_MODEL_PATH \ 26 | --train_dataset $DATASET_NAME \ 27 | --train_scenes ${scenes[i]} \ 28 | --eval_dataset $DATASET_NAME \ 29 | --eval_scenes ${scenes[i]} 30 | done 31 | -------------------------------------------------------------------------------- /eval/finetune_dbarf_scannet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | GPU_ID=$1 4 | 5 | HOME_DIR=$HOME #'/home/chenyu' 6 | echo $HOME_DIR 7 | CODE_DIR=${HOME_DIR}/'Projects/PoseNeRF/ibrnet' 8 | CONFIG_DIR=$CODE_DIR/'configs' 9 | ROOT_DIR=${HOME_DIR}/'Datasets/scannet' 10 | EXPNAME='finetune_dbarf_scannet' 11 | PRETRAINED_MODEL_PATH=${HOME_DIR}/'Datasets/IBRNet/pretrained_model/dbarf_model_200000.pth' 12 | 13 | DATASET_NAME='scannet' 14 | scenes=("scene0671_00" "scene0673_03" "scene0675_00" "scene_0675_01" "scene0680_00" "scene0684_00" "scene0684_01") 15 | 16 | cd $CODE_DIR 17 | 18 | for((i=0;i<${#scenes[@]};i++)); 19 | do 20 | echo "Finetuning ${scenes[i]} on single machine" 21 | CUDA_VISIBLE_DEVICES=${GPU_ID} python train_dbarf.py \ 22 | --config ${CONFIG_DIR}/finetune_dbarf.txt \ 23 | --expname ${EXPNAME}_${scenes[i]} \ 24 | --rootdir $ROOT_DIR \ 25 | --ckpt_path $PRETRAINED_MODEL_PATH \ 26 | --train_dataset $DATASET_NAME \ 27 | --train_scenes ${scenes[i]} \ 28 | --eval_dataset $DATASET_NAME \ 29 | --eval_scenes ${scenes[i]} 30 | done 31 | -------------------------------------------------------------------------------- /eval/finetune_llff.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SFX='1' 4 | FILE='run_nerf' 5 | CUDA_VISIBLE_DEVICES=0 python $FILE.py --config configs/finetune_llff.txt --train_scenes orchids --eval_scenes orchids --expname finetune_orchids_$SFX & 6 | CUDA_VISIBLE_DEVICES=1 python $FILE.py --config configs/finetune_llff.txt --train_scenes horns --eval_scenes horns --expname finetune_horns_$SFX & 7 | CUDA_VISIBLE_DEVICES=2 python $FILE.py --config configs/finetune_llff.txt --train_scenes trex --eval_scenes trex --expname finetune_trex_$SFX & 8 | CUDA_VISIBLE_DEVICES=3 python $FILE.py --config configs/finetune_llff.txt --train_scenes room --eval_scenes room --expname finetune_room_$SFX & 9 | CUDA_VISIBLE_DEVICES=4 python $FILE.py --config configs/finetune_llff.txt --train_scenes flower --eval_scenes flower --expname finetune_flower_$SFX & 10 | CUDA_VISIBLE_DEVICES=5 python $FILE.py --config configs/finetune_llff.txt --train_scenes leaves --eval_scenes leaves --expname finetune_leaves_$SFX & 11 | CUDA_VISIBLE_DEVICES=6 python $FILE.py --config configs/finetune_llff.txt --train_scenes fern --eval_scenes fern --expname finetune_fern_$SFX & 12 | CUDA_VISIBLE_DEVICES=7 python $FILE.py --config configs/finetune_llff.txt --train_scenes fortress --eval_scenes fortress --expname finetune_fortress_$SFX 13 | 14 | -------------------------------------------------------------------------------- /eval/render_llff.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd eval/ 4 | FILE='render_llff_video' 5 | CUDA_VISIBLE_DEVICES=0 python $FILE.py --config ../configs/render_llff_video.txt --train_scenes orchids --eval_scenes orchids & 6 | CUDA_VISIBLE_DEVICES=1 python $FILE.py --config ../configs/render_llff_video.txt --train_scenes horns --eval_scenes horns & 7 | CUDA_VISIBLE_DEVICES=2 python $FILE.py --config ../configs/render_llff_video.txt --train_scenes trex --eval_scenes trex & 8 | CUDA_VISIBLE_DEVICES=3 python $FILE.py --config ../configs/render_llff_video.txt --train_scenes room --eval_scenes room & 9 | CUDA_VISIBLE_DEVICES=4 python $FILE.py --config ../configs/render_llff_video.txt --train_scenes flower --eval_scenes flower & 10 | CUDA_VISIBLE_DEVICES=5 python $FILE.py --config ../configs/render_llff_video.txt --train_scenes leaves --eval_scenes leaves & 11 | CUDA_VISIBLE_DEVICES=6 python $FILE.py --config ../configs/render_llff_video.txt --train_scenes fern --eval_scenes fern & 12 | CUDA_VISIBLE_DEVICES=7 python $FILE.py --config ../configs/render_llff_video.txt --train_scenes fortress --eval_scenes fortress 13 | 14 | -------------------------------------------------------------------------------- /ggrt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifuguan/GGRt_official/01886261b6b6b6175b6ea88f44a85c640564ae9f/ggrt/__init__.py -------------------------------------------------------------------------------- /ggrt/base/model_base.py: -------------------------------------------------------------------------------- 1 | 2 | class Model(object): 3 | def __init__(self, args) -> None: 4 | self.args = args 5 | 6 | def to_distributed(self): 7 | raise NotImplementedError 8 | 9 | def switch_to_eval(self): 10 | raise NotImplementedError 11 | 12 | def switch_to_train(self): 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /ggrt/data_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .google_scanned_objects import * 17 | from .realestate import * 18 | from .deepvoxels import * 19 | from .realestate import * 20 | from .llff import * 21 | from .llff_test import * 22 | from .ibrnet_collected import * 23 | from .realestate import * 24 | from .spaces_dataset import * 25 | from .nerf_synthetic import * 26 | from .scannet import * 27 | from .waymo import * 28 | from .kitti import * 29 | # from .scanet_test import ScannetTrainDataset 30 | 31 | 32 | dataset_dict = { 33 | 'spaces': SpacesFreeDataset, 34 | 'google_scanned': GoogleScannedDataset, 35 | 'realestate': RealEstateDataset, 36 | 'deepvoxels': DeepVoxelsDataset, 37 | 'nerf_synthetic': NerfSyntheticDataset, 38 | 'llff': LLFFDataset, 39 | 'ibrnet_collected': IBRNetCollectedDataset, 40 | 'scannet': ScannetDataset, 41 | 'llff_test': LLFFTestDataset, 42 | 'waymo':WaymoStaticDataset, 43 | 'kitti':KittiPixelSource, 44 | # "scannet": ScannetTrainDataset 45 | } -------------------------------------------------------------------------------- /ggrt/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | -------------------------------------------------------------------------------- /ggrt/dataset/data_module.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Callable 4 | 5 | import numpy as np 6 | import torch 7 | from torch import Generator, nn 8 | from torch.utils.data import DataLoader, Dataset, IterableDataset 9 | 10 | from ggrt.misc.step_tracker import StepTracker 11 | from .types import DataShim, Stage 12 | from .validation_wrapper import ValidationWrapper 13 | 14 | 15 | def get_data_shim(encoder: nn.Module) -> DataShim: 16 | """Get functions that modify the batch. It's sometimes necessary to modify batches 17 | outside the data loader because GPU computations are required to modify the batch or 18 | because the modification depends on something outside the data loader. 19 | """ 20 | 21 | shims: list[DataShim] = [] 22 | if hasattr(encoder, "get_data_shim"): 23 | shims.append(encoder.get_data_shim()) 24 | 25 | def combined_shim(batch): 26 | for shim in shims: 27 | batch = shim(batch) 28 | return batch 29 | 30 | return combined_shim 31 | 32 | 33 | @dataclass 34 | class DataLoaderStageCfg: 35 | batch_size: int 36 | num_workers: int 37 | persistent_workers: bool 38 | seed: int | None 39 | 40 | 41 | @dataclass 42 | class DataLoaderCfg: 43 | train: DataLoaderStageCfg 44 | test: DataLoaderStageCfg 45 | val: DataLoaderStageCfg 46 | 47 | 48 | DatasetShim = Callable[[Dataset, Stage], Dataset] 49 | 50 | 51 | def worker_init_fn(worker_id: int) -> None: 52 | random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 53 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 54 | -------------------------------------------------------------------------------- /ggrt/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .view_sampler import ViewSamplerCfg 4 | 5 | 6 | @dataclass 7 | class DatasetCfgCommon: 8 | image_shape: list[int] 9 | background_color: list[float] 10 | cameras_are_circular: bool 11 | overfit_to_scene: str | None 12 | view_sampler: ViewSamplerCfg 13 | -------------------------------------------------------------------------------- /ggrt/dataset/shims/augmentation_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float 3 | from torch import Tensor 4 | 5 | from ..types import AnyExample, AnyViews 6 | 7 | 8 | def reflect_extrinsics( 9 | extrinsics: Float[Tensor, "*batch 4 4"], 10 | ) -> Float[Tensor, "*batch 4 4"]: 11 | reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) 12 | reflect[0, 0] = -1 13 | return reflect @ extrinsics @ reflect 14 | 15 | 16 | def reflect_views(views: AnyViews) -> AnyViews: 17 | return { 18 | **views, 19 | "image": views["image"].flip(-1), 20 | "extrinsics": reflect_extrinsics(views["extrinsics"]), 21 | } 22 | 23 | 24 | def apply_augmentation_shim( 25 | example: AnyExample, 26 | generator: torch.Generator | None = None, 27 | ) -> AnyExample: 28 | """Randomly augment the training images.""" 29 | # Do not augment with 50% chance. 30 | if torch.rand(tuple(), generator=generator) < 0.5: 31 | return example 32 | 33 | return { 34 | **example, 35 | "context": reflect_views(example["context"]), 36 | "target": reflect_views(example["target"]), 37 | } 38 | -------------------------------------------------------------------------------- /ggrt/dataset/shims/bounds_shim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, reduce, repeat 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | from ..types import BatchedExample 7 | 8 | 9 | def compute_depth_for_disparity( 10 | extrinsics: Float[Tensor, "batch view 4 4"], 11 | intrinsics: Float[Tensor, "batch view 3 3"], 12 | image_shape: tuple[int, int], 13 | disparity: float, 14 | delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth. 15 | ) -> Float[Tensor, " batch"]: 16 | """Compute the depth at which moving the maximum distance between cameras 17 | corresponds to the specified disparity (in pixels). 18 | """ 19 | 20 | # Use the furthest distance between cameras as the baseline. 21 | origins = extrinsics[:, :, :3, 3] 22 | deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) 23 | deltas = deltas.clip(min=delta_min) 24 | baselines = reduce(deltas, "b v ov -> b", "max") 25 | 26 | # Compute a single pixel's size at depth 1. 27 | h, w = image_shape 28 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) 29 | pixel_size = einsum( 30 | intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i" 31 | ) 32 | 33 | # This wouldn't make sense with non-square pixels, but then again, non-square pixels 34 | # don't make much sense anyway. 35 | mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") 36 | 37 | return baselines / (disparity * mean_pixel_size) 38 | 39 | 40 | def apply_bounds_shim( 41 | batch: BatchedExample, 42 | near_disparity: float, 43 | far_disparity: float, 44 | ) -> BatchedExample: 45 | """Compute reasonable near and far planes (lower and upper bounds on depth). This 46 | assumes that all of an example's views are of roughly the same thing. 47 | """ 48 | 49 | context = batch["context"] 50 | _, cv, _, h, w = context["image"].shape 51 | 52 | # Compute near and far planes using the context views. 53 | near = compute_depth_for_disparity( 54 | context["extrinsics"], 55 | context["intrinsics"], 56 | (h, w), 57 | near_disparity, 58 | ) 59 | far = compute_depth_for_disparity( 60 | context["extrinsics"], 61 | context["intrinsics"], 62 | (h, w), 63 | far_disparity, 64 | ) 65 | 66 | target = batch["target"] 67 | _, tv, _, _, _ = target["image"].shape 68 | return { 69 | **batch, 70 | "context": { 71 | **context, 72 | "near": repeat(near, "b -> b v", v=cv), 73 | "far": repeat(far, "b -> b v", v=cv), 74 | }, 75 | "target": { 76 | **target, 77 | "near": repeat(near, "b -> b v", v=tv), 78 | "far": repeat(far, "b -> b v", v=tv), 79 | }, 80 | } 81 | -------------------------------------------------------------------------------- /ggrt/dataset/shims/crop_shim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from PIL import Image 6 | from torch import Tensor 7 | 8 | from ..types import AnyExample, AnyViews 9 | 10 | 11 | def rescale( 12 | image: Float[Tensor, "3 h_in w_in"], 13 | shape: tuple[int, int], 14 | ) -> Float[Tensor, "3 h_out w_out"]: 15 | h, w = shape 16 | image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) 17 | image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() 18 | image_new = Image.fromarray(image_new) 19 | image_new = image_new.resize((w, h), Image.LANCZOS) 20 | image_new = np.array(image_new) / 255 21 | image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) 22 | return rearrange(image_new, "h w c -> c h w") 23 | 24 | 25 | def center_crop( 26 | images: Float[Tensor, "*#batch c h w"], 27 | intrinsics: Float[Tensor, "*#batch 3 3"], 28 | shape: tuple[int, int], 29 | ) -> tuple[ 30 | Float[Tensor, "*#batch c h_out w_out"], # updated images 31 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 32 | ]: 33 | *_, h_in, w_in = images.shape 34 | h_out, w_out = shape 35 | 36 | # Note that odd input dimensions induce half-pixel misalignments. 37 | row = (h_in - h_out) // 2 38 | col = (w_in - w_out) // 2 39 | 40 | # Center-crop the image. 41 | images = images[..., :, row : row + h_out, col : col + w_out] 42 | 43 | # Adjust the intrinsics to account for the cropping. 44 | intrinsics = intrinsics.clone() 45 | intrinsics[..., 0, 0] *= w_in / w_out # fx 46 | intrinsics[..., 1, 1] *= h_in / h_out # fy 47 | 48 | return images, intrinsics 49 | 50 | 51 | def rescale_and_crop( 52 | images: Float[Tensor, "*#batch c h w"], 53 | intrinsics: Float[Tensor, "*#batch 3 3"], 54 | shape: tuple[int, int], 55 | ) -> tuple[ 56 | Float[Tensor, "*#batch c h_out w_out"], # updated images 57 | Float[Tensor, "*#batch 3 3"], # updated intrinsics 58 | ]: 59 | *_, h_in, w_in = images.shape 60 | h_out, w_out = shape 61 | assert h_out <= h_in and w_out <= w_in 62 | 63 | scale_factor = max(h_out / h_in, w_out / w_in) 64 | h_scaled = round(h_in * scale_factor) 65 | w_scaled = round(w_in * scale_factor) 66 | assert h_scaled == h_out or w_scaled == w_out 67 | 68 | # Reshape the images to the correct size. Assume we don't have to worry about 69 | # changing the intrinsics based on how the images are rounded. 70 | *batch, c, h, w = images.shape 71 | images = images.reshape(-1, c, h, w) 72 | images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) 73 | images = images.reshape(*batch, c, h_scaled, w_scaled) 74 | 75 | return center_crop(images, intrinsics, shape) 76 | 77 | 78 | def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int]) -> AnyViews: 79 | images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape) 80 | return { 81 | **views, 82 | "image": images, 83 | "intrinsics": intrinsics, 84 | } 85 | 86 | 87 | def apply_crop_shim(example: AnyExample, shape: tuple[int, int]) -> AnyExample: 88 | """Crop images in the example.""" 89 | return { 90 | **example, 91 | "context": apply_crop_shim_to_views(example["context"], shape), 92 | "target": apply_crop_shim_to_views(example["target"], shape), 93 | } 94 | -------------------------------------------------------------------------------- /ggrt/dataset/shims/patch_shim.py: -------------------------------------------------------------------------------- 1 | from ..types import BatchedExample, BatchedViews 2 | 3 | 4 | def apply_patch_shim_to_views(views: BatchedViews, patch_size: int) -> BatchedViews: 5 | _, _, _, h, w = views["image"].shape 6 | 7 | # Image size must be even so that naive center-cropping does not cause misalignment. 8 | assert h % 2 == 0 and w % 2 == 0 9 | 10 | h_new = (h // patch_size) * patch_size 11 | row = (h - h_new) // 2 12 | w_new = (w // patch_size) * patch_size 13 | col = (w - w_new) // 2 14 | 15 | # Center-crop the image. 16 | image = views["image"][:, :, :, row : row + h_new, col : col + w_new] 17 | 18 | # Adjust the intrinsics to account for the cropping. 19 | intrinsics = views["intrinsics"].clone() 20 | intrinsics[:, :, 0, 0] *= w / w_new # fx 21 | intrinsics[:, :, 1, 1] *= h / h_new # fy 22 | 23 | return { 24 | **views, 25 | "image": image, 26 | "intrinsics": intrinsics, 27 | } 28 | 29 | 30 | def apply_patch_shim(batch: BatchedExample, patch_size: int) -> BatchedExample: 31 | """Crop images in the batch so that their dimensions are cleanly divisible by the 32 | specified patch size. 33 | """ 34 | return { 35 | **batch, 36 | "context": apply_patch_shim_to_views(batch["context"], patch_size), 37 | "target": apply_patch_shim_to_views(batch["target"], patch_size), 38 | } 39 | -------------------------------------------------------------------------------- /ggrt/dataset/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Literal, TypedDict 2 | 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | Stage = Literal["train", "val", "test"] 7 | 8 | 9 | # The following types mainly exist to make type-hinted keys show up in VS Code. Some 10 | # dimensions are annotated as "_" because either: 11 | # 1. They're expected to change as part of a function call (e.g., resizing the dataset). 12 | # 2. They're expected to vary within the same function call (e.g., the number of views, 13 | # which differs between context and target BatchedViews). 14 | 15 | 16 | class BatchedViews(TypedDict, total=False): 17 | extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4 18 | intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3 19 | image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width 20 | near: Float[Tensor, "batch _"] # batch view 21 | far: Float[Tensor, "batch _"] # batch view 22 | index: Int64[Tensor, "batch _"] # batch view 23 | 24 | 25 | class BatchedExample(TypedDict, total=False): 26 | target: BatchedViews 27 | context: BatchedViews 28 | scene: list[str] 29 | 30 | 31 | class UnbatchedViews(TypedDict, total=False): 32 | extrinsics: Float[Tensor, "_ 4 4"] 33 | intrinsics: Float[Tensor, "_ 3 3"] 34 | image: Float[Tensor, "_ 3 height width"] 35 | near: Float[Tensor, " _"] 36 | far: Float[Tensor, " _"] 37 | index: Int64[Tensor, " _"] 38 | 39 | 40 | class UnbatchedExample(TypedDict, total=False): 41 | target: UnbatchedViews 42 | context: UnbatchedViews 43 | scene: str 44 | 45 | 46 | # A data shim modifies the example after it's been returned from the data loader. 47 | DataShim = Callable[[BatchedExample], BatchedExample] 48 | 49 | AnyExample = BatchedExample | UnbatchedExample 50 | AnyViews = BatchedViews | UnbatchedViews 51 | -------------------------------------------------------------------------------- /ggrt/dataset/validation_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional 2 | 3 | import torch 4 | from torch.utils.data import Dataset, IterableDataset 5 | 6 | 7 | class ValidationWrapper(Dataset): 8 | """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a 9 | visualization step. 10 | """ 11 | 12 | dataset: Dataset 13 | dataset_iterator: Optional[Iterator] 14 | length: int 15 | 16 | def __init__(self, dataset: Dataset, length: int) -> None: 17 | super().__init__() 18 | self.dataset = dataset 19 | self.length = length 20 | self.dataset_iterator = None 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | def __getitem__(self, index: int): 26 | if isinstance(self.dataset, IterableDataset): 27 | if self.dataset_iterator is None: 28 | self.dataset_iterator = iter(self.dataset) 29 | return next(self.dataset_iterator) 30 | 31 | random_index = torch.randint(0, len(self.dataset), tuple()) 32 | return self.dataset[random_index.item()] 33 | -------------------------------------------------------------------------------- /ggrt/dataset/view_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .dbarf.misc.step_tracker import StepTracker 4 | from ..types import Stage 5 | from .view_sampler import ViewSampler 6 | from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg 7 | from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg 8 | from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg 9 | from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg 10 | 11 | VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = { 12 | "all": ViewSamplerAll, 13 | "arbitrary": ViewSamplerArbitrary, 14 | "bounded": ViewSamplerBounded, 15 | "evaluation": ViewSamplerEvaluation, 16 | } 17 | 18 | ViewSamplerCfg = ( 19 | ViewSamplerArbitraryCfg 20 | | ViewSamplerBoundedCfg 21 | | ViewSamplerEvaluationCfg 22 | | ViewSamplerAllCfg 23 | ) 24 | 25 | 26 | def get_view_sampler( 27 | cfg: ViewSamplerCfg, 28 | stage: Stage, 29 | overfit: bool, 30 | cameras_are_circular: bool, 31 | step_tracker: StepTracker | None, 32 | ) -> ViewSampler[Any]: 33 | return VIEW_SAMPLERS[cfg.name]( 34 | cfg, 35 | stage, 36 | overfit, 37 | cameras_are_circular, 38 | step_tracker, 39 | ) 40 | -------------------------------------------------------------------------------- /ggrt/dataset/view_sampler/view_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .dbarf.misc.step_tracker import StepTracker 9 | from ..types import Stage 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | class ViewSampler(ABC, Generic[T]): 15 | cfg: T 16 | stage: Stage 17 | is_overfitting: bool 18 | cameras_are_circular: bool 19 | step_tracker: StepTracker | None 20 | 21 | def __init__( 22 | self, 23 | cfg: T, 24 | stage: Stage, 25 | is_overfitting: bool, 26 | cameras_are_circular: bool, 27 | step_tracker: StepTracker | None, 28 | ) -> None: 29 | self.cfg = cfg 30 | self.stage = stage 31 | self.is_overfitting = is_overfitting 32 | self.cameras_are_circular = cameras_are_circular 33 | self.step_tracker = step_tracker 34 | 35 | @abstractmethod 36 | def sample( 37 | self, 38 | scene: str, 39 | extrinsics: Float[Tensor, "view 4 4"], 40 | intrinsics: Float[Tensor, "view 3 3"], 41 | device: torch.device = torch.device("cpu"), 42 | ) -> tuple[ 43 | Int64[Tensor, " context_view"], # indices for context views 44 | Int64[Tensor, " target_view"], # indices for target views 45 | ]: 46 | pass 47 | 48 | @property 49 | @abstractmethod 50 | def num_target_views(self) -> int: 51 | pass 52 | 53 | @property 54 | @abstractmethod 55 | def num_context_views(self) -> int: 56 | pass 57 | 58 | @property 59 | def global_step(self) -> int: 60 | return 0 if self.step_tracker is None else self.step_tracker.get_step() 61 | -------------------------------------------------------------------------------- /ggrt/dataset/view_sampler/view_sampler_all.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerAllCfg: 13 | name: Literal["all"] 14 | 15 | 16 | class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]): 17 | def sample( 18 | self, 19 | scene: str, 20 | extrinsics: Float[Tensor, "view 4 4"], 21 | intrinsics: Float[Tensor, "view 3 3"], 22 | device: torch.device = torch.device("cpu"), 23 | ) -> tuple[ 24 | Int64[Tensor, " context_view"], # indices for context views 25 | Int64[Tensor, " target_view"], # indices for target views 26 | ]: 27 | v, _, _ = extrinsics.shape 28 | all_frames = torch.arange(v, device=device) 29 | return all_frames, all_frames 30 | 31 | @property 32 | def num_context_views(self) -> int: 33 | return 0 34 | 35 | @property 36 | def num_target_views(self) -> int: 37 | return 0 38 | -------------------------------------------------------------------------------- /ggrt/dataset/view_sampler/view_sampler_arbitrary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerArbitraryCfg: 13 | name: Literal["arbitrary"] 14 | num_context_views: int 15 | num_target_views: int 16 | context_views: list[int] | None 17 | target_views: list[int] | None 18 | 19 | 20 | class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): 21 | def sample( 22 | self, 23 | scene: str, 24 | extrinsics: Float[Tensor, "view 4 4"], 25 | intrinsics: Float[Tensor, "view 3 3"], 26 | device: torch.device = torch.device("cpu"), 27 | ) -> tuple[ 28 | Int64[Tensor, " context_view"], # indices for context views 29 | Int64[Tensor, " target_view"], # indices for target views 30 | ]: 31 | """Arbitrarily sample context and target views.""" 32 | num_views, _, _ = extrinsics.shape 33 | 34 | index_context = torch.randint( 35 | 0, 36 | num_views, 37 | size=(self.cfg.num_context_views,), 38 | device=device, 39 | ) 40 | 41 | # Allow the context views to be fixed. 42 | if self.cfg.context_views is not None: 43 | assert len(self.cfg.context_views) == self.cfg.num_context_views 44 | index_context = torch.tensor( 45 | self.cfg.context_views, dtype=torch.int64, device=device 46 | ) 47 | 48 | index_target = torch.randint( 49 | 0, 50 | num_views, 51 | size=(self.cfg.num_target_views,), 52 | device=device, 53 | ) 54 | 55 | # Allow the target views to be fixed. 56 | if self.cfg.target_views is not None: 57 | assert len(self.cfg.target_views) == self.cfg.num_target_views 58 | index_target = torch.tensor( 59 | self.cfg.target_views, dtype=torch.int64, device=device 60 | ) 61 | 62 | return index_context, index_target 63 | 64 | @property 65 | def num_context_views(self) -> int: 66 | return self.cfg.num_context_views 67 | 68 | @property 69 | def num_target_views(self) -> int: 70 | return self.cfg.num_target_views 71 | -------------------------------------------------------------------------------- /ggrt/dataset/view_sampler/view_sampler_bounded.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from jaxtyping import Float, Int64 6 | from torch import Tensor 7 | 8 | from .view_sampler import ViewSampler 9 | 10 | 11 | @dataclass 12 | class ViewSamplerBoundedCfg: 13 | name: Literal["bounded"] 14 | num_context_views: int 15 | num_target_views: int 16 | min_distance_between_context_views: int 17 | max_distance_between_context_views: int 18 | min_distance_to_context_views: int 19 | warm_up_steps: int 20 | initial_min_distance_between_context_views: int 21 | initial_max_distance_between_context_views: int 22 | 23 | 24 | class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): 25 | def schedule(self, initial: int, final: int) -> int: 26 | fraction = self.global_step / self.cfg.warm_up_steps 27 | return min(initial + int((final - initial) * fraction), final) 28 | 29 | def sample( 30 | self, 31 | scene: str, 32 | extrinsics: Float[Tensor, "view 4 4"], 33 | intrinsics: Float[Tensor, "view 3 3"], 34 | device: torch.device = torch.device("cpu"), 35 | ) -> tuple[ 36 | Int64[Tensor, " context_view"], # indices for context views 37 | Int64[Tensor, " target_view"], # indices for target views 38 | ]: 39 | num_views, _, _ = extrinsics.shape 40 | 41 | # Compute the context view spacing based on the current global step. 42 | if self.stage == "test": 43 | # When testing, always use the full gap. 44 | max_gap = self.cfg.max_distance_between_context_views 45 | min_gap = self.cfg.max_distance_between_context_views 46 | elif self.cfg.warm_up_steps > 0: 47 | max_gap = self.schedule( 48 | self.cfg.initial_max_distance_between_context_views, 49 | self.cfg.max_distance_between_context_views, 50 | ) 51 | min_gap = self.schedule( 52 | self.cfg.initial_min_distance_between_context_views, 53 | self.cfg.min_distance_between_context_views, 54 | ) 55 | else: 56 | max_gap = self.cfg.max_distance_between_context_views 57 | min_gap = self.cfg.min_distance_between_context_views 58 | 59 | # Pick the gap between the context views. 60 | if not self.cameras_are_circular: 61 | max_gap = min(num_views - 1, min_gap) 62 | min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) 63 | if max_gap < min_gap: 64 | raise ValueError("Example does not have enough frames!") 65 | context_gap = torch.randint( 66 | min_gap, 67 | max_gap + 1, 68 | size=tuple(), 69 | device=device, 70 | ).item() 71 | 72 | # Pick the left and right context indices. 73 | index_context_left = torch.randint( 74 | num_views if self.cameras_are_circular else num_views - context_gap, 75 | size=tuple(), 76 | device=device, 77 | ).item() 78 | if self.stage == "test": 79 | index_context_left = index_context_left * 0 80 | index_context_right = index_context_left + context_gap 81 | 82 | if self.is_overfitting: 83 | index_context_left *= 0 84 | index_context_right *= 0 85 | index_context_right += max_gap 86 | 87 | # Pick the target view indices. 88 | if self.stage == "test": 89 | # When testing, pick all. 90 | index_target = torch.arange( 91 | index_context_left, 92 | index_context_right + 1, 93 | device=device, 94 | ) 95 | else: 96 | # When training or validating (visualizing), pick at random. 97 | index_target = torch.randint( 98 | index_context_left + self.cfg.min_distance_to_context_views, 99 | index_context_right + 1 - self.cfg.min_distance_to_context_views, 100 | size=(self.cfg.num_target_views,), 101 | device=device, 102 | ) 103 | 104 | # Apply modulo for circular datasets. 105 | if self.cameras_are_circular: 106 | index_target %= num_views 107 | index_context_right %= num_views 108 | 109 | return ( 110 | torch.tensor((index_context_left, index_context_right)), 111 | index_target, 112 | ) 113 | 114 | @property 115 | def num_context_views(self) -> int: 116 | return 2 117 | 118 | @property 119 | def num_target_views(self) -> int: 120 | return self.cfg.num_target_views 121 | -------------------------------------------------------------------------------- /ggrt/dataset/view_sampler/view_sampler_evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import torch 7 | from dacite import Config, from_dict 8 | from jaxtyping import Float, Int64 9 | from torch import Tensor 10 | 11 | from ...evaluation.evaluation_index_generator import IndexEntry 12 | from .dbarf.misc.step_tracker import StepTracker 13 | from ..types import Stage 14 | from .view_sampler import ViewSampler 15 | 16 | 17 | @dataclass 18 | class ViewSamplerEvaluationCfg: 19 | name: Literal["evaluation"] 20 | index_path: Path 21 | num_context_views: int 22 | 23 | 24 | class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): 25 | index: dict[str, IndexEntry | None] 26 | 27 | def __init__( 28 | self, 29 | cfg: ViewSamplerEvaluationCfg, 30 | stage: Stage, 31 | is_overfitting: bool, 32 | cameras_are_circular: bool, 33 | step_tracker: StepTracker | None, 34 | ) -> None: 35 | super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) 36 | 37 | dacite_config = Config(cast=[tuple]) 38 | with cfg.index_path.open("r") as f: 39 | self.index = { 40 | k: None if v is None else from_dict(IndexEntry, v, dacite_config) 41 | for k, v in json.load(f).items() 42 | } 43 | 44 | def sample( 45 | self, 46 | scene: str, 47 | extrinsics: Float[Tensor, "view 4 4"], 48 | intrinsics: Float[Tensor, "view 3 3"], 49 | device: torch.device = torch.device("cpu"), 50 | ) -> tuple[ 51 | Int64[Tensor, " context_view"], # indices for context views 52 | Int64[Tensor, " target_view"], # indices for target views 53 | ]: 54 | entry = self.index.get(scene) 55 | if entry is None: 56 | raise ValueError(f"No indices available for scene {scene}.") 57 | context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) 58 | target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) 59 | return context_indices, target_indices 60 | 61 | @property 62 | def num_context_views(self) -> int: 63 | return 0 64 | 65 | @property 66 | def num_target_views(self) -> int: 67 | return 0 68 | -------------------------------------------------------------------------------- /ggrt/geometry/depth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def is_tuple(data): 6 | """Checks if data is a tuple.""" 7 | return isinstance(data, tuple) 8 | 9 | 10 | def is_list(data): 11 | """Checks if data is a list.""" 12 | return isinstance(data, list) 13 | 14 | 15 | def is_seq(data): 16 | """Checks if data is a list or tuple.""" 17 | return is_tuple(data) or is_list(data) 18 | 19 | 20 | def inv2depth(inv_depth): 21 | """ 22 | Invert an inverse depth map to produce a depth map 23 | 24 | Parameters 25 | ---------- 26 | inv_depth : torch.Tensor or list of torch.Tensor [B,1,H,W] 27 | Inverse depth map 28 | 29 | Returns 30 | ------- 31 | depth : torch.Tensor or list of torch.Tensor [B,1,H,W] 32 | Depth map 33 | """ 34 | if is_seq(inv_depth): 35 | return [inv2depth(item) for item in inv_depth] 36 | else: 37 | depth = 1. / inv_depth.clamp(min=1e-6) 38 | depth[inv_depth <= 0.] = 0. 39 | return depth 40 | 41 | 42 | def depth2inv(depth): 43 | """ 44 | Invert a depth map to produce an inverse depth map 45 | 46 | Parameters 47 | ---------- 48 | depth : torch.Tensor or list of torch.Tensor [B,1,H,W] 49 | Depth map 50 | 51 | Returns 52 | ------- 53 | inv_depth : torch.Tensor or list of torch.Tensor [B,1,H,W] 54 | Inverse depth map 55 | 56 | """ 57 | if is_seq(depth): 58 | return [depth2inv(item) for item in depth] 59 | else: 60 | inv_depth = 1. / depth.clamp(min=1e-6) 61 | inv_depth[depth <= 0.] = 0. 62 | return inv_depth 63 | 64 | 65 | def disp_to_depth(disp, min_depth, max_depth): 66 | """Convert network's sigmoid output into depth prediction 67 | The formula for this conversion is given in the 'additional considerations' 68 | section of the paper. 69 | """ 70 | min_disp = 1 / max_depth 71 | max_disp = 1 / min_depth 72 | scaled_disp = min_disp + (max_disp - min_disp) * disp 73 | depth = 1 / scaled_disp 74 | return scaled_disp, depth 75 | 76 | 77 | def gradient_x(image): 78 | """ 79 | Calculates the gradient of an image in the x dimension 80 | Parameters 81 | ---------- 82 | image : torch.Tensor [B,3,H,W] 83 | Input image 84 | 85 | Returns 86 | ------- 87 | gradient_x : torch.Tensor [B,3,H,W-1] 88 | Gradient of image with respect to x 89 | """ 90 | return image[:, :, :, :-1] - image[:, :, :, 1:] 91 | 92 | def gradient_y(image): 93 | """ 94 | Calculates the gradient of an image in the y dimension 95 | Parameters 96 | ---------- 97 | image : torch.Tensor [B,3,H,W] 98 | Input image 99 | 100 | Returns 101 | ------- 102 | gradient_y : torch.Tensor [B,3,H-1,W] 103 | Gradient of image with respect to y 104 | """ 105 | return image[:, :, :-1, :] - image[:, :, 1:, :] 106 | 107 | 108 | def inv_depths_normalize(inv_depths): 109 | """ 110 | Inverse depth normalization 111 | 112 | Parameters 113 | ---------- 114 | inv_depths : list of torch.Tensor [B,1,H,W] 115 | Inverse depth maps 116 | 117 | Returns 118 | ------- 119 | norm_inv_depths : list of torch.Tensor [B,1,H,W] 120 | Normalized inverse depth maps 121 | """ 122 | mean_inv_depths = [inv_depth.mean(2, True).mean(3, True) for inv_depth in inv_depths] 123 | return [inv_depth / mean_inv_depth.clamp(min=1e-6) 124 | for inv_depth, mean_inv_depth in zip(inv_depths, mean_inv_depths)] 125 | 126 | 127 | def calc_smoothness(inv_depths, images, num_scales): 128 | """ 129 | Calculate smoothness values for inverse depths 130 | 131 | Parameters 132 | ---------- 133 | inv_depths : list of torch.Tensor [B,1,H,W] 134 | Inverse depth maps 135 | images : list of torch.Tensor [B,3,H,W] 136 | Inverse depth maps 137 | num_scales : int 138 | Number of scales considered 139 | 140 | Returns 141 | ------- 142 | smoothness_x : list of torch.Tensor [B,1,H,W] 143 | Smoothness values in direction x 144 | smoothness_y : list of torch.Tensor [B,1,H,W] 145 | Smoothness values in direction y 146 | """ 147 | inv_depths_norm = inv_depths_normalize(inv_depths) 148 | inv_depth_gradients_x = [gradient_x(d) for d in inv_depths_norm] 149 | inv_depth_gradients_y = [gradient_y(d) for d in inv_depths_norm] 150 | 151 | image_gradients_x = [gradient_x(image) for image in images] 152 | image_gradients_y = [gradient_y(image) for image in images] 153 | 154 | weights_x = [torch.exp(-torch.mean(torch.abs(g), 1, keepdim=True)) for g in image_gradients_x] 155 | weights_y = [torch.exp(-torch.mean(torch.abs(g), 1, keepdim=True)) for g in image_gradients_y] 156 | 157 | # Note: Fix gradient addition 158 | smoothness_x = [inv_depth_gradients_x[i] * weights_x[i] for i in range(num_scales)] 159 | smoothness_y = [inv_depth_gradients_y[i] * weights_y[i] for i in range(num_scales)] 160 | return smoothness_x, smoothness_y -------------------------------------------------------------------------------- /ggrt/geometry/lie_group/__init__.py: -------------------------------------------------------------------------------- 1 | from ggrt.geometry.lie_group.liegroupbase import LieGroupBase 2 | from ggrt.geometry.lie_group.so3 import SO3 3 | from ggrt.geometry.lie_group.so3q import SO3q 4 | from ggrt.geometry.lie_group.se3 import SE3 5 | from ggrt.geometry.lie_group.se3q import SE3q 6 | -------------------------------------------------------------------------------- /ggrt/geometry/lie_group/liegroupbase.py: -------------------------------------------------------------------------------- 1 | """Generic transformation class (Torch) 2 | 3 | References: 4 | [1] "A tutorial on SE(3) transformation parameterizations and on-manifold 5 | optimization 6 | """ 7 | from typing import Dict, List 8 | 9 | import torch 10 | 11 | 12 | class LieGroupBase(object): 13 | DIM = None 14 | DOF = None 15 | N = None # Group transformation is NxN matrix, e.g. 3 for SO(3) 16 | name = 'LieGroupBaseTorch' 17 | 18 | def __init__(self, data: torch.Tensor): 19 | """Constructor for the Lie group instance. 20 | Note that you should NOT call this directly, but should use one 21 | of the from_* methods, which will perform the appropriate checking. 22 | """ 23 | self.data = data 24 | 25 | @staticmethod 26 | def identity(size: int = None, dtype=None, device=None) -> 'LieGroupBase': 27 | raise NotImplementedError 28 | 29 | @staticmethod 30 | def sample_uniform(size: int = None, device=None) -> 'LieGroupBase': 31 | raise NotImplementedError 32 | 33 | @staticmethod 34 | def from_matrix(mat: torch.Tensor, normalize: bool = False, check: bool = True) -> 'LieGroupBase': 35 | raise NotImplementedError 36 | 37 | def inv(self) -> 'LieGroupBase': 38 | raise NotImplementedError 39 | 40 | @staticmethod 41 | def pexp(omega: torch.Tensor) -> 'LieGroupBase': 42 | raise NotImplementedError 43 | 44 | @staticmethod 45 | def exp(omega: torch.Tensor) -> 'LieGroupBase': 46 | raise NotImplementedError 47 | 48 | def log(self) -> torch.Tensor: 49 | raise NotImplementedError 50 | 51 | def boxplus_left(self, delta: torch.Tensor, pseudo=False) -> 'LieGroupBase': 52 | """Left variant of box plus operator""" 53 | if pseudo: 54 | return self.__class__.pexp(delta) * self 55 | else: 56 | return self.__class__.exp(delta) * self 57 | 58 | def boxplus_right(self, delta: torch.Tensor, pseudo=False) -> 'LieGroupBase': 59 | """Right variant of box plus operator, i.e. 60 | x boxplus delta = x * exp(delta) 61 | See Eq (10.6) in [1] 62 | """ 63 | if pseudo: 64 | return self * self.__class__.pexp(delta) 65 | else: 66 | return self * self.__class__.exp(delta) 67 | 68 | def __mul__(self, other: 'LieGroupBase') -> 'LieGroupBase': 69 | return self.__class__(self.data @ other.data) 70 | 71 | def transform(self, pts: torch.Tensor) -> torch.Tensor: 72 | """Applies the transformation on points 73 | 74 | Args: 75 | pts: Points to transform. Should have the size [N, N_pts, 3] if 76 | transform is batched else, [N_pts, 3] 77 | """ 78 | raise NotImplementedError 79 | 80 | def compare(self, other: 'LieGroupBase') -> Dict: 81 | """Compare with another instance""" 82 | raise NotImplementedError 83 | 84 | def vec(self) -> torch.Tensor: 85 | """Returns the flattened representation""" 86 | raise NotImplementedError 87 | 88 | def as_matrix(self) -> torch.Tensor: 89 | """Return the matrix form of the transform (e.g. 3x3 for SO(3))""" 90 | return self.data 91 | 92 | def is_valid(self) -> bool: 93 | """Check whether the data is valid, e.g. if the underlying SE(3) 94 | representation has a valid rotation""" 95 | raise NotImplementedError 96 | 97 | def make_valid(self): 98 | """Rectifies the data so that the representation is valid""" 99 | pass 100 | 101 | """Misc methods""" 102 | def __getitem__(self, item) -> 'LieGroupBase': 103 | return self.__class__(self.data[item]) 104 | 105 | def __setitem__(self, key, value): 106 | if isinstance(value, torch.Tensor): 107 | self.data[key] = value 108 | else: 109 | self.data[key] = value.data 110 | 111 | def __repr__(self): 112 | return '{} containing {}'.format(self.name, str(self.data)) 113 | 114 | def __str__(self): 115 | return '{}{}'.format(self.name, list(self.data.shape[:-2])) 116 | 117 | @property 118 | def shape(self): 119 | return self.data.shape[:-2] 120 | 121 | def __len__(self): 122 | shape = self.shape 123 | return self.shape[0] if len(shape) >= 1 else 1 124 | 125 | @classmethod 126 | def stack(cls, transforms: List['LieGroupBase']): 127 | """Concatenates transforms into a single transform""" 128 | stacked = torch.cat([t.data for t in transforms], dim=0) 129 | return cls(stacked) 130 | 131 | """Torch specific methods""" 132 | def to(self, device) -> 'LieGroupBase': 133 | """Move instance to device""" 134 | self.data = self.data.to(device) 135 | return self 136 | 137 | def type(self, dtype) -> 'LieGroupBase': 138 | self.data = self.data.type(dtype) 139 | return self 140 | 141 | def detach(self) -> 'LieGroupBase': 142 | return self.__class__(self.data.detach()) 143 | 144 | -------------------------------------------------------------------------------- /ggrt/geometry/lie_group/se3_common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ggrt.geometry.lie_group import so3_common as so3c 4 | 5 | 6 | def is_valid_quat_trans(vec: torch.tensor) -> bool: 7 | """7D vec contains a valid quaternion""" 8 | assert vec.shape[-1] == 7 9 | return so3c.is_valid_quaternion(vec[..., :4]) 10 | 11 | 12 | def normalize_quat_trans(vec: torch.tensor) -> torch.tensor: 13 | """Normalizes SE(3) &D vec to have a valid rotation component""" 14 | 15 | trans = vec[..., 4:] 16 | rot = so3c.normalize_quaternion(vec[..., :4]) 17 | 18 | vec = torch.cat([rot, trans], dim=-1) 19 | return vec 20 | 21 | 22 | def is_valid_matrix(mat: torch.Tensor) -> bool: 23 | """Checks if 4x4 matrix is a valid SE(3) matrix""" 24 | return so3c.is_valid_rotmat(mat[..., :3, :3]) 25 | 26 | 27 | def normalize_matrix(mat: torch.Tensor) -> torch.Tensor: 28 | """Normalizes SE(3) matrix to have a valid rotation component""" 29 | trans = mat[..., :3, 3:] 30 | rot = so3c.normalize_rotmat(mat[..., :3, :3]) 31 | 32 | mat = torch.cat([rot, trans], dim=-1) 33 | bottom_row = torch.zeros_like(mat[..., :1, :]) 34 | bottom_row[..., -1, -1] = 1.0 35 | return torch.cat([mat, bottom_row], dim=-2) 36 | 37 | 38 | def hat(v: torch.Tensor): 39 | """hat-operator for SE(3) 40 | Specifically, it takes in the 6-vector representation (= twist) and returns 41 | the corresponding matrix representation of Lie algebra element. 42 | 43 | Args: 44 | v: Twist vector of size ([*,] 6). As with common convention, first 3 45 | elements denote translation. 46 | 47 | Returns: 48 | mat: se(3) element of size ([*,] 4, 4) 49 | """ 50 | mat = torch.zeros((*v.shape[:-1], 4, 4)) 51 | mat[..., :3, :3] = so3c.hat(v[..., 3:]) # Rotation 52 | mat[..., :3, 3] = v[..., :3] # Translation 53 | 54 | return mat 55 | 56 | 57 | def vee(mat: torch.Tensor): 58 | """vee-operator for SE(3), i.e. inverse of hat() operator. 59 | 60 | Args: 61 | mat: ([*, ] 4, 4) matrix containing the 4x4-matrix lie algebra 62 | representation. Omega must have the following structure: 63 | | 0 -f e a | 64 | | f 0 -d b | 65 | | -e d 0 c | 66 | | 0 0 0 0 | . 67 | 68 | Returns: 69 | v: twist vector of size ([*,] 6) 70 | 71 | """ 72 | v = torch.zeros((*mat.shape[:-2], 6)) 73 | v[..., 3:] = so3c.vee(mat[..., :3, :3]) 74 | v[..., :3] = mat[..., :3, 3] 75 | return v 76 | 77 | 78 | def quattrans2mat(vec: torch.Tensor) -> torch.Tensor: 79 | """Convert 7D quaternion+translation to a 4x4 SE(3) matrix""" 80 | rot, trans = vec[..., :4], vec[..., 4:] 81 | rotmat = so3c.quat2rotmat(rot) 82 | top = torch.cat([rotmat, trans[..., None]], dim=-1) 83 | bottom_row = torch.zeros_like(top[..., :1, :]) 84 | bottom_row[..., -1, -1] = 1.0 85 | mat = torch.cat([top, bottom_row], dim=-2) 86 | return mat 87 | 88 | 89 | def mat2quattrans(mat: torch.Tensor) -> torch.Tensor: 90 | """Convert 4x4 SE(3) matrix to 7D quaternion+translation""" 91 | assert mat.shape[-2:] == (4, 4), 'Matrix should be of shape ([*,] 4, 4)' 92 | quat = so3c.rotmat2quat(mat[..., :3, :3]).data 93 | trans = mat[..., :3, 3] 94 | vec = torch.cat([quat, trans], dim=-1) 95 | return vec 96 | -------------------------------------------------------------------------------- /ggrt/geometry/lie_group/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | 5 | _EPS = 1e-3 # larger epsilon for float32 6 | 7 | 8 | def allclose(mat1: torch.Tensor, mat2: Union[torch.Tensor, float], tol=_EPS): 9 | """Check if all elements of two tensors are close within some tolerance. 10 | 11 | Either tensor can be replaced by a scalar. 12 | 13 | Note: 14 | This is similar to torch.allclose(), but considers just the absolute 15 | difference at a larger tolerance more suitable for float32. 16 | """ 17 | return isclose(mat1, mat2, tol).all() 18 | 19 | 20 | def isclose(mat1: torch.Tensor, mat2: Union[torch.Tensor, float], tol=_EPS): 21 | """Check element-wise if two tensors are close within some tolerance. 22 | 23 | Either tensor can be replaced by a scalar. 24 | 25 | Note: 26 | This is similar to torch.isclose(), but considers just the absolute 27 | difference at a larger tolerance more suitable for float32. 28 | """ 29 | return (mat1 - mat2).abs_().lt(tol) 30 | -------------------------------------------------------------------------------- /ggrt/global_cfg.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from omegaconf import DictConfig 4 | 5 | cfg: Optional[DictConfig] = None 6 | 7 | 8 | def get_cfg() -> DictConfig: 9 | global cfg 10 | return cfg 11 | 12 | 13 | def set_cfg(new_cfg: DictConfig) -> None: 14 | global cfg 15 | cfg = new_cfg 16 | 17 | 18 | def get_seed() -> int: 19 | return cfg.seed 20 | -------------------------------------------------------------------------------- /ggrt/hack_torch/custom_grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from packaging import version as pver 4 | 5 | 6 | def custom_meshgrid(*args): 7 | # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid 8 | if pver.parse(torch.__version__) < pver.parse('1.10'): 9 | return torch.meshgrid(*args) 10 | else: 11 | return torch.meshgrid(*args, indexing='ij') 12 | -------------------------------------------------------------------------------- /ggrt/loss/criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | from utils_loc import img2mse 20 | from ggrt.geometry.depth import depth2inv 21 | 22 | 23 | class MaskedL2ImageLoss(nn.Module): 24 | def __init__(self): 25 | super().__init__() 26 | 27 | def forward(self, outputs, ray_batch): 28 | ''' 29 | training criterion 30 | ''' 31 | pred_rgb = outputs['rgb'] 32 | if 'mask' in outputs: 33 | pred_mask = outputs['mask'].float() 34 | else: 35 | pred_mask = None 36 | gt_rgb = ray_batch['rgb'] 37 | 38 | loss = img2mse(pred_rgb, gt_rgb, pred_mask) 39 | 40 | return loss 41 | 42 | 43 | def pseudo_huber_loss(residual, scale=10): 44 | trunc_residual = residual / scale 45 | return torch.sqrt(trunc_residual * trunc_residual + 1) - 1 46 | 47 | 48 | class FeatureMetricLoss(nn.Module): 49 | def __init__(self) -> None: 50 | super().__init__() 51 | 52 | def forward(self, target_rgb_feat, nearby_view_rgb_feat, mask=None): 53 | ''' 54 | Args: 55 | target_rgb_feat: [n_rays, n_samples=1, n_views+1, d+3] 56 | nearby_view_rgb_feat: [n_rays, n_samples=1, n_views+1, d+3] 57 | ''' 58 | if mask is None: 59 | l1_loss = nn.L1Loss(reduction='mean') 60 | # mse_loss = nn.MSELoss(reduction='mean') 61 | # loss = mse_loss(nearby_view_rgb_feat, target_rgb_feat) 62 | 63 | loss = l1_loss(nearby_view_rgb_feat, target_rgb_feat) 64 | 65 | else: 66 | feat_diff = target_rgb_feat - nearby_view_rgb_feat 67 | feat_diff_square = (feat_diff * feat_diff).squeeze(1) 68 | mask = mask.repeat(1, 1, 1).permute(2, 0, 1) 69 | n_views, n_dims = target_rgb_feat.shape[-2], target_rgb_feat.shape[-1] 70 | loss = torch.sum(feat_diff_square * mask) / (torch.sum(mask.squeeze(-1)) * n_views * n_dims + 1e-6) 71 | 72 | # feat_diff_huber = pseudo_huber_loss(feat_diff, scale=0.8).squeeze(1) 73 | # mask = mask.repeat(1, 1, 1).permute(2, 0, 1) 74 | # n_views, n_dims = target_rgb_feat.shape[-2], target_rgb_feat.shape[-1] 75 | # loss = torch.sum(feat_diff_huber * mask) / (torch.sum(mask.squeeze(-1)) * n_views * n_dims + 1e-6) 76 | 77 | return loss 78 | 79 | 80 | def self_sup_depth_loss(inv_depth_prior, rendered_depth, min_depth, max_depth): 81 | min_disparity = 1.0 / max_depth 82 | max_disparity = 1.0 / min_depth 83 | valid = ((inv_depth_prior > min_disparity) & (inv_depth_prior < max_disparity)).detach() 84 | 85 | inv_rendered_depth = depth2inv(rendered_depth) 86 | 87 | loss_depth = torch.mean(valid * torch.abs(inv_depth_prior - inv_rendered_depth)) 88 | 89 | return loss_depth 90 | 91 | 92 | def sup_depth_loss(ego_motion_inv_depths, gt_depth, min_depth, max_depth): 93 | num_iters = len(ego_motion_inv_depths) 94 | total_loss = 0 95 | total_w = 0 96 | gamma = 0.85 97 | min_disp = 1.0 / max_depth 98 | max_disp = 1.0 / min_depth 99 | 100 | gt_inv_depth = depth2inv(gt_depth) 101 | 102 | valid = ((gt_inv_depth > min_disp) & (gt_inv_depth < max_disp)).detach() 103 | 104 | for i, inv_depth in enumerate(ego_motion_inv_depths): 105 | w = gamma ** (num_iters - i - 1) 106 | total_w += w 107 | 108 | loss_depth = torch.mean(valid * torch.abs(gt_inv_depth - inv_depth.squeeze(0))) 109 | loss_i = loss_depth 110 | total_loss += w * loss_i 111 | loss = total_loss / total_w 112 | return loss 113 | -------------------------------------------------------------------------------- /ggrt/loss/ssim_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from math import exp 6 | from torch.autograd import Variable 7 | 8 | 9 | def gaussian(window_size, sigma): 10 | gauss = torch.Tensor([ 11 | exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) \ 12 | for x in range(window_size) 13 | ]) 14 | return gauss / gauss.sum() 15 | 16 | 17 | def create_window(window_size, channel): 18 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 19 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 20 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 21 | return window 22 | 23 | 24 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 25 | mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel) 26 | mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel) 27 | 28 | mu1_sq = mu1.pow(2) 29 | mu2_sq = mu2.pow(2) 30 | mu1_mu2 = mu1*mu2 31 | 32 | sigma1_sq = F.conv2d( 33 | img1 * img1, window, padding=window_size//2, groups=channel 34 | ) - mu1_sq 35 | sigma2_sq = F.conv2d( 36 | img2 * img2, window, padding=window_size//2, groups=channel 37 | ) - mu2_sq 38 | sigma12 = F.conv2d( 39 | img1 * img2, window, padding=window_size//2, groups=channel 40 | ) - mu1_mu2 41 | 42 | C1 = 0.01 ** 2 43 | C2 = 0.03 ** 2 44 | 45 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (( 46 | mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 47 | 48 | if size_average: 49 | return ssim_map.mean() 50 | else: 51 | return ssim_map.mean(1).mean(1).mean(1) 52 | 53 | 54 | class SSIM(torch.nn.Module): 55 | def __init__(self, window_size=11, size_average = True): 56 | super(SSIM, self).__init__() 57 | self.window_size = window_size 58 | self.size_average = size_average 59 | self.channel = 1 60 | self.window = create_window(window_size, self.channel) 61 | 62 | def forward(self, img1, img2): 63 | (_, channel, _, _) = img1.size() 64 | 65 | if channel == self.channel and self.window.data.type() == img1.data.type(): 66 | window = self.window 67 | else: 68 | window = create_window(self.window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | self.window = window 75 | self.channel = channel 76 | 77 | 78 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 79 | 80 | 81 | def ssim(img1, img2, window_size = 11, size_average = True): 82 | (_, channel, _, _) = img1.size() 83 | window = create_window(window_size, channel) 84 | 85 | if img1.is_cuda: 86 | window = window.cuda(img1.get_device()) 87 | window = window.type_as(img1) 88 | 89 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /ggrt/misc/LocalLogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Optional 4 | 5 | from PIL import Image 6 | from pytorch_lightning.loggers.logger import Logger 7 | from pytorch_lightning.utilities import rank_zero_only 8 | 9 | LOG_PATH = Path("outputs/local") 10 | 11 | 12 | class LocalLogger(Logger): 13 | def __init__(self) -> None: 14 | super().__init__() 15 | self.experiment = None 16 | os.system(f"rm -r {LOG_PATH}") 17 | 18 | @property 19 | def name(self): 20 | return "LocalLogger" 21 | 22 | @property 23 | def version(self): 24 | return 0 25 | 26 | @rank_zero_only 27 | def log_hyperparams(self, params): 28 | pass 29 | 30 | @rank_zero_only 31 | def log_metrics(self, metrics, step): 32 | pass 33 | 34 | @rank_zero_only 35 | def log_image( 36 | self, 37 | key: str, 38 | images: list[Any], 39 | step: Optional[int] = None, 40 | **kwargs, 41 | ): 42 | # The function signature is the same as the wandb logger's, but the step is 43 | # actually required. 44 | assert step is not None 45 | for index, image in enumerate(images): 46 | path = LOG_PATH / f"{key}/{index:0>2}_{step:0>6}.png" 47 | path.parent.mkdir(exist_ok=True, parents=True) 48 | Image.fromarray(image).save(path) 49 | -------------------------------------------------------------------------------- /ggrt/misc/benchmarker.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | from pathlib import Path 5 | from time import time 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class Benchmarker: 12 | def __init__(self): 13 | self.execution_times = defaultdict(list) 14 | 15 | @contextmanager 16 | def time(self, tag: str, num_calls: int = 1): 17 | try: 18 | start_time = time() 19 | yield 20 | finally: 21 | end_time = time() 22 | for _ in range(num_calls): 23 | self.execution_times[tag].append((end_time - start_time) / num_calls) 24 | 25 | def dump(self, path: Path) -> None: 26 | path.parent.mkdir(exist_ok=True, parents=True) 27 | with path.open("w") as f: 28 | json.dump(dict(self.execution_times), f) 29 | 30 | def dump_memory(self, path: Path) -> None: 31 | path.parent.mkdir(exist_ok=True, parents=True) 32 | with path.open("w") as f: 33 | json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f) 34 | 35 | def summarize(self) -> None: 36 | for tag, times in self.execution_times.items(): 37 | print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") 38 | -------------------------------------------------------------------------------- /ggrt/misc/collation.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Union 2 | 3 | from torch import Tensor 4 | 5 | Tree = Union[Dict[str, "Tree"], Tensor] 6 | 7 | 8 | def collate(trees: list[Tree], merge_fn: Callable[[list[Tensor]], Tensor]) -> Tree: 9 | """Merge nested dictionaries of tensors.""" 10 | if isinstance(trees[0], Tensor): 11 | return merge_fn(trees) 12 | else: 13 | return { 14 | key: collate([tree[key] for tree in trees], merge_fn) for key in trees[0] 15 | } 16 | -------------------------------------------------------------------------------- /ggrt/misc/discrete_probability_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import reduce 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor 5 | 6 | 7 | def sample_discrete_distribution( 8 | pdf: Float[Tensor, "*batch bucket"], 9 | num_samples: int, 10 | eps: float = torch.finfo(torch.float32).eps, 11 | ) -> tuple[ 12 | Int64[Tensor, "*batch sample"], # index 13 | Float[Tensor, "*batch sample"], # probability density 14 | ]: 15 | *batch, bucket = pdf.shape 16 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 17 | cdf = normalized_pdf.cumsum(dim=-1) 18 | samples = torch.rand((*batch, num_samples), device=pdf.device) 19 | index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) 20 | return index, normalized_pdf.gather(dim=-1, index=index) 21 | 22 | 23 | def gather_discrete_topk( 24 | pdf: Float[Tensor, "*batch bucket"], 25 | num_samples: int, 26 | eps: float = torch.finfo(torch.float32).eps, 27 | ) -> tuple[ 28 | Int64[Tensor, "*batch sample"], # index 29 | Float[Tensor, "*batch sample"], # probability density 30 | ]: 31 | normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) 32 | index = pdf.topk(k=num_samples, dim=-1).indices 33 | return index, normalized_pdf.gather(dim=-1, index=index) 34 | -------------------------------------------------------------------------------- /ggrt/misc/heterogeneous_pairings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | from jaxtyping import Int 4 | from torch import Tensor 5 | 6 | Index = Int[Tensor, "n n-1"] 7 | 8 | 9 | def generate_heterogeneous_index( 10 | n: int, 11 | device: torch.device = torch.device("cpu"), 12 | ) -> tuple[Index, Index]: 13 | """Generate indices for all pairs except self-pairs.""" 14 | arange = torch.arange(n, device=device) 15 | 16 | # Generate an index that represents the item itself. 17 | index_self = repeat(arange, "h -> h w", w=n - 1) 18 | 19 | # Generate an index that represents the other items. 20 | index_other = repeat(arange, "w -> h w", h=n).clone() 21 | index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() 22 | index_other = index_other[:, :-1] 23 | 24 | return index_self, index_other 25 | 26 | 27 | def generate_heterogeneous_index_transpose( 28 | n: int, 29 | device: torch.device = torch.device("cpu"), 30 | ) -> tuple[Index, Index]: 31 | """Generate an index that can be used to "transpose" the heterogeneous index. 32 | Applying the index a second time inverts the "transpose." 33 | """ 34 | arange = torch.arange(n, device=device) 35 | ones = torch.ones((n, n), device=device, dtype=torch.int64) 36 | 37 | index_self = repeat(arange, "w -> h w", h=n).clone() 38 | index_self = index_self + ones.triu() 39 | 40 | index_other = repeat(arange, "h -> h w", w=n) 41 | index_other = index_other - (1 - ones.triu()) 42 | 43 | return index_self[:, :-1], index_other[:, :-1] 44 | -------------------------------------------------------------------------------- /ggrt/misc/image_io.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as tf 8 | from einops import rearrange, repeat 9 | from jaxtyping import Float, UInt8 10 | from matplotlib.figure import Figure 11 | from PIL import Image 12 | from torch import Tensor 13 | 14 | FloatImage = Union[ 15 | Float[Tensor, "height width"], 16 | Float[Tensor, "channel height width"], 17 | Float[Tensor, "batch channel height width"], 18 | ] 19 | 20 | 21 | def fig_to_image( 22 | fig: Figure, 23 | dpi: int = 100, 24 | device: torch.device = torch.device("cpu"), 25 | ) -> Float[Tensor, "3 height width"]: 26 | buffer = io.BytesIO() 27 | fig.savefig(buffer, format="raw", dpi=dpi) 28 | buffer.seek(0) 29 | data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) 30 | h = int(fig.bbox.bounds[3]) 31 | w = int(fig.bbox.bounds[2]) 32 | data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) 33 | buffer.close() 34 | return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] 35 | 36 | 37 | def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: 38 | # Handle batched images. 39 | if image.ndim == 4: 40 | image = rearrange(image, "b c h w -> c h (b w)") 41 | 42 | # Handle single-channel images. 43 | if image.ndim == 2: 44 | image = rearrange(image, "h w -> () h w") 45 | 46 | # Ensure that there are 3 or 4 channels. 47 | channel, _, _ = image.shape 48 | if channel == 1: 49 | image = repeat(image, "() h w -> c h w", c=3) 50 | assert image.shape[0] in (3, 4) 51 | 52 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) 53 | return rearrange(image, "c h w -> h w c").cpu().numpy() 54 | 55 | 56 | def save_image( 57 | image: FloatImage, 58 | path: Union[Path, str], 59 | ) -> None: 60 | """Save an image. Assumed to be in range 0-1.""" 61 | 62 | # Create the parent directory if it doesn't already exist. 63 | path = Path(path) 64 | path.parent.mkdir(exist_ok=True, parents=True) 65 | 66 | # Save the image. 67 | Image.fromarray(prep_image(image)).save(path) 68 | 69 | 70 | def load_image( 71 | path: Union[Path, str], 72 | ) -> Float[Tensor, "3 height width"]: 73 | return tf.ToTensor()(Image.open(path))[:3] 74 | -------------------------------------------------------------------------------- /ggrt/misc/nn_module_tools.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def convert_to_buffer(module: nn.Module, persistent: bool = True): 5 | # Recurse over child modules. 6 | for name, child in list(module.named_children()): 7 | convert_to_buffer(child, persistent) 8 | 9 | # Also re-save buffers to change persistence. 10 | for name, parameter_or_buffer in ( 11 | *module.named_parameters(recurse=False), 12 | *module.named_buffers(recurse=False), 13 | ): 14 | value = parameter_or_buffer.detach().clone() 15 | delattr(module, name) 16 | module.register_buffer(name, value, persistent=persistent) 17 | -------------------------------------------------------------------------------- /ggrt/misc/sh_rotation.py: -------------------------------------------------------------------------------- 1 | from math import isqrt 2 | 3 | import torch 4 | from e3nn.o3 import matrix_to_angles, wigner_D 5 | from einops import einsum 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | 10 | def rotate_sh( 11 | sh_coefficients: Float[Tensor, "*#batch n"], 12 | rotations: Float[Tensor, "*#batch 3 3"], 13 | ) -> Float[Tensor, "*batch n"]: 14 | device = sh_coefficients.device 15 | dtype = sh_coefficients.dtype 16 | 17 | *_, n = sh_coefficients.shape 18 | alpha, beta, gamma = matrix_to_angles(rotations) 19 | result = [] 20 | for degree in range(isqrt(n)): 21 | sh_rotations = wigner_D(torch.tensor(degree).to(device), alpha, beta, gamma).type(dtype) 22 | sh_rotated = einsum( 23 | sh_rotations, 24 | sh_coefficients[..., degree**2 : (degree + 1) ** 2], 25 | "... i j, ... j -> ... i", 26 | ) 27 | result.append(sh_rotated) 28 | 29 | return torch.cat(result, dim=-1) 30 | 31 | 32 | if __name__ == "__main__": 33 | from pathlib import Path 34 | 35 | import matplotlib.pyplot as plt 36 | from e3nn.o3 import spherical_harmonics 37 | from matplotlib import cm 38 | from scipy.spatial.transform.rotation import Rotation as R 39 | 40 | device = torch.device("cuda") 41 | 42 | # Generate random spherical harmonics coefficients. 43 | degree = 4 44 | coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device) 45 | 46 | def plot_sh(sh_coefficients, path: Path) -> None: 47 | phi = torch.linspace(0, torch.pi, 100, device=device) 48 | theta = torch.linspace(0, 2 * torch.pi, 100, device=device) 49 | phi, theta = torch.meshgrid(phi, theta, indexing="xy") 50 | x = torch.sin(phi) * torch.cos(theta) 51 | y = torch.sin(phi) * torch.sin(theta) 52 | z = torch.cos(phi) 53 | xyz = torch.stack([x, y, z], dim=-1) 54 | sh = spherical_harmonics(list(range(degree + 1)), xyz, True) 55 | result = einsum(sh, sh_coefficients, "... n, n -> ...") 56 | result = (result - result.min()) / (result.max() - result.min()) 57 | 58 | # Set the aspect ratio to 1 so our sphere looks spherical 59 | fig = plt.figure(figsize=plt.figaspect(1.0)) 60 | ax = fig.add_subplot(111, projection="3d") 61 | ax.plot_surface( 62 | x.cpu().numpy(), 63 | y.cpu().numpy(), 64 | z.cpu().numpy(), 65 | rstride=1, 66 | cstride=1, 67 | facecolors=cm.seismic(result.cpu().numpy()), 68 | ) 69 | # Turn off the axis planes 70 | ax.set_axis_off() 71 | path.parent.mkdir(exist_ok=True, parents=True) 72 | plt.savefig(path) 73 | 74 | for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)): 75 | rotation = torch.tensor( 76 | R.from_euler("x", angle.item()).as_matrix(), device=device 77 | ) 78 | plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png")) 79 | 80 | print("Done!") 81 | -------------------------------------------------------------------------------- /ggrt/misc/step_tracker.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import RLock 2 | 3 | import torch 4 | from jaxtyping import Int64 5 | from torch import Tensor 6 | from torch.multiprocessing import Manager 7 | 8 | 9 | class StepTracker: 10 | lock: RLock 11 | step: Int64[Tensor, ""] 12 | 13 | def __init__(self): 14 | self.lock = Manager().RLock() 15 | self.step = torch.tensor(0, dtype=torch.int64).share_memory_() 16 | 17 | def set_step(self, step: int) -> None: 18 | with self.lock: 19 | self.step.fill_(step) 20 | 21 | def get_step(self) -> int: 22 | with self.lock: 23 | return self.step.item() 24 | -------------------------------------------------------------------------------- /ggrt/misc/wandb_tools.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import wandb 4 | 5 | 6 | def version_to_int(artifact) -> int: 7 | """Convert versions of the form vX to X. For example, v12 to 12.""" 8 | return int(artifact.version[1:]) 9 | 10 | 11 | def download_checkpoint( 12 | run_id: str, 13 | download_dir: Path, 14 | version: str | None, 15 | ) -> Path: 16 | api = wandb.Api() 17 | run = api.run(run_id) 18 | 19 | # Find the latest saved model checkpoint. 20 | chosen = None 21 | for artifact in run.logged_artifacts(): 22 | if artifact.type != "model" or artifact.state != "COMMITTED": 23 | continue 24 | 25 | # If no version is specified, use the latest. 26 | if version is None: 27 | if chosen is None or version_to_int(artifact) > version_to_int(chosen): 28 | chosen = artifact 29 | 30 | # If a specific verison is specified, look for it. 31 | elif version == artifact.version: 32 | chosen = artifact 33 | break 34 | 35 | # Download the checkpoint. 36 | download_dir.mkdir(exist_ok=True, parents=True) 37 | root = download_dir / run_id 38 | chosen.download(root=root) 39 | return root / "model.ckpt" 40 | 41 | 42 | def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: 43 | if path is None: 44 | return None 45 | 46 | if not str(path).startswith("wandb://"): 47 | return Path(path) 48 | 49 | run_id, *version = path[len("wandb://") :].split(":") 50 | if len(version) == 0: 51 | version = None 52 | elif len(version) == 1: 53 | version = version[0] 54 | else: 55 | raise ValueError("Invalid version specifier!") 56 | 57 | project = wandb_cfg["project"] 58 | return download_checkpoint( 59 | f"{project}/{run_id}", 60 | Path("checkpoints"), 61 | version, 62 | ) 63 | -------------------------------------------------------------------------------- /ggrt/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifuguan/GGRt_official/01886261b6b6b6175b6ea88f44a85c640564ae9f/ggrt/model/__init__.py -------------------------------------------------------------------------------- /ggrt/model/dbarf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import torch 5 | 6 | from ggrt.model.ibrnet import IBRNetModel 7 | from ggrt.depth_pose_network import DepthPoseNet 8 | from ggrt.loss.photometric_loss import MultiViewPhotometricDecayLoss 9 | 10 | 11 | class DBARFModel(IBRNetModel): 12 | def __init__(self, args, load_opt=True, load_scheduler=True, pretrained=True): 13 | device = torch.device(f'cuda:{args.local_rank}') 14 | 15 | # create pose optimizer. 16 | self.pose_learner = DepthPoseNet(iters=12, pretrained=pretrained).to(device) 17 | self.photometric_loss = MultiViewPhotometricDecayLoss() 18 | 19 | super(DBARFModel, self).__init__(args, load_opt, load_scheduler, half_feat_dim=True) 20 | 21 | def to_distributed(self): 22 | super().to_distributed() 23 | 24 | if self.args.distributed: 25 | self.pose_learner = torch.nn.parallel.DistributedDataParallel( 26 | self.pose_learner, 27 | device_ids=[self.args.local_rank], 28 | output_device=[self.args.local_rank] 29 | ) 30 | 31 | def correct_poses(self, fmaps, target_image, ref_imgs, target_camera, ref_cameras, 32 | min_depth=0.1, max_depth=100, scaled_shape=(378, 504)): 33 | """ 34 | Args: 35 | fmaps: [n_views+1, c, h, w] 36 | target_image: [1, h, w, 3] 37 | ref_imgs: [1, n_views, h, w, 3] 38 | target_camera: [1, 34] 39 | ref_cameras: [1, n_views, 34] 40 | Return: 41 | inv_depths: n_iters*[1, 1, h, w] if training else [1, 1, h, w] 42 | rel_poses: [n_views, n_iters, 6] if training else [n_views, 6] 43 | """ 44 | target_intrinsics = target_camera[:, 2:18].reshape(-1, 4, 4)[..., :3, :3] # [1, 3, 3] 45 | ref_intrinsics = ref_cameras.squeeze(0)[:, 2:18].reshape(-1, 4, 4)[..., :3, :3] # [n_views, 3, 3] 46 | target_image = target_image.permute(0, 3, 1, 2) # [1, 3, h, w] 47 | ref_imgs = ref_imgs.squeeze(0).permute(0, 3, 1, 2) # [n_views, 3, h, w] 48 | 49 | inv_depths, rel_poses, fmap = self.pose_learner( 50 | fmaps=None, # fmaps, 51 | target_image=target_image, 52 | ref_imgs=ref_imgs, 53 | target_intrinsics=target_intrinsics, 54 | ref_intrinsics=ref_intrinsics, 55 | min_depth=min_depth, max_depth=max_depth, 56 | scaled_shape=scaled_shape) 57 | rel_poses = rel_poses.squeeze(0) 58 | 59 | sfm_loss = 0 60 | if self.pose_learner.training: 61 | sfm_loss = self.photometric_loss(target_image, ref_imgs, inv_depths, target_intrinsics, ref_intrinsics, rel_poses) 62 | 63 | return inv_depths, rel_poses, sfm_loss, fmap 64 | 65 | def switch_to_eval(self): 66 | super().switch_to_eval() 67 | self.pose_learner.eval() 68 | 69 | def switch_to_train(self): 70 | super().switch_to_train() 71 | self.pose_learner.train() 72 | 73 | def switch_state_machine(self, state='joint') -> str: 74 | if state == 'pose_only': 75 | self._set_pose_learner_state(opt=True) 76 | self._set_nerf_state(opt=False) 77 | 78 | elif state == 'nerf_only': 79 | self._set_pose_learner_state(opt=False) 80 | self._set_nerf_state(opt=True) 81 | 82 | elif state == 'joint': 83 | self._set_pose_learner_state(opt=True) 84 | self._set_nerf_state(opt=True) 85 | 86 | else: 87 | raise NotImplementedError("Not supported state") 88 | 89 | return state 90 | 91 | def _set_pose_learner_state(self, opt=True): 92 | for param in self.pose_learner.parameters(): 93 | param.requires_grad = opt 94 | 95 | def _set_nerf_state(self, opt=True): 96 | for param in self.net_coarse.parameters(): 97 | param.requires_grad = opt 98 | 99 | for param in self.feature_net.parameters(): 100 | param.requires_grad = opt 101 | 102 | if self.net_fine is not None: 103 | for param in self.net_fine.parameters(): 104 | param.requires_grad = opt 105 | 106 | def compose_joint_loss(self, sfm_loss, nerf_loss, step, coefficient=1e-5): 107 | # The jointly training loss is composed by the convex_combination: 108 | # L = a * L1 + (1-a) * L2 109 | alpha = math.pow(2.0, -coefficient * step) 110 | loss = alpha * sfm_loss + (1 - alpha) * nerf_loss 111 | 112 | return loss 113 | -------------------------------------------------------------------------------- /ggrt/model/dgaussian.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import torch 5 | 6 | from ggrt.model.feature_network import ResUNet 7 | from ggrt.depth_pose_network import DepthPoseNet 8 | from ggrt.loss.photometric_loss import MultiViewPhotometricDecayLoss 9 | 10 | from ggrt.base.model_base import Model 11 | 12 | from ggrt.model.pixelsplat.decoder import get_decoder 13 | from ggrt.model.pixelsplat.encoder import get_encoder 14 | from ggrt.model.pixelsplat.pixelsplat import PixelSplat 15 | 16 | class DGaussianModel(Model): 17 | def __init__(self, args, load_opt=True, load_scheduler=True, pretrained=True): 18 | device = torch.device(f'cuda:{args.local_rank}') 19 | 20 | # create pose optimizer. 21 | self.pose_learner = DepthPoseNet(iters=12, pretrained=pretrained).to(device) 22 | 23 | # create generalized 3d gaussian. 24 | encoder, encoder_visualizer = get_encoder(args.pixelsplat.encoder) 25 | decoder = get_decoder(args.pixelsplat.decoder) 26 | self.gaussian_model = PixelSplat(encoder, decoder, encoder_visualizer) 27 | self.gaussian_model.to(device) 28 | # self.gaussian_model.load_state_dict(torch.load('model_zoo/re10k.ckpt')['state_dict']) 29 | 30 | self.photometric_loss = MultiViewPhotometricDecayLoss() 31 | 32 | def to_distributed(self): 33 | super().to_distributed() 34 | 35 | if self.args.distributed: 36 | self.pose_learner = torch.nn.parallel.DistributedDataParallel( 37 | self.pose_learner, 38 | device_ids=[self.args.local_rank], 39 | output_device=[self.args.local_rank] 40 | ) 41 | self.gaussian_model = torch.nn.parallel.DistributedDataParallel( 42 | self.gaussian_model, 43 | device_ids=[self.args.local_rank], 44 | output_device=[self.args.local_rank] 45 | ) 46 | 47 | def switch_to_eval(self): 48 | self.pose_learner.eval() 49 | self.gaussian_model.eval() 50 | 51 | def switch_to_train(self): 52 | self.pose_learner.train() 53 | self.gaussian_model.train() 54 | 55 | def iponet(self, fmaps, target_image, ref_imgs, target_camera, ref_cameras, 56 | min_depth=0.1, max_depth=100, scaled_shape=(378, 504)): 57 | """ 58 | Args: 59 | fmaps: [n_views+1, c, h, w] 60 | target_image: [1, h, w, 3] 61 | ref_imgs: [1, n_views, h, w, 3] 62 | target_camera: [1, 34] 63 | ref_cameras: [1, n_views, 34] 64 | Return: 65 | inv_depths: n_iters*[1, 1, h, w] if training else [1, 1, h, w] 66 | rel_poses: [n_views, n_iters, 6] if training else [n_views, 6] 67 | """ 68 | target_intrinsics = target_camera[:, 2:18].reshape(-1, 4, 4)[..., :3, :3] # [1, 3, 3] 69 | ref_intrinsics = ref_cameras.squeeze(0)[:, 2:18].reshape(-1, 4, 4)[..., :3, :3] # [n_views, 3, 3] 70 | target_image = target_image.permute(0, 3, 1, 2) # [1, 3, h, w] 71 | ref_imgs = ref_imgs.squeeze(0).permute(0, 3, 1, 2) # [n_views, 3, h, w] 72 | 73 | inv_depths, rel_poses, fmap = self.pose_learner( 74 | fmaps=None, # fmaps, 75 | target_image=target_image, 76 | ref_imgs=ref_imgs, 77 | target_intrinsics=target_intrinsics, 78 | ref_intrinsics=ref_intrinsics, 79 | min_depth=min_depth, max_depth=max_depth, 80 | scaled_shape=scaled_shape) 81 | rel_poses = rel_poses.squeeze(0) 82 | 83 | sfm_loss = 0 84 | if self.pose_learner.training: 85 | sfm_loss = self.photometric_loss(target_image, ref_imgs, inv_depths, target_intrinsics, ref_intrinsics, rel_poses) 86 | # sfm_loss = 0 87 | return inv_depths, rel_poses, sfm_loss, fmap 88 | 89 | def switch_state_machine(self, state='joint') -> str: 90 | if state == 'pose_only': 91 | self._set_pose_learner_state(opt=True) 92 | self._set_gaussian_state(opt=False) 93 | 94 | elif state == 'nerf_only': 95 | self._set_pose_learner_state(opt=False) 96 | self._set_gaussian_state(opt=True) 97 | 98 | elif state == 'joint': 99 | self._set_pose_learner_state(opt=True) 100 | self._set_gaussian_state(opt=True) 101 | 102 | else: 103 | raise NotImplementedError("Not supported state") 104 | 105 | return state 106 | 107 | def _set_pose_learner_state(self, opt=True): 108 | for param in self.pose_learner.parameters(): 109 | param.requires_grad = opt 110 | 111 | def _set_gaussian_state(self, opt=True): 112 | for param in self.gaussian_model.parameters(): 113 | param.requires_grad = opt 114 | 115 | def compose_joint_loss(self, sfm_loss, nerf_loss, step, coefficient=1e-5): 116 | # The jointly training loss is composed by the convex_combination: 117 | # L = a * L1 + (1-a) * L2 118 | alpha = math.pow(2.0, -coefficient * step) 119 | loss = alpha * sfm_loss + (1 - alpha) * nerf_loss 120 | 121 | return loss 122 | -------------------------------------------------------------------------------- /ggrt/model/gaussian.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import torch 5 | 6 | from ggrt.model.feature_network import ResUNet 7 | from ggrt.depth_pose_network import DepthPoseNet 8 | from ggrt.loss.photometric_loss import MultiViewPhotometricDecayLoss 9 | 10 | from ggrt.base.model_base import Model 11 | 12 | from ggrt.model.pixelsplat.decoder import get_decoder 13 | from ggrt.model.pixelsplat.encoder import get_encoder 14 | from ggrt.model.pixelsplat.pixelsplat import PixelSplat 15 | 16 | class GaussianModel(Model): 17 | def __init__(self, args, load_opt=True, load_scheduler=True, pretrained=True): 18 | device = torch.device(f'cuda:{args.local_rank}') 19 | 20 | 21 | # create generalized 3d gaussian. 22 | encoder, encoder_visualizer = get_encoder(args.pixelsplat.encoder) 23 | decoder = get_decoder(args.pixelsplat.decoder) 24 | self.gaussian_model = PixelSplat(encoder, decoder, encoder_visualizer) 25 | self.gaussian_model.to(device) 26 | # self.gaussian_model.load_state_dict(torch.load('model_zoo/re10k.ckpt')['state_dict']) 27 | 28 | self.photometric_loss = MultiViewPhotometricDecayLoss() 29 | 30 | def to_distributed(self): 31 | super().to_distributed() 32 | 33 | if self.args.distributed: 34 | self.gaussian_model = torch.nn.parallel.DistributedDataParallel( 35 | self.gaussian_model, 36 | device_ids=[self.args.local_rank], 37 | output_device=[self.args.local_rank] 38 | ) 39 | 40 | def switch_to_eval(self): 41 | 42 | self.gaussian_model.eval() 43 | 44 | def switch_to_train(self): 45 | self.gaussian_model.train() 46 | 47 | def switch_state_machine(self, state='joint') -> str: 48 | if state == 'pose_only': 49 | self._set_gaussian_state(opt=False) 50 | 51 | elif state == 'nerf_only': 52 | self._set_gaussian_state(opt=True) 53 | 54 | elif state == 'joint': 55 | self._set_gaussian_state(opt=True) 56 | 57 | else: 58 | raise NotImplementedError("Not supported state") 59 | 60 | return state 61 | 62 | def _set_gaussian_state(self, opt=True): 63 | for param in self.gaussian_model.parameters(): 64 | param.requires_grad = opt 65 | 66 | def compose_joint_loss(self, sfm_loss, nerf_loss, step, coefficient=1e-5): 67 | # The jointly training loss is composed by the convex_combination: 68 | # L = a * L1 + (1-a) * L2 69 | alpha = math.pow(2.0, -coefficient * step) 70 | loss = alpha * sfm_loss + (1 - alpha) * nerf_loss 71 | 72 | return loss 73 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import Decoder 2 | from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg 3 | 4 | DECODERS = { 5 | "splatting_cuda": DecoderSplattingCUDA, 6 | } 7 | 8 | DecoderCfg = DecoderSplattingCUDACfg 9 | 10 | 11 | def get_decoder(decoder_cfg: DecoderCfg) -> Decoder: 12 | return DECODERS[decoder_cfg.name](decoder_cfg) 13 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Generic, Literal, TypeVar 4 | 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | import torch 8 | 9 | # from ...dataset import DatasetCfg 10 | from ..types import Gaussians 11 | 12 | DepthRenderingMode = Literal[ 13 | "depth", 14 | "log", 15 | "disparity", 16 | "relative_disparity", 17 | ] 18 | 19 | 20 | @dataclass 21 | class DecoderOutput: 22 | color: Float[Tensor, "batch view 3 height width"] 23 | depth: Float[Tensor, "batch view height width"] | None 24 | 25 | 26 | T = TypeVar("T") 27 | 28 | 29 | class Decoder(nn.Module, ABC, Generic[T]): 30 | cfg: T 31 | 32 | def __init__(self, cfg: T) -> None: 33 | super().__init__() 34 | self.cfg = cfg 35 | 36 | @abstractmethod 37 | def forward( 38 | self, 39 | gaussians: Gaussians, 40 | extrinsics: Float[Tensor, "batch view 4 4"], 41 | intrinsics: Float[Tensor, "batch view 3 3"], 42 | near: Float[Tensor, "batch view"], 43 | far: Float[Tensor, "batch view"], 44 | image_shape: tuple[int, int], 45 | depth_mode: DepthRenderingMode | None = None, 46 | ) -> DecoderOutput: 47 | pass 48 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/decoder/decoder_splatting_cuda.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from jaxtyping import Float 7 | from torch import Tensor 8 | 9 | from ..types import Gaussians 10 | from .cuda_splatting import DepthRenderingMode, render_cuda, render_depth_cuda 11 | from .decoder import Decoder, DecoderOutput 12 | 13 | 14 | @dataclass 15 | class DecoderSplattingCUDACfg: 16 | name: Literal["splatting_cuda"] 17 | 18 | 19 | class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): 20 | background_color: Float[Tensor, "3"] 21 | 22 | def __init__( 23 | self, 24 | cfg: DecoderSplattingCUDACfg, 25 | ) -> None: 26 | super().__init__(cfg) 27 | self.background_color = torch.tensor(([0.0, 0.0, 0.0]), dtype=torch.float32) 28 | 29 | def forward( 30 | self, 31 | gaussians: Gaussians, 32 | extrinsics: Float[Tensor, "batch view 4 4"], 33 | intrinsics: Float[Tensor, "batch view 3 3"], 34 | near: Float[Tensor, "batch view"], 35 | far: Float[Tensor, "batch view"], 36 | image_shape: tuple[int, int], 37 | depth_mode: DepthRenderingMode | None = None, 38 | ) -> DecoderOutput: 39 | b, v, _, _ = extrinsics.shape 40 | color = render_cuda( 41 | rearrange(extrinsics, "b v i j -> (b v) i j"), 42 | rearrange(intrinsics, "b v i j -> (b v) i j"), 43 | rearrange(near, "b v -> (b v)"), 44 | rearrange(far, "b v -> (b v)"), 45 | image_shape, 46 | repeat(self.background_color.to(far.device), "c -> (b v) c", b=b, v=v), 47 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 48 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 49 | repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), 50 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 51 | ) 52 | color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v) 53 | 54 | return DecoderOutput( 55 | color, 56 | None 57 | if depth_mode is None 58 | else self.render_depth( 59 | gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode 60 | ), 61 | ) 62 | 63 | def render_depth( 64 | self, 65 | gaussians: Gaussians, 66 | extrinsics: Float[Tensor, "batch view 4 4"], 67 | intrinsics: Float[Tensor, "batch view 3 3"], 68 | near: Float[Tensor, "batch view"], 69 | far: Float[Tensor, "batch view"], 70 | image_shape: tuple[int, int], 71 | mode: DepthRenderingMode = "depth", 72 | ) -> Float[Tensor, "batch view height width"]: 73 | b, v, _, _ = extrinsics.shape 74 | result = render_depth_cuda( 75 | rearrange(extrinsics, "b v i j -> (b v) i j"), 76 | rearrange(intrinsics, "b v i j -> (b v) i j"), 77 | rearrange(near, "b v -> (b v)"), 78 | rearrange(far, "b v -> (b v)"), 79 | image_shape, 80 | repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), 81 | repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), 82 | repeat(gaussians.opacities, "b g -> (b v) g", v=v), 83 | mode=mode, 84 | ) 85 | return rearrange(result, "(b v) h w -> b v h w", b=b, v=v) 86 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .encoder import Encoder 4 | from .encoder_epipolar import EncoderEpipolar, EncoderEpipolarCfg 5 | from .visualization.encoder_visualizer import EncoderVisualizer 6 | from .visualization.encoder_visualizer_epipolar import EncoderVisualizerEpipolar 7 | 8 | ENCODERS = { 9 | "epipolar": (EncoderEpipolar, EncoderVisualizerEpipolar), 10 | } 11 | 12 | EncoderCfg = EncoderEpipolarCfg 13 | 14 | 15 | def get_encoder(cfg: EncoderCfg) -> tuple[Encoder, Optional[EncoderVisualizer]]: 16 | encoder, visualizer = ENCODERS[cfg.name] 17 | encoder = encoder(cfg) 18 | if visualizer is not None: 19 | visualizer = visualizer(cfg.visualizer, encoder) 20 | return encoder, visualizer 21 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .backbone import Backbone 4 | from .backbone_dino import BackboneDino, BackboneDinoCfg 5 | from .backbone_resnet import BackboneResnet, BackboneResnetCfg 6 | 7 | BACKBONES: dict[str, Backbone[Any]] = { 8 | "resnet": BackboneResnet, 9 | "dino": BackboneDino, 10 | } 11 | 12 | BackboneCfg = BackboneResnetCfg | BackboneDinoCfg 13 | 14 | 15 | def get_backbone(cfg: BackboneCfg, d_in: int) -> Backbone[Any]: 16 | return BACKBONES[cfg.name](cfg, d_in) 17 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor, nn 6 | 7 | from ggrt.dataset.types import BatchedViews 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Backbone(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | ) -> Float[Tensor, "batch view d_out height width"]: 24 | pass 25 | 26 | @property 27 | @abstractmethod 28 | def d_out(self) -> int: 29 | pass 30 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/backbone/backbone_dino.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | 4 | import torch 5 | from einops import rearrange, repeat 6 | from jaxtyping import Float 7 | from torch import Tensor, nn 8 | 9 | from ggrt.dataset.types import BatchedViews 10 | from .backbone import Backbone 11 | from .backbone_resnet import BackboneResnet, BackboneResnetCfg 12 | 13 | 14 | @dataclass 15 | class BackboneDinoCfg: 16 | name: Literal["dino"] 17 | model: Literal["dino_vits16", "dino_vits8", "dino_vitb16", "dino_vitb8"] 18 | d_out: int 19 | 20 | 21 | class BackboneDino(Backbone[BackboneDinoCfg]): 22 | def __init__(self, cfg: BackboneDinoCfg, d_in: int) -> None: 23 | super().__init__(cfg) 24 | assert d_in == 3 25 | self.resnet_backbone = BackboneResnet( 26 | BackboneResnetCfg("resnet", "dino_resnet50", 4, False, cfg.d_out), 27 | d_in, 28 | ) 29 | 30 | def forward( 31 | self, 32 | context: BatchedViews, 33 | ) -> Float[Tensor, "batch view d_out height width"]: 34 | # Compute features from the DINO-pretrained resnet50. 35 | resnet_features = self.resnet_backbone(context) 36 | 37 | return resnet_features.to(torch.float) 38 | # return resnet_features + local_tokens + global_token 39 | 40 | @property 41 | def patch_size(self) -> int: 42 | return int("".join(filter(str.isdigit, self.cfg.model))) 43 | 44 | @property 45 | def d_out(self) -> int: 46 | return self.cfg.d_out 47 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/backbone/backbone_resnet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from dataclasses import dataclass 3 | from typing import Literal 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from einops import rearrange 9 | from jaxtyping import Float 10 | from torch import Tensor, nn 11 | from torchvision.models import ResNet 12 | 13 | from ggrt.dataset.types import BatchedViews 14 | from .backbone import Backbone 15 | 16 | 17 | @dataclass 18 | class BackboneResnetCfg: 19 | name: Literal["resnet"] 20 | model: Literal[ 21 | "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "dino_resnet50" 22 | ] 23 | num_layers: int 24 | use_first_pool: bool 25 | d_out: int 26 | 27 | 28 | class BackboneResnet(Backbone[BackboneResnetCfg]): 29 | model: ResNet 30 | 31 | def __init__(self, cfg: BackboneResnetCfg, d_in: int) -> None: 32 | super().__init__(cfg) 33 | 34 | assert d_in == 3 35 | 36 | norm_layer = functools.partial( 37 | nn.InstanceNorm2d, 38 | affine=False, 39 | track_running_stats=False, 40 | ) 41 | 42 | if cfg.model == "dino_resnet50": 43 | self.model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50") 44 | else: 45 | self.model = getattr(torchvision.models, cfg.model)(norm_layer=norm_layer) 46 | 47 | # Set up projections 48 | self.projections = nn.ModuleDict({}) 49 | for index in range(1, cfg.num_layers): 50 | key = f"layer{index}" 51 | block = getattr(self.model, key) 52 | conv_index = 1 53 | try: 54 | while True: 55 | d_layer_out = getattr(block[-1], f"conv{conv_index}").out_channels 56 | conv_index += 1 57 | except AttributeError: 58 | pass 59 | self.projections[key] = nn.Conv2d(d_layer_out, cfg.d_out, 1) 60 | 61 | # Add a projection for the first layer. 62 | self.projections["layer0"] = nn.Conv2d( 63 | self.model.conv1.out_channels, cfg.d_out, 1 64 | ) 65 | 66 | def forward( 67 | self, 68 | context: BatchedViews, 69 | ) -> Float[Tensor, "batch view d_out height width"]: 70 | # Merge the batch dimensions. 71 | b, v, _, h, w = context["image"].shape 72 | x = rearrange(context["image"], "b v c h w -> (b v) c h w") 73 | 74 | # Run the images through the resnet. 75 | x = self.model.conv1(x) 76 | x = self.model.bn1(x) 77 | x = self.model.relu(x) 78 | features = [self.projections["layer0"](x)] 79 | 80 | # Propagate the input through the resnet's layers. 81 | for index in range(1, self.cfg.num_layers): 82 | key = f"layer{index}" 83 | if index == 0 and self.cfg.use_first_pool: 84 | x = self.model.maxpool(x) 85 | x = getattr(self.model, key)(x) 86 | features.append(self.projections[key](x)) 87 | 88 | # Upscale the features. 89 | features = [ 90 | F.interpolate(f.to(dtype=torch.float32), (h, w), mode="bilinear", align_corners=True).to(dtype=torch.bfloat16) 91 | for f in features 92 | ] 93 | features = torch.stack(features).sum(dim=0) 94 | 95 | # Separate batch dimensions. 96 | return rearrange(features, "(b v) c h w -> b v c h w", b=b, v=v) 97 | 98 | @property 99 | def d_out(self) -> int: 100 | return self.cfg.d_out 101 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/common/depth_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float, Int64 4 | from torch import Tensor, nn 5 | 6 | from ..epipolar.conversions import relative_disparity_to_depth 7 | from ..epipolar.distribution_sampler import DistributionSampler 8 | 9 | 10 | class DepthPredictor(nn.Module): 11 | sampler: DistributionSampler 12 | 13 | def __init__( 14 | self, 15 | use_transmittance: bool, 16 | ) -> None: 17 | super().__init__() 18 | self.sampler = DistributionSampler() 19 | self.to_pdf = nn.Softmax(dim=-1) 20 | self.to_offset = nn.Sigmoid() 21 | self.use_transmittance = use_transmittance 22 | 23 | def forward( 24 | self, 25 | features: Float[Tensor, "batch view ray surface depth 2"], 26 | near: Float[Tensor, "batch view"], 27 | far: Float[Tensor, "batch view"], 28 | deterministic: bool, 29 | gaussians_per_pixel: int, 30 | ) -> tuple[ 31 | Float[Tensor, "batch view ray surface sample"], # depth 32 | Float[Tensor, "batch view ray surface sample"], # opacity 33 | Int64[Tensor, "batch view ray surface sample"], # index 34 | ]: 35 | # Convert the features into a depth distribution plus intra-bucket offsets. 36 | pdf_raw, offset_raw = features.unbind(dim=-1) 37 | pdf = self.to_pdf(pdf_raw) 38 | offset = self.to_offset(offset_raw) 39 | 40 | # Sample from the depth distribution. 41 | index, pdf_i = self.sampler.sample(pdf, deterministic, gaussians_per_pixel) 42 | offset = self.sampler.gather(index, offset) 43 | 44 | # Convert the sampled bucket and offset to a depth. 45 | *_, num_depths, _ = features.shape 46 | relative_disparity = (index + offset) / num_depths 47 | depth = relative_disparity_to_depth( 48 | relative_disparity, 49 | rearrange(near, "b v -> b v () () ()"), 50 | rearrange(far, "b v -> b v () () ()"), 51 | ) 52 | 53 | # Compute opacity from PDF. 54 | if self.use_transmittance: 55 | partial = pdf.cumsum(dim=-1) 56 | partial = torch.cat( 57 | (torch.zeros_like(partial[..., :1]), partial[..., :-1]), dim=-1 58 | ) 59 | opacity = pdf / (1 - partial + 1e-10) 60 | opacity = self.sampler.gather(index, opacity) 61 | else: 62 | opacity = pdf_i 63 | 64 | return depth, opacity, index 65 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/common/gaussian_adapter.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from einops import einsum, rearrange 5 | from jaxtyping import Float 6 | from torch import Tensor, nn 7 | 8 | from ggrt.geometry.projection import get_world_rays 9 | from ggrt.misc.sh_rotation import rotate_sh 10 | from .gaussians import build_covariance 11 | 12 | 13 | @dataclass 14 | class Gaussians: 15 | means: Float[Tensor, "*batch 3"] 16 | covariances: Float[Tensor, "*batch 3 3"] 17 | scales: Float[Tensor, "*batch 3"] 18 | rotations: Float[Tensor, "*batch 4"] 19 | harmonics: Float[Tensor, "*batch 3 _"] 20 | opacities: Float[Tensor, " *batch"] 21 | 22 | 23 | @dataclass 24 | class GaussianAdapterCfg: 25 | gaussian_scale_min: float 26 | gaussian_scale_max: float 27 | sh_degree: int 28 | 29 | 30 | class GaussianAdapter(nn.Module): 31 | cfg: GaussianAdapterCfg 32 | 33 | def __init__(self, cfg: GaussianAdapterCfg): 34 | super().__init__() 35 | self.cfg = cfg 36 | 37 | # Create a mask for the spherical harmonics coefficients. This ensures that at 38 | # initialization, the coefficients are biased towards having a large DC 39 | # component and small view-dependent components. 40 | self.register_buffer( 41 | "sh_mask", 42 | torch.ones((self.d_sh,), dtype=torch.float32), 43 | persistent=False, 44 | ) 45 | for degree in range(1, self.cfg.sh_degree + 1): 46 | self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree 47 | 48 | def forward( 49 | self, 50 | extrinsics: Float[Tensor, "*#batch 4 4"], 51 | intrinsics: Float[Tensor, "*#batch 3 3"], 52 | coordinates: Float[Tensor, "*#batch 2"], 53 | depths: Float[Tensor, "*#batch"], 54 | opacities: Float[Tensor, "*#batch"], 55 | raw_gaussians: Float[Tensor, "*#batch _"], 56 | image_shape: tuple[int, int], 57 | eps: float = 1e-8, 58 | ) -> Gaussians: 59 | device = extrinsics.device 60 | scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1) 61 | 62 | # Map scale features to valid scale range. 63 | scale_min = self.cfg.gaussian_scale_min 64 | scale_max = self.cfg.gaussian_scale_max 65 | scales = scale_min + (scale_max - scale_min) * scales.sigmoid() 66 | h, w = image_shape 67 | pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=device) 68 | multiplier = self.get_scale_multiplier(intrinsics, pixel_size) 69 | scales = scales * depths[..., None] * multiplier[..., None] 70 | 71 | # Normalize the quaternion features to yield a valid quaternion. 72 | rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps) 73 | 74 | # Apply sigmoid to get valid colors. 75 | sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) 76 | sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask 77 | 78 | # Create world-space covariance matrices. 79 | covariances = build_covariance(scales, rotations) 80 | c2w_rotations = extrinsics[..., :3, :3] 81 | covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2) 82 | 83 | # Compute Gaussian means. 84 | origins, directions = get_world_rays(coordinates, extrinsics, intrinsics) 85 | means = origins + directions * depths[..., None] 86 | 87 | return Gaussians( 88 | means=means, 89 | covariances=covariances, 90 | harmonics=rotate_sh(sh, c2w_rotations[..., None, :, :]), 91 | opacities=opacities, 92 | # Note: These aren't yet rotated into world space, but they're only used for 93 | # exporting Gaussians to ply files. This needs to be fixed... 94 | scales=scales, 95 | rotations=rotations.broadcast_to((*scales.shape[:-1], 4)), 96 | ) 97 | 98 | def get_scale_multiplier( 99 | self, 100 | intrinsics: Float[Tensor, "*#batch 3 3"], 101 | pixel_size: Float[Tensor, "*#batch 2"], 102 | multiplier: float = 0.1, 103 | ) -> Float[Tensor, " *batch"]: 104 | xy_multipliers = multiplier * einsum( 105 | intrinsics[..., :2, :2].inverse(), 106 | pixel_size, 107 | "... i j, j -> ... i", 108 | ) 109 | return xy_multipliers.sum(dim=-1) 110 | 111 | @property 112 | def d_sh(self) -> int: 113 | return (self.cfg.sh_degree + 1) ** 2 114 | 115 | @property 116 | def d_in(self) -> int: 117 | return 7 + 3 * self.d_sh 118 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/common/gaussians.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 8 | def quaternion_to_matrix( 9 | quaternions: Float[Tensor, "*batch 4"], 10 | eps: float = 1e-8, 11 | ) -> Float[Tensor, "*batch 3 3"]: 12 | # Order changed to match scipy format! 13 | i, j, k, r = torch.unbind(quaternions, dim=-1) 14 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) 15 | 16 | o = torch.stack( 17 | ( 18 | 1 - two_s * (j * j + k * k), 19 | two_s * (i * j - k * r), 20 | two_s * (i * k + j * r), 21 | two_s * (i * j + k * r), 22 | 1 - two_s * (i * i + k * k), 23 | two_s * (j * k - i * r), 24 | two_s * (i * k - j * r), 25 | two_s * (j * k + i * r), 26 | 1 - two_s * (i * i + j * j), 27 | ), 28 | -1, 29 | ) 30 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) 31 | 32 | 33 | def build_covariance( 34 | scale: Float[Tensor, "*#batch 3"], 35 | rotation_xyzw: Float[Tensor, "*#batch 4"], 36 | ) -> Float[Tensor, "*batch 3 3"]: 37 | scale = scale.diag_embed() 38 | rotation = quaternion_to_matrix(rotation_xyzw) 39 | return ( 40 | rotation 41 | @ scale 42 | @ rearrange(scale, "... i j -> ... j i") 43 | @ rearrange(rotation, "... i j -> ... j i") 44 | ) 45 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/common/sampler.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float, Int64, Shaped 2 | from torch import Tensor, nn 3 | 4 | from ..dbarf.misc.discrete_probability_distribution import ( 5 | gather_discrete_topk, 6 | sample_discrete_distribution, 7 | ) 8 | 9 | 10 | class Sampler(nn.Module): 11 | def forward( 12 | self, 13 | probabilities: Float[Tensor, "*batch bucket"], 14 | num_samples: int, 15 | deterministic: bool, 16 | ) -> tuple[ 17 | Int64[Tensor, "*batch 1"], # index 18 | Float[Tensor, "*batch 1"], # probability density 19 | ]: 20 | return ( 21 | gather_discrete_topk(probabilities, num_samples) 22 | if deterministic 23 | else sample_discrete_distribution(probabilities, num_samples) 24 | ) 25 | 26 | def gather( 27 | self, 28 | index: Int64[Tensor, "*batch sample"], 29 | target: Shaped[Tensor, "..."], # *batch bucket *shape 30 | ) -> Shaped[Tensor, "..."]: # *batch sample *shape 31 | """Gather from the target according to the specified index. Handle the 32 | broadcasting needed for the gather to work. See the comments for the actual 33 | expected input/output shapes since jaxtyping doesn't support multiple variadic 34 | lengths in annotations. 35 | """ 36 | bucket_dim = index.ndim - 1 37 | while len(index.shape) < len(target.shape): 38 | index = index[..., None] 39 | broadcasted_index_shape = list(target.shape) 40 | broadcasted_index_shape[bucket_dim] = index.shape[bucket_dim] 41 | index = index.broadcast_to(broadcasted_index_shape) 42 | return target.gather(dim=bucket_dim, index=index) 43 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from torch import Tensor,nn 5 | 6 | from ....dataset.types import BatchedViews, DataShim 7 | from ..types import Gaussians 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class Encoder(nn.Module, ABC, Generic[T]): 13 | cfg: T 14 | 15 | def __init__(self, cfg: T) -> None: 16 | super().__init__() 17 | self.cfg = cfg 18 | 19 | @abstractmethod 20 | def forward( 21 | self, 22 | context: BatchedViews, 23 | features: Tensor, 24 | clip_h: int, 25 | clip_w: int, 26 | deterministic: bool, 27 | just_return_future: bool = False, 28 | crop_size = None, 29 | ) -> Gaussians: 30 | pass 31 | 32 | def get_data_shim(self) -> DataShim: 33 | """The default shim doesn't modify the batch.""" 34 | return lambda x: x 35 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/epipolar/conversions.py: -------------------------------------------------------------------------------- 1 | from jaxtyping import Float 2 | from torch import Tensor 3 | 4 | 5 | def relative_disparity_to_depth( 6 | relative_disparity: Float[Tensor, "*#batch"], 7 | near: Float[Tensor, "*#batch"], 8 | far: Float[Tensor, "*#batch"], 9 | eps: float = 1e-10, 10 | ) -> Float[Tensor, " *batch"]: 11 | """Convert relative disparity, where 0 is near and 1 is far, to depth.""" 12 | disp_near = 1 / (near + eps) 13 | disp_far = 1 / (far + eps) 14 | return 1 / ((1 - relative_disparity) * (disp_near - disp_far) + disp_far + eps) 15 | 16 | 17 | def depth_to_relative_disparity( 18 | depth: Float[Tensor, "*#batch"], 19 | near: Float[Tensor, "*#batch"], 20 | far: Float[Tensor, "*#batch"], 21 | eps: float = 1e-10, 22 | ) -> Float[Tensor, " *batch"]: 23 | """Convert depth to relative disparity, where 0 is near and 1 is far""" 24 | disp_near = 1 / (near + eps) 25 | disp_far = 1 / (far + eps) 26 | disp = 1 / (depth + eps) 27 | return 1 - (disp - disp_far) / (disp_near - disp_far + eps) 28 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/epipolar/depth_predictor_monocular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor, nn 5 | 6 | from .conversions import relative_disparity_to_depth 7 | from .distribution_sampler import DistributionSampler 8 | 9 | 10 | class DepthPredictorMonocular(nn.Module): 11 | projection: nn.Sequential 12 | sampler: DistributionSampler 13 | num_samples: int 14 | num_surfaces: int 15 | 16 | def __init__( 17 | self, 18 | d_in: int, 19 | num_samples: int, 20 | num_surfaces: int, 21 | use_transmittance: bool, 22 | ) -> None: 23 | super().__init__() 24 | self.projection = nn.Sequential( 25 | nn.ReLU(), 26 | nn.Linear(d_in, 2 * num_samples * num_surfaces), 27 | ) 28 | self.sampler = DistributionSampler() 29 | self.num_samples = num_samples 30 | self.num_surfaces = num_surfaces 31 | self.use_transmittance = use_transmittance 32 | 33 | # This exists for hooks to latch onto. 34 | self.to_pdf = nn.Softmax(dim=-1) 35 | self.to_offset = nn.Sigmoid() 36 | 37 | def forward( 38 | self, 39 | features: Float[Tensor, "batch view ray channel"], 40 | near: Float[Tensor, "batch view"], 41 | far: Float[Tensor, "batch view"], 42 | deterministic: bool, 43 | gaussians_per_pixel: int, 44 | ) -> tuple[ 45 | Float[Tensor, "batch view ray surface sample"], # depth 46 | Float[Tensor, "batch view ray surface sample"], # pdf 47 | ]: 48 | s = self.num_samples 49 | 50 | # Convert the features into a depth distribution plus intra-bucket offsets. 51 | features = self.projection(features) 52 | pdf_raw, offset_raw = rearrange( 53 | features, "... (dpt srf c) -> c ... srf dpt", c=2, srf=self.num_surfaces 54 | ) 55 | pdf = self.to_pdf(pdf_raw) 56 | offset = self.to_offset(offset_raw) 57 | 58 | # Sample from the depth distribution. 59 | index, pdf_i = self.sampler.sample(pdf, deterministic, gaussians_per_pixel) 60 | offset = self.sampler.gather(index, offset) 61 | 62 | # Convert the sampled bucket and offset to a depth. 63 | relative_disparity = (index + offset) / s 64 | depth = relative_disparity_to_depth( 65 | relative_disparity, 66 | rearrange(near, "b v -> b v () () ()"), 67 | rearrange(far, "b v -> b v () () ()"), 68 | ) 69 | 70 | # Compute opacity from PDF. 71 | if self.use_transmittance: 72 | partial = pdf.cumsum(dim=-1) 73 | partial = torch.cat( 74 | (torch.zeros_like(partial[..., :1]), partial[..., :-1]), dim=-1 75 | ) 76 | opacity = pdf / (1 - partial + 1e-10) 77 | opacity = self.sampler.gather(index, opacity) 78 | else: 79 | opacity = pdf_i 80 | 81 | return depth, opacity 82 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/epipolar/distribution.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import einsum 5 | from jaxtyping import Bool, Float 6 | from torch import Tensor, nn 7 | 8 | 9 | class Distribution(nn.Module): 10 | scale: float 11 | to_q: nn.Linear 12 | to_k: nn.Linear 13 | 14 | def __init__( 15 | self, 16 | dim_q: int, 17 | dim_k: int, 18 | dim_inner: int = 64, 19 | ): 20 | super().__init__() 21 | self.scale = dim_inner**-0.5 22 | self.to_q = nn.Linear(dim_q, dim_inner, bias=False) 23 | self.to_k = nn.Linear(dim_k, dim_inner, bias=False) 24 | 25 | def forward( 26 | self, 27 | queries: Float[Tensor, "batch token_query dim_query"], 28 | keys: Float[Tensor, "batch token_key dim_key"], 29 | force_last_token: Optional[Bool[Tensor, " batch"]] = None, 30 | ) -> Float[Tensor, "batch token_query token_key"]: 31 | # Compute softmax attention. 32 | q = self.to_q(queries) 33 | k = self.to_k(keys) 34 | weights = einsum(q, k, "b q d, b k d -> b q k").softmax(dim=-1) 35 | 36 | if force_last_token is None: 37 | return weights 38 | 39 | # Where applicable, force the last token to be selected. 40 | last_token = torch.zeros( 41 | keys.shape[1], device=queries.device, dtype=queries.dtype 42 | ) 43 | last_token[-1] = 1 44 | mask = force_last_token[:, None, None] 45 | return last_token * mask + weights * ~mask 46 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/epipolar/distribution_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float, Int64, Shaped 3 | from torch import Tensor 4 | 5 | from ggrt.misc.discrete_probability_distribution import ( 6 | gather_discrete_topk, 7 | sample_discrete_distribution, 8 | ) 9 | 10 | 11 | class DistributionSampler: 12 | def sample( 13 | self, 14 | pdf: Float[Tensor, "*batch bucket"], 15 | deterministic: bool, 16 | num_samples: int, 17 | ) -> tuple[ 18 | Int64[Tensor, "*batch sample"], # index 19 | Float[Tensor, "*batch sample"], # probability density 20 | ]: 21 | """Sample from the given probability distribution. Return sampled indices and 22 | their corresponding probability densities. 23 | """ 24 | if deterministic: 25 | index, densities = gather_discrete_topk(pdf, num_samples) 26 | else: 27 | index, densities = sample_discrete_distribution(pdf, num_samples) 28 | return index, densities 29 | 30 | def gather( 31 | self, 32 | index: Int64[Tensor, "*batch sample"], 33 | target: Shaped[Tensor, "..."], # *batch bucket *shape 34 | ) -> Shaped[Tensor, "..."]: # *batch *shape 35 | """Gather from the target according to the specified index. Handle the 36 | broadcasting needed for the gather to work. See the comments for the actual 37 | expected input/output shapes since jaxtyping doesn't support multiple variadic 38 | lengths in annotations. 39 | """ 40 | bucket_dim = index.ndim - 1 41 | while len(index.shape) < len(target.shape): 42 | index = index[..., None] 43 | broadcasted_index_shape = list(target.shape) 44 | broadcasted_index_shape[bucket_dim] = index.shape[bucket_dim] 45 | index = index.broadcast_to(broadcasted_index_shape) 46 | 47 | # Add the ability to broadcast. 48 | if target.shape[bucket_dim] == 1: 49 | index = torch.zeros_like(index) 50 | 51 | return target.gather(dim=bucket_dim, index=index) 52 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/epipolar/image_self_attention.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from torch import Tensor, nn 6 | 7 | from ggrt.geometry.projection import sample_image_grid 8 | from ...encodings.positional_encoding import PositionalEncoding 9 | from ...transformer.transformer import Transformer 10 | 11 | 12 | @dataclass 13 | class ImageSelfAttentionCfg: 14 | patch_size: int 15 | num_octaves: int 16 | num_layers: int 17 | num_heads: int 18 | d_token: int 19 | d_dot: int 20 | d_mlp: int 21 | 22 | 23 | class ImageSelfAttention(nn.Module): 24 | positional_encoding: nn.Sequential 25 | patch_embedder: nn.Sequential 26 | transformer: Transformer 27 | 28 | def __init__( 29 | self, 30 | cfg: ImageSelfAttentionCfg, 31 | d_in: int, 32 | d_out: int, 33 | ): 34 | super().__init__() 35 | self.positional_encoding = nn.Sequential( 36 | (pe := PositionalEncoding(cfg.num_octaves)), 37 | nn.Linear(pe.d_out(2), cfg.d_token), 38 | ) 39 | self.patch_embedder = nn.Sequential( 40 | nn.Conv2d(d_in, cfg.d_token, cfg.patch_size, cfg.patch_size), 41 | nn.ReLU(), 42 | ) 43 | self.transformer = Transformer( 44 | cfg.d_token, 45 | cfg.num_layers, 46 | cfg.num_heads, 47 | cfg.d_dot, 48 | cfg.d_mlp, 49 | ) 50 | self.resampler = nn.ConvTranspose2d( 51 | cfg.d_token, 52 | d_out, 53 | cfg.patch_size, 54 | cfg.patch_size, 55 | ) 56 | self.index=0 57 | def forward( 58 | self, 59 | image: Float[Tensor, "batch d_in height width"], 60 | ) -> Float[Tensor, "batch d_out height width"]: 61 | # Embed patches so they become tokens. 62 | tokens = self.patch_embedder.forward(image) 63 | 64 | # Append positional information to the tokens. 65 | _, _, nh, nw = tokens.shape 66 | crop_size=2 67 | # if nh<20: 68 | # index= self.index//3 #查看是第几个crop 69 | # self.index=self.index+1 70 | # i=index//crop_size #行和列的缩影 71 | # j=index%crop_size #行和列的缩影 72 | # xy, _ = sample_image_grid((nh*crop_size, nw*crop_size), device=image.device) 73 | # xy = self.positional_encoding.forward(xy)[i*nh:(i+1)*nh,j*nw :(j+1)*nw, :] 74 | 75 | # else: #走nograd全图将index赋值为0 76 | self.index=0 77 | xy, _ = sample_image_grid((nh, nw), device=image.device) 78 | xy = self.positional_encoding.forward(xy) 79 | 80 | # Put the tokens through a transformer. 81 | _, _, nh, nw = tokens.shape 82 | # if nh>=20 : 83 | # for i in range(crop_size): 84 | # for j in range(crop_size): 85 | # tokens_1=tokens[:,:,i*nh//crop_size:(i+1)*nh//crop_size,j*nw//crop_size:(j+1)*nw//crop_size] 86 | # tokens_1 = rearrange(tokens_1, "b c nh nw -> b (nh nw) c") 87 | # tokens_1 = self.transformer.forward(tokens_1) 88 | # tokens_1 = rearrange(tokens_1, "b (nh nw) c -> b c nh nw", nh=nh//crop_size, nw=nw//crop_size) 89 | # tokens[:,:,i*nh//crop_size:(i+1)*nh//crop_size,j*nw//crop_size:(j+1)*nw//crop_size]=tokens_1 90 | # Resample the tokens back to the original resolution. 91 | # tokens = rearrange(tokens, "b (nh nw) c -> b c nh nw", nh=nh, nw=nw) 92 | # else: 93 | tokens = rearrange(tokens, "b c nh nw -> b (nh nw) c") 94 | tokens = self.transformer.forward(tokens) 95 | tokens = rearrange(tokens, "b (nh nw) c -> b c nh nw", nh=nh, nw=nw) 96 | 97 | tokens = self.resampler.forward(tokens) 98 | 99 | return tokens 100 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/visualization/encoder_visualizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Generic, TypeVar 3 | 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | T_cfg = TypeVar("T_cfg") 8 | T_encoder = TypeVar("T_encoder") 9 | 10 | 11 | class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]): 12 | cfg: T_cfg 13 | encoder: T_encoder 14 | 15 | def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None: 16 | self.cfg = cfg 17 | self.encoder = encoder 18 | 19 | @abstractmethod 20 | def visualize( 21 | self, 22 | context: dict, 23 | global_step: int, 24 | ) -> dict[str, Float[Tensor, "3 _ _"]]: 25 | pass 26 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encoder/visualization/encoder_visualizer_epipolar_cfg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | # This is in a separate file to avoid circular imports. 4 | 5 | 6 | @dataclass 7 | class EncoderVisualizerEpipolarCfg: 8 | num_samples: int 9 | min_resolution: int 10 | export_ply: bool 11 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/encodings/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import einsum, rearrange, repeat 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """For the sake of simplicity, this encodes values in the range [0, 1].""" 10 | 11 | frequencies: Float[Tensor, "frequency phase"] 12 | phases: Float[Tensor, "frequency phase"] 13 | 14 | def __init__(self, num_octaves: int): 15 | super().__init__() 16 | octaves = torch.arange(num_octaves).float() 17 | 18 | # The lowest frequency has a period of 1. 19 | frequencies = 2 * torch.pi * 2**octaves 20 | frequencies = repeat(frequencies, "f -> f p", p=2) 21 | self.register_buffer("frequencies", frequencies, persistent=False) 22 | 23 | # Choose the phases to match sine and cosine. 24 | phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32) 25 | phases = repeat(phases, "p -> f p", f=num_octaves) 26 | self.register_buffer("phases", phases, persistent=False) 27 | 28 | def forward( 29 | self, 30 | samples: Float[Tensor, "*batch dim"], 31 | ) -> Float[Tensor, "*batch embedded_dim"]: 32 | samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p") 33 | return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)") 34 | 35 | def d_out(self, dimensionality: int): 36 | return self.frequencies.numel() * dimensionality 37 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/ply_export.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from einops import einsum, rearrange 6 | from jaxtyping import Float 7 | from plyfile import PlyData, PlyElement 8 | from scipy.spatial.transform import Rotation as R 9 | from torch import Tensor 10 | 11 | 12 | def construct_list_of_attributes(num_rest: int) -> list[str]: 13 | attributes = ["x", "y", "z", "nx", "ny", "nz"] 14 | for i in range(3): 15 | attributes.append(f"f_dc_{i}") 16 | for i in range(num_rest): 17 | attributes.append(f"f_rest_{i}") 18 | attributes.append("opacity") 19 | for i in range(3): 20 | attributes.append(f"scale_{i}") 21 | for i in range(4): 22 | attributes.append(f"rot_{i}") 23 | return attributes 24 | 25 | 26 | def export_ply( 27 | extrinsics: Float[Tensor, "4 4"], 28 | means: Float[Tensor, "gaussian 3"], 29 | scales: Float[Tensor, "gaussian 3"], 30 | rotations: Float[Tensor, "gaussian 4"], 31 | harmonics: Float[Tensor, "gaussian 3 d_sh"], 32 | opacities: Float[Tensor, " gaussian"], 33 | path: Path, 34 | ): 35 | # Shift the scene so that the median Gaussian is at the origin. 36 | means = means - means.median(dim=0).values 37 | 38 | # Rescale the scene so that most Gaussians are within range [-1, 1]. 39 | scale_factor = means.abs().quantile(0.95, dim=0).max() 40 | means = means / scale_factor 41 | scales = scales / scale_factor 42 | 43 | # Define a rotation that makes +Z be the world up vector. 44 | rotation = [ 45 | [0, 0, 1], 46 | [-1, 0, 0], 47 | [0, -1, 0], 48 | ] 49 | rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device) 50 | 51 | # The Polycam viewer seems to start at a 45 degree angle. Since we want to be 52 | # looking directly at the object, we compose a 45 degree rotation onto the above 53 | # rotation. 54 | adjustment = torch.tensor( 55 | R.from_rotvec([0, 0, -45], True).as_matrix(), 56 | dtype=torch.float32, 57 | device=means.device, 58 | ) 59 | rotation = adjustment @ rotation 60 | 61 | # We also want to see the scene in camera space (as the default view). We therefore 62 | # compose the w2c rotation onto the above rotation. 63 | rotation = rotation @ extrinsics[:3, :3].inverse() 64 | 65 | # Apply the rotation to the means (Gaussian positions). 66 | means = einsum(rotation, means, "i j, ... j -> ... i") 67 | 68 | # Apply the rotation to the Gaussian rotations. 69 | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() 70 | rotations = rotation.detach().cpu().numpy() @ rotations 71 | rotations = R.from_matrix(rotations).as_quat() 72 | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") 73 | rotations = np.stack((w, x, y, z), axis=-1) 74 | 75 | # Since our axes are swizzled for the spherical harmonics, we only export the DC 76 | # band. 77 | harmonics_view_invariant = harmonics[..., 0] 78 | 79 | dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)] 80 | elements = np.empty(means.shape[0], dtype=dtype_full) 81 | attributes = ( 82 | means.detach().cpu().numpy(), 83 | torch.zeros_like(means).detach().cpu().numpy(), 84 | harmonics_view_invariant.detach().cpu().contiguous().numpy(), 85 | opacities[..., None].detach().cpu().numpy(), 86 | scales.log().detach().cpu().numpy(), 87 | rotations, 88 | ) 89 | attributes = np.concatenate(attributes, axis=1) 90 | elements[:] = list(map(tuple, attributes)) 91 | path.parent.mkdir(exist_ok=True, parents=True) 92 | PlyData([PlyElement.describe(elements, "vertex")]).write(path) 93 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/transformer/feed_forward.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | from torch import nn 26 | 27 | 28 | class FeedForward(nn.Module): 29 | def __init__(self, dim, hidden_dim, dropout=0.0): 30 | super().__init__() 31 | self.net = nn.Sequential( 32 | nn.Linear(dim, hidden_dim), 33 | nn.GELU(), 34 | nn.Dropout(dropout), 35 | nn.Linear(hidden_dim, dim), 36 | nn.Dropout(dropout), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.net(x) 41 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/transformer/pre_norm.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | from torch import nn 26 | 27 | 28 | class PreNorm(nn.Module): 29 | def __init__(self, dim, fn): 30 | super().__init__() 31 | self.norm = nn.LayerNorm(dim) 32 | self.fn = fn 33 | 34 | def forward(self, x, **kwargs): 35 | return self.fn(self.norm(x), **kwargs) 36 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2022 Karl Stelzner 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # This file comes from https://github.com/stelzner/srt 24 | 25 | from torch import nn 26 | 27 | from .attention import Attention 28 | from .feed_forward import FeedForward 29 | from .pre_norm import PreNorm 30 | 31 | 32 | class Transformer(nn.Module): 33 | def __init__( 34 | self, 35 | dim, 36 | depth, 37 | heads, 38 | dim_head, 39 | mlp_dim, 40 | dropout=0.0, 41 | selfatt=True, 42 | kv_dim=None, 43 | feed_forward_layer=FeedForward, 44 | ): 45 | super().__init__() 46 | self.layers = nn.ModuleList([]) 47 | for _ in range(depth): 48 | self.layers.append( 49 | nn.ModuleList( 50 | [ 51 | PreNorm( 52 | dim, 53 | Attention( 54 | dim, 55 | heads=heads, 56 | dim_head=dim_head, 57 | dropout=dropout, 58 | selfatt=selfatt, 59 | kv_dim=kv_dim, 60 | ), 61 | ), 62 | PreNorm(dim, feed_forward_layer(dim, mlp_dim, dropout=dropout)), 63 | ] 64 | ) 65 | ) 66 | 67 | def forward(self, x, z=None, **kwargs): 68 | a = 1 69 | for i, layer in enumerate(self.layers): 70 | x = layer[0](x, z=z) + x 71 | x = layer[1](x, **kwargs) + x 72 | return x 73 | -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | import torch 6 | 7 | @dataclass 8 | class Gaussians: 9 | means: Float[Tensor, "batch gaussian dim"] 10 | covariances: Float[Tensor, "batch gaussian dim dim"] 11 | harmonics: Float[Tensor, "batch gaussian 3 d_sh"] 12 | opacities: Float[Tensor, "batch gaussian"] 13 | 14 | def to(self, type=torch.bfloat16) -> "Gaussians": 15 | return Gaussians(means=self.means.to(type), covariances=self.covariances.to(type), harmonics=self.harmonics.to(type), opacities=self.opacities.to(type)) 16 | 17 | def detach(self) -> "Gaussians": 18 | return Gaussians(means=self.means.detach(), covariances=self.covariances.detach(), harmonics=self.harmonics.detach(), opacities=self.opacities.detach()) -------------------------------------------------------------------------------- /ggrt/model/pixelsplat/wobble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @torch.no_grad() 8 | def generate_wobble_transformation( 9 | radius: Float[Tensor, "*#batch"], 10 | t: Float[Tensor, " time_step"], 11 | num_rotations: int = 1, 12 | scale_radius_with_t: bool = True, 13 | ) -> Float[Tensor, "*batch time_step 4 4"]: 14 | # Generate a translation in the image plane. 15 | tf = torch.eye(4, dtype=torch.float32, device=t.device) 16 | tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() 17 | radius = radius[..., None] 18 | if scale_radius_with_t: 19 | radius = radius * t 20 | tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius 21 | tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius 22 | return tf 23 | @torch.no_grad() 24 | def generate_wobble( 25 | extrinsics: Float[Tensor, "*#batch 4 4"], 26 | radius: Float[Tensor, "*#batch"], 27 | t: Float[Tensor, " time_step"], 28 | ) -> Float[Tensor, "*batch time_step 4 4"]: 29 | tf = generate_wobble_transformation(radius, t) 30 | return rearrange(extrinsics, "... i j -> ... () i j") @ tf 31 | -------------------------------------------------------------------------------- /ggrt/render_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from collections import OrderedDict 18 | 19 | from ggrt.render_ray import render_rays 20 | 21 | 22 | def render_single_image(ray_sampler, 23 | ray_batch, 24 | model, 25 | projector, 26 | chunk_size, 27 | N_samples, 28 | inv_uniform=False, 29 | N_importance=0, 30 | det=False, 31 | white_bkgd=False, 32 | render_stride=1, 33 | feat_maps=None, 34 | inv_depth_prior=None, 35 | rel_poses=None): 36 | ''' 37 | :param ray_sampler: RaySamplingSingleImage for this view 38 | :param model: {'net_coarse': , 'net_fine': , ...} 39 | :param chunk_size: number of rays in a chunk 40 | :param N_samples: samples along each ray (for both coarse and fine model) 41 | :param inv_uniform: if True, uniformly sample inverse depth for coarse model 42 | :param N_importance: additional samples along each ray produced by importance sampling (for fine model) 43 | :return: {'outputs_coarse': {'rgb': numpy, 'depth': numpy, ...}, 'outputs_fine': {}} 44 | ''' 45 | 46 | all_ret = OrderedDict([('outputs_coarse', OrderedDict()), 47 | ('outputs_fine', OrderedDict())]) 48 | 49 | N_rays = ray_batch['ray_o'].shape[0] 50 | 51 | for i in range(0, N_rays, chunk_size): 52 | chunk = OrderedDict() 53 | for k in ray_batch: 54 | if k in ['camera', 'depth_range', 'src_rgbs', 'src_cameras']: 55 | chunk[k] = ray_batch[k] 56 | elif ray_batch[k] is not None: 57 | chunk[k] = ray_batch[k][i:i+chunk_size] 58 | else: 59 | chunk[k] = None 60 | 61 | inv_depth_prior_chunk = None 62 | if inv_depth_prior is not None: 63 | inv_depth_prior_chunk = inv_depth_prior[i:i+chunk_size] 64 | 65 | ret = render_rays(chunk, model, feat_maps, 66 | projector=projector, 67 | N_samples=N_samples, 68 | inv_uniform=inv_uniform, 69 | N_importance=N_importance, 70 | det=det, 71 | white_bkgd=white_bkgd, 72 | inv_depth_prior=inv_depth_prior_chunk, 73 | rel_poses=rel_poses) 74 | 75 | # handle both coarse and fine outputs 76 | # cache chunk results on cpu 77 | if i == 0: 78 | for k in ret['outputs_coarse']: 79 | all_ret['outputs_coarse'][k] = [] 80 | 81 | if ret['outputs_fine'] is None: 82 | all_ret['outputs_fine'] = None 83 | else: 84 | for k in ret['outputs_fine']: 85 | all_ret['outputs_fine'][k] = [] 86 | 87 | for k in ret['outputs_coarse']: 88 | all_ret['outputs_coarse'][k].append(ret['outputs_coarse'][k].cpu()) 89 | 90 | if ret['outputs_fine'] is not None: 91 | for k in ret['outputs_fine']: 92 | all_ret['outputs_fine'][k].append(ret['outputs_fine'][k].cpu()) 93 | 94 | rgb_strided = torch.ones(ray_sampler.H, ray_sampler.W, 3)[::render_stride, ::render_stride, :] 95 | # merge chunk results and reshape 96 | for k in all_ret['outputs_coarse']: 97 | if k == 'random_sigma': 98 | continue 99 | tmp = torch.cat(all_ret['outputs_coarse'][k], dim=0).reshape((rgb_strided.shape[0], 100 | rgb_strided.shape[1], -1)) 101 | all_ret['outputs_coarse'][k] = tmp.squeeze() 102 | 103 | all_ret['outputs_coarse']['rgb'][all_ret['outputs_coarse']['mask'] == 0] = 1. 104 | if all_ret['outputs_fine'] is not None: 105 | for k in all_ret['outputs_fine']: 106 | if k == 'random_sigma': 107 | continue 108 | tmp = torch.cat(all_ret['outputs_fine'][k], dim=0).reshape((rgb_strided.shape[0], 109 | rgb_strided.shape[1], -1)) 110 | 111 | all_ret['outputs_fine'][k] = tmp.squeeze() 112 | 113 | return all_ret 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /ggrt/utils/union_find.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class UnionFind(): 5 | def __init__(self, size: int, max_num_per_set=None) -> None: 6 | # The maximum size for each component. 7 | self.max_num_per_set = max_num_per_set 8 | 9 | # Tracking the rank for each node. 10 | self.ranks = [0 for i in range(size)] 11 | 12 | # Tracking the root node for each node. 13 | self.parents = [i for i in range(size)] 14 | 15 | self.nodes = [i for i in range(size)] 16 | 17 | # For nodes which indices are not in sequential, we map the id of each 18 | # node into a sequential index. 19 | self.node_mapper = {i:i for i in range(size)} 20 | 21 | # Tracking of the size of each component, such that we are able to 22 | # truncate too large components. 23 | self.component_size = {i:1 for i in range(size)} 24 | 25 | def init_with_nodes(self, nodes: list): 26 | self.node_mapper.clear() 27 | self.nodes = nodes 28 | for i, node_idx in enumerate(nodes): 29 | self.node_mapper[node_idx] = i 30 | 31 | def union(self, x, y): 32 | x = self.find_root(x) 33 | y = self.find_root(y) 34 | 35 | # If the nodes are already part of the same connected component then do nothing. 36 | if x == y: 37 | return 38 | 39 | # If merging the connected components will create a connected component larger 40 | # than the maximum size then do nothing. 41 | if (self.max_num_per_set != None) and \ 42 | (self.component_size[x] + self.component_size[y] > self.max_num_per_set): 43 | return 44 | 45 | if self.ranks[x] < self.ranks[y]: 46 | self.component_size[y] += self.component_size[x] 47 | self.parents[x] = y 48 | else: 49 | self.component_size[x] += self.component_size[y] 50 | self.parents[y] = x 51 | if self.ranks[x] == self.ranks[y]: 52 | self.ranks[x] += 1 53 | 54 | def find_root(self, x): 55 | idx = self.node_mapper[x] 56 | if self.parents[idx] == idx: 57 | return idx 58 | else: 59 | self.parents[idx] = self.find_root(self.nodes[self.parents[idx]]) 60 | return self.parents[idx] 61 | 62 | def get_connected_components(): 63 | return None 64 | 65 | def validate(self): 66 | union_set = {} 67 | for node_id in self.nodes: 68 | root_id = self.find_root(node_id) 69 | if root_id not in union_set.keys(): 70 | union_set[root_id] = [] 71 | union_set[root_id].append(node_id) 72 | 73 | for key in union_set.keys(): 74 | assert len(union_set[key]) <= self.max_num_per_set, \ 75 | f"elements number of set {key}: {len(union_set[key])}" 76 | -------------------------------------------------------------------------------- /ggrt/visualization/annotation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from string import ascii_letters, digits, punctuation 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from jaxtyping import Float 8 | from PIL import Image, ImageDraw, ImageFont 9 | from torch import Tensor 10 | 11 | from .layout import vcat 12 | 13 | EXPECTED_CHARACTERS = digits + punctuation + ascii_letters 14 | 15 | 16 | def draw_label( 17 | text: str, 18 | font: Path, 19 | font_size: int, 20 | device: torch.device = torch.device("cpu"), 21 | ) -> Float[Tensor, "3 height width"]: 22 | """Draw a black label on a white background with no border.""" 23 | try: 24 | font = ImageFont.truetype(str(font), font_size) 25 | except OSError: 26 | font = ImageFont.load_default() 27 | left, _, right, _ = font.getbbox(text) 28 | width = right - left 29 | _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) 30 | height = bottom - top 31 | image = Image.new("RGB", (width, height), color="white") 32 | draw = ImageDraw.Draw(image) 33 | draw.text((0, 0), text, font=font, fill="black") 34 | image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) 35 | return rearrange(image, "h w c -> c h w") 36 | 37 | 38 | def add_label( 39 | image: Float[Tensor, "3 width height"], 40 | label: str, 41 | font: Path = Path("assets/Inter-Regular.otf"), 42 | font_size: int = 24, 43 | ) -> Float[Tensor, "3 width_with_label height_with_label"]: 44 | return vcat( 45 | draw_label(label, font, font_size, image.device), 46 | image, 47 | align="left", 48 | gap=4, 49 | ) 50 | -------------------------------------------------------------------------------- /ggrt/visualization/camera_trajectory/spin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import repeat 4 | from jaxtyping import Float 5 | from scipy.spatial.transform import Rotation as R 6 | from torch import Tensor 7 | 8 | 9 | def generate_spin( 10 | num_frames: int, 11 | device: torch.device, 12 | elevation: float, 13 | radius: float, 14 | ) -> Float[Tensor, "frame 4 4"]: 15 | # Translate back along the camera's look vector. 16 | tf_translation = torch.eye(4, dtype=torch.float32, device=device) 17 | tf_translation[:2] *= -1 18 | tf_translation[2, 3] = -radius 19 | 20 | # Generate the transformation for the azimuth. 21 | phi = 2 * np.pi * (np.arange(num_frames) / num_frames) 22 | rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) 23 | 24 | azimuth = R.from_rotvec(rotation_vectors).as_matrix() 25 | azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) 26 | tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) 27 | tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() 28 | tf_azimuth[:, :3, :3] = azimuth 29 | 30 | # Generate the transformation for the elevation. 31 | deg_elevation = np.deg2rad(elevation) 32 | elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) 33 | elevation = torch.tensor(elevation.as_matrix()) 34 | tf_elevation = torch.eye(4, dtype=torch.float32, device=device) 35 | tf_elevation[:3, :3] = elevation 36 | 37 | return tf_azimuth @ tf_elevation @ tf_translation 38 | -------------------------------------------------------------------------------- /ggrt/visualization/camera_trajectory/wobble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from jaxtyping import Float 4 | from torch import Tensor 5 | 6 | 7 | @torch.no_grad() 8 | def generate_wobble_transformation( 9 | radius: Float[Tensor, "*#batch"], 10 | t: Float[Tensor, " time_step"], 11 | num_rotations: int = 1, 12 | scale_radius_with_t: bool = True, 13 | ) -> Float[Tensor, "*batch time_step 4 4"]: 14 | # Generate a translation in the image plane. 15 | tf = torch.eye(4, dtype=torch.float32, device=t.device) 16 | tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() 17 | radius = radius[..., None] 18 | if scale_radius_with_t: 19 | radius = radius * t 20 | tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius 21 | tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius 22 | return tf 23 | 24 | 25 | @torch.no_grad() 26 | def generate_wobble( 27 | extrinsics: Float[Tensor, "*#batch 4 4"], 28 | radius: Float[Tensor, "*#batch"], 29 | t: Float[Tensor, " time_step"], 30 | ) -> Float[Tensor, "*batch time_step 4 4"]: 31 | tf = generate_wobble_transformation(radius, t) 32 | return rearrange(extrinsics, "... i j -> ... () i j") @ tf 33 | -------------------------------------------------------------------------------- /ggrt/visualization/color_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colorspacious import cspace_convert 3 | from einops import rearrange 4 | from jaxtyping import Float 5 | from matplotlib import cm 6 | from torch import Tensor 7 | 8 | 9 | def apply_color_map( 10 | x: Float[Tensor, " *batch"], 11 | color_map: str = "inferno", 12 | ) -> Float[Tensor, "*batch 3"]: 13 | cmap = cm.get_cmap(color_map) 14 | 15 | # Convert to NumPy so that Matplotlib color maps can be used. 16 | mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3] 17 | 18 | # Convert back to the original format. 19 | return torch.tensor(mapped, device=x.device, dtype=torch.float32) 20 | 21 | 22 | def apply_color_map_to_image( 23 | image: Float[Tensor, "*batch height width"], 24 | color_map: str = "inferno", 25 | ) -> Float[Tensor, "*batch 3 height with"]: 26 | image = apply_color_map(image, color_map) 27 | return rearrange(image, "... h w c -> ... c h w") 28 | 29 | 30 | def apply_color_map_2d( 31 | x: Float[Tensor, "*#batch"], 32 | y: Float[Tensor, "*#batch"], 33 | ) -> Float[Tensor, "*batch 3"]: 34 | red = cspace_convert((189, 0, 0), "sRGB255", "CIELab") 35 | blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab") 36 | white = cspace_convert((255, 255, 255), "sRGB255", "CIELab") 37 | x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None] 38 | y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None] 39 | 40 | # Interpolate between red and blue on the x axis. 41 | interpolated = x_np * red + (1 - x_np) * blue 42 | 43 | # Interpolate between color and white on the y axis. 44 | interpolated = y_np * interpolated + (1 - y_np) * white 45 | 46 | # Convert to RGB. 47 | rgb = cspace_convert(interpolated, "CIELab", "sRGB1") 48 | return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1) 49 | -------------------------------------------------------------------------------- /ggrt/visualization/colors.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageColor 2 | 3 | # https://sashamaps.net/docs/resources/20-colors/ 4 | DISTINCT_COLORS = [ 5 | "#e6194b", 6 | "#3cb44b", 7 | "#ffe119", 8 | "#4363d8", 9 | "#f58231", 10 | "#911eb4", 11 | "#46f0f0", 12 | "#f032e6", 13 | "#bcf60c", 14 | "#fabebe", 15 | "#008080", 16 | "#e6beff", 17 | "#9a6324", 18 | "#fffac8", 19 | "#800000", 20 | "#aaffc3", 21 | "#808000", 22 | "#ffd8b1", 23 | "#000075", 24 | "#808080", 25 | "#ffffff", 26 | "#000000", 27 | ] 28 | 29 | 30 | def get_distinct_color(index: int) -> tuple[float, float, float]: 31 | hex = DISTINCT_COLORS[index % len(DISTINCT_COLORS)] 32 | return tuple(x / 255 for x in ImageColor.getcolor(hex, "RGB")) 33 | -------------------------------------------------------------------------------- /ggrt/visualization/drawing/coordinate_conversion.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, runtime_checkable 2 | 3 | import torch 4 | from jaxtyping import Float 5 | from torch import Tensor 6 | 7 | from .types import Pair, sanitize_pair 8 | 9 | 10 | @runtime_checkable 11 | class ConversionFunction(Protocol): 12 | def __call__( 13 | self, 14 | xy: Float[Tensor, "*batch 2"], 15 | ) -> Float[Tensor, "*batch 2"]: 16 | pass 17 | 18 | 19 | def generate_conversions( 20 | shape: tuple[int, int], 21 | device: torch.device, 22 | x_range: Optional[Pair] = None, 23 | y_range: Optional[Pair] = None, 24 | ) -> tuple[ 25 | ConversionFunction, # conversion from world coordinates to pixel coordinates 26 | ConversionFunction, # conversion from pixel coordinates to world coordinates 27 | ]: 28 | h, w = shape 29 | x_range = sanitize_pair((0, w) if x_range is None else x_range, device) 30 | y_range = sanitize_pair((0, h) if y_range is None else y_range, device) 31 | minima, maxima = torch.stack((x_range, y_range), dim=-1) 32 | wh = torch.tensor((w, h), dtype=torch.float32, device=device) 33 | 34 | def convert_world_to_pixel( 35 | xy: Float[Tensor, "*batch 2"], 36 | ) -> Float[Tensor, "*batch 2"]: 37 | return (xy - minima) / (maxima - minima) * wh 38 | 39 | def convert_pixel_to_world( 40 | xy: Float[Tensor, "*batch 2"], 41 | ) -> Float[Tensor, "*batch 2"]: 42 | return xy / wh * (maxima - minima) + minima 43 | 44 | return convert_world_to_pixel, convert_pixel_to_world 45 | -------------------------------------------------------------------------------- /ggrt/visualization/drawing/lines.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import torch 4 | from einops import einsum, repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_lines( 14 | image: Float[Tensor, "3 height width"], 15 | start: Vector, 16 | end: Vector, 17 | color: Vector, 18 | width: Scalar, 19 | cap: Literal["butt", "round", "square"] = "round", 20 | num_msaa_passes: int = 1, 21 | x_range: Optional[Pair] = None, 22 | y_range: Optional[Pair] = None, 23 | ) -> Float[Tensor, "3 height width"]: 24 | device = image.device 25 | start = sanitize_vector(start, 2, device) 26 | end = sanitize_vector(end, 2, device) 27 | color = sanitize_vector(color, 3, device) 28 | width = sanitize_scalar(width, device) 29 | (num_lines,) = torch.broadcast_shapes( 30 | start.shape[0], 31 | end.shape[0], 32 | color.shape[0], 33 | width.shape, 34 | ) 35 | 36 | # Convert world-space points to pixel space. 37 | _, h, w = image.shape 38 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 39 | start = world_to_pixel(start) 40 | end = world_to_pixel(end) 41 | 42 | def color_function( 43 | xy: Float[Tensor, "point 2"], 44 | ) -> Float[Tensor, "point 4"]: 45 | # Define a vector between the start and end points. 46 | delta = end - start 47 | delta_norm = delta.norm(dim=-1, keepdim=True) 48 | u_delta = delta / delta_norm 49 | 50 | # Define a vector between each sample and the start point. 51 | indicator = xy - start[:, None] 52 | 53 | # Determine whether each sample is inside the line in the parallel direction. 54 | extra = 0.5 * width[:, None] if cap == "square" else 0 55 | parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") 56 | parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) 57 | 58 | # Determine whether each sample is inside the line perpendicularly. 59 | perpendicular = indicator - parallel[..., None] * u_delta[:, None] 60 | perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] 61 | 62 | inside_line = parallel_inside_line & perpendicular_inside_line 63 | 64 | # Compute round caps. 65 | if cap == "round": 66 | near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] 67 | inside_line |= near_start 68 | end_indicator = indicator = xy - end[:, None] 69 | near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] 70 | inside_line |= near_end 71 | 72 | # Determine the sample's color. 73 | selectable_color = color.broadcast_to((num_lines, 3)) 74 | arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] 75 | top_color = selectable_color.gather( 76 | dim=0, 77 | index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), 78 | ) 79 | rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) 80 | 81 | return rgba 82 | 83 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 84 | -------------------------------------------------------------------------------- /ggrt/visualization/drawing/points.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float 6 | from torch import Tensor 7 | 8 | from .coordinate_conversion import generate_conversions 9 | from .rendering import render_over_image 10 | from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector 11 | 12 | 13 | def draw_points( 14 | image: Float[Tensor, "3 height width"], 15 | points: Vector, 16 | color: Vector = [1, 1, 1], 17 | radius: Scalar = 1, 18 | inner_radius: Scalar = 0, 19 | num_msaa_passes: int = 1, 20 | x_range: Optional[Pair] = None, 21 | y_range: Optional[Pair] = None, 22 | ) -> Float[Tensor, "3 height width"]: 23 | device = image.device 24 | points = sanitize_vector(points, 2, device) 25 | color = sanitize_vector(color, 3, device) 26 | radius = sanitize_scalar(radius, device) 27 | inner_radius = sanitize_scalar(inner_radius, device) 28 | (num_points,) = torch.broadcast_shapes( 29 | points.shape[0], 30 | color.shape[0], 31 | radius.shape, 32 | inner_radius.shape, 33 | ) 34 | 35 | # Convert world-space points to pixel space. 36 | _, h, w = image.shape 37 | world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) 38 | points = world_to_pixel(points) 39 | 40 | def color_function( 41 | xy: Float[Tensor, "point 2"], 42 | ) -> Float[Tensor, "point 4"]: 43 | # Define a vector between the start and end points. 44 | delta = xy[:, None] - points[None] 45 | delta_norm = delta.norm(dim=-1) 46 | mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None]) 47 | 48 | # Determine the sample's color. 49 | selectable_color = color.broadcast_to((num_points, 3)) 50 | arrangement = mask * torch.arange(num_points, device=device) 51 | top_color = selectable_color.gather( 52 | dim=0, 53 | index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3), 54 | ) 55 | rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1) 56 | 57 | return rgba 58 | 59 | return render_over_image(image, color_function, device, num_passes=num_msaa_passes) 60 | -------------------------------------------------------------------------------- /ggrt/visualization/drawing/rendering.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | import torch 4 | from einops import rearrange, reduce 5 | from jaxtyping import Bool, Float 6 | from torch import Tensor 7 | 8 | 9 | @runtime_checkable 10 | class ColorFunction(Protocol): 11 | def __call__( 12 | self, 13 | xy: Float[Tensor, "point 2"], 14 | ) -> Float[Tensor, "point 4"]: # RGBA color 15 | pass 16 | 17 | 18 | def generate_sample_grid( 19 | shape: tuple[int, int], 20 | device: torch.device, 21 | ) -> Float[Tensor, "height width 2"]: 22 | h, w = shape 23 | x = torch.arange(w, device=device) + 0.5 24 | y = torch.arange(h, device=device) + 0.5 25 | x, y = torch.meshgrid(x, y, indexing="xy") 26 | return torch.stack([x, y], dim=-1) 27 | 28 | 29 | def detect_msaa_pixels( 30 | image: Float[Tensor, "batch 4 height width"], 31 | ) -> Bool[Tensor, "batch height width"]: 32 | b, _, h, w = image.shape 33 | 34 | mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) 35 | 36 | # Detect horizontal differences. 37 | horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) 38 | mask[:, :, 1:] |= horizontal 39 | mask[:, :, :-1] |= horizontal 40 | 41 | # Detect vertical differences. 42 | vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) 43 | mask[:, 1:, :] |= vertical 44 | mask[:, :-1, :] |= vertical 45 | 46 | # Detect diagonal (top left to bottom right) differences. 47 | tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) 48 | mask[:, 1:, 1:] |= tlbr 49 | mask[:, :-1, :-1] |= tlbr 50 | 51 | # Detect diagonal (top right to bottom left) differences. 52 | trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) 53 | mask[:, :-1, 1:] |= trbl 54 | mask[:, 1:, :-1] |= trbl 55 | 56 | return mask 57 | 58 | 59 | def reduce_straight_alpha( 60 | rgba: Float[Tensor, "batch 4 height width"], 61 | ) -> Float[Tensor, "batch 4"]: 62 | color, alpha = rgba.split((3, 1), dim=1) 63 | 64 | # Color becomes a weighted average of color (weighted by alpha). 65 | weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") 66 | alpha_sum = reduce(alpha, "b c h w -> b c", "sum") 67 | color = weighted_color / (alpha_sum + 1e-10) 68 | 69 | # Alpha becomes mean alpha. 70 | alpha = reduce(alpha, "b c h w -> b c", "mean") 71 | 72 | return torch.cat((color, alpha), dim=-1) 73 | 74 | 75 | @torch.no_grad() 76 | def run_msaa_pass( 77 | xy: Float[Tensor, "batch height width 2"], 78 | color_function: ColorFunction, 79 | scale: float, 80 | subdivision: int, 81 | remaining_passes: int, 82 | device: torch.device, 83 | batch_size: int = int(2**16), 84 | ) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha) 85 | # Sample the color function. 86 | b, h, w, _ = xy.shape 87 | color = [ 88 | color_function(batch) 89 | for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) 90 | ] 91 | color = torch.cat(color, dim=0) 92 | color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) 93 | 94 | # If any MSAA passes remain, subdivide. 95 | if remaining_passes > 0: 96 | mask = detect_msaa_pixels(color) 97 | batch_index, row_index, col_index = torch.where(mask) 98 | xy = xy[batch_index, row_index, col_index] 99 | 100 | offsets = generate_sample_grid((subdivision, subdivision), device) 101 | offsets = (offsets / subdivision - 0.5) * scale 102 | 103 | color_fine = run_msaa_pass( 104 | xy[:, None, None] + offsets, 105 | color_function, 106 | scale / subdivision, 107 | subdivision, 108 | remaining_passes - 1, 109 | device, 110 | batch_size=batch_size, 111 | ) 112 | color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) 113 | 114 | return color 115 | 116 | 117 | @torch.no_grad() 118 | def render( 119 | shape: tuple[int, int], 120 | color_function: ColorFunction, 121 | device: torch.device, 122 | subdivision: int = 8, 123 | num_passes: int = 2, 124 | ) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha) 125 | xy = generate_sample_grid(shape, device) 126 | return run_msaa_pass( 127 | xy[None], 128 | color_function, 129 | 1.0, 130 | subdivision, 131 | num_passes, 132 | device, 133 | )[0] 134 | 135 | 136 | def render_over_image( 137 | image: Float[Tensor, "3 height width"], 138 | color_function: ColorFunction, 139 | device: torch.device, 140 | subdivision: int = 8, 141 | num_passes: int = 1, 142 | ) -> Float[Tensor, "3 height width"]: 143 | _, h, w = image.shape 144 | overlay = render( 145 | (h, w), 146 | color_function, 147 | device, 148 | subdivision=subdivision, 149 | num_passes=num_passes, 150 | ) 151 | color, alpha = overlay.split((3, 1), dim=0) 152 | return image * (1 - alpha) + color * alpha 153 | -------------------------------------------------------------------------------- /ggrt/visualization/drawing/types.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import torch 4 | from einops import repeat 5 | from jaxtyping import Float, Shaped 6 | from torch import Tensor 7 | 8 | Real = Union[float, int] 9 | 10 | Vector = Union[ 11 | Real, 12 | Iterable[Real], 13 | Shaped[Tensor, "3"], 14 | Shaped[Tensor, "batch 3"], 15 | ] 16 | 17 | 18 | def sanitize_vector( 19 | vector: Vector, 20 | dim: int, 21 | device: torch.device, 22 | ) -> Float[Tensor, "*#batch dim"]: 23 | if isinstance(vector, Tensor): 24 | vector = vector.type(torch.float32).to(device) 25 | else: 26 | vector = torch.tensor(vector, dtype=torch.float32, device=device) 27 | while vector.ndim < 2: 28 | vector = vector[None] 29 | if vector.shape[-1] == 1: 30 | vector = repeat(vector, "... () -> ... c", c=dim) 31 | assert vector.shape[-1] == dim 32 | assert vector.ndim == 2 33 | return vector 34 | 35 | 36 | Scalar = Union[ 37 | Real, 38 | Iterable[Real], 39 | Shaped[Tensor, ""], 40 | Shaped[Tensor, " batch"], 41 | ] 42 | 43 | 44 | def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]: 45 | if isinstance(scalar, Tensor): 46 | scalar = scalar.type(torch.float32).to(device) 47 | else: 48 | scalar = torch.tensor(scalar, dtype=torch.float32, device=device) 49 | while scalar.ndim < 1: 50 | scalar = scalar[None] 51 | assert scalar.ndim == 1 52 | return scalar 53 | 54 | 55 | Pair = Union[ 56 | Iterable[Real], 57 | Shaped[Tensor, "2"], 58 | ] 59 | 60 | 61 | def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]: 62 | if isinstance(pair, Tensor): 63 | pair = pair.type(torch.float32).to(device) 64 | else: 65 | pair = torch.tensor(pair, dtype=torch.float32, device=device) 66 | assert pair.shape == (2,) 67 | return pair 68 | -------------------------------------------------------------------------------- /ggrt/visualization/pose_visualizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ggrt.pose_util import to_hom 4 | 5 | 6 | def get_camera_mesh(pose, depth=1): 7 | vertices = torch.tensor([[-0.5, -0.5, 1], 8 | [ 0.5, -0.5, 1], 9 | [ 0.5, 0.5, 1], 10 | [-0.5, 0.5, 1], 11 | [ 0, 0, 0]]) * depth 12 | 13 | faces = torch.tensor([[0, 1, 2], 14 | [0, 2, 3], 15 | [0, 1, 4], 16 | [1, 2, 4], 17 | [2, 3, 4], 18 | [3, 0, 4]]) 19 | 20 | # vertices = camera.cam2world(vertices[None], pose) 21 | vertices = to_hom(vertices[None]) @ pose.transpose(-1, -2) 22 | 23 | wire_frame = vertices[:, [0,1,2,3,0,4,1,2,4,3]] 24 | 25 | return vertices, faces, wire_frame 26 | 27 | 28 | def merge_wire_frames(wire_frame): 29 | wire_frame_merged = [[], [], []] 30 | for w in wire_frame: 31 | wire_frame_merged[0] += [float(n) for n in w[:, 0]] + [None] 32 | wire_frame_merged[1] += [float(n) for n in w[:, 1]] + [None] 33 | wire_frame_merged[2] += [float(n) for n in w[:, 2]] + [None] 34 | 35 | return wire_frame_merged 36 | 37 | 38 | def merge_meshes(vertices, faces): 39 | mesh_N, vertex_N = vertices.shape[:2] 40 | faces_merged = torch.cat([faces+i*vertex_N for i in range(mesh_N)], dim=0) 41 | vertices_merged = vertices.view(-1, vertices.shape[-1]) 42 | 43 | return vertices_merged,faces_merged 44 | 45 | 46 | def merge_centers(centers): 47 | center_merged = [[], [], []] 48 | 49 | for c1, c2 in zip(*centers): 50 | center_merged[0] += [float(c1[0]), float(c2[0]), None] 51 | center_merged[1] += [float(c1[1]), float(c2[1]), None] 52 | center_merged[2] += [float(c1[2]), float(c2[2]), None] 53 | 54 | return center_merged 55 | 56 | 57 | @torch.no_grad() 58 | def visualize_cameras(vis, step, poses=[], cam_depth=0.5, colors=["blue", "magenta"], plot_dist=True): 59 | win_name = "gt_pred" 60 | data = [] 61 | 62 | # set up plots 63 | centers = [] 64 | for pose, color in zip(poses, colors): 65 | pose = pose.detach().cpu() 66 | vertices, faces, wire_frame = get_camera_mesh(pose, depth=cam_depth) 67 | center = vertices[:, -1] 68 | centers.append(center) 69 | 70 | # camera centers 71 | data.append(dict( 72 | type="scatter3d", 73 | x=[float(n) for n in center[:, 0]], 74 | y=[float(n) for n in center[:, 1]], 75 | z=[float(n) for n in center[:, 2]], 76 | mode="markers", 77 | marker=dict(color=color, size=3), 78 | )) 79 | 80 | # colored camera mesh 81 | vertices_merged, faces_merged = merge_meshes(vertices, faces) 82 | 83 | data.append(dict( 84 | type="mesh3d", 85 | x=[float(n) for n in vertices_merged[:, 0]], 86 | y=[float(n) for n in vertices_merged[:, 1]], 87 | z=[float(n) for n in vertices_merged[:, 2]], 88 | i=[int(n) for n in faces_merged[:, 0]], 89 | j=[int(n) for n in faces_merged[:, 1]], 90 | k=[int(n) for n in faces_merged[:, 2]], 91 | flatshading=True, 92 | color=color, 93 | opacity=0.05, 94 | )) 95 | 96 | # camera wire_frame 97 | wire_frame_merged = merge_wire_frames(wire_frame) 98 | data.append(dict( 99 | type="scatter3d", 100 | x=wire_frame_merged[0], 101 | y=wire_frame_merged[1], 102 | z=wire_frame_merged[2], 103 | mode="lines", 104 | line=dict(color=color,), 105 | opacity=0.3, 106 | )) 107 | 108 | if plot_dist: 109 | # distance between two poses (camera centers) 110 | center_merged = merge_centers(centers[:2]) 111 | data.append(dict( 112 | type="scatter3d", 113 | x=center_merged[0], 114 | y=center_merged[1], 115 | z=center_merged[2], 116 | mode="lines", 117 | line=dict(color="red",width=4,), 118 | )) 119 | 120 | if len(centers)==4: 121 | center_merged = merge_centers(centers[2:4]) 122 | data.append(dict( 123 | type="scatter3d", 124 | x=center_merged[0], 125 | y=center_merged[1], 126 | z=center_merged[2], 127 | mode="lines", 128 | line=dict(color="red",width=4,), 129 | )) 130 | 131 | # send data to visdom 132 | vis._send(dict( 133 | data=data, 134 | win="poses", 135 | eid=win_name, 136 | layout=dict( 137 | title="({})".format(step), 138 | autosize=True, 139 | margin=dict(l=30,r=30,b=30,t=30,), 140 | showlegend=False, 141 | yaxis=dict( 142 | scaleanchor="x", 143 | scaleratio=1, 144 | ) 145 | ), 146 | opts=dict(title="{} poses ({})".format(win_name, step),), 147 | )) 148 | -------------------------------------------------------------------------------- /ggrt/visualization/validation_in_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from jaxtyping import Float, Shaped 3 | from torch import Tensor 4 | 5 | from ..model.decoder.cuda_splatting import render_cuda_orthographic 6 | from ..model.types import Gaussians 7 | from ..visualization.annotation import add_label 8 | from ..visualization.drawing.cameras import draw_cameras 9 | from .drawing.cameras import compute_equal_aabb_with_margin 10 | 11 | 12 | def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]: 13 | shapes = torch.stack([torch.tensor(x.shape) for x in images]) 14 | padded_shape = shapes.max(dim=0)[0] 15 | results = [ 16 | torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device) 17 | for x in images 18 | ] 19 | for image, result in zip(images, results): 20 | slices = [slice(0, x) for x in image.shape] 21 | result[slices] = image[slices] 22 | return results 23 | 24 | 25 | def render_projections( 26 | gaussians: Gaussians, 27 | resolution: int, 28 | margin: float = 0.1, 29 | draw_label: bool = True, 30 | extra_label: str = "", 31 | ) -> Float[Tensor, "batch 3 3 height width"]: 32 | device = gaussians.means.device 33 | b, _, _ = gaussians.means.shape 34 | 35 | # Compute the minima and maxima of the scene. 36 | minima = gaussians.means.min(dim=1).values 37 | maxima = gaussians.means.max(dim=1).values 38 | scene_minima, scene_maxima = compute_equal_aabb_with_margin( 39 | minima, maxima, margin=margin 40 | ) 41 | 42 | projections = [] 43 | for look_axis in range(3): 44 | right_axis = (look_axis + 1) % 3 45 | down_axis = (look_axis + 2) % 3 46 | 47 | # Define the extrinsics for rendering. 48 | extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device) 49 | extrinsics[:, right_axis, 0] = 1 50 | extrinsics[:, down_axis, 1] = 1 51 | extrinsics[:, look_axis, 2] = 1 52 | extrinsics[:, right_axis, 3] = 0.5 * ( 53 | scene_minima[:, right_axis] + scene_maxima[:, right_axis] 54 | ) 55 | extrinsics[:, down_axis, 3] = 0.5 * ( 56 | scene_minima[:, down_axis] + scene_maxima[:, down_axis] 57 | ) 58 | extrinsics[:, look_axis, 3] = scene_minima[:, look_axis] 59 | extrinsics[:, 3, 3] = 1 60 | 61 | # Define the intrinsics for rendering. 62 | extents = scene_maxima - scene_minima 63 | far = extents[:, look_axis] 64 | near = torch.zeros_like(far) 65 | width = extents[:, right_axis] 66 | height = extents[:, down_axis] 67 | 68 | projection = render_cuda_orthographic( 69 | extrinsics, 70 | width, 71 | height, 72 | near, 73 | far, 74 | (resolution, resolution), 75 | torch.zeros((b, 3), dtype=torch.float32, device=device), 76 | gaussians.means, 77 | gaussians.covariances, 78 | gaussians.harmonics, 79 | gaussians.opacities, 80 | fov_degrees=10.0, 81 | ) 82 | if draw_label: 83 | right_axis_name = "XYZ"[right_axis] 84 | down_axis_name = "XYZ"[down_axis] 85 | label = f"{right_axis_name}{down_axis_name} Projection {extra_label}" 86 | projection = torch.stack([add_label(x, label) for x in projection]) 87 | 88 | projections.append(projection) 89 | 90 | return torch.stack(pad(projections), dim=1) 91 | 92 | 93 | def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]: 94 | # Define colors for context and target views. 95 | num_context_views = batch["context"]["extrinsics"].shape[1] 96 | num_target_views = batch["target"]["extrinsics"].shape[1] 97 | color = torch.ones( 98 | (num_target_views + num_context_views, 3), 99 | dtype=torch.float32, 100 | device=batch["target"]["extrinsics"].device, 101 | ) 102 | color[num_context_views:, 1:] = 0 103 | 104 | return draw_cameras( 105 | resolution, 106 | torch.cat( 107 | (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0]) 108 | ), 109 | torch.cat( 110 | (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0]) 111 | ), 112 | color, 113 | torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])), 114 | torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])), 115 | ) 116 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lifuguan/GGRt_official/01886261b6b6b6175b6ea88f44a85c640564ae9f/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/env/dependencies.sh: -------------------------------------------------------------------------------- 1 | conda create -n dbarf python=3.9 2 | conda activate dbarf 3 | 4 | # install pytorch 5 | # # CUDA 10.2 6 | # conda install pytorch==1.11.0 torchvision==0.12.0 cudatoolkit=10.2 -c pytorch 7 | 8 | # CUDA 11.3 9 | conda install pytorch==1.11.0 torchvision==0.12.0 cudatoolkit=11.3 -c pytorch 10 | 11 | git clone https://github.com/cvg/sfm-disambiguation-colmap.git 12 | cd sfm-disambiguation-colmap 13 | python -m pip install -e . 14 | 15 | # HLoc is used for extracting keypoints and matching features. 16 | git clone --recursive https://github.com/cvg/Hierarchical-Localization/ 17 | cd Hierarchical-Localization/ 18 | python -m pip install -e . 19 | cd .. 20 | 21 | conda install -c conda-forge ffmpeg # imageio-ffmpeg 22 | pip install opencv-python matplotlib easydict tqdm networkx einops \ 23 | imageio visdom tensorboard tensorboardX configargparse lpips 24 | -------------------------------------------------------------------------------- /scripts/pairs_from_retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | from typing import Optional 5 | import h5py 6 | import numpy as np 7 | import torch 8 | import collections.abc as collections 9 | 10 | from hloc import logger 11 | from hloc.utils.parsers import parse_image_lists 12 | from hloc.utils.read_write_model import read_images_binary 13 | from hloc.utils.io import list_h5_names 14 | 15 | 16 | def parse_names(prefix, names, names_all): 17 | if prefix is not None: 18 | if not isinstance(prefix, str): 19 | prefix = tuple(prefix) 20 | names = [n for n in names_all if n.startswith(prefix)] 21 | elif names is not None: 22 | if isinstance(names, (str, Path)): 23 | names = parse_image_lists(names) 24 | elif isinstance(names, collections.Iterable): 25 | names = list(names) 26 | else: 27 | raise ValueError(f'Unknown type of image list: {names}.' 28 | 'Provide either a list or a path to a list file.') 29 | else: 30 | names = names_all 31 | return names 32 | 33 | 34 | def get_descriptors(names, path, name2idx=None, key='global_descriptor'): 35 | if name2idx is None: 36 | with h5py.File(str(path), 'r') as fd: 37 | desc = [fd[n][key].__array__() for n in names] 38 | else: 39 | desc = [] 40 | for n in names: 41 | with h5py.File(str(path[name2idx[n]]), 'r') as fd: 42 | desc.append(fd[n][key].__array__()) 43 | return torch.from_numpy(np.stack(desc, 0)).float() 44 | 45 | 46 | def pairs_from_score_matrix(scores: torch.Tensor, 47 | invalid: np.array, 48 | num_select: int, 49 | min_score: Optional[float] = None): 50 | assert scores.shape == invalid.shape 51 | if isinstance(scores, np.ndarray): 52 | scores = torch.from_numpy(scores) 53 | invalid = torch.from_numpy(invalid).to(scores.device) 54 | if min_score is not None: 55 | invalid |= scores < min_score 56 | scores.masked_fill_(invalid, float('-inf')) 57 | 58 | topk = torch.topk(scores, num_select, dim=1) 59 | indices = topk.indices.cpu().numpy() 60 | valid = topk.values.isfinite().cpu().numpy() 61 | 62 | pairs = [] 63 | for i, j in zip(*np.where(valid)): 64 | pairs.append((i, indices[i, j])) 65 | return pairs 66 | 67 | 68 | def main(descriptors, output, num_matched, 69 | query_prefix=None, query_list=None, 70 | db_prefix=None, db_list=None, db_model=None, db_descriptors=None, 71 | device='cuda'): 72 | logger.info('Extracting image pairs from a retrieval database.') 73 | 74 | # We handle multiple reference feature files. 75 | # We only assume that names are unique among them and map names to files. 76 | if db_descriptors is None: 77 | db_descriptors = descriptors 78 | if isinstance(db_descriptors, (Path, str)): 79 | db_descriptors = [db_descriptors] 80 | name2db = {n: i for i, p in enumerate(db_descriptors) 81 | for n in list_h5_names(p)} 82 | db_names_h5 = list(name2db.keys()) 83 | query_names_h5 = list_h5_names(descriptors) 84 | 85 | if db_model: 86 | images = read_images_binary(os.path.join(db_model, 'images.bin')) 87 | db_names = [i.name for i in images.values()] 88 | else: 89 | db_names = parse_names(db_prefix, db_list, db_names_h5) 90 | 91 | num_images = len(db_names) 92 | if num_images == 0: 93 | raise ValueError('Could not find any database image.') 94 | query_names = parse_names(query_prefix, query_list, query_names_h5) 95 | 96 | # device = 'cuda' if torch.cuda.is_available() else 'cpu' 97 | db_desc = get_descriptors(db_names, db_descriptors, name2db) 98 | query_desc = get_descriptors(query_names, descriptors) 99 | sim = torch.einsum('id,jd->ij', query_desc.to(device), db_desc.to(device)) 100 | 101 | # Avoid self-matching 102 | self = np.array(query_names)[:, None] == np.array(db_names)[None] 103 | num_matched = min(num_images, num_matched) 104 | pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0) 105 | pairs = [(query_names[i], db_names[j]) for i, j in pairs] 106 | 107 | logger.info(f'Found {len(pairs)} pairs.') 108 | with open(output, 'w') as f: 109 | f.write('\n'.join(' '.join([i, j]) for i, j in pairs)) 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument('--descriptors', type=Path, required=True) 115 | parser.add_argument('--output', type=Path, required=True) 116 | parser.add_argument('--num_matched', type=int, required=True) 117 | parser.add_argument('--query_prefix', type=str, nargs='+') 118 | parser.add_argument('--query_list', type=Path) 119 | parser.add_argument('--db_prefix', type=str, nargs='+') 120 | parser.add_argument('--db_list', type=Path) 121 | parser.add_argument('--db_model', type=Path) 122 | parser.add_argument('--db_descriptors', type=Path) 123 | args = parser.parse_args() 124 | main(**args.__dict__) 125 | -------------------------------------------------------------------------------- /scripts/preprocess_dbarf_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | 5 | from ggrt.geometry.track import load_track_elements, TrackBuilder 6 | from scripts import extract_relative_poses 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--dataset_dir', type=Path, default='datasets', 12 | help='Path to the dataset, default: %(default)s') 13 | parser.add_argument('--outputs', type=Path, default='outputs', 14 | help='Path to the output directory, default: %(default)s') 15 | parser.add_argument('--num_matches', type=int, default=30, 16 | help='Number of image pairs for loc, default: %(default)s') 17 | parser.add_argument('--disambiguate', action="store_true", 18 | help='Enable/Disable disambiguating wrong matches.') 19 | parser.add_argument('--min_track_length', type=int, default=3) 20 | parser.add_argument('--max_track_length', type=int, default=40) 21 | parser.add_argument('--track_degree', type=int, default=3) 22 | parser.add_argument('--coverage_thres', type=float, default=0.9) 23 | parser.add_argument('--alpha', type=float, default=0.1) 24 | parser.add_argument('--minimal_views', type=int, default=5) 25 | parser.add_argument('--ds', type=str, 26 | choices=['dict', 'smallarray', 'largearray'], 27 | default='largearray') 28 | parser.add_argument('--filter_type', type=str, choices=[ 29 | 'threshold', 'knn', 'mst_min', 'mst_mean', 'percentile'], 30 | default='threshold') 31 | parser.add_argument('--threshold', type=float, default=0.15) 32 | parser.add_argument('--topk', type=int, default=3) 33 | parser.add_argument('--percentile', type=float) 34 | parser.add_argument('--colmap_path', type=Path, default='colmap') 35 | parser.add_argument('--geometric_verification_type', 36 | type=str, 37 | choices=['default', 'strict'], 38 | default='default') 39 | parser.add_argument('--recon', action="store_true", 40 | help='Indicates whether to reconstruct the scene.') 41 | parser.add_argument('--visualize', action="store_true", 42 | help='Whether to visualize the reconstruction.') 43 | parser.add_argument('--gpu_idx', type=str, default='0') 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | # Extracting relative poses and store as g2o file. 51 | view_graph_path, database_path, num_view_pairs = extract_relative_poses.main(args=args) 52 | 53 | # Extracting tracks from colmap database. 54 | track_elements, track_element_pairs = \ 55 | load_track_elements(database_path=database_path) 56 | track_builder = TrackBuilder(args.min_track_length, args.max_track_length) 57 | track_builder.build(track_elements, track_element_pairs) 58 | track_builder.filter() 59 | track_filename = os.path.join(args.outputs, 'track.txt') 60 | track_builder.write_to_file(track_filename) 61 | print(f'Tracks are written to: {track_filename}') 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /scripts/reconstruction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from typing import Optional, List 4 | import multiprocessing 5 | from pathlib import Path 6 | import pycolmap 7 | 8 | from hloc import logger 9 | from hloc.utils.database import COLMAPDatabase 10 | from hloc.triangulation import ( 11 | import_features, import_matches, geometric_verification, OutputCapture) 12 | 13 | 14 | def create_empty_db(database_path): 15 | if database_path.exists(): 16 | logger.warning('The database already exists, deleting it.') 17 | database_path.unlink() 18 | logger.info('Creating an empty database...') 19 | db = COLMAPDatabase.connect(database_path) 20 | db.create_tables() 21 | db.commit() 22 | db.close() 23 | 24 | 25 | def import_images(image_dir, database_path, camera_mode, image_list=None): 26 | logger.info('Importing images into the database...') 27 | images = list(image_dir.iterdir()) 28 | if len(images) == 0: 29 | raise IOError(f'No images found in {image_dir}.') 30 | with pycolmap.ostream(): 31 | pycolmap.import_images(database_path, image_dir, camera_mode, 32 | image_list=image_list or []) 33 | 34 | 35 | def get_image_ids(database_path): 36 | db = COLMAPDatabase.connect(database_path) 37 | images = {} 38 | for name, image_id in db.execute("SELECT name, image_id FROM images;"): 39 | images[name] = image_id 40 | db.close() 41 | return images 42 | 43 | 44 | def run_reconstruction(sfm_dir, database_path, image_dir, verbose=False): 45 | models_path = sfm_dir / 'models' 46 | models_path.mkdir(exist_ok=True, parents=True) 47 | logger.info('Running 3D reconstruction...') 48 | with OutputCapture(verbose): 49 | with pycolmap.ostream(): 50 | reconstructions = pycolmap.incremental_mapping( 51 | database_path, image_dir, models_path, 52 | num_threads=min(multiprocessing.cpu_count(), 16)) 53 | 54 | if len(reconstructions) == 0: 55 | logger.error('Could not reconstruct any model!') 56 | return None 57 | logger.info(f'Reconstructed {len(reconstructions)} model(s).') 58 | 59 | largest_index = None 60 | largest_num_images = 0 61 | for index, rec in reconstructions.items(): 62 | num_images = rec.num_reg_images() 63 | if num_images > largest_num_images: 64 | largest_index = index 65 | largest_num_images = num_images 66 | assert largest_index is not None 67 | logger.info(f'Largest model is #{largest_index} ' 68 | f'with {largest_num_images} images.') 69 | 70 | for filename in ['images.bin', 'cameras.bin', 'points3D.bin']: 71 | if (sfm_dir / filename).exists(): 72 | (sfm_dir / filename).unlink() 73 | shutil.move( 74 | str(models_path / str(largest_index) / filename), str(models_path)) 75 | return reconstructions[largest_index] 76 | 77 | 78 | def main(database, output_dir, image_dir, pairs, features, matches, 79 | camera_mode=pycolmap.CameraMode.AUTO, verbose=False, 80 | skip_geometric_verification=False, min_match_score=None, 81 | image_list: Optional[List[str]] = None): 82 | 83 | assert features.exists(), features 84 | assert pairs.exists(), pairs 85 | assert matches.exists(), matches 86 | 87 | output_dir.mkdir(parents=True, exist_ok=True) 88 | 89 | # create_empty_db(database) 90 | # import_images(image_dir, database, camera_mode, image_list) 91 | image_ids = get_image_ids(database) 92 | # import_features(image_ids, database, features) 93 | # import_matches(image_ids, database, pairs, matches, 94 | # min_match_score, skip_geometric_verification) 95 | # if not skip_geometric_verification: 96 | # geometric_verification(database, pairs, verbose) 97 | reconstruction = run_reconstruction(output_dir, database, image_dir, verbose) 98 | if reconstruction is not None: 99 | logger.info(f'Reconstruction statistics:\n{reconstruction.summary()}' 100 | + f'\n\tnum_input_images = {len(image_ids)}') 101 | return reconstruction 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument('--output_dir', type=Path, required=True) 107 | parser.add_argument('--image_dir', type=Path, required=True) 108 | 109 | parser.add_argument('--pairs', type=Path, required=True) 110 | parser.add_argument('--features', type=Path, required=True) 111 | parser.add_argument('--matches', type=Path, required=True) 112 | 113 | parser.add_argument('--camera_mode', type=str, default="AUTO", 114 | choices=list(pycolmap.CameraMode.__members__.keys())) 115 | parser.add_argument('--skip_geometric_verification', action='store_true') 116 | parser.add_argument('--min_match_score', type=float) 117 | parser.add_argument('--verbose', action='store_true') 118 | args = parser.parse_args() 119 | 120 | main(**args.__dict__) 121 | -------------------------------------------------------------------------------- /scripts/shell/eval_coarse_llff_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ITER=$1 # [200000, 260000] 4 | GPU_ID=$2 5 | 6 | export PYTHONDONTWRITEBYTECODE=1 7 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 8 | 9 | HOME_DIR=$HOME 10 | EVAL_CODE_DIR=${HOME_DIR}/'Projects/dbarf/eval' 11 | cd $EVAL_CODE_DIR 12 | 13 | CONFIG_DIR=${HOME_DIR}/'Projects/dbarf/configs' 14 | ROOT_DIR=${HOME_DIR}/'Datasets/IBRNet/eval' 15 | CKPT_DIR=${HOME_DIR}/'Datasets/IBRNet/eval/out' 16 | 17 | EXPNAME='eval_coarse_llff_finetune' 18 | 19 | scenes=('fern' 'flower' 'fortress' 'horns' 'leaves' 'orchids' 'room' 'trex') 20 | 21 | for((i=0;i<${#scenes[@]};i++)); 22 | do 23 | echo ${scenes[i]} 24 | checkpoint_path=${CKPT_DIR}/'finetune_coarse_ibr_llff_'${scenes[i]}/'model_'$ITER'.pth' 25 | eval_config_file=$CONFIG_DIR/eval_coarse_llff.txt 26 | 27 | CUDA_VISIBLE_DEVICES=$GPU_ID python eval.py \ 28 | --config ${eval_config_file} \ 29 | --expname $EXPNAME \ 30 | --rootdir $ROOT_DIR \ 31 | --ckpt_path ${checkpoint_path} \ 32 | --eval_scenes ${scenes[i]} 33 | done 34 | -------------------------------------------------------------------------------- /scripts/shell/eval_coarse_nerf_synthetic_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ITER=$1 4 | GPU_ID=$2 5 | 6 | export PYTHONDONTWRITEBYTECODE=1 7 | 8 | HOME_DIR=$HOME 9 | EVAL_CODE_DIR=${HOME_DIR}/'Projects/dbarf/eval' 10 | cd $EVAL_CODE_DIR 11 | 12 | CONFIG_DIR=${HOME_DIR}/'Projects/dbarf/configs' 13 | ROOT_DIR=${HOME_DIR}/'Datasets/IBRNet/eval' 14 | CKPT_DIR=${HOME_DIR}/'Datasets/IBRNet/eval/out' 15 | 16 | EXPNAME='eval_coarse_nerf_synthetic' 17 | 18 | scenes=("fern" "flower" "fortress" "horns" "leaves" "orchids" "room" "trex") 19 | 20 | for((i=0;i<${#scenes[@]};i++)); 21 | do 22 | echo ${scenes[i]} 23 | # For pretrained model. 24 | checkpoint_path=$HOME_DIR/Datasets/IBRNet/pretraining_dbarf/model/model_${ITER}.pth 25 | # For fintuned checkpoint. 26 | # checkpoint_path=${CKPT_DIR}/'finetune_ibrnet_llff_'${scenes[i]}_200k/'model_'$ITER'.pth' 27 | 28 | echo 'Computing metrics for NeRF...' 29 | CUDA_VISIBLE_DEVICES=$GPU_ID python eval.py \ 30 | --config $CONFIG_DIR/eval_coarse_nerf_synthetic.txt \ 31 | --expname $EXPNAME \ 32 | --rootdir $ROOT_DIR \ 33 | --ckpt_path ${checkpoint_path} \ 34 | --eval_scenes ${scenes[i]} 35 | 36 | done 37 | -------------------------------------------------------------------------------- /scripts/shell/eval_coarse_scannet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ITER=$1 # [200000, 260000] 4 | GPU_ID=$2 5 | 6 | export PYTHONDONTWRITEBYTECODE=1 7 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 8 | 9 | HOME_DIR=$HOME 10 | EVAL_CODE_DIR=${HOME_DIR}/'Projects/dbarf/eval' 11 | cd $EVAL_CODE_DIR 12 | 13 | GRAPH_OPTIM_ROOT_DIR=${HOME_DIR}/'Projects/GraphOptim' 14 | CONFIG_DIR=${HOME_DIR}/'Projects/dbarf/configs' 15 | ROOT_DIR=${HOME_DIR}/'Datasets/scannet' 16 | CKPT_DIR=${HOME_DIR}/'Datasets/IBRNet/eval/out' 17 | 18 | EXPNAME='eval_coarse_ibrnet_scannet' 19 | 20 | scenes=("scene0671_00" "scene0673_03" "scene0675_00" "scene_0675_01" "scene0680_00" "scene0684_00" "scene0684_01") 21 | 22 | for((i=0;i<${#scenes[@]};i++)); 23 | do 24 | echo ${scenes[i]} 25 | checkpoint_path=${CKPT_DIR}/'finetune_dbarf_llff_'${scenes[i]}/'model_'$ITER'.pth' 26 | eval_config_file=$CONFIG_DIR/eval_dbarf_llff.txt 27 | 28 | echo 'Computing metrics for IBRNet...' 29 | CUDA_VISIBLE_DEVICES=0 python eval.py \ 30 | --config ${eval_config_file} \ 31 | --expname $EXPNAME \ 32 | --rootdir $ROOT_DIR \ 33 | --ckpt_path ${checkpoint_path} \ 34 | --eval_scenes ${scenes[i]} 35 | done 36 | -------------------------------------------------------------------------------- /scripts/shell/eval_dbarf_ibr_collected_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | GPU_ID=$1 4 | 5 | HOME_DIR=$HOME 6 | EVAL_CODE_DIR=${HOME_DIR}/'Projects/dbarf/eval' 7 | cd $EVAL_CODE_DIR 8 | 9 | CONFIG_DIR=${HOME_DIR}/'Projects/dbarf/configs' 10 | ROOT_DIR=${HOME_DIR}/'Datasets/IBRNet/train' 11 | CKPT_DIR=${HOME_DIR}/'Datasets/IBRNet/pretrained_model/dbarf_model_200000.pth' 12 | ITER=200000 13 | EXPNAME='eval_dbarf_ibr_collected_pretrain' 14 | 15 | scenes=("howardzhou_001_yellow_roses" "howardzhou_2_002_giraffe_plush" "howardzhou_2_003_yamaha_piano" \ 16 | "howardzhou_2_004_sony_camera" "howardzhou_2_010_Japanese_camilia" "howardzhou_2_014_silver_toyota_rav4_scaled_model" \ 17 | "howardzhou_2_018_dumbbell_jumprope" "howardzhou_2_021_hat_on_fur" "howardzhou_2_022_roses" \ 18 | "howardzhou_002_stonecrops" "howardzhou_003_stream" "howardzhou_004_wooden_moose" "howardzhou_005_ladder" \ 19 | "howardzhou_009_girl_head_bust" "howardzhou_012_ground_plant" "howardzhou_014_pink_camilia" \ 20 | "howardzhou_015_valves" "howardzhou_017_vending_machine_02" "howardzhou_019_red_18wheeler_truck" \ 21 | "howardzhou_020_yellow_beetle_and_rv" "howardzhou_021_crystal_light" "howardzhou_023_2_plush_toys" \ 22 | "howardzhou_024_android_figurine" "howardzhou_025_mug_with_pink_drink" "howardzhou_026_metal_alarm_clock" \ 23 | "qq1" "qq2" "qq3" "qq4" "qq5" "qq6" "qq7" "qq8" "qq9" "qq10" "qq11" "qq12" "qq13" "qq14" "qq15" "qq16" "qq17" \ 24 | "qq18" "qq19" "qq20" "qq21" "qq37" "qq40" "qq44" "zc02" "zc03" "zc04" "zc05" "zc06" "zc07" "zc08" "zc09" "zc10" \ 25 | "zc11" "zc12" "zc16" "zc17" "zc18") 26 | 27 | for((i=0;i<${#scenes[@]};i++)); 28 | do 29 | echo ${scenes[i]} 30 | 31 | eval_config_file=$CONFIG_DIR/eval_dbarf_llff.txt 32 | 33 | # (1) Compute metrics for NeRF. 34 | echo 'Computing metrics for NeRF...' 35 | CUDA_VISIBLE_DEVICES=$GPU_ID python eval_dbarf.py \ 36 | --config ${eval_config_file} \ 37 | --expname $EXPNAME \ 38 | --rootdir $ROOT_DIR \ 39 | --ckpt_path ${CKPT_DIR} \ 40 | --eval_dataset 'ibrnet_collected' \ 41 | --eval_scenes ${scenes[i]} 42 | 43 | # (2) Generate view graph. 44 | echo 'Generating view graph from pose estimator...' 45 | CUDA_VISIBLE_DEVICES=$GPU_ID python dbarf_compute_poses.py \ 46 | --config ${eval_config_file} \ 47 | --expname $EXPNAME \ 48 | --rootdir $ROOT_DIR \ 49 | --ckpt_path ${CKPT_DIR} \ 50 | --eval_dataset 'ibrnet_collected' \ 51 | --eval_scenes ${scenes[i]} 52 | 53 | pred_view_graph_path=$ROOT_DIR/$EXPNAME/${scenes[i]}_${ITER}/'pred_view_graph.g2o' 54 | gt_view_graph_path=$ROOT_DIR/$EXPNAME/${scenes[i]}_${ITER}/'gt_view_graph.g2o' 55 | updated_pred_view_graph_path=$ROOT_DIR/$EXPNAME/${scenes[i]}_${ITER}/'updated_pred_view_graph.g2o' 56 | 57 | done 58 | -------------------------------------------------------------------------------- /scripts/shell/eval_dbarf_llff_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ITER=$1 # [200000, 260000] 4 | GPU_ID=$2 5 | 6 | export PYTHONDONTWRITEBYTECODE=1 7 | 8 | HOME_DIR=$HOME 9 | EVAL_CODE_DIR=${HOME_DIR}/'Projects/dbarf/eval' 10 | cd $EVAL_CODE_DIR 11 | 12 | CONFIG_DIR=${HOME_DIR}/'Projects/dbarf/configs' 13 | ROOT_DIR=${HOME_DIR}/'Datasets/IBRNet/eval' 14 | CKPT_DIR=/home/chenyu/Datasets/IBRNet/eval/experimental_results/out/dbarf # ${HOME_DIR}/'Datasets/IBRNet/eval/out' 15 | 16 | EXPNAME='eval_dbarf_llff_finetune' 17 | 18 | scenes=("fern" "flower" "fortress" "horns" "leaves" "orchids" "room" "trex") 19 | 20 | for((i=0;i<${#scenes[@]};i++)); 21 | do 22 | echo ${scenes[i]} 23 | # For pretrained model. 24 | checkpoint_path=$HOME_DIR/Datasets/IBRNet/pretraining_dbarf/model/model_${ITER}.pth 25 | # ITER=200000 26 | # For fintuned checkpoint. 27 | ## checkpoint_path=${CKPT_DIR}/'finetune_dbarf_llff_'${scenes[i]}_200k/'model_'$ITER'.pth' 28 | eval_config_file=$CONFIG_DIR/eval_dbarf_llff.txt 29 | 30 | # (1) Compute metrics for NeRF. 31 | echo 'Computing metrics for NeRF...' 32 | CUDA_VISIBLE_DEVICES=$GPU_ID python eval_dbarf.py \ 33 | --config ${eval_config_file} \ 34 | --expname $EXPNAME \ 35 | --rootdir $ROOT_DIR \ 36 | --ckpt_path ${checkpoint_path} \ 37 | --eval_scenes ${scenes[i]} 38 | 39 | # (2) Generate view graph. 40 | echo 'Generating view graph from pose estimator...' 41 | CUDA_VISIBLE_DEVICES=$GPU_ID python dbarf_compute_poses.py \ 42 | --config ${eval_config_file} \ 43 | --expname $EXPNAME \ 44 | --rootdir $ROOT_DIR \ 45 | --ckpt_path ${checkpoint_path} \ 46 | --eval_scenes ${scenes[i]} 47 | 48 | pred_view_graph_path=$ROOT_DIR/$EXPNAME/${scenes[i]}_${ITER}/'pred_view_graph.g2o' 49 | gt_view_graph_path=$ROOT_DIR/$EXPNAME/${scenes[i]}_${ITER}/'gt_view_graph.g2o' 50 | 51 | done 52 | -------------------------------------------------------------------------------- /scripts/shell/eval_dbarf_scannet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ITER=$1 # [200000, 260000] 4 | GPU_ID=$2 5 | 6 | HOME_DIR=$HOME 7 | EVAL_CODE_DIR=${HOME_DIR}/'Projects/dbarf/eval' 8 | cd $EVAL_CODE_DIR 9 | 10 | CONFIG_DIR=${HOME_DIR}/'Projects/dbarf/configs' 11 | ROOT_DIR=${HOME_DIR}/'Datasets/scannet' 12 | CKPT_DIR=${HOME_DIR}/'Datasets/IBRNet/eval/out' 13 | 14 | EXPNAME='eval_dbarf_scannet_pred_pose' 15 | 16 | scenes=("scene0671_00" "scene0673_03" "scene0675_00" "scene0675_01" "scene0680_00" "scene0684_00" "scene0684_01") 17 | 18 | for((i=0;i<${#scenes[@]};i++)); 19 | do 20 | echo ${scenes[i]} 21 | checkpoint_path=${CKPT_DIR}/'finetune_dbarf_scannet_'${scenes[i]}/'model_'$ITER'.pth' 22 | eval_config_file=$CONFIG_DIR/eval_dbarf_llff.txt 23 | 24 | # (1) Compute metrics for NeRF. 25 | echo 'Computing metrics for DBARF...' 26 | CUDA_VISIBLE_DEVICES=${GPU_ID} python eval_dbarf.py \ 27 | --config ${eval_config_file} \ 28 | --expname $EXPNAME \ 29 | --rootdir $ROOT_DIR \ 30 | --ckpt_path ${checkpoint_path} \ 31 | --eval_dataset 'scannet' \ 32 | --eval_scenes ${scenes[i]} 33 | 34 | # (2) Generate view graph. 35 | echo 'Generating view graph from pose estimator...' 36 | CUDA_VISIBLE_DEVICES=${GPU_ID} python dbarf_compute_poses.py \ 37 | --config ${eval_config_file} \ 38 | --expname $EXPNAME \ 39 | --rootdir $ROOT_DIR \ 40 | --ckpt_path ${checkpoint_path} \ 41 | --eval_dataset 'scannet' \ 42 | --eval_scenes ${scenes[i]} 43 | 44 | pred_view_graph_path=$ROOT_DIR/$EXPNAME/${scenes[i]}_${ITER}/'pred_view_graph.g2o' 45 | gt_view_graph_path=$ROOT_DIR/$EXPNAME/${scenes[i]}_${ITER}/'gt_view_graph.g2o' 46 | done 47 | -------------------------------------------------------------------------------- /scripts/shell/eval_llff_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG_DIR='/home/chenyu/Projects/PoseNeRF/ibrnet/configs' 4 | EVAL_CODE_DIR='/home/chenyu/Projects/PoseNeRF/ibrnet/eval' 5 | cd $EVAL_CODE_DIR 6 | 7 | scenes=('fern' 'flower' 'fortress' 'horns' 'leaves' 'orchids' 'room' 'trex') 8 | 9 | for((i=0;i<=${#scenes[@]};i++)); 10 | do 11 | CUDA_VISIBLE_DEVICES=2 python eval.py \ 12 | --config $CONFIG_DIR/eval_llff.txt \ 13 | --eval_scenes ${scenes[i]} 14 | done 15 | 16 | 17 | # CUDA_VISIBLE_DEVICES=0 python eval.py --config ../configs/eval_llff.txt --eval_scenes horns & 18 | # CUDA_VISIBLE_DEVICES=1 python eval.py --config ../configs/eval_llff.txt --eval_scenes trex & 19 | # CUDA_VISIBLE_DEVICES=2 python eval.py --config ../configs/eval_llff.txt --eval_scenes room & 20 | # CUDA_VISIBLE_DEVICES=3 python eval.py --config ../configs/eval_llff.txt --eval_scenes flower & 21 | # CUDA_VISIBLE_DEVICES=4 python eval.py --config ../configs/eval_llff.txt --eval_scenes orchids & 22 | # CUDA_VISIBLE_DEVICES=5 python eval.py --config ../configs/eval_llff.txt --eval_scenes leaves & 23 | # CUDA_VISIBLE_DEVICES=6 python eval.py --config ../configs/eval_llff.txt --eval_scenes fern & 24 | # CUDA_VISIBLE_DEVICES=7 python eval.py --config ../configs/eval_llff.txt --eval_scenes fortress & 25 | -------------------------------------------------------------------------------- /scripts/shell/render_coarse_llff_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ITER=$1 # [200000, 260000] 4 | GPU_ID=$2 5 | 6 | export PYTHONDONTWRITEBYTECODE=1 7 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 8 | 9 | HOME_DIR=$HOME 10 | EVAL_CODE_DIR=${HOME_DIR}/'Projects/dbarf/eval' 11 | cd $EVAL_CODE_DIR 12 | 13 | CONFIG_DIR=${HOME_DIR}/'Projects/dbarf/configs' 14 | ROOT_DIR='/media/chenyu/Data/datasets/IBRNet/eval' 15 | CKPT_DIR=$ROOT_DIR'/out'/'coarse_ibr' 16 | 17 | EXPNAME='eval_coarse_llff_finetune' 18 | 19 | scenes=('fern' 'flower' 'fortress' 'horns' 'leaves' 'orchids' 'room' 'trex') 20 | 21 | for((i=0;i<${#scenes[@]};i++)); 22 | do 23 | echo ${scenes[i]} 24 | checkpoint_path=${CKPT_DIR}/'finetune_coarse_ibr_llff_'${scenes[i]}/'model_'$ITER'.pth' 25 | eval_config_file=$CONFIG_DIR/eval_coarse_llff.txt 26 | 27 | CUDA_VISIBLE_DEVICES=$GPU_ID python render_llff_video.py \ 28 | --config ${eval_config_file} \ 29 | --expname $EXPNAME \ 30 | --rootdir $ROOT_DIR \ 31 | --ckpt_path ${checkpoint_path} \ 32 | --eval_scenes ${scenes[i]} 33 | done 34 | -------------------------------------------------------------------------------- /scripts/shell/render_dbarf_llff_all.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ITER=$1 # [200000, 260000] 4 | GPU_ID=$2 5 | 6 | export PYTHONDONTWRITEBYTECODE=1 7 | export CUDA_VISIBLE_DEVICES=${GPU_ID} 8 | 9 | HOME_DIR=$HOME 10 | EVAL_CODE_DIR=${HOME_DIR}/'Projects/dbarf/eval' 11 | cd $EVAL_CODE_DIR 12 | 13 | CONFIG_DIR=${HOME_DIR}/'Projects/dbarf/configs' 14 | ROOT_DIR='/media/chenyu/Data/datasets/IBRNet/eval' 15 | CKPT_DIR=$ROOT_DIR'/out'/'dbarf' 16 | 17 | EXPNAME='eval_dbarf_llff_finetune' 18 | 19 | scenes=('fern' 'flower' 'fortress' 'horns' 'leaves' 'orchids' 'room' 'trex') 20 | 21 | for((i=0;i<${#scenes[@]};i++)); 22 | do 23 | echo ${scenes[i]} 24 | checkpoint_path=${CKPT_DIR}/'finetune_dbarf_llff_'${scenes[i]}_200k/'model_'$ITER'.pth' 25 | eval_config_file=$CONFIG_DIR/eval_dbarf_llff.txt 26 | 27 | CUDA_VISIBLE_DEVICES=$GPU_ID python render_dbarf_llff_video.py \ 28 | --config ${eval_config_file} \ 29 | --expname $EXPNAME \ 30 | --rootdir $ROOT_DIR \ 31 | --ckpt_path ${checkpoint_path} \ 32 | --eval_scenes ${scenes[i]} 33 | done 34 | -------------------------------------------------------------------------------- /scripts/shell/train_coarse_ibrnet.sh: -------------------------------------------------------------------------------- 1 | TASK_TYPE=$1 # {'pretrain', 'finetune'} 2 | DISTRIBUTED=$2 # {'True', 'False'} 3 | CUDA_IDS=$3 # {'0,1,2,...'} 4 | 5 | export PYTHONDONTWRITEBYTECODE=1 6 | export CUDA_VISIBLE_DEVICES=${CUDA_IDS} 7 | 8 | HOME_DIR=$HOME 9 | echo $HOME_DIR 10 | 11 | if [ $TASK_TYPE = 'pretrain' ] 12 | then 13 | CONFIG_FILENAME="pretrain_coarse_ibr" 14 | else 15 | CONFIG_FILENAME="finetune_coarse_ibr" 16 | fi 17 | 18 | CODE_DIR=${HOME_DIR}'/Projects/dbarf' 19 | cd $CODE_DIR 20 | 21 | if [ $DISTRIBUTED = "True" ]; then 22 | echo "Training in distributed mode" 23 | python -m torch.distributed.launch \ 24 | --nproc_per_node=2 train_ibrnet.py \ 25 | --config configs/$CONFIG_FILENAME.txt 26 | else 27 | echo "Training on single machine" 28 | python -m train_ibrnet \ 29 | --config configs/$CONFIG_FILENAME.txt 30 | fi 31 | -------------------------------------------------------------------------------- /scripts/shell/train_dbarf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TASK_TYPE=$1 # {'pretrain', 'finetune'} 4 | DISTRIBUTED=$2 # {'True', 'False'} 5 | CUDA_IDS=$3 # {'0,1,2,...'} 6 | 7 | export PYTHONDONTWRITEBYTECODE=1 8 | export CUDA_VISIBLE_DEVICES=${CUDA_IDS} 9 | 10 | HOME_DIR=$HOME 11 | echo $HOME_DIR 12 | 13 | if [ $TASK_TYPE = 'pretrain' ] 14 | then 15 | CONFIG_FILENAME="pretrain_dbarf" 16 | ROOT_DIR=${HOME_DIR}/'Datasets/IBRNet/train' 17 | else 18 | CONFIG_FILENAME="finetune_dbarf" 19 | ROOT_DIR=${HOME_DIR}/'Datasets/IBRNet/eval' 20 | fi 21 | 22 | CODE_DIR=${HOME_DIR}'/Projects/dbarf' 23 | cd $CODE_DIR 24 | 25 | if [ $DISTRIBUTED = "True" ]; then 26 | echo "Training in distributed mode" 27 | python -m torch.distributed.launch \ 28 | --nproc_per_node=2 train_dbarf.py \ 29 | --config configs/$CONFIG_FILENAME.txt \ 30 | --rootdir $ROOT_DIR 31 | else 32 | echo "Training on single machine" 33 | python -m train_dbarf \ 34 | --config configs/$CONFIG_FILENAME.txt \ 35 | --rootdir $ROOT_DIR 36 | fi 37 | -------------------------------------------------------------------------------- /scripts/shell/train_ibrnet.sh: -------------------------------------------------------------------------------- 1 | TASK_TYPE=$1 # {'pretrain', 'finetune'} 2 | DISTRIBUTED=$2 # {'True', 'False'} 3 | CUDA_IDS=$3 # {'0,1,2,...'} 4 | 5 | export PYTHONDONTWRITEBYTECODE=1 6 | export CUDA_VISIBLE_DEVICES=${CUDA_IDS} 7 | 8 | HOME_DIR=$HOME 9 | echo $HOME_DIR 10 | 11 | if [ $TASK_TYPE = 'pretrain' ] 12 | then 13 | CONFIG_FILENAME="pretrain" 14 | else 15 | CONFIG_FILENAME="finetune_llff" 16 | fi 17 | 18 | CODE_DIR=${HOME_DIR}'/Projects/dbarf' 19 | cd $CODE_DIR 20 | 21 | if [ $DISTRIBUTED = "True" ]; then 22 | echo "Training in distributed mode" 23 | python -m torch.distributed.launch \ 24 | --nproc_per_node=2 train_ibrnet.py \ 25 | --config configs/$CONFIG_FILENAME.txt 26 | else 27 | echo "Training on single machine" 28 | python -m train_ibrnet \ 29 | --config configs/$CONFIG_FILENAME.txt 30 | fi 31 | --------------------------------------------------------------------------------