├── .gitignore ├── README.md ├── assets └── teaser.png ├── cfg ├── dataset │ └── scannet │ │ ├── split.npz │ │ └── split_cl.npz ├── env │ └── env.yml └── exp │ ├── multi_step │ ├── cl_base.yml │ └── cl_base_novel_viewpoints.yml │ ├── one_step_finetune_nerf │ ├── s00_lr1e-5.yml │ ├── s10_lr1e-5.yml │ ├── s20_lr1e-5.yml │ ├── s30_lr1e-5.yml │ ├── s40_lr1e-5.yml │ ├── s50_lr1e-5.yml │ ├── s60_lr1e-5.yml │ ├── s70_lr1e-5.yml │ ├── s80_lr1e-5.yml │ └── s90_lr1e-5.yml │ ├── one_step_joint │ ├── s00_lr1e-5.yml │ ├── s10_lr1e-5.yml │ ├── s20_lr1e-5.yml │ ├── s30_lr1e-5.yml │ ├── s40_lr1e-5.yml │ ├── s50_lr1e-5.yml │ ├── s60_lr1e-5.yml │ ├── s70_lr1e-5.yml │ ├── s80_lr1e-5.yml │ └── s90_lr1e-5.yml │ └── pretrain_scannet_25k_deeplabv3.yml ├── nr4seg ├── __init__.py ├── dataset │ ├── __init__.py │ ├── create_split.py │ ├── helper.py │ ├── label_loader.py │ ├── ngp_utils.py │ ├── scannet.py │ ├── scannet_cl.py │ ├── scannet_cl_joint.py │ ├── scannet_ngp.py │ └── scannet_ngp_joint.py ├── lightning │ ├── __init__.py │ ├── finetune_data_module.py │ ├── joint_train_data_module.py │ ├── joint_train_lightning_net.py │ ├── pretrain_data_module.py │ └── semantics_lightning_net.py ├── nerf │ ├── __init__.py │ ├── activation.py │ ├── network_tcnn_semantics.py │ ├── raymarching │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── raymarching.py │ │ └── src │ │ │ ├── bindings.cpp │ │ │ ├── pcg32.h │ │ │ ├── raymarching.cu │ │ │ └── raymarching.h │ └── renderer_semantics.py ├── network │ ├── __init__.py │ └── deeplabv3.py ├── utils │ ├── __init__.py │ ├── flatten_dict.py │ ├── get_logger.py │ ├── loading.py │ └── metrics.py └── visualizer │ ├── __init__.py │ ├── colormaps.py │ └── visualizer.py ├── preprocessing_scripts ├── scannet2nerf.py ├── scannet2transform.py └── utils.py ├── requirements.txt ├── run_scripts ├── multi_step.sh ├── one_step_finetune_train.sh ├── one_step_joint_train.sh ├── one_step_nerf_only_train.sh ├── preprocess_scannet.sh └── pretrain.sh ├── scripts ├── cl_deeplab.py ├── eval_utils.py ├── pretrain.py ├── train_finetune.py └── train_joint.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # no IntelliJ files 2 | .idea 3 | 4 | # don't upload macOS folder info 5 | *.DS_Store 6 | 7 | # don't upload node_modules from npm test 8 | node_modules/* 9 | flow-typed/* 10 | 11 | # potential files generated by golang 12 | bin/ 13 | 14 | # don't upload webpack bundle file 15 | app/dist/ 16 | 17 | # default data directory 18 | local-data/ 19 | 20 | # potential integration testing data directory 21 | test_data/ 22 | 23 | #python 24 | .pyc 25 | __pycache__/ 26 | 27 | # pytype 28 | .pytype 29 | 30 | # vscode sftp settings 31 | .vscode/sftp.json 32 | 33 | # redis 34 | *.rdb 35 | 36 | # mypy 37 | .mypy_cache 38 | 39 | # jest coverage cache 40 | coverage/ 41 | 42 | # python cache 43 | __pycache__/ 44 | 45 | # python virtual environment 46 | env/ 47 | 48 | # local data folder 49 | data/* 50 | 51 | # vscode workspace configuration 52 | *.code-workspace 53 | 54 | # sphinx build folder 55 | _build/ 56 | 57 | # media files are not in this repo 58 | doc/media 59 | 60 | # ignore rope db cache 61 | .vscode/.ropeproject/objectdb 62 | 63 | # ignore vscode launch file 64 | .vscode/launch.json 65 | 66 | # python build 67 | build/ 68 | dist/ 69 | 70 | # package egg info 71 | *.egg-info 72 | 73 | # legacy files 74 | archived 75 | 76 | # unrelated files 77 | *.pkl 78 | .dist_tmp 79 | 80 | # project files 81 | experiments 82 | ckpts 83 | wandb 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Unsupervised Continual Semantic Adaptation through Neural Rendering

2 | 3 |

4 | Zhizheng Liu*, Francesco Milano*, Jonas Frey, Roland Siegwart, Hermann Blum, Cesar Cadena 5 |

6 | 7 |

CVPR 2023

8 |

Paper | Video | Project Page

9 | 10 |

11 | 12 | Unsupervised Continual Semantic Adaptation through Neural Rendering 13 | 14 |

