├── pgdvs ├── models │ ├── gnt │ │ ├── __init__.py │ │ ├── common.py │ │ └── model.py │ ├── cotracker │ │ ├── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── core │ │ │ │ ├── __init__.py │ │ │ │ ├── cotracker │ │ │ │ │ └── __init__.py │ │ │ │ ├── embeddings.py │ │ │ │ └── model_utils.py │ │ │ └── build_cotracker.py │ │ ├── interface.py │ │ └── predictor.py │ └── tapnet │ │ ├── utils │ │ └── transforms.py │ │ ├── interface.py │ │ └── models │ │ └── tsm_utils.py ├── utils │ ├── nsff_lpips │ │ ├── weights │ │ │ ├── v0.0 │ │ │ │ ├── alex.pth │ │ │ │ ├── vgg.pth │ │ │ │ └── squeeze.pth │ │ │ └── v0.1 │ │ │ │ ├── alex.pth │ │ │ │ ├── vgg.pth │ │ │ │ └── squeeze.pth │ │ ├── base_model.py │ │ ├── __init__.py │ │ └── pretrained_networks.py │ ├── pytorch3d_utils.py │ ├── flow_vis_utils.py │ ├── rendering.py │ ├── comm.py │ ├── dycheck │ │ ├── camera.py │ │ └── misc.py │ └── vis_utils.py ├── datasets │ └── combined.py ├── preprocess │ ├── dycheck_mono_info_extractor.py │ ├── convert_casual_sam_output.py │ ├── convert_dyn_video_depth_output.py │ ├── colmap_processor.py │ └── convert_colmap_output.py ├── renderers │ ├── pgdvs_renderer_base.py │ └── st_geo_renderer.py └── engines │ ├── abstract.py │ └── visualizer_pgdvs.py ├── configs ├── tracker │ ├── dummy.yaml │ ├── cotracker.yaml │ └── tapnet.yaml ├── static_renderer │ ├── dummy.yaml │ ├── geo.yaml │ └── gnt.yaml ├── engine │ ├── trainer_pgdvs.yaml │ ├── evaluator_pgdvs.yaml │ └── visualizer_pgdvs.yaml ├── model │ └── pgdvs_renderer.yaml ├── pgdvs.yaml ├── _basic.yaml └── dataset │ └── combined.yaml ├── media └── teaser.jpg ├── envs ├── requirements_jax.txt └── pgdvs.yaml ├── CONTRIBUTING.md ├── third_parties └── flowformer_6ba7ea82.patch ├── scripts ├── download_ckpts.sh ├── preprocess │ └── install_colmap.sh └── visualize.sh ├── .gitignore ├── LICENSE ├── docs ├── BENCHMARK_iPhone.md ├── IN_THE_WILD.md └── BENCHMARK_NVIDIA.md ├── CODE_OF_CONDUCT.md └── README.md /pgdvs/models/gnt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/tracker/dummy.yaml: -------------------------------------------------------------------------------- 1 | _target_: null -------------------------------------------------------------------------------- /configs/static_renderer/dummy.yaml: -------------------------------------------------------------------------------- 1 | _target_: null -------------------------------------------------------------------------------- /media/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-pgdvs/HEAD/media/teaser.jpg -------------------------------------------------------------------------------- /configs/static_renderer/geo.yaml: -------------------------------------------------------------------------------- 1 | _target_: pgdvs.renderers.st_geo_renderer.StaticGeoPointRenderer 2 | 3 | model_cfg: 4 | ckpt_path: null -------------------------------------------------------------------------------- /pgdvs/utils/nsff_lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-pgdvs/HEAD/pgdvs/utils/nsff_lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /pgdvs/utils/nsff_lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-pgdvs/HEAD/pgdvs/utils/nsff_lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /pgdvs/utils/nsff_lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-pgdvs/HEAD/pgdvs/utils/nsff_lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /pgdvs/utils/nsff_lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-pgdvs/HEAD/pgdvs/utils/nsff_lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /pgdvs/utils/nsff_lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-pgdvs/HEAD/pgdvs/utils/nsff_lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /pgdvs/utils/nsff_lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-pgdvs/HEAD/pgdvs/utils/nsff_lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /pgdvs/models/gnt/common.py: -------------------------------------------------------------------------------- 1 | HUGE_NUMBER = 1e10 2 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision 3 | 4 | PIX_COORD_DECIMALS = 1 5 | -------------------------------------------------------------------------------- /configs/engine/trainer_pgdvs.yaml: -------------------------------------------------------------------------------- 1 | _target_: pgdvs.engines.trainer_pgdvs.PGDVSTrainer 2 | 3 | engine_cfg: 4 | 5 | for_overfit: false 6 | 7 | lr_init: 0.0 -------------------------------------------------------------------------------- /configs/tracker/cotracker.yaml: -------------------------------------------------------------------------------- 1 | _target_: pgdvs.models.cotracker.interface.CoTrackerInterface 2 | 3 | ckpt_path: 'cotracker_stride_4_wind_8.pth' 4 | 5 | query_chunk_size: 4096 -------------------------------------------------------------------------------- /configs/model/pgdvs_renderer.yaml: -------------------------------------------------------------------------------- 1 | _target_: pgdvs.renderers.pgdvs_renderer.PGDVSRenderer 2 | 3 | softsplat_metric_abs_alpha: 100.0 # this value needs to be reasonaly large 4 | flag_debug: false -------------------------------------------------------------------------------- /configs/tracker/tapnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: pgdvs.models.tapnet.interface.TAPNetInterface 2 | 3 | ckpt_path: 'tapir_checkpoint_panning.npy' 4 | 5 | query_chunk_size: 4096 6 | 7 | flag_keep_raw_res: false -------------------------------------------------------------------------------- /pgdvs/models/cotracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /pgdvs/models/cotracker/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /pgdvs/models/cotracker/models/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /pgdvs/models/cotracker/models/core/cotracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /configs/static_renderer/gnt.yaml: -------------------------------------------------------------------------------- 1 | _target_: pgdvs.models.gnt.renderer.BaseRenderer 2 | 3 | model_cfg: 4 | _target_: pgdvs.models.gnt.model.GNTModel 5 | 6 | # https://github.com/VITA-Group/GNT/blob/7b63996cb807dbb5c95ab6898e8093996588e73a/configs/gnt_full.txt 7 | 8 | netwidth: 64 9 | transformer_depth: 8 10 | coarse_feat_dim: 32 11 | fine_feat_dim: 32 12 | single_net: True 13 | 14 | posenc_max_freq_log2: 9 15 | pos_enc_n_freqs: 10 16 | view_enc_n_freqs: 10 17 | 18 | ckpt_path: null -------------------------------------------------------------------------------- /envs/requirements_jax.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/deepmind/einshape 2 | jax==0.4.13 3 | https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.13+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl 4 | flax==0.7.1 5 | jaxline==0.0.5 6 | ml_collections 7 | opencv-python==4.7.0.68 8 | optax==0.1.4 9 | chex==0.1.82 10 | dm-haiku==0.0.10 11 | dm-tree==0.1.8 12 | typing_extensions==4.7.1 13 | 14 | # # The following combination also works. Feel free to choose it if it fits your CUDA version. 15 | # jax==0.4.7 16 | # https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.7+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl 17 | # flax==0.5.1 18 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /third_parties/flowformer_6ba7ea82.patch: -------------------------------------------------------------------------------- 1 | diff --git a/core/FlowFormer/LatentCostFormer/twins.py b/core/FlowFormer/LatentCostFormer/twins.py 2 | index e002ba2..af266c9 100644 3 | --- a/core/FlowFormer/LatentCostFormer/twins.py 4 | +++ b/core/FlowFormer/LatentCostFormer/twins.py 5 | @@ -22,7 +22,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 6 | from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import Attention 9 | -from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg 10 | +# from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg 11 | from .attention import MultiHeadAttention, LinearPositionEmbeddingSine 12 | from utils.utils import coords_grid, bilinear_sampler, upflow8 13 | 14 | -------------------------------------------------------------------------------- /configs/pgdvs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _basic # inherit configs from _basic.yaml 3 | - engine: trainer 4 | - model: pgdvs_renderer 5 | - static_renderer: gnt 6 | - tracker: dummy 7 | - dataset: combined 8 | # put __self__ in the end will have any configurations specified in this file 9 | # OVERWRITE same configurations appeared in the defaults list above. 10 | - _self_ 11 | 12 | hydra: 13 | # output_subdir: null # null disables creation of .hydra for saving configuration 14 | job: 15 | chdir: true 16 | name: "pgdvs" 17 | config: 18 | override_dirname: 19 | kv_sep: "=" # original is = 20 | item_sep: "-" # original is , 21 | exclude_keys: 22 | - config_name 23 | 24 | run: 25 | # set the saving / loading directory 26 | dir: "experiments/\ 27 | ${hydra.job.name}/\ 28 | ${now:%Y%m%d}_${now:%H%M%S%f}/\ 29 | " 30 | # ${hydra.job.override_dirname}/\ 31 | 32 | job_logging: 33 | handlers: 34 | file: 35 | filename: pgdvs.log -------------------------------------------------------------------------------- /configs/_basic.yaml: -------------------------------------------------------------------------------- 1 | seed: 1234 2 | 3 | verbose: true 4 | 5 | percision: float32 6 | 7 | distributed: true 8 | 9 | resume: "none" # [none, train, eval, vis] 10 | resume_epoch: null 11 | resume_dir: null 12 | 13 | series_eval: false 14 | 15 | max_epochs: 1000 16 | 17 | rgb_range: "0_1" # ["-1_1", "0_1", "0_255"] 18 | 19 | use_grad_clip: false 20 | grad_clip_val: null 21 | 22 | log_every_iter: 10 23 | save_every_iter: -1 24 | save_every_epoch: 50 25 | 26 | vis_every_epoch: 50 27 | vis_every_iter: -1 28 | 29 | n_ckpts_local: -1 30 | 31 | train_batch_size: 1 32 | eval_batch_size: 1 33 | eval_save_individual: true 34 | n_max_eval_data: -1 # negative values mean evaluting on all data 35 | 36 | n_dataloader_workers: 4 37 | 38 | engine: evaluator_pgdvs 39 | model: ??? 40 | 41 | dataset: combined 42 | dataset_max_hw: -1 43 | dataset_flow_consist_thres: 1.0 44 | 45 | n_src_views_spatial: 10 46 | n_src_views_temporal_track_one_side: 5 47 | 48 | flag_debug: false 49 | 50 | vis_specifics: 51 | n_render_frames: 200 52 | vis_center_time: 50 53 | vis_time_interval: 10 54 | vis_bt_max_disp: 64 55 | -------------------------------------------------------------------------------- /envs/pgdvs.yaml: -------------------------------------------------------------------------------- 1 | name: pgdvs 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.10 9 | - pytorch::pytorch=2.0.1 10 | - pytorch::torchaudio=2.0.2 11 | - pytorch::torchvision=0.15.2 12 | - pytorch::pytorch-cuda=11.8 13 | - scikit-image=0.20.0 14 | - ffmpeg=4.3 15 | - imageio-ffmpeg=0.4.8 16 | - jupyterlab=4.0.1 17 | - matplotlib=3.7.1 18 | - numpy=1.25.2 19 | - scipy=1.10.1 20 | - pillow=9.4.0 21 | - setuptools=67.7.2 22 | - tqdm=4.65.0 23 | - cmake=3.22.1 24 | - mediapy=1.1.8 25 | - pip==23.2.1 26 | - pip: 27 | - lpips==0.1.4 28 | - opencv-python==4.7.0.68 29 | - protobuf==3.20.3 30 | - pyyaml==6.0.1 31 | - tensorboard==2.13.0 32 | - hydra-core==1.2.0 33 | - einops==0.6.0 34 | - trimesh==3.21.7 35 | - joblib==1.2.0 36 | - cupy-cuda11x==12.1.0 37 | - loguru==0.7.0 38 | - diffdist==0.1 39 | - timm==0.6.12 40 | - kornia==0.6.12 41 | - gdown==4.6.0 # https://github.com/wkentaro/gdown/issues/43#issuecomment-1892954390 42 | - awscli==1.29.46 43 | - boto3==1.28.46 44 | - scikit-learn==1.3.1 45 | - immutabledict==3.0.0 -------------------------------------------------------------------------------- /configs/engine/evaluator_pgdvs.yaml: -------------------------------------------------------------------------------- 1 | _target_: pgdvs.engines.evaluator_pgdvs.PGDVSEvaluator 2 | 3 | engine_cfg: 4 | 5 | for_overfit: false 6 | 7 | lr_init: 0.0 8 | 9 | quant_type: "nvidia" 10 | 11 | render_cfg: 12 | render_stride: 1 13 | 14 | chunk_size: 1024 15 | sample_inv_uniform: true 16 | n_coarse_samples_per_ray: 256 17 | n_fine_samples_per_ray: 0 18 | 19 | pure_gnt: false 20 | pure_gnt_with_dyn_mask: false 21 | 22 | gnt_use_dyn_mask: false 23 | gnt_use_masked_spatial_src: true 24 | 25 | mask_oob_n_proj_thres: 1 26 | mask_invalid_n_proj_thres: 4 27 | 28 | st_pcl_remove_outlier: false 29 | st_pcl_outlier_knn: 50 30 | st_pcl_outlier_std_thres: 0.1 31 | 32 | st_render_pcl_pt_radius: 0.01 33 | st_render_pcl_pts_per_pixel: 1 34 | 35 | dyn_pcl_remove_outlier: false 36 | dyn_pcl_outlier_knn: 50 37 | dyn_pcl_outlier_std_thres: 0.1 38 | 39 | dyn_render_type: "softsplat" # ["softsplat", "mesh", "pcl"] 40 | 41 | dyn_render_pcl_pt_radius: 0.01 42 | dyn_render_pcl_pts_per_pixel: 1 43 | 44 | dyn_render_track_temporal: "none" # ["none", "no_tgt"] 45 | 46 | dyn_pcl_track_track2base_thres_mult: 50 47 | 48 | dyn_render_use_flow_consistency: false -------------------------------------------------------------------------------- /configs/engine/visualizer_pgdvs.yaml: -------------------------------------------------------------------------------- 1 | _target_: pgdvs.engines.visualizer_pgdvs.PGDVSVisualizer 2 | 3 | engine_cfg: 4 | 5 | for_overfit: false 6 | 7 | lr_init: 0.0 8 | 9 | quant_type: "nvidia" 10 | 11 | render_cfg: 12 | render_stride: 1 13 | 14 | chunk_size: 1024 15 | sample_inv_uniform: true 16 | n_coarse_samples_per_ray: 256 17 | n_fine_samples_per_ray: 0 18 | 19 | pure_gnt: false 20 | pure_gnt_with_dyn_mask: false 21 | 22 | gnt_use_dyn_mask: false 23 | gnt_use_masked_spatial_src: true 24 | 25 | mask_oob_n_proj_thres: 1 26 | mask_invalid_n_proj_thres: 4 27 | 28 | st_pcl_remove_outlier: false 29 | st_pcl_outlier_knn: 50 30 | st_pcl_outlier_std_thres: 0.1 31 | 32 | st_render_pcl_pt_radius: 0.01 33 | st_render_pcl_pts_per_pixel: 1 34 | 35 | dyn_pcl_remove_outlier: false 36 | dyn_pcl_outlier_knn: 50 37 | dyn_pcl_outlier_std_thres: 0.1 38 | 39 | dyn_render_type: "softsplat" # ["softsplat", "mesh", "pcl"] 40 | 41 | dyn_render_pcl_pt_radius: 0.01 42 | dyn_render_pcl_pts_per_pixel: 1 43 | 44 | dyn_render_track_temporal: "none" # ["none", "no_tgt"] 45 | 46 | dyn_pcl_track_track2base_thres_mult: 50 47 | 48 | dyn_render_use_flow_consistency: false 49 | -------------------------------------------------------------------------------- /scripts/download_ckpts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | { 3 | 4 | DATA_ROOT="$1" 5 | FLAG_ORIGINAL="$2" 6 | 7 | printf '\nDownload checkpoints\n' 8 | 9 | printf '\nDATA_ROOT: %s' ${DATA_ROOT} 10 | printf '\nFLAG_ORIGINAL: %s\n\n' ${FLAG_ORIGINAL} 11 | 12 | mkdir -p ${DATA_ROOT} 13 | 14 | eval "$(conda shell.bash hook)" 15 | conda activate pgdvs 16 | 17 | download_start="$(date -u +%s)" 18 | 19 | if [ "${FLAG_ORIGINAL}" == "1" ]; then 20 | # GNT 21 | if [ ! -f ${DATA_ROOT}/gnt/generalized_model_720000.pth ]; then 22 | gdown 1AMN0diPeHvf2fw53IO5EE2Qp4os5SkoX -O ${DATA_ROOT}/gnt/ 23 | fi 24 | 25 | # TAPIR 26 | if [ ! -f ${DATA_ROOT}/tapnet/tapir_checkpoint_panning.npy ]; then 27 | wget https://storage.googleapis.com/dm-tapnet/tapir_checkpoint_panning.npy -P ${DATA_ROOT}/tapnet/ 28 | fi 29 | 30 | # CoTracker 31 | if [ ! -f ${DATA_ROOT}/cotracker/cotracker_stride_4_wind_8.pth ]; then 32 | wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth -P ${DATA_ROOT}/cotracker/ 33 | wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_12.pth -P ${DATA_ROOT}/cotracker/ 34 | wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_8_wind_16.pth -P ${DATA_ROOT}/cotracker/ 35 | fi 36 | 37 | elif [ "${FLAG_ORIGINAL}" == "0" ]; then 38 | wget https://github.com/apple/ml-pgdvs/releases/download/v0.1/pgdvs_ckpts.zip -P ${DATA_ROOT}/ 39 | unzip ${DATA_ROOT}/pgdvs_ckpts.zip -d ${DATA_ROOT}/ 40 | 41 | else 42 | 43 | printf '\nHello\n' 44 | 45 | fi 46 | 47 | download_end="$(date -u +%s)" 48 | download_elapsed="$(($download_end-$download_start))" 49 | printf "\nDownload time elapsed %f\n" $download_elapsed 50 | printf "\n\n" 51 | 52 | exit; 53 | } -------------------------------------------------------------------------------- /pgdvs/utils/nsff_lpips/base_model.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/d4001759a39b056c95d8bc22da34b10b4fb85afb/nsff_exp/models/base_model.py 2 | 3 | 4 | import os 5 | import torch 6 | from torch.autograd import Variable 7 | from pdb import set_trace as st 8 | 9 | 10 | class BaseModel(torch.nn.Module): 11 | # def __init__(self): 12 | # pass 13 | 14 | def name(self): 15 | return "BaseModel" 16 | 17 | def initialize(self, use_gpu=True, gpu_ids=[0]): 18 | self.use_gpu = use_gpu 19 | self.gpu_ids = gpu_ids 20 | 21 | def forward(self): 22 | pass 23 | 24 | def get_image_paths(self): 25 | pass 26 | 27 | def optimize_parameters(self): 28 | pass 29 | 30 | def get_current_visuals(self): 31 | return self.input 32 | 33 | def get_current_errors(self): 34 | return {} 35 | 36 | def save(self, label): 37 | pass 38 | 39 | # helper saving function that can be used by subclasses 40 | def save_network(self, network, path, network_label, epoch_label): 41 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 42 | save_path = os.path.join(path, save_filename) 43 | torch.save(network.state_dict(), save_path) 44 | 45 | # helper loading function that can be used by subclasses 46 | def load_network(self, network, network_label, epoch_label): 47 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 48 | save_path = os.path.join(self.save_dir, save_filename) 49 | print("Loading network from %s" % save_path) 50 | network.load_state_dict(torch.load(save_path)) 51 | 52 | def update_learning_rate(): 53 | pass 54 | 55 | def get_image_paths(self): 56 | return self.image_paths 57 | 58 | def save_done(self, flag=False): 59 | np.save(os.path.join(self.save_dir, "done_flag"), flag) 60 | np.savetxt( 61 | os.path.join(self.save_dir, "done_flag"), 62 | [ 63 | flag, 64 | ], 65 | fmt="%i", 66 | ) 67 | -------------------------------------------------------------------------------- /pgdvs/models/cotracker/models/build_cotracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from pgdvs.models.cotracker.models.core.cotracker.cotracker import CoTracker 10 | 11 | 12 | def build_cotracker( 13 | checkpoint: str, 14 | ): 15 | if checkpoint is None: 16 | return build_cotracker_stride_4_wind_8() 17 | model_name = checkpoint.split("/")[-1].split(".")[0] 18 | if model_name == "cotracker_stride_4_wind_8": 19 | return build_cotracker_stride_4_wind_8(checkpoint=checkpoint) 20 | elif model_name == "cotracker_stride_4_wind_12": 21 | return build_cotracker_stride_4_wind_12(checkpoint=checkpoint) 22 | elif model_name == "cotracker_stride_8_wind_16": 23 | return build_cotracker_stride_8_wind_16(checkpoint=checkpoint) 24 | else: 25 | raise ValueError(f"Unknown model name {model_name}") 26 | 27 | 28 | # model used to produce the results in the paper 29 | def build_cotracker_stride_4_wind_8(checkpoint=None): 30 | return _build_cotracker( 31 | stride=4, 32 | sequence_len=8, 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_cotracker_stride_4_wind_12(checkpoint=None): 38 | return _build_cotracker( 39 | stride=4, 40 | sequence_len=12, 41 | checkpoint=checkpoint, 42 | ) 43 | 44 | 45 | # the fastest model 46 | def build_cotracker_stride_8_wind_16(checkpoint=None): 47 | return _build_cotracker( 48 | stride=8, 49 | sequence_len=16, 50 | checkpoint=checkpoint, 51 | ) 52 | 53 | 54 | def _build_cotracker( 55 | stride, 56 | sequence_len, 57 | checkpoint=None, 58 | ): 59 | cotracker = CoTracker( 60 | stride=stride, 61 | S=sequence_len, 62 | add_space_attn=True, 63 | space_depth=6, 64 | time_depth=6, 65 | ) 66 | if checkpoint is not None: 67 | with open(checkpoint, "rb") as f: 68 | state_dict = torch.load(f, map_location="cpu") 69 | if "model" in state_dict: 70 | state_dict = state_dict["model"] 71 | cotracker.load_state_dict(state_dict) 72 | return cotracker 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .vscode 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | *.o 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /configs/dataset/combined.yaml: -------------------------------------------------------------------------------- 1 | _target_: pgdvs.datasets.combined.CombinedDataset 2 | 3 | dataset_list: 4 | train: null 5 | eval: ["nvidia_eval"] 6 | vis: ["nvidia_vis"] 7 | 8 | data_root: null 9 | 10 | max_hw: ${dataset_max_hw} 11 | rgb_range: ${rgb_range} 12 | use_aug: false 13 | 14 | dataset_specifics: 15 | 16 | nvidia_eval: 17 | scene_ids: null 18 | raw_data_dir: "nvidia_long" 19 | depth_data_dir: "nvidia_long/Depths" 20 | mask_data_dir: "nvidia_long_flow_mask" 21 | flow_data_dir: "nvidia_long_flow_mask" 22 | n_src_views_spatial: ${n_src_views_spatial} 23 | n_src_views_temporal_track_one_side: ${n_src_views_temporal_track_one_side} 24 | use_zoe_depth: "none" # ["none", "moe", "nk_share_med", "nk_share_trim", "nk_indiv_med", "nk_indiv_trim"] 25 | zoe_depth_data_path: "nvidia_long_zoedepth.zip" 26 | flow_consist_thres: ${dataset_flow_consist_thres} 27 | 28 | nvidia_eval_pure_geo: 29 | scene_ids: null 30 | raw_data_dir: "nvidia_long" 31 | depth_data_dir: "nvidia_long/Depths" 32 | mask_data_dir: "nvidia_long_flow_mask" 33 | flow_data_dir: "nvidia_long_flow_mask" 34 | flow_consist_thres: ${dataset_flow_consist_thres} 35 | 36 | nvidia_vis: 37 | scene_ids: null 38 | raw_data_dir: "nvidia_long" 39 | depth_data_dir: "nvidia_long/Depths" 40 | mask_data_dir: "nvidia_long_flow_mask" 41 | flow_data_dir: "nvidia_long_flow_mask" 42 | n_src_views_spatial: ${n_src_views_spatial} 43 | n_render_frames: ${vis_specifics.n_render_frames} 44 | vis_center_time: ${vis_specifics.vis_center_time} 45 | vis_time_interval: ${vis_specifics.vis_time_interval} 46 | vis_bt_max_disp: ${vis_specifics.vis_bt_max_disp} 47 | flow_consist_thres: ${dataset_flow_consist_thres} 48 | 49 | dycheck_iphone_eval: 50 | scene_ids: null 51 | raw_data_dir: "iphone" 52 | mask_data_dir: "dycheck_iphone_flow_mask" 53 | flow_data_dir: "dycheck_iphone_flow_mask" 54 | n_src_views_spatial: ${n_src_views_spatial} 55 | spatial_src_view_type: "clustered" # [closest_wo_temporal, closest_with_temporal, clustered] 56 | n_src_views_spatial_cluster: null 57 | n_src_views_temporal_track_one_side: ${n_src_views_temporal_track_one_side} 58 | flow_consist_thres: ${dataset_flow_consist_thres} 59 | 60 | mono_vis: 61 | scene_ids: null 62 | n_src_views_spatial: ${n_src_views_spatial} 63 | n_render_frames: ${vis_specifics.n_render_frames} 64 | vis_center_time: ${vis_specifics.vis_center_time} 65 | vis_time_interval: ${vis_specifics.vis_time_interval} 66 | vis_bt_max_disp: ${vis_specifics.vis_bt_max_disp} 67 | flow_consist_thres: ${dataset_flow_consist_thres} 68 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | -------------------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED WITH PGDVS: 43 | 44 | The PGDVS software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the "Acknowledgements" sections in README. 46 | -------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /pgdvs/datasets/combined.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from pgdvs.datasets.nvidia_eval import NvidiaDynEvaluationDataset 8 | from pgdvs.datasets.nvidia_eval_pure_geo import NvidiaDynPureGeoEvaluationDataset 9 | from pgdvs.datasets.nvidia_vis import NvidiaDynVisualizationDataset 10 | from pgdvs.datasets.dycheck_iphone_eval import DyCheckiPhoneEvaluationDataset 11 | from pgdvs.datasets.mono_vis import MonoVisualizationDataset 12 | 13 | 14 | DEBUG_DIR = pathlib.Path(__file__).absolute().parent.parent.parent / "debug" 15 | DEBUG_DIR.mkdir(parents=True, exist_ok=True) 16 | 17 | 18 | DATASET_DICT = { 19 | "nvidia_eval": NvidiaDynEvaluationDataset, 20 | "nvidia_eval_pure_geo": NvidiaDynPureGeoEvaluationDataset, 21 | "nvidia_vis": NvidiaDynVisualizationDataset, 22 | "dycheck_iphone_eval": DyCheckiPhoneEvaluationDataset, 23 | "mono_vis": MonoVisualizationDataset, 24 | } 25 | 26 | 27 | LOGGER = logging.getLogger(__name__) 28 | 29 | 30 | class CombinedDataset(Dataset): 31 | def __init__( 32 | self, 33 | *, 34 | data_root, 35 | dataset_list, 36 | mode="train", 37 | max_hw=-1, 38 | rgb_range="0_1", 39 | use_aug=False, 40 | dataset_specifics={}, 41 | ): 42 | if mode in ["eval", "vis"]: 43 | use_aug = False 44 | 45 | assert mode in ["train", "eval", "vis"], mode 46 | cur_dataset_list = dataset_list[mode] 47 | 48 | self.datasets = {} 49 | for dataset_name in cur_dataset_list: 50 | self.datasets[dataset_name] = DATASET_DICT[dataset_name]( 51 | data_root=data_root, 52 | max_hw=max_hw, 53 | rgb_range=rgb_range, 54 | use_aug=use_aug, 55 | mode=mode, 56 | **dataset_specifics[dataset_name], 57 | ) 58 | 59 | data_cnt = 0 60 | self.data_idxs = [] 61 | sorted_dataset_names = sorted( 62 | list(self.datasets.keys()) 63 | ) # to ensure all workers have the same order 64 | for tmp_name in sorted_dataset_names: 65 | for tmp_i in range(len(self.datasets[tmp_name])): 66 | self.data_idxs.append((data_cnt, tmp_name, tmp_i)) 67 | data_cnt = data_cnt + 1 68 | 69 | def __len__(self): 70 | cur_len = np.sum([len(self.datasets[_]) for _ in self.datasets]) 71 | assert cur_len == len(self.data_idxs), f"{cur_len}, {len(self.data_idxs)}" 72 | return cur_len 73 | 74 | def __getitem__(self, index): 75 | global_i, dataset_name, dataset_i = self.data_idxs[index] 76 | assert index == global_i, f"{index}, {global_i}" 77 | 78 | ret_dict = self.datasets[dataset_name][dataset_i] 79 | 80 | return ret_dict 81 | -------------------------------------------------------------------------------- /pgdvs/utils/pytorch3d_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch3d 3 | 4 | 5 | def cameras_from_opencv_to_pytorch3d( 6 | R: torch.Tensor, 7 | tvec: torch.Tensor, 8 | camera_matrix: torch.Tensor, 9 | image_size: torch.Tensor, 10 | ): 11 | # Ref: https://github.com/facebookresearch/pytorch3d/blob/57f6e79280e78b6e8308f750e64d32984ddeaba4/pytorch3d/renderer/camera_conversions.py#L19 12 | 13 | focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1) 14 | principal_point = camera_matrix[:, :2, 2] 15 | 16 | # Retype the image_size correctly and flip to width, height. 17 | image_size_wh = image_size.to(R).flip(dims=(1,)) 18 | 19 | # Screen to NDC conversion: 20 | # For non square images, we scale the points such that smallest side 21 | # has range [-1, 1] and the largest side has range [-u, u], with u > 1. 22 | # This convention is consistent with the PyTorch3D renderer, as well as 23 | # the transformation function `get_ndc_to_screen_transform`. 24 | scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0 25 | scale = scale.expand(-1, 2) 26 | c0 = image_size_wh / 2.0 27 | 28 | # Get the PyTorch3D focal length and principal point. 29 | focal_pytorch3d = focal_length / scale 30 | p0_pytorch3d = -(principal_point - c0) / scale 31 | 32 | # For R, T we flip x, y axes (opencv screen space has an opposite 33 | # orientation of screen axes). 34 | # We also transpose R (opencv multiplies points from the opposite=left side). 35 | R_pytorch3d = R.clone().permute(0, 2, 1) # PyTorch3D is row-major 36 | T_pytorch3d = tvec.clone() 37 | R_pytorch3d[:, :, :2] *= -1 38 | T_pytorch3d[:, :2] *= -1 39 | 40 | return pytorch3d.renderer.PerspectiveCameras( 41 | R=R_pytorch3d, 42 | T=T_pytorch3d, 43 | focal_length=focal_pytorch3d, 44 | principal_point=p0_pytorch3d, 45 | image_size=image_size, 46 | device=R.device, 47 | ) 48 | 49 | 50 | class SimpleShader(torch.nn.Module): 51 | # - https://github.com/facebookresearch/pytorch3d/issues/84#issuecomment-590118666 52 | # - https://github.com/facebookresearch/pytorch3d/issues/607#issuecomment-801241305 53 | def __init__(self, device="cpu", blend_params=None): 54 | super().__init__() 55 | self.blend_params = ( 56 | blend_params 57 | if blend_params is not None 58 | else pytorch3d.renderer.BlendParams(background_color=(0.0, 0.0, 0.0)) 59 | ) 60 | 61 | def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: 62 | blend_params = kwargs.get("blend_params", self.blend_params) 63 | pixel_colors = meshes.sample_textures(fragments) 64 | images = pytorch3d.renderer.blending.hard_rgb_blend( 65 | pixel_colors, fragments, blend_params 66 | ) 67 | return images # (N, H, W, 3) RGBA image 68 | -------------------------------------------------------------------------------- /pgdvs/preprocess/dycheck_mono_info_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import math 5 | import tqdm 6 | import shutil 7 | import pathlib 8 | import argparse 9 | import PIL.Image 10 | import numpy as np 11 | 12 | from pgdvs.datasets.dycheck_utils import iPhoneParser 13 | 14 | 15 | def extract_frame_info(parser, time_id, camera_id): 16 | rgba = parser.load_rgba(time_id, camera_id) 17 | # rgb = rgba.astype(np.float32)[..., :3] / 255.0 18 | rgb = rgba[..., :3] 19 | 20 | depth = parser.load_depth(time_id, camera_id)[..., 0] # [H, W, 1] -> [H, W] 21 | 22 | cam = parser.load_camera(time_id, camera_id) # [H, W, 1] 23 | K = cam.intrin # [3, 3] 24 | w2c = cam.extrin # [4, 4] 25 | 26 | return rgb, depth, K, w2c 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--data_dir", type=str, default=".") 32 | parser.add_argument("--save_dir", type=str, default=".") 33 | parser.add_argument("--scene_id", type=str, default="apple") 34 | 35 | args = parser.parse_args() 36 | 37 | iphone_parser = iPhoneParser(args.scene_id, data_root=args.data_dir) 38 | 39 | train_frame_names, train_time_ids, train_camera_ids = iphone_parser.load_split( 40 | "train" 41 | ) 42 | 43 | n_train_frame_names = len(train_frame_names) 44 | n_train_time_ids = len(train_time_ids) 45 | n_train_camera_ids = len(train_camera_ids) 46 | 47 | assert ( 48 | n_train_frame_names == n_train_time_ids 49 | ), f"{n_train_frame_names}, {n_train_time_ids}" 50 | assert ( 51 | n_train_frame_names == n_train_camera_ids 52 | ), f"{n_train_frame_names}, {n_train_camera_ids}" 53 | 54 | save_dir = pathlib.Path(args.save_dir) / args.scene_id 55 | save_dir.mkdir(parents=True, exist_ok=True) 56 | 57 | rgb_dir = save_dir / "rgbs" 58 | rgb_dir.mkdir(parents=True, exist_ok=True) 59 | 60 | depth_dir = save_dir / "depths" 61 | depth_dir.mkdir(parents=True, exist_ok=True) 62 | 63 | all_K = [] 64 | all_w2c = [] 65 | 66 | for i in tqdm.tqdm(range(n_train_camera_ids), desc="#train_frames converting"): 67 | tmp_time_id = train_time_ids[i] 68 | tmp_cam_id = train_camera_ids[i] 69 | tmp_frame_name = iphone_parser.get_frame_name(tmp_time_id, tmp_cam_id) 70 | 71 | tmp_rgb, tmp_depth, tmp_K, tmp_w2c = extract_frame_info( 72 | iphone_parser, tmp_time_id, tmp_cam_id 73 | ) 74 | 75 | PIL.Image.fromarray(tmp_rgb).save(rgb_dir / f"{tmp_frame_name}.png") 76 | 77 | with open(depth_dir / f"{tmp_frame_name}.npy", "wb") as f: 78 | np.save(f, tmp_depth) 79 | 80 | all_K.append(tmp_K) 81 | all_w2c.append(tmp_w2c) 82 | 83 | all_K = np.array(all_K) # [#frame, 3, 3] 84 | all_w2c = np.array(all_w2c) # [#frame, 4, 4] 85 | 86 | print("\nall_K: ", all_K.shape, all_w2c.shape, "\n") 87 | np.savez(save_dir / "camera.npz", all_K=all_K, all_w2c=all_w2c) 88 | -------------------------------------------------------------------------------- /docs/BENCHMARK_iPhone.md: -------------------------------------------------------------------------------- 1 | # Benchmark: DyCheck iPhone Dataset 2 | 3 | ## Table of Contents 4 | 5 | - [1 Download Data](#1-download-data) 6 | - [2 Preprocess](#2-preprocess) 7 | - [2.1 Use Precomputed Flow and Mask](#21-use-precomputed-flow-and-mask) 8 | - [2.2 Reproduce Preprocessed Results](#22-reproduce-preprocessed-results) 9 | - [3 Run Benchmark](#3-run-benchmark) 10 | 11 | 12 | ## 1 Download Data 13 | 14 | ```bash 15 | # this environment variable is used for demonstration 16 | cd /path/to/this/repo 17 | export PGDVS_ROOT=$PWD 18 | ``` 19 | 20 | Please follow [DyCheck's official tutorial](https://github.com/KAIR-BAIR/dycheck/blob/ddf77a4e006fdbc5aed28e0859c216da0de5aff5/docs/DATASETS.md#2-iphone-dataset) to download the iPhone dataset to `${PGDVS_ROOT}/data`. After this, you should have a structure as 21 | ``` 22 | . 23 | +-- data 24 | | +-- iphone 25 | | | +-- apple 26 | | | +-- block 27 | | | ... 28 | ``` 29 | 30 | ## 2 Preprocess 31 | 32 | ### 2.1 Use Precomputed Flow and Mask 33 | 34 | ```bash 35 | gdown 1SgvqDJcuFaGJr6Lr3bE9B-knjbOnADQs -O ${PGDVS_ROOT}/data/ 36 | unzip ${PGDVS_ROOT}/data/dycheck_iphone_flow_mask.zip -d ${PGDVS_ROOT}/data 37 | ``` 38 | 39 | ### 2.2 Reproduce Preprocessed Results 40 | 41 | Our pseudo-generalized approach requires optical flow and mask for potentially dynamic content. For this, we need several third parties's repositories and checkpoints. **NOTE**: the `CUDA_HOME` must be set correctly for [`detectron2`](https://github.com/facebookresearch/detectron2)'s installation and consequentially [`OneFormer`](https://github.com/SHI-Labs/OneFormer)'s usage. 42 | ```bash 43 | CUDA_HOME=/usr/local/cuda # set to your own CUDA_HOME, where nvcc is installed 44 | bash ${PGDVS_ROOT}/scripts/preprocess/preprocess.sh \ 45 | ${CUDA_HOME} \ 46 | ${PGDVS_ROOT} \ 47 | ${PGDVS_ROOT}/data \ 48 | prepare 49 | ``` 50 | After running the command, repositories and pretrained checkpoints will be saved to `${PGDVS_ROOT}/third_parties` and `${PGDVS_ROOT}/ckpts` respectively. 51 | 52 | We then compute optical flow and mask with 53 | ```bash 54 | cd ${PGDVS_ROOT} 55 | 56 | bash ${PGDVS_ROOT}/scripts/preprocess/preprocess.sh \ 57 | /usr/local/cuda/ \ 58 | ${PGDVS_ROOT} \ 59 | ${PGDVS_ROOT}/data/iphone \ 60 | execute_on_dycheck \ 61 | ${PGDVS_ROOT}/data/dycheck_iphone_flow_mask \ 62 | apple # can be one of [apple block paper-windmill space-out spin teddy wheel] 63 | ``` 64 | The computed optical flows and masks will be saved to `${PGDVS_ROOT}/data/dycheck_iphone_flow_mask` 65 | 66 | ## 3 Run Benchmark 67 | 68 | To obtain quantitative results, run the following command. All results will be saved to `${PGDVS_ROOT}/experiments`. 69 | 70 | ```bash 71 | benchmark_type=st_gnt_masked_attn_dy_cvd_pcl_clean_render_point 72 | scene_id='[apple]' # or 'null' to evaluate on all scenes 73 | 74 | bash ${PGDVS_ROOT}/scripts/benchmark.sh \ 75 | ${PGDVS_ROOT} \ 76 | ${PGDVS_ROOT}/ckpts \ 77 | ${PGDVS_ROOT}/data \ 78 | dycheck_iphone \ 79 | ${scene_id} \ 80 | ${benchmark_type} 81 | ``` -------------------------------------------------------------------------------- /pgdvs/preprocess/convert_casual_sam_output.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import tqdm 5 | import pathlib 6 | import argparse 7 | import PIL.Image 8 | import numpy as np 9 | 10 | 11 | def extract_depth(base_dir, rgb_dir, save_dir): 12 | result_dir = base_dir / "BA_full" 13 | result_f_list = sorted(list(result_dir.glob("*.npz"))) 14 | 15 | img_exts = PIL.Image.registered_extensions() 16 | supported_img_exts = {ex for ex, f in img_exts.items() if f in PIL.Image.OPEN} 17 | img_f_list = [] 18 | for tmp_ext in supported_img_exts: 19 | img_f_list = img_f_list + list(rgb_dir.glob(f"*{tmp_ext}")) 20 | 21 | img_f_list = sorted(img_f_list) 22 | 23 | assert len(img_f_list) == len( 24 | result_f_list 25 | ), f"{len(img_f_list)}, {len(result_f_list)}" 26 | 27 | K0 = None 28 | 29 | pose_dir = save_dir / "poses" 30 | pose_dir.mkdir(exist_ok=True, parents=True) 31 | 32 | depth_dir = save_dir / "depths" 33 | depth_dir.mkdir(exist_ok=True, parents=True) 34 | 35 | for tmp_i, tmp_result_f in enumerate(tqdm.tqdm(result_f_list)): 36 | tmp_idx = int(tmp_result_f.stem) # 0002.npz 37 | assert tmp_idx == tmp_i, f"{tmp_idx}, {tmp_i}" 38 | 39 | # disp (192, 384) float32 40 | # R (3, 3) float32 41 | # t (3,) float32 42 | # K (3, 3) float32 43 | # img (192, 384, 3) float32 44 | # mask_motion (192, 384) float32 45 | # uncertainty_pred (2, 192, 384) float32 46 | # gt_K (3, 3) float32 47 | tmp_info = np.load(tmp_result_f) 48 | disp = tmp_info["disp"] 49 | depth = 1 / (disp + 1e-8) 50 | cam_c2w = np.eye(4) 51 | cam_c2w[:3, :3] = tmp_info["R"] 52 | cam_c2w[:3, 3] = tmp_info["t"] 53 | K = np.eye(4) 54 | K[:3, :3] = tmp_info["K"] # this is important 55 | 56 | if K0 is None: 57 | K0 = K 58 | else: 59 | assert np.sum(np.abs(K0 - K)) < 1e-5, f"{K0}, {K}" 60 | 61 | pose_f = pose_dir / f"{img_f_list[tmp_i].stem}.npz" 62 | pose_dict = {"c2w": cam_c2w, "K": K} 63 | np.savez(pose_f, **pose_dict) 64 | 65 | depth_f = depth_dir / f"{img_f_list[tmp_i].stem}.npz" 66 | np.savez(depth_f, depth=depth) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument( 72 | "--casual_sam_dir", 73 | type=str, 74 | help="directory for results from CasualSAM", 75 | ) 76 | parser.add_argument("--rgb_dir", type=str, help="rgb directory") 77 | parser.add_argument("--save_dir", type=str, help="save directory") 78 | 79 | args = parser.parse_args() 80 | 81 | casual_sam_dir = pathlib.Path(args.casual_sam_dir) 82 | rgb_dir = pathlib.Path(args.rgb_dir) 83 | save_dir = pathlib.Path(args.save_dir) 84 | save_dir.mkdir(parents=True, exist_ok=True) 85 | 86 | extract_depth(casual_sam_dir, rgb_dir, save_dir) 87 | print("Done with extracing poses from CasualSAM") 88 | -------------------------------------------------------------------------------- /pgdvs/models/cotracker/interface.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/facebookresearch/co-tracker/blob/0a0596b277545625054cb041f00419bcd3693ea5/demo.py 2 | import tqdm 3 | import logging 4 | 5 | import torch 6 | 7 | from pgdvs.models.cotracker.predictor import CoTrackerPredictor 8 | from pgdvs.utils.rendering import modify_rgb_range 9 | 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | 14 | class CoTrackerInterface(torch.nn.Module): 15 | def __init__( 16 | self, 17 | ckpt_path, 18 | ori_rgb_range="0_1", 19 | query_chunk_size=4096, 20 | local_rank=0, 21 | ): 22 | super().__init__() 23 | self.ori_rgb_range = ori_rgb_range 24 | self.query_chunk_size = query_chunk_size 25 | 26 | self.model = CoTrackerPredictor(checkpoint=ckpt_path) 27 | 28 | LOGGER.info(f"[CoTracker] Done loading checkpoint from {ckpt_path}") 29 | 30 | def __call__(self, *, frames, query_points): 31 | # frames: [N, H, W, 3] 32 | # query_points: [#pt, 3], 3 for [time, row, col] 33 | 34 | frames = frames.permute(0, 3, 1, 2)[None, ...] # [1, N, 3, H, W] 35 | 36 | frames = modify_rgb_range( 37 | frames, 38 | src_range=self.ori_rgb_range, 39 | tgt_range="0_255", 40 | check_range=False, 41 | enforce_range=True, 42 | ) 43 | 44 | # query points are formulated as [t, x, y] 45 | # https://github.com/facebookresearch/co-tracker/blob/0a0596b277545625054cb041f00419bcd3693ea5/cotracker/predictor.py#L37 46 | query_points = query_points[ 47 | None, :, [0, 2, 1] 48 | ] # [1, #pt, 3]; [t, row, col] -> [t, x, y] 49 | 50 | pred_tracks = [] 51 | pred_visibility = [] 52 | 53 | n_queries = query_points.shape[1] 54 | 55 | for start_i in tqdm.tqdm( 56 | range(0, n_queries, self.query_chunk_size), disable=True 57 | ): 58 | end_i = min(n_queries, start_i + self.query_chunk_size) 59 | tmp_queries = query_points[:, start_i:end_i, :] 60 | tmp_pred_tracks, tmp_pred_visibility = self.model( 61 | frames, 62 | queries=tmp_queries, 63 | grid_size=0, 64 | grid_query_frame=0, 65 | backward_tracking=False, 66 | segm_mask=None, 67 | ) # tracks: [1, #frame, #pt, 2], float32; visibility: [1, #frame, #pt], bool 68 | 69 | pred_tracks.append(tmp_pred_tracks) 70 | pred_visibility.append(tmp_pred_visibility) 71 | 72 | pred_tracks = torch.cat(pred_tracks, dim=2)[0, ...].permute( 73 | 1, 0, 2 74 | ) # [1, #frame, #pt, 2] -> [#pt, #frame, 2] 75 | pred_visibility = torch.cat(pred_visibility, dim=2)[0, ...].permute( 76 | 1, 0 77 | ) # [1, #frame, #pt] -> [#pt, #frame] 78 | 79 | # there may be negative values and we clip them 80 | pred_tracks = torch.clip(pred_tracks, 0.0) 81 | 82 | return pred_tracks, pred_visibility 83 | -------------------------------------------------------------------------------- /scripts/preprocess/install_colmap.sh: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/consistent_depth/blob/e2c9b724d3221aa7c0bf89aa9449ae33b418d943/scripts/install_colmap_ubuntu.sh 2 | 3 | export BASE_DIR=$1 # root directory where you prefer to install for colmap and related repositories 4 | echo "Dir: $BASE_DIR" 5 | 6 | # IMPORTANT !!! 7 | # See issues here: https://github.com/pism/pism/issues/356 8 | # Make sure that we have a PATH without conda's pollution 9 | # Put the following two lines in .bashrc or .zshrc before conda overwrite the PATH 10 | # export NOCONDA_PATH=$PATH 11 | export CUR_NOCONDA_PATH=$NOCONDA_PATH 12 | 13 | # conda deactivate 14 | 15 | sudo apt-get update 16 | sudo apt-get install -y libgl1-mesa-glx ffmpeg libsm6 libxext6 17 | 18 | # https://colmap.github.io/install.html 19 | sudo apt-get install -y \ 20 | git \ 21 | cmake \ 22 | build-essential \ 23 | libboost-program-options-dev \ 24 | libboost-filesystem-dev \ 25 | libboost-graph-dev \ 26 | libboost-system-dev \ 27 | libboost-test-dev \ 28 | libeigen3-dev \ 29 | libsuitesparse-dev \ 30 | libfreeimage-dev \ 31 | libgoogle-glog-dev \ 32 | libgflags-dev \ 33 | libglew-dev \ 34 | qtbase5-dev \ 35 | libqt5opengl5-dev \ 36 | libcgal-dev 37 | 38 | sudo apt-get install -y libcgal-qt5-dev 39 | 40 | sudo apt-get install -y libatlas-base-dev libsuitesparse-dev 41 | 42 | eval "$(conda shell.bash hook)" 43 | conda activate base 44 | conda install cmake -y 45 | 46 | cd $BASE_DIR 47 | git clone https://ceres-solver.googlesource.com/ceres-solver 48 | cd ceres-solver 49 | # git checkout $(git describe --tags) # Checkout the latest release 50 | git checkout 2.1.0 51 | mkdir build 52 | cd build 53 | PATH=$CUR_NOCONDA_PATH cmake .. -DBUILD_TESTING=OFF -DBUILD_EXAMPLES=OFF -DCMAKE_INSTALL_PREFIX=$BASE_DIR/ceres-solver/build 54 | make -j 55 | make install 56 | 57 | # resolve building bugs 58 | sudo apt-get install -y libgoogle-glog-dev 59 | sudo apt-get install -y libfreeimage-dev 60 | sudo apt-get install -y libglu1-mesa-dev freeglut3-dev mesa-common-dev 61 | sudo apt-get install -y libglew-dev 62 | sudo apt-get install -y qt5-default 63 | 64 | # https://github.com/facebookresearch/habitat-sim/issues/971#issuecomment-818560795 65 | # https://github.com/colmap/colmap/issues/1271#issuecomment-931900582 66 | # https://github.com/NVIDIA/libglvnd 67 | sudo apt-get install -y lsb-core 68 | sudo apt-get install -y autoconf 69 | sudo apt-get install -y libxext-dev libx11-dev x11proto-gl-dev 70 | cd $BASE_DIR 71 | git clone https://github.com/NVIDIA/libglvnd.git 72 | cd libglvnd 73 | git checkout c8ee005 74 | ./autogen.sh 75 | mkdir build 76 | ./configure --prefix=$BASE_DIR/libglvnd/build 77 | make -j 78 | make install 79 | 80 | # # https://github.com/colmap/colmap/issues/188 81 | # conda activate base 82 | # conda uninstall libtiff 83 | 84 | cd $BASE_DIR 85 | git clone https://github.com/colmap/colmap.git 86 | cd colmap 87 | git checkout ea40ef9a # 3.7 88 | mkdir build 89 | cd build 90 | PATH=$CUR_NOCONDA_PATH cmake .. \ 91 | -DCMAKE_INSTALL_PREFIX=$BASE_DIR/colmap/build \ 92 | -DCMAKE_PREFIX_PATH=$BASE_DIR/ceres-solver/build \ 93 | -DDCMAKE_INCLUDE_PATH=$BASE_DIR/libglvnd/build/include \ 94 | -DCMAKE_LIBRARY_PATH=$BASE_DIR/libglvnd/build/lib 95 | make -j 96 | make install -------------------------------------------------------------------------------- /pgdvs/preprocess/convert_dyn_video_depth_output.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import tqdm 5 | import pathlib 6 | import argparse 7 | import PIL.Image 8 | import numpy as np 9 | 10 | 11 | def extract_depth(base_dir, rgb_dir, save_dir): 12 | result_dir_list = list( 13 | (base_dir / "test").glob("scene_flow_motion_field*/epoch*_test") 14 | ) 15 | assert len(result_dir_list) == 1, f"{len(result_dir_list)}, {result_dir_list}" 16 | 17 | result_dir = result_dir_list[0] 18 | result_f_list = sorted(list(result_dir.glob("*.npz"))) 19 | 20 | img_exts = PIL.Image.registered_extensions() 21 | supported_img_exts = {ex for ex, f in img_exts.items() if f in PIL.Image.OPEN} 22 | img_f_list = [] 23 | for tmp_ext in supported_img_exts: 24 | img_f_list = img_f_list + list(rgb_dir.glob(f"*{tmp_ext}")) 25 | 26 | img_f_list = sorted(img_f_list) 27 | 28 | assert len(img_f_list) == len( 29 | result_f_list 30 | ), f"{len(img_f_list)}, {len(result_f_list)}" 31 | 32 | K0 = None 33 | 34 | pose_dir = save_dir / "poses" 35 | pose_dir.mkdir(exist_ok=True, parents=True) 36 | 37 | depth_dir = save_dir / "depths" 38 | depth_dir.mkdir(exist_ok=True, parents=True) 39 | 40 | for tmp_i, tmp_result_f in enumerate(tqdm.tqdm(result_f_list)): 41 | tmp_idx = int(tmp_result_f.stem.split("batch")[1]) # batch0013.npz 42 | assert tmp_idx == tmp_i, f"{tmp_idx}, {tmp_i}" 43 | 44 | # batch_size () int64 45 | # img_1 (1, 3, 458, 816) float32 46 | # img_2 (1, 3, 458, 816) float32 47 | # depth (1, 1, 458, 816) float32 48 | # sf_1_2 (1, 3, 458, 816) float32 49 | # depth_nn (1, 1, 458, 816) float32 50 | # depth_gt (1, 1, 458, 816) float32 51 | # cam_c2w (1, 4, 4) float32 52 | # K (1, 1, 1, 3, 3) float32 53 | # pair_path (1,) 0: 83 | flag_strict = False 84 | else: 85 | flag_strict = True 86 | 87 | missed_keys, unexpcted_keys = self.load_state_dict( 88 | tgt_state_dict, strict=flag_strict 89 | ) 90 | 91 | LOGGER.info(f"[GNT] Done loading from {path}") 92 | 93 | if len(ignore_keys) > 0: 94 | assert set(ignore_keys) == set(missed_keys), f"{ignore_keys}, {missed_keys}" 95 | 96 | if len(missed_keys) > 0: 97 | LOGGER.info("[GNT] Missing keys:") 98 | LOGGER.info(missed_keys) 99 | if len(unexpcted_keys) > 0: 100 | LOGGER.info("[GNT] Unexpected keys:") 101 | LOGGER.info(unexpcted_keys) 102 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual 11 | identity and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the overall 27 | community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or advances of 32 | any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email address, 36 | without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official email address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 64 | complaints will be reviewed and investigated and will result in a response that 65 | is deemed necessary and appropriate to the circumstances. The project team is 66 | obligated to maintain confidentiality with regard to the reporter of an incident. 67 | Further details of specific enforcement policies may be posted separately. 68 | 69 | Project maintainers who do not follow or enforce the Code of Conduct in good 70 | faith may face temporary or permanent repercussions as determined by other 71 | members of the project's leadership. 72 | 73 | ## Attribution 74 | 75 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 76 | version 2.1, available at 77 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 78 | 79 | Community Impact Guidelines were inspired by 80 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 81 | 82 | For answers to common questions about this code of conduct, see the FAQ at 83 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 84 | [https://www.contributor-covenant.org/translations][translations]. 85 | 86 | [homepage]: https://www.contributor-covenant.org 87 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 88 | [Mozilla CoC]: https://github.com/mozilla/diversity 89 | [FAQ]: https://www.contributor-covenant.org/faq 90 | [translations]: https://www.contributor-covenant.org/translations 91 | 92 | -------------------------------------------------------------------------------- /scripts/visualize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | { 3 | 4 | ALL_ARGS=("$@") # treat input as array 5 | 6 | printf '\nAll args: %s' ${ALL_ARGS[@]} 7 | 8 | printf '\n' 9 | 10 | REPO_ROOT=(${ALL_ARGS[0]}) 11 | CKPT_ROOT=(${ALL_ARGS[1]}) 12 | DATA_ROOT=(${ALL_ARGS[2]}) 13 | DATASET=(${ALL_ARGS[3]}) 14 | SCENE_ID=(${ALL_ARGS[4]}) 15 | # RUN_TYPE=(${ALL_ARGS[5]}) 16 | # CUDA_HOME=(${ALL_ARGS[4]}) 17 | 18 | N_IN_ARGS=5 19 | 20 | printf '\nREPO_ROOT: %s' ${REPO_ROOT} 21 | printf '\nCKPT_ROOT: %s' ${CKPT_ROOT} 22 | printf '\nDATA_ROOT: %s' ${DATA_ROOT} 23 | printf '\nDATASET: %s' ${DATASET} 24 | printf '\nSCENE_ID: %s' ${SCENE_ID} 25 | # printf '\nRUN_TYPE: %s' ${RUN_TYPE} 26 | # printf '\nCUDA_HOME\n: %s' ${CUDA_HOME} 27 | 28 | eval "$(conda shell.bash hook)" 29 | conda activate pgdvs 30 | 31 | cd ${REPO_ROOT} 32 | export PYTHONPATH=${REPO_ROOT}:${PYTHONPATH} 33 | 34 | ulimit -n 65000; 35 | ulimit -c 0; # Disable core file creation 36 | 37 | export MKL_THREADING_LAYER=GNU; 38 | export NCCL_P2P_DISABLE=1; 39 | export HYDRA_FULL_ERROR=1; 40 | export OC_CAUSE=1; 41 | # export CUDA_HOME=${CUDA_HOME} 42 | 43 | # for jax 44 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 45 | export TF_CPP_MIN_LOG_LEVEL=0 46 | 47 | VALID_DATA=(nvidia_vis mono_vis) 48 | 49 | if [[ ${VALID_DATA[@]} =~ ${DATASET} ]] 50 | then 51 | printf "\n\ndataset ${DATASET} is valid\n" 52 | else 53 | printf "\n\ndataset ${DATASET} is NOT supported\n" 54 | fi 55 | 56 | OVERWRITE_CONF=() 57 | 58 | if [ "${DATASET}" == "nvidia_vis" ]; then 59 | 60 | OVERWRITE_CONF+=('dataset.dataset_list.vis=[nvidia_vis]') 61 | OVERWRITE_CONF+=("dataset.dataset_specifics.nvidia_vis.scene_ids=${SCENE_ID}") 62 | 63 | elif [ "${DATASET}" == "mono_vis" ]; then 64 | 65 | OVERWRITE_CONF+=('dataset.dataset_list.vis=[mono_vis]') 66 | OVERWRITE_CONF+=("dataset.dataset_specifics.mono_vis.scene_ids=${SCENE_ID}") 67 | 68 | fi 69 | 70 | OVERWRITE_CONF+=('vis_specifics.n_render_frames=400') 71 | OVERWRITE_CONF+=('vis_specifics.vis_center_time=50') 72 | OVERWRITE_CONF+=('vis_specifics.vis_time_interval=50') 73 | OVERWRITE_CONF+=('vis_specifics.vis_bt_max_disp=32') 74 | 75 | OVERWRITE_ARGS=("${ALL_ARGS[@]:${N_IN_ARGS}}") # remove the first several elements 76 | OVERWRITE_CONF+=(${OVERWRITE_ARGS[@]}) 77 | 78 | OVERWRITE_STR=$(printf " %s" "${OVERWRITE_CONF[@]}") 79 | OVERWRITE_STR=${OVERWRITE_STR:1} 80 | 81 | printf "\n\nOverwrites: %s" ${OVERWRITE_STR} 82 | printf "\n" 83 | 84 | # https://stackoverflow.com/a/76742577 85 | N_GPU_SPLIT_STR=(${CUDA_VISIBLE_DEVICES//,/ }) 86 | NUM_GPUS=${#N_GPU_SPLIT_STR[@]} 87 | 88 | if [ "${NUM_GPUS}" == "0" ]; then 89 | NUM_GPUS=$(nvidia-smi --list-gpus | wc -l) 90 | fi 91 | 92 | printf "\nNumber of GPU is ${NUM_GPUS}\n\n" 93 | 94 | export IMAGEIO_FFMPEG_EXE=/usr/bin/ffmpeg && \ 95 | python ${REPO_ROOT}/pgdvs/run.py \ 96 | verbose=true \ 97 | distributed=true \ 98 | seed=0 \ 99 | resume="vis_wo_resume" \ 100 | resume_dir=null \ 101 | engine=visualizer_pgdvs \ 102 | model=pgdvs_renderer \ 103 | model.softsplat_metric_abs_alpha=100.0 \ 104 | static_renderer=gnt \ 105 | static_renderer.model_cfg.ckpt_path=${CKPT_ROOT}/gnt/model_720000.pth \ 106 | series_eval=false \ 107 | eval_batch_size=${NUM_GPUS} \ 108 | n_max_eval_data=-1 \ 109 | eval_save_individual=true \ 110 | engine.engine_cfg.render_cfg.render_stride=1 \ 111 | engine.engine_cfg.render_cfg.chunk_size=2048 \ 112 | engine.engine_cfg.render_cfg.sample_inv_uniform=true \ 113 | engine.engine_cfg.render_cfg.n_coarse_samples_per_ray=256 \ 114 | engine.engine_cfg.render_cfg.n_fine_samples_per_ray=0 \ 115 | engine.engine_cfg.render_cfg.mask_oob_n_proj_thres=1 \ 116 | engine.engine_cfg.render_cfg.mask_invalid_n_proj_thres=4 \ 117 | engine.engine_cfg.render_cfg.dyn_pcl_remove_outlier=true \ 118 | engine.engine_cfg.render_cfg.dyn_pcl_outlier_knn=50 \ 119 | engine.engine_cfg.render_cfg.dyn_pcl_outlier_std_thres=0.1 \ 120 | engine.engine_cfg.render_cfg.gnt_use_dyn_mask=true \ 121 | engine.engine_cfg.render_cfg.gnt_use_masked_spatial_src=false \ 122 | engine.engine_cfg.render_cfg.dyn_render_use_flow_consistency=false \ 123 | dataset=combined \ 124 | 'dataset.dataset_list.train=[nvidia_eval]' \ 125 | 'dataset.dataset_list.eval=[nvidia_eval]' \ 126 | 'dataset.dataset_list.vis=[nvidia_vis]' \ 127 | "dataset.dataset_specifics.mono_vis.scene_ids=${SCENE_ID}" \ 128 | dataset.data_root=${DATA_ROOT} \ 129 | n_dataloader_workers=1 \ 130 | dataset_max_hw=-1 \ 131 | dataset.use_aug=false \ 132 | ${OVERWRITE_STR} 133 | 134 | exit; 135 | } -------------------------------------------------------------------------------- /pgdvs/utils/flow_vis_utils.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | 21 | def make_colorwheel(): 22 | """ 23 | Generates a color wheel for optical flow visualization as presented in: 24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 26 | 27 | Code follows the original C++ source code of Daniel Scharstein. 28 | Code follows the the Matlab source code of Deqing Sun. 29 | 30 | Returns: 31 | np.ndarray: Color wheel 32 | """ 33 | 34 | RY = 15 35 | YG = 6 36 | GC = 4 37 | CB = 11 38 | BM = 13 39 | MR = 6 40 | 41 | ncols = RY + YG + GC + CB + BM + MR 42 | colorwheel = np.zeros((ncols, 3)) 43 | col = 0 44 | 45 | # RY 46 | colorwheel[0:RY, 0] = 255 47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 48 | col = col + RY 49 | # YG 50 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 51 | colorwheel[col : col + YG, 1] = 255 52 | col = col + YG 53 | # GC 54 | colorwheel[col : col + GC, 1] = 255 55 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 56 | col = col + GC 57 | # CB 58 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 59 | colorwheel[col : col + CB, 2] = 255 60 | col = col + CB 61 | # BM 62 | colorwheel[col : col + BM, 2] = 255 63 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 64 | col = col + BM 65 | # MR 66 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 67 | colorwheel[col : col + MR, 0] = 255 68 | return colorwheel 69 | 70 | 71 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 72 | """ 73 | Applies the flow color wheel to (possibly clipped) flow components u and v. 74 | 75 | According to the C++ source code of Daniel Scharstein 76 | According to the Matlab source code of Deqing Sun 77 | 78 | Args: 79 | u (np.ndarray): Input horizontal flow of shape [H,W] 80 | v (np.ndarray): Input vertical flow of shape [H,W] 81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 82 | 83 | Returns: 84 | np.ndarray: Flow visualization image of shape [H,W,3] 85 | """ 86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 87 | colorwheel = make_colorwheel() # shape [55x3] 88 | ncols = colorwheel.shape[0] 89 | rad = np.sqrt(np.square(u) + np.square(v)) 90 | a = np.arctan2(-v, -u) / np.pi 91 | fk = (a + 1) / 2 * (ncols - 1) 92 | k0 = np.floor(fk).astype(np.int32) 93 | k1 = k0 + 1 94 | k1[k1 == ncols] = 0 95 | f = fk - k0 96 | for i in range(colorwheel.shape[1]): 97 | tmp = colorwheel[:, i] 98 | col0 = tmp[k0] / 255.0 99 | col1 = tmp[k1] / 255.0 100 | col = (1 - f) * col0 + f * col1 101 | idx = rad <= 1 102 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 103 | col[~idx] = col[~idx] * 0.75 # out of range 104 | # Note the 2-i => BGR instead of RGB 105 | ch_idx = 2 - i if convert_to_bgr else i 106 | flow_image[:, :, ch_idx] = np.floor(255 * col) 107 | return flow_image 108 | 109 | 110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 111 | """ 112 | Expects a two dimensional flow image of shape. 113 | 114 | Args: 115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 118 | 119 | Returns: 120 | np.ndarray: Flow visualization image of shape [H,W,3] 121 | """ 122 | assert flow_uv.ndim == 3, "input flow must have three dimensions" 123 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" 124 | if clip_flow is not None: 125 | flow_uv = np.clip(flow_uv, 0, clip_flow) 126 | u = flow_uv[:, :, 0] 127 | v = flow_uv[:, :, 1] 128 | rad = np.sqrt(np.square(u) + np.square(v)) 129 | rad_max = np.max(rad) 130 | epsilon = 1e-5 131 | u = u / (rad_max + epsilon) 132 | v = v / (rad_max + epsilon) 133 | return flow_uv_to_colors(u, v, convert_to_bgr) 134 | -------------------------------------------------------------------------------- /pgdvs/preprocess/colmap_processor.py: -------------------------------------------------------------------------------- 1 | # Modified from 2 | # - https://github.com/kwea123/nsff_pl/blob/0e7b75543a4c3f0782332cf64c96fadce03ae34d/preprocess.py 3 | # - https://github.com/facebookresearch/consistent_depth/blob/e2c9b724d3221aa7c0bf89aa9449ae33b418d943/tools/colmap_processor.py 4 | 5 | import os 6 | import pathlib 7 | import argparse 8 | import subprocess 9 | import PIL.Image 10 | 11 | 12 | def run_cmd(cmd): 13 | cmd_str = " ".join(cmd) 14 | print(f"\n{cmd_str}\n") 15 | 16 | new_env = os.environ.copy() 17 | new_env["LD_LIBRARY_PATH"] = f"/opt/conda/lib:{new_env['LD_LIBRARY_PATH']}" 18 | subprocess.run(cmd, env=new_env) 19 | 20 | 21 | def run_colmap(args): 22 | max_num_matches = 132768 # colmap setting 23 | 24 | workspace_path = pathlib.Path(args.workspace_path) 25 | workspace_path.mkdir(parents=True, exist_ok=True) 26 | 27 | rgb_dir = pathlib.Path(args.image_path) 28 | mask_dir = pathlib.Path(args.mask_path) 29 | 30 | exts = PIL.Image.registered_extensions() 31 | supported_exts = {ex for ex, f in exts.items() if f in PIL.Image.OPEN} 32 | 33 | rgb_f_list = [] 34 | for tmp_ext in supported_exts: 35 | rgb_f_list = rgb_f_list + list(sorted(rgb_dir.glob(f"*{tmp_ext}"))) 36 | for tmp_rgb_f in rgb_f_list: 37 | tmp_mask_f = mask_dir / f"{tmp_rgb_f.name}.png" 38 | assert tmp_mask_f.exists(), rgb_f_list 39 | 40 | if not (workspace_path / "database.db").exists() or args.overwrite: 41 | # https://colmap.github.io/faq.html#mask-image-regions 42 | # Features will only be extracted from areas with mask values of 1 43 | 44 | # fmt: off 45 | cmd_feat = [ 46 | f"{args.colmap_bin}", 47 | "feature_extractor", 48 | "--database_path", f"{args.workspace_path}/database.db", 49 | "--image_path", f"{args.image_path}", 50 | "--ImageReader.mask_path", f"{args.mask_path}", 51 | "--ImageReader.camera_model", "SIMPLE_RADIAL", 52 | "--ImageReader.single_camera", "1", 53 | "--SiftExtraction.num_threads", "1", 54 | "--SiftExtraction.gpu_index", "0", 55 | ] 56 | 57 | # "--ImageReader.default_focal_length_factor", "0.95", 58 | # "--SiftExtraction.peak_threshold", "0.004", 59 | # "--SiftExtraction.max_num_features", "8192", 60 | # "--SiftExtraction.edge_threshold", "16", 61 | 62 | # fmt: on 63 | run_cmd(cmd_feat) 64 | 65 | # fmt: off 66 | cmd_match = [ 67 | f"{args.colmap_bin}", 68 | "exhaustive_matcher", 69 | "--database_path", f"{args.workspace_path}/database.db", 70 | "--SiftMatching.multiple_models", "1", 71 | "--SiftMatching.guided_matching", "1", 72 | ] 73 | 74 | # "--SiftMatching.max_ratio", "0.8", 75 | # "--SiftMatching.max_error", "4.0", 76 | # "--SiftMatching.max_distance", "0.7", 77 | # "--SiftMatching.max_num_matches", f"{max_num_matches}", 78 | 79 | # fmt: on 80 | run_cmd(cmd_match) 81 | 82 | if not (workspace_path / "sparse").exists() or args.overwrite: 83 | (workspace_path / "sparse").mkdir(exist_ok=True, parents=True) 84 | 85 | # fmt: off 86 | cmd_map = [ 87 | f"{args.colmap_bin}", 88 | "mapper", 89 | "--database_path", f"{args.workspace_path}/database.db", 90 | "--image_path", f"{args.image_path}", 91 | "--output_path", f"{args.workspace_path}/sparse", 92 | "--Mapper.abs_pose_min_inlier_ratio", "0.5", 93 | "--Mapper.abs_pose_min_num_inliers", "50", 94 | "--Mapper.init_max_forward_motion", "1", 95 | "--Mapper.ba_local_num_images", "15", 96 | ] 97 | # fmt: on 98 | run_cmd(cmd_map) 99 | 100 | if not pathlib.Path(args.undistort_path).exists() or args.overwrite: 101 | pathlib.Path(args.undistort_path).mkdir(exist_ok=True, parents=True) 102 | 103 | # fmt: off 104 | cmd_undist = [ 105 | f"{args.colmap_bin}", 106 | "image_undistorter", 107 | "--input_path", f"{args.workspace_path}/sparse/0", 108 | "--image_path", f"{args.image_path}", 109 | "--output_path", f"{args.undistort_path}", 110 | "--output_type", "COLMAP", 111 | ] 112 | # fmt: on 113 | run_cmd(cmd_undist) 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser(description="run colmap") 118 | parser.add_argument("--colmap_bin", type=str, default="colmap") 119 | parser.add_argument("--cuda_device", type=int, default=-1) 120 | parser.add_argument("--image_path", type=str) 121 | parser.add_argument("--mask_path", type=str) 122 | parser.add_argument("--workspace_path", type=str, default="colmap") 123 | parser.add_argument("--undistort_path", type=str, default="undistorted") 124 | parser.add_argument( 125 | "--overwrite", default=False, action="store_true", help="overwrite cache" 126 | ) 127 | 128 | args = parser.parse_args() 129 | 130 | run_colmap(args) 131 | -------------------------------------------------------------------------------- /pgdvs/renderers/pgdvs_renderer_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | import torch 5 | 6 | import pgdvs.utils.softsplat as softsplat 7 | 8 | 9 | DEBUG_DIR = pathlib.Path(__file__).absolute().parent.parent.parent / "debug" 10 | DEBUG_DIR.mkdir(parents=True, exist_ok=True) 11 | 12 | 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | 16 | class PGDVSBaseRenderer(torch.nn.Module): 17 | def get_batched_rays( 18 | self, *, device, batch_size, H, W, render_stride, intrinsics, c2w 19 | ): 20 | """ 21 | :param H: image height 22 | :param W: image width 23 | :param intrinsics: 4 by 4 intrinsic matrix 24 | :param c2w: 4 by 4 camera to world extrinsic matrix 25 | :return: 26 | """ 27 | u, v = torch.meshgrid( 28 | torch.arange(W, device=device)[::render_stride], 29 | torch.arange(H, device=device)[::render_stride], 30 | indexing="xy", 31 | ) # both are [H, W] 32 | 33 | render_h, render_w = u.shape 34 | 35 | u = u.reshape(-1).float() # + 0.5 # add half pixel 36 | v = v.reshape(-1).float() # + 0.5 37 | pixels = torch.stack((u, v, torch.ones_like(u)), dim=0) # (3, H*W) 38 | batched_pixels = pixels.unsqueeze(0).repeat(batch_size, 1, 1) # [B, 3, HxW] 39 | 40 | rays_d = ( 41 | c2w[:, :3, :3].bmm(torch.inverse(intrinsics[:, :3, :3])).bmm(batched_pixels) 42 | ).transpose( 43 | 1, 2 44 | ) # [B, 3, 3] x [B, 3, 3] x [B, 3, HxW] -> [B, 3, HxW] -> [B, HxW, 3] 45 | rays_o = c2w[:, :3, 3].unsqueeze(1).repeat(1, rays_d.shape[1], 1) # [B, HxW, 3] 46 | rays_d = rays_d.reshape(-1, 3) # [BxHxW, 3] 47 | rays_o = rays_o.reshape(-1, 3) # [BxHxW, 3] 48 | uvs = batched_pixels[:, :2, :].permute(0, 2, 1).reshape((-1, 2)) 49 | 50 | batch_refs = ( 51 | torch.arange(batch_size) 52 | .reshape((batch_size, 1)) 53 | .expand(-1, u.shape[0]) 54 | .reshape(-1) 55 | ) # [BxHxW] 56 | 57 | return rays_o, rays_d, uvs, batch_refs, (render_h, render_w) 58 | 59 | def softsplat_img( 60 | self, 61 | *, 62 | rgb_src1, 63 | flow_src1_to_tgt, 64 | rgb_src2=None, 65 | flow_src1_to_src2=None, 66 | softsplat_metric_src1_to_src2=None, 67 | ): 68 | if softsplat_metric_src1_to_src2 is None: 69 | backwarp_img_for_softsplt_metric = self.backwarp_for_softsplat_metric( 70 | tenIn=rgb_src2, tenFlow=flow_src1_to_src2 71 | ) # [B, 3, H, W] 72 | softsplat_metric_src1_to_src2 = torch.nn.functional.l1_loss( 73 | input=rgb_src1, 74 | target=backwarp_img_for_softsplt_metric, 75 | reduction="none", 76 | ).mean( 77 | dim=1, keepdim=True 78 | ) # [B, 1, H, W] 79 | 80 | splat_img_src1_to_tgt = softsplat.softsplat( 81 | tenIn=rgb_src1, 82 | tenFlow=flow_src1_to_tgt, 83 | tenMetric=( 84 | -self.softsplat_metric_abs_alpha * softsplat_metric_src1_to_src2 85 | ).clip(-self.softsplat_metric_abs_alpha, self.softsplat_metric_abs_alpha), 86 | strMode="soft", 87 | ) # [B, 3, H, W] 88 | 89 | return splat_img_src1_to_tgt, softsplat_metric_src1_to_src2 90 | 91 | def backwarp_for_softsplat_metric(self, tenIn, tenFlow): 92 | if not hasattr(self, "backwarp_grid_dict"): 93 | self.backwarp_grid_dict = {} 94 | 95 | if str(tenFlow.shape) not in self.backwarp_grid_dict: 96 | tenHor = ( 97 | torch.linspace( 98 | start=-1.0, 99 | end=1.0, 100 | steps=tenFlow.shape[3], 101 | dtype=tenFlow.dtype, 102 | device=tenFlow.device, 103 | ) 104 | .view(1, 1, 1, -1) 105 | .repeat(1, 1, tenFlow.shape[2], 1) 106 | ) 107 | tenVer = ( 108 | torch.linspace( 109 | start=-1.0, 110 | end=1.0, 111 | steps=tenFlow.shape[2], 112 | dtype=tenFlow.dtype, 113 | device=tenFlow.device, 114 | ) 115 | .view(1, 1, -1, 1) 116 | .repeat(1, 1, 1, tenFlow.shape[3]) 117 | ) 118 | self.backwarp_grid_dict[str(tenFlow.shape)] = torch.cat( 119 | [tenHor, tenVer], 1 120 | ).to(tenIn.device) 121 | 122 | tenFlow = torch.cat( 123 | [ 124 | tenFlow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), 125 | tenFlow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0), 126 | ], 127 | 1, 128 | ) 129 | 130 | return torch.nn.functional.grid_sample( 131 | input=tenIn, 132 | grid=(self.backwarp_grid_dict[str(tenFlow.shape)] + tenFlow).permute( 133 | 0, 2, 3, 1 134 | ), 135 | mode="bilinear", 136 | padding_mode="zeros", 137 | align_corners=True, 138 | ) 139 | -------------------------------------------------------------------------------- /pgdvs/renderers/st_geo_renderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import pathlib 4 | import hydra 5 | import hydra.utils 6 | import numpy as np 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import pytorch3d 11 | import pytorch3d.utils 12 | import pytorch3d.ops as p3d_ops 13 | 14 | from pgdvs.utils.training import disabled_train 15 | from pgdvs.utils.vis_utils import draw_cam_mesh 16 | from pgdvs.models.gnt.common import TINY_NUMBER, HUGE_NUMBER 17 | from pgdvs.models.gnt.projector import Projector 18 | 19 | 20 | class StaticGeoPointRenderer(torch.nn.Module): 21 | def __init__(self, model_cfg=None): 22 | super().__init__() 23 | 24 | self.projector = Projector() 25 | 26 | def forward(self, *, tgt_h, tgt_w, flat_tgt_cam, st_pcl_rgb, render_cfg): 27 | assert st_pcl_rgb.ndim == 2, f"{st_pcl_rgb.shape}" 28 | 29 | st_pcl = st_pcl_rgb[:, :3] # [#pt, 3] 30 | st_rgb = st_pcl_rgb[:, 3:] # [#pt, 3] 31 | 32 | if render_cfg.st_pcl_remove_outlier: 33 | # removal of outliers 34 | # - https://github.com/facebookresearch/pytorch3d/issues/511#issuecomment-1152970392 35 | # - http://www.open3d.org/docs/release/tutorial/geometry/pointcloud_outlier_removal.html 36 | # - https://pcl.readthedocs.io/en/latest/statistical_outlier.html 37 | nn_dists, nn_idxs, nn_pts = p3d_ops.knn_points( 38 | st_pcl[None, ...], 39 | st_pcl[None, ...], 40 | K=(render_cfg.st_pcl_outlier_knn + 1), 41 | return_nn=True, 42 | ) # nn_dists/idxs: [1, #pts, K]; nn_pts: [1, #pts, K, 3] 43 | 44 | # The 1st distance is always 0 as the nearest is the point itself. 45 | nn_dists = nn_dists[0, :, 1:] 46 | nn_idxs = nn_idxs[0, :, 1:] 47 | nn_pts = nn_pts[0, :, 1:, :] 48 | 49 | # We mimic Open3D's statistical removal 50 | # https://github.com/isl-org/Open3D/blob/6ddbcd5c9b8bf0b496e4151c7d7766af09e3dba7/cpp/open3d/geometry/PointCloud.cpp#L636-L653 51 | avg_nn_dist = torch.mean(nn_dists, dim=1) # [#pt, ] 52 | st_pcl_nn_dist_med = torch.median(avg_nn_dist) 53 | st_pcl_nn_dist_std = torch.std(avg_nn_dist) 54 | nn_dist_thres = ( 55 | st_pcl_nn_dist_med 56 | + st_pcl_nn_dist_std * render_cfg.st_pcl_outlier_std_thres 57 | ) 58 | 59 | flag_not_outlier = avg_nn_dist < nn_dist_thres 60 | assert ( 61 | flag_not_outlier.shape[0] == st_pcl.shape[0] 62 | ), f"{flag_not_outlier.shape}, {st_pcl.shape}" 63 | 64 | st_pcl_clean = st_pcl[flag_not_outlier, :] 65 | else: 66 | st_pcl_clean = st_pcl 67 | flag_not_outlier = torch.ones( 68 | (st_pcl.shape[0]), dtype=bool, device=st_pcl.device 69 | ) 70 | 71 | assert ( 72 | flag_not_outlier.shape[0] == st_rgb.shape[0] 73 | ), f"{flag_not_outlier.shape}, {st_rgb.shape}" 74 | 75 | st_rgb_clean = st_rgb[flag_not_outlier, :] 76 | 77 | K = flat_tgt_cam[2:18].reshape((4, 4)) 78 | c2w = flat_tgt_cam[18:34].reshape((4, 4)) 79 | w2c = torch.inverse(c2w) 80 | 81 | if st_pcl_clean.shape[0] == 0: 82 | mesh_img = torch.zeros((tgt_h, tgt_w, 3)) # [H, W, 3] 83 | mesh_mask = torch.zeros((tgt_h, tgt_w, 1)) # [H, W, 1] 84 | else: 85 | img_size = torch.LongTensor([tgt_h, tgt_w]).reshape((1, 2)) 86 | cameras_pytorch3d = pytorch3d.utils.cameras_from_opencv_projection( 87 | w2c[None, :3, :3], w2c[None, :3, 3], K[None, :3, :3], img_size 88 | ) 89 | 90 | # for bin size, see https://github.com/facebookresearch/pytorch3d/issues/1064 91 | raster_settings = pytorch3d.renderer.PointsRasterizationSettings( 92 | image_size=(tgt_h, tgt_w), 93 | radius=render_cfg.st_render_pcl_pt_radius, 94 | points_per_pixel=render_cfg.st_render_pcl_pts_per_pixel, 95 | bin_size=0, 96 | ) 97 | 98 | # Create a points renderer by compositing points using an alpha compositor (nearer points 99 | # are weighted more heavily). See [1] for an explanation. 100 | rasterizer = pytorch3d.renderer.PointsRasterizer( 101 | cameras=cameras_pytorch3d, raster_settings=raster_settings 102 | ) 103 | point_renderer = pytorch3d.renderer.PointsRenderer( 104 | rasterizer=rasterizer, 105 | compositor=pytorch3d.renderer.NormWeightedCompositor( 106 | background_color=(0, 0, 0) 107 | ), 108 | ) 109 | 110 | dy_mesh = pytorch3d.structures.Pointclouds( 111 | points=st_pcl_clean[None, ...], # [1, #pts, 3] 112 | features=st_rgb_clean[None, ...], # [1, #pt, 3] 113 | ) 114 | 115 | mesh_img = point_renderer(dy_mesh)[0, :, :, :3] # [H, W, 3] 116 | 117 | dy_mesh.features = torch.ones_like(st_rgb_clean)[None, ...] 118 | mesh_mask = ( 119 | point_renderer(dy_mesh)[0, :, :, :1] > 0.0 120 | ).float() # [H, W, 1] 121 | 122 | return mesh_img, mesh_mask 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Pseudo-Generalized Dynamic View Synthesis

