├── .gitignore ├── LICENSE.md ├── README.md ├── environment.yml ├── scripts ├── configs │ ├── kitti-06.yaml │ ├── paper │ │ ├── nvs │ │ │ ├── 0.25 │ │ │ │ ├── kitti-01.yaml │ │ │ │ ├── kitti-02.yaml │ │ │ │ ├── kitti-06.yaml │ │ │ │ ├── vkitti2-02.yaml │ │ │ │ ├── vkitti2-06.yaml │ │ │ │ └── vkitti2-18.yaml │ │ │ ├── 0.5 │ │ │ │ ├── kitti-01.yaml │ │ │ │ ├── kitti-02.yaml │ │ │ │ ├── kitti-06.yaml │ │ │ │ ├── vkitti2-02.yaml │ │ │ │ ├── vkitti2-06.yaml │ │ │ │ └── vkitti2-18.yaml │ │ │ └── 0.75 │ │ │ │ ├── kitti-01.yaml │ │ │ │ ├── kitti-02.yaml │ │ │ │ ├── kitti-06.yaml │ │ │ │ ├── vkitti2-02.yaml │ │ │ │ ├── vkitti2-06.yaml │ │ │ │ └── vkitti2-18.yaml │ │ └── reconstruction │ │ │ ├── kitti-01.yaml │ │ │ ├── kitti-02.yaml │ │ │ └── kitti-06.yaml │ └── vkitti2-06.yaml ├── create_kitti_depth_maps.py ├── create_kitti_feature_clusters.py ├── create_kitti_masks.py ├── create_kitti_metadata.py ├── create_vkitti2_feature_clusters.py ├── create_vkitti2_masks.py ├── create_vkitti2_metadata.py ├── extract_dino_correspondences.py ├── extract_dino_features.py ├── metadata_utils.py └── run_pca.py ├── setup.py └── suds ├── __init__.py ├── composite_proposal_network_sampler.py ├── cpp ├── suds_cpp.cpp └── suds_cuda.cu ├── data ├── __init__.py ├── dataset_utils.py ├── image_metadata.py ├── stream_input_dataset.py ├── suds_datamanager.py ├── suds_dataparser.py ├── suds_dataset.py ├── suds_eval_dataloader.py └── suds_pipeline.py ├── draw_utils.py ├── eval.py ├── fields ├── __init__.py ├── dynamic_field.py ├── dynamic_proposal_field.py ├── env_map_field.py ├── sharded │ ├── sharded_dynamic_field.py │ ├── sharded_dynamic_proposal_field.py │ ├── sharded_static_field.py │ └── sharded_static_proposal_field.py ├── static_field.py ├── static_proposal_field.py ├── suds_field_head_names.py └── video_embedding.py ├── interpolate.py ├── kmeans.py ├── metrics.py ├── render_images.py ├── sample_utils.py ├── stream_utils.py ├── suds_collider.py ├── suds_constants.py ├── suds_depth_renderer.py ├── suds_model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .vscode/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Haithem Turki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SUDS: Scalable Urban Dynamic Scenes 2 | 3 | [Haithem Turki](https://haithemturki.com), [Jason Y. Zhang](https://jasonyzhang.com/), [Francesco Ferroni](https://www.francescoferroni.com/), [Deva Ramanan](http://www.cs.cmu.edu/~deva) 4 | 5 | [Project Page](https://haithemturki.com/suds) / [Paper](https://haithemturki.com/suds/paper.pdf) 6 | 7 | 8 | This repository contains the code needed to train [SUDS](https://haithemturki.com/suds/) models. 9 | 10 | ## Citation 11 | 12 | ``` 13 | @misc{turki2023suds, 14 | title={SUDS: Scalable Urban Dynamic Scenes}, 15 | author={Haithem Turki and Jason Y. Zhang and Francesco Ferroni and Deva Ramanan}, 16 | year={2023}, 17 | eprint={2303.14536}, 18 | archivePrefix={arXiv}, 19 | primaryClass={cs.CV} 20 | } 21 | ``` 22 | 23 | ## Setup 24 | 25 | ``` 26 | conda env create -f environment.yml 27 | conda activate suds 28 | python setup.py install 29 | ``` 30 | 31 | The codebase has been mainly tested against CUDA >= 11.3 and A100/A6000 GPUs. GPUs with compute capability greater or equal to 7.5 should generally work, although you may need to adjust batch sizes to fit within GPU memory constraints. 32 | 33 | ## Data Preparation 34 | 35 | ### KITTI 36 | 37 | 1. Download the following from the [KITTI MOT dataset](http://www.cvlibs.net/datasets/kitti/eval_tracking.php): 38 | 1. [Left color images](http://www.cvlibs.net/download.php?file=data_tracking_image_2.zip) 39 | 2. [Right color images](http://www.cvlibs.net/download.php?file=data_tracking_image_3.zip) 40 | 3. [GPS/IMU data](http://www.cvlibs.net/download.php?file=data_tracking_oxts.zip) 41 | 4. [Camera calibration files](http://www.cvlibs.net/download.php?file=data_tracking_calib.zip) 42 | 5. [Velodyne point clouds](http://www.cvlibs.net/download.php?file=data_tracking_velodyne.zip) 43 | 6. (Optional) [Semantic labels](https://storage.googleapis.com/gresearch/tf-deeplab/data/kitti-step.tar.gz) 44 | 45 | 2. Extract everything to ```./data/kitti``` and keep the data structure 46 | 3. Generate depth maps from the Velodyne point clouds: ```python scripts/create_kitti_depth_maps.py --kitti_sequence $SEQUENCE``` 47 | 4. (Optional) Generate sky and static masks from semantic labels: ```python scripts/create_kitti_masks.py --kitti_sequence $SEQUENCE``` 48 | 5. Create metadata file: ```python scripts/create_kitti_metadata.py --config_file scripts/configs/$CONFIG_FILE``` 49 | 6. Extract DINO features: 50 | 1. ```python scripts/extract_dino_features.py --metadata_path $METADATA_PATH``` or ```python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node $NUM_GPUS scripts/extract_dino_features.py --metadata_path $METADATA_PATH``` for multi-GPU extraction 51 | 2. ```python scripts/run_pca.py --metadata_path $METADATA_PATH``` 52 | 7. Extract DINO correspondences: ```python scripts/extract_dino_correspondences.py --metadata_path $METADATA_PATH``` or ```python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node $NUM_GPUS scripts/extract_dino_correspondences.py --metadata_path $METADATA_PATH``` for multi-GPU extraction 53 | 8. (Optional) Generate feature clusters for visualization: ```python scripts/create_kitti_feature_clusters.py --metadata_path $METADATA_PATH --output_path $OUTPUT_PATH``` 54 | 55 | ### VKITTI2 56 | 57 | 1. Download the following from the [VKITTI2 dataset](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/): 58 | 1. [RGB images](http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_rgb.tar) 59 | 2. [Depth images](http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_depth.tar) 60 | 3. [Camera intrinsics/extrinsics](http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_textgt.tar.gz) 61 | 4. (Optional) [Ground truth forward flow](http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_forwardFlow.tar) 62 | 5. (Optional) [Ground truth backward flow](http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_backwardFlow.tar) 63 | 6. (Optional) [Semantic labels](http://download.europe.naverlabs.com//virtual_kitti_2.0.3/vkitti_2.0.3_classSegmentation.tar) 64 | 65 | 2. Extract everything to ```./data/vkitti2``` and keep the data structure 66 | 3. (Optional) Generate sky and static masks from semantic labels: ```python scripts/create_vkitti2_masks.py --vkitti2_path $SCENE_PATH``` 67 | 4. Create metadata file: ```python scripts/create_vkitti2_metadata.py --config_file scripts/configs/$CONFIG_FILE``` 68 | 5. Extract DINO features: 69 | 1. ```python scripts/extract_dino_features.py --metadata_path $METADATA_PATH``` or ```python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node $NUM_GPUS scripts/extract_dino_features.py --metadata_path $METADATA_PATH``` for multi-GPU extraction 70 | 2. ```python scripts/run_pca.py --metadata_path $METADATA_PATH``` 71 | 6. If not using the ground truth flow provided by VKITTI2, extract DINO correspondences: ```python scripts/extract_dino_correspondences.py --metadata_path $METADATA_PATH``` or ```python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node $NUM_GPUS scripts/extract_dino_correspondences.py --metadata_path $METADATA_PATH``` for multi-GPU extraction 72 | 7. (Optional) Generate feature clusters for visualization: ```python scripts/create_vkitti2_feature_clusters.py --metadata_path $METADATA_PATH --vkitti2_path $SCENE_PATH --output_path $OUTPUT_PATH``` 73 | 74 | ## Training 75 | 76 | ```python suds/train.py suds --experiment-name $EXPERIMENT_NAME --pipeline.datamanager.dataparser.metadata_path $METADATA_PATH [--pipeline.feature_clusters $FEATURE_CLUSTERS]``` 77 | 78 | ## Evaluation 79 | 80 | ```python suds/eval.py --load_config $SAVED_MODEL_PATH``` or ```python -m torch.distributed.run --standalone --nnodes=1 --nproc_per_node $NUM_GPUS suds/eval.py --load_config $SAVED_MODEL_PATH``` for multi-GPU evaluation 81 | 82 | ## Acknowledgements 83 | 84 | This project is built on [Nerfstudio](https://github.com/nerfstudio-project/nerfstudio) and [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn). The DINO feature extraction scripts are based on [ShirAmir's implementation](https://github.com/ShirAmir/dino-vit-features) and parts of the KITTI processing code from [Neural Scene Graphs](https://github.com/princeton-computational-imaging/neural-scene-graphs). -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: suds 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotlipy=0.7.0=py39h27cfd23_1003 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2023.01.10=h06a4308_0 13 | - certifi=2022.12.7=py39h06a4308_0 14 | - cffi=1.15.1=py39h5eee18b_3 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - cryptography=39.0.1=py39h9ce1e76_0 17 | - cuda=11.6.1=0 18 | - cuda-cccl=11.6.55=hf6102b2_0 19 | - cuda-command-line-tools=11.6.2=0 20 | - cuda-compiler=11.6.2=0 21 | - cuda-cudart=11.6.55=he381448_0 22 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 23 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 24 | - cuda-cupti=11.6.124=h86345e5_0 25 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 26 | - cuda-driver-dev=11.6.55=0 27 | - cuda-gdb=12.1.55=0 28 | - cuda-libraries=11.6.1=0 29 | - cuda-libraries-dev=11.6.1=0 30 | - cuda-memcheck=11.8.86=0 31 | - cuda-nsight=12.1.55=0 32 | - cuda-nsight-compute=12.1.0=0 33 | - cuda-nvcc=11.6.124=hbba6d2d_0 34 | - cuda-nvdisasm=12.1.55=0 35 | - cuda-nvml-dev=11.6.55=haa9ef22_0 36 | - cuda-nvprof=12.1.55=0 37 | - cuda-nvprune=11.6.124=he22ec0a_0 38 | - cuda-nvrtc=11.6.124=h020bade_0 39 | - cuda-nvrtc-dev=11.6.124=h249d397_0 40 | - cuda-nvtx=11.6.124=h0630a44_0 41 | - cuda-nvvp=12.1.55=0 42 | - cuda-runtime=11.6.1=0 43 | - cuda-samples=11.6.101=h8efea70_0 44 | - cuda-sanitizer-api=12.1.55=0 45 | - cuda-toolkit=11.6.1=0 46 | - cuda-tools=11.6.1=0 47 | - cuda-visual-tools=11.6.1=0 48 | - ffmpeg=4.3=hf484d3e_0 49 | - flit-core=3.8.0=py39h06a4308_0 50 | - freetype=2.12.1=h4a9f257_0 51 | - gds-tools=1.6.0.25=0 52 | - giflib=5.2.1=h5eee18b_3 53 | - gmp=6.2.1=h295c915_3 54 | - gnutls=3.6.15=he1e5248_0 55 | - idna=3.4=py39h06a4308_0 56 | - intel-openmp=2021.4.0=h06a4308_3561 57 | - jpeg=9e=h5eee18b_1 58 | - lame=3.100=h7b6447c_0 59 | - lcms2=2.12=h3be6417_0 60 | - ld_impl_linux-64=2.38=h1181459_1 61 | - lerc=3.0=h295c915_0 62 | - libcublas=11.9.2.110=h5e84587_0 63 | - libcublas-dev=11.9.2.110=h5c901ab_0 64 | - libcufft=10.7.1.112=hf425ae0_0 65 | - libcufft-dev=10.7.1.112=ha5ce4c0_0 66 | - libcufile=1.6.0.25=0 67 | - libcufile-dev=1.6.0.25=0 68 | - libcurand=10.3.2.56=0 69 | - libcurand-dev=10.3.2.56=0 70 | - libcusolver=11.3.4.124=h33c3c4e_0 71 | - libcusparse=11.7.2.124=h7538f96_0 72 | - libcusparse-dev=11.7.2.124=hbbe9722_0 73 | - libdeflate=1.17=h5eee18b_0 74 | - libffi=3.4.2=h6a678d5_6 75 | - libgcc-ng=11.2.0=h1234567_1 76 | - libgomp=11.2.0=h1234567_1 77 | - libiconv=1.16=h7f8727e_2 78 | - libidn2=2.3.2=h7f8727e_0 79 | - libnpp=11.6.3.124=hd2722f0_0 80 | - libnpp-dev=11.6.3.124=h3c42840_0 81 | - libnvjpeg=11.6.2.124=hd473ad6_0 82 | - libnvjpeg-dev=11.6.2.124=hb5906b9_0 83 | - libpng=1.6.39=h5eee18b_0 84 | - libstdcxx-ng=11.2.0=h1234567_1 85 | - libtasn1=4.16.0=h27cfd23_0 86 | - libtiff=4.5.0=h6a678d5_2 87 | - libunistring=0.9.10=h27cfd23_0 88 | - libwebp=1.2.4=h11a3e52_1 89 | - libwebp-base=1.2.4=h5eee18b_1 90 | - lz4-c=1.9.4=h6a678d5_0 91 | - mkl=2021.4.0=h06a4308_640 92 | - mkl-service=2.4.0=py39h7f8727e_0 93 | - mkl_fft=1.3.1=py39hd3c417c_0 94 | - mkl_random=1.2.2=py39h51133e4_0 95 | - ncurses=6.4=h6a678d5_0 96 | - nettle=3.7.3=hbbd107a_1 97 | - nsight-compute=2023.1.0.15=0 98 | - numpy=1.23.5=py39h14f4228_0 99 | - numpy-base=1.23.5=py39h31eccc5_0 100 | - openh264=2.1.1=h4ff587b_0 101 | - openssl=1.1.1t=h7f8727e_0 102 | - pillow=9.4.0=py39h6a678d5_0 103 | - pip=23.0.1=py39h06a4308_0 104 | - pycparser=2.21=pyhd3eb1b0_0 105 | - pyopenssl=23.0.0=py39h06a4308_0 106 | - pysocks=1.7.1=py39h06a4308_0 107 | - python=3.9.16=h7a1cb2a_2 108 | - pytorch=1.13.1=py3.9_cuda11.6_cudnn8.3.2_0 109 | - pytorch-cuda=11.6=h867d48c_1 110 | - pytorch-mutex=1.0=cuda 111 | - readline=8.2=h5eee18b_0 112 | - requests=2.28.1=py39h06a4308_1 113 | - setuptools=65.6.3=py39h06a4308_0 114 | - six=1.16.0=pyhd3eb1b0_1 115 | - sqlite=3.41.1=h5eee18b_0 116 | - tk=8.6.12=h1ccaba5_0 117 | - torchvision=0.14.1=py39_cu116 118 | - typing_extensions=4.4.0=py39h06a4308_0 119 | - tzdata=2022g=h04d1e81_0 120 | - urllib3=1.26.14=py39h06a4308_0 121 | - wheel=0.38.4=py39h06a4308_0 122 | - xz=5.2.10=h5eee18b_1 123 | - zlib=1.2.13=h5eee18b_0 124 | - zstd=1.5.2=ha4553b6_0 125 | - pip: 126 | - absl-py==1.4.0 127 | - aiofiles==22.1.0 128 | - aiosqlite==0.18.0 129 | - anyio==3.6.2 130 | - appdirs==1.4.4 131 | - argon2-cffi==21.3.0 132 | - argon2-cffi-bindings==21.2.0 133 | - arrow==1.2.3 134 | - asttokens==2.2.1 135 | - attrs==22.2.0 136 | - av==10.0.0 137 | - babel==2.12.1 138 | - backcall==0.2.0 139 | - beautifulsoup4==4.12.0 140 | - bidict==0.22.1 141 | - bleach==6.0.0 142 | - cachetools==5.3.0 143 | - click==8.1.3 144 | - comm==0.1.3 145 | - configargparse==1.5.3 146 | - contourpy==1.0.7 147 | - cycler==0.11.0 148 | - debugpy==1.6.6 149 | - decorator==5.1.1 150 | - defusedxml==0.7.1 151 | - descartes==1.1.0 152 | - docker-pycreds==0.4.0 153 | - docstring-parser==0.14.1 154 | - executing==1.2.0 155 | - fastjsonschema==2.16.3 156 | - filelock==3.10.3 157 | - fire==0.5.0 158 | - fonttools==4.39.2 159 | - fqdn==1.5.1 160 | - frozendict==2.3.6 161 | - gdown==4.6.4 162 | - gitdb==4.0.10 163 | - gitpython==3.1.31 164 | - google-auth==2.16.3 165 | - google-auth-oauthlib==0.4.6 166 | - grpcio==1.51.3 167 | - h5py==3.8.0 168 | - huggingface-hub==0.13.3 169 | - imageio==2.26.1 170 | - importlib-metadata==6.1.0 171 | - importlib-resources==5.12.0 172 | - ipykernel==6.22.0 173 | - ipython==8.11.0 174 | - ipython-genutils==0.2.0 175 | - ipywidgets==8.0.5 176 | - isoduration==20.11.0 177 | - jedi==0.18.2 178 | - jinja2==3.1.2 179 | - joblib==1.2.0 180 | - json5==0.9.11 181 | - jsonpointer==2.3 182 | - jsonschema==4.17.3 183 | - jupyter==1.0.0 184 | - jupyter-client==8.1.0 185 | - jupyter-console==6.6.3 186 | - jupyter-core==5.3.0 187 | - jupyter-events==0.6.3 188 | - jupyter-server==2.5.0 189 | - jupyter-server-fileid==0.8.0 190 | - jupyter-server-terminals==0.4.4 191 | - jupyter-server-ydoc==0.8.0 192 | - jupyter-ydoc==0.2.3 193 | - jupyterlab==3.6.2 194 | - jupyterlab-pygments==0.2.2 195 | - jupyterlab-server==2.21.0 196 | - jupyterlab-widgets==3.0.6 197 | - kiwisolver==1.4.4 198 | - lazy-loader==0.2 199 | - lpips==0.1.4 200 | - markdown==3.4.3 201 | - markdown-it-py==2.2.0 202 | - markupsafe==2.1.2 203 | - matplotlib==3.7.1 204 | - matplotlib-inline==0.1.6 205 | - mdurl==0.1.2 206 | - mediapy==1.1.6 207 | - mistune==2.0.5 208 | - msgpack==1.0.5 209 | - msgpack-numpy==0.4.8 210 | - nbclassic==0.5.3 211 | - nbclient==0.7.2 212 | - nbconvert==7.2.10 213 | - nbformat==5.8.0 214 | - nerfacc==0.3.5 215 | - nest-asyncio==1.5.6 216 | - networkx==3.0 217 | - ninja==1.11.1 218 | - notebook==6.5.3 219 | - notebook-shim==0.2.2 220 | - nuscenes-devkit==1.1.9 221 | - oauthlib==3.2.2 222 | - opencv-python==4.6.0.66 223 | - packaging==23.0 224 | - pandas==1.5.3 225 | - pandocfilters==1.5.0 226 | - parso==0.8.3 227 | - pathtools==0.1.2 228 | - pexpect==4.8.0 229 | - pickleshare==0.7.5 230 | - platformdirs==3.1.1 231 | - plotly==5.13.1 232 | - prometheus-client==0.16.0 233 | - prompt-toolkit==3.0.38 234 | - protobuf==3.20.3 235 | - psutil==5.9.4 236 | - ptyprocess==0.7.0 237 | - pure-eval==0.2.2 238 | - pyarrow==11.0.0 239 | - pyasn1==0.4.8 240 | - pyasn1-modules==0.2.8 241 | - pycocotools==2.0.6 242 | - pyequilib==0.5.6 243 | - pygments==2.14.0 244 | - pymeshlab==2022.2.post3 245 | - pyngrok==5.2.1 246 | - pyparsing==3.0.9 247 | - pyquaternion==0.9.9 248 | - pyrsistent==0.19.3 249 | - python-dateutil==2.8.2 250 | - python-engineio==4.4.0 251 | - python-json-logger==2.0.7 252 | - python-socketio==5.8.0 253 | - pytz==2022.7.1 254 | - pywavelets==1.4.1 255 | - pyyaml==6.0 256 | - pyzmq==25.0.2 257 | - qtconsole==5.4.1 258 | - qtpy==2.3.0 259 | - requests-oauthlib==1.3.1 260 | - rfc3339-validator==0.1.4 261 | - rfc3986-validator==0.1.1 262 | - rich==13.3.2 263 | - rsa==4.9 264 | - scikit-image==0.20.0 265 | - scikit-learn==1.2.2 266 | - scipy==1.9.1 267 | - send2trash==1.8.0 268 | - sentry-sdk==1.17.0 269 | - setproctitle==1.3.2 270 | - shapely==2.0.1 271 | - shtab==1.5.8 272 | - smart-open==6.3.0 273 | - smmap==5.0.0 274 | - sniffio==1.3.0 275 | - soupsieve==2.4 276 | - stack-data==0.6.2 277 | - tenacity==8.2.2 278 | - tensorboard==2.9.0 279 | - tensorboard-data-server==0.6.1 280 | - tensorboard-plugin-wit==1.8.1 281 | - termcolor==2.2.0 282 | - terminado==0.17.1 283 | - threadpoolctl==3.1.0 284 | - tifffile==2023.3.21 285 | - timm==0.6.13 286 | - tinycss2==1.2.1 287 | - tomli==2.0.1 288 | - torch-fidelity==0.3.0 289 | - torch-scatter==2.1.1 290 | - torchmetrics==0.11.4 291 | - torchtyping==0.1.4 292 | - tornado==6.2 293 | - tqdm==4.65.0 294 | - traitlets==5.9.0 295 | - typeguard==3.0.2 296 | - tyro==0.4.2 297 | - u-msgpack-python==2.7.2 298 | - uri-template==1.2.0 299 | - wandb==0.14.0 300 | - wcwidth==0.2.6 301 | - webcolors==1.12 302 | - webencodings==0.5.1 303 | - websocket-client==1.5.1 304 | - werkzeug==2.2.3 305 | - widgetsnbextension==4.0.6 306 | - xatlas==0.0.7 307 | - y-py==0.5.9 308 | - ypy-websocket==0.8.2 309 | - zipp==3.15.0 310 | - git+https://github.com/hturki/nerfstudio.git@ht/suds 311 | - git+https://github.com/hturki/tiny-cuda-nn.git@ht/res-grid#subdirectory=bindings/torch 312 | -------------------------------------------------------------------------------- /scripts/configs/kitti-06.yaml: -------------------------------------------------------------------------------- 1 | output_path: metadata/kitti-06.json 2 | kitti_sequence: '0006' 3 | frame_ranges: [0, 269] 4 | train_every: 1 5 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.25/kitti-01.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0001' 2 | frame_ranges: [380, 431] 3 | train_every: 4 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.25/kitti-02.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0002' 2 | frame_ranges: [140, 224] 3 | train_every: 4 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.25/kitti-06.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0006' 2 | frame_ranges: [65, 120] 3 | train_every: 4 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.25/vkitti2-02.yaml: -------------------------------------------------------------------------------- 1 | frame_ranges: [100, 200] 2 | train_every: 4 3 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.25/vkitti2-06.yaml: -------------------------------------------------------------------------------- 1 | frame_ranges: [0, 100] 2 | train_every: 4 3 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.25/vkitti2-18.yaml: -------------------------------------------------------------------------------- 1 | frame_ranges: [273, 338] 2 | train_every: 4 3 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.5/kitti-01.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0001' 2 | frame_ranges: [380, 431] 3 | test_every: 2 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.5/kitti-02.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0002' 2 | frame_ranges: [140, 224] 3 | test_every: 2 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.5/kitti-06.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0006' 2 | frame_ranges: [65, 120] 3 | test_every: 2 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.5/vkitti2-02.yaml: -------------------------------------------------------------------------------- 1 | frame_ranges: [100, 200] 2 | test_every: 2 3 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.5/vkitti2-06.yaml: -------------------------------------------------------------------------------- 1 | frame_ranges: [0, 100] 2 | test_every: 2 3 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.5/vkitti2-18.yaml: -------------------------------------------------------------------------------- 1 | frame_ranges: [273, 338] 2 | test_every: 2 3 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.75/kitti-01.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0001' 2 | frame_ranges: [380, 431] 3 | test_every: 4 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.75/kitti-02.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0002' 2 | frame_ranges: [140, 224] 3 | test_every: 4 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.75/kitti-06.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0006' 2 | frame_ranges: [65, 120] 3 | test_every: 4 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.75/vkitti2-02.yaml: -------------------------------------------------------------------------------- 1 | frame_ranges: [100, 200] 2 | test_every: 4 3 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.75/vkitti2-06.yaml: -------------------------------------------------------------------------------- 1 | frame_ranges: [0, 100] 2 | test_every: 4 3 | -------------------------------------------------------------------------------- /scripts/configs/paper/nvs/0.75/vkitti2-18.yaml: -------------------------------------------------------------------------------- 1 | frame_ranges: [273, 338] 2 | test_every: 4 3 | -------------------------------------------------------------------------------- /scripts/configs/paper/reconstruction/kitti-01.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0001' 2 | frame_ranges: [380, 431] 3 | train_every: 1 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/reconstruction/kitti-02.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0002' 2 | frame_ranges: [140, 224] 3 | train_every: 1 4 | -------------------------------------------------------------------------------- /scripts/configs/paper/reconstruction/kitti-06.yaml: -------------------------------------------------------------------------------- 1 | kitti_sequence: '0006' 2 | frame_ranges: [0, 20, 65, 120] 3 | train_every: 1 4 | -------------------------------------------------------------------------------- /scripts/configs/vkitti2-06.yaml: -------------------------------------------------------------------------------- 1 | output_path: metadata/vkitti2-06.json 2 | vkitti2_path: data/vkitti2/Scene06/clone 3 | frame_ranges: [0, 269] 4 | train_every: 1 5 | -------------------------------------------------------------------------------- /scripts/create_kitti_depth_maps.py: -------------------------------------------------------------------------------- 1 | from smart_open import open 2 | from tqdm import tqdm 3 | 4 | MOVER_CLASSES = [11, 12, 13, 14, 15, 16, 17, 18, 255] 5 | SKY_CLASS = 10 6 | 7 | from argparse import Namespace 8 | from pathlib import Path 9 | 10 | import configargparse 11 | import numpy as np 12 | import pyarrow as pa 13 | import pyarrow.parquet as pq 14 | import torch 15 | from torch_scatter import scatter_min 16 | 17 | from suds.stream_utils import image_from_stream, get_filesystem, buffer_from_stream 18 | 19 | 20 | @torch.inference_mode() 21 | def write_depth_maps(kitti_root: str, kitti_sequence: str) -> None: 22 | fs = get_filesystem(kitti_root) 23 | 24 | if fs is None: 25 | (Path(kitti_root) / 'depth_02' / kitti_sequence).mkdir(parents=True) 26 | (Path(kitti_root) / 'depth_03' / kitti_sequence).mkdir(parents=True) 27 | 28 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 29 | with open('{}/calib/{}.txt'.format(kitti_root, kitti_sequence), 'r') as f: 30 | for line in f: 31 | tokens = line.strip().split() 32 | if tokens[0] == 'P2:': 33 | P2 = torch.eye(4, dtype=torch.float64, device=device) 34 | P2[:3] = torch.DoubleTensor([float(x) for x in tokens[1:]]).view(3, 4) 35 | if tokens[0] == 'P3:': 36 | P3 = torch.eye(4, dtype=torch.float64, device=device) 37 | P3[:3] = torch.DoubleTensor([float(x) for x in tokens[1:]]).view(3, 4) 38 | if tokens[0] == 'R_rect': 39 | R_rect = torch.eye(4, dtype=torch.float64, device=device) 40 | R_rect[:3, :3] = torch.DoubleTensor([float(x) for x in tokens[1:]]).view(3, 3) 41 | if tokens[0] == 'Tr_velo_cam': 42 | Tr_velo_cam = torch.eye(4, dtype=torch.float64, device=device) 43 | Tr_velo_cam[:3] = torch.DoubleTensor([float(x) for x in tokens[1:]]).view(3, 4) 44 | 45 | with open('{}/oxts/{}.txt'.format(kitti_root, kitti_sequence), 'r') as f: 46 | for frame, line in enumerate(tqdm(f)): 47 | lidar_points = np.frombuffer(buffer_from_stream( 48 | '{0}/velodyne/{1}/{2:06d}.bin'.format(kitti_root, kitti_sequence, frame)).getbuffer(), 49 | dtype=np.float32).reshape(-1, 4) 50 | lidar_points = torch.DoubleTensor(lidar_points).to(device) 51 | lidar_points[:, 3] = 1 52 | lidar_points_T = lidar_points.T 53 | for camera, transform in [('2', P2), ('3', P3)]: 54 | points_cam = (transform @ R_rect @ Tr_velo_cam @ lidar_points_T).T[:, :3] 55 | 56 | image_path = '{0}/image_0{1}/{2}/{3:06d}.png'.format(kitti_root, camera, kitti_sequence, frame) 57 | image = image_from_stream(image_path) 58 | depth_map = torch.ones(image.size[1], image.size[0], device=device, dtype=torch.float64) \ 59 | * torch.finfo(torch.float64).max 60 | points_cam[:, :2] = points_cam[:, :2] / points_cam[:, 2].view(-1, 1) 61 | is_valid_x = torch.logical_and(0 <= points_cam[:, 0], points_cam[:, 0] < image.size[0] - 1) 62 | is_valid_y = torch.logical_and(0 <= points_cam[:, 1], points_cam[:, 1] < image.size[1] - 1) 63 | is_valid_z = points_cam[:, 2] > 0 64 | is_valid_points = torch.logical_and(torch.logical_and(is_valid_x, is_valid_y), is_valid_z) 65 | 66 | assert is_valid_points.sum() > 0 67 | 68 | u = torch.round(points_cam[:, 0][is_valid_points]).long().cuda() 69 | v = torch.round(points_cam[:, 1][is_valid_points]).long().cuda() 70 | z = points_cam[:, 2][is_valid_points].cuda() 71 | scatter_min(z, v * image.size[0] + u, out=depth_map.view(-1)) 72 | 73 | depth_map[depth_map >= torch.finfo(torch.float64).max - 1e-5] = 0 74 | pq.write_table(pa.table({'depth': depth_map.cpu().float().numpy().flatten()}), 75 | '{0}/depth_0{1}/{2}/{3:06d}.parquet'.format(kitti_root, camera, kitti_sequence, frame), 76 | filesystem=fs, compression='BROTLI') 77 | 78 | 79 | def _get_opts() -> Namespace: 80 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) 81 | parser.add_argument('--config_file', is_config_file=True) 82 | 83 | parser.add_argument('--kitti_root', type=str, default='data/kitti/training') 84 | parser.add_argument('--kitti_sequence', type=str, required=True) 85 | 86 | return parser.parse_known_args()[0] 87 | 88 | 89 | def main(hparams: Namespace) -> None: 90 | write_depth_maps(hparams.kitti_root, hparams.kitti_sequence) 91 | 92 | 93 | if __name__ == '__main__': 94 | main(_get_opts()) 95 | -------------------------------------------------------------------------------- /scripts/create_kitti_feature_clusters.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import random 4 | from argparse import Namespace 5 | from collections import defaultdict 6 | from concurrent.futures import ThreadPoolExecutor 7 | from io import BytesIO 8 | from pathlib import Path 9 | from typing import Dict, Tuple, Any 10 | 11 | import configargparse 12 | import numpy as np 13 | import pyarrow.parquet as pq 14 | import torch 15 | import torch.nn.functional as F 16 | from smart_open import open 17 | from tqdm import tqdm 18 | 19 | from suds.stream_utils import get_filesystem, image_from_stream 20 | 21 | CITYSCAPE_COLORS = torch.ByteTensor([[128, 64, 128], 22 | [244, 35, 232], 23 | [70, 70, 70], 24 | [102, 102, 156], 25 | [190, 153, 153], 26 | [153, 153, 153], 27 | [250, 170, 30], 28 | [220, 220, 0], 29 | [107, 142, 35], 30 | [152, 251, 152], 31 | [70, 130, 180], 32 | [220, 20, 60], 33 | [255, 0, 0], 34 | [0, 0, 142], 35 | [0, 0, 70], 36 | [0, 60, 100], 37 | [0, 80, 100], 38 | [0, 0, 230], 39 | [119, 11, 32]]) 40 | 41 | 42 | def load_class_features(sequence_path: str, item: Dict[str, Any]) -> Tuple[Dict[int, int], Dict[int, torch.Tensor]]: 43 | feature_path = item['feature_path'] 44 | 45 | table = pq.read_table(feature_path) 46 | features = torch.FloatTensor(table['pca'].to_numpy()).view( 47 | [int(x) for x in table.schema.metadata[b'shape'].split()]) 48 | 49 | if (features.shape[0] != item['H'] or features.shape[1] != item['W']): 50 | features = F.interpolate(features.permute(2, 0, 1).unsqueeze(0), size=(item['H'], item['W'])).squeeze() \ 51 | .permute(1, 2, 0) 52 | 53 | frame = Path(feature_path).stem 54 | categories = torch.LongTensor(np.asarray(image_from_stream('{}/{}.png'.format(sequence_path, frame))))[:, :, 55 | 0].view(-1) 56 | 57 | sorted_categories, ordering = categories.sort() 58 | unique_categories, counts = torch.unique_consecutive(sorted_categories, return_counts=True) 59 | 60 | category_features = {} 61 | category_counts = {} 62 | 63 | offset = 0 64 | for category, category_count in zip(unique_categories, counts): 65 | if category > 18: 66 | continue 67 | 68 | category_counts[category.item()] = category_count.item() 69 | category_features[category.item()] = features.view(-1, features.shape[-1])[ 70 | ordering[offset:offset + category_count]].sum(dim=0) 71 | offset += category_count 72 | 73 | return category_features, category_counts 74 | 75 | 76 | def _get_opts() -> Namespace: 77 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) 78 | parser.add_argument('--config_file', is_config_file=True) 79 | 80 | parser.add_argument('--metadata_path', type=str, required=True) 81 | parser.add_argument('--output_path', type=str, required=True) 82 | parser.add_argument('--kitti_step_path', type=str, default='data/kitti/kitti-step/panoptic_maps') 83 | parser.add_argument('--subset_ratio', type=float, default=0.1) 84 | 85 | return parser.parse_known_args()[0] 86 | 87 | 88 | def main(hparams: Namespace) -> None: 89 | with open(hparams.metadata_path) as f: 90 | metadata = json.load(f) 91 | 92 | class_features = {} 93 | class_counts = defaultdict(int) 94 | 95 | frames = metadata['frames'] 96 | frames_with_sem = list(filter(lambda x: '/dino_02/' in x['feature_path'], frames)) 97 | indices = np.linspace(0, len(frames_with_sem), int(len(frames) * hparams.subset_ratio), endpoint=False, 98 | dtype=np.int32) 99 | 100 | random.seed(42) 101 | random.shuffle(indices) 102 | 103 | kitti_sequence = Path(metadata['frames'][0]['rgb_path']).parent.name 104 | sequence_path = '{}/train/{}'.format(hparams.kitti_step_path, kitti_sequence) 105 | fs = get_filesystem(sequence_path) 106 | if (fs is None and (not Path(sequence_path).exists())) or (fs is not None and (not fs.exists(sequence_path))): 107 | sequence_path = '{}/val/{}'.format(hparams.kitti_step_path, kitti_sequence) 108 | 109 | with ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor: 110 | futures = {} 111 | for index in indices: 112 | futures[index] = executor.submit(load_class_features, sequence_path, frames[index]) 113 | 114 | for index in tqdm(indices): 115 | frame_class_features, frame_class_counts = futures[index].result() 116 | for key, val in frame_class_features.items(): 117 | if key not in class_features: 118 | class_features[key] = val 119 | else: 120 | class_features[key] = class_features[key] + val 121 | 122 | class_counts[key] += frame_class_counts[key] 123 | del futures[index] 124 | 125 | class_clusters = [] 126 | category_colors = [] 127 | for key in sorted(class_features.keys()): 128 | class_clusters.append((class_features[key] / class_counts[key]).unsqueeze(0)) 129 | category_colors.append(CITYSCAPE_COLORS[key:key + 1]) 130 | 131 | class_clusters = torch.cat(class_clusters) 132 | category_colors = torch.cat(category_colors) 133 | 134 | buffer = BytesIO() 135 | torch.save({'centroids': class_clusters, 'colors': category_colors}, buffer) 136 | 137 | if get_filesystem(hparams.output_path) is None: 138 | Path(hparams.output_path).parent.mkdir(parents=True, exist_ok=True) 139 | 140 | with open(hparams.output_path, 'wb') as f: 141 | f.write(buffer.getbuffer()) 142 | 143 | 144 | if __name__ == '__main__': 145 | main(_get_opts()) 146 | -------------------------------------------------------------------------------- /scripts/create_kitti_masks.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from pathlib import Path 3 | 4 | import configargparse 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | from suds.stream_utils import image_from_stream, image_to_stream, get_filesystem 12 | 13 | MOVER_CLASSES = [11, 12, 13, 14, 15, 16, 17, 18, 255] 14 | SKY_CLASS = 10 15 | 16 | 17 | def write_static_masks(kitti_root: str, kitti_step_path: str, kitti_sequence: str, dilation: int) -> None: 18 | if get_filesystem(kitti_root) is None: 19 | (Path(kitti_root) / 'static_02' / kitti_sequence).mkdir(parents=True) 20 | (Path(kitti_root) / 'sky_02' / kitti_sequence).mkdir(parents=True) 21 | 22 | sequence_path = '{}/train/{}'.format(kitti_step_path, kitti_sequence) 23 | fs = get_filesystem(sequence_path) 24 | if (fs is None and (not Path(sequence_path).exists())) or (fs is not None and (not fs.exists(sequence_path))): 25 | sequence_path = '{}/val/{}'.format(kitti_step_path, kitti_sequence) 26 | 27 | with open('{}/oxts/{}.txt'.format(kitti_root, kitti_sequence), 'r') as f: 28 | for frame, line in enumerate(tqdm(f)): 29 | category_path_2 = '{0}/{1:06d}.png'.format(sequence_path, frame) 30 | category = torch.LongTensor(np.asarray(image_from_stream(category_path_2)))[:, :, 0] 31 | 32 | mover = torch.zeros_like(category, dtype=torch.bool) 33 | for mover_class in MOVER_CLASSES: 34 | mover[category == mover_class] = True 35 | mover = mover.float().numpy() 36 | kernel = np.ones((dilation, dilation), dtype=np.float32) 37 | mover = cv2.dilate(mover, kernel) 38 | static_mask = Image.fromarray(mover <= 0) 39 | 40 | image_to_stream(static_mask, '{0}/static_02/{1}/{2:06d}.png'.format(kitti_root, kitti_sequence, frame)) 41 | 42 | sky_mask = Image.fromarray((category == SKY_CLASS).numpy()) 43 | image_to_stream(sky_mask, '{0}/sky_02/{1}/{2:06d}.png'.format(kitti_root, kitti_sequence, frame)) 44 | 45 | if frame == 0: 46 | # create all-false for camera 3 47 | all_false = Image.fromarray(np.zeros_like(mover <= 0)) 48 | image_to_stream(all_false, '{0}/all-false.png'.format(kitti_root)) 49 | 50 | 51 | def _get_opts() -> Namespace: 52 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) 53 | parser.add_argument('--config_file', is_config_file=True) 54 | 55 | parser.add_argument('--kitti_root', type=str, default='data/kitti/training') 56 | parser.add_argument('--kitti_step_path', type=str, default='data/kitti/kitti-step/panoptic_maps') 57 | parser.add_argument('--kitti_sequence', type=str, required=True) 58 | parser.add_argument('--dilation', type=int, default=30) 59 | 60 | return parser.parse_known_args()[0] 61 | 62 | 63 | def main(hparams: Namespace) -> None: 64 | write_static_masks(hparams.kitti_root, hparams.kitti_step_path, hparams.kitti_sequence, hparams.dilation) 65 | 66 | 67 | if __name__ == '__main__': 68 | main(_get_opts()) 69 | -------------------------------------------------------------------------------- /scripts/create_vkitti2_feature_clusters.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import random 4 | from argparse import Namespace 5 | from collections import defaultdict 6 | from concurrent.futures import ThreadPoolExecutor 7 | from io import BytesIO 8 | from pathlib import Path 9 | from typing import Dict, Tuple, Any 10 | 11 | import configargparse 12 | import numpy as np 13 | import pyarrow.parquet as pq 14 | import torch 15 | import torch.nn.functional as F 16 | from smart_open import open 17 | from tqdm import tqdm 18 | 19 | from suds.stream_utils import get_filesystem, image_from_stream 20 | 21 | 22 | def load_class_features(item: Dict[str, Any]) -> Tuple[Dict[int, int], Dict[int, torch.Tensor]]: 23 | feature_path = item['feature_path'] 24 | 25 | table = pq.read_table(feature_path) 26 | features = torch.FloatTensor(table['pca'].to_numpy()).view( 27 | [int(x) for x in table.schema.metadata[b'shape'].split()]) 28 | 29 | if (features.shape[0] != item['H'] or features.shape[1] != item['W']): 30 | features = F.interpolate(features.permute(2, 0, 1).unsqueeze(0), size=(item['H'], item['W'])).squeeze() \ 31 | .permute(1, 2, 0) 32 | 33 | frame = Path(feature_path).stem 34 | categories = torch.LongTensor(np.asarray(image_from_stream( 35 | item['depth_path'].replace('/depth/', '/classSegmentation/').replace('depth_', 'classgt_')))).sum(dim=-1) 36 | 37 | sorted_categories, ordering = categories.sort() 38 | unique_categories, counts = torch.unique_consecutive(sorted_categories, return_counts=True) 39 | 40 | category_features = {} 41 | category_counts = {} 42 | 43 | offset = 0 44 | for category, category_count in zip(unique_categories, counts): 45 | category_counts[category.item()] = category_count.item() 46 | category_features[category.item()] = features.view(-1, features.shape[-1])[ 47 | ordering[offset:offset + category_count]].sum(dim=0) 48 | offset += category_count 49 | 50 | return category_features, category_counts 51 | 52 | 53 | def _get_opts() -> Namespace: 54 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) 55 | parser.add_argument('--config_file', is_config_file=True) 56 | 57 | parser.add_argument('--metadata_path', type=str, required=True) 58 | parser.add_argument('--output_path', type=str, required=True) 59 | parser.add_argument('--vkitti2_path', type=str, required=True) 60 | parser.add_argument('--subset_ratio', type=float, default=0.1) 61 | 62 | return parser.parse_known_args()[0] 63 | 64 | 65 | def main(hparams: Namespace) -> None: 66 | with open(hparams.metadata_path) as f: 67 | metadata = json.load(f) 68 | 69 | class_features = {} 70 | class_counts = defaultdict(int) 71 | 72 | frames = metadata['frames'] 73 | 74 | with open('{}/colors.txt'.format(hparams.vkitti2_path)) as f: 75 | # Category r g b 76 | next(f) # skip header 77 | # Terrain 210 0 200 78 | color_mappings = [] 79 | for line in f: 80 | color_mappings.append([int(x) for x in line.strip().split()[1:]]) 81 | color_mappings = torch.LongTensor(color_mappings) 82 | colors = color_mappings.sum(dim=-1) 83 | assert torch.unique(colors).shape[0] == colors.shape[0] 84 | 85 | all_category_colors = torch.zeros(colors.max() + 2, 3, dtype=torch.uint8) 86 | for color, color_mapping in zip(colors, color_mappings): 87 | all_category_colors[color] = color_mapping 88 | 89 | indices = np.linspace(0, len(frames), int(len(frames) * hparams.subset_ratio), endpoint=False, 90 | dtype=np.int32) 91 | 92 | random.seed(42) 93 | random.shuffle(indices) 94 | 95 | with ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor: 96 | futures = {} 97 | for index in indices: 98 | futures[index] = executor.submit(load_class_features, frames[index]) 99 | 100 | for index in tqdm(indices): 101 | frame_class_features, frame_class_counts = futures[index].result() 102 | for key, val in frame_class_features.items(): 103 | if key not in class_features: 104 | class_features[key] = val 105 | else: 106 | class_features[key] = class_features[key] + val 107 | 108 | class_counts[key] += frame_class_counts[key] 109 | del futures[index] 110 | 111 | class_clusters = [] 112 | category_colors = [] 113 | for key in sorted(class_features.keys()): 114 | class_clusters.append((class_features[key] / class_counts[key]).unsqueeze(0)) 115 | category_colors.append(all_category_colors[key:key + 1]) 116 | 117 | class_clusters = torch.cat(class_clusters) 118 | category_colors = torch.cat(category_colors) 119 | 120 | buffer = BytesIO() 121 | torch.save({'centroids': class_clusters, 'colors': category_colors}, buffer) 122 | 123 | if get_filesystem(hparams.output_path) is None: 124 | Path(hparams.output_path).parent.mkdir(parents=True, exist_ok=True) 125 | 126 | with open(hparams.output_path, 'wb') as f: 127 | f.write(buffer.getbuffer()) 128 | 129 | 130 | if __name__ == '__main__': 131 | main(_get_opts()) 132 | -------------------------------------------------------------------------------- /scripts/create_vkitti2_masks.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from pathlib import Path 3 | 4 | import configargparse 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from smart_open import open 10 | from tqdm import tqdm 11 | 12 | from suds.stream_utils import image_from_stream, image_to_stream, get_filesystem 13 | 14 | MOVER_CLASSES = [280, 462, 278] 15 | SKY_CLASS = 545 16 | 17 | def write_masks(vkitti2_path: str, dilation: int) -> None: 18 | if get_filesystem(vkitti2_path) is None: 19 | for i in range(2): 20 | (Path(vkitti2_path) / 'frames' / 'static_mask' / 'Camera_{}'.format(i)).mkdir(parents=True) 21 | (Path(vkitti2_path) / 'frames' / 'sky_mask' / 'Camera_{}'.format(i)).mkdir(parents=True) 22 | 23 | with open('{}/intrinsic.txt'.format(vkitti2_path), 'r') as in_f, \ 24 | open('{}/extrinsic.txt'.format(vkitti2_path), 'r') as ex_f: 25 | # frame cameraID K[0,0] K[1,1] K[0,2] K[1,2] 26 | next(in_f) 27 | # frame cameraID r1,1 r1,2 r1,3 t1 r2,1 r2,2 r2,3 t2 r3,1 r3,2 r3,3 t3 0 0 0 1 28 | next(ex_f) 29 | 30 | for in_line, ex_line in tqdm(zip(in_f, ex_f)): 31 | in_entry = in_line.strip().split() 32 | frame = int(in_entry[0]) 33 | cameraID = int(in_entry[1]) 34 | 35 | category_path = '{0}/frames/classSegmentation/Camera_{1}/classgt_{2:05d}.png'.format( 36 | vkitti2_path, cameraID, frame) 37 | category = torch.LongTensor(np.asarray(image_from_stream(category_path))).sum(dim=-1) 38 | mover = torch.zeros_like(category, dtype=torch.bool) 39 | for mover_class in MOVER_CLASSES: 40 | mover[category == mover_class] = True 41 | mover = mover.float().numpy() 42 | kernel = np.ones((dilation, dilation), dtype=np.float32) 43 | mover = cv2.dilate(mover, kernel) 44 | static_mask = Image.fromarray(mover <= 0) 45 | 46 | image_to_stream(static_mask, '{0}/frames/static_mask/Camera_{1}/static_mask_{2:05d}.png'.format( 47 | vkitti2_path, cameraID, frame)) 48 | 49 | sky_mask = Image.fromarray((category == SKY_CLASS).numpy()) 50 | image_to_stream(sky_mask, '{0}/frames/sky_mask/Camera_{1}/sky_mask_{2:05d}.png'.format( 51 | vkitti2_path, cameraID, frame)) 52 | 53 | 54 | def _get_opts() -> Namespace: 55 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) 56 | parser.add_argument('--config_file', is_config_file=True) 57 | 58 | parser.add_argument('--vkitti2_path', type=str, required=True) 59 | parser.add_argument('--dilation', type=int, default=30) 60 | 61 | return parser.parse_known_args()[0] 62 | 63 | 64 | def main(hparams: Namespace) -> None: 65 | write_masks(hparams.vkitti2_path, hparams.dilation) 66 | 67 | 68 | if __name__ == '__main__': 69 | main(_get_opts()) 70 | -------------------------------------------------------------------------------- /scripts/create_vkitti2_metadata.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from pathlib import Path 3 | from typing import List, Optional, Tuple 4 | 5 | import configargparse 6 | import torch 7 | from smart_open import open 8 | from tqdm import tqdm 9 | 10 | from metadata_utils import get_frame_range, get_bounds_from_depth, get_neighbor, \ 11 | write_metadata, get_val_frames, scale_bounds, OPENCV_TO_OPENGL, normalize_timestamp 12 | from suds.data.image_metadata import ImageMetadata 13 | from suds.stream_utils import image_from_stream, get_filesystem 14 | 15 | GROUND_PLANE_Z = torch.DoubleTensor([[1, 0, 0, 0], 16 | [0, 0, 1, 0], 17 | [0, -1, 0, 0], 18 | [0, 0, 0, 1]]) 19 | 20 | 21 | def get_vkitti2_items(vkitti2_path: str, 22 | frame_ranges: Optional[List[Tuple[int]]], 23 | train_every: Optional[int], 24 | test_every: Optional[int], 25 | use_gt_flow: bool) -> \ 26 | Tuple[List[ImageMetadata], List[str], torch.Tensor, float, torch.Tensor]: 27 | with open('{}/intrinsic.txt'.format(vkitti2_path), 'r') as in_f, \ 28 | open('{}/extrinsic.txt'.format(vkitti2_path), 'r') as ex_f: 29 | # frame cameraID K[0,0] K[1,1] K[0,2] K[1,2] 30 | next(in_f) 31 | # frame cameraID r1,1 r1,2 r1,3 t1 r2,1 r2,2 r2,3 t2 r3,1 r3,2 r3,3 t3 0 0 0 1 32 | next(ex_f) 33 | 34 | num_frames = 0 35 | for in_line, ex_line in zip(in_f, ex_f): 36 | in_entry = in_line.strip().split() 37 | frame = int(in_entry[0]) 38 | if frame_ranges is not None and get_frame_range(frame_ranges, frame) is None: 39 | continue 40 | num_frames += 1 41 | 42 | val_frames = get_val_frames(num_frames, test_every, train_every) 43 | metadata_items: List[ImageMetadata] = [] 44 | item_frame_ranges: List[Tuple[int]] = [] 45 | static_masks = [] 46 | min_bounds = None 47 | max_bounds = None 48 | 49 | use_masks = True 50 | with open('{}/intrinsic.txt'.format(vkitti2_path), 'r') as in_f, \ 51 | open('{}/extrinsic.txt'.format(vkitti2_path), 'r') as ex_f: 52 | # frame cameraID K[0,0] K[1,1] K[0,2] K[1,2] 53 | next(in_f) 54 | # frame cameraID r1,1 r1,2 r1,3 t1 r2,1 r2,2 r2,3 t2 r3,1 r3,2 r3,3 t3 0 0 0 1 55 | next(ex_f) 56 | 57 | min_frame = None 58 | max_frame = None 59 | for in_line, ex_line in tqdm(zip(in_f, ex_f)): 60 | in_entry = in_line.strip().split() 61 | frame = int(in_entry[0]) 62 | frame_range = get_frame_range(frame_ranges, frame) if frame_ranges is not None else None 63 | if frame_ranges is not None and frame_range is None: 64 | continue 65 | 66 | min_frame = min(frame, min_frame) if min_frame is not None else frame 67 | max_frame = max(frame, max_frame) if max_frame is not None else frame 68 | cameraID = int(in_entry[1]) 69 | 70 | w2c = torch.DoubleTensor([float(x) for x in ex_line.strip().split()[2:]]).view(4, 4) 71 | c2w = (GROUND_PLANE_Z @ (torch.inverse(w2c) @ OPENCV_TO_OPENGL))[:3] 72 | 73 | image_index = len(metadata_items) 74 | is_val = image_index // 2 in val_frames 75 | 76 | if is_val: 77 | backward_neighbor = image_index - 2 78 | forward_neighbor = image_index + 2 79 | else: 80 | backward_neighbor = get_neighbor(image_index, val_frames, -2) 81 | forward_neighbor = get_neighbor(image_index, val_frames, 2) 82 | 83 | backward_suffix = '' if (image_index - backward_neighbor) // 2 == 1 else '-{}'.format( 84 | (image_index - backward_neighbor) // 2) 85 | forward_suffix = '' if (forward_neighbor - image_index) // 2 == 1 else '-{}'.format( 86 | (forward_neighbor - image_index) // 2) 87 | 88 | if use_gt_flow: 89 | backward_flow_path = '{0}/frames/backwardFlow{1}/Camera_{2}/backwardFlow_{3:05d}.png'.format( 90 | vkitti2_path, backward_suffix, cameraID, frame) 91 | forward_flow_path = '{0}/frames/forwardFlow{1}/Camera_{2}/flow_{3:05d}.png'.format(vkitti2_path, 92 | forward_suffix, 93 | cameraID, frame) 94 | else: 95 | backward_flow_path = '{0}/frames/dino_correspondences{1}/Camera_{2}/rgb_{3:05d}.parquet'.format( 96 | vkitti2_path, forward_suffix, cameraID, frame - (image_index - backward_neighbor) // 2) 97 | forward_flow_path = '{0}/frames/dino_correspondences{1}/Camera_{2}/rgb_{3:05d}.parquet'.format( 98 | vkitti2_path, forward_suffix, cameraID, frame) 99 | 100 | image_path = '{0}/frames/rgb/Camera_{1}/rgb_{2:05d}.jpg'.format(vkitti2_path, cameraID, frame) 101 | image = image_from_stream(image_path) 102 | 103 | sky_mask_path = '{0}/frames/sky_mask/Camera_{1}/sky_mask_{2:05d}.png'.format(vkitti2_path, cameraID, frame) \ 104 | if use_masks else None 105 | if sky_mask_path is not None and use_masks: 106 | fs = get_filesystem(sky_mask_path) 107 | if (fs is None and (not Path(sky_mask_path).exists())) or \ 108 | (fs is not None and (not fs.exists(sky_mask_path))): 109 | print('Did not find sky mask at {} - not including static or sky masks in metadata'.format( 110 | sky_mask_path)) 111 | use_masks = False 112 | sky_mask_path = None 113 | 114 | item = ImageMetadata( 115 | image_path, 116 | c2w, 117 | image.size[0], 118 | image.size[1], 119 | torch.FloatTensor([float(x) for x in in_line.strip().split()[2:]]), 120 | image_index, 121 | frame, 122 | 0, 123 | '{0}/frames/depth/Camera_{1}/depth_{2:05d}.png'.format(vkitti2_path, cameraID, frame), 124 | None, 125 | sky_mask_path, 126 | '{0}/frames/dino/Camera_{1}/dino_{2:05d}.parquet'.format(vkitti2_path, cameraID, frame), 127 | backward_flow_path, 128 | forward_flow_path, 129 | backward_neighbor, 130 | forward_neighbor, 131 | is_val, 132 | 1, 133 | None 134 | ) 135 | 136 | metadata_items.append(item) 137 | item_frame_ranges.append(frame_range) 138 | 139 | if use_masks: 140 | static_mask_path = '{0}/frames/static_mask/Camera_{1}/static_mask_{2:05d}.png'.format(vkitti2_path, 141 | cameraID, 142 | frame) 143 | 144 | static_masks.append(static_mask_path) 145 | 146 | min_bounds, max_bounds = get_bounds_from_depth(item, min_bounds, max_bounds) 147 | 148 | for item in metadata_items: 149 | normalize_timestamp(item, min_frame, max_frame) 150 | 151 | for item in metadata_items: 152 | if item.backward_neighbor_index < 0 \ 153 | or item_frame_ranges[item.image_index] != item_frame_ranges[item.backward_neighbor_index]: 154 | item.backward_flow_path = None 155 | item.backward_neighbor_index = None 156 | 157 | if item.forward_neighbor_index >= len(metadata_items) \ 158 | or item_frame_ranges[item.image_index] != item_frame_ranges[item.forward_neighbor_index]: 159 | item.forward_flow_path = None 160 | item.forward_neighbor_index = None 161 | 162 | origin, pose_scale_factor, scene_bounds = scale_bounds(metadata_items, min_bounds, max_bounds) 163 | 164 | return metadata_items, static_masks, origin, pose_scale_factor, scene_bounds 165 | 166 | 167 | def _get_opts() -> Namespace: 168 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) 169 | parser.add_argument('--config_file', is_config_file=True) 170 | 171 | parser.add_argument('--output_path', type=str, required=True) 172 | parser.add_argument('--vkitti2_path', type=str, required=True) 173 | parser.add_argument('--frame_ranges', type=int, nargs='+', default=None) 174 | parser.add_argument('--train_every', type=int, default=None) 175 | parser.add_argument('--test_every', type=int, default=None) 176 | parser.add_argument('--use_gt_flow', default=False, action='store_true') 177 | 178 | return parser.parse_args() 179 | 180 | 181 | def main(hparams: Namespace) -> None: 182 | assert hparams.train_every is not None or hparams.test_every is not None, \ 183 | 'Exactly one of train_every or test_every must be specified' 184 | 185 | assert hparams.train_every is None or hparams.test_every is None, \ 186 | 'Only one of train_every or test_every must be specified' 187 | 188 | if hparams.frame_ranges is not None: 189 | frame_ranges = [] 190 | for i in range(0, len(hparams.frame_ranges), 2): 191 | frame_ranges.append([hparams.frame_ranges[i], hparams.frame_ranges[i + 1]]) 192 | else: 193 | frame_ranges = None 194 | 195 | metadata_items, static_masks, origin, pose_scale_factor, scene_bounds = get_vkitti2_items(hparams.vkitti2_path, 196 | frame_ranges, 197 | hparams.train_every, 198 | hparams.test_every, 199 | hparams.use_gt_flow) 200 | 201 | write_metadata(hparams.output_path, metadata_items, static_masks, origin, pose_scale_factor, scene_bounds) 202 | 203 | 204 | if __name__ == '__main__': 205 | main(_get_opts()) 206 | -------------------------------------------------------------------------------- /scripts/extract_dino_correspondences.py: -------------------------------------------------------------------------------- 1 | """Code adapted and modified from https://github.com/ShirAmir/dino-vit-features""" 2 | 3 | import argparse 4 | import datetime 5 | import json 6 | import os 7 | import traceback 8 | from pathlib import Path 9 | from typing import List, Tuple 10 | 11 | import configargparse 12 | import numpy as np 13 | import pyarrow as pa 14 | import pyarrow.parquet as pq 15 | import torch 16 | import torch.distributed as dist 17 | from PIL import Image 18 | from sklearn.cluster import KMeans 19 | from smart_open import open 20 | from tqdm import tqdm 21 | 22 | from extract_dino_features import ViTExtractor 23 | from suds.stream_utils import get_filesystem 24 | 25 | 26 | def find_correspondences(extractor: ViTExtractor, image_path1: str, image_path2: str, num_pairs: int = 10, 27 | load_size: int = 224, layer: int = 9, 28 | facet: str = 'key', bin: bool = True, thresh: float = 0.05) -> Tuple[ 29 | List[Tuple[float, float]], List[Tuple[float, float]], Image.Image, Image.Image]: 30 | """ 31 | finding point correspondences between two images. 32 | :param image_path1: path to the first image. 33 | :param image_path2: path to the second image. 34 | :param num_pairs: number of outputted corresponding pairs. 35 | :param load_size: size of the smaller edge of loaded images. If None, does not resize. 36 | :param layer: layer to extract descriptors from. 37 | :param facet: facet to extract descriptors from. 38 | :param bin: if True use a log-binning descriptor. 39 | :param thresh: threshold of saliency maps to distinguish fg and bg. 40 | :param model_type: type of model to extract descriptors from. 41 | :param stride: stride of the model. 42 | :return: list of points from image_path1, list of corresponding points from image_path2, the processed pil image of 43 | image_path1, and the processed pil image of image_path2. 44 | """ 45 | # extracting descriptors for each image 46 | image1_batch, image1_pil = extractor.preprocess(image_path1, load_size) 47 | descriptors1 = extractor.extract_descriptors(image1_batch.to(extractor.device), layer, facet, bin) 48 | num_patches1, load_size1 = extractor.num_patches, extractor.load_size 49 | image2_batch, image2_pil = extractor.preprocess(image_path2, load_size) 50 | descriptors2 = extractor.extract_descriptors(image2_batch.to(extractor.device), layer, facet, bin) 51 | num_patches2, load_size2 = extractor.num_patches, extractor.load_size 52 | 53 | # extracting saliency maps for each image 54 | saliency_map1 = extractor.extract_saliency_maps(image1_batch.to(extractor.device))[0] 55 | saliency_map2 = extractor.extract_saliency_maps(image2_batch.to(extractor.device))[0] 56 | # threshold saliency maps to get fg / bg masks 57 | fg_mask1 = saliency_map1 > thresh 58 | fg_mask2 = saliency_map2 > thresh 59 | 60 | # calculate similarity between image1 and image2 descriptors 61 | similarities = chunk_cosine_sim(descriptors1, descriptors2) 62 | 63 | # calculate best buddies 64 | image_idxs = torch.arange(num_patches1[0] * num_patches1[1], device=extractor.device) 65 | sim_1, nn_1 = torch.max(similarities, dim=-1) # nn_1 - indices of block2 closest to block1 66 | sim_2, nn_2 = torch.max(similarities, dim=-2) # nn_2 - indices of block1 closest to block2 67 | sim_1, nn_1 = sim_1[0, 0], nn_1[0, 0] 68 | sim_2, nn_2 = sim_2[0, 0], nn_2[0, 0] 69 | bbs_mask = nn_2[nn_1] == image_idxs 70 | 71 | # remove best buddies where at least one descriptor is marked bg by saliency mask. 72 | fg_mask2_new_coors = nn_2[fg_mask2] 73 | fg_mask2_mask_new_coors = torch.zeros(num_patches1[0] * num_patches1[1], dtype=torch.bool, device=extractor.device) 74 | fg_mask2_mask_new_coors[fg_mask2_new_coors] = True 75 | bbs_mask = torch.bitwise_and(bbs_mask, fg_mask1) 76 | bbs_mask = torch.bitwise_and(bbs_mask, fg_mask2_mask_new_coors) 77 | 78 | # applying k-means to extract k high quality well distributed correspondence pairs 79 | bb_descs1 = descriptors1[0, 0, bbs_mask, :].cpu().numpy() 80 | bb_descs2 = descriptors2[0, 0, nn_1[bbs_mask], :].cpu().numpy() 81 | # apply k-means on a concatenation of a pairs descriptors. 82 | all_keys_together = np.concatenate((bb_descs1, bb_descs2), axis=1) 83 | n_clusters = min(num_pairs, len(all_keys_together)) # if not enough pairs, show all found pairs. 84 | length = np.sqrt((all_keys_together ** 2).sum(axis=1))[:, None] 85 | normalized = all_keys_together / length 86 | kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10).fit(normalized) 87 | bb_topk_sims = np.full((n_clusters), -np.inf) 88 | bb_indices_to_show = np.full((n_clusters), -np.inf) 89 | 90 | # rank pairs by their mean saliency value 91 | bb_cls_attn1 = saliency_map1[bbs_mask] 92 | bb_cls_attn2 = saliency_map2[nn_1[bbs_mask]] 93 | bb_cls_attn = (bb_cls_attn1 + bb_cls_attn2) / 2 94 | ranks = bb_cls_attn 95 | 96 | for i, (label, rank) in enumerate(zip(kmeans.labels_, ranks)): 97 | if rank > bb_topk_sims[label]: 98 | bb_topk_sims[label] = rank 99 | bb_indices_to_show[label] = i 100 | 101 | # get coordinates to show 102 | indices_to_show = torch.nonzero(bbs_mask, as_tuple=False).squeeze(dim=1)[ 103 | bb_indices_to_show] # close bbs 104 | img1_indices_to_show = torch.arange(num_patches1[0] * num_patches1[1], device=extractor.device)[indices_to_show] 105 | img2_indices_to_show = nn_1[indices_to_show] 106 | # coordinates in descriptor map's dimensions 107 | img1_y_to_show = (img1_indices_to_show / num_patches1[1]).cpu().numpy() 108 | img1_x_to_show = (img1_indices_to_show % num_patches1[1]).cpu().numpy() 109 | img2_y_to_show = (img2_indices_to_show / num_patches2[1]).cpu().numpy() 110 | img2_x_to_show = (img2_indices_to_show % num_patches2[1]).cpu().numpy() 111 | points1, points2 = [], [] 112 | for y1, x1, y2, x2 in zip(img1_y_to_show, img1_x_to_show, img2_y_to_show, img2_x_to_show): 113 | x1_show = (int(x1) - 1) * extractor.stride[1] + extractor.stride[1] + extractor.p // 2 114 | y1_show = (int(y1) - 1) * extractor.stride[0] + extractor.stride[0] + extractor.p // 2 115 | x2_show = (int(x2) - 1) * extractor.stride[1] + extractor.stride[1] + extractor.p // 2 116 | y2_show = (int(y2) - 1) * extractor.stride[0] + extractor.stride[0] + extractor.p // 2 117 | points1.append((y1_show, x1_show)) 118 | points2.append((y2_show, x2_show)) 119 | return points1, points2, image1_pil, image2_pil 120 | 121 | 122 | def chunk_cosine_sim(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 123 | """ Computes cosine similarity between all possible pairs in two sets of vectors. 124 | Operates on chunks so no large amount of GPU RAM is required. 125 | :param x: a tensor of descriptors of shape Bx1x(t_x)xd' where d' is the dimensionality of the descriptors and t_x 126 | is the number of tokens in x. 127 | :param y: a tensor of descriptors of shape Bx1x(t_y)xd' where d' is the dimensionality of the descriptors and t_y 128 | is the number of tokens in y. 129 | :return: cosine similarity between all descriptors in x and all descriptors in y. Has shape of Bx1x(t_x)x(t_y) """ 130 | result_list = [] 131 | num_token_x = x.shape[2] 132 | for token_idx in range(num_token_x): 133 | token = x[:, :, token_idx, :].unsqueeze(dim=2) # Bx1x1xd' 134 | result_list.append(torch.nn.CosineSimilarity(dim=3)(token, y)) # Bx1xt 135 | return torch.stack(result_list, dim=2) # Bx1x(t_x)x(t_y) 136 | 137 | 138 | def _get_opts() -> argparse.Namespace: 139 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) 140 | parser.add_argument('--config_file', is_config_file=True) 141 | 142 | parser.add_argument('--metadata_path', type=str, required=True) 143 | parser.add_argument('--load_size', default=375, type=int, help='load size of the input image.') 144 | parser.add_argument('--stride', default=4, type=int, help="""stride of first convolution layer. 145 | small stride -> higher resolution.""") 146 | parser.add_argument('--model_type', default='dino_vits8', type=str, 147 | help="""type of model to extract. 148 | Choose from [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | 149 | vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224]""") 150 | parser.add_argument('--facet', default='key', type=str, help="""facet to create descriptors from. 151 | options: ['key' | 'query' | 'value' | 'token']""") 152 | parser.add_argument('--layer', default=11, type=int, help="layer to create descriptors from.") 153 | parser.add_argument('--bin', default=False, action='store_true', help='create a binned descriptor if True.') 154 | parser.add_argument('--thresh', default=0.05, type=float, help='saliency maps threshold to distinguish fg / bg.') 155 | parser.add_argument('--num_pairs', default=50000, type=int, help='Final number of correspondences.') 156 | 157 | return parser.parse_known_args()[0] 158 | 159 | 160 | @torch.inference_mode() 161 | def main(hparams: argparse.Namespace) -> None: 162 | if 'RANK' in os.environ: 163 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(0, hours=24)) 164 | torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) 165 | rank = int(os.environ['RANK']) 166 | world_size = int(os.environ['WORLD_SIZE']) 167 | else: 168 | rank = 0 169 | world_size = 1 170 | 171 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 172 | 173 | if rank == 0: 174 | # Create the extractor on only one device to avoid race conditions when downloading the weights across 175 | # multiple processes 176 | extractor = ViTExtractor(hparams.model_type, hparams.stride, device=device) 177 | if world_size > 1: 178 | dist.barrier() 179 | else: 180 | if world_size > 1: 181 | dist.barrier() 182 | extractor = ViTExtractor(hparams.model_type, hparams.stride, device=device) 183 | 184 | with open(hparams.metadata_path) as f: 185 | metadata = json.load(f) 186 | 187 | frames = metadata['frames'] 188 | for i in tqdm(np.arange(rank, len(frames), world_size)): 189 | frame = frames[i] 190 | for flow_path, neighbor_index in [(frame['forward_flow_path'], frame['forward_neighbor_index']), ( 191 | frame['backward_flow_path'], frame['backward_neighbor_index'])]: 192 | if flow_path is None: 193 | continue 194 | 195 | fs = get_filesystem(flow_path) 196 | 197 | if (fs is None and Path(flow_path).exists()) or (fs is not None and fs.exists(flow_path)): 198 | try: 199 | pq.read_table(flow_path, filesystem=fs)['point1_x'].to_numpy().sum() 200 | continue 201 | except: 202 | traceback.print_exc() 203 | 204 | if fs is None: 205 | parent = Path(flow_path).parent 206 | if not parent.exists(): 207 | parent.mkdir(parents=True, exist_ok=True) 208 | 209 | first = frames[min(i, neighbor_index)]['rgb_path'] 210 | second = frames[max(i, neighbor_index)]['rgb_path'] 211 | points1, points2, image1_pil, image2_pil = find_correspondences(extractor, first, second, 212 | hparams.num_pairs, hparams.load_size, 213 | hparams.layer, hparams.facet, hparams.bin, 214 | hparams.thresh) 215 | 216 | pq.write_table(pa.table({ 217 | 'point1_x': np.array([x[1] for x in points1]), 218 | 'point1_y': np.array([x[0] for x in points1]), 219 | 'point2_x': np.array([x[1] for x in points2]), 220 | 'point2_y': np.array([x[0] for x in points2]) 221 | }, metadata={'shape': ' '.join([str(image1_pil.size[1]), str(image1_pil.size[0])])}), 222 | flow_path, filesystem=fs, compression='BROTLI') 223 | 224 | if world_size > 1: 225 | dist.barrier() 226 | 227 | 228 | if __name__ == '__main__': 229 | main(_get_opts()) 230 | -------------------------------------------------------------------------------- /scripts/metadata_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List, Tuple, Optional, Set 4 | 5 | import numpy as np 6 | import torch 7 | from nerfstudio.cameras.cameras import CameraType, Cameras 8 | from smart_open import open 9 | 10 | from suds.data.image_metadata import ImageMetadata 11 | from suds.stream_utils import get_filesystem 12 | 13 | OPENCV_TO_OPENGL = torch.DoubleTensor([[1, 0, 0, 0], 14 | [0, -1, 0, 0], 15 | [0, 0, -1, 0], 16 | [0, 0, 0, 1]]) 17 | 18 | 19 | def write_metadata(output_path: str, metadata_items: List[ImageMetadata], static_masks: List[str], origin: torch.Tensor, 20 | pose_scale_factor: float, scene_bounds: torch.Tensor) -> None: 21 | if len(static_masks) > 0: 22 | assert len(metadata_items) == len(static_masks), \ 23 | 'Number of metadata items and static masks not equal: {} {}'.format(len(metadata_items), len(static_masks)) 24 | 25 | frames = [] 26 | for i, item in enumerate(metadata_items): 27 | frame_metadata = { 28 | 'image_index': item.image_index, 29 | 'rgb_path': item.image_path, 30 | 'depth_path': item.depth_path, 31 | 'feature_path': item.feature_path, 32 | 'backward_flow_path': item.backward_flow_path, 33 | 'forward_flow_path': item.forward_flow_path, 34 | 'backward_neighbor_index': item.backward_neighbor_index, 35 | 'forward_neighbor_index': item.forward_neighbor_index, 36 | 'c2w': item.c2w.tolist(), 37 | 'W': item.W, 38 | 'H': item.H, 39 | 'intrinsics': item.intrinsics.tolist(), 40 | 'time': item.time, 41 | 'video_id': item.video_id, 42 | 'is_val': item.is_val 43 | } 44 | 45 | if len(static_masks) > 0: 46 | frame_metadata['static_mask_path'] = static_masks[i] 47 | 48 | if item.mask_path is not None: 49 | frame_metadata['mask_path'] = item.mask_path 50 | 51 | if item.sky_mask_path is not None: 52 | frame_metadata['sky_mask_path'] = item.sky_mask_path 53 | 54 | frames.append(frame_metadata) 55 | 56 | if get_filesystem(output_path) is None: 57 | Path(output_path).parent.mkdir(parents=True, exist_ok=True) 58 | 59 | with open(output_path, 'w') as f: 60 | json.dump({ 61 | 'origin': origin.tolist(), 62 | 'scene_bounds': scene_bounds.tolist(), 63 | 'pose_scale_factor': pose_scale_factor, 64 | 'frames': frames 65 | }, f, indent=2) 66 | 67 | 68 | def get_bounds_from_depth(item: ImageMetadata, cur_min_bounds: Optional[torch.Tensor], 69 | cur_max_bounds: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 70 | ray_bundle = Cameras(camera_to_worlds=item.c2w, 71 | fx=item.intrinsics[0], 72 | fy=item.intrinsics[1], 73 | cx=item.intrinsics[2], 74 | cy=item.intrinsics[3], 75 | width=item.W, 76 | height=item.H, 77 | camera_type=CameraType.PERSPECTIVE).generate_rays(0) 78 | 79 | directions = ray_bundle.directions.view(-1, 3) 80 | depth = item.load_depth().view(-1) 81 | 82 | filtered_directions = directions[depth > 0] 83 | filtered_depth = depth[depth > 0].unsqueeze(-1) 84 | filtered_z_scale = ray_bundle.metadata['directions_norm'].view(-1, 1)[depth > 0] 85 | 86 | points = item.c2w[:, 3].unsqueeze(0) + filtered_directions * filtered_depth * filtered_z_scale 87 | bounds = [item.c2w[:, 3].unsqueeze(0), points] 88 | 89 | if cur_min_bounds is not None: 90 | bounds.append(cur_min_bounds.unsqueeze(0)) 91 | bounds.append(cur_max_bounds.unsqueeze(0)) 92 | 93 | bounds = torch.cat(bounds) 94 | return bounds.min(dim=0)[0], bounds.max(dim=0)[0] 95 | 96 | 97 | def scale_bounds( 98 | all_items: List[ImageMetadata], 99 | min_bounds: torch.Tensor, 100 | max_bounds: torch.Tensor) -> Tuple[torch.Tensor, float, torch.Tensor]: 101 | positions = torch.cat([x.c2w[:, 3].unsqueeze(0) for x in all_items]) 102 | 103 | print('Camera range in metric space: {} {}'.format(positions.min(dim=0)[0], positions.max(dim=0)[0])) 104 | 105 | origin = (max_bounds + min_bounds) * 0.5 106 | print('Calculated origin: {} {} {}'.format(origin, min_bounds, max_bounds)) 107 | 108 | pose_scale_factor = torch.linalg.norm((max_bounds - min_bounds) * 0.5).item() 109 | print('Calculated pose scale factor: {}'.format(pose_scale_factor)) 110 | 111 | for item in all_items: 112 | item.c2w[:, 3] = (item.c2w[:, 3] - origin) / pose_scale_factor 113 | assert torch.logical_and(item.c2w >= -1, item.c2w <= 1).all(), item.c2w 114 | 115 | scene_bounds = (torch.stack([min_bounds, max_bounds]) - origin) / pose_scale_factor 116 | 117 | return origin, pose_scale_factor, scene_bounds 118 | 119 | 120 | def normalize_timestamp(item: ImageMetadata, min_frame: int, max_frame: int) -> None: 121 | divisor = 0.5 * (max_frame - min_frame) 122 | assert divisor > 0 123 | item.time = (item.time - min_frame) / divisor - 1 124 | assert -1 <= item.time <= 1 125 | 126 | 127 | def get_frame_range(frame_ranges: List[Tuple[int]], frame: int) -> Optional[Tuple[int]]: 128 | for frame_range in frame_ranges: 129 | if frame_range[0] <= frame <= frame_range[1]: 130 | return frame_range 131 | 132 | return None 133 | 134 | 135 | def get_val_frames(num_frames: int, test_every: int, train_every: int) -> Set[int]: 136 | assert train_every is None or test_every is None 137 | if train_every is None: 138 | val_frames = set(np.arange(test_every, num_frames, test_every)) 139 | else: 140 | train_frames = set(np.arange(0, num_frames, train_every)) 141 | val_frames = (set(np.arange(num_frames)) - train_frames) if train_every > 1 else train_frames 142 | 143 | return val_frames 144 | 145 | 146 | def get_neighbor(image_index: int, val_frames: Set[int], dir: int) -> int: 147 | diff = dir 148 | while (image_index + diff) // 2 in val_frames: 149 | diff += dir 150 | 151 | return image_index + diff 152 | -------------------------------------------------------------------------------- /scripts/run_pca.py: -------------------------------------------------------------------------------- 1 | """Code adapted and modified from https://github.com/ShirAmir/dino-vit-features""" 2 | 3 | import argparse 4 | import datetime 5 | import json 6 | import os 7 | import traceback 8 | from io import BytesIO 9 | from pathlib import Path 10 | 11 | import configargparse 12 | import numpy as np 13 | import pyarrow as pa 14 | import pyarrow.parquet as pq 15 | import torch 16 | import torch.distributed as dist 17 | from sklearn.decomposition import PCA 18 | from smart_open import open 19 | from tqdm import tqdm 20 | 21 | from suds.stream_utils import get_filesystem, buffer_from_stream 22 | 23 | 24 | def _get_opts() -> argparse.Namespace: 25 | parser = configargparse.ArgParser(config_file_parser_class=configargparse.YAMLConfigFileParser) 26 | parser.add_argument('--config_file', is_config_file=True) 27 | 28 | parser.add_argument('--metadata_path', type=str, required=True) 29 | parser.add_argument('--n_components', default=64, type=int, help="number of pca components to produce.") 30 | parser.add_argument('--no_tmp_cleanup', dest='tmp_cleanup', default=True, action='store_false') 31 | 32 | return parser.parse_known_args()[0] 33 | 34 | 35 | @torch.inference_mode() 36 | def main(hparams: argparse.Namespace) -> None: 37 | with open(hparams.metadata_path) as f: 38 | metadata = json.load(f) 39 | 40 | descriptors_list = [] 41 | num_patches_list = [] 42 | 43 | frames = metadata['frames'] 44 | for frame in tqdm(frames): 45 | descriptor = torch.load(buffer_from_stream('{}.pt'.format(frame['feature_path'])), map_location='cpu') 46 | num_patches_list.append([descriptor.shape[0], descriptor.shape[1]]) 47 | descriptor /= descriptor.norm(dim=-1, keepdim=True) 48 | descriptors_list.append(descriptor.view(-1, descriptor.shape[2]).numpy()) 49 | 50 | descriptors = np.concatenate(descriptors_list, axis=0) 51 | print('Running PCA on descriptors of dim: {}'.format(descriptors.shape)) 52 | pca_descriptors = PCA(n_components=hparams.n_components, random_state=42).fit_transform(descriptors) 53 | split_idxs = np.array([num_patches[0] * num_patches[1] for num_patches in num_patches_list]) 54 | split_idxs = np.cumsum(split_idxs) 55 | pca_per_image = np.split(pca_descriptors, split_idxs[:-1], axis=0) 56 | 57 | results = [(frame, img_pca.reshape((num_patches[0], num_patches[1], hparams.n_components))) for 58 | (frame, img_pca, num_patches) in zip(frames, pca_per_image, num_patches_list)] 59 | 60 | for frame, img_pca in tqdm(results): 61 | fs = get_filesystem(frame['feature_path']) 62 | pq.write_table( 63 | pa.table({'pca': img_pca.flatten()}, metadata={'shape': ' '.join([str(x) for x in img_pca.shape])}), 64 | frame['feature_path'], filesystem=fs, compression='BROTLI') 65 | 66 | if hparams.tmp_cleanup: 67 | tmp_path = '{}.pt'.format(frame['feature_path']) 68 | if fs is None: 69 | Path(tmp_path).unlink() 70 | else: 71 | fs.rm(tmp_path) 72 | 73 | 74 | if __name__ == '__main__': 75 | main(_get_opts()) 76 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='suds_cuda', 6 | ext_modules=[ 7 | CUDAExtension('suds_cuda', [ 8 | 'suds/cpp/suds_cpp.cpp', 9 | 'suds/cpp/suds_cuda.cu' 10 | ]) 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /suds/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hturki/suds/9b472ad5e4fd4d810682af984ef65bfaf4d75188/suds/__init__.py -------------------------------------------------------------------------------- /suds/composite_proposal_network_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, List, Callable 2 | 3 | import torch 4 | from nerfstudio.cameras.rays import RayBundle, RaySamples 5 | from nerfstudio.model_components.ray_samplers import ProposalNetworkSampler 6 | 7 | 8 | class CompositeProposalNetworkSampler(ProposalNetworkSampler): 9 | 10 | def generate_ray_samples( 11 | self, 12 | ray_bundle: Optional[RayBundle], 13 | static_density_fns: List[Callable], 14 | dynamic_density_fns: List[Callable], 15 | static_only: bool, 16 | dynamic_only: bool, 17 | filter_fn: Optional[Callable] 18 | ) -> Tuple[RaySamples, List, List, List, List]: 19 | weights_list = [] 20 | ray_samples_list = [] 21 | static_weights_list = [] 22 | dynamic_weights_list = [] 23 | 24 | self.initial_sampler.training = False 25 | self.pdf_sampler.training = False 26 | 27 | n = self.num_proposal_network_iterations 28 | weights = None 29 | ray_samples = None 30 | 31 | for i_level in range(n + 1): 32 | is_prop = i_level < n 33 | num_samples = self.num_proposal_samples_per_ray[i_level] if is_prop else self.num_nerf_samples_per_ray 34 | if i_level == 0: 35 | # Uniform sampling because we need to start with some samples 36 | ray_samples = self.initial_sampler(ray_bundle, num_samples=num_samples) 37 | else: 38 | # PDF sampling based on the last samples and their weights 39 | # Perform annealing to the weights. This will be a no-op if self._anneal is 1.0. 40 | assert weights is not None 41 | annealed_weights = torch.pow(weights, self._anneal) 42 | ray_samples = self.pdf_sampler(ray_bundle, ray_samples, annealed_weights, num_samples=num_samples) 43 | 44 | if is_prop: 45 | if not dynamic_only: 46 | static_density = static_density_fns[i_level](ray_samples.frustums.get_positions()) 47 | 48 | if not static_only: 49 | dynamic_density = dynamic_density_fns[i_level](ray_samples.frustums.get_positions()) 50 | 51 | if static_only: 52 | to_use = static_density 53 | elif dynamic_only: 54 | to_use = dynamic_density 55 | else: 56 | to_use = static_density + dynamic_density 57 | 58 | if filter_fn is not None: 59 | to_keep = filter_fn(ray_samples) 60 | to_use[to_keep <= 0] = 0 61 | 62 | weights = ray_samples.get_weights(to_use) 63 | weights_list.append(weights) # (num_rays, num_samples) 64 | ray_samples_list.append(ray_samples) 65 | 66 | if not dynamic_only: 67 | static_weights_list.append(ray_samples.get_weights(static_density)) 68 | 69 | if not static_only: 70 | dynamic_weights_list.append(ray_samples.get_weights(dynamic_density)) 71 | 72 | assert ray_samples is not None 73 | return ray_samples, weights_list, ray_samples_list, static_weights_list, dynamic_weights_list 74 | -------------------------------------------------------------------------------- /suds/cpp/suds_cpp.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 6 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 8 | 9 | #define CHECK_LAST_DIM(x, y) TORCH_CHECK(x.size(-1) == y, #x " should have last dim = " #y) 10 | 11 | torch::Tensor video_embedding_forward_cuda( 12 | torch::Tensor time, 13 | torch::Tensor video_id, 14 | torch::Tensor weights, 15 | const int num_frequencies); 16 | 17 | torch::Tensor video_embedding_backward_cuda( 18 | torch::Tensor d_loss_embedding, 19 | torch::Tensor time, 20 | torch::Tensor video_id, 21 | const int num_sequences, 22 | const int num_frequencies); 23 | 24 | torch::Tensor video_embedding_forward( 25 | torch::Tensor time, 26 | torch::Tensor video_id, 27 | torch::Tensor weights, 28 | const int num_frequencies) { 29 | CHECK_INPUT(time); 30 | CHECK_INPUT(video_id); 31 | CHECK_INPUT(weights); 32 | CHECK_LAST_DIM(weights, num_frequencies * 2 + 1); 33 | const at::cuda::OptionalCUDAGuard device_guard(device_of(time)); 34 | return video_embedding_forward_cuda(time, video_id, weights, num_frequencies); 35 | } 36 | 37 | torch::Tensor video_embedding_backward( 38 | torch::Tensor d_loss_embedding, 39 | torch::Tensor time, 40 | torch::Tensor video_id, 41 | const int num_sequences, 42 | const int num_frequencies) { 43 | CHECK_INPUT(d_loss_embedding); 44 | CHECK_INPUT(time); 45 | CHECK_INPUT(video_id); 46 | 47 | const at::cuda::OptionalCUDAGuard device_guard(device_of(time)); 48 | return video_embedding_backward_cuda(d_loss_embedding, time, video_id, num_sequences, num_frequencies); 49 | } 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("video_embedding_forward", &video_embedding_forward); 53 | m.def("video_embedding_backward", &video_embedding_backward); 54 | } 55 | -------------------------------------------------------------------------------- /suds/cpp/suds_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #define N_BLOCKS_NEEDED(Q, N_CUDA_THREADS) ((Q - 1) / N_CUDA_THREADS + 1) 8 | #define CUDA_GET_THREAD_ID(tid, Q) \ 9 | const int tid = blockIdx.x * blockDim.x + threadIdx.x; \ 10 | if (tid >= Q) return 11 | 12 | // Automatically choose number of CUDA threads based on HW CUDA kernel count 13 | int cuda_n_threads = -1; 14 | 15 | __host__ int get_sp_cores(cudaDeviceProp devProp) { 16 | int cores = 0; 17 | int mp = devProp.multiProcessorCount; 18 | switch (devProp.major) { 19 | case 2: // Fermi 20 | if (devProp.minor == 1) 21 | cores = mp * 48; 22 | else 23 | cores = mp * 32; 24 | break; 25 | case 3: // Kepler 26 | cores = mp * 192; 27 | break; 28 | case 5: // Maxwell 29 | cores = mp * 128; 30 | break; 31 | case 6: // Pascal 32 | if ((devProp.minor == 1) || (devProp.minor == 2)) 33 | cores = mp * 128; 34 | else if (devProp.minor == 0) 35 | cores = mp * 64; 36 | break; 37 | case 7: // Volta and Turing 38 | if ((devProp.minor == 0) || (devProp.minor == 5)) cores = mp * 64; 39 | break; 40 | case 8: // Ampere 41 | if (devProp.minor == 0) 42 | cores = mp * 64; 43 | else if (devProp.minor == 6) 44 | cores = mp * 128; 45 | break; 46 | default: 47 | break; 48 | } 49 | return cores; 50 | } 51 | 52 | __host__ void auto_cuda_threads() { 53 | if (~cuda_n_threads) return; 54 | cudaDeviceProp dev_prop; 55 | cudaGetDeviceProperties(&dev_prop, 0); 56 | const int n_cores = get_sp_cores(dev_prop); 57 | // Optimize number of CUDA threads per block 58 | if (n_cores < 2048) { 59 | cuda_n_threads = 256; 60 | } 61 | if (n_cores < 8192) { 62 | cuda_n_threads = 512; 63 | } else { 64 | cuda_n_threads = 1024; 65 | } 66 | } 67 | 68 | template 69 | __global__ void video_embedding_forward_kernel( 70 | const torch::PackedTensorAccessor32 time, 71 | const torch::PackedTensorAccessor32 video_id, 72 | const torch::PackedTensorAccessor32 weights, 73 | torch::PackedTensorAccessor32 embedding, 74 | const int appearance_dim, 75 | const int num_frequencies, 76 | size_t num_items 77 | ) { 78 | CUDA_GET_THREAD_ID(idx, num_items); 79 | 80 | const int i = idx / appearance_dim; 81 | const int j = idx - i * appearance_dim; 82 | 83 | scalar_t result = weights[video_id[i]][j][0] * time[i]; 84 | for (uint32_t log2_frequency = 0; log2_frequency < num_frequencies; log2_frequency++) { 85 | const scalar_t x = scalbn(time[i], log2_frequency); 86 | result += weights[video_id[i]][j][log2_frequency * 2 + 1] * sin(x); 87 | result += weights[video_id[i]][j][log2_frequency * 2 + 2] * cos(x); 88 | } 89 | 90 | embedding[i][j] = result; 91 | } 92 | 93 | template 94 | __global__ void video_embedding_backward_kernel( 95 | const torch::PackedTensorAccessor32 d_loss_embedding, 96 | const torch::PackedTensorAccessor32 time, 97 | const torch::PackedTensorAccessor32 video_id, 98 | torch::PackedTensorAccessor32 d_loss_weights, 99 | const int appearance_dim, 100 | const int num_frequencies, 101 | size_t num_items 102 | ) { 103 | CUDA_GET_THREAD_ID(idx, num_items); 104 | 105 | const int i = idx / appearance_dim; 106 | const int j = idx - i * appearance_dim; 107 | 108 | atomicAdd(&d_loss_weights[video_id[i]][j][0], d_loss_embedding[i][j] * time[i]); 109 | for (uint32_t log2_frequency = 0; log2_frequency < num_frequencies; log2_frequency++) { 110 | const scalar_t x = scalbn(time[i], log2_frequency); 111 | atomicAdd(&d_loss_weights[video_id[i]][j][log2_frequency * 2 + 1], d_loss_embedding[i][j] * sin(x)); 112 | atomicAdd(&d_loss_weights[video_id[i]][j][log2_frequency * 2 + 2], d_loss_embedding[i][j] * cos(x)); 113 | } 114 | } 115 | 116 | torch::Tensor video_embedding_forward_cuda( 117 | torch::Tensor time, 118 | torch::Tensor video_id, 119 | torch::Tensor weights, 120 | const int num_frequencies) { 121 | auto embedding = torch::empty({time.size(0), weights.size(1)}, 122 | torch::TensorOptions().device(weights.device()).dtype(weights.scalar_type())); 123 | 124 | auto_cuda_threads(); 125 | const int blocks = N_BLOCKS_NEEDED(time.size(0) * weights.size(1), cuda_n_threads); 126 | 127 | AT_DISPATCH_FLOATING_TYPES(weights.scalar_type(), "video_embedding_forward_cuda", ([&] { 128 | video_embedding_forward_kernel<<>>( 129 | time.packed_accessor32(), 130 | video_id.packed_accessor32(), 131 | weights.packed_accessor32(), 132 | embedding.packed_accessor32(), 133 | weights.size(1), 134 | num_frequencies, 135 | time.size(0) * weights.size(1)); 136 | })); 137 | 138 | return embedding; 139 | } 140 | 141 | torch::Tensor video_embedding_backward_cuda( 142 | torch::Tensor d_loss_embedding, 143 | torch::Tensor time, 144 | torch::Tensor video_id, 145 | const int num_sequences, 146 | const int num_frequencies) { 147 | auto d_loss_weights = torch::zeros({num_sequences, d_loss_embedding.size(1), num_frequencies * 2 + 1}, 148 | torch::TensorOptions().device(d_loss_embedding.device()).dtype(d_loss_embedding.scalar_type())); 149 | 150 | auto_cuda_threads(); 151 | const int blocks = N_BLOCKS_NEEDED(time.size(0) * d_loss_embedding.size(1), cuda_n_threads); 152 | 153 | AT_DISPATCH_FLOATING_TYPES(d_loss_embedding.scalar_type(), "video_embedding_backward_cuda", ([&] { 154 | video_embedding_backward_kernel<<>>( 155 | d_loss_embedding.packed_accessor32(), 156 | time.packed_accessor32(), 157 | video_id.packed_accessor32(), 158 | d_loss_weights.packed_accessor32(), 159 | d_loss_embedding.size(1), 160 | num_frequencies, 161 | time.size(0) * d_loss_embedding.size(1)); 162 | })); 163 | 164 | return d_loss_weights; 165 | } 166 | -------------------------------------------------------------------------------- /suds/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hturki/suds/9b472ad5e4fd4d810682af984ef65bfaf4d75188/suds/data/__init__.py -------------------------------------------------------------------------------- /suds/data/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | from suds.data.image_metadata import ImageMetadata 6 | 7 | 8 | def get_w2c_and_K(item: ImageMetadata) -> Tuple[torch.Tensor, torch.Tensor]: 9 | K = torch.eye(3) 10 | K[0, 0] = item.intrinsics[0] 11 | K[1, 1] = item.intrinsics[1] 12 | K[0, 2] = item.intrinsics[2] 13 | K[1, 2] = item.intrinsics[3] 14 | 15 | c2w_4x4 = torch.eye(4) 16 | c2w_4x4[:3] = item.c2w 17 | w2c = torch.inverse(c2w_4x4) 18 | 19 | return w2c, K 20 | -------------------------------------------------------------------------------- /suds/data/image_metadata.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import shutil 4 | import uuid 5 | from pathlib import Path 6 | from typing import Optional, Tuple 7 | 8 | import cv2 9 | import numpy as np 10 | import pyarrow.parquet as pq 11 | import torch 12 | import torch.nn.functional as F 13 | from PIL import Image 14 | 15 | from suds.stream_utils import (buffer_from_stream, image_from_stream, get_filesystem, table_from_stream) 16 | 17 | 18 | class ImageMetadata: 19 | def __init__(self, image_path: str, c2w: torch.Tensor, W: int, H: int, intrinsics: torch.Tensor, image_index: int, 20 | time: float, video_id: int, depth_path: str, mask_path: Optional[str], sky_mask_path: Optional[str], 21 | feature_path: Optional[str], backward_flow_path: Optional[str], forward_flow_path: Optional[str], 22 | backward_neighbor_index: Optional[int], forward_neighbor_index: Optional[int], is_val: bool, 23 | pose_scale_factor: float, local_cache: Optional[Path]): 24 | self.image_path = image_path 25 | self.c2w = c2w 26 | self.W = W 27 | self.H = H 28 | self.intrinsics = intrinsics 29 | self.image_index = image_index 30 | self.time = time 31 | self.video_id = video_id 32 | self.depth_path = depth_path 33 | self.mask_path = mask_path 34 | self.sky_mask_path = sky_mask_path 35 | self.feature_path = feature_path 36 | self.backward_flow_path = backward_flow_path 37 | self.forward_flow_path = forward_flow_path 38 | self.backward_neighbor_index = backward_neighbor_index 39 | self.forward_neighbor_index = forward_neighbor_index 40 | self.is_val = is_val 41 | 42 | self._pose_scale_factor = pose_scale_factor 43 | self._local_cache = local_cache 44 | 45 | def load_image(self) -> torch.Tensor: 46 | if self._local_cache is not None and not self.image_path.startswith(str(self._local_cache)): 47 | self.image_path = self._load_from_cache(self.image_path) 48 | 49 | rgbs = image_from_stream(self.image_path).convert('RGB') 50 | size = rgbs.size 51 | 52 | if size[0] != self.W or size[1] != self.H: 53 | rgbs = rgbs.resize((self.W, self.H), Image.LANCZOS) 54 | 55 | return torch.ByteTensor(np.asarray(rgbs)) 56 | 57 | def load_mask(self) -> torch.Tensor: 58 | if self.mask_path is None: 59 | return torch.ones(self.H, self.W, dtype=torch.bool) 60 | 61 | if self._local_cache is not None and not self.mask_path.startswith(str(self._local_cache)): 62 | self.mask_path = self._load_from_cache(self.mask_path) 63 | 64 | mask = image_from_stream(self.mask_path) 65 | size = mask.size 66 | 67 | if size[0] != self.W or size[1] != self.H: 68 | mask = mask.resize((self.W, self.H), Image.NEAREST) 69 | 70 | return torch.BoolTensor(np.asarray(mask)) 71 | 72 | def load_sky_mask(self) -> torch.Tensor: 73 | if self.sky_mask_path is None: 74 | return torch.zeros(self.H, self.W, dtype=torch.bool) 75 | 76 | if self._local_cache is not None and not self.sky_mask_path.startswith(str(self._local_cache)): 77 | self.sky_mask_path = self._load_from_cache(self.sky_mask_path) 78 | 79 | sky_mask = image_from_stream(self.sky_mask_path) 80 | size = sky_mask.size 81 | 82 | if size[0] != self.W or size[1] != self.H: 83 | sky_mask = sky_mask.resize((self.W, self.H), Image.NEAREST) 84 | 85 | return torch.BoolTensor(np.asarray(sky_mask)) 86 | 87 | def load_features(self, resize: bool = True) -> torch.Tensor: 88 | assert self.feature_path is not None 89 | 90 | if self._local_cache is not None and not self.feature_path.startswith(str(self._local_cache)): 91 | self.feature_path = self._load_from_cache(self.feature_path) 92 | 93 | table = table_from_stream(self.feature_path) 94 | features = torch.FloatTensor(table['pca'].to_numpy()).view( 95 | [int(x) for x in table.schema.metadata[b'shape'].split()]) 96 | 97 | if (features.shape[0] != self.H or features.shape[1] != self.W) and resize: 98 | features = F.interpolate(features.permute(2, 0, 1).unsqueeze(0), size=(self.H, self.W)).squeeze() \ 99 | .permute(1, 2, 0) 100 | 101 | return features 102 | 103 | def load_depth(self) -> torch.Tensor: 104 | if self._local_cache is not None and not self.depth_path.startswith(str(self._local_cache)): 105 | self.depth_path = self._load_from_cache(self.depth_path) 106 | 107 | if self.depth_path.endswith('.parquet'): 108 | table = table_from_stream(self.depth_path) 109 | 110 | # Get original depth dimensions 111 | size = image_from_stream(self.image_path).size 112 | 113 | depth = torch.FloatTensor(table['depth'].to_numpy()).view(size[1], size[0]) 114 | else: 115 | # Assume it's vkitti2 format 116 | depth = np.array(image_from_stream(self.depth_path)) 117 | depth[depth == 65535] = -1 118 | 119 | # depth is in cm - convert to meters 120 | depth = torch.FloatTensor(depth / 100) 121 | 122 | if depth.shape[0] != self.H or depth.shape[1] != self.W: 123 | depth = F.interpolate(depth.unsqueeze(0).unsqueeze(0), size=(self.H, self.W)).squeeze() 124 | 125 | return depth / self._pose_scale_factor 126 | 127 | def load_backward_flow(self) -> Tuple[torch.Tensor, torch.Tensor]: 128 | return self._load_flow(self.backward_flow_path, False) 129 | 130 | def load_forward_flow(self) -> Tuple[torch.Tensor, torch.Tensor]: 131 | return self._load_flow(self.forward_flow_path, True) 132 | 133 | def _load_flow(self, flow_path: Optional[str], is_forward: bool) -> Tuple[torch.Tensor, torch.Tensor]: 134 | if flow_path is None: 135 | return torch.zeros(self.H, self.W, 2), torch.zeros(self.H, self.W, dtype=torch.bool) 136 | 137 | if self._local_cache is not None and not flow_path.startswith(str(self._local_cache)): 138 | flow_path = self._load_from_cache(flow_path) 139 | if is_forward: 140 | self.forward_flow_path = flow_path 141 | else: 142 | self.backward_flow_path = flow_path 143 | 144 | if flow_path.endswith('.parquet'): 145 | table = pq.read_table(flow_path, filesystem=get_filesystem(flow_path)) 146 | 147 | if 'flow' in table.column_names: 148 | flow = torch.FloatTensor(table['flow'].to_numpy()).view( 149 | [int(x) for x in table.schema.metadata[b'shape'].split()]) 150 | if len(flow.shape) == 4: 151 | flow = flow.squeeze().permute(1, 2, 0) 152 | 153 | flow_valid = torch.ones_like(flow[:, :, 0], dtype=torch.bool) 154 | else: 155 | point1 = torch.LongTensor(table.to_pandas()[['point1_x', 'point1_y']].to_numpy()) 156 | point2 = torch.LongTensor(table.to_pandas()[['point2_x', 'point2_y']].to_numpy()) 157 | 158 | correspondences = (point2 - point1) if is_forward else (point1 - point2) 159 | to_index = point1 if is_forward else point2 160 | 161 | orig_H, orig_W = [int(x) for x in table.schema.metadata[b'shape'].split()] 162 | flow = torch.zeros(orig_H, orig_W, 2) 163 | flow_valid = torch.zeros(orig_H, orig_W, dtype=torch.bool) 164 | flow.view(-1, 2)[to_index[:, 0] + to_index[:, 1] * orig_W] = correspondences.float() 165 | flow_valid.view(-1)[to_index[:, 0] + to_index[:, 1] * orig_W] = True 166 | else: 167 | quantized_flow = cv2.imdecode(np.frombuffer(buffer_from_stream(flow_path).getbuffer(), dtype=np.uint8), 168 | cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 169 | # From https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/ 170 | _, _, _c = quantized_flow.shape 171 | assert quantized_flow.dtype == np.uint16 and _c == 3 172 | # b == invalid flow flag == 0 for sky or other invalid flow 173 | invalid = quantized_flow[:, :, 0] == 0 174 | # g,r == flow_y,x normalized by height,width and scaled to [0;2**16 – 1] 175 | flow = 2.0 / (2 ** 16 - 1.0) * quantized_flow[:, :, 2:0:-1].astype(np.float32) - 1 176 | flow[:, :, 0] *= flow.shape[1] - 1 177 | flow[:, :, 1] *= flow.shape[0] - 1 178 | flow[invalid] = 0 # or another value (e.g., np.nan) 179 | 180 | flow = torch.FloatTensor(flow) 181 | flow_valid = torch.BoolTensor(quantized_flow[:, :, 0] != 0) 182 | 183 | if flow.shape[0] != self.H or flow.shape[1] != self.W: 184 | flow[:, :, 0] *= (self.W / flow.shape[1]) 185 | flow[:, :, 1] *= (self.H / flow.shape[0]) 186 | flow = F.interpolate(flow.permute(2, 0, 1).unsqueeze(0), size=(self.H, self.W)).squeeze().permute(1, 2, 0) 187 | flow_valid = F.interpolate(flow_valid.unsqueeze(0).unsqueeze(0).float(), 188 | size=(self.H, self.W)).bool().squeeze() 189 | 190 | return flow, flow_valid 191 | 192 | def _load_from_cache(self, remote_path: str) -> str: 193 | sha_hash = hashlib.sha256() 194 | sha_hash.update(remote_path.encode('utf-8')) 195 | hashed = sha_hash.hexdigest() 196 | cache_path = self._local_cache / hashed[:2] / hashed[2:4] / '{}{}'.format(hashed, Path(remote_path).suffix) 197 | 198 | if cache_path.exists(): 199 | return str(cache_path) 200 | 201 | cache_path.parent.mkdir(parents=True, exist_ok=True) 202 | tmp_path = '{}.{}'.format(cache_path, uuid.uuid4()) 203 | remote_filesystem = get_filesystem(remote_path) 204 | 205 | if remote_filesystem is not None: 206 | remote_filesystem.get(remote_path, tmp_path) 207 | else: 208 | shutil.copy(remote_path, tmp_path) 209 | 210 | os.rename(tmp_path, cache_path) 211 | return str(cache_path) 212 | -------------------------------------------------------------------------------- /suds/data/stream_input_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from nerfstudio.data.datasets.base_dataset import InputDataset 3 | 4 | import numpy as np 5 | import numpy.typing as npt 6 | 7 | from suds.stream_utils import image_from_stream 8 | 9 | 10 | class StreamInputDataset(InputDataset): 11 | 12 | def get_numpy_image(self, image_idx: int) -> npt.NDArray[np.uint8]: 13 | """Returns the image of shape (H, W, 3 or 4). 14 | 15 | Args: 16 | image_idx: The image index in the dataset. 17 | """ 18 | image_filename = str(self._dataparser_outputs.image_filenames[image_idx]) 19 | # Converting to path mangles s3:// -> s3:/ 20 | pil_image = image_from_stream(image_filename.replace('s3:/', 's3://').replace('gs:/', 'gs://')) 21 | if self.scale_factor != 1.0: 22 | width, height = pil_image.size 23 | newsize = (int(width * self.scale_factor), int(height * self.scale_factor)) 24 | pil_image = pil_image.resize(newsize, resample=Image.BILINEAR) 25 | image = np.array(pil_image, dtype="uint8") # shape is (h, w, 3 or 4) 26 | assert len(image.shape) == 3 27 | assert image.dtype == np.uint8 28 | assert image.shape[2] in [3, 4], f"Image shape of {image.shape} is incorrect." 29 | return image -------------------------------------------------------------------------------- /suds/data/suds_dataparser.py: -------------------------------------------------------------------------------- 1 | """ Data parser for SUDS datasets. """ 2 | 3 | from __future__ import annotations 4 | 5 | import json 6 | from dataclasses import dataclass, field 7 | from pathlib import Path 8 | from typing import Type, List, Optional, Set, Dict, Any 9 | 10 | import torch 11 | import tyro 12 | from nerfstudio.cameras.cameras import Cameras, CameraType 13 | from nerfstudio.configs.config_utils import to_immutable_dict 14 | from nerfstudio.data.dataparsers.base_dataparser import ( 15 | DataParser, 16 | DataParserConfig, 17 | DataparserOutputs, 18 | ) 19 | from nerfstudio.data.scene_box import SceneBox 20 | from rich.console import Console 21 | from smart_open import open 22 | 23 | from suds.data.image_metadata import ImageMetadata 24 | 25 | CONSOLE = Console(width=120) 26 | ALL_ITEMS = 'all_items' 27 | ALL_CAMERAS = 'all_cameras' 28 | POSE_SCALE_FACTOR = 'pose_scale_factor' 29 | ORIGIN = 'origin' 30 | 31 | 32 | @dataclass 33 | class SUDSDataParserConfig(DataParserConfig): 34 | """SUDS dataset config""" 35 | 36 | _target: Type = field(default_factory=lambda: SUDSDataParser) 37 | """target class to instantiate""" 38 | metadata_path: str = 'metadata.json' 39 | """Directory specifying location of data.""" 40 | scale_factor: float = 1.0 41 | """How much to scale the camera origins by.""" 42 | scene_scale: float = 1.0 43 | """How much to scale the region of interest by.""" 44 | train_downscale_factor: float = 1 45 | """How much to downscale images used for training.""" 46 | eval_downscale_factor: float = 1 47 | """How much to downscale images used for evaluation.""" 48 | train_with_val_images: bool = False 49 | """Whether to include the validation images when training.""" 50 | static_only: bool = False 51 | """Whether to include static pixels when training.""" 52 | local_cache_path: Optional[str] = None 53 | """Caches images and metadata in specific path if set.""" 54 | 55 | metadata: tyro.conf.Suppress[Optional[Dict[str, Any]]] = None 56 | 57 | 58 | @dataclass 59 | class SUDSDataParser(DataParser): 60 | """SUDS DatasetParser""" 61 | 62 | config: SUDSDataParserConfig 63 | 64 | def get_dataparser_outputs(self, split='train', indices: Optional[Set[int]] = None) -> DataparserOutputs: 65 | # Cache json load - SUDSManager will clear it when it's no longer needed 66 | if self.config.metadata is None: 67 | with open(self.config.metadata_path) as f: 68 | self.config.metadata = json.load(f) 69 | 70 | if all([f['is_val'] for f in self.config.metadata['frames']]): 71 | self.config.train_with_val_images = True 72 | 73 | downscale_factor = self.config.train_downscale_factor if split == 'train' else self.config.eval_downscale_factor 74 | all_items = [] 75 | split_items = [] 76 | image_filenames = [] 77 | mask_filenames = [] 78 | 79 | local_cache_path = Path(self.config.local_cache_path) if self.config.local_cache_path is not None else None 80 | frames = self.config.metadata['frames'] 81 | for frame_index in range(len(frames)): 82 | frame = frames[frame_index] 83 | c2w = torch.FloatTensor(frame['c2w']) 84 | c2w[:, 3] /= self.config.scale_factor 85 | 86 | item = ImageMetadata(frame['rgb_path'], 87 | c2w, 88 | int(frame['W'] // downscale_factor), 89 | int(frame['H'] // downscale_factor), 90 | torch.FloatTensor(frame['intrinsics']) / downscale_factor, 91 | frame['image_index'], 92 | frame['time'], 93 | frame['video_id'], 94 | frame['depth_path'], 95 | frame.get('static_mask_path' if self.config.static_only else 'mask_path', None), 96 | frame.get('sky_mask_path', None), 97 | frame.get('feature_path', None), 98 | frame.get('backward_flow_path', None), 99 | frame.get('forward_flow_path', None), 100 | frame.get('backward_neighbor_index', None), 101 | frame.get('forward_neighbor_index', None), 102 | frame['is_val'], 103 | self.config.metadata['pose_scale_factor'], 104 | local_cache_path) 105 | 106 | all_items.append(item) 107 | 108 | # Keep the image indices consistent between training and validation 109 | if split == 'train': 110 | if frame['is_val'] and not self.config.train_with_val_images: 111 | continue 112 | elif not frame['is_val']: 113 | continue 114 | 115 | if indices is not None and frame_index not in indices: 116 | continue 117 | 118 | split_items.append(item) 119 | image_filenames.append(Path(item.image_path)) 120 | if item.mask_path is not None: 121 | mask_filenames.append(Path(item.mask_path)) 122 | 123 | assert ( 124 | len(image_filenames) != 0 125 | ), """ 126 | No image files found. 127 | You should check the file_paths in the transforms.json file to make sure they are correct. 128 | """ 129 | assert len(mask_filenames) == 0 or ( 130 | len(mask_filenames) == len(image_filenames) 131 | ), """ 132 | Different number of image and mask filenames. 133 | You should check that mask_path is specified for every frame (or zero frames) in transforms.json. 134 | """ 135 | 136 | scene_box = SceneBox( 137 | aabb=torch.tensor(self.config.metadata['scene_bounds']) * self.config.scene_scale 138 | ) 139 | 140 | dataparser_outputs = DataparserOutputs( 141 | image_filenames=image_filenames, 142 | cameras=self.create_cameras(split_items), 143 | scene_box=scene_box, 144 | mask_filenames=mask_filenames if len(mask_filenames) > 0 else None, 145 | metadata={ 146 | ALL_ITEMS: all_items, 147 | ALL_CAMERAS: self.create_cameras(all_items), 148 | POSE_SCALE_FACTOR: self.config.metadata['pose_scale_factor'], 149 | ORIGIN: self.config.metadata['origin'] 150 | } 151 | ) 152 | 153 | return dataparser_outputs 154 | 155 | @staticmethod 156 | def create_cameras(metadata_items: List[ImageMetadata]) -> Cameras: 157 | return Cameras( 158 | camera_to_worlds=torch.stack([x.c2w for x in metadata_items]), 159 | fx=torch.FloatTensor([x.intrinsics[0] for x in metadata_items]), 160 | fy=torch.FloatTensor([x.intrinsics[1] for x in metadata_items]), 161 | cx=torch.FloatTensor([x.intrinsics[2] for x in metadata_items]), 162 | cy=torch.FloatTensor([x.intrinsics[3] for x in metadata_items]), 163 | width=torch.IntTensor([x.W for x in metadata_items]), 164 | height=torch.IntTensor([x.H for x in metadata_items]), 165 | camera_type=CameraType.PERSPECTIVE, 166 | times=torch.FloatTensor([x.time for x in metadata_items]).unsqueeze(-1) 167 | ) 168 | -------------------------------------------------------------------------------- /suds/data/suds_eval_dataloader.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Optional, List, Dict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from nerfstudio.cameras.cameras import Cameras 6 | from nerfstudio.cameras.rays import RayBundle 7 | from nerfstudio.utils.comms import get_rank, get_world_size 8 | from torch.utils.data import DataLoader, TensorDataset 9 | 10 | from suds.data.dataset_utils import get_w2c_and_K 11 | from suds.data.image_metadata import ImageMetadata 12 | from suds.data.suds_dataset import RGB, DEPTH, BACKWARD_FLOW, FORWARD_FLOW, VIDEO_ID 13 | from suds.kmeans import kmeans_predict 14 | from suds.suds_constants import MASK, FEATURES, RAY_INDEX, BACKWARD_NEIGHBOR_TIME_DIFF, FORWARD_NEIGHBOR_TIME_DIFF, \ 15 | BACKWARD_NEIGHBOR_W2C, BACKWARD_NEIGHBOR_K, FORWARD_NEIGHBOR_W2C, FORWARD_NEIGHBOR_K, SKY 16 | 17 | 18 | class SUDSEvalDataLoader(DataLoader): 19 | 20 | def __init__( 21 | self, 22 | all_items: List[ImageMetadata], 23 | cameras: Cameras, 24 | load_depth: bool, 25 | load_features: bool, 26 | load_flow: bool, 27 | load_sky: bool, 28 | feature_clusters: Dict[str, torch.Tensor], 29 | feature_colors: Dict[str, torch.Tensor], 30 | image_indices: Optional[Tuple[int]] = None, 31 | device: Union[torch.device, str] = 'cpu' 32 | ): 33 | if image_indices is None: 34 | self.image_indices = [] 35 | val_items = list(filter(lambda x: x.is_val, all_items)) 36 | for item_index in range(get_rank(), len(val_items), get_world_size()): 37 | self.image_indices.append(val_items[item_index].image_index) 38 | else: 39 | self.image_indices = image_indices 40 | 41 | super().__init__(dataset=TensorDataset(torch.LongTensor(self.image_indices))) 42 | 43 | self.all_items = all_items 44 | self.cameras = cameras.to(device) 45 | self.load_depth = load_depth 46 | self.load_features = load_features 47 | self.load_flow = load_flow 48 | self.load_sky = load_sky 49 | self.feature_clusters = feature_clusters 50 | self.feature_colors = feature_colors 51 | self.device = device 52 | self.count = 0 53 | 54 | def __iter__(self): 55 | self.count = 0 56 | return self 57 | 58 | def __next__(self) -> Tuple[RayBundle, Dict]: 59 | if self.count < len(self.image_indices): 60 | data = self.get_image_data(self.image_indices[self.count]) 61 | self.count += 1 62 | return data 63 | 64 | raise StopIteration 65 | 66 | def get_image_data(self, image_index: int) -> Tuple[RayBundle, Dict]: 67 | metadata_item = self.all_items[image_index] 68 | 69 | batch = { 70 | RGB: metadata_item.load_image().float().to(self.device) / 255., 71 | MASK: metadata_item.load_mask().to(self.device) 72 | } 73 | 74 | if self.load_depth: 75 | batch[DEPTH] = metadata_item.load_depth().to(self.device) 76 | 77 | if self.load_features: 78 | features = metadata_item.load_features(False).to(self.device) 79 | for key, val in self.feature_clusters.items(): 80 | feature_colors = self.feature_colors[key].to(self.device)[ 81 | kmeans_predict(features.view(-1, features.shape[-1]), val.to(self.device), device=self.device, 82 | tqdm_flag=False)].view((*features.shape[:-1], 3)) 83 | 84 | if feature_colors.shape[0] != metadata_item.H or feature_colors.shape[1] != metadata_item.W: 85 | feature_colors = F.interpolate(feature_colors.permute(2, 0, 1).unsqueeze(0), 86 | size=(metadata_item.H, metadata_item.W)).squeeze().permute(1, 2, 0) 87 | 88 | batch[f'{FEATURES}_{key}'] = feature_colors 89 | 90 | ray_bundle = self.cameras.generate_rays(camera_indices=image_index, keep_shape=True) 91 | 92 | if self.load_flow: 93 | if metadata_item.backward_neighbor_index is not None: 94 | batch[BACKWARD_FLOW] = metadata_item.load_backward_flow()[0].to(self.device) 95 | backward_w2c, backward_K = get_w2c_and_K(self.all_items[metadata_item.backward_neighbor_index]) 96 | ray_bundle.metadata[BACKWARD_NEIGHBOR_W2C] = backward_w2c.to(self.device).reshape(1, 1, 16).expand( 97 | metadata_item.H, metadata_item.W, -1) 98 | ray_bundle.metadata[BACKWARD_NEIGHBOR_K] = backward_K.to(self.device).reshape(1, 1, 9).expand( 99 | metadata_item.H, metadata_item.W, -1) 100 | ray_bundle.metadata[BACKWARD_NEIGHBOR_TIME_DIFF] = \ 101 | torch.ones_like(ray_bundle.origins[..., 0:1], dtype=torch.long) * \ 102 | (metadata_item.time - self.all_items[metadata_item.backward_neighbor_index].time) 103 | 104 | if metadata_item.forward_neighbor_index is not None: 105 | batch[FORWARD_FLOW] = metadata_item.load_forward_flow()[0].to(self.device) 106 | forward_w2c, forward_K = get_w2c_and_K(self.all_items[metadata_item.forward_neighbor_index]) 107 | ray_bundle.metadata[FORWARD_NEIGHBOR_W2C] = forward_w2c.to(self.device).reshape(1, 1, 16).expand( 108 | metadata_item.H, metadata_item.W, -1) 109 | ray_bundle.metadata[FORWARD_NEIGHBOR_K] = forward_K.to(self.device).reshape(1, 1, 9).expand( 110 | metadata_item.H, metadata_item.W, -1) 111 | ray_bundle.metadata[FORWARD_NEIGHBOR_TIME_DIFF] = \ 112 | torch.ones_like(ray_bundle.origins[..., 0:1], dtype=torch.long) * \ 113 | (self.all_items[metadata_item.forward_neighbor_index].time - metadata_item.time) 114 | 115 | pixel_indices = torch.arange(metadata_item.W * metadata_item.H, device=self.device).unsqueeze(-1) 116 | v = pixel_indices // metadata_item.W 117 | u = pixel_indices % metadata_item.W 118 | image_indices = torch.ones_like(pixel_indices) * metadata_item.image_index 119 | batch[RAY_INDEX] = torch.cat([image_indices, v, u], -1).view(metadata_item.H, metadata_item.W, 3) 120 | 121 | if self.load_sky: 122 | batch[SKY] = metadata_item.load_sky_mask().to(self.device) 123 | 124 | ray_bundle.times = torch.ones_like(ray_bundle.origins[..., 0:1]) * metadata_item.time 125 | ray_bundle.metadata[VIDEO_ID] = torch.ones_like(ray_bundle.times, dtype=torch.int32) * metadata_item.video_id 126 | return ray_bundle, batch 127 | -------------------------------------------------------------------------------- /suds/data/suds_pipeline.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from pathlib import Path 3 | from time import time 4 | from typing import Type, List, Optional 5 | 6 | import torch 7 | from PIL import Image 8 | from nerfstudio.pipelines.base_pipeline import VanillaPipeline, VanillaPipelineConfig 9 | from nerfstudio.utils import profiler 10 | from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, MofNCompleteColumn 11 | from typing_extensions import Literal 12 | 13 | from suds.data.suds_datamanager import SUDSDataManagerConfig 14 | from suds.draw_utils import label_colormap 15 | from suds.stream_utils import buffer_from_stream 16 | from suds.suds_constants import DEPTH, FEATURES, SKY 17 | from suds.suds_model import SUDSModelConfig 18 | 19 | 20 | @dataclass 21 | class SUDSPipelineConfig(VanillaPipelineConfig): 22 | """Configuration for pipeline instantiation""" 23 | 24 | _target: Type = field(default_factory=lambda: SUDSPipeline) 25 | """target class to instantiate""" 26 | datamanager: SUDSDataManagerConfig = SUDSDataManagerConfig() 27 | """specifies the datamanager config""" 28 | model: SUDSModelConfig = SUDSModelConfig() 29 | """specifies the model config""" 30 | 31 | feature_clusters: List[str] = field(default_factory=lambda: []) 32 | """clusters to use for feature visualization""" 33 | 34 | 35 | class SUDSPipeline(VanillaPipeline): 36 | config: SUDSPipelineConfig 37 | 38 | def __init__( 39 | self, 40 | config: SUDSPipelineConfig, 41 | device: str, 42 | test_mode: Literal['test', 'val', 'inference'] = 'val', 43 | world_size: int = 1, 44 | local_rank: int = 0): 45 | 46 | feature_clusters = {} 47 | feature_colors = {} 48 | for feature_cluster_path in config.feature_clusters: 49 | feature_cluster = torch.load(buffer_from_stream(feature_cluster_path), map_location='cpu') 50 | feature_clusters[Path(feature_cluster_path).stem] = feature_cluster['centroids'] 51 | cluster_colors = feature_cluster['colors'] / 255. if 'colors' in feature_cluster \ 52 | else label_colormap(feature_cluster['centroids'].shape[0]) 53 | feature_colors[Path(feature_cluster_path).stem] = cluster_colors 54 | 55 | config.model.feature_clusters = feature_clusters 56 | config.model.feature_colors = feature_colors 57 | 58 | config.datamanager.feature_clusters = feature_clusters 59 | config.datamanager.feature_colors = feature_colors 60 | config.datamanager.load_depth = config.model.loss_coefficients[DEPTH] > 0 61 | config.datamanager.load_features = config.model.loss_coefficients[FEATURES] > 0 62 | config.datamanager.load_flow = config.model.predict_flow 63 | config.datamanager.load_sky = config.model.loss_coefficients[SKY] > 0 64 | 65 | super().__init__(config, device, test_mode, world_size, local_rank) 66 | 67 | @profiler.time_function 68 | def get_average_eval_image_metrics(self, step: Optional[int] = None, image_save_dir: Optional[Path] = None): 69 | """Iterate over all the images in the eval dataset and get the average. 70 | 71 | Returns: 72 | metrics_dict: dictionary of metrics 73 | """ 74 | self.eval() 75 | metrics_dict_list = [] 76 | num_images = len(self.datamanager.fixed_indices_eval_dataloader) 77 | with Progress( 78 | TextColumn("[progress.description]{task.description}"), 79 | BarColumn(), 80 | TimeElapsedColumn(), 81 | MofNCompleteColumn(), 82 | transient=True, 83 | ) as progress: 84 | task = progress.add_task("[green]Evaluating all eval images...", total=num_images) 85 | for camera_ray_bundle, batch in self.datamanager.fixed_indices_eval_dataloader: 86 | # time this the following line 87 | inner_start = time() 88 | height, width = camera_ray_bundle.shape 89 | num_rays = height * width 90 | outputs = self.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle) 91 | metrics_dict, images = self.model.get_image_metrics_and_images(outputs, batch) 92 | 93 | if image_save_dir is not None: 94 | for key, val in images.items(): 95 | Image.fromarray((val * 255).byte().cpu().numpy()).save( 96 | image_save_dir / '{0:06d}-{1}.jpg'.format(int(camera_ray_bundle.camera_indices[0, 0, 0]), 97 | key)) 98 | 99 | assert "num_rays_per_sec" not in metrics_dict 100 | metrics_dict["num_rays_per_sec"] = num_rays / (time() - inner_start) 101 | fps_str = "fps" 102 | assert fps_str not in metrics_dict 103 | metrics_dict[fps_str] = metrics_dict["num_rays_per_sec"] / (height * width) 104 | metrics_dict_list.append(metrics_dict) 105 | progress.advance(task) 106 | # average the metrics list 107 | metrics_dict = {} 108 | for key in metrics_dict_list[0].keys(): 109 | metrics_dict[key] = float( 110 | torch.mean(torch.tensor([metrics_dict[key] for metrics_dict in metrics_dict_list])) 111 | ) 112 | self.train() 113 | return metrics_dict 114 | -------------------------------------------------------------------------------- /suds/draw_utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from itertools import accumulate 3 | from typing import Optional 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | 9 | DEFAULT_TRANSITIONS = (15, 6, 4, 11, 13, 6) 10 | 11 | 12 | def _make_colorwheel(transitions: tuple = DEFAULT_TRANSITIONS) -> torch.Tensor: 13 | '''Creates a colorwheel (borrowed/modified from flowpy). 14 | A colorwheel defines the transitions between the six primary hues: 15 | Red(255, 0, 0), Yellow(255, 255, 0), Green(0, 255, 0), Cyan(0, 255, 255), Blue(0, 0, 255) and Magenta(255, 0, 255). 16 | Args: 17 | transitions: Contains the length of the six transitions, based on human color perception. 18 | Returns: 19 | colorwheel: The RGB values of the transitions in the color space. 20 | Notes: 21 | For more information, see: 22 | https://web.archive.org/web/20051107102013/http://members.shaw.ca/quadibloc/other/colint.htm 23 | http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 24 | ''' 25 | colorwheel_length = sum(transitions) 26 | # The red hue is repeated to make the colorwheel cyclic 27 | base_hues = map( 28 | np.array, ([255, 0, 0], [255, 255, 0], [0, 255, 0], [0, 255, 255], [0, 0, 255], [255, 0, 255], [255, 0, 0]) 29 | ) 30 | colorwheel = np.zeros((colorwheel_length, 3), dtype='uint8') 31 | hue_from = next(base_hues) 32 | start_index = 0 33 | for hue_to, end_index in zip(base_hues, accumulate(transitions)): 34 | transition_length = end_index - start_index 35 | colorwheel[start_index:end_index] = np.linspace(hue_from, hue_to, transition_length, endpoint=False) 36 | hue_from = hue_to 37 | start_index = end_index 38 | return torch.FloatTensor(colorwheel) 39 | 40 | 41 | WHEEL = _make_colorwheel() 42 | N_COLS = len(WHEEL) 43 | WHEEL = torch.vstack((WHEEL, WHEEL[0])) # Make the wheel cyclic for interpolation 44 | 45 | # Adapted from https://github.com/facebookresearch/banmo/blob/main/third_party/ext_utils/flowlib.py 46 | 47 | UNKNOWN_FLOW_THRESH = 1e7 48 | SMALLFLOW = 0.0 49 | LARGEFLOW = 1e8 50 | 51 | 52 | def cat_imgflo(img: torch.Tensor, flow: torch.Tensor, skip: int = None) -> torch.Tensor: 53 | """ 54 | img in (0,1) 55 | flo in non-normalized coordinate 56 | """ 57 | flow = flow.clone() 58 | flow[:, :, 0] /= flow.shape[1] 59 | flow[:, :, 1] /= flow.shape[0] 60 | 61 | img = img.clone() * 255 62 | h, w = img.shape[:2] 63 | flow = flow.clone() 64 | flow[:, :, 0] = flow[:, :, 0] * 0.5 * w 65 | flow[:, :, 1] = flow[:, :, 1] * 0.5 * h 66 | imgflo = _point_vec(img, flow, skip) 67 | return imgflo 68 | 69 | 70 | def _point_vec(img: torch.Tensor, flow: torch.Tensor, skip: int = None) -> torch.Tensor: 71 | if skip is None: 72 | skip = min(10, 10 * img.shape[1] // 500) 73 | 74 | dispimg = img.clone().cpu().numpy() 75 | meshgrid = np.meshgrid(range(dispimg.shape[1]), range(dispimg.shape[0])) 76 | 77 | colorflow = _flow_to_image(flow.clone().cpu()).int() 78 | for i in range(dispimg.shape[1]): # x 79 | for j in range(dispimg.shape[0]): # y 80 | if flow.shape[-1] == 3 and flow[j, i, 2] != 1: continue 81 | if j % skip != 0 or i % skip != 0: continue 82 | leng = torch.linalg.norm(flow[j, i, :2]).item() 83 | if leng < 1: 84 | continue 85 | xend = int((meshgrid[0][j, i] + flow[j, i, 0])) 86 | yend = int((meshgrid[1][j, i] + flow[j, i, 1])) 87 | dispimg = cv2.arrowedLine(dispimg, (meshgrid[0][j, i], meshgrid[1][j, i]), \ 88 | (xend, yend), 89 | (int(colorflow[j, i, 2]), int(colorflow[j, i, 1]), int(colorflow[j, i, 0])), 1, 90 | tipLength=4 / leng, line_type=cv2.LINE_AA) 91 | return torch.FloatTensor(dispimg).to(img.device) / 255. 92 | 93 | 94 | def _flow_to_image(flow: torch.Tensor) -> torch.Tensor: 95 | """ 96 | Convert flow into middlebury color code image 97 | :param flow: optical flow map 98 | :return: optical flow image in middlebury color 99 | """ 100 | u = flow[:, :, 0] 101 | v = flow[:, :, 1] 102 | 103 | idxUnknow = (u.abs() > UNKNOWN_FLOW_THRESH) | (v.abs() > UNKNOWN_FLOW_THRESH) 104 | u[idxUnknow] = 0 105 | v[idxUnknow] = 0 106 | 107 | rad = torch.sqrt(u ** 2 + v ** 2) 108 | maxrad = max(-1, rad.max()) 109 | 110 | u = u / (maxrad + 1e-8) 111 | v = v / (maxrad + 1e-8) 112 | 113 | img = _compute_color(u, v) 114 | 115 | idx = idxUnknow.unsqueeze(-1).expand(-1, -1, 3) 116 | img[idx] = 0 117 | 118 | return img.byte() 119 | 120 | 121 | def _compute_color(u: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 122 | """ 123 | compute optical flow color map 124 | :param u: optical flow horizontal map 125 | :param v: optical flow vertical map 126 | :return: optical flow in color code 127 | """ 128 | [h, w] = u.shape 129 | img = torch.zeros(h, w, 3, device=u.device) 130 | nanIdx = torch.isnan(u) | torch.isnan(v) 131 | u[nanIdx] = 0 132 | v[nanIdx] = 0 133 | 134 | rad = torch.sqrt(u ** 2 + v ** 2) 135 | 136 | a = torch.arctan2(-v, -u) / np.pi 137 | 138 | fk = (a + 1) / 2 * (N_COLS - 1) + 1 139 | 140 | k0 = torch.floor(fk).long() 141 | 142 | k1 = k0 + 1 143 | k1[k1 == N_COLS + 1] = 1 144 | f = fk - k0 145 | 146 | for i in range(0, WHEEL.shape[1]): 147 | tmp = WHEEL[:, i] 148 | col0 = tmp[k0 - 1] / 255 149 | col1 = tmp[k1 - 1] / 255 150 | col = (1 - f) * col0 + f * col1 151 | 152 | idx = rad <= 1 153 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 154 | notidx = torch.logical_not(idx) 155 | 156 | col[notidx] *= 0.75 157 | img[:, :, i] = (torch.floor(255 * col * (1 - nanIdx.int()))).byte() 158 | 159 | return img 160 | 161 | 162 | # Adapted from https://github.com/Lilac-Lee/Neural_Scene_Flow_Prior/blob/main/visualize.py 163 | def scene_flow_to_rgb( 164 | flow: torch.Tensor, 165 | flow_max_radius: Optional[float] = None, 166 | background: Optional[str] = 'dark', 167 | ) -> torch.Tensor: 168 | '''Creates a RGB representation of an optical flow (borrowed/modified from flowpy). 169 | Args: 170 | flow: scene flow. 171 | flow[..., 0] should be the x-displacement 172 | flow[..., 1] should be the y-displacement 173 | flow[..., 2] should be the z-displacement 174 | flow_max_radius: Set the radius that gives the maximum color intensity, useful for comparing different flows. 175 | Default: The normalization is based on the input flow maximum radius. 176 | background: States if zero-valued flow should look 'bright' or 'dark'. 177 | Returns: An array of RGB colors. 178 | ''' 179 | valid_backgrounds = ('bright', 'dark') 180 | if background not in valid_backgrounds: 181 | raise ValueError(f'background should be one the following: {valid_backgrounds}, not {background}.') 182 | 183 | # For scene flow, it's reasonable to assume displacements in x and y directions only for visualization pursposes. 184 | complex_flow = flow[..., 0] + 1j * flow[..., 1] 185 | radius, angle = torch.abs(complex_flow), torch.angle(complex_flow) 186 | if flow_max_radius is None: 187 | flow_max_radius = torch.max(radius) 188 | if flow_max_radius > 0: 189 | radius /= flow_max_radius 190 | # Map the angles from (-pi, pi] to [0, 2pi) to [0, ncols - 1) 191 | angle[angle < 0] += 2 * np.pi 192 | angle = angle * ((N_COLS - 1) / (2 * np.pi)) 193 | 194 | # Interpolate the hues 195 | angle_fractional, angle_floor, angle_ceil = torch.fmod(angle, 1), angle.trunc(), torch.ceil(angle) 196 | angle_fractional = angle_fractional.unsqueeze(-1) 197 | wheel = WHEEL.to(angle_floor.device) 198 | float_hue = ( 199 | wheel[angle_floor.long()] * (1 - angle_fractional) + wheel[angle_ceil.long()] * angle_fractional 200 | ) 201 | ColorizationArgs = namedtuple( 202 | 'ColorizationArgs', ['move_hue_valid_radius', 'move_hue_oversized_radius', 'invalid_color'] 203 | ) 204 | 205 | def move_hue_on_V_axis(hues, factors): 206 | return hues * factors.unsqueeze(-1) 207 | 208 | def move_hue_on_S_axis(hues, factors): 209 | return 255. - factors.unsqueeze(-1) * (255. - hues) 210 | 211 | if background == 'dark': 212 | parameters = ColorizationArgs( 213 | move_hue_on_V_axis, move_hue_on_S_axis, torch.FloatTensor([255, 255, 255]) 214 | ) 215 | else: 216 | parameters = ColorizationArgs(move_hue_on_S_axis, move_hue_on_V_axis, torch.zeros(3)) 217 | colors = parameters.move_hue_valid_radius(float_hue, radius) 218 | oversized_radius_mask = radius > 1 219 | colors[oversized_radius_mask] = parameters.move_hue_oversized_radius( 220 | float_hue[oversized_radius_mask], 221 | 1 / radius[oversized_radius_mask] 222 | ) 223 | return colors / 255. 224 | 225 | 226 | def label_colormap(N: int) -> torch.FloatTensor: 227 | cmap = torch.zeros(N, 3) 228 | for i in range(0, N): 229 | id = i 230 | r, g, b = 0, 0, 0 231 | for j in range(0, 8): 232 | r = np.bitwise_or(r, (_bitget(id, 0) << 7 - j)) 233 | g = np.bitwise_or(g, (_bitget(id, 1) << 7 - j)) 234 | b = np.bitwise_or(b, (_bitget(id, 2) << 7 - j)) 235 | id = (id >> 3) 236 | cmap[i, 0] = r 237 | cmap[i, 1] = g 238 | cmap[i, 2] = b 239 | 240 | return cmap / 255. 241 | 242 | 243 | def _bitget(byteval: int, idx: int) -> int: 244 | return ((byteval & (1 << idx)) != 0) 245 | -------------------------------------------------------------------------------- /suds/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | eval.py 4 | """ 5 | from __future__ import annotations 6 | 7 | import datetime 8 | import json 9 | import os 10 | from dataclasses import dataclass 11 | from pathlib import Path 12 | 13 | import torch 14 | import torch.distributed as dist 15 | import tyro 16 | from nerfstudio.utils.comms import get_world_size, get_rank, is_main_process 17 | from nerfstudio.utils.eval_utils import eval_setup 18 | from rich.console import Console 19 | 20 | CONSOLE = Console(width=120) 21 | 22 | 23 | @dataclass 24 | class ComputePSNR: 25 | """Load a checkpoint, compute some PSNR metrics, and save it to a JSON file.""" 26 | 27 | # Path to config YAML file. 28 | load_config: Path 29 | # Name of the output file. 30 | output_path: Path 31 | 32 | def main(self) -> None: 33 | """Main function.""" 34 | config, pipeline, checkpoint_path = eval_setup(self.load_config) 35 | 36 | self.output_path.mkdir(parents=True, exist_ok=True) 37 | metrics_dict = pipeline.get_average_eval_image_metrics(image_save_dir=self.output_path) 38 | 39 | output_json_path = self.output_path / 'metrics.json' 40 | 41 | if is_main_process(): 42 | if get_world_size() > 1: 43 | dist.barrier() 44 | num_images = len(pipeline.datamanager.fixed_indices_eval_dataloader) 45 | for key in metrics_dict: 46 | metrics_dict[key] = metrics_dict[key] * num_images 47 | 48 | for i in range(1, get_world_size()): 49 | shard_path = Path(str(output_json_path) + f'.{i}') 50 | with shard_path.open() as f: 51 | shard_results = json.load(f) 52 | 53 | for key in metrics_dict: 54 | metrics_dict[key] += (shard_results['results'][key] * shard_results['num_images']) 55 | 56 | num_images += shard_results['num_images'] 57 | 58 | shard_path.unlink() 59 | 60 | for key in metrics_dict: 61 | metrics_dict[key] = metrics_dict[key] / num_images 62 | 63 | benchmark_info = { 64 | "experiment_name": config.experiment_name, 65 | "method_name": config.method_name, 66 | "checkpoint": str(checkpoint_path), 67 | "results": metrics_dict, 68 | "num_images": num_images 69 | } 70 | 71 | output_json_path.write_text(json.dumps(benchmark_info, indent=2), "utf8") 72 | 73 | CONSOLE.print(f"Saved results to: {self.output_path}") 74 | else: 75 | shard_output_path = Path(str(output_json_path) + f'.{get_rank()}.tmp') 76 | # Get the output and define the names to save to 77 | benchmark_info = { 78 | "results": metrics_dict, 79 | "num_images": len(pipeline.datamanager.fixed_indices_eval_dataloader) 80 | } 81 | 82 | # Save output to output file 83 | 84 | shard_output_path.write_text(json.dumps(benchmark_info, indent=2), "utf8") 85 | shard_output_path = shard_output_path.rename(Path(str(output_json_path) + f'.{get_rank()}')) 86 | CONSOLE.print(f"Saved shard results to: {shard_output_path}") 87 | if get_world_size() > 1: 88 | dist.barrier() 89 | 90 | 91 | def entrypoint(): 92 | if 'RANK' in os.environ: 93 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(0, hours=24)) 94 | torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) 95 | 96 | """Entrypoint for use with pyproject scripts.""" 97 | tyro.extras.set_accent_color("bright_yellow") 98 | tyro.cli(ComputePSNR).main() 99 | 100 | 101 | if __name__ == "__main__": 102 | entrypoint() 103 | 104 | # For sphinx docs 105 | get_parser_fn = lambda: tyro.extras.get_parser(ComputePSNR) # noqa 106 | -------------------------------------------------------------------------------- /suds/fields/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hturki/suds/9b472ad5e4fd4d810682af984ef65bfaf4d75188/suds/fields/__init__.py -------------------------------------------------------------------------------- /suds/fields/dynamic_proposal_field.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Proposal network field. 17 | """ 18 | 19 | from typing import Optional 20 | 21 | import numpy as np 22 | import tinycudann as tcnn 23 | import torch 24 | import torch.nn.functional as F 25 | from nerfstudio.cameras.rays import RaySamples, Frustums 26 | from nerfstudio.fields.base_field import Field 27 | from torchtyping import TensorType 28 | 29 | from suds.suds_constants import VIDEO_ID 30 | 31 | 32 | class DynamicProposalField(Field): 33 | """A lightweight density field module. 34 | 35 | Args: 36 | aabb: parameters of scene aabb bounds 37 | num_layers: number of hidden layers 38 | hidden_dim: dimension of hidden layers 39 | spatial_distortion: spatial distortion module 40 | use_linear: whether to skip the MLP and use a single linear layer instead 41 | """ 42 | 43 | def __init__( 44 | self, 45 | num_layers: int = 2, 46 | hidden_dim: int = 16, 47 | num_levels: int = 5, 48 | base_resolution: int = 16, 49 | max_resolution: int = 256, 50 | log2_hashmap_size: int = 18, 51 | features_per_level: int = 2, 52 | network_activation: str = 'ReLU' 53 | ) -> None: 54 | super().__init__() 55 | 56 | growth_factor = np.exp((np.log(max_resolution) - np.log(base_resolution)) / (num_levels - 1)) 57 | 58 | self.encoding = tcnn.Encoding( 59 | n_input_dims=5, 60 | encoding_config={ 61 | 'otype': 'SequentialGrid', 62 | 'n_levels': num_levels, 63 | 'n_features_per_level': features_per_level, 64 | 'log2_hashmap_size': log2_hashmap_size, 65 | 'base_resolution': base_resolution, 66 | 'per_level_scale': growth_factor, 67 | 'include_static': False 68 | } 69 | ) 70 | 71 | self.mlp_base = tcnn.Network( 72 | n_input_dims=num_levels * features_per_level, 73 | n_output_dims=1, 74 | network_config={ 75 | 'otype': 'FullyFusedMLP', 76 | 'activation': network_activation, 77 | 'output_activation': 'None', 78 | 'n_neurons': hidden_dim, 79 | 'n_hidden_layers': num_layers - 1, 80 | } 81 | ) 82 | 83 | def density_fn(self, positions: TensorType["bs":..., 3], times: TensorType, video_ids: TensorType) -> \ 84 | TensorType["bs":..., 1]: 85 | ray_samples = RaySamples( 86 | frustums=Frustums( 87 | origins=positions, 88 | directions=torch.ones_like(positions), 89 | starts=torch.zeros_like(positions[..., :1]), 90 | ends=torch.zeros_like(positions[..., :1]), 91 | pixel_area=torch.ones_like(positions[..., :1]), 92 | ), 93 | times=times.unsqueeze(-2).expand(*times.shape[:-1], positions.shape[-2], -1), 94 | metadata={VIDEO_ID: video_ids.unsqueeze(-2).expand(*video_ids.shape[:-1], positions.shape[-2], -1)} 95 | ) 96 | density, _ = self.get_density(ray_samples) 97 | return density 98 | 99 | def get_density(self, ray_samples: RaySamples): 100 | if ray_samples.times is None: 101 | raise AttributeError('Times are not provided.') 102 | if VIDEO_ID not in ray_samples.metadata: 103 | raise AttributeError('Video ids are not provided.') 104 | 105 | positions = ray_samples.frustums.get_positions() 106 | times = ray_samples.times 107 | 108 | base_input = torch.cat( 109 | [positions.view(-1, 3), times.reshape(-1, 1), ray_samples.metadata[VIDEO_ID].reshape(-1, 1)], -1) 110 | density_before_activation = self.mlp_base(self.encoding(base_input)).view(*ray_samples.frustums.shape, -1) 111 | 112 | # Rectifying the density with an exponential is much more stable than a ReLU or 113 | # softplus, because it enables high post-activation (float32) density outputs 114 | # from smaller internal (float16) parameters. 115 | density = F.softplus(density_before_activation.to(ray_samples.frustums.directions) - 1) 116 | 117 | return density, None 118 | 119 | def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None): 120 | return {} 121 | -------------------------------------------------------------------------------- /suds/fields/env_map_field.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Optional 2 | 3 | import tinycudann as tcnn 4 | import torch 5 | from nerfstudio.field_components.field_heads import FieldHeadNames 6 | from torch import nn 7 | 8 | from suds.fields.suds_field_head_names import SUDSFieldHeadNames 9 | from suds.suds_constants import RGB, FEATURES, FILTER_FEATURES 10 | 11 | 12 | class EnvMapField(nn.Module): 13 | 14 | def __init__( 15 | self, 16 | num_layers: int = 2, 17 | hidden_dim: int = 64, 18 | appearance_embedding_dim: int = 32, 19 | num_levels: int = 16, 20 | features_per_level: int = 2, 21 | log2_hashmap_size: int = 16, 22 | base_resolution: int = 16, 23 | feature_dim: int = 64, 24 | network_activation: str = 'ReLU', 25 | feature_output_activation: str = 'Tanh' 26 | ) -> None: 27 | super().__init__() 28 | self.appearance_embedding_dim = appearance_embedding_dim 29 | self.feature_dim = feature_dim 30 | self.feature_output_activation = feature_output_activation 31 | 32 | self.encoding = tcnn.Encoding( 33 | n_input_dims=4, 34 | encoding_config={ 35 | 'otype': 'SequentialGrid', 36 | 'n_levels': num_levels, 37 | 'n_features_per_level': features_per_level, 38 | 'log2_hashmap_size': log2_hashmap_size, 39 | 'base_resolution': base_resolution, 40 | 'include_static': False 41 | } 42 | ) 43 | 44 | self.mlp_head = tcnn.Network( 45 | n_input_dims=features_per_level * num_levels + appearance_embedding_dim, 46 | n_output_dims=3, 47 | network_config={ 48 | 'otype': 'FullyFusedMLP', 49 | 'activation': network_activation, 50 | 'output_activation': 'Sigmoid', 51 | 'n_neurons': hidden_dim, 52 | 'n_hidden_layers': num_layers - 1, 53 | }, 54 | ) 55 | 56 | if feature_dim > 0: 57 | self.mlp_feature = tcnn.Network( 58 | n_input_dims=features_per_level * num_levels, 59 | n_output_dims=feature_dim, 60 | network_config={ 61 | 'otype': 'FullyFusedMLP', 62 | 'activation': network_activation, 63 | 'output_activation': self.feature_output_activation, 64 | 'n_neurons': hidden_dim, 65 | 'n_hidden_layers': num_layers - 1, 66 | }, 67 | ) 68 | 69 | def forward(self, directions: torch.Tensor, video_ids: torch.Tensor, appearance_embedding: Optional[torch.Tensor], 70 | output_type: Optional[str], filter_features: bool) \ 71 | -> Dict[Union[FieldHeadNames, SUDSFieldHeadNames], torch.Tensor]: 72 | embedding = self.encoding(torch.cat([directions, video_ids], -1)) 73 | 74 | outputs = {} 75 | 76 | if output_type is None or RGB in output_type: 77 | outputs[FieldHeadNames.RGB] = self.mlp_head(torch.cat([embedding, appearance_embedding], -1) \ 78 | if self.appearance_embedding_dim > 0 else embedding) \ 79 | * (1 + 2e-3) - 1e-3 80 | 81 | if self.feature_dim > 0 and (filter_features or (output_type is None or FEATURES in output_type)): 82 | features = self.mlp_feature(embedding) 83 | if self.feature_output_activation.casefold() == 'tanh': 84 | features = features * 1.1 85 | 86 | outputs[SUDSFieldHeadNames.FEATURES] = features 87 | 88 | return outputs 89 | -------------------------------------------------------------------------------- /suds/fields/sharded/sharded_dynamic_field.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Tuple, List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from nerfstudio.cameras.rays import RaySamples 6 | from nerfstudio.field_components.field_heads import FieldHeadNames 7 | from nerfstudio.fields.base_field import Field, shift_directions_for_tcnn 8 | from torch import nn 9 | from torchtyping import TensorType 10 | 11 | from suds.fields.dynamic_field import DynamicField 12 | from suds.fields.suds_field_head_names import SUDSFieldHeadNames 13 | from suds.suds_constants import BACKWARD_NEIGHBOR_TIME_DIFF, FORWARD_NEIGHBOR_TIME_DIFF, VIDEO_ID, OUTPUT_TYPE, RGB, \ 14 | FEATURES, STATIC_RGB, NO_ENV_MAP_RGB, BACKWARD_FLOW, FORWARD_FLOW 15 | 16 | 17 | class ShardedDynamicField(Field): 18 | 19 | def __init__( 20 | self, 21 | centroids: torch.Tensor, 22 | origin: torch.Tensor, 23 | centroid_origins: torch.Tensor, 24 | scale: float, 25 | centroid_scales: List[float], 26 | delegates: List[DynamicField] 27 | ) -> None: 28 | super().__init__() 29 | 30 | self.register_buffer('centroids', centroids) 31 | self.register_buffer('origin', origin) 32 | self.register_buffer('centroid_origins', centroid_origins) 33 | 34 | self.scale = scale 35 | self.centroid_scales = centroid_scales 36 | self.delegates = nn.ModuleList(delegates) 37 | 38 | def get_density(self, ray_samples: RaySamples): 39 | if ray_samples.times is None: 40 | raise AttributeError('Times are not provided.') 41 | if VIDEO_ID not in ray_samples.metadata: 42 | raise AttributeError('Video ids are not provided.') 43 | 44 | positions = ray_samples.frustums.get_positions().view(-1, 3) 45 | times = ray_samples.times.reshape(-1, 1) 46 | video_ids = ray_samples.metadata[VIDEO_ID].reshape(-1, 1) 47 | base_input = torch.empty(positions.shape[0], 5, dtype=positions.dtype, device=positions.device) 48 | 49 | density = None 50 | base_mlp_out = None 51 | cluster_assignments = torch.cdist(positions, self.centroids).argmin(dim=1) 52 | for i, delegate in enumerate(self.delegates): 53 | cluster_mask = cluster_assignments == i 54 | 55 | if torch.any(cluster_mask): 56 | shifted_positions = (positions[cluster_mask].double() * self.scale + self.origin - 57 | self.centroid_origins[i]) / self.centroid_scales[i] 58 | del_input = torch.cat([shifted_positions.float(), times[cluster_mask], video_ids[cluster_mask]], -1) 59 | base_input[cluster_mask] = del_input 60 | 61 | h = delegate.mlp_base(delegate.encoding(del_input)) 62 | del_density_before_activation, del_base_mlp_out = torch.split(h, [1, delegate.geo_feat_dim], dim=-1) 63 | del_density = F.softplus(del_density_before_activation.to(ray_samples.frustums.directions) - 1) 64 | 65 | if density is None: 66 | density = torch.empty(ray_samples.frustums.starts.shape, dtype=del_density.dtype, 67 | device=del_density.device) 68 | base_mlp_out = torch.empty(*ray_samples.frustums.starts.shape[:-1], del_base_mlp_out.shape[-1], 69 | dtype=del_base_mlp_out.dtype, device=del_base_mlp_out.device) 70 | 71 | density.view(-1, 1)[cluster_mask] = del_density 72 | base_mlp_out.view(-1, base_mlp_out.shape[-1])[cluster_mask] = del_base_mlp_out 73 | 74 | return density, (base_mlp_out, base_input, cluster_assignments) 75 | 76 | def get_outputs(self, ray_samples: RaySamples, density_embedding: Tuple[TensorType, TensorType, TensorType]) \ 77 | -> Dict[Union[FieldHeadNames, SUDSFieldHeadNames], TensorType]: 78 | density_embedding, base_input, cluster_assignments = density_embedding 79 | density_embedding = density_embedding.view(-1, self.delegates[0].geo_feat_dim) 80 | directions = ray_samples.frustums.directions 81 | 82 | outputs = {} 83 | 84 | if ray_samples.metadata[OUTPUT_TYPE] is None or RGB in ray_samples.metadata[OUTPUT_TYPE]: 85 | rgb_inputs = [] 86 | if self.delegates[0].num_directions > 0: 87 | # Using spherical harmonics - need to map directions to [0, 1] 88 | rgb_inputs.append(shift_directions_for_tcnn(directions).view(-1, 3)) 89 | 90 | rgb_inputs.append(density_embedding) 91 | rgb_inputs = torch.cat(rgb_inputs, -1) 92 | 93 | rgb = torch.empty_like(directions) 94 | for i, delegate in enumerate(self.delegates): 95 | cluster_mask = cluster_assignments == i 96 | 97 | if torch.any(cluster_mask): 98 | rgb.view(-1, 3)[cluster_mask] = delegate.mlp_head(rgb_inputs[cluster_mask]).to(directions) 99 | 100 | outputs[FieldHeadNames.RGB] = rgb * (1 + 2e-3) - 1e-3 101 | 102 | if self.delegates[0].feature_dim > 0 \ 103 | and (ray_samples.metadata[OUTPUT_TYPE] is None or FEATURES in ray_samples.metadata[OUTPUT_TYPE]): 104 | features = torch.empty(*directions.shape[:-1], self.delegates[0].feature_dim, dtype=directions.dtype, 105 | device=directions.device) 106 | for i, delegate in enumerate(self.delegates): 107 | cluster_mask = cluster_assignments == i 108 | 109 | if torch.any(cluster_mask): 110 | features.view(-1, delegate.feature_dim)[cluster_mask] = delegate.mlp_feature( 111 | delegate.encoding_feature(base_input[cluster_mask])).to(directions) 112 | 113 | if self.delegates[0].feature_output_activation.casefold() == 'tanh': 114 | features = features * 1.1 115 | 116 | outputs[SUDSFieldHeadNames.FEATURES] = features 117 | 118 | if self.delegates[0].predict_shadow and (ray_samples.metadata[OUTPUT_TYPE] is None 119 | or ray_samples.metadata[OUTPUT_TYPE] in {RGB, STATIC_RGB, 120 | NO_ENV_MAP_RGB}): 121 | shadows = torch.empty(*directions.shape[:-1], 1, dtype=density_embedding.dtype, device=directions.device) 122 | for i, delegate in enumerate(self.delegates): 123 | cluster_mask = cluster_assignments == i 124 | 125 | if torch.any(cluster_mask): 126 | shadows.view(-1, 1)[cluster_mask] = delegate.mlp_shadow(density_embedding[cluster_mask]).to( 127 | directions) 128 | 129 | outputs[SUDSFieldHeadNames.SHADOWS] = shadows 130 | 131 | if self.delegates[0].predict_flow and (ray_samples.metadata[OUTPUT_TYPE] is None 132 | or ray_samples.metadata[OUTPUT_TYPE] in {BACKWARD_FLOW, FORWARD_FLOW}): 133 | flow = torch.empty(*directions.shape[:-1], 6, dtype=density_embedding.dtype, device=directions.device) 134 | for i, delegate in enumerate(self.delegates): 135 | cluster_mask = cluster_assignments == i 136 | 137 | if torch.any(cluster_mask): 138 | flow.view(-1, 6)[cluster_mask] = delegate.mlp_flow( 139 | delegate.encoding_flow(base_input[cluster_mask])).to(directions) 140 | 141 | flow = torch.tanh(flow) 142 | 143 | backward_flow = flow[..., :3] 144 | forward_flow = flow[..., 3:] 145 | 146 | if BACKWARD_NEIGHBOR_TIME_DIFF in ray_samples.metadata: 147 | backward_time_diff = ray_samples.metadata[BACKWARD_NEIGHBOR_TIME_DIFF] 148 | backward_flow = backward_flow * backward_time_diff / self.delegates[0].flow_unit 149 | 150 | if FORWARD_NEIGHBOR_TIME_DIFF in ray_samples.metadata: 151 | forward_time_diff = ray_samples.metadata[FORWARD_NEIGHBOR_TIME_DIFF] 152 | forward_flow = forward_flow * forward_time_diff / self.delegates[0].flow_unit 153 | 154 | outputs[SUDSFieldHeadNames.BACKWARD_FLOW] = backward_flow 155 | outputs[SUDSFieldHeadNames.FORWARD_FLOW] = forward_flow 156 | 157 | return outputs 158 | -------------------------------------------------------------------------------- /suds/fields/sharded/sharded_dynamic_proposal_field.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Proposal network field. 17 | """ 18 | 19 | from typing import Optional, List 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | from nerfstudio.cameras.rays import RaySamples, Frustums 24 | from nerfstudio.fields.base_field import Field 25 | from torch import nn 26 | from torchtyping import TensorType 27 | 28 | from suds.fields.dynamic_proposal_field import DynamicProposalField 29 | from suds.suds_constants import VIDEO_ID 30 | 31 | 32 | class ShardedDynamicProposalField(Field): 33 | def __init__( 34 | self, 35 | centroids: torch.Tensor, 36 | origin: torch.Tensor, 37 | centroid_origins: torch.Tensor, 38 | scale: float, 39 | centroid_scales: List[float], 40 | delegates: List[DynamicProposalField] 41 | ) -> None: 42 | super().__init__() 43 | 44 | self.register_buffer('centroids', centroids) 45 | self.register_buffer('origin', origin) 46 | self.register_buffer('centroid_origins', centroid_origins) 47 | 48 | self.scale = scale 49 | self.centroid_scales = centroid_scales 50 | self.delegates = nn.ModuleList(delegates) 51 | 52 | def density_fn(self, positions: TensorType["bs":..., 3], times: TensorType, video_ids: TensorType) -> \ 53 | TensorType["bs":..., 1]: 54 | ray_samples = RaySamples( 55 | frustums=Frustums( 56 | origins=positions, 57 | directions=torch.ones_like(positions), 58 | starts=torch.zeros_like(positions[..., :1]), 59 | ends=torch.zeros_like(positions[..., :1]), 60 | pixel_area=torch.ones_like(positions[..., :1]), 61 | ), 62 | times=times.unsqueeze(-2).expand(*times.shape[:-1], positions.shape[-2], -1), 63 | metadata={VIDEO_ID: video_ids.unsqueeze(-2).expand(*video_ids.shape[:-1], positions.shape[-2], -1)} 64 | ) 65 | density, _ = self.get_density(ray_samples) 66 | return density 67 | 68 | def get_density(self, ray_samples: RaySamples): 69 | if ray_samples.times is None: 70 | raise AttributeError('Times are not provided.') 71 | if VIDEO_ID not in ray_samples.metadata: 72 | raise AttributeError('Video ids are not provided.') 73 | 74 | positions = ray_samples.frustums.get_positions().view(-1, 3) 75 | times = ray_samples.times.reshape(-1, 1) 76 | video_ids = ray_samples.metadata[VIDEO_ID].reshape(-1, 1) 77 | 78 | density = None 79 | cluster_assignments = torch.cdist(positions, self.centroids).argmin(dim=1) 80 | for i, delegate in enumerate(self.delegates): 81 | cluster_mask = cluster_assignments == i 82 | 83 | if torch.any(cluster_mask): 84 | shifted_positions = (positions[cluster_mask].double() * self.scale + self.origin - 85 | self.centroid_origins[i]) / self.centroid_scales[i] 86 | del_input = torch.cat([shifted_positions.float(), times[cluster_mask], video_ids[cluster_mask]], -1) 87 | 88 | del_density_before_activation = delegate.mlp_base(delegate.encoding(del_input)) 89 | del_density = F.softplus(del_density_before_activation.to(ray_samples.frustums.directions) - 1) 90 | 91 | if density is None: 92 | density = torch.empty(ray_samples.frustums.starts.shape, dtype=del_density.dtype, 93 | device=del_density.device) 94 | 95 | density.view(-1, 1)[cluster_mask] = del_density 96 | 97 | return density, None 98 | 99 | def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None): 100 | return {} 101 | -------------------------------------------------------------------------------- /suds/fields/sharded/sharded_static_field.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Tuple, List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from nerfstudio.cameras.rays import RaySamples 6 | from nerfstudio.field_components.field_heads import FieldHeadNames 7 | from nerfstudio.fields.base_field import Field, shift_directions_for_tcnn 8 | from torch import nn 9 | from torchtyping import TensorType 10 | 11 | from suds.fields.static_field import StaticField 12 | from suds.fields.suds_field_head_names import SUDSFieldHeadNames 13 | from suds.suds_constants import APPEARANCE_EMBEDDING, OUTPUT_TYPE, RGB, FEATURES 14 | 15 | 16 | class ShardedStaticField(Field): 17 | 18 | def __init__( 19 | self, 20 | centroids: torch.Tensor, 21 | origin: torch.Tensor, 22 | centroid_origins: torch.Tensor, 23 | scale: float, 24 | centroid_scales: List[float], 25 | delegates: List[StaticField] 26 | ) -> None: 27 | super().__init__() 28 | 29 | self.register_buffer('centroids', centroids) 30 | self.register_buffer('origin', origin) 31 | self.register_buffer('centroid_origins', centroid_origins) 32 | 33 | self.scale = scale 34 | self.centroid_scales = centroid_scales 35 | self.delegates = nn.ModuleList(delegates) 36 | 37 | def get_density(self, ray_samples: RaySamples): 38 | positions = ray_samples.frustums.get_positions().view(-1, 3) 39 | base_input = torch.empty_like(positions) 40 | 41 | density = None 42 | base_mlp_out = None 43 | cluster_assignments = torch.cdist(positions, self.centroids).argmin(dim=1) 44 | for i, delegate in enumerate(self.delegates): 45 | cluster_mask = cluster_assignments == i 46 | 47 | if torch.any(cluster_mask): 48 | shifted_positions = ((positions[cluster_mask].double() * self.scale + self.origin - 49 | self.centroid_origins[i]) / self.centroid_scales[i]).float() 50 | base_input[cluster_mask] = shifted_positions 51 | 52 | h = delegate.mlp_base(delegate.encoding(shifted_positions)) 53 | del_density_before_activation, del_base_mlp_out = torch.split(h, [1, delegate.geo_feat_dim], dim=-1) 54 | del_density = F.softplus(del_density_before_activation.to(ray_samples.frustums.directions) - 1) 55 | 56 | if density is None: 57 | density = torch.empty(ray_samples.frustums.starts.shape, dtype=del_density.dtype, 58 | device=del_density.device) 59 | base_mlp_out = torch.empty(*ray_samples.frustums.starts.shape[:-1], del_base_mlp_out.shape[-1], 60 | dtype=del_base_mlp_out.dtype, device=del_base_mlp_out.device) 61 | 62 | density.view(-1, 1)[cluster_mask] = del_density 63 | base_mlp_out.view(-1, base_mlp_out.shape[-1])[cluster_mask] = del_base_mlp_out 64 | 65 | return density, (base_mlp_out, base_input, cluster_assignments) 66 | 67 | def get_outputs(self, ray_samples: RaySamples, density_embedding: Tuple[TensorType, TensorType, TensorType]) \ 68 | -> Dict[Union[FieldHeadNames, SUDSFieldHeadNames], TensorType]: 69 | density_embedding, base_input, cluster_assignments = density_embedding 70 | density_embedding = density_embedding.view(-1, self.delegates[0].geo_feat_dim) 71 | directions = ray_samples.frustums.directions 72 | 73 | outputs = {} 74 | 75 | if ray_samples.metadata[OUTPUT_TYPE] is None or RGB in ray_samples.metadata[OUTPUT_TYPE]: 76 | rgb_inputs = [] 77 | if self.delegates[0].num_directions > 0: 78 | # Using spherical harmonics - need to map directions to [0, 1] 79 | rgb_inputs.append(shift_directions_for_tcnn(directions).view(-1, 3)) 80 | 81 | rgb_inputs.append(density_embedding) 82 | if self.delegates[0].appearance_embedding_dim > 0: 83 | rgb_inputs.append( 84 | ray_samples.metadata[APPEARANCE_EMBEDDING].reshape(-1, self.delegates[0].appearance_embedding_dim)) 85 | 86 | rgb_inputs = torch.cat(rgb_inputs, -1) 87 | 88 | rgb = torch.empty_like(directions) 89 | for i, delegate in enumerate(self.delegates): 90 | cluster_mask = cluster_assignments == i 91 | 92 | if torch.any(cluster_mask): 93 | rgb.view(-1, 3)[cluster_mask] = delegate.mlp_head(rgb_inputs[cluster_mask]).to(directions) 94 | 95 | outputs[FieldHeadNames.RGB] = rgb * (1 + 2e-3) - 1e-3 96 | 97 | if self.delegates[0].feature_dim > 0 \ 98 | and (ray_samples.metadata[OUTPUT_TYPE] is None or FEATURES in ray_samples.metadata[OUTPUT_TYPE]): 99 | features = torch.empty(*directions.shape[:-1], self.delegates[0].feature_dim, dtype=directions.dtype, 100 | device=directions.device) 101 | for i, delegate in enumerate(self.delegates): 102 | cluster_mask = cluster_assignments == i 103 | 104 | if torch.any(cluster_mask): 105 | features.view(-1, delegate.feature_dim)[cluster_mask] = delegate.mlp_feature( 106 | delegate.encoding_feature(base_input[cluster_mask])).to(directions) 107 | 108 | if self.delegates[0].feature_output_activation.casefold() == 'tanh': 109 | features = features * 1.1 110 | 111 | outputs[SUDSFieldHeadNames.FEATURES] = features 112 | 113 | return outputs 114 | -------------------------------------------------------------------------------- /suds/fields/sharded/sharded_static_proposal_field.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from nerfstudio.cameras.rays import RaySamples 6 | from nerfstudio.fields.base_field import Field 7 | from torch import nn 8 | from torchtyping import TensorType 9 | 10 | from suds.fields.static_proposal_field import StaticProposalField 11 | 12 | 13 | class ShardedStaticProposalField(Field): 14 | 15 | def __init__( 16 | self, 17 | centroids: torch.Tensor, 18 | origin: torch.Tensor, 19 | centroid_origins: torch.Tensor, 20 | scale: float, 21 | centroid_scales: List[float], 22 | delegates: List[StaticProposalField] 23 | ) -> None: 24 | super().__init__() 25 | 26 | self.register_buffer('centroids', centroids) 27 | self.register_buffer('origin', origin) 28 | self.register_buffer('centroid_origins', centroid_origins) 29 | 30 | self.scale = scale 31 | self.centroid_scales = centroid_scales 32 | self.delegates = nn.ModuleList(delegates) 33 | 34 | def get_density(self, ray_samples: RaySamples): 35 | positions = ray_samples.frustums.get_positions().view(-1, 3) 36 | 37 | density = None 38 | cluster_assignments = torch.cdist(positions, self.centroids).argmin(dim=1) 39 | for i, delegate in enumerate(self.delegates): 40 | cluster_mask = cluster_assignments == i 41 | 42 | if torch.any(cluster_mask): 43 | shifted_positions = ((positions[cluster_mask].double() * self.scale + self.origin - 44 | self.centroid_origins[i]) / self.centroid_scales[i]).float() 45 | 46 | del_density_before_activation = delegate.mlp_base(delegate.encoding(shifted_positions)) 47 | del_density = F.softplus(del_density_before_activation.to(ray_samples.frustums.directions) - 1) 48 | 49 | if density is None: 50 | density = torch.empty(ray_samples.frustums.starts.shape, dtype=del_density.dtype, 51 | device=del_density.device) 52 | 53 | density.view(-1, 1)[cluster_mask] = del_density 54 | 55 | return density, None 56 | 57 | def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None): 58 | return {} 59 | -------------------------------------------------------------------------------- /suds/fields/static_field.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Tuple 2 | 3 | import tinycudann as tcnn 4 | import torch 5 | import torch.nn.functional as F 6 | from nerfstudio.cameras.rays import RaySamples 7 | from nerfstudio.field_components.field_heads import FieldHeadNames 8 | from nerfstudio.fields.base_field import Field, shift_directions_for_tcnn 9 | from torchtyping import TensorType 10 | 11 | from suds.fields.suds_field_head_names import SUDSFieldHeadNames 12 | from suds.suds_constants import APPEARANCE_EMBEDDING, OUTPUT_TYPE, RGB, FEATURES, FILTER_FEATURES 13 | 14 | 15 | class StaticField(Field): 16 | 17 | def __init__( 18 | self, 19 | num_layers: int = 2, 20 | hidden_dim: int = 64, 21 | geo_feat_dim: int = 15, 22 | num_layers_color: int = 3, 23 | hidden_dim_color: int = 64, 24 | appearance_embedding_dim: int = 32, 25 | num_levels: int = 16, 26 | features_per_level: int = 2, 27 | log2_hashmap_size: int = 19, 28 | base_resolution: int = 16, 29 | num_directions: int = 4, 30 | feature_dim: int = 64, 31 | num_layers_feature: int = 3, 32 | hidden_dim_feature: int = 64, 33 | network_activation: str = 'ReLU', 34 | feature_output_activation: str = 'Tanh', 35 | ) -> None: 36 | super().__init__() 37 | 38 | self.geo_feat_dim = geo_feat_dim 39 | self.appearance_embedding_dim = appearance_embedding_dim 40 | self.feature_output_activation = feature_output_activation 41 | 42 | self.encoding = tcnn.Encoding( 43 | n_input_dims=3, 44 | encoding_config={ 45 | 'otype': 'HashGrid', 46 | 'n_levels': num_levels, 47 | 'n_features_per_level': features_per_level, 48 | 'log2_hashmap_size': log2_hashmap_size, 49 | 'base_resolution': base_resolution, 50 | } 51 | ) 52 | 53 | self.mlp_base = tcnn.Network( 54 | n_input_dims=features_per_level * num_levels, 55 | n_output_dims=1 + self.geo_feat_dim, 56 | network_config={ 57 | 'otype': 'FullyFusedMLP', 58 | 'activation': network_activation, 59 | 'output_activation': 'None', 60 | 'n_neurons': hidden_dim, 61 | 'n_hidden_layers': num_layers - 1, 62 | }, 63 | ) 64 | 65 | color_network_config = { 66 | 'otype': 'FullyFusedMLP', 67 | 'activation': network_activation, 68 | 'output_activation': 'Sigmoid', 69 | 'n_neurons': hidden_dim_color, 70 | 'n_hidden_layers': num_layers_color - 1, 71 | } 72 | 73 | self.num_directions = num_directions 74 | if num_directions > 0: 75 | dir_encoding = { 76 | 'n_dims_to_encode': 3, 77 | 'otype': 'SphericalHarmonics', 78 | 'degree': num_directions 79 | } 80 | 81 | self.mlp_head = tcnn.NetworkWithInputEncoding( 82 | n_input_dims=3 + geo_feat_dim + appearance_embedding_dim, 83 | n_output_dims=3, 84 | encoding_config={ 85 | 'otype': 'Composite', 86 | 'nested': [ 87 | dir_encoding, 88 | { 89 | 'n_dims_to_encode': geo_feat_dim + appearance_embedding_dim, 90 | 'otype': 'Identity' 91 | } 92 | ] 93 | }, 94 | network_config=color_network_config, 95 | ) 96 | else: 97 | self.mlp_head = tcnn.Network( 98 | n_input_dims=geo_feat_dim + appearance_embedding_dim, 99 | n_output_dims=3, 100 | network_config=color_network_config, 101 | ) 102 | 103 | self.feature_dim = feature_dim 104 | if feature_dim > 0: 105 | self.encoding_feature = tcnn.Encoding( 106 | n_input_dims=3, 107 | encoding_config={ 108 | 'otype': 'HashGrid', 109 | 'n_levels': num_levels, 110 | 'n_features_per_level': features_per_level, 111 | 'log2_hashmap_size': log2_hashmap_size, 112 | 'base_resolution': base_resolution, 113 | }, 114 | ) 115 | 116 | self.mlp_feature = tcnn.Network( 117 | n_input_dims=(features_per_level * num_levels), 118 | n_output_dims=feature_dim, 119 | network_config={ 120 | 'otype': 'FullyFusedMLP', 121 | 'activation': network_activation, 122 | 'output_activation': self.feature_output_activation, 123 | 'n_neurons': hidden_dim_feature, 124 | 'n_hidden_layers': num_layers_feature - 1, 125 | }, 126 | ) 127 | 128 | def get_density(self, ray_samples: RaySamples): 129 | positions = ray_samples.frustums.get_positions() 130 | positions_flat = positions.view(-1, 3) 131 | 132 | h = self.mlp_base(self.encoding(positions_flat)).view(*ray_samples.frustums.shape, -1) 133 | density_before_activation, base_mlp_out = torch.split(h, [1, self.geo_feat_dim], dim=-1) 134 | 135 | density = F.softplus(density_before_activation.to(positions) - 1) 136 | return density, (base_mlp_out, positions_flat) 137 | 138 | def get_outputs(self, ray_samples: RaySamples, density_embedding: Tuple[TensorType, TensorType]) \ 139 | -> Dict[Union[FieldHeadNames, SUDSFieldHeadNames], TensorType]: 140 | density_embedding, base_input = density_embedding 141 | density_embedding = density_embedding.view(-1, self.geo_feat_dim) 142 | directions = ray_samples.frustums.directions 143 | 144 | outputs = {} 145 | 146 | if ray_samples.metadata[OUTPUT_TYPE] is None or RGB in ray_samples.metadata[OUTPUT_TYPE]: 147 | rgb_inputs = [] 148 | if self.num_directions > 0: 149 | rgb_inputs.append(shift_directions_for_tcnn(directions).view(-1, 3)) 150 | 151 | rgb_inputs.append(density_embedding) 152 | 153 | if self.appearance_embedding_dim > 0: 154 | rgb_inputs.append(ray_samples.metadata[APPEARANCE_EMBEDDING].reshape(-1, self.appearance_embedding_dim)) 155 | 156 | rgb = self.mlp_head(torch.cat(rgb_inputs, -1)).view(*directions.shape[:-1], -1).to(directions) 157 | outputs[FieldHeadNames.RGB] = rgb * (1 + 2e-3) - 1e-3 158 | 159 | if self.feature_dim > 0 \ 160 | and (ray_samples.metadata[FILTER_FEATURES] 161 | or (ray_samples.metadata[OUTPUT_TYPE] is None or FEATURES in ray_samples.metadata[OUTPUT_TYPE])): 162 | features = self.mlp_feature(self.encoding_feature(base_input)).view(*directions.shape[:-1], -1).to( 163 | directions) 164 | 165 | if self.feature_output_activation.casefold() == 'Tanh': 166 | features = features * 1.1 167 | 168 | outputs[SUDSFieldHeadNames.FEATURES] = features 169 | 170 | return outputs 171 | -------------------------------------------------------------------------------- /suds/fields/static_proposal_field.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Proposal network field. 17 | """ 18 | 19 | from typing import Optional 20 | 21 | import numpy as np 22 | import tinycudann as tcnn 23 | import torch.nn.functional as F 24 | from nerfstudio.cameras.rays import RaySamples 25 | from nerfstudio.fields.base_field import Field 26 | from torchtyping import TensorType 27 | 28 | 29 | class StaticProposalField(Field): 30 | """A lightweight density field module. 31 | 32 | Args: 33 | aabb: parameters of scene aabb bounds 34 | num_layers: number of hidden layers 35 | hidden_dim: dimension of hidden layers 36 | spatial_distortion: spatial distortion module 37 | use_linear: whether to skip the MLP and use a single linear layer instead 38 | """ 39 | 40 | def __init__( 41 | self, 42 | num_layers: int = 2, 43 | hidden_dim: int = 16, 44 | num_levels: int = 5, 45 | base_resolution: int = 16, 46 | max_resolution: int = 256, 47 | log2_hashmap_size: int = 18, 48 | features_per_level: int = 2, 49 | network_activation: str = 'ReLU' 50 | ) -> None: 51 | super().__init__() 52 | 53 | growth_factor = np.exp((np.log(max_resolution) - np.log(base_resolution)) / (num_levels - 1)) 54 | 55 | self.encoding = tcnn.Encoding( 56 | n_input_dims=3, 57 | encoding_config={ 58 | 'otype': 'HashGrid', 59 | 'n_levels': num_levels, 60 | 'n_features_per_level': features_per_level, 61 | 'log2_hashmap_size': log2_hashmap_size, 62 | 'base_resolution': base_resolution, 63 | 'per_level_scale': growth_factor 64 | } 65 | ) 66 | 67 | self.mlp_base = tcnn.Network( 68 | n_input_dims=num_levels * features_per_level, 69 | n_output_dims=1, 70 | network_config={ 71 | 'otype': 'FullyFusedMLP', 72 | 'activation': network_activation, 73 | 'output_activation': 'None', 74 | 'n_neurons': hidden_dim, 75 | 'n_hidden_layers': num_layers - 1, 76 | } 77 | ) 78 | 79 | def get_density(self, ray_samples: RaySamples): 80 | positions = ray_samples.frustums.get_positions() 81 | density_before_activation = self.mlp_base(self.encoding(positions.view(-1, 3))).view( 82 | *ray_samples.frustums.shape, -1) 83 | 84 | density = F.softplus(density_before_activation.to(positions) - 1) 85 | return density, None 86 | 87 | def get_outputs(self, ray_samples: RaySamples, density_embedding: Optional[TensorType] = None): 88 | return {} 89 | -------------------------------------------------------------------------------- /suds/fields/suds_field_head_names.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class SUDSFieldHeadNames(Enum): 5 | FEATURES = 'features' 6 | SHADOWS = 'shadows' 7 | BACKWARD_FLOW = 'backward_flow' 8 | FORWARD_FLOW = 'forward_flow' 9 | DYNAMIC_WEIGHT = 'dynamic_weight' 10 | 11 | BACKWARD_RGB = 'backward_rgb' 12 | BACKWARD_DENSITY = 'backward_density' 13 | BACKWARD_FEATURES = 'backward_features' 14 | BACKWARD_FLOW_CYCLE_DIFF = 'backward_flow_cycle_diff' 15 | BACKWARD_DYNAMIC_WEIGHT = 'backward_dynamic_weight' 16 | 17 | FORWARD_RGB = 'forward_rgb' 18 | FORWARD_DENSITY = 'forward_density' 19 | FORWARD_FEATURES = 'forward_features' 20 | FORWARD_FLOW_CYCLE_DIFF = 'forward_flow_cycle_diff' 21 | FORWARD_DYNAMIC_WEIGHT = 'forward_dynamic_weight' 22 | 23 | FLOW_SLOW = 'flow_slow' 24 | FLOW_SMOOTH_TEMPORAL = 'flow_smooth_temporal' 25 | 26 | NO_SHADOW_RGB = 'no_shadow_rgb' 27 | -------------------------------------------------------------------------------- /suds/fields/video_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import suds_cuda 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class VideoEmbeddingFunction(torch.autograd.Function): 9 | 10 | @staticmethod 11 | def forward(ctx: Any, times: torch.Tensor, video_ids: torch.Tensor, weights: torch.Tensor, 12 | num_frequencies: int) -> torch.Tensor: 13 | embeddings = suds_cuda.video_embedding_forward(times, video_ids, weights, num_frequencies) 14 | ctx.save_for_backward(times, video_ids, torch.IntTensor([weights.shape[0], num_frequencies])) 15 | return embeddings 16 | 17 | @staticmethod 18 | def backward(ctx: Any, d_loss_embedding: torch.Tensor): 19 | times, video_ids, num_sequences_and_frequencies = ctx.saved_tensors 20 | d_loss_weights = suds_cuda.video_embedding_backward(d_loss_embedding.contiguous(), times, video_ids, 21 | num_sequences_and_frequencies[0].item(), 22 | num_sequences_and_frequencies[1].item()) 23 | 24 | return None, None, d_loss_weights, None 25 | 26 | 27 | class VideoEmbedding(nn.Module): 28 | 29 | def __init__(self, num_videos: int, num_frequencies: int, embedding_dim: int): 30 | super(VideoEmbedding, self).__init__() 31 | 32 | self.num_frequencies = num_frequencies 33 | self.sequence_code_weights = nn.Parameter( 34 | torch.empty(size=(num_videos, embedding_dim, num_frequencies * 2 + 1), dtype=torch.float32), 35 | requires_grad=True) 36 | torch.nn.init.normal_(self.sequence_code_weights) 37 | 38 | def forward(self, times: torch.Tensor, video_ids: torch.Tensor) -> torch.Tensor: 39 | return VideoEmbeddingFunction.apply(times, video_ids, self.sequence_code_weights, self.num_frequencies) 40 | -------------------------------------------------------------------------------- /suds/interpolate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.spatial import Delaunay 3 | 4 | 5 | # Modified from https://github.com/pytorch/pytorch/issues/50339#issuecomment-1339910414 6 | class GridInterpolator: 7 | def __init__(self, height: int, width: int) -> None: 8 | self.height = height 9 | self.width = width 10 | 11 | self.Y_grid, self.X_grid = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij') 12 | 13 | self.X_grid = self.X_grid.reshape(-1) 14 | self.Y_grid = self.Y_grid.reshape(-1) 15 | 16 | def forward(self, X: torch.Tensor, Y: torch.Tensor, value: torch.Tensor) -> torch.Tensor: 17 | # Tesselate grid points 18 | pos = torch.stack([X, Y], dim=-1).cpu().numpy() 19 | tri = Delaunay(pos, furthest_site=False) 20 | 21 | # Find the corners of each simplice 22 | corners_X = X[tri.simplices] 23 | corners_Y = Y[tri.simplices] 24 | corners_F = value[tri.simplices] 25 | 26 | # Find simplice ID for each query pixel in the original grid 27 | pos_orig = torch.stack([self.X_grid, self.Y_grid], dim=-1).numpy() 28 | simplice_id = tri.find_simplex(pos_orig) 29 | 30 | # Find X,Y,F values of the 3 nearest grid points for each 31 | # pixel in the original grid 32 | corners_X_pq = corners_X[simplice_id] 33 | corners_Y_pq = corners_Y[simplice_id] 34 | corners_F_pq = corners_F[simplice_id] 35 | 36 | x1, y1 = corners_X_pq[:, 0], corners_Y_pq[:, 0] 37 | x2, y2 = corners_X_pq[:, 1], corners_Y_pq[:, 1] 38 | x3, y3 = corners_X_pq[:, 2], corners_Y_pq[:, 2] 39 | 40 | x_grid_gpu = self.X_grid.to(X) 41 | y_grid_gpu = self.Y_grid.to(X) 42 | lambda1 = ((y2 - y3) * (x_grid_gpu - x3) + (x3 - x2) * (y_grid_gpu - y3)) / \ 43 | ((y2 - y3) * (x1 - x3) + (x3 - x2) * (y1 - y3)) 44 | 45 | lambda2 = ((y3 - y1) * (x_grid_gpu - x3) + (x1 - x3) * (y_grid_gpu - y3)) / \ 46 | ((y2 - y3) * (x1 - x3) + (x3 - x2) * (y1 - y3)) 47 | 48 | lambda3 = 1 - lambda1 - lambda2 49 | 50 | out = lambda1 * corners_F_pq[:, 0] + lambda2 * corners_F_pq[:, 1] + lambda3 * corners_F_pq[:, 2] 51 | out[simplice_id == -1] = 0 52 | 53 | return out.reshape(self.height, self.width) 54 | -------------------------------------------------------------------------------- /suds/kmeans.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/subhadarship/kmeans_pytorch 2 | 3 | from functools import partial 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | def initialize(X, num_clusters, seed): 10 | """ 11 | initialize cluster centers 12 | :param X: (torch.tensor) matrix 13 | :param num_clusters: (int) number of clusters 14 | :param seed: (int) seed for kmeans 15 | :return: (np.array) initial state 16 | """ 17 | num_samples = len(X) 18 | if seed == None: 19 | indices = np.random.choice(num_samples, num_clusters, replace=False) 20 | else: 21 | np.random.seed(seed) ; indices = np.random.choice(num_samples, num_clusters, replace=False) 22 | initial_state = X[indices] 23 | return initial_state 24 | 25 | 26 | def kmeans( 27 | X, 28 | num_clusters, 29 | distance='euclidean', 30 | cluster_centers=[], 31 | tol=1e-4, 32 | tqdm_flag=True, 33 | iter_limit=0, 34 | device=torch.device('cpu'), 35 | seed=None, 36 | ): 37 | """ 38 | perform kmeans 39 | :param X: (torch.tensor) matrix 40 | :param num_clusters: (int) number of clusters 41 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 42 | :param seed: (int) seed for kmeans 43 | :param tol: (float) threshold [default: 0.0001] 44 | :param device: (torch.device) device [default: cpu] 45 | :param tqdm_flag: Allows to turn logs on and off 46 | :param iter_limit: hard limit for max number of iterations 47 | :return: (torch.tensor, torch.tensor) cluster ids, cluster centers 48 | """ 49 | if tqdm_flag: 50 | print(f'running k-means on {device}..') 51 | 52 | if distance == 'euclidean': 53 | pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) 54 | elif distance == 'cosine': 55 | pairwise_distance_function = partial(pairwise_cosine, device=device) 56 | else: 57 | raise NotImplementedError 58 | 59 | # convert to float 60 | X = X.float() 61 | 62 | # transfer to device 63 | X = X.to(device) 64 | 65 | # initialize 66 | if type(cluster_centers) == list: # ToDo: make this less annoyingly weird 67 | initial_state = initialize(X, num_clusters, seed=seed) 68 | else: 69 | if tqdm_flag: 70 | print('resuming') 71 | # find data point closest to the initial cluster center 72 | initial_state = cluster_centers 73 | dis = pairwise_distance_function(X, initial_state) 74 | choice_points = torch.argmin(dis, dim=0) 75 | initial_state = X[choice_points] 76 | initial_state = initial_state.to(device) 77 | 78 | iteration = 0 79 | if tqdm_flag: 80 | tqdm_meter = tqdm(desc='[running kmeans]') 81 | while True: 82 | 83 | dis = pairwise_distance_function(X, initial_state) 84 | 85 | choice_cluster = torch.argmin(dis, dim=1) 86 | 87 | initial_state_pre = initial_state.clone() 88 | 89 | for index in range(num_clusters): 90 | selected = torch.nonzero(choice_cluster == index).squeeze().to(device) 91 | 92 | selected = torch.index_select(X, 0, selected) 93 | 94 | # https://github.com/subhadarship/kmeans_pytorch/issues/16 95 | if selected.shape[0] == 0: 96 | selected = X[torch.randint(len(X), (1,))] 97 | 98 | initial_state[index] = selected.mean(dim=0) 99 | 100 | center_shift = torch.sum( 101 | torch.sqrt( 102 | torch.sum((initial_state - initial_state_pre) ** 2, dim=1) 103 | )) 104 | 105 | # increment iteration 106 | iteration = iteration + 1 107 | 108 | # update tqdm meter 109 | if tqdm_flag: 110 | tqdm_meter.set_postfix( 111 | iteration=f'{iteration}', 112 | center_shift=f'{center_shift ** 2:0.6f}', 113 | tol=f'{tol:0.6f}' 114 | ) 115 | tqdm_meter.update() 116 | if center_shift ** 2 < tol: 117 | break 118 | if iter_limit != 0 and iteration >= iter_limit: 119 | break 120 | 121 | return choice_cluster.cpu(), initial_state.cpu() 122 | 123 | 124 | def kmeans_predict( 125 | X, 126 | cluster_centers, 127 | distance='euclidean', 128 | device=torch.device('cpu'), 129 | tqdm_flag=True 130 | ): 131 | """ 132 | predict using cluster centers 133 | :param X: (torch.tensor) matrix 134 | :param cluster_centers: (torch.tensor) cluster centers 135 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 136 | :param device: (torch.device) device [default: 'cpu'] 137 | :return: (torch.tensor) cluster ids 138 | """ 139 | if tqdm_flag: 140 | print(f'predicting on {device}..') 141 | 142 | if distance == 'euclidean': 143 | pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) 144 | elif distance == 'cosine': 145 | pairwise_distance_function = partial(pairwise_cosine, device=device) 146 | else: 147 | raise NotImplementedError 148 | 149 | # convert to float 150 | X = X.float() 151 | 152 | # transfer to device 153 | X = X.to(device) 154 | 155 | dis = pairwise_distance_function(X, cluster_centers) 156 | choice_cluster = torch.argmin(dis, dim=1) 157 | 158 | return choice_cluster.to(device) 159 | 160 | 161 | def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True): 162 | if tqdm_flag: 163 | print(f'device is :{device}') 164 | 165 | # transfer to device 166 | data1, data2 = data1.to(device), data2.to(device) 167 | 168 | # N*1*M 169 | A = data1.unsqueeze(dim=1) 170 | 171 | # 1*N*M 172 | B = data2.unsqueeze(dim=0) 173 | 174 | dis = (A - B) ** 2.0 175 | # return N*N matrix for pairwise distance 176 | dis = dis.sum(dim=-1).squeeze() 177 | return dis 178 | 179 | 180 | def pairwise_cosine(data1, data2, device=torch.device('cpu')): 181 | # transfer to device 182 | data1, data2 = data1.to(device), data2.to(device) 183 | 184 | # N*1*M 185 | A = data1.unsqueeze(dim=1) 186 | 187 | # 1*N*M 188 | B = data2.unsqueeze(dim=0) 189 | 190 | # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5] 191 | A_normalized = A / torch.linalg.norm(A, dim=-1, keepdim=True) 192 | B_normalized = B / torch.linalg.norm(B, dim=-1, keepdim=True) 193 | 194 | cosine = A_normalized * B_normalized 195 | 196 | # return N*N matrix for pairwise distance 197 | cosine_dis = 1 - cosine.sum(dim=-1).squeeze() 198 | return cosine_dis -------------------------------------------------------------------------------- /suds/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | import torch 25 | import torch.nn.functional as F 26 | 27 | 28 | def ssim( 29 | target_rgbs: torch.Tensor, 30 | rgbs: torch.Tensor, 31 | max_val: float = 1, 32 | filter_size: int = 11, 33 | filter_sigma: float = 1.5, 34 | k1: float = 0.01, 35 | k2: float = 0.03, 36 | ) -> float: 37 | """Computes SSIM from two images. 38 | This function was modeled after tf.image.ssim, and should produce comparable 39 | output. 40 | Args: 41 | rgbs: torch.tensor. An image of size [..., width, height, num_channels]. 42 | target_rgbs: torch.tensor. An image of size [..., width, height, num_channels]. 43 | max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. 44 | filter_size: int >= 1. Window size. 45 | filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. 46 | k1: float > 0. One of the SSIM dampening parameters. 47 | k2: float > 0. One of the SSIM dampening parameters. 48 | Returns: 49 | Each image's mean SSIM. 50 | """ 51 | device = rgbs.device 52 | ori_shape = rgbs.size() 53 | width, height, num_channels = ori_shape[-3:] 54 | rgbs = rgbs.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 55 | target_rgbs = target_rgbs.view(-1, width, height, num_channels).permute(0, 3, 1, 2) 56 | 57 | # Construct a 1D Gaussian blur filter. 58 | hw = filter_size // 2 59 | shift = (2 * hw - filter_size + 1) / 2 60 | f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2 61 | filt = torch.exp(-0.5 * f_i) 62 | filt /= torch.sum(filt) 63 | 64 | # Blur in x and y (faster than the 2D convolution). 65 | # z is a tensor of size [B, H, W, C] 66 | filt_fn1 = lambda z: F.conv2d( 67 | z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1), 68 | padding=[hw, 0], groups=num_channels) 69 | filt_fn2 = lambda z: F.conv2d( 70 | z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1), 71 | padding=[0, hw], groups=num_channels) 72 | 73 | # Vmap the blurs to the tensor size, and then compose them. 74 | filt_fn = lambda z: filt_fn1(filt_fn2(z)) 75 | mu0 = filt_fn(rgbs) 76 | mu1 = filt_fn(target_rgbs) 77 | mu00 = mu0 * mu0 78 | mu11 = mu1 * mu1 79 | mu01 = mu0 * mu1 80 | sigma00 = filt_fn(rgbs ** 2) - mu00 81 | sigma11 = filt_fn(target_rgbs ** 2) - mu11 82 | sigma01 = filt_fn(rgbs * target_rgbs) - mu01 83 | 84 | # Clip the variances and covariances to valid values. 85 | # Variance must be non-negative: 86 | sigma00 = torch.clamp(sigma00, min=0.0) 87 | sigma11 = torch.clamp(sigma11, min=0.0) 88 | sigma01 = torch.sign(sigma01) * torch.min( 89 | torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) 90 | ) 91 | 92 | c1 = (k1 * max_val) ** 2 93 | c2 = (k2 * max_val) ** 2 94 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 95 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 96 | ssim_map = numer / denom 97 | 98 | return torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1).item() 99 | -------------------------------------------------------------------------------- /suds/render_images.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | eval.py 4 | """ 5 | from __future__ import annotations 6 | 7 | import datetime 8 | import os 9 | from dataclasses import dataclass 10 | from pathlib import Path 11 | from typing import Optional, List, Set 12 | 13 | import numpy as np 14 | import torch 15 | import torch.distributed as dist 16 | import tyro 17 | from PIL import Image 18 | from nerfstudio.utils.comms import get_world_size 19 | from nerfstudio.utils.eval_utils import eval_setup 20 | from rich.console import Console 21 | from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, MofNCompleteColumn 22 | 23 | from suds.suds_constants import RGB, DEPTH, FEATURES, VIDEO_ID 24 | 25 | CONSOLE = Console(width=120) 26 | 27 | 28 | @dataclass 29 | class RenderImages: 30 | # Path to config YAML file. 31 | load_config: Path 32 | # Name of the output file. 33 | output_path: Path 34 | 35 | generate_ring_view: bool 36 | video_ids: Optional[Set[int]] = None 37 | start_frame: Optional[int] = None 38 | end_frame: Optional[int] = None 39 | 40 | focal_mult: Optional[float] = None 41 | pos_shift: Optional[List[float]] = None 42 | 43 | feature_filter: Optional[List[int]] = None 44 | sigma_threshold: Optional[float] = None 45 | max_altitude: Optional[float] = None 46 | static_only: bool = False 47 | 48 | @torch.inference_mode() 49 | def main(self) -> None: 50 | """Main function.""" 51 | config, pipeline, checkpoint_path = eval_setup(self.load_config) 52 | pipeline.eval() 53 | 54 | dataloader = pipeline.datamanager.all_indices_eval_dataloader(self.generate_ring_view, self.video_ids, 55 | self.start_frame, self.end_frame, self.focal_mult, 56 | torch.FloatTensor(self.pos_shift) 57 | if self.pos_shift is not None else None) 58 | num_images = len(dataloader) 59 | 60 | render_options = {'static_only': self.static_only} 61 | if self.sigma_threshold is not None: 62 | render_options['sigma_threshold'] = self.sigma_threshold 63 | if self.max_altitude is not None: 64 | render_options['max_altitude'] = self.max_altitude 65 | if self.feature_filter is not None: 66 | render_options['feature_filter'] = self.feature_filter 67 | 68 | with Progress( 69 | TextColumn("[progress.description]{task.description}"), 70 | BarColumn(), 71 | TimeElapsedColumn(), 72 | MofNCompleteColumn(), 73 | transient=True, 74 | ) as progress: 75 | task = progress.add_task("[green]Evaluating all eval images...", total=num_images) 76 | 77 | ring_buffer = [] 78 | for camera_ray_bundle, batch in dataloader: 79 | images = {} 80 | 81 | to_check = [RGB, DEPTH] 82 | for key in pipeline.model.config.feature_clusters: 83 | to_check.append(f'{FEATURES}_{key}') 84 | 85 | frame_id = int(camera_ray_bundle.camera_indices[0, 0, 0]) 86 | video_id = str(camera_ray_bundle.metadata[VIDEO_ID][0, 0, 0].item()) 87 | 88 | all_present = True 89 | for key in to_check: 90 | candidiate = (self.output_path / video_id / '{0}-{1:06d}.jpg'.format(key, frame_id)) 91 | if candidiate.exists(): 92 | images[key] = np.asarray(Image.open(candidiate)) 93 | else: 94 | all_present = False 95 | break 96 | 97 | if not all_present: 98 | outputs = pipeline.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle, 99 | render_options=render_options) 100 | 101 | images[RGB] = (outputs[RGB] * 255).byte().cpu().numpy() 102 | images[DEPTH] = (pipeline.model.apply_depth_colormap(outputs[DEPTH]) * 255).byte().cpu().numpy() 103 | for key in pipeline.model.config.feature_clusters: 104 | images[f'{FEATURES}_{key}'] = (outputs[f'{FEATURES}_{key}'] * 255).byte().cpu().numpy() 105 | 106 | (self.output_path / video_id).mkdir(parents=True, exist_ok=True) 107 | for key, val in images.items(): 108 | Image.fromarray(val).save( 109 | self.output_path / video_id / '{0}-{1:06d}.jpg'.format(key, frame_id)) 110 | 111 | if self.generate_ring_view: 112 | ring_buffer.append(images) 113 | 114 | if len(ring_buffer) == 7: 115 | merged_W = ring_buffer[1][RGB].shape[1] + \ 116 | ring_buffer[0][RGB].shape[1] + ring_buffer[2][RGB].shape[1] 117 | merged_H = max(ring_buffer[0][RGB].shape[0], ring_buffer[1][RGB].shape[0] + 118 | ring_buffer[5][RGB].shape[0] + ring_buffer[3][RGB].shape[0]) 119 | 120 | offsets = [ 121 | (ring_buffer[1][RGB].shape[1], 0), 122 | (0, 0), 123 | (ring_buffer[1][RGB].shape[1] + ring_buffer[0][RGB].shape[1], 0), 124 | (0, ring_buffer[1][RGB].shape[0] + ring_buffer[5][RGB].shape[0]), 125 | (ring_buffer[1][RGB].shape[1] + ring_buffer[0][RGB].shape[1], 126 | ring_buffer[1][RGB].shape[0] + ring_buffer[5][RGB].shape[0]), 127 | (0, ring_buffer[1][RGB].shape[0]), 128 | (ring_buffer[1][RGB].shape[1] + ring_buffer[0][RGB].shape[1], ring_buffer[1][RGB].shape[0]) 129 | ] 130 | 131 | merged_images = [] 132 | for key, val in ring_buffer[0].items(): 133 | merged = np.zeros((merged_H, merged_W, 3), dtype=np.uint8) 134 | for i, (offset_W, offset_H) in enumerate(offsets): 135 | image = ring_buffer[i][key] 136 | merged[offset_H:offset_H + image.shape[0], offset_W:offset_W + image.shape[1]] = image 137 | 138 | merged_images.append(merged) 139 | Image.fromarray(merged).save( 140 | self.output_path / video_id / 'merged-{0}-{1:06d}.jpg'.format(key, frame_id // 7)) 141 | 142 | Image.fromarray(np.concatenate(merged_images, 1)).save( 143 | self.output_path / video_id / 'merged-all-{0:06d}.jpg'.format(frame_id // 7)) 144 | 145 | ring_buffer = [] 146 | 147 | progress.advance(task) 148 | 149 | if get_world_size() > 1: 150 | dist.barrier() 151 | 152 | 153 | def entrypoint(): 154 | if 'RANK' in os.environ: 155 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(0, hours=24)) 156 | torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) 157 | 158 | """Entrypoint for use with pyproject scripts.""" 159 | tyro.extras.set_accent_color("bright_yellow") 160 | tyro.cli(RenderImages).main() 161 | 162 | 163 | if __name__ == "__main__": 164 | entrypoint() 165 | 166 | # For sphinx docs 167 | get_parser_fn = lambda: tyro.extras.get_parser(RenderImages) # noqa 168 | -------------------------------------------------------------------------------- /suds/sample_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # From RaySamplers.get_weights 5 | @torch.jit.script 6 | def _get_weights(deltas: torch.Tensor, density: torch.Tensor, filter_nan: bool) -> torch.Tensor: 7 | delta_density = deltas * density 8 | alphas = 1 - torch.exp(-delta_density) 9 | transmittance = torch.cumsum(delta_density[..., :-1, :], dim=-2) 10 | transmittance = torch.cat( 11 | [torch.zeros((transmittance.shape[0], 1, 1), device=density.device), 12 | transmittance], dim=-2 13 | ) 14 | transmittance = torch.exp(-transmittance) # [..., "num_samples"] 15 | weights = alphas * transmittance # [..., "num_samples"] 16 | 17 | if filter_nan: 18 | weights = torch.nan_to_num(weights) 19 | 20 | return weights 21 | -------------------------------------------------------------------------------- /suds/stream_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import traceback 3 | from io import BytesIO 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | from rich.console import Console 8 | 9 | try: 10 | import gcsfs 11 | except: 12 | pass 13 | 14 | try: 15 | import s3fs 16 | except: 17 | pass 18 | 19 | from PIL import Image 20 | from smart_open import open 21 | import pyarrow.parquet as pq 22 | 23 | NUM_RETRIES = 3 24 | 25 | CONSOLE = Console(width=120) 26 | 27 | 28 | def table_from_stream(path: str) -> Any: 29 | for i in range(NUM_RETRIES): 30 | try: 31 | return pq.read_table(path, filesystem=get_filesystem(path)) 32 | except Exception as e: 33 | CONSOLE.log('Download failed for {} (attempt {})'.format(path, i + 1)) 34 | if i == NUM_RETRIES - 1: 35 | raise e 36 | traceback.print_exc() 37 | time.sleep(10) 38 | 39 | 40 | def buffer_from_stream(path: str) -> BytesIO: 41 | for i in range(NUM_RETRIES): 42 | try: 43 | buffer = BytesIO() 44 | with open(path, 'rb') as f: 45 | buffer.write(f.read()) 46 | buffer.seek(0) 47 | return buffer 48 | except Exception as e: 49 | CONSOLE.log('Download failed for {} (attempt {})'.format(path, i + 1)) 50 | if i == NUM_RETRIES - 1: 51 | raise e 52 | traceback.print_exc() 53 | time.sleep(10) 54 | 55 | 56 | def image_from_stream(path: str) -> Image: 57 | return Image.open(buffer_from_stream(path)) 58 | 59 | 60 | def image_to_stream(img: Image, path: str) -> None: 61 | for i in range(NUM_RETRIES): 62 | try: 63 | extension = Path(path).suffix 64 | if extension == '.png': 65 | format = 'PNG' 66 | elif extension == '.jpg': 67 | format = 'JPEG' 68 | else: 69 | raise Exception(path) 70 | 71 | buffer = BytesIO() 72 | img.save(buffer, format=format) 73 | with open(path, 'wb') as f: 74 | f.write(buffer.getbuffer()) 75 | return 76 | except Exception as e: 77 | CONSOLE.log('Download failed for {} (attempt {})'.format(path, i + 1)) 78 | if i == NUM_RETRIES - 1: 79 | raise e 80 | traceback.print_exc() 81 | time.sleep(10) 82 | 83 | 84 | def get_filesystem(path: str) -> Any: 85 | if path.startswith('s3://'): 86 | return s3fs.S3FileSystem() 87 | elif path.startswith('gs://'): 88 | return gcsfs.GCSFileSystem() 89 | else: 90 | return None 91 | -------------------------------------------------------------------------------- /suds/suds_collider.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from nerfstudio.cameras.rays import RayBundle 5 | from nerfstudio.model_components.scene_colliders import NearFarCollider, _intersect_with_sphere 6 | 7 | from suds.suds_constants import BG_INTERSECTION 8 | 9 | 10 | class SUDSCollider(NearFarCollider): 11 | 12 | def __init__(self, 13 | near: float, 14 | far: float, 15 | scene_bounds: Optional[torch.Tensor], 16 | sphere_center: Optional[torch.Tensor], 17 | sphere_radius: Optional[torch.Tensor]) -> None: 18 | super().__init__(near, far if sphere_center is None else 1e10) 19 | self.far = far # we clamp to far after finding sphere intersections 20 | self.scene_bounds = scene_bounds 21 | self.sphere_center = sphere_center 22 | self.sphere_radius = sphere_radius 23 | 24 | def forward(self, ray_bundle: RayBundle) -> RayBundle: 25 | ray_bundle = super().forward(ray_bundle) 26 | 27 | if self.scene_bounds is not None: 28 | _truncate_with_plane_intersection(ray_bundle.origins, ray_bundle.directions, self.scene_bounds[1, 2], 29 | ray_bundle.nears) 30 | _truncate_with_plane_intersection(ray_bundle.origins, ray_bundle.directions, self.scene_bounds[0, 2], 31 | ray_bundle.fars) 32 | 33 | if self.sphere_center is not None: 34 | device = ray_bundle.origins.device 35 | rays_d, rays_o = _ellipse_to_sphere_coords(ray_bundle.origins.view(-1, 3), 36 | ray_bundle.directions.view(-1, 3), 37 | self.sphere_center, 38 | self.sphere_radius, 39 | device) 40 | 41 | _, sphere_fars = _intersect_with_sphere(rays_o, rays_d, torch.zeros(3, device=device)) 42 | ray_bundle.metadata[BG_INTERSECTION] = torch.zeros_like(sphere_fars) 43 | rays_with_bg = ray_bundle.fars > sphere_fars 44 | ray_bundle.metadata[BG_INTERSECTION][rays_with_bg] = sphere_fars[rays_with_bg] 45 | ray_bundle.fars = torch.minimum(ray_bundle.fars, sphere_fars) 46 | 47 | assert ray_bundle.nears.min() >= 0, ray_bundle.nears.min() 48 | assert ray_bundle.fars.min() >= 0, ray_bundle.fars.min() 49 | 50 | ray_bundle.nears = ray_bundle.nears.clamp_min(self.near_plane) 51 | ray_bundle.fars = ray_bundle.fars.clamp_min(ray_bundle.nears + 1e-6).clamp_max(self.far) 52 | 53 | return ray_bundle 54 | 55 | 56 | @torch.jit.script 57 | def _ellipse_to_sphere_coords(rays_o: torch.Tensor, rays_d: torch.Tensor, sphere_center: torch.Tensor, 58 | sphere_radius: torch.Tensor, device: torch.device) \ 59 | -> Tuple[torch.Tensor, torch.Tensor]: 60 | sphere_radius = sphere_radius.to(device) 61 | rays_o = (rays_o - sphere_center.to(device)) / sphere_radius 62 | rays_d = rays_d / sphere_radius 63 | return rays_d, rays_o 64 | 65 | 66 | @torch.jit.script 67 | def _truncate_with_plane_intersection(rays_o: torch.Tensor, rays_d: torch.Tensor, altitude: float, 68 | default_bounds: torch.Tensor) -> None: 69 | starts_before = rays_o[..., 2] > altitude 70 | goes_down = rays_d[..., 2] < 0 71 | 72 | boundable_rays = torch.minimum(starts_before, goes_down) 73 | 74 | ray_points = rays_o[boundable_rays] 75 | ray_dirs = rays_d[boundable_rays] 76 | if ray_points.shape[0] == 0: 77 | return 78 | 79 | default_bounds[boundable_rays] = ((altitude - ray_points[..., 2]) / ray_dirs[..., 2]).unsqueeze(-1) 80 | 81 | assert torch.all(default_bounds[boundable_rays] > 0) 82 | -------------------------------------------------------------------------------- /suds/suds_constants.py: -------------------------------------------------------------------------------- 1 | IMAGE_INDEX = 'image_index' 2 | PIXEL_INDEX = 'pixel_index' 3 | RGB = 'rgb' 4 | DEPTH = 'depth' 5 | FEATURES = 'features' 6 | BACKWARD_FLOW = 'backward_flow' 7 | BACKWARD_FLOW_VALID = 'backward_flow_valid' 8 | FORWARD_FLOW = 'forward_flow' 9 | FORWARD_FLOW_VALID = 'forward_flow_valid' 10 | 11 | BACKWARD_NEIGHBOR_TIME_DIFF = 'backward_neighbor_time_diff' 12 | BACKWARD_NEIGHBOR_W2C = 'backward_neighbor_w2c' 13 | BACKWARD_NEIGHBOR_K = 'backward_neighbor_K' 14 | 15 | FORWARD_NEIGHBOR_TIME_DIFF = 'forward_neighbor_time_diff' 16 | FORWARD_NEIGHBOR_W2C = 'forward_neighbor_w2c' 17 | FORWARD_NEIGHBOR_K = 'forward_neighbor_K' 18 | 19 | RAY_INDEX = 'ray_index' 20 | TIME = 'time' 21 | VIDEO_ID = 'video_id' 22 | 23 | MASK = 'mask' 24 | 25 | APPEARANCE_EMBEDDING = 'appearance_embedding' 26 | BG_INTERSECTION = 'bg_intersection' 27 | 28 | SKY = 'sky' 29 | 30 | STATIC_RGB = f'{RGB}_static' 31 | NO_ENV_MAP_RGB = f'{RGB}_no_env_map' 32 | DYNAMIC_RGB = f'{RGB}_dynamic' 33 | NO_SHADOW_RGB = f'{RGB}_no_shadow' 34 | 35 | FILTER_FEATURES = 'filter_features' 36 | OUTPUT_TYPE = 'output_type' -------------------------------------------------------------------------------- /suds/suds_depth_renderer.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | 3 | import nerfacc 4 | import torch 5 | from torch import nn 6 | from torchtyping import TensorType 7 | 8 | 9 | class SUDSDepthRenderer(nn.Module): 10 | def __init__(self, method: Literal["median", "expected"] = "expected") -> None: 11 | super().__init__() 12 | self.method = method 13 | 14 | def forward( 15 | self, 16 | weights: TensorType[..., "num_samples", 1], 17 | z_vals: torch.Tensor, 18 | ray_indices: Optional[TensorType["num_samples"]] = None, 19 | num_rays: Optional[int] = None, 20 | ) -> TensorType[..., 1]: 21 | """Composite samples along ray and calculate depths. 22 | 23 | Args: 24 | weights: Weights for each sample. 25 | ray_samples: Set of ray samples. 26 | ray_indices: Ray index for each sample, used when samples are packed. 27 | num_rays: Number of rays, used when samples are packed. 28 | 29 | Returns: 30 | Outputs of depth values. 31 | """ 32 | 33 | if self.method == "median": 34 | # steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2 35 | 36 | if ray_indices is not None and num_rays is not None: 37 | raise NotImplementedError("Median depth calculation is not implemented for packed samples.") 38 | cumulative_weights = torch.cumsum(weights[..., 0], dim=-1) # [..., num_samples] 39 | split = torch.ones((*weights.shape[:-2], 1), device=weights.device) * 0.5 # [..., 1] 40 | median_index = torch.searchsorted(cumulative_weights, split, side="left") # [..., 1] 41 | median_index = torch.clamp(median_index, 0, z_vals.shape[-2] - 1) # [..., 1] 42 | median_depth = torch.gather(z_vals[..., 0], dim=-1, index=median_index) # [..., 1] 43 | return median_depth 44 | if self.method == "expected": 45 | eps = 1e-10 46 | # steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2 47 | 48 | if ray_indices is not None and num_rays is not None: 49 | # Necessary for packed samples from volumetric ray sampler 50 | depth = nerfacc.accumulate_along_rays(weights, ray_indices, z_vals, num_rays) 51 | accumulation = nerfacc.accumulate_along_rays(weights, ray_indices, None, num_rays) 52 | depth = depth / (accumulation + eps) 53 | else: 54 | depth = torch.sum(weights * z_vals, dim=-2) / (torch.sum(weights, -2) + eps) 55 | 56 | depth = torch.clip(depth, z_vals.min(), z_vals.max()) 57 | 58 | return depth 59 | 60 | raise NotImplementedError(f"Method {self.method} not implemented") 61 | -------------------------------------------------------------------------------- /suds/train.py: -------------------------------------------------------------------------------- 1 | import faulthandler 2 | import signal 3 | 4 | import nerfstudio.configs.method_configs 5 | import nerfstudio.data.datamanagers.base_datamanager 6 | import tyro 7 | from nerfstudio.configs.config_utils import convert_markup_to_ansi 8 | from nerfstudio.configs.method_configs import descriptions, method_configs 9 | from nerfstudio.data.dataparsers.arkitscenes_dataparser import ARKitScenesDataParserConfig 10 | from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig 11 | from nerfstudio.data.dataparsers.dnerf_dataparser import DNeRFDataParserConfig 12 | from nerfstudio.data.dataparsers.dycheck_dataparser import DycheckDataParserConfig 13 | from nerfstudio.data.dataparsers.instant_ngp_dataparser import InstantNGPDataParserConfig 14 | from nerfstudio.data.dataparsers.minimal_dataparser import MinimalDataParserConfig 15 | from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig 16 | from nerfstudio.data.dataparsers.nuscenes_dataparser import NuScenesDataParserConfig 17 | from nerfstudio.data.dataparsers.phototourism_dataparser import PhototourismDataParserConfig 18 | from nerfstudio.data.dataparsers.scannet_dataparser import ScanNetDataParserConfig 19 | from nerfstudio.data.dataparsers.sdfstudio_dataparser import SDFStudioDataParserConfig 20 | from nerfstudio.data.dataparsers.sitcoms3d_dataparser import Sitcoms3DDataParserConfig 21 | from nerfstudio.engine.optimizers import AdamOptimizerConfig 22 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig 23 | from nerfstudio.engine.trainer import TrainerConfig 24 | from scripts.train import main 25 | 26 | from suds.data.suds_datamanager import SUDSDataManagerConfig 27 | from suds.data.suds_dataparser import SUDSDataParserConfig 28 | from suds.data.suds_pipeline import SUDSPipelineConfig 29 | from suds.suds_model import SUDSModelConfig 30 | 31 | 32 | def suds_entrypoint(): 33 | faulthandler.register(signal.SIGUSR1) 34 | 35 | descriptions['suds'] = 'Scalable Urban Dynamic Scenes' 36 | 37 | method_configs['suds'] = TrainerConfig( 38 | method_name='suds', 39 | steps_per_eval_batch=500, 40 | steps_per_save=10000, 41 | max_num_iterations=250001, 42 | mixed_precision=True, 43 | steps_per_eval_all_images=250000, 44 | steps_per_eval_image=1000, 45 | log_gradients=True, 46 | pipeline=SUDSPipelineConfig( 47 | datamanager=SUDSDataManagerConfig( 48 | dataparser=SUDSDataParserConfig(), 49 | ), 50 | model=SUDSModelConfig() 51 | ), 52 | optimizers={ 53 | 'fields': { 54 | 'optimizer': AdamOptimizerConfig(lr=5e-3, eps=1e-8), 55 | 'scheduler': ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=250000), 56 | }, 57 | 'mlps': { 58 | 'optimizer': AdamOptimizerConfig(lr=5e-3, eps=1e-8, weight_decay=1e-8), 59 | 'scheduler': ExponentialDecaySchedulerConfig(lr_final=5e-4, max_steps=250000), 60 | }, 61 | 'flow': { 62 | 'optimizer': AdamOptimizerConfig(lr=1e-4, eps=1e-8), 63 | 'scheduler': ExponentialDecaySchedulerConfig(lr_final=1e-5, max_steps=250000), 64 | }, 65 | 'flow_mlp': { 66 | 'optimizer': AdamOptimizerConfig(lr=1e-4, eps=1e-8, weight_decay=1e-8), 67 | 'scheduler': ExponentialDecaySchedulerConfig(lr_final=1e-5, max_steps=250000), 68 | } 69 | } 70 | ) 71 | 72 | AnnotatedBaseConfigUnion = tyro.conf.SuppressFixed[ 73 | # Don't show unparseable (fixed) arguments in helptext. 74 | tyro.conf.FlagConversionOff[ 75 | tyro.extras.subcommand_type_from_defaults(defaults=method_configs, descriptions=descriptions) 76 | ] 77 | ] 78 | 79 | nerfstudio.data.datamanagers.base_datamanager.AnnotatedDataParserUnion = tyro.conf.OmitSubcommandPrefixes[ 80 | # Omit prefixes of flags in subcommands. 81 | tyro.extras.subcommand_type_from_defaults( 82 | { 83 | "nerfstudio-data": NerfstudioDataParserConfig(), 84 | "minimal-parser": MinimalDataParserConfig(), 85 | "arkit-data": ARKitScenesDataParserConfig(), 86 | "blender-data": BlenderDataParserConfig(), 87 | "instant-ngp-data": InstantNGPDataParserConfig(), 88 | "nuscenes-data": NuScenesDataParserConfig(), 89 | "dnerf-data": DNeRFDataParserConfig(), 90 | "phototourism-data": PhototourismDataParserConfig(), 91 | "dycheck-data": DycheckDataParserConfig(), 92 | "scannet-data": ScanNetDataParserConfig(), 93 | "sdfstudio-data": SDFStudioDataParserConfig(), 94 | "sitcoms3d-data": Sitcoms3DDataParserConfig(), 95 | "suds-data": SUDSDataParserConfig(), 96 | }, 97 | prefix_names=False, # Omit prefixes in subcommands themselves. 98 | ) 99 | ] 100 | 101 | tyro.extras.set_accent_color("bright_yellow") 102 | main( 103 | tyro.cli( 104 | AnnotatedBaseConfigUnion, 105 | description=convert_markup_to_ansi(__doc__), 106 | ) 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | suds_entrypoint() 112 | --------------------------------------------------------------------------------