├── .gitignore
├── README.md
├── assets
└── teaser.png
├── cfg
├── dataset
│ └── scannet
│ │ ├── split.npz
│ │ └── split_cl.npz
├── env
│ └── env.yml
└── exp
│ ├── multi_step
│ ├── cl_base.yml
│ └── cl_base_novel_viewpoints.yml
│ ├── one_step_finetune_nerf
│ ├── s00_lr1e-5.yml
│ ├── s10_lr1e-5.yml
│ ├── s20_lr1e-5.yml
│ ├── s30_lr1e-5.yml
│ ├── s40_lr1e-5.yml
│ ├── s50_lr1e-5.yml
│ ├── s60_lr1e-5.yml
│ ├── s70_lr1e-5.yml
│ ├── s80_lr1e-5.yml
│ └── s90_lr1e-5.yml
│ ├── one_step_joint
│ ├── s00_lr1e-5.yml
│ ├── s10_lr1e-5.yml
│ ├── s20_lr1e-5.yml
│ ├── s30_lr1e-5.yml
│ ├── s40_lr1e-5.yml
│ ├── s50_lr1e-5.yml
│ ├── s60_lr1e-5.yml
│ ├── s70_lr1e-5.yml
│ ├── s80_lr1e-5.yml
│ └── s90_lr1e-5.yml
│ └── pretrain_scannet_25k_deeplabv3.yml
├── nr4seg
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── create_split.py
│ ├── helper.py
│ ├── label_loader.py
│ ├── ngp_utils.py
│ ├── scannet.py
│ ├── scannet_cl.py
│ ├── scannet_cl_joint.py
│ ├── scannet_ngp.py
│ └── scannet_ngp_joint.py
├── lightning
│ ├── __init__.py
│ ├── finetune_data_module.py
│ ├── joint_train_data_module.py
│ ├── joint_train_lightning_net.py
│ ├── pretrain_data_module.py
│ └── semantics_lightning_net.py
├── nerf
│ ├── __init__.py
│ ├── activation.py
│ ├── network_tcnn_semantics.py
│ ├── raymarching
│ │ ├── __init__.py
│ │ ├── backend.py
│ │ ├── raymarching.py
│ │ └── src
│ │ │ ├── bindings.cpp
│ │ │ ├── pcg32.h
│ │ │ ├── raymarching.cu
│ │ │ └── raymarching.h
│ └── renderer_semantics.py
├── network
│ ├── __init__.py
│ └── deeplabv3.py
├── utils
│ ├── __init__.py
│ ├── flatten_dict.py
│ ├── get_logger.py
│ ├── loading.py
│ └── metrics.py
└── visualizer
│ ├── __init__.py
│ ├── colormaps.py
│ └── visualizer.py
├── preprocessing_scripts
├── scannet2nerf.py
├── scannet2transform.py
└── utils.py
├── requirements.txt
├── run_scripts
├── multi_step.sh
├── one_step_finetune_train.sh
├── one_step_joint_train.sh
├── one_step_nerf_only_train.sh
├── preprocess_scannet.sh
└── pretrain.sh
├── scripts
├── cl_deeplab.py
├── eval_utils.py
├── pretrain.py
├── train_finetune.py
└── train_joint.py
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # no IntelliJ files
2 | .idea
3 |
4 | # don't upload macOS folder info
5 | *.DS_Store
6 |
7 | # don't upload node_modules from npm test
8 | node_modules/*
9 | flow-typed/*
10 |
11 | # potential files generated by golang
12 | bin/
13 |
14 | # don't upload webpack bundle file
15 | app/dist/
16 |
17 | # default data directory
18 | local-data/
19 |
20 | # potential integration testing data directory
21 | test_data/
22 |
23 | #python
24 | .pyc
25 | __pycache__/
26 |
27 | # pytype
28 | .pytype
29 |
30 | # vscode sftp settings
31 | .vscode/sftp.json
32 |
33 | # redis
34 | *.rdb
35 |
36 | # mypy
37 | .mypy_cache
38 |
39 | # jest coverage cache
40 | coverage/
41 |
42 | # python cache
43 | __pycache__/
44 |
45 | # python virtual environment
46 | env/
47 |
48 | # local data folder
49 | data/*
50 |
51 | # vscode workspace configuration
52 | *.code-workspace
53 |
54 | # sphinx build folder
55 | _build/
56 |
57 | # media files are not in this repo
58 | doc/media
59 |
60 | # ignore rope db cache
61 | .vscode/.ropeproject/objectdb
62 |
63 | # ignore vscode launch file
64 | .vscode/launch.json
65 |
66 | # python build
67 | build/
68 | dist/
69 |
70 | # package egg info
71 | *.egg-info
72 |
73 | # legacy files
74 | archived
75 |
76 | # unrelated files
77 | *.pkl
78 | .dist_tmp
79 |
80 | # project files
81 | experiments
82 | ckpts
83 | wandb
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Unsupervised Continual Semantic Adaptation through Neural Rendering
2 |
3 |
4 | Zhizheng Liu*, Francesco Milano*, Jonas Frey, Roland Siegwart, Hermann Blum, Cesar Cadena
5 |
6 |
7 | CVPR 2023
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | We present a framework to improve semantic scene understanding for agents that are deployed across a _sequence of scenes_. In particular, our method performs unsupervised continual semantic adaptation by jointly training a _2-D segmentation model_ and a _Semantic-NeRF network_.
17 |
18 | - Our framework allows successfully adapting the 2-D segmentation model across _multiple, previously unseen scenes_ and with _no ground-truth supervision_, reducing the domain gap in the new scenes and improving on the initial performance of the model.
19 | - By rendering training and novel views, the pipeline can effectively _mitigate forgetting_ and even _gain additional knowledge_ about the previous scenes.
20 |
21 | ## Table of Contents
22 |
23 | 1. [Installation](#installation)
24 | 2. [Running experiments](#running-experiments)
25 | 3. [Citation](#citation)
26 | 4. [Acknowledgements](#acknowledgements)
27 | 5. [Contact](#contact)
28 |
29 | ## Installation
30 |
31 | ### Workspace setup
32 |
33 | We recommend configuring your workspace with a conda environment. You can then install the project and its dependencies as follows. The instructions were tested on Ubuntu 20.04 and 22.04, with CUDA 11.3.
34 |
35 | - Clone this repo to a folder of your choice, which in the following we will refer to with the environmental variable `REPO_ROOT`:
36 | ```bash
37 | export REPO_ROOT=
38 | cd ${REPO_ROOT};
39 | git clone git@github.com:ethz-asl/nr_semantic_segmentation.git
40 | ```
41 | - Create a conda environment and install [PyTorch](https://pytorch.org/), [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) and other dependencies:
42 |
43 | ```bash
44 | conda create -n nr4seg python=3.8
45 | conda activate nr4seg
46 | python -m pip install --upgrade pip
47 |
48 | # For CUDA 11.3.
49 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
50 | # Install tiny-cuda-nn
51 | pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
52 |
53 | pip install -r requirements.txt
54 |
55 | python setup.py develop
56 | ```
57 |
58 | ### Setting up the dataset
59 |
60 | We use the [ScanNet v2](http://www.scan-net.org/) [1] dataset for our experiments.
61 |
62 | > [1] Angela Dai, Angel X. Chang, Manolis Savva, Maciej Halber, Thomas Funkhouser, and Matthias Nießner, "ScanNet: Richly-annotated 3D Reconstructions of Indoor Scenes", in _Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)_, pp. 2432-2443, 2017.
63 |
64 | #### Dataset download
65 |
66 | - To get started, visit the official [ScanNet](https://github.com/ScanNet/ScanNet#scannet-data) dataset website to obtain the data downloading permission and script. We keep all the ScanNet data in the `${REPO_ROOT}/data/scannet` folder. You may use a symbolic link if necessary (_e.g._, `ln -s ${REPO_ROOT}/data/scannet`).
67 | - As detailed in the paper, we use scenes `0000` to `0009` to perform continual semantic adaptation, and a subset of data from the remaining scenes (`0010` to `0706`) to pre-train the segmentation network.
68 | - The data from the pre-training scenes are conveniently provided already in the correct format by the ScanNet dataset, as a `scannet_frames_25k.zip` file. You can download this file using the official download script that you should have received after requesting access to the dataset, specifying the `--preprocessed_frames` flag. Once downloaded, the content of the file should be extracted to the subfolder `${REPO_ROOT}/data/scannet/scannet_frames_25k`.
69 | - For the scenes used to perform continual semantic adaptation, the full data are required. To obtain them, run the official download script, specifying through the flag `--id` the scene to download (_e.g._, `--id scene0000_00` to download scene `0000`) and including the `--label_map` flag, to download also the label mapping file `scannetv2-labels.combined.tsv` (cf. [here](https://github.com/ScanNet/ScanNet#labels)). The downloaded data should be stored in the subfolder `${REPO_ROOT}/data/scannet/scans`. Next, extract all the sensor data (depth images, color images, poses, intrinsics) using the [SensReader](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python) tool provided by ScanNet, for each of the downloaded scenes from `0000` to `0009`. For instance, for scene `0000`, run
70 | ```bash
71 | python2 reader.py --filename ${REPO_ROOT}/data/scannet/scans/scene0000_00/scene0000_00.sens --output_path ${REPO_ROOT}/data/scannet/scans/scene0000_00 --export_depth_images --export_color_images --export_poses --export_intrinsics
72 | ```
73 | To obtain the raw labels (for evaluation purposes) for each of the continual adaptation scenes, also extract the content of the `sceneXXXX_XX_2d-label-filt.zip` file, so that a `${REPO_ROOT}/data/scannet/scans/sceneXXXX_XX/label-filt` folder is created.
74 | - Copy the `scannetv2-labels.combined.tsv` file to each scene folder under `${REPO_ROOT}/data/scannet/scans`, as well as to the subfolder `${REPO_ROOT}/data/scannet/scannet_frames_25k`.
75 |
76 | - At the end of the process, the `${REPO_ROOT}/data` folder should contain _at least_ the following data, structured as below:
77 |
78 | ```shell
79 | scannet
80 | scannet_frames_25k
81 | scene0010_00
82 | color
83 | 000000.jpg
84 | ...
85 | XXXXXX.jpg
86 | label
87 | 000000.png
88 | ...
89 | XXXXXX.png
90 | ...
91 | ...
92 | scene0706_00
93 | ...
94 | scannetv2-labels.combined.tsv
95 | scans
96 | scene0000_00
97 | color
98 | 000000.jpg
99 | ...
100 | XXXXXX.jpg
101 | depth
102 | 000000.png
103 | ...
104 | XXXXXX.png
105 | label-filt
106 | 000000.png
107 | ...
108 | XXXXXX.png
109 | pose
110 | 000000.txt
111 | ...
112 | XXXXXX.txt
113 | intrinsics
114 | intriniscs_color.txt
115 | intrinsics_depth.txt
116 | scannetv2-labels.combined.tsv
117 | ...
118 | scene0009_00
119 | ...
120 | ```
121 |
122 | You may define the data subfolders differently by adjusting the `scannet` and `scannet_frames_25k` fields in [`cfg/env/env.yml`](./cfg/env/env.yml). You may also define several config files and set the configuration to use by specifying the `ENV_WORKSTATION_NAME` environmental variable before running the code (_e.g._, `export ENV_WORKSTATION_NAME="gpu_machine"` to use the config in `cfg/env/gpu_machine.yml`).
123 |
124 | - Copy the files [`split.npz`](./cfg/dataset/scannet/split.npz) and [`split_cl.npz`](./cfg/dataset/scannet/split_cl.npz) from the `${REPO_ROOT}/cfg/dataset/scannet/` folder to the `${REPO_ROOT}/data/scannet/scannet_frames_25k` folder. These files contain the indices of the samples that define the train/validation splits used in pre-training and to form the replay buffer in continual adaptation, to ensure reproducibility.
125 |
126 | #### Dataset pre-processing
127 |
128 | After organizing the ScanNet files as detailed above, run the following script to pre-process the files:
129 |
130 | ```bash
131 | bash run_scripts/preprocess_scannet.sh
132 | ```
133 |
134 | After pre-processing, the folder structure for each `sceneXXXX_XX` from `scene0000_00` to `scene0009_00` should look as follows:
135 |
136 | ```shell
137 | sceneXXXX_XX
138 | color
139 | 000000.jpg
140 | ...
141 | XXXXXX.jpg
142 | color_scaled
143 | 000000.jpg
144 | ...
145 | XXXXXX.jpg
146 | depth
147 | 000000.png
148 | ...
149 | XXXXXX.png
150 | label_40
151 | 000000.png
152 | ...
153 | XXXXXX.png
154 | label_40_scaled
155 | 000000.png
156 | ...
157 | XXXXXX.png
158 | label-filt
159 | 000000.png
160 | ...
161 | XXXXXX.png
162 | pose
163 | 000000.txt
164 | ...
165 | XXXXXX.txt
166 | intrinsics
167 | intriniscs_color.txt
168 | intrinsics_depth.txt
169 | scannetv2-labels.combined.tsv
170 | transforms_test.json
171 | transforms_test_scaled_semantics_40_raw.json
172 | transforms_train.json
173 | transforms_train_scaled_semantics_40_raw.json
174 | ```
175 |
176 | ## Running experiments
177 |
178 | By default, the data produced when running the code is stored in the `${REPO_ROOT}/experiments` folder. You can modify this by changing the `results` field in [`cfg/env/env.yml`](./cfg/env/env.yml).
179 |
180 | ### DeepLabv3 pre-training
181 |
182 | To pre-train the DeepLabv3 segmentation network on scenes `0010` to `0706`, run the following script:
183 |
184 | ```bash
185 | bash run_scripts/pretrain.sh --exp cfg/exp/pretrain_scannet_25k_deeplabv3.yml
186 | ```
187 |
188 | Alternatively, we provide a pre-trained DeepLabv3 [checkpoint](https://www.research-collection.ethz.ch/bitstream/handle/20.500.11850/637142/best-epoch143-step175536.ckpt), which you may download to the `${REPO_ROOT}/ckpts` folder.
189 |
190 | ### One-step experiments
191 |
192 | This Section contains instruction on how to perform one-step adaptation experiments (cf. Sec. 4.4 in the main paper).
193 |
194 | #### Fine-tuning
195 |
196 | For fine-tuning, NeRF pseudo-labels should first be generated by running NeRF-only training:
197 |
198 | ```bash
199 | bash run_scripts/one_step_nerf_only_train.sh
200 | ```
201 |
202 | Next, run
203 |
204 | ```bash
205 | bash run_scripts/one_step_finetune_train.sh
206 | ```
207 |
208 | to fine-tune DeepLabv3 with the NeRF pseudo-labels. Please make sure the variable `prev_exp_name` defined in the [fine-tuning script](./run_scripts/one_step_finetune_train.sh) matches the variable `name` in the [NeRF-only script](./run_scripts/one_step_nerf_only_train.sh).
209 |
210 | #### Joint-training
211 |
212 | To perform one-step joint training, run
213 |
214 | ```bash
215 | bash run_scripts/one_step_joint_train.sh
216 | ```
217 |
218 | ### Multi-step experiments
219 |
220 | To perform multi-step adaptation experiments (cf. Sec. 4.5 in the main paper), run the following commands:
221 |
222 | ```bash
223 | # Using training views for replay.
224 | bash run_scripts/multi_step.sh --exp cfg/exp/multi_step/cl_base.yml
225 | # Using novel views for "replay".
226 | bash run_scripts/multi_step.sh --exp cfg/exp/multi_step/cl_base_novel_viewpoints.yml
227 | ```
228 |
229 | ### Logging
230 |
231 | By default, we use [WandB](https://wandb.ai/site) to log our experiments. You can initialize WandB logging by running
232 |
233 | ```bash
234 | wandb init -e ${YOUR_WANDB_ENTITY}
235 | ```
236 | in the terminal. Alternatively, you can disable all logging by defining `export WANDB_MODE=disabled` before launching the experiments.
237 |
238 | ### Seeding
239 |
240 | To obtain the variances of the results, we run the above experiments multiple times with different seeds by specifying `--seed` in the argument.
241 |
242 | ## Citation
243 |
244 | If you find our code or paper useful, please cite:
245 |
246 | ```bibtex
247 | @inproceedings{Liu2023UnsupervisedContinualSemanticAdaptationNR,
248 | author = {Liu, Zhizheng and Milano, Francesco and Frey, Jonas and Siegwart, Roland and Blum, Hermann and Cadena, Cesar},
249 | title = {Unsupervised Continual Semantic Adaptation through Neural Rendering},
250 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
251 | year = {2023}
252 | }
253 | ```
254 |
255 | ## Acknowledgements
256 |
257 | Parts of the NeRF implementation are adapted from [torch-ngp](https://github.com/ashawkey/torch-ngp), [Semantic-NeRF](https://github.com/Harry-Zhi/semantic_nerf/), and [Instant-NGP](https://github.com/NVlabs/instant-ngp).
258 |
259 | ## Contact
260 |
261 | Contact [Zhizheng Liu](mailto:liuzhi@student.ethz.ch) and [Francesco Milano](mailto:francesco.milano@mavt.ethz.ch) for questions, comments, and reporting bugs.
262 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethz-asl/ucsa_neural_rendering/d29b37445e7e6be2cf04fa70ccebebbcbec0c1bf/assets/teaser.png
--------------------------------------------------------------------------------
/cfg/dataset/scannet/split.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethz-asl/ucsa_neural_rendering/d29b37445e7e6be2cf04fa70ccebebbcbec0c1bf/cfg/dataset/scannet/split.npz
--------------------------------------------------------------------------------
/cfg/dataset/scannet/split_cl.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethz-asl/ucsa_neural_rendering/d29b37445e7e6be2cf04fa70ccebebbcbec0c1bf/cfg/dataset/scannet/split_cl.npz
--------------------------------------------------------------------------------
/cfg/env/env.yml:
--------------------------------------------------------------------------------
1 | results: experiments
2 | scannet: data/scannet/scans
3 | scannet_frames_25k: data/scannet/scannet_frames_25k
4 |
--------------------------------------------------------------------------------
/cfg/exp/multi_step/cl_base.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0000_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 2
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | - scene0000_00
49 | # - scene0001_00
50 | # - scene0002_00
51 | # - scene0003_00
52 | # - scene0004_00
53 | # - scene0005_00
54 | # - scene0006_00
55 | # - scene0007_00
56 | # - scene0008_00
57 | # - scene0009_00
58 |
59 |
60 | cl:
61 | active: true
62 | 25k_fraction: 0.1
63 | ngp_25k_ratio: 1
64 | use_novel_viewpoints: False
65 | replay_buffer_size: 100
66 |
--------------------------------------------------------------------------------
/cfg/exp/multi_step/cl_base_novel_viewpoints.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0000_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 2
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 |
48 | scenes:
49 | - scene0000_00
50 | # - scene0001_00
51 | # - scene0002_00
52 | # - scene0003_00
53 | # - scene0004_00
54 | # - scene0005_00
55 | # - scene0006_00
56 | # - scene0007_00
57 | # - scene0008_00
58 | # - scene0009_00
59 |
60 |
61 | cl:
62 | active: true
63 | 25k_fraction: 0.1
64 | ngp_25k_ratio: 1
65 | use_novel_viewpoints: True
66 | replay_buffer_size: 100
67 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s00_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0000_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | - scene0000_00
50 | # - scene0001_00
51 | # - scene0002_00
52 | # - scene0003_00
53 | # - scene0004_00
54 | # - scene0005_00
55 | # - scene0006_00
56 | # - scene0007_00
57 | # - scene0008_00
58 | # - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s10_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0001_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | # - scene0000_00
50 | - scene0001_00
51 | # - scene0002_00
52 | # - scene0003_00
53 | # - scene0004_00
54 | # - scene0005_00
55 | # - scene0006_00
56 | # - scene0007_00
57 | # - scene0008_00
58 | # - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s20_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0002_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | # - scene0000_00
50 | # - scene0001_00
51 | - scene0002_00
52 | # - scene0003_00
53 | # - scene0004_00
54 | # - scene0005_00
55 | # - scene0006_00
56 | # - scene0007_00
57 | # - scene0008_00
58 | # - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
65 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s30_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0003_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | # - scene0000_00
50 | # - scene0001_00
51 | # - scene0002_00
52 | - scene0003_00
53 | # - scene0004_00
54 | # - scene0005_00
55 | # - scene0006_00
56 | # - scene0007_00
57 | # - scene0008_00
58 | # - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
65 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s40_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0004_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | # - scene0000_00
50 | # - scene0001_00
51 | # - scene0002_00
52 | # - scene0003_00
53 | - scene0004_00
54 | # - scene0005_00
55 | # - scene0006_00
56 | # - scene0007_00
57 | # - scene0008_00
58 | # - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
65 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s50_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0005_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | # - scene0000_00
50 | # - scene0001_00
51 | # - scene0002_00
52 | # - scene0003_00
53 | # - scene0004_00
54 | - scene0005_00
55 | # - scene0006_00
56 | # - scene0007_00
57 | # - scene0008_00
58 | # - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
65 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s60_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0006_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | # - scene0000_00
50 | # - scene0001_00
51 | # - scene0002_00
52 | # - scene0003_00
53 | # - scene0004_00
54 | # - scene0005_00
55 | - scene0006_00
56 | # - scene0007_00
57 | # - scene0008_00
58 | # - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
65 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s70_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0007_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | # - scene0000_00
50 | # - scene0001_00
51 | # - scene0002_00
52 | # - scene0003_00
53 | # - scene0004_00
54 | # - scene0005_00
55 | # - scene0006_00
56 | - scene0007_00
57 | # - scene0008_00
58 | # - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s80_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0008_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | # - scene0000_00
50 | # - scene0001_00
51 | # - scene0002_00
52 | # - scene0003_00
53 | # - scene0004_00
54 | # - scene0005_00
55 | # - scene0006_00
56 | # - scene0007_00
57 | - scene0008_00
58 | # - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_finetune_nerf/s90_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: finetune_nerf_train/scannet_scene0009_00
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr: 1.0e-5
16 | name: Adam
17 |
18 | trainer:
19 | max_epochs: 50
20 | gpus: -1
21 | num_sanity_val_steps: 0
22 | check_val_every_n_epoch: 1
23 | resume_from_checkpoint: False
24 | load_from_checkpoint: True
25 |
26 | data_module:
27 | pin_memory: true
28 | batch_size: 4
29 | shuffle: true
30 | num_workers: 2
31 | drop_last: false
32 | root: data/scannet/scannet_frames_25k
33 | train_image: nerf
34 | train_label: nerf
35 | data_preprocessing:
36 | val_ratio: 0.2
37 | image_regex: /*/color/*.jpg
38 | split_file: split.npz
39 | split_file_cl: split_cl.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
48 | scenes:
49 | # - scene0000_00
50 | # - scene0001_00
51 | # - scene0002_00
52 | # - scene0003_00
53 | # - scene0004_00
54 | # - scene0005_00
55 | # - scene0006_00
56 | # - scene0007_00
57 | # - scene0008_00
58 | - scene0009_00
59 |
60 | cl:
61 | active: false
62 | use_novel_viewpoints: False
63 | replay_buffer_size: 0
64 |
65 |
66 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s00_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0000_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | - scene0000_00
49 | # - scene0001_00
50 | # - scene0002_00
51 | # - scene0003_00
52 | # - scene0004_00
53 | # - scene0005_00
54 | # - scene0006_00
55 | # - scene0007_00
56 | # - scene0008_00
57 | # - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s10_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0001_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | # - scene0000_00
49 | - scene0001_00
50 | # - scene0002_00
51 | # - scene0003_00
52 | # - scene0004_00
53 | # - scene0005_00
54 | # - scene0006_00
55 | # - scene0007_00
56 | # - scene0008_00
57 | # - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
65 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s20_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0002_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | # - scene0000_00
49 | # - scene0001_00
50 | - scene0002_00
51 | # - scene0003_00
52 | # - scene0004_00
53 | # - scene0005_00
54 | # - scene0006_00
55 | # - scene0007_00
56 | # - scene0008_00
57 | # - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s30_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0003_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | # - scene0000_00
49 | # - scene0001_00
50 | # - scene0002_00
51 | - scene0003_00
52 | # - scene0004_00
53 | # - scene0005_00
54 | # - scene0006_00
55 | # - scene0007_00
56 | # - scene0008_00
57 | # - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
65 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s40_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0004_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | # - scene0000_00
49 | # - scene0001_00
50 | # - scene0002_00
51 | # - scene0003_00
52 | - scene0004_00
53 | # - scene0005_00
54 | # - scene0006_00
55 | # - scene0007_00
56 | # - scene0008_00
57 | # - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
65 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s50_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0005_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | # - scene0000_00
49 | # - scene0001_00
50 | # - scene0002_00
51 | # - scene0003_00
52 | # - scene0004_00
53 | - scene0005_00
54 | # - scene0006_00
55 | # - scene0007_00
56 | # - scene0008_00
57 | # - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
65 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s60_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0006_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | # - scene0000_00
49 | # - scene0001_00
50 | # - scene0002_00
51 | # - scene0003_00
52 | # - scene0004_00
53 | # - scene0005_00
54 | - scene0006_00
55 | # - scene0007_00
56 | # - scene0008_00
57 | # - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s70_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0007_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | # - scene0000_00
49 | # - scene0001_00
50 | # - scene0002_00
51 | # - scene0003_00
52 | # - scene0004_00
53 | # - scene0005_00
54 | # - scene0006_00
55 | - scene0007_00
56 | # - scene0008_00
57 | # - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s80_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0008_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | # - scene0000_00
49 | # - scene0001_00
50 | # - scene0002_00
51 | # - scene0003_00
52 | # - scene0004_00
53 | # - scene0005_00
54 | # - scene0006_00
55 | # - scene0007_00
56 | - scene0008_00
57 | # - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
--------------------------------------------------------------------------------
/cfg/exp/one_step_joint/s90_lr1e-5.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: joint_train/scannet_scene0009_00_finetune
3 | clean_up_folder_if_exists: True
4 | checkpoint_load: "ckpts/best-epoch=143-step=175536.ckpt"
5 |
6 | model:
7 | pretrained: False
8 | pretrained_backbone: True
9 | num_classes: 40 # Scannet (40)
10 |
11 | lr_scheduler:
12 | active: false
13 |
14 | optimizer:
15 | lr_seg: 1.0e-5
16 | lr_nerf: 1.0e-2
17 | name: Adam
18 |
19 | trainer:
20 | max_epochs: 20
21 | gpus: -1
22 | num_sanity_val_steps: 0
23 | check_val_every_n_epoch: 1
24 | resume_from_checkpoint: False
25 | load_from_checkpoint: True
26 |
27 | data_module:
28 | pin_memory: true
29 | batch_size: 4
30 | shuffle: true
31 | num_workers: 0
32 | drop_last: true
33 | root: data/scannet/scannet_frames_25k
34 | data_preprocessing:
35 | val_ratio: 0.2
36 | image_regex: /*/color/*.jpg
37 | split_file: split.npz
38 | split_file_cl: split_cl.npz
39 |
40 | visualizer:
41 | store: true
42 | store_n:
43 | train: 3
44 | val: 3
45 | test: 3
46 |
47 | scenes:
48 | # - scene0000_00
49 | # - scene0001_00
50 | # - scene0002_00
51 | # - scene0003_00
52 | # - scene0004_00
53 | # - scene0005_00
54 | # - scene0006_00
55 | # - scene0007_00
56 | # - scene0008_00
57 | - scene0009_00
58 |
59 | cl:
60 | active: false
61 | use_novel_viewpoints: False
62 | replay_buffer_size: 0
63 |
64 |
--------------------------------------------------------------------------------
/cfg/exp/pretrain_scannet_25k_deeplabv3.yml:
--------------------------------------------------------------------------------
1 | general:
2 | name: scannet_25k_deeplab/pretrain
3 | clean_up_folder_if_exists: True
4 |
5 | model:
6 | pretrained: False
7 | pretrained_backbone: True
8 | num_classes: 40 # Scannet (40)
9 |
10 | lr_scheduler:
11 | active: true
12 | name: POLY
13 | poly_cfg:
14 | power: 0.9
15 | max_epochs: 150
16 | target_lr: 1.0e-06
17 |
18 | optimizer:
19 | lr: 0.0001
20 | name: Adam
21 |
22 | trainer:
23 | max_epochs: 150
24 | gpus: -1
25 | num_sanity_val_steps: 0
26 | check_val_every_n_epoch: 1
27 | resume_from_checkpoint: false
28 |
29 | data_module:
30 | pin_memory: true
31 | batch_size: 4
32 | shuffle: true
33 | num_workers: 2
34 | drop_last: false
35 | root: data/scannet_frames_25k/scannet_frames_25k
36 | data_preprocessing:
37 | val_ratio: 0.2
38 | image_regex: /*/color/*.jpg
39 | split_file: split.npz
40 |
41 | visualizer:
42 | store: true
43 | store_n:
44 | train: 3
45 | val: 3
46 | test: 3
47 |
--------------------------------------------------------------------------------
/nr4seg/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
4 |
5 | if not "ENV_WORKSTATION_NAME" in os.environ:
6 | os.environ["ENV_WORKSTATION_NAME"] = "env"
7 |
--------------------------------------------------------------------------------
/nr4seg/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .scannet import ScanNet
2 | from .scannet_cl import ScanNetCL
3 | from .scannet_cl_joint import ScanNetCLJoint
4 | from .scannet_ngp import ScanNetNGP
5 | from .scannet_ngp_joint import ScanNetNGPJoint
6 |
7 | __all__ = [
8 | "ScanNet",
9 | "ScanNetCL",
10 | "ScanNetCLJoint",
11 | "ScanNetNGP",
12 | "ScanNetNGPJoint",
13 | ]
14 |
--------------------------------------------------------------------------------
/nr4seg/dataset/create_split.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import os
4 | import random
5 |
6 | from glob import glob
7 |
8 | from nr4seg.utils import load_yaml
9 |
10 | parser = argparse.ArgumentParser(description="Generates a data split file.")
11 |
12 | curr_dir_path = os.path.dirname(os.path.abspath(__file__))
13 | default_exp_path = os.path.join(
14 | curr_dir_path, "../../cfg/exp/pretrain_scannet_25k_deeplabv3.yml")
15 | parser.add_argument("--config",
16 | type=str,
17 | default=default_exp_path,
18 | help="Path to config file.")
19 |
20 | args = parser.parse_args()
21 |
22 | cfg = load_yaml(args.config)
23 | cfg = cfg["data_module"]
24 |
25 | train_all = glob(cfg["root"] + cfg["data_preprocessing"]["image_regex"])
26 | random.shuffle(train_all)
27 | val = train_all[:int(len(train_all) * cfg["data_preprocessing"]["val_ratio"])]
28 | train = train_all[int(len(train_all) * cfg["data_preprocessing"]["val_ratio"]):]
29 | test = val
30 | train, val, test = map(sorted, [train, val, test])
31 |
32 | split_cl = {"train_cl": train}
33 |
34 | split_dict = {"train": train, "test": test, "val": val}
35 | env_name = os.environ["ENV_WORKSTATION_NAME"]
36 | env = load_yaml(os.path.join("cfg/env", env_name + ".yml"))
37 | scannet_25k_dir = env["scannet_frames_25k"]
38 | out_file = os.path.join(scannet_25k_dir,
39 | cfg["data_preprocessing"]["split_file_cl"])
40 | np.savez(out_file, **split_cl)
41 |
--------------------------------------------------------------------------------
/nr4seg/dataset/helper.py:
--------------------------------------------------------------------------------
1 | import random
2 | import PIL
3 | import torch
4 | from torchvision import transforms as tf
5 | from torchvision.transforms import functional as F
6 | import warnings
7 |
8 | __all__ = ["Augmentation", "AugmentationList", "get_output_size"]
9 |
10 |
11 | def get_output_size(output_size):
12 | if type(output_size) == list or type(output_size) == tuple:
13 | if len(output_size) == 2:
14 | output_size = tuple(output_size)
15 | elif len(output_size) == 1:
16 | output_size = (output_size[0], output_size[0])
17 | elif type(output_size) == int:
18 | output_size = (output_size, output_size)
19 | return output_size
20 |
21 |
22 | class Augmentation:
23 |
24 | def __init__(
25 | self,
26 | output_size=400,
27 | degrees=10,
28 | flip_p=0.5,
29 | jitter_bcsh=[0.3, 0.3, 0.3, 0.05],
30 | ):
31 |
32 | # training transforms
33 | output_size = get_output_size(output_size)
34 | self._output_size = output_size
35 | self._crop = tf.RandomCrop(self._output_size)
36 | with warnings.catch_warnings():
37 | # this will suppress all warnings in this block
38 | warnings.simplefilter("ignore")
39 | self._rot = tf.RandomRotation(degrees=degrees,
40 | resample=PIL.Image.BILINEAR)
41 | self._flip_p = flip_p
42 | self._degrees = degrees
43 |
44 | self._jitter = tf.ColorJitter(
45 | brightness=jitter_bcsh[0],
46 | contrast=jitter_bcsh[1],
47 | saturation=jitter_bcsh[2],
48 | hue=jitter_bcsh[3],
49 | )
50 | self._crop_center = tf.CenterCrop(self._output_size)
51 |
52 | def apply(self, img, label, only_crop=False):
53 | # LABEL is now a list of labels
54 |
55 | scale = False
56 | # Check if rescaling is neccessary based on image height and output height
57 | if img.shape[1] >= 2 * self._output_size[0]:
58 | sf = float(self._output_size[0] / img.shape[1]) * 1.2
59 | sf2 = float(self._output_size[1] / img.shape[2]) * 1.2
60 | sf = max(sf, sf2)
61 |
62 | scale = True
63 | elif (img.shape[1] < self._output_size[0] or
64 | img.shape[2] < self._output_size[1]):
65 | sf1 = float(self._output_size[0] / img.shape[1]) * 1.2
66 | sf2 = float(self._output_size[1] / img.shape[2]) * 1.2
67 | sf = max(sf1, sf2)
68 | scale = True
69 |
70 | if scale:
71 | img = torch.nn.functional.interpolate(
72 | img[None],
73 | scale_factor=(sf, sf),
74 | mode="bilinear",
75 | recompute_scale_factor=False,
76 | align_corners=False,
77 | )[0]
78 | label = torch.nn.functional.interpolate(
79 | label[None],
80 | scale_factor=(sf, sf),
81 | mode="nearest",
82 | recompute_scale_factor=False,
83 | )[0]
84 | if not only_crop:
85 | # Color Jitter
86 | img = self._jitter(img)
87 |
88 | # Rotate
89 | angle = random.uniform(-self._degrees, self._degrees)
90 | with warnings.catch_warnings():
91 | # this will suppress all warnings in this block
92 | warnings.simplefilter("ignore")
93 | img = F.rotate(
94 | img,
95 | angle,
96 | resample=PIL.Image.BILINEAR,
97 | expand=False,
98 | center=None,
99 | fill=0,
100 | )
101 | label = F.rotate(
102 | label,
103 | angle,
104 | resample=PIL.Image.NEAREST,
105 | expand=False,
106 | center=None,
107 | fill=0,
108 | )
109 |
110 | # Crop
111 | i, j, h, w = self._crop.get_params(img, self._output_size)
112 | img = F.crop(img, i, j, h, w)
113 | label = F.crop(label, i, j, h, w)
114 |
115 | # Flip
116 | if torch.rand(1) < self._flip_p:
117 | img = F.hflip(img)
118 | label = F.hflip(label)
119 |
120 | # Performes center crop
121 | img = self._crop_center(img)
122 | label = self._crop_center(label)
123 |
124 | return img, label
125 |
126 |
127 | class AugmentationList:
128 |
129 | def __init__(
130 | self,
131 | output_size=400,
132 | degrees=10,
133 | flip_p=0.5,
134 | jitter_bcsh=[0.3, 0.3, 0.3, 0.05],
135 | ):
136 |
137 | # training transforms
138 | output_size = get_output_size(output_size)
139 | self._output_size = output_size
140 | self._crop = tf.RandomCrop(self._output_size)
141 | with warnings.catch_warnings():
142 | # this will suppress all warnings in this block
143 | warnings.simplefilter("ignore")
144 | self._rot = tf.RandomRotation(degrees=degrees,
145 | resample=PIL.Image.BILINEAR)
146 | self._flip_p = flip_p
147 | self._degrees = degrees
148 |
149 | self._jitter = tf.ColorJitter(
150 | brightness=jitter_bcsh[0],
151 | contrast=jitter_bcsh[1],
152 | saturation=jitter_bcsh[2],
153 | hue=jitter_bcsh[3],
154 | )
155 | self._crop_center = tf.CenterCrop(self._output_size)
156 |
157 | def apply(self, img, label, only_crop=False):
158 | scale = False
159 | # Check if rescaling is neccessary based on image height and output height
160 | if img.shape[1] >= 2 * self._output_size[0]:
161 | sf = float(self._output_size[0] / img.shape[1]) * 1.2
162 | sf2 = float(self._output_size[1] / img.shape[2]) * 1.2
163 | sf = max(sf, sf2)
164 |
165 | scale = True
166 | elif (img.shape[1] < self._output_size[0] or
167 | img.shape[2] < self._output_size[1]):
168 | sf1 = float(self._output_size[0] / img.shape[1]) * 1.2
169 | sf2 = float(self._output_size[1] / img.shape[2]) * 1.2
170 | sf = max(sf1, sf2)
171 | scale = True
172 |
173 | if scale:
174 | img = torch.nn.functional.interpolate(
175 | img[None],
176 | scale_factor=(sf, sf),
177 | mode="bilinear",
178 | recompute_scale_factor=False,
179 | align_corners=False,
180 | )[0]
181 | for _i, l in enumerate(label):
182 | label[_i] = torch.nn.functional.interpolate(
183 | l[None],
184 | scale_factor=(sf, sf),
185 | mode="nearest",
186 | recompute_scale_factor=False,
187 | )[0]
188 | if not only_crop:
189 | # Color Jitter
190 | img = self._jitter(img)
191 |
192 | # Rotate
193 | angle = random.uniform(-self._degrees, self._degrees)
194 | with warnings.catch_warnings():
195 | # this will suppress all warnings in this block
196 | warnings.simplefilter("ignore")
197 | img = F.rotate(
198 | img,
199 | angle,
200 | resample=PIL.Image.BILINEAR,
201 | expand=False,
202 | center=None,
203 | fill=0,
204 | )
205 | for _i, l in enumerate(label):
206 | label[_i] = F.rotate(
207 | l,
208 | angle,
209 | resample=PIL.Image.NEAREST,
210 | expand=False,
211 | center=None,
212 | fill=0,
213 | )
214 |
215 | # Crop
216 | i, j, h, w = self._crop.get_params(img, self._output_size)
217 | img = F.crop(img, i, j, h, w)
218 | for _i, l in enumerate(label):
219 | label[_i] = F.crop(l, i, j, h, w)
220 |
221 | # Flip
222 | if torch.rand(1) < self._flip_p:
223 | img = F.hflip(img)
224 | for _i, l in enumerate(label):
225 | label[_i] = F.hflip(l)
226 |
227 | # Performes center crop
228 | img = self._crop_center(img)
229 | for _i, l in enumerate(label):
230 | label[_i] = self._crop_center(l)
231 |
232 | return img, label
233 |
--------------------------------------------------------------------------------
/nr4seg/dataset/label_loader.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import numpy as np
3 | import os
4 | import pandas
5 | import torch
6 |
7 | __all__ = ["LabelLoaderAuto"]
8 |
9 |
10 | class LabelLoaderAuto:
11 |
12 | def __init__(self, root_scannet=None, confidence=0, H=968, W=1296):
13 | assert root_scannet is not None
14 | self._get_mapping(root_scannet)
15 | self._confidence = confidence
16 | # return label between 0-40
17 |
18 | self.max_classes = 40
19 | self.label = np.zeros((H, W, self.max_classes))
20 | iu16 = np.iinfo(np.uint16)
21 | mask = np.full((H, W), iu16.max, dtype=np.uint16)
22 | self.mask_low = np.right_shift(mask, 6, dtype=np.uint16)
23 |
24 | def get(self, path):
25 | img = imageio.imread(path)
26 | if len(img.shape) == 3:
27 | if img.shape[2] == 4:
28 | H, W, _ = img.shape
29 | self.label = np.zeros((H, W, self.max_classes))
30 | for i in range(3):
31 | prob = np.bitwise_and(img[:, :, i], self.mask_low) / 1023
32 | cls = np.right_shift(img[:, :, i], 10, dtype=np.uint16)
33 | m = np.eye(self.max_classes)[cls] == 1
34 | self.label[m] = prob.reshape(-1)
35 | m = np.max(self.label, axis=2) < self._confidence
36 | self.label = np.argmax(self.label, axis=2).astype(np.int32) + 1
37 | self.label[m] = 0
38 | method = "RGBA"
39 | else:
40 | raise Exception("Type not know")
41 | elif len(img.shape) == 2 and img.dtype == np.uint8:
42 | self.label = img.astype(np.int32)
43 | method = "FAST"
44 | elif len(img.shape) == 2 and img.dtype == np.uint16:
45 | self.label = torch.from_numpy(img.astype(np.int32)).type(
46 | torch.float32)[None, :, :]
47 | sa = self.label.shape
48 | self.label = self.label.flatten()
49 | self.label = self.mapping[self.label.type(torch.int64)]
50 | self.label = self.label.reshape(sa).numpy().astype(np.int32)[0]
51 | method = "MAPPED"
52 | else:
53 | raise Exception("Type not know")
54 | return self.label, method
55 |
56 | def get_probs(self, path):
57 | img = imageio.imread(path)
58 | assert len(img.shape) == 3
59 | assert img.shape[2] == 4
60 | H, W, _ = img.shape
61 | probs = np.zeros((H, W, self.max_classes))
62 | for i in range(3):
63 | prob = np.bitwise_and(img[:, :, i], self.mask_low) / 1023
64 | cls = np.right_shift(img[:, :, i], 10, dtype=np.uint16)
65 | m = np.eye(self.max_classes)[cls] == 1
66 | probs[m] = prob.reshape(-1)
67 |
68 | return probs
69 |
70 | def _get_mapping(self, root):
71 | tsv = os.path.join(root, "scannetv2-labels.combined.tsv")
72 | df = pandas.read_csv(tsv, sep="\t")
73 | mapping_source = np.array(df["id"])
74 | mapping_target = np.array(df["nyu40id"])
75 |
76 | self.mapping = torch.zeros((int(mapping_source.max() + 1)),
77 | dtype=torch.int64)
78 | for so, ta in zip(mapping_source, mapping_target):
79 | self.mapping[so] = ta
80 |
--------------------------------------------------------------------------------
/nr4seg/dataset/ngp_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from packaging import version as pver
4 |
5 |
6 | # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
7 | def nerf_matrix_to_ngp(pose):
8 | new_pose = np.array(
9 | [
10 | [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3]],
11 | [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3]],
12 | [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3]],
13 | [0, 0, 0, 1],
14 | ],
15 | dtype=np.float32,
16 | )
17 | return new_pose
18 |
19 |
20 | def custom_meshgrid(*args):
21 | # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
22 | if pver.parse(torch.__version__) < pver.parse("1.10"):
23 | return torch.meshgrid(*args)
24 | else:
25 | return torch.meshgrid(*args, indexing="ij")
26 |
27 |
28 | @torch.cuda.amp.autocast(enabled=False)
29 | def get_rays(poses, intrinsics, H, W, error_map=None):
30 | """get rays
31 | Args:
32 | poses: [B, 4, 4], cam2world
33 | intrinsics: [4]
34 | H, W, N: int
35 | error_map: [B, 128 * 128], sample probability based on training error
36 | Returns:
37 | rays_o, rays_d: [B, N, 3]
38 | direction_norms: [B, N, 1]
39 | inds: [B, N]
40 | """
41 |
42 | device = poses.device
43 | B = poses.shape[0]
44 | fx, fy, cx, cy = intrinsics
45 |
46 | i, j = custom_meshgrid(
47 | torch.linspace(0, W - 1, W, device=device),
48 | torch.linspace(0, H - 1, H, device=device),
49 | )
50 | i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5
51 | j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5
52 |
53 | results = {}
54 | zs = torch.ones_like(i)
55 | xs = (i - cx) / fx * zs
56 | ys = (j - cy) / fy * zs
57 | directions = torch.stack((xs, ys, zs), dim=-1)
58 | direction_norms = torch.norm(directions, dim=-1, keepdim=True)
59 | directions = directions / direction_norms
60 | rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
61 |
62 | rays_o = poses[..., :3, 3] # [B, 3]
63 | rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
64 |
65 | results["rays_o"] = rays_o
66 | results["rays_d"] = rays_d
67 | results["direction_norms"] = direction_norms
68 |
69 | return results
70 |
71 |
72 | # color palette for nyu40 labels
73 | nyu40_colour_code = np.array([
74 | (0, 0, 0),
75 | (174, 199, 232), # wall
76 | (152, 223, 138), # floor
77 | (31, 119, 180), # cabinet
78 | (255, 187, 120), # bed
79 | (188, 189, 34), # chair
80 | (140, 86, 75), # sofa
81 | (255, 152, 150), # table
82 | (214, 39, 40), # door
83 | (197, 176, 213), # window
84 | (148, 103, 189), # bookshelf
85 | (196, 156, 148), # picture
86 | (23, 190, 207), # counter
87 | (178, 76, 76), # blinds
88 | (247, 182, 210), # desk
89 | (66, 188, 102), # shelves
90 | (219, 219, 141), # curtain
91 | (140, 57, 197), # dresser
92 | (202, 185, 52), # pillow
93 | (51, 176, 203), # mirror
94 | (200, 54, 131), # floor
95 | (92, 193, 61), # clothes
96 | (78, 71, 183), # ceiling
97 | (172, 114, 82), # books
98 | (255, 127, 14), # refrigerator
99 | (91, 163, 138), # tv
100 | (153, 98, 156), # paper
101 | (140, 153, 101), # towel
102 | (158, 218, 229), # shower curtain
103 | (100, 125, 154), # box
104 | (178, 127, 135), # white board
105 | (120, 185, 128), # person
106 | (146, 111, 194), # night stand
107 | (44, 160, 44), # toilet
108 | (112, 128, 144), # sink
109 | (96, 207, 209), # lamp
110 | (227, 119, 194), # bathtub
111 | (213, 92, 176), # bag
112 | (94, 106, 211), # other struct
113 | (82, 84, 163), # otherfurn
114 | (100, 85, 144), # other prop
115 | ]).astype(np.uint8)
116 |
--------------------------------------------------------------------------------
/nr4seg/dataset/scannet.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import numpy as np
3 | import os
4 | import random
5 | import torch
6 |
7 | from torch.utils.data import Dataset
8 |
9 | try:
10 | from .helper import AugmentationList
11 | except Exception:
12 | from helper import AugmentationList
13 |
14 | from .label_loader import LabelLoaderAuto
15 |
16 | __all__ = ["ScanNet"]
17 |
18 |
19 | class ScanNet(Dataset):
20 |
21 | def __init__(
22 | self,
23 | root,
24 | img_list,
25 | mode="train",
26 | output_trafo=None,
27 | output_size=(240, 320),
28 | degrees=10,
29 | flip_p=0.5,
30 | jitter_bcsh=[0.3, 0.3, 0.3, 0.05],
31 | sub=10,
32 | data_augmentation=True,
33 | label_setting="default",
34 | confidence_aux=0,
35 | ):
36 | """
37 | Dataset dosent know if it contains replayed or normal samples !
38 |
39 | Some images are stored in 640x480 other ins 1296x968
40 | Warning scene0088_03 has wrong resolution -> Ignored
41 | Parameters
42 | ----------
43 | root : str, path to the ML-Hypersim folder
44 | mode : str, option ['train','val]
45 | """
46 |
47 | super(ScanNet, self).__init__()
48 | self._sub = sub
49 | self._mode = mode
50 | self._confidence_aux = confidence_aux
51 |
52 | self._label_setting = label_setting
53 | self.image_pths = img_list
54 | self.label_pths = [
55 | p.replace("color", "label").replace("jpg", "png") for p in img_list
56 | ]
57 | self.length = len(self.image_pths)
58 | self.aux_labels = False
59 | self._augmenter = AugmentationList(output_size, degrees, flip_p,
60 | jitter_bcsh)
61 | self._output_trafo = output_trafo
62 | self._data_augmentation = data_augmentation
63 | self.unique = False
64 |
65 | self.aux_labels_fake = False
66 |
67 | self._label_loader = LabelLoaderAuto(root_scannet=root,
68 | confidence=self._confidence_aux)
69 | if self.aux_labels:
70 | self._preprocessing_hack()
71 |
72 | def set_aux_labels_fake(self, flag=True):
73 | self.aux_labels_fake = flag
74 | self.aux_labels = flag
75 |
76 | def __getitem__(self, index):
77 | # Read Image and Label
78 | label, _ = self._label_loader.get(self.label_pths[index])
79 | label = torch.from_numpy(label).type(
80 | torch.float32)[None, :, :] # C H W -> contains 0-40
81 | label = [label]
82 | if self.aux_labels and not self.aux_labels_fake:
83 | _p = self.aux_label_pths[index]
84 | if os.path.isfile(_p):
85 | aux_label, _ = self._label_loader.get(_p)
86 | aux_label = torch.from_numpy(aux_label).type(
87 | torch.float32)[None, :, :]
88 | label.append(aux_label)
89 | else:
90 | if _p.find("_.png") != -1:
91 | print(_p)
92 | print("Processed not found")
93 | _p = _p.replace("_.png", ".png")
94 | aux_label, _ = self._label_loader.get(_p)
95 | aux_label = torch.from_numpy(aux_label).type(
96 | torch.float32)[None, :, :]
97 | label.append(aux_label)
98 |
99 | img = imageio.imread(self.image_pths[index])
100 | img = (torch.from_numpy(img).type(torch.float32).permute(2, 0, 1) / 255
101 | ) # C H W range 0-1
102 |
103 | if self._mode.find("train") != -1 and self._data_augmentation:
104 | img, label = self._augmenter.apply(img, label)
105 | else:
106 | img, label = self._augmenter.apply(img, label, only_crop=True)
107 |
108 | img_ori = img.clone()
109 | if self._output_trafo is not None:
110 | img = self._output_trafo(img)
111 |
112 | for k in range(len(label)):
113 | label[k] = label[k] - 1 # 0 == chairs 39 other prop -1 invalid
114 |
115 | # REJECT LABEL
116 | if (label[0] != -1).sum() < 10:
117 | idx = random.randint(0, len(self) - 1)
118 | if not self.unique:
119 | return self[idx]
120 | else:
121 | return False
122 |
123 | ret = (img, label[0].type(torch.int64)[0, :, :])
124 | if self.aux_labels:
125 | if self.aux_labels_fake:
126 | ret += (
127 | label[0].type(torch.int64)[0, :, :],
128 | torch.tensor(False),
129 | )
130 | else:
131 | ret += (
132 | label[1].type(torch.int64)[0, :, :],
133 | torch.tensor(True),
134 | )
135 |
136 | ret += (img_ori,)
137 | return ret
138 |
139 | def __len__(self):
140 | return self.length
141 |
142 | def __str__(self):
143 | string = "=" * 90
144 | string += "\nScannet Dataset: \n"
145 | length = len(self)
146 | string += f" Total Samples: {length}"
147 | string += f" » Mode: {self._mode} \n"
148 | string += f" Replay: {self.replay}"
149 | string += f" » DataAug: {self._data_augmentation}"
150 | string += (
151 | f" » DataAug Replay: {self._data_augmentation_for_replay}\n")
152 | string += "=" * 90
153 | return string
154 |
155 | def _preprocessing_hack(self, force=False):
156 | """
157 | If training with aux_labels ->
158 | generates label for fast loading with a fixed certainty.
159 | """
160 |
161 | # check if this has already been performed
162 | aux_label, method = self._label_loader.get(
163 | self.aux_label_pths[self.global_to_local_idx[0]])
164 | print("Meethod ", method)
165 | print(
166 | "self.global_to_local_idx[0] ",
167 | self.global_to_local_idx[0],
168 | self.aux_label_pths[self.global_to_local_idx[0]],
169 | )
170 | if method == "RGBA":
171 |
172 | # This should always evaluate to true
173 | if (self.aux_label_pths[self.global_to_local_idx[0]].find("_.png")
174 | == -1):
175 | print(
176 | "self.aux_label_pths[self.global_to_local_idx[0]]",
177 | self.aux_label_pths[self.global_to_local_idx[0]],
178 | self.global_to_local_idx[0],
179 | )
180 | if (os.path.isfile(self.aux_label_pths[
181 | self.global_to_local_idx[0]].replace(".png", "_.png"))
182 | and os.path.isfile(self.aux_label_pths[
183 | self.global_to_local_idx[-1]].replace(
184 | ".png", "_.png")) and not force):
185 | # only perform simple renaming
186 | print("Only do renanming")
187 | self.aux_label_pths = [
188 | a.replace(".png", "_.png") for a in self.aux_label_pths
189 | ]
190 | else:
191 | print("Start multithread preprocessing of images")
192 |
193 | def parallel(gtli, aux_label_pths, label_loader):
194 | print("Start take care of: ", gtli[0], " - ", gtli[-1])
195 | for i in gtli:
196 | aux_label, method = label_loader.get(
197 | aux_label_pths[i])
198 | imageio.imwrite(
199 | aux_label_pths[i].replace(".png", "_.png"),
200 | np.uint8(aux_label),
201 | )
202 |
203 | def parallel2(aux_pths, label_loader):
204 | for a in aux_pths:
205 | aux_label, method = label_loader.get(a)
206 | imageio.imwrite(
207 | a.replace(".png", "_.png"),
208 | np.uint8(aux_label),
209 | )
210 |
211 | cores = 16
212 | tasks = [
213 | t.tolist() for t in np.array_split(
214 | np.array(self.global_to_local_idx), cores)
215 | ]
216 |
217 | from multiprocessing import Process
218 |
219 | for i in range(cores):
220 | p = Process(
221 | target=parallel2,
222 | args=(
223 | np.array(self.aux_label_pths)[np.array(
224 | tasks[i])].tolist(),
225 | self._label_loader,
226 | ),
227 | )
228 | p.start()
229 | p.join()
230 | print("Done multithread preprocessing of images")
231 | self.aux_label_pths = [
232 | a.replace(".png", "_.png") for a in self.aux_label_pths
233 | ]
234 |
--------------------------------------------------------------------------------
/nr4seg/dataset/scannet_cl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 |
5 | from glob import glob
6 | from torch.utils.data import Dataset
7 |
8 | __all__ = ["ScanNetCL"]
9 |
10 |
11 | class ScanNetCL(Dataset):
12 |
13 | def __init__(
14 | self,
15 | scannet_25k,
16 | scannet_ngp,
17 | ngp_25k_ratio=1,
18 | ):
19 | """
20 | Dataset dosent know if it contains replayed or normal samples !
21 |
22 | Some images are stored in 640x480 other ins 1296x968
23 | Warning scene0088_03 has wrong resolution -> Ignored
24 | Parameters
25 | ----------
26 | root : str, path to the ML-Hypersim folder
27 | mode : str, option ['train','val]
28 | """
29 |
30 | super(ScanNetCL, self).__init__()
31 | self.scannet_25k = scannet_25k
32 | self.scannet_ngp = scannet_ngp
33 | self.ngp_25k_ratio = ngp_25k_ratio
34 |
35 | def get_image_pths(self, scene_list, val_ratio=0.2):
36 | img_list = []
37 | for scene_name in scene_list:
38 | all_imgs = glob(self.root + "/" + scene_name + "/color/*jpg")
39 | all_imgs = sorted(all_imgs,
40 | key=lambda x: int(os.path.basename(x)[:-4]))
41 | val_imgs = all_imgs[-int(len(all_imgs) * val_ratio):]
42 | train_imgs = all_imgs[:-int(len(all_imgs) * val_ratio)]
43 | if self._mode == "train":
44 | img_list.extend(train_imgs[::self._sub])
45 | else:
46 | img_list.extend(val_imgs[::self._sub])
47 |
48 | return img_list
49 |
50 | def __getitem__(self, index):
51 | # Read Image and Label
52 | ret_ngp = self.scannet_ngp.__getitem__(index)
53 | ret_25k = []
54 |
55 | for _ in range(self.ngp_25k_ratio):
56 | rand_id = random.randint(0, self.scannet_25k.__len__() - 1)
57 | ret_25k.append(self.scannet_25k.__getitem__(rand_id))
58 |
59 | return (ret_ngp, ret_25k)
60 |
61 | @staticmethod
62 | def collate(batch):
63 | img = []
64 | label = []
65 | img_ori = []
66 |
67 | for bb in batch:
68 | img.append(bb[0][0])
69 | label.append(bb[0][1])
70 | img_ori.append(bb[0][2])
71 | for bbb in bb[1]:
72 | img.append(bbb[0])
73 | label.append(bbb[1])
74 | img_ori.append(bbb[2])
75 |
76 | img = torch.stack(img, dim=0)
77 | label = torch.stack(label, dim=0)
78 | img_ori = torch.stack(img_ori, dim=0)
79 | return img, label, img_ori
80 |
81 | def __len__(self):
82 | return self.scannet_ngp.__len__()
83 |
--------------------------------------------------------------------------------
/nr4seg/dataset/scannet_cl_joint.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.utils.data import Dataset
4 |
5 | __all__ = ["ScanNetCLJoint"]
6 |
7 |
8 | class ScanNetCLJoint(Dataset):
9 |
10 | def __init__(
11 | self,
12 | scannet_25k,
13 | scannet_ngp,
14 | ngp_25k_ratio=1,
15 | ):
16 | """
17 | Dataset dosent know if it contains replayed or normal samples !
18 |
19 | Some images are stored in 640x480 other ins 1296x968
20 | Warning scene0088_03 has wrong resolution -> Ignored
21 | Parameters
22 | ----------
23 | root : str, path to the ML-Hypersim folder
24 | mode : str, option ['train','val]
25 | """
26 |
27 | super(ScanNetCLJoint, self).__init__()
28 | self.scannet_25k = scannet_25k
29 | self.scannet_ngp = scannet_ngp
30 | self.ngp_25k_ratio = ngp_25k_ratio
31 |
32 | def __getitem__(self, index):
33 | # Read Image and Label
34 | ret_dict = self.scannet_ngp.__getitem__(index)
35 | ret_25k = {"replay_img": [], "replay_label": []}
36 |
37 | for _ in range(self.ngp_25k_ratio):
38 | rand_id = random.randint(0, self.scannet_25k.__len__() - 1)
39 | img, label, _ = self.scannet_25k.__getitem__(rand_id)
40 | ret_25k["replay_img"].append(img)
41 | ret_25k["replay_label"].append(label)
42 |
43 | for key in ret_25k.keys():
44 | ret_25k[key] = torch.stack(ret_25k[key], dim=0)
45 |
46 | ret_dict.update(ret_25k)
47 | return ret_dict
48 |
49 | @staticmethod
50 | def collate(batch):
51 | img = []
52 | label = []
53 | img_ori = []
54 |
55 | for bb in batch:
56 | img.append(bb[0][0])
57 | label.append(bb[0][1])
58 | img_ori.append(bb[0][2])
59 | for bbb in bb[1]:
60 | img.append(bbb[0])
61 | label.append(bbb[1])
62 | img_ori.append(bbb[2])
63 |
64 | img = torch.stack(img, dim=0)
65 | label = torch.stack(label, dim=0)
66 | img_ori = torch.stack(img_ori, dim=0)
67 | return batch_new, batch_old
68 |
69 | def __len__(self):
70 | return self.scannet_ngp.__len__()
71 |
72 | def __len__(self):
73 | return self.scannet_ngp.__len__()
74 |
--------------------------------------------------------------------------------
/nr4seg/dataset/scannet_ngp.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import random
4 | import torch
5 |
6 | from glob import glob
7 | from torch.utils.data import Dataset
8 |
9 | try:
10 | from .helper import AugmentationList
11 | except Exception:
12 | from helper import AugmentationList
13 |
14 | __all__ = ["ScanNetNGP"]
15 |
16 |
17 | class ScanNetNGP(Dataset):
18 |
19 | def __init__(
20 | self,
21 | root,
22 | scene_list,
23 | prev_exp_name="one_step_nerf_only",
24 | mode="train",
25 | train_image="nerf",
26 | train_label="nerf",
27 | val_mode="gtgt",
28 | output_trafo=None,
29 | output_size=(240, 320),
30 | degrees=10,
31 | flip_p=0.5,
32 | jitter_bcsh=[0.3, 0.3, 0.3, 0.05],
33 | sub=1,
34 | data_augmentation=True,
35 | label_setting="default",
36 | confidence_aux=0,
37 | ):
38 | """
39 | Dataset dosent know if it contains replayed or normal samples !
40 |
41 | Some images are stored in 640x480 other ins 1296x968
42 | Warning scene0088_03 has wrong resolution -> Ignored
43 | Parameters
44 | ----------
45 | root : str, path to the ML-Hypersim folder
46 | mode : str, option ['train','val]
47 | """
48 |
49 | super(ScanNetNGP, self).__init__()
50 | self._sub = sub
51 | self._mode = mode
52 | self._confidence_aux = confidence_aux
53 |
54 | self.H = output_size[0]
55 | self.W = output_size[1]
56 |
57 | self._label_setting = label_setting
58 | self.root = root
59 | self.image_pths, self.img_num = self.get_image_pths(scene_list)
60 |
61 | self.image_gt_pths = self.image_pths
62 | self.image_nerf_pths = [
63 | p.replace("color_scaled",
64 | prev_exp_name + "/nerf_image").replace("jpg", "png")
65 | for p in self.image_pths
66 | ]
67 | self.label_nerf_pths = [
68 | p.replace("color_scaled",
69 | prev_exp_name + "/nerf_label").replace("jpg", "png")
70 | for p in self.image_pths
71 | ]
72 | self.label_mapping_pths = [
73 | p.replace("color_scaled", "mapping_label").replace("jpg", "png")
74 | for p in self.image_pths
75 | ]
76 | self.label_gt_pths = [
77 | p.replace("color_scaled", "label_scaled").replace("jpg", "png")
78 | for p in self.image_pths
79 | ]
80 |
81 | self.length = len(self.image_pths)
82 | self._augmenter = AugmentationList(output_size, degrees, flip_p,
83 | jitter_bcsh)
84 | self._output_trafo = output_trafo
85 | self._data_augmentation = data_augmentation
86 | self.train_image = train_image
87 | self.train_label = train_label
88 | self.val_mode = val_mode
89 |
90 | def get_image_pths(self, scene_list, val_ratio=0.2):
91 | img_list = []
92 | img_num = []
93 | for scene_name in scene_list:
94 | all_imgs = glob(self.root + "/" + scene_name + "/color_scaled/*jpg")
95 | all_imgs = sorted(all_imgs,
96 | key=lambda x: int(os.path.basename(x)[:-4]))
97 | val_imgs = all_imgs[-int(len(all_imgs) * val_ratio):]
98 | # val_imgs = val_imgs[: len(val_imgs)//4*4]
99 | train_imgs = all_imgs[:-int(len(all_imgs) * val_ratio)]
100 | if self._mode == "train":
101 | img_list.extend(train_imgs[::self._sub])
102 | img_num.append(len(train_imgs[::self._sub]))
103 | else:
104 | img_list.extend(val_imgs[::self._sub])
105 |
106 | return img_list, img_num
107 |
108 | def __getitem__(self, index):
109 | # Read Image and Label
110 |
111 | use_nerf_label = False
112 | if self._mode == "train":
113 | if self.train_image == "gt":
114 | img = cv2.imread(self.image_gt_pths[index],
115 | cv2.IMREAD_UNCHANGED)
116 | elif self.train_image == "nerf":
117 | img = cv2.imread(self.image_nerf_pths[index],
118 | cv2.IMREAD_UNCHANGED)
119 | elif self.train_image == "half":
120 | img = (cv2.imread(self.image_gt_pths[index],
121 | cv2.IMREAD_UNCHANGED)
122 | if random.random() > 0.5 else cv2.imread(
123 | self.image_nerf_pths[index], cv2.IMREAD_UNCHANGED))
124 | else:
125 | raise NotImplementedError
126 | else:
127 | if self.val_mode == "gtgt":
128 | img = cv2.imread(self.image_gt_pths[index],
129 | cv2.IMREAD_UNCHANGED)
130 | elif self.val_mode in ["nerfgt", "nerfnerf"]:
131 | img = cv2.imread(self.image_nerf_pths[index],
132 | cv2.IMREAD_UNCHANGED)
133 | else:
134 | raise NotImplementedError
135 |
136 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
137 | img = img / 255.0
138 | img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
139 | img = (torch.from_numpy(img).type(torch.float32).permute(2, 0, 1)
140 | ) # C H W range 0-1
141 |
142 | if self._mode == "train":
143 | if self.train_label == "nerf":
144 | label = cv2.imread(self.label_nerf_pths[index],
145 | cv2.IMREAD_UNCHANGED)
146 | else:
147 | label = cv2.imread(self.label_mapping_pths[index],
148 | cv2.IMREAD_UNCHANGED)
149 | use_nerf_label = self.train_label == "nerf"
150 | else:
151 | if self.val_mode in ["gtgt", "nerfgt"]:
152 | label = cv2.imread(self.label_gt_pths[index],
153 | cv2.IMREAD_UNCHANGED)
154 | elif self.val_mode == "nerfnerf":
155 | label = cv2.imread(self.label_nerf_pths[index],
156 | cv2.IMREAD_UNCHANGED)
157 | use_nerf_label = True
158 |
159 | label = cv2.resize(label, (self.W, self.H),
160 | interpolation=cv2.INTER_NEAREST)
161 | label = torch.from_numpy(label).type(
162 | torch.float32)[None, :, :] # C H W -> contains 0-40
163 |
164 | # if use nerf labels + 1
165 | if use_nerf_label:
166 | label = label + 1
167 |
168 | label = [label]
169 |
170 | if self._mode.find("train") != -1 and self._data_augmentation:
171 | img, label = self._augmenter.apply(img, label)
172 | else:
173 | img, label = self._augmenter.apply(img, label, only_crop=True)
174 |
175 | label[0] = label[0] - 1
176 |
177 | img_ori = img.clone()
178 | if self._output_trafo is not None:
179 | img = self._output_trafo(img)
180 |
181 | ret = (img, label[0].type(torch.int64)[0, :, :])
182 | ret += (img_ori,)
183 |
184 | if self._mode != "train":
185 | current_scene_name = os.path.normpath(self.image_pths[index]).split(
186 | os.path.sep)[-3]
187 | ret += (current_scene_name,)
188 |
189 | return ret
190 |
191 | def __len__(self):
192 | return self.length
193 |
194 | def __str__(self):
195 | string = "=" * 90
196 | string += "\nScannet Dataset: \n"
197 | length = len(self)
198 | string += f" Total Samples: {length}"
199 | string += f" » Mode: {self._mode} \n"
200 | string += f" » DataAug: {self._data_augmentation}"
201 | string += "=" * 90
202 | return string
203 |
--------------------------------------------------------------------------------
/nr4seg/lightning/__init__.py:
--------------------------------------------------------------------------------
1 | from .finetune_data_module import FineTuneDataModule
2 | from .joint_train_data_module import JointTrainDataModule
3 | from .joint_train_lightning_net import JointTrainLightningNet
4 | from .pretrain_data_module import PretrainDataModule
5 | from .semantics_lightning_net import SemanticsLightningNet
6 |
7 | __all__ = [
8 | "FineTuneDataModule",
9 | "JointTrainDataModule",
10 | "JointTrainLightningNet",
11 | "PretrainDataModule",
12 | "SemanticsLightningNet",
13 | ]
14 |
--------------------------------------------------------------------------------
/nr4seg/lightning/finetune_data_module.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pytorch_lightning as pl
4 |
5 | from torch.utils.data import DataLoader
6 | from typing import Optional
7 |
8 | from nr4seg.dataset import ScanNet, ScanNetCL, ScanNetNGP
9 |
10 |
11 | class FineTuneDataModule(pl.LightningDataModule):
12 |
13 | def __init__(
14 | self,
15 | exp: dict,
16 | env: dict,
17 | prev_exp_name: str = "one_step_nerf_only",
18 | ):
19 | super().__init__()
20 |
21 | self.env = env
22 | self.exp = exp
23 | self.cfg_loader = self.exp["data_module"]
24 | self.prev_exp_name = prev_exp_name
25 |
26 | def setup(self, stage: Optional[str] = None) -> None:
27 | ## test adaption (last 20% of the new scenes)
28 | finetune_seqs = self.exp["scenes"]
29 | self.scannet_test_ada = ScanNetNGP(
30 | root=self.env["scannet"],
31 | mode="val", # val
32 | val_mode="gtgt",
33 | scene_list=finetune_seqs,
34 | )
35 | ## test generation
36 | scannet_25k_dir = self.env["scannet_frames_25k"]
37 | split_file = os.path.join(
38 | scannet_25k_dir,
39 | self.cfg_loader["data_preprocessing"]["split_file"],
40 | )
41 | img_list = np.load(split_file)
42 | self.scannet_test_gen = ScanNet(
43 | root=self.cfg_loader["root"],
44 | img_list=img_list["test"],
45 | mode="test",
46 | )
47 | ## train (all Torch-NGP generated labels)
48 | # random select from
49 |
50 | if not self.exp["cl"]["active"]:
51 | self.scannet_train = ScanNetNGP(
52 | root=self.env["scannet"],
53 | train_image=self.cfg_loader["train_image"],
54 | train_label=self.cfg_loader["train_label"],
55 | mode="train",
56 | scene_list=finetune_seqs,
57 | prev_exp_name=self.prev_exp_name,
58 | )
59 | else:
60 | scannet_ngp = ScanNetNGP(
61 | root=self.env["scannet"],
62 | train_image=self.cfg_loader["train_image"],
63 | train_label=self.cfg_loader["train_label"],
64 | mode="train",
65 | scene_list=finetune_seqs,
66 | prev_exp_name=self.prev_exp_name,
67 | )
68 | split_file_cl = os.path.join(
69 | scannet_25k_dir,
70 | self.cfg_loader["data_preprocessing"]["split_file_cl"],
71 | )
72 | img_list_cl = np.load(split_file_cl)["train_cl"]
73 | img_list_cl = img_list_cl[:int(self.exp["cl"]["25k_fraction"] *
74 | len(img_list_cl))]
75 | scannet_25k = ScanNet(
76 | root=self.cfg_loader["root"],
77 | img_list=img_list_cl,
78 | mode="train",
79 | )
80 | self.scannet_train = ScanNetCL(
81 | scannet_25k,
82 | scannet_ngp,
83 | ngp_25k_ratio=self.exp["cl"]["ngp_25k_ratio"],
84 | )
85 |
86 | def train_dataloader(self) -> DataLoader:
87 | return DataLoader(
88 | self.scannet_train,
89 | batch_size=self.cfg_loader["batch_size"],
90 | drop_last=True,
91 | shuffle=True, # only true in train_dataloader
92 | collate_fn=self.scannet_train.collate
93 | if self.exp["cl"]["active"] else None,
94 | pin_memory=self.cfg_loader["pin_memory"],
95 | num_workers=self.cfg_loader["num_workers"],
96 | )
97 |
98 | def val_dataloader(self) -> DataLoader:
99 | return DataLoader(
100 | self.scannet_test_ada,
101 | batch_size=
102 | 1, ## set bs=1 to ensure a batch always has frames from the same scene
103 | drop_last=False,
104 | shuffle=False, # false
105 | pin_memory=self.cfg_loader["pin_memory"],
106 | num_workers=self.cfg_loader["num_workers"],
107 | )
108 |
109 | def test_dataloader(self) -> DataLoader:
110 | return DataLoader(
111 | self.scannet_test_gen,
112 | batch_size=self.cfg_loader["batch_size"],
113 | drop_last=False,
114 | shuffle=False, # false
115 | pin_memory=self.cfg_loader["pin_memory"],
116 | num_workers=self.cfg_loader["num_workers"],
117 | )
118 |
--------------------------------------------------------------------------------
/nr4seg/lightning/joint_train_data_module.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pytorch_lightning as pl
4 |
5 | from torch.utils.data import DataLoader
6 | from typing import Optional
7 |
8 | from nr4seg.dataset import ScanNet, ScanNetCLJoint, ScanNetNGPJoint
9 |
10 |
11 | class JointTrainDataModule(pl.LightningDataModule):
12 |
13 | def __init__(
14 | self,
15 | exp: dict,
16 | env: dict,
17 | split_ratio: float = 0.2,
18 | ):
19 | super().__init__()
20 |
21 | self.env = env
22 | self.exp = exp
23 | self.cfg_loader = self.exp["data_module"]
24 |
25 | self.split_ratio = split_ratio
26 | self.test_sz = 0
27 | self.val_sz = 0
28 | self.train_sz = 0
29 |
30 | def setup(self, stage: Optional[str] = None) -> None:
31 | finetune_seqs = self.exp["scenes"]
32 | self.scannet_val_ada = ScanNetNGPJoint(
33 | root=self.env["scannet"],
34 | mode="val", # val
35 | scene_list=finetune_seqs,
36 | exp_name=self.exp["exp_name"],
37 | only_new_scene=False,
38 | )
39 | self.scannet_train_val_ada = ScanNetNGPJoint(
40 | root=self.env["scannet"],
41 | mode="train_val", # train_val
42 | scene_list=finetune_seqs,
43 | exp_name=self.exp["exp_name"],
44 | only_new_scene=False,
45 | )
46 | self.scannet_predict_ada = ScanNetNGPJoint(
47 | root=self.env["scannet"],
48 | mode="predict", # val
49 | scene_list=finetune_seqs,
50 | exp_name=self.exp["exp_name"],
51 | use_novel_viewpoints=self.exp["cl"]["use_novel_viewpoints"],
52 | only_new_scene=True,
53 | )
54 | ## test generation
55 | scannet_25k_dir = self.env["scannet_frames_25k"]
56 | split_file = os.path.join(
57 | scannet_25k_dir,
58 | self.cfg_loader["data_preprocessing"]["split_file"],
59 | )
60 | img_list = np.load(split_file)
61 | self.scannet_test_gen = ScanNet(
62 | root=self.env["scannet_frames_25k"],
63 | img_list=img_list["test"],
64 | mode="test",
65 | )
66 |
67 | self.scannet_train_nerf = ScanNetNGPJoint(
68 | root=self.env["scannet"],
69 | mode="train",
70 | scene_list=finetune_seqs,
71 | exp_name=self.exp["exp_name"],
72 | only_new_scene=True,
73 | )
74 | print("\033[93mNOTE: By default, the replay buffer is set to have size "
75 | "100.\033[0m")
76 | self.scannet_train_joint = ScanNetNGPJoint(
77 | root=self.env["scannet"],
78 | mode="train",
79 | scene_list=finetune_seqs,
80 | exp_name=self.exp["exp_name"],
81 | only_new_scene=False,
82 | # NOTE: This is referred to whether the previous scenes (if any)
83 | # used for replay were generated from novel viewpoints.
84 | use_novel_viewpoints=self.exp["cl"]["use_novel_viewpoints"],
85 | fix_nerf=False,
86 | replay_buffer_size=self.exp["cl"]["replay_buffer_size"],
87 | )
88 | ## train (all Torch-NGP generated labels)
89 | # random select from
90 | if self.exp["cl"]["active"]:
91 | split_file_cl = os.path.join(
92 | scannet_25k_dir,
93 | self.cfg_loader["data_preprocessing"]["split_file_cl"],
94 | )
95 | img_list_cl = np.load(split_file_cl)["train_cl"]
96 | img_list_cl = img_list_cl[:int(self.exp["cl"]["25k_fraction"] *
97 | len(img_list_cl))]
98 | scannet_25k = ScanNet(
99 | root=self.env["scannet_frames_25k"],
100 | img_list=img_list_cl,
101 | mode="train",
102 | )
103 | self.scannet_train_joint = ScanNetCLJoint(
104 | scannet_25k,
105 | self.scannet_train_joint,
106 | ngp_25k_ratio=self.exp["cl"]["ngp_25k_ratio"],
107 | )
108 |
109 | self.test_sz = len(self.scannet_test_gen)
110 | self.val_sz = len(self.scannet_val_ada)
111 | self.train_sz = len(self.scannet_train_joint)
112 | print("Train/Val/Test/Total: {}/{}/{}/{}".format(
113 | self.train_sz,
114 | self.val_sz,
115 | self.test_sz,
116 | self.train_sz + self.val_sz + self.test_sz,
117 | ))
118 |
119 | def train_dataloader_nerf(self) -> DataLoader:
120 | return DataLoader(
121 | self.scannet_train_nerf,
122 | batch_size=1,
123 | drop_last=False,
124 | shuffle=True, # only true in train_dataloader
125 | pin_memory=self.cfg_loader["pin_memory"],
126 | num_workers=self.cfg_loader["num_workers"],
127 | )
128 |
129 | def train_dataloader_joint(self) -> DataLoader:
130 | return DataLoader(
131 | self.scannet_train_joint,
132 | batch_size=self.cfg_loader["batch_size"],
133 | drop_last=True,
134 | shuffle=True, # only true in train_dataloader
135 | pin_memory=self.cfg_loader["pin_memory"],
136 | num_workers=self.cfg_loader["num_workers"],
137 | collate_fn=ScanNetNGPJoint.collate,
138 | )
139 |
140 | def val_dataloader(self) -> DataLoader:
141 | return [
142 | DataLoader(
143 | self.scannet_val_ada,
144 | # Set bs=1 to ensure a batch always has frames from the same
145 | # scene.
146 | batch_size=1,
147 | drop_last=False,
148 | shuffle=False, # false
149 | pin_memory=self.cfg_loader["pin_memory"],
150 | num_workers=self.cfg_loader["num_workers"],
151 | ),
152 | DataLoader(
153 | self.scannet_train_val_ada,
154 | # Set bs=1 to ensure a batch always has frames from the same
155 | # scene.
156 | batch_size=1,
157 | drop_last=False,
158 | shuffle=False, # false
159 | pin_memory=self.cfg_loader["pin_memory"],
160 | num_workers=self.cfg_loader["num_workers"],
161 | ),
162 | ]
163 |
164 | def test_dataloader(self) -> DataLoader:
165 | return [
166 | DataLoader(
167 | self.scannet_train_nerf,
168 | batch_size=1,
169 | drop_last=False,
170 | shuffle=False, # false
171 | pin_memory=self.cfg_loader["pin_memory"],
172 | num_workers=self.cfg_loader["num_workers"],
173 | ),
174 | DataLoader(
175 | self.scannet_test_gen,
176 | batch_size=4,
177 | drop_last=False,
178 | shuffle=False, # false
179 | pin_memory=self.cfg_loader["pin_memory"],
180 | num_workers=self.cfg_loader["num_workers"],
181 | ),
182 | ]
183 |
184 | def test_dataloader_nerf(self) -> DataLoader:
185 | return DataLoader(
186 | self.scannet_train_nerf,
187 | batch_size=1,
188 | drop_last=False,
189 | shuffle=False, # false
190 | pin_memory=self.cfg_loader["pin_memory"],
191 | num_workers=self.cfg_loader["num_workers"],
192 | )
193 |
194 | def predict_dataloader(self) -> DataLoader:
195 | return DataLoader(
196 | self.scannet_predict_ada,
197 | batch_size=1,
198 | drop_last=False,
199 | shuffle=False, # false
200 | pin_memory=self.cfg_loader["pin_memory"],
201 | num_workers=self.cfg_loader["num_workers"],
202 | )
203 |
--------------------------------------------------------------------------------
/nr4seg/lightning/pretrain_data_module.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pytorch_lightning as pl
4 |
5 | from torch.utils.data import DataLoader
6 | from typing import Optional
7 |
8 | from nr4seg.dataset import ScanNet
9 |
10 |
11 | class PretrainDataModule(pl.LightningDataModule):
12 |
13 | def __init__(self, env: dict, cfg_dm: dict):
14 | super().__init__()
15 |
16 | self.cfg_dm = cfg_dm
17 | self.env = env
18 |
19 | def setup(self, stage: Optional[str] = None) -> None:
20 | split_file = os.path.join(
21 | self.cfg_dm["root"],
22 | self.cfg_dm["data_preprocessing"]["split_file"],
23 | )
24 | img_list = np.load(split_file)
25 | self.scannet_test = ScanNet(root=self.cfg_dm["root"],
26 | img_list=img_list["test"],
27 | mode="test")
28 | self.scannet_train = ScanNet(root=self.cfg_dm["root"],
29 | img_list=img_list["train"],
30 | mode="train")
31 | self.scannet_val = ScanNet(root=self.cfg_dm["root"],
32 | img_list=img_list["val"],
33 | mode="val")
34 |
35 | def train_dataloader(self) -> DataLoader:
36 | return DataLoader(
37 | self.scannet_train,
38 | batch_size=self.cfg_dm["batch_size"],
39 | drop_last=self.cfg_dm["drop_last"],
40 | shuffle=self.cfg_dm["shuffle"],
41 | pin_memory=self.cfg_dm["pin_memory"],
42 | num_workers=self.cfg_dm["num_workers"],
43 | )
44 |
45 | def val_dataloader(self) -> DataLoader:
46 | return DataLoader(
47 | self.scannet_val,
48 | batch_size=self.cfg_dm["batch_size"],
49 | drop_last=self.cfg_dm["drop_last"],
50 | shuffle=False,
51 | pin_memory=self.cfg_dm["pin_memory"],
52 | num_workers=self.cfg_dm["num_workers"],
53 | )
54 |
55 | def test_dataloader(self) -> DataLoader:
56 | return DataLoader(
57 | self.scannet_test,
58 | batch_size=self.cfg_dm["batch_size"],
59 | drop_last=self.cfg_dm["drop_last"],
60 | shuffle=False,
61 | pin_memory=self.cfg_dm["pin_memory"],
62 | num_workers=self.cfg_dm["num_workers"],
63 | )
64 |
--------------------------------------------------------------------------------
/nr4seg/lightning/semantics_lightning_net.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytorch_lightning as pl
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from nr4seg.network import DeepLabV3
7 | from nr4seg.utils.metrics import SemanticsMeter
8 | from nr4seg.visualizer import Visualizer
9 |
10 |
11 | class SemanticsLightningNet(pl.LightningModule):
12 |
13 | def __init__(self, exp, env):
14 | super().__init__()
15 | self._model = DeepLabV3(exp["model"])
16 | self.prev_scene_name = None
17 |
18 | self._visualizer = Visualizer(
19 | os.path.join(exp["general"]["name"], "visu"),
20 | exp["visualizer"]["store"],
21 | self,
22 | )
23 |
24 | self._meter = {
25 | "val_1": SemanticsMeter(number_classes=exp["model"]["num_classes"]),
26 | "val_2": SemanticsMeter(number_classes=exp["model"]["num_classes"]),
27 | "val_3": SemanticsMeter(number_classes=exp["model"]["num_classes"]),
28 | "test": SemanticsMeter(number_classes=exp["model"]["num_classes"]),
29 | "train": SemanticsMeter(number_classes=exp["model"]["num_classes"]),
30 | }
31 |
32 | self._visu_count = {"val": 0, "test": 0, "train": 0}
33 |
34 | self._exp = exp
35 | self._env = env
36 | self._mode = "train"
37 | self.length_train_dataloader = 10000
38 |
39 | def forward(self, image: torch.Tensor) -> torch.Tensor:
40 | return self._model(image)
41 |
42 | def visu(self, image, target, pred):
43 | if not (self._visu_count[self._mode]
44 | < self._exp["visualizer"]["store_n"][self._mode]):
45 | return
46 |
47 | for b in range(image.shape[0]):
48 | if (self._visu_count[self._mode]
49 | < self._exp["visualizer"]["store_n"][self._mode]):
50 | c = self._visu_count[self._mode]
51 | self._visualizer.plot_image(image[b],
52 | tag=f"{self._mode}_vis/image_{c}")
53 | self._visualizer.plot_segmentation(
54 | pred[b], tag=f"{self._mode}_vis/pred_{c}")
55 | self._visualizer.plot_segmentation(
56 | target[b], tag=f"{self._mode}_vis/target_{c}")
57 |
58 | self._visualizer.plot_detectron(
59 | image[b], target[b], tag=f"{self._mode}_vis/detectron_{c}")
60 |
61 | self._visu_count[self._mode] += 1
62 | else:
63 | break
64 |
65 | # TRAINING
66 | def on_train_epoch_start(self):
67 | self._mode = "train"
68 | self._visu_count[self._mode] = 0
69 | self._meter["train"].clear()
70 |
71 | def training_step(self, batch, batch_idx: int) -> torch.Tensor:
72 | image, target, ori_image = batch
73 | output = self(image)
74 | pred = F.softmax(output["out"], dim=1)
75 | pred_argmax = torch.argmax(pred, dim=1)
76 | all_pred_argmax = self.all_gather(pred_argmax)
77 | all_target = self.all_gather(target)
78 | self._meter[self._mode].update(all_pred_argmax, all_target)
79 | # Compute Loss
80 | loss = F.cross_entropy(pred, target, ignore_index=-1, reduction="none")
81 | # Visu
82 | # self.visu(ori_image, target+1, pred_argmax+1)
83 | # Loss loggging
84 | self.log(
85 | f"{self._mode}/loss",
86 | loss.mean().item(),
87 | on_step=self._mode == "train",
88 | on_epoch=self._mode != "train",
89 | )
90 | return loss.mean()
91 |
92 | def on_train_epoch_end(self):
93 | m_iou, total_acc, m_acc = self._meter["train"].measure()
94 | self.log(f"train/total_accuracy", total_acc, rank_zero_only=True)
95 | self.log(f"train/mean_accuracy", m_acc, rank_zero_only=True)
96 | self.log(f"train/mean_IoU", m_iou, rank_zero_only=True)
97 |
98 | # VALIDATION
99 | def on_validation_epoch_start(self):
100 | self._mode = "val"
101 | self._visu_count[self._mode] = 0
102 | self._meter["val_1"].clear()
103 | self._meter["val_2"].clear()
104 | self._meter["val_3"].clear()
105 |
106 | def validation_step(self, batch, batch_idx: int) -> None:
107 | dataloader_idx = 0 # TODO: modify back
108 | image, target, ori_image, scene_name = batch
109 | scene_name = scene_name[0]
110 | output = self(image)
111 | pred = F.softmax(output["out"], dim=1)
112 | pred_argmax = torch.argmax(pred, dim=1)
113 | all_pred_argmax = self.all_gather(pred_argmax)
114 | all_target = self.all_gather(target)
115 |
116 | self.prev_scene_name = scene_name
117 | self._meter[f"val_{dataloader_idx+1}"].update(all_pred_argmax,
118 | all_target)
119 |
120 | # Compute Loss
121 | loss = F.cross_entropy(pred, target, ignore_index=-1, reduction="none")
122 | # Visu
123 | # self.visu(ori_image, target+1, pred_argmax+1)
124 | # Loss loggging
125 | self.log(
126 | f"{self._mode}/loss",
127 | loss.mean().item(),
128 | on_step=self._mode == "train",
129 | on_epoch=self._mode != "train",
130 | )
131 | return loss.mean()
132 |
133 | def on_validation_epoch_end(self):
134 | m_iou_1, total_acc_1, m_acc_1 = self._meter["val_1"].measure()
135 | self.log(f"val/total_accuracy_gg", total_acc_1, rank_zero_only=True)
136 | self.log(f"val/mean_accuracy_gg", m_acc_1, rank_zero_only=True)
137 | self.log(f"val/mean_IoU_gg", m_iou_1, rank_zero_only=True)
138 | self.prev_scene_name = None
139 |
140 | # TESTING
141 | def on_test_epoch_start(self):
142 | self._mode = "test"
143 | self._visu_count[self._mode] = 0
144 | self._meter["test"].clear()
145 |
146 | def test_step(self, batch, batch_idx: int) -> None:
147 | return self.training_step(batch, batch_idx)
148 |
149 | def on_test_epoch_end(self):
150 | m_iou, total_acc, m_acc = self._meter["test"].measure()
151 | self.log(f"test/total_accuracy", total_acc, rank_zero_only=True)
152 | self.log(f"test/mean_accuracy", m_acc, rank_zero_only=True)
153 | self.log(f"test/mean_IoU", m_iou, rank_zero_only=True)
154 |
155 | def configure_optimizers(self) -> torch.optim.Optimizer:
156 | optimizer = self._exp["optimizer"]["name"]
157 | lr = self._exp["optimizer"]["lr"]
158 | if optimizer == "Adam":
159 | optimizer = torch.optim.Adam(self._model.parameters(), lr=lr)
160 | if optimizer == "SGD":
161 | sgd_cfg = self._exp["optimizer"]["sgd_cfg"]
162 | optimizer = torch.optim.SGD(
163 | self._model.parameters(),
164 | lr=lr,
165 | momentum=sgd_cfg["momentum"],
166 | weight_decay=sgd_cfg["weight_decay"],
167 | )
168 | if optimizer == "Adadelta":
169 | optimizer = torch.optim.Adadelta(self._model.parameters(), lr=lr)
170 | if optimizer == "RMSprop":
171 | optimizer = torch.optim.RMSprop(self._model.parameters(),
172 | momentum=0.9,
173 | lr=lr)
174 | if self._exp["lr_scheduler"]["active"]:
175 | scheduler = self._exp["lr_scheduler"]["name"]
176 | if scheduler == "POLY":
177 | init_lr = (self._exp["optimizer"]["lr"],)
178 | max_epochs = self._exp["lr_scheduler"]["poly_cfg"]["max_epochs"]
179 | target_lr = self._exp["lr_scheduler"]["poly_cfg"]["target_lr"]
180 | power = self._exp["lr_scheduler"]["poly_cfg"]["power"]
181 | lambda_lr = (lambda epoch: (
182 | ((max_epochs - min(max_epochs, epoch)) / max_epochs)**
183 | (power)) + (1 - ((
184 | (max_epochs - min(max_epochs, epoch)) / max_epochs)**
185 | (power))) * target_lr / init_lr)
186 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
187 | lambda_lr,
188 | last_epoch=-1,
189 | verbose=True)
190 | interval = "epoch"
191 | lr_scheduler = {"scheduler": scheduler, "interval": interval}
192 | ret = {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
193 | else:
194 | ret = [optimizer]
195 | return ret
196 |
--------------------------------------------------------------------------------
/nr4seg/nerf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethz-asl/ucsa_neural_rendering/d29b37445e7e6be2cf04fa70ccebebbcbec0c1bf/nr4seg/nerf/__init__.py
--------------------------------------------------------------------------------
/nr4seg/nerf/activation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch.autograd import Function
4 | from torch.cuda.amp import custom_bwd, custom_fwd
5 |
6 |
7 | class _trunc_exp(Function):
8 |
9 | @staticmethod
10 | @custom_fwd(cast_inputs=torch.float)
11 | def forward(ctx, x):
12 | ctx.save_for_backward(x)
13 | return torch.exp(x)
14 |
15 | @staticmethod
16 | @custom_bwd
17 | def backward(ctx, g):
18 | x = ctx.saved_tensors[0]
19 | return g * torch.exp(x.clamp(-15, 15))
20 |
21 |
22 | trunc_exp = _trunc_exp.apply
--------------------------------------------------------------------------------
/nr4seg/nerf/network_tcnn_semantics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tinycudann as tcnn
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from .activation import trunc_exp
7 | from .renderer_semantics import SemanticNeRFRenderer
8 |
9 |
10 | class SemanticNeRFNetwork(SemanticNeRFRenderer):
11 |
12 | def __init__(self,
13 | encoding="HashGrid",
14 | encoding_dir="SphericalHarmonics",
15 | num_layers=2,
16 | hidden_dim=64,
17 | geo_feat_dim=15,
18 | num_layers_color=3,
19 | hidden_dim_color=64,
20 | num_layers_semantics=2,
21 | hidden_dim_semantics=64,
22 | bound=1,
23 | num_semantic_classes=41,
24 | **kwargs):
25 | super().__init__(bound,
26 | **kwargs,
27 | num_semantic_classes=num_semantic_classes)
28 |
29 | # sigma network
30 | self.num_layers = num_layers
31 | self.hidden_dim = hidden_dim
32 | self.geo_feat_dim = geo_feat_dim
33 |
34 | per_level_scale = np.exp2(np.log2(2048 * bound / 16) / (16 - 1))
35 |
36 | self.encoder = tcnn.Encoding(
37 | n_input_dims=3,
38 | encoding_config={
39 | "otype": "HashGrid",
40 | "n_levels": 16,
41 | "n_features_per_level": 2,
42 | "log2_hashmap_size": 19,
43 | "base_resolution": 16,
44 | "per_level_scale": per_level_scale,
45 | },
46 | )
47 |
48 | self.sigma_net = tcnn.Network(
49 | n_input_dims=32,
50 | n_output_dims=1 + self.geo_feat_dim,
51 | network_config={
52 | "otype": "FullyFusedMLP",
53 | "activation": "ReLU",
54 | "output_activation": "None",
55 | "n_neurons": hidden_dim,
56 | "n_hidden_layers": num_layers - 1,
57 | },
58 | )
59 |
60 | # color network
61 | self.num_layers_color = num_layers_color
62 | self.hidden_dim_color = hidden_dim_color
63 |
64 | self.encoder_dir = tcnn.Encoding(
65 | n_input_dims=3,
66 | encoding_config={
67 | "otype": "SphericalHarmonics",
68 | "degree": 4,
69 | },
70 | )
71 |
72 | self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim
73 |
74 | self.color_net = tcnn.Network(
75 | n_input_dims=self.in_dim_color,
76 | n_output_dims=3,
77 | network_config={
78 | "otype": "FullyFusedMLP",
79 | "activation": "ReLU",
80 | "output_activation": "None",
81 | "n_neurons": hidden_dim_color,
82 | "n_hidden_layers": num_layers_color - 1,
83 | },
84 | )
85 |
86 | self.num_layers_semantics = num_layers_semantics
87 | self.hidden_dim_semantics = hidden_dim_semantics
88 | self.num_semantic_classes = num_semantic_classes
89 | self.in_dim_semantics = self.geo_feat_dim
90 | self.semantics_net = tcnn.Network(
91 | n_input_dims=self.in_dim_semantics,
92 | n_output_dims=self.num_semantic_classes,
93 | network_config={
94 | "otype": "FullyFusedMLP",
95 | "activation": "ReLU",
96 | "output_activation": "None",
97 | "n_neurons": hidden_dim_semantics,
98 | "n_hidden_layers": num_layers_semantics - 1,
99 | },
100 | )
101 |
102 | def forward(self, x, d):
103 | # x: [N, 3], in [-bound, bound]
104 | # d: [N, 3], nomalized in [-1, 1]
105 |
106 | # sigma
107 | x = (x + self.bound) / (2 * self.bound) # to [0, 1]
108 | x = self.encoder(x)
109 | h = self.sigma_net(x)
110 |
111 | # sigma = F.relu(h[..., 0])
112 | sigma = trunc_exp(h[..., 0])
113 | geo_feat = h[..., 1:]
114 |
115 | # color
116 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
117 | d = self.encoder_dir(d)
118 |
119 | # p = torch.zeros_like(geo_feat[..., :1]) # manual input padding
120 | h = torch.cat([d, geo_feat], dim=-1)
121 | h = self.color_net(h)
122 |
123 | # sigmoid activation for rgb
124 | color = torch.sigmoid(h)
125 | semantics = self.semantics_net(geo_feat)
126 | semantics = F.softmax(semantics, dim=-1)
127 |
128 | return sigma, color, semantics
129 |
130 | def density(self, x):
131 | # x: [N, 3], in [-bound, bound]
132 |
133 | x = (x + self.bound) / (2 * self.bound) # to [0, 1]
134 | x = self.encoder(x)
135 | h = self.sigma_net(x)
136 |
137 | # sigma = F.relu(h[..., 0])
138 | sigma = trunc_exp(h[..., 0])
139 | geo_feat = h[..., 1:]
140 |
141 | return {
142 | "sigma": sigma,
143 | "geo_feat": geo_feat,
144 | }
145 |
146 | # allow masked inference
147 | def color(self, x, d, mask=None, geo_feat=None, **kwargs):
148 | # x: [N, 3] in [-bound, bound]
149 | # mask: [N,], bool, indicates where we actually needs to compute rgb.
150 |
151 | x = (x + self.bound) / (2 * self.bound) # to [0, 1]
152 |
153 | if mask is not None:
154 | rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype,
155 | device=x.device) # [N, 3]
156 | # in case of empty mask
157 | if not mask.any():
158 | return rgbs
159 | x = x[mask]
160 | d = d[mask]
161 | geo_feat = geo_feat[mask]
162 |
163 | # color
164 | d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
165 | d = self.encoder_dir(d)
166 |
167 | h = torch.cat([d, geo_feat], dim=-1)
168 | h = self.color_net(h)
169 |
170 | # sigmoid activation for rgb
171 | h = torch.sigmoid(h)
172 |
173 | if mask is not None:
174 | rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
175 | else:
176 | rgbs = h
177 |
178 | return rgbs
179 |
180 | def semantics(self, x, d, mask=None, geo_feat=None, **kwargs):
181 | # x: [N, 3] in [-bound, bound]
182 | # mask: [N,], bool, indicates where we actually needs to compute rgb.
183 |
184 | x = (x + self.bound) / (2 * self.bound) # to [0, 1]
185 |
186 | if mask is not None:
187 | semantics = torch.zeros(
188 | mask.shape[0],
189 | self.num_semantic_classes,
190 | dtype=x.dtype,
191 | device=x.device,
192 | ) # [N, 3]
193 | # in case of empty mask
194 | if not mask.any():
195 | return semantics
196 | x = x[mask]
197 | geo_feat = geo_feat[mask]
198 |
199 | h = self.semantics_net(geo_feat)
200 |
201 | if mask is not None:
202 | semantics[mask] = h.to(semantics.dtype) # fp16 --> fp32
203 | semantics[mask] = F.softmax(semantics[mask], dim=-1)
204 | else:
205 | semantics = h.to(semantics.dtype)
206 | semantics = F.softmax(semantics, dim=-1)
207 | return semantics
208 |
--------------------------------------------------------------------------------
/nr4seg/nerf/raymarching/__init__.py:
--------------------------------------------------------------------------------
1 | from .raymarching import *
2 |
--------------------------------------------------------------------------------
/nr4seg/nerf/raymarching/backend.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from torch.utils.cpp_extension import load
4 |
5 | _src_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | nvcc_flags = [
8 | "-O3",
9 | "-std=c++14",
10 | "-U__CUDA_NO_HALF_OPERATORS__",
11 | "-U__CUDA_NO_HALF_CONVERSIONS__",
12 | "-U__CUDA_NO_HALF2_OPERATORS__",
13 | ]
14 |
15 | if os.name == "posix":
16 | c_flags = ["-O3", "-std=c++14"]
17 | elif os.name == "nt":
18 | c_flags = ["/O2", "/std:c++17"]
19 |
20 | # find cl.exe
21 | def find_cl_path():
22 | import glob
23 |
24 | for edition in [
25 | "Enterprise", "Professional", "BuildTools", "Community"
26 | ]:
27 | paths = sorted(
28 | glob.glob(
29 | r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64"
30 | % edition),
31 | reverse=True,
32 | )
33 | if paths:
34 | return paths[0]
35 |
36 | # If cl.exe is not on path, try to find it.
37 | if os.system("where cl.exe >nul 2>nul") != 0:
38 | cl_path = find_cl_path()
39 | if cl_path is None:
40 | raise RuntimeError(
41 | "Could not locate a supported Microsoft Visual C++ installation"
42 | )
43 | os.environ["PATH"] += ";" + cl_path
44 |
45 | _backend = load(
46 | name="_raymarching",
47 | extra_cflags=c_flags,
48 | extra_cuda_cflags=nvcc_flags,
49 | sources=[
50 | os.path.join(_src_path, "src", f) for f in [
51 | "raymarching.cu",
52 | "bindings.cpp",
53 | ]
54 | ],
55 | )
56 |
57 | __all__ = ["_backend"]
58 |
--------------------------------------------------------------------------------
/nr4seg/nerf/raymarching/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "raymarching.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | // utils
7 | m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
8 | // train
9 | m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
10 | m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
11 | m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
12 | // m.def("composite_rays_train_semantics_forward", &composite_rays_train_semantics_forward, "composite_rays_train_forward (CUDA)");
13 | // m.def("composite_rays_train_semantics_backward", &composite_rays_train_semantics_backward, "composite_rays_train_backward (CUDA)");
14 | // infer
15 | m.def("march_rays", &march_rays, "march rays (CUDA)");
16 | // m.def("composite_rays_semantics", &composite_rays_semantics, "composite rays (CUDA)");
17 | m.def("compact_rays", &compact_rays, "compact rays (CUDA)");
18 | }
--------------------------------------------------------------------------------
/nr4seg/nerf/raymarching/src/pcg32.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Tiny self-contained version of the PCG Random Number Generation for C++
3 | * put together from pieces of the much larger C/C++ codebase.
4 | * Wenzel Jakob, February 2015
5 | *
6 | * The PCG random number generator was developed by Melissa O'Neill
7 | *
8 | *
9 | * Licensed under the Apache License, Version 2.0 (the "License");
10 | * you may not use this file except in compliance with the License.
11 | * You may obtain a copy of the License at
12 | *
13 | * http://www.apache.org/licenses/LICENSE-2.0
14 | *
15 | * Unless required by applicable law or agreed to in writing, software
16 | * distributed under the License is distributed on an "AS IS" BASIS,
17 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | * See the License for the specific language governing permissions and
19 | * limitations under the License.
20 | *
21 | * For additional information about the PCG random number generation scheme,
22 | * including its license and other licensing options, visit
23 | *
24 | * http://www.pcg-random.org
25 | *
26 | * Note: This code was modified to work with CUDA by the tiny-cuda-nn authors.
27 | */
28 |
29 | #pragma once
30 |
31 | #define PCG32_DEFAULT_STATE 0x853c49e6748fea9bULL
32 | #define PCG32_DEFAULT_STREAM 0xda3e39cb94b95bdbULL
33 | #define PCG32_MULT 0x5851f42d4c957f2dULL
34 |
35 | #include
36 | #include
37 | #include
38 |
39 | #include
40 | #include
41 | #include
42 |
43 | /// PCG32 Pseudorandom number generator
44 | struct pcg32 {
45 | /// Initialize the pseudorandom number generator with default seed
46 | __host__ __device__ pcg32() : state(PCG32_DEFAULT_STATE), inc(PCG32_DEFAULT_STREAM) {}
47 |
48 | /// Initialize the pseudorandom number generator with the \ref seed() function
49 | __host__ __device__ pcg32(uint64_t initstate, uint64_t initseq = 1u) { seed(initstate, initseq); }
50 |
51 | /**
52 | * \brief Seed the pseudorandom number generator
53 | *
54 | * Specified in two parts: a state initializer and a sequence selection
55 | * constant (a.k.a. stream id)
56 | */
57 | __host__ __device__ void seed(uint64_t initstate, uint64_t initseq = 1) {
58 | state = 0U;
59 | inc = (initseq << 1u) | 1u;
60 | next_uint();
61 | state += initstate;
62 | next_uint();
63 | }
64 |
65 | /// Generate a uniformly distributed unsigned 32-bit random number
66 | __host__ __device__ uint32_t next_uint() {
67 | uint64_t oldstate = state;
68 | state = oldstate * PCG32_MULT + inc;
69 | uint32_t xorshifted = (uint32_t) (((oldstate >> 18u) ^ oldstate) >> 27u);
70 | uint32_t rot = (uint32_t) (oldstate >> 59u);
71 | return (xorshifted >> rot) | (xorshifted << ((~rot + 1u) & 31));
72 | }
73 |
74 | /// Generate a uniformly distributed number, r, where 0 <= r < bound
75 | __host__ __device__ uint32_t next_uint(uint32_t bound) {
76 | // To avoid bias, we need to make the range of the RNG a multiple of
77 | // bound, which we do by dropping output less than a threshold.
78 | // A naive scheme to calculate the threshold would be to do
79 | //
80 | // uint32_t threshold = 0x100000000ull % bound;
81 | //
82 | // but 64-bit div/mod is slower than 32-bit div/mod (especially on
83 | // 32-bit platforms). In essence, we do
84 | //
85 | // uint32_t threshold = (0x100000000ull-bound) % bound;
86 | //
87 | // because this version will calculate the same modulus, but the LHS
88 | // value is less than 2^32.
89 |
90 | uint32_t threshold = (~bound+1u) % bound;
91 |
92 | // Uniformity guarantees that this loop will terminate. In practice, it
93 | // should usually terminate quickly; on average (assuming all bounds are
94 | // equally likely), 82.25% of the time, we can expect it to require just
95 | // one iteration. In the worst case, someone passes a bound of 2^31 + 1
96 | // (i.e., 2147483649), which invalidates almost 50% of the range. In
97 | // practice, bounds are typically small and only a tiny amount of the range
98 | // is eliminated.
99 | for (;;) {
100 | uint32_t r = next_uint();
101 | if (r >= threshold)
102 | return r % bound;
103 | }
104 | }
105 |
106 | /// Generate a single precision floating point value on the interval [0, 1)
107 | __host__ __device__ float next_float() {
108 | /* Trick from MTGP: generate an uniformly distributed
109 | single precision number in [1,2) and subtract 1. */
110 | union {
111 | uint32_t u;
112 | float f;
113 | } x;
114 | x.u = (next_uint() >> 9) | 0x3f800000u;
115 | return x.f - 1.0f;
116 | }
117 |
118 | /**
119 | * \brief Generate a double precision floating point value on the interval [0, 1)
120 | *
121 | * \remark Since the underlying random number generator produces 32 bit output,
122 | * only the first 32 mantissa bits will be filled (however, the resolution is still
123 | * finer than in \ref next_float(), which only uses 23 mantissa bits)
124 | */
125 | __host__ __device__ double next_double() {
126 | /* Trick from MTGP: generate an uniformly distributed
127 | double precision number in [1,2) and subtract 1. */
128 | union {
129 | uint64_t u;
130 | double d;
131 | } x;
132 | x.u = ((uint64_t) next_uint() << 20) | 0x3ff0000000000000ULL;
133 | return x.d - 1.0;
134 | }
135 |
136 | /**
137 | * \brief Multi-step advance function (jump-ahead, jump-back)
138 | *
139 | * The method used here is based on Brown, "Random Number Generation
140 | * with Arbitrary Stride", Transactions of the American Nuclear
141 | * Society (Nov. 1994). The algorithm is very similar to fast
142 | * exponentiation.
143 | *
144 | * The default value of 2^32 ensures that the PRNG is advanced
145 | * sufficiently far that there is (likely) no overlap with
146 | * previously drawn random numbers, even if small advancements.
147 | * are made inbetween.
148 | */
149 | __host__ __device__ void advance(int64_t delta_ = (1ll<<32)) {
150 | uint64_t
151 | cur_mult = PCG32_MULT,
152 | cur_plus = inc,
153 | acc_mult = 1u,
154 | acc_plus = 0u;
155 |
156 | /* Even though delta is an unsigned integer, we can pass a signed
157 | integer to go backwards, it just goes "the long way round". */
158 | uint64_t delta = (uint64_t) delta_;
159 |
160 | while (delta > 0) {
161 | if (delta & 1) {
162 | acc_mult *= cur_mult;
163 | acc_plus = acc_plus * cur_mult + cur_plus;
164 | }
165 | cur_plus = (cur_mult + 1) * cur_plus;
166 | cur_mult *= cur_mult;
167 | delta /= 2;
168 | }
169 | state = acc_mult * state + acc_plus;
170 | }
171 |
172 | /// Compute the distance between two PCG32 pseudorandom number generators
173 | __host__ __device__ int64_t operator-(const pcg32 &other) const {
174 | assert(inc == other.inc);
175 |
176 | uint64_t
177 | cur_mult = PCG32_MULT,
178 | cur_plus = inc,
179 | cur_state = other.state,
180 | the_bit = 1u,
181 | distance = 0u;
182 |
183 | while (state != cur_state) {
184 | if ((state & the_bit) != (cur_state & the_bit)) {
185 | cur_state = cur_state * cur_mult + cur_plus;
186 | distance |= the_bit;
187 | }
188 | assert((state & the_bit) == (cur_state & the_bit));
189 | the_bit <<= 1;
190 | cur_plus = (cur_mult + 1ULL) * cur_plus;
191 | cur_mult *= cur_mult;
192 | }
193 |
194 | return (int64_t) distance;
195 | }
196 |
197 | /// Equality operator
198 | __host__ __device__ bool operator==(const pcg32 &other) const { return state == other.state && inc == other.inc; }
199 |
200 | /// Inequality operator
201 | __host__ __device__ bool operator!=(const pcg32 &other) const { return state != other.state || inc != other.inc; }
202 |
203 | uint64_t state; // RNG state. All values are possible.
204 | uint64_t inc; // Controls which RNG sequence (stream) is selected. Must *always* be odd.
205 | };
--------------------------------------------------------------------------------
/nr4seg/nerf/raymarching/src/raymarching.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 |
7 | void near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_d, at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
8 |
9 | void march_rays_train(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float mean_density, const float bound, const float dt_gamma, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const uint32_t perturb);
10 | void composite_rays_train_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
11 | void composite_rays_train_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
12 | // void composite_rays_train_semantics_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor local_semantics, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image, at::Tensor semantics);
13 | // void composite_rays_train_semantics_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor grad_semantics, at::Tensor sigmas, at::Tensor rgbs, at::Tensor local_semantics, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, at::Tensor semantics, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_local_semantics);
14 |
15 | void march_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor rays_o, at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t C, const uint32_t H, at::Tensor density_grid, const float mean_density, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb);
16 | void composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
17 | // void composite_rays_semantics(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor local_semantics, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image, at::Tensor semantics);
18 | void compact_rays(const uint32_t n_alive, at::Tensor rays_alive, at::Tensor rays_alive_old, at::Tensor rays_t, at::Tensor rays_t_old, at::Tensor alive_counter);
--------------------------------------------------------------------------------
/nr4seg/nerf/renderer_semantics.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import trimesh
6 |
7 | from .raymarching import raymarching
8 |
9 |
10 | def sample_pdf(bins, weights, n_samples, det=False):
11 | # This implementation is from NeRF
12 | # bins: [B, T], old_z_vals
13 | # weights: [B, T - 1], bin weights.
14 | # return: [B, n_samples], new_z_vals
15 |
16 | # Get pdf
17 | weights = weights + 1e-5 # prevent nans
18 | pdf = weights / torch.sum(weights, -1, keepdim=True)
19 | cdf = torch.cumsum(pdf, -1)
20 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
21 | # Take uniform samples
22 | if det:
23 | u = torch.linspace(0.0 + 0.5 / n_samples,
24 | 1.0 - 0.5 / n_samples,
25 | steps=n_samples).to(weights.device)
26 | u = u.expand(list(cdf.shape[:-1]) + [n_samples])
27 | else:
28 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
29 |
30 | # Invert CDF
31 | u = u.contiguous()
32 | inds = torch.searchsorted(cdf, u, right=True)
33 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
34 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
35 | inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
36 |
37 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
38 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
39 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
40 |
41 | denom = cdf_g[..., 1] - cdf_g[..., 0]
42 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
43 | t = (u - cdf_g[..., 0]) / denom
44 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
45 |
46 | return samples
47 |
48 |
49 | def plot_pointcloud(pc, color=None):
50 | # pc: [N, 3]
51 | # color: [N, 3/4]
52 | print("[visualize points]", pc.shape, pc.dtype, pc.min(0), pc.max(0))
53 | pc = trimesh.PointCloud(pc, color)
54 | # axis
55 | axes = trimesh.creation.axis(axis_length=4)
56 | # sphere
57 | sphere = trimesh.creation.icosphere(radius=1)
58 | trimesh.Scene([pc, axes, sphere]).show()
59 |
60 |
61 | class SemanticNeRFRenderer(nn.Module):
62 |
63 | def __init__(
64 | self,
65 | bound=1,
66 | cuda_ray=False,
67 | density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
68 | num_semantic_classes=41,
69 | ):
70 | super().__init__()
71 |
72 | self.epoch = 1
73 | self.weights = np.zeros([0])
74 | self.weights_sum = np.zeros([0])
75 |
76 | self.bound = bound
77 | self.cascade = 1 + math.ceil(math.log2(bound))
78 | self.density_scale = density_scale
79 | self.num_semantic_classes = num_semantic_classes
80 |
81 | # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
82 | # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
83 | aabb_train = torch.FloatTensor(
84 | [-bound, -bound, -bound, bound, bound, bound])
85 | aabb_infer = aabb_train.clone()
86 | self.register_buffer("aabb_train", aabb_train)
87 | self.register_buffer("aabb_infer", aabb_infer)
88 |
89 | # extra state for cuda raymarching
90 | self.cuda_ray = cuda_ray
91 | if cuda_ray:
92 | # density grid
93 | density_grid = torch.zeros([self.cascade] +
94 | [128] * 3) # [CAS, H, H, H]
95 | self.register_buffer("density_grid", density_grid)
96 | self.mean_density = 0
97 | self.iter_density = 0
98 | # step counter
99 | step_counter = torch.zeros(
100 | 16, 2, dtype=torch.int32) # 16 is hardcoded for averaging...
101 | self.register_buffer("step_counter", step_counter)
102 | self.mean_count = 0
103 | self.local_step = 0
104 |
105 | def forward(self, x, d):
106 | raise NotImplementedError()
107 |
108 | def density(self, x):
109 | raise NotImplementedError()
110 |
111 | def reset_extra_state(self):
112 | if not self.cuda_ray:
113 | return
114 | # density grid
115 | self.density_grid.zero_()
116 | self.mean_density = 0
117 | self.iter_density = 0
118 | # step counter
119 | self.step_counter.zero_()
120 | self.mean_count = 0
121 | self.local_step = 0
122 |
123 | def run(self,
124 | rays_o,
125 | rays_d,
126 | direction_norms,
127 | num_steps=256,
128 | upsample_steps=256,
129 | bg_color=None,
130 | perturb=False,
131 | epoch=None,
132 | **kwargs):
133 | # rays_o, rays_d: [B, N, 3], assumes B == 1
134 | # direction_norms: [B, N, 1]
135 | # bg_color: [3] in range [0, 1]
136 | # return: image: [B, N, 3], depth: [B, N]
137 |
138 | prefix = rays_o.shape[:-1]
139 | rays_o = rays_o.contiguous().view(-1, 3)
140 | rays_d = rays_d.contiguous().view(-1, 3)
141 | direction_norms = direction_norms.contiguous().view(-1)
142 |
143 | N = rays_o.shape[0] # N = B * N, in fact
144 | device = rays_o.device
145 |
146 | # choose aabb
147 | aabb = self.aabb_train if self.training else self.aabb_infer
148 |
149 | # sample steps
150 | nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb)
151 | nears.unsqueeze_(-1)
152 | fars.unsqueeze_(-1)
153 |
154 | z_vals = torch.linspace(0.0, 1.0, num_steps,
155 | device=device).unsqueeze(0) # [1, T]
156 | z_vals = z_vals.expand((N, num_steps)) # [N, T]
157 | z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
158 |
159 | # perturb z_vals
160 | sample_dist = (fars - nears) / num_steps
161 | if perturb:
162 | mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
163 | upper = torch.cat([mids, z_vals[..., -1:]], -1)
164 | lower = torch.cat([z_vals[..., :1], mids], -1)
165 | # stratified samples in those intervals
166 | t_rand = torch.rand(z_vals.shape, device=device)
167 |
168 | z_vals = lower + (upper - lower) * t_rand
169 |
170 | # generate xyzs
171 | xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(
172 | -1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
173 | xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
174 |
175 | # query density and RGB
176 | density_outputs = self.density(xyzs.reshape(-1, 3))
177 |
178 | for k, v in density_outputs.items():
179 | density_outputs[k] = v.view(N, num_steps, -1)
180 |
181 | # upsample z_vals (nerf-like)
182 | if upsample_steps > 0:
183 | with torch.no_grad():
184 |
185 | deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
186 | deltas = torch.cat(
187 | [deltas, 1e10 * torch.ones_like(deltas[..., :1])], dim=-1)
188 |
189 | alphas = 1 - torch.exp(
190 | -deltas * self.density_scale *
191 | density_outputs["sigma"].squeeze(-1)) # [N, T]
192 | alphas_shifted = torch.cat(
193 | [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15],
194 | dim=-1,
195 | ) # [N, T+1]
196 | weights = (alphas *
197 | torch.cumprod(alphas_shifted, dim=-1)[..., :-1]
198 | ) # [N, T]
199 |
200 | # sample new z_vals
201 | z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]
202 | ) # [N, T-1]
203 | new_z_vals = sample_pdf(z_vals_mid,
204 | weights[:, 1:-1],
205 | upsample_steps,
206 | det=False).detach() # [N, t]
207 |
208 | new_xyzs = rays_o.unsqueeze(
209 | -2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(
210 | -1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
211 | new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]),
212 | aabb[3:]) # a manual clip.
213 |
214 | # only forward new points to save computation
215 | new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
216 | # new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
217 | for k, v in new_density_outputs.items():
218 | new_density_outputs[k] = v.view(N, upsample_steps, -1)
219 |
220 | # re-order
221 | z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
222 | z_vals, z_index = torch.sort(z_vals, dim=1)
223 |
224 | xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
225 | xyzs = torch.gather(xyzs,
226 | dim=1,
227 | index=z_index.unsqueeze(-1).expand_as(xyzs))
228 |
229 | for k in density_outputs:
230 | tmp_output = torch.cat(
231 | [density_outputs[k], new_density_outputs[k]], dim=1)
232 | density_outputs[k] = torch.gather(
233 | tmp_output,
234 | dim=1,
235 | index=z_index.unsqueeze(-1).expand_as(tmp_output),
236 | )
237 |
238 | deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
239 | deltas = torch.cat([deltas, 1e10 * torch.ones_like(deltas[..., :1])],
240 | dim=-1)
241 | alphas = 1 - torch.exp(-deltas * self.density_scale *
242 | density_outputs["sigma"].squeeze(-1)) # [N, T+t]
243 | alphas_shifted = torch.cat(
244 | [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15],
245 | dim=-1) # [N, T+t+1]
246 | weights = (alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]
247 | ) # [N, T+t]
248 |
249 | mask_rgb = weights > 1e-4 # hard coded
250 | mask_semantics = weights > 1e-4 # hard coded
251 |
252 | dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
253 | for k, v in density_outputs.items():
254 | density_outputs[k] = v.view(-1, v.shape[-1])
255 |
256 | rgbs = self.color(xyzs.reshape(-1, 3),
257 | dirs.reshape(-1, 3),
258 | mask=mask_rgb.reshape(-1),
259 | **density_outputs)
260 | rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
261 |
262 | local_semantics = self.semantics(xyzs.reshape(-1, 3),
263 | dirs.reshape(-1, 3),
264 | mask=mask_semantics.reshape(-1),
265 | **density_outputs)
266 | local_semantics = local_semantics.view(
267 | N, -1, self.num_semantic_classes) # [N, T+t, 3]
268 |
269 | # calculate weight_sum (mask)
270 | weights_semantics = weights.clone().detach()
271 | weights[torch.logical_not(mask_rgb)] = 0.0
272 |
273 | # calculate depth
274 | ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
275 | # depth = torch.sum(weights * ori_z_vals, dim=-1)
276 | depth = torch.sum(weights * z_vals, dim=-1)
277 | depth = depth / direction_norms
278 |
279 | # calculate color
280 |
281 | image = torch.sum(weights.unsqueeze(-1) * rgbs,
282 | dim=-2) # [N, 3], in [0, 1]
283 | weights_semantics[torch.logical_not(mask_semantics)] = 0.0
284 | semantics = torch.sum(weights_semantics.unsqueeze(-1) * local_semantics,
285 | dim=-2) # [N, C], in [0, 1]
286 |
287 | # mix background color
288 | if bg_color is None:
289 | bg_color = 1
290 |
291 | image = image.view(*prefix, 3)
292 | depth = depth.view(*prefix)
293 | semantics = semantics.view(*prefix, self.num_semantic_classes)
294 |
295 | return {
296 | "depth": depth,
297 | "image": image,
298 | "semantics": semantics,
299 | }
300 |
301 | def render(self,
302 | rays_o,
303 | rays_d,
304 | direction_norms,
305 | staged=False,
306 | max_ray_batch=4096,
307 | bg_color=None,
308 | perturb=False,
309 | epoch=None,
310 | **kwargs):
311 | # rays_o, rays_d: [B, N, 3], assumes B == 1
312 | # direction_norms: [B, N, 1]
313 | # return: pred_rgb: [B, N, 3]
314 |
315 | _run = self.run
316 |
317 | B, N = rays_o.shape[:2]
318 | device = rays_o.device
319 |
320 | # never stage when cuda_ray
321 | if staged and not self.cuda_ray:
322 | depth = torch.empty((B, N), device=device)
323 | image = torch.empty((B, N, 3), device=device)
324 | semantics = torch.empty((B, N, self.num_semantic_classes),
325 | device=device)
326 |
327 | for b in range(B):
328 | head = 0
329 | while head < N:
330 | tail = min(head + max_ray_batch, N)
331 | results_ = _run(rays_o[b:b + 1, head:tail],
332 | rays_d[b:b + 1, head:tail],
333 | direction_norms=direction_norms[b:b + 1,
334 | head:tail],
335 | bg_color=bg_color,
336 | perturb=perturb,
337 | epoch=epoch,
338 | **kwargs)
339 | depth[b:b + 1, head:tail] = results_["depth"]
340 | image[b:b + 1, head:tail] = results_["image"]
341 | semantics[b:b + 1, head:tail] = results_["semantics"]
342 | head += max_ray_batch
343 |
344 | results = {}
345 | results["depth"] = depth
346 | results["image"] = image
347 | results["semantics"] = semantics
348 |
349 | else:
350 | results = _run(rays_o,
351 | rays_d,
352 | direction_norms=direction_norms,
353 | bg_color=bg_color,
354 | perturb=perturb,
355 | epoch=epoch,
356 | **kwargs)
357 |
358 | return results
359 |
--------------------------------------------------------------------------------
/nr4seg/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .deeplabv3 import *
2 |
--------------------------------------------------------------------------------
/nr4seg/network/deeplabv3.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 |
3 | from torch import nn
4 |
5 |
6 | class DeepLabV3(nn.Module):
7 |
8 | def __init__(self, cfg_model):
9 | super().__init__()
10 | self._model = torchvision.models.segmentation.deeplabv3_resnet101(
11 | pretrained=cfg_model["pretrained"],
12 | pretrained_backbone=cfg_model["pretrained_backbone"],
13 | progress=True,
14 | num_classes=cfg_model["num_classes"],
15 | aux_loss=None,
16 | )
17 |
18 | def forward(self, data):
19 | return self._model(data)
20 |
--------------------------------------------------------------------------------
/nr4seg/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .flatten_dict import *
2 | from .get_logger import *
3 | from .loading import *
4 |
--------------------------------------------------------------------------------
/nr4seg/utils/flatten_dict.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | __all__ = ["flatten_dict"]
4 |
5 |
6 | def flatten_dict(d, parent_key="", sep="_"):
7 | items = []
8 | for k, v in d.items():
9 | new_key = parent_key + sep + k if parent_key else k
10 | if isinstance(v, collections.MutableMapping):
11 | items.extend(flatten_dict(v, new_key, sep=sep).items())
12 | else:
13 | if isinstance(v, list):
14 | if isinstance(v[0], dict):
15 | items.extend(flatten_list(v, new_key, sep=sep))
16 | continue
17 | items.append((new_key, v))
18 | return dict(items)
19 |
--------------------------------------------------------------------------------
/nr4seg/utils/get_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
4 | from pytorch_lightning.loggers.neptune import NeptuneLogger
5 |
6 | from nr4seg.utils import flatten_dict
7 |
8 | __all__ = ["get_neptune_logger", "get_tensorboard_logger", "get_wandb_logger"]
9 |
10 |
11 | def log_important_params(exp):
12 | dic = {}
13 | dic = flatten_dict(exp)
14 | return dic
15 |
16 |
17 | def get_neptune_logger(exp, env, exp_p, env_p, project_name=""):
18 | params = log_important_params(exp)
19 |
20 | name_full = exp["general"]["name"]
21 | name_short = "__".join(name_full.split("/")[-2:])
22 |
23 | return NeptuneLogger(
24 | api_key=os.environ["NEPTUNE_API_TOKEN"],
25 | project=project_name,
26 | name=name_short,
27 | tags=[
28 | os.environ["ENV_WORKSTATION_NAME"],
29 | name_full.split("/")[-2],
30 | name_full.split("/")[-1],
31 | ],
32 | )
33 |
34 |
35 | def get_wandb_logger(exp, env, exp_p, env_p, project_name, save_dir):
36 | params = log_important_params(exp)
37 | name_full = exp["general"]["name"]
38 | name_short = "__".join(name_full.split("/")[-2:])
39 | return WandbLogger(
40 | name=name_short,
41 | project=project_name,
42 | save_dir=save_dir,
43 | )
44 |
45 |
46 | def get_tensorboard_logger(exp, env, exp_p, env_p):
47 | params = log_important_params(exp)
48 | return TensorBoardLogger(
49 | save_dir=exp["general"]["name"],
50 | name="tensorboard",
51 | default_hp_metric=params,
52 | )
53 |
--------------------------------------------------------------------------------
/nr4seg/utils/loading.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 | __all__ = ["file_path", "load_yaml"]
5 |
6 |
7 | def file_path(string):
8 | if os.path.isfile(string):
9 | return string
10 | else:
11 | raise NotADirectoryError(string)
12 |
13 |
14 | def load_yaml(path):
15 | with open(path) as file:
16 | res = yaml.load(file, Loader=yaml.FullLoader)
17 | return res
18 |
--------------------------------------------------------------------------------
/nr4seg/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from sklearn.metrics import confusion_matrix
5 |
6 |
7 | def nanmean(data, **args):
8 | # This makes it ignore the first 'background' class
9 | return np.ma.masked_array(data, np.isnan(data)).mean(**args)
10 | # In np.ma.masked_array(data, np.isnan(data), elements of data == np.nan is invalid and will be ingorned during computation of np.mean()
11 |
12 |
13 | class SemanticsMeter:
14 |
15 | def __init__(self, number_classes):
16 | self.conf_mat = None
17 | self.number_classes = number_classes
18 |
19 | def clear(self):
20 | self.conf_mat = None
21 |
22 | def prepare_inputs(self, *inputs):
23 | outputs = []
24 | for i, inp in enumerate(inputs):
25 | if torch.is_tensor(inp):
26 | inp = inp.detach().cpu().numpy()
27 | outputs.append(inp)
28 |
29 | return outputs
30 |
31 | def update(self, preds, truths):
32 | preds, truths = self.prepare_inputs(
33 | preds, truths) # [B, N, 3] or [B, H, W, 3], range[0, 1]
34 | preds = preds.flatten()
35 | truths = truths.flatten()
36 | valid_pix_ids = truths != -1
37 | preds = preds[valid_pix_ids]
38 | truths = truths[valid_pix_ids]
39 | conf_mat_current = confusion_matrix(truths,
40 | preds,
41 | labels=list(
42 | range(self.number_classes)))
43 | if self.conf_mat is None:
44 | self.conf_mat = conf_mat_current
45 | else:
46 | self.conf_mat += conf_mat_current
47 |
48 | def measure(self):
49 | conf_mat = self.conf_mat
50 | norm_conf_mat = np.transpose(
51 | np.transpose(conf_mat) / conf_mat.astype(np.float).sum(axis=1))
52 |
53 | missing_class_mask = np.isnan(norm_conf_mat.sum(
54 | 1)) # missing class will have NaN at corresponding class
55 | exsiting_class_mask = ~missing_class_mask
56 |
57 | class_average_accuracy = nanmean(np.diagonal(norm_conf_mat))
58 | total_accuracy = np.sum(np.diagonal(conf_mat)) / np.sum(conf_mat)
59 | ious = np.zeros(self.number_classes)
60 | for class_id in range(self.number_classes):
61 | ious[class_id] = conf_mat[class_id, class_id] / (
62 | np.sum(conf_mat[class_id, :]) + np.sum(conf_mat[:, class_id]) -
63 | conf_mat[class_id, class_id])
64 | miou_valid_class = np.mean(ious[exsiting_class_mask])
65 | return miou_valid_class, total_accuracy, class_average_accuracy
66 |
--------------------------------------------------------------------------------
/nr4seg/visualizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .colormaps import *
2 | from .visualizer import *
3 |
--------------------------------------------------------------------------------
/nr4seg/visualizer/colormaps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from collections import OrderedDict
4 | from matplotlib import cm
5 |
6 | ORDERED_DICT = OrderedDict([
7 | ("unlabeled", (0, 0, 0)),
8 | ("wall", (174, 199, 232)),
9 | ("floor", (152, 223, 138)),
10 | ("cabinet", (31, 119, 180)),
11 | ("bed", (255, 187, 120)),
12 | ("chair", (188, 189, 34)),
13 | ("sofa", (140, 86, 75)),
14 | ("table", (255, 152, 150)),
15 | ("door", (214, 39, 40)),
16 | ("window", (197, 176, 213)),
17 | ("bookshelf", (148, 103, 189)),
18 | ("picture", (196, 156, 148)),
19 | ("counter", (23, 190, 207)),
20 | ("blinds", (178, 76, 76)),
21 | ("desk", (247, 182, 210)),
22 | ("shelves", (66, 188, 102)),
23 | ("curtain", (219, 219, 141)),
24 | ("dresser", (140, 57, 197)),
25 | ("pillow", (202, 185, 52)),
26 | ("mirror", (51, 176, 203)),
27 | ("floormat", (200, 54, 131)),
28 | ("clothes", (92, 193, 61)),
29 | ("ceiling", (78, 71, 183)),
30 | ("books", (172, 114, 82)),
31 | ("refrigerator", (255, 127, 14)),
32 | ("television", (91, 163, 138)),
33 | ("paper", (153, 98, 156)),
34 | ("towel", (140, 153, 101)),
35 | ("showercurtain", (158, 218, 229)),
36 | ("box", (100, 125, 154)),
37 | ("whiteboard", (178, 127, 135)),
38 | ("person", (120, 185, 128)),
39 | ("nightstand", (146, 111, 194)),
40 | ("toilet", (44, 160, 44)),
41 | ("sink", (112, 128, 144)),
42 | ("lamp", (96, 207, 209)),
43 | ("bathtub", (227, 119, 194)),
44 | ("bag", (213, 92, 176)),
45 | ("otherstructure", (94, 106, 211)),
46 | ("otherfurniture", (82, 84, 163)),
47 | ("otherprop", (100, 85, 144)),
48 | ])
49 | SCANNET_CLASSES = [i for i, v in enumerate(ORDERED_DICT.values())]
50 | SCANNET_COLORS = [v for i, v in enumerate(ORDERED_DICT.values())]
51 |
52 | jet = cm.get_cmap("jet")
53 | BINARY_COLORS = (np.stack([jet(v) for v in np.linspace(0, 1, 2)]) * 255).astype(
54 | np.uint8)
55 |
--------------------------------------------------------------------------------
/nr4seg/visualizer/visualizer.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import imageio
3 | import numpy as np
4 | import os
5 | import skimage
6 | import wandb
7 |
8 | from matplotlib.backends.backend_agg import FigureCanvasAgg
9 | from PIL import Image, ImageDraw
10 |
11 | from nr4seg.visualizer import BINARY_COLORS, SCANNET_CLASSES, SCANNET_COLORS
12 |
13 | __all__ = ["Visualizer"]
14 |
15 |
16 | def get_img_from_fig(fig, dpi=180):
17 | """
18 | Converts matplot figure to np.array
19 | """
20 |
21 | fig.set_dpi(dpi)
22 | canvas = FigureCanvasAgg(fig)
23 | # Retrieve a view on the renderer buffer
24 | canvas.draw()
25 | buf = canvas.buffer_rgba()
26 | # convert to a np array
27 | buf = np.asarray(buf)
28 | buf = Image.fromarray(buf)
29 | buf = buf.convert("RGB")
30 | return buf
31 |
32 |
33 | def image_functionality(func):
34 | """
35 | Decorator to allow for logging functionality.
36 | Should be added to all plotting functions of visualizer.
37 | The plot function has to return a np.uint8
38 |
39 | @image_functionality
40 | def plot_segmentation(self, seg, **kwargs):
41 |
42 | return np.zeros((H, W, 3), dtype=np.uint8)
43 |
44 | not_log [optional, bool, default: false] : the decorator is ignored if set to true
45 | epoch [optional, int, default: visualizer.epoch ] : overwrites visualizer, epoch used to log the image
46 | store [optional, bool, default: visualizer.store ] : overwrites visualizer, flag if image is stored to disk
47 | tag [optinal, str, default: tag_not_defined ] : stores image as: visulaiter._p_visu{epoch}_{tag}.png
48 | """
49 |
50 | def wrap(*args, **kwargs):
51 | img = func(*args, **kwargs)
52 |
53 | if not kwargs.get("not_log", False):
54 | log_exp = args[0]._pl_model.logger is not None
55 | tag = kwargs.get("tag", "tag_not_defined")
56 |
57 | if kwargs.get("store", None) is not None:
58 | store = kwargs["store"]
59 | else:
60 | store = args[0]._store
61 |
62 | if kwargs.get("epoch", None) is not None:
63 | epoch = kwargs["epoch"]
64 | else:
65 | epoch = args[0]._epoch
66 |
67 | # Store to disk
68 | if store:
69 | p = os.path.join(args[0]._p_visu, f"{tag}_epoch_{epoch}.png")
70 | imageio.imwrite(p, img)
71 |
72 | if log_exp:
73 | H, W, C = img.shape
74 | ds = cv2.resize(
75 | img,
76 | dsize=(int(W / 2), int(H / 2)),
77 | interpolation=cv2.INTER_CUBIC,
78 | )
79 | if args[0]._pl_model.logger is not None:
80 | args[0]._pl_model.logger.experiment.log(
81 | {tag: [wandb.Image(ds, caption=tag)]}, commit=False)
82 | return func(*args, **kwargs)
83 |
84 | return wrap
85 |
86 |
87 | class Visualizer:
88 |
89 | def __init__(self, p_visu, store, pl_model, epoch=0, num_classes=22):
90 | self._p_visu = p_visu
91 | self._pl_model = pl_model
92 | self._epoch = epoch
93 | self._store = store
94 |
95 | os.makedirs(os.path.join(self._p_visu, "train_vis"), exist_ok=True)
96 | os.makedirs(os.path.join(self._p_visu, "val_vis"), exist_ok=True)
97 | os.makedirs(os.path.join(self._p_visu, "test_vis"), exist_ok=True)
98 |
99 | @property
100 | def epoch(self):
101 | return self._epoch
102 |
103 | @epoch.setter
104 | def epoch(self, epoch):
105 | self._epoch = epoch
106 |
107 | @property
108 | def store(self):
109 | return self._store
110 |
111 | @store.setter
112 | def store(self, store):
113 | self._store = store
114 |
115 | @image_functionality
116 | def plot_segmentation(self, seg, **kwargs):
117 | try:
118 | seg = seg.clone().cpu().numpy()
119 | except:
120 | pass
121 |
122 | if seg.dtype == np.bool:
123 | col_map = BINARY_COLORS
124 | else:
125 | col_map = SCANNET_COLORS
126 |
127 | H, W = seg.shape[:2]
128 | img = np.zeros((H, W, 3), dtype=np.uint8)
129 | for i, color in enumerate(col_map):
130 | img[seg == i] = color[:3]
131 |
132 | return img
133 |
134 | @image_functionality
135 | def plot_image(self, img, **kwargs):
136 | """
137 | ----------
138 | img : CHW HWC accepts torch.tensor or numpy.array
139 | Range 0-1 or 0-255
140 | """
141 | try:
142 | img = img.clone().cpu().numpy()
143 | except:
144 | pass
145 |
146 | if img.shape[2] == 3:
147 | pass
148 | elif img.shape[0] == 3:
149 | img = np.moveaxis(img, [0, 1, 2], [2, 0, 1])
150 | else:
151 | raise Exception("Invalid Shape")
152 | if img.max() <= 1:
153 | img = img * 255
154 |
155 | img = np.uint8(img)
156 | return img
157 |
158 | @image_functionality
159 | def plot_detectron(
160 | self,
161 | img,
162 | label,
163 | text_off=False,
164 | alpha=0.5,
165 | draw_bound=True,
166 | shift=2.5,
167 | font_size=12,
168 | **kwargs,
169 | ):
170 | """
171 | ----------
172 | img : CHW HWC accepts torch.tensor or numpy.array
173 | Range 0-1 or 0-255
174 | label: HW accepts torch.tensor or numpy.array
175 | """
176 |
177 | img = self.plot_image(img, not_log=True)
178 | try:
179 | label = label.clone().cpu().numpy()
180 | except:
181 | pass
182 | label = label.astype(np.long)
183 |
184 | H, W, C = img.shape
185 | uni = np.unique(label)
186 | overlay = np.zeros_like(img)
187 |
188 | centers = []
189 | for u in uni:
190 | m = label == u
191 | col = SCANNET_COLORS[u]
192 | overlay[m] = col
193 | labels_mask = skimage.measure.label(m)
194 | regions = skimage.measure.regionprops(labels_mask)
195 | regions.sort(key=lambda x: x.area, reverse=True)
196 | cen = np.mean(regions[0].coords, axis=0).astype(np.uint32)[::-1]
197 |
198 | centers.append((SCANNET_CLASSES[u], cen))
199 |
200 | back = np.zeros((H, W, 4))
201 | back[:, :, :3] = img
202 | back[:, :, 3] = 255
203 | fore = np.zeros((H, W, 4))
204 | fore[:, :, :3] = overlay
205 | fore[:, :, 3] = alpha * 255
206 | img_new = Image.alpha_composite(Image.fromarray(np.uint8(back)),
207 | Image.fromarray(np.uint8(fore)))
208 | draw = ImageDraw.Draw(img_new)
209 |
210 | if not text_off:
211 |
212 | for i in centers:
213 | pose = i[1]
214 | pose[0] -= len(str(i[0])) * shift
215 | pose[1] -= font_size / 2
216 | draw.text(tuple(pose), str(i[0]), fill=(255, 255, 255, 128))
217 |
218 | img_new = img_new.convert("RGB")
219 | mask = skimage.segmentation.mark_boundaries(np.array(img_new),
220 | label,
221 | color=(255, 255, 255))
222 | mask = mask.sum(axis=2)
223 | m = mask == mask.max()
224 | img_new = np.array(img_new)
225 | if draw_bound:
226 | img_new[m] = (255, 255, 255)
227 | return np.uint8(img_new)
228 |
--------------------------------------------------------------------------------
/preprocessing_scripts/scannet2nerf.py:
--------------------------------------------------------------------------------
1 | # Partly based on https://github.com/ashawkey/torch-ngp/blob/main/scripts/
2 | # colmap2nerf.py.
3 |
4 | import argparse
5 | import copy
6 | import glob
7 | import json
8 | import numpy as np
9 | import os
10 |
11 |
12 | def rotmat(a, b):
13 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
14 | v = np.cross(a, b)
15 | c = np.dot(a, b)
16 | s = np.linalg.norm(v)
17 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
18 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s**2 + 1e-10))
19 |
20 |
21 | def closest_point_2_lines(
22 | oa, da, ob, db
23 | ): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
24 | da = da / np.linalg.norm(da)
25 | db = db / np.linalg.norm(db)
26 | c = np.cross(da, db)
27 | denom = np.linalg.norm(c)**2
28 | t = ob - oa
29 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
30 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
31 | if ta > 0:
32 | ta = 0
33 | if tb > 0:
34 | tb = 0
35 | return (oa + ta * da + ob + tb * db) * 0.5, denom
36 |
37 |
38 | parser = argparse.ArgumentParser(
39 | description=
40 | "Run neural graphics primitives testbed with additional configuration & output options"
41 | )
42 |
43 | parser.add_argument("--scene_folder", type=str, default="")
44 | parser.add_argument(
45 | "--transform_train",
46 | type=str,
47 | default="",
48 | )
49 | parser.add_argument(
50 | "--transform_test",
51 | type=str,
52 | default="",
53 | )
54 |
55 | parser.add_argument("--interval", default=10, type=int, help="Sample Interval.")
56 |
57 | parser.add_argument("--room_center",
58 | action="store_true",
59 | help="Use room centers from the mesh file")
60 | args = parser.parse_args()
61 |
62 | scannet_folder = args.scene_folder
63 | json_train = args.transform_train
64 | json_test = args.transform_test
65 | interval = args.interval
66 | json_train_base = "transforms_train"
67 | json_test_base = "transforms_test"
68 | # Select the frames from the json. This are the frames of which we want to find
69 | # the actual transform.
70 | c2ws = []
71 | frame_names = []
72 | with open(json_train, "r") as f:
73 | transforms_train = json.load(f)
74 | # - Get filenames and concurrently load the c2w.
75 | for frame_idx, frame in enumerate(transforms_train["frames"]):
76 | if frame_idx % interval == 0:
77 | frame_name = os.path.basename(frame["file_path"]).split(".jpg")[0]
78 | pose_name = os.path.join(scannet_folder, f"pose/{frame_name}.txt")
79 | c2w = np.loadtxt(pose_name)
80 | if np.any(np.isinf(c2w)):
81 | continue
82 | frame_names.append(frame_name)
83 | c2ws.append(c2w)
84 |
85 | c2ws_test = []
86 | frame_names_test = []
87 | with open(json_test, "r") as f:
88 | transforms_test = json.load(f)
89 | # - Get filenames and concurrently load the c2w.
90 | for frame_idx, frame in enumerate(transforms_test["frames"]):
91 | if frame_idx % interval == 0:
92 | frame_name = os.path.basename(frame["file_path"]).split(".jpg")[0]
93 | pose_name = os.path.join(scannet_folder, f"pose/{frame_name}.txt")
94 | c2w = np.loadtxt(pose_name)
95 | if np.any(np.isinf(c2w)):
96 | continue
97 | frame_names_test.append(frame_name)
98 | c2ws_test.append(c2w)
99 |
100 | selected_transforms = copy.deepcopy(transforms_train)
101 | selected_transforms.pop("frames")
102 | selected_transforms["frames"] = []
103 | selected_transforms_test = copy.deepcopy(transforms_test)
104 | selected_transforms_test.pop("frames")
105 | selected_transforms_test["frames"] = []
106 |
107 | # Open the mesh file to retrieve the scene center.
108 | if args.room_center:
109 | mesh_files = glob.glob(os.path.join(scannet_folder, "*_vh_clean.ply"))
110 | assert len(mesh_files) == 1, (
111 | "Found no/more than 1 'vh_clean' mesh files in "
112 | f"{scannet_folder}.")
113 |
114 | mesh = o3d.io.read_triangle_mesh(mesh_files[0])
115 | max_coord_mesh = np.max(mesh.vertices, axis=0)
116 | min_coord_mesh = np.min(mesh.vertices, axis=0)
117 | room_center = (max_coord_mesh + min_coord_mesh) / 2.0
118 | else:
119 | room_center = np.zeros(3)
120 |
121 | up = np.zeros(3)
122 | print(f"length of c2ws: {len(c2ws)}")
123 | for c2w_idx in range(len(c2ws)):
124 | c2ws[c2w_idx][:3, 3] -= room_center
125 | c2ws[c2w_idx][0:3, 2] *= -1 # flip the y and z axis
126 | c2ws[c2w_idx][0:3, 1] *= -1
127 | c2ws[c2w_idx] = c2ws[c2w_idx][[1, 0, 2, 3], :] # swap y and z
128 | c2ws[c2w_idx][2, :] *= -1 # flip whole world upside down
129 | up += c2ws[c2w_idx][0:3, 1]
130 |
131 | for c2w_idx in range(len(c2ws_test)):
132 | c2ws_test[c2w_idx][:3, 3] -= room_center
133 | c2ws_test[c2w_idx][0:3, 2] *= -1 # flip the y and z axis
134 | c2ws_test[c2w_idx][0:3, 1] *= -1
135 | c2ws_test[c2w_idx] = c2ws_test[c2w_idx][[1, 0, 2, 3], :] # swap y and z
136 | c2ws_test[c2w_idx][2, :] *= -1 # flip whole world upside down
137 |
138 | print(f"up vector: {up}")
139 |
140 | nframes = len(c2ws)
141 | up = up / np.linalg.norm(up)
142 | print("up vector was", up)
143 | R = rotmat(up, [0, 0, 1]) # rotate up vector to [0,0,1]
144 | R = np.pad(R, [0, 1])
145 | R[-1, -1] = 1
146 |
147 | for c2w_idx in range(len(c2ws)):
148 | c2ws[c2w_idx] = np.matmul(R, c2ws[c2w_idx]) # rotate up to be the z axis
149 |
150 | for c2w_idx in range(len(c2ws_test)):
151 | c2ws_test[c2w_idx] = np.matmul(
152 | R, c2ws_test[c2w_idx]) # rotate up to be the z axis
153 |
154 | # find a central point they are all looking at
155 | if not args.room_center:
156 | print("computing center of attention...")
157 | totw = 0.0
158 | totp = np.array([0.0, 0.0, 0.0])
159 | for c2w_idx_1 in range(len(c2ws)):
160 | mf = c2ws[c2w_idx_1][0:3, :]
161 | for c2w_idx_2 in range(len(c2ws)):
162 | mg = c2ws[c2w_idx_2][0:3, :]
163 | p, w = closest_point_2_lines(mf[:, 3], mf[:, 2], mg[:, 3], mg[:, 2])
164 | if w > 0.01:
165 | totp += p * w
166 | totw += w
167 | totp /= totw
168 | print("room center was:")
169 | print(totp) # the cameras are looking at totp
170 | for c2w_idx in range(len(c2ws)):
171 | c2ws[c2w_idx][0:3, 3] -= totp
172 |
173 | for c2w_idx in range(len(c2ws_test)):
174 | c2ws_test[c2w_idx][0:3, 3] -= totp
175 |
176 | avglen = 0.0
177 | for c2w_idx in range(len(c2ws)):
178 | avglen += np.linalg.norm(c2ws[c2w_idx][0:3, 3])
179 | print(f"avglen:{avglen}")
180 | print(nframes)
181 | avglen /= nframes
182 |
183 | # This factor converts one meter to the unit of measure of the scene.
184 | # NOTE: This incorporates both the scaling previously done in this script
185 | # and the one previously done in `nerf_matrix_to_ngp`, which now no longer
186 | # scales the poses.
187 | one_m_to_scene_uom = 4.0 / avglen * 0.33
188 |
189 | print("avg camera distance from origin", avglen)
190 | for c2w_idx in range(len(c2ws)):
191 | c2ws[c2w_idx][0:3, 3] *= one_m_to_scene_uom # scale to "nerf sized"
192 | for c2w_idx in range(len(c2ws_test)):
193 | c2ws_test[c2w_idx][0:3, 3] *= one_m_to_scene_uom # scale to "nerf sized"
194 |
195 | store_dict = {}
196 | store_dict["avglen"] = avglen
197 | store_dict["up"] = up
198 | store_dict["totp"] = totp
199 | store_dict["totw"] = totw
200 | scene_name = os.path.basename(os.path.dirname(scannet_folder))
201 |
202 | curr_frame_name_idx = 0
203 | for frame_idx in range(len(transforms_train["frames"])):
204 | if curr_frame_name_idx == len(frame_names):
205 | break
206 | frame = transforms_train["frames"][frame_idx]
207 | frame_name = os.path.basename(frame["file_path"]).split(".jpg")[0]
208 | if frame_name == frame_names[curr_frame_name_idx]:
209 | c2w = c2ws[curr_frame_name_idx]
210 | transforms_train["frames"][frame_idx]["transform_matrix"] = c2w.tolist()
211 | selected_transforms["frames"].append(
212 | transforms_train["frames"][frame_idx])
213 | curr_frame_name_idx += 1
214 | selected_transforms["one_m_to_scene_uom"] = one_m_to_scene_uom
215 |
216 | out_path = os.path.join(scannet_folder, f"{json_train_base}.json")
217 | with open(out_path, "w") as f:
218 | json.dump(selected_transforms, f, indent=4)
219 |
220 | curr_frame_name_idx = 0
221 | for frame_idx in range(len(transforms_test["frames"])):
222 | if curr_frame_name_idx == len(frame_names_test):
223 | break
224 | frame = transforms_test["frames"][frame_idx]
225 | frame_name = os.path.basename(frame["file_path"]).split(".jpg")[0]
226 | if frame_name == frame_names_test[curr_frame_name_idx]:
227 | c2w = c2ws_test[curr_frame_name_idx]
228 | transforms_test["frames"][frame_idx]["transform_matrix"] = c2w.tolist()
229 | selected_transforms_test["frames"].append(
230 | transforms_test["frames"][frame_idx])
231 | curr_frame_name_idx += 1
232 | selected_transforms_test["one_m_to_scene_uom"] = one_m_to_scene_uom
233 |
234 | out_path = os.path.join(scannet_folder, f"{json_test_base}.json")
235 | with open(out_path, "w") as f:
236 | json.dump(selected_transforms_test, f, indent=4)
237 |
--------------------------------------------------------------------------------
/preprocessing_scripts/scannet2transform.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import csv
4 | import cv2
5 | import json
6 | import numpy as np
7 | import os
8 |
9 |
10 | def load_scannet_nyu40_mapping(path):
11 | mapping = {}
12 | with open(os.path.join(path, 'scannetv2-labels.combined.tsv')) as tsvfile:
13 | tsvreader = csv.reader(tsvfile, delimiter='\t')
14 | for i, line in enumerate(tsvreader):
15 | if i == 0:
16 | continue
17 | scannet_id, nyu40id = int(line[0]), int(line[4])
18 | mapping[scannet_id] = nyu40id
19 | return mapping
20 |
21 |
22 | def load_scannet_nyu13_mapping(path):
23 | mapping = {}
24 | with open(os.path.join(path, 'scannetv2-labels.combined.tsv')) as tsvfile:
25 | tsvreader = csv.reader(tsvfile, delimiter='\t')
26 | for i, line in enumerate(tsvreader):
27 | if i == 0:
28 | continue
29 | scannet_id, nyu40id = int(line[0]), int(line[5])
30 | mapping[scannet_id] = nyu40id
31 | return mapping
32 |
33 |
34 | parser = argparse.ArgumentParser(
35 | description=
36 | "Run neural graphics primitives testbed with additional configuration & output options"
37 | )
38 |
39 | parser.add_argument("--scene_folder", type=str, default="")
40 | parser.add_argument("--scaled_image", action="store_true")
41 | parser.add_argument("--semantics", action="store_true")
42 | args = parser.parse_args()
43 | basedir = args.scene_folder
44 |
45 | print(f"processing folder: {basedir}")
46 |
47 | # Step for generating training images
48 | step = 1
49 |
50 | frame_ids = os.listdir(os.path.join(basedir, 'color'))
51 | frame_ids = [int(os.path.splitext(frame)[0]) for frame in frame_ids]
52 | frame_ids = sorted(frame_ids)
53 |
54 | intrinsic_file = os.path.join(basedir, "intrinsic/intrinsic_color.txt")
55 | intrinsic = np.loadtxt(intrinsic_file)
56 | print("intrinsic parameters:")
57 | print(intrinsic)
58 |
59 | imgs = []
60 | poses = []
61 |
62 | K_unscaled = copy.deepcopy(intrinsic)
63 |
64 | W_unscaled = 1296
65 | H_unscaled = 968
66 |
67 | W = 320
68 | H = 240
69 | K = copy.deepcopy(intrinsic)
70 | scale_x = 320. / 1296.
71 | scale_y = 240. / 968.
72 |
73 | K[0, 0] = K[0, 0] * scale_x # fx
74 | K[1, 1] = K[1, 1] * scale_y # fy
75 | K[0, 2] = K[0, 2] * scale_x # cx
76 | K[1, 2] = K[1, 2] * scale_y # cy
77 |
78 | if args.semantics:
79 | label_mapping_nyu = load_scannet_nyu40_mapping(basedir)
80 | os.makedirs(os.path.join(basedir, 'label_40'), exist_ok=True)
81 | os.makedirs(os.path.join(basedir, 'label_40_scaled'), exist_ok=True)
82 |
83 | train_ids = frame_ids[::step]
84 | test_id_step = 10
85 | test_ids = []
86 | for x in train_ids:
87 | if x + (test_id_step // 2) < len(frame_ids):
88 | test_ids.append(x + (test_id_step // 2))
89 | test_ids = test_ids[::
90 | test_id_step] # only use 10% of the test frames to speed up inference
91 |
92 | print(f"total number of training frames: {len(train_ids)}")
93 | print(f"total number of testing frames: {len(test_ids)}")
94 |
95 | os.makedirs(os.path.join(basedir, 'color_scaled'), exist_ok=True)
96 |
97 | for ids in (train_ids, test_ids):
98 | transform_json = {}
99 | transform_json["fl_x"] = K[0, 0]
100 | transform_json["fl_y"] = K[1, 1]
101 | transform_json["cx"] = K[0, 2]
102 | transform_json["cy"] = K[1, 2]
103 | transform_json["w"] = W
104 | transform_json["h"] = H
105 | transform_json["camera_angle_x"] = np.arctan2(W / 2, K[0, 0]) * 2
106 | transform_json["camera_angle_y"] = np.arctan2(H / 2, K[1, 1]) * 2
107 | transform_json["aabb_scale"] = 16
108 | transform_json["frames"] = []
109 |
110 | transform_json_unscaled = {}
111 | transform_json_unscaled["fl_x"] = K_unscaled[0, 0]
112 | transform_json_unscaled["fl_y"] = K_unscaled[1, 1]
113 | transform_json_unscaled["cx"] = K_unscaled[0, 2]
114 | transform_json_unscaled["cy"] = K_unscaled[1, 2]
115 | transform_json_unscaled["w"] = W_unscaled
116 | transform_json_unscaled["h"] = H_unscaled
117 | transform_json_unscaled["camera_angle_x"] = np.arctan2(
118 | W_unscaled / 2, K_unscaled[0, 0]) * 2
119 | transform_json_unscaled["camera_angle_y"] = np.arctan2(
120 | H_unscaled / 2, K_unscaled[1, 1]) * 2
121 | transform_json_unscaled["aabb_scale"] = 16
122 | transform_json_unscaled["frames"] = []
123 |
124 | for frame_id in ids:
125 | pose = np.loadtxt(os.path.join(basedir, 'pose', '%d.txt' % frame_id))
126 | pose = pose.reshape((4, 4))
127 | if np.any(np.isinf(pose)):
128 | continue
129 | if args.scaled_image:
130 | file_name_image = os.path.join(basedir, 'color',
131 | '%d.jpg' % frame_id)
132 | image = cv2.imread(
133 | file_name_image)[:, :, ::
134 | -1] # change from BGR uinit 8 to RGB float
135 | #image = cv2.copyMakeBorder(src=image, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=[0,0,0]) # pad 4 pixels to height so that images have aspect ratio of 4:3
136 | #assert image.shape[0] * 4==3 * image.shape[1]
137 | image = image / 255.0
138 | image = cv2.resize(image, (W, H), interpolation=cv2.INTER_AREA)
139 | image_save = cv2.cvtColor(image.astype(np.float32),
140 | cv2.COLOR_BGR2RGB)
141 | image_save = image_save * 255.0
142 | cv2.imwrite(
143 | os.path.join(basedir, 'color_scaled', '%d.jpg' % frame_id),
144 | image_save)
145 |
146 | if args.semantics:
147 | file_name_label = os.path.join(basedir, 'label-filt',
148 | '%d.png' % frame_id)
149 | semantic = cv2.imread(file_name_label, cv2.IMREAD_UNCHANGED)
150 | semantic_copy = copy.deepcopy(semantic)
151 | for scan_id, nyu_id in label_mapping_nyu.items():
152 | semantic[semantic_copy == scan_id] = nyu_id
153 | # semantic_scaled = cv2.copyMakeBorder(src=semantic, top=2, bottom=2, left=0, right=0, borderType=cv2.BORDER_CONSTANT, value=0)
154 | semantic_scaled = cv2.resize(semantic, (W, H),
155 | interpolation=cv2.INTER_NEAREST)
156 | semantic = semantic.astype(np.uint8)
157 | semantic_scaled = semantic_scaled.astype(np.uint8)
158 | cv2.imwrite(
159 | os.path.join(basedir, 'label_40_scaled',
160 | '%d.png' % frame_id), semantic_scaled)
161 | cv2.imwrite(
162 | os.path.join(basedir, 'label_40', '%d.png' % frame_id),
163 | semantic)
164 |
165 | json_image_dict = {}
166 | json_image_dict["file_path"] = os.path.join('color_scaled',
167 | '%d.jpg' % frame_id)
168 | if args.semantics:
169 | json_image_dict["label_path"] = os.path.join(
170 | 'label_40_scaled', '%d.png' % frame_id)
171 | json_image_dict["transform_matrix"] = pose.tolist()
172 | transform_json["frames"].append(json_image_dict)
173 |
174 | json_image_dict_unscaled = {}
175 | json_image_dict_unscaled["file_path"] = os.path.join(
176 | 'color', '%d.jpg' % frame_id)
177 | if args.semantics:
178 | json_image_dict_unscaled["label_path"] = os.path.join(
179 | 'label_40', '%d.png' % frame_id)
180 | json_image_dict_unscaled["transform_matrix"] = pose.tolist()
181 | transform_json_unscaled["frames"].append(json_image_dict_unscaled)
182 |
183 | if args.scaled_image:
184 | if ids == train_ids:
185 | file_name = 'transforms_train_scaled'
186 | else:
187 | file_name = 'transforms_test_scaled'
188 |
189 | if args.semantics:
190 | file_name += "_semantics_40_raw"
191 | file_name += ".json"
192 | out_file = open(os.path.join(basedir, file_name), "w")
193 | json.dump(transform_json, out_file, indent=4)
194 | out_file.close()
195 |
196 | else:
197 | if ids == train_ids:
198 | file_name = 'transforms_train'
199 | else:
200 | file_name = 'transforms_test'
201 | if args.semantics:
202 | file_name += "_semantics_40_raw"
203 | file_name += ".json"
204 | out_file = open(os.path.join(basedir, file_name), "w")
205 | json.dump(transform_json_unscaled, out_file, indent=4)
206 | out_file.close()
207 |
--------------------------------------------------------------------------------
/preprocessing_scripts/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # color palette for nyu40 labels
4 | nyu40_colour_code = np.array([
5 | (0, 0, 0),
6 | (174, 199, 232), # wall
7 | (152, 223, 138), # floor
8 | (31, 119, 180), # cabinet
9 | (255, 187, 120), # bed
10 | (188, 189, 34), # chair
11 | (140, 86, 75), # sofa
12 | (255, 152, 150), # table
13 | (214, 39, 40), # door
14 | (197, 176, 213), # window
15 | (148, 103, 189), # bookshelf
16 | (196, 156, 148), # picture
17 | (23, 190, 207), # counter
18 | (178, 76, 76), # blinds
19 | (247, 182, 210), # desk
20 | (66, 188, 102), # shelves
21 | (219, 219, 141), # curtain
22 | (140, 57, 197), # dresser
23 | (202, 185, 52), # pillow
24 | (51, 176, 203), # mirror
25 | (200, 54, 131), # floor
26 | (92, 193, 61), # clothes
27 | (78, 71, 183), # ceiling
28 | (172, 114, 82), # books
29 | (255, 127, 14), # refrigerator
30 | (91, 163, 138), # tv
31 | (153, 98, 156), # paper
32 | (140, 153, 101), # towel
33 | (158, 218, 229), # shower curtain
34 | (100, 125, 154), # box
35 | (178, 127, 135), # white board
36 | (120, 185, 128), # person
37 | (146, 111, 194), # night stand
38 | (44, 160, 44), # toilet
39 | (112, 128, 144), # sink
40 | (96, 207, 209), # lamp
41 | (227, 119, 194), # bathtub
42 | (213, 92, 176), # bag
43 | (94, 106, 211), # other struct
44 | (82, 84, 163), # otherfurn
45 | (100, 85, 144), # other prop
46 | ]).astype(np.uint8)
47 |
48 | nyu13_colour_code = (
49 | np.array([
50 | [0, 0, 0],
51 | [0, 0, 1], # BED
52 | [0.9137, 0.3490, 0.1882], # BOOKS
53 | [0, 0.8549, 0], # CEILING
54 | [0.5843, 0, 0.9412], # CHAIR
55 | [0.8706, 0.9451, 0.0941], # FLOOR
56 | [1.0000, 0.8078, 0.8078], # FURNITURE
57 | [0, 0.8784, 0.8980], # OBJECTS
58 | [0.4157, 0.5333, 0.8000], # PAINTING
59 | [0.4588, 0.1137, 0.1608], # SOFA
60 | [0.9412, 0.1373, 0.9216], # TABLE
61 | [0, 0.6549, 0.6118], # TV
62 | [0.9765, 0.5451, 0], # WALL
63 | [0.8824, 0.8980, 0.7608],
64 | ]) * 255).astype(np.uint8)
65 |
66 | nyu40_to13 = {
67 | 0: 0,
68 | 1: 12,
69 | 2: 5,
70 | 3: 6,
71 | 4: 1,
72 | 5: 4,
73 | 6: 9,
74 | 7: 10,
75 | 8: 12,
76 | 9: 13,
77 | 10: 6,
78 | 11: 8,
79 | 12: 6,
80 | 13: 13,
81 | 14: 10,
82 | 15: 6,
83 | 16: 13,
84 | 17: 6,
85 | 18: 7,
86 | 19: 7,
87 | 20: 5,
88 | 21: 7,
89 | 22: 3,
90 | 23: 2,
91 | 24: 6,
92 | 25: 11,
93 | 26: 7,
94 | 27: 7,
95 | 28: 7,
96 | 29: 7,
97 | 30: 7,
98 | 31: 7,
99 | 32: 7,
100 | 33: 7,
101 | 34: 7,
102 | 35: 7,
103 | 36: 7,
104 | 37: 7,
105 | 38: 7,
106 | 39: 6,
107 | 40: 7,
108 | }
109 |
110 | nyu40_to_13 = (np.array([
111 | 0,
112 | 12,
113 | 5,
114 | 6,
115 | 1,
116 | 4,
117 | 9,
118 | 10,
119 | 12,
120 | 13,
121 | 6,
122 | 8,
123 | 6,
124 | 13,
125 | 10,
126 | 6,
127 | 13,
128 | 6,
129 | 7,
130 | 7,
131 | 5,
132 | 7,
133 | 3,
134 | 2,
135 | 6,
136 | 11,
137 | 7,
138 | 7,
139 | 7,
140 | 7,
141 | 7,
142 | 7,
143 | 7,
144 | 7,
145 | 7,
146 | 7,
147 | 7,
148 | 7,
149 | 7,
150 | 7,
151 | 7,
152 | ])).astype(np.uint8)
153 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | addict==2.4.0
3 | aiohttp==3.8.1
4 | aiosignal==1.2.0
5 | anyio==3.6.1
6 | argon2-cffi-bindings==21.2.0
7 | argon2-cffi==21.3.0
8 | arrow==1.2.2
9 | astroid==2.11.5
10 | asttokens==2.0.5
11 | async-timeout==4.0.2
12 | attrs==21.4.0
13 | babel==2.10.1
14 | backcall==0.2.0
15 | beautifulsoup4==4.11.1
16 | black==22.3.0
17 | bleach==5.0.0
18 | boto3==1.23.1
19 | botocore==1.26.1
20 | bravado-core==5.17.0
21 | bravado==11.0.3
22 | brotlipy==0.7.0
23 | cachetools==5.1.0
24 | certifi==2021.10.8
25 | cffi==1.15.0
26 | charset-normalizer==2.0.4
27 | click==8.1.3
28 | commonmark==0.9.1
29 | cryptography==36.0.0
30 | cycler==0.11.0
31 | dearpygui==1.6.1
32 | debugpy==1.6.0
33 | decorator==5.1.1
34 | defusedxml==0.7.1
35 | deprecation==2.1.0
36 | dill==0.3.4
37 | docker-pycreds==0.4.0
38 | entrypoints==0.4
39 | executing==0.8.3
40 | fastjsonschema==2.15.3
41 | fonttools==4.33.3
42 | fqdn==1.5.1
43 | frozenlist==1.3.0
44 | fsspec==2022.3.0
45 | future==0.18.2
46 | gitdb==4.0.9
47 | gitpython==3.1.27
48 | google-auth-oauthlib==0.4.6
49 | google-auth==2.6.6
50 | grpcio==1.46.1
51 | humanfriendly==10.0
52 | idna==3.3
53 | imageio-ffmpeg==0.4.7
54 | imageio==2.17.0
55 | imgviz==1.5.0
56 | importlib-metadata==4.11.3
57 | iniconfig==1.1.1
58 | ipykernel==6.13.0
59 | ipython-genutils==0.2.0
60 | ipython==8.3.0
61 | ipywidgets==7.7.0
62 | isoduration==20.11.0
63 | isort==5.10.1
64 | jedi==0.18.1
65 | jinja2==3.1.2
66 | jmespath==1.0.0
67 | joblib==1.1.0
68 | json5==0.9.8
69 | jsonpointer==2.3
70 | jsonref==0.2
71 | jsonschema==4.5.1
72 | jupyter-client==7.3.1
73 | jupyter-core==4.10.0
74 | jupyter-packaging==0.12.0
75 | jupyter-server==1.17.0
76 | jupyterlab-pygments==0.2.2
77 | jupyterlab-server==2.14.0
78 | jupyterlab-widgets==1.1.0
79 | jupyterlab==3.4.2
80 | kiwisolver==1.4.2
81 | lazy-object-proxy==1.7.1
82 | markdown==3.3.7
83 | markupsafe==2.1.1
84 | matplotlib-inline==0.1.3
85 | matplotlib==3.5.1
86 | mccabe==0.7.0
87 | mistune==0.8.4
88 | monotonic==1.6
89 | msgpack==1.0.3
90 | multidict==6.0.2
91 | mypy-extensions==0.4.3
92 | nbclassic==0.3.7
93 | nbclient==0.6.3
94 | nbconvert==6.5.0
95 | nbformat==5.4.0
96 | neptune-client==0.16.2
97 | nest-asyncio==1.5.5
98 | networkx==2.8
99 | ninja==1.10.2.3
100 | notebook-shim==0.1.0
101 | notebook==6.4.11
102 | numpy==1.23.0
103 | oauthlib==3.2.0
104 | open3d==0.15.2
105 | opencv-python==4.5.5.64
106 | packaging==21.3
107 | pandas==1.4.2
108 | pandocfilters==1.5.0
109 | parso==0.8.3
110 | pathspec==0.9.0
111 | pathtools==0.1.2
112 | pexpect==4.8.0
113 | pickleshare==0.7.5
114 | pillow==9.0.1
115 | platformdirs==2.5.2
116 | pluggy==1.0.0
117 | prometheus-client==0.14.1
118 | promise==2.3
119 | prompt-toolkit==3.0.29
120 | protobuf==3.20.1
121 | psutil==5.9.0
122 | ptyprocess==0.7.0
123 | pure-eval==0.2.2
124 | py==1.11.0
125 | pyasn1-modules==0.2.8
126 | pyasn1==0.4.8
127 | pycparser==2.21
128 | pydeprecate==0.3.2
129 | pygments==2.12.0
130 | pyjwt==2.4.0
131 | pylint==2.13.9
132 | pymcubes==0.1.2
133 | pyopenssl==22.0.0
134 | pyparsing==3.0.8
135 | pyquaternion==0.9.9
136 | pyrsistent==0.18.1
137 | pysdf==0.1.8
138 | pysocks==1.7.1
139 | pytest==7.1.2
140 | python-dateutil==2.8.2
141 | pytorch-lightning==1.6.3
142 | pytz==2022.1
143 | pywavelets==1.3.0
144 | pyyaml==6.0
145 | pyzmq==22.3.0
146 | requests-oauthlib==1.3.1
147 | requests==2.27.1
148 | rfc3339-validator==0.1.4
149 | rfc3987==1.3.8
150 | rich==12.3.0
151 | rsa==4.8
152 | s3transfer==0.5.2
153 | scikit-image==0.19.2
154 | scikit-learn==1.0.2
155 | scipy==1.8.0
156 | send2trash==1.8.0
157 | sentry-sdk==1.5.12
158 | setproctitle==1.2.3
159 | setuptools==61.2.0
160 | shortuuid==1.0.9
161 | simplejson==3.17.6
162 | six==1.16.0
163 | sklearn==0.0
164 | smmap==5.0.0
165 | sniffio==1.2.0
166 | soupsieve==2.3.2.post1
167 | stack-data==0.2.0
168 | swagger-spec-validator==2.7.4
169 | tensorboard-data-server==0.6.1
170 | tensorboard-plugin-wit==1.8.1
171 | tensorboard==2.9.0
172 | tensorboardx==2.5
173 | terminado==0.15.0
174 | threadpoolctl==3.1.0
175 | tifffile==2022.5.4
176 | tinycss2==1.1.1
177 | tomli==2.0.1
178 | tomlkit==0.10.2
179 | torch-ema==0.3
180 | torch-tb-profiler==0.4.0
181 | torch==1.11.0
182 | torchaudio==0.11.0
183 | torchmetrics==0.8.2
184 | torchvision==0.12.0
185 | tornado==6.1
186 | tqdm==4.64.0
187 | traitlets==5.2.1.post0
188 | trimesh==3.11.2
189 | typing-extensions==4.1.1
190 | uri-template==1.2.0
191 | urllib3==1.26.9
192 | wandb==0.12.16
193 | wcwidth==0.2.5
194 | webcolors==1.11.1
195 | webencodings==0.5.1
196 | websocket-client==1.3.2
197 | werkzeug==2.1.2
198 | wheel==0.37.1
199 | widgetsnbextension==3.6.0
200 | wrapt==1.14.1
201 | yarl==1.7.2
202 | zipp==3.8.0
--------------------------------------------------------------------------------
/run_scripts/multi_step.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python scripts/cl_deeplab.py "$@"
3 |
4 |
--------------------------------------------------------------------------------
/run_scripts/one_step_finetune_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | name=one_step_finetune_nerf
3 | prev_exp_name=one_step_nerf_only
4 | declare -a Scenes=("s00" "s10" "s20" "s30" "s40" "s50" "s60" "s70" "s80" "s90")
5 | for i in "${!Scenes[@]}"; do
6 | python scripts/train_finetune.py --exp cfg/exp/one_step_finetune_nerf/${Scenes[i]}_lr1e-5.yml --project_name $name --prev_exp_name $prev_exp_name
7 | done
--------------------------------------------------------------------------------
/run_scripts/one_step_joint_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | name=one_step_joint
3 | declare -a Scenes=("s00" "s10" "s20" "s30" "s40" "s50" "s60" "s70" "s80" "s90")
4 | for i in "${!Scenes[@]}"; do
5 | python scripts/train_joint.py --exp cfg/exp/one_step_joint/${Scenes[i]}_lr1e-5.yml --exp_name $name --project_name $name --nerf_train_epoch 10 --joint_train_epoch 50
6 | done
--------------------------------------------------------------------------------
/run_scripts/one_step_nerf_only_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | name=one_step_nerf_only
3 | declare -a Scenes=("s00" "s10" "s20" "s30" "s40" "s50" "s60" "s70" "s80" "s90")
4 | for i in "${!Scenes[@]}"; do
5 | python scripts/train_joint.py --exp cfg/exp/one_step_joint/${Scenes[i]}_lr1e-5.yml --exp_name $name --project_name $name --nerf_train_epoch 60 --joint_train_epoch 0
6 | done
--------------------------------------------------------------------------------
/run_scripts/preprocess_scannet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | root_dir="data/scannet/scans"
3 | for i in $(ls -d $root_dir/*/);
4 | do
5 | echo ${i};
6 | python preprocessing_scripts/scannet2transform.py --scene_folder $i --scaled_image --semantics
7 | python preprocessing_scripts/scannet2nerf.py --scene_folder $i --transform_train $i/transforms_train_scaled_semantics_40_raw.json \
8 | --transform_test $i/transforms_test_scaled_semantics_40_raw.json --interval 10
9 | done
--------------------------------------------------------------------------------
/run_scripts/pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python scripts/pretrain.py "$@"
3 |
4 |
--------------------------------------------------------------------------------
/scripts/cl_deeplab.py:
--------------------------------------------------------------------------------
1 | """Continual learning protocal for calling vis4d in multiple stages."""
2 | import argparse
3 | import os
4 | import sys
5 | import wandb
6 |
7 | from nr4seg import ROOT_DIR
8 | from nr4seg.utils import load_yaml
9 | from scripts.train_joint import train
10 |
11 | SCENE_ORDER = [
12 | "scene0000_00",
13 | "scene0001_00",
14 | "scene0002_00",
15 | "scene0003_00",
16 | "scene0004_00",
17 | "scene0005_00",
18 | "scene0006_00",
19 | "scene0007_00",
20 | "scene0008_00",
21 | "scene0009_00",
22 | ]
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument(
28 | "--exp",
29 | default="cfg/exp/finetune/deeplabv3_s0.yml",
30 | help=
31 | "Experiment yaml file path relative to template_project_name/cfg/exp directory.",
32 | )
33 | parser.add_argument(
34 | "--exp_name",
35 | default="debug",
36 | help="overall experiment of this continual learning experiment.",
37 | )
38 | parser.add_argument("--seed", default=123, type=int)
39 | parser.add_argument(
40 | "--fix_nerf",
41 | action="store_true",
42 | help="whether or not to fix nerf during joint training",
43 | )
44 |
45 | parser.add_argument("--project_name", default="test_one_by_one")
46 | parser.add_argument("--nerf_train_epoch", default=10, type=int)
47 |
48 | parser.add_argument("--joint_train_epoch", default=10, type=int)
49 | args = parser.parse_args()
50 | return args
51 |
52 |
53 | def main():
54 | """Main function."""
55 | env_cfg_path = os.path.join(ROOT_DIR, "cfg/env",
56 | os.environ["ENV_WORKSTATION_NAME"] + ".yml")
57 | env = load_yaml(env_cfg_path)
58 | os.chdir(ROOT_DIR)
59 | args = parse_args()
60 | exp_cfg_path = os.path.join(ROOT_DIR, args.exp)
61 | exp = load_yaml(exp_cfg_path)
62 | exp_name = args.exp_name
63 | exp["exp_name"] = exp_name
64 |
65 | prev_stage = ""
66 | stage = "init"
67 | exp["scenes"] = []
68 |
69 | for i, new_scene in enumerate(SCENE_ORDER):
70 | exp["scenes"].append(new_scene)
71 | prev_stage = stage
72 | stage = f"stage_{i}"
73 | exp["general"]["name"] = f"{exp_name}/{stage}"
74 |
75 | # train on new class
76 | exp["trainer"]["resume_from_checkpoint"] = False
77 | exp["trainer"]["load_from_checkpoint"] = True
78 | if i == 0:
79 | exp["general"]["load_pretrain"] = True
80 | old_model_path = exp["general"]["checkpoint_load"]
81 | else:
82 | exp["general"]["load_pretrain"] = False
83 | old_model_path = os.path.join("experiments", exp_name, prev_stage,
84 | "deeplab.ckpt")
85 |
86 | exp["general"]["checkpoint_load"] = old_model_path
87 |
88 | print(f"training on: {new_scene}")
89 |
90 | train(exp, env, exp_cfg_path, env_cfg_path, args)
91 | wandb.finish()
92 |
93 |
94 | if __name__ == "__main__": # pragma: no cover
95 | main()
96 | sys.exit(1)
97 |
--------------------------------------------------------------------------------
/scripts/eval_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # color palette for nyu40 labels
4 | nyu40_colour_code = np.array([
5 | (0, 0, 0),
6 | (174, 199, 232), # wall
7 | (152, 223, 138), # floor
8 | (31, 119, 180), # cabinet
9 | (255, 187, 120), # bed
10 | (188, 189, 34), # chair
11 | (140, 86, 75), # sofa
12 | (255, 152, 150), # table
13 | (214, 39, 40), # door
14 | (197, 176, 213), # window
15 | (148, 103, 189), # bookshelf
16 | (196, 156, 148), # picture
17 | (23, 190, 207), # counter
18 | (178, 76, 76), # blinds
19 | (247, 182, 210), # desk
20 | (66, 188, 102), # shelves
21 | (219, 219, 141), # curtain
22 | (140, 57, 197), # dresser
23 | (202, 185, 52), # pillow
24 | (51, 176, 203), # mirror
25 | (200, 54, 131), # floor
26 | (92, 193, 61), # clothes
27 | (78, 71, 183), # ceiling
28 | (172, 114, 82), # books
29 | (255, 127, 14), # refrigerator
30 | (91, 163, 138), # tv
31 | (153, 98, 156), # paper
32 | (140, 153, 101), # towel
33 | (158, 218, 229), # shower curtain
34 | (100, 125, 154), # box
35 | (178, 127, 135), # white board
36 | (120, 185, 128), # person
37 | (146, 111, 194), # night stand
38 | (44, 160, 44), # toilet
39 | (112, 128, 144), # sink
40 | (96, 207, 209), # lamp
41 | (227, 119, 194), # bathtub
42 | (213, 92, 176), # bag
43 | (94, 106, 211), # other struct
44 | (82, 84, 163), # otherfurn
45 | (100, 85, 144), # other prop
46 | ]).astype(np.uint8)
47 |
48 | nyu13_colour_code = (
49 | np.array([
50 | [0, 0, 0],
51 | [0, 0, 1], # BED
52 | [0.9137, 0.3490, 0.1882], # BOOKS
53 | [0, 0.8549, 0], # CEILING
54 | [0.5843, 0, 0.9412], # CHAIR
55 | [0.8706, 0.9451, 0.0941], # FLOOR
56 | [1.0000, 0.8078, 0.8078], # FURNITURE
57 | [0, 0.8784, 0.8980], # OBJECTS
58 | [0.4157, 0.5333, 0.8000], # PAINTING
59 | [0.4588, 0.1137, 0.1608], # SOFA
60 | [0.9412, 0.1373, 0.9216], # TABLE
61 | [0, 0.6549, 0.6118], # TV
62 | [0.9765, 0.5451, 0], # WALL
63 | [0.8824, 0.8980, 0.7608],
64 | ]) * 255).astype(np.uint8)
65 |
66 | nyu40_to13 = {
67 | 0: 0,
68 | 1: 12,
69 | 2: 5,
70 | 3: 6,
71 | 4: 1,
72 | 5: 4,
73 | 6: 9,
74 | 7: 10,
75 | 8: 12,
76 | 9: 13,
77 | 10: 6,
78 | 11: 8,
79 | 12: 6,
80 | 13: 13,
81 | 14: 10,
82 | 15: 6,
83 | 16: 13,
84 | 17: 6,
85 | 18: 7,
86 | 19: 7,
87 | 20: 5,
88 | 21: 7,
89 | 22: 3,
90 | 23: 2,
91 | 24: 6,
92 | 25: 11,
93 | 26: 7,
94 | 27: 7,
95 | 28: 7,
96 | 29: 7,
97 | 30: 7,
98 | 31: 7,
99 | 32: 7,
100 | 33: 7,
101 | 34: 7,
102 | 35: 7,
103 | 36: 7,
104 | 37: 7,
105 | 38: 7,
106 | 39: 6,
107 | 40: 7,
108 | }
109 |
110 | nyu40_to_13 = (np.array([
111 | 0,
112 | 12,
113 | 5,
114 | 6,
115 | 1,
116 | 4,
117 | 9,
118 | 10,
119 | 12,
120 | 13,
121 | 6,
122 | 8,
123 | 6,
124 | 13,
125 | 10,
126 | 6,
127 | 13,
128 | 6,
129 | 7,
130 | 7,
131 | 5,
132 | 7,
133 | 3,
134 | 2,
135 | 6,
136 | 11,
137 | 7,
138 | 7,
139 | 7,
140 | 7,
141 | 7,
142 | 7,
143 | 7,
144 | 7,
145 | 7,
146 | 7,
147 | 7,
148 | 7,
149 | 7,
150 | 7,
151 | 7,
152 | ])).astype(np.uint8)
153 |
--------------------------------------------------------------------------------
/scripts/pretrain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import torch
5 |
6 | from pathlib import Path
7 | from pytorch_lightning import Trainer, seed_everything
8 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
9 | from pytorch_lightning.plugins import DDPPlugin
10 | from pytorch_lightning.profiler import AdvancedProfiler
11 |
12 | from nr4seg import ROOT_DIR
13 | from nr4seg.lightning import PretrainDataModule, SemanticsLightningNet
14 | from nr4seg.utils import flatten_dict, get_wandb_logger, load_yaml
15 |
16 |
17 | def train(exp, args) -> float:
18 | seed_everything(args.seed)
19 |
20 | ################################################## LOAD CONFIG ###################################################
21 | # ROOT_DIR is defined int template_project_name/__init__.py
22 | env_cfg_path = os.path.join(ROOT_DIR, "cfg/env",
23 | os.environ["ENV_WORKSTATION_NAME"] + ".yml")
24 | env = load_yaml(env_cfg_path)
25 | ####################################################################################################################
26 |
27 | ############################################ CREAT EXPERIMENT FOLDER #############################################
28 |
29 | model_path = os.path.join(env["results"], exp["general"]["name"])
30 | if exp["general"]["clean_up_folder_if_exists"]:
31 | shutil.rmtree(model_path, ignore_errors=True)
32 |
33 | # Create the directory
34 | Path(model_path).mkdir(parents=True, exist_ok=True)
35 |
36 | # Copy config files
37 | exp_cfg_fn = os.path.split(exp_cfg_path)[-1]
38 | env_cfg_fn = os.path.split(env_cfg_path)[-1]
39 | print(f"Copy {env_cfg_path} to {model_path}/{exp_cfg_fn}")
40 | shutil.copy(exp_cfg_path, f"{model_path}/{exp_cfg_fn}")
41 | shutil.copy(env_cfg_path, f"{model_path}/{env_cfg_fn}")
42 | exp["general"]["name"] = model_path
43 | ####################################################################################################################
44 |
45 | ################################################# CREATE LOGGER ##################################################
46 |
47 | logger = get_wandb_logger(
48 | exp=exp,
49 | env=env,
50 | exp_p=exp_cfg_path,
51 | env_p=env_cfg_path,
52 | project_name=args.project_name,
53 | save_dir=model_path,
54 | )
55 | ex = flatten_dict(exp)
56 | logger.log_hyperparams(ex)
57 |
58 | ####################################################################################################################
59 |
60 | ########################################### CREAET NETWORK AND DATASET ###########################################
61 | model = SemanticsLightningNet(exp, env)
62 | datamodule = PretrainDataModule(env, exp["data_module"])
63 | ####################################################################################################################
64 |
65 | ################################################# TRAINER SETUP ##################################################
66 | # Callbacks
67 | lr_monitor = LearningRateMonitor(logging_interval="step")
68 | cb_ls = [lr_monitor]
69 |
70 | checkpoint_callback = ModelCheckpoint(
71 | dirpath=model_path,
72 | filename="best-{epoch:02d}-{step:06d}",
73 | verbose=True,
74 | monitor="val/mean_IoU",
75 | mode="max",
76 | save_last=True,
77 | save_top_k=1,
78 | )
79 | cb_ls.append(checkpoint_callback)
80 |
81 | # set gpus
82 | if (exp["trainer"]).get("gpus", -1) == -1:
83 | nr = torch.cuda.device_count()
84 | print(f"Set GPU Count for Trainer to {nr}!")
85 | for i in range(nr):
86 | print(f"Device {i}: ", torch.cuda.get_device_name(i))
87 | exp["trainer"]["gpus"] = nr
88 |
89 | # profiler
90 | if exp["trainer"].get("profiler", False):
91 | exp["trainer"]["profiler"] = AdvancedProfiler(
92 | output_filename=os.path.join(model_path, "profile.out"))
93 | else:
94 | exp["trainer"]["profiler"] = False
95 |
96 | # check if restore checkpoint
97 | if exp["trainer"]["resume_from_checkpoint"] is True:
98 | exp["trainer"]["resume_from_checkpoint"] = os.path.join(
99 | env["results"], exp["general"]["checkpoint_load"])
100 | else:
101 | del exp["trainer"]["resume_from_checkpoint"]
102 |
103 | trainer = Trainer(
104 | **exp["trainer"],
105 | plugins=DDPPlugin(find_unused_parameters=False),
106 | default_root_dir=model_path,
107 | callbacks=cb_ls,
108 | logger=logger,
109 | )
110 | ####################################################################################################################
111 |
112 | res = trainer.fit(model, datamodule=datamodule)
113 |
114 | return res
115 |
116 |
117 | if __name__ == "__main__":
118 | os.chdir(ROOT_DIR)
119 |
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument(
122 | "--exp",
123 | default="exp.yml",
124 | help=
125 | "Experiment yaml file path relative to template_project_name/cfg/exp directory.",
126 | )
127 | parser.add_argument("--seed", default=123, type=int)
128 | parser.add_argument("--project_name", default="scannet_debug")
129 | args = parser.parse_args()
130 | exp_cfg_path = os.path.join(ROOT_DIR, args.exp)
131 | exp = load_yaml(exp_cfg_path)
132 |
133 | train(exp, args)
134 |
--------------------------------------------------------------------------------
/scripts/train_finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 | import torch
5 | import shutil
6 |
7 | from pytorch_lightning import Trainer, seed_everything
8 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
9 | from pytorch_lightning.plugins import DDPPlugin
10 | from pytorch_lightning.profiler import AdvancedProfiler
11 |
12 | from nr4seg import ROOT_DIR
13 | from nr4seg.lightning import FineTuneDataModule, SemanticsLightningNet
14 | from nr4seg.utils import flatten_dict, get_wandb_logger, load_yaml
15 |
16 |
17 | def train(exp, env, exp_cfg_path, env_cfg_path, args) -> float:
18 | seed_everything(args.seed)
19 |
20 | ####################################################################################################################
21 |
22 | ############################################ CREAT EXPERIMENT FOLDER #############################################
23 | model_path = os.path.join(env["results"], exp["general"]["name"])
24 | if exp["general"]["clean_up_folder_if_exists"]:
25 | shutil.rmtree(model_path, ignore_errors=True)
26 |
27 | # Create the directory
28 | Path(model_path).mkdir(parents=True, exist_ok=True)
29 |
30 | # Copy config files
31 | exp_cfg_fn = os.path.split(exp_cfg_path)[-1]
32 | env_cfg_fn = os.path.split(env_cfg_path)[-1]
33 | print(f"Copy {env_cfg_path} to {model_path}/{exp_cfg_fn}")
34 | shutil.copy(exp_cfg_path, f"{model_path}/{exp_cfg_fn}")
35 | shutil.copy(env_cfg_path, f"{model_path}/{env_cfg_fn}")
36 | exp["general"]["name"] = model_path
37 | ####################################################################################################################
38 |
39 | ################################################# CREATE LOGGER ##################################################
40 |
41 | logger = get_wandb_logger(
42 | exp=exp,
43 | env=env,
44 | exp_p=exp_cfg_path,
45 | env_p=env_cfg_path,
46 | project_name=args.project_name,
47 | save_dir=model_path,
48 | )
49 | ex = flatten_dict(exp)
50 | logger.log_hyperparams(ex)
51 |
52 | ####################################################################################################################
53 |
54 | ########################################### CREAET NETWORK AND DATASET ###########################################
55 | model = SemanticsLightningNet(exp, env)
56 | datamodule = FineTuneDataModule(exp, env, prev_exp_name=args.prev_exp_name)
57 | ####################################################################################################################
58 |
59 | ################################################# TRAINER SETUP ##################################################
60 | # Callbacks
61 | lr_monitor = LearningRateMonitor(logging_interval="step")
62 | cb_ls = [lr_monitor]
63 |
64 | checkpoint_callback = ModelCheckpoint(
65 | dirpath=model_path,
66 | save_last=True,
67 | save_top_k=1,
68 | )
69 | cb_ls.append(checkpoint_callback)
70 |
71 | # set gpus
72 | if (exp["trainer"]).get("gpus", -1) == -1:
73 | nr = torch.cuda.device_count()
74 | print(f"Set GPU Count for Trainer to {nr}!")
75 | for i in range(nr):
76 | print(f"Device {i}: ", torch.cuda.get_device_name(i))
77 | exp["trainer"]["gpus"] = nr
78 |
79 | # profiler
80 | if exp["trainer"].get("profiler", False):
81 | exp["trainer"]["profiler"] = AdvancedProfiler(
82 | output_filename=os.path.join(model_path, "profile.out"))
83 | else:
84 | exp["trainer"]["profiler"] = False
85 |
86 | # check if restore checkpoint
87 | if exp["trainer"]["resume_from_checkpoint"] is True:
88 | exp["trainer"]["resume_from_checkpoint"] = exp["general"][
89 | "checkpoint_load"]
90 | else:
91 | del exp["trainer"]["resume_from_checkpoint"]
92 |
93 | if exp["trainer"]["load_from_checkpoint"] is True:
94 | checkpoint = torch.load(exp["general"]["checkpoint_load"])
95 | checkpoint = checkpoint["state_dict"]
96 | # remove any aux classifier stuff
97 | removekeys = [
98 | key for key in checkpoint.keys()
99 | if key.startswith("_model._model.aux_classifier")
100 | ]
101 | for key in removekeys:
102 | del checkpoint[key]
103 | model.load_state_dict(checkpoint, strict=True)
104 |
105 | del exp["trainer"]["load_from_checkpoint"]
106 |
107 | trainer = Trainer(
108 | **exp["trainer"],
109 | plugins=DDPPlugin(find_unused_parameters=False),
110 | default_root_dir=model_path,
111 | callbacks=cb_ls,
112 | logger=logger,
113 | )
114 | ####################################################################################################################
115 | trainer.validate(model=model, datamodule=datamodule)
116 | trainer.test(model=model, datamodule=datamodule)
117 | trainer.fit(model, datamodule=datamodule)
118 | trainer.test(model=model, datamodule=datamodule)
119 |
120 |
121 | if __name__ == "__main__":
122 | os.chdir(ROOT_DIR)
123 |
124 | parser = argparse.ArgumentParser()
125 | parser.add_argument(
126 | "--exp",
127 | default="cfg/exp/finetune/deeplabv3_s0.yml",
128 | help=
129 | "Experiment yaml file path relative to template_project_name/cfg/exp directory.",
130 | )
131 | parser.add_argument("--seed", default=123, type=int)
132 | parser.add_argument("--project_name", default="scannet_debug")
133 | parser.add_argument("--prev_exp_name", default="one_step_nerf_only")
134 | args = parser.parse_args()
135 | exp_cfg_path = os.path.join(ROOT_DIR, args.exp)
136 | exp = load_yaml(exp_cfg_path)
137 | env_cfg_path = os.path.join(ROOT_DIR, "cfg/env",
138 | os.environ["ENV_WORKSTATION_NAME"] + ".yml")
139 | env = load_yaml(env_cfg_path)
140 |
141 | train(exp, env, exp_cfg_path, env_cfg_path, args)
142 |
--------------------------------------------------------------------------------
/scripts/train_joint.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import torch
5 |
6 | from pathlib import Path
7 | from pytorch_lightning import Trainer, seed_everything
8 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
9 | from pytorch_lightning.plugins import DDPPlugin
10 |
11 | from nr4seg import ROOT_DIR
12 | from nr4seg.lightning import JointTrainDataModule, JointTrainLightningNet
13 | from nr4seg.utils import flatten_dict, get_wandb_logger, load_yaml
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | "--exp",
20 | default="cfg/exp/finetune/deeplabv3_s0.yml",
21 | help=
22 | ("Experiment yaml file path relative to template_project_name/cfg/exp "
23 | "directory."),
24 | )
25 | parser.add_argument(
26 | "--exp_name",
27 | default="debug",
28 | help="overall experiment of this continual learning experiment.",
29 | )
30 |
31 | parser.add_argument(
32 | "--fix_nerf",
33 | action="store_true",
34 | help="whether or not to fix nerf during joint training",
35 | )
36 |
37 | parser.add_argument("--seed", default=123, type=int)
38 |
39 | parser.add_argument("--project_name", default="test_one_by_one")
40 | parser.add_argument("--nerf_train_epoch", default=10, type=int)
41 |
42 | parser.add_argument("--joint_train_epoch", default=10, type=int)
43 | args = parser.parse_args()
44 | return args
45 |
46 |
47 | def train(exp, env, exp_cfg_path, env_cfg_path, args) -> float:
48 | seed_everything(args.seed)
49 | exp["exp_name"] = args.exp_name
50 | exp["fix_nerf"] = args.fix_nerf
51 |
52 | # Create experiment folder.
53 | model_path = os.path.join(env["results"], exp["general"]["name"])
54 | if exp["general"]["clean_up_folder_if_exists"]:
55 | shutil.rmtree(model_path, ignore_errors=True)
56 |
57 | # Create the directory
58 | Path(model_path).mkdir(parents=True, exist_ok=True)
59 |
60 | # Copy config files
61 | exp_cfg_fn = os.path.split(exp_cfg_path)[-1]
62 | env_cfg_fn = os.path.split(env_cfg_path)[-1]
63 | print(f"Copy {env_cfg_path} to {model_path}/{exp_cfg_fn}")
64 | shutil.copy(exp_cfg_path, f"{model_path}/{exp_cfg_fn}")
65 | shutil.copy(env_cfg_path, f"{model_path}/{env_cfg_fn}")
66 | exp["general"]["name"] = model_path
67 |
68 | # Create logger.
69 | logger = get_wandb_logger(
70 | exp=exp,
71 | env=env,
72 | exp_p=exp_cfg_path,
73 | env_p=env_cfg_path,
74 | project_name=args.project_name,
75 | save_dir=model_path,
76 | )
77 | ex = flatten_dict(exp)
78 | logger.log_hyperparams(ex)
79 |
80 | # Create network and dataset.
81 | model = JointTrainLightningNet(exp, env)
82 | datamodule = JointTrainDataModule(exp, env)
83 | datamodule.setup()
84 |
85 | # Trainer setup.
86 | # - Callbacks
87 | lr_monitor = LearningRateMonitor(logging_interval="step")
88 | cb_ls = [lr_monitor]
89 |
90 | checkpoint_callback = ModelCheckpoint(
91 | dirpath=model_path,
92 | save_last=True,
93 | save_top_k=1,
94 | )
95 | cb_ls.append(checkpoint_callback)
96 | # - Set GPUs.
97 | if (exp["trainer"]).get("gpus", -1) == -1:
98 | nr = torch.cuda.device_count()
99 | print(f"Set GPU Count for Trainer to {nr}!")
100 | for i in range(nr):
101 | print(f"Device {i}: ", torch.cuda.get_device_name(i))
102 | exp["trainer"]["gpus"] = nr
103 |
104 | # - Check whether to restore checkpoint.
105 | if exp["trainer"]["resume_from_checkpoint"] is True:
106 | exp["trainer"]["resume_from_checkpoint"] = exp["general"][
107 | "checkpoint_load"]
108 | else:
109 | del exp["trainer"]["resume_from_checkpoint"]
110 |
111 | if exp["trainer"]["load_from_checkpoint"] is True:
112 | if exp["general"]["load_pretrain"]:
113 | checkpoint = torch.load(exp["general"]["checkpoint_load"])
114 | checkpoint = checkpoint["state_dict"]
115 | # remove any aux classifier stuff
116 | removekeys = [
117 | key for key in checkpoint.keys()
118 | if key.startswith("_model._model.aux_classifier")
119 | ]
120 | for key in removekeys:
121 | del checkpoint[key]
122 |
123 | seg_model_state_dict = {}
124 | for key in checkpoint.keys():
125 | seg_model_key = key.split(".", 1)[1]
126 | seg_model_state_dict[seg_model_key] = checkpoint[key]
127 |
128 | model.seg_model.load_state_dict(seg_model_state_dict, strict=True)
129 | else:
130 | checkpoint = torch.load(exp["general"]["checkpoint_load"])
131 | checkpoint = checkpoint["state_dict"]
132 | model.seg_model.load_state_dict(checkpoint)
133 |
134 | del exp["trainer"]["load_from_checkpoint"]
135 |
136 | # - Add distributed plugin.
137 | if exp["trainer"]["gpus"] > 1:
138 | if (exp["trainer"]["accelerator"] == "ddp" or
139 | exp["trainer"]["accelerator"] is None):
140 | ddp_plugin = DDPPlugin(find_unused_parameters=exp["trainer"].get(
141 | "find_unused_parameters", False))
142 | exp["trainer"]["plugins"] = [ddp_plugin]
143 |
144 | exp["trainer"]["max_epochs"] = args.nerf_train_epoch
145 |
146 | trainer_nerf = Trainer(
147 | **exp["trainer"],
148 | default_root_dir=model_path,
149 | logger=logger,
150 | callbacks=cb_ls,
151 | )
152 |
153 | exp["trainer"]["check_val_every_n_epoch"] = 10
154 | exp["trainer"]["max_epochs"] = args.joint_train_epoch
155 | trainer_joint = Trainer(
156 | **exp["trainer"],
157 | default_root_dir=model_path,
158 | logger=logger,
159 | callbacks=cb_ls,
160 | )
161 |
162 | # Train NeRF.
163 | model.joint_train = False
164 | trainer_nerf.fit(model,
165 | train_dataloaders=datamodule.train_dataloader_nerf())
166 | # test initial nerf performance on the training set
167 | trainer_joint.test(model, dataloaders=datamodule.test_dataloader_nerf())
168 | # # test initial seg performance on the validation set
169 | trainer_joint.validate(model, dataloaders=datamodule.val_dataloader())
170 | # joint train + old scenes
171 | model.joint_train = True
172 | # trainer_seg.fit(model, train_dataloaders=datamodule.train_dataloader_seg(), val_dataloaders=datamodule.val_dataloader())
173 | trainer_joint.fit(
174 | model,
175 | train_dataloaders=datamodule.train_dataloader_joint(),
176 | val_dataloaders=datamodule.val_dataloader(),
177 | )
178 | # test nerf performance on the training set after joint training + test generalization performance on scannet 25k
179 | trainer_joint.test(model, dataloaders=datamodule.train_dataloader_nerf())
180 | # predict pseudo labels
181 | trainer_joint.predict(model, dataloaders=datamodule.predict_dataloader())
182 | # save checkpoint of the deeplab model
183 | torch.save(
184 | {"state_dict": model.seg_model.state_dict()},
185 | os.path.join(model_path, "deeplab.ckpt"),
186 | )
187 |
188 |
189 | if __name__ == "__main__":
190 | os.chdir(ROOT_DIR)
191 | args = parse_args()
192 | exp_cfg_path = os.path.join(ROOT_DIR, args.exp)
193 | exp = load_yaml(exp_cfg_path)
194 | exp["general"]["load_pretrain"] = True
195 | env_cfg_path = os.path.join(ROOT_DIR, "cfg/env",
196 | os.environ["ENV_WORKSTATION_NAME"] + ".yml")
197 | env = load_yaml(env_cfg_path)
198 | train(exp, env, exp_cfg_path, env_cfg_path, args)
199 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [yapf]
2 | based_on_style = google
3 | indent_width = 4
4 | column_limit = 80
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | from setuptools import find_packages
4 |
5 | setup(
6 | name="nr4seg",
7 | version="0.0.1",
8 | author="Zhizheng Liu, Francesco Milano",
9 | author_email="liuzhi@student.ethz.ch, francesco.milano@mavt.ethz.ch",
10 | packages=find_packages(),
11 | python_requires=">=3.6",
12 | description=
13 | "[CVPR 2023] Unsupervised Continual Semantic Adaptation through Neural Rendering",
14 | )
15 |
--------------------------------------------------------------------------------