2 | 3 |

ICLR 2024

4 | 5 |

6 | 7 |

8 | 9 | **Pseudo-Generalized Dynamic View Synthesis from a Video, ICLR 2024.**
10 | [Xiaoming Zhao](https://xiaoming-zhao.com/), [Alex Colburn](https://www.colburn.org/), [Fangchang Ma](https://fangchangma.github.io/), [Miguel Ángel Bautista](https://scholar.google.com/citations?user=ZrRs-qoAAAAJ&hl=en), [Joshua M Susskind](https://scholar.google.com/citations?user=Sv2TGqsAAAAJ&hl=en), and [Alexander G. Schwing](https://www.alexander-schwing.de/). 11 | 12 | ### [Project Page](https://xiaoming-zhao.github.io/projects/pgdvs/) | [Paper](https://arxiv.org/abs/2310.08587) 13 | 14 | ## Table of Contents 15 | 16 | - [Environment Setup](#environment-setup) 17 | - [Try PGDVS on Video in the Wild](#try-pgdvs-on-video-in-the-wild) 18 | - [Benchmarking](#benchmarking) 19 | - [Citation](#citation) 20 | - [License](#license) 21 | - [Acknowledgements](#acknowledgements) 22 | 23 | ## Environment Setup 24 | 25 | This code has been tested on Ubuntu 20.04 with CUDA 11.8 on NVIDIA A100-SXM4-80GB GPU (driver 470.82.01). 26 | 27 | We recommend using `conda` for virtual environment control and [`libmamba`](https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community) for a faster dependency check. 28 | 29 | ```bash 30 | # setup libmamba 31 | conda install -n base conda-libmamba-solver -y 32 | conda config --set solver libmamba 33 | 34 | # create virtual environment 35 | conda env create -f envs/pgdvs.yaml 36 | 37 | conda activate pgdvs 38 | conda install pytorch3d=0.7.4 -c pytorch3d -y 39 | ``` 40 | 41 | **[optional]** Run the following to install JAX if you want to 42 | 1. try [TAPIR](https://github.com/google-deepmind/tapnet) 43 | 2. evaluate with metrics computation from [DyCheck](https://github.com/KAIR-BAIR/dycheck) 44 | ```bash 45 | conda activate pgdvs 46 | pip install -r envs/requirements_jax.txt --verbose 47 | ``` 48 | To check that JAX is installed correctly, run the following. 49 | **NOTE**: the first `import torch` is important since it will make sure that JAX finds the cuDNN installed by `conda`. 50 | ```bash 51 | conda activate pgdvs 52 | python -c "import torch; from jax import random; key = random.PRNGKey(0); x = random.normal(key, (10,)); print(x)" 53 | ``` 54 | 55 | ## Try PGDVS on Video in the Wild 56 | 57 | ### Download Checkpoints 58 | 59 | ```bash 60 | # this environment variable is used for demonstration 61 | cd /path/to/this/repo 62 | export PGDVS_ROOT=$PWD 63 | ``` 64 | 65 | Since we use third parties's pretrained models, we provide two ways to download them: 66 | 1. Directly download from those official repositories; 67 | 2. Download from our copy for reproducing results in the paper just in case those official repositories's checkpoints are modified in the future. 68 | ```bash 69 | FLAG_ORIGINAL=1 # set to 0 if you want to download from our copy 70 | bash ${PGDVS_ROOT}/scripts/download_ckpts.sh ${PGDVS_ROOT}/ckpts ${FLAG_ORIGINAL} 71 | ``` 72 | 73 | ### Example of DAVIS 74 | 75 | We use [DAVIS](https://davischallenge.org/) as an example to illustrate how to render novel view from monocular videos in the wild. Please see [IN_THE_WILD.md](./docs/IN_THE_WILD.md) for details. 76 | 77 | 78 | ## Benchmarking 79 | 80 | Please see [BENCHMARK_NVIDIA.md](./docs/BENCHMARK_NVIDIA.md) and [BENCHMARK_iPhone.md](./docs/BENCHMARK_iPhone.md) for details about reproducing results on [NVIDIA Dynamic Scenes](https://gorokee.github.io/jsyoon/dynamic_synth/) and [DyCheck's iPhone Dataset](https://github.com/KAIR-BAIR/dycheck) in the paper. 81 | 82 | ## Citation 83 | >Xiaoming Zhao, Alex Colburn, Fangchang Ma, Miguel Ángel Bautista, Joshua M Susskind, and Alexander G. Schwing. Pseudo-Generalized Dynamic View Synthesis from a Video. ICLR 2024. 84 | ``` 85 | @inproceedings{Zhao2024PGDVS, 86 | title={{Pseudo-Generalized Dynamic View Synthesis from a Video}}, 87 | author={Xiaoming Zhao and Alex Colburn and Fangchang Ma and Miguel Angel Bautista and Joshua M. Susskind and Alexander G. Schwing}, 88 | booktitle={ICLR}, 89 | year={2024}, 90 | } 91 | ``` 92 | 93 | ## License 94 | 95 | This sample code is released under the [LICENSE](./LICENSE) terms. 96 | 97 | ## Acknowledgements 98 | 99 | Our project is not possible without the following ones: 100 | - [GNT](https://github.com/VITA-Group/GNT) (commit `7b63996cb807dbb5c95ab6898e8093996588e73a`) 101 | - [RAFT](https://github.com/princeton-vl/RAFT) (commit `3fa0bb0a9c633ea0a9bb8a79c576b6785d4e6a02`) 102 | - [OneFormer](https://github.com/SHI-Labs/OneFormer) (commit `56799ef9e02968af4c7793b30deabcbeec29ffc0`) 103 | - [segment-anything](https://github.com/facebookresearch/segment-anything) (commit `6fdee8f2727f4506cfbbe553e23b895e27956588`) 104 | - [ZoeDepth](https://github.com/isl-org/ZoeDepth) (commit `edb6daf45458569e24f50250ef1ed08c015f17a7`) 105 | - [TAPIR](https://github.com/deepmind/tapnet) (commit `4ac6b2acd0aed36c0762f4247de9e8630340e2e0`) 106 | - [CoTracker](https://github.com/facebookresearch/co-tracker) (commit `0a0596b277545625054cb041f00419bcd3693ea5`) 107 | - [casualSAM](https://github.com/ztzhang/casualSAM) (we use [our modified version](https://github.com/Xiaoming-Zhao/casualSAM)) 108 | - [dynamic-video-depth](https://github.com/google/dynamic-video-depth) (we use [our modified version](https://github.com/Xiaoming-Zhao/dynamic-video-depth)) -------------------------------------------------------------------------------- /pgdvs/engines/abstract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | import random 4 | import boto3 5 | import botocore 6 | import logging 7 | import numpy as np 8 | from abc import abstractmethod 9 | from omegaconf import DictConfig 10 | 11 | import torch 12 | 13 | from pgdvs.utils.training import resume_from_ckpt, upload_to_s3 14 | 15 | LOGGER = logging.getLogger(__name__) 16 | 17 | 18 | def combine(l): 19 | if isinstance(l[0], torch.Tensor): 20 | return torch.stack(l, dim=0) 21 | if isinstance(l[0], float): 22 | return torch.Tensor(l) 23 | if isinstance(l[0], list) and isinstance(l[0][0], float): 24 | return torch.Tensor(l) 25 | else: 26 | return l 27 | 28 | 29 | def default_collate_fn(batch): 30 | return {k: combine([x[k] for x in batch]) for k in batch[0].keys()} 31 | 32 | 33 | class AbstractEngine: 34 | def __init__( 35 | self, 36 | cfg: DictConfig, 37 | hydra_config: DictConfig, 38 | *, 39 | global_rank: int, 40 | local_rank: int, 41 | world_size: int, 42 | run_dir: str, 43 | hydra_config_dir: str = None, 44 | hydra_log_f: str = None, 45 | is_training: bool = False, 46 | flag_verbose: bool = True, 47 | ) -> None: 48 | self.global_rank = global_rank 49 | self.local_rank = local_rank 50 | self.world_size = world_size 51 | 52 | self.cfg = cfg 53 | self.hydra_config = hydra_config 54 | 55 | if torch.cuda.is_available(): 56 | self.device = "cuda:%d" % local_rank 57 | else: 58 | self.device = torch.device("cpu") 59 | 60 | self.is_training = is_training 61 | 62 | self.verbose = flag_verbose 63 | 64 | self.hydra_log_f = hydra_log_f 65 | 66 | self.LOG_DIR = run_dir # Essentially hydra_config.runtime.output_dir 67 | 68 | self.INFO_DIR = os.path.join(self.LOG_DIR, f"infos") 69 | self.CHECKPOINT_FOLDER = os.path.join(self.LOG_DIR, "checkpoints") 70 | self.TENSORBOARD_DIR = os.path.join(self.LOG_DIR, "tb") 71 | self.VIS_DIR = os.path.join(self.LOG_DIR, "vis") 72 | 73 | os.makedirs(self.INFO_DIR, exist_ok=True) 74 | os.makedirs(self.CHECKPOINT_FOLDER, exist_ok=True) 75 | os.makedirs(self.TENSORBOARD_DIR, exist_ok=True) 76 | os.makedirs(self.VIS_DIR, exist_ok=True) 77 | 78 | LOGGER.info("All files are saved to %s" % self.LOG_DIR) 79 | 80 | # setup model 81 | self.model_modules_not_to_save = None # this must be placed before build_model 82 | self.model, self.optimizer = self.build_model() 83 | 84 | self.init_epoch = 0 85 | self.init_total_steps = 0 86 | self.init_total_steps_on_epoch_start = 0 87 | 88 | self.max_epochs = self.cfg.max_epochs 89 | 90 | self.s3_info = None 91 | 92 | if ( 93 | self.cfg.resume 94 | not in ["none", "elastic", "eval_wo_resume", "vis_wo_resume"] 95 | ) or (self.cfg.resume in ["eval"] and not self.cfg.series_eval): 96 | self.run_resume(self.cfg.resume_epoch) 97 | 98 | if isinstance(self.model, dict): 99 | for k in self.model: 100 | self.model[k].to(self.device) 101 | else: 102 | self.model.to(self.device) 103 | 104 | # setup dataset 105 | self.datasets = self.build_dataset() 106 | LOGGER.info( 107 | f"Initialization done in Rank {self.global_rank} | Local rank {self.local_rank}" 108 | ) 109 | 110 | def run_resume(self, resume_epoch): 111 | if self.cfg.resume_dir is not None: 112 | # load other pretrain models. 113 | ckpt_dir = self.cfg.resume_dir 114 | else: 115 | # resume from the default path. 116 | ckpt_dir = self.CHECKPOINT_FOLDER 117 | 118 | ( 119 | self.init_epoch, 120 | _, 121 | self.init_total_steps, 122 | self.init_total_steps_on_epoch_start, 123 | ) = resume_from_ckpt( 124 | ckpt_dir=ckpt_dir, 125 | model=self.model, 126 | modules_not_saved=self.model_modules_not_to_save, 127 | optimizer=self.optimizer if self.cfg.resume_dir is not None else None, 128 | epoch=resume_epoch, 129 | strict=False, 130 | cfg=self.cfg, 131 | device=self.device, 132 | ) 133 | 134 | def upload_info_to_s3(self): 135 | for cur_root, cur_dirs, cur_files in os.walk(self.LOG_DIR): 136 | for tmp_f in cur_files: 137 | local_f = os.path.join(cur_root, tmp_f) 138 | if ("tb" in local_f) or ("hydra" in local_f): 139 | upload_to_s3(local_f, self.LOG_DIR, **self.s3_info) 140 | 141 | def run_preparation(self): 142 | pass 143 | 144 | def _set_seed(self, seed): 145 | random.seed(seed) 146 | np.random.seed(seed) 147 | torch.manual_seed(seed) 148 | if torch.cuda.is_available(): 149 | torch.cuda.manual_seed_all(seed) 150 | torch.backends.cudnn.deterministic = True 151 | torch.backends.cudnn.benchmark = False 152 | 153 | def _to_gpu_func(self, batch, device): 154 | return { 155 | k: v.to(device) if isinstance(v, torch.Tensor) else v 156 | for k, v in batch.items() 157 | } 158 | 159 | @abstractmethod 160 | def build_model(self): 161 | raise NotImplementedError 162 | 163 | @abstractmethod 164 | def build_dataset(self): 165 | raise NotImplementedError 166 | -------------------------------------------------------------------------------- /pgdvs/models/cotracker/models/core/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import numpy as np 9 | 10 | 11 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 12 | """ 13 | grid_size: int of the grid height and width 14 | return: 15 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 16 | """ 17 | if isinstance(grid_size, tuple): 18 | grid_size_h, grid_size_w = grid_size 19 | else: 20 | grid_size_h = grid_size_w = grid_size 21 | grid_h = np.arange(grid_size_h, dtype=np.float32) 22 | grid_w = np.arange(grid_size_w, dtype=np.float32) 23 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 24 | grid = np.stack(grid, axis=0) 25 | 26 | grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) 27 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 28 | if cls_token and extra_tokens > 0: 29 | pos_embed = np.concatenate( 30 | [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 31 | ) 32 | return pos_embed 33 | 34 | 35 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 36 | assert embed_dim % 2 == 0 37 | 38 | # use half of dimensions to encode grid_h 39 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 40 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 41 | 42 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 43 | return emb 44 | 45 | 46 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 47 | """ 48 | embed_dim: output dimension for each position 49 | pos: a list of positions to be encoded: size (M,) 50 | out: (M, D) 51 | """ 52 | assert embed_dim % 2 == 0 53 | omega = np.arange(embed_dim // 2, dtype=np.float64) 54 | omega /= embed_dim / 2.0 55 | omega = 1.0 / 10000**omega # (D/2,) 56 | 57 | pos = pos.reshape(-1) # (M,) 58 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 59 | 60 | emb_sin = np.sin(out) # (M, D/2) 61 | emb_cos = np.cos(out) # (M, D/2) 62 | 63 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 64 | return emb 65 | 66 | 67 | def get_2d_embedding(xy, C, cat_coords=True): 68 | B, N, D = xy.shape 69 | assert D == 2 70 | 71 | x = xy[:, :, 0:1] 72 | y = xy[:, :, 1:2] 73 | div_term = ( 74 | torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) 75 | ).reshape(1, 1, int(C / 2)) 76 | 77 | pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 78 | pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 79 | 80 | pe_x[:, :, 0::2] = torch.sin(x * div_term) 81 | pe_x[:, :, 1::2] = torch.cos(x * div_term) 82 | 83 | pe_y[:, :, 0::2] = torch.sin(y * div_term) 84 | pe_y[:, :, 1::2] = torch.cos(y * div_term) 85 | 86 | pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3 87 | if cat_coords: 88 | pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3 89 | return pe 90 | 91 | 92 | def get_3d_embedding(xyz, C, cat_coords=True): 93 | B, N, D = xyz.shape 94 | assert D == 3 95 | 96 | x = xyz[:, :, 0:1] 97 | y = xyz[:, :, 1:2] 98 | z = xyz[:, :, 2:3] 99 | div_term = ( 100 | torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C) 101 | ).reshape(1, 1, int(C / 2)) 102 | 103 | pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) 104 | pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) 105 | pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32) 106 | 107 | pe_x[:, :, 0::2] = torch.sin(x * div_term) 108 | pe_x[:, :, 1::2] = torch.cos(x * div_term) 109 | 110 | pe_y[:, :, 0::2] = torch.sin(y * div_term) 111 | pe_y[:, :, 1::2] = torch.cos(y * div_term) 112 | 113 | pe_z[:, :, 0::2] = torch.sin(z * div_term) 114 | pe_z[:, :, 1::2] = torch.cos(z * div_term) 115 | 116 | pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3 117 | if cat_coords: 118 | pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3 119 | return pe 120 | 121 | 122 | def get_4d_embedding(xyzw, C, cat_coords=True): 123 | B, N, D = xyzw.shape 124 | assert D == 4 125 | 126 | x = xyzw[:, :, 0:1] 127 | y = xyzw[:, :, 1:2] 128 | z = xyzw[:, :, 2:3] 129 | w = xyzw[:, :, 3:4] 130 | div_term = ( 131 | torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C) 132 | ).reshape(1, 1, int(C / 2)) 133 | 134 | pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) 135 | pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) 136 | pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) 137 | pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32) 138 | 139 | pe_x[:, :, 0::2] = torch.sin(x * div_term) 140 | pe_x[:, :, 1::2] = torch.cos(x * div_term) 141 | 142 | pe_y[:, :, 0::2] = torch.sin(y * div_term) 143 | pe_y[:, :, 1::2] = torch.cos(y * div_term) 144 | 145 | pe_z[:, :, 0::2] = torch.sin(z * div_term) 146 | pe_z[:, :, 1::2] = torch.cos(z * div_term) 147 | 148 | pe_w[:, :, 0::2] = torch.sin(w * div_term) 149 | pe_w[:, :, 1::2] = torch.cos(w * div_term) 150 | 151 | pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3 152 | if cat_coords: 153 | pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3 154 | return pe 155 | -------------------------------------------------------------------------------- /pgdvs/preprocess/convert_colmap_output.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/d4001759a39b056c95d8bc22da34b10b4fb85afb/nsff_scripts/save_poses_nerf.py 2 | 3 | import os 4 | import sys 5 | import json 6 | import tqdm 7 | import pathlib 8 | import argparse 9 | import numpy as np 10 | 11 | import pgdvs.preprocess.colmap_reader as colmap_reader 12 | 13 | 14 | def hwf_to_K(hwf): 15 | h, w, f = hwf[:, 0].tolist() # [3, 1] -> [3,] 16 | print(h, w, f) 17 | K = np.eye(3) 18 | K[0, 0] = f 19 | K[1, 1] = f 20 | K[0, 2] = w / 2.0 21 | K[1, 2] = h / 2.0 22 | 23 | return K 24 | 25 | 26 | def get_bbox_corners(points): 27 | lower = points.min(axis=0) 28 | upper = points.max(axis=0) 29 | return np.stack([lower, upper]) 30 | 31 | 32 | def filter_outlier_points(points, inner_percentile): 33 | """Filters outlier points.""" 34 | outer = 1.0 - inner_percentile 35 | lower = outer / 2.0 36 | upper = 1.0 - lower 37 | centers_min = np.quantile(points, lower, axis=0) 38 | centers_max = np.quantile(points, upper, axis=0) 39 | result = points.copy() 40 | 41 | too_near = np.any(result < centers_min[None, :], axis=1) 42 | too_far = np.any(result > centers_max[None, :], axis=1) 43 | 44 | return result[~(too_near | too_far)] 45 | 46 | 47 | def load_colmap_data(realdir, save_dir): 48 | camerasfile = os.path.join(realdir, "sparse/cameras.bin") 49 | camdata = colmap_reader.read_cameras_binary(camerasfile) 50 | 51 | list_of_keys = list(camdata.keys()) 52 | assert ( 53 | len(list_of_keys) == 1 54 | ), f"{list_of_keys}" # check that there is only one track 55 | cam = camdata[list_of_keys[0]] 56 | print("Cameras", len(cam)) 57 | 58 | h, w, f = cam.height, cam.width, cam.params[0] 59 | # w, h, f = factor * w, factor * h, factor * f 60 | hwf = np.array([h, w, f]).reshape([3, 1]) 61 | 62 | imagesfile = os.path.join(realdir, "sparse/images.bin") 63 | imdata = colmap_reader.read_images_binary(imagesfile) 64 | 65 | w2c_mats = [] 66 | bottom = np.array([0, 0, 0, 1.0]).reshape([1, 4]) 67 | 68 | names = [imdata[k].name for k in imdata] 69 | img_keys = [k for k in imdata] 70 | 71 | print("Images #", len(names)) 72 | perm = np.argsort(names) 73 | 74 | points3dfile = os.path.join(realdir, "sparse/points3D.bin") 75 | pts3d = colmap_reader.read_points3d_binary(points3dfile) 76 | 77 | # extract point 3D xyz 78 | point_cloud = [] 79 | for key in pts3d: 80 | point_cloud.append(pts3d[key].xyz) 81 | 82 | point_cloud = np.stack(point_cloud, 0) 83 | point_cloud = filter_outlier_points(point_cloud, 0.95) 84 | 85 | bounds_mats = [] 86 | 87 | upper_bound = 1000 88 | 89 | if upper_bound < len(img_keys): 90 | print("Only keeping " + str(upper_bound) + " images!") 91 | 92 | for i in tqdm.tqdm(perm[0 : min(upper_bound, len(img_keys))], desc="#images"): 93 | im = imdata[img_keys[i]] 94 | # print(im.name) 95 | R = im.qvec2rotmat() 96 | t = im.tvec.reshape([3, 1]) 97 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 98 | w2c_mats.append(m) 99 | 100 | pts_3d_idx = im.point3D_ids 101 | pts_3d_vis_idx = pts_3d_idx[pts_3d_idx >= 0] 102 | 103 | # 104 | depth_list = [] 105 | for k in range(len(pts_3d_vis_idx)): 106 | point_info = pts3d[pts_3d_vis_idx[k]] 107 | P_g = point_info.xyz 108 | P_c = np.dot(R, P_g.reshape(3, 1)) + t.reshape(3, 1) 109 | depth_list.append(P_c[2]) 110 | 111 | zs = np.array(depth_list) 112 | close_depth, inf_depth = np.percentile(zs, 5), np.percentile(zs, 95) 113 | bounds = np.array([close_depth, inf_depth]) 114 | bounds_mats.append(bounds) 115 | 116 | w2c_mats = np.stack(w2c_mats, 0) 117 | # bounds_mats = np.stack(bounds_mats, 0) 118 | c2w_mats = np.linalg.inv(w2c_mats) # [#frame, 4, 4] 119 | 120 | # bbox_corners = get_bbox_corners(point_cloud) 121 | # also add camera 122 | bbox_corners = get_bbox_corners( 123 | np.concatenate([point_cloud, c2w_mats[:, :3, 3]], axis=0) 124 | ) 125 | 126 | scene_center = np.mean(bbox_corners, axis=0) 127 | scene_scale = 1.0 / np.sqrt(np.sum((bbox_corners[1] - bbox_corners[0]) ** 2)) 128 | 129 | print("bbox_corners ", bbox_corners) 130 | print("scene_center ", scene_center, scene_scale) 131 | 132 | K = np.eye(4) 133 | K[:3, :3] = hwf_to_K(hwf) # [4, 4] 134 | 135 | n_frames = c2w_mats.shape[0] 136 | 137 | tiled_K = np.tile(K[np.newaxis, ...], [n_frames, 1, 1]) # [#frame, 4, 4] 138 | 139 | save_arr = np.concatenate( 140 | [c2w_mats.reshape((n_frames, 16)), tiled_K.reshape((n_frames, 16))], 1 141 | ) # [#frame, 32] 142 | 143 | print(save_arr.shape) 144 | np.save(save_dir / "poses.npy", save_arr) 145 | 146 | with open(save_dir / "scene.json", "w") as f: 147 | json.dump( 148 | { 149 | "scale": scene_scale, 150 | "center": scene_center.tolist(), 151 | "bbox": bbox_corners.tolist(), 152 | }, 153 | f, 154 | indent=2, 155 | ) 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument("--data_path", type=str, help="COLMAP Directory") 161 | parser.add_argument("--save_dir", type=str, help="save Directory") 162 | 163 | args = parser.parse_args() 164 | 165 | base_dir = args.data_path 166 | save_dir = pathlib.Path(args.save_dir) 167 | save_dir.mkdir(parents=True, exist_ok=True) 168 | load_colmap_data(base_dir, save_dir) 169 | print("Done with imgs2poses") 170 | -------------------------------------------------------------------------------- /pgdvs/utils/rendering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import tqdm 4 | import numpy as np 5 | from typing import List, Optional 6 | 7 | import torch 8 | 9 | 10 | def clamp_rgb(rgb, tgt_range="0_255"): 11 | assert tgt_range in ["-1_1", "0_1", "0_255"], tgt_range 12 | 13 | min_val = float(tgt_range.split("_")[0]) 14 | max_val = float(tgt_range.split("_")[1]) 15 | 16 | if isinstance(rgb, np.ndarray): 17 | rgb = np.clip(rgb, min_val, max_val) 18 | elif isinstance(rgb, torch.Tensor): 19 | rgb = torch.clamp(rgb, min_val, max_val) 20 | else: 21 | raise TypeError 22 | 23 | return rgb 24 | 25 | 26 | def modify_rgb_range( 27 | rgb, src_range="0_255", tgt_range="0_255", check_range=True, enforce_range=False 28 | ): 29 | assert src_range in ["-1_1", "0_1", "0_255"], src_range 30 | assert tgt_range in ["-1_1", "0_1", "0_255"], tgt_range 31 | 32 | if src_range == tgt_range: 33 | return rgb 34 | 35 | if isinstance(rgb, np.ndarray): 36 | rgb = rgb.astype(np.float32) 37 | elif isinstance(rgb, torch.Tensor): 38 | rgb = rgb.float() 39 | else: 40 | raise TypeError 41 | 42 | if check_range: 43 | src_min = float(src_range.split("_")[0]) 44 | src_max = float(src_range.split("_")[1]) 45 | assert rgb.min() >= src_min, f"{rgb.min()}, {src_min}" 46 | assert rgb.max() <= src_max, f"{rgb.max()}, {src_max}" 47 | 48 | # We change ranges to be in [0, 1] 49 | if src_range == "0_255": 50 | rgb = rgb / 255.0 51 | elif src_range == "-1_1": 52 | rgb = (rgb + 1.0) / 2.0 53 | else: 54 | pass 55 | 56 | if enforce_range: 57 | if isinstance(rgb, np.ndarray): 58 | rgb = np.clip(rgb, 0.0, 1.0) 59 | elif isinstance(rgb, torch.Tensor): 60 | rgb = rgb.clamp(0.0, 1.0) 61 | else: 62 | raise TypeError 63 | 64 | # Now, RGB is in range of [0, 1] 65 | if tgt_range == "-1_1": 66 | rgb = 2.0 * rgb - 1.0 67 | elif tgt_range == "0_255": 68 | rgb = rgb * 255.0 69 | 70 | if check_range: 71 | tgt_min = float(tgt_range.split("_")[0]) 72 | tgt_max = float(tgt_range.split("_")[1]) 73 | assert rgb.min() >= tgt_min, f"{rgb.min()}, {tgt_min}" 74 | assert rgb.max() <= tgt_max, f"{rgb.max()}, {tgt_max}" 75 | 76 | return rgb 77 | 78 | 79 | def images_to_video( 80 | images: List[np.ndarray], 81 | output_dir: str, 82 | video_name: str, 83 | fps: int = 10, 84 | quality: Optional[float] = 5, 85 | disable_tqdm=False, 86 | **kwargs, 87 | ): 88 | r"""Calls imageio to run FFMPEG on a list of images. For more info on 89 | parameters, see https://imageio.readthedocs.io/en/stable/format_ffmpeg.html 90 | Args: 91 | images: The list of images. Images should be HxWx3 in RGB order. 92 | output_dir: The folder to put the video in. 93 | video_name: The name for the video. 94 | fps: Frames per second for the video. Not all values work with FFMPEG, 95 | use at your own risk. 96 | quality: Default is 5. Uses variable bit rate. Highest quality is 10, 97 | lowest is 0. Set to None to prevent variable bitrate flags to 98 | FFMPEG so you can manually specify them using output_params 99 | instead. Specifying a fixed bitrate using ‘bitrate’ disables 100 | this parameter. 101 | """ 102 | assert 0 <= quality <= 10 103 | if not os.path.exists(output_dir): 104 | os.makedirs(output_dir) 105 | video_name = video_name.replace(" ", "_").replace("\n", "_") + ".mp4" 106 | writer = imageio.get_writer( 107 | os.path.join(output_dir, video_name), 108 | fps=fps, 109 | quality=quality, 110 | macro_block_size=1, 111 | **kwargs, 112 | ) 113 | # print(f"Video created: {os.path.join(output_dir, video_name)}") 114 | for im in tqdm.tqdm(images, disable=disable_tqdm): 115 | writer.append_data(im) 116 | writer.close() 117 | 118 | 119 | def CreateRenderPoses(xCamTarget, N, device): 120 | c2w = xCamTarget.transform[0, ...] 121 | up = xCamTarget.transform[0, :3, 1] 122 | 123 | rads = torch.cat( 124 | ( 125 | 0.8 * torch.abs(xCamTarget.transform[0, :3, 3]), 126 | torch.tensor( 127 | [ 128 | 1, 129 | ], 130 | device=device, 131 | ), 132 | ), 133 | dim=0, 134 | ) 135 | rots = 2 136 | focal = 1 137 | flip = False 138 | 139 | def normalize(v): 140 | return v / torch.linalg.norm(v) 141 | 142 | def viewmatrix(z, up, pos): 143 | vec2 = normalize(z) 144 | vec1_avg = up 145 | vec0 = normalize(torch.cross(vec1_avg, vec2)) 146 | vec1 = normalize(torch.cross(vec2, vec0)) 147 | m = torch.stack([vec0, vec1, vec2, pos], 1) 148 | return m 149 | 150 | render_poses = [] 151 | for theta in torch.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]: 152 | c = torch.matmul( 153 | c2w[:3, :4], 154 | torch.tensor( 155 | [torch.cos(theta), -torch.sin(theta), -torch.sin(theta * 0.5), 1.0], 156 | device=device, 157 | ) 158 | * rads, 159 | ) 160 | 161 | if flip: 162 | z = normalize( 163 | torch.matmul( 164 | c2w[:3, :4], torch.tensor([0, 0, focal, 1.0], device=device) 165 | ) 166 | - c 167 | ) 168 | else: 169 | z = normalize( 170 | c 171 | - torch.matmul( 172 | c2w[:3, :4], torch.tensor([0, 0, -focal, 1.0], device=device) 173 | ) 174 | ) 175 | 176 | render_poses.append(viewmatrix(z, up, c)) 177 | return render_poses 178 | -------------------------------------------------------------------------------- /pgdvs/models/cotracker/models/core/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | EPS = 1e-6 10 | 11 | 12 | def smart_cat(tensor1, tensor2, dim): 13 | if tensor1 is None: 14 | return tensor2 15 | return torch.cat([tensor1, tensor2], dim=dim) 16 | 17 | 18 | def normalize_single(d): 19 | # d is a whatever shape torch tensor 20 | dmin = torch.min(d) 21 | dmax = torch.max(d) 22 | d = (d - dmin) / (EPS + (dmax - dmin)) 23 | return d 24 | 25 | 26 | def normalize(d): 27 | # d is B x whatever. normalize within each element of the batch 28 | out = torch.zeros(d.size()) 29 | if d.is_cuda: 30 | out = out.cuda() 31 | B = list(d.size())[0] 32 | for b in list(range(B)): 33 | out[b] = normalize_single(d[b]) 34 | return out 35 | 36 | 37 | def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"): 38 | # returns a meshgrid sized B x Y x X 39 | 40 | grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device)) 41 | grid_y = torch.reshape(grid_y, [1, Y, 1]) 42 | grid_y = grid_y.repeat(B, 1, X) 43 | 44 | grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device)) 45 | grid_x = torch.reshape(grid_x, [1, 1, X]) 46 | grid_x = grid_x.repeat(B, Y, 1) 47 | 48 | if stack: 49 | # note we stack in xy order 50 | # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample) 51 | grid = torch.stack([grid_x, grid_y], dim=-1) 52 | return grid 53 | else: 54 | return grid_y, grid_x 55 | 56 | 57 | def reduce_masked_mean(x, mask, dim=None, keepdim=False): 58 | # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting 59 | # returns shape-1 60 | # axis can be a list of axes 61 | for a, b in zip(x.size(), mask.size()): 62 | assert a == b # some shape mismatch! 63 | prod = x * mask 64 | if dim is None: 65 | numer = torch.sum(prod) 66 | denom = EPS + torch.sum(mask) 67 | else: 68 | numer = torch.sum(prod, dim=dim, keepdim=keepdim) 69 | denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim) 70 | 71 | mean = numer / denom 72 | return mean 73 | 74 | 75 | def bilinear_sample2d(im, x, y, return_inbounds=False): 76 | # x and y are each B, N 77 | # output is B, C, N 78 | if len(im.shape) == 5: 79 | B, N, C, H, W = list(im.shape) 80 | else: 81 | B, C, H, W = list(im.shape) 82 | N = list(x.shape)[1] 83 | 84 | x = x.float() 85 | y = y.float() 86 | H_f = torch.tensor(H, dtype=torch.float32) 87 | W_f = torch.tensor(W, dtype=torch.float32) 88 | 89 | # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() 162 | y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() 163 | inbounds = (x_valid & y_valid).float() 164 | inbounds = inbounds.reshape( 165 | B, N 166 | ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) 167 | return output, inbounds 168 | 169 | return output # B, C, N 170 | -------------------------------------------------------------------------------- /docs/IN_THE_WILD.md: -------------------------------------------------------------------------------- 1 | # Run PGDVS on an in-the-wild Video 2 | 3 | ## Table of Contents 4 | 5 | - [1 Download Data](#1-download-data) 6 | - [2 Prepare Dependencies](#2-prepare-dependencies) 7 | - [3 Preprocess](#3-preprocess) 8 | - [3.1 Two-step Camera Pose and Depth Estimations](#31-two-step-camera-pose-and-depth-estimations) 9 | - [3.2 One-step Camera Pose and Depth estimations](#32-one-step-camera-pose-and-depth-estimations) 10 | - [4 Run Spatial Temporal Interpolation Visualization](#4-run-spatial-temporal-interpolation-visualization) 11 | 12 | ## 1 Download Data 13 | 14 | ```bash 15 | # this environment variable is used for demonstration 16 | cd /path/to/this/repo 17 | export PGDVS_ROOT=$PWD 18 | ``` 19 | 20 | We use [DAVIS](https://davischallenge.org/) as an example to illustrate how to render novel view from monocular videos in the wild. First we download the dataset: 21 | ```bash 22 | wget https://graphics.ethz.ch/Downloads/Data/Davis/DAVIS-data.zip -P ${PGDVS_ROOT}/data 23 | unzip ${PGDVS_ROOT}/data/DAVIS-data.zip -d ${PGDVS_ROOT}/data 24 | ``` 25 | 26 | ## 2 Prepare Dependencies 27 | 28 | We need several third parties's repositories and checkpoints. **NOTE**: the `CUDA_HOME` must be set correctly for [`detectron2`](https://github.com/facebookresearch/detectron2)'s installation and consequentially [`OneFormer`](https://github.com/SHI-Labs/OneFormer)'s usage. 29 | ```bash 30 | CUDA_HOME=/usr/local/cuda # set to your own CUDA_HOME, where nvcc is installed 31 | bash ${PGDVS_ROOT}/scripts/preprocess/preprocess.sh \ 32 | ${CUDA_HOME} \ 33 | ${PGDVS_ROOT} \ 34 | ${PGDVS_ROOT}/data \ 35 | prepare 36 | ``` 37 | After running the command, repositories and pretrained checkpoints will be saved to `${PGDVS_ROOT}/third_parties` and `${PGDVS_ROOT}/ckpts` respectively. 38 | 39 | ## 3 Preprocess 40 | 41 | For a monocular video, we provide two ways to preprocess, i.e., obtaining camera poses, consistent depths, optical flows, and potentially dynamic masks: 42 | 1. **Two-step** camera pose and depth estimations: we need [`COLMAP`](https://github.com/colmap/colmap) for this. We first run `COLMAP` and then apply [Consistent Depth of Moving Objects in Video](https://arxiv.org/abs/2108.01166) (official code is [here](https://github.com/google/dynamic-video-depth) and our modified version is [here](https://github.com/Xiaoming-Zhao/dynamic-video-depth)). 43 | 2. **One-step** camera pose and depth estimations: we directly run [Structure and Motion from Casual Videos](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136930020.pdf) (official code is [here](https://github.com/ztzhang/casualSAM) and our modified version is [here](https://github.com/Xiaoming-Zhao/casualSAM)). 44 | 45 | **Note 1**: we modify the two consistent depth estimation tools above to tailor to our needs. For the preprocessing, please use **our forked/modified versions** ([1](https://github.com/Xiaoming-Zhao/dynamic-video-depth), [2](https://github.com/Xiaoming-Zhao/casualSAM)), which will be automatically handled by our [preprocess.sh](../scripts/preprocess/preprocess.sh). 46 | 47 | **Note 2**: we empirically find that the Two-Step version works better. 48 | 49 | ### 3.1 Two-step Camera Pose and Depth Estimations 50 | 51 | We need [`COLMAP`](https://github.com/colmap/colmap) for this. If `COLMAP` has not been installed yet, you can refer to [install_colmap.sh](../scripts/preprocess/install_colmap.sh) on how to manually install it. You may need to first set an environment variable `NOCONDA_PATH` by putting `export NOCONDA_PATH=$PATH` in your `.bashrc` (or equivalent shell setup file) before `conda` changes `PATH` (see [this issue](https://github.com/pism/pism/issues/356)). 52 | 53 | ```bash 54 | # Though the script should be able to run all steps automatically, 55 | # for debugging purpose, we would recommend running those commands one by one. 56 | # Namely, you could run each command by commenting out the rest. 57 | 58 | CUDA_HOME=/usr/local/cuda # set to your own CUDA_HOME 59 | SCENE_ID=dog 60 | bash ${PGDVS_ROOT}/scripts/preprocess/preprocess.sh \ 61 | ${CUDA_HOME} \ 62 | ${PGDVS_ROOT} \ 63 | ${PGDVS_ROOT}/data/DAVIS/JPEGImages/480p \ 64 | execute_on_mono_two_step_pose_depth \ 65 | ${PGDVS_ROOT}/data/DAVIS_processed_two_step_pose_depth \ 66 | ${SCENE_ID} \ 67 | /usr/bin/colmap # set to your own COLMAP binary file path 68 | ``` 69 | 70 | To ease future comparison to PGDVS, we provide four processed scenes from the DAVIS dataset on the [release page](https://github.com/apple/ml-pgdvs/releases/tag/v0.2). 71 | 72 | ### 3.2 One-step Camera Pose and Depth Estimations 73 | 74 | ```bash 75 | # Though the script should be able to run all steps automatically, 76 | # for debugging purpose, we would recommend running those commands one by one. 77 | # Namely, you could run each command by commenting out the rest. 78 | 79 | CUDA_HOME=/usr/local/cuda # set to your own CUDA_HOME 80 | SCENE_ID=dog 81 | bash ${PGDVS_ROOT}/scripts/preprocess/preprocess.sh \ 82 | ${CUDA_HOME} \ 83 | ${PGDVS_ROOT} \ 84 | ${PGDVS_ROOT}/data/DAVIS/JPEGImages/480p \ 85 | execute_on_mono_one_step_pose_depth \ 86 | ${PGDVS_ROOT}/data/DAVIS_processed_one_step_pose_depth \ 87 | ${SCENE_ID} 88 | ``` 89 | 90 | ## 4 Run Spatial Temporal Interpolation Visualization 91 | 92 | After completing preprocessing, we can run spatial temporal interpolation. Here we use Two-step Camera Pose and Depth Estimations's saved path as an exmaple. The result will be saved to `${PGDVS_ROOT}/experiments`. 93 | ```bash 94 | # vis_bt_max_disp: 95 | # - boat: 48 96 | # - dog: 48 97 | # - stroller: 96 98 | # - train: 48 99 | 100 | SCENE_ID='[dog]' 101 | 102 | bash ${PGDVS_ROOT}/scripts/visualize.sh \ 103 | ${PGDVS_ROOT} \ 104 | ${PGDVS_ROOT}/ckpts \ 105 | ${PGDVS_ROOT}/data/DAVIS_processed_two_step_pose_depth/ \ 106 | mono_vis \ 107 | ${SCENE_ID} \ 108 | engine.engine_cfg.render_cfg.render_stride=1 \ 109 | vis_specifics.vis_center_time=40 \ 110 | vis_specifics.vis_time_interval=30 \ 111 | vis_specifics.vis_bt_max_disp=48 \ 112 | vis_specifics.n_render_frames=100 113 | ``` -------------------------------------------------------------------------------- /pgdvs/models/tapnet/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for transforming image coordinates.""" 17 | 18 | from typing import Sequence 19 | import numpy as np 20 | 21 | import chex 22 | 23 | import torch 24 | 25 | 26 | def convert_grid_coordinates( 27 | coords: chex.Array, 28 | input_grid_size: Sequence[int], 29 | output_grid_size: Sequence[int], 30 | coordinate_format: str = "xy", 31 | ) -> chex.Array: 32 | """Convert image coordinates between image grids of different sizes. 33 | 34 | By default, it assumes that the image corners are aligned. Therefore, 35 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid 36 | cell), multiplies by the size ratio, and then subtracts .5. 37 | 38 | Args: 39 | coords: The coordinates to be converted. It is of shape [..., 2] if 40 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'. 41 | input_grid_size: The size of the image/grid that the coordinates currently 42 | are with respect to. This is a 2-tuple of the format [width, height] 43 | if coordinate_format is 'xy' or a 3-tuple of the format 44 | [num_frames, height, width] if coordinate_format is 'tyx'. 45 | output_grid_size: The size of the target image/grid that you want the 46 | coordinates to be with respect to. This is a 2-tuple of the format 47 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format 48 | [num_frames, height, width] if coordinate_format is 'tyx'. 49 | coordinate_format: Which format the coordinates are in. This can be one 50 | of 'xy' (the default) or 'tyx', which are the only formats used in this 51 | project. 52 | 53 | Returns: 54 | The transformed coordinates, of the same shape as coordinates. 55 | 56 | Raises: 57 | ValueError: if coordinates don't match the given format. 58 | """ 59 | if isinstance(input_grid_size, tuple): 60 | input_grid_size = np.array(input_grid_size) 61 | if isinstance(output_grid_size, tuple): 62 | output_grid_size = np.array(output_grid_size) 63 | 64 | if coordinate_format == "xy": 65 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: 66 | raise ValueError("If coordinate_format is xy, the shapes must be length 2.") 67 | elif coordinate_format == "tyx": 68 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: 69 | raise ValueError( 70 | "If coordinate_format is tyx, the shapes must be length 3." 71 | ) 72 | if input_grid_size[0] != output_grid_size[0]: 73 | raise ValueError("converting frame count is not supported.") 74 | else: 75 | raise ValueError("Recognized coordinate formats are xy and tyx.") 76 | 77 | position_in_grid = coords 78 | position_in_grid = position_in_grid * output_grid_size / input_grid_size 79 | 80 | return position_in_grid 81 | 82 | 83 | def convert_grid_coordinates_torch( 84 | coords: torch.Tensor, 85 | input_grid_size: Sequence[int], 86 | output_grid_size: Sequence[int], 87 | coordinate_format: str = "xy", 88 | ) -> torch.Tensor: 89 | """Convert image coordinates between image grids of different sizes. 90 | 91 | By default, it assumes that the image corners are aligned. Therefore, 92 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid 93 | cell), multiplies by the size ratio, and then subtracts .5. 94 | 95 | Args: 96 | coords: The coordinates to be converted. It is of shape [..., 2] if 97 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'. 98 | input_grid_size: The size of the image/grid that the coordinates currently 99 | are with respect to. This is a 2-tuple of the format [width, height] 100 | if coordinate_format is 'xy' or a 3-tuple of the format 101 | [num_frames, height, width] if coordinate_format is 'tyx'. 102 | output_grid_size: The size of the target image/grid that you want the 103 | coordinates to be with respect to. This is a 2-tuple of the format 104 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format 105 | [num_frames, height, width] if coordinate_format is 'tyx'. 106 | coordinate_format: Which format the coordinates are in. This can be one 107 | of 'xy' (the default) or 'tyx', which are the only formats used in this 108 | project. 109 | 110 | Returns: 111 | The transformed coordinates, of the same shape as coordinates. 112 | 113 | Raises: 114 | ValueError: if coordinates don't match the given format. 115 | """ 116 | if isinstance(input_grid_size, tuple): 117 | input_grid_size = torch.FloatTensor(input_grid_size).to(coords.device) 118 | if isinstance(output_grid_size, tuple): 119 | output_grid_size = torch.FloatTensor(output_grid_size).to(coords.device) 120 | 121 | if coordinate_format == "xy": 122 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: 123 | raise ValueError("If coordinate_format is xy, the shapes must be length 2.") 124 | elif coordinate_format == "tyx": 125 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: 126 | raise ValueError( 127 | "If coordinate_format is tyx, the shapes must be length 3." 128 | ) 129 | if input_grid_size[0] != output_grid_size[0]: 130 | raise ValueError("converting frame count is not supported.") 131 | else: 132 | raise ValueError("Recognized coordinate formats are xy and tyx.") 133 | 134 | position_in_grid = coords 135 | position_in_grid = position_in_grid * output_grid_size / input_grid_size 136 | 137 | return position_in_grid 138 | -------------------------------------------------------------------------------- /pgdvs/utils/comm.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/facebookresearch/detectron2/blob/45b3fcea6e76bf7a351e54e01c7d6e1a3a0100a5/detectron2/utils/comm.py 2 | 3 | 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | """ 6 | This file contains primitives for multi-gpu communication. 7 | This is useful when doing distributed training. 8 | """ 9 | 10 | import functools 11 | import numpy as np 12 | import torch 13 | import torch.distributed as dist 14 | 15 | _LOCAL_PROCESS_GROUP = None 16 | """ 17 | A torch process group which only includes processes that on the same machine as the current process. 18 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 19 | """ 20 | 21 | 22 | def get_world_size() -> int: 23 | if not dist.is_available(): 24 | return 1 25 | if not dist.is_initialized(): 26 | return 1 27 | return dist.get_world_size() 28 | 29 | 30 | def get_rank() -> int: 31 | if not dist.is_available(): 32 | return 0 33 | if not dist.is_initialized(): 34 | return 0 35 | return dist.get_rank() 36 | 37 | 38 | def get_local_rank() -> int: 39 | """ 40 | Returns: 41 | The rank of the current process within the local (per-machine) process group. 42 | """ 43 | if not dist.is_available(): 44 | return 0 45 | if not dist.is_initialized(): 46 | return 0 47 | assert ( 48 | _LOCAL_PROCESS_GROUP is not None 49 | ), "Local process group is not created! Please use launch() to spawn processes!" 50 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 51 | 52 | 53 | def get_local_size() -> int: 54 | """ 55 | Returns: 56 | The size of the per-machine process group, 57 | i.e. the number of processes per machine. 58 | """ 59 | if not dist.is_available(): 60 | return 1 61 | if not dist.is_initialized(): 62 | return 1 63 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 64 | 65 | 66 | def is_main_process() -> bool: 67 | return get_rank() == 0 68 | 69 | 70 | def synchronize(): 71 | """ 72 | Helper function to synchronize (barrier) among all processes when 73 | using distributed training 74 | """ 75 | if not dist.is_available(): 76 | return 77 | if not dist.is_initialized(): 78 | return 79 | world_size = dist.get_world_size() 80 | if world_size == 1: 81 | return 82 | if dist.get_backend() == dist.Backend.NCCL: 83 | # This argument is needed to avoid warnings. 84 | # It's valid only for NCCL backend. 85 | dist.barrier(device_ids=[torch.cuda.current_device()]) 86 | else: 87 | dist.barrier() 88 | 89 | 90 | @functools.lru_cache() 91 | def _get_global_gloo_group(): 92 | """ 93 | Return a process group based on gloo backend, containing all the ranks 94 | The result is cached. 95 | """ 96 | if dist.get_backend() == "nccl": 97 | return dist.new_group(backend="gloo") 98 | else: 99 | return dist.group.WORLD 100 | 101 | 102 | def all_gather(data, group=None): 103 | """ 104 | Run all_gather on arbitrary picklable data (not necessarily tensors). 105 | 106 | Args: 107 | data: any picklable object 108 | group: a torch process group. By default, will use a group which 109 | contains all ranks on gloo backend. 110 | 111 | Returns: 112 | list[data]: list of data gathered from each rank 113 | """ 114 | if get_world_size() == 1: 115 | return [data] 116 | if group is None: 117 | group = ( 118 | _get_global_gloo_group() 119 | ) # use CPU group by default, to reduce GPU RAM usage. 120 | world_size = dist.get_world_size(group) 121 | if world_size == 1: 122 | return [data] 123 | 124 | output = [None for _ in range(world_size)] 125 | dist.all_gather_object(output, data, group=group) 126 | return output 127 | 128 | 129 | def gather(data, dst=0, group=None): 130 | """ 131 | Run gather on arbitrary picklable data (not necessarily tensors). 132 | 133 | Args: 134 | data: any picklable object 135 | dst (int): destination rank 136 | group: a torch process group. By default, will use a group which 137 | contains all ranks on gloo backend. 138 | 139 | Returns: 140 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 141 | an empty list. 142 | """ 143 | if get_world_size() == 1: 144 | return [data] 145 | if group is None: 146 | group = _get_global_gloo_group() 147 | world_size = dist.get_world_size(group=group) 148 | if world_size == 1: 149 | return [data] 150 | rank = dist.get_rank(group=group) 151 | 152 | if rank == dst: 153 | output = [None for _ in range(world_size)] 154 | dist.gather_object(data, output, dst=dst, group=group) 155 | return output 156 | else: 157 | dist.gather_object(data, None, dst=dst, group=group) 158 | return [] 159 | 160 | 161 | def shared_random_seed(): 162 | """ 163 | Returns: 164 | int: a random number that is the same across all workers. 165 | If workers need a shared RNG, they can use this shared seed to 166 | create one. 167 | 168 | All workers must call this function, otherwise it will deadlock. 169 | """ 170 | ints = np.random.randint(2**31) 171 | all_ints = all_gather(ints) 172 | return all_ints[0] 173 | 174 | 175 | def reduce_dict(input_dict, average=True): 176 | """ 177 | Reduce the values in the dictionary from all processes so that process with rank 178 | 0 has the reduced results. 179 | 180 | Args: 181 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 182 | average (bool): whether to do average or sum 183 | 184 | Returns: 185 | a dict with the same keys as input_dict, after reduction. 186 | """ 187 | world_size = get_world_size() 188 | if world_size < 2: 189 | return input_dict 190 | with torch.no_grad(): 191 | names = [] 192 | values = [] 193 | # sort the keys so that they are consistent across processes 194 | for k in sorted(input_dict.keys()): 195 | names.append(k) 196 | values.append(input_dict[k]) 197 | values = torch.stack(values, dim=0) 198 | dist.reduce(values, dst=0) 199 | if dist.get_rank() == 0 and average: 200 | # only main process gets accumulated, so only divide by 201 | # world_size in this case 202 | values /= world_size 203 | reduced_dict = {k: v for k, v in zip(names, values)} 204 | return reduced_dict 205 | -------------------------------------------------------------------------------- /pgdvs/engines/visualizer_pgdvs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import pathlib 4 | import tqdm 5 | import traceback 6 | import imageio_ffmpeg 7 | import PIL.Image 8 | import numpy as np 9 | from collections import OrderedDict 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | import torchvision 14 | 15 | from pgdvs.engines.abstract import default_collate_fn 16 | from pgdvs.engines.trainer_pgdvs import PGDVSTrainer 17 | from pgdvs.utils.rendering import images_to_video 18 | 19 | 20 | LOGGER = logging.getLogger(__name__) 21 | 22 | 23 | class PGDVSVisualizer(PGDVSTrainer): 24 | @torch.no_grad() 25 | def run(self): 26 | self.vis_model() 27 | 28 | @torch.no_grad() 29 | def vis_model(self, epoch=None, flag_full_vis=False, dataset_split="vis"): 30 | try: 31 | LOGGER.info("Check FFMPEG exists.") 32 | imageio_ffmpeg.get_ffmpeg_exe() 33 | flag_ffmpeg_exe = True 34 | except RuntimeError: 35 | traceback.print_exc() 36 | err = sys.exc_info()[0] 37 | LOGGER.info(err) 38 | LOGGER.info( 39 | f"FFMPEG is not properly set therefore we do not automatically generate videos." 40 | ) 41 | flag_ffmpeg_exe = False 42 | 43 | sampler = torch.utils.data.distributed.DistributedSampler( 44 | self.datasets[dataset_split], 45 | num_replicas=self.world_size, 46 | rank=self.global_rank, 47 | shuffle=False, 48 | ) 49 | 50 | batch_size_per_proc = self.cfg.eval_batch_size // self.world_size 51 | real_eval_batch_size = batch_size_per_proc * self.world_size 52 | 53 | eval_dataloader = torch.utils.data.DataLoader( 54 | self.datasets[dataset_split], 55 | sampler=sampler, 56 | num_workers=self.cfg.n_dataloader_workers, 57 | batch_size=batch_size_per_proc, # self.cfg.eval_batch_size, 58 | collate_fn=default_collate_fn, 59 | ) 60 | 61 | if self.cfg.n_max_eval_data > 0: 62 | n_all_data = min( 63 | len(self.datasets[dataset_split]), self.cfg.n_max_eval_data 64 | ) 65 | else: 66 | n_all_data = len(self.datasets[dataset_split]) 67 | n_batches_per_epoch = int(np.ceil(n_all_data / real_eval_batch_size)) 68 | 69 | LOGGER.info( 70 | f"Rank {self.global_rank} | Local rank {self.local_rank} batch: #per_epoch {n_batches_per_epoch}, " 71 | f"#all_data {n_all_data}, #batch_size_per_proc {batch_size_per_proc}, #real_batch_size {real_eval_batch_size}" 72 | ) 73 | 74 | vis_dir_dict = {} 75 | 76 | for step, data in tqdm.tqdm( 77 | enumerate(eval_dataloader), 78 | total=n_batches_per_epoch, 79 | disable=not self.is_main_proc, 80 | leave=False, 81 | ): 82 | if step >= n_batches_per_epoch: 83 | break 84 | 85 | data_gpu = self._to_gpu_func(data, self.device) 86 | 87 | self.model.eval() 88 | 89 | n_batch, n_context, raw_h, raw_w, _ = data["rgb_src_temporal"].shape 90 | 91 | ret_dict = self._get_model_module(self.model).forward( 92 | data_gpu, 93 | render_cfg=self.engine_cfg.render_cfg, 94 | disable_tqdm=not self.verbose, 95 | for_debug=False, 96 | ) 97 | 98 | if False: 99 | self.debug_ret(ret_dict) 100 | 101 | rgb_pred_dict = OrderedDict( 102 | {"combined": ret_dict["combined_rgb"].clamp(0.0, 1.0)} 103 | ) 104 | 105 | for i_b in tqdm.tqdm(np.arange(n_batch), disable=True): 106 | if "split" in data["misc"][i_b]: 107 | tmp_split = data["misc"][i_b]["split"] 108 | else: 109 | tmp_split = "" 110 | tmp_scene_id = data["misc"][i_b]["scene_id"] 111 | tmp_tgt_idx = data["misc"][i_b]["tgt_idx"] 112 | 113 | tmp_fname = f"{tmp_tgt_idx:05d}" 114 | 115 | tmp_scene_vis_dir = ( 116 | pathlib.Path(self.VIS_DIR) / tmp_split / tmp_scene_id 117 | ) 118 | tmp_scene_vis_dir.mkdir(parents=True, exist_ok=True) 119 | 120 | vis_dir_dict[tmp_scene_id] = tmp_scene_vis_dir 121 | 122 | tmp_vis_combined_f = tmp_scene_vis_dir / f"{tmp_fname}_combined.png" 123 | torchvision.utils.save_image( 124 | rgb_pred_dict["combined"][i_b : (i_b + 1), ...], tmp_vis_combined_f 125 | ) 126 | 127 | if "static_coarse_rgb" in ret_dict: 128 | # save pure GNT results 129 | img_gnt = ( 130 | ret_dict["static_coarse_rgb"][i_b, ...] 131 | .clamp(0.0, 1.0) 132 | .permute(1, 2, 0) 133 | .cpu() 134 | .numpy() 135 | * 255 136 | ).astype(np.uint8) 137 | PIL.Image.fromarray(img_gnt).save( 138 | tmp_scene_vis_dir / f"{tmp_fname}_gnt.png" 139 | ) 140 | 141 | if flag_ffmpeg_exe: 142 | try: 143 | LOGGER.info("Converting images to video.") 144 | 145 | for tmp_scene_id in vis_dir_dict: 146 | tmp_vis_dir = vis_dir_dict[tmp_scene_id] 147 | 148 | tmp_combined_save_f = ( 149 | tmp_scene_vis_dir.parent / f"{tmp_scene_id}_combined.mp4" 150 | ) 151 | self._compute_video( 152 | tmp_vis_dir, "*_combined.png", tmp_combined_save_f 153 | ) 154 | except: 155 | traceback.print_exc() 156 | err = sys.exc_info()[0] 157 | LOGGER.info(err) 158 | LOGGER.info(f"Video generation fails.") 159 | 160 | def _compute_video(self, data_dir, glob_str, save_f, disable_tqdm=False): 161 | raw_combined_vis_f_list = sorted( 162 | list(data_dir.glob(glob_str)) 163 | ) # e.g., combined_00062_rank_6.png 164 | tmp_dict = {} 165 | for tmp_f in raw_combined_vis_f_list: 166 | tmp_idx = int(tmp_f.stem.split("_")[1]) 167 | tmp_dict[tmp_idx] = np.array(PIL.Image.open(tmp_f)) 168 | 169 | tmp_all_imgs = [tmp_dict[_] for _ in sorted(list(tmp_dict.keys()))] 170 | images_to_video( 171 | tmp_all_imgs, 172 | save_f.parent, 173 | save_f.stem, 174 | fps=10, 175 | quality=9, 176 | disable_tqdm=disable_tqdm, 177 | ) 178 | -------------------------------------------------------------------------------- /pgdvs/utils/nsff_lpips/__init__.py: -------------------------------------------------------------------------------- 1 | # Copy from https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/d4001759a39b056c95d8bc22da34b10b4fb85afb/nsff_exp/models/__init__.py 2 | 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import numpy as np 9 | 10 | # from skimage.measure import compare_ssim 11 | from skimage.metrics import structural_similarity as compare_ssim 12 | import torch 13 | from torch.autograd import Variable 14 | 15 | from . import dist_model 16 | 17 | 18 | class PerceptualLoss(torch.nn.Module): 19 | def __init__( 20 | self, 21 | model="net-lin", 22 | net="alex", 23 | colorspace="rgb", 24 | spatial=False, 25 | use_gpu=True, 26 | gpu_ids=[0], 27 | version="0.1", 28 | ): # VGG using our perceptually-learned weights (LPIPS metric) 29 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 30 | super(PerceptualLoss, self).__init__() 31 | print("Setting up Perceptual loss...") 32 | self.use_gpu = use_gpu 33 | self.spatial = spatial 34 | self.gpu_ids = gpu_ids 35 | self.model = dist_model.DistModel() 36 | self.model.initialize( 37 | model=model, 38 | net=net, 39 | use_gpu=use_gpu, 40 | colorspace=colorspace, 41 | spatial=self.spatial, 42 | gpu_ids=gpu_ids, 43 | version=version, 44 | ) 45 | print("...[%s] initialized" % self.model.name()) 46 | print("...Done") 47 | 48 | def forward(self, pred, target, mask=None, normalize=False): 49 | """ 50 | Pred and target are Variables. 51 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 52 | If normalize is False, assumes the images are already between [-1,+1] 53 | 54 | Inputs pred and target are Nx3xHxW 55 | Output pytorch Variable N long 56 | """ 57 | 58 | if normalize: 59 | target = 2 * target - 1 60 | pred = 2 * pred - 1 61 | 62 | return self.model.forward(target, pred, mask=mask) 63 | 64 | 65 | def normalize_tensor(in_feat, eps=1e-10): 66 | norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) 67 | return in_feat / (norm_factor + eps) 68 | 69 | 70 | def l2(p0, p1, range=255.0): 71 | return 0.5 * np.mean((p0 / range - p1 / range) ** 2) 72 | 73 | 74 | def psnr(p0, p1, peak=255.0): 75 | return 10 * np.log10(peak**2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2)) 76 | 77 | 78 | def dssim(p0, p1, range=255.0): 79 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0 80 | 81 | 82 | def rgb2lab(in_img, mean_cent=False): 83 | from skimage import color 84 | 85 | img_lab = color.rgb2lab(in_img) 86 | if mean_cent: 87 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 88 | return img_lab 89 | 90 | 91 | def tensor2np(tensor_obj): 92 | # change dimension of a tensor object into a numpy array 93 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 94 | 95 | 96 | def np2tensor(np_obj): 97 | # change dimenion of np array into tensor array 98 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 99 | 100 | 101 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 102 | # image tensor to lab tensor 103 | from skimage import color 104 | 105 | img = tensor2im(image_tensor) 106 | img_lab = color.rgb2lab(img) 107 | if mc_only: 108 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 109 | if to_norm and not mc_only: 110 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 111 | img_lab = img_lab / 100.0 112 | 113 | return np2tensor(img_lab) 114 | 115 | 116 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 117 | from skimage import color 118 | import warnings 119 | 120 | warnings.filterwarnings("ignore") 121 | 122 | lab = tensor2np(lab_tensor) * 100.0 123 | lab[:, :, 0] = lab[:, :, 0] + 50 124 | 125 | rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1) 126 | if return_inbnd: 127 | # convert back to lab, see if we match 128 | lab_back = color.rgb2lab(rgb_back.astype("uint8")) 129 | mask = 1.0 * np.isclose(lab_back, lab, atol=2.0) 130 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 131 | return (im2tensor(rgb_back), mask) 132 | else: 133 | return im2tensor(rgb_back) 134 | 135 | 136 | def rgb2lab(input): 137 | from skimage import color 138 | 139 | return color.rgb2lab(input / 255.0) 140 | 141 | 142 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 143 | image_numpy = image_tensor[0].cpu().float().numpy() 144 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 145 | return image_numpy.astype(imtype) 146 | 147 | 148 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 149 | return torch.Tensor( 150 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 151 | ) 152 | 153 | 154 | def tensor2vec(vector_tensor): 155 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 156 | 157 | 158 | def voc_ap(rec, prec, use_07_metric=False): 159 | """ap = voc_ap(rec, prec, [use_07_metric]) 160 | Compute VOC AP given precision and recall. 161 | If use_07_metric is true, uses the 162 | VOC 07 11 point method (default:False). 163 | """ 164 | if use_07_metric: 165 | # 11 point metric 166 | ap = 0.0 167 | for t in np.arange(0.0, 1.1, 0.1): 168 | if np.sum(rec >= t) == 0: 169 | p = 0 170 | else: 171 | p = np.max(prec[rec >= t]) 172 | ap = ap + p / 11.0 173 | else: 174 | # correct AP calculation 175 | # first append sentinel values at the end 176 | mrec = np.concatenate(([0.0], rec, [1.0])) 177 | mpre = np.concatenate(([0.0], prec, [0.0])) 178 | 179 | # compute the precision envelope 180 | for i in range(mpre.size - 1, 0, -1): 181 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 182 | 183 | # to calculate area under PR curve, look for points 184 | # where X axis (recall) changes value 185 | i = np.where(mrec[1:] != mrec[:-1])[0] 186 | 187 | # and sum (\Delta recall) * prec 188 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 189 | return ap 190 | 191 | 192 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 193 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 194 | image_numpy = image_tensor[0].cpu().float().numpy() 195 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 196 | return image_numpy.astype(imtype) 197 | 198 | 199 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 200 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 201 | return torch.Tensor( 202 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 203 | ) 204 | -------------------------------------------------------------------------------- /pgdvs/utils/nsff_lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | # Copy from https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/d4001759a39b056c95d8bc22da34b10b4fb85afb/nsff_exp/models/pretrained_networks.py 2 | 3 | 4 | from collections import namedtuple 5 | import torch 6 | from torchvision import models as tv 7 | 8 | 9 | class squeezenet(torch.nn.Module): 10 | def __init__(self, requires_grad=False, pretrained=True): 11 | super(squeezenet, self).__init__() 12 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 13 | self.slice1 = torch.nn.Sequential() 14 | self.slice2 = torch.nn.Sequential() 15 | self.slice3 = torch.nn.Sequential() 16 | self.slice4 = torch.nn.Sequential() 17 | self.slice5 = torch.nn.Sequential() 18 | self.slice6 = torch.nn.Sequential() 19 | self.slice7 = torch.nn.Sequential() 20 | self.N_slices = 7 21 | for x in range(2): 22 | self.slice1.add_module(str(x), pretrained_features[x]) 23 | for x in range(2, 5): 24 | self.slice2.add_module(str(x), pretrained_features[x]) 25 | for x in range(5, 8): 26 | self.slice3.add_module(str(x), pretrained_features[x]) 27 | for x in range(8, 10): 28 | self.slice4.add_module(str(x), pretrained_features[x]) 29 | for x in range(10, 11): 30 | self.slice5.add_module(str(x), pretrained_features[x]) 31 | for x in range(11, 12): 32 | self.slice6.add_module(str(x), pretrained_features[x]) 33 | for x in range(12, 13): 34 | self.slice7.add_module(str(x), pretrained_features[x]) 35 | if not requires_grad: 36 | for param in self.parameters(): 37 | param.requires_grad = False 38 | 39 | def forward(self, X): 40 | h = self.slice1(X) 41 | h_relu1 = h 42 | h = self.slice2(h) 43 | h_relu2 = h 44 | h = self.slice3(h) 45 | h_relu3 = h 46 | h = self.slice4(h) 47 | h_relu4 = h 48 | h = self.slice5(h) 49 | h_relu5 = h 50 | h = self.slice6(h) 51 | h_relu6 = h 52 | h = self.slice7(h) 53 | h_relu7 = h 54 | vgg_outputs = namedtuple( 55 | "SqueezeOutputs", 56 | ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"], 57 | ) 58 | out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) 59 | 60 | return out 61 | 62 | 63 | class alexnet(torch.nn.Module): 64 | def __init__(self, requires_grad=False, pretrained=True): 65 | super(alexnet, self).__init__() 66 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 67 | self.slice1 = torch.nn.Sequential() 68 | self.slice2 = torch.nn.Sequential() 69 | self.slice3 = torch.nn.Sequential() 70 | self.slice4 = torch.nn.Sequential() 71 | self.slice5 = torch.nn.Sequential() 72 | self.N_slices = 5 73 | for x in range(2): 74 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(2, 5): 76 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 77 | for x in range(5, 8): 78 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 79 | for x in range(8, 10): 80 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 81 | for x in range(10, 12): 82 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 83 | if not requires_grad: 84 | for param in self.parameters(): 85 | param.requires_grad = False 86 | 87 | def forward(self, X): 88 | h = self.slice1(X) 89 | h_relu1 = h 90 | h = self.slice2(h) 91 | h_relu2 = h 92 | h = self.slice3(h) 93 | h_relu3 = h 94 | h = self.slice4(h) 95 | h_relu4 = h 96 | h = self.slice5(h) 97 | h_relu5 = h 98 | alexnet_outputs = namedtuple( 99 | "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] 100 | ) 101 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 102 | 103 | return out 104 | 105 | 106 | class vgg16(torch.nn.Module): 107 | def __init__(self, requires_grad=False, pretrained=True): 108 | super(vgg16, self).__init__() 109 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 110 | self.slice1 = torch.nn.Sequential() 111 | self.slice2 = torch.nn.Sequential() 112 | self.slice3 = torch.nn.Sequential() 113 | self.slice4 = torch.nn.Sequential() 114 | self.slice5 = torch.nn.Sequential() 115 | self.N_slices = 5 116 | for x in range(4): 117 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 118 | for x in range(4, 9): 119 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 120 | for x in range(9, 16): 121 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 122 | for x in range(16, 23): 123 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 124 | for x in range(23, 30): 125 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 126 | if not requires_grad: 127 | for param in self.parameters(): 128 | param.requires_grad = False 129 | 130 | def forward(self, X): 131 | h = self.slice1(X) 132 | h_relu1_2 = h 133 | h = self.slice2(h) 134 | h_relu2_2 = h 135 | h = self.slice3(h) 136 | h_relu3_3 = h 137 | h = self.slice4(h) 138 | h_relu4_3 = h 139 | h = self.slice5(h) 140 | h_relu5_3 = h 141 | vgg_outputs = namedtuple( 142 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 143 | ) 144 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 145 | 146 | return out 147 | 148 | 149 | class resnet(torch.nn.Module): 150 | def __init__(self, requires_grad=False, pretrained=True, num=18): 151 | super(resnet, self).__init__() 152 | if num == 18: 153 | self.net = tv.resnet18(pretrained=pretrained) 154 | elif num == 34: 155 | self.net = tv.resnet34(pretrained=pretrained) 156 | elif num == 50: 157 | self.net = tv.resnet50(pretrained=pretrained) 158 | elif num == 101: 159 | self.net = tv.resnet101(pretrained=pretrained) 160 | elif num == 152: 161 | self.net = tv.resnet152(pretrained=pretrained) 162 | self.N_slices = 5 163 | 164 | self.conv1 = self.net.conv1 165 | self.bn1 = self.net.bn1 166 | self.relu = self.net.relu 167 | self.maxpool = self.net.maxpool 168 | self.layer1 = self.net.layer1 169 | self.layer2 = self.net.layer2 170 | self.layer3 = self.net.layer3 171 | self.layer4 = self.net.layer4 172 | 173 | def forward(self, X): 174 | h = self.conv1(X) 175 | h = self.bn1(h) 176 | h = self.relu(h) 177 | h_relu1 = h 178 | h = self.maxpool(h) 179 | h = self.layer1(h) 180 | h_conv2 = h 181 | h = self.layer2(h) 182 | h_conv3 = h 183 | h = self.layer3(h) 184 | h_conv4 = h 185 | h = self.layer4(h) 186 | h_conv5 = h 187 | 188 | outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]) 189 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 190 | 191 | return out 192 | -------------------------------------------------------------------------------- /pgdvs/models/cotracker/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from tqdm import tqdm 11 | from pgdvs.models.cotracker.models.core.cotracker.cotracker import ( 12 | get_points_on_a_grid, 13 | ) 14 | from pgdvs.models.cotracker.models.core.model_utils import smart_cat 15 | from pgdvs.models.cotracker.models.build_cotracker import ( 16 | build_cotracker, 17 | ) 18 | 19 | 20 | class CoTrackerPredictor(torch.nn.Module): 21 | def __init__( 22 | self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" 23 | ): 24 | super().__init__() 25 | self.interp_shape = (384, 512) 26 | self.support_grid_size = 6 27 | model = build_cotracker(checkpoint) 28 | 29 | self.model = model 30 | self.model.eval() 31 | 32 | @torch.no_grad() 33 | def forward( 34 | self, 35 | video, # (1, T, 3, H, W) 36 | # input prompt types: 37 | # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. 38 | # *backward_tracking=True* will compute tracks in both directions. 39 | # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates. 40 | # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. 41 | # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. 42 | queries: torch.Tensor = None, 43 | segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W) 44 | grid_size: int = 0, 45 | grid_query_frame: int = 0, # only for dense and regular grid tracks 46 | backward_tracking: bool = False, 47 | ): 48 | if queries is None and grid_size == 0: 49 | tracks, visibilities = self._compute_dense_tracks( 50 | video, 51 | grid_query_frame=grid_query_frame, 52 | backward_tracking=backward_tracking, 53 | ) 54 | else: 55 | tracks, visibilities = self._compute_sparse_tracks( 56 | video, 57 | queries, 58 | segm_mask, 59 | grid_size, 60 | add_support_grid=(grid_size == 0 or segm_mask is not None), 61 | grid_query_frame=grid_query_frame, 62 | backward_tracking=backward_tracking, 63 | ) 64 | 65 | return tracks, visibilities 66 | 67 | def _compute_dense_tracks( 68 | self, video, grid_query_frame, grid_size=30, backward_tracking=False 69 | ): 70 | *_, H, W = video.shape 71 | grid_step = W // grid_size 72 | grid_width = W // grid_step 73 | grid_height = H // grid_step 74 | tracks = visibilities = None 75 | grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) 76 | grid_pts[0, :, 0] = grid_query_frame 77 | for offset in tqdm(range(grid_step * grid_step), disable=True): 78 | ox = offset % grid_step 79 | oy = offset // grid_step 80 | grid_pts[0, :, 1] = ( 81 | torch.arange(grid_width).repeat(grid_height) * grid_step + ox 82 | ) 83 | grid_pts[0, :, 2] = ( 84 | torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy 85 | ) 86 | tracks_step, visibilities_step = self._compute_sparse_tracks( 87 | video=video, 88 | queries=grid_pts, 89 | backward_tracking=backward_tracking, 90 | ) 91 | tracks = smart_cat(tracks, tracks_step, dim=2) 92 | visibilities = smart_cat(visibilities, visibilities_step, dim=2) 93 | 94 | return tracks, visibilities 95 | 96 | def _compute_sparse_tracks( 97 | self, 98 | video, 99 | queries, 100 | segm_mask=None, 101 | grid_size=0, 102 | add_support_grid=False, 103 | grid_query_frame=0, 104 | backward_tracking=False, 105 | ): 106 | B, T, C, H, W = video.shape 107 | assert B == 1 108 | 109 | video = video.reshape(B * T, C, H, W) 110 | video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear") 111 | video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) 112 | 113 | if queries is not None: 114 | queries = queries.clone() 115 | B, N, D = queries.shape 116 | assert D == 3 117 | queries[:, :, 1] *= self.interp_shape[1] / W 118 | queries[:, :, 2] *= self.interp_shape[0] / H 119 | elif grid_size > 0: 120 | grid_pts = get_points_on_a_grid( 121 | grid_size, self.interp_shape, device=video.device 122 | ) 123 | if segm_mask is not None: 124 | segm_mask = F.interpolate( 125 | segm_mask, tuple(self.interp_shape), mode="nearest" 126 | ) 127 | point_mask = segm_mask[0, 0][ 128 | (grid_pts[0, :, 1]).round().long().cpu(), 129 | (grid_pts[0, :, 0]).round().long().cpu(), 130 | ].bool() 131 | grid_pts = grid_pts[:, point_mask] 132 | 133 | queries = torch.cat( 134 | [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], 135 | dim=2, 136 | ) 137 | 138 | if add_support_grid: 139 | grid_pts = get_points_on_a_grid( 140 | self.support_grid_size, self.interp_shape, device=video.device 141 | ) 142 | grid_pts = torch.cat( 143 | [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 144 | ) 145 | queries = torch.cat([queries, grid_pts], dim=1) 146 | 147 | tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6) 148 | 149 | if backward_tracking: 150 | tracks, visibilities = self._compute_backward_tracks( 151 | video, queries, tracks, visibilities 152 | ) 153 | if add_support_grid: 154 | queries[:, -self.support_grid_size**2 :, 0] = T - 1 155 | if add_support_grid: 156 | tracks = tracks[:, :, : -self.support_grid_size**2] 157 | visibilities = visibilities[:, :, : -self.support_grid_size**2] 158 | thr = 0.9 159 | visibilities = visibilities > thr 160 | tracks[:, :, :, 0] *= W / float(self.interp_shape[1]) 161 | tracks[:, :, :, 1] *= H / float(self.interp_shape[0]) 162 | return tracks, visibilities 163 | 164 | def _compute_backward_tracks(self, video, queries, tracks, visibilities): 165 | inv_video = video.flip(1).clone() 166 | inv_queries = queries.clone() 167 | inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 168 | 169 | inv_tracks, __, inv_visibilities, __ = self.model( 170 | rgbs=inv_video, queries=inv_queries, iters=6 171 | ) 172 | 173 | inv_tracks = inv_tracks.flip(1) 174 | inv_visibilities = inv_visibilities.flip(1) 175 | 176 | mask = tracks == 0 177 | 178 | tracks[mask] = inv_tracks[mask] 179 | visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] 180 | return tracks, visibilities 181 | -------------------------------------------------------------------------------- /pgdvs/models/tapnet/interface.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/deepmind/tapnet/blob/4ac6b2acd0aed36c0762f4247de9e8630340e2e0/colabs/tapir_demo.ipynb 2 | 3 | import tree 4 | import functools 5 | import numpy as np 6 | from typing import Any 7 | 8 | import torch 9 | 10 | import jax 11 | import jax.dlpack 12 | import haiku as hk 13 | from jax.lib import xla_bridge 14 | 15 | from pgdvs.models.tapnet import tapir_model 16 | from pgdvs.models.tapnet.utils import transforms 17 | from pgdvs.utils.rendering import modify_rgb_range 18 | 19 | 20 | def jax2torch(x): 21 | return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x)) 22 | 23 | 24 | def torch2jax(x): 25 | return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x.contiguous())) 26 | 27 | 28 | class TAPNetInterface: 29 | def __init__( 30 | self, 31 | ckpt_path, 32 | ori_rgb_range="0_1", 33 | query_chunk_size=4096, 34 | flag_keep_raw_res=False, 35 | local_rank=0, 36 | ): 37 | self.ori_rgb_range = ori_rgb_range 38 | self.query_chunk_size = query_chunk_size 39 | 40 | self.resize_h = 256 41 | self.resize_w = 256 42 | 43 | ckpt_state = np.load(ckpt_path, allow_pickle=True).item() 44 | self.params, self.state = ckpt_state["params"], ckpt_state["state"] 45 | 46 | build_model_func = functools.partial( 47 | build_model, query_chunk_size=self.query_chunk_size 48 | ) 49 | 50 | model = hk.transform_with_state(build_model_func) 51 | 52 | if xla_bridge.get_backend().platform == "gpu": 53 | device = jax.devices("gpu")[local_rank] 54 | else: 55 | device = jax.devices("cpu") 56 | 57 | self.model_apply = jax.jit(model.apply, device=device) 58 | 59 | self.flag_keep_raw_res = flag_keep_raw_res 60 | 61 | self.network_mult = 8 62 | 63 | def __call__(self, *, frames, query_points): 64 | # frames: [N, H, W, 3] 65 | # query_points: [#pt, 3], 3 for [time, row, col] 66 | 67 | _, orig_h, orig_w, _ = frames.shape 68 | 69 | if self.flag_keep_raw_res: 70 | if (orig_h % self.network_mult != 0) or (orig_w % self.network_mult != 0): 71 | resize_h = int(np.ceil(orig_h / self.network_mult) * self.network_mult) 72 | resize_w = int(np.ceil(orig_w / self.network_mult) * self.network_mult) 73 | else: 74 | resize_h = orig_h 75 | resize_w = orig_w 76 | else: 77 | resize_h = self.resize_h 78 | resize_w = self.resize_w 79 | 80 | if (orig_h != resize_h) or (orig_w != resize_w): 81 | frames = ( 82 | torch.nn.functional.interpolate( 83 | frames.permute(0, 3, 1, 2), 84 | size=(resize_h, resize_w), 85 | mode="bicubic", 86 | antialias=True, 87 | ) 88 | .permute(0, 2, 3, 1) 89 | .contiguous() 90 | ) 91 | query_cols_rows = transforms.convert_grid_coordinates_torch( 92 | query_points[:, [2, 1]], 93 | (orig_w, orig_h), 94 | (resize_w, resize_h), 95 | ) # transform accepts [u, v] or [col, row] format 96 | query_points[:, 1] = query_cols_rows[:, 1] # for rows 97 | query_points[:, 2] = query_cols_rows[:, 0] # for cols 98 | 99 | frames = modify_rgb_range( 100 | frames, 101 | src_range=self.ori_rgb_range, 102 | tgt_range="0_255", 103 | check_range=False, 104 | enforce_range=True, 105 | ) 106 | 107 | frames = torch2jax(frames) 108 | query_points = torch2jax(query_points) 109 | 110 | tracks, visibles = inference( 111 | self.model_apply, self.params, self.state, frames, query_points 112 | ) # tracks: [#pt, #frame, 2], float32; visibles: [#pt, #frame], bool 113 | 114 | if (orig_h != resize_h) or (orig_w != resize_w): 115 | tracks = transforms.convert_grid_coordinates_torch( 116 | tracks, (resize_w, resize_h), (orig_w, orig_h) 117 | ) 118 | 119 | return tracks, visibles 120 | 121 | 122 | def build_model(frames, query_points, query_chunk_size=64): 123 | """Compute point tracks and occlusions given frames and query points.""" 124 | model = tapir_model.TAPIR( 125 | bilinear_interp_with_depthwise_conv=False, pyramid_level=0 126 | ) 127 | outputs = model( 128 | video=frames, 129 | is_training=False, 130 | query_points=query_points, 131 | query_chunk_size=query_chunk_size, 132 | ) 133 | return outputs 134 | 135 | 136 | def preprocess_frames(frames): 137 | """Preprocess frames to model inputs. 138 | 139 | Args: 140 | frames: [num_frames, height, width, 3], [0, 255], np.uint8 141 | 142 | Returns: 143 | frames: [num_frames, height, width, 3], [-1, 1], np.float32 144 | """ 145 | frames = frames.astype(np.float32) 146 | frames = frames / 255 * 2 - 1 147 | return frames 148 | 149 | 150 | def postprocess_occlusions(occlusions, expected_dist): 151 | """Postprocess occlusions to boolean visible flag. 152 | 153 | Args: 154 | occlusions: [num_points, num_frames], [-inf, inf], np.float32 155 | expected_dist: [num_points, num_frames], [-inf, inf], np.float32 156 | 157 | Returns: 158 | visibles: [num_points, num_frames], bool 159 | """ 160 | visibles = (1 - jax.nn.sigmoid(occlusions)) * ( 161 | 1 - jax.nn.sigmoid(expected_dist) 162 | ) > 0.5 163 | return visibles 164 | 165 | 166 | def postprocess_occlusions_torch(occlusions, expected_dist): 167 | """Postprocess occlusions to boolean visible flag. 168 | 169 | Args: 170 | occlusions: [num_points, num_frames], [-inf, inf], np.float32 171 | expected_dist: [num_points, num_frames], [-inf, inf], np.float32 172 | 173 | Returns: 174 | visibles: [num_points, num_frames], bool 175 | """ 176 | visibles = (1 - torch.nn.functional.sigmoid(occlusions)) * ( 177 | 1 - torch.nn.functional.sigmoid(expected_dist) 178 | ) > 0.5 179 | return visibles 180 | 181 | 182 | def inference(model_apply, params, state, frames, query_points): 183 | """Inference on one video. 184 | 185 | Args: 186 | frames: [num_frames, height, width, 3], [0, 255], np.uint8 187 | query_points: [num_points, 3], [0, num_frames/height/width], [t, y, x] 188 | 189 | Returns: 190 | tracks: [num_points, 3], [-1, 1], [t, y, x] 191 | visibles: [num_points, num_frames], bool 192 | """ 193 | # Preprocess video to match model inputs format 194 | frames = preprocess_frames(frames) 195 | num_frames, height, width = frames.shape[0:3] 196 | query_points = query_points.astype(np.float32) 197 | frames, query_points = frames[None], query_points[None] # Add batch dimension 198 | 199 | # Model inference 200 | rng = jax.random.PRNGKey(42) 201 | outputs, _ = model_apply(params, state, rng, frames, query_points) 202 | 203 | if False: 204 | outputs = tree.map_structure(lambda x: np.array(x[0]), outputs) 205 | else: 206 | for k in ["tracks", "occlusion", "expected_dist"]: 207 | outputs[k] = jax2torch(outputs[k])[0, ...] 208 | 209 | tracks, occlusions, expected_dist = ( 210 | outputs["tracks"], 211 | outputs["occlusion"], 212 | outputs["expected_dist"], 213 | ) 214 | 215 | # Binarize occlusions 216 | if False: 217 | visibles = postprocess_occlusions(occlusions, expected_dist) 218 | else: 219 | visibles = postprocess_occlusions_torch(occlusions, expected_dist) 220 | return tracks, visibles 221 | -------------------------------------------------------------------------------- /pgdvs/utils/dycheck/camera.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import PIL.Image 4 | import numpy as np 5 | import os.path as osp 6 | from typing import Optional, Sequence, Tuple, Union, Dict 7 | 8 | 9 | class DyCheckCamera(object): 10 | """A generic camera class that potentially distorts rays. 11 | 12 | This camera class uses OpenCV camera model, whhere the local-to-world 13 | transform assumes (right, down, forward). 14 | 15 | Attributes: 16 | orientation (np.ndarray): The orientation of the camera of shape (3, 3) 17 | that maps the world coordinates to local coordinates. 18 | position (np.ndarray): The position of the camera of shape (3,) in the 19 | world coordinates. 20 | focal_length (Union[np.ndarray, float]): The focal length of the camera. 21 | principal_point (np.ndarray): The principal point of the camera of 22 | shape (2,) 23 | image_size (np.ndarray): The image size (W, H). 24 | skew (Union[np.ndarray, float]): The skewness of the camera. 25 | pixel_aspect_ratio (Union[np.ndarray, float]): The pixel aspect ratio. 26 | radial_distortion (Optional[np.ndarray]): The radial distortion of the 27 | camera of shape (3,). 28 | tangential_distortion (Optional[np.ndarray]): The tangential distortion 29 | of the camera of shape (2,). 30 | 31 | Modified from https://github.com/KAIR-BAIR/dycheck/blob/ddf77a4e006fdbc5aed28e0859c216da0de5aff5/dycheck/geometry/camera.py#L245 32 | """ 33 | 34 | def __init__( 35 | self, 36 | orientation: np.ndarray, 37 | position: np.ndarray, 38 | focal_length: Union[np.ndarray, float], 39 | principal_point: np.ndarray, 40 | image_size: np.ndarray, 41 | skew: Union[np.ndarray, float] = 0.0, 42 | pixel_aspect_ratio: Union[np.ndarray, float] = 1.0, 43 | radial_distortion: Optional[np.ndarray] = None, 44 | tangential_distortion: Optional[np.ndarray] = None, 45 | *, 46 | use_center: bool = True, 47 | use_projective_depth: bool = True, 48 | ): 49 | """Constructor for camera class.""" 50 | if radial_distortion is None: 51 | radial_distortion = np.array([0, 0, 0], np.float32) 52 | if tangential_distortion is None: 53 | tangential_distortion = np.array([0, 0], np.float32) 54 | 55 | self.orientation = np.array(orientation, np.float32) 56 | self.position = np.array(position, np.float32) 57 | self.focal_length = np.array(focal_length, np.float32) 58 | self.principal_point = np.array(principal_point, np.float32) 59 | self.image_size = np.array(image_size, np.uint32) 60 | 61 | # Distortion parameters. 62 | self.skew = np.array(skew, np.float32) 63 | self.pixel_aspect_ratio = np.array(pixel_aspect_ratio, np.float32) 64 | self.radial_distortion = np.array(radial_distortion, np.float32) 65 | self.tangential_distortion = np.array(tangential_distortion, np.float32) 66 | 67 | self.use_center = use_center 68 | self.use_projective_depth = use_projective_depth 69 | 70 | @classmethod 71 | def fromjson(cls, filename): 72 | with open(filename) as f: 73 | camera_dict = json.load(f) 74 | 75 | # Fix old camera JSON. 76 | if "tangential" in camera_dict: 77 | camera_dict["tangential_distortion"] = camera_dict["tangential"] 78 | 79 | return cls( 80 | orientation=np.asarray(camera_dict["orientation"]), 81 | position=np.asarray(camera_dict["position"]), 82 | focal_length=camera_dict["focal_length"], 83 | principal_point=np.asarray(camera_dict["principal_point"]), 84 | image_size=np.asarray(camera_dict["image_size"]), 85 | skew=camera_dict["skew"], 86 | pixel_aspect_ratio=camera_dict["pixel_aspect_ratio"], 87 | radial_distortion=np.asarray(camera_dict["radial_distortion"]), 88 | tangential_distortion=np.asarray(camera_dict["tangential_distortion"]), 89 | ) 90 | 91 | def rescale_image_domain(self, scale: float) -> "DyCheckCamera": 92 | """Rescale the image domain of the camera.""" 93 | if scale <= 0: 94 | raise ValueError("scale needs to be positive.") 95 | 96 | camera = self.copy() 97 | camera.focal_length *= scale 98 | camera.principal_point *= scale 99 | camera.image_size = np.array( 100 | ( 101 | int(round(self.image_size[0] * scale)), 102 | int(round(self.image_size[1] * scale)), 103 | ) 104 | ) 105 | return camera 106 | 107 | def translate(self, transl: np.ndarray) -> "DyCheckCamera": 108 | """Translate the camera.""" 109 | camera = self.copy() 110 | camera.position += transl 111 | return camera 112 | 113 | def rescale(self, scale: float) -> "DyCheckCamera": 114 | """Rescale the camera.""" 115 | if scale <= 0: 116 | raise ValueError("scale needs to be positive.") 117 | 118 | camera = self.copy() 119 | camera.position *= scale 120 | return camera 121 | 122 | @property 123 | def has_tangential_distortion(self): 124 | return any(self.tangential_distortion != 0) 125 | 126 | @property 127 | def has_radial_distortion(self): 128 | return any(self.radial_distortion != 0) 129 | 130 | @property 131 | def distortion(self): 132 | """Camera distortion parameters compatible with OpenCV. 133 | 134 | Reference: 135 | https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html 136 | """ 137 | return np.concatenate( 138 | [ 139 | self.radial_distortion[:2], 140 | self.tangential_distortion, 141 | self.radial_distortion[-1:], 142 | ] 143 | ) 144 | 145 | def copy(self) -> "DyCheckCamera": 146 | return copy.deepcopy(self) 147 | 148 | @property 149 | def scale_factor_x(self): 150 | return self.focal_length 151 | 152 | @property 153 | def scale_factor_y(self): 154 | return self.focal_length * self.pixel_aspect_ratio 155 | 156 | @property 157 | def principal_point_x(self): 158 | return self.principal_point[0] 159 | 160 | @property 161 | def principal_point_y(self): 162 | return self.principal_point[1] 163 | 164 | @property 165 | def translation(self): 166 | return -self.orientation @ self.position 167 | 168 | @property 169 | def optical_axis(self): 170 | return self.orientation[2, :] 171 | 172 | @property 173 | def up_axis(self): 174 | return -self.orientation[1, :] 175 | 176 | @property 177 | def intrin(self): 178 | return np.array( 179 | [ 180 | [self.scale_factor_x, self.skew, self.principal_point_x], 181 | [0, self.scale_factor_y, self.principal_point_y], 182 | [0, 0, 1], 183 | ], 184 | np.float32, 185 | ) 186 | 187 | @property 188 | def extrin(self): 189 | # 4x4 world-to-camera transform. 190 | return np.concatenate( 191 | [ 192 | np.concatenate( 193 | [self.orientation, self.translation[..., None]], axis=-1 194 | ), 195 | np.array([[0, 0, 0, 1]], np.float32), 196 | ], 197 | axis=-2, 198 | ) 199 | 200 | def asdict(self): 201 | return { 202 | "orientation": self.orientation, 203 | "position": self.position, 204 | "focal_length": self.focal_length, 205 | "principal_point": self.principal_point, 206 | "image_size": self.image_size, 207 | "skew": self.skew, 208 | "pixel_aspect_ratio": self.pixel_aspect_ratio, 209 | "radial_distortion": self.radial_distortion, 210 | "tangential_distortion": self.tangential_distortion, 211 | } 212 | -------------------------------------------------------------------------------- /pgdvs/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import trimesh 4 | import numpy as np 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | from matplotlib import cm 9 | from matplotlib.figure import Figure 10 | from matplotlib.backends.backend_agg import FigureCanvasAgg 11 | 12 | import torch 13 | 14 | HUGE_NUMBER = 1e10 15 | TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision 16 | 17 | 18 | def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2): 19 | """ 20 | :param w: pixels 21 | :param h: pixels 22 | :param vmin: min value 23 | :param vmax: max value 24 | :param cmap_name: 25 | :param label 26 | :return: 27 | """ 28 | fig = Figure(figsize=(2, 8), dpi=100) 29 | fig.subplots_adjust(right=1.5) 30 | canvas = FigureCanvasAgg(fig) 31 | 32 | # Do some plotting. 33 | ax = fig.add_subplot(111) 34 | cmap = cm.get_cmap(cmap_name) 35 | norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) 36 | 37 | tick_cnt = 6 38 | tick_loc = np.linspace(vmin, vmax, tick_cnt) 39 | cb1 = matplotlib.colorbar.ColorbarBase( 40 | ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical" 41 | ) 42 | 43 | tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc] 44 | if cbar_precision == 0: 45 | tick_label = [x[:-2] for x in tick_label] 46 | 47 | cb1.set_ticklabels(tick_label) 48 | 49 | cb1.ax.tick_params(labelsize=18, rotation=0) 50 | 51 | if label is not None: 52 | cb1.set_label(label) 53 | 54 | fig.tight_layout() 55 | 56 | canvas.draw() 57 | s, (width, height) = canvas.print_to_buffer() 58 | 59 | im = np.frombuffer(s, np.uint8).reshape((height, width, 4)) 60 | 61 | im = im[:, :, :3].astype(np.float32) / 255.0 62 | if h != im.shape[0]: 63 | w = int(im.shape[1] / im.shape[0] * h) 64 | im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA) 65 | 66 | return im 67 | 68 | 69 | def colorize_np( 70 | x, 71 | cmap_name="jet", 72 | mask=None, 73 | range=None, 74 | append_cbar=False, 75 | cbar_in_image=False, 76 | cbar_precision=2, 77 | ): 78 | """ 79 | turn a grayscale image into a color image 80 | :param x: input grayscale, [H, W] 81 | :param cmap_name: the colorization method 82 | :param mask: the mask image, [H, W] 83 | :param range: the range for scaling, automatic if None, [min, max] 84 | :param append_cbar: if append the color bar 85 | :param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image 86 | :return: colorized image, [H, W] 87 | """ 88 | if range is not None: 89 | vmin, vmax = range 90 | elif mask is not None: 91 | # vmin, vmax = np.percentile(x[mask], (2, 100)) 92 | vmin = np.min(x[mask][np.nonzero(x[mask])]) 93 | vmax = np.max(x[mask]) 94 | # vmin = vmin - np.abs(vmin) * 0.01 95 | x[np.logical_not(mask)] = vmin 96 | # print(vmin, vmax) 97 | else: 98 | vmin, vmax = np.percentile(x, (1, 100)) 99 | vmax += TINY_NUMBER 100 | 101 | x = np.clip(x, vmin, vmax) 102 | x = (x - vmin) / (vmax - vmin) 103 | # x = np.clip(x, 0., 1.) 104 | 105 | cmap = cm.get_cmap(cmap_name) 106 | x_new = cmap(x)[:, :, :3] 107 | 108 | if mask is not None: 109 | mask = np.float32(mask[:, :, np.newaxis]) 110 | x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask) 111 | 112 | cbar = get_vertical_colorbar( 113 | h=x.shape[0], 114 | vmin=vmin, 115 | vmax=vmax, 116 | cmap_name=cmap_name, 117 | cbar_precision=cbar_precision, 118 | ) 119 | 120 | if append_cbar: 121 | if cbar_in_image: 122 | x_new[:, -cbar.shape[1] :, :] = cbar 123 | else: 124 | x_new = np.concatenate( 125 | (x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1 126 | ) 127 | return x_new 128 | else: 129 | return x_new 130 | 131 | 132 | # tensor 133 | def colorize( 134 | x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False 135 | ): 136 | device = x.device 137 | x = x.cpu().numpy() 138 | if mask is not None: 139 | mask = mask.cpu().numpy() > 0.99 140 | kernel = np.ones((3, 3), np.uint8) 141 | mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool) 142 | 143 | x = colorize_np(x, cmap_name, mask, range, append_cbar, cbar_in_image) 144 | x = torch.from_numpy(x).to(device) 145 | return x 146 | 147 | 148 | class MplColorHelper: 149 | def __init__(self, cmap_name, start_val, stop_val): 150 | self.cmap_name = cmap_name 151 | self.cmap = plt.get_cmap(cmap_name) 152 | self.norm = matplotlib.colors.Normalize(vmin=start_val, vmax=stop_val) 153 | self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap) 154 | 155 | def get_rgb(self, val): 156 | return self.scalarMap.to_rgba(val) 157 | 158 | 159 | def hex2rgb(hex_val): 160 | hex_val = hex_val.lstrip("#") 161 | rgb = tuple(int(hex_val[i : i + 2], 16) for i in (0, 2, 4)) 162 | return np.array(rgb) 163 | 164 | 165 | COLOR_GRADS = [ 166 | "YlOrBr", 167 | "Blues", 168 | "Purples", 169 | "Greens", 170 | "YlOrBr", 171 | "Reds", 172 | "Oranges", 173 | "YlOrRd", 174 | "OrRd", 175 | "PuRd", 176 | "RdPu", 177 | "BuPu", 178 | "GnBu", 179 | "PuBu", 180 | "YlGnBu", 181 | "PuBuGn", 182 | "BuGn", 183 | "YlGn", 184 | "Greys", 185 | ] 186 | 187 | 188 | def draw_ray_pcl(idx, px_ray_pts, save_f): 189 | # px_ray_pts: [N, 3] 190 | tmp_n = px_ray_pts.shape[0] 191 | tmp_helper = MplColorHelper(COLOR_GRADS[idx], 0, tmp_n) 192 | px_ray_pt_rgbs = np.array([tmp_helper.get_rgb(_) for _ in range(tmp_n)])[ 193 | ::-1, :3 194 | ] # [N, 3] 195 | 196 | px_ray_pt_rgbs = (px_ray_pt_rgbs * 255).astype(np.uint8) 197 | px_pcl = trimesh.PointCloud( 198 | vertices=px_ray_pts, colors=px_ray_pt_rgbs, process=False 199 | ) 200 | _ = px_pcl.export(save_f) 201 | return px_ray_pt_rgbs 202 | 203 | 204 | def interp_head_tail(head_v, tail_v, N=100): 205 | line_dir = head_v - tail_v 206 | interp_vs = ( 207 | tail_v[None, :] + np.linspace(0, 1, N).reshape((N, 1)) * line_dir[None, :] 208 | ) # [N, 3] 209 | return interp_vs 210 | 211 | 212 | def draw_cam_mesh(w2c_mat, save_f, tmp_coord=0.1, flag_save=True): 213 | N = 100 214 | 215 | c2w = np.linalg.inv(w2c_mat) 216 | c_world = np.matmul(c2w, np.array([0, 0, 0, 1]).reshape((4, 1)))[:3, 0] 217 | x_world = np.matmul(c2w, np.array([tmp_coord, 0, 0, 1]).reshape((4, 1)))[:3, 0] 218 | y_world = np.matmul(c2w, np.array([0, tmp_coord, 0, 1]).reshape((4, 1)))[:3, 0] 219 | z_world = np.matmul(c2w, np.array([0, 0, tmp_coord, 1]).reshape((4, 1)))[:3, 0] 220 | 221 | x_interp_vs = interp_head_tail(x_world, c_world, N=N) 222 | y_interp_vs = interp_head_tail(y_world, c_world, N=N) 223 | z_interp_vs = interp_head_tail(z_world, c_world, N=N) 224 | 225 | # RGB for XYZ, [3 * N, 3] 226 | all_vs = np.concatenate((x_interp_vs, y_interp_vs, z_interp_vs), axis=0) 227 | all_colors = np.zeros((3 * N, 4), dtype=np.uint8) 228 | all_colors[:, 3] = 255 229 | all_colors[:N, 0] = 255 # red, X 230 | all_colors[N : 2 * N, 1] = 255 # green, Y 231 | all_colors[2 * N :, 2] = 255 # blue, Z 232 | 233 | if flag_save: 234 | frame = trimesh.points.PointCloud(vertices=all_vs, colors=all_colors) 235 | _ = frame.export(save_f) 236 | else: 237 | return all_vs, all_colors 238 | 239 | 240 | def draw_set_poses(poses, save_f, tmp_coord=1.0): 241 | n = poses.shape[0] 242 | tmp_placebolder = np.zeros((n, 1, 4)) 243 | tmp_placebolder[:, 0, 3] = 1 244 | poses = np.concatenate((poses[:, :3, :4], tmp_placebolder), axis=1) 245 | 246 | all_verts = [] 247 | all_colors = [] 248 | for i in range(n): 249 | tmp_c2w = poses[i, :4, :4] # [3, 4] 250 | tmp_verts, tmp_colors = draw_cam_mesh( 251 | np.linalg.inv(tmp_c2w), None, tmp_coord=tmp_coord, flag_save=False 252 | ) 253 | tmp_colors = (tmp_colors.astype(float) * (i + 1) / (n + 2)).astype(np.uint8) 254 | all_verts.append(tmp_verts) 255 | all_colors.append(tmp_colors) 256 | 257 | all_verts = np.concatenate(all_verts, axis=0) 258 | all_colors = np.concatenate(all_colors, axis=0) 259 | 260 | frame = trimesh.points.PointCloud(vertices=all_verts, colors=all_colors) 261 | _ = frame.export(save_f) 262 | -------------------------------------------------------------------------------- /pgdvs/utils/dycheck/misc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Functions are gathered and modified from https://github.com/KAIR-BAIR/dycheck/tree/main 4 | # 5 | # File : image.py 6 | # Author : Hang Gao 7 | # Email : hangg.sv7@gmail.com 8 | # 9 | # Copyright 2022 Adobe. All rights reserved. 10 | # 11 | # This file is licensed to you under the Apache License, Version 2.0 (the 12 | # "License"); you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 18 | # WARRANTIES OR REPRESENTATIONS OF ANY KIND, either express or implied. See the 19 | # License for the specific language governing permissions and limitations under 20 | # the License. 21 | 22 | import math 23 | from absl import logging 24 | import itertools 25 | from typing import Optional, Sequence, Any, Tuple 26 | import tqdm 27 | import cv2 28 | import numpy as np 29 | 30 | from pgdvs.utils.dycheck.camera import DyCheckCamera 31 | 32 | 33 | UINT8_MAX = 255 34 | UINT16_MAX = 65535 35 | 36 | 37 | def to_uint8(img: np.ndarray) -> np.ndarray: 38 | img = np.array(img) 39 | if img.dtype == np.uint8: 40 | return img 41 | if not issubclass(img.dtype.type, np.floating): 42 | raise ValueError( 43 | f"Input image should be a floating type but is of type " f"{img.dtype!r}." 44 | ) 45 | return (img * UINT8_MAX).clip(0.0, UINT8_MAX).astype(np.uint8) 46 | 47 | 48 | def to_float32(img: np.ndarray) -> np.ndarray: 49 | img = np.array(img) 50 | if img.dtype == np.float32: 51 | return img 52 | 53 | dtype = img.dtype 54 | img = img.astype(np.float32) 55 | if dtype == np.uint8: 56 | return img / UINT8_MAX 57 | elif dtype == np.uint16: 58 | return img / UINT16_MAX 59 | elif dtype == np.float64: 60 | return img 61 | elif dtype == np.float16: 62 | return img 63 | 64 | raise ValueError(f"Unexpected dtype: {dtype}.") 65 | 66 | 67 | def downscale(img: np.ndarray, scale: int) -> np.ndarray: 68 | if scale == 1: 69 | return img 70 | 71 | height, width = img.shape[:2] 72 | if height % scale > 0 or width % scale > 0: 73 | raise ValueError( 74 | f"Image shape ({height},{width}) must be divisible by the" 75 | f" scale ({scale})." 76 | ) 77 | out_height, out_width = height // scale, width // scale 78 | resized = cv2.resize(img, (out_width, out_height), cv2.INTER_AREA) 79 | return resized 80 | 81 | 82 | def upscale(img: np.ndarray, scale: int) -> np.ndarray: 83 | if scale == 1: 84 | return img 85 | 86 | height, width = img.shape[:2] 87 | out_height, out_width = height * scale, width * scale 88 | resized = cv2.resize(img, (out_width, out_height), cv2.INTER_AREA) 89 | return resized 90 | 91 | 92 | def rescale( 93 | img: np.ndarray, scale_factor: float, interpolation: Any = cv2.INTER_AREA 94 | ) -> np.ndarray: 95 | scale_factor = float(scale_factor) 96 | 97 | if scale_factor <= 0.0: 98 | raise ValueError("scale_factor must be a non-negative number.") 99 | if scale_factor == 1.0: 100 | return img 101 | 102 | height, width = img.shape[:2] 103 | if scale_factor.is_integer(): 104 | return upscale(img, int(scale_factor)) 105 | 106 | inv_scale = 1.0 / scale_factor 107 | if ( 108 | inv_scale.is_integer() 109 | and (scale_factor * height).is_integer() 110 | and (scale_factor * width).is_integer() 111 | ): 112 | return downscale(img, int(inv_scale)) 113 | 114 | logging.warning( 115 | "Resizing image by non-integer factor %f, this may lead to artifacts.", 116 | scale_factor, 117 | ) 118 | 119 | height, width = img.shape[:2] 120 | out_height = math.ceil(height * scale_factor) 121 | out_height -= out_height % 2 122 | out_width = math.ceil(width * scale_factor) 123 | out_width -= out_width % 2 124 | 125 | return resize(img, (out_height, out_width), interpolation) 126 | 127 | 128 | def resize( 129 | img: np.ndarray, 130 | shape: Tuple[int, int], 131 | interpolation: Any = cv2.INTER_AREA, 132 | ) -> np.ndarray: 133 | out_height, out_width = shape 134 | return cv2.resize( 135 | img, 136 | (out_width, out_height), 137 | interpolation=interpolation, 138 | ) 139 | 140 | 141 | def sobel_by_quantile(img_points: np.ndarray, q: float): 142 | """Return a boundary mask where 255 indicates boundaries (where gradient is 143 | bigger than quantile). 144 | """ 145 | dx0 = np.linalg.norm(img_points[1:-1, 1:-1] - img_points[1:-1, :-2], axis=-1) 146 | dx1 = np.linalg.norm(img_points[1:-1, 1:-1] - img_points[1:-1, 2:], axis=-1) 147 | dy0 = np.linalg.norm(img_points[1:-1, 1:-1] - img_points[:-2, 1:-1], axis=-1) 148 | dy1 = np.linalg.norm(img_points[1:-1, 1:-1] - img_points[2:, 1:-1], axis=-1) 149 | dx01 = (dx0 + dx1) / 2 150 | dy01 = (dy0 + dy1) / 2 151 | dxy01 = np.linalg.norm(np.stack([dx01, dy01], axis=-1), axis=-1) 152 | 153 | # (H, W, 1) uint8 154 | boundary_mask = (dxy01 > np.quantile(dxy01, q)).astype(np.float32) 155 | boundary_mask = ( 156 | np.pad(boundary_mask, ((1, 1), (1, 1)), constant_values=False)[ 157 | ..., None 158 | ].astype(np.uint8) 159 | * 255 160 | ) 161 | return boundary_mask 162 | 163 | 164 | def dilate(img: np.ndarray, kernel_size: Optional[int]): 165 | if kernel_size is None: 166 | return img 167 | is_float = np.issubdtype(img.dtype, np.floating) 168 | img = to_uint8(img) 169 | kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) 170 | dilated = cv2.dilate(img, kernel, iterations=1) 171 | if is_float: 172 | dilated = to_float32(dilated) 173 | return dilated 174 | 175 | 176 | def tsdf_fusion( 177 | imgs: np.ndarray, 178 | depths: np.ndarray, 179 | cameras: Sequence[DyCheckCamera], 180 | *, 181 | voxel_length: float = 1, 182 | sdf_trunc: float = 0.01, 183 | depth_far: float = 1e5, 184 | ): 185 | import open3d as o3d 186 | 187 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 188 | voxel_length=voxel_length, 189 | sdf_trunc=sdf_trunc, 190 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8, 191 | ) 192 | 193 | for rgb, depth, camera in zip( 194 | tqdm.tqdm(imgs, desc="* Fusing RGBDs"), 195 | depths, 196 | cameras, 197 | ): 198 | if (depth != 0).sum() == 0: 199 | continue 200 | # Make sure that the RGBD image is contiguous. 201 | rgb = o3d.geometry.Image(np.array(rgb)) 202 | depth = o3d.geometry.Image(np.array(depth)) 203 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 204 | rgb, 205 | depth, 206 | depth_scale=1, 207 | depth_trunc=depth_far, 208 | convert_rgb_to_intensity=False, 209 | ) 210 | w2c = camera.extrin 211 | W, H = camera.image_size 212 | fx = fy = camera.focal_length 213 | cx, cy = camera.principal_point 214 | volume.integrate( 215 | rgbd, o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy), w2c 216 | ) 217 | 218 | pcd = volume.extract_point_cloud() 219 | return np.asarray(pcd.points), np.asarray(pcd.colors) 220 | 221 | 222 | def get_bbox_segments(bbox: np.ndarray): 223 | points = x000, x001, x010, x011, x100, x101, x110, x111 = np.array( 224 | list(itertools.product(*bbox.T.tolist())) 225 | ) 226 | end_points = [x001, x011, x000, x010, x101, x111, x100, x110] 227 | points = points.tolist() 228 | points += [x000, x001, x010, x011] 229 | end_points += [x100, x101, x110, x111] 230 | 231 | return np.array(points), np.array(end_points) 232 | 233 | 234 | def tringulate_rays(origins: np.ndarray, viewdirs: np.ndarray) -> np.ndarray: 235 | """Triangulate a set of rays to find a single lookat point. 236 | 237 | Args: 238 | origins (np.ndarray): A (N, 3) array of ray origins. 239 | viewdirs (np.ndarray): A (N, 3) array of ray view directions. 240 | 241 | Returns: 242 | np.ndarray: A (3,) lookat point. 243 | """ 244 | import tensorflow as tf 245 | from tensorflow_graphics.geometry.representation.ray import ( 246 | triangulate as ray_triangulate, 247 | ) 248 | 249 | tf.config.set_visible_devices([], "GPU") 250 | 251 | gpus = tf.config.list_physical_devices("GPU") 252 | for gpu in gpus: 253 | tf.config.experimental.set_memory_growth(gpu, True) 254 | 255 | origins = np.array(origins[None], np.float32) 256 | viewdirs = np.array(viewdirs[None], np.float32) 257 | weights = np.ones(origins.shape[:2], dtype=np.float32) 258 | points = np.array(ray_triangulate(origins, origins + viewdirs, weights)) 259 | return points[0] 260 | -------------------------------------------------------------------------------- /docs/BENCHMARK_NVIDIA.md: -------------------------------------------------------------------------------- 1 | # Benchmark: NVIDIA Dynamic Scenes 2 | 3 | ## Table of Contents 4 | 5 | - [1 Download Data](#1-download-data) 6 | - [2 Preprocess](#2-preprocess) 7 | - [2.1 Use Precomputed Flow and Mask](#21-use-precomputed-flow-and-mask) 8 | - [2.2 Reproduce Preprocessed Results](#22-reproduce-preprocessed-results) 9 | - [2.3 [Optional] Compute ZoeDepth](#23-optional-compute-zoedepth) 10 | - [3 Run Benchmark](#3-run-benchmark) 11 | - [4 Run Spatial Temporal Interpolation Visualizations](#3-run-spatial-temporal-interpolation-visualizations) 12 | 13 | ## 1 Download Data 14 | 15 | ```bash 16 | # this environment variable is used for demonstration 17 | cd /path/to/this/repo 18 | export PGDVS_ROOT=$PWD 19 | ``` 20 | 21 | For a fair comparison to scene-specific approaches, we evaluate our proposed pipeline on the [DynIBaR](https://github.com/google/dynibar)'s [processed data](https://drive.google.com/drive/folders/1Gv6j_RvDG2WrpqEJWtx73u1tlCZKsPiM): 22 | 23 | ```bash 24 | # download processed data 25 | conda activate pgdvs 26 | gdown https://drive.google.com/drive/folders/1Gv6j_RvDG2WrpqEJWtx73u1tlCZKsPiM -O ${PGDVS_ROOT}/data/nvidia_long --folder 27 | 28 | # unzip 29 | ALL_SCENE_IDS=(Balloon1 Balloon2 Jumping Playground Skating Truck Umbrella dynamicFace) 30 | # rgb, poses 31 | printf "%s\0" "${ALL_SCENE_IDS[@]}" | xargs -0 -n 1 -I {} -P 8 unzip ${PGDVS_ROOT}/data/nvidia_long/{}.zip -d ${PGDVS_ROOT}/data/nvidia_long/ 32 | # depth 33 | printf "%s\0" "${ALL_SCENE_IDS[@]}" | xargs -0 -n 1 -I {} -P 8 unzip ${PGDVS_ROOT}/data/nvidia_long/Depths/{}_disp.zip -d ${PGDVS_ROOT}/data/nvidia_long/Depths/{}/ 34 | ``` 35 | 36 | After running the above command, you should have a structure as the following: 37 | ``` 38 | . 39 | +-- data 40 | | +-- nvidia_long 41 | | | +-- Balloon1 42 | | | +-- Balloon2 43 | | | ... 44 | | | +-- Depths 45 | | | | +-- Balloon1 46 | | | | | +-- disp 47 | | | | +-- Balloon2 48 | | | | | +-- disp 49 | | | | ... 50 | ``` 51 | 52 | ## 2 Preprocess 53 | 54 | ### 2.1 Use Precomputed Flow and Mask 55 | 56 | ```bash 57 | gdown 1wn9rgCRDWqOJHZmViFYP5vHUbGk81fVp -O ${PGDVS_ROOT}/data/ 58 | unzip ${PGDVS_ROOT}/data/nvidia_long_flow_mask.zip -d ${PGDVS_ROOT}/data 59 | ``` 60 | 61 | ### 2.2 Reproduce Preprocessed Results 62 | 63 | Our pseudo-generalized approach requires optical flow and mask for potentially dynamic content. For this, we need several third parties's repositories and checkpoints. **NOTE**: the `CUDA_HOME` must be set correctly for [`detectron2`](https://github.com/facebookresearch/detectron2)'s installation and consequentially [`OneFormer`](https://github.com/SHI-Labs/OneFormer)'s usage. 64 | ```bash 65 | CUDA_HOME=/usr/local/cuda # set to your own CUDA_HOME, where nvcc is installed 66 | bash ${PGDVS_ROOT}/scripts/preprocess/preprocess.sh \ 67 | ${CUDA_HOME} \ 68 | ${PGDVS_ROOT} \ 69 | ${PGDVS_ROOT}/data \ 70 | prepare 71 | ``` 72 | After running the command, repositories and pretrained checkpoints will be saved to `${PGDVS_ROOT}/third_parties` and `${PGDVS_ROOT}/ckpts` respectively. 73 | 74 | We then compute optical flow and mask with 75 | ```bash 76 | CUDA_HOME=/usr/local/cuda # set to your own CUDA_HOME 77 | bash ${PGDVS_ROOT}/scripts/preprocess/preprocess.sh \ 78 | ${CUDA_HOME} \ 79 | ${PGDVS_ROOT} \ 80 | ${PGDVS_ROOT}/data/nvidia_long \ 81 | execute_on_nvidia \ 82 | ${PGDVS_ROOT}/data/nvidia_long_flow_mask \ 83 | Balloon1 # can be one of [Balloon1 Balloon2 Jumping Playground Skating Truck Umbrella dynamicFace] 84 | ``` 85 | The computed optical flows and masks will be saved to `${PGDVS_ROOT}/data/nvidia_long_flow_mask` 86 | 87 | ### 2.3 \[Optional\] Compute ZoeDepth 88 | 89 | If you want to try using [ZoeDepth](https://github.com/isl-org/ZoeDepth), we provide a preprocessed data: 90 | ```bash 91 | gdown 1EcOzg8TSIbQtS7hw9-2RNIvq6L59P4RR -O ${PGDVS_ROOT}/data/ 92 | ``` 93 | 94 | Or you can compute them yourself. **Note**, the following command requires that the masks from either [2.1 Use Precomputed Flow and Mask](#21-use-precomputed-flow-and-mask) or [2.2 Reproduce Preprocessed Results](#22-reproduce-preprocessed-results) to align static area to mitigate scale and shift inconsistencies: 95 | ```bash 96 | bash ${PGDVS_ROOT}/scripts/preprocess/preprocess.sh \ 97 | /usr/local/cuda/ \ 98 | ${PGDVS_ROOT} \ 99 | ${PGDVS_ROOT}/data/nvidia_long \ 100 | execute_on_nvidia_zoedepth \ 101 | ${PGDVS_ROOT}/data/nvidia_long_zoedepth \ 102 | Balloon1 # can be one of [Balloon1 Balloon2 Jumping Playground Skating Truck Umbrella dynamicFace] 103 | ``` 104 | 105 | The computed depths will be saved to `${PGDVS_ROOT}/data/nvidia_long_zoedepth`. 106 | 107 | ## 3 Run Benchmark 108 | 109 | To obtain quantitative results, run the following command 110 | 111 | ```bash 112 | benchmark_type=default 113 | scene_id='[Balloon1]' # or 'null' to evaluate on all scenes 114 | 115 | bash ${PGDVS_ROOT}/scripts/benchmark.sh \ 116 | ${PGDVS_ROOT} \ 117 | ${PGDVS_ROOT}/ckpts \ 118 | ${PGDVS_ROOT}/data \ 119 | nvidia \ 120 | ${scene_id} \ 121 | ${benchmark_type} 122 | ``` 123 | You can choose `benchmark_type` from one of the following: 124 | 125 | | benchmark_type | static rendering | dynamic rendering | 126 | |:----------|:-------------| :-------------| 127 | | st_cvd_dy_cvd | **point renderer** from **consistent** depth | softsplat from **consistent** depth | 128 | | st_cvd_dy_cvd_pcl_clean | **point renderer** from **consistent** depth | softsplat from **consistent** depth and outlier removal for point cloud | 129 | | st_cvd_pcl_clean_dy_cvd_pcl_clean | **point renderer** from **consistent** depth and outlier removal for point cloud | softsplat from **consistent** depth and outlier removal for point cloud | 130 | | st_gnt | GNT with **full** input | none | 131 | | st_gnt_masked_attn | GNT with **full** input and **masked** attention | none | 132 | | st_gnt_dy_cvd | GNT with **full** input | softsplat from **consistent** depth | 133 | | st_gnt_dy_cvd_pcl_clean | GNT with **full** input | softsplat from **consistent** depth and outlier removal for point cloud | 134 | | st_gnt_masked_input_dy_cvd | GNT with **masked** input | softsplat from **consistent** depth | 135 | | st_gnt_masked_input_dy_cvd_pcl_clean | GNT with **masked** input | softsplat from **consistent** depth and outlier removal for point cloud | 136 | | st_gnt_masked_input_attn_dy_cvd_pcl_clean | GNT with **masked** input and **masked** attention | softsplat from **consistent** depth and outlier removal for point cloud | 137 | | st_gnt_masked_attn_dy_cvd_pcl_clean **(default)** | GNT with **full** input and **masked** attention | softsplat from **consistent** depth and outlier removal for point cloud | 138 | | st_gnt_masked_attn_dy_cvd_pcl_clean_render_point | GNT with **full** input and **masked** attention | **point renderer** from **consistent** depth and outlier removal for point cloud | 139 | | st_gnt_masked_attn_dy_cvd_pcl_clean_render_mesh | GNT with **full** input and **masked** attention | **mesh renderer** from **consistent** depth and outlier removal for point cloud | 140 | | st_gnt_masked_attn_dy_zoed_pcl_clean | GNT with **full** input and **masked** attention | softsplat from **ZoeDepth**, outlier removal for point cloud | 141 | | st_gnt_masked_attn_dy_cvd_pcl_clean_track_tapir | GNT with **full** input and **masked** attention | softsplat from **consistent** depth, outlier removal for point cloud, and **tracking** with TAPIR | 142 | | st_gnt_masked_attn_dy_cvd_pcl_clean_track_cotracker | GNT with **full** input and **masked** attention | from **consistent** depth, outlier removal for point cloud, and **tracking** with CoTracker | 143 | 144 | All results will be saved to `${PGDVS_ROOT}/experiments`, where subfolder `infos` contain image-wise quantitative results and subfolder `vis` stores image-wise renderings. 145 | 146 | Regarding the number of GPUs `benchmark.sh` uses: it will first check whether environment variable `CUDA_VISIBLE_DEVICES` has been set. If yes, then it just uses the number of GPUs specified by `CUDA_VISIBLE_DEVICES`. Otherwise, it will uses all GPUs on the server. 147 | 148 | With 8 A100 GPUs, for the evaluation on the whole dataset of 8 scenes (15840 images as the summation of the number of both training views and test views ) with the resolution about 288x550: 149 | - for each `benchmark_type` **without tracking**, the evaluation takes around 2 days 150 | - for `benchmark_type` **with tracking**, tracking with [TAPIR](https://github.com/deepmind/tapnet) needs around 5 days and tracking with [CoTracker](https://github.com/facebookresearch/co-tracker) needs 10 days due to the costly dense tracking. Therefore, for evaluation on these types, we highly recommend parallelizing the evaluation as largely as possible, e.g., evaluating one scene with 8 GPUs. 151 | 152 | 153 | ## 4 Run Spatial Temporal Interpolation Visualizations 154 | 155 | All results will be saved to `${PGDVS_ROOT}/experiments`. 156 | 157 | ```bash 158 | scene_id='[Balloon1]' 159 | 160 | bash ${PGDVS_ROOT}/scripts/visualize.sh \ 161 | ${PGDVS_ROOT} \ 162 | ${PGDVS_ROOT}/ckpts \ 163 | ${PGDVS_ROOT}/data \ 164 | nvidia_vis \ 165 | ${scene_id} 166 | ``` -------------------------------------------------------------------------------- /pgdvs/models/tapnet/models/tsm_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils functions for TSM.""" 17 | 18 | from typing import Tuple 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | 25 | def prepare_inputs(inputs: chex.Array) -> Tuple[jnp.ndarray, str, int]: 26 | """Deduces input mode for TSM.""" 27 | # Deduce if we run on TPU based on input shape. 28 | if len(inputs.shape) == 5: 29 | # Input is given in the standard [B, T, H, W, 3] format. 30 | tsm_mode = "gpu" 31 | num_frames = inputs.shape[1] 32 | inputs = jnp.reshape(inputs, [-1] + list(inputs.shape[2:])) 33 | else: 34 | # Input is given in the [T * B, H, W, 3] format. 35 | tsm_mode = "tpu" 36 | num_frames = None 37 | return inputs, tsm_mode, num_frames 38 | 39 | 40 | def prepare_outputs( 41 | outputs: chex.Array, 42 | tsm_mode: str, 43 | num_frames: int, 44 | reduce_mean: bool = True, 45 | ) -> jnp.ndarray: 46 | """Processes output of TSM to undo the merging of batch and time.""" 47 | # Get the shape without the batch/time dimension (for TSM batch and time are 48 | # merged in the first dimension). 49 | shape_no_bt = list(outputs.shape[1:]) 50 | if tsm_mode == "tpu": 51 | # Outputs are of the shape [num_frames * B, ..., n_channels] 52 | outputs = jnp.reshape(outputs, [num_frames, -1] + shape_no_bt) 53 | if reduce_mean: 54 | # We average over time and space. 55 | outputs = jnp.mean(outputs, axis=[0] + list(range(2, len(shape_no_bt) + 1))) 56 | else: 57 | outputs = jnp.transpose( 58 | outputs, axes=[1, 0] + list(range(2, len(shape_no_bt) + 2)) 59 | ) 60 | elif tsm_mode == "gpu": 61 | # Outputs are of the shape [B * num_frames, ..., n_channels]. 62 | outputs = jnp.reshape(outputs, [-1, num_frames] + shape_no_bt) 63 | if reduce_mean: 64 | outputs = jnp.mean(outputs, axis=[1] + list(range(2, len(shape_no_bt) + 1))) 65 | elif tsm_mode.startswith("deflated"): 66 | # In deflated mode, outputs are already in the right format. 67 | pass 68 | else: 69 | raise ValueError( 70 | "`tsm_mode` should be 'tpu' or 'gpu' or " 71 | f"'deflated_0.x' ({tsm_mode} given)" 72 | ) 73 | return outputs # pytype: disable=bad-return-type # numpy-scalars 74 | 75 | 76 | def apply_temporal_shift( 77 | x: chex.Array, 78 | tsm_mode: str, 79 | num_frames: int, 80 | channel_shift_fraction: float = 0.125, 81 | ) -> jnp.ndarray: 82 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383 with mode.""" 83 | if tsm_mode == "tpu": 84 | outputs = temporal_shift_tpu(x, num_frames, channel_shift_fraction) 85 | elif tsm_mode == "gpu": 86 | outputs = temporal_shift_gpu(x, num_frames, channel_shift_fraction) 87 | elif tsm_mode.startswith("deflated"): 88 | alpha = float(tsm_mode.split("_")[1]) 89 | outputs = temporal_shift_image_mode(x, channel_shift_fraction, alpha) 90 | else: 91 | raise ValueError( 92 | "`tsm_mode` should be 'tpu' or 'gpu' or " 93 | f"'deflated_0.x' ({tsm_mode} given)" 94 | ) 95 | return outputs 96 | 97 | 98 | def temporal_shift_image_mode(x, channel_shift_fraction=0.125, alpha=0.3): 99 | """Temporal shift applied on single image (to emulate a fixed video).""" 100 | # B, H, W, C = batch_size, im_height, im_width, channels. 101 | # Input is (B, H, W, C). 102 | orig_shp = tuple(x.shape) 103 | n_channels = orig_shp[-1] 104 | n_shift = int(n_channels * channel_shift_fraction) 105 | # Alpha emulates the effect of the padding when using a single frame. 106 | shifted_backward = alpha * x[:, :, :, -n_shift:] 107 | shifted_forward = alpha * x[:, :, :, :n_shift] 108 | no_shift = x[:, :, :, n_shift:-n_shift] 109 | shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward], axis=3) 110 | return shifted_x 111 | 112 | 113 | def temporal_shift_gpu( 114 | x: chex.Array, 115 | num_frames: int, 116 | channel_shift_fraction: float = 0.125, 117 | ) -> jnp.ndarray: 118 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383.""" 119 | # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels. 120 | # Input is (B * T, H, W, C). 121 | orig_shp = tuple(x.shape) 122 | reshaped_x = jnp.reshape(x, (-1, num_frames) + orig_shp[1:]) 123 | n_channels = orig_shp[-1] 124 | n_shift = int(n_channels * channel_shift_fraction) 125 | 126 | new_shp = tuple(reshaped_x.shape) 127 | 128 | # shifted_backward = reshaped_x[:, 1:, :, :, -n_shift:]. 129 | shifted_backward = jax.lax.slice( 130 | reshaped_x, 131 | (0, 1, 0, 0, new_shp[4] - n_shift), 132 | (new_shp[0], new_shp[1], new_shp[2], new_shp[3], new_shp[4]), 133 | ) 134 | shifted_backward_padding = ((0, 0), (0, 1), (0, 0), (0, 0), (0, 0)) 135 | shifted_backward = jnp.pad(shifted_backward, shifted_backward_padding) 136 | 137 | # shifted_forward = reshaped_x[:, :-1, :, :, :n_shift]. 138 | shifted_forward = jax.lax.slice( 139 | reshaped_x, 140 | (0, 0, 0, 0, 0), 141 | (new_shp[0], new_shp[1] - 1, new_shp[2], new_shp[3], n_shift), 142 | ) 143 | shifted_forward_padding = ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0)) 144 | shifted_forward = jnp.pad(shifted_forward, shifted_forward_padding) 145 | 146 | no_shift = reshaped_x[:, :, :, :, n_shift:-n_shift] 147 | shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward], axis=4) 148 | return jnp.reshape(shifted_x, (-1,) + orig_shp[1:]) 149 | 150 | 151 | def temporal_shift_tpu( 152 | x: chex.Array, 153 | num_frames: int, 154 | channel_shift_fraction: float = 0.125, 155 | ) -> jnp.ndarray: 156 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383. 157 | 158 | TPU optimized version of TSM. Reshape is avoided by having the images 159 | reshaped in [T * B, :] so that frames corresponding to same time frame in 160 | videos are contiguous in memory. Finally, to avoid concatenate that prevent 161 | some fusion from happening we simply sum masked version of the features. 162 | Args: 163 | x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped 164 | from a time major version of the input). 165 | num_frames: number of frames T per video. 166 | channel_shift_fraction: fraction of the channel to shift forward and 167 | backward. 168 | 169 | Returns: 170 | The temporal shifted version of x. 171 | """ 172 | # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels. 173 | # Input is (T * B, H, W, C). 174 | original_dtype = x.dtype 175 | original_shape = list(x.shape) 176 | 177 | batch_size = int(original_shape[0] / num_frames) 178 | n_channels = int(original_shape[-1]) 179 | n_shift = int(n_channels * channel_shift_fraction) 180 | 181 | # Cast to bfloat16. 182 | x = x.astype(jnp.bfloat16) 183 | 184 | # For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1. 185 | # Shift backward, we first pad by zeros [x1, x2, x3, 0, 0]. 186 | orig_shp = list(x.shape) 187 | 188 | shifted_backward_padding = ( 189 | (0, batch_size, 0), 190 | (0, 0, 0), 191 | (0, 0, 0), 192 | (0, n_channels - n_shift, 0), 193 | ) 194 | x_backward_padding = jax.lax.pad( 195 | x, padding_value=jnp.bfloat16(0.0), padding_config=shifted_backward_padding 196 | ) 197 | # The following shift gets to [x3^+1, 0, 0] (where +1 means from the future). 198 | shifted_backward = jax.lax.slice( 199 | x_backward_padding, 200 | (batch_size, 0, 0, n_channels - n_shift), 201 | (orig_shp[0] + batch_size, orig_shp[1], orig_shp[2], 2 * n_channels - n_shift), 202 | ) 203 | # Shift forward, we first pad by zeros [0, 0, x1, x2, x3]. 204 | shifted_forward_padding = ( 205 | (batch_size, 0, 0), 206 | (0, 0, 0), 207 | (0, 0, 0), 208 | (n_channels - n_shift, 0, 0), 209 | ) 210 | x_forward_padding = jax.lax.pad( 211 | x, padding_value=jnp.bfloat16(0.0), padding_config=shifted_forward_padding 212 | ) 213 | # The following shift gets to [0, 0, x1^-1] (where -1 means from the past). 214 | shifted_forward = jax.lax.slice( 215 | x_forward_padding, 216 | (0, 0, 0, 0), 217 | (orig_shp[0], orig_shp[1], orig_shp[2], n_channels), 218 | ) 219 | # No shift is in the middle, this gets [0, x2, 0]. 220 | mask_noshift = ( 221 | jnp.reshape( 222 | (jnp.arange(n_channels) >= n_shift) 223 | & (jnp.arange(n_channels) < n_channels - n_shift), 224 | (1, 1, 1, -1), 225 | ) 226 | ).astype(jnp.bfloat16) 227 | no_shift = mask_noshift * x 228 | # By summing everything together, we end up with [x3^+1, x2, x1^-1]. 229 | # Note: channels have been reordered but that doesn't matter for the model. 230 | shifted_x = shifted_backward + shifted_forward + no_shift 231 | 232 | return shifted_x.astype(original_dtype) 233 | --------------------------------------------------------------------------------