15 | 16 | We present a framework to improve semantic scene understanding for agents that are deployed across a _sequence of scenes_. In particular, our method performs unsupervised continual semantic adaptation by jointly training a _2-D segmentation model_ and a _Semantic-NeRF network_. 17 | 18 | - Our framework allows successfully adapting the 2-D segmentation model across _multiple, previously unseen scenes_ and with _no ground-truth supervision_, reducing the domain gap in the new scenes and improving on the initial performance of the model. 19 | - By rendering training and novel views, the pipeline can effectively _mitigate forgetting_ and even _gain additional knowledge_ about the previous scenes. 20 | 21 | ## Table of Contents 22 | 23 | 1. [Installation](#installation) 24 | 2. [Running experiments](#running-experiments) 25 | 3. [Citation](#citation) 26 | 4. [Acknowledgements](#acknowledgements) 27 | 5. [Contact](#contact) 28 | 29 | ## Installation 30 | 31 | ### Workspace setup 32 | 33 | We recommend configuring your workspace with a conda environment. You can then install the project and its dependencies as follows. The instructions were tested on Ubuntu 20.04 and 22.04, with CUDA 11.3. 34 | 35 | - Clone this repo to a folder of your choice, which in the following we will refer to with the environmental variable `REPO_ROOT`: 36 | ```bash 37 | export REPO_ROOT= 38 | cd ${REPO_ROOT}; 39 | git clone git@github.com:ethz-asl/nr_semantic_segmentation.git 40 | ``` 41 | - Create a conda environment and install [PyTorch](https://pytorch.org/), [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) and other dependencies: 42 | 43 | ```bash 44 | conda create -n nr4seg python=3.8 45 | conda activate nr4seg 46 | python -m pip install --upgrade pip 47 | 48 | # For CUDA 11.3. 49 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch 50 | # Install tiny-cuda-nn 51 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 52 | 53 | pip install -r requirements.txt 54 | 55 | python setup.py develop 56 | ``` 57 | 58 | ### Setting up the dataset 59 | 60 | We use the [ScanNet v2](http://www.scan-net.org/) [1] dataset for our experiments. 61 | 62 | > [1] Angela Dai, Angel X. Chang, Manolis Savva, Maciej Halber, Thomas Funkhouser, and Matthias Nießner, "ScanNet: Richly-annotated 3D Reconstructions of Indoor Scenes", in _Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)_, pp. 2432-2443, 2017. 63 | 64 | #### Dataset download 65 | 66 | - To get started, visit the official [ScanNet](https://github.com/ScanNet/ScanNet#scannet-data) dataset website to obtain the data downloading permission and script. We keep all the ScanNet data in the `${REPO_ROOT}/data/scannet` folder. You may use a symbolic link if necessary (_e.g._, `ln -s ${REPO_ROOT}/data/scannet`). 67 | - As detailed in the paper, we use scenes `0000` to `0009` to perform continual semantic adaptation, and a subset of data from the remaining scenes (`0010` to `0706`) to pre-train the segmentation network. 68 | - The data from the pre-training scenes are conveniently provided already in the correct format by the ScanNet dataset, as a `scannet_frames_25k.zip` file. You can download this file using the official download script that you should have received after requesting access to the dataset, specifying the `--preprocessed_frames` flag. Once downloaded, the content of the file should be extracted to the subfolder `${REPO_ROOT}/data/scannet/scannet_frames_25k`. 69 | - For the scenes used to perform continual semantic adaptation, the full data are required. To obtain them, run the official download script, specifying through the flag `--id` the scene to download (_e.g._, `--id scene0000_00` to download scene `0000`) and including the `--label_map` flag, to download also the label mapping file `scannetv2-labels.combined.tsv` (cf. [here](https://github.com/ScanNet/ScanNet#labels)). The downloaded data should be stored in the subfolder `${REPO_ROOT}/data/scannet/scans`. Next, extract all the sensor data (depth images, color images, poses, intrinsics) using the [SensReader](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python) tool provided by ScanNet, for each of the downloaded scenes from `0000` to `0009`. For instance, for scene `0000`, run 70 | ```bash 71 | python2 reader.py --filename ${REPO_ROOT}/data/scannet/scans/scene0000_00/scene0000_00.sens --output_path ${REPO_ROOT}/data/scannet/scans/scene0000_00 --export_depth_images --export_color_images --export_poses --export_intrinsics 72 | ``` 73 | To obtain the raw labels (for evaluation purposes) for each of the continual adaptation scenes, also extract the content of the `sceneXXXX_XX_2d-label-filt.zip` file, so that a `${REPO_ROOT}/data/scannet/scans/sceneXXXX_XX/label-filt` folder is created. 74 | - Copy the `scannetv2-labels.combined.tsv` file to each scene folder under `${REPO_ROOT}/data/scannet/scans`, as well as to the subfolder `${REPO_ROOT}/data/scannet/scannet_frames_25k`. 75 | 76 | - At the end of the process, the `${REPO_ROOT}/data` folder should contain _at least_ the following data, structured as below: 77 | 78 | ```shell 79 | scannet 80 | scannet_frames_25k 81 | scene0010_00 82 | color 83 | 000000.jpg 84 | ... 85 | XXXXXX.jpg 86 | label 87 | 000000.png 88 | ... 89 | XXXXXX.png 90 | ... 91 | ... 92 | scene0706_00 93 | ... 94 | scannetv2-labels.combined.tsv 95 | scans 96 | scene0000_00 97 | color 98 | 000000.jpg 99 | ... 100 | XXXXXX.jpg 101 | depth 102 | 000000.png 103 | ... 104 | XXXXXX.png 105 | label-filt 106 | 000000.png 107 | ... 108 | XXXXXX.png 109 | pose 110 | 000000.txt 111 | ... 112 | XXXXXX.txt 113 | intrinsics 114 | intriniscs_color.txt 115 | intrinsics_depth.txt 116 | scannetv2-labels.combined.tsv 117 | ... 118 | scene0009_00 119 | ... 120 | ``` 121 | 122 | You may define the data subfolders differently by adjusting the `scannet` and `scannet_frames_25k` fields in [`cfg/env/env.yml`](./cfg/env/env.yml). You may also define several config files and set the configuration to use by specifying the `ENV_WORKSTATION_NAME` environmental variable before running the code (_e.g._, `export ENV_WORKSTATION_NAME="gpu_machine"` to use the config in `cfg/env/gpu_machine.yml`). 123 | 124 | - Copy the files [`split.npz`](./cfg/dataset/scannet/split.npz) and [`split_cl.npz`](./cfg/dataset/scannet/split_cl.npz) from the `${REPO_ROOT}/cfg/dataset/scannet/` folder to the `${REPO_ROOT}/data/scannet/scannet_frames_25k` folder. These files contain the indices of the samples that define the train/validation splits used in pre-training and to form the replay buffer in continual adaptation, to ensure reproducibility. 125 | 126 | #### Dataset pre-processing 127 | 128 | After organizing the ScanNet files as detailed above, run the following script to pre-process the files: 129 | 130 | ```bash 131 | bash run_scripts/preprocess_scannet.sh 132 | ``` 133 | 134 | After pre-processing, the folder structure for each `sceneXXXX_XX` from `scene0000_00` to `scene0009_00` should look as follows: 135 | 136 | ```shell 137 | sceneXXXX_XX 138 | color 139 | 000000.jpg 140 | ... 141 | XXXXXX.jpg 142 | color_scaled 143 | 000000.jpg 144 | ... 145 | XXXXXX.jpg 146 | depth 147 | 000000.png 148 | ... 149 | XXXXXX.png 150 | label_40 151 | 000000.png 152 | ... 153 | XXXXXX.png 154 | label_40_scaled 155 | 000000.png 156 | ... 157 | XXXXXX.png 158 | label-filt 159 | 000000.png 160 | ... 161 | XXXXXX.png 162 | pose 163 | 000000.txt 164 | ... 165 | XXXXXX.txt 166 | intrinsics 167 | intriniscs_color.txt 168 | intrinsics_depth.txt 169 | scannetv2-labels.combined.tsv 170 | transforms_test.json 171 | transforms_test_scaled_semantics_40_raw.json 172 | transforms_train.json 173 | transforms_train_scaled_semantics_40_raw.json 174 | ``` 175 | 176 | ## Running experiments 177 | 178 | By default, the data produced when running the code is stored in the `${REPO_ROOT}/experiments` folder. You can modify this by changing the `results` field in [`cfg/env/env.yml`](./cfg/env/env.yml). 179 | 180 | ### DeepLabv3 pre-training 181 | 182 | To pre-train the DeepLabv3 segmentation network on scenes `0010` to `0706`, run the following script: 183 | 184 | ```bash 185 | bash run_scripts/pretrain.sh --exp cfg/exp/pretrain_scannet_25k_deeplabv3.yml 186 | ``` 187 | 188 | Alternatively, we provide a pre-trained DeepLabv3 [checkpoint](https://www.research-collection.ethz.ch/bitstream/handle/20.500.11850/637142/best-epoch143-step175536.ckpt), which you may download to the `${REPO_ROOT}/ckpts` folder. 189 | 190 | ### One-step experiments 191 | 192 | This Section contains instruction on how to perform one-step adaptation experiments (cf. Sec. 4.4 in the main paper). 193 | 194 | #### Fine-tuning 195 | 196 | For fine-tuning, NeRF pseudo-labels should first be generated by running NeRF-only training: 197 | 198 | ```bash 199 | bash run_scripts/one_step_nerf_only_train.sh 200 | ``` 201 | 202 | Next, run 203 | 204 | ```bash 205 | bash run_scripts/one_step_finetune_train.sh 206 | ``` 207 | 208 | to fine-tune DeepLabv3 with the NeRF pseudo-labels. Please make sure the variable `prev_exp_name` defined in the [fine-tuning script](./run_scripts/one_step_finetune_train.sh) matches the variable `name` in the [NeRF-only script](./run_scripts/one_step_nerf_only_train.sh). 209 | 210 | #### Joint-training 211 | 212 | To perform one-step joint training, run 213 | 214 | ```bash 215 | bash run_scripts/one_step_joint_train.sh 216 | ``` 217 | 218 | ### Multi-step experiments 219 | 220 | To perform multi-step adaptation experiments (cf. Sec. 4.5 in the main paper), run the following commands: 221 | 222 | ```bash 223 | # Using training views for replay. 224 | bash run_scripts/multi_step.sh --exp cfg/exp/multi_step/cl_base.yml 225 | # Using novel views for "replay". 226 | bash run_scripts/multi_step.sh --exp cfg/exp/multi_step/cl_base_novel_viewpoints.yml 227 | ``` 228 | 229 | ### Logging 230 | 231 | By default, we use [WandB](https://wandb.ai/site) to log our experiments. You can initialize WandB logging by running 232 | 233 | ```bash 234 | wandb init -e ${YOUR_WANDB_ENTITY} 235 | ``` 236 | in the terminal. Alternatively, you can disable all logging by defining `export WANDB_MODE=disabled` before launching the experiments. 237 | 238 | ### Seeding 239 | 240 | To obtain the variances of the results, we run the above experiments multiple times with different seeds by specifying `--seed` in the argument. 241 | 242 | ## Citation 243 | 244 | If you find our code or paper useful, please cite: 245 | 246 | ```bibtex 247 | @inproceedings{Liu2023UnsupervisedContinualSemanticAdaptationNR, 248 | author = {Liu, Zhizheng and Milano, Francesco and Frey, Jonas and Siegwart, Roland and Blum, Hermann and Cadena, Cesar}, 249 | title = {Unsupervised Continual Semantic Adaptation through Neural Rendering}, 250 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 251 | year = {2023} 252 | } 253 | ``` 254 | 255 | ## Acknowledgements 256 | 257 | Parts of the NeRF implementation are adapted from [torch-ngp](https://github.com/ashawkey/torch-ngp), [Semantic-NeRF](https://github.com/Harry-Zhi/semantic_nerf/), and [Instant-NGP](https://github.com/NVlabs/instant-ngp). 258 | 259 | ## Contact 260 | 261 | Contact [Zhizheng Liu](mailto:liuzhi@student.ethz.ch) and [Francesco Milano](mailto:francesco.milano@mavt.ethz.ch) for questions, comments, and reporting bugs. 262 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethz-asl/ucsa_neural_rendering/d29b37445e7e6be2cf04fa70ccebebbcbec0c1bf/assets/teaser.png -------------------------------------------------------------------------------- /cfg/dataset/scannet/split.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethz-asl/ucsa_neural_rendering/d29b37445e7e6be2cf04fa70ccebebbcbec0c1bf/cfg/dataset/scannet/split.npz -------------------------------------------------------------------------------- /cfg/dataset/scannet/split_cl.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethz-asl/ucsa_neural_rendering/d29b37445e7e6be2cf04fa70ccebebbcbec0c1bf/cfg/dataset/scannet/split_cl.npz -------------------------------------------------------------------------------- /cfg/env/env.yml: -------------------------------------------------------------------------------- 1 | results: experiments 2 | scannet: data/scannet/scans 3 | scannet_frames_25k: data/scannet/scannet_frames_25k 4 | -------------------------------------------------------------------------------- /cfg/exp/multi_step/cl_base.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0000_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 2 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | - scene0000_00 49 | # - scene0001_00 50 | # - scene0002_00 51 | # - scene0003_00 52 | # - scene0004_00 53 | # - scene0005_00 54 | # - scene0006_00 55 | # - scene0007_00 56 | # - scene0008_00 57 | # - scene0009_00 58 | 59 | 60 | cl: 61 | active: true 62 | 25k_fraction: 0.1 63 | ngp_25k_ratio: 1 64 | use_novel_viewpoints: False 65 | replay_buffer_size: 100 66 | -------------------------------------------------------------------------------- /cfg/exp/multi_step/cl_base_novel_viewpoints.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0000_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 2 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | 48 | scenes: 49 | - scene0000_00 50 | # - scene0001_00 51 | # - scene0002_00 52 | # - scene0003_00 53 | # - scene0004_00 54 | # - scene0005_00 55 | # - scene0006_00 56 | # - scene0007_00 57 | # - scene0008_00 58 | # - scene0009_00 59 | 60 | 61 | cl: 62 | active: true 63 | 25k_fraction: 0.1 64 | ngp_25k_ratio: 1 65 | use_novel_viewpoints: True 66 | replay_buffer_size: 100 67 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s00_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0000_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | - scene0000_00 50 | # - scene0001_00 51 | # - scene0002_00 52 | # - scene0003_00 53 | # - scene0004_00 54 | # - scene0005_00 55 | # - scene0006_00 56 | # - scene0007_00 57 | # - scene0008_00 58 | # - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s10_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0001_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | # - scene0000_00 50 | - scene0001_00 51 | # - scene0002_00 52 | # - scene0003_00 53 | # - scene0004_00 54 | # - scene0005_00 55 | # - scene0006_00 56 | # - scene0007_00 57 | # - scene0008_00 58 | # - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s20_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0002_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | # - scene0000_00 50 | # - scene0001_00 51 | - scene0002_00 52 | # - scene0003_00 53 | # - scene0004_00 54 | # - scene0005_00 55 | # - scene0006_00 56 | # - scene0007_00 57 | # - scene0008_00 58 | # - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | 65 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s30_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0003_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | # - scene0000_00 50 | # - scene0001_00 51 | # - scene0002_00 52 | - scene0003_00 53 | # - scene0004_00 54 | # - scene0005_00 55 | # - scene0006_00 56 | # - scene0007_00 57 | # - scene0008_00 58 | # - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | 65 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s40_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0004_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | # - scene0000_00 50 | # - scene0001_00 51 | # - scene0002_00 52 | # - scene0003_00 53 | - scene0004_00 54 | # - scene0005_00 55 | # - scene0006_00 56 | # - scene0007_00 57 | # - scene0008_00 58 | # - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | 65 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s50_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0005_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | # - scene0000_00 50 | # - scene0001_00 51 | # - scene0002_00 52 | # - scene0003_00 53 | # - scene0004_00 54 | - scene0005_00 55 | # - scene0006_00 56 | # - scene0007_00 57 | # - scene0008_00 58 | # - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | 65 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s60_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0006_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | # - scene0000_00 50 | # - scene0001_00 51 | # - scene0002_00 52 | # - scene0003_00 53 | # - scene0004_00 54 | # - scene0005_00 55 | - scene0006_00 56 | # - scene0007_00 57 | # - scene0008_00 58 | # - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | 65 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s70_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0007_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | # - scene0000_00 50 | # - scene0001_00 51 | # - scene0002_00 52 | # - scene0003_00 53 | # - scene0004_00 54 | # - scene0005_00 55 | # - scene0006_00 56 | - scene0007_00 57 | # - scene0008_00 58 | # - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s80_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0008_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | # - scene0000_00 50 | # - scene0001_00 51 | # - scene0002_00 52 | # - scene0003_00 53 | # - scene0004_00 54 | # - scene0005_00 55 | # - scene0006_00 56 | # - scene0007_00 57 | - scene0008_00 58 | # - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | -------------------------------------------------------------------------------- /cfg/exp/one_step_finetune_nerf/s90_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: finetune_nerf_train/scannet_scene0009_00 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr: 1.0e-5 16 | name: Adam 17 | 18 | trainer: 19 | max_epochs: 50 20 | gpus: -1 21 | num_sanity_val_steps: 0 22 | check_val_every_n_epoch: 1 23 | resume_from_checkpoint: False 24 | load_from_checkpoint: True 25 | 26 | data_module: 27 | pin_memory: true 28 | batch_size: 4 29 | shuffle: true 30 | num_workers: 2 31 | drop_last: false 32 | root: data/scannet/scannet_frames_25k 33 | train_image: nerf 34 | train_label: nerf 35 | data_preprocessing: 36 | val_ratio: 0.2 37 | image_regex: /*/color/*.jpg 38 | split_file: split.npz 39 | split_file_cl: split_cl.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | 48 | scenes: 49 | # - scene0000_00 50 | # - scene0001_00 51 | # - scene0002_00 52 | # - scene0003_00 53 | # - scene0004_00 54 | # - scene0005_00 55 | # - scene0006_00 56 | # - scene0007_00 57 | # - scene0008_00 58 | - scene0009_00 59 | 60 | cl: 61 | active: false 62 | use_novel_viewpoints: False 63 | replay_buffer_size: 0 64 | 65 | 66 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s00_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0000_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | - scene0000_00 49 | # - scene0001_00 50 | # - scene0002_00 51 | # - scene0003_00 52 | # - scene0004_00 53 | # - scene0005_00 54 | # - scene0006_00 55 | # - scene0007_00 56 | # - scene0008_00 57 | # - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s10_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0001_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | # - scene0000_00 49 | - scene0001_00 50 | # - scene0002_00 51 | # - scene0003_00 52 | # - scene0004_00 53 | # - scene0005_00 54 | # - scene0006_00 55 | # - scene0007_00 56 | # - scene0008_00 57 | # - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | 65 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s20_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0002_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | # - scene0000_00 49 | # - scene0001_00 50 | - scene0002_00 51 | # - scene0003_00 52 | # - scene0004_00 53 | # - scene0005_00 54 | # - scene0006_00 55 | # - scene0007_00 56 | # - scene0008_00 57 | # - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s30_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0003_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | # - scene0000_00 49 | # - scene0001_00 50 | # - scene0002_00 51 | - scene0003_00 52 | # - scene0004_00 53 | # - scene0005_00 54 | # - scene0006_00 55 | # - scene0007_00 56 | # - scene0008_00 57 | # - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | 65 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s40_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0004_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | # - scene0000_00 49 | # - scene0001_00 50 | # - scene0002_00 51 | # - scene0003_00 52 | - scene0004_00 53 | # - scene0005_00 54 | # - scene0006_00 55 | # - scene0007_00 56 | # - scene0008_00 57 | # - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | 65 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s50_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0005_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | # - scene0000_00 49 | # - scene0001_00 50 | # - scene0002_00 51 | # - scene0003_00 52 | # - scene0004_00 53 | - scene0005_00 54 | # - scene0006_00 55 | # - scene0007_00 56 | # - scene0008_00 57 | # - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | 65 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s60_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0006_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | # - scene0000_00 49 | # - scene0001_00 50 | # - scene0002_00 51 | # - scene0003_00 52 | # - scene0004_00 53 | # - scene0005_00 54 | - scene0006_00 55 | # - scene0007_00 56 | # - scene0008_00 57 | # - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s70_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0007_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | # - scene0000_00 49 | # - scene0001_00 50 | # - scene0002_00 51 | # - scene0003_00 52 | # - scene0004_00 53 | # - scene0005_00 54 | # - scene0006_00 55 | - scene0007_00 56 | # - scene0008_00 57 | # - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s80_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0008_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | # - scene0000_00 49 | # - scene0001_00 50 | # - scene0002_00 51 | # - scene0003_00 52 | # - scene0004_00 53 | # - scene0005_00 54 | # - scene0006_00 55 | # - scene0007_00 56 | - scene0008_00 57 | # - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | -------------------------------------------------------------------------------- /cfg/exp/one_step_joint/s90_lr1e-5.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: joint_train/scannet_scene0009_00_finetune 3 | clean_up_folder_if_exists: True 4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt" 5 | 6 | model: 7 | pretrained: False 8 | pretrained_backbone: True 9 | num_classes: 40 # Scannet (40) 10 | 11 | lr_scheduler: 12 | active: false 13 | 14 | optimizer: 15 | lr_seg: 1.0e-5 16 | lr_nerf: 1.0e-2 17 | name: Adam 18 | 19 | trainer: 20 | max_epochs: 20 21 | gpus: -1 22 | num_sanity_val_steps: 0 23 | check_val_every_n_epoch: 1 24 | resume_from_checkpoint: False 25 | load_from_checkpoint: True 26 | 27 | data_module: 28 | pin_memory: true 29 | batch_size: 4 30 | shuffle: true 31 | num_workers: 0 32 | drop_last: true 33 | root: data/scannet/scannet_frames_25k 34 | data_preprocessing: 35 | val_ratio: 0.2 36 | image_regex: /*/color/*.jpg 37 | split_file: split.npz 38 | split_file_cl: split_cl.npz 39 | 40 | visualizer: 41 | store: true 42 | store_n: 43 | train: 3 44 | val: 3 45 | test: 3 46 | 47 | scenes: 48 | # - scene0000_00 49 | # - scene0001_00 50 | # - scene0002_00 51 | # - scene0003_00 52 | # - scene0004_00 53 | # - scene0005_00 54 | # - scene0006_00 55 | # - scene0007_00 56 | # - scene0008_00 57 | - scene0009_00 58 | 59 | cl: 60 | active: false 61 | use_novel_viewpoints: False 62 | replay_buffer_size: 0 63 | 64 | -------------------------------------------------------------------------------- /cfg/exp/pretrain_scannet_25k_deeplabv3.yml: -------------------------------------------------------------------------------- 1 | general: 2 | name: scannet_25k_deeplab/pretrain 3 | clean_up_folder_if_exists: True 4 | 5 | model: 6 | pretrained: False 7 | pretrained_backbone: True 8 | num_classes: 40 # Scannet (40) 9 | 10 | lr_scheduler: 11 | active: true 12 | name: POLY 13 | poly_cfg: 14 | power: 0.9 15 | max_epochs: 150 16 | target_lr: 1.0e-06 17 | 18 | optimizer: 19 | lr: 0.0001 20 | name: Adam 21 | 22 | trainer: 23 | max_epochs: 150 24 | gpus: -1 25 | num_sanity_val_steps: 0 26 | check_val_every_n_epoch: 1 27 | resume_from_checkpoint: false 28 | 29 | data_module: 30 | pin_memory: true 31 | batch_size: 4 32 | shuffle: true 33 | num_workers: 2 34 | drop_last: false 35 | root: data/scannet_frames_25k/scannet_frames_25k 36 | data_preprocessing: 37 | val_ratio: 0.2 38 | image_regex: /*/color/*.jpg 39 | split_file: split.npz 40 | 41 | visualizer: 42 | store: true 43 | store_n: 44 | train: 3 45 | val: 3 46 | test: 3 47 | -------------------------------------------------------------------------------- /nr4seg/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 4 | 5 | if not "ENV_WORKSTATION_NAME" in os.environ: 6 | os.environ["ENV_WORKSTATION_NAME"] = "env" 7 | -------------------------------------------------------------------------------- /nr4seg/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .scannet import ScanNet 2 | from .scannet_cl import ScanNetCL 3 | from .scannet_cl_joint import ScanNetCLJoint 4 | from .scannet_ngp import ScanNetNGP 5 | from .scannet_ngp_joint import ScanNetNGPJoint 6 | 7 | __all__ = [ 8 | "ScanNet", 9 | "ScanNetCL", 10 | "ScanNetCLJoint", 11 | "ScanNetNGP", 12 | "ScanNetNGPJoint", 13 | ] 14 | -------------------------------------------------------------------------------- /nr4seg/dataset/create_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import random 5 | 6 | from glob import glob 7 | 8 | from nr4seg.utils import load_yaml 9 | 10 | parser = argparse.ArgumentParser(description="Generates a data split file.") 11 | 12 | curr_dir_path = os.path.dirname(os.path.abspath(__file__)) 13 | default_exp_path = os.path.join( 14 | curr_dir_path, "../../cfg/exp/pretrain_scannet_25k_deeplabv3.yml") 15 | parser.add_argument("--config", 16 | type=str, 17 | default=default_exp_path, 18 | help="Path to config file.") 19 | 20 | args = parser.parse_args() 21 | 22 | cfg = load_yaml(args.config) 23 | cfg = cfg["data_module"] 24 | 25 | train_all = glob(cfg["root"] + cfg["data_preprocessing"]["image_regex"]) 26 | random.shuffle(train_all) 27 | val = train_all[:int(len(train_all) * cfg["data_preprocessing"]["val_ratio"])] 28 | train = train_all[int(len(train_all) * cfg["data_preprocessing"]["val_ratio"]):] 29 | test = val 30 | train, val, test = map(sorted, [train, val, test]) 31 | 32 | split_cl = {"train_cl": train} 33 | 34 | split_dict = {"train": train, "test": test, "val": val} 35 | env_name = os.environ["ENV_WORKSTATION_NAME"] 36 | env = load_yaml(os.path.join("cfg/env", env_name + ".yml")) 37 | scannet_25k_dir = env["scannet_frames_25k"] 38 | out_file = os.path.join(scannet_25k_dir, 39 | cfg["data_preprocessing"]["split_file_cl"]) 40 | np.savez(out_file, **split_cl) 41 | -------------------------------------------------------------------------------- /nr4seg/dataset/helper.py: -------------------------------------------------------------------------------- 1 | import random 2 | import PIL 3 | import torch 4 | from torchvision import transforms as tf 5 | from torchvision.transforms import functional as F 6 | import warnings 7 | 8 | __all__ = ["Augmentation", "AugmentationList", "get_output_size"] 9 | 10 | 11 | def get_output_size(output_size): 12 | if type(output_size) == list or type(output_size) == tuple: 13 | if len(output_size) == 2: 14 | output_size = tuple(output_size) 15 | elif len(output_size) == 1: 16 | output_size = (output_size[0], output_size[0]) 17 | elif type(output_size) == int: 18 | output_size = (output_size, output_size) 19 | return output_size 20 | 21 | 22 | class Augmentation: 23 | 24 | def __init__( 25 | self, 26 | output_size=400, 27 | degrees=10, 28 | flip_p=0.5, 29 | jitter_bcsh=[0.3, 0.3, 0.3, 0.05], 30 | ): 31 | 32 | # training transforms 33 | output_size = get_output_size(output_size) 34 | self._output_size = output_size 35 | self._crop = tf.RandomCrop(self._output_size) 36 | with warnings.catch_warnings(): 37 | # this will suppress all warnings in this block 38 | warnings.simplefilter("ignore") 39 | self._rot = tf.RandomRotation(degrees=degrees, 40 | resample=PIL.Image.BILINEAR) 41 | self._flip_p = flip_p 42 | self._degrees = degrees 43 | 44 | self._jitter = tf.ColorJitter( 45 | brightness=jitter_bcsh[0], 46 | contrast=jitter_bcsh[1], 47 | saturation=jitter_bcsh[2], 48 | hue=jitter_bcsh[3], 49 | ) 50 | self._crop_center = tf.CenterCrop(self._output_size) 51 | 52 | def apply(self, img, label, only_crop=False): 53 | # LABEL is now a list of labels 54 | 55 | scale = False 56 | # Check if rescaling is neccessary based on image height and output height 57 | if img.shape[1] >= 2 * self._output_size[0]: 58 | sf = float(self._output_size[0] / img.shape[1]) * 1.2 59 | sf2 = float(self._output_size[1] / img.shape[2]) * 1.2 60 | sf = max(sf, sf2) 61 | 62 | scale = True 63 | elif (img.shape[1] < self._output_size[0] or 64 | img.shape[2] < self._output_size[1]): 65 | sf1 = float(self._output_size[0] / img.shape[1]) * 1.2 66 | sf2 = float(self._output_size[1] / img.shape[2]) * 1.2 67 | sf = max(sf1, sf2) 68 | scale = True 69 | 70 | if scale: 71 | img = torch.nn.functional.interpolate( 72 | img[None], 73 | scale_factor=(sf, sf), 74 | mode="bilinear", 75 | recompute_scale_factor=False, 76 | align_corners=False, 77 | )[0] 78 | label = torch.nn.functional.interpolate( 79 | label[None], 80 | scale_factor=(sf, sf), 81 | mode="nearest", 82 | recompute_scale_factor=False, 83 | )[0] 84 | if not only_crop: 85 | # Color Jitter 86 | img = self._jitter(img) 87 | 88 | # Rotate 89 | angle = random.uniform(-self._degrees, self._degrees) 90 | with warnings.catch_warnings(): 91 | # this will suppress all warnings in this block 92 | warnings.simplefilter("ignore") 93 | img = F.rotate( 94 | img, 95 | angle, 96 | resample=PIL.Image.BILINEAR, 97 | expand=False, 98 | center=None, 99 | fill=0, 100 | ) 101 | label = F.rotate( 102 | label, 103 | angle, 104 | resample=PIL.Image.NEAREST, 105 | expand=False, 106 | center=None, 107 | fill=0, 108 | ) 109 | 110 | # Crop 111 | i, j, h, w = self._crop.get_params(img, self._output_size) 112 | img = F.crop(img, i, j, h, w) 113 | label = F.crop(label, i, j, h, w) 114 | 115 | # Flip 116 | if torch.rand(1) < self._flip_p: 117 | img = F.hflip(img) 118 | label = F.hflip(label) 119 | 120 | # Performes center crop 121 | img = self._crop_center(img) 122 | label = self._crop_center(label) 123 | 124 | return img, label 125 | 126 | 127 | class AugmentationList: 128 | 129 | def __init__( 130 | self, 131 | output_size=400, 132 | degrees=10, 133 | flip_p=0.5, 134 | jitter_bcsh=[0.3, 0.3, 0.3, 0.05], 135 | ): 136 | 137 | # training transforms 138 | output_size = get_output_size(output_size) 139 | self._output_size = output_size 140 | self._crop = tf.RandomCrop(self._output_size) 141 | with warnings.catch_warnings(): 142 | # this will suppress all warnings in this block 143 | warnings.simplefilter("ignore") 144 | self._rot = tf.RandomRotation(degrees=degrees, 145 | resample=PIL.Image.BILINEAR) 146 | self._flip_p = flip_p 147 | self._degrees = degrees 148 | 149 | self._jitter = tf.ColorJitter( 150 | brightness=jitter_bcsh[0], 151 | contrast=jitter_bcsh[1], 152 | saturation=jitter_bcsh[2], 153 | hue=jitter_bcsh[3], 154 | ) 155 | self._crop_center = tf.CenterCrop(self._output_size) 156 | 157 | def apply(self, img, label, only_crop=False): 158 | scale = False 159 | # Check if rescaling is neccessary based on image height and output height 160 | if img.shape[1] >= 2 * self._output_size[0]: 161 | sf = float(self._output_size[0] / img.shape[1]) * 1.2 162 | sf2 = float(self._output_size[1] / img.shape[2]) * 1.2 163 | sf = max(sf, sf2) 164 | 165 | scale = True 166 | elif (img.shape[1] < self._output_size[0] or 167 | img.shape[2] < self._output_size[1]): 168 | sf1 = float(self._output_size[0] / img.shape[1]) * 1.2 169 | sf2 = float(self._output_size[1] / img.shape[2]) * 1.2 170 | sf = max(sf1, sf2) 171 | scale = True 172 | 173 | if scale: 174 | img = torch.nn.functional.interpolate( 175 | img[None], 176 | scale_factor=(sf, sf), 177 | mode="bilinear", 178 | recompute_scale_factor=False, 179 | align_corners=False, 180 | )[0] 181 | for _i, l in enumerate(label): 182 | label[_i] = torch.nn.functional.interpolate( 183 | l[None], 184 | scale_factor=(sf, sf), 185 | mode="nearest", 186 | recompute_scale_factor=False, 187 | )[0] 188 | if not only_crop: 189 | # Color Jitter 190 | img = self._jitter(img) 191 | 192 | # Rotate 193 | angle = random.uniform(-self._degrees, self._degrees) 194 | with warnings.catch_warnings(): 195 | # this will suppress all warnings in this block 196 | warnings.simplefilter("ignore") 197 | img = F.rotate( 198 | img, 199 | angle, 200 | resample=PIL.Image.BILINEAR, 201 | expand=False, 202 | center=None, 203 | fill=0, 204 | ) 205 | for _i, l in enumerate(label): 206 | label[_i] = F.rotate( 207 | l, 208 | angle, 209 | resample=PIL.Image.NEAREST, 210 | expand=False, 211 | center=None, 212 | fill=0, 213 | ) 214 | 215 | # Crop 216 | i, j, h, w = self._crop.get_params(img, self._output_size) 217 | img = F.crop(img, i, j, h, w) 218 | for _i, l in enumerate(label): 219 | label[_i] = F.crop(l, i, j, h, w) 220 | 221 | # Flip 222 | if torch.rand(1) < self._flip_p: 223 | img = F.hflip(img) 224 | for _i, l in enumerate(label): 225 | label[_i] = F.hflip(l) 226 | 227 | # Performes center crop 228 | img = self._crop_center(img) 229 | for _i, l in enumerate(label): 230 | label[_i] = self._crop_center(l) 231 | 232 | return img, label 233 | -------------------------------------------------------------------------------- /nr4seg/dataset/label_loader.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import os 4 | import pandas 5 | import torch 6 | 7 | __all__ = ["LabelLoaderAuto"] 8 | 9 | 10 | class LabelLoaderAuto: 11 | 12 | def __init__(self, root_scannet=None, confidence=0, H=968, W=1296): 13 | assert root_scannet is not None 14 | self._get_mapping(root_scannet) 15 | self._confidence = confidence 16 | # return label between 0-40 17 | 18 | self.max_classes = 40 19 | self.label = np.zeros((H, W, self.max_classes)) 20 | iu16 = np.iinfo(np.uint16) 21 | mask = np.full((H, W), iu16.max, dtype=np.uint16) 22 | self.mask_low = np.right_shift(mask, 6, dtype=np.uint16) 23 | 24 | def get(self, path): 25 | img = imageio.imread(path) 26 | if len(img.shape) == 3: 27 | if img.shape[2] == 4: 28 | H, W, _ = img.shape 29 | self.label = np.zeros((H, W, self.max_classes)) 30 | for i in range(3): 31 | prob = np.bitwise_and(img[:, :, i], self.mask_low) / 1023 32 | cls = np.right_shift(img[:, :, i], 10, dtype=np.uint16) 33 | m = np.eye(self.max_classes)[cls] == 1 34 | self.label[m] = prob.reshape(-1) 35 | m = np.max(self.label, axis=2) < self._confidence 36 | self.label = np.argmax(self.label, axis=2).astype(np.int32) + 1 37 | self.label[m] = 0 38 | method = "RGBA" 39 | else: 40 | raise Exception("Type not know") 41 | elif len(img.shape) == 2 and img.dtype == np.uint8: 42 | self.label = img.astype(np.int32) 43 | method = "FAST" 44 | elif len(img.shape) == 2 and img.dtype == np.uint16: 45 | self.label = torch.from_numpy(img.astype(np.int32)).type( 46 | torch.float32)[None, :, :] 47 | sa = self.label.shape 48 | self.label = self.label.flatten() 49 | self.label = self.mapping[self.label.type(torch.int64)] 50 | self.label = self.label.reshape(sa).numpy().astype(np.int32)[0] 51 | method = "MAPPED" 52 | else: 53 | raise Exception("Type not know") 54 | return self.label, method 55 | 56 | def get_probs(self, path): 57 | img = imageio.imread(path) 58 | assert len(img.shape) == 3 59 | assert img.shape[2] == 4 60 | H, W, _ = img.shape 61 | probs = np.zeros((H, W, self.max_classes)) 62 | for i in range(3): 63 | prob = np.bitwise_and(img[:, :, i], self.mask_low) / 1023 64 | cls = np.right_shift(img[:, :, i], 10, dtype=np.uint16) 65 | m = np.eye(self.max_classes)[cls] == 1 66 | probs[m] = prob.reshape(-1) 67 | 68 | return probs 69 | 70 | def _get_mapping(self, root): 71 | tsv = os.path.join(root, "scannetv2-labels.combined.tsv") 72 | df = pandas.read_csv(tsv, sep="\t") 73 | mapping_source = np.array(df["id"]) 74 | mapping_target = np.array(df["nyu40id"]) 75 | 76 | self.mapping = torch.zeros((int(mapping_source.max() + 1)), 77 | dtype=torch.int64) 78 | for so, ta in zip(mapping_source, mapping_target): 79 | self.mapping[so] = ta 80 | -------------------------------------------------------------------------------- /nr4seg/dataset/ngp_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from packaging import version as pver 4 | 5 | 6 | # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 7 | def nerf_matrix_to_ngp(pose): 8 | new_pose = np.array( 9 | [ 10 | [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3]], 11 | [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3]], 12 | [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3]], 13 | [0, 0, 0, 1], 14 | ], 15 | dtype=np.float32, 16 | ) 17 | return new_pose 18 | 19 | 20 | def custom_meshgrid(*args): 21 | # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid 22 | if pver.parse(torch.__version__) < pver.parse("1.10"): 23 | return torch.meshgrid(*args) 24 | else: 25 | return torch.meshgrid(*args, indexing="ij") 26 | 27 | 28 | @torch.cuda.amp.autocast(enabled=False) 29 | def get_rays(poses, intrinsics, H, W, error_map=None): 30 | """get rays 31 | Args: 32 | poses: [B, 4, 4], cam2world 33 | intrinsics: [4] 34 | H, W, N: int 35 | error_map: [B, 128 * 128], sample probability based on training error 36 | Returns: 37 | rays_o, rays_d: [B, N, 3] 38 | direction_norms: [B, N, 1] 39 | inds: [B, N] 40 | """ 41 | 42 | device = poses.device 43 | B = poses.shape[0] 44 | fx, fy, cx, cy = intrinsics 45 | 46 | i, j = custom_meshgrid( 47 | torch.linspace(0, W - 1, W, device=device), 48 | torch.linspace(0, H - 1, H, device=device), 49 | ) 50 | i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 51 | j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 52 | 53 | results = {} 54 | zs = torch.ones_like(i) 55 | xs = (i - cx) / fx * zs 56 | ys = (j - cy) / fy * zs 57 | directions = torch.stack((xs, ys, zs), dim=-1) 58 | direction_norms = torch.norm(directions, dim=-1, keepdim=True) 59 | directions = directions / direction_norms 60 | rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) 61 | 62 | rays_o = poses[..., :3, 3] # [B, 3] 63 | rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] 64 | 65 | results["rays_o"] = rays_o 66 | results["rays_d"] = rays_d 67 | results["direction_norms"] = direction_norms 68 | 69 | return results 70 | 71 | 72 | # color palette for nyu40 labels 73 | nyu40_colour_code = np.array([ 74 | (0, 0, 0), 75 | (174, 199, 232), # wall 76 | (152, 223, 138), # floor 77 | (31, 119, 180), # cabinet 78 | (255, 187, 120), # bed 79 | (188, 189, 34), # chair 80 | (140, 86, 75), # sofa 81 | (255, 152, 150), # table 82 | (214, 39, 40), # door 83 | (197, 176, 213), # window 84 | (148, 103, 189), # bookshelf 85 | (196, 156, 148), # picture 86 | (23, 190, 207), # counter 87 | (178, 76, 76), # blinds 88 | (247, 182, 210), # desk 89 | (66, 188, 102), # shelves 90 | (219, 219, 141), # curtain 91 | (140, 57, 197), # dresser 92 | (202, 185, 52), # pillow 93 | (51, 176, 203), # mirror 94 | (200, 54, 131), # floor 95 | (92, 193, 61), # clothes 96 | (78, 71, 183), # ceiling 97 | (172, 114, 82), # books 98 | (255, 127, 14), # refrigerator 99 | (91, 163, 138), # tv 100 | (153, 98, 156), # paper 101 | (140, 153, 101), # towel 102 | (158, 218, 229), # shower curtain 103 | (100, 125, 154), # box 104 | (178, 127, 135), # white board 105 | (120, 185, 128), # person 106 | (146, 111, 194), # night stand 107 | (44, 160, 44), # toilet 108 | (112, 128, 144), # sink 109 | (96, 207, 209), # lamp 110 | (227, 119, 194), # bathtub 111 | (213, 92, 176), # bag 112 | (94, 106, 211), # other struct 113 | (82, 84, 163), # otherfurn 114 | (100, 85, 144), # other prop 115 | ]).astype(np.uint8) 116 | -------------------------------------------------------------------------------- /nr4seg/dataset/scannet.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | import os 4 | import random 5 | import torch 6 | 7 | from torch.utils.data import Dataset 8 | 9 | try: 10 | from .helper import AugmentationList 11 | except Exception: 12 | from helper import AugmentationList 13 | 14 | from .label_loader import LabelLoaderAuto 15 | 16 | __all__ = ["ScanNet"] 17 | 18 | 19 | class ScanNet(Dataset): 20 | 21 | def __init__( 22 | self, 23 | root, 24 | img_list, 25 | mode="train", 26 | output_trafo=None, 27 | output_size=(240, 320), 28 | degrees=10, 29 | flip_p=0.5, 30 | jitter_bcsh=[0.3, 0.3, 0.3, 0.05], 31 | sub=10, 32 | data_augmentation=True, 33 | label_setting="default", 34 | confidence_aux=0, 35 | ): 36 | """ 37 | Dataset dosent know if it contains replayed or normal samples ! 38 | 39 | Some images are stored in 640x480 other ins 1296x968 40 | Warning scene0088_03 has wrong resolution -> Ignored 41 | Parameters 42 | ---------- 43 | root : str, path to the ML-Hypersim folder 44 | mode : str, option ['train','val] 45 | """ 46 | 47 | super(ScanNet, self).__init__() 48 | self._sub = sub 49 | self._mode = mode 50 | self._confidence_aux = confidence_aux 51 | 52 | self._label_setting = label_setting 53 | self.image_pths = img_list 54 | self.label_pths = [ 55 | p.replace("color", "label").replace("jpg", "png") for p in img_list 56 | ] 57 | self.length = len(self.image_pths) 58 | self.aux_labels = False 59 | self._augmenter = AugmentationList(output_size, degrees, flip_p, 60 | jitter_bcsh) 61 | self._output_trafo = output_trafo 62 | self._data_augmentation = data_augmentation 63 | self.unique = False 64 | 65 | self.aux_labels_fake = False 66 | 67 | self._label_loader = LabelLoaderAuto(root_scannet=root, 68 | confidence=self._confidence_aux) 69 | if self.aux_labels: 70 | self._preprocessing_hack() 71 | 72 | def set_aux_labels_fake(self, flag=True): 73 | self.aux_labels_fake = flag 74 | self.aux_labels = flag 75 | 76 | def __getitem__(self, index): 77 | # Read Image and Label 78 | label, _ = self._label_loader.get(self.label_pths[index]) 79 | label = torch.from_numpy(label).type( 80 | torch.float32)[None, :, :] # C H W -> contains 0-40 81 | label = [label] 82 | if self.aux_labels and not self.aux_labels_fake: 83 | _p = self.aux_label_pths[index] 84 | if os.path.isfile(_p): 85 | aux_label, _ = self._label_loader.get(_p) 86 | aux_label = torch.from_numpy(aux_label).type( 87 | torch.float32)[None, :, :] 88 | label.append(aux_label) 89 | else: 90 | if _p.find("_.png") != -1: 91 | print(_p) 92 | print("Processed not found") 93 | _p = _p.replace("_.png", ".png") 94 | aux_label, _ = self._label_loader.get(_p) 95 | aux_label = torch.from_numpy(aux_label).type( 96 | torch.float32)[None, :, :] 97 | label.append(aux_label) 98 | 99 | img = imageio.imread(self.image_pths[index]) 100 | img = (torch.from_numpy(img).type(torch.float32).permute(2, 0, 1) / 255 101 | ) # C H W range 0-1 102 | 103 | if self._mode.find("train") != -1 and self._data_augmentation: 104 | img, label = self._augmenter.apply(img, label) 105 | else: 106 | img, label = self._augmenter.apply(img, label, only_crop=True) 107 | 108 | img_ori = img.clone() 109 | if self._output_trafo is not None: 110 | img = self._output_trafo(img) 111 | 112 | for k in range(len(label)): 113 | label[k] = label[k] - 1 # 0 == chairs 39 other prop -1 invalid 114 | 115 | # REJECT LABEL 116 | if (label[0] != -1).sum() < 10: 117 | idx = random.randint(0, len(self) - 1) 118 | if not self.unique: 119 | return self[idx] 120 | else: 121 | return False 122 | 123 | ret = (img, label[0].type(torch.int64)[0, :, :]) 124 | if self.aux_labels: 125 | if self.aux_labels_fake: 126 | ret += ( 127 | label[0].type(torch.int64)[0, :, :], 128 | torch.tensor(False), 129 | ) 130 | else: 131 | ret += ( 132 | label[1].type(torch.int64)[0, :, :], 133 | torch.tensor(True), 134 | ) 135 | 136 | ret += (img_ori,) 137 | return ret 138 | 139 | def __len__(self): 140 | return self.length 141 | 142 | def __str__(self): 143 | string = "=" * 90 144 | string += "\nScannet Dataset: \n" 145 | length = len(self) 146 | string += f" Total Samples: {length}" 147 | string += f" » Mode: {self._mode} \n" 148 | string += f" Replay: {self.replay}" 149 | string += f" » DataAug: {self._data_augmentation}" 150 | string += ( 151 | f" » DataAug Replay: {self._data_augmentation_for_replay}\n") 152 | string += "=" * 90 153 | return string 154 | 155 | def _preprocessing_hack(self, force=False): 156 | """ 157 | If training with aux_labels -> 158 | generates label for fast loading with a fixed certainty. 159 | """ 160 | 161 | # check if this has already been performed 162 | aux_label, method = self._label_loader.get( 163 | self.aux_label_pths[self.global_to_local_idx[0]]) 164 | print("Meethod ", method) 165 | print( 166 | "self.global_to_local_idx[0] ", 167 | self.global_to_local_idx[0], 168 | self.aux_label_pths[self.global_to_local_idx[0]], 169 | ) 170 | if method == "RGBA": 171 | 172 | # This should always evaluate to true 173 | if (self.aux_label_pths[self.global_to_local_idx[0]].find("_.png") 174 | == -1): 175 | print( 176 | "self.aux_label_pths[self.global_to_local_idx[0]]", 177 | self.aux_label_pths[self.global_to_local_idx[0]], 178 | self.global_to_local_idx[0], 179 | ) 180 | if (os.path.isfile(self.aux_label_pths[ 181 | self.global_to_local_idx[0]].replace(".png", "_.png")) 182 | and os.path.isfile(self.aux_label_pths[ 183 | self.global_to_local_idx[-1]].replace( 184 | ".png", "_.png")) and not force): 185 | # only perform simple renaming 186 | print("Only do renanming") 187 | self.aux_label_pths = [ 188 | a.replace(".png", "_.png") for a in self.aux_label_pths 189 | ] 190 | else: 191 | print("Start multithread preprocessing of images") 192 | 193 | def parallel(gtli, aux_label_pths, label_loader): 194 | print("Start take care of: ", gtli[0], " - ", gtli[-1]) 195 | for i in gtli: 196 | aux_label, method = label_loader.get( 197 | aux_label_pths[i]) 198 | imageio.imwrite( 199 | aux_label_pths[i].replace(".png", "_.png"), 200 | np.uint8(aux_label), 201 | ) 202 | 203 | def parallel2(aux_pths, label_loader): 204 | for a in aux_pths: 205 | aux_label, method = label_loader.get(a) 206 | imageio.imwrite( 207 | a.replace(".png", "_.png"), 208 | np.uint8(aux_label), 209 | ) 210 | 211 | cores = 16 212 | tasks = [ 213 | t.tolist() for t in np.array_split( 214 | np.array(self.global_to_local_idx), cores) 215 | ] 216 | 217 | from multiprocessing import Process 218 | 219 | for i in range(cores): 220 | p = Process( 221 | target=parallel2, 222 | args=( 223 | np.array(self.aux_label_pths)[np.array( 224 | tasks[i])].tolist(), 225 | self._label_loader, 226 | ), 227 | ) 228 | p.start() 229 | p.join() 230 | print("Done multithread preprocessing of images") 231 | self.aux_label_pths = [ 232 | a.replace(".png", "_.png") for a in self.aux_label_pths 233 | ] 234 | -------------------------------------------------------------------------------- /nr4seg/dataset/scannet_cl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | 5 | from glob import glob 6 | from torch.utils.data import Dataset 7 | 8 | __all__ = ["ScanNetCL"] 9 | 10 | 11 | class ScanNetCL(Dataset): 12 | 13 | def __init__( 14 | self, 15 | scannet_25k, 16 | scannet_ngp, 17 | ngp_25k_ratio=1, 18 | ): 19 | """ 20 | Dataset dosent know if it contains replayed or normal samples ! 21 | 22 | Some images are stored in 640x480 other ins 1296x968 23 | Warning scene0088_03 has wrong resolution -> Ignored 24 | Parameters 25 | ---------- 26 | root : str, path to the ML-Hypersim folder 27 | mode : str, option ['train','val] 28 | """ 29 | 30 | super(ScanNetCL, self).__init__() 31 | self.scannet_25k = scannet_25k 32 | self.scannet_ngp = scannet_ngp 33 | self.ngp_25k_ratio = ngp_25k_ratio 34 | 35 | def get_image_pths(self, scene_list, val_ratio=0.2): 36 | img_list = [] 37 | for scene_name in scene_list: 38 | all_imgs = glob(self.root + "/" + scene_name + "/color/*jpg") 39 | all_imgs = sorted(all_imgs, 40 | key=lambda x: int(os.path.basename(x)[:-4])) 41 | val_imgs = all_imgs[-int(len(all_imgs) * val_ratio):] 42 | train_imgs = all_imgs[:-int(len(all_imgs) * val_ratio)] 43 | if self._mode == "train": 44 | img_list.extend(train_imgs[::self._sub]) 45 | else: 46 | img_list.extend(val_imgs[::self._sub]) 47 | 48 | return img_list 49 | 50 | def __getitem__(self, index): 51 | # Read Image and Label 52 | ret_ngp = self.scannet_ngp.__getitem__(index) 53 | ret_25k = [] 54 | 55 | for _ in range(self.ngp_25k_ratio): 56 | rand_id = random.randint(0, self.scannet_25k.__len__() - 1) 57 | ret_25k.append(self.scannet_25k.__getitem__(rand_id)) 58 | 59 | return (ret_ngp, ret_25k) 60 | 61 | @staticmethod 62 | def collate(batch): 63 | img = [] 64 | label = [] 65 | img_ori = [] 66 | 67 | for bb in batch: 68 | img.append(bb[0][0]) 69 | label.append(bb[0][1]) 70 | img_ori.append(bb[0][2]) 71 | for bbb in bb[1]: 72 | img.append(bbb[0]) 73 | label.append(bbb[1]) 74 | img_ori.append(bbb[2]) 75 | 76 | img = torch.stack(img, dim=0) 77 | label = torch.stack(label, dim=0) 78 | img_ori = torch.stack(img_ori, dim=0) 79 | return img, label, img_ori 80 | 81 | def __len__(self): 82 | return self.scannet_ngp.__len__() 83 | -------------------------------------------------------------------------------- /nr4seg/dataset/scannet_cl_joint.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | __all__ = ["ScanNetCLJoint"] 6 | 7 | 8 | class ScanNetCLJoint(Dataset): 9 | 10 | def __init__( 11 | self, 12 | scannet_25k, 13 | scannet_ngp, 14 | ngp_25k_ratio=1, 15 | ): 16 | """ 17 | Dataset dosent know if it contains replayed or normal samples ! 18 | 19 | Some images are stored in 640x480 other ins 1296x968 20 | Warning scene0088_03 has wrong resolution -> Ignored 21 | Parameters 22 | ---------- 23 | root : str, path to the ML-Hypersim folder 24 | mode : str, option ['train','val] 25 | """ 26 | 27 | super(ScanNetCLJoint, self).__init__() 28 | self.scannet_25k = scannet_25k 29 | self.scannet_ngp = scannet_ngp 30 | self.ngp_25k_ratio = ngp_25k_ratio 31 | 32 | def __getitem__(self, index): 33 | # Read Image and Label 34 | ret_dict = self.scannet_ngp.__getitem__(index) 35 | ret_25k = {"replay_img": [], "replay_label": []} 36 | 37 | for _ in range(self.ngp_25k_ratio): 38 | rand_id = random.randint(0, self.scannet_25k.__len__() - 1) 39 | img, label, _ = self.scannet_25k.__getitem__(rand_id) 40 | ret_25k["replay_img"].append(img) 41 | ret_25k["replay_label"].append(label) 42 | 43 | for key in ret_25k.keys(): 44 | ret_25k[key] = torch.stack(ret_25k[key], dim=0) 45 | 46 | ret_dict.update(ret_25k) 47 | return ret_dict 48 | 49 | @staticmethod 50 | def collate(batch): 51 | img = [] 52 | label = [] 53 | img_ori = [] 54 | 55 | for bb in batch: 56 | img.append(bb[0][0]) 57 | label.append(bb[0][1]) 58 | img_ori.append(bb[0][2]) 59 | for bbb in bb[1]: 60 | img.append(bbb[0]) 61 | label.append(bbb[1]) 62 | img_ori.append(bbb[2]) 63 | 64 | img = torch.stack(img, dim=0) 65 | label = torch.stack(label, dim=0) 66 | img_ori = torch.stack(img_ori, dim=0) 67 | return batch_new, batch_old 68 | 69 | def __len__(self): 70 | return self.scannet_ngp.__len__() 71 | 72 | def __len__(self): 73 | return self.scannet_ngp.__len__() 74 | -------------------------------------------------------------------------------- /nr4seg/dataset/scannet_ngp.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | import torch 5 | 6 | from glob import glob 7 | from torch.utils.data import Dataset 8 | 9 | try: 10 | from .helper import AugmentationList 11 | except Exception: 12 | from helper import AugmentationList 13 | 14 | __all__ = ["ScanNetNGP"] 15 | 16 | 17 | class ScanNetNGP(Dataset): 18 | 19 | def __init__( 20 | self, 21 | root, 22 | scene_list, 23 | prev_exp_name="one_step_nerf_only", 24 | mode="train", 25 | train_image="nerf", 26 | train_label="nerf", 27 | val_mode="gtgt", 28 | output_trafo=None, 29 | output_size=(240, 320), 30 | degrees=10, 31 | flip_p=0.5, 32 | jitter_bcsh=[0.3, 0.3, 0.3, 0.05], 33 | sub=1, 34 | data_augmentation=True, 35 | label_setting="default", 36 | confidence_aux=0, 37 | ): 38 | """ 39 | Dataset dosent know if it contains replayed or normal samples ! 40 | 41 | Some images are stored in 640x480 other ins 1296x968 42 | Warning scene0088_03 has wrong resolution -> Ignored 43 | Parameters 44 | ---------- 45 | root : str, path to the ML-Hypersim folder 46 | mode : str, option ['train','val] 47 | """ 48 | 49 | super(ScanNetNGP, self).__init__() 50 | self._sub = sub 51 | self._mode = mode 52 | self._confidence_aux = confidence_aux 53 | 54 | self.H = output_size[0] 55 | self.W = output_size[1] 56 | 57 | self._label_setting = label_setting 58 | self.root = root 59 | self.image_pths, self.img_num = self.get_image_pths(scene_list) 60 | 61 | self.image_gt_pths = self.image_pths 62 | self.image_nerf_pths = [ 63 | p.replace("color_scaled", 64 | prev_exp_name + "/nerf_image").replace("jpg", "png") 65 | for p in self.image_pths 66 | ] 67 | self.label_nerf_pths = [ 68 | p.replace("color_scaled", 69 | prev_exp_name + "/nerf_label").replace("jpg", "png") 70 | for p in self.image_pths 71 | ] 72 | self.label_mapping_pths = [ 73 | p.replace("color_scaled", "mapping_label").replace("jpg", "png") 74 | for p in self.image_pths 75 | ] 76 | self.label_gt_pths = [ 77 | p.replace("color_scaled", "label_scaled").replace("jpg", "png") 78 | for p in self.image_pths 79 | ] 80 | 81 | self.length = len(self.image_pths) 82 | self._augmenter = AugmentationList(output_size, degrees, flip_p, 83 | jitter_bcsh) 84 | self._output_trafo = output_trafo 85 | self._data_augmentation = data_augmentation 86 | self.train_image = train_image 87 | self.train_label = train_label 88 | self.val_mode = val_mode 89 | 90 | def get_image_pths(self, scene_list, val_ratio=0.2): 91 | img_list = [] 92 | img_num = [] 93 | for scene_name in scene_list: 94 | all_imgs = glob(self.root + "/" + scene_name + "/color_scaled/*jpg") 95 | all_imgs = sorted(all_imgs, 96 | key=lambda x: int(os.path.basename(x)[:-4])) 97 | val_imgs = all_imgs[-int(len(all_imgs) * val_ratio):] 98 | # val_imgs = val_imgs[: len(val_imgs)//4*4] 99 | train_imgs = all_imgs[:-int(len(all_imgs) * val_ratio)] 100 | if self._mode == "train": 101 | img_list.extend(train_imgs[::self._sub]) 102 | img_num.append(len(train_imgs[::self._sub])) 103 | else: 104 | img_list.extend(val_imgs[::self._sub]) 105 | 106 | return img_list, img_num 107 | 108 | def __getitem__(self, index): 109 | # Read Image and Label 110 | 111 | use_nerf_label = False 112 | if self._mode == "train": 113 | if self.train_image == "gt": 114 | img = cv2.imread(self.image_gt_pths[index], 115 | cv2.IMREAD_UNCHANGED) 116 | elif self.train_image == "nerf": 117 | img = cv2.imread(self.image_nerf_pths[index], 118 | cv2.IMREAD_UNCHANGED) 119 | elif self.train_image == "half": 120 | img = (cv2.imread(self.image_gt_pths[index], 121 | cv2.IMREAD_UNCHANGED) 122 | if random.random() > 0.5 else cv2.imread( 123 | self.image_nerf_pths[index], cv2.IMREAD_UNCHANGED)) 124 | else: 125 | raise NotImplementedError 126 | else: 127 | if self.val_mode == "gtgt": 128 | img = cv2.imread(self.image_gt_pths[index], 129 | cv2.IMREAD_UNCHANGED) 130 | elif self.val_mode in ["nerfgt", "nerfnerf"]: 131 | img = cv2.imread(self.image_nerf_pths[index], 132 | cv2.IMREAD_UNCHANGED) 133 | else: 134 | raise NotImplementedError 135 | 136 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 137 | img = img / 255.0 138 | img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA) 139 | img = (torch.from_numpy(img).type(torch.float32).permute(2, 0, 1) 140 | ) # C H W range 0-1 141 | 142 | if self._mode == "train": 143 | if self.train_label == "nerf": 144 | label = cv2.imread(self.label_nerf_pths[index], 145 | cv2.IMREAD_UNCHANGED) 146 | else: 147 | label = cv2.imread(self.label_mapping_pths[index], 148 | cv2.IMREAD_UNCHANGED) 149 | use_nerf_label = self.train_label == "nerf" 150 | else: 151 | if self.val_mode in ["gtgt", "nerfgt"]: 152 | label = cv2.imread(self.label_gt_pths[index], 153 | cv2.IMREAD_UNCHANGED) 154 | elif self.val_mode == "nerfnerf": 155 | label = cv2.imread(self.label_nerf_pths[index], 156 | cv2.IMREAD_UNCHANGED) 157 | use_nerf_label = True 158 | 159 | label = cv2.resize(label, (self.W, self.H), 160 | interpolation=cv2.INTER_NEAREST) 161 | label = torch.from_numpy(label).type( 162 | torch.float32)[None, :, :] # C H W -> contains 0-40 163 | 164 | # if use nerf labels + 1 165 | if use_nerf_label: 166 | label = label + 1 167 | 168 | label = [label] 169 | 170 | if self._mode.find("train") != -1 and self._data_augmentation: 171 | img, label = self._augmenter.apply(img, label) 172 | else: 173 | img, label = self._augmenter.apply(img, label, only_crop=True) 174 | 175 | label[0] = label[0] - 1 176 | 177 | img_ori = img.clone() 178 | if self._output_trafo is not None: 179 | img = self._output_trafo(img) 180 | 181 | ret = (img, label[0].type(torch.int64)[0, :, :]) 182 | ret += (img_ori,) 183 | 184 | if self._mode != "train": 185 | current_scene_name = os.path.normpath(self.image_pths[index]).split( 186 | os.path.sep)[-3] 187 | ret += (current_scene_name,) 188 | 189 | return ret 190 | 191 | def __len__(self): 192 | return self.length 193 | 194 | def __str__(self): 195 | string = "=" * 90 196 | string += "\nScannet Dataset: \n" 197 | length = len(self) 198 | string += f" Total Samples: {length}" 199 | string += f" » Mode: {self._mode} \n" 200 | string += f" » DataAug: {self._data_augmentation}" 201 | string += "=" * 90 202 | return string 203 | -------------------------------------------------------------------------------- /nr4seg/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | from .finetune_data_module import FineTuneDataModule 2 | from .joint_train_data_module import JointTrainDataModule 3 | from .joint_train_lightning_net import JointTrainLightningNet 4 | from .pretrain_data_module import PretrainDataModule 5 | from .semantics_lightning_net import SemanticsLightningNet 6 | 7 | __all__ = [ 8 | "FineTuneDataModule", 9 | "JointTrainDataModule", 10 | "JointTrainLightningNet", 11 | "PretrainDataModule", 12 | "SemanticsLightningNet", 13 | ] 14 | -------------------------------------------------------------------------------- /nr4seg/lightning/finetune_data_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pytorch_lightning as pl 4 | 5 | from torch.utils.data import DataLoader 6 | from typing import Optional 7 | 8 | from nr4seg.dataset import ScanNet, ScanNetCL, ScanNetNGP 9 | 10 | 11 | class FineTuneDataModule(pl.LightningDataModule): 12 | 13 | def __init__( 14 | self, 15 | exp: dict, 16 | env: dict, 17 | prev_exp_name: str = "one_step_nerf_only", 18 | ): 19 | super().__init__() 20 | 21 | self.env = env 22 | self.exp = exp 23 | self.cfg_loader = self.exp["data_module"] 24 | self.prev_exp_name = prev_exp_name 25 | 26 | def setup(self, stage: Optional[str] = None) -> None: 27 | ## test adaption (last 20% of the new scenes) 28 | finetune_seqs = self.exp["scenes"] 29 | self.scannet_test_ada = ScanNetNGP( 30 | root=self.env["scannet"], 31 | mode="val", # val 32 | val_mode="gtgt", 33 | scene_list=finetune_seqs, 34 | ) 35 | ## test generation 36 | scannet_25k_dir = self.env["scannet_frames_25k"] 37 | split_file = os.path.join( 38 | scannet_25k_dir, 39 | self.cfg_loader["data_preprocessing"]["split_file"], 40 | ) 41 | img_list = np.load(split_file) 42 | self.scannet_test_gen = ScanNet( 43 | root=self.cfg_loader["root"], 44 | img_list=img_list["test"], 45 | mode="test", 46 | ) 47 | ## train (all Torch-NGP generated labels) 48 | # random select from 49 | 50 | if not self.exp["cl"]["active"]: 51 | self.scannet_train = ScanNetNGP( 52 | root=self.env["scannet"], 53 | train_image=self.cfg_loader["train_image"], 54 | train_label=self.cfg_loader["train_label"], 55 | mode="train", 56 | scene_list=finetune_seqs, 57 | prev_exp_name=self.prev_exp_name, 58 | ) 59 | else: 60 | scannet_ngp = ScanNetNGP( 61 | root=self.env["scannet"], 62 | train_image=self.cfg_loader["train_image"], 63 | train_label=self.cfg_loader["train_label"], 64 | mode="train", 65 | scene_list=finetune_seqs, 66 | prev_exp_name=self.prev_exp_name, 67 | ) 68 | split_file_cl = os.path.join( 69 | scannet_25k_dir, 70 | self.cfg_loader["data_preprocessing"]["split_file_cl"], 71 | ) 72 | img_list_cl = np.load(split_file_cl)["train_cl"] 73 | img_list_cl = img_list_cl[:int(self.exp["cl"]["25k_fraction"] * 74 | len(img_list_cl))] 75 | scannet_25k = ScanNet( 76 | root=self.cfg_loader["root"], 77 | img_list=img_list_cl, 78 | mode="train", 79 | ) 80 | self.scannet_train = ScanNetCL( 81 | scannet_25k, 82 | scannet_ngp, 83 | ngp_25k_ratio=self.exp["cl"]["ngp_25k_ratio"], 84 | ) 85 | 86 | def train_dataloader(self) -> DataLoader: 87 | return DataLoader( 88 | self.scannet_train, 89 | batch_size=self.cfg_loader["batch_size"], 90 | drop_last=True, 91 | shuffle=True, # only true in train_dataloader 92 | collate_fn=self.scannet_train.collate 93 | if self.exp["cl"]["active"] else None, 94 | pin_memory=self.cfg_loader["pin_memory"], 95 | num_workers=self.cfg_loader["num_workers"], 96 | ) 97 | 98 | def val_dataloader(self) -> DataLoader: 99 | return DataLoader( 100 | self.scannet_test_ada, 101 | batch_size= 102 | 1, ## set bs=1 to ensure a batch always has frames from the same scene 103 | drop_last=False, 104 | shuffle=False, # false 105 | pin_memory=self.cfg_loader["pin_memory"], 106 | num_workers=self.cfg_loader["num_workers"], 107 | ) 108 | 109 | def test_dataloader(self) -> DataLoader: 110 | return DataLoader( 111 | self.scannet_test_gen, 112 | batch_size=self.cfg_loader["batch_size"], 113 | drop_last=False, 114 | shuffle=False, # false 115 | pin_memory=self.cfg_loader["pin_memory"], 116 | num_workers=self.cfg_loader["num_workers"], 117 | ) 118 | -------------------------------------------------------------------------------- /nr4seg/lightning/joint_train_data_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pytorch_lightning as pl 4 | 5 | from torch.utils.data import DataLoader 6 | from typing import Optional 7 | 8 | from nr4seg.dataset import ScanNet, ScanNetCLJoint, ScanNetNGPJoint 9 | 10 | 11 | class JointTrainDataModule(pl.LightningDataModule): 12 | 13 | def __init__( 14 | self, 15 | exp: dict, 16 | env: dict, 17 | split_ratio: float = 0.2, 18 | ): 19 | super().__init__() 20 | 21 | self.env = env 22 | self.exp = exp 23 | self.cfg_loader = self.exp["data_module"] 24 | 25 | self.split_ratio = split_ratio 26 | self.test_sz = 0 27 | self.val_sz = 0 28 | self.train_sz = 0 29 | 30 | def setup(self, stage: Optional[str] = None) -> None: 31 | finetune_seqs = self.exp["scenes"] 32 | self.scannet_val_ada = ScanNetNGPJoint( 33 | root=self.env["scannet"], 34 | mode="val", # val 35 | scene_list=finetune_seqs, 36 | exp_name=self.exp["exp_name"], 37 | only_new_scene=False, 38 | ) 39 | self.scannet_train_val_ada = ScanNetNGPJoint( 40 | root=self.env["scannet"], 41 | mode="train_val", # train_val 42 | scene_list=finetune_seqs, 43 | exp_name=self.exp["exp_name"], 44 | only_new_scene=False, 45 | ) 46 | self.scannet_predict_ada = ScanNetNGPJoint( 47 | root=self.env["scannet"], 48 | mode="predict", # val 49 | scene_list=finetune_seqs, 50 | exp_name=self.exp["exp_name"], 51 | use_novel_viewpoints=self.exp["cl"]["use_novel_viewpoints"], 52 | only_new_scene=True, 53 | ) 54 | ## test generation 55 | scannet_25k_dir = self.env["scannet_frames_25k"] 56 | split_file = os.path.join( 57 | scannet_25k_dir, 58 | self.cfg_loader["data_preprocessing"]["split_file"], 59 | ) 60 | img_list = np.load(split_file) 61 | self.scannet_test_gen = ScanNet( 62 | root=self.env["scannet_frames_25k"], 63 | img_list=img_list["test"], 64 | mode="test", 65 | ) 66 | 67 | self.scannet_train_nerf = ScanNetNGPJoint( 68 | root=self.env["scannet"], 69 | mode="train", 70 | scene_list=finetune_seqs, 71 | exp_name=self.exp["exp_name"], 72 | only_new_scene=True, 73 | ) 74 | print("\033[93mNOTE: By default, the replay buffer is set to have size " 75 | "100.\033[0m") 76 | self.scannet_train_joint = ScanNetNGPJoint( 77 | root=self.env["scannet"], 78 | mode="train", 79 | scene_list=finetune_seqs, 80 | exp_name=self.exp["exp_name"], 81 | only_new_scene=False, 82 | # NOTE: This is referred to whether the previous scenes (if any) 83 | # used for replay were generated from novel viewpoints. 84 | use_novel_viewpoints=self.exp["cl"]["use_novel_viewpoints"], 85 | fix_nerf=False, 86 | replay_buffer_size=self.exp["cl"]["replay_buffer_size"], 87 | ) 88 | ## train (all Torch-NGP generated labels) 89 | # random select from 90 | if self.exp["cl"]["active"]: 91 | split_file_cl = os.path.join( 92 | scannet_25k_dir, 93 | self.cfg_loader["data_preprocessing"]["split_file_cl"], 94 | ) 95 | img_list_cl = np.load(split_file_cl)["train_cl"] 96 | img_list_cl = img_list_cl[:int(self.exp["cl"]["25k_fraction"] * 97 | len(img_list_cl))] 98 | scannet_25k = ScanNet( 99 | root=self.env["scannet_frames_25k"], 100 | img_list=img_list_cl, 101 | mode="train", 102 | ) 103 | self.scannet_train_joint = ScanNetCLJoint( 104 | scannet_25k, 105 | self.scannet_train_joint, 106 | ngp_25k_ratio=self.exp["cl"]["ngp_25k_ratio"], 107 | ) 108 | 109 | self.test_sz = len(self.scannet_test_gen) 110 | self.val_sz = len(self.scannet_val_ada) 111 | self.train_sz = len(self.scannet_train_joint) 112 | print("Train/Val/Test/Total: {}/{}/{}/{}".format( 113 | self.train_sz, 114 | self.val_sz, 115 | self.test_sz, 116 | self.train_sz + self.val_sz + self.test_sz, 117 | )) 118 | 119 | def train_dataloader_nerf(self) -> DataLoader: 120 | return DataLoader( 121 | self.scannet_train_nerf, 122 | batch_size=1, 123 | drop_last=False, 124 | shuffle=True, # only true in train_dataloader 125 | pin_memory=self.cfg_loader["pin_memory"], 126 | num_workers=self.cfg_loader["num_workers"], 127 | ) 128 | 129 | def train_dataloader_joint(self) -> DataLoader: 130 | return DataLoader( 131 | self.scannet_train_joint, 132 | batch_size=self.cfg_loader["batch_size"], 133 | drop_last=True, 134 | shuffle=True, # only true in train_dataloader 135 | pin_memory=self.cfg_loader["pin_memory"], 136 | num_workers=self.cfg_loader["num_workers"], 137 | collate_fn=ScanNetNGPJoint.collate, 138 | ) 139 | 140 | def val_dataloader(self) -> DataLoader: 141 | return [ 142 | DataLoader( 143 | self.scannet_val_ada, 144 | # Set bs=1 to ensure a batch always has frames from the same 145 | # scene. 146 | batch_size=1, 147 | drop_last=False, 148 | shuffle=False, # false 149 | pin_memory=self.cfg_loader["pin_memory"], 150 | num_workers=self.cfg_loader["num_workers"], 151 | ), 152 | DataLoader( 153 | self.scannet_train_val_ada, 154 | # Set bs=1 to ensure a batch always has frames from the same 155 | # scene. 156 | batch_size=1, 157 | drop_last=False, 158 | shuffle=False, # false 159 | pin_memory=self.cfg_loader["pin_memory"], 160 | num_workers=self.cfg_loader["num_workers"], 161 | ), 162 | ] 163 | 164 | def test_dataloader(self) -> DataLoader: 165 | return [ 166 | DataLoader( 167 | self.scannet_train_nerf, 168 | batch_size=1, 169 | drop_last=False, 170 | shuffle=False, # false 171 | pin_memory=self.cfg_loader["pin_memory"], 172 | num_workers=self.cfg_loader["num_workers"], 173 | ), 174 | DataLoader( 175 | self.scannet_test_gen, 176 | batch_size=4, 177 | drop_last=False, 178 | shuffle=False, # false 179 | pin_memory=self.cfg_loader["pin_memory"], 180 | num_workers=self.cfg_loader["num_workers"], 181 | ), 182 | ] 183 | 184 | def test_dataloader_nerf(self) -> DataLoader: 185 | return DataLoader( 186 | self.scannet_train_nerf, 187 | batch_size=1, 188 | drop_last=False, 189 | shuffle=False, # false 190 | pin_memory=self.cfg_loader["pin_memory"], 191 | num_workers=self.cfg_loader["num_workers"], 192 | ) 193 | 194 | def predict_dataloader(self) -> DataLoader: 195 | return DataLoader( 196 | self.scannet_predict_ada, 197 | batch_size=1, 198 | drop_last=False, 199 | shuffle=False, # false 200 | pin_memory=self.cfg_loader["pin_memory"], 201 | num_workers=self.cfg_loader["num_workers"], 202 | ) 203 | -------------------------------------------------------------------------------- /nr4seg/lightning/pretrain_data_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pytorch_lightning as pl 4 | 5 | from torch.utils.data import DataLoader 6 | from typing import Optional 7 | 8 | from nr4seg.dataset import ScanNet 9 | 10 | 11 | class PretrainDataModule(pl.LightningDataModule): 12 | 13 | def __init__(self, env: dict, cfg_dm: dict): 14 | super().__init__() 15 | 16 | self.cfg_dm = cfg_dm 17 | self.env = env 18 | 19 | def setup(self, stage: Optional[str] = None) -> None: 20 | split_file = os.path.join( 21 | self.cfg_dm["root"], 22 | self.cfg_dm["data_preprocessing"]["split_file"], 23 | ) 24 | img_list = np.load(split_file) 25 | self.scannet_test = ScanNet(root=self.cfg_dm["root"], 26 | img_list=img_list["test"], 27 | mode="test") 28 | self.scannet_train = ScanNet(root=self.cfg_dm["root"], 29 | img_list=img_list["train"], 30 | mode="train") 31 | self.scannet_val = ScanNet(root=self.cfg_dm["root"], 32 | img_list=img_list["val"], 33 | mode="val") 34 | 35 | def train_dataloader(self) -> DataLoader: 36 | return DataLoader( 37 | self.scannet_train, 38 | batch_size=self.cfg_dm["batch_size"], 39 | drop_last=self.cfg_dm["drop_last"], 40 | shuffle=self.cfg_dm["shuffle"], 41 | pin_memory=self.cfg_dm["pin_memory"], 42 | num_workers=self.cfg_dm["num_workers"], 43 | ) 44 | 45 | def val_dataloader(self) -> DataLoader: 46 | return DataLoader( 47 | self.scannet_val, 48 | batch_size=self.cfg_dm["batch_size"], 49 | drop_last=self.cfg_dm["drop_last"], 50 | shuffle=False, 51 | pin_memory=self.cfg_dm["pin_memory"], 52 | num_workers=self.cfg_dm["num_workers"], 53 | ) 54 | 55 | def test_dataloader(self) -> DataLoader: 56 | return DataLoader( 57 | self.scannet_test, 58 | batch_size=self.cfg_dm["batch_size"], 59 | drop_last=self.cfg_dm["drop_last"], 60 | shuffle=False, 61 | pin_memory=self.cfg_dm["pin_memory"], 62 | num_workers=self.cfg_dm["num_workers"], 63 | ) 64 | -------------------------------------------------------------------------------- /nr4seg/lightning/semantics_lightning_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytorch_lightning as pl 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from nr4seg.network import DeepLabV3 7 | from nr4seg.utils.metrics import SemanticsMeter 8 | from nr4seg.visualizer import Visualizer 9 | 10 | 11 | class SemanticsLightningNet(pl.LightningModule): 12 | 13 | def __init__(self, exp, env): 14 | super().__init__() 15 | self._model = DeepLabV3(exp["model"]) 16 | self.prev_scene_name = None 17 | 18 | self._visualizer = Visualizer( 19 | os.path.join(exp["general"]["name"], "visu"), 20 | exp["visualizer"]["store"], 21 | self, 22 | ) 23 | 24 | self._meter = { 25 | "val_1": SemanticsMeter(number_classes=exp["model"]["num_classes"]), 26 | "val_2": SemanticsMeter(number_classes=exp["model"]["num_classes"]), 27 | "val_3": SemanticsMeter(number_classes=exp["model"]["num_classes"]), 28 | "test": SemanticsMeter(number_classes=exp["model"]["num_classes"]), 29 | "train": SemanticsMeter(number_classes=exp["model"]["num_classes"]), 30 | } 31 | 32 | self._visu_count = {"val": 0, "test": 0, "train": 0} 33 | 34 | self._exp = exp 35 | self._env = env 36 | self._mode = "train" 37 | self.length_train_dataloader = 10000 38 | 39 | def forward(self, image: torch.Tensor) -> torch.Tensor: 40 | return self._model(image) 41 | 42 | def visu(self, image, target, pred): 43 | if not (self._visu_count[self._mode] 44 | < self._exp["visualizer"]["store_n"][self._mode]): 45 | return 46 | 47 | for b in range(image.shape[0]): 48 | if (self._visu_count[self._mode] 49 | < self._exp["visualizer"]["store_n"][self._mode]): 50 | c = self._visu_count[self._mode] 51 | self._visualizer.plot_image(image[b], 52 | tag=f"{self._mode}_vis/image_{c}") 53 | self._visualizer.plot_segmentation( 54 | pred[b], tag=f"{self._mode}_vis/pred_{c}") 55 | self._visualizer.plot_segmentation( 56 | target[b], tag=f"{self._mode}_vis/target_{c}") 57 | 58 | self._visualizer.plot_detectron( 59 | image[b], target[b], tag=f"{self._mode}_vis/detectron_{c}") 60 | 61 | self._visu_count[self._mode] += 1 62 | else: 63 | break 64 | 65 | # TRAINING 66 | def on_train_epoch_start(self): 67 | self._mode = "train" 68 | self._visu_count[self._mode] = 0 69 | self._meter["train"].clear() 70 | 71 | def training_step(self, batch, batch_idx: int) -> torch.Tensor: 72 | image, target, ori_image = batch 73 | output = self(image) 74 | pred = F.softmax(output["out"], dim=1) 75 | pred_argmax = torch.argmax(pred, dim=1) 76 | all_pred_argmax = self.all_gather(pred_argmax) 77 | all_target = self.all_gather(target) 78 | self._meter[self._mode].update(all_pred_argmax, all_target) 79 | # Compute Loss 80 | loss = F.cross_entropy(pred, target, ignore_index=-1, reduction="none") 81 | # Visu 82 | # self.visu(ori_image, target+1, pred_argmax+1) 83 | # Loss loggging 84 | self.log( 85 | f"{self._mode}/loss", 86 | loss.mean().item(), 87 | on_step=self._mode == "train", 88 | on_epoch=self._mode != "train", 89 | ) 90 | return loss.mean() 91 | 92 | def on_train_epoch_end(self): 93 | m_iou, total_acc, m_acc = self._meter["train"].measure() 94 | self.log(f"train/total_accuracy", total_acc, rank_zero_only=True) 95 | self.log(f"train/mean_accuracy", m_acc, rank_zero_only=True) 96 | self.log(f"train/mean_IoU", m_iou, rank_zero_only=True) 97 | 98 | # VALIDATION 99 | def on_validation_epoch_start(self): 100 | self._mode = "val" 101 | self._visu_count[self._mode] = 0 102 | self._meter["val_1"].clear() 103 | self._meter["val_2"].clear() 104 | self._meter["val_3"].clear() 105 | 106 | def validation_step(self, batch, batch_idx: int) -> None: 107 | dataloader_idx = 0 # TODO: modify back 108 | image, target, ori_image, scene_name = batch 109 | scene_name = scene_name[0] 110 | output = self(image) 111 | pred = F.softmax(output["out"], dim=1) 112 | pred_argmax = torch.argmax(pred, dim=1) 113 | all_pred_argmax = self.all_gather(pred_argmax) 114 | all_target = self.all_gather(target) 115 | 116 | self.prev_scene_name = scene_name 117 | self._meter[f"val_{dataloader_idx+1}"].update(all_pred_argmax, 118 | all_target) 119 | 120 | # Compute Loss 121 | loss = F.cross_entropy(pred, target, ignore_index=-1, reduction="none") 122 | # Visu 123 | # self.visu(ori_image, target+1, pred_argmax+1) 124 | # Loss loggging 125 | self.log( 126 | f"{self._mode}/loss", 127 | loss.mean().item(), 128 | on_step=self._mode == "train", 129 | on_epoch=self._mode != "train", 130 | ) 131 | return loss.mean() 132 | 133 | def on_validation_epoch_end(self): 134 | m_iou_1, total_acc_1, m_acc_1 = self._meter["val_1"].measure() 135 | self.log(f"val/total_accuracy_gg", total_acc_1, rank_zero_only=True) 136 | self.log(f"val/mean_accuracy_gg", m_acc_1, rank_zero_only=True) 137 | self.log(f"val/mean_IoU_gg", m_iou_1, rank_zero_only=True) 138 | self.prev_scene_name = None 139 | 140 | # TESTING 141 | def on_test_epoch_start(self): 142 | self._mode = "test" 143 | self._visu_count[self._mode] = 0 144 | self._meter["test"].clear() 145 | 146 | def test_step(self, batch, batch_idx: int) -> None: 147 | return self.training_step(batch, batch_idx) 148 | 149 | def on_test_epoch_end(self): 150 | m_iou, total_acc, m_acc = self._meter["test"].measure() 151 | self.log(f"test/total_accuracy", total_acc, rank_zero_only=True) 152 | self.log(f"test/mean_accuracy", m_acc, rank_zero_only=True) 153 | self.log(f"test/mean_IoU", m_iou, rank_zero_only=True) 154 | 155 | def configure_optimizers(self) -> torch.optim.Optimizer: 156 | optimizer = self._exp["optimizer"]["name"] 157 | lr = self._exp["optimizer"]["lr"] 158 | if optimizer == "Adam": 159 | optimizer = torch.optim.Adam(self._model.parameters(), lr=lr) 160 | if optimizer == "SGD": 161 | sgd_cfg = self._exp["optimizer"]["sgd_cfg"] 162 | optimizer = torch.optim.SGD( 163 | self._model.parameters(), 164 | lr=lr, 165 | momentum=sgd_cfg["momentum"], 166 | weight_decay=sgd_cfg["weight_decay"], 167 | ) 168 | if optimizer == "Adadelta": 169 | optimizer = torch.optim.Adadelta(self._model.parameters(), lr=lr) 170 | if optimizer == "RMSprop": 171 | optimizer = torch.optim.RMSprop(self._model.parameters(), 172 | momentum=0.9, 173 | lr=lr) 174 | if self._exp["lr_scheduler"]["active"]: 175 | scheduler = self._exp["lr_scheduler"]["name"] 176 | if scheduler == "POLY": 177 | init_lr = (self._exp["optimizer"]["lr"],) 178 | max_epochs = self._exp["lr_scheduler"]["poly_cfg"]["max_epochs"] 179 | target_lr = self._exp["lr_scheduler"]["poly_cfg"]["target_lr"] 180 | power = self._exp["lr_scheduler"]["poly_cfg"]["power"] 181 | lambda_lr = (lambda epoch: ( 182 | ((max_epochs - min(max_epochs, epoch)) / max_epochs)** 183 | (power)) + (1 - (( 184 | (max_epochs - min(max_epochs, epoch)) / max_epochs)** 185 | (power))) * target_lr / init_lr) 186 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 187 | lambda_lr, 188 | last_epoch=-1, 189 | verbose=True) 190 | interval = "epoch" 191 | lr_scheduler = {"scheduler": scheduler, "interval": interval} 192 | ret = {"optimizer": optimizer, "lr_scheduler": lr_scheduler} 193 | else: 194 | ret = [optimizer] 195 | return ret 196 | -------------------------------------------------------------------------------- /nr4seg/nerf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethz-asl/ucsa_neural_rendering/d29b37445e7e6be2cf04fa70ccebebbcbec0c1bf/nr4seg/nerf/__init__.py -------------------------------------------------------------------------------- /nr4seg/nerf/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.autograd import Function 4 | from torch.cuda.amp import custom_bwd, custom_fwd 5 | 6 | 7 | class _trunc_exp(Function): 8 | 9 | @staticmethod 10 | @custom_fwd(cast_inputs=torch.float) 11 | def forward(ctx, x): 12 | ctx.save_for_backward(x) 13 | return torch.exp(x) 14 | 15 | @staticmethod 16 | @custom_bwd 17 | def backward(ctx, g): 18 | x = ctx.saved_tensors[0] 19 | return g * torch.exp(x.clamp(-15, 15)) 20 | 21 | 22 | trunc_exp = _trunc_exp.apply -------------------------------------------------------------------------------- /nr4seg/nerf/network_tcnn_semantics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tinycudann as tcnn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .activation import trunc_exp 7 | from .renderer_semantics import SemanticNeRFRenderer 8 | 9 | 10 | class SemanticNeRFNetwork(SemanticNeRFRenderer): 11 | 12 | def __init__(self, 13 | encoding="HashGrid", 14 | encoding_dir="SphericalHarmonics", 15 | num_layers=2, 16 | hidden_dim=64, 17 | geo_feat_dim=15, 18 | num_layers_color=3, 19 | hidden_dim_color=64, 20 | num_layers_semantics=2, 21 | hidden_dim_semantics=64, 22 | bound=1, 23 | num_semantic_classes=41, 24 | **kwargs): 25 | super().__init__(bound, 26 | **kwargs, 27 | num_semantic_classes=num_semantic_classes) 28 | 29 | # sigma network 30 | self.num_layers = num_layers 31 | self.hidden_dim = hidden_dim 32 | self.geo_feat_dim = geo_feat_dim 33 | 34 | per_level_scale = np.exp2(np.log2(2048 * bound / 16) / (16 - 1)) 35 | 36 | self.encoder = tcnn.Encoding( 37 | n_input_dims=3, 38 | encoding_config={ 39 | "otype": "HashGrid", 40 | "n_levels": 16, 41 | "n_features_per_level": 2, 42 | "log2_hashmap_size": 19, 43 | "base_resolution": 16, 44 | "per_level_scale": per_level_scale, 45 | }, 46 | ) 47 | 48 | self.sigma_net = tcnn.Network( 49 | n_input_dims=32, 50 | n_output_dims=1 + self.geo_feat_dim, 51 | network_config={ 52 | "otype": "FullyFusedMLP", 53 | "activation": "ReLU", 54 | "output_activation": "None", 55 | "n_neurons": hidden_dim, 56 | "n_hidden_layers": num_layers - 1, 57 | }, 58 | ) 59 | 60 | # color network 61 | self.num_layers_color = num_layers_color 62 | self.hidden_dim_color = hidden_dim_color 63 | 64 | self.encoder_dir = tcnn.Encoding( 65 | n_input_dims=3, 66 | encoding_config={ 67 | "otype": "SphericalHarmonics", 68 | "degree": 4, 69 | }, 70 | ) 71 | 72 | self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim 73 | 74 | self.color_net = tcnn.Network( 75 | n_input_dims=self.in_dim_color, 76 | n_output_dims=3, 77 | network_config={ 78 | "otype": "FullyFusedMLP", 79 | "activation": "ReLU", 80 | "output_activation": "None", 81 | "n_neurons": hidden_dim_color, 82 | "n_hidden_layers": num_layers_color - 1, 83 | }, 84 | ) 85 | 86 | self.num_layers_semantics = num_layers_semantics 87 | self.hidden_dim_semantics = hidden_dim_semantics 88 | self.num_semantic_classes = num_semantic_classes 89 | self.in_dim_semantics = self.geo_feat_dim 90 | self.semantics_net = tcnn.Network( 91 | n_input_dims=self.in_dim_semantics, 92 | n_output_dims=self.num_semantic_classes, 93 | network_config={ 94 | "otype": "FullyFusedMLP", 95 | "activation": "ReLU", 96 | "output_activation": "None", 97 | "n_neurons": hidden_dim_semantics, 98 | "n_hidden_layers": num_layers_semantics - 1, 99 | }, 100 | ) 101 | 102 | def forward(self, x, d): 103 | # x: [N, 3], in [-bound, bound] 104 | # d: [N, 3], nomalized in [-1, 1] 105 | 106 | # sigma 107 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 108 | x = self.encoder(x) 109 | h = self.sigma_net(x) 110 | 111 | # sigma = F.relu(h[..., 0]) 112 | sigma = trunc_exp(h[..., 0]) 113 | geo_feat = h[..., 1:] 114 | 115 | # color 116 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] 117 | d = self.encoder_dir(d) 118 | 119 | # p = torch.zeros_like(geo_feat[..., :1]) # manual input padding 120 | h = torch.cat([d, geo_feat], dim=-1) 121 | h = self.color_net(h) 122 | 123 | # sigmoid activation for rgb 124 | color = torch.sigmoid(h) 125 | semantics = self.semantics_net(geo_feat) 126 | semantics = F.softmax(semantics, dim=-1) 127 | 128 | return sigma, color, semantics 129 | 130 | def density(self, x): 131 | # x: [N, 3], in [-bound, bound] 132 | 133 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 134 | x = self.encoder(x) 135 | h = self.sigma_net(x) 136 | 137 | # sigma = F.relu(h[..., 0]) 138 | sigma = trunc_exp(h[..., 0]) 139 | geo_feat = h[..., 1:] 140 | 141 | return { 142 | "sigma": sigma, 143 | "geo_feat": geo_feat, 144 | } 145 | 146 | # allow masked inference 147 | def color(self, x, d, mask=None, geo_feat=None, **kwargs): 148 | # x: [N, 3] in [-bound, bound] 149 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 150 | 151 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 152 | 153 | if mask is not None: 154 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, 155 | device=x.device) # [N, 3] 156 | # in case of empty mask 157 | if not mask.any(): 158 | return rgbs 159 | x = x[mask] 160 | d = d[mask] 161 | geo_feat = geo_feat[mask] 162 | 163 | # color 164 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] 165 | d = self.encoder_dir(d) 166 | 167 | h = torch.cat([d, geo_feat], dim=-1) 168 | h = self.color_net(h) 169 | 170 | # sigmoid activation for rgb 171 | h = torch.sigmoid(h) 172 | 173 | if mask is not None: 174 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 175 | else: 176 | rgbs = h 177 | 178 | return rgbs 179 | 180 | def semantics(self, x, d, mask=None, geo_feat=None, **kwargs): 181 | # x: [N, 3] in [-bound, bound] 182 | # mask: [N,], bool, indicates where we actually needs to compute rgb. 183 | 184 | x = (x + self.bound) / (2 * self.bound) # to [0, 1] 185 | 186 | if mask is not None: 187 | semantics = torch.zeros( 188 | mask.shape[0], 189 | self.num_semantic_classes, 190 | dtype=x.dtype, 191 | device=x.device, 192 | ) # [N, 3] 193 | # in case of empty mask 194 | if not mask.any(): 195 | return semantics 196 | x = x[mask] 197 | geo_feat = geo_feat[mask] 198 | 199 | h = self.semantics_net(geo_feat) 200 | 201 | if mask is not None: 202 | semantics[mask] = h.to(semantics.dtype) # fp16 --> fp32 203 | semantics[mask] = F.softmax(semantics[mask], dim=-1) 204 | else: 205 | semantics = h.to(semantics.dtype) 206 | semantics = F.softmax(semantics, dim=-1) 207 | return semantics 208 | -------------------------------------------------------------------------------- /nr4seg/nerf/raymarching/__init__.py: -------------------------------------------------------------------------------- 1 | from .raymarching import * 2 | -------------------------------------------------------------------------------- /nr4seg/nerf/raymarching/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.cpp_extension import load 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | "-O3", 9 | "-std=c++14", 10 | "-U__CUDA_NO_HALF_OPERATORS__", 11 | "-U__CUDA_NO_HALF_CONVERSIONS__", 12 | "-U__CUDA_NO_HALF2_OPERATORS__", 13 | ] 14 | 15 | if os.name == "posix": 16 | c_flags = ["-O3", "-std=c++14"] 17 | elif os.name == "nt": 18 | c_flags = ["/O2", "/std:c++17"] 19 | 20 | # find cl.exe 21 | def find_cl_path(): 22 | import glob 23 | 24 | for edition in [ 25 | "Enterprise", "Professional", "BuildTools", "Community" 26 | ]: 27 | paths = sorted( 28 | glob.glob( 29 | r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" 30 | % edition), 31 | reverse=True, 32 | ) 33 | if paths: 34 | return paths[0] 35 | 36 | # If cl.exe is not on path, try to find it. 37 | if os.system("where cl.exe >nul 2>nul") != 0: 38 | cl_path = find_cl_path() 39 | if cl_path is None: 40 | raise RuntimeError( 41 | "Could not locate a supported Microsoft Visual C++ installation" 42 | ) 43 | os.environ["PATH"] += ";" + cl_path 44 | 45 | _backend = load( 46 | name="_raymarching", 47 | extra_cflags=c_flags, 48 | extra_cuda_cflags=nvcc_flags, 49 | sources=[ 50 | os.path.join(_src_path, "src", f) for f in [ 51 | "raymarching.cu", 52 | "bindings.cpp", 53 | ] 54 | ], 55 | ) 56 | 57 | __all__ = ["_backend"] 58 | -------------------------------------------------------------------------------- /nr4seg/nerf/raymarching/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "raymarching.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | // utils 7 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); 8 | // train 9 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); 10 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); 11 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); 12 | // m.def("composite_rays_train_semantics_forward", &composite_rays_train_semantics_forward, "composite_rays_train_forward (CUDA)"); 13 | // m.def("composite_rays_train_semantics_backward", &composite_rays_train_semantics_backward, "composite_rays_train_backward (CUDA)"); 14 | // infer 15 | m.def("march_rays", &march_rays, "march rays (CUDA)"); 16 | // m.def("composite_rays_semantics", &composite_rays_semantics, "composite rays (CUDA)"); 17 | m.def("compact_rays", &compact_rays, "compact rays (CUDA)"); 18 | } -------------------------------------------------------------------------------- /nr4seg/nerf/raymarching/src/pcg32.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Tiny self-contained version of the PCG Random Number Generation for C++ 3 | * put together from pieces of the much larger C/C++ codebase. 4 | * Wenzel Jakob, February 2015 5 | * 6 | * The PCG random number generator was developed by Melissa O'Neill 7 | * 8 | * 9 | * Licensed under the Apache License, Version 2.0 (the "License"); 10 | * you may not use this file except in compliance with the License. 11 | * You may obtain a copy of the License at 12 | * 13 | * http://www.apache.org/licenses/LICENSE-2.0 14 | * 15 | * Unless required by applicable law or agreed to in writing, software 16 | * distributed under the License is distributed on an "AS IS" BASIS, 17 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | * See the License for the specific language governing permissions and 19 | * limitations under the License. 20 | * 21 | * For additional information about the PCG random number generation scheme, 22 | * including its license and other licensing options, visit 23 | * 24 | * http://www.pcg-random.org 25 | * 26 | * Note: This code was modified to work with CUDA by the tiny-cuda-nn authors. 27 | */ 28 | 29 | #pragma once 30 | 31 | #define PCG32_DEFAULT_STATE 0x853c49e6748fea9bULL 32 | #define PCG32_DEFAULT_STREAM 0xda3e39cb94b95bdbULL 33 | #define PCG32_MULT 0x5851f42d4c957f2dULL 34 | 35 | #include 36 | #include 37 | #include 38 | 39 | #include 40 | #include 41 | #include 42 | 43 | /// PCG32 Pseudorandom number generator 44 | struct pcg32 { 45 | /// Initialize the pseudorandom number generator with default seed 46 | __host__ __device__ pcg32() : state(PCG32_DEFAULT_STATE), inc(PCG32_DEFAULT_STREAM) {} 47 | 48 | /// Initialize the pseudorandom number generator with the \ref seed() function 49 | __host__ __device__ pcg32(uint64_t initstate, uint64_t initseq = 1u) { seed(initstate, initseq); } 50 | 51 | /** 52 | * \brief Seed the pseudorandom number generator 53 | * 54 | * Specified in two parts: a state initializer and a sequence selection 55 | * constant (a.k.a. stream id) 56 | */ 57 | __host__ __device__ void seed(uint64_t initstate, uint64_t initseq = 1) { 58 | state = 0U; 59 | inc = (initseq << 1u) | 1u; 60 | next_uint(); 61 | state += initstate; 62 | next_uint(); 63 | } 64 | 65 | /// Generate a uniformly distributed unsigned 32-bit random number 66 | __host__ __device__ uint32_t next_uint() { 67 | uint64_t oldstate = state; 68 | state = oldstate * PCG32_MULT + inc; 69 | uint32_t xorshifted = (uint32_t) (((oldstate >> 18u) ^ oldstate) >> 27u); 70 | uint32_t rot = (uint32_t) (oldstate >> 59u); 71 | return (xorshifted >> rot) | (xorshifted << ((~rot + 1u) & 31)); 72 | } 73 | 74 | /// Generate a uniformly distributed number, r, where 0 <= r < bound 75 | __host__ __device__ uint32_t next_uint(uint32_t bound) { 76 | // To avoid bias, we need to make the range of the RNG a multiple of 77 | // bound, which we do by dropping output less than a threshold. 78 | // A naive scheme to calculate the threshold would be to do 79 | // 80 | // uint32_t threshold = 0x100000000ull % bound; 81 | // 82 | // but 64-bit div/mod is slower than 32-bit div/mod (especially on 83 | // 32-bit platforms). In essence, we do 84 | // 85 | // uint32_t threshold = (0x100000000ull-bound) % bound; 86 | // 87 | // because this version will calculate the same modulus, but the LHS 88 | // value is less than 2^32. 89 | 90 | uint32_t threshold = (~bound+1u) % bound; 91 | 92 | // Uniformity guarantees that this loop will terminate. In practice, it 93 | // should usually terminate quickly; on average (assuming all bounds are 94 | // equally likely), 82.25% of the time, we can expect it to require just 95 | // one iteration. In the worst case, someone passes a bound of 2^31 + 1 96 | // (i.e., 2147483649), which invalidates almost 50% of the range. In 97 | // practice, bounds are typically small and only a tiny amount of the range 98 | // is eliminated. 99 | for (;;) { 100 | uint32_t r = next_uint(); 101 | if (r >= threshold) 102 | return r % bound; 103 | } 104 | } 105 | 106 | /// Generate a single precision floating point value on the interval [0, 1) 107 | __host__ __device__ float next_float() { 108 | /* Trick from MTGP: generate an uniformly distributed 109 | single precision number in [1,2) and subtract 1. */ 110 | union { 111 | uint32_t u; 112 | float f; 113 | } x; 114 | x.u = (next_uint() >> 9) | 0x3f800000u; 115 | return x.f - 1.0f; 116 | } 117 | 118 | /** 119 | * \brief Generate a double precision floating point value on the interval [0, 1) 120 | * 121 | * \remark Since the underlying random number generator produces 32 bit output, 122 | * only the first 32 mantissa bits will be filled (however, the resolution is still 123 | * finer than in \ref next_float(), which only uses 23 mantissa bits) 124 | */ 125 | __host__ __device__ double next_double() { 126 | /* Trick from MTGP: generate an uniformly distributed 127 | double precision number in [1,2) and subtract 1. */ 128 | union { 129 | uint64_t u; 130 | double d; 131 | } x; 132 | x.u = ((uint64_t) next_uint() << 20) | 0x3ff0000000000000ULL; 133 | return x.d - 1.0; 134 | } 135 | 136 | /** 137 | * \brief Multi-step advance function (jump-ahead, jump-back) 138 | * 139 | * The method used here is based on Brown, "Random Number Generation 140 | * with Arbitrary Stride", Transactions of the American Nuclear 141 | * Society (Nov. 1994). The algorithm is very similar to fast 142 | * exponentiation. 143 | * 144 | * The default value of 2^32 ensures that the PRNG is advanced 145 | * sufficiently far that there is (likely) no overlap with 146 | * previously drawn random numbers, even if small advancements. 147 | * are made inbetween. 148 | */ 149 | __host__ __device__ void advance(int64_t delta_ = (1ll<<32)) { 150 | uint64_t 151 | cur_mult = PCG32_MULT, 152 | cur_plus = inc, 153 | acc_mult = 1u, 154 | acc_plus = 0u; 155 | 156 | /* Even though delta is an unsigned integer, we can pass a signed 157 | integer to go backwards, it just goes "the long way round". */ 158 | uint64_t delta = (uint64_t) delta_; 159 | 160 | while (delta > 0) { 161 | if (delta & 1) { 162 | acc_mult *= cur_mult; 163 | acc_plus = acc_plus * cur_mult + cur_plus; 164 | } 165 | cur_plus = (cur_mult + 1) * cur_plus; 166 | cur_mult *= cur_mult; 167 | delta /= 2; 168 | } 169 | state = acc_mult * state + acc_plus; 170 | } 171 | 172 | /// Compute the distance between two PCG32 pseudorandom number generators 173 | __host__ __device__ int64_t operator-(const pcg32 &other) const { 174 | assert(inc == other.inc); 175 | 176 | uint64_t 177 | cur_mult = PCG32_MULT, 178 | cur_plus = inc, 179 | cur_state = other.state, 180 | the_bit = 1u, 181 | distance = 0u; 182 | 183 | while (state != cur_state) { 184 | if ((state & the_bit) != (cur_state & the_bit)) { 185 | cur_state = cur_state * cur_mult + cur_plus; 186 | distance |= the_bit; 187 | } 188 | assert((state & the_bit) == (cur_state & the_bit)); 189 | the_bit <<= 1; 190 | cur_plus = (cur_mult + 1ULL) * cur_plus; 191 | cur_mult *= cur_mult; 192 | } 193 | 194 | return (int64_t) distance; 195 | } 196 | 197 | /// Equality operator 198 | __host__ __device__ bool operator==(const pcg32 &other) const { return state == other.state && inc == other.inc; } 199 | 200 | /// Inequality operator 201 | __host__ __device__ bool operator!=(const pcg32 &other) const { return state != other.state || inc != other.inc; } 202 | 203 | uint64_t state; // RNG state. All values are possible. 204 | uint64_t inc; // Controls which RNG sequence (stream) is selected. Must *always* be odd. 205 | }; -------------------------------------------------------------------------------- /nr4seg/nerf/raymarching/src/raymarching.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | 7 | void near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_d, at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); 8 | 9 | void march_rays_train(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float mean_density, const float bound, const float dt_gamma, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const uint32_t perturb); 10 | void composite_rays_train_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 11 | void composite_rays_train_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs); 12 | // void composite_rays_train_semantics_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor local_semantics, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image, at::Tensor semantics); 13 | // void composite_rays_train_semantics_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor grad_semantics, at::Tensor sigmas, at::Tensor rgbs, at::Tensor local_semantics, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, at::Tensor semantics, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_local_semantics); 14 | 15 | void march_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor rays_o, at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t C, const uint32_t H, at::Tensor density_grid, const float mean_density, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb); 16 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); 17 | // void composite_rays_semantics(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor local_semantics, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image, at::Tensor semantics); 18 | void compact_rays(const uint32_t n_alive, at::Tensor rays_alive, at::Tensor rays_alive_old, at::Tensor rays_t, at::Tensor rays_t_old, at::Tensor alive_counter); -------------------------------------------------------------------------------- /nr4seg/nerf/renderer_semantics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import trimesh 6 | 7 | from .raymarching import raymarching 8 | 9 | 10 | def sample_pdf(bins, weights, n_samples, det=False): 11 | # This implementation is from NeRF 12 | # bins: [B, T], old_z_vals 13 | # weights: [B, T - 1], bin weights. 14 | # return: [B, n_samples], new_z_vals 15 | 16 | # Get pdf 17 | weights = weights + 1e-5 # prevent nans 18 | pdf = weights / torch.sum(weights, -1, keepdim=True) 19 | cdf = torch.cumsum(pdf, -1) 20 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 21 | # Take uniform samples 22 | if det: 23 | u = torch.linspace(0.0 + 0.5 / n_samples, 24 | 1.0 - 0.5 / n_samples, 25 | steps=n_samples).to(weights.device) 26 | u = u.expand(list(cdf.shape[:-1]) + [n_samples]) 27 | else: 28 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) 29 | 30 | # Invert CDF 31 | u = u.contiguous() 32 | inds = torch.searchsorted(cdf, u, right=True) 33 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 34 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 35 | inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) 36 | 37 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 38 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 39 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 40 | 41 | denom = cdf_g[..., 1] - cdf_g[..., 0] 42 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 43 | t = (u - cdf_g[..., 0]) / denom 44 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 45 | 46 | return samples 47 | 48 | 49 | def plot_pointcloud(pc, color=None): 50 | # pc: [N, 3] 51 | # color: [N, 3/4] 52 | print("[visualize points]", pc.shape, pc.dtype, pc.min(0), pc.max(0)) 53 | pc = trimesh.PointCloud(pc, color) 54 | # axis 55 | axes = trimesh.creation.axis(axis_length=4) 56 | # sphere 57 | sphere = trimesh.creation.icosphere(radius=1) 58 | trimesh.Scene([pc, axes, sphere]).show() 59 | 60 | 61 | class SemanticNeRFRenderer(nn.Module): 62 | 63 | def __init__( 64 | self, 65 | bound=1, 66 | cuda_ray=False, 67 | density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. 68 | num_semantic_classes=41, 69 | ): 70 | super().__init__() 71 | 72 | self.epoch = 1 73 | self.weights = np.zeros([0]) 74 | self.weights_sum = np.zeros([0]) 75 | 76 | self.bound = bound 77 | self.cascade = 1 + math.ceil(math.log2(bound)) 78 | self.density_scale = density_scale 79 | self.num_semantic_classes = num_semantic_classes 80 | 81 | # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) 82 | # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. 83 | aabb_train = torch.FloatTensor( 84 | [-bound, -bound, -bound, bound, bound, bound]) 85 | aabb_infer = aabb_train.clone() 86 | self.register_buffer("aabb_train", aabb_train) 87 | self.register_buffer("aabb_infer", aabb_infer) 88 | 89 | # extra state for cuda raymarching 90 | self.cuda_ray = cuda_ray 91 | if cuda_ray: 92 | # density grid 93 | density_grid = torch.zeros([self.cascade] + 94 | [128] * 3) # [CAS, H, H, H] 95 | self.register_buffer("density_grid", density_grid) 96 | self.mean_density = 0 97 | self.iter_density = 0 98 | # step counter 99 | step_counter = torch.zeros( 100 | 16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... 101 | self.register_buffer("step_counter", step_counter) 102 | self.mean_count = 0 103 | self.local_step = 0 104 | 105 | def forward(self, x, d): 106 | raise NotImplementedError() 107 | 108 | def density(self, x): 109 | raise NotImplementedError() 110 | 111 | def reset_extra_state(self): 112 | if not self.cuda_ray: 113 | return 114 | # density grid 115 | self.density_grid.zero_() 116 | self.mean_density = 0 117 | self.iter_density = 0 118 | # step counter 119 | self.step_counter.zero_() 120 | self.mean_count = 0 121 | self.local_step = 0 122 | 123 | def run(self, 124 | rays_o, 125 | rays_d, 126 | direction_norms, 127 | num_steps=256, 128 | upsample_steps=256, 129 | bg_color=None, 130 | perturb=False, 131 | epoch=None, 132 | **kwargs): 133 | # rays_o, rays_d: [B, N, 3], assumes B == 1 134 | # direction_norms: [B, N, 1] 135 | # bg_color: [3] in range [0, 1] 136 | # return: image: [B, N, 3], depth: [B, N] 137 | 138 | prefix = rays_o.shape[:-1] 139 | rays_o = rays_o.contiguous().view(-1, 3) 140 | rays_d = rays_d.contiguous().view(-1, 3) 141 | direction_norms = direction_norms.contiguous().view(-1) 142 | 143 | N = rays_o.shape[0] # N = B * N, in fact 144 | device = rays_o.device 145 | 146 | # choose aabb 147 | aabb = self.aabb_train if self.training else self.aabb_infer 148 | 149 | # sample steps 150 | nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb) 151 | nears.unsqueeze_(-1) 152 | fars.unsqueeze_(-1) 153 | 154 | z_vals = torch.linspace(0.0, 1.0, num_steps, 155 | device=device).unsqueeze(0) # [1, T] 156 | z_vals = z_vals.expand((N, num_steps)) # [N, T] 157 | z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] 158 | 159 | # perturb z_vals 160 | sample_dist = (fars - nears) / num_steps 161 | if perturb: 162 | mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) 163 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 164 | lower = torch.cat([z_vals[..., :1], mids], -1) 165 | # stratified samples in those intervals 166 | t_rand = torch.rand(z_vals.shape, device=device) 167 | 168 | z_vals = lower + (upper - lower) * t_rand 169 | 170 | # generate xyzs 171 | xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze( 172 | -1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] 173 | xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. 174 | 175 | # query density and RGB 176 | density_outputs = self.density(xyzs.reshape(-1, 3)) 177 | 178 | for k, v in density_outputs.items(): 179 | density_outputs[k] = v.view(N, num_steps, -1) 180 | 181 | # upsample z_vals (nerf-like) 182 | if upsample_steps > 0: 183 | with torch.no_grad(): 184 | 185 | deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] 186 | deltas = torch.cat( 187 | [deltas, 1e10 * torch.ones_like(deltas[..., :1])], dim=-1) 188 | 189 | alphas = 1 - torch.exp( 190 | -deltas * self.density_scale * 191 | density_outputs["sigma"].squeeze(-1)) # [N, T] 192 | alphas_shifted = torch.cat( 193 | [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], 194 | dim=-1, 195 | ) # [N, T+1] 196 | weights = (alphas * 197 | torch.cumprod(alphas_shifted, dim=-1)[..., :-1] 198 | ) # [N, T] 199 | 200 | # sample new z_vals 201 | z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1] 202 | ) # [N, T-1] 203 | new_z_vals = sample_pdf(z_vals_mid, 204 | weights[:, 1:-1], 205 | upsample_steps, 206 | det=False).detach() # [N, t] 207 | 208 | new_xyzs = rays_o.unsqueeze( 209 | -2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze( 210 | -1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] 211 | new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), 212 | aabb[3:]) # a manual clip. 213 | 214 | # only forward new points to save computation 215 | new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) 216 | # new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] 217 | for k, v in new_density_outputs.items(): 218 | new_density_outputs[k] = v.view(N, upsample_steps, -1) 219 | 220 | # re-order 221 | z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] 222 | z_vals, z_index = torch.sort(z_vals, dim=1) 223 | 224 | xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] 225 | xyzs = torch.gather(xyzs, 226 | dim=1, 227 | index=z_index.unsqueeze(-1).expand_as(xyzs)) 228 | 229 | for k in density_outputs: 230 | tmp_output = torch.cat( 231 | [density_outputs[k], new_density_outputs[k]], dim=1) 232 | density_outputs[k] = torch.gather( 233 | tmp_output, 234 | dim=1, 235 | index=z_index.unsqueeze(-1).expand_as(tmp_output), 236 | ) 237 | 238 | deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] 239 | deltas = torch.cat([deltas, 1e10 * torch.ones_like(deltas[..., :1])], 240 | dim=-1) 241 | alphas = 1 - torch.exp(-deltas * self.density_scale * 242 | density_outputs["sigma"].squeeze(-1)) # [N, T+t] 243 | alphas_shifted = torch.cat( 244 | [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], 245 | dim=-1) # [N, T+t+1] 246 | weights = (alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] 247 | ) # [N, T+t] 248 | 249 | mask_rgb = weights > 1e-4 # hard coded 250 | mask_semantics = weights > 1e-4 # hard coded 251 | 252 | dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) 253 | for k, v in density_outputs.items(): 254 | density_outputs[k] = v.view(-1, v.shape[-1]) 255 | 256 | rgbs = self.color(xyzs.reshape(-1, 3), 257 | dirs.reshape(-1, 3), 258 | mask=mask_rgb.reshape(-1), 259 | **density_outputs) 260 | rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] 261 | 262 | local_semantics = self.semantics(xyzs.reshape(-1, 3), 263 | dirs.reshape(-1, 3), 264 | mask=mask_semantics.reshape(-1), 265 | **density_outputs) 266 | local_semantics = local_semantics.view( 267 | N, -1, self.num_semantic_classes) # [N, T+t, 3] 268 | 269 | # calculate weight_sum (mask) 270 | weights_semantics = weights.clone().detach() 271 | weights[torch.logical_not(mask_rgb)] = 0.0 272 | 273 | # calculate depth 274 | ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) 275 | # depth = torch.sum(weights * ori_z_vals, dim=-1) 276 | depth = torch.sum(weights * z_vals, dim=-1) 277 | depth = depth / direction_norms 278 | 279 | # calculate color 280 | 281 | image = torch.sum(weights.unsqueeze(-1) * rgbs, 282 | dim=-2) # [N, 3], in [0, 1] 283 | weights_semantics[torch.logical_not(mask_semantics)] = 0.0 284 | semantics = torch.sum(weights_semantics.unsqueeze(-1) * local_semantics, 285 | dim=-2) # [N, C], in [0, 1] 286 | 287 | # mix background color 288 | if bg_color is None: 289 | bg_color = 1 290 | 291 | image = image.view(*prefix, 3) 292 | depth = depth.view(*prefix) 293 | semantics = semantics.view(*prefix, self.num_semantic_classes) 294 | 295 | return { 296 | "depth": depth, 297 | "image": image, 298 | "semantics": semantics, 299 | } 300 | 301 | def render(self, 302 | rays_o, 303 | rays_d, 304 | direction_norms, 305 | staged=False, 306 | max_ray_batch=4096, 307 | bg_color=None, 308 | perturb=False, 309 | epoch=None, 310 | **kwargs): 311 | # rays_o, rays_d: [B, N, 3], assumes B == 1 312 | # direction_norms: [B, N, 1] 313 | # return: pred_rgb: [B, N, 3] 314 | 315 | _run = self.run 316 | 317 | B, N = rays_o.shape[:2] 318 | device = rays_o.device 319 | 320 | # never stage when cuda_ray 321 | if staged and not self.cuda_ray: 322 | depth = torch.empty((B, N), device=device) 323 | image = torch.empty((B, N, 3), device=device) 324 | semantics = torch.empty((B, N, self.num_semantic_classes), 325 | device=device) 326 | 327 | for b in range(B): 328 | head = 0 329 | while head < N: 330 | tail = min(head + max_ray_batch, N) 331 | results_ = _run(rays_o[b:b + 1, head:tail], 332 | rays_d[b:b + 1, head:tail], 333 | direction_norms=direction_norms[b:b + 1, 334 | head:tail], 335 | bg_color=bg_color, 336 | perturb=perturb, 337 | epoch=epoch, 338 | **kwargs) 339 | depth[b:b + 1, head:tail] = results_["depth"] 340 | image[b:b + 1, head:tail] = results_["image"] 341 | semantics[b:b + 1, head:tail] = results_["semantics"] 342 | head += max_ray_batch 343 | 344 | results = {} 345 | results["depth"] = depth 346 | results["image"] = image 347 | results["semantics"] = semantics 348 | 349 | else: 350 | results = _run(rays_o, 351 | rays_d, 352 | direction_norms=direction_norms, 353 | bg_color=bg_color, 354 | perturb=perturb, 355 | epoch=epoch, 356 | **kwargs) 357 | 358 | return results 359 | -------------------------------------------------------------------------------- /nr4seg/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplabv3 import * 2 | -------------------------------------------------------------------------------- /nr4seg/network/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | from torch import nn 4 | 5 | 6 | class DeepLabV3(nn.Module): 7 | 8 | def __init__(self, cfg_model): 9 | super().__init__() 10 | self._model = torchvision.models.segmentation.deeplabv3_resnet101( 11 | pretrained=cfg_model["pretrained"], 12 | pretrained_backbone=cfg_model["pretrained_backbone"], 13 | progress=True, 14 | num_classes=cfg_model["num_classes"], 15 | aux_loss=None, 16 | ) 17 | 18 | def forward(self, data): 19 | return self._model(data) 20 | -------------------------------------------------------------------------------- /nr4seg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .flatten_dict import * 2 | from .get_logger import * 3 | from .loading import * 4 | -------------------------------------------------------------------------------- /nr4seg/utils/flatten_dict.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | __all__ = ["flatten_dict"] 4 | 5 | 6 | def flatten_dict(d, parent_key="", sep="_"): 7 | items = [] 8 | for k, v in d.items(): 9 | new_key = parent_key + sep + k if parent_key else k 10 | if isinstance(v, collections.MutableMapping): 11 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 12 | else: 13 | if isinstance(v, list): 14 | if isinstance(v[0], dict): 15 | items.extend(flatten_list(v, new_key, sep=sep)) 16 | continue 17 | items.append((new_key, v)) 18 | return dict(items) 19 | -------------------------------------------------------------------------------- /nr4seg/utils/get_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 4 | from pytorch_lightning.loggers.neptune import NeptuneLogger 5 | 6 | from nr4seg.utils import flatten_dict 7 | 8 | __all__ = ["get_neptune_logger", "get_tensorboard_logger", "get_wandb_logger"] 9 | 10 | 11 | def log_important_params(exp): 12 | dic = {} 13 | dic = flatten_dict(exp) 14 | return dic 15 | 16 | 17 | def get_neptune_logger(exp, env, exp_p, env_p, project_name=""): 18 | params = log_important_params(exp) 19 | 20 | name_full = exp["general"]["name"] 21 | name_short = "__".join(name_full.split("/")[-2:]) 22 | 23 | return NeptuneLogger( 24 | api_key=os.environ["NEPTUNE_API_TOKEN"], 25 | project=project_name, 26 | name=name_short, 27 | tags=[ 28 | os.environ["ENV_WORKSTATION_NAME"], 29 | name_full.split("/")[-2], 30 | name_full.split("/")[-1], 31 | ], 32 | ) 33 | 34 | 35 | def get_wandb_logger(exp, env, exp_p, env_p, project_name, save_dir): 36 | params = log_important_params(exp) 37 | name_full = exp["general"]["name"] 38 | name_short = "__".join(name_full.split("/")[-2:]) 39 | return WandbLogger( 40 | name=name_short, 41 | project=project_name, 42 | save_dir=save_dir, 43 | ) 44 | 45 | 46 | def get_tensorboard_logger(exp, env, exp_p, env_p): 47 | params = log_important_params(exp) 48 | return TensorBoardLogger( 49 | save_dir=exp["general"]["name"], 50 | name="tensorboard", 51 | default_hp_metric=params, 52 | ) 53 | -------------------------------------------------------------------------------- /nr4seg/utils/loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | __all__ = ["file_path", "load_yaml"] 5 | 6 | 7 | def file_path(string): 8 | if os.path.isfile(string): 9 | return string 10 | else: 11 | raise NotADirectoryError(string) 12 | 13 | 14 | def load_yaml(path): 15 | with open(path) as file: 16 | res = yaml.load(file, Loader=yaml.FullLoader) 17 | return res 18 | -------------------------------------------------------------------------------- /nr4seg/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from sklearn.metrics import confusion_matrix 5 | 6 | 7 | def nanmean(data, **args): 8 | # This makes it ignore the first 'background' class 9 | return np.ma.masked_array(data, np.isnan(data)).mean(**args) 10 | # In np.ma.masked_array(data, np.isnan(data), elements of data == np.nan is invalid and will be ingorned during computation of np.mean() 11 | 12 | 13 | class SemanticsMeter: 14 | 15 | def __init__(self, number_classes): 16 | self.conf_mat = None 17 | self.number_classes = number_classes 18 | 19 | def clear(self): 20 | self.conf_mat = None 21 | 22 | def prepare_inputs(self, *inputs): 23 | outputs = [] 24 | for i, inp in enumerate(inputs): 25 | if torch.is_tensor(inp): 26 | inp = inp.detach().cpu().numpy() 27 | outputs.append(inp) 28 | 29 | return outputs 30 | 31 | def update(self, preds, truths): 32 | preds, truths = self.prepare_inputs( 33 | preds, truths) # [B, N, 3] or [B, H, W, 3], range[0, 1] 34 | preds = preds.flatten() 35 | truths = truths.flatten() 36 | valid_pix_ids = truths != -1 37 | preds = preds[valid_pix_ids] 38 | truths = truths[valid_pix_ids] 39 | conf_mat_current = confusion_matrix(truths, 40 | preds, 41 | labels=list( 42 | range(self.number_classes))) 43 | if self.conf_mat is None: 44 | self.conf_mat = conf_mat_current 45 | else: 46 | self.conf_mat += conf_mat_current 47 | 48 | def measure(self): 49 | conf_mat = self.conf_mat 50 | norm_conf_mat = np.transpose( 51 | np.transpose(conf_mat) / conf_mat.astype(np.float).sum(axis=1)) 52 | 53 | missing_class_mask = np.isnan(norm_conf_mat.sum( 54 | 1)) # missing class will have NaN at corresponding class 55 | exsiting_class_mask = ~missing_class_mask 56 | 57 | class_average_accuracy = nanmean(np.diagonal(norm_conf_mat)) 58 | total_accuracy = np.sum(np.diagonal(conf_mat)) / np.sum(conf_mat) 59 | ious = np.zeros(self.number_classes) 60 | for class_id in range(self.number_classes): 61 | ious[class_id] = conf_mat[class_id, class_id] / ( 62 | np.sum(conf_mat[class_id, :]) + np.sum(conf_mat[:, class_id]) - 63 | conf_mat[class_id, class_id]) 64 | miou_valid_class = np.mean(ious[exsiting_class_mask]) 65 | return miou_valid_class, total_accuracy, class_average_accuracy 66 | -------------------------------------------------------------------------------- /nr4seg/visualizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .colormaps import * 2 | from .visualizer import * 3 | -------------------------------------------------------------------------------- /nr4seg/visualizer/colormaps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from collections import OrderedDict 4 | from matplotlib import cm 5 | 6 | ORDERED_DICT = OrderedDict([ 7 | ("unlabeled", (0, 0, 0)), 8 | ("wall", (174, 199, 232)), 9 | ("floor", (152, 223, 138)), 10 | ("cabinet", (31, 119, 180)), 11 | ("bed", (255, 187, 120)), 12 | ("chair", (188, 189, 34)), 13 | ("sofa", (140, 86, 75)), 14 | ("table", (255, 152, 150)), 15 | ("door", (214, 39, 40)), 16 | ("window", (197, 176, 213)), 17 | ("bookshelf", (148, 103, 189)), 18 | ("picture", (196, 156, 148)), 19 | ("counter", (23, 190, 207)), 20 | ("blinds", (178, 76, 76)), 21 | ("desk", (247, 182, 210)), 22 | ("shelves", (66, 188, 102)), 23 | ("curtain", (219, 219, 141)), 24 | ("dresser", (140, 57, 197)), 25 | ("pillow", (202, 185, 52)), 26 | ("mirror", (51, 176, 203)), 27 | ("floormat", (200, 54, 131)), 28 | ("clothes", (92, 193, 61)), 29 | ("ceiling", (78, 71, 183)), 30 | ("books", (172, 114, 82)), 31 | ("refrigerator", (255, 127, 14)), 32 | ("television", (91, 163, 138)), 33 | ("paper", (153, 98, 156)), 34 | ("towel", (140, 153, 101)), 35 | ("showercurtain", (158, 218, 229)), 36 | ("box", (100, 125, 154)), 37 | ("whiteboard", (178, 127, 135)), 38 | ("person", (120, 185, 128)), 39 | ("nightstand", (146, 111, 194)), 40 | ("toilet", (44, 160, 44)), 41 | ("sink", (112, 128, 144)), 42 | ("lamp", (96, 207, 209)), 43 | ("bathtub", (227, 119, 194)), 44 | ("bag", (213, 92, 176)), 45 | ("otherstructure", (94, 106, 211)), 46 | ("otherfurniture", (82, 84, 163)), 47 | ("otherprop", (100, 85, 144)), 48 | ]) 49 | SCANNET_CLASSES = [i for i, v in enumerate(ORDERED_DICT.values())] 50 | SCANNET_COLORS = [v for i, v in enumerate(ORDERED_DICT.values())] 51 | 52 | jet = cm.get_cmap("jet") 53 | BINARY_COLORS = (np.stack([jet(v) for v in np.linspace(0, 1, 2)]) * 255).astype( 54 | np.uint8) 55 | -------------------------------------------------------------------------------- /nr4seg/visualizer/visualizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | import os 5 | import skimage 6 | import wandb 7 | 8 | from matplotlib.backends.backend_agg import FigureCanvasAgg 9 | from PIL import Image, ImageDraw 10 | 11 | from nr4seg.visualizer import BINARY_COLORS, SCANNET_CLASSES, SCANNET_COLORS 12 | 13 | __all__ = ["Visualizer"] 14 | 15 | 16 | def get_img_from_fig(fig, dpi=180): 17 | """ 18 | Converts matplot figure to np.array 19 | """ 20 | 21 | fig.set_dpi(dpi) 22 | canvas = FigureCanvasAgg(fig) 23 | # Retrieve a view on the renderer buffer 24 | canvas.draw() 25 | buf = canvas.buffer_rgba() 26 | # convert to a np array 27 | buf = np.asarray(buf) 28 | buf = Image.fromarray(buf) 29 | buf = buf.convert("RGB") 30 | return buf 31 | 32 | 33 | def image_functionality(func): 34 | """ 35 | Decorator to allow for logging functionality. 36 | Should be added to all plotting functions of visualizer. 37 | The plot function has to return a np.uint8 38 | 39 | @image_functionality 40 | def plot_segmentation(self, seg, **kwargs): 41 | 42 | return np.zeros((H, W, 3), dtype=np.uint8) 43 | 44 | not_log [optional, bool, default: false] : the decorator is ignored if set to true 45 | epoch [optional, int, default: visualizer.epoch ] : overwrites visualizer, epoch used to log the image 46 | store [optional, bool, default: visualizer.store ] : overwrites visualizer, flag if image is stored to disk 47 | tag [optinal, str, default: tag_not_defined ] : stores image as: visulaiter._p_visu{epoch}_{tag}.png 48 | """ 49 | 50 | def wrap(*args, **kwargs): 51 | img = func(*args, **kwargs) 52 | 53 | if not kwargs.get("not_log", False): 54 | log_exp = args[0]._pl_model.logger is not None 55 | tag = kwargs.get("tag", "tag_not_defined") 56 | 57 | if kwargs.get("store", None) is not None: 58 | store = kwargs["store"] 59 | else: 60 | store = args[0]._store 61 | 62 | if kwargs.get("epoch", None) is not None: 63 | epoch = kwargs["epoch"] 64 | else: 65 | epoch = args[0]._epoch 66 | 67 | # Store to disk 68 | if store: 69 | p = os.path.join(args[0]._p_visu, f"{tag}_epoch_{epoch}.png") 70 | imageio.imwrite(p, img) 71 | 72 | if log_exp: 73 | H, W, C = img.shape 74 | ds = cv2.resize( 75 | img, 76 | dsize=(int(W / 2), int(H / 2)), 77 | interpolation=cv2.INTER_CUBIC, 78 | ) 79 | if args[0]._pl_model.logger is not None: 80 | args[0]._pl_model.logger.experiment.log( 81 | {tag: [wandb.Image(ds, caption=tag)]}, commit=False) 82 | return func(*args, **kwargs) 83 | 84 | return wrap 85 | 86 | 87 | class Visualizer: 88 | 89 | def __init__(self, p_visu, store, pl_model, epoch=0, num_classes=22): 90 | self._p_visu = p_visu 91 | self._pl_model = pl_model 92 | self._epoch = epoch 93 | self._store = store 94 | 95 | os.makedirs(os.path.join(self._p_visu, "train_vis"), exist_ok=True) 96 | os.makedirs(os.path.join(self._p_visu, "val_vis"), exist_ok=True) 97 | os.makedirs(os.path.join(self._p_visu, "test_vis"), exist_ok=True) 98 | 99 | @property 100 | def epoch(self): 101 | return self._epoch 102 | 103 | @epoch.setter 104 | def epoch(self, epoch): 105 | self._epoch = epoch 106 | 107 | @property 108 | def store(self): 109 | return self._store 110 | 111 | @store.setter 112 | def store(self, store): 113 | self._store = store 114 | 115 | @image_functionality 116 | def plot_segmentation(self, seg, **kwargs): 117 | try: 118 | seg = seg.clone().cpu().numpy() 119 | except: 120 | pass 121 | 122 | if seg.dtype == np.bool: 123 | col_map = BINARY_COLORS 124 | else: 125 | col_map = SCANNET_COLORS 126 | 127 | H, W = seg.shape[:2] 128 | img = np.zeros((H, W, 3), dtype=np.uint8) 129 | for i, color in enumerate(col_map): 130 | img[seg == i] = color[:3] 131 | 132 | return img 133 | 134 | @image_functionality 135 | def plot_image(self, img, **kwargs): 136 | """ 137 | ---------- 138 | img : CHW HWC accepts torch.tensor or numpy.array 139 | Range 0-1 or 0-255 140 | """ 141 | try: 142 | img = img.clone().cpu().numpy() 143 | except: 144 | pass 145 | 146 | if img.shape[2] == 3: 147 | pass 148 | elif img.shape[0] == 3: 149 | img = np.moveaxis(img, [0, 1, 2], [2, 0, 1]) 150 | else: 151 | raise Exception("Invalid Shape") 152 | if img.max() <= 1: 153 | img = img * 255 154 | 155 | img = np.uint8(img) 156 | return img 157 | 158 | @image_functionality 159 | def plot_detectron( 160 | self, 161 | img, 162 | label, 163 | text_off=False, 164 | alpha=0.5, 165 | draw_bound=True, 166 | shift=2.5, 167 | font_size=12, 168 | **kwargs, 169 | ): 170 | """ 171 | ---------- 172 | img : CHW HWC accepts torch.tensor or numpy.array 173 | Range 0-1 or 0-255 174 | label: HW accepts torch.tensor or numpy.array 175 | """ 176 | 177 | img = self.plot_image(img, not_log=True) 178 | try: 179 | label = label.clone().cpu().numpy() 180 | except: 181 | pass 182 | label = label.astype(np.long) 183 | 184 | H, W, C = img.shape 185 | uni = np.unique(label) 186 | overlay = np.zeros_like(img) 187 | 188 | centers = [] 189 | for u in uni: 190 | m = label == u 191 | col = SCANNET_COLORS[u] 192 | overlay[m] = col 193 | labels_mask = skimage.measure.label(m) 194 | regions = skimage.measure.regionprops(labels_mask) 195 | regions.sort(key=lambda x: x.area, reverse=True) 196 | cen = np.mean(regions[0].coords, axis=0).astype(np.uint32)[::-1] 197 | 198 | centers.append((SCANNET_CLASSES[u], cen)) 199 | 200 | back = np.zeros((H, W, 4)) 201 | back[:, :, :3] = img 202 | back[:, :, 3] = 255 203 | fore = np.zeros((H, W, 4)) 204 | fore[:, :, :3] = overlay 205 | fore[:, :, 3] = alpha * 255 206 | img_new = Image.alpha_composite(Image.fromarray(np.uint8(back)), 207 | Image.fromarray(np.uint8(fore))) 208 | draw = ImageDraw.Draw(img_new) 209 | 210 | if not text_off: 211 | 212 | for i in centers: 213 | pose = i[1] 214 | pose[0] -= len(str(i[0])) * shift 215 | pose[1] -= font_size / 2 216 | draw.text(tuple(pose), str(i[0]), fill=(255, 255, 255, 128)) 217 | 218 | img_new = img_new.convert("RGB") 219 | mask = skimage.segmentation.mark_boundaries(np.array(img_new), 220 | label, 221 | color=(255, 255, 255)) 222 | mask = mask.sum(axis=2) 223 | m = mask == mask.max() 224 | img_new = np.array(img_new) 225 | if draw_bound: 226 | img_new[m] = (255, 255, 255) 227 | return np.uint8(img_new) 228 | -------------------------------------------------------------------------------- /preprocessing_scripts/scannet2nerf.py: -------------------------------------------------------------------------------- 1 | # Partly based on https://github.com/ashawkey/torch-ngp/blob/main/scripts/ 2 | # colmap2nerf.py. 3 | 4 | import argparse 5 | import copy 6 | import glob 7 | import json 8 | import numpy as np 9 | import os 10 | 11 | 12 | def rotmat(a, b): 13 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 14 | v = np.cross(a, b) 15 | c = np.dot(a, b) 16 | s = np.linalg.norm(v) 17 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 18 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2 + 1e-10)) 19 | 20 | 21 | def closest_point_2_lines( 22 | oa, da, ob, db 23 | ): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 24 | da = da / np.linalg.norm(da) 25 | db = db / np.linalg.norm(db) 26 | c = np.cross(da, db) 27 | denom = np.linalg.norm(c)**2 28 | t = ob - oa 29 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 30 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 31 | if ta > 0: 32 | ta = 0 33 | if tb > 0: 34 | tb = 0 35 | return (oa + ta * da + ob + tb * db) * 0.5, denom 36 | 37 | 38 | parser = argparse.ArgumentParser( 39 | description= 40 | "Run neural graphics primitives testbed with additional configuration & output options" 41 | ) 42 | 43 | parser.add_argument("--scene_folder", type=str, default="") 44 | parser.add_argument( 45 | "--transform_train", 46 | type=str, 47 | default="", 48 | ) 49 | parser.add_argument( 50 | "--transform_test", 51 | type=str, 52 | default="", 53 | ) 54 | 55 | parser.add_argument("--interval", default=10, type=int, help="Sample Interval.") 56 | 57 | parser.add_argument("--room_center", 58 | action="store_true", 59 | help="Use room centers from the mesh file") 60 | args = parser.parse_args() 61 | 62 | scannet_folder = args.scene_folder 63 | json_train = args.transform_train 64 | json_test = args.transform_test 65 | interval = args.interval 66 | json_train_base = "transforms_train" 67 | json_test_base = "transforms_test" 68 | # Select the frames from the json. This are the frames of which we want to find 69 | # the actual transform. 70 | c2ws = [] 71 | frame_names = [] 72 | with open(json_train, "r") as f: 73 | transforms_train = json.load(f) 74 | # - Get filenames and concurrently load the c2w. 75 | for frame_idx, frame in enumerate(transforms_train["frames"]): 76 | if frame_idx % interval == 0: 77 | frame_name = os.path.basename(frame["file_path"]).split(".jpg")[0] 78 | pose_name = os.path.join(scannet_folder, f"pose/{frame_name}.txt") 79 | c2w = np.loadtxt(pose_name) 80 | if np.any(np.isinf(c2w)): 81 | continue 82 | frame_names.append(frame_name) 83 | c2ws.append(c2w) 84 | 85 | c2ws_test = [] 86 | frame_names_test = [] 87 | with open(json_test, "r") as f: 88 | transforms_test = json.load(f) 89 | # - Get filenames and concurrently load the c2w. 90 | for frame_idx, frame in enumerate(transforms_test["frames"]): 91 | if frame_idx % interval == 0: 92 | frame_name = os.path.basename(frame["file_path"]).split(".jpg")[0] 93 | pose_name = os.path.join(scannet_folder, f"pose/{frame_name}.txt") 94 | c2w = np.loadtxt(pose_name) 95 | if np.any(np.isinf(c2w)): 96 | continue 97 | frame_names_test.append(frame_name) 98 | c2ws_test.append(c2w) 99 | 100 | selected_transforms = copy.deepcopy(transforms_train) 101 | selected_transforms.pop("frames") 102 | selected_transforms["frames"] = [] 103 | selected_transforms_test = copy.deepcopy(transforms_test) 104 | selected_transforms_test.pop("frames") 105 | selected_transforms_test["frames"] = [] 106 | 107 | # Open the mesh file to retrieve the scene center. 108 | if args.room_center: 109 | mesh_files = glob.glob(os.path.join(scannet_folder, "*_vh_clean.ply")) 110 | assert len(mesh_files) == 1, ( 111 | "Found no/more than 1 'vh_clean' mesh files in " 112 | f"{scannet_folder}.") 113 | 114 | mesh = o3d.io.read_triangle_mesh(mesh_files[0]) 115 | max_coord_mesh = np.max(mesh.vertices, axis=0) 116 | min_coord_mesh = np.min(mesh.vertices, axis=0) 117 | room_center = (max_coord_mesh + min_coord_mesh) / 2.0 118 | else: 119 | room_center = np.zeros(3) 120 | 121 | up = np.zeros(3) 122 | print(f"length of c2ws: {len(c2ws)}") 123 | for c2w_idx in range(len(c2ws)): 124 | c2ws[c2w_idx][:3, 3] -= room_center 125 | c2ws[c2w_idx][0:3, 2] *= -1 # flip the y and z axis 126 | c2ws[c2w_idx][0:3, 1] *= -1 127 | c2ws[c2w_idx] = c2ws[c2w_idx][[1, 0, 2, 3], :] # swap y and z 128 | c2ws[c2w_idx][2, :] *= -1 # flip whole world upside down 129 | up += c2ws[c2w_idx][0:3, 1] 130 | 131 | for c2w_idx in range(len(c2ws_test)): 132 | c2ws_test[c2w_idx][:3, 3] -= room_center 133 | c2ws_test[c2w_idx][0:3, 2] *= -1 # flip the y and z axis 134 | c2ws_test[c2w_idx][0:3, 1] *= -1 135 | c2ws_test[c2w_idx] = c2ws_test[c2w_idx][[1, 0, 2, 3], :] # swap y and z 136 | c2ws_test[c2w_idx][2, :] *= -1 # flip whole world upside down 137 | 138 | print(f"up vector: {up}") 139 | 140 | nframes = len(c2ws) 141 | up = up / np.linalg.norm(up) 142 | print("up vector was", up) 143 | R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1] 144 | R = np.pad(R, [0, 1]) 145 | R[-1, -1] = 1 146 | 147 | for c2w_idx in range(len(c2ws)): 148 | c2ws[c2w_idx] = np.matmul(R, c2ws[c2w_idx]) # rotate up to be the z axis 149 | 150 | for c2w_idx in range(len(c2ws_test)): 151 | c2ws_test[c2w_idx] = np.matmul( 152 | R, c2ws_test[c2w_idx]) # rotate up to be the z axis 153 | 154 | # find a central point they are all looking at 155 | if not args.room_center: 156 | print("computing center of attention...") 157 | totw = 0.0 158 | totp = np.array([0.0, 0.0, 0.0]) 159 | for c2w_idx_1 in range(len(c2ws)): 160 | mf = c2ws[c2w_idx_1][0:3, :] 161 | for c2w_idx_2 in range(len(c2ws)): 162 | mg = c2ws[c2w_idx_2][0:3, :] 163 | p, w = closest_point_2_lines(mf[:, 3], mf[:, 2], mg[:, 3], mg[:, 2]) 164 | if w > 0.01: 165 | totp += p * w 166 | totw += w 167 | totp /= totw 168 | print("room center was:") 169 | print(totp) # the cameras are looking at totp 170 | for c2w_idx in range(len(c2ws)): 171 | c2ws[c2w_idx][0:3, 3] -= totp 172 | 173 | for c2w_idx in range(len(c2ws_test)): 174 | c2ws_test[c2w_idx][0:3, 3] -= totp 175 | 176 | avglen = 0.0 177 | for c2w_idx in range(len(c2ws)): 178 | avglen += np.linalg.norm(c2ws[c2w_idx][0:3, 3]) 179 | print(f"avglen:{avglen}") 180 | print(nframes) 181 | avglen /= nframes 182 | 183 | # This factor converts one meter to the unit of measure of the scene. 184 | # NOTE: This incorporates both the scaling previously done in this script 185 | # and the one previously done in `nerf_matrix_to_ngp`, which now no longer 186 | # scales the poses. 187 | one_m_to_scene_uom = 4.0 / avglen * 0.33 188 | 189 | print("avg camera distance from origin", avglen) 190 | for c2w_idx in range(len(c2ws)): 191 | c2ws[c2w_idx][0:3, 3] *= one_m_to_scene_uom # scale to "nerf sized" 192 | for c2w_idx in range(len(c2ws_test)): 193 | c2ws_test[c2w_idx][0:3, 3] *= one_m_to_scene_uom # scale to "nerf sized" 194 | 195 | store_dict = {} 196 | store_dict["avglen"] = avglen 197 | store_dict["up"] = up 198 | store_dict["totp"] = totp 199 | store_dict["totw"] = totw 200 | scene_name = os.path.basename(os.path.dirname(scannet_folder)) 201 | 202 | curr_frame_name_idx = 0 203 | for frame_idx in range(len(transforms_train["frames"])): 204 | if curr_frame_name_idx == len(frame_names): 205 | break 206 | frame = transforms_train["frames"][frame_idx] 207 | frame_name = os.path.basename(frame["file_path"]).split(".jpg")[0] 208 | if frame_name == frame_names[curr_frame_name_idx]: 209 | c2w = c2ws[curr_frame_name_idx] 210 | transforms_train["frames"][frame_idx]["transform_matrix"] = c2w.tolist() 211 | selected_transforms["frames"].append( 212 | transforms_train["frames"][frame_idx]) 213 | curr_frame_name_idx += 1 214 | selected_transforms["one_m_to_scene_uom"] = one_m_to_scene_uom 215 | 216 | out_path = os.path.join(scannet_folder, f"{json_train_base}.json") 217 | with open(out_path, "w") as f: 218 | json.dump(selected_transforms, f, indent=4) 219 | 220 | curr_frame_name_idx = 0 221 | for frame_idx in range(len(transforms_test["frames"])): 222 | if curr_frame_name_idx == len(frame_names_test): 223 | break 224 | frame = transforms_test["frames"][frame_idx] 225 | frame_name = os.path.basename(frame["file_path"]).split(".jpg")[0] 226 | if frame_name == frame_names_test[curr_frame_name_idx]: 227 | c2w = c2ws_test[curr_frame_name_idx] 228 | transforms_test["frames"][frame_idx]["transform_matrix"] = c2w.tolist() 229 | selected_transforms_test["frames"].append( 230 | transforms_test["frames"][frame_idx]) 231 | curr_frame_name_idx += 1 232 | selected_transforms_test["one_m_to_scene_uom"] = one_m_to_scene_uom 233 | 234 | out_path = os.path.join(scannet_folder, f"{json_test_base}.json") 235 | with open(out_path, "w") as f: 236 | json.dump(selected_transforms_test, f, indent=4) 237 | -------------------------------------------------------------------------------- /preprocessing_scripts/scannet2transform.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import csv 4 | import cv2 5 | import json 6 | import numpy as np 7 | import os 8 | 9 | 10 | def load_scannet_nyu40_mapping(path): 11 | mapping = {} 12 | with open(os.path.join(path, 'scannetv2-labels.combined.tsv')) as tsvfile: 13 | tsvreader = csv.reader(tsvfile, delimiter='\t') 14 | for i, line in enumerate(tsvreader): 15 | if i == 0: 16 | continue 17 | scannet_id, nyu40id = int(line[0]), int(line[4]) 18 | mapping[scannet_id] = nyu40id 19 | return mapping 20 | 21 | 22 | def load_scannet_nyu13_mapping(path): 23 | mapping = {} 24 | with open(os.path.join(path, 'scannetv2-labels.combined.tsv')) as tsvfile: 25 | tsvreader = csv.reader(tsvfile, delimiter='\t') 26 | for i, line in enumerate(tsvreader): 27 | if i == 0: 28 | continue 29 | scannet_id, nyu40id = int(line[0]), int(line[5]) 30 | mapping[scannet_id] = nyu40id 31 | return mapping 32 | 33 | 34 | parser = argparse.ArgumentParser( 35 | description= 36 | "Run neural graphics primitives testbed with additional configuration & output options" 37 | ) 38 | 39 | parser.add_argument("--scene_folder", type=str, default="") 40 | parser.add_argument("--scaled_image", action="store_true") 41 | parser.add_argument("--semantics", action="store_true") 42 | args = parser.parse_args() 43 | basedir = args.scene_folder 44 | 45 | print(f"processing folder: {basedir}") 46 | 47 | # Step for generating training images 48 | step = 1 49 | 50 | frame_ids = os.listdir(os.path.join(basedir, 'color')) 51 | frame_ids = [int(os.path.splitext(frame)[0]) for frame in frame_ids] 52 | frame_ids = sorted(frame_ids) 53 | 54 | intrinsic_file = os.path.join(basedir, "intrinsic/intrinsic_color.txt") 55 | intrinsic = np.loadtxt(intrinsic_file) 56 | print("intrinsic parameters:") 57 | print(intrinsic) 58 | 59 | imgs = [] 60 | poses = [] 61 | 62 | K_unscaled = copy.deepcopy(intrinsic) 63 | 64 | W_unscaled = 1296 65 | H_unscaled = 968 66 | 67 | W = 320 68 | H = 240 69 | K = copy.deepcopy(intrinsic) 70 | scale_x = 320. / 1296. 71 | scale_y = 240. / 968. 72 | 73 | K[0, 0] = K[0, 0] * scale_x # fx 74 | K[1, 1] = K[1, 1] * scale_y # fy 75 | K[0, 2] = K[0, 2] * scale_x # cx 76 | K[1, 2] = K[1, 2] * scale_y # cy 77 | 78 | if args.semantics: 79 | label_mapping_nyu = load_scannet_nyu40_mapping(basedir) 80 | os.makedirs(os.path.join(basedir, 'label_40'), exist_ok=True) 81 | os.makedirs(os.path.join(basedir, 'label_40_scaled'), exist_ok=True) 82 | 83 | train_ids = frame_ids[::step] 84 | test_id_step = 10 85 | test_ids = [] 86 | for x in train_ids: 87 | if x + (test_id_step // 2) < len(frame_ids): 88 | test_ids.append(x + (test_id_step // 2)) 89 | test_ids = test_ids[:: 90 | test_id_step] # only use 10% of the test frames to speed up inference 91 | 92 | print(f"total number of training frames: {len(train_ids)}") 93 | print(f"total number of testing frames: {len(test_ids)}") 94 | 95 | os.makedirs(os.path.join(basedir, 'color_scaled'), exist_ok=True) 96 | 97 | for ids in (train_ids, test_ids): 98 | transform_json = {} 99 | transform_json["fl_x"] = K[0, 0] 100 | transform_json["fl_y"] = K[1, 1] 101 | transform_json["cx"] = K[0, 2] 102 | transform_json["cy"] = K[1, 2] 103 | transform_json["w"] = W 104 | transform_json["h"] = H 105 | transform_json["camera_angle_x"] = np.arctan2(W / 2, K[0, 0]) * 2 106 | transform_json["camera_angle_y"] = np.arctan2(H / 2, K[1, 1]) * 2 107 | transform_json["aabb_scale"] = 16 108 | transform_json["frames"] = [] 109 | 110 | transform_json_unscaled = {} 111 | transform_json_unscaled["fl_x"] = K_unscaled[0, 0] 112 | transform_json_unscaled["fl_y"] = K_unscaled[1, 1] 113 | transform_json_unscaled["cx"] = K_unscaled[0, 2] 114 | transform_json_unscaled["cy"] = K_unscaled[1, 2] 115 | transform_json_unscaled["w"] = W_unscaled 116 | transform_json_unscaled["h"] = H_unscaled 117 | transform_json_unscaled["camera_angle_x"] = np.arctan2( 118 | W_unscaled / 2, K_unscaled[0, 0]) * 2 119 | transform_json_unscaled["camera_angle_y"] = np.arctan2( 120 | H_unscaled / 2, K_unscaled[1, 1]) * 2 121 | transform_json_unscaled["aabb_scale"] = 16 122 | transform_json_unscaled["frames"] = [] 123 | 124 | for frame_id in ids: 125 | pose = np.loadtxt(os.path.join(basedir, 'pose', '%d.txt' % frame_id)) 126 | pose = pose.reshape((4, 4)) 127 | if np.any(np.isinf(pose)): 128 | continue 129 | if args.scaled_image: 130 | file_name_image = os.path.join(basedir, 'color', 131 | '%d.jpg' % frame_id) 132 | image = cv2.imread( 133 | file_name_image)[:, :, :: 134 | -1] # change from BGR uinit 8 to RGB float 135 | #image = cv2.copyMakeBorder(src=image, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=[0,0,0]) # pad 4 pixels to height so that images have aspect ratio of 4:3 136 | #assert image.shape[0] * 4==3 * image.shape[1] 137 | image = image / 255.0 138 | image = cv2.resize(image, (W, H), interpolation=cv2.INTER_AREA) 139 | image_save = cv2.cvtColor(image.astype(np.float32), 140 | cv2.COLOR_BGR2RGB) 141 | image_save = image_save * 255.0 142 | cv2.imwrite( 143 | os.path.join(basedir, 'color_scaled', '%d.jpg' % frame_id), 144 | image_save) 145 | 146 | if args.semantics: 147 | file_name_label = os.path.join(basedir, 'label-filt', 148 | '%d.png' % frame_id) 149 | semantic = cv2.imread(file_name_label, cv2.IMREAD_UNCHANGED) 150 | semantic_copy = copy.deepcopy(semantic) 151 | for scan_id, nyu_id in label_mapping_nyu.items(): 152 | semantic[semantic_copy == scan_id] = nyu_id 153 | # semantic_scaled = cv2.copyMakeBorder(src=semantic, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=0) 154 | semantic_scaled = cv2.resize(semantic, (W, H), 155 | interpolation=cv2.INTER_NEAREST) 156 | semantic = semantic.astype(np.uint8) 157 | semantic_scaled = semantic_scaled.astype(np.uint8) 158 | cv2.imwrite( 159 | os.path.join(basedir, 'label_40_scaled', 160 | '%d.png' % frame_id), semantic_scaled) 161 | cv2.imwrite( 162 | os.path.join(basedir, 'label_40', '%d.png' % frame_id), 163 | semantic) 164 | 165 | json_image_dict = {} 166 | json_image_dict["file_path"] = os.path.join('color_scaled', 167 | '%d.jpg' % frame_id) 168 | if args.semantics: 169 | json_image_dict["label_path"] = os.path.join( 170 | 'label_40_scaled', '%d.png' % frame_id) 171 | json_image_dict["transform_matrix"] = pose.tolist() 172 | transform_json["frames"].append(json_image_dict) 173 | 174 | json_image_dict_unscaled = {} 175 | json_image_dict_unscaled["file_path"] = os.path.join( 176 | 'color', '%d.jpg' % frame_id) 177 | if args.semantics: 178 | json_image_dict_unscaled["label_path"] = os.path.join( 179 | 'label_40', '%d.png' % frame_id) 180 | json_image_dict_unscaled["transform_matrix"] = pose.tolist() 181 | transform_json_unscaled["frames"].append(json_image_dict_unscaled) 182 | 183 | if args.scaled_image: 184 | if ids == train_ids: 185 | file_name = 'transforms_train_scaled' 186 | else: 187 | file_name = 'transforms_test_scaled' 188 | 189 | if args.semantics: 190 | file_name += "_semantics_40_raw" 191 | file_name += ".json" 192 | out_file = open(os.path.join(basedir, file_name), "w") 193 | json.dump(transform_json, out_file, indent=4) 194 | out_file.close() 195 | 196 | else: 197 | if ids == train_ids: 198 | file_name = 'transforms_train' 199 | else: 200 | file_name = 'transforms_test' 201 | if args.semantics: 202 | file_name += "_semantics_40_raw" 203 | file_name += ".json" 204 | out_file = open(os.path.join(basedir, file_name), "w") 205 | json.dump(transform_json_unscaled, out_file, indent=4) 206 | out_file.close() 207 | -------------------------------------------------------------------------------- /preprocessing_scripts/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # color palette for nyu40 labels 4 | nyu40_colour_code = np.array([ 5 | (0, 0, 0), 6 | (174, 199, 232), # wall 7 | (152, 223, 138), # floor 8 | (31, 119, 180), # cabinet 9 | (255, 187, 120), # bed 10 | (188, 189, 34), # chair 11 | (140, 86, 75), # sofa 12 | (255, 152, 150), # table 13 | (214, 39, 40), # door 14 | (197, 176, 213), # window 15 | (148, 103, 189), # bookshelf 16 | (196, 156, 148), # picture 17 | (23, 190, 207), # counter 18 | (178, 76, 76), # blinds 19 | (247, 182, 210), # desk 20 | (66, 188, 102), # shelves 21 | (219, 219, 141), # curtain 22 | (140, 57, 197), # dresser 23 | (202, 185, 52), # pillow 24 | (51, 176, 203), # mirror 25 | (200, 54, 131), # floor 26 | (92, 193, 61), # clothes 27 | (78, 71, 183), # ceiling 28 | (172, 114, 82), # books 29 | (255, 127, 14), # refrigerator 30 | (91, 163, 138), # tv 31 | (153, 98, 156), # paper 32 | (140, 153, 101), # towel 33 | (158, 218, 229), # shower curtain 34 | (100, 125, 154), # box 35 | (178, 127, 135), # white board 36 | (120, 185, 128), # person 37 | (146, 111, 194), # night stand 38 | (44, 160, 44), # toilet 39 | (112, 128, 144), # sink 40 | (96, 207, 209), # lamp 41 | (227, 119, 194), # bathtub 42 | (213, 92, 176), # bag 43 | (94, 106, 211), # other struct 44 | (82, 84, 163), # otherfurn 45 | (100, 85, 144), # other prop 46 | ]).astype(np.uint8) 47 | 48 | nyu13_colour_code = ( 49 | np.array([ 50 | [0, 0, 0], 51 | [0, 0, 1], # BED 52 | [0.9137, 0.3490, 0.1882], # BOOKS 53 | [0, 0.8549, 0], # CEILING 54 | [0.5843, 0, 0.9412], # CHAIR 55 | [0.8706, 0.9451, 0.0941], # FLOOR 56 | [1.0000, 0.8078, 0.8078], # FURNITURE 57 | [0, 0.8784, 0.8980], # OBJECTS 58 | [0.4157, 0.5333, 0.8000], # PAINTING 59 | [0.4588, 0.1137, 0.1608], # SOFA 60 | [0.9412, 0.1373, 0.9216], # TABLE 61 | [0, 0.6549, 0.6118], # TV 62 | [0.9765, 0.5451, 0], # WALL 63 | [0.8824, 0.8980, 0.7608], 64 | ]) * 255).astype(np.uint8) 65 | 66 | nyu40_to13 = { 67 | 0: 0, 68 | 1: 12, 69 | 2: 5, 70 | 3: 6, 71 | 4: 1, 72 | 5: 4, 73 | 6: 9, 74 | 7: 10, 75 | 8: 12, 76 | 9: 13, 77 | 10: 6, 78 | 11: 8, 79 | 12: 6, 80 | 13: 13, 81 | 14: 10, 82 | 15: 6, 83 | 16: 13, 84 | 17: 6, 85 | 18: 7, 86 | 19: 7, 87 | 20: 5, 88 | 21: 7, 89 | 22: 3, 90 | 23: 2, 91 | 24: 6, 92 | 25: 11, 93 | 26: 7, 94 | 27: 7, 95 | 28: 7, 96 | 29: 7, 97 | 30: 7, 98 | 31: 7, 99 | 32: 7, 100 | 33: 7, 101 | 34: 7, 102 | 35: 7, 103 | 36: 7, 104 | 37: 7, 105 | 38: 7, 106 | 39: 6, 107 | 40: 7, 108 | } 109 | 110 | nyu40_to_13 = (np.array([ 111 | 0, 112 | 12, 113 | 5, 114 | 6, 115 | 1, 116 | 4, 117 | 9, 118 | 10, 119 | 12, 120 | 13, 121 | 6, 122 | 8, 123 | 6, 124 | 13, 125 | 10, 126 | 6, 127 | 13, 128 | 6, 129 | 7, 130 | 7, 131 | 5, 132 | 7, 133 | 3, 134 | 2, 135 | 6, 136 | 11, 137 | 7, 138 | 7, 139 | 7, 140 | 7, 141 | 7, 142 | 7, 143 | 7, 144 | 7, 145 | 7, 146 | 7, 147 | 7, 148 | 7, 149 | 7, 150 | 7, 151 | 7, 152 | ])).astype(np.uint8) 153 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | addict==2.4.0 3 | aiohttp==3.8.1 4 | aiosignal==1.2.0 5 | anyio==3.6.1 6 | argon2-cffi-bindings==21.2.0 7 | argon2-cffi==21.3.0 8 | arrow==1.2.2 9 | astroid==2.11.5 10 | asttokens==2.0.5 11 | async-timeout==4.0.2 12 | attrs==21.4.0 13 | babel==2.10.1 14 | backcall==0.2.0 15 | beautifulsoup4==4.11.1 16 | black==22.3.0 17 | bleach==5.0.0 18 | boto3==1.23.1 19 | botocore==1.26.1 20 | bravado-core==5.17.0 21 | bravado==11.0.3 22 | brotlipy==0.7.0 23 | cachetools==5.1.0 24 | certifi==2021.10.8 25 | cffi==1.15.0 26 | charset-normalizer==2.0.4 27 | click==8.1.3 28 | commonmark==0.9.1 29 | cryptography==36.0.0 30 | cycler==0.11.0 31 | dearpygui==1.6.1 32 | debugpy==1.6.0 33 | decorator==5.1.1 34 | defusedxml==0.7.1 35 | deprecation==2.1.0 36 | dill==0.3.4 37 | docker-pycreds==0.4.0 38 | entrypoints==0.4 39 | executing==0.8.3 40 | fastjsonschema==2.15.3 41 | fonttools==4.33.3 42 | fqdn==1.5.1 43 | frozenlist==1.3.0 44 | fsspec==2022.3.0 45 | future==0.18.2 46 | gitdb==4.0.9 47 | gitpython==3.1.27 48 | google-auth-oauthlib==0.4.6 49 | google-auth==2.6.6 50 | grpcio==1.46.1 51 | humanfriendly==10.0 52 | idna==3.3 53 | imageio-ffmpeg==0.4.7 54 | imageio==2.17.0 55 | imgviz==1.5.0 56 | importlib-metadata==4.11.3 57 | iniconfig==1.1.1 58 | ipykernel==6.13.0 59 | ipython-genutils==0.2.0 60 | ipython==8.3.0 61 | ipywidgets==7.7.0 62 | isoduration==20.11.0 63 | isort==5.10.1 64 | jedi==0.18.1 65 | jinja2==3.1.2 66 | jmespath==1.0.0 67 | joblib==1.1.0 68 | json5==0.9.8 69 | jsonpointer==2.3 70 | jsonref==0.2 71 | jsonschema==4.5.1 72 | jupyter-client==7.3.1 73 | jupyter-core==4.10.0 74 | jupyter-packaging==0.12.0 75 | jupyter-server==1.17.0 76 | jupyterlab-pygments==0.2.2 77 | jupyterlab-server==2.14.0 78 | jupyterlab-widgets==1.1.0 79 | jupyterlab==3.4.2 80 | kiwisolver==1.4.2 81 | lazy-object-proxy==1.7.1 82 | markdown==3.3.7 83 | markupsafe==2.1.1 84 | matplotlib-inline==0.1.3 85 | matplotlib==3.5.1 86 | mccabe==0.7.0 87 | mistune==0.8.4 88 | monotonic==1.6 89 | msgpack==1.0.3 90 | multidict==6.0.2 91 | mypy-extensions==0.4.3 92 | nbclassic==0.3.7 93 | nbclient==0.6.3 94 | nbconvert==6.5.0 95 | nbformat==5.4.0 96 | neptune-client==0.16.2 97 | nest-asyncio==1.5.5 98 | networkx==2.8 99 | ninja==1.10.2.3 100 | notebook-shim==0.1.0 101 | notebook==6.4.11 102 | numpy==1.23.0 103 | oauthlib==3.2.0 104 | open3d==0.15.2 105 | opencv-python==4.5.5.64 106 | packaging==21.3 107 | pandas==1.4.2 108 | pandocfilters==1.5.0 109 | parso==0.8.3 110 | pathspec==0.9.0 111 | pathtools==0.1.2 112 | pexpect==4.8.0 113 | pickleshare==0.7.5 114 | pillow==9.0.1 115 | platformdirs==2.5.2 116 | pluggy==1.0.0 117 | prometheus-client==0.14.1 118 | promise==2.3 119 | prompt-toolkit==3.0.29 120 | protobuf==3.20.1 121 | psutil==5.9.0 122 | ptyprocess==0.7.0 123 | pure-eval==0.2.2 124 | py==1.11.0 125 | pyasn1-modules==0.2.8 126 | pyasn1==0.4.8 127 | pycparser==2.21 128 | pydeprecate==0.3.2 129 | pygments==2.12.0 130 | pyjwt==2.4.0 131 | pylint==2.13.9 132 | pymcubes==0.1.2 133 | pyopenssl==22.0.0 134 | pyparsing==3.0.8 135 | pyquaternion==0.9.9 136 | pyrsistent==0.18.1 137 | pysdf==0.1.8 138 | pysocks==1.7.1 139 | pytest==7.1.2 140 | python-dateutil==2.8.2 141 | pytorch-lightning==1.6.3 142 | pytz==2022.1 143 | pywavelets==1.3.0 144 | pyyaml==6.0 145 | pyzmq==22.3.0 146 | requests-oauthlib==1.3.1 147 | requests==2.27.1 148 | rfc3339-validator==0.1.4 149 | rfc3987==1.3.8 150 | rich==12.3.0 151 | rsa==4.8 152 | s3transfer==0.5.2 153 | scikit-image==0.19.2 154 | scikit-learn==1.0.2 155 | scipy==1.8.0 156 | send2trash==1.8.0 157 | sentry-sdk==1.5.12 158 | setproctitle==1.2.3 159 | setuptools==61.2.0 160 | shortuuid==1.0.9 161 | simplejson==3.17.6 162 | six==1.16.0 163 | sklearn==0.0 164 | smmap==5.0.0 165 | sniffio==1.2.0 166 | soupsieve==2.3.2.post1 167 | stack-data==0.2.0 168 | swagger-spec-validator==2.7.4 169 | tensorboard-data-server==0.6.1 170 | tensorboard-plugin-wit==1.8.1 171 | tensorboard==2.9.0 172 | tensorboardx==2.5 173 | terminado==0.15.0 174 | threadpoolctl==3.1.0 175 | tifffile==2022.5.4 176 | tinycss2==1.1.1 177 | tomli==2.0.1 178 | tomlkit==0.10.2 179 | torch-ema==0.3 180 | torch-tb-profiler==0.4.0 181 | torch==1.11.0 182 | torchaudio==0.11.0 183 | torchmetrics==0.8.2 184 | torchvision==0.12.0 185 | tornado==6.1 186 | tqdm==4.64.0 187 | traitlets==5.2.1.post0 188 | trimesh==3.11.2 189 | typing-extensions==4.1.1 190 | uri-template==1.2.0 191 | urllib3==1.26.9 192 | wandb==0.12.16 193 | wcwidth==0.2.5 194 | webcolors==1.11.1 195 | webencodings==0.5.1 196 | websocket-client==1.3.2 197 | werkzeug==2.1.2 198 | wheel==0.37.1 199 | widgetsnbextension==3.6.0 200 | wrapt==1.14.1 201 | yarl==1.7.2 202 | zipp==3.8.0 -------------------------------------------------------------------------------- /run_scripts/multi_step.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python scripts/cl_deeplab.py "$@" 3 | 4 | -------------------------------------------------------------------------------- /run_scripts/one_step_finetune_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | name=one_step_finetune_nerf 3 | prev_exp_name=one_step_nerf_only 4 | declare -a Scenes=("s00" "s10" "s20" "s30" "s40" "s50" "s60" "s70" "s80" "s90") 5 | for i in "${!Scenes[@]}"; do 6 | python scripts/train_finetune.py --exp cfg/exp/one_step_finetune_nerf/${Scenes[i]}_lr1e-5.yml --project_name $name --prev_exp_name $prev_exp_name 7 | done -------------------------------------------------------------------------------- /run_scripts/one_step_joint_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | name=one_step_joint 3 | declare -a Scenes=("s00" "s10" "s20" "s30" "s40" "s50" "s60" "s70" "s80" "s90") 4 | for i in "${!Scenes[@]}"; do 5 | python scripts/train_joint.py --exp cfg/exp/one_step_joint/${Scenes[i]}_lr1e-5.yml --exp_name $name --project_name $name --nerf_train_epoch 10 --joint_train_epoch 50 6 | done -------------------------------------------------------------------------------- /run_scripts/one_step_nerf_only_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | name=one_step_nerf_only 3 | declare -a Scenes=("s00" "s10" "s20" "s30" "s40" "s50" "s60" "s70" "s80" "s90") 4 | for i in "${!Scenes[@]}"; do 5 | python scripts/train_joint.py --exp cfg/exp/one_step_joint/${Scenes[i]}_lr1e-5.yml --exp_name $name --project_name $name --nerf_train_epoch 60 --joint_train_epoch 0 6 | done -------------------------------------------------------------------------------- /run_scripts/preprocess_scannet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | root_dir="data/scannet/scans" 3 | for i in $(ls -d $root_dir/*/); 4 | do 5 | echo ${i}; 6 | python preprocessing_scripts/scannet2transform.py --scene_folder $i --scaled_image --semantics 7 | python preprocessing_scripts/scannet2nerf.py --scene_folder $i --transform_train $i/transforms_train_scaled_semantics_40_raw.json \ 8 | --transform_test $i/transforms_test_scaled_semantics_40_raw.json --interval 10 9 | done -------------------------------------------------------------------------------- /run_scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python scripts/pretrain.py "$@" 3 | 4 | -------------------------------------------------------------------------------- /scripts/cl_deeplab.py: -------------------------------------------------------------------------------- 1 | """Continual learning protocal for calling vis4d in multiple stages.""" 2 | import argparse 3 | import os 4 | import sys 5 | import wandb 6 | 7 | from nr4seg import ROOT_DIR 8 | from nr4seg.utils import load_yaml 9 | from scripts.train_joint import train 10 | 11 | SCENE_ORDER = [ 12 | "scene0000_00", 13 | "scene0001_00", 14 | "scene0002_00", 15 | "scene0003_00", 16 | "scene0004_00", 17 | "scene0005_00", 18 | "scene0006_00", 19 | "scene0007_00", 20 | "scene0008_00", 21 | "scene0009_00", 22 | ] 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--exp", 29 | default="cfg/exp/finetune/deeplabv3_s0.yml", 30 | help= 31 | "Experiment yaml file path relative to template_project_name/cfg/exp directory.", 32 | ) 33 | parser.add_argument( 34 | "--exp_name", 35 | default="debug", 36 | help="overall experiment of this continual learning experiment.", 37 | ) 38 | parser.add_argument("--seed", default=123, type=int) 39 | parser.add_argument( 40 | "--fix_nerf", 41 | action="store_true", 42 | help="whether or not to fix nerf during joint training", 43 | ) 44 | 45 | parser.add_argument("--project_name", default="test_one_by_one") 46 | parser.add_argument("--nerf_train_epoch", default=10, type=int) 47 | 48 | parser.add_argument("--joint_train_epoch", default=10, type=int) 49 | args = parser.parse_args() 50 | return args 51 | 52 | 53 | def main(): 54 | """Main function.""" 55 | env_cfg_path = os.path.join(ROOT_DIR, "cfg/env", 56 | os.environ["ENV_WORKSTATION_NAME"] + ".yml") 57 | env = load_yaml(env_cfg_path) 58 | os.chdir(ROOT_DIR) 59 | args = parse_args() 60 | exp_cfg_path = os.path.join(ROOT_DIR, args.exp) 61 | exp = load_yaml(exp_cfg_path) 62 | exp_name = args.exp_name 63 | exp["exp_name"] = exp_name 64 | 65 | prev_stage = "" 66 | stage = "init" 67 | exp["scenes"] = [] 68 | 69 | for i, new_scene in enumerate(SCENE_ORDER): 70 | exp["scenes"].append(new_scene) 71 | prev_stage = stage 72 | stage = f"stage_{i}" 73 | exp["general"]["name"] = f"{exp_name}/{stage}" 74 | 75 | # train on new class 76 | exp["trainer"]["resume_from_checkpoint"] = False 77 | exp["trainer"]["load_from_checkpoint"] = True 78 | if i == 0: 79 | exp["general"]["load_pretrain"] = True 80 | old_model_path = exp["general"]["checkpoint_load"] 81 | else: 82 | exp["general"]["load_pretrain"] = False 83 | old_model_path = os.path.join("experiments", exp_name, prev_stage, 84 | "deeplab.ckpt") 85 | 86 | exp["general"]["checkpoint_load"] = old_model_path 87 | 88 | print(f"training on: {new_scene}") 89 | 90 | train(exp, env, exp_cfg_path, env_cfg_path, args) 91 | wandb.finish() 92 | 93 | 94 | if __name__ == "__main__": # pragma: no cover 95 | main() 96 | sys.exit(1) 97 | -------------------------------------------------------------------------------- /scripts/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # color palette for nyu40 labels 4 | nyu40_colour_code = np.array([ 5 | (0, 0, 0), 6 | (174, 199, 232), # wall 7 | (152, 223, 138), # floor 8 | (31, 119, 180), # cabinet 9 | (255, 187, 120), # bed 10 | (188, 189, 34), # chair 11 | (140, 86, 75), # sofa 12 | (255, 152, 150), # table 13 | (214, 39, 40), # door 14 | (197, 176, 213), # window 15 | (148, 103, 189), # bookshelf 16 | (196, 156, 148), # picture 17 | (23, 190, 207), # counter 18 | (178, 76, 76), # blinds 19 | (247, 182, 210), # desk 20 | (66, 188, 102), # shelves 21 | (219, 219, 141), # curtain 22 | (140, 57, 197), # dresser 23 | (202, 185, 52), # pillow 24 | (51, 176, 203), # mirror 25 | (200, 54, 131), # floor 26 | (92, 193, 61), # clothes 27 | (78, 71, 183), # ceiling 28 | (172, 114, 82), # books 29 | (255, 127, 14), # refrigerator 30 | (91, 163, 138), # tv 31 | (153, 98, 156), # paper 32 | (140, 153, 101), # towel 33 | (158, 218, 229), # shower curtain 34 | (100, 125, 154), # box 35 | (178, 127, 135), # white board 36 | (120, 185, 128), # person 37 | (146, 111, 194), # night stand 38 | (44, 160, 44), # toilet 39 | (112, 128, 144), # sink 40 | (96, 207, 209), # lamp 41 | (227, 119, 194), # bathtub 42 | (213, 92, 176), # bag 43 | (94, 106, 211), # other struct 44 | (82, 84, 163), # otherfurn 45 | (100, 85, 144), # other prop 46 | ]).astype(np.uint8) 47 | 48 | nyu13_colour_code = ( 49 | np.array([ 50 | [0, 0, 0], 51 | [0, 0, 1], # BED 52 | [0.9137, 0.3490, 0.1882], # BOOKS 53 | [0, 0.8549, 0], # CEILING 54 | [0.5843, 0, 0.9412], # CHAIR 55 | [0.8706, 0.9451, 0.0941], # FLOOR 56 | [1.0000, 0.8078, 0.8078], # FURNITURE 57 | [0, 0.8784, 0.8980], # OBJECTS 58 | [0.4157, 0.5333, 0.8000], # PAINTING 59 | [0.4588, 0.1137, 0.1608], # SOFA 60 | [0.9412, 0.1373, 0.9216], # TABLE 61 | [0, 0.6549, 0.6118], # TV 62 | [0.9765, 0.5451, 0], # WALL 63 | [0.8824, 0.8980, 0.7608], 64 | ]) * 255).astype(np.uint8) 65 | 66 | nyu40_to13 = { 67 | 0: 0, 68 | 1: 12, 69 | 2: 5, 70 | 3: 6, 71 | 4: 1, 72 | 5: 4, 73 | 6: 9, 74 | 7: 10, 75 | 8: 12, 76 | 9: 13, 77 | 10: 6, 78 | 11: 8, 79 | 12: 6, 80 | 13: 13, 81 | 14: 10, 82 | 15: 6, 83 | 16: 13, 84 | 17: 6, 85 | 18: 7, 86 | 19: 7, 87 | 20: 5, 88 | 21: 7, 89 | 22: 3, 90 | 23: 2, 91 | 24: 6, 92 | 25: 11, 93 | 26: 7, 94 | 27: 7, 95 | 28: 7, 96 | 29: 7, 97 | 30: 7, 98 | 31: 7, 99 | 32: 7, 100 | 33: 7, 101 | 34: 7, 102 | 35: 7, 103 | 36: 7, 104 | 37: 7, 105 | 38: 7, 106 | 39: 6, 107 | 40: 7, 108 | } 109 | 110 | nyu40_to_13 = (np.array([ 111 | 0, 112 | 12, 113 | 5, 114 | 6, 115 | 1, 116 | 4, 117 | 9, 118 | 10, 119 | 12, 120 | 13, 121 | 6, 122 | 8, 123 | 6, 124 | 13, 125 | 10, 126 | 6, 127 | 13, 128 | 6, 129 | 7, 130 | 7, 131 | 5, 132 | 7, 133 | 3, 134 | 2, 135 | 6, 136 | 11, 137 | 7, 138 | 7, 139 | 7, 140 | 7, 141 | 7, 142 | 7, 143 | 7, 144 | 7, 145 | 7, 146 | 7, 147 | 7, 148 | 7, 149 | 7, 150 | 7, 151 | 7, 152 | ])).astype(np.uint8) 153 | -------------------------------------------------------------------------------- /scripts/pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import torch 5 | 6 | from pathlib import Path 7 | from pytorch_lightning import Trainer, seed_everything 8 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 9 | from pytorch_lightning.plugins import DDPPlugin 10 | from pytorch_lightning.profiler import AdvancedProfiler 11 | 12 | from nr4seg import ROOT_DIR 13 | from nr4seg.lightning import PretrainDataModule, SemanticsLightningNet 14 | from nr4seg.utils import flatten_dict, get_wandb_logger, load_yaml 15 | 16 | 17 | def train(exp, args) -> float: 18 | seed_everything(args.seed) 19 | 20 | ################################################## LOAD CONFIG ################################################### 21 | # ROOT_DIR is defined int template_project_name/__init__.py 22 | env_cfg_path = os.path.join(ROOT_DIR, "cfg/env", 23 | os.environ["ENV_WORKSTATION_NAME"] + ".yml") 24 | env = load_yaml(env_cfg_path) 25 | #################################################################################################################### 26 | 27 | ############################################ CREAT EXPERIMENT FOLDER ############################################# 28 | 29 | model_path = os.path.join(env["results"], exp["general"]["name"]) 30 | if exp["general"]["clean_up_folder_if_exists"]: 31 | shutil.rmtree(model_path, ignore_errors=True) 32 | 33 | # Create the directory 34 | Path(model_path).mkdir(parents=True, exist_ok=True) 35 | 36 | # Copy config files 37 | exp_cfg_fn = os.path.split(exp_cfg_path)[-1] 38 | env_cfg_fn = os.path.split(env_cfg_path)[-1] 39 | print(f"Copy {env_cfg_path} to {model_path}/{exp_cfg_fn}") 40 | shutil.copy(exp_cfg_path, f"{model_path}/{exp_cfg_fn}") 41 | shutil.copy(env_cfg_path, f"{model_path}/{env_cfg_fn}") 42 | exp["general"]["name"] = model_path 43 | #################################################################################################################### 44 | 45 | ################################################# CREATE LOGGER ################################################## 46 | 47 | logger = get_wandb_logger( 48 | exp=exp, 49 | env=env, 50 | exp_p=exp_cfg_path, 51 | env_p=env_cfg_path, 52 | project_name=args.project_name, 53 | save_dir=model_path, 54 | ) 55 | ex = flatten_dict(exp) 56 | logger.log_hyperparams(ex) 57 | 58 | #################################################################################################################### 59 | 60 | ########################################### CREAET NETWORK AND DATASET ########################################### 61 | model = SemanticsLightningNet(exp, env) 62 | datamodule = PretrainDataModule(env, exp["data_module"]) 63 | #################################################################################################################### 64 | 65 | ################################################# TRAINER SETUP ################################################## 66 | # Callbacks 67 | lr_monitor = LearningRateMonitor(logging_interval="step") 68 | cb_ls = [lr_monitor] 69 | 70 | checkpoint_callback = ModelCheckpoint( 71 | dirpath=model_path, 72 | filename="best-{epoch:02d}-{step:06d}", 73 | verbose=True, 74 | monitor="val/mean_IoU", 75 | mode="max", 76 | save_last=True, 77 | save_top_k=1, 78 | ) 79 | cb_ls.append(checkpoint_callback) 80 | 81 | # set gpus 82 | if (exp["trainer"]).get("gpus", -1) == -1: 83 | nr = torch.cuda.device_count() 84 | print(f"Set GPU Count for Trainer to {nr}!") 85 | for i in range(nr): 86 | print(f"Device {i}: ", torch.cuda.get_device_name(i)) 87 | exp["trainer"]["gpus"] = nr 88 | 89 | # profiler 90 | if exp["trainer"].get("profiler", False): 91 | exp["trainer"]["profiler"] = AdvancedProfiler( 92 | output_filename=os.path.join(model_path, "profile.out")) 93 | else: 94 | exp["trainer"]["profiler"] = False 95 | 96 | # check if restore checkpoint 97 | if exp["trainer"]["resume_from_checkpoint"] is True: 98 | exp["trainer"]["resume_from_checkpoint"] = os.path.join( 99 | env["results"], exp["general"]["checkpoint_load"]) 100 | else: 101 | del exp["trainer"]["resume_from_checkpoint"] 102 | 103 | trainer = Trainer( 104 | **exp["trainer"], 105 | plugins=DDPPlugin(find_unused_parameters=False), 106 | default_root_dir=model_path, 107 | callbacks=cb_ls, 108 | logger=logger, 109 | ) 110 | #################################################################################################################### 111 | 112 | res = trainer.fit(model, datamodule=datamodule) 113 | 114 | return res 115 | 116 | 117 | if __name__ == "__main__": 118 | os.chdir(ROOT_DIR) 119 | 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument( 122 | "--exp", 123 | default="exp.yml", 124 | help= 125 | "Experiment yaml file path relative to template_project_name/cfg/exp directory.", 126 | ) 127 | parser.add_argument("--seed", default=123, type=int) 128 | parser.add_argument("--project_name", default="scannet_debug") 129 | args = parser.parse_args() 130 | exp_cfg_path = os.path.join(ROOT_DIR, args.exp) 131 | exp = load_yaml(exp_cfg_path) 132 | 133 | train(exp, args) 134 | -------------------------------------------------------------------------------- /scripts/train_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | import torch 5 | import shutil 6 | 7 | from pytorch_lightning import Trainer, seed_everything 8 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 9 | from pytorch_lightning.plugins import DDPPlugin 10 | from pytorch_lightning.profiler import AdvancedProfiler 11 | 12 | from nr4seg import ROOT_DIR 13 | from nr4seg.lightning import FineTuneDataModule, SemanticsLightningNet 14 | from nr4seg.utils import flatten_dict, get_wandb_logger, load_yaml 15 | 16 | 17 | def train(exp, env, exp_cfg_path, env_cfg_path, args) -> float: 18 | seed_everything(args.seed) 19 | 20 | #################################################################################################################### 21 | 22 | ############################################ CREAT EXPERIMENT FOLDER ############################################# 23 | model_path = os.path.join(env["results"], exp["general"]["name"]) 24 | if exp["general"]["clean_up_folder_if_exists"]: 25 | shutil.rmtree(model_path, ignore_errors=True) 26 | 27 | # Create the directory 28 | Path(model_path).mkdir(parents=True, exist_ok=True) 29 | 30 | # Copy config files 31 | exp_cfg_fn = os.path.split(exp_cfg_path)[-1] 32 | env_cfg_fn = os.path.split(env_cfg_path)[-1] 33 | print(f"Copy {env_cfg_path} to {model_path}/{exp_cfg_fn}") 34 | shutil.copy(exp_cfg_path, f"{model_path}/{exp_cfg_fn}") 35 | shutil.copy(env_cfg_path, f"{model_path}/{env_cfg_fn}") 36 | exp["general"]["name"] = model_path 37 | #################################################################################################################### 38 | 39 | ################################################# CREATE LOGGER ################################################## 40 | 41 | logger = get_wandb_logger( 42 | exp=exp, 43 | env=env, 44 | exp_p=exp_cfg_path, 45 | env_p=env_cfg_path, 46 | project_name=args.project_name, 47 | save_dir=model_path, 48 | ) 49 | ex = flatten_dict(exp) 50 | logger.log_hyperparams(ex) 51 | 52 | #################################################################################################################### 53 | 54 | ########################################### CREAET NETWORK AND DATASET ########################################### 55 | model = SemanticsLightningNet(exp, env) 56 | datamodule = FineTuneDataModule(exp, env, prev_exp_name=args.prev_exp_name) 57 | #################################################################################################################### 58 | 59 | ################################################# TRAINER SETUP ################################################## 60 | # Callbacks 61 | lr_monitor = LearningRateMonitor(logging_interval="step") 62 | cb_ls = [lr_monitor] 63 | 64 | checkpoint_callback = ModelCheckpoint( 65 | dirpath=model_path, 66 | save_last=True, 67 | save_top_k=1, 68 | ) 69 | cb_ls.append(checkpoint_callback) 70 | 71 | # set gpus 72 | if (exp["trainer"]).get("gpus", -1) == -1: 73 | nr = torch.cuda.device_count() 74 | print(f"Set GPU Count for Trainer to {nr}!") 75 | for i in range(nr): 76 | print(f"Device {i}: ", torch.cuda.get_device_name(i)) 77 | exp["trainer"]["gpus"] = nr 78 | 79 | # profiler 80 | if exp["trainer"].get("profiler", False): 81 | exp["trainer"]["profiler"] = AdvancedProfiler( 82 | output_filename=os.path.join(model_path, "profile.out")) 83 | else: 84 | exp["trainer"]["profiler"] = False 85 | 86 | # check if restore checkpoint 87 | if exp["trainer"]["resume_from_checkpoint"] is True: 88 | exp["trainer"]["resume_from_checkpoint"] = exp["general"][ 89 | "checkpoint_load"] 90 | else: 91 | del exp["trainer"]["resume_from_checkpoint"] 92 | 93 | if exp["trainer"]["load_from_checkpoint"] is True: 94 | checkpoint = torch.load(exp["general"]["checkpoint_load"]) 95 | checkpoint = checkpoint["state_dict"] 96 | # remove any aux classifier stuff 97 | removekeys = [ 98 | key for key in checkpoint.keys() 99 | if key.startswith("_model._model.aux_classifier") 100 | ] 101 | for key in removekeys: 102 | del checkpoint[key] 103 | model.load_state_dict(checkpoint, strict=True) 104 | 105 | del exp["trainer"]["load_from_checkpoint"] 106 | 107 | trainer = Trainer( 108 | **exp["trainer"], 109 | plugins=DDPPlugin(find_unused_parameters=False), 110 | default_root_dir=model_path, 111 | callbacks=cb_ls, 112 | logger=logger, 113 | ) 114 | #################################################################################################################### 115 | trainer.validate(model=model, datamodule=datamodule) 116 | trainer.test(model=model, datamodule=datamodule) 117 | trainer.fit(model, datamodule=datamodule) 118 | trainer.test(model=model, datamodule=datamodule) 119 | 120 | 121 | if __name__ == "__main__": 122 | os.chdir(ROOT_DIR) 123 | 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument( 126 | "--exp", 127 | default="cfg/exp/finetune/deeplabv3_s0.yml", 128 | help= 129 | "Experiment yaml file path relative to template_project_name/cfg/exp directory.", 130 | ) 131 | parser.add_argument("--seed", default=123, type=int) 132 | parser.add_argument("--project_name", default="scannet_debug") 133 | parser.add_argument("--prev_exp_name", default="one_step_nerf_only") 134 | args = parser.parse_args() 135 | exp_cfg_path = os.path.join(ROOT_DIR, args.exp) 136 | exp = load_yaml(exp_cfg_path) 137 | env_cfg_path = os.path.join(ROOT_DIR, "cfg/env", 138 | os.environ["ENV_WORKSTATION_NAME"] + ".yml") 139 | env = load_yaml(env_cfg_path) 140 | 141 | train(exp, env, exp_cfg_path, env_cfg_path, args) 142 | -------------------------------------------------------------------------------- /scripts/train_joint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import torch 5 | 6 | from pathlib import Path 7 | from pytorch_lightning import Trainer, seed_everything 8 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 9 | from pytorch_lightning.plugins import DDPPlugin 10 | 11 | from nr4seg import ROOT_DIR 12 | from nr4seg.lightning import JointTrainDataModule, JointTrainLightningNet 13 | from nr4seg.utils import flatten_dict, get_wandb_logger, load_yaml 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--exp", 20 | default="cfg/exp/finetune/deeplabv3_s0.yml", 21 | help= 22 | ("Experiment yaml file path relative to template_project_name/cfg/exp " 23 | "directory."), 24 | ) 25 | parser.add_argument( 26 | "--exp_name", 27 | default="debug", 28 | help="overall experiment of this continual learning experiment.", 29 | ) 30 | 31 | parser.add_argument( 32 | "--fix_nerf", 33 | action="store_true", 34 | help="whether or not to fix nerf during joint training", 35 | ) 36 | 37 | parser.add_argument("--seed", default=123, type=int) 38 | 39 | parser.add_argument("--project_name", default="test_one_by_one") 40 | parser.add_argument("--nerf_train_epoch", default=10, type=int) 41 | 42 | parser.add_argument("--joint_train_epoch", default=10, type=int) 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def train(exp, env, exp_cfg_path, env_cfg_path, args) -> float: 48 | seed_everything(args.seed) 49 | exp["exp_name"] = args.exp_name 50 | exp["fix_nerf"] = args.fix_nerf 51 | 52 | # Create experiment folder. 53 | model_path = os.path.join(env["results"], exp["general"]["name"]) 54 | if exp["general"]["clean_up_folder_if_exists"]: 55 | shutil.rmtree(model_path, ignore_errors=True) 56 | 57 | # Create the directory 58 | Path(model_path).mkdir(parents=True, exist_ok=True) 59 | 60 | # Copy config files 61 | exp_cfg_fn = os.path.split(exp_cfg_path)[-1] 62 | env_cfg_fn = os.path.split(env_cfg_path)[-1] 63 | print(f"Copy {env_cfg_path} to {model_path}/{exp_cfg_fn}") 64 | shutil.copy(exp_cfg_path, f"{model_path}/{exp_cfg_fn}") 65 | shutil.copy(env_cfg_path, f"{model_path}/{env_cfg_fn}") 66 | exp["general"]["name"] = model_path 67 | 68 | # Create logger. 69 | logger = get_wandb_logger( 70 | exp=exp, 71 | env=env, 72 | exp_p=exp_cfg_path, 73 | env_p=env_cfg_path, 74 | project_name=args.project_name, 75 | save_dir=model_path, 76 | ) 77 | ex = flatten_dict(exp) 78 | logger.log_hyperparams(ex) 79 | 80 | # Create network and dataset. 81 | model = JointTrainLightningNet(exp, env) 82 | datamodule = JointTrainDataModule(exp, env) 83 | datamodule.setup() 84 | 85 | # Trainer setup. 86 | # - Callbacks 87 | lr_monitor = LearningRateMonitor(logging_interval="step") 88 | cb_ls = [lr_monitor] 89 | 90 | checkpoint_callback = ModelCheckpoint( 91 | dirpath=model_path, 92 | save_last=True, 93 | save_top_k=1, 94 | ) 95 | cb_ls.append(checkpoint_callback) 96 | # - Set GPUs. 97 | if (exp["trainer"]).get("gpus", -1) == -1: 98 | nr = torch.cuda.device_count() 99 | print(f"Set GPU Count for Trainer to {nr}!") 100 | for i in range(nr): 101 | print(f"Device {i}: ", torch.cuda.get_device_name(i)) 102 | exp["trainer"]["gpus"] = nr 103 | 104 | # - Check whether to restore checkpoint. 105 | if exp["trainer"]["resume_from_checkpoint"] is True: 106 | exp["trainer"]["resume_from_checkpoint"] = exp["general"][ 107 | "checkpoint_load"] 108 | else: 109 | del exp["trainer"]["resume_from_checkpoint"] 110 | 111 | if exp["trainer"]["load_from_checkpoint"] is True: 112 | if exp["general"]["load_pretrain"]: 113 | checkpoint = torch.load(exp["general"]["checkpoint_load"]) 114 | checkpoint = checkpoint["state_dict"] 115 | # remove any aux classifier stuff 116 | removekeys = [ 117 | key for key in checkpoint.keys() 118 | if key.startswith("_model._model.aux_classifier") 119 | ] 120 | for key in removekeys: 121 | del checkpoint[key] 122 | 123 | seg_model_state_dict = {} 124 | for key in checkpoint.keys(): 125 | seg_model_key = key.split(".", 1)[1] 126 | seg_model_state_dict[seg_model_key] = checkpoint[key] 127 | 128 | model.seg_model.load_state_dict(seg_model_state_dict, strict=True) 129 | else: 130 | checkpoint = torch.load(exp["general"]["checkpoint_load"]) 131 | checkpoint = checkpoint["state_dict"] 132 | model.seg_model.load_state_dict(checkpoint) 133 | 134 | del exp["trainer"]["load_from_checkpoint"] 135 | 136 | # - Add distributed plugin. 137 | if exp["trainer"]["gpus"] > 1: 138 | if (exp["trainer"]["accelerator"] == "ddp" or 139 | exp["trainer"]["accelerator"] is None): 140 | ddp_plugin = DDPPlugin(find_unused_parameters=exp["trainer"].get( 141 | "find_unused_parameters", False)) 142 | exp["trainer"]["plugins"] = [ddp_plugin] 143 | 144 | exp["trainer"]["max_epochs"] = args.nerf_train_epoch 145 | 146 | trainer_nerf = Trainer( 147 | **exp["trainer"], 148 | default_root_dir=model_path, 149 | logger=logger, 150 | callbacks=cb_ls, 151 | ) 152 | 153 | exp["trainer"]["check_val_every_n_epoch"] = 10 154 | exp["trainer"]["max_epochs"] = args.joint_train_epoch 155 | trainer_joint = Trainer( 156 | **exp["trainer"], 157 | default_root_dir=model_path, 158 | logger=logger, 159 | callbacks=cb_ls, 160 | ) 161 | 162 | # Train NeRF. 163 | model.joint_train = False 164 | trainer_nerf.fit(model, 165 | train_dataloaders=datamodule.train_dataloader_nerf()) 166 | # test initial nerf performance on the training set 167 | trainer_joint.test(model, dataloaders=datamodule.test_dataloader_nerf()) 168 | # # test initial seg performance on the validation set 169 | trainer_joint.validate(model, dataloaders=datamodule.val_dataloader()) 170 | # joint train + old scenes 171 | model.joint_train = True 172 | # trainer_seg.fit(model, train_dataloaders=datamodule.train_dataloader_seg(), val_dataloaders=datamodule.val_dataloader()) 173 | trainer_joint.fit( 174 | model, 175 | train_dataloaders=datamodule.train_dataloader_joint(), 176 | val_dataloaders=datamodule.val_dataloader(), 177 | ) 178 | # test nerf performance on the training set after joint training + test generalization performance on scannet 25k 179 | trainer_joint.test(model, dataloaders=datamodule.train_dataloader_nerf()) 180 | # predict pseudo labels 181 | trainer_joint.predict(model, dataloaders=datamodule.predict_dataloader()) 182 | # save checkpoint of the deeplab model 183 | torch.save( 184 | {"state_dict": model.seg_model.state_dict()}, 185 | os.path.join(model_path, "deeplab.ckpt"), 186 | ) 187 | 188 | 189 | if __name__ == "__main__": 190 | os.chdir(ROOT_DIR) 191 | args = parse_args() 192 | exp_cfg_path = os.path.join(ROOT_DIR, args.exp) 193 | exp = load_yaml(exp_cfg_path) 194 | exp["general"]["load_pretrain"] = True 195 | env_cfg_path = os.path.join(ROOT_DIR, "cfg/env", 196 | os.environ["ENV_WORKSTATION_NAME"] + ".yml") 197 | env = load_yaml(env_cfg_path) 198 | train(exp, env, exp_cfg_path, env_cfg_path, args) 199 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [yapf] 2 | based_on_style = google 3 | indent_width = 4 4 | column_limit = 80 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | from setuptools import find_packages 4 | 5 | setup( 6 | name="nr4seg", 7 | version="0.0.1", 8 | author="Zhizheng Liu, Francesco Milano", 9 | author_email="liuzhi@student.ethz.ch, francesco.milano@mavt.ethz.ch", 10 | packages=find_packages(), 11 | python_requires=">=3.6", 12 | description= 13 | "[CVPR 2023] Unsupervised Continual Semantic Adaptation through Neural Rendering", 14 | ) 15 | --------------------------------------------------------------------------------